SPV support

This commit is contained in:
Lex Berezhny 2018-06-25 09:54:35 -04:00
parent 22b897db2e
commit aba3ed7ca0
3 changed files with 115 additions and 41 deletions

View file

@ -29,11 +29,21 @@ class BasicTransactionTests(IntegrationTestCase):
[self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, self.ledger.address_to_hash160(address))], [self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, self.ledger.address_to_hash160(address))],
[account1], account1 [account1], account1
).asFuture(asyncio.get_event_loop()) ).asFuture(asyncio.get_event_loop())
await self.blockchain.decode_raw_transaction(tx)
await self.broadcast(tx) await self.broadcast(tx)
await self.on_transaction(tx.hex_id.decode()) #mempool await self.on_transaction(tx.hex_id.decode()) #mempool
tx2 = await self.ledger.transaction_class.pay(
[self.ledger.transaction_class.output_class.pay_pubkey_hash(1*COIN, self.ledger.address_to_hash160(address))],
[account1], account1
).asFuture(asyncio.get_event_loop())
await self.broadcast(tx2)
await self.on_transaction(tx2.hex_id.decode()) #mempool
await self.blockchain.generate(1) await self.blockchain.generate(1)
await self.on_transaction(tx.hex_id.decode()) #confirmed await self.on_transaction(tx.hex_id.decode()) #confirmed
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5) #self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5)
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0) #self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0)
self.assertTrue(await self.ledger.is_valid_transaction(tx, 202).asFuture(asyncio.get_event_loop()))
self.assertTrue(await self.ledger.is_valid_transaction(tx2, 202).asFuture(asyncio.get_event_loop()))

View file

