lbry-sdk/lbry/wallet/server/db/interface.py

228 lines
8.4 KiB
Python
Raw Normal View History

2022-01-06 18:28:15 +01:00
import struct
import rocksdb
from typing import Optional
from lbry.wallet.server.db import DB_PREFIXES
from lbry.wallet.server.db.revertable import RevertableOpStack, RevertablePut, RevertableDelete
class RocksDBStore:
def __init__(self, path: str, cache_mb: int, max_open_files: int, secondary_path: str = ''):
# Use snappy compression (the default)
self.path = path
self.secondary_path = secondary_path
self._max_open_files = max_open_files
self.db = rocksdb.DB(path, self.get_options(), secondary_name=secondary_path)
# self.multi_get = self.db.multi_get
def get_options(self):
return rocksdb.Options(
create_if_missing=True, use_fsync=True, target_file_size_base=33554432,
max_open_files=self._max_open_files if not self.secondary_path else -1
)
def get(self, key: bytes, fill_cache: bool = True) -> Optional[bytes]:
return self.db.get(key, fill_cache=fill_cache)
def iterator(self, reverse=False, start=None, stop=None, include_start=True, include_stop=False, prefix=None,
include_key=True, include_value=True, fill_cache=True):
return RocksDBIterator(
self.db, reverse=reverse, start=start, stop=stop, include_start=include_start, include_stop=include_stop,
prefix=prefix if start is None and stop is None else None, include_key=include_key,
include_value=include_value
)
def write_batch(self, disable_wal: bool = False, sync: bool = False):
return RocksDBWriteBatch(self.db, sync=sync, disable_wal=disable_wal)
def close(self):
self.db.close()
self.db = None
@property
def closed(self) -> bool:
return self.db is None
def try_catch_up_with_primary(self):
self.db.try_catch_up_with_primary()
class RocksDBWriteBatch:
def __init__(self, db: rocksdb.DB, sync: bool = False, disable_wal: bool = False):
self.batch = rocksdb.WriteBatch()
self.db = db
self.sync = sync
self.disable_wal = disable_wal
def __enter__(self):
return self.batch
def __exit__(self, exc_type, exc_val, exc_tb):
if not exc_val:
self.db.write(self.batch, sync=self.sync, disable_wal=self.disable_wal)
class RocksDBIterator:
"""An iterator for RocksDB."""
__slots__ = [
'start',
'prefix',
'stop',
'iterator',
'include_key',
'include_value',
'prev_k',
'reverse',
'include_start',
'include_stop'
]
def __init__(self, db: rocksdb.DB, prefix: bytes = None, start: bytes = None, stop: bytes = None,
include_key: bool = True, include_value: bool = True, reverse: bool = False,
include_start: bool = True, include_stop: bool = False):
assert (start is None and stop is None) or (prefix is None), 'cannot use start/stop and prefix'
self.start = start
self.prefix = prefix
self.stop = stop
self.iterator = db.iteritems() if not reverse else reversed(db.iteritems())
if prefix is not None:
self.iterator.seek(prefix)
elif start is not None:
self.iterator.seek(start)
self.include_key = include_key
self.include_value = include_value
self.prev_k = None
self.reverse = reverse
self.include_start = include_start
self.include_stop = include_stop
def __iter__(self):
return self
def _check_stop_iteration(self, key: bytes):
if self.stop is not None and (key.startswith(self.stop) or self.stop < key[:len(self.stop)]):
raise StopIteration
elif self.start is not None and self.start > key[:len(self.start)]:
raise StopIteration
elif self.prefix is not None and not key.startswith(self.prefix):
raise StopIteration
def __next__(self):
if self.prev_k is not None:
self._check_stop_iteration(self.prev_k)
k, v = next(self.iterator)
self._check_stop_iteration(k)
self.prev_k = k
if self.include_key and self.include_value:
return k, v
elif self.include_key:
return k
return v
class PrefixDB:
"""
Base class for a revertable rocksdb database (a rocksdb db where each set of applied changes can be undone)
"""
UNDO_KEY_STRUCT = struct.Struct(b'>Q32s')
PARTIAL_UNDO_KEY_STRUCT = struct.Struct(b'>Q')
def __init__(self, db: RocksDBStore, max_undo_depth: int = 200, unsafe_prefixes=None):
self._db = db
self._op_stack = RevertableOpStack(db.get, unsafe_prefixes=unsafe_prefixes)
self._max_undo_depth = max_undo_depth
def unsafe_commit(self):
"""
Write staged changes to the database without keeping undo information
Changes written cannot be undone
"""
try:
if not len(self._op_stack):
return
with self._db.write_batch(sync=True) as batch:
batch_put = batch.put
batch_delete = batch.delete
for staged_change in self._op_stack:
if staged_change.is_put:
batch_put(staged_change.key, staged_change.value)
else:
batch_delete(staged_change.key)
finally:
self._op_stack.clear()
def commit(self, height: int, block_hash: bytes):
"""
Write changes for a block height to the database and keep undo information so that the changes can be reverted
"""
undo_ops = self._op_stack.get_undo_ops()
delete_undos = []
if height > self._max_undo_depth:
delete_undos.extend(self._db.iterator(
start=DB_PREFIXES.undo.value + self.PARTIAL_UNDO_KEY_STRUCT.pack(0),
stop=DB_PREFIXES.undo.value + self.PARTIAL_UNDO_KEY_STRUCT.pack(height - self._max_undo_depth),
include_value=False
))
try:
with self._db.write_batch(sync=True) as batch:
batch_put = batch.put
batch_delete = batch.delete
for staged_change in self._op_stack:
if staged_change.is_put:
batch_put(staged_change.key, staged_change.value)
else:
batch_delete(staged_change.key)
for undo_to_delete in delete_undos:
batch_delete(undo_to_delete)
batch_put(DB_PREFIXES.undo.value + self.UNDO_KEY_STRUCT.pack(height, block_hash), undo_ops)
finally:
self._op_stack.clear()
def rollback(self, height: int, block_hash: bytes):
"""
Revert changes for a block height
"""
undo_key = DB_PREFIXES.undo.value + self.UNDO_KEY_STRUCT.pack(height, block_hash)
undo_info = self._db.get(undo_key)
self._op_stack.apply_packed_undo_ops(undo_info)
try:
with self._db.write_batch(sync=True) as batch:
batch_put = batch.put
batch_delete = batch.delete
for staged_change in self._op_stack:
if staged_change.is_put:
batch_put(staged_change.key, staged_change.value)
else:
batch_delete(staged_change.key)
# batch_delete(undo_key)
finally:
self._op_stack.clear()
def get(self, key: bytes, fill_cache: bool = True) -> Optional[bytes]:
return self._db.get(key, fill_cache=fill_cache)
def iterator(self, reverse=False, start=None, stop=None, include_start=True, include_stop=False, prefix=None,
include_key=True, include_value=True, fill_cache=True):
return self._db.iterator(
reverse=reverse, start=start, stop=stop, include_start=include_start, include_stop=include_stop,
prefix=prefix, include_key=include_key, include_value=include_value, fill_cache=fill_cache
)
def close(self):
if not self._db.closed:
self._db.close()
def try_catch_up_with_primary(self):
self._db.try_catch_up_with_primary()
@property
def closed(self) -> bool:
return self._db.closed
def stage_raw_put(self, key: bytes, value: bytes):
self._op_stack.append_op(RevertablePut(key, value))
def stage_raw_delete(self, key: bytes, value: bytes):
self._op_stack.append_op(RevertableDelete(key, value))