From 43bef9447c0fd2cec62e9be35d49fb07eb327f8f Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Thu, 12 Jul 2018 15:44:07 -0400 Subject: [PATCH] progress on publish command: py3 porting and integration tests --- lbrynet/core/StreamDescriptor.py | 21 ++++++-- lbrynet/daemon/Components.py | 1 + lbrynet/daemon/Publisher.py | 6 +-- lbrynet/database/storage.py | 8 ++- lbrynet/file_manager/EncryptedFileCreator.py | 7 +-- lbrynet/reflector/client/client.py | 8 +-- lbrynet/wallet/manager.py | 17 ++++++ tests/integration/wallet/test_commands.py | 57 +++++++++++++++++--- 8 files changed, 101 insertions(+), 24 deletions(-) diff --git a/lbrynet/core/StreamDescriptor.py b/lbrynet/core/StreamDescriptor.py index d0dd951ad..b53b268d2 100644 --- a/lbrynet/core/StreamDescriptor.py +++ b/lbrynet/core/StreamDescriptor.py @@ -1,3 +1,4 @@ +import six import binascii from collections import defaultdict import json @@ -66,6 +67,16 @@ class BlobStreamDescriptorReader(StreamDescriptorReader): return threads.deferToThread(get_data) +def bytes2unicode(value): + if isinstance(value, bytes): + return value.decode() + elif isinstance(value, (list, tuple)): + return [bytes2unicode(v) for v in value] + elif isinstance(value, dict): + return {key: bytes2unicode(v) for key, v in value.items()} + return value + + class StreamDescriptorWriter(object): """Classes which derive from this class write fields from a dictionary of fields to a stream descriptor""" @@ -73,7 +84,7 @@ class StreamDescriptorWriter(object): pass def create_descriptor(self, sd_info): - return self._write_stream_descriptor(json.dumps(sd_info, sort_keys=True)) + return self._write_stream_descriptor(json.dumps(bytes2unicode(sd_info), sort_keys=True)) def _write_stream_descriptor(self, raw_data): """This method must be overridden by subclasses to write raw data to @@ -345,9 +356,9 @@ def get_blob_hashsum(b): blob_hashsum = get_lbry_hash_obj() if length != 0: blob_hashsum.update(blob_hash) - blob_hashsum.update(str(blob_num)) + blob_hashsum.update(str(blob_num).encode()) blob_hashsum.update(iv) - blob_hashsum.update(str(length)) + blob_hashsum.update(str(length).encode()) return blob_hashsum.digest() @@ -365,7 +376,7 @@ def get_stream_hash(hex_stream_name, key, hex_suggested_file_name, blob_infos): def verify_hex(text, field_name): for c in text: - if c not in '0123456789abcdef': + if c not in b'0123456789abcdef': raise InvalidStreamDescriptorError("%s is not a hex-encoded string" % field_name) @@ -391,7 +402,7 @@ def validate_descriptor(stream_info): calculated_stream_hash = get_stream_hash( hex_stream_name, key, hex_suggested_file_name, blobs - ) + ).encode() if calculated_stream_hash != stream_hash: raise InvalidStreamDescriptorError("Stream hash does not match stream metadata") return True diff --git a/lbrynet/daemon/Components.py b/lbrynet/daemon/Components.py index dd84ff734..424290ed4 100644 --- a/lbrynet/daemon/Components.py +++ b/lbrynet/daemon/Components.py @@ -332,6 +332,7 @@ class WalletComponent(Component): storage = self.component_manager.get_component(DATABASE_COMPONENT) lbryschema.BLOCKCHAIN_NAME = conf.settings['blockchain_name'] self.wallet = LbryWalletManager.from_old_config(conf.settings) + self.wallet.old_db = storage yield self.wallet.start() @defer.inlineCallbacks diff --git a/lbrynet/daemon/Publisher.py b/lbrynet/daemon/Publisher.py index 16f13dd28..3bd8d72e2 100644 --- a/lbrynet/daemon/Publisher.py +++ b/lbrynet/daemon/Publisher.py @@ -49,7 +49,7 @@ class Publisher(object): # check if we have a file already for this claim (if this is a publish update with a new stream) old_stream_hashes = yield self.storage.get_old_stream_hashes_for_claim_id( - tx.get_claim_id(0).decode(), self.lbry_file.stream_hash + tx.get_claim_id(0), self.lbry_file.stream_hash.decode() ) if old_stream_hashes: for lbry_file in filter(lambda l: l.stream_hash in old_stream_hashes, @@ -58,7 +58,7 @@ class Publisher(object): log.info("Removed old stream for claim update: %s", lbry_file.stream_hash) yield self.storage.save_content_claim( - self.lbry_file.stream_hash, get_certificate_lookup(tx, 0).decode() + self.lbry_file.stream_hash.decode(), get_certificate_lookup(tx, 0) ) defer.returnValue(tx) @@ -70,7 +70,7 @@ class Publisher(object): ) if stream_hash: # the stream_hash returned from the db will be None if this isn't a stream we have yield self.storage.save_content_claim( - stream_hash, get_certificate_lookup(tx, 0).decode() + stream_hash.decode(), get_certificate_lookup(tx, 0) ) self.lbry_file = [f for f in self.lbry_file_manager.lbry_files if f.stream_hash == stream_hash][0] defer.returnValue(tx) diff --git a/lbrynet/database/storage.py b/lbrynet/database/storage.py index b8e92251e..861b99515 100644 --- a/lbrynet/database/storage.py +++ b/lbrynet/database/storage.py @@ -2,6 +2,7 @@ import logging import os import sqlite3 import traceback +from binascii import hexlify, unhexlify from decimal import Decimal from twisted.internet import defer, task, threads from twisted.enterprise import adbapi @@ -613,7 +614,7 @@ class SQLiteStorage(WalletDatabase): source_hash = None except AttributeError: source_hash = None - serialized = claim_info.get('hex') or smart_decode(claim_info['value']).serialized.encode('hex') + serialized = claim_info.get('hex') or hexlify(smart_decode(claim_info['value']).serialized) transaction.execute( "insert or replace into claim values (?, ?, ?, ?, ?, ?, ?, ?, ?)", (outpoint, claim_id, name, amount, height, serialized, certificate_id, address, sequence) @@ -671,12 +672,15 @@ class SQLiteStorage(WalletDatabase): ).fetchone() if not claim_info: raise Exception("claim not found") - new_claim_id, claim = claim_info[0], ClaimDict.deserialize(claim_info[1].decode('hex')) + new_claim_id, claim = claim_info[0], ClaimDict.deserialize(unhexlify(claim_info[1])) # certificate claims should not be in the content_claim table if not claim.is_stream: raise Exception("claim does not contain a stream") + if not isinstance(stream_hash, bytes): + stream_hash = stream_hash.encode() + # get the known sd hash for this stream known_sd_hash = transaction.execute( "select sd_hash from stream where stream_hash=?", (stream_hash,) diff --git a/lbrynet/file_manager/EncryptedFileCreator.py b/lbrynet/file_manager/EncryptedFileCreator.py index a5411d2ec..91ab26d93 100644 --- a/lbrynet/file_manager/EncryptedFileCreator.py +++ b/lbrynet/file_manager/EncryptedFileCreator.py @@ -2,6 +2,7 @@ Utilities for turning plain files into LBRY Files. """ +import six import binascii import logging import os @@ -44,7 +45,7 @@ class EncryptedFileStreamCreator(CryptStreamCreator): # generate the sd info self.sd_info = format_sd_info( EncryptedFileStreamType, hexlify(self.name), hexlify(self.key), - hexlify(self.name), self.stream_hash, self.blob_infos + hexlify(self.name), self.stream_hash.encode(), self.blob_infos ) # sanity check @@ -125,14 +126,14 @@ def create_lbry_file(blob_manager, storage, payment_rate_manager, lbry_file_mana ) log.debug("adding to the file manager") lbry_file = yield lbry_file_manager.add_published_file( - sd_info['stream_hash'], sd_hash, binascii.hexlify(file_directory), payment_rate_manager, + sd_info['stream_hash'], sd_hash, binascii.hexlify(file_directory.encode()), payment_rate_manager, payment_rate_manager.min_blob_data_payment_rate ) defer.returnValue(lbry_file) def hexlify(str_or_unicode): - if isinstance(str_or_unicode, unicode): + if isinstance(str_or_unicode, six.text_type): strng = str_or_unicode.encode('utf-8') else: strng = str_or_unicode diff --git a/lbrynet/reflector/client/client.py b/lbrynet/reflector/client/client.py index 09c4694c4..1dd33144e 100644 --- a/lbrynet/reflector/client/client.py +++ b/lbrynet/reflector/client/client.py @@ -16,8 +16,8 @@ class EncryptedFileReflectorClient(Protocol): # Protocol stuff def connectionMade(self): log.debug("Connected to reflector") - self.response_buff = '' - self.outgoing_buff = '' + self.response_buff = b'' + self.outgoing_buff = b'' self.blob_hashes_to_send = [] self.failed_blob_hashes = [] self.next_blob_to_send = None @@ -50,7 +50,7 @@ class EncryptedFileReflectorClient(Protocol): except IncompleteResponse: pass else: - self.response_buff = '' + self.response_buff = b'' d = self.handle_response(msg) d.addCallback(lambda _: self.send_next_request()) d.addErrback(self.response_failure_handler) @@ -143,7 +143,7 @@ class EncryptedFileReflectorClient(Protocol): return d def send_request(self, request_dict): - self.write(json.dumps(request_dict)) + self.write(json.dumps(request_dict).encode()) def send_handshake(self): self.send_request({'version': self.protocol_version}) diff --git a/lbrynet/wallet/manager.py b/lbrynet/wallet/manager.py index 3353fcf60..8df66573c 100644 --- a/lbrynet/wallet/manager.py +++ b/lbrynet/wallet/manager.py @@ -1,5 +1,6 @@ import os import json +from binascii import hexlify from twisted.internet import defer from torba.manager import WalletManager as BaseWalletManager @@ -160,9 +161,25 @@ class LbryWalletManager(BaseWalletManager): ) tx = yield Transaction.claim(name.encode(), claim, amount, claim_address, [account], account) yield account.ledger.broadcast(tx) + yield self.old_db.save_claims([self._old_get_temp_claim_info( + tx, tx.outputs[0], claim_address, claim_dict, name, amount + )]) # TODO: release reserved tx outputs in case anything fails by this point defer.returnValue(tx) + def _old_get_temp_claim_info(self, tx, txo, address, claim_dict, name, bid): + return { + "claim_id": hexlify(tx.get_claim_id(txo.index)).decode(), + "name": name, + "amount": bid, + "address": address.decode(), + "txid": tx.hex_id.decode(), + "nout": txo.index, + "value": claim_dict, + "height": -1, + "claim_sequence": -1, + } + @defer.inlineCallbacks def claim_new_channel(self, channel_name, amount): try: diff --git a/tests/integration/wallet/test_commands.py b/tests/integration/wallet/test_commands.py index 8102888b5..a6ee34892 100644 --- a/tests/integration/wallet/test_commands.py +++ b/tests/integration/wallet/test_commands.py @@ -1,4 +1,7 @@ +import six import tempfile +from types import SimpleNamespace +from binascii import hexlify from twisted.internet import defer from orchstr8.testcase import IntegrationTestCase, d2f @@ -25,14 +28,49 @@ class FakeAnalytics: pass +class FakeBlob: + def __init__(self): + self.data = [] + self.blob_hash = 'abc' + self.length = 3 + + def write(self, data): + self.data.append(data) + + def close(self): + if self.data: + return defer.succeed(hexlify(b'a'*48)) + return defer.succeed(None) + + def get_is_verified(self): + return True + + def open_for_reading(self): + return six.StringIO('foo') + + class FakeBlobManager: def get_blob_creator(self): - return None + return FakeBlob() + + def creator_finished(self, blob_info, should_announce): + pass + + def get_blob(self, sd_hash): + return FakeBlob() class FakeSession: - storage = None blob_manager = FakeBlobManager() + peer_finder = None + rate_limiter = None + + + @property + def payment_rate_manager(self): + obj = SimpleNamespace() + obj.min_blob_data_payment_rate = 1 + return obj class CommandTestCase(IntegrationTestCase): @@ -68,22 +106,27 @@ class CommandTestCase(IntegrationTestCase): self.daemon.wallet = self.manager self.daemon.component_manager.components.add(wallet_component) + storage_component = DatabaseComponent(self.daemon.component_manager) + await d2f(storage_component.start()) + self.daemon.storage = storage_component.storage + self.daemon.wallet.old_db = self.daemon.storage + self.daemon.component_manager.components.add(storage_component) + session_component = SessionComponent(self.daemon.component_manager) session_component.session = FakeSession() session_component._running = True self.daemon.session = session_component.session + self.daemon.session.storage = self.daemon.storage + self.daemon.session.wallet = self.daemon.wallet + self.daemon.session.blob_manager.storage = self.daemon.storage self.daemon.component_manager.components.add(session_component) file_manager = FileManager(self.daemon.component_manager) file_manager.file_manager = EncryptedFileManager(session_component.session, True) file_manager._running = True + self.daemon.file_manager = file_manager.file_manager self.daemon.component_manager.components.add(file_manager) - storage_component = DatabaseComponent(self.daemon.component_manager) - await d2f(storage_component.start()) - self.daemon.storage = storage_component.storage - self.daemon.component_manager.components.add(storage_component) - class ChannelNewCommandTests(CommandTestCase):