This commit is contained in:
Lex Berezhny 2020-06-05 00:35:22 -04:00
parent ffecd02fbc
commit 3ef83febc0
39 changed files with 1905 additions and 1698 deletions

View file

@ -2,61 +2,69 @@ name: ci
on: push
jobs:
# lint:
# name: lint
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v1
# - uses: actions/setup-python@v1
# with:
# python-version: '3.7'
# - run: pip install -e .[lint]
# - run: make lint
#
# tests-unit:
# name: "tests / unit"
# strategy:
# matrix:
# os:
# - ubuntu-latest
# - macos-latest
# - windows-latest
# runs-on: ${{ matrix.os }}
# steps:
# - uses: actions/checkout@v1
# - uses: actions/setup-python@v1
# with:
# python-version: '3.7'
# - name: set pip cache dir
# id: pip-cache
# run: echo "::set-output name=dir::$(pip cache dir)"
# - name: extract pip cache
# uses: actions/cache@v2
# with:
# path: ${{ steps.pip-cache.outputs.dir }}
# key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
# restore-keys: ${{ runner.os }}-pip-
# - run: |
# pip install --user --upgrade pip wheel
# pip install -e .[test]
# - working-directory: lbry
# env:
# HOME: /tmp
# run: coverage run -p --source=lbry -m unittest -vv tests.unit.test_conf
lint:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v1
with:
python-version: '3.7'
- name: extract pip cache
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
restore-keys: ${{ runner.os }}-pip-
- run: |
pip install --user --upgrade pip wheel
pip install -e .[lint]
- run: make lint
tests-unit:
name: "tests / unit"
strategy:
matrix:
os:
- ubuntu-latest
- macos-latest
- windows-latest
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v1
with:
python-version: '3.7'
- name: set pip cache dir
id: pip-cache
run: echo "::set-output name=dir::$(pip cache dir)"
- name: extract pip cache
uses: actions/cache@v2
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
restore-keys: ${{ runner.os }}-pip-
- run: |
pip install --user --upgrade pip wheel
pip install -e .[test]
- working-directory: lbry
env:
HOME: /tmp
run: coverage run -p --source=lbry -m unittest -vv tests.unit.test_conf
# run: coverage run -p --source=lbry -m unittest discover -vv tests.unit
#
# tests-integration:
# name: "tests / integration"
# runs-on: ubuntu-latest
# strategy:
# matrix:
# test:
tests-integration:
name: "tests / integration"
runs-on: ubuntu-latest
strategy:
matrix:
test:
# - datanetwork
# - blockchain
- blockchain
# - other
# db:
db:
- sqlite
# - postgres
# - sqlite
# services:
# postgres:
# image: postgres:12
@ -67,22 +75,22 @@ jobs:
# ports:
# - 5432:5432
# options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
# steps:
# - uses: actions/checkout@v1
# - uses: actions/setup-python@v1
# with:
# python-version: '3.7'
# - if: matrix.test == 'other'
# run: |
# sudo apt-get update
# sudo apt-get install -y --no-install-recommends ffmpeg
# - run: pip install tox-travis
# - env:
# TEST_DB: ${{ matrix.db }}
# run: tox -e ${{ matrix.test }}
#
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v1
with:
python-version: '3.7'
- if: matrix.test == 'other'
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends ffmpeg
- run: pip install tox-travis
- env:
TEST_DB: ${{ matrix.db }}
run: tox -e ${{ matrix.test }}
build:
#needs: ["lint", "tests-unit", "tests-integration"]
needs: ["lint", "tests-unit", "tests-integration"]
name: "build"
strategy:
matrix:

View file

@ -1,9 +1,8 @@
import struct
from typing import Set
from binascii import unhexlify
from typing import NamedTuple, List
from chiabip158 import PyBIP158
from chiabip158 import PyBIP158 # pylint: disable=no-name-in-module
from lbry.crypto.hash import double_sha256
from lbry.blockchain.transaction import Transaction

View file

