From 2260608bb609a69f187a978a04f5e04f44bf48bf Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Sat, 30 Mar 2019 19:40:01 -0400 Subject: [PATCH] working sql claimtrie --- lbrynet/extras/daemon/Daemon.py | 55 +-- lbrynet/schema/page.py | 71 +++ lbrynet/schema/types/v2/page_pb2.py | 273 +++++++++++ lbrynet/testcase.py | 8 + lbrynet/wallet/ledger.py | 24 +- lbrynet/wallet/manager.py | 6 - lbrynet/wallet/network.py | 15 +- lbrynet/wallet/server/block_processor.py | 142 +----- lbrynet/wallet/server/db.py | 598 +++++++++++++++++------ lbrynet/wallet/server/model.py | 15 - lbrynet/wallet/server/session.py | 72 +-- lbrynet/wallet/transaction.py | 7 +- tests/integration/test_claim_commands.py | 17 +- tests/integration/test_wallet_server.py | 25 + tests/unit/wallet/server/test_sqldb.py | 169 +++++++ 15 files changed, 1057 insertions(+), 440 deletions(-) create mode 100644 lbrynet/schema/page.py create mode 100644 lbrynet/schema/types/v2/page_pb2.py delete mode 100644 lbrynet/wallet/server/model.py create mode 100644 tests/integration/test_wallet_server.py create mode 100644 tests/unit/wallet/server/test_sqldb.py diff --git a/lbrynet/extras/daemon/Daemon.py b/lbrynet/extras/daemon/Daemon.py index a7dbde6da..c50455fbf 100644 --- a/lbrynet/extras/daemon/Daemon.py +++ b/lbrynet/extras/daemon/Daemon.py @@ -1663,9 +1663,7 @@ class Daemon(metaclass=JSONRPCServerType): ) @requires(WALLET_COMPONENT) - async def jsonrpc_claim_search( - self, name=None, claim_id=None, txid=None, nout=None, - channel_id=None, channel_name=None, winning=False, page=1, page_size=10): + async def jsonrpc_claim_search(self, **kwargs): """ Search for stream and channel claims on the blockchain. @@ -1673,7 +1671,7 @@ class Daemon(metaclass=JSONRPCServerType): Usage: claim_search [ | --name=] [--claim_id=] [--txid= --nout=] - [--channel_id=] [--channel_name=] [--winning] [--page=] + [--channel_id=] [--channel_name=] [--is_winning] [--page=] [--page_size=] Options: @@ -1683,45 +1681,19 @@ class Daemon(metaclass=JSONRPCServerType): --nout= : (str) find a claim with this txid:nout --channel_id= : (str) limit search to specific channel claim id (returns stream claims) --channel_name= : (str) limit search to specific channel name (returns stream claims) - --winning : (bool) limit to winning claims + --is_winning : (bool) limit to winning claims --page= : (int) page to return during paginating --page_size= : (int) number of items on page during pagination Returns: {Paginated[Output]} """ - claims = [] - if name is not None: - claims = await self.ledger.network.get_claims_for_name(name) - elif claim_id is not None: - claim = await self.wallet_manager.get_claim_by_claim_id(claim_id) - if claim and claim != 'claim not found': - claims = {'claims': [claim]} - elif txid is not None and nout is not None: - claim = await self.wallet_manager.get_claim_by_outpoint(txid, int(nout)) - if claim and claim != 'claim not found': - claims = {'claims': [claim]} - elif channel_id is not None or channel_name is not None: - channel_url = f"{channel_name}{('#' + str(channel_id)) if channel_id else ''}" if channel_name else None - if channel_id and not channel_name: - claim = await self.wallet_manager.get_claim_by_claim_id(channel_id) - if claim and claim != 'claim not found': - channel_url = f"{claim['name']}#{claim['claim_id']}" - if channel_url: - resolve = await self.resolve(channel_url, page=page, page_size=page_size) - resolve = resolve.get(channel_url, {}) - claims = resolve.get('claims_in_channel', []) or [] - total_pages = 0 - if claims: - total_pages = int((resolve['total_claims'] + (page_size-1)) / page_size) - #sort_claim_results(claims) - return {"items": claims, "total_pages": total_pages, "page": page, "page_size": page_size} - else: - raise Exception("Must specify either name, claim_id, or txid:nout.") - if claims: - resolutions = await self.resolve(*(f"{claim['name']}#{claim['claim_id']}" for claim in claims['claims'])) - claims = [value.get('claim', value.get('certificate')) for value in resolutions.values()] - sort_claim_results(claims) - return {"items": claims, "total_pages": 1, "page": 1, "page_size": len(claims)} + page_num, page_size = abs(kwargs.pop('page', 1)), min(abs(kwargs.pop('page_size', 10)), 50) + kwargs.update({'offset': page_size * (page_num-1), 'limit': page_size}) + page = await self.ledger.claim_search(**kwargs) + return { + "items": page.txos, "page": page_num, "page_size": page_size, + "total_pages": int((page.total + (page_size-1)) / page_size) + } CHANNEL_DOC = """ Create, update, abandon and list your channel claims. @@ -2641,14 +2613,13 @@ class Daemon(metaclass=JSONRPCServerType): """ account = self.get_account_or_default(account_id) amount = self.get_dewies_or_error("amount", amount) - claim = await account.ledger.get_claim_by_claim_id(claim_id) - claim_name = claim['name'] - claim_address = claim['address'] + claim = await self.ledger.get_claim_by_claim_id(claim_id) + claim_address = claim.get_address(self.ledger) if not tip: claim_address = await account.receiving.get_or_create_usable_address() tx = await Transaction.support( - claim_name, claim_id, amount, claim_address, [account], account + claim.claim_name, claim_id, amount, claim_address, [account], account ) if not preview: diff --git a/lbrynet/schema/page.py b/lbrynet/schema/page.py new file mode 100644 index 000000000..c8262f11e --- /dev/null +++ b/lbrynet/schema/page.py @@ -0,0 +1,71 @@ +import base64 +import struct +from typing import List + +from lbrynet.schema.types.v2.page_pb2 import Page as PageMessage +from lbrynet.wallet.transaction import Transaction, Output + + +class Page: + + __slots__ = 'txs', 'txos', 'offset', 'total' + + def __init__(self, txs, txos, offset, total): + self.txs: List[Transaction] = txs + self.txos: List[Output] = txos + self.offset = offset + self.total = total + + @classmethod + def from_base64(cls, data: str) -> 'Page': + return cls.from_bytes(base64.b64decode(data)) + + @classmethod + def from_bytes(cls, data: bytes) -> 'Page': + page_message = PageMessage() + page_message.ParseFromString(data) + tx_map, txo_list = {}, [] + for tx_message in page_message.txs: + tx = Transaction(tx_message.raw, height=tx_message.height, position=tx_message.position) + tx_map[tx.hash] = tx + for txo_message in page_message.txos: + output = tx_map[txo_message.tx_hash].outputs[txo_message.nout] + if txo_message.WhichOneof('meta') == 'claim': + claim = txo_message.claim + output.meta = { + 'is_winning': claim.is_winning, + 'effective_amount': claim.effective_amount, + 'trending_amount': claim.trending_amount, + } + if claim.HasField('channel'): + output.channel = tx_map[claim.channel.tx_hash].outputs[claim.channel.nout] + txo_list.append(output) + return cls(list(tx_map.values()), txo_list, page_message.offset, page_message.total) + + @classmethod + def to_base64(cls, tx_rows, txo_rows, offset, total) -> str: + return base64.b64encode(cls.to_bytes(tx_rows, txo_rows, offset, total)).decode() + + @classmethod + def to_bytes(cls, tx_rows, txo_rows, offset, total) -> bytes: + page = PageMessage() + page.total = total + page.offset = offset + for tx in tx_rows: + tx_message = page.txs.add() + tx_message.raw = tx['raw'] + tx_message.height = tx['height'] + tx_message.position = tx['position'] + for txo in txo_rows: + txo_message = page.txos.add() + txo_message.tx_hash = txo['txo_hash'][:32] + txo_message.nout, = struct.unpack(' Page: + return Page.from_base64(await self.network.claim_search(**kwargs)) - async def get_claim_by_outpoint(self, txid, nout): - claims = (await self.network.get_claims_in_tx(txid)) or [] - for claim in claims: - if claim['nout'] == nout: - return await self.resolver.parse_and_validate_claim_result(claim) - return 'claim not found' + async def get_claim_by_claim_id(self, claim_id) -> Optional[Output]: + page = await self.claim_search(claim_id=claim_id) + if page.txos: + return page.txos[0] + + async def get_claim_by_outpoint(self, txid, nout) -> Optional[Output]: + page = await self.claim_search(txid=txid, nout=nout) + if page.txos: + return page.txos[0] async def start(self): await super().start() diff --git a/lbrynet/wallet/manager.py b/lbrynet/wallet/manager.py index 8a9f3c2e9..bd8070598 100644 --- a/lbrynet/wallet/manager.py +++ b/lbrynet/wallet/manager.py @@ -307,9 +307,3 @@ class LbryWalletManager(BaseWalletManager): def save(self): for wallet in self.wallets: wallet.save() - - def get_claim_by_claim_id(self, claim_id): - return self.ledger.get_claim_by_claim_id(claim_id) - - def get_claim_by_outpoint(self, txid, nout): - return self.ledger.get_claim_by_outpoint(txid, nout) diff --git a/lbrynet/wallet/network.py b/lbrynet/wallet/network.py index 85aa46a60..af71a87e8 100644 --- a/lbrynet/wallet/network.py +++ b/lbrynet/wallet/network.py @@ -4,20 +4,23 @@ from torba.client.basenetwork import BaseNetwork class Network(BaseNetwork): def get_server_height(self): - return self.rpc('blockchain.block.get_server_height') + return self.rpc('blockchain.block.get_server_height', []) def get_values_for_uris(self, block_hash, *uris): - return self.rpc('blockchain.claimtrie.getvaluesforuris', block_hash, *uris) + return self.rpc('blockchain.claimtrie.getvaluesforuris', [block_hash, *uris]) def get_claims_by_ids(self, *claim_ids): - return self.rpc('blockchain.claimtrie.getclaimsbyids', *claim_ids) + return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids) def get_claims_in_tx(self, txid): - return self.rpc('blockchain.claimtrie.getclaimsintx', txid) + return self.rpc('blockchain.claimtrie.getclaimsintx', [txid]) def get_claims_for_name(self, name): - return self.rpc('blockchain.claimtrie.getclaimsforname', name) + return self.rpc('blockchain.claimtrie.getclaimsforname', [name]) def get_transaction_height(self, txid): # 1.0 protocol specific workaround. Newer protocol should use get_transaction with verbose True - return self.rpc('blockchain.transaction.get_height', txid) + return self.rpc('blockchain.transaction.get_height', [txid]) + + def claim_search(self, **kwargs): + return self.rpc('blockchain.claimtrie.search', kwargs) diff --git a/lbrynet/wallet/server/block_processor.py b/lbrynet/wallet/server/block_processor.py index 7259dd258..3f13528ec 100644 --- a/lbrynet/wallet/server/block_processor.py +++ b/lbrynet/wallet/server/block_processor.py @@ -1,14 +1,7 @@ -import struct - -import msgpack - -from lbrynet.wallet.transaction import Transaction, Output -from torba.server.hash import hash_to_hex_str - from torba.server.block_processor import BlockProcessor -from lbrynet.schema.claim import Claim -from lbrynet.wallet.server.model import ClaimInfo +from lbrynet.schema.claim import Claim +from lbrynet.wallet.server.db import SQLDB class LBRYBlockProcessor(BlockProcessor): @@ -17,126 +10,25 @@ class LBRYBlockProcessor(BlockProcessor): super().__init__(*args, **kwargs) if self.env.coin.NET == "regtest": self.prefetcher.polling_delay = 0.5 - self.should_validate_signatures = self.env.boolean('VALIDATE_CLAIM_SIGNATURES', False) - self.logger.info("LbryumX Block Processor - Validating signatures: {}".format(self.should_validate_signatures)) + self.logger.info(f"LbryumX Block Processor - Validating signatures: {self.should_validate_signatures}") + self.sql: SQLDB = self.db.sql def advance_blocks(self, blocks): - # save height, advance blocks as usual, then hook our claim tx processing - height = self.height + 1 - super().advance_blocks(blocks) - pending_undo = [] - for index, block in enumerate(blocks): - undo = self.advance_claim_txs(block.transactions, height + index) - pending_undo.append((height+index, undo,)) - self.db.write_undo(pending_undo) + self.sql.begin() + try: + super().advance_blocks(blocks) + self.sql.delete_dereferenced_transactions() + except: + self.logger.exception(f'Error while advancing transaction in new block.') + raise + finally: + self.sql.commit() - def advance_claim_txs(self, txs, height): - # TODO: generate claim undo info! - undo_info = [] - add_undo = undo_info.append - update_inputs = set() - for etx, txid in txs: - update_inputs.clear() - tx = Transaction(etx.serialize()) - for index, output in enumerate(tx.outputs): - if not output.is_claim: - continue - if output.script.is_claim_name: - add_undo(self.advance_claim_name_transaction(output, height, txid, index)) - elif output.script.is_update_claim: - update_input = self.db.get_update_input(output.claim_hash, tx.inputs) - if update_input: - update_inputs.add(update_input) - add_undo(self.advance_update_claim(output, height, txid, index)) - else: - info = (hash_to_hex_str(txid), output.claim_id,) - self.logger.error("REJECTED: {} updating {}".format(*info)) - for txin in tx.inputs: - if txin not in update_inputs: - abandoned_claim_id = self.db.abandon_spent(txin.txo_ref.tx_ref.hash, txin.txo_ref.position) - if abandoned_claim_id: - add_undo((abandoned_claim_id, self.db.get_claim_info(abandoned_claim_id))) - return undo_info - - def advance_update_claim(self, output: Output, height, txid, nout): - claim_id = output.claim_hash - claim_info = self.claim_info_from_output(output, txid, nout, height) - old_claim_info = self.db.get_claim_info(claim_id) - self.db.put_claim_id_for_outpoint(old_claim_info.txid, old_claim_info.nout, None) - if old_claim_info.cert_id: - self.db.remove_claim_from_certificate_claims(old_claim_info.cert_id, claim_id) - if claim_info.cert_id: - self.db.put_claim_id_signed_by_cert_id(claim_info.cert_id, claim_id) - self.db.put_claim_info(claim_id, claim_info) - self.db.put_claim_id_for_outpoint(txid, nout, claim_id) - return claim_id, old_claim_info - - def advance_claim_name_transaction(self, output: Output, height, txid, nout): - claim_id = output.claim_hash - claim_info = self.claim_info_from_output(output, txid, nout, height) - if claim_info.cert_id: - self.db.put_claim_id_signed_by_cert_id(claim_info.cert_id, claim_id) - self.db.put_claim_info(claim_id, claim_info) - self.db.put_claim_id_for_outpoint(txid, nout, claim_id) - return claim_id, None - - def backup_from_undo_info(self, claim_id, undo_claim_info): - """ - Undo information holds a claim state **before** a transaction changes it - There are 4 possibilities when processing it, of which only 3 are valid ones: - 1. the claim is known and the undo info has info, it was an update - 2. the claim is known and the undo info doesn't hold any info, it was claimed - 3. the claim in unknown and the undo info has info, it was abandoned - 4. the claim is unknown and the undo info does't hold info, error! - """ - - undo_claim_info = ClaimInfo(*undo_claim_info) if undo_claim_info else None - current_claim_info = self.db.get_claim_info(claim_id) - if current_claim_info and undo_claim_info: - # update, remove current claim - self.db.remove_claim_id_for_outpoint(current_claim_info.txid, current_claim_info.nout) - if current_claim_info.cert_id: - self.db.remove_claim_from_certificate_claims(current_claim_info.cert_id, claim_id) - elif current_claim_info and not undo_claim_info: - # claim, abandon it - self.db.abandon_spent(current_claim_info.txid, current_claim_info.nout) - elif not current_claim_info and undo_claim_info: - # abandon, reclaim it (happens below) - pass - else: - # should never happen, unless the database got into an inconsistent state - raise Exception("Unexpected situation occurred on backup, this means the database is inconsistent. " - "Please report. Resetting the data folder (reindex) solves it for now.") - if undo_claim_info: - self.db.put_claim_info(claim_id, undo_claim_info) - if undo_claim_info.cert_id: - cert_id = self._checksig(undo_claim_info.value, undo_claim_info.address) - self.db.put_claim_id_signed_by_cert_id(cert_id, claim_id) - self.db.put_claim_id_for_outpoint(undo_claim_info.txid, undo_claim_info.nout, claim_id) - - def backup_txs(self, txs): - self.logger.info("Reorg at height {} with {} transactions.".format(self.height, len(txs))) - undo_info = msgpack.loads(self.db.claim_undo_db.get(struct.pack(">I", self.height)), use_list=False) - for claim_id, undo_claim_info in reversed(undo_info): - self.backup_from_undo_info(claim_id, undo_claim_info) - return super().backup_txs(txs) - - def backup_blocks(self, raw_blocks): - self.db.batched_flush_claims() - super().backup_blocks(raw_blocks=raw_blocks) - self.db.batched_flush_claims() - - async def flush(self, flush_utxos): - self.db.batched_flush_claims() - return await super().flush(flush_utxos) - - def claim_info_from_output(self, output: Output, txid, nout, height): - address = self.coin.address_from_script(output.script.source) - name, value, cert_id = output.script.values['claim_name'], output.script.values['claim'], None - assert txid and address - cert_id = self._checksig(value, address) - return ClaimInfo(name, value, txid, nout, output.amount, address, height, cert_id) + def advance_txs(self, height, txs): + undo = super().advance_txs(height, txs) + self.sql.advance_txs(height, txs) + return undo def _checksig(self, value, address): try: diff --git a/lbrynet/wallet/server/db.py b/lbrynet/wallet/server/db.py index c7c1b6062..17ccedf1e 100644 --- a/lbrynet/wallet/server/db.py +++ b/lbrynet/wallet/server/db.py @@ -1,176 +1,450 @@ -import msgpack +import sqlite3 import struct - -import time -from torba.server.hash import hash_to_hex_str +from typing import Union, Tuple, Set, List +from binascii import unhexlify from torba.server.db import DB +from torba.server.util import class_logger +from torba.client.basedatabase import query, constraints_to_sql +from google.protobuf.message import DecodeError -from lbrynet.wallet.server.model import ClaimInfo +from lbrynet.wallet.transaction import Transaction, Output + + +class SQLDB: + + TRENDING_BLOCKS = 300 # number of blocks over which to calculate trending + + PRAGMAS = """ + pragma journal_mode=WAL; + """ + + CREATE_TX_TABLE = """ + create table if not exists tx ( + tx_hash bytes primary key, + raw bytes not null, + position integer not null, + height integer not null + ); + """ + + CREATE_CLAIM_TABLE = """ + create table if not exists claim ( + claim_hash bytes primary key, + tx_hash bytes not null, + txo_hash bytes not null, + height integer not null, + activation_height integer, + amount integer not null, + effective_amount integer not null default 0, + trending_amount integer not null default 0, + claim_name text not null, + channel_hash bytes + ); + create index if not exists claim_tx_hash_idx on claim (tx_hash); + create index if not exists claim_txo_hash_idx on claim (txo_hash); + create index if not exists claim_activation_height_idx on claim (activation_height); + create index if not exists claim_channel_hash_idx on claim (channel_hash); + create index if not exists claim_claim_name_idx on claim (claim_name); + """ + + CREATE_SUPPORT_TABLE = """ + create table if not exists support ( + txo_hash bytes primary key, + tx_hash bytes not null, + claim_hash bytes not null, + position integer not null, + height integer not null, + amount integer not null, + is_comment bool not null default false + ); + create index if not exists support_tx_hash_idx on support (tx_hash); + create index if not exists support_claim_hash_idx on support (claim_hash, height); + """ + + CREATE_TAG_TABLE = """ + create table if not exists tag ( + tag text not null, + txo_hash bytes not null, + height integer not null + ); + create index if not exists tag_tag_idx on tag (tag); + create index if not exists tag_txo_hash_idx on tag (txo_hash); + create index if not exists tag_height_idx on tag (height); + """ + + CREATE_CLAIMTRIE_TABLE = """ + create table if not exists claimtrie ( + claim_name text primary key, + claim_hash bytes not null, + last_take_over_height integer not null + ); + create index if not exists claimtrie_claim_hash_idx on claimtrie (claim_hash); + """ + + CREATE_TABLES_QUERY = ( + PRAGMAS + + CREATE_TX_TABLE + + CREATE_CLAIM_TABLE + + CREATE_SUPPORT_TABLE + + CREATE_CLAIMTRIE_TABLE + + CREATE_TAG_TABLE + ) + + def __init__(self, path): + self._db_path = path + self.db = None + self.logger = class_logger(__name__, self.__class__.__name__) + + def open(self): + self.db = sqlite3.connect(self._db_path, isolation_level=None, check_same_thread=False) + self.db.row_factory = sqlite3.Row + self.db.executescript(self.CREATE_TABLES_QUERY) + + def close(self): + self.db.close() + + @staticmethod + def _insert_sql(table: str, data: dict) -> Tuple[str, list]: + columns, values = [], [] + for column, value in data.items(): + columns.append(column) + values.append(value) + sql = ( + f"INSERT INTO {table} ({', '.join(columns)}) " + f"VALUES ({', '.join(['?'] * len(values))})" + ) + return sql, values + + @staticmethod + def _update_sql(table: str, data: dict, where: str, + constraints: Union[list, tuple]) -> Tuple[str, list]: + columns, values = [], [] + for column, value in data.items(): + columns.append("{} = ?".format(column)) + values.append(value) + values.extend(constraints) + return f"UPDATE {table} SET {', '.join(columns)} WHERE {where}", values + + @staticmethod + def _delete_sql(table: str, constraints: dict) -> Tuple[str, dict]: + where, values = constraints_to_sql(constraints) + return f"DELETE FROM {table} WHERE {where}", values + + def execute(self, *args): + return self.db.execute(*args) + + def begin(self): + self.execute('begin;') + + def commit(self): + self.execute('commit;') + + def insert_txs(self, txs: Set[Transaction]): + if txs: + self.db.executemany( + "INSERT INTO tx (tx_hash, raw, position, height) VALUES (?, ?, ?, ?)", + [(sqlite3.Binary(tx.hash), sqlite3.Binary(tx.raw), tx.position, tx.height) for tx in txs] + ) + + def _upsertable_claims(self, txos: Set[Output]): + claims, tags = [], [] + for txo in txos: + tx = txo.tx_ref.tx + try: + assert txo.claim_name + except (AssertionError, UnicodeDecodeError): + self.logger.exception(f"Could not decode claim name for {tx.id}:{txo.position}.") + continue + try: + claim = txo.claim + if claim.is_channel: + metadata = claim.channel + else: + metadata = claim.stream + except DecodeError: + self.logger.exception(f"Could not parse claim protobuf for {tx.id}:{txo.position}.") + continue + txo_hash = sqlite3.Binary(txo.ref.hash) + channel_hash = sqlite3.Binary(claim.signing_channel_hash) if claim.signing_channel_hash else None + claims.append({ + 'claim_hash': sqlite3.Binary(txo.claim_hash), + 'tx_hash': sqlite3.Binary(tx.hash), + 'txo_hash': txo_hash, + 'channel_hash': channel_hash, + 'amount': txo.amount, + 'claim_name': txo.claim_name, + 'height': tx.height + }) + for tag in metadata.tags: + tags.append((tag, txo_hash, tx.height)) + if tags: + self.db.executemany( + "INSERT INTO tag (tag, txo_hash, height) VALUES (?, ?, ?)", tags + ) + return claims + + def insert_claims(self, txos: Set[Output]): + claims = self._upsertable_claims(txos) + if claims: + self.db.executemany( + "INSERT INTO claim (claim_hash, tx_hash, txo_hash, channel_hash, amount, claim_name, height) " + "VALUES (:claim_hash, :tx_hash, :txo_hash, :channel_hash, :amount, :claim_name, :height) ", + claims + ) + + def update_claims(self, txos: Set[Output]): + claims = self._upsertable_claims(txos) + if claims: + self.db.executemany( + "UPDATE claim SET " + " tx_hash=:tx_hash, txo_hash=:txo_hash, channel_hash=:channel_hash, " + " amount=:amount, height=:height " + "WHERE claim_hash=:claim_hash;", + claims + ) + + def clear_claim_metadata(self, txo_hashes: Set[bytes]): + """ Deletes metadata associated with claim in case of an update or an abandon. """ + if txo_hashes: + binary_txo_hashes = [sqlite3.Binary(txo_hash) for txo_hash in txo_hashes] + for table in ('tag',): # 'language', 'location', etc + self.execute(*self._delete_sql(table, {'txo_hash__in': binary_txo_hashes})) + + def abandon_claims(self, claim_hashes: Set[bytes]): + """ Deletes claim supports and from claimtrie in case of an abandon. """ + if claim_hashes: + binary_claim_hashes = [sqlite3.Binary(claim_hash) for claim_hash in claim_hashes] + for table in ('claim', 'support', 'claimtrie'): + self.execute(*self._delete_sql(table, {'claim_hash__in': binary_claim_hashes})) + + def split_inputs_into_claims_and_other(self, txis): + all = set(txi.txo_ref.hash for txi in txis) + claims = dict(self.execute(*query( + "SELECT txo_hash, claim_hash FROM claim", + txo_hash__in=[sqlite3.Binary(txo_hash) for txo_hash in all] + ))) + return claims, all-set(claims) + + def insert_supports(self, txos: Set[Output]): + supports = [] + for txo in txos: + tx = txo.tx_ref.tx + supports.append(( + sqlite3.Binary(txo.ref.hash), sqlite3.Binary(tx.hash), + sqlite3.Binary(txo.claim_hash), tx.position, tx.height, + txo.amount, False + )) + if supports: + self.db.executemany( + "INSERT INTO support (txo_hash, tx_hash, claim_hash, position, height, amount, is_comment) " + "VALUES (?, ?, ?, ?, ?, ?, ?)", supports + ) + + def delete_other_txos(self, txo_hashes: Set[bytes]): + if txo_hashes: + self.execute(*self._delete_sql( + 'support', {'txo_hash__in': [sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]} + )) + + def delete_dereferenced_transactions(self): + self.execute(""" + DELETE FROM tx WHERE ( + (SELECT COUNT(*) FROM claim WHERE claim.tx_hash=tx.tx_hash) + + (SELECT COUNT(*) FROM support WHERE support.tx_hash=tx.tx_hash) + ) = 0 + """) + + def _make_claims_without_competition_become_controlling(self, height): + self.execute(f""" + INSERT INTO claimtrie (claim_name, claim_hash, last_take_over_height) + SELECT claim.claim_name, claim.claim_hash, {height} FROM claim + LEFT JOIN claimtrie USING (claim_name) + WHERE claimtrie.claim_hash IS NULL + GROUP BY claim.claim_name HAVING COUNT(*) = 1 + """) + self.execute(f""" + UPDATE claim SET activation_height = {height} + WHERE activation_height IS NULL AND claim_hash IN ( + SELECT claim_hash FROM claimtrie + ) + """) + + def _update_trending_amount(self, height): + self.execute(f""" + UPDATE claim SET + trending_amount = COALESCE( + (SELECT SUM(amount) FROM support WHERE support.claim_hash=claim.claim_hash + AND support.height > {height-self.TRENDING_BLOCKS}), 0 + ) + """) + + def _update_effective_amount(self, height): + self.execute(f""" + UPDATE claim SET + effective_amount = claim.amount + COALESCE( + (SELECT SUM(amount) FROM support WHERE support.claim_hash=claim.claim_hash), 0 + ) + WHERE activation_height <= {height} + """) + + def _set_activation_height(self, height): + self.execute(f""" + UPDATE claim SET + activation_height = {height} + min(4032, cast( + ( + {height} - + (SELECT last_take_over_height FROM claimtrie + WHERE claimtrie.claim_name=claim.claim_name) + ) / 32 AS INT)) + WHERE activation_height IS NULL + """) + + def get_overtakings(self): + return self.execute(f""" + SELECT winner.claim_name, winner.claim_hash FROM ( + SELECT claim_name, claim_hash, MAX(effective_amount) + FROM claim GROUP BY claim_name + ) AS winner JOIN claimtrie USING (claim_name) + WHERE claimtrie.claim_hash <> winner.claim_hash + """) + + def _perform_overtake(self, height): + for overtake in self.get_overtakings(): + self.execute( + f"UPDATE claim SET activation_height = {height} WHERE claim_name = ?", + (overtake['claim_name'],) + ) + self.execute( + f"UPDATE claimtrie SET claim_hash = ?, last_take_over_height = {height}", + (sqlite3.Binary(overtake['claim_hash']),) + ) + + def update_claimtrie(self, height): + self._make_claims_without_competition_become_controlling(height) + self._update_trending_amount(height) + self._update_effective_amount(height) + self._set_activation_height(height) + self._perform_overtake(height) + self._update_effective_amount(height) + self._perform_overtake(height) + + def get_transactions(self, tx_hashes): + cur = self.db.cursor() + cur.execute(*query("SELECT * FROM tx", tx_hash__in=tx_hashes)) + return cur.fetchall() + + def get_claims(self, cols, **constraints): + if 'is_winning' in constraints: + constraints['claimtrie.claim_hash__is_not_null'] = '' + del constraints['is_winning'] + if 'name' in constraints: + constraints['claim.claim_name__like'] = constraints.pop('name') + if 'claim_id' in constraints: + constraints['claim.claim_hash'] = sqlite3.Binary( + unhexlify(constraints.pop('claim_id'))[::-1] + ) + if 'channel_id' in constraints: + constraints['claim.channel_hash'] = sqlite3.Binary( + unhexlify(constraints.pop('channel_id'))[::-1] + ) + if 'txid' in constraints: + tx_hash = unhexlify(constraints.pop('txid'))[::-1] + if 'nout' in constraints: + nout = constraints.pop('nout') + constraints['claim.txo_hash'] = sqlite3.Binary( + tx_hash + struct.pack(' Tuple[List, List, int, int]: + assert set(constraints).issubset(self.SEARCH_PARAMS), \ + f"Search query contains invalid arguments: {set(constraints).difference(self.SEARCH_PARAMS)}" + total = self.get_claims_count(**constraints) + constraints['offset'] = abs(constraints.get('offset', 0)) + constraints['limit'] = min(abs(constraints.get('limit', 10)), 50) + constraints['order_by'] = ["claim.height DESC", "claim.claim_name ASC"] + txo_rows = self.get_claims( + """ + claim.txo_hash, channel.txo_hash as channel_txo_hash, + claim.activation_height, claimtrie.claim_hash as is_winning, + claim.effective_amount, claim.trending_amount + """, **constraints + ) + tx_hashes = set() + for claim in txo_rows: + tx_hashes.add(claim['txo_hash'][:32]) + if claim['channel_txo_hash'] is not None: + tx_hashes.add(claim['channel_txo_hash'][:32]) + tx_rows = self.get_transactions([sqlite3.Binary(h) for h in tx_hashes]) + return tx_rows, txo_rows, constraints['offset'], total + + def advance_txs(self, height, all_txs): + sql, txs = self, set() + abandon_claim_hashes, stale_claim_metadata_txo_hashes = set(), set() + insert_claims, update_claims = set(), set() + delete_txo_hashes, insert_supports = set(), set() + for position, (etx, txid) in enumerate(all_txs): + tx = Transaction(etx.serialize(), height=height, position=position) + claim_abandon_map, delete_txo_hashes = sql.split_inputs_into_claims_and_other(tx.inputs) + stale_claim_metadata_txo_hashes.update(claim_abandon_map) + for output in tx.outputs: + if output.is_support: + txs.add(tx) + insert_supports.add(output) + elif output.script.is_claim_name: + txs.add(tx) + insert_claims.add(output) + elif output.script.is_update_claim: + txs.add(tx) + update_claims.add(output) + # don't abandon update claims (removes supports & removes from claimtrie) + for txo_hash, input_claim_hash in claim_abandon_map.items(): + if output.claim_hash == input_claim_hash: + del claim_abandon_map[txo_hash] + break + abandon_claim_hashes.update(claim_abandon_map.values()) + sql.abandon_claims(abandon_claim_hashes) + sql.clear_claim_metadata(stale_claim_metadata_txo_hashes) + sql.delete_other_txos(delete_txo_hashes) + sql.insert_txs(txs) + sql.insert_claims(insert_claims) + sql.update_claims(update_claims) + sql.insert_supports(insert_supports) + sql.update_claimtrie(height) class LBRYDB(DB): def __init__(self, *args, **kwargs): - self.claim_cache = {} - self.claims_signed_by_cert_cache = {} - self.outpoint_to_claim_id_cache = {} - self.claims_db = self.signatures_db = self.outpoint_to_claim_id_db = self.claim_undo_db = None - # stores deletes not yet flushed to disk - self.pending_abandons = {} super().__init__(*args, **kwargs) + self.sql = SQLDB('claims.db') def close(self): - self.batched_flush_claims() - self.claims_db.close() - self.signatures_db.close() - self.outpoint_to_claim_id_db.close() - self.claim_undo_db.close() - self.utxo_db.close() super().close() + self.sql.close() - async def _open_dbs(self, for_sync, compacting): - await super()._open_dbs(for_sync=for_sync, compacting=compacting) - def log_reason(message, is_for_sync): - reason = 'sync' if is_for_sync else 'serving' - self.logger.info('{} for {}'.format(message, reason)) - - if self.claims_db: - if self.claims_db.for_sync == for_sync: - return - log_reason('closing claim DBs to re-open', for_sync) - self.claims_db.close() - self.signatures_db.close() - self.outpoint_to_claim_id_db.close() - self.claim_undo_db.close() - self.claims_db = self.db_class('claims', for_sync) - self.signatures_db = self.db_class('signatures', for_sync) - self.outpoint_to_claim_id_db = self.db_class('outpoint_claim_id', for_sync) - self.claim_undo_db = self.db_class('claim_undo', for_sync) - log_reason('opened claim DBs', self.claims_db.for_sync) - - def flush_dbs(self, flush_data, flush_utxos, estimate_txs_remaining): - # flush claims together with utxos as they are parsed together - self.batched_flush_claims() - return super().flush_dbs(flush_data, flush_utxos, estimate_txs_remaining) - - def batched_flush_claims(self): - with self.claims_db.write_batch() as claims_batch: - with self.signatures_db.write_batch() as signed_claims_batch: - with self.outpoint_to_claim_id_db.write_batch() as outpoint_batch: - self.flush_claims(claims_batch, signed_claims_batch, outpoint_batch) - - def flush_claims(self, batch, signed_claims_batch, outpoint_batch): - flush_start = time.time() - write_claim, write_cert = batch.put, signed_claims_batch.put - write_outpoint = outpoint_batch.put - delete_claim, delete_outpoint = batch.delete, outpoint_batch.delete - delete_cert = signed_claims_batch.delete - for claim_id, outpoints in self.pending_abandons.items(): - claim = self.get_claim_info(claim_id) - if claim.cert_id: - self.remove_claim_from_certificate_claims(claim.cert_id, claim_id) - self.remove_certificate(claim_id) - self.claim_cache[claim_id] = None - for txid, tx_index in outpoints: - self.put_claim_id_for_outpoint(txid, tx_index, None) - for key, claim in self.claim_cache.items(): - if claim: - write_claim(key, claim) - else: - delete_claim(key) - for cert_id, claims in self.claims_signed_by_cert_cache.items(): - if not claims: - delete_cert(cert_id) - else: - write_cert(cert_id, msgpack.dumps(claims)) - for key, claim_id in self.outpoint_to_claim_id_cache.items(): - if claim_id: - write_outpoint(key, claim_id) - else: - delete_outpoint(key) - self.logger.info('flushed at height {:,d} with {:,d} claims, {:,d} outpoints ' - 'and {:,d} certificates added while {:,d} were abandoned in {:.1f}s, committing...' - .format(self.db_height, - len(self.claim_cache), len(self.outpoint_to_claim_id_cache), - len(self.claims_signed_by_cert_cache), len(self.pending_abandons), - time.time() - flush_start)) - self.claim_cache = {} - self.claims_signed_by_cert_cache = {} - self.outpoint_to_claim_id_cache = {} - self.pending_abandons = {} - - def assert_flushed(self, flush_data): - super().assert_flushed(flush_data) - assert not self.claim_cache - assert not self.claims_signed_by_cert_cache - assert not self.outpoint_to_claim_id_cache - assert not self.pending_abandons - - def abandon_spent(self, tx_hash, tx_idx): - claim_id = self.get_claim_id_from_outpoint(tx_hash, tx_idx) - if claim_id: - self.logger.info("[!] Abandon: {}".format(hash_to_hex_str(claim_id))) - self.pending_abandons.setdefault(claim_id, []).append((tx_hash, tx_idx,)) - return claim_id - - def put_claim_id_for_outpoint(self, tx_hash, tx_idx, claim_id): - self.logger.info("[+] Adding outpoint: {}:{} for {}.".format(hash_to_hex_str(tx_hash), tx_idx, - hash_to_hex_str(claim_id) if claim_id else None)) - self.outpoint_to_claim_id_cache[tx_hash + struct.pack('>I', tx_idx)] = claim_id - - def remove_claim_id_for_outpoint(self, tx_hash, tx_idx): - self.logger.info("[-] Remove outpoint: {}:{}.".format(hash_to_hex_str(tx_hash), tx_idx)) - self.outpoint_to_claim_id_cache[tx_hash + struct.pack('>I', tx_idx)] = None - - def get_claim_id_from_outpoint(self, tx_hash, tx_idx): - key = tx_hash + struct.pack('>I', tx_idx) - return self.outpoint_to_claim_id_cache.get(key) or self.outpoint_to_claim_id_db.get(key) - - def get_signed_claim_ids_by_cert_id(self, cert_id): - if cert_id in self.claims_signed_by_cert_cache: - return self.claims_signed_by_cert_cache[cert_id] - db_claims = self.signatures_db.get(cert_id) - return msgpack.loads(db_claims, use_list=True) if db_claims else [] - - def put_claim_id_signed_by_cert_id(self, cert_id, claim_id): - msg = "[+] Adding signature: {} - {}".format(hash_to_hex_str(claim_id), hash_to_hex_str(cert_id)) - self.logger.info(msg) - certs = self.get_signed_claim_ids_by_cert_id(cert_id) - certs.append(claim_id) - self.claims_signed_by_cert_cache[cert_id] = certs - - def remove_certificate(self, cert_id): - msg = "[-] Removing certificate: {}".format(hash_to_hex_str(cert_id)) - self.logger.info(msg) - self.claims_signed_by_cert_cache[cert_id] = [] - - def remove_claim_from_certificate_claims(self, cert_id, claim_id): - msg = "[-] Removing signature: {} - {}".format(hash_to_hex_str(claim_id), hash_to_hex_str(cert_id)) - self.logger.info(msg) - certs = self.get_signed_claim_ids_by_cert_id(cert_id) - if claim_id in certs: - certs.remove(claim_id) - self.claims_signed_by_cert_cache[cert_id] = certs - - def get_claim_info(self, claim_id): - serialized = self.claim_cache.get(claim_id) or self.claims_db.get(claim_id) - return ClaimInfo.from_serialized(serialized) if serialized else None - - def put_claim_info(self, claim_id, claim_info): - self.logger.info("[+] Adding claim info for: {}".format(hash_to_hex_str(claim_id))) - self.claim_cache[claim_id] = claim_info.serialized - - def get_update_input(self, claim_id, inputs): - claim_info = self.get_claim_info(claim_id) - if not claim_info: - return False - for input in inputs: - if (input.txo_ref.tx_ref.hash, input.txo_ref.position) == (claim_info.txid, claim_info.nout): - return input - return False - - def write_undo(self, pending_undo): - with self.claim_undo_db.write_batch() as writer: - for height, undo_info in pending_undo: - writer.put(struct.pack(">I", height), msgpack.dumps(undo_info)) + async def _open_dbs(self, *args): + await super()._open_dbs(*args) + self.sql.open() diff --git a/lbrynet/wallet/server/model.py b/lbrynet/wallet/server/model.py deleted file mode 100644 index a8ad9baf0..000000000 --- a/lbrynet/wallet/server/model.py +++ /dev/null @@ -1,15 +0,0 @@ -from collections import namedtuple -import msgpack -# Classes representing data and their serializers, if any. - - -class ClaimInfo(namedtuple("NameClaim", "name value txid nout amount address height cert_id")): - '''Claim information as its stored on database''' - - @classmethod - def from_serialized(cls, serialized): - return cls(*msgpack.loads(serialized)) - - @property - def serialized(self): - return msgpack.dumps(self) diff --git a/lbrynet/wallet/server/session.py b/lbrynet/wallet/server/session.py index 074cea0c7..f1d265b65 100644 --- a/lbrynet/wallet/server/session.py +++ b/lbrynet/wallet/server/session.py @@ -7,6 +7,7 @@ from torba.server.hash import hash_to_hex_str from torba.server.session import ElectrumX from torba.server import util +from lbrynet.schema.page import Page from lbrynet.schema.uri import parse_lbry_uri, CLAIM_ID_MAX_LENGTH, URIParseError from lbrynet.wallet.server.block_processor import LBRYBlockProcessor from lbrynet.wallet.server.db import LBRYDB @@ -22,13 +23,12 @@ class LBRYElectrumX(ElectrumX): self.daemon = self.session_mgr.daemon self.bp: LBRYBlockProcessor = self.session_mgr.bp self.db: LBRYDB = self.bp.db - # fixme: lbryum specific subscribe - self.subscribe_height = False def set_request_handlers(self, ptuple): super().set_request_handlers(ptuple) handlers = { 'blockchain.transaction.get_height': self.transaction_get_height, + 'blockchain.claimtrie.search': self.claimtrie_search, 'blockchain.claimtrie.getclaimbyid': self.claimtrie_getclaimbyid, 'blockchain.claimtrie.getclaimsforname': self.claimtrie_getclaimsforname, 'blockchain.claimtrie.getclaimsbyids': self.claimtrie_getclaimsbyids, @@ -42,67 +42,8 @@ class LBRYElectrumX(ElectrumX): 'blockchain.claimtrie.getclaimssignedbyid': self.claimtrie_getclaimssignedbyid, 'blockchain.block.get_server_height': self.get_server_height, } - # fixme: methods we use but shouldnt be using anymore. To be removed when torba goes out - handlers.update({ - 'blockchain.numblocks.subscribe': self.numblocks_subscribe, - 'blockchain.utxo.get_address': self.utxo_get_address, - 'blockchain.transaction.broadcast': - self.transaction_broadcast_1_0, - 'blockchain.transaction.get': self.transaction_get, - }) self.request_handlers.update(handlers) - async def utxo_get_address(self, tx_hash, index): - # fixme: lbryum - # Used only for electrum client command-line requests. We no - # longer index by address, so need to request the raw - # transaction. So it works for any TXO not just UTXOs. - self.assert_tx_hash(tx_hash) - try: - index = int(index) - if index < 0: - raise ValueError - except ValueError: - raise RPCError(1, "index has to be >= 0 and integer") - raw_tx = await self.daemon_request('getrawtransaction', tx_hash) - if not raw_tx: - return None - raw_tx = util.hex_to_bytes(raw_tx) - tx = self.coin.DESERIALIZER(raw_tx).read_tx() - if index >= len(tx.outputs): - return None - return self.coin.address_from_script(tx.outputs[index].pk_script) - - async def transaction_broadcast_1_0(self, raw_tx): - # fixme: lbryum - # An ugly API: current Electrum clients only pass the raw - # transaction in hex and expect error messages to be returned in - # the result field. And the server shouldn't be doing the client's - # user interface job here. - try: - return await self.transaction_broadcast(raw_tx) - except RPCError as e: - return e.message - - async def numblocks_subscribe(self): - # fixme workaround for lbryum - '''Subscribe to get height of new blocks.''' - self.subscribe_height = True - return self.bp.height - - async def notify(self, height, touched): - # fixme workaround for lbryum - await super().notify(height, touched) - if self.subscribe_height and height != self.notified_height: - self.send_notification('blockchain.numblocks.subscribe', (height,)) - - async def transaction_get(self, tx_hash, verbose=False): - # fixme: workaround for lbryum sending the height instead of True/False. - # fixme: lbryum_server ignored that and always used False, but this is out of spec - if verbose not in (True, False): - verbose = False - return await self.daemon_request('getrawtransaction', tx_hash, verbose) - async def get_server_height(self): return self.bp.height @@ -197,6 +138,11 @@ class LBRYElectrumX(ElectrumX): return claims return {} + async def claimtrie_search(self, **kwargs): + if 'claim_id' in kwargs: + self.assert_claim_id(kwargs['claim_id']) + return Page.to_base64(*self.db.sql.claim_search(kwargs)) + async def batched_formatted_claims_from_daemon(self, claim_ids): claims = await self.daemon.getclaimsbyids(claim_ids) result = [] @@ -217,9 +163,7 @@ class LBRYElectrumX(ElectrumX): if 'name' in claim: name = claim['name'].encode('ISO-8859-1').decode() - claim_id = claim['claimId'] - raw_claim_id = unhexlify(claim_id)[::-1] - info = self.db.get_claim_info(raw_claim_id) + info = self.db.sql.get_claims(claim_id=claim['claimId']) if not info: # raise RPCError("Lbrycrd has {} but not lbryumx, please submit a bug report.".format(claim_id)) return {} diff --git a/lbrynet/wallet/transaction.py b/lbrynet/wallet/transaction.py index 53007d6cb..5e92b10ce 100644 --- a/lbrynet/wallet/transaction.py +++ b/lbrynet/wallet/transaction.py @@ -27,13 +27,14 @@ class Output(BaseOutput): script: OutputScript script_class = OutputScript - __slots__ = 'channel', 'private_key' + __slots__ = 'channel', 'private_key', 'meta' def __init__(self, *args, channel: Optional['Output'] = None, private_key: Optional[str] = None, **kwargs) -> None: super().__init__(*args, **kwargs) self.channel = channel self.private_key = private_key + self.meta = {} def update_annotations(self, annotated): super().update_annotations(annotated) @@ -50,6 +51,10 @@ class Output(BaseOutput): def is_claim(self) -> bool: return self.script.is_claim_name or self.script.is_update_claim + @property + def is_support(self) -> bool: + return self.script.is_support_claim + @property def claim_hash(self) -> bytes: if self.script.is_claim_name: diff --git a/tests/integration/test_claim_commands.py b/tests/integration/test_claim_commands.py index 534661e19..75f77a0af 100644 --- a/tests/integration/test_claim_commands.py +++ b/tests/integration/test_claim_commands.py @@ -676,7 +676,7 @@ class StreamCommands(CommandTestCase): channel_id, txid = channel['outputs'][0]['claim_id'], channel['txid'] value = channel['outputs'][0]['value'] - claims = await self.claim_search('@abc') + claims = await self.claim_search(name='@abc') self.assertEqual(claims[0]['value'], value) claims = await self.claim_search(txid=txid, nout=0) @@ -695,10 +695,10 @@ class StreamCommands(CommandTestCase): signed = await self.stream_create('on-channel-claim', '0.0001', channel_id=channel_id) unsigned = await self.stream_create('unsigned', '0.0001') - claims = await self.claim_search('on-channel-claim') + claims = await self.claim_search(name='on-channel-claim') self.assertEqual(claims[0]['value'], signed['outputs'][0]['value']) - claims = await self.claim_search('unsigned') + claims = await self.claim_search(name='unsigned') self.assertEqual(claims[0]['value'], unsigned['outputs'][0]['value']) # list streams in a channel @@ -725,25 +725,24 @@ class StreamCommands(CommandTestCase): tx = await self.daemon.jsonrpc_account_fund(None, None, '0.001', outputs=100, broadcast=True) await self.confirm_tx(tx.id) - # 4 claims per block, 3 blocks. Sorted by height (descending) then claim_id (ascending). + # 4 claims per block, 3 blocks. Sorted by height (descending) then claim name (ascending). claims = [] for j in range(3): same_height_claims = [] for k in range(3): claim_tx = await self.stream_create(f'c{j}-{k}', '0.000001', channel_id=channel_id, confirm=False) - same_height_claims.append(claim_tx['outputs'][0]['claim_id']) + same_height_claims.append(claim_tx['outputs'][0]['name']) await self.on_transaction_dict(claim_tx) claim_tx = await self.stream_create(f'c{j}-4', '0.000001', channel_id=channel_id, confirm=True) - same_height_claims.append(claim_tx['outputs'][0]['claim_id']) - same_height_claims.sort(key=lambda x: int(x, 16)) + same_height_claims.append(claim_tx['outputs'][0]['name']) claims = same_height_claims + claims page = await self.claim_search(page_size=20, channel_id=channel_id) - page_claim_ids = [item['claim_id'] for item in page] + page_claim_ids = [item['name'] for item in page] self.assertEqual(page_claim_ids, claims) page = await self.claim_search(page_size=6, channel_id=channel_id) - page_claim_ids = [item['claim_id'] for item in page] + page_claim_ids = [item['name'] for item in page] self.assertEqual(page_claim_ids, claims[:6]) out_of_bounds = await self.claim_search(page=2, page_size=20, channel_id=channel_id) diff --git a/tests/integration/test_wallet_server.py b/tests/integration/test_wallet_server.py new file mode 100644 index 000000000..a3cf51768 --- /dev/null +++ b/tests/integration/test_wallet_server.py @@ -0,0 +1,25 @@ +from lbrynet.testcase import CommandTestCase + + +class TestClaimtrie(CommandTestCase): + + def get_claim_id(self, tx): + return tx['outputs'][0]['claim_id'] + + async def assertWinningClaim(self, name, tx): + other = (await self.out(self.daemon.jsonrpc_claim_search(name=name, is_winning=True)))['items'][0] + self.assertEqual(self.get_claim_id(tx), other['claim_id']) + + async def test_designed_edge_cases(self): + tx1 = await self.channel_create('@foo', allow_duplicate_name=True) + await self.assertWinningClaim('@foo', tx1) + tx2 = await self.channel_create('@foo', allow_duplicate_name=True) + await self.assertWinningClaim('@foo', tx1) + tx3 = await self.channel_create('@foo', allow_duplicate_name=True) + await self.assertWinningClaim('@foo', tx1) + await self.support_create(self.get_claim_id(tx3), '0.09') + await self.assertWinningClaim('@foo', tx3) + await self.support_create(self.get_claim_id(tx2), '0.19') + await self.assertWinningClaim('@foo', tx2) + await self.support_create(self.get_claim_id(tx1), '0.19') + await self.assertWinningClaim('@foo', tx1) diff --git a/tests/unit/wallet/server/test_sqldb.py b/tests/unit/wallet/server/test_sqldb.py new file mode 100644 index 000000000..cc4d3e39b --- /dev/null +++ b/tests/unit/wallet/server/test_sqldb.py @@ -0,0 +1,169 @@ +import unittest +from torba.client.constants import COIN, NULL_HASH32 + +from lbrynet.schema.claim import Claim +from lbrynet.wallet.server.db import SQLDB +from lbrynet.wallet.transaction import Transaction, Input, Output + + +def get_output(amount=COIN, pubkey_hash=NULL_HASH32): + return Transaction() \ + .add_outputs([Output.pay_pubkey_hash(amount, pubkey_hash)]) \ + .outputs[0] + + +def get_input(): + return Input.spend(get_output()) + + +def get_tx(): + return Transaction().add_inputs([get_input()]) + + +class OldWalletServerTransaction: + def __init__(self, tx): + self.tx = tx + + def serialize(self): + return self.tx.raw + + +class TestSQLDB(unittest.TestCase): + + def setUp(self): + self.sql = SQLDB(':memory:') + self.sql.open() + self._current_height = 0 + + def _make_tx(self, output): + tx = get_tx().add_outputs([output]) + return OldWalletServerTransaction(tx), tx.hash + + def get_channel(self, title, amount, name='@foo'): + claim = Claim() + claim.channel.title = title + channel = Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc') + channel.generate_channel_private_key() + return self._make_tx(channel) + + def get_stream(self, title, amount, name='foo'): + claim = Claim() + claim.stream.title = title + return self._make_tx(Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc')) + + def get_stream_update(self, tx, amount): + claim = Transaction(tx[0].serialize()).outputs[0] + return self._make_tx( + Output.pay_update_claim_pubkey_hash( + amount, claim.claim_name, claim.claim_id, claim.claim, b'abc' + ) + ) + + def get_support(self, tx, amount): + claim = Transaction(tx[0].serialize()).outputs[0] + return self._make_tx( + Output.pay_support_pubkey_hash( + amount, claim.claim_name, claim.claim_id, b'abc' + ) + ) + + def get_controlling(self): + for claim in self.sql.execute("select claim.*, raw from claimtrie natural join claim natural join tx"): + txo = Transaction(claim['raw']).outputs[0] + controlling = txo.claim.stream.title, claim['amount'], claim['effective_amount'], claim['activation_height'] + return controlling + + def get_active(self): + controlling = self.get_controlling() + active = [] + for claim in self.sql.execute( + f"select claim.*, raw from claim join tx using (tx_hash) " + f"where activation_height <= {self._current_height}"): + txo = Transaction(claim['raw']).outputs[0] + if controlling and controlling[0] == txo.claim.stream.title: + continue + active.append((txo.claim.stream.title, claim['amount'], claim['effective_amount'], claim['activation_height'])) + return active + + def get_accepted(self): + accepted = [] + for claim in self.sql.execute( + f"select claim.*, raw from claim join tx using (tx_hash) " + f"where activation_height > {self._current_height}"): + txo = Transaction(claim['raw']).outputs[0] + accepted.append((txo.claim.stream.title, claim['amount'], claim['effective_amount'], claim['activation_height'])) + return accepted + + def advance(self, height, txs): + self._current_height = height + self.sql.advance_txs(height, txs) + + def state(self, controlling=None, active=None, accepted=None): + self.assertEqual(controlling or [], self.get_controlling()) + self.assertEqual(active or [], self.get_active()) + self.assertEqual(accepted or [], self.get_accepted()) + + def test_example_from_spec(self): + # https://spec.lbry.com/#claim-activation-example + advance, state = self.advance, self.state + stream = self.get_stream('Claim A', 10*COIN) + advance(13, [stream]) + state( + controlling=('Claim A', 10*COIN, 10*COIN, 13), + active=[], + accepted=[] + ) + advance(1001, [self.get_stream('Claim B', 20*COIN)]) + state( + controlling=('Claim A', 10*COIN, 10*COIN, 13), + active=[], + accepted=[('Claim B', 20*COIN, 0, 1031)] + ) + advance(1010, [self.get_support(stream, 14*COIN)]) + state( + controlling=('Claim A', 10*COIN, 24*COIN, 13), + active=[], + accepted=[('Claim B', 20*COIN, 0, 1031)] + ) + advance(1020, [self.get_stream('Claim C', 50*COIN)]) + state( + controlling=('Claim A', 10*COIN, 24*COIN, 13), + active=[], + accepted=[ + ('Claim B', 20*COIN, 0, 1031), + ('Claim C', 50*COIN, 0, 1051)] + ) + advance(1031, []) + state( + controlling=('Claim A', 10*COIN, 24*COIN, 13), + active=[('Claim B', 20*COIN, 20*COIN, 1031)], + accepted=[('Claim C', 50*COIN, 0, 1051)] + ) + advance(1040, [self.get_stream('Claim D', 300*COIN)]) + state( + controlling=('Claim A', 10*COIN, 24*COIN, 13), + active=[('Claim B', 20*COIN, 20*COIN, 1031)], + accepted=[ + ('Claim C', 50*COIN, 0, 1051), + ('Claim D', 300*COIN, 0, 1072)] + ) + advance(1051, []) + state( + controlling=('Claim D', 300*COIN, 300*COIN, 1051), + active=[ + ('Claim A', 10*COIN, 24*COIN, 1051), + ('Claim B', 20*COIN, 20*COIN, 1051), + ('Claim C', 50*COIN, 50*COIN, 1051)], + accepted=[] + ) + # beyond example + advance(1052, [self.get_stream_update(stream, 290*COIN)]) + state( + controlling=('Claim A', 290*COIN, 304*COIN, 1052), + active=[ + ('Claim B', 20*COIN, 20*COIN, 1052), + ('Claim C', 50*COIN, 50*COIN, 1052), + ('Claim D', 300*COIN, 300*COIN, 1052), + ], + accepted=[] + )