moving and refactoring electrumx into torba

This commit is contained in:
Lex Berezhny 2018-11-03 18:50:34 -04:00
parent 41332f22c3
commit f60435c878
35 changed files with 10200 additions and 22 deletions

1
.gitignore vendored
View file

@ -8,6 +8,7 @@ dist/
# testing
.tox/
tests/unit/bitcoin_headers
torba/bin
# cache and logs
__pycache__/

View file

@ -22,8 +22,6 @@ jobs:
env: TESTTYPE=unit
install:
- pip install tox-travis
- pushd .. && git clone https://github.com/lbryio/electrumx.git --branch lbryumx && popd
- pushd .. && git clone https://github.com/lbryio/orchstr8.git && popd
script: tox
- <<: *tests
python: "3.6"

View file

@ -10,7 +10,7 @@ source =
ignore_missing_imports = True
[pylint]
ignore=words
ignore=words,server
max-args=10
max-line-length=110
good-names=T,t,n,i,j,k,x,y,s,f,d,h,c,e,op,db,tx,io,cachedproperty,log,id

View file

@ -3,6 +3,10 @@ from setuptools import setup, find_packages
import torba
BASE = os.path.dirname(__file__)
with open(os.path.join(BASE, 'README.md'), encoding='utf-8') as fh:
long_description = fh.read()
setup(
name='torba',
version=torba.__version__,
@ -10,11 +14,10 @@ setup(
license='MIT',
author='LBRY Inc.',
author_email='hello@lbry.io',
description='Wallet library for bitcoin based currencies.',
long_description=open(os.path.join(os.path.dirname(__file__), 'README.md'),
encoding='utf-8').read(),
description='Wallet client/server framework for bitcoin based currencies.',
long_description=long_description,
long_description_content_type="text/markdown",
keywords='wallet,crypto,currency,money,bitcoin,lbry',
keywords='wallet,crypto,currency,money,bitcoin,electrum,electrumx',
classifiers=(
'Framework :: AsyncIO',
'Intended Audience :: Developers',
@ -23,13 +26,16 @@ setup(
'Programming Language :: Python :: 3',
'Operating System :: OS Independent',
'Topic :: Internet',
'Topic :: Software Development :: Testing',
'Topic :: Software Development :: Libraries :: Python Modules',
'Topic :: System :: Benchmark',
'Topic :: System :: Distributed Computing',
'Topic :: Utilities',
),
packages=find_packages(exclude=('tests',)),
python_requires='>=3.6',
install_requires=(
'aiohttp',
'aiorpcx==0.9.0',
'coincurve',
'pbkdf2',
@ -38,6 +44,13 @@ setup(
extras_require={
'test': (
'mock',
)
}
'requests',
),
'server': (
'attrs',
'plyvel',
'pylru'
),
},
entry_points={'console_scripts': ['torba=torba.cli:main']}
)

View file

@ -1,3 +1,4 @@
import logging
from asyncio import CancelledError
from orchstr8.testcase import IntegrationTestCase
@ -6,7 +7,7 @@ from torba.constants import COIN
class ReconnectTests(IntegrationTestCase):
VERBOSE = False
VERBOSITY = logging.DEBUG
async def test_connection_drop_still_receives_events_after_reconnected(self):
address1 = await self.account.receiving.get_or_create_usable_address()
@ -31,4 +32,3 @@ class ReconnectTests(IntegrationTestCase):
await self.blockchain.generate(1)
# omg, the burned cable still works! torba is fire proof!
await self.ledger.network.get_transaction(sendtxid)

View file

@ -1,11 +1,12 @@
import logging
import asyncio
from orchstr8.testcase import IntegrationTestCase
from torba.testing import IntegrationTestCase
from torba.constants import COIN
class BasicTransactionTests(IntegrationTestCase):
VERBOSE = False
VERBOSITY = logging.WARNING
async def test_sending_and_receiving(self):
account1, account2 = self.account, self.wallet.generate_account(self.ledger)

89
torba/cli.py Normal file
View file

@ -0,0 +1,89 @@
import logging
import argparse
import asyncio
import aiohttp
from torba.testing.node import Conductor, get_ledger_from_environment, get_blockchain_node_from_ledger
from torba.testing.service import TestingServiceAPI
def get_argument_parser():
parser = argparse.ArgumentParser(
prog="torba"
)
subparsers = parser.add_subparsers(dest='command', help='sub-command help')
gui = subparsers.add_parser("gui", help="Start Qt GUI.")
download = subparsers.add_parser("download", help="Download blockchain node binary.")
start = subparsers.add_parser("start", help="Start orchstr8 service.")
start.add_argument("--blockchain", help="Start blockchain node.", action="store_true")
start.add_argument("--spv", help="Start SPV server.", action="store_true")
start.add_argument("--wallet", help="Start wallet daemon.", action="store_true")
generate = subparsers.add_parser("generate", help="Call generate method on running orchstr8 instance.")
generate.add_argument("blocks", type=int, help="Number of blocks to generate")
transfer = subparsers.add_parser("transfer", help="Call transfer method on running orchstr8 instance.")
return parser
async def run_remote_command(command, **kwargs):
async with aiohttp.ClientSession() as session:
async with session.post('http://localhost:7954/'+command, data=kwargs) as resp:
print(resp.status)
print(await resp.text())
def main():
parser = get_argument_parser()
args = parser.parse_args()
command = getattr(args, 'command', 'help')
if command == 'gui':
from torba.workbench import main as start_app
return start_app()
loop = asyncio.get_event_loop()
ledger = get_ledger_from_environment()
if command == 'download':
logging.getLogger('blockchain').setLevel(logging.INFO)
get_blockchain_node_from_ledger(ledger).ensure()
elif command == 'generate':
loop.run_until_complete(run_remote_command(
'generate', blocks=args.blocks
))
elif command == 'start':
conductor = Conductor()
if getattr(args, 'blockchain', False):
loop.run_until_complete(conductor.start_blockchain())
if getattr(args, 'spv', False):
loop.run_until_complete(conductor.start_spv())
if getattr(args, 'wallet', False):
loop.run_until_complete(conductor.start_wallet())
service = TestingServiceAPI(conductor, loop)
loop.run_until_complete(service.start())
try:
print('========== Orchstr8 API Service Started ========')
loop.run_forever()
except KeyboardInterrupt:
pass
finally:
loop.run_until_complete(service.stop())
loop.run_until_complete(conductor.stop())
loop.close()
else:
parser.print_help()
if __name__ == "__main__":
main()

View file

@ -4,7 +4,7 @@ __node_bin__ = 'bitcoin-abc-0.17.2/bin'
__node_url__ = (
'https://download.bitcoinabc.org/0.17.2/linux/bitcoin-abc-0.17.2-x86_64-linux-gnu.tar.gz'
)
__electrumx__ = 'electrumx.lib.coins.BitcoinCashRegtest'
__spvserver__ = 'torba.server.coins.BitcoinCashRegtest'
from binascii import unhexlify
from torba.baseledger import BaseLedger

View file

@ -4,7 +4,7 @@ __node_bin__ = 'bitcoin-0.16.3/bin'
__node_url__ = (
'https://bitcoin.org/bin/bitcoin-core-0.16.3/bitcoin-0.16.3-x86_64-linux-gnu.tar.gz'
)
__electrumx__ = 'electrumx.lib.coins.BitcoinSegwitRegtest'
__spvserver__ = 'torba.server.coins.BitcoinSegwitRegtest'
import struct
from typing import Optional

1
torba/server/__init__.py Normal file
View file

@ -0,0 +1 @@
from .server import Server

View file

@ -0,0 +1,713 @@
# Copyright (c) 2016-2017, Neil Booth
# Copyright (c) 2017, the ElectrumX authors
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
'''Block prefetcher and chain processor.'''
import array
import asyncio
from struct import pack, unpack
import time
from functools import partial
from aiorpcx import TaskGroup, run_in_thread
import torba
from torba.server.daemon import DaemonError
from torba.server.hash import hash_to_hex_str, HASHX_LEN
from torba.server.util import chunks, class_logger
from torba.server.db import FlushData
class Prefetcher(object):
'''Prefetches blocks (in the forward direction only).'''
def __init__(self, daemon, coin, blocks_event):
self.logger = class_logger(__name__, self.__class__.__name__)
self.daemon = daemon
self.coin = coin
self.blocks_event = blocks_event
self.blocks = []
self.caught_up = False
# Access to fetched_height should be protected by the semaphore
self.fetched_height = None
self.semaphore = asyncio.Semaphore()
self.refill_event = asyncio.Event()
# The prefetched block cache size. The min cache size has
# little effect on sync time.
self.cache_size = 0
self.min_cache_size = 10 * 1024 * 1024
# This makes the first fetch be 10 blocks
self.ave_size = self.min_cache_size // 10
self.polling_delay = 5
async def main_loop(self, bp_height):
'''Loop forever polling for more blocks.'''
await self.reset_height(bp_height)
while True:
try:
# Sleep a while if there is nothing to prefetch
await self.refill_event.wait()
if not await self._prefetch_blocks():
await asyncio.sleep(self.polling_delay)
except DaemonError as e:
self.logger.info(f'ignoring daemon error: {e}')
def get_prefetched_blocks(self):
'''Called by block processor when it is processing queued blocks.'''
blocks = self.blocks
self.blocks = []
self.cache_size = 0
self.refill_event.set()
return blocks
async def reset_height(self, height):
'''Reset to prefetch blocks from the block processor's height.
Used in blockchain reorganisations. This coroutine can be
called asynchronously to the _prefetch_blocks coroutine so we
must synchronize with a semaphore.
'''
async with self.semaphore:
self.blocks.clear()
self.cache_size = 0
self.fetched_height = height
self.refill_event.set()
daemon_height = await self.daemon.height()
behind = daemon_height - height
if behind > 0:
self.logger.info('catching up to daemon height {:,d} '
'({:,d} blocks behind)'
.format(daemon_height, behind))
else:
self.logger.info('caught up to daemon height {:,d}'
.format(daemon_height))
async def _prefetch_blocks(self):
'''Prefetch some blocks and put them on the queue.
Repeats until the queue is full or caught up.
'''
daemon = self.daemon
daemon_height = await daemon.height()
async with self.semaphore:
while self.cache_size < self.min_cache_size:
# Try and catch up all blocks but limit to room in cache.
# Constrain fetch count to between 0 and 500 regardless;
# testnet can be lumpy.
cache_room = self.min_cache_size // self.ave_size
count = min(daemon_height - self.fetched_height, cache_room)
count = min(500, max(count, 0))
if not count:
self.caught_up = True
return False
first = self.fetched_height + 1
hex_hashes = await daemon.block_hex_hashes(first, count)
if self.caught_up:
self.logger.info('new block height {:,d} hash {}'
.format(first + count-1, hex_hashes[-1]))
blocks = await daemon.raw_blocks(hex_hashes)
assert count == len(blocks)
# Special handling for genesis block
if first == 0:
blocks[0] = self.coin.genesis_block(blocks[0])
self.logger.info('verified genesis block with hash {}'
.format(hex_hashes[0]))
# Update our recent average block size estimate
size = sum(len(block) for block in blocks)
if count >= 10:
self.ave_size = size // count
else:
self.ave_size = (size + (10 - count) * self.ave_size) // 10
self.blocks.extend(blocks)
self.cache_size += size
self.fetched_height += count
self.blocks_event.set()
self.refill_event.clear()
return True
class ChainError(Exception):
'''Raised on error processing blocks.'''
class BlockProcessor(object):
'''Process blocks and update the DB state to match.
Employ a prefetcher to prefetch blocks in batches for processing.
Coordinate backing up in case of chain reorganisations.
'''
def __init__(self, env, db, daemon, notifications):
self.env = env
self.db = db
self.daemon = daemon
self.notifications = notifications
self.coin = env.coin
self.blocks_event = asyncio.Event()
self.prefetcher = Prefetcher(daemon, env.coin, self.blocks_event)
self.logger = class_logger(__name__, self.__class__.__name__)
# Meta
self.next_cache_check = 0
self.touched = set()
self.reorg_count = 0
# Caches of unflushed items.
self.headers = []
self.tx_hashes = []
self.undo_infos = []
# UTXO cache
self.utxo_cache = {}
self.db_deletes = []
# If the lock is successfully acquired, in-memory chain state
# is consistent with self.height
self.state_lock = asyncio.Lock()
async def run_in_thread_with_lock(self, func, *args):
# Run in a thread to prevent blocking. Shielded so that
# cancellations from shutdown don't lose work - when the task
# completes the data will be flushed and then we shut down.
# Take the state lock to be certain in-memory state is
# consistent and not being updated elsewhere.
async def run_in_thread_locked():
async with self.state_lock:
return await run_in_thread(func, *args)
return await asyncio.shield(run_in_thread_locked())
async def check_and_advance_blocks(self, raw_blocks):
'''Process the list of raw blocks passed. Detects and handles
reorgs.
'''
if not raw_blocks:
return
first = self.height + 1
blocks = [self.coin.block(raw_block, first + n)
for n, raw_block in enumerate(raw_blocks)]
headers = [block.header for block in blocks]
hprevs = [self.coin.header_prevhash(h) for h in headers]
chain = [self.tip] + [self.coin.header_hash(h) for h in headers[:-1]]
if hprevs == chain:
start = time.time()
await self.run_in_thread_with_lock(self.advance_blocks, blocks)
await self._maybe_flush()
if not self.db.first_sync:
s = '' if len(blocks) == 1 else 's'
self.logger.info('processed {:,d} block{} in {:.1f}s'
.format(len(blocks), s,
time.time() - start))
if self._caught_up_event.is_set():
await self.notifications.on_block(self.touched, self.height)
self.touched = set()
elif hprevs[0] != chain[0]:
await self.reorg_chain()
else:
# It is probably possible but extremely rare that what
# bitcoind returns doesn't form a chain because it
# reorg-ed the chain as it was processing the batched
# block hash requests. Should this happen it's simplest
# just to reset the prefetcher and try again.
self.logger.warning('daemon blocks do not form a chain; '
'resetting the prefetcher')
await self.prefetcher.reset_height(self.height)
async def reorg_chain(self, count=None):
'''Handle a chain reorganisation.
Count is the number of blocks to simulate a reorg, or None for
a real reorg.'''
if count is None:
self.logger.info('chain reorg detected')
else:
self.logger.info(f'faking a reorg of {count:,d} blocks')
await self.flush(True)
async def get_raw_blocks(last_height, hex_hashes):
heights = range(last_height, last_height - len(hex_hashes), -1)
try:
blocks = [self.db.read_raw_block(height) for height in heights]
self.logger.info(f'read {len(blocks)} blocks from disk')
return blocks
except FileNotFoundError:
return await self.daemon.raw_blocks(hex_hashes)
def flush_backup():
# self.touched can include other addresses which is
# harmless, but remove None.
self.touched.discard(None)
self.db.flush_backup(self.flush_data(), self.touched)
start, last, hashes = await self.reorg_hashes(count)
# Reverse and convert to hex strings.
hashes = [hash_to_hex_str(hash) for hash in reversed(hashes)]
for hex_hashes in chunks(hashes, 50):
raw_blocks = await get_raw_blocks(last, hex_hashes)
await self.run_in_thread_with_lock(self.backup_blocks, raw_blocks)
await self.run_in_thread_with_lock(flush_backup)
last -= len(raw_blocks)
await self.prefetcher.reset_height(self.height)
async def reorg_hashes(self, count):
'''Return a pair (start, last, hashes) of blocks to back up during a
reorg.
The hashes are returned in order of increasing height. Start
is the height of the first hash, last of the last.
'''
start, count = await self.calc_reorg_range(count)
last = start + count - 1
s = '' if count == 1 else 's'
self.logger.info(f'chain was reorganised replacing {count:,d} '
f'block{s} at heights {start:,d}-{last:,d}')
return start, last, await self.db.fs_block_hashes(start, count)
async def calc_reorg_range(self, count):
'''Calculate the reorg range'''
def diff_pos(hashes1, hashes2):
'''Returns the index of the first difference in the hash lists.
If both lists match returns their length.'''
for n, (hash1, hash2) in enumerate(zip(hashes1, hashes2)):
if hash1 != hash2:
return n
return len(hashes)
if count is None:
# A real reorg
start = self.height - 1
count = 1
while start > 0:
hashes = await self.db.fs_block_hashes(start, count)
hex_hashes = [hash_to_hex_str(hash) for hash in hashes]
d_hex_hashes = await self.daemon.block_hex_hashes(start, count)
n = diff_pos(hex_hashes, d_hex_hashes)
if n > 0:
start += n
break
count = min(count * 2, start)
start -= count
count = (self.height - start) + 1
else:
start = (self.height - count) + 1
return start, count
def estimate_txs_remaining(self):
# Try to estimate how many txs there are to go
daemon_height = self.daemon.cached_height()
coin = self.coin
tail_count = daemon_height - max(self.height, coin.TX_COUNT_HEIGHT)
# Damp the initial enthusiasm
realism = max(2.0 - 0.9 * self.height / coin.TX_COUNT_HEIGHT, 1.0)
return (tail_count * coin.TX_PER_BLOCK +
max(coin.TX_COUNT - self.tx_count, 0)) * realism
# - Flushing
def flush_data(self):
'''The data for a flush. The lock must be taken.'''
assert self.state_lock.locked()
return FlushData(self.height, self.tx_count, self.headers,
self.tx_hashes, self.undo_infos, self.utxo_cache,
self.db_deletes, self.tip)
async def flush(self, flush_utxos):
def flush():
self.db.flush_dbs(self.flush_data(), flush_utxos,
self.estimate_txs_remaining)
await self.run_in_thread_with_lock(flush)
async def _maybe_flush(self):
# If caught up, flush everything as client queries are
# performed on the DB.
if self._caught_up_event.is_set():
await self.flush(True)
elif time.time() > self.next_cache_check:
flush_arg = self.check_cache_size()
if flush_arg is not None:
await self.flush(flush_arg)
self.next_cache_check = time.time() + 30
def check_cache_size(self):
'''Flush a cache if it gets too big.'''
# Good average estimates based on traversal of subobjects and
# requesting size from Python (see deep_getsizeof).
one_MB = 1000*1000
utxo_cache_size = len(self.utxo_cache) * 205
db_deletes_size = len(self.db_deletes) * 57
hist_cache_size = self.db.history.unflushed_memsize()
# Roughly ntxs * 32 + nblocks * 42
tx_hash_size = ((self.tx_count - self.db.fs_tx_count) * 32
+ (self.height - self.db.fs_height) * 42)
utxo_MB = (db_deletes_size + utxo_cache_size) // one_MB
hist_MB = (hist_cache_size + tx_hash_size) // one_MB
self.logger.info('our height: {:,d} daemon: {:,d} '
'UTXOs {:,d}MB hist {:,d}MB'
.format(self.height, self.daemon.cached_height(),
utxo_MB, hist_MB))
# Flush history if it takes up over 20% of cache memory.
# Flush UTXOs once they take up 80% of cache memory.
cache_MB = self.env.cache_MB
if utxo_MB + hist_MB >= cache_MB or hist_MB >= cache_MB // 5:
return utxo_MB >= cache_MB * 4 // 5
return None
def advance_blocks(self, blocks):
'''Synchronously advance the blocks.
It is already verified they correctly connect onto our tip.
'''
min_height = self.db.min_undo_height(self.daemon.cached_height())
height = self.height
for block in blocks:
height += 1
undo_info = self.advance_txs(block.transactions)
if height >= min_height:
self.undo_infos.append((undo_info, height))
self.db.write_raw_block(block.raw, height)
headers = [block.header for block in blocks]
self.height = height
self.headers.extend(headers)
self.tip = self.coin.header_hash(headers[-1])
def advance_txs(self, txs):
self.tx_hashes.append(b''.join(tx_hash for tx, tx_hash in txs))
# Use local vars for speed in the loops
undo_info = []
tx_num = self.tx_count
script_hashX = self.coin.hashX_from_script
s_pack = pack
put_utxo = self.utxo_cache.__setitem__
spend_utxo = self.spend_utxo
undo_info_append = undo_info.append
update_touched = self.touched.update
hashXs_by_tx = []
append_hashXs = hashXs_by_tx.append
for tx, tx_hash in txs:
hashXs = []
append_hashX = hashXs.append
tx_numb = s_pack('<I', tx_num)
# Spend the inputs
for txin in tx.inputs:
if txin.is_generation():
continue
cache_value = spend_utxo(txin.prev_hash, txin.prev_idx)
undo_info_append(cache_value)
append_hashX(cache_value[:-12])
# Add the new UTXOs
for idx, txout in enumerate(tx.outputs):
# Get the hashX. Ignore unspendable outputs
hashX = script_hashX(txout.pk_script)
if hashX:
append_hashX(hashX)
put_utxo(tx_hash + s_pack('<H', idx),
hashX + tx_numb + s_pack('<Q', txout.value))
append_hashXs(hashXs)
update_touched(hashXs)
tx_num += 1
self.db.history.add_unflushed(hashXs_by_tx, self.tx_count)
self.tx_count = tx_num
self.db.tx_counts.append(tx_num)
return undo_info
def backup_blocks(self, raw_blocks):
'''Backup the raw blocks and flush.
The blocks should be in order of decreasing height, starting at.
self.height. A flush is performed once the blocks are backed up.
'''
self.db.assert_flushed(self.flush_data())
assert self.height >= len(raw_blocks)
coin = self.coin
for raw_block in raw_blocks:
# Check and update self.tip
block = coin.block(raw_block, self.height)
header_hash = coin.header_hash(block.header)
if header_hash != self.tip:
raise ChainError('backup block {} not tip {} at height {:,d}'
.format(hash_to_hex_str(header_hash),
hash_to_hex_str(self.tip),
self.height))
self.tip = coin.header_prevhash(block.header)
self.backup_txs(block.transactions)
self.height -= 1
self.db.tx_counts.pop()
self.logger.info('backed up to height {:,d}'.format(self.height))
def backup_txs(self, txs):
# Prevout values, in order down the block (coinbase first if present)
# undo_info is in reverse block order
undo_info = self.db.read_undo_info(self.height)
if undo_info is None:
raise ChainError('no undo information found for height {:,d}'
.format(self.height))
n = len(undo_info)
# Use local vars for speed in the loops
s_pack = pack
put_utxo = self.utxo_cache.__setitem__
spend_utxo = self.spend_utxo
script_hashX = self.coin.hashX_from_script
touched = self.touched
undo_entry_len = 12 + HASHX_LEN
for tx, tx_hash in reversed(txs):
for idx, txout in enumerate(tx.outputs):
# Spend the TX outputs. Be careful with unspendable
# outputs - we didn't save those in the first place.
hashX = script_hashX(txout.pk_script)
if hashX:
cache_value = spend_utxo(tx_hash, idx)
touched.add(cache_value[:-12])
# Restore the inputs
for txin in reversed(tx.inputs):
if txin.is_generation():
continue
n -= undo_entry_len
undo_item = undo_info[n:n + undo_entry_len]
put_utxo(txin.prev_hash + s_pack('<H', txin.prev_idx),
undo_item)
touched.add(undo_item[:-12])
assert n == 0
self.tx_count -= len(txs)
'''An in-memory UTXO cache, representing all changes to UTXO state
since the last DB flush.
We want to store millions of these in memory for optimal
performance during initial sync, because then it is possible to
spend UTXOs without ever going to the database (other than as an
entry in the address history, and there is only one such entry per
TX not per UTXO). So store them in a Python dictionary with
binary keys and values.
Key: TX_HASH + TX_IDX (32 + 2 = 34 bytes)
Value: HASHX + TX_NUM + VALUE (11 + 4 + 8 = 23 bytes)
That's 57 bytes of raw data in-memory. Python dictionary overhead
means each entry actually uses about 205 bytes of memory. So
almost 5 million UTXOs can fit in 1GB of RAM. There are
approximately 42 million UTXOs on bitcoin mainnet at height
433,000.
Semantics:
add: Add it to the cache dictionary.
spend: Remove it if in the cache dictionary. Otherwise it's
been flushed to the DB. Each UTXO is responsible for two
entries in the DB. Mark them for deletion in the next
cache flush.
The UTXO database format has to be able to do two things efficiently:
1. Given an address be able to list its UTXOs and their values
so its balance can be efficiently computed.
2. When processing transactions, for each prevout spent - a (tx_hash,
idx) pair - we have to be able to remove it from the DB. To send
notifications to clients we also need to know any address it paid
to.
To this end we maintain two "tables", one for each point above:
1. Key: b'u' + address_hashX + tx_idx + tx_num
Value: the UTXO value as a 64-bit unsigned integer
2. Key: b'h' + compressed_tx_hash + tx_idx + tx_num
Value: hashX
The compressed tx hash is just the first few bytes of the hash of
the tx in which the UTXO was created. As this is not unique there
will be potential collisions so tx_num is also in the key. When
looking up a UTXO the prefix space of the compressed hash needs to
be searched and resolved if necessary with the tx_num. The
collision rate is low (<0.1%).
'''
def spend_utxo(self, tx_hash, tx_idx):
'''Spend a UTXO and return the 33-byte value.
If the UTXO is not in the cache it must be on disk. We store
all UTXOs so not finding one indicates a logic error or DB
corruption.
'''
# Fast track is it being in the cache
idx_packed = pack('<H', tx_idx)
cache_value = self.utxo_cache.pop(tx_hash + idx_packed, None)
if cache_value:
return cache_value
# Spend it from the DB.
# Key: b'h' + compressed_tx_hash + tx_idx + tx_num
# Value: hashX
prefix = b'h' + tx_hash[:4] + idx_packed
candidates = {db_key: hashX for db_key, hashX
in self.db.utxo_db.iterator(prefix=prefix)}
for hdb_key, hashX in candidates.items():
tx_num_packed = hdb_key[-4:]
if len(candidates) > 1:
tx_num, = unpack('<I', tx_num_packed)
hash, height = self.db.fs_tx_hash(tx_num)
if hash != tx_hash:
assert hash is not None # Should always be found
continue
# Key: b'u' + address_hashX + tx_idx + tx_num
# Value: the UTXO value as a 64-bit unsigned integer
udb_key = b'u' + hashX + hdb_key[-6:]
utxo_value_packed = self.db.utxo_db.get(udb_key)
if utxo_value_packed:
# Remove both entries for this UTXO
self.db_deletes.append(hdb_key)
self.db_deletes.append(udb_key)
return hashX + tx_num_packed + utxo_value_packed
raise ChainError('UTXO {} / {:,d} not found in "h" table'
.format(hash_to_hex_str(tx_hash), tx_idx))
async def _process_prefetched_blocks(self):
'''Loop forever processing blocks as they arrive.'''
while True:
if self.height == self.daemon.cached_height():
if not self._caught_up_event.is_set():
await self._first_caught_up()
self._caught_up_event.set()
await self.blocks_event.wait()
self.blocks_event.clear()
if self.reorg_count:
await self.reorg_chain(self.reorg_count)
self.reorg_count = 0
else:
blocks = self.prefetcher.get_prefetched_blocks()
await self.check_and_advance_blocks(blocks)
async def _first_caught_up(self):
self.logger.info(f'caught up to height {self.height}')
# Flush everything but with first_sync->False state.
first_sync = self.db.first_sync
self.db.first_sync = False
await self.flush(True)
if first_sync:
self.logger.info(f'{torba.__version__} synced to '
f'height {self.height:,d}')
# Reopen for serving
await self.db.open_for_serving()
async def _first_open_dbs(self):
await self.db.open_for_sync()
self.height = self.db.db_height
self.tip = self.db.db_tip
self.tx_count = self.db.db_tx_count
# --- External API
async def fetch_and_process_blocks(self, caught_up_event):
'''Fetch, process and index blocks from the daemon.
Sets caught_up_event when first caught up. Flushes to disk
and shuts down cleanly if cancelled.
This is mainly because if, during initial sync ElectrumX is
asked to shut down when a large number of blocks have been
processed but not written to disk, it should write those to
disk before exiting, as otherwise a significant amount of work
could be lost.
'''
self._caught_up_event = caught_up_event
await self._first_open_dbs()
try:
async with TaskGroup() as group:
await group.spawn(self.prefetcher.main_loop(self.height))
await group.spawn(self._process_prefetched_blocks())
finally:
# Shut down block processing
self.logger.info('flushing to DB for a clean shutdown...')
await self.flush(True)
def force_chain_reorg(self, count):
'''Force a reorg of the given number of blocks.
Returns True if a reorg is queued, false if not caught up.
'''
if self._caught_up_event.is_set():
self.reorg_count = count
self.blocks_event.set()
return True
return False
class DecredBlockProcessor(BlockProcessor):
async def calc_reorg_range(self, count):
start, count = await super().calc_reorg_range(count)
if start > 0:
# A reorg in Decred can invalidate the previous block
start -= 1
count += 1
return start, count
class NamecoinBlockProcessor(BlockProcessor):
def advance_txs(self, txs):
result = super().advance_txs(txs)
tx_num = self.tx_count - len(txs)
script_name_hashX = self.coin.name_hashX_from_script
update_touched = self.touched.update
hashXs_by_tx = []
append_hashXs = hashXs_by_tx.append
for tx, tx_hash in txs:
hashXs = []
append_hashX = hashXs.append
# Add the new UTXOs and associate them with the name script
for idx, txout in enumerate(tx.outputs):
# Get the hashX of the name script. Ignore non-name scripts.
hashX = script_name_hashX(txout.pk_script)
if hashX:
append_hashX(hashX)
append_hashXs(hashXs)
update_touched(hashXs)
tx_num += 1
self.db.history.add_unflushed(hashXs_by_tx, self.tx_count - len(txs))
return result

2290
torba/server/coins.py Normal file

File diff suppressed because it is too large Load diff

459
torba/server/daemon.py Normal file
View file

@ -0,0 +1,459 @@
# Copyright (c) 2016-2017, Neil Booth
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
'''Class for handling asynchronous connections to a blockchain
daemon.'''
import asyncio
import itertools
import json
import time
from calendar import timegm
from struct import pack
from time import strptime
import aiohttp
from torba.server.util import hex_to_bytes, class_logger,\
unpack_le_uint16_from, pack_varint
from torba.server.hash import hex_str_to_hash, hash_to_hex_str
from torba.server.tx import DeserializerDecred
from aiorpcx import JSONRPC
class DaemonError(Exception):
'''Raised when the daemon returns an error in its results.'''
class WarmingUpError(Exception):
'''Internal - when the daemon is warming up.'''
class WorkQueueFullError(Exception):
'''Internal - when the daemon's work queue is full.'''
class Daemon(object):
'''Handles connections to a daemon at the given URL.'''
WARMING_UP = -28
id_counter = itertools.count()
def __init__(self, coin, url, max_workqueue=10, init_retry=0.25,
max_retry=4.0):
self.coin = coin
self.logger = class_logger(__name__, self.__class__.__name__)
self.set_url(url)
# Limit concurrent RPC calls to this number.
# See DEFAULT_HTTP_WORKQUEUE in bitcoind, which is typically 16
self.workqueue_semaphore = asyncio.Semaphore(value=max_workqueue)
self.init_retry = init_retry
self.max_retry = max_retry
self._height = None
self.available_rpcs = {}
def set_url(self, url):
'''Set the URLS to the given list, and switch to the first one.'''
urls = url.split(',')
urls = [self.coin.sanitize_url(url) for url in urls]
for n, url in enumerate(urls):
status = '' if n else ' (current)'
logged_url = self.logged_url(url)
self.logger.info(f'daemon #{n + 1} at {logged_url}{status}')
self.url_index = 0
self.urls = urls
def current_url(self):
'''Returns the current daemon URL.'''
return self.urls[self.url_index]
def logged_url(self, url=None):
'''The host and port part, for logging.'''
url = url or self.current_url()
return url[url.rindex('@') + 1:]
def failover(self):
'''Call to fail-over to the next daemon URL.
Returns False if there is only one, otherwise True.
'''
if len(self.urls) > 1:
self.url_index = (self.url_index + 1) % len(self.urls)
self.logger.info(f'failing over to {self.logged_url()}')
return True
return False
def client_session(self):
'''An aiohttp client session.'''
return aiohttp.ClientSession()
async def _send_data(self, data):
async with self.workqueue_semaphore:
async with self.client_session() as session:
async with session.post(self.current_url(), data=data) as resp:
kind = resp.headers.get('Content-Type', None)
if kind == 'application/json':
return await resp.json()
# bitcoind's HTTP protocol "handling" is a bad joke
text = await resp.text()
if 'Work queue depth exceeded' in text:
raise WorkQueueFullError
text = text.strip() or resp.reason
self.logger.error(text)
raise DaemonError(text)
async def _send(self, payload, processor):
'''Send a payload to be converted to JSON.
Handles temporary connection issues. Daemon reponse errors
are raise through DaemonError.
'''
def log_error(error):
nonlocal last_error_log, retry
now = time.time()
if now - last_error_log > 60:
last_error_log = now
self.logger.error(f'{error} Retrying occasionally...')
if retry == self.max_retry and self.failover():
retry = 0
on_good_message = None
last_error_log = 0
data = json.dumps(payload)
retry = self.init_retry
while True:
try:
result = await self._send_data(data)
result = processor(result)
if on_good_message:
self.logger.info(on_good_message)
return result
except asyncio.TimeoutError:
log_error('timeout error.')
except aiohttp.ServerDisconnectedError:
log_error('disconnected.')
on_good_message = 'connection restored'
except aiohttp.ClientConnectionError:
log_error('connection problem - is your daemon running?')
on_good_message = 'connection restored'
except aiohttp.ClientError as e:
log_error(f'daemon error: {e}')
on_good_message = 'running normally'
except WarmingUpError:
log_error('starting up checking blocks.')
on_good_message = 'running normally'
except WorkQueueFullError:
log_error('work queue full.')
on_good_message = 'running normally'
await asyncio.sleep(retry)
retry = max(min(self.max_retry, retry * 2), self.init_retry)
async def _send_single(self, method, params=None):
'''Send a single request to the daemon.'''
def processor(result):
err = result['error']
if not err:
return result['result']
if err.get('code') == self.WARMING_UP:
raise WarmingUpError
raise DaemonError(err)
payload = {'method': method, 'id': next(self.id_counter)}
if params:
payload['params'] = params
return await self._send(payload, processor)
async def _send_vector(self, method, params_iterable, replace_errs=False):
'''Send several requests of the same method.
The result will be an array of the same length as params_iterable.
If replace_errs is true, any item with an error is returned as None,
otherwise an exception is raised.'''
def processor(result):
errs = [item['error'] for item in result if item['error']]
if any(err.get('code') == self.WARMING_UP for err in errs):
raise WarmingUpError
if not errs or replace_errs:
return [item['result'] for item in result]
raise DaemonError(errs)
payload = [{'method': method, 'params': p, 'id': next(self.id_counter)}
for p in params_iterable]
if payload:
return await self._send(payload, processor)
return []
async def _is_rpc_available(self, method):
'''Return whether given RPC method is available in the daemon.
Results are cached and the daemon will generally not be queried with
the same method more than once.'''
available = self.available_rpcs.get(method)
if available is None:
available = True
try:
await self._send_single(method)
except DaemonError as e:
err = e.args[0]
error_code = err.get("code")
available = error_code != JSONRPC.METHOD_NOT_FOUND
self.available_rpcs[method] = available
return available
async def block_hex_hashes(self, first, count):
'''Return the hex hashes of count block starting at height first.'''
params_iterable = ((h, ) for h in range(first, first + count))
return await self._send_vector('getblockhash', params_iterable)
async def deserialised_block(self, hex_hash):
'''Return the deserialised block with the given hex hash.'''
return await self._send_single('getblock', (hex_hash, True))
async def raw_blocks(self, hex_hashes):
'''Return the raw binary blocks with the given hex hashes.'''
params_iterable = ((h, False) for h in hex_hashes)
blocks = await self._send_vector('getblock', params_iterable)
# Convert hex string to bytes
return [hex_to_bytes(block) for block in blocks]
async def mempool_hashes(self):
'''Update our record of the daemon's mempool hashes.'''
return await self._send_single('getrawmempool')
async def estimatefee(self, block_count):
'''Return the fee estimate for the block count. Units are whole
currency units per KB, e.g. 0.00000995, or -1 if no estimate
is available.
'''
args = (block_count, )
if await self._is_rpc_available('estimatesmartfee'):
estimate = await self._send_single('estimatesmartfee', args)
return estimate.get('feerate', -1)
return await self._send_single('estimatefee', args)
async def getnetworkinfo(self):
'''Return the result of the 'getnetworkinfo' RPC call.'''
return await self._send_single('getnetworkinfo')
async def relayfee(self):
'''The minimum fee a low-priority tx must pay in order to be accepted
to the daemon's memory pool.'''
network_info = await self.getnetworkinfo()
return network_info['relayfee']
async def getrawtransaction(self, hex_hash, verbose=False):
'''Return the serialized raw transaction with the given hash.'''
# Cast to int because some coin daemons are old and require it
return await self._send_single('getrawtransaction',
(hex_hash, int(verbose)))
async def getrawtransactions(self, hex_hashes, replace_errs=True):
'''Return the serialized raw transactions with the given hashes.
Replaces errors with None by default.'''
params_iterable = ((hex_hash, 0) for hex_hash in hex_hashes)
txs = await self._send_vector('getrawtransaction', params_iterable,
replace_errs=replace_errs)
# Convert hex strings to bytes
return [hex_to_bytes(tx) if tx else None for tx in txs]
async def broadcast_transaction(self, raw_tx):
'''Broadcast a transaction to the network.'''
return await self._send_single('sendrawtransaction', (raw_tx, ))
async def height(self):
'''Query the daemon for its current height.'''
self._height = await self._send_single('getblockcount')
return self._height
def cached_height(self):
'''Return the cached daemon height.
If the daemon has not been queried yet this returns None.'''
return self._height
class DashDaemon(Daemon):
async def masternode_broadcast(self, params):
'''Broadcast a transaction to the network.'''
return await self._send_single('masternodebroadcast', params)
async def masternode_list(self, params):
'''Return the masternode status.'''
return await self._send_single('masternodelist', params)
class FakeEstimateFeeDaemon(Daemon):
'''Daemon that simulates estimatefee and relayfee RPC calls. Coin that
wants to use this daemon must define ESTIMATE_FEE & RELAY_FEE'''
async def estimatefee(self, block_count):
'''Return the fee estimate for the given parameters.'''
return self.coin.ESTIMATE_FEE
async def relayfee(self):
'''The minimum fee a low-priority tx must pay in order to be accepted
to the daemon's memory pool.'''
return self.coin.RELAY_FEE
class LegacyRPCDaemon(Daemon):
'''Handles connections to a daemon at the given URL.
This class is useful for daemons that don't have the new 'getblock'
RPC call that returns the block in hex, the workaround is to manually
recreate the block bytes. The recreated block bytes may not be the exact
as in the underlying blockchain but it is good enough for our indexing
purposes.'''
async def raw_blocks(self, hex_hashes):
'''Return the raw binary blocks with the given hex hashes.'''
params_iterable = ((h, ) for h in hex_hashes)
block_info = await self._send_vector('getblock', params_iterable)
blocks = []
for i in block_info:
raw_block = await self.make_raw_block(i)
blocks.append(raw_block)
# Convert hex string to bytes
return blocks
async def make_raw_header(self, b):
pbh = b.get('previousblockhash')
if pbh is None:
pbh = '0' * 64
return b''.join([
pack('<L', b.get('version')),
hex_str_to_hash(pbh),
hex_str_to_hash(b.get('merkleroot')),
pack('<L', self.timestamp_safe(b['time'])),
pack('<L', int(b.get('bits'), 16)),
pack('<L', int(b.get('nonce')))
])
async def make_raw_block(self, b):
'''Construct a raw block'''
header = await self.make_raw_header(b)
transactions = []
if b.get('height') > 0:
transactions = await self.getrawtransactions(b.get('tx'), False)
raw_block = header
num_txs = len(transactions)
if num_txs > 0:
raw_block += pack_varint(num_txs)
raw_block += b''.join(transactions)
else:
raw_block += b'\x00'
return raw_block
def timestamp_safe(self, t):
if isinstance(t, int):
return t
return timegm(strptime(t, "%Y-%m-%d %H:%M:%S %Z"))
class DecredDaemon(Daemon):
async def raw_blocks(self, hex_hashes):
'''Return the raw binary blocks with the given hex hashes.'''
params_iterable = ((h, False) for h in hex_hashes)
blocks = await self._send_vector('getblock', params_iterable)
raw_blocks = []
valid_tx_tree = {}
for block in blocks:
# Convert to bytes from hex
raw_block = hex_to_bytes(block)
raw_blocks.append(raw_block)
# Check if previous block is valid
prev = self.prev_hex_hash(raw_block)
votebits = unpack_le_uint16_from(raw_block[100:102])[0]
valid_tx_tree[prev] = self.is_valid_tx_tree(votebits)
processed_raw_blocks = []
for hash, raw_block in zip(hex_hashes, raw_blocks):
if hash in valid_tx_tree:
is_valid = valid_tx_tree[hash]
else:
# Do something complicated to figure out if this block is valid
header = await self._send_single('getblockheader', (hash, ))
if 'nextblockhash' not in header:
raise DaemonError(f'Could not find next block for {hash}')
next_hash = header['nextblockhash']
next_header = await self._send_single('getblockheader',
(next_hash, ))
is_valid = self.is_valid_tx_tree(next_header['votebits'])
if is_valid:
processed_raw_blocks.append(raw_block)
else:
# If this block is invalid remove the normal transactions
self.logger.info(f'block {hash} is invalidated')
processed_raw_blocks.append(self.strip_tx_tree(raw_block))
return processed_raw_blocks
@staticmethod
def prev_hex_hash(raw_block):
return hash_to_hex_str(raw_block[4:36])
@staticmethod
def is_valid_tx_tree(votebits):
# Check if previous block was invalidated.
return bool(votebits & (1 << 0) != 0)
def strip_tx_tree(self, raw_block):
c = self.coin
assert issubclass(c.DESERIALIZER, DeserializerDecred)
d = c.DESERIALIZER(raw_block, start=c.BASIC_HEADER_SIZE)
d.read_tx_tree() # Skip normal transactions
# Create a fake block without any normal transactions
return raw_block[:c.BASIC_HEADER_SIZE] + b'\x00' + raw_block[d.cursor:]
async def height(self):
height = await super().height()
if height > 0:
# Lie about the daemon height as the current tip can be invalidated
height -= 1
self._height = height
return height
async def mempool_hashes(self):
mempool = await super().mempool_hashes()
# Add current tip transactions to the 'fake' mempool.
real_height = await self._send_single('getblockcount')
tip_hash = await self._send_single('getblockhash', (real_height,))
tip = await self.deserialised_block(tip_hash)
# Add normal transactions except coinbase
mempool += tip['tx'][1:]
# Add stake transactions if applicable
mempool += tip.get('stx', [])
return mempool
def client_session(self):
# FIXME allow self signed certificates
connector = aiohttp.TCPConnector(verify_ssl=False)
return aiohttp.ClientSession(connector=connector)
class PreLegacyRPCDaemon(LegacyRPCDaemon):
'''Handles connections to a daemon at the given URL.
This class is useful for daemons that don't have the new 'getblock'
RPC call that returns the block in hex, and need the False parameter
for the getblock'''
async def deserialised_block(self, hex_hash):
'''Return the deserialised block with the given hex hash.'''
return await self._send_single('getblock', (hex_hash, False))

665
torba/server/db.py Normal file
View file

@ -0,0 +1,665 @@
# Copyright (c) 2016, Neil Booth
# Copyright (c) 2017, the ElectrumX authors
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
'''Interface to the blockchain database.'''
import array
import ast
import os
import time
from bisect import bisect_right
from collections import namedtuple
from glob import glob
from struct import pack, unpack
import attr
from aiorpcx import run_in_thread, sleep
import torba.server.util as util
from torba.server.hash import hash_to_hex_str, HASHX_LEN
from torba.server.merkle import Merkle, MerkleCache
from torba.server.util import formatted_time
from torba.server.storage import db_class
from torba.server.history import History
UTXO = namedtuple("UTXO", "tx_num tx_pos tx_hash height value")
@attr.s(slots=True)
class FlushData(object):
height = attr.ib()
tx_count = attr.ib()
headers = attr.ib()
block_tx_hashes = attr.ib()
# The following are flushed to the UTXO DB if undo_infos is not None
undo_infos = attr.ib()
adds = attr.ib()
deletes = attr.ib()
tip = attr.ib()
class DB(object):
'''Simple wrapper of the backend database for querying.
Performs no DB update, though the DB will be cleaned on opening if
it was shutdown uncleanly.
'''
DB_VERSIONS = [6]
class DBError(Exception):
'''Raised on general DB errors generally indicating corruption.'''
def __init__(self, env):
self.logger = util.class_logger(__name__, self.__class__.__name__)
self.env = env
self.coin = env.coin
# Setup block header size handlers
if self.coin.STATIC_BLOCK_HEADERS:
self.header_offset = self.coin.static_header_offset
self.header_len = self.coin.static_header_len
else:
self.header_offset = self.dynamic_header_offset
self.header_len = self.dynamic_header_len
self.logger.info(f'switching current directory to {env.db_dir}')
os.chdir(env.db_dir)
self.db_class = db_class(self.env.db_engine)
self.history = History()
self.utxo_db = None
self.tx_counts = None
self.last_flush = time.time()
self.logger.info(f'using {self.env.db_engine} for DB backend')
# Header merkle cache
self.merkle = Merkle()
self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes)
self.headers_file = util.LogicalFile('meta/headers', 2, 16000000)
self.tx_counts_file = util.LogicalFile('meta/txcounts', 2, 2000000)
self.hashes_file = util.LogicalFile('meta/hashes', 4, 16000000)
if not self.coin.STATIC_BLOCK_HEADERS:
self.headers_offsets_file = util.LogicalFile(
'meta/headers_offsets', 2, 16000000)
async def _read_tx_counts(self):
if self.tx_counts is not None:
return
# tx_counts[N] has the cumulative number of txs at the end of
# height N. So tx_counts[0] is 1 - the genesis coinbase
size = (self.db_height + 1) * 4
tx_counts = self.tx_counts_file.read(0, size)
assert len(tx_counts) == size
self.tx_counts = array.array('I', tx_counts)
if self.tx_counts:
assert self.db_tx_count == self.tx_counts[-1]
else:
assert self.db_tx_count == 0
async def _open_dbs(self, for_sync, compacting):
assert self.utxo_db is None
# First UTXO DB
self.utxo_db = self.db_class('utxo', for_sync)
if self.utxo_db.is_new:
self.logger.info('created new database')
self.logger.info('creating metadata directory')
os.mkdir('meta')
with util.open_file('COIN', create=True) as f:
f.write(f'ElectrumX databases and metadata for '
f'{self.coin.NAME} {self.coin.NET}'.encode())
if not self.coin.STATIC_BLOCK_HEADERS:
self.headers_offsets_file.write(0, bytes(8))
else:
self.logger.info(f'opened UTXO DB (for sync: {for_sync})')
self.read_utxo_state()
# Then history DB
self.utxo_flush_count = self.history.open_db(self.db_class, for_sync,
self.utxo_flush_count,
compacting)
self.clear_excess_undo_info()
# Read TX counts (requires meta directory)
await self._read_tx_counts()
async def open_for_compacting(self):
await self._open_dbs(True, True)
async def open_for_sync(self):
'''Open the databases to sync to the daemon.
When syncing we want to reserve a lot of open files for the
synchronization. When serving clients we want the open files for
serving network connections.
'''
await self._open_dbs(True, False)
async def open_for_serving(self):
'''Open the databases for serving. If they are already open they are
closed first.
'''
if self.utxo_db:
self.logger.info('closing DBs to re-open for serving')
self.utxo_db.close()
self.history.close_db()
self.utxo_db = None
await self._open_dbs(False, False)
# Header merkle cache
async def populate_header_merkle_cache(self):
self.logger.info('populating header merkle cache...')
length = max(1, self.db_height - self.env.reorg_limit)
start = time.time()
await self.header_mc.initialize(length)
elapsed = time.time() - start
self.logger.info(f'header merkle cache populated in {elapsed:.1f}s')
async def header_branch_and_root(self, length, height):
return await self.header_mc.branch_and_root(length, height)
# Flushing
def assert_flushed(self, flush_data):
'''Asserts state is fully flushed.'''
assert flush_data.tx_count == self.fs_tx_count == self.db_tx_count
assert flush_data.height == self.fs_height == self.db_height
assert flush_data.tip == self.db_tip
assert not flush_data.headers
assert not flush_data.block_tx_hashes
assert not flush_data.adds
assert not flush_data.deletes
assert not flush_data.undo_infos
self.history.assert_flushed()
def flush_dbs(self, flush_data, flush_utxos, estimate_txs_remaining):
'''Flush out cached state. History is always flushed; UTXOs are
flushed if flush_utxos.'''
if flush_data.height == self.db_height:
self.assert_flushed(flush_data)
return
start_time = time.time()
prior_flush = self.last_flush
tx_delta = flush_data.tx_count - self.last_flush_tx_count
# Flush to file system
self.flush_fs(flush_data)
# Then history
self.flush_history()
# Flush state last as it reads the wall time.
with self.utxo_db.write_batch() as batch:
if flush_utxos:
self.flush_utxo_db(batch, flush_data)
self.flush_state(batch)
# Update and put the wall time again - otherwise we drop the
# time it took to commit the batch
self.flush_state(self.utxo_db)
elapsed = self.last_flush - start_time
self.logger.info(f'flush #{self.history.flush_count:,d} took '
f'{elapsed:.1f}s. Height {flush_data.height:,d} '
f'txs: {flush_data.tx_count:,d} ({tx_delta:+,d})')
# Catch-up stats
if self.utxo_db.for_sync:
flush_interval = self.last_flush - prior_flush
tx_per_sec_gen = int(flush_data.tx_count / self.wall_time)
tx_per_sec_last = 1 + int(tx_delta / flush_interval)
eta = estimate_txs_remaining() / tx_per_sec_last
self.logger.info(f'tx/sec since genesis: {tx_per_sec_gen:,d}, '
f'since last flush: {tx_per_sec_last:,d}')
self.logger.info(f'sync time: {formatted_time(self.wall_time)} '
f'ETA: {formatted_time(eta)}')
def flush_fs(self, flush_data):
'''Write headers, tx counts and block tx hashes to the filesystem.
The first height to write is self.fs_height + 1. The FS
metadata is all append-only, so in a crash we just pick up
again from the height stored in the DB.
'''
prior_tx_count = (self.tx_counts[self.fs_height]
if self.fs_height >= 0 else 0)
assert len(flush_data.block_tx_hashes) == len(flush_data.headers)
assert flush_data.height == self.fs_height + len(flush_data.headers)
assert flush_data.tx_count == (self.tx_counts[-1] if self.tx_counts
else 0)
assert len(self.tx_counts) == flush_data.height + 1
hashes = b''.join(flush_data.block_tx_hashes)
flush_data.block_tx_hashes.clear()
assert len(hashes) % 32 == 0
assert len(hashes) // 32 == flush_data.tx_count - prior_tx_count
# Write the headers, tx counts, and tx hashes
start_time = time.time()
height_start = self.fs_height + 1
offset = self.header_offset(height_start)
self.headers_file.write(offset, b''.join(flush_data.headers))
self.fs_update_header_offsets(offset, height_start, flush_data.headers)
flush_data.headers.clear()
offset = height_start * self.tx_counts.itemsize
self.tx_counts_file.write(offset,
self.tx_counts[height_start:].tobytes())
offset = prior_tx_count * 32
self.hashes_file.write(offset, hashes)
self.fs_height = flush_data.height
self.fs_tx_count = flush_data.tx_count
if self.utxo_db.for_sync:
elapsed = time.time() - start_time
self.logger.info(f'flushed filesystem data in {elapsed:.2f}s')
def flush_history(self):
self.history.flush()
def flush_utxo_db(self, batch, flush_data):
'''Flush the cached DB writes and UTXO set to the batch.'''
# Care is needed because the writes generated by flushing the
# UTXO state may have keys in common with our write cache or
# may be in the DB already.
start_time = time.time()
add_count = len(flush_data.adds)
spend_count = len(flush_data.deletes) // 2
# Spends
batch_delete = batch.delete
for key in sorted(flush_data.deletes):
batch_delete(key)
flush_data.deletes.clear()
# New UTXOs
batch_put = batch.put
for key, value in flush_data.adds.items():
# suffix = tx_idx + tx_num
hashX = value[:-12]
suffix = key[-2:] + value[-12:-8]
batch_put(b'h' + key[:4] + suffix, hashX)
batch_put(b'u' + hashX + suffix, value[-8:])
flush_data.adds.clear()
# New undo information
self.flush_undo_infos(batch_put, flush_data.undo_infos)
flush_data.undo_infos.clear()
if self.utxo_db.for_sync:
block_count = flush_data.height - self.db_height
tx_count = flush_data.tx_count - self.db_tx_count
elapsed = time.time() - start_time
self.logger.info(f'flushed {block_count:,d} blocks with '
f'{tx_count:,d} txs, {add_count:,d} UTXO adds, '
f'{spend_count:,d} spends in '
f'{elapsed:.1f}s, committing...')
self.utxo_flush_count = self.history.flush_count
self.db_height = flush_data.height
self.db_tx_count = flush_data.tx_count
self.db_tip = flush_data.tip
def flush_state(self, batch):
'''Flush chain state to the batch.'''
now = time.time()
self.wall_time += now - self.last_flush
self.last_flush = now
self.last_flush_tx_count = self.fs_tx_count
self.write_utxo_state(batch)
def flush_backup(self, flush_data, touched):
'''Like flush_dbs() but when backing up. All UTXOs are flushed.'''
assert not flush_data.headers
assert not flush_data.block_tx_hashes
assert flush_data.height < self.db_height
self.history.assert_flushed()
start_time = time.time()
tx_delta = flush_data.tx_count - self.last_flush_tx_count
self.backup_fs(flush_data.height, flush_data.tx_count)
self.history.backup(touched, flush_data.tx_count)
with self.utxo_db.write_batch() as batch:
self.flush_utxo_db(batch, flush_data)
# Flush state last as it reads the wall time.
self.flush_state(batch)
elapsed = self.last_flush - start_time
self.logger.info(f'backup flush #{self.history.flush_count:,d} took '
f'{elapsed:.1f}s. Height {flush_data.height:,d} '
f'txs: {flush_data.tx_count:,d} ({tx_delta:+,d})')
def fs_update_header_offsets(self, offset_start, height_start, headers):
if self.coin.STATIC_BLOCK_HEADERS:
return
offset = offset_start
offsets = []
for h in headers:
offset += len(h)
offsets.append(pack("<Q", offset))
# For each header we get the offset of the next header, hence we
# start writing from the next height
pos = (height_start + 1) * 8
self.headers_offsets_file.write(pos, b''.join(offsets))
def dynamic_header_offset(self, height):
assert not self.coin.STATIC_BLOCK_HEADERS
offset, = unpack('<Q', self.headers_offsets_file.read(height * 8, 8))
return offset
def dynamic_header_len(self, height):
return self.dynamic_header_offset(height + 1)\
- self.dynamic_header_offset(height)
def backup_fs(self, height, tx_count):
'''Back up during a reorg. This just updates our pointers.'''
self.fs_height = height
self.fs_tx_count = tx_count
# Truncate header_mc: header count is 1 more than the height.
self.header_mc.truncate(height + 1)
async def raw_header(self, height):
'''Return the binary header at the given height.'''
header, n = await self.read_headers(height, 1)
if n != 1:
raise IndexError(f'height {height:,d} out of range')
return header
async def read_headers(self, start_height, count):
'''Requires start_height >= 0, count >= 0. Reads as many headers as
are available starting at start_height up to count. This
would be zero if start_height is beyond self.db_height, for
example.
Returns a (binary, n) pair where binary is the concatenated
binary headers, and n is the count of headers returned.
'''
if start_height < 0 or count < 0:
raise self.DBError(f'{count:,d} headers starting at '
f'{start_height:,d} not on disk')
def read_headers():
# Read some from disk
disk_count = max(0, min(count, self.db_height + 1 - start_height))
if disk_count:
offset = self.header_offset(start_height)
size = self.header_offset(start_height + disk_count) - offset
return self.headers_file.read(offset, size), disk_count
return b'', 0
return await run_in_thread(read_headers)
def fs_tx_hash(self, tx_num):
'''Return a par (tx_hash, tx_height) for the given tx number.
If the tx_height is not on disk, returns (None, tx_height).'''
tx_height = bisect_right(self.tx_counts, tx_num)
if tx_height > self.db_height:
tx_hash = None
else:
tx_hash = self.hashes_file.read(tx_num * 32, 32)
return tx_hash, tx_height
async def fs_block_hashes(self, height, count):
headers_concat, headers_count = await self.read_headers(height, count)
if headers_count != count:
raise self.DBError('only got {:,d} headers starting at {:,d}, not '
'{:,d}'.format(headers_count, height, count))
offset = 0
headers = []
for n in range(count):
hlen = self.header_len(height + n)
headers.append(headers_concat[offset:offset + hlen])
offset += hlen
return [self.coin.header_hash(header) for header in headers]
async def limited_history(self, hashX, *, limit=1000):
'''Return an unpruned, sorted list of (tx_hash, height) tuples of
confirmed transactions that touched the address, earliest in
the blockchain first. Includes both spending and receiving
transactions. By default returns at most 1000 entries. Set
limit to None to get them all.
'''
def read_history():
tx_nums = list(self.history.get_txnums(hashX, limit))
fs_tx_hash = self.fs_tx_hash
return [fs_tx_hash(tx_num) for tx_num in tx_nums]
while True:
history = await run_in_thread(read_history)
if all(hash is not None for hash, height in history):
return history
self.logger.warning(f'limited_history: tx hash '
f'not found (reorg?), retrying...')
await sleep(0.25)
# -- Undo information
def min_undo_height(self, max_height):
'''Returns a height from which we should store undo info.'''
return max_height - self.env.reorg_limit + 1
def undo_key(self, height):
'''DB key for undo information at the given height.'''
return b'U' + pack('>I', height)
def read_undo_info(self, height):
'''Read undo information from a file for the current height.'''
return self.utxo_db.get(self.undo_key(height))
def flush_undo_infos(self, batch_put, undo_infos):
'''undo_infos is a list of (undo_info, height) pairs.'''
for undo_info, height in undo_infos:
batch_put(self.undo_key(height), b''.join(undo_info))
def raw_block_prefix(self):
return 'meta/block'
def raw_block_path(self, height):
return f'{self.raw_block_prefix()}{height:d}'
def read_raw_block(self, height):
'''Returns a raw block read from disk. Raises FileNotFoundError
if the block isn't on-disk.'''
with util.open_file(self.raw_block_path(height)) as f:
return f.read(-1)
def write_raw_block(self, block, height):
'''Write a raw block to disk.'''
with util.open_truncate(self.raw_block_path(height)) as f:
f.write(block)
# Delete old blocks to prevent them accumulating
try:
del_height = self.min_undo_height(height) - 1
os.remove(self.raw_block_path(del_height))
except FileNotFoundError:
pass
def clear_excess_undo_info(self):
'''Clear excess undo info. Only most recent N are kept.'''
prefix = b'U'
min_height = self.min_undo_height(self.db_height)
keys = []
for key, hist in self.utxo_db.iterator(prefix=prefix):
height, = unpack('>I', key[-4:])
if height >= min_height:
break
keys.append(key)
if keys:
with self.utxo_db.write_batch() as batch:
for key in keys:
batch.delete(key)
self.logger.info(f'deleted {len(keys):,d} stale undo entries')
# delete old block files
prefix = self.raw_block_prefix()
paths = [path for path in glob(f'{prefix}[0-9]*')
if len(path) > len(prefix)
and int(path[len(prefix):]) < min_height]
if paths:
for path in paths:
try:
os.remove(path)
except FileNotFoundError:
pass
self.logger.info(f'deleted {len(paths):,d} stale block files')
# -- UTXO database
def read_utxo_state(self):
state = self.utxo_db.get(b'state')
if not state:
self.db_height = -1
self.db_tx_count = 0
self.db_tip = b'\0' * 32
self.db_version = max(self.DB_VERSIONS)
self.utxo_flush_count = 0
self.wall_time = 0
self.first_sync = True
else:
state = ast.literal_eval(state.decode())
if not isinstance(state, dict):
raise self.DBError('failed reading state from DB')
self.db_version = state['db_version']
if self.db_version not in self.DB_VERSIONS:
raise self.DBError('your UTXO DB version is {} but this '
'software only handles versions {}'
.format(self.db_version, self.DB_VERSIONS))
# backwards compat
genesis_hash = state['genesis']
if isinstance(genesis_hash, bytes):
genesis_hash = genesis_hash.decode()
if genesis_hash != self.coin.GENESIS_HASH:
raise self.DBError('DB genesis hash {} does not match coin {}'
.format(genesis_hash,
self.coin.GENESIS_HASH))
self.db_height = state['height']
self.db_tx_count = state['tx_count']
self.db_tip = state['tip']
self.utxo_flush_count = state['utxo_flush_count']
self.wall_time = state['wall_time']
self.first_sync = state['first_sync']
# These are our state as we move ahead of DB state
self.fs_height = self.db_height
self.fs_tx_count = self.db_tx_count
self.last_flush_tx_count = self.fs_tx_count
# Log some stats
self.logger.info('DB version: {:d}'.format(self.db_version))
self.logger.info('coin: {}'.format(self.coin.NAME))
self.logger.info('network: {}'.format(self.coin.NET))
self.logger.info('height: {:,d}'.format(self.db_height))
self.logger.info('tip: {}'.format(hash_to_hex_str(self.db_tip)))
self.logger.info('tx count: {:,d}'.format(self.db_tx_count))
if self.utxo_db.for_sync:
self.logger.info(f'flushing DB cache at {self.env.cache_MB:,d} MB')
if self.first_sync:
self.logger.info('sync time so far: {}'
.format(util.formatted_time(self.wall_time)))
def write_utxo_state(self, batch):
'''Write (UTXO) state to the batch.'''
state = {
'genesis': self.coin.GENESIS_HASH,
'height': self.db_height,
'tx_count': self.db_tx_count,
'tip': self.db_tip,
'utxo_flush_count': self.utxo_flush_count,
'wall_time': self.wall_time,
'first_sync': self.first_sync,
'db_version': self.db_version,
}
batch.put(b'state', repr(state).encode())
def set_flush_count(self, count):
self.utxo_flush_count = count
with self.utxo_db.write_batch() as batch:
self.write_utxo_state(batch)
async def all_utxos(self, hashX):
'''Return all UTXOs for an address sorted in no particular order.'''
def read_utxos():
utxos = []
utxos_append = utxos.append
s_unpack = unpack
# Key: b'u' + address_hashX + tx_idx + tx_num
# Value: the UTXO value as a 64-bit unsigned integer
prefix = b'u' + hashX
for db_key, db_value in self.utxo_db.iterator(prefix=prefix):
tx_pos, tx_num = s_unpack('<HI', db_key[-6:])
value, = unpack('<Q', db_value)
tx_hash, height = self.fs_tx_hash(tx_num)
utxos_append(UTXO(tx_num, tx_pos, tx_hash, height, value))
return utxos
while True:
utxos = await run_in_thread(read_utxos)
if all(utxo.tx_hash is not None for utxo in utxos):
return utxos
self.logger.warning(f'all_utxos: tx hash not '
f'found (reorg?), retrying...')
await sleep(0.25)
async def lookup_utxos(self, prevouts):
'''For each prevout, lookup it up in the DB and return a (hashX,
value) pair or None if not found.
Used by the mempool code.
'''
def lookup_hashXs():
'''Return (hashX, suffix) pairs, or None if not found,
for each prevout.
'''
def lookup_hashX(tx_hash, tx_idx):
idx_packed = pack('<H', tx_idx)
# Key: b'h' + compressed_tx_hash + tx_idx + tx_num
# Value: hashX
prefix = b'h' + tx_hash[:4] + idx_packed
# Find which entry, if any, the TX_HASH matches.
for db_key, hashX in self.utxo_db.iterator(prefix=prefix):
tx_num_packed = db_key[-4:]
tx_num, = unpack('<I', tx_num_packed)
hash, height = self.fs_tx_hash(tx_num)
if hash == tx_hash:
return hashX, idx_packed + tx_num_packed
return None, None
return [lookup_hashX(*prevout) for prevout in prevouts]
def lookup_utxos(hashX_pairs):
def lookup_utxo(hashX, suffix):
if not hashX:
# This can happen when the daemon is a block ahead
# of us and has mempool txs spending outputs from
# that new block
return None
# Key: b'u' + address_hashX + tx_idx + tx_num
# Value: the UTXO value as a 64-bit unsigned integer
key = b'u' + hashX + suffix
db_value = self.utxo_db.get(key)
if not db_value:
# This can happen if the DB was updated between
# getting the hashXs and getting the UTXOs
return None
value, = unpack('<Q', db_value)
return hashX, value
return [lookup_utxo(*hashX_pair) for hashX_pair in hashX_pairs]
hashX_pairs = await run_in_thread(lookup_hashXs)
return await run_in_thread(lookup_utxos, hashX_pairs)

54
torba/server/enum.py Normal file
View file

@ -0,0 +1,54 @@
# Copyright (c) 2016, Neil Booth
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
'''An enum-like type with reverse lookup.
Source: Python Cookbook, http://code.activestate.com/recipes/67107/
'''
class EnumError(Exception):
pass
class Enumeration:
def __init__(self, name, enumList):
self.__doc__ = name
lookup = {}
reverseLookup = {}
i = 0
uniqueNames = set()
uniqueValues = set()
for x in enumList:
if isinstance(x, tuple):
x, i = x
if not isinstance(x, str):
raise EnumError("enum name {} not a string".format(x))
if not isinstance(i, int):
raise EnumError("enum value {} not an integer".format(i))
if x in uniqueNames:
raise EnumError("enum name {} not unique".format(x))
if i in uniqueValues:
raise EnumError("enum value {} not unique".format(x))
uniqueNames.add(x)
uniqueValues.add(i)
lookup[x] = i
reverseLookup[i] = x
i = i + 1
self.lookup = lookup
self.reverseLookup = reverseLookup
def __getattr__(self, attr):
result = self.lookup.get(attr)
if result is None:
raise AttributeError('enumeration has no member {}'.format(attr))
return result
def whatis(self, value):
return self.reverseLookup[value]

169
torba/server/env.py Normal file
View file

@ -0,0 +1,169 @@
# Copyright (c) 2016, Neil Booth
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
'''Class for handling environment configuration and defaults.'''
import re
import resource
from collections import namedtuple
from ipaddress import ip_address
from torba.server.coins import Coin
from torba.server.env_base import EnvBase
import torba.server.util as lib_util
NetIdentity = namedtuple('NetIdentity', 'host tcp_port ssl_port nick_suffix')
class Env(EnvBase):
'''Wraps environment configuration. Optionally, accepts a Coin class
as first argument to have ElectrumX serve custom coins not part of
the standard distribution.
'''
# Peer discovery
PD_OFF, PD_SELF, PD_ON = range(3)
def __init__(self, coin=None):
super().__init__()
self.obsolete(['UTXO_MB', 'HIST_MB', 'NETWORK'])
self.db_dir = self.required('DB_DIRECTORY')
self.db_engine = self.default('DB_ENGINE', 'leveldb')
self.daemon_url = self.required('DAEMON_URL')
if coin is not None:
assert issubclass(coin, Coin)
self.coin = coin
else:
coin_name = self.required('COIN').strip()
network = self.default('NET', 'mainnet').strip()
self.coin = Coin.lookup_coin_class(coin_name, network)
self.cache_MB = self.integer('CACHE_MB', 1200)
self.host = self.default('HOST', 'localhost')
self.reorg_limit = self.integer('REORG_LIMIT', self.coin.REORG_LIMIT)
# Server stuff
self.tcp_port = self.integer('TCP_PORT', None)
self.ssl_port = self.integer('SSL_PORT', None)
if self.ssl_port:
self.ssl_certfile = self.required('SSL_CERTFILE')
self.ssl_keyfile = self.required('SSL_KEYFILE')
self.rpc_port = self.integer('RPC_PORT', 8000)
self.max_subscriptions = self.integer('MAX_SUBSCRIPTIONS', 10000)
self.banner_file = self.default('BANNER_FILE', None)
self.tor_banner_file = self.default('TOR_BANNER_FILE',
self.banner_file)
self.anon_logs = self.boolean('ANON_LOGS', False)
self.log_sessions = self.integer('LOG_SESSIONS', 3600)
# Peer discovery
self.peer_discovery = self.peer_discovery_enum()
self.peer_announce = self.boolean('PEER_ANNOUNCE', True)
self.force_proxy = self.boolean('FORCE_PROXY', False)
self.tor_proxy_host = self.default('TOR_PROXY_HOST', 'localhost')
self.tor_proxy_port = self.integer('TOR_PROXY_PORT', None)
# The electrum client takes the empty string as unspecified
self.donation_address = self.default('DONATION_ADDRESS', '')
# Server limits to help prevent DoS
self.max_send = self.integer('MAX_SEND', 1000000)
self.max_subs = self.integer('MAX_SUBS', 250000)
self.max_sessions = self.sane_max_sessions()
self.max_session_subs = self.integer('MAX_SESSION_SUBS', 50000)
self.bandwidth_limit = self.integer('BANDWIDTH_LIMIT', 2000000)
self.session_timeout = self.integer('SESSION_TIMEOUT', 600)
self.drop_client = self.custom("DROP_CLIENT", None, re.compile)
# Identities
clearnet_identity = self.clearnet_identity()
tor_identity = self.tor_identity(clearnet_identity)
self.identities = [identity
for identity in (clearnet_identity, tor_identity)
if identity is not None]
def sane_max_sessions(self):
'''Return the maximum number of sessions to permit. Normally this
is MAX_SESSIONS. However, to prevent open file exhaustion, ajdust
downwards if running with a small open file rlimit.'''
env_value = self.integer('MAX_SESSIONS', 1000)
nofile_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
# We give the DB 250 files; allow ElectrumX 100 for itself
value = max(0, min(env_value, nofile_limit - 350))
if value < env_value:
self.logger.warning('lowered maximum sessions from {:,d} to {:,d} '
'because your open file limit is {:,d}'
.format(env_value, value, nofile_limit))
return value
def clearnet_identity(self):
host = self.default('REPORT_HOST', None)
if host is None:
return None
try:
ip = ip_address(host)
except ValueError:
bad = (not lib_util.is_valid_hostname(host)
or host.lower() == 'localhost')
else:
bad = (ip.is_multicast or ip.is_unspecified
or (ip.is_private and self.peer_announce))
if bad:
raise self.Error('"{}" is not a valid REPORT_HOST'.format(host))
tcp_port = self.integer('REPORT_TCP_PORT', self.tcp_port) or None
ssl_port = self.integer('REPORT_SSL_PORT', self.ssl_port) or None
if tcp_port == ssl_port:
raise self.Error('REPORT_TCP_PORT and REPORT_SSL_PORT '
'both resolve to {}'.format(tcp_port))
return NetIdentity(
host,
tcp_port,
ssl_port,
''
)
def tor_identity(self, clearnet):
host = self.default('REPORT_HOST_TOR', None)
if host is None:
return None
if not host.endswith('.onion'):
raise self.Error('tor host "{}" must end with ".onion"'
.format(host))
def port(port_kind):
'''Returns the clearnet identity port, if any and not zero,
otherwise the listening port.'''
result = 0
if clearnet:
result = getattr(clearnet, port_kind)
return result or getattr(self, port_kind)
tcp_port = self.integer('REPORT_TCP_PORT_TOR',
port('tcp_port')) or None
ssl_port = self.integer('REPORT_SSL_PORT_TOR',
port('ssl_port')) or None
if tcp_port == ssl_port:
raise self.Error('REPORT_TCP_PORT_TOR and REPORT_SSL_PORT_TOR '
'both resolve to {}'.format(tcp_port))
return NetIdentity(
host,
tcp_port,
ssl_port,
'_tor',
)
def hosts_dict(self):
return {identity.host: {'tcp_port': identity.tcp_port,
'ssl_port': identity.ssl_port}
for identity in self.identities}
def peer_discovery_enum(self):
pd = self.default('PEER_DISCOVERY', 'on').strip().lower()
if pd in ('off', ''):
return self.PD_OFF
elif pd == 'self':
return self.PD_SELF
else:
return self.PD_ON

99
torba/server/env_base.py Normal file
View file

@ -0,0 +1,99 @@
# Copyright (c) 2017, Neil Booth
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
'''Class for server environment configuration and defaults.'''
from os import environ
from torba.server.util import class_logger
class EnvBase(object):
'''Wraps environment configuration.'''
class Error(Exception):
pass
def __init__(self):
self.logger = class_logger(__name__, self.__class__.__name__)
self.allow_root = self.boolean('ALLOW_ROOT', False)
self.host = self.default('HOST', 'localhost')
self.rpc_host = self.default('RPC_HOST', 'localhost')
self.loop_policy = self.event_loop_policy()
@classmethod
def default(cls, envvar, default):
return environ.get(envvar, default)
@classmethod
def boolean(cls, envvar, default):
default = 'Yes' if default else ''
return bool(cls.default(envvar, default).strip())
@classmethod
def required(cls, envvar):
value = environ.get(envvar)
if value is None:
raise cls.Error('required envvar {} not set'.format(envvar))
return value
@classmethod
def integer(cls, envvar, default):
value = environ.get(envvar)
if value is None:
return default
try:
return int(value)
except Exception:
raise cls.Error('cannot convert envvar {} value {} to an integer'
.format(envvar, value))
@classmethod
def custom(cls, envvar, default, parse):
value = environ.get(envvar)
if value is None:
return default
try:
return parse(value)
except Exception as e:
raise cls.Error('cannot parse envvar {} value {}'
.format(envvar, value)) from e
@classmethod
def obsolete(cls, envvars):
bad = [envvar for envvar in envvars if environ.get(envvar)]
if bad:
raise cls.Error('remove obsolete environment variables {}'
.format(bad))
def event_loop_policy(self):
policy = self.default('EVENT_LOOP_POLICY', None)
if policy is None:
return None
if policy == 'uvloop':
import uvloop
return uvloop.EventLoopPolicy()
raise self.Error('unknown event loop policy "{}"'.format(policy))
def cs_host(self, *, for_rpc):
'''Returns the 'host' argument to pass to asyncio's create_server
call. The result can be a single host name string, a list of
host name strings, or an empty string to bind to all interfaces.
If rpc is True the host to use for the RPC server is returned.
Otherwise the host to use for SSL/TCP servers is returned.
'''
host = self.rpc_host if for_rpc else self.host
result = [part.strip() for part in host.split(',')]
if len(result) == 1:
result = result[0]
# An empty result indicates all interfaces, which we do not
# permitted for an RPC server.
if for_rpc and not result:
result = 'localhost'
return result

159
torba/server/hash.py Normal file
View file

@ -0,0 +1,159 @@
# Copyright (c) 2016-2017, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''Cryptograph hash functions and related classes.'''
import hashlib
import hmac
from torba.server.util import bytes_to_int, int_to_bytes, hex_to_bytes
_sha256 = hashlib.sha256
_sha512 = hashlib.sha512
_new_hash = hashlib.new
_new_hmac = hmac.new
HASHX_LEN = 11
def sha256(x):
'''Simple wrapper of hashlib sha256.'''
return _sha256(x).digest()
def ripemd160(x):
'''Simple wrapper of hashlib ripemd160.'''
h = _new_hash('ripemd160')
h.update(x)
return h.digest()
def double_sha256(x):
'''SHA-256 of SHA-256, as used extensively in bitcoin.'''
return sha256(sha256(x))
def hmac_sha512(key, msg):
'''Use SHA-512 to provide an HMAC.'''
return _new_hmac(key, msg, _sha512).digest()
def hash160(x):
'''RIPEMD-160 of SHA-256.
Used to make bitcoin addresses from pubkeys.'''
return ripemd160(sha256(x))
def hash_to_hex_str(x):
'''Convert a big-endian binary hash to displayed hex string.
Display form of a binary hash is reversed and converted to hex.
'''
return bytes(reversed(x)).hex()
def hex_str_to_hash(x):
'''Convert a displayed hex string to a binary hash.'''
return bytes(reversed(hex_to_bytes(x)))
class Base58Error(Exception):
'''Exception used for Base58 errors.'''
class Base58(object):
'''Class providing base 58 functionality.'''
chars = '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz'
assert len(chars) == 58
cmap = {c: n for n, c in enumerate(chars)}
@staticmethod
def char_value(c):
val = Base58.cmap.get(c)
if val is None:
raise Base58Error('invalid base 58 character "{}"'.format(c))
return val
@staticmethod
def decode(txt):
"""Decodes txt into a big-endian bytearray."""
if not isinstance(txt, str):
raise TypeError('a string is required')
if not txt:
raise Base58Error('string cannot be empty')
value = 0
for c in txt:
value = value * 58 + Base58.char_value(c)
result = int_to_bytes(value)
# Prepend leading zero bytes if necessary
count = 0
for c in txt:
if c != '1':
break
count += 1
if count:
result = bytes(count) + result
return result
@staticmethod
def encode(be_bytes):
"""Converts a big-endian bytearray into a base58 string."""
value = bytes_to_int(be_bytes)
txt = ''
while value:
value, mod = divmod(value, 58)
txt += Base58.chars[mod]
for byte in be_bytes:
if byte != 0:
break
txt += '1'
return txt[::-1]
@staticmethod
def decode_check(txt, *, hash_fn=double_sha256):
'''Decodes a Base58Check-encoded string to a payload. The version
prefixes it.'''
be_bytes = Base58.decode(txt)
result, check = be_bytes[:-4], be_bytes[-4:]
if check != hash_fn(result)[:4]:
raise Base58Error('invalid base 58 checksum for {}'.format(txt))
return result
@staticmethod
def encode_check(payload, *, hash_fn=double_sha256):
"""Encodes a payload bytearray (which includes the version byte(s))
into a Base58Check string."""
be_bytes = payload + hash_fn(payload)[:4]
return Base58.encode(be_bytes)

324
torba/server/history.py Normal file
View file

@ -0,0 +1,324 @@
# Copyright (c) 2016-2018, Neil Booth
# Copyright (c) 2017, the ElectrumX authors
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
'''History by script hash (address).'''
import array
import ast
import bisect
import time
from collections import defaultdict
from functools import partial
import torba.server.util as util
from torba.server.util import pack_be_uint16, unpack_be_uint16_from
from torba.server.hash import hash_to_hex_str, HASHX_LEN
class History(object):
DB_VERSIONS = [0]
def __init__(self):
self.logger = util.class_logger(__name__, self.__class__.__name__)
# For history compaction
self.max_hist_row_entries = 12500
self.unflushed = defaultdict(partial(array.array, 'I'))
self.unflushed_count = 0
self.db = None
def open_db(self, db_class, for_sync, utxo_flush_count, compacting):
self.db = db_class('hist', for_sync)
self.read_state()
self.clear_excess(utxo_flush_count)
# An incomplete compaction needs to be cancelled otherwise
# restarting it will corrupt the history
if not compacting:
self._cancel_compaction()
return self.flush_count
def close_db(self):
if self.db:
self.db.close()
self.db = None
def read_state(self):
state = self.db.get(b'state\0\0')
if state:
state = ast.literal_eval(state.decode())
if not isinstance(state, dict):
raise RuntimeError('failed reading state from history DB')
self.flush_count = state['flush_count']
self.comp_flush_count = state.get('comp_flush_count', -1)
self.comp_cursor = state.get('comp_cursor', -1)
self.db_version = state.get('db_version', 0)
else:
self.flush_count = 0
self.comp_flush_count = -1
self.comp_cursor = -1
self.db_version = max(self.DB_VERSIONS)
self.logger.info(f'history DB version: {self.db_version}')
if self.db_version not in self.DB_VERSIONS:
msg = f'this software only handles DB versions {self.DB_VERSIONS}'
self.logger.error(msg)
raise RuntimeError(msg)
self.logger.info(f'flush count: {self.flush_count:,d}')
def clear_excess(self, utxo_flush_count):
# < might happen at end of compaction as both DBs cannot be
# updated atomically
if self.flush_count <= utxo_flush_count:
return
self.logger.info('DB shut down uncleanly. Scanning for '
'excess history flushes...')
keys = []
for key, hist in self.db.iterator(prefix=b''):
flush_id, = unpack_be_uint16_from(key[-2:])
if flush_id > utxo_flush_count:
keys.append(key)
self.logger.info(f'deleting {len(keys):,d} history entries')
self.flush_count = utxo_flush_count
with self.db.write_batch() as batch:
for key in keys:
batch.delete(key)
self.write_state(batch)
self.logger.info('deleted excess history entries')
def write_state(self, batch):
'''Write state to the history DB.'''
state = {
'flush_count': self.flush_count,
'comp_flush_count': self.comp_flush_count,
'comp_cursor': self.comp_cursor,
'db_version': self.db_version,
}
# History entries are not prefixed; the suffix \0\0 ensures we
# look similar to other entries and aren't interfered with
batch.put(b'state\0\0', repr(state).encode())
def add_unflushed(self, hashXs_by_tx, first_tx_num):
unflushed = self.unflushed
count = 0
for tx_num, hashXs in enumerate(hashXs_by_tx, start=first_tx_num):
hashXs = set(hashXs)
for hashX in hashXs:
unflushed[hashX].append(tx_num)
count += len(hashXs)
self.unflushed_count += count
def unflushed_memsize(self):
return len(self.unflushed) * 180 + self.unflushed_count * 4
def assert_flushed(self):
assert not self.unflushed
def flush(self):
start_time = time.time()
self.flush_count += 1
flush_id = pack_be_uint16(self.flush_count)
unflushed = self.unflushed
with self.db.write_batch() as batch:
for hashX in sorted(unflushed):
key = hashX + flush_id
batch.put(key, unflushed[hashX].tobytes())
self.write_state(batch)
count = len(unflushed)
unflushed.clear()
self.unflushed_count = 0
if self.db.for_sync:
elapsed = time.time() - start_time
self.logger.info(f'flushed history in {elapsed:.1f}s '
f'for {count:,d} addrs')
def backup(self, hashXs, tx_count):
# Not certain this is needed, but it doesn't hurt
self.flush_count += 1
nremoves = 0
bisect_left = bisect.bisect_left
with self.db.write_batch() as batch:
for hashX in sorted(hashXs):
deletes = []
puts = {}
for key, hist in self.db.iterator(prefix=hashX, reverse=True):
a = array.array('I')
a.frombytes(hist)
# Remove all history entries >= tx_count
idx = bisect_left(a, tx_count)
nremoves += len(a) - idx
if idx > 0:
puts[key] = a[:idx].tobytes()
break
deletes.append(key)
for key in deletes:
batch.delete(key)
for key, value in puts.items():
batch.put(key, value)
self.write_state(batch)
self.logger.info(f'backing up removed {nremoves:,d} history entries')
def get_txnums(self, hashX, limit=1000):
'''Generator that returns an unpruned, sorted list of tx_nums in the
history of a hashX. Includes both spending and receiving
transactions. By default yields at most 1000 entries. Set
limit to None to get them all. '''
limit = util.resolve_limit(limit)
for key, hist in self.db.iterator(prefix=hashX):
a = array.array('I')
a.frombytes(hist)
for tx_num in a:
if limit == 0:
return
yield tx_num
limit -= 1
#
# History compaction
#
# comp_cursor is a cursor into compaction progress.
# -1: no compaction in progress
# 0-65535: Compaction in progress; all prefixes < comp_cursor have
# been compacted, and later ones have not.
# 65536: compaction complete in-memory but not flushed
#
# comp_flush_count applies during compaction, and is a flush count
# for history with prefix < comp_cursor. flush_count applies
# to still uncompacted history. It is -1 when no compaction is
# taking place. Key suffixes up to and including comp_flush_count
# are used, so a parallel history flush must first increment this
#
# When compaction is complete and the final flush takes place,
# flush_count is reset to comp_flush_count, and comp_flush_count to -1
def _flush_compaction(self, cursor, write_items, keys_to_delete):
'''Flush a single compaction pass as a batch.'''
# Update compaction state
if cursor == 65536:
self.flush_count = self.comp_flush_count
self.comp_cursor = -1
self.comp_flush_count = -1
else:
self.comp_cursor = cursor
# History DB. Flush compacted history and updated state
with self.db.write_batch() as batch:
# Important: delete first! The keyspace may overlap.
for key in keys_to_delete:
batch.delete(key)
for key, value in write_items:
batch.put(key, value)
self.write_state(batch)
def _compact_hashX(self, hashX, hist_map, hist_list,
write_items, keys_to_delete):
'''Compres history for a hashX. hist_list is an ordered list of
the histories to be compressed.'''
# History entries (tx numbers) are 4 bytes each. Distribute
# over rows of up to 50KB in size. A fixed row size means
# future compactions will not need to update the first N - 1
# rows.
max_row_size = self.max_hist_row_entries * 4
full_hist = b''.join(hist_list)
nrows = (len(full_hist) + max_row_size - 1) // max_row_size
if nrows > 4:
self.logger.info('hashX {} is large: {:,d} entries across '
'{:,d} rows'
.format(hash_to_hex_str(hashX),
len(full_hist) // 4, nrows))
# Find what history needs to be written, and what keys need to
# be deleted. Start by assuming all keys are to be deleted,
# and then remove those that are the same on-disk as when
# compacted.
write_size = 0
keys_to_delete.update(hist_map)
for n, chunk in enumerate(util.chunks(full_hist, max_row_size)):
key = hashX + pack_be_uint16(n)
if hist_map.get(key) == chunk:
keys_to_delete.remove(key)
else:
write_items.append((key, chunk))
write_size += len(chunk)
assert n + 1 == nrows
self.comp_flush_count = max(self.comp_flush_count, n)
return write_size
def _compact_prefix(self, prefix, write_items, keys_to_delete):
'''Compact all history entries for hashXs beginning with the
given prefix. Update keys_to_delete and write.'''
prior_hashX = None
hist_map = {}
hist_list = []
key_len = HASHX_LEN + 2
write_size = 0
for key, hist in self.db.iterator(prefix=prefix):
# Ignore non-history entries
if len(key) != key_len:
continue
hashX = key[:-2]
if hashX != prior_hashX and prior_hashX:
write_size += self._compact_hashX(prior_hashX, hist_map,
hist_list, write_items,
keys_to_delete)
hist_map.clear()
hist_list.clear()
prior_hashX = hashX
hist_map[key] = hist
hist_list.append(hist)
if prior_hashX:
write_size += self._compact_hashX(prior_hashX, hist_map, hist_list,
write_items, keys_to_delete)
return write_size
def _compact_history(self, limit):
'''Inner loop of history compaction. Loops until limit bytes have
been processed.
'''
keys_to_delete = set()
write_items = [] # A list of (key, value) pairs
write_size = 0
# Loop over 2-byte prefixes
cursor = self.comp_cursor
while write_size < limit and cursor < 65536:
prefix = pack_be_uint16(cursor)
write_size += self._compact_prefix(prefix, write_items,
keys_to_delete)
cursor += 1
max_rows = self.comp_flush_count + 1
self._flush_compaction(cursor, write_items, keys_to_delete)
self.logger.info('history compaction: wrote {:,d} rows ({:.1f} MB), '
'removed {:,d} rows, largest: {:,d}, {:.1f}% complete'
.format(len(write_items), write_size / 1000000,
len(keys_to_delete), max_rows,
100 * cursor / 65536))
return write_size
def _cancel_compaction(self):
if self.comp_cursor != -1:
self.logger.warning('cancelling in-progress history compaction')
self.comp_flush_count = -1
self.comp_cursor = -1

365
torba/server/mempool.py Normal file
View file

@ -0,0 +1,365 @@
# Copyright (c) 2016-2018, Neil Booth
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
'''Mempool handling.'''
import itertools
import time
from abc import ABC, abstractmethod
from asyncio import Lock
from collections import defaultdict
import attr
from aiorpcx import TaskGroup, run_in_thread, sleep
from torba.server.hash import hash_to_hex_str, hex_str_to_hash
from torba.server.util import class_logger, chunks
from torba.server.db import UTXO
@attr.s(slots=True)
class MemPoolTx(object):
prevouts = attr.ib()
# A pair is a (hashX, value) tuple
in_pairs = attr.ib()
out_pairs = attr.ib()
fee = attr.ib()
size = attr.ib()
@attr.s(slots=True)
class MemPoolTxSummary(object):
hash = attr.ib()
fee = attr.ib()
has_unconfirmed_inputs = attr.ib()
class MemPoolAPI(ABC):
'''A concrete instance of this class is passed to the MemPool object
and used by it to query DB and blockchain state.'''
@abstractmethod
async def height(self):
'''Query bitcoind for its height.'''
@abstractmethod
def cached_height(self):
'''Return the height of bitcoind the last time it was queried,
for any reason, without actually querying it.
'''
@abstractmethod
async def mempool_hashes(self):
'''Query bitcoind for the hashes of all transactions in its
mempool, returned as a list.'''
@abstractmethod
async def raw_transactions(self, hex_hashes):
'''Query bitcoind for the serialized raw transactions with the given
hashes. Missing transactions are returned as None.
hex_hashes is an iterable of hexadecimal hash strings.'''
@abstractmethod
async def lookup_utxos(self, prevouts):
'''Return a list of (hashX, value) pairs each prevout if unspent,
otherwise return None if spent or not found.
prevouts - an iterable of (hash, index) pairs
'''
@abstractmethod
async def on_mempool(self, touched, height):
'''Called each time the mempool is synchronized. touched is a set of
hashXs touched since the previous call. height is the
daemon's height at the time the mempool was obtained.'''
class MemPool(object):
'''Representation of the daemon's mempool.
coin - a coin class from coins.py
api - an object implementing MemPoolAPI
Updated regularly in caught-up state. Goal is to enable efficient
response to the calls in the external interface. To that end we
maintain the following maps:
tx: tx_hash -> MemPoolTx
hashXs: hashX -> set of all hashes of txs touching the hashX
'''
def __init__(self, coin, api, refresh_secs=5.0, log_status_secs=120.0):
assert isinstance(api, MemPoolAPI)
self.coin = coin
self.api = api
self.logger = class_logger(__name__, self.__class__.__name__)
self.txs = {}
self.hashXs = defaultdict(set) # None can be a key
self.cached_compact_histogram = []
self.refresh_secs = refresh_secs
self.log_status_secs = log_status_secs
# Prevents mempool refreshes during fee histogram calculation
self.lock = Lock()
async def _logging(self, synchronized_event):
'''Print regular logs of mempool stats.'''
self.logger.info('beginning processing of daemon mempool. '
'This can take some time...')
start = time.time()
await synchronized_event.wait()
elapsed = time.time() - start
self.logger.info(f'synced in {elapsed:.2f}s')
while True:
self.logger.info(f'{len(self.txs):,d} txs '
f'touching {len(self.hashXs):,d} addresses')
await sleep(self.log_status_secs)
await synchronized_event.wait()
async def _refresh_histogram(self, synchronized_event):
while True:
await synchronized_event.wait()
async with self.lock:
# Threaded as can be expensive
await run_in_thread(self._update_histogram, 100_000)
await sleep(self.coin.MEMPOOL_HISTOGRAM_REFRESH_SECS)
def _update_histogram(self, bin_size):
# Build a histogram by fee rate
histogram = defaultdict(int)
for tx in self.txs.values():
histogram[tx.fee // tx.size] += tx.size
# Now compact it. For efficiency, get_fees returns a
# compact histogram with variable bin size. The compact
# histogram is an array of (fee_rate, vsize) values.
# vsize_n is the cumulative virtual size of mempool
# transactions with a fee rate in the interval
# [rate_(n-1), rate_n)], and rate_(n-1) > rate_n.
# Intervals are chosen to create tranches containing at
# least 100kb of transactions
compact = []
cum_size = 0
r = 0 # ?
for fee_rate, size in sorted(histogram.items(), reverse=True):
cum_size += size
if cum_size + r > bin_size:
compact.append((fee_rate, cum_size))
r += cum_size - bin_size
cum_size = 0
bin_size *= 1.1
self.logger.info(f'compact fee histogram: {compact}')
self.cached_compact_histogram = compact
def _accept_transactions(self, tx_map, utxo_map, touched):
'''Accept transactions in tx_map to the mempool if all their inputs
can be found in the existing mempool or a utxo_map from the
DB.
Returns an (unprocessed tx_map, unspent utxo_map) pair.
'''
hashXs = self.hashXs
txs = self.txs
deferred = {}
unspent = set(utxo_map)
# Try to find all prevouts so we can accept the TX
for hash, tx in tx_map.items():
in_pairs = []
try:
for prevout in tx.prevouts:
utxo = utxo_map.get(prevout)
if not utxo:
prev_hash, prev_index = prevout
# Raises KeyError if prev_hash is not in txs
utxo = txs[prev_hash].out_pairs[prev_index]
in_pairs.append(utxo)
except KeyError:
deferred[hash] = tx
continue
# Spend the prevouts
unspent.difference_update(tx.prevouts)
# Save the in_pairs, compute the fee and accept the TX
tx.in_pairs = tuple(in_pairs)
# Avoid negative fees if dealing with generation-like transactions
# because some in_parts would be missing
tx.fee = max(0, (sum(v for _, v in tx.in_pairs) -
sum(v for _, v in tx.out_pairs)))
txs[hash] = tx
for hashX, value in itertools.chain(tx.in_pairs, tx.out_pairs):
touched.add(hashX)
hashXs[hashX].add(hash)
return deferred, {prevout: utxo_map[prevout] for prevout in unspent}
async def _refresh_hashes(self, synchronized_event):
'''Refresh our view of the daemon's mempool.'''
while True:
height = self.api.cached_height()
hex_hashes = await self.api.mempool_hashes()
if height != await self.api.height():
continue
hashes = set(hex_str_to_hash(hh) for hh in hex_hashes)
async with self.lock:
touched = await self._process_mempool(hashes)
synchronized_event.set()
synchronized_event.clear()
await self.api.on_mempool(touched, height)
await sleep(self.refresh_secs)
async def _process_mempool(self, all_hashes):
# Re-sync with the new set of hashes
txs = self.txs
hashXs = self.hashXs
touched = set()
# First handle txs that have disappeared
for tx_hash in set(txs).difference(all_hashes):
tx = txs.pop(tx_hash)
tx_hashXs = set(hashX for hashX, value in tx.in_pairs)
tx_hashXs.update(hashX for hashX, value in tx.out_pairs)
for hashX in tx_hashXs:
hashXs[hashX].remove(tx_hash)
if not hashXs[hashX]:
del hashXs[hashX]
touched.update(tx_hashXs)
# Process new transactions
new_hashes = list(all_hashes.difference(txs))
if new_hashes:
group = TaskGroup()
for hashes in chunks(new_hashes, 200):
coro = self._fetch_and_accept(hashes, all_hashes, touched)
await group.spawn(coro)
tx_map = {}
utxo_map = {}
async for task in group:
deferred, unspent = task.result()
tx_map.update(deferred)
utxo_map.update(unspent)
prior_count = 0
# FIXME: this is not particularly efficient
while tx_map and len(tx_map) != prior_count:
prior_count = len(tx_map)
tx_map, utxo_map = self._accept_transactions(tx_map, utxo_map,
touched)
if tx_map:
self.logger.info(f'{len(tx_map)} txs dropped')
return touched
async def _fetch_and_accept(self, hashes, all_hashes, touched):
'''Fetch a list of mempool transactions.'''
hex_hashes_iter = (hash_to_hex_str(hash) for hash in hashes)
raw_txs = await self.api.raw_transactions(hex_hashes_iter)
def deserialize_txs(): # This function is pure
to_hashX = self.coin.hashX_from_script
deserializer = self.coin.DESERIALIZER
txs = {}
for hash, raw_tx in zip(hashes, raw_txs):
# The daemon may have evicted the tx from its
# mempool or it may have gotten in a block
if not raw_tx:
continue
tx, tx_size = deserializer(raw_tx).read_tx_and_vsize()
# Convert the inputs and outputs into (hashX, value) pairs
# Drop generation-like inputs from MemPoolTx.prevouts
txin_pairs = tuple((txin.prev_hash, txin.prev_idx)
for txin in tx.inputs
if not txin.is_generation())
txout_pairs = tuple((to_hashX(txout.pk_script), txout.value)
for txout in tx.outputs)
txs[hash] = MemPoolTx(txin_pairs, None, txout_pairs,
0, tx_size)
return txs
# Thread this potentially slow operation so as not to block
tx_map = await run_in_thread(deserialize_txs)
# Determine all prevouts not in the mempool, and fetch the
# UTXO information from the database. Failed prevout lookups
# return None - concurrent database updates happen - which is
# relied upon by _accept_transactions. Ignore prevouts that are
# generation-like.
prevouts = tuple(prevout for tx in tx_map.values()
for prevout in tx.prevouts
if prevout[0] not in all_hashes)
utxos = await self.api.lookup_utxos(prevouts)
utxo_map = {prevout: utxo for prevout, utxo in zip(prevouts, utxos)}
return self._accept_transactions(tx_map, utxo_map, touched)
#
# External interface
#
async def keep_synchronized(self, synchronized_event):
'''Keep the mempool synchronized with the daemon.'''
async with TaskGroup() as group:
await group.spawn(self._refresh_hashes(synchronized_event))
await group.spawn(self._refresh_histogram(synchronized_event))
await group.spawn(self._logging(synchronized_event))
async def balance_delta(self, hashX):
'''Return the unconfirmed amount in the mempool for hashX.
Can be positive or negative.
'''
value = 0
if hashX in self.hashXs:
for hash in self.hashXs[hashX]:
tx = self.txs[hash]
value -= sum(v for h168, v in tx.in_pairs if h168 == hashX)
value += sum(v for h168, v in tx.out_pairs if h168 == hashX)
return value
async def compact_fee_histogram(self):
'''Return a compact fee histogram of the current mempool.'''
return self.cached_compact_histogram
async def potential_spends(self, hashX):
'''Return a set of (prev_hash, prev_idx) pairs from mempool
transactions that touch hashX.
None, some or all of these may be spends of the hashX, but all
actual spends of it (in the DB or mempool) will be included.
'''
result = set()
for tx_hash in self.hashXs.get(hashX, ()):
tx = self.txs[tx_hash]
result.update(tx.prevouts)
return result
async def transaction_summaries(self, hashX):
'''Return a list of MemPoolTxSummary objects for the hashX.'''
result = []
for tx_hash in self.hashXs.get(hashX, ()):
tx = self.txs[tx_hash]
has_ui = any(hash in self.txs for hash, idx in tx.prevouts)
result.append(MemPoolTxSummary(tx_hash, tx.fee, has_ui))
return result
async def unordered_UTXOs(self, hashX):
'''Return an unordered list of UTXO named tuples from mempool
transactions that pay to hashX.
This does not consider if any other mempool transactions spend
the outputs.
'''
utxos = []
for tx_hash in self.hashXs.get(hashX, ()):
tx = self.txs.get(tx_hash)
for pos, (hX, value) in enumerate(tx.out_pairs):
if hX == hashX:
utxos.append(UTXO(-1, pos, tx_hash, 0, value))
return utxos

254
torba/server/merkle.py Normal file
View file

@ -0,0 +1,254 @@
# Copyright (c) 2018, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# and warranty status of this software.
'''Merkle trees, branches, proofs and roots.'''
from math import ceil, log
from aiorpcx import Event
from torba.server.hash import double_sha256
class Merkle(object):
'''Perform merkle tree calculations on binary hashes using a given hash
function.
If the hash count is not even, the final hash is repeated when
calculating the next merkle layer up the tree.
'''
def __init__(self, hash_func=double_sha256):
self.hash_func = hash_func
def tree_depth(self, hash_count):
return self.branch_length(hash_count) + 1
def branch_length(self, hash_count):
'''Return the length of a merkle branch given the number of hashes.'''
if not isinstance(hash_count, int):
raise TypeError('hash_count must be an integer')
if hash_count < 1:
raise ValueError('hash_count must be at least 1')
return ceil(log(hash_count, 2))
def branch_and_root(self, hashes, index, length=None):
'''Return a (merkle branch, merkle_root) pair given hashes, and the
index of one of those hashes.
'''
hashes = list(hashes)
if not isinstance(index, int):
raise TypeError('index must be an integer')
# This also asserts hashes is not empty
if not 0 <= index < len(hashes):
raise ValueError('index out of range')
natural_length = self.branch_length(len(hashes))
if length is None:
length = natural_length
else:
if not isinstance(length, int):
raise TypeError('length must be an integer')
if length < natural_length:
raise ValueError('length out of range')
hash_func = self.hash_func
branch = []
for _ in range(length):
if len(hashes) & 1:
hashes.append(hashes[-1])
branch.append(hashes[index ^ 1])
index >>= 1
hashes = [hash_func(hashes[n] + hashes[n + 1])
for n in range(0, len(hashes), 2)]
return branch, hashes[0]
def root(self, hashes, length=None):
'''Return the merkle root of a non-empty iterable of binary hashes.'''
branch, root = self.branch_and_root(hashes, 0, length)
return root
def root_from_proof(self, hash, branch, index):
'''Return the merkle root given a hash, a merkle branch to it, and
its index in the hashes array.
branch is an iterable sorted deepest to shallowest. If the
returned root is the expected value then the merkle proof is
verified.
The caller should have confirmed the length of the branch with
branch_length(). Unfortunately this is not easily done for
bitcoin transactions as the number of transactions in a block
is unknown to an SPV client.
'''
hash_func = self.hash_func
for elt in branch:
if index & 1:
hash = hash_func(elt + hash)
else:
hash = hash_func(hash + elt)
index >>= 1
if index:
raise ValueError('index out of range for branch')
return hash
def level(self, hashes, depth_higher):
'''Return a level of the merkle tree of hashes the given depth
higher than the bottom row of the original tree.'''
size = 1 << depth_higher
root = self.root
return [root(hashes[n: n + size], depth_higher)
for n in range(0, len(hashes), size)]
def branch_and_root_from_level(self, level, leaf_hashes, index,
depth_higher):
'''Return a (merkle branch, merkle_root) pair when a merkle-tree has a
level cached.
To maximally reduce the amount of data hashed in computing a
markle branch, cache a tree of depth N at level N // 2.
level is a list of hashes in the middle of the tree (returned
by level())
leaf_hashes are the leaves needed to calculate a partial branch
up to level.
depth_higher is how much higher level is than the leaves of the tree
index is the index in the full list of hashes of the hash whose
merkle branch we want.
'''
if not isinstance(level, list):
raise TypeError("level must be a list")
if not isinstance(leaf_hashes, list):
raise TypeError("leaf_hashes must be a list")
leaf_index = (index >> depth_higher) << depth_higher
leaf_branch, leaf_root = self.branch_and_root(
leaf_hashes, index - leaf_index, depth_higher)
index >>= depth_higher
level_branch, root = self.branch_and_root(level, index)
# Check last so that we know index is in-range
if leaf_root != level[index]:
raise ValueError('leaf hashes inconsistent with level')
return leaf_branch + level_branch, root
class MerkleCache(object):
'''A cache to calculate merkle branches efficiently.'''
def __init__(self, merkle, source_func):
'''Initialise a cache hashes taken from source_func:
async def source_func(index, count):
...
'''
self.merkle = merkle
self.source_func = source_func
self.length = 0
self.depth_higher = 0
self.initialized = Event()
def _segment_length(self):
return 1 << self.depth_higher
def _leaf_start(self, index):
'''Given a level's depth higher and a hash index, return the leaf
index and leaf hash count needed to calculate a merkle branch.
'''
depth_higher = self.depth_higher
return (index >> depth_higher) << depth_higher
def _level(self, hashes):
return self.merkle.level(hashes, self.depth_higher)
async def _extend_to(self, length):
'''Extend the length of the cache if necessary.'''
if length <= self.length:
return
# Start from the beginning of any final partial segment.
# Retain the value of depth_higher; in practice this is fine
start = self._leaf_start(self.length)
hashes = await self.source_func(start, length - start)
self.level[start >> self.depth_higher:] = self._level(hashes)
self.length = length
async def _level_for(self, length):
'''Return a (level_length, final_hash) pair for a truncation
of the hashes to the given length.'''
if length == self.length:
return self.level
level = self.level[:length >> self.depth_higher]
leaf_start = self._leaf_start(length)
count = min(self._segment_length(), length - leaf_start)
hashes = await self.source_func(leaf_start, count)
level += self._level(hashes)
return level
async def initialize(self, length):
'''Call to initialize the cache to a source of given length.'''
self.length = length
self.depth_higher = self.merkle.tree_depth(length) // 2
self.level = self._level(await self.source_func(0, length))
self.initialized.set()
def truncate(self, length):
'''Truncate the cache so it covers no more than length underlying
hashes.'''
if not isinstance(length, int):
raise TypeError('length must be an integer')
if length <= 0:
raise ValueError('length must be positive')
if length >= self.length:
return
length = self._leaf_start(length)
self.length = length
self.level[length >> self.depth_higher:] = []
async def branch_and_root(self, length, index):
'''Return a merkle branch and root. Length is the number of
hashes used to calculate the merkle root, index is the position
of the hash to calculate the branch of.
index must be less than length, which must be at least 1.'''
if not isinstance(length, int):
raise TypeError('length must be an integer')
if not isinstance(index, int):
raise TypeError('index must be an integer')
if length <= 0:
raise ValueError('length must be positive')
if index >= length:
raise ValueError('index must be less than length')
await self.initialized.wait()
await self._extend_to(length)
leaf_start = self._leaf_start(index)
count = min(self._segment_length(), length - leaf_start)
leaf_hashes = await self.source_func(leaf_start, count)
if length < self._segment_length():
return self.merkle.branch_and_root(leaf_hashes, index)
level = await self._level_for(length)
return self.merkle.branch_and_root_from_level(
level, leaf_hashes, index, self.depth_higher)

301
torba/server/peer.py Normal file
View file

@ -0,0 +1,301 @@
# Copyright (c) 2017, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''Representation of a peer server.'''
from ipaddress import ip_address
from torba.server.util import cachedproperty
import torba.server.util as util
from typing import Dict
class Peer(object):
# Protocol version
ATTRS = ('host', 'features',
# metadata
'source', 'ip_addr',
'last_good', 'last_try', 'try_count')
FEATURES = ('pruning', 'server_version', 'protocol_min', 'protocol_max',
'ssl_port', 'tcp_port')
# This should be set by the application
DEFAULT_PORTS: Dict[str, int] = {}
def __init__(self, host, features, source='unknown', ip_addr=None,
last_good=0, last_try=0, try_count=0):
'''Create a peer given a host name (or IP address as a string),
a dictionary of features, and a record of the source.'''
assert isinstance(host, str)
assert isinstance(features, dict)
assert host in features.get('hosts', {})
self.host = host
self.features = features.copy()
# Canonicalize / clean-up
for feature in self.FEATURES:
self.features[feature] = getattr(self, feature)
# Metadata
self.source = source
self.ip_addr = ip_addr
# last_good represents the last connection that was
# successful *and* successfully verified, at which point
# try_count is set to 0. Failure to connect or failure to
# verify increment the try_count.
self.last_good = last_good
self.last_try = last_try
self.try_count = try_count
# Transient, non-persisted metadata
self.bad = False
self.other_port_pairs = set()
@classmethod
def peers_from_features(cls, features, source):
peers = []
if isinstance(features, dict):
hosts = features.get('hosts')
if isinstance(hosts, dict):
peers = [Peer(host, features, source=source)
for host in hosts if isinstance(host, str)]
return peers
@classmethod
def deserialize(cls, item):
'''Deserialize from a dictionary.'''
return cls(**item)
def matches(self, peers):
'''Return peers whose host matches our hostname or IP address.
Additionally include all peers whose IP address matches our
hostname if that is an IP address.
'''
candidates = (self.host.lower(), self.ip_addr)
return [peer for peer in peers
if peer.host.lower() in candidates
or peer.ip_addr == self.host]
def __str__(self):
return self.host
def update_features(self, features):
'''Update features in-place.'''
try:
tmp = Peer(self.host, features)
except Exception:
pass
else:
self.update_features_from_peer(tmp)
def update_features_from_peer(self, peer):
if peer != self:
self.features = peer.features
for feature in self.FEATURES:
setattr(self, feature, getattr(peer, feature))
def connection_port_pairs(self):
'''Return a list of (kind, port) pairs to try when making a
connection.'''
# Use a list not a set - it's important to try the registered
# ports first.
pairs = [('SSL', self.ssl_port), ('TCP', self.tcp_port)]
while self.other_port_pairs:
pairs.append(self.other_port_pairs.pop())
return [pair for pair in pairs if pair[1]]
def mark_bad(self):
'''Mark as bad to avoid reconnects but also to remember for a
while.'''
self.bad = True
def check_ports(self, other):
'''Remember differing ports in case server operator changed them
or removed one.'''
if other.ssl_port != self.ssl_port:
self.other_port_pairs.add(('SSL', other.ssl_port))
if other.tcp_port != self.tcp_port:
self.other_port_pairs.add(('TCP', other.tcp_port))
return bool(self.other_port_pairs)
@cachedproperty
def is_tor(self):
return self.host.endswith('.onion')
@cachedproperty
def is_valid(self):
ip = self.ip_address
if ip:
return ((ip.is_global or ip.is_private)
and not (ip.is_multicast or ip.is_unspecified))
return util.is_valid_hostname(self.host)
@cachedproperty
def is_public(self):
ip = self.ip_address
if ip:
return self.is_valid and not ip.is_private
else:
return self.is_valid and self.host != 'localhost'
@cachedproperty
def ip_address(self):
'''The host as a python ip_address object, or None.'''
try:
return ip_address(self.host)
except ValueError:
return None
def bucket(self):
if self.is_tor:
return 'onion'
if not self.ip_addr:
return ''
return tuple(self.ip_addr.split('.')[:2])
def serialize(self):
'''Serialize to a dictionary.'''
return {attr: getattr(self, attr) for attr in self.ATTRS}
def _port(self, key):
hosts = self.features.get('hosts')
if isinstance(hosts, dict):
host = hosts.get(self.host)
port = self._integer(key, host)
if port and 0 < port < 65536:
return port
return None
def _integer(self, key, d=None):
d = d or self.features
result = d.get(key) if isinstance(d, dict) else None
if isinstance(result, str):
try:
result = int(result)
except ValueError:
pass
return result if isinstance(result, int) else None
def _string(self, key):
result = self.features.get(key)
return result if isinstance(result, str) else None
@cachedproperty
def genesis_hash(self):
'''Returns None if no SSL port, otherwise the port as an integer.'''
return self._string('genesis_hash')
@cachedproperty
def ssl_port(self):
'''Returns None if no SSL port, otherwise the port as an integer.'''
return self._port('ssl_port')
@cachedproperty
def tcp_port(self):
'''Returns None if no TCP port, otherwise the port as an integer.'''
return self._port('tcp_port')
@cachedproperty
def server_version(self):
'''Returns the server version as a string if known, otherwise None.'''
return self._string('server_version')
@cachedproperty
def pruning(self):
'''Returns the pruning level as an integer. None indicates no
pruning.'''
pruning = self._integer('pruning')
if pruning and pruning > 0:
return pruning
return None
def _protocol_version_string(self, key):
version_str = self.features.get(key)
ptuple = util.protocol_tuple(version_str)
return util.version_string(ptuple)
@cachedproperty
def protocol_min(self):
'''Minimum protocol version as a string, e.g., 1.0'''
return self._protocol_version_string('protocol_min')
@cachedproperty
def protocol_max(self):
'''Maximum protocol version as a string, e.g., 1.1'''
return self._protocol_version_string('protocol_max')
def to_tuple(self):
'''The tuple ((ip, host, details) expected in response
to a peers subscription.'''
details = self.real_name().split()[1:]
return (self.ip_addr or self.host, self.host, details)
def real_name(self):
'''Real name of this peer as used on IRC.'''
def port_text(letter, port):
if port == self.DEFAULT_PORTS.get(letter):
return letter
else:
return letter + str(port)
parts = [self.host, 'v' + self.protocol_max]
if self.pruning:
parts.append('p{:d}'.format(self.pruning))
for letter, port in (('s', self.ssl_port), ('t', self.tcp_port)):
if port:
parts.append(port_text(letter, port))
return ' '.join(parts)
@classmethod
def from_real_name(cls, real_name, source):
'''Real name is a real name as on IRC, such as
"erbium1.sytes.net v1.0 s t"
Returns an instance of this Peer class.
'''
host = 'nohost'
features = {}
ports = {}
for n, part in enumerate(real_name.split()):
if n == 0:
host = part
continue
if part[0] in ('s', 't'):
if len(part) == 1:
port = cls.DEFAULT_PORTS[part[0]]
else:
port = part[1:]
if part[0] == 's':
ports['ssl_port'] = port
else:
ports['tcp_port'] = port
elif part[0] == 'v':
features['protocol_max'] = features['protocol_min'] = part[1:]
elif part[0] == 'p':
features['pruning'] = part[1:]
features.update(ports)
features['hosts'] = {host: ports}
return cls(host, features, source)

510
torba/server/peers.py Normal file
View file

@ -0,0 +1,510 @@
# Copyright (c) 2017-2018, Neil Booth
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
'''Peer management.'''
import asyncio
import random
import socket
import ssl
import time
from collections import defaultdict, Counter
from aiorpcx import (Connector, RPCSession, SOCKSProxy,
Notification, handler_invocation,
SOCKSError, RPCError, TaskTimeout, TaskGroup, Event,
sleep, run_in_thread, ignore_after, timeout_after)
from torba.server.peer import Peer
from torba.server.util import class_logger, protocol_tuple
PEER_GOOD, PEER_STALE, PEER_NEVER, PEER_BAD = range(4)
STALE_SECS = 24 * 3600
WAKEUP_SECS = 300
class BadPeerError(Exception):
pass
def assert_good(message, result, instance):
if not isinstance(result, instance):
raise BadPeerError(f'{message} returned bad result type '
f'{type(result).__name__}')
class PeerSession(RPCSession):
'''An outgoing session to a peer.'''
async def handle_request(self, request):
# We subscribe so might be unlucky enough to get a notification...
if (isinstance(request, Notification) and
request.method == 'blockchain.headers.subscribe'):
pass
else:
await handler_invocation(None, request) # Raises
class PeerManager(object):
'''Looks after the DB of peer network servers.
Attempts to maintain a connection with up to 8 peers.
Issues a 'peers.subscribe' RPC to them and tells them our data.
'''
def __init__(self, env, db):
self.logger = class_logger(__name__, self.__class__.__name__)
# Initialise the Peer class
Peer.DEFAULT_PORTS = env.coin.PEER_DEFAULT_PORTS
self.env = env
self.db = db
# Our clearnet and Tor Peers, if any
sclass = env.coin.SESSIONCLS
self.myselves = [Peer(ident.host, sclass.server_features(env), 'env')
for ident in env.identities]
self.server_version_args = sclass.server_version_args()
# Peers have one entry per hostname. Once connected, the
# ip_addr property is either None, an onion peer, or the
# IP address that was connected to. Adding a peer will evict
# any other peers with the same host name or IP address.
self.peers = set()
self.permit_onion_peer_time = time.time()
self.proxy = None
self.group = TaskGroup()
def _my_clearnet_peer(self):
'''Returns the clearnet peer representing this server, if any.'''
clearnet = [peer for peer in self.myselves if not peer.is_tor]
return clearnet[0] if clearnet else None
def _set_peer_statuses(self):
'''Set peer statuses.'''
cutoff = time.time() - STALE_SECS
for peer in self.peers:
if peer.bad:
peer.status = PEER_BAD
elif peer.last_good > cutoff:
peer.status = PEER_GOOD
elif peer.last_good:
peer.status = PEER_STALE
else:
peer.status = PEER_NEVER
def _features_to_register(self, peer, remote_peers):
'''If we should register ourselves to the remote peer, which has
reported the given list of known peers, return the clearnet
identity features to register, otherwise None.
'''
# Announce ourself if not present. Don't if disabled, we
# are a non-public IP address, or to ourselves.
if not self.env.peer_announce or peer in self.myselves:
return None
my = self._my_clearnet_peer()
if not my or not my.is_public:
return None
# Register if no matches, or ports have changed
for peer in my.matches(remote_peers):
if peer.tcp_port == my.tcp_port and peer.ssl_port == my.ssl_port:
return None
return my.features
def _permit_new_onion_peer(self):
'''Accept a new onion peer only once per random time interval.'''
now = time.time()
if now < self.permit_onion_peer_time:
return False
self.permit_onion_peer_time = now + random.randrange(0, 1200)
return True
async def _import_peers(self):
'''Import hard-coded peers from a file or the coin defaults.'''
imported_peers = self.myselves.copy()
# Add the hard-coded ones unless only reporting ourself
if self.env.peer_discovery != self.env.PD_SELF:
imported_peers.extend(Peer.from_real_name(real_name, 'coins.py')
for real_name in self.env.coin.PEERS)
await self._note_peers(imported_peers, limit=None)
async def _detect_proxy(self):
'''Detect a proxy if we don't have one and some time has passed since
the last attempt.
If found self.proxy is set to a SOCKSProxy instance, otherwise
None.
'''
host = self.env.tor_proxy_host
if self.env.tor_proxy_port is None:
ports = [9050, 9150, 1080]
else:
ports = [self.env.tor_proxy_port]
while True:
self.logger.info(f'trying to detect proxy on "{host}" '
f'ports {ports}')
proxy = await SOCKSProxy.auto_detect_host(host, ports, None)
if proxy:
self.proxy = proxy
self.logger.info(f'detected {proxy}')
return
self.logger.info('no proxy detected, will try later')
await sleep(900)
async def _note_peers(self, peers, limit=2, check_ports=False,
source=None):
'''Add a limited number of peers that are not already present.'''
new_peers = []
for peer in peers:
if not peer.is_public or (peer.is_tor and not self.proxy):
continue
matches = peer.matches(self.peers)
if not matches:
new_peers.append(peer)
elif check_ports:
for match in matches:
if match.check_ports(peer):
self.logger.info(f'ports changed for {peer}')
match.retry_event.set()
if new_peers:
source = source or new_peers[0].source
if limit:
random.shuffle(new_peers)
use_peers = new_peers[:limit]
else:
use_peers = new_peers
for peer in use_peers:
self.logger.info(f'accepted new peer {peer} from {source}')
peer.retry_event = Event()
self.peers.add(peer)
await self.group.spawn(self._monitor_peer(peer))
async def _monitor_peer(self, peer):
# Stop monitoring if we were dropped (a duplicate peer)
while peer in self.peers:
if await self._should_drop_peer(peer):
self.peers.discard(peer)
break
# Figure out how long to sleep before retrying. Retry a
# good connection when it is about to turn stale, otherwise
# exponentially back off retries.
if peer.try_count == 0:
pause = STALE_SECS - WAKEUP_SECS * 2
else:
pause = WAKEUP_SECS * 2 ** peer.try_count
async with ignore_after(pause):
await peer.retry_event.wait()
peer.retry_event.clear()
async def _should_drop_peer(self, peer):
peer.try_count += 1
is_good = False
for kind, port in peer.connection_port_pairs():
peer.last_try = time.time()
kwargs = {}
if kind == 'SSL':
kwargs['ssl'] = ssl.SSLContext(ssl.PROTOCOL_TLS)
host = self.env.cs_host(for_rpc=False)
if isinstance(host, list):
host = host[0]
if self.env.force_proxy or peer.is_tor:
if not self.proxy:
return
kwargs['proxy'] = self.proxy
kwargs['resolve'] = not peer.is_tor
elif host:
# Use our listening Host/IP for outgoing non-proxy
# connections so our peers see the correct source.
kwargs['local_addr'] = (host, None)
peer_text = f'[{peer}:{port} {kind}]'
try:
async with timeout_after(120 if peer.is_tor else 30):
async with Connector(PeerSession, peer.host, port,
**kwargs) as session:
await self._verify_peer(session, peer)
is_good = True
break
except BadPeerError as e:
self.logger.error(f'{peer_text} marking bad: ({e})')
peer.mark_bad()
break
except RPCError as e:
self.logger.error(f'{peer_text} RPC error: {e.message} '
f'({e.code})')
except (OSError, SOCKSError, ConnectionError, TaskTimeout) as e:
self.logger.info(f'{peer_text} {e}')
if is_good:
now = time.time()
elapsed = now - peer.last_try
self.logger.info(f'{peer_text} verified in {elapsed:.1f}s')
peer.try_count = 0
peer.last_good = now
peer.source = 'peer'
# At most 2 matches if we're a host name, potentially
# several if we're an IP address (several instances
# can share a NAT).
matches = peer.matches(self.peers)
for match in matches:
if match.ip_address:
if len(matches) > 1:
self.peers.remove(match)
# Force the peer's monitoring task to exit
match.retry_event.set()
elif peer.host in match.features['hosts']:
match.update_features_from_peer(peer)
else:
# Forget the peer if long-term unreachable
if peer.last_good and not peer.bad:
try_limit = 10
else:
try_limit = 3
if peer.try_count >= try_limit:
desc = 'bad' if peer.bad else 'unreachable'
self.logger.info(f'forgetting {desc} peer: {peer}')
return True
return False
async def _verify_peer(self, session, peer):
if not peer.is_tor:
address = session.peer_address()
if address:
peer.ip_addr = address[0]
# server.version goes first
message = 'server.version'
result = await session.send_request(message, self.server_version_args)
assert_good(message, result, list)
# Protocol version 1.1 returns a pair with the version first
if len(result) != 2 or not all(isinstance(x, str) for x in result):
raise BadPeerError(f'bad server.version result: {result}')
server_version, protocol_version = result
peer.server_version = server_version
peer.features['server_version'] = server_version
ptuple = protocol_tuple(protocol_version)
async with TaskGroup() as g:
await g.spawn(self._send_headers_subscribe(session, peer, ptuple))
await g.spawn(self._send_server_features(session, peer))
await g.spawn(self._send_peers_subscribe(session, peer))
async def _send_headers_subscribe(self, session, peer, ptuple):
message = 'blockchain.headers.subscribe'
result = await session.send_request(message)
assert_good(message, result, dict)
our_height = self.db.db_height
if ptuple < (1, 3):
their_height = result.get('block_height')
else:
their_height = result.get('height')
if not isinstance(their_height, int):
raise BadPeerError(f'invalid height {their_height}')
if abs(our_height - their_height) > 5:
raise BadPeerError(f'bad height {their_height:,d} '
f'(ours: {our_height:,d})')
# Check prior header too in case of hard fork.
check_height = min(our_height, their_height)
raw_header = await self.db.raw_header(check_height)
if ptuple >= (1, 4):
ours = raw_header.hex()
message = 'blockchain.block.header'
theirs = await session.send_request(message, [check_height])
assert_good(message, theirs, str)
if ours != theirs:
raise BadPeerError(f'our header {ours} and '
f'theirs {theirs} differ')
else:
ours = self.env.coin.electrum_header(raw_header, check_height)
ours = ours.get('prev_block_hash')
message = 'blockchain.block.get_header'
theirs = await session.send_request(message, [check_height])
assert_good(message, theirs, dict)
theirs = theirs.get('prev_block_hash')
if ours != theirs:
raise BadPeerError(f'our header hash {ours} and '
f'theirs {theirs} differ')
async def _send_server_features(self, session, peer):
message = 'server.features'
features = await session.send_request(message)
assert_good(message, features, dict)
hosts = [host.lower() for host in features.get('hosts', {})]
if self.env.coin.GENESIS_HASH != features.get('genesis_hash'):
raise BadPeerError('incorrect genesis hash')
elif peer.host.lower() in hosts:
peer.update_features(features)
else:
raise BadPeerError(f'not listed in own hosts list {hosts}')
async def _send_peers_subscribe(self, session, peer):
message = 'server.peers.subscribe'
raw_peers = await session.send_request(message)
assert_good(message, raw_peers, list)
# Check the peers list we got from a remote peer.
# Each is expected to be of the form:
# [ip_addr, hostname, ['v1.0', 't51001', 's51002']]
# Call add_peer if the remote doesn't appear to know about us.
try:
real_names = [' '.join([u[1]] + u[2]) for u in raw_peers]
peers = [Peer.from_real_name(real_name, str(peer))
for real_name in real_names]
except Exception:
raise BadPeerError('bad server.peers.subscribe response')
await self._note_peers(peers)
features = self._features_to_register(peer, peers)
if not features:
return
self.logger.info(f'registering ourself with {peer}')
# We only care to wait for the response
await session.send_request('server.add_peer', [features])
#
# External interface
#
async def discover_peers(self):
'''Perform peer maintenance. This includes
1) Forgetting unreachable peers.
2) Verifying connectivity of new peers.
3) Retrying old peers at regular intervals.
'''
if self.env.peer_discovery != self.env.PD_ON:
self.logger.info('peer discovery is disabled')
return
self.logger.info(f'beginning peer discovery. Force use of '
f'proxy: {self.env.force_proxy}')
forever = Event()
async with self.group as group:
await group.spawn(forever.wait())
await group.spawn(self._detect_proxy())
await group.spawn(self._import_peers())
# Consume tasks as they complete, logging unexpected failures
async for task in group:
if not task.cancelled():
try:
task.result()
except Exception:
self.logger.exception('task failed unexpectedly')
def info(self):
'''The number of peers.'''
self._set_peer_statuses()
counter = Counter(peer.status for peer in self.peers)
return {
'bad': counter[PEER_BAD],
'good': counter[PEER_GOOD],
'never': counter[PEER_NEVER],
'stale': counter[PEER_STALE],
'total': len(self.peers),
}
async def add_localRPC_peer(self, real_name):
'''Add a peer passed by the admin over LocalRPC.'''
await self._note_peers([Peer.from_real_name(real_name, 'RPC')])
async def on_add_peer(self, features, source_info):
'''Add a peer (but only if the peer resolves to the source).'''
if not source_info:
self.logger.info('ignored add_peer request: no source info')
return False
source = source_info[0]
peers = Peer.peers_from_features(features, source)
if not peers:
self.logger.info('ignored add_peer request: no peers given')
return False
# Just look at the first peer, require it
peer = peers[0]
host = peer.host
if peer.is_tor:
permit = self._permit_new_onion_peer()
reason = 'rate limiting'
else:
getaddrinfo = asyncio.get_event_loop().getaddrinfo
try:
infos = await getaddrinfo(host, 80, type=socket.SOCK_STREAM)
except socket.gaierror:
permit = False
reason = 'address resolution failure'
else:
permit = any(source == info[-1][0] for info in infos)
reason = 'source-destination mismatch'
if permit:
self.logger.info(f'accepted add_peer request from {source} '
f'for {host}')
await self._note_peers([peer], check_ports=True)
else:
self.logger.warning(f'rejected add_peer request from {source} '
f'for {host} ({reason})')
return permit
def on_peers_subscribe(self, is_tor):
'''Returns the server peers as a list of (ip, host, details) tuples.
We return all peers we've connected to in the last day.
Additionally, if we don't have onion routing, we return a few
hard-coded onion servers.
'''
cutoff = time.time() - STALE_SECS
recent = [peer for peer in self.peers
if peer.last_good > cutoff and
not peer.bad and peer.is_public]
onion_peers = []
# Always report ourselves if valid (even if not public)
peers = set(myself for myself in self.myselves
if myself.last_good > cutoff)
# Bucket the clearnet peers and select up to two from each
buckets = defaultdict(list)
for peer in recent:
if peer.is_tor:
onion_peers.append(peer)
else:
buckets[peer.bucket()].append(peer)
for bucket_peers in buckets.values():
random.shuffle(bucket_peers)
peers.update(bucket_peers[:2])
# Add up to 20% onion peers (but up to 10 is OK anyway)
random.shuffle(onion_peers)
max_onion = 50 if is_tor else max(10, len(peers) // 4)
peers.update(onion_peers[:max_onion])
return [peer.to_tuple() for peer in peers]
def proxy_peername(self):
'''Return the peername of the proxy, if there is a proxy, otherwise
None.'''
return self.proxy.peername if self.proxy else None
def rpc_data(self):
'''Peer data for the peers RPC method.'''
self._set_peer_statuses()
descs = ['good', 'stale', 'never', 'bad']
def peer_data(peer):
data = peer.serialize()
data['status'] = descs[peer.status]
return data
def peer_key(peer):
return (peer.bad, -peer.last_good)
return [peer_data(peer) for peer in sorted(self.peers, key=peer_key)]

251
torba/server/script.py Normal file
View file

@ -0,0 +1,251 @@
# Copyright (c) 2016-2017, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# and warranty status of this software.
'''Script-related classes and functions.'''
import struct
from collections import namedtuple
from torba.server.enum import Enumeration
from torba.server.hash import hash160
from torba.server.util import unpack_le_uint16_from, unpack_le_uint32_from, \
pack_le_uint16, pack_le_uint32
class ScriptError(Exception):
'''Exception used for script errors.'''
OpCodes = Enumeration("Opcodes", [
("OP_0", 0), ("OP_PUSHDATA1", 76),
"OP_PUSHDATA2", "OP_PUSHDATA4", "OP_1NEGATE",
"OP_RESERVED",
"OP_1", "OP_2", "OP_3", "OP_4", "OP_5", "OP_6", "OP_7", "OP_8",
"OP_9", "OP_10", "OP_11", "OP_12", "OP_13", "OP_14", "OP_15", "OP_16",
"OP_NOP", "OP_VER", "OP_IF", "OP_NOTIF", "OP_VERIF", "OP_VERNOTIF",
"OP_ELSE", "OP_ENDIF", "OP_VERIFY", "OP_RETURN",
"OP_TOALTSTACK", "OP_FROMALTSTACK", "OP_2DROP", "OP_2DUP", "OP_3DUP",
"OP_2OVER", "OP_2ROT", "OP_2SWAP", "OP_IFDUP", "OP_DEPTH", "OP_DROP",
"OP_DUP", "OP_NIP", "OP_OVER", "OP_PICK", "OP_ROLL", "OP_ROT",
"OP_SWAP", "OP_TUCK",
"OP_CAT", "OP_SUBSTR", "OP_LEFT", "OP_RIGHT", "OP_SIZE",
"OP_INVERT", "OP_AND", "OP_OR", "OP_XOR", "OP_EQUAL", "OP_EQUALVERIFY",
"OP_RESERVED1", "OP_RESERVED2",
"OP_1ADD", "OP_1SUB", "OP_2MUL", "OP_2DIV", "OP_NEGATE", "OP_ABS",
"OP_NOT", "OP_0NOTEQUAL", "OP_ADD", "OP_SUB", "OP_MUL", "OP_DIV", "OP_MOD",
"OP_LSHIFT", "OP_RSHIFT", "OP_BOOLAND", "OP_BOOLOR", "OP_NUMEQUAL",
"OP_NUMEQUALVERIFY", "OP_NUMNOTEQUAL", "OP_LESSTHAN", "OP_GREATERTHAN",
"OP_LESSTHANOREQUAL", "OP_GREATERTHANOREQUAL", "OP_MIN", "OP_MAX",
"OP_WITHIN",
"OP_RIPEMD160", "OP_SHA1", "OP_SHA256", "OP_HASH160", "OP_HASH256",
"OP_CODESEPARATOR", "OP_CHECKSIG", "OP_CHECKSIGVERIFY", "OP_CHECKMULTISIG",
"OP_CHECKMULTISIGVERIFY",
"OP_NOP1",
"OP_CHECKLOCKTIMEVERIFY", "OP_CHECKSEQUENCEVERIFY"
])
# Paranoia to make it hard to create bad scripts
assert OpCodes.OP_DUP == 0x76
assert OpCodes.OP_HASH160 == 0xa9
assert OpCodes.OP_EQUAL == 0x87
assert OpCodes.OP_EQUALVERIFY == 0x88
assert OpCodes.OP_CHECKSIG == 0xac
assert OpCodes.OP_CHECKMULTISIG == 0xae
def _match_ops(ops, pattern):
if len(ops) != len(pattern):
return False
for op, pop in zip(ops, pattern):
if pop != op:
# -1 means 'data push', whose op is an (op, data) tuple
if pop == -1 and isinstance(op, tuple):
continue
return False
return True
class ScriptPubKey(object):
'''A class for handling a tx output script that gives conditions
necessary for spending.
'''
TO_ADDRESS_OPS = [OpCodes.OP_DUP, OpCodes.OP_HASH160, -1,
OpCodes.OP_EQUALVERIFY, OpCodes.OP_CHECKSIG]
TO_P2SH_OPS = [OpCodes.OP_HASH160, -1, OpCodes.OP_EQUAL]
TO_PUBKEY_OPS = [-1, OpCodes.OP_CHECKSIG]
PayToHandlers = namedtuple('PayToHandlers', 'address script_hash pubkey '
'unspendable strange')
@classmethod
def pay_to(cls, handlers, script):
'''Parse a script, invoke the appropriate handler and
return the result.
One of the following handlers is invoked:
handlers.address(hash160)
handlers.script_hash(hash160)
handlers.pubkey(pubkey)
handlers.unspendable()
handlers.strange(script)
'''
try:
ops = Script.get_ops(script)
except ScriptError:
return handlers.unspendable()
match = _match_ops
if match(ops, cls.TO_ADDRESS_OPS):
return handlers.address(ops[2][-1])
if match(ops, cls.TO_P2SH_OPS):
return handlers.script_hash(ops[1][-1])
if match(ops, cls.TO_PUBKEY_OPS):
return handlers.pubkey(ops[0][-1])
if ops and ops[0] == OpCodes.OP_RETURN:
return handlers.unspendable()
return handlers.strange(script)
@classmethod
def P2SH_script(cls, hash160):
return (bytes([OpCodes.OP_HASH160])
+ Script.push_data(hash160)
+ bytes([OpCodes.OP_EQUAL]))
@classmethod
def P2PKH_script(cls, hash160):
return (bytes([OpCodes.OP_DUP, OpCodes.OP_HASH160])
+ Script.push_data(hash160)
+ bytes([OpCodes.OP_EQUALVERIFY, OpCodes.OP_CHECKSIG]))
@classmethod
def validate_pubkey(cls, pubkey, req_compressed=False):
if isinstance(pubkey, (bytes, bytearray)):
if len(pubkey) == 33 and pubkey[0] in (2, 3):
return # Compressed
if len(pubkey) == 65 and pubkey[0] == 4:
if not req_compressed:
return
raise PubKeyError('uncompressed pubkeys are invalid')
raise PubKeyError('invalid pubkey {}'.format(pubkey))
@classmethod
def pubkey_script(cls, pubkey):
cls.validate_pubkey(pubkey)
return Script.push_data(pubkey) + bytes([OpCodes.OP_CHECKSIG])
@classmethod
def multisig_script(cls, m, pubkeys):
'''Returns the script for a pay-to-multisig transaction.'''
n = len(pubkeys)
if not 1 <= m <= n <= 15:
raise ScriptError('{:d} of {:d} multisig script not possible'
.format(m, n))
for pubkey in pubkeys:
cls.validate_pubkey(pubkey, req_compressed=True)
# See https://bitcoin.org/en/developer-guide
# 2 of 3 is: OP_2 pubkey1 pubkey2 pubkey3 OP_3 OP_CHECKMULTISIG
return (bytes([OP_1 + m - 1])
+ b''.join(cls.push_data(pubkey) for pubkey in pubkeys)
+ bytes([OP_1 + n - 1, OP_CHECK_MULTISIG]))
class Script(object):
@classmethod
def get_ops(cls, script):
ops = []
# The unpacks or script[n] below throw on truncated scripts
try:
n = 0
while n < len(script):
op = script[n]
n += 1
if op <= OpCodes.OP_PUSHDATA4:
# Raw bytes follow
if op < OpCodes.OP_PUSHDATA1:
dlen = op
elif op == OpCodes.OP_PUSHDATA1:
dlen = script[n]
n += 1
elif op == OpCodes.OP_PUSHDATA2:
dlen, = unpack_le_uint16_from(script[n: n + 2])
n += 2
else:
dlen, = unpack_le_uint32_from(script[n: n + 4])
n += 4
if n + dlen > len(script):
raise IndexError
op = (op, script[n:n + dlen])
n += dlen
ops.append(op)
except Exception:
# Truncated script; e.g. tx_hash
# ebc9fa1196a59e192352d76c0f6e73167046b9d37b8302b6bb6968dfd279b767
raise ScriptError('truncated script')
return ops
@classmethod
def push_data(cls, data):
'''Returns the opcodes to push the data on the stack.'''
assert isinstance(data, (bytes, bytearray))
n = len(data)
if n < OpCodes.OP_PUSHDATA1:
return bytes([n]) + data
if n < 256:
return bytes([OpCodes.OP_PUSHDATA1, n]) + data
if n < 65536:
return bytes([OpCodes.OP_PUSHDATA2]) + pack_le_uint16(n) + data
return bytes([OpCodes.OP_PUSHDATA4]) + pack_le_uint32(n) + data
@classmethod
def opcode_name(cls, opcode):
if OpCodes.OP_0 < opcode < OpCodes.OP_PUSHDATA1:
return 'OP_{:d}'.format(opcode)
try:
return OpCodes.whatis(opcode)
except KeyError:
return 'OP_UNKNOWN:{:d}'.format(opcode)
@classmethod
def dump(cls, script):
opcodes, datas = cls.get_ops(script)
for opcode, data in zip(opcodes, datas):
name = cls.opcode_name(opcode)
if data is None:
print(name)
else:
print('{} {} ({:d} bytes)'
.format(name, data.hex(), len(data)))

129
torba/server/server.py Normal file
View file

@ -0,0 +1,129 @@
import signal
import time
import logging
from functools import partial
import asyncio
import torba
from torba.server.db import DB
from torba.server.mempool import MemPool, MemPoolAPI
from torba.server.session import SessionManager
class Notifications:
# hashX notifications come from two sources: new blocks and
# mempool refreshes.
#
# A user with a pending transaction is notified after the block it
# gets in is processed. Block processing can take an extended
# time, and the prefetcher might poll the daemon after the mempool
# code in any case. In such cases the transaction will not be in
# the mempool after the mempool refresh. We want to avoid
# notifying clients twice - for the mempool refresh and when the
# block is done. This object handles that logic by deferring
# notifications appropriately.
def __init__(self):
self._touched_mp = {}
self._touched_bp = {}
self._highest_block = -1
async def _maybe_notify(self):
tmp, tbp = self._touched_mp, self._touched_bp
common = set(tmp).intersection(tbp)
if common:
height = max(common)
elif tmp and max(tmp) == self._highest_block:
height = self._highest_block
else:
# Either we are processing a block and waiting for it to
# come in, or we have not yet had a mempool update for the
# new block height
return
touched = tmp.pop(height)
for old in [h for h in tmp if h <= height]:
del tmp[old]
for old in [h for h in tbp if h <= height]:
touched.update(tbp.pop(old))
await self.notify(height, touched)
async def notify(self, height, touched):
pass
async def start(self, height, notify_func):
self._highest_block = height
self.notify = notify_func
await self.notify(height, set())
async def on_mempool(self, touched, height):
self._touched_mp[height] = touched
await self._maybe_notify()
async def on_block(self, touched, height):
self._touched_bp[height] = touched
self._highest_block = height
await self._maybe_notify()
class Server:
def __init__(self, env):
self.env = env
self.log = logging.getLogger(__name__).getChild(self.__class__.__name__)
self.shutdown_event = asyncio.Event()
self.cancellable_tasks = []
async def start(self):
env = self.env
min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()
self.log.info(f'software version: {torba.__version__}')
self.log.info(f'supported protocol versions: {min_str}-{max_str}')
self.log.info(f'event loop policy: {env.loop_policy}')
self.log.info(f'reorg limit is {env.reorg_limit:,d} blocks')
notifications = Notifications()
Daemon = env.coin.DAEMON
BlockProcessor = env.coin.BLOCK_PROCESSOR
daemon = Daemon(env.coin, env.daemon_url)
db = DB(env)
bp = BlockProcessor(env, db, daemon, notifications)
# Set notifications up to implement the MemPoolAPI
notifications.height = daemon.height
notifications.cached_height = daemon.cached_height
notifications.mempool_hashes = daemon.mempool_hashes
notifications.raw_transactions = daemon.getrawtransactions
notifications.lookup_utxos = db.lookup_utxos
MemPoolAPI.register(Notifications)
mempool = MemPool(env.coin, notifications)
session_mgr = SessionManager(
env, db, bp, daemon, mempool, self.shutdown_event
)
await daemon.height()
def _start_cancellable(run, *args):
_flag = asyncio.Event()
self.cancellable_tasks.append(asyncio.ensure_future(run(*args, _flag)))
return _flag.wait()
await _start_cancellable(bp.fetch_and_process_blocks)
await db.populate_header_merkle_cache()
await _start_cancellable(mempool.keep_synchronized)
await _start_cancellable(session_mgr.serve, notifications)
def stop(self):
for task in reversed(self.cancellable_tasks):
task.cancel()
def run(self):
loop = asyncio.get_event_loop()
try:
loop.add_signal_handler(signal.SIGINT, self.stop)
loop.add_signal_handler(signal.SIGTERM, self.stop)
loop.run_until_complete(self.start())
loop.run_forever()
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

1436
torba/server/session.py Normal file

File diff suppressed because it is too large Load diff

166
torba/server/storage.py Normal file
View file

@ -0,0 +1,166 @@
# Copyright (c) 2016-2017, the ElectrumX authors
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
'''Backend database abstraction.'''
import os
from functools import partial
import torba.server.util as util
def db_class(name):
'''Returns a DB engine class.'''
for db_class in util.subclasses(Storage):
if db_class.__name__.lower() == name.lower():
db_class.import_module()
return db_class
raise RuntimeError('unrecognised DB engine "{}"'.format(name))
class Storage(object):
'''Abstract base class of the DB backend abstraction.'''
def __init__(self, name, for_sync):
self.is_new = not os.path.exists(name)
self.for_sync = for_sync or self.is_new
self.open(name, create=self.is_new)
@classmethod
def import_module(cls):
'''Import the DB engine module.'''
raise NotImplementedError
def open(self, name, create):
'''Open an existing database or create a new one.'''
raise NotImplementedError
def close(self):
'''Close an existing database.'''
raise NotImplementedError
def get(self, key):
raise NotImplementedError
def put(self, key, value):
raise NotImplementedError
def write_batch(self):
'''Return a context manager that provides `put` and `delete`.
Changes should only be committed when the context manager
closes without an exception.
'''
raise NotImplementedError
def iterator(self, prefix=b'', reverse=False):
'''Return an iterator that yields (key, value) pairs from the
database sorted by key.
If `prefix` is set, only keys starting with `prefix` will be
included. If `reverse` is True the items are returned in
reverse order.
'''
raise NotImplementedError
class LevelDB(Storage):
'''LevelDB database engine.'''
@classmethod
def import_module(cls):
import plyvel
cls.module = plyvel
def open(self, name, create):
mof = 512 if self.for_sync else 128
# Use snappy compression (the default)
self.db = self.module.DB(name, create_if_missing=create,
max_open_files=mof)
self.close = self.db.close
self.get = self.db.get
self.put = self.db.put
self.iterator = self.db.iterator
self.write_batch = partial(self.db.write_batch, transaction=True,
sync=True)
class RocksDB(Storage):
'''RocksDB database engine.'''
@classmethod
def import_module(cls):
import rocksdb
cls.module = rocksdb
def open(self, name, create):
mof = 512 if self.for_sync else 128
# Use snappy compression (the default)
options = self.module.Options(create_if_missing=create,
use_fsync=True,
target_file_size_base=33554432,
max_open_files=mof)
self.db = self.module.DB(name, options)
self.get = self.db.get
self.put = self.db.put
def close(self):
# PyRocksDB doesn't provide a close method; hopefully this is enough
self.db = self.get = self.put = None
import gc
gc.collect()
def write_batch(self):
return RocksDBWriteBatch(self.db)
def iterator(self, prefix=b'', reverse=False):
return RocksDBIterator(self.db, prefix, reverse)
class RocksDBWriteBatch(object):
'''A write batch for RocksDB.'''
def __init__(self, db):
self.batch = RocksDB.module.WriteBatch()
self.db = db
def __enter__(self):
return self.batch
def __exit__(self, exc_type, exc_val, exc_tb):
if not exc_val:
self.db.write(self.batch)
class RocksDBIterator(object):
'''An iterator for RocksDB.'''
def __init__(self, db, prefix, reverse):
self.prefix = prefix
if reverse:
self.iterator = reversed(db.iteritems())
nxt_prefix = util.increment_byte_string(prefix)
if nxt_prefix:
self.iterator.seek(nxt_prefix)
try:
next(self.iterator)
except StopIteration:
self.iterator.seek(nxt_prefix)
else:
self.iterator.seek_to_last()
else:
self.iterator = db.iteritems()
self.iterator.seek(prefix)
def __iter__(self):
return self
def __next__(self):
k, v = next(self.iterator)
if not k.startswith(self.prefix):
raise StopIteration
return k, v

82
torba/server/text.py Normal file
View file

@ -0,0 +1,82 @@
import time
import torba.server.util as util
def sessions_lines(data):
'''A generator returning lines for a list of sessions.
data is the return value of rpc_sessions().'''
fmt = ('{:<6} {:<5} {:>17} {:>5} {:>5} {:>5} '
'{:>7} {:>7} {:>7} {:>7} {:>7} {:>9} {:>21}')
yield fmt.format('ID', 'Flags', 'Client', 'Proto',
'Reqs', 'Txs', 'Subs',
'Recv', 'Recv KB', 'Sent', 'Sent KB', 'Time', 'Peer')
for (id_, flags, peer, client, proto, reqs, txs_sent, subs,
recv_count, recv_size, send_count, send_size, time) in data:
yield fmt.format(id_, flags, client, proto,
'{:,d}'.format(reqs),
'{:,d}'.format(txs_sent),
'{:,d}'.format(subs),
'{:,d}'.format(recv_count),
'{:,d}'.format(recv_size // 1024),
'{:,d}'.format(send_count),
'{:,d}'.format(send_size // 1024),
util.formatted_time(time, sep=''), peer)
def groups_lines(data):
'''A generator returning lines for a list of groups.
data is the return value of rpc_groups().'''
fmt = ('{:<6} {:>9} {:>9} {:>6} {:>6} {:>8}'
'{:>7} {:>9} {:>7} {:>9}')
yield fmt.format('ID', 'Sessions', 'Bwidth KB', 'Reqs', 'Txs', 'Subs',
'Recv', 'Recv KB', 'Sent', 'Sent KB')
for (id_, session_count, bandwidth, reqs, txs_sent, subs,
recv_count, recv_size, send_count, send_size) in data:
yield fmt.format(id_,
'{:,d}'.format(session_count),
'{:,d}'.format(bandwidth // 1024),
'{:,d}'.format(reqs),
'{:,d}'.format(txs_sent),
'{:,d}'.format(subs),
'{:,d}'.format(recv_count),
'{:,d}'.format(recv_size // 1024),
'{:,d}'.format(send_count),
'{:,d}'.format(send_size // 1024))
def peers_lines(data):
'''A generator returning lines for a list of peers.
data is the return value of rpc_peers().'''
def time_fmt(t):
if not t:
return 'Never'
return util.formatted_time(now - t)
now = time.time()
fmt = ('{:<30} {:<6} {:>5} {:>5} {:<17} {:>4} '
'{:>4} {:>8} {:>11} {:>11} {:>5} {:>20} {:<15}')
yield fmt.format('Host', 'Status', 'TCP', 'SSL', 'Server', 'Min',
'Max', 'Pruning', 'Last Good', 'Last Try',
'Tries', 'Source', 'IP Address')
for item in data:
features = item['features']
hostname = item['host']
host = features['hosts'][hostname]
yield fmt.format(hostname[:30],
item['status'],
host.get('tcp_port') or '',
host.get('ssl_port') or '',
features['server_version'] or 'unknown',
features['protocol_min'],
features['protocol_max'],
features['pruning'] or '',
time_fmt(item['last_good']),
time_fmt(item['last_try']),
item['try_count'],
item['source'][:20],
item['ip_addr'] or '')

625
torba/server/tx.py Normal file
View file

@ -0,0 +1,625 @@
# Copyright (c) 2016-2017, Neil Booth
# Copyright (c) 2017, the ElectrumX authors
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# and warranty status of this software.
'''Transaction-related classes and functions.'''
from collections import namedtuple
from torba.server.hash import sha256, double_sha256, hash_to_hex_str
from torba.server.script import OpCodes
from torba.server.util import (
unpack_le_int32_from, unpack_le_int64_from, unpack_le_uint16_from,
unpack_le_uint32_from, unpack_le_uint64_from, pack_le_int32, pack_varint,
pack_le_uint32, pack_le_int64, pack_varbytes,
)
ZERO = bytes(32)
MINUS_1 = 4294967295
class Tx(namedtuple("Tx", "version inputs outputs locktime")):
'''Class representing a transaction.'''
def serialize(self):
return b''.join((
pack_le_int32(self.version),
pack_varint(len(self.inputs)),
b''.join(tx_in.serialize() for tx_in in self.inputs),
pack_varint(len(self.outputs)),
b''.join(tx_out.serialize() for tx_out in self.outputs),
pack_le_uint32(self.locktime)
))
class TxInput(namedtuple("TxInput", "prev_hash prev_idx script sequence")):
'''Class representing a transaction input.'''
def __str__(self):
script = self.script.hex()
prev_hash = hash_to_hex_str(self.prev_hash)
return ("Input({}, {:d}, script={}, sequence={:d})"
.format(prev_hash, self.prev_idx, script, self.sequence))
def is_generation(self):
'''Test if an input is generation/coinbase like'''
return self.prev_idx == MINUS_1 and self.prev_hash == ZERO
def serialize(self):
return b''.join((
self.prev_hash,
pack_le_uint32(self.prev_idx),
pack_varbytes(self.script),
pack_le_uint32(self.sequence),
))
class TxOutput(namedtuple("TxOutput", "value pk_script")):
def serialize(self):
return b''.join((
pack_le_int64(self.value),
pack_varbytes(self.pk_script),
))
class Deserializer(object):
'''Deserializes blocks into transactions.
External entry points are read_tx(), read_tx_and_hash(),
read_tx_and_vsize() and read_block().
This code is performance sensitive as it is executed 100s of
millions of times during sync.
'''
TX_HASH_FN = staticmethod(double_sha256)
def __init__(self, binary, start=0):
assert isinstance(binary, bytes)
self.binary = binary
self.binary_length = len(binary)
self.cursor = start
def read_tx(self):
'''Return a deserialized transaction.'''
return Tx(
self._read_le_int32(), # version
self._read_inputs(), # inputs
self._read_outputs(), # outputs
self._read_le_uint32() # locktime
)
def read_tx_and_hash(self):
'''Return a (deserialized TX, tx_hash) pair.
The hash needs to be reversed for human display; for efficiency
we process it in the natural serialized order.
'''
start = self.cursor
return self.read_tx(), self.TX_HASH_FN(self.binary[start:self.cursor])
def read_tx_and_vsize(self):
'''Return a (deserialized TX, vsize) pair.'''
return self.read_tx(), self.binary_length
def read_tx_block(self):
'''Returns a list of (deserialized_tx, tx_hash) pairs.'''
read = self.read_tx_and_hash
# Some coins have excess data beyond the end of the transactions
return [read() for _ in range(self._read_varint())]
def _read_inputs(self):
read_input = self._read_input
return [read_input() for i in range(self._read_varint())]
def _read_input(self):
return TxInput(
self._read_nbytes(32), # prev_hash
self._read_le_uint32(), # prev_idx
self._read_varbytes(), # script
self._read_le_uint32() # sequence
)
def _read_outputs(self):
read_output = self._read_output
return [read_output() for i in range(self._read_varint())]
def _read_output(self):
return TxOutput(
self._read_le_int64(), # value
self._read_varbytes(), # pk_script
)
def _read_byte(self):
cursor = self.cursor
self.cursor += 1
return self.binary[cursor]
def _read_nbytes(self, n):
cursor = self.cursor
self.cursor = end = cursor + n
assert self.binary_length >= end
return self.binary[cursor:end]
def _read_varbytes(self):
return self._read_nbytes(self._read_varint())
def _read_varint(self):
n = self.binary[self.cursor]
self.cursor += 1
if n < 253:
return n
if n == 253:
return self._read_le_uint16()
if n == 254:
return self._read_le_uint32()
return self._read_le_uint64()
def _read_le_int32(self):
result, = unpack_le_int32_from(self.binary, self.cursor)
self.cursor += 4
return result
def _read_le_int64(self):
result, = unpack_le_int64_from(self.binary, self.cursor)
self.cursor += 8
return result
def _read_le_uint16(self):
result, = unpack_le_uint16_from(self.binary, self.cursor)
self.cursor += 2
return result
def _read_le_uint32(self):
result, = unpack_le_uint32_from(self.binary, self.cursor)
self.cursor += 4
return result
def _read_le_uint64(self):
result, = unpack_le_uint64_from(self.binary, self.cursor)
self.cursor += 8
return result
class TxSegWit(namedtuple("Tx", "version marker flag inputs outputs "
"witness locktime")):
'''Class representing a SegWit transaction.'''
class DeserializerSegWit(Deserializer):
# https://bitcoincore.org/en/segwit_wallet_dev/#transaction-serialization
def _read_witness(self, fields):
read_witness_field = self._read_witness_field
return [read_witness_field() for i in range(fields)]
def _read_witness_field(self):
read_varbytes = self._read_varbytes
return [read_varbytes() for i in range(self._read_varint())]
def _read_tx_parts(self):
'''Return a (deserialized TX, tx_hash, vsize) tuple.'''
start = self.cursor
marker = self.binary[self.cursor + 4]
if marker:
tx = super().read_tx()
tx_hash = self.TX_HASH_FN(self.binary[start:self.cursor])
return tx, tx_hash, self.binary_length
# Ugh, this is nasty.
version = self._read_le_int32()
orig_ser = self.binary[start:self.cursor]
marker = self._read_byte()
flag = self._read_byte()
start = self.cursor
inputs = self._read_inputs()
outputs = self._read_outputs()
orig_ser += self.binary[start:self.cursor]
base_size = self.cursor - start
witness = self._read_witness(len(inputs))
start = self.cursor
locktime = self._read_le_uint32()
orig_ser += self.binary[start:self.cursor]
vsize = (3 * base_size + self.binary_length) // 4
return TxSegWit(version, marker, flag, inputs, outputs, witness,
locktime), self.TX_HASH_FN(orig_ser), vsize
def read_tx(self):
return self._read_tx_parts()[0]
def read_tx_and_hash(self):
tx, tx_hash, vsize = self._read_tx_parts()
return tx, tx_hash
def read_tx_and_vsize(self):
tx, tx_hash, vsize = self._read_tx_parts()
return tx, vsize
class DeserializerAuxPow(Deserializer):
VERSION_AUXPOW = (1 << 8)
def read_header(self, height, static_header_size):
'''Return the AuxPow block header bytes'''
start = self.cursor
version = self._read_le_uint32()
if version & self.VERSION_AUXPOW:
# We are going to calculate the block size then read it as bytes
self.cursor = start
self.cursor += static_header_size # Block normal header
self.read_tx() # AuxPow transaction
self.cursor += 32 # Parent block hash
merkle_size = self._read_varint()
self.cursor += 32 * merkle_size # Merkle branch
self.cursor += 4 # Index
merkle_size = self._read_varint()
self.cursor += 32 * merkle_size # Chain merkle branch
self.cursor += 4 # Chain index
self.cursor += 80 # Parent block header
header_end = self.cursor
else:
header_end = static_header_size
self.cursor = start
return self._read_nbytes(header_end)
class DeserializerAuxPowSegWit(DeserializerSegWit, DeserializerAuxPow):
pass
class DeserializerEquihash(Deserializer):
def read_header(self, height, static_header_size):
'''Return the block header bytes'''
start = self.cursor
# We are going to calculate the block size then read it as bytes
self.cursor += static_header_size
solution_size = self._read_varint()
self.cursor += solution_size
header_end = self.cursor
self.cursor = start
return self._read_nbytes(header_end)
class DeserializerEquihashSegWit(DeserializerSegWit, DeserializerEquihash):
pass
class TxJoinSplit(namedtuple("Tx", "version inputs outputs locktime")):
'''Class representing a JoinSplit transaction.'''
class DeserializerZcash(DeserializerEquihash):
def read_tx(self):
header = self._read_le_uint32()
overwintered = ((header >> 31) == 1)
if overwintered:
version = header & 0x7fffffff
self.cursor += 4 # versionGroupId
else:
version = header
is_overwinter_v3 = version == 3
is_sapling_v4 = version == 4
base_tx = TxJoinSplit(
version,
self._read_inputs(), # inputs
self._read_outputs(), # outputs
self._read_le_uint32() # locktime
)
if is_overwinter_v3 or is_sapling_v4:
self.cursor += 4 # expiryHeight
has_shielded = False
if is_sapling_v4:
self.cursor += 8 # valueBalance
shielded_spend_size = self._read_varint()
self.cursor += shielded_spend_size * 384 # vShieldedSpend
shielded_output_size = self._read_varint()
self.cursor += shielded_output_size * 948 # vShieldedOutput
has_shielded = shielded_spend_size > 0 or shielded_output_size > 0
if base_tx.version >= 2:
joinsplit_size = self._read_varint()
if joinsplit_size > 0:
joinsplit_desc_len = 1506 + (192 if is_sapling_v4 else 296)
# JSDescription
self.cursor += joinsplit_size * joinsplit_desc_len
self.cursor += 32 # joinSplitPubKey
self.cursor += 64 # joinSplitSig
if is_sapling_v4 and has_shielded:
self.cursor += 64 # bindingSig
return base_tx
class TxTime(namedtuple("Tx", "version time inputs outputs locktime")):
'''Class representing transaction that has a time field.'''
class DeserializerTxTime(Deserializer):
def read_tx(self):
return TxTime(
self._read_le_int32(), # version
self._read_le_uint32(), # time
self._read_inputs(), # inputs
self._read_outputs(), # outputs
self._read_le_uint32(), # locktime
)
class DeserializerReddcoin(Deserializer):
def read_tx(self):
version = self._read_le_int32()
inputs = self._read_inputs()
outputs = self._read_outputs()
locktime = self._read_le_uint32()
if version > 1:
time = self._read_le_uint32()
else:
time = 0
return TxTime(version, time, inputs, outputs, locktime)
class DeserializerTxTimeAuxPow(DeserializerTxTime):
VERSION_AUXPOW = (1 << 8)
def is_merged_block(self):
start = self.cursor
self.cursor = 0
version = self._read_le_uint32()
self.cursor = start
if version & self.VERSION_AUXPOW:
return True
return False
def read_header(self, height, static_header_size):
'''Return the AuxPow block header bytes'''
start = self.cursor
version = self._read_le_uint32()
if version & self.VERSION_AUXPOW:
# We are going to calculate the block size then read it as bytes
self.cursor = start
self.cursor += static_header_size # Block normal header
self.read_tx() # AuxPow transaction
self.cursor += 32 # Parent block hash
merkle_size = self._read_varint()
self.cursor += 32 * merkle_size # Merkle branch
self.cursor += 4 # Index
merkle_size = self._read_varint()
self.cursor += 32 * merkle_size # Chain merkle branch
self.cursor += 4 # Chain index
self.cursor += 80 # Parent block header
header_end = self.cursor
else:
header_end = static_header_size
self.cursor = start
return self._read_nbytes(header_end)
class DeserializerBitcoinAtom(DeserializerSegWit):
FORK_BLOCK_HEIGHT = 505888
def read_header(self, height, static_header_size):
'''Return the block header bytes'''
header_len = static_header_size
if height >= self.FORK_BLOCK_HEIGHT:
header_len += 4 # flags
return self._read_nbytes(header_len)
class DeserializerGroestlcoin(DeserializerSegWit):
TX_HASH_FN = staticmethod(sha256)
class TxInputTokenPay(TxInput):
'''Class representing a TokenPay transaction input.'''
OP_ANON_MARKER = 0xb9
# 2byte marker (cpubkey + sigc + sigr)
MIN_ANON_IN_SIZE = 2 + (33 + 32 + 32)
def _is_anon_input(self):
return (len(self.script) >= self.MIN_ANON_IN_SIZE and
self.script[0] == OpCodes.OP_RETURN and
self.script[1] == self.OP_ANON_MARKER)
def is_generation(self):
# Transactions comming in from stealth addresses are seen by
# the blockchain as newly minted coins. The reverse, where coins
# are sent TO a stealth address, are seen by the blockchain as
# a coin burn.
if self._is_anon_input():
return True
return super(TxInputTokenPay, self).is_generation()
class TxInputTokenPayStealth(
namedtuple("TxInput", "keyimage ringsize script sequence")):
'''Class representing a TokenPay stealth transaction input.'''
def __str__(self):
script = self.script.hex()
keyimage = bytes(self.keyimage).hex()
return ("Input({}, {:d}, script={}, sequence={:d})"
.format(keyimage, self.ringsize[1], script, self.sequence))
def is_generation(self):
return True
def serialize(self):
return b''.join((
self.keyimage,
self.ringsize,
pack_varbytes(self.script),
pack_le_uint32(self.sequence),
))
class DeserializerTokenPay(DeserializerTxTime):
def _read_input(self):
txin = TxInputTokenPay(
self._read_nbytes(32), # prev_hash
self._read_le_uint32(), # prev_idx
self._read_varbytes(), # script
self._read_le_uint32(), # sequence
)
if txin._is_anon_input():
# Not sure if this is actually needed, and seems
# extra work for no immediate benefit, but it at
# least correctly represents a stealth input
raw = txin.serialize()
deserializer = Deserializer(raw)
txin = TxInputTokenPayStealth(
deserializer._read_nbytes(33), # keyimage
deserializer._read_nbytes(3), # ringsize
deserializer._read_varbytes(), # script
deserializer._read_le_uint32() # sequence
)
return txin
# Decred
class TxInputDcr(namedtuple("TxInput", "prev_hash prev_idx tree sequence")):
'''Class representing a Decred transaction input.'''
def __str__(self):
prev_hash = hash_to_hex_str(self.prev_hash)
return ("Input({}, {:d}, tree={}, sequence={:d})"
.format(prev_hash, self.prev_idx, self.tree, self.sequence))
def is_generation(self):
'''Test if an input is generation/coinbase like'''
return self.prev_idx == MINUS_1 and self.prev_hash == ZERO
class TxOutputDcr(namedtuple("TxOutput", "value version pk_script")):
'''Class representing a Decred transaction output.'''
pass
class TxDcr(namedtuple("Tx", "version inputs outputs locktime expiry "
"witness")):
'''Class representing a Decred transaction.'''
class DeserializerDecred(Deserializer):
@staticmethod
def blake256(data):
from blake256.blake256 import blake_hash
return blake_hash(data)
@staticmethod
def blake256d(data):
from blake256.blake256 import blake_hash
return blake_hash(blake_hash(data))
def read_tx(self):
return self._read_tx_parts(produce_hash=False)[0]
def read_tx_and_hash(self):
tx, tx_hash, vsize = self._read_tx_parts()
return tx, tx_hash
def read_tx_and_vsize(self):
tx, tx_hash, vsize = self._read_tx_parts(produce_hash=False)
return tx, vsize
def read_tx_block(self):
'''Returns a list of (deserialized_tx, tx_hash) pairs.'''
read = self.read_tx_and_hash
txs = [read() for _ in range(self._read_varint())]
stxs = [read() for _ in range(self._read_varint())]
return txs + stxs
def read_tx_tree(self):
'''Returns a list of deserialized_tx without tx hashes.'''
read_tx = self.read_tx
return [read_tx() for _ in range(self._read_varint())]
def _read_input(self):
return TxInputDcr(
self._read_nbytes(32), # prev_hash
self._read_le_uint32(), # prev_idx
self._read_byte(), # tree
self._read_le_uint32(), # sequence
)
def _read_output(self):
return TxOutputDcr(
self._read_le_int64(), # value
self._read_le_uint16(), # version
self._read_varbytes(), # pk_script
)
def _read_witness(self, fields):
read_witness_field = self._read_witness_field
assert fields == self._read_varint()
return [read_witness_field() for _ in range(fields)]
def _read_witness_field(self):
value_in = self._read_le_int64()
block_height = self._read_le_uint32()
block_index = self._read_le_uint32()
script = self._read_varbytes()
return value_in, block_height, block_index, script
def _read_tx_parts(self, produce_hash=True):
start = self.cursor
version = self._read_le_int32()
inputs = self._read_inputs()
outputs = self._read_outputs()
locktime = self._read_le_uint32()
expiry = self._read_le_uint32()
end_prefix = self.cursor
witness = self._read_witness(len(inputs))
if produce_hash:
# TxSerializeNoWitness << 16 == 0x10000
no_witness_header = pack_le_uint32(0x10000 | (version & 0xffff))
prefix_tx = no_witness_header + self.binary[start+4:end_prefix]
tx_hash = self.blake256(prefix_tx)
else:
tx_hash = None
return TxDcr(
version,
inputs,
outputs,
locktime,
expiry,
witness
), tx_hash, self.cursor - start

359
torba/server/util.py Normal file
View file

@ -0,0 +1,359 @@
# Copyright (c) 2016-2017, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# and warranty status of this software.
'''Miscellaneous utility classes and functions.'''
import array
import inspect
from ipaddress import ip_address
import logging
import re
import sys
from collections import Container, Mapping
from struct import pack, Struct
# Logging utilities
class ConnectionLogger(logging.LoggerAdapter):
'''Prepends a connection identifier to a logging message.'''
def process(self, msg, kwargs):
conn_id = self.extra.get('conn_id', 'unknown')
return f'[{conn_id}] {msg}', kwargs
class CompactFormatter(logging.Formatter):
'''Strips the module from the logger name to leave the class only.'''
def format(self, record):
record.name = record.name.rpartition('.')[-1]
return super().format(record)
def make_logger(name, *, handler, level):
'''Return the root ElectrumX logger.'''
logger = logging.getLogger(name)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.propagate = False
return logger
def class_logger(path, classname):
'''Return a hierarchical logger for a class.'''
return logging.getLogger(path).getChild(classname)
# Method decorator. To be used for calculations that will always
# deliver the same result. The method cannot take any arguments
# and should be accessed as an attribute.
class cachedproperty(object):
def __init__(self, f):
self.f = f
def __get__(self, obj, type):
obj = obj or type
value = self.f(obj)
setattr(obj, self.f.__name__, value)
return value
def formatted_time(t, sep=' '):
'''Return a number of seconds as a string in days, hours, mins and
maybe secs.'''
t = int(t)
fmts = (('{:d}d', 86400), ('{:02d}h', 3600), ('{:02d}m', 60))
parts = []
for fmt, n in fmts:
val = t // n
if parts or val:
parts.append(fmt.format(val))
t %= n
if len(parts) < 3:
parts.append('{:02d}s'.format(t))
return sep.join(parts)
def deep_getsizeof(obj):
"""Find the memory footprint of a Python object.
Based on code from code.tutsplus.com: http://goo.gl/fZ0DXK
This is a recursive function that drills down a Python object graph
like a dictionary holding nested dictionaries with lists of lists
and tuples and sets.
The sys.getsizeof function does a shallow size of only. It counts each
object inside a container as pointer only regardless of how big it
really is.
"""
ids = set()
def size(o):
if id(o) in ids:
return 0
r = sys.getsizeof(o)
ids.add(id(o))
if isinstance(o, (str, bytes, bytearray, array.array)):
return r
if isinstance(o, Mapping):
return r + sum(size(k) + size(v) for k, v in o.items())
if isinstance(o, Container):
return r + sum(size(x) for x in o)
return r
return size(obj)
def subclasses(base_class, strict=True):
'''Return a list of subclasses of base_class in its module.'''
def select(obj):
return (inspect.isclass(obj) and issubclass(obj, base_class) and
(not strict or obj != base_class))
pairs = inspect.getmembers(sys.modules[base_class.__module__], select)
return [pair[1] for pair in pairs]
def chunks(items, size):
'''Break up items, an iterable, into chunks of length size.'''
for i in range(0, len(items), size):
yield items[i: i + size]
def resolve_limit(limit):
if limit is None:
return -1
assert isinstance(limit, int) and limit >= 0
return limit
def bytes_to_int(be_bytes):
'''Interprets a big-endian sequence of bytes as an integer'''
return int.from_bytes(be_bytes, 'big')
def int_to_bytes(value):
'''Converts an integer to a big-endian sequence of bytes'''
return value.to_bytes((value.bit_length() + 7) // 8, 'big')
def increment_byte_string(bs):
'''Return the lexicographically next byte string of the same length.
Return None if there is none (when the input is all 0xff bytes).'''
for n in range(1, len(bs) + 1):
if bs[-n] != 0xff:
return bs[:-n] + bytes([bs[-n] + 1]) + bytes(n - 1)
return None
class LogicalFile(object):
'''A logical binary file split across several separate files on disk.'''
def __init__(self, prefix, digits, file_size):
digit_fmt = '{' + ':0{:d}d'.format(digits) + '}'
self.filename_fmt = prefix + digit_fmt
self.file_size = file_size
def read(self, start, size=-1):
'''Read up to size bytes from the virtual file, starting at offset
start, and return them.
If size is -1 all bytes are read.'''
parts = []
while size != 0:
try:
with self.open_file(start, False) as f:
part = f.read(size)
if not part:
break
except FileNotFoundError:
break
parts.append(part)
start += len(part)
if size > 0:
size -= len(part)
return b''.join(parts)
def write(self, start, b):
'''Write the bytes-like object, b, to the underlying virtual file.'''
while b:
size = min(len(b), self.file_size - (start % self.file_size))
with self.open_file(start, True) as f:
f.write(b if size == len(b) else b[:size])
b = b[size:]
start += size
def open_file(self, start, create):
'''Open the virtual file and seek to start. Return a file handle.
Raise FileNotFoundError if the file does not exist and create
is False.
'''
file_num, offset = divmod(start, self.file_size)
filename = self.filename_fmt.format(file_num)
f = open_file(filename, create)
f.seek(offset)
return f
def open_file(filename, create=False):
'''Open the file name. Return its handle.'''
try:
return open(filename, 'rb+')
except FileNotFoundError:
if create:
return open(filename, 'wb+')
raise
def open_truncate(filename):
'''Open the file name. Return its handle.'''
return open(filename, 'wb+')
def address_string(address):
'''Return an address as a correctly formatted string.'''
fmt = '{}:{:d}'
host, port = address
try:
host = ip_address(host)
except ValueError:
pass
else:
if host.version == 6:
fmt = '[{}]:{:d}'
return fmt.format(host, port)
# See http://stackoverflow.com/questions/2532053/validate-a-hostname-string
# Note underscores are valid in domain names, but strictly invalid in host
# names. We ignore that distinction.
SEGMENT_REGEX = re.compile("(?!-)[A-Z_\\d-]{1,63}(?<!-)$", re.IGNORECASE)
def is_valid_hostname(hostname):
if len(hostname) > 255:
return False
# strip exactly one dot from the right, if present
if hostname and hostname[-1] == ".":
hostname = hostname[:-1]
return all(SEGMENT_REGEX.match(x) for x in hostname.split("."))
def protocol_tuple(s):
'''Converts a protocol version number, such as "1.0" to a tuple (1, 0).
If the version number is bad, (0, ) indicating version 0 is returned.'''
try:
return tuple(int(part) for part in s.split('.'))
except Exception:
return (0, )
def version_string(ptuple):
'''Convert a version tuple such as (1, 2) to "1.2".
There is always at least one dot, so (1, ) becomes "1.0".'''
while len(ptuple) < 2:
ptuple += (0, )
return '.'.join(str(p) for p in ptuple)
def protocol_version(client_req, min_tuple, max_tuple):
'''Given a client's protocol version string, return a pair of
protocol tuples:
(negotiated version, client min request)
If the request is unsupported, the negotiated protocol tuple is
None.
'''
if client_req is None:
client_min = client_max = min_tuple
else:
if isinstance(client_req, list) and len(client_req) == 2:
client_min, client_max = client_req
else:
client_min = client_max = client_req
client_min = protocol_tuple(client_min)
client_max = protocol_tuple(client_max)
result = min(client_max, max_tuple)
if result < max(client_min, min_tuple) or result == (0, ):
result = None
return result, client_min
struct_le_i = Struct('<i')
struct_le_q = Struct('<q')
struct_le_H = Struct('<H')
struct_le_I = Struct('<I')
struct_le_Q = Struct('<Q')
struct_be_H = Struct('>H')
struct_be_I = Struct('>I')
structB = Struct('B')
unpack_le_int32_from = struct_le_i.unpack_from
unpack_le_int64_from = struct_le_q.unpack_from
unpack_le_uint16_from = struct_le_H.unpack_from
unpack_le_uint32_from = struct_le_I.unpack_from
unpack_le_uint64_from = struct_le_Q.unpack_from
unpack_be_uint16_from = struct_be_H.unpack_from
unpack_be_uint32_from = struct_be_I.unpack_from
pack_le_int32 = struct_le_i.pack
pack_le_int64 = struct_le_q.pack
pack_le_uint16 = struct_le_H.pack
pack_le_uint32 = struct_le_I.pack
pack_le_uint64 = struct_le_Q.pack
pack_be_uint16 = struct_be_H.pack
pack_be_uint32 = struct_be_I.pack
pack_byte = structB.pack
hex_to_bytes = bytes.fromhex
def pack_varint(n):
if n < 253:
return pack_byte(n)
if n < 65536:
return pack_byte(253) + pack_le_uint16(n)
if n < 4294967296:
return pack_byte(254) + pack_le_uint32(n)
return pack_byte(255) + pack_le_uint64(n)
def pack_varbytes(data):
return pack_varint(len(data)) + data

View file

@ -0,0 +1 @@
from .testcase import IntegrationTestCase

344
torba/testing/node.py Normal file
View file

@ -0,0 +1,344 @@
import os
import sys
import shutil
import asyncio
import zipfile
import tarfile
import logging
import tempfile
import subprocess
import importlib
import requests
from binascii import hexlify
from typing import Type
from torba.server.server import Server
from torba.server.env import Env
from torba.wallet import Wallet
from torba.baseledger import BaseLedger, BlockHeightEvent
from torba.basemanager import BaseWalletManager
from torba.baseaccount import BaseAccount
root = logging.getLogger()
ch = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
root.addHandler(ch)
def get_manager_from_environment(default_manager=BaseWalletManager):
if 'TORBA_MANAGER' not in os.environ:
return default_manager
module_name = os.environ['TORBA_MANAGER'].split('-')[-1] # tox support
return importlib.import_module(module_name)
def get_ledger_from_environment():
if 'TORBA_LEDGER' not in os.environ:
raise ValueError('Environment variable TORBA_LEDGER must point to a torba based ledger module.')
module_name = os.environ['TORBA_LEDGER'].split('-')[-1] # tox support
return importlib.import_module(module_name)
def get_spvserver_from_ledger(ledger_module):
spvserver_path, regtest_class_name = ledger_module.__spvserver__.rsplit('.', 1)
spvserver_module = importlib.import_module(spvserver_path)
return getattr(spvserver_module, regtest_class_name)
def get_blockchain_node_from_ledger(ledger_module):
return BlockchainNode(
ledger_module.__node_url__,
os.path.join(ledger_module.__node_bin__, ledger_module.__node_daemon__),
os.path.join(ledger_module.__node_bin__, ledger_module.__node_cli__)
)
def set_logging(ledger_module, level):
logging.getLogger('torba').setLevel(level)
logging.getLogger('torba.server').setLevel(level)
#logging.getLogger('asyncio').setLevel(level)
logging.getLogger('blockchain').setLevel(level)
logging.getLogger(ledger_module.__name__).setLevel(level)
class Conductor:
def __init__(self, ledger_module=None, manager_module=None, verbosity=logging.WARNING):
self.ledger_module = ledger_module or get_ledger_from_environment()
self.manager_module = manager_module or get_manager_from_environment()
self.spv_module = get_spvserver_from_ledger(self.ledger_module)
self.blockchain_node = get_blockchain_node_from_ledger(self.ledger_module)
self.spv_node = SPVNode(self.spv_module)
self.wallet_node = WalletNode(self.manager_module, self.ledger_module.RegTestLedger)
set_logging(self.ledger_module, verbosity)
self.blockchain_started = False
self.spv_started = False
self.wallet_started = False
async def start_blockchain(self):
await self.blockchain_node.start()
await self.blockchain_node.generate(200)
self.blockchain_started = True
async def start_spv(self):
await self.spv_node.start()
self.spv_started = True
async def start_wallet(self):
await self.wallet_node.start()
self.wallet_started = True
async def start(self):
self.blockchain_started or await self.start_blockchain()
self.spv_started or await self.start_spv()
self.wallet_started or await self.start_wallet()
async def stop(self):
if self.wallet_started:
try:
await self.wallet_node.stop(cleanup=True)
except Exception as e:
print(e)
if self.spv_started:
try:
await self.spv_node.stop(cleanup=True)
except Exception as e:
print(e)
if self.blockchain_started:
try:
await self.blockchain_node.stop(cleanup=True)
except Exception as e:
print(e)
class WalletNode:
def __init__(self, manager_class: Type[BaseWalletManager], ledger_class: Type[BaseLedger],
verbose: bool = False) -> None:
self.manager_class = manager_class
self.ledger_class = ledger_class
self.verbose = verbose
self.manager: BaseWalletManager = None
self.ledger: BaseLedger = None
self.wallet: Wallet = None
self.account: BaseAccount = None
self.data_path: str = None
async def start(self):
self.data_path = tempfile.mkdtemp()
wallet_file_name = os.path.join(self.data_path, 'my_wallet.json')
with open(wallet_file_name, 'w') as wf:
wf.write('{"version": 1, "accounts": []}\n')
self.manager = self.manager_class.from_config({
'ledgers': {
self.ledger_class.get_id(): {
'default_servers': [('localhost', 1984)],
'data_path': self.data_path
}
},
'wallets': [wallet_file_name]
})
self.ledger = self.manager.ledgers[self.ledger_class]
self.wallet = self.manager.default_wallet
self.wallet.generate_account(self.ledger)
self.account = self.wallet.default_account
await self.manager.start()
async def stop(self, cleanup=True):
try:
await self.manager.stop()
finally:
cleanup and self.cleanup()
def cleanup(self):
shutil.rmtree(self.data_path, ignore_errors=True)
class SPVNode:
def __init__(self, coin_class):
self.coin_class = coin_class
self.controller = None
self.data_path = None
async def start(self):
self.data_path = tempfile.mkdtemp()
conf = {
'DB_DIRECTORY': self.data_path,
'DAEMON_URL': 'http://rpcuser:rpcpassword@localhost:50001/',
'REORG_LIMIT': '100',
'TCP_PORT': '1984'
}
os.environ.update(conf)
self.server = Server(Env(self.coin_class))
await self.server.start()
async def stop(self, cleanup=True):
try:
await self.controller.shutdown()
finally:
cleanup and self.cleanup()
def cleanup(self):
shutil.rmtree(self.data_path, ignore_errors=True)
class BlockchainProcess(asyncio.SubprocessProtocol):
IGNORE_OUTPUT = [
b'keypool keep',
b'keypool reserve',
b'keypool return',
]
def __init__(self, log):
self.ready = asyncio.Event()
self.stopped = asyncio.Event()
self.log = log
def pipe_data_received(self, fd, data):
if self.log and not any(ignore in data for ignore in self.IGNORE_OUTPUT):
if b'Error:' in data:
self.log.error(data.decode('ascii'))
else:
self.log.info(data.decode('ascii'))
if b'Error:' in data:
self.ready.set()
raise SystemError(data.decode('ascii'))
elif b'Done loading' in data:
self.ready.set()
elif b'Shutdown: done' in data:
self.stopped.set()
def process_exited(self):
self.stopped.set()
class BlockchainNode:
def __init__(self, url, daemon, cli):
self.latest_release_url = url
self.project_dir = os.path.dirname(os.path.dirname(__file__))
self.bin_dir = os.path.join(self.project_dir, 'bin')
self.daemon_bin = os.path.join(self.bin_dir, daemon)
self.cli_bin = os.path.join(self.bin_dir, cli)
self.log = logging.getLogger('blockchain')
self.data_path = None
self.protocol = None
self.transport = None
self._block_expected = 0
def is_expected_block(self, e: BlockHeightEvent):
return self._block_expected == e.height
@property
def exists(self):
return (
os.path.exists(self.cli_bin) and
os.path.exists(self.daemon_bin)
)
def download(self):
downloaded_file = os.path.join(
self.bin_dir,
self.latest_release_url[self.latest_release_url.rfind('/')+1:]
)
if not os.path.exists(self.bin_dir):
os.mkdir(self.bin_dir)
if not os.path.exists(downloaded_file):
self.log.info('Downloading: %s', self.latest_release_url)
r = requests.get(self.latest_release_url, stream=True)
with open(downloaded_file, 'wb') as f:
shutil.copyfileobj(r.raw, f)
self.log.info('Extracting: %s', downloaded_file)
if downloaded_file.endswith('.zip'):
with zipfile.ZipFile(downloaded_file) as zf:
zf.extractall(self.bin_dir)
# zipfile bug https://bugs.python.org/issue15795
os.chmod(self.cli_bin, 0o755)
os.chmod(self.daemon_bin, 0o755)
elif downloaded_file.endswith('.tar.gz'):
with tarfile.open(downloaded_file) as tar:
tar.extractall(self.bin_dir)
return self.exists
def ensure(self):
return self.exists or self.download()
async def start(self):
assert self.ensure()
self.data_path = tempfile.mkdtemp()
loop = asyncio.get_event_loop()
asyncio.get_child_watcher().attach_loop(loop)
command = (
self.daemon_bin,
'-datadir={}'.format(self.data_path),
'-printtoconsole', '-regtest', '-server', '-txindex',
'-rpcuser=rpcuser', '-rpcpassword=rpcpassword', '-rpcport=50001'
)
self.log.info(' '.join(command))
self.transport, self.protocol = await loop.subprocess_exec(
lambda: BlockchainProcess(self.log), *command
)
await self.protocol.ready.wait()
async def stop(self, cleanup=True):
try:
self.transport.terminate()
await self.protocol.stopped.wait()
finally:
if cleanup:
self.cleanup()
def cleanup(self):
shutil.rmtree(self.data_path, ignore_errors=True)
async def _cli_cmnd(self, *args):
cmnd_args = [
self.cli_bin, '-datadir={}'.format(self.data_path), '-regtest',
'-rpcuser=rpcuser', '-rpcpassword=rpcpassword', '-rpcport=50001'
] + list(args)
self.log.info(' '.join(cmnd_args))
loop = asyncio.get_event_loop()
asyncio.get_child_watcher().attach_loop(loop)
process = await asyncio.create_subprocess_exec(
*cmnd_args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
out, err = await process.communicate()
self.log.info(out.decode().strip())
return out.decode().strip()
def generate(self, blocks):
self._block_expected += blocks
return self._cli_cmnd('generate', str(blocks))
def invalidateblock(self, hash):
return self._cli_cmnd('invalidateblock', hash)
def get_raw_change_address(self):
return self._cli_cmnd('getrawchangeaddress')
async def get_balance(self):
return float(await self._cli_cmnd('getbalance'))
def send_to_address(self, address, credits):
return self._cli_cmnd('sendtoaddress', address, str(credits))
def send_raw_transaction(self, tx):
return self._cli_cmnd('sendrawtransaction', tx.decode())
def decode_raw_transaction(self, tx):
return self._cli_cmnd('decoderawtransaction', hexlify(tx.raw).decode())
def get_raw_transaction(self, txid):
return self._cli_cmnd('getrawtransaction', txid, '1')

147
torba/testing/service.py Normal file
View file

@ -0,0 +1,147 @@
import asyncio
import logging
from aiohttp.web import Application, WebSocketResponse, json_response
from aiohttp.http_websocket import WSMsgType, WSCloseCode
from .node import Conductor
PORT = 7954
class WebSocketLogHandler(logging.Handler):
def __init__(self, send_message):
super().__init__()
self.send_message = send_message
def emit(self, record):
try:
self.send_message({
'type': 'log',
'name': record.name,
'message': self.format(record)
})
except Exception:
self.handleError(record)
class TestingServiceAPI:
def __init__(self, stack: Conductor, loop: asyncio.AbstractEventLoop):
self.stack = stack
self.loop = loop
self.app = Application()
self.app.router.add_post('/start', self.start_stack)
self.app.router.add_post('/generate', self.generate)
self.app.router.add_post('/transfer', self.transfer)
self.app.router.add_post('/balance', self.balance)
self.app.router.add_get('/log', self.log)
self.app['websockets'] = set()
self.app.on_shutdown.append(self.on_shutdown)
self.handler = self.app.make_handler()
self.server = None
async def start(self):
self.server = await self.loop.create_server(
self.handler, '0.0.0.0', PORT
)
print('serving on', self.server.sockets[0].getsockname())
async def stop(self):
await self.stack.stop()
self.server.close()
await self.server.wait_closed()
await self.app.shutdown()
await self.handler.shutdown(60.0)
await self.app.cleanup()
async def start_stack(self, _):
handler = WebSocketLogHandler(self.send_message)
logging.getLogger('blockchain').setLevel(logging.DEBUG)
logging.getLogger('blockchain').addHandler(handler)
logging.getLogger('electrumx').setLevel(logging.DEBUG)
logging.getLogger('electrumx').addHandler(handler)
logging.getLogger('Controller').setLevel(logging.DEBUG)
logging.getLogger('Controller').addHandler(handler)
logging.getLogger('LBRYBlockProcessor').setLevel(logging.DEBUG)
logging.getLogger('LBRYBlockProcessor').addHandler(handler)
logging.getLogger('LBCDaemon').setLevel(logging.DEBUG)
logging.getLogger('LBCDaemon').addHandler(handler)
logging.getLogger('torba').setLevel(logging.DEBUG)
logging.getLogger('torba').addHandler(handler)
logging.getLogger(self.stack.ledger_module.__name__).setLevel(logging.DEBUG)
logging.getLogger(self.stack.ledger_module.__name__).addHandler(handler)
logging.getLogger(self.stack.ledger_module.__electrumx__.split('.')[0]).setLevel(logging.DEBUG)
logging.getLogger(self.stack.ledger_module.__electrumx__.split('.')[0]).addHandler(handler)
#await self.stack.start()
self.stack.blockchain_started or await self.stack.start_blockchain()
self.send_message({'type': 'service', 'name': 'blockchain'})
self.stack.spv_started or await self.stack.start_spv()
self.send_message({'type': 'service', 'name': 'spv'})
self.stack.wallet_started or await self.stack.start_wallet()
self.send_message({'type': 'service', 'name': 'wallet'})
self.stack.wallet_node.ledger.on_header.listen(self.on_status)
self.stack.wallet_node.ledger.on_transaction.listen(self.on_status)
return json_response({'started': True})
async def generate(self, request):
data = await request.post()
blocks = data.get('blocks', 1)
await self.stack.blockchain_node.generate(int(blocks))
return json_response({'blocks': blocks})
async def transfer(self, request):
data = await request.post()
address = data.get('address')
if not address:
address = await self.stack.wallet_node.account.receiving.get_or_create_usable_address()
amount = data.get('amount', 1)
txid = await self.stack.blockchain_node.send_to_address(address, amount)
await self.stack.wallet_node.ledger.on_transaction.where(
lambda e: e.tx.id == txid and e.address == address
)
return json_response({
'address': address,
'amount': amount,
'txid': txid
})
async def balance(self, _):
return json_response({
'balance': await self.stack.blockchain_node.get_balance()
})
async def log(self, request):
ws = WebSocketResponse()
await ws.prepare(request)
self.app['websockets'].add(ws)
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
if msg.data == 'close':
await ws.close()
elif msg.type == WSMsgType.ERROR:
print('ws connection closed with exception %s' %
ws.exception())
finally:
self.app['websockets'].remove(ws)
return ws
@staticmethod
async def on_shutdown(app):
for ws in app['websockets']:
await ws.close(code=WSCloseCode.GOING_AWAY, message='Server shutdown')
async def on_status(self, _):
if not self.app['websockets']:
return
self.send_message({
'type': 'status',
'height': self.stack.wallet_node.ledger.headers.height,
'balance': await self.stack.wallet_node.account.get_balance(),
'miner': await self.stack.blockchain_node.get_balance()
})
def send_message(self, msg):
for ws in self.app['websockets']:
asyncio.ensure_future(ws.send_json(msg))

176
torba/testing/testcase.py Normal file
View file

@ -0,0 +1,176 @@
import asyncio
import unittest
import logging
from unittest.case import _Outcome
from .node import Conductor
try:
from asyncio.runners import _cancel_all_tasks
except ImportError:
# this is only available in py3.7
def _cancel_all_tasks(loop):
pass
class AsyncioTestCase(unittest.TestCase):
# Implementation inspired by discussion:
# https://bugs.python.org/issue32972
async def asyncSetUp(self):
pass
async def asyncTearDown(self):
pass
async def doAsyncCleanups(self):
pass
def run(self, result=None):
orig_result = result
if result is None:
result = self.defaultTestResult()
startTestRun = getattr(result, 'startTestRun', None)
if startTestRun is not None:
startTestRun()
result.startTest(self)
testMethod = getattr(self, self._testMethodName)
if (getattr(self.__class__, "__unittest_skip__", False) or
getattr(testMethod, "__unittest_skip__", False)):
# If the class or method was skipped.
try:
skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
or getattr(testMethod, '__unittest_skip_why__', ''))
self._addSkip(result, self, skip_why)
finally:
result.stopTest(self)
return
expecting_failure_method = getattr(testMethod,
"__unittest_expecting_failure__", False)
expecting_failure_class = getattr(self,
"__unittest_expecting_failure__", False)
expecting_failure = expecting_failure_class or expecting_failure_method
outcome = _Outcome(result)
try:
self._outcome = outcome
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
loop.set_debug(True)
with outcome.testPartExecutor(self):
self.setUp()
loop.run_until_complete(self.asyncSetUp())
if outcome.success:
outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True):
possible_coroutine = testMethod()
if asyncio.iscoroutine(possible_coroutine):
loop.run_until_complete(possible_coroutine)
outcome.expecting_failure = False
with outcome.testPartExecutor(self):
loop.run_until_complete(self.asyncTearDown())
self.tearDown()
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
asyncio.set_event_loop(None)
loop.close()
self.doCleanups()
for test, reason in outcome.skipped:
self._addSkip(result, test, reason)
self._feedErrorsToResult(result, outcome.errors)
if outcome.success:
if expecting_failure:
if outcome.expectedFailure:
self._addExpectedFailure(result, outcome.expectedFailure)
else:
self._addUnexpectedSuccess(result)
else:
result.addSuccess(self)
return result
finally:
result.stopTest(self)
if orig_result is None:
stopTestRun = getattr(result, 'stopTestRun', None)
if stopTestRun is not None:
stopTestRun()
# explicitly break reference cycles:
# outcome.errors -> frame -> outcome -> outcome.errors
# outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
outcome.errors.clear()
outcome.expectedFailure = None
# clear the outcome, no more needed
self._outcome = None
class IntegrationTestCase(AsyncioTestCase):
LEDGER = None
MANAGER = None
VERBOSITY = logging.WARNING
async def asyncSetUp(self):
self.conductor = Conductor(
ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY
)
await self.conductor.start()
self.blockchain = self.conductor.blockchain_node
self.manager = self.conductor.wallet_node.manager
self.ledger = self.conductor.wallet_node.ledger
self.wallet = self.conductor.wallet_node.wallet
self.account = self.conductor.wallet_node.wallet.default_account
async def asyncTearDown(self):
await self.conductor.stop()
def broadcast(self, tx):
return self.ledger.broadcast(tx)
def get_balance(self, account=None, confirmations=0):
if account is None:
return self.manager.get_balance(confirmations=confirmations)
else:
return account.get_balance(confirmations=confirmations)
async def on_header(self, height):
if self.ledger.headers.height < height:
await self.ledger.on_header.where(
lambda e: e.height == height
)
return True
def on_transaction_id(self, txid):
return self.ledger.on_transaction.where(
lambda e: e.tx.id == txid
)
def on_transaction_address(self, tx, address):
return self.ledger.on_transaction.where(
lambda e: e.tx.id == tx.id and e.address == address
)
async def on_transaction(self, tx):
addresses = await self.get_tx_addresses(tx, self.ledger)
await asyncio.wait([
self.ledger.on_transaction.where(lambda e: e.address == address)
for address in addresses
])
async def get_tx_addresses(self, tx, ledger):
addresses = set()
for txo in tx.outputs:
address = ledger.hash160_to_address(txo.script.values['pubkey_hash'])
record = await ledger.db.get_address(address=address)
if record is not None:
addresses.add(address)
return list(addresses)

11
tox.ini
View file

@ -8,17 +8,14 @@ TESTTYPE =
integration: integration
[testenv]
deps =
coverage
../orchstr8
../electrumx
extras = test
deps = coverage
extras = test,server
changedir = {toxinidir}/tests
setenv =
integration: LEDGER={envname}
integration: TORBA_LEDGER={envname}
commands =
unit: coverage run -p --source={envsitepackagesdir}/torba -m unittest discover -t . unit
integration: orchstr8 download
integration: torba download
integration: coverage run -p --source={envsitepackagesdir}/torba -m unittest integration.test_transactions
integration: coverage run -p --source={envsitepackagesdir}/torba -m unittest integration.test_reconnect
# Too slow on Travis