diff --git a/lbry/wallet/server/db/revertable.py b/lbry/wallet/server/db/revertable.py index d6e957a7b..604c7c60f 100644 --- a/lbry/wallet/server/db/revertable.py +++ b/lbry/wallet/server/db/revertable.py @@ -1,10 +1,10 @@ import struct from string import printable -from collections import OrderedDict, defaultdict -from typing import Tuple, List, Iterable, Callable, Optional +from collections import defaultdict +from typing import Tuple, Iterable, Callable, Optional from lbry.wallet.server.db import DB_PREFIXES -_OP_STRUCT = struct.Struct('>BHH') +_OP_STRUCT = struct.Struct('>BLL') class RevertableOp: @@ -30,7 +30,7 @@ class RevertableOp: Serialize to bytes """ return struct.pack( - f'>BHH{len(self.key)}s{len(self.value)}s', int(self.is_put), len(self.key), len(self.value), self.key, + f'>BLL{len(self.key)}s{len(self.value)}s', int(self.is_put), len(self.key), len(self.value), self.key, self.value ) @@ -42,23 +42,12 @@ class RevertableOp: :param packed: bytes containing at least one packed revertable op :return: tuple of the deserialized op (a put or a delete) and the remaining serialized bytes """ - is_put, key_len, val_len = _OP_STRUCT.unpack(packed[:5]) - key = packed[5:5 + key_len] - value = packed[5 + key_len:5 + key_len + val_len] + is_put, key_len, val_len = _OP_STRUCT.unpack(packed[:9]) + key = packed[9:9 + key_len] + value = packed[9 + key_len:9 + key_len + val_len] if is_put == 1: - return RevertablePut(key, value), packed[5 + key_len + val_len:] - return RevertableDelete(key, value), packed[5 + key_len + val_len:] - - @classmethod - def unpack_stack(cls, packed: bytes) -> List['RevertableOp']: - """ - Deserialize multiple from bytes - """ - ops = [] - while packed: - op, packed = cls.unpack(packed) - ops.append(op) - return ops + return RevertablePut(key, value), packed[9 + key_len + val_len:] + return RevertableDelete(key, value), packed[9 + key_len + val_len:] def __eq__(self, other: 'RevertableOp') -> bool: return (self.is_put, self.key, self.value) == (other.is_put, other.key, other.value) @@ -134,3 +123,16 @@ class RevertableOpStack: for key, ops in self._items.items(): for op in ops: yield op + + def __reversed__(self): + for key, ops in self._items.items(): + for op in reversed(ops): + yield op + + def get_undo_ops(self) -> bytes: + return b''.join(op.invert().pack() for op in reversed(self)) + + def apply_packed_undo_ops(self, packed: bytes): + while packed: + op, packed = RevertableOp.unpack(packed) + self.append(op)