postgresql support added

This commit is contained in:
Lex Berezhny 2020-04-11 17:27:41 -04:00
parent 2309d6354c
commit ae9d4af8c0
33 changed files with 1376 additions and 1011 deletions

View file

@ -36,6 +36,9 @@ jobs:
- datanetwork - datanetwork
- blockchain - blockchain
- other - other
db:
- postgres
- sqlite
steps: steps:
- uses: actions/checkout@v1 - uses: actions/checkout@v1
- uses: actions/setup-python@v1 - uses: actions/setup-python@v1
@ -44,7 +47,9 @@ jobs:
- if: matrix.test == 'other' - if: matrix.test == 'other'
run: sudo apt install -y --no-install-recommends ffmpeg run: sudo apt install -y --no-install-recommends ffmpeg
- run: pip install tox-travis - run: pip install tox-travis
- run: tox -e ${{ matrix.test }} - env:
TEST_DB: ${{ matrix.db }}
run: tox -e ${{ matrix.test }}
build: build:
needs: ["lint", "tests-unit", "tests-integration"] needs: ["lint", "tests-unit", "tests-integration"]

View file

@ -1,172 +0,0 @@
import os
import asyncio
from concurrent import futures
from collections import namedtuple, deque
import sqlite3
import apsw
DDL = """
pragma journal_mode=WAL;
create table if not exists block (
block_hash bytes not null primary key,
previous_hash bytes not null,
file_number integer not null,
height int
);
create table if not exists tx (
block_hash integer not null,
position integer not null,
tx_hash bytes not null
);
create table if not exists txi (
block_hash bytes not null,
tx_hash bytes not null,
txo_hash bytes not null
);
create table if not exists claim (
txo_hash bytes not null,
claim_hash bytes not null,
claim_name text not null,
amount integer not null,
height integer
);
create table if not exists claim_history (
block_hash bytes not null,
tx_hash bytes not null,
tx_position integer not null,
txo_hash bytes not null,
claim_hash bytes not null,
claim_name text not null,
action integer not null,
amount integer not null,
height integer,
is_spent bool
);
create table if not exists support (
block_hash bytes not null,
tx_hash bytes not null,
txo_hash bytes not null,
claim_hash bytes not null,
amount integer not null
);
"""
class BlockchainDB:
__slots__ = 'db', 'directory'
def __init__(self, path: str):
self.db = None
self.directory = path
@property
def db_file_path(self):
return os.path.join(self.directory, 'blockchain.db')
def open(self):
self.db = sqlite3.connect(self.db_file_path, isolation_level=None, uri=True, timeout=60.0 * 5)
self.db.executescript("""
pragma journal_mode=wal;
""")
# self.db = apsw.Connection(
# self.db_file_path,
# flags=(
# apsw.SQLITE_OPEN_READWRITE |
# apsw.SQLITE_OPEN_CREATE |
# apsw.SQLITE_OPEN_URI
# )
# )
self.execute_ddl(DDL)
self.execute(f"ATTACH ? AS block_index", ('file:'+os.path.join(self.directory, 'block_index.sqlite')+'?mode=ro',))
#def exec_factory(cursor, statement, bindings):
# tpl = namedtuple('row', (d[0] for d in cursor.getdescription()))
# cursor.setrowtrace(lambda cursor, row: tpl(*row))
# return True
#self.db.setexectrace(exec_factory)
def row_factory(cursor, row):
tpl = namedtuple('row', (d[0] for d in cursor.description))
return tpl(*row)
self.db.row_factory = row_factory
return self
def close(self):
if self.db is not None:
self.db.close()
def execute(self, *args):
return self.db.cursor().execute(*args)
def execute_many(self, *args):
return self.db.cursor().executemany(*args)
def execute_many_tx(self, *args):
cursor = self.db.cursor()
cursor.execute('begin;')
result = cursor.executemany(*args)
cursor.execute('commit;')
return result
def execute_ddl(self, *args):
self.db.executescript(*args)
#deque(self.execute(*args), maxlen=0)
def begin(self):
self.execute('begin;')
def commit(self):
self.execute('commit;')
def get_block_file_path_from_number(self, block_file_number):
return os.path.join(self.directory, 'blocks', f'blk{block_file_number:05}.dat')
def get_block_files_not_synced(self):
return list(self.execute(
"""
SELECT file as file_number, COUNT(hash) as blocks, SUM(txcount) as txs
FROM block_index.block_info
WHERE hash NOT IN (SELECT block_hash FROM block)
GROUP BY file ORDER BY file ASC;
"""
))
def get_blocks_not_synced(self, block_file):
return self.execute(
"""
SELECT datapos as data_offset, height, hash as block_hash, txCount as txs
FROM block_index.block_info
WHERE file = ? AND hash NOT IN (SELECT block_hash FROM block)
ORDER BY datapos ASC;
""", (block_file,)
)
class AsyncBlockchainDB:
def __init__(self, db: BlockchainDB):
self.sync_db = db
self.executor = futures.ThreadPoolExecutor(max_workers=1)
@classmethod
def from_path(cls, path: str) -> 'AsyncBlockchainDB':
return cls(BlockchainDB(path))
def get_block_file_path_from_number(self, block_file_number):
return self.sync_db.get_block_file_path_from_number(block_file_number)
async def run_in_executor(self, func, *args):
return await asyncio.get_running_loop().run_in_executor(
self.executor, func, *args
)
async def open(self):
return await self.run_in_executor(self.sync_db.open)
async def close(self):
return await self.run_in_executor(self.sync_db.close)
async def get_block_files_not_synced(self):
return await self.run_in_executor(self.sync_db.get_block_files_not_synced)

6
lbry/db/__init__.py Normal file
View file

@ -0,0 +1,6 @@
from .database import Database, in_account
from .tables import (
Table, Version, metadata,
AccountAddress, PubkeyAddress,
Block, TX, TXO, TXI
)

1004
lbry/db/database.py Normal file

File diff suppressed because it is too large Load diff

82
lbry/db/tables.py Normal file
View file

@ -0,0 +1,82 @@
from sqlalchemy import (
MetaData, Table, Column, ForeignKey,
Binary, Text, SmallInteger, Integer, Boolean
)
metadata = MetaData()
Version = Table(
'version', metadata,
Column('version', Text, primary_key=True),
)
PubkeyAddress = Table(
'pubkey_address', metadata,
Column('address', Text, primary_key=True),
Column('history', Text, nullable=True),
Column('used_times', Integer, server_default='0'),
)
AccountAddress = Table(
'account_address', metadata,
Column('account', Text, primary_key=True),
Column('address', Text, ForeignKey(PubkeyAddress.columns.address), primary_key=True),
Column('chain', Integer),
Column('pubkey', Binary),
Column('chain_code', Binary),
Column('n', Integer),
Column('depth', Integer),
)
Block = Table(
'block', metadata,
Column('block_hash', Binary, primary_key=True),
Column('previous_hash', Binary),
Column('file_number', SmallInteger),
Column('height', Integer),
)
TX = Table(
'tx', metadata,
Column('block_hash', Binary, nullable=True),
Column('tx_hash', Binary, primary_key=True),
Column('raw', Binary),
Column('height', Integer),
Column('position', SmallInteger),
Column('is_verified', Boolean, server_default='FALSE'),
Column('purchased_claim_hash', Binary, nullable=True),
Column('day', Integer, nullable=True),
)
TXO = Table(
'txo', metadata,
Column('tx_hash', Binary, ForeignKey(TX.columns.tx_hash)),
Column('txo_hash', Binary, primary_key=True),
Column('address', Text),
Column('position', Integer),
Column('amount', Integer),
Column('script', Binary),
Column('is_reserved', Boolean, server_default='0'),
Column('txo_type', Integer, server_default='0'),
Column('claim_id', Text, nullable=True),
Column('claim_hash', Binary, nullable=True),
Column('claim_name', Text, nullable=True),
Column('channel_hash', Binary, nullable=True),
Column('reposted_claim_hash', Binary, nullable=True),
)
TXI = Table(
'txi', metadata,
Column('tx_hash', Binary, ForeignKey(TX.columns.tx_hash)),
Column('txo_hash', Binary, ForeignKey(TXO.columns.txo_hash), primary_key=True),
Column('address', Text),
Column('position', Integer),
)

View file

