client wallet db versioning
This commit is contained in:
parent
063eeb34ec
commit
d31f09c174
2 changed files with 105 additions and 1 deletions
|
@ -1,5 +1,7 @@
|
|||
import unittest
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from torba.client.wallet import Wallet
|
||||
from torba.client.constants import COIN
|
||||
|
@ -268,7 +270,6 @@ class TestQueries(AsyncioTestCase):
|
|||
self.assertEqual(len(tx.outputs), 1)
|
||||
last_tx = tx
|
||||
|
||||
|
||||
async def test_queries(self):
|
||||
self.assertEqual(0, await self.ledger.db.get_address_count())
|
||||
account1 = await self.create_account()
|
||||
|
@ -354,3 +355,79 @@ class TestQueries(AsyncioTestCase):
|
|||
txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
|
||||
self.assertEqual([0, 3, 2, 1], [tx.height for tx in txs])
|
||||
self.assertEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
|
||||
|
||||
|
||||
class TestUpgrade(AsyncioTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.path = tempfile.mktemp()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
os.remove(self.path)
|
||||
|
||||
def get_version(self):
|
||||
with sqlite3.connect(self.path) as conn:
|
||||
versions = conn.execute('select version from version').fetchall()
|
||||
assert len(versions) == 1
|
||||
return versions[0][0]
|
||||
|
||||
def get_tables(self):
|
||||
with sqlite3.connect(self.path) as conn:
|
||||
sql = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;"
|
||||
return [col[0] for col in conn.execute(sql).fetchall()]
|
||||
|
||||
def add_address(self, address):
|
||||
with sqlite3.connect(self.path) as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO pubkey_address (address, account, chain, position, pubkey)
|
||||
VALUES (?, 'account1', 0, 0, 'pubkey blob')
|
||||
""", (address,))
|
||||
|
||||
def get_addresses(self):
|
||||
with sqlite3.connect(self.path) as conn:
|
||||
sql = "SELECT address FROM pubkey_address ORDER BY address;"
|
||||
return [col[0] for col in conn.execute(sql).fetchall()]
|
||||
|
||||
async def test_reset_on_version_change(self):
|
||||
self.ledger = ledger_class({
|
||||
'db': ledger_class.database_class(self.path),
|
||||
'headers': ledger_class.headers_class(':memory:'),
|
||||
})
|
||||
|
||||
# initial open, pre-version enabled db
|
||||
self.ledger.db.SCHEMA_VERSION = None
|
||||
self.assertEqual(self.get_tables(), [])
|
||||
await self.ledger.db.open()
|
||||
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo'])
|
||||
self.assertEqual(self.get_addresses(), [])
|
||||
self.add_address('address1')
|
||||
await self.ledger.db.close()
|
||||
|
||||
# initial open after version enabled
|
||||
self.ledger.db.SCHEMA_VERSION = '1.0'
|
||||
await self.ledger.db.open()
|
||||
self.assertEqual(self.get_version(), '1.0')
|
||||
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
||||
self.assertEqual(self.get_addresses(), []) # address1 deleted during version upgrade
|
||||
self.add_address('address2')
|
||||
await self.ledger.db.close()
|
||||
|
||||
# nothing changes
|
||||
self.assertEqual(self.get_version(), '1.0')
|
||||
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
||||
await self.ledger.db.open()
|
||||
self.assertEqual(self.get_version(), '1.0')
|
||||
self.assertEqual(self.get_tables(), ['pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
||||
self.assertEqual(self.get_addresses(), ['address2'])
|
||||
await self.ledger.db.close()
|
||||
|
||||
# upgrade version, database reset
|
||||
self.ledger.db.SCHEMA_VERSION = '1.1'
|
||||
self.ledger.db.CREATE_TABLES_QUERY += """
|
||||
create table if not exists foo (bar text);
|
||||
"""
|
||||
await self.ledger.db.open()
|
||||
self.assertEqual(self.get_version(), '1.1')
|
||||
self.assertEqual(self.get_tables(), ['foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version'])
|
||||
self.assertEqual(self.get_addresses(), []) # all tables got reset
|
||||
await self.ledger.db.close()
|
||||
|
|
|
@ -50,6 +50,10 @@ class AIOSQLite:
|
|||
parameters = parameters if parameters is not None else []
|
||||
return self.run(lambda conn: conn.execute(sql, parameters).fetchall())
|
||||
|
||||
def execute_fetchone(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]:
|
||||
parameters = parameters if parameters is not None else []
|
||||
return self.run(lambda conn: conn.execute(sql, parameters).fetchone())
|
||||
|
||||
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
|
||||
parameters = parameters if parameters is not None else []
|
||||
return self.run(lambda conn: conn.execute(sql, parameters))
|
||||
|
@ -212,9 +216,16 @@ def rows_to_dict(rows, fields):
|
|||
|
||||
class SQLiteMixin:
|
||||
|
||||
SCHEMA_VERSION: str = None
|
||||
CREATE_TABLES_QUERY: str
|
||||
MAX_QUERY_VARIABLES = 900
|
||||
|
||||
CREATE_VERSION_TABLE = """
|
||||
create table if not exists version (
|
||||
version text
|
||||
);
|
||||
"""
|
||||
|
||||
def __init__(self, path):
|
||||
self._db_path = path
|
||||
self.db: AIOSQLite = None
|
||||
|
@ -223,6 +234,20 @@ class SQLiteMixin:
|
|||
async def open(self):
|
||||
log.info("connecting to database: %s", self._db_path)
|
||||
self.db = await AIOSQLite.connect(self._db_path, isolation_level=None)
|
||||
if self.SCHEMA_VERSION:
|
||||
tables = [t[0] for t in await self.db.execute_fetchall(
|
||||
"SELECT name FROM sqlite_master WHERE type='table';"
|
||||
)]
|
||||
if tables:
|
||||
if 'version' in tables:
|
||||
version = await self.db.execute_fetchone("SELECT version FROM version LIMIT 1;")
|
||||
if version == (self.SCHEMA_VERSION,):
|
||||
return
|
||||
await self.db.executescript('\n'.join(
|
||||
f"DROP TABLE {table};" for table in tables
|
||||
))
|
||||
await self.db.execute(self.CREATE_VERSION_TABLE)
|
||||
await self.db.execute("INSERT INTO version VALUES (?)", (self.SCHEMA_VERSION,))
|
||||
await self.db.executescript(self.CREATE_TABLES_QUERY)
|
||||
|
||||
async def close(self):
|
||||
|
@ -261,6 +286,8 @@ class SQLiteMixin:
|
|||
|
||||
class BaseDatabase(SQLiteMixin):
|
||||
|
||||
SCHEMA_VERSION = "1.0"
|
||||
|
||||
PRAGMAS = """
|
||||
pragma journal_mode=WAL;
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue