diff --git a/lbry/wallet/server/db/__init__.py b/lbry/wallet/server/db/__init__.py index 5826a5421..f286840be 100644 --- a/lbry/wallet/server/db/__init__.py +++ b/lbry/wallet/server/db/__init__.py @@ -39,7 +39,7 @@ class DB_PREFIXES(enum.Enum): db_state = b's' channel_count = b'Z' support_amount = b'a' - block_txs = b'b' + block_tx = b'b' trending_notifications = b'c' mempool_tx = b'd' touched_hashX = b'e' diff --git a/lbry/wallet/server/db/db.py b/lbry/wallet/server/db/db.py index d381497db..097f91a4b 100644 --- a/lbry/wallet/server/db/db.py +++ b/lbry/wallet/server/db/db.py @@ -394,7 +394,9 @@ class HubDB: def get_claims_for_name(self, name): claims = [] prefix = self.prefix_db.claim_short_id.pack_partial_key(name) + bytes([1]) - for _k, _v in self.prefix_db.iterator(prefix=prefix): + stop = self.prefix_db.claim_short_id.pack_partial_key(name) + int(2).to_bytes(1, byteorder='big') + cf = self.prefix_db.column_families[self.prefix_db.claim_short_id.prefix] + for _v in self.prefix_db.iterator(column_family=cf, start=prefix, iterate_upper_bound=stop, include_key=False): v = self.prefix_db.claim_short_id.unpack_value(_v) claim_hash = self.get_claim_from_txo(v.tx_num, v.position).claim_hash if claim_hash not in claims: @@ -459,7 +461,9 @@ class HubDB: def get_claim_txos_for_name(self, name: str): txos = {} prefix = self.prefix_db.claim_short_id.pack_partial_key(name) + int(1).to_bytes(1, byteorder='big') - for k, v in self.prefix_db.iterator(prefix=prefix): + stop = self.prefix_db.claim_short_id.pack_partial_key(name) + int(2).to_bytes(1, byteorder='big') + cf = self.prefix_db.column_families[self.prefix_db.claim_short_id.prefix] + for v in self.prefix_db.iterator(column_family=cf, start=prefix, iterate_upper_bound=stop, include_key=False): tx_num, nout = self.prefix_db.claim_short_id.unpack_value(v) txos[self.get_claim_from_txo(tx_num, nout).claim_hash] = tx_num, nout return txos diff --git a/lbry/wallet/server/db/interface.py b/lbry/wallet/server/db/interface.py index 7b1bc6039..a4066d18d 100644 --- a/lbry/wallet/server/db/interface.py +++ b/lbry/wallet/server/db/interface.py @@ -1,126 +1,12 @@ import struct +import typing + import rocksdb from typing import Optional from lbry.wallet.server.db import DB_PREFIXES from lbry.wallet.server.db.revertable import RevertableOpStack, RevertablePut, RevertableDelete -class RocksDBStore: - def __init__(self, path: str, cache_mb: int, max_open_files: int, secondary_path: str = ''): - # Use snappy compression (the default) - self.path = path - self.secondary_path = secondary_path - self._max_open_files = max_open_files - self.db = rocksdb.DB(path, self.get_options(), secondary_name=secondary_path) - # self.multi_get = self.db.multi_get - - def get_options(self): - return rocksdb.Options( - create_if_missing=True, use_fsync=True, target_file_size_base=33554432, - max_open_files=self._max_open_files if not self.secondary_path else -1 - ) - - def get(self, key: bytes, fill_cache: bool = True) -> Optional[bytes]: - return self.db.get(key, fill_cache=fill_cache) - - def iterator(self, reverse=False, start=None, stop=None, include_start=True, include_stop=False, prefix=None, - include_key=True, include_value=True, fill_cache=True): - return RocksDBIterator( - self.db, reverse=reverse, start=start, stop=stop, include_start=include_start, include_stop=include_stop, - prefix=prefix if start is None and stop is None else None, include_key=include_key, - include_value=include_value - ) - - def write_batch(self, disable_wal: bool = False, sync: bool = False): - return RocksDBWriteBatch(self.db, sync=sync, disable_wal=disable_wal) - - def close(self): - self.db.close() - self.db = None - - @property - def closed(self) -> bool: - return self.db is None - - def try_catch_up_with_primary(self): - self.db.try_catch_up_with_primary() - - -class RocksDBWriteBatch: - def __init__(self, db: rocksdb.DB, sync: bool = False, disable_wal: bool = False): - self.batch = rocksdb.WriteBatch() - self.db = db - self.sync = sync - self.disable_wal = disable_wal - - def __enter__(self): - return self.batch - - def __exit__(self, exc_type, exc_val, exc_tb): - if not exc_val: - self.db.write(self.batch, sync=self.sync, disable_wal=self.disable_wal) - - -class RocksDBIterator: - """An iterator for RocksDB.""" - - __slots__ = [ - 'start', - 'prefix', - 'stop', - 'iterator', - 'include_key', - 'include_value', - 'prev_k', - 'reverse', - 'include_start', - 'include_stop' - ] - - def __init__(self, db: rocksdb.DB, prefix: bytes = None, start: bytes = None, stop: bytes = None, - include_key: bool = True, include_value: bool = True, reverse: bool = False, - include_start: bool = True, include_stop: bool = False): - assert (start is None and stop is None) or (prefix is None), 'cannot use start/stop and prefix' - self.start = start - self.prefix = prefix - self.stop = stop - self.iterator = db.iteritems() if not reverse else reversed(db.iteritems()) - if prefix is not None: - self.iterator.seek(prefix) - elif start is not None: - self.iterator.seek(start) - self.include_key = include_key - self.include_value = include_value - self.prev_k = None - self.reverse = reverse - self.include_start = include_start - self.include_stop = include_stop - - def __iter__(self): - return self - - def _check_stop_iteration(self, key: bytes): - if self.stop is not None and (key.startswith(self.stop) or self.stop < key[:len(self.stop)]): - raise StopIteration - elif self.start is not None and self.start > key[:len(self.start)]: - raise StopIteration - elif self.prefix is not None and not key.startswith(self.prefix): - raise StopIteration - - def __next__(self): - if self.prev_k is not None: - self._check_stop_iteration(self.prev_k) - k, v = next(self.iterator) - self._check_stop_iteration(k) - self.prev_k = k - - if self.include_key and self.include_value: - return k, v - elif self.include_key: - return k - return v - - class PrefixDB: """ Base class for a revertable rocksdb database (a rocksdb db where each set of applied changes can be undone) @@ -128,9 +14,25 @@ class PrefixDB: UNDO_KEY_STRUCT = struct.Struct(b'>Q32s') PARTIAL_UNDO_KEY_STRUCT = struct.Struct(b'>Q') - def __init__(self, db: RocksDBStore, max_undo_depth: int = 200, unsafe_prefixes=None): - self._db = db - self._op_stack = RevertableOpStack(db.get, unsafe_prefixes=unsafe_prefixes) + def __init__(self, path, max_open_files=64, secondary_path='', max_undo_depth: int = 200, unsafe_prefixes=None): + column_family_options = { + prefix.value: rocksdb.ColumnFamilyOptions() for prefix in DB_PREFIXES + } if secondary_path else {} + self.column_families: typing.Dict[bytes, 'rocksdb.ColumnFamilyHandle'] = {} + self._db = rocksdb.DB( + path, rocksdb.Options( + create_if_missing=True, use_fsync=True, target_file_size_base=33554432, + max_open_files=max_open_files if not secondary_path else -1 + ), secondary_name=secondary_path, column_families=column_family_options + ) + for prefix in DB_PREFIXES: + cf = self._db.get_column_family(prefix.value) + if cf is None and not secondary_path: + self._db.create_column_family(prefix.value, rocksdb.ColumnFamilyOptions()) + cf = self._db.get_column_family(prefix.value) + self.column_families[prefix.value] = cf + + self._op_stack = RevertableOpStack(self.get, unsafe_prefixes=unsafe_prefixes) self._max_undo_depth = max_undo_depth def unsafe_commit(self): @@ -144,11 +46,13 @@ class PrefixDB: with self._db.write_batch(sync=True) as batch: batch_put = batch.put batch_delete = batch.delete + get_column_family = self.column_families.__getitem__ for staged_change in self._op_stack: + column_family = get_column_family(DB_PREFIXES(staged_change.key[:1]).value) if staged_change.is_put: - batch_put(staged_change.key, staged_change.value) + batch_put((column_family, staged_change.key), staged_change.value) else: - batch_delete(staged_change.key) + batch_delete((column_family, staged_change.key)) finally: self._op_stack.clear() @@ -161,21 +65,24 @@ class PrefixDB: if height > self._max_undo_depth: delete_undos.extend(self._db.iterator( start=DB_PREFIXES.undo.value + self.PARTIAL_UNDO_KEY_STRUCT.pack(0), - stop=DB_PREFIXES.undo.value + self.PARTIAL_UNDO_KEY_STRUCT.pack(height - self._max_undo_depth), + iterate_upper_bound=DB_PREFIXES.undo.value + self.PARTIAL_UNDO_KEY_STRUCT.pack(height - self._max_undo_depth), include_value=False )) try: + undo_c_f = self.column_families[DB_PREFIXES.undo.value] with self._db.write_batch(sync=True) as batch: batch_put = batch.put batch_delete = batch.delete + get_column_family = self.column_families.__getitem__ for staged_change in self._op_stack: + column_family = get_column_family(DB_PREFIXES(staged_change.key[:1]).value) if staged_change.is_put: - batch_put(staged_change.key, staged_change.value) + batch_put((column_family, staged_change.key), staged_change.value) else: - batch_delete(staged_change.key) + batch_delete((column_family, staged_change.key)) for undo_to_delete in delete_undos: - batch_delete(undo_to_delete) - batch_put(DB_PREFIXES.undo.value + self.UNDO_KEY_STRUCT.pack(height, block_hash), undo_ops) + batch_delete((undo_c_f, undo_to_delete)) + batch_put((undo_c_f, DB_PREFIXES.undo.value + self.UNDO_KEY_STRUCT.pack(height, block_hash)), undo_ops) finally: self._op_stack.clear() @@ -184,33 +91,41 @@ class PrefixDB: Revert changes for a block height """ undo_key = DB_PREFIXES.undo.value + self.UNDO_KEY_STRUCT.pack(height, block_hash) - undo_info = self._db.get(undo_key) + undo_c_f = self.column_families[DB_PREFIXES.undo.value] + undo_info = self._db.get((undo_c_f, undo_key)) self._op_stack.apply_packed_undo_ops(undo_info) try: with self._db.write_batch(sync=True) as batch: batch_put = batch.put batch_delete = batch.delete + get_column_family = self.column_families.__getitem__ for staged_change in self._op_stack: + column_family = get_column_family(DB_PREFIXES(staged_change.key[:1]).value) if staged_change.is_put: - batch_put(staged_change.key, staged_change.value) + batch_put((column_family, staged_change.key), staged_change.value) else: - batch_delete(staged_change.key) + batch_delete((column_family, staged_change.key)) # batch_delete(undo_key) finally: self._op_stack.clear() def get(self, key: bytes, fill_cache: bool = True) -> Optional[bytes]: - return self._db.get(key, fill_cache=fill_cache) + cf = self.column_families[key[:1]] + return self._db.get((cf, key), fill_cache=fill_cache) - def iterator(self, reverse=False, start=None, stop=None, include_start=True, include_stop=False, prefix=None, - include_key=True, include_value=True, fill_cache=True): + def iterator(self, start: bytes, column_family: 'rocksdb.ColumnFamilyHandle' = None, + iterate_lower_bound: bytes = None, iterate_upper_bound: bytes = None, + reverse: bool = False, include_key: bool = True, include_value: bool = True, + fill_cache: bool = True, prefix_same_as_start: bool = True, auto_prefix_mode: bool = True): return self._db.iterator( - reverse=reverse, start=start, stop=stop, include_start=include_start, include_stop=include_stop, - prefix=prefix, include_key=include_key, include_value=include_value, fill_cache=fill_cache + start=start, column_family=column_family, iterate_lower_bound=iterate_lower_bound, + iterate_upper_bound=iterate_upper_bound, reverse=reverse, include_key=include_key, + include_value=include_value, fill_cache=fill_cache, prefix_same_as_start=prefix_same_as_start, + auto_prefix_mode=auto_prefix_mode ) def close(self): - if not self._db.closed: + if not self._db.is_closed: self._db.close() def try_catch_up_with_primary(self): @@ -218,7 +133,7 @@ class PrefixDB: @property def closed(self) -> bool: - return self._db.closed + return self._db.is_closed def stage_raw_put(self, key: bytes, value: bytes): self._op_stack.append_op(RevertablePut(key, value)) diff --git a/lbry/wallet/server/db/prefixes.py b/lbry/wallet/server/db/prefixes.py index f3baa71c4..82758f56f 100644 --- a/lbry/wallet/server/db/prefixes.py +++ b/lbry/wallet/server/db/prefixes.py @@ -4,10 +4,12 @@ import array import base64 from typing import Union, Tuple, NamedTuple, Optional from lbry.wallet.server.db import DB_PREFIXES -from lbry.wallet.server.db.interface import RocksDBStore, PrefixDB +from lbry.wallet.server.db.interface import PrefixDB from lbry.wallet.server.db.common import TrendingNotification from lbry.wallet.server.db.revertable import RevertableOpStack, RevertablePut, RevertableDelete from lbry.schema.url import normalize_name +if typing.TYPE_CHECKING: + import rocksdb ACTIVATED_CLAIM_TXO_TYPE = 1 ACTIVATED_SUPPORT_TXO_TYPE = 2 @@ -39,21 +41,32 @@ class PrefixRow(metaclass=PrefixRowType): value_struct: struct.Struct key_part_lambdas = [] - def __init__(self, db: RocksDBStore, op_stack: RevertableOpStack): + def __init__(self, db: 'rocksdb.DB', op_stack: RevertableOpStack): self._db = db self._op_stack = op_stack + self._column_family = self._db.get_column_family(self.prefix) + if not self._column_family.is_valid: + raise RuntimeError('column family is not valid') - def iterate(self, prefix=None, start=None, stop=None, - reverse: bool = False, include_key: bool = True, include_value: bool = True, - fill_cache: bool = True, deserialize_key: bool = True, deserialize_value: bool = True): + def iterate(self, prefix=None, start=None, stop=None, reverse: bool = False, include_key: bool = True, + include_value: bool = True, fill_cache: bool = True, deserialize_key: bool = True, + deserialize_value: bool = True): if not prefix and not start and not stop: prefix = () if prefix is not None: prefix = self.pack_partial_key(*prefix) - if start is not None: - start = self.pack_partial_key(*start) - if stop is not None: - stop = self.pack_partial_key(*stop) + if stop is None: + try: + stop = (int.from_bytes(prefix, byteorder='big') + 1).to_bytes(len(prefix), byteorder='big') + except OverflowError: + stop = (int.from_bytes(prefix, byteorder='big') + 1).to_bytes(len(prefix) + 1, byteorder='big') + else: + stop = self.pack_partial_key(*stop) + else: + if start is not None: + start = self.pack_partial_key(*start) + if stop is not None: + stop = self.pack_partial_key(*stop) if deserialize_key: key_getter = lambda k: self.unpack_key(k) @@ -64,25 +77,27 @@ class PrefixRow(metaclass=PrefixRowType): else: value_getter = lambda v: v + it = self._db.iterator( + start or prefix, self._column_family, iterate_lower_bound=None, + iterate_upper_bound=stop, reverse=reverse, include_key=include_key, + include_value=include_value, fill_cache=fill_cache, prefix_same_as_start=True + ) + if include_key and include_value: - for k, v in self._db.iterator(prefix=prefix, start=start, stop=stop, reverse=reverse, - fill_cache=fill_cache): - yield key_getter(k), value_getter(v) + for k, v in it: + yield key_getter(k[1]), value_getter(v) elif include_key: - for k in self._db.iterator(prefix=prefix, start=start, stop=stop, reverse=reverse, include_value=False, - fill_cache=fill_cache): - yield key_getter(k) + for k in it: + yield key_getter(k[1]) elif include_value: - for v in self._db.iterator(prefix=prefix, start=start, stop=stop, reverse=reverse, include_key=False, - fill_cache=fill_cache): + for v in it: yield value_getter(v) else: - for _ in self._db.iterator(prefix=prefix, start=start, stop=stop, reverse=reverse, include_key=False, - include_value=False, fill_cache=fill_cache): + for _ in it: yield None def get(self, *key_args, fill_cache=True, deserialize_value=True): - v = self._db.get(self.pack_key(*key_args), fill_cache=fill_cache) + v = self._db.get((self._column_family, self.pack_key(*key_args)), fill_cache=fill_cache) if v: return v if not deserialize_value else self.unpack_value(v) @@ -94,7 +109,7 @@ class PrefixRow(metaclass=PrefixRowType): return last_op.value if not deserialize_value else self.unpack_value(last_op.value) else: # it's a delete return - v = self._db.get(packed_key, fill_cache=fill_cache) + v = self._db.get((self._column_family, packed_key), fill_cache=fill_cache) if v: return v if not deserialize_value else self.unpack_value(v) @@ -118,7 +133,7 @@ class PrefixRow(metaclass=PrefixRowType): @classmethod def unpack_key(cls, key: bytes): - assert key[:1] == cls.prefix + assert key[:1] == cls.prefix, f"prefix should be {cls.prefix}, got {key[:1]}" return cls.key_struct.unpack(key[1:]) @classmethod @@ -1571,7 +1586,7 @@ class DBStatePrefixRow(PrefixRow): class BlockTxsPrefixRow(PrefixRow): - prefix = DB_PREFIXES.block_txs.value + prefix = DB_PREFIXES.block_tx.value key_struct = struct.Struct(b'>L') key_part_lambdas = [ lambda: b'', @@ -1600,12 +1615,18 @@ class BlockTxsPrefixRow(PrefixRow): return cls.pack_key(height), cls.pack_value(tx_hashes) -class MempoolTxKey(TxKey): - pass +class MempoolTxKey(NamedTuple): + tx_hash: bytes + + def __str__(self): + return f"{self.__class__.__name__}(tx_hash={self.tx_hash[::-1].hex()})" -class MempoolTxValue(TxValue): - pass +class MempoolTxValue(NamedTuple): + raw_tx: bytes + + def __str__(self): + return f"{self.__class__.__name__}(raw_tx={base64.b64encode(self.raw_tx).decode()})" class MempoolTXPrefixRow(PrefixRow): @@ -1724,8 +1745,9 @@ class TouchedHashXPrefixRow(PrefixRow): class HubDB(PrefixDB): def __init__(self, path: str, cache_mb: int = 128, reorg_limit: int = 200, max_open_files: int = 512, secondary_path: str = '', unsafe_prefixes: Optional[typing.Set[bytes]] = None): - db = RocksDBStore(path, cache_mb, max_open_files, secondary_path=secondary_path) - super().__init__(db, reorg_limit, unsafe_prefixes=unsafe_prefixes) + super().__init__(path, max_open_files=max_open_files, secondary_path=secondary_path, + max_undo_depth=reorg_limit, unsafe_prefixes=unsafe_prefixes) + db = self._db self.claim_to_support = ClaimToSupportPrefixRow(db, self._op_stack) self.support_to_claim = SupportToClaimPrefixRow(db, self._op_stack) self.claim_to_txo = ClaimToTXOPrefixRow(db, self._op_stack) diff --git a/tests/unit/wallet/server/test_revertable.py b/tests/unit/wallet/server/test_revertable.py index b3fe03a57..db9e1d0e9 100644 --- a/tests/unit/wallet/server/test_revertable.py +++ b/tests/unit/wallet/server/test_revertable.py @@ -151,3 +151,61 @@ class TestRevertablePrefixDB(unittest.TestCase): self.assertEqual(10000000, self.db.claim_takeover.get(name).height) self.db.rollback(10000000, b'\x00' * 32) self.assertIsNone(self.db.claim_takeover.get(name)) + + def test_hub_db_iterator(self): + name = 'derp' + claim_hash0 = 20 * b'\x00' + claim_hash1 = 20 * b'\x01' + claim_hash2 = 20 * b'\x02' + claim_hash3 = 20 * b'\x03' + overflow_value = 0xffffffff + self.db.claim_expiration.stage_put((99, 999, 0), (claim_hash0, name)) + self.db.claim_expiration.stage_put((100, 1000, 0), (claim_hash1, name)) + self.db.claim_expiration.stage_put((100, 1001, 0), (claim_hash2, name)) + self.db.claim_expiration.stage_put((101, 1002, 0), (claim_hash3, name)) + self.db.claim_expiration.stage_put((overflow_value - 1, 1003, 0), (claim_hash3, name)) + self.db.claim_expiration.stage_put((overflow_value, 1004, 0), (claim_hash3, name)) + self.db.tx_num.stage_put((b'\x00' * 32,), (101,)) + self.db.claim_takeover.stage_put((name,), (claim_hash3, 101)) + self.db.db_state.stage_put((), (b'n?\xcf\x12\x99\xd4\xec]y\xc3\xa4\xc9\x1dbJJ\xcf\x9e.\x17=\x95\xa1\xa0POgvihuV', 0, 1, b'VuhivgOP\xa0\xa1\x95=\x17.\x9e\xcfJJb\x1d\xc9\xa4\xc3y]\xec\xd4\x99\x12\xcf?n', 1, 0, 1, 7, 1, -1, -1, 0)) + self.db.unsafe_commit() + + state = self.db.db_state.get() + self.assertEqual(b'n?\xcf\x12\x99\xd4\xec]y\xc3\xa4\xc9\x1dbJJ\xcf\x9e.\x17=\x95\xa1\xa0POgvihuV', state.genesis) + + self.assertListEqual( + [], list(self.db.claim_expiration.iterate(prefix=(98,))) + ) + self.assertListEqual( + list(self.db.claim_expiration.iterate(start=(98,), stop=(99,))), + list(self.db.claim_expiration.iterate(prefix=(98,))) + ) + self.assertListEqual( + list(self.db.claim_expiration.iterate(start=(99,), stop=(100,))), + list(self.db.claim_expiration.iterate(prefix=(99,))) + ) + self.assertListEqual( + [ + ((99, 999, 0), (claim_hash0, name)), + ], list(self.db.claim_expiration.iterate(prefix=(99,))) + ) + self.assertListEqual( + [ + ((100, 1000, 0), (claim_hash1, name)), + ((100, 1001, 0), (claim_hash2, name)) + ], list(self.db.claim_expiration.iterate(prefix=(100,))) + ) + self.assertListEqual( + list(self.db.claim_expiration.iterate(start=(100,), stop=(101,))), + list(self.db.claim_expiration.iterate(prefix=(100,))) + ) + self.assertListEqual( + [ + ((overflow_value - 1, 1003, 0), (claim_hash3, name)) + ], list(self.db.claim_expiration.iterate(prefix=(overflow_value - 1,))) + ) + self.assertListEqual( + [ + ((overflow_value, 1004, 0), (claim_hash3, name)) + ], list(self.db.claim_expiration.iterate(prefix=(overflow_value,))) + )