@ -18,6 +18,7 @@ from functools import wraps, partial
import ecdsa import ecdsa
import base58 import base58
from sqlalchemy import text
from aiohttp import web from aiohttp import web
from prometheus_client import generate_latest as prom_generate_latest from prometheus_client import generate_latest as prom_generate_latest
from google.protobuf.message import DecodeError from google.protobuf.message import DecodeError
@ -1530,7 +1531,7 @@ class Daemon(metaclass=JSONRPCServerType):
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
account = wallet.get_account_or_default(account_id) account = wallet.get_account_or_default(account_id)
balance = await account.get_detailed_balance( balance = await account.get_detailed_balance(
confirmations=confirmations, reserved_subtotals=True, read_only=True confirmations=confirmations, reserved_subtotals=True,
) )
return dict_values_to_lbc(balance) return dict_values_to_lbc(balance)
@ -1855,7 +1856,7 @@ class Daemon(metaclass=JSONRPCServerType):
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
account = wallet.get_account_or_default(account_id) account = wallet.get_account_or_default(account_id)
match = await self.ledger.db.get_address(read_only=True, address=address, accounts=[account]) match = await self.ledger.db.get_address(address=address, accounts=[account])
if match is not None: if match is not None:
return True return True
return False return False
@ -1879,9 +1880,7 @@ class Daemon(metaclass=JSONRPCServerType):
Returns: {Paginated[Address]} Returns: {Paginated[Address]}
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
constraints = { constraints = {}
'cols': ('address', 'account', 'used_times', 'pubkey', 'chain_code', 'n', 'depth')
}
if address: if address:
constraints['address'] = address constraints['address'] = address
if account_id: if account_id:
@ -1891,7 +1890,7 @@ class Daemon(metaclass=JSONRPCServerType):
return paginate_rows( return paginate_rows(
self.ledger.get_addresses, self.ledger.get_addresses,
self.ledger.get_address_count, self.ledger.get_address_count,
page, page_size, read_only=True, **constraints page, page_size, **constraints
) )
@requires(WALLET_COMPONENT) @requires(WALLET_COMPONENT)
@ -1968,7 +1967,7 @@ class Daemon(metaclass=JSONRPCServerType):
txo.purchased_claim_id: txo for txo in txo.purchased_claim_id: txo for txo in
await self.ledger.db.get_purchases( await self.ledger.db.get_purchases(
accounts=wallet.accounts, accounts=wallet.accounts,
purchased_claim_id__in=[s.claim_id for s in paginated['items']] purchased_claim_hash__in=[unhexlify(s.claim_id)[::-1] for s in paginated['items']]
) )
} }
for stream in paginated['items']: for stream in paginated['items']:
@ -2630,7 +2629,7 @@ class Daemon(metaclass=JSONRPCServerType):
accounts = wallet.accounts accounts = wallet.accounts
existing_channels = await self.ledger.get_claims( existing_channels = await self.ledger.get_claims(
wallet=wallet, accounts=accounts, claim_id=claim_id wallet=wallet, accounts=accounts, claim_hash=unhexlify(claim_id)[::-1]
) )
if len(existing_channels) != 1: if len(existing_channels) != 1:
account_ids = ', '.join(f"'{account.id}'" for account in accounts) account_ids = ', '.join(f"'{account.id}'" for account in accounts)
@ -2721,7 +2720,7 @@ class Daemon(metaclass=JSONRPCServerType):
if txid is not None and nout is not None: if txid is not None and nout is not None:
claims = await self.ledger.get_claims( claims = await self.ledger.get_claims(
wallet=wallet, accounts=accounts, **{'txo.txid': txid, 'txo.position': nout} wallet=wallet, accounts=accounts, tx_hash=unhexlify(txid)[::-1], position=nout
) )
elif claim_id is not None: elif claim_id is not None:
claims = await self.ledger.get_claims( claims = await self.ledger.get_claims(
@ -3477,7 +3476,7 @@ class Daemon(metaclass=JSONRPCServerType):
if txid is not None and nout is not None: if txid is not None and nout is not None:
claims = await self.ledger.get_claims( claims = await self.ledger.get_claims(
wallet=wallet, accounts=accounts, **{'txo.txid': txid, 'txo.position': nout} wallet=wallet, accounts=accounts, tx_hash=unhexlify(txid)[::-1], position=nout
) )
elif claim_id is not None: elif claim_id is not None:
claims = await self.ledger.get_claims( claims = await self.ledger.get_claims(
@ -4053,7 +4052,7 @@ class Daemon(metaclass=JSONRPCServerType):
if txid is not None and nout is not None: if txid is not None and nout is not None:
supports = await self.ledger.get_supports( supports = await self.ledger.get_supports(
wallet=wallet, accounts=accounts, **{'txo.txid': txid, 'txo.position': nout} wallet=wallet, accounts=accounts, tx_hash=unhexlify(txid)[::-1], position=nout
) )
elif claim_id is not None: elif claim_id is not None:
supports = await self.ledger.get_supports( supports = await self.ledger.get_supports(
@ -4165,7 +4164,7 @@ class Daemon(metaclass=JSONRPCServerType):
self.ledger.get_transaction_history, wallet=wallet, accounts=wallet.accounts) self.ledger.get_transaction_history, wallet=wallet, accounts=wallet.accounts)
transaction_count = partial( transaction_count = partial(
self.ledger.get_transaction_history_count, wallet=wallet, accounts=wallet.accounts) self.ledger.get_transaction_history_count, wallet=wallet, accounts=wallet.accounts)
return paginate_rows(transactions, transaction_count, page, page_size, read_only=True) return paginate_rows(transactions, transaction_count, page, page_size)
@requires(WALLET_COMPONENT) @requires(WALLET_COMPONENT)
def jsonrpc_transaction_show(self, txid): def jsonrpc_transaction_show(self, txid):
@ -4180,7 +4179,7 @@ class Daemon(metaclass=JSONRPCServerType):
Returns: {Transaction} Returns: {Transaction}
""" """
return self.wallet_manager.get_transaction(txid) return self.wallet_manager.get_transaction(unhexlify(txid)[::-1])
TXO_DOC = """ TXO_DOC = """
List and sum transaction outputs. List and sum transaction outputs.
@ -4210,12 +4209,13 @@ class Daemon(metaclass=JSONRPCServerType):
constraints['is_my_output'] = True constraints['is_my_output'] = True
elif is_not_my_output is True: elif is_not_my_output is True:
constraints['is_my_output'] = False constraints['is_my_output'] = False
to_hash = lambda x: unhexlify(x)[::-1]
database.constrain_single_or_list(constraints, 'txo_type', type, lambda x: TXO_TYPES[x]) database.constrain_single_or_list(constraints, 'txo_type', type, lambda x: TXO_TYPES[x])
database.constrain_single_or_list(constraints, 'channel_id', channel_id) database.constrain_single_or_list(constraints, 'channel_hash', channel_id, to_hash)
database.constrain_single_or_list(constraints, 'claim_id', claim_id) database.constrain_single_or_list(constraints, 'claim_hash', claim_id, to_hash)
database.constrain_single_or_list(constraints, 'claim_name', name) database.constrain_single_or_list(constraints, 'claim_name', name)
database.constrain_single_or_list(constraints, 'txid', txid) database.constrain_single_or_list(constraints, 'tx_hash', txid, to_hash)
database.constrain_single_or_list(constraints, 'reposted_claim_id', reposted_claim_id) database.constrain_single_or_list(constraints, 'reposted_claim_hash', reposted_claim_id, to_hash)
return constraints return constraints
@requires(WALLET_COMPONENT) @requires(WALLET_COMPONENT)
@ -4274,8 +4274,8 @@ class Daemon(metaclass=JSONRPCServerType):
claims = account.get_txos claims = account.get_txos
claim_count = account.get_txo_count claim_count = account.get_txo_count
else: else:
claims = partial(self.ledger.get_txos, wallet=wallet, accounts=wallet.accounts, read_only=True) claims = partial(self.ledger.get_txos, wallet=wallet, accounts=wallet.accounts)
claim_count = partial(self.ledger.get_txo_count, wallet=wallet, accounts=wallet.accounts, read_only=True) claim_count = partial(self.ledger.get_txo_count, wallet=wallet, accounts=wallet.accounts)
constraints = { constraints = {
'resolve': resolve, 'resolve': resolve,
'include_is_spent': True, 'include_is_spent': True,
@ -4332,7 +4332,7 @@ class Daemon(metaclass=JSONRPCServerType):
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
accounts = [wallet.get_account_or_error(account_id)] if account_id else wallet.accounts accounts = [wallet.get_account_or_error(account_id)] if account_id else wallet.accounts
txos = await self.ledger.get_txos( txos = await self.ledger.get_txos(
wallet=wallet, accounts=accounts, read_only=True, wallet=wallet, accounts=accounts,
**self._constrain_txo_from_kwargs({}, is_not_spent=True, is_my_output=True, **kwargs) **self._constrain_txo_from_kwargs({}, is_not_spent=True, is_my_output=True, **kwargs)
) )
txs = [] txs = []
@ -4391,7 +4391,7 @@ class Daemon(metaclass=JSONRPCServerType):
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
return self.ledger.get_txo_sum( return self.ledger.get_txo_sum(
wallet=wallet, accounts=[wallet.get_account_or_error(account_id)] if account_id else wallet.accounts, wallet=wallet, accounts=[wallet.get_account_or_error(account_id)] if account_id else wallet.accounts,
read_only=True, **self._constrain_txo_from_kwargs({}, **kwargs) **self._constrain_txo_from_kwargs({}, **kwargs)
) )
@requires(WALLET_COMPONENT) @requires(WALLET_COMPONENT)
@ -4447,7 +4447,7 @@ class Daemon(metaclass=JSONRPCServerType):
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
plot = await self.ledger.get_txo_plot( plot = await self.ledger.get_txo_plot(
wallet=wallet, accounts=[wallet.get_account_or_error(account_id)] if account_id else wallet.accounts, wallet=wallet, accounts=[wallet.get_account_or_error(account_id)] if account_id else wallet.accounts,
read_only=True, days_back=days_back, start_day=start_day, days_after=days_after, end_day=end_day, days_back=days_back, start_day=start_day, days_after=days_after, end_day=end_day,
**self._constrain_txo_from_kwargs({}, **kwargs) **self._constrain_txo_from_kwargs({}, **kwargs)
) )
for row in plot: for row in plot:

View file

@ -1,7 +1,7 @@
import logging import logging
from decimal import Decimal from decimal import Decimal
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from datetime import datetime from datetime import datetime, date
from json import JSONEncoder from json import JSONEncoder
from google.protobuf.message import DecodeError from google.protobuf.message import DecodeError
@ -134,6 +134,8 @@ class JSONResponseEncoder(JSONEncoder):
return self.encode_claim(obj) return self.encode_claim(obj)
if isinstance(obj, PubKey): if isinstance(obj, PubKey):
return obj.extended_key_string() return obj.extended_key_string()
if isinstance(obj, date):
return obj.isoformat()
if isinstance(obj, datetime): if isinstance(obj, datetime):
return obj.strftime("%Y%m%dT%H:%M:%S") return obj.strftime("%Y%m%dT%H:%M:%S")
if isinstance(obj, Decimal): if isinstance(obj, Decimal):

View file

@ -6,7 +6,7 @@ import asyncio
import binascii import binascii
import time import time
from typing import Optional from typing import Optional
from lbry.wallet import SQLiteMixin from lbry.wallet.database import SQLiteMixin
from lbry.conf import Config from lbry.conf import Config
from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies
from lbry.wallet.transaction import Transaction from lbry.wallet.transaction import Transaction

View file

@ -148,7 +148,7 @@ class Outputs:
for txo_message in chain(outputs.txos, outputs.extra_txos): for txo_message in chain(outputs.txos, outputs.extra_txos):
if txo_message.WhichOneof('meta') == 'error': if txo_message.WhichOneof('meta') == 'error':
continue continue
txs.add((hexlify(txo_message.tx_hash[::-1]).decode(), txo_message.height)) txs.add((txo_message.tx_hash, txo_message.height))
return cls( return cls(
outputs.txos, outputs.extra_txos, txs, outputs.txos, outputs.extra_txos, txs,
outputs.offset, outputs.total, outputs.offset, outputs.total,

View file

@ -253,6 +253,11 @@ class IntegrationTestCase(AsyncioTestCase):
lambda e: e.tx.id == txid lambda e: e.tx.id == txid
) )
def on_transaction_hash(self, tx_hash, ledger=None):
return (ledger or self.ledger).on_transaction.where(
lambda e: e.tx.hash == tx_hash
)
def on_address_update(self, address): def on_address_update(self, address):
return self.ledger.on_transaction.where( return self.ledger.on_transaction.where(
lambda e: e.address == address lambda e: e.address == address
@ -316,7 +321,7 @@ class CommandTestCase(IntegrationTestCase):
self.server_config = None self.server_config = None
self.server_storage = None self.server_storage = None
self.extra_wallet_nodes = [] self.extra_wallet_nodes = []
self.extra_wallet_node_port = 5280 self.extra_wallet_node_port = 5281
self.server_blob_manager = None self.server_blob_manager = None
self.server = None self.server = None
self.reflector = None self.reflector = None

View file

@ -6,6 +6,7 @@ __node_url__ = (
) )
__spvserver__ = 'lbry.wallet.server.coin.LBCRegTest' __spvserver__ = 'lbry.wallet.server.coin.LBCRegTest'
from .bip32 import PubKey
from .wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DISK from .wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DISK
from .manager import WalletManager from .manager import WalletManager
from .network import Network from .network import Network
@ -13,5 +14,4 @@ from .ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent
from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic
from .transaction import Transaction, Output, Input from .transaction import Transaction, Output, Input
from .script import OutputScript, InputScript from .script import OutputScript, InputScript
from .database import SQLiteMixin, Database
from .header import Headers from .header import Headers

View file

@ -10,6 +10,8 @@ from hashlib import sha256
from string import hexdigits from string import hexdigits
from typing import Type, Dict, Tuple, Optional, Any, List from typing import Type, Dict, Tuple, Optional, Any, List
from sqlalchemy import text
import ecdsa import ecdsa
from lbry.error import InvalidPasswordError from lbry.error import InvalidPasswordError
from lbry.crypto.crypt import aes_encrypt, aes_decrypt from lbry.crypto.crypt import aes_encrypt, aes_decrypt
@ -71,7 +73,6 @@ class AddressManager:
def _query_addresses(self, **constraints): def _query_addresses(self, **constraints):
return self.account.ledger.db.get_addresses( return self.account.ledger.db.get_addresses(
read_only=constraints.pop("read_only", False),
accounts=[self.account], accounts=[self.account],
chain=self.chain_number, chain=self.chain_number,
**constraints **constraints
@ -435,8 +436,8 @@ class Account:
addresses.extend(new_addresses) addresses.extend(new_addresses)
return addresses return addresses
async def get_addresses(self, read_only=False, **constraints) -> List[str]: async def get_addresses(self, **constraints) -> List[str]:
rows = await self.ledger.db.select_addresses('address', read_only=read_only, accounts=[self], **constraints) rows = await self.ledger.db.select_addresses([text('account_address.address')], accounts=[self], **constraints)
return [r['address'] for r in rows] return [r['address'] for r in rows]
def get_address_records(self, **constraints): def get_address_records(self, **constraints):
@ -452,13 +453,13 @@ class Account:
def get_public_key(self, chain: int, index: int) -> PubKey: def get_public_key(self, chain: int, index: int) -> PubKey:
return self.address_managers[chain].get_public_key(index) return self.address_managers[chain].get_public_key(index)
def get_balance(self, confirmations=0, include_claims=False, read_only=False, **constraints): def get_balance(self, confirmations=0, include_claims=False, **constraints):
if not include_claims: if not include_claims:
constraints.update({'txo_type__in': (TXO_TYPES['other'], TXO_TYPES['purchase'])}) constraints.update({'txo_type__in': (TXO_TYPES['other'], TXO_TYPES['purchase'])})
if confirmations > 0: if confirmations > 0:
height = self.ledger.headers.height - (confirmations-1) height = self.ledger.headers.height - (confirmations-1)
constraints.update({'height__lte': height, 'height__gt': 0}) constraints.update({'height__lte': height, 'height__gt': 0})
return self.ledger.db.get_balance(accounts=[self], read_only=read_only, **constraints) return self.ledger.db.get_balance(accounts=[self], **constraints)
async def get_max_gap(self): async def get_max_gap(self):
change_gap = await self.change.get_max_gap() change_gap = await self.change.get_max_gap()
@ -564,9 +565,9 @@ class Account:
if gap_changed: if gap_changed:
self.wallet.save() self.wallet.save()
async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False, read_only=False): async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False):
tips_balance, supports_balance, claims_balance = 0, 0, 0 tips_balance, supports_balance, claims_balance = 0, 0, 0
get_total_balance = partial(self.get_balance, read_only=read_only, confirmations=confirmations, get_total_balance = partial(self.get_balance, confirmations=confirmations,
include_claims=True) include_claims=True)
total = await get_total_balance() total = await get_total_balance()
if reserved_subtotals: if reserved_subtotals:
@ -594,14 +595,14 @@ class Account:
} if reserved_subtotals else None } if reserved_subtotals else None
} }
def get_transaction_history(self, read_only=False, **constraints): def get_transaction_history(self, **constraints):
return self.ledger.get_transaction_history( return self.ledger.get_transaction_history(
read_only=read_only, wallet=self.wallet, accounts=[self], **constraints wallet=self.wallet, accounts=[self], **constraints
) )
def get_transaction_history_count(self, read_only=False, **constraints): def get_transaction_history_count(self, **constraints):
return self.ledger.get_transaction_history_count( return self.ledger.get_transaction_history_count(
read_only=read_only, wallet=self.wallet, accounts=[self], **constraints wallet=self.wallet, accounts=[self], **constraints
) )
def get_claims(self, **constraints): def get_claims(self, **constraints):

View file

@ -9,12 +9,6 @@ from contextvars import ContextVar
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from concurrent.futures.process import ProcessPoolExecutor from concurrent.futures.process import ProcessPoolExecutor
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional
from datetime import date
from .bip32 import PubKey
from .transaction import Transaction, Output, OutputScript, TXRefImmutable
from .constants import TXO_TYPES, CLAIM_TYPES
from .util import date_to_julian_day
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -389,697 +383,3 @@ def dict_row_factory(cursor, row):
for idx, col in enumerate(cursor.description): for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx] d[col[0]] = row[idx]
return d return d
class Database(SQLiteMixin):
SCHEMA_VERSION = "1.3"
PRAGMAS = """
pragma journal_mode=WAL;
"""
CREATE_ACCOUNT_TABLE = """
create table if not exists account_address (
account text not null,
address text not null,
chain integer not null,
pubkey blob not null,
chain_code blob not null,
n integer not null,
depth integer not null,
primary key (account, address)
);
create index if not exists address_account_idx on account_address (address, account);
"""
CREATE_PUBKEY_ADDRESS_TABLE = """
create table if not exists pubkey_address (
address text primary key,
history text,
used_times integer not null default 0
);
"""
CREATE_TX_TABLE = """
create table if not exists tx (
txid text primary key,
raw blob not null,
height integer not null,
position integer not null,
is_verified boolean not null default 0,
purchased_claim_id text,
day integer
);
create index if not exists tx_purchased_claim_id_idx on tx (purchased_claim_id);
"""
CREATE_TXO_TABLE = """
create table if not exists txo (
txid text references tx,
txoid text primary key,
address text references pubkey_address,
position integer not null,
amount integer not null,
script blob not null,
is_reserved boolean not null default 0,
txo_type integer not null default 0,
claim_id text,
claim_name text,
channel_id text,
reposted_claim_id text
);
create index if not exists txo_txid_idx on txo (txid);
create index if not exists txo_address_idx on txo (address);
create index if not exists txo_claim_id_idx on txo (claim_id, txo_type);
create index if not exists txo_claim_name_idx on txo (claim_name);
create index if not exists txo_txo_type_idx on txo (txo_type);
create index if not exists txo_channel_id_idx on txo (channel_id);
create index if not exists txo_reposted_claim_idx on txo (reposted_claim_id);
"""
CREATE_TXI_TABLE = """
create table if not exists txi (
txid text references tx,
txoid text references txo primary key,
address text references pubkey_address,
position integer not null
);
create index if not exists txi_address_idx on txi (address);
create index if not exists first_input_idx on txi (txid, address) where position=0;
"""
CREATE_TABLES_QUERY = (
PRAGMAS +
CREATE_ACCOUNT_TABLE +
CREATE_PUBKEY_ADDRESS_TABLE +
CREATE_TX_TABLE +
CREATE_TXO_TABLE +
CREATE_TXI_TABLE
)
async def open(self):
await super().open()
self.db.writer_connection.row_factory = dict_row_factory
def txo_to_row(self, tx, txo):
row = {
'txid': tx.id,
'txoid': txo.id,
'address': txo.get_address(self.ledger),
'position': txo.position,
'amount': txo.amount,
'script': sqlite3.Binary(txo.script.source)
}
if txo.is_claim:
if txo.can_decode_claim:
claim = txo.claim
row['txo_type'] = TXO_TYPES.get(claim.claim_type, TXO_TYPES['stream'])
if claim.is_repost:
row['reposted_claim_id'] = claim.repost.reference.claim_id
if claim.is_signed:
row['channel_id'] = claim.signing_channel_id
else:
row['txo_type'] = TXO_TYPES['stream']
elif txo.is_support:
row['txo_type'] = TXO_TYPES['support']
elif txo.purchase is not None:
row['txo_type'] = TXO_TYPES['purchase']
row['claim_id'] = txo.purchased_claim_id
if txo.script.is_claim_involved:
row['claim_id'] = txo.claim_id
row['claim_name'] = txo.claim_name
return row
def tx_to_row(self, tx):
row = {
'txid': tx.id,
'raw': sqlite3.Binary(tx.raw),
'height': tx.height,
'position': tx.position,
'is_verified': tx.is_verified,
'day': tx.get_julian_day(self.ledger),
}
txos = tx.outputs
if len(txos) >= 2 and txos[1].can_decode_purchase_data:
txos[0].purchase = txos[1]
row['purchased_claim_id'] = txos[1].purchase_data.claim_id
return row
async def insert_transaction(self, tx):
await self.db.execute_fetchall(*self._insert_sql('tx', self.tx_to_row(tx)))
async def update_transaction(self, tx):
await self.db.execute_fetchall(*self._update_sql("tx", {
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,)))
def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash):
conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True)).fetchall()
is_my_input = False
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
txo = txi.txo_ref.txo
if txo.has_address and txo.get_address(self.ledger) == address:
is_my_input = True
conn.execute(*self._insert_sql("txi", {
'txid': tx.id,
'txoid': txo.id,
'address': address,
'position': txi.position
}, ignore_duplicate=True)).fetchall()
for txo in tx.outputs:
if txo.script.is_pay_pubkey_hash and (txo.pubkey_hash == txhash or is_my_input):
conn.execute(*self._insert_sql(
"txo", self.txo_to_row(tx, txo), ignore_duplicate=True
)).fetchall()
elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments
log.warning('Database.save_transaction_io: pay script hash is not implemented!')
def save_transaction_io(self, tx: Transaction, address, txhash, history):
return self.save_transaction_io_batch([tx], address, txhash, history)
def save_transaction_io_batch(self, txs: Iterable[Transaction], address, txhash, history):
history_count = history.count(':') // 2
def __many(conn):
for tx in txs:
self._transaction_io(conn, tx, address, txhash)
conn.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history_count, address)
).fetchall()
return self.db.run(__many)
async def reserve_outputs(self, txos, is_reserved=True):
txoids = ((is_reserved, txo.id) for txo in txos)
await self.db.executemany("UPDATE txo SET is_reserved = ? WHERE txoid = ?", txoids)
async def release_outputs(self, txos):
await self.reserve_outputs(txos, is_reserved=False)
async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
# TODO:
# 1. delete transactions above_height
# 2. update address histories removing deleted TXs
return True
async def select_transactions(self, cols, accounts=None, read_only=False, **constraints):
if not {'txid', 'txid__in'}.intersection(constraints):
assert accounts, "'accounts' argument required when no 'txid' constraint is present"
where, values = constraints_to_sql({
'$$account_address.account__in': [a.public_key.address for a in accounts]
})
constraints['txid__in'] = f"""
SELECT txo.txid FROM txo JOIN account_address USING (address) WHERE {where}
UNION
SELECT txi.txid FROM txi JOIN account_address USING (address) WHERE {where}
"""
constraints.update(values)
return await self.db.execute_fetchall(
*query(f"SELECT {cols} FROM tx", **constraints), read_only=read_only
)
TXO_NOT_MINE = Output(None, None, is_my_output=False)
async def get_transactions(self, wallet=None, **constraints):
include_is_spent = constraints.pop('include_is_spent', False)
include_is_my_input = constraints.pop('include_is_my_input', False)
include_is_my_output = constraints.pop('include_is_my_output', False)
tx_rows = await self.select_transactions(
'txid, raw, height, position, is_verified',
order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]),
**constraints
)
if not tx_rows:
return []
txids, txs, txi_txoids = [], [], []
for row in tx_rows:
txids.append(row['txid'])
txs.append(Transaction(
raw=row['raw'], height=row['height'], position=row['position'],
is_verified=bool(row['is_verified'])
))
for txi in txs[-1].inputs:
txi_txoids.append(txi.txo_ref.id)
step = self.MAX_QUERY_VARIABLES
annotated_txos = {}
for offset in range(0, len(txids), step):
annotated_txos.update({
txo.id: txo for txo in
(await self.get_txos(
wallet=wallet,
txid__in=txids[offset:offset+step], order_by='txo.txid',
include_is_spent=include_is_spent,
include_is_my_input=include_is_my_input,
include_is_my_output=include_is_my_output,
))
})
referenced_txos = {}
for offset in range(0, len(txi_txoids), step):
referenced_txos.update({
txo.id: txo for txo in
(await self.get_txos(
wallet=wallet,
txoid__in=txi_txoids[offset:offset+step], order_by='txo.txoid',
include_is_my_output=include_is_my_output,
))
})
for tx in txs:
for txi in tx.inputs:
txo = referenced_txos.get(txi.txo_ref.id)
if txo:
txi.txo_ref = txo.ref
for txo in tx.outputs:
_txo = annotated_txos.get(txo.id)
if _txo:
txo.update_annotations(_txo)
else:
txo.update_annotations(self.TXO_NOT_MINE)
for tx in txs:
txos = tx.outputs
if len(txos) >= 2 and txos[1].can_decode_purchase_data:
txos[0].purchase = txos[1]
return txs
async def get_transaction_count(self, **constraints):
constraints.pop('wallet', None)
constraints.pop('offset', None)
constraints.pop('limit', None)
constraints.pop('order_by', None)
count = await self.select_transactions('COUNT(*) as total', **constraints)
return count[0]['total'] or 0
async def get_transaction(self, **constraints):
txs = await self.get_transactions(limit=1, **constraints)
if txs:
return txs[0]
async def select_txos(
self, cols, accounts=None, is_my_input=None, is_my_output=True,
is_my_input_or_output=None, exclude_internal_transfers=False,
include_is_spent=False, include_is_my_input=False,
is_spent=None, read_only=False, **constraints):
for rename_col in ('txid', 'txoid'):
for rename_constraint in (rename_col, rename_col+'__in', rename_col+'__not_in'):
if rename_constraint in constraints:
constraints['txo.'+rename_constraint] = constraints.pop(rename_constraint)
if accounts:
account_in_sql, values = constraints_to_sql({
'$$account__in': [a.public_key.address for a in accounts]
})
my_addresses = f"SELECT address FROM account_address WHERE {account_in_sql}"
constraints.update(values)
if is_my_input_or_output:
include_is_my_input = True
constraints['received_or_sent__or'] = {
'txo.address__in': my_addresses,
'sent__and': {
'txi.address__is_not_null': True,
'txi.address__in': my_addresses
}
}
else:
if is_my_output:
constraints['txo.address__in'] = my_addresses
elif is_my_output is False:
constraints['txo.address__not_in'] = my_addresses
if is_my_input:
include_is_my_input = True
constraints['txi.address__is_not_null'] = True
constraints['txi.address__in'] = my_addresses
elif is_my_input is False:
include_is_my_input = True
constraints['is_my_input_false__or'] = {
'txi.address__is_null': True,
'txi.address__not_in': my_addresses
}
if exclude_internal_transfers:
include_is_my_input = True
constraints['exclude_internal_payments__or'] = {
'txo.txo_type__not': TXO_TYPES['other'],
'txi.address__is_null': True,
'txi.address__not_in': my_addresses
}
sql = [f"SELECT {cols} FROM txo JOIN tx ON (tx.txid=txo.txid)"]
if is_spent:
constraints['spent.txoid__is_not_null'] = True
elif is_spent is False:
constraints['is_reserved'] = False
constraints['spent.txoid__is_null'] = True
if include_is_spent or is_spent is not None:
sql.append("LEFT JOIN txi AS spent ON (spent.txoid=txo.txoid)")
if include_is_my_input:
sql.append("LEFT JOIN txi ON (txi.position=0 AND txi.txid=txo.txid)")
return await self.db.execute_fetchall(*query(' '.join(sql), **constraints), read_only=read_only)
async def get_txos(self, wallet=None, no_tx=False, read_only=False, **constraints):
include_is_spent = constraints.get('include_is_spent', False)
include_is_my_input = constraints.get('include_is_my_input', False)
include_is_my_output = constraints.pop('include_is_my_output', False)
include_received_tips = constraints.pop('include_received_tips', False)
select_columns = [
"tx.txid, raw, tx.height, tx.position as tx_position, tx.is_verified, "
"txo_type, txo.position as txo_position, amount, script"
]
my_accounts = {a.public_key.address for a in wallet.accounts} if wallet else set()
my_accounts_sql = ""
if include_is_my_output or include_is_my_input:
my_accounts_sql, values = constraints_to_sql({'$$account__in#_wallet': my_accounts})
constraints.update(values)
if include_is_my_output and my_accounts:
if constraints.get('is_my_output', None) in (True, False):
select_columns.append(f"{1 if constraints['is_my_output'] else 0} AS is_my_output")
else:
select_columns.append(f"""(
txo.address IN (SELECT address FROM account_address WHERE {my_accounts_sql})
) AS is_my_output""")
if include_is_my_input and my_accounts:
if constraints.get('is_my_input', None) in (True, False):
select_columns.append(f"{1 if constraints['is_my_input'] else 0} AS is_my_input")
else:
select_columns.append(f"""(
txi.address IS NOT NULL AND
txi.address IN (SELECT address FROM account_address WHERE {my_accounts_sql})
) AS is_my_input""")
if include_is_spent:
select_columns.append("spent.txoid IS NOT NULL AS is_spent")
if include_received_tips:
select_columns.append(f"""(
SELECT COALESCE(SUM(support.amount), 0) FROM txo AS support WHERE
support.claim_id = txo.claim_id AND
support.txo_type = {TXO_TYPES['support']} AND
support.address IN (SELECT address FROM account_address WHERE {my_accounts_sql}) AND
support.txoid NOT IN (SELECT txoid FROM txi)
) AS received_tips""")
if 'order_by' not in constraints or constraints['order_by'] == 'height':
constraints['order_by'] = [
"tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position"
]
elif constraints.get('order_by', None) == 'none':
del constraints['order_by']
rows = await self.select_txos(', '.join(select_columns), read_only=read_only, **constraints)
txos = []
txs = {}
for row in rows:
if no_tx:
txo = Output(
amount=row['amount'],
script=OutputScript(row['script']),
tx_ref=TXRefImmutable.from_id(row['txid'], row['height']),
position=row['txo_position']
)
else:
if row['txid'] not in txs:
txs[row['txid']] = Transaction(
row['raw'], height=row['height'], position=row['tx_position'],
is_verified=bool(row['is_verified'])
)
txo = txs[row['txid']].outputs[row['txo_position']]
if include_is_spent:
txo.is_spent = bool(row['is_spent'])
if include_is_my_input:
txo.is_my_input = bool(row['is_my_input'])
if include_is_my_output:
txo.is_my_output = bool(row['is_my_output'])
if include_is_my_input and include_is_my_output:
if txo.is_my_input and txo.is_my_output and row['txo_type'] == TXO_TYPES['other']:
txo.is_internal_transfer = True
else:
txo.is_internal_transfer = False
if include_received_tips:
txo.received_tips = row['received_tips']
txos.append(txo)
channel_ids = set()
for txo in txos:
if txo.is_claim and txo.can_decode_claim:
if txo.claim.is_signed:
channel_ids.add(txo.claim.signing_channel_id)
if txo.claim.is_channel and wallet:
for account in wallet.accounts:
private_key = account.get_channel_private_key(
txo.claim.channel.public_key_bytes
)
if private_key:
txo.private_key = private_key
break
if channel_ids:
channels = {
txo.claim_id: txo for txo in
(await self.get_channels(
wallet=wallet,
claim_id__in=channel_ids,
read_only=read_only
))
}
for txo in txos:
if txo.is_claim and txo.can_decode_claim:
txo.channel = channels.get(txo.claim.signing_channel_id, None)
return txos
@staticmethod
def _clean_txo_constraints_for_aggregation(constraints):
constraints.pop('include_is_spent', None)
constraints.pop('include_is_my_input', None)
constraints.pop('include_is_my_output', None)
constraints.pop('include_received_tips', None)
constraints.pop('wallet', None)
constraints.pop('resolve', None)
constraints.pop('offset', None)
constraints.pop('limit', None)
constraints.pop('order_by', None)
async def get_txo_count(self, **constraints):
self._clean_txo_constraints_for_aggregation(constraints)
count = await self.select_txos('COUNT(*) AS total', **constraints)
return count[0]['total'] or 0
async def get_txo_sum(self, **constraints):
self._clean_txo_constraints_for_aggregation(constraints)
result = await self.select_txos('SUM(amount) AS total', **constraints)
return result[0]['total'] or 0
async def get_txo_plot(self, start_day=None, days_back=0, end_day=None, days_after=None, **constraints):
self._clean_txo_constraints_for_aggregation(constraints)
if start_day is None:
constraints['day__gte'] = self.ledger.headers.estimated_julian_day(
self.ledger.headers.height
) - days_back
else:
constraints['day__gte'] = date_to_julian_day(
date.fromisoformat(start_day)
)
if end_day is not None:
constraints['day__lte'] = date_to_julian_day(
date.fromisoformat(end_day)
)
elif days_after is not None:
constraints['day__lte'] = constraints['day__gte'] + days_after
return await self.select_txos(
"DATE(day) AS day, SUM(amount) AS total",
group_by='day', order_by='day', **constraints
)
def get_utxos(self, read_only=False, **constraints):
return self.get_txos(is_spent=False, read_only=read_only, **constraints)
def get_utxo_count(self, **constraints):
return self.get_txo_count(is_spent=False, **constraints)
async def get_balance(self, wallet=None, accounts=None, read_only=False, **constraints):
assert wallet or accounts, \
"'wallet' or 'accounts' constraints required to calculate balance"
constraints['accounts'] = accounts or wallet.accounts
balance = await self.select_txos(
'SUM(amount) as total', is_spent=False, read_only=read_only, **constraints
)
return balance[0]['total'] or 0
async def select_addresses(self, cols, read_only=False, **constraints):
return await self.db.execute_fetchall(*query(
f"SELECT {cols} FROM pubkey_address JOIN account_address USING (address)",
**constraints
), read_only=read_only)
async def get_addresses(self, cols=None, read_only=False, **constraints):
cols = cols or (
'address', 'account', 'chain', 'history', 'used_times',
'pubkey', 'chain_code', 'n', 'depth'
)
addresses = await self.select_addresses(', '.join(cols), read_only=read_only, **constraints)
if 'pubkey' in cols:
for address in addresses:
address['pubkey'] = PubKey(
self.ledger, address.pop('pubkey'), address.pop('chain_code'),
address.pop('n'), address.pop('depth')
)
return addresses
async def get_address_count(self, cols=None, read_only=False, **constraints):
count = await self.select_addresses('COUNT(*) as total', read_only=read_only, **constraints)
return count[0]['total'] or 0
async def get_address(self, read_only=False, **constraints):
addresses = await self.get_addresses(read_only=read_only, limit=1, **constraints)
if addresses:
return addresses[0]
async def add_keys(self, account, chain, pubkeys):
await self.db.executemany(
"insert or ignore into account_address "
"(account, address, chain, pubkey, chain_code, n, depth) values "
"(?, ?, ?, ?, ?, ?, ?)", ((
account.id, k.address, chain,
sqlite3.Binary(k.pubkey_bytes),
sqlite3.Binary(k.chain_code),
k.n, k.depth
) for k in pubkeys)
)
await self.db.executemany(
"insert or ignore into pubkey_address (address) values (?)",
((pubkey.address,) for pubkey in pubkeys)
)
async def _set_address_history(self, address, history):
await self.db.execute_fetchall(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':')//2, address)
)
async def set_address_history(self, address, history):
await self._set_address_history(address, history)
@staticmethod
def constrain_purchases(constraints):
accounts = constraints.pop('accounts', None)
assert accounts, "'accounts' argument required to find purchases"
if not {'purchased_claim_id', 'purchased_claim_id__in'}.intersection(constraints):
constraints['purchased_claim_id__is_not_null'] = True
constraints.update({
f'$account{i}': a.public_key.address for i, a in enumerate(accounts)
})
account_values = ', '.join([f':$account{i}' for i in range(len(accounts))])
constraints['txid__in'] = f"""
SELECT txid FROM txi JOIN account_address USING (address)
WHERE account_address.account IN ({account_values})
"""
async def get_purchases(self, **constraints):
self.constrain_purchases(constraints)
return [tx.outputs[0] for tx in await self.get_transactions(**constraints)]
def get_purchase_count(self, **constraints):
self.constrain_purchases(constraints)
return self.get_transaction_count(**constraints)
@staticmethod
def constrain_claims(constraints):
if {'txo_type', 'txo_type__in'}.intersection(constraints):
return
claim_types = constraints.pop('claim_type', None)
if claim_types:
constrain_single_or_list(
constraints, 'txo_type', claim_types, lambda x: TXO_TYPES[x]
)
else:
constraints['txo_type__in'] = CLAIM_TYPES
async def get_claims(self, read_only=False, **constraints) -> List[Output]:
self.constrain_claims(constraints)
return await self.get_utxos(read_only=read_only, **constraints)
def get_claim_count(self, **constraints):
self.constrain_claims(constraints)
return self.get_utxo_count(**constraints)
@staticmethod
def constrain_streams(constraints):
constraints['txo_type'] = TXO_TYPES['stream']
def get_streams(self, read_only=False, **constraints):
self.constrain_streams(constraints)
return self.get_claims(read_only=read_only, **constraints)
def get_stream_count(self, **constraints):
self.constrain_streams(constraints)
return self.get_claim_count(**constraints)
@staticmethod
def constrain_channels(constraints):
constraints['txo_type'] = TXO_TYPES['channel']
def get_channels(self, **constraints):
self.constrain_channels(constraints)
return self.get_claims(**constraints)
def get_channel_count(self, **constraints):
self.constrain_channels(constraints)
return self.get_claim_count(**constraints)
@staticmethod
def constrain_supports(constraints):
constraints['txo_type'] = TXO_TYPES['support']
def get_supports(self, **constraints):
self.constrain_supports(constraints)
return self.get_utxos(**constraints)
def get_support_count(self, **constraints):
self.constrain_supports(constraints)
return self.get_utxo_count(**constraints)
@staticmethod
def constrain_collections(constraints):
constraints['txo_type'] = TXO_TYPES['collection']
def get_collections(self, **constraints):
self.constrain_collections(constraints)
return self.get_utxos(**constraints)
def get_collection_count(self, **constraints):
self.constrain_collections(constraints)
return self.get_utxo_count(**constraints)
async def release_all_outputs(self, account):
await self.db.execute_fetchall(
"UPDATE txo SET is_reserved = 0 WHERE"
" is_reserved = 1 AND txo.address IN ("
" SELECT address from account_address WHERE account = ?"
" )", (account.public_key.address, )
)
def get_supports_summary(self, read_only=False, **constraints):
return self.get_txos(
txo_type=TXO_TYPES['support'],
is_spent=False, is_my_output=True,
include_is_my_input=True,
no_tx=True, read_only=read_only,
**constraints
)

