added Database.temp_sqlite_regtest

This commit is contained in:
Lex Berezhny 2020-05-20 17:59:26 -04:00
parent 12915143b8
commit b7ff6569e4
2 changed files with 17 additions and 11 deletions

View file

@ -1,5 +1,6 @@
import os import os
import asyncio import asyncio
import tempfile
from typing import List, Optional, Tuple, Iterable, TYPE_CHECKING from typing import List, Optional, Tuple, Iterable, TYPE_CHECKING
from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
from functools import partial from functools import partial
@ -56,6 +57,14 @@ class Database:
self.multiprocess = multiprocess self.multiprocess = multiprocess
self.executor: Optional[Executor] = None self.executor: Optional[Executor] = None
@classmethod
def temp_sqlite_regtest(cls):
from lbry import Config, RegTestLedger
directory = tempfile.mkdtemp()
conf = Config.with_same_dir(directory)
ledger = RegTestLedger(conf)
return cls(ledger, conf.db_url_or_default)
@classmethod @classmethod
def from_memory(cls, ledger): def from_memory(cls, ledger):
return cls(ledger, 'sqlite:///:memory:') return cls(ledger, 'sqlite:///:memory:')

View file

@ -54,16 +54,13 @@ def initialize(url: str, ledger: Ledger, track_metrics=False, block_and_filter=N
def check_version_and_create_tables(): def check_version_and_create_tables():
context = ctx() context = ctx()
if SCHEMA_VERSION: if context.has_table('version'):
if context.has_table('version'): version = context.fetchone(select(Version.c.version).limit(1))
version = context.fetchone(select(Version.c.version).limit(1)) if version and version['version'] == SCHEMA_VERSION:
if version and version['version'] == SCHEMA_VERSION: return
return metadata.drop_all(context.engine)
metadata.drop_all(context.engine) metadata.create_all(context.engine)
metadata.create_all(context.engine) context.execute(Version.insert().values(version=SCHEMA_VERSION))
context.execute(Version.insert().values(version=SCHEMA_VERSION))
else:
metadata.create_all(context.engine)
class QueryContext(NamedTuple): class QueryContext(NamedTuple):
@ -375,7 +372,7 @@ def execute_fetchall(sql):
def get_best_height(): def get_best_height():
return ctx().fetchone( return ctx().fetchone(
select(func.coalesce(func.max(TX.c.height), 0).label('total')).select_from(TX) select(func.coalesce(func.max(TX.c.height), -1).label('total')).select_from(TX)
)['total'] )['total']