fixes from code review

This commit is contained in:
Victor Shyba 2018-10-17 20:39:38 -03:00 committed by Lex Berezhny
parent 28b638f0aa
commit b830f4c3ca
2 changed files with 3 additions and 12 deletions

View file

@ -110,10 +110,9 @@ class SQLiteMixin:
async def open(self): async def open(self):
log.info("connecting to database: %s", self._db_path) log.info("connecting to database: %s", self._db_path)
self.db = aiosqlite.connect(self._db_path) self.db = aiosqlite.connect(self._db_path, isolation_level=None)
await self.db.__aenter__() await self.db.__aenter__()
await self.db.executescript(self.CREATE_TABLES_QUERY) await self.db.executescript(self.CREATE_TABLES_QUERY)
await self.db.commit()
async def close(self): async def close(self):
await self.db.close() await self.db.close()
@ -273,7 +272,6 @@ class BaseDatabase(SQLiteMixin):
})) }))
await self._set_address_history(address, history) await self._set_address_history(address, history)
await self.db.commit()
async def reserve_outputs(self, txos, is_reserved=True): async def reserve_outputs(self, txos, is_reserved=True):
txoids = [txo.id for txo in txos] txoids = [txo.id for txo in txos]
@ -282,7 +280,6 @@ class BaseDatabase(SQLiteMixin):
', '.join(['?']*len(txoids)) ', '.join(['?']*len(txoids))
), [is_reserved]+txoids ), [is_reserved]+txoids
) )
await self.db.commit()
async def release_outputs(self, txos): async def release_outputs(self, txos):
await self.reserve_outputs(txos, is_reserved=False) await self.reserve_outputs(txos, is_reserved=False)
@ -454,7 +451,6 @@ class BaseDatabase(SQLiteMixin):
sqlite3.Binary(pubkey.pubkey_bytes) sqlite3.Binary(pubkey.pubkey_bytes)
)) ))
await self.db.execute(sql, values) await self.db.execute(sql, values)
await self.db.commit()
async def _set_address_history(self, address, history): async def _set_address_history(self, address, history):
await self.db.execute( await self.db.execute(
@ -464,4 +460,3 @@ class BaseDatabase(SQLiteMixin):
async def set_address_history(self, address, history): async def set_address_history(self, address, history):
await self._set_address_history(address, history) await self._set_address_history(address, history)
await self.db.commit()

View file

@ -306,17 +306,13 @@ class BaseLedger(metaclass=LedgerRegistry):
)) ))
async def update_account(self, account: baseaccount.BaseAccount): async def update_account(self, account: baseaccount.BaseAccount):
# Before subscribing, download history for any addresses that don't have any,
# this avoids situation where we're getting status updates to addresses we know
# need to update anyways. Continue to get history and create more addresses until
# all missing addresses are created and history for them is fully restored.
await account.ensure_address_gap() await account.ensure_address_gap()
addresses = await account.get_addresses(used_times=0) addresses = await account.get_addresses(used_times=0)
while addresses: while addresses:
await asyncio.gather(*(self.subscribe_history(a) for a in addresses)) await asyncio.gather(*(self.subscribe_history(a) for a in addresses))
addresses = await account.ensure_address_gap() addresses = await account.ensure_address_gap()
async def _prefetch_history(self, remote_history, local_history): def _prefetch_history(self, remote_history, local_history):
proofs, network_txs = {}, {} proofs, network_txs = {}, {}
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
if i < len(local_history) and local_history[i] == (hex_id, remote_height): if i < len(local_history) and local_history[i] == (hex_id, remote_height):
@ -329,7 +325,7 @@ class BaseLedger(metaclass=LedgerRegistry):
async def update_history(self, address): async def update_history(self, address):
remote_history = await self.network.get_history(address) remote_history = await self.network.get_history(address)
local_history = await self.get_local_history(address) local_history = await self.get_local_history(address)
proofs, network_txs = await self._prefetch_history(remote_history, local_history) proofs, network_txs = self._prefetch_history(remote_history, local_history)
synced_history = StringIO() synced_history = StringIO()
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):