View file

@ -12,7 +12,7 @@ from typing import Optional, Iterator, Tuple, Callable
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from lbry.crypto.hash import sha512, double_sha256, ripemd160 from lbry.crypto.hash import sha512, double_sha256, ripemd160
from lbry.wallet.util import ArithUint256, date_to_julian_day from lbry.wallet.util import ArithUint256
from .checkpoints import HASHES from .checkpoints import HASHES
@ -140,8 +140,8 @@ class Headers:
return return
return int(self.first_block_timestamp + (height * self.timestamp_average_offset)) return int(self.first_block_timestamp + (height * self.timestamp_average_offset))
def estimated_julian_day(self, height): def estimated_date(self, height):
return date_to_julian_day(date.fromtimestamp(self.estimated_timestamp(height))) return date.fromtimestamp(self.estimated_timestamp(height))
async def get_raw_header(self, height) -> bytes: async def get_raw_header(self, height) -> bytes:
if self.chunk_getter: if self.chunk_getter:

View file

@ -16,8 +16,9 @@ from lbry.schema.url import URL
from lbry.crypto.hash import hash160, double_sha256, sha256 from lbry.crypto.hash import hash160, double_sha256, sha256
from lbry.crypto.base58 import Base58 from lbry.crypto.base58 import Base58
from lbry.db import Database, AccountAddress
from .tasks import TaskGroup from .tasks import TaskGroup
from .database import Database
from .stream import StreamController from .stream import StreamController
from .dewies import dewies_to_lbc from .dewies import dewies_to_lbc
from .account import Account, AddressManager, SingleKey from .account import Account, AddressManager, SingleKey
@ -508,7 +509,7 @@ class Ledger(metaclass=LedgerRegistry):
else: else:
check_local = (txid, remote_height) not in we_need check_local = (txid, remote_height) not in we_need
cache_tasks.append(loop.create_task( cache_tasks.append(loop.create_task(
self.cache_transaction(txid, remote_height, check_local=check_local) self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local)
)) ))
synced_txs = [] synced_txs = []
@ -519,18 +520,18 @@ class Ledger(metaclass=LedgerRegistry):
for txi in tx.inputs: for txi in tx.inputs:
if txi.txo_ref.txo is not None: if txi.txo_ref.txo is not None:
continue continue
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id) cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash)
if cache_item is not None: if cache_item is not None:
if cache_item.tx is None: if cache_item.tx is None:
await cache_item.has_tx.wait() await cache_item.has_tx.wait()
assert cache_item.tx is not None assert cache_item.tx is not None
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
else: else:
check_db_for_txos.append(txi.txo_ref.id) check_db_for_txos.append(txi.txo_ref.hash)
referenced_txos = {} if not check_db_for_txos else { referenced_txos = {} if not check_db_for_txos else {
txo.id: txo for txo in await self.db.get_txos( txo.id: txo for txo in await self.db.get_txos(
txoid__in=check_db_for_txos, order_by='txo.txoid', no_tx=True txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True
) )
} }
@ -574,10 +575,10 @@ class Ledger(metaclass=LedgerRegistry):
else: else:
return True return True
async def cache_transaction(self, txid, remote_height, check_local=True): async def cache_transaction(self, tx_hash, remote_height, check_local=True):
cache_item = self._tx_cache.get(txid) cache_item = self._tx_cache.get(tx_hash)
if cache_item is None: if cache_item is None:
cache_item = self._tx_cache[txid] = TransactionCacheItem() cache_item = self._tx_cache[tx_hash] = TransactionCacheItem()
elif cache_item.tx is not None and \ elif cache_item.tx is not None and \
cache_item.tx.height >= remote_height and \ cache_item.tx.height >= remote_height and \
(cache_item.tx.is_verified or remote_height < 1): (cache_item.tx.is_verified or remote_height < 1):
@ -585,11 +586,11 @@ class Ledger(metaclass=LedgerRegistry):
try: try:
cache_item.pending_verifications += 1 cache_item.pending_verifications += 1
return await self._update_cache_item(cache_item, txid, remote_height, check_local) return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local)
finally: finally:
cache_item.pending_verifications -= 1 cache_item.pending_verifications -= 1
async def _update_cache_item(self, cache_item, txid, remote_height, check_local=True): async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True):
async with cache_item.lock: async with cache_item.lock:
@ -597,13 +598,13 @@ class Ledger(metaclass=LedgerRegistry):
if tx is None and check_local: if tx is None and check_local:
# check local db # check local db
tx = cache_item.tx = await self.db.get_transaction(txid=txid) tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash)
merkle = None merkle = None
if tx is None: if tx is None:
# fetch from network # fetch from network
_raw, merkle = await self.network.retriable_call( _raw, merkle = await self.network.retriable_call(
self.network.get_transaction_and_merkle, txid, remote_height self.network.get_transaction_and_merkle, tx_hash, remote_height
) )
tx = Transaction(unhexlify(_raw), height=merkle.get('block_height')) tx = Transaction(unhexlify(_raw), height=merkle.get('block_height'))
cache_item.tx = tx # make sure it's saved before caching it cache_item.tx = tx # make sure it's saved before caching it
@ -612,16 +613,16 @@ class Ledger(metaclass=LedgerRegistry):
async def maybe_verify_transaction(self, tx, remote_height, merkle=None): async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
tx.height = remote_height tx.height = remote_height
cached = self._tx_cache.get(tx.id) cached = self._tx_cache.get(tx.hash)
if not cached: if not cached:
# cache txs looked up by transaction_show too # cache txs looked up by transaction_show too
cached = TransactionCacheItem() cached = TransactionCacheItem()
cached.tx = tx cached.tx = tx
self._tx_cache[tx.id] = cached self._tx_cache[tx.hash] = cached
if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1: if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
# can't be tx.pending_verifications == 1 because we have to handle the transaction_show case # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
if not merkle: if not merkle:
merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height) merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = await self.headers.get(remote_height) header = await self.headers.get(remote_height)
tx.position = merkle['pos'] tx.position = merkle['pos']
@ -703,7 +704,7 @@ class Ledger(metaclass=LedgerRegistry):
txo.purchased_claim_id: txo for txo in txo.purchased_claim_id: txo for txo in
await self.db.get_purchases( await self.db.get_purchases(
accounts=accounts, accounts=accounts,
purchased_claim_id__in=[c.claim_id for c in priced_claims] purchased_claim_hash__in=[c.claim_hash for c in priced_claims]
) )
} }
for txo in txos: for txo in txos:
@ -808,7 +809,7 @@ class Ledger(metaclass=LedgerRegistry):
async def _reset_balance_cache(self, e: TransactionEvent): async def _reset_balance_cache(self, e: TransactionEvent):
account_ids = [ account_ids = [
r['account'] for r in await self.db.get_addresses(('account',), address=e.address) r['account'] for r in await self.db.get_addresses([AccountAddress.c.account], address=e.address)
] ]
for account_id in account_ids: for account_id in account_ids:
if account_id in self._balance_cache: if account_id in self._balance_cache:
@ -917,10 +918,10 @@ class Ledger(metaclass=LedgerRegistry):
def get_support_count(self, **constraints): def get_support_count(self, **constraints):
return self.db.get_support_count(**constraints) return self.db.get_support_count(**constraints)
async def get_transaction_history(self, read_only=False, **constraints): async def get_transaction_history(self, **constraints):
txs: List[Transaction] = await self.db.get_transactions( txs: List[Transaction] = await self.db.get_transactions(
include_is_my_output=True, include_is_spent=True, include_is_my_output=True, include_is_spent=True,
read_only=read_only, **constraints **constraints
) )
headers = self.headers headers = self.headers
history = [] history = []
@ -1030,8 +1031,8 @@ class Ledger(metaclass=LedgerRegistry):
history.append(item) history.append(item)
return history return history
def get_transaction_history_count(self, read_only=False, **constraints): def get_transaction_history_count(self, **constraints):
return self.db.get_transaction_count(read_only=read_only, **constraints) return self.db.get_transaction_count(**constraints)
async def get_detailed_balance(self, accounts, confirmations=0): async def get_detailed_balance(self, accounts, confirmations=0):
result = { result = {

View file

@ -14,11 +14,11 @@ from .dewies import dewies_to_lbc
from .account import Account from .account import Account
from .ledger import Ledger, LedgerRegistry from .ledger import Ledger, LedgerRegistry
from .transaction import Transaction, Output from .transaction import Transaction, Output
from .database import Database
from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK
from .rpc.jsonrpc import CodeMessageError from .rpc.jsonrpc import CodeMessageError
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbry.db import Database
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
@ -109,7 +109,7 @@ class WalletManager:
return self.default_account.ledger return self.default_account.ledger
@property @property
def db(self) -> Database: def db(self) -> 'Database':
return self.ledger.db return self.ledger.db
def check_locked(self): def check_locked(self):
@ -256,12 +256,12 @@ class WalletManager:
def get_unused_address(self): def get_unused_address(self):
return self.default_account.receiving.get_or_create_usable_address() return self.default_account.receiving.get_or_create_usable_address()
async def get_transaction(self, txid: str): async def get_transaction(self, tx_hash: bytes):
tx = await self.db.get_transaction(txid=txid) tx = await self.db.get_transaction(tx_hash=tx_hash)
if tx: if tx:
return tx return tx
try: try:
raw, merkle = await self.ledger.network.get_transaction_and_merkle(txid) raw, merkle = await self.ledger.network.get_transaction_and_merkle(tx_hash)
except CodeMessageError as e: except CodeMessageError as e:
if 'No such mempool or blockchain transaction.' in e.message: if 'No such mempool or blockchain transaction.' in e.message:
return {'success': False, 'code': 404, 'message': 'transaction not found'} return {'success': False, 'code': 404, 'message': 'transaction not found'}

View file

@ -4,6 +4,7 @@ import json
from time import perf_counter from time import perf_counter
from operator import itemgetter from operator import itemgetter
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from binascii import hexlify
from lbry import __version__ from lbry import __version__
from lbry.error import IncompatibleWalletServerError from lbry.error import IncompatibleWalletServerError
@ -254,20 +255,20 @@ class Network:
def get_transaction(self, tx_hash, known_height=None): def get_transaction(self, tx_hash, known_height=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [tx_hash], restricted) return self.rpc('blockchain.transaction.get', [hexlify(tx_hash[::-1]).decode()], restricted)
def get_transaction_and_merkle(self, tx_hash, known_height=None): def get_transaction_and_merkle(self, tx_hash, known_height=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.info', [tx_hash], restricted) return self.rpc('blockchain.transaction.info', [hexlify(tx_hash[::-1]).decode()], restricted)
def get_transaction_height(self, tx_hash, known_height=None): def get_transaction_height(self, tx_hash, known_height=None):
restricted = not known_height or 0 > known_height > self.remote_height - 10 restricted = not known_height or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted) return self.rpc('blockchain.transaction.get_height', [hexlify(tx_hash[::-1]).decode()], restricted)
def get_merkle(self, tx_hash, height): def get_merkle(self, tx_hash, height):
restricted = 0 > height > self.remote_height - 10 restricted = 0 > height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted) return self.rpc('blockchain.transaction.get_merkle', [hexlify(tx_hash[::-1]).decode(), height], restricted)
def get_headers(self, height, count=10000, b64=False): def get_headers(self, height, count=10000, b64=False):
restricted = height >= self.remote_height - 100 restricted = height >= self.remote_height - 100

View file

@ -13,6 +13,7 @@ from typing import Type, Optional
import urllib.request import urllib.request
import lbry import lbry
from lbry.db import Database
from lbry.wallet.server.server import Server from lbry.wallet.server.server import Server
from lbry.wallet.server.env import Env from lbry.wallet.server.env import Env
from lbry.wallet import Wallet, Ledger, RegTestLedger, WalletManager, Account, BlockHeightEvent from lbry.wallet import Wallet, Ledger, RegTestLedger, WalletManager, Account, BlockHeightEvent
@ -125,12 +126,24 @@ class WalletNode:
wallet_file_name = os.path.join(wallets_dir, 'my_wallet.json') wallet_file_name = os.path.join(wallets_dir, 'my_wallet.json')
with open(wallet_file_name, 'w') as wallet_file: with open(wallet_file_name, 'w') as wallet_file:
wallet_file.write('{"version": 1, "accounts": []}\n') wallet_file.write('{"version": 1, "accounts": []}\n')
db_driver = os.environ.get('TEST_DB', 'sqlite')
if db_driver == 'sqlite':
db = 'sqlite:///'+os.path.join(self.data_path, self.ledger_class.get_id(), 'blockchain.db')
elif db_driver == 'postgres':
db_name = f'lbry_test_{self.port}'
meta_db = Database(f'postgres:///postgres')
await meta_db.drop(db_name)
await meta_db.create(db_name)
db = f'postgres:///{db_name}'
else:
raise RuntimeError(f"Unsupported database driver: {db_driver}")
self.manager = self.manager_class.from_config({ self.manager = self.manager_class.from_config({
'ledgers': { 'ledgers': {
self.ledger_class.get_id(): { self.ledger_class.get_id(): {
'api_port': self.port, 'api_port': self.port,
'default_servers': [(spv_node.hostname, spv_node.port)], 'default_servers': [(spv_node.hostname, spv_node.port)],
'data_path': self.data_path 'data_path': self.data_path,
'db': Database(db)
} }
}, },
'wallets': [wallet_file_name] 'wallets': [wallet_file_name]

View file

@ -268,6 +268,10 @@ class Output(InputOutput):
def id(self): def id(self):
return self.ref.id return self.ref.id
@property
def hash(self):
return self.ref.hash
@property @property
def pubkey_hash(self): def pubkey_hash(self):
return self.script.values['pubkey_hash'] return self.script.values['pubkey_hash']
@ -477,6 +481,13 @@ class Output(InputOutput):
if self.purchased_claim is not None: if self.purchased_claim is not None:
return self.purchased_claim.claim_id return self.purchased_claim.claim_id
@property
def purchased_claim_hash(self):
if self.purchase is not None:
return self.purchase.purchase_data.claim_hash
if self.purchased_claim is not None:
return self.purchased_claim.claim_hash
@property @property
def has_price(self): def has_price(self):
if self.can_decode_claim: if self.can_decode_claim:
@ -536,9 +547,9 @@ class Transaction:
def hash(self): def hash(self):
return self.ref.hash return self.ref.hash
def get_julian_day(self, ledger): def get_ordinal_day(self, ledger):
if self._day is None and self.height > 0: if self._day is None and self.height > 0:
self._day = ledger.headers.estimated_julian_day(self.height) self._day = ledger.headers.estimated_date(self.height).toordinal()
return self._day return self._day
@property @property

View file

@ -3,10 +3,6 @@ from typing import TypeVar, Sequence, Optional
from .constants import COIN from .constants import COIN
def date_to_julian_day(d):
return d.toordinal() + 1721424.5
def coins_to_satoshis(coins): def coins_to_satoshis(coins):
if not isinstance(coins, str): if not isinstance(coins, str):
raise ValueError("{coins} must be a string") raise ValueError("{coins} must be a string")

View file

@ -0,0 +1,87 @@
import os
import time
import asyncio
import logging
from binascii import unhexlify, hexlify
from random import choice
from lbry.testcase import AsyncioTestCase
from lbry.crypto.base58 import Base58
from lbry.blockchain import Lbrycrd, BlockchainSync
from lbry.db import Database
from lbry.blockchain.block import Block
from lbry.schema.claim import Stream
from lbry.wallet.transaction import Transaction, Output
from lbry.wallet.constants import CENT
from lbry.wallet.bcd_data_stream import BCDataStream
#logging.getLogger('lbry.blockchain').setLevel(logging.DEBUG)
log = logging.getLogger(__name__)
class TestBlockchain(AsyncioTestCase):
async def asyncSetUp(self):
await super().asyncSetUp()
#self.chain = Lbrycrd.temp_regtest()
self.chain = Lbrycrd('/tmp/tmp0429f0ku/', True)#.temp_regtest()
await self.chain.ensure()
await self.chain.start('-maxblockfilesize=8', '-rpcworkqueue=128')
self.addCleanup(self.chain.stop, False)
async def test_block_event(self):
msgs = []
self.chain.subscribe()
self.chain.on_block.listen(lambda e: msgs.append(e['msg']))
res = await self.chain.generate(5)
await self.chain.on_block.where(lambda e: e['msg'] == 4)
self.assertEqual([0, 1, 2, 3, 4], msgs)
self.assertEqual(5, len(res))
self.chain.unsubscribe()
res = await self.chain.generate(2)
self.assertEqual(2, len(res))
await asyncio.sleep(0.1) # give some time to "miss" the new block events
self.chain.subscribe()
res = await self.chain.generate(3)
await self.chain.on_block.where(lambda e: e['msg'] == 9)
self.assertEqual(3, len(res))
self.assertEqual([0, 1, 2, 3, 4, 7, 8, 9], msgs)
async def test_sync(self):
if False:
names = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten']
await self.chain.generate(101)
address = Base58.decode(await self.chain.get_new_address())
for _ in range(190):
tx = Transaction().add_outputs([
Output.pay_claim_name_pubkey_hash(
CENT, f'{choice(names)}{i}',
Stream().update(
title='a claim title',
description='Lorem ipsum '*400,
tags=['crypto', 'health', 'space'],
).claim,
address)
for i in range(1, 20)
])
funded = await self.chain.fund_raw_transaction(hexlify(tx.raw).decode())
signed = await self.chain.sign_raw_transaction_with_wallet(funded['hex'])
await self.chain.send_raw_transaction(signed['hex'])
await self.chain.generate(1)
self.assertEqual(
[(0, 191, 280), (1, 89, 178), (2, 12, 24)],
[(file['file_number'], file['blocks'], file['txs'])
for file in await self.chain.get_block_files()]
)
self.assertEqual(191, len(await self.chain.get_file_details(0)))
db = Database(os.path.join(self.chain.actual_data_dir, 'lbry.db'))
self.addCleanup(db.close)
await db.open()
sync = BlockchainSync(self.chain, use_process_pool=False)
await sync.load_blocks()

View file

@ -1,8 +1,9 @@
import asyncio import asyncio
import lbry
from unittest.mock import Mock from unittest.mock import Mock
from binascii import unhexlify
import lbry
from lbry.wallet.network import Network from lbry.wallet.network import Network
from lbry.wallet.orchstr8.node import SPVNode from lbry.wallet.orchstr8.node import SPVNode
from lbry.wallet.rpc import RPCSession from lbry.wallet.rpc import RPCSession
@ -100,15 +101,15 @@ class ReconnectTests(IntegrationTestCase):
# disconnect and send a new tx, should reconnect and get it # disconnect and send a new tx, should reconnect and get it
self.ledger.network.client.connection_lost(Exception()) self.ledger.network.client.connection_lost(Exception())
self.assertFalse(self.ledger.network.is_connected) self.assertFalse(self.ledger.network.is_connected)
sendtxid = await self.blockchain.send_to_address(address1, 1.1337) tx_hash = unhexlify((await self.blockchain.send_to_address(address1, 1.1337)))[::-1]
await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool await asyncio.wait_for(self.on_transaction_hash(tx_hash), 2.0) # mempool
await self.blockchain.generate(1) await self.blockchain.generate(1)
await self.on_transaction_id(sendtxid) # confirmed await self.on_transaction_hash(tx_hash) # confirmed
self.assertLess(self.ledger.network.client.response_time, 1) # response time properly set lower, we are fine self.assertLess(self.ledger.network.client.response_time, 1) # response time properly set lower, we are fine
await self.assertBalance(self.account, '1.1337') await self.assertBalance(self.account, '1.1337')
# is it real? are we rich!? let me see this tx... # is it real? are we rich!? let me see this tx...
d = self.ledger.network.get_transaction(sendtxid) d = self.ledger.network.get_transaction(tx_hash)
# what's that smoke on my ethernet cable? oh no! # what's that smoke on my ethernet cable? oh no!
master_client = self.ledger.network.client master_client = self.ledger.network.client
self.ledger.network.client.connection_lost(Exception()) self.ledger.network.client.connection_lost(Exception())
@ -117,15 +118,15 @@ class ReconnectTests(IntegrationTestCase):
self.assertIsNone(master_client.response_time) # response time unknown as it failed self.assertIsNone(master_client.response_time) # response time unknown as it failed
# rich but offline? no way, no water, let's retry # rich but offline? no way, no water, let's retry
with self.assertRaisesRegex(ConnectionError, 'connection is not available'): with self.assertRaisesRegex(ConnectionError, 'connection is not available'):
await self.ledger.network.get_transaction(sendtxid) await self.ledger.network.get_transaction(tx_hash)
# * goes to pick some water outside... * time passes by and another donation comes in # * goes to pick some water outside... * time passes by and another donation comes in
sendtxid = await self.blockchain.send_to_address(address1, 42) tx_hash = unhexlify((await self.blockchain.send_to_address(address1, 42)))[::-1]
await self.blockchain.generate(1) await self.blockchain.generate(1)
# (this is just so the test doesn't hang forever if it doesn't reconnect) # (this is just so the test doesn't hang forever if it doesn't reconnect)
if not self.ledger.network.is_connected: if not self.ledger.network.is_connected:
await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0) await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0)
# omg, the burned cable still works! torba is fire proof! # omg, the burned cable still works! torba is fire proof!
await self.ledger.network.get_transaction(sendtxid) await self.ledger.network.get_transaction(tx_hash)
async def test_timeout_then_reconnect(self): async def test_timeout_then_reconnect(self):
# tests that it connects back after some failed attempts # tests that it connects back after some failed attempts

View file

@ -46,13 +46,13 @@ class BasicTransactionTests(IntegrationTestCase):
[self.account], self.account [self.account], self.account
)) ))
await asyncio.wait([self.broadcast(tx) for tx in txs]) await asyncio.wait([self.broadcast(tx) for tx in txs])
await asyncio.wait([self.ledger.wait(tx) for tx in txs]) await asyncio.wait([self.ledger.wait(tx, timeout=2) for tx in txs])
# verify that a previous bug which failed to save TXIs doesn't come back # verify that a previous bug which failed to save TXIs doesn't come back
# this check must happen before generating a new block # this check must happen before generating a new block
self.assertTrue(all([ self.assertTrue(all([
tx.inputs[0].txo_ref.txo is not None tx.inputs[0].txo_ref.txo is not None
for tx in await self.ledger.db.get_transactions(txid__in=[tx.id for tx in txs]) for tx in await self.ledger.db.get_transactions(tx_hash__in=[tx.hash for tx in txs])
])) ]))
await self.blockchain.generate(1) await self.blockchain.generate(1)

