Compare commits

..

29 commits

Author SHA1 Message Date
Victor Shyba
24d749ed2b lbry-libtorrent is now on pypi 2020-03-10 21:46:20 -03:00
Victor Shyba
bdd843196d use ubuntu 18 on gitlab, temporarly 2020-03-10 20:25:16 -03:00
Victor Shyba
dd456347c6 calculate total bytes outside of dict 2020-03-10 20:25:16 -03:00
Victor Shyba
e31323fa9d tell progress, stop trying to read first piece 2020-03-10 20:25:16 -03:00
Victor Shyba
ed8572a961 wait started event 2020-03-10 20:25:16 -03:00
Victor Shyba
9b9e55d07d find and show largest file 2020-03-10 20:25:16 -03:00
Victor Shyba
eb9e913788 fix and test delete with torrents 2020-03-10 20:25:16 -03:00
Victor Shyba
539655292a fix moving to a new btih 2020-03-10 20:25:16 -03:00
Victor Shyba
5a7312cbb8 fix not knowing a torrent exists 2020-03-10 20:25:16 -03:00
Victor Shyba
a5a4b224e9 gitlab: use ubuntu on datanetwork tests 2020-03-10 20:25:16 -03:00
Victor Shyba
7f88fcd1b5 add boost on gitlab, fix failing test, add libtorrent to linux build 2020-03-10 20:25:16 -03:00
Victor Shyba
7e2381816d correct wheel with boost 1.65 2020-03-10 20:25:16 -03:00
Victor Shyba
3f2346eb01 linting and minor refactor 2020-03-10 20:25:16 -03:00
Victor Shyba
260a713b08 fixes from rebase, install libtorrent from s3 2020-03-10 20:25:16 -03:00
Victor Shyba
f6b4fee814 pylint 2020-03-10 20:25:16 -03:00
Victor Shyba
e9623fae13 working file list after torrent get 2020-03-10 20:25:16 -03:00
Victor Shyba
a1090679a5 torrent test and misc fixes 2020-03-10 20:25:16 -03:00
Victor Shyba
ad66358d18 fix torrent and stream manager reference leftovers 2020-03-10 20:25:16 -03:00
Victor Shyba
34c6e09e6f adds more torrent parts 2020-03-10 20:25:16 -03:00
Victor Shyba
507db5f79a torrent manager and torrent source 2020-03-10 20:25:16 -03:00
Victor Shyba
e6663603a6 fix unit tests 2020-03-10 20:25:16 -03:00
Victor Shyba
899852d8ba pylint 2020-03-10 20:25:16 -03:00
Victor Shyba
41175a814b fix save from resolve 2020-03-10 20:25:16 -03:00
Victor Shyba
78faf4e691 stream manager component becomes file manager component 2020-03-10 20:25:15 -03:00
Victor Shyba
7e1795b7b1 wip 2020-03-10 20:25:15 -03:00
Victor Shyba
837dc1b078 add torrent component 2020-03-10 20:25:15 -03:00
Jack Robison
8d0f31e466 add lbry.torrent 2020-03-10 20:25:15 -03:00
Jack Robison
404f0dcfe0 file manager refactor 2020-03-10 20:25:15 -03:00
Jack Robison
c9b868a238 ManagedDownloadSource and SourceManager refactor 2020-03-10 20:25:15 -03:00
227 changed files with 17047 additions and 17198 deletions

View file

@ -1,206 +1,79 @@
name: ci name: ci
on: ["push", "pull_request", "workflow_dispatch"] on: pull_request
jobs: jobs:
lint: lint:
name: lint name: lint
runs-on: ubuntu-20.04 runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v1
- uses: actions/setup-python@v4 - uses: actions/setup-python@v1
with: with:
python-version: '3.9' python-version: '3.7'
- name: extract pip cache - run: make install tools
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
restore-keys: ${{ runner.os }}-pip-
- run: pip install --user --upgrade pip wheel
- run: pip install -e .[lint]
- run: make lint - run: make lint
tests-unit: tests-unit:
name: "tests / unit" name: "tests / unit"
strategy: runs-on: ubuntu-latest
matrix:
os:
- ubuntu-20.04
- macos-latest
- windows-latest
runs-on: ${{ matrix.os }}
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v1
- uses: actions/setup-python@v4 - uses: actions/setup-python@v1
with: with:
python-version: '3.9' python-version: '3.7'
- name: set pip cache dir - run: make install tools
shell: bash - working-directory: lbry
run: echo "PIP_CACHE_DIR=$(pip cache dir)" >> $GITHUB_ENV
- name: extract pip cache
uses: actions/cache@v3
with:
path: ${{ env.PIP_CACHE_DIR }}
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
restore-keys: ${{ runner.os }}-pip-
- id: os-name
uses: ASzc/change-string-case-action@v5
with:
string: ${{ runner.os }}
- run: python -m pip install --user --upgrade pip wheel
- if: startsWith(runner.os, 'linux')
run: pip install -e .[test]
- if: startsWith(runner.os, 'linux')
env: env:
HOME: /tmp HOME: /tmp
run: make test-unit-coverage run: coverage run -p --source=lbry -m unittest discover -vv tests.unit
- if: startsWith(runner.os, 'linux') != true
run: pip install -e .[test]
- if: startsWith(runner.os, 'linux') != true
env:
HOME: /tmp
run: coverage run --source=lbry -m unittest tests/unit/test_conf.py
- name: submit coverage report
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COVERALLS_FLAG_NAME: tests-unit-${{ steps.os-name.outputs.lowercase }}
COVERALLS_PARALLEL: true
run: |
pip install coveralls
coveralls --service=github
tests-integration: tests-integration:
name: "tests / integration" name: "tests / integration"
runs-on: ubuntu-20.04 runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
test: test:
- datanetwork - datanetwork
- blockchain - blockchain
- claims
- takeovers
- transactions
- other - other
steps: steps:
- name: Configure sysctl limits - uses: actions/checkout@v1
run: | - uses: actions/setup-python@v1
sudo swapoff -a
sudo sysctl -w vm.swappiness=1
sudo sysctl -w fs.file-max=262144
sudo sysctl -w vm.max_map_count=262144
- name: Runs Elasticsearch
uses: elastic/elastic-github-actions/elasticsearch@master
with: with:
stack-version: 7.12.1 python-version: '3.7'
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.9'
- if: matrix.test == 'other' - if: matrix.test == 'other'
run: | run: sudo apt install -y --no-install-recommends ffmpeg
sudo apt-get update - run: pip install tox-travis
sudo apt-get install -y --no-install-recommends ffmpeg
- name: extract pip cache
uses: actions/cache@v3
with:
path: ./.tox
key: tox-integration-${{ matrix.test }}-${{ hashFiles('setup.py') }}
restore-keys: txo-integration-${{ matrix.test }}-
- run: pip install tox coverage coveralls
- if: matrix.test == 'claims'
run: rm -rf .tox
- run: tox -e ${{ matrix.test }} - run: tox -e ${{ matrix.test }}
- name: submit coverage report
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COVERALLS_FLAG_NAME: tests-integration-${{ matrix.test }}
COVERALLS_PARALLEL: true
run: |
coverage combine tests
coveralls --service=github
coverage:
needs: ["tests-unit", "tests-integration"]
runs-on: ubuntu-20.04
steps:
- name: finalize coverage report submission
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
pip install coveralls
coveralls --service=github --finish
build: build:
needs: ["lint", "tests-unit", "tests-integration"] needs: ["lint", "tests-unit", "tests-integration"]
name: "build / binary" name: "build"
strategy: strategy:
matrix: matrix:
os: os:
- ubuntu-20.04 - ubuntu-latest
- macos-latest - macos-latest
- windows-latest - windows-latest
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v1
- uses: actions/setup-python@v4 - uses: actions/setup-python@v1
with: with:
python-version: '3.9' python-version: '3.7'
- id: os-name - name: Setup
uses: ASzc/change-string-case-action@v5 run: |
with: pip install pyinstaller
string: ${{ runner.os }} pip install -e .
- name: set pip cache dir - if: startsWith(matrix.os, 'windows') == false
shell: bash
run: echo "PIP_CACHE_DIR=$(pip cache dir)" >> $GITHUB_ENV
- name: extract pip cache
uses: actions/cache@v3
with:
path: ${{ env.PIP_CACHE_DIR }}
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
restore-keys: ${{ runner.os }}-pip-
- run: pip install pyinstaller==4.6
- run: pip install -e .
- if: startsWith(github.ref, 'refs/tags/v')
run: python docker/set_build.py
- if: startsWith(runner.os, 'linux') || startsWith(runner.os, 'mac')
name: Build & Run (Unix) name: Build & Run (Unix)
run: | run: |
pyinstaller --onefile --name lbrynet lbry/extras/cli.py pyinstaller --onefile --name lbrynet lbry/extras/cli.py
chmod +x dist/lbrynet
dist/lbrynet --version dist/lbrynet --version
- if: startsWith(runner.os, 'windows') - if: startsWith(matrix.os, 'windows')
name: Build & Run (Windows) name: Build & Run (Windows)
run: | run: |
pip install pywin32==301 pip install pywin32
pyinstaller --additional-hooks-dir=scripts/. --icon=icons/lbry256.ico --onefile --name lbrynet lbry/extras/cli.py pyinstaller --additional-hooks-dir=scripts/. --icon=icons/lbry256.ico --onefile --name lbrynet lbry/extras/cli.py
dist/lbrynet.exe --version dist/lbrynet.exe --version
- uses: actions/upload-artifact@v3
with:
name: lbrynet-${{ steps.os-name.outputs.lowercase }}
path: dist/
release:
name: "release"
if: startsWith(github.ref, 'refs/tags/v')
needs: ["build"]
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v1
- uses: actions/download-artifact@v2
- name: upload binaries
env:
GITHUB_TOKEN: ${{ secrets.RELEASE_API_TOKEN }}
run: |
pip install githubrelease
chmod +x lbrynet-macos/lbrynet
chmod +x lbrynet-linux/lbrynet
zip --junk-paths lbrynet-mac.zip lbrynet-macos/lbrynet
zip --junk-paths lbrynet-linux.zip lbrynet-linux/lbrynet
zip --junk-paths lbrynet-windows.zip lbrynet-windows/lbrynet.exe
ls -lh
githubrelease release lbryio/lbry-sdk info ${GITHUB_REF#refs/tags/}
githubrelease asset lbryio/lbry-sdk upload ${GITHUB_REF#refs/tags/} \
lbrynet-mac.zip lbrynet-linux.zip lbrynet-windows.zip
githubrelease release lbryio/lbry-sdk publish ${GITHUB_REF#refs/tags/}

View file

@ -1,22 +0,0 @@
name: slack
on:
release:
types: [published]
jobs:
release:
name: "slack notification"
runs-on: ubuntu-20.04
steps:
- uses: LoveToKnow/slackify-markdown-action@v1.0.0
id: markdown
with:
text: "There is a new SDK release: ${{github.event.release.html_url}}\n${{ github.event.release.body }}"
- uses: slackapi/slack-github-action@v1.14.0
env:
CHANGELOG: '<!channel> ${{ steps.markdown.outputs.text }}'
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_RELEASE_BOT_WEBHOOK }}
with:
payload: '{"type": "mrkdwn", "text": ${{ toJSON(env.CHANGELOG) }} }'

7
.gitignore vendored
View file

@ -6,17 +6,12 @@
/.coverage* /.coverage*
/lbry-venv /lbry-venv
/venv /venv
/lbry/blockchain
lbry.egg-info lbry.egg-info
__pycache__ __pycache__
_trial_temp/ _trial_temp/
trending*.log
/tests/integration/claims/files /tests/integration/blockchain/files
/tests/.coverage.* /tests/.coverage.*
/lbry/wallet/bin /lbry/wallet/bin
/.vscode
/.gitignore

211
.gitlab-ci.yml Normal file
View file

@ -0,0 +1,211 @@
default:
image: python:3.7
#cache:
# directories:
# - $HOME/venv
# - $HOME/.cache/pip
# - $HOME/Library/Caches/pip
# - $HOME/Library/Caches/Homebrew
# - $TRAVIS_BUILD_DIR/.tox
stages:
- test
- build
- assets
- release
.tagged:
rules:
- if: '$CI_COMMIT_TAG =~ /^v[0-9\.]+$/'
when: on_success
test:lint:
stage: test
script:
- make install tools
- make lint
test:unit:
stage: test
script:
- make install tools
- HOME=/tmp coverage run -p --source=lbry -m unittest discover -vv tests.unit
test:datanetwork-integration:
stage: test
script:
- pip install tox-travis
- tox -e datanetwork --recreate
test:blockchain-integration:
stage: test
script:
- pip install tox-travis
- tox -e blockchain
test:other-integration:
stage: test
script:
- apt-get update
- apt-get install -y --no-install-recommends ffmpeg
- pip install tox-travis
- tox -e other
test:json-api:
stage: test
script:
- make install tools
- HOME=/tmp coverage run -p --source=lbry scripts/generate_json_api.py
.build:
stage: build
artifacts:
expire_in: 1 day
paths:
- lbrynet-${OS}.zip
script:
- pip install --upgrade 'setuptools<45.0.0'
- pip install pyinstaller
- pip install -e .
- python3.7 docker/set_build.py # must come after lbry is installed because it imports lbry
- pyinstaller --onefile --name lbrynet lbry/extras/cli.py
- chmod +x dist/lbrynet
- zip --junk-paths ${CI_PROJECT_DIR}/lbrynet-${OS}.zip dist/lbrynet # gitlab expects artifacts to be in $CI_PROJECT_DIR
- openssl dgst -sha256 ${CI_PROJECT_DIR}/lbrynet-${OS}.zip | egrep -o [0-9a-f]+$ # get sha256 of asset. works on mac and ubuntu
- dist/lbrynet --version
build:linux:
extends: .build
image: ubuntu:16.04
variables:
OS: linux
before_script:
- apt-get update
- apt-get install -y --no-install-recommends software-properties-common zip curl build-essential
- add-apt-repository -y ppa:deadsnakes/ppa
- apt-get update
- apt-get install -y --no-install-recommends python3.7-dev
- python3.7 <(curl -q https://bootstrap.pypa.io/get-pip.py) # make sure we get pip with python3.7
- pip install lbry-libtorrent
build:mac:
extends: .build
tags: [macos] # makes gitlab use the mac runner
variables:
OS: mac
GIT_DEPTH: 5
VENV: /tmp/gitlab-lbry-sdk-venv
before_script:
# - brew upgrade python || true
- python3 --version | grep -q '^Python 3\.7\.' # dont upgrade python on every run. just make sure we're on the right Python
# - pip3 install --user --upgrade pip virtualenv
- pip3 --version | grep -q '\(python 3\.7\)'
- virtualenv --python=python3.7 "${VENV}"
- source "${VENV}/bin/activate"
after_script:
- rm -rf "${VENV}"
build:windows:
extends: .build
tags: [windows] # makes gitlab use the windows runner
variables:
OS: windows
GIT_DEPTH: 5
before_script:
- ./docker/install_choco.ps1
- choco install -y --x86 python3 7zip checksum
- python --version # | findstr /B "Python 3\.7\." # dont upgrade python on every run. just make sure we're on the right Python
- pip --version # | findstr /E '\(python 3\.7\)'
- pip install virtualenv pywin32
- virtualenv venv
- venv/Scripts/activate.ps1
- pip install pip==19.3.1; $true # $true ignores errors. need this to get the correct coincurve wheel. see commit notes for details.
after_script:
- rmdir -Recurse venv
script:
- pip install --upgrade 'setuptools<45.0.0'
- pip install pyinstaller==3.5
- pip install -e .
- python docker/set_build.py # must come after lbry is installed because it imports lbry
- pyinstaller --additional-hooks-dir=scripts/. --icon=icons/lbry256.ico -F -n lbrynet lbry/extras/cli.py
- 7z a -tzip $env:CI_PROJECT_DIR/lbrynet-${OS}.zip ./dist/lbrynet.exe
- checksum --type=sha256 --file=$env:CI_PROJECT_DIR/lbrynet-${OS}.zip
- dist/lbrynet.exe --version
# s3 = upload asset to s3 (build.lbry.io)
.s3:
stage: assets
variables:
GIT_STRATEGY: none
script:
- "[ -f lbrynet-${OS}.zip ]" # check that asset exists before trying to upload
- pip install awscli
- S3_PATH="daemon/gitlab-build-${CI_PIPELINE_ID}_commit-${CI_COMMIT_SHA:0:7}$( if [ ! -z ${CI_COMMIT_TAG} ]; then echo _tag-${CI_COMMIT_TAG}; else echo _branch-${CI_COMMIT_REF_NAME}; fi )"
- AWS_ACCESS_KEY_ID=${ARTIFACTS_KEY} AWS_SECRET_ACCESS_KEY=${ARTIFACTS_SECRET} AWS_REGION=${ARTIFACTS_REGION}
aws s3 cp lbrynet-${OS}.zip s3://${ARTIFACTS_BUCKET}/${S3_PATH}/lbrynet-${OS}.zip
s3:linux:
extends: .s3
variables: {OS: linux}
needs: ["build:linux"]
s3:mac:
extends: .s3
variables: {OS: mac}
needs: ["build:mac"]
s3:windows:
extends: .s3
variables: {OS: windows}
needs: ["build:windows"]
# github = upload assets to github when there's a tagged release
.github:
extends: .tagged
stage: assets
variables:
GIT_STRATEGY: none
script:
- "[ -f lbrynet-${OS}.zip ]" # check that asset exists before trying to upload. githubrelease won't error if its missing
- pip install githubrelease
- githubrelease --no-progress --github-token ${GITHUB_CI_USER_ACCESS_TOKEN} asset lbryio/lbry-sdk upload ${CI_COMMIT_TAG} lbrynet-${OS}.zip
github:linux:
extends: .github
variables: {OS: linux}
needs: ["build:linux"]
github:mac:
extends: .github
variables: {OS: mac}
needs: ["build:mac"]
github:windows:
extends: .github
variables: {OS: windows}
needs: ["build:windows"]
publish:
extends: .tagged
stage: release
variables:
GIT_STRATEGY: none
script:
- pip install githubrelease
- githubrelease --no-progress --github-token ${GITHUB_CI_USER_ACCESS_TOKEN} release lbryio/lbry-sdk publish ${CI_COMMIT_TAG}
- >
curl -X POST -H 'Content-type: application/json' --data '{"text":"<!channel> There is a new SDK release: https://github.com/lbryio/lbry-sdk/releases/tag/'"${CI_COMMIT_TAG}"'\n'"$(curl -s "https://api.github.com/repos/lbryio/lbry-sdk/releases/tags/${CI_COMMIT_TAG}" | egrep '\w*\"body\":' | cut -d':' -f 2- | tail -c +3 | head -c -2)"'", "channel":"tech"}' "$(echo ${SLACK_WEBHOOK_URL_BASE64} | base64 -d)"

View file

@ -9,29 +9,20 @@ Here's a video walkthrough of this setup, which is itself hosted by the LBRY net
## Prerequisites ## Prerequisites
Running `lbrynet` from source requires Python 3.7. Get the installer for your OS [here](https://www.python.org/downloads/release/python-370/). Running `lbrynet` from source requires Python 3.7 or higher. Get the installer for your OS [here](https://www.python.org/downloads/release/python-370/).
After installing Python 3.7, you'll need to install some additional libraries depending on your operating system. After installing python 3, you'll need to install some additional libraries depending on your operating system.
Because of [issue #2769](https://github.com/lbryio/lbry-sdk/issues/2769)
at the moment the `lbrynet` daemon will only work correctly with Python 3.7.
If Python 3.8+ is used, the daemon will start but the RPC server
may not accept messages, returning the following:
```
Could not connect to daemon. Are you sure it's running?
```
### macOS ### macOS
macOS users will need to install [xcode command line tools](https://developer.xamarin.com/guides/testcloud/calabash/configuring/osx/install-xcode-command-line-tools/) and [homebrew](http://brew.sh/). macOS users will need to install [xcode command line tools](https://developer.xamarin.com/guides/testcloud/calabash/configuring/osx/install-xcode-command-line-tools/) and [homebrew](http://brew.sh/).
These environment variables also need to be set: These environment variables also need to be set:
``` 1. PYTHONUNBUFFERED=1
PYTHONUNBUFFERED=1 2. EVENT_NOKQUEUE=1
EVENT_NOKQUEUE=1
```
Remaining dependencies can then be installed by running: Remaining dependencies can then be installed by running:
``` ```
brew install python protobuf brew install python protobuf
``` ```
@ -40,17 +31,14 @@ Assistance installing Python3: https://docs.python-guide.org/starting/install3/o
### Linux ### Linux
On Ubuntu (we recommend 18.04 or 20.04), install the following: On Ubuntu (16.04 minimum, we recommend 18.04), install the following:
``` ```
sudo add-apt-repository ppa:deadsnakes/ppa sudo add-apt-repository ppa:deadsnakes/ppa
sudo apt-get update sudo apt-get update
sudo apt-get install build-essential python3.7 python3.7-dev git python3.7-venv libssl-dev python-protobuf sudo apt-get install build-essential python3.7 python3.7-dev git python3.7-venv libssl-dev python-protobuf
``` ```
The [deadsnakes PPA](https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa) provides Python 3.7
for those Ubuntu distributions that no longer have it in their
official repositories.
On Raspbian, you will also need to install `python-pyparsing`. On Raspbian, you will also need to install `python-pyparsing`.
If you're running another Linux distro, install the equivalent of the above packages for your system. If you're running another Linux distro, install the equivalent of the above packages for your system.
@ -59,119 +47,62 @@ If you're running another Linux distro, install the equivalent of the above pack
### Linux/Mac ### Linux/Mac
Clone the repository: To install on Linux/Mac:
```bash
git clone https://github.com/lbryio/lbry-sdk.git
cd lbry-sdk
``` ```
Clone the repository:
$ git clone https://github.com/lbryio/lbry-sdk.git
$ cd lbry-sdk
Create a Python virtual environment for lbry-sdk: Create a Python virtual environment for lbry-sdk:
```bash $ python3.7 -m venv lbry-venv
python3.7 -m venv lbry-venv
```
Activate virtual environment: Activating lbry-sdk virtual environment:
```bash $ source lbry-venv/bin/activate
source lbry-venv/bin/activate
```
Make sure you're on Python 3.7+ as default in the virtual environment: Make sure you're on Python 3.7+ (as the default Python in virtual environment):
```bash $ python --version
python --version
```
Install packages: Install packages:
```bash $ make install
make install
```
If you are on Linux and using PyCharm, generates initial configs: If you are on Linux and using PyCharm, generates initial configs:
```bash $ make idea
make idea
``` ```
To verify your installation, `which lbrynet` should return a path inside To verify your installation, `which lbrynet` should return a path inside of the `lbry-venv` folder created by the `python3.7 -m venv lbry-venv` command.
of the `lbry-venv` folder.
```bash
(lbry-venv) $ which lbrynet
/opt/lbry-sdk/lbry-venv/bin/lbrynet
```
To exit the virtual environment simply use the command `deactivate`.
### Windows ### Windows
Clone the repository: To install on Windows:
```bash
git clone https://github.com/lbryio/lbry-sdk.git
cd lbry-sdk
``` ```
Clone the repository:
> git clone https://github.com/lbryio/lbry-sdk.git
> cd lbry-sdk
Create a Python virtual environment for lbry-sdk: Create a Python virtual environment for lbry-sdk:
```bash > python -m venv lbry-venv
python -m venv lbry-venv
```
Activate virtual environment: Activating lbry-sdk virtual environment:
```bash > lbry-venv\Scripts\activate
lbry-venv\Scripts\activate
```
Install packages: Install packages:
```bash > pip install -e .
pip install -e .
``` ```
## Run the tests ## Run the tests
### Elasticsearch
For running integration tests, Elasticsearch is required to be available at localhost:9200/ To run the unit tests from the repo directory:
The easiest way to start it is using docker with:
```bash
make elastic-docker
```
Alternative installation methods are available [at Elasticsearch website](https://www.elastic.co/guide/en/elasticsearch/reference/current/install-elasticsearch.html).
To run the unit and integration tests from the repo directory:
``` ```
python -m unittest discover tests.unit python -m unittest discover tests.unit
python -m unittest discover tests.integration
``` ```
## Usage ## Usage
To start the API server: To start the API server:
``` `lbrynet start`
lbrynet start
```
Whenever the code inside [lbry-sdk/lbry](./lbry)
is modified we should run `make install` to recompile the `lbrynet`
executable with the newest code.
## Development
When developing, remember to enter the environment,
and if you wish start the server interactively.
```bash
$ source lbry-venv/bin/activate
(lbry-venv) $ python lbry/extras/cli.py start
```
Parameters can be passed in the same way.
```bash
(lbry-venv) $ python lbry/extras/cli.py wallet balance
```
If a Python debugger (`pdb` or `ipdb`) is installed we can also start it
in this way, set up break points, and step through the code.
```bash
(lbry-venv) $ pip install ipdb
(lbry-venv) $ ipdb lbry/extras/cli.py
```
Happy hacking! Happy hacking!

View file

@ -1,6 +1,6 @@
The MIT License (MIT) The MIT License (MIT)
Copyright (c) 2015-2022 LBRY Inc Copyright (c) 2015-2020 LBRY Inc
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish,

View file

@ -1,26 +1,24 @@
.PHONY: install tools lint test test-unit test-unit-coverage test-integration idea .PHONY: install tools lint test idea
install: install:
pip install https://s3.amazonaws.com/files.lbry.io/python_libtorrent-1.2.4-py3-none-any.whl
CFLAGS="-DSQLITE_MAX_VARIABLE_NUMBER=2500000" pip install -U https://github.com/rogerbinns/apsw/releases/download/3.30.1-r1/apsw-3.30.1-r1.zip \
--global-option=fetch \
--global-option=--version --global-option=3.30.1 --global-option=--all \
--global-option=build --global-option=--enable --global-option=fts5
pip install -e . pip install -e .
tools:
pip install mypy==0.701
pip install coverage astroid pylint
lint: lint:
pylint --rcfile=setup.cfg lbry pylint --rcfile=setup.cfg lbry
#mypy --ignore-missing-imports lbry #mypy --ignore-missing-imports lbry
test: test-unit test-integration test:
test-unit:
python -m unittest discover tests.unit
test-unit-coverage:
coverage run --source=lbry -m unittest discover -vv tests.unit
test-integration:
tox tox
idea: idea:
mkdir -p .idea mkdir -p .idea
cp -r scripts/idea/* .idea cp -r scripts/idea/* .idea
elastic-docker:
docker run -d -v lbryhub:/usr/share/elasticsearch/data -p 9200:9200 -p 9300:9300 -e"ES_JAVA_OPTS=-Xms512m -Xmx512m" -e "discovery.type=single-node" docker.elastic.co/elasticsearch/elasticsearch:7.12.1

View file

@ -1,10 +1,10 @@
# <img src="https://raw.githubusercontent.com/lbryio/lbry-sdk/master/lbry.png" alt="LBRY" width="48" height="36" /> LBRY SDK [![build](https://github.com/lbryio/lbry-sdk/actions/workflows/main.yml/badge.svg)](https://github.com/lbryio/lbry-sdk/actions/workflows/main.yml) [![coverage](https://coveralls.io/repos/github/lbryio/lbry-sdk/badge.svg)](https://coveralls.io/github/lbryio/lbry-sdk) # <img src="https://raw.githubusercontent.com/lbryio/lbry-sdk/master/lbry.png" alt="LBRY" width="48" height="36" /> LBRY SDK [![Gitlab CI Badge](https://ci.lbry.tech/lbry/lbry-sdk/badges/master/pipeline.svg)](https://ci.lbry.tech/lbry/lbry-sdk)
LBRY is a decentralized peer-to-peer protocol for publishing and accessing digital content. It utilizes the [LBRY blockchain](https://github.com/lbryio/lbrycrd) as a global namespace and database of digital content. Blockchain entries contain searchable content metadata, identities, rights and access rules. LBRY also provides a data network that consists of peers (seeders) uploading and downloading data from other peers, possibly in exchange for payments, as well as a distributed hash table used by peers to discover other peers. LBRY is a decentralized peer-to-peer protocol for publishing and accessing digital content. It utilizes the [LBRY blockchain](https://github.com/lbryio/lbrycrd) as a global namespace and database of digital content. Blockchain entries contain searchable content metadata, identities, rights and access rules. LBRY also provides a data network that consists of peers (seeders) uploading and downloading data from other peers, possibly in exchange for payments, as well as a distributed hash table used by peers to discover other peers.
LBRY SDK for Python is currently the most fully featured implementation of the LBRY Network protocols and includes many useful components and tools for building decentralized applications. Primary features and components include: LBRY SDK for Python is currently the most fully featured implementation of the LBRY Network protocols and includes many useful components and tools for building decentralized applications. Primary features and components include:
* Built on Python 3.7 and `asyncio`. * Built on Python 3.7+ and `asyncio`.
* Kademlia DHT (Distributed Hash Table) implementation for finding peers to download from and announcing to peers what we have to host ([lbry.dht](https://github.com/lbryio/lbry-sdk/tree/master/lbry/dht)). * Kademlia DHT (Distributed Hash Table) implementation for finding peers to download from and announcing to peers what we have to host ([lbry.dht](https://github.com/lbryio/lbry-sdk/tree/master/lbry/dht)).
* Blob exchange protocol for transferring encrypted blobs of content and negotiating payments ([lbry.blob_exchange](https://github.com/lbryio/lbry-sdk/tree/master/lbry/blob_exchange)). * Blob exchange protocol for transferring encrypted blobs of content and negotiating payments ([lbry.blob_exchange](https://github.com/lbryio/lbry-sdk/tree/master/lbry/blob_exchange)).
* Protobuf schema for encoding and decoding metadata stored on the blockchain ([lbry.schema](https://github.com/lbryio/lbry-sdk/tree/master/lbry/schema)). * Protobuf schema for encoding and decoding metadata stored on the blockchain ([lbry.schema](https://github.com/lbryio/lbry-sdk/tree/master/lbry/schema)).
@ -41,7 +41,7 @@ This project is MIT licensed. For the full license, see [LICENSE](LICENSE).
## Security ## Security
We take security seriously. Please contact security@lbry.com regarding any security issues. [Our PGP key is here](https://lbry.com/faq/pgp-key) if you need it. We take security seriously. Please contact security@lbry.com regarding any security issues. [Our GPG key is here](https://lbry.com/faq/gpg-key) if you need it.
## Contact ## Contact

View file

@ -1,9 +0,0 @@
# Security Policy
## Supported Versions
While we are not at v1.0 yet, only the latest release will be supported.
## Reporting a Vulnerability
See https://lbry.com/faq/security

View file

@ -1,43 +0,0 @@
FROM debian:10-slim
ARG user=lbry
ARG projects_dir=/home/$user
ARG db_dir=/database
ARG DOCKER_TAG
ARG DOCKER_COMMIT=docker
ENV DOCKER_TAG=$DOCKER_TAG DOCKER_COMMIT=$DOCKER_COMMIT
RUN apt-get update && \
apt-get -y --no-install-recommends install \
wget \
automake libtool \
tar unzip \
build-essential \
pkg-config \
libleveldb-dev \
python3.7 \
python3-dev \
python3-pip \
python3-wheel \
python3-setuptools && \
update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1 && \
rm -rf /var/lib/apt/lists/*
RUN groupadd -g 999 $user && useradd -m -u 999 -g $user $user
COPY . $projects_dir
RUN chown -R $user:$user $projects_dir
RUN mkdir -p $db_dir
RUN chown -R $user:$user $db_dir
USER $user
WORKDIR $projects_dir
RUN python3 -m pip install -U setuptools pip
RUN make install
RUN python3 docker/set_build.py
RUN rm ~/.cache -rf
VOLUME $db_dir
ENTRYPOINT ["python3", "scripts/dht_node.py"]

View file

@ -1,4 +1,4 @@
FROM debian:10-slim FROM ubuntu:19.10
ARG user=lbry ARG user=lbry
ARG db_dir=/database ARG db_dir=/database
@ -13,14 +13,10 @@ RUN apt-get update && \
wget \ wget \
tar unzip \ tar unzip \
build-essential \ build-essential \
automake libtool \ python3 \
pkg-config \
libleveldb-dev \
python3.7 \
python3-dev \ python3-dev \
python3-pip \ python3-pip \
python3-wheel \ python3-wheel \
python3-cffi \
python3-setuptools && \ python3-setuptools && \
update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1 && \ update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1 && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*

View file

@ -1,45 +0,0 @@
FROM debian:10-slim
ARG user=lbry
ARG downloads_dir=/database
ARG projects_dir=/home/$user
ARG DOCKER_TAG
ARG DOCKER_COMMIT=docker
ENV DOCKER_TAG=$DOCKER_TAG DOCKER_COMMIT=$DOCKER_COMMIT
RUN apt-get update && \
apt-get -y --no-install-recommends install \
wget \
automake libtool \
tar unzip \
build-essential \
pkg-config \
libleveldb-dev \
python3.7 \
python3-dev \
python3-pip \
python3-wheel \
python3-setuptools && \
update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1 && \
rm -rf /var/lib/apt/lists/*
RUN groupadd -g 999 $user && useradd -m -u 999 -g $user $user
RUN mkdir -p $downloads_dir
RUN chown -R $user:$user $downloads_dir
COPY . $projects_dir
RUN chown -R $user:$user $projects_dir
USER $user
WORKDIR $projects_dir
RUN pip install uvloop
RUN make install
RUN python3 docker/set_build.py
RUN rm ~/.cache -rf
# entry point
VOLUME $downloads_dir
COPY ./docker/webconf.yaml /webconf.yaml
ENTRYPOINT ["/home/lbry/.local/bin/lbrynet", "start", "--config=/webconf.yaml"]

View file

@ -1,9 +0,0 @@
### How to run with docker-compose
1. Edit config file and after that fix permissions with
```
sudo chown -R 999:999 webconf.yaml
```
2. Start SDK with
```
docker-compose up -d
```

View file

@ -1,49 +1,34 @@
version: "3" version: "3"
volumes: volumes:
lbrycrd:
wallet_server: wallet_server:
es01:
services: services:
wallet_server: lbrycrd:
depends_on: image: lbry/lbrycrd:${LBRYCRD_TAG:-latest-release}
- es01 restart: always
image: lbry/wallet-server:${WALLET_SERVER_TAG:-latest-release} ports: # accessible from host
- "9246:9246" # rpc port
expose: # internal to docker network. also this doesn't do anything. its for documentation only.
- "9245" # node-to-node comms port
volumes:
- "lbrycrd:/data/.lbrycrd"
environment:
- RUN_MODE=default
- SNAPSHOT_URL=${LBRYCRD_SNAPSHOT_URL-https://lbry.com/snapshot/blockchain}
- RPC_ALLOW_IP=0.0.0.0/0
wallet_server:
image: lbry/wallet-server:${WALLET_SERVER_TAG:-latest-release}
depends_on:
- lbrycrd
restart: always restart: always
network_mode: host
ports: ports:
- "50001:50001" # rpc port - "50001:50001" # rpc port
- "2112:2112" # uncomment to enable prometheus - "50005:50005" # websocket port
#- "2112:2112" # uncomment to enable prometheus
volumes: volumes:
- "wallet_server:/database" - "wallet_server:/database"
environment: environment:
- DAEMON_URL=http://lbry:lbry@127.0.0.1:9245 - SNAPSHOT_URL=${WALLET_SERVER_SNAPSHOT_URL-https://lbry.com/snapshot/wallet}
- MAX_QUERY_WORKERS=4 - DAEMON_URL=http://lbry:lbry@lbrycrd:9245
- CACHE_MB=1024
- CACHE_ALL_TX_HASHES=
- CACHE_ALL_CLAIM_TXOS=
- MAX_SEND=1000000000000000000
- MAX_RECEIVE=1000000000000000000
- MAX_SESSIONS=100000
- HOST=0.0.0.0
- TCP_PORT=50001
- PROMETHEUS_PORT=2112
- FILTERING_CHANNEL_IDS=770bd7ecba84fd2f7607fb15aedd2b172c2e153f 95e5db68a3101df19763f3a5182e4b12ba393ee8
- BLOCKING_CHANNEL_IDS=dd687b357950f6f271999971f43c785e8067c3a9 06871aa438032244202840ec59a469b303257cad b4a2528f436eca1bf3bf3e10ff3f98c57bd6c4c6
es01:
image: docker.elastic.co/elasticsearch/elasticsearch:7.11.0
container_name: es01
environment:
- node.name=es01
- discovery.type=single-node
- indices.query.bool.max_clause_count=8192
- bootstrap.memory_lock=true
- "ES_JAVA_OPTS=-Xms4g -Xmx4g" # no more than 32, remember to disable swap
ulimits:
memlock:
soft: -1
hard: -1
volumes:
- es01:/usr/share/elasticsearch/data
ports:
- 127.0.0.1:9200:9200

View file

@ -1,9 +0,0 @@
version: '3'
services:
websdk:
image: vshyba/websdk
ports:
- '5279:5279'
- '5280:5280'
volumes:
- ./webconf.yaml:/webconf.yaml

View file

@ -20,7 +20,7 @@ def _check_and_set(d: dict, key: str, value: str):
def main(): def main():
build_info = {item: build_info_mod.__dict__[item] for item in dir(build_info_mod) if not item.startswith("__")} build_info = {item: build_info_mod.__dict__[item] for item in dir(build_info_mod) if not item.startswith("__")}
commit_hash = os.getenv('DOCKER_COMMIT', os.getenv('GITHUB_SHA')) commit_hash = os.getenv('DOCKER_COMMIT', os.getenv('CI_COMMIT_SHA', os.getenv('TRAVIS_COMMIT')))
if commit_hash is None: if commit_hash is None:
raise ValueError("Commit hash not found in env vars") raise ValueError("Commit hash not found in env vars")
_check_and_set(build_info, "COMMIT_HASH", commit_hash[:6]) _check_and_set(build_info, "COMMIT_HASH", commit_hash[:6])
@ -30,10 +30,8 @@ def main():
_check_and_set(build_info, "DOCKER_TAG", docker_tag) _check_and_set(build_info, "DOCKER_TAG", docker_tag)
_check_and_set(build_info, "BUILD", "docker") _check_and_set(build_info, "BUILD", "docker")
else: else:
if re.match(r'refs/tags/v\d+\.\d+\.\d+$', str(os.getenv('GITHUB_REF'))): ci_tag = os.getenv('CI_COMMIT_TAG', os.getenv('TRAVIS_TAG'))
_check_and_set(build_info, "BUILD", "release") _check_and_set(build_info, "BUILD", "release" if re.match(r'v\d+\.\d+\.\d+$', str(ci_tag)) else "qa")
else:
_check_and_set(build_info, "BUILD", "qa")
log.debug("build info: %s", ", ".join([f"{k}={v}" for k, v in build_info.items()])) log.debug("build info: %s", ", ".join([f"{k}={v}" for k, v in build_info.items()]))
with open(build_info_mod.__file__, 'w') as f: with open(build_info_mod.__file__, 'w') as f:

View file

@ -6,7 +6,7 @@ set -euo pipefail
SNAPSHOT_URL="${SNAPSHOT_URL:-}" #off by default. latest snapshot at https://lbry.com/snapshot/wallet SNAPSHOT_URL="${SNAPSHOT_URL:-}" #off by default. latest snapshot at https://lbry.com/snapshot/wallet
if [[ -n "$SNAPSHOT_URL" ]] && [[ ! -f /database/lbry-leveldb ]]; then if [[ -n "$SNAPSHOT_URL" ]] && [[ ! -f /database/claims.db ]]; then
files="$(ls)" files="$(ls)"
echo "Downloading wallet snapshot from $SNAPSHOT_URL" echo "Downloading wallet snapshot from $SNAPSHOT_URL"
wget --no-verbose --trust-server-names --content-disposition "$SNAPSHOT_URL" wget --no-verbose --trust-server-names --content-disposition "$SNAPSHOT_URL"
@ -20,6 +20,4 @@ if [[ -n "$SNAPSHOT_URL" ]] && [[ ! -f /database/lbry-leveldb ]]; then
rm "$filename" rm "$filename"
fi fi
/home/lbry/.local/bin/lbry-hub-elastic-sync /home/lbry/.local/bin/torba-server "$@"
echo 'starting server'
/home/lbry/.local/bin/lbry-hub "$@"

View file

@ -1,9 +0,0 @@
allowed_origin: "*"
max_key_fee: "0.0 USD"
save_files: false
save_blobs: false
streaming_server: "0.0.0.0:5280"
api: "0.0.0.0:5279"
data_dir: /tmp
download_dir: /tmp
wallet_dir: /tmp

File diff suppressed because one or more lines are too long

View file

@ -1,2 +1,2 @@
__version__ = "0.113.0" __version__ = "0.64.0"
version = tuple(map(int, __version__.split('.'))) # pylint: disable=invalid-name version = tuple(map(int, __version__.split('.'))) # pylint: disable=invalid-name

View file

@ -1,6 +1,5 @@
import os import os
import re import re
import time
import asyncio import asyncio
import binascii import binascii
import logging import logging
@ -71,27 +70,21 @@ class AbstractBlob:
'writers', 'writers',
'verified', 'verified',
'writing', 'writing',
'readers', 'readers'
'added_on',
'is_mine',
] ]
def __init__( def __init__(self, loop: asyncio.AbstractEventLoop, blob_hash: str, length: typing.Optional[int] = None,
self, loop: asyncio.AbstractEventLoop, blob_hash: str, length: typing.Optional[int] = None,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None, blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None,
blob_directory: typing.Optional[str] = None, added_on: typing.Optional[int] = None, is_mine: bool = False, blob_directory: typing.Optional[str] = None):
):
self.loop = loop self.loop = loop
self.blob_hash = blob_hash self.blob_hash = blob_hash
self.length = length self.length = length
self.blob_completed_callback = blob_completed_callback self.blob_completed_callback = blob_completed_callback
self.blob_directory = blob_directory self.blob_directory = blob_directory
self.writers: typing.Dict[typing.Tuple[typing.Optional[str], typing.Optional[int]], HashBlobWriter] = {} self.writers: typing.Dict[typing.Tuple[typing.Optional[str], typing.Optional[int]], HashBlobWriter] = {}
self.verified: asyncio.Event = asyncio.Event() self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
self.writing: asyncio.Event = asyncio.Event() self.writing: asyncio.Event = asyncio.Event(loop=self.loop)
self.readers: typing.List[typing.BinaryIO] = [] self.readers: typing.List[typing.BinaryIO] = []
self.added_on = added_on or time.time()
self.is_mine = is_mine
if not is_valid_blobhash(blob_hash): if not is_valid_blobhash(blob_hash):
raise InvalidBlobHashError(blob_hash) raise InvalidBlobHashError(blob_hash)
@ -117,7 +110,7 @@ class AbstractBlob:
if reader in self.readers: if reader in self.readers:
self.readers.remove(reader) self.readers.remove(reader)
def _write_blob(self, blob_bytes: bytes) -> asyncio.Task: def _write_blob(self, blob_bytes: bytes):
raise NotImplementedError() raise NotImplementedError()
def set_length(self, length) -> None: def set_length(self, length) -> None:
@ -188,41 +181,34 @@ class AbstractBlob:
@classmethod @classmethod
async def create_from_unencrypted( async def create_from_unencrypted(
cls, loop: asyncio.AbstractEventLoop, blob_dir: typing.Optional[str], key: bytes, iv: bytes, cls, loop: asyncio.AbstractEventLoop, blob_dir: typing.Optional[str], key: bytes, iv: bytes,
unencrypted: bytes, blob_num: int, added_on: int, is_mine: bool, unencrypted: bytes, blob_num: int,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], None]] = None, blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], None]] = None) -> BlobInfo:
) -> BlobInfo:
""" """
Create an encrypted BlobFile from plaintext bytes Create an encrypted BlobFile from plaintext bytes
""" """
blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted) blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted)
length = len(blob_bytes) length = len(blob_bytes)
blob = cls(loop, blob_hash, length, blob_completed_callback, blob_dir, added_on, is_mine) blob = cls(loop, blob_hash, length, blob_completed_callback, blob_dir)
writer = blob.get_blob_writer() writer = blob.get_blob_writer()
writer.write(blob_bytes) writer.write(blob_bytes)
await blob.verified.wait() await blob.verified.wait()
return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), added_on, blob_hash, is_mine) return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), blob_hash)
def save_verified_blob(self, verified_bytes: bytes): def save_verified_blob(self, verified_bytes: bytes):
if self.verified.is_set(): if self.verified.is_set():
return return
def update_events(_):
self.verified.set()
self.writing.clear()
if self.is_writeable(): if self.is_writeable():
self.writing.set() self._write_blob(verified_bytes)
task = self._write_blob(verified_bytes) self.verified.set()
task.add_done_callback(update_events)
if self.blob_completed_callback: if self.blob_completed_callback:
task.add_done_callback(lambda _: self.blob_completed_callback(self)) self.blob_completed_callback(self)
def get_blob_writer(self, peer_address: typing.Optional[str] = None, def get_blob_writer(self, peer_address: typing.Optional[str] = None,
peer_port: typing.Optional[int] = None) -> HashBlobWriter: peer_port: typing.Optional[int] = None) -> HashBlobWriter:
if (peer_address, peer_port) in self.writers and not self.writers[(peer_address, peer_port)].closed(): if (peer_address, peer_port) in self.writers and not self.writers[(peer_address, peer_port)].closed():
raise OSError(f"attempted to download blob twice from {peer_address}:{peer_port}") raise OSError(f"attempted to download blob twice from {peer_address}:{peer_port}")
fut = asyncio.Future() fut = asyncio.Future(loop=self.loop)
writer = HashBlobWriter(self.blob_hash, self.get_length, fut) writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
self.writers[(peer_address, peer_port)] = writer self.writers[(peer_address, peer_port)] = writer
@ -256,13 +242,11 @@ class BlobBuffer(AbstractBlob):
""" """
An in-memory only blob An in-memory only blob
""" """
def __init__( def __init__(self, loop: asyncio.AbstractEventLoop, blob_hash: str, length: typing.Optional[int] = None,
self, loop: asyncio.AbstractEventLoop, blob_hash: str, length: typing.Optional[int] = None,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None, blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None,
blob_directory: typing.Optional[str] = None, added_on: typing.Optional[int] = None, is_mine: bool = False blob_directory: typing.Optional[str] = None):
):
self._verified_bytes: typing.Optional[BytesIO] = None self._verified_bytes: typing.Optional[BytesIO] = None
super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory, added_on, is_mine) super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory)
@contextlib.contextmanager @contextlib.contextmanager
def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]: def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
@ -277,11 +261,9 @@ class BlobBuffer(AbstractBlob):
self.verified.clear() self.verified.clear()
def _write_blob(self, blob_bytes: bytes): def _write_blob(self, blob_bytes: bytes):
async def write():
if self._verified_bytes: if self._verified_bytes:
raise OSError("already have bytes for blob") raise OSError("already have bytes for blob")
self._verified_bytes = BytesIO(blob_bytes) self._verified_bytes = BytesIO(blob_bytes)
return self.loop.create_task(write())
def delete(self): def delete(self):
if self._verified_bytes: if self._verified_bytes:
@ -299,12 +281,10 @@ class BlobFile(AbstractBlob):
""" """
A blob existing on the local file system A blob existing on the local file system
""" """
def __init__( def __init__(self, loop: asyncio.AbstractEventLoop, blob_hash: str, length: typing.Optional[int] = None,
self, loop: asyncio.AbstractEventLoop, blob_hash: str, length: typing.Optional[int] = None,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None, blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None,
blob_directory: typing.Optional[str] = None, added_on: typing.Optional[int] = None, is_mine: bool = False blob_directory: typing.Optional[str] = None):
): super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory)
super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory, added_on, is_mine)
if not blob_directory or not os.path.isdir(blob_directory): if not blob_directory or not os.path.isdir(blob_directory):
raise OSError(f"invalid blob directory '{blob_directory}'") raise OSError(f"invalid blob directory '{blob_directory}'")
self.file_path = os.path.join(self.blob_directory, self.blob_hash) self.file_path = os.path.join(self.blob_directory, self.blob_hash)
@ -339,28 +319,22 @@ class BlobFile(AbstractBlob):
handle.close() handle.close()
def _write_blob(self, blob_bytes: bytes): def _write_blob(self, blob_bytes: bytes):
def _write_blob():
with open(self.file_path, 'wb') as f: with open(self.file_path, 'wb') as f:
f.write(blob_bytes) f.write(blob_bytes)
async def write_blob():
await self.loop.run_in_executor(None, _write_blob)
return self.loop.create_task(write_blob())
def delete(self): def delete(self):
super().delete()
if os.path.isfile(self.file_path): if os.path.isfile(self.file_path):
os.remove(self.file_path) os.remove(self.file_path)
return super().delete()
@classmethod @classmethod
async def create_from_unencrypted( async def create_from_unencrypted(
cls, loop: asyncio.AbstractEventLoop, blob_dir: typing.Optional[str], key: bytes, iv: bytes, cls, loop: asyncio.AbstractEventLoop, blob_dir: typing.Optional[str], key: bytes, iv: bytes,
unencrypted: bytes, blob_num: int, added_on: float, is_mine: bool, unencrypted: bytes, blob_num: int,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'],
) -> BlobInfo: asyncio.Task]] = None) -> BlobInfo:
if not blob_dir or not os.path.isdir(blob_dir): if not blob_dir or not os.path.isdir(blob_dir):
raise OSError(f"cannot create blob in directory: '{blob_dir}'") raise OSError(f"cannot create blob in directory: '{blob_dir}'")
return await super().create_from_unencrypted( return await super().create_from_unencrypted(
loop, blob_dir, key, iv, unencrypted, blob_num, added_on, is_mine, blob_completed_callback loop, blob_dir, key, iv, unencrypted, blob_num, blob_completed_callback
) )

View file

@ -7,19 +7,13 @@ class BlobInfo:
'blob_num', 'blob_num',
'length', 'length',
'iv', 'iv',
'added_on',
'is_mine'
] ]
def __init__( def __init__(self, blob_num: int, length: int, iv: str, blob_hash: typing.Optional[str] = None):
self, blob_num: int, length: int, iv: str, added_on,
blob_hash: typing.Optional[str] = None, is_mine=False):
self.blob_hash = blob_hash self.blob_hash = blob_hash
self.blob_num = blob_num self.blob_num = blob_num
self.length = length self.length = length
self.iv = iv self.iv = iv
self.added_on = added_on
self.is_mine = is_mine
def as_dict(self) -> typing.Dict: def as_dict(self) -> typing.Dict:
d = { d = {

View file

@ -2,7 +2,7 @@ import os
import typing import typing
import asyncio import asyncio
import logging import logging
from lbry.utils import LRUCacheWithMetrics from lbry.utils import LRUCache
from lbry.blob.blob_file import is_valid_blobhash, BlobFile, BlobBuffer, AbstractBlob from lbry.blob.blob_file import is_valid_blobhash, BlobFile, BlobBuffer, AbstractBlob
from lbry.stream.descriptor import StreamDescriptor from lbry.stream.descriptor import StreamDescriptor
from lbry.connection_manager import ConnectionManager from lbry.connection_manager import ConnectionManager
@ -32,34 +32,34 @@ class BlobManager:
else self._node_data_store.completed_blobs else self._node_data_store.completed_blobs
self.blobs: typing.Dict[str, AbstractBlob] = {} self.blobs: typing.Dict[str, AbstractBlob] = {}
self.config = config self.config = config
self.decrypted_blob_lru_cache = None if not self.config.blob_lru_cache_size else LRUCacheWithMetrics( self.decrypted_blob_lru_cache = None if not self.config.blob_lru_cache_size else LRUCache(
self.config.blob_lru_cache_size) self.config.blob_lru_cache_size)
self.connection_manager = ConnectionManager(loop) self.connection_manager = ConnectionManager(loop)
def _get_blob(self, blob_hash: str, length: typing.Optional[int] = None, is_mine: bool = False): def _get_blob(self, blob_hash: str, length: typing.Optional[int] = None):
if self.config.save_blobs or ( if self.config.save_blobs or (
is_valid_blobhash(blob_hash) and os.path.isfile(os.path.join(self.blob_dir, blob_hash))): is_valid_blobhash(blob_hash) and os.path.isfile(os.path.join(self.blob_dir, blob_hash))):
return BlobFile( return BlobFile(
self.loop, blob_hash, length, self.blob_completed, self.blob_dir, is_mine=is_mine self.loop, blob_hash, length, self.blob_completed, self.blob_dir
) )
return BlobBuffer( return BlobBuffer(
self.loop, blob_hash, length, self.blob_completed, self.blob_dir, is_mine=is_mine self.loop, blob_hash, length, self.blob_completed, self.blob_dir
) )
def get_blob(self, blob_hash, length: typing.Optional[int] = None, is_mine: bool = False): def get_blob(self, blob_hash, length: typing.Optional[int] = None):
if blob_hash in self.blobs: if blob_hash in self.blobs:
if self.config.save_blobs and isinstance(self.blobs[blob_hash], BlobBuffer): if self.config.save_blobs and isinstance(self.blobs[blob_hash], BlobBuffer):
buffer = self.blobs.pop(blob_hash) buffer = self.blobs.pop(blob_hash)
if blob_hash in self.completed_blob_hashes: if blob_hash in self.completed_blob_hashes:
self.completed_blob_hashes.remove(blob_hash) self.completed_blob_hashes.remove(blob_hash)
self.blobs[blob_hash] = self._get_blob(blob_hash, length, is_mine) self.blobs[blob_hash] = self._get_blob(blob_hash, length)
if buffer.is_readable(): if buffer.is_readable():
with buffer.reader_context() as reader: with buffer.reader_context() as reader:
self.blobs[blob_hash].write_blob(reader.read()) self.blobs[blob_hash].write_blob(reader.read())
if length and self.blobs[blob_hash].length is None: if length and self.blobs[blob_hash].length is None:
self.blobs[blob_hash].set_length(length) self.blobs[blob_hash].set_length(length)
else: else:
self.blobs[blob_hash] = self._get_blob(blob_hash, length, is_mine) self.blobs[blob_hash] = self._get_blob(blob_hash, length)
return self.blobs[blob_hash] return self.blobs[blob_hash]
def is_blob_verified(self, blob_hash: str, length: typing.Optional[int] = None) -> bool: def is_blob_verified(self, blob_hash: str, length: typing.Optional[int] = None) -> bool:
@ -83,8 +83,6 @@ class BlobManager:
to_add = await self.storage.sync_missing_blobs(in_blobfiles_dir) to_add = await self.storage.sync_missing_blobs(in_blobfiles_dir)
if to_add: if to_add:
self.completed_blob_hashes.update(to_add) self.completed_blob_hashes.update(to_add)
# check blobs that aren't set as finished but were seen on disk
await self.ensure_completed_blobs_status(in_blobfiles_dir - to_add)
if self.config.track_bandwidth: if self.config.track_bandwidth:
self.connection_manager.start() self.connection_manager.start()
return True return True
@ -107,26 +105,13 @@ class BlobManager:
if isinstance(blob, BlobFile): if isinstance(blob, BlobFile):
if blob.blob_hash not in self.completed_blob_hashes: if blob.blob_hash not in self.completed_blob_hashes:
self.completed_blob_hashes.add(blob.blob_hash) self.completed_blob_hashes.add(blob.blob_hash)
return self.loop.create_task(self.storage.add_blobs( return self.loop.create_task(self.storage.add_blobs((blob.blob_hash, blob.length), finished=True))
(blob.blob_hash, blob.length, blob.added_on, blob.is_mine), finished=True)
)
else: else:
return self.loop.create_task(self.storage.add_blobs( return self.loop.create_task(self.storage.add_blobs((blob.blob_hash, blob.length), finished=False))
(blob.blob_hash, blob.length, blob.added_on, blob.is_mine), finished=False)
)
async def ensure_completed_blobs_status(self, blob_hashes: typing.Iterable[str]): def check_completed_blobs(self, blob_hashes: typing.List[str]) -> typing.List[str]:
"""Ensures that completed blobs from a given list of blob hashes are set as 'finished' in the database.""" """Returns of the blobhashes_to_check, which are valid"""
to_add = [] return [blob_hash for blob_hash in blob_hashes if self.is_blob_verified(blob_hash)]
for blob_hash in blob_hashes:
if not self.is_blob_verified(blob_hash):
continue
blob = self.get_blob(blob_hash)
to_add.append((blob.blob_hash, blob.length, blob.added_on, blob.is_mine))
if len(to_add) > 500:
await self.storage.add_blobs(*to_add, finished=True)
to_add.clear()
return await self.storage.add_blobs(*to_add, finished=True)
def delete_blob(self, blob_hash: str): def delete_blob(self, blob_hash: str):
if not is_valid_blobhash(blob_hash): if not is_valid_blobhash(blob_hash):

View file

@ -1,77 +0,0 @@
import asyncio
import logging
log = logging.getLogger(__name__)
class DiskSpaceManager:
def __init__(self, config, db, blob_manager, cleaning_interval=30 * 60, analytics=None):
self.config = config
self.db = db
self.blob_manager = blob_manager
self.cleaning_interval = cleaning_interval
self.running = False
self.task = None
self.analytics = analytics
self._used_space_bytes = None
async def get_free_space_mb(self, is_network_blob=False):
limit_mb = self.config.network_storage_limit if is_network_blob else self.config.blob_storage_limit
space_used_mb = await self.get_space_used_mb()
space_used_mb = space_used_mb['network_storage'] if is_network_blob else space_used_mb['content_storage']
return max(0, limit_mb - space_used_mb)
async def get_space_used_bytes(self):
self._used_space_bytes = await self.db.get_stored_blob_disk_usage()
return self._used_space_bytes
async def get_space_used_mb(self, cached=True):
cached = cached and self._used_space_bytes is not None
space_used_bytes = self._used_space_bytes if cached else await self.get_space_used_bytes()
return {key: int(value/1024.0/1024.0) for key, value in space_used_bytes.items()}
async def clean(self):
await self._clean(False)
await self._clean(True)
async def _clean(self, is_network_blob=False):
space_used_mb = await self.get_space_used_mb(cached=False)
if is_network_blob:
space_used_mb = space_used_mb['network_storage']
else:
space_used_mb = space_used_mb['content_storage'] + space_used_mb['private_storage']
storage_limit_mb = self.config.network_storage_limit if is_network_blob else self.config.blob_storage_limit
if self.analytics:
asyncio.create_task(
self.analytics.send_disk_space_used(space_used_mb, storage_limit_mb, is_network_blob)
)
delete = []
available = storage_limit_mb - space_used_mb
if storage_limit_mb == 0 if not is_network_blob else available >= 0:
return 0
for blob_hash, file_size, _ in await self.db.get_stored_blobs(is_mine=False, is_network_blob=is_network_blob):
delete.append(blob_hash)
available += int(file_size/1024.0/1024.0)
if available >= 0:
break
if delete:
await self.db.stop_all_files()
await self.blob_manager.delete_blobs(delete, delete_from_db=True)
self._used_space_bytes = None
return len(delete)
async def cleaning_loop(self):
while self.running:
await asyncio.sleep(self.cleaning_interval)
await self.clean()
async def start(self):
self.running = True
self.task = asyncio.create_task(self.cleaning_loop())
self.task.add_done_callback(lambda _: log.info("Stopping blob cleanup service."))
async def stop(self):
if self.running:
self.running = False
self.task.cancel()

View file

@ -32,7 +32,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.buf = b'' self.buf = b''
# this is here to handle the race when the downloader is closed right as response_fut gets a result # this is here to handle the race when the downloader is closed right as response_fut gets a result
self.closed = asyncio.Event() self.closed = asyncio.Event(loop=self.loop)
def data_received(self, data: bytes): def data_received(self, data: bytes):
if self.connection_manager: if self.connection_manager:
@ -111,7 +111,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.transport.write(msg) self.transport.write(msg)
if self.connection_manager: if self.connection_manager:
self.connection_manager.sent_data(f"{self.peer_address}:{self.peer_port}", len(msg)) self.connection_manager.sent_data(f"{self.peer_address}:{self.peer_port}", len(msg))
response: BlobResponse = await asyncio.wait_for(self._response_fut, self.peer_timeout) response: BlobResponse = await asyncio.wait_for(self._response_fut, self.peer_timeout, loop=self.loop)
availability_response = response.get_availability_response() availability_response = response.get_availability_response()
price_response = response.get_price_response() price_response = response.get_price_response()
blob_response = response.get_blob_response() blob_response = response.get_blob_response()
@ -151,9 +151,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
f" timeout in {self.peer_timeout}" f" timeout in {self.peer_timeout}"
log.debug(msg) log.debug(msg)
msg = f"downloaded {self.blob.blob_hash[:8]} from {self.peer_address}:{self.peer_port}" msg = f"downloaded {self.blob.blob_hash[:8]} from {self.peer_address}:{self.peer_port}"
await asyncio.wait_for(self.writer.finished, self.peer_timeout) await asyncio.wait_for(self.writer.finished, self.peer_timeout, loop=self.loop)
# wait for the io to finish
await self.blob.verified.wait()
log.info("%s at %fMB/s", msg, log.info("%s at %fMB/s", msg,
round((float(self._blob_bytes_received) / round((float(self._blob_bytes_received) /
float(time.perf_counter() - start_time)) / 1000000.0, 2)) float(time.perf_counter() - start_time)) / 1000000.0, 2))
@ -187,7 +185,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
try: try:
self._blob_bytes_received = 0 self._blob_bytes_received = 0
self.blob, self.writer = blob, blob.get_blob_writer(self.peer_address, self.peer_port) self.blob, self.writer = blob, blob.get_blob_writer(self.peer_address, self.peer_port)
self._response_fut = asyncio.Future() self._response_fut = asyncio.Future(loop=self.loop)
return await self._download_blob() return await self._download_blob()
except OSError: except OSError:
# i'm not sure how to fix this race condition - jack # i'm not sure how to fix this race condition - jack
@ -244,7 +242,7 @@ async def request_blob(loop: asyncio.AbstractEventLoop, blob: Optional['Abstract
try: try:
if not connected_protocol: if not connected_protocol:
await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port), await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port),
peer_connect_timeout) peer_connect_timeout, loop=loop)
connected_protocol = protocol connected_protocol = protocol
if blob is None or blob.get_is_verified() or not blob.is_writeable(): if blob is None or blob.get_is_verified() or not blob.is_writeable():
# blob is None happens when we are just opening a connection # blob is None happens when we are just opening a connection

View file

@ -3,7 +3,6 @@ import typing
import logging import logging
from lbry.utils import cache_concurrent from lbry.utils import cache_concurrent
from lbry.blob_exchange.client import request_blob from lbry.blob_exchange.client import request_blob
from lbry.dht.node import get_kademlia_peers_from_hosts
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbry.conf import Config from lbry.conf import Config
from lbry.dht.node import Node from lbry.dht.node import Node
@ -30,7 +29,7 @@ class BlobDownloader:
self.failures: typing.Dict['KademliaPeer', int] = {} self.failures: typing.Dict['KademliaPeer', int] = {}
self.connection_failures: typing.Set['KademliaPeer'] = set() self.connection_failures: typing.Set['KademliaPeer'] = set()
self.connections: typing.Dict['KademliaPeer', 'BlobExchangeClientProtocol'] = {} self.connections: typing.Dict['KademliaPeer', 'BlobExchangeClientProtocol'] = {}
self.is_running = asyncio.Event() self.is_running = asyncio.Event(loop=self.loop)
def should_race_continue(self, blob: 'AbstractBlob'): def should_race_continue(self, blob: 'AbstractBlob'):
max_probes = self.config.max_connections_per_download * (1 if self.connections else 10) max_probes = self.config.max_connections_per_download * (1 if self.connections else 10)
@ -64,8 +63,8 @@ class BlobDownloader:
self.scores[peer] = bytes_received / elapsed if bytes_received and elapsed else 1 self.scores[peer] = bytes_received / elapsed if bytes_received and elapsed else 1
async def new_peer_or_finished(self): async def new_peer_or_finished(self):
active_tasks = list(self.active_connections.values()) + [asyncio.create_task(asyncio.sleep(1))] active_tasks = list(self.active_connections.values()) + [asyncio.sleep(1)]
await asyncio.wait(active_tasks, return_when='FIRST_COMPLETED') await asyncio.wait(active_tasks, loop=self.loop, return_when='FIRST_COMPLETED')
def cleanup_active(self): def cleanup_active(self):
if not self.active_connections and not self.connections: if not self.active_connections and not self.connections:
@ -88,6 +87,7 @@ class BlobDownloader:
if blob.get_is_verified(): if blob.get_is_verified():
return blob return blob
self.is_running.set() self.is_running.set()
tried_for_this_blob: typing.Set['KademliaPeer'] = set()
try: try:
while not blob.get_is_verified() and self.is_running.is_set(): while not blob.get_is_verified() and self.is_running.is_set():
batch: typing.Set['KademliaPeer'] = set(self.connections.keys()) batch: typing.Set['KademliaPeer'] = set(self.connections.keys())
@ -97,14 +97,23 @@ class BlobDownloader:
"%s running, %d peers, %d ignored, %d active, %s connections", blob_hash[:6], "%s running, %d peers, %d ignored, %d active, %s connections", blob_hash[:6],
len(batch), len(self.ignored), len(self.active_connections), len(self.connections) len(batch), len(self.ignored), len(self.active_connections), len(self.connections)
) )
re_add: typing.Set['KademliaPeer'] = set()
for peer in sorted(batch, key=lambda peer: self.scores.get(peer, 0), reverse=True): for peer in sorted(batch, key=lambda peer: self.scores.get(peer, 0), reverse=True):
if peer in self.ignored: if peer in self.ignored:
continue continue
if peer in self.active_connections or not self.should_race_continue(blob): if peer in tried_for_this_blob:
continue continue
if peer in self.active_connections:
if peer not in re_add:
re_add.add(peer)
continue
if not self.should_race_continue(blob):
break
log.debug("request %s from %s:%i", blob_hash[:8], peer.address, peer.tcp_port) log.debug("request %s from %s:%i", blob_hash[:8], peer.address, peer.tcp_port)
t = self.loop.create_task(self.request_blob_from_peer(blob, peer, connection_id)) t = self.loop.create_task(self.request_blob_from_peer(blob, peer, connection_id))
self.active_connections[peer] = t self.active_connections[peer] = t
tried_for_this_blob.add(peer)
if not re_add:
self.peer_queue.put_nowait(list(batch)) self.peer_queue.put_nowait(list(batch))
await self.new_peer_or_finished() await self.new_peer_or_finished()
self.cleanup_active() self.cleanup_active()
@ -124,14 +133,11 @@ class BlobDownloader:
protocol.close() protocol.close()
async def download_blob(loop, config: 'Config', blob_manager: 'BlobManager', dht_node: 'Node', async def download_blob(loop, config: 'Config', blob_manager: 'BlobManager', node: 'Node',
blob_hash: str) -> 'AbstractBlob': blob_hash: str) -> 'AbstractBlob':
search_queue = asyncio.Queue(maxsize=config.max_connections_per_download) search_queue = asyncio.Queue(loop=loop, maxsize=config.max_connections_per_download)
search_queue.put_nowait(blob_hash) search_queue.put_nowait(blob_hash)
peer_queue, accumulate_task = dht_node.accumulate_peers(search_queue) peer_queue, accumulate_task = node.accumulate_peers(search_queue)
fixed_peers = None if not config.fixed_peers else await get_kademlia_peers_from_hosts(config.fixed_peers)
if fixed_peers:
loop.call_later(config.fixed_peer_delay, peer_queue.put_nowait, fixed_peers)
downloader = BlobDownloader(loop, config, blob_manager, peer_queue) downloader = BlobDownloader(loop, config, blob_manager, peer_queue)
try: try:
return await downloader.download_blob(blob_hash) return await downloader.download_blob(blob_hash)

View file

@ -1,7 +1,6 @@
import asyncio import asyncio
import binascii import binascii
import logging import logging
import socket
import typing import typing
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from lbry.blob_exchange.serialization import BlobResponse, BlobRequest, blob_response_types from lbry.blob_exchange.serialization import BlobResponse, BlobRequest, blob_response_types
@ -25,19 +24,19 @@ class BlobServerProtocol(asyncio.Protocol):
self.idle_timeout = idle_timeout self.idle_timeout = idle_timeout
self.transfer_timeout = transfer_timeout self.transfer_timeout = transfer_timeout
self.server_task: typing.Optional[asyncio.Task] = None self.server_task: typing.Optional[asyncio.Task] = None
self.started_listening = asyncio.Event() self.started_listening = asyncio.Event(loop=self.loop)
self.buf = b'' self.buf = b''
self.transport: typing.Optional[asyncio.Transport] = None self.transport: typing.Optional[asyncio.Transport] = None
self.lbrycrd_address = lbrycrd_address self.lbrycrd_address = lbrycrd_address
self.peer_address_and_port: typing.Optional[str] = None self.peer_address_and_port: typing.Optional[str] = None
self.started_transfer = asyncio.Event() self.started_transfer = asyncio.Event(loop=self.loop)
self.transfer_finished = asyncio.Event() self.transfer_finished = asyncio.Event(loop=self.loop)
self.close_on_idle_task: typing.Optional[asyncio.Task] = None self.close_on_idle_task: typing.Optional[asyncio.Task] = None
async def close_on_idle(self): async def close_on_idle(self):
while self.transport: while self.transport:
try: try:
await asyncio.wait_for(self.started_transfer.wait(), self.idle_timeout) await asyncio.wait_for(self.started_transfer.wait(), self.idle_timeout, loop=self.loop)
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.debug("closing idle connection from %s", self.peer_address_and_port) log.debug("closing idle connection from %s", self.peer_address_and_port)
return self.close() return self.close()
@ -101,7 +100,7 @@ class BlobServerProtocol(asyncio.Protocol):
log.debug("send %s to %s:%i", blob_hash, peer_address, peer_port) log.debug("send %s to %s:%i", blob_hash, peer_address, peer_port)
self.started_transfer.set() self.started_transfer.set()
try: try:
sent = await asyncio.wait_for(blob.sendfile(self), self.transfer_timeout) sent = await asyncio.wait_for(blob.sendfile(self), self.transfer_timeout, loop=self.loop)
if sent and sent > 0: if sent and sent > 0:
self.blob_manager.connection_manager.sent_data(self.peer_address_and_port, sent) self.blob_manager.connection_manager.sent_data(self.peer_address_and_port, sent)
log.info("sent %s (%i bytes) to %s:%i", blob_hash, sent, peer_address, peer_port) log.info("sent %s (%i bytes) to %s:%i", blob_hash, sent, peer_address, peer_port)
@ -138,7 +137,7 @@ class BlobServerProtocol(asyncio.Protocol):
try: try:
request = BlobRequest.deserialize(self.buf + data) request = BlobRequest.deserialize(self.buf + data)
self.buf = remainder self.buf = remainder
except (UnicodeDecodeError, JSONDecodeError): except JSONDecodeError:
log.error("request from %s is not valid json (%i bytes): %s", self.peer_address_and_port, log.error("request from %s is not valid json (%i bytes): %s", self.peer_address_and_port,
len(self.buf + data), '' if not data else binascii.hexlify(self.buf + data).decode()) len(self.buf + data), '' if not data else binascii.hexlify(self.buf + data).decode())
self.close() self.close()
@ -157,7 +156,7 @@ class BlobServer:
self.loop = loop self.loop = loop
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.server_task: typing.Optional[asyncio.Task] = None self.server_task: typing.Optional[asyncio.Task] = None
self.started_listening = asyncio.Event() self.started_listening = asyncio.Event(loop=self.loop)
self.lbrycrd_address = lbrycrd_address self.lbrycrd_address = lbrycrd_address
self.idle_timeout = idle_timeout self.idle_timeout = idle_timeout
self.transfer_timeout = transfer_timeout self.transfer_timeout = transfer_timeout
@ -168,13 +167,6 @@ class BlobServer:
raise Exception("already running") raise Exception("already running")
async def _start_server(): async def _start_server():
# checking if the port is in use
# thx https://stackoverflow.com/a/52872579
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(('localhost', port)) == 0:
# the port is already in use!
log.error("Failed to bind TCP %s:%d", interface, port)
server = await self.loop.create_server( server = await self.loop.create_server(
lambda: self.server_protocol_class(self.loop, self.blob_manager, self.lbrycrd_address, lambda: self.server_protocol_class(self.loop, self.blob_manager, self.lbrycrd_address,
self.idle_timeout, self.transfer_timeout), self.idle_timeout, self.transfer_timeout),

View file

@ -1,12 +1,13 @@
import os import os
import re import re
import sys import sys
import typing
import logging import logging
from typing import List, Dict, Tuple, Union, TypeVar, Generic, Optional
from argparse import ArgumentParser from argparse import ArgumentParser
from contextlib import contextmanager from contextlib import contextmanager
from appdirs import user_data_dir, user_config_dir
import yaml import yaml
from appdirs import user_data_dir, user_config_dir
from lbry.error import InvalidCurrencyError from lbry.error import InvalidCurrencyError
from lbry.dht import constants from lbry.dht import constants
from lbry.wallet.coinselection import STRATEGIES from lbry.wallet.coinselection import STRATEGIES
@ -15,7 +16,7 @@ log = logging.getLogger(__name__)
NOT_SET = type('NOT_SET', (object,), {}) # pylint: disable=invalid-name NOT_SET = type('NOT_SET', (object,), {}) # pylint: disable=invalid-name
T = TypeVar('T') T = typing.TypeVar('T')
CURRENCIES = { CURRENCIES = {
'BTC': {'type': 'crypto'}, 'BTC': {'type': 'crypto'},
@ -24,11 +25,11 @@ CURRENCIES = {
} }
class Setting(Generic[T]): class Setting(typing.Generic[T]):
def __init__(self, doc: str, default: Optional[T] = None, def __init__(self, doc: str, default: typing.Optional[T] = None,
previous_names: Optional[List[str]] = None, previous_names: typing.Optional[typing.List[str]] = None,
metavar: Optional[str] = None): metavar: typing.Optional[str] = None):
self.doc = doc self.doc = doc
self.default = default self.default = default
self.previous_names = previous_names or [] self.previous_names = previous_names or []
@ -45,7 +46,7 @@ class Setting(Generic[T]):
def no_cli_name(self): def no_cli_name(self):
return f"--no-{self.name.replace('_', '-')}" return f"--no-{self.name.replace('_', '-')}"
def __get__(self, obj: Optional['BaseConfig'], owner) -> T: def __get__(self, obj: typing.Optional['BaseConfig'], owner) -> T:
if obj is None: if obj is None:
return self return self
for location in obj.search_order: for location in obj.search_order:
@ -53,7 +54,7 @@ class Setting(Generic[T]):
return location[self.name] return location[self.name]
return self.default return self.default
def __set__(self, obj: 'BaseConfig', val: Union[T, NOT_SET]): def __set__(self, obj: 'BaseConfig', val: typing.Union[T, NOT_SET]):
if val == NOT_SET: if val == NOT_SET:
for location in obj.modify_order: for location in obj.modify_order:
if self.name in location: if self.name in location:
@ -63,18 +64,6 @@ class Setting(Generic[T]):
for location in obj.modify_order: for location in obj.modify_order:
location[self.name] = val location[self.name] = val
def is_set(self, obj: 'BaseConfig') -> bool:
for location in obj.search_order:
if self.name in location:
return True
return False
def is_set_to_default(self, obj: 'BaseConfig') -> bool:
for location in obj.search_order:
if self.name in location:
return location[self.name] == self.default
return False
def validate(self, value): def validate(self, value):
raise NotImplementedError() raise NotImplementedError()
@ -99,7 +88,7 @@ class String(Setting[str]):
f"Setting '{self.name}' must be a string." f"Setting '{self.name}' must be a string."
# TODO: removes this after pylint starts to understand generics # TODO: removes this after pylint starts to understand generics
def __get__(self, obj: Optional['BaseConfig'], owner) -> str: # pylint: disable=useless-super-delegation def __get__(self, obj: typing.Optional['BaseConfig'], owner) -> str: # pylint: disable=useless-super-delegation
return super().__get__(obj, owner) return super().__get__(obj, owner)
@ -203,7 +192,7 @@ class MaxKeyFee(Setting[dict]):
) )
parser.add_argument( parser.add_argument(
self.no_cli_name, self.no_cli_name,
help="Disable maximum key fee check.", help=f"Disable maximum key fee check.",
dest=self.name, dest=self.name,
const=None, const=None,
action="store_const", action="store_const",
@ -212,7 +201,7 @@ class MaxKeyFee(Setting[dict]):
class StringChoice(String): class StringChoice(String):
def __init__(self, doc: str, valid_values: List[str], default: str, *args, **kwargs): def __init__(self, doc: str, valid_values: typing.List[str], default: str, *args, **kwargs):
super().__init__(doc, default, *args, **kwargs) super().__init__(doc, default, *args, **kwargs)
if not valid_values: if not valid_values:
raise ValueError("No valid values provided") raise ValueError("No valid values provided")
@ -285,95 +274,17 @@ class Strings(ListSetting):
f"'{self.name}' must be a string." f"'{self.name}' must be a string."
class KnownHubsList:
def __init__(self, config: 'Config' = None, file_name: str = 'known_hubs.yml'):
self.file_name = file_name
self.path = os.path.join(config.wallet_dir, self.file_name) if config else None
self.hubs: Dict[Tuple[str, int], Dict] = {}
if self.exists:
self.load()
@property
def exists(self):
return self.path and os.path.exists(self.path)
@property
def serialized(self) -> Dict[str, Dict]:
return {f"{host}:{port}": details for (host, port), details in self.hubs.items()}
def filter(self, match_none=False, **kwargs):
if not kwargs:
return self.hubs
result = {}
for hub, details in self.hubs.items():
for key, constraint in kwargs.items():
value = details.get(key)
if value == constraint or (match_none and value is None):
result[hub] = details
break
return result
def load(self):
if self.path:
with open(self.path, 'r') as known_hubs_file:
raw = known_hubs_file.read()
for hub, details in yaml.safe_load(raw).items():
self.set(hub, details)
def save(self):
if self.path:
with open(self.path, 'w') as known_hubs_file:
known_hubs_file.write(yaml.safe_dump(self.serialized, default_flow_style=False))
def set(self, hub: str, details: Dict):
if hub and hub.count(':') == 1:
host, port = hub.split(':')
hub_parts = (host, int(port))
if hub_parts not in self.hubs:
self.hubs[hub_parts] = details
return hub
def add_hubs(self, hubs: List[str]):
added = False
for hub in hubs:
if self.set(hub, {}) is not None:
added = True
return added
def items(self):
return self.hubs.items()
def __bool__(self):
return len(self) > 0
def __len__(self):
return self.hubs.__len__()
def __iter__(self):
return iter(self.hubs)
class EnvironmentAccess: class EnvironmentAccess:
PREFIX = 'LBRY_' PREFIX = 'LBRY_'
def __init__(self, config: 'BaseConfig', environ: dict): def __init__(self, environ: dict):
self.configuration = config self.environ = environ
self.data = {}
if environ:
self.load(environ)
def load(self, environ):
for setting in self.configuration.get_settings():
value = environ.get(f'{self.PREFIX}{setting.name.upper()}', NOT_SET)
if value != NOT_SET and not (isinstance(setting, ListSetting) and value is None):
self.data[setting.name] = setting.deserialize(value)
def __contains__(self, item: str): def __contains__(self, item: str):
return item in self.data return f'{self.PREFIX}{item.upper()}' in self.environ
def __getitem__(self, item: str): def __getitem__(self, item: str):
return self.data[item] return self.environ[f'{self.PREFIX}{item.upper()}']
class ArgumentAccess: class ArgumentAccess:
@ -414,7 +325,7 @@ class ConfigFileAccess:
cls = type(self.configuration) cls = type(self.configuration)
with open(self.path, 'r') as config_file: with open(self.path, 'r') as config_file:
raw = config_file.read() raw = config_file.read()
serialized = yaml.safe_load(raw) or {} serialized = yaml.load(raw) or {}
for key, value in serialized.items(): for key, value in serialized.items():
attr = getattr(cls, key, None) attr = getattr(cls, key, None)
if attr is None: if attr is None:
@ -458,7 +369,7 @@ class ConfigFileAccess:
del self.data[key] del self.data[key]
TBC = TypeVar('TBC', bound='BaseConfig') TBC = typing.TypeVar('TBC', bound='BaseConfig')
class BaseConfig: class BaseConfig:
@ -532,7 +443,7 @@ class BaseConfig:
self.arguments = ArgumentAccess(self, args) self.arguments = ArgumentAccess(self, args)
def set_environment(self, environ=None): def set_environment(self, environ=None):
self.environment = EnvironmentAccess(self, environ or os.environ) self.environment = EnvironmentAccess(environ or os.environ)
def set_persisted(self, config_file_path=None): def set_persisted(self, config_file_path=None):
if config_file_path is None: if config_file_path is None:
@ -553,21 +464,19 @@ class BaseConfig:
class TranscodeConfig(BaseConfig): class TranscodeConfig(BaseConfig):
ffmpeg_path = String('A list of places to check for ffmpeg and ffprobe. ' ffmpeg_folder = String('The path to ffmpeg and ffprobe', '')
f'$data_dir/ffmpeg/bin and $PATH are checked afterward. Separator: {os.pathsep}',
'', previous_names=['ffmpeg_folder'])
video_encoder = String('FFmpeg codec and parameters for the video encoding. ' video_encoder = String('FFmpeg codec and parameters for the video encoding. '
'Example: libaom-av1 -crf 25 -b:v 0 -strict experimental', 'Example: libaom-av1 -crf 25 -b:v 0 -strict experimental',
'libx264 -crf 24 -preset faster -pix_fmt yuv420p') 'libx264 -crf 21 -preset faster -pix_fmt yuv420p')
video_bitrate_maximum = Integer('Maximum bits per second allowed for video streams (0 to disable).', 5_000_000) video_bitrate_maximum = Integer('Maximum bits per second allowed for video streams (0 to disable).', 8400000)
video_scaler = String('FFmpeg scaling parameters for reducing bitrate. ' video_scaler = String('FFmpeg scaling parameters for reducing bitrate. '
'Example: -vf "scale=-2:720,fps=24" -maxrate 5M -bufsize 3M', 'Example: -vf "scale=-2:720,fps=24" -maxrate 5M -bufsize 3M',
r'-vf "scale=if(gte(iw\,ih)\,min(1920\,iw)\,-2):if(lt(iw\,ih)\,min(1920\,ih)\,-2)" ' r'-vf "scale=if(gte(iw\,ih)\,min(2560\,iw)\,-2):if(lt(iw\,ih)\,min(2560\,ih)\,-2)" '
r'-maxrate 5500K -bufsize 5000K') r'-maxrate 8400K -bufsize 5000K')
audio_encoder = String('FFmpeg codec and parameters for the audio encoding. ' audio_encoder = String('FFmpeg codec and parameters for the audio encoding. '
'Example: libopus -b:a 128k', 'Example: libopus -b:a 128k',
'aac -b:a 160k') 'aac -b:a 160k')
volume_filter = String('FFmpeg filter for audio normalization. Exmple: -af loudnorm', '') volume_filter = String('FFmpeg filter for audio normalization.', '-af loudnorm')
volume_analysis_time = Integer('Maximum seconds into the file that we examine audio volume (0 to disable).', 240) volume_analysis_time = Integer('Maximum seconds into the file that we examine audio volume (0 to disable).', 240)
@ -589,9 +498,6 @@ class CLIConfig(TranscodeConfig):
class Config(CLIConfig): class Config(CLIConfig):
jurisdiction = String("Limit interactions to wallet server in this jurisdiction.")
# directories # directories
data_dir = Path("Directory path to store blobs.", metavar='DIR') data_dir = Path("Directory path to store blobs.", metavar='DIR')
download_dir = Path( download_dir = Path(
@ -613,7 +519,7 @@ class Config(CLIConfig):
"ports or have firewall rules you likely want to disable this.", True "ports or have firewall rules you likely want to disable this.", True
) )
udp_port = Integer("UDP port for communicating on the LBRY DHT", 4444, previous_names=['dht_node_port']) udp_port = Integer("UDP port for communicating on the LBRY DHT", 4444, previous_names=['dht_node_port'])
tcp_port = Integer("TCP port to listen for incoming blob requests", 4444, previous_names=['peer_port']) tcp_port = Integer("TCP port to listen for incoming blob requests", 3333, previous_names=['peer_port'])
prometheus_port = Integer("Port to expose prometheus metrics (off by default)", 0) prometheus_port = Integer("Port to expose prometheus metrics (off by default)", 0)
network_interface = String("Interface to use for the DHT and blob exchange", '0.0.0.0') network_interface = String("Interface to use for the DHT and blob exchange", '0.0.0.0')
@ -622,24 +528,17 @@ class Config(CLIConfig):
"Routing table bucket index below which we always split the bucket if given a new key to add to it and " "Routing table bucket index below which we always split the bucket if given a new key to add to it and "
"the bucket is full. As this value is raised the depth of the routing table (and number of peers in it) " "the bucket is full. As this value is raised the depth of the routing table (and number of peers in it) "
"will increase. This setting is used by seed nodes, you probably don't want to change it during normal " "will increase. This setting is used by seed nodes, you probably don't want to change it during normal "
"use.", 2 "use.", 1
)
is_bootstrap_node = Toggle(
"When running as a bootstrap node, disable all logic related to balancing the routing table, so we can "
"add as many peers as possible and better help first-runs.", False
) )
# protocol timeouts # protocol timeouts
download_timeout = Float("Cumulative timeout for a stream to begin downloading before giving up", 30.0) download_timeout = Float("Cumulative timeout for a stream to begin downloading before giving up", 30.0)
blob_download_timeout = Float("Timeout to download a blob from a peer", 30.0) blob_download_timeout = Float("Timeout to download a blob from a peer", 30.0)
hub_timeout = Float("Timeout when making a hub request", 30.0)
peer_connect_timeout = Float("Timeout to establish a TCP connection to a peer", 3.0) peer_connect_timeout = Float("Timeout to establish a TCP connection to a peer", 3.0)
node_rpc_timeout = Float("Timeout when making a DHT request", constants.RPC_TIMEOUT) node_rpc_timeout = Float("Timeout when making a DHT request", constants.RPC_TIMEOUT)
# blob announcement and download # blob announcement and download
save_blobs = Toggle("Save encrypted blob files for hosting, otherwise download blobs to memory only.", True) save_blobs = Toggle("Save encrypted blob files for hosting, otherwise download blobs to memory only.", True)
network_storage_limit = Integer("Disk space in MB to be allocated for helping the P2P network. 0 = disable", 0)
blob_storage_limit = Integer("Disk space in MB to be allocated for blob storage. 0 = no limit", 0)
blob_lru_cache_size = Integer( blob_lru_cache_size = Integer(
"LRU cache size for decrypted downloaded blobs used to minimize re-downloading the same blobs when " "LRU cache size for decrypted downloaded blobs used to minimize re-downloading the same blobs when "
"replying to a range request. Set to 0 to disable.", 32 "replying to a range request. Set to 0 to disable.", 32
@ -656,7 +555,6 @@ class Config(CLIConfig):
"Maximum number of peers to connect to while downloading a blob", 4, "Maximum number of peers to connect to while downloading a blob", 4,
previous_names=['max_connections_per_stream'] previous_names=['max_connections_per_stream']
) )
concurrent_hub_requests = Integer("Maximum number of concurrent hub requests", 32)
fixed_peer_delay = Float( fixed_peer_delay = Float(
"Amount of seconds before adding the reflector servers as potential peers to download from in case dht" "Amount of seconds before adding the reflector servers as potential peers to download from in case dht"
"peers are not found or are slow", 2.0 "peers are not found or are slow", 2.0
@ -677,22 +575,9 @@ class Config(CLIConfig):
) )
# servers # servers
reflector_servers = Servers("Reflector re-hosting servers for mirroring publishes", [ reflector_servers = Servers("Reflector re-hosting servers", [
('reflector.lbry.com', 5566) ('reflector.lbry.com', 5566)
]) ])
fixed_peers = Servers("Fixed peers to fall back to if none are found on P2P for a blob", [
('cdn.reflector.lbry.com', 5567)
])
tracker_servers = Servers("BitTorrent-compatible (BEP15) UDP trackers for helping P2P discovery", [
('tracker.lbry.com', 9252),
('tracker.lbry.grin.io', 9252),
('tracker.lbry.pigg.es', 9252),
('tracker.lizard.technology', 9252),
('s1.lbry.network', 9252),
])
lbryum_servers = Servers("SPV wallet servers", [ lbryum_servers = Servers("SPV wallet servers", [
('spv11.lbry.com', 50001), ('spv11.lbry.com', 50001),
('spv12.lbry.com', 50001), ('spv12.lbry.com', 50001),
@ -703,36 +588,29 @@ class Config(CLIConfig):
('spv17.lbry.com', 50001), ('spv17.lbry.com', 50001),
('spv18.lbry.com', 50001), ('spv18.lbry.com', 50001),
('spv19.lbry.com', 50001), ('spv19.lbry.com', 50001),
('hub.lbry.grin.io', 50001),
('hub.lizard.technology', 50001),
('s1.lbry.network', 50001),
]) ])
known_dht_nodes = Servers("Known nodes for bootstrapping connection to the DHT", [ known_dht_nodes = Servers("Known nodes for bootstrapping connection to the DHT", [
('dht.lbry.grin.io', 4444), # Grin
('dht.lbry.madiator.com', 4444), # Madiator
('dht.lbry.pigg.es', 4444), # Pigges
('lbrynet1.lbry.com', 4444), # US EAST ('lbrynet1.lbry.com', 4444), # US EAST
('lbrynet2.lbry.com', 4444), # US WEST ('lbrynet2.lbry.com', 4444), # US WEST
('lbrynet3.lbry.com', 4444), # EU ('lbrynet3.lbry.com', 4444), # EU
('lbrynet4.lbry.com', 4444), # ASIA ('lbrynet4.lbry.com', 4444) # ASIA
('dht.lizard.technology', 4444), # Jack
('s2.lbry.network', 4444),
]) ])
comment_server = String("Comment server API URL", "https://comments.lbry.com/api")
# blockchain # blockchain
blockchain_name = String("Blockchain name - lbrycrd_main, lbrycrd_regtest, or lbrycrd_testnet", 'lbrycrd_main') blockchain_name = String("Blockchain name - lbrycrd_main, lbrycrd_regtest, or lbrycrd_testnet", 'lbrycrd_main')
s3_headers_depth = Integer("download headers from s3 when the local height is more than 10 chunks behind", 96 * 10)
cache_time = Integer("Time to cache resolved claims", 150) # TODO: use this
# daemon # daemon
save_files = Toggle("Save downloaded files when calling `get` by default", False) save_files = Toggle("Save downloaded files when calling `get` by default", True)
components_to_skip = Strings("components which will be skipped during start-up of daemon", []) components_to_skip = Strings("components which will be skipped during start-up of daemon", [])
share_usage_data = Toggle( share_usage_data = Toggle(
"Whether to share usage stats and diagnostic info with LBRY.", False, "Whether to share usage stats and diagnostic info with LBRY.", False,
previous_names=['upload_log', 'upload_log', 'share_debug_info'] previous_names=['upload_log', 'upload_log', 'share_debug_info']
) )
track_bandwidth = Toggle("Track bandwidth usage", True) track_bandwidth = Toggle("Track bandwidth usage", True)
allowed_origin = String(
"Allowed `Origin` header value for API request (sent by browser), use * to allow "
"all hosts; default is to only allow API requests with no `Origin` value.", "")
# media server # media server
streaming_server = String('Host name and port to serve streaming media over range requests', streaming_server = String('Host name and port to serve streaming media over range requests',
@ -742,10 +620,8 @@ class Config(CLIConfig):
coin_selection_strategy = StringChoice( coin_selection_strategy = StringChoice(
"Strategy to use when selecting UTXOs for a transaction", "Strategy to use when selecting UTXOs for a transaction",
STRATEGIES, "prefer_confirmed" STRATEGIES, "standard")
)
transaction_cache_size = Integer("Transaction cache size", 2 ** 17)
save_resolved_claims = Toggle( save_resolved_claims = Toggle(
"Save content claims to the database when they are resolved to keep file_list up to date, " "Save content claims to the database when they are resolved to keep file_list up to date, "
"only disable this if file_x commands are not needed", True "only disable this if file_x commands are not needed", True
@ -762,7 +638,6 @@ class Config(CLIConfig):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.set_default_paths() self.set_default_paths()
self.known_hubs = KnownHubsList(self)
def set_default_paths(self): def set_default_paths(self):
if 'darwin' in sys.platform.lower(): if 'darwin' in sys.platform.lower():
@ -784,7 +659,7 @@ class Config(CLIConfig):
return os.path.join(self.data_dir, 'lbrynet.log') return os.path.join(self.data_dir, 'lbrynet.log')
def get_windows_directories() -> Tuple[str, str, str]: def get_windows_directories() -> typing.Tuple[str, str, str]:
from lbry.winpaths import get_path, FOLDERID, UserHandle, \ from lbry.winpaths import get_path, FOLDERID, UserHandle, \
PathNotFoundException # pylint: disable=import-outside-toplevel PathNotFoundException # pylint: disable=import-outside-toplevel
@ -806,19 +681,18 @@ def get_windows_directories() -> Tuple[str, str, str]:
return data_dir, lbryum_dir, download_dir return data_dir, lbryum_dir, download_dir
def get_darwin_directories() -> Tuple[str, str, str]: def get_darwin_directories() -> typing.Tuple[str, str, str]:
data_dir = user_data_dir('LBRY') data_dir = user_data_dir('LBRY')
lbryum_dir = os.path.expanduser('~/.lbryum') lbryum_dir = os.path.expanduser('~/.lbryum')
download_dir = os.path.expanduser('~/Downloads') download_dir = os.path.expanduser('~/Downloads')
return data_dir, lbryum_dir, download_dir return data_dir, lbryum_dir, download_dir
def get_linux_directories() -> Tuple[str, str, str]: def get_linux_directories() -> typing.Tuple[str, str, str]:
try: try:
with open(os.path.join(user_config_dir(), 'user-dirs.dirs'), 'r') as xdg: with open(os.path.join(user_config_dir(), 'user-dirs.dirs'), 'r') as xdg:
down_dir = re.search(r'XDG_DOWNLOAD_DIR=(.+)', xdg.read()) down_dir = re.search(r'XDG_DOWNLOAD_DIR=(.+)', xdg.read()).group(1)
if down_dir: down_dir = re.sub(r'\$HOME', os.getenv('HOME') or os.path.expanduser("~/"), down_dir)
down_dir = re.sub(r'\$HOME', os.getenv('HOME') or os.path.expanduser("~/"), down_dir.group(1))
download_dir = re.sub('\"', '', down_dir) download_dir = re.sub('\"', '', down_dir)
except OSError: except OSError:
download_dir = os.getenv('XDG_DOWNLOAD_DIR') download_dir = os.getenv('XDG_DOWNLOAD_DIR')

View file

@ -67,7 +67,7 @@ class ConnectionManager:
while True: while True:
last = time.perf_counter() last = time.perf_counter()
await asyncio.sleep(0.1) await asyncio.sleep(0.1, loop=self.loop)
self._status['incoming_bps'].clear() self._status['incoming_bps'].clear()
self._status['outgoing_bps'].clear() self._status['outgoing_bps'].clear()
now = time.perf_counter() now = time.perf_counter()

View file

@ -1,9 +1,6 @@
import asyncio import asyncio
import typing import typing
import logging import logging
from prometheus_client import Counter, Gauge
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbry.dht.node import Node from lbry.dht.node import Node
from lbry.extras.daemon.storage import SQLiteStorage from lbry.extras.daemon.storage import SQLiteStorage
@ -12,59 +9,45 @@ log = logging.getLogger(__name__)
class BlobAnnouncer: class BlobAnnouncer:
announcements_sent_metric = Counter(
"announcements_sent", "Number of announcements sent and their respective status.", namespace="dht_node",
labelnames=("peers", "error"),
)
announcement_queue_size_metric = Gauge(
"announcement_queue_size", "Number of hashes waiting to be announced.", namespace="dht_node",
labelnames=("scope",)
)
def __init__(self, loop: asyncio.AbstractEventLoop, node: 'Node', storage: 'SQLiteStorage'): def __init__(self, loop: asyncio.AbstractEventLoop, node: 'Node', storage: 'SQLiteStorage'):
self.loop = loop self.loop = loop
self.node = node self.node = node
self.storage = storage self.storage = storage
self.announce_task: asyncio.Task = None self.announce_task: asyncio.Task = None
self.announce_queue: typing.List[str] = [] self.announce_queue: typing.List[str] = []
self._done = asyncio.Event()
self.announced = set()
async def _run_consumer(self): async def _submit_announcement(self, blob_hash):
while self.announce_queue:
try: try:
blob_hash = self.announce_queue.pop()
peers = len(await self.node.announce_blob(blob_hash)) peers = len(await self.node.announce_blob(blob_hash))
self.announcements_sent_metric.labels(peers=peers, error=False).inc()
if peers > 4: if peers > 4:
self.announced.add(blob_hash) return blob_hash
else: else:
log.debug("failed to announce %s, could only find %d peers, retrying soon.", blob_hash[:8], peers) log.debug("failed to announce %s, could only find %d peers, retrying soon.", blob_hash[:8], peers)
except Exception as err: except Exception as err:
self.announcements_sent_metric.labels(peers=0, error=True).inc() if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise err
log.warning("error announcing %s: %s", blob_hash[:8], str(err)) log.warning("error announcing %s: %s", blob_hash[:8], str(err))
async def _announce(self, batch_size: typing.Optional[int] = 10): async def _announce(self, batch_size: typing.Optional[int] = 10):
while batch_size: while batch_size:
if not self.node.joined.is_set(): if not self.node.joined.is_set():
await self.node.joined.wait() await self.node.joined.wait()
await asyncio.sleep(60) await asyncio.sleep(60, loop=self.loop)
if not self.node.protocol.routing_table.get_peers(): if not self.node.protocol.routing_table.get_peers():
log.warning("No peers in DHT, announce round skipped") log.warning("No peers in DHT, announce round skipped")
continue continue
self.announce_queue.extend(await self.storage.get_blobs_to_announce()) self.announce_queue.extend(await self.storage.get_blobs_to_announce())
self.announcement_queue_size_metric.labels(scope="global").set(len(self.announce_queue))
log.debug("announcer task wake up, %d blobs to announce", len(self.announce_queue)) log.debug("announcer task wake up, %d blobs to announce", len(self.announce_queue))
while len(self.announce_queue) > 0: while len(self.announce_queue) > 0:
log.info("%i blobs to announce", len(self.announce_queue)) log.info("%i blobs to announce", len(self.announce_queue))
await asyncio.gather(*[self._run_consumer() for _ in range(batch_size)]) announced = await asyncio.gather(*[
announced = list(filter(None, self.announced)) self._submit_announcement(
self.announce_queue.pop()) for _ in range(batch_size) if self.announce_queue
], loop=self.loop)
announced = list(filter(None, announced))
if announced: if announced:
await self.storage.update_last_announced_blobs(announced) await self.storage.update_last_announced_blobs(announced)
log.info("announced %i blobs", len(announced)) log.info("announced %i blobs", len(announced))
self.announced.clear()
self._done.set()
self._done.clear()
def start(self, batch_size: typing.Optional[int] = 10): def start(self, batch_size: typing.Optional[int] = 10):
assert not self.announce_task or self.announce_task.done(), "already running" assert not self.announce_task or self.announce_task.done(), "already running"
@ -73,6 +56,3 @@ class BlobAnnouncer:
def stop(self): def stop(self):
if self.announce_task and not self.announce_task.done(): if self.announce_task and not self.announce_task.done():
self.announce_task.cancel() self.announce_task.cancel()
def wait(self):
return self._done.wait()

View file

@ -20,6 +20,7 @@ MAYBE_PING_DELAY = 300 # 5 minutes
CHECK_REFRESH_INTERVAL = REFRESH_INTERVAL / 5 CHECK_REFRESH_INTERVAL = REFRESH_INTERVAL / 5
RPC_ID_LENGTH = 20 RPC_ID_LENGTH = 20
PROTOCOL_VERSION = 1 PROTOCOL_VERSION = 1
BOTTOM_OUT_LIMIT = 3
MSG_SIZE_LIMIT = 1400 MSG_SIZE_LIMIT = 1400

View file

@ -1,11 +1,9 @@
import logging import logging
import asyncio import asyncio
import typing import typing
import binascii
import socket import socket
from lbry.utils import resolve_host
from prometheus_client import Gauge
from lbry.utils import aclosing, resolve_host
from lbry.dht import constants from lbry.dht import constants
from lbry.dht.peer import make_kademlia_peer from lbry.dht.peer import make_kademlia_peer
from lbry.dht.protocol.distance import Distance from lbry.dht.protocol.distance import Distance
@ -20,32 +18,20 @@ log = logging.getLogger(__name__)
class Node: class Node:
storing_peers_metric = Gauge(
"storing_peers", "Number of peers storing blobs announced to this node", namespace="dht_node",
labelnames=("scope",),
)
stored_blob_with_x_bytes_colliding = Gauge(
"stored_blobs_x_bytes_colliding", "Number of blobs with at least X bytes colliding with this node id prefix",
namespace="dht_node", labelnames=("amount",)
)
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', node_id: bytes, udp_port: int, def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', node_id: bytes, udp_port: int,
internal_udp_port: int, peer_port: int, external_ip: str, rpc_timeout: float = constants.RPC_TIMEOUT, internal_udp_port: int, peer_port: int, external_ip: str, rpc_timeout: float = constants.RPC_TIMEOUT,
split_buckets_under_index: int = constants.SPLIT_BUCKETS_UNDER_INDEX, is_bootstrap_node: bool = False, split_buckets_under_index: int = constants.SPLIT_BUCKETS_UNDER_INDEX,
storage: typing.Optional['SQLiteStorage'] = None): storage: typing.Optional['SQLiteStorage'] = None):
self.loop = loop self.loop = loop
self.internal_udp_port = internal_udp_port self.internal_udp_port = internal_udp_port
self.protocol = KademliaProtocol(loop, peer_manager, node_id, external_ip, udp_port, peer_port, rpc_timeout, self.protocol = KademliaProtocol(loop, peer_manager, node_id, external_ip, udp_port, peer_port, rpc_timeout,
split_buckets_under_index, is_bootstrap_node) split_buckets_under_index)
self.listening_port: asyncio.DatagramTransport = None self.listening_port: asyncio.DatagramTransport = None
self.joined = asyncio.Event() self.joined = asyncio.Event(loop=self.loop)
self._join_task: asyncio.Task = None self._join_task: asyncio.Task = None
self._refresh_task: asyncio.Task = None self._refresh_task: asyncio.Task = None
self._storage = storage self._storage = storage
@property
def stored_blob_hashes(self):
return self.protocol.data_store.keys()
async def refresh_node(self, force_once=False): async def refresh_node(self, force_once=False):
while True: while True:
# remove peers with expired blob announcements from the datastore # remove peers with expired blob announcements from the datastore
@ -55,21 +41,17 @@ class Node:
# add all peers in the routing table # add all peers in the routing table
total_peers.extend(self.protocol.routing_table.get_peers()) total_peers.extend(self.protocol.routing_table.get_peers())
# add all the peers who have announced blobs to us # add all the peers who have announced blobs to us
storing_peers = self.protocol.data_store.get_storing_contacts() total_peers.extend(self.protocol.data_store.get_storing_contacts())
self.storing_peers_metric.labels("global").set(len(storing_peers))
total_peers.extend(storing_peers)
counts = {0: 0, 1: 0, 2: 0}
node_id = self.protocol.node_id
for blob_hash in self.protocol.data_store.keys():
bytes_colliding = 0 if blob_hash[0] != node_id[0] else 2 if blob_hash[1] == node_id[1] else 1
counts[bytes_colliding] += 1
self.stored_blob_with_x_bytes_colliding.labels(amount=0).set(counts[0])
self.stored_blob_with_x_bytes_colliding.labels(amount=1).set(counts[1])
self.stored_blob_with_x_bytes_colliding.labels(amount=2).set(counts[2])
# get ids falling in the midpoint of each bucket that hasn't been recently updated # get ids falling in the midpoint of each bucket that hasn't been recently updated
node_ids = self.protocol.routing_table.get_refresh_list(0, True) node_ids = self.protocol.routing_table.get_refresh_list(0, True)
# if we have 3 or fewer populated buckets get two random ids in the range of each to try and
# populate/split the buckets further
buckets_with_contacts = self.protocol.routing_table.buckets_with_contacts()
if buckets_with_contacts <= 3:
for i in range(buckets_with_contacts):
node_ids.append(self.protocol.routing_table.random_id_in_bucket_range(i))
node_ids.append(self.protocol.routing_table.random_id_in_bucket_range(i))
if self.protocol.routing_table.get_peers(): if self.protocol.routing_table.get_peers():
# if we have node ids to look up, perform the iterative search until we have k results # if we have node ids to look up, perform the iterative search until we have k results
@ -79,7 +61,7 @@ class Node:
else: else:
if force_once: if force_once:
break break
fut = asyncio.Future() fut = asyncio.Future(loop=self.loop)
self.loop.call_later(constants.REFRESH_INTERVAL // 4, fut.set_result, None) self.loop.call_later(constants.REFRESH_INTERVAL // 4, fut.set_result, None)
await fut await fut
continue continue
@ -93,12 +75,12 @@ class Node:
if force_once: if force_once:
break break
fut = asyncio.Future() fut = asyncio.Future(loop=self.loop)
self.loop.call_later(constants.REFRESH_INTERVAL, fut.set_result, None) self.loop.call_later(constants.REFRESH_INTERVAL, fut.set_result, None)
await fut await fut
async def announce_blob(self, blob_hash: str) -> typing.List[bytes]: async def announce_blob(self, blob_hash: str) -> typing.List[bytes]:
hash_value = bytes.fromhex(blob_hash) hash_value = binascii.unhexlify(blob_hash.encode())
assert len(hash_value) == constants.HASH_LENGTH assert len(hash_value) == constants.HASH_LENGTH
peers = await self.peer_search(hash_value) peers = await self.peer_search(hash_value)
@ -108,12 +90,12 @@ class Node:
for peer in peers: for peer in peers:
log.debug("store to %s %s %s", peer.address, peer.udp_port, peer.tcp_port) log.debug("store to %s %s %s", peer.address, peer.udp_port, peer.tcp_port)
stored_to_tup = await asyncio.gather( stored_to_tup = await asyncio.gather(
*(self.protocol.store_to_peer(hash_value, peer) for peer in peers) *(self.protocol.store_to_peer(hash_value, peer) for peer in peers), loop=self.loop
) )
stored_to = [node_id for node_id, contacted in stored_to_tup if contacted] stored_to = [node_id for node_id, contacted in stored_to_tup if contacted]
if stored_to: if stored_to:
log.debug( log.debug(
"Stored %s to %i of %i attempted peers", hash_value.hex()[:8], "Stored %s to %i of %i attempted peers", binascii.hexlify(hash_value).decode()[:8],
len(stored_to), len(peers) len(stored_to), len(peers)
) )
else: else:
@ -182,35 +164,38 @@ class Node:
for address, udp_port in known_node_urls or [] for address, udp_port in known_node_urls or []
])) ]))
except socket.gaierror: except socket.gaierror:
await asyncio.sleep(30) await asyncio.sleep(30, loop=self.loop)
continue continue
self.protocol.peer_manager.reset() self.protocol.peer_manager.reset()
self.protocol.ping_queue.enqueue_maybe_ping(*seed_peers, delay=0.0) self.protocol.ping_queue.enqueue_maybe_ping(*seed_peers, delay=0.0)
await self.peer_search(self.protocol.node_id, shortlist=seed_peers, count=32) await self.peer_search(self.protocol.node_id, shortlist=seed_peers, count=32)
await asyncio.sleep(1) await asyncio.sleep(1, loop=self.loop)
def start(self, interface: str, known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None): def start(self, interface: str, known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None):
self._join_task = self.loop.create_task(self.join_network(interface, known_node_urls)) self._join_task = self.loop.create_task(self.join_network(interface, known_node_urls))
def get_iterative_node_finder(self, key: bytes, shortlist: typing.Optional[typing.List['KademliaPeer']] = None, def get_iterative_node_finder(self, key: bytes, shortlist: typing.Optional[typing.List['KademliaPeer']] = None,
bottom_out_limit: int = constants.BOTTOM_OUT_LIMIT,
max_results: int = constants.K) -> IterativeNodeFinder: max_results: int = constants.K) -> IterativeNodeFinder:
shortlist = shortlist or self.protocol.routing_table.find_close_peers(key)
return IterativeNodeFinder(self.loop, self.protocol, key, max_results, shortlist) return IterativeNodeFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol,
key, bottom_out_limit, max_results, None, shortlist)
def get_iterative_value_finder(self, key: bytes, shortlist: typing.Optional[typing.List['KademliaPeer']] = None, def get_iterative_value_finder(self, key: bytes, shortlist: typing.Optional[typing.List['KademliaPeer']] = None,
bottom_out_limit: int = 40,
max_results: int = -1) -> IterativeValueFinder: max_results: int = -1) -> IterativeValueFinder:
shortlist = shortlist or self.protocol.routing_table.find_close_peers(key)
return IterativeValueFinder(self.loop, self.protocol, key, max_results, shortlist) return IterativeValueFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol,
key, bottom_out_limit, max_results, None, shortlist)
async def peer_search(self, node_id: bytes, count=constants.K, max_results=constants.K * 2, async def peer_search(self, node_id: bytes, count=constants.K, max_results=constants.K * 2,
shortlist: typing.Optional[typing.List['KademliaPeer']] = None bottom_out_limit=20, shortlist: typing.Optional[typing.List['KademliaPeer']] = None
) -> typing.List['KademliaPeer']: ) -> typing.List['KademliaPeer']:
peers = [] peers = []
async with aclosing(self.get_iterative_node_finder( async for iteration_peers in self.get_iterative_node_finder(
node_id, shortlist=shortlist, max_results=max_results)) as node_finder: node_id, shortlist=shortlist, bottom_out_limit=bottom_out_limit, max_results=max_results):
async for iteration_peers in node_finder:
peers.extend(iteration_peers) peers.extend(iteration_peers)
distance = Distance(node_id) distance = Distance(node_id)
peers.sort(key=lambda peer: distance(peer.node_id)) peers.sort(key=lambda peer: distance(peer.node_id))
@ -237,8 +222,8 @@ class Node:
# prioritize peers who reply to a dht ping first # prioritize peers who reply to a dht ping first
# this minimizes attempting to make tcp connections that won't work later to dead or unreachable peers # this minimizes attempting to make tcp connections that won't work later to dead or unreachable peers
async with aclosing(self.get_iterative_value_finder(bytes.fromhex(blob_hash))) as value_finder:
async for results in value_finder: async for results in self.get_iterative_value_finder(binascii.unhexlify(blob_hash.encode())):
to_put = [] to_put = []
for peer in results: for peer in results:
if peer.address == self.protocol.external_ip and self.protocol.peer_port == peer.tcp_port: if peer.address == self.protocol.external_ip and self.protocol.peer_port == peer.tcp_port:
@ -271,12 +256,5 @@ class Node:
def accumulate_peers(self, search_queue: asyncio.Queue, def accumulate_peers(self, search_queue: asyncio.Queue,
peer_queue: typing.Optional[asyncio.Queue] = None peer_queue: typing.Optional[asyncio.Queue] = None
) -> typing.Tuple[asyncio.Queue, asyncio.Task]: ) -> typing.Tuple[asyncio.Queue, asyncio.Task]:
queue = peer_queue or asyncio.Queue() queue = peer_queue or asyncio.Queue(loop=self.loop)
return queue, self.loop.create_task(self._accumulate_peers_for_value(search_queue, queue)) return queue, self.loop.create_task(self._accumulate_peers_for_value(search_queue, queue))
async def get_kademlia_peers_from_hosts(peer_list: typing.List[typing.Tuple[str, int]]) -> typing.List['KademliaPeer']:
peer_address_list = [(await resolve_host(url, port, proto='tcp'), port) for url, port in peer_list]
kademlia_peer_list = [make_kademlia_peer(None, address, None, tcp_port=port, allow_localhost=True)
for address, port in peer_address_list]
return kademlia_peer_list

View file

@ -1,21 +1,18 @@
import typing import typing
import asyncio import asyncio
import logging import logging
import ipaddress
from binascii import hexlify
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from prometheus_client import Gauge
from lbry.utils import is_valid_public_ipv4 as _is_valid_public_ipv4, LRUCache
from lbry.dht import constants from lbry.dht import constants
from lbry.dht.serialization.datagram import make_compact_address, make_compact_ip, decode_compact_address from lbry.dht.serialization.datagram import make_compact_address, make_compact_ip, decode_compact_address
ALLOW_LOCALHOST = False
CACHE_SIZE = 16384
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@lru_cache(CACHE_SIZE) @lru_cache(1024)
def make_kademlia_peer(node_id: typing.Optional[bytes], address: typing.Optional[str], def make_kademlia_peer(node_id: typing.Optional[bytes], address: typing.Optional[str],
udp_port: typing.Optional[int] = None, udp_port: typing.Optional[int] = None,
tcp_port: typing.Optional[int] = None, tcp_port: typing.Optional[int] = None,
@ -23,32 +20,40 @@ def make_kademlia_peer(node_id: typing.Optional[bytes], address: typing.Optional
return KademliaPeer(address, node_id, udp_port, tcp_port=tcp_port, allow_localhost=allow_localhost) return KademliaPeer(address, node_id, udp_port, tcp_port=tcp_port, allow_localhost=allow_localhost)
# the ipaddress module does not show these subnets as reserved
CARRIER_GRADE_NAT_SUBNET = ipaddress.ip_network('100.64.0.0/10')
IPV4_TO_6_RELAY_SUBNET = ipaddress.ip_network('192.88.99.0/24')
ALLOW_LOCALHOST = False
def is_valid_public_ipv4(address, allow_localhost: bool = False): def is_valid_public_ipv4(address, allow_localhost: bool = False):
allow_localhost = bool(allow_localhost or ALLOW_LOCALHOST) allow_localhost = bool(allow_localhost or ALLOW_LOCALHOST)
return _is_valid_public_ipv4(address, allow_localhost) try:
parsed_ip = ipaddress.ip_address(address)
if parsed_ip.is_loopback and allow_localhost:
return True
return not any((parsed_ip.version != 4, parsed_ip.is_unspecified, parsed_ip.is_link_local,
parsed_ip.is_loopback, parsed_ip.is_multicast, parsed_ip.is_reserved, parsed_ip.is_private,
parsed_ip.is_reserved,
CARRIER_GRADE_NAT_SUBNET.supernet_of(ipaddress.ip_network(f"{address}/32")),
IPV4_TO_6_RELAY_SUBNET.supernet_of(ipaddress.ip_network(f"{address}/32"))))
except ipaddress.AddressValueError:
return False
class PeerManager: class PeerManager:
peer_manager_keys_metric = Gauge(
"peer_manager_keys", "Number of keys tracked by PeerManager dicts (sum)", namespace="dht_node",
labelnames=("scope",)
)
def __init__(self, loop: asyncio.AbstractEventLoop): def __init__(self, loop: asyncio.AbstractEventLoop):
self._loop = loop self._loop = loop
self._rpc_failures: typing.Dict[ self._rpc_failures: typing.Dict[
typing.Tuple[str, int], typing.Tuple[typing.Optional[float], typing.Optional[float]] typing.Tuple[str, int], typing.Tuple[typing.Optional[float], typing.Optional[float]]
] = LRUCache(CACHE_SIZE) ] = {}
self._last_replied: typing.Dict[typing.Tuple[str, int], float] = LRUCache(CACHE_SIZE) self._last_replied: typing.Dict[typing.Tuple[str, int], float] = {}
self._last_sent: typing.Dict[typing.Tuple[str, int], float] = LRUCache(CACHE_SIZE) self._last_sent: typing.Dict[typing.Tuple[str, int], float] = {}
self._last_requested: typing.Dict[typing.Tuple[str, int], float] = LRUCache(CACHE_SIZE) self._last_requested: typing.Dict[typing.Tuple[str, int], float] = {}
self._node_id_mapping: typing.Dict[typing.Tuple[str, int], bytes] = LRUCache(CACHE_SIZE) self._node_id_mapping: typing.Dict[typing.Tuple[str, int], bytes] = {}
self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = LRUCache(CACHE_SIZE) self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = {}
self._node_tokens: typing.Dict[bytes, (float, bytes)] = LRUCache(CACHE_SIZE) self._node_tokens: typing.Dict[bytes, (float, bytes)] = {}
def count_cache_keys(self):
return len(self._rpc_failures) + len(self._last_replied) + len(self._last_sent) + len(
self._last_requested) + len(self._node_id_mapping) + len(self._node_id_reverse_mapping) + len(
self._node_tokens)
def reset(self): def reset(self):
for statistic in (self._rpc_failures, self._last_replied, self._last_sent, self._last_requested): for statistic in (self._rpc_failures, self._last_replied, self._last_sent, self._last_requested):
@ -98,10 +103,6 @@ class PeerManager:
self._node_id_mapping.pop(self._node_id_reverse_mapping.pop(node_id)) self._node_id_mapping.pop(self._node_id_reverse_mapping.pop(node_id))
self._node_id_mapping[(address, udp_port)] = node_id self._node_id_mapping[(address, udp_port)] = node_id
self._node_id_reverse_mapping[node_id] = (address, udp_port) self._node_id_reverse_mapping[node_id] = (address, udp_port)
self.peer_manager_keys_metric.labels("global").set(self.count_cache_keys())
def get_node_id_for_endpoint(self, address, port):
return self._node_id_mapping.get((address, port))
def prune(self): # TODO: periodically call this def prune(self): # TODO: periodically call this
now = self._loop.time() now = self._loop.time()
@ -153,8 +154,7 @@ class PeerManager:
def peer_is_good(self, peer: 'KademliaPeer'): def peer_is_good(self, peer: 'KademliaPeer'):
return self.contact_triple_is_good(peer.node_id, peer.address, peer.udp_port) return self.contact_triple_is_good(peer.node_id, peer.address, peer.udp_port)
def decode_tcp_peer_from_compact_address(self, compact_address: bytes) -> 'KademliaPeer': # pylint: disable=no-self-use
def decode_tcp_peer_from_compact_address(compact_address: bytes) -> 'KademliaPeer': # pylint: disable=no-self-use
node_id, address, tcp_port = decode_compact_address(compact_address) node_id, address, tcp_port = decode_compact_address(compact_address)
return make_kademlia_peer(node_id, address, udp_port=None, tcp_port=tcp_port) return make_kademlia_peer(node_id, address, udp_port=None, tcp_port=tcp_port)
@ -171,11 +171,11 @@ class KademliaPeer:
def __post_init__(self): def __post_init__(self):
if self._node_id is not None: if self._node_id is not None:
if not len(self._node_id) == constants.HASH_LENGTH: if not len(self._node_id) == constants.HASH_LENGTH:
raise ValueError("invalid node_id: {}".format(self._node_id.hex())) raise ValueError("invalid node_id: {}".format(hexlify(self._node_id).decode()))
if self.udp_port is not None and not 1024 <= self.udp_port <= 65535: if self.udp_port is not None and not 1 <= self.udp_port <= 65535:
raise ValueError(f"invalid udp port: {self.address}:{self.udp_port}") raise ValueError("invalid udp port")
if self.tcp_port is not None and not 1024 <= self.tcp_port <= 65535: if self.tcp_port is not None and not 1 <= self.tcp_port <= 65535:
raise ValueError(f"invalid tcp port: {self.address}:{self.tcp_port}") raise ValueError("invalid tcp port")
if not is_valid_public_ipv4(self.address, self.allow_localhost): if not is_valid_public_ipv4(self.address, self.allow_localhost):
raise ValueError(f"invalid ip address: '{self.address}'") raise ValueError(f"invalid ip address: '{self.address}'")
@ -194,6 +194,3 @@ class KademliaPeer:
def compact_ip(self): def compact_ip(self):
return make_compact_ip(self.address) return make_compact_ip(self.address)
def __str__(self):
return f"{self.__class__.__name__}({self.node_id.hex()[:8]}@{self.address}:{self.udp_port}-{self.tcp_port})"

View file

@ -16,12 +16,6 @@ class DictDataStore:
self._peer_manager = peer_manager self._peer_manager = peer_manager
self.completed_blobs: typing.Set[str] = set() self.completed_blobs: typing.Set[str] = set()
def keys(self):
return self._data_store.keys()
def __len__(self):
return self._data_store.__len__()
def removed_expired_peers(self): def removed_expired_peers(self):
now = self.loop.time() now = self.loop.time()
keys = list(self._data_store.keys()) keys = list(self._data_store.keys())

View file

@ -1,17 +1,18 @@
import asyncio import asyncio
from binascii import hexlify
from itertools import chain from itertools import chain
from collections import defaultdict, OrderedDict from collections import defaultdict
from collections.abc import AsyncIterator
import typing import typing
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from lbry.dht import constants from lbry.dht import constants
from lbry.dht.error import RemoteException, TransportNotConnected from lbry.dht.error import RemoteException, TransportNotConnected
from lbry.dht.protocol.distance import Distance from lbry.dht.protocol.distance import Distance
from lbry.dht.peer import make_kademlia_peer, decode_tcp_peer_from_compact_address from lbry.dht.peer import make_kademlia_peer
from lbry.dht.serialization.datagram import PAGE_KEY from lbry.dht.serialization.datagram import PAGE_KEY
if TYPE_CHECKING: if TYPE_CHECKING:
from lbry.dht.protocol.routing_table import TreeRoutingTable
from lbry.dht.protocol.protocol import KademliaProtocol from lbry.dht.protocol.protocol import KademliaProtocol
from lbry.dht.peer import PeerManager, KademliaPeer from lbry.dht.peer import PeerManager, KademliaPeer
@ -26,15 +27,6 @@ class FindResponse:
def get_close_triples(self) -> typing.List[typing.Tuple[bytes, str, int]]: def get_close_triples(self) -> typing.List[typing.Tuple[bytes, str, int]]:
raise NotImplementedError() raise NotImplementedError()
def get_close_kademlia_peers(self, peer_info) -> typing.Generator[typing.Iterator['KademliaPeer'], None, None]:
for contact_triple in self.get_close_triples():
node_id, address, udp_port = contact_triple
try:
yield make_kademlia_peer(node_id, address, udp_port)
except ValueError:
log.warning("misbehaving peer %s:%i returned peer with reserved ip %s:%i", peer_info.address,
peer_info.udp_port, address, udp_port)
class FindNodeResponse(FindResponse): class FindNodeResponse(FindResponse):
def __init__(self, key: bytes, close_triples: typing.List[typing.Tuple[bytes, str, int]]): def __init__(self, key: bytes, close_triples: typing.List[typing.Tuple[bytes, str, int]]):
@ -65,33 +57,57 @@ class FindValueResponse(FindResponse):
return [(node_id, address.decode(), port) for node_id, address, port in self.close_triples] return [(node_id, address.decode(), port) for node_id, address, port in self.close_triples]
class IterativeFinder(AsyncIterator): def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes,
def __init__(self, loop: asyncio.AbstractEventLoop, shortlist: typing.Optional[typing.List['KademliaPeer']]) -> typing.List['KademliaPeer']:
protocol: 'KademliaProtocol', key: bytes, """
max_results: typing.Optional[int] = constants.K, If not provided, initialize the shortlist of peers to probe to the (up to) k closest peers in the routing table
:param routing_table: a TreeRoutingTable
:param key: a 48 byte hash
:param shortlist: optional manually provided shortlist, this is done during bootstrapping when there are no
peers in the routing table. During bootstrap the shortlist is set to be the seed nodes.
"""
if len(key) != constants.HASH_LENGTH:
raise ValueError("invalid key length: %i" % len(key))
return shortlist or routing_table.find_close_peers(key)
class IterativeFinder:
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
bottom_out_limit: typing.Optional[int] = 2, max_results: typing.Optional[int] = constants.K,
exclude: typing.Optional[typing.List[typing.Tuple[str, int]]] = None,
shortlist: typing.Optional[typing.List['KademliaPeer']] = None): shortlist: typing.Optional[typing.List['KademliaPeer']] = None):
if len(key) != constants.HASH_LENGTH: if len(key) != constants.HASH_LENGTH:
raise ValueError("invalid key length: %i" % len(key)) raise ValueError("invalid key length: %i" % len(key))
self.loop = loop self.loop = loop
self.peer_manager = protocol.peer_manager self.peer_manager = peer_manager
self.routing_table = routing_table
self.protocol = protocol self.protocol = protocol
self.key = key self.key = key
self.max_results = max(constants.K, max_results) self.bottom_out_limit = bottom_out_limit
self.max_results = max_results
self.exclude = exclude or []
self.active: typing.Dict['KademliaPeer', int] = OrderedDict() # peer: distance, sorted self.active: typing.Set['KademliaPeer'] = set()
self.contacted: typing.Set['KademliaPeer'] = set() self.contacted: typing.Set['KademliaPeer'] = set()
self.distance = Distance(key) self.distance = Distance(key)
self.iteration_queue = asyncio.Queue() self.closest_peer: typing.Optional['KademliaPeer'] = None
self.prev_closest_peer: typing.Optional['KademliaPeer'] = None
self.running_probes: typing.Dict['KademliaPeer', asyncio.Task] = {} self.iteration_queue = asyncio.Queue(loop=self.loop)
self.running_probes: typing.Set[asyncio.Task] = set()
self.iteration_count = 0 self.iteration_count = 0
self.bottom_out_count = 0
self.running = False self.running = False
self.tasks: typing.List[asyncio.Task] = [] self.tasks: typing.List[asyncio.Task] = []
for peer in shortlist: self.delayed_calls: typing.List[asyncio.Handle] = []
for peer in get_shortlist(routing_table, key, shortlist):
if peer.node_id: if peer.node_id:
self._add_active(peer, force=True) self._add_active(peer)
else: else:
# seed nodes # seed nodes
self._schedule_probe(peer) self._schedule_probe(peer)
@ -123,79 +139,66 @@ class IterativeFinder(AsyncIterator):
""" """
return [] return []
def _add_active(self, peer, force=False): def _is_closer(self, peer: 'KademliaPeer') -> bool:
if not force and self.peer_manager.peer_is_good(peer) is False: return not self.closest_peer or self.distance.is_closer(peer.node_id, self.closest_peer.node_id)
return
if peer in self.contacted: def _add_active(self, peer):
return
if peer not in self.active and peer.node_id and peer.node_id != self.protocol.node_id: if peer not in self.active and peer.node_id and peer.node_id != self.protocol.node_id:
self.active[peer] = self.distance(peer.node_id) self.active.add(peer)
self.active = OrderedDict(sorted(self.active.items(), key=lambda item: item[1])) if self._is_closer(peer):
self.prev_closest_peer = self.closest_peer
self.closest_peer = peer
async def _handle_probe_result(self, peer: 'KademliaPeer', response: FindResponse): async def _handle_probe_result(self, peer: 'KademliaPeer', response: FindResponse):
self._add_active(peer) self._add_active(peer)
for new_peer in response.get_close_kademlia_peers(peer): for contact_triple in response.get_close_triples():
self._add_active(new_peer) node_id, address, udp_port = contact_triple
try:
self._add_active(make_kademlia_peer(node_id, address, udp_port))
except ValueError:
log.warning("misbehaving peer %s:%i returned peer with reserved ip %s:%i", peer.address,
peer.udp_port, address, udp_port)
self.check_result_ready(response) self.check_result_ready(response)
self._log_state(reason="check result")
def _reset_closest(self, peer):
if peer in self.active:
del self.active[peer]
async def _send_probe(self, peer: 'KademliaPeer'): async def _send_probe(self, peer: 'KademliaPeer'):
try: try:
response = await self.send_probe(peer) response = await self.send_probe(peer)
except asyncio.TimeoutError: except asyncio.TimeoutError:
self._reset_closest(peer) self.active.discard(peer)
return return
except asyncio.CancelledError:
log.debug("%s[%x] cancelled probe",
type(self).__name__, id(self))
raise
except ValueError as err: except ValueError as err:
log.warning(str(err)) log.warning(str(err))
self._reset_closest(peer) self.active.discard(peer)
return return
except TransportNotConnected: except TransportNotConnected:
await self._aclose(reason="not connected") return self.aclose()
return
except RemoteException: except RemoteException:
self._reset_closest(peer)
return return
return await self._handle_probe_result(peer, response) return await self._handle_probe_result(peer, response)
def _search_round(self): async def _search_round(self):
""" """
Send up to constants.alpha (5) probes to closest active peers Send up to constants.alpha (5) probes to closest active peers
""" """
added = 0 added = 0
for index, peer in enumerate(self.active.keys()): to_probe = list(self.active - self.contacted)
if index == 0: to_probe.sort(key=lambda peer: self.distance(self.key))
log.debug("%s[%x] closest to probe: %s", for peer in to_probe:
type(self).__name__, id(self), if added >= constants.ALPHA:
peer.node_id.hex()[:8])
if peer in self.contacted:
continue
if len(self.running_probes) >= constants.ALPHA:
break
if index > (constants.K + len(self.running_probes)):
break break
origin_address = (peer.address, peer.udp_port) origin_address = (peer.address, peer.udp_port)
if origin_address in self.exclude:
continue
if peer.node_id == self.protocol.node_id: if peer.node_id == self.protocol.node_id:
continue continue
if origin_address == (self.protocol.external_ip, self.protocol.udp_port): if origin_address == (self.protocol.external_ip, self.protocol.udp_port):
continue continue
self._schedule_probe(peer) self._schedule_probe(peer)
added += 1 added += 1
log.debug("%s[%x] running %d probes for key %s", log.debug("running %d probes", len(self.running_probes))
type(self).__name__, id(self),
len(self.running_probes), self.key.hex()[:8])
if not added and not self.running_probes: if not added and not self.running_probes:
log.debug("%s[%x] search for %s exhausted", log.debug("search for %s exhausted", hexlify(self.key)[:8])
type(self).__name__, id(self),
self.key.hex()[:8])
self.search_exhausted() self.search_exhausted()
def _schedule_probe(self, peer: 'KademliaPeer'): def _schedule_probe(self, peer: 'KademliaPeer'):
@ -204,24 +207,33 @@ class IterativeFinder(AsyncIterator):
t = self.loop.create_task(self._send_probe(peer)) t = self.loop.create_task(self._send_probe(peer))
def callback(_): def callback(_):
self.running_probes.pop(peer, None) self.running_probes.difference_update({
if self.running: probe for probe in self.running_probes if probe.done() or probe == t
self._search_round() })
if not self.running_probes:
self.tasks.append(self.loop.create_task(self._search_task(0.0)))
t.add_done_callback(callback) t.add_done_callback(callback)
self.running_probes[peer] = t self.running_probes.add(t)
def _log_state(self, reason="?"): async def _search_task(self, delay: typing.Optional[float] = constants.ITERATIVE_LOOKUP_DELAY):
log.debug("%s[%x] [%s] %s: %i active nodes %i contacted %i produced %i queued", try:
type(self).__name__, id(self), self.key.hex()[:8], if self.running:
reason, len(self.active), len(self.contacted), await self._search_round()
self.iteration_count, self.iteration_queue.qsize()) if self.running:
self.delayed_calls.append(self.loop.call_later(delay, self._search))
except (asyncio.CancelledError, StopAsyncIteration, TransportNotConnected):
if self.running:
self.loop.call_soon(self.aclose)
def _search(self):
self.tasks.append(self.loop.create_task(self._search_task()))
def __aiter__(self): def __aiter__(self):
if self.running: if self.running:
raise Exception("already running") raise Exception("already running")
self.running = True self.running = True
self.loop.call_soon(self._search_round) self._search()
return self return self
async def __anext__(self) -> typing.List['KademliaPeer']: async def __anext__(self) -> typing.List['KademliaPeer']:
@ -234,57 +246,47 @@ class IterativeFinder(AsyncIterator):
raise StopAsyncIteration raise StopAsyncIteration
self.iteration_count += 1 self.iteration_count += 1
return result return result
except asyncio.CancelledError: except (asyncio.CancelledError, StopAsyncIteration):
await self._aclose(reason="cancelled") self.loop.call_soon(self.aclose)
raise
except StopAsyncIteration:
await self._aclose(reason="no more results")
raise raise
async def _aclose(self, reason="?"): def aclose(self):
log.debug("%s[%x] [%s] shutdown because %s: %i active nodes %i contacted %i produced %i queued",
type(self).__name__, id(self), self.key.hex()[:8],
reason, len(self.active), len(self.contacted),
self.iteration_count, self.iteration_queue.qsize())
self.running = False self.running = False
self.iteration_queue.put_nowait(None) self.iteration_queue.put_nowait(None)
for task in chain(self.tasks, self.running_probes.values()): for task in chain(self.tasks, self.running_probes, self.delayed_calls):
task.cancel() task.cancel()
self.tasks.clear() self.tasks.clear()
self.running_probes.clear() self.running_probes.clear()
self.delayed_calls.clear()
async def aclose(self):
if self.running:
await self._aclose(reason="aclose")
log.debug("%s[%x] [%s] async close completed",
type(self).__name__, id(self), self.key.hex()[:8])
class IterativeNodeFinder(IterativeFinder): class IterativeNodeFinder(IterativeFinder):
def __init__(self, loop: asyncio.AbstractEventLoop, def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
protocol: 'KademliaProtocol', key: bytes, routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
max_results: typing.Optional[int] = constants.K, bottom_out_limit: typing.Optional[int] = 2, max_results: typing.Optional[int] = constants.K,
exclude: typing.Optional[typing.List[typing.Tuple[str, int]]] = None,
shortlist: typing.Optional[typing.List['KademliaPeer']] = None): shortlist: typing.Optional[typing.List['KademliaPeer']] = None):
super().__init__(loop, protocol, key, max_results, shortlist) super().__init__(loop, peer_manager, routing_table, protocol, key, bottom_out_limit, max_results, exclude,
shortlist)
self.yielded_peers: typing.Set['KademliaPeer'] = set() self.yielded_peers: typing.Set['KademliaPeer'] = set()
async def send_probe(self, peer: 'KademliaPeer') -> FindNodeResponse: async def send_probe(self, peer: 'KademliaPeer') -> FindNodeResponse:
log.debug("probe %s:%d (%s) for NODE %s", log.debug("probing %s:%d %s", peer.address, peer.udp_port, hexlify(peer.node_id)[:8] if peer.node_id else '')
peer.address, peer.udp_port, peer.node_id.hex()[:8] if peer.node_id else '', self.key.hex()[:8])
response = await self.protocol.get_rpc_peer(peer).find_node(self.key) response = await self.protocol.get_rpc_peer(peer).find_node(self.key)
return FindNodeResponse(self.key, response) return FindNodeResponse(self.key, response)
def search_exhausted(self): def search_exhausted(self):
self.put_result(self.active.keys(), finish=True) self.put_result(self.active, finish=True)
def put_result(self, from_iter: typing.Iterable['KademliaPeer'], finish=False): def put_result(self, from_iter: typing.Iterable['KademliaPeer'], finish=False):
not_yet_yielded = [ not_yet_yielded = [
peer for peer in from_iter peer for peer in from_iter
if peer not in self.yielded_peers if peer not in self.yielded_peers
and peer.node_id != self.protocol.node_id and peer.node_id != self.protocol.node_id
and self.peer_manager.peer_is_good(peer) is True # return only peers who answered and self.peer_manager.peer_is_good(peer) is not False
] ]
not_yet_yielded.sort(key=lambda peer: self.distance(peer.node_id)) not_yet_yielded.sort(key=lambda peer: self.distance(peer.node_id))
to_yield = not_yet_yielded[:max(constants.K, self.max_results)] to_yield = not_yet_yielded[:min(constants.K, len(not_yet_yielded))]
if to_yield: if to_yield:
self.yielded_peers.update(to_yield) self.yielded_peers.update(to_yield)
self.iteration_queue.put_nowait(to_yield) self.iteration_queue.put_nowait(to_yield)
@ -296,15 +298,27 @@ class IterativeNodeFinder(IterativeFinder):
if found: if found:
log.debug("found") log.debug("found")
return self.put_result(self.active.keys(), finish=True) return self.put_result(self.active, finish=True)
if self.prev_closest_peer and self.closest_peer and not self._is_closer(self.prev_closest_peer):
# log.info("improving, %i %i %i %i %i", len(self.shortlist), len(self.active), len(self.contacted),
# self.bottom_out_count, self.iteration_count)
self.bottom_out_count = 0
elif self.prev_closest_peer and self.closest_peer:
self.bottom_out_count += 1
log.info("bottom out %i %i %i", len(self.active), len(self.contacted), self.bottom_out_count)
if self.bottom_out_count >= self.bottom_out_limit or self.iteration_count >= self.bottom_out_limit:
log.info("limit hit")
self.put_result(self.active, True)
class IterativeValueFinder(IterativeFinder): class IterativeValueFinder(IterativeFinder):
def __init__(self, loop: asyncio.AbstractEventLoop, def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
protocol: 'KademliaProtocol', key: bytes, routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
max_results: typing.Optional[int] = constants.K, bottom_out_limit: typing.Optional[int] = 2, max_results: typing.Optional[int] = constants.K,
exclude: typing.Optional[typing.List[typing.Tuple[str, int]]] = None,
shortlist: typing.Optional[typing.List['KademliaPeer']] = None): shortlist: typing.Optional[typing.List['KademliaPeer']] = None):
super().__init__(loop, protocol, key, max_results, shortlist) super().__init__(loop, peer_manager, routing_table, protocol, key, bottom_out_limit, max_results, exclude,
shortlist)
self.blob_peers: typing.Set['KademliaPeer'] = set() self.blob_peers: typing.Set['KademliaPeer'] = set()
# this tracks the index of the most recent page we requested from each peer # this tracks the index of the most recent page we requested from each peer
self.peer_pages: typing.DefaultDict['KademliaPeer', int] = defaultdict(int) self.peer_pages: typing.DefaultDict['KademliaPeer', int] = defaultdict(int)
@ -312,8 +326,6 @@ class IterativeValueFinder(IterativeFinder):
self.discovered_peers: typing.Dict['KademliaPeer', typing.Set['KademliaPeer']] = defaultdict(set) self.discovered_peers: typing.Dict['KademliaPeer', typing.Set['KademliaPeer']] = defaultdict(set)
async def send_probe(self, peer: 'KademliaPeer') -> FindValueResponse: async def send_probe(self, peer: 'KademliaPeer') -> FindValueResponse:
log.debug("probe %s:%d (%s) for VALUE %s",
peer.address, peer.udp_port, peer.node_id.hex()[:8], self.key.hex()[:8])
page = self.peer_pages[peer] page = self.peer_pages[peer]
response = await self.protocol.get_rpc_peer(peer).find_value(self.key, page=page) response = await self.protocol.get_rpc_peer(peer).find_value(self.key, page=page)
parsed = FindValueResponse(self.key, response) parsed = FindValueResponse(self.key, response)
@ -323,7 +335,7 @@ class IterativeValueFinder(IterativeFinder):
decoded_peers = set() decoded_peers = set()
for compact_addr in parsed.found_compact_addresses: for compact_addr in parsed.found_compact_addresses:
try: try:
decoded_peers.add(decode_tcp_peer_from_compact_address(compact_addr)) decoded_peers.add(self.peer_manager.decode_tcp_peer_from_compact_address(compact_addr))
except ValueError: except ValueError:
log.warning("misbehaving peer %s:%i returned invalid peer for blob", log.warning("misbehaving peer %s:%i returned invalid peer for blob",
peer.address, peer.udp_port) peer.address, peer.udp_port)
@ -335,6 +347,7 @@ class IterativeValueFinder(IterativeFinder):
already_known + len(parsed.found_compact_addresses)) already_known + len(parsed.found_compact_addresses))
if len(self.discovered_peers[peer]) != already_known + len(parsed.found_compact_addresses): if len(self.discovered_peers[peer]) != already_known + len(parsed.found_compact_addresses):
log.warning("misbehaving peer %s:%i returned duplicate peers for blob", peer.address, peer.udp_port) log.warning("misbehaving peer %s:%i returned duplicate peers for blob", peer.address, peer.udp_port)
parsed.found_compact_addresses.clear()
elif len(parsed.found_compact_addresses) >= constants.K and self.peer_pages[peer] < parsed.pages: elif len(parsed.found_compact_addresses) >= constants.K and self.peer_pages[peer] < parsed.pages:
# the peer returned a full page and indicates it has more # the peer returned a full page and indicates it has more
self.peer_pages[peer] += 1 self.peer_pages[peer] += 1
@ -345,15 +358,26 @@ class IterativeValueFinder(IterativeFinder):
def check_result_ready(self, response: FindValueResponse): def check_result_ready(self, response: FindValueResponse):
if response.found: if response.found:
blob_peers = [decode_tcp_peer_from_compact_address(compact_addr) blob_peers = [self.peer_manager.decode_tcp_peer_from_compact_address(compact_addr)
for compact_addr in response.found_compact_addresses] for compact_addr in response.found_compact_addresses]
to_yield = [] to_yield = []
self.bottom_out_count = 0
for blob_peer in blob_peers: for blob_peer in blob_peers:
if blob_peer not in self.blob_peers: if blob_peer not in self.blob_peers:
self.blob_peers.add(blob_peer) self.blob_peers.add(blob_peer)
to_yield.append(blob_peer) to_yield.append(blob_peer)
if to_yield: if to_yield:
# log.info("found %i new peers for blob", len(to_yield))
self.iteration_queue.put_nowait(to_yield) self.iteration_queue.put_nowait(to_yield)
# if self.max_results and len(self.blob_peers) >= self.max_results:
# log.info("enough blob peers found")
# if not self.finished.is_set():
# self.finished.set()
elif self.prev_closest_peer and self.closest_peer:
self.bottom_out_count += 1
if self.bottom_out_count >= self.bottom_out_limit:
log.info("blob peer search bottomed out")
self.iteration_queue.put_nowait(None)
def get_initial_result(self) -> typing.List['KademliaPeer']: def get_initial_result(self) -> typing.List['KademliaPeer']:
if self.protocol.data_store.has_peers_for_blob(self.key): if self.protocol.data_store.has_peers_for_blob(self.key):

View file

@ -3,16 +3,13 @@ import socket
import functools import functools
import hashlib import hashlib
import asyncio import asyncio
import time
import typing import typing
import binascii
import random import random
from asyncio.protocols import DatagramProtocol from asyncio.protocols import DatagramProtocol
from asyncio.transports import DatagramTransport from asyncio.transports import DatagramTransport
from prometheus_client import Gauge, Counter, Histogram
from lbry.dht import constants from lbry.dht import constants
from lbry.dht.serialization.bencoding import DecodeError
from lbry.dht.serialization.datagram import decode_datagram, ErrorDatagram, ResponseDatagram, RequestDatagram from lbry.dht.serialization.datagram import decode_datagram, ErrorDatagram, ResponseDatagram, RequestDatagram
from lbry.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE, PAGE_KEY from lbry.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE, PAGE_KEY
from lbry.dht.error import RemoteException, TransportNotConnected from lbry.dht.error import RemoteException, TransportNotConnected
@ -33,11 +30,6 @@ OLD_PROTOCOL_ERRORS = {
class KademliaRPC: class KademliaRPC:
stored_blob_metric = Gauge(
"stored_blobs", "Number of blobs announced by other peers", namespace="dht_node",
labelnames=("scope",),
)
def __init__(self, protocol: 'KademliaProtocol', loop: asyncio.AbstractEventLoop, peer_port: int = 3333): def __init__(self, protocol: 'KademliaProtocol', loop: asyncio.AbstractEventLoop, peer_port: int = 3333):
self.protocol = protocol self.protocol = protocol
self.loop = loop self.loop = loop
@ -69,7 +61,6 @@ class KademliaRPC:
self.protocol.data_store.add_peer_to_blob( self.protocol.data_store.add_peer_to_blob(
rpc_contact, blob_hash rpc_contact, blob_hash
) )
self.stored_blob_metric.labels("global").set(len(self.protocol.data_store))
return b'OK' return b'OK'
def find_node(self, rpc_contact: 'KademliaPeer', key: bytes) -> typing.List[typing.Tuple[bytes, str, int]]: def find_node(self, rpc_contact: 'KademliaPeer', key: bytes) -> typing.List[typing.Tuple[bytes, str, int]]:
@ -105,7 +96,7 @@ class KademliaRPC:
if not rpc_contact.tcp_port or peer.compact_address_tcp() != rpc_contact.compact_address_tcp() if not rpc_contact.tcp_port or peer.compact_address_tcp() != rpc_contact.compact_address_tcp()
] ]
# if we don't have k storing peers to return and we have this hash locally, include our contact information # if we don't have k storing peers to return and we have this hash locally, include our contact information
if len(peers) < constants.K and key.hex() in self.protocol.data_store.completed_blobs: if len(peers) < constants.K and binascii.hexlify(key).decode() in self.protocol.data_store.completed_blobs:
peers.append(self.compact_address()) peers.append(self.compact_address())
if not peers: if not peers:
response[PAGE_KEY] = 0 response[PAGE_KEY] = 0
@ -218,10 +209,6 @@ class PingQueue:
def running(self): def running(self):
return self._running return self._running
@property
def busy(self):
return self._running and (any(self._running_pings) or any(self._pending_contacts))
def enqueue_maybe_ping(self, *peers: 'KademliaPeer', delay: typing.Optional[float] = None): def enqueue_maybe_ping(self, *peers: 'KademliaPeer', delay: typing.Optional[float] = None):
delay = delay if delay is not None else self._default_delay delay = delay if delay is not None else self._default_delay
now = self._loop.time() now = self._loop.time()
@ -233,7 +220,7 @@ class PingQueue:
async def ping_task(): async def ping_task():
try: try:
if self._protocol.peer_manager.peer_is_good(peer): if self._protocol.peer_manager.peer_is_good(peer):
if not self._protocol.routing_table.get_peer(peer.node_id): if peer not in self._protocol.routing_table.get_peers():
self._protocol.add_peer(peer) self._protocol.add_peer(peer)
return return
await self._protocol.get_rpc_peer(peer).ping() await self._protocol.get_rpc_peer(peer).ping()
@ -253,7 +240,7 @@ class PingQueue:
del self._pending_contacts[peer] del self._pending_contacts[peer]
self.maybe_ping(peer) self.maybe_ping(peer)
break break
await asyncio.sleep(1) await asyncio.sleep(1, loop=self._loop)
def start(self): def start(self):
assert not self._running assert not self._running
@ -272,33 +259,9 @@ class PingQueue:
class KademliaProtocol(DatagramProtocol): class KademliaProtocol(DatagramProtocol):
request_sent_metric = Counter(
"request_sent", "Number of requests send from DHT RPC protocol", namespace="dht_node",
labelnames=("method",),
)
request_success_metric = Counter(
"request_success", "Number of successful requests", namespace="dht_node",
labelnames=("method",),
)
request_error_metric = Counter(
"request_error", "Number of errors returned from request to other peers", namespace="dht_node",
labelnames=("method",),
)
HISTOGRAM_BUCKETS = (
.005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 3.0, 3.5, 4.0, 4.50, 5.0, 5.50, 6.0, float('inf')
)
response_time_metric = Histogram(
"response_time", "Response times of DHT RPC requests", namespace="dht_node", buckets=HISTOGRAM_BUCKETS,
labelnames=("method",)
)
received_request_metric = Counter(
"received_request", "Number of received DHT RPC requests", namespace="dht_node",
labelnames=("method",),
)
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', node_id: bytes, external_ip: str, def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', node_id: bytes, external_ip: str,
udp_port: int, peer_port: int, rpc_timeout: float = constants.RPC_TIMEOUT, udp_port: int, peer_port: int, rpc_timeout: float = constants.RPC_TIMEOUT,
split_buckets_under_index: int = constants.SPLIT_BUCKETS_UNDER_INDEX, is_boostrap_node: bool = False): split_buckets_under_index: int = constants.SPLIT_BUCKETS_UNDER_INDEX):
self.peer_manager = peer_manager self.peer_manager = peer_manager
self.loop = loop self.loop = loop
self.node_id = node_id self.node_id = node_id
@ -313,16 +276,15 @@ class KademliaProtocol(DatagramProtocol):
self.transport: DatagramTransport = None self.transport: DatagramTransport = None
self.old_token_secret = constants.generate_id() self.old_token_secret = constants.generate_id()
self.token_secret = constants.generate_id() self.token_secret = constants.generate_id()
self.routing_table = TreeRoutingTable( self.routing_table = TreeRoutingTable(self.loop, self.peer_manager, self.node_id, split_buckets_under_index)
self.loop, self.peer_manager, self.node_id, split_buckets_under_index, is_bootstrap_node=is_boostrap_node)
self.data_store = DictDataStore(self.loop, self.peer_manager) self.data_store = DictDataStore(self.loop, self.peer_manager)
self.ping_queue = PingQueue(self.loop, self) self.ping_queue = PingQueue(self.loop, self)
self.node_rpc = KademliaRPC(self, self.loop, self.peer_port) self.node_rpc = KademliaRPC(self, self.loop, self.peer_port)
self.rpc_timeout = rpc_timeout self.rpc_timeout = rpc_timeout
self._split_lock = asyncio.Lock() self._split_lock = asyncio.Lock(loop=self.loop)
self._to_remove: typing.Set['KademliaPeer'] = set() self._to_remove: typing.Set['KademliaPeer'] = set()
self._to_add: typing.Set['KademliaPeer'] = set() self._to_add: typing.Set['KademliaPeer'] = set()
self._wakeup_routing_task = asyncio.Event() self._wakeup_routing_task = asyncio.Event(loop=self.loop)
self.maintaing_routing_task: typing.Optional[asyncio.Task] = None self.maintaing_routing_task: typing.Optional[asyncio.Task] = None
@functools.lru_cache(128) @functools.lru_cache(128)
@ -361,10 +323,72 @@ class KademliaProtocol(DatagramProtocol):
return args, {} return args, {}
async def _add_peer(self, peer: 'KademliaPeer'): async def _add_peer(self, peer: 'KademliaPeer'):
async def probe(some_peer: 'KademliaPeer'): if not peer.node_id:
rpc_peer = self.get_rpc_peer(some_peer) log.warning("Tried adding a peer with no node id!")
await rpc_peer.ping() return False
return await self.routing_table.add_peer(peer, probe) for my_peer in self.routing_table.get_peers():
if (my_peer.address, my_peer.udp_port) == (peer.address, peer.udp_port) and my_peer.node_id != peer.node_id:
self.routing_table.remove_peer(my_peer)
self.routing_table.join_buckets()
bucket_index = self.routing_table.kbucket_index(peer.node_id)
if self.routing_table.buckets[bucket_index].add_peer(peer):
return True
# The bucket is full; see if it can be split (by checking if its range includes the host node's node_id)
if self.routing_table.should_split(bucket_index, peer.node_id):
self.routing_table.split_bucket(bucket_index)
# Retry the insertion attempt
result = await self._add_peer(peer)
self.routing_table.join_buckets()
return result
else:
# We can't split the k-bucket
#
# The 13 page kademlia paper specifies that the least recently contacted node in the bucket
# shall be pinged. If it fails to reply it is replaced with the new contact. If the ping is successful
# the new contact is ignored and not added to the bucket (sections 2.2 and 2.4).
#
# A reasonable extension to this is BEP 0005, which extends the above:
#
# Not all nodes that we learn about are equal. Some are "good" and some are not.
# Many nodes using the DHT are able to send queries and receive responses,
# but are not able to respond to queries from other nodes. It is important that
# each node's routing table must contain only known good nodes. A good node is
# a node has responded to one of our queries within the last 15 minutes. A node
# is also good if it has ever responded to one of our queries and has sent us a
# query within the last 15 minutes. After 15 minutes of inactivity, a node becomes
# questionable. Nodes become bad when they fail to respond to multiple queries
# in a row. Nodes that we know are good are given priority over nodes with unknown status.
#
# When there are bad or questionable nodes in the bucket, the least recent is selected for
# potential replacement (BEP 0005). When all nodes in the bucket are fresh, the head (least recent)
# contact is selected as described in section 2.2 of the kademlia paper. In both cases the new contact
# is ignored if the pinged node replies.
not_good_contacts = self.routing_table.buckets[bucket_index].get_bad_or_unknown_peers()
not_recently_replied = []
for my_peer in not_good_contacts:
last_replied = self.peer_manager.get_last_replied(my_peer.address, my_peer.udp_port)
if not last_replied or last_replied + 60 < self.loop.time():
not_recently_replied.append(my_peer)
if not_recently_replied:
to_replace = not_recently_replied[0]
else:
to_replace = self.routing_table.buckets[bucket_index].peers[0]
last_replied = self.peer_manager.get_last_replied(to_replace.address, to_replace.udp_port)
if last_replied and last_replied + 60 > self.loop.time():
return False
log.debug("pinging %s:%s", to_replace.address, to_replace.udp_port)
try:
to_replace_rpc = self.get_rpc_peer(to_replace)
await to_replace_rpc.ping()
return False
except asyncio.TimeoutError:
log.debug("Replacing dead contact in bucket %i: %s:%i with %s:%i ", bucket_index,
to_replace.address, to_replace.udp_port, peer.address, peer.udp_port)
if to_replace in self.routing_table.buckets[bucket_index]:
self.routing_table.buckets[bucket_index].remove_peer(to_replace)
return await self._add_peer(peer)
def add_peer(self, peer: 'KademliaPeer'): def add_peer(self, peer: 'KademliaPeer'):
if peer.node_id == self.node_id: if peer.node_id == self.node_id:
@ -382,15 +406,16 @@ class KademliaProtocol(DatagramProtocol):
async with self._split_lock: async with self._split_lock:
peer = self._to_remove.pop() peer = self._to_remove.pop()
self.routing_table.remove_peer(peer) self.routing_table.remove_peer(peer)
self.routing_table.join_buckets()
while self._to_add: while self._to_add:
async with self._split_lock: async with self._split_lock:
await self._add_peer(self._to_add.pop()) await self._add_peer(self._to_add.pop())
await asyncio.gather(self._wakeup_routing_task.wait(), asyncio.sleep(.1)) await asyncio.gather(self._wakeup_routing_task.wait(), asyncio.sleep(.1, loop=self.loop), loop=self.loop)
self._wakeup_routing_task.clear() self._wakeup_routing_task.clear()
def _handle_rpc(self, sender_contact: 'KademliaPeer', message: RequestDatagram): def _handle_rpc(self, sender_contact: 'KademliaPeer', message: RequestDatagram):
assert sender_contact.node_id != self.node_id, (sender_contact.node_id.hex()[:8], assert sender_contact.node_id != self.node_id, (binascii.hexlify(sender_contact.node_id)[:8].decode(),
self.node_id.hex()[:8]) binascii.hexlify(self.node_id)[:8].decode())
method = message.method method = message.method
if method not in [b'ping', b'store', b'findNode', b'findValue']: if method not in [b'ping', b'store', b'findNode', b'findValue']:
raise AttributeError('Invalid method: %s' % message.method.decode()) raise AttributeError('Invalid method: %s' % message.method.decode())
@ -422,15 +447,11 @@ class KademliaProtocol(DatagramProtocol):
def handle_request_datagram(self, address: typing.Tuple[str, int], request_datagram: RequestDatagram): def handle_request_datagram(self, address: typing.Tuple[str, int], request_datagram: RequestDatagram):
# This is an RPC method request # This is an RPC method request
self.received_request_metric.labels(method=request_datagram.method).inc()
self.peer_manager.report_last_requested(address[0], address[1]) self.peer_manager.report_last_requested(address[0], address[1])
peer = self.routing_table.get_peer(request_datagram.node_id)
if not peer:
try: try:
peer = self.routing_table.get_peer(request_datagram.node_id)
except IndexError:
peer = make_kademlia_peer(request_datagram.node_id, address[0], address[1]) peer = make_kademlia_peer(request_datagram.node_id, address[0], address[1])
except ValueError as err:
log.warning("error replying to %s: %s", address[0], str(err))
return
try: try:
self._handle_rpc(peer, request_datagram) self._handle_rpc(peer, request_datagram)
# if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it # if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it
@ -530,12 +551,12 @@ class KademliaProtocol(DatagramProtocol):
address[0], address[1], OLD_PROTOCOL_ERRORS[error_datagram.response] address[0], address[1], OLD_PROTOCOL_ERRORS[error_datagram.response]
) )
def datagram_received(self, datagram: bytes, address: typing.Tuple[str, int]) -> None: # pylint: disable=arguments-renamed def datagram_received(self, datagram: bytes, address: typing.Tuple[str, int]) -> None: # pylint: disable=arguments-differ
try: try:
message = decode_datagram(datagram) message = decode_datagram(datagram)
except (ValueError, TypeError, DecodeError): except (ValueError, TypeError):
self.peer_manager.report_failure(address[0], address[1]) self.peer_manager.report_failure(address[0], address[1])
log.warning("Couldn't decode dht datagram from %s: %s", address, datagram.hex()) log.warning("Couldn't decode dht datagram from %s: %s", address, binascii.hexlify(datagram).decode())
return return
if isinstance(message, RequestDatagram): if isinstance(message, RequestDatagram):
@ -550,19 +571,14 @@ class KademliaProtocol(DatagramProtocol):
self._send(peer, request) self._send(peer, request)
response_fut = self.sent_messages[request.rpc_id][1] response_fut = self.sent_messages[request.rpc_id][1]
try: try:
self.request_sent_metric.labels(method=request.method).inc()
start = time.perf_counter()
response = await asyncio.wait_for(response_fut, self.rpc_timeout) response = await asyncio.wait_for(response_fut, self.rpc_timeout)
self.response_time_metric.labels(method=request.method).observe(time.perf_counter() - start)
self.peer_manager.report_last_replied(peer.address, peer.udp_port) self.peer_manager.report_last_replied(peer.address, peer.udp_port)
self.request_success_metric.labels(method=request.method).inc()
return response return response
except asyncio.CancelledError: except asyncio.CancelledError:
if not response_fut.done(): if not response_fut.done():
response_fut.cancel() response_fut.cancel()
raise raise
except (asyncio.TimeoutError, RemoteException): except (asyncio.TimeoutError, RemoteException):
self.request_error_metric.labels(method=request.method).inc()
self.peer_manager.report_failure(peer.address, peer.udp_port) self.peer_manager.report_failure(peer.address, peer.udp_port)
if self.peer_manager.peer_is_good(peer) is False: if self.peer_manager.peer_is_good(peer) is False:
self.remove_peer(peer) self.remove_peer(peer)
@ -582,7 +598,7 @@ class KademliaProtocol(DatagramProtocol):
if len(data) > constants.MSG_SIZE_LIMIT: if len(data) > constants.MSG_SIZE_LIMIT:
log.warning("cannot send datagram larger than %i bytes (packet is %i bytes)", log.warning("cannot send datagram larger than %i bytes (packet is %i bytes)",
constants.MSG_SIZE_LIMIT, len(data)) constants.MSG_SIZE_LIMIT, len(data))
log.debug("Packet is too large to send: %s", data[:3500].hex()) log.debug("Packet is too large to send: %s", binascii.hexlify(data[:3500]).decode())
raise ValueError( raise ValueError(
f"cannot send datagram larger than {constants.MSG_SIZE_LIMIT} bytes (packet is {len(data)} bytes)" f"cannot send datagram larger than {constants.MSG_SIZE_LIMIT} bytes (packet is {len(data)} bytes)"
) )
@ -642,13 +658,13 @@ class KademliaProtocol(DatagramProtocol):
res = await self.get_rpc_peer(peer).store(hash_value) res = await self.get_rpc_peer(peer).store(hash_value)
if res != b"OK": if res != b"OK":
raise ValueError(res) raise ValueError(res)
log.debug("Stored %s to %s", hash_value.hex()[:8], peer) log.debug("Stored %s to %s", binascii.hexlify(hash_value).decode()[:8], peer)
return peer.node_id, True return peer.node_id, True
try: try:
return await __store() return await __store()
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.debug("Timeout while storing blob_hash %s at %s", hash_value.hex()[:8], peer) log.debug("Timeout while storing blob_hash %s at %s", binascii.hexlify(hash_value).decode()[:8], peer)
return peer.node_id, False return peer.node_id, False
except ValueError as err: except ValueError as err:
log.error("Unexpected response: %s", err) log.error("Unexpected response: %s", err)

View file

@ -4,11 +4,7 @@ import logging
import typing import typing
import itertools import itertools
from prometheus_client import Gauge
from lbry import utils
from lbry.dht import constants from lbry.dht import constants
from lbry.dht.error import RemoteException
from lbry.dht.protocol.distance import Distance from lbry.dht.protocol.distance import Distance
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbry.dht.peer import KademliaPeer, PeerManager from lbry.dht.peer import KademliaPeer, PeerManager
@ -17,20 +13,10 @@ log = logging.getLogger(__name__)
class KBucket: class KBucket:
""" Description - later
""" """
Kademlia K-bucket implementation.
"""
peer_in_routing_table_metric = Gauge(
"peers_in_routing_table", "Number of peers on routing table", namespace="dht_node",
labelnames=("scope",)
)
peer_with_x_bit_colliding_metric = Gauge(
"peer_x_bit_colliding", "Number of peers with at least X bits colliding with this node id",
namespace="dht_node", labelnames=("amount",)
)
def __init__(self, peer_manager: 'PeerManager', range_min: int, range_max: int, def __init__(self, peer_manager: 'PeerManager', range_min: int, range_max: int, node_id: bytes):
node_id: bytes, capacity: int = constants.K):
""" """
@param range_min: The lower boundary for the range in the n-bit ID @param range_min: The lower boundary for the range in the n-bit ID
space covered by this k-bucket space covered by this k-bucket
@ -38,12 +24,12 @@ class KBucket:
covered by this k-bucket covered by this k-bucket
""" """
self._peer_manager = peer_manager self._peer_manager = peer_manager
self.last_accessed = 0
self.range_min = range_min self.range_min = range_min
self.range_max = range_max self.range_max = range_max
self.peers: typing.List['KademliaPeer'] = [] self.peers: typing.List['KademliaPeer'] = []
self._node_id = node_id self._node_id = node_id
self._distance_to_self = Distance(node_id) self._distance_to_self = Distance(node_id)
self.capacity = capacity
def add_peer(self, peer: 'KademliaPeer') -> bool: def add_peer(self, peer: 'KademliaPeer') -> bool:
""" Add contact to _contact list in the right order. This will move the """ Add contact to _contact list in the right order. This will move the
@ -64,25 +50,24 @@ class KBucket:
self.peers.append(peer) self.peers.append(peer)
return True return True
else: else:
for i, _ in enumerate(self.peers): for i in range(len(self.peers)):
local_peer = self.peers[i] local_peer = self.peers[i]
if local_peer.node_id == peer.node_id: if local_peer.node_id == peer.node_id:
self.peers.remove(local_peer) self.peers.remove(local_peer)
self.peers.append(peer) self.peers.append(peer)
return True return True
if len(self.peers) < self.capacity: if len(self.peers) < constants.K:
self.peers.append(peer) self.peers.append(peer)
self.peer_in_routing_table_metric.labels("global").inc()
bits_colliding = utils.get_colliding_prefix_bits(peer.node_id, self._node_id)
self.peer_with_x_bit_colliding_metric.labels(amount=bits_colliding).inc()
return True return True
else: else:
return False return False
# raise BucketFull("No space in bucket to insert contact")
def get_peer(self, node_id: bytes) -> 'KademliaPeer': def get_peer(self, node_id: bytes) -> 'KademliaPeer':
for peer in self.peers: for peer in self.peers:
if peer.node_id == node_id: if peer.node_id == node_id:
return peer return peer
raise IndexError(node_id)
def get_peers(self, count=-1, exclude_contact=None, sort_distance_to=None) -> typing.List['KademliaPeer']: def get_peers(self, count=-1, exclude_contact=None, sort_distance_to=None) -> typing.List['KademliaPeer']:
""" Returns a list containing up to the first count number of contacts """ Returns a list containing up to the first count number of contacts
@ -139,9 +124,6 @@ class KBucket:
def remove_peer(self, peer: 'KademliaPeer') -> None: def remove_peer(self, peer: 'KademliaPeer') -> None:
self.peers.remove(peer) self.peers.remove(peer)
self.peer_in_routing_table_metric.labels("global").dec()
bits_colliding = utils.get_colliding_prefix_bits(peer.node_id, self._node_id)
self.peer_with_x_bit_colliding_metric.labels(amount=bits_colliding).dec()
def key_in_range(self, key: bytes) -> bool: def key_in_range(self, key: bytes) -> bool:
""" Tests whether the specified key (i.e. node ID) is in the range """ Tests whether the specified key (i.e. node ID) is in the range
@ -179,36 +161,24 @@ class TreeRoutingTable:
version of the Kademlia paper, in section 2.4. It does, however, use the version of the Kademlia paper, in section 2.4. It does, however, use the
ping RPC-based k-bucket eviction algorithm described in section 2.2 of ping RPC-based k-bucket eviction algorithm described in section 2.2 of
that paper. that paper.
BOOTSTRAP MODE: if set to True, we always add all peers. This is so a
bootstrap node does not get a bias towards its own node id and replies are
the best it can provide (joining peer knows its neighbors immediately).
Over time, this will need to be optimized so we use the disk as holding
everything in memory won't be feasible anymore.
See: https://github.com/bittorrent/bootstrap-dht
""" """
bucket_in_routing_table_metric = Gauge(
"buckets_in_routing_table", "Number of buckets on routing table", namespace="dht_node",
labelnames=("scope",)
)
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', parent_node_id: bytes, def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', parent_node_id: bytes,
split_buckets_under_index: int = constants.SPLIT_BUCKETS_UNDER_INDEX, is_bootstrap_node: bool = False): split_buckets_under_index: int = constants.SPLIT_BUCKETS_UNDER_INDEX):
self._loop = loop self._loop = loop
self._peer_manager = peer_manager self._peer_manager = peer_manager
self._parent_node_id = parent_node_id self._parent_node_id = parent_node_id
self._split_buckets_under_index = split_buckets_under_index self._split_buckets_under_index = split_buckets_under_index
self.buckets: typing.List[KBucket] = [ self.buckets: typing.List[KBucket] = [
KBucket( KBucket(
self._peer_manager, range_min=0, range_max=2 ** constants.HASH_BITS, node_id=self._parent_node_id, self._peer_manager, range_min=0, range_max=2 ** constants.HASH_BITS, node_id=self._parent_node_id
capacity=1 << 32 if is_bootstrap_node else constants.K
) )
] ]
def get_peers(self) -> typing.List['KademliaPeer']: def get_peers(self) -> typing.List['KademliaPeer']:
return list(itertools.chain.from_iterable(map(lambda bucket: bucket.peers, self.buckets))) return list(itertools.chain.from_iterable(map(lambda bucket: bucket.peers, self.buckets)))
def _should_split(self, bucket_index: int, to_add: bytes) -> bool: def should_split(self, bucket_index: int, to_add: bytes) -> bool:
# https://stackoverflow.com/questions/32129978/highly-unbalanced-kademlia-routing-table/32187456#32187456 # https://stackoverflow.com/questions/32129978/highly-unbalanced-kademlia-routing-table/32187456#32187456
if bucket_index < self._split_buckets_under_index: if bucket_index < self._split_buckets_under_index:
return True return True
@ -233,32 +203,39 @@ class TreeRoutingTable:
return [] return []
def get_peer(self, contact_id: bytes) -> 'KademliaPeer': def get_peer(self, contact_id: bytes) -> 'KademliaPeer':
return self.buckets[self._kbucket_index(contact_id)].get_peer(contact_id) """
@raise IndexError: No contact with the specified contact ID is known
by this node
"""
return self.buckets[self.kbucket_index(contact_id)].get_peer(contact_id)
def get_refresh_list(self, start_index: int = 0, force: bool = False) -> typing.List[bytes]: def get_refresh_list(self, start_index: int = 0, force: bool = False) -> typing.List[bytes]:
bucket_index = start_index
refresh_ids = [] refresh_ids = []
for offset, _ in enumerate(self.buckets[start_index:]): now = int(self._loop.time())
refresh_ids.append(self._midpoint_id_in_bucket_range(start_index + offset)) for bucket in self.buckets[start_index:]:
# if we have 3 or fewer populated buckets get two random ids in the range of each to try and if force or now - bucket.last_accessed >= constants.REFRESH_INTERVAL:
# populate/split the buckets further to_search = self.midpoint_id_in_bucket_range(bucket_index)
buckets_with_contacts = self.buckets_with_contacts() refresh_ids.append(to_search)
if buckets_with_contacts <= 3: bucket_index += 1
for i in range(buckets_with_contacts):
refresh_ids.append(self._random_id_in_bucket_range(i))
refresh_ids.append(self._random_id_in_bucket_range(i))
return refresh_ids return refresh_ids
def remove_peer(self, peer: 'KademliaPeer') -> None: def remove_peer(self, peer: 'KademliaPeer') -> None:
if not peer.node_id: if not peer.node_id:
return return
bucket_index = self._kbucket_index(peer.node_id) bucket_index = self.kbucket_index(peer.node_id)
try: try:
self.buckets[bucket_index].remove_peer(peer) self.buckets[bucket_index].remove_peer(peer)
self._join_buckets()
except ValueError: except ValueError:
return return
def _kbucket_index(self, key: bytes) -> int: def touch_kbucket(self, key: bytes) -> None:
self.touch_kbucket_by_index(self.kbucket_index(key))
def touch_kbucket_by_index(self, bucket_index: int):
self.buckets[bucket_index].last_accessed = int(self._loop.time())
def kbucket_index(self, key: bytes) -> int:
i = 0 i = 0
for bucket in self.buckets: for bucket in self.buckets:
if bucket.key_in_range(key): if bucket.key_in_range(key):
@ -267,19 +244,19 @@ class TreeRoutingTable:
i += 1 i += 1
return i return i
def _random_id_in_bucket_range(self, bucket_index: int) -> bytes: def random_id_in_bucket_range(self, bucket_index: int) -> bytes:
random_id = int(random.randrange(self.buckets[bucket_index].range_min, self.buckets[bucket_index].range_max)) random_id = int(random.randrange(self.buckets[bucket_index].range_min, self.buckets[bucket_index].range_max))
return Distance( return Distance(
self._parent_node_id self._parent_node_id
)(random_id.to_bytes(constants.HASH_LENGTH, 'big')).to_bytes(constants.HASH_LENGTH, 'big') )(random_id.to_bytes(constants.HASH_LENGTH, 'big')).to_bytes(constants.HASH_LENGTH, 'big')
def _midpoint_id_in_bucket_range(self, bucket_index: int) -> bytes: def midpoint_id_in_bucket_range(self, bucket_index: int) -> bytes:
half = int((self.buckets[bucket_index].range_max - self.buckets[bucket_index].range_min) // 2) half = int((self.buckets[bucket_index].range_max - self.buckets[bucket_index].range_min) // 2)
return Distance(self._parent_node_id)( return Distance(self._parent_node_id)(
int(self.buckets[bucket_index].range_min + half).to_bytes(constants.HASH_LENGTH, 'big') int(self.buckets[bucket_index].range_min + half).to_bytes(constants.HASH_LENGTH, 'big')
).to_bytes(constants.HASH_LENGTH, 'big') ).to_bytes(constants.HASH_LENGTH, 'big')
def _split_bucket(self, old_bucket_index: int) -> None: def split_bucket(self, old_bucket_index: int) -> None:
""" Splits the specified k-bucket into two new buckets which together """ Splits the specified k-bucket into two new buckets which together
cover the same range in the key/ID space cover the same range in the key/ID space
@ -302,9 +279,8 @@ class TreeRoutingTable:
# ...and remove them from the old bucket # ...and remove them from the old bucket
for contact in new_bucket.peers: for contact in new_bucket.peers:
old_bucket.remove_peer(contact) old_bucket.remove_peer(contact)
self.bucket_in_routing_table_metric.labels("global").set(len(self.buckets))
def _join_buckets(self): def join_buckets(self):
if len(self.buckets) == 1: if len(self.buckets) == 1:
return return
to_pop = [i for i, bucket in enumerate(self.buckets) if len(bucket) == 0] to_pop = [i for i, bucket in enumerate(self.buckets) if len(bucket) == 0]
@ -326,8 +302,14 @@ class TreeRoutingTable:
elif can_go_higher: elif can_go_higher:
self.buckets[bucket_index_to_pop + 1].range_min = bucket.range_min self.buckets[bucket_index_to_pop + 1].range_min = bucket.range_min
self.buckets.remove(bucket) self.buckets.remove(bucket)
self.bucket_in_routing_table_metric.labels("global").set(len(self.buckets)) return self.join_buckets()
return self._join_buckets()
def contact_in_routing_table(self, address_tuple: typing.Tuple[str, int]) -> bool:
for bucket in self.buckets:
for contact in bucket.get_peers(sort_distance_to=False):
if address_tuple[0] == contact.address and address_tuple[1] == contact.udp_port:
return True
return False
def buckets_with_contacts(self) -> int: def buckets_with_contacts(self) -> int:
count = 0 count = 0
@ -335,70 +317,3 @@ class TreeRoutingTable:
if len(bucket) > 0: if len(bucket) > 0:
count += 1 count += 1
return count return count
async def add_peer(self, peer: 'KademliaPeer', probe: typing.Callable[['KademliaPeer'], typing.Awaitable]):
if not peer.node_id:
log.warning("Tried adding a peer with no node id!")
return False
for my_peer in self.get_peers():
if (my_peer.address, my_peer.udp_port) == (peer.address, peer.udp_port) and my_peer.node_id != peer.node_id:
self.remove_peer(my_peer)
self._join_buckets()
bucket_index = self._kbucket_index(peer.node_id)
if self.buckets[bucket_index].add_peer(peer):
return True
# The bucket is full; see if it can be split (by checking if its range includes the host node's node_id)
if self._should_split(bucket_index, peer.node_id):
self._split_bucket(bucket_index)
# Retry the insertion attempt
result = await self.add_peer(peer, probe)
self._join_buckets()
return result
else:
# We can't split the k-bucket
#
# The 13 page kademlia paper specifies that the least recently contacted node in the bucket
# shall be pinged. If it fails to reply it is replaced with the new contact. If the ping is successful
# the new contact is ignored and not added to the bucket (sections 2.2 and 2.4).
#
# A reasonable extension to this is BEP 0005, which extends the above:
#
# Not all nodes that we learn about are equal. Some are "good" and some are not.
# Many nodes using the DHT are able to send queries and receive responses,
# but are not able to respond to queries from other nodes. It is important that
# each node's routing table must contain only known good nodes. A good node is
# a node has responded to one of our queries within the last 15 minutes. A node
# is also good if it has ever responded to one of our queries and has sent us a
# query within the last 15 minutes. After 15 minutes of inactivity, a node becomes
# questionable. Nodes become bad when they fail to respond to multiple queries
# in a row. Nodes that we know are good are given priority over nodes with unknown status.
#
# When there are bad or questionable nodes in the bucket, the least recent is selected for
# potential replacement (BEP 0005). When all nodes in the bucket are fresh, the head (least recent)
# contact is selected as described in section 2.2 of the kademlia paper. In both cases the new contact
# is ignored if the pinged node replies.
not_good_contacts = self.buckets[bucket_index].get_bad_or_unknown_peers()
not_recently_replied = []
for my_peer in not_good_contacts:
last_replied = self._peer_manager.get_last_replied(my_peer.address, my_peer.udp_port)
if not last_replied or last_replied + 60 < self._loop.time():
not_recently_replied.append(my_peer)
if not_recently_replied:
to_replace = not_recently_replied[0]
else:
to_replace = self.buckets[bucket_index].peers[0]
last_replied = self._peer_manager.get_last_replied(to_replace.address, to_replace.udp_port)
if last_replied and last_replied + 60 > self._loop.time():
return False
log.debug("pinging %s:%s", to_replace.address, to_replace.udp_port)
try:
await probe(to_replace)
return False
except (asyncio.TimeoutError, RemoteException):
log.debug("Replacing dead contact in bucket %i: %s:%i with %s:%i ", bucket_index,
to_replace.address, to_replace.udp_port, peer.address, peer.udp_port)
if to_replace in self.buckets[bucket_index]:
self.buckets[bucket_index].remove_peer(to_replace)
return await self.add_peer(peer, probe)

View file

@ -144,7 +144,7 @@ class ErrorDatagram(KademliaDatagramBase):
self.response = response.decode() self.response = response.decode()
def _decode_datagram(datagram: bytes): def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]:
msg_types = { msg_types = {
REQUEST_TYPE: RequestDatagram, REQUEST_TYPE: RequestDatagram,
RESPONSE_TYPE: ResponseDatagram, RESPONSE_TYPE: ResponseDatagram,
@ -152,36 +152,26 @@ def _decode_datagram(datagram: bytes):
} }
primitive: typing.Dict = bdecode(datagram) primitive: typing.Dict = bdecode(datagram)
if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
converted = { datagram_type = primitive[0] # pylint: disable=unsubscriptable-object
str(k).encode() if not isinstance(k, bytes) else k: v for k, v in primitive.items()
}
if converted[b'0'] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
datagram_type = converted[b'0'] # pylint: disable=unsubscriptable-object
else: else:
raise ValueError("invalid datagram type") raise ValueError("invalid datagram type")
datagram_class = msg_types[datagram_type] datagram_class = msg_types[datagram_type]
decoded = { decoded = {
k: converted[str(i).encode()] # pylint: disable=unsubscriptable-object k: primitive[i] # pylint: disable=unsubscriptable-object
for i, k in enumerate(datagram_class.required_fields) for i, k in enumerate(datagram_class.required_fields)
if str(i).encode() in converted # pylint: disable=unsupported-membership-test if i in primitive # pylint: disable=unsupported-membership-test
} }
for i, _ in enumerate(OPTIONAL_FIELDS): for i, _ in enumerate(OPTIONAL_FIELDS):
if str(i + OPTIONAL_ARG_OFFSET).encode() in converted: if i + OPTIONAL_ARG_OFFSET in primitive:
decoded[i + OPTIONAL_ARG_OFFSET] = converted[str(i + OPTIONAL_ARG_OFFSET).encode()] decoded[i + OPTIONAL_ARG_OFFSET] = primitive[i + OPTIONAL_ARG_OFFSET]
return decoded, datagram_class
def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]:
decoded, datagram_class = _decode_datagram(datagram)
return datagram_class(**decoded) return datagram_class(**decoded)
def make_compact_ip(address: str) -> bytearray: def make_compact_ip(address: str) -> bytearray:
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray()) compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray())
if len(compact_ip) != 4: if len(compact_ip) != 4:
raise ValueError("invalid IPv4 length") raise ValueError(f"invalid IPv4 length")
return compact_ip return compact_ip
@ -190,7 +180,7 @@ def make_compact_address(node_id: bytes, address: str, port: int) -> bytearray:
if not 0 < port < 65536: if not 0 < port < 65536:
raise ValueError(f'Invalid port: {port}') raise ValueError(f'Invalid port: {port}')
if len(node_id) != constants.HASH_BITS // 8: if len(node_id) != constants.HASH_BITS // 8:
raise ValueError("invalid node node_id length") raise ValueError(f"invalid node node_id length")
return compact_ip + port.to_bytes(2, 'big') + node_id return compact_ip + port.to_bytes(2, 'big') + node_id
@ -201,5 +191,5 @@ def decode_compact_address(compact_address: bytes) -> typing.Tuple[bytes, str, i
if not 0 < port < 65536: if not 0 < port < 65536:
raise ValueError(f'Invalid port: {port}') raise ValueError(f'Invalid port: {port}')
if len(node_id) != constants.HASH_BITS // 8: if len(node_id) != constants.HASH_BITS // 8:
raise ValueError("invalid node node_id length") raise ValueError(f"invalid node node_id length")
return node_id, address, port return node_id, address, port

View file

@ -34,11 +34,6 @@ Code | Name | Message
**11x** | InputValue(ValueError) | Invalid argument value provided to command. **11x** | InputValue(ValueError) | Invalid argument value provided to command.
111 | GenericInputValue | The value '{value}' for argument '{argument}' is not valid. 111 | GenericInputValue | The value '{value}' for argument '{argument}' is not valid.
112 | InputValueIsNone | None or null is not valid value for argument '{argument}'. 112 | InputValueIsNone | None or null is not valid value for argument '{argument}'.
113 | ConflictingInputValue | Only '{first_argument}' or '{second_argument}' is allowed, not both.
114 | InputStringIsBlank | {argument} cannot be blank.
115 | EmptyPublishedFile | Cannot publish empty file: {file_path}
116 | MissingPublishedFile | File does not exist: {file_path}
117 | InvalidStreamURL | Invalid LBRY stream URL: '{url}' -- When an URL cannot be downloaded, such as '@Channel/' or a collection
**2xx** | Configuration | Configuration errors. **2xx** | Configuration | Configuration errors.
201 | ConfigWrite | Cannot write configuration file '{path}'. -- When writing the default config fails on startup, such as due to permission issues. 201 | ConfigWrite | Cannot write configuration file '{path}'. -- When writing the default config fails on startup, such as due to permission issues.
202 | ConfigRead | Cannot find provided configuration file '{path}'. -- Can't open the config file user provided via command line args. 202 | ConfigRead | Cannot find provided configuration file '{path}'. -- Can't open the config file user provided via command line args.
@ -56,22 +51,15 @@ Code | Name | Message
405 | ChannelKeyNotFound | Channel signing key not found. 405 | ChannelKeyNotFound | Channel signing key not found.
406 | ChannelKeyInvalid | Channel signing key is out of date. -- For example, channel was updated but you don't have the updated key. 406 | ChannelKeyInvalid | Channel signing key is out of date. -- For example, channel was updated but you don't have the updated key.
407 | DataDownload | Failed to download blob. *generic* 407 | DataDownload | Failed to download blob. *generic*
408 | PrivateKeyNotFound | Couldn't find private key for {key} '{value}'.
410 | Resolve | Failed to resolve '{url}'. 410 | Resolve | Failed to resolve '{url}'.
411 | ResolveTimeout | Failed to resolve '{url}' within the timeout. 411 | ResolveTimeout | Failed to resolve '{url}' within the timeout.
411 | ResolveCensored | Resolve of '{url}' was censored by channel with claim id '{censor_id}'. 411 | ResolveCensored | Resolve of '{url}' was censored by channel with claim id '{claim_id(censor_hash)}'.
420 | KeyFeeAboveMaxAllowed | {message} 420 | KeyFeeAboveMaxAllowed | {message}
421 | InvalidPassword | Password is invalid. 421 | InvalidPassword | Password is invalid.
422 | IncompatibleWalletServer | '{server}:{port}' has an incompatibly old version. 422 | IncompatibleWalletServer | '{server}:{port}' has an incompatibly old version.
423 | TooManyClaimSearchParameters | {key} cant have more than {limit} items.
424 | AlreadyPurchased | You already have a purchase for claim_id '{claim_id_hex}'. Use --allow-duplicate-purchase flag to override.
431 | ServerPaymentInvalidAddress | Invalid address from wallet server: '{address}' - skipping payment round. 431 | ServerPaymentInvalidAddress | Invalid address from wallet server: '{address}' - skipping payment round.
432 | ServerPaymentWalletLocked | Cannot spend funds with locked wallet, skipping payment round. 432 | ServerPaymentWalletLocked | Cannot spend funds with locked wallet, skipping payment round.
433 | ServerPaymentFeeAboveMaxAllowed | Daily server fee of {daily_fee} exceeds maximum configured of {max_fee} LBC. 433 | ServerPaymentFeeAboveMaxAllowed | Daily server fee of {daily_fee} exceeds maximum configured of {max_fee} LBC.
434 | WalletNotLoaded | Wallet {wallet_id} is not loaded.
435 | WalletAlreadyLoaded | Wallet {wallet_path} is already loaded.
436 | WalletNotFound | Wallet not found at {wallet_path}.
437 | WalletAlreadyExists | Wallet {wallet_path} already exists, use `wallet_add` to load it.
**5xx** | Blob | **Blobs** **5xx** | Blob | **Blobs**
500 | BlobNotFound | Blob not found. 500 | BlobNotFound | Blob not found.
501 | BlobPermissionDenied | Permission denied to read blob. 501 | BlobPermissionDenied | Permission denied to read blob.

View file

@ -76,45 +76,6 @@ class InputValueIsNoneError(InputValueError):
super().__init__(f"None or null is not valid value for argument '{argument}'.") super().__init__(f"None or null is not valid value for argument '{argument}'.")
class ConflictingInputValueError(InputValueError):
def __init__(self, first_argument, second_argument):
self.first_argument = first_argument
self.second_argument = second_argument
super().__init__(f"Only '{first_argument}' or '{second_argument}' is allowed, not both.")
class InputStringIsBlankError(InputValueError):
def __init__(self, argument):
self.argument = argument
super().__init__(f"{argument} cannot be blank.")
class EmptyPublishedFileError(InputValueError):
def __init__(self, file_path):
self.file_path = file_path
super().__init__(f"Cannot publish empty file: {file_path}")
class MissingPublishedFileError(InputValueError):
def __init__(self, file_path):
self.file_path = file_path
super().__init__(f"File does not exist: {file_path}")
class InvalidStreamURLError(InputValueError):
"""
When an URL cannot be downloaded, such as '@Channel/' or a collection
"""
def __init__(self, url):
self.url = url
super().__init__(f"Invalid LBRY stream URL: '{url}'")
class ConfigurationError(BaseError): class ConfigurationError(BaseError):
""" """
Configuration errors. Configuration errors.
@ -238,14 +199,6 @@ class DataDownloadError(WalletError):
super().__init__("Failed to download blob. *generic*") super().__init__("Failed to download blob. *generic*")
class PrivateKeyNotFoundError(WalletError):
def __init__(self, key, value):
self.key = key
self.value = value
super().__init__(f"Couldn't find private key for {key} '{value}'.")
class ResolveError(WalletError): class ResolveError(WalletError):
def __init__(self, url): def __init__(self, url):
@ -262,11 +215,10 @@ class ResolveTimeoutError(WalletError):
class ResolveCensoredError(WalletError): class ResolveCensoredError(WalletError):
def __init__(self, url, censor_id, censor_row): def __init__(self, url, censor_hash):
self.url = url self.url = url
self.censor_id = censor_id self.censor_hash = censor_hash
self.censor_row = censor_row super().__init__(f"Resolve of '{url}' was censored by channel with claim id '{claim_id(censor_hash)}'.")
super().__init__(f"Resolve of '{url}' was censored by channel with claim id '{censor_id}'.")
class KeyFeeAboveMaxAllowedError(WalletError): class KeyFeeAboveMaxAllowedError(WalletError):
@ -290,24 +242,6 @@ class IncompatibleWalletServerError(WalletError):
super().__init__(f"'{server}:{port}' has an incompatibly old version.") super().__init__(f"'{server}:{port}' has an incompatibly old version.")
class TooManyClaimSearchParametersError(WalletError):
def __init__(self, key, limit):
self.key = key
self.limit = limit
super().__init__(f"{key} cant have more than {limit} items.")
class AlreadyPurchasedError(WalletError):
"""
allow-duplicate-purchase flag to override.
"""
def __init__(self, claim_id_hex):
self.claim_id_hex = claim_id_hex
super().__init__(f"You already have a purchase for claim_id '{claim_id_hex}'. Use")
class ServerPaymentInvalidAddressError(WalletError): class ServerPaymentInvalidAddressError(WalletError):
def __init__(self, address): def __init__(self, address):
@ -329,34 +263,6 @@ class ServerPaymentFeeAboveMaxAllowedError(WalletError):
super().__init__(f"Daily server fee of {daily_fee} exceeds maximum configured of {max_fee} LBC.") super().__init__(f"Daily server fee of {daily_fee} exceeds maximum configured of {max_fee} LBC.")
class WalletNotLoadedError(WalletError):
def __init__(self, wallet_id):
self.wallet_id = wallet_id
super().__init__(f"Wallet {wallet_id} is not loaded.")
class WalletAlreadyLoadedError(WalletError):
def __init__(self, wallet_path):
self.wallet_path = wallet_path
super().__init__(f"Wallet {wallet_path} is already loaded.")
class WalletNotFoundError(WalletError):
def __init__(self, wallet_path):
self.wallet_path = wallet_path
super().__init__(f"Wallet not found at {wallet_path}.")
class WalletAlreadyExistsError(WalletError):
def __init__(self, wallet_path):
self.wallet_path = wallet_path
super().__init__(f"Wallet {wallet_path} already exists, use `wallet_add` to load it.")
class BlobError(BaseError): class BlobError(BaseError):
""" """
**Blobs** **Blobs**

View file

@ -63,7 +63,7 @@ class ErrorClass:
@staticmethod @staticmethod
def get_fields(args): def get_fields(args):
if len(args) > 1: if len(args) > 1:
return ''.join(f'\n{INDENT*2}self.{field} = {field}' for field in args[1:]) return f''.join(f'\n{INDENT*2}self.{field} = {field}' for field in args[1:])
return '' return ''
@staticmethod @staticmethod

View file

@ -14,6 +14,7 @@ from aiohttp.web import GracefulExit
from docopt import docopt from docopt import docopt
from lbry import __version__ as lbrynet_version from lbry import __version__ as lbrynet_version
from lbry.extras.daemon.loggly_handler import get_loggly_handler
from lbry.extras.daemon.daemon import Daemon from lbry.extras.daemon.daemon import Daemon
from lbry.conf import Config, CLIConfig from lbry.conf import Config, CLIConfig
@ -101,7 +102,7 @@ class ArgumentParser(argparse.ArgumentParser):
self._optionals.title = 'Options' self._optionals.title = 'Options'
if group_name is None: if group_name is None:
self.epilog = ( self.epilog = (
"Run 'lbrynet COMMAND --help' for more information on a command or group." f"Run 'lbrynet COMMAND --help' for more information on a command or group."
) )
else: else:
self.epilog = ( self.epilog = (
@ -226,9 +227,6 @@ def get_argument_parser():
def ensure_directory_exists(path: str): def ensure_directory_exists(path: str):
if not os.path.isdir(path): if not os.path.isdir(path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True) pathlib.Path(path).mkdir(parents=True, exist_ok=True)
use_effective_ids = os.access in os.supports_effective_ids
if not os.access(path, os.W_OK, effective_ids=use_effective_ids):
raise PermissionError(f"The following directory is not writable: {path}")
LOG_MODULES = 'lbry', 'aioupnp' LOG_MODULES = 'lbry', 'aioupnp'
@ -257,6 +255,10 @@ def setup_logging(logger: logging.Logger, args: argparse.Namespace, conf: Config
else: else:
logger.getChild('lbry').setLevel(logging.DEBUG) logger.getChild('lbry').setLevel(logging.DEBUG)
loggly_handler = get_loggly_handler(conf)
loggly_handler.setLevel(logging.ERROR)
logger.getChild('lbry').addHandler(loggly_handler)
def run_daemon(args: argparse.Namespace, conf: Config): def run_daemon(args: argparse.Namespace, conf: Config):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()

View file

@ -18,7 +18,6 @@ DOWNLOAD_STARTED = 'Download Started'
DOWNLOAD_ERRORED = 'Download Errored' DOWNLOAD_ERRORED = 'Download Errored'
DOWNLOAD_FINISHED = 'Download Finished' DOWNLOAD_FINISHED = 'Download Finished'
HEARTBEAT = 'Heartbeat' HEARTBEAT = 'Heartbeat'
DISK_SPACE = 'Disk Space'
CLAIM_ACTION = 'Claim Action' # publish/create/update/abandon CLAIM_ACTION = 'Claim Action' # publish/create/update/abandon
NEW_CHANNEL = 'New Channel' NEW_CHANNEL = 'New Channel'
CREDITS_SENT = 'Credits Sent' CREDITS_SENT = 'Credits Sent'
@ -67,7 +66,7 @@ def _download_properties(conf: Config, external_ip: str, resolve_duration: float
"node_rpc_timeout": conf.node_rpc_timeout, "node_rpc_timeout": conf.node_rpc_timeout,
"peer_connect_timeout": conf.peer_connect_timeout, "peer_connect_timeout": conf.peer_connect_timeout,
"blob_download_timeout": conf.blob_download_timeout, "blob_download_timeout": conf.blob_download_timeout,
"use_fixed_peers": len(conf.fixed_peers) > 0, "use_fixed_peers": len(conf.reflector_servers) > 0,
"fixed_peer_delay": fixed_peer_delay, "fixed_peer_delay": fixed_peer_delay,
"added_fixed_peers": added_fixed_peers, "added_fixed_peers": added_fixed_peers,
"active_peer_count": active_peer_count, "active_peer_count": active_peer_count,
@ -133,7 +132,7 @@ class AnalyticsManager:
async def run(self): async def run(self):
while True: while True:
if self.enabled: if self.enabled:
self.external_ip, _ = await utils.get_external_ip(self.conf.lbryum_servers) self.external_ip = await utils.get_external_ip()
await self._send_heartbeat() await self._send_heartbeat()
await asyncio.sleep(1800) await asyncio.sleep(1800)
@ -170,15 +169,6 @@ class AnalyticsManager:
}) })
) )
async def send_disk_space_used(self, storage_used, storage_limit, is_from_network_quota):
await self.track(
self._event(DISK_SPACE, {
'used': storage_used,
'limit': storage_limit,
'from_network_quota': is_from_network_quota
})
)
async def send_server_startup(self): async def send_server_startup(self):
await self.track(self._event(SERVER_STARTUP)) await self.track(self._event(SERVER_STARTUP))

View file

@ -1,5 +1,5 @@
from lbry.extras.cli import execute_command
from lbry.conf import Config from lbry.conf import Config
from lbry.extras.cli import execute_command
def daemon_rpc(conf: Config, method: str, **kwargs): def daemon_rpc(conf: Config, method: str, **kwargs):

View file

@ -0,0 +1,66 @@
import logging
import time
import hashlib
import binascii
import ecdsa
from lbry import utils
from lbry.crypto.hash import sha256
from lbry.wallet.transaction import Output
log = logging.getLogger(__name__)
def get_encoded_signature(signature):
signature = signature.encode() if isinstance(signature, str) else signature
r = int(signature[:int(len(signature) / 2)], 16)
s = int(signature[int(len(signature) / 2):], 16)
return ecdsa.util.sigencode_der(r, s, len(signature) * 4)
def cid2hash(claim_id: str) -> bytes:
return binascii.unhexlify(claim_id.encode())[::-1]
def is_comment_signed_by_channel(comment: dict, channel: Output, abandon=False):
if isinstance(channel, Output):
try:
signing_field = comment['comment_id'] if abandon else comment['comment']
pieces = [
comment['signing_ts'].encode(),
cid2hash(comment['channel_id']),
signing_field.encode()
]
return Output.is_signature_valid(
get_encoded_signature(comment['signature']),
sha256(b''.join(pieces)),
channel.claim.channel.public_key_bytes
)
except KeyError:
pass
return False
def sign_comment(comment: dict, channel: Output, abandon=False):
timestamp = str(int(time.time()))
signing_field = comment['comment_id'] if abandon else comment['comment']
pieces = [timestamp.encode(), channel.claim_hash, signing_field.encode()]
digest = sha256(b''.join(pieces))
signature = channel.private_key.sign_digest_deterministic(digest, hashfunc=hashlib.sha256)
comment.update({
'signature': binascii.hexlify(signature).decode(),
'signing_ts': timestamp
})
async def jsonrpc_post(url: str, method: str, params: dict = None, **kwargs) -> any:
params = params or {}
params.update(kwargs)
json_body = {'jsonrpc': '2.0', 'id': None, 'method': method, 'params': params}
async with utils.aiohttp_request('POST', url, json=json_body) as response:
try:
result = await response.json()
return result['result'] if 'result' in result else result
except Exception as cte:
log.exception('Unable to decode response from server: %s', cte)
return await response.text()

View file

@ -37,7 +37,7 @@ class Component(metaclass=ComponentType):
def running(self): def running(self):
return self._running return self._running
async def get_status(self): # pylint: disable=no-self-use async def get_status(self):
return return
async def start(self): async def start(self):

View file

@ -42,7 +42,7 @@ class ComponentManager:
self.analytics_manager = analytics_manager self.analytics_manager = analytics_manager
self.component_classes = {} self.component_classes = {}
self.components = set() self.components = set()
self.started = asyncio.Event() self.started = asyncio.Event(loop=self.loop)
self.peer_manager = peer_manager or PeerManager(asyncio.get_event_loop_policy().get_event_loop()) self.peer_manager = peer_manager or PeerManager(asyncio.get_event_loop_policy().get_event_loop())
for component_name, component_class in self.default_component_classes.items(): for component_name, component_class in self.default_component_classes.items():
@ -118,7 +118,7 @@ class ComponentManager:
component._setup() for component in stage if not component.running component._setup() for component in stage if not component.running
] ]
if needing_start: if needing_start:
await asyncio.wait(map(asyncio.create_task, needing_start)) await asyncio.wait(needing_start)
self.started.set() self.started.set()
async def stop(self): async def stop(self):
@ -131,7 +131,7 @@ class ComponentManager:
component._stop() for component in stage if component.running component._stop() for component in stage if component.running
] ]
if needing_stop: if needing_stop:
await asyncio.wait(map(asyncio.create_task, needing_stop)) await asyncio.wait(needing_stop)
def all_components_running(self, *component_names): def all_components_running(self, *component_names):
""" """
@ -158,14 +158,11 @@ class ComponentManager:
for component in self.components for component in self.components
} }
def get_actual_component(self, component_name): def get_component(self, component_name):
for component in self.components: for component in self.components:
if component.component_name == component_name: if component.component_name == component_name:
return component return component.component
raise NameError(component_name) raise NameError(component_name)
def get_component(self, component_name):
return self.get_actual_component(component_name).component
def has_component(self, component_name): def has_component(self, component_name):
return any(component for component in self.components if component_name == component.component_name) return any(component for component in self.components if component_name == component.component_name)

View file

@ -4,7 +4,6 @@ import asyncio
import logging import logging
import binascii import binascii
import typing import typing
import base58 import base58
from aioupnp import __version__ as aioupnp_version from aioupnp import __version__ as aioupnp_version
@ -16,9 +15,7 @@ from lbry.dht.node import Node
from lbry.dht.peer import is_valid_public_ipv4 from lbry.dht.peer import is_valid_public_ipv4
from lbry.dht.blob_announcer import BlobAnnouncer from lbry.dht.blob_announcer import BlobAnnouncer
from lbry.blob.blob_manager import BlobManager from lbry.blob.blob_manager import BlobManager
from lbry.blob.disk_space_manager import DiskSpaceManager
from lbry.blob_exchange.server import BlobServer from lbry.blob_exchange.server import BlobServer
from lbry.stream.background_downloader import BackgroundDownloader
from lbry.stream.stream_manager import StreamManager from lbry.stream.stream_manager import StreamManager
from lbry.file.file_manager import FileManager from lbry.file.file_manager import FileManager
from lbry.extras.daemon.component import Component from lbry.extras.daemon.component import Component
@ -27,8 +24,10 @@ from lbry.extras.daemon.storage import SQLiteStorage
from lbry.torrent.torrent_manager import TorrentManager from lbry.torrent.torrent_manager import TorrentManager
from lbry.wallet import WalletManager from lbry.wallet import WalletManager
from lbry.wallet.usage_payment import WalletServerPayer from lbry.wallet.usage_payment import WalletServerPayer
from lbry.torrent.tracker import TrackerClient try:
from lbry.torrent.session import TorrentSession from lbry.torrent.session import TorrentSession
except ImportError:
TorrentSession = None
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -41,12 +40,9 @@ WALLET_SERVER_PAYMENTS_COMPONENT = "wallet_server_payments"
DHT_COMPONENT = "dht" DHT_COMPONENT = "dht"
HASH_ANNOUNCER_COMPONENT = "hash_announcer" HASH_ANNOUNCER_COMPONENT = "hash_announcer"
FILE_MANAGER_COMPONENT = "file_manager" FILE_MANAGER_COMPONENT = "file_manager"
DISK_SPACE_COMPONENT = "disk_space"
BACKGROUND_DOWNLOADER_COMPONENT = "background_downloader"
PEER_PROTOCOL_SERVER_COMPONENT = "peer_protocol_server" PEER_PROTOCOL_SERVER_COMPONENT = "peer_protocol_server"
UPNP_COMPONENT = "upnp" UPNP_COMPONENT = "upnp"
EXCHANGE_RATE_MANAGER_COMPONENT = "exchange_rate_manager" EXCHANGE_RATE_MANAGER_COMPONENT = "exchange_rate_manager"
TRACKER_ANNOUNCER_COMPONENT = "tracker_announcer_component"
LIBTORRENT_COMPONENT = "libtorrent_component" LIBTORRENT_COMPONENT = "libtorrent_component"
@ -63,7 +59,7 @@ class DatabaseComponent(Component):
@staticmethod @staticmethod
def get_current_db_revision(): def get_current_db_revision():
return 15 return 14
@property @property
def revision_filename(self): def revision_filename(self):
@ -123,14 +119,13 @@ class WalletComponent(Component):
async def get_status(self): async def get_status(self):
if self.wallet_manager is None: if self.wallet_manager is None:
return return
is_connected = self.wallet_manager.ledger.network.is_connected session_pool = self.wallet_manager.ledger.network.session_pool
sessions = [] sessions = session_pool.sessions
connected = None connected = None
if is_connected: if self.wallet_manager.ledger.network.client:
addr, port = self.wallet_manager.ledger.network.client.server addr_and_port = self.wallet_manager.ledger.network.client.server_address_and_port
connected = f"{addr}:{port}" if addr_and_port:
sessions.append(self.wallet_manager.ledger.network.client) connected = f"{addr_and_port[0]}:{addr_and_port[1]}"
result = { result = {
'connected': connected, 'connected': connected,
'connected_features': self.wallet_manager.ledger.network.server_features, 'connected_features': self.wallet_manager.ledger.network.server_features,
@ -142,8 +137,8 @@ class WalletComponent(Component):
'availability': session.available, 'availability': session.available,
} for session in sessions } for session in sessions
], ],
'known_servers': len(self.wallet_manager.ledger.network.known_hubs), 'known_servers': len(sessions),
'available_servers': 1 if is_connected else 0 'available_servers': len(list(session_pool.available_sessions))
} }
if self.wallet_manager.ledger.network.remote_height: if self.wallet_manager.ledger.network.remote_height:
@ -155,7 +150,7 @@ class WalletComponent(Component):
progress = min(max(math.ceil(float(download_height) / float(target_height) * 100), 0), 100) progress = min(max(math.ceil(float(download_height) / float(target_height) * 100), 0), 100)
else: else:
progress = 100 progress = 100
best_hash = await self.wallet_manager.get_best_blockhash() best_hash = self.wallet_manager.get_best_blockhash()
result.update({ result.update({
'headers_synchronization_progress': progress, 'headers_synchronization_progress': progress,
'blocks': max(local_height, 0), 'blocks': max(local_height, 0),
@ -279,7 +274,7 @@ class DHTComponent(Component):
external_ip = upnp_component.external_ip external_ip = upnp_component.external_ip
storage = self.component_manager.get_component(DATABASE_COMPONENT) storage = self.component_manager.get_component(DATABASE_COMPONENT)
if not external_ip: if not external_ip:
external_ip, _ = await utils.get_external_ip(self.conf.lbryum_servers) external_ip = await utils.get_external_ip()
if not external_ip: if not external_ip:
log.warning("failed to get external ip") log.warning("failed to get external ip")
@ -293,7 +288,6 @@ class DHTComponent(Component):
peer_port=self.external_peer_port, peer_port=self.external_peer_port,
rpc_timeout=self.conf.node_rpc_timeout, rpc_timeout=self.conf.node_rpc_timeout,
split_buckets_under_index=self.conf.split_buckets_under_index, split_buckets_under_index=self.conf.split_buckets_under_index,
is_bootstrap_node=self.conf.is_bootstrap_node,
storage=storage storage=storage
) )
self.dht_node.start(self.conf.network_interface, self.conf.known_dht_nodes) self.dht_node.start(self.conf.network_interface, self.conf.known_dht_nodes)
@ -334,7 +328,7 @@ class HashAnnouncerComponent(Component):
class FileManagerComponent(Component): class FileManagerComponent(Component):
component_name = FILE_MANAGER_COMPONENT component_name = FILE_MANAGER_COMPONENT
depends_on = [BLOB_COMPONENT, DATABASE_COMPONENT, WALLET_COMPONENT] depends_on = [BLOB_COMPONENT, DATABASE_COMPONENT, WALLET_COMPONENT, LIBTORRENT_COMPONENT]
def __init__(self, component_manager): def __init__(self, component_manager):
super().__init__(component_manager) super().__init__(component_manager)
@ -357,6 +351,7 @@ class FileManagerComponent(Component):
wallet = self.component_manager.get_component(WALLET_COMPONENT) wallet = self.component_manager.get_component(WALLET_COMPONENT)
node = self.component_manager.get_component(DHT_COMPONENT) \ node = self.component_manager.get_component(DHT_COMPONENT) \
if self.component_manager.has_component(DHT_COMPONENT) else None if self.component_manager.has_component(DHT_COMPONENT) else None
torrent = self.component_manager.get_component(LIBTORRENT_COMPONENT) if TorrentSession else None
log.info('Starting the file manager') log.info('Starting the file manager')
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
self.file_manager = FileManager( self.file_manager = FileManager(
@ -365,8 +360,7 @@ class FileManagerComponent(Component):
self.file_manager.source_managers['stream'] = StreamManager( self.file_manager.source_managers['stream'] = StreamManager(
loop, self.conf, blob_manager, wallet, storage, node, loop, self.conf, blob_manager, wallet, storage, node,
) )
if self.component_manager.has_component(LIBTORRENT_COMPONENT): if TorrentSession:
torrent = self.component_manager.get_component(LIBTORRENT_COMPONENT)
self.file_manager.source_managers['torrent'] = TorrentManager( self.file_manager.source_managers['torrent'] = TorrentManager(
loop, self.conf, torrent, storage, self.component_manager.analytics_manager loop, self.conf, torrent, storage, self.component_manager.analytics_manager
) )
@ -374,106 +368,7 @@ class FileManagerComponent(Component):
log.info('Done setting up file manager') log.info('Done setting up file manager')
async def stop(self): async def stop(self):
await self.file_manager.stop() self.file_manager.stop()
class BackgroundDownloaderComponent(Component):
MIN_PREFIX_COLLIDING_BITS = 8
component_name = BACKGROUND_DOWNLOADER_COMPONENT
depends_on = [DATABASE_COMPONENT, BLOB_COMPONENT, DISK_SPACE_COMPONENT]
def __init__(self, component_manager):
super().__init__(component_manager)
self.background_task: typing.Optional[asyncio.Task] = None
self.download_loop_delay_seconds = 60
self.ongoing_download: typing.Optional[asyncio.Task] = None
self.space_manager: typing.Optional[DiskSpaceManager] = None
self.blob_manager: typing.Optional[BlobManager] = None
self.background_downloader: typing.Optional[BackgroundDownloader] = None
self.dht_node: typing.Optional[Node] = None
self.space_available: typing.Optional[int] = None
@property
def is_busy(self):
return bool(self.ongoing_download and not self.ongoing_download.done())
@property
def component(self) -> 'BackgroundDownloaderComponent':
return self
async def get_status(self):
return {'running': self.background_task is not None and not self.background_task.done(),
'available_free_space_mb': self.space_available,
'ongoing_download': self.is_busy}
async def download_blobs_in_background(self):
while True:
self.space_available = await self.space_manager.get_free_space_mb(True)
if not self.is_busy and self.space_available > 10:
self._download_next_close_blob_hash()
await asyncio.sleep(self.download_loop_delay_seconds)
def _download_next_close_blob_hash(self):
node_id = self.dht_node.protocol.node_id
for blob_hash in self.dht_node.stored_blob_hashes:
if blob_hash.hex() in self.blob_manager.completed_blob_hashes:
continue
if utils.get_colliding_prefix_bits(node_id, blob_hash) >= self.MIN_PREFIX_COLLIDING_BITS:
self.ongoing_download = asyncio.create_task(self.background_downloader.download_blobs(blob_hash.hex()))
return
async def start(self):
self.space_manager: DiskSpaceManager = self.component_manager.get_component(DISK_SPACE_COMPONENT)
if not self.component_manager.has_component(DHT_COMPONENT):
return
self.dht_node = self.component_manager.get_component(DHT_COMPONENT)
self.blob_manager = self.component_manager.get_component(BLOB_COMPONENT)
storage = self.component_manager.get_component(DATABASE_COMPONENT)
self.background_downloader = BackgroundDownloader(self.conf, storage, self.blob_manager, self.dht_node)
self.background_task = asyncio.create_task(self.download_blobs_in_background())
async def stop(self):
if self.ongoing_download and not self.ongoing_download.done():
self.ongoing_download.cancel()
if self.background_task:
self.background_task.cancel()
class DiskSpaceComponent(Component):
component_name = DISK_SPACE_COMPONENT
depends_on = [DATABASE_COMPONENT, BLOB_COMPONENT]
def __init__(self, component_manager):
super().__init__(component_manager)
self.disk_space_manager: typing.Optional[DiskSpaceManager] = None
@property
def component(self) -> typing.Optional[DiskSpaceManager]:
return self.disk_space_manager
async def get_status(self):
if self.disk_space_manager:
space_used = await self.disk_space_manager.get_space_used_mb(cached=True)
return {
'total_used_mb': space_used['total'],
'published_blobs_storage_used_mb': space_used['private_storage'],
'content_blobs_storage_used_mb': space_used['content_storage'],
'seed_blobs_storage_used_mb': space_used['network_storage'],
'running': self.disk_space_manager.running,
}
return {'space_used': '0', 'network_seeding_space_used': '0', 'running': False}
async def start(self):
db = self.component_manager.get_component(DATABASE_COMPONENT)
blob_manager = self.component_manager.get_component(BLOB_COMPONENT)
self.disk_space_manager = DiskSpaceManager(
self.conf, db, blob_manager,
analytics=self.component_manager.analytics_manager
)
await self.disk_space_manager.start()
async def stop(self):
await self.disk_space_manager.stop()
class TorrentComponent(Component): class TorrentComponent(Component):
@ -495,6 +390,7 @@ class TorrentComponent(Component):
} }
async def start(self): async def start(self):
if TorrentSession:
self.torrent_session = TorrentSession(asyncio.get_event_loop(), None) self.torrent_session = TorrentSession(asyncio.get_event_loop(), None)
await self.torrent_session.bind() # TODO: specify host/port await self.torrent_session.bind() # TODO: specify host/port
@ -551,7 +447,7 @@ class UPnPComponent(Component):
while True: while True:
if now: if now:
await self._maintain_redirects() await self._maintain_redirects()
await asyncio.sleep(360) await asyncio.sleep(360, loop=self.component_manager.loop)
async def _maintain_redirects(self): async def _maintain_redirects(self):
# setup the gateway if necessary # setup the gateway if necessary
@ -560,6 +456,8 @@ class UPnPComponent(Component):
self.upnp = await UPnP.discover(loop=self.component_manager.loop) self.upnp = await UPnP.discover(loop=self.component_manager.loop)
log.info("found upnp gateway: %s", self.upnp.gateway.manufacturer_string) log.info("found upnp gateway: %s", self.upnp.gateway.manufacturer_string)
except Exception as err: except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.warning("upnp discovery failed: %s", err) log.warning("upnp discovery failed: %s", err)
self.upnp = None self.upnp = None
@ -574,15 +472,11 @@ class UPnPComponent(Component):
pass pass
if external_ip and not is_valid_public_ipv4(external_ip): if external_ip and not is_valid_public_ipv4(external_ip):
log.warning("UPnP returned a private/reserved ip - %s, checking lbry.com fallback", external_ip) log.warning("UPnP returned a private/reserved ip - %s, checking lbry.com fallback", external_ip)
external_ip, _ = await utils.get_external_ip(self.conf.lbryum_servers) external_ip = await utils.get_external_ip()
if self.external_ip and self.external_ip != external_ip: if self.external_ip and self.external_ip != external_ip:
log.info("external ip changed from %s to %s", self.external_ip, external_ip) log.info("external ip changed from %s to %s", self.external_ip, external_ip)
if external_ip: if external_ip:
self.external_ip = external_ip self.external_ip = external_ip
dht_component = self.component_manager.get_component(DHT_COMPONENT)
if dht_component:
dht_node = dht_component.component
dht_node.protocol.external_ip = external_ip
# assert self.external_ip is not None # TODO: handle going/starting offline # assert self.external_ip is not None # TODO: handle going/starting offline
if not self.upnp_redirects and self.upnp: # setup missing redirects if not self.upnp_redirects and self.upnp: # setup missing redirects
@ -636,15 +530,13 @@ class UPnPComponent(Component):
async def start(self): async def start(self):
log.info("detecting external ip") log.info("detecting external ip")
if not self.use_upnp: if not self.use_upnp:
self.external_ip, _ = await utils.get_external_ip(self.conf.lbryum_servers) self.external_ip = await utils.get_external_ip()
return return
success = False success = False
await self._maintain_redirects() await self._maintain_redirects()
if self.upnp: if self.upnp:
if not self.upnp_redirects and not all( if not self.upnp_redirects and not all([x in self.component_manager.skip_components for x in
x in self.component_manager.skip_components (DHT_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT)]):
for x in (DHT_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT)
):
log.error("failed to setup upnp") log.error("failed to setup upnp")
else: else:
success = True success = True
@ -653,9 +545,9 @@ class UPnPComponent(Component):
else: else:
log.error("failed to setup upnp") log.error("failed to setup upnp")
if not self.external_ip: if not self.external_ip:
self.external_ip, probed_url = await utils.get_external_ip(self.conf.lbryum_servers) self.external_ip = await utils.get_external_ip()
if self.external_ip: if self.external_ip:
log.info("detected external ip using %s fallback", probed_url) log.info("detected external ip using lbry.com fallback")
if self.component_manager.analytics_manager: if self.component_manager.analytics_manager:
self.component_manager.loop.create_task( self.component_manager.loop.create_task(
self.component_manager.analytics_manager.send_upnp_setup_success_fail( self.component_manager.analytics_manager.send_upnp_setup_success_fail(
@ -671,7 +563,7 @@ class UPnPComponent(Component):
log.info("Removing upnp redirects: %s", self.upnp_redirects) log.info("Removing upnp redirects: %s", self.upnp_redirects)
await asyncio.wait([ await asyncio.wait([
self.upnp.delete_port_mapping(port, protocol) for protocol, port in self.upnp_redirects.items() self.upnp.delete_port_mapping(port, protocol) for protocol, port in self.upnp_redirects.items()
]) ], loop=self.component_manager.loop)
if self._maintain_redirects_task and not self._maintain_redirects_task.done(): if self._maintain_redirects_task and not self._maintain_redirects_task.done():
self._maintain_redirects_task.cancel() self._maintain_redirects_task.cancel()
@ -702,49 +594,3 @@ class ExchangeRateManagerComponent(Component):
async def stop(self): async def stop(self):
self.exchange_rate_manager.stop() self.exchange_rate_manager.stop()
class TrackerAnnouncerComponent(Component):
component_name = TRACKER_ANNOUNCER_COMPONENT
depends_on = [FILE_MANAGER_COMPONENT]
def __init__(self, component_manager):
super().__init__(component_manager)
self.file_manager = None
self.announce_task = None
self.tracker_client: typing.Optional[TrackerClient] = None
@property
def component(self):
return self.tracker_client
@property
def running(self):
return self._running and self.announce_task and not self.announce_task.done()
async def announce_forever(self):
while True:
sleep_seconds = 60.0
announce_sd_hashes = []
for file in self.file_manager.get_filtered():
if not file.downloader:
continue
announce_sd_hashes.append(bytes.fromhex(file.sd_hash))
await self.tracker_client.announce_many(*announce_sd_hashes)
await asyncio.sleep(sleep_seconds)
async def start(self):
node = self.component_manager.get_component(DHT_COMPONENT) \
if self.component_manager.has_component(DHT_COMPONENT) else None
node_id = node.protocol.node_id if node else None
self.tracker_client = TrackerClient(node_id, self.conf.tcp_port, lambda: self.conf.tracker_servers)
await self.tracker_client.start()
self.file_manager = self.component_manager.get_component(FILE_MANAGER_COMPONENT)
self.announce_task = asyncio.create_task(self.announce_forever())
async def stop(self):
self.file_manager = None
if self.announce_task and not self.announce_task.done():
self.announce_task.cancel()
self.announce_task = None
self.tracker_client.stop()

File diff suppressed because it is too large Load diff

View file

@ -2,10 +2,9 @@ import json
import time import time
import asyncio import asyncio
import logging import logging
from statistics import median
from decimal import Decimal from decimal import Decimal
from typing import Optional, Iterable, Type from typing import Optional, Iterable, Type
from aiohttp.client_exceptions import ContentTypeError, ClientConnectionError from aiohttp.client_exceptions import ContentTypeError
from lbry.error import InvalidExchangeRateResponseError, CurrencyConversionError from lbry.error import InvalidExchangeRateResponseError, CurrencyConversionError
from lbry.utils import aiohttp_request from lbry.utils import aiohttp_request
from lbry.wallet.dewies import lbc_to_dewies from lbry.wallet.dewies import lbc_to_dewies
@ -59,12 +58,9 @@ class MarketFeed:
raise NotImplementedError() raise NotImplementedError()
async def get_response(self): async def get_response(self):
async with aiohttp_request( async with aiohttp_request('get', self.url, params=self.params, timeout=self.request_timeout) as response:
'get', self.url, params=self.params,
timeout=self.request_timeout, headers={"User-Agent": "lbrynet"}
) as response:
try: try:
self._last_response = await response.json(content_type=None) self._last_response = await response.json()
except ContentTypeError as e: except ContentTypeError as e:
self._last_response = {} self._last_response = {}
log.warning("Could not parse exchange rate response from %s: %s", self.name, e.message) log.warning("Could not parse exchange rate response from %s: %s", self.name, e.message)
@ -79,21 +75,18 @@ class MarketFeed:
log.debug("Saving rate update %f for %s from %s", rate, self.market, self.name) log.debug("Saving rate update %f for %s from %s", rate, self.market, self.name)
self.rate = ExchangeRate(self.market, rate, int(time.time())) self.rate = ExchangeRate(self.market, rate, int(time.time()))
self.last_check = time.time() self.last_check = time.time()
self.event.set()
return self.rate return self.rate
except asyncio.CancelledError:
raise
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning("Timed out fetching exchange rate from %s.", self.name) log.warning("Timed out fetching exchange rate from %s.", self.name)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
msg = e.doc if '<html>' not in e.doc else 'unexpected content type.' log.warning("Could not parse exchange rate response from %s: %s", self.name, e.doc)
log.warning("Could not parse exchange rate response from %s: %s", self.name, msg)
log.debug(e.doc)
except InvalidExchangeRateResponseError as e: except InvalidExchangeRateResponseError as e:
log.warning(str(e)) log.warning(str(e))
except ClientConnectionError as e:
log.warning("Error trying to connect to exchange rate %s: %s", self.name, str(e))
except Exception as e: except Exception as e:
log.exception("Exchange rate error (%s from %s):", self.market, self.name) log.exception("Exchange rate error (%s from %s):", self.market, self.name)
finally:
self.event.set()
async def keep_updated(self): async def keep_updated(self):
while True: while True:
@ -111,92 +104,70 @@ class MarketFeed:
self.event.clear() self.event.clear()
class BaseBittrexFeed(MarketFeed): class BittrexFeed(MarketFeed):
name = "Bittrex" name = "Bittrex"
market = None market = "BTCLBC"
url = None url = "https://bittrex.com/api/v1.1/public/getmarkethistory"
params = {'market': 'BTC-LBC', 'count': 50}
fee = 0.0025 fee = 0.0025
def get_rate_from_response(self, json_response):
if 'lastTradeRate' not in json_response:
raise InvalidExchangeRateResponseError(self.name, 'result not found')
return 1.0 / float(json_response['lastTradeRate'])
class BittrexBTCFeed(BaseBittrexFeed):
market = "BTCLBC"
url = "https://api.bittrex.com/v3/markets/LBC-BTC/ticker"
class BittrexUSDFeed(BaseBittrexFeed):
market = "USDLBC"
url = "https://api.bittrex.com/v3/markets/LBC-USD/ticker"
class BaseCoinExFeed(MarketFeed):
name = "CoinEx"
market = None
url = None
def get_rate_from_response(self, json_response):
if 'data' not in json_response or \
'ticker' not in json_response['data'] or \
'last' not in json_response['data']['ticker']:
raise InvalidExchangeRateResponseError(self.name, 'result not found')
return 1.0 / float(json_response['data']['ticker']['last'])
class CoinExBTCFeed(BaseCoinExFeed):
market = "BTCLBC"
url = "https://api.coinex.com/v1/market/ticker?market=LBCBTC"
class CoinExUSDFeed(BaseCoinExFeed):
market = "USDLBC"
url = "https://api.coinex.com/v1/market/ticker?market=LBCUSDT"
class BaseHotbitFeed(MarketFeed):
name = "hotbit"
market = None
url = "https://api.hotbit.io/api/v1/market.last"
def get_rate_from_response(self, json_response): def get_rate_from_response(self, json_response):
if 'result' not in json_response: if 'result' not in json_response:
raise InvalidExchangeRateResponseError(self.name, 'result not found') raise InvalidExchangeRateResponseError(self.name, 'result not found')
return 1.0 / float(json_response['result']) trades = json_response['result']
if len(trades) == 0:
raise InvalidExchangeRateResponseError(self.name, 'trades not found')
totals = sum([i['Total'] for i in trades])
qtys = sum([i['Quantity'] for i in trades])
if totals <= 0 or qtys <= 0:
raise InvalidExchangeRateResponseError(self.name, 'quantities were not positive')
vwap = totals / qtys
return float(1.0 / vwap)
class HotbitBTCFeed(BaseHotbitFeed): class LBRYFeed(MarketFeed):
name = "lbry.com"
market = "BTCLBC" market = "BTCLBC"
params = {"market": "LBC/BTC"} url = "https://api.lbry.com/lbc/exchange_rate"
class HotbitUSDFeed(BaseHotbitFeed):
market = "USDLBC"
params = {"market": "LBC/USDT"}
class UPbitBTCFeed(MarketFeed):
name = "UPbit"
market = "BTCLBC"
url = "https://api.upbit.com/v1/ticker"
params = {"markets": "BTC-LBC"}
def get_rate_from_response(self, json_response): def get_rate_from_response(self, json_response):
if "error" in json_response or len(json_response) != 1 or 'trade_price' not in json_response[0]: if 'data' not in json_response:
raise InvalidExchangeRateResponseError(self.name, 'result not found') raise InvalidExchangeRateResponseError(self.name, 'result not found')
return 1.0 / float(json_response[0]['trade_price']) return 1.0 / json_response['data']['lbc_btc']
class LBRYBTCFeed(LBRYFeed):
market = "USDBTC"
def get_rate_from_response(self, json_response):
if 'data' not in json_response:
raise InvalidExchangeRateResponseError(self.name, 'result not found')
return 1.0 / json_response['data']['btc_usd']
class CryptonatorFeed(MarketFeed):
name = "cryptonator.com"
market = "BTCLBC"
url = "https://api.cryptonator.com/api/ticker/btc-lbc"
def get_rate_from_response(self, json_response):
if 'ticker' not in json_response or len(json_response['ticker']) == 0 or \
'success' not in json_response or json_response['success'] is not True:
raise InvalidExchangeRateResponseError(self.name, 'result not found')
return float(json_response['ticker']['price'])
class CryptonatorBTCFeed(CryptonatorFeed):
market = "USDBTC"
url = "https://api.cryptonator.com/api/ticker/usd-btc"
FEEDS: Iterable[Type[MarketFeed]] = ( FEEDS: Iterable[Type[MarketFeed]] = (
BittrexBTCFeed, LBRYFeed,
BittrexUSDFeed, LBRYBTCFeed,
CoinExBTCFeed, BittrexFeed,
CoinExUSDFeed, # CryptonatorFeed,
# HotbitBTCFeed, # CryptonatorBTCFeed,
# HotbitUSDFeed,
# UPbitBTCFeed,
) )
@ -220,23 +191,20 @@ class ExchangeRateManager:
source.stop() source.stop()
def convert_currency(self, from_currency, to_currency, amount): def convert_currency(self, from_currency, to_currency, amount):
log.debug( rates = [market.rate for market in self.market_feeds]
"Converting %f %s to %s, rates: %s", log.debug("Converting %f %s to %s, rates: %s", amount, from_currency, to_currency, rates)
amount, from_currency, to_currency,
[market.rate for market in self.market_feeds]
)
if from_currency == to_currency: if from_currency == to_currency:
return round(amount, 8) return round(amount, 8)
rates = []
for market in self.market_feeds: for market in self.market_feeds:
if (market.has_rate and market.is_online and if (market.has_rate and market.is_online and
market.rate.currency_pair == (from_currency, to_currency)): market.rate.currency_pair == (from_currency, to_currency)):
rates.append(market.rate.spot) return round(amount * Decimal(market.rate.spot), 8)
for market in self.market_feeds:
if rates: if (market.has_rate and market.is_online and
return round(amount * Decimal(median(rates)), 8) market.rate.currency_pair[0] == from_currency):
return round(self.convert_currency(
market.rate.currency_pair[1], to_currency, amount * Decimal(market.rate.spot)), 8)
raise CurrencyConversionError( raise CurrencyConversionError(
f'Unable to convert {amount} from {from_currency} to {to_currency}') f'Unable to convert {amount} from {from_currency} to {to_currency}')

View file

@ -7,10 +7,9 @@ from json import JSONEncoder
from google.protobuf.message import DecodeError from google.protobuf.message import DecodeError
from lbry.schema.claim import Claim from lbry.schema.claim import Claim
from lbry.schema.support import Support
from lbry.torrent.torrent_manager import TorrentSource from lbry.torrent.torrent_manager import TorrentSource
from lbry.wallet import Wallet, Ledger, Account, Transaction, Output from lbry.wallet import Wallet, Ledger, Account, Transaction, Output
from lbry.wallet.bip32 import PublicKey from lbry.wallet.bip32 import PubKey
from lbry.wallet.dewies import dewies_to_lbc from lbry.wallet.dewies import dewies_to_lbc
from lbry.stream.managed_stream import ManagedStream from lbry.stream.managed_stream import ManagedStream
@ -110,9 +109,7 @@ def encode_file_doc():
'metadata': '(dict) None if claim is not found else the claim metadata', 'metadata': '(dict) None if claim is not found else the claim metadata',
'channel_claim_id': '(str) None if claim is not found or not signed', 'channel_claim_id': '(str) None if claim is not found or not signed',
'channel_name': '(str) None if claim is not found or not signed', 'channel_name': '(str) None if claim is not found or not signed',
'claim_name': '(str) None if claim is not found else the claim name', 'claim_name': '(str) None if claim is not found else the claim name'
'reflector_progress': '(int) reflector upload progress, 0 to 100',
'uploading_to_reflector': '(bool) set to True when currently uploading to reflector'
} }
@ -123,12 +120,14 @@ class JSONResponseEncoder(JSONEncoder):
self.ledger = ledger self.ledger = ledger
self.include_protobuf = include_protobuf self.include_protobuf = include_protobuf
def default(self, obj): # pylint: disable=method-hidden,arguments-renamed,too-many-return-statements def default(self, obj): # pylint: disable=method-hidden,arguments-differ,too-many-return-statements
if isinstance(obj, Account): if isinstance(obj, Account):
return self.encode_account(obj) return self.encode_account(obj)
if isinstance(obj, Wallet): if isinstance(obj, Wallet):
return self.encode_wallet(obj) return self.encode_wallet(obj)
if isinstance(obj, (ManagedStream, TorrentSource)): if isinstance(obj, ManagedStream):
return self.encode_file(obj)
if isinstance(obj, TorrentSource):
return self.encode_file(obj) return self.encode_file(obj)
if isinstance(obj, Transaction): if isinstance(obj, Transaction):
return self.encode_transaction(obj) return self.encode_transaction(obj)
@ -136,9 +135,7 @@ class JSONResponseEncoder(JSONEncoder):
return self.encode_output(obj) return self.encode_output(obj)
if isinstance(obj, Claim): if isinstance(obj, Claim):
return self.encode_claim(obj) return self.encode_claim(obj)
if isinstance(obj, Support): if isinstance(obj, PubKey):
return obj.to_dict()
if isinstance(obj, PublicKey):
return obj.extended_key_string() return obj.extended_key_string()
if isinstance(obj, datetime): if isinstance(obj, datetime):
return obj.strftime("%Y%m%dT%H:%M:%S") return obj.strftime("%Y%m%dT%H:%M:%S")
@ -172,22 +169,16 @@ class JSONResponseEncoder(JSONEncoder):
'amount': dewies_to_lbc(txo.amount), 'amount': dewies_to_lbc(txo.amount),
'address': txo.get_address(self.ledger) if txo.has_address else None, 'address': txo.get_address(self.ledger) if txo.has_address else None,
'confirmations': (best_height+1) - tx_height if tx_height > 0 else tx_height, 'confirmations': (best_height+1) - tx_height if tx_height > 0 else tx_height,
'timestamp': self.ledger.headers.estimated_timestamp(tx_height) 'timestamp': self.ledger.headers[tx_height]['timestamp'] if 0 < tx_height <= best_height else None
} }
if txo.is_change is not None:
output['is_change'] = txo.is_change
if txo.is_received is not None:
output['is_received'] = txo.is_received
if txo.is_spent is not None: if txo.is_spent is not None:
output['is_spent'] = txo.is_spent output['is_spent'] = txo.is_spent
if txo.is_my_output is not None: if txo.is_my_account is not None:
output['is_my_output'] = txo.is_my_output output['is_mine'] = txo.is_my_account
if txo.is_my_input is not None:
output['is_my_input'] = txo.is_my_input
if txo.sent_supports is not None:
output['sent_supports'] = dewies_to_lbc(txo.sent_supports)
if txo.sent_tips is not None:
output['sent_tips'] = dewies_to_lbc(txo.sent_tips)
if txo.received_tips is not None:
output['received_tips'] = dewies_to_lbc(txo.received_tips)
if txo.is_internal_transfer is not None:
output['is_internal_transfer'] = txo.is_internal_transfer
if txo.script.is_claim_name: if txo.script.is_claim_name:
output['type'] = 'claim' output['type'] = 'claim'
@ -223,23 +214,22 @@ class JSONResponseEncoder(JSONEncoder):
output['claims'] = [self.encode_output(o) for o in txo.claims] output['claims'] = [self.encode_output(o) for o in txo.claims]
if txo.reposted_claim is not None: if txo.reposted_claim is not None:
output['reposted_claim'] = self.encode_output(txo.reposted_claim) output['reposted_claim'] = self.encode_output(txo.reposted_claim)
if txo.script.is_claim_name or txo.script.is_update_claim or txo.script.is_support_claim_data: if txo.script.is_claim_name or txo.script.is_update_claim:
try: try:
output['value'] = txo.signable output['value'] = txo.claim
output['value_type'] = txo.claim.claim_type
if self.include_protobuf: if self.include_protobuf:
output['protobuf'] = hexlify(txo.signable.to_bytes()) output['protobuf'] = hexlify(txo.claim.to_bytes())
if txo.purchase_receipt is not None: if txo.purchase_receipt is not None:
output['purchase_receipt'] = self.encode_output(txo.purchase_receipt) output['purchase_receipt'] = self.encode_output(txo.purchase_receipt)
if txo.script.is_claim_name or txo.script.is_update_claim:
output['value_type'] = txo.claim.claim_type
if txo.claim.is_channel: if txo.claim.is_channel:
output['has_signing_key'] = txo.has_private_key output['has_signing_key'] = txo.has_private_key
if check_signature and txo.signable.is_signed: if check_signature and txo.claim.is_signed:
if txo.channel is not None: if txo.channel is not None:
output['signing_channel'] = self.encode_output(txo.channel) output['signing_channel'] = self.encode_output(txo.channel)
output['is_channel_signature_valid'] = txo.is_signed_by(txo.channel, self.ledger) output['is_channel_signature_valid'] = txo.is_signed_by(txo.channel, self.ledger)
else: else:
output['signing_channel'] = {'channel_id': txo.signable.signing_channel_id} output['signing_channel'] = {'channel_id': txo.claim.signing_channel_id}
output['is_channel_signature_valid'] = False output['is_channel_signature_valid'] = False
except DecodeError: except DecodeError:
pass pass
@ -251,7 +241,7 @@ class JSONResponseEncoder(JSONEncoder):
if isinstance(value, int): if isinstance(value, int):
meta[key] = dewies_to_lbc(value) meta[key] = dewies_to_lbc(value)
if 0 < meta.get('creation_height', 0) <= self.ledger.headers.height: if 0 < meta.get('creation_height', 0) <= self.ledger.headers.height:
meta['creation_timestamp'] = self.ledger.headers.estimated_timestamp(meta['creation_height']) meta['creation_timestamp'] = self.ledger.headers[meta['creation_height']]['timestamp']
return meta return meta
def encode_input(self, txi): def encode_input(self, txi):
@ -284,26 +274,26 @@ class JSONResponseEncoder(JSONEncoder):
total_bytes = managed_stream.descriptor.upper_bound_decrypted_length() total_bytes = managed_stream.descriptor.upper_bound_decrypted_length()
else: else:
total_bytes_lower_bound = total_bytes = managed_stream.torrent_length total_bytes_lower_bound = total_bytes = managed_stream.torrent_length
result = { return {
'streaming_url': None, 'streaming_url': managed_stream.stream_url if is_stream else f'file://{managed_stream.full_path}',
'completed': managed_stream.completed, 'completed': managed_stream.completed,
'file_name': None, 'file_name': managed_stream.file_name if output_exists else None,
'download_directory': None, 'download_directory': managed_stream.download_directory if output_exists else None,
'download_path': None, 'download_path': managed_stream.full_path if output_exists else None,
'points_paid': 0.0, 'points_paid': 0.0,
'stopped': not managed_stream.running, 'stopped': not managed_stream.running,
'stream_hash': None, 'stream_hash': managed_stream.stream_hash if is_stream else None,
'stream_name': None, 'stream_name': managed_stream.descriptor.stream_name if is_stream else None,
'suggested_file_name': None, 'suggested_file_name': managed_stream.descriptor.suggested_file_name if is_stream else None,
'sd_hash': None, 'sd_hash': managed_stream.descriptor.sd_hash if is_stream else None,
'mime_type': None, 'mime_type': managed_stream.mime_type if is_stream else None,
'key': None, 'key': managed_stream.descriptor.key if is_stream else None,
'total_bytes_lower_bound': total_bytes_lower_bound, 'total_bytes_lower_bound': total_bytes_lower_bound,
'total_bytes': total_bytes, 'total_bytes': total_bytes,
'written_bytes': managed_stream.written_bytes, 'written_bytes': managed_stream.written_bytes if is_stream else managed_stream.written_bytes,
'blobs_completed': None, 'blobs_completed': managed_stream.blobs_completed if is_stream else None,
'blobs_in_stream': None, 'blobs_in_stream': managed_stream.blobs_in_stream if is_stream else None,
'blobs_remaining': None, 'blobs_remaining': managed_stream.blobs_remaining if is_stream else None,
'status': managed_stream.status, 'status': managed_stream.status,
'claim_id': managed_stream.claim_id, 'claim_id': managed_stream.claim_id,
'txid': managed_stream.txid, 'txid': managed_stream.txid,
@ -319,38 +309,9 @@ class JSONResponseEncoder(JSONEncoder):
'added_on': managed_stream.added_on, 'added_on': managed_stream.added_on,
'height': tx_height, 'height': tx_height,
'confirmations': (best_height + 1) - tx_height if tx_height > 0 else tx_height, 'confirmations': (best_height + 1) - tx_height if tx_height > 0 else tx_height,
'timestamp': self.ledger.headers.estimated_timestamp(tx_height), 'timestamp': self.ledger.headers[tx_height]['timestamp'] if 0 < tx_height <= best_height else None,
'is_fully_reflected': False, 'is_fully_reflected': managed_stream.is_fully_reflected if is_stream else False
'reflector_progress': False,
'uploading_to_reflector': False
} }
if is_stream:
result.update({
'streaming_url': managed_stream.stream_url,
'stream_hash': managed_stream.stream_hash,
'stream_name': managed_stream.stream_name,
'suggested_file_name': managed_stream.suggested_file_name,
'sd_hash': managed_stream.descriptor.sd_hash,
'mime_type': managed_stream.mime_type,
'key': managed_stream.descriptor.key,
'blobs_completed': managed_stream.blobs_completed,
'blobs_in_stream': managed_stream.blobs_in_stream,
'blobs_remaining': managed_stream.blobs_remaining,
'is_fully_reflected': managed_stream.is_fully_reflected,
'reflector_progress': managed_stream.reflector_progress,
'uploading_to_reflector': managed_stream.uploading_to_reflector
})
else:
result.update({
'streaming_url': f'file://{managed_stream.full_path}',
})
if output_exists:
result.update({
'file_name': managed_stream.file_name,
'download_directory': managed_stream.download_directory,
'download_path': managed_stream.full_path,
})
return result
def encode_claim(self, claim): def encode_claim(self, claim):
encoded = getattr(claim, claim.claim_type).to_dict() encoded = getattr(claim, claim.claim_type).to_dict()

View file

@ -0,0 +1,95 @@
import asyncio
import json
import logging.handlers
import traceback
import typing
from aiohttp.client_exceptions import ClientError
import aiohttp
from lbry import utils, __version__
if typing.TYPE_CHECKING:
from lbry.conf import Config
LOGGLY_TOKEN = 'BQEzZmMzLJHgAGxkBF00LGD0YGuyATVgAmqxAQEuAQZ2BQH4'
class JsonFormatter(logging.Formatter):
"""Format log records using json serialization"""
def __init__(self, **kwargs):
super().__init__()
self.attributes = kwargs
def format(self, record):
data = {
'loggerName': record.name,
'asciTime': self.formatTime(record),
'fileName': record.filename,
'functionName': record.funcName,
'levelNo': record.levelno,
'lineNo': record.lineno,
'levelName': record.levelname,
'message': record.getMessage(),
}
data.update(self.attributes)
if record.exc_info:
data['exc_info'] = self.formatException(record.exc_info)
return json.dumps(data)
class HTTPSLogglyHandler(logging.Handler):
def __init__(self, loggly_token: str, config: 'Config'):
super().__init__()
self.cookies = {}
self.url = "https://logs-01.loggly.com/inputs/{token}/tag/{tag}".format(
token=utils.deobfuscate(loggly_token), tag='lbrynet-' + __version__
)
self._loop = asyncio.get_event_loop()
self._session = aiohttp.ClientSession()
self._config = config
@property
def enabled(self):
return self._config.share_usage_data
@staticmethod
def get_full_message(record):
if record.exc_info:
return '\n'.join(traceback.format_exception(*record.exc_info))
else:
return record.getMessage()
async def _emit(self, record, retry=True):
data = self.format(record).encode()
try:
async with self._session.post(self.url, data=data,
cookies=self.cookies) as response:
self.cookies.update(response.cookies)
except ClientError:
if self._loop.is_running() and retry and self.enabled:
await self._session.close()
self._session = aiohttp.ClientSession()
return await self._emit(record, retry=False)
def emit(self, record):
if not self.enabled:
return
try:
asyncio.ensure_future(self._emit(record), loop=self._loop)
except RuntimeError: # TODO: use a second loop
print(f"\nfailed to send traceback to loggly, please file an issue with the following traceback:\n"
f"{self.format(record)}")
def close(self):
super().close()
try:
loop = asyncio.get_event_loop()
loop.run_until_complete(self._session.close())
except RuntimeError:
pass
def get_loggly_handler(config):
handler = HTTPSLogglyHandler(LOGGLY_TOKEN, config=config)
handler.setFormatter(JsonFormatter())
return handler

View file

@ -35,10 +35,6 @@ def migrate_db(conf, start, end):
from .migrate12to13 import do_migration from .migrate12to13 import do_migration
elif current == 13: elif current == 13:
from .migrate13to14 import do_migration from .migrate13to14 import do_migration
elif current == 14:
from .migrate14to15 import do_migration
elif current == 15:
from .migrate15to16 import do_migration
else: else:
raise Exception(f"DB migration of version {current} to {current+1} is not available") raise Exception(f"DB migration of version {current} to {current+1} is not available")
try: try:

View file

@ -1,16 +0,0 @@
import os
import sqlite3
def do_migration(conf):
db_path = os.path.join(conf.data_dir, "lbrynet.sqlite")
connection = sqlite3.connect(db_path)
cursor = connection.cursor()
cursor.executescript("""
alter table blob add column added_on integer not null default 0;
alter table blob add column is_mine integer not null default 1;
""")
connection.commit()
connection.close()

View file

@ -1,17 +0,0 @@
import os
import sqlite3
def do_migration(conf):
db_path = os.path.join(conf.data_dir, "lbrynet.sqlite")
connection = sqlite3.connect(db_path)
cursor = connection.cursor()
cursor.executescript("""
update blob set should_announce=0
where should_announce=1 and
blob.blob_hash in (select stream_blob.blob_hash from stream_blob where position=0);
""")
connection.commit()
connection.close()

View file

@ -20,7 +20,7 @@ def do_migration(conf):
"left outer join blob b ON b.blob_hash=s.blob_hash order by s.position").fetchall() "left outer join blob b ON b.blob_hash=s.blob_hash order by s.position").fetchall()
blobs_by_stream = {} blobs_by_stream = {}
for stream_hash, position, iv, blob_hash, blob_length in blobs: for stream_hash, position, iv, blob_hash, blob_length in blobs:
blobs_by_stream.setdefault(stream_hash, []).append(BlobInfo(position, blob_length or 0, iv, 0, blob_hash)) blobs_by_stream.setdefault(stream_hash, []).append(BlobInfo(position, blob_length or 0, iv, blob_hash))
for stream_name, stream_key, suggested_filename, sd_hash, stream_hash in streams: for stream_name, stream_key, suggested_filename, sd_hash, stream_hash in streams:
sd = StreamDescriptor(None, blob_dir, stream_name, stream_key, suggested_filename, sd = StreamDescriptor(None, blob_dir, stream_name, stream_key, suggested_filename,

View file

@ -1,31 +0,0 @@
import logging
from aiohttp import web
log = logging.getLogger(__name__)
def ensure_request_allowed(request, conf):
if is_request_allowed(request, conf):
return
if conf.allowed_origin:
log.warning(
"API requests with Origin '%s' are not allowed, "
"configuration 'allowed_origin' limits requests to: '%s'",
request.headers.get('Origin'), conf.allowed_origin
)
else:
log.warning(
"API requests with Origin '%s' are not allowed, "
"update configuration 'allowed_origin' to enable this origin.",
request.headers.get('Origin')
)
raise web.HTTPForbidden()
def is_request_allowed(request, conf) -> bool:
origin = request.headers.get('Origin')
return (
origin is None or
origin == conf.allowed_origin or
conf.allowed_origin == '*'
)

View file

@ -170,8 +170,8 @@ def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Di
def store_stream(transaction: sqlite3.Connection, sd_blob: 'BlobFile', descriptor: 'StreamDescriptor'): def store_stream(transaction: sqlite3.Connection, sd_blob: 'BlobFile', descriptor: 'StreamDescriptor'):
# add all blobs, except the last one, which is empty # add all blobs, except the last one, which is empty
transaction.executemany( transaction.executemany(
"insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?, ?, ?)", "insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)",
((blob.blob_hash, blob.length, 0, 0, "pending", 0, 0, blob.added_on, blob.is_mine) ((blob.blob_hash, blob.length, 0, 0, "pending", 0, 0)
for blob in (descriptor.blobs[:-1] if len(descriptor.blobs) > 1 else descriptor.blobs) + [sd_blob]) for blob in (descriptor.blobs[:-1] if len(descriptor.blobs) > 1 else descriptor.blobs) + [sd_blob])
).fetchall() ).fetchall()
# associate the blobs to the stream # associate the blobs to the stream
@ -187,8 +187,8 @@ def store_stream(transaction: sqlite3.Connection, sd_blob: 'BlobFile', descripto
).fetchall() ).fetchall()
# ensure should_announce is set regardless if insert was ignored # ensure should_announce is set regardless if insert was ignored
transaction.execute( transaction.execute(
"update blob set should_announce=1 where blob_hash in (?)", "update blob set should_announce=1 where blob_hash in (?, ?)",
(sd_blob.blob_hash,) (sd_blob.blob_hash, descriptor.blobs[0].blob_hash,)
).fetchall() ).fetchall()
@ -242,9 +242,7 @@ class SQLiteStorage(SQLiteMixin):
should_announce integer not null default 0, should_announce integer not null default 0,
status text not null, status text not null,
last_announced_time integer, last_announced_time integer,
single_announce integer, single_announce integer
added_on integer not null,
is_mine integer not null default 0
); );
create table if not exists stream ( create table if not exists stream (
@ -337,7 +335,6 @@ class SQLiteStorage(SQLiteMixin):
tcp_port integer, tcp_port integer,
unique (address, udp_port) unique (address, udp_port)
); );
create index if not exists blob_data on blob(blob_hash, blob_length, is_mine);
""" """
def __init__(self, conf: Config, path, loop=None, time_getter: typing.Optional[typing.Callable[[], float]] = None): def __init__(self, conf: Config, path, loop=None, time_getter: typing.Optional[typing.Callable[[], float]] = None):
@ -359,19 +356,19 @@ class SQLiteStorage(SQLiteMixin):
# # # # # # # # # blob functions # # # # # # # # # # # # # # # # # # blob functions # # # # # # # # #
async def add_blobs(self, *blob_hashes_and_lengths: typing.Tuple[str, int, int, int], finished=False): async def add_blobs(self, *blob_hashes_and_lengths: typing.Tuple[str, int], finished=False):
def _add_blobs(transaction: sqlite3.Connection): def _add_blobs(transaction: sqlite3.Connection):
transaction.executemany( transaction.executemany(
"insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?, ?, ?)", "insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)",
( (
(blob_hash, length, 0, 0, "pending" if not finished else "finished", 0, 0, added_on, is_mine) (blob_hash, length, 0, 0, "pending" if not finished else "finished", 0, 0)
for blob_hash, length, added_on, is_mine in blob_hashes_and_lengths for blob_hash, length in blob_hashes_and_lengths
) )
).fetchall() ).fetchall()
if finished: if finished:
transaction.executemany( transaction.executemany(
"update blob set status='finished' where blob.blob_hash=?", ( "update blob set status='finished' where blob.blob_hash=?", (
(blob_hash, ) for blob_hash, _, _, _ in blob_hashes_and_lengths (blob_hash, ) for blob_hash, _ in blob_hashes_and_lengths
) )
).fetchall() ).fetchall()
return await self.db.run(_add_blobs) return await self.db.run(_add_blobs)
@ -381,11 +378,6 @@ class SQLiteStorage(SQLiteMixin):
"select status from blob where blob_hash=?", blob_hash "select status from blob where blob_hash=?", blob_hash
) )
def set_announce(self, *blob_hashes):
return self.db.execute_fetchall(
"update blob set should_announce=1 where blob_hash in (?, ?)", blob_hashes
)
def update_last_announced_blobs(self, blob_hashes: typing.List[str]): def update_last_announced_blobs(self, blob_hashes: typing.List[str]):
def _update_last_announced_blobs(transaction: sqlite3.Connection): def _update_last_announced_blobs(transaction: sqlite3.Connection):
last_announced = self.time_getter() last_announced = self.time_getter()
@ -443,62 +435,6 @@ class SQLiteStorage(SQLiteMixin):
def get_all_blob_hashes(self): def get_all_blob_hashes(self):
return self.run_and_return_list("select blob_hash from blob") return self.run_and_return_list("select blob_hash from blob")
async def get_stored_blobs(self, is_mine: bool, is_network_blob=False):
is_mine = 1 if is_mine else 0
if is_network_blob:
return await self.db.execute_fetchall(
"select blob.blob_hash, blob.blob_length, blob.added_on "
"from blob left join stream_blob using (blob_hash) "
"where stream_blob.stream_hash is null and blob.is_mine=? and blob.status='finished'"
"order by blob.blob_length desc, blob.added_on asc",
(is_mine,)
)
sd_blobs = await self.db.execute_fetchall(
"select blob.blob_hash, blob.blob_length, blob.added_on "
"from blob join stream on blob.blob_hash=stream.sd_hash join file using (stream_hash) "
"where blob.is_mine=? order by blob.added_on asc",
(is_mine,)
)
content_blobs = await self.db.execute_fetchall(
"select blob.blob_hash, blob.blob_length, blob.added_on "
"from blob join stream_blob using (blob_hash) cross join stream using (stream_hash)"
"cross join file using (stream_hash)"
"where blob.is_mine=? and blob.status='finished' order by blob.added_on asc, blob.blob_length asc",
(is_mine,)
)
return content_blobs + sd_blobs
async def get_stored_blob_disk_usage(self):
total, network_size, content_size, private_size = await self.db.execute_fetchone("""
select coalesce(sum(blob_length), 0) as total,
coalesce(sum(case when
stream_blob.stream_hash is null
then blob_length else 0 end), 0) as network_storage,
coalesce(sum(case when
stream_blob.blob_hash is not null and is_mine=0
then blob_length else 0 end), 0) as content_storage,
coalesce(sum(case when
is_mine=1
then blob_length else 0 end), 0) as private_storage
from blob left join stream_blob using (blob_hash)
where blob_hash not in (select sd_hash from stream) and blob.status="finished"
""")
return {
'network_storage': network_size,
'content_storage': content_size,
'private_storage': private_size,
'total': total
}
async def update_blob_ownership(self, sd_hash, is_mine: bool):
is_mine = 1 if is_mine else 0
await self.db.execute_fetchall(
"update blob set is_mine = ? where blob_hash in ("
" select blob_hash from blob natural join stream_blob natural join stream where sd_hash = ?"
") OR blob_hash = ?", (is_mine, sd_hash, sd_hash)
)
def sync_missing_blobs(self, blob_files: typing.Set[str]) -> typing.Awaitable[typing.Set[str]]: def sync_missing_blobs(self, blob_files: typing.Set[str]) -> typing.Awaitable[typing.Set[str]]:
def _sync_blobs(transaction: sqlite3.Connection) -> typing.Set[str]: def _sync_blobs(transaction: sqlite3.Connection) -> typing.Set[str]:
finished_blob_hashes = tuple( finished_blob_hashes = tuple(
@ -534,8 +470,7 @@ class SQLiteStorage(SQLiteMixin):
def _get_blobs_for_stream(transaction): def _get_blobs_for_stream(transaction):
crypt_blob_infos = [] crypt_blob_infos = []
stream_blobs = transaction.execute( stream_blobs = transaction.execute(
"select s.blob_hash, s.position, s.iv, b.added_on " "select blob_hash, position, iv from stream_blob where stream_hash=? "
"from stream_blob s left outer join blob b on b.blob_hash=s.blob_hash where stream_hash=? "
"order by position asc", (stream_hash, ) "order by position asc", (stream_hash, )
).fetchall() ).fetchall()
if only_completed: if only_completed:
@ -555,10 +490,9 @@ class SQLiteStorage(SQLiteMixin):
for blob_hash, length in lengths: for blob_hash, length in lengths:
blob_length_dict[blob_hash] = length blob_length_dict[blob_hash] = length
current_time = time.time() for blob_hash, position, iv in stream_blobs:
for blob_hash, position, iv, added_on in stream_blobs:
blob_length = blob_length_dict.get(blob_hash, 0) blob_length = blob_length_dict.get(blob_hash, 0)
crypt_blob_infos.append(BlobInfo(position, blob_length, iv, added_on or current_time, blob_hash)) crypt_blob_infos.append(BlobInfo(position, blob_length, iv, blob_hash))
if not blob_hash: if not blob_hash:
break break
return crypt_blob_infos return crypt_blob_infos
@ -636,10 +570,6 @@ class SQLiteStorage(SQLiteMixin):
log.debug("update file status %s -> %s", stream_hash, new_status) log.debug("update file status %s -> %s", stream_hash, new_status)
return self.db.execute_fetchall("update file set status=? where stream_hash=?", (new_status, stream_hash)) return self.db.execute_fetchall("update file set status=? where stream_hash=?", (new_status, stream_hash))
def stop_all_files(self):
log.debug("stopping all files")
return self.db.execute_fetchall("update file set status=?", ("stopped",))
async def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: typing.Optional[str], async def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: typing.Optional[str],
file_name: typing.Optional[str]): file_name: typing.Optional[str]):
if not file_name or not download_dir: if not file_name or not download_dir:
@ -687,7 +617,7 @@ class SQLiteStorage(SQLiteMixin):
).fetchall() ).fetchall()
download_dir = binascii.hexlify(self.conf.download_dir.encode()).decode() download_dir = binascii.hexlify(self.conf.download_dir.encode()).decode()
transaction.executemany( transaction.executemany(
"update file set download_directory=? where stream_hash=?", f"update file set download_directory=? where stream_hash=?",
((download_dir, stream_hash) for stream_hash in stream_hashes) ((download_dir, stream_hash) for stream_hash in stream_hashes)
).fetchall() ).fetchall()
await self.db.run_with_foreign_keys_disabled(_recover) await self.db.run_with_foreign_keys_disabled(_recover)
@ -793,7 +723,7 @@ class SQLiteStorage(SQLiteMixin):
await self.db.run(_save_claims) await self.db.run(_save_claims)
if update_file_callbacks: if update_file_callbacks:
await asyncio.wait(map(asyncio.create_task, update_file_callbacks)) await asyncio.wait(update_file_callbacks)
if claim_id_to_supports: if claim_id_to_supports:
await self.save_supports(claim_id_to_supports) await self.save_supports(claim_id_to_supports)
@ -931,6 +861,6 @@ class SQLiteStorage(SQLiteMixin):
transaction.execute('delete from peer').fetchall() transaction.execute('delete from peer').fetchall()
transaction.executemany( transaction.executemany(
'insert into peer(node_id, address, udp_port, tcp_port) values (?, ?, ?, ?)', 'insert into peer(node_id, address, udp_port, tcp_port) values (?, ?, ?, ?)',
((binascii.hexlify(p.node_id), p.address, p.udp_port, p.tcp_port) for p in peers) tuple([(binascii.hexlify(p.node_id), p.address, p.udp_port, p.tcp_port) for p in peers])
).fetchall() ).fetchall()
return await self.db.run(_save_kademlia_peers) return await self.db.run(_save_kademlia_peers)

View file

@ -5,7 +5,6 @@ from typing import Optional
from aiohttp.web import Request from aiohttp.web import Request
from lbry.error import ResolveError, DownloadSDTimeoutError, InsufficientFundsError from lbry.error import ResolveError, DownloadSDTimeoutError, InsufficientFundsError
from lbry.error import ResolveTimeoutError, DownloadDataTimeoutError, KeyFeeAboveMaxAllowedError from lbry.error import ResolveTimeoutError, DownloadDataTimeoutError, KeyFeeAboveMaxAllowedError
from lbry.error import InvalidStreamURLError
from lbry.stream.managed_stream import ManagedStream from lbry.stream.managed_stream import ManagedStream
from lbry.torrent.torrent_manager import TorrentSource from lbry.torrent.torrent_manager import TorrentSource
from lbry.utils import cache_concurrent from lbry.utils import cache_concurrent
@ -13,12 +12,11 @@ from lbry.schema.url import URL
from lbry.wallet.dewies import dewies_to_lbc from lbry.wallet.dewies import dewies_to_lbc
from lbry.file.source_manager import SourceManager from lbry.file.source_manager import SourceManager
from lbry.file.source import ManagedDownloadSource from lbry.file.source import ManagedDownloadSource
from lbry.extras.daemon.storage import StoredContentClaim
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbry.conf import Config from lbry.conf import Config
from lbry.extras.daemon.analytics import AnalyticsManager from lbry.extras.daemon.analytics import AnalyticsManager
from lbry.extras.daemon.storage import SQLiteStorage from lbry.extras.daemon.storage import SQLiteStorage
from lbry.wallet import WalletManager from lbry.wallet import WalletManager, Output
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -50,10 +48,10 @@ class FileManager:
await manager.started.wait() await manager.started.wait()
self.started.set() self.started.set()
async def stop(self): def stop(self):
for manager in self.source_managers.values(): for manager in self.source_managers.values():
# fixme: pop or not? # fixme: pop or not?
await manager.stop() manager.stop()
self.started.clear() self.started.clear()
@cache_concurrent @cache_concurrent
@ -83,22 +81,18 @@ class FileManager:
payment = None payment = None
try: try:
# resolve the claim # resolve the claim
try:
if not URL.parse(uri).has_stream: if not URL.parse(uri).has_stream:
raise InvalidStreamURLError(uri) raise ResolveError("cannot download a channel claim, specify a /path")
except ValueError:
raise InvalidStreamURLError(uri)
try: try:
resolved_result = await asyncio.wait_for( resolved_result = await asyncio.wait_for(
self.wallet_manager.ledger.resolve( self.wallet_manager.ledger.resolve(wallet.accounts, [uri]),
wallet.accounts, [uri], resolve_timeout
include_purchase_receipt=True,
include_is_my_output=True
), resolve_timeout
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise ResolveTimeoutError(uri) raise ResolveTimeoutError(uri)
except Exception as err: except Exception as err:
if isinstance(err, asyncio.CancelledError):
raise
log.exception("Unexpected error resolving stream:") log.exception("Unexpected error resolving stream:")
raise ResolveError(f"Unexpected error resolving stream: {str(err)}") raise ResolveError(f"Unexpected error resolving stream: {str(err)}")
if 'error' in resolved_result: if 'error' in resolved_result:
@ -120,11 +114,9 @@ class FileManager:
if claim.stream.source.bt_infohash: if claim.stream.source.bt_infohash:
source_manager = self.source_managers['torrent'] source_manager = self.source_managers['torrent']
existing = source_manager.get_filtered(bt_infohash=claim.stream.source.bt_infohash) existing = source_manager.get_filtered(bt_infohash=claim.stream.source.bt_infohash)
elif claim.stream.source.sd_hash: else:
source_manager = self.source_managers['stream'] source_manager = self.source_managers['stream']
existing = source_manager.get_filtered(sd_hash=claim.stream.source.sd_hash) existing = source_manager.get_filtered(sd_hash=claim.stream.source.sd_hash)
else:
raise ResolveError(f"There is nothing to download at {uri} - Source is unknown or unset")
# resume or update an existing stream, if the stream changed: download it and delete the old one after # resume or update an existing stream, if the stream changed: download it and delete the old one after
to_replace, updated_stream = None, None to_replace, updated_stream = None, None
@ -150,8 +142,6 @@ class FileManager:
log.info("claim contains an update to a stream we have, downloading it") log.info("claim contains an update to a stream we have, downloading it")
if save_file and existing_for_claim_id[0].output_file_exists: if save_file and existing_for_claim_id[0].output_file_exists:
save_file = False save_file = False
if not claim.stream.source.bt_infohash:
existing_for_claim_id[0].downloader.node = source_manager.node
await existing_for_claim_id[0].start(timeout=timeout, save_now=save_file) await existing_for_claim_id[0].start(timeout=timeout, save_now=save_file)
if not existing_for_claim_id[0].output_file_exists and ( if not existing_for_claim_id[0].output_file_exists and (
save_file or file_name or download_directory): save_file or file_name or download_directory):
@ -165,8 +155,6 @@ class FileManager:
log.info("already have stream for %s", uri) log.info("already have stream for %s", uri)
if save_file and updated_stream.output_file_exists: if save_file and updated_stream.output_file_exists:
save_file = False save_file = False
if not claim.stream.source.bt_infohash:
updated_stream.downloader.node = source_manager.node
await updated_stream.start(timeout=timeout, save_now=save_file) await updated_stream.start(timeout=timeout, save_now=save_file)
if not updated_stream.output_file_exists and (save_file or file_name or download_directory): if not updated_stream.output_file_exists and (save_file or file_name or download_directory):
await updated_stream.save_file( await updated_stream.save_file(
@ -178,14 +166,7 @@ class FileManager:
# pay fee # pay fee
#################### ####################
needs_purchasing = ( if not to_replace and txo.has_price and not txo.purchase_receipt:
not to_replace and
not txo.is_my_output and
txo.has_price and
not txo.purchase_receipt
)
if needs_purchasing:
payment = await self.wallet_manager.create_purchase_transaction( payment = await self.wallet_manager.create_purchase_transaction(
wallet.accounts, txo, exchange_rate_manager wallet.accounts, txo, exchange_rate_manager
) )
@ -193,24 +174,21 @@ class FileManager:
#################### ####################
# make downloader and wait for start # make downloader and wait for start
#################### ####################
# temporary with fields we know so downloader can start. Missing fields are populated later.
stored_claim = StoredContentClaim(outpoint=outpoint, claim_id=txo.claim_id, name=txo.claim_name,
amount=txo.amount, height=txo.tx_ref.height,
serialized=claim.to_bytes().hex())
if not claim.stream.source.bt_infohash: if not claim.stream.source.bt_infohash:
# fixme: this shouldnt be here # fixme: this shouldnt be here
stream = ManagedStream( stream = ManagedStream(
self.loop, self.config, source_manager.blob_manager, claim.stream.source.sd_hash, self.loop, self.config, source_manager.blob_manager, claim.stream.source.sd_hash,
download_directory, file_name, ManagedStream.STATUS_RUNNING, content_fee=payment, download_directory, file_name, ManagedStream.STATUS_RUNNING, content_fee=payment,
analytics_manager=self.analytics_manager, claim=stored_claim analytics_manager=self.analytics_manager
) )
stream.downloader.node = source_manager.node stream.downloader.node = source_manager.node
else: else:
stream = TorrentSource( stream = TorrentSource(
self.loop, self.config, self.storage, identifier=claim.stream.source.bt_infohash, self.loop, self.config, self.storage, identifier=claim.stream.source.bt_infohash,
file_name=file_name, download_directory=download_directory or self.config.download_dir, file_name=file_name, download_directory=download_directory or self.config.download_dir,
status=ManagedStream.STATUS_RUNNING, claim=stored_claim, analytics_manager=self.analytics_manager, status=ManagedStream.STATUS_RUNNING,
analytics_manager=self.analytics_manager,
torrent_session=source_manager.torrent_session torrent_session=source_manager.torrent_session
) )
log.info("starting download for %s", uri) log.info("starting download for %s", uri)
@ -242,14 +220,15 @@ class FileManager:
claim_info = await self.storage.get_content_claim_for_torrent(stream.identifier) claim_info = await self.storage.get_content_claim_for_torrent(stream.identifier)
stream.set_claim(claim_info, claim) stream.set_claim(claim_info, claim)
if save_file: if save_file:
await asyncio.wait_for(stream.save_file(), timeout - (self.loop.time() - before_download)) await asyncio.wait_for(stream.save_file(), timeout - (self.loop.time() - before_download),
loop=self.loop)
return stream return stream
except asyncio.TimeoutError: except asyncio.TimeoutError:
error = DownloadDataTimeoutError(stream.sd_hash) error = DownloadDataTimeoutError(stream.sd_hash)
raise error raise error
except (Exception, asyncio.CancelledError) as err: # forgive data timeout, don't delete stream except Exception as err: # forgive data timeout, don't delete stream
expected = (DownloadSDTimeoutError, DownloadDataTimeoutError, InsufficientFundsError, expected = (DownloadSDTimeoutError, DownloadDataTimeoutError, InsufficientFundsError,
KeyFeeAboveMaxAllowedError, ResolveError, InvalidStreamURLError) KeyFeeAboveMaxAllowedError)
if isinstance(err, expected): if isinstance(err, expected):
log.warning("Failed to download %s: %s", uri, str(err)) log.warning("Failed to download %s: %s", uri, str(err))
elif isinstance(err, asyncio.CancelledError): elif isinstance(err, asyncio.CancelledError):

View file

@ -45,12 +45,11 @@ class ManagedDownloadSource:
self.purchase_receipt = None self.purchase_receipt = None
self._added_on = added_on self._added_on = added_on
self.analytics_manager = analytics_manager self.analytics_manager = analytics_manager
self.downloader = None
self.saving = asyncio.Event() self.saving = asyncio.Event(loop=self.loop)
self.finished_writing = asyncio.Event() self.finished_writing = asyncio.Event(loop=self.loop)
self.started_writing = asyncio.Event() self.started_writing = asyncio.Event(loop=self.loop)
self.finished_write_attempt = asyncio.Event() self.finished_write_attempt = asyncio.Event(loop=self.loop)
# @classmethod # @classmethod
# async def create(cls, loop: asyncio.AbstractEventLoop, config: 'Config', file_path: str, # async def create(cls, loop: asyncio.AbstractEventLoop, config: 'Config', file_path: str,
@ -67,7 +66,7 @@ class ManagedDownloadSource:
async def save_file(self, file_name: Optional[str] = None, download_directory: Optional[str] = None): async def save_file(self, file_name: Optional[str] = None, download_directory: Optional[str] = None):
raise NotImplementedError() raise NotImplementedError()
async def stop_tasks(self): def stop_tasks(self):
raise NotImplementedError() raise NotImplementedError()
def set_claim(self, claim_info: typing.Dict, claim: 'Claim'): def set_claim(self, claim_info: typing.Dict, claim: 'Claim'):

View file

@ -27,7 +27,6 @@ class SourceManager:
'status', 'status',
'file_name', 'file_name',
'added_on', 'added_on',
'download_path',
'claim_name', 'claim_name',
'claim_height', 'claim_height',
'claim_id', 'claim_id',
@ -35,14 +34,7 @@ class SourceManager:
'txid', 'txid',
'nout', 'nout',
'channel_claim_id', 'channel_claim_id',
'channel_name', 'channel_name'
'completed'
}
set_filter_fields = {
"claim_ids": "claim_id",
"channel_claim_ids": "channel_claim_id",
"outpoints": "outpoint"
} }
source_class = ManagedDownloadSource source_class = ManagedDownloadSource
@ -54,16 +46,16 @@ class SourceManager:
self.storage = storage self.storage = storage
self.analytics_manager = analytics_manager self.analytics_manager = analytics_manager
self._sources: typing.Dict[str, ManagedDownloadSource] = {} self._sources: typing.Dict[str, ManagedDownloadSource] = {}
self.started = asyncio.Event() self.started = asyncio.Event(loop=self.loop)
def add(self, source: ManagedDownloadSource): def add(self, source: ManagedDownloadSource):
self._sources[source.identifier] = source self._sources[source.identifier] = source
async def remove(self, source: ManagedDownloadSource): def remove(self, source: ManagedDownloadSource):
if source.identifier not in self._sources: if source.identifier not in self._sources:
return return
self._sources.pop(source.identifier) self._sources.pop(source.identifier)
await source.stop_tasks() source.stop_tasks()
async def initialize_from_database(self): async def initialize_from_database(self):
raise NotImplementedError() raise NotImplementedError()
@ -72,10 +64,10 @@ class SourceManager:
await self.initialize_from_database() await self.initialize_from_database()
self.started.set() self.started.set()
async def stop(self): def stop(self):
while self._sources: while self._sources:
_, source = self._sources.popitem() _, source = self._sources.popitem()
await source.stop_tasks() source.stop_tasks()
self.started.clear() self.started.clear()
async def create(self, file_path: str, key: Optional[bytes] = None, async def create(self, file_path: str, key: Optional[bytes] = None,
@ -83,7 +75,7 @@ class SourceManager:
raise NotImplementedError() raise NotImplementedError()
async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False): async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False):
await self.remove(source) self.remove(source)
if delete_file and source.output_file_exists: if delete_file and source.output_file_exists:
os.remove(source.full_path) os.remove(source.full_path)
@ -103,36 +95,21 @@ class SourceManager:
raise ValueError(f"'{comparison}' is not a valid comparison") raise ValueError(f"'{comparison}' is not a valid comparison")
if 'full_status' in search_by: if 'full_status' in search_by:
del search_by['full_status'] del search_by['full_status']
for search in search_by: for search in search_by:
if search not in self.filter_fields: if search not in self.filter_fields:
raise ValueError(f"'{search}' is not a valid search operation") raise ValueError(f"'{search}' is not a valid search operation")
if search_by:
compare_sets = {}
if isinstance(search_by.get('claim_id'), list):
compare_sets['claim_ids'] = search_by.pop('claim_id')
if isinstance(search_by.get('outpoint'), list):
compare_sets['outpoints'] = search_by.pop('outpoint')
if isinstance(search_by.get('channel_claim_id'), list):
compare_sets['channel_claim_ids'] = search_by.pop('channel_claim_id')
if search_by or compare_sets:
comparison = comparison or 'eq' comparison = comparison or 'eq'
streams = [] sources = []
for stream in self._sources.values(): for stream in self._sources.values():
if compare_sets and not all( for search, val in search_by.items():
getattr(stream, self.set_filter_fields[set_search]) in val if COMPARISON_OPERATORS[comparison](getattr(stream, search), val):
for set_search, val in compare_sets.items()): sources.append(stream)
continue break
if search_by and not all(
COMPARISON_OPERATORS[comparison](getattr(stream, search), val)
for search, val in search_by.items()):
continue
streams.append(stream)
else: else:
streams = list(self._sources.values()) sources = list(self._sources.values())
if sort_by: if sort_by:
streams.sort(key=lambda s: getattr(s, sort_by) or "") sources.sort(key=lambda s: getattr(s, sort_by))
if reverse: if reverse:
streams.reverse() sources.reverse()
return streams return sources

View file

@ -3,17 +3,16 @@ import json
import logging import logging
import os import os
import pathlib import pathlib
import platform
import re import re
import shlex import shlex
import shutil import shutil
import subprocess import platform
from math import ceil
import lbry.utils import lbry.utils
from lbry.conf import TranscodeConfig from lbry.conf import TranscodeConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
DISABLED = platform.system() == "Windows"
class VideoFileAnalyzer: class VideoFileAnalyzer:
@ -27,95 +26,63 @@ class VideoFileAnalyzer:
def __init__(self, conf: TranscodeConfig): def __init__(self, conf: TranscodeConfig):
self._conf = conf self._conf = conf
self._available_encoders = "" self._available_encoders = ""
self._ffmpeg_installed = None self._ffmpeg_installed = False
self._which_ffmpeg = None self._which = None
self._which_ffprobe = None
self._env_copy = dict(os.environ)
self._checked_ffmpeg = False self._checked_ffmpeg = False
self._env_copy = dict(os.environ)
if lbry.utils.is_running_from_bundle(): if lbry.utils.is_running_from_bundle():
# handle the situation where PyInstaller overrides our runtime environment: # handle the situation where PyInstaller overrides our runtime environment:
self._replace_or_pop_env('LD_LIBRARY_PATH') self._replace_or_pop_env('LD_LIBRARY_PATH')
@staticmethod async def _execute(self, command, arguments):
def _execute(command, environment): if DISABLED:
# log.debug("Executing: %s", command) return "Disabled on Windows", -1
try: args = shlex.split(arguments)
with subprocess.Popen( process = await asyncio.create_subprocess_exec(
shlex.split(command) if platform.system() != 'Windows' else command, os.path.join(self._conf.ffmpeg_folder, command), *args,
stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=environment stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=self._env_copy
) as process: )
(stdout, stderr) = process.communicate() # blocks until the process exits stdout, stderr = await process.communicate() # returns when the streams are closed
return stdout.decode(errors='replace') + stderr.decode(errors='replace'), process.returncode return stdout.decode(errors='replace') + stderr.decode(errors='replace'), process.returncode
except subprocess.SubprocessError as e:
return str(e), -1
# This create_subprocess_exec call is broken in Windows Python 3.7, but it's prettier than what's here. async def _verify_executable(self, name):
# The recommended fix is switching to ProactorEventLoop, but that breaks UDP in Linux Python 3.7.
# We work around that issue here by using run_in_executor. Check it again in Python 3.8.
async def _execute_ffmpeg(self, arguments):
arguments = self._which_ffmpeg + " " + arguments
return await asyncio.get_event_loop().run_in_executor(None, self._execute, arguments, self._env_copy)
async def _execute_ffprobe(self, arguments):
arguments = self._which_ffprobe + " " + arguments
return await asyncio.get_event_loop().run_in_executor(None, self._execute, arguments, self._env_copy)
async def _verify_executables(self):
try: try:
await self._execute_ffprobe("-version") version, code = await self._execute(name, "-version")
version, code = await self._execute_ffmpeg("-version")
except Exception as e: except Exception as e:
code = -1 code = -1
version = str(e) version = str(e)
if code != 0 or not version.startswith("ffmpeg"): if code != 0 or not version.startswith(name):
log.warning("Unable to run ffmpeg, but it was requested. Code: %d; Message: %s", code, version) log.warning("Unable to run %s, but it was requested. Code: %d; Message: %s", name, code, version)
raise FileNotFoundError("Unable to locate or run ffmpeg or ffprobe. Please install FFmpeg " raise FileNotFoundError(f"Unable to locate or run {name}. Please install FFmpeg "
"and ensure that it is callable via PATH or conf.ffmpeg_path") f"and ensure that it is callable via PATH or conf.ffmpeg_folder")
log.debug("Using %s at %s", version.splitlines()[0].split(" Copyright")[0], self._which_ffmpeg)
return version return version
@staticmethod
def _which_ffmpeg_and_ffmprobe(path):
return shutil.which("ffmpeg", path=path), shutil.which("ffprobe", path=path)
async def _verify_ffmpeg_installed(self): async def _verify_ffmpeg_installed(self):
if self._ffmpeg_installed: if self._ffmpeg_installed:
return return
self._ffmpeg_installed = False await self._verify_executable("ffprobe")
path = self._conf.ffmpeg_path version = await self._verify_executable("ffmpeg")
if hasattr(self._conf, "data_dir"): self._which = shutil.which(os.path.join(self._conf.ffmpeg_folder, "ffmpeg"))
path += os.path.pathsep + os.path.join(getattr(self._conf, "data_dir"), "ffmpeg", "bin")
path += os.path.pathsep + self._env_copy.get("PATH", "")
self._which_ffmpeg, self._which_ffprobe = await asyncio.get_running_loop().run_in_executor(
None, self._which_ffmpeg_and_ffmprobe, path
)
if not self._which_ffmpeg:
log.warning("Unable to locate ffmpeg executable. Path: %s", path)
raise FileNotFoundError(f"Unable to locate ffmpeg executable. Path: {path}")
if not self._which_ffprobe:
log.warning("Unable to locate ffprobe executable. Path: %s", path)
raise FileNotFoundError(f"Unable to locate ffprobe executable. Path: {path}")
if os.path.dirname(self._which_ffmpeg) != os.path.dirname(self._which_ffprobe):
log.warning("ffmpeg and ffprobe are in different folders!")
await self._verify_executables()
self._ffmpeg_installed = True self._ffmpeg_installed = True
log.debug("Using %s at %s", version.splitlines()[0].split(" Copyright")[0], self._which)
async def status(self, reset=False, recheck=False): async def status(self, reset=False, recheck=False):
if reset: if reset:
self._available_encoders = "" self._available_encoders = ""
self._ffmpeg_installed = None self._ffmpeg_installed = False
self._which = None
if self._checked_ffmpeg and not recheck: if self._checked_ffmpeg and not recheck:
pass installed = self._ffmpeg_installed
elif self._ffmpeg_installed is None: else:
installed = True
try: try:
await self._verify_ffmpeg_installed() await self._verify_ffmpeg_installed()
except FileNotFoundError: except FileNotFoundError:
pass installed = False
self._checked_ffmpeg = True self._checked_ffmpeg = True
return { return {
"available": self._ffmpeg_installed, "available": installed,
"which": self._which_ffmpeg, "which": self._which,
"analyze_audio_volume": int(self._conf.volume_analysis_time) > 0 "analyze_audio_volume": int(self._conf.volume_analysis_time) > 0
} }
@ -123,24 +90,9 @@ class VideoFileAnalyzer:
def _verify_container(scan_data: json): def _verify_container(scan_data: json):
container = scan_data["format"]["format_name"] container = scan_data["format"]["format_name"]
log.debug(" Detected container is %s", container) log.debug(" Detected container is %s", container)
splits = container.split(",") if not {"webm", "mp4", "3gp", "ogg"}.intersection(container.split(",")):
if not {"webm", "mp4", "3gp", "ogg"}.intersection(splits):
return "Container format is not in the approved list of WebM, MP4. " \ return "Container format is not in the approved list of WebM, MP4. " \
f"Actual: {container} [{scan_data['format']['format_long_name']}]" f"Actual: {container} [{scan_data['format']['format_long_name']}]"
if "matroska" in splits:
for stream in scan_data["streams"]:
if stream["codec_type"] == "video":
codec = stream["codec_name"]
if not {"vp8", "vp9", "av1"}.intersection(codec.split(",")):
return "WebM format requires VP8/9 or AV1 video. " \
f"Actual: {codec} [{stream['codec_long_name']}]"
elif stream["codec_type"] == "audio":
codec = stream["codec_name"]
if not {"vorbis", "opus"}.intersection(codec.split(",")):
return "WebM format requires Vorbis or Opus audio. " \
f"Actual: {codec} [{stream['codec_long_name']}]"
return "" return ""
@staticmethod @staticmethod
@ -169,12 +121,12 @@ class VideoFileAnalyzer:
bit_rate = float(scan_data["format"]["bit_rate"]) bit_rate = float(scan_data["format"]["bit_rate"])
else: else:
bit_rate = os.stat(file_path).st_size / float(scan_data["format"]["duration"]) bit_rate = os.stat(file_path).st_size / float(scan_data["format"]["duration"])
log.debug(" Detected bitrate is %s Mbps. Allowed max: %s Mbps", log.debug(" Detected bitrate is %s Mbps. Allowed is %s Mbps",
str(bit_rate / 1000000.0), str(bit_rate_max / 1000000.0)) str(bit_rate / 1000000.0), str(bit_rate_max / 1000000.0))
if bit_rate > bit_rate_max: if bit_rate > bit_rate_max:
return "The bit rate is above the configured maximum. Actual: " \ return "The bit rate is above the configured maximum. Actual: " \
f"{bit_rate / 1000000.0} Mbps; Allowed max: {bit_rate_max / 1000000.0} Mbps" f"{bit_rate / 1000000.0} Mbps; Allowed: {bit_rate_max / 1000000.0} Mbps"
return "" return ""
@ -183,7 +135,7 @@ class VideoFileAnalyzer:
if {"webm", "ogg"}.intersection(container.split(",")): if {"webm", "ogg"}.intersection(container.split(",")):
return "" return ""
result, _ = await self._execute_ffprobe(f'-v debug "{video_file}"') result, _ = await self._execute("ffprobe", f'-v debug "{video_file}"')
match = re.search(r"Before avformat_find_stream_info.+?\s+seeks:(\d+)\s+", result) match = re.search(r"Before avformat_find_stream_info.+?\s+seeks:(\d+)\s+", result)
if match and int(match.group(1)) != 0: if match and int(match.group(1)) != 0:
return "Video stream descriptors are not at the start of the file (the faststart flag was not used)." return "Video stream descriptors are not at the start of the file (the faststart flag was not used)."
@ -199,8 +151,6 @@ class VideoFileAnalyzer:
if not {"aac", "mp3", "flac", "vorbis", "opus"}.intersection(codec.split(",")): if not {"aac", "mp3", "flac", "vorbis", "opus"}.intersection(codec.split(",")):
return "Audio codec is not in the approved list of AAC, FLAC, MP3, Vorbis, and Opus. " \ return "Audio codec is not in the approved list of AAC, FLAC, MP3, Vorbis, and Opus. " \
f"Actual: {codec} [{stream['codec_long_name']}]" f"Actual: {codec} [{stream['codec_long_name']}]"
if int(stream['sample_rate']) > 48000:
return "Sample rate out of range"
return "" return ""
@ -213,7 +163,7 @@ class VideoFileAnalyzer:
if not validate_volume: if not validate_volume:
return "" return ""
result, _ = await self._execute_ffmpeg(f'-i "{video_file}" -t {seconds} ' result, _ = await self._execute("ffmpeg", f'-i "{video_file}" -t {seconds} '
f'-af volumedetect -vn -sn -dn -f null "{os.devnull}"') f'-af volumedetect -vn -sn -dn -f null "{os.devnull}"')
try: try:
mean_volume = float(re.search(r"mean_volume:\s+([-+]?\d*\.\d+|\d+)", result).group(1)) mean_volume = float(re.search(r"mean_volume:\s+([-+]?\d*\.\d+|\d+)", result).group(1))
@ -249,7 +199,7 @@ class VideoFileAnalyzer:
# if we don't have h264 use vp9; it's fairly compatible even though it's slow # if we don't have h264 use vp9; it's fairly compatible even though it's slow
if not self._available_encoders: if not self._available_encoders:
self._available_encoders, _ = await self._execute_ffmpeg("-encoders -v quiet") self._available_encoders, _ = await self._execute("ffmpeg", "-encoders -v quiet")
encoder = self._conf.video_encoder.split(" ", 1)[0] encoder = self._conf.video_encoder.split(" ", 1)[0]
if re.search(fr"^\s*V..... {encoder} ", self._available_encoders, re.MULTILINE): if re.search(fr"^\s*V..... {encoder} ", self._available_encoders, re.MULTILINE):
@ -283,7 +233,7 @@ class VideoFileAnalyzer:
wants_opus = extension != "mp4" wants_opus = extension != "mp4"
if not self._available_encoders: if not self._available_encoders:
self._available_encoders, _ = await self._execute_ffmpeg("-encoders -v quiet") self._available_encoders, _ = await self._execute("ffmpeg", "-encoders -v quiet")
encoder = self._conf.audio_encoder.split(" ", 1)[0] encoder = self._conf.audio_encoder.split(" ", 1)[0]
if wants_opus and 'opus' in encoder: if wants_opus and 'opus' in encoder:
@ -306,6 +256,9 @@ class VideoFileAnalyzer:
raise Exception(f"The audio encoder is not available. Requested: {encoder or 'aac'}") raise Exception(f"The audio encoder is not available. Requested: {encoder or 'aac'}")
async def _get_volume_filter(self):
return self._conf.volume_filter if self._conf.volume_filter else "-af loudnorm"
@staticmethod @staticmethod
def _get_best_container_extension(scan_data, video_encoder): def _get_best_container_extension(scan_data, video_encoder):
# the container is chosen by the video format # the container is chosen by the video format
@ -313,13 +266,7 @@ class VideoFileAnalyzer:
# if we are vp8/vp9/av1 we want webm # if we are vp8/vp9/av1 we want webm
# use mp4 for anything else # use mp4 for anything else
if video_encoder: # not re-encoding video if not video_encoder: # not re-encoding video
if "theora" in video_encoder:
return "ogv"
if re.search(r"vp[89x]|av1", video_encoder.split(" ", 1)[0]):
return "webm"
return "mp4"
for stream in scan_data["streams"]: for stream in scan_data["streams"]:
if stream["codec_type"] != "video": if stream["codec_type"] != "video":
continue continue
@ -329,11 +276,15 @@ class VideoFileAnalyzer:
if {"vp8", "vp9", "av1"}.intersection(codec): if {"vp8", "vp9", "av1"}.intersection(codec):
return "webm" return "webm"
if "theora" in video_encoder:
return "ogv"
elif re.search(r"vp[89x]|av1", video_encoder.split(" ", 1)[0]):
return "webm"
return "mp4" return "mp4"
async def _get_scan_data(self, validate, file_path): async def _get_scan_data(self, validate, file_path):
arguments = f'-v quiet -print_format json -show_format -show_streams "{file_path}"' result, _ = await self._execute("ffprobe",
result, _ = await self._execute_ffprobe(arguments) f'-v quiet -print_format json -show_format -show_streams "{file_path}"')
try: try:
scan_data = json.loads(result) scan_data = json.loads(result)
except Exception as e: except Exception as e:
@ -342,7 +293,7 @@ class VideoFileAnalyzer:
if "format" not in scan_data or "duration" not in scan_data["format"]: if "format" not in scan_data or "duration" not in scan_data["format"]:
log.debug("Format data is missing from ffprobe results for: %s", file_path) log.debug("Format data is missing from ffprobe results for: %s", file_path)
raise ValueError(f'Media file does not appear to contain video content: {file_path}') raise ValueError(f'Media file does not appear to contain video content at: {file_path}')
if float(scan_data["format"]["duration"]) < 0.1: if float(scan_data["format"]["duration"]) < 0.1:
log.debug("Media file appears to be an image: %s", file_path) log.debug("Media file appears to be an image: %s", file_path)
@ -350,46 +301,20 @@ class VideoFileAnalyzer:
return scan_data return scan_data
@staticmethod
def _build_spec(scan_data):
assert scan_data
duration = ceil(float(scan_data["format"]["duration"])) # existence verified when scan_data made
width = -1
height = -1
for stream in scan_data["streams"]:
if stream["codec_type"] != "video":
continue
width = max(width, int(stream["width"]))
height = max(height, int(stream["height"]))
log.debug(" Detected duration: %d sec. with resolution: %d x %d", duration, width, height)
spec = {"duration": duration}
if height >= 0:
spec["height"] = height
if width >= 0:
spec["width"] = width
return spec
async def verify_or_repair(self, validate, repair, file_path, ignore_non_video=False): async def verify_or_repair(self, validate, repair, file_path, ignore_non_video=False):
if not validate and not repair: if not validate and not repair:
return file_path, {} return file_path
if ignore_non_video and not file_path:
return file_path, {}
await self._verify_ffmpeg_installed() await self._verify_ffmpeg_installed()
try: try:
scan_data = await self._get_scan_data(validate, file_path) scan_data = await self._get_scan_data(validate, file_path)
except ValueError: except ValueError:
if ignore_non_video: if ignore_non_video:
return file_path, {} return file_path
raise raise
fast_start_msg = await self._verify_fast_start(scan_data, file_path) fast_start_msg = await self._verify_fast_start(scan_data, file_path)
log.debug("Analyzing %s:", file_path) log.debug("Analyzing %s:", file_path)
spec = self._build_spec(scan_data)
log.debug(" Detected faststart is %s", "false" if fast_start_msg else "true") log.debug(" Detected faststart is %s", "false" if fast_start_msg else "true")
container_msg = self._verify_container(scan_data) container_msg = self._verify_container(scan_data)
bitrate_msg = self._verify_bitrate(scan_data, file_path) bitrate_msg = self._verify_bitrate(scan_data, file_path)
@ -399,7 +324,7 @@ class VideoFileAnalyzer:
messages = [container_msg, bitrate_msg, fast_start_msg, video_msg, audio_msg, volume_msg] messages = [container_msg, bitrate_msg, fast_start_msg, video_msg, audio_msg, volume_msg]
if not any(messages): if not any(messages):
return file_path, spec return file_path
if not repair: if not repair:
errors = ["Streamability verification failed:"] errors = ["Streamability verification failed:"]
@ -409,7 +334,7 @@ class VideoFileAnalyzer:
# the plan for transcoding: # the plan for transcoding:
# we have to re-encode the video if it is in a nonstandard format # we have to re-encode the video if it is in a nonstandard format
# we also re-encode if we are h264 but not yuv420p (both errors caught in video_msg) # we also re-encode if we are h264 but not yuv420p (both errors caught in video_msg)
# we also re-encode if our bitrate or sample rate is too high # we also re-encode if our bitrate is too high
try: try:
transcode_command = [f'-i "{file_path}" -y -c:s copy -c:d copy -c:v'] transcode_command = [f'-i "{file_path}" -y -c:s copy -c:d copy -c:v']
@ -424,26 +349,25 @@ class VideoFileAnalyzer:
transcode_command.append("copy") transcode_command.append("copy")
transcode_command.append("-movflags +faststart -c:a") transcode_command.append("-movflags +faststart -c:a")
path = pathlib.Path(file_path)
extension = self._get_best_container_extension(scan_data, video_encoder) extension = self._get_best_container_extension(scan_data, video_encoder)
if audio_msg or volume_msg: if audio_msg or volume_msg:
audio_encoder = await self._get_audio_encoder(extension) audio_encoder = await self._get_audio_encoder(extension)
transcode_command.append(audio_encoder) transcode_command.append(audio_encoder)
if volume_msg and self._conf.volume_filter: if volume_msg:
transcode_command.append(self._conf.volume_filter) volume_filter = await self._get_volume_filter()
if audio_msg == "Sample rate out of range": transcode_command.append(volume_filter)
transcode_command.append(" -ar 48000 ")
else: else:
transcode_command.append("copy") transcode_command.append("copy")
# TODO: put it in a temp folder and delete it after we upload? # TODO: put it in a temp folder and delete it after we upload?
path = pathlib.Path(file_path)
output = path.parent / f"{path.stem}_fixed.{extension}" output = path.parent / f"{path.stem}_fixed.{extension}"
transcode_command.append(f'"{output}"') transcode_command.append(f'"{output}"')
ffmpeg_command = " ".join(transcode_command) ffmpeg_command = " ".join(transcode_command)
log.info("Proceeding on transcode via: ffmpeg %s", ffmpeg_command) log.info("Proceeding on transcode via: ffmpeg %s", ffmpeg_command)
result, code = await self._execute_ffmpeg(ffmpeg_command) result, code = await self._execute("ffmpeg", ffmpeg_command)
if code != 0: if code != 0:
raise Exception(f"Failure to complete the transcode command. Output: {result}") raise Exception(f"Failure to complete the transcode command. Output: {result}")
except Exception as e: except Exception as e:
@ -451,6 +375,6 @@ class VideoFileAnalyzer:
raise raise
log.info("Unable to transcode %s . Message: %s", file_path, str(e)) log.info("Unable to transcode %s . Message: %s", file_path, str(e))
# TODO: delete partial output file here if it exists? # TODO: delete partial output file here if it exists?
return file_path, spec return file_path
return str(output), spec return str(output)

View file

@ -1,68 +0,0 @@
import time
import logging
import asyncio
import asyncio.tasks
from aiohttp import web
from prometheus_client import generate_latest as prom_generate_latest
from prometheus_client import Counter, Histogram, Gauge
PROBES_IN_FLIGHT = Counter("probes_in_flight", "Number of loop probes in flight", namespace='asyncio')
PROBES_FINISHED = Counter("probes_finished", "Number of finished loop probes", namespace='asyncio')
PROBE_TIMES = Histogram("probe_times", "Loop probe times", namespace='asyncio')
TASK_COUNT = Gauge("running_tasks", "Number of running tasks", namespace='asyncio')
def get_loop_metrics(delay=1):
loop = asyncio.get_event_loop()
def callback(started):
PROBE_TIMES.observe(time.perf_counter() - started - delay)
PROBES_FINISHED.inc()
async def monitor_loop_responsiveness():
while True:
now = time.perf_counter()
loop.call_later(delay, callback, now)
PROBES_IN_FLIGHT.inc()
TASK_COUNT.set(len(asyncio.tasks._all_tasks))
await asyncio.sleep(delay)
return loop.create_task(monitor_loop_responsiveness())
class PrometheusServer:
def __init__(self, logger=None):
self.runner = None
self.logger = logger or logging.getLogger(__name__)
self._monitor_loop_task = None
async def start(self, interface: str, port: int):
self.logger.info("start prometheus metrics")
prom_app = web.Application()
prom_app.router.add_get('/metrics', self.handle_metrics_get_request)
self.runner = web.AppRunner(prom_app)
await self.runner.setup()
metrics_site = web.TCPSite(self.runner, interface, port, shutdown_timeout=.5)
await metrics_site.start()
self.logger.info(
'prometheus metrics server listening on %s:%i', *metrics_site._server.sockets[0].getsockname()[:2]
)
self._monitor_loop_task = get_loop_metrics()
async def handle_metrics_get_request(self, request: web.Request):
try:
return web.Response(
text=prom_generate_latest().decode(),
content_type='text/plain; version=0.0.4'
)
except Exception:
self.logger.exception('could not generate prometheus data')
raise
async def stop(self):
if self._monitor_loop_task and not self._monitor_loop_task.done():
self._monitor_loop_task.cancel()
self._monitor_loop_task = None
await self.runner.cleanup()

View file

@ -2,5 +2,4 @@ build:
rm types/v2/* -rf rm types/v2/* -rf
touch types/v2/__init__.py touch types/v2/__init__.py
cd types/v2/ && protoc --python_out=. -I ../../../../../types/v2/proto/ ../../../../../types/v2/proto/*.proto cd types/v2/ && protoc --python_out=. -I ../../../../../types/v2/proto/ ../../../../../types/v2/proto/*.proto
cd types/v2/ && cp ../../../../../types/jsonschema/* ./
sed -e 's/^import\ \(.*\)_pb2\ /from . import\ \1_pb2\ /g' -i types/v2/*.py sed -e 's/^import\ \(.*\)_pb2\ /from . import\ \1_pb2\ /g' -i types/v2/*.py

View file

@ -1,24 +0,0 @@
Schema
=====
Those files are generated from the [types repo](https://github.com/lbryio/types). If you are modifying/adding a new type, make sure it is cloned in the same root folder as the SDK repo, like:
```
repos/
- lbry-sdk/
- types/
```
Then, [download protoc 3.2.0](https://github.com/protocolbuffers/protobuf/releases/tag/v3.2.0), add it to your PATH. On linux it is:
```bash
cd ~/.local/bin
wget https://github.com/protocolbuffers/protobuf/releases/download/v3.2.0/protoc-3.2.0-linux-x86_64.zip
unzip protoc-3.2.0-linux-x86_64.zip bin/protoc -d..
```
Finally, `make` should update everything in place.
### Why protoc 3.2.0?
Different/newer versions will generate larger diffs and we need to make sure they are good. In theory, we can just update to latest and it will all work, but it is a good practice to check blockchain data and retro compatibility before bumping versions (if you do, please update this section!).

View file

@ -10,7 +10,6 @@ from google.protobuf.json_format import MessageToDict
from lbry.crypto.base58 import Base58 from lbry.crypto.base58 import Base58
from lbry.constants import COIN from lbry.constants import COIN
from lbry.error import MissingPublishedFileError, EmptyPublishedFileError
from lbry.schema.mime_types import guess_media_type from lbry.schema.mime_types import guess_media_type
from lbry.schema.base import Metadata, BaseMessageList from lbry.schema.base import Metadata, BaseMessageList
@ -33,17 +32,6 @@ def calculate_sha384_file_hash(file_path):
return sha384.digest() return sha384.digest()
def country_int_to_str(country: int) -> str:
r = LocationMessage.Country.Name(country)
return r[1:] if r.startswith('R') else r
def country_str_to_int(country: str) -> int:
if len(country) == 3:
country = 'R' + country
return LocationMessage.Country.Value(country)
class Dimmensional(Metadata): class Dimmensional(Metadata):
__slots__ = () __slots__ = ()
@ -140,10 +128,10 @@ class Source(Metadata):
self.name = os.path.basename(file_path) self.name = os.path.basename(file_path)
self.media_type, stream_type = guess_media_type(file_path) self.media_type, stream_type = guess_media_type(file_path)
if not os.path.isfile(file_path): if not os.path.isfile(file_path):
raise MissingPublishedFileError(file_path) raise Exception(f"File does not exist: {file_path}")
self.size = os.path.getsize(file_path) self.size = os.path.getsize(file_path)
if self.size == 0: if self.size == 0:
raise EmptyPublishedFileError(file_path) raise Exception(f"Cannot publish empty file: {file_path}")
self.file_hash_bytes = calculate_sha384_file_hash(file_path) self.file_hash_bytes = calculate_sha384_file_hash(file_path)
return stream_type return stream_type
@ -435,11 +423,14 @@ class Language(Metadata):
@property @property
def region(self) -> str: def region(self) -> str:
if self.message.region: if self.message.region:
return country_int_to_str(self.message.region) r = LocationMessage.Country.Name(self.message.region)
return r[1:] if r.startswith('R') else r
@region.setter @region.setter
def region(self, region: str): def region(self, region: str):
self.message.region = country_str_to_int(region) if len(region) == 3:
region = 'R'+region
self.message.region = LocationMessage.Country.Value(region)
class LanguageList(BaseMessageList[Language]): class LanguageList(BaseMessageList[Language]):

View file

@ -2,9 +2,6 @@ import logging
from typing import List from typing import List
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from asn1crypto.keys import PublicKeyInfo
from coincurve import PublicKey as cPublicKey
from google.protobuf.json_format import MessageToDict from google.protobuf.json_format import MessageToDict
from google.protobuf.message import DecodeError from google.protobuf.message import DecodeError
from hachoir.core.log import log as hachoir_log from hachoir.core.log import log as hachoir_log
@ -33,10 +30,14 @@ class Claim(Signable):
COLLECTION = 'collection' COLLECTION = 'collection'
REPOST = 'repost' REPOST = 'repost'
__slots__ = () __slots__ = 'version',
message_class = ClaimMessage message_class = ClaimMessage
def __init__(self, message=None):
super().__init__(message)
self.version = 2
@property @property
def claim_type(self) -> str: def claim_type(self) -> str:
return self.message.WhichOneof('type') return self.message.WhichOneof('type')
@ -252,7 +253,7 @@ class Stream(BaseClaim):
if stream_type in ('image', 'video', 'audio'): if stream_type in ('image', 'video', 'audio'):
media = getattr(self, stream_type) media = getattr(self, stream_type)
media_args = {'file_metadata': None} media_args = {'file_metadata': None}
if file_path is not None and not all((duration, width, height)): if file_path is not None:
try: try:
media_args['file_metadata'] = binary_file_metadata(binary_file_parser(file_path)) media_args['file_metadata'] = binary_file_metadata(binary_file_parser(file_path))
except: except:
@ -306,10 +307,6 @@ class Stream(BaseClaim):
def has_fee(self) -> bool: def has_fee(self) -> bool:
return self.message.HasField('fee') return self.message.HasField('fee')
@property
def has_source(self) -> bool:
return self.message.HasField('source')
@property @property
def source(self) -> Source: def source(self) -> Source:
return Source(self.message.source) return Source(self.message.source)
@ -349,7 +346,7 @@ class Channel(BaseClaim):
@property @property
def public_key(self) -> str: def public_key(self) -> str:
return hexlify(self.public_key_bytes).decode() return hexlify(self.message.public_key).decode()
@public_key.setter @public_key.setter
def public_key(self, sd_public_key: str): def public_key(self, sd_public_key: str):
@ -357,11 +354,7 @@ class Channel(BaseClaim):
@property @property
def public_key_bytes(self) -> bytes: def public_key_bytes(self) -> bytes:
if len(self.message.public_key) == 33:
return self.message.public_key return self.message.public_key
public_key_info = PublicKeyInfo.load(self.message.public_key)
public_key = cPublicKey(public_key_info.native['public_key'])
return public_key.format(compressed=True)
@public_key_bytes.setter @public_key_bytes.setter
def public_key_bytes(self, public_key: bytes): def public_key_bytes(self, public_key: bytes):
@ -398,12 +391,6 @@ class Repost(BaseClaim):
claim_type = Claim.REPOST claim_type = Claim.REPOST
def to_dict(self):
claim = super().to_dict()
if claim.pop('claim_hash', None):
claim['claim_id'] = self.reference.claim_id
return claim
@property @property
def reference(self) -> ClaimReference: def reference(self) -> ClaimReference:
return ClaimReference(self.message) return ClaimReference(self.message)

View file

@ -1,6 +1,4 @@
import os import os
import filetype
import logging
types_map = { types_map = {
# http://www.iana.org/assignments/media-types # http://www.iana.org/assignments/media-types
@ -48,8 +46,8 @@ types_map = {
'.ksh': ('text/plain', 'document'), '.ksh': ('text/plain', 'document'),
'.latex': ('application/x-latex', 'binary'), '.latex': ('application/x-latex', 'binary'),
'.m1v': ('video/mpeg', 'video'), '.m1v': ('video/mpeg', 'video'),
'.m3u': ('application/x-mpegurl', 'audio'), '.m3u': ('application/vnd.apple.mpegurl', 'audio'),
'.m3u8': ('application/x-mpegurl', 'video'), '.m3u8': ('application/vnd.apple.mpegurl', 'audio'),
'.man': ('application/x-troff-man', 'document'), '.man': ('application/x-troff-man', 'document'),
'.markdown': ('text/markdown', 'document'), '.markdown': ('text/markdown', 'document'),
'.md': ('text/markdown', 'document'), '.md': ('text/markdown', 'document'),
@ -63,12 +61,10 @@ types_map = {
'.mp3': ('audio/mpeg', 'audio'), '.mp3': ('audio/mpeg', 'audio'),
'.mp4': ('video/mp4', 'video'), '.mp4': ('video/mp4', 'video'),
'.mpa': ('video/mpeg', 'video'), '.mpa': ('video/mpeg', 'video'),
'.mpd': ('application/dash+xml', 'video'),
'.mpe': ('video/mpeg', 'video'), '.mpe': ('video/mpeg', 'video'),
'.mpeg': ('video/mpeg', 'video'), '.mpeg': ('video/mpeg', 'video'),
'.mpg': ('video/mpeg', 'video'), '.mpg': ('video/mpeg', 'video'),
'.ms': ('application/x-troff-ms', 'binary'), '.ms': ('application/x-troff-ms', 'binary'),
'.m4s': ('video/iso.segment', 'binary'),
'.nc': ('application/x-netcdf', 'binary'), '.nc': ('application/x-netcdf', 'binary'),
'.nws': ('message/rfc822', 'document'), '.nws': ('message/rfc822', 'document'),
'.o': ('application/octet-stream', 'binary'), '.o': ('application/octet-stream', 'binary'),
@ -122,12 +118,10 @@ types_map = {
'.tif': ('image/tiff', 'image'), '.tif': ('image/tiff', 'image'),
'.tiff': ('image/tiff', 'image'), '.tiff': ('image/tiff', 'image'),
'.tr': ('application/x-troff', 'binary'), '.tr': ('application/x-troff', 'binary'),
'.ts': ('video/mp2t', 'video'),
'.tsv': ('text/tab-separated-values', 'document'), '.tsv': ('text/tab-separated-values', 'document'),
'.txt': ('text/plain', 'document'), '.txt': ('text/plain', 'document'),
'.ustar': ('application/x-ustar', 'binary'), '.ustar': ('application/x-ustar', 'binary'),
'.vcf': ('text/x-vcard', 'document'), '.vcf': ('text/x-vcard', 'document'),
'.vtt': ('text/vtt', 'document'),
'.wav': ('audio/x-wav', 'audio'), '.wav': ('audio/x-wav', 'audio'),
'.webm': ('video/webm', 'video'), '.webm': ('video/webm', 'video'),
'.wiz': ('application/msword', 'document'), '.wiz': ('application/msword', 'document'),
@ -147,7 +141,6 @@ types_map = {
'.cbz': ('application/vnd.comicbook+zip', 'document'), '.cbz': ('application/vnd.comicbook+zip', 'document'),
'.flac': ('audio/flac', 'audio'), '.flac': ('audio/flac', 'audio'),
'.lbry': ('application/x-ext-lbry', 'document'), '.lbry': ('application/x-ext-lbry', 'document'),
'.m4a': ('audio/mp4', 'audio'),
'.m4v': ('video/m4v', 'video'), '.m4v': ('video/m4v', 'video'),
'.mid': ('audio/midi', 'audio'), '.mid': ('audio/midi', 'audio'),
'.midi': ('audio/midi', 'audio'), '.midi': ('audio/midi', 'audio'),
@ -168,38 +161,10 @@ types_map = {
'.wmv': ('video/x-ms-wmv', 'video') '.wmv': ('video/x-ms-wmv', 'video')
} }
# maps detected extensions to the possible analogs
# i.e. .cbz file is actually a .zip
synonyms_map = {
'.zip': ['.cbz'],
'.rar': ['.cbr'],
'.ar': ['.a']
}
log = logging.getLogger(__name__)
def guess_media_type(path): def guess_media_type(path):
_, ext = os.path.splitext(path) _, ext = os.path.splitext(path)
extension = ext.strip().lower() extension = ext.strip().lower()
try:
kind = filetype.guess(path)
if kind:
real_extension = f".{kind.extension}"
if extension != real_extension:
if extension:
log.warning(f"file extension does not match it's contents: {path}, identified as {real_extension}")
else:
log.debug(f"file {path} does not have extension, identified by it's contents as {real_extension}")
if extension not in synonyms_map.get(real_extension, []):
extension = real_extension
except OSError as error:
pass
if extension[1:]: if extension[1:]:
if extension in types_map: if extension in types_map:
return types_map[extension] return types_map[extension]

View file

@ -1,5 +1,6 @@
import base64 import base64
from typing import List, Union, Optional, NamedTuple import struct
from typing import List
from binascii import hexlify from binascii import hexlify
from itertools import chain from itertools import chain
@ -15,70 +16,47 @@ BLOCKED = ErrorMessage.Code.Name(ErrorMessage.BLOCKED)
def set_reference(reference, claim_hash, rows): def set_reference(reference, claim_hash, rows):
if claim_hash: if claim_hash:
for txo in rows: for txo in rows:
if claim_hash == txo.claim_hash: if claim_hash == txo['claim_hash']:
reference.tx_hash = txo.tx_hash reference.tx_hash = txo['txo_hash'][:32]
reference.nout = txo.position reference.nout = struct.unpack('<I', txo['txo_hash'][32:])[0]
reference.height = txo.height reference.height = txo['height']
return return
class ResolveResult(NamedTuple):
name: str
normalized_name: str
claim_hash: bytes
tx_num: int
position: int
tx_hash: bytes
height: int
amount: int
short_url: str
is_controlling: bool
canonical_url: str
creation_height: int
activation_height: int
expiration_height: int
effective_amount: int
support_amount: int
reposted: int
last_takeover_height: Optional[int]
claims_in_channel: Optional[int]
channel_hash: Optional[bytes]
reposted_claim_hash: Optional[bytes]
signature_valid: Optional[bool]
class Censor: class Censor:
NOT_CENSORED = 0 __slots__ = 'streams', 'channels', 'censored', 'total'
SEARCH = 1
RESOLVE = 2
__slots__ = 'censor_type', 'censored' def __init__(self, streams: dict = None, channels: dict = None):
self.streams = streams or {}
def __init__(self, censor_type): self.channels = channels or {}
self.censor_type = censor_type
self.censored = {} self.censored = {}
self.total = 0
def is_censored(self, row): def censor(self, row) -> bool:
return (row.get('censor_type') or self.NOT_CENSORED) >= self.censor_type was_censored = False
for claim_hash, lookup in (
(row['claim_hash'], self.streams),
(row['claim_hash'], self.channels),
(row['channel_hash'], self.channels),
(row['reposted_claim_hash'], self.streams),
(row['reposted_claim_hash'], self.channels)):
censoring_channel_hash = lookup.get(claim_hash)
if censoring_channel_hash:
was_censored = True
self.censored.setdefault(censoring_channel_hash, 0)
self.censored[censoring_channel_hash] += 1
break
if was_censored:
self.total += 1
return was_censored
def apply(self, rows): def to_message(self, outputs: OutputsMessage, extra_txo_rows):
return [row for row in rows if not self.censor(row)] outputs.blocked_total = self.total
def censor(self, row) -> Optional[bytes]:
if self.is_censored(row):
censoring_channel_hash = bytes.fromhex(row['censoring_channel_id'])[::-1]
self.censored.setdefault(censoring_channel_hash, set())
self.censored[censoring_channel_hash].add(row['tx_hash'])
return censoring_channel_hash
return None
def to_message(self, outputs: OutputsMessage, extra_txo_rows: dict):
for censoring_channel_hash, count in self.censored.items(): for censoring_channel_hash, count in self.censored.items():
blocked = outputs.blocked.add() blocked = outputs.blocked.add()
blocked.count = len(count) blocked.count = count
set_reference(blocked.channel, censoring_channel_hash, extra_txo_rows) set_reference(blocked.channel, censoring_channel_hash, extra_txo_rows)
outputs.blocked_total += len(count)
class Outputs: class Outputs:
@ -142,10 +120,10 @@ class Outputs:
'expiration_height': claim.expiration_height, 'expiration_height': claim.expiration_height,
'effective_amount': claim.effective_amount, 'effective_amount': claim.effective_amount,
'support_amount': claim.support_amount, 'support_amount': claim.support_amount,
# 'trending_group': claim.trending_group, 'trending_group': claim.trending_group,
# 'trending_mixed': claim.trending_mixed, 'trending_mixed': claim.trending_mixed,
# 'trending_local': claim.trending_local, 'trending_local': claim.trending_local,
# 'trending_global': claim.trending_global, 'trending_global': claim.trending_global,
} }
if claim.HasField('channel'): if claim.HasField('channel'):
txo.channel = tx_map[claim.channel.tx_hash].outputs[claim.channel.nout] txo.channel = tx_map[claim.channel.tx_hash].outputs[claim.channel.nout]
@ -189,54 +167,44 @@ class Outputs:
page.total = total page.total = total
if blocked is not None: if blocked is not None:
blocked.to_message(page, extra_txo_rows) blocked.to_message(page, extra_txo_rows)
for row in extra_txo_rows:
txo_message: 'OutputsMessage' = page.extra_txos.add()
if not isinstance(row, Exception):
if row.channel_hash:
set_reference(txo_message.claim.channel, row.channel_hash, extra_txo_rows)
if row.reposted_claim_hash:
set_reference(txo_message.claim.repost, row.reposted_claim_hash, extra_txo_rows)
cls.encode_txo(txo_message, row)
for row in txo_rows: for row in txo_rows:
# cls.row_to_message(row, page.txos.add(), extra_txo_rows) cls.row_to_message(row, page.txos.add(), extra_txo_rows)
txo_message: 'OutputsMessage' = page.txos.add() for row in extra_txo_rows:
cls.encode_txo(txo_message, row) cls.row_to_message(row, page.extra_txos.add(), extra_txo_rows)
if not isinstance(row, Exception):
if row.channel_hash:
set_reference(txo_message.claim.channel, row.channel_hash, extra_txo_rows)
if row.reposted_claim_hash:
set_reference(txo_message.claim.repost, row.reposted_claim_hash, extra_txo_rows)
elif isinstance(row, ResolveCensoredError):
set_reference(txo_message.error.blocked.channel, row.censor_id, extra_txo_rows)
return page.SerializeToString() return page.SerializeToString()
@classmethod @classmethod
def encode_txo(cls, txo_message, resolve_result: Union['ResolveResult', Exception]): def row_to_message(cls, txo, txo_message, extra_txo_rows):
if isinstance(resolve_result, Exception): if isinstance(txo, Exception):
txo_message.error.text = resolve_result.args[0] txo_message.error.text = txo.args[0]
if isinstance(resolve_result, ValueError): if isinstance(txo, ValueError):
txo_message.error.code = ErrorMessage.INVALID txo_message.error.code = ErrorMessage.INVALID
elif isinstance(resolve_result, LookupError): elif isinstance(txo, LookupError):
txo_message.error.code = ErrorMessage.NOT_FOUND txo_message.error.code = ErrorMessage.NOT_FOUND
elif isinstance(resolve_result, ResolveCensoredError): elif isinstance(txo, ResolveCensoredError):
txo_message.error.code = ErrorMessage.BLOCKED txo_message.error.code = ErrorMessage.BLOCKED
set_reference(txo_message.error.blocked.channel, txo.censor_hash, extra_txo_rows)
return return
txo_message.tx_hash = resolve_result.tx_hash txo_message.tx_hash = txo['txo_hash'][:32]
txo_message.nout = resolve_result.position txo_message.nout, = struct.unpack('<I', txo['txo_hash'][32:])
txo_message.height = resolve_result.height txo_message.height = txo['height']
txo_message.claim.short_url = resolve_result.short_url txo_message.claim.short_url = txo['short_url']
txo_message.claim.reposted = resolve_result.reposted txo_message.claim.reposted = txo['reposted']
txo_message.claim.is_controlling = resolve_result.is_controlling if txo['canonical_url'] is not None:
txo_message.claim.creation_height = resolve_result.creation_height txo_message.claim.canonical_url = txo['canonical_url']
txo_message.claim.activation_height = resolve_result.activation_height txo_message.claim.is_controlling = bool(txo['is_controlling'])
txo_message.claim.expiration_height = resolve_result.expiration_height if txo['last_take_over_height'] is not None:
txo_message.claim.effective_amount = resolve_result.effective_amount txo_message.claim.take_over_height = txo['last_take_over_height']
txo_message.claim.support_amount = resolve_result.support_amount txo_message.claim.creation_height = txo['creation_height']
txo_message.claim.activation_height = txo['activation_height']
if resolve_result.canonical_url is not None: txo_message.claim.expiration_height = txo['expiration_height']
txo_message.claim.canonical_url = resolve_result.canonical_url if txo['claims_in_channel'] is not None:
if resolve_result.last_takeover_height is not None: txo_message.claim.claims_in_channel = txo['claims_in_channel']
txo_message.claim.take_over_height = resolve_result.last_takeover_height txo_message.claim.effective_amount = txo['effective_amount']
if resolve_result.claims_in_channel is not None: txo_message.claim.support_amount = txo['support_amount']
txo_message.claim.claims_in_channel = resolve_result.claims_in_channel txo_message.claim.trending_group = txo['trending_group']
txo_message.claim.trending_mixed = txo['trending_mixed']
txo_message.claim.trending_local = txo['trending_local']
txo_message.claim.trending_global = txo['trending_global']
set_reference(txo_message.claim.channel, txo['channel_hash'], extra_txo_rows)
set_reference(txo_message.claim.repost, txo['reposted_claim_hash'], extra_txo_rows)

View file

@ -1,23 +1,6 @@
from lbry.schema.base import Signable from lbry.schema.base import Signable
from lbry.schema.types.v2.support_pb2 import Support as SupportMessage
class Support(Signable): class Support(Signable):
__slots__ = () __slots__ = ()
message_class = SupportMessage message_class = None # TODO: add support protobufs
@property
def emoji(self) -> str:
return self.message.emoji
@emoji.setter
def emoji(self, emoji: str):
self.message.emoji = emoji
@property
def comment(self) -> str:
return self.message.comment
@comment.setter
def comment(self, comment: str):
self.message.comment = comment

View file

@ -1,11 +1,13 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: result.proto # source: result.proto
"""Generated protocol buffer code."""
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@ -17,10 +19,9 @@ DESCRIPTOR = _descriptor.FileDescriptor(
name='result.proto', name='result.proto',
package='pb', package='pb',
syntax='proto3', syntax='proto3',
serialized_options=b'Z$github.com/lbryio/hub/protobuf/go/pb', serialized_pb=_b('\n\x0cresult.proto\x12\x02pb\"\x97\x01\n\x07Outputs\x12\x18\n\x04txos\x18\x01 \x03(\x0b\x32\n.pb.Output\x12\x1e\n\nextra_txos\x18\x02 \x03(\x0b\x32\n.pb.Output\x12\r\n\x05total\x18\x03 \x01(\r\x12\x0e\n\x06offset\x18\x04 \x01(\r\x12\x1c\n\x07\x62locked\x18\x05 \x03(\x0b\x32\x0b.pb.Blocked\x12\x15\n\rblocked_total\x18\x06 \x01(\r\"{\n\x06Output\x12\x0f\n\x07tx_hash\x18\x01 \x01(\x0c\x12\x0c\n\x04nout\x18\x02 \x01(\r\x12\x0e\n\x06height\x18\x03 \x01(\r\x12\x1e\n\x05\x63laim\x18\x07 \x01(\x0b\x32\r.pb.ClaimMetaH\x00\x12\x1a\n\x05\x65rror\x18\x0f \x01(\x0b\x32\t.pb.ErrorH\x00\x42\x06\n\x04meta\"\xaf\x03\n\tClaimMeta\x12\x1b\n\x07\x63hannel\x18\x01 \x01(\x0b\x32\n.pb.Output\x12\x1a\n\x06repost\x18\x02 \x01(\x0b\x32\n.pb.Output\x12\x11\n\tshort_url\x18\x03 \x01(\t\x12\x15\n\rcanonical_url\x18\x04 \x01(\t\x12\x16\n\x0eis_controlling\x18\x05 \x01(\x08\x12\x18\n\x10take_over_height\x18\x06 \x01(\r\x12\x17\n\x0f\x63reation_height\x18\x07 \x01(\r\x12\x19\n\x11\x61\x63tivation_height\x18\x08 \x01(\r\x12\x19\n\x11\x65xpiration_height\x18\t \x01(\r\x12\x19\n\x11\x63laims_in_channel\x18\n \x01(\r\x12\x10\n\x08reposted\x18\x0b \x01(\r\x12\x18\n\x10\x65\x66\x66\x65\x63tive_amount\x18\x14 \x01(\x04\x12\x16\n\x0esupport_amount\x18\x15 \x01(\x04\x12\x16\n\x0etrending_group\x18\x16 \x01(\r\x12\x16\n\x0etrending_mixed\x18\x17 \x01(\x02\x12\x16\n\x0etrending_local\x18\x18 \x01(\x02\x12\x17\n\x0ftrending_global\x18\x19 \x01(\x02\"\x94\x01\n\x05\x45rror\x12\x1c\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x0e.pb.Error.Code\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x1c\n\x07\x62locked\x18\x03 \x01(\x0b\x32\x0b.pb.Blocked\"A\n\x04\x43ode\x12\x10\n\x0cUNKNOWN_CODE\x10\x00\x12\r\n\tNOT_FOUND\x10\x01\x12\x0b\n\x07INVALID\x10\x02\x12\x0b\n\x07\x42LOCKED\x10\x03\"5\n\x07\x42locked\x12\r\n\x05\x63ount\x18\x01 \x01(\r\x12\x1b\n\x07\x63hannel\x18\x02 \x01(\x0b\x32\n.pb.Outputb\x06proto3')
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\x0cresult.proto\x12\x02pb\"\x97\x01\n\x07Outputs\x12\x18\n\x04txos\x18\x01 \x03(\x0b\x32\n.pb.Output\x12\x1e\n\nextra_txos\x18\x02 \x03(\x0b\x32\n.pb.Output\x12\r\n\x05total\x18\x03 \x01(\r\x12\x0e\n\x06offset\x18\x04 \x01(\r\x12\x1c\n\x07\x62locked\x18\x05 \x03(\x0b\x32\x0b.pb.Blocked\x12\x15\n\rblocked_total\x18\x06 \x01(\r\"{\n\x06Output\x12\x0f\n\x07tx_hash\x18\x01 \x01(\x0c\x12\x0c\n\x04nout\x18\x02 \x01(\r\x12\x0e\n\x06height\x18\x03 \x01(\r\x12\x1e\n\x05\x63laim\x18\x07 \x01(\x0b\x32\r.pb.ClaimMetaH\x00\x12\x1a\n\x05\x65rror\x18\x0f \x01(\x0b\x32\t.pb.ErrorH\x00\x42\x06\n\x04meta\"\xe6\x02\n\tClaimMeta\x12\x1b\n\x07\x63hannel\x18\x01 \x01(\x0b\x32\n.pb.Output\x12\x1a\n\x06repost\x18\x02 \x01(\x0b\x32\n.pb.Output\x12\x11\n\tshort_url\x18\x03 \x01(\t\x12\x15\n\rcanonical_url\x18\x04 \x01(\t\x12\x16\n\x0eis_controlling\x18\x05 \x01(\x08\x12\x18\n\x10take_over_height\x18\x06 \x01(\r\x12\x17\n\x0f\x63reation_height\x18\x07 \x01(\r\x12\x19\n\x11\x61\x63tivation_height\x18\x08 \x01(\r\x12\x19\n\x11\x65xpiration_height\x18\t \x01(\r\x12\x19\n\x11\x63laims_in_channel\x18\n \x01(\r\x12\x10\n\x08reposted\x18\x0b \x01(\r\x12\x18\n\x10\x65\x66\x66\x65\x63tive_amount\x18\x14 \x01(\x04\x12\x16\n\x0esupport_amount\x18\x15 \x01(\x04\x12\x16\n\x0etrending_score\x18\x16 \x01(\x01\"\x94\x01\n\x05\x45rror\x12\x1c\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x0e.pb.Error.Code\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x1c\n\x07\x62locked\x18\x03 \x01(\x0b\x32\x0b.pb.Blocked\"A\n\x04\x43ode\x12\x10\n\x0cUNKNOWN_CODE\x10\x00\x12\r\n\tNOT_FOUND\x10\x01\x12\x0b\n\x07INVALID\x10\x02\x12\x0b\n\x07\x42LOCKED\x10\x03\"5\n\x07\x42locked\x12\r\n\x05\x63ount\x18\x01 \x01(\r\x12\x1b\n\x07\x63hannel\x18\x02 \x01(\x0b\x32\n.pb.OutputB&Z$github.com/lbryio/hub/protobuf/go/pbb\x06proto3'
) )
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
@ -29,33 +30,28 @@ _ERROR_CODE = _descriptor.EnumDescriptor(
full_name='pb.Error.Code', full_name='pb.Error.Code',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
create_key=_descriptor._internal_create_key,
values=[ values=[
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(
name='UNKNOWN_CODE', index=0, number=0, name='UNKNOWN_CODE', index=0, number=0,
serialized_options=None, options=None,
type=None, type=None),
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(
name='NOT_FOUND', index=1, number=1, name='NOT_FOUND', index=1, number=1,
serialized_options=None, options=None,
type=None, type=None),
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(
name='INVALID', index=2, number=2, name='INVALID', index=2, number=2,
serialized_options=None, options=None,
type=None, type=None),
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(
name='BLOCKED', index=3, number=3, name='BLOCKED', index=3, number=3,
serialized_options=None, options=None,
type=None, type=None),
create_key=_descriptor._internal_create_key),
], ],
containing_type=None, containing_type=None,
serialized_options=None, options=None,
serialized_start=744, serialized_start=817,
serialized_end=809, serialized_end=882,
) )
_sym_db.RegisterEnumDescriptor(_ERROR_CODE) _sym_db.RegisterEnumDescriptor(_ERROR_CODE)
@ -66,7 +62,6 @@ _OUTPUTS = _descriptor.Descriptor(
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='txos', full_name='pb.Outputs.txos', index=0, name='txos', full_name='pb.Outputs.txos', index=0,
@ -74,49 +69,49 @@ _OUTPUTS = _descriptor.Descriptor(
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='extra_txos', full_name='pb.Outputs.extra_txos', index=1, name='extra_txos', full_name='pb.Outputs.extra_txos', index=1,
number=2, type=11, cpp_type=10, label=3, number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='total', full_name='pb.Outputs.total', index=2, name='total', full_name='pb.Outputs.total', index=2,
number=3, type=13, cpp_type=3, label=1, number=3, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='offset', full_name='pb.Outputs.offset', index=3, name='offset', full_name='pb.Outputs.offset', index=3,
number=4, type=13, cpp_type=3, label=1, number=4, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='blocked', full_name='pb.Outputs.blocked', index=4, name='blocked', full_name='pb.Outputs.blocked', index=4,
number=5, type=11, cpp_type=10, label=3, number=5, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='blocked_total', full_name='pb.Outputs.blocked_total', index=5, name='blocked_total', full_name='pb.Outputs.blocked_total', index=5,
number=6, type=13, cpp_type=3, label=1, number=6, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
], ],
extensions=[ extensions=[
], ],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[
], ],
serialized_options=None, options=None,
is_extendable=False, is_extendable=False,
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
@ -133,59 +128,56 @@ _OUTPUT = _descriptor.Descriptor(
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='tx_hash', full_name='pb.Output.tx_hash', index=0, name='tx_hash', full_name='pb.Output.tx_hash', index=0,
number=1, type=12, cpp_type=9, label=1, number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"", has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='nout', full_name='pb.Output.nout', index=1, name='nout', full_name='pb.Output.nout', index=1,
number=2, type=13, cpp_type=3, label=1, number=2, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='height', full_name='pb.Output.height', index=2, name='height', full_name='pb.Output.height', index=2,
number=3, type=13, cpp_type=3, label=1, number=3, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='claim', full_name='pb.Output.claim', index=3, name='claim', full_name='pb.Output.claim', index=3,
number=7, type=11, cpp_type=10, label=1, number=7, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None, has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='error', full_name='pb.Output.error', index=4, name='error', full_name='pb.Output.error', index=4,
number=15, type=11, cpp_type=10, label=1, number=15, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None, has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
], ],
extensions=[ extensions=[
], ],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[
], ],
serialized_options=None, options=None,
is_extendable=False, is_extendable=False,
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
_descriptor.OneofDescriptor( _descriptor.OneofDescriptor(
name='meta', full_name='pb.Output.meta', name='meta', full_name='pb.Output.meta',
index=0, containing_type=None, index=0, containing_type=None, fields=[]),
create_key=_descriptor._internal_create_key,
fields=[]),
], ],
serialized_start=174, serialized_start=174,
serialized_end=297, serialized_end=297,
@ -198,7 +190,6 @@ _CLAIMMETA = _descriptor.Descriptor(
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='channel', full_name='pb.ClaimMeta.channel', index=0, name='channel', full_name='pb.ClaimMeta.channel', index=0,
@ -206,112 +197,133 @@ _CLAIMMETA = _descriptor.Descriptor(
has_default_value=False, default_value=None, has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='repost', full_name='pb.ClaimMeta.repost', index=1, name='repost', full_name='pb.ClaimMeta.repost', index=1,
number=2, type=11, cpp_type=10, label=1, number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None, has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='short_url', full_name='pb.ClaimMeta.short_url', index=2, name='short_url', full_name='pb.ClaimMeta.short_url', index=2,
number=3, type=9, cpp_type=9, label=1, number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='canonical_url', full_name='pb.ClaimMeta.canonical_url', index=3, name='canonical_url', full_name='pb.ClaimMeta.canonical_url', index=3,
number=4, type=9, cpp_type=9, label=1, number=4, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='is_controlling', full_name='pb.ClaimMeta.is_controlling', index=4, name='is_controlling', full_name='pb.ClaimMeta.is_controlling', index=4,
number=5, type=8, cpp_type=7, label=1, number=5, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False, has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='take_over_height', full_name='pb.ClaimMeta.take_over_height', index=5, name='take_over_height', full_name='pb.ClaimMeta.take_over_height', index=5,
number=6, type=13, cpp_type=3, label=1, number=6, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='creation_height', full_name='pb.ClaimMeta.creation_height', index=6, name='creation_height', full_name='pb.ClaimMeta.creation_height', index=6,
number=7, type=13, cpp_type=3, label=1, number=7, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='activation_height', full_name='pb.ClaimMeta.activation_height', index=7, name='activation_height', full_name='pb.ClaimMeta.activation_height', index=7,
number=8, type=13, cpp_type=3, label=1, number=8, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='expiration_height', full_name='pb.ClaimMeta.expiration_height', index=8, name='expiration_height', full_name='pb.ClaimMeta.expiration_height', index=8,
number=9, type=13, cpp_type=3, label=1, number=9, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='claims_in_channel', full_name='pb.ClaimMeta.claims_in_channel', index=9, name='claims_in_channel', full_name='pb.ClaimMeta.claims_in_channel', index=9,
number=10, type=13, cpp_type=3, label=1, number=10, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='reposted', full_name='pb.ClaimMeta.reposted', index=10, name='reposted', full_name='pb.ClaimMeta.reposted', index=10,
number=11, type=13, cpp_type=3, label=1, number=11, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='effective_amount', full_name='pb.ClaimMeta.effective_amount', index=11, name='effective_amount', full_name='pb.ClaimMeta.effective_amount', index=11,
number=20, type=4, cpp_type=4, label=1, number=20, type=4, cpp_type=4, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='support_amount', full_name='pb.ClaimMeta.support_amount', index=12, name='support_amount', full_name='pb.ClaimMeta.support_amount', index=12,
number=21, type=4, cpp_type=4, label=1, number=21, type=4, cpp_type=4, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='trending_score', full_name='pb.ClaimMeta.trending_score', index=13, name='trending_group', full_name='pb.ClaimMeta.trending_group', index=13,
number=22, type=1, cpp_type=5, label=1, number=22, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='trending_mixed', full_name='pb.ClaimMeta.trending_mixed', index=14,
number=23, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0), has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor(
name='trending_local', full_name='pb.ClaimMeta.trending_local', index=15,
number=24, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='trending_global', full_name='pb.ClaimMeta.trending_global', index=16,
number=25, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
], ],
extensions=[ extensions=[
], ],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[
], ],
serialized_options=None, options=None,
is_extendable=False, is_extendable=False,
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=300, serialized_start=300,
serialized_end=658, serialized_end=731,
) )
@ -321,7 +333,6 @@ _ERROR = _descriptor.Descriptor(
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='code', full_name='pb.Error.code', index=0, name='code', full_name='pb.Error.code', index=0,
@ -329,21 +340,21 @@ _ERROR = _descriptor.Descriptor(
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='text', full_name='pb.Error.text', index=1, name='text', full_name='pb.Error.text', index=1,
number=2, type=9, cpp_type=9, label=1, number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'), has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='blocked', full_name='pb.Error.blocked', index=2, name='blocked', full_name='pb.Error.blocked', index=2,
number=3, type=11, cpp_type=10, label=1, number=3, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None, has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
], ],
extensions=[ extensions=[
], ],
@ -351,14 +362,14 @@ _ERROR = _descriptor.Descriptor(
enum_types=[ enum_types=[
_ERROR_CODE, _ERROR_CODE,
], ],
serialized_options=None, options=None,
is_extendable=False, is_extendable=False,
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=661, serialized_start=734,
serialized_end=809, serialized_end=882,
) )
@ -368,7 +379,6 @@ _BLOCKED = _descriptor.Descriptor(
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='count', full_name='pb.Blocked.count', index=0, name='count', full_name='pb.Blocked.count', index=0,
@ -376,28 +386,28 @@ _BLOCKED = _descriptor.Descriptor(
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='channel', full_name='pb.Blocked.channel', index=1, name='channel', full_name='pb.Blocked.channel', index=1,
number=2, type=11, cpp_type=10, label=1, number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None, has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), options=None),
], ],
extensions=[ extensions=[
], ],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[
], ],
serialized_options=None, options=None,
is_extendable=False, is_extendable=False,
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=811, serialized_start=884,
serialized_end=864, serialized_end=937,
) )
_OUTPUTS.fields_by_name['txos'].message_type = _OUTPUT _OUTPUTS.fields_by_name['txos'].message_type = _OUTPUT
@ -422,43 +432,41 @@ DESCRIPTOR.message_types_by_name['Output'] = _OUTPUT
DESCRIPTOR.message_types_by_name['ClaimMeta'] = _CLAIMMETA DESCRIPTOR.message_types_by_name['ClaimMeta'] = _CLAIMMETA
DESCRIPTOR.message_types_by_name['Error'] = _ERROR DESCRIPTOR.message_types_by_name['Error'] = _ERROR
DESCRIPTOR.message_types_by_name['Blocked'] = _BLOCKED DESCRIPTOR.message_types_by_name['Blocked'] = _BLOCKED
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Outputs = _reflection.GeneratedProtocolMessageType('Outputs', (_message.Message,), { Outputs = _reflection.GeneratedProtocolMessageType('Outputs', (_message.Message,), dict(
'DESCRIPTOR' : _OUTPUTS, DESCRIPTOR = _OUTPUTS,
'__module__' : 'result_pb2' __module__ = 'result_pb2'
# @@protoc_insertion_point(class_scope:pb.Outputs) # @@protoc_insertion_point(class_scope:pb.Outputs)
}) ))
_sym_db.RegisterMessage(Outputs) _sym_db.RegisterMessage(Outputs)
Output = _reflection.GeneratedProtocolMessageType('Output', (_message.Message,), { Output = _reflection.GeneratedProtocolMessageType('Output', (_message.Message,), dict(
'DESCRIPTOR' : _OUTPUT, DESCRIPTOR = _OUTPUT,
'__module__' : 'result_pb2' __module__ = 'result_pb2'
# @@protoc_insertion_point(class_scope:pb.Output) # @@protoc_insertion_point(class_scope:pb.Output)
}) ))
_sym_db.RegisterMessage(Output) _sym_db.RegisterMessage(Output)
ClaimMeta = _reflection.GeneratedProtocolMessageType('ClaimMeta', (_message.Message,), { ClaimMeta = _reflection.GeneratedProtocolMessageType('ClaimMeta', (_message.Message,), dict(
'DESCRIPTOR' : _CLAIMMETA, DESCRIPTOR = _CLAIMMETA,
'__module__' : 'result_pb2' __module__ = 'result_pb2'
# @@protoc_insertion_point(class_scope:pb.ClaimMeta) # @@protoc_insertion_point(class_scope:pb.ClaimMeta)
}) ))
_sym_db.RegisterMessage(ClaimMeta) _sym_db.RegisterMessage(ClaimMeta)
Error = _reflection.GeneratedProtocolMessageType('Error', (_message.Message,), { Error = _reflection.GeneratedProtocolMessageType('Error', (_message.Message,), dict(
'DESCRIPTOR' : _ERROR, DESCRIPTOR = _ERROR,
'__module__' : 'result_pb2' __module__ = 'result_pb2'
# @@protoc_insertion_point(class_scope:pb.Error) # @@protoc_insertion_point(class_scope:pb.Error)
}) ))
_sym_db.RegisterMessage(Error) _sym_db.RegisterMessage(Error)
Blocked = _reflection.GeneratedProtocolMessageType('Blocked', (_message.Message,), { Blocked = _reflection.GeneratedProtocolMessageType('Blocked', (_message.Message,), dict(
'DESCRIPTOR' : _BLOCKED, DESCRIPTOR = _BLOCKED,
'__module__' : 'result_pb2' __module__ = 'result_pb2'
# @@protoc_insertion_point(class_scope:pb.Blocked) # @@protoc_insertion_point(class_scope:pb.Blocked)
}) ))
_sym_db.RegisterMessage(Blocked) _sym_db.RegisterMessage(Blocked)
DESCRIPTOR._options = None
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View file

@ -1,76 +0,0 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: support.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='support.proto',
package='pb',
syntax='proto3',
serialized_pb=_b('\n\rsupport.proto\x12\x02pb\")\n\x07Support\x12\r\n\x05\x65moji\x18\x01 \x01(\t\x12\x0f\n\x07\x63omment\x18\x02 \x01(\tb\x06proto3')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_SUPPORT = _descriptor.Descriptor(
name='Support',
full_name='pb.Support',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='emoji', full_name='pb.Support.emoji', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment', full_name='pb.Support.comment', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=21,
serialized_end=62,
)
DESCRIPTOR.message_types_by_name['Support'] = _SUPPORT
Support = _reflection.GeneratedProtocolMessageType('Support', (_message.Message,), dict(
DESCRIPTOR = _SUPPORT,
__module__ = 'support_pb2'
# @@protoc_insertion_point(class_scope:pb.Support)
))
_sym_db.RegisterMessage(Support)
# @@protoc_insertion_point(module_scope)

View file

@ -1,139 +0,0 @@
{
"title": "Wallet",
"description": "An LBC wallet",
"type": "object",
"required": ["name", "version", "accounts", "preferences"],
"additionalProperties": false,
"properties": {
"name": {
"description": "Human readable name for this wallet",
"type": "string"
},
"version": {
"description": "Wallet spec version",
"type": "integer",
"$comment": "Should this be a string? We may need some sort of decimal type if we want exact decimal versions."
},
"accounts": {
"description": "Accounts associated with this wallet",
"type": "array",
"items": {
"type": "object",
"required": ["address_generator", "certificates", "encrypted", "ledger", "modified_on", "name", "private_key", "public_key", "seed"],
"additionalProperties": false,
"properties": {
"address_generator": {
"description": "Higher level manager of either singular or deterministically generated addresses",
"type": "object",
"oneOf": [
{
"required": ["name", "change", "receiving"],
"additionalProperties": false,
"properties": {
"name": {
"description": "type of address generator: a deterministic chain of addresses",
"enum": ["deterministic-chain"],
"type": "string"
},
"change": {
"$ref": "#/$defs/address_manager",
"description": "Manager for deterministically generated change address (not used for single address)"
},
"receiving": {
"$ref": "#/$defs/address_manager",
"description": "Manager for deterministically generated receiving address (not used for single address)"
}
}
}, {
"required": ["name"],
"additionalProperties": false,
"properties": {
"name": {
"description": "type of address generator: a single address",
"enum": ["single-address"],
"type": "string"
}
}
}
]
},
"certificates": {
"type": "object",
"description": "Channel keys. Mapping from public key address to pem-formatted private key.",
"additionalProperties": {"type": "string"}
},
"encrypted": {
"type": "boolean",
"description": "Whether private key and seed are encrypted with a password"
},
"ledger": {
"description": "Which network to use",
"type": "string",
"examples": [
"lbc_mainnet",
"lbc_testnet"
]
},
"modified_on": {
"description": "last modified time in Unix Time",
"type": "integer"
},
"name": {
"description": "Name for account, possibly human readable",
"type": "string"
},
"private_key": {
"description": "Private key for address if `address_generator` is a single address. Root of chain of private keys for addresses if `address_generator` is a deterministic chain of addresses. Encrypted if `encrypted` is true.",
"type": "string"
},
"public_key": {
"description": "Public key for address if `address_generator` is a single address. Root of chain of public keys for addresses if `address_generator` is a deterministic chain of addresses.",
"type": "string"
},
"seed": {
"description": "Human readable representation of `private_key`. encrypted if `encrypted` is set to `true`",
"type": "string"
}
}
}
},
"preferences": {
"description": "Timestamped application-level preferences. Values can be objects or of a primitive type.",
"$comment": "enable-sync is seen in example wallet. encrypt-on-disk is seen in example wallet. they both have a boolean `value` field. Do we want them explicitly defined here? local and shared seem to have at least a similar structure (type, value [yes, again], version), value being the free-form part. Should we define those here? Or can there be any key under preferences, and `value` be literally be anything in any form?",
"type": "object",
"additionalProperties": {
"type": "object",
"required": ["ts", "value"],
"additionalProperties": false,
"properties": {
"ts": {
"type": "number",
"description": "When the item was set, in Unix time format.",
"$comment": "Do we want a string (decimal)?"
},
"value": {
"$comment": "Sometimes this has been an object, sometimes just a boolean. I don't want to prescribe anything."
}
}
}
}
},
"$defs": {
"address_manager": {
"description": "Manager for deterministically generated addresses",
"type": "object",
"required": ["gap", "maximum_uses_per_address"],
"additionalProperties": false,
"properties": {
"gap": {
"description": "Maximum allowed consecutive generated addresses with no transactions",
"type": "integer"
},
"maximum_uses_per_address": {
"description": "Maximum number of uses for each generated address",
"type": "integer"
}
}
}
}
}

View file

@ -22,7 +22,8 @@ def _create_url_regex():
return _group( return _group(
_named(name+"_name", prefix + invalid_names_regex) + _named(name+"_name", prefix + invalid_names_regex) +
_oneof( _oneof(
_group('[:#]' + _named(name+"_claim_id", "[0-9a-f]{1,40}")), _group('#' + _named(name+"_claim_id", "[0-9a-f]{1,40}")),
_group(':' + _named(name+"_sequence", '[1-9][0-9]*')),
_group(r'\$' + _named(name+"_amount_order", '[1-9][0-9]*')) _group(r'\$' + _named(name+"_amount_order", '[1-9][0-9]*'))
) + '?' ) + '?'
) )
@ -49,31 +50,28 @@ def normalize_name(name):
class PathSegment(NamedTuple): class PathSegment(NamedTuple):
name: str name: str
claim_id: str = None claim_id: str = None
sequence: int = None
amount_order: int = None amount_order: int = None
@property @property
def normalized(self): def normalized(self):
return normalize_name(self.name) return normalize_name(self.name)
@property
def is_shortid(self):
return self.claim_id is not None and len(self.claim_id) < 40
@property
def is_fullid(self):
return self.claim_id is not None and len(self.claim_id) == 40
def to_dict(self): def to_dict(self):
q = {'name': self.name} q = {'name': self.name}
if self.claim_id is not None: if self.claim_id is not None:
q['claim_id'] = self.claim_id q['claim_id'] = self.claim_id
if self.sequence is not None:
q['sequence'] = self.sequence
if self.amount_order is not None: if self.amount_order is not None:
q['amount_order'] = self.amount_order q['amount_order'] = self.amount_order
return q return q
def __str__(self): def __str__(self):
if self.claim_id is not None: if self.claim_id is not None:
return f"{self.name}:{self.claim_id}" return f"{self.name}#{self.claim_id}"
elif self.sequence is not None:
return f"{self.name}:{self.sequence}"
elif self.amount_order is not None: elif self.amount_order is not None:
return f"{self.name}${self.amount_order}" return f"{self.name}${self.amount_order}"
return self.name return self.name
@ -120,6 +118,7 @@ class URL(NamedTuple):
segments[segment] = PathSegment( segments[segment] = PathSegment(
parts[f'{segment}_name'], parts[f'{segment}_name'],
parts[f'{segment}_claim_id'], parts[f'{segment}_claim_id'],
parts[f'{segment}_sequence'],
parts[f'{segment}_amount_order'] parts[f'{segment}_amount_order']
) )

View file

@ -1,31 +0,0 @@
import asyncio
import logging
from lbry.stream.downloader import StreamDownloader
log = logging.getLogger(__name__)
class BackgroundDownloader:
def __init__(self, conf, storage, blob_manager, dht_node=None):
self.storage = storage
self.blob_manager = blob_manager
self.node = dht_node
self.conf = conf
async def download_blobs(self, sd_hash):
downloader = StreamDownloader(asyncio.get_running_loop(), self.conf, self.blob_manager, sd_hash)
try:
await downloader.start(self.node, save_stream=False)
for blob_info in downloader.descriptor.blobs[:-1]:
await downloader.download_stream_blob(blob_info)
except ValueError:
return
except asyncio.CancelledError:
log.debug("Cancelled background downloader")
raise
except Exception:
log.error("Unexpected download error on background downloader")
finally:
downloader.stop()

View file

@ -4,7 +4,6 @@ import binascii
import logging import logging
import typing import typing
import asyncio import asyncio
import time
import re import re
from collections import OrderedDict from collections import OrderedDict
from cryptography.hazmat.primitives.ciphers.algorithms import AES from cryptography.hazmat.primitives.ciphers.algorithms import AES
@ -45,23 +44,16 @@ def random_iv_generator() -> typing.Generator[bytes, None, None]:
yield os.urandom(AES.block_size // 8) yield os.urandom(AES.block_size // 8)
def read_bytes(file_path: str, offset: int, to_read: int): def file_reader(file_path: str):
with open(file_path, 'rb') as f:
f.seek(offset)
return f.read(to_read)
async def file_reader(file_path: str):
length = int(os.stat(file_path).st_size) length = int(os.stat(file_path).st_size)
offset = 0 offset = 0
with open(file_path, 'rb') as stream_file:
while offset < length: while offset < length:
bytes_to_read = min((length - offset), MAX_BLOB_SIZE - 1) bytes_to_read = min((length - offset), MAX_BLOB_SIZE - 1)
if not bytes_to_read: if not bytes_to_read:
break break
blob_bytes = await asyncio.get_event_loop().run_in_executor( blob_bytes = stream_file.read(bytes_to_read)
None, read_bytes, file_path, offset, bytes_to_read
)
yield blob_bytes yield blob_bytes
offset += bytes_to_read offset += bytes_to_read
@ -153,19 +145,15 @@ class StreamDescriptor:
h.update(self.old_sort_json()) h.update(self.old_sort_json())
return h.hexdigest() return h.hexdigest()
async def make_sd_blob( async def make_sd_blob(self, blob_file_obj: typing.Optional[AbstractBlob] = None,
self, blob_file_obj: typing.Optional[AbstractBlob] = None, old_sort: typing.Optional[bool] = False, old_sort: typing.Optional[bool] = False,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], None]] = None, blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], None]] = None):
added_on: float = None, is_mine: bool = False
):
sd_hash = self.calculate_sd_hash() if not old_sort else self.calculate_old_sort_sd_hash() sd_hash = self.calculate_sd_hash() if not old_sort else self.calculate_old_sort_sd_hash()
if not old_sort: if not old_sort:
sd_data = self.as_json() sd_data = self.as_json()
else: else:
sd_data = self.old_sort_json() sd_data = self.old_sort_json()
sd_blob = blob_file_obj or BlobFile( sd_blob = blob_file_obj or BlobFile(self.loop, sd_hash, len(sd_data), blob_completed_callback, self.blob_dir)
self.loop, sd_hash, len(sd_data), blob_completed_callback, self.blob_dir, added_on, is_mine
)
if blob_file_obj: if blob_file_obj:
blob_file_obj.set_length(len(sd_data)) blob_file_obj.set_length(len(sd_data))
if not sd_blob.get_is_verified(): if not sd_blob.get_is_verified():
@ -188,19 +176,18 @@ class StreamDescriptor:
raise InvalidStreamDescriptorError("Does not decode as valid JSON") raise InvalidStreamDescriptorError("Does not decode as valid JSON")
if decoded['blobs'][-1]['length'] != 0: if decoded['blobs'][-1]['length'] != 0:
raise InvalidStreamDescriptorError("Does not end with a zero-length blob.") raise InvalidStreamDescriptorError("Does not end with a zero-length blob.")
if any(blob_info['length'] == 0 for blob_info in decoded['blobs'][:-1]): if any([blob_info['length'] == 0 for blob_info in decoded['blobs'][:-1]]):
raise InvalidStreamDescriptorError("Contains zero-length data blob") raise InvalidStreamDescriptorError("Contains zero-length data blob")
if 'blob_hash' in decoded['blobs'][-1]: if 'blob_hash' in decoded['blobs'][-1]:
raise InvalidStreamDescriptorError("Stream terminator blob should not have a hash") raise InvalidStreamDescriptorError("Stream terminator blob should not have a hash")
if any(i != blob_info['blob_num'] for i, blob_info in enumerate(decoded['blobs'])): if any([i != blob_info['blob_num'] for i, blob_info in enumerate(decoded['blobs'])]):
raise InvalidStreamDescriptorError("Stream contains out of order or skipped blobs") raise InvalidStreamDescriptorError("Stream contains out of order or skipped blobs")
added_on = time.time()
descriptor = cls( descriptor = cls(
loop, blob_dir, loop, blob_dir,
binascii.unhexlify(decoded['stream_name']).decode(), binascii.unhexlify(decoded['stream_name']).decode(),
decoded['key'], decoded['key'],
binascii.unhexlify(decoded['suggested_file_name']).decode(), binascii.unhexlify(decoded['suggested_file_name']).decode(),
[BlobInfo(info['blob_num'], info['length'], info['iv'], added_on, info.get('blob_hash')) [BlobInfo(info['blob_num'], info['length'], info['iv'], info.get('blob_hash'))
for info in decoded['blobs']], for info in decoded['blobs']],
decoded['stream_hash'], decoded['stream_hash'],
blob.blob_hash blob.blob_hash
@ -258,25 +245,20 @@ class StreamDescriptor:
iv_generator = iv_generator or random_iv_generator() iv_generator = iv_generator or random_iv_generator()
key = key or os.urandom(AES.block_size // 8) key = key or os.urandom(AES.block_size // 8)
blob_num = -1 blob_num = -1
added_on = time.time() for blob_bytes in file_reader(file_path):
async for blob_bytes in file_reader(file_path):
blob_num += 1 blob_num += 1
blob_info = await BlobFile.create_from_unencrypted( blob_info = await BlobFile.create_from_unencrypted(
loop, blob_dir, key, next(iv_generator), blob_bytes, blob_num, added_on, True, blob_completed_callback loop, blob_dir, key, next(iv_generator), blob_bytes, blob_num, blob_completed_callback
) )
blobs.append(blob_info) blobs.append(blob_info)
blobs.append( blobs.append(
# add the stream terminator BlobInfo(len(blobs), 0, binascii.hexlify(next(iv_generator)).decode())) # add the stream terminator
BlobInfo(len(blobs), 0, binascii.hexlify(next(iv_generator)).decode(), added_on, None, True)
)
file_name = os.path.basename(file_path) file_name = os.path.basename(file_path)
suggested_file_name = sanitize_file_name(file_name) suggested_file_name = sanitize_file_name(file_name)
descriptor = cls( descriptor = cls(
loop, blob_dir, file_name, binascii.hexlify(key).decode(), suggested_file_name, blobs loop, blob_dir, file_name, binascii.hexlify(key).decode(), suggested_file_name, blobs
) )
sd_blob = await descriptor.make_sd_blob( sd_blob = await descriptor.make_sd_blob(old_sort=old_sort, blob_completed_callback=blob_completed_callback)
old_sort=old_sort, blob_completed_callback=blob_completed_callback, added_on=added_on, is_mine=True
)
descriptor.sd_hash = sd_blob.blob_hash descriptor.sd_hash = sd_blob.blob_hash
return descriptor return descriptor

View file

@ -3,13 +3,11 @@ import typing
import logging import logging
import binascii import binascii
from lbry.dht.node import get_kademlia_peers_from_hosts from lbry.dht.peer import make_kademlia_peer
from lbry.error import DownloadSDTimeoutError from lbry.error import DownloadSDTimeoutError
from lbry.utils import lru_cache_concurrent from lbry.utils import resolve_host, lru_cache_concurrent
from lbry.stream.descriptor import StreamDescriptor from lbry.stream.descriptor import StreamDescriptor
from lbry.blob_exchange.downloader import BlobDownloader from lbry.blob_exchange.downloader import BlobDownloader
from lbry.torrent.tracker import enqueue_tracker_search
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbry.conf import Config from lbry.conf import Config
from lbry.dht.node import Node from lbry.dht.node import Node
@ -27,8 +25,8 @@ class StreamDownloader:
self.config = config self.config = config
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.sd_hash = sd_hash self.sd_hash = sd_hash
self.search_queue = asyncio.Queue() # blob hashes to feed into the iterative finder self.search_queue = asyncio.Queue(loop=loop) # blob hashes to feed into the iterative finder
self.peer_queue = asyncio.Queue() # new peers to try self.peer_queue = asyncio.Queue(loop=loop) # new peers to try
self.blob_downloader = BlobDownloader(self.loop, self.config, self.blob_manager, self.peer_queue) self.blob_downloader = BlobDownloader(self.loop, self.config, self.blob_manager, self.peer_queue)
self.descriptor: typing.Optional[StreamDescriptor] = descriptor self.descriptor: typing.Optional[StreamDescriptor] = descriptor
self.node: typing.Optional['Node'] = None self.node: typing.Optional['Node'] = None
@ -42,7 +40,7 @@ class StreamDownloader:
async def cached_read_blob(blob_info: 'BlobInfo') -> bytes: async def cached_read_blob(blob_info: 'BlobInfo') -> bytes:
return await self.read_blob(blob_info, 2) return await self.read_blob(blob_info, 2)
if self.blob_manager.decrypted_blob_lru_cache is not None: if self.blob_manager.decrypted_blob_lru_cache:
cached_read_blob = lru_cache_concurrent(override_lru_cache=self.blob_manager.decrypted_blob_lru_cache)( cached_read_blob = lru_cache_concurrent(override_lru_cache=self.blob_manager.decrypted_blob_lru_cache)(
cached_read_blob cached_read_blob
) )
@ -50,19 +48,26 @@ class StreamDownloader:
self.cached_read_blob = cached_read_blob self.cached_read_blob = cached_read_blob
async def add_fixed_peers(self): async def add_fixed_peers(self):
def _add_fixed_peers(fixed_peers): def _delayed_add_fixed_peers():
self.peer_queue.put_nowait(fixed_peers)
self.added_fixed_peers = True self.added_fixed_peers = True
self.peer_queue.put_nowait([
make_kademlia_peer(None, address, None, tcp_port=port + 1, allow_localhost=True)
for address, port in addresses
])
if not self.config.fixed_peers: if not self.config.reflector_servers:
return return
addresses = [
(await resolve_host(url, port + 1, proto='tcp'), port)
for url, port in self.config.reflector_servers
]
if 'dht' in self.config.components_to_skip or not self.node or not \ if 'dht' in self.config.components_to_skip or not self.node or not \
len(self.node.protocol.routing_table.get_peers()) > 0: len(self.node.protocol.routing_table.get_peers()) > 0:
self.fixed_peers_delay = 0.0 self.fixed_peers_delay = 0.0
else: else:
self.fixed_peers_delay = self.config.fixed_peer_delay self.fixed_peers_delay = self.config.fixed_peer_delay
fixed_peers = await get_kademlia_peers_from_hosts(self.config.fixed_peers)
self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _add_fixed_peers, fixed_peers) self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _delayed_add_fixed_peers)
async def load_descriptor(self, connection_id: int = 0): async def load_descriptor(self, connection_id: int = 0):
# download or get the sd blob # download or get the sd blob
@ -72,7 +77,7 @@ class StreamDownloader:
now = self.loop.time() now = self.loop.time()
sd_blob = await asyncio.wait_for( sd_blob = await asyncio.wait_for(
self.blob_downloader.download_blob(self.sd_hash, connection_id), self.blob_downloader.download_blob(self.sd_hash, connection_id),
self.config.blob_download_timeout self.config.blob_download_timeout, loop=self.loop
) )
log.info("downloaded sd blob %s", self.sd_hash) log.info("downloaded sd blob %s", self.sd_hash)
self.time_to_descriptor = self.loop.time() - now self.time_to_descriptor = self.loop.time() - now
@ -85,7 +90,7 @@ class StreamDownloader:
) )
log.info("loaded stream manifest %s", self.sd_hash) log.info("loaded stream manifest %s", self.sd_hash)
async def start(self, node: typing.Optional['Node'] = None, connection_id: int = 0, save_stream=True): async def start(self, node: typing.Optional['Node'] = None, connection_id: int = 0):
# set up peer accumulation # set up peer accumulation
self.node = node or self.node # fixme: this shouldnt be set here! self.node = node or self.node # fixme: this shouldnt be set here!
if self.node: if self.node:
@ -93,7 +98,6 @@ class StreamDownloader:
self.accumulate_task.cancel() self.accumulate_task.cancel()
_, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue) _, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue)
await self.add_fixed_peers() await self.add_fixed_peers()
enqueue_tracker_search(bytes.fromhex(self.sd_hash), self.peer_queue)
# start searching for peers for the sd hash # start searching for peers for the sd hash
self.search_queue.put_nowait(self.sd_hash) self.search_queue.put_nowait(self.sd_hash)
log.info("searching for peers for stream %s", self.sd_hash) log.info("searching for peers for stream %s", self.sd_hash)
@ -101,7 +105,11 @@ class StreamDownloader:
if not self.descriptor: if not self.descriptor:
await self.load_descriptor(connection_id) await self.load_descriptor(connection_id)
if not await self.blob_manager.storage.stream_exists(self.sd_hash) and save_stream: # add the head blob to the peer search
self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash)
log.info("added head blob to peer search for stream %s", self.sd_hash)
if not await self.blob_manager.storage.stream_exists(self.sd_hash):
await self.blob_manager.storage.store_stream( await self.blob_manager.storage.store_stream(
self.blob_manager.get_blob(self.sd_hash, length=self.descriptor.length), self.descriptor self.blob_manager.get_blob(self.sd_hash, length=self.descriptor.length), self.descriptor
) )
@ -111,7 +119,7 @@ class StreamDownloader:
raise ValueError(f"blob {blob_info.blob_hash} is not part of stream with sd hash {self.sd_hash}") raise ValueError(f"blob {blob_info.blob_hash} is not part of stream with sd hash {self.sd_hash}")
blob = await asyncio.wait_for( blob = await asyncio.wait_for(
self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length, connection_id), self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length, connection_id),
self.config.blob_download_timeout * 10 self.config.blob_download_timeout * 10, loop=self.loop
) )
return blob return blob

View file

@ -16,8 +16,10 @@ from lbry.file.source import ManagedDownloadSource
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbry.conf import Config from lbry.conf import Config
from lbry.schema.claim import Claim
from lbry.blob.blob_manager import BlobManager from lbry.blob.blob_manager import BlobManager
from lbry.blob.blob_info import BlobInfo from lbry.blob.blob_info import BlobInfo
from lbry.dht.node import Node
from lbry.extras.daemon.analytics import AnalyticsManager from lbry.extras.daemon.analytics import AnalyticsManager
from lbry.wallet.transaction import Transaction from lbry.wallet.transaction import Transaction
@ -54,15 +56,12 @@ class ManagedStream(ManagedDownloadSource):
self.purchase_receipt = None self.purchase_receipt = None
self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor) self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor)
self.analytics_manager = analytics_manager self.analytics_manager = analytics_manager
self.file_output_task: Optional[asyncio.Task] = None
self.reflector_progress = 0 self.delayed_stop_task: Optional[asyncio.Task] = None
self.uploading_to_reflector = False
self.file_output_task: typing.Optional[asyncio.Task] = None
self.delayed_stop_task: typing.Optional[asyncio.Task] = None
self.streaming_responses: typing.List[typing.Tuple[Request, StreamResponse]] = [] self.streaming_responses: typing.List[typing.Tuple[Request, StreamResponse]] = []
self.fully_reflected = asyncio.Event() self.fully_reflected = asyncio.Event(loop=self.loop)
self.streaming = asyncio.Event() self.streaming = asyncio.Event(loop=self.loop)
self._running = asyncio.Event() self._running = asyncio.Event(loop=self.loop)
@property @property
def sd_hash(self) -> str: def sd_hash(self) -> str:
@ -82,19 +81,7 @@ class ManagedStream(ManagedDownloadSource):
@property @property
def file_name(self) -> Optional[str]: def file_name(self) -> Optional[str]:
return self._file_name or self.suggested_file_name return self._file_name or (self.descriptor.suggested_file_name if self.descriptor else None)
@property
def suggested_file_name(self) -> Optional[str]:
first_option = ((self.descriptor and self.descriptor.suggested_file_name) or '').strip()
return sanitize_file_name(first_option or (self.stream_claim_info and self.stream_claim_info.claim and
self.stream_claim_info.claim.stream.source.name))
@property
def stream_name(self) -> Optional[str]:
first_option = ((self.descriptor and self.descriptor.stream_name) or '').strip()
return first_option or (self.stream_claim_info and self.stream_claim_info.claim and
self.stream_claim_info.claim.stream.source.name)
@property @property
def written_bytes(self) -> int: def written_bytes(self) -> int:
@ -128,11 +115,7 @@ class ManagedStream(ManagedDownloadSource):
@property @property
def mime_type(self): def mime_type(self):
return guess_media_type(os.path.basename(self.suggested_file_name))[0] return guess_media_type(os.path.basename(self.descriptor.suggested_file_name))[0]
@property
def download_path(self):
return f"{self.download_directory}/{self._file_name}" if self.download_directory and self._file_name else None
# @classmethod # @classmethod
# async def create(cls, loop: asyncio.AbstractEventLoop, config: 'Config', # async def create(cls, loop: asyncio.AbstractEventLoop, config: 'Config',
@ -161,7 +144,7 @@ class ManagedStream(ManagedDownloadSource):
log.info("start downloader for stream (sd hash: %s)", self.sd_hash) log.info("start downloader for stream (sd hash: %s)", self.sd_hash)
self._running.set() self._running.set()
try: try:
await asyncio.wait_for(self.downloader.start(), timeout) await asyncio.wait_for(self.downloader.start(), timeout, loop=self.loop)
except asyncio.TimeoutError: except asyncio.TimeoutError:
self._running.clear() self._running.clear()
raise DownloadSDTimeoutError(self.sd_hash) raise DownloadSDTimeoutError(self.sd_hash)
@ -174,7 +157,7 @@ class ManagedStream(ManagedDownloadSource):
if not self._file_name: if not self._file_name:
self._file_name = await get_next_available_file_name( self._file_name = await get_next_available_file_name(
self.loop, self.download_directory, self.loop, self.download_directory,
self._file_name or sanitize_file_name(self.suggested_file_name) self._file_name or sanitize_file_name(self.descriptor.suggested_file_name)
) )
file_name, download_dir = self._file_name, self.download_directory file_name, download_dir = self._file_name, self.download_directory
else: else:
@ -191,7 +174,7 @@ class ManagedStream(ManagedDownloadSource):
Stop any running save/stream tasks as well as the downloader and update the status in the database Stop any running save/stream tasks as well as the downloader and update the status in the database
""" """
await self.stop_tasks() self.stop_tasks()
if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING: if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING:
await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED) await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED)
@ -207,13 +190,13 @@ class ManagedStream(ManagedDownloadSource):
decrypted = await self.downloader.read_blob(blob_info, connection_id) decrypted = await self.downloader.read_blob(blob_info, connection_id)
yield (blob_info, decrypted) yield (blob_info, decrypted)
async def stream_file(self, request: Request) -> StreamResponse: async def stream_file(self, request: Request, node: Optional['Node'] = None) -> StreamResponse:
log.info("stream file to browser for lbry://%s#%s (sd hash %s...)", self.claim_name, self.claim_id, log.info("stream file to browser for lbry://%s#%s (sd hash %s...)", self.claim_name, self.claim_id,
self.sd_hash[:6]) self.sd_hash[:6])
headers, size, skip_blobs, first_blob_start_offset = self._prepare_range_response_headers( headers, size, skip_blobs, first_blob_start_offset = self._prepare_range_response_headers(
request.headers.get('range', 'bytes=0-') request.headers.get('range', 'bytes=0-')
) )
await self.start() await self.start(node)
response = StreamResponse( response = StreamResponse(
status=206, status=206,
headers=headers headers=headers
@ -251,8 +234,7 @@ class ManagedStream(ManagedDownloadSource):
self.streaming.clear() self.streaming.clear()
@staticmethod @staticmethod
def _write_decrypted_blob(output_path: str, data: bytes): def _write_decrypted_blob(handle: typing.IO, data: bytes):
with open(output_path, 'ab') as handle:
handle.write(data) handle.write(data)
handle.flush() handle.flush()
@ -264,10 +246,10 @@ class ManagedStream(ManagedDownloadSource):
self.finished_writing.clear() self.finished_writing.clear()
self.started_writing.clear() self.started_writing.clear()
try: try:
open(output_path, 'wb').close() # pylint: disable=consider-using-with with open(output_path, 'wb') as file_write_handle:
async for blob_info, decrypted in self._aiter_read_stream(connection_id=self.SAVING_ID): async for blob_info, decrypted in self._aiter_read_stream(connection_id=self.SAVING_ID):
log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1) log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
await self.loop.run_in_executor(None, self._write_decrypted_blob, output_path, decrypted) await self.loop.run_in_executor(None, self._write_decrypted_blob, file_write_handle, decrypted)
if not self.started_writing.is_set(): if not self.started_writing.is_set():
self.started_writing.set() self.started_writing.set()
await self.update_status(ManagedStream.STATUS_FINISHED) await self.update_status(ManagedStream.STATUS_FINISHED)
@ -279,7 +261,7 @@ class ManagedStream(ManagedDownloadSource):
log.info("finished saving file for lbry://%s#%s (sd hash %s...) -> %s", self.claim_name, self.claim_id, log.info("finished saving file for lbry://%s#%s (sd hash %s...) -> %s", self.claim_name, self.claim_id,
self.sd_hash[:6], self.full_path) self.sd_hash[:6], self.full_path)
await self.blob_manager.storage.set_saved_file(self.stream_hash) await self.blob_manager.storage.set_saved_file(self.stream_hash)
except (Exception, asyncio.CancelledError) as err: except Exception as err:
if os.path.isfile(output_path): if os.path.isfile(output_path):
log.warning("removing incomplete download %s for %s", output_path, self.sd_hash) log.warning("removing incomplete download %s for %s", output_path, self.sd_hash)
os.remove(output_path) os.remove(output_path)
@ -306,14 +288,14 @@ class ManagedStream(ManagedDownloadSource):
self.download_directory = download_directory or self.download_directory or self.config.download_dir self.download_directory = download_directory or self.download_directory or self.config.download_dir
if not self.download_directory: if not self.download_directory:
raise ValueError("no directory to download to") raise ValueError("no directory to download to")
if not (file_name or self._file_name or self.suggested_file_name): if not (file_name or self._file_name or self.descriptor.suggested_file_name):
raise ValueError("no file name to download to") raise ValueError("no file name to download to")
if not os.path.isdir(self.download_directory): if not os.path.isdir(self.download_directory):
log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory) log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory)
os.mkdir(self.download_directory) os.mkdir(self.download_directory)
self._file_name = await get_next_available_file_name( self._file_name = await get_next_available_file_name(
self.loop, self.download_directory, self.loop, self.download_directory,
file_name or self._file_name or sanitize_file_name(self.suggested_file_name) file_name or self._file_name or sanitize_file_name(self.descriptor.suggested_file_name)
) )
await self.blob_manager.storage.change_file_download_dir_and_file_name( await self.blob_manager.storage.change_file_download_dir_and_file_name(
self.stream_hash, self.download_directory, self.file_name self.stream_hash, self.download_directory, self.file_name
@ -321,16 +303,15 @@ class ManagedStream(ManagedDownloadSource):
await self.update_status(ManagedStream.STATUS_RUNNING) await self.update_status(ManagedStream.STATUS_RUNNING)
self.file_output_task = self.loop.create_task(self._save_file(self.full_path)) self.file_output_task = self.loop.create_task(self._save_file(self.full_path))
try: try:
await asyncio.wait_for(self.started_writing.wait(), self.config.download_timeout) await asyncio.wait_for(self.started_writing.wait(), self.config.download_timeout, loop=self.loop)
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning("timeout starting to write data for lbry://%s#%s", self.claim_name, self.claim_id) log.warning("timeout starting to write data for lbry://%s#%s", self.claim_name, self.claim_id)
await self.stop_tasks() self.stop_tasks()
await self.update_status(ManagedStream.STATUS_STOPPED) await self.update_status(ManagedStream.STATUS_STOPPED)
async def stop_tasks(self): def stop_tasks(self):
if self.file_output_task and not self.file_output_task.done(): if self.file_output_task and not self.file_output_task.done():
self.file_output_task.cancel() self.file_output_task.cancel()
await asyncio.gather(self.file_output_task, return_exceptions=True)
self.file_output_task = None self.file_output_task = None
while self.streaming_responses: while self.streaming_responses:
req, response = self.streaming_responses.pop() req, response = self.streaming_responses.pop()
@ -343,7 +324,6 @@ class ManagedStream(ManagedDownloadSource):
sent = [] sent = []
protocol = StreamReflectorClient(self.blob_manager, self.descriptor) protocol = StreamReflectorClient(self.blob_manager, self.descriptor)
try: try:
self.uploading_to_reflector = True
await self.loop.create_connection(lambda: protocol, host, port) await self.loop.create_connection(lambda: protocol, host, port)
await protocol.send_handshake() await protocol.send_handshake()
sent_sd, needed = await protocol.send_descriptor() sent_sd, needed = await protocol.send_descriptor()
@ -359,30 +339,19 @@ class ManagedStream(ManagedDownloadSource):
] ]
log.info("we have %i/%i needed blobs needed by reflector for lbry://%s#%s", len(we_have), len(needed), log.info("we have %i/%i needed blobs needed by reflector for lbry://%s#%s", len(we_have), len(needed),
self.claim_name, self.claim_id) self.claim_name, self.claim_id)
for i, blob_hash in enumerate(we_have): for blob_hash in we_have:
await protocol.send_blob(blob_hash) await protocol.send_blob(blob_hash)
sent.append(blob_hash) sent.append(blob_hash)
self.reflector_progress = int((i + 1) / len(we_have) * 100)
except (asyncio.TimeoutError, ValueError): except (asyncio.TimeoutError, ValueError):
return sent return sent
except ConnectionError: except ConnectionRefusedError:
return sent
except (OSError, Exception, asyncio.CancelledError) as err:
if isinstance(err, asyncio.CancelledError):
log.warning("stopped uploading %s#%s to reflector", self.claim_name, self.claim_id)
elif isinstance(err, OSError):
log.warning(
"stopped uploading %s#%s to reflector because blobs were deleted or moved", self.claim_name,
self.claim_id
)
else:
log.exception("unexpected error reflecting %s#%s", self.claim_name, self.claim_id)
return sent return sent
finally: finally:
if protocol.transport: if protocol.transport:
protocol.transport.close() protocol.transport.close()
self.uploading_to_reflector = False if not self.fully_reflected.is_set():
self.fully_reflected.set()
await self.blob_manager.storage.update_reflected_stream(self.sd_hash, f"{host}:{port}")
return sent return sent
async def update_content_claim(self, claim_info: Optional[typing.Dict] = None): async def update_content_claim(self, claim_info: Optional[typing.Dict] = None):
@ -402,7 +371,7 @@ class ManagedStream(ManagedDownloadSource):
self.sd_hash[:6]) self.sd_hash[:6])
await self.stop() await self.stop()
return return
await asyncio.sleep(1) await asyncio.sleep(1, loop=self.loop)
def _prepare_range_response_headers(self, get_range: str) -> typing.Tuple[typing.Dict[str, str], int, int, int]: def _prepare_range_response_headers(self, get_range: str) -> typing.Tuple[typing.Dict[str, str], int, int, int]:
if '=' in get_range: if '=' in get_range:

View file

@ -35,8 +35,6 @@ class StreamReflectorClient(asyncio.Protocol):
def connection_lost(self, exc: typing.Optional[Exception]): def connection_lost(self, exc: typing.Optional[Exception]):
self.transport = None self.transport = None
self.connected.clear() self.connected.clear()
if self.pending_request:
self.pending_request.cancel()
if self.reflected_blobs: if self.reflected_blobs:
log.info("Finished sending reflector %i blobs", len(self.reflected_blobs)) log.info("Finished sending reflector %i blobs", len(self.reflected_blobs))
@ -58,18 +56,12 @@ class StreamReflectorClient(asyncio.Protocol):
self.response_buff = b'' self.response_buff = b''
return return
async def send_request(self, request_dict: typing.Dict, timeout: int = 180): async def send_request(self, request_dict: typing.Dict):
msg = json.dumps(request_dict, sort_keys=True) msg = json.dumps(request_dict)
try:
self.transport.write(msg.encode()) self.transport.write(msg.encode())
self.pending_request = self.loop.create_task(asyncio.wait_for(self.response_queue.get(), timeout)) try:
self.pending_request = self.loop.create_task(self.response_queue.get())
return await self.pending_request return await self.pending_request
except (AttributeError, asyncio.CancelledError) as err:
# attribute error happens when we transport.write after disconnect
# cancelled error happens when the pending_request task is cancelled by a disconnect
if self.transport:
self.transport.close()
raise err if isinstance(err, asyncio.CancelledError) else asyncio.CancelledError()
finally: finally:
self.pending_request = None self.pending_request = None
@ -94,16 +86,8 @@ class StreamReflectorClient(asyncio.Protocol):
needed = response.get('needed_blobs', []) needed = response.get('needed_blobs', [])
sent_sd = False sent_sd = False
if response['send_sd_blob']: if response['send_sd_blob']:
try: await sd_blob.sendfile(self)
sent = await sd_blob.sendfile(self) received = await self.response_queue.get()
if sent == -1:
log.warning("failed to send sd blob")
raise asyncio.CancelledError()
received = await asyncio.wait_for(self.response_queue.get(), 30)
except asyncio.CancelledError as err:
if self.transport:
self.transport.close()
raise err
if received.get('received_sd_blob'): if received.get('received_sd_blob'):
sent_sd = True sent_sd = True
if not needed: if not needed:
@ -126,16 +110,8 @@ class StreamReflectorClient(asyncio.Protocol):
if 'send_blob' not in response: if 'send_blob' not in response:
raise ValueError("I don't know whether to send the blob or not!") raise ValueError("I don't know whether to send the blob or not!")
if response['send_blob']: if response['send_blob']:
try: await blob.sendfile(self)
sent = await blob.sendfile(self) received = await self.response_queue.get()
if sent == -1:
log.warning("failed to send blob")
raise asyncio.CancelledError()
received = await asyncio.wait_for(self.response_queue.get(), 30)
except asyncio.CancelledError as err:
if self.transport:
self.transport.close()
raise err
if received.get('received_blob'): if received.get('received_blob'):
self.reflected_blobs.append(blob.blob_hash) self.reflected_blobs.append(blob.blob_hash)
log.info("Sent reflector blob %s", blob.blob_hash[:8]) log.info("Sent reflector blob %s", blob.blob_hash[:8])

View file

@ -15,13 +15,11 @@ log = logging.getLogger(__name__)
class ReflectorServerProtocol(asyncio.Protocol): class ReflectorServerProtocol(asyncio.Protocol):
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000, def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000):
stop_event: asyncio.Event = None, incoming_event: asyncio.Event = None,
not_incoming_event: asyncio.Event = None, partial_event: asyncio.Event = None):
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.server_task: asyncio.Task = None self.server_task: asyncio.Task = None
self.started_listening = asyncio.Event() self.started_listening = asyncio.Event(loop=self.loop)
self.buf = b'' self.buf = b''
self.transport: asyncio.StreamWriter = None self.transport: asyncio.StreamWriter = None
self.writer: typing.Optional['HashBlobWriter'] = None self.writer: typing.Optional['HashBlobWriter'] = None
@ -29,26 +27,11 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.descriptor: typing.Optional['StreamDescriptor'] = None self.descriptor: typing.Optional['StreamDescriptor'] = None
self.sd_blob: typing.Optional['BlobFile'] = None self.sd_blob: typing.Optional['BlobFile'] = None
self.received = [] self.received = []
self.incoming = incoming_event or asyncio.Event() self.incoming = asyncio.Event(loop=self.loop)
self.not_incoming = not_incoming_event or asyncio.Event()
self.stop_event = stop_event or asyncio.Event()
self.chunk_size = response_chunk_size self.chunk_size = response_chunk_size
self.wait_for_stop_task: typing.Optional[asyncio.Task] = None
self.partial_event = partial_event
async def wait_for_stop(self):
await self.stop_event.wait()
if self.transport:
self.transport.close()
def connection_made(self, transport): def connection_made(self, transport):
self.transport = transport self.transport = transport
self.wait_for_stop_task = self.loop.create_task(self.wait_for_stop())
def connection_lost(self, exc):
if self.wait_for_stop_task:
self.wait_for_stop_task.cancel()
self.wait_for_stop_task = None
def data_received(self, data: bytes): def data_received(self, data: bytes):
if self.incoming.is_set(): if self.incoming.is_set():
@ -90,11 +73,10 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size']) self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
if not self.sd_blob.get_is_verified(): if not self.sd_blob.get_is_verified():
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername')) self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername'))
self.not_incoming.clear()
self.incoming.set() self.incoming.set()
self.send_response({"send_sd_blob": True}) self.send_response({"send_sd_blob": True})
try: try:
await asyncio.wait_for(self.sd_blob.verified.wait(), 30) await asyncio.wait_for(self.sd_blob.verified.wait(), 30, loop=self.loop)
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob( self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
self.loop, self.blob_manager.blob_dir, self.sd_blob self.loop, self.blob_manager.blob_dir, self.sd_blob
) )
@ -104,7 +86,6 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.transport.close() self.transport.close()
finally: finally:
self.incoming.clear() self.incoming.clear()
self.not_incoming.set()
self.writer.close_handle() self.writer.close_handle()
self.writer = None self.writer = None
else: else:
@ -112,18 +93,13 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.loop, self.blob_manager.blob_dir, self.sd_blob self.loop, self.blob_manager.blob_dir, self.sd_blob
) )
self.incoming.clear() self.incoming.clear()
self.not_incoming.set()
if self.writer: if self.writer:
self.writer.close_handle() self.writer.close_handle()
self.writer = None self.writer = None
self.send_response({"send_sd_blob": False, 'needed': [
needs = [blob.blob_hash blob.blob_hash for blob in self.descriptor.blobs[:-1]
for blob in self.descriptor.blobs[:-1] if not self.blob_manager.get_blob(blob.blob_hash).get_is_verified()
if not self.blob_manager.get_blob(blob.blob_hash).get_is_verified()] ]})
if needs and not self.partial_event.is_set():
needs = needs[:3]
self.partial_event.set()
self.send_response({"send_sd_blob": False, 'needed_blobs': needs})
return return
return return
elif self.descriptor: elif self.descriptor:
@ -136,16 +112,14 @@ class ReflectorServerProtocol(asyncio.Protocol):
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size']) blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
if not blob.get_is_verified(): if not blob.get_is_verified():
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername')) self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername'))
self.not_incoming.clear()
self.incoming.set() self.incoming.set()
self.send_response({"send_blob": True}) self.send_response({"send_blob": True})
try: try:
await asyncio.wait_for(blob.verified.wait(), 30) await asyncio.wait_for(blob.verified.wait(), 30, loop=self.loop)
self.send_response({"received_blob": True}) self.send_response({"received_blob": True})
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.send_response({"received_blob": False}) self.send_response({"received_blob": False})
self.incoming.clear() self.incoming.clear()
self.not_incoming.set()
self.writer.close_handle() self.writer.close_handle()
self.writer = None self.writer = None
else: else:
@ -156,39 +130,26 @@ class ReflectorServerProtocol(asyncio.Protocol):
class ReflectorServer: class ReflectorServer:
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000, def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000):
stop_event: asyncio.Event = None, incoming_event: asyncio.Event = None,
not_incoming_event: asyncio.Event = None, partial_needs=False):
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.server_task: typing.Optional[asyncio.Task] = None self.server_task: typing.Optional[asyncio.Task] = None
self.started_listening = asyncio.Event() self.started_listening = asyncio.Event(loop=self.loop)
self.stopped_listening = asyncio.Event()
self.incoming_event = incoming_event or asyncio.Event()
self.not_incoming_event = not_incoming_event or asyncio.Event()
self.response_chunk_size = response_chunk_size self.response_chunk_size = response_chunk_size
self.stop_event = stop_event
self.partial_needs = partial_needs # for testing cases where it doesn't know what it wants
def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'): def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'):
if self.server_task is not None: if self.server_task is not None:
raise Exception("already running") raise Exception("already running")
async def _start_server(): async def _start_server():
partial_event = asyncio.Event() server = await self.loop.create_server(
if not self.partial_needs: lambda: ReflectorServerProtocol(self.blob_manager, self.response_chunk_size),
partial_event.set() interface, port
server = await self.loop.create_server(lambda: ReflectorServerProtocol( )
self.blob_manager, self.response_chunk_size, self.stop_event, self.incoming_event,
self.not_incoming_event, partial_event), interface, port)
self.started_listening.set() self.started_listening.set()
self.stopped_listening.clear()
log.info("Reflector server listening on TCP %s:%i", interface, port) log.info("Reflector server listening on TCP %s:%i", interface, port)
try:
async with server: async with server:
await server.serve_forever() await server.serve_forever()
finally:
self.stopped_listening.set()
self.server_task = self.loop.create_task(_start_server()) self.server_task = self.loop.create_task(_start_server())

View file

@ -38,9 +38,7 @@ class StreamManager(SourceManager):
'stream_hash', 'stream_hash',
'full_status', # TODO: remove 'full_status', # TODO: remove
'blobs_remaining', 'blobs_remaining',
'blobs_in_stream', 'blobs_in_stream'
'uploading_to_reflector',
'is_fully_reflected'
}) })
def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', blob_manager: 'BlobManager', def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', blob_manager: 'BlobManager',
@ -54,7 +52,7 @@ class StreamManager(SourceManager):
self.re_reflect_task: Optional[asyncio.Task] = None self.re_reflect_task: Optional[asyncio.Task] = None
self.update_stream_finished_futs: typing.List[asyncio.Future] = [] self.update_stream_finished_futs: typing.List[asyncio.Future] = []
self.running_reflector_uploads: typing.Dict[str, asyncio.Task] = {} self.running_reflector_uploads: typing.Dict[str, asyncio.Task] = {}
self.started = asyncio.Event() self.started = asyncio.Event(loop=self.loop)
@property @property
def streams(self): def streams(self):
@ -70,7 +68,6 @@ class StreamManager(SourceManager):
async def recover_streams(self, file_infos: typing.List[typing.Dict]): async def recover_streams(self, file_infos: typing.List[typing.Dict]):
to_restore = [] to_restore = []
to_check = []
async def recover_stream(sd_hash: str, stream_hash: str, stream_name: str, async def recover_stream(sd_hash: str, stream_hash: str, stream_name: str,
suggested_file_name: str, key: str, suggested_file_name: str, key: str,
@ -83,7 +80,6 @@ class StreamManager(SourceManager):
if not descriptor: if not descriptor:
return return
to_restore.append((descriptor, sd_blob, content_fee)) to_restore.append((descriptor, sd_blob, content_fee))
to_check.extend([sd_blob.blob_hash] + [blob.blob_hash for blob in descriptor.blobs[:-1]])
await asyncio.gather(*[ await asyncio.gather(*[
recover_stream( recover_stream(
@ -95,8 +91,6 @@ class StreamManager(SourceManager):
if to_restore: if to_restore:
await self.storage.recover_streams(to_restore, self.config.download_dir) await self.storage.recover_streams(to_restore, self.config.download_dir)
if to_check:
await self.blob_manager.ensure_completed_blobs_status(to_check)
# if self.blob_manager._save_blobs: # if self.blob_manager._save_blobs:
# log.info("Recovered %i/%i attempted streams", len(to_restore), len(file_infos)) # log.info("Recovered %i/%i attempted streams", len(to_restore), len(file_infos))
@ -150,7 +144,7 @@ class StreamManager(SourceManager):
file_info['added_on'], file_info['fully_reflected'] file_info['added_on'], file_info['fully_reflected']
))) )))
if add_stream_tasks: if add_stream_tasks:
await asyncio.gather(*add_stream_tasks) await asyncio.gather(*add_stream_tasks, loop=self.loop)
log.info("Started stream manager with %i files", len(self._sources)) log.info("Started stream manager with %i files", len(self._sources))
if not self.node: if not self.node:
log.info("no DHT node given, resuming downloads trusting that we can contact reflector") log.info("no DHT node given, resuming downloads trusting that we can contact reflector")
@ -159,19 +153,12 @@ class StreamManager(SourceManager):
self.resume_saving_task = asyncio.ensure_future(asyncio.gather( self.resume_saving_task = asyncio.ensure_future(asyncio.gather(
*(self._sources[sd_hash].save_file(file_name, download_directory) *(self._sources[sd_hash].save_file(file_name, download_directory)
for (file_name, download_directory, sd_hash) in to_resume_saving), for (file_name, download_directory, sd_hash) in to_resume_saving),
loop=self.loop
)) ))
async def reflect_streams(self): async def reflect_streams(self):
try:
return await self._reflect_streams()
except Exception:
log.exception("reflector task encountered an unexpected error!")
async def _reflect_streams(self):
# todo: those debug statements are temporary for #2987 - remove them if its closed
while True: while True:
if self.config.reflect_streams and self.config.reflector_servers: if self.config.reflect_streams and self.config.reflector_servers:
log.debug("collecting streams to reflect")
sd_hashes = await self.storage.get_streams_to_re_reflect() sd_hashes = await self.storage.get_streams_to_re_reflect()
sd_hashes = [sd for sd in sd_hashes if sd in self._sources] sd_hashes = [sd for sd in sd_hashes if sd in self._sources]
batch = [] batch = []
@ -182,22 +169,18 @@ class StreamManager(SourceManager):
stream.fully_reflected.is_set(): stream.fully_reflected.is_set():
batch.append(self.reflect_stream(stream)) batch.append(self.reflect_stream(stream))
if len(batch) >= self.config.concurrent_reflector_uploads: if len(batch) >= self.config.concurrent_reflector_uploads:
log.debug("waiting for batch of %s reflecting streams", len(batch)) await asyncio.gather(*batch, loop=self.loop)
await asyncio.gather(*batch)
log.debug("done processing %s streams", len(batch))
batch = [] batch = []
if batch: if batch:
log.debug("waiting for batch of %s reflecting streams", len(batch)) await asyncio.gather(*batch, loop=self.loop)
await asyncio.gather(*batch) await asyncio.sleep(300, loop=self.loop)
log.debug("done processing %s streams", len(batch))
await asyncio.sleep(300)
async def start(self): async def start(self):
await super().start() await super().start()
self.re_reflect_task = self.loop.create_task(self.reflect_streams()) self.re_reflect_task = self.loop.create_task(self.reflect_streams())
async def stop(self): def stop(self):
await super().stop() super().stop()
if self.resume_saving_task and not self.resume_saving_task.done(): if self.resume_saving_task and not self.resume_saving_task.done():
self.resume_saving_task.cancel() self.resume_saving_task.cancel()
if self.re_reflect_task and not self.re_reflect_task.done(): if self.re_reflect_task and not self.re_reflect_task.done():
@ -216,7 +199,7 @@ class StreamManager(SourceManager):
server, port = random.choice(self.config.reflector_servers) server, port = random.choice(self.config.reflector_servers)
if stream.sd_hash in self.running_reflector_uploads: if stream.sd_hash in self.running_reflector_uploads:
return self.running_reflector_uploads[stream.sd_hash] return self.running_reflector_uploads[stream.sd_hash]
task = self.loop.create_task(self._retriable_reflect_stream(stream, server, port)) task = self.loop.create_task(stream.upload_to_reflector(server, port))
self.running_reflector_uploads[stream.sd_hash] = task self.running_reflector_uploads[stream.sd_hash] = task
task.add_done_callback( task.add_done_callback(
lambda _: None if stream.sd_hash not in self.running_reflector_uploads else lambda _: None if stream.sd_hash not in self.running_reflector_uploads else
@ -224,14 +207,6 @@ class StreamManager(SourceManager):
) )
return task return task
@staticmethod
async def _retriable_reflect_stream(stream, host, port):
sent = await stream.upload_to_reflector(host, port)
while not stream.is_fully_reflected and stream.reflector_progress > 0 and len(sent) > 0:
stream.reflector_progress = 0
sent = await stream.upload_to_reflector(host, port)
return sent
async def create(self, file_path: str, key: Optional[bytes] = None, async def create(self, file_path: str, key: Optional[bytes] = None,
iv_generator: Optional[typing.Generator[bytes, None, None]] = None) -> ManagedStream: iv_generator: Optional[typing.Generator[bytes, None, None]] = None) -> ManagedStream:
descriptor = await StreamDescriptor.create_stream( descriptor = await StreamDescriptor.create_stream(
@ -239,7 +214,7 @@ class StreamManager(SourceManager):
blob_completed_callback=self.blob_manager.blob_completed blob_completed_callback=self.blob_manager.blob_completed
) )
await self.storage.store_stream( await self.storage.store_stream(
self.blob_manager.get_blob(descriptor.sd_hash, is_mine=True), descriptor self.blob_manager.get_blob(descriptor.sd_hash), descriptor
) )
row_id = await self.storage.save_published_file( row_id = await self.storage.save_published_file(
descriptor.stream_hash, os.path.basename(file_path), os.path.dirname(file_path), 0 descriptor.stream_hash, os.path.basename(file_path), os.path.dirname(file_path), 0
@ -260,7 +235,7 @@ class StreamManager(SourceManager):
return return
if source.identifier in self.running_reflector_uploads: if source.identifier in self.running_reflector_uploads:
self.running_reflector_uploads[source.identifier].cancel() self.running_reflector_uploads[source.identifier].cancel()
await source.stop_tasks() source.stop_tasks()
if source.identifier in self.streams: if source.identifier in self.streams:
del self.streams[source.identifier] del self.streams[source.identifier]
blob_hashes = [source.identifier] + [b.blob_hash for b in source.descriptor.blobs[:-1]] blob_hashes = [source.identifier] + [b.blob_hash for b in source.descriptor.blobs[:-1]]
@ -270,7 +245,4 @@ class StreamManager(SourceManager):
os.remove(source.full_path) os.remove(source.full_path)
async def stream_partial_content(self, request: Request, sd_hash: str): async def stream_partial_content(self, request: Request, sd_hash: str):
stream = self._sources[sd_hash] return await self._sources[sd_hash].stream_file(request, self.node)
if not stream.downloader.node:
stream.downloader.node = self.node
return await stream.stream_file(request)

View file

@ -17,21 +17,18 @@ from functools import partial
from lbry.wallet import WalletManager, Wallet, Ledger, Account, Transaction from lbry.wallet import WalletManager, Wallet, Ledger, Account, Transaction
from lbry.conf import Config from lbry.conf import Config
from lbry.wallet.util import satoshis_to_coins from lbry.wallet.util import satoshis_to_coins
from lbry.wallet.dewies import lbc_to_dewies
from lbry.wallet.orchstr8 import Conductor from lbry.wallet.orchstr8 import Conductor
from lbry.wallet.orchstr8.node import LBCWalletNode, WalletNode from lbry.wallet.orchstr8.node import BlockchainNode, WalletNode
from lbry.schema.claim import Claim
from lbry.extras.daemon.daemon import Daemon, jsonrpc_dumps_pretty from lbry.extras.daemon.daemon import Daemon, jsonrpc_dumps_pretty
from lbry.extras.daemon.components import Component, WalletComponent from lbry.extras.daemon.components import Component, WalletComponent
from lbry.extras.daemon.components import ( from lbry.extras.daemon.components import (
DHT_COMPONENT, DHT_COMPONENT, HASH_ANNOUNCER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT,
HASH_ANNOUNCER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT, UPNP_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT
UPNP_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, LIBTORRENT_COMPONENT
) )
from lbry.extras.daemon.componentmanager import ComponentManager from lbry.extras.daemon.componentmanager import ComponentManager
from lbry.extras.daemon.exchange_rate_manager import ( from lbry.extras.daemon.exchange_rate_manager import (
ExchangeRateManager, ExchangeRate, BittrexBTCFeed, BittrexUSDFeed ExchangeRateManager, ExchangeRate, LBRYFeed, LBRYBTCFeed
) )
from lbry.extras.daemon.storage import SQLiteStorage from lbry.extras.daemon.storage import SQLiteStorage
from lbry.blob.blob_manager import BlobManager from lbry.blob.blob_manager import BlobManager
@ -86,7 +83,6 @@ class AsyncioTestCase(unittest.TestCase):
# https://bugs.python.org/issue32972 # https://bugs.python.org/issue32972
LOOP_SLOW_CALLBACK_DURATION = 0.2 LOOP_SLOW_CALLBACK_DURATION = 0.2
TIMEOUT = 120.0
maxDiff = None maxDiff = None
@ -134,18 +130,15 @@ class AsyncioTestCase(unittest.TestCase):
with outcome.testPartExecutor(self): with outcome.testPartExecutor(self):
self.setUp() self.setUp()
self.add_timeout()
self.loop.run_until_complete(self.asyncSetUp()) self.loop.run_until_complete(self.asyncSetUp())
if outcome.success: if outcome.success:
outcome.expecting_failure = expecting_failure outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True): with outcome.testPartExecutor(self, isTest=True):
maybe_coroutine = testMethod() maybe_coroutine = testMethod()
if asyncio.iscoroutine(maybe_coroutine): if asyncio.iscoroutine(maybe_coroutine):
self.add_timeout()
self.loop.run_until_complete(maybe_coroutine) self.loop.run_until_complete(maybe_coroutine)
outcome.expecting_failure = False outcome.expecting_failure = False
with outcome.testPartExecutor(self): with outcome.testPartExecutor(self):
self.add_timeout()
self.loop.run_until_complete(self.asyncTearDown()) self.loop.run_until_complete(self.asyncTearDown())
self.tearDown() self.tearDown()
@ -193,25 +186,8 @@ class AsyncioTestCase(unittest.TestCase):
with outcome.testPartExecutor(self): with outcome.testPartExecutor(self):
maybe_coroutine = function(*args, **kwargs) maybe_coroutine = function(*args, **kwargs)
if asyncio.iscoroutine(maybe_coroutine): if asyncio.iscoroutine(maybe_coroutine):
self.add_timeout()
self.loop.run_until_complete(maybe_coroutine) self.loop.run_until_complete(maybe_coroutine)
def cancel(self):
for task in asyncio.all_tasks(self.loop):
if not task.done():
task.print_stack()
task.cancel()
def add_timeout(self):
if self.TIMEOUT:
self.loop.call_later(self.TIMEOUT, self.check_timeout, time())
def check_timeout(self, started):
if time() - started >= self.TIMEOUT:
self.cancel()
else:
self.loop.call_later(self.TIMEOUT, self.check_timeout, started)
class AdvanceTimeTestCase(AsyncioTestCase): class AdvanceTimeTestCase(AsyncioTestCase):
@ -236,7 +212,7 @@ class IntegrationTestCase(AsyncioTestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.conductor: Optional[Conductor] = None self.conductor: Optional[Conductor] = None
self.blockchain: Optional[LBCWalletNode] = None self.blockchain: Optional[BlockchainNode] = None
self.wallet_node: Optional[WalletNode] = None self.wallet_node: Optional[WalletNode] = None
self.manager: Optional[WalletManager] = None self.manager: Optional[WalletManager] = None
self.ledger: Optional[Ledger] = None self.ledger: Optional[Ledger] = None
@ -245,15 +221,13 @@ class IntegrationTestCase(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.conductor = Conductor(seed=self.SEED) self.conductor = Conductor(seed=self.SEED)
await self.conductor.start_lbcd() await self.conductor.start_blockchain()
self.addCleanup(self.conductor.stop_lbcd) self.addCleanup(self.conductor.stop_blockchain)
await self.conductor.start_lbcwallet()
self.addCleanup(self.conductor.stop_lbcwallet)
await self.conductor.start_spv() await self.conductor.start_spv()
self.addCleanup(self.conductor.stop_spv) self.addCleanup(self.conductor.stop_spv)
await self.conductor.start_wallet() await self.conductor.start_wallet()
self.addCleanup(self.conductor.stop_wallet) self.addCleanup(self.conductor.stop_wallet)
self.blockchain = self.conductor.lbcwallet_node self.blockchain = self.conductor.blockchain_node
self.wallet_node = self.conductor.wallet_node self.wallet_node = self.conductor.wallet_node
self.manager = self.wallet_node.manager self.manager = self.wallet_node.manager
self.ledger = self.wallet_node.ledger self.ledger = self.wallet_node.ledger
@ -267,13 +241,6 @@ class IntegrationTestCase(AsyncioTestCase):
def broadcast(self, tx): def broadcast(self, tx):
return self.ledger.broadcast(tx) return self.ledger.broadcast(tx)
async def broadcast_and_confirm(self, tx, ledger=None):
ledger = ledger or self.ledger
notifications = asyncio.create_task(ledger.wait(tx))
await ledger.broadcast(tx)
await notifications
await self.generate_and_wait(1, [tx.id], ledger)
async def on_header(self, height): async def on_header(self, height):
if self.ledger.headers.height < height: if self.ledger.headers.height < height:
await self.ledger.on_header.where( await self.ledger.on_header.where(
@ -281,29 +248,11 @@ class IntegrationTestCase(AsyncioTestCase):
) )
return True return True
async def send_to_address_and_wait(self, address, amount, blocks_to_generate=0, ledger=None): def on_transaction_id(self, txid, ledger=None):
tx_watch = [] return (ledger or self.ledger).on_transaction.where(
txid = None lambda e: e.tx.id == txid
done = False
watcher = (ledger or self.ledger).on_transaction.where(
lambda e: e.tx.id == txid or done or tx_watch.append(e.tx.id)
) )
txid = await self.blockchain.send_to_address(address, amount)
done = txid in tx_watch
await watcher
await self.generate_and_wait(blocks_to_generate, [txid], ledger)
return txid
async def generate_and_wait(self, blocks_to_generate, txids, ledger=None):
if blocks_to_generate > 0:
watcher = (ledger or self.ledger).on_transaction.where(
lambda e: ((e.tx.id in txids and txids.remove(e.tx.id)), len(txids) <= 0)[-1] # multi-statement lambda
)
await self.generate(blocks_to_generate)
await watcher
def on_address_update(self, address): def on_address_update(self, address):
return self.ledger.on_transaction.where( return self.ledger.on_transaction.where(
lambda e: e.address == address lambda e: e.address == address
@ -314,22 +263,6 @@ class IntegrationTestCase(AsyncioTestCase):
lambda e: e.tx.id == tx.id and e.address == address lambda e: e.tx.id == tx.id and e.address == address
) )
async def generate(self, blocks):
""" Ask lbrycrd to generate some blocks and wait until ledger has them. """
prepare = self.ledger.on_header.where(self.blockchain.is_expected_block)
self.conductor.spv_node.server.synchronized.clear()
await self.blockchain.generate(blocks)
height = self.blockchain.block_expected
await prepare # no guarantee that it didn't happen already, so start waiting from before calling generate
while True:
await self.conductor.spv_node.server.synchronized.wait()
self.conductor.spv_node.server.synchronized.clear()
if self.conductor.spv_node.server.db.db_height < height:
continue
if self.conductor.spv_node.server._es_height < height:
continue
break
class FakeExchangeRateManager(ExchangeRateManager): class FakeExchangeRateManager(ExchangeRateManager):
@ -348,8 +281,8 @@ class FakeExchangeRateManager(ExchangeRateManager):
def get_fake_exchange_rate_manager(rates=None): def get_fake_exchange_rate_manager(rates=None):
return FakeExchangeRateManager( return FakeExchangeRateManager(
[BittrexBTCFeed(), BittrexUSDFeed()], [LBRYFeed(), LBRYBTCFeed()],
rates or {'BTCLBC': 3.0, 'USDLBC': 2.0} rates or {'BTCLBC': 3.0, 'USDBTC': 2.0}
) )
@ -387,32 +320,26 @@ class CommandTestCase(IntegrationTestCase):
self.server_blob_manager = None self.server_blob_manager = None
self.server = None self.server = None
self.reflector = None self.reflector = None
self.skip_libtorrent = True
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp()
logging.getLogger('lbry.blob_exchange').setLevel(self.VERBOSITY) logging.getLogger('lbry.blob_exchange').setLevel(self.VERBOSITY)
logging.getLogger('lbry.daemon').setLevel(self.VERBOSITY) logging.getLogger('lbry.daemon').setLevel(self.VERBOSITY)
logging.getLogger('lbry.stream').setLevel(self.VERBOSITY) logging.getLogger('lbry.stream').setLevel(self.VERBOSITY)
logging.getLogger('lbry.wallet').setLevel(self.VERBOSITY) logging.getLogger('lbry.wallet').setLevel(self.VERBOSITY)
await super().asyncSetUp()
self.daemon = await self.add_daemon(self.wallet_node) self.daemon = await self.add_daemon(self.wallet_node)
await self.account.ensure_address_gap() await self.account.ensure_address_gap()
address = (await self.account.receiving.get_addresses(limit=1, only_usable=True))[0] address = (await self.account.receiving.get_addresses(limit=1, only_usable=True))[0]
await self.send_to_address_and_wait(address, 10, 6) sendtxid = await self.blockchain.send_to_address(address, 10)
await self.confirm_tx(sendtxid)
await self.generate(5)
server_tmp_dir = tempfile.mkdtemp() server_tmp_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, server_tmp_dir) self.addCleanup(shutil.rmtree, server_tmp_dir)
self.server_config = Config( self.server_config = Config()
data_dir=server_tmp_dir,
wallet_dir=server_tmp_dir,
save_files=True,
download_dir=server_tmp_dir
)
self.server_config.transaction_cache_size = 10000
self.server_storage = SQLiteStorage(self.server_config, ':memory:') self.server_storage = SQLiteStorage(self.server_config, ':memory:')
await self.server_storage.open() await self.server_storage.open()
@ -435,7 +362,6 @@ class CommandTestCase(IntegrationTestCase):
await daemon.stop() await daemon.stop()
async def add_daemon(self, wallet_node=None, seed=None): async def add_daemon(self, wallet_node=None, seed=None):
start_wallet_node = False
if wallet_node is None: if wallet_node is None:
wallet_node = WalletNode( wallet_node = WalletNode(
self.wallet_node.manager_class, self.wallet_node.manager_class,
@ -443,42 +369,30 @@ class CommandTestCase(IntegrationTestCase):
port=self.extra_wallet_node_port port=self.extra_wallet_node_port
) )
self.extra_wallet_node_port += 1 self.extra_wallet_node_port += 1
start_wallet_node = True await wallet_node.start(self.conductor.spv_node, seed=seed)
self.extra_wallet_nodes.append(wallet_node)
upload_dir = os.path.join(wallet_node.data_path, 'uploads') upload_dir = os.path.join(wallet_node.data_path, 'uploads')
os.mkdir(upload_dir) os.mkdir(upload_dir)
conf = Config( conf = Config()
# needed during instantiation to access known_hubs path conf.data_dir = wallet_node.data_path
data_dir=wallet_node.data_path, conf.wallet_dir = wallet_node.data_path
wallet_dir=wallet_node.data_path, conf.download_dir = wallet_node.data_path
save_files=True,
download_dir=wallet_node.data_path
)
conf.upload_dir = upload_dir # not a real conf setting conf.upload_dir = upload_dir # not a real conf setting
conf.share_usage_data = False conf.share_usage_data = False
conf.use_upnp = False conf.use_upnp = False
conf.reflect_streams = True conf.reflect_streams = True
conf.blockchain_name = 'lbrycrd_regtest' conf.blockchain_name = 'lbrycrd_regtest'
conf.lbryum_servers = [(self.conductor.spv_node.hostname, self.conductor.spv_node.port)] conf.lbryum_servers = [('127.0.0.1', 50001)]
conf.reflector_servers = [('127.0.0.1', 5566)] conf.reflector_servers = [('127.0.0.1', 5566)]
conf.fixed_peers = [('127.0.0.1', 5567)]
conf.known_dht_nodes = [] conf.known_dht_nodes = []
conf.blob_lru_cache_size = self.blob_lru_cache_size conf.blob_lru_cache_size = self.blob_lru_cache_size
conf.transaction_cache_size = 10000
conf.components_to_skip = [ conf.components_to_skip = [
DHT_COMPONENT, UPNP_COMPONENT, HASH_ANNOUNCER_COMPONENT, DHT_COMPONENT, UPNP_COMPONENT, HASH_ANNOUNCER_COMPONENT,
PEER_PROTOCOL_SERVER_COMPONENT PEER_PROTOCOL_SERVER_COMPONENT
] ]
if self.skip_libtorrent:
conf.components_to_skip.append(LIBTORRENT_COMPONENT)
if start_wallet_node:
await wallet_node.start(self.conductor.spv_node, seed=seed, config=conf)
self.extra_wallet_nodes.append(wallet_node)
else:
wallet_node.manager.config = conf wallet_node.manager.config = conf
wallet_node.manager.ledger.config['known_hubs'] = conf.known_hubs
def wallet_maker(component_manager): def wallet_maker(component_manager):
wallet_component = WalletComponent(component_manager) wallet_component = WalletComponent(component_manager)
@ -489,7 +403,7 @@ class CommandTestCase(IntegrationTestCase):
daemon = Daemon(conf, ComponentManager( daemon = Daemon(conf, ComponentManager(
conf, skip_components=conf.components_to_skip, wallet=wallet_maker, conf, skip_components=conf.components_to_skip, wallet=wallet_maker,
exchange_rate_manager=partial(ExchangeRateManagerComponent, rates={ exchange_rate_manager=partial(ExchangeRateManagerComponent, rates={
'BTCLBC': 1.0, 'USDLBC': 2.0 'BTCLBC': 1.0, 'USDBTC': 2.0
}) })
)) ))
await daemon.initialize() await daemon.initialize()
@ -499,14 +413,9 @@ class CommandTestCase(IntegrationTestCase):
async def confirm_tx(self, txid, ledger=None): async def confirm_tx(self, txid, ledger=None):
""" Wait for tx to be in mempool, then generate a block, wait for tx to be in a block. """ """ Wait for tx to be in mempool, then generate a block, wait for tx to be in a block. """
# await (ledger or self.ledger).on_transaction.where(lambda e: e.tx.id == txid) await self.on_transaction_id(txid, ledger)
on_tx = (ledger or self.ledger).on_transaction.where(lambda e: e.tx.id == txid) await self.generate(1)
await asyncio.wait([self.generate(1), on_tx], timeout=5) await self.on_transaction_id(txid, ledger)
# # actually, if it's in the mempool or in the block we're fine
# await self.generate_and_wait(1, [txid], ledger=ledger)
# return txid
return txid return txid
async def on_transaction_dict(self, tx): async def on_transaction_dict(self, tx):
@ -521,6 +430,11 @@ class CommandTestCase(IntegrationTestCase):
addresses.add(txo['address']) addresses.add(txo['address'])
return list(addresses) return list(addresses)
async def generate(self, blocks):
""" Ask lbrycrd to generate some blocks and wait until ledger has them. """
await self.blockchain.generate(blocks)
await self.ledger.on_header.where(self.blockchain.is_expected_block)
async def blockchain_claim_name(self, name: str, value: str, amount: str, confirm=True): async def blockchain_claim_name(self, name: str, value: str, amount: str, confirm=True):
txid = await self.blockchain._cli_cmnd('claimname', name, value, amount) txid = await self.blockchain._cli_cmnd('claimname', name, value, amount)
if confirm: if confirm:
@ -541,27 +455,12 @@ class CommandTestCase(IntegrationTestCase):
""" Synchronous version of `out` method. """ """ Synchronous version of `out` method. """
return json.loads(jsonrpc_dumps_pretty(value, ledger=self.ledger))['result'] return json.loads(jsonrpc_dumps_pretty(value, ledger=self.ledger))['result']
async def confirm_and_render(self, awaitable, confirm, return_tx=False) -> Transaction: async def confirm_and_render(self, awaitable, confirm) -> Transaction:
tx = await awaitable tx = await awaitable
if confirm: if confirm:
await self.ledger.wait(tx) await self.ledger.wait(tx)
await self.generate(1) await self.generate(1)
await self.ledger.wait(tx, self.blockchain.block_expected) await self.ledger.wait(tx, self.blockchain.block_expected)
if not return_tx:
return self.sout(tx)
return tx
async def create_nondeterministic_channel(self, name, price, pubkey_bytes, daemon=None, blocking=False):
account = (daemon or self.daemon).wallet_manager.default_account
claim_address = await account.receiving.get_or_create_usable_address()
claim = Claim()
claim.channel.public_key_bytes = pubkey_bytes
tx = await Transaction.claim_create(
name, claim, lbc_to_dewies(price),
claim_address, [self.account], self.account
)
await tx.sign([self.account])
await (daemon or self.daemon).broadcast_or_release(tx, blocking)
return self.sout(tx) return self.sout(tx)
def create_upload_file(self, data, prefix=None, suffix=None): def create_upload_file(self, data, prefix=None, suffix=None):
@ -573,26 +472,26 @@ class CommandTestCase(IntegrationTestCase):
async def stream_create( async def stream_create(
self, name='hovercraft', bid='1.0', file_path=None, self, name='hovercraft', bid='1.0', file_path=None,
data=b'hi!', confirm=True, prefix=None, suffix=None, return_tx=False, **kwargs): data=b'hi!', confirm=True, prefix=None, suffix=None, **kwargs):
if file_path is None and data is not None: if file_path is None:
file_path = self.create_upload_file(data=data, prefix=prefix, suffix=suffix) file_path = self.create_upload_file(data=data, prefix=prefix, suffix=suffix)
return await self.confirm_and_render( return await self.confirm_and_render(
self.daemon.jsonrpc_stream_create(name, bid, file_path=file_path, **kwargs), confirm, return_tx self.daemon.jsonrpc_stream_create(name, bid, file_path=file_path, **kwargs), confirm
) )
async def stream_update( async def stream_update(
self, claim_id, data=None, prefix=None, suffix=None, confirm=True, return_tx=False, **kwargs): self, claim_id, data=None, prefix=None, suffix=None, confirm=True, **kwargs):
if data is not None: if data is not None:
file_path = self.create_upload_file(data=data, prefix=prefix, suffix=suffix) file_path = self.create_upload_file(data=data, prefix=prefix, suffix=suffix)
return await self.confirm_and_render( return await self.confirm_and_render(
self.daemon.jsonrpc_stream_update(claim_id, file_path=file_path, **kwargs), confirm, return_tx self.daemon.jsonrpc_stream_update(claim_id, file_path=file_path, **kwargs), confirm
) )
return await self.confirm_and_render( return await self.confirm_and_render(
self.daemon.jsonrpc_stream_update(claim_id, **kwargs), confirm self.daemon.jsonrpc_stream_update(claim_id, **kwargs), confirm
) )
async def stream_repost(self, claim_id, name='repost', bid='1.0', confirm=True, **kwargs): def stream_repost(self, claim_id, name='repost', bid='1.0', confirm=True, **kwargs):
return await self.confirm_and_render( return self.confirm_and_render(
self.daemon.jsonrpc_stream_repost(claim_id=claim_id, name=name, bid=bid, **kwargs), confirm self.daemon.jsonrpc_stream_repost(claim_id=claim_id, name=name, bid=bid, **kwargs), confirm
) )
@ -603,11 +502,6 @@ class CommandTestCase(IntegrationTestCase):
self.daemon.jsonrpc_stream_abandon(*args, **kwargs), confirm self.daemon.jsonrpc_stream_abandon(*args, **kwargs), confirm
) )
async def purchase_create(self, *args, confirm=True, **kwargs):
return await self.confirm_and_render(
self.daemon.jsonrpc_purchase_create(*args, **kwargs), confirm
)
async def publish(self, name, *args, confirm=True, **kwargs): async def publish(self, name, *args, confirm=True, **kwargs):
return await self.confirm_and_render( return await self.confirm_and_render(
self.daemon.jsonrpc_publish(name, *args, **kwargs), confirm self.daemon.jsonrpc_publish(name, *args, **kwargs), confirm
@ -661,51 +555,18 @@ class CommandTestCase(IntegrationTestCase):
self.daemon.jsonrpc_support_abandon(*args, **kwargs), confirm self.daemon.jsonrpc_support_abandon(*args, **kwargs), confirm
) )
async def account_send(self, *args, confirm=True, **kwargs): async def resolve(self, uri):
return await self.confirm_and_render( return (await self.out(self.daemon.jsonrpc_resolve(uri)))[uri]
self.daemon.jsonrpc_account_send(*args, **kwargs), confirm
)
async def wallet_send(self, *args, confirm=True, **kwargs):
return await self.confirm_and_render(
self.daemon.jsonrpc_wallet_send(*args, **kwargs), confirm
)
async def txo_spend(self, *args, confirm=True, **kwargs):
txs = await self.daemon.jsonrpc_txo_spend(*args, **kwargs)
if confirm:
await asyncio.wait([self.ledger.wait(tx) for tx in txs])
await self.generate(1)
await asyncio.wait([self.ledger.wait(tx, self.blockchain.block_expected) for tx in txs])
return self.sout(txs)
async def blob_clean(self):
return await self.out(self.daemon.jsonrpc_blob_clean())
async def status(self):
return await self.out(self.daemon.jsonrpc_status())
async def resolve(self, uri, **kwargs):
return (await self.out(self.daemon.jsonrpc_resolve(uri, **kwargs)))[uri]
async def claim_search(self, **kwargs): async def claim_search(self, **kwargs):
return (await self.out(self.daemon.jsonrpc_claim_search(**kwargs)))['items'] return (await self.out(self.daemon.jsonrpc_claim_search(**kwargs)))['items']
async def get_claim_by_claim_id(self, claim_id):
return await self.out(self.ledger.get_claim_by_claim_id(claim_id))
async def file_list(self, *args, **kwargs): async def file_list(self, *args, **kwargs):
return (await self.out(self.daemon.jsonrpc_file_list(*args, **kwargs)))['items'] return (await self.out(self.daemon.jsonrpc_file_list(*args, **kwargs)))['items']
async def txo_list(self, *args, **kwargs): async def txo_list(self, *args, **kwargs):
return (await self.out(self.daemon.jsonrpc_txo_list(*args, **kwargs)))['items'] return (await self.out(self.daemon.jsonrpc_txo_list(*args, **kwargs)))['items']
async def txo_sum(self, *args, **kwargs):
return await self.out(self.daemon.jsonrpc_txo_sum(*args, **kwargs))
async def txo_plot(self, *args, **kwargs):
return await self.out(self.daemon.jsonrpc_txo_plot(*args, **kwargs))
async def claim_list(self, *args, **kwargs): async def claim_list(self, *args, **kwargs):
return (await self.out(self.daemon.jsonrpc_claim_list(*args, **kwargs)))['items'] return (await self.out(self.daemon.jsonrpc_claim_list(*args, **kwargs)))['items']
@ -718,9 +579,6 @@ class CommandTestCase(IntegrationTestCase):
async def transaction_list(self, *args, **kwargs): async def transaction_list(self, *args, **kwargs):
return (await self.out(self.daemon.jsonrpc_transaction_list(*args, **kwargs)))['items'] return (await self.out(self.daemon.jsonrpc_transaction_list(*args, **kwargs)))['items']
async def blob_list(self, *args, **kwargs):
return (await self.out(self.daemon.jsonrpc_blob_list(*args, **kwargs)))['items']
@staticmethod @staticmethod
def get_claim_id(tx): def get_claim_id(tx):
return tx['outputs'][0]['claim_id'] return tx['outputs'][0]['claim_id']

View file

@ -10,13 +10,47 @@ from typing import Optional
import libtorrent import libtorrent
NOTIFICATION_MASKS = [
"error",
"peer",
"port_mapping",
"storage",
"tracker",
"debug",
"status",
"progress",
"ip_block",
"dht",
"stats",
"session_log",
"torrent_log",
"peer_log",
"incoming_request",
"dht_log",
"dht_operation",
"port_mapping_log",
"picker_log",
"file_progress",
"piece_progress",
"upload",
"block_progress"
]
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
DEFAULT_FLAGS = ( # fixme: somehow the logic here is inverted? DEFAULT_FLAGS = ( # fixme: somehow the logic here is inverted?
libtorrent.add_torrent_params_flags_t.flag_auto_managed libtorrent.add_torrent_params_flags_t.flag_auto_managed
| libtorrent.add_torrent_params_flags_t.flag_update_subscribe | libtorrent.add_torrent_params_flags_t.flag_update_subscribe
) )
def get_notification_type(notification) -> str:
for i, notification_type in enumerate(NOTIFICATION_MASKS):
if (1 << i) & notification:
return notification_type
raise ValueError("unrecognized notification type")
class TorrentHandle: class TorrentHandle:
def __init__(self, loop, executor, handle): def __init__(self, loop, executor, handle):
self._loop = loop self._loop = loop
@ -87,7 +121,7 @@ class TorrentHandle:
self._show_status() self._show_status()
if self.finished.is_set(): if self.finished.is_set():
break break
await asyncio.sleep(0.1) await asyncio.sleep(0.1, loop=self._loop)
async def pause(self): async def pause(self):
await self._loop.run_in_executor( await self._loop.run_in_executor(
@ -122,8 +156,10 @@ class TorrentSession:
async def bind(self, interface: str = '0.0.0.0', port: int = 10889): async def bind(self, interface: str = '0.0.0.0', port: int = 10889):
settings = { settings = {
'listen_interfaces': f"{interface}:{port}", 'listen_interfaces': f"{interface}:{port}",
'enable_natpmp': False, 'enable_outgoing_utp': True,
'enable_upnp': False 'enable_incoming_utp': True,
'enable_outgoing_tcp': False,
'enable_incoming_tcp': False
} }
self._session = await self._loop.run_in_executor( self._session = await self._loop.run_in_executor(
self._executor, libtorrent.session, settings # pylint: disable=c-extension-no-member self._executor, libtorrent.session, settings # pylint: disable=c-extension-no-member
@ -150,7 +186,7 @@ class TorrentSession:
await self._loop.run_in_executor( await self._loop.run_in_executor(
self._executor, self._pop_alerts self._executor, self._pop_alerts
) )
await asyncio.sleep(1) await asyncio.sleep(1, loop=self._loop)
async def pause(self): async def pause(self):
await self._loop.run_in_executor( await self._loop.run_in_executor(

View file

@ -36,7 +36,7 @@ class Torrent:
def __init__(self, loop, handle): def __init__(self, loop, handle):
self._loop = loop self._loop = loop
self._handle = handle self._handle = handle
self.finished = asyncio.Event() self.finished = asyncio.Event(loop=loop)
def _threaded_update_status(self): def _threaded_update_status(self):
status = self._handle.status() status = self._handle.status()
@ -58,7 +58,7 @@ class Torrent:
log.info("finished downloading torrent!") log.info("finished downloading torrent!")
await self.pause() await self.pause()
break break
await asyncio.sleep(1) await asyncio.sleep(1, loop=self._loop)
async def pause(self): async def pause(self):
log.info("pause torrent") log.info("pause torrent")

View file

@ -74,7 +74,7 @@ class TorrentSource(ManagedDownloadSource):
def bt_infohash(self): def bt_infohash(self):
return self.identifier return self.identifier
async def stop_tasks(self): def stop_tasks(self):
pass pass
@property @property
@ -118,8 +118,8 @@ class TorrentManager(SourceManager):
async def start(self): async def start(self):
await super().start() await super().start()
async def stop(self): def stop(self):
await super().stop() super().stop()
log.info("finished stopping the torrent manager") log.info("finished stopping the torrent manager")
async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False): async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False):

View file

@ -1,285 +0,0 @@
import random
import socket
import string
import struct
import asyncio
import logging
import time
import ipaddress
from collections import namedtuple
from functools import reduce
from typing import Optional
from lbry.dht.node import get_kademlia_peers_from_hosts
from lbry.utils import resolve_host, async_timed_cache, cache_concurrent
from lbry.wallet.stream import StreamController
from lbry import version
log = logging.getLogger(__name__)
CONNECTION_EXPIRES_AFTER_SECONDS = 50
PREFIX = 'LB' # todo: PR BEP20 to add ourselves
DEFAULT_TIMEOUT_SECONDS = 10.0
DEFAULT_CONCURRENCY_LIMIT = 100
# see: http://bittorrent.org/beps/bep_0015.html and http://xbtt.sourceforge.net/udp_tracker_protocol.html
ConnectRequest = namedtuple("ConnectRequest", ["connection_id", "action", "transaction_id"])
ConnectResponse = namedtuple("ConnectResponse", ["action", "transaction_id", "connection_id"])
AnnounceRequest = namedtuple("AnnounceRequest",
["connection_id", "action", "transaction_id", "info_hash", "peer_id", "downloaded", "left",
"uploaded", "event", "ip_addr", "key", "num_want", "port"])
AnnounceResponse = namedtuple("AnnounceResponse",
["action", "transaction_id", "interval", "leechers", "seeders", "peers"])
CompactIPv4Peer = namedtuple("CompactPeer", ["address", "port"])
ScrapeRequest = namedtuple("ScrapeRequest", ["connection_id", "action", "transaction_id", "infohashes"])
ScrapeResponse = namedtuple("ScrapeResponse", ["action", "transaction_id", "items"])
ScrapeResponseItem = namedtuple("ScrapeResponseItem", ["seeders", "completed", "leechers"])
ErrorResponse = namedtuple("ErrorResponse", ["action", "transaction_id", "message"])
structs = {
ConnectRequest: struct.Struct(">QII"),
ConnectResponse: struct.Struct(">IIQ"),
AnnounceRequest: struct.Struct(">QII20s20sQQQIIIiH"),
AnnounceResponse: struct.Struct(">IIIII"),
CompactIPv4Peer: struct.Struct(">IH"),
ScrapeRequest: struct.Struct(">QII"),
ScrapeResponse: struct.Struct(">II"),
ScrapeResponseItem: struct.Struct(">III"),
ErrorResponse: struct.Struct(">II")
}
def decode(cls, data, offset=0):
decoder = structs[cls]
if cls is AnnounceResponse:
return AnnounceResponse(*decoder.unpack_from(data, offset),
peers=[decode(CompactIPv4Peer, data, index) for index in range(20, len(data), 6)])
elif cls is ScrapeResponse:
return ScrapeResponse(*decoder.unpack_from(data, offset),
items=[decode(ScrapeResponseItem, data, index) for index in range(8, len(data), 12)])
elif cls is ErrorResponse:
return ErrorResponse(*decoder.unpack_from(data, offset), data[decoder.size:])
return cls(*decoder.unpack_from(data, offset))
def encode(obj):
if isinstance(obj, ScrapeRequest):
return structs[ScrapeRequest].pack(*obj[:-1]) + b''.join(obj.infohashes)
elif isinstance(obj, ErrorResponse):
return structs[ErrorResponse].pack(*obj[:-1]) + obj.message
elif isinstance(obj, AnnounceResponse):
return structs[AnnounceResponse].pack(*obj[:-1]) + b''.join([encode(peer) for peer in obj.peers])
return structs[type(obj)].pack(*obj)
def make_peer_id(random_part: Optional[str] = None) -> bytes:
# see https://wiki.theory.org/BitTorrentSpecification#peer_id and https://www.bittorrent.org/beps/bep_0020.html
# not to confuse with node id; peer id identifies uniquely the software, version and instance
random_part = random_part or ''.join(random.choice(string.ascii_letters) for _ in range(20))
return f"{PREFIX}-{'-'.join(map(str, version))}-{random_part}"[:20].encode()
class UDPTrackerClientProtocol(asyncio.DatagramProtocol):
def __init__(self, timeout: float = DEFAULT_TIMEOUT_SECONDS):
self.transport = None
self.data_queue = {}
self.timeout = timeout
self.semaphore = asyncio.Semaphore(DEFAULT_CONCURRENCY_LIMIT)
def connection_made(self, transport: asyncio.DatagramTransport) -> None:
self.transport = transport
async def request(self, obj, tracker_ip, tracker_port):
self.data_queue[obj.transaction_id] = asyncio.get_running_loop().create_future()
try:
async with self.semaphore:
self.transport.sendto(encode(obj), (tracker_ip, tracker_port))
return await asyncio.wait_for(self.data_queue[obj.transaction_id], self.timeout)
finally:
self.data_queue.pop(obj.transaction_id, None)
async def connect(self, tracker_ip, tracker_port):
transaction_id = random.getrandbits(32)
return decode(ConnectResponse,
await self.request(ConnectRequest(0x41727101980, 0, transaction_id), tracker_ip, tracker_port))
@cache_concurrent
@async_timed_cache(CONNECTION_EXPIRES_AFTER_SECONDS)
async def ensure_connection_id(self, peer_id, tracker_ip, tracker_port):
# peer_id is just to ensure cache coherency
return (await self.connect(tracker_ip, tracker_port)).connection_id
async def announce(self, info_hash, peer_id, port, tracker_ip, tracker_port, stopped=False):
connection_id = await self.ensure_connection_id(peer_id, tracker_ip, tracker_port)
# this should make the key deterministic but unique per info hash + peer id
key = int.from_bytes(info_hash[:4], "big") ^ int.from_bytes(peer_id[:4], "big") ^ port
transaction_id = random.getrandbits(32)
req = AnnounceRequest(
connection_id, 1, transaction_id, info_hash, peer_id, 0, 0, 0, 3 if stopped else 1, 0, key, -1, port)
return decode(AnnounceResponse, await self.request(req, tracker_ip, tracker_port))
async def scrape(self, infohashes, tracker_ip, tracker_port, connection_id=None):
connection_id = await self.ensure_connection_id(None, tracker_ip, tracker_port)
transaction_id = random.getrandbits(32)
reply = await self.request(
ScrapeRequest(connection_id, 2, transaction_id, infohashes), tracker_ip, tracker_port)
return decode(ScrapeResponse, reply), connection_id
def datagram_received(self, data: bytes, addr: (str, int)) -> None:
if len(data) < 8:
return
transaction_id = int.from_bytes(data[4:8], byteorder="big", signed=False)
if transaction_id in self.data_queue:
if not self.data_queue[transaction_id].done():
if data[3] == 3:
return self.data_queue[transaction_id].set_exception(Exception(decode(ErrorResponse, data).message))
return self.data_queue[transaction_id].set_result(data)
log.debug("unexpected packet (can be a response for a previously timed out request): %s", data.hex())
def connection_lost(self, exc: Exception = None) -> None:
self.transport = None
class TrackerClient:
event_controller = StreamController()
def __init__(self, node_id, announce_port, get_servers, timeout=10.0):
self.client = UDPTrackerClientProtocol(timeout=timeout)
self.transport = None
self.peer_id = make_peer_id(node_id.hex() if node_id else None)
self.announce_port = announce_port
self._get_servers = get_servers
self.results = {} # we can't probe the server before the interval, so we keep the result here until it expires
self.tasks = {}
async def start(self):
self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(
lambda: self.client, local_addr=("0.0.0.0", 0))
self.event_controller.stream.listen(
lambda request: self.on_hash(request[1], request[2]) if request[0] == 'search' else None)
def stop(self):
while self.tasks:
self.tasks.popitem()[1].cancel()
if self.transport is not None:
self.transport.close()
self.client = None
self.transport = None
self.event_controller.close()
def on_hash(self, info_hash, on_announcement=None):
if info_hash not in self.tasks:
task = asyncio.create_task(self.get_peer_list(info_hash, on_announcement=on_announcement))
task.add_done_callback(lambda *_: self.tasks.pop(info_hash, None))
self.tasks[info_hash] = task
async def announce_many(self, *info_hashes, stopped=False):
await asyncio.gather(
*[self._announce_many(server, info_hashes, stopped=stopped) for server in self._get_servers()],
return_exceptions=True)
async def _announce_many(self, server, info_hashes, stopped=False):
tracker_ip = await resolve_host(*server, 'udp')
still_good_info_hashes = {
info_hash for (info_hash, (next_announcement, _)) in self.results.get(tracker_ip, {}).items()
if time.time() < next_announcement
}
results = await asyncio.gather(
*[self._probe_server(info_hash, tracker_ip, server[1], stopped=stopped)
for info_hash in info_hashes if info_hash not in still_good_info_hashes],
return_exceptions=True)
if results:
errors = sum([1 for result in results if result is None or isinstance(result, Exception)])
log.info("Tracker: finished announcing %d files to %s:%d, %d errors", len(results), *server, errors)
async def get_peer_list(self, info_hash, stopped=False, on_announcement=None, no_port=False):
found = []
probes = [self._probe_server(info_hash, *server, stopped, no_port) for server in self._get_servers()]
for done in asyncio.as_completed(probes):
result = await done
if result is not None:
await asyncio.gather(*filter(asyncio.iscoroutine, [on_announcement(result)] if on_announcement else []))
found.append(result)
return found
async def get_kademlia_peer_list(self, info_hash):
responses = await self.get_peer_list(info_hash, no_port=True)
return await announcement_to_kademlia_peers(*responses)
async def _probe_server(self, info_hash, tracker_host, tracker_port, stopped=False, no_port=False):
result = None
try:
tracker_host = await resolve_host(tracker_host, tracker_port, 'udp')
except socket.error:
log.warning("DNS failure while resolving tracker host: %s, skipping.", tracker_host)
return
self.results.setdefault(tracker_host, {})
if info_hash in self.results[tracker_host]:
next_announcement, result = self.results[tracker_host][info_hash]
if time.time() < next_announcement:
return result
try:
result = await self.client.announce(
info_hash, self.peer_id, 0 if no_port else self.announce_port, tracker_host, tracker_port, stopped)
self.results[tracker_host][info_hash] = (time.time() + result.interval, result)
except asyncio.TimeoutError: # todo: this is UDP, timeout is common, we need a better metric for failures
self.results[tracker_host][info_hash] = (time.time() + 60.0, result)
log.debug("Tracker timed out: %s:%d", tracker_host, tracker_port)
return None
log.debug("Announced: %s found %d peers for %s", tracker_host, len(result.peers), info_hash.hex()[:8])
return result
def enqueue_tracker_search(info_hash: bytes, peer_q: asyncio.Queue):
async def on_announcement(announcement: AnnounceResponse):
peers = await announcement_to_kademlia_peers(announcement)
log.info("Found %d peers from tracker for %s", len(peers), info_hash.hex()[:8])
peer_q.put_nowait(peers)
TrackerClient.event_controller.add(('search', info_hash, on_announcement))
def announcement_to_kademlia_peers(*announcements: AnnounceResponse):
peers = [
(str(ipaddress.ip_address(peer.address)), peer.port)
for announcement in announcements for peer in announcement.peers if peer.port > 1024 # no privileged or 0
]
return get_kademlia_peers_from_hosts(peers)
class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not suitable for production
def __init__(self):
self.transport = None
self.known_conns = set()
self.peers = {}
def connection_made(self, transport: asyncio.DatagramTransport) -> None:
self.transport = transport
def add_peer(self, info_hash, ip_address: str, port: int):
self.peers.setdefault(info_hash, [])
self.peers[info_hash].append(encode_peer(ip_address, port))
def datagram_received(self, data: bytes, addr: (str, int)) -> None:
if len(data) < 16:
return
action = int.from_bytes(data[8:12], "big", signed=False)
if action == 0:
req = decode(ConnectRequest, data)
connection_id = random.getrandbits(32)
self.known_conns.add(connection_id)
return self.transport.sendto(encode(ConnectResponse(0, req.transaction_id, connection_id)), addr)
elif action == 1:
req = decode(AnnounceRequest, data)
if req.connection_id not in self.known_conns:
resp = encode(ErrorResponse(3, req.transaction_id, b'Connection ID missmatch.\x00'))
else:
compact_address = encode_peer(addr[0], req.port)
if req.event != 3:
self.add_peer(req.info_hash, addr[0], req.port)
elif compact_address in self.peers.get(req.info_hash, []):
self.peers[req.info_hash].remove(compact_address)
peers = [decode(CompactIPv4Peer, peer) for peer in self.peers[req.info_hash]]
resp = encode(AnnounceResponse(1, req.transaction_id, 1700, 0, len(peers), peers))
return self.transport.sendto(resp, addr)
def encode_peer(ip_address: str, port: int):
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), ip_address.split('.'), bytearray())
return compact_ip + port.to_bytes(2, "big", signed=False)

View file

@ -3,7 +3,6 @@ import codecs
import datetime import datetime
import random import random
import socket import socket
import time
import string import string
import sys import sys
import json import json
@ -20,10 +19,8 @@ import pkg_resources
import certifi import certifi
import aiohttp import aiohttp
from prometheus_client import Counter
from lbry.schema.claim import Claim from lbry.schema.claim import Claim
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -104,6 +101,10 @@ def check_connection(server="lbry.com", port=80, timeout=5) -> bool:
return False return False
async def async_check_connection(server="lbry.com", port=80, timeout=1) -> bool:
return await asyncio.get_event_loop().run_in_executor(None, check_connection, server, port, timeout)
def random_string(length=10, chars=string.ascii_lowercase): def random_string(length=10, chars=string.ascii_lowercase):
return ''.join([random.choice(chars) for _ in range(length)]) return ''.join([random.choice(chars) for _ in range(length)])
@ -130,16 +131,21 @@ def get_sd_hash(stream_info):
def json_dumps_pretty(obj, **kwargs): def json_dumps_pretty(obj, **kwargs):
return json.dumps(obj, sort_keys=True, indent=2, separators=(',', ': '), **kwargs) return json.dumps(obj, sort_keys=True, indent=2, separators=(',', ': '), **kwargs)
try:
# the standard contextlib.aclosing() is available in 3.10+ def cancel_task(task: typing.Optional[asyncio.Task]):
from contextlib import aclosing # pylint: disable=unused-import if task and not task.done():
except ImportError: task.cancel()
@contextlib.asynccontextmanager
async def aclosing(thing):
try: def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
yield thing for task in tasks:
finally: cancel_task(task)
await thing.aclose()
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
while tasks:
cancel_task(tasks.pop())
def async_timed_cache(duration: int): def async_timed_cache(duration: int):
def wrapper(func): def wrapper(func):
@ -150,7 +156,7 @@ def async_timed_cache(duration: int):
async def _inner(*args, **kwargs) -> typing.Any: async def _inner(*args, **kwargs) -> typing.Any:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
time_now = loop.time() time_now = loop.time()
key = (args, tuple(kwargs.items())) key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])])
if key in cache and (time_now - cache[key][1] < duration): if key in cache and (time_now - cache[key][1] < duration):
return cache[key][0] return cache[key][0]
to_cache = await func(*args, **kwargs) to_cache = await func(*args, **kwargs)
@ -168,7 +174,7 @@ def cache_concurrent(async_fn):
@functools.wraps(async_fn) @functools.wraps(async_fn)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
key = (args, tuple(kwargs.items())) key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])])
cache[key] = cache.get(key) or asyncio.create_task(async_fn(*args, **kwargs)) cache[key] = cache.get(key) or asyncio.create_task(async_fn(*args, **kwargs))
try: try:
return await cache[key] return await cache[key]
@ -182,8 +188,6 @@ def cache_concurrent(async_fn):
async def resolve_host(url: str, port: int, proto: str) -> str: async def resolve_host(url: str, port: int, proto: str) -> str:
if proto not in ['udp', 'tcp']: if proto not in ['udp', 'tcp']:
raise Exception("invalid protocol") raise Exception("invalid protocol")
if url.lower() == 'localhost':
return '127.0.0.1'
try: try:
if ipaddress.ip_address(url): if ipaddress.ip_address(url):
return url return url
@ -198,95 +202,18 @@ async def resolve_host(url: str, port: int, proto: str) -> str:
))[0][4][0] ))[0][4][0]
class LRUCacheWithMetrics:
__slots__ = [
'capacity',
'cache',
'_track_metrics',
'hits',
'misses'
]
def __init__(self, capacity: int, metric_name: typing.Optional[str] = None, namespace: str = "daemon_cache"):
self.capacity = capacity
self.cache = collections.OrderedDict()
if metric_name is None:
self._track_metrics = False
self.hits = self.misses = None
else:
self._track_metrics = True
try:
self.hits = Counter(
f"{metric_name}_cache_hit_count", "Number of cache hits", namespace=namespace
)
self.misses = Counter(
f"{metric_name}_cache_miss_count", "Number of cache misses", namespace=namespace
)
except ValueError as err:
log.debug("failed to set up prometheus %s_cache_miss_count metric: %s", metric_name, err)
self._track_metrics = False
self.hits = self.misses = None
def get(self, key, default=None):
try:
value = self.cache.pop(key)
if self._track_metrics:
self.hits.inc()
except KeyError:
if self._track_metrics:
self.misses.inc()
return default
self.cache[key] = value
return value
def set(self, key, value):
try:
self.cache.pop(key)
except KeyError:
if len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
self.cache[key] = value
def clear(self):
self.cache.clear()
def pop(self, key):
return self.cache.pop(key)
def __setitem__(self, key, value):
return self.set(key, value)
def __getitem__(self, item):
return self.get(item)
def __contains__(self, item) -> bool:
return item in self.cache
def __len__(self):
return len(self.cache)
def __delitem__(self, key):
self.cache.pop(key)
def __del__(self):
self.clear()
class LRUCache: class LRUCache:
__slots__ = [ __slots__ = [
'capacity', 'capacity',
'cache' 'cache'
] ]
def __init__(self, capacity: int): def __init__(self, capacity):
self.capacity = capacity self.capacity = capacity
self.cache = collections.OrderedDict() self.cache = collections.OrderedDict()
def get(self, key, default=None): def get(self, key):
try:
value = self.cache.pop(key) value = self.cache.pop(key)
except KeyError:
return default
self.cache[key] = value self.cache[key] = value
return value return value
@ -298,46 +225,22 @@ class LRUCache:
self.cache.popitem(last=False) self.cache.popitem(last=False)
self.cache[key] = value self.cache[key] = value
def items(self):
return self.cache.items()
def clear(self):
self.cache.clear()
def pop(self, key, default=None):
return self.cache.pop(key, default)
def __setitem__(self, key, value):
return self.set(key, value)
def __getitem__(self, item):
return self.get(item)
def __contains__(self, item) -> bool: def __contains__(self, item) -> bool:
return item in self.cache return item in self.cache
def __len__(self):
return len(self.cache)
def __delitem__(self, key):
self.cache.pop(key)
def __del__(self):
self.clear()
def lru_cache_concurrent(cache_size: typing.Optional[int] = None, def lru_cache_concurrent(cache_size: typing.Optional[int] = None,
override_lru_cache: typing.Optional[LRUCacheWithMetrics] = None): override_lru_cache: typing.Optional[LRUCache] = None):
if not cache_size and override_lru_cache is None: if not cache_size and override_lru_cache is None:
raise ValueError("invalid cache size") raise ValueError("invalid cache size")
concurrent_cache = {} concurrent_cache = {}
lru_cache = override_lru_cache if override_lru_cache is not None else LRUCacheWithMetrics(cache_size) lru_cache = override_lru_cache or LRUCache(cache_size)
def wrapper(async_fn): def wrapper(async_fn):
@functools.wraps(async_fn) @functools.wraps(async_fn)
async def _inner(*args, **kwargs): async def _inner(*args, **kwargs):
key = (args, tuple(kwargs.items())) key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])])
if key in lru_cache: if key in lru_cache:
return lru_cache.get(key) return lru_cache.get(key)
@ -362,125 +265,20 @@ def get_ssl_context() -> ssl.SSLContext:
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def aiohttp_request(method, url, **kwargs) -> typing.AsyncContextManager[aiohttp.ClientResponse]: async def aiohttp_request(method, url, **kwargs) -> typing.AsyncContextManager[aiohttp.ClientResponse]:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.request(method, url, **kwargs) as response: async with session.request(method, url, ssl=get_ssl_context(), **kwargs) as response:
yield response yield response
# the ipaddress module does not show these subnets as reserved async def get_external_ip() -> typing.Optional[str]: # used if upnp is disabled or non-functioning
CARRIER_GRADE_NAT_SUBNET = ipaddress.ip_network('100.64.0.0/10')
IPV4_TO_6_RELAY_SUBNET = ipaddress.ip_network('192.88.99.0/24')
def is_valid_public_ipv4(address, allow_localhost: bool = False, allow_lan: bool = False):
try:
parsed_ip = ipaddress.ip_address(address)
if parsed_ip.is_loopback and allow_localhost:
return True
if allow_lan and parsed_ip.is_private:
return True
if any((parsed_ip.version != 4, parsed_ip.is_unspecified, parsed_ip.is_link_local, parsed_ip.is_loopback,
parsed_ip.is_multicast, parsed_ip.is_reserved, parsed_ip.is_private)):
return False
else:
return not any((CARRIER_GRADE_NAT_SUBNET.supernet_of(ipaddress.ip_network(f"{address}/32")),
IPV4_TO_6_RELAY_SUBNET.supernet_of(ipaddress.ip_network(f"{address}/32"))))
except (ipaddress.AddressValueError, ValueError):
return False
async def fallback_get_external_ip(): # used if spv servers can't be used for ip detection
try: try:
async with aiohttp_request("get", "https://api.lbry.com/ip") as resp: async with aiohttp_request("get", "https://api.lbry.com/ip") as resp:
response = await resp.json() response = await resp.json()
if response['success']: if response['success']:
return response['data']['ip'], None return response['data']['ip']
except Exception: except Exception:
return None, None return
async def _get_external_ip(default_servers) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
# used if upnp is disabled or non-functioning
from lbry.wallet.udp import SPVStatusClientProtocol # pylint: disable=C0415
hostname_to_ip = {}
ip_to_hostnames = collections.defaultdict(list)
async def resolve_spv(server, port):
try:
server_addr = await resolve_host(server, port, 'udp')
hostname_to_ip[server] = (server_addr, port)
ip_to_hostnames[(server_addr, port)].append(server)
except Exception:
log.exception("error looking up dns for spv servers")
# accumulate the dns results
await asyncio.gather(*(resolve_spv(server, port) for (server, port) in default_servers))
loop = asyncio.get_event_loop()
pong_responses = asyncio.Queue()
connection = SPVStatusClientProtocol(pong_responses)
try:
await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0))
# could raise OSError if it cant bind
randomized_servers = list(ip_to_hostnames.keys())
random.shuffle(randomized_servers)
for server in randomized_servers:
connection.ping(server)
try:
_, pong = await asyncio.wait_for(pong_responses.get(), 1)
if is_valid_public_ipv4(pong.ip_address):
return pong.ip_address, ip_to_hostnames[server][0]
except asyncio.TimeoutError:
pass
return None, None
finally:
connection.close()
async def get_external_ip(default_servers) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
ip_from_spv_servers = await _get_external_ip(default_servers)
if not ip_from_spv_servers[1]:
return await fallback_get_external_ip()
return ip_from_spv_servers
def is_running_from_bundle(): def is_running_from_bundle():
# see https://pyinstaller.readthedocs.io/en/stable/runtime-information.html # see https://pyinstaller.readthedocs.io/en/stable/runtime-information.html
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
class LockWithMetrics(asyncio.Lock):
def __init__(self, acquire_metric, held_time_metric):
super().__init__()
self._acquire_metric = acquire_metric
self._lock_held_time_metric = held_time_metric
self._lock_acquired_time = None
async def acquire(self):
start = time.perf_counter()
try:
return await super().acquire()
finally:
self._lock_acquired_time = time.perf_counter()
self._acquire_metric.observe(self._lock_acquired_time - start)
def release(self):
try:
return super().release()
finally:
self._lock_held_time_metric.observe(time.perf_counter() - self._lock_acquired_time)
def get_colliding_prefix_bits(first_value: bytes, second_value: bytes):
"""
Calculates the amount of colliding prefix bits between <first_value> and <second_value>.
This is given by the amount of bits that are the same until the first different one (via XOR),
starting from the most significant bit to the least significant bit.
:param first_value: first value to compare, bigger than size.
:param second_value: second value to compare, bigger than size.
:return: amount of prefix colliding bits.
"""
assert len(first_value) == len(second_value), "length should be the same"
size = len(first_value) * 8
first_value, second_value = int.from_bytes(first_value, "big"), int.from_bytes(second_value, "big")
return size - (first_value ^ second_value).bit_length()

View file

@ -1,23 +1,17 @@
__lbcd__ = 'lbcd' __node_daemon__ = 'lbrycrdd'
__lbcctl__ = 'lbcctl' __node_cli__ = 'lbrycrd-cli'
__lbcwallet__ = 'lbcwallet' __node_bin__ = ''
__lbcd_url__ = ( __node_url__ = (
'https://github.com/lbryio/lbcd/releases/download/' + 'https://github.com/lbryio/lbrycrd/releases/download/v0.17.4.3/lbrycrd-linux-1743.zip'
'v0.22.100-rc.0/lbcd_0.22.100-rc.0_TARGET_PLATFORM.tar.gz'
)
__lbcwallet_url__ = (
'https://github.com/lbryio/lbcwallet/releases/download/' +
'v0.13.100-alpha.0/lbcwallet_0.13.100-alpha.0_TARGET_PLATFORM.tar.gz'
) )
__spvserver__ = 'lbry.wallet.server.coin.LBCRegTest' __spvserver__ = 'lbry.wallet.server.coin.LBCRegTest'
from lbry.wallet.wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DISK from .wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DISK
from lbry.wallet.manager import WalletManager from .manager import WalletManager
from lbry.wallet.network import Network from .network import Network
from lbry.wallet.ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent from .ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent
from lbry.wallet.account import Account, AddressManager, SingleKey, HierarchicalDeterministic, \ from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic
DeterministicChannelKeyManager from .transaction import Transaction, Output, Input
from lbry.wallet.transaction import Transaction, Output, Input from .script import OutputScript, InputScript
from lbry.wallet.script import OutputScript, InputScript from .database import SQLiteMixin, Database
from lbry.wallet.database import SQLiteMixin, Database from .header import Headers
from lbry.wallet.header import Headers

View file

@ -5,16 +5,18 @@ import logging
import typing import typing
import asyncio import asyncio
import random import random
from functools import partial
from hashlib import sha256 from hashlib import sha256
from string import hexdigits from string import hexdigits
from typing import Type, Dict, Tuple, Optional, Any, List from typing import Type, Dict, Tuple, Optional, Any, List
import ecdsa
from lbry.error import InvalidPasswordError from lbry.error import InvalidPasswordError
from lbry.crypto.crypt import aes_encrypt, aes_decrypt from lbry.crypto.crypt import aes_encrypt, aes_decrypt
from .bip32 import PrivateKey, PublicKey, KeyPath, from_extended_key_string from .bip32 import PrivateKey, PubKey, from_extended_key_string
from .mnemonic import Mnemonic from .mnemonic import Mnemonic
from .constants import COIN, TXO_TYPES from .constants import COIN, CLAIM_TYPES, TXO_TYPES
from .transaction import Transaction, Input, Output from .transaction import Transaction, Input, Output
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
@ -33,49 +35,6 @@ def validate_claim_id(claim_id):
raise Exception("Claim id is not hex encoded") raise Exception("Claim id is not hex encoded")
class DeterministicChannelKeyManager:
def __init__(self, account: 'Account'):
self.account = account
self.last_known = 0
self.cache = {}
self._private_key: Optional[PrivateKey] = None
@property
def private_key(self):
if self._private_key is None:
if self.account.private_key is not None:
self._private_key = self.account.private_key.child(KeyPath.CHANNEL)
return self._private_key
def maybe_generate_deterministic_key_for_channel(self, txo):
if self.private_key is None:
return
next_private_key = self.private_key.child(self.last_known)
public_key = next_private_key.public_key
public_key_bytes = public_key.pubkey_bytes
if txo.claim.channel.public_key_bytes == public_key_bytes:
self.cache[public_key.address] = next_private_key
self.last_known += 1
async def ensure_cache_primed(self):
if self.private_key is not None:
await self.generate_next_key()
async def generate_next_key(self) -> PrivateKey:
db = self.account.ledger.db
while True:
next_private_key = self.private_key.child(self.last_known)
public_key = next_private_key.public_key
self.cache[public_key.address] = next_private_key
if not await db.is_channel_key_used(self.account, public_key):
return next_private_key
self.last_known += 1
def get_private_key_from_pubkey_hash(self, pubkey_hash) -> PrivateKey:
return self.cache.get(pubkey_hash)
class AddressManager: class AddressManager:
name: str name: str
@ -112,7 +71,6 @@ class AddressManager:
def _query_addresses(self, **constraints): def _query_addresses(self, **constraints):
return self.account.ledger.db.get_addresses( return self.account.ledger.db.get_addresses(
read_only=constraints.pop("read_only", False),
accounts=[self.account], accounts=[self.account],
chain=self.chain_number, chain=self.chain_number,
**constraints **constraints
@ -121,7 +79,7 @@ class AddressManager:
def get_private_key(self, index: int) -> PrivateKey: def get_private_key(self, index: int) -> PrivateKey:
raise NotImplementedError raise NotImplementedError
def get_public_key(self, index: int) -> PublicKey: def get_public_key(self, index: int) -> PubKey:
raise NotImplementedError raise NotImplementedError
async def get_max_gap(self): async def get_max_gap(self):
@ -138,7 +96,6 @@ class AddressManager:
return [r['address'] for r in records] return [r['address'] for r in records]
async def get_or_create_usable_address(self) -> str: async def get_or_create_usable_address(self) -> str:
async with self.address_generator_lock:
addresses = await self.get_addresses(only_usable=True, limit=10) addresses = await self.get_addresses(only_usable=True, limit=10)
if addresses: if addresses:
return random.choice(addresses) return random.choice(addresses)
@ -161,8 +118,8 @@ class HierarchicalDeterministic(AddressManager):
@classmethod @classmethod
def from_dict(cls, account: 'Account', d: dict) -> Tuple[AddressManager, AddressManager]: def from_dict(cls, account: 'Account', d: dict) -> Tuple[AddressManager, AddressManager]:
return ( return (
cls(account, KeyPath.RECEIVE, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})), cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})),
cls(account, KeyPath.CHANGE, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1})) cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1}))
) )
def merge(self, d: dict): def merge(self, d: dict):
@ -175,7 +132,7 @@ class HierarchicalDeterministic(AddressManager):
def get_private_key(self, index: int) -> PrivateKey: def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key.child(self.chain_number).child(index) return self.account.private_key.child(self.chain_number).child(index)
def get_public_key(self, index: int) -> PublicKey: def get_public_key(self, index: int) -> PubKey:
return self.account.public_key.child(self.chain_number).child(index) return self.account.public_key.child(self.chain_number).child(index)
async def get_max_gap(self) -> int: async def get_max_gap(self) -> int:
@ -235,7 +192,7 @@ class SingleKey(AddressManager):
@classmethod @classmethod
def from_dict(cls, account: 'Account', d: dict) \ def from_dict(cls, account: 'Account', d: dict) \
-> Tuple[AddressManager, AddressManager]: -> Tuple[AddressManager, AddressManager]:
same_address_manager = cls(account, account.public_key, KeyPath.RECEIVE) same_address_manager = cls(account, account.public_key, 0)
return same_address_manager, same_address_manager return same_address_manager, same_address_manager
def to_dict_instance(self): def to_dict_instance(self):
@ -244,7 +201,7 @@ class SingleKey(AddressManager):
def get_private_key(self, index: int) -> PrivateKey: def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key return self.account.private_key
def get_public_key(self, index: int) -> PublicKey: def get_public_key(self, index: int) -> PubKey:
return self.account.public_key return self.account.public_key
async def get_max_gap(self) -> int: async def get_max_gap(self) -> int:
@ -266,6 +223,9 @@ class SingleKey(AddressManager):
class Account: class Account:
mnemonic_class = Mnemonic
private_key_class = PrivateKey
public_key_class = PubKey
address_generators: Dict[str, Type[AddressManager]] = { address_generators: Dict[str, Type[AddressManager]] = {
SingleKey.name: SingleKey, SingleKey.name: SingleKey,
HierarchicalDeterministic.name: HierarchicalDeterministic, HierarchicalDeterministic.name: HierarchicalDeterministic,
@ -273,7 +233,7 @@ class Account:
def __init__(self, ledger: 'Ledger', wallet: 'Wallet', name: str, def __init__(self, ledger: 'Ledger', wallet: 'Wallet', name: str,
seed: str, private_key_string: str, encrypted: bool, seed: str, private_key_string: str, encrypted: bool,
private_key: Optional[PrivateKey], public_key: PublicKey, private_key: Optional[PrivateKey], public_key: PubKey,
address_generator: dict, modified_on: float, channel_keys: dict) -> None: address_generator: dict, modified_on: float, channel_keys: dict) -> None:
self.ledger = ledger self.ledger = ledger
self.wallet = wallet self.wallet = wallet
@ -284,14 +244,13 @@ class Account:
self.private_key_string = private_key_string self.private_key_string = private_key_string
self.init_vectors: Dict[str, bytes] = {} self.init_vectors: Dict[str, bytes] = {}
self.encrypted = encrypted self.encrypted = encrypted
self.private_key: Optional[PrivateKey] = private_key self.private_key = private_key
self.public_key: PublicKey = public_key self.public_key = public_key
generator_name = address_generator.get('name', HierarchicalDeterministic.name) generator_name = address_generator.get('name', HierarchicalDeterministic.name)
self.address_generator = self.address_generators[generator_name] self.address_generator = self.address_generators[generator_name]
self.receiving, self.change = self.address_generator.from_dict(self, address_generator) self.receiving, self.change = self.address_generator.from_dict(self, address_generator)
self.address_managers = {am.chain_number: am for am in (self.receiving, self.change)} self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}}
self.channel_keys = channel_keys self.channel_keys = channel_keys
self.deterministic_channel_keys = DeterministicChannelKeyManager(self)
ledger.add_account(self) ledger.add_account(self)
wallet.add_account(self) wallet.add_account(self)
@ -306,19 +265,19 @@ class Account:
name: str = None, address_generator: dict = None): name: str = None, address_generator: dict = None):
return cls.from_dict(ledger, wallet, { return cls.from_dict(ledger, wallet, {
'name': name, 'name': name,
'seed': Mnemonic().make_seed(), 'seed': cls.mnemonic_class().make_seed(),
'address_generator': address_generator or {} 'address_generator': address_generator or {}
}) })
@classmethod @classmethod
def get_private_key_from_seed(cls, ledger: 'Ledger', seed: str, password: str): def get_private_key_from_seed(cls, ledger: 'Ledger', seed: str, password: str):
return PrivateKey.from_seed( return cls.private_key_class.from_seed(
ledger, Mnemonic.mnemonic_to_seed(seed, password or 'lbryum') ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password or 'lbryum')
) )
@classmethod @classmethod
def keys_from_dict(cls, ledger: 'Ledger', d: dict) \ def keys_from_dict(cls, ledger: 'Ledger', d: dict) \
-> Tuple[str, Optional[PrivateKey], PublicKey]: -> Tuple[str, Optional[PrivateKey], PubKey]:
seed = d.get('seed', '') seed = d.get('seed', '')
private_key_string = d.get('private_key', '') private_key_string = d.get('private_key', '')
private_key = None private_key = None
@ -351,7 +310,7 @@ class Account:
private_key=private_key, private_key=private_key,
public_key=public_key, public_key=public_key,
address_generator=d.get('address_generator', {}), address_generator=d.get('address_generator', {}),
modified_on=int(d.get('modified_on', time.time())), modified_on=d.get('modified_on', time.time()),
channel_keys=d.get('certificates', {}) channel_keys=d.get('certificates', {})
) )
@ -383,7 +342,7 @@ class Account:
def merge(self, d: dict): def merge(self, d: dict):
if d.get('modified_on', 0) > self.modified_on: if d.get('modified_on', 0) > self.modified_on:
self.name = d['name'] self.name = d['name']
self.modified_on = int(d.get('modified_on', time.time())) self.modified_on = d.get('modified_on', time.time())
assert self.address_generator.name == d['address_generator']['name'] assert self.address_generator.name == d['address_generator']['name']
for chain_name in ('change', 'receiving'): for chain_name in ('change', 'receiving'):
if chain_name in d['address_generator']: if chain_name in d['address_generator']:
@ -475,9 +434,9 @@ class Account:
addresses.extend(new_addresses) addresses.extend(new_addresses)
return addresses return addresses
async def get_addresses(self, read_only=False, **constraints) -> List[str]: async def get_addresses(self, **constraints) -> List[str]:
rows = await self.ledger.db.select_addresses('address', read_only=read_only, accounts=[self], **constraints) rows = await self.ledger.db.select_addresses('address', accounts=[self], **constraints)
return [r['address'] for r in rows] return [r[0] for r in rows]
def get_address_records(self, **constraints): def get_address_records(self, **constraints):
return self.ledger.db.get_addresses(accounts=[self], **constraints) return self.ledger.db.get_addresses(accounts=[self], **constraints)
@ -489,16 +448,16 @@ class Account:
assert not self.encrypted, "Cannot get private key on encrypted wallet account." assert not self.encrypted, "Cannot get private key on encrypted wallet account."
return self.address_managers[chain].get_private_key(index) return self.address_managers[chain].get_private_key(index)
def get_public_key(self, chain: int, index: int) -> PublicKey: def get_public_key(self, chain: int, index: int) -> PubKey:
return self.address_managers[chain].get_public_key(index) return self.address_managers[chain].get_public_key(index)
def get_balance(self, confirmations=0, include_claims=False, read_only=False, **constraints): def get_balance(self, confirmations: int = 0, include_claims=False, **constraints):
if not include_claims: if not include_claims:
constraints.update({'txo_type__in': (TXO_TYPES['other'], TXO_TYPES['purchase'])}) constraints.update({'txo_type__in': (0, TXO_TYPES['purchase'])})
if confirmations > 0: if confirmations > 0:
height = self.ledger.headers.height - (confirmations-1) height = self.ledger.headers.height - (confirmations-1)
constraints.update({'height__lte': height, 'height__gt': 0}) constraints.update({'height__lte': height, 'height__gt': 0})
return self.ledger.db.get_balance(accounts=[self], read_only=read_only, **constraints) return self.ledger.db.get_balance(accounts=[self], **constraints)
async def get_max_gap(self): async def get_max_gap(self):
change_gap = await self.change.get_max_gap() change_gap = await self.change.get_max_gap()
@ -560,18 +519,16 @@ class Account:
return tx return tx
async def generate_channel_private_key(self): def add_channel_private_key(self, private_key):
return await self.deterministic_channel_keys.generate_next_key() public_key_bytes = private_key.get_verifying_key().to_der()
channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes)
self.channel_keys[channel_pubkey_hash] = private_key.to_pem().decode()
def add_channel_private_key(self, private_key: PrivateKey): def get_channel_private_key(self, public_key_bytes):
self.channel_keys[private_key.address] = private_key.to_pem().decode()
async def get_channel_private_key(self, public_key_bytes) -> PrivateKey:
channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes) channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes)
private_key_pem = self.channel_keys.get(channel_pubkey_hash) private_key_pem = self.channel_keys.get(channel_pubkey_hash)
if private_key_pem: if private_key_pem:
return PrivateKey.from_pem(self.ledger, private_key_pem) return ecdsa.SigningKey.from_pem(private_key_pem, hashfunc=sha256)
return self.deterministic_channel_keys.get_private_key_from_pubkey_hash(channel_pubkey_hash)
async def maybe_migrate_certificates(self): async def maybe_migrate_certificates(self):
if not self.channel_keys: if not self.channel_keys:
@ -580,10 +537,11 @@ class Account:
for private_key_pem in self.channel_keys.values(): for private_key_pem in self.channel_keys.values():
if not isinstance(private_key_pem, str): if not isinstance(private_key_pem, str):
continue continue
if not private_key_pem.startswith("-----BEGIN"): if "-----BEGIN EC PRIVATE KEY-----" not in private_key_pem:
continue continue
private_key = PrivateKey.from_pem(self.ledger, private_key_pem) private_key = ecdsa.SigningKey.from_pem(private_key_pem, hashfunc=sha256)
channel_keys[private_key.address] = private_key_pem public_key_der = private_key.get_verifying_key().to_der()
channel_keys[self.ledger.public_key_to_address(public_key_der)] = private_key_pem
if self.channel_keys != channel_keys: if self.channel_keys != channel_keys:
self.channel_keys = channel_keys self.channel_keys = channel_keys
self.wallet.save() self.wallet.save()
@ -603,24 +561,41 @@ class Account:
if gap_changed: if gap_changed:
self.wallet.save() self.wallet.save()
async def get_detailed_balance(self, confirmations=0, read_only=False): async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False):
constraints = {} tips_balance, supports_balance, claims_balance = 0, 0, 0
if confirmations > 0: get_total_balance = partial(self.get_balance, confirmations=confirmations, include_claims=True)
height = self.ledger.headers.height - (confirmations-1) total = await get_total_balance()
constraints.update({'height__lte': height, 'height__gt': 0}) if reserved_subtotals:
return await self.ledger.db.get_detailed_balance( claims_balance = await get_total_balance(txo_type__in=CLAIM_TYPES)
accounts=[self], read_only=read_only, **constraints for amount, spent, from_me, to_me, height in await self.get_support_summary():
if confirmations > 0 and not 0 < height <= self.ledger.headers.height - (confirmations - 1):
continue
if not spent and to_me:
if from_me:
supports_balance += amount
else:
tips_balance += amount
reserved = claims_balance + supports_balance + tips_balance
else:
reserved = await self.get_balance(
confirmations=confirmations, include_claims=True, txo_type__gt=0
) )
return {
'total': total,
'available': total - reserved,
'reserved': reserved,
'reserved_subtotals': {
'claims': claims_balance,
'supports': supports_balance,
'tips': tips_balance
} if reserved_subtotals else None
}
def get_transaction_history(self, read_only=False, **constraints): def get_transaction_history(self, **constraints):
return self.ledger.get_transaction_history( return self.ledger.get_transaction_history(wallet=self.wallet, accounts=[self], **constraints)
read_only=read_only, wallet=self.wallet, accounts=[self], **constraints
)
def get_transaction_history_count(self, read_only=False, **constraints): def get_transaction_history_count(self, **constraints):
return self.ledger.get_transaction_history_count( return self.ledger.get_transaction_history_count(wallet=self.wallet, accounts=[self], **constraints)
read_only=read_only, wallet=self.wallet, accounts=[self], **constraints
)
def get_claims(self, **constraints): def get_claims(self, **constraints):
return self.ledger.get_claims(wallet=self.wallet, accounts=[self], **constraints) return self.ledger.get_claims(wallet=self.wallet, accounts=[self], **constraints)
@ -653,7 +628,7 @@ class Account:
return self.ledger.get_support_count(wallet=self.wallet, accounts=[self], **constraints) return self.ledger.get_support_count(wallet=self.wallet, accounts=[self], **constraints)
def get_support_summary(self): def get_support_summary(self):
return self.ledger.db.get_supports_summary(wallet=self.wallet, accounts=[self]) return self.ledger.db.get_supports_summary(account_id=self.id)
async def release_all_outputs(self): async def release_all_outputs(self):
await self.ledger.db.release_all_outputs(self) await self.ledger.db.release_all_outputs(self)

View file

@ -1,21 +1,10 @@
from asn1crypto.keys import PrivateKeyInfo, ECPrivateKey from coincurve import PublicKey, PrivateKey as _PrivateKey
from coincurve import PublicKey as cPublicKey, PrivateKey as cPrivateKey
from coincurve.utils import (
pem_to_der, lib as libsecp256k1, ffi as libsecp256k1_ffi
)
from coincurve.ecdsa import CDATA_SIG_LENGTH
from lbry.crypto.hash import hmac_sha512, hash160, double_sha256 from lbry.crypto.hash import hmac_sha512, hash160, double_sha256
from lbry.crypto.base58 import Base58 from lbry.crypto.base58 import Base58
from .util import cachedproperty from .util import cachedproperty
class KeyPath:
RECEIVE = 0
CHANGE = 1
CHANNEL = 2
class DerivationError(Exception): class DerivationError(Exception):
""" Raised when an invalid derivation occurs. """ """ Raised when an invalid derivation occurs. """
@ -57,11 +46,9 @@ class _KeyBase:
if len(raw_serkey) != 33: if len(raw_serkey) != 33:
raise ValueError('raw_serkey must have length 33') raise ValueError('raw_serkey must have length 33')
return ( return (ver_bytes + bytes((self.depth,))
ver_bytes + bytes((self.depth,))
+ self.parent_fingerprint() + self.n.to_bytes(4, 'big') + self.parent_fingerprint() + self.n.to_bytes(4, 'big')
+ self.chain_code + raw_serkey + self.chain_code + raw_serkey)
)
def identifier(self): def identifier(self):
raise NotImplementedError raise NotImplementedError
@ -82,30 +69,26 @@ class _KeyBase:
return Base58.encode_check(self.extended_key()) return Base58.encode_check(self.extended_key())
class PublicKey(_KeyBase): class PubKey(_KeyBase):
""" A BIP32 public key. """ """ A BIP32 public key. """
def __init__(self, ledger, pubkey, chain_code, n, depth, parent=None): def __init__(self, ledger, pubkey, chain_code, n, depth, parent=None):
super().__init__(ledger, chain_code, n, depth, parent) super().__init__(ledger, chain_code, n, depth, parent)
if isinstance(pubkey, cPublicKey): if isinstance(pubkey, PublicKey):
self.verifying_key = pubkey self.verifying_key = pubkey
else: else:
self.verifying_key = self._verifying_key_from_pubkey(pubkey) self.verifying_key = self._verifying_key_from_pubkey(pubkey)
@classmethod
def from_compressed(cls, public_key_bytes, ledger=None) -> 'PublicKey':
return cls(ledger, public_key_bytes, bytes((0,)*32), 0, 0)
@classmethod @classmethod
def _verifying_key_from_pubkey(cls, pubkey): def _verifying_key_from_pubkey(cls, pubkey):
""" Converts a 33-byte compressed pubkey into an coincurve.PublicKey object. """ """ Converts a 33-byte compressed pubkey into an PublicKey object. """
if not isinstance(pubkey, (bytes, bytearray)): if not isinstance(pubkey, (bytes, bytearray)):
raise TypeError('pubkey must be raw bytes') raise TypeError('pubkey must be raw bytes')
if len(pubkey) != 33: if len(pubkey) != 33:
raise ValueError('pubkey must be 33 bytes') raise ValueError('pubkey must be 33 bytes')
if pubkey[0] not in (2, 3): if pubkey[0] not in (2, 3):
raise ValueError('invalid pubkey prefix byte') raise ValueError('invalid pubkey prefix byte')
return cPublicKey(pubkey) return PublicKey(pubkey)
@cachedproperty @cachedproperty
def pubkey_bytes(self): def pubkey_bytes(self):
@ -120,7 +103,7 @@ class PublicKey(_KeyBase):
def ec_point(self): def ec_point(self):
return self.verifying_key.point() return self.verifying_key.point()
def child(self, n: int) -> 'PublicKey': def child(self, n: int):
""" Return the derived child extended pubkey at index N. """ """ Return the derived child extended pubkey at index N. """
if not 0 <= n < (1 << 31): if not 0 <= n < (1 << 31):
raise ValueError('invalid BIP32 public key child number') raise ValueError('invalid BIP32 public key child number')
@ -128,7 +111,7 @@ class PublicKey(_KeyBase):
msg = self.pubkey_bytes + n.to_bytes(4, 'big') msg = self.pubkey_bytes + n.to_bytes(4, 'big')
L_b, R_b = self._hmac_sha512(msg) # pylint: disable=invalid-name L_b, R_b = self._hmac_sha512(msg) # pylint: disable=invalid-name
derived_key = self.verifying_key.add(L_b) derived_key = self.verifying_key.add(L_b)
return PublicKey(self.ledger, derived_key, R_b, n, self.depth + 1, self) return PubKey(self.ledger, derived_key, R_b, n, self.depth + 1, self)
def identifier(self): def identifier(self):
""" Return the key's identifier as 20 bytes. """ """ Return the key's identifier as 20 bytes. """
@ -141,36 +124,6 @@ class PublicKey(_KeyBase):
self.pubkey_bytes self.pubkey_bytes
) )
def verify(self, signature, digest) -> bool:
""" Verify that a signature is valid for a 32 byte digest. """
if len(signature) != 64:
raise ValueError('Signature must be 64 bytes long.')
if len(digest) != 32:
raise ValueError('Digest must be 32 bytes long.')
key = self.verifying_key
raw_signature = libsecp256k1_ffi.new('secp256k1_ecdsa_signature *')
parsed = libsecp256k1.secp256k1_ecdsa_signature_parse_compact(
key.context.ctx, raw_signature, signature
)
assert parsed == 1
normalized_signature = libsecp256k1_ffi.new('secp256k1_ecdsa_signature *')
libsecp256k1.secp256k1_ecdsa_signature_normalize(
key.context.ctx, normalized_signature, raw_signature
)
verified = libsecp256k1.secp256k1_ecdsa_verify(
key.context.ctx, normalized_signature, digest, key.public_key
)
return bool(verified)
class PrivateKey(_KeyBase): class PrivateKey(_KeyBase):
"""A BIP32 private key.""" """A BIP32 private key."""
@ -179,7 +132,7 @@ class PrivateKey(_KeyBase):
def __init__(self, ledger, privkey, chain_code, n, depth, parent=None): def __init__(self, ledger, privkey, chain_code, n, depth, parent=None):
super().__init__(ledger, chain_code, n, depth, parent) super().__init__(ledger, chain_code, n, depth, parent)
if isinstance(privkey, cPrivateKey): if isinstance(privkey, _PrivateKey):
self.signing_key = privkey self.signing_key = privkey
else: else:
self.signing_key = self._signing_key_from_privkey(privkey) self.signing_key = self._signing_key_from_privkey(privkey)
@ -187,7 +140,7 @@ class PrivateKey(_KeyBase):
@classmethod @classmethod
def _signing_key_from_privkey(cls, private_key): def _signing_key_from_privkey(cls, private_key):
""" Converts a 32-byte private key into an coincurve.PrivateKey object. """ """ Converts a 32-byte private key into an coincurve.PrivateKey object. """
return cPrivateKey.from_int(PrivateKey._private_key_secret_exponent(private_key)) return _PrivateKey.from_int(PrivateKey._private_key_secret_exponent(private_key))
@classmethod @classmethod
def _private_key_secret_exponent(cls, private_key): def _private_key_secret_exponent(cls, private_key):
@ -199,40 +152,24 @@ class PrivateKey(_KeyBase):
return int.from_bytes(private_key, 'big') return int.from_bytes(private_key, 'big')
@classmethod @classmethod
def from_seed(cls, ledger, seed) -> 'PrivateKey': def from_seed(cls, ledger, seed):
# This hard-coded message string seems to be coin-independent... # This hard-coded message string seems to be coin-independent...
hmac = hmac_sha512(b'Bitcoin seed', seed) hmac = hmac_sha512(b'Bitcoin seed', seed)
privkey, chain_code = hmac[:32], hmac[32:] privkey, chain_code = hmac[:32], hmac[32:]
return cls(ledger, privkey, chain_code, 0, 0) return cls(ledger, privkey, chain_code, 0, 0)
@classmethod
def from_pem(cls, ledger, pem) -> 'PrivateKey':
der = pem_to_der(pem.encode())
try:
key_int = ECPrivateKey.load(der).native['private_key']
except ValueError:
key_int = PrivateKeyInfo.load(der).native['private_key']['private_key']
private_key = cPrivateKey.from_int(key_int)
return cls(ledger, private_key, bytes((0,)*32), 0, 0)
@classmethod
def from_bytes(cls, ledger, key_bytes) -> 'PrivateKey':
return cls(ledger, cPrivateKey(key_bytes), bytes((0,)*32), 0, 0)
@cachedproperty @cachedproperty
def private_key_bytes(self): def private_key_bytes(self):
""" Return the serialized private key (no leading zero byte). """ """ Return the serialized private key (no leading zero byte). """
return self.signing_key.secret return self.signing_key.secret
@cachedproperty @cachedproperty
def public_key(self) -> PublicKey: def public_key(self):
""" Return the corresponding extended public key. """ """ Return the corresponding extended public key. """
verifying_key = self.signing_key.public_key verifying_key = self.signing_key.public_key
parent_pubkey = self.parent.public_key if self.parent else None parent_pubkey = self.parent.public_key if self.parent else None
return PublicKey( return PubKey(self.ledger, verifying_key, self.chain_code, self.n, self.depth,
self.ledger, verifying_key, self.chain_code, parent_pubkey)
self.n, self.depth, parent_pubkey
)
def ec_point(self): def ec_point(self):
return self.public_key.ec_point() return self.public_key.ec_point()
@ -245,12 +182,11 @@ class PrivateKey(_KeyBase):
""" Return the private key encoded in Wallet Import Format. """ """ Return the private key encoded in Wallet Import Format. """
return self.ledger.private_key_to_wif(self.private_key_bytes) return self.ledger.private_key_to_wif(self.private_key_bytes)
@property
def address(self): def address(self):
""" The public key as a P2PKH address. """ """ The public key as a P2PKH address. """
return self.public_key.address return self.public_key.address
def child(self, n) -> 'PrivateKey': def child(self, n):
""" Return the derived child extended private key at index N.""" """ Return the derived child extended private key at index N."""
if not 0 <= n < (1 << 32): if not 0 <= n < (1 << 32):
raise ValueError('invalid BIP32 private key child number') raise ValueError('invalid BIP32 private key child number')
@ -269,28 +205,6 @@ class PrivateKey(_KeyBase):
""" Produce a signature for piece of data by double hashing it and signing the hash. """ """ Produce a signature for piece of data by double hashing it and signing the hash. """
return self.signing_key.sign(data, hasher=double_sha256) return self.signing_key.sign(data, hasher=double_sha256)
def sign_compact(self, digest):
""" Produce a compact signature. """
key = self.signing_key
signature = libsecp256k1_ffi.new('secp256k1_ecdsa_signature *')
signed = libsecp256k1.secp256k1_ecdsa_sign(
key.context.ctx, signature, digest, key.secret,
libsecp256k1_ffi.NULL, libsecp256k1_ffi.NULL
)
if not signed:
raise ValueError('The private key was invalid.')
serialized = libsecp256k1_ffi.new('unsigned char[%d]' % CDATA_SIG_LENGTH)
compacted = libsecp256k1.secp256k1_ecdsa_signature_serialize_compact(
key.context.ctx, serialized, signature
)
if compacted != 1:
raise ValueError('The signature could not be compacted.')
return bytes(libsecp256k1_ffi.buffer(serialized, CDATA_SIG_LENGTH))
def identifier(self): def identifier(self):
"""Return the key's identifier as 20 bytes.""" """Return the key's identifier as 20 bytes."""
return self.public_key.identifier() return self.public_key.identifier()
@ -302,12 +216,9 @@ class PrivateKey(_KeyBase):
b'\0' + self.private_key_bytes b'\0' + self.private_key_bytes
) )
def to_pem(self):
return self.signing_key.to_pem()
def _from_extended_key(ledger, ekey): def _from_extended_key(ledger, ekey):
"""Return a PublicKey or PrivateKey from an extended key raw bytes.""" """Return a PubKey or PrivateKey from an extended key raw bytes."""
if not isinstance(ekey, (bytes, bytearray)): if not isinstance(ekey, (bytes, bytearray)):
raise TypeError('extended key must be raw bytes') raise TypeError('extended key must be raw bytes')
if len(ekey) != 78: if len(ekey) != 78:
@ -319,7 +230,7 @@ def _from_extended_key(ledger, ekey):
if ekey[:4] == ledger.extended_public_key_prefix: if ekey[:4] == ledger.extended_public_key_prefix:
pubkey = ekey[45:] pubkey = ekey[45:]
key = PublicKey(ledger, pubkey, chain_code, n, depth) key = PubKey(ledger, pubkey, chain_code, n, depth)
elif ekey[:4] == ledger.extended_private_key_prefix: elif ekey[:4] == ledger.extended_private_key_prefix:
if ekey[45] != 0: if ekey[45] != 0:
raise ValueError('invalid extended private key prefix byte') raise ValueError('invalid extended private key prefix byte')
@ -337,6 +248,6 @@ def from_extended_key_string(ledger, ekey_str):
xpub6BsnM1W2Y7qLMiuhi7f7dbAwQZ5Cz5gYJCRzTNainXzQXYjFwtuQXHd xpub6BsnM1W2Y7qLMiuhi7f7dbAwQZ5Cz5gYJCRzTNainXzQXYjFwtuQXHd
3qfi3t3KJtHxshXezfjft93w4UE7BGMtKwhqEHae3ZA7d823DVrL 3qfi3t3KJtHxshXezfjft93w4UE7BGMtKwhqEHae3ZA7d823DVrL
return a PublicKey or PrivateKey. return a PubKey or PrivateKey.
""" """
return _from_extended_key(ledger, Base58.decode_check(ekey_str)) return _from_extended_key(ledger, Base58.decode_check(ekey_str))

File diff suppressed because it is too large Load diff

View file

@ -5,7 +5,7 @@ from lbry.wallet.transaction import OutputEffectiveAmountEstimator
MAXIMUM_TRIES = 100000 MAXIMUM_TRIES = 100000
STRATEGIES = ['sqlite'] # sqlite coin chooser is in database.py STRATEGIES = []
def strategy(method): def strategy(method):
@ -141,7 +141,7 @@ class CoinSelector:
_) -> List[OutputEffectiveAmountEstimator]: _) -> List[OutputEffectiveAmountEstimator]:
""" Accumulate UTXOs at random until there is enough to cover the target. """ """ Accumulate UTXOs at random until there is enough to cover the target. """
target = self.target + self.cost_of_change target = self.target + self.cost_of_change
self.random.shuffle(txos, random=self.random.random) # pylint: disable=deprecated-argument self.random.shuffle(txos, self.random.random)
selection = [] selection = []
amount = 0 amount = 0
for coin in txos: for coin in txos:

View file

@ -2,7 +2,6 @@ NULL_HASH32 = b'\x00'*32
CENT = 1000000 CENT = 1000000
COIN = 100*CENT COIN = 100*CENT
DUST = 1000
TIMEOUT = 30.0 TIMEOUT = 30.0

File diff suppressed because it is too large Load diff

View file

@ -1,18 +1,16 @@
import base64
import os import os
import struct import struct
import asyncio import asyncio
import hashlib
import logging import logging
import zlib
from datetime import date
from io import BytesIO from io import BytesIO
from typing import Optional, Iterator, Tuple, Callable from contextlib import asynccontextmanager
from typing import Optional, Iterator, Tuple
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from lbry.crypto.hash import sha512, double_sha256, ripemd160 from lbry.crypto.hash import sha512, double_sha256, ripemd160
from lbry.wallet.util import ArithUint256, date_to_julian_day from lbry.wallet.util import ArithUint256
from .checkpoints import HASHES
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -34,50 +32,25 @@ class Headers:
max_target = 0x0000ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff max_target = 0x0000ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
genesis_hash = b'9c89283ba0f3227f6c03b70216b9f665f0118d5e0fa729cedf4fb34d6a34f463' genesis_hash = b'9c89283ba0f3227f6c03b70216b9f665f0118d5e0fa729cedf4fb34d6a34f463'
target_timespan = 150 target_timespan = 150
checkpoints = HASHES checkpoint = (600_000, b'100b33ca3d0b86a48f0d6d6f30458a130ecb89d5affefe4afccb134d5a40f4c2')
first_block_timestamp = 1466646588 # block 1, as 0 is off by a lot
timestamp_average_offset = 160.6855883050695 # calculated at 733447
validate_difficulty: bool = True validate_difficulty: bool = True
def __init__(self, path) -> None: def __init__(self, path) -> None:
self.io = None if path == ':memory:':
self.io = BytesIO()
self.path = path self.path = path
self._size: Optional[int] = None self._size: Optional[int] = None
self.chunk_getter: Optional[Callable] = None
self.known_missing_checkpointed_chunks = set()
self.check_chunk_lock = asyncio.Lock()
async def open(self): async def open(self):
self.io = BytesIO()
if self.path != ':memory:': if self.path != ':memory:':
def _readit(): if not os.path.exists(self.path):
if os.path.exists(self.path): self.io = open(self.path, 'w+b')
with open(self.path, 'r+b') as header_file:
self.io.seek(0)
self.io.write(header_file.read())
await asyncio.get_event_loop().run_in_executor(None, _readit)
bytes_size = self.io.seek(0, os.SEEK_END)
self._size = bytes_size // self.header_size
max_checkpointed_height = max(self.checkpoints.keys() or [-1]) + 1000
if bytes_size % self.header_size:
log.warning("Reader file size doesnt match header size. Repairing, might take a while.")
await self.repair()
else: else:
# try repairing any incomplete write on tip from previous runs (outside of checkpoints, that are ok) self.io = open(self.path, 'r+b')
await self.repair(start_height=max_checkpointed_height)
await self.ensure_checkpointed_size()
await self.get_all_missing_headers()
async def close(self): async def close(self):
if self.io is not None:
def _close():
flags = 'r+b' if os.path.exists(self.path) else 'w+b'
with open(self.path, flags) as header_file:
header_file.write(self.io.getbuffer())
await asyncio.get_event_loop().run_in_executor(None, _close)
self.io.close() self.io.close()
self.io = None
@staticmethod @staticmethod
def serialize(header): def serialize(header):
@ -124,92 +97,23 @@ class Headers:
return new_target return new_target
def __len__(self) -> int: def __len__(self) -> int:
if self._size is None:
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
return self._size return self._size
def __bool__(self): def __bool__(self):
return True return True
async def get(self, height) -> dict: def __getitem__(self, height) -> dict:
if isinstance(height, slice): if isinstance(height, slice):
raise NotImplementedError("Slicing of header chain has not been implemented yet.") raise NotImplementedError("Slicing of header chain has not been implemented yet.")
try:
return self.deserialize(height, await self.get_raw_header(height))
except struct.error:
raise IndexError(f"failed to get {height}, at {len(self)}")
def estimated_timestamp(self, height, try_real_headers=True):
if height <= 0:
return
if try_real_headers and self.has_header(height):
offset = height * self.header_size
return struct.unpack('<I', self.io.getbuffer()[offset + 100: offset + 104])[0]
return int(self.first_block_timestamp + (height * self.timestamp_average_offset))
def estimated_julian_day(self, height):
return date_to_julian_day(date.fromtimestamp(self.estimated_timestamp(height, False)))
async def get_raw_header(self, height) -> bytes:
if self.chunk_getter:
await self.ensure_chunk_at(height)
if not 0 <= height <= self.height: if not 0 <= height <= self.height:
raise IndexError(f"{height} is out of bounds, current height: {self.height}") raise IndexError(f"{height} is out of bounds, current height: {self.height}")
return self._read(height) return self.deserialize(height, self.get_raw_header(height))
def _read(self, height, count=1): def get_raw_header(self, height) -> bytes:
offset = height * self.header_size self.io.seek(height * self.header_size, os.SEEK_SET)
return bytes(self.io.getbuffer()[offset: offset + self.header_size * count]) return self.io.read(self.header_size)
def chunk_hash(self, start, count):
return self.hash_header(self._read(start, count)).decode()
async def ensure_checkpointed_size(self):
max_checkpointed_height = max(self.checkpoints.keys() or [-1])
if self.height < max_checkpointed_height:
self._write(max_checkpointed_height, bytes([0] * self.header_size * 1000))
async def ensure_chunk_at(self, height):
async with self.check_chunk_lock:
if self.has_header(height):
log.debug("has header %s", height)
return
return await self.fetch_chunk(height)
async def fetch_chunk(self, height):
log.info("on-demand fetching height %s", height)
start = (height // 1000) * 1000
headers = await self.chunk_getter(start) # pylint: disable=not-callable
chunk = (
zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000)
)
chunk_hash = self.hash_header(chunk).decode()
if self.checkpoints.get(start) == chunk_hash:
self._write(start, chunk)
if start in self.known_missing_checkpointed_chunks:
self.known_missing_checkpointed_chunks.remove(start)
return
elif start not in self.checkpoints:
return # todo: fixme
raise Exception(
f"Checkpoint mismatch at height {start}. Expected {self.checkpoints[start]}, but got {chunk_hash} instead."
)
def has_header(self, height):
normalized_height = (height // 1000) * 1000
if normalized_height in self.checkpoints:
return normalized_height not in self.known_missing_checkpointed_chunks
empty = '56944c5d3f98413ef45cf54545538103cc9f298e0575820ad3591376e2e0f65d'
all_zeroes = '789d737d4f448e554b318c94063bbfa63e9ccda6e208f5648ca76ee68896557b'
return self.chunk_hash(height, 1) not in (empty, all_zeroes)
async def get_all_missing_headers(self):
# Heavy operation done in one optimized shot
for chunk_height, expected_hash in reversed(list(self.checkpoints.items())):
if chunk_height in self.known_missing_checkpointed_chunks:
continue
if self.chunk_hash(chunk_height, 1000) != expected_hash:
self.known_missing_checkpointed_chunks.add(chunk_height)
return self.known_missing_checkpointed_chunks
@property @property
def height(self) -> int: def height(self) -> int:
@ -219,9 +123,9 @@ class Headers:
def bytes_size(self): def bytes_size(self):
return len(self) * self.header_size return len(self) * self.header_size
async def hash(self, height=None) -> bytes: def hash(self, height=None) -> bytes:
return self.hash_header( return self.hash_header(
await self.get_raw_header(height if height is not None else self.height) self.get_raw_header(height if height is not None else self.height)
) )
@staticmethod @staticmethod
@ -230,18 +134,44 @@ class Headers:
return b'0' * 64 return b'0' * 64
return hexlify(double_sha256(header)[::-1]) return hexlify(double_sha256(header)[::-1])
@asynccontextmanager
async def checkpointed_connector(self):
buf = BytesIO()
try:
yield buf
finally:
await asyncio.sleep(0)
final_height = len(self) + buf.tell() // self.header_size
verifiable_bytes = (self.checkpoint[0] - len(self)) * self.header_size if self.checkpoint else 0
if verifiable_bytes > 0 and final_height >= self.checkpoint[0]:
buf.seek(0)
self.io.seek(0)
h = hashlib.sha256()
h.update(self.io.read())
h.update(buf.read(verifiable_bytes))
if h.hexdigest().encode() == self.checkpoint[1]:
buf.seek(0)
self._write(len(self), buf.read(verifiable_bytes))
remaining = buf.read()
buf.seek(0)
buf.write(remaining)
buf.truncate()
else:
log.warning("Checkpoint mismatch, connecting headers through slow method.")
if buf.tell() > 0:
await self.connect(len(self), buf.getvalue())
async def connect(self, start: int, headers: bytes) -> int: async def connect(self, start: int, headers: bytes) -> int:
added = 0 added = 0
bail = False bail = False
for height, chunk in self._iterate_chunks(start, headers): for height, chunk in self._iterate_chunks(start, headers):
try: try:
# validate_chunk() is CPU bound and reads previous chunks from file system # validate_chunk() is CPU bound and reads previous chunks from file system
await self.validate_chunk(height, chunk) self.validate_chunk(height, chunk)
except InvalidHeader as e: except InvalidHeader as e:
bail = True bail = True
chunk = chunk[:(height-e.height)*self.header_size] chunk = chunk[:(height-e.height)*self.header_size]
if chunk: added += self._write(height, chunk) if chunk else 0
added += self._write(height, chunk)
if bail: if bail:
break break
return added return added
@ -249,21 +179,20 @@ class Headers:
def _write(self, height, verified_chunk): def _write(self, height, verified_chunk):
self.io.seek(height * self.header_size, os.SEEK_SET) self.io.seek(height * self.header_size, os.SEEK_SET)
written = self.io.write(verified_chunk) // self.header_size written = self.io.write(verified_chunk) // self.header_size
# self.io.truncate() self.io.truncate()
# .seek()/.write()/.truncate() might also .flush() when needed # .seek()/.write()/.truncate() might also .flush() when needed
# the goal here is mainly to ensure we're definitely flush()'ing # the goal here is mainly to ensure we're definitely flush()'ing
self.io.flush() self.io.flush()
self._size = max(self._size or 0, self.io.tell() // self.header_size) self._size = self.io.tell() // self.header_size
return written return written
async def validate_chunk(self, height, chunk): def validate_chunk(self, height, chunk):
previous_hash, previous_header, previous_previous_header = None, None, None previous_hash, previous_header, previous_previous_header = None, None, None
if height > 0: if height > 0:
raw = await self.get_raw_header(height-1) previous_header = self[height-1]
previous_header = self.deserialize(height-1, raw) previous_hash = self.hash(height-1)
previous_hash = self.hash_header(raw)
if height > 1: if height > 1:
previous_previous_header = await self.get(height-2) previous_previous_header = self[height-2]
chunk_target = self.get_next_chunk_target(height // 2016 - 1) chunk_target = self.get_next_chunk_target(height // 2016 - 1)
for current_hash, current_header in self._iterate_headers(height, chunk): for current_hash, current_header in self._iterate_headers(height, chunk):
block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header) block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header)
@ -302,30 +231,28 @@ class Headers:
height, f"insufficient proof of work: {proof_of_work.value} vs target {target.value}" height, f"insufficient proof of work: {proof_of_work.value} vs target {target.value}"
) )
async def repair(self, start_height=0): async def repair(self):
previous_header_hash = fail = None previous_header_hash = fail = None
batch_size = 36 batch_size = 36
for height in range(start_height, self.height, batch_size): for start_height in range(0, self.height, batch_size):
headers = self._read(height, batch_size) self.io.seek(self.header_size * start_height)
headers = self.io.read(self.header_size*batch_size)
if len(headers) % self.header_size != 0: if len(headers) % self.header_size != 0:
headers = headers[:(len(headers) // self.header_size) * self.header_size] headers = headers[:(len(headers) // self.header_size) * self.header_size]
for header_hash, header in self._iterate_headers(height, headers): for header_hash, header in self._iterate_headers(start_height, headers):
height = header['block_height'] height = header['block_height']
if previous_header_hash: if height:
if header['prev_block_hash'] != previous_header_hash: if header['prev_block_hash'] != previous_header_hash:
fail = True fail = True
elif height == 0: else:
if header_hash != self.genesis_hash: if header_hash != self.genesis_hash:
fail = True fail = True
else:
# for sanity and clarity, since it is the only way we can end up here
assert start_height > 0 and height == start_height
if fail: if fail:
log.warning("Header file corrupted at height %s, truncating it.", height - 1) log.warning("Header file corrupted at height %s, truncating it.", height - 1)
self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET) self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET)
self.io.truncate() self.io.truncate()
self.io.flush() self.io.flush()
self._size = self.io.seek(0, os.SEEK_END) // self.header_size self._size = None
return return
previous_header_hash = header_hash previous_header_hash = header_hash
@ -349,6 +276,10 @@ class Headers:
header = headers[start:end] header = headers[start:end]
yield self.hash_header(header), self.deserialize(height+idx, header) yield self.hash_header(header), self.deserialize(height+idx, header)
@property
def claim_trie_root(self):
return self[self.height]['claim_trie_root']
@staticmethod @staticmethod
def header_hash_to_pow_hash(header_hash: bytes): def header_hash_to_pow_hash(header_hash: bytes):
header_hash_bytes = unhexlify(header_hash)[::-1] header_hash_bytes = unhexlify(header_hash)[::-1]
@ -364,4 +295,3 @@ class UnvalidatedHeaders(Headers):
validate_difficulty = False validate_difficulty = False
max_target = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff max_target = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
genesis_hash = b'6e3fcf1299d4ec5d79c3a4c91d624a4acf9e2e173d95a1a0504f677669687556' genesis_hash = b'6e3fcf1299d4ec5d79c3a4c91d624a4acf9e2e173d95a1a0504f677669687556'
checkpoints = {}

View file

@ -1,33 +1,34 @@
import os import os
import copy import zlib
import time import base64
import asyncio import asyncio
import logging import logging
from io import StringIO
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from operator import itemgetter from operator import itemgetter
from collections import defaultdict from collections import namedtuple, defaultdict
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from typing import Dict, Tuple, Type, Iterable, List, Optional, DefaultDict, NamedTuple from typing import Dict, Tuple, Type, Iterable, List, Optional, DefaultDict
import pylru
from lbry.schema.result import Outputs, INVALID, NOT_FOUND from lbry.schema.result import Outputs, INVALID, NOT_FOUND
from lbry.schema.url import URL from lbry.schema.url import URL
from lbry.crypto.hash import hash160, double_sha256, sha256 from lbry.crypto.hash import hash160, double_sha256, sha256
from lbry.crypto.base58 import Base58 from lbry.crypto.base58 import Base58
from lbry.utils import LRUCacheWithMetrics
from lbry.wallet.tasks import TaskGroup from .tasks import TaskGroup
from lbry.wallet.database import Database from .database import Database
from lbry.wallet.stream import StreamController from .stream import StreamController
from lbry.wallet.dewies import dewies_to_lbc from .dewies import dewies_to_lbc
from lbry.wallet.account import Account, AddressManager, SingleKey from .account import Account, AddressManager, SingleKey
from lbry.wallet.network import Network from .network import Network
from lbry.wallet.transaction import Transaction, Output from .transaction import Transaction, Output
from lbry.wallet.header import Headers, UnvalidatedHeaders from .header import Headers, UnvalidatedHeaders
from lbry.wallet.checkpoints import HASHES from .constants import TXO_TYPES, COIN, NULL_HASH32
from lbry.wallet.constants import TXO_TYPES, CLAIM_TYPES, COIN, NULL_HASH32 from .bip32 import PubKey, PrivateKey
from lbry.wallet.bip32 import PublicKey, PrivateKey from .coinselection import CoinSelector
from lbry.wallet.coinselection import CoinSelector
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -52,29 +53,25 @@ class LedgerRegistry(type):
return mcs.ledgers[ledger_id] return mcs.ledgers[ledger_id]
class TransactionEvent(NamedTuple): class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))):
address: str pass
tx: Transaction
class AddressesGeneratedEvent(NamedTuple): class AddressesGeneratedEvent(namedtuple('AddressesGeneratedEvent', ('address_manager', 'addresses'))):
address_manager: AddressManager pass
addresses: List[str]
class BlockHeightEvent(NamedTuple): class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))):
height: int pass
change: int
class TransactionCacheItem: class TransactionCacheItem:
__slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications' __slots__ = '_tx', 'lock', 'has_tx'
def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None): def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None):
self.has_tx = asyncio.Event() self.has_tx = asyncio.Event()
self.lock = lock or asyncio.Lock() self.lock = lock or asyncio.Lock()
self._tx = self.tx = tx self._tx = self.tx = tx
self.pending_verifications = 0
@property @property
def tx(self) -> Optional[Transaction]: def tx(self) -> Optional[Transaction]:
@ -106,9 +103,7 @@ class Ledger(metaclass=LedgerRegistry):
target_timespan = 150 target_timespan = 150
default_fee_per_byte = 50 default_fee_per_byte = 50
default_fee_per_name_char = 0 default_fee_per_name_char = 200000
checkpoints = HASHES
def __init__(self, config=None): def __init__(self, config=None):
self.config = config or {} self.config = config or {}
@ -119,10 +114,10 @@ class Ledger(metaclass=LedgerRegistry):
self.headers: Headers = self.config.get('headers') or self.headers_class( self.headers: Headers = self.config.get('headers') or self.headers_class(
os.path.join(self.path, "headers") os.path.join(self.path, "headers")
) )
self.headers.checkpoints = self.checkpoints
self.network: Network = self.config.get('network') or Network(self) self.network: Network = self.config.get('network') or Network(self)
self.network.on_header.listen(self.receive_header) self.network.on_header.listen(self.receive_header)
self.network.on_status.listen(self.process_status_update) self.network.on_status.listen(self.process_status_update)
self.network.on_connected.listen(self.join_network)
self.accounts = [] self.accounts = []
self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte) self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte)
@ -155,19 +150,17 @@ class Ledger(metaclass=LedgerRegistry):
self._on_ready_controller = StreamController() self._on_ready_controller = StreamController()
self.on_ready = self._on_ready_controller.stream self.on_ready = self._on_ready_controller.stream
self._tx_cache = LRUCacheWithMetrics(self.config.get("tx_cache_size", 1024), metric_name='tx') self._tx_cache = pylru.lrucache(100000)
self._update_tasks = TaskGroup() self._update_tasks = TaskGroup()
self._other_tasks = TaskGroup() # that we dont need to start
self._utxo_reservation_lock = asyncio.Lock() self._utxo_reservation_lock = asyncio.Lock()
self._header_processing_lock = asyncio.Lock() self._header_processing_lock = asyncio.Lock()
self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._history_lock = asyncio.Lock()
self.coin_selection_strategy = None self.coin_selection_strategy = None
self._known_addresses_out_of_sync = set() self._known_addresses_out_of_sync = set()
self.fee_per_name_char = self.config.get('fee_per_name_char', self.default_fee_per_name_char) self.fee_per_name_char = self.config.get('fee_per_name_char', self.default_fee_per_name_char)
self._balance_cache = LRUCacheWithMetrics(2 ** 15) self._balance_cache = pylru.lrucache(100000)
@classmethod @classmethod
def get_id(cls): def get_id(cls):
@ -178,25 +171,15 @@ class Ledger(metaclass=LedgerRegistry):
raw_address = cls.pubkey_address_prefix + h160 raw_address = cls.pubkey_address_prefix + h160
return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4])) return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4]))
@classmethod
def hash160_to_script_address(cls, h160):
raw_address = cls.script_address_prefix + h160
return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4]))
@staticmethod @staticmethod
def address_to_hash160(address): def address_to_hash160(address):
return Base58.decode(address)[1:21] return Base58.decode(address)[1:21]
@classmethod @classmethod
def is_pubkey_address(cls, address): def is_valid_address(cls, address):
decoded = Base58.decode_check(address) decoded = Base58.decode_check(address)
return decoded[0] == cls.pubkey_address_prefix[0] return decoded[0] == cls.pubkey_address_prefix[0]
@classmethod
def is_script_address(cls, address):
decoded = Base58.decode_check(address)
return decoded[0] == cls.script_address_prefix[0]
@classmethod @classmethod
def public_key_to_address(cls, public_key): def public_key_to_address(cls, public_key):
return cls.hash160_to_address(hash160(public_key)) return cls.hash160_to_address(hash160(public_key))
@ -226,7 +209,7 @@ class Ledger(metaclass=LedgerRegistry):
return account.get_private_key(address_info['chain'], address_info['pubkey'].n) return account.get_private_key(address_info['chain'], address_info['pubkey'].n)
return None return None
async def get_public_key_for_address(self, wallet, address) -> Optional[PublicKey]: async def get_public_key_for_address(self, wallet, address) -> Optional[PubKey]:
match = await self._get_account_and_address_info_for_address(wallet, address) match = await self._get_account_and_address_info_for_address(wallet, address)
if match: if match:
_, address_info = match _, address_info = match
@ -241,7 +224,7 @@ class Ledger(metaclass=LedgerRegistry):
async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]): async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]):
estimators = [] estimators = []
for account in funding_accounts: for account in funding_accounts:
utxos = await account.get_utxos(no_tx=True, no_channel_info=True) utxos = await account.get_utxos()
for utxo in utxos: for utxo in utxos:
estimators.append(utxo.get_estimator(self)) estimators.append(utxo.get_estimator(self))
return estimators return estimators
@ -252,15 +235,11 @@ class Ledger(metaclass=LedgerRegistry):
def get_address_count(self, **constraints): def get_address_count(self, **constraints):
return self.db.get_address_count(**constraints) return self.db.get_address_count(**constraints)
async def get_spendable_utxos(self, amount: int, funding_accounts: Optional[Iterable['Account']], min_amount=1): async def get_spendable_utxos(self, amount: int, funding_accounts):
min_amount = min(amount // 10, min_amount) async with self._utxo_reservation_lock:
txos = await self.get_effective_amount_estimators(funding_accounts)
fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self) fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self)
selector = CoinSelector(amount, fee) selector = CoinSelector(amount, fee)
async with self._utxo_reservation_lock:
if self.coin_selection_strategy == 'sqlite':
return await self.db.get_spendable_utxos(self, amount + fee, funding_accounts, min_amount=min_amount,
fee_per_byte=self.fee_per_byte)
txos = await self.get_effective_amount_estimators(funding_accounts)
spendables = selector.select(txos, self.coin_selection_strategy) spendables = selector.select(txos, self.coin_selection_strategy)
if spendables: if spendables:
await self.reserve_outputs(s.txo for s in spendables) await self.reserve_outputs(s.txo for s in spendables)
@ -283,7 +262,7 @@ class Ledger(metaclass=LedgerRegistry):
self.constraint_spending_utxos(constraints) self.constraint_spending_utxos(constraints)
return self.db.get_utxo_count(**constraints) return self.db.get_utxo_count(**constraints)
async def get_txos(self, resolve=False, **constraints) -> List[Output]: async def get_txos(self, resolve=False, **constraints):
txos = await self.db.get_txos(**constraints) txos = await self.db.get_txos(**constraints)
if resolve: if resolve:
return await self._resolve_for_local_results(constraints.get('accounts', []), txos) return await self._resolve_for_local_results(constraints.get('accounts', []), txos)
@ -292,12 +271,6 @@ class Ledger(metaclass=LedgerRegistry):
def get_txo_count(self, **constraints): def get_txo_count(self, **constraints):
return self.db.get_txo_count(**constraints) return self.db.get_txo_count(**constraints)
def get_txo_sum(self, **constraints):
return self.db.get_txo_sum(**constraints)
def get_txo_plot(self, **constraints):
return self.db.get_txo_plot(**constraints)
def get_transactions(self, **constraints): def get_transactions(self, **constraints):
return self.db.get_transactions(**constraints) return self.db.get_transactions(**constraints)
@ -329,19 +302,16 @@ class Ledger(metaclass=LedgerRegistry):
async def start(self): async def start(self):
if not os.path.exists(self.path): if not os.path.exists(self.path):
os.mkdir(self.path) os.mkdir(self.path)
await asyncio.wait(map(asyncio.create_task, [ await asyncio.wait([
self.db.open(), self.db.open(),
self.headers.open() self.headers.open()
])) ])
fully_synced = self.on_ready.first first_connection = self.network.on_connected.first
asyncio.create_task(self.network.start()) asyncio.ensure_future(self.network.start())
await self.network.on_connected.first await first_connection
async with self._header_processing_lock: async with self._header_processing_lock:
await self._update_tasks.add(self.initial_headers_sync()) await self._update_tasks.add(self.initial_headers_sync())
self.network.on_connected.listen(self.join_network) await self._on_ready_controller.stream.first
asyncio.ensure_future(self.join_network())
await fully_synced
await self.db.release_all_outputs()
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts)) await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
await asyncio.gather(*(a.save_max_gap() for a in self.accounts)) await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
if len(self.accounts) > 10: if len(self.accounts) > 10:
@ -352,37 +322,37 @@ class Ledger(metaclass=LedgerRegistry):
async def join_network(self, *_): async def join_network(self, *_):
log.info("Subscribing and updating accounts.") log.info("Subscribing and updating accounts.")
await self._update_tasks.add(self.subscribe_accounts()) async with self._header_processing_lock:
await self.update_headers()
await self.subscribe_accounts()
await self._update_tasks.done.wait() await self._update_tasks.done.wait()
self._on_ready_controller.add(True) self._on_ready_controller.add(True)
async def stop(self): async def stop(self):
self._update_tasks.cancel() self._update_tasks.cancel()
self._other_tasks.cancel()
await self._update_tasks.done.wait() await self._update_tasks.done.wait()
await self._other_tasks.done.wait()
await self.network.stop() await self.network.stop()
await self.db.close() await self.db.close()
await self.headers.close() await self.headers.close()
async def tasks_are_done(self):
await self._update_tasks.done.wait()
await self._other_tasks.done.wait()
@property @property
def local_height_including_downloaded_height(self): def local_height_including_downloaded_height(self):
return max(self.headers.height, self._download_height) return max(self.headers.height, self._download_height)
async def initial_headers_sync(self): async def initial_headers_sync(self):
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True) target = self.network.remote_height + 1
self.headers.chunk_getter = get_chunk current = len(self.headers)
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=4096, b64=True)
async def doit(): chunks = [asyncio.create_task(get_chunk(height)) for height in range(current, target, 4096)]
for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)): total = 0
async with self._header_processing_lock: async with self.headers.checkpointed_connector() as buffer:
await self.headers.ensure_chunk_at(height) for chunk in chunks:
self._other_tasks.add(doit()) headers = await chunk
await self.update_headers() total += buffer.write(
zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000)
)
self._download_height = current + total // self.headers.header_size
log.info("Headers sync: %s / %s", self._download_height, target)
async def update_headers(self, height=None, headers=None, subscription_update=False): async def update_headers(self, height=None, headers=None, subscription_update=False):
rewound = 0 rewound = 0
@ -429,7 +399,6 @@ class Ledger(metaclass=LedgerRegistry):
"Blockchain Reorganization: attempting rewind to height %s from starting height %s", "Blockchain Reorganization: attempting rewind to height %s from starting height %s",
height, height+rewound height, height+rewound
) )
self._tx_cache.clear()
else: else:
raise IndexError(f"headers.connect() returned negative number ({added})") raise IndexError(f"headers.connect() returned negative number ({added})")
@ -466,15 +435,14 @@ class Ledger(metaclass=LedgerRegistry):
async def subscribe_accounts(self): async def subscribe_accounts(self):
if self.network.is_connected and self.accounts: if self.network.is_connected and self.accounts:
log.info("Subscribe to %i accounts", len(self.accounts)) log.info("Subscribe to %i accounts", len(self.accounts))
await asyncio.wait(map(asyncio.create_task, [ await asyncio.wait([
self.subscribe_account(a) for a in self.accounts self.subscribe_account(a) for a in self.accounts
])) ])
async def subscribe_account(self, account: Account): async def subscribe_account(self, account: Account):
for address_manager in account.address_managers.values(): for address_manager in account.address_managers.values():
await self.subscribe_addresses(address_manager, await address_manager.get_addresses()) await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
await account.ensure_address_gap() await account.ensure_address_gap()
await account.deterministic_channel_keys.ensure_cache_primed()
async def unsubscribe_account(self, account: Account): async def unsubscribe_account(self, account: Account):
for address in await account.get_addresses(): for address in await account.get_addresses():
@ -495,10 +463,8 @@ class Ledger(metaclass=LedgerRegistry):
for address, remote_status in zip(batch, results): for address, remote_status in zip(batch, results):
self._update_tasks.add(self.update_history(address, remote_status, address_manager)) self._update_tasks.add(self.update_history(address, remote_status, address_manager))
addresses_remaining = addresses_remaining[batch_size:] addresses_remaining = addresses_remaining[batch_size:]
if self.network.client and self.network.client.server_address_and_port:
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining), log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
len(addresses), *self.network.client.server_address_and_port) len(addresses), *self.network.client.server_address_and_port)
if self.network.client and self.network.client.server_address_and_port:
log.info( log.info(
"finished subscribing to %i addresses on %s:%i", len(addresses), "finished subscribing to %i addresses on %s:%i", len(addresses),
*self.network.client.server_address_and_port *self.network.client.server_address_and_port
@ -508,10 +474,10 @@ class Ledger(metaclass=LedgerRegistry):
address, remote_status = update address, remote_status = update
self._update_tasks.add(self.update_history(address, remote_status)) self._update_tasks.add(self.update_history(address, remote_status))
async def update_history(self, address, remote_status, address_manager: AddressManager = None, async def update_history(self, address, remote_status, address_manager: AddressManager = None):
reattempt_update: bool = True):
async with self._address_update_locks[address]: async with self._address_update_locks[address]:
self._known_addresses_out_of_sync.discard(address) self._known_addresses_out_of_sync.discard(address)
local_status, local_history = await self.get_local_status_and_history(address) local_status, local_history = await self.get_local_status_and_history(address)
if local_status == remote_status: if local_status == remote_status:
@ -521,56 +487,57 @@ class Ledger(metaclass=LedgerRegistry):
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history)) remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
we_need = set(remote_history) - set(local_history) we_need = set(remote_history) - set(local_history)
if not we_need: if not we_need:
remote_missing = set(local_history) - set(remote_history)
if remote_missing:
log.warning(
"%i transactions we have for %s are not in the remote address history",
len(remote_missing), address
)
return True return True
to_request = {} cache_tasks: List[asyncio.Future[Transaction]] = []
pending_synced_history = {} synced_history = StringIO()
already_synced = set()
already_synced_offset = 0
for i, (txid, remote_height) in enumerate(remote_history): for i, (txid, remote_height) in enumerate(remote_history):
if i == already_synced_offset and i < len(local_history) and local_history[i] == (txid, remote_height): if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
pending_synced_history[i] = f'{txid}:{remote_height}:' synced_history.write(f'{txid}:{remote_height}:')
already_synced.add((txid, remote_height)) else:
already_synced_offset += 1 check_local = (txid, remote_height) not in we_need
cache_tasks.append(asyncio.ensure_future(
self.cache_transaction(txid, remote_height, check_local=check_local)
))
synced_txs = []
for task in cache_tasks:
tx = await task
check_db_for_txos = []
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue continue
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id)
if cache_item is not None:
if cache_item.tx is None:
await cache_item.has_tx.wait()
assert cache_item.tx is not None
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
else:
check_db_for_txos.append(txi.txo_ref.id)
tx_indexes = {} referenced_txos = {} if not check_db_for_txos else {
txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos, no_tx=True)
}
for i, (txid, remote_height) in enumerate(remote_history): for txi in tx.inputs:
tx_indexes[txid] = i if txi.txo_ref.txo is not None:
if (txid, remote_height) in already_synced:
continue continue
to_request[i] = (txid, remote_height) referenced_txo = referenced_txos.get(txi.txo_ref.id)
if referenced_txo is not None:
txi.txo_ref = referenced_txo.ref
log.debug( synced_history.write(f'{tx.id}:{tx.height}:')
"request %i transactions, %i/%i for %s are already synced", len(to_request), len(already_synced), synced_txs.append(tx)
len(remote_history), address
await self.db.save_transaction_io_batch(
synced_txs, address, self.address_to_hash160(address), synced_history.getvalue()
) )
remote_history_txids = {txid for txid, _ in remote_history} await asyncio.wait([
async for tx in self.request_synced_transactions(to_request, remote_history_txids, address): self._on_transaction_controller.add(TransactionEvent(address, tx))
self.maybe_has_channel_key(tx) for tx in synced_txs
pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:" ])
if len(pending_synced_history) % 100 == 0:
log.info("Syncing address %s: %d/%d", address, len(pending_synced_history), len(to_request))
log.info("Sync finished for address %s: %d/%d", address, len(pending_synced_history), len(to_request))
assert len(pending_synced_history) == len(remote_history), \
f"{len(pending_synced_history)} vs {len(remote_history)} for {address}"
synced_history = ""
for remote_i, i in zip(range(len(remote_history)), sorted(pending_synced_history.keys())):
assert i == remote_i, f"{i} vs {remote_i}"
txid, height = remote_history[remote_i]
if f"{txid}:{height}:" != pending_synced_history[i]:
log.warning("history mismatch: %s vs %s", remote_history[remote_i], pending_synced_history[i])
synced_history += pending_synced_history[i]
await self.db.set_address_history(address, synced_history)
if address_manager is None: if address_manager is None:
address_manager = await self.get_address_manager_for_address(address) address_manager = await self.get_address_manager_for_address(address)
@ -579,141 +546,55 @@ class Ledger(metaclass=LedgerRegistry):
await address_manager.ensure_address_gap() await address_manager.ensure_address_gap()
local_status, local_history = \ local_status, local_history = \
await self.get_local_status_and_history(address, synced_history) await self.get_local_status_and_history(address, synced_history.getvalue())
if local_status != remote_status: if local_status != remote_status:
if local_history == remote_history: if local_history == remote_history:
log.warning(
"%s has a synced history but a mismatched status", address
)
return True return True
remote_set = set(remote_history)
local_set = set(local_history)
log.warning( log.warning(
"%s is out of sync after syncing.\n" "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
"Remote: %s with %d items (%i unique), local: %s with %d items (%i unique).\n" remote_status, len(remote_history), local_status, len(local_history)
"Histories are mismatched on %i items.\n"
"Local is missing\n"
"%s\n"
"Remote is missing\n"
"%s\n"
"******",
address, remote_status, len(remote_history), len(remote_set),
local_status, len(local_history), len(local_set), len(remote_set.symmetric_difference(local_set)),
"\n".join([f"{txid} - {height}" for txid, height in local_set.difference(remote_set)]),
"\n".join([f"{txid} - {height}" for txid, height in remote_set.difference(local_set)])
) )
log.warning("local: %s", local_history)
log.warning("remote: %s", remote_history)
self._known_addresses_out_of_sync.add(address) self._known_addresses_out_of_sync.add(address)
return False return False
else: else:
log.debug("finished syncing transaction history for %s, %i known txs", address, len(local_history))
return True return True
async def maybe_verify_transaction(self, tx, remote_height, merkle=None): async def cache_transaction(self, txid, remote_height, check_local=True):
cache_item = self._tx_cache.get(txid)
if cache_item is None:
cache_item = self._tx_cache[txid] = TransactionCacheItem()
elif cache_item.tx is not None and \
cache_item.tx.height >= remote_height and \
(cache_item.tx.is_verified or remote_height < 1):
return cache_item.tx # cached tx is already up-to-date
async with cache_item.lock:
tx = cache_item.tx
if tx is None and check_local:
# check local db
tx = cache_item.tx = await self.db.get_transaction(txid=txid)
if tx is None:
# fetch from network
_raw = await self.network.retriable_call(self.network.get_transaction, txid, remote_height)
tx = Transaction(unhexlify(_raw))
cache_item.tx = tx # make sure it's saved before caching it
await self.maybe_verify_transaction(tx, remote_height)
return tx
async def maybe_verify_transaction(self, tx, remote_height):
tx.height = remote_height tx.height = remote_height
if 0 < remote_height < len(self.headers): if 0 < remote_height < len(self.headers):
# can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
if not merkle:
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height) merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height)
if 'merkle' not in merkle:
return
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = await self.headers.get(remote_height) header = self.headers[remote_height]
tx.position = merkle['pos'] tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root'] tx.is_verified = merkle_root == header['merkle_root']
return tx
def maybe_has_channel_key(self, tx):
for txo in tx._outputs:
if txo.can_decode_claim and txo.claim.is_channel:
for account in self.accounts:
account.deterministic_channel_keys.maybe_generate_deterministic_key_for_channel(txo)
async def request_transactions(self, to_request: Tuple[Tuple[str, int], ...], cached=False):
batches = [[]]
remote_heights = {}
cache_hits = set()
for txid, height in sorted(to_request, key=lambda x: x[1]):
if cached:
cached_tx = self._tx_cache.get(txid)
if cached_tx is not None:
if cached_tx.tx is not None and cached_tx.tx.is_verified:
cache_hits.add(txid)
continue
else:
self._tx_cache[txid] = TransactionCacheItem()
remote_heights[txid] = height
if len(batches[-1]) == 100:
batches.append([])
batches[-1].append(txid)
if not batches[-1]:
batches.pop()
if cached and cache_hits:
yield {txid: self._tx_cache[txid].tx for txid in cache_hits}
for batch in batches:
txs = await self._single_batch(batch, remote_heights)
if cached:
for txid, tx in txs.items():
self._tx_cache[txid].tx = tx
yield txs
async def request_synced_transactions(self, to_request, remote_history, address):
async for txs in self.request_transactions(((txid, height) for txid, height in to_request.values())):
for tx in txs.values():
yield tx
await self._sync_and_save_batch(address, remote_history, txs)
async def _single_batch(self, batch, remote_heights):
heights = {remote_heights[txid] for txid in batch}
unrestriced = 0 < min(heights) < max(heights) < max(self.headers.checkpoints or [0])
batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch, not unrestriced)
txs = {}
for txid, (raw, merkle) in batch_result.items():
remote_height = remote_heights[txid]
tx = Transaction(unhexlify(raw), height=remote_height)
txs[tx.id] = tx
await self.maybe_verify_transaction(tx, remote_height, merkle)
return txs
async def _sync_and_save_batch(self, address, remote_history, pending_txs):
await asyncio.gather(*(self._sync(tx, remote_history, pending_txs) for tx in pending_txs.values()))
await self.db.save_transaction_io_batch(
pending_txs.values(), address, self.address_to_hash160(address), ""
)
while pending_txs:
self._on_transaction_controller.add(TransactionEvent(address, pending_txs.popitem()[1]))
async def _sync(self, tx, remote_history, pending_txs):
check_db_for_txos = {}
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
wanted_txid = txi.txo_ref.tx_ref.id
if wanted_txid not in remote_history:
continue
if wanted_txid in pending_txs:
txi.txo_ref = pending_txs[wanted_txid].outputs[txi.txo_ref.position].ref
else:
check_db_for_txos[txi] = txi.txo_ref.id
referenced_txos = {} if not check_db_for_txos else {
txo.id: txo for txo in await self.db.get_txos(
txoid__in=list(check_db_for_txos.values()), order_by='txo.txoid', no_tx=True
)
}
for txi in check_db_for_txos:
if txi.txo_ref.id in referenced_txos:
txi.txo_ref = referenced_txos[txi.txo_ref.id].ref
else:
tx_from_db = await self.db.get_transaction(txid=txi.txo_ref.tx_ref.id)
if tx_from_db is None:
log.warning("%s not on db, not on cache, but on remote history!", txi.txo_ref.id)
else:
txi.txo_ref = tx_from_db.outputs[txi.txo_ref.position].ref
return tx
async def get_address_manager_for_address(self, address) -> Optional[AddressManager]: async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
details = await self.db.get_address(address=address) details = await self.db.get_address(address=address)
@ -722,21 +603,11 @@ class Ledger(metaclass=LedgerRegistry):
return account.address_managers[details['chain']] return account.address_managers[details['chain']]
return None return None
async def broadcast_or_release(self, tx, blocking=False):
try:
await self.broadcast(tx)
except:
await self.release_tx(tx)
raise
if blocking:
await self.wait(tx, timeout=None)
def broadcast(self, tx): def broadcast(self, tx):
# broadcast can't be a retriable call yet # broadcast can't be a retriable call yet
return self.network.broadcast(hexlify(tx.raw).decode()) return self.network.broadcast(hexlify(tx.raw).decode())
async def wait(self, tx: Transaction, height=-1, timeout=1): async def wait(self, tx: Transaction, height=-1, timeout=1):
timeout = timeout or 600 # after 10 minutes there is almost 0 hope
addresses = set() addresses = set()
for txi in tx.inputs: for txi in tx.inputs:
if txi.txo_ref.txo is not None: if txi.txo_ref.txo is not None:
@ -744,84 +615,42 @@ class Ledger(metaclass=LedgerRegistry):
self.hash160_to_address(txi.txo_ref.txo.pubkey_hash) self.hash160_to_address(txi.txo_ref.txo.pubkey_hash)
) )
for txo in tx.outputs: for txo in tx.outputs:
if txo.is_pubkey_hash: if txo.has_address:
addresses.add(self.hash160_to_address(txo.pubkey_hash)) addresses.add(self.hash160_to_address(txo.pubkey_hash))
elif txo.is_script_hash:
addresses.add(self.hash160_to_script_address(txo.script_hash))
start = int(time.perf_counter())
while timeout and (int(time.perf_counter()) - start) <= timeout:
if await self._wait_round(tx, height, addresses):
return
raise asyncio.TimeoutError(f'Timed out waiting for transaction. {tx.id}')
async def _wait_round(self, tx: Transaction, height: int, addresses: Iterable[str]):
records = await self.db.get_addresses(address__in=addresses) records = await self.db.get_addresses(address__in=addresses)
_, pending = await asyncio.wait([ _, pending = await asyncio.wait([
self.on_transaction.where(partial( self.on_transaction.where(partial(
lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id, lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id,
address_record['address'] address_record['address']
)) for address_record in records )) for address_record in records
], timeout=1) ], timeout=timeout)
if not pending: if pending:
return True
records = await self.db.get_addresses(address__in=addresses) records = await self.db.get_addresses(address__in=addresses)
for record in records: for record in records:
found = False
local_history = (await self.get_local_status_and_history( local_history = (await self.get_local_status_and_history(
record['address'], history=record['history'] record['address'], history=record['history']
))[1] if record['history'] else [] ))[1] if record['history'] else []
for txid, local_height in local_history: for txid, local_height in local_history:
if txid == tx.id: if txid == tx.id and local_height >= height:
if local_height >= height or (local_height == 0 and height > local_height): found = True
return True if not found:
log.warning( print(record['history'], addresses, tx.id)
"local history has higher height than remote for %s (%i vs %i)", txid, raise asyncio.TimeoutError('Timed out waiting for transaction.')
local_height, height
)
return False
log.warning(
"local history does not contain %s, requested height %i", tx.id, height
)
return False
async def _inflate_outputs( async def _inflate_outputs(self, query, accounts) -> Tuple[List[Output], dict, int, int]:
self, query, accounts,
include_purchase_receipt=False,
include_is_my_output=False,
include_sent_supports=False,
include_sent_tips=False,
include_received_tips=False) -> Tuple[List[Output], dict, int, int]:
encoded_outputs = await query encoded_outputs = await query
outputs = Outputs.from_base64(encoded_outputs or '') # TODO: why is the server returning None? outputs = Outputs.from_base64(encoded_outputs or b'') # TODO: why is the server returning None?
txs: List[Transaction] = [] txs = []
if len(outputs.txs) > 0: if len(outputs.txs) > 0:
async for tx in self.request_transactions(tuple(outputs.txs), cached=True): txs: List[Transaction] = await asyncio.gather(*(
txs.extend(tx.values()) self.cache_transaction(*tx) for tx in outputs.txs
))
_txos, blocked = outputs.inflate(txs) if accounts:
txos = []
for txo in _txos:
if isinstance(txo, Output):
# transactions and outputs are cached and shared between wallets
# we don't want to leak informaion between wallet so we add the
# wallet specific metadata on throw away copies of the txos
txo = copy.copy(txo)
channel = txo.channel
txo.purchase_receipt = None
txo.update_annotations(None)
txo.channel = channel
txos.append(txo)
includes = (
include_purchase_receipt, include_is_my_output,
include_sent_supports, include_sent_tips
)
if accounts and any(includes):
receipts = {}
if include_purchase_receipt:
priced_claims = [] priced_claims = []
for txo in txos: for tx in txs:
if isinstance(txo, Output) and txo.has_price: for txo in tx.outputs:
if txo.has_price:
priced_claims.append(txo) priced_claims.append(txo)
if priced_claims: if priced_claims:
receipts = { receipts = {
@ -831,54 +660,14 @@ class Ledger(metaclass=LedgerRegistry):
purchased_claim_id__in=[c.claim_id for c in priced_claims] purchased_claim_id__in=[c.claim_id for c in priced_claims]
) )
} }
for txo in txos: for txo in priced_claims:
if isinstance(txo, Output) and txo.can_decode_claim:
if include_purchase_receipt:
txo.purchase_receipt = receipts.get(txo.claim_id) txo.purchase_receipt = receipts.get(txo.claim_id)
if include_is_my_output: txos, blocked = outputs.inflate(txs)
mine = await self.db.get_txo_count(
claim_id=txo.claim_id, txo_type__in=CLAIM_TYPES, is_my_output=True,
is_spent=False, accounts=accounts
)
if mine:
txo.is_my_output = True
else:
txo.is_my_output = False
if include_sent_supports:
supports = await self.db.get_txo_sum(
claim_id=txo.claim_id, txo_type=TXO_TYPES['support'],
is_my_input=True, is_my_output=True,
is_spent=False, accounts=accounts
)
txo.sent_supports = supports
if include_sent_tips:
tips = await self.db.get_txo_sum(
claim_id=txo.claim_id, txo_type=TXO_TYPES['support'],
is_my_input=True, is_my_output=False,
accounts=accounts
)
txo.sent_tips = tips
if include_received_tips:
tips = await self.db.get_txo_sum(
claim_id=txo.claim_id, txo_type=TXO_TYPES['support'],
is_my_input=False, is_my_output=True,
accounts=accounts
)
txo.received_tips = tips
return txos, blocked, outputs.offset, outputs.total return txos, blocked, outputs.offset, outputs.total
async def resolve(self, accounts, urls, **kwargs): async def resolve(self, accounts, urls):
txos = []
urls_copy = list(urls)
resolve = partial(self.network.retriable_call, self.network.resolve) resolve = partial(self.network.retriable_call, self.network.resolve)
while urls_copy: txos = (await self._inflate_outputs(resolve(urls), accounts))[0]
batch, urls_copy = urls_copy[:100], urls_copy[100:]
txos.extend(
(await self._inflate_outputs(
resolve(batch), accounts, **kwargs
))[0]
)
assert len(urls) == len(txos), "Mismatch between urls requested for resolve and responses received." assert len(urls) == len(txos), "Mismatch between urls requested for resolve and responses received."
result = {} result = {}
for url, txo in zip(urls, txos): for url, txo in zip(urls, txos):
@ -891,35 +680,12 @@ class Ledger(metaclass=LedgerRegistry):
result[url] = txo result[url] = txo
return result return result
async def sum_supports(self, new_sdk_server, **kwargs) -> List[Dict]: async def claim_search(self, accounts, **kwargs) -> Tuple[List[Output], dict, int, int]:
return await self.network.sum_supports(new_sdk_server, **kwargs) return await self._inflate_outputs(self.network.claim_search(**kwargs), accounts)
async def claim_search( async def get_claim_by_claim_id(self, accounts, claim_id) -> Output:
self, accounts, for claim in (await self.claim_search(accounts, claim_id=claim_id))[0]:
include_purchase_receipt=False, return claim
include_is_my_output=False,
**kwargs) -> Tuple[List[Output], dict, int, int]:
return await self._inflate_outputs(
self.network.claim_search(**kwargs), accounts,
include_purchase_receipt=include_purchase_receipt,
include_is_my_output=include_is_my_output
)
# async def get_claim_by_claim_id(self, accounts, claim_id, **kwargs) -> Output:
# return await self.network.get_claim_by_id(claim_id)
async def get_claim_by_claim_id(self, claim_id, accounts=None, include_purchase_receipt=False,
include_is_my_output=False):
accounts = accounts or []
# return await self.network.get_claim_by_id(claim_id)
inflated = await self._inflate_outputs(
self.network.get_claim_by_id(claim_id), accounts,
include_purchase_receipt=include_purchase_receipt,
include_is_my_output=include_is_my_output,
)
txos = inflated[0]
if txos:
return txos[0]
async def _report_state(self): async def _report_state(self):
try: try:
@ -938,7 +704,9 @@ class Ledger(metaclass=LedgerRegistry):
"%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ", "%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ",
account.id, balance, total_receiving, account.receiving.gap, total_change, account.id, balance, total_receiving, account.receiving.gap, total_change,
account.change.gap, channel_count, len(account.channel_keys), claim_count) account.change.gap, channel_count, len(account.channel_keys), claim_count)
except Exception: except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception( log.exception(
'Failed to display wallet state, please file issue ' 'Failed to display wallet state, please file issue '
'for this bug along with the traceback you see below:') 'for this bug along with the traceback you see below:')
@ -961,7 +729,9 @@ class Ledger(metaclass=LedgerRegistry):
claim_ids = [p.purchased_claim_id for p in purchases] claim_ids = [p.purchased_claim_id for p in purchases]
try: try:
resolved, _, _, _ = await self.claim_search([], claim_ids=claim_ids) resolved, _, _, _ = await self.claim_search([], claim_ids=claim_ids)
except Exception: except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception("Resolve failed while looking up purchased claim ids:") log.exception("Resolve failed while looking up purchased claim ids:")
resolved = [] resolved = []
lookup = {claim.claim_id: claim for claim in resolved} lookup = {claim.claim_id: claim for claim in resolved}
@ -973,11 +743,6 @@ class Ledger(metaclass=LedgerRegistry):
return self.db.get_purchase_count(**constraints) return self.db.get_purchase_count(**constraints)
async def _resolve_for_local_results(self, accounts, txos): async def _resolve_for_local_results(self, accounts, txos):
txos = await self._resolve_for_local_claim_results(accounts, txos)
txos = await self._resolve_for_local_support_results(accounts, txos)
return txos
async def _resolve_for_local_claim_results(self, accounts, txos):
results = [] results = []
response = await self.resolve( response = await self.resolve(
accounts, [txo.permanent_url for txo in txos if txo.can_decode_claim] accounts, [txo.permanent_url for txo in txos if txo.can_decode_claim]
@ -993,23 +758,6 @@ class Ledger(metaclass=LedgerRegistry):
results.append(txo) results.append(txo)
return results return results
async def _resolve_for_local_support_results(self, accounts, txos):
channel_ids = set()
signed_support_txos = []
for txo in txos:
support = txo.can_decode_support
if support and support.signing_channel_id:
channel_ids.add(support.signing_channel_id)
signed_support_txos.append(txo)
if channel_ids:
channels = {
channel.claim_id: channel for channel in
(await self.claim_search(accounts, claim_ids=list(channel_ids)))[0]
}
for txo in signed_support_txos:
txo.channel = channels.get(txo.support.signing_channel_id)
return txos
async def get_claims(self, resolve=False, **constraints): async def get_claims(self, resolve=False, **constraints):
claims = await self.db.get_claims(**constraints) claims = await self.db.get_claims(**constraints)
if resolve: if resolve:
@ -1041,7 +789,9 @@ class Ledger(metaclass=LedgerRegistry):
claim_ids = collection.claim.collection.claims.ids[offset:page_size+offset] claim_ids = collection.claim.collection.claims.ids[offset:page_size+offset]
try: try:
resolve_results, _, _, _ = await self.claim_search([], claim_ids=claim_ids) resolve_results, _, _, _ = await self.claim_search([], claim_ids=claim_ids)
except Exception: except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception("Resolve failed while looking up collection claim ids:") log.exception("Resolve failed while looking up collection claim ids:")
return [] return []
claims = [] claims = []
@ -1056,10 +806,8 @@ class Ledger(metaclass=LedgerRegistry):
claims.append(None) claims.append(None)
return claims return claims
async def get_collections(self, resolve_claims=0, resolve=False, **constraints): async def get_collections(self, resolve_claims=0, **constraints):
collections = await self.db.get_collections(**constraints) collections = await self.db.get_collections(**constraints)
if resolve:
collections = await self._resolve_for_local_results(constraints.get('accounts', []), collections)
if resolve_claims > 0: if resolve_claims > 0:
for collection in collections: for collection in collections:
collection.claims = await self.resolve_collection(collection, page_size=resolve_claims) collection.claims = await self.resolve_collection(collection, page_size=resolve_claims)
@ -1074,15 +822,12 @@ class Ledger(metaclass=LedgerRegistry):
def get_support_count(self, **constraints): def get_support_count(self, **constraints):
return self.db.get_support_count(**constraints) return self.db.get_support_count(**constraints)
async def get_transaction_history(self, read_only=False, **constraints): async def get_transaction_history(self, **constraints):
txs: List[Transaction] = await self.db.get_transactions( txs: List[Transaction] = await self.db.get_transactions(**constraints)
include_is_my_output=True, include_is_spent=True,
read_only=read_only, **constraints
)
headers = self.headers headers = self.headers
history = [] history = []
for tx in txs: # pylint: disable=too-many-nested-blocks for tx in txs: # pylint: disable=too-many-nested-blocks
ts = headers.estimated_timestamp(tx.height) ts = headers[tx.height]['timestamp'] if tx.height > 0 else None
item = { item = {
'txid': tx.id, 'txid': tx.id,
'timestamp': ts, 'timestamp': ts,
@ -1094,7 +839,7 @@ class Ledger(metaclass=LedgerRegistry):
'abandon_info': [], 'abandon_info': [],
'purchase_info': [] 'purchase_info': []
} }
is_my_inputs = all(txi.is_my_input for txi in tx.inputs) is_my_inputs = all([txi.is_my_account for txi in tx.inputs])
if is_my_inputs: if is_my_inputs:
# fees only matter if we are the ones paying them # fees only matter if we are the ones paying them
item['value'] = dewies_to_lbc(tx.net_account_balance+tx.fee) item['value'] = dewies_to_lbc(tx.net_account_balance+tx.fee)
@ -1187,8 +932,8 @@ class Ledger(metaclass=LedgerRegistry):
history.append(item) history.append(item)
return history return history
def get_transaction_history_count(self, read_only=False, **constraints): def get_transaction_history_count(self, **constraints):
return self.db.get_transaction_count(read_only=read_only, **constraints) return self.db.get_transaction_count(**constraints)
async def get_detailed_balance(self, accounts, confirmations=0): async def get_detailed_balance(self, accounts, confirmations=0):
result = { result = {
@ -1205,7 +950,7 @@ class Ledger(metaclass=LedgerRegistry):
balance = self._balance_cache.get(account.id) balance = self._balance_cache.get(account.id)
if not balance: if not balance:
balance = self._balance_cache[account.id] =\ balance = self._balance_cache[account.id] =\
await account.get_detailed_balance(confirmations) await account.get_detailed_balance(confirmations, reserved_subtotals=True)
for key, value in balance.items(): for key, value in balance.items():
if key == 'reserved_subtotals': if key == 'reserved_subtotals':
for subkey, subvalue in value.items(): for subkey, subvalue in value.items():
@ -1221,7 +966,6 @@ class TestNetLedger(Ledger):
script_address_prefix = bytes((196,)) script_address_prefix = bytes((196,))
extended_public_key_prefix = unhexlify('043587cf') extended_public_key_prefix = unhexlify('043587cf')
extended_private_key_prefix = unhexlify('04358394') extended_private_key_prefix = unhexlify('04358394')
checkpoints = {}
class RegTestLedger(Ledger): class RegTestLedger(Ledger):
@ -1236,4 +980,3 @@ class RegTestLedger(Ledger):
genesis_hash = '6e3fcf1299d4ec5d79c3a4c91d624a4acf9e2e173d95a1a0504f677669687556' genesis_hash = '6e3fcf1299d4ec5d79c3a4c91d624a4acf9e2e173d95a1a0504f677669687556'
genesis_bits = 0x207fffff genesis_bits = 0x207fffff
target_timespan = 1 target_timespan = 1
checkpoints = {}

View file

@ -3,21 +3,20 @@ import json
import typing import typing
import logging import logging
import asyncio import asyncio
from binascii import unhexlify from binascii import unhexlify
from decimal import Decimal from decimal import Decimal
from typing import List, Type, MutableSequence, MutableMapping, Optional from typing import List, Type, MutableSequence, MutableMapping, Optional
from lbry.error import KeyFeeAboveMaxAllowedError, WalletNotLoadedError from lbry.error import KeyFeeAboveMaxAllowedError
from lbry.conf import Config, NOT_SET from lbry.conf import Config
from lbry.wallet.dewies import dewies_to_lbc from .dewies import dewies_to_lbc
from lbry.wallet.account import Account from .account import Account
from lbry.wallet.ledger import Ledger, LedgerRegistry from .ledger import Ledger, LedgerRegistry
from lbry.wallet.transaction import Transaction, Output from .transaction import Transaction, Output
from lbry.wallet.database import Database from .database import Database
from lbry.wallet.wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK
from lbry.wallet.rpc.jsonrpc import CodeMessageError from .rpc.jsonrpc import CodeMessageError
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
@ -96,7 +95,7 @@ class WalletManager:
for wallet in self.wallets: for wallet in self.wallets:
if wallet.id == wallet_id: if wallet.id == wallet_id:
return wallet return wallet
raise WalletNotLoadedError(wallet_id) raise ValueError(f"Couldn't find wallet: {wallet_id}.")
@staticmethod @staticmethod
def get_balance(wallet): def get_balance(wallet):
@ -183,17 +182,9 @@ class WalletManager:
ledger_config = { ledger_config = {
'auto_connect': True, 'auto_connect': True,
'explicit_servers': [],
'hub_timeout': config.hub_timeout,
'default_servers': config.lbryum_servers, 'default_servers': config.lbryum_servers,
'known_hubs': config.known_hubs,
'jurisdiction': config.jurisdiction,
'concurrent_hub_requests': config.concurrent_hub_requests,
'data_path': config.wallet_dir, 'data_path': config.wallet_dir,
'tx_cache_size': config.transaction_cache_size
} }
if 'LBRY_FEE_PER_NAME_CHAR' in os.environ:
ledger_config['fee_per_name_char'] = int(os.environ.get('LBRY_FEE_PER_NAME_CHAR'))
wallets_directory = os.path.join(config.wallet_dir, 'wallets') wallets_directory = os.path.join(config.wallet_dir, 'wallets')
if not os.path.exists(wallets_directory): if not os.path.exists(wallets_directory):
@ -203,10 +194,6 @@ class WalletManager:
os.path.join(wallets_directory, 'default_wallet') os.path.join(wallets_directory, 'default_wallet')
) )
if Config.lbryum_servers.is_set_to_default(config):
with config.update_config() as c:
c.lbryum_servers = NOT_SET
manager = cls.from_config({ manager = cls.from_config({
'ledgers': {ledger_id: ledger_config}, 'ledgers': {ledger_id: ledger_config},
'wallets': [ 'wallets': [
@ -237,16 +224,9 @@ class WalletManager:
async def reset(self): async def reset(self):
self.ledger.config = { self.ledger.config = {
'auto_connect': True, 'auto_connect': True,
'explicit_servers': [], 'default_servers': self.config.lbryum_servers,
'default_servers': Config.lbryum_servers.default,
'known_hubs': self.config.known_hubs,
'jurisdiction': self.config.jurisdiction,
'hub_timeout': self.config.hub_timeout,
'concurrent_hub_requests': self.config.concurrent_hub_requests,
'data_path': self.config.wallet_dir, 'data_path': self.config.wallet_dir,
} }
if Config.lbryum_servers.is_set(self.config):
self.ledger.config['explicit_servers'] = self.config.lbryum_servers
await self.ledger.stop() await self.ledger.stop()
await self.ledger.start() await self.ledger.start()
@ -268,28 +248,26 @@ class WalletManager:
log.warning("Failed to migrate %s receiving addresses!", log.warning("Failed to migrate %s receiving addresses!",
len(set(receiving_addresses).difference(set(migrated_receiving)))) len(set(receiving_addresses).difference(set(migrated_receiving))))
async def get_best_blockhash(self): def get_best_blockhash(self):
if len(self.ledger.headers) <= 0: if len(self.ledger.headers) <= 0:
return self.ledger.genesis_hash return self.ledger.genesis_hash
return (await self.ledger.headers.hash(self.ledger.headers.height)).decode() return self.ledger.headers.hash(self.ledger.headers.height).decode()
def get_unused_address(self): def get_unused_address(self):
return self.default_account.receiving.get_or_create_usable_address() return self.default_account.receiving.get_or_create_usable_address()
async def get_transaction(self, txid: str): async def get_transaction(self, txid):
tx = await self.db.get_transaction(txid=txid) tx = await self.db.get_transaction(txid=txid)
if tx: if not tx:
return tx
try: try:
raw, merkle = await self.ledger.network.get_transaction_and_merkle(txid) raw = await self.ledger.network.get_transaction(txid)
height = await self.ledger.network.get_transaction_height(txid)
except CodeMessageError as e: except CodeMessageError as e:
if 'No such mempool or blockchain transaction.' in e.message: if 'No such mempool or blockchain transaction.' in e.message:
return {'success': False, 'code': 404, 'message': 'transaction not found'} return {'success': False, 'code': 404, 'message': 'transaction not found'}
return {'success': False, 'code': e.code, 'message': e.message} return {'success': False, 'code': e.code, 'message': e.message}
height = merkle.get('block_height') tx = Transaction(unhexlify(raw))
tx = Transaction(unhexlify(raw), height=height) await self.ledger.maybe_verify_transaction(tx, height)
if height and height > 0:
await self.ledger.maybe_verify_transaction(tx, height, merkle)
return tx return tx
async def create_purchase_transaction( async def create_purchase_transaction(
@ -317,4 +295,10 @@ class WalletManager:
) )
async def broadcast_or_release(self, tx, blocking=False): async def broadcast_or_release(self, tx, blocking=False):
await self.ledger.broadcast_or_release(tx, blocking=blocking) try:
await self.ledger.broadcast(tx)
if blocking:
await self.ledger.wait(tx, timeout=None)
except:
await self.ledger.release_tx(tx)
raise

View file

@ -1,40 +1,34 @@
import logging import logging
import asyncio import asyncio
import json import json
import socket
import random
from time import perf_counter from time import perf_counter
from collections import defaultdict from operator import itemgetter
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import aiohttp
from lbry import __version__ from lbry import __version__
from lbry.utils import resolve_host
from lbry.error import IncompatibleWalletServerError from lbry.error import IncompatibleWalletServerError
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
from lbry.wallet.stream import StreamController from lbry.wallet.stream import StreamController
from lbry.wallet.udp import SPVStatusClientProtocol, SPVPong
from lbry.conf import KnownHubsList
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class ClientSession(BaseClientSession): class ClientSession(BaseClientSession):
def __init__(self, *args, network: 'Network', server, timeout=30, concurrency=32, **kwargs): def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
self.network = network self.network = network
self.server = server self.server = server
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._on_disconnect_controller = StreamController()
self.on_disconnected = self._on_disconnect_controller.stream
self.framer.max_size = self.max_errors = 1 << 32 self.framer.max_size = self.max_errors = 1 << 32
self.timeout = timeout self.timeout = timeout
self.max_seconds_idle = timeout * 2 self.max_seconds_idle = timeout * 2
self.response_time: Optional[float] = None self.response_time: Optional[float] = None
self.connection_latency: Optional[float] = None self.connection_latency: Optional[float] = None
self._response_samples = 0 self._response_samples = 0
self._concurrency = asyncio.Semaphore(concurrency) self.pending_amount = 0
self._on_connect_cb = on_connect_callback or (lambda: None)
@property self.trigger_urgent_reconnect = asyncio.Event()
def concurrency(self):
return self._concurrency._value
@property @property
def available(self): def available(self):
@ -60,9 +54,9 @@ class ClientSession(BaseClientSession):
return result return result
async def send_request(self, method, args=()): async def send_request(self, method, args=()):
log.debug("send %s%s to %s:%i (%i timeout)", method, tuple(args), self.server[0], self.server[1], self.timeout) self.pending_amount += 1
log.debug("send %s to %s:%i", method, *self.server)
try: try:
await self._concurrency.acquire()
if method == 'server.version': if method == 'server.version':
return await self.send_timed_server_version_request(args, self.timeout) return await self.send_timed_server_version_request(args, self.timeout)
request = asyncio.ensure_future(super().send_request(method, args)) request = asyncio.ensure_future(super().send_request(method, args))
@ -72,7 +66,7 @@ class ClientSession(BaseClientSession):
log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received) log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received)
if (perf_counter() - self.last_packet_received) < self.timeout: if (perf_counter() - self.last_packet_received) < self.timeout:
continue continue
log.warning("timeout sending %s to %s:%i", method, *self.server) log.info("timeout sending %s to %s:%i", method, *self.server)
raise asyncio.TimeoutError raise asyncio.TimeoutError
if done: if done:
try: try:
@ -92,11 +86,43 @@ class ClientSession(BaseClientSession):
self.synchronous_close() self.synchronous_close()
raise raise
except asyncio.CancelledError: except asyncio.CancelledError:
log.warning("cancelled sending %s to %s:%i", method, *self.server) log.info("cancelled sending %s to %s:%i", method, *self.server)
# self.synchronous_close() self.synchronous_close()
raise raise
finally: finally:
self._concurrency.release() self.pending_amount -= 1
async def ensure_session(self):
# Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2)
retry_delay = default_delay = 1.0
while True:
try:
if self.is_closing():
await self.create_connection(self.timeout)
await self.ensure_server_version()
self._on_connect_cb()
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
await self.ensure_server_version()
retry_delay = default_delay
except RPCError as e:
await self.close()
log.debug("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message)
retry_delay = 60 * 60
except IncompatibleWalletServerError:
await self.close()
retry_delay = 60 * 60
log.debug("Wallet server has an incompatible version, retrying in 1h: %s:%d", *self.server)
except (asyncio.TimeoutError, OSError):
await self.close()
retry_delay = min(60, retry_delay * 2)
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
try:
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
except asyncio.TimeoutError:
pass
finally:
self.trigger_urgent_reconnect.clear()
async def ensure_server_version(self, required=None, timeout=3): async def ensure_server_version(self, required=None, timeout=3):
required = required or self.network.PROTOCOL_VERSION required = required or self.network.PROTOCOL_VERSION
@ -107,25 +133,6 @@ class ClientSession(BaseClientSession):
raise IncompatibleWalletServerError(*self.server) raise IncompatibleWalletServerError(*self.server)
return response return response
async def keepalive_loop(self, timeout=3, max_idle=60):
try:
while True:
now = perf_counter()
if min(self.last_send, self.last_packet_received) + max_idle < now:
await asyncio.wait_for(
self.send_request('server.ping', []), timeout=timeout
)
else:
await asyncio.sleep(max(0, max_idle - (now - self.last_send)))
except (Exception, asyncio.CancelledError) as err:
if isinstance(err, asyncio.CancelledError):
log.info("closing connection to %s:%i", *self.server)
else:
log.exception("lost connection to spv")
finally:
if not self.is_closing():
self._close()
async def create_connection(self, timeout=6): async def create_connection(self, timeout=6):
connector = Connector(lambda: self, *self.server) connector = Connector(lambda: self, *self.server)
start = perf_counter() start = perf_counter()
@ -142,23 +149,24 @@ class ClientSession(BaseClientSession):
self.response_time = None self.response_time = None
self.connection_latency = None self.connection_latency = None
self._response_samples = 0 self._response_samples = 0
# self._on_disconnect_controller.add(True) self.pending_amount = 0
if self.network: self._on_disconnect_controller.add(True)
self.network.disconnect()
class Network: class Network:
PROTOCOL_VERSION = __version__ PROTOCOL_VERSION = __version__
MINIMUM_REQUIRED = (0, 65, 0) MINIMUM_REQUIRED = (0, 59, 0)
def __init__(self, ledger): def __init__(self, ledger):
self.ledger = ledger self.ledger = ledger
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.client: Optional[ClientSession] = None self.client: Optional[ClientSession] = None
self.server_features = None self.server_features = None
# self._switch_task: Optional[asyncio.Task] = None self._switch_task: Optional[asyncio.Task] = None
self.running = False self.running = False
self.remote_height: int = 0 self.remote_height: int = 0
self._concurrency = asyncio.Semaphore(16)
self._on_connected_controller = StreamController() self._on_connected_controller = StreamController()
self.on_connected = self._on_connected_controller.stream self.on_connected = self._on_connected_controller.stream
@ -169,255 +177,85 @@ class Network:
self._on_status_controller = StreamController(merge_repeated_events=True) self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream self.on_status = self._on_status_controller.stream
self._on_hub_controller = StreamController(merge_repeated_events=True)
self.on_hub = self._on_hub_controller.stream
self.subscription_controllers = { self.subscription_controllers = {
'blockchain.headers.subscribe': self._on_header_controller, 'blockchain.headers.subscribe': self._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller, 'blockchain.address.subscribe': self._on_status_controller,
'blockchain.peers.subscribe': self._on_hub_controller,
} }
self.aiohttp_session: Optional[aiohttp.ClientSession] = None
self._urgent_need_reconnect = asyncio.Event()
self._loop_task: Optional[asyncio.Task] = None
self._keepalive_task: Optional[asyncio.Task] = None
@property @property
def config(self): def config(self):
return self.ledger.config return self.ledger.config
@property async def switch_forever(self):
def known_hubs(self):
if 'known_hubs' not in self.config:
return KnownHubsList()
return self.config['known_hubs']
@property
def jurisdiction(self):
return self.config.get("jurisdiction")
def disconnect(self):
if self._keepalive_task and not self._keepalive_task.done():
self._keepalive_task.cancel()
self._keepalive_task = None
async def start(self):
if not self.running:
self.running = True
self.aiohttp_session = aiohttp.ClientSession()
self.on_header.listen(self._update_remote_height)
self.on_hub.listen(self._update_hubs)
self._loop_task = asyncio.create_task(self.network_loop())
self._urgent_need_reconnect.set()
def loop_task_done_callback(f):
try:
f.result()
except (Exception, asyncio.CancelledError):
if self.running:
log.exception("wallet server connection loop crashed")
self._loop_task.add_done_callback(loop_task_done_callback)
async def resolve_spv_dns(self):
hostname_to_ip = {}
ip_to_hostnames = defaultdict(list)
async def resolve_spv(server, port):
try:
server_addr = await resolve_host(server, port, 'udp')
hostname_to_ip[server] = (server_addr, port)
ip_to_hostnames[(server_addr, port)].append(server)
except socket.error:
log.warning("error looking up dns for spv server %s:%i", server, port)
except Exception:
log.exception("error looking up dns for spv server %s:%i", server, port)
# accumulate the dns results
if self.config.get('explicit_servers', []):
hubs = self.config['explicit_servers']
elif self.known_hubs:
hubs = self.known_hubs
else:
hubs = self.config['default_servers']
await asyncio.gather(*(resolve_spv(server, port) for (server, port) in hubs))
return hostname_to_ip, ip_to_hostnames
async def get_n_fastest_spvs(self, timeout=3.0) -> Dict[Tuple[str, int], Optional[SPVPong]]:
loop = asyncio.get_event_loop()
pong_responses = asyncio.Queue()
connection = SPVStatusClientProtocol(pong_responses)
sent_ping_timestamps = {}
_, ip_to_hostnames = await self.resolve_spv_dns()
n = len(ip_to_hostnames)
log.info("%i possible spv servers to try (%i urls in config)", n, len(self.config.get('explicit_servers', [])))
pongs = {}
known_hubs = self.known_hubs
try:
await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0))
# could raise OSError if it cant bind
start = perf_counter()
for server in ip_to_hostnames:
connection.ping(server)
sent_ping_timestamps[server] = perf_counter()
while len(pongs) < n:
(remote, ts), pong = await asyncio.wait_for(pong_responses.get(), timeout - (perf_counter() - start))
latency = ts - start
log.info("%s:%i has latency of %sms (available: %s, height: %i)",
'/'.join(ip_to_hostnames[remote]), remote[1], round(latency * 1000, 2),
pong.available, pong.height)
known_hubs.hubs.setdefault((ip_to_hostnames[remote][0], remote[1]), {}).update(
{"country": pong.country_name}
)
if pong.available:
pongs[(ip_to_hostnames[remote][0], remote[1])] = pong
return pongs
except asyncio.TimeoutError:
if pongs:
log.info("%i/%i probed spv servers are accepting connections", len(pongs), len(ip_to_hostnames))
return pongs
else:
log.warning("%i spv status probes failed, retrying later. servers tried: %s",
len(sent_ping_timestamps),
', '.join('/'.join(hosts) + f' ({ip})' for ip, hosts in ip_to_hostnames.items()))
random_server = random.choice(list(ip_to_hostnames.keys()))
host, port = random_server
log.warning("trying fallback to randomly selected spv: %s:%i", host, port)
known_hubs.hubs.setdefault((host, port), {})
return {(host, port): None}
finally:
connection.close()
async def connect_to_fastest(self) -> Optional[ClientSession]:
fastest_spvs = await self.get_n_fastest_spvs()
for (host, port), pong in fastest_spvs.items():
if (pong is not None and self.jurisdiction is not None) and \
(pong.country_name != self.jurisdiction):
continue
client = ClientSession(network=self, server=(host, port), timeout=self.config.get('hub_timeout', 30),
concurrency=self.config.get('concurrent_hub_requests', 30))
try:
await client.create_connection()
log.info("Connected to spv server %s:%i", host, port)
await client.ensure_server_version()
return client
except (asyncio.TimeoutError, ConnectionError, OSError, IncompatibleWalletServerError, RPCError):
log.warning("Connecting to %s:%d failed", host, port)
client._close()
return
async def network_loop(self):
sleep_delay = 30
while self.running: while self.running:
await asyncio.wait( if self.is_connected:
map(asyncio.create_task, [asyncio.sleep(30), self._urgent_need_reconnect.wait()]), await self.client.on_disconnected.first
return_when=asyncio.FIRST_COMPLETED self.server_features = None
) self.client = None
if self._urgent_need_reconnect.is_set():
sleep_delay = 30
self._urgent_need_reconnect.clear()
if not self.is_connected:
client = await self.connect_to_fastest()
if not client:
log.warning("failed to connect to any spv servers, retrying later")
sleep_delay *= 2
sleep_delay = min(sleep_delay, 300)
continue continue
log.debug("get spv server features %s:%i", *client.server) self.client = await self.session_pool.wait_for_fastest_session()
features = await client.send_request('server.features', []) log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
self.client, self.server_features = client, features try:
log.debug("discover other hubs %s:%i", *client.server) self.server_features = await self.get_server_features()
await self._update_hubs(await client.send_request('server.peers.subscribe', []))
log.info("subscribe to headers %s:%i", *client.server)
self._update_remote_height((await self.subscribe_headers(),)) self._update_remote_height((await self.subscribe_headers(),))
self._on_connected_controller.add(True) self._on_connected_controller.add(True)
server_str = "%s:%i" % client.server log.info("Subscribed to headers: %s:%d", *self.client.server)
log.info("maintaining connection to spv server %s", server_str) except (asyncio.TimeoutError, ConnectionError):
self._keepalive_task = asyncio.create_task(self.client.keepalive_loop()) log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server)
try: self.client.synchronous_close()
if not self._urgent_need_reconnect.is_set():
await asyncio.wait(
[self._keepalive_task, asyncio.create_task(self._urgent_need_reconnect.wait())],
return_when=asyncio.FIRST_COMPLETED
)
else:
await self._keepalive_task
if self._urgent_need_reconnect.is_set():
log.warning("urgent reconnect needed")
if self._keepalive_task and not self._keepalive_task.done():
self._keepalive_task.cancel()
except asyncio.CancelledError:
pass
finally:
self._keepalive_task = None
self.client = None
self.server_features = None self.server_features = None
log.info("connection lost to %s", server_str) self.client = None
log.info("network loop finished")
async def start(self):
self.running = True
self._switch_task = asyncio.ensure_future(self.switch_forever())
# this may become unnecessary when there are no more bugs found,
# but for now it helps understanding log reports
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height)
async def stop(self): async def stop(self):
if self.running:
self.running = False self.running = False
self.disconnect() self._switch_task.cancel()
if self._loop_task and not self._loop_task.done(): self.session_pool.stop()
self._loop_task.cancel()
self._loop_task = None
if self.aiohttp_session:
await self.aiohttp_session.close()
self.aiohttp_session = None
@property @property
def is_connected(self): def is_connected(self):
return self.client and not self.client.is_closing() return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args, restricted=True, session: Optional[ClientSession] = None): def rpc(self, list_or_method, args, restricted=True):
if session or self.is_connected: session = self.client if restricted else self.session_pool.fastest_session
session = session or self.client if session and not session.is_closing():
return session.send_request(list_or_method, args) return session.send_request(list_or_method, args)
else: else:
self._urgent_need_reconnect.set() self.session_pool.trigger_nodelay_connect()
raise ConnectionError("Attempting to send rpc request when connection is not available.") raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def retriable_call(self, function, *args, **kwargs): async def retriable_call(self, function, *args, **kwargs):
async with self._concurrency:
while self.running: while self.running:
if not self.is_connected: if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.") log.warning("Wallet server unavailable, waiting for it to come back and retry.")
self._urgent_need_reconnect.set()
await self.on_connected.first await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
try: try:
return await function(*args, **kwargs) return await function(*args, **kwargs)
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning("Wallet server call timed out, retrying.") log.warning("Wallet server call timed out, retrying.")
except ConnectionError: except ConnectionError:
log.warning("connection error") pass
raise asyncio.CancelledError() # if we got here, we are shutting down raise asyncio.CancelledError() # if we got here, we are shutting down
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"] self.remote_height = header_args[0]["height"]
async def _update_hubs(self, hubs):
if hubs and hubs != ['']:
try:
if self.known_hubs.add_hubs(hubs):
self.known_hubs.save()
except Exception:
log.exception("could not add hubs: %s", hubs)
def get_transaction(self, tx_hash, known_height=None): def get_transaction(self, tx_hash, known_height=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [tx_hash], restricted) return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
def get_transaction_batch(self, txids, restricted=True):
# use any server if its old, otherwise restrict to who gave us the history
return self.rpc('blockchain.transaction.get_batch', txids, restricted)
def get_transaction_and_merkle(self, tx_hash, known_height=None):
# use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.info', [tx_hash], restricted)
def get_transaction_height(self, tx_hash, known_height=None): def get_transaction_height(self, tx_hash, known_height=None):
restricted = not known_height or 0 > known_height > self.remote_height - 10 restricted = not known_height or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted) return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted)
@ -442,13 +280,12 @@ class Network:
async def subscribe_address(self, address, *addresses): async def subscribe_address(self, address, *addresses):
addresses = list((address, ) + addresses) addresses = list((address, ) + addresses)
server_addr_and_port = self.client.server_address_and_port # on disconnect client will be None
try: try:
return await self.rpc('blockchain.address.subscribe', addresses, True) return await self.rpc('blockchain.address.subscribe', addresses, True)
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning( log.warning(
"timed out subscribing to addresses from %s:%i", "timed out subscribing to addresses from %s:%i",
*server_addr_and_port *self.client.server_address_and_port
) )
# abort and cancel, we can't lose a subscription, it will happen again on reconnect # abort and cancel, we can't lose a subscription, it will happen again on reconnect
if self.client: if self.client:
@ -461,20 +298,103 @@ class Network:
def get_server_features(self): def get_server_features(self):
return self.rpc('server.features', (), restricted=True) return self.rpc('server.features', (), restricted=True)
# def get_claims_by_ids(self, claim_ids): def get_claims_by_ids(self, claim_ids):
# return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids) return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
def get_claim_by_id(self, claim_id): def resolve(self, urls):
return self.rpc('blockchain.claimtrie.getclaimbyid', [claim_id]) return self.rpc('blockchain.claimtrie.resolve', urls)
def resolve(self, urls, session_override=None): def claim_search(self, **kwargs):
return self.rpc('blockchain.claimtrie.resolve', urls, False, session_override) return self.rpc('blockchain.claimtrie.search', kwargs)
def claim_search(self, session_override=None, **kwargs):
return self.rpc('blockchain.claimtrie.search', kwargs, False, session_override)
async def sum_supports(self, server, **kwargs): class SessionPool:
message = {"method": "support_sum", "params": kwargs}
async with self.aiohttp_session.post(server, json=message) as r: def __init__(self, network: Network, timeout: float):
result = await r.json() self.network = network
return result['result'] self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.timeout = timeout
self.new_connection_event = asyncio.Event()
@property
def online(self):
return any(not session.is_closing() for session in self.sessions)
@property
def available_sessions(self):
return (session for session in self.sessions if session.available)
@property
def fastest_session(self):
if not self.online:
return None
return min(
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
for session in self.available_sessions] or [(0, None)],
key=itemgetter(0)
)[1]
def _get_session_connect_callback(self, session: ClientSession):
loop = asyncio.get_event_loop()
def callback():
duplicate_connections = [
s for s in self.sessions
if s is not session and s.server_address_and_port == session.server_address_and_port
]
already_connected = None if not duplicate_connections else duplicate_connections[0]
if already_connected:
self.sessions.pop(session).cancel()
session.synchronous_close()
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
session.server[0], already_connected.server[0])
loop.call_later(3600, self._connect_session, session.server)
return
self.new_connection_event.set()
log.info("connected to %s:%i", *session.server)
return callback
def _connect_session(self, server: Tuple[str, int]):
session = None
for s in self.sessions:
if s.server == server:
session = s
break
if not session:
session = ClientSession(
network=self.network, server=server
)
session._on_connect_cb = self._get_session_connect_callback(session)
task = self.sessions.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
def start(self, default_servers):
for server in default_servers:
self._connect_session(server)
def stop(self):
for session, task in self.sessions.items():
task.cancel()
session.synchronous_close()
self.sessions.clear()
def ensure_connections(self):
for session in self.sessions:
self._connect_session(session.server)
def trigger_nodelay_connect(self):
# used when other parts of the system sees we might have internet back
# bypasses the retry interval
for session in self.sessions:
session.trigger_urgent_reconnect.set()
async def wait_for_fastest_session(self):
while not self.fastest_session:
self.trigger_nodelay_connect()
self.new_connection_event.clear()
await self.new_connection_event.wait()
return self.fastest_session

View file

@ -1,2 +1,2 @@
from lbry.wallet.orchstr8.node import Conductor from .node import Conductor
from lbry.wallet.orchstr8.service import ConductorService from .service import ConductorService

View file

@ -5,9 +5,7 @@ import aiohttp
from lbry import wallet from lbry import wallet
from lbry.wallet.orchstr8.node import ( from lbry.wallet.orchstr8.node import (
Conductor, Conductor, get_blockchain_node_from_ledger
get_lbcd_node_from_ledger,
get_lbcwallet_node_from_ledger
) )
from lbry.wallet.orchstr8.service import ConductorService from lbry.wallet.orchstr8.service import ConductorService
@ -18,11 +16,10 @@ def get_argument_parser():
) )
subparsers = parser.add_subparsers(dest='command', help='sub-command help') subparsers = parser.add_subparsers(dest='command', help='sub-command help')
subparsers.add_parser("download", help="Download lbcd and lbcwallet node binaries.") subparsers.add_parser("download", help="Download blockchain node binary.")
start = subparsers.add_parser("start", help="Start orchstr8 service.") start = subparsers.add_parser("start", help="Start orchstr8 service.")
start.add_argument("--lbcd", help="Hostname to start lbcd node.") start.add_argument("--blockchain", help="Hostname to start blockchain node.")
start.add_argument("--lbcwallet", help="Hostname to start lbcwallet node.")
start.add_argument("--spv", help="Hostname to start SPV server.") start.add_argument("--spv", help="Hostname to start SPV server.")
start.add_argument("--wallet", help="Hostname to start wallet daemon.") start.add_argument("--wallet", help="Hostname to start wallet daemon.")
@ -50,8 +47,7 @@ def main():
if command == 'download': if command == 'download':
logging.getLogger('blockchain').setLevel(logging.INFO) logging.getLogger('blockchain').setLevel(logging.INFO)
get_lbcd_node_from_ledger(wallet).ensure() get_blockchain_node_from_ledger(wallet).ensure()
get_lbcwallet_node_from_ledger(wallet).ensure()
elif command == 'generate': elif command == 'generate':
loop.run_until_complete(run_remote_command( loop.run_until_complete(run_remote_command(
@ -61,12 +57,9 @@ def main():
elif command == 'start': elif command == 'start':
conductor = Conductor() conductor = Conductor()
if getattr(args, 'lbcd', False): if getattr(args, 'blockchain', False):
conductor.lbcd_node.hostname = args.lbcd conductor.blockchain_node.hostname = args.blockchain
loop.run_until_complete(conductor.start_lbcd()) loop.run_until_complete(conductor.start_blockchain())
if getattr(args, 'lbcwallet', False):
conductor.lbcwallet_node.hostname = args.lbcwallet
loop.run_until_complete(conductor.start_lbcwallet())
if getattr(args, 'spv', False): if getattr(args, 'spv', False):
conductor.spv_node.hostname = args.spv conductor.spv_node.hostname = args.spv
loop.run_until_complete(conductor.start_spv()) loop.run_until_complete(conductor.start_spv())

Some files were not shown because too many files have changed in this diff Show more