update_history serially handles unique addresses

This commit is contained in:
Lex Berezhny 2018-12-05 11:02:52 -05:00
parent 2bd60021aa
commit 1b7c5a1373
3 changed files with 76 additions and 66 deletions

View file

@ -18,11 +18,16 @@ class BasicTransactionTests(IntegrationTestCase):
address1 = await self.account.receiving.get_or_create_usable_address() address1 = await self.account.receiving.get_or_create_usable_address()
hash1 = self.ledger.address_to_hash160(address1) hash1 = self.ledger.address_to_hash160(address1)
tasks = [] txids = await asyncio.gather(*(
for _ in range(10): self.blockchain.send_to_address(address1, 100)
sendtxid = await self.blockchain.send_to_address(address1, 100) for _ in range(10)
tasks.append(self.on_transaction_id(sendtxid)) ))
await asyncio.wait(tasks)
await asyncio.wait([
self.on_transaction_id(txid)
for txid in txids
])
await self.assertBalance(self.account, '1000.0') await self.assertBalance(self.account, '1000.0')
tasks = [] tasks = []
@ -37,11 +42,7 @@ class BasicTransactionTests(IntegrationTestCase):
await asyncio.wait(tasks) await asyncio.wait(tasks)
#await asyncio.sleep(5) await self.assertBalance(self.account, '999.99876')
await self.assertBalance(self.account, '1000.0')
await self.blockchain.generate(1)
async def test_sending_and_receiving(self): async def test_sending_and_receiving(self):
account1, account2 = self.account, self.wallet.generate_account(self.ledger) account1, account2 = self.account, self.wallet.generate_account(self.ledger)

View file