View file

@ -1,6 +1,8 @@
import asyncio import asyncio
import json import json
from sqlalchemy import event
from lbry.wallet import ENCRYPT_ON_DISK from lbry.wallet import ENCRYPT_ON_DISK
from lbry.error import InvalidPasswordError from lbry.error import InvalidPasswordError
from lbry.testcase import CommandTestCase from lbry.testcase import CommandTestCase
@ -64,7 +66,14 @@ class WalletCommands(CommandTestCase):
wallet_balance = self.daemon.jsonrpc_wallet_balance wallet_balance = self.daemon.jsonrpc_wallet_balance
ledger = self.ledger ledger = self.ledger
query_count = self.ledger.db.db.query_count
query_count = 0
def catch_queries(*args, **kwargs):
nonlocal query_count
query_count += 1
event.listen(self.ledger.db.engine, "before_cursor_execute", catch_queries)
expected = { expected = {
'total': '20.0', 'total': '20.0',
@ -74,15 +83,14 @@ class WalletCommands(CommandTestCase):
} }
self.assertIsNone(ledger._balance_cache.get(self.account.id)) self.assertIsNone(ledger._balance_cache.get(self.account.id))
query_count += 6
self.assertEqual(await wallet_balance(), expected) self.assertEqual(await wallet_balance(), expected)
self.assertEqual(self.ledger.db.db.query_count, query_count) self.assertEqual(query_count, 6)
self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(self.account.id))['total'], '10.0') self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(self.account.id))['total'], '10.0')
self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(account2.id))['total'], '10.0') self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(account2.id))['total'], '10.0')
# calling again uses cache # calling again uses cache
self.assertEqual(await wallet_balance(), expected) self.assertEqual(await wallet_balance(), expected)
self.assertEqual(self.ledger.db.db.query_count, query_count) self.assertEqual(query_count, 6)
self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(self.account.id))['total'], '10.0') self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(self.account.id))['total'], '10.0')
self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(account2.id))['total'], '10.0') self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(account2.id))['total'], '10.0')
@ -96,12 +104,11 @@ class WalletCommands(CommandTestCase):
'reserved_subtotals': {'claims': '1.0', 'supports': '0.0', 'tips': '0.0'} 'reserved_subtotals': {'claims': '1.0', 'supports': '0.0', 'tips': '0.0'}
} }
# on_transaction event reset balance cache # on_transaction event reset balance cache
query_count = self.ledger.db.db.query_count query_count = 0
self.assertEqual(await wallet_balance(), expected) self.assertEqual(await wallet_balance(), expected)
query_count += 3 # only one of the accounts changed
self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(self.account.id))['total'], '9.979893') self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(self.account.id))['total'], '9.979893')
self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(account2.id))['total'], '10.0') self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(account2.id))['total'], '10.0')
self.assertEqual(self.ledger.db.db.query_count, query_count) self.assertEqual(query_count, 3) # only one of the accounts changed
async def test_granular_balances(self): async def test_granular_balances(self):
account2 = await self.daemon.jsonrpc_account_create("Tip-er") account2 = await self.daemon.jsonrpc_account_create("Tip-er")

