client wallet db versioning

This commit is contained in:
Lex Berezhny 2019-09-16 12:18:41 -04:00
parent 063eeb34ec
commit d31f09c174
2 changed files with 105 additions and 1 deletions

View file

@ -1,5 +1,7 @@
import unittest
import sqlite3
import tempfile
import os
from torba.client.wallet import Wallet
from torba.client.constants import COIN
@ -268,7 +270,6 @@ class TestQueries(AsyncioTestCase):
self.assertEqual(len(tx.outputs), 1)
last_tx = tx
async def test_queries(self):
self.assertEqual(0, await self.ledger.db.get_address_count())
account1 = await self.create_account()
@ -354,3 +355,79 @@ class TestQueries(AsyncioTestCase):
txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
self.assertEqual([0, 3, 2, 1], [tx.height for tx in txs])
self.assertEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
class TestUpgrade(AsyncioTestCase):
def setUp(self) -> None:
self.path = tempfile.mktemp()
def tearDown(self) -> None:
os.remove(self.path)
def get_version(self):
with sqlite3.connect(self.path) as conn:
versions = conn.execute('select version from version').fetchall()
assert len(versions) == 1
return versions[0][0]
def get_tables(self):
with sqlite3.connect(self.path) as conn:
sql = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;"
return [col[0] for col in conn.execute(sql).fetchall()]
def add_address(self, address):
with sqlite3.connect(self.path) as conn:
conn.execute("""
INSERT INTO pubkey_address (address, account, chain, position, pubkey)
VALUES (?, 'account1', 0, 0, 'pubkey blob')
""", (address,))
def get_addresses(self):
with sqlite3.connect(self.path) as conn:
sql = "SELECT address FROM pubkey_address ORDER BY address;"
return [col[0] for col in conn.execute(sql).fetchall()]
async def test_reset_on_version_change(self):
self.ledger = ledger_class({
'db': ledger_class.database_class(self.path),
'headers': ledger_class.headers_class(':memory:'),
})
# initial open, pre-version enabled db
self.ledger.db.SCHEMA_VERSION = None
self.assertEqual(self.get_tables(), [])
await self.ledger.db.open()
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo'])
self.assertEqual(self.get_addresses(), [])
self.add_address('address1')
await self.ledger.db.close()
# initial open after version enabled
self.ledger.db.SCHEMA_VERSION = '1.0'
await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.0')
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_addresses(), []) # address1 deleted during version upgrade
self.add_address('address2')
await self.ledger.db.close()
# nothing changes
self.assertEqual(self.get_version(), '1.0')
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version'])
await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.0')
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_addresses(), ['address2'])
await self.ledger.db.close()
# upgrade version, database reset
self.ledger.db.SCHEMA_VERSION = '1.1'
self.ledger.db.CREATE_TABLES_QUERY += """
create table if not exists foo (bar text);
"""
await self.ledger.db.open()
self.assertEqual(self.get_version(), '1.1')
self.assertEqual(self.get_tables(), ['foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
self.assertEqual(self.get_addresses(), []) # all tables got reset
await self.ledger.db.close()

View file

@ -50,6 +50,10 @@ class AIOSQLite:
parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters).fetchall())
def execute_fetchone(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters).fetchone())
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters))
@ -212,9 +216,16 @@ def rows_to_dict(rows, fields):
class SQLiteMixin:
SCHEMA_VERSION: str = None
CREATE_TABLES_QUERY: str
MAX_QUERY_VARIABLES = 900
CREATE_VERSION_TABLE = """
create table if not exists version (
version text
);
"""
def __init__(self, path):
self._db_path = path
self.db: AIOSQLite = None
@ -223,6 +234,20 @@ class SQLiteMixin:
async def open(self):
log.info("connecting to database: %s", self._db_path)
self.db = await AIOSQLite.connect(self._db_path, isolation_level=None)
if self.SCHEMA_VERSION:
tables = [t[0] for t in await self.db.execute_fetchall(
"SELECT name FROM sqlite_master WHERE type='table';"
)]
if tables:
if 'version' in tables:
version = await self.db.execute_fetchone("SELECT version FROM version LIMIT 1;")
if version == (self.SCHEMA_VERSION,):
return
await self.db.executescript('\n'.join(
f"DROP TABLE {table};" for table in tables
))
await self.db.execute(self.CREATE_VERSION_TABLE)
await self.db.execute("INSERT INTO version VALUES (?)", (self.SCHEMA_VERSION,))
await self.db.executescript(self.CREATE_TABLES_QUERY)
async def close(self):
@ -261,6 +286,8 @@ class SQLiteMixin:
class BaseDatabase(SQLiteMixin):
SCHEMA_VERSION = "1.0"
PRAGMAS = """
pragma journal_mode=WAL;
"""