@ -44,11 +44,23 @@ class SQLiteMixin(object):
for column, value in data.items(): for column, value in data.items():
columns.append(column) columns.append(column)
values.append(value) values.append(value)
sql = "REPLACE INTO {} ({}) VALUES ({})".format( sql = "INSERT INTO {} ({}) VALUES ({})".format(
table, ', '.join(columns), ', '.join(['?'] * len(values)) table, ', '.join(columns), ', '.join(['?'] * len(values))
) )
return sql, values return sql, values
def _update_sql(self, table, data, where, constraints):
# type: (str, dict) -> tuple[str, List]
columns, values = [], []
for column, value in data.items():
columns.append("{} = ?".format(column))
values.append(value)
values.extend(constraints)
sql = "UPDATE {} SET {} WHERE {}".format(
table, ', '.join(columns), where
)
return sql, values
@defer.inlineCallbacks @defer.inlineCallbacks
def query_one_value_list(self, query, params): def query_one_value_list(self, query, params):
# type: (str, Union[dict,tuple]) -> defer.Deferred[List] # type: (str, Union[dict,tuple]) -> defer.Deferred[List]
@ -143,21 +155,21 @@ class BaseDatabase(SQLiteMixin):
CREATE_TXI_TABLE CREATE_TXI_TABLE
) )
def add_transaction(self, address, hash, tx, height, is_verified): def save_transaction_io(self, save_tx, tx, height, is_verified, address, hash, history):
def _steps(t): def _steps(t):
current_height = t.execute("SELECT height FROM tx WHERE txhash=?", (sqlite3.Binary(tx.hash),)).fetchone() if save_tx == 'insert':
if current_height is None:
t.execute(*self._insert_sql('tx', { t.execute(*self._insert_sql('tx', {
'txhash': sqlite3.Binary(tx.hash), 'txhash': sqlite3.Binary(tx.hash),
'raw': sqlite3.Binary(tx.raw), 'raw': sqlite3.Binary(tx.raw),
'height': height, 'height': height,
'is_verified': is_verified 'is_verified': is_verified
})) }))
elif current_height[0] != height: elif save_tx == 'update':
t.execute("UPDATE tx SET height = :height WHERE txhash = :txhash", { t.execute(*self._update_sql("tx", {
'txhash': sqlite3.Binary(tx.hash), 'height': height, 'is_verified': is_verified
'height': height, }, 'WHERE txhash = ?', (sqlite3.Binary(tx.hash),)
}) ))
existing_txos = list(map(itemgetter(0), t.execute( existing_txos = list(map(itemgetter(0), t.execute(
"SELECT position FROM txo WHERE txhash = ?", "SELECT position FROM txo WHERE txhash = ?",
@ -177,7 +189,7 @@ class BaseDatabase(SQLiteMixin):
})) }))
elif txo.script.is_pay_script_hash: elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments # TODO: implement script hash payments
print('Database.add_transaction pay script hash is not implemented!') print('Database.save_transaction_io: pay script hash is not implemented!')
existing_txis = [txi[0] for txi in t.execute( existing_txis = [txi[0] for txi in t.execute(
"SELECT txoid FROM txi WHERE txhash = ? AND address = ?", "SELECT txoid FROM txi WHERE txhash = ? AND address = ?",
@ -195,8 +207,23 @@ class BaseDatabase(SQLiteMixin):
'txoid': txoid[0], 'txoid': txoid[0],
})) }))
t.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(sqlite3.Binary(history), history.count(b':')//2, sqlite3.Binary(address))
)
return self.db.runInteraction(_steps) return self.db.runInteraction(_steps)
@defer.inlineCallbacks
def get_transaction(self, txhash):
result = yield self.db.runQuery(
"SELECT raw, height, is_verified FROM tx WHERE txhash = ?", (sqlite3.Binary(txhash),)
)
if result:
defer.returnValue(*result[0])
else:
defer.returnValue((None, None, False))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_balance_for_account(self, account): def get_balance_for_account(self, account):
result = yield self.db.runQuery( result = yield self.db.runQuery(
@ -296,9 +323,3 @@ class BaseDatabase(SQLiteMixin):
('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'), ('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'),
{'address': sqlite3.Binary(address)} {'address': sqlite3.Binary(address)}
) )
def set_address_history(self, address, history):
return self.db.runOperation(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(sqlite3.Binary(history), history.count(b':')//2, sqlite3.Binary(address))
)

View file

@ -68,6 +68,8 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
self._on_transaction_controller = StreamController() self._on_transaction_controller = StreamController()
self.on_transaction = self._on_transaction_controller.stream self.on_transaction = self._on_transaction_controller.stream
self._transaction_processing_locks = {}
@classmethod @classmethod
def get_id(cls): def get_id(cls):
return '{}_{}'.format(cls.symbol.lower(), cls.network_name.lower()) return '{}_{}'.format(cls.symbol.lower(), cls.network_name.lower())
@ -101,14 +103,6 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
""" Fee for the transaction header and all outputs; without inputs. """ """ Fee for the transaction header and all outputs; without inputs. """
return self.fee_per_byte * tx.base_size return self.fee_per_byte * tx.base_size
@defer.inlineCallbacks
def add_transaction(self, address, transaction, height):
# type: (bytes, basetransaction.BaseTransaction, int) -> None
yield self.db.add_transaction(
address, self.address_to_hash160(address), transaction, height, False
)
self._on_transaction_controller.add(transaction)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_account(self, account): # type: (baseaccount.BaseAccount) -> None def add_account(self, account): # type: (baseaccount.BaseAccount) -> None
self.accounts.add(account) self.accounts.add(account)
@ -139,7 +133,10 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_local_status(self, address): def get_local_status(self, address):
address_details = yield self.db.get_address(address) address_details = yield self.db.get_address(address)
hash = hashlib.sha256(address_details['history']).digest() history = address_details['history'] or b''
if six.PY2:
history = str(history)
hash = hashlib.sha256(history).digest()
defer.returnValue(hexlify(hash)) defer.returnValue(hexlify(hash))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -151,6 +148,26 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
parts = history.split(b':')[:-1] parts = history.split(b':')[:-1]
defer.returnValue(list(zip(parts[0::2], map(int, parts[1::2])))) defer.returnValue(list(zip(parts[0::2], map(int, parts[1::2]))))
@staticmethod
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
for i, branch in enumerate(branches):
other_branch = unhexlify(branch)[::-1]
other_branch_on_left = bool((branch_positions >> i) & 1)
if other_branch_on_left:
combined = other_branch + working_branch
else:
combined = working_branch + other_branch
working_branch = double_sha256(combined)
return hexlify(working_branch[::-1])
@defer.inlineCallbacks
def is_valid_transaction(self, tx, height):
len(self.headers) < height or defer.returnValue(False)
merkle = yield self.network.get_merkle(tx.hex_id.decode(), height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[height]
defer.returnValue(merkle_root == header['merkle_root'])
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
if not os.path.exists(self.path): if not os.path.exists(self.path):
@ -173,8 +190,6 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
while True: while True:
height_sought = len(self.headers) height_sought = len(self.headers)
headers = yield self.network.get_headers(height_sought) headers = yield self.network.get_headers(height_sought)
print("received {} headers starting at {} height".format(headers['count'], height_sought))
#log.info("received {} headers starting at {} height".format(headers['count'], height_sought))
if headers['count'] <= 0: if headers['count'] <= 0:
break break
yield self.headers.connect(height_sought, unhexlify(headers['hex'])) yield self.headers.connect(height_sought, unhexlify(headers['hex']))
@ -221,20 +236,48 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
@defer.inlineCallbacks @defer.inlineCallbacks
def update_history(self, address): def update_history(self, address):
remote_history = yield self.network.get_history(address) remote_history = yield self.network.get_history(address)
local = yield self.get_local_history(address) local_history = yield self.get_local_history(address)
history_parts = [] synced_history = []
for i, (hash, height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): for i, (hash, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
history_parts.append('{}:{}:'.format(hash.decode(), height))
if i < len(local) and local[i] == (hash, height): synced_history.append((hash, remote_height))
if i < len(local_history) and local_history[i] == (hash, remote_height):
continue continue
raw = yield self.network.get_transaction(hash)
transaction = self.transaction_class(unhexlify(raw))
yield self.add_transaction(address, transaction, height)
yield self.db.set_address_history( lock = self._transaction_processing_locks.setdefault(hash, defer.DeferredLock())
address, ''.join(history_parts).encode()
) yield lock.acquire()
try:
# see if we have a local copy of transaction, otherwise fetch it from server
raw, local_height, is_verified = yield self.db.get_transaction(unhexlify(hash))
save_tx = None
if raw is None:
_raw = yield self.network.get_transaction(hash)
tx = self.transaction_class(unhexlify(_raw))
save_tx = 'insert'
else:
tx = self.transaction_class(unhexlify(raw))
if remote_height > 0 and not is_verified:
is_verified = yield self.is_valid_transaction(tx, remote_height)
if save_tx is None:
save_tx = 'update'
yield self.db.save_transaction_io(
save_tx, tx, remote_height, is_verified, address, self.address_to_hash160(address),
''.join('{}:{}:'.format(hash.decode(), height) for hash, height in synced_history).encode()
)
if save_tx is not None:
self._on_transaction_controller.add(tx)
finally:
lock.release()
if not lock.locked:
del self._transaction_processing_locks[hash]
@defer.inlineCallbacks @defer.inlineCallbacks
def subscribe_history(self, address): def subscribe_history(self, address):