View file

@ -10,9 +10,10 @@ from lbry.testcase import get_fake_exchange_rate_manager
from lbry.utils import generate_id from lbry.utils import generate_id
from lbry.error import InsufficientFundsError from lbry.error import InsufficientFundsError
from lbry.error import KeyFeeAboveMaxAllowedError, ResolveError, DownloadSDTimeoutError, DownloadDataTimeoutError from lbry.error import KeyFeeAboveMaxAllowedError, ResolveError, DownloadSDTimeoutError, DownloadDataTimeoutError
from lbry.wallet import WalletManager, Wallet, Ledger, Transaction, Input, Output, Database from lbry.wallet import WalletManager, Wallet, Ledger, Transaction, Input, Output
from lbry.wallet.constants import CENT, NULL_HASH32 from lbry.wallet.constants import CENT, NULL_HASH32
from lbry.wallet.network import ClientSession from lbry.wallet.network import ClientSession
from lbry.db import Database
from lbry.conf import Config from lbry.conf import Config
from lbry.extras.daemon.analytics import AnalyticsManager from lbry.extras.daemon.analytics import AnalyticsManager
from lbry.stream.stream_manager import StreamManager from lbry.stream.stream_manager import StreamManager
@ -95,7 +96,7 @@ async def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None):
wallet = Wallet() wallet = Wallet()
ledger = Ledger({ ledger = Ledger({
'db': Database(':memory:'), 'db': Database('sqlite:///:memory:'),
'headers': FakeHeaders(514082) 'headers': FakeHeaders(514082)
}) })
await ledger.db.open() await ledger.db.open()

