This commit is contained in:
Lex Berezhny 2021-01-07 23:06:55 -05:00
parent c42b08b090
commit 51467c546b
6 changed files with 34 additions and 38 deletions

View file

@ -87,10 +87,10 @@ def generate_addresses_using_filters(best_height, allowed_gap, address_manager)
matchers = get_filter_matchers(best_height) matchers = get_filter_matchers(best_height)
with PersistingAddressIterator(*address_manager) as addresses: with PersistingAddressIterator(*address_manager) as addresses:
gap = 0 gap = 0
for address_hash, n, is_new in addresses: for address_hash, n, is_new in addresses: # pylint: disable=unused-variable
gap += 1 gap += 1
address_bytes = bytearray(address_hash) address_bytes = bytearray(address_hash)
for granularity, height, matcher, filter_range in matchers: for matcher, filter_range in matchers:
if matcher.Match(address_bytes): if matcher.Match(address_bytes):
gap = 0 gap = 0
if filter_range not in need and filter_range not in have: if filter_range not in need and filter_range not in have:
@ -105,8 +105,8 @@ def generate_addresses_using_filters(best_height, allowed_gap, address_manager)
def get_missing_sub_filters_for_addresses(granularity, address_manager): def get_missing_sub_filters_for_addresses(granularity, address_manager):
need = set() need = set()
for height, matcher, filter_range in get_filter_matchers_at_granularity(granularity): for matcher, filter_range in get_filter_matchers_at_granularity(granularity):
for address_hash, n, is_new in DatabaseAddressIterator(*address_manager): for address_hash, _, _ in DatabaseAddressIterator(*address_manager):
address_bytes = bytearray(address_hash) address_bytes = bytearray(address_hash)
if matcher.Match(address_bytes) and not has_filter_range(*filter_range): if matcher.Match(address_bytes) and not has_filter_range(*filter_range):
need.add(filter_range) need.add(filter_range)
@ -117,7 +117,7 @@ def get_missing_sub_filters_for_addresses(granularity, address_manager):
def get_missing_tx_for_addresses(address_manager): def get_missing_tx_for_addresses(address_manager):
need = set() need = set()
for tx_hash, matcher in get_tx_matchers_for_missing_txs(): for tx_hash, matcher in get_tx_matchers_for_missing_txs():
for address_hash, n, is_new in DatabaseAddressIterator(*address_manager): for address_hash, _, _ in DatabaseAddressIterator(*address_manager):
address_bytes = bytearray(address_hash) address_bytes = bytearray(address_hash)
if matcher.Match(address_bytes): if matcher.Match(address_bytes):
need.add(tx_hash) need.add(tx_hash)

View file

@ -113,7 +113,7 @@ def get_missing_required_filters(height) -> Set[Tuple[int, int, int]]:
return missing_filters return missing_filters
def get_filter_matchers(height) -> List[Tuple[int, int, PyBIP158]]: def get_filter_matchers(height) -> List[Tuple[PyBIP158, Tuple[int, int, int]]]:
conditions = [] conditions = []
for granularity, (start, end) in get_minimal_required_filter_ranges(height).items(): for granularity, (start, end) in get_minimal_required_filter_ranges(height).items():
conditions.append( conditions.append(
@ -127,20 +127,18 @@ def get_filter_matchers(height) -> List[Tuple[int, int, PyBIP158]]:
.order_by(BlockFilter.c.height.desc()) .order_by(BlockFilter.c.height.desc())
) )
return [( return [(
bf["factor"], bf["height"],
get_address_filter(bf["address_filter"]), get_address_filter(bf["address_filter"]),
get_sub_filter_range(bf["factor"], bf["height"]) get_sub_filter_range(bf["factor"], bf["height"])
) for bf in context().fetchall(query)] ) for bf in context().fetchall(query)]
def get_filter_matchers_at_granularity(granularity) -> List[Tuple[int, PyBIP158, Tuple]]: def get_filter_matchers_at_granularity(granularity) -> List[Tuple[PyBIP158, Tuple]]:
query = ( query = (
select(BlockFilter.c.height, BlockFilter.c.address_filter) select(BlockFilter.c.height, BlockFilter.c.address_filter)
.where(BlockFilter.c.factor == granularity) .where(BlockFilter.c.factor == granularity)
.order_by(BlockFilter.c.height.desc()) .order_by(BlockFilter.c.height.desc())
) )
return [( return [(
bf["height"],
get_address_filter(bf["address_filter"]), get_address_filter(bf["address_filter"]),
get_sub_filter_range(granularity, bf["height"]) get_sub_filter_range(granularity, bf["height"])
) for bf in context().fetchall(query)] ) for bf in context().fetchall(query)]

View file

@ -5,6 +5,7 @@ from typing import Tuple, List, Optional, Union
from sqlalchemy import union, func, text, between, distinct, case, false from sqlalchemy import union, func, text, between, distinct, case, false
from sqlalchemy.future import select, Select from sqlalchemy.future import select, Select
from lbry.constants import INVALIDATED_SIGNATURE_GRACE_PERIOD
from ...blockchain.transaction import ( from ...blockchain.transaction import (
Transaction, Output, OutputScript, TXRefImmutable Transaction, Output, OutputScript, TXRefImmutable
) )
@ -15,7 +16,6 @@ from ..tables import (
from ..utils import query, in_account_ids from ..utils import query, in_account_ids
from ..query_context import context from ..query_context import context
from ..constants import TXO_TYPES, CLAIM_TYPE_CODES, MAX_QUERY_VARIABLES from ..constants import TXO_TYPES, CLAIM_TYPE_CODES, MAX_QUERY_VARIABLES
from lbry.constants import INVALIDATED_SIGNATURE_GRACE_PERIOD
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -162,22 +162,22 @@ class FilterManager:
working_branch = double_sha256(combined) working_branch = double_sha256(combined)
return hexlify(working_branch[::-1]) return hexlify(working_branch[::-1])
async def maybe_verify_transaction(self, tx, remote_height, merkle=None): # async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
tx.height = remote_height # tx.height = remote_height
cached = self._tx_cache.get(tx.hash) # cached = self._tx_cache.get(tx.hash)
if not cached: # if not cached:
# cache txs looked up by transaction_show too # # cache txs looked up by transaction_show too
cached = TransactionCacheItem() # cached = TransactionCacheItem()
cached.tx = tx # cached.tx = tx
self._tx_cache[tx.hash] = cached # self._tx_cache[tx.hash] = cached
if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1: # if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
# can't be tx.pending_verifications == 1 because we have to handle the transaction_show case # # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
if not merkle: # if not merkle:
merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height) # merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) # merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = await self.headers.get(remote_height) # header = await self.headers.get(remote_height)
tx.position = merkle['pos'] # tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root'] # tx.is_verified = merkle_root == header['merkle_root']
class BlockHeaderManager: class BlockHeaderManager:

View file

@ -15,7 +15,6 @@ import multiprocessing as mp
from unittest.case import _Outcome from unittest.case import _Outcome
from typing import Optional, List, Union, Tuple from typing import Optional, List, Union, Tuple
from binascii import unhexlify, hexlify from binascii import unhexlify, hexlify
from distutils.dir_util import remove_tree
import ecdsa import ecdsa
@ -32,7 +31,6 @@ from lbry.constants import COIN, CENT, NULL_HASH32
from lbry.service import API, Daemon, Service, FullNode, FullEndpoint, LightClient, jsonrpc_dumps_pretty from lbry.service import API, Daemon, Service, FullNode, FullEndpoint, LightClient, jsonrpc_dumps_pretty
from lbry.conf import Config from lbry.conf import Config
from lbry.console import Console from lbry.console import Console
from lbry.wallet import Wallet, Account
from lbry.schema.claim import Claim from lbry.schema.claim import Claim
from lbry.service.exchange_rate_manager import ( from lbry.service.exchange_rate_manager import (
@ -594,22 +592,22 @@ class CommandTestCase(IntegrationTestCase):
await self.generate(5) await self.generate(5)
def broadcast(self, tx): def broadcast(self, tx):
return self.ledger.broadcast(tx) return self.service.broadcast(tx)
async def on_header(self, height): async def on_header(self, height):
if self.ledger.headers.height < height: if self.service.headers.height < height:
await self.ledger.on_header.where( await self.service.on_header.where(
lambda e: e.height == height lambda e: e.height == height
) )
return True return True
def on_transaction_id(self, txid, ledger=None): def on_transaction_id(self, txid, ledger=None):
return (ledger or self.ledger).on_transaction.where( return (ledger or self.service).on_transaction.where(
lambda e: e.tx.id == txid lambda e: e.tx.id == txid
) )
def on_transaction_hash(self, tx_hash, ledger=None): def on_transaction_hash(self, tx_hash, ledger=None):
return (ledger or self.ledger).on_transaction.where( return (ledger or self.service).on_transaction.where(
lambda e: e.tx.hash == tx_hash lambda e: e.tx.hash == tx_hash
) )
@ -617,12 +615,12 @@ class CommandTestCase(IntegrationTestCase):
await self.service.wait(Transaction(unhexlify(tx['hex']))) await self.service.wait(Transaction(unhexlify(tx['hex'])))
def on_address_update(self, address): def on_address_update(self, address):
return self.ledger.on_transaction.where( return self.service.on_transaction.where(
lambda e: e.address == address lambda e: e.address == address
) )
def on_transaction_address(self, tx, address): def on_transaction_address(self, tx, address):
return self.ledger.on_transaction.where( return self.service.on_transaction.where(
lambda e: e.tx.id == tx.id and e.address == address lambda e: e.tx.id == tx.id and e.address == address
) )
@ -807,9 +805,9 @@ class CommandTestCase(IntegrationTestCase):
async def txo_spend(self, *args, confirm=True, **kwargs): async def txo_spend(self, *args, confirm=True, **kwargs):
txs = await self.api.txo_spend(*args, **kwargs) txs = await self.api.txo_spend(*args, **kwargs)
if confirm: if confirm:
await asyncio.wait([self.ledger.wait(tx) for tx in txs]) await asyncio.wait([self.service.wait(tx) for tx in txs])
await self.generate(1) await self.generate(1)
await asyncio.wait([self.ledger.wait(tx, self.block_expected) for tx in txs]) await asyncio.wait([self.service.wait(tx, self.block_expected) for tx in txs])
return self.sout(txs) return self.sout(txs)
async def resolve(self, uri, **kwargs): async def resolve(self, uri, **kwargs):

View file

@ -97,7 +97,7 @@ class TestAddressGenerationAndTXSync(UnitDBTestCase):
)] )]
def test_generator_persisting(self): def test_generator_persisting(self):
expected = [self.receiving_pubkey.child(n).addresses for n in range(30)] expected = [self.receiving_pubkey.child(n).address for n in range(30)]
self.assertEqual([], self.get_ordered_addresses()) self.assertEqual([], self.get_ordered_addresses())
self.generate(5, 0) self.generate(5, 0)
self.assertEqual(expected[:6], self.get_ordered_addresses()) self.assertEqual(expected[:6], self.get_ordered_addresses())