@ -75,20 +75,21 @@ class TransactionCacheItem:
class SynchronizationMonitor: class SynchronizationMonitor:
def __init__(self): def __init__(self, loop=None):
self.done = asyncio.Event() self.done = asyncio.Event()
self.tasks = [] self.tasks = []
self.loop = loop or asyncio.get_event_loop()
def add(self, coro): def add(self, coro):
len(self.tasks) < 1 and self.done.clear() len(self.tasks) < 1 and self.done.clear()
asyncio.ensure_future(self._monitor(coro)) self.loop.create_task(self._monitor(coro))
def cancel(self): def cancel(self):
for task in self.tasks: for task in self.tasks:
task.cancel() task.cancel()
async def _monitor(self, coro): async def _monitor(self, coro):
task = asyncio.ensure_future(coro) task = self.loop.create_task(coro)
self.tasks.append(task) self.tasks.append(task)
try: try:
await task await task
@ -161,6 +162,7 @@ class BaseLedger(metaclass=LedgerRegistry):
self.sync = SynchronizationMonitor() self.sync = SynchronizationMonitor()
self._utxo_reservation_lock = asyncio.Lock() self._utxo_reservation_lock = asyncio.Lock()
self._header_processing_lock = asyncio.Lock() self._header_processing_lock = asyncio.Lock()
self._address_update_locks: Dict[str, asyncio.Lock] = {}
@classmethod @classmethod
def get_id(cls): def get_id(cls):
@ -382,63 +384,66 @@ class BaseLedger(metaclass=LedgerRegistry):
async def update_history(self, address, remote_status, async def update_history(self, address, remote_status,
address_manager: baseaccount.AddressManager = None): address_manager: baseaccount.AddressManager = None):
local_status, local_history = await self.get_local_status_and_history(address)
if local_status == remote_status: async with self._address_update_locks.setdefault(address, asyncio.Lock()):
return
remote_history = await self.network.get_history(address) local_status, local_history = await self.get_local_status_and_history(address)
cache_tasks = [] if local_status == remote_status:
synced_history = StringIO() return
for i, (txid, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
if i < len(local_history) and local_history[i] == (txid, remote_height):
synced_history.write(f'{txid}:{remote_height}:')
else:
cache_tasks.append(asyncio.ensure_future(
self.cache_transaction(txid, remote_height)
))
for task in cache_tasks: remote_history = await self.network.get_history(address)
tx = await task
check_db_for_txos = [] cache_tasks = []
for txi in tx.inputs: synced_history = StringIO()
if txi.txo_ref.txo is not None: for i, (txid, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
continue if i < len(local_history) and local_history[i] == (txid, remote_height):
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id) synced_history.write(f'{txid}:{remote_height}:')
if cache_item is not None:
if cache_item.tx is None:
await cache_item.has_tx.wait()
assert cache_item.tx is not None
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
else: else:
check_db_for_txos.append(txi.txo_ref.tx_ref.id) cache_tasks.append(asyncio.ensure_future(
self.cache_transaction(txid, remote_height)
))
referenced_txos = { for task in cache_tasks:
txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos) tx = await task
}
for txi in tx.inputs: check_db_for_txos = []
if txi.txo_ref.txo is not None: for txi in tx.inputs:
continue if txi.txo_ref.txo is not None:
referenced_txo = referenced_txos.get(txi.txo_ref.tx_ref.id) continue
if referenced_txo is not None: cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id)
txi.txo_ref = referenced_txo.ref if cache_item is not None:
if cache_item.tx is None:
await cache_item.has_tx.wait()
assert cache_item.tx is not None
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
else:
check_db_for_txos.append(txi.txo_ref.tx_ref.id)
synced_history.write(f'{tx.id}:{tx.height}:') referenced_txos = {
txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos)
}
await self.db.save_transaction_io( for txi in tx.inputs:
tx, address, self.address_to_hash160(address), synced_history.getvalue() if txi.txo_ref.txo is not None:
) continue
referenced_txo = referenced_txos.get(txi.txo_ref.tx_ref.id)
if referenced_txo is not None:
txi.txo_ref = referenced_txo.ref
self._on_transaction_controller.add(TransactionEvent(address, tx)) synced_history.write(f'{tx.id}:{tx.height}:')
if address_manager is None: await self.db.save_transaction_io(
address_manager = await self.get_address_manager_for_address(address) tx, address, self.address_to_hash160(address), synced_history.getvalue()
)
if address_manager is not None: await self._on_transaction_controller.add(TransactionEvent(address, tx))
await address_manager.ensure_address_gap()
if address_manager is None:
address_manager = await self.get_address_manager_for_address(address)
if address_manager is not None:
await address_manager.ensure_address_gap()
async def cache_transaction(self, txid, remote_height): async def cache_transaction(self, txid, remote_height):
cache_item = self._tx_cache.get(txid) cache_item = self._tx_cache.get(txid)
@ -449,9 +454,8 @@ class BaseLedger(metaclass=LedgerRegistry):
(cache_item.tx.is_verified or remote_height < 1): (cache_item.tx.is_verified or remote_height < 1):
return cache_item.tx # cached tx is already up-to-date return cache_item.tx # cached tx is already up-to-date
await cache_item.lock.acquire() async with cache_item.lock:
try:
tx = cache_item.tx tx = cache_item.tx
if tx is None: if tx is None:
@ -478,9 +482,6 @@ class BaseLedger(metaclass=LedgerRegistry):
return tx return tx
finally:
cache_item.lock.release()
async def maybe_verify_transaction(self, tx, remote_height): async def maybe_verify_transaction(self, tx, remote_height):
tx.height = remote_height tx.height = remote_height
if 0 < remote_height <= len(self.headers): if 0 < remote_height <= len(self.headers):
@ -514,6 +515,6 @@ class BaseLedger(metaclass=LedgerRegistry):
records = await self.db.get_addresses(cols=('address',), address__in=addresses) records = await self.db.get_addresses(cols=('address',), address__in=addresses)
await asyncio.wait([ await asyncio.wait([
self.on_transaction.where(partial( self.on_transaction.where(partial(
lambda a, e: a == e.address and e.tx.height >= height, address_record['address'] lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id, address_record['address']
)) for address_record in records )) for address_record in records
]) ])

View file

@ -18,14 +18,22 @@ class ColorHandler(logging.StreamHandler):
level_color = { level_color = {
logging.DEBUG: "black", logging.DEBUG: "black",
logging.INFO: "black", logging.INFO: "light_gray",
logging.WARNING: "yellow", logging.WARNING: "yellow",
logging.ERROR: "red" logging.ERROR: "red"
} }
color_code = dict( color_code = dict(
black=30, red=31, green=32, yellow=33, black=30,
blue=34, magenta=35, cyan=36, white=37 red=31,
green=32,
yellow=33,
blue=34,
magenta=35,
cyan=36,
white=37,
light_gray='0;37',
dark_gray='1;30'
) )
def emit(self, record): def emit(self, record):