View file

@ -1,13 +1,14 @@
from binascii import hexlify from binascii import hexlify
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet import Wallet, Ledger, Database, Headers, Account, SingleKey, HierarchicalDeterministic from lbry.wallet import Wallet, Ledger, Headers, Account, SingleKey, HierarchicalDeterministic
from lbry.db import Database
class TestAccount(AsyncioTestCase): class TestAccount(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = Ledger({ self.ledger = Ledger({
'db': Database(':memory:'), 'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
await self.ledger.db.open() await self.ledger.db.open()
@ -233,7 +234,7 @@ class TestSingleKeyAccount(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = Ledger({ self.ledger = Ledger({
'db': Database(':memory:'), 'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
await self.ledger.db.open() await self.ledger.db.open()

View file

@ -2,7 +2,8 @@ from binascii import unhexlify, hexlify
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet.bip32 import PubKey, PrivateKey, from_extended_key_string from lbry.wallet.bip32 import PubKey, PrivateKey, from_extended_key_string
from lbry.wallet import Ledger, Database, Headers from lbry.wallet import Ledger, Headers
from lbry.db import Database
from tests.unit.wallet.key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys from tests.unit.wallet.key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys
@ -47,7 +48,7 @@ class BIP32Tests(AsyncioTestCase):
PrivateKey(None, b'abcd', b'abcd'*8, 0, 255) PrivateKey(None, b'abcd', b'abcd'*8, 0, 255)
private_key = PrivateKey( private_key = PrivateKey(
Ledger({ Ledger({
'db': Database(':memory:'), 'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:'), 'headers': Headers(':memory:'),
}), }),
unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'), unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'),
@ -68,7 +69,7 @@ class BIP32Tests(AsyncioTestCase):
async def test_private_key_derivation(self): async def test_private_key_derivation(self):
private_key = PrivateKey( private_key = PrivateKey(
Ledger({ Ledger({
'db': Database(':memory:'), 'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:'), 'headers': Headers(':memory:'),
}), }),
unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'), unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'),
@ -85,7 +86,7 @@ class BIP32Tests(AsyncioTestCase):
async def test_from_extended_keys(self): async def test_from_extended_keys(self):
ledger = Ledger({ ledger = Ledger({
'db': Database(':memory:'), 'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:'), 'headers': Headers(':memory:'),
}) })
self.assertIsInstance( self.assertIsInstance(

View file

@ -2,7 +2,8 @@ from types import GeneratorType
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet import Ledger, Database, Headers from lbry.wallet import Ledger, Headers
from lbry.db import Database
from lbry.wallet.coinselection import CoinSelector, MAXIMUM_TRIES from lbry.wallet.coinselection import CoinSelector, MAXIMUM_TRIES
from lbry.constants import CENT from lbry.constants import CENT
@ -21,7 +22,7 @@ class BaseSelectionTestCase(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = Ledger({ self.ledger = Ledger({
'db': Database(':memory:'), 'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:'), 'headers': Headers(':memory:'),
}) })
await self.ledger.db.open() await self.ledger.db.open()

View file

@ -6,9 +6,12 @@ import tempfile
import asyncio import asyncio
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from sqlalchemy import Column, Text
from lbry.wallet import ( from lbry.wallet import (
Wallet, Account, Ledger, Database, Headers, Transaction, Input Wallet, Account, Ledger, Headers, Transaction, Input
) )
from lbry.db import Table, Version, Database, metadata
from lbry.wallet.constants import COIN from lbry.wallet.constants import COIN
from lbry.wallet.database import query, interpolate, constraints_to_sql, AIOSQLite from lbry.wallet.database import query, interpolate, constraints_to_sql, AIOSQLite
from lbry.crypto.hash import sha256 from lbry.crypto.hash import sha256
@ -208,7 +211,7 @@ class TestQueries(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = Ledger({ self.ledger = Ledger({
'db': Database(':memory:'), 'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
self.wallet = Wallet() self.wallet = Wallet()
@ -265,13 +268,13 @@ class TestQueries(AsyncioTestCase):
async def test_large_tx_doesnt_hit_variable_limits(self): async def test_large_tx_doesnt_hit_variable_limits(self):
# SQLite is usually compiled with 999 variables limit: https://www.sqlite.org/limits.html # SQLite is usually compiled with 999 variables limit: https://www.sqlite.org/limits.html
# This can be removed when there is a better way. See: https://github.com/lbryio/lbry-sdk/issues/2281 # This can be removed when there is a better way. See: https://github.com/lbryio/lbry-sdk/issues/2281
fetchall = self.ledger.db.db.execute_fetchall fetchall = self.ledger.db.execute_fetchall
def check_parameters_length(sql, parameters, read_only=False): def check_parameters_length(sql, parameters=None):
self.assertLess(len(parameters or []), 999) self.assertLess(len(parameters or []), 999)
return fetchall(sql, parameters, read_only) return fetchall(sql, parameters)
self.ledger.db.db.execute_fetchall = check_parameters_length self.ledger.db.execute_fetchall = check_parameters_length
account = await self.create_account() account = await self.create_account()
tx = await self.create_tx_from_nothing(account, 0) tx = await self.create_tx_from_nothing(account, 0)
for height in range(1, 1200): for height in range(1, 1200):
@ -368,14 +371,14 @@ class TestQueries(AsyncioTestCase):
self.assertEqual(txs[1].outputs[0].is_my_output, True) self.assertEqual(txs[1].outputs[0].is_my_output, True)
self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account2])) self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account2]))
tx = await self.ledger.db.get_transaction(txid=tx2.id) tx = await self.ledger.db.get_transaction(tx_hash=tx2.hash)
self.assertEqual(tx.id, tx2.id) self.assertEqual(tx.id, tx2.id)
self.assertIsNone(tx.inputs[0].is_my_input) self.assertIsNone(tx.inputs[0].is_my_input)
self.assertIsNone(tx.outputs[0].is_my_output) self.assertIsNone(tx.outputs[0].is_my_output)
tx = await self.ledger.db.get_transaction(wallet=wallet1, txid=tx2.id, include_is_my_output=True) tx = await self.ledger.db.get_transaction(wallet=wallet1, tx_hash=tx2.hash, include_is_my_output=True)
self.assertTrue(tx.inputs[0].is_my_input) self.assertTrue(tx.inputs[0].is_my_input)
self.assertFalse(tx.outputs[0].is_my_output) self.assertFalse(tx.outputs[0].is_my_output)
tx = await self.ledger.db.get_transaction(wallet=wallet2, txid=tx2.id, include_is_my_output=True) tx = await self.ledger.db.get_transaction(wallet=wallet2, tx_hash=tx2.hash, include_is_my_output=True)
self.assertFalse(tx.inputs[0].is_my_input) self.assertFalse(tx.inputs[0].is_my_input)
self.assertTrue(tx.outputs[0].is_my_output) self.assertTrue(tx.outputs[0].is_my_output)
@ -425,7 +428,7 @@ class TestUpgrade(AsyncioTestCase):
async def test_reset_on_version_change(self): async def test_reset_on_version_change(self):
self.ledger = Ledger({ self.ledger = Ledger({
'db': Database(self.path), 'db': Database('sqlite:///'+self.path),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
@ -433,7 +436,8 @@ class TestUpgrade(AsyncioTestCase):
self.ledger.db.SCHEMA_VERSION = None self.ledger.db.SCHEMA_VERSION = None
self.assertListEqual(self.get_tables(), []) self.assertListEqual(self.get_tables(), [])
await self.ledger.db.open() await self.ledger.db.open()
self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo']) metadata.drop_all(self.ledger.db.engine, [Version]) # simulate pre-version table db
self.assertEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo'])
self.assertListEqual(self.get_addresses(), []) self.assertListEqual(self.get_addresses(), [])
self.add_address('address1') self.add_address('address1')
await self.ledger.db.close() await self.ledger.db.close()
@ -442,28 +446,27 @@ class TestUpgrade(AsyncioTestCase):
self.ledger.db.SCHEMA_VERSION = '1.0' self.ledger.db.SCHEMA_VERSION = '1.0'
await self.ledger.db.open() await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.0') self.assertEqual(self.get_version(), '1.0')
self.assertListEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertListEqual(self.get_addresses(), []) # address1 deleted during version upgrade self.assertListEqual(self.get_addresses(), []) # address1 deleted during version upgrade
self.add_address('address2') self.add_address('address2')
await self.ledger.db.close() await self.ledger.db.close()
# nothing changes # nothing changes
self.assertEqual(self.get_version(), '1.0') self.assertEqual(self.get_version(), '1.0')
self.assertListEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
await self.ledger.db.open() await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.0') self.assertEqual(self.get_version(), '1.0')
self.assertListEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertListEqual(self.get_addresses(), ['address2']) self.assertListEqual(self.get_addresses(), ['address2'])
await self.ledger.db.close() await self.ledger.db.close()
# upgrade version, database reset # upgrade version, database reset
foo = Table('foo', metadata, Column('bar', Text, primary_key=True))
self.addCleanup(metadata.remove, foo)
self.ledger.db.SCHEMA_VERSION = '1.1' self.ledger.db.SCHEMA_VERSION = '1.1'
self.ledger.db.CREATE_TABLES_QUERY += """
create table if not exists foo (bar text);
"""
await self.ledger.db.open() await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.1') self.assertEqual(self.get_version(), '1.1')
self.assertListEqual(self.get_tables(), ['account_address', 'foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) self.assertListEqual(self.get_tables(), ['account_address', 'block', 'foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertListEqual(self.get_addresses(), []) # all tables got reset self.assertListEqual(self.get_addresses(), []) # all tables got reset
await self.ledger.db.close() await self.ledger.db.close()

View file

@ -1,8 +1,9 @@
import os import os
from binascii import hexlify from binascii import hexlify, unhexlify
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet import Wallet, Account, Transaction, Output, Input, Ledger, Database, Headers from lbry.wallet import Wallet, Account, Transaction, Output, Input, Ledger, Headers
from lbry.db import Database
from tests.unit.wallet.test_transaction import get_transaction, get_output from tests.unit.wallet.test_transaction import get_transaction, get_output
from tests.unit.wallet.test_headers import HEADERS, block_bytes from tests.unit.wallet.test_headers import HEADERS, block_bytes
@ -45,7 +46,7 @@ class LedgerTestCase(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = Ledger({ self.ledger = Ledger({
'db': Database(':memory:'), 'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
self.account = Account.generate(self.ledger, Wallet(), "lbryum") self.account = Account.generate(self.ledger, Wallet(), "lbryum")
@ -84,6 +85,10 @@ class TestSynchronization(LedgerTestCase):
txid2 = 'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9' txid2 = 'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9'
txid3 = 'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0' txid3 = 'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0'
txid4 = '047cf1d53ef68f0fd586d46f90c09ff8e57a4180f67e7f4b8dd0135c3741e828' txid4 = '047cf1d53ef68f0fd586d46f90c09ff8e57a4180f67e7f4b8dd0135c3741e828'
txhash1 = unhexlify(txid1)[::-1]
txhash2 = unhexlify(txid2)[::-1]
txhash3 = unhexlify(txid3)[::-1]
txhash4 = unhexlify(txid4)[::-1]
account = Account.generate(self.ledger, Wallet(), "torba") account = Account.generate(self.ledger, Wallet(), "torba")
address = await account.receiving.get_or_create_usable_address() address = await account.receiving.get_or_create_usable_address()
@ -99,13 +104,13 @@ class TestSynchronization(LedgerTestCase):
{'tx_hash': txid2, 'height': 1}, {'tx_hash': txid2, 'height': 1},
{'tx_hash': txid3, 'height': 2}, {'tx_hash': txid3, 'height': 2},
], { ], {
txid1: hexlify(get_transaction(get_output(1)).raw), txhash1: hexlify(get_transaction(get_output(1)).raw),
txid2: hexlify(get_transaction(get_output(2)).raw), txhash2: hexlify(get_transaction(get_output(2)).raw),
txid3: hexlify(get_transaction(get_output(3)).raw), txhash3: hexlify(get_transaction(get_output(3)).raw),
}) })
await self.ledger.update_history(address, '') await self.ledger.update_history(address, '')
self.assertListEqual(self.ledger.network.get_history_called, [address]) self.assertListEqual(self.ledger.network.get_history_called, [address])
self.assertListEqual(self.ledger.network.get_transaction_called, [txid1, txid2, txid3]) self.assertListEqual(self.ledger.network.get_transaction_called, [txhash1, txhash2, txhash3])
address_details = await self.ledger.db.get_address(address=address) address_details = await self.ledger.db.get_address(address=address)
@ -125,12 +130,12 @@ class TestSynchronization(LedgerTestCase):
self.assertListEqual(self.ledger.network.get_transaction_called, []) self.assertListEqual(self.ledger.network.get_transaction_called, [])
self.ledger.network.history.append({'tx_hash': txid4, 'height': 3}) self.ledger.network.history.append({'tx_hash': txid4, 'height': 3})
self.ledger.network.transaction[txid4] = hexlify(get_transaction(get_output(4)).raw) self.ledger.network.transaction[txhash4] = hexlify(get_transaction(get_output(4)).raw)
self.ledger.network.get_history_called = [] self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = [] self.ledger.network.get_transaction_called = []
await self.ledger.update_history(address, '') await self.ledger.update_history(address, '')
self.assertListEqual(self.ledger.network.get_history_called, [address]) self.assertListEqual(self.ledger.network.get_history_called, [address])
self.assertListEqual(self.ledger.network.get_transaction_called, [txid4]) self.assertListEqual(self.ledger.network.get_transaction_called, [txhash4])
address_details = await self.ledger.db.get_address(address=address) address_details = await self.ledger.db.get_address(address=address)
self.assertEqual( self.assertEqual(
address_details['history'], address_details['history'],

View file

@ -3,7 +3,8 @@ from binascii import unhexlify
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet.constants import CENT, NULL_HASH32 from lbry.wallet.constants import CENT, NULL_HASH32
from lbry.wallet import Ledger, Database, Headers, Transaction, Input, Output from lbry.wallet import Ledger, Headers, Transaction, Input, Output
from lbry.db import Database
from lbry.schema.claim import Claim from lbry.schema.claim import Claim

View file

@ -4,7 +4,8 @@ from itertools import cycle
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet.constants import CENT, COIN, NULL_HASH32 from lbry.wallet.constants import CENT, COIN, NULL_HASH32
from lbry.wallet import Wallet, Account, Ledger, Database, Headers, Transaction, Output, Input from lbry.wallet import Wallet, Account, Ledger, Headers, Transaction, Output, Input
from lbry.db import Database
NULL_HASH = b'\x00'*32 NULL_HASH = b'\x00'*32
@ -38,7 +39,7 @@ class TestSizeAndFeeEstimation(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = Ledger({ self.ledger = Ledger({
'db': Database(':memory:'), 'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
await self.ledger.db.open() await self.ledger.db.open()
@ -264,7 +265,7 @@ class TestTransactionSigning(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = Ledger({ self.ledger = Ledger({
'db': Database(':memory:'), 'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
await self.ledger.db.open() await self.ledger.db.open()
@ -303,7 +304,7 @@ class TransactionIOBalancing(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.ledger = Ledger({ self.ledger = Ledger({
'db': Database(':memory:'), 'db': Database('sqlite:///:memory:'),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
await self.ledger.db.open() await self.ledger.db.open()

View file

@ -6,6 +6,7 @@ extras = test
changedir = {toxinidir}/tests changedir = {toxinidir}/tests
setenv = setenv =
HOME=/tmp HOME=/tmp
passenv = TEST_DB
commands = commands =
pip install https://github.com/rogerbinns/apsw/releases/download/3.30.1-r1/apsw-3.30.1-r1.zip \ pip install https://github.com/rogerbinns/apsw/releases/download/3.30.1-r1/apsw-3.30.1-r1.zip \
--global-option=fetch \ --global-option=fetch \