@ -60,27 +60,33 @@ class BlockchainDB:
async def execute_fetchall(self, sql: str, *args):
return await self.run_in_executor(self.sync_execute_fetchall, sql, *args)
def sync_get_block_files(self, above_height=-1):
return self.sync_execute_fetchall(
"""
def sync_get_block_files(self, file_number=None, above_height=None):
sql = """
SELECT
file as file_number,
COUNT(hash) as blocks,
SUM(txcount) as txs,
MAX(height) as max_height
FROM block_info WHERE height > ? GROUP BY file ORDER BY file ASC;
""", (above_height,)
)
FROM block_info
WHERE status&1 AND status&4
"""
args = ()
if file_number is not None and above_height is not None:
sql += "AND file = ? AND height > ?"
args = (file_number, above_height)
return self.sync_execute_fetchall(sql + " GROUP BY file ORDER BY file ASC;", args)
async def get_block_files(self, above_height=-1):
return await self.run_in_executor(self.sync_get_block_files, above_height)
async def get_block_files(self, file_number=None, above_height=None):
return await self.run_in_executor(
self.sync_get_block_files, file_number, above_height
)
def sync_get_blocks_in_file(self, block_file, above_height=-1):
return self.sync_execute_fetchall(
"""
SELECT datapos as data_offset, height, hash as block_hash, txCount as txs
FROM block_info
WHERE file = ? AND height > ? AND status&1 > 0
WHERE file = ? AND height > ? AND status&1 AND status&4
ORDER BY datapos ASC;
""", (block_file, above_height)
)

View file

@ -38,14 +38,14 @@ class Process(asyncio.SubprocessProtocol):
def __init__(self):
self.ready = asyncio.Event()
self.stopped = asyncio.Event()
self.log = log.getChild('blockchain')
def pipe_data_received(self, fd, data):
if self.log and not any(ignore in data for ignore in self.IGNORE_OUTPUT):
if not any(ignore in data for ignore in self.IGNORE_OUTPUT):
if b'Error:' in data:
self.log.error(data.decode())
log.error(data.decode())
else:
self.log.info(data.decode())
for line in data.decode().splitlines():
log.debug(line.rstrip())
if b'Error:' in data:
self.ready.set()
raise SystemError(data.decode())

View file

@ -1,98 +1,98 @@
import os
import asyncio
import logging
import multiprocessing as mp
from contextvars import ContextVar
from typing import Tuple, Optional
from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
from typing import Optional
from sqlalchemy import func, bindparam
from sqlalchemy.future import select
from lbry.event import EventController, BroadcastSubscription, EventQueuePublisher
from lbry.event import BroadcastSubscription
from lbry.service.base import Sync, BlockEvent
from lbry.db import Database, queries, TXO_TYPES
from lbry.db.tables import Claim, Claimtrie, TX, TXO, TXI, Block as BlockTable
from lbry.db.tables import Claim, Claimtrie, TXO, TXI, Block as BlockTable
from lbry.db.query_context import progress, context, Event
from lbry.db.utils import chunk
from .lbrycrd import Lbrycrd
from .block import Block, create_block_filter, get_block_filter
from .block import Block, create_block_filter
from .bcd_data_stream import BCDataStream
from .ledger import Ledger
log = logging.getLogger(__name__)
_context: ContextVar[Tuple[Lbrycrd, mp.Queue, mp.Event, int]] = ContextVar('ctx')
_chain: ContextVar[Lbrycrd] = ContextVar('chain')
def ctx():
return _context.get()
def initialize(url: str, ledger: Ledger, progress: mp.Queue, stop: mp.Event, track_metrics: bool):
chain = Lbrycrd(ledger)
def get_or_initialize_lbrycrd(ctx=None) -> Lbrycrd:
chain = _chain.get(None)
if chain is not None:
return chain
chain = Lbrycrd((ctx or context()).ledger)
chain.db.sync_open()
_context.set((chain, progress, stop, os.getpid()))
queries.initialize(url=url, ledger=ledger, track_metrics=track_metrics)
_chain.set(chain)
return chain
PARSING = 1
SAVING = 2
PROCESSED = 3
FINISHED = 4
def process_block_file(block_file_number, current_height):
ctx = context()
chain = get_or_initialize_lbrycrd(ctx)
stop = ctx.stop_event
loader = ctx.get_bulk_loader()
def process_block_file(block_file_number):
chain, progress, stop, pid = ctx()
block_file_path = chain.get_block_file_path_from_number(block_file_number)
current_height = queries.get_best_height()
with progress(Event.BLOCK_READ, 100) as p:
new_blocks = chain.db.sync_get_blocks_in_file(block_file_number, current_height)
if not new_blocks:
return -1
num = 0
total = len(new_blocks)
progress.put_nowait((PARSING, pid, block_file_number, num, total))
collector = queries.RowCollector(queries.ctx())
last_block_processed = -1
done, total, last_block_processed = 0, len(new_blocks), -1
block_file_path = chain.get_block_file_path_from_number(block_file_number)
p.start(total, {'block_file': block_file_number})
with open(block_file_path, 'rb') as fp:
stream = BCDataStream(fp=fp)
for num, block_info in enumerate(new_blocks, start=1):
for done, block_info in enumerate(new_blocks, start=1):
if stop.is_set():
return -1
block_height = block_info['height']
fp.seek(block_info['data_offset'])
block = Block.from_data_stream(stream, block_height, block_file_number)
collector.add_block(block)
loader.add_block(block)
last_block_processed = block_height
if num % 100 == 0:
progress.put_nowait((PARSING, pid, block_file_number, num, total))
progress.put_nowait((PARSING, pid, block_file_number, num, total))
collector.save(
lambda remaining, total: progress.put_nowait(
(SAVING, pid, block_file_number, remaining, total)
)
)
progress.put((PROCESSED, pid, block_file_number))
p.step(done)
with progress(Event.BLOCK_SAVE) as p:
p.extra = {'block_file': block_file_number}
loader.save()
return last_block_processed
def process_claimtrie():
execute = queries.ctx().execute
chain, progress, stop, _ = ctx()
def process_claimtrie(heights):
chain = get_or_initialize_lbrycrd()
execute(Claimtrie.delete())
for record in chain.db.sync_get_claimtrie():
execute(
Claimtrie.insert(), {
'normalized': record['normalized'],
'claim_hash': record['claim_hash'],
'last_take_over_height': record['last_take_over_height'],
}
with progress(Event.TRIE_DELETE) as p:
p.start(1)
p.ctx.execute(Claimtrie.delete())
with progress(Event.TRIE_UPDATE) as p, context().connection.begin():
trie = chain.db.sync_get_claimtrie()
p.start(len(trie))
done = 0
for chunk_size, chunk_rows in chunk(trie, 10000):
p.ctx.execute(
Claimtrie.insert(), [{
'normalized': r['normalized'],
'claim_hash': r['claim_hash'],
'last_take_over_height': r['last_take_over_height'],
} for r in chunk_rows]
)
done += chunk_size
p.step(done)
best_height = queries.get_best_height()
for record in chain.db.sync_get_claims():
execute(
with progress(Event.CLAIM_UPDATE, 250) as p, context().connection.begin():
claims = chain.db.sync_get_claims()
p.start(len(claims))
done = 0
for record in claims:
p.ctx.execute(
Claim.update()
.where(Claim.c.claim_hash == record['claim_hash'])
.values(
@ -100,11 +100,14 @@ def process_claimtrie():
expiration_height=record['expiration_height']
)
)
done += 1
p.step(done)
with context("effective amount update") as ctx:
support = TXO.alias('support')
effective_amount_update = (
Claim.update()
.where(Claim.c.activation_height <= best_height)
.where(Claim.c.activation_height <= heights[-1])
.values(
effective_amount=(
select(func.coalesce(func.sum(support.c.amount), 0) + Claim.c.amount)
@ -116,28 +119,26 @@ def process_claimtrie():
)
)
)
execute(effective_amount_update)
ctx.execute(effective_amount_update)
def process_block_and_tx_filters():
context = queries.ctx()
execute = context.execute
ledger = context.ledger
with context("effective amount update") as ctx:
blocks = []
all_filters = []
all_addresses = []
for block in queries.get_blocks_without_filters():
addresses = {
ledger.address_to_hash160(r['address'])
ctx.ledger.address_to_hash160(r['address'])
for r in queries.get_block_tx_addresses(block_hash=block['block_hash'])
}
all_addresses.extend(addresses)
block_filter = create_block_filter(addresses)
all_filters.append(block_filter)
blocks.append({'pk': block['block_hash'], 'block_filter': block_filter})
filters = [get_block_filter(f) for f in all_filters]
execute(BlockTable.update().where(BlockTable.c.block_hash == bindparam('pk')), blocks)
# filters = [get_block_filter(f) for f in all_filters]
ctx.execute(BlockTable.update().where(BlockTable.c.block_hash == bindparam('pk')), blocks)
# txs = []
# for tx in queries.get_transactions_without_filters():
@ -148,79 +149,17 @@ def process_block_and_tx_filters():
# execute(TX.update().where(TX.c.tx_hash == bindparam('pk')), txs)
class SyncMessageToEvent(EventQueuePublisher):
def message_to_event(self, message):
if message[0] == PARSING:
event = "blockchain.sync.parsing"
elif message[0] == SAVING:
event = "blockchain.sync.saving"
elif message[0] == PROCESSED:
return {
"event": "blockchain.sync.processed",
"data": {"pid": message[1], "block_file": message[2]}
}
elif message[0] == FINISHED:
return {
'event': 'blockchain.sync.finish',
'data': {'finished_height': message[1]}
}
else:
raise ValueError("Unknown message type.")
return {
"event": event,
"data": {
"pid": message[1],
"block_file": message[2],
"step": message[3],
"total": message[4]
}
}
class BlockchainSync(Sync):
def __init__(self, chain: Lbrycrd, db: Database, processes=-1):
def __init__(self, chain: Lbrycrd, db: Database):
super().__init__(chain.ledger, db)
self.chain = chain
self.message_queue = mp.Queue()
self.stop_event = mp.Event()
self.on_block_subscription: Optional[BroadcastSubscription] = None
self.advance_loop_task: Optional[asyncio.Task] = None
self.advance_loop_event = asyncio.Event()
self._on_progress_controller = EventController()
self.on_progress = self._on_progress_controller.stream
self.progress_publisher = SyncMessageToEvent(
self.message_queue, self._on_progress_controller
)
self.track_metrics = False
self.processes = self._normalize_processes(processes)
self.executor = self._create_executor()
@staticmethod
def _normalize_processes(processes):
if processes == 0:
return os.cpu_count()
elif processes > 0:
return processes
return 1
def _create_executor(self) -> Executor:
args = dict(
initializer=initialize,
initargs=(
self.db.url, self.chain.ledger,
self.message_queue, self.stop_event,
self.track_metrics
)
)
if self.processes > 1:
return ProcessPoolExecutor(max_workers=self.processes, **args)
else:
return ThreadPoolExecutor(max_workers=1, **args)
async def start(self):
self.progress_publisher.start()
# initial advance as task so that it can be stop()'ed before finishing
self.advance_loop_task = asyncio.create_task(self.advance())
await self.advance_loop_task
self.chain.subscribe()
@ -233,74 +172,87 @@ class BlockchainSync(Sync):
self.chain.unsubscribe()
if self.on_block_subscription is not None:
self.on_block_subscription.cancel()
self.stop_event.set()
self.db.stop_event.set()
self.advance_loop_task.cancel()
self.progress_publisher.stop()
self.executor.shutdown()
async def run(self, f, *args):
return await asyncio.get_running_loop().run_in_executor(
self.db.executor, f, *args
)
async def load_blocks(self):
tasks = []
best_height = await self.db.get_best_height()
starting_height = None
tx_count = block_count = ending_height = 0
#for file in (await self.chain.db.get_block_files(best_height))[:1]:
for file in await self.chain.db.get_block_files(best_height):
for file in await self.chain.db.get_block_files():
# block files may be read and saved out of order, need to check
# each file individually to see if we have missing blocks
current_height = await self.db.get_best_height_for_file(file['file_number'])
if current_height == file['max_height']:
# we have all blocks in this file, skipping
continue
if -1 < current_height < file['max_height']:
# we have some blocks, need to figure out what we're missing
# call get_block_files again limited to this file and current_height
file = (await self.chain.db.get_block_files(
file_number=file['file_number'], above_height=current_height
))[0]
tx_count += file['txs']
block_count += file['blocks']
starting_height = min(
current_height if starting_height is None else starting_height, current_height
)
ending_height = max(ending_height, file['max_height'])
tasks.append(asyncio.get_running_loop().run_in_executor(
self.executor, process_block_file, file['file_number']
))
tasks.append(self.run(process_block_file, file['file_number'], current_height))
if not tasks:
return None
await self._on_progress_controller.add({
'event': 'blockchain.sync.start',
'data': {
'starting_height': best_height,
'ending_height': ending_height,
'files': len(tasks),
'blocks': block_count,
'txs': tx_count
"event": "blockchain.sync.start",
"data": {
"starting_height": starting_height,
"ending_height": ending_height,
"files": len(tasks),
"blocks": block_count,
"txs": tx_count
}
})
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_EXCEPTION
)
if pending:
self.stop_event.set()
self.db.stop_event.set()
for future in pending:
future.cancel()
return None
best_height_processed = max(f.result() for f in done)
# putting event in queue instead of add to progress_controller because
# we want this message to appear after all of the queued messages from workers
self.message_queue.put((FINISHED, best_height_processed))
return best_height_processed
async def process_claims(self):
await asyncio.get_event_loop().run_in_executor(
self.executor, queries.process_claims_and_supports
)
async def process_block_and_tx_filters(self):
await asyncio.get_event_loop().run_in_executor(
self.executor, process_block_and_tx_filters
)
async def process_claimtrie(self):
await asyncio.get_event_loop().run_in_executor(
self.executor, process_claimtrie
)
async def post_process(self):
await self.process_claims()
if self.conf.spv_address_filters:
await self.process_block_and_tx_filters()
await self.process_claimtrie()
self.db.message_queue.put((
Event.BLOCK_DONE.value, os.getpid(),
len(done), len(tasks),
{"best_height_processed": best_height_processed}
))
return starting_height, best_height_processed
async def advance(self):
best_height = await self.load_blocks()
await self.post_process()
await self._on_block_controller.add(BlockEvent(best_height))
heights = await self.load_blocks()
if heights and heights[0] < heights[-1]:
await self.db.process_inputs(heights)
await self.db.process_claims(heights)
await self.db.process_supports(heights)
await self.run(process_claimtrie, heights)
if self.conf.spv_address_filters:
await self.run(process_block_and_tx_filters, heights)
await self._on_block_controller.add(BlockEvent(heights[1]))
async def advance_loop(self):
while True:
await self.advance_loop_event.wait()
self.advance_loop_event.clear()
try:
await self.advance()
except asyncio.CancelledError:
return
except Exception as e:
log.exception(e)
await self.stop()

View file

@ -30,7 +30,8 @@ async def add_block_to_lbrycrd(chain: Lbrycrd, block: Block, takeovers: List[str
async def insert_claim(chain, block, tx, txo):
await chain.db.execute("""
await chain.db.execute(
"""
INSERT OR REPLACE INTO claim (
claimID, name, nodeName, txID, txN, originalHeight, updateHeight, validHeight,
activationHeight, expirationHeight, amount

View file

@ -122,6 +122,7 @@ class Input(InputOutput):
NULL_SIGNATURE = b'\x00'*72
NULL_PUBLIC_KEY = b'\x00'*33
NULL_HASH32 = b'\x00'*32
__slots__ = 'txo_ref', 'sequence', 'coinbase', 'script'
@ -144,6 +145,12 @@ class Input(InputOutput):
script = InputScript.redeem_pubkey_hash(cls.NULL_SIGNATURE, cls.NULL_PUBLIC_KEY)
return cls(txo.ref, script)
@classmethod
def create_coinbase(cls) -> 'Input':
tx_ref = TXRefImmutable.from_hash(cls.NULL_HASH32, 0)
txo_ref = TXORef(tx_ref, 0)
return cls(txo_ref, b'beef')
@property
def amount(self) -> int:
""" Amount this input adds to the transaction. """
@ -513,6 +520,9 @@ class Transaction:
if raw is not None:
self.deserialize()
def __repr__(self):
return f"TX({self.id[:10]}...{self.id[-10:]})"
@property
def is_broadcast(self):
return self.height > -2

View file

@ -1,10 +1,14 @@
import os
import sys
import asyncio
import pathlib
import argparse
from docopt import docopt
from lbry import __version__
from lbry.conf import Config, CLIConfig
from lbry.service import API, Daemon
from lbry.service import Daemon
from lbry.service.metadata import interface
from lbry.service.full_node import FullNode
from lbry.blockchain.ledger import Ledger
@ -160,6 +164,44 @@ def ensure_directory_exists(path: str):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
async def execute_command(conf, method, params):
pass
def normalize_value(x, key=None):
if not isinstance(x, str):
return x
if key in ('uri', 'channel_name', 'name', 'file_name', 'claim_name', 'download_directory'):
return x
if x.lower() == 'true':
return True
if x.lower() == 'false':
return False
if x.isdigit():
return int(x)
return x
def remove_brackets(key):
if key.startswith("<") and key.endswith(">"):
return str(key[1:-1])
return key
def set_kwargs(parsed_args):
kwargs = {}
for key, arg in parsed_args.items():
if arg is None:
continue
k = None
if key.startswith("--") and remove_brackets(key[2:]) not in kwargs:
k = remove_brackets(key[2:])
elif remove_brackets(key) not in kwargs:
k = remove_brackets(key)
kwargs[k] = normalize_value(arg, k)
return kwargs
def main(argv=None):
argv = argv or sys.argv[1:]
parser = get_argument_parser()
@ -170,13 +212,12 @@ def main(argv=None):
ensure_directory_exists(directory)
if args.cli_version:
from lbry import __version__
print(f"lbrynet {__version__}")
elif args.command == 'start':
if args.help:
args.start_parser.print_help()
elif args.full_node:
service = FullNode(Ledger(conf), conf.db_url_or_default)
service = FullNode(Ledger(conf))
if conf.console == "advanced":
console = AdvancedConsole(service)
else:

View file

@ -9,7 +9,7 @@ from lbry.service.full_node import FullNode
from lbry.service.light_client import LightClient
class Console: # silent
class Console:
def __init__(self, service: Service):
self.service = service
@ -35,101 +35,78 @@ class Basic(Console):
elif isinstance(self.service, LightClient):
s.append('Light Client')
if conf.processes == -1:
s.append(f'Single Process')
s.append(f'Threads Only')
elif conf.processes == 0:
s.append(f'{os.cpu_count()} Process(es)')
else:
s.append(f'{conf.processes} Processes')
s.append(f'{conf.processes} Process(es)')
s.append(f'({os.cpu_count()} CPU(s) available)')
print(' '.join(s))
def stopping(self):
@staticmethod
def stopping():
print('bye.')
def on_sync_progress(self, event):
@staticmethod
def on_sync_progress(event):
print(event)
class Advanced(Basic):
FORMAT = '{l_bar}{bar}| {n_fmt:>7}/{total_fmt:>8} [{elapsed}<{remaining:>5}, {rate_fmt:>15}]'
FORMAT = '{l_bar}{bar}| {n_fmt:>8}/{total_fmt:>8} [{elapsed:>7}<{remaining:>8}, {rate_fmt:>15}]'
def __init__(self, service: Service):
super().__init__(service)
self.bars: Dict[Any, tqdm.tqdm] = {}
def get_or_create_bar(self, name, desc, unit, total):
def get_or_create_bar(self, name, desc, unit, total, leave=False):
bar = self.bars.get(name)
if bar is None:
bar = self.bars[name] = tqdm.tqdm(
desc=desc, unit=unit, total=total, bar_format=self.FORMAT, leave=False
desc=desc, unit=unit, total=total,
bar_format=self.FORMAT, leave=leave
)
return bar
def parsing_bar(self, d):
def start_sync_block_bars(self, d):
self.bars.clear()
self.get_or_create_bar("parse", "total parsing", "blocks", d['blocks'], True)
self.get_or_create_bar("save", "total saving", "txs", d['txs'], True)
def close_sync_block_bars(self):
self.bars.pop("parse").close()
self.bars.pop("save").close()
def update_sync_block_bars(self, event, d):
bar_name = f"block-{d['block_file']}"
bar = self.bars.get(bar_name)
if bar is None:
return self.get_or_create_bar(
f"parsing-{d['block_file']}",
bar_name,
f"├─ blk{d['block_file']:05}.dat parsing", 'blocks', d['total']
)
def saving_bar(self, d):
return self.get_or_create_bar(
f"saving-{d['block_file']}",
f"├─ blk{d['block_file']:05}.dat saving", "txs", d['total']
)
if event == "save" and bar.unit == "blocks":
bar.desc = f"├─ blk{d['block_file']:05}.dat saving"
bar.unit = "txs"
bar.reset(d['total'])
return
def initialize_sync_bars(self, d):
self.bars.clear()
self.get_or_create_bar("parsing", "total parsing", "blocks", d['blocks'])
self.get_or_create_bar("saving", "total saving", "txs", d['txs'])
@staticmethod
def update_sync_bars(main, bar, d):
diff = d['step']-bar.last_print_n
main.update(diff)
bar.update(diff)
if d['step'] == d['total']:
self.bars[event].update(diff)
if event == "save" and d['step'] == d['total']:
bar.close()
def on_sync_progress(self, event):
e, d = event['event'], event.get('data', {})
if e.endswith("start"):
self.initialize_sync_bars(d)
elif e.endswith('parsing'):
self.update_sync_bars(self.bars['parsing'], self.parsing_bar(d), d)
elif e.endswith('saving'):
self.update_sync_bars(self.bars['saving'], self.saving_bar(d), d)
return
bars: Dict[int, tqdm.tqdm] = {}
while True:
msg = self.queue.get()
if msg == self.STOP:
return
file_num, msg_type, done = msg
bar, state = bars.get(file_num, None), self.state[file_num]
if msg_type == 1:
if bar is None:
bar = bars[file_num] = tqdm.tqdm(
desc=f'├─ blk{file_num:05}.dat parsing', total=state['total_blocks'],
unit='blocks', bar_format=self.FORMAT
)
change = done - state['done_blocks']
state['done_blocks'] = done
bar.update(change)
block_bar.update(change)
if state['total_blocks'] == done:
bar.set_description(''+bar.desc[3:])
bar.close()
bars.pop(file_num)
elif msg_type == 2:
if bar is None:
bar = bars[file_num] = tqdm.tqdm(
desc=f'├─ blk{file_num:05}.dat loading', total=state['total_txs'],
unit='txs', bar_format=self.FORMAT
)
change = done - state['done_txs']
state['done_txs'] = done
bar.update(change)
tx_bar.update(change)
if state['total_txs'] == done:
bar.set_description(''+bar.desc[3:])
bar.close()
bars.pop(file_num)
self.start_sync_block_bars(d)
elif e.endswith("block.done"):
self.close_sync_block_bars()
elif e.endswith("block.parse"):
self.update_sync_block_bars("parse", d)
elif e.endswith("block.save"):
self.update_sync_block_bars("save", d)

View file

@ -1,4 +1,4 @@
from .database import Database
from .database import Database, Result
from .constants import (
TXO_TYPES, SPENDABLE_TYPE_CODES,
CLAIM_TYPE_CODES, CLAIM_TYPE_NAMES

View file

@ -1,17 +1,20 @@
import os
import asyncio
import tempfile
from typing import List, Optional, Tuple, Iterable, TYPE_CHECKING
import multiprocessing as mp
from typing import List, Optional, Iterable, Iterator, TypeVar, Generic, TYPE_CHECKING
from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
from functools import partial
from sqlalchemy import create_engine, text
from lbry.event import EventController
from lbry.crypto.bip32 import PubKey
from lbry.schema.result import Censor
from lbry.blockchain.transaction import Transaction, Output
from .constants import TXO_TYPES
from .query_context import initialize, ProgressPublisher
from . import queries as q
from . import sync
if TYPE_CHECKING:
@ -48,26 +51,70 @@ async def add_channel_keys_to_txo_results(accounts: List, txos: Iterable[Output]
if sub_channels:
await add_channel_keys_to_txo_results(accounts, sub_channels)
ResultType = TypeVar('ResultType')
class Result(Generic[ResultType]):
__slots__ = 'rows', 'total', 'censor'
def __init__(self, rows: List[ResultType], total, censor=None):
self.rows = rows
self.total = total
self.censor = censor
def __getitem__(self, item: int) -> ResultType:
return self.rows[item]
def __iter__(self) -> Iterator[ResultType]:
return iter(self.rows)
def __len__(self):
return len(self.rows)
def __repr__(self):
return repr(self.rows)
class Database:
def __init__(self, ledger: 'Ledger', url: str, multiprocess=False):
self.url = url
def __init__(self, ledger: 'Ledger', processes=-1):
self.url = ledger.conf.db_url_or_default
self.ledger = ledger
self.multiprocess = multiprocess
self.processes = self._normalize_processes(processes)
self.executor: Optional[Executor] = None
self.message_queue = mp.Queue()
self.stop_event = mp.Event()
self._on_progress_controller = EventController()
self.on_progress = self._on_progress_controller.stream
self.progress_publisher = ProgressPublisher(
self.message_queue, self._on_progress_controller
)
@staticmethod
def _normalize_processes(processes):
if processes == 0:
return os.cpu_count()
elif processes > 0:
return processes
return 1
@classmethod
def temp_sqlite_regtest(cls):
from lbry import Config, RegTestLedger
def temp_sqlite_regtest(cls, lbrycrd_dir=None):
from lbry import Config, RegTestLedger # pylint: disable=import-outside-toplevel
directory = tempfile.mkdtemp()
conf = Config.with_same_dir(directory)
if lbrycrd_dir is not None:
conf.lbrycrd_dir = lbrycrd_dir
ledger = RegTestLedger(conf)
return cls(ledger, conf.db_url_or_default)
return cls(ledger)
@classmethod
def from_memory(cls, ledger):
return cls(ledger, 'sqlite:///:memory:')
def in_memory(cls):
from lbry import Config, Ledger # pylint: disable=import-outside-toplevel
conf = Config.with_same_dir('/dev/null')
conf.db_url = 'sqlite:///:memory:'
return cls(Ledger(conf))
def sync_create(self, name):
engine = create_engine(self.url)
@ -89,21 +136,22 @@ class Database:
async def open(self):
assert self.executor is None, "Database already open."
kwargs = dict(
initializer=q.initialize,
initargs=(self.url, self.ledger)
)
if self.multiprocess:
self.executor = ProcessPoolExecutor(
max_workers=max(os.cpu_count()-1, 4), **kwargs
self.progress_publisher.start()
kwargs = {
"initializer": initialize,
"initargs": (
self.ledger,
self.message_queue, self.stop_event
)
}
if self.processes > 1:
self.executor = ProcessPoolExecutor(max_workers=self.processes, **kwargs)
else:
self.executor = ThreadPoolExecutor(
max_workers=1, **kwargs
)
self.executor = ThreadPoolExecutor(max_workers=1, **kwargs)
return await self.run_in_executor(q.check_version_and_create_tables)
async def close(self):
self.progress_publisher.stop()
if self.executor is not None:
self.executor.shutdown()
self.executor = None
@ -115,15 +163,31 @@ class Database:
self.executor, partial(func, *args, **kwargs)
)
async def fetch_result(self, func, *args, **kwargs) -> Result:
rows, total = await self.run_in_executor(func, *args, **kwargs)
return Result(rows, total)
async def execute(self, sql):
return await self.run_in_executor(q.execute, sql)
async def execute_fetchall(self, sql):
return await self.run_in_executor(q.execute_fetchall, sql)
async def get_best_height(self):
async def process_inputs(self, heights):
return await self.run_in_executor(sync.process_inputs, heights)
async def process_claims(self, heights):
return await self.run_in_executor(sync.process_claims, heights)
async def process_supports(self, heights):
return await self.run_in_executor(sync.process_supports, heights)
async def get_best_height(self) -> int:
return await self.run_in_executor(q.get_best_height)
async def get_best_height_for_file(self, file_number) -> int:
return await self.run_in_executor(q.get_best_height_for_file, file_number)
async def get_blocks_without_filters(self):
return await self.run_in_executor(q.get_blocks_without_filters)
@ -139,8 +203,8 @@ class Database:
async def get_transaction_address_filters(self, block_hash):
return await self.run_in_executor(q.get_transaction_address_filters, block_hash)
async def insert_transaction(self, tx):
return await self.run_in_executor(q.insert_transaction, tx)
async def insert_transaction(self, block_hash, tx):
return await self.run_in_executor(q.insert_transaction, block_hash, tx)
async def update_address_used_times(self, addresses):
return await self.run_in_executor(q.update_address_used_times, addresses)
@ -167,73 +231,70 @@ class Database:
async def get_report(self, accounts):
return await self.run_in_executor(q.get_report, accounts=accounts)
async def get_addresses(self, **constraints) -> Tuple[List[dict], Optional[int]]:
addresses, count = await self.run_in_executor(q.get_addresses, **constraints)
async def get_addresses(self, **constraints) -> Result[dict]:
addresses = await self.fetch_result(q.get_addresses, **constraints)
if addresses and 'pubkey' in addresses[0]:
for address in addresses:
address['pubkey'] = PubKey(
self.ledger, bytes(address.pop('pubkey')), bytes(address.pop('chain_code')),
address.pop('n'), address.pop('depth')
)
return addresses, count
return addresses
async def get_all_addresses(self):
return await self.run_in_executor(q.get_all_addresses)
async def get_address(self, **constraints):
addresses, _ = await self.get_addresses(limit=1, **constraints)
if addresses:
return addresses[0]
for address in await self.get_addresses(limit=1, **constraints):
return address
async def add_keys(self, account, chain, pubkeys):
return await self.run_in_executor(q.add_keys, account, chain, pubkeys)
async def get_raw_transactions(self, tx_hashes):
return await self.run_in_executor(q.get_raw_transactions, tx_hashes)
async def get_transactions(self, **constraints) -> Tuple[List[Transaction], Optional[int]]:
return await self.run_in_executor(q.get_transactions, **constraints)
async def get_transactions(self, **constraints) -> Result[Transaction]:
return await self.fetch_result(q.get_transactions, **constraints)
async def get_transaction(self, **constraints) -> Optional[Transaction]:
txs, _ = await self.get_transactions(limit=1, **constraints)
txs = await self.get_transactions(limit=1, **constraints)
if txs:
return txs[0]
async def get_purchases(self, **constraints) -> Tuple[List[Output], Optional[int]]:
return await self.run_in_executor(q.get_purchases, **constraints)
async def get_purchases(self, **constraints) -> Result[Output]:
return await self.fetch_result(q.get_purchases, **constraints)
async def search_claims(self, **constraints) -> Tuple[List[Output], Optional[int], Censor]:
return await self.run_in_executor(q.search, **constraints)
async def search_claims(self, **constraints) -> Result[Output]:
claims, total, censor = await self.run_in_executor(q.search, **constraints)
return Result(claims, total, censor)
async def get_txo_sum(self, **constraints):
async def get_txo_sum(self, **constraints) -> int:
return await self.run_in_executor(q.get_txo_sum, **constraints)
async def get_txo_plot(self, **constraints):
async def get_txo_plot(self, **constraints) -> List[dict]:
return await self.run_in_executor(q.get_txo_plot, **constraints)
async def get_txos(self, **constraints) -> Tuple[List[Output], Optional[int]]:
txos, count = await self.run_in_executor(q.get_txos, **constraints)
async def get_txos(self, **constraints) -> Result[Output]:
txos = await self.fetch_result(q.get_txos, **constraints)
if 'wallet' in constraints:
await add_channel_keys_to_txo_results(constraints['wallet'].accounts, txos)
return txos, count
return txos
async def get_utxos(self, **constraints) -> Tuple[List[Output], Optional[int]]:
async def get_utxos(self, **constraints) -> Result[Output]:
return await self.get_txos(is_spent=False, **constraints)
async def get_supports(self, **constraints) -> Tuple[List[Output], Optional[int]]:
async def get_supports(self, **constraints) -> Result[Output]:
return await self.get_utxos(txo_type=TXO_TYPES['support'], **constraints)
async def get_claims(self, **constraints) -> Tuple[List[Output], Optional[int]]:
txos, count = await self.run_in_executor(q.get_claims, **constraints)
async def get_claims(self, **constraints) -> Result[Output]:
txos = await self.fetch_result(q.get_claims, **constraints)
if 'wallet' in constraints:
await add_channel_keys_to_txo_results(constraints['wallet'].accounts, txos)
return txos, count
return txos
async def get_streams(self, **constraints) -> Tuple[List[Output], Optional[int]]:
async def get_streams(self, **constraints) -> Result[Output]:
return await self.get_claims(txo_type=TXO_TYPES['stream'], **constraints)
async def get_channels(self, **constraints) -> Tuple[List[Output], Optional[int]]:
async def get_channels(self, **constraints) -> Result[Output]:
return await self.get_claims(txo_type=TXO_TYPES['channel'], **constraints)
async def get_collections(self, **constraints) -> Tuple[List[Output], Optional[int]]:
async def get_collections(self, **constraints) -> Result[Output]:
return await self.get_claims(txo_type=TXO_TYPES['collection'], **constraints)

View file

@ -1,383 +1,91 @@
# pylint: disable=singleton-comparison
import struct
import logging
import itertools
from datetime import date
from decimal import Decimal
from binascii import unhexlify
from operator import itemgetter
from contextvars import ContextVar
from itertools import chain
from typing import NamedTuple, Tuple, Dict, Callable, Optional
from typing import Tuple, List, Dict, Optional, Union
from sqlalchemy import create_engine, union, func, inspect
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.future import select
from sqlalchemy import union, func, text
from sqlalchemy.future import select, Select
from lbry.schema.tags import clean_tags
from lbry.schema.result import Censor, Outputs
from lbry.schema.url import URL, normalize_name
from lbry.schema.mime_types import guess_stream_type
from lbry.error import ResolveCensoredError
from lbry.blockchain.ledger import Ledger
from lbry.blockchain.transaction import Transaction, Output, Input, OutputScript, TXRefImmutable
from lbry.blockchain.transaction import Transaction, Output, OutputScript, TXRefImmutable
from .utils import *
from .tables import *
from .constants import *
from .utils import query, in_account_ids
from .query_context import context
from .constants import (
TXO_TYPES, CLAIM_TYPE_CODES, STREAM_TYPES, ATTRIBUTE_ARRAY_MAX_LENGTH,
SEARCH_PARAMS, SEARCH_INTEGER_PARAMS, SEARCH_ORDER_FIELDS
)
from .tables import (
metadata,
SCHEMA_VERSION, Version,
Block, TX, TXO, TXI, txi_join_account, txo_join_account,
Claim, Claimtrie,
PubkeyAddress, AccountAddress
)
MAX_QUERY_VARIABLES = 900
_context: ContextVar['QueryContext'] = ContextVar('_context')
def ctx():
return _context.get()
def initialize(url: str, ledger: Ledger, track_metrics=False, block_and_filter=None):
engine = create_engine(url)
connection = engine.connect()
if block_and_filter is not None:
blocked_streams, blocked_channels, filtered_streams, filtered_channels = block_and_filter
else:
blocked_streams = blocked_channels = filtered_streams = filtered_channels = {}
_context.set(
QueryContext(
engine=engine, connection=connection, ledger=ledger,
stack=[], metrics={}, is_tracking_metrics=track_metrics,
blocked_streams=blocked_streams, blocked_channels=blocked_channels,
filtered_streams=filtered_streams, filtered_channels=filtered_channels,
)
)
log = logging.getLogger(__name__)
def check_version_and_create_tables():
context = ctx()
if context.has_table('version'):
version = context.fetchone(select(Version.c.version).limit(1))
with context("db.connecting") as ctx:
if ctx.is_sqlite:
ctx.execute(text("PRAGMA journal_mode=WAL;"))
if ctx.has_table('version'):
version = ctx.fetchone(select(Version.c.version).limit(1))
if version and version['version'] == SCHEMA_VERSION:
return
metadata.drop_all(context.engine)
metadata.create_all(context.engine)
context.execute(Version.insert().values(version=SCHEMA_VERSION))
class QueryContext(NamedTuple):
engine: Engine
connection: Connection
ledger: Ledger
stack: List[List]
metrics: Dict
is_tracking_metrics: bool
blocked_streams: Dict
blocked_channels: Dict
filtered_streams: Dict
filtered_channels: Dict
@property
def is_postgres(self):
return self.connection.dialect.name == 'postgresql'
@property
def is_sqlite(self):
return self.connection.dialect.name == 'sqlite'
def raise_unsupported_dialect(self):
raise RuntimeError(f'Unsupported database dialect: {self.connection.dialect.name}.')
def reset_metrics(self):
self.stack = []
self.metrics = {}
def get_resolve_censor(self) -> Censor:
return Censor(self.blocked_streams, self.blocked_channels)
def get_search_censor(self) -> Censor:
return Censor(self.filtered_streams, self.filtered_channels)
def execute(self, sql, *args):
return self.connection.execute(sql, *args)
def fetchone(self, sql, *args):
row = self.connection.execute(sql, *args).fetchone()
return dict(row._mapping) if row else row
def fetchall(self, sql, *args):
rows = self.connection.execute(sql, *args).fetchall()
return [dict(row._mapping) for row in rows]
def insert_or_ignore(self, table):
if self.is_sqlite:
return table.insert().prefix_with("OR IGNORE")
elif self.is_postgres:
return pg_insert(table).on_conflict_do_nothing()
else:
self.raise_unsupported_dialect()
def insert_or_replace(self, table, replace):
if self.is_sqlite:
return table.insert().prefix_with("OR REPLACE")
elif self.is_postgres:
insert = pg_insert(table)
return insert.on_conflict_do_update(
table.primary_key, set_={col: getattr(insert.excluded, col) for col in replace}
)
else:
self.raise_unsupported_dialect()
def has_table(self, table):
return inspect(self.engine).has_table(table)
class RowCollector:
def __init__(self, context: QueryContext):
self.context = context
self.ledger = context.ledger
self.blocks = []
self.txs = []
self.txos = []
self.txis = []
self.claims = []
self.tags = []
@staticmethod
def block_to_row(block):
return {
'block_hash': block.block_hash,
'previous_hash': block.prev_block_hash,
'file_number': block.file_number,
'height': 0 if block.is_first_block else None,
}
@staticmethod
def tx_to_row(block_hash: bytes, tx: Transaction):
row = {
'tx_hash': tx.hash,
'block_hash': block_hash,
'raw': tx.raw,
'height': tx.height,
'position': tx.position,
'is_verified': tx.is_verified,
# TODO: fix
# 'day': tx.get_ordinal_day(self.db.ledger),
'purchased_claim_hash': None,
}
txos = tx.outputs
if len(txos) >= 2 and txos[1].can_decode_purchase_data:
txos[0].purchase = txos[1]
row['purchased_claim_hash'] = txos[1].purchase_data.claim_hash
return row
@staticmethod
def txi_to_row(tx: Transaction, txi: Input):
return {
'tx_hash': tx.hash,
'txo_hash': txi.txo_ref.hash,
'position': txi.position,
}
def txo_to_row(self, tx: Transaction, txo: Output):
row = {
'tx_hash': tx.hash,
'txo_hash': txo.hash,
'address': txo.get_address(self.ledger) if txo.has_address else None,
'position': txo.position,
'amount': txo.amount,
'script_offset': txo.script.offset,
'script_length': txo.script.length,
'txo_type': 0,
'claim_id': None,
'claim_hash': None,
'claim_name': None,
'reposted_claim_hash': None,
'channel_hash': None,
}
if txo.is_claim:
if txo.can_decode_claim:
claim = txo.claim
row['txo_type'] = TXO_TYPES.get(claim.claim_type, TXO_TYPES['stream'])
if claim.is_repost:
row['reposted_claim_hash'] = claim.repost.reference.claim_hash
if claim.is_signed:
row['channel_hash'] = claim.signing_channel_hash
else:
row['txo_type'] = TXO_TYPES['stream']
elif txo.is_support:
row['txo_type'] = TXO_TYPES['support']
elif txo.purchase is not None:
row['txo_type'] = TXO_TYPES['purchase']
row['claim_id'] = txo.purchased_claim_id
row['claim_hash'] = txo.purchased_claim_hash
if txo.script.is_claim_involved:
row['claim_id'] = txo.claim_id
row['claim_hash'] = txo.claim_hash
row['claim_name'] = txo.claim_name
return row
def add_block(self, block):
self.blocks.append(self.block_to_row(block))
for tx in block.txs:
self.add_transaction(block.block_hash, tx)
return self
def add_transaction(self, block_hash: bytes, tx: Transaction):
self.txs.append(self.tx_to_row(block_hash, tx))
for txi in tx.inputs:
if txi.coinbase is None:
self.txis.append(self.txi_to_row(tx, txi))
for txo in tx.outputs:
self.txos.append(self.txo_to_row(tx, txo))
return self
def add_claim(self, txo):
try:
assert txo.claim_name
assert txo.normalized_name
except:
#self.logger.exception(f"Could not decode claim name for {tx.id}:{txo.position}.")
return
tx = txo.tx_ref.tx
claim_hash = txo.claim_hash
claim_record = {
'claim_hash': claim_hash,
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'normalized': txo.normalized_name,
'address': txo.get_address(self.ledger),
'txo_hash': txo.ref.hash,
'tx_position': tx.position,
'amount': txo.amount,
'timestamp': 0, # TODO: fix
'creation_timestamp': 0, # TODO: fix
'height': tx.height,
'creation_height': tx.height,
'release_time': None,
'title': None,
'author': None,
'description': None,
'claim_type': None,
# streams
'stream_type': None,
'media_type': None,
'fee_currency': None,
'fee_amount': 0,
'duration': None,
# reposts
'reposted_claim_hash': None,
# claims which are channels
'public_key_bytes': None,
'public_key_hash': None,
}
self.claims.append(claim_record)
try:
claim = txo.claim
except:
#self.logger.exception(f"Could not parse claim protobuf for {tx.id}:{txo.position}.")
return
if claim.is_stream:
claim_record['claim_type'] = TXO_TYPES['stream']
claim_record['media_type'] = claim.stream.source.media_type
claim_record['stream_type'] = STREAM_TYPES[guess_stream_type(claim_record['media_type'])]
claim_record['title'] = claim.stream.title
claim_record['description'] = claim.stream.description
claim_record['author'] = claim.stream.author
if claim.stream.video and claim.stream.video.duration:
claim_record['duration'] = claim.stream.video.duration
if claim.stream.audio and claim.stream.audio.duration:
claim_record['duration'] = claim.stream.audio.duration
if claim.stream.release_time:
claim_record['release_time'] = claim.stream.release_time
if claim.stream.has_fee:
fee = claim.stream.fee
if isinstance(fee.currency, str):
claim_record['fee_currency'] = fee.currency.lower()
if isinstance(fee.amount, Decimal):
claim_record['fee_amount'] = int(fee.amount*1000)
elif claim.is_repost:
claim_record['claim_type'] = TXO_TYPES['repost']
claim_record['reposted_claim_hash'] = claim.repost.reference.claim_hash
elif claim.is_channel:
claim_record['claim_type'] = TXO_TYPES['channel']
claim_record['public_key_bytes'] = txo.claim.channel.public_key_bytes
claim_record['public_key_hash'] = self.ledger.address_to_hash160(
self.ledger.public_key_to_address(txo.claim.channel.public_key_bytes)
)
for tag in clean_tags(claim.message.tags):
self.tags.append({'claim_hash': claim_hash, 'tag': tag})
return self
def save(self, progress: Callable = None):
queries = (
(Block.insert(), self.blocks),
(TX.insert(), self.txs),
(TXO.insert(), self.txos),
(TXI.insert(), self.txis),
(Claim.insert(), self.claims),
(Tag.insert(), self.tags),
)
total_rows = sum(len(query[1]) for query in queries)
inserted_rows = 0
if progress is not None:
progress(inserted_rows, total_rows)
execute = self.context.connection.execute
for sql, rows in queries:
for chunk_size, chunk_rows in chunk(rows, 10000):
execute(sql, list(chunk_rows))
inserted_rows += chunk_size
if progress is not None:
progress(inserted_rows, total_rows)
metadata.drop_all(ctx.engine)
metadata.create_all(ctx.engine)
ctx.execute(Version.insert().values(version=SCHEMA_VERSION))
if ctx.is_postgres:
ctx.execute(text("ALTER TABLE txi DISABLE TRIGGER ALL;"))
ctx.execute(text("ALTER TABLE txo DISABLE TRIGGER ALL;"))
ctx.execute(text("ALTER TABLE tx DISABLE TRIGGER ALL;"))
ctx.execute(text("ALTER TABLE claim DISABLE TRIGGER ALL;"))
ctx.execute(text("ALTER TABLE claimtrie DISABLE TRIGGER ALL;"))
ctx.execute(text("ALTER TABLE block DISABLE TRIGGER ALL;"))
def insert_transaction(block_hash, tx):
RowCollector(ctx()).add_transaction(block_hash, tx).save()
def process_claims_and_supports(block_range=None):
context = ctx()
if context.is_sqlite:
address_query = select(TXO.c.address).where(TXI.c.txo_hash == TXO.c.txo_hash)
sql = (
TXI.update()
.values(address=address_query.scalar_subquery())
.where(TXI.c.address == None)
)
else:
sql = (
TXI.update()
.values({TXI.c.address: TXO.c.address})
.where((TXI.c.address == None) & (TXI.c.txo_hash == TXO.c.txo_hash))
)
context.execute(sql)
context.execute(Claim.delete())
rows = RowCollector(ctx())
for claim in get_txos(txo_type__in=CLAIM_TYPE_CODES, is_spent=False)[0]:
rows.add_claim(claim)
rows.save()
context().get_bulk_loader().add_transaction(block_hash, tx).save()
def execute(sql):
return ctx().execute(text(sql))
return context().execute(text(sql))
def execute_fetchall(sql):
return ctx().fetchall(text(sql))
return context().fetchall(text(sql))
def get_best_height():
return ctx().fetchone(
return context().fetchone(
select(func.coalesce(func.max(TX.c.height), -1).label('total')).select_from(TX)
)['total']
def get_best_height_for_file(file_number):
return context().fetchone(
select(func.coalesce(func.max(Block.c.height), -1).label('height'))
.select_from(Block)
.where(Block.c.file_number == file_number)
)['height']
def get_blocks_without_filters():
return ctx().fetchall(
return context().fetchall(
select(Block.c.block_hash)
.select_from(Block)
.where(Block.c.block_filter == None)
@ -385,7 +93,7 @@ def get_blocks_without_filters():
def get_transactions_without_filters():
return ctx().fetchall(
return context().fetchall(
select(TX.c.tx_hash)
.select_from(TX)
.where(TX.c.tx_filter == None)
@ -399,7 +107,7 @@ def get_block_tx_addresses(block_hash=None, tx_hash=None):
constraint = (TX.c.tx_hash == tx_hash)
else:
raise ValueError('block_hash or tx_hash must be provided.')
return ctx().fetchall(
return context().fetchall(
union(
select(TXO.c.address).select_from(TXO.join(TX)).where((TXO.c.address != None) & constraint),
select(TXI.c.address).select_from(TXI.join(TX)).where((TXI.c.address != None) & constraint),
@ -408,13 +116,13 @@ def get_block_tx_addresses(block_hash=None, tx_hash=None):
def get_block_address_filters():
return ctx().fetchall(
return context().fetchall(
select(Block.c.block_hash, Block.c.block_filter).select_from(Block)
)
def get_transaction_address_filters(block_hash):
return ctx().fetchall(
return context().fetchall(
select(TX.c.tx_hash, TX.c.tx_filter)
.select_from(TX)
.where(TX.c.block_hash == block_hash)
@ -422,7 +130,7 @@ def get_transaction_address_filters(block_hash):
def update_address_used_times(addresses):
ctx().execute(
context().execute(
PubkeyAddress.update()
.values(used_times=(
select(func.count(TXO.c.address)).where((TXO.c.address == PubkeyAddress.c.address)),
@ -432,13 +140,13 @@ def update_address_used_times(addresses):
def reserve_outputs(txo_hashes, is_reserved=True):
ctx().execute(
context().execute(
TXO.update().values(is_reserved=is_reserved).where(TXO.c.txo_hash.in_(txo_hashes))
)
def release_all_outputs(account_id):
ctx().execute(
context().execute(
TXO.update().values(is_reserved=False).where(
(TXO.c.is_reserved == True) &
(TXO.c.address.in_(select(AccountAddress.c.address).where(in_account_ids(account_id))))
@ -456,19 +164,28 @@ def select_transactions(cols, account_ids=None, **constraints):
select(TXI.c.tx_hash).select_from(txi_join_account).where(where)
)
s = s.where(TX.c.tx_hash.in_(tx_hashes))
return ctx().fetchall(query([TX], s, **constraints))
return context().fetchall(query([TX], s, **constraints))
TXO_NOT_MINE = Output(None, None, is_my_output=False)
def get_raw_transactions(tx_hashes):
return ctx().fetchall(
return context().fetchall(
select(TX.c.tx_hash, TX.c.raw).where(TX.c.tx_hash.in_(tx_hashes))
)
def get_transactions(wallet=None, include_total=False, **constraints) -> Tuple[List[Transaction], Optional[int]]:
def get_transactions(**constraints) -> Tuple[List[Transaction], Optional[int]]:
txs = []
sql = select(TX.c.raw, TX.c.height, TX.c.position).select_from(TX)
rows = context().fetchall(query([TX], sql, **constraints))
for row in rows:
txs.append(Transaction(row['raw'], height=row['height'], position=row['position']))
return txs, 0
def _get_transactions(wallet=None, include_total=False, **constraints) -> Tuple[List[Transaction], Optional[int]]:
include_is_spent = constraints.pop('include_is_spent', False)
include_is_my_input = constraints.pop('include_is_my_input', False)
include_is_my_output = constraints.pop('include_is_my_output', False)
@ -599,7 +316,7 @@ def select_txos(
tables.append(Claim)
joins = joins.join(Claim)
s = s.select_from(joins)
return ctx().fetchall(query(tables, s, **constraints))
return context().fetchall(query(tables, s, **constraints))
def get_txos(no_tx=False, include_total=False, **constraints) -> Tuple[List[Output], Optional[int]]:
@ -750,7 +467,6 @@ def get_txo_plot(start_day=None, days_back=0, end_day=None, days_after=None, **c
_clean_txo_constraints_for_aggregation(constraints)
if start_day is None:
# TODO: Fix
raise NotImplementedError
current_ordinal = 0 # self.ledger.headers.estimated_date(self.ledger.headers.height).toordinal()
constraints['day__gte'] = current_ordinal - days_back
else:
@ -774,14 +490,14 @@ def get_purchases(**constraints) -> Tuple[List[Output], Optional[int]]:
if not {'purchased_claim_hash', 'purchased_claim_hash__in'}.intersection(constraints):
constraints['purchased_claim_hash__is_not_null'] = True
constraints['tx_hash__in'] = (
select(TXI.c.tx_hash).select_from(txi_join_account).where(in_account(accounts))
select(TXI.c.tx_hash).select_from(txi_join_account).where(in_account_ids(accounts))
)
txs, count = get_transactions(**constraints)
return [tx.outputs[0] for tx in txs], count
def select_addresses(cols, **constraints):
return ctx().fetchall(query(
return context().fetchall(query(
[AccountAddress, PubkeyAddress],
select(*cols).select_from(PubkeyAddress.join(AccountAddress)),
**constraints
@ -812,11 +528,11 @@ def get_address_count(**constraints):
def get_all_addresses(self):
return ctx().execute(select(PubkeyAddress.c.address))
return context().execute(select(PubkeyAddress.c.address))
def add_keys(account, chain, pubkeys):
c = ctx()
c = context()
c.execute(
c.insert_or_ignore(PubkeyAddress)
.values([{'address': k.address} for k in pubkeys])
@ -846,7 +562,7 @@ def get_supports_summary(self, **constraints):
def search_to_bytes(constraints) -> Union[bytes, Tuple[bytes, Dict]]:
return Outputs.to_bytes(*search(constraints))
return Outputs.to_bytes(*search(**constraints))
def resolve_to_bytes(urls) -> Union[bytes, Tuple[bytes, Dict]]:
@ -854,26 +570,26 @@ def resolve_to_bytes(urls) -> Union[bytes, Tuple[bytes, Dict]]:
def execute_censored(sql, row_offset: int, row_limit: int, censor: Censor) -> List:
context = ctx()
return ctx().fetchall(sql)
c = context.db.cursor()
def row_filter(cursor, row):
nonlocal row_offset
#row = row_factory(cursor, row)
if len(row) > 1 and censor.censor(row):
return
if row_offset:
row_offset -= 1
return
return row
c.setrowtrace(row_filter)
i, rows = 0, []
for row in c.execute(sql):
i += 1
rows.append(row)
if i >= row_limit:
break
return rows
ctx = context()
return ctx.fetchall(sql)
# c = ctx.db.cursor()
# def row_filter(cursor, row):
# nonlocal row_offset
# #row = row_factory(cursor, row)
# if len(row) > 1 and censor.censor(row):
# return
# if row_offset:
# row_offset -= 1
# return
# return row
# c.setrowtrace(row_filter)
# i, rows = 0, []
# for row in c.execute(sql):
# i += 1
# rows.append(row)
# if i >= row_limit:
# break
# return rows
def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
@ -938,7 +654,7 @@ def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
if 'public_key_id' in constraints:
constraints['public_key_hash'] = (
ctx().ledger.address_to_hash160(constraints.pop('public_key_id')))
context().ledger.address_to_hash160(constraints.pop('public_key_id')))
if 'channel_hash' in constraints:
constraints['channel_hash'] = constraints.pop('channel_hash')
if 'channel_ids' in constraints:
@ -984,7 +700,7 @@ def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
claim_types = [claim_types]
if claim_types:
constraints['claim_type__in'] = {
CLAIM_TYPES[claim_type] for claim_type in claim_types
CLAIM_TYPE_CODES[claim_type] for claim_type in claim_types
}
if 'stream_types' in constraints:
stream_types = constraints.pop('stream_types')
@ -1053,23 +769,23 @@ def search_claims(censor: Censor, **constraints) -> List:
TXO.c.position.label('txo_position'),
Claim.c.claim_hash,
Claim.c.txo_hash,
# Claim.c.claims_in_channel,
# Claim.c.reposted,
# Claim.c.height,
# Claim.c.creation_height,
# Claim.c.activation_height,
# Claim.c.expiration_height,
# Claim.c.effective_amount,
# Claim.c.support_amount,
# Claim.c.trending_group,
# Claim.c.trending_mixed,
# Claim.c.trending_local,
# Claim.c.trending_global,
# Claim.c.short_url,
# Claim.c.canonical_url,
# Claim.c.claims_in_channel,
# Claim.c.reposted,
# Claim.c.height,
# Claim.c.creation_height,
# Claim.c.activation_height,
# Claim.c.expiration_height,
# Claim.c.effective_amount,
# Claim.c.support_amount,
# Claim.c.trending_group,
# Claim.c.trending_mixed,
# Claim.c.trending_local,
# Claim.c.trending_global,
# Claim.c.short_url,
# Claim.c.canonical_url,
Claim.c.channel_hash,
Claim.c.reposted_claim_hash,
# Claim.c.signature_valid
# Claim.c.signature_valid
], **constraints
)
@ -1079,9 +795,9 @@ def get_claims(**constraints) -> Tuple[List[Output], Optional[int]]:
def _get_referenced_rows(txo_rows: List[dict], censor_channels: List[bytes]):
censor = ctx().get_resolve_censor()
censor = context().get_resolve_censor()
repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows)))
channel_hashes = set(chain(
channel_hashes = set(itertools.chain(
filter(None, map(itemgetter('channel_hash'), txo_rows)),
censor_channels
))
@ -1107,8 +823,8 @@ def old_search(**constraints) -> Tuple[List, List, int, int, Censor]:
total = count_claims(**constraints)
constraints['offset'] = abs(constraints.get('offset', 0))
constraints['limit'] = min(abs(constraints.get('limit', 10)), 50)
context = ctx()
search_censor = context.get_search_censor()
ctx = context()
search_censor = ctx.get_search_censor()
txo_rows = search_claims(search_censor, **constraints)
extra_txo_rows = _get_referenced_rows(txo_rows, search_censor.censored.keys())
return txo_rows, extra_txo_rows, constraints['offset'], total, search_censor
@ -1122,8 +838,8 @@ def search(**constraints) -> Tuple[List, int, Censor]:
total = count_claims(**constraints)
constraints['offset'] = abs(constraints.get('offset', 0))
constraints['limit'] = min(abs(constraints.get('limit', 10)), 50)
context = ctx()
search_censor = context.get_search_censor()
ctx = context()
search_censor = ctx.get_search_censor()
txos = []
for row in search_claims(search_censor, **constraints):
source = row['raw'][row['script_offset']:row['script_offset']+row['script_length']]
@ -1148,7 +864,7 @@ def resolve(urls) -> Tuple[List, List]:
def resolve_url(raw_url):
censor = ctx().get_resolve_censor()
censor = context().get_resolve_censor()
try:
url = URL.parse(raw_url)
@ -1158,12 +874,12 @@ def resolve_url(raw_url):
channel = None
if url.has_channel:
query = url.channel.to_dict()
if set(query) == {'name'}:
query['is_controlling'] = True
q = url.channel.to_dict()
if set(q) == {'name'}:
q['is_controlling'] = True
else:
query['order_by'] = ['^creation_height']
matches = search_claims(censor, **query, limit=1)
q['order_by'] = ['^creation_height']
matches = search_claims(censor, **q, limit=1)
if matches:
channel = matches[0]
elif censor.censored:
@ -1172,18 +888,18 @@ def resolve_url(raw_url):
return LookupError(f'Could not find channel in "{raw_url}".')
if url.has_stream:
query = url.stream.to_dict()
q = url.stream.to_dict()
if channel is not None:
if set(query) == {'name'}:
if set(q) == {'name'}:
# temporarily emulate is_controlling for claims in channel
query['order_by'] = ['effective_amount', '^height']
q['order_by'] = ['effective_amount', '^height']
else:
query['order_by'] = ['^channel_join']
query['channel_hash'] = channel['claim_hash']
query['signature_valid'] = 1
elif set(query) == {'name'}:
query['is_controlling'] = 1
matches = search_claims(censor, **query, limit=1)
q['order_by'] = ['^channel_join']
q['channel_hash'] = channel['claim_hash']
q['signature_valid'] = 1
elif set(q) == {'name'}:
q['is_controlling'] = 1
matches = search_claims(censor, **q, limit=1)
if matches:
return matches[0]
elif censor.censored:

498
lbry/db/query_context.py Normal file
View file

@ -0,0 +1,498 @@
import os
import time
import multiprocessing as mp
from enum import Enum
from decimal import Decimal
from typing import Dict, List, Optional
from dataclasses import dataclass
from contextvars import ContextVar
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine import Engine, Connection
from lbry.event import EventQueuePublisher
from lbry.blockchain.ledger import Ledger
from lbry.blockchain.transaction import Transaction, Output, Input
from lbry.schema.tags import clean_tags
from lbry.schema.result import Censor
from lbry.schema.mime_types import guess_stream_type
from .utils import pg_insert, chunk
from .tables import Block, TX, TXO, TXI, Claim, Tag, Claimtrie, Support
from .constants import TXO_TYPES, STREAM_TYPES
_context: ContextVar['QueryContext'] = ContextVar('_context')
@dataclass
class QueryContext:
engine: Engine
connection: Connection
ledger: Ledger
message_queue: mp.Queue
stop_event: mp.Event
stack: List[List]
metrics: Dict
is_tracking_metrics: bool
blocked_streams: Dict
blocked_channels: Dict
filtered_streams: Dict
filtered_channels: Dict
pid: int
# QueryContext __enter__/__exit__ state
print_timers: List
current_timer_name: Optional[str] = None
current_timer_time: float = 0
current_progress: Optional['ProgressContext'] = None
@property
def is_postgres(self):
return self.connection.dialect.name == 'postgresql'
@property
def is_sqlite(self):
return self.connection.dialect.name == 'sqlite'
def raise_unsupported_dialect(self):
raise RuntimeError(f'Unsupported database dialect: {self.connection.dialect.name}.')
def get_resolve_censor(self) -> Censor:
return Censor(self.blocked_streams, self.blocked_channels)
def get_search_censor(self) -> Censor:
return Censor(self.filtered_streams, self.filtered_channels)
def execute(self, sql, *args):
return self.connection.execute(sql, *args)
def fetchone(self, sql, *args):
row = self.connection.execute(sql, *args).fetchone()
return dict(row._mapping) if row else row
def fetchall(self, sql, *args):
rows = self.connection.execute(sql, *args).fetchall()
return [dict(row._mapping) for row in rows]
def insert_or_ignore(self, table):
if self.is_sqlite:
return table.insert().prefix_with("OR IGNORE")
elif self.is_postgres:
return pg_insert(table).on_conflict_do_nothing()
else:
self.raise_unsupported_dialect()
def insert_or_replace(self, table, replace):
if self.is_sqlite:
return table.insert().prefix_with("OR REPLACE")
elif self.is_postgres:
insert = pg_insert(table)
return insert.on_conflict_do_update(
table.primary_key, set_={col: getattr(insert.excluded, col) for col in replace}
)
else:
self.raise_unsupported_dialect()
def has_table(self, table):
return inspect(self.engine).has_table(table)
def get_bulk_loader(self) -> 'BulkLoader':
return BulkLoader(self)
def reset_metrics(self):
self.stack = []
self.metrics = {}
def with_timer(self, timer_name: str) -> 'QueryContext':
self.current_timer_name = timer_name
return self
def __enter__(self) -> 'QueryContext':
self.current_timer_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.current_timer_name and self.current_timer_name in self.print_timers:
elapsed = time.perf_counter() - self.current_timer_time
print(f"{self.print_timers} in {elapsed:.6f}s", flush=True)
self.current_timer_name = None
self.current_timer_time = 0
self.current_progress = None
def context(with_timer: str = None) -> 'QueryContext':
if isinstance(with_timer, str):
return _context.get().with_timer(with_timer)
return _context.get()
def initialize(
ledger: Ledger, message_queue: mp.Queue, stop_event: mp.Event,
track_metrics=False, block_and_filter=None, print_timers=None):
url = ledger.conf.db_url_or_default
engine = create_engine(url)
connection = engine.connect()
if block_and_filter is not None:
blocked_streams, blocked_channels, filtered_streams, filtered_channels = block_and_filter
else:
blocked_streams = blocked_channels = filtered_streams = filtered_channels = {}
_context.set(
QueryContext(
pid=os.getpid(),
engine=engine, connection=connection,
ledger=ledger, message_queue=message_queue, stop_event=stop_event,
stack=[], metrics={}, is_tracking_metrics=track_metrics,
blocked_streams=blocked_streams, blocked_channels=blocked_channels,
filtered_streams=filtered_streams, filtered_channels=filtered_channels,
print_timers=print_timers or []
)
)
def uninitialize():
ctx = _context.get(None)
if ctx is not None:
if ctx.connection:
ctx.connection.close()
_context.set(None)
class ProgressUnit(Enum):
NONE = "", None
TASKS = "tasks", None
BLOCKS = "blocks", Block
TXS = "txs", TX
TRIE = "trie", Claimtrie
TXIS = "txis", TXI
CLAIMS = "claims", Claim
SUPPORTS = "supports", Support
def __new__(cls, value, table):
next_id = len(cls.__members__) + 1
obj = object.__new__(cls)
obj._value_ = next_id
obj.label = value
obj.table = table
return obj
class Event(Enum):
# full node specific sync events
BLOCK_READ = "blockchain.sync.block.read", ProgressUnit.BLOCKS
BLOCK_SAVE = "blockchain.sync.block.save", ProgressUnit.TXS
BLOCK_DONE = "blockchain.sync.block.done", ProgressUnit.TASKS
TRIE_DELETE = "blockchain.sync.trie.delete", ProgressUnit.TRIE
TRIE_UPDATE = "blockchain.sync.trie.update", ProgressUnit.TRIE
TRIE_INSERT = "blockchain.sync.trie.insert", ProgressUnit.TRIE
# full node + light client sync events
INPUT_UPDATE = "db.sync.input", ProgressUnit.TXIS
CLAIM_DELETE = "db.sync.claim.delete", ProgressUnit.CLAIMS
CLAIM_UPDATE = "db.sync.claim.update", ProgressUnit.CLAIMS
CLAIM_INSERT = "db.sync.claim.insert", ProgressUnit.CLAIMS
SUPPORT_DELETE = "db.sync.support.delete", ProgressUnit.SUPPORTS
SUPPORT_UPDATE = "db.sync.support.update", ProgressUnit.SUPPORTS
SUPPORT_INSERT = "db.sync.support.insert", ProgressUnit.SUPPORTS
def __new__(cls, value, unit: ProgressUnit):
next_id = len(cls.__members__) + 1
obj = object.__new__(cls)
obj._value_ = next_id
obj.label = value
obj.unit = unit
return obj
class ProgressPublisher(EventQueuePublisher):
def message_to_event(self, message):
event = Event(message[0]) # pylint: disable=no-value-for-parameter
d = {
"event": event.label,
"data": {
"pid": message[1],
"step": message[2],
"total": message[3],
"unit": event.unit.label
}
}
if len(message) > 4 and isinstance(message[4], dict):
d['data'].update(message[4])
return d
class ProgressContext:
def __init__(self, ctx: QueryContext, event: Event, step_size=1):
self.ctx = ctx
self.event = event
self.extra = None
self.step_size = step_size
self.last_step = -1
self.total = 0
def __enter__(self) -> 'ProgressContext':
self.ctx.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.ctx.message_queue.put(self.get_event_args(self.total))
return self.ctx.__exit__(exc_type, exc_val, exc_tb)
def start(self, total, extra=None):
self.total = total
if extra is not None:
self.extra = extra
self.step(0)
def step(self, done):
send_condition = (
# enforce step rate
(self.step_size == 1 or done % self.step_size == 0) and
# deduplicate finish event by not sending a step where done == total
done < self.total and
# deduplicate same step
done != self.last_step
)
if send_condition:
self.ctx.message_queue.put_nowait(self.get_event_args(done))
self.last_step = done
def get_event_args(self, done):
if self.extra is not None:
return self.event.value, self.ctx.pid, done, self.total, self.extra
return self.event.value, self.ctx.pid, done, self.total
def progress(e: Event, step_size=1) -> ProgressContext:
ctx = context(e.label)
ctx.current_progress = ProgressContext(ctx, e, step_size=step_size)
return ctx.current_progress
class BulkLoader:
def __init__(self, ctx: QueryContext):
self.ctx = ctx
self.ledger = ctx.ledger
self.blocks = []
self.txs = []
self.txos = []
self.txis = []
self.claims = []
self.tags = []
@staticmethod
def block_to_row(block):
return {
'block_hash': block.block_hash,
'previous_hash': block.prev_block_hash,
'file_number': block.file_number,
'height': 0 if block.is_first_block else block.height,
}
@staticmethod
def tx_to_row(block_hash: bytes, tx: Transaction):
row = {
'tx_hash': tx.hash,
'block_hash': block_hash,
'raw': tx.raw,
'height': tx.height,
'position': tx.position,
'is_verified': tx.is_verified,
# TODO: fix
# 'day': tx.get_ordinal_day(self.db.ledger),
'purchased_claim_hash': None,
}
txos = tx.outputs
if len(txos) >= 2 and txos[1].can_decode_purchase_data:
txos[0].purchase = txos[1]
row['purchased_claim_hash'] = txos[1].purchase_data.claim_hash
return row
@staticmethod
def txi_to_row(tx: Transaction, txi: Input):
return {
'tx_hash': tx.hash,
'txo_hash': txi.txo_ref.hash,
'position': txi.position,
}
def txo_to_row(self, tx: Transaction, txo: Output):
row = {
'tx_hash': tx.hash,
'txo_hash': txo.hash,
'address': txo.get_address(self.ledger) if txo.has_address else None,
'position': txo.position,
'amount': txo.amount,
'script_offset': txo.script.offset,
'script_length': txo.script.length,
'txo_type': 0,
'claim_id': None,
'claim_hash': None,
'claim_name': None,
'reposted_claim_hash': None,
'channel_hash': None,
}
if txo.is_claim:
if txo.can_decode_claim:
claim = txo.claim
row['txo_type'] = TXO_TYPES.get(claim.claim_type, TXO_TYPES['stream'])
if claim.is_repost:
row['reposted_claim_hash'] = claim.repost.reference.claim_hash
if claim.is_signed:
row['channel_hash'] = claim.signing_channel_hash
else:
row['txo_type'] = TXO_TYPES['stream']
#self.add_claim(txo)
elif txo.is_support:
row['txo_type'] = TXO_TYPES['support']
elif txo.purchase is not None:
row['txo_type'] = TXO_TYPES['purchase']
row['claim_id'] = txo.purchased_claim_id
row['claim_hash'] = txo.purchased_claim_hash
if txo.script.is_claim_involved:
row['claim_id'] = txo.claim_id
row['claim_hash'] = txo.claim_hash
try:
claim_name = txo.claim_name
if '\x00' in claim_name:
# log.error(f"Name for claim {txo.claim_id} contains a NULL (\\x00) character, skipping.")
pass
else:
row['claim_name'] = claim_name
except UnicodeDecodeError:
# log.error(f"Name for claim {txo.claim_id} contains invalid unicode, skipping.")
pass
return row
def add_block(self, block):
self.blocks.append(self.block_to_row(block))
for tx in block.txs:
self.add_transaction(block.block_hash, tx)
return self
def add_transaction(self, block_hash: bytes, tx: Transaction):
self.txs.append(self.tx_to_row(block_hash, tx))
for txi in tx.inputs:
if txi.coinbase is None:
self.txis.append(self.txi_to_row(tx, txi))
for txo in tx.outputs:
self.txos.append(self.txo_to_row(tx, txo))
return self
def add_claim(self, txo):
try:
assert txo.claim_name
assert txo.normalized_name
except Exception:
#self.logger.exception(f"Could not decode claim name for {tx.id}:{txo.position}.")
return
tx = txo.tx_ref.tx
claim_hash = txo.claim_hash
claim_record = {
'claim_hash': claim_hash,
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'normalized': txo.normalized_name,
'address': txo.get_address(self.ledger),
'txo_hash': txo.ref.hash,
'tx_position': tx.position,
'amount': txo.amount,
'timestamp': 0, # TODO: fix
'creation_timestamp': 0, # TODO: fix
'height': tx.height,
'creation_height': tx.height,
'release_time': None,
'title': None,
'author': None,
'description': None,
'claim_type': None,
# streams
'stream_type': None,
'media_type': None,
'fee_currency': None,
'fee_amount': 0,
'duration': None,
# reposts
'reposted_claim_hash': None,
# claims which are channels
'public_key_bytes': None,
'public_key_hash': None,
}
self.claims.append(claim_record)
try:
claim = txo.claim
except Exception:
#self.logger.exception(f"Could not parse claim protobuf for {tx.id}:{txo.position}.")
return
if claim.is_stream:
claim_record['claim_type'] = TXO_TYPES['stream']
claim_record['media_type'] = claim.stream.source.media_type
claim_record['stream_type'] = STREAM_TYPES[guess_stream_type(claim_record['media_type'])]
claim_record['title'] = claim.stream.title
claim_record['description'] = claim.stream.description
claim_record['author'] = claim.stream.author
if claim.stream.video and claim.stream.video.duration:
claim_record['duration'] = claim.stream.video.duration
if claim.stream.audio and claim.stream.audio.duration:
claim_record['duration'] = claim.stream.audio.duration
if claim.stream.release_time:
claim_record['release_time'] = claim.stream.release_time
if claim.stream.has_fee:
fee = claim.stream.fee
if isinstance(fee.currency, str):
claim_record['fee_currency'] = fee.currency.lower()
if isinstance(fee.amount, Decimal):
claim_record['fee_amount'] = int(fee.amount*1000)
elif claim.is_repost:
claim_record['claim_type'] = TXO_TYPES['repost']
claim_record['reposted_claim_hash'] = claim.repost.reference.claim_hash
elif claim.is_channel:
claim_record['claim_type'] = TXO_TYPES['channel']
claim_record['public_key_bytes'] = txo.claim.channel.public_key_bytes
claim_record['public_key_hash'] = self.ledger.address_to_hash160(
self.ledger.public_key_to_address(txo.claim.channel.public_key_bytes)
)
for tag in clean_tags(claim.message.tags):
self.tags.append({'claim_hash': claim_hash, 'tag': tag})
return self
def save(self, batch_size=10000):
queries = (
(Block, self.blocks),
(TX, self.txs),
(TXO, self.txos),
(TXI, self.txis),
(Claim, self.claims),
(Tag, self.tags),
)
p = self.ctx.current_progress
done = row_scale = 0
if p:
unit_table = p.event.unit.table
progress_total, row_total = 0, sum(len(q[1]) for q in queries)
for table, rows in queries:
if table == unit_table:
progress_total = len(rows)
break
if not progress_total:
assert row_total == 0, "Rows used for progress are empty but other rows present."
return
row_scale = row_total / progress_total
p.start(progress_total)
execute = self.ctx.connection.execute
for table, rows in queries:
sql = table.insert()
for chunk_size, chunk_rows in chunk(rows, batch_size):
execute(sql, list(chunk_rows))
if p:
done += int(chunk_size/row_scale)
p.step(done)

48
lbry/db/sync.py Normal file
View file

@ -0,0 +1,48 @@
# pylint: disable=singleton-comparison
from sqlalchemy.future import select
from .constants import CLAIM_TYPE_CODES
from .queries import get_txos
from .query_context import progress, Event
from .tables import (
TXO, TXI,
Claim
)
def process_inputs(heights):
with progress(Event.INPUT_UPDATE) as p:
if p.ctx.is_sqlite:
address_query = select(TXO.c.address).where(TXI.c.txo_hash == TXO.c.txo_hash)
sql = (
TXI.update()
.values(address=address_query.scalar_subquery())
.where(TXI.c.address == None)
)
else:
sql = (
TXI.update()
.values({TXI.c.address: TXO.c.address})
.where((TXI.c.address == None) & (TXI.c.txo_hash == TXO.c.txo_hash))
)
p.start(1)
p.ctx.execute(sql)
def process_claims(heights):
with progress(Event.CLAIM_DELETE) as p:
p.start(1)
p.ctx.execute(Claim.delete())
with progress(Event.CLAIM_UPDATE) as p:
loader = p.ctx.get_bulk_loader()
for claim in get_txos(
txo_type__in=CLAIM_TYPE_CODES, is_spent=False,
height__gte=heights[0], height__lte=heights[1])[0]:
loader.add_claim(claim)
loader.save()
def process_supports(heights):
pass

View file

@ -70,6 +70,7 @@ TXO = Table(
Column('amount', BigInteger),
Column('script_offset', Integer),
Column('script_length', Integer),
Column('is_spent', Boolean, server_default='0'),
Column('is_reserved', Boolean, server_default='0'),
Column('txo_type', SmallInteger, server_default='0'),
Column('claim_id', Text, nullable=True),
@ -159,6 +160,13 @@ Tag = Table(
)
Support = Table(
'support', metadata,
Column('normalized', Text, primary_key=True),
Column('claim_hash', LargeBinary, ForeignKey(Claim.columns.claim_hash)),
)
Claimtrie = Table(
'claimtrie', metadata,
Column('normalized', Text, primary_key=True),

View file

@ -4,7 +4,7 @@ from typing import List, Union
from sqlalchemy import text, and_
from sqlalchemy.sql.expression import Select
try:
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.postgresql import insert as pg_insert # pylint: disable=unused-import
except ImportError:
pg_insert = None

View file

@ -162,12 +162,12 @@ class EventStream:
future = asyncio.get_event_loop().create_future()
value = None
def update_value(v):
def update_value(_value):
nonlocal value
value = v
value = _value
subscription = self.listen(
lambda v: update_value(v),
update_value,
lambda exception: not future.done() and self._cancel_and_error(subscription, future, exception),
lambda: not future.done() and self._cancel_and_callback(subscription, future, value),
)
@ -195,7 +195,8 @@ class EventQueuePublisher(threading.Thread):
self.event_controller = event_controller
self.loop = None
def message_to_event(self, message):
@staticmethod
def message_to_event(message):
return message
def start(self):

View file

@ -1,4 +1,4 @@
from .api import API
from .daemon import Daemon
from .daemon import Daemon, jsonrpc_dumps_pretty
from .full_node import FullNode
from .light_client import LightClient

View file

@ -20,6 +20,7 @@ from lbry.stream.managed_stream import ManagedStream
from lbry.event import EventController, EventStream
from .base import Service
from .json_encoder import Paginated
DEFAULT_PAGE_SIZE = 20
@ -32,12 +33,7 @@ async def paginate_rows(get_records: Callable, page: Optional[int], page_size: O
"offset": page_size * (page - 1),
"limit": page_size
})
items, count = await get_records(**constraints)
result = {"items": items, "page": page, "page_size": page_size}
if count is not None:
result["total_pages"] = int((count + (page_size - 1)) / page_size)
result["total_items"] = count
return result
return Paginated(await get_records(**constraints), page, page_size)
def paginate_list(items: List, page: Optional[int], page_size: Optional[int]):
@ -57,7 +53,6 @@ def paginate_list(items: List, page: Optional[int], page_size: Optional[int]):
StrOrList = Union[str, list]
Paginated = List
Address = Dict
@ -128,7 +123,7 @@ def tx_kwargs(
change_account_id: str = None, # account to send excess change (LBC)
fund_account_id: StrOrList = None, # accounts to fund the transaction
preview=False, # do not broadcast the transaction
blocking=False, # wait until transaction is in mempool
no_wait=False, # do not wait for mempool confirmation
):
pass
@ -1668,7 +1663,7 @@ class API:
name=name, amount=amount, holding_account=holding_account, funding_accounts=funding_accounts,
save_key=not tx_dict['preview'], **remove_nulls(channel_dict)
)
await self.service.maybe_broadcast_or_release(tx, tx_dict['blocking'], tx_dict['preview'])
await self.service.maybe_broadcast_or_release(tx, tx_dict['preview'], tx_dict['no_wait'])
return tx
async def channel_update(
@ -2783,7 +2778,7 @@ class API:
days_after: int = None, # end number of days after --start_day (instead of --end_day)
end_day: str = None, # end on specific date (YYYY-MM-DD) (instead of --days_after)
**txo_filter_and_pagination_kwargs
) -> list:
) -> List:
"""
Plot transaction output sum over days.

View file

@ -3,9 +3,7 @@ import asyncio
import logging
from typing import List, Optional, Tuple, NamedTuple
from lbry.conf import Config
from lbry.db import Database
from lbry.db import Database, Result
from lbry.db.constants import TXO_TYPES
from lbry.schema.result import Censor
from lbry.blockchain.transaction import Transaction, Output
@ -30,8 +28,8 @@ class Sync:
self._on_block_controller = EventController()
self.on_block = self._on_block_controller.stream
self._on_progress_controller = EventController()
self.on_progress = self._on_progress_controller.stream
self._on_progress_controller = db._on_progress_controller
self.on_progress = db.on_progress
self._on_ready_controller = EventController()
self.on_ready = self._on_ready_controller.stream
@ -39,9 +37,6 @@ class Sync:
def on_bulk_started(self):
return self.on_progress.where() # filter for bulk started event
def on_bulk_started(self):
return self.on_progress.where() # filter for bulk started event
def on_bulk_finished(self):
return self.on_progress.where() # filter for bulk finished event
@ -59,9 +54,9 @@ class Service:
sync: Sync
def __init__(self, ledger: Ledger, db_url: str):
def __init__(self, ledger: Ledger):
self.ledger, self.conf = ledger, ledger.conf
self.db = Database(ledger, db_url)
self.db = Database(ledger)
self.wallets = WalletManager(ledger, self.db)
#self.on_address = sync.on_address
@ -105,7 +100,7 @@ class Service:
def create_wallet(self, file_name):
path = os.path.join(self.conf.wallet_dir, file_name)
return self.wallet_manager.import_wallet(path)
return self.wallets.add_from_path(path)
async def get_addresses(self, **constraints):
return await self.db.get_addresses(**constraints)
@ -123,11 +118,11 @@ class Service:
self.constraint_spending_utxos(constraints)
return self.db.get_utxos(**constraints)
async def get_txos(self, resolve=False, **constraints) -> Tuple[List[Output], Optional[int]]:
txos, count = await self.db.get_txos(**constraints)
async def get_txos(self, resolve=False, **constraints) -> Result[Output]:
txos = await self.db.get_txos(**constraints)
if resolve:
return await self._resolve_for_local_results(constraints.get('accounts', []), txos), count
return txos, count
return await self._resolve_for_local_results(constraints.get('accounts', []), txos)
return txos
def get_txo_sum(self, **constraints):
return self.db.get_txo_sum(**constraints)
@ -142,17 +137,17 @@ class Service:
tx = await self.db.get_transaction(tx_hash=tx_hash)
if tx:
return tx
try:
raw, merkle = await self.ledger.network.get_transaction_and_merkle(tx_hash)
except CodeMessageError as e:
if 'No such mempool or blockchain transaction.' in e.message:
return {'success': False, 'code': 404, 'message': 'transaction not found'}
return {'success': False, 'code': e.code, 'message': e.message}
height = merkle.get('block_height')
tx = Transaction(unhexlify(raw), height=height)
if height and height > 0:
await self.ledger.maybe_verify_transaction(tx, height, merkle)
return tx
# try:
# raw, merkle = await self.ledger.network.get_transaction_and_merkle(tx_hash)
# except CodeMessageError as e:
# if 'No such mempool or blockchain transaction.' in e.message:
# return {'success': False, 'code': 404, 'message': 'transaction not found'}
# return {'success': False, 'code': e.code, 'message': e.message}
# height = merkle.get('block_height')
# tx = Transaction(unhexlify(raw), height=height)
# if height and height > 0:
# await self.ledger.maybe_verify_transaction(tx, height, merkle)
# return tx
async def search_transactions(self, txids):
raise NotImplementedError
@ -162,16 +157,17 @@ class Service:
async def get_address_manager_for_address(self, address):
details = await self.db.get_address(address=address)
for account in self.accounts:
for wallet in self.wallets:
for account in wallet.accounts:
if account.id == details['account']:
return account.address_managers[details['chain']]
return None
async def reset(self):
self.ledger.config = {
self.ledger.conf = {
'auto_connect': True,
'default_servers': self.config.lbryum_servers,
'data_path': self.config.wallet_dir,
'default_servers': self.conf.lbryum_servers,
'data_path': self.conf.wallet_dir,
}
await self.ledger.stop()
await self.ledger.start()
@ -181,13 +177,13 @@ class Service:
return self.ledger.genesis_hash
return (await self.ledger.headers.hash(self.ledger.headers.height)).decode()
async def maybe_broadcast_or_release(self, tx, blocking=False, preview=False):
async def maybe_broadcast_or_release(self, tx, preview=False, no_wait=False):
if preview:
return await self.release_tx(tx)
try:
await self.broadcast(tx)
if blocking:
await self.wait(tx, timeout=None)
if not no_wait:
await self.wait(tx)
except Exception:
await self.release_tx(tx)
raise
@ -217,7 +213,7 @@ class Service:
if resolve:
claim_ids = [p.purchased_claim_id for p in purchases]
try:
resolved, _, _, _ = await self.claim_search([], claim_ids=claim_ids)
resolved, _, _ = await self.search_claims([], claim_ids=claim_ids)
except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
@ -228,7 +224,7 @@ class Service:
purchase.purchased_claim = lookup.get(purchase.purchased_claim_id)
return purchases
async def _resolve_for_local_results(self, accounts, txos):
async def _resolve_for_local_results(self, accounts, txos: Result) -> Result:
results = []
response = await self.resolve(
accounts, [txo.permanent_url for txo in txos if txo.can_decode_claim]
@ -242,12 +238,13 @@ class Service:
if isinstance(resolved, dict) and 'error' in resolved:
txo.meta['error'] = resolved['error']
results.append(txo)
return results
txos.rows = results
return txos
async def resolve_collection(self, collection, offset=0, page_size=1):
claim_ids = collection.claim.collection.claims.ids[offset:page_size+offset]
try:
resolve_results, _, _, _ = await self.claim_search([], claim_ids=claim_ids)
resolve_results, _, _ = await self.search_claims([], claim_ids=claim_ids)
except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise

View file

@ -7,16 +7,16 @@ from aiohttp.web import GracefulExit
from aiohttp.web import Application, AppRunner, WebSocketResponse, TCPSite, Response
from aiohttp.http_websocket import WSMsgType, WSCloseCode
from lbry.extras.daemon.json_response_encoder import JSONResponseEncoder
from lbry.service.json_encoder import JSONResponseEncoder
from lbry.service.base import Service
from lbry.service.api import API
from lbry.console import Console
def jsonrpc_dumps_pretty(obj, **kwargs):
if not isinstance(obj, dict):
data = {"jsonrpc": "2.0", "error": obj.to_dict()}
else:
#if not isinstance(obj, dict):
# data = {"jsonrpc": "2.0", "error": obj.to_dict()}
#else:
data = {"jsonrpc": "2.0", "result": obj}
return json.dumps(data, cls=JSONResponseEncoder, sort_keys=True, indent=2, **kwargs) + "\n"
@ -34,7 +34,7 @@ class WebSocketLogHandler(logging.Handler):
'name': record.name,
'message': self.format(record)
})
except:
except Exception:
self.handleError(record)
@ -44,16 +44,16 @@ class WebSocketManager(WebSocketResponse):
super().__init__(*args, **kwargs)
def subscribe(self, requested: list, subscriptions):
for component in requested:
if component == '*':
for request in requested:
if request == '*':
for _, component in subscriptions.items():
for _, sockets in component.items():
sockets.add(self)
elif '.' not in component:
for _, sockets in subscriptions[component].items():
elif '.' not in request:
for _, sockets in subscriptions[request].items():
sockets.add(self)
elif component.count('.') == 1:
component, stream = component.split('.')
elif request.count('.') == 1:
component, stream = request.split('.')
subscriptions[component][stream].add(self)
def unsubscribe(self, subscriptions):
@ -72,6 +72,7 @@ class Daemon:
self.app = Application()
self.app['websockets'] = WeakSet()
self.app['subscriptions'] = {}
self.components = {}
#for component in components:
# streams = self.app['subscriptions'][component.name] = {}
# for event_name, event_stream in component.event_streams.items():
@ -86,12 +87,12 @@ class Daemon:
def run(self):
loop = asyncio.get_event_loop()
def exit():
def graceful_exit():
raise GracefulExit()
try:
loop.add_signal_handler(signal.SIGINT, exit)
loop.add_signal_handler(signal.SIGTERM, exit)
loop.add_signal_handler(signal.SIGINT, graceful_exit)
loop.add_signal_handler(signal.SIGTERM, graceful_exit)
except NotImplementedError:
pass # Not implemented on Windows

View file

@ -7,7 +7,7 @@ from typing import Optional, Iterable, Type
from aiohttp.client_exceptions import ContentTypeError
from lbry.error import InvalidExchangeRateResponseError, CurrencyConversionError
from lbry.utils import aiohttp_request
from lbry.wallet.dewies import lbc_to_dewies
from lbry.blockchain.dewies import lbc_to_dewies
log = logging.getLogger(__name__)

View file

@ -16,10 +16,10 @@ class FullNode(Service):
sync: BlockchainSync
def __init__(self, ledger: Ledger, db_url: str, chain: Lbrycrd = None):
super().__init__(ledger, db_url)
def __init__(self, ledger: Ledger, chain: Lbrycrd = None):
super().__init__(ledger)
self.chain = chain or Lbrycrd(ledger)
self.sync = BlockchainSync(self.chain, self.db, self.conf.processes)
self.sync = BlockchainSync(self.chain, self.db)
async def start(self):
await self.chain.open()
@ -39,7 +39,7 @@ class FullNode(Service):
tx_hashes = [unhexlify(txid)[::-1] for txid in txids]
return {
hexlify(tx['tx_hash'][::-1]).decode(): hexlify(tx['raw']).decode()
for tx in await self.db.get_raw_transactions(tx_hashes)
for tx in await self.db.get_transactions(tx_hashes=tx_hashes)
}
async def search_claims(self, accounts, **kwargs):

View file

@ -1,8 +1,10 @@
# pylint: disable=invalid-name
import logging
from decimal import Decimal
from binascii import hexlify, unhexlify
from datetime import datetime, date
from json import JSONEncoder
from typing import Iterator, Generic
from google.protobuf.message import DecodeError
@ -12,6 +14,7 @@ from lbry.blockchain.transaction import Transaction, Output
from lbry.crypto.bip32 import PubKey
from lbry.blockchain.dewies import dewies_to_lbc
from lbry.stream.managed_stream import ManagedStream
from lbry.db.database import Result, ResultType
log = logging.getLogger(__name__)
@ -123,6 +126,36 @@ def encode_pagination_doc(items):
}
class Paginated(Generic[ResultType]):
__slots__ = 'result', 'page', 'page_size'
def __init__(self, result: Result, page: int, page_size: int):
self.result = result
self.page = page
self.page_size = page_size
def __getitem__(self, item: int) -> ResultType:
return self.result[item]
def __iter__(self) -> Iterator[ResultType]:
return iter(self.result)
def __len__(self):
return len(self.result)
def __repr__(self):
return repr(self.to_dict())
def to_dict(self):
d = {"items": self.result.rows, "page": self.page, "page_size": self.page_size}
if self.result.total is not None:
count = self.result.total
d["total_pages"] = int((count + (self.page_size - 1)) / self.page_size)
d["total_items"] = count
return d
class JSONResponseEncoder(JSONEncoder):
def __init__(self, *args, service, include_protobuf=False, **kwargs):
@ -131,6 +164,8 @@ class JSONResponseEncoder(JSONEncoder):
self.include_protobuf = include_protobuf
def default(self, obj): # pylint: disable=method-hidden,arguments-differ,too-many-return-statements
if isinstance(obj, Paginated):
return obj.to_dict()
if isinstance(obj, Account):
return self.encode_account(obj)
if isinstance(obj, Wallet):
@ -257,7 +292,7 @@ class JSONResponseEncoder(JSONEncoder):
if isinstance(value, int):
meta[key] = dewies_to_lbc(value)
if 0 < meta.get('creation_height', 0) <= 0: #self.ledger.headers.height:
meta['creation_timestamp'] = self.ledger.headers.estimated_timestamp(meta['creation_height'])
meta['creation_timestamp'] = self.service.ledger.headers.estimated_timestamp(meta['creation_height'])
return meta
def encode_input(self, txi):
@ -266,7 +301,8 @@ class JSONResponseEncoder(JSONEncoder):
'nout': txi.txo_ref.position
}
def encode_account(self, account):
@staticmethod
def encode_account(account):
result = account.to_dict()
result['id'] = account.id
result.pop('certificates', None)

View file

@ -1,8 +1,7 @@
import logging
from lbry.conf import Config
from lbry.blockchain.ledger import Ledger
from lbry.db import Database
from lbry.blockchain import Ledger, Transaction
from lbry.wallet.sync import SPVSync
from .base import Service
@ -13,9 +12,9 @@ log = logging.getLogger(__name__)
class LightClient(Service):
def __init__(self, ledger: Ledger, db_url: str):
super().__init__(ledger, db_url)
self.client = Client(self, Config().api_connection_url)#ledger.conf)
def __init__(self, ledger: Ledger):
super().__init__(ledger)
self.client = Client(Config().api_connection_url)
self.sync = SPVSync(self)
async def search_transactions(self, txids):
@ -26,3 +25,15 @@ class LightClient(Service):
async def get_transaction_address_filters(self, block_hash):
return await self.client.address_transaction_filters(block_hash=block_hash)
async def broadcast(self, tx):
pass
async def wait(self, tx: Transaction, height=-1, timeout=1):
pass
async def resolve(self, accounts, urls, **kwargs):
pass
async def search_claims(self, accounts, **kwargs):
pass

File diff suppressed because one or more lines are too long

View file

@ -54,10 +54,6 @@ def parse_argument(tokens, method_name='') -> dict:
}
if arg['name'] == 'self':
return {}
try:
tokens[0]
except:
a = 9
if tokens[0].string == ':':
tokens.pop(0)
type_tokens = []
@ -68,7 +64,7 @@ def parse_argument(tokens, method_name='') -> dict:
tokens.pop(0)
default = tokens.pop(0)
if default.type == token.NAME:
default_value = eval(default.string)
default_value = eval(default.string) # pylint: disable=eval-used
if default_value is not None:
arg['default'] = default_value
elif default.type == token.NUMBER:
@ -137,7 +133,6 @@ def produce_return_tokens(src: str):
elif in_return:
if t.type == token.INDENT:
break
else:
parsed.append(t)
return parsed
@ -288,6 +283,7 @@ def get_api_definitions(cls):
def write(fp):
fp.write('# pylint: skip-file\n')
fp.write('# DO NOT EDIT: GENERATED FILE\n')
fp.write(f'interface = ')
defs = get_api_definitions(api.API)

View file

@ -22,7 +22,7 @@ if typing.TYPE_CHECKING:
from lbry.dht.node import Node
from lbry.extras.daemon.analytics import AnalyticsManager
from lbry.extras.daemon.storage import SQLiteStorage, StoredContentClaim
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
from lbry.service.exchange_rate_manager import ExchangeRateManager
from lbry.service.base import Service
log = logging.getLogger(__name__)

View file

@ -12,31 +12,20 @@ import unittest
from unittest.case import _Outcome
from typing import Optional
from binascii import unhexlify
from functools import partial
from lbry.wallet import WalletManager, Wallet, Account
from lbry.blockchain import (
RegTestLedger, Transaction, Input, Output, dewies_to_lbc
)
from lbry.blockchain.lbrycrd import Lbrycrd
from lbry.constants import CENT, NULL_HASH32
from lbry.service import Daemon, FullNode
from lbry.service import Daemon, FullNode, jsonrpc_dumps_pretty
from lbry.conf import Config
from lbry.console import Console
from lbry.wallet import Wallet, Account
from lbry.extras.daemon.daemon import jsonrpc_dumps_pretty
from lbry.extras.daemon.components import Component, WalletComponent
from lbry.extras.daemon.components import (
DHT_COMPONENT, HASH_ANNOUNCER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT,
UPNP_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT
)
from lbry.extras.daemon.componentmanager import ComponentManager
from lbry.extras.daemon.exchange_rate_manager import (
from lbry.service.exchange_rate_manager import (
ExchangeRateManager, ExchangeRate, LBRYFeed, LBRYBTCFeed
)
from lbry.extras.daemon.storage import SQLiteStorage
from lbry.blob.blob_manager import BlobManager
from lbry.stream.reflector.server import ReflectorServer
from lbry.blob_exchange.server import BlobServer
def get_output(amount=CENT, pubkey_hash=NULL_HASH32, height=-2):
@ -236,29 +225,14 @@ class IntegrationTestCase(AsyncioTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conductor: Optional[Conductor] = None
self.blockchain: Optional[BlockchainNode] = None
self.wallet_node: Optional[WalletNode] = None
self.manager: Optional[WalletManager] = None
self.ledger: Optional['Ledger'] = None
self.ledger: Optional[RegTestLedger] = None
self.chain: Optional[Lbrycrd] = None
self.block_expected = 0
self.service = None
self.api = None
self.wallet: Optional[Wallet] = None
self.account: Optional[Account] = None
async def asyncSetUp(self):
self.conductor = Conductor(seed=self.SEED)
await self.conductor.start_blockchain()
self.addCleanup(self.conductor.stop_blockchain)
await self.conductor.start_spv()
self.addCleanup(self.conductor.stop_spv)
await self.conductor.start_wallet()
self.addCleanup(self.conductor.stop_wallet)
self.blockchain = self.conductor.blockchain_node
self.wallet_node = self.conductor.wallet_node
self.manager = self.wallet_node.manager
self.ledger = self.wallet_node.ledger
self.wallet = self.wallet_node.wallet
self.account = self.wallet_node.wallet.default_account
async def assertBalance(self, account, expected_balance: str): # pylint: disable=C0103
balance = await account.get_balance()
self.assertEqual(dewies_to_lbc(balance), expected_balance)
@ -316,24 +290,6 @@ def get_fake_exchange_rate_manager(rates=None):
)
class ExchangeRateManagerComponent(Component):
component_name = EXCHANGE_RATE_MANAGER_COMPONENT
def __init__(self, component_manager, rates=None):
super().__init__(component_manager)
self.exchange_rate_manager = get_fake_exchange_rate_manager(rates)
@property
def component(self) -> ExchangeRateManager:
return self.exchange_rate_manager
async def start(self):
self.exchange_rate_manager.start()
async def stop(self):
self.exchange_rate_manager.stop()
class CommandTestCase(IntegrationTestCase):
VERBOSITY = logging.WARN
@ -358,7 +314,6 @@ class CommandTestCase(IntegrationTestCase):
self.addCleanup(self.chain.stop)
await self.chain.start('-rpcworkqueue=128')
self.block_expected = 0
await self.generate(200, wait=False)
self.daemon = await self.add_daemon()
@ -382,47 +337,16 @@ class CommandTestCase(IntegrationTestCase):
self.addCleanup(shutil.rmtree, path, True)
ledger = RegTestLedger(Config.with_same_dir(path).set(
api=f'localhost:{self.daemon_port}',
lbrycrd_dir=self.chain.ledger.conf.lbrycrd_dir,
spv_address_filters=False
))
db_url = f"sqlite:///{os.path.join(path,'full_node.db')}"
service = FullNode(ledger, db_url, Lbrycrd(self.chain.ledger))
daemon = Daemon(service)
service = FullNode(ledger)
console = Console(service)
daemon = Daemon(service, console)
self.addCleanup(daemon.stop)
await daemon.start()
return daemon
async def XasyncSetUp(self):
await super().asyncSetUp()
logging.getLogger('lbry.blob_exchange').setLevel(self.VERBOSITY)
logging.getLogger('lbry.daemon').setLevel(self.VERBOSITY)
logging.getLogger('lbry.stream').setLevel(self.VERBOSITY)
logging.getLogger('lbry.wallet').setLevel(self.VERBOSITY)
self.daemon = await self.add_daemon(self.wallet_node)
await self.account.ensure_address_gap()
address = (await self.account.receiving.get_addresses(limit=1, only_usable=True))[0]
sendtxid = await self.blockchain.send_to_address(address, 10)
await self.confirm_tx(sendtxid)
await self.generate(5)
server_tmp_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, server_tmp_dir)
self.server_config = Config()
self.server_storage = SQLiteStorage(self.server_config, ':memory:')
await self.server_storage.open()
self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_config)
self.server = BlobServer(self.loop, self.server_blob_manager, 'bQEaw42GXsgCAGio1nxFncJSyRmnztSCjP')
self.server.start_server(5567, '127.0.0.1')
await self.server.started_listening.wait()
self.reflector = ReflectorServer(self.server_blob_manager)
self.reflector.start_server(5566, '127.0.0.1')
await self.reflector.started_listening.wait()
self.addCleanup(self.reflector.stop_server)
async def asyncTearDown(self):
await super().asyncTearDown()
for wallet_node in self.extra_wallet_nodes:
@ -431,56 +355,6 @@ class CommandTestCase(IntegrationTestCase):
daemon.component_manager.get_component('wallet')._running = False
await daemon.stop()
async def Xadd_daemon(self, wallet_node=None, seed=None):
if wallet_node is None:
wallet_node = WalletNode(
self.wallet_node.manager_class,
self.wallet_node.ledger_class,
port=self.extra_wallet_node_port
)
self.extra_wallet_node_port += 1
await wallet_node.start(self.conductor.spv_node, seed=seed)
self.extra_wallet_nodes.append(wallet_node)
upload_dir = os.path.join(wallet_node.data_path, 'uploads')
os.mkdir(upload_dir)
conf = Config()
conf.data_dir = wallet_node.data_path
conf.wallet_dir = wallet_node.data_path
conf.download_dir = wallet_node.data_path
conf.upload_dir = upload_dir # not a real conf setting
conf.share_usage_data = False
conf.use_upnp = False
conf.reflect_streams = True
conf.blockchain_name = 'lbrycrd_regtest'
conf.lbryum_servers = [('127.0.0.1', 50001)]
conf.reflector_servers = [('127.0.0.1', 5566)]
conf.known_dht_nodes = []
conf.blob_lru_cache_size = self.blob_lru_cache_size
conf.components_to_skip = [
DHT_COMPONENT, UPNP_COMPONENT, HASH_ANNOUNCER_COMPONENT,
PEER_PROTOCOL_SERVER_COMPONENT
]
wallet_node.manager.config = conf
def wallet_maker(component_manager):
wallet_component = WalletComponent(component_manager)
wallet_component.wallet_manager = wallet_node.manager
wallet_component._running = True
return wallet_component
daemon = Daemon(conf, ComponentManager(
conf, skip_components=conf.components_to_skip, wallet=wallet_maker,
exchange_rate_manager=partial(ExchangeRateManagerComponent, rates={
'BTCLBC': 1.0, 'USDBTC': 2.0
})
))
await daemon.initialize()
self.daemons.append(daemon)
wallet_node.manager.old_db = daemon.storage
return daemon
async def confirm_tx(self, txid, ledger=None):
""" Wait for tx to be in mempool, then generate a block, wait for tx to be in a block. """
await self.on_transaction_id(txid, ledger)
@ -500,8 +374,8 @@ class CommandTestCase(IntegrationTestCase):
addresses.add(txo['address'])
return list(addresses)
def is_expected_block(self, b):
return self.block_expected == b.height
def is_expected_block(self, event):
return self.block_expected == event.height
async def generate(self, blocks, wait=True):
""" Ask lbrycrd to generate some blocks and wait until ledger has them. """
@ -510,18 +384,6 @@ class CommandTestCase(IntegrationTestCase):
if wait:
await self.service.sync.on_block.where(self.is_expected_block)
async def blockchain_claim_name(self, name: str, value: str, amount: str, confirm=True):
txid = await self.blockchain._cli_cmnd('claimname', name, value, amount)
if confirm:
await self.generate(1)
return txid
async def blockchain_update_name(self, txid: str, value: str, amount: str, confirm=True):
txid = await self.blockchain._cli_cmnd('updateclaim', txid, value, amount)
if confirm:
await self.generate(1)
return txid
async def out(self, awaitable):
""" Serializes lbrynet API results to JSON then loads and returns it as dictionary. """
return json.loads(jsonrpc_dumps_pretty(await awaitable, service=self.service))['result']
@ -533,7 +395,6 @@ class CommandTestCase(IntegrationTestCase):
async def confirm_and_render(self, awaitable, confirm) -> Transaction:
tx = await awaitable
if confirm:
await self.service.wait(tx)
await self.generate(1)
await self.service.wait(tx)
return self.sout(tx)
@ -657,7 +518,7 @@ class CommandTestCase(IntegrationTestCase):
if confirm:
await asyncio.wait([self.ledger.wait(tx) for tx in txs])
await self.generate(1)
await asyncio.wait([self.ledger.wait(tx, self.blockchain.block_expected) for tx in txs])
await asyncio.wait([self.ledger.wait(tx, self.block_expected) for tx in txs])
return self.sout(txs)
async def resolve(self, uri, **kwargs):

View file

@ -54,10 +54,8 @@ def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
Typical user data directories are:
Mac OS X: ~/Library/Application Support/<AppName>
Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
Win XP: C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
Win 7: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
That means, by default "~/.local/share/<AppName>".
@ -65,7 +63,7 @@ def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
if sys.platform == "win32":
if appauthor is None:
appauthor = appname
const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA"
const = "CSIDL_APPDATA" if roaming else "CSIDL_LOCAL_APPDATA"
path = os.path.normpath(_get_win_folder(const))
if appname:
if appauthor is not False:
@ -130,7 +128,7 @@ def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
def _get_win_folder(csidl_name):
import ctypes
import ctypes # pylint: disable=import-outside-toplevel
csidl_const = {
"CSIDL_APPDATA": 26,
@ -157,9 +155,9 @@ def _get_win_folder(csidl_name):
def _get_win_download_folder():
import ctypes
from ctypes import windll, wintypes
from uuid import UUID
import ctypes # pylint: disable=import-outside-toplevel
from ctypes import windll, wintypes # pylint: disable=import-outside-toplevel
from uuid import UUID # pylint: disable=import-outside-toplevel
class GUID(ctypes.Structure):
_fields_ = [
@ -177,12 +175,12 @@ def _get_win_download_folder():
for i in range(2, 8):
self.data4[i] = rest >> (8-i-1)*8 & 0xff
SHGetKnownFolderPath = windll.shell32.SHGetKnownFolderPath
SHGetKnownFolderPath = windll.shell32.SHGetKnownFolderPath # pylint: disable=invalid-name
SHGetKnownFolderPath.argtypes = [
ctypes.POINTER(GUID), wintypes.DWORD, wintypes.HANDLE, ctypes.POINTER(ctypes.c_wchar_p)
]
FOLDERID_Downloads = '{374DE290-123F-4565-9164-39C4925E467B}'
FOLDERID_Downloads = '{374DE290-123F-4565-9164-39C4925E467B}' # pylint: disable=invalid-name
guid = GUID(FOLDERID_Downloads)
pathptr = ctypes.c_wchar_p()

View file

@ -12,7 +12,7 @@ import ecdsa
from lbry.constants import COIN
from lbry.db import Database, CLAIM_TYPE_CODES, TXO_TYPES
from lbry.blockchain import Ledger, Transaction, Input, Output
from lbry.blockchain import Ledger
from lbry.error import InvalidPasswordError
from lbry.crypto.crypt import aes_encrypt, aes_decrypt
from lbry.crypto.bip32 import PrivateKey, PubKey, from_extended_key_string
@ -29,7 +29,7 @@ class AddressManager:
__slots__ = 'account', 'public_key', 'chain_number', 'address_generator_lock'
def __init__(self, account, public_key, chain_number):
def __init__(self, account: 'Account', public_key, chain_number):
self.account = account
self.public_key = public_key
self.chain_number = chain_number
@ -57,11 +57,11 @@ class AddressManager:
raise NotImplementedError
async def _query_addresses(self, **constraints):
return (await self.account.db.get_addresses(
return await self.account.db.get_addresses(
account=self.account,
chain=self.chain_number,
**constraints
))[0]
)
def get_private_key(self, index: int) -> PrivateKey:
raise NotImplementedError
@ -415,7 +415,7 @@ class Account:
return await self.db.get_addresses(account=self, **constraints)
async def get_addresses(self, **constraints) -> List[str]:
rows, _ = await self.get_address_records(cols=['account_address.address'], **constraints)
rows = await self.get_address_records(cols=['account_address.address'], **constraints)
return [r['address'] for r in rows]
async def get_valid_receiving_address(self, default_address: str) -> str:
@ -447,40 +447,6 @@ class Account:
'max_receiving_gap': receiving_gap,
}
async def fund(self, to_account, amount=None, everything=False,
outputs=1, broadcast=False, **constraints):
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
if everything:
utxos = await self.get_utxos(**constraints)
await self.ledger.reserve_outputs(utxos)
tx = await Transaction.create(
inputs=[Input.spend(txo) for txo in utxos],
outputs=[],
funding_accounts=[self],
change_account=to_account
)
elif amount > 0:
to_address = await to_account.change.get_or_create_usable_address()
to_hash160 = to_account.ledger.address_to_hash160(to_address)
tx = await Transaction.create(
inputs=[],
outputs=[
Output.pay_pubkey_hash(amount//outputs, to_hash160)
for _ in range(outputs)
],
funding_accounts=[self],
change_account=self
)
else:
raise ValueError('An amount is required.')
if broadcast:
await self.ledger.broadcast(tx)
else:
await self.ledger.release_tx(tx)
return tx
def add_channel_private_key(self, private_key):
public_key_bytes = private_key.get_verifying_key().to_der()
channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes)
@ -516,7 +482,6 @@ class Account:
channel_keys[self.ledger.public_key_to_address(public_key_der)] = private_key_pem
if self.channel_keys != channel_keys:
self.channel_keys = channel_keys
self.wallet.save()
async def save_max_gap(self):
gap_changed = False

View file

@ -60,12 +60,12 @@ class WalletManager:
wallets_directory = self.path
for wallet_id in self.ledger.conf.wallets:
if wallet_id in self.wallets:
log.warning(f"Ignoring duplicate wallet_id in config: {wallet_id}")
log.warning("Ignoring duplicate wallet_id in config: %s", wallet_id)
continue
wallet_path = os.path.join(wallets_directory, wallet_id)
if not os.path.exists(wallet_path):
if not wallet_id == "default_wallet": # we'll probably generate this wallet, don't show error
log.error(f"Could not load wallet, file does not exist: {wallet_path}")
log.error("Could not load wallet, file does not exist: %s", wallet_path)
continue
wallet = await Wallet.from_path(self.ledger, self.db, wallet_path)
self.add(wallet)

View file

@ -31,11 +31,11 @@ def sync_generate_phrase(language: str) -> str:
while True:
nonce += 1
i = entropy + nonce
w = []
word_buffer = []
while i:
w.append(local_words[i % 2048])
word_buffer.append(local_words[i % 2048])
i //= 2048
seed = ' '.join(w)
seed = ' '.join(word_buffer)
if hexlify(hmac_sha512(b"Seed version", seed.encode())).startswith(b"01"):
break
return seed

View file

@ -1,24 +1,24 @@
import asyncio
import logging
from io import StringIO
from functools import partial
from operator import itemgetter
#from io import StringIO
#from functools import partial
#from operator import itemgetter
from collections import defaultdict
from binascii import hexlify, unhexlify
#from binascii import hexlify, unhexlify
from typing import List, Optional, DefaultDict, NamedTuple
from lbry.crypto.hash import double_sha256, sha256
#from lbry.crypto.hash import double_sha256, sha256
from lbry.service.api import Client
from lbry.tasks import TaskGroup
from lbry.blockchain.transaction import Transaction
from lbry.blockchain.ledger import Ledger
from lbry.blockchain.block import get_block_filter
from lbry.db import Database
#from lbry.blockchain.block import get_block_filter
from lbry.event import EventController
from lbry.service.base import Service, Sync
from .account import Account, AddressManager
from .account import AddressManager
log = logging.getLogger(__name__)
class TransactionEvent(NamedTuple):
@ -54,24 +54,16 @@ class TransactionCacheItem:
class SPVSync(Sync):
def __init__(self, service: Service):
super().__init__(service)
return
self.headers = headers
self.network: Network = self.config.get('network') or Network(self)
self.network.on_header.listen(self.receive_header)
self.network.on_status.listen(self.process_status_update)
self.network.on_connected.listen(self.join_network)
super().__init__(service.ledger, service.db)
self.accounts = []
self.on_address = self.ledger.on_address
self._on_header_controller = EventController()
self.on_header = self._on_header_controller.stream
self.on_header.listen(
lambda change: log.info(
'%s: added %s header blocks, final height %s',
self.ledger.get_id(), change, self.headers.height
'%s: added %s header blocks',
self.ledger.get_id(), change
)
)
self._download_height = 0
@ -86,345 +78,343 @@ class SPVSync(Sync):
self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._known_addresses_out_of_sync = set()
async def advance(self):
address_array = [
bytearray(a['address'].encode())
for a in await self.service.db.get_all_addresses()
]
block_filters = await self.service.get_block_address_filters()
for block_hash, block_filter in block_filters.items():
bf = get_block_filter(block_filter)
if bf.MatchAny(address_array):
print(f'match: {block_hash} - {block_filter}')
tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash)
for txid, tx_filter in tx_filters.items():
tf = get_block_filter(tx_filter)
if tf.MatchAny(address_array):
print(f' match: {txid} - {tx_filter}')
txs = await self.service.search_transactions([txid])
tx = Transaction(unhexlify(txs[txid]))
await self.service.db.insert_transaction(tx)
async def get_local_status_and_history(self, address, history=None):
if not history:
address_details = await self.db.get_address(address=address)
history = (address_details['history'] if address_details else '') or ''
parts = history.split(':')[:-1]
return (
hexlify(sha256(history.encode())).decode() if history else None,
list(zip(parts[0::2], map(int, parts[1::2])))
)
@staticmethod
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
for i, branch in enumerate(branches):
other_branch = unhexlify(branch)[::-1]
other_branch_on_left = bool((branch_positions >> i) & 1)
if other_branch_on_left:
combined = other_branch + working_branch
else:
combined = working_branch + other_branch
working_branch = double_sha256(combined)
return hexlify(working_branch[::-1])
# async def advance(self):
# address_array = [
# bytearray(a['address'].encode())
# for a in await self.service.db.get_all_addresses()
# ]
# block_filters = await self.service.get_block_address_filters()
# for block_hash, block_filter in block_filters.items():
# bf = get_block_filter(block_filter)
# if bf.MatchAny(address_array):
# print(f'match: {block_hash} - {block_filter}')
# tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash)
# for txid, tx_filter in tx_filters.items():
# tf = get_block_filter(tx_filter)
# if tf.MatchAny(address_array):
# print(f' match: {txid} - {tx_filter}')
# txs = await self.service.search_transactions([txid])
# tx = Transaction(unhexlify(txs[txid]))
# await self.service.db.insert_transaction(tx)
#
# async def get_local_status_and_history(self, address, history=None):
# if not history:
# address_details = await self.db.get_address(address=address)
# history = (address_details['history'] if address_details else '') or ''
# parts = history.split(':')[:-1]
# return (
# hexlify(sha256(history.encode())).decode() if history else None,
# list(zip(parts[0::2], map(int, parts[1::2])))
# )
#
# @staticmethod
# def get_root_of_merkle_tree(branches, branch_positions, working_branch):
# for i, branch in enumerate(branches):
# other_branch = unhexlify(branch)[::-1]
# other_branch_on_left = bool((branch_positions >> i) & 1)
# if other_branch_on_left:
# combined = other_branch + working_branch
# else:
# combined = working_branch + other_branch
# working_branch = double_sha256(combined)
# return hexlify(working_branch[::-1])
#
async def start(self):
await self.headers.open()
fully_synced = self.on_ready.first
asyncio.create_task(self.network.start())
await self.network.on_connected.first
async with self._header_processing_lock:
await self._update_tasks.add(self.initial_headers_sync())
#asyncio.create_task(self.network.start())
#await self.network.on_connected.first
#async with self._header_processing_lock:
# await self._update_tasks.add(self.initial_headers_sync())
await fully_synced
async def join_network(self, *_):
log.info("Subscribing and updating accounts.")
await self._update_tasks.add(self.subscribe_accounts())
await self._update_tasks.done.wait()
self._on_ready_controller.add(True)
#
# async def join_network(self, *_):
# log.info("Subscribing and updating accounts.")
# await self._update_tasks.add(self.subscribe_accounts())
# await self._update_tasks.done.wait()
# self._on_ready_controller.add(True)
#
async def stop(self):
self._update_tasks.cancel()
self._other_tasks.cancel()
await self._update_tasks.done.wait()
await self._other_tasks.done.wait()
await self.network.stop()
await self.headers.close()
@property
def local_height_including_downloaded_height(self):
return max(self.headers.height, self._download_height)
async def initial_headers_sync(self):
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True)
self.headers.chunk_getter = get_chunk
async def doit():
for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)):
async with self._header_processing_lock:
await self.headers.ensure_chunk_at(height)
self._other_tasks.add(doit())
await self.update_headers()
async def update_headers(self, height=None, headers=None, subscription_update=False):
rewound = 0
while True:
if height is None or height > len(self.headers):
# sometimes header subscription updates are for a header in the future
# which can't be connected, so we do a normal header sync instead
height = len(self.headers)
headers = None
subscription_update = False
if not headers:
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
headers = header_response['hex']
if not headers:
# Nothing to do, network thinks we're already at the latest height.
return
added = await self.headers.connect(height, unhexlify(headers))
if added > 0:
height += added
self._on_header_controller.add(
BlockHeightEvent(self.headers.height, added))
if rewound > 0:
# we started rewinding blocks and apparently found
# a new chain
rewound = 0
await self.db.rewind_blockchain(height)
if subscription_update:
# subscription updates are for latest header already
# so we don't need to check if there are newer / more
# on another loop of update_headers(), just return instead
return
elif added == 0:
# we had headers to connect but none got connected, probably a reorganization
height -= 1
rewound += 1
log.warning(
"Blockchain Reorganization: attempting rewind to height %s from starting height %s",
height, height+rewound
)
else:
raise IndexError(f"headers.connect() returned negative number ({added})")
if height < 0:
raise IndexError(
"Blockchain reorganization rewound all the way back to genesis hash. "
"Something is very wrong. Maybe you are on the wrong blockchain?"
)
if rewound >= 100:
raise IndexError(
"Blockchain reorganization dropped {} headers. This is highly unusual. "
"Will not continue to attempt reorganizing. Please, delete the ledger "
"synchronization directory inside your wallet directory (folder: '{}') and "
"restart the program to synchronize from scratch."
.format(rewound, self.ledger.get_id())
)
headers = None # ready to download some more headers
# if we made it this far and this was a subscription_update
# it means something went wrong and now we're doing a more
# robust sync, turn off subscription update shortcut
subscription_update = False
async def receive_header(self, response):
async with self._header_processing_lock:
header = response[0]
await self.update_headers(
height=header['height'], headers=header['hex'], subscription_update=True
)
async def subscribe_accounts(self):
if self.network.is_connected and self.accounts:
log.info("Subscribe to %i accounts", len(self.accounts))
await asyncio.wait([
self.subscribe_account(a) for a in self.accounts
])
async def subscribe_account(self, account: Account):
for address_manager in account.address_managers.values():
await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
await account.ensure_address_gap()
async def unsubscribe_account(self, account: Account):
for address in await account.get_addresses():
await self.network.unsubscribe_address(address)
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
if self.network.is_connected and addresses:
addresses_remaining = list(addresses)
while addresses_remaining:
batch = addresses_remaining[:batch_size]
results = await self.network.subscribe_address(*batch)
for address, remote_status in zip(batch, results):
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
addresses_remaining = addresses_remaining[batch_size:]
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
len(addresses), *self.network.client.server_address_and_port)
log.info(
"finished subscribing to %i addresses on %s:%i", len(addresses),
*self.network.client.server_address_and_port
)
def process_status_update(self, update):
address, remote_status = update
self._update_tasks.add(self.update_history(address, remote_status))
async def update_history(self, address, remote_status, address_manager: AddressManager = None):
async with self._address_update_locks[address]:
self._known_addresses_out_of_sync.discard(address)
local_status, local_history = await self.get_local_status_and_history(address)
if local_status == remote_status:
return True
remote_history = await self.network.retriable_call(self.network.get_history, address)
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
we_need = set(remote_history) - set(local_history)
if not we_need:
return True
cache_tasks: List[asyncio.Task[Transaction]] = []
synced_history = StringIO()
loop = asyncio.get_running_loop()
for i, (txid, remote_height) in enumerate(remote_history):
if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
synced_history.write(f'{txid}:{remote_height}:')
else:
check_local = (txid, remote_height) not in we_need
cache_tasks.append(loop.create_task(
self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local)
))
synced_txs = []
for task in cache_tasks:
tx = await task
check_db_for_txos = []
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash)
if cache_item is not None:
if cache_item.tx is None:
await cache_item.has_tx.wait()
assert cache_item.tx is not None
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
else:
check_db_for_txos.append(txi.txo_ref.hash)
referenced_txos = {} if not check_db_for_txos else {
txo.id: txo for txo in await self.db.get_txos(
txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True
)
}
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
referenced_txo = referenced_txos.get(txi.txo_ref.id)
if referenced_txo is not None:
txi.txo_ref = referenced_txo.ref
synced_history.write(f'{tx.id}:{tx.height}:')
synced_txs.append(tx)
await self.db.save_transaction_io_batch(
synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue()
)
await asyncio.wait([
self.ledger._on_transaction_controller.add(TransactionEvent(address, tx))
for tx in synced_txs
])
if address_manager is None:
address_manager = await self.get_address_manager_for_address(address)
if address_manager is not None:
await address_manager.ensure_address_gap()
local_status, local_history = \
await self.get_local_status_and_history(address, synced_history.getvalue())
if local_status != remote_status:
if local_history == remote_history:
return True
log.warning(
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
remote_status, len(remote_history), local_status, len(local_history)
)
log.warning("local: %s", local_history)
log.warning("remote: %s", remote_history)
self._known_addresses_out_of_sync.add(address)
return False
else:
return True
async def cache_transaction(self, tx_hash, remote_height, check_local=True):
cache_item = self._tx_cache.get(tx_hash)
if cache_item is None:
cache_item = self._tx_cache[tx_hash] = TransactionCacheItem()
elif cache_item.tx is not None and \
cache_item.tx.height >= remote_height and \
(cache_item.tx.is_verified or remote_height < 1):
return cache_item.tx # cached tx is already up-to-date
try:
cache_item.pending_verifications += 1
return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local)
finally:
cache_item.pending_verifications -= 1
async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True):
async with cache_item.lock:
tx = cache_item.tx
if tx is None and check_local:
# check local db
tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash)
merkle = None
if tx is None:
# fetch from network
_raw, merkle = await self.network.retriable_call(
self.network.get_transaction_and_merkle, tx_hash, remote_height
)
tx = Transaction(unhexlify(_raw), height=merkle.get('block_height'))
cache_item.tx = tx # make sure it's saved before caching it
await self.maybe_verify_transaction(tx, remote_height, merkle)
return tx
async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
tx.height = remote_height
cached = self._tx_cache.get(tx.hash)
if not cached:
# cache txs looked up by transaction_show too
cached = TransactionCacheItem()
cached.tx = tx
self._tx_cache[tx.hash] = cached
if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
# can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
if not merkle:
merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = await self.headers.get(remote_height)
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
details = await self.db.get_address(address=address)
for account in self.accounts:
if account.id == details['account']:
return account.address_managers[details['chain']]
return None
#
# @property
# def local_height_including_downloaded_height(self):
# return max(self.headers.height, self._download_height)
#
# async def initial_headers_sync(self):
# get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True)
# self.headers.chunk_getter = get_chunk
#
# async def doit():
# for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)):
# async with self._header_processing_lock:
# await self.headers.ensure_chunk_at(height)
# self._other_tasks.add(doit())
# await self.update_headers()
#
# async def update_headers(self, height=None, headers=None, subscription_update=False):
# rewound = 0
# while True:
#
# if height is None or height > len(self.headers):
# # sometimes header subscription updates are for a header in the future
# # which can't be connected, so we do a normal header sync instead
# height = len(self.headers)
# headers = None
# subscription_update = False
#
# if not headers:
# header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
# headers = header_response['hex']
#
# if not headers:
# # Nothing to do, network thinks we're already at the latest height.
# return
#
# added = await self.headers.connect(height, unhexlify(headers))
# if added > 0:
# height += added
# self._on_header_controller.add(
# BlockHeightEvent(self.headers.height, added))
#
# if rewound > 0:
# # we started rewinding blocks and apparently found
# # a new chain
# rewound = 0
# await self.db.rewind_blockchain(height)
#
# if subscription_update:
# # subscription updates are for latest header already
# # so we don't need to check if there are newer / more
# # on another loop of update_headers(), just return instead
# return
#
# elif added == 0:
# # we had headers to connect but none got connected, probably a reorganization
# height -= 1
# rewound += 1
# log.warning(
# "Blockchain Reorganization: attempting rewind to height %s from starting height %s",
# height, height+rewound
# )
#
# else:
# raise IndexError(f"headers.connect() returned negative number ({added})")
#
# if height < 0:
# raise IndexError(
# "Blockchain reorganization rewound all the way back to genesis hash. "
# "Something is very wrong. Maybe you are on the wrong blockchain?"
# )
#
# if rewound >= 100:
# raise IndexError(
# "Blockchain reorganization dropped {} headers. This is highly unusual. "
# "Will not continue to attempt reorganizing. Please, delete the ledger "
# "synchronization directory inside your wallet directory (folder: '{}') and "
# "restart the program to synchronize from scratch."
# .format(rewound, self.ledger.get_id())
# )
#
# headers = None # ready to download some more headers
#
# # if we made it this far and this was a subscription_update
# # it means something went wrong and now we're doing a more
# # robust sync, turn off subscription update shortcut
# subscription_update = False
#
# async def receive_header(self, response):
# async with self._header_processing_lock:
# header = response[0]
# await self.update_headers(
# height=header['height'], headers=header['hex'], subscription_update=True
# )
#
# async def subscribe_accounts(self):
# if self.network.is_connected and self.accounts:
# log.info("Subscribe to %i accounts", len(self.accounts))
# await asyncio.wait([
# self.subscribe_account(a) for a in self.accounts
# ])
#
# async def subscribe_account(self, account: Account):
# for address_manager in account.address_managers.values():
# await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
# await account.ensure_address_gap()
#
# async def unsubscribe_account(self, account: Account):
# for address in await account.get_addresses():
# await self.network.unsubscribe_address(address)
#
# async def subscribe_addresses(
# self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
# if self.network.is_connected and addresses:
# addresses_remaining = list(addresses)
# while addresses_remaining:
# batch = addresses_remaining[:batch_size]
# results = await self.network.subscribe_address(*batch)
# for address, remote_status in zip(batch, results):
# self._update_tasks.add(self.update_history(address, remote_status, address_manager))
# addresses_remaining = addresses_remaining[batch_size:]
# log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
# len(addresses), *self.network.client.server_address_and_port)
# log.info(
# "finished subscribing to %i addresses on %s:%i", len(addresses),
# *self.network.client.server_address_and_port
# )
#
# def process_status_update(self, update):
# address, remote_status = update
# self._update_tasks.add(self.update_history(address, remote_status))
#
# async def update_history(self, address, remote_status, address_manager: AddressManager = None):
# async with self._address_update_locks[address]:
# self._known_addresses_out_of_sync.discard(address)
#
# local_status, local_history = await self.get_local_status_and_history(address)
#
# if local_status == remote_status:
# return True
#
# remote_history = await self.network.retriable_call(self.network.get_history, address)
# remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
# we_need = set(remote_history) - set(local_history)
# if not we_need:
# return True
#
# cache_tasks: List[asyncio.Task[Transaction]] = []
# synced_history = StringIO()
# loop = asyncio.get_running_loop()
# for i, (txid, remote_height) in enumerate(remote_history):
# if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
# synced_history.write(f'{txid}:{remote_height}:')
# else:
# check_local = (txid, remote_height) not in we_need
# cache_tasks.append(loop.create_task(
# self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local)
# ))
#
# synced_txs = []
# for task in cache_tasks:
# tx = await task
#
# check_db_for_txos = []
# for txi in tx.inputs:
# if txi.txo_ref.txo is not None:
# continue
# cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash)
# if cache_item is not None:
# if cache_item.tx is None:
# await cache_item.has_tx.wait()
# assert cache_item.tx is not None
# txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
# else:
# check_db_for_txos.append(txi.txo_ref.hash)
#
# referenced_txos = {} if not check_db_for_txos else {
# txo.id: txo for txo in await self.db.get_txos(
# txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True
# )
# }
#
# for txi in tx.inputs:
# if txi.txo_ref.txo is not None:
# continue
# referenced_txo = referenced_txos.get(txi.txo_ref.id)
# if referenced_txo is not None:
# txi.txo_ref = referenced_txo.ref
#
# synced_history.write(f'{tx.id}:{tx.height}:')
# synced_txs.append(tx)
#
# await self.db.save_transaction_io_batch(
# synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue()
# )
# await asyncio.wait([
# self.ledger._on_transaction_controller.add(TransactionEvent(address, tx))
# for tx in synced_txs
# ])
#
# if address_manager is None:
# address_manager = await self.get_address_manager_for_address(address)
#
# if address_manager is not None:
# await address_manager.ensure_address_gap()
#
# local_status, local_history = \
# await self.get_local_status_and_history(address, synced_history.getvalue())
# if local_status != remote_status:
# if local_history == remote_history:
# return True
# log.warning(
# "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
# remote_status, len(remote_history), local_status, len(local_history)
# )
# log.warning("local: %s", local_history)
# log.warning("remote: %s", remote_history)
# self._known_addresses_out_of_sync.add(address)
# return False
# else:
# return True
#
# async def cache_transaction(self, tx_hash, remote_height, check_local=True):
# cache_item = self._tx_cache.get(tx_hash)
# if cache_item is None:
# cache_item = self._tx_cache[tx_hash] = TransactionCacheItem()
# elif cache_item.tx is not None and \
# cache_item.tx.height >= remote_height and \
# (cache_item.tx.is_verified or remote_height < 1):
# return cache_item.tx # cached tx is already up-to-date
#
# try:
# cache_item.pending_verifications += 1
# return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local)
# finally:
# cache_item.pending_verifications -= 1
#
# async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True):
#
# async with cache_item.lock:
#
# tx = cache_item.tx
#
# if tx is None and check_local:
# # check local db
# tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash)
#
# merkle = None
# if tx is None:
# # fetch from network
# _raw, merkle = await self.network.retriable_call(
# self.network.get_transaction_and_merkle, tx_hash, remote_height
# )
# tx = Transaction(unhexlify(_raw), height=merkle.get('block_height'))
# cache_item.tx = tx # make sure it's saved before caching it
# await self.maybe_verify_transaction(tx, remote_height, merkle)
# return tx
#
# async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
# tx.height = remote_height
# cached = self._tx_cache.get(tx.hash)
# if not cached:
# # cache txs looked up by transaction_show too
# cached = TransactionCacheItem()
# cached.tx = tx
# self._tx_cache[tx.hash] = cached
# if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
# # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
# if not merkle:
# merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height)
# merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
# header = await self.headers.get(remote_height)
# tx.position = merkle['pos']
# tx.is_verified = merkle_root == header['merkle_root']
#
# async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
# details = await self.db.get_address(address=address)
# for account in self.accounts:
# if account.id == details['account']:
# return account.address_managers[details['chain']]
# return None

View file

@ -8,7 +8,6 @@ from lbry.error import (
)
from lbry.blockchain.dewies import lbc_to_dewies
from lbry.event import EventController
from lbry.blockchain.transaction import Output, Transaction
log = logging.getLogger(__name__)
@ -51,17 +50,17 @@ class WalletServerPayer:
)
continue
tx = await Transaction.create(
[],
[Output.pay_pubkey_hash(amount, self.ledger.address_to_hash160(address))],
self.wallet.get_accounts_or_all(None),
self.wallet.get_account_or_default(None)
)
await self.ledger.broadcast(tx)
if self.analytics_manager:
await self.analytics_manager.send_credits_sent()
self._on_payment_controller.add(tx)
# tx = await Transaction.create(
# [],
# [Output.pay_pubkey_hash(amount, self.ledger.address_to_hash160(address))],
# self.wallet.get_accounts_or_all(None),
# self.wallet.get_account_or_default(None)
# )
#
# await self.ledger.broadcast(tx)
# if self.analytics_manager:
# await self.analytics_manager.send_credits_sent()
# self._on_payment_controller.add(tx)
async def start(self, ledger=None, wallet=None):
if lbc_to_dewies(self.max_fee) < 1:

View file

@ -1,15 +1,17 @@
# pylint: disable=arguments-differ
import os
import json
import zlib
import asyncio
import logging
from datetime import datetime
from typing import Awaitable, Callable, List, Tuple, Optional, Iterable, Union
from hashlib import sha256
from operator import attrgetter
from decimal import Decimal
from lbry.db import Database, SPENDABLE_TYPE_CODES
from lbry.db import Database, SPENDABLE_TYPE_CODES, Result
from lbry.blockchain.ledger import Ledger
from lbry.constants import COIN, NULL_HASH32
from lbry.blockchain.transaction import Transaction, Input, Output
@ -243,7 +245,7 @@ class Wallet:
accounts=funding_accounts,
txo_type__in=SPENDABLE_TYPE_CODES
)
for utxo in utxos[0]:
for utxo in utxos:
estimators.append(OutputEffectiveAmountEstimator(self.ledger, utxo))
return estimators
@ -349,6 +351,35 @@ class Wallet:
output = Output.pay_pubkey_hash(amount, self.ledger.address_to_hash160(address))
return await self.create_transaction([], [output], funding_accounts, change_account)
async def fund(self, from_account, to_account, amount=None, everything=False,
outputs=1, broadcast=False, **constraints):
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
if everything:
utxos = await self.db.get_utxos(**constraints)
await self.db.reserve_outputs(utxos)
tx = await self.create_transaction(
inputs=[Input.spend(txo) for txo in utxos],
outputs=[],
funding_accounts=[from_account],
change_account=to_account
)
elif amount > 0:
to_address = await to_account.change.get_or_create_usable_address()
to_hash160 = to_account.ledger.address_to_hash160(to_address)
tx = await self.create_transaction(
inputs=[],
outputs=[
Output.pay_pubkey_hash(amount//outputs, to_hash160)
for _ in range(outputs)
],
funding_accounts=[from_account],
change_account=from_account
)
else:
raise ValueError('An amount is required.')
return tx
async def _report_state(self):
try:
for account in self.accounts:
@ -375,7 +406,7 @@ class Wallet:
async def verify_duplicate(self, name: str, allow_duplicate: bool):
if not allow_duplicate:
claims, _ = await self.claims.list(claim_name=name)
claims = await self.claims.list(claim_name=name)
if len(claims) > 0:
raise Exception(
f"You already have a claim published under the name '{name}'. "
@ -454,7 +485,7 @@ class BaseListManager:
def __init__(self, wallet: Wallet):
self.wallet = wallet
async def create(self, **kwargs) -> Transaction:
async def create(self, *args, **kwargs) -> Transaction:
raise NotImplementedError
async def delete(self, **constraints) -> Transaction:
@ -525,7 +556,7 @@ class ClaimListManager(BaseListManager):
[Input.spend(claim)], [], self.wallet._accounts, self.wallet._accounts[0]
)
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
async def list(self, **constraints) -> Result[Output]:
return await self.wallet.db.get_claims(wallet=self.wallet, **constraints)
async def get(self, claim_id=None, claim_name=None, txid=None, nout=None) -> Output:
@ -537,7 +568,7 @@ class ClaimListManager(BaseListManager):
key, value, constraints = 'name', claim_name, {'claim_name': claim_name}
else:
raise ValueError(f"Couldn't find {self.name} because an {self.name}_id or name was not provided.")
claims, _ = await self.list(**constraints)
claims = await self.list(**constraints)
if len(claims) == 1:
return claims[0]
elif len(claims) > 1:
@ -626,7 +657,7 @@ class ChannelListManager(ClaimListManager):
return tx
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
async def list(self, **constraints) -> Result[Output]:
return await self.wallet.db.get_channels(wallet=self.wallet, **constraints)
async def get_for_signing(self, channel_id=None, channel_name=None) -> Output:
@ -685,7 +716,7 @@ class StreamListManager(ClaimListManager):
return tx, file_stream
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
async def list(self, **constraints) -> Result[Output]:
return await self.wallet.db.get_streams(wallet=self.wallet, **constraints)
@ -701,7 +732,7 @@ class CollectionListManager(ClaimListManager):
name, claim, amount, holding_address, funding_accounts, funding_accounts[0], channel
)
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
async def list(self, **constraints) -> Result[Output]:
return await self.wallet.db.get_collections(wallet=self.wallet, **constraints)
@ -717,7 +748,7 @@ class SupportListManager(BaseListManager):
[], [support_output], funding_accounts, change_account
)
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
async def list(self, **constraints) -> Result[Output]:
return await self.wallet.db.get_supports(**constraints)
async def get(self, **constraints) -> Output:
@ -741,9 +772,9 @@ class PurchaseListManager(BaseListManager):
def purchase(self, claim_id: str, amount: int, merchant_address: bytes,
funding_accounts: List['Account'], change_account: 'Account'):
payment = Output.pay_pubkey_hash(amount, self.ledger.address_to_hash160(merchant_address))
payment = Output.pay_pubkey_hash(amount, self.wallet.ledger.address_to_hash160(merchant_address))
data = Output.add_purchase_data(Purchase(claim_id))
return self.create_transaction(
return self.wallet.create_transaction(
[], [payment, data], funding_accounts, change_account
)
@ -752,8 +783,8 @@ class PurchaseListManager(BaseListManager):
override_max_key_fee=False):
fee = txo.claim.stream.fee
fee_amount = exchange.to_dewies(fee.currency, fee.amount)
if not override_max_key_fee and self.ledger.conf.max_key_fee:
max_fee = self.ledger.conf.max_key_fee
if not override_max_key_fee and self.wallet.ledger.conf.max_key_fee:
max_fee = self.wallet.ledger.conf.max_key_fee
max_fee_amount = exchange.to_dewies(max_fee['currency'], Decimal(max_fee['amount']))
if max_fee_amount and fee_amount > max_fee_amount:
error_fee = f"{dewies_to_lbc(fee_amount)} LBC"
@ -766,12 +797,12 @@ class PurchaseListManager(BaseListManager):
f"Purchase price of {error_fee} exceeds maximum "
f"configured price of {error_max_fee}."
)
fee_address = fee.address or txo.get_address(self.ledger)
fee_address = fee.address or txo.get_address(self.wallet.ledger)
return await self.purchase(
txo.claim_id, fee_amount, fee_address, accounts, accounts[0]
)
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
async def list(self, **constraints) -> Result[Output]:
return await self.wallet.db.get_purchases(**constraints)
async def get(self, **constraints) -> Output:
@ -784,12 +815,12 @@ class PurchaseListManager(BaseListManager):
def txs_to_dict(txs, ledger):
history = []
for tx in txs: # pylint: disable=too-many-nested-blocks
ts = headers.estimated_timestamp(tx.height)
ts = ledger.headers.estimated_timestamp(tx.height)
item = {
'txid': tx.id,
'timestamp': ts,
'date': datetime.fromtimestamp(ts).isoformat(' ')[:-3] if tx.height > 0 else None,
'confirmations': (headers.height + 1) - tx.height if tx.height > 0 else 0,
'confirmations': (ledger.headers.height + 1) - tx.height if tx.height > 0 else 0,
'claim_info': [],
'update_info': [],
'support_info': [],
@ -807,7 +838,7 @@ def txs_to_dict(txs, ledger):
item['fee'] = '0.0'
for txo in tx.my_claim_outputs:
item['claim_info'].append({
'address': txo.get_address(self.ledger),
'address': txo.get_address(ledger),
'balance_delta': dewies_to_lbc(-txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
@ -827,7 +858,7 @@ def txs_to_dict(txs, ledger):
break
if previous is not None:
item['update_info'].append({
'address': txo.get_address(self),
'address': txo.get_address(ledger),
'balance_delta': dewies_to_lbc(previous.amount - txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
@ -837,7 +868,7 @@ def txs_to_dict(txs, ledger):
})
else: # someone sent us their claim
item['update_info'].append({
'address': txo.get_address(self),
'address': txo.get_address(ledger),
'balance_delta': dewies_to_lbc(0),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
@ -847,7 +878,7 @@ def txs_to_dict(txs, ledger):
})
for txo in tx.my_support_outputs:
item['support_info'].append({
'address': txo.get_address(self.ledger),
'address': txo.get_address(ledger),
'balance_delta': dewies_to_lbc(txo.amount if not is_my_inputs else -txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
@ -859,7 +890,7 @@ def txs_to_dict(txs, ledger):
if is_my_inputs:
for txo in tx.other_support_outputs:
item['support_info'].append({
'address': txo.get_address(self.ledger),
'address': txo.get_address(ledger),
'balance_delta': dewies_to_lbc(-txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
@ -870,7 +901,7 @@ def txs_to_dict(txs, ledger):
})
for txo in tx.my_abandon_outputs:
item['abandon_info'].append({
'address': txo.get_address(self.ledger),
'address': txo.get_address(ledger),
'balance_delta': dewies_to_lbc(txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.claim_id,
@ -879,7 +910,7 @@ def txs_to_dict(txs, ledger):
})
for txo in tx.any_purchase_outputs:
item['purchase_info'].append({
'address': txo.get_address(self.ledger),
'address': txo.get_address(ledger),
'balance_delta': dewies_to_lbc(txo.amount if not is_my_inputs else -txo.amount),
'amount': dewies_to_lbc(txo.amount),
'claim_id': txo.purchased_claim_id,

View file

@ -11,11 +11,11 @@ ignore_missing_imports = True
[pylint]
jobs=8
ignore=words,server,rpc,schema,winpaths.py,migrator,undecorated.py
ignore=words,schema,migrator,extras,ui,api.py
max-parents=10
max-args=10
max-line-length=120
good-names=T,t,n,i,j,k,x,y,s,f,d,h,c,e,op,db,tx,io,cachedproperty,log,id,r,iv,ts,l,it,fp
good-names=T,t,n,i,j,k,x,y,s,f,d,h,c,e,op,db,tx,io,cachedproperty,log,id,r,iv,ts,l,it,fp,q,p
valid-metaclass-classmethod-first-arg=mcs
disable=
fixme,
@ -24,16 +24,20 @@ disable=
cyclic-import,
missing-docstring,
duplicate-code,
blacklisted-name,
expression-not-assigned,
inconsistent-return-statements,
trailing-comma-tuple,
too-few-public-methods,
too-many-lines,
too-many-locals,
too-many-branches,
too-many-ancestors,
too-many-arguments,
too-many-statements,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-instance-attributes,
protected-access,
unused-argument

View file

@ -8,11 +8,7 @@ setenv =
HOME=/tmp
passenv = TEST_DB
commands =
pip install https://github.com/rogerbinns/apsw/releases/download/3.30.1-r1/apsw-3.30.1-r1.zip \
--global-option=fetch \
--global-option=--version --global-option=3.30.1 --global-option=--all \
--global-option=build --global-option=--enable --global-option=fts5
orchstr8 download
blockchain: coverage run -p --source={envsitepackagesdir}/lbry -m unittest discover -vv integration.blockchain {posargs}
blockchain: coverage run -p --source={envsitepackagesdir}/lbry -m unittest -vv integration.blockchain.test_claim_commands.ChannelCommands.test_create_channel_names {posargs}
#blockchain: coverage run -p --source={envsitepackagesdir}/lbry -m unittest discover -vv integration.blockchain {posargs}
datanetwork: coverage run -p --source={envsitepackagesdir}/lbry -m unittest discover -vv integration.datanetwork {posargs}
other: coverage run -p --source={envsitepackagesdir}/lbry -m unittest discover -vv integration.other {posargs}