forked from LBRYCommunity/lbry-sdk
Merge remote-tracking branch 'torba/move-to-lbry'
This commit is contained in:
commit
d530599b5d
102 changed files with 30328 additions and 0 deletions
21
torba/.gitignore
vendored
Normal file
21
torba/.gitignore
vendored
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# packaging
|
||||||
|
torba.egg-info/
|
||||||
|
dist/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# testing
|
||||||
|
.tox/
|
||||||
|
tests/client_tests/unit/bitcoin_headers
|
||||||
|
torba/bin
|
||||||
|
|
||||||
|
# cache and logs
|
||||||
|
__pycache__/
|
||||||
|
.mypy_cache/
|
||||||
|
_trial_temp/
|
||||||
|
_trial_temp-*/
|
||||||
|
|
||||||
|
# OS X DS_Store
|
||||||
|
*.DS_Store
|
||||||
|
|
36
torba/.travis.yml
Normal file
36
torba/.travis.yml
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
dist: xenial
|
||||||
|
sudo: true
|
||||||
|
language: python
|
||||||
|
python: "3.7"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
include:
|
||||||
|
|
||||||
|
- stage: code quality
|
||||||
|
name: "pylint & mypy"
|
||||||
|
install:
|
||||||
|
- pip install pylint mypy
|
||||||
|
- pip install -e .
|
||||||
|
script:
|
||||||
|
- pylint --rcfile=setup.cfg torba
|
||||||
|
- mypy --ignore-missing-imports torba
|
||||||
|
after_success: skip
|
||||||
|
|
||||||
|
- &tests
|
||||||
|
stage: tests
|
||||||
|
env: TESTTYPE=unit
|
||||||
|
install:
|
||||||
|
- pip install tox-travis
|
||||||
|
script: tox
|
||||||
|
- <<: *tests
|
||||||
|
env: TESTTYPE=integration
|
||||||
|
|
||||||
|
after_success:
|
||||||
|
- pip install coverage
|
||||||
|
- coverage combine tests/
|
||||||
|
- bash <(curl -s https://codecov.io/bash)
|
||||||
|
|
||||||
|
cache:
|
||||||
|
directories:
|
||||||
|
- $HOME/.cache/pip
|
||||||
|
- $TRAVIS_BUILD_DIR/.tox
|
0
torba/CHANGELOG.md
Normal file
0
torba/CHANGELOG.md
Normal file
21
torba/LICENSE
Normal file
21
torba/LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2018 LBRY Inc.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
4
torba/MANIFEST.in
Normal file
4
torba/MANIFEST.in
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
include README.md
|
||||||
|
include CHANGELOG.md
|
||||||
|
include LICENSE
|
||||||
|
recursive-include torba *.txt *.py
|
3
torba/README.md
Normal file
3
torba/README.md
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
# <img src="https://raw.githubusercontent.com/lbryio/torba/master/torba.png" alt="Torba" width="42" height="30" /> Torba [![Build Status](https://travis-ci.org/lbryio/torba.svg?branch=master)](https://travis-ci.org/lbryio/torba) [![Test Coverage](https://codecov.io/gh/lbryio/torba/branch/master/graph/badge.svg)](https://codecov.io/gh/lbryio/torba)
|
||||||
|
|
||||||
|
A new wallet library to help bitcoin based projects build fast, correct and scalable crypto currency wallets in Python.
|
35
torba/setup.cfg
Normal file
35
torba/setup.cfg
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
[coverage:run]
|
||||||
|
branch = True
|
||||||
|
|
||||||
|
[coverage:paths]
|
||||||
|
source =
|
||||||
|
torba
|
||||||
|
.tox/*/lib/python*/site-packages/torba
|
||||||
|
|
||||||
|
[cryptography.*,coincurve.*,pbkdf2]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[pylint]
|
||||||
|
ignore=words,server,workbench,rpc
|
||||||
|
max-args=10
|
||||||
|
max-line-length=110
|
||||||
|
good-names=T,t,n,i,j,k,x,y,s,f,d,h,c,e,op,db,tx,io,cachedproperty,log,id
|
||||||
|
valid-metaclass-classmethod-first-arg=mcs
|
||||||
|
disable=
|
||||||
|
fixme,
|
||||||
|
broad-except,
|
||||||
|
no-else-return,
|
||||||
|
cyclic-import,
|
||||||
|
missing-docstring,
|
||||||
|
duplicate-code,
|
||||||
|
expression-not-assigned,
|
||||||
|
inconsistent-return-statements,
|
||||||
|
too-few-public-methods,
|
||||||
|
too-many-locals,
|
||||||
|
too-many-branches,
|
||||||
|
too-many-arguments,
|
||||||
|
too-many-statements,
|
||||||
|
too-many-public-methods,
|
||||||
|
too-many-instance-attributes,
|
||||||
|
protected-access,
|
||||||
|
unused-argument
|
68
torba/setup.py
Normal file
68
torba/setup.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
import torba
|
||||||
|
|
||||||
|
BASE = os.path.dirname(__file__)
|
||||||
|
with open(os.path.join(BASE, 'README.md'), encoding='utf-8') as fh:
|
||||||
|
long_description = fh.read()
|
||||||
|
|
||||||
|
REQUIRES = [
|
||||||
|
'aiohttp==3.5.4',
|
||||||
|
'cffi==1.12.1', # TODO: 1.12.2 fails on travis in wine
|
||||||
|
'coincurve==11.0.0',
|
||||||
|
'pbkdf2==1.3',
|
||||||
|
'cryptography==2.5',
|
||||||
|
'attrs==18.2.0',
|
||||||
|
'pylru==1.1.0'
|
||||||
|
]
|
||||||
|
if sys.platform.startswith('linux'):
|
||||||
|
REQUIRES.append('plyvel==1.0.5')
|
||||||
|
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='torba',
|
||||||
|
version=torba.__version__,
|
||||||
|
url='https://github.com/lbryio/torba',
|
||||||
|
license='MIT',
|
||||||
|
author='LBRY Inc.',
|
||||||
|
author_email='hello@lbry.io',
|
||||||
|
description='Wallet client/server framework for bitcoin based currencies.',
|
||||||
|
long_description=long_description,
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
keywords='wallet,crypto,currency,money,bitcoin,electrum,electrumx',
|
||||||
|
classifiers=[
|
||||||
|
'Framework :: AsyncIO',
|
||||||
|
'Intended Audience :: Developers',
|
||||||
|
'Intended Audience :: System Administrators',
|
||||||
|
'License :: OSI Approved :: MIT License',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Operating System :: OS Independent',
|
||||||
|
'Topic :: Internet',
|
||||||
|
'Topic :: Software Development :: Testing',
|
||||||
|
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||||
|
'Topic :: System :: Benchmark',
|
||||||
|
'Topic :: System :: Distributed Computing',
|
||||||
|
'Topic :: Utilities',
|
||||||
|
],
|
||||||
|
packages=find_packages(exclude=('tests',)),
|
||||||
|
python_requires='>=3.6',
|
||||||
|
install_requires=REQUIRES,
|
||||||
|
extras_require={
|
||||||
|
'gui': (
|
||||||
|
'pyside2',
|
||||||
|
)
|
||||||
|
},
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'torba-client=torba.client.cli:main',
|
||||||
|
'torba-server=torba.server.cli:main',
|
||||||
|
'orchstr8=torba.orchstr8.cli:main',
|
||||||
|
],
|
||||||
|
'gui_scripts': [
|
||||||
|
'torba=torba.ui:main [gui]',
|
||||||
|
'torba-workbench=torba.workbench:main [gui]',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
0
torba/tests/client_tests/__init__.py
Normal file
0
torba/tests/client_tests/__init__.py
Normal file
0
torba/tests/client_tests/integration/__init__.py
Normal file
0
torba/tests/client_tests/integration/__init__.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
import logging
|
||||||
|
from torba.testcase import IntegrationTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class BlockchainReorganizationTests(IntegrationTestCase):
|
||||||
|
|
||||||
|
VERBOSITY = logging.WARN
|
||||||
|
|
||||||
|
async def assertBlockHash(self, height):
|
||||||
|
self.assertEqual(
|
||||||
|
self.ledger.headers.hash(height).decode(),
|
||||||
|
await self.blockchain.get_block_hash(height)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_reorg(self):
|
||||||
|
# invalidate current block, move forward 2
|
||||||
|
self.assertEqual(self.ledger.headers.height, 200)
|
||||||
|
await self.assertBlockHash(200)
|
||||||
|
await self.blockchain.invalidate_block(self.ledger.headers.hash(200).decode())
|
||||||
|
await self.blockchain.generate(2)
|
||||||
|
await self.ledger.on_header.where(lambda e: e.height == 201)
|
||||||
|
self.assertEqual(self.ledger.headers.height, 201)
|
||||||
|
await self.assertBlockHash(200)
|
||||||
|
await self.assertBlockHash(201)
|
||||||
|
|
||||||
|
# invalidate current block, move forward 3
|
||||||
|
await self.blockchain.invalidate_block(self.ledger.headers.hash(200).decode())
|
||||||
|
await self.blockchain.generate(3)
|
||||||
|
await self.ledger.on_header.where(lambda e: e.height == 202)
|
||||||
|
self.assertEqual(self.ledger.headers.height, 202)
|
||||||
|
await self.assertBlockHash(200)
|
||||||
|
await self.assertBlockHash(201)
|
||||||
|
await self.assertBlockHash(202)
|
74
torba/tests/client_tests/integration/test_reconnect.py
Normal file
74
torba/tests/client_tests/integration/test_reconnect.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from torba.client.basenetwork import BaseNetwork
|
||||||
|
from torba.rpc import RPCSession
|
||||||
|
from torba.testcase import IntegrationTestCase, AsyncioTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class ReconnectTests(IntegrationTestCase):
|
||||||
|
|
||||||
|
VERBOSITY = logging.WARN
|
||||||
|
|
||||||
|
async def test_connection_drop_still_receives_events_after_reconnected(self):
|
||||||
|
address1 = await self.account.receiving.get_or_create_usable_address()
|
||||||
|
self.ledger.network.client.connection_lost(Exception())
|
||||||
|
sendtxid = await self.blockchain.send_to_address(address1, 1.1337)
|
||||||
|
await self.on_transaction_id(sendtxid) # mempool
|
||||||
|
await self.blockchain.generate(1)
|
||||||
|
await self.on_transaction_id(sendtxid) # confirmed
|
||||||
|
|
||||||
|
await self.assertBalance(self.account, '1.1337')
|
||||||
|
# is it real? are we rich!? let me see this tx...
|
||||||
|
d = self.ledger.network.get_transaction(sendtxid)
|
||||||
|
# what's that smoke on my ethernet cable? oh no!
|
||||||
|
self.ledger.network.client.connection_lost(Exception())
|
||||||
|
with self.assertRaises(asyncio.CancelledError):
|
||||||
|
await d
|
||||||
|
# rich but offline? no way, no water, let's retry
|
||||||
|
with self.assertRaisesRegex(ConnectionError, 'connection is not available'):
|
||||||
|
await self.ledger.network.get_transaction(sendtxid)
|
||||||
|
# * goes to pick some water outside... * time passes by and another donation comes in
|
||||||
|
sendtxid = await self.blockchain.send_to_address(address1, 42)
|
||||||
|
await self.blockchain.generate(1)
|
||||||
|
# omg, the burned cable still works! torba is fire proof!
|
||||||
|
await self.ledger.network.get_transaction(sendtxid)
|
||||||
|
|
||||||
|
async def test_timeout_then_reconnect(self):
|
||||||
|
await self.conductor.spv_node.stop()
|
||||||
|
self.assertFalse(self.ledger.network.is_connected)
|
||||||
|
await self.conductor.spv_node.start(self.conductor.blockchain_node)
|
||||||
|
await self.ledger.network.on_connected.first
|
||||||
|
self.assertTrue(self.ledger.network.is_connected)
|
||||||
|
|
||||||
|
|
||||||
|
class ServerPickingTestCase(AsyncioTestCase):
|
||||||
|
async def _make_fake_server(self, latency=1.0, port=1337):
|
||||||
|
# local fake server with artificial latency
|
||||||
|
proto = RPCSession()
|
||||||
|
proto.handle_request = lambda _: asyncio.sleep(latency)
|
||||||
|
server = await self.loop.create_server(lambda: proto, host='127.0.0.1', port=port)
|
||||||
|
self.addCleanup(server.close)
|
||||||
|
return ('127.0.0.1', port)
|
||||||
|
|
||||||
|
async def test_pick_fastest(self):
|
||||||
|
ledger = Mock(config={
|
||||||
|
'default_servers': [
|
||||||
|
await self._make_fake_server(latency=1.5, port=1340),
|
||||||
|
await self._make_fake_server(latency=0.1, port=1337),
|
||||||
|
await self._make_fake_server(latency=1.0, port=1339),
|
||||||
|
await self._make_fake_server(latency=0.5, port=1338),
|
||||||
|
],
|
||||||
|
'connect_timeout': 30
|
||||||
|
})
|
||||||
|
|
||||||
|
network = BaseNetwork(ledger)
|
||||||
|
self.addCleanup(network.stop)
|
||||||
|
asyncio.ensure_future(network.start())
|
||||||
|
await asyncio.wait_for(network.on_connected.first, timeout=1)
|
||||||
|
self.assertTrue(network.is_connected)
|
||||||
|
self.assertEqual(network.client.server, ('127.0.0.1', 1337))
|
||||||
|
# ensure we are connected to all of them
|
||||||
|
self.assertEqual(len(network.session_pool.sessions), 4)
|
||||||
|
self.assertTrue(all([not session.is_closing() for session in network.session_pool.sessions]))
|
97
torba/tests/client_tests/integration/test_sync.py
Normal file
97
torba/tests/client_tests/integration/test_sync.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from torba.testcase import IntegrationTestCase, WalletNode
|
||||||
|
from torba.client.constants import CENT
|
||||||
|
|
||||||
|
|
||||||
|
class SyncTests(IntegrationTestCase):
|
||||||
|
|
||||||
|
VERBOSITY = logging.WARN
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.api_port = 5280
|
||||||
|
self.started_nodes = []
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
for node in self.started_nodes:
|
||||||
|
try:
|
||||||
|
await node.stop(cleanup=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
await super().asyncTearDown()
|
||||||
|
|
||||||
|
async def make_wallet_node(self, seed=None):
|
||||||
|
self.api_port += 1
|
||||||
|
wallet_node = WalletNode(
|
||||||
|
self.wallet_node.manager_class,
|
||||||
|
self.wallet_node.ledger_class,
|
||||||
|
port=self.api_port
|
||||||
|
)
|
||||||
|
await wallet_node.start(self.conductor.spv_node, seed)
|
||||||
|
self.started_nodes.append(wallet_node)
|
||||||
|
return wallet_node
|
||||||
|
|
||||||
|
async def test_nodes_with_same_account_stay_in_sync(self):
|
||||||
|
# destination node/account for receiving TXs
|
||||||
|
node0 = await self.make_wallet_node()
|
||||||
|
account0 = node0.account
|
||||||
|
# main node/account creating TXs
|
||||||
|
node1 = self.wallet_node
|
||||||
|
account1 = self.wallet_node.account
|
||||||
|
# mirror node/account, expected to reflect everything in main node as it happens
|
||||||
|
node2 = await self.make_wallet_node(account1.seed)
|
||||||
|
account2 = node2.account
|
||||||
|
|
||||||
|
self.assertNotEqual(account0.id, account1.id)
|
||||||
|
self.assertEqual(account1.id, account2.id)
|
||||||
|
await self.assertBalance(account0, '0.0')
|
||||||
|
await self.assertBalance(account1, '0.0')
|
||||||
|
await self.assertBalance(account2, '0.0')
|
||||||
|
self.assertEqual(await account0.get_address_count(chain=0), 20)
|
||||||
|
self.assertEqual(await account1.get_address_count(chain=0), 20)
|
||||||
|
self.assertEqual(await account2.get_address_count(chain=0), 20)
|
||||||
|
self.assertEqual(await account1.get_address_count(chain=1), 6)
|
||||||
|
self.assertEqual(await account2.get_address_count(chain=1), 6)
|
||||||
|
|
||||||
|
# check that main node and mirror node generate 5 address to fill gap
|
||||||
|
fifth_address = (await account1.receiving.get_addresses())[4]
|
||||||
|
await self.blockchain.send_to_address(fifth_address, 1.00)
|
||||||
|
await asyncio.wait([
|
||||||
|
account1.ledger.on_address.first,
|
||||||
|
account2.ledger.on_address.first
|
||||||
|
])
|
||||||
|
self.assertEqual(await account1.get_address_count(chain=0), 25)
|
||||||
|
self.assertEqual(await account2.get_address_count(chain=0), 25)
|
||||||
|
await self.assertBalance(account1, '1.0')
|
||||||
|
await self.assertBalance(account2, '1.0')
|
||||||
|
|
||||||
|
await self.blockchain.generate(1)
|
||||||
|
|
||||||
|
# pay 0.01 from main node to receiving node, would have increased change addresses
|
||||||
|
address0 = (await account0.receiving.get_addresses())[0]
|
||||||
|
hash0 = self.ledger.address_to_hash160(address0)
|
||||||
|
tx = await account1.ledger.transaction_class.create(
|
||||||
|
[],
|
||||||
|
[self.ledger.transaction_class.output_class.pay_pubkey_hash(CENT, hash0)],
|
||||||
|
[account1], account1
|
||||||
|
)
|
||||||
|
await self.broadcast(tx)
|
||||||
|
await asyncio.wait([
|
||||||
|
account0.ledger.wait(tx),
|
||||||
|
account1.ledger.wait(tx),
|
||||||
|
account2.ledger.wait(tx),
|
||||||
|
])
|
||||||
|
self.assertEqual(await account0.get_address_count(chain=0), 21)
|
||||||
|
self.assertGreater(await account1.get_address_count(chain=1), 6)
|
||||||
|
self.assertGreater(await account2.get_address_count(chain=1), 6)
|
||||||
|
await self.assertBalance(account0, '0.01')
|
||||||
|
await self.assertBalance(account1, '0.989876')
|
||||||
|
await self.assertBalance(account2, '0.989876')
|
||||||
|
|
||||||
|
await self.blockchain.generate(1)
|
||||||
|
|
||||||
|
# create a new mirror node and see if it syncs to same balance from scratch
|
||||||
|
node3 = await self.make_wallet_node(account1.seed)
|
||||||
|
account3 = node3.account
|
||||||
|
await self.assertBalance(account3, '0.989876')
|
131
torba/tests/client_tests/integration/test_transactions.py
Normal file
131
torba/tests/client_tests/integration/test_transactions.py
Normal file
|
@ -0,0 +1,131 @@
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from itertools import chain
|
||||||
|
from torba.testcase import IntegrationTestCase
|
||||||
|
from torba.client.util import satoshis_to_coins, coins_to_satoshis
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransactionTests(IntegrationTestCase):
|
||||||
|
|
||||||
|
VERBOSITY = logging.WARN
|
||||||
|
|
||||||
|
async def test_variety_of_transactions_and_longish_history(self):
|
||||||
|
await self.blockchain.generate(300)
|
||||||
|
await self.assertBalance(self.account, '0.0')
|
||||||
|
addresses = await self.account.receiving.get_addresses()
|
||||||
|
|
||||||
|
# send 10 coins to first 10 receiving addresses and then 10 transactions worth 10 coins each
|
||||||
|
# to the 10th receiving address for a total of 30 UTXOs on the entire account
|
||||||
|
sends = list(chain(
|
||||||
|
(self.blockchain.send_to_address(address, 10) for address in addresses[:10]),
|
||||||
|
(self.blockchain.send_to_address(addresses[9], 10) for _ in range(10))
|
||||||
|
))
|
||||||
|
# use batching to reduce issues with send_to_address on cli
|
||||||
|
for batch in range(0, len(sends), 10):
|
||||||
|
txids = await asyncio.gather(*sends[batch:batch+10])
|
||||||
|
await asyncio.wait([self.on_transaction_id(txid) for txid in txids])
|
||||||
|
await self.assertBalance(self.account, '200.0')
|
||||||
|
self.assertEqual(20, await self.account.get_utxo_count())
|
||||||
|
|
||||||
|
# address gap should have increase by 10 to cover the first 10 addresses we've used up
|
||||||
|
addresses = await self.account.receiving.get_addresses()
|
||||||
|
self.assertEqual(30, len(addresses))
|
||||||
|
|
||||||
|
# there used to be a sync bug which failed to save TXIs between
|
||||||
|
# daemon restarts, clearing cache replicates that behavior
|
||||||
|
self.ledger._tx_cache.clear()
|
||||||
|
|
||||||
|
# spend from each of the first 10 addresses to the subsequent 10 addresses
|
||||||
|
txs = []
|
||||||
|
for address in addresses[10:20]:
|
||||||
|
txs.append(await self.ledger.transaction_class.create(
|
||||||
|
[],
|
||||||
|
[self.ledger.transaction_class.output_class.pay_pubkey_hash(
|
||||||
|
coins_to_satoshis('1.0'), self.ledger.address_to_hash160(address)
|
||||||
|
)],
|
||||||
|
[self.account], self.account
|
||||||
|
))
|
||||||
|
await asyncio.wait([self.broadcast(tx) for tx in txs])
|
||||||
|
await asyncio.wait([self.ledger.wait(tx) for tx in txs])
|
||||||
|
|
||||||
|
# verify that a previous bug which failed to save TXIs doesn't come back
|
||||||
|
# this check must happen before generating a new block
|
||||||
|
self.assertTrue(all([
|
||||||
|
tx.inputs[0].txo_ref.txo is not None
|
||||||
|
for tx in await self.ledger.db.get_transactions(txid__in=[tx.id for tx in txs])
|
||||||
|
]))
|
||||||
|
|
||||||
|
await self.blockchain.generate(1)
|
||||||
|
await asyncio.wait([self.ledger.wait(tx) for tx in txs])
|
||||||
|
await self.assertBalance(self.account, '199.99876')
|
||||||
|
|
||||||
|
# 10 of the UTXOs have been split into a 1 coin UTXO and a 9 UTXO change
|
||||||
|
self.assertEqual(30, await self.account.get_utxo_count())
|
||||||
|
|
||||||
|
# spend all 30 UTXOs into a a 199 coin UTXO and change
|
||||||
|
tx = await self.ledger.transaction_class.create(
|
||||||
|
[],
|
||||||
|
[self.ledger.transaction_class.output_class.pay_pubkey_hash(
|
||||||
|
coins_to_satoshis('199.0'), self.ledger.address_to_hash160(addresses[-1])
|
||||||
|
)],
|
||||||
|
[self.account], self.account
|
||||||
|
)
|
||||||
|
await self.broadcast(tx)
|
||||||
|
await self.ledger.wait(tx)
|
||||||
|
await self.blockchain.generate(1)
|
||||||
|
await self.ledger.wait(tx)
|
||||||
|
|
||||||
|
self.assertEqual(2, await self.account.get_utxo_count()) # 199 + change
|
||||||
|
await self.assertBalance(self.account, '199.99649')
|
||||||
|
|
||||||
|
async def test_sending_and_receiving(self):
|
||||||
|
account1, account2 = self.account, self.wallet.generate_account(self.ledger)
|
||||||
|
await self.ledger.subscribe_account(account2)
|
||||||
|
|
||||||
|
await self.assertBalance(account1, '0.0')
|
||||||
|
await self.assertBalance(account2, '0.0')
|
||||||
|
|
||||||
|
addresses = await self.account.receiving.get_addresses()
|
||||||
|
txids = await asyncio.gather(*(
|
||||||
|
self.blockchain.send_to_address(address, 1.1) for address in addresses[:5]
|
||||||
|
))
|
||||||
|
await asyncio.wait([self.on_transaction_id(txid) for txid in txids]) # mempool
|
||||||
|
await self.blockchain.generate(1)
|
||||||
|
await asyncio.wait([self.on_transaction_id(txid) for txid in txids]) # confirmed
|
||||||
|
await self.assertBalance(account1, '5.5')
|
||||||
|
await self.assertBalance(account2, '0.0')
|
||||||
|
|
||||||
|
address2 = await account2.receiving.get_or_create_usable_address()
|
||||||
|
tx = await self.ledger.transaction_class.create(
|
||||||
|
[],
|
||||||
|
[self.ledger.transaction_class.output_class.pay_pubkey_hash(
|
||||||
|
coins_to_satoshis('2.0'), self.ledger.address_to_hash160(address2)
|
||||||
|
)],
|
||||||
|
[account1], account1
|
||||||
|
)
|
||||||
|
await self.broadcast(tx)
|
||||||
|
await self.ledger.wait(tx) # mempool
|
||||||
|
await self.blockchain.generate(1)
|
||||||
|
await self.ledger.wait(tx) # confirmed
|
||||||
|
|
||||||
|
await self.assertBalance(account1, '3.499802')
|
||||||
|
await self.assertBalance(account2, '2.0')
|
||||||
|
|
||||||
|
utxos = await self.account.get_utxos()
|
||||||
|
tx = await self.ledger.transaction_class.create(
|
||||||
|
[self.ledger.transaction_class.input_class.spend(utxos[0])],
|
||||||
|
[],
|
||||||
|
[account1], account1
|
||||||
|
)
|
||||||
|
await self.broadcast(tx)
|
||||||
|
await self.ledger.wait(tx) # mempool
|
||||||
|
await self.blockchain.generate(1)
|
||||||
|
await self.ledger.wait(tx) # confirmed
|
||||||
|
|
||||||
|
tx = (await account1.get_transactions())[1]
|
||||||
|
self.assertEqual(satoshis_to_coins(tx.inputs[0].amount), '1.1')
|
||||||
|
self.assertEqual(satoshis_to_coins(tx.inputs[1].amount), '1.1')
|
||||||
|
self.assertEqual(satoshis_to_coins(tx.outputs[0].amount), '2.0')
|
||||||
|
self.assertEqual(tx.outputs[0].get_address(self.ledger), address2)
|
||||||
|
self.assertEqual(tx.outputs[0].is_change, False)
|
||||||
|
self.assertEqual(tx.outputs[1].is_change, True)
|
0
torba/tests/client_tests/unit/__init__.py
Normal file
0
torba/tests/client_tests/unit/__init__.py
Normal file
65
torba/tests/client_tests/unit/key_fixtures.py
Normal file
65
torba/tests/client_tests/unit/key_fixtures.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
expected_ids = [
|
||||||
|
b'948adae2a128c0bd1fa238117fd0d9690961f26e',
|
||||||
|
b'cd9f4f2adde7de0a53ab6d326bb6a62b489876dd',
|
||||||
|
b'c479e02a74a809ffecff60255d1c14f4081a197a',
|
||||||
|
b'4bab2fb2c424f31f170b15ec53c4a596db9d6710',
|
||||||
|
b'689cb7c621f57b7c398e7e04ed9a5098ab8389e9',
|
||||||
|
b'75116d6a689a0f9b56fe7cfec9cbbd0e16814288',
|
||||||
|
b'2439f0993fb298497dd7f317b9737c356f664a86',
|
||||||
|
b'32f1cb4799008cf5496bb8cafdaf59d5dabec6af',
|
||||||
|
b'fa29aa536353904e9cc813b0cf18efcc09e5ad13',
|
||||||
|
b'37df34002f34d7875428a2977df19be3f4f40a31',
|
||||||
|
b'8c8a72b5d2747a3e7e05ed85110188769d5656c3',
|
||||||
|
b'e5c8ef10c5bdaa79c9a237a096f50df4dcac27f0',
|
||||||
|
b'4d5270dc100fba85974665c20cd0f95d4822e8d1',
|
||||||
|
b'e76b07da0cdd59915475cd310599544b9744fa34',
|
||||||
|
b'6f009bccf8be99707161abb279d8ccf8fd953721',
|
||||||
|
b'f32f08b722cc8607c3f7f192b4d5f13a74c85785',
|
||||||
|
b'46f4430a5c91b9b799e9be6b47ac7a749d8d9f30',
|
||||||
|
b'ebbf9850abe0aae2d09e7e3ebd6b51f01282f39b',
|
||||||
|
b'5f6655438f8ddc6b2f6ea8197c8babaffc9f5c09',
|
||||||
|
b'e194e70ee8711b0ed765608121e4cceb551cdf28'
|
||||||
|
]
|
||||||
|
expected_privkeys = [
|
||||||
|
b'95557ee9a2bb7665e67e45246658b5c839f7dcd99b6ebc800eeebccd28bf134a',
|
||||||
|
b'689b6921f65647a8e4fc1497924730c92ad4ad183f10fac2bdee65cc8fb6dcf9',
|
||||||
|
b'977ee018b448c530327b7e927cc3645ca4cb152c5dd98e1bd917c52fd46fc80a',
|
||||||
|
b'3c7fb05b0ab4da8b292e895f574f8213cadfe81b84ded7423eab61c5f884c8ae',
|
||||||
|
b'b21fc7be1e69182827538683a48ac9d95684faf6c1c6deabb6e513d8c76afcc9',
|
||||||
|
b'a5021734dbbf1d090b15509ba00f2c04a3d5afc19939b4594ca0850d4190b923',
|
||||||
|
b'07dfe0aa94c1b948dc935be1f8179f3050353b46f3a3134e77c70e66208be72d',
|
||||||
|
b'c331b2fb82cd91120b0703ee312042a854a51a8d945aa9e70fb14d68b0366fe1',
|
||||||
|
b'3aa59ec4d8f1e7ce2775854b5e82433535b6e3503f9a8e7c4e60aac066d44718',
|
||||||
|
b'ccc8b4ca73b266b4a0c89a9d33c4ec7532b434c9294c26832355e5e2bee2e005',
|
||||||
|
b'280c074d8982e56d70c404072252c309694a6e5c05457a6abbe8fc225c2dfd52',
|
||||||
|
b'546cee26da713a3a64b2066d5e3a52b7c1d927396d1ba8a3d9f6e3e973398856',
|
||||||
|
b'7fbc4615d5e819eee22db440c5bcc4ff25bb046841c41a192003a6d9abfbafbf',
|
||||||
|
b'5b63f13011cab965feea3a41fac2d7a877aa710ab20e2a9a1708474e3c05c050',
|
||||||
|
b'394b36f528947557d317fd40a4adde5514c8745a5f64185421fa2c0c4a158938',
|
||||||
|
b'8f101c8f5290ae6c0dd76d210b7effacd7f12db18f3befab711f533bde084c76',
|
||||||
|
b'6637a656f897a66080fbe60027d32c3f4ebc0e3b5f96123a33f932a091b039c2',
|
||||||
|
b'2815aa6667c042a3a4565fb789890cd33e380d047ed712759d097d479df71051',
|
||||||
|
b'120e761c6382b07a9548650a20b3b9dd74b906093260fa6f92f790ba71f79e8d',
|
||||||
|
b'823c8a613ea539f730a968518993195174bf973ed75c734b6898022867165d7b'
|
||||||
|
]
|
||||||
|
expected_hardened_privkeys = [
|
||||||
|
b'abdba45b0459e7804beb68edb899e58a5c2636bf67d096711904001406afbd4c',
|
||||||
|
b'c9e804d4b8fdd99ef6ab2b0ca627a57f4283c28e11e9152ad9d3f863404d940e',
|
||||||
|
b'4cf87d68ae99711261f8cb8e1bde83b8703ff5d689ef70ce23106d1e6e8ed4bd',
|
||||||
|
b'dbf8d578c77f9bf62bb2ad40975e253af1e1d44d53abf84a22d2be29b9488f7f',
|
||||||
|
b'633bb840505521ffe39cb89a04fb8bff3298d6b64a5d8f170aca1e456d6f89b9',
|
||||||
|
b'92e80a38791bd8ba2105b9867fd58ac2cc4fb9853e18141b7fee1884bc5aae69',
|
||||||
|
b'd3663339af1386d05dd90ee20f627661ae87ddb1db0c2dc73fd8a4485930d0e7',
|
||||||
|
b'09a448303452d241b8a25670b36cc758975b97e88f62b6f25cd9084535e3c13a',
|
||||||
|
b'ee22eb77df05ff53e9c2ba797c1f2ebf97ec4cf5a99528adec94972674aeabed',
|
||||||
|
b'935facccb6120659c5b7c606a457c797e5a10ce4a728346e1a3a963251169651',
|
||||||
|
b'8ac9b4a48da1def375640ca03bc6711040dfd4eea7106d42bb4c2de83d7f595e',
|
||||||
|
b'51ecd3f7565c2b86d5782dbde2175ab26a7b896022564063fafe153588610be9',
|
||||||
|
b'04918252f6b6f51cd75957289b56a324b45cc085df80839137d740f9ada6c062',
|
||||||
|
b'2efbd0c839af971e3769c26938d776990ebf097989df4861535a7547a2701483',
|
||||||
|
b'85c6e31e6b27bd188291a910f4a7faba7fceb3e09df72884b10907ecc1491cd0',
|
||||||
|
b'05e245131885bebda993a31bb14ac98b794062a50af639ad22010aed1e533a54',
|
||||||
|
b'ddca42cf7db93f3a3f0723d5fee4c21bf60b7afac35d5c30eb34bd91b35cc609',
|
||||||
|
b'324a5c16030e0c3947e4dcd2b5057fd3a4d5bed96b23e3b476b2af0ab76369c9',
|
||||||
|
b'da63c41cdb398cdcd93e832f3e198528afbb4065821b026c143cec910d8362f0'
|
||||||
|
]
|
478
torba/tests/client_tests/unit/test_account.py
Normal file
478
torba/tests/client_tests/unit/test_account.py
Normal file
|
@ -0,0 +1,478 @@
|
||||||
|
from binascii import hexlify
|
||||||
|
|
||||||
|
from torba.testcase import AsyncioTestCase
|
||||||
|
|
||||||
|
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||||
|
from torba.client.baseaccount import HierarchicalDeterministic, SingleKey
|
||||||
|
from torba.client.wallet import Wallet
|
||||||
|
|
||||||
|
|
||||||
|
class TestHierarchicalDeterministicAccount(AsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.ledger = ledger_class({
|
||||||
|
'db': ledger_class.database_class(':memory:'),
|
||||||
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
|
})
|
||||||
|
await self.ledger.db.open()
|
||||||
|
self.account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
await self.ledger.db.close()
|
||||||
|
|
||||||
|
async def test_generate_account(self):
|
||||||
|
account = self.account
|
||||||
|
|
||||||
|
self.assertEqual(account.ledger, self.ledger)
|
||||||
|
self.assertIsNotNone(account.seed)
|
||||||
|
self.assertEqual(account.public_key.ledger, self.ledger)
|
||||||
|
self.assertEqual(account.private_key.public_key, account.public_key)
|
||||||
|
|
||||||
|
addresses = await account.receiving.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 0)
|
||||||
|
addresses = await account.change.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 0)
|
||||||
|
|
||||||
|
await account.ensure_address_gap()
|
||||||
|
|
||||||
|
addresses = await account.receiving.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 20)
|
||||||
|
addresses = await account.change.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 6)
|
||||||
|
|
||||||
|
addresses = await account.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 26)
|
||||||
|
|
||||||
|
async def test_generate_keys_over_batch_threshold_saves_it_properly(self):
|
||||||
|
async with self.account.receiving.address_generator_lock:
|
||||||
|
await self.account.receiving._generate_keys(0, 200)
|
||||||
|
records = await self.account.receiving.get_address_records()
|
||||||
|
self.assertEqual(201, len(records))
|
||||||
|
|
||||||
|
async def test_ensure_address_gap(self):
|
||||||
|
account = self.account
|
||||||
|
|
||||||
|
self.assertIsInstance(account.receiving, HierarchicalDeterministic)
|
||||||
|
|
||||||
|
async with account.receiving.address_generator_lock:
|
||||||
|
await account.receiving._generate_keys(4, 7)
|
||||||
|
await account.receiving._generate_keys(0, 3)
|
||||||
|
await account.receiving._generate_keys(8, 11)
|
||||||
|
records = await account.receiving.get_address_records()
|
||||||
|
self.assertEqual(
|
||||||
|
[r['position'] for r in records],
|
||||||
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
||||||
|
)
|
||||||
|
|
||||||
|
# we have 12, but default gap is 20
|
||||||
|
new_keys = await account.receiving.ensure_address_gap()
|
||||||
|
self.assertEqual(len(new_keys), 8)
|
||||||
|
records = await account.receiving.get_address_records()
|
||||||
|
self.assertEqual(
|
||||||
|
[r['position'] for r in records],
|
||||||
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
|
||||||
|
)
|
||||||
|
|
||||||
|
# case #1: no new addresses needed
|
||||||
|
empty = await account.receiving.ensure_address_gap()
|
||||||
|
self.assertEqual(len(empty), 0)
|
||||||
|
|
||||||
|
# case #2: only one new addressed needed
|
||||||
|
records = await account.receiving.get_address_records()
|
||||||
|
await self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
|
||||||
|
new_keys = await account.receiving.ensure_address_gap()
|
||||||
|
self.assertEqual(len(new_keys), 1)
|
||||||
|
|
||||||
|
# case #3: 20 addresses needed
|
||||||
|
await self.ledger.db.set_address_history(new_keys[0], 'a:1:')
|
||||||
|
new_keys = await account.receiving.ensure_address_gap()
|
||||||
|
self.assertEqual(len(new_keys), 20)
|
||||||
|
|
||||||
|
async def test_get_or_create_usable_address(self):
|
||||||
|
account = self.account
|
||||||
|
|
||||||
|
keys = await account.receiving.get_addresses()
|
||||||
|
self.assertEqual(len(keys), 0)
|
||||||
|
|
||||||
|
address = await account.receiving.get_or_create_usable_address()
|
||||||
|
self.assertIsNotNone(address)
|
||||||
|
|
||||||
|
keys = await account.receiving.get_addresses()
|
||||||
|
self.assertEqual(len(keys), 20)
|
||||||
|
|
||||||
|
async def test_generate_account_from_seed(self):
|
||||||
|
account = self.ledger.account_class.from_dict(
|
||||||
|
self.ledger, Wallet(), {
|
||||||
|
"seed": "carbon smart garage balance margin twelve chest sword "
|
||||||
|
"toast envelope bottom stomach absent",
|
||||||
|
"address_generator": {
|
||||||
|
'name': 'deterministic-chain',
|
||||||
|
'receiving': {'gap': 3, 'maximum_uses_per_address': 1},
|
||||||
|
'change': {'gap': 2, 'maximum_uses_per_address': 1}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
account.private_key.extended_key_string(),
|
||||||
|
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp5BxK'
|
||||||
|
'Kfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna'
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
account.public_key.extended_key_string(),
|
||||||
|
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7UbpV'
|
||||||
|
'NzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g'
|
||||||
|
)
|
||||||
|
address = await account.receiving.ensure_address_gap()
|
||||||
|
self.assertEqual(address[0], '1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
|
||||||
|
|
||||||
|
private_key = await self.ledger.get_private_key_for_address('1CDLuMfwmPqJiNk5C2Bvew6tpgjAGgUk8J')
|
||||||
|
self.assertEqual(
|
||||||
|
private_key.extended_key_string(),
|
||||||
|
'xprv9xV7rhbg6M4yWrdTeLorz3Q1GrQb4aQzzGWboP3du7W7UUztzNTUrEYTnDfz7o'
|
||||||
|
'ptBygDxXYRppyiuenJpoBTgYP2C26E1Ah5FEALM24CsWi'
|
||||||
|
)
|
||||||
|
|
||||||
|
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
|
||||||
|
self.assertIsNone(invalid_key)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
hexlify(private_key.wif()),
|
||||||
|
b'1c01ae1e4c7d89e39f6d3aa7792c097a30ca7d40be249b6de52c81ec8cf9aab48b01'
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_load_and_save_account(self):
|
||||||
|
account_data = {
|
||||||
|
'name': 'My Account',
|
||||||
|
'modified_on': 123.456,
|
||||||
|
'seed':
|
||||||
|
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
|
||||||
|
"h absent",
|
||||||
|
'encrypted': False,
|
||||||
|
'private_key':
|
||||||
|
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
|
||||||
|
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
|
||||||
|
'public_key':
|
||||||
|
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7'
|
||||||
|
'UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g',
|
||||||
|
'address_generator': {
|
||||||
|
'name': 'deterministic-chain',
|
||||||
|
'receiving': {'gap': 5, 'maximum_uses_per_address': 2},
|
||||||
|
'change': {'gap': 5, 'maximum_uses_per_address': 2}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
|
||||||
|
|
||||||
|
await account.ensure_address_gap()
|
||||||
|
|
||||||
|
addresses = await account.receiving.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 5)
|
||||||
|
addresses = await account.change.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 5)
|
||||||
|
|
||||||
|
self.maxDiff = None
|
||||||
|
account_data['ledger'] = 'btc_mainnet'
|
||||||
|
self.assertDictEqual(account_data, account.to_dict())
|
||||||
|
|
||||||
|
def test_apply_diff(self):
|
||||||
|
account_data = {
|
||||||
|
'name': 'My Account',
|
||||||
|
'modified_on': 123.456,
|
||||||
|
'seed':
|
||||||
|
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
|
||||||
|
"h absent",
|
||||||
|
'encrypted': False,
|
||||||
|
'private_key':
|
||||||
|
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
|
||||||
|
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
|
||||||
|
'public_key':
|
||||||
|
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7'
|
||||||
|
'UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g',
|
||||||
|
'address_generator': {
|
||||||
|
'name': 'deterministic-chain',
|
||||||
|
'receiving': {'gap': 5, 'maximum_uses_per_address': 2},
|
||||||
|
'change': {'gap': 5, 'maximum_uses_per_address': 2}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
|
||||||
|
|
||||||
|
self.assertEqual(account.name, 'My Account')
|
||||||
|
self.assertEqual(account.modified_on, 123.456)
|
||||||
|
self.assertEqual(account.change.gap, 5)
|
||||||
|
self.assertEqual(account.change.maximum_uses_per_address, 2)
|
||||||
|
self.assertEqual(account.receiving.gap, 5)
|
||||||
|
self.assertEqual(account.receiving.maximum_uses_per_address, 2)
|
||||||
|
|
||||||
|
account_data['name'] = 'Changed Name'
|
||||||
|
account_data['address_generator']['change']['gap'] = 6
|
||||||
|
account_data['address_generator']['change']['maximum_uses_per_address'] = 7
|
||||||
|
account_data['address_generator']['receiving']['gap'] = 8
|
||||||
|
account_data['address_generator']['receiving']['maximum_uses_per_address'] = 9
|
||||||
|
|
||||||
|
account.apply(account_data)
|
||||||
|
# no change because modified_on is not newer
|
||||||
|
self.assertEqual(account.name, 'My Account')
|
||||||
|
|
||||||
|
account_data['modified_on'] = 200.00
|
||||||
|
|
||||||
|
account.apply(account_data)
|
||||||
|
self.assertEqual(account.name, 'Changed Name')
|
||||||
|
self.assertEqual(account.change.gap, 6)
|
||||||
|
self.assertEqual(account.change.maximum_uses_per_address, 7)
|
||||||
|
self.assertEqual(account.receiving.gap, 8)
|
||||||
|
self.assertEqual(account.receiving.maximum_uses_per_address, 9)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingleKeyAccount(AsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.ledger = ledger_class({
|
||||||
|
'db': ledger_class.database_class(':memory:'),
|
||||||
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
|
})
|
||||||
|
await self.ledger.db.open()
|
||||||
|
self.account = self.ledger.account_class.generate(
|
||||||
|
self.ledger, Wallet(), "torba", {'name': 'single-address'})
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
await self.ledger.db.close()
|
||||||
|
|
||||||
|
async def test_generate_account(self):
|
||||||
|
account = self.account
|
||||||
|
|
||||||
|
self.assertEqual(account.ledger, self.ledger)
|
||||||
|
self.assertIsNotNone(account.seed)
|
||||||
|
self.assertEqual(account.public_key.ledger, self.ledger)
|
||||||
|
self.assertEqual(account.private_key.public_key, account.public_key)
|
||||||
|
|
||||||
|
addresses = await account.receiving.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 0)
|
||||||
|
addresses = await account.change.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 0)
|
||||||
|
|
||||||
|
await account.ensure_address_gap()
|
||||||
|
|
||||||
|
addresses = await account.receiving.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 1)
|
||||||
|
self.assertEqual(addresses[0], account.public_key.address)
|
||||||
|
addresses = await account.change.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 1)
|
||||||
|
self.assertEqual(addresses[0], account.public_key.address)
|
||||||
|
|
||||||
|
addresses = await account.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 1)
|
||||||
|
self.assertEqual(addresses[0], account.public_key.address)
|
||||||
|
|
||||||
|
async def test_ensure_address_gap(self):
|
||||||
|
account = self.account
|
||||||
|
|
||||||
|
self.assertIsInstance(account.receiving, SingleKey)
|
||||||
|
addresses = await account.receiving.get_addresses()
|
||||||
|
self.assertEqual(addresses, [])
|
||||||
|
|
||||||
|
# we have 12, but default gap is 20
|
||||||
|
new_keys = await account.receiving.ensure_address_gap()
|
||||||
|
self.assertEqual(len(new_keys), 1)
|
||||||
|
self.assertEqual(new_keys[0], account.public_key.address)
|
||||||
|
records = await account.receiving.get_address_records()
|
||||||
|
self.assertEqual(records, [{
|
||||||
|
'position': 0, 'chain': 0,
|
||||||
|
'account': account.public_key.address,
|
||||||
|
'address': account.public_key.address,
|
||||||
|
'used_times': 0
|
||||||
|
}])
|
||||||
|
|
||||||
|
# case #1: no new addresses needed
|
||||||
|
empty = await account.receiving.ensure_address_gap()
|
||||||
|
self.assertEqual(len(empty), 0)
|
||||||
|
|
||||||
|
# case #2: after use, still no new address needed
|
||||||
|
records = await account.receiving.get_address_records()
|
||||||
|
await self.ledger.db.set_address_history(records[0]['address'], 'a:1:')
|
||||||
|
empty = await account.receiving.ensure_address_gap()
|
||||||
|
self.assertEqual(len(empty), 0)
|
||||||
|
|
||||||
|
async def test_get_or_create_usable_address(self):
|
||||||
|
account = self.account
|
||||||
|
|
||||||
|
addresses = await account.receiving.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 0)
|
||||||
|
|
||||||
|
address1 = await account.receiving.get_or_create_usable_address()
|
||||||
|
self.assertIsNotNone(address1)
|
||||||
|
|
||||||
|
await self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:')
|
||||||
|
records = await account.receiving.get_address_records()
|
||||||
|
self.assertEqual(records[0]['used_times'], 3)
|
||||||
|
|
||||||
|
address2 = await account.receiving.get_or_create_usable_address()
|
||||||
|
self.assertEqual(address1, address2)
|
||||||
|
|
||||||
|
keys = await account.receiving.get_addresses()
|
||||||
|
self.assertEqual(len(keys), 1)
|
||||||
|
|
||||||
|
async def test_generate_account_from_seed(self):
|
||||||
|
account = self.ledger.account_class.from_dict(
|
||||||
|
self.ledger, Wallet(), {
|
||||||
|
"seed":
|
||||||
|
"carbon smart garage balance margin twelve chest sword toas"
|
||||||
|
"t envelope bottom stomach absent",
|
||||||
|
'address_generator': {'name': 'single-address'}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
account.private_key.extended_key_string(),
|
||||||
|
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
|
||||||
|
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
account.public_key.extended_key_string(),
|
||||||
|
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7'
|
||||||
|
'UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g',
|
||||||
|
)
|
||||||
|
address = await account.receiving.ensure_address_gap()
|
||||||
|
self.assertEqual(address[0], account.public_key.address)
|
||||||
|
|
||||||
|
private_key = await self.ledger.get_private_key_for_address(address[0])
|
||||||
|
self.assertEqual(
|
||||||
|
private_key.extended_key_string(),
|
||||||
|
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
|
||||||
|
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
|
||||||
|
)
|
||||||
|
|
||||||
|
invalid_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
|
||||||
|
self.assertIsNone(invalid_key)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
hexlify(private_key.wif()),
|
||||||
|
b'1c92caa0ef99bfd5e2ceb73b66da8cd726a9370be8c368d448a322f3c5b23aaab901'
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_load_and_save_account(self):
|
||||||
|
account_data = {
|
||||||
|
'name': 'My Account',
|
||||||
|
'modified_on': 123.456,
|
||||||
|
'seed':
|
||||||
|
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
|
||||||
|
"h absent",
|
||||||
|
'encrypted': False,
|
||||||
|
'private_key':
|
||||||
|
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
|
||||||
|
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
|
||||||
|
'public_key':
|
||||||
|
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7'
|
||||||
|
'UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g',
|
||||||
|
'address_generator': {'name': 'single-address'}
|
||||||
|
}
|
||||||
|
|
||||||
|
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), account_data)
|
||||||
|
|
||||||
|
await account.ensure_address_gap()
|
||||||
|
|
||||||
|
addresses = await account.receiving.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 1)
|
||||||
|
addresses = await account.change.get_addresses()
|
||||||
|
self.assertEqual(len(addresses), 1)
|
||||||
|
|
||||||
|
self.maxDiff = None
|
||||||
|
account_data['ledger'] = 'btc_mainnet'
|
||||||
|
self.assertDictEqual(account_data, account.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
class AccountEncryptionTests(AsyncioTestCase):
|
||||||
|
password = "password"
|
||||||
|
init_vector = b'0000000000000000'
|
||||||
|
unencrypted_account = {
|
||||||
|
'name': 'My Account',
|
||||||
|
'seed':
|
||||||
|
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
|
||||||
|
"h absent",
|
||||||
|
'encrypted': False,
|
||||||
|
'private_key':
|
||||||
|
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3ZT4vYymkp'
|
||||||
|
'5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
|
||||||
|
'public_key':
|
||||||
|
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7'
|
||||||
|
'UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g',
|
||||||
|
'address_generator': {'name': 'single-address'}
|
||||||
|
}
|
||||||
|
encrypted_account = {
|
||||||
|
'name': 'My Account',
|
||||||
|
'seed':
|
||||||
|
"MDAwMDAwMDAwMDAwMDAwMJ4e4W4pE6nQtPiD6MujNIQ7aFPhUBl63GwPziAgGN"
|
||||||
|
"MBTMoaSjZfyyvw7ELMCqAYTWJ61aV7K4lmd2hR11g9dpdnnpCb9f9j3zLZHRv7+"
|
||||||
|
"bIkZ//trah9AIkmrc/ZvNkC0Q==",
|
||||||
|
'encrypted': True,
|
||||||
|
'private_key':
|
||||||
|
'MDAwMDAwMDAwMDAwMDAwMLkWikOLScA/ZxlFSGU7dl//7Q/1gS9h7vqQyrd8DX+'
|
||||||
|
'jwcp7SwlJ1mkMwuraUaWLq9/LxiaGmqJBUZ50p77YVZbDycaCN1unBr1/i1q6RP'
|
||||||
|
'Ob2MNCaG8nyjxZhQai+V/2JmJ+UnFMp3nHany7F8/Hr0g=',
|
||||||
|
'public_key':
|
||||||
|
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6mKUMJFc7'
|
||||||
|
'UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g',
|
||||||
|
'address_generator': {'name': 'single-address'}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.ledger = ledger_class({
|
||||||
|
'db': ledger_class.database_class(':memory:'),
|
||||||
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_encrypt_wallet(self):
|
||||||
|
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.unencrypted_account)
|
||||||
|
account.private_key_encryption_init_vector = self.init_vector
|
||||||
|
account.seed_encryption_init_vector = self.init_vector
|
||||||
|
|
||||||
|
self.assertFalse(account.serialize_encrypted)
|
||||||
|
self.assertFalse(account.encrypted)
|
||||||
|
self.assertIsNotNone(account.private_key)
|
||||||
|
account.encrypt(self.password)
|
||||||
|
self.assertFalse(account.serialize_encrypted)
|
||||||
|
self.assertTrue(account.encrypted)
|
||||||
|
self.assertEqual(account.seed, self.encrypted_account['seed'])
|
||||||
|
self.assertEqual(account.private_key_string, self.encrypted_account['private_key'])
|
||||||
|
self.assertIsNone(account.private_key)
|
||||||
|
|
||||||
|
self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed'])
|
||||||
|
self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key'])
|
||||||
|
|
||||||
|
account.serialize_encrypted = True
|
||||||
|
account.decrypt(self.password)
|
||||||
|
self.assertEqual(account.private_key_encryption_init_vector, self.init_vector)
|
||||||
|
self.assertEqual(account.seed_encryption_init_vector, self.init_vector)
|
||||||
|
|
||||||
|
self.assertEqual(account.seed, self.unencrypted_account['seed'])
|
||||||
|
self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
|
||||||
|
|
||||||
|
self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed'])
|
||||||
|
self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key'])
|
||||||
|
|
||||||
|
self.assertFalse(account.encrypted)
|
||||||
|
self.assertTrue(account.serialize_encrypted)
|
||||||
|
|
||||||
|
account.serialize_encrypted = False
|
||||||
|
self.assertEqual(account.to_dict()['seed'], self.unencrypted_account['seed'])
|
||||||
|
self.assertEqual(account.to_dict()['private_key'], self.unencrypted_account['private_key'])
|
||||||
|
|
||||||
|
def test_decrypt_wallet(self):
|
||||||
|
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.encrypted_account)
|
||||||
|
|
||||||
|
self.assertTrue(account.encrypted)
|
||||||
|
self.assertTrue(account.serialize_encrypted)
|
||||||
|
account.decrypt(self.password)
|
||||||
|
self.assertEqual(account.private_key_encryption_init_vector, self.init_vector)
|
||||||
|
self.assertEqual(account.seed_encryption_init_vector, self.init_vector)
|
||||||
|
|
||||||
|
self.assertFalse(account.encrypted)
|
||||||
|
self.assertTrue(account.serialize_encrypted)
|
||||||
|
|
||||||
|
self.assertEqual(account.seed, self.unencrypted_account['seed'])
|
||||||
|
self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
|
||||||
|
|
||||||
|
self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed'])
|
||||||
|
self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key'])
|
||||||
|
|
||||||
|
account.serialize_encrypted = False
|
||||||
|
self.assertEqual(account.to_dict()['seed'], self.unencrypted_account['seed'])
|
||||||
|
self.assertEqual(account.to_dict()['private_key'], self.unencrypted_account['private_key'])
|
23
torba/tests/client_tests/unit/test_bcd_data_stream.py
Normal file
23
torba/tests/client_tests/unit/test_bcd_data_stream.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from torba.client.bcd_data_stream import BCDataStream
|
||||||
|
|
||||||
|
|
||||||
|
class TestBCDataStream(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_write_read(self):
|
||||||
|
s = BCDataStream()
|
||||||
|
s.write_string(b'a'*252)
|
||||||
|
s.write_string(b'b'*254)
|
||||||
|
s.write_string(b'c'*(0xFFFF + 1))
|
||||||
|
# s.write_string(b'd'*(0xFFFFFFFF + 1))
|
||||||
|
s.write_boolean(True)
|
||||||
|
s.write_boolean(False)
|
||||||
|
s.reset()
|
||||||
|
|
||||||
|
self.assertEqual(s.read_string(), b'a'*252)
|
||||||
|
self.assertEqual(s.read_string(), b'b'*254)
|
||||||
|
self.assertEqual(s.read_string(), b'c'*(0xFFFF + 1))
|
||||||
|
# self.assertEqual(s.read_string(), b'd'*(0xFFFFFFFF + 1))
|
||||||
|
self.assertEqual(s.read_boolean(), True)
|
||||||
|
self.assertEqual(s.read_boolean(), False)
|
104
torba/tests/client_tests/unit/test_bip32.py
Normal file
104
torba/tests/client_tests/unit/test_bip32.py
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
from binascii import unhexlify, hexlify
|
||||||
|
|
||||||
|
from torba.testcase import AsyncioTestCase
|
||||||
|
|
||||||
|
from client_tests.unit.key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys
|
||||||
|
from torba.client.bip32 import PubKey, PrivateKey, from_extended_key_string
|
||||||
|
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||||
|
|
||||||
|
|
||||||
|
class BIP32Tests(AsyncioTestCase):
|
||||||
|
|
||||||
|
def test_pubkey_validation(self):
|
||||||
|
with self.assertRaisesRegex(TypeError, 'chain code must be raw bytes'):
|
||||||
|
PubKey(None, None, 1, None, None, None)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'invalid chain code'):
|
||||||
|
PubKey(None, None, b'abcd', None, None, None)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'invalid child number'):
|
||||||
|
PubKey(None, None, b'abcd'*8, -1, None, None)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'invalid depth'):
|
||||||
|
PubKey(None, None, b'abcd'*8, 0, 256, None)
|
||||||
|
with self.assertRaisesRegex(TypeError, 'pubkey must be raw bytes'):
|
||||||
|
PubKey(None, None, b'abcd'*8, 0, 255, None)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'pubkey must be 33 bytes'):
|
||||||
|
PubKey(None, b'abcd', b'abcd'*8, 0, 255, None)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'invalid pubkey prefix byte'):
|
||||||
|
PubKey(
|
||||||
|
None,
|
||||||
|
unhexlify('33d1a3dc8155673bc1e2214fa493ccc82d57961b66054af9b6b653ac28eeef3ffe'),
|
||||||
|
b'abcd'*8, 0, 255, None
|
||||||
|
)
|
||||||
|
pubkey = PubKey( # success
|
||||||
|
None,
|
||||||
|
unhexlify('03d1a3dc8155673bc1e2214fa493ccc82d57961b66054af9b6b653ac28eeef3ffe'),
|
||||||
|
b'abcd'*8, 0, 1, None
|
||||||
|
)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'invalid BIP32 public key child number'):
|
||||||
|
pubkey.child(-1)
|
||||||
|
for i in range(20):
|
||||||
|
new_key = pubkey.child(i)
|
||||||
|
self.assertIsInstance(new_key, PubKey)
|
||||||
|
self.assertEqual(hexlify(new_key.identifier()), expected_ids[i])
|
||||||
|
|
||||||
|
async def test_private_key_validation(self):
|
||||||
|
with self.assertRaisesRegex(TypeError, 'private key must be raw bytes'):
|
||||||
|
PrivateKey(None, None, b'abcd'*8, 0, 255)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'private key must be 32 bytes'):
|
||||||
|
PrivateKey(None, b'abcd', b'abcd'*8, 0, 255)
|
||||||
|
private_key = PrivateKey(
|
||||||
|
ledger_class({
|
||||||
|
'db': ledger_class.database_class(':memory:'),
|
||||||
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
|
}),
|
||||||
|
unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'),
|
||||||
|
b'abcd'*8, 0, 1
|
||||||
|
)
|
||||||
|
ec_point = private_key.ec_point()
|
||||||
|
self.assertEqual(
|
||||||
|
ec_point[0], 30487144161998778625547553412379759661411261804838752332906558028921886299019
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
ec_point[1], 86198965946979720220333266272536217633917099472454294641561154971209433250106
|
||||||
|
)
|
||||||
|
self.assertEqual(private_key.address(), '1GVM5dEhThbiyCZ9gqBZBv6p9whga7MTXo' )
|
||||||
|
with self.assertRaisesRegex(ValueError, 'invalid BIP32 private key child number'):
|
||||||
|
private_key.child(-1)
|
||||||
|
self.assertIsInstance(private_key.child(PrivateKey.HARDENED), PrivateKey)
|
||||||
|
|
||||||
|
async def test_private_key_derivation(self):
|
||||||
|
private_key = PrivateKey(
|
||||||
|
ledger_class({
|
||||||
|
'db': ledger_class.database_class(':memory:'),
|
||||||
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
|
}),
|
||||||
|
unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'),
|
||||||
|
b'abcd'*8, 0, 1
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
new_privkey = private_key.child(i)
|
||||||
|
self.assertIsInstance(new_privkey, PrivateKey)
|
||||||
|
self.assertEqual(hexlify(new_privkey.private_key_bytes), expected_privkeys[i])
|
||||||
|
for i in range(PrivateKey.HARDENED + 1, private_key.HARDENED + 20):
|
||||||
|
new_privkey = private_key.child(i)
|
||||||
|
self.assertIsInstance(new_privkey, PrivateKey)
|
||||||
|
self.assertEqual(hexlify(new_privkey.private_key_bytes), expected_hardened_privkeys[i - 1 - PrivateKey.HARDENED])
|
||||||
|
|
||||||
|
async def test_from_extended_keys(self):
|
||||||
|
ledger = ledger_class({
|
||||||
|
'db': ledger_class.database_class(':memory:'),
|
||||||
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
|
})
|
||||||
|
self.assertIsInstance(
|
||||||
|
from_extended_key_string(
|
||||||
|
ledger,
|
||||||
|
'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P'
|
||||||
|
'6yz3jMbycrLrRMpeAJxR8qDg8',
|
||||||
|
), PrivateKey
|
||||||
|
)
|
||||||
|
self.assertIsInstance(
|
||||||
|
from_extended_key_string(
|
||||||
|
ledger,
|
||||||
|
'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f'
|
||||||
|
'iW44g14WF52fYC5J483wqQ5ZP',
|
||||||
|
), PubKey
|
||||||
|
)
|
185
torba/tests/client_tests/unit/test_coinselection.py
Normal file
185
torba/tests/client_tests/unit/test_coinselection.py
Normal file
|
@ -0,0 +1,185 @@
|
||||||
|
from types import GeneratorType
|
||||||
|
|
||||||
|
from torba.testcase import AsyncioTestCase
|
||||||
|
|
||||||
|
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||||
|
from torba.client.coinselection import CoinSelector, MAXIMUM_TRIES
|
||||||
|
from torba.client.constants import CENT
|
||||||
|
|
||||||
|
from client_tests.unit.test_transaction import get_output as utxo
|
||||||
|
|
||||||
|
|
||||||
|
NULL_HASH = b'\x00'*32
|
||||||
|
|
||||||
|
|
||||||
|
def search(*args, **kwargs):
|
||||||
|
selection = CoinSelector(*args, **kwargs).branch_and_bound()
|
||||||
|
return [o.txo.amount for o in selection] if selection else selection
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSelectionTestCase(AsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.ledger = ledger_class({
|
||||||
|
'db': ledger_class.database_class(':memory:'),
|
||||||
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
|
})
|
||||||
|
await self.ledger.db.open()
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
await self.ledger.db.close()
|
||||||
|
|
||||||
|
def estimates(self, *args):
|
||||||
|
txos = args[0] if isinstance(args[0], (GeneratorType, list)) else args
|
||||||
|
return [txo.get_estimator(self.ledger) for txo in txos]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCoinSelectionTests(BaseSelectionTestCase):
|
||||||
|
|
||||||
|
def test_empty_coins(self):
|
||||||
|
self.assertEqual(CoinSelector([], 0, 0).select(), [])
|
||||||
|
|
||||||
|
def test_skip_binary_search_if_total_not_enough(self):
|
||||||
|
fee = utxo(CENT).get_estimator(self.ledger).fee
|
||||||
|
big_pool = self.estimates(utxo(CENT+fee) for _ in range(100))
|
||||||
|
selector = CoinSelector(big_pool, 101 * CENT, 0)
|
||||||
|
self.assertEqual(selector.select(), [])
|
||||||
|
self.assertEqual(selector.tries, 0) # Never tried.
|
||||||
|
# check happy path
|
||||||
|
selector = CoinSelector(big_pool, 100 * CENT, 0)
|
||||||
|
self.assertEqual(len(selector.select()), 100)
|
||||||
|
self.assertEqual(selector.tries, 201)
|
||||||
|
|
||||||
|
def test_exact_match(self):
|
||||||
|
fee = utxo(CENT).get_estimator(self.ledger).fee
|
||||||
|
utxo_pool = self.estimates(
|
||||||
|
utxo(CENT + fee),
|
||||||
|
utxo(CENT),
|
||||||
|
utxo(CENT - fee)
|
||||||
|
)
|
||||||
|
selector = CoinSelector(utxo_pool, CENT, 0)
|
||||||
|
match = selector.select()
|
||||||
|
self.assertEqual([CENT + fee], [c.txo.amount for c in match])
|
||||||
|
self.assertTrue(selector.exact_match)
|
||||||
|
|
||||||
|
def test_random_draw(self):
|
||||||
|
utxo_pool = self.estimates(
|
||||||
|
utxo(2 * CENT),
|
||||||
|
utxo(3 * CENT),
|
||||||
|
utxo(4 * CENT)
|
||||||
|
)
|
||||||
|
selector = CoinSelector(utxo_pool, CENT, 0, '\x00')
|
||||||
|
match = selector.select()
|
||||||
|
self.assertEqual([2 * CENT], [c.txo.amount for c in match])
|
||||||
|
self.assertFalse(selector.exact_match)
|
||||||
|
|
||||||
|
def test_pick(self):
|
||||||
|
utxo_pool = self.estimates(
|
||||||
|
utxo(1*CENT),
|
||||||
|
utxo(1*CENT),
|
||||||
|
utxo(3*CENT),
|
||||||
|
utxo(5*CENT),
|
||||||
|
utxo(10*CENT),
|
||||||
|
)
|
||||||
|
selector = CoinSelector(utxo_pool, 3*CENT, 0)
|
||||||
|
match = selector.select()
|
||||||
|
self.assertEqual([5*CENT], [c.txo.amount for c in match])
|
||||||
|
|
||||||
|
def test_prefer_confirmed_strategy(self):
|
||||||
|
utxo_pool = self.estimates(
|
||||||
|
utxo(11*CENT, height=5),
|
||||||
|
utxo(11*CENT, height=0),
|
||||||
|
utxo(11*CENT, height=-2),
|
||||||
|
utxo(11*CENT, height=5),
|
||||||
|
)
|
||||||
|
selector = CoinSelector(utxo_pool, 20*CENT, 0)
|
||||||
|
match = selector.select("prefer_confirmed")
|
||||||
|
self.assertEqual([5, 5], [c.txo.tx_ref.height for c in match])
|
||||||
|
|
||||||
|
|
||||||
|
class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase):
|
||||||
|
|
||||||
|
# Bitcoin implementation:
|
||||||
|
# https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp
|
||||||
|
#
|
||||||
|
# Bitcoin implementation tests:
|
||||||
|
# https://github.com/bitcoin/bitcoin/blob/master/src/wallet/test/coinselector_tests.cpp
|
||||||
|
#
|
||||||
|
# Branch and Bound coin selection white paper:
|
||||||
|
# https://murch.one/wp-content/uploads/2016/11/erhardt2016coinselection.pdf
|
||||||
|
|
||||||
|
def make_hard_case(self, utxos):
|
||||||
|
target = 0
|
||||||
|
utxo_pool = []
|
||||||
|
for i in range(utxos):
|
||||||
|
amount = 1 << (utxos+i)
|
||||||
|
target += amount
|
||||||
|
utxo_pool.append(utxo(amount))
|
||||||
|
utxo_pool.append(utxo(amount + (1 << (utxos-1-i))))
|
||||||
|
return self.estimates(utxo_pool), target
|
||||||
|
|
||||||
|
def test_branch_and_bound_coin_selection(self):
|
||||||
|
self.ledger.fee_per_byte = 0
|
||||||
|
|
||||||
|
utxo_pool = self.estimates(
|
||||||
|
utxo(1 * CENT),
|
||||||
|
utxo(2 * CENT),
|
||||||
|
utxo(3 * CENT),
|
||||||
|
utxo(4 * CENT)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select 1 Cent
|
||||||
|
self.assertEqual([1 * CENT], search(utxo_pool, 1 * CENT, 0.5 * CENT))
|
||||||
|
|
||||||
|
# Select 2 Cent
|
||||||
|
self.assertEqual([2 * CENT], search(utxo_pool, 2 * CENT, 0.5 * CENT))
|
||||||
|
|
||||||
|
# Select 5 Cent
|
||||||
|
self.assertEqual([3 * CENT, 2 * CENT], search(utxo_pool, 5 * CENT, 0.5 * CENT))
|
||||||
|
|
||||||
|
# Select 11 Cent, not possible
|
||||||
|
self.assertEqual([], search(utxo_pool, 11 * CENT, 0.5 * CENT))
|
||||||
|
|
||||||
|
# Select 10 Cent
|
||||||
|
utxo_pool += self.estimates(utxo(5 * CENT))
|
||||||
|
self.assertEqual(
|
||||||
|
[4 * CENT, 3 * CENT, 2 * CENT, 1 * CENT],
|
||||||
|
search(utxo_pool, 10 * CENT, 0.5 * CENT)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Negative effective value
|
||||||
|
# Select 10 Cent but have 1 Cent not be possible because too small
|
||||||
|
# TODO: bitcoin has [5, 3, 2]
|
||||||
|
self.assertEqual(
|
||||||
|
[4 * CENT, 3 * CENT, 2 * CENT, 1 * CENT],
|
||||||
|
search(utxo_pool, 10 * CENT, 5000)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select 0.25 Cent, not possible
|
||||||
|
self.assertEqual(search(utxo_pool, 0.25 * CENT, 0.5 * CENT), [])
|
||||||
|
|
||||||
|
# Iteration exhaustion test
|
||||||
|
utxo_pool, target = self.make_hard_case(17)
|
||||||
|
selector = CoinSelector(utxo_pool, target, 0)
|
||||||
|
self.assertEqual(selector.branch_and_bound(), [])
|
||||||
|
self.assertEqual(selector.tries, MAXIMUM_TRIES) # Should exhaust
|
||||||
|
utxo_pool, target = self.make_hard_case(14)
|
||||||
|
self.assertIsNotNone(search(utxo_pool, target, 0)) # Should not exhaust
|
||||||
|
|
||||||
|
# Test same value early bailout optimization
|
||||||
|
utxo_pool = self.estimates([
|
||||||
|
utxo(7 * CENT),
|
||||||
|
utxo(7 * CENT),
|
||||||
|
utxo(7 * CENT),
|
||||||
|
utxo(7 * CENT),
|
||||||
|
utxo(2 * CENT)
|
||||||
|
] + [utxo(5 * CENT)]*50000)
|
||||||
|
self.assertEqual(
|
||||||
|
[7 * CENT, 7 * CENT, 7 * CENT, 7 * CENT, 2 * CENT],
|
||||||
|
search(utxo_pool, 30 * CENT, 5000)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select 1 Cent with pool of only greater than 5 Cent
|
||||||
|
utxo_pool = self.estimates(utxo(i * CENT) for i in range(5, 21))
|
||||||
|
for _ in range(100):
|
||||||
|
self.assertEqual(search(utxo_pool, 1 * CENT, 2 * CENT), [])
|
310
torba/tests/client_tests/unit/test_database.py
Normal file
310
torba/tests/client_tests/unit/test_database.py
Normal file
|
@ -0,0 +1,310 @@
|
||||||
|
import unittest
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
from torba.client.wallet import Wallet
|
||||||
|
from torba.client.constants import COIN
|
||||||
|
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||||
|
from torba.client.basedatabase import query, constraints_to_sql, AIOSQLite
|
||||||
|
|
||||||
|
from torba.testcase import AsyncioTestCase
|
||||||
|
|
||||||
|
from client_tests.unit.test_transaction import get_output, NULL_HASH
|
||||||
|
|
||||||
|
|
||||||
|
class TestAIOSQLite(AsyncioTestCase):
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.db = await AIOSQLite.connect(':memory:')
|
||||||
|
await self.db.executescript("""
|
||||||
|
pragma foreign_keys=on;
|
||||||
|
create table parent (id integer primary key, name);
|
||||||
|
create table child (id integer primary key, parent_id references parent);
|
||||||
|
""")
|
||||||
|
await self.db.execute("insert into parent values (1, 'test')")
|
||||||
|
await self.db.execute("insert into child values (2, 1)")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_item(transaction):
|
||||||
|
transaction.execute('delete from parent where id=1')
|
||||||
|
|
||||||
|
async def test_foreign_keys_integrity_error(self):
|
||||||
|
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
|
||||||
|
|
||||||
|
with self.assertRaises(sqlite3.IntegrityError):
|
||||||
|
await self.db.run(self.delete_item)
|
||||||
|
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
|
||||||
|
|
||||||
|
await self.db.executescript("pragma foreign_keys=off;")
|
||||||
|
|
||||||
|
await self.db.run(self.delete_item)
|
||||||
|
self.assertListEqual([], await self.db.execute_fetchall("select * from parent"))
|
||||||
|
|
||||||
|
async def test_run_without_foreign_keys(self):
|
||||||
|
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
|
||||||
|
await self.db.run_with_foreign_keys_disabled(self.delete_item)
|
||||||
|
self.assertListEqual([], await self.db.execute_fetchall("select * from parent"))
|
||||||
|
|
||||||
|
async def test_integrity_error_when_foreign_keys_disabled_and_skipped(self):
|
||||||
|
await self.db.executescript("pragma foreign_keys=off;")
|
||||||
|
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
|
||||||
|
with self.assertRaises(sqlite3.IntegrityError):
|
||||||
|
await self.db.run_with_foreign_keys_disabled(self.delete_item)
|
||||||
|
self.assertListEqual([(1, 'test')], await self.db.execute_fetchall("select * from parent"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryBuilder(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_dot(self):
|
||||||
|
self.assertEqual(
|
||||||
|
constraints_to_sql({'txo.position': 18}),
|
||||||
|
('txo.position = :txo_position0', {'txo_position0': 18})
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
constraints_to_sql({'txo.position#6': 18}),
|
||||||
|
('txo.position = :txo_position6', {'txo_position6': 18})
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_any(self):
|
||||||
|
self.assertEqual(
|
||||||
|
constraints_to_sql({
|
||||||
|
'ages__any': {
|
||||||
|
'txo.age__gt': 18,
|
||||||
|
'txo.age__lt': 38
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
('(txo.age > :ages__any0_txo_age__gt0 OR txo.age < :ages__any0_txo_age__lt0)', {
|
||||||
|
'ages__any0_txo_age__gt0': 18,
|
||||||
|
'ages__any0_txo_age__lt0': 38
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_in(self):
|
||||||
|
self.assertEqual(
|
||||||
|
constraints_to_sql({'txo.age__in#2': [18, 38]}),
|
||||||
|
('txo.age IN (:txo_age__in2_0, :txo_age__in2_1)', {
|
||||||
|
'txo_age__in2_0': 18,
|
||||||
|
'txo_age__in2_1': 38
|
||||||
|
})
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
constraints_to_sql({'txo.name__in': ('abc123', 'def456')}),
|
||||||
|
('txo.name IN (:txo_name__in0_0, :txo_name__in0_1)', {
|
||||||
|
'txo_name__in0_0': 'abc123',
|
||||||
|
'txo_name__in0_1': 'def456'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}),
|
||||||
|
('txo.age IN (SELECT age from ages_table)', {})
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_not_in(self):
|
||||||
|
self.assertEqual(
|
||||||
|
constraints_to_sql({'txo.age__not_in': [18, 38]}),
|
||||||
|
('txo.age NOT IN (:txo_age__not_in0_0, :txo_age__not_in0_1)', {
|
||||||
|
'txo_age__not_in0_0': 18,
|
||||||
|
'txo_age__not_in0_1': 38
|
||||||
|
})
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
constraints_to_sql({'txo.name__not_in': ('abc123', 'def456')}),
|
||||||
|
('txo.name NOT IN (:txo_name__not_in0_0, :txo_name__not_in0_1)', {
|
||||||
|
'txo_name__not_in0_0': 'abc123',
|
||||||
|
'txo_name__not_in0_1': 'def456'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}),
|
||||||
|
('txo.age NOT IN (SELECT age from ages_table)', {})
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_in_invalid(self):
|
||||||
|
with self.assertRaisesRegex(ValueError, 'list, set or string'):
|
||||||
|
constraints_to_sql({'ages__in': 9})
|
||||||
|
|
||||||
|
def test_query(self):
|
||||||
|
self.assertEqual(
|
||||||
|
query("select * from foo"),
|
||||||
|
("select * from foo", {})
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
query(
|
||||||
|
"select * from foo",
|
||||||
|
a__not='b', b__in='select * from blah where c=:$c',
|
||||||
|
d__any={'one__like': 'o', 'two': 2}, limit=10, order_by='b', **{'$c': 3}),
|
||||||
|
(
|
||||||
|
"select * from foo WHERE a != :a__not0 AND "
|
||||||
|
"b IN (select * from blah where c=:$c) AND "
|
||||||
|
"(one LIKE :d__any0_one__like0 OR two = :d__any0_two0) ORDER BY b LIMIT 10",
|
||||||
|
{'a__not0': 'b', 'd__any0_one__like0': 'o', 'd__any0_two0': 2, '$c': 3}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_query_order_by(self):
|
||||||
|
self.assertEqual(
|
||||||
|
query("select * from foo", order_by='foo'),
|
||||||
|
("select * from foo ORDER BY foo", {})
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
query("select * from foo", order_by=['foo', 'bar']),
|
||||||
|
("select * from foo ORDER BY foo, bar", {})
|
||||||
|
)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'order_by must be string or list'):
|
||||||
|
query("select * from foo", order_by={'foo': 'bar'})
|
||||||
|
|
||||||
|
def test_query_limit_offset(self):
|
||||||
|
self.assertEqual(
|
||||||
|
query("select * from foo", limit=10),
|
||||||
|
("select * from foo LIMIT 10", {})
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
query("select * from foo", offset=10),
|
||||||
|
("select * from foo OFFSET 10", {})
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
query("select * from foo", limit=20, offset=10),
|
||||||
|
("select * from foo LIMIT 20 OFFSET 10", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueries(AsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.ledger = ledger_class({
|
||||||
|
'db': ledger_class.database_class(':memory:'),
|
||||||
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
|
})
|
||||||
|
await self.ledger.db.open()
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
await self.ledger.db.close()
|
||||||
|
|
||||||
|
async def create_account(self):
|
||||||
|
account = self.ledger.account_class.generate(self.ledger, Wallet())
|
||||||
|
await account.ensure_address_gap()
|
||||||
|
return account
|
||||||
|
|
||||||
|
async def create_tx_from_nothing(self, my_account, height):
|
||||||
|
to_address = await my_account.receiving.get_or_create_usable_address()
|
||||||
|
to_hash = ledger_class.address_to_hash160(to_address)
|
||||||
|
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
||||||
|
.add_inputs([self.txi(self.txo(1, NULL_HASH))]) \
|
||||||
|
.add_outputs([self.txo(1, to_hash)])
|
||||||
|
await self.ledger.db.insert_transaction(tx)
|
||||||
|
await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '')
|
||||||
|
return tx
|
||||||
|
|
||||||
|
async def create_tx_from_txo(self, txo, to_account, height):
|
||||||
|
from_hash = txo.script.values['pubkey_hash']
|
||||||
|
from_address = self.ledger.hash160_to_address(from_hash)
|
||||||
|
to_address = await to_account.receiving.get_or_create_usable_address()
|
||||||
|
to_hash = ledger_class.address_to_hash160(to_address)
|
||||||
|
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
||||||
|
.add_inputs([self.txi(txo)]) \
|
||||||
|
.add_outputs([self.txo(1, to_hash)])
|
||||||
|
await self.ledger.db.insert_transaction(tx)
|
||||||
|
await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '')
|
||||||
|
await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '')
|
||||||
|
return tx
|
||||||
|
|
||||||
|
async def create_tx_to_nowhere(self, txo, height):
|
||||||
|
from_hash = txo.script.values['pubkey_hash']
|
||||||
|
from_address = self.ledger.hash160_to_address(from_hash)
|
||||||
|
to_hash = NULL_HASH
|
||||||
|
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
||||||
|
.add_inputs([self.txi(txo)]) \
|
||||||
|
.add_outputs([self.txo(1, to_hash)])
|
||||||
|
await self.ledger.db.insert_transaction(tx)
|
||||||
|
await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '')
|
||||||
|
return tx
|
||||||
|
|
||||||
|
def txo(self, amount, address):
|
||||||
|
return get_output(int(amount*COIN), address)
|
||||||
|
|
||||||
|
def txi(self, txo):
|
||||||
|
return ledger_class.transaction_class.input_class.spend(txo)
|
||||||
|
|
||||||
|
async def test_queries(self):
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_address_count())
|
||||||
|
account1 = await self.create_account()
|
||||||
|
self.assertEqual(26, await self.ledger.db.get_address_count())
|
||||||
|
account2 = await self.create_account()
|
||||||
|
self.assertEqual(52, await self.ledger.db.get_address_count())
|
||||||
|
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_transaction_count())
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_utxo_count())
|
||||||
|
self.assertEqual([], await self.ledger.db.get_utxos())
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_txo_count())
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_balance())
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_balance(account=account1))
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_balance(account=account2))
|
||||||
|
|
||||||
|
tx1 = await self.create_tx_from_nothing(account1, 1)
|
||||||
|
self.assertEqual(1, await self.ledger.db.get_transaction_count(account=account1))
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_transaction_count(account=account2))
|
||||||
|
self.assertEqual(1, await self.ledger.db.get_utxo_count(account=account1))
|
||||||
|
self.assertEqual(1, await self.ledger.db.get_txo_count(account=account1))
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_txo_count(account=account2))
|
||||||
|
self.assertEqual(10**8, await self.ledger.db.get_balance())
|
||||||
|
self.assertEqual(10**8, await self.ledger.db.get_balance(account=account1))
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_balance(account=account2))
|
||||||
|
|
||||||
|
tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2)
|
||||||
|
self.assertEqual(2, await self.ledger.db.get_transaction_count(account=account1))
|
||||||
|
self.assertEqual(1, await self.ledger.db.get_transaction_count(account=account2))
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_utxo_count(account=account1))
|
||||||
|
self.assertEqual(1, await self.ledger.db.get_txo_count(account=account1))
|
||||||
|
self.assertEqual(1, await self.ledger.db.get_utxo_count(account=account2))
|
||||||
|
self.assertEqual(1, await self.ledger.db.get_txo_count(account=account2))
|
||||||
|
self.assertEqual(10**8, await self.ledger.db.get_balance())
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_balance(account=account1))
|
||||||
|
self.assertEqual(10**8, await self.ledger.db.get_balance(account=account2))
|
||||||
|
|
||||||
|
tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3)
|
||||||
|
self.assertEqual(2, await self.ledger.db.get_transaction_count(account=account1))
|
||||||
|
self.assertEqual(2, await self.ledger.db.get_transaction_count(account=account2))
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_utxo_count(account=account1))
|
||||||
|
self.assertEqual(1, await self.ledger.db.get_txo_count(account=account1))
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_utxo_count(account=account2))
|
||||||
|
self.assertEqual(1, await self.ledger.db.get_txo_count(account=account2))
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_balance())
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_balance(account=account1))
|
||||||
|
self.assertEqual(0, await self.ledger.db.get_balance(account=account2))
|
||||||
|
|
||||||
|
txs = await self.ledger.db.get_transactions()
|
||||||
|
self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
|
||||||
|
self.assertEqual([3, 2, 1], [tx.height for tx in txs])
|
||||||
|
|
||||||
|
txs = await self.ledger.db.get_transactions(account=account1)
|
||||||
|
self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs])
|
||||||
|
self.assertEqual(txs[0].inputs[0].is_my_account, True)
|
||||||
|
self.assertEqual(txs[0].outputs[0].is_my_account, False)
|
||||||
|
self.assertEqual(txs[1].inputs[0].is_my_account, False)
|
||||||
|
self.assertEqual(txs[1].outputs[0].is_my_account, True)
|
||||||
|
|
||||||
|
txs = await self.ledger.db.get_transactions(account=account2)
|
||||||
|
self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs])
|
||||||
|
self.assertEqual(txs[0].inputs[0].is_my_account, True)
|
||||||
|
self.assertEqual(txs[0].outputs[0].is_my_account, False)
|
||||||
|
self.assertEqual(txs[1].inputs[0].is_my_account, False)
|
||||||
|
self.assertEqual(txs[1].outputs[0].is_my_account, True)
|
||||||
|
self.assertEqual(2, await self.ledger.db.get_transaction_count(account=account2))
|
||||||
|
|
||||||
|
tx = await self.ledger.db.get_transaction(txid=tx2.id)
|
||||||
|
self.assertEqual(tx.id, tx2.id)
|
||||||
|
self.assertEqual(tx.inputs[0].is_my_account, False)
|
||||||
|
self.assertEqual(tx.outputs[0].is_my_account, False)
|
||||||
|
tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account1)
|
||||||
|
self.assertEqual(tx.inputs[0].is_my_account, True)
|
||||||
|
self.assertEqual(tx.outputs[0].is_my_account, False)
|
||||||
|
tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account2)
|
||||||
|
self.assertEqual(tx.inputs[0].is_my_account, False)
|
||||||
|
self.assertEqual(tx.outputs[0].is_my_account, True)
|
||||||
|
|
||||||
|
# height 0 sorted to the top with the rest in descending order
|
||||||
|
tx4 = await self.create_tx_from_nothing(account1, 0)
|
||||||
|
txos = await self.ledger.db.get_txos()
|
||||||
|
self.assertEqual([0, 2, 1], [txo.tx_ref.height for txo in txos])
|
||||||
|
self.assertEqual([tx4.id, tx2.id, tx1.id], [txo.tx_ref.id for txo in txos])
|
||||||
|
txs = await self.ledger.db.get_transactions()
|
||||||
|
self.assertEqual([0, 3, 2, 1], [tx.height for tx in txs])
|
||||||
|
self.assertEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
|
42
torba/tests/client_tests/unit/test_hash.py
Normal file
42
torba/tests/client_tests/unit/test_hash.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
from unittest import TestCase, mock
|
||||||
|
from torba.client.hash import aes_decrypt, aes_encrypt, better_aes_decrypt, better_aes_encrypt
|
||||||
|
|
||||||
|
|
||||||
|
class TestAESEncryptDecrypt(TestCase):
|
||||||
|
message = 'The Times 03/Jan/2009 Chancellor on brink of second bailout for banks'
|
||||||
|
expected = 'ZmZmZmZmZmZmZmZmZmZmZjlrKptoKD+MFwDxcg3XtCD9qz8UWhEhq/TVJT5+Mtp2a8sE' \
|
||||||
|
'CaO6WQj7fYsWGu2Hvbc0qYqxdN0HeTsiO+cZRo3eJISgr3F+rXFYi5oSBlD2'
|
||||||
|
password = 'bubblegum'
|
||||||
|
|
||||||
|
@mock.patch('os.urandom', side_effect=lambda i: b'd'*i)
|
||||||
|
def test_encrypt_iv_f(self, _):
|
||||||
|
self.assertEqual(
|
||||||
|
aes_encrypt(self.password, self.message),
|
||||||
|
'ZGRkZGRkZGRkZGRkZGRkZKBP/4pR+47hLHbHyvDJm9aRKDuoBdTG8SrFvHqfagK6Co1VrHUOd'
|
||||||
|
'oF+6PGSxru3+VR63ybkXLNM75s/qVw+dnKVAkI8OfoVnJvGRSc49e38'
|
||||||
|
)
|
||||||
|
|
||||||
|
@mock.patch('os.urandom', side_effect=lambda i: b'f'*i)
|
||||||
|
def test_encrypt_iv_d(self, _):
|
||||||
|
self.assertEqual(
|
||||||
|
aes_encrypt(self.password, self.message),
|
||||||
|
'ZmZmZmZmZmZmZmZmZmZmZjlrKptoKD+MFwDxcg3XtCD9qz8UWhEhq/TVJT5+Mtp2a8sE'
|
||||||
|
'CaO6WQj7fYsWGu2Hvbc0qYqxdN0HeTsiO+cZRo3eJISgr3F+rXFYi5oSBlD2'
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
aes_decrypt(self.password, self.expected),
|
||||||
|
(self.message, b'f' * 16)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_encrypt_decrypt(self):
|
||||||
|
self.assertEqual(
|
||||||
|
aes_decrypt('bubblegum', aes_encrypt('bubblegum', self.message))[0],
|
||||||
|
self.message
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_better_encrypt_decrypt(self):
|
||||||
|
self.assertEqual(
|
||||||
|
b'valuable value',
|
||||||
|
better_aes_decrypt(
|
||||||
|
'super secret',
|
||||||
|
better_aes_encrypt('super secret', b'valuable value')))
|
108
torba/tests/client_tests/unit/test_headers.py
Normal file
108
torba/tests/client_tests/unit/test_headers.py
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
import os
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
from torba.testcase import AsyncioTestCase
|
||||||
|
|
||||||
|
from torba.coin.bitcoinsegwit import MainHeaders
|
||||||
|
|
||||||
|
|
||||||
|
def block_bytes(blocks):
|
||||||
|
return blocks * MainHeaders.header_size
|
||||||
|
|
||||||
|
|
||||||
|
class BitcoinHeadersTestCase(AsyncioTestCase):
|
||||||
|
|
||||||
|
# Download headers instead of storing them in git.
|
||||||
|
HEADER_URL = 'http://headers.electrum.org/blockchain_headers'
|
||||||
|
HEADER_FILE = 'bitcoin_headers'
|
||||||
|
HEADER_BYTES = block_bytes(32260) # 2.6MB
|
||||||
|
RETARGET_BLOCK = 32256 # difficulty: 1 -> 1.18
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.maxDiff = None
|
||||||
|
self.header_file_name = os.path.join(os.path.dirname(__file__), self.HEADER_FILE)
|
||||||
|
if not os.path.exists(self.header_file_name):
|
||||||
|
req = Request(self.HEADER_URL)
|
||||||
|
req.add_header('Range', 'bytes=0-{}'.format(self.HEADER_BYTES-1))
|
||||||
|
with urlopen(req) as response, open(self.header_file_name, 'wb') as header_file:
|
||||||
|
header_file.write(response.read())
|
||||||
|
if os.path.getsize(self.header_file_name) != self.HEADER_BYTES:
|
||||||
|
os.remove(self.header_file_name)
|
||||||
|
raise Exception(
|
||||||
|
"Downloaded headers for testing are not the correct number of bytes. "
|
||||||
|
"They were deleted. Try running the tests again."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_bytes(self, upto: int = -1, after: int = 0) -> bytes:
|
||||||
|
with open(self.header_file_name, 'rb') as headers:
|
||||||
|
headers.seek(after, os.SEEK_SET)
|
||||||
|
return headers.read(upto)
|
||||||
|
|
||||||
|
async def get_headers(self, upto: int = -1):
|
||||||
|
h = MainHeaders(':memory:')
|
||||||
|
h.io.write(self.get_bytes(upto))
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class BasicHeadersTests(BitcoinHeadersTestCase):
|
||||||
|
|
||||||
|
async def test_serialization(self):
|
||||||
|
h = await self.get_headers()
|
||||||
|
self.assertEqual(h[0], {
|
||||||
|
'bits': 486604799,
|
||||||
|
'block_height': 0,
|
||||||
|
'merkle_root': b'4a5e1e4baab89f3a32518a88c31bc87f618f76673e2cc77ab2127b7afdeda33b',
|
||||||
|
'nonce': 2083236893,
|
||||||
|
'prev_block_hash': b'0000000000000000000000000000000000000000000000000000000000000000',
|
||||||
|
'timestamp': 1231006505,
|
||||||
|
'version': 1
|
||||||
|
})
|
||||||
|
self.assertEqual(h[self.RETARGET_BLOCK-1], {
|
||||||
|
'bits': 486604799,
|
||||||
|
'block_height': 32255,
|
||||||
|
'merkle_root': b'89b4f223789e40b5b475af6483bb05bceda54059e17d2053334b358f6bb310ac',
|
||||||
|
'nonce': 312762301,
|
||||||
|
'prev_block_hash': b'000000006baebaa74cecde6c6787c26ee0a616a3c333261bff36653babdac149',
|
||||||
|
'timestamp': 1262152739,
|
||||||
|
'version': 1
|
||||||
|
})
|
||||||
|
self.assertEqual(h[self.RETARGET_BLOCK], {
|
||||||
|
'bits': 486594666,
|
||||||
|
'block_height': 32256,
|
||||||
|
'merkle_root': b'64b5e5f5a262f47af443a0120609206a3305877693edfe03e994f20a024ab627',
|
||||||
|
'nonce': 121087187,
|
||||||
|
'prev_block_hash': b'00000000984f962134a7291e3693075ae03e521f0ee33378ec30a334d860034b',
|
||||||
|
'timestamp': 1262153464,
|
||||||
|
'version': 1
|
||||||
|
})
|
||||||
|
self.assertEqual(h[self.RETARGET_BLOCK+1], {
|
||||||
|
'bits': 486594666,
|
||||||
|
'block_height': 32257,
|
||||||
|
'merkle_root': b'4d1488981f08b3037878193297dbac701a2054e0f803d4424fe6a4d763d62334',
|
||||||
|
'nonce': 274675219,
|
||||||
|
'prev_block_hash': b'000000004f2886a170adb7204cb0c7a824217dd24d11a74423d564c4e0904967',
|
||||||
|
'timestamp': 1262154352,
|
||||||
|
'version': 1
|
||||||
|
})
|
||||||
|
self.assertEqual(
|
||||||
|
h.serialize(h[0]),
|
||||||
|
h.get_raw_header(0)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
h.serialize(h[self.RETARGET_BLOCK]),
|
||||||
|
h.get_raw_header(self.RETARGET_BLOCK)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_connect_from_genesis_to_3000_past_first_chunk_at_2016(self):
|
||||||
|
headers = MainHeaders(':memory:')
|
||||||
|
self.assertEqual(headers.height, -1)
|
||||||
|
await headers.connect(0, self.get_bytes(block_bytes(3001)))
|
||||||
|
self.assertEqual(headers.height, 3000)
|
||||||
|
|
||||||
|
async def test_connect_9_blocks_passing_a_retarget_at_32256(self):
|
||||||
|
retarget = block_bytes(self.RETARGET_BLOCK-5)
|
||||||
|
headers = await self.get_headers(upto=retarget)
|
||||||
|
remainder = self.get_bytes(after=retarget)
|
||||||
|
self.assertEqual(headers.height, 32250)
|
||||||
|
await headers.connect(len(headers), remainder)
|
||||||
|
self.assertEqual(headers.height, 32259)
|
166
torba/tests/client_tests/unit/test_ledger.py
Normal file
166
torba/tests/client_tests/unit/test_ledger.py
Normal file
|
@ -0,0 +1,166 @@
|
||||||
|
import os
|
||||||
|
from binascii import hexlify
|
||||||
|
|
||||||
|
from torba.coin.bitcoinsegwit import MainNetLedger
|
||||||
|
from torba.client.wallet import Wallet
|
||||||
|
|
||||||
|
from client_tests.unit.test_transaction import get_transaction, get_output
|
||||||
|
from client_tests.unit.test_headers import BitcoinHeadersTestCase, block_bytes
|
||||||
|
|
||||||
|
|
||||||
|
class MockNetwork:
|
||||||
|
|
||||||
|
def __init__(self, history, transaction):
|
||||||
|
self.history = history
|
||||||
|
self.transaction = transaction
|
||||||
|
self.address = None
|
||||||
|
self.get_history_called = []
|
||||||
|
self.get_transaction_called = []
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
|
async def get_history(self, address):
|
||||||
|
self.get_history_called.append(address)
|
||||||
|
self.address = address
|
||||||
|
return self.history
|
||||||
|
|
||||||
|
async def get_merkle(self, txid, height):
|
||||||
|
return {'merkle': ['abcd01'], 'pos': 1}
|
||||||
|
|
||||||
|
async def get_transaction(self, tx_hash):
|
||||||
|
self.get_transaction_called.append(tx_hash)
|
||||||
|
return self.transaction[tx_hash]
|
||||||
|
|
||||||
|
|
||||||
|
class LedgerTestCase(BitcoinHeadersTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.ledger = MainNetLedger({
|
||||||
|
'db': MainNetLedger.database_class(':memory:'),
|
||||||
|
'headers': MainNetLedger.headers_class(':memory:')
|
||||||
|
})
|
||||||
|
await self.ledger.db.open()
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
await self.ledger.db.close()
|
||||||
|
|
||||||
|
def make_header(self, **kwargs):
|
||||||
|
header = {
|
||||||
|
'bits': 486604799,
|
||||||
|
'block_height': 0,
|
||||||
|
'merkle_root': b'4a5e1e4baab89f3a32518a88c31bc87f618f76673e2cc77ab2127b7afdeda33b',
|
||||||
|
'nonce': 2083236893,
|
||||||
|
'prev_block_hash': b'0000000000000000000000000000000000000000000000000000000000000000',
|
||||||
|
'timestamp': 1231006505,
|
||||||
|
'version': 1
|
||||||
|
}
|
||||||
|
header.update(kwargs)
|
||||||
|
header['merkle_root'] = header['merkle_root'].ljust(64, b'a')
|
||||||
|
header['prev_block_hash'] = header['prev_block_hash'].ljust(64, b'0')
|
||||||
|
return self.ledger.headers.serialize(header)
|
||||||
|
|
||||||
|
def add_header(self, **kwargs):
|
||||||
|
serialized = self.make_header(**kwargs)
|
||||||
|
self.ledger.headers.io.seek(0, os.SEEK_END)
|
||||||
|
self.ledger.headers.io.write(serialized)
|
||||||
|
self.ledger.headers._size = None
|
||||||
|
|
||||||
|
|
||||||
|
class TestSynchronization(LedgerTestCase):
|
||||||
|
|
||||||
|
async def test_update_history(self):
|
||||||
|
account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
|
||||||
|
address = await account.receiving.get_or_create_usable_address()
|
||||||
|
address_details = await self.ledger.db.get_address(address=address)
|
||||||
|
self.assertEqual(address_details['history'], None)
|
||||||
|
|
||||||
|
self.add_header(block_height=0, merkle_root=b'abcd04')
|
||||||
|
self.add_header(block_height=1, merkle_root=b'abcd04')
|
||||||
|
self.add_header(block_height=2, merkle_root=b'abcd04')
|
||||||
|
self.add_header(block_height=3, merkle_root=b'abcd04')
|
||||||
|
self.ledger.network = MockNetwork([
|
||||||
|
{'tx_hash': 'abcd01', 'height': 0},
|
||||||
|
{'tx_hash': 'abcd02', 'height': 1},
|
||||||
|
{'tx_hash': 'abcd03', 'height': 2},
|
||||||
|
], {
|
||||||
|
'abcd01': hexlify(get_transaction(get_output(1)).raw),
|
||||||
|
'abcd02': hexlify(get_transaction(get_output(2)).raw),
|
||||||
|
'abcd03': hexlify(get_transaction(get_output(3)).raw),
|
||||||
|
})
|
||||||
|
await self.ledger.update_history(address, '')
|
||||||
|
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||||
|
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03'])
|
||||||
|
|
||||||
|
address_details = await self.ledger.db.get_address(address=address)
|
||||||
|
self.assertEqual(
|
||||||
|
address_details['history'],
|
||||||
|
'252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:0:'
|
||||||
|
'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9:1:'
|
||||||
|
'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0:2:'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ledger.network.get_history_called = []
|
||||||
|
self.ledger.network.get_transaction_called = []
|
||||||
|
await self.ledger.update_history(address, '')
|
||||||
|
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||||
|
self.assertEqual(self.ledger.network.get_transaction_called, [])
|
||||||
|
|
||||||
|
self.ledger.network.history.append({'tx_hash': 'abcd04', 'height': 3})
|
||||||
|
self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw)
|
||||||
|
self.ledger.network.get_history_called = []
|
||||||
|
self.ledger.network.get_transaction_called = []
|
||||||
|
await self.ledger.update_history(address, '')
|
||||||
|
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||||
|
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04'])
|
||||||
|
address_details = await self.ledger.db.get_address(address=address)
|
||||||
|
self.assertEqual(
|
||||||
|
address_details['history'],
|
||||||
|
'252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:0:'
|
||||||
|
'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9:1:'
|
||||||
|
'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0:2:'
|
||||||
|
'047cf1d53ef68f0fd586d46f90c09ff8e57a4180f67e7f4b8dd0135c3741e828:3:'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MocHeaderNetwork:
|
||||||
|
def __init__(self, responses):
|
||||||
|
self.responses = responses
|
||||||
|
|
||||||
|
async def get_headers(self, height, blocks):
|
||||||
|
return self.responses[height]
|
||||||
|
|
||||||
|
|
||||||
|
class BlockchainReorganizationTests(LedgerTestCase):
|
||||||
|
|
||||||
|
async def test_1_block_reorganization(self):
|
||||||
|
self.ledger.network = MocHeaderNetwork({
|
||||||
|
20: {'height': 20, 'count': 5, 'hex': hexlify(
|
||||||
|
self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
|
||||||
|
)},
|
||||||
|
25: {'height': 25, 'count': 0, 'hex': b''}
|
||||||
|
})
|
||||||
|
headers = self.ledger.headers
|
||||||
|
await headers.connect(0, self.get_bytes(upto=block_bytes(20)))
|
||||||
|
self.add_header(block_height=len(headers))
|
||||||
|
self.assertEqual(headers.height, 20)
|
||||||
|
await self.ledger.receive_header([{
|
||||||
|
'height': 21, 'hex': hexlify(self.make_header(block_height=21))
|
||||||
|
}])
|
||||||
|
|
||||||
|
async def test_3_block_reorganization(self):
|
||||||
|
self.ledger.network = MocHeaderNetwork({
|
||||||
|
20: {'height': 20, 'count': 5, 'hex': hexlify(
|
||||||
|
self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
|
||||||
|
)},
|
||||||
|
21: {'height': 21, 'count': 1, 'hex': hexlify(self.make_header(block_height=21))},
|
||||||
|
22: {'height': 22, 'count': 1, 'hex': hexlify(self.make_header(block_height=22))},
|
||||||
|
25: {'height': 25, 'count': 0, 'hex': b''}
|
||||||
|
})
|
||||||
|
headers = self.ledger.headers
|
||||||
|
await headers.connect(0, self.get_bytes(upto=block_bytes(20)))
|
||||||
|
self.add_header(block_height=len(headers))
|
||||||
|
self.add_header(block_height=len(headers))
|
||||||
|
self.add_header(block_height=len(headers))
|
||||||
|
self.assertEqual(headers.height, 22)
|
||||||
|
await self.ledger.receive_header(({
|
||||||
|
'height': 23, 'hex': hexlify(self.make_header(block_height=23))
|
||||||
|
},))
|
23
torba/tests/client_tests/unit/test_mnemonic.py
Normal file
23
torba/tests/client_tests/unit/test_mnemonic.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
import unittest
|
||||||
|
from binascii import hexlify
|
||||||
|
|
||||||
|
from torba.client.mnemonic import Mnemonic
|
||||||
|
|
||||||
|
|
||||||
|
class TestMnemonic(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_mnemonic_to_seed(self):
|
||||||
|
seed = Mnemonic.mnemonic_to_seed(mnemonic=u'foobar', passphrase=u'torba')
|
||||||
|
self.assertEqual(
|
||||||
|
hexlify(seed),
|
||||||
|
b'475a419db4e991cab14f08bde2d357e52b3e7241f72c6d8a2f92782367feeee9f403dc6a37c26a3f02ab9'
|
||||||
|
b'dec7f5063161eb139cea00da64cd77fba2f07c49ddc'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_make_seed_decode_encode(self):
|
||||||
|
iters = 10
|
||||||
|
m = Mnemonic('en')
|
||||||
|
for _ in range(iters):
|
||||||
|
seed = m.make_seed()
|
||||||
|
i = m.mnemonic_decode(seed)
|
||||||
|
self.assertEqual(m.mnemonic_encode(i), seed)
|
218
torba/tests/client_tests/unit/test_script.py
Normal file
218
torba/tests/client_tests/unit/test_script.py
Normal file
|
@ -0,0 +1,218 @@
|
||||||
|
import unittest
|
||||||
|
from binascii import hexlify, unhexlify
|
||||||
|
|
||||||
|
from torba.client.bcd_data_stream import BCDataStream
|
||||||
|
from torba.client.basescript import Template, ParseError, tokenize, push_data
|
||||||
|
from torba.client.basescript import PUSH_SINGLE, PUSH_INTEGER, PUSH_MANY, OP_HASH160, OP_EQUAL
|
||||||
|
from torba.client.basescript import BaseInputScript, BaseOutputScript
|
||||||
|
|
||||||
|
|
||||||
|
def parse(opcodes, source):
|
||||||
|
template = Template('test', opcodes)
|
||||||
|
s = BCDataStream()
|
||||||
|
for t in source:
|
||||||
|
if isinstance(t, bytes):
|
||||||
|
s.write_many(push_data(t))
|
||||||
|
elif isinstance(t, int):
|
||||||
|
s.write_uint8(t)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
s.reset()
|
||||||
|
return template.parse(tokenize(s))
|
||||||
|
|
||||||
|
|
||||||
|
class TestScriptTemplates(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_push_data(self):
|
||||||
|
self.assertEqual(parse(
|
||||||
|
(PUSH_SINGLE('script_hash'),),
|
||||||
|
(b'abcdef',)
|
||||||
|
), {
|
||||||
|
'script_hash': b'abcdef'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(parse(
|
||||||
|
(PUSH_SINGLE('first'), PUSH_INTEGER('rating')),
|
||||||
|
(b'Satoshi', (1000).to_bytes(2, 'little'))
|
||||||
|
), {
|
||||||
|
'first': b'Satoshi',
|
||||||
|
'rating': 1000,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(parse(
|
||||||
|
(OP_HASH160, PUSH_SINGLE('script_hash'), OP_EQUAL),
|
||||||
|
(OP_HASH160, b'abcdef', OP_EQUAL)
|
||||||
|
), {
|
||||||
|
'script_hash': b'abcdef'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_push_data_many(self):
|
||||||
|
self.assertEqual(parse(
|
||||||
|
(PUSH_MANY('names'),),
|
||||||
|
(b'amit',)
|
||||||
|
), {
|
||||||
|
'names': [b'amit']
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(parse(
|
||||||
|
(PUSH_MANY('names'),),
|
||||||
|
(b'jeremy', b'amit', b'victor')
|
||||||
|
), {
|
||||||
|
'names': [b'jeremy', b'amit', b'victor']
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(parse(
|
||||||
|
(OP_HASH160, PUSH_MANY('names'), OP_EQUAL),
|
||||||
|
(OP_HASH160, b'grin', b'jack', OP_EQUAL)
|
||||||
|
), {
|
||||||
|
'names': [b'grin', b'jack']
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_push_data_mixed(self):
|
||||||
|
self.assertEqual(parse(
|
||||||
|
(PUSH_SINGLE('CEO'), PUSH_MANY('Devs'), PUSH_SINGLE('CTO'), PUSH_SINGLE('State')),
|
||||||
|
(b'jeremy', b'lex', b'amit', b'victor', b'jack', b'grin', b'NH')
|
||||||
|
), {
|
||||||
|
'CEO': b'jeremy',
|
||||||
|
'CTO': b'grin',
|
||||||
|
'Devs': [b'lex', b'amit', b'victor', b'jack'],
|
||||||
|
'State': b'NH'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_push_data_many_separated(self):
|
||||||
|
self.assertEqual(parse(
|
||||||
|
(PUSH_MANY('Chiefs'), OP_HASH160, PUSH_MANY('Devs')),
|
||||||
|
(b'jeremy', b'grin', OP_HASH160, b'lex', b'jack')
|
||||||
|
), {
|
||||||
|
'Chiefs': [b'jeremy', b'grin'],
|
||||||
|
'Devs': [b'lex', b'jack']
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_push_data_many_not_separated(self):
|
||||||
|
with self.assertRaisesRegex(ParseError, 'consecutive PUSH_MANY'):
|
||||||
|
parse((PUSH_MANY('Chiefs'), PUSH_MANY('Devs')), (b'jeremy', b'grin', b'lex', b'jack'))
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedeemPubKeyHash(unittest.TestCase):
|
||||||
|
|
||||||
|
def redeem_pubkey_hash(self, sig, pubkey):
|
||||||
|
# this checks that factory function correctly sets up the script
|
||||||
|
src1 = BaseInputScript.redeem_pubkey_hash(unhexlify(sig), unhexlify(pubkey))
|
||||||
|
self.assertEqual(src1.template.name, 'pubkey_hash')
|
||||||
|
self.assertEqual(hexlify(src1.values['signature']), sig)
|
||||||
|
self.assertEqual(hexlify(src1.values['pubkey']), pubkey)
|
||||||
|
# now we test that it will round trip
|
||||||
|
src2 = BaseInputScript(src1.source)
|
||||||
|
self.assertEqual(src2.template.name, 'pubkey_hash')
|
||||||
|
self.assertEqual(hexlify(src2.values['signature']), sig)
|
||||||
|
self.assertEqual(hexlify(src2.values['pubkey']), pubkey)
|
||||||
|
return hexlify(src1.source)
|
||||||
|
|
||||||
|
def test_redeem_pubkey_hash_1(self):
|
||||||
|
self.assertEqual(
|
||||||
|
self.redeem_pubkey_hash(
|
||||||
|
b'30450221009dc93f25184a8d483745cd3eceff49727a317c9bfd8be8d3d04517e9cdaf8dd502200e'
|
||||||
|
b'02dc5939cad9562d2b1f303f185957581c4851c98d497af281118825e18a8301',
|
||||||
|
b'025415a06514230521bff3aaface31f6db9d9bbc39bf1ca60a189e78731cfd4e1b'
|
||||||
|
),
|
||||||
|
b'4830450221009dc93f25184a8d483745cd3eceff49727a317c9bfd8be8d3d04517e9cdaf8dd502200e02d'
|
||||||
|
b'c5939cad9562d2b1f303f185957581c4851c98d497af281118825e18a830121025415a06514230521bff3'
|
||||||
|
b'aaface31f6db9d9bbc39bf1ca60a189e78731cfd4e1b'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedeemScriptHash(unittest.TestCase):
|
||||||
|
|
||||||
|
def redeem_script_hash(self, sigs, pubkeys):
|
||||||
|
# this checks that factory function correctly sets up the script
|
||||||
|
src1 = BaseInputScript.redeem_script_hash(
|
||||||
|
[unhexlify(sig) for sig in sigs],
|
||||||
|
[unhexlify(pubkey) for pubkey in pubkeys]
|
||||||
|
)
|
||||||
|
subscript1 = src1.values['script']
|
||||||
|
self.assertEqual(src1.template.name, 'script_hash')
|
||||||
|
self.assertEqual([hexlify(v) for v in src1.values['signatures']], sigs)
|
||||||
|
self.assertEqual([hexlify(p) for p in subscript1.values['pubkeys']], pubkeys)
|
||||||
|
self.assertEqual(subscript1.values['signatures_count'], len(sigs))
|
||||||
|
self.assertEqual(subscript1.values['pubkeys_count'], len(pubkeys))
|
||||||
|
# now we test that it will round trip
|
||||||
|
src2 = BaseInputScript(src1.source)
|
||||||
|
subscript2 = src2.values['script']
|
||||||
|
self.assertEqual(src2.template.name, 'script_hash')
|
||||||
|
self.assertEqual([hexlify(v) for v in src2.values['signatures']], sigs)
|
||||||
|
self.assertEqual([hexlify(p) for p in subscript2.values['pubkeys']], pubkeys)
|
||||||
|
self.assertEqual(subscript2.values['signatures_count'], len(sigs))
|
||||||
|
self.assertEqual(subscript2.values['pubkeys_count'], len(pubkeys))
|
||||||
|
return hexlify(src1.source)
|
||||||
|
|
||||||
|
def test_redeem_script_hash_1(self):
|
||||||
|
self.assertEqual(
|
||||||
|
self.redeem_script_hash([
|
||||||
|
b'3045022100fec82ed82687874f2a29cbdc8334e114af645c45298e85bb1efe69fcf15c617a0220575'
|
||||||
|
b'e40399f9ada388d8e522899f4ec3b7256896dd9b02742f6567d960b613f0401',
|
||||||
|
b'3044022024890462f731bd1a42a4716797bad94761fc4112e359117e591c07b8520ea33b02201ac68'
|
||||||
|
b'9e35c4648e6beff1d42490207ba14027a638a62663b2ee40153299141eb01',
|
||||||
|
b'30450221009910823e0142967a73c2d16c1560054d71c0625a385904ba2f1f53e0bc1daa8d02205cd'
|
||||||
|
b'70a89c6cf031a8b07d1d5eb0d65d108c4d49c2d403f84fb03ad3dc318777a01'
|
||||||
|
], [
|
||||||
|
b'0372ba1fd35e5f1b1437cba0c4ebfc4025b7349366f9f9c7c8c4b03a47bd3f68a4',
|
||||||
|
b'03061d250182b2db1ba144167fd8b0ef3fe0fc3a2fa046958f835ffaf0dfdb7692',
|
||||||
|
b'02463bfbc1eaec74b5c21c09239ae18dbf6fc07833917df10d0b43e322810cee0c',
|
||||||
|
b'02fa6a6455c26fb516cfa85ea8de81dd623a893ffd579ee2a00deb6cdf3633d6bb',
|
||||||
|
b'0382910eae483ce4213d79d107bfc78f3d77e2a31ea597be45256171ad0abeaa89'
|
||||||
|
]),
|
||||||
|
b'00483045022100fec82ed82687874f2a29cbdc8334e114af645c45298e85bb1efe69fcf15c617a0220575e'
|
||||||
|
b'40399f9ada388d8e522899f4ec3b7256896dd9b02742f6567d960b613f0401473044022024890462f731bd'
|
||||||
|
b'1a42a4716797bad94761fc4112e359117e591c07b8520ea33b02201ac689e35c4648e6beff1d42490207ba'
|
||||||
|
b'14027a638a62663b2ee40153299141eb014830450221009910823e0142967a73c2d16c1560054d71c0625a'
|
||||||
|
b'385904ba2f1f53e0bc1daa8d02205cd70a89c6cf031a8b07d1d5eb0d65d108c4d49c2d403f84fb03ad3dc3'
|
||||||
|
b'18777a014cad53210372ba1fd35e5f1b1437cba0c4ebfc4025b7349366f9f9c7c8c4b03a47bd3f68a42103'
|
||||||
|
b'061d250182b2db1ba144167fd8b0ef3fe0fc3a2fa046958f835ffaf0dfdb76922102463bfbc1eaec74b5c2'
|
||||||
|
b'1c09239ae18dbf6fc07833917df10d0b43e322810cee0c2102fa6a6455c26fb516cfa85ea8de81dd623a89'
|
||||||
|
b'3ffd579ee2a00deb6cdf3633d6bb210382910eae483ce4213d79d107bfc78f3d77e2a31ea597be45256171'
|
||||||
|
b'ad0abeaa8955ae'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPayPubKeyHash(unittest.TestCase):
|
||||||
|
|
||||||
|
def pay_pubkey_hash(self, pubkey_hash):
|
||||||
|
# this checks that factory function correctly sets up the script
|
||||||
|
src1 = BaseOutputScript.pay_pubkey_hash(unhexlify(pubkey_hash))
|
||||||
|
self.assertEqual(src1.template.name, 'pay_pubkey_hash')
|
||||||
|
self.assertEqual(hexlify(src1.values['pubkey_hash']), pubkey_hash)
|
||||||
|
# now we test that it will round trip
|
||||||
|
src2 = BaseOutputScript(src1.source)
|
||||||
|
self.assertEqual(src2.template.name, 'pay_pubkey_hash')
|
||||||
|
self.assertEqual(hexlify(src2.values['pubkey_hash']), pubkey_hash)
|
||||||
|
return hexlify(src1.source)
|
||||||
|
|
||||||
|
def test_pay_pubkey_hash_1(self):
|
||||||
|
self.assertEqual(
|
||||||
|
self.pay_pubkey_hash(b'64d74d12acc93ba1ad495e8d2d0523252d664f4d'),
|
||||||
|
b'76a91464d74d12acc93ba1ad495e8d2d0523252d664f4d88ac'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPayScriptHash(unittest.TestCase):
|
||||||
|
|
||||||
|
def pay_script_hash(self, script_hash):
|
||||||
|
# this checks that factory function correctly sets up the script
|
||||||
|
src1 = BaseOutputScript.pay_script_hash(unhexlify(script_hash))
|
||||||
|
self.assertEqual(src1.template.name, 'pay_script_hash')
|
||||||
|
self.assertEqual(hexlify(src1.values['script_hash']), script_hash)
|
||||||
|
# now we test that it will round trip
|
||||||
|
src2 = BaseOutputScript(src1.source)
|
||||||
|
self.assertEqual(src2.template.name, 'pay_script_hash')
|
||||||
|
self.assertEqual(hexlify(src2.values['script_hash']), script_hash)
|
||||||
|
return hexlify(src1.source)
|
||||||
|
|
||||||
|
def test_pay_pubkey_hash_1(self):
|
||||||
|
self.assertEqual(
|
||||||
|
self.pay_script_hash(b'63d65a2ee8c44426d06050cfd71c0f0ff3fc41ac'),
|
||||||
|
b'a91463d65a2ee8c44426d06050cfd71c0f0ff3fc41ac87'
|
||||||
|
)
|
345
torba/tests/client_tests/unit/test_transaction.py
Normal file
345
torba/tests/client_tests/unit/test_transaction.py
Normal file
|
@ -0,0 +1,345 @@
|
||||||
|
import unittest
|
||||||
|
from binascii import hexlify, unhexlify
|
||||||
|
from itertools import cycle
|
||||||
|
|
||||||
|
from torba.testcase import AsyncioTestCase
|
||||||
|
|
||||||
|
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
|
||||||
|
from torba.client.wallet import Wallet
|
||||||
|
from torba.client.constants import CENT, COIN
|
||||||
|
|
||||||
|
|
||||||
|
NULL_HASH = b'\x00'*32
|
||||||
|
FEE_PER_BYTE = 50
|
||||||
|
FEE_PER_CHAR = 200000
|
||||||
|
|
||||||
|
|
||||||
|
def get_output(amount=CENT, pubkey_hash=NULL_HASH, height=-2):
|
||||||
|
return ledger_class.transaction_class(height=height) \
|
||||||
|
.add_outputs([ledger_class.transaction_class.output_class.pay_pubkey_hash(amount, pubkey_hash)]) \
|
||||||
|
.outputs[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_input(amount=CENT, pubkey_hash=NULL_HASH):
|
||||||
|
return ledger_class.transaction_class.input_class.spend(get_output(amount, pubkey_hash))
|
||||||
|
|
||||||
|
|
||||||
|
def get_transaction(txo=None):
|
||||||
|
return ledger_class.transaction_class() \
|
||||||
|
.add_inputs([get_input()]) \
|
||||||
|
.add_outputs([txo or ledger_class.transaction_class.output_class.pay_pubkey_hash(CENT, NULL_HASH)])
|
||||||
|
|
||||||
|
|
||||||
|
class TestSizeAndFeeEstimation(AsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.ledger = ledger_class({
|
||||||
|
'db': ledger_class.database_class(':memory:'),
|
||||||
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_output_size_and_fee(self):
|
||||||
|
txo = get_output()
|
||||||
|
self.assertEqual(txo.size, 46)
|
||||||
|
self.assertEqual(txo.get_fee(self.ledger), 46 * FEE_PER_BYTE)
|
||||||
|
|
||||||
|
def test_input_size_and_fee(self):
|
||||||
|
txi = get_input()
|
||||||
|
self.assertEqual(txi.size, 148)
|
||||||
|
self.assertEqual(txi.get_fee(self.ledger), 148 * FEE_PER_BYTE)
|
||||||
|
|
||||||
|
def test_transaction_size_and_fee(self):
|
||||||
|
tx = get_transaction()
|
||||||
|
self.assertEqual(tx.size, 204)
|
||||||
|
self.assertEqual(tx.base_size, tx.size - tx.inputs[0].size - tx.outputs[0].size)
|
||||||
|
self.assertEqual(tx.get_base_fee(self.ledger), FEE_PER_BYTE * tx.base_size)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAccountBalanceImpactFromTransaction(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_is_my_account_not_set(self):
|
||||||
|
tx = get_transaction()
|
||||||
|
with self.assertRaisesRegex(ValueError, "Cannot access net_account_balance"):
|
||||||
|
_ = tx.net_account_balance
|
||||||
|
tx.inputs[0].txo_ref.txo.is_my_account = True
|
||||||
|
with self.assertRaisesRegex(ValueError, "Cannot access net_account_balance"):
|
||||||
|
_ = tx.net_account_balance
|
||||||
|
tx.outputs[0].is_my_account = True
|
||||||
|
# all inputs/outputs are set now so it should work
|
||||||
|
_ = tx.net_account_balance
|
||||||
|
|
||||||
|
def test_paying_from_my_account_to_other_account(self):
|
||||||
|
tx = ledger_class.transaction_class() \
|
||||||
|
.add_inputs([get_input(300*CENT)]) \
|
||||||
|
.add_outputs([get_output(190*CENT, NULL_HASH),
|
||||||
|
get_output(100*CENT, NULL_HASH)])
|
||||||
|
tx.inputs[0].txo_ref.txo.is_my_account = True
|
||||||
|
tx.outputs[0].is_my_account = False
|
||||||
|
tx.outputs[1].is_my_account = True
|
||||||
|
self.assertEqual(tx.net_account_balance, -200*CENT)
|
||||||
|
|
||||||
|
def test_paying_from_other_account_to_my_account(self):
|
||||||
|
tx = ledger_class.transaction_class() \
|
||||||
|
.add_inputs([get_input(300*CENT)]) \
|
||||||
|
.add_outputs([get_output(190*CENT, NULL_HASH),
|
||||||
|
get_output(100*CENT, NULL_HASH)])
|
||||||
|
tx.inputs[0].txo_ref.txo.is_my_account = False
|
||||||
|
tx.outputs[0].is_my_account = True
|
||||||
|
tx.outputs[1].is_my_account = False
|
||||||
|
self.assertEqual(tx.net_account_balance, 190*CENT)
|
||||||
|
|
||||||
|
def test_paying_from_my_account_to_my_account(self):
|
||||||
|
tx = ledger_class.transaction_class() \
|
||||||
|
.add_inputs([get_input(300*CENT)]) \
|
||||||
|
.add_outputs([get_output(190*CENT, NULL_HASH),
|
||||||
|
get_output(100*CENT, NULL_HASH)])
|
||||||
|
tx.inputs[0].txo_ref.txo.is_my_account = True
|
||||||
|
tx.outputs[0].is_my_account = True
|
||||||
|
tx.outputs[1].is_my_account = True
|
||||||
|
self.assertEqual(tx.net_account_balance, -10*CENT) # lost to fee
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransactionSerialization(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_genesis_transaction(self):
|
||||||
|
raw = unhexlify(
|
||||||
|
'01000000010000000000000000000000000000000000000000000000000000000000000000ffffffff4d04'
|
||||||
|
'ffff001d0104455468652054696d65732030332f4a616e2f32303039204368616e63656c6c6f72206f6e20'
|
||||||
|
'6272696e6b206f66207365636f6e64206261696c6f757420666f722062616e6b73ffffffff0100f2052a01'
|
||||||
|
'000000434104678afdb0fe5548271967f1a67130b7105cd6a828e03909a67962e0ea1f61deb649f6bc3f4c'
|
||||||
|
'ef38c4f35504e51ec112de5c384df7ba0b8d578a4c702b6bf11d5fac00000000'
|
||||||
|
)
|
||||||
|
tx = ledger_class.transaction_class(raw)
|
||||||
|
self.assertEqual(tx.version, 1)
|
||||||
|
self.assertEqual(tx.locktime, 0)
|
||||||
|
self.assertEqual(len(tx.inputs), 1)
|
||||||
|
self.assertEqual(len(tx.outputs), 1)
|
||||||
|
|
||||||
|
coinbase = tx.inputs[0]
|
||||||
|
self.assertTrue(coinbase.txo_ref.is_null, NULL_HASH)
|
||||||
|
self.assertEqual(coinbase.txo_ref.position, 0xFFFFFFFF)
|
||||||
|
self.assertEqual(coinbase.sequence, 4294967295)
|
||||||
|
self.assertIsNotNone(coinbase.coinbase)
|
||||||
|
self.assertIsNone(coinbase.script)
|
||||||
|
self.assertEqual(
|
||||||
|
coinbase.coinbase[8:],
|
||||||
|
b'The Times 03/Jan/2009 Chancellor on brink of second bailout for banks'
|
||||||
|
)
|
||||||
|
|
||||||
|
out = tx.outputs[0]
|
||||||
|
self.assertEqual(out.amount, 5000000000)
|
||||||
|
self.assertEqual(out.position, 0)
|
||||||
|
self.assertTrue(out.script.is_pay_pubkey)
|
||||||
|
self.assertFalse(out.script.is_pay_pubkey_hash)
|
||||||
|
self.assertFalse(out.script.is_pay_script_hash)
|
||||||
|
|
||||||
|
tx._reset()
|
||||||
|
self.assertEqual(tx.raw, raw)
|
||||||
|
|
||||||
|
def test_coinbase_transaction(self):
|
||||||
|
raw = unhexlify(
|
||||||
|
'01000000010000000000000000000000000000000000000000000000000000000000000000ffffffff4e03'
|
||||||
|
'1f5a070473319e592f4254432e434f4d2f4e59412ffabe6d6dcceb2a9d0444c51cabc4ee97a1a000036ca0'
|
||||||
|
'cb48d25b94b78c8367d8b868454b0100000000000000c0309b21000008c5f8f80000ffffffff0291920b5d'
|
||||||
|
'0000000017a914e083685a1097ce1ea9e91987ab9e94eae33d8a13870000000000000000266a24aa21a9ed'
|
||||||
|
'e6c99265a6b9e1d36c962fda0516b35709c49dc3b8176fa7e5d5f1f6197884b400000000'
|
||||||
|
)
|
||||||
|
tx = ledger_class.transaction_class(raw)
|
||||||
|
self.assertEqual(tx.version, 1)
|
||||||
|
self.assertEqual(tx.locktime, 0)
|
||||||
|
self.assertEqual(len(tx.inputs), 1)
|
||||||
|
self.assertEqual(len(tx.outputs), 2)
|
||||||
|
|
||||||
|
coinbase = tx.inputs[0]
|
||||||
|
self.assertTrue(coinbase.txo_ref.is_null)
|
||||||
|
self.assertEqual(coinbase.txo_ref.position, 0xFFFFFFFF)
|
||||||
|
self.assertEqual(coinbase.sequence, 4294967295)
|
||||||
|
self.assertIsNotNone(coinbase.coinbase)
|
||||||
|
self.assertIsNone(coinbase.script)
|
||||||
|
self.assertEqual(coinbase.coinbase[9:22], b'/BTC.COM/NYA/')
|
||||||
|
|
||||||
|
out = tx.outputs[0]
|
||||||
|
self.assertEqual(out.amount, 1561039505)
|
||||||
|
self.assertEqual(out.position, 0)
|
||||||
|
self.assertFalse(out.script.is_pay_pubkey)
|
||||||
|
self.assertFalse(out.script.is_pay_pubkey_hash)
|
||||||
|
self.assertTrue(out.script.is_pay_script_hash)
|
||||||
|
self.assertFalse(out.script.is_return_data)
|
||||||
|
|
||||||
|
out1 = tx.outputs[1]
|
||||||
|
self.assertEqual(out1.amount, 0)
|
||||||
|
self.assertEqual(out1.position, 1)
|
||||||
|
self.assertEqual(
|
||||||
|
hexlify(out1.script.values['data']),
|
||||||
|
b'aa21a9ede6c99265a6b9e1d36c962fda0516b35709c49dc3b8176fa7e5d5f1f6197884b4'
|
||||||
|
)
|
||||||
|
self.assertTrue(out1.script.is_return_data)
|
||||||
|
self.assertFalse(out1.script.is_pay_pubkey)
|
||||||
|
self.assertFalse(out1.script.is_pay_pubkey_hash)
|
||||||
|
self.assertFalse(out1.script.is_pay_script_hash)
|
||||||
|
|
||||||
|
tx._reset()
|
||||||
|
self.assertEqual(tx.raw, raw)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransactionSigning(AsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.ledger = ledger_class({
|
||||||
|
'db': ledger_class.database_class(':memory:'),
|
||||||
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
|
})
|
||||||
|
await self.ledger.db.open()
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
await self.ledger.db.close()
|
||||||
|
|
||||||
|
async def test_sign(self):
|
||||||
|
account = self.ledger.account_class.from_dict(
|
||||||
|
self.ledger, Wallet(), {
|
||||||
|
"seed": "carbon smart garage balance margin twelve chest sword "
|
||||||
|
"toast envelope bottom stomach absent"
|
||||||
|
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await account.ensure_address_gap()
|
||||||
|
address1, address2 = await account.receiving.get_addresses(limit=2)
|
||||||
|
pubkey_hash1 = self.ledger.address_to_hash160(address1)
|
||||||
|
pubkey_hash2 = self.ledger.address_to_hash160(address2)
|
||||||
|
|
||||||
|
tx_class = ledger_class.transaction_class
|
||||||
|
|
||||||
|
tx = tx_class() \
|
||||||
|
.add_inputs([tx_class.input_class.spend(get_output(2*COIN, pubkey_hash1))]) \
|
||||||
|
.add_outputs([tx_class.output_class.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)]) \
|
||||||
|
|
||||||
|
await tx.sign([account])
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
hexlify(tx.inputs[0].script.values['signature']),
|
||||||
|
b'304402205a1df8cd5d2d2fa5934b756883d6c07e4f83e1350c740992d47a12422'
|
||||||
|
b'226aaa202200098ac8675827aea2b0d6f0e49566143a95d523e311d342172cd99e2021e47cb01'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionIOBalancing(AsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.ledger = ledger_class({
|
||||||
|
'db': ledger_class.database_class(':memory:'),
|
||||||
|
'headers': ledger_class.headers_class(':memory:'),
|
||||||
|
})
|
||||||
|
await self.ledger.db.open()
|
||||||
|
self.account = self.ledger.account_class.from_dict(
|
||||||
|
self.ledger, Wallet(), {
|
||||||
|
"seed": "carbon smart garage balance margin twelve chest sword "
|
||||||
|
"toast envelope bottom stomach absent"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
addresses = await self.account.ensure_address_gap()
|
||||||
|
self.pubkey_hash = [self.ledger.address_to_hash160(a) for a in addresses]
|
||||||
|
self.hash_cycler = cycle(self.pubkey_hash)
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
await self.ledger.db.close()
|
||||||
|
|
||||||
|
def txo(self, amount, address=None):
|
||||||
|
return get_output(int(amount*COIN), address or next(self.hash_cycler))
|
||||||
|
|
||||||
|
def txi(self, txo):
|
||||||
|
return ledger_class.transaction_class.input_class.spend(txo)
|
||||||
|
|
||||||
|
def tx(self, inputs, outputs):
|
||||||
|
return ledger_class.transaction_class.create(inputs, outputs, [self.account], self.account)
|
||||||
|
|
||||||
|
async def create_utxos(self, amounts):
|
||||||
|
utxos = [self.txo(amount) for amount in amounts]
|
||||||
|
|
||||||
|
self.funding_tx = ledger_class.transaction_class(is_verified=True) \
|
||||||
|
.add_inputs([self.txi(self.txo(sum(amounts)+0.1))]) \
|
||||||
|
.add_outputs(utxos)
|
||||||
|
|
||||||
|
await self.ledger.db.insert_transaction(self.funding_tx)
|
||||||
|
|
||||||
|
for utxo in utxos:
|
||||||
|
await self.ledger.db.save_transaction_io(
|
||||||
|
self.funding_tx,
|
||||||
|
self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']),
|
||||||
|
utxo.script.values['pubkey_hash'], ''
|
||||||
|
)
|
||||||
|
|
||||||
|
return utxos
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def inputs(tx):
|
||||||
|
return [round(i.amount/COIN, 2) for i in tx.inputs]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def outputs(tx):
|
||||||
|
return [round(o.amount/COIN, 2) for o in tx.outputs]
|
||||||
|
|
||||||
|
async def test_basic_use_cases(self):
|
||||||
|
self.ledger.fee_per_byte = int(.01*CENT)
|
||||||
|
|
||||||
|
# available UTXOs for filling missing inputs
|
||||||
|
utxos = await self.create_utxos([
|
||||||
|
1, 1, 3, 5, 10
|
||||||
|
])
|
||||||
|
|
||||||
|
# pay 3 coins (3.02 w/ fees)
|
||||||
|
tx = await self.tx(
|
||||||
|
[], # inputs
|
||||||
|
[self.txo(3)] # outputs
|
||||||
|
)
|
||||||
|
# best UTXO match is 5 (as UTXO 3 will be short 0.02 to cover fees)
|
||||||
|
self.assertEqual(self.inputs(tx), [5])
|
||||||
|
# a change of 1.98 is added to reach balance
|
||||||
|
self.assertEqual(self.outputs(tx), [3, 1.98])
|
||||||
|
|
||||||
|
await self.ledger.release_outputs(utxos)
|
||||||
|
|
||||||
|
# pay 2.98 coins (3.00 w/ fees)
|
||||||
|
tx = await self.tx(
|
||||||
|
[], # inputs
|
||||||
|
[self.txo(2.98)] # outputs
|
||||||
|
)
|
||||||
|
# best UTXO match is 3 and no change is needed
|
||||||
|
self.assertEqual(self.inputs(tx), [3])
|
||||||
|
self.assertEqual(self.outputs(tx), [2.98])
|
||||||
|
|
||||||
|
await self.ledger.release_outputs(utxos)
|
||||||
|
|
||||||
|
# supplied input and output, but input is not enough to cover output
|
||||||
|
tx = await self.tx(
|
||||||
|
[self.txi(self.txo(10))], # inputs
|
||||||
|
[self.txo(11)] # outputs
|
||||||
|
)
|
||||||
|
# additional input is chosen (UTXO 3)
|
||||||
|
self.assertEqual([10, 3], self.inputs(tx))
|
||||||
|
# change is now needed to consume extra input
|
||||||
|
self.assertEqual([11, 1.96], self.outputs(tx))
|
||||||
|
|
||||||
|
await self.ledger.release_outputs(utxos)
|
||||||
|
|
||||||
|
# liquidating a UTXO
|
||||||
|
tx = await self.tx(
|
||||||
|
[self.txi(self.txo(10))], # inputs
|
||||||
|
[] # outputs
|
||||||
|
)
|
||||||
|
self.assertEqual([10], self.inputs(tx))
|
||||||
|
# missing change added to consume the amount
|
||||||
|
self.assertEqual([9.98], self.outputs(tx))
|
||||||
|
|
||||||
|
await self.ledger.release_outputs(utxos)
|
||||||
|
|
||||||
|
# liquidating at a loss, requires adding extra inputs
|
||||||
|
tx = await self.tx(
|
||||||
|
[self.txi(self.txo(0.01))], # inputs
|
||||||
|
[] # outputs
|
||||||
|
)
|
||||||
|
# UTXO 1 is added to cover some of the fee
|
||||||
|
self.assertEqual([0.01, 1], self.inputs(tx))
|
||||||
|
# change is now needed to consume extra input
|
||||||
|
self.assertEqual([0.97], self.outputs(tx))
|
95
torba/tests/client_tests/unit/test_utils.py
Normal file
95
torba/tests/client_tests/unit/test_utils.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from torba.client.util import ArithUint256
|
||||||
|
from torba.client.util import coins_to_satoshis as c2s, satoshis_to_coins as s2c
|
||||||
|
|
||||||
|
|
||||||
|
class TestCoinValueParsing(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_good_output(self):
|
||||||
|
self.assertEqual(s2c(1), "0.00000001")
|
||||||
|
self.assertEqual(s2c(10**7), "0.1")
|
||||||
|
self.assertEqual(s2c(2*10**8), "2.0")
|
||||||
|
self.assertEqual(s2c(2*10**17), "2000000000.0")
|
||||||
|
|
||||||
|
def test_good_input(self):
|
||||||
|
self.assertEqual(c2s("0.00000001"), 1)
|
||||||
|
self.assertEqual(c2s("0.1"), 10**7)
|
||||||
|
self.assertEqual(c2s("1.0"), 10**8)
|
||||||
|
self.assertEqual(c2s("2.00000000"), 2*10**8)
|
||||||
|
self.assertEqual(c2s("2000000000.0"), 2*10**17)
|
||||||
|
|
||||||
|
def test_bad_input(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
c2s("1")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
c2s("-1.0")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
c2s("10000000000.0")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
c2s("1.000000000")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
c2s("-0")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
c2s("1")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
c2s(".1")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
c2s("1e-7")
|
||||||
|
|
||||||
|
|
||||||
|
class TestArithUint256(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_arithunit256(self):
|
||||||
|
# https://github.com/bitcoin/bitcoin/blob/master/src/test/arith_uint256_tests.cpp
|
||||||
|
|
||||||
|
from_compact = ArithUint256.from_compact
|
||||||
|
eq = self.assertEqual
|
||||||
|
|
||||||
|
eq(from_compact(0).value, 0)
|
||||||
|
eq(from_compact(0x00123456).value, 0)
|
||||||
|
eq(from_compact(0x01003456).value, 0)
|
||||||
|
eq(from_compact(0x02000056).value, 0)
|
||||||
|
eq(from_compact(0x03000000).value, 0)
|
||||||
|
eq(from_compact(0x04000000).value, 0)
|
||||||
|
eq(from_compact(0x00923456).value, 0)
|
||||||
|
eq(from_compact(0x01803456).value, 0)
|
||||||
|
eq(from_compact(0x02800056).value, 0)
|
||||||
|
eq(from_compact(0x03800000).value, 0)
|
||||||
|
eq(from_compact(0x04800000).value, 0)
|
||||||
|
|
||||||
|
# Make sure that we don't generate compacts with the 0x00800000 bit set
|
||||||
|
uint = ArithUint256(0x80)
|
||||||
|
eq(uint.compact, 0x02008000)
|
||||||
|
|
||||||
|
uint = from_compact(0x01123456)
|
||||||
|
eq(uint.value, 0x12)
|
||||||
|
eq(uint.compact, 0x01120000)
|
||||||
|
|
||||||
|
uint = from_compact(0x01fedcba)
|
||||||
|
eq(uint.value, 0x7e)
|
||||||
|
eq(uint.negative, 0x01fe0000)
|
||||||
|
|
||||||
|
uint = from_compact(0x02123456)
|
||||||
|
eq(uint.value, 0x1234)
|
||||||
|
eq(uint.compact, 0x02123400)
|
||||||
|
|
||||||
|
uint = from_compact(0x03123456)
|
||||||
|
eq(uint.value, 0x123456)
|
||||||
|
eq(uint.compact, 0x03123456)
|
||||||
|
|
||||||
|
uint = from_compact(0x04123456)
|
||||||
|
eq(uint.value, 0x12345600)
|
||||||
|
eq(uint.compact, 0x04123456)
|
||||||
|
|
||||||
|
uint = from_compact(0x04923456)
|
||||||
|
eq(uint.value, 0x12345600)
|
||||||
|
eq(uint.negative, 0x04923456)
|
||||||
|
|
||||||
|
uint = from_compact(0x05009234)
|
||||||
|
eq(uint.value, 0x92340000)
|
||||||
|
eq(uint.compact, 0x05009234)
|
||||||
|
|
||||||
|
uint = from_compact(0x20123456)
|
||||||
|
eq(uint.value, 0x1234560000000000000000000000000000000000000000000000000000000000)
|
||||||
|
eq(uint.compact, 0x20123456)
|
93
torba/tests/client_tests/unit/test_wallet.py
Normal file
93
torba/tests/client_tests/unit/test_wallet.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
import tempfile
|
||||||
|
from binascii import hexlify
|
||||||
|
|
||||||
|
from torba.testcase import AsyncioTestCase
|
||||||
|
|
||||||
|
from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger
|
||||||
|
from torba.coin.bitcoincash import MainNetLedger as BCHLedger
|
||||||
|
from torba.client.basemanager import BaseWalletManager
|
||||||
|
from torba.client.wallet import Wallet, WalletStorage
|
||||||
|
|
||||||
|
|
||||||
|
class TestWalletCreation(AsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.manager = BaseWalletManager()
|
||||||
|
config = {'data_path': '/tmp/wallet'}
|
||||||
|
self.btc_ledger = self.manager.get_or_create_ledger(BTCLedger.get_id(), config)
|
||||||
|
self.bch_ledger = self.manager.get_or_create_ledger(BCHLedger.get_id(), config)
|
||||||
|
|
||||||
|
def test_create_wallet_and_accounts(self):
|
||||||
|
wallet = Wallet()
|
||||||
|
self.assertEqual(wallet.name, 'Wallet')
|
||||||
|
self.assertEqual(wallet.accounts, [])
|
||||||
|
|
||||||
|
account1 = wallet.generate_account(self.btc_ledger)
|
||||||
|
wallet.generate_account(self.btc_ledger)
|
||||||
|
wallet.generate_account(self.bch_ledger)
|
||||||
|
self.assertEqual(wallet.default_account, account1)
|
||||||
|
self.assertEqual(len(wallet.accounts), 3)
|
||||||
|
|
||||||
|
def test_load_and_save_wallet(self):
|
||||||
|
wallet_dict = {
|
||||||
|
'version': 1,
|
||||||
|
'name': 'Main Wallet',
|
||||||
|
'accounts': [
|
||||||
|
{
|
||||||
|
'name': 'An Account',
|
||||||
|
'ledger': 'btc_mainnet',
|
||||||
|
'modified_on': 123.456,
|
||||||
|
'seed':
|
||||||
|
"carbon smart garage balance margin twelve chest sword toast envelope bottom stomac"
|
||||||
|
"h absent",
|
||||||
|
'encrypted': False,
|
||||||
|
'private_key':
|
||||||
|
'xprv9s21ZrQH143K3TsAz5efNV8K93g3Ms3FXcjaWB9fVUsMwAoE3Z'
|
||||||
|
'T4vYymkp5BxKKfnpz8J6sHDFriX1SnpvjNkzcks8XBnxjGLS83BTyfpna',
|
||||||
|
'public_key':
|
||||||
|
'xpub661MyMwAqRbcFwwe67Bfjd53h5WXmKm6tqfBJZZH3pQLoy8Nb6'
|
||||||
|
'mKUMJFc7UbpVNzmwFPN2evn3YHnig1pkKVYcvCV8owTd2yAcEkJfCX53g',
|
||||||
|
'address_generator': {
|
||||||
|
'name': 'deterministic-chain',
|
||||||
|
'receiving': {'gap': 17, 'maximum_uses_per_address': 3},
|
||||||
|
'change': {'gap': 10, 'maximum_uses_per_address': 3}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
storage = WalletStorage(default=wallet_dict)
|
||||||
|
wallet = Wallet.from_storage(storage, self.manager)
|
||||||
|
self.assertEqual(wallet.name, 'Main Wallet')
|
||||||
|
self.assertEqual(
|
||||||
|
hexlify(wallet.hash), b'9f462b8dd802eb8c913e54f09a09827ebc14abbc13f33baa90d8aec5ae920fc7'
|
||||||
|
)
|
||||||
|
self.assertEqual(len(wallet.accounts), 1)
|
||||||
|
account = wallet.default_account
|
||||||
|
self.assertIsInstance(account, BTCLedger.account_class)
|
||||||
|
self.maxDiff = None
|
||||||
|
self.assertDictEqual(wallet_dict, wallet.to_dict())
|
||||||
|
|
||||||
|
encrypted = wallet.pack('password')
|
||||||
|
decrypted = Wallet.unpack('password', encrypted)
|
||||||
|
self.assertEqual(decrypted['accounts'][0]['name'], 'An Account')
|
||||||
|
|
||||||
|
def test_read_write(self):
|
||||||
|
manager = BaseWalletManager()
|
||||||
|
config = {'data_path': '/tmp/wallet'}
|
||||||
|
ledger = manager.get_or_create_ledger(BTCLedger.get_id(), config)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file:
|
||||||
|
wallet_file.write(b'{"version": 1}')
|
||||||
|
wallet_file.seek(0)
|
||||||
|
|
||||||
|
# create and write wallet to a file
|
||||||
|
wallet = manager.import_wallet(wallet_file.name)
|
||||||
|
account = wallet.generate_account(ledger)
|
||||||
|
wallet.save()
|
||||||
|
|
||||||
|
# read wallet from file
|
||||||
|
wallet_storage = WalletStorage(wallet_file.name)
|
||||||
|
wallet = Wallet.from_storage(wallet_storage, manager)
|
||||||
|
|
||||||
|
self.assertEqual(account.public_key.address, wallet.default_account.public_key.address)
|
BIN
torba/torba.png
Normal file
BIN
torba/torba.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 10 KiB |
2
torba/torba/__init__.py
Normal file
2
torba/torba/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
__path__: str = __import__('pkgutil').extend_path(__path__, __name__)
|
||||||
|
__version__ = '0.5.7'
|
0
torba/torba/client/__init__.py
Normal file
0
torba/torba/client/__init__.py
Normal file
464
torba/torba/client/baseaccount.py
Normal file
464
torba/torba/client/baseaccount.py
Normal file
|
@ -0,0 +1,464 @@
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import typing
|
||||||
|
from typing import Dict, Tuple, Type, Optional, Any, List
|
||||||
|
|
||||||
|
from torba.client.mnemonic import Mnemonic
|
||||||
|
from torba.client.bip32 import PrivateKey, PubKey, from_extended_key_string
|
||||||
|
from torba.client.hash import aes_encrypt, aes_decrypt, sha256
|
||||||
|
from torba.client.constants import COIN
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from torba.client import baseledger, wallet as basewallet
|
||||||
|
|
||||||
|
|
||||||
|
class AddressManager:
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
__slots__ = 'account', 'public_key', 'chain_number', 'address_generator_lock'
|
||||||
|
|
||||||
|
def __init__(self, account, public_key, chain_number):
|
||||||
|
self.account = account
|
||||||
|
self.public_key = public_key
|
||||||
|
self.chain_number = chain_number
|
||||||
|
self.address_generator_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, account: 'BaseAccount', d: dict) \
|
||||||
|
-> Tuple['AddressManager', 'AddressManager']:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def to_dict(cls, receiving: 'AddressManager', change: 'AddressManager') -> Dict:
|
||||||
|
d: Dict[str, Any] = {'name': cls.name}
|
||||||
|
receiving_dict = receiving.to_dict_instance()
|
||||||
|
if receiving_dict:
|
||||||
|
d['receiving'] = receiving_dict
|
||||||
|
change_dict = change.to_dict_instance()
|
||||||
|
if change_dict:
|
||||||
|
d['change'] = change_dict
|
||||||
|
return d
|
||||||
|
|
||||||
|
def apply(self, d: dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def to_dict_instance(self) -> Optional[dict]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _query_addresses(self, **constraints):
|
||||||
|
return self.account.ledger.db.get_addresses(
|
||||||
|
account=self.account,
|
||||||
|
chain=self.chain_number,
|
||||||
|
**constraints
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_private_key(self, index: int) -> PrivateKey:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_public_key(self, index: int) -> PubKey:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_max_gap(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def ensure_address_gap(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_addresses(self, only_usable: bool = False, **constraints) -> List[str]:
|
||||||
|
records = await self.get_address_records(only_usable=only_usable, **constraints)
|
||||||
|
return [r['address'] for r in records]
|
||||||
|
|
||||||
|
async def get_or_create_usable_address(self) -> str:
|
||||||
|
addresses = await self.get_addresses(only_usable=True, limit=10)
|
||||||
|
if addresses:
|
||||||
|
return random.choice(addresses)
|
||||||
|
addresses = await self.ensure_address_gap()
|
||||||
|
return addresses[0]
|
||||||
|
|
||||||
|
|
||||||
|
class HierarchicalDeterministic(AddressManager):
|
||||||
|
""" Implements simple version of Bitcoin Hierarchical Deterministic key management. """
|
||||||
|
|
||||||
|
name = "deterministic-chain"
|
||||||
|
|
||||||
|
__slots__ = 'gap', 'maximum_uses_per_address'
|
||||||
|
|
||||||
|
def __init__(self, account: 'BaseAccount', chain: int, gap: int, maximum_uses_per_address: int) -> None:
|
||||||
|
super().__init__(account, account.public_key.child(chain), chain)
|
||||||
|
self.gap = gap
|
||||||
|
self.maximum_uses_per_address = maximum_uses_per_address
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, account: 'BaseAccount', d: dict) -> Tuple[AddressManager, AddressManager]:
|
||||||
|
return (
|
||||||
|
cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})),
|
||||||
|
cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1}))
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply(self, d: dict):
|
||||||
|
self.gap = d.get('gap', self.gap)
|
||||||
|
self.maximum_uses_per_address = d.get('maximum_uses_per_address', self.maximum_uses_per_address)
|
||||||
|
|
||||||
|
def to_dict_instance(self):
|
||||||
|
return {'gap': self.gap, 'maximum_uses_per_address': self.maximum_uses_per_address}
|
||||||
|
|
||||||
|
def get_private_key(self, index: int) -> PrivateKey:
|
||||||
|
return self.account.private_key.child(self.chain_number).child(index)
|
||||||
|
|
||||||
|
def get_public_key(self, index: int) -> PubKey:
|
||||||
|
return self.account.public_key.child(self.chain_number).child(index)
|
||||||
|
|
||||||
|
async def get_max_gap(self) -> int:
|
||||||
|
addresses = await self._query_addresses(order_by="position ASC")
|
||||||
|
max_gap = 0
|
||||||
|
current_gap = 0
|
||||||
|
for address in addresses:
|
||||||
|
if address['used_times'] == 0:
|
||||||
|
current_gap += 1
|
||||||
|
else:
|
||||||
|
max_gap = max(max_gap, current_gap)
|
||||||
|
current_gap = 0
|
||||||
|
return max_gap
|
||||||
|
|
||||||
|
async def ensure_address_gap(self) -> List[str]:
|
||||||
|
async with self.address_generator_lock:
|
||||||
|
addresses = await self._query_addresses(limit=self.gap, order_by="position DESC")
|
||||||
|
|
||||||
|
existing_gap = 0
|
||||||
|
for address in addresses:
|
||||||
|
if address['used_times'] == 0:
|
||||||
|
existing_gap += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if existing_gap == self.gap:
|
||||||
|
return []
|
||||||
|
|
||||||
|
start = addresses[0]['position']+1 if addresses else 0
|
||||||
|
end = start + (self.gap - existing_gap)
|
||||||
|
new_keys = await self._generate_keys(start, end-1)
|
||||||
|
await self.account.ledger.announce_addresses(self, new_keys)
|
||||||
|
return new_keys
|
||||||
|
|
||||||
|
async def _generate_keys(self, start: int, end: int) -> List[str]:
|
||||||
|
if not self.address_generator_lock.locked():
|
||||||
|
raise RuntimeError('Should not be called outside of address_generator_lock.')
|
||||||
|
keys = [(index, self.public_key.child(index)) for index in range(start, end+1)]
|
||||||
|
await self.account.ledger.db.add_keys(self.account, self.chain_number, keys)
|
||||||
|
return [key[1].address for key in keys]
|
||||||
|
|
||||||
|
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||||
|
if only_usable:
|
||||||
|
constraints['used_times__lt'] = self.maximum_uses_per_address
|
||||||
|
if 'order_by' not in constraints:
|
||||||
|
constraints['order_by'] = "used_times ASC, position ASC"
|
||||||
|
return self._query_addresses(**constraints)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleKey(AddressManager):
|
||||||
|
""" Single Key address manager always returns the same address for all operations. """
|
||||||
|
|
||||||
|
name = "single-address"
|
||||||
|
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, account: 'BaseAccount', d: dict)\
|
||||||
|
-> Tuple[AddressManager, AddressManager]:
|
||||||
|
same_address_manager = cls(account, account.public_key, 0)
|
||||||
|
return same_address_manager, same_address_manager
|
||||||
|
|
||||||
|
def to_dict_instance(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_private_key(self, index: int) -> PrivateKey:
|
||||||
|
return self.account.private_key
|
||||||
|
|
||||||
|
def get_public_key(self, index: int) -> PubKey:
|
||||||
|
return self.account.public_key
|
||||||
|
|
||||||
|
async def get_max_gap(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def ensure_address_gap(self) -> List[str]:
|
||||||
|
async with self.address_generator_lock:
|
||||||
|
exists = await self.get_address_records()
|
||||||
|
if not exists:
|
||||||
|
await self.account.ledger.db.add_keys(
|
||||||
|
self.account, self.chain_number, [(0, self.public_key)]
|
||||||
|
)
|
||||||
|
new_keys = [self.public_key.address]
|
||||||
|
await self.account.ledger.announce_addresses(self, new_keys)
|
||||||
|
return new_keys
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||||
|
return self._query_addresses(**constraints)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAccount:
|
||||||
|
|
||||||
|
mnemonic_class = Mnemonic
|
||||||
|
private_key_class = PrivateKey
|
||||||
|
public_key_class = PubKey
|
||||||
|
address_generators: Dict[str, Type[AddressManager]] = {
|
||||||
|
SingleKey.name: SingleKey,
|
||||||
|
HierarchicalDeterministic.name: HierarchicalDeterministic,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', name: str,
|
||||||
|
seed: str, private_key_string: str, encrypted: bool,
|
||||||
|
private_key: Optional[PrivateKey], public_key: PubKey,
|
||||||
|
address_generator: dict, modified_on: float) -> None:
|
||||||
|
self.ledger = ledger
|
||||||
|
self.wallet = wallet
|
||||||
|
self.id = public_key.address
|
||||||
|
self.name = name
|
||||||
|
self.seed = seed
|
||||||
|
self.modified_on = modified_on
|
||||||
|
self.private_key_string = private_key_string
|
||||||
|
self.password: Optional[str] = None
|
||||||
|
self.private_key_encryption_init_vector: Optional[bytes] = None
|
||||||
|
self.seed_encryption_init_vector: Optional[bytes] = None
|
||||||
|
|
||||||
|
self.encrypted = encrypted
|
||||||
|
self.serialize_encrypted = encrypted
|
||||||
|
self.private_key = private_key
|
||||||
|
self.public_key = public_key
|
||||||
|
generator_name = address_generator.get('name', HierarchicalDeterministic.name)
|
||||||
|
self.address_generator = self.address_generators[generator_name]
|
||||||
|
self.receiving, self.change = self.address_generator.from_dict(self, address_generator)
|
||||||
|
self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}}
|
||||||
|
ledger.add_account(self)
|
||||||
|
wallet.add_account(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet',
|
||||||
|
name: str = None, address_generator: dict = None):
|
||||||
|
return cls.from_dict(ledger, wallet, {
|
||||||
|
'name': name,
|
||||||
|
'seed': cls.mnemonic_class().make_seed(),
|
||||||
|
'address_generator': address_generator or {}
|
||||||
|
})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_private_key_from_seed(cls, ledger: 'baseledger.BaseLedger', seed: str, password: str):
|
||||||
|
return cls.private_key_class.from_seed(
|
||||||
|
ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def keys_from_dict(cls, ledger: 'baseledger.BaseLedger', d: dict) \
|
||||||
|
-> Tuple[str, Optional[PrivateKey], PubKey]:
|
||||||
|
seed = d.get('seed', '')
|
||||||
|
private_key_string = d.get('private_key', '')
|
||||||
|
private_key = None
|
||||||
|
public_key = None
|
||||||
|
encrypted = d.get('encrypted', False)
|
||||||
|
if not encrypted:
|
||||||
|
if seed:
|
||||||
|
private_key = cls.get_private_key_from_seed(ledger, seed, '')
|
||||||
|
public_key = private_key.public_key
|
||||||
|
elif private_key_string:
|
||||||
|
private_key = from_extended_key_string(ledger, private_key_string)
|
||||||
|
public_key = private_key.public_key
|
||||||
|
if public_key is None:
|
||||||
|
public_key = from_extended_key_string(ledger, d['public_key'])
|
||||||
|
return seed, private_key, public_key
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', d: dict):
|
||||||
|
seed, private_key, public_key = cls.keys_from_dict(ledger, d)
|
||||||
|
name = d.get('name')
|
||||||
|
if not name:
|
||||||
|
name = 'Account #{}'.format(public_key.address)
|
||||||
|
return cls(
|
||||||
|
ledger=ledger,
|
||||||
|
wallet=wallet,
|
||||||
|
name=name,
|
||||||
|
seed=seed,
|
||||||
|
private_key_string=d.get('private_key', ''),
|
||||||
|
encrypted=d.get('encrypted', False),
|
||||||
|
private_key=private_key,
|
||||||
|
public_key=public_key,
|
||||||
|
address_generator=d.get('address_generator', {}),
|
||||||
|
modified_on=d.get('modified_on', time.time())
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
private_key_string, seed = self.private_key_string, self.seed
|
||||||
|
if not self.encrypted and self.private_key:
|
||||||
|
private_key_string = self.private_key.extended_key_string()
|
||||||
|
if not self.encrypted and self.serialize_encrypted:
|
||||||
|
assert None not in [self.seed_encryption_init_vector, self.private_key_encryption_init_vector]
|
||||||
|
private_key_string = aes_encrypt(
|
||||||
|
self.password, private_key_string, self.private_key_encryption_init_vector
|
||||||
|
)
|
||||||
|
seed = aes_encrypt(self.password, self.seed, self.seed_encryption_init_vector)
|
||||||
|
return {
|
||||||
|
'ledger': self.ledger.get_id(),
|
||||||
|
'name': self.name,
|
||||||
|
'seed': seed,
|
||||||
|
'encrypted': self.serialize_encrypted,
|
||||||
|
'private_key': private_key_string,
|
||||||
|
'public_key': self.public_key.extended_key_string(),
|
||||||
|
'address_generator': self.address_generator.to_dict(self.receiving, self.change),
|
||||||
|
'modified_on': self.modified_on
|
||||||
|
}
|
||||||
|
|
||||||
|
def apply(self, d: dict):
|
||||||
|
if d.get('modified_on', 0) > self.modified_on:
|
||||||
|
self.name = d['name']
|
||||||
|
self.modified_on = d.get('modified_on', time.time())
|
||||||
|
assert self.address_generator.name == d['address_generator']['name']
|
||||||
|
for chain_name in ('change', 'receiving'):
|
||||||
|
if chain_name in d['address_generator']:
|
||||||
|
chain_object = getattr(self, chain_name)
|
||||||
|
chain_object.apply(d['address_generator'][chain_name])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hash(self) -> bytes:
|
||||||
|
return sha256(json.dumps(self.to_dict()).encode())
|
||||||
|
|
||||||
|
async def get_details(self, show_seed=False, **kwargs):
|
||||||
|
satoshis = await self.get_balance(**kwargs)
|
||||||
|
details = {
|
||||||
|
'id': self.id,
|
||||||
|
'name': self.name,
|
||||||
|
'coins': round(satoshis/COIN, 2),
|
||||||
|
'satoshis': satoshis,
|
||||||
|
'encrypted': self.encrypted,
|
||||||
|
'public_key': self.public_key.extended_key_string(),
|
||||||
|
'address_generator': self.address_generator.to_dict(self.receiving, self.change)
|
||||||
|
}
|
||||||
|
if show_seed:
|
||||||
|
details['seed'] = self.seed
|
||||||
|
return details
|
||||||
|
|
||||||
|
def decrypt(self, password: str) -> None:
|
||||||
|
assert self.encrypted, "Key is not encrypted."
|
||||||
|
try:
|
||||||
|
seed, seed_iv = aes_decrypt(password, self.seed)
|
||||||
|
pk_string, pk_iv = aes_decrypt(password, self.private_key_string)
|
||||||
|
except ValueError: # failed to remove padding, password is wrong
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
Mnemonic().mnemonic_decode(seed)
|
||||||
|
except IndexError: # failed to decode the seed, this either means it decrypted and is invalid
|
||||||
|
# or that we hit an edge case where an incorrect password gave valid padding
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
private_key = from_extended_key_string(
|
||||||
|
self.ledger, pk_string
|
||||||
|
)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return
|
||||||
|
self.seed = seed
|
||||||
|
self.seed_encryption_init_vector = seed_iv
|
||||||
|
self.private_key = private_key
|
||||||
|
self.private_key_encryption_init_vector = pk_iv
|
||||||
|
self.password = password
|
||||||
|
self.encrypted = False
|
||||||
|
|
||||||
|
def encrypt(self, password: str) -> None:
|
||||||
|
assert not self.encrypted, "Key is already encrypted."
|
||||||
|
assert isinstance(self.private_key, PrivateKey)
|
||||||
|
|
||||||
|
self.seed = aes_encrypt(password, self.seed, self.seed_encryption_init_vector)
|
||||||
|
self.private_key_string = aes_encrypt(
|
||||||
|
password, self.private_key.extended_key_string(), self.private_key_encryption_init_vector
|
||||||
|
)
|
||||||
|
self.private_key = None
|
||||||
|
self.password = None
|
||||||
|
self.encrypted = True
|
||||||
|
|
||||||
|
async def ensure_address_gap(self):
|
||||||
|
addresses = []
|
||||||
|
for address_manager in self.address_managers.values():
|
||||||
|
new_addresses = await address_manager.ensure_address_gap()
|
||||||
|
addresses.extend(new_addresses)
|
||||||
|
return addresses
|
||||||
|
|
||||||
|
async def get_addresses(self, **constraints) -> List[str]:
|
||||||
|
rows = await self.ledger.db.select_addresses('address', account=self, **constraints)
|
||||||
|
return [r[0] for r in rows]
|
||||||
|
|
||||||
|
def get_address_records(self, **constraints):
|
||||||
|
return self.ledger.db.get_addresses(account=self, **constraints)
|
||||||
|
|
||||||
|
def get_address_count(self, **constraints):
|
||||||
|
return self.ledger.db.get_address_count(account=self, **constraints)
|
||||||
|
|
||||||
|
def get_private_key(self, chain: int, index: int) -> PrivateKey:
|
||||||
|
assert not self.encrypted, "Cannot get private key on encrypted wallet account."
|
||||||
|
return self.address_managers[chain].get_private_key(index)
|
||||||
|
|
||||||
|
def get_public_key(self, chain: int, index: int) -> PubKey:
|
||||||
|
return self.address_managers[chain].get_public_key(index)
|
||||||
|
|
||||||
|
def get_balance(self, confirmations: int = 0, **constraints):
|
||||||
|
if confirmations > 0:
|
||||||
|
height = self.ledger.headers.height - (confirmations-1)
|
||||||
|
constraints.update({'height__lte': height, 'height__gt': 0})
|
||||||
|
return self.ledger.db.get_balance(account=self, **constraints)
|
||||||
|
|
||||||
|
async def get_max_gap(self):
|
||||||
|
change_gap = await self.change.get_max_gap()
|
||||||
|
receiving_gap = await self.receiving.get_max_gap()
|
||||||
|
return {
|
||||||
|
'max_change_gap': change_gap,
|
||||||
|
'max_receiving_gap': receiving_gap,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_utxos(self, **constraints):
|
||||||
|
return self.ledger.db.get_utxos(account=self, **constraints)
|
||||||
|
|
||||||
|
def get_utxo_count(self, **constraints):
|
||||||
|
return self.ledger.db.get_utxo_count(account=self, **constraints)
|
||||||
|
|
||||||
|
def get_transactions(self, **constraints):
|
||||||
|
return self.ledger.db.get_transactions(account=self, **constraints)
|
||||||
|
|
||||||
|
def get_transaction_count(self, **constraints):
|
||||||
|
return self.ledger.db.get_transaction_count(account=self, **constraints)
|
||||||
|
|
||||||
|
async def fund(self, to_account, amount=None, everything=False,
|
||||||
|
outputs=1, broadcast=False, **constraints):
|
||||||
|
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
|
||||||
|
tx_class = self.ledger.transaction_class
|
||||||
|
if everything:
|
||||||
|
utxos = await self.get_utxos(**constraints)
|
||||||
|
await self.ledger.reserve_outputs(utxos)
|
||||||
|
tx = await tx_class.create(
|
||||||
|
inputs=[tx_class.input_class.spend(txo) for txo in utxos],
|
||||||
|
outputs=[],
|
||||||
|
funding_accounts=[self],
|
||||||
|
change_account=to_account
|
||||||
|
)
|
||||||
|
elif amount > 0:
|
||||||
|
to_address = await to_account.change.get_or_create_usable_address()
|
||||||
|
to_hash160 = to_account.ledger.address_to_hash160(to_address)
|
||||||
|
tx = await tx_class.create(
|
||||||
|
inputs=[],
|
||||||
|
outputs=[
|
||||||
|
tx_class.output_class.pay_pubkey_hash(amount//outputs, to_hash160)
|
||||||
|
for _ in range(outputs)
|
||||||
|
],
|
||||||
|
funding_accounts=[self],
|
||||||
|
change_account=self
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError('An amount is required.')
|
||||||
|
|
||||||
|
if broadcast:
|
||||||
|
await self.ledger.broadcast(tx)
|
||||||
|
else:
|
||||||
|
await self.ledger.release_tx(tx)
|
||||||
|
|
||||||
|
return tx
|
549
torba/torba/client/basedatabase.py
Normal file
549
torba/torba/client/basedatabase.py
Normal file
|
@ -0,0 +1,549 @@
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from asyncio import wrap_future
|
||||||
|
from concurrent.futures.thread import ThreadPoolExecutor
|
||||||
|
|
||||||
|
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
from torba.client.basetransaction import BaseTransaction
|
||||||
|
from torba.client.baseaccount import BaseAccount
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AIOSQLite:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# has to be single threaded as there is no mapping of thread:connection
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
self.connection: sqlite3.Connection = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def connect(cls, path: Union[bytes, str], *args, **kwargs):
|
||||||
|
db = cls()
|
||||||
|
db.connection = await wrap_future(db.executor.submit(sqlite3.connect, path, *args, **kwargs))
|
||||||
|
return db
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
def __close(conn):
|
||||||
|
self.executor.submit(conn.close)
|
||||||
|
self.executor.shutdown(wait=True)
|
||||||
|
conn = self.connection
|
||||||
|
if not conn:
|
||||||
|
return
|
||||||
|
self.connection = None
|
||||||
|
return asyncio.get_event_loop_policy().get_event_loop().call_later(0.01, __close, conn)
|
||||||
|
|
||||||
|
def executemany(self, sql: str, params: Iterable):
|
||||||
|
def __executemany_in_a_transaction(conn: sqlite3.Connection, *args, **kwargs):
|
||||||
|
return conn.executemany(*args, **kwargs)
|
||||||
|
return self.run(__executemany_in_a_transaction, sql, params)
|
||||||
|
|
||||||
|
def executescript(self, script: str) -> Awaitable:
|
||||||
|
return wrap_future(self.executor.submit(self.connection.executescript, script))
|
||||||
|
|
||||||
|
def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
|
||||||
|
parameters = parameters if parameters is not None else []
|
||||||
|
def __fetchall(conn: sqlite3.Connection, *args, **kwargs):
|
||||||
|
return conn.execute(*args, **kwargs).fetchall()
|
||||||
|
return wrap_future(self.executor.submit(__fetchall, self.connection, sql, parameters))
|
||||||
|
|
||||||
|
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
|
||||||
|
parameters = parameters if parameters is not None else []
|
||||||
|
return self.run(lambda conn, sql, parameters: conn.execute(sql, parameters), sql, parameters)
|
||||||
|
|
||||||
|
def run(self, fun, *args, **kwargs) -> Awaitable:
|
||||||
|
return wrap_future(self.executor.submit(self.__run_transaction, fun, *args, **kwargs))
|
||||||
|
|
||||||
|
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
|
||||||
|
self.connection.execute('begin')
|
||||||
|
try:
|
||||||
|
result = fun(self.connection, *args, **kwargs) # type: ignore
|
||||||
|
self.connection.commit()
|
||||||
|
return result
|
||||||
|
except (Exception, OSError): # as e:
|
||||||
|
#log.exception('Error running transaction:', exc_info=e)
|
||||||
|
self.connection.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable:
|
||||||
|
return wrap_future(
|
||||||
|
self.executor.submit(self.__run_transaction_with_foreign_keys_disabled, fun, *args, **kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __run_transaction_with_foreign_keys_disabled(self,
|
||||||
|
fun: Callable[[sqlite3.Connection, Any, Any], Any],
|
||||||
|
*args, **kwargs):
|
||||||
|
foreign_keys_enabled, = self.connection.execute("pragma foreign_keys").fetchone()
|
||||||
|
if not foreign_keys_enabled:
|
||||||
|
raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
|
||||||
|
try:
|
||||||
|
self.connection.execute('pragma foreign_keys=off')
|
||||||
|
return self.__run_transaction(fun, *args, **kwargs)
|
||||||
|
finally:
|
||||||
|
self.connection.execute('pragma foreign_keys=on')
|
||||||
|
|
||||||
|
|
||||||
|
def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
|
||||||
|
sql, values = [], {}
|
||||||
|
for key, constraint in constraints.items():
|
||||||
|
tag = '0'
|
||||||
|
if '#' in key:
|
||||||
|
key, tag = key[:key.index('#')], key[key.index('#')+1:]
|
||||||
|
col, op, key = key, '=', key.replace('.', '_')
|
||||||
|
if key.startswith('$'):
|
||||||
|
values[key] = constraint
|
||||||
|
continue
|
||||||
|
elif key.endswith('__not'):
|
||||||
|
col, op = col[:-len('__not')], '!='
|
||||||
|
elif key.endswith('__is_null'):
|
||||||
|
col = col[:-len('__is_null')]
|
||||||
|
sql.append(f'{col} IS NULL')
|
||||||
|
continue
|
||||||
|
elif key.endswith('__is_not_null'):
|
||||||
|
col = col[:-len('__is_not_null')]
|
||||||
|
sql.append(f'{col} IS NOT NULL')
|
||||||
|
continue
|
||||||
|
elif key.endswith('__lt'):
|
||||||
|
col, op = col[:-len('__lt')], '<'
|
||||||
|
elif key.endswith('__lte'):
|
||||||
|
col, op = col[:-len('__lte')], '<='
|
||||||
|
elif key.endswith('__gt'):
|
||||||
|
col, op = col[:-len('__gt')], '>'
|
||||||
|
elif key.endswith('__gte'):
|
||||||
|
col, op = col[:-len('__gte')], '>='
|
||||||
|
elif key.endswith('__like'):
|
||||||
|
col, op = col[:-len('__like')], 'LIKE'
|
||||||
|
elif key.endswith('__not_like'):
|
||||||
|
col, op = col[:-len('__not_like')], 'NOT LIKE'
|
||||||
|
elif key.endswith('__in') or key.endswith('__not_in'):
|
||||||
|
if key.endswith('__in'):
|
||||||
|
col, op = col[:-len('__in')], 'IN'
|
||||||
|
else:
|
||||||
|
col, op = col[:-len('__not_in')], 'NOT IN'
|
||||||
|
if constraint:
|
||||||
|
if isinstance(constraint, (list, set, tuple)):
|
||||||
|
keys = []
|
||||||
|
for i, val in enumerate(constraint):
|
||||||
|
keys.append(f':{key}{tag}_{i}')
|
||||||
|
values[f'{key}{tag}_{i}'] = val
|
||||||
|
sql.append(f'{col} {op} ({", ".join(keys)})')
|
||||||
|
elif isinstance(constraint, str):
|
||||||
|
sql.append(f'{col} {op} ({constraint})')
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{col} requires a list, set or string as constraint value.")
|
||||||
|
continue
|
||||||
|
elif key.endswith('__any') or key.endswith('__or'):
|
||||||
|
where, subvalues = constraints_to_sql(constraint, ' OR ', key+tag+'_')
|
||||||
|
sql.append(f'({where})')
|
||||||
|
values.update(subvalues)
|
||||||
|
continue
|
||||||
|
elif key.endswith('__and'):
|
||||||
|
where, subvalues = constraints_to_sql(constraint, ' AND ', key+tag+'_')
|
||||||
|
sql.append(f'({where})')
|
||||||
|
values.update(subvalues)
|
||||||
|
continue
|
||||||
|
sql.append(f'{col} {op} :{prepend_key}{key}{tag}')
|
||||||
|
values[prepend_key+key+tag] = constraint
|
||||||
|
return joiner.join(sql) if sql else '', values
|
||||||
|
|
||||||
|
|
||||||
|
def query(select, **constraints):
|
||||||
|
sql = [select]
|
||||||
|
limit = constraints.pop('limit', None)
|
||||||
|
offset = constraints.pop('offset', None)
|
||||||
|
order_by = constraints.pop('order_by', None)
|
||||||
|
|
||||||
|
constraints.pop('my_account', None)
|
||||||
|
account = constraints.pop('account', None)
|
||||||
|
if account is not None:
|
||||||
|
if not isinstance(account, list):
|
||||||
|
account = [account]
|
||||||
|
constraints['account__in'] = [
|
||||||
|
(a.public_key.address if isinstance(a, BaseAccount) else a) for a in account
|
||||||
|
]
|
||||||
|
|
||||||
|
where, values = constraints_to_sql(constraints)
|
||||||
|
if where:
|
||||||
|
sql.append('WHERE')
|
||||||
|
sql.append(where)
|
||||||
|
|
||||||
|
if order_by:
|
||||||
|
sql.append('ORDER BY')
|
||||||
|
if isinstance(order_by, str):
|
||||||
|
sql.append(order_by)
|
||||||
|
elif isinstance(order_by, list):
|
||||||
|
sql.append(', '.join(order_by))
|
||||||
|
else:
|
||||||
|
raise ValueError("order_by must be string or list")
|
||||||
|
|
||||||
|
if limit is not None:
|
||||||
|
sql.append('LIMIT {}'.format(limit))
|
||||||
|
|
||||||
|
if offset is not None:
|
||||||
|
sql.append('OFFSET {}'.format(offset))
|
||||||
|
|
||||||
|
return ' '.join(sql), values
|
||||||
|
|
||||||
|
|
||||||
|
def rows_to_dict(rows, fields):
|
||||||
|
if rows:
|
||||||
|
return [dict(zip(fields, r)) for r in rows]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteMixin:
|
||||||
|
|
||||||
|
CREATE_TABLES_QUERY: str
|
||||||
|
|
||||||
|
def __init__(self, path):
|
||||||
|
self._db_path = path
|
||||||
|
self.db: AIOSQLite = None
|
||||||
|
self.ledger = None
|
||||||
|
|
||||||
|
async def open(self):
|
||||||
|
log.info("connecting to database: %s", self._db_path)
|
||||||
|
self.db = await AIOSQLite.connect(self._db_path)
|
||||||
|
await self.db.executescript(self.CREATE_TABLES_QUERY)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
await self.db.close()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False) -> Tuple[str, List]:
|
||||||
|
columns, values = [], []
|
||||||
|
for column, value in data.items():
|
||||||
|
columns.append(column)
|
||||||
|
values.append(value)
|
||||||
|
or_ignore = ""
|
||||||
|
if ignore_duplicate:
|
||||||
|
or_ignore = " OR IGNORE"
|
||||||
|
sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
|
||||||
|
or_ignore, table, ', '.join(columns), ', '.join(['?'] * len(values))
|
||||||
|
)
|
||||||
|
return sql, values
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _update_sql(table: str, data: dict, where: str,
|
||||||
|
constraints: Union[list, tuple]) -> Tuple[str, list]:
|
||||||
|
columns, values = [], []
|
||||||
|
for column, value in data.items():
|
||||||
|
columns.append("{} = ?".format(column))
|
||||||
|
values.append(value)
|
||||||
|
values.extend(constraints)
|
||||||
|
sql = "UPDATE {} SET {} WHERE {}".format(
|
||||||
|
table, ', '.join(columns), where
|
||||||
|
)
|
||||||
|
return sql, values
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDatabase(SQLiteMixin):
|
||||||
|
|
||||||
|
PRAGMAS = """
|
||||||
|
pragma journal_mode=WAL;
|
||||||
|
"""
|
||||||
|
|
||||||
|
CREATE_PUBKEY_ADDRESS_TABLE = """
|
||||||
|
create table if not exists pubkey_address (
|
||||||
|
address text primary key,
|
||||||
|
account text not null,
|
||||||
|
chain integer not null,
|
||||||
|
position integer not null,
|
||||||
|
pubkey blob not null,
|
||||||
|
history text,
|
||||||
|
used_times integer not null default 0
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
CREATE_PUBKEY_ADDRESS_INDEX = """
|
||||||
|
create index if not exists pubkey_address_account_idx on pubkey_address (account);
|
||||||
|
"""
|
||||||
|
|
||||||
|
CREATE_TX_TABLE = """
|
||||||
|
create table if not exists tx (
|
||||||
|
txid text primary key,
|
||||||
|
raw blob not null,
|
||||||
|
height integer not null,
|
||||||
|
position integer not null,
|
||||||
|
is_verified boolean not null default 0
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
CREATE_TXO_TABLE = """
|
||||||
|
create table if not exists txo (
|
||||||
|
txid text references tx,
|
||||||
|
txoid text primary key,
|
||||||
|
address text references pubkey_address,
|
||||||
|
position integer not null,
|
||||||
|
amount integer not null,
|
||||||
|
script blob not null,
|
||||||
|
is_reserved boolean not null default 0
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
CREATE_TXO_INDEX = """
|
||||||
|
create index if not exists txo_address_idx on txo (address);
|
||||||
|
"""
|
||||||
|
|
||||||
|
CREATE_TXI_TABLE = """
|
||||||
|
create table if not exists txi (
|
||||||
|
txid text references tx,
|
||||||
|
txoid text references txo,
|
||||||
|
address text references pubkey_address
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
CREATE_TXI_INDEX = """
|
||||||
|
create index if not exists txi_address_idx on txi (address);
|
||||||
|
create index if not exists txi_txoid_idx on txi (txoid);
|
||||||
|
"""
|
||||||
|
|
||||||
|
CREATE_TABLES_QUERY = (
|
||||||
|
PRAGMAS +
|
||||||
|
CREATE_TX_TABLE +
|
||||||
|
CREATE_PUBKEY_ADDRESS_TABLE +
|
||||||
|
CREATE_PUBKEY_ADDRESS_INDEX +
|
||||||
|
CREATE_TXO_TABLE +
|
||||||
|
CREATE_TXO_INDEX +
|
||||||
|
CREATE_TXI_TABLE +
|
||||||
|
CREATE_TXI_INDEX
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def txo_to_row(tx, address, txo):
|
||||||
|
return {
|
||||||
|
'txid': tx.id,
|
||||||
|
'txoid': txo.id,
|
||||||
|
'address': address,
|
||||||
|
'position': txo.position,
|
||||||
|
'amount': txo.amount,
|
||||||
|
'script': sqlite3.Binary(txo.script.source)
|
||||||
|
}
|
||||||
|
|
||||||
|
async def insert_transaction(self, tx):
|
||||||
|
await self.db.execute(*self._insert_sql('tx', {
|
||||||
|
'txid': tx.id,
|
||||||
|
'raw': sqlite3.Binary(tx.raw),
|
||||||
|
'height': tx.height,
|
||||||
|
'position': tx.position,
|
||||||
|
'is_verified': tx.is_verified
|
||||||
|
}))
|
||||||
|
|
||||||
|
async def update_transaction(self, tx):
|
||||||
|
await self.db.execute(*self._update_sql("tx", {
|
||||||
|
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
|
||||||
|
}, 'txid = ?', (tx.id,)))
|
||||||
|
|
||||||
|
def save_transaction_io(self, tx: BaseTransaction, address, txhash, history):
|
||||||
|
|
||||||
|
def _transaction(conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history):
|
||||||
|
|
||||||
|
for txo in tx.outputs:
|
||||||
|
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
|
||||||
|
conn.execute(*self._insert_sql(
|
||||||
|
"txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True
|
||||||
|
))
|
||||||
|
elif txo.script.is_pay_script_hash:
|
||||||
|
# TODO: implement script hash payments
|
||||||
|
log.warning('Database.save_transaction_io: pay script hash is not implemented!')
|
||||||
|
|
||||||
|
for txi in tx.inputs:
|
||||||
|
if txi.txo_ref.txo is not None:
|
||||||
|
txo = txi.txo_ref.txo
|
||||||
|
if txo.get_address(self.ledger) == address:
|
||||||
|
conn.execute(*self._insert_sql("txi", {
|
||||||
|
'txid': tx.id,
|
||||||
|
'txoid': txo.id,
|
||||||
|
'address': address,
|
||||||
|
}, ignore_duplicate=True))
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||||
|
(history, history.count(':')//2, address)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.db.run(_transaction, tx, address, txhash, history)
|
||||||
|
|
||||||
|
async def reserve_outputs(self, txos, is_reserved=True):
|
||||||
|
txoids = ((is_reserved, txo.id) for txo in txos)
|
||||||
|
await self.db.executemany("UPDATE txo SET is_reserved = ? WHERE txoid = ?", txoids)
|
||||||
|
|
||||||
|
async def release_outputs(self, txos):
|
||||||
|
await self.reserve_outputs(txos, is_reserved=False)
|
||||||
|
|
||||||
|
async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
|
||||||
|
# TODO:
|
||||||
|
# 1. delete transactions above_height
|
||||||
|
# 2. update address histories removing deleted TXs
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def select_transactions(self, cols, account=None, **constraints):
|
||||||
|
if 'txid' not in constraints and account is not None:
|
||||||
|
constraints['$account'] = account.public_key.address
|
||||||
|
constraints['txid__in'] = """
|
||||||
|
SELECT txo.txid FROM txo
|
||||||
|
JOIN pubkey_address USING (address) WHERE pubkey_address.account = :$account
|
||||||
|
UNION
|
||||||
|
SELECT txi.txid FROM txi
|
||||||
|
JOIN pubkey_address USING (address) WHERE pubkey_address.account = :$account
|
||||||
|
"""
|
||||||
|
return await self.db.execute_fetchall(
|
||||||
|
*query("SELECT {} FROM tx".format(cols), **constraints)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_transactions(self, my_account=None, **constraints):
|
||||||
|
my_account = my_account or constraints.get('account', None)
|
||||||
|
|
||||||
|
tx_rows = await self.select_transactions(
|
||||||
|
'txid, raw, height, position, is_verified',
|
||||||
|
order_by=["height=0 DESC", "height DESC", "position DESC"],
|
||||||
|
**constraints
|
||||||
|
)
|
||||||
|
|
||||||
|
if not tx_rows:
|
||||||
|
return []
|
||||||
|
|
||||||
|
txids, txs, txi_txoids = [], [], []
|
||||||
|
for row in tx_rows:
|
||||||
|
txids.append(row[0])
|
||||||
|
txs.append(self.ledger.transaction_class(
|
||||||
|
raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4])
|
||||||
|
))
|
||||||
|
for txi in txs[-1].inputs:
|
||||||
|
txi_txoids.append(txi.txo_ref.id)
|
||||||
|
|
||||||
|
annotated_txos = {
|
||||||
|
txo.id: txo for txo in
|
||||||
|
(await self.get_txos(
|
||||||
|
my_account=my_account,
|
||||||
|
txid__in=txids
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
referenced_txos = {
|
||||||
|
txo.id: txo for txo in
|
||||||
|
(await self.get_txos(
|
||||||
|
my_account=my_account,
|
||||||
|
txoid__in=txi_txoids
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
for tx in txs:
|
||||||
|
for txi in tx.inputs:
|
||||||
|
txo = referenced_txos.get(txi.txo_ref.id)
|
||||||
|
if txo:
|
||||||
|
txi.txo_ref = txo.ref
|
||||||
|
for txo in tx.outputs:
|
||||||
|
_txo = annotated_txos.get(txo.id)
|
||||||
|
if _txo:
|
||||||
|
txo.update_annotations(_txo)
|
||||||
|
else:
|
||||||
|
txo.update_annotations(None)
|
||||||
|
|
||||||
|
return txs
|
||||||
|
|
||||||
|
async def get_transaction_count(self, **constraints):
|
||||||
|
constraints.pop('offset', None)
|
||||||
|
constraints.pop('limit', None)
|
||||||
|
constraints.pop('order_by', None)
|
||||||
|
count = await self.select_transactions('count(*)', **constraints)
|
||||||
|
return count[0][0]
|
||||||
|
|
||||||
|
async def get_transaction(self, **constraints):
|
||||||
|
txs = await self.get_transactions(limit=1, **constraints)
|
||||||
|
if txs:
|
||||||
|
return txs[0]
|
||||||
|
|
||||||
|
async def select_txos(self, cols, **constraints):
|
||||||
|
return await self.db.execute_fetchall(*query(
|
||||||
|
"SELECT {} FROM txo"
|
||||||
|
" JOIN pubkey_address USING (address)"
|
||||||
|
" JOIN tx USING (txid)".format(cols), **constraints
|
||||||
|
))
|
||||||
|
|
||||||
|
async def get_txos(self, my_account=None, **constraints):
|
||||||
|
my_account = my_account or constraints.get('account', None)
|
||||||
|
if isinstance(my_account, BaseAccount):
|
||||||
|
my_account = my_account.public_key.address
|
||||||
|
if 'order_by' not in constraints:
|
||||||
|
constraints['order_by'] = ["tx.height=0 DESC", "tx.height DESC", "tx.position DESC"]
|
||||||
|
rows = await self.select_txos(
|
||||||
|
"tx.txid, raw, tx.height, tx.position, tx.is_verified, txo.position, chain, account",
|
||||||
|
**constraints
|
||||||
|
)
|
||||||
|
txos = []
|
||||||
|
txs = {}
|
||||||
|
for row in rows:
|
||||||
|
if row[0] not in txs:
|
||||||
|
txs[row[0]] = self.ledger.transaction_class(
|
||||||
|
row[1], height=row[2], position=row[3], is_verified=row[4]
|
||||||
|
)
|
||||||
|
txo = txs[row[0]].outputs[row[5]]
|
||||||
|
txo.is_change = row[6] == 1
|
||||||
|
txo.is_my_account = row[7] == my_account
|
||||||
|
txos.append(txo)
|
||||||
|
return txos
|
||||||
|
|
||||||
|
async def get_txo_count(self, **constraints):
|
||||||
|
constraints.pop('offset', None)
|
||||||
|
constraints.pop('limit', None)
|
||||||
|
constraints.pop('order_by', None)
|
||||||
|
count = await self.select_txos('count(*)', **constraints)
|
||||||
|
return count[0][0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def constrain_utxo(constraints):
|
||||||
|
constraints['is_reserved'] = False
|
||||||
|
constraints['txoid__not_in'] = "SELECT txoid FROM txi"
|
||||||
|
|
||||||
|
def get_utxos(self, **constraints):
|
||||||
|
self.constrain_utxo(constraints)
|
||||||
|
return self.get_txos(**constraints)
|
||||||
|
|
||||||
|
def get_utxo_count(self, **constraints):
|
||||||
|
self.constrain_utxo(constraints)
|
||||||
|
return self.get_txo_count(**constraints)
|
||||||
|
|
||||||
|
async def get_balance(self, **constraints):
|
||||||
|
self.constrain_utxo(constraints)
|
||||||
|
balance = await self.select_txos('SUM(amount)', **constraints)
|
||||||
|
return balance[0][0] or 0
|
||||||
|
|
||||||
|
async def select_addresses(self, cols, **constraints):
|
||||||
|
return await self.db.execute_fetchall(*query(
|
||||||
|
"SELECT {} FROM pubkey_address".format(cols), **constraints
|
||||||
|
))
|
||||||
|
|
||||||
|
async def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'),
|
||||||
|
**constraints):
|
||||||
|
addresses = await self.select_addresses(', '.join(cols), **constraints)
|
||||||
|
return rows_to_dict(addresses, cols)
|
||||||
|
|
||||||
|
async def get_address_count(self, **constraints):
|
||||||
|
count = await self.select_addresses('count(*)', **constraints)
|
||||||
|
return count[0][0]
|
||||||
|
|
||||||
|
async def get_address(self, **constraints):
|
||||||
|
addresses = await self.get_addresses(
|
||||||
|
cols=('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'),
|
||||||
|
limit=1, **constraints
|
||||||
|
)
|
||||||
|
if addresses:
|
||||||
|
return addresses[0]
|
||||||
|
|
||||||
|
async def add_keys(self, account, chain, keys):
|
||||||
|
await self.db.executemany(
|
||||||
|
"insert into pubkey_address (address, account, chain, position, pubkey) values (?, ?, ?, ?, ?)",
|
||||||
|
((pubkey.address, account.public_key.address, chain,
|
||||||
|
position, sqlite3.Binary(pubkey.pubkey_bytes))
|
||||||
|
for position, pubkey in keys)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _set_address_history(self, address, history):
|
||||||
|
await self.db.execute(
|
||||||
|
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||||
|
(history, history.count(':')//2, address)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def set_address_history(self, address, history):
|
||||||
|
await self._set_address_history(address, history)
|
189
torba/torba/client/baseheader.py
Normal file
189
torba/torba/client/baseheader.py
Normal file
|
@ -0,0 +1,189 @@
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Optional, Iterator, Tuple
|
||||||
|
from binascii import hexlify
|
||||||
|
|
||||||
|
from torba.client.util import ArithUint256
|
||||||
|
from torba.client.hash import double_sha256
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidHeader(Exception):
|
||||||
|
|
||||||
|
def __init__(self, height, message):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.height = height
|
||||||
|
|
||||||
|
|
||||||
|
class BaseHeaders:
|
||||||
|
|
||||||
|
header_size: int
|
||||||
|
chunk_size: int
|
||||||
|
|
||||||
|
max_target: int
|
||||||
|
genesis_hash: Optional[bytes]
|
||||||
|
target_timespan: int
|
||||||
|
|
||||||
|
validate_difficulty: bool = True
|
||||||
|
|
||||||
|
def __init__(self, path) -> None:
|
||||||
|
if path == ':memory:':
|
||||||
|
self.io = BytesIO()
|
||||||
|
self.path = path
|
||||||
|
self._size: Optional[int] = None
|
||||||
|
self._header_connect_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def open(self):
|
||||||
|
if self.path != ':memory:':
|
||||||
|
if not os.path.exists(self.path):
|
||||||
|
self.io = open(self.path, 'w+b')
|
||||||
|
else:
|
||||||
|
self.io = open(self.path, 'r+b')
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
self.io.close()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def serialize(header: dict) -> bytes:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deserialize(height, header):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_next_chunk_target(self, chunk: int) -> ArithUint256:
|
||||||
|
return ArithUint256(self.max_target)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_next_block_target(chunk_target: ArithUint256, previous: Optional[dict],
|
||||||
|
current: Optional[dict]) -> ArithUint256:
|
||||||
|
return chunk_target
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
if self._size is None:
|
||||||
|
self._size = self.io.seek(0, os.SEEK_END) // self.header_size
|
||||||
|
return self._size
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __getitem__(self, height) -> dict:
|
||||||
|
assert not isinstance(height, slice), \
|
||||||
|
"Slicing of header chain has not been implemented yet."
|
||||||
|
return self.deserialize(height, self.get_raw_header(height))
|
||||||
|
|
||||||
|
def get_raw_header(self, height) -> bytes:
|
||||||
|
self.io.seek(height * self.header_size, os.SEEK_SET)
|
||||||
|
return self.io.read(self.header_size)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height(self) -> int:
|
||||||
|
return len(self)-1
|
||||||
|
|
||||||
|
def hash(self, height=None) -> bytes:
|
||||||
|
return self.hash_header(
|
||||||
|
self.get_raw_header(height if height is not None else self.height)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hash_header(header: bytes) -> bytes:
|
||||||
|
if header is None:
|
||||||
|
return b'0' * 64
|
||||||
|
return hexlify(double_sha256(header)[::-1])
|
||||||
|
|
||||||
|
async def connect(self, start: int, headers: bytes) -> int:
|
||||||
|
added = 0
|
||||||
|
bail = False
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
async with self._header_connect_lock:
|
||||||
|
for height, chunk in self._iterate_chunks(start, headers):
|
||||||
|
try:
|
||||||
|
# validate_chunk() is CPU bound and reads previous chunks from file system
|
||||||
|
await loop.run_in_executor(None, self.validate_chunk, height, chunk)
|
||||||
|
except InvalidHeader as e:
|
||||||
|
bail = True
|
||||||
|
chunk = chunk[:(height-e.height)*self.header_size]
|
||||||
|
written = 0
|
||||||
|
if chunk:
|
||||||
|
self.io.seek(height * self.header_size, os.SEEK_SET)
|
||||||
|
written = self.io.write(chunk) // self.header_size
|
||||||
|
self.io.truncate()
|
||||||
|
# .seek()/.write()/.truncate() might also .flush() when needed
|
||||||
|
# the goal here is mainly to ensure we're definitely flush()'ing
|
||||||
|
await loop.run_in_executor(None, self.io.flush)
|
||||||
|
self._size = None
|
||||||
|
added += written
|
||||||
|
if bail:
|
||||||
|
break
|
||||||
|
return added
|
||||||
|
|
||||||
|
def validate_chunk(self, height, chunk):
|
||||||
|
previous_hash, previous_header, previous_previous_header = None, None, None
|
||||||
|
if height > 0:
|
||||||
|
previous_header = self[height-1]
|
||||||
|
previous_hash = self.hash(height-1)
|
||||||
|
if height > 1:
|
||||||
|
previous_previous_header = self[height-2]
|
||||||
|
chunk_target = self.get_next_chunk_target(height // 2016 - 1)
|
||||||
|
for current_hash, current_header in self._iterate_headers(height, chunk):
|
||||||
|
block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header)
|
||||||
|
self.validate_header(height, current_hash, current_header, previous_hash, block_target)
|
||||||
|
previous_previous_header = previous_header
|
||||||
|
previous_header = current_header
|
||||||
|
previous_hash = current_hash
|
||||||
|
|
||||||
|
def validate_header(self, height: int, current_hash: bytes,
|
||||||
|
header: dict, previous_hash: bytes, target: ArithUint256):
|
||||||
|
|
||||||
|
if previous_hash is None:
|
||||||
|
if self.genesis_hash is not None and self.genesis_hash != current_hash:
|
||||||
|
raise InvalidHeader(
|
||||||
|
height, "genesis header doesn't match: {} vs expected {}".format(
|
||||||
|
current_hash.decode(), self.genesis_hash.decode())
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if header['prev_block_hash'] != previous_hash:
|
||||||
|
raise InvalidHeader(
|
||||||
|
height, "previous hash mismatch: {} vs expected {}".format(
|
||||||
|
header['prev_block_hash'].decode(), previous_hash.decode())
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.validate_difficulty:
|
||||||
|
|
||||||
|
if header['bits'] != target.compact:
|
||||||
|
raise InvalidHeader(
|
||||||
|
height, "bits mismatch: {} vs expected {}".format(
|
||||||
|
header['bits'], target.compact)
|
||||||
|
)
|
||||||
|
|
||||||
|
proof_of_work = self.get_proof_of_work(current_hash)
|
||||||
|
if proof_of_work > target:
|
||||||
|
raise InvalidHeader(
|
||||||
|
height, "insufficient proof of work: {} vs target {}".format(
|
||||||
|
proof_of_work.value, target.value)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_proof_of_work(header_hash: bytes) -> ArithUint256:
|
||||||
|
return ArithUint256(int(b'0x' + header_hash, 16))
|
||||||
|
|
||||||
|
def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]:
|
||||||
|
assert len(headers) % self.header_size == 0
|
||||||
|
start = 0
|
||||||
|
end = (self.chunk_size - height % self.chunk_size) * self.header_size
|
||||||
|
while start < end:
|
||||||
|
yield height + (start // self.header_size), headers[start:end]
|
||||||
|
start = end
|
||||||
|
end = min(len(headers), end + self.chunk_size * self.header_size)
|
||||||
|
|
||||||
|
def _iterate_headers(self, height: int, headers: bytes) -> Iterator[Tuple[bytes, dict]]:
|
||||||
|
assert len(headers) % self.header_size == 0
|
||||||
|
for idx in range(len(headers) // self.header_size):
|
||||||
|
start, end = idx * self.header_size, (idx + 1) * self.header_size
|
||||||
|
header = headers[start:end]
|
||||||
|
yield self.hash_header(header), self.deserialize(height+idx, header)
|
526
torba/torba/client/baseledger.py
Normal file
526
torba/torba/client/baseledger.py
Normal file
|
@ -0,0 +1,526 @@
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
from binascii import hexlify, unhexlify
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
|
from typing import Dict, Type, Iterable, List, Optional
|
||||||
|
from operator import itemgetter
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
from torba.tasks import TaskGroup
|
||||||
|
from torba.client import baseaccount, basenetwork, basetransaction
|
||||||
|
from torba.client.basedatabase import BaseDatabase
|
||||||
|
from torba.client.baseheader import BaseHeaders
|
||||||
|
from torba.client.coinselection import CoinSelector
|
||||||
|
from torba.client.constants import COIN, NULL_HASH32
|
||||||
|
from torba.stream import StreamController
|
||||||
|
from torba.client.hash import hash160, double_sha256, sha256, Base58
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
LedgerType = Type['BaseLedger']
|
||||||
|
|
||||||
|
|
||||||
|
class LedgerRegistry(type):
|
||||||
|
|
||||||
|
ledgers: Dict[str, LedgerType] = {}
|
||||||
|
|
||||||
|
def __new__(mcs, name, bases, attrs):
|
||||||
|
cls: LedgerType = super().__new__(mcs, name, bases, attrs)
|
||||||
|
if not (name == 'BaseLedger' and not bases):
|
||||||
|
ledger_id = cls.get_id()
|
||||||
|
assert ledger_id not in mcs.ledgers,\
|
||||||
|
'Ledger with id "{}" already registered.'.format(ledger_id)
|
||||||
|
mcs.ledgers[ledger_id] = cls
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_ledger_class(mcs, ledger_id: str) -> LedgerType:
|
||||||
|
return mcs.ledgers[ledger_id]
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AddressesGeneratedEvent(namedtuple('AddressesGeneratedEvent', ('address_manager', 'addresses'))):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionCacheItem:
|
||||||
|
__slots__ = '_tx', 'lock', 'has_tx'
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
tx: Optional[basetransaction.BaseTransaction] = None,
|
||||||
|
lock: Optional[asyncio.Lock] = None):
|
||||||
|
self.has_tx = asyncio.Event()
|
||||||
|
self.lock = lock or asyncio.Lock()
|
||||||
|
self._tx = self.tx = tx
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tx(self) -> Optional[basetransaction.BaseTransaction]:
|
||||||
|
return self._tx
|
||||||
|
|
||||||
|
@tx.setter
|
||||||
|
def tx(self, tx: basetransaction.BaseTransaction):
|
||||||
|
self._tx = tx
|
||||||
|
if tx is not None:
|
||||||
|
self.has_tx.set()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLedger(metaclass=LedgerRegistry):
|
||||||
|
|
||||||
|
name: str
|
||||||
|
symbol: str
|
||||||
|
network_name: str
|
||||||
|
|
||||||
|
database_class = BaseDatabase
|
||||||
|
account_class = baseaccount.BaseAccount
|
||||||
|
network_class = basenetwork.BaseNetwork
|
||||||
|
transaction_class = basetransaction.BaseTransaction
|
||||||
|
|
||||||
|
headers_class: Type[BaseHeaders]
|
||||||
|
|
||||||
|
pubkey_address_prefix: bytes
|
||||||
|
script_address_prefix: bytes
|
||||||
|
extended_public_key_prefix: bytes
|
||||||
|
extended_private_key_prefix: bytes
|
||||||
|
|
||||||
|
default_fee_per_byte = 10
|
||||||
|
|
||||||
|
def __init__(self, config=None):
|
||||||
|
self.config = config or {}
|
||||||
|
self.db: BaseDatabase = self.config.get('db') or self.database_class(
|
||||||
|
os.path.join(self.path, "blockchain.db")
|
||||||
|
)
|
||||||
|
self.db.ledger = self
|
||||||
|
self.headers: BaseHeaders = self.config.get('headers') or self.headers_class(
|
||||||
|
os.path.join(self.path, "headers")
|
||||||
|
)
|
||||||
|
self.network = self.config.get('network') or self.network_class(self)
|
||||||
|
self.network.on_header.listen(self.receive_header)
|
||||||
|
self.network.on_status.listen(self.process_status_update)
|
||||||
|
|
||||||
|
self.accounts = []
|
||||||
|
self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte)
|
||||||
|
|
||||||
|
self._on_transaction_controller = StreamController()
|
||||||
|
self.on_transaction = self._on_transaction_controller.stream
|
||||||
|
self.on_transaction.listen(
|
||||||
|
lambda e: log.info(
|
||||||
|
'(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s',
|
||||||
|
self.get_id(), e.address, e.tx.height, e.tx.is_verified, e.tx.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._on_address_controller = StreamController()
|
||||||
|
self.on_address = self._on_address_controller.stream
|
||||||
|
self.on_address.listen(
|
||||||
|
lambda e: log.info('(%s) on_address: %s', self.get_id(), e.addresses)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._on_header_controller = StreamController()
|
||||||
|
self.on_header = self._on_header_controller.stream
|
||||||
|
self.on_header.listen(
|
||||||
|
lambda change: log.info(
|
||||||
|
'%s: added %s header blocks, final height %s',
|
||||||
|
self.get_id(), change, self.headers.height
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._tx_cache = {}
|
||||||
|
self._update_tasks = TaskGroup()
|
||||||
|
self._utxo_reservation_lock = asyncio.Lock()
|
||||||
|
self._header_processing_lock = asyncio.Lock()
|
||||||
|
self._address_update_locks: Dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
|
self.coin_selection_strategy = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_id(cls):
|
||||||
|
return '{}_{}'.format(cls.symbol.lower(), cls.network_name.lower())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def hash160_to_address(cls, h160):
|
||||||
|
raw_address = cls.pubkey_address_prefix + h160
|
||||||
|
return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4]))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def address_to_hash160(address):
|
||||||
|
return Base58.decode(address)[1:21]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_valid_address(cls, address):
|
||||||
|
decoded = Base58.decode(address)
|
||||||
|
return decoded[0] == cls.pubkey_address_prefix[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def public_key_to_address(cls, public_key):
|
||||||
|
return cls.hash160_to_address(hash160(public_key))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def private_key_to_wif(private_key):
|
||||||
|
return b'\x1c' + private_key + b'\x01'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self):
|
||||||
|
return os.path.join(self.config['data_path'], self.get_id())
|
||||||
|
|
||||||
|
def add_account(self, account: baseaccount.BaseAccount):
|
||||||
|
self.accounts.append(account)
|
||||||
|
|
||||||
|
async def _get_account_and_address_info_for_address(self, address):
|
||||||
|
match = await self.db.get_address(address=address)
|
||||||
|
if match:
|
||||||
|
for account in self.accounts:
|
||||||
|
if match['account'] == account.public_key.address:
|
||||||
|
return account, match
|
||||||
|
|
||||||
|
async def get_private_key_for_address(self, address):
|
||||||
|
match = await self._get_account_and_address_info_for_address(address)
|
||||||
|
if match:
|
||||||
|
account, address_info = match
|
||||||
|
return account.get_private_key(address_info['chain'], address_info['position'])
|
||||||
|
|
||||||
|
async def get_public_key_for_address(self, address):
|
||||||
|
match = await self._get_account_and_address_info_for_address(address)
|
||||||
|
if match:
|
||||||
|
account, address_info = match
|
||||||
|
return account.get_public_key(address_info['chain'], address_info['position'])
|
||||||
|
|
||||||
|
async def get_account_for_address(self, address):
|
||||||
|
match = await self._get_account_and_address_info_for_address(address)
|
||||||
|
if match:
|
||||||
|
return match[0]
|
||||||
|
|
||||||
|
async def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]):
|
||||||
|
estimators = []
|
||||||
|
for account in funding_accounts:
|
||||||
|
utxos = await account.get_utxos()
|
||||||
|
for utxo in utxos:
|
||||||
|
estimators.append(utxo.get_estimator(self))
|
||||||
|
return estimators
|
||||||
|
|
||||||
|
async def get_spendable_utxos(self, amount: int, funding_accounts):
|
||||||
|
async with self._utxo_reservation_lock:
|
||||||
|
txos = await self.get_effective_amount_estimators(funding_accounts)
|
||||||
|
selector = CoinSelector(
|
||||||
|
txos, amount,
|
||||||
|
self.transaction_class.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self)
|
||||||
|
)
|
||||||
|
spendables = selector.select(self.coin_selection_strategy)
|
||||||
|
if spendables:
|
||||||
|
await self.reserve_outputs(s.txo for s in spendables)
|
||||||
|
return spendables
|
||||||
|
|
||||||
|
def reserve_outputs(self, txos):
|
||||||
|
return self.db.reserve_outputs(txos)
|
||||||
|
|
||||||
|
def release_outputs(self, txos):
|
||||||
|
return self.db.release_outputs(txos)
|
||||||
|
|
||||||
|
def release_tx(self, tx):
|
||||||
|
return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs])
|
||||||
|
|
||||||
|
async def get_local_status_and_history(self, address):
|
||||||
|
address_details = await self.db.get_address(address=address)
|
||||||
|
history = address_details['history'] or ''
|
||||||
|
parts = history.split(':')[:-1]
|
||||||
|
return (
|
||||||
|
hexlify(sha256(history.encode())).decode() if history else None,
|
||||||
|
list(zip(parts[0::2], map(int, parts[1::2])))
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
|
||||||
|
for i, branch in enumerate(branches):
|
||||||
|
other_branch = unhexlify(branch)[::-1]
|
||||||
|
other_branch_on_left = bool((branch_positions >> i) & 1)
|
||||||
|
if other_branch_on_left:
|
||||||
|
combined = other_branch + working_branch
|
||||||
|
else:
|
||||||
|
combined = working_branch + other_branch
|
||||||
|
working_branch = double_sha256(combined)
|
||||||
|
return hexlify(working_branch[::-1])
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if not os.path.exists(self.path):
|
||||||
|
os.mkdir(self.path)
|
||||||
|
await asyncio.wait([
|
||||||
|
self.db.open(),
|
||||||
|
self.headers.open()
|
||||||
|
])
|
||||||
|
first_connection = self.network.on_connected.first
|
||||||
|
asyncio.ensure_future(self.network.start())
|
||||||
|
await first_connection
|
||||||
|
await self.join_network()
|
||||||
|
self.network.on_connected.listen(self.join_network)
|
||||||
|
|
||||||
|
async def join_network(self, *args):
|
||||||
|
log.info("Subscribing and updating accounts.")
|
||||||
|
await self.update_headers()
|
||||||
|
await self.network.subscribe_headers()
|
||||||
|
await self.subscribe_accounts()
|
||||||
|
await self._update_tasks.done.wait()
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
self._update_tasks.cancel()
|
||||||
|
await self._update_tasks.done.wait()
|
||||||
|
await self.network.stop()
|
||||||
|
await self.db.close()
|
||||||
|
await self.headers.close()
|
||||||
|
|
||||||
|
async def update_headers(self, height=None, headers=None, subscription_update=False):
|
||||||
|
rewound = 0
|
||||||
|
while True:
|
||||||
|
|
||||||
|
if height is None or height > len(self.headers):
|
||||||
|
# sometimes header subscription updates are for a header in the future
|
||||||
|
# which can't be connected, so we do a normal header sync instead
|
||||||
|
height = len(self.headers)
|
||||||
|
headers = None
|
||||||
|
subscription_update = False
|
||||||
|
|
||||||
|
if not headers:
|
||||||
|
header_response = await self.network.get_headers(height, 2001)
|
||||||
|
headers = header_response['hex']
|
||||||
|
|
||||||
|
if not headers:
|
||||||
|
# Nothing to do, network thinks we're already at the latest height.
|
||||||
|
return
|
||||||
|
|
||||||
|
added = await self.headers.connect(height, unhexlify(headers))
|
||||||
|
if added > 0:
|
||||||
|
height += added
|
||||||
|
self._on_header_controller.add(
|
||||||
|
BlockHeightEvent(self.headers.height, added))
|
||||||
|
|
||||||
|
if rewound > 0:
|
||||||
|
# we started rewinding blocks and apparently found
|
||||||
|
# a new chain
|
||||||
|
rewound = 0
|
||||||
|
await self.db.rewind_blockchain(height)
|
||||||
|
|
||||||
|
if subscription_update:
|
||||||
|
# subscription updates are for latest header already
|
||||||
|
# so we don't need to check if there are newer / more
|
||||||
|
# on another loop of update_headers(), just return instead
|
||||||
|
return
|
||||||
|
|
||||||
|
elif added == 0:
|
||||||
|
# we had headers to connect but none got connected, probably a reorganization
|
||||||
|
height -= 1
|
||||||
|
rewound += 1
|
||||||
|
log.warning(
|
||||||
|
"Blockchain Reorganization: attempting rewind to height %s from starting height %s",
|
||||||
|
height, height+rewound
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise IndexError("headers.connect() returned negative number ({})".format(added))
|
||||||
|
|
||||||
|
if height < 0:
|
||||||
|
raise IndexError(
|
||||||
|
"Blockchain reorganization rewound all the way back to genesis hash. "
|
||||||
|
"Something is very wrong. Maybe you are on the wrong blockchain?"
|
||||||
|
)
|
||||||
|
|
||||||
|
if rewound >= 100:
|
||||||
|
raise IndexError(
|
||||||
|
"Blockchain reorganization dropped {} headers. This is highly unusual. "
|
||||||
|
"Will not continue to attempt reorganizing. Please, delete the ledger "
|
||||||
|
"synchronization directory inside your wallet directory (folder: '{}') and "
|
||||||
|
"restart the program to synchronize from scratch."
|
||||||
|
.format(rewound, self.get_id())
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = None # ready to download some more headers
|
||||||
|
|
||||||
|
# if we made it this far and this was a subscription_update
|
||||||
|
# it means something went wrong and now we're doing a more
|
||||||
|
# robust sync, turn off subscription update shortcut
|
||||||
|
subscription_update = False
|
||||||
|
|
||||||
|
async def receive_header(self, response):
|
||||||
|
async with self._header_processing_lock:
|
||||||
|
header = response[0]
|
||||||
|
await self.update_headers(
|
||||||
|
height=header['height'], headers=header['hex'], subscription_update=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async def subscribe_accounts(self):
|
||||||
|
if self.network.is_connected and self.accounts:
|
||||||
|
await asyncio.wait([
|
||||||
|
self.subscribe_account(a) for a in self.accounts
|
||||||
|
])
|
||||||
|
|
||||||
|
async def subscribe_account(self, account: baseaccount.BaseAccount):
|
||||||
|
for address_manager in account.address_managers.values():
|
||||||
|
await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
|
||||||
|
await account.ensure_address_gap()
|
||||||
|
|
||||||
|
async def announce_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
|
||||||
|
await self.subscribe_addresses(address_manager, addresses)
|
||||||
|
await self._on_address_controller.add(
|
||||||
|
AddressesGeneratedEvent(address_manager, addresses)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
|
||||||
|
if self.network.is_connected and addresses:
|
||||||
|
await asyncio.wait([
|
||||||
|
self.subscribe_address(address_manager, address) for address in addresses
|
||||||
|
])
|
||||||
|
|
||||||
|
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
|
||||||
|
remote_status = await self.network.subscribe_address(address)
|
||||||
|
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
|
||||||
|
|
||||||
|
def process_status_update(self, update):
|
||||||
|
address, remote_status = update
|
||||||
|
self._update_tasks.add(self.update_history(address, remote_status))
|
||||||
|
|
||||||
|
async def update_history(self, address, remote_status,
|
||||||
|
address_manager: baseaccount.AddressManager = None):
|
||||||
|
|
||||||
|
async with self._address_update_locks.setdefault(address, asyncio.Lock()):
|
||||||
|
|
||||||
|
local_status, local_history = await self.get_local_status_and_history(address)
|
||||||
|
|
||||||
|
if local_status == remote_status:
|
||||||
|
return
|
||||||
|
|
||||||
|
remote_history = await self.network.get_history(address)
|
||||||
|
|
||||||
|
cache_tasks = []
|
||||||
|
synced_history = StringIO()
|
||||||
|
for i, (txid, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
|
||||||
|
if i < len(local_history) and local_history[i] == (txid, remote_height):
|
||||||
|
synced_history.write(f'{txid}:{remote_height}:')
|
||||||
|
else:
|
||||||
|
cache_tasks.append(asyncio.ensure_future(
|
||||||
|
self.cache_transaction(txid, remote_height)
|
||||||
|
))
|
||||||
|
|
||||||
|
for task in cache_tasks:
|
||||||
|
tx = await task
|
||||||
|
|
||||||
|
check_db_for_txos = []
|
||||||
|
for txi in tx.inputs:
|
||||||
|
if txi.txo_ref.txo is not None:
|
||||||
|
continue
|
||||||
|
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id)
|
||||||
|
if cache_item is not None:
|
||||||
|
if cache_item.tx is None:
|
||||||
|
await cache_item.has_tx.wait()
|
||||||
|
assert cache_item.tx is not None
|
||||||
|
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
|
||||||
|
else:
|
||||||
|
check_db_for_txos.append(txi.txo_ref.id)
|
||||||
|
|
||||||
|
referenced_txos = {
|
||||||
|
txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos)
|
||||||
|
}
|
||||||
|
|
||||||
|
for txi in tx.inputs:
|
||||||
|
if txi.txo_ref.txo is not None:
|
||||||
|
continue
|
||||||
|
referenced_txo = referenced_txos.get(txi.txo_ref.id)
|
||||||
|
if referenced_txo is not None:
|
||||||
|
txi.txo_ref = referenced_txo.ref
|
||||||
|
|
||||||
|
synced_history.write(f'{tx.id}:{tx.height}:')
|
||||||
|
|
||||||
|
await self.db.save_transaction_io(
|
||||||
|
tx, address, self.address_to_hash160(address), synced_history.getvalue()
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._on_transaction_controller.add(TransactionEvent(address, tx))
|
||||||
|
|
||||||
|
if address_manager is None:
|
||||||
|
address_manager = await self.get_address_manager_for_address(address)
|
||||||
|
|
||||||
|
if address_manager is not None:
|
||||||
|
await address_manager.ensure_address_gap()
|
||||||
|
|
||||||
|
async def cache_transaction(self, txid, remote_height):
|
||||||
|
cache_item = self._tx_cache.get(txid)
|
||||||
|
if cache_item is None:
|
||||||
|
cache_item = self._tx_cache[txid] = TransactionCacheItem()
|
||||||
|
elif cache_item.tx is not None and \
|
||||||
|
cache_item.tx.height >= remote_height and \
|
||||||
|
(cache_item.tx.is_verified or remote_height < 1):
|
||||||
|
return cache_item.tx # cached tx is already up-to-date
|
||||||
|
|
||||||
|
async with cache_item.lock:
|
||||||
|
|
||||||
|
tx = cache_item.tx
|
||||||
|
|
||||||
|
if tx is None:
|
||||||
|
# check local db
|
||||||
|
tx = cache_item.tx = await self.db.get_transaction(txid=txid)
|
||||||
|
|
||||||
|
if tx is None:
|
||||||
|
# fetch from network
|
||||||
|
_raw = await self.network.get_transaction(txid)
|
||||||
|
if _raw:
|
||||||
|
tx = self.transaction_class(unhexlify(_raw))
|
||||||
|
await self.maybe_verify_transaction(tx, remote_height)
|
||||||
|
await self.db.insert_transaction(tx)
|
||||||
|
cache_item.tx = tx # make sure it's saved before caching it
|
||||||
|
return tx
|
||||||
|
|
||||||
|
if tx is None:
|
||||||
|
raise ValueError(f'Transaction {txid} was not in database and not on network.')
|
||||||
|
|
||||||
|
if remote_height > 0 and not tx.is_verified:
|
||||||
|
# tx from cache / db is not up-to-date
|
||||||
|
await self.maybe_verify_transaction(tx, remote_height)
|
||||||
|
await self.db.update_transaction(tx)
|
||||||
|
|
||||||
|
return tx
|
||||||
|
|
||||||
|
async def maybe_verify_transaction(self, tx, remote_height):
|
||||||
|
tx.height = remote_height
|
||||||
|
if 0 < remote_height <= len(self.headers):
|
||||||
|
merkle = await self.network.get_merkle(tx.id, remote_height)
|
||||||
|
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||||
|
header = self.headers[remote_height]
|
||||||
|
tx.position = merkle['pos']
|
||||||
|
tx.is_verified = merkle_root == header['merkle_root']
|
||||||
|
|
||||||
|
async def get_address_manager_for_address(self, address) -> Optional[baseaccount.AddressManager]:
|
||||||
|
details = await self.db.get_address(address=address)
|
||||||
|
for account in self.accounts:
|
||||||
|
if account.id == details['account']:
|
||||||
|
return account.address_managers[details['chain']]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def broadcast(self, tx):
|
||||||
|
return self.network.broadcast(hexlify(tx.raw).decode())
|
||||||
|
|
||||||
|
async def wait(self, tx: basetransaction.BaseTransaction, height=-1, timeout=None):
|
||||||
|
addresses = set()
|
||||||
|
for txi in tx.inputs:
|
||||||
|
if txi.txo_ref.txo is not None:
|
||||||
|
addresses.add(
|
||||||
|
self.hash160_to_address(txi.txo_ref.txo.script.values['pubkey_hash'])
|
||||||
|
)
|
||||||
|
for txo in tx.outputs:
|
||||||
|
addresses.add(
|
||||||
|
self.hash160_to_address(txo.script.values['pubkey_hash'])
|
||||||
|
)
|
||||||
|
records = await self.db.get_addresses(cols=('address',), address__in=addresses)
|
||||||
|
_, pending = await asyncio.wait([
|
||||||
|
self.on_transaction.where(partial(
|
||||||
|
lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id,
|
||||||
|
address_record['address']
|
||||||
|
)) for address_record in records
|
||||||
|
], timeout=timeout)
|
||||||
|
if pending:
|
||||||
|
raise TimeoutError('Timed out waiting for transaction.')
|
80
torba/torba/client/basemanager.py
Normal file
80
torba/torba/client/basemanager.py
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Type, MutableSequence, MutableMapping
|
||||||
|
|
||||||
|
from torba.client.baseledger import BaseLedger, LedgerRegistry
|
||||||
|
from torba.client.wallet import Wallet, WalletStorage
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseWalletManager:
|
||||||
|
|
||||||
|
def __init__(self, wallets: MutableSequence[Wallet] = None,
|
||||||
|
ledgers: MutableMapping[Type[BaseLedger], BaseLedger] = None) -> None:
|
||||||
|
self.wallets = wallets or []
|
||||||
|
self.ledgers = ledgers or {}
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: dict) -> 'BaseWalletManager':
|
||||||
|
manager = cls()
|
||||||
|
for ledger_id, ledger_config in config.get('ledgers', {}).items():
|
||||||
|
manager.get_or_create_ledger(ledger_id, ledger_config)
|
||||||
|
for wallet_path in config.get('wallets', []):
|
||||||
|
wallet_storage = WalletStorage(wallet_path)
|
||||||
|
wallet = Wallet.from_storage(wallet_storage, manager)
|
||||||
|
manager.wallets.append(wallet)
|
||||||
|
return manager
|
||||||
|
|
||||||
|
def get_or_create_ledger(self, ledger_id, ledger_config=None):
|
||||||
|
ledger_class = LedgerRegistry.get_ledger_class(ledger_id)
|
||||||
|
ledger = self.ledgers.get(ledger_class)
|
||||||
|
if ledger is None:
|
||||||
|
ledger = ledger_class(ledger_config or {})
|
||||||
|
self.ledgers[ledger_class] = ledger
|
||||||
|
return ledger
|
||||||
|
|
||||||
|
def import_wallet(self, path):
|
||||||
|
storage = WalletStorage(path)
|
||||||
|
wallet = Wallet.from_storage(storage, self)
|
||||||
|
self.wallets.append(wallet)
|
||||||
|
return wallet
|
||||||
|
|
||||||
|
async def get_detailed_accounts(self, **kwargs):
|
||||||
|
ledgers = {}
|
||||||
|
for i, account in enumerate(self.accounts):
|
||||||
|
details = await account.get_details(**kwargs)
|
||||||
|
details['is_default'] = i == 0
|
||||||
|
ledger_id = account.ledger.get_id()
|
||||||
|
ledgers.setdefault(ledger_id, [])
|
||||||
|
ledgers[ledger_id].append(details)
|
||||||
|
return ledgers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_wallet(self):
|
||||||
|
for wallet in self.wallets:
|
||||||
|
return wallet
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_account(self):
|
||||||
|
for wallet in self.wallets:
|
||||||
|
return wallet.default_account
|
||||||
|
|
||||||
|
@property
|
||||||
|
def accounts(self):
|
||||||
|
for wallet in self.wallets:
|
||||||
|
for account in wallet.accounts:
|
||||||
|
yield account
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
self.running = True
|
||||||
|
await asyncio.gather(*(
|
||||||
|
l.start() for l in self.ledgers.values()
|
||||||
|
))
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
await asyncio.gather(*(
|
||||||
|
l.stop() for l in self.ledgers.values()
|
||||||
|
))
|
||||||
|
self.running = False
|
232
torba/torba/client/basenetwork.py
Normal file
232
torba/torba/client/basenetwork.py
Normal file
|
@ -0,0 +1,232 @@
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from asyncio import CancelledError
|
||||||
|
from time import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
|
||||||
|
|
||||||
|
from torba import __version__
|
||||||
|
from torba.stream import StreamController
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ClientSession(BaseClientSession):
|
||||||
|
|
||||||
|
def __init__(self, *args, network, server, **kwargs):
|
||||||
|
self.network = network
|
||||||
|
self.server = server
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._on_disconnect_controller = StreamController()
|
||||||
|
self.on_disconnected = self._on_disconnect_controller.stream
|
||||||
|
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
|
||||||
|
self.max_seconds_idle = 60
|
||||||
|
self.ping_task = None
|
||||||
|
|
||||||
|
async def send_request(self, method, args=()):
|
||||||
|
try:
|
||||||
|
return await super().send_request(method, args)
|
||||||
|
except RPCError as e:
|
||||||
|
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def ping_forever(self):
|
||||||
|
# TODO: change to 'ping' on newer protocol (above 1.2)
|
||||||
|
while not self.is_closing():
|
||||||
|
if (time() - self.last_send) > self.max_seconds_idle:
|
||||||
|
await self.send_request('server.banner')
|
||||||
|
await asyncio.sleep(self.max_seconds_idle//3)
|
||||||
|
|
||||||
|
async def create_connection(self, timeout=6):
|
||||||
|
connector = Connector(lambda: self, *self.server)
|
||||||
|
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
|
||||||
|
self.ping_task = asyncio.create_task(self.ping_forever())
|
||||||
|
|
||||||
|
async def handle_request(self, request):
|
||||||
|
controller = self.network.subscription_controllers[request.method]
|
||||||
|
controller.add(request.args)
|
||||||
|
|
||||||
|
def connection_lost(self, exc):
|
||||||
|
super().connection_lost(exc)
|
||||||
|
self._on_disconnect_controller.add(True)
|
||||||
|
if self.ping_task:
|
||||||
|
self.ping_task.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseNetwork:
|
||||||
|
|
||||||
|
def __init__(self, ledger):
|
||||||
|
self.config = ledger.config
|
||||||
|
self.client: ClientSession = None
|
||||||
|
self.session_pool: SessionPool = None
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
self._on_connected_controller = StreamController()
|
||||||
|
self.on_connected = self._on_connected_controller.stream
|
||||||
|
|
||||||
|
self._on_header_controller = StreamController()
|
||||||
|
self.on_header = self._on_header_controller.stream
|
||||||
|
|
||||||
|
self._on_status_controller = StreamController()
|
||||||
|
self.on_status = self._on_status_controller.stream
|
||||||
|
|
||||||
|
self.subscription_controllers = {
|
||||||
|
'blockchain.headers.subscribe': self._on_header_controller,
|
||||||
|
'blockchain.address.subscribe': self._on_status_controller,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
self.running = True
|
||||||
|
connect_timeout = self.config.get('connect_timeout', 6)
|
||||||
|
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
|
||||||
|
self.session_pool.start(self.config['default_servers'])
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self.client = await self.pick_fastest_session()
|
||||||
|
if self.is_connected:
|
||||||
|
await self.ensure_server_version()
|
||||||
|
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
|
||||||
|
self._on_connected_controller.add(True)
|
||||||
|
await self.client.on_disconnected.first
|
||||||
|
except CancelledError:
|
||||||
|
self.running = False
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
log.warning("Timed out while trying to find a server!")
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
log.exception("Exception while trying to find a server!")
|
||||||
|
if not self.running:
|
||||||
|
return
|
||||||
|
elif self.client:
|
||||||
|
await self.client.close()
|
||||||
|
self.client.connection.cancel_pending_requests()
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
self.running = False
|
||||||
|
if self.session_pool:
|
||||||
|
self.session_pool.stop()
|
||||||
|
if self.is_connected:
|
||||||
|
disconnected = self.client.on_disconnected.first
|
||||||
|
await self.client.close()
|
||||||
|
await disconnected
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self):
|
||||||
|
return self.client is not None and not self.client.is_closing()
|
||||||
|
|
||||||
|
def rpc(self, list_or_method, args):
|
||||||
|
if self.is_connected:
|
||||||
|
return self.client.send_request(list_or_method, args)
|
||||||
|
else:
|
||||||
|
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||||
|
|
||||||
|
async def pick_fastest_session(self):
|
||||||
|
sessions = await self.session_pool.get_online_sessions()
|
||||||
|
done, pending = await asyncio.wait([
|
||||||
|
self.probe_session(session)
|
||||||
|
for session in sessions if not session.is_closing()
|
||||||
|
], return_when='FIRST_COMPLETED')
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
for session in done:
|
||||||
|
return await session
|
||||||
|
|
||||||
|
async def probe_session(self, session: ClientSession):
|
||||||
|
await session.send_request('server.banner')
|
||||||
|
return session
|
||||||
|
|
||||||
|
def ensure_server_version(self, required='1.2'):
|
||||||
|
return self.rpc('server.version', [__version__, required])
|
||||||
|
|
||||||
|
def broadcast(self, raw_transaction):
|
||||||
|
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
|
||||||
|
|
||||||
|
def get_history(self, address):
|
||||||
|
return self.rpc('blockchain.address.get_history', [address])
|
||||||
|
|
||||||
|
def get_transaction(self, tx_hash):
|
||||||
|
return self.rpc('blockchain.transaction.get', [tx_hash])
|
||||||
|
|
||||||
|
def get_transaction_height(self, tx_hash):
|
||||||
|
return self.rpc('blockchain.transaction.get_height', [tx_hash])
|
||||||
|
|
||||||
|
def get_merkle(self, tx_hash, height):
|
||||||
|
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height])
|
||||||
|
|
||||||
|
def get_headers(self, height, count=10000):
|
||||||
|
return self.rpc('blockchain.block.headers', [height, count])
|
||||||
|
|
||||||
|
def subscribe_headers(self):
|
||||||
|
return self.rpc('blockchain.headers.subscribe', [True])
|
||||||
|
|
||||||
|
def subscribe_address(self, address):
|
||||||
|
return self.rpc('blockchain.address.subscribe', [address])
|
||||||
|
|
||||||
|
|
||||||
|
class SessionPool:
|
||||||
|
|
||||||
|
def __init__(self, network: BaseNetwork, timeout: float):
|
||||||
|
self.network = network
|
||||||
|
self.sessions: List[ClientSession] = []
|
||||||
|
self._dead_servers: List[ClientSession] = []
|
||||||
|
self.maintain_connections_task = None
|
||||||
|
self.timeout = timeout
|
||||||
|
# triggered when the master server is out, to speed up reconnect
|
||||||
|
self._lost_master = asyncio.Event()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def online(self):
|
||||||
|
for session in self.sessions:
|
||||||
|
if not session.is_closing():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def start(self, default_servers):
|
||||||
|
self.sessions = [
|
||||||
|
ClientSession(network=self.network, server=server)
|
||||||
|
for server in default_servers
|
||||||
|
]
|
||||||
|
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if self.maintain_connections_task:
|
||||||
|
self.maintain_connections_task.cancel()
|
||||||
|
for session in self.sessions:
|
||||||
|
if not session.is_closing():
|
||||||
|
session.abort()
|
||||||
|
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None
|
||||||
|
|
||||||
|
async def ensure_connections(self):
|
||||||
|
while True:
|
||||||
|
await asyncio.gather(*[
|
||||||
|
self.ensure_connection(session)
|
||||||
|
for session in self.sessions
|
||||||
|
], return_exceptions=True)
|
||||||
|
await asyncio.wait([asyncio.sleep(3), self._lost_master.wait()], return_when='FIRST_COMPLETED')
|
||||||
|
self._lost_master.clear()
|
||||||
|
if not self.sessions:
|
||||||
|
self.sessions.extend(self._dead_servers)
|
||||||
|
self._dead_servers = []
|
||||||
|
|
||||||
|
async def ensure_connection(self, session):
|
||||||
|
if not session.is_closing():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
return await session.create_connection(self.timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
log.warning("Timeout connecting to %s:%d", *session.server)
|
||||||
|
except asyncio.CancelledError: # pylint: disable=try-except-raise
|
||||||
|
raise
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
if 'Connect call failed' in str(err):
|
||||||
|
log.warning("Could not connect to %s:%d", *session.server)
|
||||||
|
else:
|
||||||
|
log.exception("Connecting to %s:%d raised an exception:", *session.server)
|
||||||
|
self._dead_servers.append(session)
|
||||||
|
self.sessions.remove(session)
|
||||||
|
|
||||||
|
async def get_online_sessions(self):
|
||||||
|
self._lost_master.set()
|
||||||
|
while not self.online:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
return self.sessions
|
436
torba/torba/client/basescript.py
Normal file
436
torba/torba/client/basescript.py
Normal file
|
@ -0,0 +1,436 @@
|
||||||
|
from itertools import chain
|
||||||
|
from binascii import hexlify
|
||||||
|
from collections import namedtuple
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from torba.client.bcd_data_stream import BCDataStream
|
||||||
|
from torba.client.util import subclass_tuple
|
||||||
|
|
||||||
|
# bitcoin opcodes
|
||||||
|
OP_0 = 0x00
|
||||||
|
OP_1 = 0x51
|
||||||
|
OP_16 = 0x60
|
||||||
|
OP_VERIFY = 0x69
|
||||||
|
OP_DUP = 0x76
|
||||||
|
OP_HASH160 = 0xa9
|
||||||
|
OP_EQUALVERIFY = 0x88
|
||||||
|
OP_CHECKSIG = 0xac
|
||||||
|
OP_CHECKMULTISIG = 0xae
|
||||||
|
OP_EQUAL = 0x87
|
||||||
|
OP_PUSHDATA1 = 0x4c
|
||||||
|
OP_PUSHDATA2 = 0x4d
|
||||||
|
OP_PUSHDATA4 = 0x4e
|
||||||
|
OP_RETURN = 0x6a
|
||||||
|
OP_2DROP = 0x6d
|
||||||
|
OP_DROP = 0x75
|
||||||
|
|
||||||
|
|
||||||
|
# template matching opcodes (not real opcodes)
|
||||||
|
# base class for PUSH_DATA related opcodes
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
PUSH_DATA_OP = namedtuple('PUSH_DATA_OP', 'name')
|
||||||
|
# opcode for variable length strings
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
PUSH_SINGLE = subclass_tuple('PUSH_SINGLE', PUSH_DATA_OP)
|
||||||
|
# opcode for variable size integers
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
PUSH_INTEGER = subclass_tuple('PUSH_INTEGER', PUSH_DATA_OP)
|
||||||
|
# opcode for variable number of variable length strings
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
PUSH_MANY = subclass_tuple('PUSH_MANY', PUSH_DATA_OP)
|
||||||
|
# opcode with embedded subscript parsing
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
PUSH_SUBSCRIPT = namedtuple('PUSH_SUBSCRIPT', 'name template')
|
||||||
|
|
||||||
|
|
||||||
|
def is_push_data_opcode(opcode):
|
||||||
|
return isinstance(opcode, (PUSH_DATA_OP, PUSH_SUBSCRIPT))
|
||||||
|
|
||||||
|
|
||||||
|
def is_push_data_token(token):
|
||||||
|
return 1 <= token <= OP_PUSHDATA4
|
||||||
|
|
||||||
|
|
||||||
|
def push_data(data):
|
||||||
|
size = len(data)
|
||||||
|
if size < OP_PUSHDATA1:
|
||||||
|
yield BCDataStream.uint8.pack(size)
|
||||||
|
elif size <= 0xFF:
|
||||||
|
yield BCDataStream.uint8.pack(OP_PUSHDATA1)
|
||||||
|
yield BCDataStream.uint8.pack(size)
|
||||||
|
elif size <= 0xFFFF:
|
||||||
|
yield BCDataStream.uint8.pack(OP_PUSHDATA2)
|
||||||
|
yield BCDataStream.uint16.pack(size)
|
||||||
|
else:
|
||||||
|
yield BCDataStream.uint8.pack(OP_PUSHDATA4)
|
||||||
|
yield BCDataStream.uint32.pack(size)
|
||||||
|
yield bytes(data)
|
||||||
|
|
||||||
|
|
||||||
|
def read_data(token, stream):
|
||||||
|
if token < OP_PUSHDATA1:
|
||||||
|
return stream.read(token)
|
||||||
|
if token == OP_PUSHDATA1:
|
||||||
|
return stream.read(stream.read_uint8())
|
||||||
|
if token == OP_PUSHDATA2:
|
||||||
|
return stream.read(stream.read_uint16())
|
||||||
|
return stream.read(stream.read_uint32())
|
||||||
|
|
||||||
|
|
||||||
|
# opcode for OP_1 - OP_16
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
SMALL_INTEGER = namedtuple('SMALL_INTEGER', 'name')
|
||||||
|
|
||||||
|
|
||||||
|
def is_small_integer(token):
|
||||||
|
return OP_1 <= token <= OP_16
|
||||||
|
|
||||||
|
|
||||||
|
def push_small_integer(num):
|
||||||
|
assert 1 <= num <= 16
|
||||||
|
yield BCDataStream.uint8.pack(OP_1 + (num - 1))
|
||||||
|
|
||||||
|
|
||||||
|
def read_small_integer(token):
|
||||||
|
return (token - OP_1) + 1
|
||||||
|
|
||||||
|
|
||||||
|
class Token(namedtuple('Token', 'value')):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
name = None
|
||||||
|
for var_name, var_value in globals().items():
|
||||||
|
if var_name.startswith('OP_') and var_value == self.value:
|
||||||
|
name = var_name
|
||||||
|
break
|
||||||
|
return name or self.value
|
||||||
|
|
||||||
|
|
||||||
|
class DataToken(Token):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '"{}"'.format(hexlify(self.value))
|
||||||
|
|
||||||
|
|
||||||
|
class SmallIntegerToken(Token):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'SmallIntegerToken({})'.format(self.value)
|
||||||
|
|
||||||
|
|
||||||
|
def token_producer(source):
|
||||||
|
token = source.read_uint8()
|
||||||
|
while token is not None:
|
||||||
|
if is_push_data_token(token):
|
||||||
|
yield DataToken(read_data(token, source))
|
||||||
|
elif is_small_integer(token):
|
||||||
|
yield SmallIntegerToken(read_small_integer(token))
|
||||||
|
else:
|
||||||
|
yield Token(token)
|
||||||
|
token = source.read_uint8()
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(source):
|
||||||
|
return list(token_producer(source))
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptError(Exception):
|
||||||
|
""" General script handling error. """
|
||||||
|
|
||||||
|
|
||||||
|
class ParseError(ScriptError):
|
||||||
|
""" Script parsing error. """
|
||||||
|
|
||||||
|
|
||||||
|
class Parser:
|
||||||
|
|
||||||
|
def __init__(self, opcodes, tokens):
|
||||||
|
self.opcodes = opcodes
|
||||||
|
self.tokens = tokens
|
||||||
|
self.values = {}
|
||||||
|
self.token_index = 0
|
||||||
|
self.opcode_index = 0
|
||||||
|
|
||||||
|
def parse(self):
|
||||||
|
while self.token_index < len(self.tokens) and self.opcode_index < len(self.opcodes):
|
||||||
|
token = self.tokens[self.token_index]
|
||||||
|
opcode = self.opcodes[self.opcode_index]
|
||||||
|
if token.value == 0 and isinstance(opcode, PUSH_SINGLE):
|
||||||
|
token = DataToken(b'')
|
||||||
|
if isinstance(token, DataToken):
|
||||||
|
if isinstance(opcode, (PUSH_SINGLE, PUSH_INTEGER, PUSH_SUBSCRIPT)):
|
||||||
|
self.push_single(opcode, token.value)
|
||||||
|
elif isinstance(opcode, PUSH_MANY):
|
||||||
|
self.consume_many_non_greedy()
|
||||||
|
else:
|
||||||
|
raise ParseError("DataToken found but opcode was '{}'.".format(opcode))
|
||||||
|
elif isinstance(token, SmallIntegerToken):
|
||||||
|
if isinstance(opcode, SMALL_INTEGER):
|
||||||
|
self.values[opcode.name] = token.value
|
||||||
|
else:
|
||||||
|
raise ParseError("SmallIntegerToken found but opcode was '{}'.".format(opcode))
|
||||||
|
elif token.value == opcode:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ParseError("Token is '{}' and opcode is '{}'.".format(token.value, opcode))
|
||||||
|
self.token_index += 1
|
||||||
|
self.opcode_index += 1
|
||||||
|
|
||||||
|
if self.token_index < len(self.tokens):
|
||||||
|
raise ParseError("Parse completed without all tokens being consumed.")
|
||||||
|
|
||||||
|
if self.opcode_index < len(self.opcodes):
|
||||||
|
raise ParseError("Parse completed without all opcodes being consumed.")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def consume_many_non_greedy(self):
|
||||||
|
""" Allows PUSH_MANY to consume data without being greedy
|
||||||
|
in cases when one or more PUSH_SINGLEs follow a PUSH_MANY. This will
|
||||||
|
prioritize giving all PUSH_SINGLEs some data and only after that
|
||||||
|
subsume the rest into PUSH_MANY.
|
||||||
|
"""
|
||||||
|
|
||||||
|
token_values = []
|
||||||
|
while self.token_index < len(self.tokens):
|
||||||
|
token = self.tokens[self.token_index]
|
||||||
|
if not isinstance(token, DataToken):
|
||||||
|
self.token_index -= 1
|
||||||
|
break
|
||||||
|
token_values.append(token.value)
|
||||||
|
self.token_index += 1
|
||||||
|
|
||||||
|
push_opcodes = []
|
||||||
|
push_many_count = 0
|
||||||
|
while self.opcode_index < len(self.opcodes):
|
||||||
|
opcode = self.opcodes[self.opcode_index]
|
||||||
|
if not is_push_data_opcode(opcode):
|
||||||
|
self.opcode_index -= 1
|
||||||
|
break
|
||||||
|
if isinstance(opcode, PUSH_MANY):
|
||||||
|
push_many_count += 1
|
||||||
|
push_opcodes.append(opcode)
|
||||||
|
self.opcode_index += 1
|
||||||
|
|
||||||
|
if push_many_count > 1:
|
||||||
|
raise ParseError(
|
||||||
|
"Cannot have more than one consecutive PUSH_MANY, as there is no way to tell which"
|
||||||
|
" token value should go into which PUSH_MANY."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(push_opcodes) > len(token_values):
|
||||||
|
raise ParseError(
|
||||||
|
"Not enough token values to match all of the PUSH_MANY and PUSH_SINGLE opcodes."
|
||||||
|
)
|
||||||
|
|
||||||
|
many_opcode = push_opcodes.pop(0)
|
||||||
|
|
||||||
|
# consume data into PUSH_SINGLE opcodes, working backwards
|
||||||
|
for opcode in reversed(push_opcodes):
|
||||||
|
self.push_single(opcode, token_values.pop())
|
||||||
|
|
||||||
|
# finally PUSH_MANY gets everything that's left
|
||||||
|
self.values[many_opcode.name] = token_values
|
||||||
|
|
||||||
|
def push_single(self, opcode, value):
|
||||||
|
if isinstance(opcode, PUSH_SINGLE):
|
||||||
|
self.values[opcode.name] = value
|
||||||
|
elif isinstance(opcode, PUSH_INTEGER):
|
||||||
|
self.values[opcode.name] = int.from_bytes(value, 'little')
|
||||||
|
elif isinstance(opcode, PUSH_SUBSCRIPT):
|
||||||
|
self.values[opcode.name] = Script.from_source_with_template(value, opcode.template)
|
||||||
|
else:
|
||||||
|
raise ParseError("Not a push single or subscript: {}".format(opcode))
|
||||||
|
|
||||||
|
|
||||||
|
class Template:
|
||||||
|
|
||||||
|
__slots__ = 'name', 'opcodes'
|
||||||
|
|
||||||
|
def __init__(self, name, opcodes):
|
||||||
|
self.name = name
|
||||||
|
self.opcodes = opcodes
|
||||||
|
|
||||||
|
def parse(self, tokens):
|
||||||
|
return Parser(self.opcodes, tokens).parse().values
|
||||||
|
|
||||||
|
def generate(self, values):
|
||||||
|
source = BCDataStream()
|
||||||
|
for opcode in self.opcodes:
|
||||||
|
if isinstance(opcode, PUSH_SINGLE):
|
||||||
|
data = values[opcode.name]
|
||||||
|
source.write_many(push_data(data))
|
||||||
|
elif isinstance(opcode, PUSH_INTEGER):
|
||||||
|
data = values[opcode.name]
|
||||||
|
source.write_many(push_data(
|
||||||
|
data.to_bytes((data.bit_length() + 7) // 8, byteorder='little')
|
||||||
|
))
|
||||||
|
elif isinstance(opcode, PUSH_SUBSCRIPT):
|
||||||
|
data = values[opcode.name]
|
||||||
|
source.write_many(push_data(data.source))
|
||||||
|
elif isinstance(opcode, PUSH_MANY):
|
||||||
|
for data in values[opcode.name]:
|
||||||
|
source.write_many(push_data(data))
|
||||||
|
elif isinstance(opcode, SMALL_INTEGER):
|
||||||
|
data = values[opcode.name]
|
||||||
|
source.write_many(push_small_integer(data))
|
||||||
|
else:
|
||||||
|
source.write_uint8(opcode)
|
||||||
|
return source.get_bytes()
|
||||||
|
|
||||||
|
|
||||||
|
class Script:
|
||||||
|
|
||||||
|
__slots__ = 'source', '_template', '_values', '_template_hint'
|
||||||
|
|
||||||
|
templates: List[Template] = []
|
||||||
|
|
||||||
|
def __init__(self, source=None, template=None, values=None, template_hint=None):
|
||||||
|
self.source = source
|
||||||
|
self._template = template
|
||||||
|
self._values = values
|
||||||
|
self._template_hint = template_hint
|
||||||
|
if source is None and template and values:
|
||||||
|
self.generate()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def template(self):
|
||||||
|
if self._template is None:
|
||||||
|
self.parse(self._template_hint)
|
||||||
|
return self._template
|
||||||
|
|
||||||
|
@property
|
||||||
|
def values(self):
|
||||||
|
if self._values is None:
|
||||||
|
self.parse(self._template_hint)
|
||||||
|
return self._values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokens(self):
|
||||||
|
return tokenize(BCDataStream(self.source))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_source_with_template(cls, source, template):
|
||||||
|
return cls(source, template_hint=template)
|
||||||
|
|
||||||
|
def parse(self, template_hint=None):
|
||||||
|
tokens = self.tokens
|
||||||
|
for template in chain((template_hint,), self.templates):
|
||||||
|
if not template:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
self._values = template.parse(tokens)
|
||||||
|
self._template = template
|
||||||
|
return
|
||||||
|
except ParseError:
|
||||||
|
continue
|
||||||
|
raise ValueError('No matching templates for source: {}'.format(hexlify(self.source)))
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
self.source = self.template.generate(self._values)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseInputScript(Script):
|
||||||
|
""" Input / redeem script templates (aka scriptSig) """
|
||||||
|
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
REDEEM_PUBKEY = Template('pubkey', (
|
||||||
|
PUSH_SINGLE('signature'),
|
||||||
|
))
|
||||||
|
REDEEM_PUBKEY_HASH = Template('pubkey_hash', (
|
||||||
|
PUSH_SINGLE('signature'), PUSH_SINGLE('pubkey')
|
||||||
|
))
|
||||||
|
REDEEM_SCRIPT = Template('script', (
|
||||||
|
SMALL_INTEGER('signatures_count'), PUSH_MANY('pubkeys'), SMALL_INTEGER('pubkeys_count'),
|
||||||
|
OP_CHECKMULTISIG
|
||||||
|
))
|
||||||
|
REDEEM_SCRIPT_HASH = Template('script_hash', (
|
||||||
|
OP_0, PUSH_MANY('signatures'), PUSH_SUBSCRIPT('script', REDEEM_SCRIPT)
|
||||||
|
))
|
||||||
|
|
||||||
|
templates = [
|
||||||
|
REDEEM_PUBKEY,
|
||||||
|
REDEEM_PUBKEY_HASH,
|
||||||
|
REDEEM_SCRIPT_HASH,
|
||||||
|
REDEEM_SCRIPT
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def redeem_pubkey_hash(cls, signature, pubkey):
|
||||||
|
return cls(template=cls.REDEEM_PUBKEY_HASH, values={
|
||||||
|
'signature': signature,
|
||||||
|
'pubkey': pubkey
|
||||||
|
})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def redeem_script_hash(cls, signatures, pubkeys):
|
||||||
|
return cls(template=cls.REDEEM_SCRIPT_HASH, values={
|
||||||
|
'signatures': signatures,
|
||||||
|
'script': cls.redeem_script(signatures, pubkeys)
|
||||||
|
})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def redeem_script(cls, signatures, pubkeys):
|
||||||
|
return cls(template=cls.REDEEM_SCRIPT, values={
|
||||||
|
'signatures_count': len(signatures),
|
||||||
|
'pubkeys': pubkeys,
|
||||||
|
'pubkeys_count': len(pubkeys)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOutputScript(Script):
|
||||||
|
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
# output / payment script templates (aka scriptPubKey)
|
||||||
|
PAY_PUBKEY_FULL = Template('pay_pubkey_full', (
|
||||||
|
PUSH_SINGLE('pubkey'), OP_CHECKSIG
|
||||||
|
))
|
||||||
|
PAY_PUBKEY_HASH = Template('pay_pubkey_hash', (
|
||||||
|
OP_DUP, OP_HASH160, PUSH_SINGLE('pubkey_hash'), OP_EQUALVERIFY, OP_CHECKSIG
|
||||||
|
))
|
||||||
|
PAY_SCRIPT_HASH = Template('pay_script_hash', (
|
||||||
|
OP_HASH160, PUSH_SINGLE('script_hash'), OP_EQUAL
|
||||||
|
))
|
||||||
|
RETURN_DATA = Template('return_data', (
|
||||||
|
OP_RETURN, PUSH_SINGLE('data')
|
||||||
|
))
|
||||||
|
|
||||||
|
templates = [
|
||||||
|
PAY_PUBKEY_FULL,
|
||||||
|
PAY_PUBKEY_HASH,
|
||||||
|
PAY_SCRIPT_HASH,
|
||||||
|
RETURN_DATA
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pay_pubkey_hash(cls, pubkey_hash):
|
||||||
|
return cls(template=cls.PAY_PUBKEY_HASH, values={
|
||||||
|
'pubkey_hash': pubkey_hash
|
||||||
|
})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pay_script_hash(cls, script_hash):
|
||||||
|
return cls(template=cls.PAY_SCRIPT_HASH, values={
|
||||||
|
'script_hash': script_hash
|
||||||
|
})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_pay_pubkey(self):
|
||||||
|
return self.template.name.endswith('pay_pubkey_full')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_pay_pubkey_hash(self):
|
||||||
|
return self.template.name.endswith('pay_pubkey_hash')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_pay_script_hash(self):
|
||||||
|
return self.template.name.endswith('pay_script_hash')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_return_data(self):
|
||||||
|
return self.template.name.endswith('return_data')
|
541
torba/torba/client/basetransaction.py
Normal file
541
torba/torba/client/basetransaction.py
Normal file
|
@ -0,0 +1,541 @@
|
||||||
|
import logging
|
||||||
|
import typing
|
||||||
|
from typing import List, Iterable, Optional
|
||||||
|
from binascii import hexlify
|
||||||
|
|
||||||
|
from torba.client.basescript import BaseInputScript, BaseOutputScript
|
||||||
|
from torba.client.baseaccount import BaseAccount
|
||||||
|
from torba.client.constants import COIN, NULL_HASH32
|
||||||
|
from torba.client.bcd_data_stream import BCDataStream
|
||||||
|
from torba.client.hash import sha256, TXRef, TXRefImmutable
|
||||||
|
from torba.client.util import ReadOnlyList
|
||||||
|
from torba.client.errors import InsufficientFundsError
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from torba.client import baseledger
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class TXRefMutable(TXRef):
|
||||||
|
|
||||||
|
__slots__ = ('tx',)
|
||||||
|
|
||||||
|
def __init__(self, tx: 'BaseTransaction') -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.tx = tx
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self):
|
||||||
|
if self._id is None:
|
||||||
|
self._id = hexlify(self.hash[::-1]).decode()
|
||||||
|
return self._id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hash(self):
|
||||||
|
if self._hash is None:
|
||||||
|
self._hash = sha256(sha256(self.tx.raw))
|
||||||
|
return self._hash
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height(self):
|
||||||
|
return self.tx.height
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._id = None
|
||||||
|
self._hash = None
|
||||||
|
|
||||||
|
|
||||||
|
class TXORef:
|
||||||
|
|
||||||
|
__slots__ = 'tx_ref', 'position'
|
||||||
|
|
||||||
|
def __init__(self, tx_ref: TXRef, position: int) -> None:
|
||||||
|
self.tx_ref = tx_ref
|
||||||
|
self.position = position
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self):
|
||||||
|
return '{}:{}'.format(self.tx_ref.id, self.position)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hash(self):
|
||||||
|
return self.tx_ref.hash + BCDataStream.uint32.pack(self.position)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_null(self):
|
||||||
|
return self.tx_ref.is_null
|
||||||
|
|
||||||
|
@property
|
||||||
|
def txo(self) -> Optional['BaseOutput']:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TXORefResolvable(TXORef):
|
||||||
|
|
||||||
|
__slots__ = ('_txo',)
|
||||||
|
|
||||||
|
def __init__(self, txo: 'BaseOutput') -> None:
|
||||||
|
assert txo.tx_ref is not None
|
||||||
|
assert txo.position is not None
|
||||||
|
super().__init__(txo.tx_ref, txo.position)
|
||||||
|
self._txo = txo
|
||||||
|
|
||||||
|
@property
|
||||||
|
def txo(self):
|
||||||
|
return self._txo
|
||||||
|
|
||||||
|
|
||||||
|
class InputOutput:
|
||||||
|
|
||||||
|
__slots__ = 'tx_ref', 'position'
|
||||||
|
|
||||||
|
def __init__(self, tx_ref: TXRef = None, position: int = None) -> None:
|
||||||
|
self.tx_ref = tx_ref
|
||||||
|
self.position = position
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
""" Size of this input / output in bytes. """
|
||||||
|
stream = BCDataStream()
|
||||||
|
self.serialize_to(stream)
|
||||||
|
return len(stream.get_bytes())
|
||||||
|
|
||||||
|
def get_fee(self, ledger):
|
||||||
|
return self.size * ledger.fee_per_byte
|
||||||
|
|
||||||
|
def serialize_to(self, stream, alternate_script=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class BaseInput(InputOutput):
|
||||||
|
|
||||||
|
script_class = BaseInputScript
|
||||||
|
|
||||||
|
NULL_SIGNATURE = b'\x00'*72
|
||||||
|
NULL_PUBLIC_KEY = b'\x00'*33
|
||||||
|
|
||||||
|
__slots__ = 'txo_ref', 'sequence', 'coinbase', 'script'
|
||||||
|
|
||||||
|
def __init__(self, txo_ref: TXORef, script: BaseInputScript, sequence: int = 0xFFFFFFFF,
|
||||||
|
tx_ref: TXRef = None, position: int = None) -> None:
|
||||||
|
super().__init__(tx_ref, position)
|
||||||
|
self.txo_ref = txo_ref
|
||||||
|
self.sequence = sequence
|
||||||
|
self.coinbase = script if txo_ref.is_null else None
|
||||||
|
self.script = script if not txo_ref.is_null else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_coinbase(self):
|
||||||
|
return self.coinbase is not None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def spend(cls, txo: 'BaseOutput') -> 'BaseInput':
|
||||||
|
""" Create an input to spend the output."""
|
||||||
|
assert txo.script.is_pay_pubkey_hash, 'Attempting to spend unsupported output.'
|
||||||
|
script = cls.script_class.redeem_pubkey_hash(cls.NULL_SIGNATURE, cls.NULL_PUBLIC_KEY)
|
||||||
|
return cls(txo.ref, script)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def amount(self) -> int:
|
||||||
|
""" Amount this input adds to the transaction. """
|
||||||
|
if self.txo_ref.txo is None:
|
||||||
|
raise ValueError('Cannot resolve output to get amount.')
|
||||||
|
return self.txo_ref.txo.amount
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_my_account(self) -> Optional[bool]:
|
||||||
|
""" True if the output this input spends is yours. """
|
||||||
|
if self.txo_ref.txo is None:
|
||||||
|
return False
|
||||||
|
return self.txo_ref.txo.is_my_account
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize_from(cls, stream):
|
||||||
|
tx_ref = TXRefImmutable.from_hash(stream.read(32), -1)
|
||||||
|
position = stream.read_uint32()
|
||||||
|
script = stream.read_string()
|
||||||
|
sequence = stream.read_uint32()
|
||||||
|
return cls(
|
||||||
|
TXORef(tx_ref, position),
|
||||||
|
cls.script_class(script) if not tx_ref.is_null else script,
|
||||||
|
sequence
|
||||||
|
)
|
||||||
|
|
||||||
|
def serialize_to(self, stream, alternate_script=None):
|
||||||
|
stream.write(self.txo_ref.tx_ref.hash)
|
||||||
|
stream.write_uint32(self.txo_ref.position)
|
||||||
|
if alternate_script is not None:
|
||||||
|
stream.write_string(alternate_script)
|
||||||
|
else:
|
||||||
|
if self.is_coinbase:
|
||||||
|
stream.write_string(self.coinbase)
|
||||||
|
else:
|
||||||
|
stream.write_string(self.script.source)
|
||||||
|
stream.write_uint32(self.sequence)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOutputEffectiveAmountEstimator:
|
||||||
|
|
||||||
|
__slots__ = 'txo', 'txi', 'fee', 'effective_amount'
|
||||||
|
|
||||||
|
def __init__(self, ledger: 'baseledger.BaseLedger', txo: 'BaseOutput') -> None:
|
||||||
|
self.txo = txo
|
||||||
|
self.txi = ledger.transaction_class.input_class.spend(txo)
|
||||||
|
self.fee: int = self.txi.get_fee(ledger)
|
||||||
|
self.effective_amount: int = txo.amount - self.fee
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
return self.effective_amount < other.effective_amount
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOutput(InputOutput):
|
||||||
|
|
||||||
|
script_class = BaseOutputScript
|
||||||
|
estimator_class = BaseOutputEffectiveAmountEstimator
|
||||||
|
|
||||||
|
__slots__ = 'amount', 'script', 'is_change', 'is_my_account'
|
||||||
|
|
||||||
|
def __init__(self, amount: int, script: BaseOutputScript,
|
||||||
|
tx_ref: TXRef = None, position: int = None,
|
||||||
|
is_change: Optional[bool] = None, is_my_account: Optional[bool] = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(tx_ref, position)
|
||||||
|
self.amount = amount
|
||||||
|
self.script = script
|
||||||
|
self.is_change = is_change
|
||||||
|
self.is_my_account = is_my_account
|
||||||
|
|
||||||
|
def update_annotations(self, annotated):
|
||||||
|
if annotated is None:
|
||||||
|
self.is_change = False
|
||||||
|
self.is_my_account = False
|
||||||
|
else:
|
||||||
|
self.is_change = annotated.is_change
|
||||||
|
self.is_my_account = annotated.is_my_account
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ref(self):
|
||||||
|
return TXORefResolvable(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self):
|
||||||
|
return self.ref.id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pubkey_hash(self):
|
||||||
|
return self.script.values['pubkey_hash']
|
||||||
|
|
||||||
|
def get_address(self, ledger):
|
||||||
|
return ledger.hash160_to_address(self.pubkey_hash)
|
||||||
|
|
||||||
|
def get_estimator(self, ledger):
|
||||||
|
return self.estimator_class(ledger, self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pay_pubkey_hash(cls, amount, pubkey_hash):
|
||||||
|
return cls(amount, cls.script_class.pay_pubkey_hash(pubkey_hash))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize_from(cls, stream):
|
||||||
|
return cls(
|
||||||
|
amount=stream.read_uint64(),
|
||||||
|
script=cls.script_class(stream.read_string())
|
||||||
|
)
|
||||||
|
|
||||||
|
def serialize_to(self, stream, alternate_script=None):
|
||||||
|
stream.write_uint64(self.amount)
|
||||||
|
stream.write_string(self.script.source)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTransaction:
|
||||||
|
|
||||||
|
input_class = BaseInput
|
||||||
|
output_class = BaseOutput
|
||||||
|
|
||||||
|
def __init__(self, raw=None, version: int = 1, locktime: int = 0, is_verified: bool = False,
|
||||||
|
height: int = -2, position: int = -1) -> None:
|
||||||
|
self._raw = raw
|
||||||
|
self.ref = TXRefMutable(self)
|
||||||
|
self.version = version
|
||||||
|
self.locktime = locktime
|
||||||
|
self._inputs: List[BaseInput] = []
|
||||||
|
self._outputs: List[BaseOutput] = []
|
||||||
|
self.is_verified = is_verified
|
||||||
|
# Height Progression
|
||||||
|
# -2: not broadcast
|
||||||
|
# -1: in mempool but has unconfirmed inputs
|
||||||
|
# 0: in mempool and all inputs confirmed
|
||||||
|
# +num: confirmed in a specific block (height)
|
||||||
|
self.height = height
|
||||||
|
self.position = position
|
||||||
|
if raw is not None:
|
||||||
|
self._deserialize()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_broadcast(self):
|
||||||
|
return self.height > -2
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_mempool(self):
|
||||||
|
return self.height in (-1, 0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_confirmed(self):
|
||||||
|
return self.height > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self):
|
||||||
|
return self.ref.id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hash(self):
|
||||||
|
return self.ref.hash
|
||||||
|
|
||||||
|
@property
|
||||||
|
def raw(self):
|
||||||
|
if self._raw is None:
|
||||||
|
self._raw = self._serialize()
|
||||||
|
return self._raw
|
||||||
|
|
||||||
|
def _reset(self):
|
||||||
|
self._raw = None
|
||||||
|
self.ref.reset()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> ReadOnlyList[BaseInput]:
|
||||||
|
return ReadOnlyList(self._inputs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> ReadOnlyList[BaseOutput]:
|
||||||
|
return ReadOnlyList(self._outputs)
|
||||||
|
|
||||||
|
def _add(self, new_ios: Iterable[InputOutput], existing_ios: List) -> 'BaseTransaction':
|
||||||
|
for txio in new_ios:
|
||||||
|
txio.tx_ref = self.ref
|
||||||
|
txio.position = len(existing_ios)
|
||||||
|
existing_ios.append(txio)
|
||||||
|
self._reset()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_inputs(self, inputs: Iterable[BaseInput]) -> 'BaseTransaction':
|
||||||
|
return self._add(inputs, self._inputs)
|
||||||
|
|
||||||
|
def add_outputs(self, outputs: Iterable[BaseOutput]) -> 'BaseTransaction':
|
||||||
|
return self._add(outputs, self._outputs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
""" Size in bytes of the entire transaction. """
|
||||||
|
return len(self.raw)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_size(self) -> int:
|
||||||
|
""" Size of transaction without inputs or outputs in bytes. """
|
||||||
|
return (
|
||||||
|
self.size
|
||||||
|
- sum(txi.size for txi in self._inputs)
|
||||||
|
- sum(txo.size for txo in self._outputs)
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_sum(self):
|
||||||
|
return sum(i.amount for i in self.inputs if i.txo_ref.txo is not None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_sum(self):
|
||||||
|
return sum(o.amount for o in self.outputs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def net_account_balance(self) -> int:
|
||||||
|
balance = 0
|
||||||
|
for txi in self.inputs:
|
||||||
|
if txi.txo_ref.txo is None:
|
||||||
|
continue
|
||||||
|
if txi.is_my_account is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot access net_account_balance if inputs/outputs do not "
|
||||||
|
"have is_my_account set properly."
|
||||||
|
)
|
||||||
|
if txi.is_my_account:
|
||||||
|
balance -= txi.amount
|
||||||
|
for txo in self.outputs:
|
||||||
|
if txo.is_my_account is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot access net_account_balance if inputs/outputs do not "
|
||||||
|
"have is_my_account set properly."
|
||||||
|
)
|
||||||
|
if txo.is_my_account:
|
||||||
|
balance += txo.amount
|
||||||
|
return balance
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fee(self) -> int:
|
||||||
|
return self.input_sum - self.output_sum
|
||||||
|
|
||||||
|
def get_base_fee(self, ledger) -> int:
|
||||||
|
""" Fee for base tx excluding inputs and outputs. """
|
||||||
|
return self.base_size * ledger.fee_per_byte
|
||||||
|
|
||||||
|
def get_effective_input_sum(self, ledger) -> int:
|
||||||
|
""" Sum of input values *minus* the cost involved to spend them. """
|
||||||
|
return sum(txi.amount - txi.get_fee(ledger) for txi in self._inputs)
|
||||||
|
|
||||||
|
def get_total_output_sum(self, ledger) -> int:
|
||||||
|
""" Sum of output values *plus* the cost involved to spend them. """
|
||||||
|
return sum(txo.amount + txo.get_fee(ledger) for txo in self._outputs)
|
||||||
|
|
||||||
|
def _serialize(self, with_inputs: bool = True) -> bytes:
|
||||||
|
stream = BCDataStream()
|
||||||
|
stream.write_uint32(self.version)
|
||||||
|
if with_inputs:
|
||||||
|
stream.write_compact_size(len(self._inputs))
|
||||||
|
for txin in self._inputs:
|
||||||
|
txin.serialize_to(stream)
|
||||||
|
stream.write_compact_size(len(self._outputs))
|
||||||
|
for txout in self._outputs:
|
||||||
|
txout.serialize_to(stream)
|
||||||
|
stream.write_uint32(self.locktime)
|
||||||
|
return stream.get_bytes()
|
||||||
|
|
||||||
|
def _serialize_for_signature(self, signing_input: int) -> bytes:
|
||||||
|
stream = BCDataStream()
|
||||||
|
stream.write_uint32(self.version)
|
||||||
|
stream.write_compact_size(len(self._inputs))
|
||||||
|
for i, txin in enumerate(self._inputs):
|
||||||
|
if signing_input == i:
|
||||||
|
assert txin.txo_ref.txo is not None
|
||||||
|
txin.serialize_to(stream, txin.txo_ref.txo.script.source)
|
||||||
|
else:
|
||||||
|
txin.serialize_to(stream, b'')
|
||||||
|
stream.write_compact_size(len(self._outputs))
|
||||||
|
for txout in self._outputs:
|
||||||
|
txout.serialize_to(stream)
|
||||||
|
stream.write_uint32(self.locktime)
|
||||||
|
stream.write_uint32(self.signature_hash_type(1)) # signature hash type: SIGHASH_ALL
|
||||||
|
return stream.get_bytes()
|
||||||
|
|
||||||
|
def _deserialize(self):
|
||||||
|
if self._raw is not None:
|
||||||
|
stream = BCDataStream(self._raw)
|
||||||
|
self.version = stream.read_uint32()
|
||||||
|
input_count = stream.read_compact_size()
|
||||||
|
self.add_inputs([
|
||||||
|
self.input_class.deserialize_from(stream) for _ in range(input_count)
|
||||||
|
])
|
||||||
|
output_count = stream.read_compact_size()
|
||||||
|
self.add_outputs([
|
||||||
|
self.output_class.deserialize_from(stream) for _ in range(output_count)
|
||||||
|
])
|
||||||
|
self.locktime = stream.read_uint32()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ensure_all_have_same_ledger(cls, funding_accounts: Iterable[BaseAccount],
|
||||||
|
change_account: BaseAccount = None) -> 'baseledger.BaseLedger':
|
||||||
|
ledger = None
|
||||||
|
for account in funding_accounts:
|
||||||
|
if ledger is None:
|
||||||
|
ledger = account.ledger
|
||||||
|
if ledger != account.ledger:
|
||||||
|
raise ValueError(
|
||||||
|
'All funding accounts used to create a transaction must be on the same ledger.'
|
||||||
|
)
|
||||||
|
if change_account is not None and change_account.ledger != ledger:
|
||||||
|
raise ValueError('Change account must use same ledger as funding accounts.')
|
||||||
|
if ledger is None:
|
||||||
|
raise ValueError('No ledger found.')
|
||||||
|
return ledger
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
|
||||||
|
funding_accounts: Iterable[BaseAccount], change_account: BaseAccount,
|
||||||
|
sign: bool = True):
|
||||||
|
""" Find optimal set of inputs when only outputs are provided; add change
|
||||||
|
outputs if only inputs are provided or if inputs are greater than outputs. """
|
||||||
|
|
||||||
|
tx = cls() \
|
||||||
|
.add_inputs(inputs) \
|
||||||
|
.add_outputs(outputs)
|
||||||
|
|
||||||
|
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
|
||||||
|
|
||||||
|
# value of the outputs plus associated fees
|
||||||
|
cost = (
|
||||||
|
tx.get_base_fee(ledger) +
|
||||||
|
tx.get_total_output_sum(ledger)
|
||||||
|
)
|
||||||
|
# value of the inputs less the cost to spend those inputs
|
||||||
|
payment = tx.get_effective_input_sum(ledger)
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
|
||||||
|
if payment < cost:
|
||||||
|
deficit = cost - payment
|
||||||
|
spendables = await ledger.get_spendable_utxos(deficit, funding_accounts)
|
||||||
|
if not spendables:
|
||||||
|
raise InsufficientFundsError('Not enough funds to cover this transaction.')
|
||||||
|
payment += sum(s.effective_amount for s in spendables)
|
||||||
|
tx.add_inputs(s.txi for s in spendables)
|
||||||
|
|
||||||
|
cost_of_change = (
|
||||||
|
tx.get_base_fee(ledger) +
|
||||||
|
cls.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(ledger)
|
||||||
|
)
|
||||||
|
if payment > cost:
|
||||||
|
change = payment - cost
|
||||||
|
if change > cost_of_change:
|
||||||
|
change_address = await change_account.change.get_or_create_usable_address()
|
||||||
|
change_hash160 = change_account.ledger.address_to_hash160(change_address)
|
||||||
|
change_amount = change - cost_of_change
|
||||||
|
change_output = cls.output_class.pay_pubkey_hash(change_amount, change_hash160)
|
||||||
|
change_output.is_change = True
|
||||||
|
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
|
||||||
|
|
||||||
|
if tx._outputs:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# this condition and the outer range(5) loop cover an edge case
|
||||||
|
# whereby a single input is just enough to cover the fee and
|
||||||
|
# has some change left over, but the change left over is less
|
||||||
|
# than the cost_of_change: thus the input is completely
|
||||||
|
# consumed and no output is added, which is an invalid tx.
|
||||||
|
# to be able to spend this input we must increase the cost
|
||||||
|
# of the TX and run through the balance algorithm a second time
|
||||||
|
# adding an extra input and change output, making tx valid.
|
||||||
|
# we do this 5 times in case the other UTXOs added are also
|
||||||
|
# less than the fee, after 5 attempts we give up and go home
|
||||||
|
cost += cost_of_change + 1
|
||||||
|
|
||||||
|
if sign:
|
||||||
|
await tx.sign(funding_accounts)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception('Failed to create transaction:')
|
||||||
|
await ledger.release_tx(tx)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return tx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def signature_hash_type(hash_type):
|
||||||
|
return hash_type
|
||||||
|
|
||||||
|
async def sign(self, funding_accounts: Iterable[BaseAccount]):
|
||||||
|
ledger = self.ensure_all_have_same_ledger(funding_accounts)
|
||||||
|
for i, txi in enumerate(self._inputs):
|
||||||
|
assert txi.script is not None
|
||||||
|
assert txi.txo_ref.txo is not None
|
||||||
|
txo_script = txi.txo_ref.txo.script
|
||||||
|
if txo_script.is_pay_pubkey_hash:
|
||||||
|
address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
|
||||||
|
private_key = await ledger.get_private_key_for_address(address)
|
||||||
|
tx = self._serialize_for_signature(i)
|
||||||
|
txi.script.values['signature'] = \
|
||||||
|
private_key.sign(tx) + bytes((self.signature_hash_type(1),))
|
||||||
|
txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes
|
||||||
|
txi.script.generate()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Don't know how to spend this output.")
|
||||||
|
self._reset()
|
122
torba/torba/client/bcd_data_stream.py
Normal file
122
torba/torba/client/bcd_data_stream.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
import struct
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
|
class BCDataStream:
|
||||||
|
|
||||||
|
def __init__(self, data=None):
|
||||||
|
self.data = BytesIO(data)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.data.seek(0)
|
||||||
|
|
||||||
|
def get_bytes(self):
|
||||||
|
return self.data.getvalue()
|
||||||
|
|
||||||
|
def read(self, size):
|
||||||
|
return self.data.read(size)
|
||||||
|
|
||||||
|
def write(self, data):
|
||||||
|
self.data.write(data)
|
||||||
|
|
||||||
|
def write_many(self, many):
|
||||||
|
self.data.writelines(many)
|
||||||
|
|
||||||
|
def read_string(self):
|
||||||
|
return self.read(self.read_compact_size())
|
||||||
|
|
||||||
|
def write_string(self, s):
|
||||||
|
self.write_compact_size(len(s))
|
||||||
|
self.write(s)
|
||||||
|
|
||||||
|
def read_compact_size(self):
|
||||||
|
size = self.read_uint8()
|
||||||
|
if size < 253:
|
||||||
|
return size
|
||||||
|
if size == 253:
|
||||||
|
return self.read_uint16()
|
||||||
|
if size == 254:
|
||||||
|
return self.read_uint32()
|
||||||
|
if size == 255:
|
||||||
|
return self.read_uint64()
|
||||||
|
|
||||||
|
def write_compact_size(self, size):
|
||||||
|
if size < 253:
|
||||||
|
self.write_uint8(size)
|
||||||
|
elif size <= 0xFFFF:
|
||||||
|
self.write_uint8(253)
|
||||||
|
self.write_uint16(size)
|
||||||
|
elif size <= 0xFFFFFFFF:
|
||||||
|
self.write_uint8(254)
|
||||||
|
self.write_uint32(size)
|
||||||
|
else:
|
||||||
|
self.write_uint8(255)
|
||||||
|
self.write_uint64(size)
|
||||||
|
|
||||||
|
def read_boolean(self):
|
||||||
|
return self.read_uint8() != 0
|
||||||
|
|
||||||
|
def write_boolean(self, val):
|
||||||
|
return self.write_uint8(1 if val else 0)
|
||||||
|
|
||||||
|
int8 = struct.Struct('b')
|
||||||
|
uint8 = struct.Struct('B')
|
||||||
|
int16 = struct.Struct('<h')
|
||||||
|
uint16 = struct.Struct('<H')
|
||||||
|
int32 = struct.Struct('<i')
|
||||||
|
uint32 = struct.Struct('<I')
|
||||||
|
int64 = struct.Struct('<q')
|
||||||
|
uint64 = struct.Struct('<Q')
|
||||||
|
|
||||||
|
def _read_struct(self, fmt):
|
||||||
|
value = self.read(fmt.size)
|
||||||
|
if value:
|
||||||
|
return fmt.unpack(value)[0]
|
||||||
|
|
||||||
|
def read_int8(self):
|
||||||
|
return self._read_struct(self.int8)
|
||||||
|
|
||||||
|
def read_uint8(self):
|
||||||
|
return self._read_struct(self.uint8)
|
||||||
|
|
||||||
|
def read_int16(self):
|
||||||
|
return self._read_struct(self.int16)
|
||||||
|
|
||||||
|
def read_uint16(self):
|
||||||
|
return self._read_struct(self.uint16)
|
||||||
|
|
||||||
|
def read_int32(self):
|
||||||
|
return self._read_struct(self.int32)
|
||||||
|
|
||||||
|
def read_uint32(self):
|
||||||
|
return self._read_struct(self.uint32)
|
||||||
|
|
||||||
|
def read_int64(self):
|
||||||
|
return self._read_struct(self.int64)
|
||||||
|
|
||||||
|
def read_uint64(self):
|
||||||
|
return self._read_struct(self.uint64)
|
||||||
|
|
||||||
|
def write_int8(self, val):
|
||||||
|
self.write(self.int8.pack(val))
|
||||||
|
|
||||||
|
def write_uint8(self, val):
|
||||||
|
self.write(self.uint8.pack(val))
|
||||||
|
|
||||||
|
def write_int16(self, val):
|
||||||
|
self.write(self.int16.pack(val))
|
||||||
|
|
||||||
|
def write_uint16(self, val):
|
||||||
|
self.write(self.uint16.pack(val))
|
||||||
|
|
||||||
|
def write_int32(self, val):
|
||||||
|
self.write(self.int32.pack(val))
|
||||||
|
|
||||||
|
def write_uint32(self, val):
|
||||||
|
self.write(self.uint32.pack(val))
|
||||||
|
|
||||||
|
def write_int64(self, val):
|
||||||
|
self.write(self.int64.pack(val))
|
||||||
|
|
||||||
|
def write_uint64(self, val):
|
||||||
|
self.write(self.uint64.pack(val))
|
261
torba/torba/client/bip32.py
Normal file
261
torba/torba/client/bip32.py
Normal file
|
@ -0,0 +1,261 @@
|
||||||
|
# Copyright (c) 2017, Neil Booth
|
||||||
|
# Copyright (c) 2018, LBRY Inc.
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# See the file "LICENCE" for information about the copyright
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
""" Logic for BIP32 Hierarchical Key Derivation. """
|
||||||
|
from coincurve import PublicKey, PrivateKey as _PrivateKey
|
||||||
|
|
||||||
|
from torba.client.hash import Base58, hmac_sha512, hash160, double_sha256
|
||||||
|
from torba.client.util import cachedproperty
|
||||||
|
|
||||||
|
|
||||||
|
class DerivationError(Exception):
|
||||||
|
""" Raised when an invalid derivation occurs. """
|
||||||
|
|
||||||
|
|
||||||
|
class _KeyBase:
|
||||||
|
""" A BIP32 Key, public or private. """
|
||||||
|
|
||||||
|
def __init__(self, ledger, chain_code, n, depth, parent):
|
||||||
|
if not isinstance(chain_code, (bytes, bytearray)):
|
||||||
|
raise TypeError('chain code must be raw bytes')
|
||||||
|
if len(chain_code) != 32:
|
||||||
|
raise ValueError('invalid chain code')
|
||||||
|
if not 0 <= n < 1 << 32:
|
||||||
|
raise ValueError('invalid child number')
|
||||||
|
if not 0 <= depth < 256:
|
||||||
|
raise ValueError('invalid depth')
|
||||||
|
if parent is not None:
|
||||||
|
if not isinstance(parent, type(self)):
|
||||||
|
raise TypeError('parent key has bad type')
|
||||||
|
self.ledger = ledger
|
||||||
|
self.chain_code = chain_code
|
||||||
|
self.n = n
|
||||||
|
self.depth = depth
|
||||||
|
self.parent = parent
|
||||||
|
|
||||||
|
def _hmac_sha512(self, msg):
|
||||||
|
""" Use SHA-512 to provide an HMAC, returned as a pair of 32-byte objects. """
|
||||||
|
hmac = hmac_sha512(self.chain_code, msg)
|
||||||
|
return hmac[:32], hmac[32:]
|
||||||
|
|
||||||
|
def _extended_key(self, ver_bytes, raw_serkey):
|
||||||
|
""" Return the 78-byte extended key given prefix version bytes and serialized key bytes. """
|
||||||
|
if not isinstance(ver_bytes, (bytes, bytearray)):
|
||||||
|
raise TypeError('ver_bytes must be raw bytes')
|
||||||
|
if len(ver_bytes) != 4:
|
||||||
|
raise ValueError('ver_bytes must have length 4')
|
||||||
|
if not isinstance(raw_serkey, (bytes, bytearray)):
|
||||||
|
raise TypeError('raw_serkey must be raw bytes')
|
||||||
|
if len(raw_serkey) != 33:
|
||||||
|
raise ValueError('raw_serkey must have length 33')
|
||||||
|
|
||||||
|
return (ver_bytes + bytes((self.depth,))
|
||||||
|
+ self.parent_fingerprint() + self.n.to_bytes(4, 'big')
|
||||||
|
+ self.chain_code + raw_serkey)
|
||||||
|
|
||||||
|
def identifier(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def extended_key(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def fingerprint(self):
|
||||||
|
""" Return the key's fingerprint as 4 bytes. """
|
||||||
|
return self.identifier()[:4]
|
||||||
|
|
||||||
|
def parent_fingerprint(self):
|
||||||
|
""" Return the parent key's fingerprint as 4 bytes. """
|
||||||
|
return self.parent.fingerprint() if self.parent else bytes((0,)*4)
|
||||||
|
|
||||||
|
def extended_key_string(self):
|
||||||
|
""" Return an extended key as a base58 string. """
|
||||||
|
return Base58.encode_check(self.extended_key())
|
||||||
|
|
||||||
|
|
||||||
|
class PubKey(_KeyBase):
|
||||||
|
""" A BIP32 public key. """
|
||||||
|
|
||||||
|
def __init__(self, ledger, pubkey, chain_code, n, depth, parent=None):
|
||||||
|
super().__init__(ledger, chain_code, n, depth, parent)
|
||||||
|
if isinstance(pubkey, PublicKey):
|
||||||
|
self.verifying_key = pubkey
|
||||||
|
else:
|
||||||
|
self.verifying_key = self._verifying_key_from_pubkey(pubkey)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _verifying_key_from_pubkey(cls, pubkey):
|
||||||
|
""" Converts a 33-byte compressed pubkey into an PublicKey object. """
|
||||||
|
if not isinstance(pubkey, (bytes, bytearray)):
|
||||||
|
raise TypeError('pubkey must be raw bytes')
|
||||||
|
if len(pubkey) != 33:
|
||||||
|
raise ValueError('pubkey must be 33 bytes')
|
||||||
|
if pubkey[0] not in (2, 3):
|
||||||
|
raise ValueError('invalid pubkey prefix byte')
|
||||||
|
return PublicKey(pubkey)
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def pubkey_bytes(self):
|
||||||
|
""" Return the compressed public key as 33 bytes. """
|
||||||
|
return self.verifying_key.format(True)
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def address(self):
|
||||||
|
""" The public key as a P2PKH address. """
|
||||||
|
return self.ledger.public_key_to_address(self.pubkey_bytes)
|
||||||
|
|
||||||
|
def ec_point(self):
|
||||||
|
return self.verifying_key.point()
|
||||||
|
|
||||||
|
def child(self, n: int):
|
||||||
|
""" Return the derived child extended pubkey at index N. """
|
||||||
|
if not 0 <= n < (1 << 31):
|
||||||
|
raise ValueError('invalid BIP32 public key child number')
|
||||||
|
|
||||||
|
msg = self.pubkey_bytes + n.to_bytes(4, 'big')
|
||||||
|
L_b, R_b = self._hmac_sha512(msg) # pylint: disable=invalid-name
|
||||||
|
derived_key = self.verifying_key.add(L_b)
|
||||||
|
return PubKey(self.ledger, derived_key, R_b, n, self.depth + 1, self)
|
||||||
|
|
||||||
|
def identifier(self):
|
||||||
|
""" Return the key's identifier as 20 bytes. """
|
||||||
|
return hash160(self.pubkey_bytes)
|
||||||
|
|
||||||
|
def extended_key(self):
|
||||||
|
""" Return a raw extended public key. """
|
||||||
|
return self._extended_key(
|
||||||
|
self.ledger.extended_public_key_prefix,
|
||||||
|
self.pubkey_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateKey(_KeyBase):
|
||||||
|
"""A BIP32 private key."""
|
||||||
|
|
||||||
|
HARDENED = 1 << 31
|
||||||
|
|
||||||
|
def __init__(self, ledger, privkey, chain_code, n, depth, parent=None):
|
||||||
|
super().__init__(ledger, chain_code, n, depth, parent)
|
||||||
|
if isinstance(privkey, _PrivateKey):
|
||||||
|
self.signing_key = privkey
|
||||||
|
else:
|
||||||
|
self.signing_key = self._signing_key_from_privkey(privkey)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _signing_key_from_privkey(cls, private_key):
|
||||||
|
""" Converts a 32-byte private key into an coincurve.PrivateKey object. """
|
||||||
|
return _PrivateKey.from_int(PrivateKey._private_key_secret_exponent(private_key))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _private_key_secret_exponent(cls, private_key):
|
||||||
|
""" Return the private key as a secret exponent if it is a valid private key. """
|
||||||
|
if not isinstance(private_key, (bytes, bytearray)):
|
||||||
|
raise TypeError('private key must be raw bytes')
|
||||||
|
if len(private_key) != 32:
|
||||||
|
raise ValueError('private key must be 32 bytes')
|
||||||
|
return int.from_bytes(private_key, 'big')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_seed(cls, ledger, seed):
|
||||||
|
# This hard-coded message string seems to be coin-independent...
|
||||||
|
hmac = hmac_sha512(b'Bitcoin seed', seed)
|
||||||
|
privkey, chain_code = hmac[:32], hmac[32:]
|
||||||
|
return cls(ledger, privkey, chain_code, 0, 0)
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def private_key_bytes(self):
|
||||||
|
""" Return the serialized private key (no leading zero byte). """
|
||||||
|
return self.signing_key.secret
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def public_key(self):
|
||||||
|
""" Return the corresponding extended public key. """
|
||||||
|
verifying_key = self.signing_key.public_key
|
||||||
|
parent_pubkey = self.parent.public_key if self.parent else None
|
||||||
|
return PubKey(self.ledger, verifying_key, self.chain_code, self.n, self.depth,
|
||||||
|
parent_pubkey)
|
||||||
|
|
||||||
|
def ec_point(self):
|
||||||
|
return self.public_key.ec_point()
|
||||||
|
|
||||||
|
def secret_exponent(self):
|
||||||
|
""" Return the private key as a secret exponent. """
|
||||||
|
return self.signing_key.to_int()
|
||||||
|
|
||||||
|
def wif(self):
|
||||||
|
""" Return the private key encoded in Wallet Import Format. """
|
||||||
|
return self.ledger.private_key_to_wif(self.private_key_bytes)
|
||||||
|
|
||||||
|
def address(self):
|
||||||
|
""" The public key as a P2PKH address. """
|
||||||
|
return self.public_key.address
|
||||||
|
|
||||||
|
def child(self, n):
|
||||||
|
""" Return the derived child extended private key at index N."""
|
||||||
|
if not 0 <= n < (1 << 32):
|
||||||
|
raise ValueError('invalid BIP32 private key child number')
|
||||||
|
|
||||||
|
if n >= self.HARDENED:
|
||||||
|
serkey = b'\0' + self.private_key_bytes
|
||||||
|
else:
|
||||||
|
serkey = self.public_key.pubkey_bytes
|
||||||
|
|
||||||
|
msg = serkey + n.to_bytes(4, 'big')
|
||||||
|
L_b, R_b = self._hmac_sha512(msg) # pylint: disable=invalid-name
|
||||||
|
derived_key = self.signing_key.add(L_b)
|
||||||
|
return PrivateKey(self.ledger, derived_key, R_b, n, self.depth + 1, self)
|
||||||
|
|
||||||
|
def sign(self, data):
|
||||||
|
""" Produce a signature for piece of data by double hashing it and signing the hash. """
|
||||||
|
return self.signing_key.sign(data, hasher=double_sha256)
|
||||||
|
|
||||||
|
def identifier(self):
|
||||||
|
"""Return the key's identifier as 20 bytes."""
|
||||||
|
return self.public_key.identifier()
|
||||||
|
|
||||||
|
def extended_key(self):
|
||||||
|
"""Return a raw extended private key."""
|
||||||
|
return self._extended_key(
|
||||||
|
self.ledger.extended_private_key_prefix,
|
||||||
|
b'\0' + self.private_key_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _from_extended_key(ledger, ekey):
|
||||||
|
"""Return a PubKey or PrivateKey from an extended key raw bytes."""
|
||||||
|
if not isinstance(ekey, (bytes, bytearray)):
|
||||||
|
raise TypeError('extended key must be raw bytes')
|
||||||
|
if len(ekey) != 78:
|
||||||
|
raise ValueError('extended key must have length 78')
|
||||||
|
|
||||||
|
depth = ekey[4]
|
||||||
|
n = int.from_bytes(ekey[9:13], 'big')
|
||||||
|
chain_code = ekey[13:45]
|
||||||
|
|
||||||
|
if ekey[:4] == ledger.extended_public_key_prefix:
|
||||||
|
pubkey = ekey[45:]
|
||||||
|
key = PubKey(ledger, pubkey, chain_code, n, depth)
|
||||||
|
elif ekey[:4] == ledger.extended_private_key_prefix:
|
||||||
|
if ekey[45] != 0:
|
||||||
|
raise ValueError('invalid extended private key prefix byte')
|
||||||
|
privkey = ekey[46:]
|
||||||
|
key = PrivateKey(ledger, privkey, chain_code, n, depth)
|
||||||
|
else:
|
||||||
|
raise ValueError('version bytes unrecognised')
|
||||||
|
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def from_extended_key_string(ledger, ekey_str):
|
||||||
|
"""Given an extended key string, such as
|
||||||
|
|
||||||
|
xpub6BsnM1W2Y7qLMiuhi7f7dbAwQZ5Cz5gYJCRzTNainXzQXYjFwtuQXHd
|
||||||
|
3qfi3t3KJtHxshXezfjft93w4UE7BGMtKwhqEHae3ZA7d823DVrL
|
||||||
|
|
||||||
|
return a PubKey or PrivateKey.
|
||||||
|
"""
|
||||||
|
return _from_extended_key(ledger, Base58.decode_check(ekey_str))
|
89
torba/torba/client/cli.py
Normal file
89
torba/torba/client/cli.py
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from torba.orchstr8.node import Conductor, get_ledger_from_environment, get_blockchain_node_from_ledger
|
||||||
|
from torba.orchstr8.service import ConductorService
|
||||||
|
|
||||||
|
|
||||||
|
def get_argument_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="torba"
|
||||||
|
)
|
||||||
|
subparsers = parser.add_subparsers(dest='command', help='sub-command help')
|
||||||
|
|
||||||
|
subparsers.add_parser("gui", help="Start Qt GUI.")
|
||||||
|
|
||||||
|
subparsers.add_parser("download", help="Download blockchain node binary.")
|
||||||
|
|
||||||
|
start = subparsers.add_parser("start", help="Start orchstr8 service.")
|
||||||
|
start.add_argument("--blockchain", help="Start blockchain node.", action="store_true")
|
||||||
|
start.add_argument("--spv", help="Start SPV server.", action="store_true")
|
||||||
|
start.add_argument("--wallet", help="Start wallet daemon.", action="store_true")
|
||||||
|
|
||||||
|
generate = subparsers.add_parser("generate", help="Call generate method on running orchstr8 instance.")
|
||||||
|
generate.add_argument("blocks", type=int, help="Number of blocks to generate")
|
||||||
|
|
||||||
|
subparsers.add_parser("transfer", help="Call transfer method on running orchstr8 instance.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
async def run_remote_command(command, **kwargs):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post('http://localhost:7954/'+command, data=kwargs) as resp:
|
||||||
|
print(resp.status)
|
||||||
|
print(await resp.text())
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_argument_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
command = getattr(args, 'command', 'help')
|
||||||
|
|
||||||
|
if command == 'gui':
|
||||||
|
from torba.workbench import main as start_app # pylint: disable=E0611,E0401
|
||||||
|
return start_app()
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
ledger = get_ledger_from_environment()
|
||||||
|
|
||||||
|
if command == 'download':
|
||||||
|
logging.getLogger('blockchain').setLevel(logging.INFO)
|
||||||
|
get_blockchain_node_from_ledger(ledger).ensure()
|
||||||
|
|
||||||
|
elif command == 'generate':
|
||||||
|
loop.run_until_complete(run_remote_command(
|
||||||
|
'generate', blocks=args.blocks
|
||||||
|
))
|
||||||
|
|
||||||
|
elif command == 'start':
|
||||||
|
|
||||||
|
conductor = Conductor()
|
||||||
|
if getattr(args, 'blockchain', False):
|
||||||
|
loop.run_until_complete(conductor.start_blockchain())
|
||||||
|
if getattr(args, 'spv', False):
|
||||||
|
loop.run_until_complete(conductor.start_spv())
|
||||||
|
if getattr(args, 'wallet', False):
|
||||||
|
loop.run_until_complete(conductor.start_wallet())
|
||||||
|
|
||||||
|
service = ConductorService(conductor, loop)
|
||||||
|
loop.run_until_complete(service.start())
|
||||||
|
|
||||||
|
try:
|
||||||
|
print('========== Orchstr8 API Service Started ========')
|
||||||
|
loop.run_forever()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(service.stop())
|
||||||
|
loop.run_until_complete(conductor.stop())
|
||||||
|
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
136
torba/torba/client/coinselection.py
Normal file
136
torba/torba/client/coinselection.py
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
from random import Random
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from torba.client import basetransaction
|
||||||
|
|
||||||
|
MAXIMUM_TRIES = 100000
|
||||||
|
|
||||||
|
STRATEGIES = []
|
||||||
|
|
||||||
|
def strategy(method):
|
||||||
|
STRATEGIES.append(method.__name__)
|
||||||
|
return method
|
||||||
|
|
||||||
|
|
||||||
|
class CoinSelector:
|
||||||
|
|
||||||
|
def __init__(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator],
|
||||||
|
target: int, cost_of_change: int, seed: str = None) -> None:
|
||||||
|
self.txos = txos
|
||||||
|
self.target = target
|
||||||
|
self.cost_of_change = cost_of_change
|
||||||
|
self.exact_match = False
|
||||||
|
self.tries = 0
|
||||||
|
self.available = sum(c.effective_amount for c in self.txos)
|
||||||
|
self.random = Random(seed)
|
||||||
|
if seed is not None:
|
||||||
|
self.random.seed(seed, version=1)
|
||||||
|
|
||||||
|
def select(self, strategy_name: str = None) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||||
|
if not self.txos:
|
||||||
|
return []
|
||||||
|
if self.target > self.available:
|
||||||
|
return []
|
||||||
|
if strategy_name is not None:
|
||||||
|
return getattr(self, strategy_name)()
|
||||||
|
return (
|
||||||
|
self.branch_and_bound() or
|
||||||
|
self.closest_match() or
|
||||||
|
self.random_draw()
|
||||||
|
)
|
||||||
|
|
||||||
|
@strategy
|
||||||
|
def prefer_confirmed(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||||
|
self.txos = [t for t in self.txos if t.txo.tx_ref and t.txo.tx_ref.height > 0] or self.txos
|
||||||
|
self.available = sum(c.effective_amount for c in self.txos)
|
||||||
|
return (
|
||||||
|
self.branch_and_bound() or
|
||||||
|
self.closest_match() or
|
||||||
|
self.random_draw()
|
||||||
|
)
|
||||||
|
|
||||||
|
@strategy
|
||||||
|
def branch_and_bound(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||||
|
# see bitcoin implementation for more info:
|
||||||
|
# https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp
|
||||||
|
|
||||||
|
self.txos.sort(reverse=True)
|
||||||
|
|
||||||
|
current_value = 0
|
||||||
|
current_available_value = self.available
|
||||||
|
current_selection: List[bool] = []
|
||||||
|
best_waste = self.cost_of_change
|
||||||
|
best_selection: List[bool] = []
|
||||||
|
|
||||||
|
while self.tries < MAXIMUM_TRIES:
|
||||||
|
self.tries += 1
|
||||||
|
|
||||||
|
backtrack = False
|
||||||
|
if current_value + current_available_value < self.target or \
|
||||||
|
current_value > self.target + self.cost_of_change:
|
||||||
|
backtrack = True
|
||||||
|
elif current_value >= self.target:
|
||||||
|
new_waste = current_value - self.target
|
||||||
|
if new_waste <= best_waste:
|
||||||
|
best_waste = new_waste
|
||||||
|
best_selection = current_selection[:]
|
||||||
|
backtrack = True
|
||||||
|
|
||||||
|
if backtrack:
|
||||||
|
while current_selection and not current_selection[-1]:
|
||||||
|
current_selection.pop()
|
||||||
|
current_available_value += self.txos[len(current_selection)].effective_amount
|
||||||
|
|
||||||
|
if not current_selection:
|
||||||
|
break
|
||||||
|
|
||||||
|
current_selection[-1] = False
|
||||||
|
utxo = self.txos[len(current_selection) - 1]
|
||||||
|
current_value -= utxo.effective_amount
|
||||||
|
|
||||||
|
else:
|
||||||
|
utxo = self.txos[len(current_selection)]
|
||||||
|
current_available_value -= utxo.effective_amount
|
||||||
|
previous_utxo = self.txos[len(current_selection) - 1] if current_selection else None
|
||||||
|
if current_selection and not current_selection[-1] and previous_utxo and \
|
||||||
|
utxo.effective_amount == previous_utxo.effective_amount and \
|
||||||
|
utxo.fee == previous_utxo.fee:
|
||||||
|
current_selection.append(False)
|
||||||
|
else:
|
||||||
|
current_selection.append(True)
|
||||||
|
current_value += utxo.effective_amount
|
||||||
|
|
||||||
|
if best_selection:
|
||||||
|
self.exact_match = True
|
||||||
|
return [
|
||||||
|
self.txos[i] for i, include in enumerate(best_selection) if include
|
||||||
|
]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
@strategy
|
||||||
|
def closest_match(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||||
|
""" Pick one UTXOs that is larger than the target but with the smallest change. """
|
||||||
|
target = self.target + self.cost_of_change
|
||||||
|
smallest_change = None
|
||||||
|
best_match = None
|
||||||
|
for txo in self.txos:
|
||||||
|
if txo.effective_amount >= target:
|
||||||
|
change = txo.effective_amount - target
|
||||||
|
if smallest_change is None or change < smallest_change:
|
||||||
|
smallest_change, best_match = change, txo
|
||||||
|
return [best_match] if best_match else []
|
||||||
|
|
||||||
|
@strategy
|
||||||
|
def random_draw(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||||
|
""" Accumulate UTXOs at random until there is enough to cover the target. """
|
||||||
|
target = self.target + self.cost_of_change
|
||||||
|
self.random.shuffle(self.txos, self.random.random)
|
||||||
|
selection = []
|
||||||
|
amount = 0
|
||||||
|
for coin in self.txos:
|
||||||
|
selection.append(coin)
|
||||||
|
amount += coin.effective_amount
|
||||||
|
if amount >= target:
|
||||||
|
return selection
|
||||||
|
return []
|
6
torba/torba/client/constants.py
Normal file
6
torba/torba/client/constants.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
NULL_HASH32 = b'\x00'*32
|
||||||
|
|
||||||
|
CENT = 1000000
|
||||||
|
COIN = 100*CENT
|
||||||
|
|
||||||
|
TIMEOUT = 30.0
|
2
torba/torba/client/errors.py
Normal file
2
torba/torba/client/errors.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
class InsufficientFundsError(Exception):
|
||||||
|
pass
|
254
torba/torba/client/hash.py
Normal file
254
torba/torba/client/hash.py
Normal file
|
@ -0,0 +1,254 @@
|
||||||
|
# Copyright (c) 2016-2017, Neil Booth
|
||||||
|
# Copyright (c) 2018, LBRY Inc.
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# See the file "LICENCE" for information about the copyright
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
""" Cryptography hash functions and related classes. """
|
||||||
|
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import typing
|
||||||
|
from binascii import hexlify, unhexlify
|
||||||
|
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
|
||||||
|
from cryptography.hazmat.primitives.ciphers import Cipher, modes
|
||||||
|
from cryptography.hazmat.primitives.ciphers.algorithms import AES
|
||||||
|
from cryptography.hazmat.primitives.padding import PKCS7
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
|
||||||
|
from torba.client.util import bytes_to_int, int_to_bytes
|
||||||
|
from torba.client.constants import NULL_HASH32
|
||||||
|
|
||||||
|
|
||||||
|
class TXRef:
|
||||||
|
|
||||||
|
__slots__ = '_id', '_hash'
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._id = None
|
||||||
|
self._hash = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self):
|
||||||
|
return self._id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hash(self):
|
||||||
|
return self._hash
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height(self):
|
||||||
|
return -1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_null(self):
|
||||||
|
return self.hash == NULL_HASH32
|
||||||
|
|
||||||
|
|
||||||
|
class TXRefImmutable(TXRef):
|
||||||
|
|
||||||
|
__slots__ = ('_height',)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._height = -1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_hash(cls, tx_hash: bytes, height: int) -> 'TXRefImmutable':
|
||||||
|
ref = cls()
|
||||||
|
ref._hash = tx_hash
|
||||||
|
ref._id = hexlify(tx_hash[::-1]).decode()
|
||||||
|
ref._height = height
|
||||||
|
return ref
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_id(cls, tx_id: str, height: int) -> 'TXRefImmutable':
|
||||||
|
ref = cls()
|
||||||
|
ref._id = tx_id
|
||||||
|
ref._hash = unhexlify(tx_id)[::-1]
|
||||||
|
ref._height = height
|
||||||
|
return ref
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height(self):
|
||||||
|
return self._height
|
||||||
|
|
||||||
|
|
||||||
|
def sha256(x):
|
||||||
|
""" Simple wrapper of hashlib sha256. """
|
||||||
|
return hashlib.sha256(x).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def sha512(x):
|
||||||
|
""" Simple wrapper of hashlib sha512. """
|
||||||
|
return hashlib.sha512(x).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def ripemd160(x):
|
||||||
|
""" Simple wrapper of hashlib ripemd160. """
|
||||||
|
h = hashlib.new('ripemd160')
|
||||||
|
h.update(x)
|
||||||
|
return h.digest()
|
||||||
|
|
||||||
|
|
||||||
|
def double_sha256(x):
|
||||||
|
""" SHA-256 of SHA-256, as used extensively in bitcoin. """
|
||||||
|
return sha256(sha256(x))
|
||||||
|
|
||||||
|
|
||||||
|
def hmac_sha512(key, msg):
|
||||||
|
""" Use SHA-512 to provide an HMAC. """
|
||||||
|
return hmac.new(key, msg, hashlib.sha512).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def hash160(x):
|
||||||
|
""" RIPEMD-160 of SHA-256.
|
||||||
|
Used to make bitcoin addresses from pubkeys. """
|
||||||
|
return ripemd160(sha256(x))
|
||||||
|
|
||||||
|
|
||||||
|
def hash_to_hex_str(x):
|
||||||
|
""" Convert a big-endian binary hash to displayed hex string.
|
||||||
|
Display form of a binary hash is reversed and converted to hex. """
|
||||||
|
return hexlify(reversed(x))
|
||||||
|
|
||||||
|
|
||||||
|
def hex_str_to_hash(x):
|
||||||
|
""" Convert a displayed hex string to a binary hash. """
|
||||||
|
return reversed(unhexlify(x))
|
||||||
|
|
||||||
|
|
||||||
|
def aes_encrypt(secret: str, value: str, init_vector: bytes = None) -> str:
|
||||||
|
if init_vector is not None:
|
||||||
|
assert len(init_vector) == 16
|
||||||
|
else:
|
||||||
|
init_vector = os.urandom(16)
|
||||||
|
key = double_sha256(secret.encode())
|
||||||
|
encryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).encryptor()
|
||||||
|
padder = PKCS7(AES.block_size).padder()
|
||||||
|
padded_data = padder.update(value.encode()) + padder.finalize()
|
||||||
|
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||||
|
return base64.b64encode(init_vector + encrypted_data).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def aes_decrypt(secret: str, value: str) -> typing.Tuple[str, bytes]:
|
||||||
|
data = base64.b64decode(value.encode())
|
||||||
|
key = double_sha256(secret.encode())
|
||||||
|
init_vector, data = data[:16], data[16:]
|
||||||
|
decryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).decryptor()
|
||||||
|
unpadder = PKCS7(AES.block_size).unpadder()
|
||||||
|
result = unpadder.update(decryptor.update(data)) + unpadder.finalize()
|
||||||
|
return result.decode(), init_vector
|
||||||
|
|
||||||
|
|
||||||
|
def better_aes_encrypt(secret: str, value: bytes) -> bytes:
|
||||||
|
init_vector = os.urandom(16)
|
||||||
|
key = scrypt(secret.encode(), salt=init_vector)
|
||||||
|
encryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).encryptor()
|
||||||
|
padder = PKCS7(AES.block_size).padder()
|
||||||
|
padded_data = padder.update(value) + padder.finalize()
|
||||||
|
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||||
|
return base64.b64encode(b's:8192:16:1:' + init_vector + encrypted_data)
|
||||||
|
|
||||||
|
|
||||||
|
def better_aes_decrypt(secret: str, value: bytes) -> bytes:
|
||||||
|
data = base64.b64decode(value)
|
||||||
|
_, scryp_n, scrypt_r, scrypt_p, data = data.split(b':', maxsplit=4)
|
||||||
|
init_vector, data = data[:16], data[16:]
|
||||||
|
key = scrypt(secret.encode(), init_vector, int(scryp_n), int(scrypt_r), int(scrypt_p))
|
||||||
|
decryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).decryptor()
|
||||||
|
unpadder = PKCS7(AES.block_size).unpadder()
|
||||||
|
return unpadder.update(decryptor.update(data)) + unpadder.finalize()
|
||||||
|
|
||||||
|
|
||||||
|
def scrypt(passphrase, salt, scrypt_n=1<<13, scrypt_r=16, scrypt_p=1):
|
||||||
|
kdf = Scrypt(salt, length=32, n=scrypt_n, r=scrypt_r, p=scrypt_p, backend=default_backend())
|
||||||
|
return kdf.derive(passphrase)
|
||||||
|
|
||||||
|
|
||||||
|
class Base58Error(Exception):
|
||||||
|
""" Exception used for Base58 errors. """
|
||||||
|
|
||||||
|
|
||||||
|
class Base58:
|
||||||
|
""" Class providing base 58 functionality. """
|
||||||
|
|
||||||
|
chars = u'123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz'
|
||||||
|
assert len(chars) == 58
|
||||||
|
char_map = {c: n for n, c in enumerate(chars)}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def char_value(cls, c):
|
||||||
|
val = cls.char_map.get(c)
|
||||||
|
if val is None:
|
||||||
|
raise Base58Error('invalid base 58 character "{}"'.format(c))
|
||||||
|
return val
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def decode(cls, txt):
|
||||||
|
""" Decodes txt into a big-endian bytearray. """
|
||||||
|
if isinstance(txt, memoryview):
|
||||||
|
txt = str(txt)
|
||||||
|
|
||||||
|
if isinstance(txt, bytes):
|
||||||
|
txt = txt.decode()
|
||||||
|
|
||||||
|
if not isinstance(txt, str):
|
||||||
|
raise TypeError('a string is required')
|
||||||
|
|
||||||
|
if not txt:
|
||||||
|
raise Base58Error('string cannot be empty')
|
||||||
|
|
||||||
|
value = 0
|
||||||
|
for c in txt:
|
||||||
|
value = value * 58 + cls.char_value(c)
|
||||||
|
|
||||||
|
result = int_to_bytes(value)
|
||||||
|
|
||||||
|
# Prepend leading zero bytes if necessary
|
||||||
|
count = 0
|
||||||
|
for c in txt:
|
||||||
|
if c != u'1':
|
||||||
|
break
|
||||||
|
count += 1
|
||||||
|
if count:
|
||||||
|
result = bytes((0,)) * count + result
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encode(cls, be_bytes):
|
||||||
|
"""Converts a big-endian bytearray into a base58 string."""
|
||||||
|
value = bytes_to_int(be_bytes)
|
||||||
|
|
||||||
|
txt = u''
|
||||||
|
while value:
|
||||||
|
value, mod = divmod(value, 58)
|
||||||
|
txt += cls.chars[mod]
|
||||||
|
|
||||||
|
for byte in be_bytes:
|
||||||
|
if byte != 0:
|
||||||
|
break
|
||||||
|
txt += u'1'
|
||||||
|
|
||||||
|
return txt[::-1]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def decode_check(cls, txt, hash_fn=double_sha256):
|
||||||
|
""" Decodes a Base58Check-encoded string to a payload. The version prefixes it. """
|
||||||
|
be_bytes = cls.decode(txt)
|
||||||
|
result, check = be_bytes[:-4], be_bytes[-4:]
|
||||||
|
if check != hash_fn(result)[:4]:
|
||||||
|
raise Base58Error('invalid base 58 checksum for {}'.format(txt))
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encode_check(cls, payload, hash_fn=double_sha256):
|
||||||
|
""" Encodes a payload bytearray (which includes the version byte(s))
|
||||||
|
into a Base58Check string."""
|
||||||
|
be_bytes = payload + hash_fn(payload)[:4]
|
||||||
|
return cls.encode(be_bytes)
|
159
torba/torba/client/mnemonic.py
Normal file
159
torba/torba/client/mnemonic.py
Normal file
|
@ -0,0 +1,159 @@
|
||||||
|
# Copyright (C) 2014 Thomas Voegtlin
|
||||||
|
# Copyright (C) 2018 LBRY Inc.
|
||||||
|
|
||||||
|
import hmac
|
||||||
|
import math
|
||||||
|
import hashlib
|
||||||
|
import importlib
|
||||||
|
import unicodedata
|
||||||
|
import string
|
||||||
|
from binascii import hexlify
|
||||||
|
from secrets import randbelow
|
||||||
|
|
||||||
|
import pbkdf2
|
||||||
|
|
||||||
|
from torba.client.hash import hmac_sha512
|
||||||
|
from torba.client.words import english
|
||||||
|
|
||||||
|
# The hash of the mnemonic seed must begin with this
|
||||||
|
SEED_PREFIX = b'01' # Standard wallet
|
||||||
|
SEED_PREFIX_2FA = b'101' # Two-factor authentication
|
||||||
|
SEED_PREFIX_SW = b'100' # Segwit wallet
|
||||||
|
|
||||||
|
# http://www.asahi-net.or.jp/~ax2s-kmtn/ref/unicode/e_asia.html
|
||||||
|
CJK_INTERVALS = [
|
||||||
|
(0x4E00, 0x9FFF, 'CJK Unified Ideographs'),
|
||||||
|
(0x3400, 0x4DBF, 'CJK Unified Ideographs Extension A'),
|
||||||
|
(0x20000, 0x2A6DF, 'CJK Unified Ideographs Extension B'),
|
||||||
|
(0x2A700, 0x2B73F, 'CJK Unified Ideographs Extension C'),
|
||||||
|
(0x2B740, 0x2B81F, 'CJK Unified Ideographs Extension D'),
|
||||||
|
(0xF900, 0xFAFF, 'CJK Compatibility Ideographs'),
|
||||||
|
(0x2F800, 0x2FA1D, 'CJK Compatibility Ideographs Supplement'),
|
||||||
|
(0x3190, 0x319F, 'Kanbun'),
|
||||||
|
(0x2E80, 0x2EFF, 'CJK Radicals Supplement'),
|
||||||
|
(0x2F00, 0x2FDF, 'CJK Radicals'),
|
||||||
|
(0x31C0, 0x31EF, 'CJK Strokes'),
|
||||||
|
(0x2FF0, 0x2FFF, 'Ideographic Description Characters'),
|
||||||
|
(0xE0100, 0xE01EF, 'Variation Selectors Supplement'),
|
||||||
|
(0x3100, 0x312F, 'Bopomofo'),
|
||||||
|
(0x31A0, 0x31BF, 'Bopomofo Extended'),
|
||||||
|
(0xFF00, 0xFFEF, 'Halfwidth and Fullwidth Forms'),
|
||||||
|
(0x3040, 0x309F, 'Hiragana'),
|
||||||
|
(0x30A0, 0x30FF, 'Katakana'),
|
||||||
|
(0x31F0, 0x31FF, 'Katakana Phonetic Extensions'),
|
||||||
|
(0x1B000, 0x1B0FF, 'Kana Supplement'),
|
||||||
|
(0xAC00, 0xD7AF, 'Hangul Syllables'),
|
||||||
|
(0x1100, 0x11FF, 'Hangul Jamo'),
|
||||||
|
(0xA960, 0xA97F, 'Hangul Jamo Extended A'),
|
||||||
|
(0xD7B0, 0xD7FF, 'Hangul Jamo Extended B'),
|
||||||
|
(0x3130, 0x318F, 'Hangul Compatibility Jamo'),
|
||||||
|
(0xA4D0, 0xA4FF, 'Lisu'),
|
||||||
|
(0x16F00, 0x16F9F, 'Miao'),
|
||||||
|
(0xA000, 0xA48F, 'Yi Syllables'),
|
||||||
|
(0xA490, 0xA4CF, 'Yi Radicals'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def is_cjk(c):
|
||||||
|
n = ord(c)
|
||||||
|
for start, end, _ in CJK_INTERVALS:
|
||||||
|
if start <= n <= end:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_text(seed):
|
||||||
|
seed = unicodedata.normalize('NFKD', seed)
|
||||||
|
seed = seed.lower()
|
||||||
|
# remove accents
|
||||||
|
seed = u''.join([c for c in seed if not unicodedata.combining(c)])
|
||||||
|
# normalize whitespaces
|
||||||
|
seed = u' '.join(seed.split())
|
||||||
|
# remove whitespaces between CJK
|
||||||
|
seed = u''.join([
|
||||||
|
seed[i] for i in range(len(seed))
|
||||||
|
if not (seed[i] in string.whitespace and is_cjk(seed[i-1]) and is_cjk(seed[i+1]))
|
||||||
|
])
|
||||||
|
return seed
|
||||||
|
|
||||||
|
|
||||||
|
def load_words(language_name):
|
||||||
|
if language_name == 'english':
|
||||||
|
return english.words
|
||||||
|
language_module = importlib.import_module('torba.words.'+language_name)
|
||||||
|
return list(map(
|
||||||
|
lambda s: unicodedata.normalize('NFKD', s),
|
||||||
|
language_module.words
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
LANGUAGE_NAMES = {
|
||||||
|
'en': 'english',
|
||||||
|
'es': 'spanish',
|
||||||
|
'ja': 'japanese',
|
||||||
|
'pt': 'portuguese',
|
||||||
|
'zh': 'chinese_simplified'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Mnemonic:
|
||||||
|
# Seed derivation no longer follows BIP39
|
||||||
|
# Mnemonic phrase uses a hash based checksum, instead of a words-dependent checksum
|
||||||
|
|
||||||
|
def __init__(self, lang='en'):
|
||||||
|
language_name = LANGUAGE_NAMES.get(lang, 'english')
|
||||||
|
self.words = load_words(language_name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mnemonic_to_seed(mnemonic, passphrase=u''):
|
||||||
|
pbkdf2_rounds = 2048
|
||||||
|
mnemonic = normalize_text(mnemonic)
|
||||||
|
passphrase = normalize_text(passphrase)
|
||||||
|
return pbkdf2.PBKDF2(
|
||||||
|
mnemonic, passphrase, iterations=pbkdf2_rounds, macmodule=hmac, digestmodule=hashlib.sha512
|
||||||
|
).read(64)
|
||||||
|
|
||||||
|
def mnemonic_encode(self, i):
|
||||||
|
n = len(self.words)
|
||||||
|
words = []
|
||||||
|
while i:
|
||||||
|
x = i%n
|
||||||
|
i = i//n
|
||||||
|
words.append(self.words[x])
|
||||||
|
return ' '.join(words)
|
||||||
|
|
||||||
|
def mnemonic_decode(self, seed):
|
||||||
|
n = len(self.words)
|
||||||
|
words = seed.split()
|
||||||
|
i = 0
|
||||||
|
while words:
|
||||||
|
word = words.pop()
|
||||||
|
k = self.words.index(word)
|
||||||
|
i = i*n + k
|
||||||
|
return i
|
||||||
|
|
||||||
|
def make_seed(self, prefix=SEED_PREFIX, num_bits=132):
|
||||||
|
# increase num_bits in order to obtain a uniform distribution for the last word
|
||||||
|
bpw = math.log(len(self.words), 2)
|
||||||
|
# rounding
|
||||||
|
n = int(math.ceil(num_bits/bpw) * bpw)
|
||||||
|
entropy = 1
|
||||||
|
while 0 < entropy < pow(2, n - bpw):
|
||||||
|
# try again if seed would not contain enough words
|
||||||
|
entropy = randbelow(pow(2, n))
|
||||||
|
nonce = 0
|
||||||
|
while True:
|
||||||
|
nonce += 1
|
||||||
|
i = entropy + nonce
|
||||||
|
seed = self.mnemonic_encode(i)
|
||||||
|
if i != self.mnemonic_decode(seed):
|
||||||
|
raise Exception('Cannot extract same entropy from mnemonic!')
|
||||||
|
if is_new_seed(seed, prefix):
|
||||||
|
break
|
||||||
|
return seed
|
||||||
|
|
||||||
|
|
||||||
|
def is_new_seed(seed, prefix):
|
||||||
|
seed = normalize_text(seed)
|
||||||
|
seed_hash = hexlify(hmac_sha512(b"Seed version", seed.encode('utf8')))
|
||||||
|
return seed_hash.startswith(prefix)
|
142
torba/torba/client/util.py
Normal file
142
torba/torba/client/util.py
Normal file
|
@ -0,0 +1,142 @@
|
||||||
|
import re
|
||||||
|
from binascii import unhexlify, hexlify
|
||||||
|
from typing import TypeVar, Sequence, Optional
|
||||||
|
from torba.client.constants import COIN
|
||||||
|
|
||||||
|
|
||||||
|
def coins_to_satoshis(coins):
|
||||||
|
if not isinstance(coins, str):
|
||||||
|
raise ValueError("{coins} must be a string")
|
||||||
|
result = re.search(r'^(\d{1,10})\.(\d{1,8})$', coins)
|
||||||
|
if result is not None:
|
||||||
|
whole, fractional = result.groups()
|
||||||
|
return int(whole+fractional.ljust(8, "0"))
|
||||||
|
raise ValueError("'{lbc}' is not a valid coin decimal")
|
||||||
|
|
||||||
|
|
||||||
|
def satoshis_to_coins(satoshis):
|
||||||
|
coins = '{:.8f}'.format(satoshis / COIN).rstrip('0')
|
||||||
|
if coins.endswith('.'):
|
||||||
|
return coins+'0'
|
||||||
|
else:
|
||||||
|
return coins
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
class ReadOnlyList(Sequence[T]):
|
||||||
|
|
||||||
|
def __init__(self, lst):
|
||||||
|
self.lst = lst
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.lst[key]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.lst)
|
||||||
|
|
||||||
|
|
||||||
|
def subclass_tuple(name, base):
|
||||||
|
return type(name, (base,), {'__slots__': ()})
|
||||||
|
|
||||||
|
|
||||||
|
class cachedproperty:
|
||||||
|
|
||||||
|
def __init__(self, f):
|
||||||
|
self.f = f
|
||||||
|
|
||||||
|
def __get__(self, obj, objtype):
|
||||||
|
obj = obj or objtype
|
||||||
|
value = self.f(obj)
|
||||||
|
setattr(obj, self.f.__name__, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_int(be_bytes):
|
||||||
|
""" Interprets a big-endian sequence of bytes as an integer. """
|
||||||
|
return int(hexlify(be_bytes), 16)
|
||||||
|
|
||||||
|
|
||||||
|
def int_to_bytes(value):
|
||||||
|
""" Converts an integer to a big-endian sequence of bytes. """
|
||||||
|
length = (value.bit_length() + 7) // 8
|
||||||
|
s = '%x' % value
|
||||||
|
return unhexlify(('0' * (len(s) % 2) + s).zfill(length * 2))
|
||||||
|
|
||||||
|
|
||||||
|
class ArithUint256:
|
||||||
|
# https://github.com/bitcoin/bitcoin/blob/master/src/arith_uint256.cpp
|
||||||
|
|
||||||
|
__slots__ = '_value', '_compact'
|
||||||
|
|
||||||
|
def __init__(self, value: int) -> None:
|
||||||
|
self._value = value
|
||||||
|
self._compact: Optional[int] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_compact(cls, compact) -> 'ArithUint256':
|
||||||
|
size = compact >> 24
|
||||||
|
word = compact & 0x007fffff
|
||||||
|
if size <= 3:
|
||||||
|
return cls(word >> 8 * (3 - size))
|
||||||
|
else:
|
||||||
|
return cls(word << 8 * (size - 3))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self) -> int:
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def compact(self) -> int:
|
||||||
|
if self._compact is None:
|
||||||
|
self._compact = self._calculate_compact()
|
||||||
|
return self._compact
|
||||||
|
|
||||||
|
@property
|
||||||
|
def negative(self) -> int:
|
||||||
|
return self._calculate_compact(negative=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bits(self) -> int:
|
||||||
|
""" Returns the position of the highest bit set plus one. """
|
||||||
|
bits = bin(self._value)[2:]
|
||||||
|
for i, d in enumerate(bits):
|
||||||
|
if d:
|
||||||
|
return (len(bits) - i) + 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def low64(self) -> int:
|
||||||
|
return self._value & 0xffffffffffffffff
|
||||||
|
|
||||||
|
def _calculate_compact(self, negative=False) -> int:
|
||||||
|
size = (self.bits + 7) // 8
|
||||||
|
if size <= 3:
|
||||||
|
compact = self.low64 << 8 * (3 - size)
|
||||||
|
else:
|
||||||
|
compact = ArithUint256(self._value >> 8 * (size - 3)).low64
|
||||||
|
# The 0x00800000 bit denotes the sign.
|
||||||
|
# Thus, if it is already set, divide the mantissa by 256 and increase the exponent.
|
||||||
|
if compact & 0x00800000:
|
||||||
|
compact >>= 8
|
||||||
|
size += 1
|
||||||
|
assert (compact & ~0x007fffff) == 0
|
||||||
|
assert size < 256
|
||||||
|
compact |= size << 24
|
||||||
|
if negative and compact & 0x007fffff:
|
||||||
|
compact |= 0x00800000
|
||||||
|
return compact
|
||||||
|
|
||||||
|
def __mul__(self, x):
|
||||||
|
# Take the mod because we are limited to an unsigned 256 bit number
|
||||||
|
return ArithUint256((self._value * x) % 2 ** 256)
|
||||||
|
|
||||||
|
def __truediv__(self, x):
|
||||||
|
return ArithUint256(int(self._value / x))
|
||||||
|
|
||||||
|
def __gt__(self, other):
|
||||||
|
return self._value > other
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
return self._value < other
|
136
torba/torba/client/wallet.py
Normal file
136
torba/torba/client/wallet.py
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
import os
|
||||||
|
import stat
|
||||||
|
import json
|
||||||
|
import zlib
|
||||||
|
import typing
|
||||||
|
from typing import Sequence, MutableSequence
|
||||||
|
from hashlib import sha256
|
||||||
|
from operator import attrgetter
|
||||||
|
from torba.client.hash import better_aes_encrypt, better_aes_decrypt
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from torba.client import basemanager, baseaccount, baseledger
|
||||||
|
|
||||||
|
|
||||||
|
class Wallet:
|
||||||
|
""" The primary role of Wallet is to encapsulate a collection
|
||||||
|
of accounts (seed/private keys) and the spending rules / settings
|
||||||
|
for the coins attached to those accounts. Wallets are represented
|
||||||
|
by physical files on the filesystem.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str = 'Wallet', accounts: MutableSequence['baseaccount.BaseAccount'] = None,
|
||||||
|
storage: 'WalletStorage' = None) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.accounts = accounts or []
|
||||||
|
self.storage = storage or WalletStorage()
|
||||||
|
|
||||||
|
def add_account(self, account):
|
||||||
|
self.accounts.append(account)
|
||||||
|
|
||||||
|
def generate_account(self, ledger: 'baseledger.BaseLedger') -> 'baseaccount.BaseAccount':
|
||||||
|
return ledger.account_class.generate(ledger, self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_storage(cls, storage: 'WalletStorage', manager: 'basemanager.BaseWalletManager') -> 'Wallet':
|
||||||
|
json_dict = storage.read()
|
||||||
|
wallet = cls(
|
||||||
|
name=json_dict.get('name', 'Wallet'),
|
||||||
|
storage=storage
|
||||||
|
)
|
||||||
|
account_dicts: Sequence[dict] = json_dict.get('accounts', [])
|
||||||
|
for account_dict in account_dicts:
|
||||||
|
ledger = manager.get_or_create_ledger(account_dict['ledger'])
|
||||||
|
ledger.account_class.from_dict(ledger, wallet, account_dict)
|
||||||
|
return wallet
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
'version': WalletStorage.LATEST_VERSION,
|
||||||
|
'name': self.name,
|
||||||
|
'accounts': [a.to_dict() for a in self.accounts]
|
||||||
|
}
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
self.storage.write(self.to_dict())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_account(self):
|
||||||
|
for account in self.accounts:
|
||||||
|
return account
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hash(self) -> bytes:
|
||||||
|
h = sha256()
|
||||||
|
for account in sorted(self.accounts, key=attrgetter('id')):
|
||||||
|
h.update(account.hash)
|
||||||
|
return h.digest()
|
||||||
|
|
||||||
|
def pack(self, password):
|
||||||
|
new_data = json.dumps(self.to_dict())
|
||||||
|
new_data_compressed = zlib.compress(new_data.encode())
|
||||||
|
return better_aes_encrypt(password, new_data_compressed)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unpack(cls, password, encrypted):
|
||||||
|
decrypted = better_aes_decrypt(password, encrypted)
|
||||||
|
decompressed = zlib.decompress(decrypted)
|
||||||
|
return json.loads(decompressed)
|
||||||
|
|
||||||
|
|
||||||
|
class WalletStorage:
|
||||||
|
|
||||||
|
LATEST_VERSION = 1
|
||||||
|
|
||||||
|
def __init__(self, path=None, default=None):
|
||||||
|
self.path = path
|
||||||
|
self._default = default or {
|
||||||
|
'version': self.LATEST_VERSION,
|
||||||
|
'name': 'My Wallet',
|
||||||
|
'accounts': []
|
||||||
|
}
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
if self.path and os.path.exists(self.path):
|
||||||
|
with open(self.path, 'r') as f:
|
||||||
|
json_data = f.read()
|
||||||
|
json_dict = json.loads(json_data)
|
||||||
|
if json_dict.get('version') == self.LATEST_VERSION and \
|
||||||
|
set(json_dict) == set(self._default):
|
||||||
|
return json_dict
|
||||||
|
else:
|
||||||
|
return self.upgrade(json_dict)
|
||||||
|
else:
|
||||||
|
return self._default.copy()
|
||||||
|
|
||||||
|
def upgrade(self, json_dict):
|
||||||
|
json_dict = json_dict.copy()
|
||||||
|
version = json_dict.pop('version', -1)
|
||||||
|
if version == -1:
|
||||||
|
pass
|
||||||
|
upgraded = self._default.copy()
|
||||||
|
upgraded.update(json_dict)
|
||||||
|
return json_dict
|
||||||
|
|
||||||
|
def write(self, json_dict):
|
||||||
|
|
||||||
|
json_data = json.dumps(json_dict, indent=4, sort_keys=True)
|
||||||
|
if self.path is None:
|
||||||
|
return json_data
|
||||||
|
|
||||||
|
temp_path = "%s.tmp.%s" % (self.path, os.getpid())
|
||||||
|
with open(temp_path, "w") as f:
|
||||||
|
f.write(json_data)
|
||||||
|
f.flush()
|
||||||
|
os.fsync(f.fileno())
|
||||||
|
|
||||||
|
if os.path.exists(self.path):
|
||||||
|
mode = os.stat(self.path).st_mode
|
||||||
|
else:
|
||||||
|
mode = stat.S_IREAD | stat.S_IWRITE
|
||||||
|
try:
|
||||||
|
os.rename(temp_path, self.path)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
os.remove(self.path)
|
||||||
|
os.rename(temp_path, self.path)
|
||||||
|
os.chmod(self.path, mode)
|
0
torba/torba/client/words/__init__.py
Normal file
0
torba/torba/client/words/__init__.py
Normal file
2050
torba/torba/client/words/chinese_simplified.py
Normal file
2050
torba/torba/client/words/chinese_simplified.py
Normal file
File diff suppressed because it is too large
Load diff
2050
torba/torba/client/words/english.py
Normal file
2050
torba/torba/client/words/english.py
Normal file
File diff suppressed because it is too large
Load diff
2050
torba/torba/client/words/japanese.py
Normal file
2050
torba/torba/client/words/japanese.py
Normal file
File diff suppressed because it is too large
Load diff
1628
torba/torba/client/words/portuguese.py
Normal file
1628
torba/torba/client/words/portuguese.py
Normal file
File diff suppressed because it is too large
Load diff
2050
torba/torba/client/words/spanish.py
Normal file
2050
torba/torba/client/words/spanish.py
Normal file
File diff suppressed because it is too large
Load diff
1
torba/torba/coin/__init__.py
Normal file
1
torba/torba/coin/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
__path__: str = __import__('pkgutil').extend_path(__path__, __name__)
|
49
torba/torba/coin/bitcoincash.py
Normal file
49
torba/torba/coin/bitcoincash.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
__node_daemon__ = 'bitcoind'
|
||||||
|
__node_cli__ = 'bitcoin-cli'
|
||||||
|
__node_bin__ = 'bitcoin-abc-0.17.2/bin'
|
||||||
|
__node_url__ = (
|
||||||
|
'https://download.bitcoinabc.org/0.17.2/linux/bitcoin-abc-0.17.2-x86_64-linux-gnu.tar.gz'
|
||||||
|
)
|
||||||
|
__spvserver__ = 'torba.server.coins.BitcoinCashRegtest'
|
||||||
|
|
||||||
|
from binascii import unhexlify
|
||||||
|
from torba.client.baseledger import BaseLedger
|
||||||
|
from torba.client.basetransaction import BaseTransaction
|
||||||
|
from .bitcoinsegwit import MainHeaders, UnverifiedHeaders
|
||||||
|
|
||||||
|
|
||||||
|
class Transaction(BaseTransaction):
|
||||||
|
|
||||||
|
def signature_hash_type(self, hash_type):
|
||||||
|
return hash_type | 0x40
|
||||||
|
|
||||||
|
|
||||||
|
class MainNetLedger(BaseLedger):
|
||||||
|
name = 'BitcoinCash'
|
||||||
|
symbol = 'BCH'
|
||||||
|
network_name = 'mainnet'
|
||||||
|
|
||||||
|
headers_class = MainHeaders
|
||||||
|
transaction_class = Transaction
|
||||||
|
|
||||||
|
pubkey_address_prefix = bytes((0,))
|
||||||
|
script_address_prefix = bytes((5,))
|
||||||
|
extended_public_key_prefix = unhexlify('0488b21e')
|
||||||
|
extended_private_key_prefix = unhexlify('0488ade4')
|
||||||
|
|
||||||
|
default_fee_per_byte = 50
|
||||||
|
|
||||||
|
|
||||||
|
class RegTestLedger(MainNetLedger):
|
||||||
|
headers_class = UnverifiedHeaders
|
||||||
|
network_name = 'regtest'
|
||||||
|
|
||||||
|
pubkey_address_prefix = bytes((111,))
|
||||||
|
script_address_prefix = bytes((196,))
|
||||||
|
extended_public_key_prefix = unhexlify('043587cf')
|
||||||
|
extended_private_key_prefix = unhexlify('04358394')
|
||||||
|
|
||||||
|
max_target = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
|
||||||
|
genesis_hash = '0f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206'
|
||||||
|
genesis_bits = 0x207fffff
|
||||||
|
target_timespan = 1
|
86
torba/torba/coin/bitcoinsegwit.py
Normal file
86
torba/torba/coin/bitcoinsegwit.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
__node_daemon__ = 'bitcoind'
|
||||||
|
__node_cli__ = 'bitcoin-cli'
|
||||||
|
__node_bin__ = 'bitcoin-0.16.3/bin'
|
||||||
|
__node_url__ = (
|
||||||
|
'https://bitcoin.org/bin/bitcoin-core-0.16.3/bitcoin-0.16.3-x86_64-linux-gnu.tar.gz'
|
||||||
|
)
|
||||||
|
__spvserver__ = 'torba.server.coins.BitcoinSegwitRegtest'
|
||||||
|
|
||||||
|
import struct
|
||||||
|
from typing import Optional
|
||||||
|
from binascii import hexlify, unhexlify
|
||||||
|
from torba.client.baseledger import BaseLedger
|
||||||
|
from torba.client.baseheader import BaseHeaders, ArithUint256
|
||||||
|
|
||||||
|
|
||||||
|
class MainHeaders(BaseHeaders):
|
||||||
|
header_size = 80
|
||||||
|
chunk_size = 2016
|
||||||
|
max_target = 0x00000000ffffffffffffffffffffffffffffffffffffffffffffffffffffffff
|
||||||
|
genesis_hash: Optional[bytes] = b'000000000019d6689c085ae165831e934ff763ae46a2a6c172b3f1b60a8ce26f'
|
||||||
|
target_timespan = 14 * 24 * 60 * 60
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def serialize(header: dict) -> bytes:
|
||||||
|
return b''.join([
|
||||||
|
struct.pack('<I', header['version']),
|
||||||
|
unhexlify(header['prev_block_hash'])[::-1],
|
||||||
|
unhexlify(header['merkle_root'])[::-1],
|
||||||
|
struct.pack('<III', header['timestamp'], header['bits'], header['nonce'])
|
||||||
|
])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deserialize(height, header):
|
||||||
|
version, = struct.unpack('<I', header[:4])
|
||||||
|
timestamp, bits, nonce = struct.unpack('<III', header[68:80])
|
||||||
|
return {
|
||||||
|
'block_height': height,
|
||||||
|
'version': version,
|
||||||
|
'prev_block_hash': hexlify(header[4:36][::-1]),
|
||||||
|
'merkle_root': hexlify(header[36:68][::-1]),
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'bits': bits,
|
||||||
|
'nonce': nonce
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_next_chunk_target(self, chunk: int) -> ArithUint256:
|
||||||
|
if chunk == -1:
|
||||||
|
return ArithUint256(self.max_target)
|
||||||
|
previous = self[chunk * 2016]
|
||||||
|
current = self[chunk * 2016 + 2015]
|
||||||
|
actual_timespan = current['timestamp'] - previous['timestamp']
|
||||||
|
actual_timespan = max(actual_timespan, int(self.target_timespan / 4))
|
||||||
|
actual_timespan = min(actual_timespan, self.target_timespan * 4)
|
||||||
|
target = ArithUint256.from_compact(current['bits'])
|
||||||
|
new_target = min(ArithUint256(self.max_target), (target * actual_timespan) / self.target_timespan)
|
||||||
|
return new_target
|
||||||
|
|
||||||
|
|
||||||
|
class MainNetLedger(BaseLedger):
|
||||||
|
name = 'BitcoinSegwit'
|
||||||
|
symbol = 'BTC'
|
||||||
|
network_name = 'mainnet'
|
||||||
|
headers_class = MainHeaders
|
||||||
|
|
||||||
|
pubkey_address_prefix = bytes((0,))
|
||||||
|
script_address_prefix = bytes((5,))
|
||||||
|
extended_public_key_prefix = unhexlify('0488b21e')
|
||||||
|
extended_private_key_prefix = unhexlify('0488ade4')
|
||||||
|
|
||||||
|
default_fee_per_byte = 50
|
||||||
|
|
||||||
|
|
||||||
|
class UnverifiedHeaders(MainHeaders):
|
||||||
|
max_target = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
|
||||||
|
genesis_hash = None
|
||||||
|
validate_difficulty = False
|
||||||
|
|
||||||
|
|
||||||
|
class RegTestLedger(MainNetLedger):
|
||||||
|
network_name = 'regtest'
|
||||||
|
headers_class = UnverifiedHeaders
|
||||||
|
|
||||||
|
pubkey_address_prefix = bytes((111,))
|
||||||
|
script_address_prefix = bytes((196,))
|
||||||
|
extended_public_key_prefix = unhexlify('043587cf')
|
||||||
|
extended_private_key_prefix = unhexlify('04358394')
|
2
torba/torba/orchstr8/__init__.py
Normal file
2
torba/torba/orchstr8/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .node import Conductor
|
||||||
|
from .service import ConductorService
|
93
torba/torba/orchstr8/cli.py
Normal file
93
torba/torba/orchstr8/cli.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from torba.orchstr8.node import Conductor, get_ledger_from_environment, get_blockchain_node_from_ledger
|
||||||
|
from torba.orchstr8.service import ConductorService
|
||||||
|
|
||||||
|
|
||||||
|
def get_argument_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="torba"
|
||||||
|
)
|
||||||
|
subparsers = parser.add_subparsers(dest='command', help='sub-command help')
|
||||||
|
|
||||||
|
subparsers.add_parser("gui", help="Start Qt GUI.")
|
||||||
|
|
||||||
|
subparsers.add_parser("download", help="Download blockchain node binary.")
|
||||||
|
|
||||||
|
start = subparsers.add_parser("start", help="Start orchstr8 service.")
|
||||||
|
start.add_argument("--blockchain", help="Hostname to start blockchain node.")
|
||||||
|
start.add_argument("--spv", help="Hostname to start SPV server.")
|
||||||
|
start.add_argument("--wallet", help="Hostname to start wallet daemon.")
|
||||||
|
|
||||||
|
generate = subparsers.add_parser("generate", help="Call generate method on running orchstr8 instance.")
|
||||||
|
generate.add_argument("blocks", type=int, help="Number of blocks to generate")
|
||||||
|
|
||||||
|
subparsers.add_parser("transfer", help="Call transfer method on running orchstr8 instance.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
async def run_remote_command(command, **kwargs):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post('http://localhost:7954/'+command, data=kwargs) as resp:
|
||||||
|
print(resp.status)
|
||||||
|
print(await resp.text())
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_argument_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
command = getattr(args, 'command', 'help')
|
||||||
|
|
||||||
|
if command == 'gui':
|
||||||
|
from torba.workbench import main as start_app # pylint: disable=E0611,E0401
|
||||||
|
return start_app()
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
ledger = get_ledger_from_environment()
|
||||||
|
|
||||||
|
if command == 'download':
|
||||||
|
logging.getLogger('blockchain').setLevel(logging.INFO)
|
||||||
|
get_blockchain_node_from_ledger(ledger).ensure()
|
||||||
|
|
||||||
|
elif command == 'generate':
|
||||||
|
loop.run_until_complete(run_remote_command(
|
||||||
|
'generate', blocks=args.blocks
|
||||||
|
))
|
||||||
|
|
||||||
|
elif command == 'start':
|
||||||
|
|
||||||
|
conductor = Conductor()
|
||||||
|
if getattr(args, 'blockchain', False):
|
||||||
|
conductor.blockchain_node.hostname = args.blockchain
|
||||||
|
loop.run_until_complete(conductor.start_blockchain())
|
||||||
|
if getattr(args, 'spv', False):
|
||||||
|
conductor.spv_node.hostname = args.spv
|
||||||
|
loop.run_until_complete(conductor.start_spv())
|
||||||
|
if getattr(args, 'wallet', False):
|
||||||
|
conductor.wallet_node.hostname = args.wallet
|
||||||
|
loop.run_until_complete(conductor.start_wallet())
|
||||||
|
|
||||||
|
service = ConductorService(conductor, loop)
|
||||||
|
loop.run_until_complete(service.start())
|
||||||
|
|
||||||
|
try:
|
||||||
|
print('========== Orchstr8 API Service Started ========')
|
||||||
|
loop.run_forever()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(service.stop())
|
||||||
|
loop.run_until_complete(conductor.stop())
|
||||||
|
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
389
torba/torba/orchstr8/node.py
Normal file
389
torba/torba/orchstr8/node.py
Normal file
|
@ -0,0 +1,389 @@
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import asyncio
|
||||||
|
import zipfile
|
||||||
|
import tarfile
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
import subprocess
|
||||||
|
import importlib
|
||||||
|
from binascii import hexlify
|
||||||
|
from typing import Type, Optional
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
from torba.server.server import Server
|
||||||
|
from torba.server.env import Env
|
||||||
|
from torba.client.wallet import Wallet
|
||||||
|
from torba.client.baseledger import BaseLedger, BlockHeightEvent
|
||||||
|
from torba.client.basemanager import BaseWalletManager
|
||||||
|
from torba.client.baseaccount import BaseAccount
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_manager_from_environment(default_manager=BaseWalletManager):
|
||||||
|
if 'TORBA_MANAGER' not in os.environ:
|
||||||
|
return default_manager
|
||||||
|
module_name = os.environ['TORBA_MANAGER'].split('-')[-1] # tox support
|
||||||
|
return importlib.import_module(module_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ledger_from_environment():
|
||||||
|
if 'TORBA_LEDGER' not in os.environ:
|
||||||
|
raise ValueError('Environment variable TORBA_LEDGER must point to a torba based ledger module.')
|
||||||
|
module_name = os.environ['TORBA_LEDGER'].split('-')[-1] # tox support
|
||||||
|
return importlib.import_module(module_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_spvserver_from_ledger(ledger_module):
|
||||||
|
spvserver_path, regtest_class_name = ledger_module.__spvserver__.rsplit('.', 1)
|
||||||
|
spvserver_module = importlib.import_module(spvserver_path)
|
||||||
|
return getattr(spvserver_module, regtest_class_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_blockchain_node_from_ledger(ledger_module):
|
||||||
|
return BlockchainNode(
|
||||||
|
ledger_module.__node_url__,
|
||||||
|
os.path.join(ledger_module.__node_bin__, ledger_module.__node_daemon__),
|
||||||
|
os.path.join(ledger_module.__node_bin__, ledger_module.__node_cli__)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_logging(ledger_module, level, handler=None):
|
||||||
|
modules = [
|
||||||
|
'torba',
|
||||||
|
'torba.client',
|
||||||
|
'torba.server',
|
||||||
|
'blockchain',
|
||||||
|
ledger_module.__name__
|
||||||
|
]
|
||||||
|
for module_name in modules:
|
||||||
|
module = logging.getLogger(module_name)
|
||||||
|
module.setLevel(level)
|
||||||
|
if handler is not None:
|
||||||
|
module.addHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
class Conductor:
|
||||||
|
|
||||||
|
def __init__(self, ledger_module=None, manager_module=None, verbosity=logging.WARNING):
|
||||||
|
self.ledger_module = ledger_module or get_ledger_from_environment()
|
||||||
|
self.manager_module = manager_module or get_manager_from_environment()
|
||||||
|
self.spv_module = get_spvserver_from_ledger(self.ledger_module)
|
||||||
|
|
||||||
|
self.blockchain_node = get_blockchain_node_from_ledger(self.ledger_module)
|
||||||
|
self.spv_node = SPVNode(self.spv_module)
|
||||||
|
self.wallet_node = WalletNode(self.manager_module, self.ledger_module.RegTestLedger)
|
||||||
|
|
||||||
|
set_logging(self.ledger_module, verbosity)
|
||||||
|
|
||||||
|
self.blockchain_started = False
|
||||||
|
self.spv_started = False
|
||||||
|
self.wallet_started = False
|
||||||
|
|
||||||
|
self.log = log.getChild('conductor')
|
||||||
|
|
||||||
|
async def start_blockchain(self):
|
||||||
|
if not self.blockchain_started:
|
||||||
|
await self.blockchain_node.start()
|
||||||
|
await self.blockchain_node.generate(200)
|
||||||
|
self.blockchain_started = True
|
||||||
|
|
||||||
|
async def stop_blockchain(self):
|
||||||
|
if self.blockchain_started:
|
||||||
|
await self.blockchain_node.stop(cleanup=True)
|
||||||
|
self.blockchain_started = False
|
||||||
|
|
||||||
|
async def start_spv(self):
|
||||||
|
if not self.spv_started:
|
||||||
|
await self.spv_node.start(self.blockchain_node)
|
||||||
|
self.spv_started = True
|
||||||
|
|
||||||
|
async def stop_spv(self):
|
||||||
|
if self.spv_started:
|
||||||
|
await self.spv_node.stop(cleanup=True)
|
||||||
|
self.spv_started = False
|
||||||
|
|
||||||
|
async def start_wallet(self):
|
||||||
|
if not self.wallet_started:
|
||||||
|
await self.wallet_node.start(self.spv_node)
|
||||||
|
self.wallet_started = True
|
||||||
|
|
||||||
|
async def stop_wallet(self):
|
||||||
|
if self.wallet_started:
|
||||||
|
await self.wallet_node.stop(cleanup=True)
|
||||||
|
self.wallet_started = False
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
await self.start_blockchain()
|
||||||
|
await self.start_spv()
|
||||||
|
await self.start_wallet()
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
all_the_stops = [
|
||||||
|
self.stop_wallet,
|
||||||
|
self.stop_spv,
|
||||||
|
self.stop_blockchain
|
||||||
|
]
|
||||||
|
for stop in all_the_stops:
|
||||||
|
try:
|
||||||
|
await stop()
|
||||||
|
except Exception as e:
|
||||||
|
log.exception('Exception raised while stopping services:', exc_info=e)
|
||||||
|
|
||||||
|
|
||||||
|
class WalletNode:
|
||||||
|
|
||||||
|
def __init__(self, manager_class: Type[BaseWalletManager], ledger_class: Type[BaseLedger],
|
||||||
|
verbose: bool = False, port: int = 5280) -> None:
|
||||||
|
self.manager_class = manager_class
|
||||||
|
self.ledger_class = ledger_class
|
||||||
|
self.verbose = verbose
|
||||||
|
self.manager: Optional[BaseWalletManager] = None
|
||||||
|
self.ledger: Optional[BaseLedger] = None
|
||||||
|
self.wallet: Optional[Wallet] = None
|
||||||
|
self.account: Optional[BaseAccount] = None
|
||||||
|
self.data_path: Optional[str] = None
|
||||||
|
self.port = port
|
||||||
|
|
||||||
|
async def start(self, spv_node: 'SPVNode', seed=None, connect=True):
|
||||||
|
self.data_path = tempfile.mkdtemp()
|
||||||
|
wallet_file_name = os.path.join(self.data_path, 'my_wallet.json')
|
||||||
|
with open(wallet_file_name, 'w') as wallet_file:
|
||||||
|
wallet_file.write('{"version": 1, "accounts": []}\n')
|
||||||
|
self.manager = self.manager_class.from_config({
|
||||||
|
'ledgers': {
|
||||||
|
self.ledger_class.get_id(): {
|
||||||
|
'api_port': self.port,
|
||||||
|
'default_servers': [(spv_node.hostname, spv_node.port)],
|
||||||
|
'data_path': self.data_path
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'wallets': [wallet_file_name]
|
||||||
|
})
|
||||||
|
self.ledger = self.manager.ledgers[self.ledger_class]
|
||||||
|
self.wallet = self.manager.default_wallet
|
||||||
|
if seed is None and self.wallet is not None:
|
||||||
|
self.wallet.generate_account(self.ledger)
|
||||||
|
elif self.wallet is not None:
|
||||||
|
self.ledger.account_class.from_dict(
|
||||||
|
self.ledger, self.wallet, {'seed': seed}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError('Wallet is required.')
|
||||||
|
self.account = self.wallet.default_account
|
||||||
|
if connect:
|
||||||
|
await self.manager.start()
|
||||||
|
|
||||||
|
async def stop(self, cleanup=True):
|
||||||
|
try:
|
||||||
|
await self.manager.stop()
|
||||||
|
finally:
|
||||||
|
cleanup and self.cleanup()
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
shutil.rmtree(self.data_path, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
class SPVNode:
|
||||||
|
|
||||||
|
def __init__(self, coin_class):
|
||||||
|
self.coin_class = coin_class
|
||||||
|
self.controller = None
|
||||||
|
self.data_path = None
|
||||||
|
self.server = None
|
||||||
|
self.hostname = 'localhost'
|
||||||
|
self.port = 50001 + 1 # avoid conflict with default daemon
|
||||||
|
|
||||||
|
async def start(self, blockchain_node: 'BlockchainNode'):
|
||||||
|
self.data_path = tempfile.mkdtemp()
|
||||||
|
conf = {
|
||||||
|
'DB_DIRECTORY': self.data_path,
|
||||||
|
'DAEMON_URL': blockchain_node.rpc_url,
|
||||||
|
'REORG_LIMIT': '100',
|
||||||
|
'HOST': self.hostname,
|
||||||
|
'TCP_PORT': str(self.port)
|
||||||
|
}
|
||||||
|
# TODO: don't use os.environ
|
||||||
|
os.environ.update(conf)
|
||||||
|
self.server = Server(Env(self.coin_class))
|
||||||
|
self.server.mempool.refresh_secs = self.server.bp.prefetcher.polling_delay = 0.5
|
||||||
|
await self.server.start()
|
||||||
|
|
||||||
|
async def stop(self, cleanup=True):
|
||||||
|
try:
|
||||||
|
await self.server.stop()
|
||||||
|
finally:
|
||||||
|
cleanup and self.cleanup()
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
shutil.rmtree(self.data_path, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockchainProcess(asyncio.SubprocessProtocol):
|
||||||
|
|
||||||
|
IGNORE_OUTPUT = [
|
||||||
|
b'keypool keep',
|
||||||
|
b'keypool reserve',
|
||||||
|
b'keypool return',
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.ready = asyncio.Event()
|
||||||
|
self.stopped = asyncio.Event()
|
||||||
|
self.log = log.getChild('blockchain')
|
||||||
|
|
||||||
|
def pipe_data_received(self, fd, data):
|
||||||
|
if self.log and not any(ignore in data for ignore in self.IGNORE_OUTPUT):
|
||||||
|
if b'Error:' in data:
|
||||||
|
self.log.error(data.decode())
|
||||||
|
else:
|
||||||
|
self.log.info(data.decode())
|
||||||
|
if b'Error:' in data:
|
||||||
|
self.ready.set()
|
||||||
|
raise SystemError(data.decode())
|
||||||
|
if b'Done loading' in data:
|
||||||
|
self.ready.set()
|
||||||
|
|
||||||
|
def process_exited(self):
|
||||||
|
self.stopped.set()
|
||||||
|
|
||||||
|
|
||||||
|
class BlockchainNode:
|
||||||
|
|
||||||
|
def __init__(self, url, daemon, cli):
|
||||||
|
self.latest_release_url = url
|
||||||
|
self.project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
self.bin_dir = os.path.join(self.project_dir, 'bin')
|
||||||
|
self.daemon_bin = os.path.join(self.bin_dir, daemon)
|
||||||
|
self.cli_bin = os.path.join(self.bin_dir, cli)
|
||||||
|
self.log = log.getChild('blockchain')
|
||||||
|
self.data_path = None
|
||||||
|
self.protocol = None
|
||||||
|
self.transport = None
|
||||||
|
self._block_expected = 0
|
||||||
|
self.hostname = 'localhost'
|
||||||
|
self.peerport = 9246 + 2 # avoid conflict with default peer port
|
||||||
|
self.rpcport = 9245 + 2 # avoid conflict with default rpc port
|
||||||
|
self.rpcuser = 'rpcuser'
|
||||||
|
self.rpcpassword = 'rpcpassword'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rpc_url(self):
|
||||||
|
return f'http://{self.rpcuser}:{self.rpcpassword}@{self.hostname}:{self.rpcport}/'
|
||||||
|
|
||||||
|
def is_expected_block(self, e: BlockHeightEvent):
|
||||||
|
return self._block_expected == e.height
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exists(self):
|
||||||
|
return (
|
||||||
|
os.path.exists(self.cli_bin) and
|
||||||
|
os.path.exists(self.daemon_bin)
|
||||||
|
)
|
||||||
|
|
||||||
|
def download(self):
|
||||||
|
downloaded_file = os.path.join(
|
||||||
|
self.bin_dir,
|
||||||
|
self.latest_release_url[self.latest_release_url.rfind('/')+1:]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not os.path.exists(self.bin_dir):
|
||||||
|
os.mkdir(self.bin_dir)
|
||||||
|
|
||||||
|
if not os.path.exists(downloaded_file):
|
||||||
|
self.log.info('Downloading: %s', self.latest_release_url)
|
||||||
|
with urllib.request.urlopen(self.latest_release_url) as response:
|
||||||
|
with open(downloaded_file, 'wb') as out_file:
|
||||||
|
shutil.copyfileobj(response, out_file)
|
||||||
|
|
||||||
|
self.log.info('Extracting: %s', downloaded_file)
|
||||||
|
|
||||||
|
if downloaded_file.endswith('.zip'):
|
||||||
|
with zipfile.ZipFile(downloaded_file) as dotzip:
|
||||||
|
dotzip.extractall(self.bin_dir)
|
||||||
|
# zipfile bug https://bugs.python.org/issue15795
|
||||||
|
os.chmod(self.cli_bin, 0o755)
|
||||||
|
os.chmod(self.daemon_bin, 0o755)
|
||||||
|
|
||||||
|
elif downloaded_file.endswith('.tar.gz'):
|
||||||
|
with tarfile.open(downloaded_file) as tar:
|
||||||
|
tar.extractall(self.bin_dir)
|
||||||
|
|
||||||
|
return self.exists
|
||||||
|
|
||||||
|
def ensure(self):
|
||||||
|
return self.exists or self.download()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
assert self.ensure()
|
||||||
|
self.data_path = tempfile.mkdtemp()
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
asyncio.get_child_watcher().attach_loop(loop)
|
||||||
|
command = (
|
||||||
|
self.daemon_bin,
|
||||||
|
f'-datadir={self.data_path}', '-printtoconsole', '-regtest', '-server', '-txindex',
|
||||||
|
f'-rpcuser={self.rpcuser}', f'-rpcpassword={self.rpcpassword}', f'-rpcport={self.rpcport}',
|
||||||
|
f'-port={self.peerport}'
|
||||||
|
)
|
||||||
|
self.log.info(' '.join(command))
|
||||||
|
self.transport, self.protocol = await loop.subprocess_exec(
|
||||||
|
BlockchainProcess, *command
|
||||||
|
)
|
||||||
|
await self.protocol.ready.wait()
|
||||||
|
|
||||||
|
async def stop(self, cleanup=True):
|
||||||
|
try:
|
||||||
|
self.transport.terminate()
|
||||||
|
await self.protocol.stopped.wait()
|
||||||
|
self.transport.close()
|
||||||
|
finally:
|
||||||
|
if cleanup:
|
||||||
|
self.cleanup()
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
shutil.rmtree(self.data_path, ignore_errors=True)
|
||||||
|
|
||||||
|
async def _cli_cmnd(self, *args):
|
||||||
|
cmnd_args = [
|
||||||
|
self.cli_bin, f'-datadir={self.data_path}', '-regtest',
|
||||||
|
f'-rpcuser={self.rpcuser}', f'-rpcpassword={self.rpcpassword}', f'-rpcport={self.rpcport}'
|
||||||
|
] + list(args)
|
||||||
|
self.log.info(' '.join(cmnd_args))
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
asyncio.get_child_watcher().attach_loop(loop)
|
||||||
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
*cmnd_args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
|
||||||
|
)
|
||||||
|
out, _ = await process.communicate()
|
||||||
|
self.log.info(out.decode().strip())
|
||||||
|
return out.decode().strip()
|
||||||
|
|
||||||
|
def generate(self, blocks):
|
||||||
|
self._block_expected += blocks
|
||||||
|
return self._cli_cmnd('generate', str(blocks))
|
||||||
|
|
||||||
|
def invalidate_block(self, blockhash):
|
||||||
|
return self._cli_cmnd('invalidateblock', blockhash)
|
||||||
|
|
||||||
|
def get_block_hash(self, block):
|
||||||
|
return self._cli_cmnd('getblockhash', str(block))
|
||||||
|
|
||||||
|
def get_raw_change_address(self):
|
||||||
|
return self._cli_cmnd('getrawchangeaddress')
|
||||||
|
|
||||||
|
async def get_balance(self):
|
||||||
|
return float(await self._cli_cmnd('getbalance'))
|
||||||
|
|
||||||
|
def send_to_address(self, address, amount):
|
||||||
|
return self._cli_cmnd('sendtoaddress', address, str(amount))
|
||||||
|
|
||||||
|
def send_raw_transaction(self, tx):
|
||||||
|
return self._cli_cmnd('sendrawtransaction', tx.decode())
|
||||||
|
|
||||||
|
def decode_raw_transaction(self, tx):
|
||||||
|
return self._cli_cmnd('decoderawtransaction', hexlify(tx.raw).decode())
|
||||||
|
|
||||||
|
def get_raw_transaction(self, txid):
|
||||||
|
return self._cli_cmnd('getrawtransaction', txid, '1')
|
137
torba/torba/orchstr8/service.py
Normal file
137
torba/torba/orchstr8/service.py
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from aiohttp.web import Application, WebSocketResponse, json_response
|
||||||
|
from aiohttp.http_websocket import WSMsgType, WSCloseCode
|
||||||
|
|
||||||
|
from torba.client.util import satoshis_to_coins
|
||||||
|
from .node import Conductor, set_logging
|
||||||
|
|
||||||
|
|
||||||
|
PORT = 7954
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketLogHandler(logging.Handler):
|
||||||
|
|
||||||
|
def __init__(self, send_message):
|
||||||
|
super().__init__()
|
||||||
|
self.send_message = send_message
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
try:
|
||||||
|
self.send_message({
|
||||||
|
'type': 'log',
|
||||||
|
'name': record.name,
|
||||||
|
'message': self.format(record)
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
self.handleError(record)
|
||||||
|
|
||||||
|
|
||||||
|
class ConductorService:
|
||||||
|
|
||||||
|
def __init__(self, stack: Conductor, loop: asyncio.AbstractEventLoop) -> None:
|
||||||
|
self.stack = stack
|
||||||
|
self.loop = loop
|
||||||
|
self.app = Application()
|
||||||
|
self.app.router.add_post('/start', self.start_stack)
|
||||||
|
self.app.router.add_post('/generate', self.generate)
|
||||||
|
self.app.router.add_post('/transfer', self.transfer)
|
||||||
|
self.app.router.add_post('/balance', self.balance)
|
||||||
|
self.app.router.add_get('/log', self.log)
|
||||||
|
self.app['websockets'] = set()
|
||||||
|
self.app.on_shutdown.append(self.on_shutdown)
|
||||||
|
self.handler = self.app.make_handler()
|
||||||
|
self.server = None
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
self.server = await self.loop.create_server(
|
||||||
|
self.handler, '0.0.0.0', PORT
|
||||||
|
)
|
||||||
|
print('serving on', self.server.sockets[0].getsockname())
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
await self.stack.stop()
|
||||||
|
self.server.close()
|
||||||
|
await self.server.wait_closed()
|
||||||
|
await self.app.shutdown()
|
||||||
|
await self.handler.shutdown(60.0)
|
||||||
|
await self.app.cleanup()
|
||||||
|
|
||||||
|
async def start_stack(self, _):
|
||||||
|
set_logging(
|
||||||
|
self.stack.ledger_module, logging.DEBUG, WebSocketLogHandler(self.send_message)
|
||||||
|
)
|
||||||
|
self.stack.blockchain_started or await self.stack.start_blockchain()
|
||||||
|
self.send_message({'type': 'service', 'name': 'blockchain', 'port': self.stack.blockchain_node.port})
|
||||||
|
self.stack.spv_started or await self.stack.start_spv()
|
||||||
|
self.send_message({'type': 'service', 'name': 'spv', 'port': self.stack.spv_node.port})
|
||||||
|
self.stack.wallet_started or await self.stack.start_wallet()
|
||||||
|
self.send_message({'type': 'service', 'name': 'wallet', 'port': self.stack.wallet_node.port})
|
||||||
|
self.stack.wallet_node.ledger.on_header.listen(self.on_status)
|
||||||
|
self.stack.wallet_node.ledger.on_transaction.listen(self.on_status)
|
||||||
|
return json_response({'started': True})
|
||||||
|
|
||||||
|
async def generate(self, request):
|
||||||
|
data = await request.post()
|
||||||
|
blocks = data.get('blocks', 1)
|
||||||
|
await self.stack.blockchain_node.generate(int(blocks))
|
||||||
|
return json_response({'blocks': blocks})
|
||||||
|
|
||||||
|
async def transfer(self, request):
|
||||||
|
data = await request.post()
|
||||||
|
address = data.get('address')
|
||||||
|
if not address and self.stack.wallet_started:
|
||||||
|
address = await self.stack.wallet_node.account.receiving.get_or_create_usable_address()
|
||||||
|
if not address:
|
||||||
|
raise ValueError("No address was provided.")
|
||||||
|
amount = data.get('amount', 1)
|
||||||
|
txid = await self.stack.blockchain_node.send_to_address(address, amount)
|
||||||
|
if self.stack.wallet_started:
|
||||||
|
await self.stack.wallet_node.ledger.on_transaction.where(
|
||||||
|
lambda e: e.tx.id == txid and e.address == address
|
||||||
|
)
|
||||||
|
return json_response({
|
||||||
|
'address': address,
|
||||||
|
'amount': amount,
|
||||||
|
'txid': txid
|
||||||
|
})
|
||||||
|
|
||||||
|
async def balance(self, _):
|
||||||
|
return json_response({
|
||||||
|
'balance': await self.stack.blockchain_node.get_balance()
|
||||||
|
})
|
||||||
|
|
||||||
|
async def log(self, request):
|
||||||
|
web_socket = WebSocketResponse()
|
||||||
|
await web_socket.prepare(request)
|
||||||
|
self.app['websockets'].add(web_socket)
|
||||||
|
try:
|
||||||
|
async for msg in web_socket:
|
||||||
|
if msg.type == WSMsgType.TEXT:
|
||||||
|
if msg.data == 'close':
|
||||||
|
await web_socket.close()
|
||||||
|
elif msg.type == WSMsgType.ERROR:
|
||||||
|
print('web socket connection closed with exception %s' %
|
||||||
|
web_socket.exception())
|
||||||
|
finally:
|
||||||
|
self.app['websockets'].remove(web_socket)
|
||||||
|
return web_socket
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def on_shutdown(app):
|
||||||
|
for web_socket in app['websockets']:
|
||||||
|
await web_socket.close(code=WSCloseCode.GOING_AWAY, message='Server shutdown')
|
||||||
|
|
||||||
|
async def on_status(self, _):
|
||||||
|
if not self.app['websockets']:
|
||||||
|
return
|
||||||
|
self.send_message({
|
||||||
|
'type': 'status',
|
||||||
|
'height': self.stack.wallet_node.ledger.headers.height,
|
||||||
|
'balance': satoshis_to_coins(await self.stack.wallet_node.account.get_balance()),
|
||||||
|
'miner': await self.stack.blockchain_node.get_balance()
|
||||||
|
})
|
||||||
|
|
||||||
|
def send_message(self, msg):
|
||||||
|
for web_socket in self.app['websockets']:
|
||||||
|
self.loop.create_task(web_socket.send_json(msg))
|
11
torba/torba/rpc/__init__.py
Normal file
11
torba/torba/rpc/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
from .framing import *
|
||||||
|
from .jsonrpc import *
|
||||||
|
from .socks import *
|
||||||
|
from .session import *
|
||||||
|
from .util import *
|
||||||
|
|
||||||
|
__all__ = (framing.__all__ +
|
||||||
|
jsonrpc.__all__ +
|
||||||
|
socks.__all__ +
|
||||||
|
session.__all__ +
|
||||||
|
util.__all__)
|
239
torba/torba/rpc/framing.py
Normal file
239
torba/torba/rpc/framing.py
Normal file
|
@ -0,0 +1,239 @@
|
||||||
|
# Copyright (c) 2018, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
"""RPC message framing in a byte stream."""
|
||||||
|
|
||||||
|
__all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer',
|
||||||
|
'OversizedPayloadError', 'BadChecksumError', 'BadMagicError')
|
||||||
|
|
||||||
|
from hashlib import sha256 as _sha256
|
||||||
|
from struct import Struct
|
||||||
|
from asyncio import Queue
|
||||||
|
|
||||||
|
|
||||||
|
class FramerBase(object):
|
||||||
|
"""Abstract base class for a framer.
|
||||||
|
|
||||||
|
A framer breaks an incoming byte stream into protocol messages,
|
||||||
|
buffering if necesary. It also frames outgoing messages into
|
||||||
|
a byte stream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def frame(self, message):
|
||||||
|
"""Return the framed message."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def received_bytes(self, data):
|
||||||
|
"""Pass incoming network bytes."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def receive_message(self):
|
||||||
|
"""Wait for a complete unframed message to arrive, and return it."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class NewlineFramer(FramerBase):
|
||||||
|
"""A framer for a protocol where messages are separated by newlines."""
|
||||||
|
|
||||||
|
# The default max_size value is motivated by JSONRPC, where a
|
||||||
|
# normal request will be 250 bytes or less, and a reasonable
|
||||||
|
# batch may contain 4000 requests.
|
||||||
|
def __init__(self, max_size=250 * 4000):
|
||||||
|
"""max_size - an anti-DoS measure. If, after processing an incoming
|
||||||
|
message, buffered data would exceed max_size bytes, that
|
||||||
|
buffered data is dropped entirely and the framer waits for a
|
||||||
|
newline character to re-synchronize the stream.
|
||||||
|
"""
|
||||||
|
self.max_size = max_size
|
||||||
|
self.queue = Queue()
|
||||||
|
self.received_bytes = self.queue.put_nowait
|
||||||
|
self.synchronizing = False
|
||||||
|
self.residual = b''
|
||||||
|
|
||||||
|
def frame(self, message):
|
||||||
|
return message + b'\n'
|
||||||
|
|
||||||
|
async def receive_message(self):
|
||||||
|
parts = []
|
||||||
|
buffer_size = 0
|
||||||
|
while True:
|
||||||
|
part = self.residual
|
||||||
|
self.residual = b''
|
||||||
|
if not part:
|
||||||
|
part = await self.queue.get()
|
||||||
|
|
||||||
|
npos = part.find(b'\n')
|
||||||
|
if npos == -1:
|
||||||
|
parts.append(part)
|
||||||
|
buffer_size += len(part)
|
||||||
|
# Ignore over-sized messages; re-synchronize
|
||||||
|
if buffer_size <= self.max_size:
|
||||||
|
continue
|
||||||
|
self.synchronizing = True
|
||||||
|
raise MemoryError(f'dropping message over {self.max_size:,d} '
|
||||||
|
f'bytes and re-synchronizing')
|
||||||
|
|
||||||
|
tail, self.residual = part[:npos], part[npos + 1:]
|
||||||
|
if self.synchronizing:
|
||||||
|
self.synchronizing = False
|
||||||
|
return await self.receive_message()
|
||||||
|
else:
|
||||||
|
parts.append(tail)
|
||||||
|
return b''.join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
class ByteQueue(object):
|
||||||
|
"""A producer-comsumer queue. Incoming network data is put as it
|
||||||
|
arrives, and the consumer calls an async method waiting for data of
|
||||||
|
a specific length."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.queue = Queue()
|
||||||
|
self.parts = []
|
||||||
|
self.parts_len = 0
|
||||||
|
self.put_nowait = self.queue.put_nowait
|
||||||
|
|
||||||
|
async def receive(self, size):
|
||||||
|
while self.parts_len < size:
|
||||||
|
part = await self.queue.get()
|
||||||
|
self.parts.append(part)
|
||||||
|
self.parts_len += len(part)
|
||||||
|
self.parts_len -= size
|
||||||
|
whole = b''.join(self.parts)
|
||||||
|
self.parts = [whole[size:]]
|
||||||
|
return whole[:size]
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryFramer(object):
|
||||||
|
"""A framer for binary messaging protocols."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.byte_queue = ByteQueue()
|
||||||
|
self.message_queue = Queue()
|
||||||
|
self.received_bytes = self.byte_queue.put_nowait
|
||||||
|
|
||||||
|
def frame(self, message):
|
||||||
|
command, payload = message
|
||||||
|
return b''.join((
|
||||||
|
self._build_header(command, payload),
|
||||||
|
payload
|
||||||
|
))
|
||||||
|
|
||||||
|
async def receive_message(self):
|
||||||
|
command, payload_len, checksum = await self._receive_header()
|
||||||
|
payload = await self.byte_queue.receive(payload_len)
|
||||||
|
payload_checksum = self._checksum(payload)
|
||||||
|
if payload_checksum != checksum:
|
||||||
|
raise BadChecksumError(payload_checksum, checksum)
|
||||||
|
return command, payload
|
||||||
|
|
||||||
|
def _checksum(self, payload):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _build_header(self, command, payload):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def _receive_header(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
# Helpers
|
||||||
|
struct_le_I = Struct('<I')
|
||||||
|
pack_le_uint32 = struct_le_I.pack
|
||||||
|
|
||||||
|
|
||||||
|
def sha256(x):
|
||||||
|
"""Simple wrapper of hashlib sha256."""
|
||||||
|
return _sha256(x).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def double_sha256(x):
|
||||||
|
"""SHA-256 of SHA-256, as used extensively in bitcoin."""
|
||||||
|
return sha256(sha256(x))
|
||||||
|
|
||||||
|
|
||||||
|
class BadChecksumError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BadMagicError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OversizedPayloadError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BitcoinFramer(BinaryFramer):
|
||||||
|
"""Provides a framer of binary message payloads in the style of the
|
||||||
|
Bitcoin network protocol.
|
||||||
|
|
||||||
|
Each binary message has the following elements, in order:
|
||||||
|
|
||||||
|
Magic - to confirm network (currently unused for stream sync)
|
||||||
|
Command - padded command
|
||||||
|
Length - payload length in bytes
|
||||||
|
Checksum - checksum of the payload
|
||||||
|
Payload - binary payload
|
||||||
|
|
||||||
|
Call frame(command, payload) to get a framed message.
|
||||||
|
Pass incoming network bytes to received_bytes().
|
||||||
|
Wait on receive_message() to get incoming (command, payload) pairs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, magic, max_block_size):
|
||||||
|
def pad_command(command):
|
||||||
|
fill = 12 - len(command)
|
||||||
|
if fill < 0:
|
||||||
|
raise ValueError(f'command {command} too long')
|
||||||
|
return command + bytes(fill)
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self._magic = magic
|
||||||
|
self._max_block_size = max_block_size
|
||||||
|
self._pad_command = pad_command
|
||||||
|
self._unpack = Struct(f'<4s12sI4s').unpack
|
||||||
|
|
||||||
|
def _checksum(self, payload):
|
||||||
|
return double_sha256(payload)[:4]
|
||||||
|
|
||||||
|
def _build_header(self, command, payload):
|
||||||
|
return b''.join((
|
||||||
|
self._magic,
|
||||||
|
self._pad_command(command),
|
||||||
|
pack_le_uint32(len(payload)),
|
||||||
|
self._checksum(payload)
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _receive_header(self):
|
||||||
|
header = await self.byte_queue.receive(24)
|
||||||
|
magic, command, payload_len, checksum = self._unpack(header)
|
||||||
|
if magic != self._magic:
|
||||||
|
raise BadMagicError(magic, self._magic)
|
||||||
|
command = command.rstrip(b'\0')
|
||||||
|
if payload_len > 1024 * 1024:
|
||||||
|
if command != b'block' or payload_len > self._max_block_size:
|
||||||
|
raise OversizedPayloadError(command, payload_len)
|
||||||
|
return command, payload_len, checksum
|
801
torba/torba/rpc/jsonrpc.py
Normal file
801
torba/torba/rpc/jsonrpc.py
Normal file
|
@ -0,0 +1,801 @@
|
||||||
|
# Copyright (c) 2018, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
"""Classes for JSONRPC versions 1.0 and 2.0, and a loose interpretation."""
|
||||||
|
|
||||||
|
__all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
|
||||||
|
'JSONRPCAutoDetect', 'Request', 'Notification', 'Batch',
|
||||||
|
'RPCError', 'ProtocolError',
|
||||||
|
'JSONRPCConnection', 'handler_invocation')
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import json
|
||||||
|
from functools import partial
|
||||||
|
from numbers import Number
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from asyncio import Queue, Event, CancelledError
|
||||||
|
from .util import signature_info
|
||||||
|
|
||||||
|
|
||||||
|
class SingleRequest(object):
|
||||||
|
__slots__ = ('method', 'args')
|
||||||
|
|
||||||
|
def __init__(self, method, args):
|
||||||
|
if not isinstance(method, str):
|
||||||
|
raise ProtocolError(JSONRPC.METHOD_NOT_FOUND,
|
||||||
|
'method must be a string')
|
||||||
|
if not isinstance(args, (list, tuple, dict)):
|
||||||
|
raise ProtocolError.invalid_args('request arguments must be a '
|
||||||
|
'list or a dictionary')
|
||||||
|
self.args = args
|
||||||
|
self.method = method
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.method!r}, {self.args!r})'
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (isinstance(other, self.__class__) and
|
||||||
|
self.method == other.method and self.args == other.args)
|
||||||
|
|
||||||
|
|
||||||
|
class Request(SingleRequest):
|
||||||
|
def send_result(self, response):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class Notification(SingleRequest):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Batch:
|
||||||
|
__slots__ = ('items', )
|
||||||
|
|
||||||
|
def __init__(self, items):
|
||||||
|
if not isinstance(items, (list, tuple)):
|
||||||
|
raise ProtocolError.invalid_request('items must be a list')
|
||||||
|
if not items:
|
||||||
|
raise ProtocolError.empty_batch()
|
||||||
|
if not (all(isinstance(item, SingleRequest) for item in items) or
|
||||||
|
all(isinstance(item, Response) for item in items)):
|
||||||
|
raise ProtocolError.invalid_request('batch must be homogeneous')
|
||||||
|
self.items = items
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.items)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.items[item]
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.items)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'Batch({len(self.items)} items)'
|
||||||
|
|
||||||
|
|
||||||
|
class Response(object):
|
||||||
|
__slots__ = ('result', )
|
||||||
|
|
||||||
|
def __init__(self, result):
|
||||||
|
# Type checking happens when converting to a message
|
||||||
|
self.result = result
|
||||||
|
|
||||||
|
|
||||||
|
class CodeMessageError(Exception):
|
||||||
|
|
||||||
|
def __init__(self, code, message):
|
||||||
|
super().__init__(code, message)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def code(self):
|
||||||
|
return self.args[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message(self):
|
||||||
|
return self.args[1]
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (isinstance(other, self.__class__) and
|
||||||
|
self.code == other.code and self.message == other.message)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
# overridden to make the exception hashable
|
||||||
|
# see https://bugs.python.org/issue28603
|
||||||
|
return hash((self.code, self.message))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def invalid_args(cls, message):
|
||||||
|
return cls(JSONRPC.INVALID_ARGS, message)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def invalid_request(cls, message):
|
||||||
|
return cls(JSONRPC.INVALID_REQUEST, message)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty_batch(cls):
|
||||||
|
return cls.invalid_request('batch is empty')
|
||||||
|
|
||||||
|
|
||||||
|
class RPCError(CodeMessageError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolError(CodeMessageError):
|
||||||
|
|
||||||
|
def __init__(self, code, message):
|
||||||
|
super().__init__(code, message)
|
||||||
|
# If not None send this unframed message over the network
|
||||||
|
self.error_message = None
|
||||||
|
# If the error was in a JSON response message; its message ID.
|
||||||
|
# Since None can be a response message ID, "id" means the
|
||||||
|
# error was not sent in a JSON response
|
||||||
|
self.response_msg_id = id
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPC(object):
|
||||||
|
"""Abstract base class that interprets and constructs JSON RPC messages."""
|
||||||
|
|
||||||
|
# Error codes. See http://www.jsonrpc.org/specification
|
||||||
|
PARSE_ERROR = -32700
|
||||||
|
INVALID_REQUEST = -32600
|
||||||
|
METHOD_NOT_FOUND = -32601
|
||||||
|
INVALID_ARGS = -32602
|
||||||
|
INTERNAL_ERROR = -32603
|
||||||
|
# Codes specific to this library
|
||||||
|
ERROR_CODE_UNAVAILABLE = -100
|
||||||
|
|
||||||
|
# Can be overridden by derived classes
|
||||||
|
allow_batches = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _message_id(cls, message, require_id):
|
||||||
|
"""Validate the message is a dictionary and return its ID.
|
||||||
|
|
||||||
|
Raise an error if the message is invalid or the ID is of an
|
||||||
|
invalid type. If it has no ID, raise an error if require_id
|
||||||
|
is True, otherwise return None.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_message(cls, message):
|
||||||
|
"""Validate other parts of the message other than those
|
||||||
|
done in _message_id."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _request_args(cls, request):
|
||||||
|
"""Validate the existence and type of the arguments passed
|
||||||
|
in the request dictionary."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _process_request(cls, payload):
|
||||||
|
request_id = None
|
||||||
|
try:
|
||||||
|
request_id = cls._message_id(payload, False)
|
||||||
|
cls._validate_message(payload)
|
||||||
|
method = payload.get('method')
|
||||||
|
if request_id is None:
|
||||||
|
item = Notification(method, cls._request_args(payload))
|
||||||
|
else:
|
||||||
|
item = Request(method, cls._request_args(payload))
|
||||||
|
return item, request_id
|
||||||
|
except ProtocolError as error:
|
||||||
|
code, message = error.code, error.message
|
||||||
|
raise cls._error(code, message, True, request_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _process_response(cls, payload):
|
||||||
|
request_id = None
|
||||||
|
try:
|
||||||
|
request_id = cls._message_id(payload, True)
|
||||||
|
cls._validate_message(payload)
|
||||||
|
return Response(cls.response_value(payload)), request_id
|
||||||
|
except ProtocolError as error:
|
||||||
|
code, message = error.code, error.message
|
||||||
|
raise cls._error(code, message, False, request_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _message_to_payload(cls, message):
|
||||||
|
"""Returns a Python object or a ProtocolError."""
|
||||||
|
try:
|
||||||
|
return json.loads(message.decode())
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
message = 'messages must be encoded in UTF-8'
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
message = 'invalid JSON'
|
||||||
|
raise cls._error(cls.PARSE_ERROR, message, True, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _error(cls, code, message, send, msg_id):
|
||||||
|
error = ProtocolError(code, message)
|
||||||
|
if send:
|
||||||
|
error.error_message = cls.response_message(error, msg_id)
|
||||||
|
else:
|
||||||
|
error.response_msg_id = msg_id
|
||||||
|
return error
|
||||||
|
|
||||||
|
#
|
||||||
|
# External API
|
||||||
|
#
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def message_to_item(cls, message):
|
||||||
|
"""Translate an unframed received message and return an
|
||||||
|
(item, request_id) pair.
|
||||||
|
|
||||||
|
The item can be a Request, Notification, Response or a list.
|
||||||
|
|
||||||
|
A JSON RPC error response is returned as an RPCError inside a
|
||||||
|
Response object.
|
||||||
|
|
||||||
|
If a Batch is returned, request_id is an iterable of request
|
||||||
|
ids, one per batch member.
|
||||||
|
|
||||||
|
If the message violates the protocol in some way a
|
||||||
|
ProtocolError is returned, except if the message was
|
||||||
|
determined to be a response, in which case the ProtocolError
|
||||||
|
is placed inside a Response object. This is so that client
|
||||||
|
code can mark a request as having been responded to even if
|
||||||
|
the response was bad.
|
||||||
|
|
||||||
|
raises: ProtocolError
|
||||||
|
"""
|
||||||
|
payload = cls._message_to_payload(message)
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
if 'method' in payload:
|
||||||
|
return cls._process_request(payload)
|
||||||
|
else:
|
||||||
|
return cls._process_response(payload)
|
||||||
|
elif isinstance(payload, list) and cls.allow_batches:
|
||||||
|
if not payload:
|
||||||
|
raise cls._error(JSONRPC.INVALID_REQUEST, 'batch is empty',
|
||||||
|
True, None)
|
||||||
|
return payload, None
|
||||||
|
raise cls._error(cls.INVALID_REQUEST,
|
||||||
|
'request object must be a dictionary', True, None)
|
||||||
|
|
||||||
|
# Message formation
|
||||||
|
@classmethod
|
||||||
|
def request_message(cls, item, request_id):
|
||||||
|
"""Convert an RPCRequest item to a message."""
|
||||||
|
assert isinstance(item, Request)
|
||||||
|
return cls.encode_payload(cls.request_payload(item, request_id))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def notification_message(cls, item):
|
||||||
|
"""Convert an RPCRequest item to a message."""
|
||||||
|
assert isinstance(item, Notification)
|
||||||
|
return cls.encode_payload(cls.request_payload(item, None))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_message(cls, result, request_id):
|
||||||
|
"""Convert a response result (or RPCError) to a message."""
|
||||||
|
if isinstance(result, CodeMessageError):
|
||||||
|
payload = cls.error_payload(result, request_id)
|
||||||
|
else:
|
||||||
|
payload = cls.response_payload(result, request_id)
|
||||||
|
return cls.encode_payload(payload)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_message(cls, batch, request_ids):
|
||||||
|
"""Convert a request Batch to a message."""
|
||||||
|
assert isinstance(batch, Batch)
|
||||||
|
if not cls.allow_batches:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'protocol does not permit batches')
|
||||||
|
id_iter = iter(request_ids)
|
||||||
|
rm = cls.request_message
|
||||||
|
nm = cls.notification_message
|
||||||
|
parts = (rm(request, next(id_iter)) if isinstance(request, Request)
|
||||||
|
else nm(request) for request in batch)
|
||||||
|
return cls.batch_message_from_parts(parts)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_message_from_parts(cls, messages):
|
||||||
|
"""Convert messages, one per batch item, into a batch message. At
|
||||||
|
least one message must be passed.
|
||||||
|
"""
|
||||||
|
# Comma-separate the messages and wrap the lot in square brackets
|
||||||
|
middle = b', '.join(messages)
|
||||||
|
if not middle:
|
||||||
|
raise ProtocolError.empty_batch()
|
||||||
|
return b''.join([b'[', middle, b']'])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encode_payload(cls, payload):
|
||||||
|
"""Encode a Python object as JSON and convert it to bytes."""
|
||||||
|
try:
|
||||||
|
return json.dumps(payload).encode()
|
||||||
|
except TypeError:
|
||||||
|
msg = f'JSON payload encoding error: {payload}'
|
||||||
|
raise ProtocolError(cls.INTERNAL_ERROR, msg) from None
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCv1(JSONRPC):
|
||||||
|
"""JSON RPC version 1.0."""
|
||||||
|
|
||||||
|
allow_batches = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _message_id(cls, message, require_id):
|
||||||
|
# JSONv1 requires an ID always, but without constraint on its type
|
||||||
|
# No need to test for a dictionary here as we don't handle batches.
|
||||||
|
if 'id' not in message:
|
||||||
|
raise ProtocolError.invalid_request('request has no "id"')
|
||||||
|
return message['id']
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _request_args(cls, request):
|
||||||
|
args = request.get('params')
|
||||||
|
if not isinstance(args, list):
|
||||||
|
raise ProtocolError.invalid_args(
|
||||||
|
f'invalid request arguments: {args}')
|
||||||
|
return args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _best_effort_error(cls, error):
|
||||||
|
# Do our best to interpret the error
|
||||||
|
code = cls.ERROR_CODE_UNAVAILABLE
|
||||||
|
message = 'no error message provided'
|
||||||
|
if isinstance(error, str):
|
||||||
|
message = error
|
||||||
|
elif isinstance(error, int):
|
||||||
|
code = error
|
||||||
|
elif isinstance(error, dict):
|
||||||
|
if isinstance(error.get('message'), str):
|
||||||
|
message = error['message']
|
||||||
|
if isinstance(error.get('code'), int):
|
||||||
|
code = error['code']
|
||||||
|
|
||||||
|
return RPCError(code, message)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_value(cls, payload):
|
||||||
|
if 'result' not in payload or 'error' not in payload:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'response must contain both "result" and "error"')
|
||||||
|
|
||||||
|
result = payload['result']
|
||||||
|
error = payload['error']
|
||||||
|
if error is None:
|
||||||
|
return result # It seems None can be a valid result
|
||||||
|
if result is not None:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'response has a "result" and an "error"')
|
||||||
|
|
||||||
|
return cls._best_effort_error(error)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def request_payload(cls, request, request_id):
|
||||||
|
"""JSON v1 request (or notification) payload."""
|
||||||
|
if isinstance(request.args, dict):
|
||||||
|
raise ProtocolError.invalid_args(
|
||||||
|
'JSONRPCv1 does not support named arguments')
|
||||||
|
return {
|
||||||
|
'method': request.method,
|
||||||
|
'params': request.args,
|
||||||
|
'id': request_id
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_payload(cls, result, request_id):
|
||||||
|
"""JSON v1 response payload."""
|
||||||
|
return {
|
||||||
|
'result': result,
|
||||||
|
'error': None,
|
||||||
|
'id': request_id
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def error_payload(cls, error, request_id):
|
||||||
|
return {
|
||||||
|
'result': None,
|
||||||
|
'error': {'code': error.code, 'message': error.message},
|
||||||
|
'id': request_id
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCv2(JSONRPC):
|
||||||
|
"""JSON RPC version 2.0."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _message_id(cls, message, require_id):
|
||||||
|
if not isinstance(message, dict):
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'request object must be a dictionary')
|
||||||
|
if 'id' in message:
|
||||||
|
request_id = message['id']
|
||||||
|
if not isinstance(request_id, (Number, str, type(None))):
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
f'invalid "id": {request_id}')
|
||||||
|
return request_id
|
||||||
|
else:
|
||||||
|
if require_id:
|
||||||
|
raise ProtocolError.invalid_request('request has no "id"')
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_message(cls, message):
|
||||||
|
if message.get('jsonrpc') != '2.0':
|
||||||
|
raise ProtocolError.invalid_request('"jsonrpc" is not "2.0"')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _request_args(cls, request):
|
||||||
|
args = request.get('params', [])
|
||||||
|
if not isinstance(args, (dict, list)):
|
||||||
|
raise ProtocolError.invalid_args(
|
||||||
|
f'invalid request arguments: {args}')
|
||||||
|
return args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_value(cls, payload):
|
||||||
|
if 'result' in payload:
|
||||||
|
if 'error' in payload:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'response contains both "result" and "error"')
|
||||||
|
return payload['result']
|
||||||
|
|
||||||
|
if 'error' not in payload:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'response contains neither "result" nor "error"')
|
||||||
|
|
||||||
|
# Return an RPCError object
|
||||||
|
error = payload['error']
|
||||||
|
if isinstance(error, dict):
|
||||||
|
code = error.get('code')
|
||||||
|
message = error.get('message')
|
||||||
|
if isinstance(code, int) and isinstance(message, str):
|
||||||
|
return RPCError(code, message)
|
||||||
|
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
f'ill-formed response error object: {error}')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def request_payload(cls, request, request_id):
|
||||||
|
"""JSON v2 request (or notification) payload."""
|
||||||
|
payload = {
|
||||||
|
'jsonrpc': '2.0',
|
||||||
|
'method': request.method,
|
||||||
|
}
|
||||||
|
# A notification?
|
||||||
|
if request_id is not None:
|
||||||
|
payload['id'] = request_id
|
||||||
|
# Preserve empty dicts as missing params is read as an array
|
||||||
|
if request.args or request.args == {}:
|
||||||
|
payload['params'] = request.args
|
||||||
|
return payload
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_payload(cls, result, request_id):
|
||||||
|
"""JSON v2 response payload."""
|
||||||
|
return {
|
||||||
|
'jsonrpc': '2.0',
|
||||||
|
'result': result,
|
||||||
|
'id': request_id
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def error_payload(cls, error, request_id):
|
||||||
|
return {
|
||||||
|
'jsonrpc': '2.0',
|
||||||
|
'error': {'code': error.code, 'message': error.message},
|
||||||
|
'id': request_id
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCLoose(JSONRPC):
|
||||||
|
"""A relaxed versin of JSON RPC."""
|
||||||
|
|
||||||
|
# Don't be so loose we accept any old message ID
|
||||||
|
_message_id = JSONRPCv2._message_id
|
||||||
|
_validate_message = JSONRPC._validate_message
|
||||||
|
_request_args = JSONRPCv2._request_args
|
||||||
|
# Outoing messages are JSONRPCv2 so we give the other side the
|
||||||
|
# best chance to assume / detect JSONRPCv2 as default protocol.
|
||||||
|
error_payload = JSONRPCv2.error_payload
|
||||||
|
request_payload = JSONRPCv2.request_payload
|
||||||
|
response_payload = JSONRPCv2.response_payload
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_value(cls, payload):
|
||||||
|
# Return result, unless it is None and there is an error
|
||||||
|
if payload.get('error') is not None:
|
||||||
|
if payload.get('result') is not None:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'response contains both "result" and "error"')
|
||||||
|
return JSONRPCv1._best_effort_error(payload['error'])
|
||||||
|
|
||||||
|
if 'result' not in payload:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'response contains neither "result" nor "error"')
|
||||||
|
|
||||||
|
# Can be None
|
||||||
|
return payload['result']
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCAutoDetect(JSONRPCv2):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def message_to_item(cls, message):
|
||||||
|
return cls.detect_protocol(message), None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_protocol(cls, message):
|
||||||
|
"""Attempt to detect the protocol from the message."""
|
||||||
|
main = cls._message_to_payload(message)
|
||||||
|
|
||||||
|
def protocol_for_payload(payload):
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return JSONRPCLoose # Will error
|
||||||
|
# Obey an explicit "jsonrpc"
|
||||||
|
version = payload.get('jsonrpc')
|
||||||
|
if version == '2.0':
|
||||||
|
return JSONRPCv2
|
||||||
|
if version == '1.0':
|
||||||
|
return JSONRPCv1
|
||||||
|
|
||||||
|
# Now to decide between JSONRPCLoose and JSONRPCv1 if possible
|
||||||
|
if 'result' in payload and 'error' in payload:
|
||||||
|
return JSONRPCv1
|
||||||
|
return JSONRPCLoose
|
||||||
|
|
||||||
|
if isinstance(main, list):
|
||||||
|
parts = set(protocol_for_payload(payload) for payload in main)
|
||||||
|
# If all same protocol, return it
|
||||||
|
if len(parts) == 1:
|
||||||
|
return parts.pop()
|
||||||
|
# If strict protocol detected, return it, preferring JSONRPCv2.
|
||||||
|
# This means a batch of JSONRPCv1 will fail
|
||||||
|
for protocol in (JSONRPCv2, JSONRPCv1):
|
||||||
|
if protocol in parts:
|
||||||
|
return protocol
|
||||||
|
# Will error if no parts
|
||||||
|
return JSONRPCLoose
|
||||||
|
|
||||||
|
return protocol_for_payload(main)
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCConnection(object):
|
||||||
|
"""Maintains state of a JSON RPC connection, in particular
|
||||||
|
encapsulating the handling of request IDs.
|
||||||
|
|
||||||
|
protocol - the JSON RPC protocol to follow
|
||||||
|
max_response_size - responses over this size send an error response
|
||||||
|
instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_id_counter = itertools.count()
|
||||||
|
|
||||||
|
def __init__(self, protocol):
|
||||||
|
self._protocol = protocol
|
||||||
|
# Sent Requests and Batches that have not received a response.
|
||||||
|
# The key is its request ID; for a batch it is sorted tuple
|
||||||
|
# of request IDs
|
||||||
|
self._requests = {}
|
||||||
|
# A public attribute intended to be settable dynamically
|
||||||
|
self.max_response_size = 0
|
||||||
|
|
||||||
|
def _oversized_response_message(self, request_id):
|
||||||
|
text = f'response too large (over {self.max_response_size:,d} bytes'
|
||||||
|
error = RPCError.invalid_request(text)
|
||||||
|
return self._protocol.response_message(error, request_id)
|
||||||
|
|
||||||
|
def _receive_response(self, result, request_id):
|
||||||
|
if request_id not in self._requests:
|
||||||
|
if request_id is None and isinstance(result, RPCError):
|
||||||
|
message = f'diagnostic error received: {result}'
|
||||||
|
else:
|
||||||
|
message = f'response to unsent request (ID: {request_id})'
|
||||||
|
raise ProtocolError.invalid_request(message) from None
|
||||||
|
request, event = self._requests.pop(request_id)
|
||||||
|
event.result = result
|
||||||
|
event.set()
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _receive_request_batch(self, payloads):
|
||||||
|
def item_send_result(request_id, result):
|
||||||
|
nonlocal size
|
||||||
|
part = protocol.response_message(result, request_id)
|
||||||
|
size += len(part) + 2
|
||||||
|
if size > self.max_response_size > 0:
|
||||||
|
part = self._oversized_response_message(request_id)
|
||||||
|
parts.append(part)
|
||||||
|
if len(parts) == count:
|
||||||
|
return protocol.batch_message_from_parts(parts)
|
||||||
|
return None
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
items = []
|
||||||
|
size = 0
|
||||||
|
count = 0
|
||||||
|
protocol = self._protocol
|
||||||
|
for payload in payloads:
|
||||||
|
try:
|
||||||
|
item, request_id = protocol._process_request(payload)
|
||||||
|
items.append(item)
|
||||||
|
if isinstance(item, Request):
|
||||||
|
count += 1
|
||||||
|
item.send_result = partial(item_send_result, request_id)
|
||||||
|
except ProtocolError as error:
|
||||||
|
count += 1
|
||||||
|
parts.append(error.error_message)
|
||||||
|
|
||||||
|
if not items and parts:
|
||||||
|
protocol_error = ProtocolError(0, "")
|
||||||
|
protocol_error.error_message = protocol.batch_message_from_parts(parts)
|
||||||
|
raise protocol_error
|
||||||
|
return items
|
||||||
|
|
||||||
|
def _receive_response_batch(self, payloads):
|
||||||
|
request_ids = []
|
||||||
|
results = []
|
||||||
|
for payload in payloads:
|
||||||
|
# Let ProtocolError exceptions through
|
||||||
|
item, request_id = self._protocol._process_response(payload)
|
||||||
|
request_ids.append(request_id)
|
||||||
|
results.append(item.result)
|
||||||
|
|
||||||
|
ordered = sorted(zip(request_ids, results), key=lambda t: t[0])
|
||||||
|
ordered_ids, ordered_results = zip(*ordered)
|
||||||
|
if ordered_ids not in self._requests:
|
||||||
|
raise ProtocolError.invalid_request('response to unsent batch')
|
||||||
|
request_batch, event = self._requests.pop(ordered_ids)
|
||||||
|
event.result = ordered_results
|
||||||
|
event.set()
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _send_result(self, request_id, result):
|
||||||
|
message = self._protocol.response_message(result, request_id)
|
||||||
|
if len(message) > self.max_response_size > 0:
|
||||||
|
message = self._oversized_response_message(request_id)
|
||||||
|
return message
|
||||||
|
|
||||||
|
def _event(self, request, request_id):
|
||||||
|
event = Event()
|
||||||
|
self._requests[request_id] = (request, event)
|
||||||
|
return event
|
||||||
|
|
||||||
|
#
|
||||||
|
# External API
|
||||||
|
#
|
||||||
|
def send_request(self, request):
|
||||||
|
"""Send a Request. Return a (message, event) pair.
|
||||||
|
|
||||||
|
The message is an unframed message to send over the network.
|
||||||
|
Wait on the event for the response; which will be in the
|
||||||
|
"result" attribute.
|
||||||
|
|
||||||
|
Raises: ProtocolError if the request violates the protocol
|
||||||
|
in some way..
|
||||||
|
"""
|
||||||
|
request_id = next(self._id_counter)
|
||||||
|
message = self._protocol.request_message(request, request_id)
|
||||||
|
return message, self._event(request, request_id)
|
||||||
|
|
||||||
|
def send_notification(self, notification):
|
||||||
|
return self._protocol.notification_message(notification)
|
||||||
|
|
||||||
|
def send_batch(self, batch):
|
||||||
|
ids = tuple(next(self._id_counter)
|
||||||
|
for request in batch if isinstance(request, Request))
|
||||||
|
message = self._protocol.batch_message(batch, ids)
|
||||||
|
event = self._event(batch, ids) if ids else None
|
||||||
|
return message, event
|
||||||
|
|
||||||
|
def receive_message(self, message):
|
||||||
|
"""Call with an unframed message received from the network.
|
||||||
|
|
||||||
|
Raises: ProtocolError if the message violates the protocol in
|
||||||
|
some way. However, if it happened in a response that can be
|
||||||
|
paired with a request, the ProtocolError is instead set in the
|
||||||
|
result attribute of the send_request() that caused the error.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
item, request_id = self._protocol.message_to_item(message)
|
||||||
|
except ProtocolError as e:
|
||||||
|
if e.response_msg_id is not id:
|
||||||
|
return self._receive_response(e, e.response_msg_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if isinstance(item, Request):
|
||||||
|
item.send_result = partial(self._send_result, request_id)
|
||||||
|
return [item]
|
||||||
|
if isinstance(item, Notification):
|
||||||
|
return [item]
|
||||||
|
if isinstance(item, Response):
|
||||||
|
return self._receive_response(item.result, request_id)
|
||||||
|
if isinstance(item, list):
|
||||||
|
if all(isinstance(payload, dict)
|
||||||
|
and ('result' in payload or 'error' in payload)
|
||||||
|
for payload in item):
|
||||||
|
return self._receive_response_batch(item)
|
||||||
|
else:
|
||||||
|
return self._receive_request_batch(item)
|
||||||
|
else:
|
||||||
|
# Protocol auto-detection hack
|
||||||
|
assert issubclass(item, JSONRPC)
|
||||||
|
self._protocol = item
|
||||||
|
return self.receive_message(message)
|
||||||
|
|
||||||
|
def cancel_pending_requests(self):
|
||||||
|
"""Cancel all pending requests."""
|
||||||
|
exception = CancelledError()
|
||||||
|
for request, event in self._requests.values():
|
||||||
|
event.result = exception
|
||||||
|
event.set()
|
||||||
|
self._requests.clear()
|
||||||
|
|
||||||
|
def pending_requests(self):
|
||||||
|
"""All sent requests that have not received a response."""
|
||||||
|
return [request for request, event in self._requests.values()]
|
||||||
|
|
||||||
|
|
||||||
|
def handler_invocation(handler, request):
|
||||||
|
method, args = request.method, request.args
|
||||||
|
if handler is None:
|
||||||
|
raise RPCError(JSONRPC.METHOD_NOT_FOUND,
|
||||||
|
f'unknown method "{method}"')
|
||||||
|
|
||||||
|
# We must test for too few and too many arguments. How
|
||||||
|
# depends on whether the arguments were passed as a list or as
|
||||||
|
# a dictionary.
|
||||||
|
info = signature_info(handler)
|
||||||
|
if isinstance(args, (tuple, list)):
|
||||||
|
if len(args) < info.min_args:
|
||||||
|
s = '' if len(args) == 1 else 's'
|
||||||
|
raise RPCError.invalid_args(
|
||||||
|
f'{len(args)} argument{s} passed to method '
|
||||||
|
f'"{method}" but it requires {info.min_args}')
|
||||||
|
if info.max_args is not None and len(args) > info.max_args:
|
||||||
|
s = '' if len(args) == 1 else 's'
|
||||||
|
raise RPCError.invalid_args(
|
||||||
|
f'{len(args)} argument{s} passed to method '
|
||||||
|
f'{method} taking at most {info.max_args}')
|
||||||
|
return partial(handler, *args)
|
||||||
|
|
||||||
|
# Arguments passed by name
|
||||||
|
if info.other_names is None:
|
||||||
|
raise RPCError.invalid_args(f'method "{method}" cannot '
|
||||||
|
f'be called with named arguments')
|
||||||
|
|
||||||
|
missing = set(info.required_names).difference(args)
|
||||||
|
if missing:
|
||||||
|
s = '' if len(missing) == 1 else 's'
|
||||||
|
missing = ', '.join(sorted(f'"{name}"' for name in missing))
|
||||||
|
raise RPCError.invalid_args(f'method "{method}" requires '
|
||||||
|
f'parameter{s} {missing}')
|
||||||
|
|
||||||
|
if info.other_names is not any:
|
||||||
|
excess = set(args).difference(info.required_names)
|
||||||
|
excess = excess.difference(info.other_names)
|
||||||
|
if excess:
|
||||||
|
s = '' if len(excess) == 1 else 's'
|
||||||
|
excess = ', '.join(sorted(f'"{name}"' for name in excess))
|
||||||
|
raise RPCError.invalid_args(f'method "{method}" does not '
|
||||||
|
f'take parameter{s} {excess}')
|
||||||
|
return partial(handler, **args)
|
529
torba/torba/rpc/session.py
Normal file
529
torba/torba/rpc/session.py
Normal file
|
@ -0,0 +1,529 @@
|
||||||
|
# Copyright (c) 2018, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ('Connector', 'RPCSession', 'MessageSession', 'Server',
|
||||||
|
'BatchError')
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from asyncio import Event, CancelledError
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
from torba.tasks import TaskGroup
|
||||||
|
|
||||||
|
from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification
|
||||||
|
from .jsonrpc import RPCError, ProtocolError
|
||||||
|
from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer
|
||||||
|
from .util import Concurrency
|
||||||
|
|
||||||
|
|
||||||
|
class Connector:
|
||||||
|
|
||||||
|
def __init__(self, session_factory, host=None, port=None, proxy=None,
|
||||||
|
**kwargs):
|
||||||
|
self.session_factory = session_factory
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.proxy = proxy
|
||||||
|
self.loop = kwargs.get('loop', asyncio.get_event_loop())
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
async def create_connection(self):
|
||||||
|
"""Initiate a connection."""
|
||||||
|
connector = self.proxy or self.loop
|
||||||
|
return await connector.create_connection(
|
||||||
|
self.session_factory, self.host, self.port, **self.kwargs)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
transport, self.protocol = await self.create_connection()
|
||||||
|
# By default, do not limit outgoing connections
|
||||||
|
self.protocol.bw_limit = 0
|
||||||
|
return self.protocol
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
await self.protocol.close()
|
||||||
|
|
||||||
|
|
||||||
|
class SessionBase(asyncio.Protocol):
|
||||||
|
"""Base class of networking sessions.
|
||||||
|
|
||||||
|
There is no client / server distinction other than who initiated
|
||||||
|
the connection.
|
||||||
|
|
||||||
|
To initiate a connection to a remote server pass host, port and
|
||||||
|
proxy to the constructor, and then call create_connection(). Each
|
||||||
|
successful call should have a corresponding call to close().
|
||||||
|
|
||||||
|
Alternatively if used in a with statement, the connection is made
|
||||||
|
on entry to the block, and closed on exit from the block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_errors = 10
|
||||||
|
|
||||||
|
def __init__(self, *, framer=None, loop=None):
|
||||||
|
self.framer = framer or self.default_framer()
|
||||||
|
self.loop = loop or asyncio.get_event_loop()
|
||||||
|
self.logger = logging.getLogger(self.__class__.__name__)
|
||||||
|
self.transport = None
|
||||||
|
# Set when a connection is made
|
||||||
|
self._address = None
|
||||||
|
self._proxy_address = None
|
||||||
|
# For logger.debug messsages
|
||||||
|
self.verbosity = 0
|
||||||
|
# Cleared when the send socket is full
|
||||||
|
self._can_send = Event()
|
||||||
|
self._can_send.set()
|
||||||
|
self._pm_task = None
|
||||||
|
self._task_group = TaskGroup(self.loop)
|
||||||
|
# Force-close a connection if a send doesn't succeed in this time
|
||||||
|
self.max_send_delay = 60
|
||||||
|
# Statistics. The RPC object also keeps its own statistics.
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.errors = 0
|
||||||
|
self.send_count = 0
|
||||||
|
self.send_size = 0
|
||||||
|
self.last_send = self.start_time
|
||||||
|
self.recv_count = 0
|
||||||
|
self.recv_size = 0
|
||||||
|
self.last_recv = self.start_time
|
||||||
|
# Bandwidth usage per hour before throttling starts
|
||||||
|
self.bw_limit = 2000000
|
||||||
|
self.bw_time = self.start_time
|
||||||
|
self.bw_charge = 0
|
||||||
|
# Concurrency control
|
||||||
|
self.max_concurrent = 6
|
||||||
|
self._concurrency = Concurrency(self.max_concurrent)
|
||||||
|
|
||||||
|
async def _update_concurrency(self):
|
||||||
|
# A non-positive value means not to limit concurrency
|
||||||
|
if self.bw_limit <= 0:
|
||||||
|
return
|
||||||
|
now = time.time()
|
||||||
|
# Reduce the recorded usage in proportion to the elapsed time
|
||||||
|
refund = (now - self.bw_time) * (self.bw_limit / 3600)
|
||||||
|
self.bw_charge = max(0, self.bw_charge - int(refund))
|
||||||
|
self.bw_time = now
|
||||||
|
# Reduce concurrency allocation by 1 for each whole bw_limit used
|
||||||
|
throttle = int(self.bw_charge / self.bw_limit)
|
||||||
|
target = max(1, self.max_concurrent - throttle)
|
||||||
|
current = self._concurrency.max_concurrent
|
||||||
|
if target != current:
|
||||||
|
self.logger.info(f'changing task concurrency from {current} '
|
||||||
|
f'to {target}')
|
||||||
|
await self._concurrency.set_max_concurrent(target)
|
||||||
|
|
||||||
|
def _using_bandwidth(self, size):
|
||||||
|
"""Called when sending or receiving size bytes."""
|
||||||
|
self.bw_charge += size
|
||||||
|
|
||||||
|
async def _limited_wait(self, secs):
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._can_send.wait(), secs)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.abort()
|
||||||
|
raise asyncio.CancelledError(f'task timed out after {secs}s')
|
||||||
|
|
||||||
|
async def _send_message(self, message):
|
||||||
|
if not self._can_send.is_set():
|
||||||
|
await self._limited_wait(self.max_send_delay)
|
||||||
|
if not self.is_closing():
|
||||||
|
framed_message = self.framer.frame(message)
|
||||||
|
self.send_size += len(framed_message)
|
||||||
|
self._using_bandwidth(len(framed_message))
|
||||||
|
self.send_count += 1
|
||||||
|
self.last_send = time.time()
|
||||||
|
if self.verbosity >= 4:
|
||||||
|
self.logger.debug(f'Sending framed message {framed_message}')
|
||||||
|
self.transport.write(framed_message)
|
||||||
|
|
||||||
|
def _bump_errors(self):
|
||||||
|
self.errors += 1
|
||||||
|
if self.errors >= self.max_errors:
|
||||||
|
# Don't await self.close() because that is self-cancelling
|
||||||
|
self._close()
|
||||||
|
|
||||||
|
def _close(self):
|
||||||
|
if self.transport:
|
||||||
|
self.transport.close()
|
||||||
|
|
||||||
|
# asyncio framework
|
||||||
|
def data_received(self, framed_message):
|
||||||
|
"""Called by asyncio when a message comes in."""
|
||||||
|
if self.verbosity >= 4:
|
||||||
|
self.logger.debug(f'Received framed message {framed_message}')
|
||||||
|
self.recv_size += len(framed_message)
|
||||||
|
self._using_bandwidth(len(framed_message))
|
||||||
|
self.framer.received_bytes(framed_message)
|
||||||
|
|
||||||
|
def pause_writing(self):
|
||||||
|
"""Transport calls when the send buffer is full."""
|
||||||
|
if not self.is_closing():
|
||||||
|
self._can_send.clear()
|
||||||
|
self.transport.pause_reading()
|
||||||
|
|
||||||
|
def resume_writing(self):
|
||||||
|
"""Transport calls when the send buffer has room."""
|
||||||
|
if not self._can_send.is_set():
|
||||||
|
self._can_send.set()
|
||||||
|
self.transport.resume_reading()
|
||||||
|
|
||||||
|
def connection_made(self, transport):
|
||||||
|
"""Called by asyncio when a connection is established.
|
||||||
|
|
||||||
|
Derived classes overriding this method must call this first."""
|
||||||
|
self.transport = transport
|
||||||
|
# This would throw if called on a closed SSL transport. Fixed
|
||||||
|
# in asyncio in Python 3.6.1 and 3.5.4
|
||||||
|
peer_address = transport.get_extra_info('peername')
|
||||||
|
# If the Socks proxy was used then _address is already set to
|
||||||
|
# the remote address
|
||||||
|
if self._address:
|
||||||
|
self._proxy_address = peer_address
|
||||||
|
else:
|
||||||
|
self._address = peer_address
|
||||||
|
self._pm_task = self.loop.create_task(self._receive_messages())
|
||||||
|
|
||||||
|
def connection_lost(self, exc):
|
||||||
|
"""Called by asyncio when the connection closes.
|
||||||
|
|
||||||
|
Tear down things done in connection_made."""
|
||||||
|
self._address = None
|
||||||
|
self.transport = None
|
||||||
|
self._task_group.cancel()
|
||||||
|
self._pm_task.cancel()
|
||||||
|
# Release waiting tasks
|
||||||
|
self._can_send.set()
|
||||||
|
|
||||||
|
# External API
|
||||||
|
def default_framer(self):
|
||||||
|
"""Return a default framer."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def peer_address(self):
|
||||||
|
"""Returns the peer's address (Python networking address), or None if
|
||||||
|
no connection or an error.
|
||||||
|
|
||||||
|
This is the result of socket.getpeername() when the connection
|
||||||
|
was made.
|
||||||
|
"""
|
||||||
|
return self._address
|
||||||
|
|
||||||
|
def peer_address_str(self):
|
||||||
|
"""Returns the peer's IP address and port as a human-readable
|
||||||
|
string."""
|
||||||
|
if not self._address:
|
||||||
|
return 'unknown'
|
||||||
|
ip_addr_str, port = self._address[:2]
|
||||||
|
if ':' in ip_addr_str:
|
||||||
|
return f'[{ip_addr_str}]:{port}'
|
||||||
|
else:
|
||||||
|
return f'{ip_addr_str}:{port}'
|
||||||
|
|
||||||
|
def is_closing(self):
|
||||||
|
"""Return True if the connection is closing."""
|
||||||
|
return not self.transport or self.transport.is_closing()
|
||||||
|
|
||||||
|
def abort(self):
|
||||||
|
"""Forcefully close the connection."""
|
||||||
|
if self.transport:
|
||||||
|
self.transport.abort()
|
||||||
|
|
||||||
|
async def close(self, *, force_after=30):
|
||||||
|
"""Close the connection and return when closed."""
|
||||||
|
self._close()
|
||||||
|
if self._pm_task:
|
||||||
|
with suppress(CancelledError):
|
||||||
|
await asyncio.wait([self._pm_task], timeout=force_after)
|
||||||
|
self.abort()
|
||||||
|
await self._pm_task
|
||||||
|
|
||||||
|
|
||||||
|
class MessageSession(SessionBase):
|
||||||
|
"""Session class for protocols where messages are not tied to responses,
|
||||||
|
such as the Bitcoin protocol.
|
||||||
|
|
||||||
|
To use as a client (connection-opening) session, pass host, port
|
||||||
|
and perhaps a proxy.
|
||||||
|
"""
|
||||||
|
async def _receive_messages(self):
|
||||||
|
while not self.is_closing():
|
||||||
|
try:
|
||||||
|
message = await self.framer.receive_message()
|
||||||
|
except BadMagicError as e:
|
||||||
|
magic, expected = e.args
|
||||||
|
self.logger.error(
|
||||||
|
f'bad network magic: got {magic} expected {expected}, '
|
||||||
|
f'disconnecting'
|
||||||
|
)
|
||||||
|
self._close()
|
||||||
|
except OversizedPayloadError as e:
|
||||||
|
command, payload_len = e.args
|
||||||
|
self.logger.error(
|
||||||
|
f'oversized payload of {payload_len:,d} bytes to command '
|
||||||
|
f'{command}, disconnecting'
|
||||||
|
)
|
||||||
|
self._close()
|
||||||
|
except BadChecksumError as e:
|
||||||
|
payload_checksum, claimed_checksum = e.args
|
||||||
|
self.logger.warning(
|
||||||
|
f'checksum mismatch: actual {payload_checksum.hex()} '
|
||||||
|
f'vs claimed {claimed_checksum.hex()}'
|
||||||
|
)
|
||||||
|
self._bump_errors()
|
||||||
|
else:
|
||||||
|
self.last_recv = time.time()
|
||||||
|
self.recv_count += 1
|
||||||
|
if self.recv_count % 10 == 0:
|
||||||
|
await self._update_concurrency()
|
||||||
|
await self._task_group.add(self._throttled_message(message))
|
||||||
|
|
||||||
|
async def _throttled_message(self, message):
|
||||||
|
"""Process a single request, respecting the concurrency limit."""
|
||||||
|
async with self._concurrency.semaphore:
|
||||||
|
try:
|
||||||
|
await self.handle_message(message)
|
||||||
|
except ProtocolError as e:
|
||||||
|
self.logger.error(f'{e}')
|
||||||
|
self._bump_errors()
|
||||||
|
except CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
self.logger.exception(f'exception handling {message}')
|
||||||
|
self._bump_errors()
|
||||||
|
|
||||||
|
# External API
|
||||||
|
def default_framer(self):
|
||||||
|
"""Return a bitcoin framer."""
|
||||||
|
return BitcoinFramer(bytes.fromhex('e3e1f3e8'), 128_000_000)
|
||||||
|
|
||||||
|
async def handle_message(self, message):
|
||||||
|
"""message is a (command, payload) pair."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send_message(self, message):
|
||||||
|
"""Send a message (command, payload) over the network."""
|
||||||
|
await self._send_message(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchError(Exception):
|
||||||
|
|
||||||
|
def __init__(self, request):
|
||||||
|
self.request = request # BatchRequest object
|
||||||
|
|
||||||
|
|
||||||
|
class BatchRequest(object):
|
||||||
|
"""Used to build a batch request to send to the server. Stores
|
||||||
|
the
|
||||||
|
|
||||||
|
Attributes batch and results are initially None.
|
||||||
|
|
||||||
|
Adding an invalid request or notification immediately raises a
|
||||||
|
ProtocolError.
|
||||||
|
|
||||||
|
On exiting the with clause, it will:
|
||||||
|
|
||||||
|
1) create a Batch object for the requests in the order they were
|
||||||
|
added. If the batch is empty this raises a ProtocolError.
|
||||||
|
|
||||||
|
2) set the "batch" attribute to be that batch
|
||||||
|
|
||||||
|
3) send the batch request and wait for a response
|
||||||
|
|
||||||
|
4) raise a ProtocolError if the protocol was violated by the
|
||||||
|
server. Currently this only happens if it gave more than one
|
||||||
|
response to any request
|
||||||
|
|
||||||
|
5) otherwise there is precisely one response to each Request. Set
|
||||||
|
the "results" attribute to the tuple of results; the responses
|
||||||
|
are ordered to match the Requests in the batch. Notifications
|
||||||
|
do not get a response.
|
||||||
|
|
||||||
|
6) if raise_errors is True and any individual response was a JSON
|
||||||
|
RPC error response, or violated the protocol in some way, a
|
||||||
|
BatchError exception is raised. Otherwise the caller can be
|
||||||
|
certain each request returned a standard result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session, raise_errors):
|
||||||
|
self._session = session
|
||||||
|
self._raise_errors = raise_errors
|
||||||
|
self._requests = []
|
||||||
|
self.batch = None
|
||||||
|
self.results = None
|
||||||
|
|
||||||
|
def add_request(self, method, args=()):
|
||||||
|
self._requests.append(Request(method, args))
|
||||||
|
|
||||||
|
def add_notification(self, method, args=()):
|
||||||
|
self._requests.append(Notification(method, args))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._requests)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
if exc_type is None:
|
||||||
|
self.batch = Batch(self._requests)
|
||||||
|
message, event = self._session.connection.send_batch(self.batch)
|
||||||
|
await self._session._send_message(message)
|
||||||
|
await event.wait()
|
||||||
|
self.results = event.result
|
||||||
|
if self._raise_errors:
|
||||||
|
if any(isinstance(item, Exception) for item in event.result):
|
||||||
|
raise BatchError(self)
|
||||||
|
|
||||||
|
|
||||||
|
class RPCSession(SessionBase):
|
||||||
|
"""Base class for protocols where a message can lead to a response,
|
||||||
|
for example JSON RPC."""
|
||||||
|
|
||||||
|
def __init__(self, *, framer=None, loop=None, connection=None):
|
||||||
|
super().__init__(framer=framer, loop=loop)
|
||||||
|
self.connection = connection or self.default_connection()
|
||||||
|
|
||||||
|
async def _receive_messages(self):
|
||||||
|
while not self.is_closing():
|
||||||
|
try:
|
||||||
|
message = await self.framer.receive_message()
|
||||||
|
except MemoryError as e:
|
||||||
|
self.logger.warning(f'{e!r}')
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.last_recv = time.time()
|
||||||
|
self.recv_count += 1
|
||||||
|
if self.recv_count % 10 == 0:
|
||||||
|
await self._update_concurrency()
|
||||||
|
|
||||||
|
try:
|
||||||
|
requests = self.connection.receive_message(message)
|
||||||
|
except ProtocolError as e:
|
||||||
|
self.logger.debug(f'{e}')
|
||||||
|
if e.error_message:
|
||||||
|
await self._send_message(e.error_message)
|
||||||
|
if e.code == JSONRPC.PARSE_ERROR:
|
||||||
|
self.max_errors = 0
|
||||||
|
self._bump_errors()
|
||||||
|
else:
|
||||||
|
for request in requests:
|
||||||
|
await self._task_group.add(self._throttled_request(request))
|
||||||
|
|
||||||
|
async def _throttled_request(self, request):
|
||||||
|
"""Process a single request, respecting the concurrency limit."""
|
||||||
|
async with self._concurrency.semaphore:
|
||||||
|
try:
|
||||||
|
result = await self.handle_request(request)
|
||||||
|
except (ProtocolError, RPCError) as e:
|
||||||
|
result = e
|
||||||
|
except CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
self.logger.exception(f'exception handling {request}')
|
||||||
|
result = RPCError(JSONRPC.INTERNAL_ERROR,
|
||||||
|
'internal server error')
|
||||||
|
if isinstance(request, Request):
|
||||||
|
message = request.send_result(result)
|
||||||
|
if message:
|
||||||
|
await self._send_message(message)
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
self._bump_errors()
|
||||||
|
|
||||||
|
def connection_lost(self, exc):
|
||||||
|
# Cancel pending requests and message processing
|
||||||
|
self.connection.cancel_pending_requests()
|
||||||
|
super().connection_lost(exc)
|
||||||
|
|
||||||
|
# External API
|
||||||
|
def default_connection(self):
|
||||||
|
"""Return a default connection if the user provides none."""
|
||||||
|
return JSONRPCConnection(JSONRPCv2)
|
||||||
|
|
||||||
|
def default_framer(self):
|
||||||
|
"""Return a default framer."""
|
||||||
|
return NewlineFramer()
|
||||||
|
|
||||||
|
async def handle_request(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send_request(self, method, args=()):
|
||||||
|
"""Send an RPC request over the network."""
|
||||||
|
message, event = self.connection.send_request(Request(method, args))
|
||||||
|
await self._send_message(message)
|
||||||
|
await event.wait()
|
||||||
|
result = event.result
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
raise result
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def send_notification(self, method, args=()):
|
||||||
|
"""Send an RPC notification over the network."""
|
||||||
|
message = self.connection.send_notification(Notification(method, args))
|
||||||
|
await self._send_message(message)
|
||||||
|
|
||||||
|
def send_batch(self, raise_errors=False):
|
||||||
|
"""Return a BatchRequest. Intended to be used like so:
|
||||||
|
|
||||||
|
async with session.send_batch() as batch:
|
||||||
|
batch.add_request("method1")
|
||||||
|
batch.add_request("sum", (x, y))
|
||||||
|
batch.add_notification("updated")
|
||||||
|
|
||||||
|
for result in batch.results:
|
||||||
|
...
|
||||||
|
|
||||||
|
Note that in some circumstances exceptions can be raised; see
|
||||||
|
BatchRequest doc string.
|
||||||
|
"""
|
||||||
|
return BatchRequest(self, raise_errors)
|
||||||
|
|
||||||
|
|
||||||
|
class Server(object):
|
||||||
|
"""A simple wrapper around an asyncio.Server object."""
|
||||||
|
|
||||||
|
def __init__(self, session_factory, host=None, port=None, *,
|
||||||
|
loop=None, **kwargs):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.loop = loop or asyncio.get_event_loop()
|
||||||
|
self.server = None
|
||||||
|
self._session_factory = session_factory
|
||||||
|
self._kwargs = kwargs
|
||||||
|
|
||||||
|
async def listen(self):
|
||||||
|
self.server = await self.loop.create_server(
|
||||||
|
self._session_factory, self.host, self.port, **self._kwargs)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the listening socket. This does not close any ServerSession
|
||||||
|
objects created to handle incoming connections.
|
||||||
|
"""
|
||||||
|
if self.server:
|
||||||
|
self.server.close()
|
||||||
|
await self.server.wait_closed()
|
||||||
|
self.server = None
|
439
torba/torba/rpc/socks.py
Normal file
439
torba/torba/rpc/socks.py
Normal file
|
@ -0,0 +1,439 @@
|
||||||
|
# Copyright (c) 2018, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
"""SOCKS proxying."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import collections
|
||||||
|
import ipaddress
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ('SOCKSUserAuth', 'SOCKS4', 'SOCKS4a', 'SOCKS5', 'SOCKSProxy',
|
||||||
|
'SOCKSError', 'SOCKSProtocolError', 'SOCKSFailure')
|
||||||
|
|
||||||
|
|
||||||
|
SOCKSUserAuth = collections.namedtuple("SOCKSUserAuth", "username password")
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKSError(Exception):
|
||||||
|
"""Base class for SOCKS exceptions. Each raised exception will be
|
||||||
|
an instance of a derived class."""
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKSProtocolError(SOCKSError):
|
||||||
|
"""Raised when the proxy does not follow the SOCKS protocol"""
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKSFailure(SOCKSError):
|
||||||
|
"""Raised when the proxy refuses or fails to make a connection"""
|
||||||
|
|
||||||
|
|
||||||
|
class NeedData(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKSBase(object):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(cls):
|
||||||
|
return cls.__name__
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._buffer = bytes()
|
||||||
|
self._state = self._start
|
||||||
|
|
||||||
|
def _read(self, size):
|
||||||
|
if len(self._buffer) < size:
|
||||||
|
raise NeedData(size - len(self._buffer))
|
||||||
|
result = self._buffer[:size]
|
||||||
|
self._buffer = self._buffer[size:]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def receive_data(self, data):
|
||||||
|
self._buffer += data
|
||||||
|
|
||||||
|
def next_message(self):
|
||||||
|
return self._state()
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKS4(SOCKSBase):
|
||||||
|
"""SOCKS4 protocol wrapper."""
|
||||||
|
|
||||||
|
# See http://ftp.icm.edu.pl/packages/socks/socks4/SOCKS4.protocol
|
||||||
|
REPLY_CODES = {
|
||||||
|
90: 'request granted',
|
||||||
|
91: 'request rejected or failed',
|
||||||
|
92: ('request rejected because SOCKS server cannot connect '
|
||||||
|
'to identd on the client'),
|
||||||
|
93: ('request rejected because the client program and identd '
|
||||||
|
'report different user-ids')
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, dst_host, dst_port, auth):
|
||||||
|
super().__init__()
|
||||||
|
self._dst_host = self._check_host(dst_host)
|
||||||
|
self._dst_port = dst_port
|
||||||
|
self._auth = auth
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _check_host(cls, host):
|
||||||
|
if not isinstance(host, ipaddress.IPv4Address):
|
||||||
|
try:
|
||||||
|
host = ipaddress.IPv4Address(host)
|
||||||
|
except ValueError:
|
||||||
|
raise SOCKSProtocolError(
|
||||||
|
f'SOCKS4 requires an IPv4 address: {host}') from None
|
||||||
|
return host
|
||||||
|
|
||||||
|
def _start(self):
|
||||||
|
self._state = self._first_response
|
||||||
|
|
||||||
|
if isinstance(self._dst_host, ipaddress.IPv4Address):
|
||||||
|
# SOCKS4
|
||||||
|
dst_ip_packed = self._dst_host.packed
|
||||||
|
host_bytes = b''
|
||||||
|
else:
|
||||||
|
# SOCKS4a
|
||||||
|
dst_ip_packed = b'\0\0\0\1'
|
||||||
|
host_bytes = self._dst_host.encode() + b'\0'
|
||||||
|
|
||||||
|
if isinstance(self._auth, SOCKSUserAuth):
|
||||||
|
user_id = self._auth.username.encode()
|
||||||
|
else:
|
||||||
|
user_id = b''
|
||||||
|
|
||||||
|
# Send TCP/IP stream CONNECT request
|
||||||
|
return b''.join([b'\4\1', struct.pack('>H', self._dst_port),
|
||||||
|
dst_ip_packed, user_id, b'\0', host_bytes])
|
||||||
|
|
||||||
|
def _first_response(self):
|
||||||
|
# Wait for 8-byte response
|
||||||
|
data = self._read(8)
|
||||||
|
if data[0] != 0:
|
||||||
|
raise SOCKSProtocolError(f'invalid {self.name()} proxy '
|
||||||
|
f'response: {data}')
|
||||||
|
reply_code = data[1]
|
||||||
|
if reply_code != 90:
|
||||||
|
msg = self.REPLY_CODES.get(
|
||||||
|
reply_code, f'unknown {self.name()} reply code {reply_code}')
|
||||||
|
raise SOCKSFailure(f'{self.name()} proxy request failed: {msg}')
|
||||||
|
|
||||||
|
# Other fields ignored
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKS4a(SOCKS4):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _check_host(cls, host):
|
||||||
|
if not isinstance(host, (str, ipaddress.IPv4Address)):
|
||||||
|
raise SOCKSProtocolError(
|
||||||
|
f'SOCKS4a requires an IPv4 address or host name: {host}')
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKS5(SOCKSBase):
|
||||||
|
"""SOCKS protocol wrapper."""
|
||||||
|
|
||||||
|
# See https://tools.ietf.org/html/rfc1928
|
||||||
|
ERROR_CODES = {
|
||||||
|
1: 'general SOCKS server failure',
|
||||||
|
2: 'connection not allowed by ruleset',
|
||||||
|
3: 'network unreachable',
|
||||||
|
4: 'host unreachable',
|
||||||
|
5: 'connection refused',
|
||||||
|
6: 'TTL expired',
|
||||||
|
7: 'command not supported',
|
||||||
|
8: 'address type not supported',
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, dst_host, dst_port, auth):
|
||||||
|
super().__init__()
|
||||||
|
self._dst_bytes = self._destination_bytes(dst_host, dst_port)
|
||||||
|
self._auth_bytes, self._auth_methods = self._authentication(auth)
|
||||||
|
|
||||||
|
def _destination_bytes(self, host, port):
|
||||||
|
if isinstance(host, ipaddress.IPv4Address):
|
||||||
|
addr_bytes = b'\1' + host.packed
|
||||||
|
elif isinstance(host, ipaddress.IPv6Address):
|
||||||
|
addr_bytes = b'\4' + host.packed
|
||||||
|
elif isinstance(host, str):
|
||||||
|
host = host.encode()
|
||||||
|
if len(host) > 255:
|
||||||
|
raise SOCKSProtocolError(f'hostname too long: '
|
||||||
|
f'{len(host)} bytes')
|
||||||
|
addr_bytes = b'\3' + bytes([len(host)]) + host
|
||||||
|
else:
|
||||||
|
raise SOCKSProtocolError(f'SOCKS5 requires an IPv4 address, IPv6 '
|
||||||
|
f'address, or host name: {host}')
|
||||||
|
return addr_bytes + struct.pack('>H', port)
|
||||||
|
|
||||||
|
def _authentication(self, auth):
|
||||||
|
if isinstance(auth, SOCKSUserAuth):
|
||||||
|
user_bytes = auth.username.encode()
|
||||||
|
if not 0 < len(user_bytes) < 256:
|
||||||
|
raise SOCKSProtocolError(f'username {auth.username} has '
|
||||||
|
f'invalid length {len(user_bytes)}')
|
||||||
|
pwd_bytes = auth.password.encode()
|
||||||
|
if not 0 < len(pwd_bytes) < 256:
|
||||||
|
raise SOCKSProtocolError(f'password has invalid length '
|
||||||
|
f'{len(pwd_bytes)}')
|
||||||
|
return b''.join([bytes([1, len(user_bytes)]), user_bytes,
|
||||||
|
bytes([len(pwd_bytes)]), pwd_bytes]), [0, 2]
|
||||||
|
return b'', [0]
|
||||||
|
|
||||||
|
def _start(self):
|
||||||
|
self._state = self._first_response
|
||||||
|
return (b'\5' + bytes([len(self._auth_methods)])
|
||||||
|
+ bytes(m for m in self._auth_methods))
|
||||||
|
|
||||||
|
def _first_response(self):
|
||||||
|
# Wait for 2-byte response
|
||||||
|
data = self._read(2)
|
||||||
|
if data[0] != 5:
|
||||||
|
raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
|
||||||
|
if data[1] not in self._auth_methods:
|
||||||
|
raise SOCKSFailure('SOCKS5 proxy rejected authentication methods')
|
||||||
|
|
||||||
|
# Authenticate if user-password authentication
|
||||||
|
if data[1] == 2:
|
||||||
|
self._state = self._auth_response
|
||||||
|
return self._auth_bytes
|
||||||
|
return self._request_connection()
|
||||||
|
|
||||||
|
def _auth_response(self):
|
||||||
|
data = self._read(2)
|
||||||
|
if data[0] != 1:
|
||||||
|
raise SOCKSProtocolError(f'invalid SOCKS5 proxy auth '
|
||||||
|
f'response: {data}')
|
||||||
|
if data[1] != 0:
|
||||||
|
raise SOCKSFailure(f'SOCKS5 proxy auth failure code: '
|
||||||
|
f'{data[1]}')
|
||||||
|
|
||||||
|
return self._request_connection()
|
||||||
|
|
||||||
|
def _request_connection(self):
|
||||||
|
# Send connection request
|
||||||
|
self._state = self._connect_response
|
||||||
|
return b'\5\1\0' + self._dst_bytes
|
||||||
|
|
||||||
|
def _connect_response(self):
|
||||||
|
data = self._read(5)
|
||||||
|
if data[0] != 5 or data[2] != 0 or data[3] not in (1, 3, 4):
|
||||||
|
raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
|
||||||
|
if data[1] != 0:
|
||||||
|
raise SOCKSFailure(self.ERROR_CODES.get(
|
||||||
|
data[1], f'unknown SOCKS5 error code: {data[1]}'))
|
||||||
|
|
||||||
|
if data[3] == 1:
|
||||||
|
addr_len = 3 # IPv4
|
||||||
|
elif data[3] == 3:
|
||||||
|
addr_len = data[4] # Hostname
|
||||||
|
else:
|
||||||
|
addr_len = 15 # IPv6
|
||||||
|
|
||||||
|
self._state = partial(self._connect_response_rest, addr_len)
|
||||||
|
return self.next_message()
|
||||||
|
|
||||||
|
def _connect_response_rest(self, addr_len):
|
||||||
|
self._read(addr_len + 2)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKSProxy(object):
|
||||||
|
|
||||||
|
def __init__(self, address, protocol, auth):
|
||||||
|
"""A SOCKS proxy at an address following a SOCKS protocol. auth is an
|
||||||
|
authentication method to use when connecting, or None.
|
||||||
|
|
||||||
|
address is a (host, port) pair; for IPv6 it can instead be a
|
||||||
|
(host, port, flowinfo, scopeid) 4-tuple.
|
||||||
|
"""
|
||||||
|
self.address = address
|
||||||
|
self.protocol = protocol
|
||||||
|
self.auth = auth
|
||||||
|
# Set on each successful connection via the proxy to the
|
||||||
|
# result of socket.getpeername()
|
||||||
|
self.peername = None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
auth = 'username' if self.auth else 'none'
|
||||||
|
return f'{self.protocol.name()} proxy at {self.address}, auth: {auth}'
|
||||||
|
|
||||||
|
async def _handshake(self, client, sock, loop):
|
||||||
|
while True:
|
||||||
|
count = 0
|
||||||
|
try:
|
||||||
|
message = client.next_message()
|
||||||
|
except NeedData as e:
|
||||||
|
count = e.args[0]
|
||||||
|
else:
|
||||||
|
if message is None:
|
||||||
|
return
|
||||||
|
await loop.sock_sendall(sock, message)
|
||||||
|
|
||||||
|
if count:
|
||||||
|
data = await loop.sock_recv(sock, count)
|
||||||
|
if not data:
|
||||||
|
raise SOCKSProtocolError("EOF received")
|
||||||
|
client.receive_data(data)
|
||||||
|
|
||||||
|
async def _connect_one(self, host, port):
|
||||||
|
"""Connect to the proxy and perform a handshake requesting a
|
||||||
|
connection to (host, port).
|
||||||
|
|
||||||
|
Return the open socket on success, or the exception on failure.
|
||||||
|
"""
|
||||||
|
client = self.protocol(host, port, self.auth)
|
||||||
|
sock = socket.socket()
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
# A non-blocking socket is required by loop socket methods
|
||||||
|
sock.setblocking(False)
|
||||||
|
await loop.sock_connect(sock, self.address)
|
||||||
|
await self._handshake(client, sock, loop)
|
||||||
|
self.peername = sock.getpeername()
|
||||||
|
return sock
|
||||||
|
except Exception as e:
|
||||||
|
# Don't close - see https://github.com/kyuupichan/aiorpcX/issues/8
|
||||||
|
if sys.platform.startswith('linux') or sys.platform == "darwin":
|
||||||
|
sock.close()
|
||||||
|
return e
|
||||||
|
|
||||||
|
async def _connect(self, addresses):
|
||||||
|
"""Connect to the proxy and perform a handshake requesting a
|
||||||
|
connection to each address in addresses.
|
||||||
|
|
||||||
|
Return an (open_socket, address) pair on success.
|
||||||
|
"""
|
||||||
|
assert len(addresses) > 0
|
||||||
|
|
||||||
|
exceptions = []
|
||||||
|
for address in addresses:
|
||||||
|
host, port = address[:2]
|
||||||
|
sock = await self._connect_one(host, port)
|
||||||
|
if isinstance(sock, socket.socket):
|
||||||
|
return sock, address
|
||||||
|
exceptions.append(sock)
|
||||||
|
|
||||||
|
strings = set(f'{exc!r}' for exc in exceptions)
|
||||||
|
raise (exceptions[0] if len(strings) == 1 else
|
||||||
|
OSError(f'multiple exceptions: {", ".join(strings)}'))
|
||||||
|
|
||||||
|
async def _detect_proxy(self):
|
||||||
|
"""Return True if it appears we can connect to a SOCKS proxy,
|
||||||
|
otherwise False.
|
||||||
|
"""
|
||||||
|
if self.protocol is SOCKS4a:
|
||||||
|
host, port = 'www.apple.com', 80
|
||||||
|
else:
|
||||||
|
host, port = ipaddress.IPv4Address('8.8.8.8'), 53
|
||||||
|
|
||||||
|
sock = await self._connect_one(host, port)
|
||||||
|
if isinstance(sock, socket.socket):
|
||||||
|
sock.close()
|
||||||
|
return True
|
||||||
|
|
||||||
|
# SOCKSFailure indicates something failed, but that we are
|
||||||
|
# likely talking to a proxy
|
||||||
|
return isinstance(sock, SOCKSFailure)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def auto_detect_address(cls, address, auth):
|
||||||
|
"""Try to detect a SOCKS proxy at address using the authentication
|
||||||
|
method (or None). SOCKS5, SOCKS4a and SOCKS are tried in
|
||||||
|
order. If a SOCKS proxy is detected a SOCKSProxy object is
|
||||||
|
returned.
|
||||||
|
|
||||||
|
Returning a SOCKSProxy does not mean it is functioning - for
|
||||||
|
example, it may have no network connectivity.
|
||||||
|
|
||||||
|
If no proxy is detected return None.
|
||||||
|
"""
|
||||||
|
for protocol in (SOCKS5, SOCKS4a, SOCKS4):
|
||||||
|
proxy = cls(address, protocol, auth)
|
||||||
|
if await proxy._detect_proxy():
|
||||||
|
return proxy
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def auto_detect_host(cls, host, ports, auth):
|
||||||
|
"""Try to detect a SOCKS proxy on a host on one of the ports.
|
||||||
|
|
||||||
|
Calls auto_detect for the ports in order. Returns SOCKS are
|
||||||
|
tried in order; a SOCKSProxy object for the first detected
|
||||||
|
proxy is returned.
|
||||||
|
|
||||||
|
Returning a SOCKSProxy does not mean it is functioning - for
|
||||||
|
example, it may have no network connectivity.
|
||||||
|
|
||||||
|
If no proxy is detected return None.
|
||||||
|
"""
|
||||||
|
for port in ports:
|
||||||
|
address = (host, port)
|
||||||
|
proxy = await cls.auto_detect_address(address, auth)
|
||||||
|
if proxy:
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def create_connection(self, protocol_factory, host, port, *,
|
||||||
|
resolve=False, ssl=None,
|
||||||
|
family=0, proto=0, flags=0):
|
||||||
|
"""Set up a connection to (host, port) through the proxy.
|
||||||
|
|
||||||
|
If resolve is True then host is resolved locally with
|
||||||
|
getaddrinfo using family, proto and flags, otherwise the proxy
|
||||||
|
is asked to resolve host.
|
||||||
|
|
||||||
|
The function signature is similar to loop.create_connection()
|
||||||
|
with the same result. The attribute _address is set on the
|
||||||
|
protocol to the address of the successful remote connection.
|
||||||
|
Additionally raises SOCKSError if something goes wrong with
|
||||||
|
the proxy handshake.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if resolve:
|
||||||
|
infos = await loop.getaddrinfo(host, port, family=family,
|
||||||
|
type=socket.SOCK_STREAM,
|
||||||
|
proto=proto, flags=flags)
|
||||||
|
addresses = [info[4] for info in infos]
|
||||||
|
else:
|
||||||
|
addresses = [(host, port)]
|
||||||
|
|
||||||
|
sock, address = await self._connect(addresses)
|
||||||
|
|
||||||
|
def set_address():
|
||||||
|
protocol = protocol_factory()
|
||||||
|
protocol._address = address
|
||||||
|
return protocol
|
||||||
|
|
||||||
|
return await loop.create_connection(
|
||||||
|
set_address, sock=sock, ssl=ssl,
|
||||||
|
server_hostname=host if ssl else None)
|
95
torba/torba/rpc/util.py
Normal file
95
torba/torba/rpc/util.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
# Copyright (c) 2018, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
__all__ = ()
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections import namedtuple
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
# other_params: None means cannot be called with keyword arguments only
|
||||||
|
# any means any name is good
|
||||||
|
SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args '
|
||||||
|
'required_names other_names')
|
||||||
|
|
||||||
|
|
||||||
|
def signature_info(func):
|
||||||
|
params = inspect.signature(func).parameters
|
||||||
|
min_args = max_args = 0
|
||||||
|
required_names = []
|
||||||
|
other_names = []
|
||||||
|
no_names = False
|
||||||
|
for p in params.values():
|
||||||
|
if p.kind == p.POSITIONAL_OR_KEYWORD:
|
||||||
|
max_args += 1
|
||||||
|
if p.default is p.empty:
|
||||||
|
min_args += 1
|
||||||
|
required_names.append(p.name)
|
||||||
|
else:
|
||||||
|
other_names.append(p.name)
|
||||||
|
elif p.kind == p.KEYWORD_ONLY:
|
||||||
|
other_names.append(p.name)
|
||||||
|
elif p.kind == p.VAR_POSITIONAL:
|
||||||
|
max_args = None
|
||||||
|
elif p.kind == p.VAR_KEYWORD:
|
||||||
|
other_names = any
|
||||||
|
elif p.kind == p.POSITIONAL_ONLY:
|
||||||
|
max_args += 1
|
||||||
|
if p.default is p.empty:
|
||||||
|
min_args += 1
|
||||||
|
no_names = True
|
||||||
|
|
||||||
|
if no_names:
|
||||||
|
other_names = None
|
||||||
|
|
||||||
|
return SignatureInfo(min_args, max_args, required_names, other_names)
|
||||||
|
|
||||||
|
|
||||||
|
class Concurrency(object):
|
||||||
|
|
||||||
|
def __init__(self, max_concurrent):
|
||||||
|
self._require_non_negative(max_concurrent)
|
||||||
|
self._max_concurrent = max_concurrent
|
||||||
|
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
|
|
||||||
|
def _require_non_negative(self, value):
|
||||||
|
if not isinstance(value, int) or value < 0:
|
||||||
|
raise RuntimeError('concurrency must be a natural number')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_concurrent(self):
|
||||||
|
return self._max_concurrent
|
||||||
|
|
||||||
|
async def set_max_concurrent(self, value):
|
||||||
|
self._require_non_negative(value)
|
||||||
|
diff = value - self._max_concurrent
|
||||||
|
self._max_concurrent = value
|
||||||
|
if diff >= 0:
|
||||||
|
for _ in range(diff):
|
||||||
|
self.semaphore.release()
|
||||||
|
else:
|
||||||
|
for _ in range(-diff):
|
||||||
|
await self.semaphore.acquire()
|
1
torba/torba/server/__init__.py
Normal file
1
torba/torba/server/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .server import Server
|
713
torba/torba/server/block_processor.py
Normal file
713
torba/torba/server/block_processor.py
Normal file
|
@ -0,0 +1,713 @@
|
||||||
|
# Copyright (c) 2016-2017, Neil Booth
|
||||||
|
# Copyright (c) 2017, the ElectrumX authors
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# See the file "LICENCE" for information about the copyright
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
"""Block prefetcher and chain processor."""
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from struct import pack, unpack
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torba
|
||||||
|
from torba.server.daemon import DaemonError
|
||||||
|
from torba.server.hash import hash_to_hex_str, HASHX_LEN
|
||||||
|
from torba.server.util import chunks, class_logger
|
||||||
|
from torba.server.db import FlushData
|
||||||
|
|
||||||
|
|
||||||
|
class Prefetcher:
|
||||||
|
"""Prefetches blocks (in the forward direction only)."""
|
||||||
|
|
||||||
|
def __init__(self, daemon, coin, blocks_event):
|
||||||
|
self.logger = class_logger(__name__, self.__class__.__name__)
|
||||||
|
self.daemon = daemon
|
||||||
|
self.coin = coin
|
||||||
|
self.blocks_event = blocks_event
|
||||||
|
self.blocks = []
|
||||||
|
self.caught_up = False
|
||||||
|
# Access to fetched_height should be protected by the semaphore
|
||||||
|
self.fetched_height = None
|
||||||
|
self.semaphore = asyncio.Semaphore()
|
||||||
|
self.refill_event = asyncio.Event()
|
||||||
|
# The prefetched block cache size. The min cache size has
|
||||||
|
# little effect on sync time.
|
||||||
|
self.cache_size = 0
|
||||||
|
self.min_cache_size = 10 * 1024 * 1024
|
||||||
|
# This makes the first fetch be 10 blocks
|
||||||
|
self.ave_size = self.min_cache_size // 10
|
||||||
|
self.polling_delay = 5
|
||||||
|
|
||||||
|
async def main_loop(self, bp_height):
|
||||||
|
"""Loop forever polling for more blocks."""
|
||||||
|
await self.reset_height(bp_height)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Sleep a while if there is nothing to prefetch
|
||||||
|
await self.refill_event.wait()
|
||||||
|
if not await self._prefetch_blocks():
|
||||||
|
await asyncio.sleep(self.polling_delay)
|
||||||
|
except DaemonError as e:
|
||||||
|
self.logger.info(f'ignoring daemon error: {e}')
|
||||||
|
|
||||||
|
def get_prefetched_blocks(self):
|
||||||
|
"""Called by block processor when it is processing queued blocks."""
|
||||||
|
blocks = self.blocks
|
||||||
|
self.blocks = []
|
||||||
|
self.cache_size = 0
|
||||||
|
self.refill_event.set()
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
async def reset_height(self, height):
|
||||||
|
"""Reset to prefetch blocks from the block processor's height.
|
||||||
|
|
||||||
|
Used in blockchain reorganisations. This coroutine can be
|
||||||
|
called asynchronously to the _prefetch_blocks coroutine so we
|
||||||
|
must synchronize with a semaphore.
|
||||||
|
"""
|
||||||
|
async with self.semaphore:
|
||||||
|
self.blocks.clear()
|
||||||
|
self.cache_size = 0
|
||||||
|
self.fetched_height = height
|
||||||
|
self.refill_event.set()
|
||||||
|
|
||||||
|
daemon_height = await self.daemon.height()
|
||||||
|
behind = daemon_height - height
|
||||||
|
if behind > 0:
|
||||||
|
self.logger.info('catching up to daemon height {:,d} '
|
||||||
|
'({:,d} blocks behind)'
|
||||||
|
.format(daemon_height, behind))
|
||||||
|
else:
|
||||||
|
self.logger.info('caught up to daemon height {:,d}'
|
||||||
|
.format(daemon_height))
|
||||||
|
|
||||||
|
async def _prefetch_blocks(self):
|
||||||
|
"""Prefetch some blocks and put them on the queue.
|
||||||
|
|
||||||
|
Repeats until the queue is full or caught up.
|
||||||
|
"""
|
||||||
|
daemon = self.daemon
|
||||||
|
daemon_height = await daemon.height()
|
||||||
|
async with self.semaphore:
|
||||||
|
while self.cache_size < self.min_cache_size:
|
||||||
|
# Try and catch up all blocks but limit to room in cache.
|
||||||
|
# Constrain fetch count to between 0 and 500 regardless;
|
||||||
|
# testnet can be lumpy.
|
||||||
|
cache_room = self.min_cache_size // self.ave_size
|
||||||
|
count = min(daemon_height - self.fetched_height, cache_room)
|
||||||
|
count = min(500, max(count, 0))
|
||||||
|
if not count:
|
||||||
|
self.caught_up = True
|
||||||
|
return False
|
||||||
|
|
||||||
|
first = self.fetched_height + 1
|
||||||
|
hex_hashes = await daemon.block_hex_hashes(first, count)
|
||||||
|
if self.caught_up:
|
||||||
|
self.logger.info('new block height {:,d} hash {}'
|
||||||
|
.format(first + count-1, hex_hashes[-1]))
|
||||||
|
blocks = await daemon.raw_blocks(hex_hashes)
|
||||||
|
|
||||||
|
assert count == len(blocks)
|
||||||
|
|
||||||
|
# Special handling for genesis block
|
||||||
|
if first == 0:
|
||||||
|
blocks[0] = self.coin.genesis_block(blocks[0])
|
||||||
|
self.logger.info('verified genesis block with hash {}'
|
||||||
|
.format(hex_hashes[0]))
|
||||||
|
|
||||||
|
# Update our recent average block size estimate
|
||||||
|
size = sum(len(block) for block in blocks)
|
||||||
|
if count >= 10:
|
||||||
|
self.ave_size = size // count
|
||||||
|
else:
|
||||||
|
self.ave_size = (size + (10 - count) * self.ave_size) // 10
|
||||||
|
|
||||||
|
self.blocks.extend(blocks)
|
||||||
|
self.cache_size += size
|
||||||
|
self.fetched_height += count
|
||||||
|
self.blocks_event.set()
|
||||||
|
|
||||||
|
self.refill_event.clear()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class ChainError(Exception):
|
||||||
|
"""Raised on error processing blocks."""
|
||||||
|
|
||||||
|
|
||||||
|
class BlockProcessor:
|
||||||
|
"""Process blocks and update the DB state to match.
|
||||||
|
|
||||||
|
Employ a prefetcher to prefetch blocks in batches for processing.
|
||||||
|
Coordinate backing up in case of chain reorganisations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env, db, daemon, notifications):
|
||||||
|
self.env = env
|
||||||
|
self.db = db
|
||||||
|
self.daemon = daemon
|
||||||
|
self.notifications = notifications
|
||||||
|
|
||||||
|
self.coin = env.coin
|
||||||
|
self.blocks_event = asyncio.Event()
|
||||||
|
self.prefetcher = Prefetcher(daemon, env.coin, self.blocks_event)
|
||||||
|
self.logger = class_logger(__name__, self.__class__.__name__)
|
||||||
|
|
||||||
|
# Meta
|
||||||
|
self.next_cache_check = 0
|
||||||
|
self.touched = set()
|
||||||
|
self.reorg_count = 0
|
||||||
|
|
||||||
|
# Caches of unflushed items.
|
||||||
|
self.headers = []
|
||||||
|
self.tx_hashes = []
|
||||||
|
self.undo_infos = []
|
||||||
|
|
||||||
|
# UTXO cache
|
||||||
|
self.utxo_cache = {}
|
||||||
|
self.db_deletes = []
|
||||||
|
|
||||||
|
# If the lock is successfully acquired, in-memory chain state
|
||||||
|
# is consistent with self.height
|
||||||
|
self.state_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def run_in_thread_with_lock(self, func, *args):
|
||||||
|
# Run in a thread to prevent blocking. Shielded so that
|
||||||
|
# cancellations from shutdown don't lose work - when the task
|
||||||
|
# completes the data will be flushed and then we shut down.
|
||||||
|
# Take the state lock to be certain in-memory state is
|
||||||
|
# consistent and not being updated elsewhere.
|
||||||
|
async def run_in_thread_locked():
|
||||||
|
async with self.state_lock:
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(None, func, *args)
|
||||||
|
return await asyncio.shield(run_in_thread_locked())
|
||||||
|
|
||||||
|
async def check_and_advance_blocks(self, raw_blocks):
|
||||||
|
"""Process the list of raw blocks passed. Detects and handles
|
||||||
|
reorgs.
|
||||||
|
"""
|
||||||
|
if not raw_blocks:
|
||||||
|
return
|
||||||
|
first = self.height + 1
|
||||||
|
blocks = [self.coin.block(raw_block, first + n)
|
||||||
|
for n, raw_block in enumerate(raw_blocks)]
|
||||||
|
headers = [block.header for block in blocks]
|
||||||
|
hprevs = [self.coin.header_prevhash(h) for h in headers]
|
||||||
|
chain = [self.tip] + [self.coin.header_hash(h) for h in headers[:-1]]
|
||||||
|
|
||||||
|
if hprevs == chain:
|
||||||
|
start = time.time()
|
||||||
|
await self.run_in_thread_with_lock(self.advance_blocks, blocks)
|
||||||
|
await self._maybe_flush()
|
||||||
|
if not self.db.first_sync:
|
||||||
|
s = '' if len(blocks) == 1 else 's'
|
||||||
|
self.logger.info('processed {:,d} block{} in {:.1f}s'
|
||||||
|
.format(len(blocks), s,
|
||||||
|
time.time() - start))
|
||||||
|
if self._caught_up_event.is_set():
|
||||||
|
await self.notifications.on_block(self.touched, self.height)
|
||||||
|
self.touched = set()
|
||||||
|
elif hprevs[0] != chain[0]:
|
||||||
|
await self.reorg_chain()
|
||||||
|
else:
|
||||||
|
# It is probably possible but extremely rare that what
|
||||||
|
# bitcoind returns doesn't form a chain because it
|
||||||
|
# reorg-ed the chain as it was processing the batched
|
||||||
|
# block hash requests. Should this happen it's simplest
|
||||||
|
# just to reset the prefetcher and try again.
|
||||||
|
self.logger.warning('daemon blocks do not form a chain; '
|
||||||
|
'resetting the prefetcher')
|
||||||
|
await self.prefetcher.reset_height(self.height)
|
||||||
|
|
||||||
|
async def reorg_chain(self, count=None):
|
||||||
|
"""Handle a chain reorganisation.
|
||||||
|
|
||||||
|
Count is the number of blocks to simulate a reorg, or None for
|
||||||
|
a real reorg."""
|
||||||
|
if count is None:
|
||||||
|
self.logger.info('chain reorg detected')
|
||||||
|
else:
|
||||||
|
self.logger.info(f'faking a reorg of {count:,d} blocks')
|
||||||
|
await self.flush(True)
|
||||||
|
|
||||||
|
async def get_raw_blocks(last_height, hex_hashes):
|
||||||
|
heights = range(last_height, last_height - len(hex_hashes), -1)
|
||||||
|
try:
|
||||||
|
blocks = [self.db.read_raw_block(height) for height in heights]
|
||||||
|
self.logger.info(f'read {len(blocks)} blocks from disk')
|
||||||
|
return blocks
|
||||||
|
except FileNotFoundError:
|
||||||
|
return await self.daemon.raw_blocks(hex_hashes)
|
||||||
|
|
||||||
|
def flush_backup():
|
||||||
|
# self.touched can include other addresses which is
|
||||||
|
# harmless, but remove None.
|
||||||
|
self.touched.discard(None)
|
||||||
|
self.db.flush_backup(self.flush_data(), self.touched)
|
||||||
|
|
||||||
|
start, last, hashes = await self.reorg_hashes(count)
|
||||||
|
# Reverse and convert to hex strings.
|
||||||
|
hashes = [hash_to_hex_str(hash) for hash in reversed(hashes)]
|
||||||
|
for hex_hashes in chunks(hashes, 50):
|
||||||
|
raw_blocks = await get_raw_blocks(last, hex_hashes)
|
||||||
|
await self.run_in_thread_with_lock(self.backup_blocks, raw_blocks)
|
||||||
|
await self.run_in_thread_with_lock(flush_backup)
|
||||||
|
last -= len(raw_blocks)
|
||||||
|
await self.prefetcher.reset_height(self.height)
|
||||||
|
|
||||||
|
async def reorg_hashes(self, count):
|
||||||
|
"""Return a pair (start, last, hashes) of blocks to back up during a
|
||||||
|
reorg.
|
||||||
|
|
||||||
|
The hashes are returned in order of increasing height. Start
|
||||||
|
is the height of the first hash, last of the last.
|
||||||
|
"""
|
||||||
|
start, count = await self.calc_reorg_range(count)
|
||||||
|
last = start + count - 1
|
||||||
|
s = '' if count == 1 else 's'
|
||||||
|
self.logger.info(f'chain was reorganised replacing {count:,d} '
|
||||||
|
f'block{s} at heights {start:,d}-{last:,d}')
|
||||||
|
|
||||||
|
return start, last, await self.db.fs_block_hashes(start, count)
|
||||||
|
|
||||||
|
async def calc_reorg_range(self, count):
|
||||||
|
"""Calculate the reorg range"""
|
||||||
|
|
||||||
|
def diff_pos(hashes1, hashes2):
|
||||||
|
"""Returns the index of the first difference in the hash lists.
|
||||||
|
If both lists match returns their length."""
|
||||||
|
for n, (hash1, hash2) in enumerate(zip(hashes1, hashes2)):
|
||||||
|
if hash1 != hash2:
|
||||||
|
return n
|
||||||
|
return len(hashes)
|
||||||
|
|
||||||
|
if count is None:
|
||||||
|
# A real reorg
|
||||||
|
start = self.height - 1
|
||||||
|
count = 1
|
||||||
|
while start > 0:
|
||||||
|
hashes = await self.db.fs_block_hashes(start, count)
|
||||||
|
hex_hashes = [hash_to_hex_str(hash) for hash in hashes]
|
||||||
|
d_hex_hashes = await self.daemon.block_hex_hashes(start, count)
|
||||||
|
n = diff_pos(hex_hashes, d_hex_hashes)
|
||||||
|
if n > 0:
|
||||||
|
start += n
|
||||||
|
break
|
||||||
|
count = min(count * 2, start)
|
||||||
|
start -= count
|
||||||
|
|
||||||
|
count = (self.height - start) + 1
|
||||||
|
else:
|
||||||
|
start = (self.height - count) + 1
|
||||||
|
|
||||||
|
return start, count
|
||||||
|
|
||||||
|
def estimate_txs_remaining(self):
|
||||||
|
# Try to estimate how many txs there are to go
|
||||||
|
daemon_height = self.daemon.cached_height()
|
||||||
|
coin = self.coin
|
||||||
|
tail_count = daemon_height - max(self.height, coin.TX_COUNT_HEIGHT)
|
||||||
|
# Damp the initial enthusiasm
|
||||||
|
realism = max(2.0 - 0.9 * self.height / coin.TX_COUNT_HEIGHT, 1.0)
|
||||||
|
return (tail_count * coin.TX_PER_BLOCK +
|
||||||
|
max(coin.TX_COUNT - self.tx_count, 0)) * realism
|
||||||
|
|
||||||
|
# - Flushing
|
||||||
|
def flush_data(self):
|
||||||
|
"""The data for a flush. The lock must be taken."""
|
||||||
|
assert self.state_lock.locked()
|
||||||
|
return FlushData(self.height, self.tx_count, self.headers,
|
||||||
|
self.tx_hashes, self.undo_infos, self.utxo_cache,
|
||||||
|
self.db_deletes, self.tip)
|
||||||
|
|
||||||
|
async def flush(self, flush_utxos):
|
||||||
|
def flush():
|
||||||
|
self.db.flush_dbs(self.flush_data(), flush_utxos,
|
||||||
|
self.estimate_txs_remaining)
|
||||||
|
await self.run_in_thread_with_lock(flush)
|
||||||
|
|
||||||
|
async def _maybe_flush(self):
|
||||||
|
# If caught up, flush everything as client queries are
|
||||||
|
# performed on the DB.
|
||||||
|
if self._caught_up_event.is_set():
|
||||||
|
await self.flush(True)
|
||||||
|
elif time.time() > self.next_cache_check:
|
||||||
|
flush_arg = self.check_cache_size()
|
||||||
|
if flush_arg is not None:
|
||||||
|
await self.flush(flush_arg)
|
||||||
|
self.next_cache_check = time.time() + 30
|
||||||
|
|
||||||
|
def check_cache_size(self):
|
||||||
|
"""Flush a cache if it gets too big."""
|
||||||
|
# Good average estimates based on traversal of subobjects and
|
||||||
|
# requesting size from Python (see deep_getsizeof).
|
||||||
|
one_MB = 1000*1000
|
||||||
|
utxo_cache_size = len(self.utxo_cache) * 205
|
||||||
|
db_deletes_size = len(self.db_deletes) * 57
|
||||||
|
hist_cache_size = self.db.history.unflushed_memsize()
|
||||||
|
# Roughly ntxs * 32 + nblocks * 42
|
||||||
|
tx_hash_size = ((self.tx_count - self.db.fs_tx_count) * 32
|
||||||
|
+ (self.height - self.db.fs_height) * 42)
|
||||||
|
utxo_MB = (db_deletes_size + utxo_cache_size) // one_MB
|
||||||
|
hist_MB = (hist_cache_size + tx_hash_size) // one_MB
|
||||||
|
|
||||||
|
self.logger.info('our height: {:,d} daemon: {:,d} '
|
||||||
|
'UTXOs {:,d}MB hist {:,d}MB'
|
||||||
|
.format(self.height, self.daemon.cached_height(),
|
||||||
|
utxo_MB, hist_MB))
|
||||||
|
|
||||||
|
# Flush history if it takes up over 20% of cache memory.
|
||||||
|
# Flush UTXOs once they take up 80% of cache memory.
|
||||||
|
cache_MB = self.env.cache_MB
|
||||||
|
if utxo_MB + hist_MB >= cache_MB or hist_MB >= cache_MB // 5:
|
||||||
|
return utxo_MB >= cache_MB * 4 // 5
|
||||||
|
return None
|
||||||
|
|
||||||
|
def advance_blocks(self, blocks):
|
||||||
|
"""Synchronously advance the blocks.
|
||||||
|
|
||||||
|
It is already verified they correctly connect onto our tip.
|
||||||
|
"""
|
||||||
|
min_height = self.db.min_undo_height(self.daemon.cached_height())
|
||||||
|
height = self.height
|
||||||
|
|
||||||
|
for block in blocks:
|
||||||
|
height += 1
|
||||||
|
undo_info = self.advance_txs(
|
||||||
|
height, block.transactions, self.coin.electrum_header(block.header, height)
|
||||||
|
)
|
||||||
|
if height >= min_height:
|
||||||
|
self.undo_infos.append((undo_info, height))
|
||||||
|
self.db.write_raw_block(block.raw, height)
|
||||||
|
|
||||||
|
headers = [block.header for block in blocks]
|
||||||
|
self.height = height
|
||||||
|
self.headers.extend(headers)
|
||||||
|
self.tip = self.coin.header_hash(headers[-1])
|
||||||
|
|
||||||
|
def advance_txs(self, height, txs, header):
|
||||||
|
self.tx_hashes.append(b''.join(tx_hash for tx, tx_hash in txs))
|
||||||
|
|
||||||
|
# Use local vars for speed in the loops
|
||||||
|
undo_info = []
|
||||||
|
tx_num = self.tx_count
|
||||||
|
script_hashX = self.coin.hashX_from_script
|
||||||
|
s_pack = pack
|
||||||
|
put_utxo = self.utxo_cache.__setitem__
|
||||||
|
spend_utxo = self.spend_utxo
|
||||||
|
undo_info_append = undo_info.append
|
||||||
|
update_touched = self.touched.update
|
||||||
|
hashXs_by_tx = []
|
||||||
|
append_hashXs = hashXs_by_tx.append
|
||||||
|
|
||||||
|
for tx, tx_hash in txs:
|
||||||
|
hashXs = []
|
||||||
|
append_hashX = hashXs.append
|
||||||
|
tx_numb = s_pack('<I', tx_num)
|
||||||
|
|
||||||
|
# Spend the inputs
|
||||||
|
for txin in tx.inputs:
|
||||||
|
if txin.is_generation():
|
||||||
|
continue
|
||||||
|
cache_value = spend_utxo(txin.prev_hash, txin.prev_idx)
|
||||||
|
undo_info_append(cache_value)
|
||||||
|
append_hashX(cache_value[:-12])
|
||||||
|
|
||||||
|
# Add the new UTXOs
|
||||||
|
for idx, txout in enumerate(tx.outputs):
|
||||||
|
# Get the hashX. Ignore unspendable outputs
|
||||||
|
hashX = script_hashX(txout.pk_script)
|
||||||
|
if hashX:
|
||||||
|
append_hashX(hashX)
|
||||||
|
put_utxo(tx_hash + s_pack('<H', idx),
|
||||||
|
hashX + tx_numb + s_pack('<Q', txout.value))
|
||||||
|
|
||||||
|
append_hashXs(hashXs)
|
||||||
|
update_touched(hashXs)
|
||||||
|
tx_num += 1
|
||||||
|
|
||||||
|
self.db.history.add_unflushed(hashXs_by_tx, self.tx_count)
|
||||||
|
|
||||||
|
self.tx_count = tx_num
|
||||||
|
self.db.tx_counts.append(tx_num)
|
||||||
|
|
||||||
|
return undo_info
|
||||||
|
|
||||||
|
def backup_blocks(self, raw_blocks):
|
||||||
|
"""Backup the raw blocks and flush.
|
||||||
|
|
||||||
|
The blocks should be in order of decreasing height, starting at.
|
||||||
|
self.height. A flush is performed once the blocks are backed up.
|
||||||
|
"""
|
||||||
|
self.db.assert_flushed(self.flush_data())
|
||||||
|
assert self.height >= len(raw_blocks)
|
||||||
|
|
||||||
|
coin = self.coin
|
||||||
|
for raw_block in raw_blocks:
|
||||||
|
# Check and update self.tip
|
||||||
|
block = coin.block(raw_block, self.height)
|
||||||
|
header_hash = coin.header_hash(block.header)
|
||||||
|
if header_hash != self.tip:
|
||||||
|
raise ChainError('backup block {} not tip {} at height {:,d}'
|
||||||
|
.format(hash_to_hex_str(header_hash),
|
||||||
|
hash_to_hex_str(self.tip),
|
||||||
|
self.height))
|
||||||
|
self.tip = coin.header_prevhash(block.header)
|
||||||
|
self.backup_txs(block.transactions)
|
||||||
|
self.height -= 1
|
||||||
|
self.db.tx_counts.pop()
|
||||||
|
|
||||||
|
self.logger.info('backed up to height {:,d}'.format(self.height))
|
||||||
|
|
||||||
|
def backup_txs(self, txs):
|
||||||
|
# Prevout values, in order down the block (coinbase first if present)
|
||||||
|
# undo_info is in reverse block order
|
||||||
|
undo_info = self.db.read_undo_info(self.height)
|
||||||
|
if undo_info is None:
|
||||||
|
raise ChainError('no undo information found for height {:,d}'
|
||||||
|
.format(self.height))
|
||||||
|
n = len(undo_info)
|
||||||
|
|
||||||
|
# Use local vars for speed in the loops
|
||||||
|
s_pack = pack
|
||||||
|
put_utxo = self.utxo_cache.__setitem__
|
||||||
|
spend_utxo = self.spend_utxo
|
||||||
|
script_hashX = self.coin.hashX_from_script
|
||||||
|
touched = self.touched
|
||||||
|
undo_entry_len = 12 + HASHX_LEN
|
||||||
|
|
||||||
|
for tx, tx_hash in reversed(txs):
|
||||||
|
for idx, txout in enumerate(tx.outputs):
|
||||||
|
# Spend the TX outputs. Be careful with unspendable
|
||||||
|
# outputs - we didn't save those in the first place.
|
||||||
|
hashX = script_hashX(txout.pk_script)
|
||||||
|
if hashX:
|
||||||
|
cache_value = spend_utxo(tx_hash, idx)
|
||||||
|
touched.add(cache_value[:-12])
|
||||||
|
|
||||||
|
# Restore the inputs
|
||||||
|
for txin in reversed(tx.inputs):
|
||||||
|
if txin.is_generation():
|
||||||
|
continue
|
||||||
|
n -= undo_entry_len
|
||||||
|
undo_item = undo_info[n:n + undo_entry_len]
|
||||||
|
put_utxo(txin.prev_hash + s_pack('<H', txin.prev_idx),
|
||||||
|
undo_item)
|
||||||
|
touched.add(undo_item[:-12])
|
||||||
|
|
||||||
|
assert n == 0
|
||||||
|
self.tx_count -= len(txs)
|
||||||
|
|
||||||
|
"""An in-memory UTXO cache, representing all changes to UTXO state
|
||||||
|
since the last DB flush.
|
||||||
|
|
||||||
|
We want to store millions of these in memory for optimal
|
||||||
|
performance during initial sync, because then it is possible to
|
||||||
|
spend UTXOs without ever going to the database (other than as an
|
||||||
|
entry in the address history, and there is only one such entry per
|
||||||
|
TX not per UTXO). So store them in a Python dictionary with
|
||||||
|
binary keys and values.
|
||||||
|
|
||||||
|
Key: TX_HASH + TX_IDX (32 + 2 = 34 bytes)
|
||||||
|
Value: HASHX + TX_NUM + VALUE (11 + 4 + 8 = 23 bytes)
|
||||||
|
|
||||||
|
That's 57 bytes of raw data in-memory. Python dictionary overhead
|
||||||
|
means each entry actually uses about 205 bytes of memory. So
|
||||||
|
almost 5 million UTXOs can fit in 1GB of RAM. There are
|
||||||
|
approximately 42 million UTXOs on bitcoin mainnet at height
|
||||||
|
433,000.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
|
||||||
|
add: Add it to the cache dictionary.
|
||||||
|
|
||||||
|
spend: Remove it if in the cache dictionary. Otherwise it's
|
||||||
|
been flushed to the DB. Each UTXO is responsible for two
|
||||||
|
entries in the DB. Mark them for deletion in the next
|
||||||
|
cache flush.
|
||||||
|
|
||||||
|
The UTXO database format has to be able to do two things efficiently:
|
||||||
|
|
||||||
|
1. Given an address be able to list its UTXOs and their values
|
||||||
|
so its balance can be efficiently computed.
|
||||||
|
|
||||||
|
2. When processing transactions, for each prevout spent - a (tx_hash,
|
||||||
|
idx) pair - we have to be able to remove it from the DB. To send
|
||||||
|
notifications to clients we also need to know any address it paid
|
||||||
|
to.
|
||||||
|
|
||||||
|
To this end we maintain two "tables", one for each point above:
|
||||||
|
|
||||||
|
1. Key: b'u' + address_hashX + tx_idx + tx_num
|
||||||
|
Value: the UTXO value as a 64-bit unsigned integer
|
||||||
|
|
||||||
|
2. Key: b'h' + compressed_tx_hash + tx_idx + tx_num
|
||||||
|
Value: hashX
|
||||||
|
|
||||||
|
The compressed tx hash is just the first few bytes of the hash of
|
||||||
|
the tx in which the UTXO was created. As this is not unique there
|
||||||
|
will be potential collisions so tx_num is also in the key. When
|
||||||
|
looking up a UTXO the prefix space of the compressed hash needs to
|
||||||
|
be searched and resolved if necessary with the tx_num. The
|
||||||
|
collision rate is low (<0.1%).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def spend_utxo(self, tx_hash, tx_idx):
|
||||||
|
"""Spend a UTXO and return the 33-byte value.
|
||||||
|
|
||||||
|
If the UTXO is not in the cache it must be on disk. We store
|
||||||
|
all UTXOs so not finding one indicates a logic error or DB
|
||||||
|
corruption.
|
||||||
|
"""
|
||||||
|
# Fast track is it being in the cache
|
||||||
|
idx_packed = pack('<H', tx_idx)
|
||||||
|
cache_value = self.utxo_cache.pop(tx_hash + idx_packed, None)
|
||||||
|
if cache_value:
|
||||||
|
return cache_value
|
||||||
|
|
||||||
|
# Spend it from the DB.
|
||||||
|
|
||||||
|
# Key: b'h' + compressed_tx_hash + tx_idx + tx_num
|
||||||
|
# Value: hashX
|
||||||
|
prefix = b'h' + tx_hash[:4] + idx_packed
|
||||||
|
candidates = {db_key: hashX for db_key, hashX
|
||||||
|
in self.db.utxo_db.iterator(prefix=prefix)}
|
||||||
|
|
||||||
|
for hdb_key, hashX in candidates.items():
|
||||||
|
tx_num_packed = hdb_key[-4:]
|
||||||
|
|
||||||
|
if len(candidates) > 1:
|
||||||
|
tx_num, = unpack('<I', tx_num_packed)
|
||||||
|
hash, height = self.db.fs_tx_hash(tx_num)
|
||||||
|
if hash != tx_hash:
|
||||||
|
assert hash is not None # Should always be found
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Key: b'u' + address_hashX + tx_idx + tx_num
|
||||||
|
# Value: the UTXO value as a 64-bit unsigned integer
|
||||||
|
udb_key = b'u' + hashX + hdb_key[-6:]
|
||||||
|
utxo_value_packed = self.db.utxo_db.get(udb_key)
|
||||||
|
if utxo_value_packed:
|
||||||
|
# Remove both entries for this UTXO
|
||||||
|
self.db_deletes.append(hdb_key)
|
||||||
|
self.db_deletes.append(udb_key)
|
||||||
|
return hashX + tx_num_packed + utxo_value_packed
|
||||||
|
|
||||||
|
raise ChainError('UTXO {} / {:,d} not found in "h" table'
|
||||||
|
.format(hash_to_hex_str(tx_hash), tx_idx))
|
||||||
|
|
||||||
|
async def _process_prefetched_blocks(self):
|
||||||
|
"""Loop forever processing blocks as they arrive."""
|
||||||
|
while True:
|
||||||
|
if self.height == self.daemon.cached_height():
|
||||||
|
if not self._caught_up_event.is_set():
|
||||||
|
await self._first_caught_up()
|
||||||
|
self._caught_up_event.set()
|
||||||
|
await self.blocks_event.wait()
|
||||||
|
self.blocks_event.clear()
|
||||||
|
if self.reorg_count:
|
||||||
|
await self.reorg_chain(self.reorg_count)
|
||||||
|
self.reorg_count = 0
|
||||||
|
else:
|
||||||
|
blocks = self.prefetcher.get_prefetched_blocks()
|
||||||
|
await self.check_and_advance_blocks(blocks)
|
||||||
|
|
||||||
|
async def _first_caught_up(self):
|
||||||
|
self.logger.info(f'caught up to height {self.height}')
|
||||||
|
# Flush everything but with first_sync->False state.
|
||||||
|
first_sync = self.db.first_sync
|
||||||
|
self.db.first_sync = False
|
||||||
|
await self.flush(True)
|
||||||
|
if first_sync:
|
||||||
|
self.logger.info(f'{torba.__version__} synced to '
|
||||||
|
f'height {self.height:,d}')
|
||||||
|
# Reopen for serving
|
||||||
|
await self.db.open_for_serving()
|
||||||
|
|
||||||
|
async def _first_open_dbs(self):
|
||||||
|
await self.db.open_for_sync()
|
||||||
|
self.height = self.db.db_height
|
||||||
|
self.tip = self.db.db_tip
|
||||||
|
self.tx_count = self.db.db_tx_count
|
||||||
|
|
||||||
|
# --- External API
|
||||||
|
|
||||||
|
async def fetch_and_process_blocks(self, caught_up_event):
|
||||||
|
"""Fetch, process and index blocks from the daemon.
|
||||||
|
|
||||||
|
Sets caught_up_event when first caught up. Flushes to disk
|
||||||
|
and shuts down cleanly if cancelled.
|
||||||
|
|
||||||
|
This is mainly because if, during initial sync ElectrumX is
|
||||||
|
asked to shut down when a large number of blocks have been
|
||||||
|
processed but not written to disk, it should write those to
|
||||||
|
disk before exiting, as otherwise a significant amount of work
|
||||||
|
could be lost.
|
||||||
|
"""
|
||||||
|
self._caught_up_event = caught_up_event
|
||||||
|
await self._first_open_dbs()
|
||||||
|
try:
|
||||||
|
await asyncio.wait([
|
||||||
|
self.prefetcher.main_loop(self.height),
|
||||||
|
self._process_prefetched_blocks()
|
||||||
|
])
|
||||||
|
finally:
|
||||||
|
# Shut down block processing
|
||||||
|
self.logger.info('flushing to DB for a clean shutdown...')
|
||||||
|
await self.flush(True)
|
||||||
|
self.db.close()
|
||||||
|
|
||||||
|
def force_chain_reorg(self, count):
|
||||||
|
"""Force a reorg of the given number of blocks.
|
||||||
|
|
||||||
|
Returns True if a reorg is queued, false if not caught up.
|
||||||
|
"""
|
||||||
|
if self._caught_up_event.is_set():
|
||||||
|
self.reorg_count = count
|
||||||
|
self.blocks_event.set()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class DecredBlockProcessor(BlockProcessor):
|
||||||
|
async def calc_reorg_range(self, count):
|
||||||
|
start, count = await super().calc_reorg_range(count)
|
||||||
|
if start > 0:
|
||||||
|
# A reorg in Decred can invalidate the previous block
|
||||||
|
start -= 1
|
||||||
|
count += 1
|
||||||
|
return start, count
|
||||||
|
|
||||||
|
|
||||||
|
class NamecoinBlockProcessor(BlockProcessor):
|
||||||
|
def advance_txs(self, txs):
|
||||||
|
result = super().advance_txs(txs)
|
||||||
|
|
||||||
|
tx_num = self.tx_count - len(txs)
|
||||||
|
script_name_hashX = self.coin.name_hashX_from_script
|
||||||
|
update_touched = self.touched.update
|
||||||
|
hashXs_by_tx = []
|
||||||
|
append_hashXs = hashXs_by_tx.append
|
||||||
|
|
||||||
|
for tx, tx_hash in txs:
|
||||||
|
hashXs = []
|
||||||
|
append_hashX = hashXs.append
|
||||||
|
|
||||||
|
# Add the new UTXOs and associate them with the name script
|
||||||
|
for idx, txout in enumerate(tx.outputs):
|
||||||
|
# Get the hashX of the name script. Ignore non-name scripts.
|
||||||
|
hashX = script_name_hashX(txout.pk_script)
|
||||||
|
if hashX:
|
||||||
|
append_hashX(hashX)
|
||||||
|
|
||||||
|
append_hashXs(hashXs)
|
||||||
|
update_touched(hashXs)
|
||||||
|
tx_num += 1
|
||||||
|
|
||||||
|
self.db.history.add_unflushed(hashXs_by_tx, self.tx_count - len(txs))
|
||||||
|
|
||||||
|
return result
|
40
torba/torba/server/cli.py
Normal file
40
torba/torba/server/cli.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
import argparse
|
||||||
|
import importlib
|
||||||
|
from torba.server.env import Env
|
||||||
|
from torba.server.server import Server
|
||||||
|
|
||||||
|
|
||||||
|
def get_argument_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="torba-server"
|
||||||
|
)
|
||||||
|
parser.add_argument("spvserver", type=str, help="Python class path to SPV server implementation.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_coin_class(spvserver):
|
||||||
|
spvserver_path, coin_class_name = spvserver.rsplit('.', 1)
|
||||||
|
spvserver_module = importlib.import_module(spvserver_path)
|
||||||
|
return getattr(spvserver_module, coin_class_name)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_argument_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
coin_class = get_coin_class(args.spvserver)
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logging.info('torba.server starting')
|
||||||
|
try:
|
||||||
|
server = Server(Env(coin_class))
|
||||||
|
server.run()
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
logging.critical('torba.server terminated abnormally')
|
||||||
|
else:
|
||||||
|
logging.info('torba.server terminated normally')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
2292
torba/torba/server/coins.py
Normal file
2292
torba/torba/server/coins.py
Normal file
File diff suppressed because it is too large
Load diff
459
torba/torba/server/daemon.py
Normal file
459
torba/torba/server/daemon.py
Normal file
|
@ -0,0 +1,459 @@
|
||||||
|
# Copyright (c) 2016-2017, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# See the file "LICENCE" for information about the copyright
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
"""Class for handling asynchronous connections to a blockchain
|
||||||
|
daemon."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import itertools
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from calendar import timegm
|
||||||
|
from struct import pack
|
||||||
|
from time import strptime
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from torba.server.util import hex_to_bytes, class_logger,\
|
||||||
|
unpack_le_uint16_from, pack_varint
|
||||||
|
from torba.server.hash import hex_str_to_hash, hash_to_hex_str
|
||||||
|
from torba.server.tx import DeserializerDecred
|
||||||
|
from torba.rpc import JSONRPC
|
||||||
|
|
||||||
|
|
||||||
|
class DaemonError(Exception):
|
||||||
|
"""Raised when the daemon returns an error in its results."""
|
||||||
|
|
||||||
|
|
||||||
|
class WarmingUpError(Exception):
|
||||||
|
"""Internal - when the daemon is warming up."""
|
||||||
|
|
||||||
|
|
||||||
|
class WorkQueueFullError(Exception):
|
||||||
|
"""Internal - when the daemon's work queue is full."""
|
||||||
|
|
||||||
|
|
||||||
|
class Daemon:
|
||||||
|
"""Handles connections to a daemon at the given URL."""
|
||||||
|
|
||||||
|
WARMING_UP = -28
|
||||||
|
id_counter = itertools.count()
|
||||||
|
|
||||||
|
def __init__(self, coin, url, max_workqueue=10, init_retry=0.25,
|
||||||
|
max_retry=4.0):
|
||||||
|
self.coin = coin
|
||||||
|
self.logger = class_logger(__name__, self.__class__.__name__)
|
||||||
|
self.set_url(url)
|
||||||
|
# Limit concurrent RPC calls to this number.
|
||||||
|
# See DEFAULT_HTTP_WORKQUEUE in bitcoind, which is typically 16
|
||||||
|
self.workqueue_semaphore = asyncio.Semaphore(value=max_workqueue)
|
||||||
|
self.init_retry = init_retry
|
||||||
|
self.max_retry = max_retry
|
||||||
|
self._height = None
|
||||||
|
self.available_rpcs = {}
|
||||||
|
|
||||||
|
def set_url(self, url):
|
||||||
|
"""Set the URLS to the given list, and switch to the first one."""
|
||||||
|
urls = url.split(',')
|
||||||
|
urls = [self.coin.sanitize_url(url) for url in urls]
|
||||||
|
for n, url in enumerate(urls):
|
||||||
|
status = '' if n else ' (current)'
|
||||||
|
logged_url = self.logged_url(url)
|
||||||
|
self.logger.info(f'daemon #{n + 1} at {logged_url}{status}')
|
||||||
|
self.url_index = 0
|
||||||
|
self.urls = urls
|
||||||
|
|
||||||
|
def current_url(self):
|
||||||
|
"""Returns the current daemon URL."""
|
||||||
|
return self.urls[self.url_index]
|
||||||
|
|
||||||
|
def logged_url(self, url=None):
|
||||||
|
"""The host and port part, for logging."""
|
||||||
|
url = url or self.current_url()
|
||||||
|
return url[url.rindex('@') + 1:]
|
||||||
|
|
||||||
|
def failover(self):
|
||||||
|
"""Call to fail-over to the next daemon URL.
|
||||||
|
|
||||||
|
Returns False if there is only one, otherwise True.
|
||||||
|
"""
|
||||||
|
if len(self.urls) > 1:
|
||||||
|
self.url_index = (self.url_index + 1) % len(self.urls)
|
||||||
|
self.logger.info(f'failing over to {self.logged_url()}')
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def client_session(self):
|
||||||
|
"""An aiohttp client session."""
|
||||||
|
return aiohttp.ClientSession()
|
||||||
|
|
||||||
|
async def _send_data(self, data):
|
||||||
|
async with self.workqueue_semaphore:
|
||||||
|
async with self.client_session() as session:
|
||||||
|
async with session.post(self.current_url(), data=data) as resp:
|
||||||
|
kind = resp.headers.get('Content-Type', None)
|
||||||
|
if kind == 'application/json':
|
||||||
|
return await resp.json()
|
||||||
|
# bitcoind's HTTP protocol "handling" is a bad joke
|
||||||
|
text = await resp.text()
|
||||||
|
if 'Work queue depth exceeded' in text:
|
||||||
|
raise WorkQueueFullError
|
||||||
|
text = text.strip() or resp.reason
|
||||||
|
self.logger.error(text)
|
||||||
|
raise DaemonError(text)
|
||||||
|
|
||||||
|
async def _send(self, payload, processor):
|
||||||
|
"""Send a payload to be converted to JSON.
|
||||||
|
|
||||||
|
Handles temporary connection issues. Daemon reponse errors
|
||||||
|
are raise through DaemonError.
|
||||||
|
"""
|
||||||
|
def log_error(error):
|
||||||
|
nonlocal last_error_log, retry
|
||||||
|
now = time.time()
|
||||||
|
if now - last_error_log > 60:
|
||||||
|
last_error_log = now
|
||||||
|
self.logger.error(f'{error} Retrying occasionally...')
|
||||||
|
if retry == self.max_retry and self.failover():
|
||||||
|
retry = 0
|
||||||
|
|
||||||
|
on_good_message = None
|
||||||
|
last_error_log = 0
|
||||||
|
data = json.dumps(payload)
|
||||||
|
retry = self.init_retry
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
result = await self._send_data(data)
|
||||||
|
result = processor(result)
|
||||||
|
if on_good_message:
|
||||||
|
self.logger.info(on_good_message)
|
||||||
|
return result
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
log_error('timeout error.')
|
||||||
|
except aiohttp.ServerDisconnectedError:
|
||||||
|
log_error('disconnected.')
|
||||||
|
on_good_message = 'connection restored'
|
||||||
|
except aiohttp.ClientConnectionError:
|
||||||
|
log_error('connection problem - is your daemon running?')
|
||||||
|
on_good_message = 'connection restored'
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
log_error(f'daemon error: {e}')
|
||||||
|
on_good_message = 'running normally'
|
||||||
|
except WarmingUpError:
|
||||||
|
log_error('starting up checking blocks.')
|
||||||
|
on_good_message = 'running normally'
|
||||||
|
except WorkQueueFullError:
|
||||||
|
log_error('work queue full.')
|
||||||
|
on_good_message = 'running normally'
|
||||||
|
|
||||||
|
await asyncio.sleep(retry)
|
||||||
|
retry = max(min(self.max_retry, retry * 2), self.init_retry)
|
||||||
|
|
||||||
|
async def _send_single(self, method, params=None):
|
||||||
|
"""Send a single request to the daemon."""
|
||||||
|
def processor(result):
|
||||||
|
err = result['error']
|
||||||
|
if not err:
|
||||||
|
return result['result']
|
||||||
|
if err.get('code') == self.WARMING_UP:
|
||||||
|
raise WarmingUpError
|
||||||
|
raise DaemonError(err)
|
||||||
|
|
||||||
|
payload = {'method': method, 'id': next(self.id_counter)}
|
||||||
|
if params:
|
||||||
|
payload['params'] = params
|
||||||
|
return await self._send(payload, processor)
|
||||||
|
|
||||||
|
async def _send_vector(self, method, params_iterable, replace_errs=False):
|
||||||
|
"""Send several requests of the same method.
|
||||||
|
|
||||||
|
The result will be an array of the same length as params_iterable.
|
||||||
|
If replace_errs is true, any item with an error is returned as None,
|
||||||
|
otherwise an exception is raised."""
|
||||||
|
def processor(result):
|
||||||
|
errs = [item['error'] for item in result if item['error']]
|
||||||
|
if any(err.get('code') == self.WARMING_UP for err in errs):
|
||||||
|
raise WarmingUpError
|
||||||
|
if not errs or replace_errs:
|
||||||
|
return [item['result'] for item in result]
|
||||||
|
raise DaemonError(errs)
|
||||||
|
|
||||||
|
payload = [{'method': method, 'params': p, 'id': next(self.id_counter)}
|
||||||
|
for p in params_iterable]
|
||||||
|
if payload:
|
||||||
|
return await self._send(payload, processor)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _is_rpc_available(self, method):
|
||||||
|
"""Return whether given RPC method is available in the daemon.
|
||||||
|
|
||||||
|
Results are cached and the daemon will generally not be queried with
|
||||||
|
the same method more than once."""
|
||||||
|
available = self.available_rpcs.get(method)
|
||||||
|
if available is None:
|
||||||
|
available = True
|
||||||
|
try:
|
||||||
|
await self._send_single(method)
|
||||||
|
except DaemonError as e:
|
||||||
|
err = e.args[0]
|
||||||
|
error_code = err.get("code")
|
||||||
|
available = error_code != JSONRPC.METHOD_NOT_FOUND
|
||||||
|
self.available_rpcs[method] = available
|
||||||
|
return available
|
||||||
|
|
||||||
|
async def block_hex_hashes(self, first, count):
|
||||||
|
"""Return the hex hashes of count block starting at height first."""
|
||||||
|
params_iterable = ((h, ) for h in range(first, first + count))
|
||||||
|
return await self._send_vector('getblockhash', params_iterable)
|
||||||
|
|
||||||
|
async def deserialised_block(self, hex_hash):
|
||||||
|
"""Return the deserialised block with the given hex hash."""
|
||||||
|
return await self._send_single('getblock', (hex_hash, True))
|
||||||
|
|
||||||
|
async def raw_blocks(self, hex_hashes):
|
||||||
|
"""Return the raw binary blocks with the given hex hashes."""
|
||||||
|
params_iterable = ((h, False) for h in hex_hashes)
|
||||||
|
blocks = await self._send_vector('getblock', params_iterable)
|
||||||
|
# Convert hex string to bytes
|
||||||
|
return [hex_to_bytes(block) for block in blocks]
|
||||||
|
|
||||||
|
async def mempool_hashes(self):
|
||||||
|
"""Update our record of the daemon's mempool hashes."""
|
||||||
|
return await self._send_single('getrawmempool')
|
||||||
|
|
||||||
|
async def estimatefee(self, block_count):
|
||||||
|
"""Return the fee estimate for the block count. Units are whole
|
||||||
|
currency units per KB, e.g. 0.00000995, or -1 if no estimate
|
||||||
|
is available.
|
||||||
|
"""
|
||||||
|
args = (block_count, )
|
||||||
|
if await self._is_rpc_available('estimatesmartfee'):
|
||||||
|
estimate = await self._send_single('estimatesmartfee', args)
|
||||||
|
return estimate.get('feerate', -1)
|
||||||
|
return await self._send_single('estimatefee', args)
|
||||||
|
|
||||||
|
async def getnetworkinfo(self):
|
||||||
|
"""Return the result of the 'getnetworkinfo' RPC call."""
|
||||||
|
return await self._send_single('getnetworkinfo')
|
||||||
|
|
||||||
|
async def relayfee(self):
|
||||||
|
"""The minimum fee a low-priority tx must pay in order to be accepted
|
||||||
|
to the daemon's memory pool."""
|
||||||
|
network_info = await self.getnetworkinfo()
|
||||||
|
return network_info['relayfee']
|
||||||
|
|
||||||
|
async def getrawtransaction(self, hex_hash, verbose=False):
|
||||||
|
"""Return the serialized raw transaction with the given hash."""
|
||||||
|
# Cast to int because some coin daemons are old and require it
|
||||||
|
return await self._send_single('getrawtransaction',
|
||||||
|
(hex_hash, int(verbose)))
|
||||||
|
|
||||||
|
async def getrawtransactions(self, hex_hashes, replace_errs=True):
|
||||||
|
"""Return the serialized raw transactions with the given hashes.
|
||||||
|
|
||||||
|
Replaces errors with None by default."""
|
||||||
|
params_iterable = ((hex_hash, 0) for hex_hash in hex_hashes)
|
||||||
|
txs = await self._send_vector('getrawtransaction', params_iterable,
|
||||||
|
replace_errs=replace_errs)
|
||||||
|
# Convert hex strings to bytes
|
||||||
|
return [hex_to_bytes(tx) if tx else None for tx in txs]
|
||||||
|
|
||||||
|
async def broadcast_transaction(self, raw_tx):
|
||||||
|
"""Broadcast a transaction to the network."""
|
||||||
|
return await self._send_single('sendrawtransaction', (raw_tx, ))
|
||||||
|
|
||||||
|
async def height(self):
|
||||||
|
"""Query the daemon for its current height."""
|
||||||
|
self._height = await self._send_single('getblockcount')
|
||||||
|
return self._height
|
||||||
|
|
||||||
|
def cached_height(self):
|
||||||
|
"""Return the cached daemon height.
|
||||||
|
|
||||||
|
If the daemon has not been queried yet this returns None."""
|
||||||
|
return self._height
|
||||||
|
|
||||||
|
|
||||||
|
class DashDaemon(Daemon):
|
||||||
|
|
||||||
|
async def masternode_broadcast(self, params):
|
||||||
|
"""Broadcast a transaction to the network."""
|
||||||
|
return await self._send_single('masternodebroadcast', params)
|
||||||
|
|
||||||
|
async def masternode_list(self, params):
|
||||||
|
"""Return the masternode status."""
|
||||||
|
return await self._send_single('masternodelist', params)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeEstimateFeeDaemon(Daemon):
|
||||||
|
"""Daemon that simulates estimatefee and relayfee RPC calls. Coin that
|
||||||
|
wants to use this daemon must define ESTIMATE_FEE & RELAY_FEE"""
|
||||||
|
|
||||||
|
async def estimatefee(self, block_count):
|
||||||
|
"""Return the fee estimate for the given parameters."""
|
||||||
|
return self.coin.ESTIMATE_FEE
|
||||||
|
|
||||||
|
async def relayfee(self):
|
||||||
|
"""The minimum fee a low-priority tx must pay in order to be accepted
|
||||||
|
to the daemon's memory pool."""
|
||||||
|
return self.coin.RELAY_FEE
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyRPCDaemon(Daemon):
|
||||||
|
"""Handles connections to a daemon at the given URL.
|
||||||
|
|
||||||
|
This class is useful for daemons that don't have the new 'getblock'
|
||||||
|
RPC call that returns the block in hex, the workaround is to manually
|
||||||
|
recreate the block bytes. The recreated block bytes may not be the exact
|
||||||
|
as in the underlying blockchain but it is good enough for our indexing
|
||||||
|
purposes."""
|
||||||
|
|
||||||
|
async def raw_blocks(self, hex_hashes):
|
||||||
|
"""Return the raw binary blocks with the given hex hashes."""
|
||||||
|
params_iterable = ((h, ) for h in hex_hashes)
|
||||||
|
block_info = await self._send_vector('getblock', params_iterable)
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
for i in block_info:
|
||||||
|
raw_block = await self.make_raw_block(i)
|
||||||
|
blocks.append(raw_block)
|
||||||
|
|
||||||
|
# Convert hex string to bytes
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
async def make_raw_header(self, b):
|
||||||
|
pbh = b.get('previousblockhash')
|
||||||
|
if pbh is None:
|
||||||
|
pbh = '0' * 64
|
||||||
|
return b''.join([
|
||||||
|
pack('<L', b.get('version')),
|
||||||
|
hex_str_to_hash(pbh),
|
||||||
|
hex_str_to_hash(b.get('merkleroot')),
|
||||||
|
pack('<L', self.timestamp_safe(b['time'])),
|
||||||
|
pack('<L', int(b.get('bits'), 16)),
|
||||||
|
pack('<L', int(b.get('nonce')))
|
||||||
|
])
|
||||||
|
|
||||||
|
async def make_raw_block(self, b):
|
||||||
|
"""Construct a raw block"""
|
||||||
|
|
||||||
|
header = await self.make_raw_header(b)
|
||||||
|
|
||||||
|
transactions = []
|
||||||
|
if b.get('height') > 0:
|
||||||
|
transactions = await self.getrawtransactions(b.get('tx'), False)
|
||||||
|
|
||||||
|
raw_block = header
|
||||||
|
num_txs = len(transactions)
|
||||||
|
if num_txs > 0:
|
||||||
|
raw_block += pack_varint(num_txs)
|
||||||
|
raw_block += b''.join(transactions)
|
||||||
|
else:
|
||||||
|
raw_block += b'\x00'
|
||||||
|
|
||||||
|
return raw_block
|
||||||
|
|
||||||
|
def timestamp_safe(self, t):
|
||||||
|
if isinstance(t, int):
|
||||||
|
return t
|
||||||
|
return timegm(strptime(t, "%Y-%m-%d %H:%M:%S %Z"))
|
||||||
|
|
||||||
|
|
||||||
|
class DecredDaemon(Daemon):
|
||||||
|
async def raw_blocks(self, hex_hashes):
|
||||||
|
"""Return the raw binary blocks with the given hex hashes."""
|
||||||
|
|
||||||
|
params_iterable = ((h, False) for h in hex_hashes)
|
||||||
|
blocks = await self._send_vector('getblock', params_iterable)
|
||||||
|
|
||||||
|
raw_blocks = []
|
||||||
|
valid_tx_tree = {}
|
||||||
|
for block in blocks:
|
||||||
|
# Convert to bytes from hex
|
||||||
|
raw_block = hex_to_bytes(block)
|
||||||
|
raw_blocks.append(raw_block)
|
||||||
|
# Check if previous block is valid
|
||||||
|
prev = self.prev_hex_hash(raw_block)
|
||||||
|
votebits = unpack_le_uint16_from(raw_block[100:102])[0]
|
||||||
|
valid_tx_tree[prev] = self.is_valid_tx_tree(votebits)
|
||||||
|
|
||||||
|
processed_raw_blocks = []
|
||||||
|
for hash, raw_block in zip(hex_hashes, raw_blocks):
|
||||||
|
if hash in valid_tx_tree:
|
||||||
|
is_valid = valid_tx_tree[hash]
|
||||||
|
else:
|
||||||
|
# Do something complicated to figure out if this block is valid
|
||||||
|
header = await self._send_single('getblockheader', (hash, ))
|
||||||
|
if 'nextblockhash' not in header:
|
||||||
|
raise DaemonError(f'Could not find next block for {hash}')
|
||||||
|
next_hash = header['nextblockhash']
|
||||||
|
next_header = await self._send_single('getblockheader',
|
||||||
|
(next_hash, ))
|
||||||
|
is_valid = self.is_valid_tx_tree(next_header['votebits'])
|
||||||
|
|
||||||
|
if is_valid:
|
||||||
|
processed_raw_blocks.append(raw_block)
|
||||||
|
else:
|
||||||
|
# If this block is invalid remove the normal transactions
|
||||||
|
self.logger.info(f'block {hash} is invalidated')
|
||||||
|
processed_raw_blocks.append(self.strip_tx_tree(raw_block))
|
||||||
|
|
||||||
|
return processed_raw_blocks
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prev_hex_hash(raw_block):
|
||||||
|
return hash_to_hex_str(raw_block[4:36])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_valid_tx_tree(votebits):
|
||||||
|
# Check if previous block was invalidated.
|
||||||
|
return bool(votebits & (1 << 0) != 0)
|
||||||
|
|
||||||
|
def strip_tx_tree(self, raw_block):
|
||||||
|
c = self.coin
|
||||||
|
assert issubclass(c.DESERIALIZER, DeserializerDecred)
|
||||||
|
d = c.DESERIALIZER(raw_block, start=c.BASIC_HEADER_SIZE)
|
||||||
|
d.read_tx_tree() # Skip normal transactions
|
||||||
|
# Create a fake block without any normal transactions
|
||||||
|
return raw_block[:c.BASIC_HEADER_SIZE] + b'\x00' + raw_block[d.cursor:]
|
||||||
|
|
||||||
|
async def height(self):
|
||||||
|
height = await super().height()
|
||||||
|
if height > 0:
|
||||||
|
# Lie about the daemon height as the current tip can be invalidated
|
||||||
|
height -= 1
|
||||||
|
self._height = height
|
||||||
|
return height
|
||||||
|
|
||||||
|
async def mempool_hashes(self):
|
||||||
|
mempool = await super().mempool_hashes()
|
||||||
|
# Add current tip transactions to the 'fake' mempool.
|
||||||
|
real_height = await self._send_single('getblockcount')
|
||||||
|
tip_hash = await self._send_single('getblockhash', (real_height,))
|
||||||
|
tip = await self.deserialised_block(tip_hash)
|
||||||
|
# Add normal transactions except coinbase
|
||||||
|
mempool += tip['tx'][1:]
|
||||||
|
# Add stake transactions if applicable
|
||||||
|
mempool += tip.get('stx', [])
|
||||||
|
return mempool
|
||||||
|
|
||||||
|
def client_session(self):
|
||||||
|
# FIXME allow self signed certificates
|
||||||
|
connector = aiohttp.TCPConnector(verify_ssl=False)
|
||||||
|
return aiohttp.ClientSession(connector=connector)
|
||||||
|
|
||||||
|
|
||||||
|
class PreLegacyRPCDaemon(LegacyRPCDaemon):
|
||||||
|
"""Handles connections to a daemon at the given URL.
|
||||||
|
|
||||||
|
This class is useful for daemons that don't have the new 'getblock'
|
||||||
|
RPC call that returns the block in hex, and need the False parameter
|
||||||
|
for the getblock"""
|
||||||
|
|
||||||
|
async def deserialised_block(self, hex_hash):
|
||||||
|
"""Return the deserialised block with the given hex hash."""
|
||||||
|
return await self._send_single('getblock', (hex_hash, False))
|
670
torba/torba/server/db.py
Normal file
670
torba/torba/server/db.py
Normal file
|
@ -0,0 +1,670 @@
|
||||||
|
# Copyright (c) 2016, Neil Booth
|
||||||
|
# Copyright (c) 2017, the ElectrumX authors
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# See the file "LICENCE" for information about the copyright
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
"""Interface to the blockchain database."""
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import array
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from asyncio import sleep
|
||||||
|
from bisect import bisect_right
|
||||||
|
from collections import namedtuple
|
||||||
|
from glob import glob
|
||||||
|
from struct import pack, unpack
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
|
from torba.server import util
|
||||||
|
from torba.server.hash import hash_to_hex_str, HASHX_LEN
|
||||||
|
from torba.server.merkle import Merkle, MerkleCache
|
||||||
|
from torba.server.util import formatted_time
|
||||||
|
from torba.server.storage import db_class
|
||||||
|
from torba.server.history import History
|
||||||
|
|
||||||
|
|
||||||
|
UTXO = namedtuple("UTXO", "tx_num tx_pos tx_hash height value")
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class FlushData:
|
||||||
|
height = attr.ib()
|
||||||
|
tx_count = attr.ib()
|
||||||
|
headers = attr.ib()
|
||||||
|
block_tx_hashes = attr.ib()
|
||||||
|
# The following are flushed to the UTXO DB if undo_infos is not None
|
||||||
|
undo_infos = attr.ib()
|
||||||
|
adds = attr.ib()
|
||||||
|
deletes = attr.ib()
|
||||||
|
tip = attr.ib()
|
||||||
|
|
||||||
|
|
||||||
|
class DB:
|
||||||
|
"""Simple wrapper of the backend database for querying.
|
||||||
|
|
||||||
|
Performs no DB update, though the DB will be cleaned on opening if
|
||||||
|
it was shutdown uncleanly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DB_VERSIONS = [6]
|
||||||
|
|
||||||
|
class DBError(Exception):
|
||||||
|
"""Raised on general DB errors generally indicating corruption."""
|
||||||
|
|
||||||
|
def __init__(self, env):
|
||||||
|
self.logger = util.class_logger(__name__, self.__class__.__name__)
|
||||||
|
self.env = env
|
||||||
|
self.coin = env.coin
|
||||||
|
|
||||||
|
# Setup block header size handlers
|
||||||
|
if self.coin.STATIC_BLOCK_HEADERS:
|
||||||
|
self.header_offset = self.coin.static_header_offset
|
||||||
|
self.header_len = self.coin.static_header_len
|
||||||
|
else:
|
||||||
|
self.header_offset = self.dynamic_header_offset
|
||||||
|
self.header_len = self.dynamic_header_len
|
||||||
|
|
||||||
|
self.logger.info(f'switching current directory to {env.db_dir}')
|
||||||
|
os.chdir(env.db_dir)
|
||||||
|
|
||||||
|
self.db_class = db_class(self.env.db_engine)
|
||||||
|
self.history = History()
|
||||||
|
self.utxo_db = None
|
||||||
|
self.tx_counts = None
|
||||||
|
self.last_flush = time.time()
|
||||||
|
|
||||||
|
self.logger.info(f'using {self.env.db_engine} for DB backend')
|
||||||
|
|
||||||
|
# Header merkle cache
|
||||||
|
self.merkle = Merkle()
|
||||||
|
self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes)
|
||||||
|
|
||||||
|
self.headers_file = util.LogicalFile('meta/headers', 2, 16000000)
|
||||||
|
self.tx_counts_file = util.LogicalFile('meta/txcounts', 2, 2000000)
|
||||||
|
self.hashes_file = util.LogicalFile('meta/hashes', 4, 16000000)
|
||||||
|
if not self.coin.STATIC_BLOCK_HEADERS:
|
||||||
|
self.headers_offsets_file = util.LogicalFile(
|
||||||
|
'meta/headers_offsets', 2, 16000000)
|
||||||
|
|
||||||
|
async def _read_tx_counts(self):
|
||||||
|
if self.tx_counts is not None:
|
||||||
|
return
|
||||||
|
# tx_counts[N] has the cumulative number of txs at the end of
|
||||||
|
# height N. So tx_counts[0] is 1 - the genesis coinbase
|
||||||
|
size = (self.db_height + 1) * 4
|
||||||
|
tx_counts = self.tx_counts_file.read(0, size)
|
||||||
|
assert len(tx_counts) == size
|
||||||
|
self.tx_counts = array.array('I', tx_counts)
|
||||||
|
if self.tx_counts:
|
||||||
|
assert self.db_tx_count == self.tx_counts[-1]
|
||||||
|
else:
|
||||||
|
assert self.db_tx_count == 0
|
||||||
|
|
||||||
|
async def _open_dbs(self, for_sync, compacting):
|
||||||
|
assert self.utxo_db is None
|
||||||
|
|
||||||
|
# First UTXO DB
|
||||||
|
self.utxo_db = self.db_class('utxo', for_sync)
|
||||||
|
if self.utxo_db.is_new:
|
||||||
|
self.logger.info('created new database')
|
||||||
|
self.logger.info('creating metadata directory')
|
||||||
|
os.mkdir('meta')
|
||||||
|
with util.open_file('COIN', create=True) as f:
|
||||||
|
f.write(f'ElectrumX databases and metadata for '
|
||||||
|
f'{self.coin.NAME} {self.coin.NET}'.encode())
|
||||||
|
if not self.coin.STATIC_BLOCK_HEADERS:
|
||||||
|
self.headers_offsets_file.write(0, bytes(8))
|
||||||
|
else:
|
||||||
|
self.logger.info(f'opened UTXO DB (for sync: {for_sync})')
|
||||||
|
self.read_utxo_state()
|
||||||
|
|
||||||
|
# Then history DB
|
||||||
|
self.utxo_flush_count = self.history.open_db(self.db_class, for_sync,
|
||||||
|
self.utxo_flush_count,
|
||||||
|
compacting)
|
||||||
|
self.clear_excess_undo_info()
|
||||||
|
|
||||||
|
# Read TX counts (requires meta directory)
|
||||||
|
await self._read_tx_counts()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.utxo_db.close()
|
||||||
|
self.history.close_db()
|
||||||
|
|
||||||
|
async def open_for_compacting(self):
|
||||||
|
await self._open_dbs(True, True)
|
||||||
|
|
||||||
|
async def open_for_sync(self):
|
||||||
|
"""Open the databases to sync to the daemon.
|
||||||
|
|
||||||
|
When syncing we want to reserve a lot of open files for the
|
||||||
|
synchronization. When serving clients we want the open files for
|
||||||
|
serving network connections.
|
||||||
|
"""
|
||||||
|
await self._open_dbs(True, False)
|
||||||
|
|
||||||
|
async def open_for_serving(self):
|
||||||
|
"""Open the databases for serving. If they are already open they are
|
||||||
|
closed first.
|
||||||
|
"""
|
||||||
|
if self.utxo_db:
|
||||||
|
self.logger.info('closing DBs to re-open for serving')
|
||||||
|
self.utxo_db.close()
|
||||||
|
self.history.close_db()
|
||||||
|
self.utxo_db = None
|
||||||
|
await self._open_dbs(False, False)
|
||||||
|
|
||||||
|
# Header merkle cache
|
||||||
|
|
||||||
|
async def populate_header_merkle_cache(self):
|
||||||
|
self.logger.info('populating header merkle cache...')
|
||||||
|
length = max(1, self.db_height - self.env.reorg_limit)
|
||||||
|
start = time.time()
|
||||||
|
await self.header_mc.initialize(length)
|
||||||
|
elapsed = time.time() - start
|
||||||
|
self.logger.info(f'header merkle cache populated in {elapsed:.1f}s')
|
||||||
|
|
||||||
|
async def header_branch_and_root(self, length, height):
|
||||||
|
return await self.header_mc.branch_and_root(length, height)
|
||||||
|
|
||||||
|
# Flushing
|
||||||
|
def assert_flushed(self, flush_data):
|
||||||
|
"""Asserts state is fully flushed."""
|
||||||
|
assert flush_data.tx_count == self.fs_tx_count == self.db_tx_count
|
||||||
|
assert flush_data.height == self.fs_height == self.db_height
|
||||||
|
assert flush_data.tip == self.db_tip
|
||||||
|
assert not flush_data.headers
|
||||||
|
assert not flush_data.block_tx_hashes
|
||||||
|
assert not flush_data.adds
|
||||||
|
assert not flush_data.deletes
|
||||||
|
assert not flush_data.undo_infos
|
||||||
|
self.history.assert_flushed()
|
||||||
|
|
||||||
|
def flush_dbs(self, flush_data, flush_utxos, estimate_txs_remaining):
|
||||||
|
"""Flush out cached state. History is always flushed; UTXOs are
|
||||||
|
flushed if flush_utxos."""
|
||||||
|
if flush_data.height == self.db_height:
|
||||||
|
self.assert_flushed(flush_data)
|
||||||
|
return
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
prior_flush = self.last_flush
|
||||||
|
tx_delta = flush_data.tx_count - self.last_flush_tx_count
|
||||||
|
|
||||||
|
# Flush to file system
|
||||||
|
self.flush_fs(flush_data)
|
||||||
|
|
||||||
|
# Then history
|
||||||
|
self.flush_history()
|
||||||
|
|
||||||
|
# Flush state last as it reads the wall time.
|
||||||
|
with self.utxo_db.write_batch() as batch:
|
||||||
|
if flush_utxos:
|
||||||
|
self.flush_utxo_db(batch, flush_data)
|
||||||
|
self.flush_state(batch)
|
||||||
|
|
||||||
|
# Update and put the wall time again - otherwise we drop the
|
||||||
|
# time it took to commit the batch
|
||||||
|
self.flush_state(self.utxo_db)
|
||||||
|
|
||||||
|
elapsed = self.last_flush - start_time
|
||||||
|
self.logger.info(f'flush #{self.history.flush_count:,d} took '
|
||||||
|
f'{elapsed:.1f}s. Height {flush_data.height:,d} '
|
||||||
|
f'txs: {flush_data.tx_count:,d} ({tx_delta:+,d})')
|
||||||
|
|
||||||
|
# Catch-up stats
|
||||||
|
if self.utxo_db.for_sync:
|
||||||
|
flush_interval = self.last_flush - prior_flush
|
||||||
|
tx_per_sec_gen = int(flush_data.tx_count / self.wall_time)
|
||||||
|
tx_per_sec_last = 1 + int(tx_delta / flush_interval)
|
||||||
|
eta = estimate_txs_remaining() / tx_per_sec_last
|
||||||
|
self.logger.info(f'tx/sec since genesis: {tx_per_sec_gen:,d}, '
|
||||||
|
f'since last flush: {tx_per_sec_last:,d}')
|
||||||
|
self.logger.info(f'sync time: {formatted_time(self.wall_time)} '
|
||||||
|
f'ETA: {formatted_time(eta)}')
|
||||||
|
|
||||||
|
def flush_fs(self, flush_data):
|
||||||
|
"""Write headers, tx counts and block tx hashes to the filesystem.
|
||||||
|
|
||||||
|
The first height to write is self.fs_height + 1. The FS
|
||||||
|
metadata is all append-only, so in a crash we just pick up
|
||||||
|
again from the height stored in the DB.
|
||||||
|
"""
|
||||||
|
prior_tx_count = (self.tx_counts[self.fs_height]
|
||||||
|
if self.fs_height >= 0 else 0)
|
||||||
|
assert len(flush_data.block_tx_hashes) == len(flush_data.headers)
|
||||||
|
assert flush_data.height == self.fs_height + len(flush_data.headers)
|
||||||
|
assert flush_data.tx_count == (self.tx_counts[-1] if self.tx_counts
|
||||||
|
else 0)
|
||||||
|
assert len(self.tx_counts) == flush_data.height + 1
|
||||||
|
hashes = b''.join(flush_data.block_tx_hashes)
|
||||||
|
flush_data.block_tx_hashes.clear()
|
||||||
|
assert len(hashes) % 32 == 0
|
||||||
|
assert len(hashes) // 32 == flush_data.tx_count - prior_tx_count
|
||||||
|
|
||||||
|
# Write the headers, tx counts, and tx hashes
|
||||||
|
start_time = time.time()
|
||||||
|
height_start = self.fs_height + 1
|
||||||
|
offset = self.header_offset(height_start)
|
||||||
|
self.headers_file.write(offset, b''.join(flush_data.headers))
|
||||||
|
self.fs_update_header_offsets(offset, height_start, flush_data.headers)
|
||||||
|
flush_data.headers.clear()
|
||||||
|
|
||||||
|
offset = height_start * self.tx_counts.itemsize
|
||||||
|
self.tx_counts_file.write(offset,
|
||||||
|
self.tx_counts[height_start:].tobytes())
|
||||||
|
offset = prior_tx_count * 32
|
||||||
|
self.hashes_file.write(offset, hashes)
|
||||||
|
|
||||||
|
self.fs_height = flush_data.height
|
||||||
|
self.fs_tx_count = flush_data.tx_count
|
||||||
|
|
||||||
|
if self.utxo_db.for_sync:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
self.logger.info(f'flushed filesystem data in {elapsed:.2f}s')
|
||||||
|
|
||||||
|
def flush_history(self):
|
||||||
|
self.history.flush()
|
||||||
|
|
||||||
|
def flush_utxo_db(self, batch, flush_data):
|
||||||
|
"""Flush the cached DB writes and UTXO set to the batch."""
|
||||||
|
# Care is needed because the writes generated by flushing the
|
||||||
|
# UTXO state may have keys in common with our write cache or
|
||||||
|
# may be in the DB already.
|
||||||
|
start_time = time.time()
|
||||||
|
add_count = len(flush_data.adds)
|
||||||
|
spend_count = len(flush_data.deletes) // 2
|
||||||
|
|
||||||
|
# Spends
|
||||||
|
batch_delete = batch.delete
|
||||||
|
for key in sorted(flush_data.deletes):
|
||||||
|
batch_delete(key)
|
||||||
|
flush_data.deletes.clear()
|
||||||
|
|
||||||
|
# New UTXOs
|
||||||
|
batch_put = batch.put
|
||||||
|
for key, value in flush_data.adds.items():
|
||||||
|
# suffix = tx_idx + tx_num
|
||||||
|
hashX = value[:-12]
|
||||||
|
suffix = key[-2:] + value[-12:-8]
|
||||||
|
batch_put(b'h' + key[:4] + suffix, hashX)
|
||||||
|
batch_put(b'u' + hashX + suffix, value[-8:])
|
||||||
|
flush_data.adds.clear()
|
||||||
|
|
||||||
|
# New undo information
|
||||||
|
self.flush_undo_infos(batch_put, flush_data.undo_infos)
|
||||||
|
flush_data.undo_infos.clear()
|
||||||
|
|
||||||
|
if self.utxo_db.for_sync:
|
||||||
|
block_count = flush_data.height - self.db_height
|
||||||
|
tx_count = flush_data.tx_count - self.db_tx_count
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
self.logger.info(f'flushed {block_count:,d} blocks with '
|
||||||
|
f'{tx_count:,d} txs, {add_count:,d} UTXO adds, '
|
||||||
|
f'{spend_count:,d} spends in '
|
||||||
|
f'{elapsed:.1f}s, committing...')
|
||||||
|
|
||||||
|
self.utxo_flush_count = self.history.flush_count
|
||||||
|
self.db_height = flush_data.height
|
||||||
|
self.db_tx_count = flush_data.tx_count
|
||||||
|
self.db_tip = flush_data.tip
|
||||||
|
|
||||||
|
def flush_state(self, batch):
|
||||||
|
"""Flush chain state to the batch."""
|
||||||
|
now = time.time()
|
||||||
|
self.wall_time += now - self.last_flush
|
||||||
|
self.last_flush = now
|
||||||
|
self.last_flush_tx_count = self.fs_tx_count
|
||||||
|
self.write_utxo_state(batch)
|
||||||
|
|
||||||
|
def flush_backup(self, flush_data, touched):
|
||||||
|
"""Like flush_dbs() but when backing up. All UTXOs are flushed."""
|
||||||
|
assert not flush_data.headers
|
||||||
|
assert not flush_data.block_tx_hashes
|
||||||
|
assert flush_data.height < self.db_height
|
||||||
|
self.history.assert_flushed()
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
tx_delta = flush_data.tx_count - self.last_flush_tx_count
|
||||||
|
|
||||||
|
self.backup_fs(flush_data.height, flush_data.tx_count)
|
||||||
|
self.history.backup(touched, flush_data.tx_count)
|
||||||
|
with self.utxo_db.write_batch() as batch:
|
||||||
|
self.flush_utxo_db(batch, flush_data)
|
||||||
|
# Flush state last as it reads the wall time.
|
||||||
|
self.flush_state(batch)
|
||||||
|
|
||||||
|
elapsed = self.last_flush - start_time
|
||||||
|
self.logger.info(f'backup flush #{self.history.flush_count:,d} took '
|
||||||
|
f'{elapsed:.1f}s. Height {flush_data.height:,d} '
|
||||||
|
f'txs: {flush_data.tx_count:,d} ({tx_delta:+,d})')
|
||||||
|
|
||||||
|
def fs_update_header_offsets(self, offset_start, height_start, headers):
|
||||||
|
if self.coin.STATIC_BLOCK_HEADERS:
|
||||||
|
return
|
||||||
|
offset = offset_start
|
||||||
|
offsets = []
|
||||||
|
for h in headers:
|
||||||
|
offset += len(h)
|
||||||
|
offsets.append(pack("<Q", offset))
|
||||||
|
# For each header we get the offset of the next header, hence we
|
||||||
|
# start writing from the next height
|
||||||
|
pos = (height_start + 1) * 8
|
||||||
|
self.headers_offsets_file.write(pos, b''.join(offsets))
|
||||||
|
|
||||||
|
def dynamic_header_offset(self, height):
|
||||||
|
assert not self.coin.STATIC_BLOCK_HEADERS
|
||||||
|
offset, = unpack('<Q', self.headers_offsets_file.read(height * 8, 8))
|
||||||
|
return offset
|
||||||
|
|
||||||
|
def dynamic_header_len(self, height):
|
||||||
|
return self.dynamic_header_offset(height + 1)\
|
||||||
|
- self.dynamic_header_offset(height)
|
||||||
|
|
||||||
|
def backup_fs(self, height, tx_count):
|
||||||
|
"""Back up during a reorg. This just updates our pointers."""
|
||||||
|
self.fs_height = height
|
||||||
|
self.fs_tx_count = tx_count
|
||||||
|
# Truncate header_mc: header count is 1 more than the height.
|
||||||
|
self.header_mc.truncate(height + 1)
|
||||||
|
|
||||||
|
async def raw_header(self, height):
|
||||||
|
"""Return the binary header at the given height."""
|
||||||
|
header, n = await self.read_headers(height, 1)
|
||||||
|
if n != 1:
|
||||||
|
raise IndexError(f'height {height:,d} out of range')
|
||||||
|
return header
|
||||||
|
|
||||||
|
async def read_headers(self, start_height, count):
|
||||||
|
"""Requires start_height >= 0, count >= 0. Reads as many headers as
|
||||||
|
are available starting at start_height up to count. This
|
||||||
|
would be zero if start_height is beyond self.db_height, for
|
||||||
|
example.
|
||||||
|
|
||||||
|
Returns a (binary, n) pair where binary is the concatenated
|
||||||
|
binary headers, and n is the count of headers returned.
|
||||||
|
"""
|
||||||
|
if start_height < 0 or count < 0:
|
||||||
|
raise self.DBError(f'{count:,d} headers starting at '
|
||||||
|
f'{start_height:,d} not on disk')
|
||||||
|
|
||||||
|
def read_headers():
|
||||||
|
# Read some from disk
|
||||||
|
disk_count = max(0, min(count, self.db_height + 1 - start_height))
|
||||||
|
if disk_count:
|
||||||
|
offset = self.header_offset(start_height)
|
||||||
|
size = self.header_offset(start_height + disk_count) - offset
|
||||||
|
return self.headers_file.read(offset, size), disk_count
|
||||||
|
return b'', 0
|
||||||
|
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(None, read_headers)
|
||||||
|
|
||||||
|
def fs_tx_hash(self, tx_num):
|
||||||
|
"""Return a par (tx_hash, tx_height) for the given tx number.
|
||||||
|
|
||||||
|
If the tx_height is not on disk, returns (None, tx_height)."""
|
||||||
|
tx_height = bisect_right(self.tx_counts, tx_num)
|
||||||
|
if tx_height > self.db_height:
|
||||||
|
tx_hash = None
|
||||||
|
else:
|
||||||
|
tx_hash = self.hashes_file.read(tx_num * 32, 32)
|
||||||
|
return tx_hash, tx_height
|
||||||
|
|
||||||
|
async def fs_block_hashes(self, height, count):
|
||||||
|
headers_concat, headers_count = await self.read_headers(height, count)
|
||||||
|
if headers_count != count:
|
||||||
|
raise self.DBError('only got {:,d} headers starting at {:,d}, not '
|
||||||
|
'{:,d}'.format(headers_count, height, count))
|
||||||
|
offset = 0
|
||||||
|
headers = []
|
||||||
|
for n in range(count):
|
||||||
|
hlen = self.header_len(height + n)
|
||||||
|
headers.append(headers_concat[offset:offset + hlen])
|
||||||
|
offset += hlen
|
||||||
|
|
||||||
|
return [self.coin.header_hash(header) for header in headers]
|
||||||
|
|
||||||
|
async def limited_history(self, hashX, *, limit=1000):
|
||||||
|
"""Return an unpruned, sorted list of (tx_hash, height) tuples of
|
||||||
|
confirmed transactions that touched the address, earliest in
|
||||||
|
the blockchain first. Includes both spending and receiving
|
||||||
|
transactions. By default returns at most 1000 entries. Set
|
||||||
|
limit to None to get them all.
|
||||||
|
"""
|
||||||
|
def read_history():
|
||||||
|
tx_nums = list(self.history.get_txnums(hashX, limit))
|
||||||
|
fs_tx_hash = self.fs_tx_hash
|
||||||
|
return [fs_tx_hash(tx_num) for tx_num in tx_nums]
|
||||||
|
|
||||||
|
while True:
|
||||||
|
history = await asyncio.get_event_loop().run_in_executor(None, read_history)
|
||||||
|
if all(hash is not None for hash, height in history):
|
||||||
|
return history
|
||||||
|
self.logger.warning(f'limited_history: tx hash '
|
||||||
|
f'not found (reorg?), retrying...')
|
||||||
|
await sleep(0.25)
|
||||||
|
|
||||||
|
# -- Undo information
|
||||||
|
|
||||||
|
def min_undo_height(self, max_height):
|
||||||
|
"""Returns a height from which we should store undo info."""
|
||||||
|
return max_height - self.env.reorg_limit + 1
|
||||||
|
|
||||||
|
def undo_key(self, height):
|
||||||
|
"""DB key for undo information at the given height."""
|
||||||
|
return b'U' + pack('>I', height)
|
||||||
|
|
||||||
|
def read_undo_info(self, height):
|
||||||
|
"""Read undo information from a file for the current height."""
|
||||||
|
return self.utxo_db.get(self.undo_key(height))
|
||||||
|
|
||||||
|
def flush_undo_infos(self, batch_put, undo_infos):
|
||||||
|
"""undo_infos is a list of (undo_info, height) pairs."""
|
||||||
|
for undo_info, height in undo_infos:
|
||||||
|
batch_put(self.undo_key(height), b''.join(undo_info))
|
||||||
|
|
||||||
|
def raw_block_prefix(self):
|
||||||
|
return 'meta/block'
|
||||||
|
|
||||||
|
def raw_block_path(self, height):
|
||||||
|
return f'{self.raw_block_prefix()}{height:d}'
|
||||||
|
|
||||||
|
def read_raw_block(self, height):
|
||||||
|
"""Returns a raw block read from disk. Raises FileNotFoundError
|
||||||
|
if the block isn't on-disk."""
|
||||||
|
with util.open_file(self.raw_block_path(height)) as f:
|
||||||
|
return f.read(-1)
|
||||||
|
|
||||||
|
def write_raw_block(self, block, height):
|
||||||
|
"""Write a raw block to disk."""
|
||||||
|
with util.open_truncate(self.raw_block_path(height)) as f:
|
||||||
|
f.write(block)
|
||||||
|
# Delete old blocks to prevent them accumulating
|
||||||
|
try:
|
||||||
|
del_height = self.min_undo_height(height) - 1
|
||||||
|
os.remove(self.raw_block_path(del_height))
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def clear_excess_undo_info(self):
|
||||||
|
"""Clear excess undo info. Only most recent N are kept."""
|
||||||
|
prefix = b'U'
|
||||||
|
min_height = self.min_undo_height(self.db_height)
|
||||||
|
keys = []
|
||||||
|
for key, hist in self.utxo_db.iterator(prefix=prefix):
|
||||||
|
height, = unpack('>I', key[-4:])
|
||||||
|
if height >= min_height:
|
||||||
|
break
|
||||||
|
keys.append(key)
|
||||||
|
|
||||||
|
if keys:
|
||||||
|
with self.utxo_db.write_batch() as batch:
|
||||||
|
for key in keys:
|
||||||
|
batch.delete(key)
|
||||||
|
self.logger.info(f'deleted {len(keys):,d} stale undo entries')
|
||||||
|
|
||||||
|
# delete old block files
|
||||||
|
prefix = self.raw_block_prefix()
|
||||||
|
paths = [path for path in glob(f'{prefix}[0-9]*')
|
||||||
|
if len(path) > len(prefix)
|
||||||
|
and int(path[len(prefix):]) < min_height]
|
||||||
|
if paths:
|
||||||
|
for path in paths:
|
||||||
|
try:
|
||||||
|
os.remove(path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
self.logger.info(f'deleted {len(paths):,d} stale block files')
|
||||||
|
|
||||||
|
# -- UTXO database
|
||||||
|
|
||||||
|
def read_utxo_state(self):
|
||||||
|
state = self.utxo_db.get(b'state')
|
||||||
|
if not state:
|
||||||
|
self.db_height = -1
|
||||||
|
self.db_tx_count = 0
|
||||||
|
self.db_tip = b'\0' * 32
|
||||||
|
self.db_version = max(self.DB_VERSIONS)
|
||||||
|
self.utxo_flush_count = 0
|
||||||
|
self.wall_time = 0
|
||||||
|
self.first_sync = True
|
||||||
|
else:
|
||||||
|
state = ast.literal_eval(state.decode())
|
||||||
|
if not isinstance(state, dict):
|
||||||
|
raise self.DBError('failed reading state from DB')
|
||||||
|
self.db_version = state['db_version']
|
||||||
|
if self.db_version not in self.DB_VERSIONS:
|
||||||
|
raise self.DBError('your UTXO DB version is {} but this '
|
||||||
|
'software only handles versions {}'
|
||||||
|
.format(self.db_version, self.DB_VERSIONS))
|
||||||
|
# backwards compat
|
||||||
|
genesis_hash = state['genesis']
|
||||||
|
if isinstance(genesis_hash, bytes):
|
||||||
|
genesis_hash = genesis_hash.decode()
|
||||||
|
if genesis_hash != self.coin.GENESIS_HASH:
|
||||||
|
raise self.DBError('DB genesis hash {} does not match coin {}'
|
||||||
|
.format(genesis_hash,
|
||||||
|
self.coin.GENESIS_HASH))
|
||||||
|
self.db_height = state['height']
|
||||||
|
self.db_tx_count = state['tx_count']
|
||||||
|
self.db_tip = state['tip']
|
||||||
|
self.utxo_flush_count = state['utxo_flush_count']
|
||||||
|
self.wall_time = state['wall_time']
|
||||||
|
self.first_sync = state['first_sync']
|
||||||
|
|
||||||
|
# These are our state as we move ahead of DB state
|
||||||
|
self.fs_height = self.db_height
|
||||||
|
self.fs_tx_count = self.db_tx_count
|
||||||
|
self.last_flush_tx_count = self.fs_tx_count
|
||||||
|
|
||||||
|
# Log some stats
|
||||||
|
self.logger.info('DB version: {:d}'.format(self.db_version))
|
||||||
|
self.logger.info('coin: {}'.format(self.coin.NAME))
|
||||||
|
self.logger.info('network: {}'.format(self.coin.NET))
|
||||||
|
self.logger.info('height: {:,d}'.format(self.db_height))
|
||||||
|
self.logger.info('tip: {}'.format(hash_to_hex_str(self.db_tip)))
|
||||||
|
self.logger.info('tx count: {:,d}'.format(self.db_tx_count))
|
||||||
|
if self.utxo_db.for_sync:
|
||||||
|
self.logger.info(f'flushing DB cache at {self.env.cache_MB:,d} MB')
|
||||||
|
if self.first_sync:
|
||||||
|
self.logger.info('sync time so far: {}'
|
||||||
|
.format(util.formatted_time(self.wall_time)))
|
||||||
|
|
||||||
|
def write_utxo_state(self, batch):
|
||||||
|
"""Write (UTXO) state to the batch."""
|
||||||
|
state = {
|
||||||
|
'genesis': self.coin.GENESIS_HASH,
|
||||||
|
'height': self.db_height,
|
||||||
|
'tx_count': self.db_tx_count,
|
||||||
|
'tip': self.db_tip,
|
||||||
|
'utxo_flush_count': self.utxo_flush_count,
|
||||||
|
'wall_time': self.wall_time,
|
||||||
|
'first_sync': self.first_sync,
|
||||||
|
'db_version': self.db_version,
|
||||||
|
}
|
||||||
|
batch.put(b'state', repr(state).encode())
|
||||||
|
|
||||||
|
def set_flush_count(self, count):
|
||||||
|
self.utxo_flush_count = count
|
||||||
|
with self.utxo_db.write_batch() as batch:
|
||||||
|
self.write_utxo_state(batch)
|
||||||
|
|
||||||
|
async def all_utxos(self, hashX):
|
||||||
|
"""Return all UTXOs for an address sorted in no particular order."""
|
||||||
|
def read_utxos():
|
||||||
|
utxos = []
|
||||||
|
utxos_append = utxos.append
|
||||||
|
s_unpack = unpack
|
||||||
|
# Key: b'u' + address_hashX + tx_idx + tx_num
|
||||||
|
# Value: the UTXO value as a 64-bit unsigned integer
|
||||||
|
prefix = b'u' + hashX
|
||||||
|
for db_key, db_value in self.utxo_db.iterator(prefix=prefix):
|
||||||
|
tx_pos, tx_num = s_unpack('<HI', db_key[-6:])
|
||||||
|
value, = unpack('<Q', db_value)
|
||||||
|
tx_hash, height = self.fs_tx_hash(tx_num)
|
||||||
|
utxos_append(UTXO(tx_num, tx_pos, tx_hash, height, value))
|
||||||
|
return utxos
|
||||||
|
|
||||||
|
while True:
|
||||||
|
utxos = await asyncio.get_event_loop().run_in_executor(None, read_utxos)
|
||||||
|
if all(utxo.tx_hash is not None for utxo in utxos):
|
||||||
|
return utxos
|
||||||
|
self.logger.warning(f'all_utxos: tx hash not '
|
||||||
|
f'found (reorg?), retrying...')
|
||||||
|
await sleep(0.25)
|
||||||
|
|
||||||
|
async def lookup_utxos(self, prevouts):
|
||||||
|
"""For each prevout, lookup it up in the DB and return a (hashX,
|
||||||
|
value) pair or None if not found.
|
||||||
|
|
||||||
|
Used by the mempool code.
|
||||||
|
"""
|
||||||
|
def lookup_hashXs():
|
||||||
|
"""Return (hashX, suffix) pairs, or None if not found,
|
||||||
|
for each prevout.
|
||||||
|
"""
|
||||||
|
def lookup_hashX(tx_hash, tx_idx):
|
||||||
|
idx_packed = pack('<H', tx_idx)
|
||||||
|
|
||||||
|
# Key: b'h' + compressed_tx_hash + tx_idx + tx_num
|
||||||
|
# Value: hashX
|
||||||
|
prefix = b'h' + tx_hash[:4] + idx_packed
|
||||||
|
|
||||||
|
# Find which entry, if any, the TX_HASH matches.
|
||||||
|
for db_key, hashX in self.utxo_db.iterator(prefix=prefix):
|
||||||
|
tx_num_packed = db_key[-4:]
|
||||||
|
tx_num, = unpack('<I', tx_num_packed)
|
||||||
|
hash, height = self.fs_tx_hash(tx_num)
|
||||||
|
if hash == tx_hash:
|
||||||
|
return hashX, idx_packed + tx_num_packed
|
||||||
|
return None, None
|
||||||
|
return [lookup_hashX(*prevout) for prevout in prevouts]
|
||||||
|
|
||||||
|
def lookup_utxos(hashX_pairs):
|
||||||
|
def lookup_utxo(hashX, suffix):
|
||||||
|
if not hashX:
|
||||||
|
# This can happen when the daemon is a block ahead
|
||||||
|
# of us and has mempool txs spending outputs from
|
||||||
|
# that new block
|
||||||
|
return None
|
||||||
|
# Key: b'u' + address_hashX + tx_idx + tx_num
|
||||||
|
# Value: the UTXO value as a 64-bit unsigned integer
|
||||||
|
key = b'u' + hashX + suffix
|
||||||
|
db_value = self.utxo_db.get(key)
|
||||||
|
if not db_value:
|
||||||
|
# This can happen if the DB was updated between
|
||||||
|
# getting the hashXs and getting the UTXOs
|
||||||
|
return None
|
||||||
|
value, = unpack('<Q', db_value)
|
||||||
|
return hashX, value
|
||||||
|
return [lookup_utxo(*hashX_pair) for hashX_pair in hashX_pairs]
|
||||||
|
|
||||||
|
hashX_pairs = await asyncio.get_event_loop().run_in_executor(None, lookup_hashXs)
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(None, lookup_utxos, hashX_pairs)
|
54
torba/torba/server/enum.py
Normal file
54
torba/torba/server/enum.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright (c) 2016, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# See the file "LICENCE" for information about the copyright
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
"""An enum-like type with reverse lookup.
|
||||||
|
|
||||||
|
Source: Python Cookbook, http://code.activestate.com/recipes/67107/
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class EnumError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Enumeration:
|
||||||
|
|
||||||
|
def __init__(self, name, enumList):
|
||||||
|
self.__doc__ = name
|
||||||
|
|
||||||
|
lookup = {}
|
||||||
|
reverseLookup = {}
|
||||||
|
i = 0
|
||||||
|
uniqueNames = set()
|
||||||
|
uniqueValues = set()
|
||||||
|
for x in enumList:
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
x, i = x
|
||||||
|
if not isinstance(x, str):
|
||||||
|
raise EnumError("enum name {} not a string".format(x))
|
||||||
|
if not isinstance(i, int):
|
||||||
|
raise EnumError("enum value {} not an integer".format(i))
|
||||||
|
if x in uniqueNames:
|
||||||
|
raise EnumError("enum name {} not unique".format(x))
|
||||||
|
if i in uniqueValues:
|
||||||
|
raise EnumError("enum value {} not unique".format(x))
|
||||||
|
uniqueNames.add(x)
|
||||||
|
uniqueValues.add(i)
|
||||||
|
lookup[x] = i
|
||||||
|
reverseLookup[i] = x
|
||||||
|
i = i + 1
|
||||||
|
self.lookup = lookup
|
||||||
|
self.reverseLookup = reverseLookup
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
result = self.lookup.get(attr)
|
||||||
|
if result is None:
|
||||||
|
raise AttributeError('enumeration has no member {}'.format(attr))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def whatis(self, value):
|
||||||
|
return self.reverseLookup[value]
|
246
torba/torba/server/env.py
Normal file
246
torba/torba/server/env.py
Normal file
|
@ -0,0 +1,246 @@
|
||||||
|
# Copyright (c) 2016, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# See the file "LICENCE" for information about the copyright
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
|
||||||
|
import re
|
||||||
|
import resource
|
||||||
|
from os import environ
|
||||||
|
from collections import namedtuple
|
||||||
|
from ipaddress import ip_address
|
||||||
|
|
||||||
|
from torba.server.util import class_logger
|
||||||
|
from torba.server.coins import Coin
|
||||||
|
import torba.server.util as lib_util
|
||||||
|
|
||||||
|
|
||||||
|
NetIdentity = namedtuple('NetIdentity', 'host tcp_port ssl_port nick_suffix')
|
||||||
|
|
||||||
|
|
||||||
|
class Env:
|
||||||
|
|
||||||
|
# Peer discovery
|
||||||
|
PD_OFF, PD_SELF, PD_ON = range(3)
|
||||||
|
|
||||||
|
class Error(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, coin=None):
|
||||||
|
self.logger = class_logger(__name__, self.__class__.__name__)
|
||||||
|
self.allow_root = self.boolean('ALLOW_ROOT', False)
|
||||||
|
self.host = self.default('HOST', 'localhost')
|
||||||
|
self.rpc_host = self.default('RPC_HOST', 'localhost')
|
||||||
|
self.loop_policy = self.event_loop_policy()
|
||||||
|
self.obsolete(['UTXO_MB', 'HIST_MB', 'NETWORK'])
|
||||||
|
self.db_dir = self.required('DB_DIRECTORY')
|
||||||
|
self.db_engine = self.default('DB_ENGINE', 'leveldb')
|
||||||
|
self.daemon_url = self.required('DAEMON_URL')
|
||||||
|
if coin is not None:
|
||||||
|
assert issubclass(coin, Coin)
|
||||||
|
self.coin = coin
|
||||||
|
else:
|
||||||
|
coin_name = self.required('COIN').strip()
|
||||||
|
network = self.default('NET', 'mainnet').strip()
|
||||||
|
self.coin = Coin.lookup_coin_class(coin_name, network)
|
||||||
|
self.cache_MB = self.integer('CACHE_MB', 1200)
|
||||||
|
self.host = self.default('HOST', 'localhost')
|
||||||
|
self.reorg_limit = self.integer('REORG_LIMIT', self.coin.REORG_LIMIT)
|
||||||
|
# Server stuff
|
||||||
|
self.tcp_port = self.integer('TCP_PORT', None)
|
||||||
|
self.ssl_port = self.integer('SSL_PORT', None)
|
||||||
|
if self.ssl_port:
|
||||||
|
self.ssl_certfile = self.required('SSL_CERTFILE')
|
||||||
|
self.ssl_keyfile = self.required('SSL_KEYFILE')
|
||||||
|
self.rpc_port = self.integer('RPC_PORT', 8000)
|
||||||
|
self.max_subscriptions = self.integer('MAX_SUBSCRIPTIONS', 10000)
|
||||||
|
self.banner_file = self.default('BANNER_FILE', None)
|
||||||
|
self.tor_banner_file = self.default('TOR_BANNER_FILE', self.banner_file)
|
||||||
|
self.anon_logs = self.boolean('ANON_LOGS', False)
|
||||||
|
self.log_sessions = self.integer('LOG_SESSIONS', 3600)
|
||||||
|
# Peer discovery
|
||||||
|
self.peer_discovery = self.peer_discovery_enum()
|
||||||
|
self.peer_announce = self.boolean('PEER_ANNOUNCE', True)
|
||||||
|
self.force_proxy = self.boolean('FORCE_PROXY', False)
|
||||||
|
self.tor_proxy_host = self.default('TOR_PROXY_HOST', 'localhost')
|
||||||
|
self.tor_proxy_port = self.integer('TOR_PROXY_PORT', None)
|
||||||
|
# The electrum client takes the empty string as unspecified
|
||||||
|
self.donation_address = self.default('DONATION_ADDRESS', '')
|
||||||
|
# Server limits to help prevent DoS
|
||||||
|
self.max_send = self.integer('MAX_SEND', 1000000)
|
||||||
|
self.max_subs = self.integer('MAX_SUBS', 250000)
|
||||||
|
self.max_sessions = self.sane_max_sessions()
|
||||||
|
self.max_session_subs = self.integer('MAX_SESSION_SUBS', 50000)
|
||||||
|
self.bandwidth_limit = self.integer('BANDWIDTH_LIMIT', 2000000)
|
||||||
|
self.session_timeout = self.integer('SESSION_TIMEOUT', 600)
|
||||||
|
self.drop_client = self.custom("DROP_CLIENT", None, re.compile)
|
||||||
|
|
||||||
|
# Identities
|
||||||
|
clearnet_identity = self.clearnet_identity()
|
||||||
|
tor_identity = self.tor_identity(clearnet_identity)
|
||||||
|
self.identities = [identity
|
||||||
|
for identity in (clearnet_identity, tor_identity)
|
||||||
|
if identity is not None]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(cls, envvar, default):
|
||||||
|
return environ.get(envvar, default)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def boolean(cls, envvar, default):
|
||||||
|
default = 'Yes' if default else ''
|
||||||
|
return bool(cls.default(envvar, default).strip())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def required(cls, envvar):
|
||||||
|
value = environ.get(envvar)
|
||||||
|
if value is None:
|
||||||
|
raise cls.Error('required envvar {} not set'.format(envvar))
|
||||||
|
return value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def integer(cls, envvar, default):
|
||||||
|
value = environ.get(envvar)
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except Exception:
|
||||||
|
raise cls.Error('cannot convert envvar {} value {} to an integer'
|
||||||
|
.format(envvar, value))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def custom(cls, envvar, default, parse):
|
||||||
|
value = environ.get(envvar)
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return parse(value)
|
||||||
|
except Exception as e:
|
||||||
|
raise cls.Error('cannot parse envvar {} value {}'
|
||||||
|
.format(envvar, value)) from e
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def obsolete(cls, envvars):
|
||||||
|
bad = [envvar for envvar in envvars if environ.get(envvar)]
|
||||||
|
if bad:
|
||||||
|
raise cls.Error('remove obsolete environment variables {}'
|
||||||
|
.format(bad))
|
||||||
|
|
||||||
|
def event_loop_policy(self):
|
||||||
|
policy = self.default('EVENT_LOOP_POLICY', None)
|
||||||
|
if policy is None:
|
||||||
|
return None
|
||||||
|
if policy == 'uvloop':
|
||||||
|
import uvloop
|
||||||
|
return uvloop.EventLoopPolicy()
|
||||||
|
raise self.Error('unknown event loop policy "{}"'.format(policy))
|
||||||
|
|
||||||
|
def cs_host(self, *, for_rpc):
|
||||||
|
"""Returns the 'host' argument to pass to asyncio's create_server
|
||||||
|
call. The result can be a single host name string, a list of
|
||||||
|
host name strings, or an empty string to bind to all interfaces.
|
||||||
|
|
||||||
|
If rpc is True the host to use for the RPC server is returned.
|
||||||
|
Otherwise the host to use for SSL/TCP servers is returned.
|
||||||
|
"""
|
||||||
|
host = self.rpc_host if for_rpc else self.host
|
||||||
|
result = [part.strip() for part in host.split(',')]
|
||||||
|
if len(result) == 1:
|
||||||
|
result = result[0]
|
||||||
|
# An empty result indicates all interfaces, which we do not
|
||||||
|
# permitted for an RPC server.
|
||||||
|
if for_rpc and not result:
|
||||||
|
result = 'localhost'
|
||||||
|
if result == 'localhost':
|
||||||
|
# 'localhost' resolves to ::1 (ipv6) on many systems, which fails on default setup of
|
||||||
|
# docker, using 127.0.0.1 instead forces ipv4
|
||||||
|
result = '127.0.0.1'
|
||||||
|
return result
|
||||||
|
|
||||||
|
def sane_max_sessions(self):
|
||||||
|
"""Return the maximum number of sessions to permit. Normally this
|
||||||
|
is MAX_SESSIONS. However, to prevent open file exhaustion, ajdust
|
||||||
|
downwards if running with a small open file rlimit."""
|
||||||
|
env_value = self.integer('MAX_SESSIONS', 1000)
|
||||||
|
nofile_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
|
||||||
|
# We give the DB 250 files; allow ElectrumX 100 for itself
|
||||||
|
value = max(0, min(env_value, nofile_limit - 350))
|
||||||
|
if value < env_value:
|
||||||
|
self.logger.warning('lowered maximum sessions from {:,d} to {:,d} '
|
||||||
|
'because your open file limit is {:,d}'
|
||||||
|
.format(env_value, value, nofile_limit))
|
||||||
|
return value
|
||||||
|
|
||||||
|
def clearnet_identity(self):
|
||||||
|
host = self.default('REPORT_HOST', None)
|
||||||
|
if host is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
ip = ip_address(host)
|
||||||
|
except ValueError:
|
||||||
|
bad = (not lib_util.is_valid_hostname(host)
|
||||||
|
or host.lower() == 'localhost')
|
||||||
|
else:
|
||||||
|
bad = (ip.is_multicast or ip.is_unspecified
|
||||||
|
or (ip.is_private and self.peer_announce))
|
||||||
|
if bad:
|
||||||
|
raise self.Error('"{}" is not a valid REPORT_HOST'.format(host))
|
||||||
|
tcp_port = self.integer('REPORT_TCP_PORT', self.tcp_port) or None
|
||||||
|
ssl_port = self.integer('REPORT_SSL_PORT', self.ssl_port) or None
|
||||||
|
if tcp_port == ssl_port:
|
||||||
|
raise self.Error('REPORT_TCP_PORT and REPORT_SSL_PORT '
|
||||||
|
'both resolve to {}'.format(tcp_port))
|
||||||
|
return NetIdentity(
|
||||||
|
host,
|
||||||
|
tcp_port,
|
||||||
|
ssl_port,
|
||||||
|
''
|
||||||
|
)
|
||||||
|
|
||||||
|
def tor_identity(self, clearnet):
|
||||||
|
host = self.default('REPORT_HOST_TOR', None)
|
||||||
|
if host is None:
|
||||||
|
return None
|
||||||
|
if not host.endswith('.onion'):
|
||||||
|
raise self.Error('tor host "{}" must end with ".onion"'
|
||||||
|
.format(host))
|
||||||
|
|
||||||
|
def port(port_kind):
|
||||||
|
"""Returns the clearnet identity port, if any and not zero,
|
||||||
|
otherwise the listening port."""
|
||||||
|
result = 0
|
||||||
|
if clearnet:
|
||||||
|
result = getattr(clearnet, port_kind)
|
||||||
|
return result or getattr(self, port_kind)
|
||||||
|
|
||||||
|
tcp_port = self.integer('REPORT_TCP_PORT_TOR',
|
||||||
|
port('tcp_port')) or None
|
||||||
|
ssl_port = self.integer('REPORT_SSL_PORT_TOR',
|
||||||
|
port('ssl_port')) or None
|
||||||
|
if tcp_port == ssl_port:
|
||||||
|
raise self.Error('REPORT_TCP_PORT_TOR and REPORT_SSL_PORT_TOR '
|
||||||
|
'both resolve to {}'.format(tcp_port))
|
||||||
|
|
||||||
|
return NetIdentity(
|
||||||
|
host,
|
||||||
|
tcp_port,
|
||||||
|
ssl_port,
|
||||||
|
'_tor',
|
||||||
|
)
|
||||||
|
|
||||||
|
def hosts_dict(self):
|
||||||
|
return {identity.host: {'tcp_port': identity.tcp_port,
|
||||||
|
'ssl_port': identity.ssl_port}
|
||||||
|
for identity in self.identities}
|
||||||
|
|
||||||
|
def peer_discovery_enum(self):
|
||||||
|
pd = self.default('PEER_DISCOVERY', 'on').strip().lower()
|
||||||
|
if pd in ('off', ''):
|
||||||
|
return self.PD_OFF
|
||||||
|
elif pd == 'self':
|
||||||
|
return self.PD_SELF
|
||||||
|
else:
|
||||||
|
return self.PD_ON
|
159
torba/torba/server/hash.py
Normal file
159
torba/torba/server/hash.py
Normal file
|
@ -0,0 +1,159 @@
|
||||||
|
# Copyright (c) 2016-2017, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
"""Cryptograph hash functions and related classes."""
|
||||||
|
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
|
||||||
|
from torba.server.util import bytes_to_int, int_to_bytes, hex_to_bytes
|
||||||
|
|
||||||
|
_sha256 = hashlib.sha256
|
||||||
|
_sha512 = hashlib.sha512
|
||||||
|
_new_hash = hashlib.new
|
||||||
|
_new_hmac = hmac.new
|
||||||
|
HASHX_LEN = 11
|
||||||
|
|
||||||
|
|
||||||
|
def sha256(x):
|
||||||
|
"""Simple wrapper of hashlib sha256."""
|
||||||
|
return _sha256(x).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def ripemd160(x):
|
||||||
|
"""Simple wrapper of hashlib ripemd160."""
|
||||||
|
h = _new_hash('ripemd160')
|
||||||
|
h.update(x)
|
||||||
|
return h.digest()
|
||||||
|
|
||||||
|
|
||||||
|
def double_sha256(x):
|
||||||
|
"""SHA-256 of SHA-256, as used extensively in bitcoin."""
|
||||||
|
return sha256(sha256(x))
|
||||||
|
|
||||||
|
|
||||||
|
def hmac_sha512(key, msg):
|
||||||
|
"""Use SHA-512 to provide an HMAC."""
|
||||||
|
return _new_hmac(key, msg, _sha512).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def hash160(x):
|
||||||
|
"""RIPEMD-160 of SHA-256.
|
||||||
|
|
||||||
|
Used to make bitcoin addresses from pubkeys."""
|
||||||
|
return ripemd160(sha256(x))
|
||||||
|
|
||||||
|
|
||||||
|
def hash_to_hex_str(x):
|
||||||
|
"""Convert a big-endian binary hash to displayed hex string.
|
||||||
|
|
||||||
|
Display form of a binary hash is reversed and converted to hex.
|
||||||
|
"""
|
||||||
|
return bytes(reversed(x)).hex()
|
||||||
|
|
||||||
|
|
||||||
|
def hex_str_to_hash(x):
|
||||||
|
"""Convert a displayed hex string to a binary hash."""
|
||||||
|
return bytes(reversed(hex_to_bytes(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class Base58Error(Exception):
|
||||||
|
"""Exception used for Base58 errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class Base58:
|
||||||
|
"""Class providing base 58 functionality."""
|
||||||
|
|
||||||
|
chars = '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz'
|
||||||
|
assert len(chars) == 58
|
||||||
|
cmap = {c: n for n, c in enumerate(chars)}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def char_value(c):
|
||||||
|
val = Base58.cmap.get(c)
|
||||||
|
if val is None:
|
||||||
|
raise Base58Error('invalid base 58 character "{}"'.format(c))
|
||||||
|
return val
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode(txt):
|
||||||
|
"""Decodes txt into a big-endian bytearray."""
|
||||||
|
if not isinstance(txt, str):
|
||||||
|
raise TypeError('a string is required')
|
||||||
|
|
||||||
|
if not txt:
|
||||||
|
raise Base58Error('string cannot be empty')
|
||||||
|
|
||||||
|
value = 0
|
||||||
|
for c in txt:
|
||||||
|
value = value * 58 + Base58.char_value(c)
|
||||||
|
|
||||||
|
result = int_to_bytes(value)
|
||||||
|
|
||||||
|
# Prepend leading zero bytes if necessary
|
||||||
|
count = 0
|
||||||
|
for c in txt:
|
||||||
|
if c != '1':
|
||||||
|
break
|
||||||
|
count += 1
|
||||||
|
if count:
|
||||||
|
result = bytes(count) + result
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode(be_bytes):
|
||||||
|
"""Converts a big-endian bytearray into a base58 string."""
|
||||||
|
value = bytes_to_int(be_bytes)
|
||||||
|
|
||||||
|
txt = ''
|
||||||
|
while value:
|
||||||
|
value, mod = divmod(value, 58)
|
||||||
|
txt += Base58.chars[mod]
|
||||||
|
|
||||||
|
for byte in be_bytes:
|
||||||
|
if byte != 0:
|
||||||
|
break
|
||||||
|
txt += '1'
|
||||||
|
|
||||||
|
return txt[::-1]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode_check(txt, *, hash_fn=double_sha256):
|
||||||
|
"""Decodes a Base58Check-encoded string to a payload. The version
|
||||||
|
prefixes it."""
|
||||||
|
be_bytes = Base58.decode(txt)
|
||||||
|
result, check = be_bytes[:-4], be_bytes[-4:]
|
||||||
|
if check != hash_fn(result)[:4]:
|
||||||
|
raise Base58Error('invalid base 58 checksum for {}'.format(txt))
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode_check(payload, *, hash_fn=double_sha256):
|
||||||
|
"""Encodes a payload bytearray (which includes the version byte(s))
|
||||||
|
into a Base58Check string."""
|
||||||
|
be_bytes = payload + hash_fn(payload)[:4]
|
||||||
|
return Base58.encode(be_bytes)
|
324
torba/torba/server/history.py
Normal file
324
torba/torba/server/history.py
Normal file
|
@ -0,0 +1,324 @@
|
||||||
|
# Copyright (c) 2016-2018, Neil Booth
|
||||||
|
# Copyright (c) 2017, the ElectrumX authors
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# See the file "LICENCE" for information about the copyright
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
"""History by script hash (address)."""
|
||||||
|
|
||||||
|
import array
|
||||||
|
import ast
|
||||||
|
import bisect
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from torba.server import util
|
||||||
|
from torba.server.util import pack_be_uint16, unpack_be_uint16_from
|
||||||
|
from torba.server.hash import hash_to_hex_str, HASHX_LEN
|
||||||
|
|
||||||
|
|
||||||
|
class History:
|
||||||
|
|
||||||
|
DB_VERSIONS = [0]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.logger = util.class_logger(__name__, self.__class__.__name__)
|
||||||
|
# For history compaction
|
||||||
|
self.max_hist_row_entries = 12500
|
||||||
|
self.unflushed = defaultdict(partial(array.array, 'I'))
|
||||||
|
self.unflushed_count = 0
|
||||||
|
self.db = None
|
||||||
|
|
||||||
|
def open_db(self, db_class, for_sync, utxo_flush_count, compacting):
|
||||||
|
self.db = db_class('hist', for_sync)
|
||||||
|
self.read_state()
|
||||||
|
self.clear_excess(utxo_flush_count)
|
||||||
|
# An incomplete compaction needs to be cancelled otherwise
|
||||||
|
# restarting it will corrupt the history
|
||||||
|
if not compacting:
|
||||||
|
self._cancel_compaction()
|
||||||
|
return self.flush_count
|
||||||
|
|
||||||
|
def close_db(self):
|
||||||
|
if self.db:
|
||||||
|
self.db.close()
|
||||||
|
self.db = None
|
||||||
|
|
||||||
|
def read_state(self):
|
||||||
|
state = self.db.get(b'state\0\0')
|
||||||
|
if state:
|
||||||
|
state = ast.literal_eval(state.decode())
|
||||||
|
if not isinstance(state, dict):
|
||||||
|
raise RuntimeError('failed reading state from history DB')
|
||||||
|
self.flush_count = state['flush_count']
|
||||||
|
self.comp_flush_count = state.get('comp_flush_count', -1)
|
||||||
|
self.comp_cursor = state.get('comp_cursor', -1)
|
||||||
|
self.db_version = state.get('db_version', 0)
|
||||||
|
else:
|
||||||
|
self.flush_count = 0
|
||||||
|
self.comp_flush_count = -1
|
||||||
|
self.comp_cursor = -1
|
||||||
|
self.db_version = max(self.DB_VERSIONS)
|
||||||
|
|
||||||
|
self.logger.info(f'history DB version: {self.db_version}')
|
||||||
|
if self.db_version not in self.DB_VERSIONS:
|
||||||
|
msg = f'this software only handles DB versions {self.DB_VERSIONS}'
|
||||||
|
self.logger.error(msg)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
self.logger.info(f'flush count: {self.flush_count:,d}')
|
||||||
|
|
||||||
|
def clear_excess(self, utxo_flush_count):
|
||||||
|
# < might happen at end of compaction as both DBs cannot be
|
||||||
|
# updated atomically
|
||||||
|
if self.flush_count <= utxo_flush_count:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.logger.info('DB shut down uncleanly. Scanning for '
|
||||||
|
'excess history flushes...')
|
||||||
|
|
||||||
|
keys = []
|
||||||
|
for key, hist in self.db.iterator(prefix=b''):
|
||||||
|
flush_id, = unpack_be_uint16_from(key[-2:])
|
||||||
|
if flush_id > utxo_flush_count:
|
||||||
|
keys.append(key)
|
||||||
|
|
||||||
|
self.logger.info(f'deleting {len(keys):,d} history entries')
|
||||||
|
|
||||||
|
self.flush_count = utxo_flush_count
|
||||||
|
with self.db.write_batch() as batch:
|
||||||
|
for key in keys:
|
||||||
|
batch.delete(key)
|
||||||
|
self.write_state(batch)
|
||||||
|
|
||||||
|
self.logger.info('deleted excess history entries')
|
||||||
|
|
||||||
|
def write_state(self, batch):
|
||||||
|
"""Write state to the history DB."""
|
||||||
|
state = {
|
||||||
|
'flush_count': self.flush_count,
|
||||||
|
'comp_flush_count': self.comp_flush_count,
|
||||||
|
'comp_cursor': self.comp_cursor,
|
||||||
|
'db_version': self.db_version,
|
||||||
|
}
|
||||||
|
# History entries are not prefixed; the suffix \0\0 ensures we
|
||||||
|
# look similar to other entries and aren't interfered with
|
||||||
|
batch.put(b'state\0\0', repr(state).encode())
|
||||||
|
|
||||||
|
def add_unflushed(self, hashXs_by_tx, first_tx_num):
|
||||||
|
unflushed = self.unflushed
|
||||||
|
count = 0
|
||||||
|
for tx_num, hashXs in enumerate(hashXs_by_tx, start=first_tx_num):
|
||||||
|
hashXs = set(hashXs)
|
||||||
|
for hashX in hashXs:
|
||||||
|
unflushed[hashX].append(tx_num)
|
||||||
|
count += len(hashXs)
|
||||||
|
self.unflushed_count += count
|
||||||
|
|
||||||
|
def unflushed_memsize(self):
|
||||||
|
return len(self.unflushed) * 180 + self.unflushed_count * 4
|
||||||
|
|
||||||
|
def assert_flushed(self):
|
||||||
|
assert not self.unflushed
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
start_time = time.time()
|
||||||
|
self.flush_count += 1
|
||||||
|
flush_id = pack_be_uint16(self.flush_count)
|
||||||
|
unflushed = self.unflushed
|
||||||
|
|
||||||
|
with self.db.write_batch() as batch:
|
||||||
|
for hashX in sorted(unflushed):
|
||||||
|
key = hashX + flush_id
|
||||||
|
batch.put(key, unflushed[hashX].tobytes())
|
||||||
|
self.write_state(batch)
|
||||||
|
|
||||||
|
count = len(unflushed)
|
||||||
|
unflushed.clear()
|
||||||
|
self.unflushed_count = 0
|
||||||
|
|
||||||
|
if self.db.for_sync:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
self.logger.info(f'flushed history in {elapsed:.1f}s '
|
||||||
|
f'for {count:,d} addrs')
|
||||||
|
|
||||||
|
def backup(self, hashXs, tx_count):
|
||||||
|
# Not certain this is needed, but it doesn't hurt
|
||||||
|
self.flush_count += 1
|
||||||
|
nremoves = 0
|
||||||
|
bisect_left = bisect.bisect_left
|
||||||
|
|
||||||
|
with self.db.write_batch() as batch:
|
||||||
|
for hashX in sorted(hashXs):
|
||||||
|
deletes = []
|
||||||
|
puts = {}
|
||||||
|
for key, hist in self.db.iterator(prefix=hashX, reverse=True):
|
||||||
|
a = array.array('I')
|
||||||
|
a.frombytes(hist)
|
||||||
|
# Remove all history entries >= tx_count
|
||||||
|
idx = bisect_left(a, tx_count)
|
||||||
|
nremoves += len(a) - idx
|
||||||
|
if idx > 0:
|
||||||
|
puts[key] = a[:idx].tobytes()
|
||||||
|
break
|
||||||
|
deletes.append(key)
|
||||||
|
|
||||||
|
for key in deletes:
|
||||||
|
batch.delete(key)
|
||||||
|
for key, value in puts.items():
|
||||||
|
batch.put(key, value)
|
||||||
|
self.write_state(batch)
|
||||||
|
|
||||||
|
self.logger.info(f'backing up removed {nremoves:,d} history entries')
|
||||||
|
|
||||||
|
def get_txnums(self, hashX, limit=1000):
|
||||||
|
"""Generator that returns an unpruned, sorted list of tx_nums in the
|
||||||
|
history of a hashX. Includes both spending and receiving
|
||||||
|
transactions. By default yields at most 1000 entries. Set
|
||||||
|
limit to None to get them all. """
|
||||||
|
limit = util.resolve_limit(limit)
|
||||||
|
for key, hist in self.db.iterator(prefix=hashX):
|
||||||
|
a = array.array('I')
|
||||||
|
a.frombytes(hist)
|
||||||
|
for tx_num in a:
|
||||||
|
if limit == 0:
|
||||||
|
return
|
||||||
|
yield tx_num
|
||||||
|
limit -= 1
|
||||||
|
|
||||||
|
#
|
||||||
|
# History compaction
|
||||||
|
#
|
||||||
|
|
||||||
|
# comp_cursor is a cursor into compaction progress.
|
||||||
|
# -1: no compaction in progress
|
||||||
|
# 0-65535: Compaction in progress; all prefixes < comp_cursor have
|
||||||
|
# been compacted, and later ones have not.
|
||||||
|
# 65536: compaction complete in-memory but not flushed
|
||||||
|
#
|
||||||
|
# comp_flush_count applies during compaction, and is a flush count
|
||||||
|
# for history with prefix < comp_cursor. flush_count applies
|
||||||
|
# to still uncompacted history. It is -1 when no compaction is
|
||||||
|
# taking place. Key suffixes up to and including comp_flush_count
|
||||||
|
# are used, so a parallel history flush must first increment this
|
||||||
|
#
|
||||||
|
# When compaction is complete and the final flush takes place,
|
||||||
|
# flush_count is reset to comp_flush_count, and comp_flush_count to -1
|
||||||
|
|
||||||
|
def _flush_compaction(self, cursor, write_items, keys_to_delete):
|
||||||
|
"""Flush a single compaction pass as a batch."""
|
||||||
|
# Update compaction state
|
||||||
|
if cursor == 65536:
|
||||||
|
self.flush_count = self.comp_flush_count
|
||||||
|
self.comp_cursor = -1
|
||||||
|
self.comp_flush_count = -1
|
||||||
|
else:
|
||||||
|
self.comp_cursor = cursor
|
||||||
|
|
||||||
|
# History DB. Flush compacted history and updated state
|
||||||
|
with self.db.write_batch() as batch:
|
||||||
|
# Important: delete first! The keyspace may overlap.
|
||||||
|
for key in keys_to_delete:
|
||||||
|
batch.delete(key)
|
||||||
|
for key, value in write_items:
|
||||||
|
batch.put(key, value)
|
||||||
|
self.write_state(batch)
|
||||||
|
|
||||||
|
def _compact_hashX(self, hashX, hist_map, hist_list,
|
||||||
|
write_items, keys_to_delete):
|
||||||
|
"""Compres history for a hashX. hist_list is an ordered list of
|
||||||
|
the histories to be compressed."""
|
||||||
|
# History entries (tx numbers) are 4 bytes each. Distribute
|
||||||
|
# over rows of up to 50KB in size. A fixed row size means
|
||||||
|
# future compactions will not need to update the first N - 1
|
||||||
|
# rows.
|
||||||
|
max_row_size = self.max_hist_row_entries * 4
|
||||||
|
full_hist = b''.join(hist_list)
|
||||||
|
nrows = (len(full_hist) + max_row_size - 1) // max_row_size
|
||||||
|
if nrows > 4:
|
||||||
|
self.logger.info('hashX {} is large: {:,d} entries across '
|
||||||
|
'{:,d} rows'
|
||||||
|
.format(hash_to_hex_str(hashX),
|
||||||
|
len(full_hist) // 4, nrows))
|
||||||
|
|
||||||
|
# Find what history needs to be written, and what keys need to
|
||||||
|
# be deleted. Start by assuming all keys are to be deleted,
|
||||||
|
# and then remove those that are the same on-disk as when
|
||||||
|
# compacted.
|
||||||
|
write_size = 0
|
||||||
|
keys_to_delete.update(hist_map)
|
||||||
|
for n, chunk in enumerate(util.chunks(full_hist, max_row_size)):
|
||||||
|
key = hashX + pack_be_uint16(n)
|
||||||
|
if hist_map.get(key) == chunk:
|
||||||
|
keys_to_delete.remove(key)
|
||||||
|
else:
|
||||||
|
write_items.append((key, chunk))
|
||||||
|
write_size += len(chunk)
|
||||||
|
|
||||||
|
assert n + 1 == nrows
|
||||||
|
self.comp_flush_count = max(self.comp_flush_count, n)
|
||||||
|
|
||||||
|
return write_size
|
||||||
|
|
||||||
|
def _compact_prefix(self, prefix, write_items, keys_to_delete):
|
||||||
|
"""Compact all history entries for hashXs beginning with the
|
||||||
|
given prefix. Update keys_to_delete and write."""
|
||||||
|
prior_hashX = None
|
||||||
|
hist_map = {}
|
||||||
|
hist_list = []
|
||||||
|
|
||||||
|
key_len = HASHX_LEN + 2
|
||||||
|
write_size = 0
|
||||||
|
for key, hist in self.db.iterator(prefix=prefix):
|
||||||
|
# Ignore non-history entries
|
||||||
|
if len(key) != key_len:
|
||||||
|
continue
|
||||||
|
hashX = key[:-2]
|
||||||
|
if hashX != prior_hashX and prior_hashX:
|
||||||
|
write_size += self._compact_hashX(prior_hashX, hist_map,
|
||||||
|
hist_list, write_items,
|
||||||
|
keys_to_delete)
|
||||||
|
hist_map.clear()
|
||||||
|
hist_list.clear()
|
||||||
|
prior_hashX = hashX
|
||||||
|
hist_map[key] = hist
|
||||||
|
hist_list.append(hist)
|
||||||
|
|
||||||
|
if prior_hashX:
|
||||||
|
write_size += self._compact_hashX(prior_hashX, hist_map, hist_list,
|
||||||
|
write_items, keys_to_delete)
|
||||||
|
return write_size
|
||||||
|
|
||||||
|
def _compact_history(self, limit):
|
||||||
|
"""Inner loop of history compaction. Loops until limit bytes have
|
||||||
|
been processed.
|
||||||
|
"""
|
||||||
|
keys_to_delete = set()
|
||||||
|
write_items = [] # A list of (key, value) pairs
|
||||||
|
write_size = 0
|
||||||
|
|
||||||
|
# Loop over 2-byte prefixes
|
||||||
|
cursor = self.comp_cursor
|
||||||
|
while write_size < limit and cursor < 65536:
|
||||||
|
prefix = pack_be_uint16(cursor)
|
||||||
|
write_size += self._compact_prefix(prefix, write_items,
|
||||||
|
keys_to_delete)
|
||||||
|
cursor += 1
|
||||||
|
|
||||||
|
max_rows = self.comp_flush_count + 1
|
||||||
|
self._flush_compaction(cursor, write_items, keys_to_delete)
|
||||||
|
|
||||||
|
self.logger.info('history compaction: wrote {:,d} rows ({:.1f} MB), '
|
||||||
|
'removed {:,d} rows, largest: {:,d}, {:.1f}% complete'
|
||||||
|
.format(len(write_items), write_size / 1000000,
|
||||||
|
len(keys_to_delete), max_rows,
|
||||||
|
100 * cursor / 65536))
|
||||||
|
return write_size
|
||||||
|
|
||||||
|
def _cancel_compaction(self):
|
||||||
|
if self.comp_cursor != -1:
|
||||||
|
self.logger.warning('cancelling in-progress history compaction')
|
||||||
|
self.comp_flush_count = -1
|
||||||
|
self.comp_cursor = -1
|
365
torba/torba/server/mempool.py
Normal file
365
torba/torba/server/mempool.py
Normal file
|
@ -0,0 +1,365 @@
|
||||||
|
# Copyright (c) 2016-2018, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# See the file "LICENCE" for information about the copyright
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
"""Mempool handling."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import itertools
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from asyncio import Lock, sleep
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
|
from torba.server.hash import hash_to_hex_str, hex_str_to_hash
|
||||||
|
from torba.server.util import class_logger, chunks
|
||||||
|
from torba.server.db import UTXO
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class MemPoolTx:
|
||||||
|
prevouts = attr.ib()
|
||||||
|
# A pair is a (hashX, value) tuple
|
||||||
|
in_pairs = attr.ib()
|
||||||
|
out_pairs = attr.ib()
|
||||||
|
fee = attr.ib()
|
||||||
|
size = attr.ib()
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class MemPoolTxSummary:
|
||||||
|
hash = attr.ib()
|
||||||
|
fee = attr.ib()
|
||||||
|
has_unconfirmed_inputs = attr.ib()
|
||||||
|
|
||||||
|
|
||||||
|
class MemPoolAPI(ABC):
|
||||||
|
"""A concrete instance of this class is passed to the MemPool object
|
||||||
|
and used by it to query DB and blockchain state."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def height(self):
|
||||||
|
"""Query bitcoind for its height."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cached_height(self):
|
||||||
|
"""Return the height of bitcoind the last time it was queried,
|
||||||
|
for any reason, without actually querying it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def mempool_hashes(self):
|
||||||
|
"""Query bitcoind for the hashes of all transactions in its
|
||||||
|
mempool, returned as a list."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def raw_transactions(self, hex_hashes):
|
||||||
|
"""Query bitcoind for the serialized raw transactions with the given
|
||||||
|
hashes. Missing transactions are returned as None.
|
||||||
|
|
||||||
|
hex_hashes is an iterable of hexadecimal hash strings."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def lookup_utxos(self, prevouts):
|
||||||
|
"""Return a list of (hashX, value) pairs each prevout if unspent,
|
||||||
|
otherwise return None if spent or not found.
|
||||||
|
|
||||||
|
prevouts - an iterable of (hash, index) pairs
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_mempool(self, touched, height):
|
||||||
|
"""Called each time the mempool is synchronized. touched is a set of
|
||||||
|
hashXs touched since the previous call. height is the
|
||||||
|
daemon's height at the time the mempool was obtained."""
|
||||||
|
|
||||||
|
|
||||||
|
class MemPool:
|
||||||
|
"""Representation of the daemon's mempool.
|
||||||
|
|
||||||
|
coin - a coin class from coins.py
|
||||||
|
api - an object implementing MemPoolAPI
|
||||||
|
|
||||||
|
Updated regularly in caught-up state. Goal is to enable efficient
|
||||||
|
response to the calls in the external interface. To that end we
|
||||||
|
maintain the following maps:
|
||||||
|
|
||||||
|
tx: tx_hash -> MemPoolTx
|
||||||
|
hashXs: hashX -> set of all hashes of txs touching the hashX
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, coin, api, refresh_secs=5.0, log_status_secs=120.0):
|
||||||
|
assert isinstance(api, MemPoolAPI)
|
||||||
|
self.coin = coin
|
||||||
|
self.api = api
|
||||||
|
self.logger = class_logger(__name__, self.__class__.__name__)
|
||||||
|
self.txs = {}
|
||||||
|
self.hashXs = defaultdict(set) # None can be a key
|
||||||
|
self.cached_compact_histogram = []
|
||||||
|
self.refresh_secs = refresh_secs
|
||||||
|
self.log_status_secs = log_status_secs
|
||||||
|
# Prevents mempool refreshes during fee histogram calculation
|
||||||
|
self.lock = Lock()
|
||||||
|
|
||||||
|
async def _logging(self, synchronized_event):
|
||||||
|
"""Print regular logs of mempool stats."""
|
||||||
|
self.logger.info('beginning processing of daemon mempool. '
|
||||||
|
'This can take some time...')
|
||||||
|
start = time.time()
|
||||||
|
await synchronized_event.wait()
|
||||||
|
elapsed = time.time() - start
|
||||||
|
self.logger.info(f'synced in {elapsed:.2f}s')
|
||||||
|
while True:
|
||||||
|
self.logger.info(f'{len(self.txs):,d} txs '
|
||||||
|
f'touching {len(self.hashXs):,d} addresses')
|
||||||
|
await sleep(self.log_status_secs)
|
||||||
|
await synchronized_event.wait()
|
||||||
|
|
||||||
|
async def _refresh_histogram(self, synchronized_event):
|
||||||
|
while True:
|
||||||
|
await synchronized_event.wait()
|
||||||
|
async with self.lock:
|
||||||
|
# Threaded as can be expensive
|
||||||
|
await asyncio.get_event_loop().run_in_executor(None, self._update_histogram, 100_000)
|
||||||
|
await sleep(self.coin.MEMPOOL_HISTOGRAM_REFRESH_SECS)
|
||||||
|
|
||||||
|
def _update_histogram(self, bin_size):
|
||||||
|
# Build a histogram by fee rate
|
||||||
|
histogram = defaultdict(int)
|
||||||
|
for tx in self.txs.values():
|
||||||
|
histogram[tx.fee // tx.size] += tx.size
|
||||||
|
|
||||||
|
# Now compact it. For efficiency, get_fees returns a
|
||||||
|
# compact histogram with variable bin size. The compact
|
||||||
|
# histogram is an array of (fee_rate, vsize) values.
|
||||||
|
# vsize_n is the cumulative virtual size of mempool
|
||||||
|
# transactions with a fee rate in the interval
|
||||||
|
# [rate_(n-1), rate_n)], and rate_(n-1) > rate_n.
|
||||||
|
# Intervals are chosen to create tranches containing at
|
||||||
|
# least 100kb of transactions
|
||||||
|
compact = []
|
||||||
|
cum_size = 0
|
||||||
|
r = 0 # ?
|
||||||
|
for fee_rate, size in sorted(histogram.items(), reverse=True):
|
||||||
|
cum_size += size
|
||||||
|
if cum_size + r > bin_size:
|
||||||
|
compact.append((fee_rate, cum_size))
|
||||||
|
r += cum_size - bin_size
|
||||||
|
cum_size = 0
|
||||||
|
bin_size *= 1.1
|
||||||
|
self.logger.info(f'compact fee histogram: {compact}')
|
||||||
|
self.cached_compact_histogram = compact
|
||||||
|
|
||||||
|
def _accept_transactions(self, tx_map, utxo_map, touched):
|
||||||
|
"""Accept transactions in tx_map to the mempool if all their inputs
|
||||||
|
can be found in the existing mempool or a utxo_map from the
|
||||||
|
DB.
|
||||||
|
|
||||||
|
Returns an (unprocessed tx_map, unspent utxo_map) pair.
|
||||||
|
"""
|
||||||
|
hashXs = self.hashXs
|
||||||
|
txs = self.txs
|
||||||
|
|
||||||
|
deferred = {}
|
||||||
|
unspent = set(utxo_map)
|
||||||
|
# Try to find all prevouts so we can accept the TX
|
||||||
|
for hash, tx in tx_map.items():
|
||||||
|
in_pairs = []
|
||||||
|
try:
|
||||||
|
for prevout in tx.prevouts:
|
||||||
|
utxo = utxo_map.get(prevout)
|
||||||
|
if not utxo:
|
||||||
|
prev_hash, prev_index = prevout
|
||||||
|
# Raises KeyError if prev_hash is not in txs
|
||||||
|
utxo = txs[prev_hash].out_pairs[prev_index]
|
||||||
|
in_pairs.append(utxo)
|
||||||
|
except KeyError:
|
||||||
|
deferred[hash] = tx
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Spend the prevouts
|
||||||
|
unspent.difference_update(tx.prevouts)
|
||||||
|
|
||||||
|
# Save the in_pairs, compute the fee and accept the TX
|
||||||
|
tx.in_pairs = tuple(in_pairs)
|
||||||
|
# Avoid negative fees if dealing with generation-like transactions
|
||||||
|
# because some in_parts would be missing
|
||||||
|
tx.fee = max(0, (sum(v for _, v in tx.in_pairs) -
|
||||||
|
sum(v for _, v in tx.out_pairs)))
|
||||||
|
txs[hash] = tx
|
||||||
|
|
||||||
|
for hashX, value in itertools.chain(tx.in_pairs, tx.out_pairs):
|
||||||
|
touched.add(hashX)
|
||||||
|
hashXs[hashX].add(hash)
|
||||||
|
|
||||||
|
return deferred, {prevout: utxo_map[prevout] for prevout in unspent}
|
||||||
|
|
||||||
|
async def _refresh_hashes(self, synchronized_event):
|
||||||
|
"""Refresh our view of the daemon's mempool."""
|
||||||
|
while True:
|
||||||
|
height = self.api.cached_height()
|
||||||
|
hex_hashes = await self.api.mempool_hashes()
|
||||||
|
if height != await self.api.height():
|
||||||
|
continue
|
||||||
|
hashes = set(hex_str_to_hash(hh) for hh in hex_hashes)
|
||||||
|
async with self.lock:
|
||||||
|
touched = await self._process_mempool(hashes)
|
||||||
|
synchronized_event.set()
|
||||||
|
synchronized_event.clear()
|
||||||
|
await self.api.on_mempool(touched, height)
|
||||||
|
await sleep(self.refresh_secs)
|
||||||
|
|
||||||
|
async def _process_mempool(self, all_hashes):
|
||||||
|
# Re-sync with the new set of hashes
|
||||||
|
txs = self.txs
|
||||||
|
hashXs = self.hashXs
|
||||||
|
touched = set()
|
||||||
|
|
||||||
|
# First handle txs that have disappeared
|
||||||
|
for tx_hash in set(txs).difference(all_hashes):
|
||||||
|
tx = txs.pop(tx_hash)
|
||||||
|
tx_hashXs = set(hashX for hashX, value in tx.in_pairs)
|
||||||
|
tx_hashXs.update(hashX for hashX, value in tx.out_pairs)
|
||||||
|
for hashX in tx_hashXs:
|
||||||
|
hashXs[hashX].remove(tx_hash)
|
||||||
|
if not hashXs[hashX]:
|
||||||
|
del hashXs[hashX]
|
||||||
|
touched.update(tx_hashXs)
|
||||||
|
|
||||||
|
# Process new transactions
|
||||||
|
new_hashes = list(all_hashes.difference(txs))
|
||||||
|
if new_hashes:
|
||||||
|
fetches = []
|
||||||
|
for hashes in chunks(new_hashes, 200):
|
||||||
|
fetches.append(self._fetch_and_accept(hashes, all_hashes, touched))
|
||||||
|
tx_map = {}
|
||||||
|
utxo_map = {}
|
||||||
|
for fetch in asyncio.as_completed(fetches):
|
||||||
|
deferred, unspent = await fetch
|
||||||
|
tx_map.update(deferred)
|
||||||
|
utxo_map.update(unspent)
|
||||||
|
|
||||||
|
prior_count = 0
|
||||||
|
# FIXME: this is not particularly efficient
|
||||||
|
while tx_map and len(tx_map) != prior_count:
|
||||||
|
prior_count = len(tx_map)
|
||||||
|
tx_map, utxo_map = self._accept_transactions(tx_map, utxo_map,
|
||||||
|
touched)
|
||||||
|
if tx_map:
|
||||||
|
self.logger.info(f'{len(tx_map)} txs dropped')
|
||||||
|
|
||||||
|
return touched
|
||||||
|
|
||||||
|
async def _fetch_and_accept(self, hashes, all_hashes, touched):
|
||||||
|
"""Fetch a list of mempool transactions."""
|
||||||
|
hex_hashes_iter = (hash_to_hex_str(hash) for hash in hashes)
|
||||||
|
raw_txs = await self.api.raw_transactions(hex_hashes_iter)
|
||||||
|
|
||||||
|
def deserialize_txs(): # This function is pure
|
||||||
|
to_hashX = self.coin.hashX_from_script
|
||||||
|
deserializer = self.coin.DESERIALIZER
|
||||||
|
|
||||||
|
txs = {}
|
||||||
|
for hash, raw_tx in zip(hashes, raw_txs):
|
||||||
|
# The daemon may have evicted the tx from its
|
||||||
|
# mempool or it may have gotten in a block
|
||||||
|
if not raw_tx:
|
||||||
|
continue
|
||||||
|
tx, tx_size = deserializer(raw_tx).read_tx_and_vsize()
|
||||||
|
# Convert the inputs and outputs into (hashX, value) pairs
|
||||||
|
# Drop generation-like inputs from MemPoolTx.prevouts
|
||||||
|
txin_pairs = tuple((txin.prev_hash, txin.prev_idx)
|
||||||
|
for txin in tx.inputs
|
||||||
|
if not txin.is_generation())
|
||||||
|
txout_pairs = tuple((to_hashX(txout.pk_script), txout.value)
|
||||||
|
for txout in tx.outputs)
|
||||||
|
txs[hash] = MemPoolTx(txin_pairs, None, txout_pairs,
|
||||||
|
0, tx_size)
|
||||||
|
return txs
|
||||||
|
|
||||||
|
# Thread this potentially slow operation so as not to block
|
||||||
|
tx_map = await asyncio.get_event_loop().run_in_executor(None, deserialize_txs)
|
||||||
|
|
||||||
|
# Determine all prevouts not in the mempool, and fetch the
|
||||||
|
# UTXO information from the database. Failed prevout lookups
|
||||||
|
# return None - concurrent database updates happen - which is
|
||||||
|
# relied upon by _accept_transactions. Ignore prevouts that are
|
||||||
|
# generation-like.
|
||||||
|
prevouts = tuple(prevout for tx in tx_map.values()
|
||||||
|
for prevout in tx.prevouts
|
||||||
|
if prevout[0] not in all_hashes)
|
||||||
|
utxos = await self.api.lookup_utxos(prevouts)
|
||||||
|
utxo_map = {prevout: utxo for prevout, utxo in zip(prevouts, utxos)}
|
||||||
|
|
||||||
|
return self._accept_transactions(tx_map, utxo_map, touched)
|
||||||
|
|
||||||
|
#
|
||||||
|
# External interface
|
||||||
|
#
|
||||||
|
|
||||||
|
async def keep_synchronized(self, synchronized_event):
|
||||||
|
"""Keep the mempool synchronized with the daemon."""
|
||||||
|
await asyncio.wait([
|
||||||
|
self._refresh_hashes(synchronized_event),
|
||||||
|
self._refresh_histogram(synchronized_event),
|
||||||
|
self._logging(synchronized_event)
|
||||||
|
])
|
||||||
|
|
||||||
|
async def balance_delta(self, hashX):
|
||||||
|
"""Return the unconfirmed amount in the mempool for hashX.
|
||||||
|
|
||||||
|
Can be positive or negative.
|
||||||
|
"""
|
||||||
|
value = 0
|
||||||
|
if hashX in self.hashXs:
|
||||||
|
for hash in self.hashXs[hashX]:
|
||||||
|
tx = self.txs[hash]
|
||||||
|
value -= sum(v for h168, v in tx.in_pairs if h168 == hashX)
|
||||||
|
value += sum(v for h168, v in tx.out_pairs if h168 == hashX)
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def compact_fee_histogram(self):
|
||||||
|
"""Return a compact fee histogram of the current mempool."""
|
||||||
|
return self.cached_compact_histogram
|
||||||
|
|
||||||
|
async def potential_spends(self, hashX):
|
||||||
|
"""Return a set of (prev_hash, prev_idx) pairs from mempool
|
||||||
|
transactions that touch hashX.
|
||||||
|
|
||||||
|
None, some or all of these may be spends of the hashX, but all
|
||||||
|
actual spends of it (in the DB or mempool) will be included.
|
||||||
|
"""
|
||||||
|
result = set()
|
||||||
|
for tx_hash in self.hashXs.get(hashX, ()):
|
||||||
|
tx = self.txs[tx_hash]
|
||||||
|
result.update(tx.prevouts)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def transaction_summaries(self, hashX):
|
||||||
|
"""Return a list of MemPoolTxSummary objects for the hashX."""
|
||||||
|
result = []
|
||||||
|
for tx_hash in self.hashXs.get(hashX, ()):
|
||||||
|
tx = self.txs[tx_hash]
|
||||||
|
has_ui = any(hash in self.txs for hash, idx in tx.prevouts)
|
||||||
|
result.append(MemPoolTxSummary(tx_hash, tx.fee, has_ui))
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def unordered_UTXOs(self, hashX):
|
||||||
|
"""Return an unordered list of UTXO named tuples from mempool
|
||||||
|
transactions that pay to hashX.
|
||||||
|
|
||||||
|
This does not consider if any other mempool transactions spend
|
||||||
|
the outputs.
|
||||||
|
"""
|
||||||
|
utxos = []
|
||||||
|
for tx_hash in self.hashXs.get(hashX, ()):
|
||||||
|
tx = self.txs.get(tx_hash)
|
||||||
|
for pos, (hX, value) in enumerate(tx.out_pairs):
|
||||||
|
if hX == hashX:
|
||||||
|
utxos.append(UTXO(-1, pos, tx_hash, 0, value))
|
||||||
|
return utxos
|
253
torba/torba/server/merkle.py
Normal file
253
torba/torba/server/merkle.py
Normal file
|
@ -0,0 +1,253 @@
|
||||||
|
# Copyright (c) 2018, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
"""Merkle trees, branches, proofs and roots."""
|
||||||
|
|
||||||
|
from asyncio import Event
|
||||||
|
from math import ceil, log
|
||||||
|
|
||||||
|
from torba.server.hash import double_sha256
|
||||||
|
|
||||||
|
|
||||||
|
class Merkle:
|
||||||
|
"""Perform merkle tree calculations on binary hashes using a given hash
|
||||||
|
function.
|
||||||
|
|
||||||
|
If the hash count is not even, the final hash is repeated when
|
||||||
|
calculating the next merkle layer up the tree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hash_func=double_sha256):
|
||||||
|
self.hash_func = hash_func
|
||||||
|
|
||||||
|
def tree_depth(self, hash_count):
|
||||||
|
return self.branch_length(hash_count) + 1
|
||||||
|
|
||||||
|
def branch_length(self, hash_count):
|
||||||
|
"""Return the length of a merkle branch given the number of hashes."""
|
||||||
|
if not isinstance(hash_count, int):
|
||||||
|
raise TypeError('hash_count must be an integer')
|
||||||
|
if hash_count < 1:
|
||||||
|
raise ValueError('hash_count must be at least 1')
|
||||||
|
return ceil(log(hash_count, 2))
|
||||||
|
|
||||||
|
def branch_and_root(self, hashes, index, length=None):
|
||||||
|
"""Return a (merkle branch, merkle_root) pair given hashes, and the
|
||||||
|
index of one of those hashes.
|
||||||
|
"""
|
||||||
|
hashes = list(hashes)
|
||||||
|
if not isinstance(index, int):
|
||||||
|
raise TypeError('index must be an integer')
|
||||||
|
# This also asserts hashes is not empty
|
||||||
|
if not 0 <= index < len(hashes):
|
||||||
|
raise ValueError('index out of range')
|
||||||
|
natural_length = self.branch_length(len(hashes))
|
||||||
|
if length is None:
|
||||||
|
length = natural_length
|
||||||
|
else:
|
||||||
|
if not isinstance(length, int):
|
||||||
|
raise TypeError('length must be an integer')
|
||||||
|
if length < natural_length:
|
||||||
|
raise ValueError('length out of range')
|
||||||
|
|
||||||
|
hash_func = self.hash_func
|
||||||
|
branch = []
|
||||||
|
for _ in range(length):
|
||||||
|
if len(hashes) & 1:
|
||||||
|
hashes.append(hashes[-1])
|
||||||
|
branch.append(hashes[index ^ 1])
|
||||||
|
index >>= 1
|
||||||
|
hashes = [hash_func(hashes[n] + hashes[n + 1])
|
||||||
|
for n in range(0, len(hashes), 2)]
|
||||||
|
|
||||||
|
return branch, hashes[0]
|
||||||
|
|
||||||
|
def root(self, hashes, length=None):
|
||||||
|
"""Return the merkle root of a non-empty iterable of binary hashes."""
|
||||||
|
branch, root = self.branch_and_root(hashes, 0, length)
|
||||||
|
return root
|
||||||
|
|
||||||
|
def root_from_proof(self, hash, branch, index):
|
||||||
|
"""Return the merkle root given a hash, a merkle branch to it, and
|
||||||
|
its index in the hashes array.
|
||||||
|
|
||||||
|
branch is an iterable sorted deepest to shallowest. If the
|
||||||
|
returned root is the expected value then the merkle proof is
|
||||||
|
verified.
|
||||||
|
|
||||||
|
The caller should have confirmed the length of the branch with
|
||||||
|
branch_length(). Unfortunately this is not easily done for
|
||||||
|
bitcoin transactions as the number of transactions in a block
|
||||||
|
is unknown to an SPV client.
|
||||||
|
"""
|
||||||
|
hash_func = self.hash_func
|
||||||
|
for elt in branch:
|
||||||
|
if index & 1:
|
||||||
|
hash = hash_func(elt + hash)
|
||||||
|
else:
|
||||||
|
hash = hash_func(hash + elt)
|
||||||
|
index >>= 1
|
||||||
|
if index:
|
||||||
|
raise ValueError('index out of range for branch')
|
||||||
|
return hash
|
||||||
|
|
||||||
|
def level(self, hashes, depth_higher):
|
||||||
|
"""Return a level of the merkle tree of hashes the given depth
|
||||||
|
higher than the bottom row of the original tree."""
|
||||||
|
size = 1 << depth_higher
|
||||||
|
root = self.root
|
||||||
|
return [root(hashes[n: n + size], depth_higher)
|
||||||
|
for n in range(0, len(hashes), size)]
|
||||||
|
|
||||||
|
def branch_and_root_from_level(self, level, leaf_hashes, index,
|
||||||
|
depth_higher):
|
||||||
|
"""Return a (merkle branch, merkle_root) pair when a merkle-tree has a
|
||||||
|
level cached.
|
||||||
|
|
||||||
|
To maximally reduce the amount of data hashed in computing a
|
||||||
|
markle branch, cache a tree of depth N at level N // 2.
|
||||||
|
|
||||||
|
level is a list of hashes in the middle of the tree (returned
|
||||||
|
by level())
|
||||||
|
|
||||||
|
leaf_hashes are the leaves needed to calculate a partial branch
|
||||||
|
up to level.
|
||||||
|
|
||||||
|
depth_higher is how much higher level is than the leaves of the tree
|
||||||
|
|
||||||
|
index is the index in the full list of hashes of the hash whose
|
||||||
|
merkle branch we want.
|
||||||
|
"""
|
||||||
|
if not isinstance(level, list):
|
||||||
|
raise TypeError("level must be a list")
|
||||||
|
if not isinstance(leaf_hashes, list):
|
||||||
|
raise TypeError("leaf_hashes must be a list")
|
||||||
|
leaf_index = (index >> depth_higher) << depth_higher
|
||||||
|
leaf_branch, leaf_root = self.branch_and_root(
|
||||||
|
leaf_hashes, index - leaf_index, depth_higher)
|
||||||
|
index >>= depth_higher
|
||||||
|
level_branch, root = self.branch_and_root(level, index)
|
||||||
|
# Check last so that we know index is in-range
|
||||||
|
if leaf_root != level[index]:
|
||||||
|
raise ValueError('leaf hashes inconsistent with level')
|
||||||
|
return leaf_branch + level_branch, root
|
||||||
|
|
||||||
|
|
||||||
|
class MerkleCache:
|
||||||
|
"""A cache to calculate merkle branches efficiently."""
|
||||||
|
|
||||||
|
def __init__(self, merkle, source_func):
|
||||||
|
"""Initialise a cache hashes taken from source_func:
|
||||||
|
|
||||||
|
async def source_func(index, count):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
self.merkle = merkle
|
||||||
|
self.source_func = source_func
|
||||||
|
self.length = 0
|
||||||
|
self.depth_higher = 0
|
||||||
|
self.initialized = Event()
|
||||||
|
|
||||||
|
def _segment_length(self):
|
||||||
|
return 1 << self.depth_higher
|
||||||
|
|
||||||
|
def _leaf_start(self, index):
|
||||||
|
"""Given a level's depth higher and a hash index, return the leaf
|
||||||
|
index and leaf hash count needed to calculate a merkle branch.
|
||||||
|
"""
|
||||||
|
depth_higher = self.depth_higher
|
||||||
|
return (index >> depth_higher) << depth_higher
|
||||||
|
|
||||||
|
def _level(self, hashes):
|
||||||
|
return self.merkle.level(hashes, self.depth_higher)
|
||||||
|
|
||||||
|
async def _extend_to(self, length):
|
||||||
|
"""Extend the length of the cache if necessary."""
|
||||||
|
if length <= self.length:
|
||||||
|
return
|
||||||
|
# Start from the beginning of any final partial segment.
|
||||||
|
# Retain the value of depth_higher; in practice this is fine
|
||||||
|
start = self._leaf_start(self.length)
|
||||||
|
hashes = await self.source_func(start, length - start)
|
||||||
|
self.level[start >> self.depth_higher:] = self._level(hashes)
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
async def _level_for(self, length):
|
||||||
|
"""Return a (level_length, final_hash) pair for a truncation
|
||||||
|
of the hashes to the given length."""
|
||||||
|
if length == self.length:
|
||||||
|
return self.level
|
||||||
|
level = self.level[:length >> self.depth_higher]
|
||||||
|
leaf_start = self._leaf_start(length)
|
||||||
|
count = min(self._segment_length(), length - leaf_start)
|
||||||
|
hashes = await self.source_func(leaf_start, count)
|
||||||
|
level += self._level(hashes)
|
||||||
|
return level
|
||||||
|
|
||||||
|
async def initialize(self, length):
|
||||||
|
"""Call to initialize the cache to a source of given length."""
|
||||||
|
self.length = length
|
||||||
|
self.depth_higher = self.merkle.tree_depth(length) // 2
|
||||||
|
self.level = self._level(await self.source_func(0, length))
|
||||||
|
self.initialized.set()
|
||||||
|
|
||||||
|
def truncate(self, length):
|
||||||
|
"""Truncate the cache so it covers no more than length underlying
|
||||||
|
hashes."""
|
||||||
|
if not isinstance(length, int):
|
||||||
|
raise TypeError('length must be an integer')
|
||||||
|
if length <= 0:
|
||||||
|
raise ValueError('length must be positive')
|
||||||
|
if length >= self.length:
|
||||||
|
return
|
||||||
|
length = self._leaf_start(length)
|
||||||
|
self.length = length
|
||||||
|
self.level[length >> self.depth_higher:] = []
|
||||||
|
|
||||||
|
async def branch_and_root(self, length, index):
|
||||||
|
"""Return a merkle branch and root. Length is the number of
|
||||||
|
hashes used to calculate the merkle root, index is the position
|
||||||
|
of the hash to calculate the branch of.
|
||||||
|
|
||||||
|
index must be less than length, which must be at least 1."""
|
||||||
|
if not isinstance(length, int):
|
||||||
|
raise TypeError('length must be an integer')
|
||||||
|
if not isinstance(index, int):
|
||||||
|
raise TypeError('index must be an integer')
|
||||||
|
if length <= 0:
|
||||||
|
raise ValueError('length must be positive')
|
||||||
|
if index >= length:
|
||||||
|
raise ValueError('index must be less than length')
|
||||||
|
await self.initialized.wait()
|
||||||
|
await self._extend_to(length)
|
||||||
|
leaf_start = self._leaf_start(index)
|
||||||
|
count = min(self._segment_length(), length - leaf_start)
|
||||||
|
leaf_hashes = await self.source_func(leaf_start, count)
|
||||||
|
if length < self._segment_length():
|
||||||
|
return self.merkle.branch_and_root(leaf_hashes, index)
|
||||||
|
level = await self._level_for(length)
|
||||||
|
return self.merkle.branch_and_root_from_level(
|
||||||
|
level, leaf_hashes, index, self.depth_higher)
|
301
torba/torba/server/peer.py
Normal file
301
torba/torba/server/peer.py
Normal file
|
@ -0,0 +1,301 @@
|
||||||
|
# Copyright (c) 2017, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
"""Representation of a peer server."""
|
||||||
|
|
||||||
|
from ipaddress import ip_address
|
||||||
|
|
||||||
|
from torba.server import util
|
||||||
|
from torba.server.util import cachedproperty
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
class Peer:
|
||||||
|
|
||||||
|
# Protocol version
|
||||||
|
ATTRS = ('host', 'features',
|
||||||
|
# metadata
|
||||||
|
'source', 'ip_addr',
|
||||||
|
'last_good', 'last_try', 'try_count')
|
||||||
|
FEATURES = ('pruning', 'server_version', 'protocol_min', 'protocol_max',
|
||||||
|
'ssl_port', 'tcp_port')
|
||||||
|
# This should be set by the application
|
||||||
|
DEFAULT_PORTS: Dict[str, int] = {}
|
||||||
|
|
||||||
|
def __init__(self, host, features, source='unknown', ip_addr=None,
|
||||||
|
last_good=0, last_try=0, try_count=0):
|
||||||
|
"""Create a peer given a host name (or IP address as a string),
|
||||||
|
a dictionary of features, and a record of the source."""
|
||||||
|
assert isinstance(host, str)
|
||||||
|
assert isinstance(features, dict)
|
||||||
|
assert host in features.get('hosts', {})
|
||||||
|
self.host = host
|
||||||
|
self.features = features.copy()
|
||||||
|
# Canonicalize / clean-up
|
||||||
|
for feature in self.FEATURES:
|
||||||
|
self.features[feature] = getattr(self, feature)
|
||||||
|
# Metadata
|
||||||
|
self.source = source
|
||||||
|
self.ip_addr = ip_addr
|
||||||
|
# last_good represents the last connection that was
|
||||||
|
# successful *and* successfully verified, at which point
|
||||||
|
# try_count is set to 0. Failure to connect or failure to
|
||||||
|
# verify increment the try_count.
|
||||||
|
self.last_good = last_good
|
||||||
|
self.last_try = last_try
|
||||||
|
self.try_count = try_count
|
||||||
|
# Transient, non-persisted metadata
|
||||||
|
self.bad = False
|
||||||
|
self.other_port_pairs = set()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def peers_from_features(cls, features, source):
|
||||||
|
peers = []
|
||||||
|
if isinstance(features, dict):
|
||||||
|
hosts = features.get('hosts')
|
||||||
|
if isinstance(hosts, dict):
|
||||||
|
peers = [Peer(host, features, source=source)
|
||||||
|
for host in hosts if isinstance(host, str)]
|
||||||
|
return peers
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize(cls, item):
|
||||||
|
"""Deserialize from a dictionary."""
|
||||||
|
return cls(**item)
|
||||||
|
|
||||||
|
def matches(self, peers):
|
||||||
|
"""Return peers whose host matches our hostname or IP address.
|
||||||
|
Additionally include all peers whose IP address matches our
|
||||||
|
hostname if that is an IP address.
|
||||||
|
"""
|
||||||
|
candidates = (self.host.lower(), self.ip_addr)
|
||||||
|
return [peer for peer in peers
|
||||||
|
if peer.host.lower() in candidates
|
||||||
|
or peer.ip_addr == self.host]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.host
|
||||||
|
|
||||||
|
def update_features(self, features):
|
||||||
|
"""Update features in-place."""
|
||||||
|
try:
|
||||||
|
tmp = Peer(self.host, features)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.update_features_from_peer(tmp)
|
||||||
|
|
||||||
|
def update_features_from_peer(self, peer):
|
||||||
|
if peer != self:
|
||||||
|
self.features = peer.features
|
||||||
|
for feature in self.FEATURES:
|
||||||
|
setattr(self, feature, getattr(peer, feature))
|
||||||
|
|
||||||
|
def connection_port_pairs(self):
|
||||||
|
"""Return a list of (kind, port) pairs to try when making a
|
||||||
|
connection."""
|
||||||
|
# Use a list not a set - it's important to try the registered
|
||||||
|
# ports first.
|
||||||
|
pairs = [('SSL', self.ssl_port), ('TCP', self.tcp_port)]
|
||||||
|
while self.other_port_pairs:
|
||||||
|
pairs.append(self.other_port_pairs.pop())
|
||||||
|
return [pair for pair in pairs if pair[1]]
|
||||||
|
|
||||||
|
def mark_bad(self):
|
||||||
|
"""Mark as bad to avoid reconnects but also to remember for a
|
||||||
|
while."""
|
||||||
|
self.bad = True
|
||||||
|
|
||||||
|
def check_ports(self, other):
|
||||||
|
"""Remember differing ports in case server operator changed them
|
||||||
|
or removed one."""
|
||||||
|
if other.ssl_port != self.ssl_port:
|
||||||
|
self.other_port_pairs.add(('SSL', other.ssl_port))
|
||||||
|
if other.tcp_port != self.tcp_port:
|
||||||
|
self.other_port_pairs.add(('TCP', other.tcp_port))
|
||||||
|
return bool(self.other_port_pairs)
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def is_tor(self):
|
||||||
|
return self.host.endswith('.onion')
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def is_valid(self):
|
||||||
|
ip = self.ip_address
|
||||||
|
if ip:
|
||||||
|
return ((ip.is_global or ip.is_private)
|
||||||
|
and not (ip.is_multicast or ip.is_unspecified))
|
||||||
|
return util.is_valid_hostname(self.host)
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def is_public(self):
|
||||||
|
ip = self.ip_address
|
||||||
|
if ip:
|
||||||
|
return self.is_valid and not ip.is_private
|
||||||
|
else:
|
||||||
|
return self.is_valid and self.host != 'localhost'
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def ip_address(self):
|
||||||
|
"""The host as a python ip_address object, or None."""
|
||||||
|
try:
|
||||||
|
return ip_address(self.host)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def bucket(self):
|
||||||
|
if self.is_tor:
|
||||||
|
return 'onion'
|
||||||
|
if not self.ip_addr:
|
||||||
|
return ''
|
||||||
|
return tuple(self.ip_addr.split('.')[:2])
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
"""Serialize to a dictionary."""
|
||||||
|
return {attr: getattr(self, attr) for attr in self.ATTRS}
|
||||||
|
|
||||||
|
def _port(self, key):
|
||||||
|
hosts = self.features.get('hosts')
|
||||||
|
if isinstance(hosts, dict):
|
||||||
|
host = hosts.get(self.host)
|
||||||
|
port = self._integer(key, host)
|
||||||
|
if port and 0 < port < 65536:
|
||||||
|
return port
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _integer(self, key, d=None):
|
||||||
|
d = d or self.features
|
||||||
|
result = d.get(key) if isinstance(d, dict) else None
|
||||||
|
if isinstance(result, str):
|
||||||
|
try:
|
||||||
|
result = int(result)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return result if isinstance(result, int) else None
|
||||||
|
|
||||||
|
def _string(self, key):
|
||||||
|
result = self.features.get(key)
|
||||||
|
return result if isinstance(result, str) else None
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def genesis_hash(self):
|
||||||
|
"""Returns None if no SSL port, otherwise the port as an integer."""
|
||||||
|
return self._string('genesis_hash')
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def ssl_port(self):
|
||||||
|
"""Returns None if no SSL port, otherwise the port as an integer."""
|
||||||
|
return self._port('ssl_port')
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def tcp_port(self):
|
||||||
|
"""Returns None if no TCP port, otherwise the port as an integer."""
|
||||||
|
return self._port('tcp_port')
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def server_version(self):
|
||||||
|
"""Returns the server version as a string if known, otherwise None."""
|
||||||
|
return self._string('server_version')
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def pruning(self):
|
||||||
|
"""Returns the pruning level as an integer. None indicates no
|
||||||
|
pruning."""
|
||||||
|
pruning = self._integer('pruning')
|
||||||
|
if pruning and pruning > 0:
|
||||||
|
return pruning
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _protocol_version_string(self, key):
|
||||||
|
version_str = self.features.get(key)
|
||||||
|
ptuple = util.protocol_tuple(version_str)
|
||||||
|
return util.version_string(ptuple)
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def protocol_min(self):
|
||||||
|
"""Minimum protocol version as a string, e.g., 1.0"""
|
||||||
|
return self._protocol_version_string('protocol_min')
|
||||||
|
|
||||||
|
@cachedproperty
|
||||||
|
def protocol_max(self):
|
||||||
|
"""Maximum protocol version as a string, e.g., 1.1"""
|
||||||
|
return self._protocol_version_string('protocol_max')
|
||||||
|
|
||||||
|
def to_tuple(self):
|
||||||
|
"""The tuple ((ip, host, details) expected in response
|
||||||
|
to a peers subscription."""
|
||||||
|
details = self.real_name().split()[1:]
|
||||||
|
return (self.ip_addr or self.host, self.host, details)
|
||||||
|
|
||||||
|
def real_name(self):
|
||||||
|
"""Real name of this peer as used on IRC."""
|
||||||
|
def port_text(letter, port):
|
||||||
|
if port == self.DEFAULT_PORTS.get(letter):
|
||||||
|
return letter
|
||||||
|
else:
|
||||||
|
return letter + str(port)
|
||||||
|
|
||||||
|
parts = [self.host, 'v' + self.protocol_max]
|
||||||
|
if self.pruning:
|
||||||
|
parts.append('p{:d}'.format(self.pruning))
|
||||||
|
for letter, port in (('s', self.ssl_port), ('t', self.tcp_port)):
|
||||||
|
if port:
|
||||||
|
parts.append(port_text(letter, port))
|
||||||
|
return ' '.join(parts)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_real_name(cls, real_name, source):
|
||||||
|
"""Real name is a real name as on IRC, such as
|
||||||
|
|
||||||
|
"erbium1.sytes.net v1.0 s t"
|
||||||
|
|
||||||
|
Returns an instance of this Peer class.
|
||||||
|
"""
|
||||||
|
host = 'nohost'
|
||||||
|
features = {}
|
||||||
|
ports = {}
|
||||||
|
for n, part in enumerate(real_name.split()):
|
||||||
|
if n == 0:
|
||||||
|
host = part
|
||||||
|
continue
|
||||||
|
if part[0] in ('s', 't'):
|
||||||
|
if len(part) == 1:
|
||||||
|
port = cls.DEFAULT_PORTS[part[0]]
|
||||||
|
else:
|
||||||
|
port = part[1:]
|
||||||
|
if part[0] == 's':
|
||||||
|
ports['ssl_port'] = port
|
||||||
|
else:
|
||||||
|
ports['tcp_port'] = port
|
||||||
|
elif part[0] == 'v':
|
||||||
|
features['protocol_max'] = features['protocol_min'] = part[1:]
|
||||||
|
elif part[0] == 'p':
|
||||||
|
features['pruning'] = part[1:]
|
||||||
|
|
||||||
|
features.update(ports)
|
||||||
|
features['hosts'] = {host: ports}
|
||||||
|
|
||||||
|
return cls(host, features, source)
|
505
torba/torba/server/peers.py
Normal file
505
torba/torba/server/peers.py
Normal file
|
@ -0,0 +1,505 @@
|
||||||
|
# Copyright (c) 2017-2018, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# See the file "LICENCE" for information about the copyright
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
"""Peer management."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import socket
|
||||||
|
import ssl
|
||||||
|
import time
|
||||||
|
from asyncio import Event, sleep
|
||||||
|
from collections import defaultdict, Counter
|
||||||
|
|
||||||
|
from torba.tasks import TaskGroup
|
||||||
|
from torba.rpc import (
|
||||||
|
Connector, RPCSession, SOCKSProxy, Notification, handler_invocation,
|
||||||
|
SOCKSError, RPCError
|
||||||
|
)
|
||||||
|
from torba.server.peer import Peer
|
||||||
|
from torba.server.util import class_logger, protocol_tuple
|
||||||
|
|
||||||
|
PEER_GOOD, PEER_STALE, PEER_NEVER, PEER_BAD = range(4)
|
||||||
|
STALE_SECS = 24 * 3600
|
||||||
|
WAKEUP_SECS = 300
|
||||||
|
|
||||||
|
|
||||||
|
class BadPeerError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def assert_good(message, result, instance):
|
||||||
|
if not isinstance(result, instance):
|
||||||
|
raise BadPeerError(f'{message} returned bad result type '
|
||||||
|
f'{type(result).__name__}')
|
||||||
|
|
||||||
|
|
||||||
|
class PeerSession(RPCSession):
|
||||||
|
"""An outgoing session to a peer."""
|
||||||
|
|
||||||
|
async def handle_request(self, request):
|
||||||
|
# We subscribe so might be unlucky enough to get a notification...
|
||||||
|
if (isinstance(request, Notification) and
|
||||||
|
request.method == 'blockchain.headers.subscribe'):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
await handler_invocation(None, request) # Raises
|
||||||
|
|
||||||
|
|
||||||
|
class PeerManager:
|
||||||
|
"""Looks after the DB of peer network servers.
|
||||||
|
|
||||||
|
Attempts to maintain a connection with up to 8 peers.
|
||||||
|
Issues a 'peers.subscribe' RPC to them and tells them our data.
|
||||||
|
"""
|
||||||
|
def __init__(self, env, db):
|
||||||
|
self.logger = class_logger(__name__, self.__class__.__name__)
|
||||||
|
# Initialise the Peer class
|
||||||
|
Peer.DEFAULT_PORTS = env.coin.PEER_DEFAULT_PORTS
|
||||||
|
self.env = env
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
# Our clearnet and Tor Peers, if any
|
||||||
|
sclass = env.coin.SESSIONCLS
|
||||||
|
self.myselves = [Peer(ident.host, sclass.server_features(env), 'env')
|
||||||
|
for ident in env.identities]
|
||||||
|
self.server_version_args = sclass.server_version_args()
|
||||||
|
# Peers have one entry per hostname. Once connected, the
|
||||||
|
# ip_addr property is either None, an onion peer, or the
|
||||||
|
# IP address that was connected to. Adding a peer will evict
|
||||||
|
# any other peers with the same host name or IP address.
|
||||||
|
self.peers = set()
|
||||||
|
self.permit_onion_peer_time = time.time()
|
||||||
|
self.proxy = None
|
||||||
|
self.group = TaskGroup()
|
||||||
|
|
||||||
|
def _my_clearnet_peer(self):
|
||||||
|
"""Returns the clearnet peer representing this server, if any."""
|
||||||
|
clearnet = [peer for peer in self.myselves if not peer.is_tor]
|
||||||
|
return clearnet[0] if clearnet else None
|
||||||
|
|
||||||
|
def _set_peer_statuses(self):
|
||||||
|
"""Set peer statuses."""
|
||||||
|
cutoff = time.time() - STALE_SECS
|
||||||
|
for peer in self.peers:
|
||||||
|
if peer.bad:
|
||||||
|
peer.status = PEER_BAD
|
||||||
|
elif peer.last_good > cutoff:
|
||||||
|
peer.status = PEER_GOOD
|
||||||
|
elif peer.last_good:
|
||||||
|
peer.status = PEER_STALE
|
||||||
|
else:
|
||||||
|
peer.status = PEER_NEVER
|
||||||
|
|
||||||
|
def _features_to_register(self, peer, remote_peers):
|
||||||
|
"""If we should register ourselves to the remote peer, which has
|
||||||
|
reported the given list of known peers, return the clearnet
|
||||||
|
identity features to register, otherwise None.
|
||||||
|
"""
|
||||||
|
# Announce ourself if not present. Don't if disabled, we
|
||||||
|
# are a non-public IP address, or to ourselves.
|
||||||
|
if not self.env.peer_announce or peer in self.myselves:
|
||||||
|
return None
|
||||||
|
my = self._my_clearnet_peer()
|
||||||
|
if not my or not my.is_public:
|
||||||
|
return None
|
||||||
|
# Register if no matches, or ports have changed
|
||||||
|
for peer in my.matches(remote_peers):
|
||||||
|
if peer.tcp_port == my.tcp_port and peer.ssl_port == my.ssl_port:
|
||||||
|
return None
|
||||||
|
return my.features
|
||||||
|
|
||||||
|
def _permit_new_onion_peer(self):
|
||||||
|
"""Accept a new onion peer only once per random time interval."""
|
||||||
|
now = time.time()
|
||||||
|
if now < self.permit_onion_peer_time:
|
||||||
|
return False
|
||||||
|
self.permit_onion_peer_time = now + random.randrange(0, 1200)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _import_peers(self):
|
||||||
|
"""Import hard-coded peers from a file or the coin defaults."""
|
||||||
|
imported_peers = self.myselves.copy()
|
||||||
|
# Add the hard-coded ones unless only reporting ourself
|
||||||
|
if self.env.peer_discovery != self.env.PD_SELF:
|
||||||
|
imported_peers.extend(Peer.from_real_name(real_name, 'coins.py')
|
||||||
|
for real_name in self.env.coin.PEERS)
|
||||||
|
await self._note_peers(imported_peers, limit=None)
|
||||||
|
|
||||||
|
async def _detect_proxy(self):
|
||||||
|
"""Detect a proxy if we don't have one and some time has passed since
|
||||||
|
the last attempt.
|
||||||
|
|
||||||
|
If found self.proxy is set to a SOCKSProxy instance, otherwise
|
||||||
|
None.
|
||||||
|
"""
|
||||||
|
host = self.env.tor_proxy_host
|
||||||
|
if self.env.tor_proxy_port is None:
|
||||||
|
ports = [9050, 9150, 1080]
|
||||||
|
else:
|
||||||
|
ports = [self.env.tor_proxy_port]
|
||||||
|
while True:
|
||||||
|
self.logger.info(f'trying to detect proxy on "{host}" '
|
||||||
|
f'ports {ports}')
|
||||||
|
proxy = await SOCKSProxy.auto_detect_host(host, ports, None)
|
||||||
|
if proxy:
|
||||||
|
self.proxy = proxy
|
||||||
|
self.logger.info(f'detected {proxy}')
|
||||||
|
return
|
||||||
|
self.logger.info('no proxy detected, will try later')
|
||||||
|
await sleep(900)
|
||||||
|
|
||||||
|
async def _note_peers(self, peers, limit=2, check_ports=False,
|
||||||
|
source=None):
|
||||||
|
"""Add a limited number of peers that are not already present."""
|
||||||
|
new_peers = []
|
||||||
|
for peer in peers:
|
||||||
|
if not peer.is_public or (peer.is_tor and not self.proxy):
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches = peer.matches(self.peers)
|
||||||
|
if not matches:
|
||||||
|
new_peers.append(peer)
|
||||||
|
elif check_ports:
|
||||||
|
for match in matches:
|
||||||
|
if match.check_ports(peer):
|
||||||
|
self.logger.info(f'ports changed for {peer}')
|
||||||
|
match.retry_event.set()
|
||||||
|
|
||||||
|
if new_peers:
|
||||||
|
source = source or new_peers[0].source
|
||||||
|
if limit:
|
||||||
|
random.shuffle(new_peers)
|
||||||
|
use_peers = new_peers[:limit]
|
||||||
|
else:
|
||||||
|
use_peers = new_peers
|
||||||
|
for peer in use_peers:
|
||||||
|
self.logger.info(f'accepted new peer {peer} from {source}')
|
||||||
|
peer.retry_event = Event()
|
||||||
|
self.peers.add(peer)
|
||||||
|
await self.group.add(self._monitor_peer(peer))
|
||||||
|
|
||||||
|
async def _monitor_peer(self, peer):
|
||||||
|
# Stop monitoring if we were dropped (a duplicate peer)
|
||||||
|
while peer in self.peers:
|
||||||
|
if await self._should_drop_peer(peer):
|
||||||
|
self.peers.discard(peer)
|
||||||
|
break
|
||||||
|
# Figure out how long to sleep before retrying. Retry a
|
||||||
|
# good connection when it is about to turn stale, otherwise
|
||||||
|
# exponentially back off retries.
|
||||||
|
if peer.try_count == 0:
|
||||||
|
pause = STALE_SECS - WAKEUP_SECS * 2
|
||||||
|
else:
|
||||||
|
pause = WAKEUP_SECS * 2 ** peer.try_count
|
||||||
|
pending, done = await asyncio.wait([peer.retry_event.wait()], timeout=pause)
|
||||||
|
if done:
|
||||||
|
peer.retry_event.clear()
|
||||||
|
|
||||||
|
async def _should_drop_peer(self, peer):
|
||||||
|
peer.try_count += 1
|
||||||
|
is_good = False
|
||||||
|
for kind, port in peer.connection_port_pairs():
|
||||||
|
peer.last_try = time.time()
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if kind == 'SSL':
|
||||||
|
kwargs['ssl'] = ssl.SSLContext(ssl.PROTOCOL_TLS)
|
||||||
|
|
||||||
|
host = self.env.cs_host(for_rpc=False)
|
||||||
|
if isinstance(host, list):
|
||||||
|
host = host[0]
|
||||||
|
|
||||||
|
if self.env.force_proxy or peer.is_tor:
|
||||||
|
if not self.proxy:
|
||||||
|
return
|
||||||
|
kwargs['proxy'] = self.proxy
|
||||||
|
kwargs['resolve'] = not peer.is_tor
|
||||||
|
elif host:
|
||||||
|
# Use our listening Host/IP for outgoing non-proxy
|
||||||
|
# connections so our peers see the correct source.
|
||||||
|
kwargs['local_addr'] = (host, None)
|
||||||
|
|
||||||
|
peer_text = f'[{peer}:{port} {kind}]'
|
||||||
|
try:
|
||||||
|
async with Connector(PeerSession, peer.host, port,
|
||||||
|
**kwargs) as session:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._verify_peer(session, peer),
|
||||||
|
120 if peer.is_tor else 30
|
||||||
|
)
|
||||||
|
is_good = True
|
||||||
|
break
|
||||||
|
except BadPeerError as e:
|
||||||
|
self.logger.error(f'{peer_text} marking bad: ({e})')
|
||||||
|
peer.mark_bad()
|
||||||
|
break
|
||||||
|
except RPCError as e:
|
||||||
|
self.logger.error(f'{peer_text} RPC error: {e.message} '
|
||||||
|
f'({e.code})')
|
||||||
|
except (OSError, SOCKSError, ConnectionError, asyncio.TimeoutError) as e:
|
||||||
|
self.logger.info(f'{peer_text} {e}')
|
||||||
|
|
||||||
|
if is_good:
|
||||||
|
now = time.time()
|
||||||
|
elapsed = now - peer.last_try
|
||||||
|
self.logger.info(f'{peer_text} verified in {elapsed:.1f}s')
|
||||||
|
peer.try_count = 0
|
||||||
|
peer.last_good = now
|
||||||
|
peer.source = 'peer'
|
||||||
|
# At most 2 matches if we're a host name, potentially
|
||||||
|
# several if we're an IP address (several instances
|
||||||
|
# can share a NAT).
|
||||||
|
matches = peer.matches(self.peers)
|
||||||
|
for match in matches:
|
||||||
|
if match.ip_address:
|
||||||
|
if len(matches) > 1:
|
||||||
|
self.peers.remove(match)
|
||||||
|
# Force the peer's monitoring task to exit
|
||||||
|
match.retry_event.set()
|
||||||
|
elif peer.host in match.features['hosts']:
|
||||||
|
match.update_features_from_peer(peer)
|
||||||
|
else:
|
||||||
|
# Forget the peer if long-term unreachable
|
||||||
|
if peer.last_good and not peer.bad:
|
||||||
|
try_limit = 10
|
||||||
|
else:
|
||||||
|
try_limit = 3
|
||||||
|
if peer.try_count >= try_limit:
|
||||||
|
desc = 'bad' if peer.bad else 'unreachable'
|
||||||
|
self.logger.info(f'forgetting {desc} peer: {peer}')
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _verify_peer(self, session, peer):
|
||||||
|
if not peer.is_tor:
|
||||||
|
address = session.peer_address()
|
||||||
|
if address:
|
||||||
|
peer.ip_addr = address[0]
|
||||||
|
|
||||||
|
# server.version goes first
|
||||||
|
message = 'server.version'
|
||||||
|
result = await session.send_request(message, self.server_version_args)
|
||||||
|
assert_good(message, result, list)
|
||||||
|
|
||||||
|
# Protocol version 1.1 returns a pair with the version first
|
||||||
|
if len(result) != 2 or not all(isinstance(x, str) for x in result):
|
||||||
|
raise BadPeerError(f'bad server.version result: {result}')
|
||||||
|
server_version, protocol_version = result
|
||||||
|
peer.server_version = server_version
|
||||||
|
peer.features['server_version'] = server_version
|
||||||
|
ptuple = protocol_tuple(protocol_version)
|
||||||
|
|
||||||
|
await asyncio.wait([
|
||||||
|
self._send_headers_subscribe(session, peer, ptuple),
|
||||||
|
self._send_server_features(session, peer),
|
||||||
|
self._send_peers_subscribe(session, peer)
|
||||||
|
])
|
||||||
|
|
||||||
|
async def _send_headers_subscribe(self, session, peer, ptuple):
|
||||||
|
message = 'blockchain.headers.subscribe'
|
||||||
|
result = await session.send_request(message)
|
||||||
|
assert_good(message, result, dict)
|
||||||
|
|
||||||
|
our_height = self.db.db_height
|
||||||
|
if ptuple < (1, 3):
|
||||||
|
their_height = result.get('block_height')
|
||||||
|
else:
|
||||||
|
their_height = result.get('height')
|
||||||
|
if not isinstance(their_height, int):
|
||||||
|
raise BadPeerError(f'invalid height {their_height}')
|
||||||
|
if abs(our_height - their_height) > 5:
|
||||||
|
raise BadPeerError(f'bad height {their_height:,d} '
|
||||||
|
f'(ours: {our_height:,d})')
|
||||||
|
|
||||||
|
# Check prior header too in case of hard fork.
|
||||||
|
check_height = min(our_height, their_height)
|
||||||
|
raw_header = await self.db.raw_header(check_height)
|
||||||
|
if ptuple >= (1, 4):
|
||||||
|
ours = raw_header.hex()
|
||||||
|
message = 'blockchain.block.header'
|
||||||
|
theirs = await session.send_request(message, [check_height])
|
||||||
|
assert_good(message, theirs, str)
|
||||||
|
if ours != theirs:
|
||||||
|
raise BadPeerError(f'our header {ours} and '
|
||||||
|
f'theirs {theirs} differ')
|
||||||
|
else:
|
||||||
|
ours = self.env.coin.electrum_header(raw_header, check_height)
|
||||||
|
ours = ours.get('prev_block_hash')
|
||||||
|
message = 'blockchain.block.get_header'
|
||||||
|
theirs = await session.send_request(message, [check_height])
|
||||||
|
assert_good(message, theirs, dict)
|
||||||
|
theirs = theirs.get('prev_block_hash')
|
||||||
|
if ours != theirs:
|
||||||
|
raise BadPeerError(f'our header hash {ours} and '
|
||||||
|
f'theirs {theirs} differ')
|
||||||
|
|
||||||
|
async def _send_server_features(self, session, peer):
|
||||||
|
message = 'server.features'
|
||||||
|
features = await session.send_request(message)
|
||||||
|
assert_good(message, features, dict)
|
||||||
|
hosts = [host.lower() for host in features.get('hosts', {})]
|
||||||
|
if self.env.coin.GENESIS_HASH != features.get('genesis_hash'):
|
||||||
|
raise BadPeerError('incorrect genesis hash')
|
||||||
|
elif peer.host.lower() in hosts:
|
||||||
|
peer.update_features(features)
|
||||||
|
else:
|
||||||
|
raise BadPeerError(f'not listed in own hosts list {hosts}')
|
||||||
|
|
||||||
|
async def _send_peers_subscribe(self, session, peer):
|
||||||
|
message = 'server.peers.subscribe'
|
||||||
|
raw_peers = await session.send_request(message)
|
||||||
|
assert_good(message, raw_peers, list)
|
||||||
|
|
||||||
|
# Check the peers list we got from a remote peer.
|
||||||
|
# Each is expected to be of the form:
|
||||||
|
# [ip_addr, hostname, ['v1.0', 't51001', 's51002']]
|
||||||
|
# Call add_peer if the remote doesn't appear to know about us.
|
||||||
|
try:
|
||||||
|
real_names = [' '.join([u[1]] + u[2]) for u in raw_peers]
|
||||||
|
peers = [Peer.from_real_name(real_name, str(peer))
|
||||||
|
for real_name in real_names]
|
||||||
|
except Exception:
|
||||||
|
raise BadPeerError('bad server.peers.subscribe response')
|
||||||
|
|
||||||
|
await self._note_peers(peers)
|
||||||
|
features = self._features_to_register(peer, peers)
|
||||||
|
if not features:
|
||||||
|
return
|
||||||
|
self.logger.info(f'registering ourself with {peer}')
|
||||||
|
# We only care to wait for the response
|
||||||
|
await session.send_request('server.add_peer', [features])
|
||||||
|
|
||||||
|
#
|
||||||
|
# External interface
|
||||||
|
#
|
||||||
|
async def discover_peers(self):
|
||||||
|
"""Perform peer maintenance. This includes
|
||||||
|
|
||||||
|
1) Forgetting unreachable peers.
|
||||||
|
2) Verifying connectivity of new peers.
|
||||||
|
3) Retrying old peers at regular intervals.
|
||||||
|
"""
|
||||||
|
if self.env.peer_discovery != self.env.PD_ON:
|
||||||
|
self.logger.info('peer discovery is disabled')
|
||||||
|
return
|
||||||
|
|
||||||
|
self.logger.info(f'beginning peer discovery. Force use of '
|
||||||
|
f'proxy: {self.env.force_proxy}')
|
||||||
|
|
||||||
|
self.group.add(self._detect_proxy())
|
||||||
|
self.group.add(self._import_peers())
|
||||||
|
|
||||||
|
def info(self):
|
||||||
|
"""The number of peers."""
|
||||||
|
self._set_peer_statuses()
|
||||||
|
counter = Counter(peer.status for peer in self.peers)
|
||||||
|
return {
|
||||||
|
'bad': counter[PEER_BAD],
|
||||||
|
'good': counter[PEER_GOOD],
|
||||||
|
'never': counter[PEER_NEVER],
|
||||||
|
'stale': counter[PEER_STALE],
|
||||||
|
'total': len(self.peers),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def add_localRPC_peer(self, real_name):
|
||||||
|
"""Add a peer passed by the admin over LocalRPC."""
|
||||||
|
await self._note_peers([Peer.from_real_name(real_name, 'RPC')])
|
||||||
|
|
||||||
|
async def on_add_peer(self, features, source_info):
|
||||||
|
"""Add a peer (but only if the peer resolves to the source)."""
|
||||||
|
if not source_info:
|
||||||
|
self.logger.info('ignored add_peer request: no source info')
|
||||||
|
return False
|
||||||
|
source = source_info[0]
|
||||||
|
peers = Peer.peers_from_features(features, source)
|
||||||
|
if not peers:
|
||||||
|
self.logger.info('ignored add_peer request: no peers given')
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Just look at the first peer, require it
|
||||||
|
peer = peers[0]
|
||||||
|
host = peer.host
|
||||||
|
if peer.is_tor:
|
||||||
|
permit = self._permit_new_onion_peer()
|
||||||
|
reason = 'rate limiting'
|
||||||
|
else:
|
||||||
|
getaddrinfo = asyncio.get_event_loop().getaddrinfo
|
||||||
|
try:
|
||||||
|
infos = await getaddrinfo(host, 80, type=socket.SOCK_STREAM)
|
||||||
|
except socket.gaierror:
|
||||||
|
permit = False
|
||||||
|
reason = 'address resolution failure'
|
||||||
|
else:
|
||||||
|
permit = any(source == info[-1][0] for info in infos)
|
||||||
|
reason = 'source-destination mismatch'
|
||||||
|
|
||||||
|
if permit:
|
||||||
|
self.logger.info(f'accepted add_peer request from {source} '
|
||||||
|
f'for {host}')
|
||||||
|
await self._note_peers([peer], check_ports=True)
|
||||||
|
else:
|
||||||
|
self.logger.warning(f'rejected add_peer request from {source} '
|
||||||
|
f'for {host} ({reason})')
|
||||||
|
|
||||||
|
return permit
|
||||||
|
|
||||||
|
def on_peers_subscribe(self, is_tor):
|
||||||
|
"""Returns the server peers as a list of (ip, host, details) tuples.
|
||||||
|
|
||||||
|
We return all peers we've connected to in the last day.
|
||||||
|
Additionally, if we don't have onion routing, we return a few
|
||||||
|
hard-coded onion servers.
|
||||||
|
"""
|
||||||
|
cutoff = time.time() - STALE_SECS
|
||||||
|
recent = [peer for peer in self.peers
|
||||||
|
if peer.last_good > cutoff and
|
||||||
|
not peer.bad and peer.is_public]
|
||||||
|
onion_peers = []
|
||||||
|
|
||||||
|
# Always report ourselves if valid (even if not public)
|
||||||
|
peers = set(myself for myself in self.myselves
|
||||||
|
if myself.last_good > cutoff)
|
||||||
|
|
||||||
|
# Bucket the clearnet peers and select up to two from each
|
||||||
|
buckets = defaultdict(list)
|
||||||
|
for peer in recent:
|
||||||
|
if peer.is_tor:
|
||||||
|
onion_peers.append(peer)
|
||||||
|
else:
|
||||||
|
buckets[peer.bucket()].append(peer)
|
||||||
|
for bucket_peers in buckets.values():
|
||||||
|
random.shuffle(bucket_peers)
|
||||||
|
peers.update(bucket_peers[:2])
|
||||||
|
|
||||||
|
# Add up to 20% onion peers (but up to 10 is OK anyway)
|
||||||
|
random.shuffle(onion_peers)
|
||||||
|
max_onion = 50 if is_tor else max(10, len(peers) // 4)
|
||||||
|
|
||||||
|
peers.update(onion_peers[:max_onion])
|
||||||
|
|
||||||
|
return [peer.to_tuple() for peer in peers]
|
||||||
|
|
||||||
|
def proxy_peername(self):
|
||||||
|
"""Return the peername of the proxy, if there is a proxy, otherwise
|
||||||
|
None."""
|
||||||
|
return self.proxy.peername if self.proxy else None
|
||||||
|
|
||||||
|
def rpc_data(self):
|
||||||
|
"""Peer data for the peers RPC method."""
|
||||||
|
self._set_peer_statuses()
|
||||||
|
descs = ['good', 'stale', 'never', 'bad']
|
||||||
|
|
||||||
|
def peer_data(peer):
|
||||||
|
data = peer.serialize()
|
||||||
|
data['status'] = descs[peer.status]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def peer_key(peer):
|
||||||
|
return (peer.bad, -peer.last_good)
|
||||||
|
|
||||||
|
return [peer_data(peer) for peer in sorted(self.peers, key=peer_key)]
|
251
torba/torba/server/script.py
Normal file
251
torba/torba/server/script.py
Normal file
|
@ -0,0 +1,251 @@
|
||||||
|
# Copyright (c) 2016-2017, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
"""Script-related classes and functions."""
|
||||||
|
|
||||||
|
|
||||||
|
import struct
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
from torba.server.enum import Enumeration
|
||||||
|
from torba.server.hash import hash160
|
||||||
|
from torba.server.util import unpack_le_uint16_from, unpack_le_uint32_from, \
|
||||||
|
pack_le_uint16, pack_le_uint32
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptError(Exception):
|
||||||
|
"""Exception used for script errors."""
|
||||||
|
|
||||||
|
|
||||||
|
OpCodes = Enumeration("Opcodes", [
|
||||||
|
("OP_0", 0), ("OP_PUSHDATA1", 76),
|
||||||
|
"OP_PUSHDATA2", "OP_PUSHDATA4", "OP_1NEGATE",
|
||||||
|
"OP_RESERVED",
|
||||||
|
"OP_1", "OP_2", "OP_3", "OP_4", "OP_5", "OP_6", "OP_7", "OP_8",
|
||||||
|
"OP_9", "OP_10", "OP_11", "OP_12", "OP_13", "OP_14", "OP_15", "OP_16",
|
||||||
|
"OP_NOP", "OP_VER", "OP_IF", "OP_NOTIF", "OP_VERIF", "OP_VERNOTIF",
|
||||||
|
"OP_ELSE", "OP_ENDIF", "OP_VERIFY", "OP_RETURN",
|
||||||
|
"OP_TOALTSTACK", "OP_FROMALTSTACK", "OP_2DROP", "OP_2DUP", "OP_3DUP",
|
||||||
|
"OP_2OVER", "OP_2ROT", "OP_2SWAP", "OP_IFDUP", "OP_DEPTH", "OP_DROP",
|
||||||
|
"OP_DUP", "OP_NIP", "OP_OVER", "OP_PICK", "OP_ROLL", "OP_ROT",
|
||||||
|
"OP_SWAP", "OP_TUCK",
|
||||||
|
"OP_CAT", "OP_SUBSTR", "OP_LEFT", "OP_RIGHT", "OP_SIZE",
|
||||||
|
"OP_INVERT", "OP_AND", "OP_OR", "OP_XOR", "OP_EQUAL", "OP_EQUALVERIFY",
|
||||||
|
"OP_RESERVED1", "OP_RESERVED2",
|
||||||
|
"OP_1ADD", "OP_1SUB", "OP_2MUL", "OP_2DIV", "OP_NEGATE", "OP_ABS",
|
||||||
|
"OP_NOT", "OP_0NOTEQUAL", "OP_ADD", "OP_SUB", "OP_MUL", "OP_DIV", "OP_MOD",
|
||||||
|
"OP_LSHIFT", "OP_RSHIFT", "OP_BOOLAND", "OP_BOOLOR", "OP_NUMEQUAL",
|
||||||
|
"OP_NUMEQUALVERIFY", "OP_NUMNOTEQUAL", "OP_LESSTHAN", "OP_GREATERTHAN",
|
||||||
|
"OP_LESSTHANOREQUAL", "OP_GREATERTHANOREQUAL", "OP_MIN", "OP_MAX",
|
||||||
|
"OP_WITHIN",
|
||||||
|
"OP_RIPEMD160", "OP_SHA1", "OP_SHA256", "OP_HASH160", "OP_HASH256",
|
||||||
|
"OP_CODESEPARATOR", "OP_CHECKSIG", "OP_CHECKSIGVERIFY", "OP_CHECKMULTISIG",
|
||||||
|
"OP_CHECKMULTISIGVERIFY",
|
||||||
|
"OP_NOP1",
|
||||||
|
"OP_CHECKLOCKTIMEVERIFY", "OP_CHECKSEQUENCEVERIFY"
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# Paranoia to make it hard to create bad scripts
|
||||||
|
assert OpCodes.OP_DUP == 0x76
|
||||||
|
assert OpCodes.OP_HASH160 == 0xa9
|
||||||
|
assert OpCodes.OP_EQUAL == 0x87
|
||||||
|
assert OpCodes.OP_EQUALVERIFY == 0x88
|
||||||
|
assert OpCodes.OP_CHECKSIG == 0xac
|
||||||
|
assert OpCodes.OP_CHECKMULTISIG == 0xae
|
||||||
|
|
||||||
|
|
||||||
|
def _match_ops(ops, pattern):
|
||||||
|
if len(ops) != len(pattern):
|
||||||
|
return False
|
||||||
|
for op, pop in zip(ops, pattern):
|
||||||
|
if pop != op:
|
||||||
|
# -1 means 'data push', whose op is an (op, data) tuple
|
||||||
|
if pop == -1 and isinstance(op, tuple):
|
||||||
|
continue
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptPubKey:
|
||||||
|
"""A class for handling a tx output script that gives conditions
|
||||||
|
necessary for spending.
|
||||||
|
"""
|
||||||
|
|
||||||
|
TO_ADDRESS_OPS = [OpCodes.OP_DUP, OpCodes.OP_HASH160, -1,
|
||||||
|
OpCodes.OP_EQUALVERIFY, OpCodes.OP_CHECKSIG]
|
||||||
|
TO_P2SH_OPS = [OpCodes.OP_HASH160, -1, OpCodes.OP_EQUAL]
|
||||||
|
TO_PUBKEY_OPS = [-1, OpCodes.OP_CHECKSIG]
|
||||||
|
|
||||||
|
PayToHandlers = namedtuple('PayToHandlers', 'address script_hash pubkey '
|
||||||
|
'unspendable strange')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pay_to(cls, handlers, script):
|
||||||
|
"""Parse a script, invoke the appropriate handler and
|
||||||
|
return the result.
|
||||||
|
|
||||||
|
One of the following handlers is invoked:
|
||||||
|
handlers.address(hash160)
|
||||||
|
handlers.script_hash(hash160)
|
||||||
|
handlers.pubkey(pubkey)
|
||||||
|
handlers.unspendable()
|
||||||
|
handlers.strange(script)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ops = Script.get_ops(script)
|
||||||
|
except ScriptError:
|
||||||
|
return handlers.unspendable()
|
||||||
|
|
||||||
|
match = _match_ops
|
||||||
|
|
||||||
|
if match(ops, cls.TO_ADDRESS_OPS):
|
||||||
|
return handlers.address(ops[2][-1])
|
||||||
|
if match(ops, cls.TO_P2SH_OPS):
|
||||||
|
return handlers.script_hash(ops[1][-1])
|
||||||
|
if match(ops, cls.TO_PUBKEY_OPS):
|
||||||
|
return handlers.pubkey(ops[0][-1])
|
||||||
|
if ops and ops[0] == OpCodes.OP_RETURN:
|
||||||
|
return handlers.unspendable()
|
||||||
|
return handlers.strange(script)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def P2SH_script(cls, hash160):
|
||||||
|
return (bytes([OpCodes.OP_HASH160])
|
||||||
|
+ Script.push_data(hash160)
|
||||||
|
+ bytes([OpCodes.OP_EQUAL]))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def P2PKH_script(cls, hash160):
|
||||||
|
return (bytes([OpCodes.OP_DUP, OpCodes.OP_HASH160])
|
||||||
|
+ Script.push_data(hash160)
|
||||||
|
+ bytes([OpCodes.OP_EQUALVERIFY, OpCodes.OP_CHECKSIG]))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_pubkey(cls, pubkey, req_compressed=False):
|
||||||
|
if isinstance(pubkey, (bytes, bytearray)):
|
||||||
|
if len(pubkey) == 33 and pubkey[0] in (2, 3):
|
||||||
|
return # Compressed
|
||||||
|
if len(pubkey) == 65 and pubkey[0] == 4:
|
||||||
|
if not req_compressed:
|
||||||
|
return
|
||||||
|
raise PubKeyError('uncompressed pubkeys are invalid')
|
||||||
|
raise PubKeyError('invalid pubkey {}'.format(pubkey))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pubkey_script(cls, pubkey):
|
||||||
|
cls.validate_pubkey(pubkey)
|
||||||
|
return Script.push_data(pubkey) + bytes([OpCodes.OP_CHECKSIG])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def multisig_script(cls, m, pubkeys):
|
||||||
|
"""Returns the script for a pay-to-multisig transaction."""
|
||||||
|
n = len(pubkeys)
|
||||||
|
if not 1 <= m <= n <= 15:
|
||||||
|
raise ScriptError('{:d} of {:d} multisig script not possible'
|
||||||
|
.format(m, n))
|
||||||
|
for pubkey in pubkeys:
|
||||||
|
cls.validate_pubkey(pubkey, req_compressed=True)
|
||||||
|
# See https://bitcoin.org/en/developer-guide
|
||||||
|
# 2 of 3 is: OP_2 pubkey1 pubkey2 pubkey3 OP_3 OP_CHECKMULTISIG
|
||||||
|
return (bytes([OP_1 + m - 1])
|
||||||
|
+ b''.join(cls.push_data(pubkey) for pubkey in pubkeys)
|
||||||
|
+ bytes([OP_1 + n - 1, OP_CHECK_MULTISIG]))
|
||||||
|
|
||||||
|
|
||||||
|
class Script:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_ops(cls, script):
|
||||||
|
ops = []
|
||||||
|
|
||||||
|
# The unpacks or script[n] below throw on truncated scripts
|
||||||
|
try:
|
||||||
|
n = 0
|
||||||
|
while n < len(script):
|
||||||
|
op = script[n]
|
||||||
|
n += 1
|
||||||
|
|
||||||
|
if op <= OpCodes.OP_PUSHDATA4:
|
||||||
|
# Raw bytes follow
|
||||||
|
if op < OpCodes.OP_PUSHDATA1:
|
||||||
|
dlen = op
|
||||||
|
elif op == OpCodes.OP_PUSHDATA1:
|
||||||
|
dlen = script[n]
|
||||||
|
n += 1
|
||||||
|
elif op == OpCodes.OP_PUSHDATA2:
|
||||||
|
dlen, = unpack_le_uint16_from(script[n: n + 2])
|
||||||
|
n += 2
|
||||||
|
else:
|
||||||
|
dlen, = unpack_le_uint32_from(script[n: n + 4])
|
||||||
|
n += 4
|
||||||
|
if n + dlen > len(script):
|
||||||
|
raise IndexError
|
||||||
|
op = (op, script[n:n + dlen])
|
||||||
|
n += dlen
|
||||||
|
|
||||||
|
ops.append(op)
|
||||||
|
except Exception:
|
||||||
|
# Truncated script; e.g. tx_hash
|
||||||
|
# ebc9fa1196a59e192352d76c0f6e73167046b9d37b8302b6bb6968dfd279b767
|
||||||
|
raise ScriptError('truncated script')
|
||||||
|
|
||||||
|
return ops
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def push_data(cls, data):
|
||||||
|
"""Returns the opcodes to push the data on the stack."""
|
||||||
|
assert isinstance(data, (bytes, bytearray))
|
||||||
|
|
||||||
|
n = len(data)
|
||||||
|
if n < OpCodes.OP_PUSHDATA1:
|
||||||
|
return bytes([n]) + data
|
||||||
|
if n < 256:
|
||||||
|
return bytes([OpCodes.OP_PUSHDATA1, n]) + data
|
||||||
|
if n < 65536:
|
||||||
|
return bytes([OpCodes.OP_PUSHDATA2]) + pack_le_uint16(n) + data
|
||||||
|
return bytes([OpCodes.OP_PUSHDATA4]) + pack_le_uint32(n) + data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def opcode_name(cls, opcode):
|
||||||
|
if OpCodes.OP_0 < opcode < OpCodes.OP_PUSHDATA1:
|
||||||
|
return 'OP_{:d}'.format(opcode)
|
||||||
|
try:
|
||||||
|
return OpCodes.whatis(opcode)
|
||||||
|
except KeyError:
|
||||||
|
return 'OP_UNKNOWN:{:d}'.format(opcode)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dump(cls, script):
|
||||||
|
opcodes, datas = cls.get_ops(script)
|
||||||
|
for opcode, data in zip(opcodes, datas):
|
||||||
|
name = cls.opcode_name(opcode)
|
||||||
|
if data is None:
|
||||||
|
print(name)
|
||||||
|
else:
|
||||||
|
print('{} {} ({:d} bytes)'
|
||||||
|
.format(name, data.hex(), len(data)))
|
134
torba/torba/server/server.py
Normal file
134
torba/torba/server/server.py
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
import signal
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from concurrent.futures.thread import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import torba
|
||||||
|
from torba.server.mempool import MemPool, MemPoolAPI
|
||||||
|
from torba.server.session import SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
class Notifications:
|
||||||
|
# hashX notifications come from two sources: new blocks and
|
||||||
|
# mempool refreshes.
|
||||||
|
#
|
||||||
|
# A user with a pending transaction is notified after the block it
|
||||||
|
# gets in is processed. Block processing can take an extended
|
||||||
|
# time, and the prefetcher might poll the daemon after the mempool
|
||||||
|
# code in any case. In such cases the transaction will not be in
|
||||||
|
# the mempool after the mempool refresh. We want to avoid
|
||||||
|
# notifying clients twice - for the mempool refresh and when the
|
||||||
|
# block is done. This object handles that logic by deferring
|
||||||
|
# notifications appropriately.
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._touched_mp = {}
|
||||||
|
self._touched_bp = {}
|
||||||
|
self._highest_block = -1
|
||||||
|
|
||||||
|
async def _maybe_notify(self):
|
||||||
|
tmp, tbp = self._touched_mp, self._touched_bp
|
||||||
|
common = set(tmp).intersection(tbp)
|
||||||
|
if common:
|
||||||
|
height = max(common)
|
||||||
|
elif tmp and max(tmp) == self._highest_block:
|
||||||
|
height = self._highest_block
|
||||||
|
else:
|
||||||
|
# Either we are processing a block and waiting for it to
|
||||||
|
# come in, or we have not yet had a mempool update for the
|
||||||
|
# new block height
|
||||||
|
return
|
||||||
|
touched = tmp.pop(height)
|
||||||
|
for old in [h for h in tmp if h <= height]:
|
||||||
|
del tmp[old]
|
||||||
|
for old in [h for h in tbp if h <= height]:
|
||||||
|
touched.update(tbp.pop(old))
|
||||||
|
await self.notify(height, touched)
|
||||||
|
|
||||||
|
async def notify(self, height, touched):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def start(self, height, notify_func):
|
||||||
|
self._highest_block = height
|
||||||
|
self.notify = notify_func
|
||||||
|
await self.notify(height, set())
|
||||||
|
|
||||||
|
async def on_mempool(self, touched, height):
|
||||||
|
self._touched_mp[height] = touched
|
||||||
|
await self._maybe_notify()
|
||||||
|
|
||||||
|
async def on_block(self, touched, height):
|
||||||
|
self._touched_bp[height] = touched
|
||||||
|
self._highest_block = height
|
||||||
|
await self._maybe_notify()
|
||||||
|
|
||||||
|
|
||||||
|
class Server:
|
||||||
|
|
||||||
|
def __init__(self, env):
|
||||||
|
self.env = env
|
||||||
|
self.log = logging.getLogger(__name__).getChild(self.__class__.__name__)
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
self.cancellable_tasks = []
|
||||||
|
|
||||||
|
self.notifications = notifications = Notifications()
|
||||||
|
self.daemon = daemon = env.coin.DAEMON(env.coin, env.daemon_url)
|
||||||
|
self.db = db = env.coin.DB(env)
|
||||||
|
self.bp = bp = env.coin.BLOCK_PROCESSOR(env, db, daemon, notifications)
|
||||||
|
|
||||||
|
# Set notifications up to implement the MemPoolAPI
|
||||||
|
notifications.height = daemon.height
|
||||||
|
notifications.cached_height = daemon.cached_height
|
||||||
|
notifications.mempool_hashes = daemon.mempool_hashes
|
||||||
|
notifications.raw_transactions = daemon.getrawtransactions
|
||||||
|
notifications.lookup_utxos = db.lookup_utxos
|
||||||
|
MemPoolAPI.register(Notifications)
|
||||||
|
self.mempool = mempool = MemPool(env.coin, notifications)
|
||||||
|
|
||||||
|
self.session_mgr = SessionManager(
|
||||||
|
env, db, bp, daemon, mempool, self.shutdown_event
|
||||||
|
)
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
env = self.env
|
||||||
|
min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()
|
||||||
|
self.log.info(f'software version: {torba.__version__}')
|
||||||
|
self.log.info(f'supported protocol versions: {min_str}-{max_str}')
|
||||||
|
self.log.info(f'event loop policy: {env.loop_policy}')
|
||||||
|
self.log.info(f'reorg limit is {env.reorg_limit:,d} blocks')
|
||||||
|
|
||||||
|
await self.daemon.height()
|
||||||
|
|
||||||
|
def _start_cancellable(run, *args):
|
||||||
|
_flag = asyncio.Event()
|
||||||
|
self.cancellable_tasks.append(asyncio.ensure_future(run(*args, _flag)))
|
||||||
|
return _flag.wait()
|
||||||
|
|
||||||
|
await _start_cancellable(self.bp.fetch_and_process_blocks)
|
||||||
|
await self.db.populate_header_merkle_cache()
|
||||||
|
await _start_cancellable(self.mempool.keep_synchronized)
|
||||||
|
await _start_cancellable(self.session_mgr.serve, self.notifications)
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
for task in reversed(self.cancellable_tasks):
|
||||||
|
task.cancel()
|
||||||
|
await asyncio.wait(self.cancellable_tasks)
|
||||||
|
self.shutdown_event.set()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
executor = ThreadPoolExecutor(1)
|
||||||
|
loop.set_default_executor(executor)
|
||||||
|
|
||||||
|
def __exit():
|
||||||
|
raise SystemExit()
|
||||||
|
try:
|
||||||
|
loop.add_signal_handler(signal.SIGINT, __exit)
|
||||||
|
loop.add_signal_handler(signal.SIGTERM, __exit)
|
||||||
|
loop.run_until_complete(self.start())
|
||||||
|
loop.run_until_complete(self.shutdown_event.wait())
|
||||||
|
except (SystemExit, KeyboardInterrupt):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(self.stop())
|
||||||
|
executor.shutdown(True)
|
1437
torba/torba/server/session.py
Normal file
1437
torba/torba/server/session.py
Normal file
File diff suppressed because it is too large
Load diff
166
torba/torba/server/storage.py
Normal file
166
torba/torba/server/storage.py
Normal file
|
@ -0,0 +1,166 @@
|
||||||
|
# Copyright (c) 2016-2017, the ElectrumX authors
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# See the file "LICENCE" for information about the copyright
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
"""Backend database abstraction."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from torba.server import util
|
||||||
|
|
||||||
|
|
||||||
|
def db_class(name):
|
||||||
|
"""Returns a DB engine class."""
|
||||||
|
for db_class in util.subclasses(Storage):
|
||||||
|
if db_class.__name__.lower() == name.lower():
|
||||||
|
db_class.import_module()
|
||||||
|
return db_class
|
||||||
|
raise RuntimeError('unrecognised DB engine "{}"'.format(name))
|
||||||
|
|
||||||
|
|
||||||
|
class Storage:
|
||||||
|
"""Abstract base class of the DB backend abstraction."""
|
||||||
|
|
||||||
|
def __init__(self, name, for_sync):
|
||||||
|
self.is_new = not os.path.exists(name)
|
||||||
|
self.for_sync = for_sync or self.is_new
|
||||||
|
self.open(name, create=self.is_new)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def import_module(cls):
|
||||||
|
"""Import the DB engine module."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def open(self, name, create):
|
||||||
|
"""Open an existing database or create a new one."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close an existing database."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get(self, key):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def put(self, key, value):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def write_batch(self):
|
||||||
|
"""Return a context manager that provides `put` and `delete`.
|
||||||
|
|
||||||
|
Changes should only be committed when the context manager
|
||||||
|
closes without an exception.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def iterator(self, prefix=b'', reverse=False):
|
||||||
|
"""Return an iterator that yields (key, value) pairs from the
|
||||||
|
database sorted by key.
|
||||||
|
|
||||||
|
If `prefix` is set, only keys starting with `prefix` will be
|
||||||
|
included. If `reverse` is True the items are returned in
|
||||||
|
reverse order.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class LevelDB(Storage):
|
||||||
|
"""LevelDB database engine."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def import_module(cls):
|
||||||
|
import plyvel
|
||||||
|
cls.module = plyvel
|
||||||
|
|
||||||
|
def open(self, name, create):
|
||||||
|
mof = 512 if self.for_sync else 128
|
||||||
|
# Use snappy compression (the default)
|
||||||
|
self.db = self.module.DB(name, create_if_missing=create,
|
||||||
|
max_open_files=mof)
|
||||||
|
self.close = self.db.close
|
||||||
|
self.get = self.db.get
|
||||||
|
self.put = self.db.put
|
||||||
|
self.iterator = self.db.iterator
|
||||||
|
self.write_batch = partial(self.db.write_batch, transaction=True,
|
||||||
|
sync=True)
|
||||||
|
|
||||||
|
|
||||||
|
class RocksDB(Storage):
|
||||||
|
"""RocksDB database engine."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def import_module(cls):
|
||||||
|
import rocksdb
|
||||||
|
cls.module = rocksdb
|
||||||
|
|
||||||
|
def open(self, name, create):
|
||||||
|
mof = 512 if self.for_sync else 128
|
||||||
|
# Use snappy compression (the default)
|
||||||
|
options = self.module.Options(create_if_missing=create,
|
||||||
|
use_fsync=True,
|
||||||
|
target_file_size_base=33554432,
|
||||||
|
max_open_files=mof)
|
||||||
|
self.db = self.module.DB(name, options)
|
||||||
|
self.get = self.db.get
|
||||||
|
self.put = self.db.put
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
# PyRocksDB doesn't provide a close method; hopefully this is enough
|
||||||
|
self.db = self.get = self.put = None
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def write_batch(self):
|
||||||
|
return RocksDBWriteBatch(self.db)
|
||||||
|
|
||||||
|
def iterator(self, prefix=b'', reverse=False):
|
||||||
|
return RocksDBIterator(self.db, prefix, reverse)
|
||||||
|
|
||||||
|
|
||||||
|
class RocksDBWriteBatch:
|
||||||
|
"""A write batch for RocksDB."""
|
||||||
|
|
||||||
|
def __init__(self, db):
|
||||||
|
self.batch = RocksDB.module.WriteBatch()
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self.batch
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if not exc_val:
|
||||||
|
self.db.write(self.batch)
|
||||||
|
|
||||||
|
|
||||||
|
class RocksDBIterator:
|
||||||
|
"""An iterator for RocksDB."""
|
||||||
|
|
||||||
|
def __init__(self, db, prefix, reverse):
|
||||||
|
self.prefix = prefix
|
||||||
|
if reverse:
|
||||||
|
self.iterator = reversed(db.iteritems())
|
||||||
|
nxt_prefix = util.increment_byte_string(prefix)
|
||||||
|
if nxt_prefix:
|
||||||
|
self.iterator.seek(nxt_prefix)
|
||||||
|
try:
|
||||||
|
next(self.iterator)
|
||||||
|
except StopIteration:
|
||||||
|
self.iterator.seek(nxt_prefix)
|
||||||
|
else:
|
||||||
|
self.iterator.seek_to_last()
|
||||||
|
else:
|
||||||
|
self.iterator = db.iteritems()
|
||||||
|
self.iterator.seek(prefix)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
k, v = next(self.iterator)
|
||||||
|
if not k.startswith(self.prefix):
|
||||||
|
raise StopIteration
|
||||||
|
return k, v
|
82
torba/torba/server/text.py
Normal file
82
torba/torba/server/text.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
from torba.server import util
|
||||||
|
|
||||||
|
|
||||||
|
def sessions_lines(data):
|
||||||
|
"""A generator returning lines for a list of sessions.
|
||||||
|
|
||||||
|
data is the return value of rpc_sessions()."""
|
||||||
|
fmt = ('{:<6} {:<5} {:>17} {:>5} {:>5} {:>5} '
|
||||||
|
'{:>7} {:>7} {:>7} {:>7} {:>7} {:>9} {:>21}')
|
||||||
|
yield fmt.format('ID', 'Flags', 'Client', 'Proto',
|
||||||
|
'Reqs', 'Txs', 'Subs',
|
||||||
|
'Recv', 'Recv KB', 'Sent', 'Sent KB', 'Time', 'Peer')
|
||||||
|
for (id_, flags, peer, client, proto, reqs, txs_sent, subs,
|
||||||
|
recv_count, recv_size, send_count, send_size, time) in data:
|
||||||
|
yield fmt.format(id_, flags, client, proto,
|
||||||
|
'{:,d}'.format(reqs),
|
||||||
|
'{:,d}'.format(txs_sent),
|
||||||
|
'{:,d}'.format(subs),
|
||||||
|
'{:,d}'.format(recv_count),
|
||||||
|
'{:,d}'.format(recv_size // 1024),
|
||||||
|
'{:,d}'.format(send_count),
|
||||||
|
'{:,d}'.format(send_size // 1024),
|
||||||
|
util.formatted_time(time, sep=''), peer)
|
||||||
|
|
||||||
|
|
||||||
|
def groups_lines(data):
|
||||||
|
"""A generator returning lines for a list of groups.
|
||||||
|
|
||||||
|
data is the return value of rpc_groups()."""
|
||||||
|
|
||||||
|
fmt = ('{:<6} {:>9} {:>9} {:>6} {:>6} {:>8}'
|
||||||
|
'{:>7} {:>9} {:>7} {:>9}')
|
||||||
|
yield fmt.format('ID', 'Sessions', 'Bwidth KB', 'Reqs', 'Txs', 'Subs',
|
||||||
|
'Recv', 'Recv KB', 'Sent', 'Sent KB')
|
||||||
|
for (id_, session_count, bandwidth, reqs, txs_sent, subs,
|
||||||
|
recv_count, recv_size, send_count, send_size) in data:
|
||||||
|
yield fmt.format(id_,
|
||||||
|
'{:,d}'.format(session_count),
|
||||||
|
'{:,d}'.format(bandwidth // 1024),
|
||||||
|
'{:,d}'.format(reqs),
|
||||||
|
'{:,d}'.format(txs_sent),
|
||||||
|
'{:,d}'.format(subs),
|
||||||
|
'{:,d}'.format(recv_count),
|
||||||
|
'{:,d}'.format(recv_size // 1024),
|
||||||
|
'{:,d}'.format(send_count),
|
||||||
|
'{:,d}'.format(send_size // 1024))
|
||||||
|
|
||||||
|
|
||||||
|
def peers_lines(data):
|
||||||
|
"""A generator returning lines for a list of peers.
|
||||||
|
|
||||||
|
data is the return value of rpc_peers()."""
|
||||||
|
def time_fmt(t):
|
||||||
|
if not t:
|
||||||
|
return 'Never'
|
||||||
|
return util.formatted_time(now - t)
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
fmt = ('{:<30} {:<6} {:>5} {:>5} {:<17} {:>4} '
|
||||||
|
'{:>4} {:>8} {:>11} {:>11} {:>5} {:>20} {:<15}')
|
||||||
|
yield fmt.format('Host', 'Status', 'TCP', 'SSL', 'Server', 'Min',
|
||||||
|
'Max', 'Pruning', 'Last Good', 'Last Try',
|
||||||
|
'Tries', 'Source', 'IP Address')
|
||||||
|
for item in data:
|
||||||
|
features = item['features']
|
||||||
|
hostname = item['host']
|
||||||
|
host = features['hosts'][hostname]
|
||||||
|
yield fmt.format(hostname[:30],
|
||||||
|
item['status'],
|
||||||
|
host.get('tcp_port') or '',
|
||||||
|
host.get('ssl_port') or '',
|
||||||
|
features['server_version'] or 'unknown',
|
||||||
|
features['protocol_min'],
|
||||||
|
features['protocol_max'],
|
||||||
|
features['pruning'] or '',
|
||||||
|
time_fmt(item['last_good']),
|
||||||
|
time_fmt(item['last_try']),
|
||||||
|
item['try_count'],
|
||||||
|
item['source'][:20],
|
||||||
|
item['ip_addr'] or '')
|
625
torba/torba/server/tx.py
Normal file
625
torba/torba/server/tx.py
Normal file
|
@ -0,0 +1,625 @@
|
||||||
|
# Copyright (c) 2016-2017, Neil Booth
|
||||||
|
# Copyright (c) 2017, the ElectrumX authors
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
"""Transaction-related classes and functions."""
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
from torba.server.hash import sha256, double_sha256, hash_to_hex_str
|
||||||
|
from torba.server.script import OpCodes
|
||||||
|
from torba.server.util import (
|
||||||
|
unpack_le_int32_from, unpack_le_int64_from, unpack_le_uint16_from,
|
||||||
|
unpack_le_uint32_from, unpack_le_uint64_from, pack_le_int32, pack_varint,
|
||||||
|
pack_le_uint32, pack_le_int64, pack_varbytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
ZERO = bytes(32)
|
||||||
|
MINUS_1 = 4294967295
|
||||||
|
|
||||||
|
|
||||||
|
class Tx(namedtuple("Tx", "version inputs outputs locktime")):
|
||||||
|
"""Class representing a transaction."""
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
return b''.join((
|
||||||
|
pack_le_int32(self.version),
|
||||||
|
pack_varint(len(self.inputs)),
|
||||||
|
b''.join(tx_in.serialize() for tx_in in self.inputs),
|
||||||
|
pack_varint(len(self.outputs)),
|
||||||
|
b''.join(tx_out.serialize() for tx_out in self.outputs),
|
||||||
|
pack_le_uint32(self.locktime)
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
class TxInput(namedtuple("TxInput", "prev_hash prev_idx script sequence")):
|
||||||
|
"""Class representing a transaction input."""
|
||||||
|
def __str__(self):
|
||||||
|
script = self.script.hex()
|
||||||
|
prev_hash = hash_to_hex_str(self.prev_hash)
|
||||||
|
return ("Input({}, {:d}, script={}, sequence={:d})"
|
||||||
|
.format(prev_hash, self.prev_idx, script, self.sequence))
|
||||||
|
|
||||||
|
def is_generation(self):
|
||||||
|
"""Test if an input is generation/coinbase like"""
|
||||||
|
return self.prev_idx == MINUS_1 and self.prev_hash == ZERO
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
return b''.join((
|
||||||
|
self.prev_hash,
|
||||||
|
pack_le_uint32(self.prev_idx),
|
||||||
|
pack_varbytes(self.script),
|
||||||
|
pack_le_uint32(self.sequence),
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
class TxOutput(namedtuple("TxOutput", "value pk_script")):
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
return b''.join((
|
||||||
|
pack_le_int64(self.value),
|
||||||
|
pack_varbytes(self.pk_script),
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
class Deserializer:
|
||||||
|
"""Deserializes blocks into transactions.
|
||||||
|
|
||||||
|
External entry points are read_tx(), read_tx_and_hash(),
|
||||||
|
read_tx_and_vsize() and read_block().
|
||||||
|
|
||||||
|
This code is performance sensitive as it is executed 100s of
|
||||||
|
millions of times during sync.
|
||||||
|
"""
|
||||||
|
|
||||||
|
TX_HASH_FN = staticmethod(double_sha256)
|
||||||
|
|
||||||
|
def __init__(self, binary, start=0):
|
||||||
|
assert isinstance(binary, bytes)
|
||||||
|
self.binary = binary
|
||||||
|
self.binary_length = len(binary)
|
||||||
|
self.cursor = start
|
||||||
|
|
||||||
|
def read_tx(self):
|
||||||
|
"""Return a deserialized transaction."""
|
||||||
|
return Tx(
|
||||||
|
self._read_le_int32(), # version
|
||||||
|
self._read_inputs(), # inputs
|
||||||
|
self._read_outputs(), # outputs
|
||||||
|
self._read_le_uint32() # locktime
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_tx_and_hash(self):
|
||||||
|
"""Return a (deserialized TX, tx_hash) pair.
|
||||||
|
|
||||||
|
The hash needs to be reversed for human display; for efficiency
|
||||||
|
we process it in the natural serialized order.
|
||||||
|
"""
|
||||||
|
start = self.cursor
|
||||||
|
return self.read_tx(), self.TX_HASH_FN(self.binary[start:self.cursor])
|
||||||
|
|
||||||
|
def read_tx_and_vsize(self):
|
||||||
|
"""Return a (deserialized TX, vsize) pair."""
|
||||||
|
return self.read_tx(), self.binary_length
|
||||||
|
|
||||||
|
def read_tx_block(self):
|
||||||
|
"""Returns a list of (deserialized_tx, tx_hash) pairs."""
|
||||||
|
read = self.read_tx_and_hash
|
||||||
|
# Some coins have excess data beyond the end of the transactions
|
||||||
|
return [read() for _ in range(self._read_varint())]
|
||||||
|
|
||||||
|
def _read_inputs(self):
|
||||||
|
read_input = self._read_input
|
||||||
|
return [read_input() for i in range(self._read_varint())]
|
||||||
|
|
||||||
|
def _read_input(self):
|
||||||
|
return TxInput(
|
||||||
|
self._read_nbytes(32), # prev_hash
|
||||||
|
self._read_le_uint32(), # prev_idx
|
||||||
|
self._read_varbytes(), # script
|
||||||
|
self._read_le_uint32() # sequence
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read_outputs(self):
|
||||||
|
read_output = self._read_output
|
||||||
|
return [read_output() for i in range(self._read_varint())]
|
||||||
|
|
||||||
|
def _read_output(self):
|
||||||
|
return TxOutput(
|
||||||
|
self._read_le_int64(), # value
|
||||||
|
self._read_varbytes(), # pk_script
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read_byte(self):
|
||||||
|
cursor = self.cursor
|
||||||
|
self.cursor += 1
|
||||||
|
return self.binary[cursor]
|
||||||
|
|
||||||
|
def _read_nbytes(self, n):
|
||||||
|
cursor = self.cursor
|
||||||
|
self.cursor = end = cursor + n
|
||||||
|
assert self.binary_length >= end
|
||||||
|
return self.binary[cursor:end]
|
||||||
|
|
||||||
|
def _read_varbytes(self):
|
||||||
|
return self._read_nbytes(self._read_varint())
|
||||||
|
|
||||||
|
def _read_varint(self):
|
||||||
|
n = self.binary[self.cursor]
|
||||||
|
self.cursor += 1
|
||||||
|
if n < 253:
|
||||||
|
return n
|
||||||
|
if n == 253:
|
||||||
|
return self._read_le_uint16()
|
||||||
|
if n == 254:
|
||||||
|
return self._read_le_uint32()
|
||||||
|
return self._read_le_uint64()
|
||||||
|
|
||||||
|
def _read_le_int32(self):
|
||||||
|
result, = unpack_le_int32_from(self.binary, self.cursor)
|
||||||
|
self.cursor += 4
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _read_le_int64(self):
|
||||||
|
result, = unpack_le_int64_from(self.binary, self.cursor)
|
||||||
|
self.cursor += 8
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _read_le_uint16(self):
|
||||||
|
result, = unpack_le_uint16_from(self.binary, self.cursor)
|
||||||
|
self.cursor += 2
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _read_le_uint32(self):
|
||||||
|
result, = unpack_le_uint32_from(self.binary, self.cursor)
|
||||||
|
self.cursor += 4
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _read_le_uint64(self):
|
||||||
|
result, = unpack_le_uint64_from(self.binary, self.cursor)
|
||||||
|
self.cursor += 8
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TxSegWit(namedtuple("Tx", "version marker flag inputs outputs "
|
||||||
|
"witness locktime")):
|
||||||
|
"""Class representing a SegWit transaction."""
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializerSegWit(Deserializer):
|
||||||
|
|
||||||
|
# https://bitcoincore.org/en/segwit_wallet_dev/#transaction-serialization
|
||||||
|
|
||||||
|
def _read_witness(self, fields):
|
||||||
|
read_witness_field = self._read_witness_field
|
||||||
|
return [read_witness_field() for i in range(fields)]
|
||||||
|
|
||||||
|
def _read_witness_field(self):
|
||||||
|
read_varbytes = self._read_varbytes
|
||||||
|
return [read_varbytes() for i in range(self._read_varint())]
|
||||||
|
|
||||||
|
def _read_tx_parts(self):
|
||||||
|
"""Return a (deserialized TX, tx_hash, vsize) tuple."""
|
||||||
|
start = self.cursor
|
||||||
|
marker = self.binary[self.cursor + 4]
|
||||||
|
if marker:
|
||||||
|
tx = super().read_tx()
|
||||||
|
tx_hash = self.TX_HASH_FN(self.binary[start:self.cursor])
|
||||||
|
return tx, tx_hash, self.binary_length
|
||||||
|
|
||||||
|
# Ugh, this is nasty.
|
||||||
|
version = self._read_le_int32()
|
||||||
|
orig_ser = self.binary[start:self.cursor]
|
||||||
|
|
||||||
|
marker = self._read_byte()
|
||||||
|
flag = self._read_byte()
|
||||||
|
|
||||||
|
start = self.cursor
|
||||||
|
inputs = self._read_inputs()
|
||||||
|
outputs = self._read_outputs()
|
||||||
|
orig_ser += self.binary[start:self.cursor]
|
||||||
|
|
||||||
|
base_size = self.cursor - start
|
||||||
|
witness = self._read_witness(len(inputs))
|
||||||
|
|
||||||
|
start = self.cursor
|
||||||
|
locktime = self._read_le_uint32()
|
||||||
|
orig_ser += self.binary[start:self.cursor]
|
||||||
|
vsize = (3 * base_size + self.binary_length) // 4
|
||||||
|
|
||||||
|
return TxSegWit(version, marker, flag, inputs, outputs, witness,
|
||||||
|
locktime), self.TX_HASH_FN(orig_ser), vsize
|
||||||
|
|
||||||
|
def read_tx(self):
|
||||||
|
return self._read_tx_parts()[0]
|
||||||
|
|
||||||
|
def read_tx_and_hash(self):
|
||||||
|
tx, tx_hash, vsize = self._read_tx_parts()
|
||||||
|
return tx, tx_hash
|
||||||
|
|
||||||
|
def read_tx_and_vsize(self):
|
||||||
|
tx, tx_hash, vsize = self._read_tx_parts()
|
||||||
|
return tx, vsize
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializerAuxPow(Deserializer):
|
||||||
|
VERSION_AUXPOW = (1 << 8)
|
||||||
|
|
||||||
|
def read_header(self, height, static_header_size):
|
||||||
|
"""Return the AuxPow block header bytes"""
|
||||||
|
start = self.cursor
|
||||||
|
version = self._read_le_uint32()
|
||||||
|
if version & self.VERSION_AUXPOW:
|
||||||
|
# We are going to calculate the block size then read it as bytes
|
||||||
|
self.cursor = start
|
||||||
|
self.cursor += static_header_size # Block normal header
|
||||||
|
self.read_tx() # AuxPow transaction
|
||||||
|
self.cursor += 32 # Parent block hash
|
||||||
|
merkle_size = self._read_varint()
|
||||||
|
self.cursor += 32 * merkle_size # Merkle branch
|
||||||
|
self.cursor += 4 # Index
|
||||||
|
merkle_size = self._read_varint()
|
||||||
|
self.cursor += 32 * merkle_size # Chain merkle branch
|
||||||
|
self.cursor += 4 # Chain index
|
||||||
|
self.cursor += 80 # Parent block header
|
||||||
|
header_end = self.cursor
|
||||||
|
else:
|
||||||
|
header_end = static_header_size
|
||||||
|
self.cursor = start
|
||||||
|
return self._read_nbytes(header_end)
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializerAuxPowSegWit(DeserializerSegWit, DeserializerAuxPow):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializerEquihash(Deserializer):
|
||||||
|
def read_header(self, height, static_header_size):
|
||||||
|
"""Return the block header bytes"""
|
||||||
|
start = self.cursor
|
||||||
|
# We are going to calculate the block size then read it as bytes
|
||||||
|
self.cursor += static_header_size
|
||||||
|
solution_size = self._read_varint()
|
||||||
|
self.cursor += solution_size
|
||||||
|
header_end = self.cursor
|
||||||
|
self.cursor = start
|
||||||
|
return self._read_nbytes(header_end)
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializerEquihashSegWit(DeserializerSegWit, DeserializerEquihash):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TxJoinSplit(namedtuple("Tx", "version inputs outputs locktime")):
|
||||||
|
"""Class representing a JoinSplit transaction."""
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializerZcash(DeserializerEquihash):
|
||||||
|
def read_tx(self):
|
||||||
|
header = self._read_le_uint32()
|
||||||
|
overwintered = ((header >> 31) == 1)
|
||||||
|
if overwintered:
|
||||||
|
version = header & 0x7fffffff
|
||||||
|
self.cursor += 4 # versionGroupId
|
||||||
|
else:
|
||||||
|
version = header
|
||||||
|
|
||||||
|
is_overwinter_v3 = version == 3
|
||||||
|
is_sapling_v4 = version == 4
|
||||||
|
|
||||||
|
base_tx = TxJoinSplit(
|
||||||
|
version,
|
||||||
|
self._read_inputs(), # inputs
|
||||||
|
self._read_outputs(), # outputs
|
||||||
|
self._read_le_uint32() # locktime
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_overwinter_v3 or is_sapling_v4:
|
||||||
|
self.cursor += 4 # expiryHeight
|
||||||
|
|
||||||
|
has_shielded = False
|
||||||
|
if is_sapling_v4:
|
||||||
|
self.cursor += 8 # valueBalance
|
||||||
|
shielded_spend_size = self._read_varint()
|
||||||
|
self.cursor += shielded_spend_size * 384 # vShieldedSpend
|
||||||
|
shielded_output_size = self._read_varint()
|
||||||
|
self.cursor += shielded_output_size * 948 # vShieldedOutput
|
||||||
|
has_shielded = shielded_spend_size > 0 or shielded_output_size > 0
|
||||||
|
|
||||||
|
if base_tx.version >= 2:
|
||||||
|
joinsplit_size = self._read_varint()
|
||||||
|
if joinsplit_size > 0:
|
||||||
|
joinsplit_desc_len = 1506 + (192 if is_sapling_v4 else 296)
|
||||||
|
# JSDescription
|
||||||
|
self.cursor += joinsplit_size * joinsplit_desc_len
|
||||||
|
self.cursor += 32 # joinSplitPubKey
|
||||||
|
self.cursor += 64 # joinSplitSig
|
||||||
|
|
||||||
|
if is_sapling_v4 and has_shielded:
|
||||||
|
self.cursor += 64 # bindingSig
|
||||||
|
|
||||||
|
return base_tx
|
||||||
|
|
||||||
|
|
||||||
|
class TxTime(namedtuple("Tx", "version time inputs outputs locktime")):
|
||||||
|
"""Class representing transaction that has a time field."""
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializerTxTime(Deserializer):
|
||||||
|
def read_tx(self):
|
||||||
|
return TxTime(
|
||||||
|
self._read_le_int32(), # version
|
||||||
|
self._read_le_uint32(), # time
|
||||||
|
self._read_inputs(), # inputs
|
||||||
|
self._read_outputs(), # outputs
|
||||||
|
self._read_le_uint32(), # locktime
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializerReddcoin(Deserializer):
|
||||||
|
def read_tx(self):
|
||||||
|
version = self._read_le_int32()
|
||||||
|
inputs = self._read_inputs()
|
||||||
|
outputs = self._read_outputs()
|
||||||
|
locktime = self._read_le_uint32()
|
||||||
|
if version > 1:
|
||||||
|
time = self._read_le_uint32()
|
||||||
|
else:
|
||||||
|
time = 0
|
||||||
|
|
||||||
|
return TxTime(version, time, inputs, outputs, locktime)
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializerTxTimeAuxPow(DeserializerTxTime):
|
||||||
|
VERSION_AUXPOW = (1 << 8)
|
||||||
|
|
||||||
|
def is_merged_block(self):
|
||||||
|
start = self.cursor
|
||||||
|
self.cursor = 0
|
||||||
|
version = self._read_le_uint32()
|
||||||
|
self.cursor = start
|
||||||
|
if version & self.VERSION_AUXPOW:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def read_header(self, height, static_header_size):
|
||||||
|
"""Return the AuxPow block header bytes"""
|
||||||
|
start = self.cursor
|
||||||
|
version = self._read_le_uint32()
|
||||||
|
if version & self.VERSION_AUXPOW:
|
||||||
|
# We are going to calculate the block size then read it as bytes
|
||||||
|
self.cursor = start
|
||||||
|
self.cursor += static_header_size # Block normal header
|
||||||
|
self.read_tx() # AuxPow transaction
|
||||||
|
self.cursor += 32 # Parent block hash
|
||||||
|
merkle_size = self._read_varint()
|
||||||
|
self.cursor += 32 * merkle_size # Merkle branch
|
||||||
|
self.cursor += 4 # Index
|
||||||
|
merkle_size = self._read_varint()
|
||||||
|
self.cursor += 32 * merkle_size # Chain merkle branch
|
||||||
|
self.cursor += 4 # Chain index
|
||||||
|
self.cursor += 80 # Parent block header
|
||||||
|
header_end = self.cursor
|
||||||
|
else:
|
||||||
|
header_end = static_header_size
|
||||||
|
self.cursor = start
|
||||||
|
return self._read_nbytes(header_end)
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializerBitcoinAtom(DeserializerSegWit):
|
||||||
|
FORK_BLOCK_HEIGHT = 505888
|
||||||
|
|
||||||
|
def read_header(self, height, static_header_size):
|
||||||
|
"""Return the block header bytes"""
|
||||||
|
header_len = static_header_size
|
||||||
|
if height >= self.FORK_BLOCK_HEIGHT:
|
||||||
|
header_len += 4 # flags
|
||||||
|
return self._read_nbytes(header_len)
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializerGroestlcoin(DeserializerSegWit):
|
||||||
|
TX_HASH_FN = staticmethod(sha256)
|
||||||
|
|
||||||
|
|
||||||
|
class TxInputTokenPay(TxInput):
|
||||||
|
"""Class representing a TokenPay transaction input."""
|
||||||
|
|
||||||
|
OP_ANON_MARKER = 0xb9
|
||||||
|
# 2byte marker (cpubkey + sigc + sigr)
|
||||||
|
MIN_ANON_IN_SIZE = 2 + (33 + 32 + 32)
|
||||||
|
|
||||||
|
def _is_anon_input(self):
|
||||||
|
return (len(self.script) >= self.MIN_ANON_IN_SIZE and
|
||||||
|
self.script[0] == OpCodes.OP_RETURN and
|
||||||
|
self.script[1] == self.OP_ANON_MARKER)
|
||||||
|
|
||||||
|
def is_generation(self):
|
||||||
|
# Transactions comming in from stealth addresses are seen by
|
||||||
|
# the blockchain as newly minted coins. The reverse, where coins
|
||||||
|
# are sent TO a stealth address, are seen by the blockchain as
|
||||||
|
# a coin burn.
|
||||||
|
if self._is_anon_input():
|
||||||
|
return True
|
||||||
|
return super(TxInputTokenPay, self).is_generation()
|
||||||
|
|
||||||
|
|
||||||
|
class TxInputTokenPayStealth(
|
||||||
|
namedtuple("TxInput", "keyimage ringsize script sequence")):
|
||||||
|
"""Class representing a TokenPay stealth transaction input."""
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
script = self.script.hex()
|
||||||
|
keyimage = bytes(self.keyimage).hex()
|
||||||
|
return ("Input({}, {:d}, script={}, sequence={:d})"
|
||||||
|
.format(keyimage, self.ringsize[1], script, self.sequence))
|
||||||
|
|
||||||
|
def is_generation(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
return b''.join((
|
||||||
|
self.keyimage,
|
||||||
|
self.ringsize,
|
||||||
|
pack_varbytes(self.script),
|
||||||
|
pack_le_uint32(self.sequence),
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializerTokenPay(DeserializerTxTime):
|
||||||
|
|
||||||
|
def _read_input(self):
|
||||||
|
txin = TxInputTokenPay(
|
||||||
|
self._read_nbytes(32), # prev_hash
|
||||||
|
self._read_le_uint32(), # prev_idx
|
||||||
|
self._read_varbytes(), # script
|
||||||
|
self._read_le_uint32(), # sequence
|
||||||
|
)
|
||||||
|
if txin._is_anon_input():
|
||||||
|
# Not sure if this is actually needed, and seems
|
||||||
|
# extra work for no immediate benefit, but it at
|
||||||
|
# least correctly represents a stealth input
|
||||||
|
raw = txin.serialize()
|
||||||
|
deserializer = Deserializer(raw)
|
||||||
|
txin = TxInputTokenPayStealth(
|
||||||
|
deserializer._read_nbytes(33), # keyimage
|
||||||
|
deserializer._read_nbytes(3), # ringsize
|
||||||
|
deserializer._read_varbytes(), # script
|
||||||
|
deserializer._read_le_uint32() # sequence
|
||||||
|
)
|
||||||
|
return txin
|
||||||
|
|
||||||
|
|
||||||
|
# Decred
|
||||||
|
class TxInputDcr(namedtuple("TxInput", "prev_hash prev_idx tree sequence")):
|
||||||
|
"""Class representing a Decred transaction input."""
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
prev_hash = hash_to_hex_str(self.prev_hash)
|
||||||
|
return ("Input({}, {:d}, tree={}, sequence={:d})"
|
||||||
|
.format(prev_hash, self.prev_idx, self.tree, self.sequence))
|
||||||
|
|
||||||
|
def is_generation(self):
|
||||||
|
"""Test if an input is generation/coinbase like"""
|
||||||
|
return self.prev_idx == MINUS_1 and self.prev_hash == ZERO
|
||||||
|
|
||||||
|
|
||||||
|
class TxOutputDcr(namedtuple("TxOutput", "value version pk_script")):
|
||||||
|
"""Class representing a Decred transaction output."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TxDcr(namedtuple("Tx", "version inputs outputs locktime expiry "
|
||||||
|
"witness")):
|
||||||
|
"""Class representing a Decred transaction."""
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializerDecred(Deserializer):
|
||||||
|
@staticmethod
|
||||||
|
def blake256(data):
|
||||||
|
from blake256.blake256 import blake_hash
|
||||||
|
return blake_hash(data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def blake256d(data):
|
||||||
|
from blake256.blake256 import blake_hash
|
||||||
|
return blake_hash(blake_hash(data))
|
||||||
|
|
||||||
|
def read_tx(self):
|
||||||
|
return self._read_tx_parts(produce_hash=False)[0]
|
||||||
|
|
||||||
|
def read_tx_and_hash(self):
|
||||||
|
tx, tx_hash, vsize = self._read_tx_parts()
|
||||||
|
return tx, tx_hash
|
||||||
|
|
||||||
|
def read_tx_and_vsize(self):
|
||||||
|
tx, tx_hash, vsize = self._read_tx_parts(produce_hash=False)
|
||||||
|
return tx, vsize
|
||||||
|
|
||||||
|
def read_tx_block(self):
|
||||||
|
"""Returns a list of (deserialized_tx, tx_hash) pairs."""
|
||||||
|
read = self.read_tx_and_hash
|
||||||
|
txs = [read() for _ in range(self._read_varint())]
|
||||||
|
stxs = [read() for _ in range(self._read_varint())]
|
||||||
|
return txs + stxs
|
||||||
|
|
||||||
|
def read_tx_tree(self):
|
||||||
|
"""Returns a list of deserialized_tx without tx hashes."""
|
||||||
|
read_tx = self.read_tx
|
||||||
|
return [read_tx() for _ in range(self._read_varint())]
|
||||||
|
|
||||||
|
def _read_input(self):
|
||||||
|
return TxInputDcr(
|
||||||
|
self._read_nbytes(32), # prev_hash
|
||||||
|
self._read_le_uint32(), # prev_idx
|
||||||
|
self._read_byte(), # tree
|
||||||
|
self._read_le_uint32(), # sequence
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read_output(self):
|
||||||
|
return TxOutputDcr(
|
||||||
|
self._read_le_int64(), # value
|
||||||
|
self._read_le_uint16(), # version
|
||||||
|
self._read_varbytes(), # pk_script
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read_witness(self, fields):
|
||||||
|
read_witness_field = self._read_witness_field
|
||||||
|
assert fields == self._read_varint()
|
||||||
|
return [read_witness_field() for _ in range(fields)]
|
||||||
|
|
||||||
|
def _read_witness_field(self):
|
||||||
|
value_in = self._read_le_int64()
|
||||||
|
block_height = self._read_le_uint32()
|
||||||
|
block_index = self._read_le_uint32()
|
||||||
|
script = self._read_varbytes()
|
||||||
|
return value_in, block_height, block_index, script
|
||||||
|
|
||||||
|
def _read_tx_parts(self, produce_hash=True):
|
||||||
|
start = self.cursor
|
||||||
|
version = self._read_le_int32()
|
||||||
|
inputs = self._read_inputs()
|
||||||
|
outputs = self._read_outputs()
|
||||||
|
locktime = self._read_le_uint32()
|
||||||
|
expiry = self._read_le_uint32()
|
||||||
|
end_prefix = self.cursor
|
||||||
|
witness = self._read_witness(len(inputs))
|
||||||
|
|
||||||
|
if produce_hash:
|
||||||
|
# TxSerializeNoWitness << 16 == 0x10000
|
||||||
|
no_witness_header = pack_le_uint32(0x10000 | (version & 0xffff))
|
||||||
|
prefix_tx = no_witness_header + self.binary[start+4:end_prefix]
|
||||||
|
tx_hash = self.blake256(prefix_tx)
|
||||||
|
else:
|
||||||
|
tx_hash = None
|
||||||
|
|
||||||
|
return TxDcr(
|
||||||
|
version,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
locktime,
|
||||||
|
expiry,
|
||||||
|
witness
|
||||||
|
), tx_hash, self.cursor - start
|
359
torba/torba/server/util.py
Normal file
359
torba/torba/server/util.py
Normal file
|
@ -0,0 +1,359 @@
|
||||||
|
# Copyright (c) 2016-2017, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
# and warranty status of this software.
|
||||||
|
|
||||||
|
"""Miscellaneous utility classes and functions."""
|
||||||
|
|
||||||
|
|
||||||
|
import array
|
||||||
|
import inspect
|
||||||
|
from ipaddress import ip_address
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from collections import Container, Mapping
|
||||||
|
from struct import pack, Struct
|
||||||
|
|
||||||
|
# Logging utilities
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionLogger(logging.LoggerAdapter):
|
||||||
|
"""Prepends a connection identifier to a logging message."""
|
||||||
|
def process(self, msg, kwargs):
|
||||||
|
conn_id = self.extra.get('conn_id', 'unknown')
|
||||||
|
return f'[{conn_id}] {msg}', kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class CompactFormatter(logging.Formatter):
|
||||||
|
"""Strips the module from the logger name to leave the class only."""
|
||||||
|
def format(self, record):
|
||||||
|
record.name = record.name.rpartition('.')[-1]
|
||||||
|
return super().format(record)
|
||||||
|
|
||||||
|
|
||||||
|
def make_logger(name, *, handler, level):
|
||||||
|
"""Return the root ElectrumX logger."""
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
logger.propagate = False
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def class_logger(path, classname):
|
||||||
|
"""Return a hierarchical logger for a class."""
|
||||||
|
return logging.getLogger(path).getChild(classname)
|
||||||
|
|
||||||
|
|
||||||
|
# Method decorator. To be used for calculations that will always
|
||||||
|
# deliver the same result. The method cannot take any arguments
|
||||||
|
# and should be accessed as an attribute.
|
||||||
|
class cachedproperty:
|
||||||
|
|
||||||
|
def __init__(self, f):
|
||||||
|
self.f = f
|
||||||
|
|
||||||
|
def __get__(self, obj, type):
|
||||||
|
obj = obj or type
|
||||||
|
value = self.f(obj)
|
||||||
|
setattr(obj, self.f.__name__, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def formatted_time(t, sep=' '):
|
||||||
|
"""Return a number of seconds as a string in days, hours, mins and
|
||||||
|
maybe secs."""
|
||||||
|
t = int(t)
|
||||||
|
fmts = (('{:d}d', 86400), ('{:02d}h', 3600), ('{:02d}m', 60))
|
||||||
|
parts = []
|
||||||
|
for fmt, n in fmts:
|
||||||
|
val = t // n
|
||||||
|
if parts or val:
|
||||||
|
parts.append(fmt.format(val))
|
||||||
|
t %= n
|
||||||
|
if len(parts) < 3:
|
||||||
|
parts.append('{:02d}s'.format(t))
|
||||||
|
return sep.join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def deep_getsizeof(obj):
|
||||||
|
"""Find the memory footprint of a Python object.
|
||||||
|
|
||||||
|
Based on code from code.tutsplus.com: http://goo.gl/fZ0DXK
|
||||||
|
|
||||||
|
This is a recursive function that drills down a Python object graph
|
||||||
|
like a dictionary holding nested dictionaries with lists of lists
|
||||||
|
and tuples and sets.
|
||||||
|
|
||||||
|
The sys.getsizeof function does a shallow size of only. It counts each
|
||||||
|
object inside a container as pointer only regardless of how big it
|
||||||
|
really is.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ids = set()
|
||||||
|
|
||||||
|
def size(o):
|
||||||
|
if id(o) in ids:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
r = sys.getsizeof(o)
|
||||||
|
ids.add(id(o))
|
||||||
|
|
||||||
|
if isinstance(o, (str, bytes, bytearray, array.array)):
|
||||||
|
return r
|
||||||
|
|
||||||
|
if isinstance(o, Mapping):
|
||||||
|
return r + sum(size(k) + size(v) for k, v in o.items())
|
||||||
|
|
||||||
|
if isinstance(o, Container):
|
||||||
|
return r + sum(size(x) for x in o)
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
|
return size(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def subclasses(base_class, strict=True):
|
||||||
|
"""Return a list of subclasses of base_class in its module."""
|
||||||
|
def select(obj):
|
||||||
|
return (inspect.isclass(obj) and issubclass(obj, base_class) and
|
||||||
|
(not strict or obj != base_class))
|
||||||
|
|
||||||
|
pairs = inspect.getmembers(sys.modules[base_class.__module__], select)
|
||||||
|
return [pair[1] for pair in pairs]
|
||||||
|
|
||||||
|
|
||||||
|
def chunks(items, size):
|
||||||
|
"""Break up items, an iterable, into chunks of length size."""
|
||||||
|
for i in range(0, len(items), size):
|
||||||
|
yield items[i: i + size]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_limit(limit):
|
||||||
|
if limit is None:
|
||||||
|
return -1
|
||||||
|
assert isinstance(limit, int) and limit >= 0
|
||||||
|
return limit
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_int(be_bytes):
|
||||||
|
"""Interprets a big-endian sequence of bytes as an integer"""
|
||||||
|
return int.from_bytes(be_bytes, 'big')
|
||||||
|
|
||||||
|
|
||||||
|
def int_to_bytes(value):
|
||||||
|
"""Converts an integer to a big-endian sequence of bytes"""
|
||||||
|
return value.to_bytes((value.bit_length() + 7) // 8, 'big')
|
||||||
|
|
||||||
|
|
||||||
|
def increment_byte_string(bs):
|
||||||
|
"""Return the lexicographically next byte string of the same length.
|
||||||
|
|
||||||
|
Return None if there is none (when the input is all 0xff bytes)."""
|
||||||
|
for n in range(1, len(bs) + 1):
|
||||||
|
if bs[-n] != 0xff:
|
||||||
|
return bs[:-n] + bytes([bs[-n] + 1]) + bytes(n - 1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class LogicalFile:
|
||||||
|
"""A logical binary file split across several separate files on disk."""
|
||||||
|
|
||||||
|
def __init__(self, prefix, digits, file_size):
|
||||||
|
digit_fmt = '{' + ':0{:d}d'.format(digits) + '}'
|
||||||
|
self.filename_fmt = prefix + digit_fmt
|
||||||
|
self.file_size = file_size
|
||||||
|
|
||||||
|
def read(self, start, size=-1):
|
||||||
|
"""Read up to size bytes from the virtual file, starting at offset
|
||||||
|
start, and return them.
|
||||||
|
|
||||||
|
If size is -1 all bytes are read."""
|
||||||
|
parts = []
|
||||||
|
while size != 0:
|
||||||
|
try:
|
||||||
|
with self.open_file(start, False) as f:
|
||||||
|
part = f.read(size)
|
||||||
|
if not part:
|
||||||
|
break
|
||||||
|
except FileNotFoundError:
|
||||||
|
break
|
||||||
|
parts.append(part)
|
||||||
|
start += len(part)
|
||||||
|
if size > 0:
|
||||||
|
size -= len(part)
|
||||||
|
return b''.join(parts)
|
||||||
|
|
||||||
|
def write(self, start, b):
|
||||||
|
"""Write the bytes-like object, b, to the underlying virtual file."""
|
||||||
|
while b:
|
||||||
|
size = min(len(b), self.file_size - (start % self.file_size))
|
||||||
|
with self.open_file(start, True) as f:
|
||||||
|
f.write(b if size == len(b) else b[:size])
|
||||||
|
b = b[size:]
|
||||||
|
start += size
|
||||||
|
|
||||||
|
def open_file(self, start, create):
|
||||||
|
"""Open the virtual file and seek to start. Return a file handle.
|
||||||
|
Raise FileNotFoundError if the file does not exist and create
|
||||||
|
is False.
|
||||||
|
"""
|
||||||
|
file_num, offset = divmod(start, self.file_size)
|
||||||
|
filename = self.filename_fmt.format(file_num)
|
||||||
|
f = open_file(filename, create)
|
||||||
|
f.seek(offset)
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def open_file(filename, create=False):
|
||||||
|
"""Open the file name. Return its handle."""
|
||||||
|
try:
|
||||||
|
return open(filename, 'rb+')
|
||||||
|
except FileNotFoundError:
|
||||||
|
if create:
|
||||||
|
return open(filename, 'wb+')
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def open_truncate(filename):
|
||||||
|
"""Open the file name. Return its handle."""
|
||||||
|
return open(filename, 'wb+')
|
||||||
|
|
||||||
|
|
||||||
|
def address_string(address):
|
||||||
|
"""Return an address as a correctly formatted string."""
|
||||||
|
fmt = '{}:{:d}'
|
||||||
|
host, port = address
|
||||||
|
try:
|
||||||
|
host = ip_address(host)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if host.version == 6:
|
||||||
|
fmt = '[{}]:{:d}'
|
||||||
|
return fmt.format(host, port)
|
||||||
|
|
||||||
|
# See http://stackoverflow.com/questions/2532053/validate-a-hostname-string
|
||||||
|
# Note underscores are valid in domain names, but strictly invalid in host
|
||||||
|
# names. We ignore that distinction.
|
||||||
|
|
||||||
|
|
||||||
|
SEGMENT_REGEX = re.compile("(?!-)[A-Z_\\d-]{1,63}(?<!-)$", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_hostname(hostname):
|
||||||
|
if len(hostname) > 255:
|
||||||
|
return False
|
||||||
|
# strip exactly one dot from the right, if present
|
||||||
|
if hostname and hostname[-1] == ".":
|
||||||
|
hostname = hostname[:-1]
|
||||||
|
return all(SEGMENT_REGEX.match(x) for x in hostname.split("."))
|
||||||
|
|
||||||
|
|
||||||
|
def protocol_tuple(s):
|
||||||
|
"""Converts a protocol version number, such as "1.0" to a tuple (1, 0).
|
||||||
|
|
||||||
|
If the version number is bad, (0, ) indicating version 0 is returned."""
|
||||||
|
try:
|
||||||
|
return tuple(int(part) for part in s.split('.'))
|
||||||
|
except Exception:
|
||||||
|
return (0, )
|
||||||
|
|
||||||
|
|
||||||
|
def version_string(ptuple):
|
||||||
|
"""Convert a version tuple such as (1, 2) to "1.2".
|
||||||
|
There is always at least one dot, so (1, ) becomes "1.0"."""
|
||||||
|
while len(ptuple) < 2:
|
||||||
|
ptuple += (0, )
|
||||||
|
return '.'.join(str(p) for p in ptuple)
|
||||||
|
|
||||||
|
|
||||||
|
def protocol_version(client_req, min_tuple, max_tuple):
|
||||||
|
"""Given a client's protocol version string, return a pair of
|
||||||
|
protocol tuples:
|
||||||
|
|
||||||
|
(negotiated version, client min request)
|
||||||
|
|
||||||
|
If the request is unsupported, the negotiated protocol tuple is
|
||||||
|
None.
|
||||||
|
"""
|
||||||
|
if client_req is None:
|
||||||
|
client_min = client_max = min_tuple
|
||||||
|
else:
|
||||||
|
if isinstance(client_req, list) and len(client_req) == 2:
|
||||||
|
client_min, client_max = client_req
|
||||||
|
else:
|
||||||
|
client_min = client_max = client_req
|
||||||
|
client_min = protocol_tuple(client_min)
|
||||||
|
client_max = protocol_tuple(client_max)
|
||||||
|
|
||||||
|
result = min(client_max, max_tuple)
|
||||||
|
if result < max(client_min, min_tuple) or result == (0, ):
|
||||||
|
result = None
|
||||||
|
|
||||||
|
return result, client_min
|
||||||
|
|
||||||
|
|
||||||
|
struct_le_i = Struct('<i')
|
||||||
|
struct_le_q = Struct('<q')
|
||||||
|
struct_le_H = Struct('<H')
|
||||||
|
struct_le_I = Struct('<I')
|
||||||
|
struct_le_Q = Struct('<Q')
|
||||||
|
struct_be_H = Struct('>H')
|
||||||
|
struct_be_I = Struct('>I')
|
||||||
|
structB = Struct('B')
|
||||||
|
|
||||||
|
unpack_le_int32_from = struct_le_i.unpack_from
|
||||||
|
unpack_le_int64_from = struct_le_q.unpack_from
|
||||||
|
unpack_le_uint16_from = struct_le_H.unpack_from
|
||||||
|
unpack_le_uint32_from = struct_le_I.unpack_from
|
||||||
|
unpack_le_uint64_from = struct_le_Q.unpack_from
|
||||||
|
unpack_be_uint16_from = struct_be_H.unpack_from
|
||||||
|
unpack_be_uint32_from = struct_be_I.unpack_from
|
||||||
|
|
||||||
|
pack_le_int32 = struct_le_i.pack
|
||||||
|
pack_le_int64 = struct_le_q.pack
|
||||||
|
pack_le_uint16 = struct_le_H.pack
|
||||||
|
pack_le_uint32 = struct_le_I.pack
|
||||||
|
pack_le_uint64 = struct_le_Q.pack
|
||||||
|
pack_be_uint16 = struct_be_H.pack
|
||||||
|
pack_be_uint32 = struct_be_I.pack
|
||||||
|
pack_byte = structB.pack
|
||||||
|
|
||||||
|
hex_to_bytes = bytes.fromhex
|
||||||
|
|
||||||
|
|
||||||
|
def pack_varint(n):
|
||||||
|
if n < 253:
|
||||||
|
return pack_byte(n)
|
||||||
|
if n < 65536:
|
||||||
|
return pack_byte(253) + pack_le_uint16(n)
|
||||||
|
if n < 4294967296:
|
||||||
|
return pack_byte(254) + pack_le_uint32(n)
|
||||||
|
return pack_byte(255) + pack_le_uint64(n)
|
||||||
|
|
||||||
|
|
||||||
|
def pack_varbytes(data):
|
||||||
|
return pack_varint(len(data)) + data
|
157
torba/torba/stream.py
Normal file
157
torba/torba/stream.py
Normal file
|
@ -0,0 +1,157 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class BroadcastSubscription:
|
||||||
|
|
||||||
|
def __init__(self, controller, on_data, on_error, on_done):
|
||||||
|
self._controller = controller
|
||||||
|
self._previous = self._next = None
|
||||||
|
self._on_data = on_data
|
||||||
|
self._on_error = on_error
|
||||||
|
self._on_done = on_done
|
||||||
|
self.is_paused = False
|
||||||
|
self.is_canceled = False
|
||||||
|
self.is_closed = False
|
||||||
|
|
||||||
|
def pause(self):
|
||||||
|
self.is_paused = True
|
||||||
|
|
||||||
|
def resume(self):
|
||||||
|
self.is_paused = False
|
||||||
|
|
||||||
|
def cancel(self):
|
||||||
|
self._controller._cancel(self)
|
||||||
|
self.is_canceled = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_fire(self):
|
||||||
|
return not any((self.is_paused, self.is_canceled, self.is_closed))
|
||||||
|
|
||||||
|
def _add(self, data):
|
||||||
|
if self.can_fire and self._on_data is not None:
|
||||||
|
return self._on_data(data)
|
||||||
|
|
||||||
|
def _add_error(self, exception):
|
||||||
|
if self.can_fire and self._on_error is not None:
|
||||||
|
return self._on_error(exception)
|
||||||
|
|
||||||
|
def _close(self):
|
||||||
|
try:
|
||||||
|
if self.can_fire and self._on_done is not None:
|
||||||
|
return self._on_done()
|
||||||
|
finally:
|
||||||
|
self.is_closed = True
|
||||||
|
|
||||||
|
|
||||||
|
class StreamController:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.stream = Stream(self)
|
||||||
|
self._first_subscription = None
|
||||||
|
self._last_subscription = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_listener(self):
|
||||||
|
return self._first_subscription is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _iterate_subscriptions(self):
|
||||||
|
next_sub = self._first_subscription
|
||||||
|
while next_sub is not None:
|
||||||
|
subscription = next_sub
|
||||||
|
next_sub = next_sub._next
|
||||||
|
yield subscription
|
||||||
|
|
||||||
|
def _notify_and_ensure_future(self, notify):
|
||||||
|
tasks = []
|
||||||
|
for subscription in self._iterate_subscriptions:
|
||||||
|
maybe_coroutine = notify(subscription)
|
||||||
|
if asyncio.iscoroutine(maybe_coroutine):
|
||||||
|
tasks.append(maybe_coroutine)
|
||||||
|
if tasks:
|
||||||
|
return asyncio.ensure_future(asyncio.wait(tasks))
|
||||||
|
else:
|
||||||
|
f = asyncio.get_event_loop().create_future()
|
||||||
|
f.set_result(None)
|
||||||
|
return f
|
||||||
|
|
||||||
|
def add(self, event):
|
||||||
|
return self._notify_and_ensure_future(
|
||||||
|
lambda subscription: subscription._add(event)
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_error(self, exception):
|
||||||
|
return self._notify_and_ensure_future(
|
||||||
|
lambda subscription: subscription._add_error(exception)
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
for subscription in self._iterate_subscriptions:
|
||||||
|
subscription._close()
|
||||||
|
|
||||||
|
def _cancel(self, subscription):
|
||||||
|
previous = subscription._previous
|
||||||
|
next_sub = subscription._next
|
||||||
|
if previous is None:
|
||||||
|
self._first_subscription = next_sub
|
||||||
|
else:
|
||||||
|
previous._next = next_sub
|
||||||
|
if next_sub is None:
|
||||||
|
self._last_subscription = previous
|
||||||
|
else:
|
||||||
|
next_sub._previous = previous
|
||||||
|
subscription._next = subscription._previous = subscription
|
||||||
|
|
||||||
|
def _listen(self, on_data, on_error, on_done):
|
||||||
|
subscription = BroadcastSubscription(self, on_data, on_error, on_done)
|
||||||
|
old_last = self._last_subscription
|
||||||
|
self._last_subscription = subscription
|
||||||
|
subscription._previous = old_last
|
||||||
|
subscription._next = None
|
||||||
|
if old_last is None:
|
||||||
|
self._first_subscription = subscription
|
||||||
|
else:
|
||||||
|
old_last._next = subscription
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
|
||||||
|
class Stream:
|
||||||
|
|
||||||
|
def __init__(self, controller):
|
||||||
|
self._controller = controller
|
||||||
|
|
||||||
|
def listen(self, on_data, on_error=None, on_done=None):
|
||||||
|
return self._controller._listen(on_data, on_error, on_done)
|
||||||
|
|
||||||
|
def where(self, condition) -> asyncio.Future:
|
||||||
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
|
||||||
|
def where_test(value):
|
||||||
|
if condition(value):
|
||||||
|
self._cancel_and_callback(subscription, future, value)
|
||||||
|
|
||||||
|
subscription = self.listen(
|
||||||
|
where_test,
|
||||||
|
lambda exception: self._cancel_and_error(subscription, future, exception)
|
||||||
|
)
|
||||||
|
|
||||||
|
return future
|
||||||
|
|
||||||
|
@property
|
||||||
|
def first(self):
|
||||||
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
subscription = self.listen(
|
||||||
|
lambda value: self._cancel_and_callback(subscription, future, value),
|
||||||
|
lambda exception: self._cancel_and_error(subscription, future, exception)
|
||||||
|
)
|
||||||
|
return future
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cancel_and_callback(subscription: BroadcastSubscription, future: asyncio.Future, value):
|
||||||
|
subscription.cancel()
|
||||||
|
future.set_result(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cancel_and_error(subscription: BroadcastSubscription, future: asyncio.Future, exception):
|
||||||
|
subscription.cancel()
|
||||||
|
future.set_exception(exception)
|
24
torba/torba/tasks.py
Normal file
24
torba/torba/tasks.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
from asyncio import Event, get_event_loop
|
||||||
|
|
||||||
|
|
||||||
|
class TaskGroup:
|
||||||
|
|
||||||
|
def __init__(self, loop=None):
|
||||||
|
self._loop = loop or get_event_loop()
|
||||||
|
self._tasks = set()
|
||||||
|
self.done = Event()
|
||||||
|
|
||||||
|
def add(self, coro):
|
||||||
|
task = self._loop.create_task(coro)
|
||||||
|
self._tasks.add(task)
|
||||||
|
self.done.clear()
|
||||||
|
task.add_done_callback(self._remove)
|
||||||
|
return task
|
||||||
|
|
||||||
|
def _remove(self, task):
|
||||||
|
self._tasks.remove(task)
|
||||||
|
len(self._tasks) < 1 and self.done.set()
|
||||||
|
|
||||||
|
def cancel(self):
|
||||||
|
for task in self._tasks:
|
||||||
|
task.cancel()
|
239
torba/torba/testcase.py
Normal file
239
torba/torba/testcase.py
Normal file
|
@ -0,0 +1,239 @@
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import functools
|
||||||
|
import asyncio
|
||||||
|
from asyncio.runners import _cancel_all_tasks # type: ignore
|
||||||
|
import unittest
|
||||||
|
from unittest.case import _Outcome
|
||||||
|
from typing import Optional
|
||||||
|
from torba.orchstr8 import Conductor
|
||||||
|
from torba.orchstr8.node import BlockchainNode, WalletNode
|
||||||
|
from torba.client.baseledger import BaseLedger
|
||||||
|
from torba.client.baseaccount import BaseAccount
|
||||||
|
from torba.client.basemanager import BaseWalletManager
|
||||||
|
from torba.client.wallet import Wallet
|
||||||
|
from torba.client.util import satoshis_to_coins
|
||||||
|
|
||||||
|
|
||||||
|
class ColorHandler(logging.StreamHandler):
|
||||||
|
|
||||||
|
level_color = {
|
||||||
|
logging.DEBUG: "black",
|
||||||
|
logging.INFO: "light_gray",
|
||||||
|
logging.WARNING: "yellow",
|
||||||
|
logging.ERROR: "red"
|
||||||
|
}
|
||||||
|
|
||||||
|
color_code = dict(
|
||||||
|
black=30,
|
||||||
|
red=31,
|
||||||
|
green=32,
|
||||||
|
yellow=33,
|
||||||
|
blue=34,
|
||||||
|
magenta=35,
|
||||||
|
cyan=36,
|
||||||
|
white=37,
|
||||||
|
light_gray='0;37',
|
||||||
|
dark_gray='1;30'
|
||||||
|
)
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
try:
|
||||||
|
msg = self.format(record)
|
||||||
|
color_name = self.level_color.get(record.levelno, "black")
|
||||||
|
color_code = self.color_code[color_name]
|
||||||
|
stream = self.stream
|
||||||
|
stream.write('\x1b[%sm%s\x1b[0m' % (color_code, msg))
|
||||||
|
stream.write(self.terminator)
|
||||||
|
self.flush()
|
||||||
|
except Exception:
|
||||||
|
self.handleError(record)
|
||||||
|
|
||||||
|
|
||||||
|
HANDLER = ColorHandler(sys.stdout)
|
||||||
|
HANDLER.setFormatter(
|
||||||
|
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
)
|
||||||
|
logging.getLogger().addHandler(HANDLER)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncioTestCase(unittest.TestCase):
|
||||||
|
# Implementation inspired by discussion:
|
||||||
|
# https://bugs.python.org/issue32972
|
||||||
|
|
||||||
|
maxDiff = None
|
||||||
|
|
||||||
|
async def asyncSetUp(self): # pylint: disable=C0103
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def asyncTearDown(self): # pylint: disable=C0103
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run(self, result=None): # pylint: disable=R0915
|
||||||
|
orig_result = result
|
||||||
|
if result is None:
|
||||||
|
result = self.defaultTestResult()
|
||||||
|
startTestRun = getattr(result, 'startTestRun', None) # pylint: disable=C0103
|
||||||
|
if startTestRun is not None:
|
||||||
|
startTestRun()
|
||||||
|
|
||||||
|
result.startTest(self)
|
||||||
|
|
||||||
|
testMethod = getattr(self, self._testMethodName) # pylint: disable=C0103
|
||||||
|
if (getattr(self.__class__, "__unittest_skip__", False) or
|
||||||
|
getattr(testMethod, "__unittest_skip__", False)):
|
||||||
|
# If the class or method was skipped.
|
||||||
|
try:
|
||||||
|
skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
|
||||||
|
or getattr(testMethod, '__unittest_skip_why__', ''))
|
||||||
|
self._addSkip(result, self, skip_why)
|
||||||
|
finally:
|
||||||
|
result.stopTest(self)
|
||||||
|
return
|
||||||
|
expecting_failure_method = getattr(testMethod,
|
||||||
|
"__unittest_expecting_failure__", False)
|
||||||
|
expecting_failure_class = getattr(self,
|
||||||
|
"__unittest_expecting_failure__", False)
|
||||||
|
expecting_failure = expecting_failure_class or expecting_failure_method
|
||||||
|
outcome = _Outcome(result)
|
||||||
|
|
||||||
|
self.loop = asyncio.new_event_loop() # pylint: disable=W0201
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
self.loop.set_debug(True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._outcome = outcome
|
||||||
|
|
||||||
|
with outcome.testPartExecutor(self):
|
||||||
|
self.setUp()
|
||||||
|
self.loop.run_until_complete(self.asyncSetUp())
|
||||||
|
if outcome.success:
|
||||||
|
outcome.expecting_failure = expecting_failure
|
||||||
|
with outcome.testPartExecutor(self, isTest=True):
|
||||||
|
maybe_coroutine = testMethod()
|
||||||
|
if asyncio.iscoroutine(maybe_coroutine):
|
||||||
|
self.loop.run_until_complete(maybe_coroutine)
|
||||||
|
outcome.expecting_failure = False
|
||||||
|
with outcome.testPartExecutor(self):
|
||||||
|
self.loop.run_until_complete(self.asyncTearDown())
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
self.doAsyncCleanups()
|
||||||
|
|
||||||
|
try:
|
||||||
|
_cancel_all_tasks(self.loop)
|
||||||
|
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
|
||||||
|
finally:
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
self.loop.close()
|
||||||
|
|
||||||
|
for test, reason in outcome.skipped:
|
||||||
|
self._addSkip(result, test, reason)
|
||||||
|
self._feedErrorsToResult(result, outcome.errors)
|
||||||
|
if outcome.success:
|
||||||
|
if expecting_failure:
|
||||||
|
if outcome.expectedFailure:
|
||||||
|
self._addExpectedFailure(result, outcome.expectedFailure)
|
||||||
|
else:
|
||||||
|
self._addUnexpectedSuccess(result)
|
||||||
|
else:
|
||||||
|
result.addSuccess(self)
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
result.stopTest(self)
|
||||||
|
if orig_result is None:
|
||||||
|
stopTestRun = getattr(result, 'stopTestRun', None) # pylint: disable=C0103
|
||||||
|
if stopTestRun is not None:
|
||||||
|
stopTestRun() # pylint: disable=E1102
|
||||||
|
|
||||||
|
# explicitly break reference cycles:
|
||||||
|
# outcome.errors -> frame -> outcome -> outcome.errors
|
||||||
|
# outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
|
||||||
|
outcome.errors.clear()
|
||||||
|
outcome.expectedFailure = None
|
||||||
|
|
||||||
|
# clear the outcome, no more needed
|
||||||
|
self._outcome = None
|
||||||
|
|
||||||
|
def doAsyncCleanups(self): # pylint: disable=C0103
|
||||||
|
outcome = self._outcome or _Outcome()
|
||||||
|
while self._cleanups:
|
||||||
|
function, args, kwargs = self._cleanups.pop()
|
||||||
|
with outcome.testPartExecutor(self):
|
||||||
|
maybe_coroutine = function(*args, **kwargs)
|
||||||
|
if asyncio.iscoroutine(maybe_coroutine):
|
||||||
|
self.loop.run_until_complete(maybe_coroutine)
|
||||||
|
|
||||||
|
|
||||||
|
class AdvanceTimeTestCase(AsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self._time = 0 # pylint: disable=W0201
|
||||||
|
self.loop.time = functools.wraps(self.loop.time)(lambda: self._time)
|
||||||
|
await super().asyncSetUp()
|
||||||
|
|
||||||
|
async def advance(self, seconds):
|
||||||
|
while self.loop._ready:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self._time += seconds
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
while self.loop._ready:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrationTestCase(AsyncioTestCase):
|
||||||
|
|
||||||
|
LEDGER = None
|
||||||
|
MANAGER = None
|
||||||
|
VERBOSITY = logging.WARN
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.conductor: Optional[Conductor] = None
|
||||||
|
self.blockchain: Optional[BlockchainNode] = None
|
||||||
|
self.wallet_node: Optional[WalletNode] = None
|
||||||
|
self.manager: Optional[BaseWalletManager] = None
|
||||||
|
self.ledger: Optional[BaseLedger] = None
|
||||||
|
self.wallet: Optional[Wallet] = None
|
||||||
|
self.account: Optional[BaseAccount] = None
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.conductor = Conductor(
|
||||||
|
ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY
|
||||||
|
)
|
||||||
|
await self.conductor.start_blockchain()
|
||||||
|
self.addCleanup(self.conductor.stop_blockchain)
|
||||||
|
await self.conductor.start_spv()
|
||||||
|
self.addCleanup(self.conductor.stop_spv)
|
||||||
|
await self.conductor.start_wallet()
|
||||||
|
self.addCleanup(self.conductor.stop_wallet)
|
||||||
|
self.blockchain = self.conductor.blockchain_node
|
||||||
|
self.wallet_node = self.conductor.wallet_node
|
||||||
|
self.manager = self.wallet_node.manager
|
||||||
|
self.ledger = self.wallet_node.ledger
|
||||||
|
self.wallet = self.wallet_node.wallet
|
||||||
|
self.account = self.wallet_node.wallet.default_account
|
||||||
|
|
||||||
|
async def assertBalance(self, account, expected_balance: str): # pylint: disable=C0103
|
||||||
|
balance = await account.get_balance()
|
||||||
|
self.assertEqual(satoshis_to_coins(balance), expected_balance)
|
||||||
|
|
||||||
|
def broadcast(self, tx):
|
||||||
|
return self.ledger.broadcast(tx)
|
||||||
|
|
||||||
|
async def on_header(self, height):
|
||||||
|
if self.ledger.headers.height < height:
|
||||||
|
await self.ledger.on_header.where(
|
||||||
|
lambda e: e.height == height
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_transaction_id(self, txid, ledger=None):
|
||||||
|
return (ledger or self.ledger).on_transaction.where(
|
||||||
|
lambda e: e.tx.id == txid
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_transaction_address(self, tx, address):
|
||||||
|
return self.ledger.on_transaction.where(
|
||||||
|
lambda e: e.tx.id == tx.id and e.address == address
|
||||||
|
)
|
0
torba/torba/ui/__init__.py
Normal file
0
torba/torba/ui/__init__.py
Normal file
5
torba/torba/workbench/Makefile
Normal file
5
torba/torba/workbench/Makefile
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
all: _blockchain_dock.py _output_dock.py
|
||||||
|
_blockchain_dock.py: blockchain_dock.ui
|
||||||
|
pyside2-uic -d blockchain_dock.ui -o _blockchain_dock.py
|
||||||
|
_output_dock.py: output_dock.ui
|
||||||
|
pyside2-uic -d output_dock.ui -o _output_dock.py
|
1
torba/torba/workbench/__init__.py
Normal file
1
torba/torba/workbench/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .application import main
|
70
torba/torba/workbench/_blockchain_dock.py
Normal file
70
torba/torba/workbench/_blockchain_dock.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Form implementation generated from reading ui file 'blockchain_dock.ui',
|
||||||
|
# licensing of 'blockchain_dock.ui' applies.
|
||||||
|
#
|
||||||
|
# Created: Sun Jan 13 02:56:21 2019
|
||||||
|
# by: pyside2-uic running on PySide2 5.12.0
|
||||||
|
#
|
||||||
|
# WARNING! All changes made in this file will be lost!
|
||||||
|
|
||||||
|
from PySide2 import QtCore, QtGui, QtWidgets
|
||||||
|
|
||||||
|
class Ui_BlockchainDock(object):
|
||||||
|
def setupUi(self, BlockchainDock):
|
||||||
|
BlockchainDock.setObjectName("BlockchainDock")
|
||||||
|
BlockchainDock.resize(416, 167)
|
||||||
|
BlockchainDock.setFloating(False)
|
||||||
|
BlockchainDock.setFeatures(QtWidgets.QDockWidget.AllDockWidgetFeatures)
|
||||||
|
self.dockWidgetContents = QtWidgets.QWidget()
|
||||||
|
self.dockWidgetContents.setObjectName("dockWidgetContents")
|
||||||
|
self.formLayout = QtWidgets.QFormLayout(self.dockWidgetContents)
|
||||||
|
self.formLayout.setObjectName("formLayout")
|
||||||
|
self.generate = QtWidgets.QPushButton(self.dockWidgetContents)
|
||||||
|
self.generate.setObjectName("generate")
|
||||||
|
self.formLayout.setWidget(0, QtWidgets.QFormLayout.LabelRole, self.generate)
|
||||||
|
self.blocks = QtWidgets.QSpinBox(self.dockWidgetContents)
|
||||||
|
self.blocks.setMinimum(1)
|
||||||
|
self.blocks.setMaximum(9999)
|
||||||
|
self.blocks.setProperty("value", 1)
|
||||||
|
self.blocks.setObjectName("blocks")
|
||||||
|
self.formLayout.setWidget(0, QtWidgets.QFormLayout.FieldRole, self.blocks)
|
||||||
|
self.transfer = QtWidgets.QPushButton(self.dockWidgetContents)
|
||||||
|
self.transfer.setObjectName("transfer")
|
||||||
|
self.formLayout.setWidget(1, QtWidgets.QFormLayout.LabelRole, self.transfer)
|
||||||
|
self.horizontalLayout = QtWidgets.QHBoxLayout()
|
||||||
|
self.horizontalLayout.setObjectName("horizontalLayout")
|
||||||
|
self.amount = QtWidgets.QDoubleSpinBox(self.dockWidgetContents)
|
||||||
|
self.amount.setSuffix("")
|
||||||
|
self.amount.setMaximum(9999.99)
|
||||||
|
self.amount.setProperty("value", 10.0)
|
||||||
|
self.amount.setObjectName("amount")
|
||||||
|
self.horizontalLayout.addWidget(self.amount)
|
||||||
|
self.to_label = QtWidgets.QLabel(self.dockWidgetContents)
|
||||||
|
self.to_label.setObjectName("to_label")
|
||||||
|
self.horizontalLayout.addWidget(self.to_label)
|
||||||
|
self.address = QtWidgets.QLineEdit(self.dockWidgetContents)
|
||||||
|
self.address.setObjectName("address")
|
||||||
|
self.horizontalLayout.addWidget(self.address)
|
||||||
|
self.formLayout.setLayout(1, QtWidgets.QFormLayout.FieldRole, self.horizontalLayout)
|
||||||
|
self.invalidate = QtWidgets.QPushButton(self.dockWidgetContents)
|
||||||
|
self.invalidate.setObjectName("invalidate")
|
||||||
|
self.formLayout.setWidget(2, QtWidgets.QFormLayout.LabelRole, self.invalidate)
|
||||||
|
self.block_hash = QtWidgets.QLineEdit(self.dockWidgetContents)
|
||||||
|
self.block_hash.setObjectName("block_hash")
|
||||||
|
self.formLayout.setWidget(2, QtWidgets.QFormLayout.FieldRole, self.block_hash)
|
||||||
|
BlockchainDock.setWidget(self.dockWidgetContents)
|
||||||
|
|
||||||
|
self.retranslateUi(BlockchainDock)
|
||||||
|
QtCore.QMetaObject.connectSlotsByName(BlockchainDock)
|
||||||
|
|
||||||
|
def retranslateUi(self, BlockchainDock):
|
||||||
|
BlockchainDock.setWindowTitle(QtWidgets.QApplication.translate("BlockchainDock", "Blockchain", None, -1))
|
||||||
|
self.generate.setText(QtWidgets.QApplication.translate("BlockchainDock", "generate", None, -1))
|
||||||
|
self.blocks.setSuffix(QtWidgets.QApplication.translate("BlockchainDock", " block(s)", None, -1))
|
||||||
|
self.transfer.setText(QtWidgets.QApplication.translate("BlockchainDock", "transfer", None, -1))
|
||||||
|
self.to_label.setText(QtWidgets.QApplication.translate("BlockchainDock", "to", None, -1))
|
||||||
|
self.address.setPlaceholderText(QtWidgets.QApplication.translate("BlockchainDock", "recipient address", None, -1))
|
||||||
|
self.invalidate.setText(QtWidgets.QApplication.translate("BlockchainDock", "invalidate", None, -1))
|
||||||
|
self.block_hash.setPlaceholderText(QtWidgets.QApplication.translate("BlockchainDock", "block hash", None, -1))
|
||||||
|
|
34
torba/torba/workbench/_output_dock.py
Normal file
34
torba/torba/workbench/_output_dock.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Form implementation generated from reading ui file 'output_dock.ui',
|
||||||
|
# licensing of 'output_dock.ui' applies.
|
||||||
|
#
|
||||||
|
# Created: Sat Oct 27 16:41:03 2018
|
||||||
|
# by: pyside2-uic running on PySide2 5.11.2
|
||||||
|
#
|
||||||
|
# WARNING! All changes made in this file will be lost!
|
||||||
|
|
||||||
|
from PySide2 import QtCore, QtGui, QtWidgets
|
||||||
|
|
||||||
|
class Ui_OutputDock(object):
|
||||||
|
def setupUi(self, OutputDock):
|
||||||
|
OutputDock.setObjectName("OutputDock")
|
||||||
|
OutputDock.resize(700, 397)
|
||||||
|
OutputDock.setFloating(False)
|
||||||
|
OutputDock.setFeatures(QtWidgets.QDockWidget.AllDockWidgetFeatures)
|
||||||
|
self.dockWidgetContents = QtWidgets.QWidget()
|
||||||
|
self.dockWidgetContents.setObjectName("dockWidgetContents")
|
||||||
|
self.horizontalLayout = QtWidgets.QHBoxLayout(self.dockWidgetContents)
|
||||||
|
self.horizontalLayout.setObjectName("horizontalLayout")
|
||||||
|
self.textEdit = QtWidgets.QTextEdit(self.dockWidgetContents)
|
||||||
|
self.textEdit.setReadOnly(True)
|
||||||
|
self.textEdit.setObjectName("textEdit")
|
||||||
|
self.horizontalLayout.addWidget(self.textEdit)
|
||||||
|
OutputDock.setWidget(self.dockWidgetContents)
|
||||||
|
|
||||||
|
self.retranslateUi(OutputDock)
|
||||||
|
QtCore.QMetaObject.connectSlotsByName(OutputDock)
|
||||||
|
|
||||||
|
def retranslateUi(self, OutputDock):
|
||||||
|
OutputDock.setWindowTitle(QtWidgets.QApplication.translate("OutputDock", "Output", None, -1))
|
||||||
|
|
401
torba/torba/workbench/application.py
Normal file
401
torba/torba/workbench/application.py
Normal file
|
@ -0,0 +1,401 @@
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
|
||||||
|
from PySide2 import QtCore, QtGui, QtWidgets, QtNetwork, QtWebSockets, QtSvg
|
||||||
|
|
||||||
|
from torba.workbench._output_dock import Ui_OutputDock as OutputDock
|
||||||
|
from torba.workbench._blockchain_dock import Ui_BlockchainDock as BlockchainDock
|
||||||
|
|
||||||
|
|
||||||
|
def dict_to_post_data(d):
|
||||||
|
query = QtCore.QUrlQuery()
|
||||||
|
for key, value in d.items():
|
||||||
|
query.addQueryItem(str(key), str(value))
|
||||||
|
return QtCore.QByteArray(query.toString().encode())
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingOutput(QtWidgets.QDockWidget, OutputDock):
|
||||||
|
|
||||||
|
def __init__(self, title, parent):
|
||||||
|
super().__init__(parent)
|
||||||
|
self.setupUi(self)
|
||||||
|
self.setWindowTitle(title)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockchainControls(QtWidgets.QDockWidget, BlockchainDock):
|
||||||
|
|
||||||
|
def __init__(self, parent):
|
||||||
|
super().__init__(parent)
|
||||||
|
self.setupUi(self)
|
||||||
|
self.generate.clicked.connect(self.on_generate)
|
||||||
|
self.transfer.clicked.connect(self.on_transfer)
|
||||||
|
|
||||||
|
def on_generate(self):
|
||||||
|
print('generating')
|
||||||
|
self.parent().run_command('generate', blocks=self.blocks.value())
|
||||||
|
|
||||||
|
def on_transfer(self):
|
||||||
|
print('transfering')
|
||||||
|
self.parent().run_command('transfer', amount=self.amount.value())
|
||||||
|
|
||||||
|
|
||||||
|
class Arrow(QtWidgets.QGraphicsLineItem):
|
||||||
|
|
||||||
|
def __init__(self, start_node, end_node, parent=None, scene=None):
|
||||||
|
super().__init__(parent, scene)
|
||||||
|
self.start_node = start_node
|
||||||
|
self.start_node.connect_arrow(self)
|
||||||
|
self.end_node = end_node
|
||||||
|
self.end_node.connect_arrow(self)
|
||||||
|
self.arrow_head = QtGui.QPolygonF()
|
||||||
|
self.setFlag(QtWidgets.QGraphicsItem.ItemIsSelectable, True)
|
||||||
|
self.setZValue(-1000.0)
|
||||||
|
self.arrow_color = QtCore.Qt.black
|
||||||
|
self.setPen(QtGui.QPen(
|
||||||
|
self.arrow_color, 2, QtCore.Qt.SolidLine, QtCore.Qt.RoundCap, QtCore.Qt.RoundJoin
|
||||||
|
))
|
||||||
|
|
||||||
|
def boundingRect(self):
|
||||||
|
extra = (self.pen().width() + 20) / 2.0
|
||||||
|
p1 = self.line().p1()
|
||||||
|
p2 = self.line().p2()
|
||||||
|
size = QtCore.QSizeF(p2.x() - p1.x(), p2.y() - p1.y())
|
||||||
|
return QtCore.QRectF(p1, size).normalized().adjusted(-extra, -extra, extra, extra)
|
||||||
|
|
||||||
|
def shape(self):
|
||||||
|
path = super().shape()
|
||||||
|
path.addPolygon(self.arrow_head)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def update_position(self):
|
||||||
|
line = QtCore.QLineF(
|
||||||
|
self.mapFromItem(self.start_node, 0, 0),
|
||||||
|
self.mapFromItem(self.end_node, 0, 0)
|
||||||
|
)
|
||||||
|
self.setLine(line)
|
||||||
|
|
||||||
|
def paint(self, painter, option, widget=None):
|
||||||
|
if self.start_node.collidesWithItem(self.end_node):
|
||||||
|
return
|
||||||
|
|
||||||
|
start_node = self.start_node
|
||||||
|
end_node = self.end_node
|
||||||
|
color = self.arrow_color
|
||||||
|
pen = self.pen()
|
||||||
|
pen.setColor(self.arrow_color)
|
||||||
|
arrow_size = 20.0
|
||||||
|
painter.setPen(pen)
|
||||||
|
painter.setBrush(self.arrow_color)
|
||||||
|
|
||||||
|
end_rectangle = end_node.sceneBoundingRect()
|
||||||
|
start_center = start_node.sceneBoundingRect().center()
|
||||||
|
end_center = end_rectangle.center()
|
||||||
|
center_line = QtCore.QLineF(start_center, end_center)
|
||||||
|
end_polygon = QtGui.QPolygonF(end_rectangle)
|
||||||
|
p1 = end_polygon.at(0)
|
||||||
|
|
||||||
|
intersect_point = QtCore.QPointF()
|
||||||
|
for p2 in end_polygon:
|
||||||
|
poly_line = QtCore.QLineF(p1, p2)
|
||||||
|
intersect_type, intersect_point = poly_line.intersect(center_line)
|
||||||
|
if intersect_type == QtCore.QLineF.BoundedIntersection:
|
||||||
|
break
|
||||||
|
p1 = p2
|
||||||
|
|
||||||
|
self.setLine(QtCore.QLineF(intersect_point, start_center))
|
||||||
|
line = self.line()
|
||||||
|
|
||||||
|
angle = math.acos(line.dx() / line.length())
|
||||||
|
if line.dy() >= 0:
|
||||||
|
angle = (math.pi * 2.0) - angle
|
||||||
|
|
||||||
|
arrow_p1 = line.p1() + QtCore.QPointF(
|
||||||
|
math.sin(angle + math.pi / 3.0) * arrow_size,
|
||||||
|
math.cos(angle + math.pi / 3.0) * arrow_size
|
||||||
|
)
|
||||||
|
arrow_p2 = line.p1() + QtCore.QPointF(
|
||||||
|
math.sin(angle + math.pi - math.pi / 3.0) * arrow_size,
|
||||||
|
math.cos(angle + math.pi - math.pi / 3.0) * arrow_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.arrow_head.clear()
|
||||||
|
for point in [line.p1(), arrow_p1, arrow_p2]:
|
||||||
|
self.arrow_head.append(point)
|
||||||
|
|
||||||
|
painter.drawLine(line)
|
||||||
|
painter.drawPolygon(self.arrow_head)
|
||||||
|
if self.isSelected():
|
||||||
|
painter.setPen(QtGui.QPen(color, 1, QtCore.Qt.DashLine))
|
||||||
|
line = QtCore.QLineF(line)
|
||||||
|
line.translate(0, 4.0)
|
||||||
|
painter.drawLine(line)
|
||||||
|
line.translate(0, -8.0)
|
||||||
|
painter.drawLine(line)
|
||||||
|
|
||||||
|
|
||||||
|
ONLINE_COLOR = "limegreen"
|
||||||
|
OFFLINE_COLOR = "lightsteelblue"
|
||||||
|
|
||||||
|
|
||||||
|
class NodeItem(QtSvg.QGraphicsSvgItem):
|
||||||
|
|
||||||
|
def __init__(self, context_menu):
|
||||||
|
super().__init__()
|
||||||
|
self._port = ''
|
||||||
|
self._color = OFFLINE_COLOR
|
||||||
|
self.context_menu = context_menu
|
||||||
|
self.arrows = set()
|
||||||
|
self.renderer = QtSvg.QSvgRenderer()
|
||||||
|
self.update_svg()
|
||||||
|
self.setSharedRenderer(self.renderer)
|
||||||
|
#self.setScale(2.0)
|
||||||
|
#self.setTransformOriginPoint(24, 24)
|
||||||
|
self.setFlag(QtWidgets.QGraphicsItem.ItemIsMovable, True)
|
||||||
|
self.setFlag(QtWidgets.QGraphicsItem.ItemIsSelectable, True)
|
||||||
|
|
||||||
|
def get_svg(self):
|
||||||
|
return self.SVG.format(
|
||||||
|
port=self.port,
|
||||||
|
color=self._color
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_svg(self):
|
||||||
|
self.renderer.load(QtCore.QByteArray(self.get_svg().encode()))
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def port(self):
|
||||||
|
return self._port
|
||||||
|
|
||||||
|
@port.setter
|
||||||
|
def port(self, port):
|
||||||
|
self._port = port
|
||||||
|
self.update_svg()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def online(self):
|
||||||
|
return self._color == ONLINE_COLOR
|
||||||
|
|
||||||
|
@online.setter
|
||||||
|
def online(self, online):
|
||||||
|
if online:
|
||||||
|
self._color = ONLINE_COLOR
|
||||||
|
else:
|
||||||
|
self._color = OFFLINE_COLOR
|
||||||
|
self.update_svg()
|
||||||
|
|
||||||
|
def connect_arrow(self, arrow):
|
||||||
|
self.arrows.add(arrow)
|
||||||
|
|
||||||
|
def disconnect_arrow(self, arrow):
|
||||||
|
self.arrows.discard(arrow)
|
||||||
|
|
||||||
|
def contextMenuEvent(self, event):
|
||||||
|
self.scene().clearSelection()
|
||||||
|
self.setSelected(True)
|
||||||
|
self.myContextMenu.exec_(event.screenPos())
|
||||||
|
|
||||||
|
def itemChange(self, change, value):
|
||||||
|
if change == QtWidgets.QGraphicsItem.ItemPositionChange:
|
||||||
|
for arrow in self.arrows:
|
||||||
|
arrow.update_position()
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class BlockchainNode(NodeItem):
|
||||||
|
SVG = """
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="48" height="48" viewBox="0 0 24 24">
|
||||||
|
<path fill="white" d="M2 0h20v24H2z"/>
|
||||||
|
<path fill="{color}" d="M8 7A 5.5 5 0 0 0 8 17h8A 5.5 5 0 0 0 16 7z"/>
|
||||||
|
<path d="M17 7h-4v2h4c1.65 0 3 1.35 3 3s-1.35 3-3 3h-4v2h4c2.76 0 5-2.24 5-5s-2.24-5-5-5zm-6 8H7c-1.65 0-3-1.35-3-3s1.35-3 3-3h4V7H7c-2.76 0-5 2.24-5 5s2.24 5 5 5h4v-2zm-3-4h8v2H8z"/>
|
||||||
|
<text x="4" y="6" font-size="6" font-weight="900">{port}</text>
|
||||||
|
<text x="4" y="23" font-size="6" font-weight="900">{block}</text>
|
||||||
|
</svg>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args):
|
||||||
|
self._block_height = ''
|
||||||
|
super().__init__(*args)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def block_height(self):
|
||||||
|
return self._block_height
|
||||||
|
|
||||||
|
@block_height.setter
|
||||||
|
def block_height(self, block_height):
|
||||||
|
self._block_height = block_height
|
||||||
|
self.update_svg()
|
||||||
|
|
||||||
|
def get_svg(self):
|
||||||
|
return self.SVG.format(
|
||||||
|
port=self.port,
|
||||||
|
block=self.block_height,
|
||||||
|
color=self._color
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SPVNode(NodeItem):
|
||||||
|
SVG = """
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="48" height="48" viewBox="0 0 24 24">
|
||||||
|
<path fill="white" d="M3 1h18v10H3z"/>
|
||||||
|
<g transform="translate(0 3)">
|
||||||
|
<path fill="{color}" d="M19.21 12.04l-1.53-.11-.3-1.5C16.88 7.86 14.62 6 12 6 9.94 6 8.08 7.14 7.12 8.96l-.5.95-1.07.11C3.53 10.24 2 11.95 2 14c0 2.21 1.79 4 4 4h13c1.65 0 3-1.35 3-3 0-1.55-1.22-2.86-2.79-2.96z"/>
|
||||||
|
<path d="M19.35 10.04C18.67 6.59 15.64 4 12 4 9.11 4 6.6 5.64 5.35 8.04 2.34 8.36 0 10.91 0 14c0 3.31 2.69 6 6 6h13c2.76 0 5-2.24 5-5 0-2.64-2.05-4.78-4.65-4.96zM19 18H6c-2.21 0-4-1.79-4-4 0-2.05 1.53-3.76 3.56-3.97l1.07-.11.5-.95C8.08 7.14 9.94 6 12 6c2.62 0 4.88 1.86 5.39 4.43l.3 1.5 1.53.11c1.56.1 2.78 1.41 2.78 2.96 0 1.65-1.35 3-3 3z"/>
|
||||||
|
</g>
|
||||||
|
<text x="4" y="6" font-size="6" font-weight="900">{port}</text>
|
||||||
|
</svg>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args):
|
||||||
|
super().__init__(*args)
|
||||||
|
|
||||||
|
|
||||||
|
class WalletNode(NodeItem):
|
||||||
|
SVG = """
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="48" height="48" viewBox="0 0 24 24">
|
||||||
|
<path fill="white" d="M3 3h17v17H3z"/>
|
||||||
|
<g transform="translate(0 -3)">
|
||||||
|
<path fill="{color}" d="M13 17c-1.1 0-2-.9-2-2V9c0-1.1.9-2 2-2h6V5H5v14h14v-2h-6z"/>
|
||||||
|
<path d="M21 7.28V5c0-1.1-.9-2-2-2H5c-1.11 0-2 .9-2 2v14c0 1.1.89 2 2 2h14c1.1 0 2-.9 2-2v-2.28c.59-.35 1-.98 1-1.72V9c0-.74-.41-1.38-1-1.72zM20 9v6h-7V9h7zM5 19V5h14v2h-6c-1.1 0-2 .9-2 2v6c0 1.1.9 2 2 2h6v2H5z"/>
|
||||||
|
<circle cx="16" cy="12" r="1.5"/>
|
||||||
|
</g>
|
||||||
|
<text x="4" y="23" font-size="6" font-weight="900">{coins}</text>
|
||||||
|
</svg>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args):
|
||||||
|
self._coins = '--'
|
||||||
|
super().__init__(*args)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def coins(self):
|
||||||
|
return self._coins
|
||||||
|
|
||||||
|
@coins.setter
|
||||||
|
def coins(self, coins):
|
||||||
|
self._coins = coins
|
||||||
|
self.update_svg()
|
||||||
|
|
||||||
|
def get_svg(self):
|
||||||
|
return self.SVG.format(
|
||||||
|
coins=self.coins,
|
||||||
|
color=self._color
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Stage(QtWidgets.QGraphicsScene):
|
||||||
|
|
||||||
|
def __init__(self, parent):
|
||||||
|
super().__init__(parent)
|
||||||
|
self.blockchain = b = BlockchainNode(None)
|
||||||
|
b.port = ''
|
||||||
|
b.block_height = ''
|
||||||
|
b.setZValue(0)
|
||||||
|
b.setPos(-25, -100)
|
||||||
|
self.addItem(b)
|
||||||
|
self.spv = s = SPVNode(None)
|
||||||
|
s.port = ''
|
||||||
|
s.setZValue(0)
|
||||||
|
self.addItem(s)
|
||||||
|
s.setPos(-10, -10)
|
||||||
|
self.wallet = w = WalletNode(None)
|
||||||
|
w.coins = ''
|
||||||
|
w.setZValue(0)
|
||||||
|
w.update_svg()
|
||||||
|
self.addItem(w)
|
||||||
|
w.setPos(0, 100)
|
||||||
|
|
||||||
|
self.addItem(Arrow(b, s))
|
||||||
|
self.addItem(Arrow(s, w))
|
||||||
|
|
||||||
|
|
||||||
|
class Orchstr8Workbench(QtWidgets.QMainWindow):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.stage = Stage(self)
|
||||||
|
self.view = QtWidgets.QGraphicsView(self.stage)
|
||||||
|
self.status_bar = QtWidgets.QStatusBar(self)
|
||||||
|
|
||||||
|
self.setWindowTitle('Orchstr8 Workbench')
|
||||||
|
self.setCentralWidget(self.view)
|
||||||
|
self.setStatusBar(self.status_bar)
|
||||||
|
|
||||||
|
self.block_height = self.make_status_label('Height: -- ')
|
||||||
|
self.user_balance = self.make_status_label('User Balance: -- ')
|
||||||
|
self.mining_balance = self.make_status_label('Mining Balance: -- ')
|
||||||
|
|
||||||
|
self.wallet_log = LoggingOutput('Wallet', self)
|
||||||
|
self.addDockWidget(QtCore.Qt.LeftDockWidgetArea, self.wallet_log)
|
||||||
|
self.spv_log = LoggingOutput('SPV Server', self)
|
||||||
|
self.addDockWidget(QtCore.Qt.LeftDockWidgetArea, self.spv_log)
|
||||||
|
self.blockchain_log = LoggingOutput('Blockchain', self)
|
||||||
|
self.addDockWidget(QtCore.Qt.LeftDockWidgetArea, self.blockchain_log)
|
||||||
|
|
||||||
|
self.blockchain_controls = BlockchainControls(self)
|
||||||
|
self.addDockWidget(QtCore.Qt.RightDockWidgetArea, self.blockchain_controls)
|
||||||
|
|
||||||
|
self.network = QtNetwork.QNetworkAccessManager(self)
|
||||||
|
self.socket = QtWebSockets.QWebSocket()
|
||||||
|
self.socket.connected.connect(lambda: self.run_command('start'))
|
||||||
|
self.socket.error.connect(lambda e: print(f'errored: {e}'))
|
||||||
|
self.socket.textMessageReceived.connect(self.on_message)
|
||||||
|
self.socket.open('ws://localhost:7954/log')
|
||||||
|
|
||||||
|
def make_status_label(self, text):
|
||||||
|
label = QtWidgets.QLabel(text)
|
||||||
|
label.setFrameStyle(QtWidgets.QLabel.Panel | QtWidgets.QLabel.Sunken)
|
||||||
|
self.status_bar.addPermanentWidget(label)
|
||||||
|
return label
|
||||||
|
|
||||||
|
def on_message(self, text):
|
||||||
|
msg = json.loads(text)
|
||||||
|
if msg['type'] == 'status':
|
||||||
|
self.stage.wallet.coins = msg['balance']
|
||||||
|
self.stage.blockchain.block_height = msg['height']
|
||||||
|
self.block_height.setText(f"Height: {msg['height']} ")
|
||||||
|
self.user_balance.setText(f"User Balance: {msg['balance']} ")
|
||||||
|
self.mining_balance.setText(f"Mining Balance: {msg['miner']} ")
|
||||||
|
elif msg['type'] == 'service':
|
||||||
|
node = {
|
||||||
|
'blockchain': self.stage.blockchain,
|
||||||
|
'spv': self.stage.spv,
|
||||||
|
'wallet': self.stage.wallet
|
||||||
|
}[msg['name']]
|
||||||
|
node.online = True
|
||||||
|
node.port = f":{msg['port']}"
|
||||||
|
elif msg['type'] == 'log':
|
||||||
|
log = {
|
||||||
|
'blockchain': self.blockchain_log,
|
||||||
|
'electrumx': self.spv_log,
|
||||||
|
'lbryumx': self.spv_log,
|
||||||
|
'Controller': self.spv_log,
|
||||||
|
'LBRYBlockProcessor': self.spv_log,
|
||||||
|
'LBCDaemon': self.spv_log,
|
||||||
|
}.get(msg['name'].split('.')[-1], self.wallet_log)
|
||||||
|
log.textEdit.append(msg['message'])
|
||||||
|
|
||||||
|
def run_command(self, command, **kwargs):
|
||||||
|
request = QtNetwork.QNetworkRequest(QtCore.QUrl('http://localhost:7954/'+command))
|
||||||
|
request.setHeader(QtNetwork.QNetworkRequest.ContentTypeHeader, "application/x-www-form-urlencoded")
|
||||||
|
reply = self.network.post(request, dict_to_post_data(kwargs))
|
||||||
|
# reply.finished.connect(cb)
|
||||||
|
reply.error.connect(self.on_command_error)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def on_command_error(error):
|
||||||
|
print('failed executing command:')
|
||||||
|
print(error)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app = QtWidgets.QApplication(sys.argv)
|
||||||
|
workbench = Orchstr8Workbench()
|
||||||
|
workbench.setGeometry(100, 100, 1200, 600)
|
||||||
|
workbench.show()
|
||||||
|
return app.exec_()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
104
torba/torba/workbench/blockchain_dock.ui
Normal file
104
torba/torba/workbench/blockchain_dock.ui
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<ui version="4.0">
|
||||||
|
<class>BlockchainDock</class>
|
||||||
|
<widget class="QDockWidget" name="BlockchainDock">
|
||||||
|
<property name="geometry">
|
||||||
|
<rect>
|
||||||
|
<x>0</x>
|
||||||
|
<y>0</y>
|
||||||
|
<width>416</width>
|
||||||
|
<height>167</height>
|
||||||
|
</rect>
|
||||||
|
</property>
|
||||||
|
<property name="floating">
|
||||||
|
<bool>false</bool>
|
||||||
|
</property>
|
||||||
|
<property name="features">
|
||||||
|
<set>QDockWidget::AllDockWidgetFeatures</set>
|
||||||
|
</property>
|
||||||
|
<property name="windowTitle">
|
||||||
|
<string>Blockchain</string>
|
||||||
|
</property>
|
||||||
|
<widget class="QWidget" name="dockWidgetContents">
|
||||||
|
<layout class="QFormLayout" name="formLayout">
|
||||||
|
<item row="0" column="0">
|
||||||
|
<widget class="QPushButton" name="generate">
|
||||||
|
<property name="text">
|
||||||
|
<string>generate</string>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
|
</item>
|
||||||
|
<item row="0" column="1">
|
||||||
|
<widget class="QSpinBox" name="blocks">
|
||||||
|
<property name="suffix">
|
||||||
|
<string> block(s)</string>
|
||||||
|
</property>
|
||||||
|
<property name="minimum">
|
||||||
|
<number>1</number>
|
||||||
|
</property>
|
||||||
|
<property name="maximum">
|
||||||
|
<number>9999</number>
|
||||||
|
</property>
|
||||||
|
<property name="value">
|
||||||
|
<number>1</number>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
|
</item>
|
||||||
|
<item row="1" column="0">
|
||||||
|
<widget class="QPushButton" name="transfer">
|
||||||
|
<property name="text">
|
||||||
|
<string>transfer</string>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
|
</item>
|
||||||
|
<item row="1" column="1">
|
||||||
|
<layout class="QHBoxLayout" name="horizontalLayout">
|
||||||
|
<item>
|
||||||
|
<widget class="QDoubleSpinBox" name="amount">
|
||||||
|
<property name="suffix">
|
||||||
|
<string/>
|
||||||
|
</property>
|
||||||
|
<property name="maximum">
|
||||||
|
<double>9999.989999999999782</double>
|
||||||
|
</property>
|
||||||
|
<property name="value">
|
||||||
|
<double>10.000000000000000</double>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
|
</item>
|
||||||
|
<item>
|
||||||
|
<widget class="QLabel" name="to_label">
|
||||||
|
<property name="text">
|
||||||
|
<string>to</string>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
|
</item>
|
||||||
|
<item>
|
||||||
|
<widget class="QLineEdit" name="address">
|
||||||
|
<property name="placeholderText">
|
||||||
|
<string>recipient address</string>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
|
</item>
|
||||||
|
</layout>
|
||||||
|
</item>
|
||||||
|
<item row="2" column="0">
|
||||||
|
<widget class="QPushButton" name="invalidate">
|
||||||
|
<property name="text">
|
||||||
|
<string>invalidate</string>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
|
</item>
|
||||||
|
<item row="2" column="1">
|
||||||
|
<widget class="QLineEdit" name="block_hash">
|
||||||
|
<property name="placeholderText">
|
||||||
|
<string>block hash</string>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
|
</item>
|
||||||
|
</layout>
|
||||||
|
</widget>
|
||||||
|
</widget>
|
||||||
|
<resources/>
|
||||||
|
<connections/>
|
||||||
|
</ui>
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue