diff --git a/.pylintrc b/.pylintrc index 1c71b985f..68f76d980 100644 --- a/.pylintrc +++ b/.pylintrc @@ -124,7 +124,8 @@ disable= keyword-arg-before-vararg, assignment-from-no-return, useless-return, - assignment-from-none + assignment-from-none, + stop-iteration-return [REPORTS] @@ -389,7 +390,7 @@ int-import-graph= [DESIGN] # Maximum number of arguments for function / method -max-args=5 +max-args=10 # Argument names that match this expression will be ignored. Default to name # with leading underscore diff --git a/lbrynet/blob/blob_file.py b/lbrynet/blob/blob_file.py index 918aef56c..4db6b5629 100644 --- a/lbrynet/blob/blob_file.py +++ b/lbrynet/blob/blob_file.py @@ -149,7 +149,7 @@ class BlobFile: def writer_finished(self, writer, err=None): def fire_finished_deferred(): self._verified = True - for p, (w, finished_deferred) in self.writers.items(): + for p, (w, finished_deferred) in list(self.writers.items()): if w == writer: del self.writers[p] finished_deferred.callback(self) diff --git a/lbrynet/blob/creator.py b/lbrynet/blob/creator.py index a90117056..d4edbfaeb 100644 --- a/lbrynet/blob/creator.py +++ b/lbrynet/blob/creator.py @@ -46,8 +46,6 @@ class BlobFileCreator: def write(self, data): if not self._is_open: raise IOError - if not isinstance(data, bytes): - data = data.encode() self._hashsum.update(data) self.len_so_far += len(data) self.buffer.write(data) diff --git a/lbrynet/core/StreamDescriptor.py b/lbrynet/core/StreamDescriptor.py index 381441677..45129269f 100644 --- a/lbrynet/core/StreamDescriptor.py +++ b/lbrynet/core/StreamDescriptor.py @@ -12,6 +12,13 @@ from lbrynet.core.HTTPBlobDownloader import HTTPBlobDownloader log = logging.getLogger(__name__) +class JSONBytesEncoder(json.JSONEncoder): + def default(self, obj): # pylint: disable=E0202 + if isinstance(obj, bytes): + return obj.decode() + return super().default(obj) + + class StreamDescriptorReader: """Classes which derive from this class read a stream descriptor file return a dictionary containing the fields in the file""" @@ -66,16 +73,6 @@ class BlobStreamDescriptorReader(StreamDescriptorReader): return threads.deferToThread(get_data) -def bytes2unicode(value): - if isinstance(value, bytes): - return value.decode() - elif isinstance(value, (list, tuple)): - return [bytes2unicode(v) for v in value] - elif isinstance(value, dict): - return {key: bytes2unicode(v) for key, v in value.items()} - return value - - class StreamDescriptorWriter: """Classes which derive from this class write fields from a dictionary of fields to a stream descriptor""" @@ -83,7 +80,9 @@ class StreamDescriptorWriter: pass def create_descriptor(self, sd_info): - return self._write_stream_descriptor(json.dumps(bytes2unicode(sd_info), sort_keys=True)) + return self._write_stream_descriptor( + json.dumps(sd_info, sort_keys=True, cls=JSONBytesEncoder).encode() + ) def _write_stream_descriptor(self, raw_data): """This method must be overridden by subclasses to write raw data to diff --git a/lbrynet/cryptstream/CryptBlob.py b/lbrynet/cryptstream/CryptBlob.py index 089139352..bc6823c78 100644 --- a/lbrynet/cryptstream/CryptBlob.py +++ b/lbrynet/cryptstream/CryptBlob.py @@ -68,14 +68,14 @@ class StreamBlobDecryptor: def write_bytes(): if self.len_read < self.length: - num_bytes_to_decrypt = greatest_multiple(len(self.buff), (AES.block_size / 8)) + num_bytes_to_decrypt = greatest_multiple(len(self.buff), (AES.block_size // 8)) data_to_decrypt, self.buff = split(self.buff, num_bytes_to_decrypt) write_func(self.cipher.update(data_to_decrypt)) def finish_decrypt(): - bytes_left = len(self.buff) % (AES.block_size / 8) + bytes_left = len(self.buff) % (AES.block_size // 8) if bytes_left != 0: - log.warning(self.buff[-1 * (AES.block_size / 8):].encode('hex')) + log.warning(self.buff[-1 * (AES.block_size // 8):].encode('hex')) raise Exception("blob %s has incorrect padding: %i bytes left" % (self.blob.blob_hash, bytes_left)) data_to_decrypt, self.buff = self.buff, b'' @@ -128,8 +128,6 @@ class CryptStreamBlobMaker: max bytes are written. num_bytes_to_write is the number of bytes that will be written from data in this call """ - if not isinstance(data, bytes): - data = data.encode() max_bytes_to_write = MAX_BLOB_SIZE - self.length - 1 done = False if max_bytes_to_write <= len(data): diff --git a/lbrynet/cryptstream/CryptStreamCreator.py b/lbrynet/cryptstream/CryptStreamCreator.py index ce15f2f07..f9e2494ec 100644 --- a/lbrynet/cryptstream/CryptStreamCreator.py +++ b/lbrynet/cryptstream/CryptStreamCreator.py @@ -102,12 +102,6 @@ class CryptStreamCreator: while 1: yield os.urandom(AES.block_size // 8) - def get_next_iv(self): - iv = next(self.iv_generator) - if not isinstance(iv, bytes): - return iv.encode() - return iv - def setup(self): """Create the symmetric key if it wasn't provided""" @@ -127,7 +121,7 @@ class CryptStreamCreator: yield defer.DeferredList(self.finished_deferreds) self.blob_count += 1 - iv = self.get_next_iv() + iv = next(self.iv_generator) final_blob = self._get_blob_maker(iv, self.blob_manager.get_blob_creator()) stream_terminator = yield final_blob.close() terminator_info = yield self._blob_finished(stream_terminator) @@ -138,7 +132,7 @@ class CryptStreamCreator: if self.current_blob is None: self.next_blob_creator = self.blob_manager.get_blob_creator() self.blob_count += 1 - iv = self.get_next_iv() + iv = next(self.iv_generator) self.current_blob = self._get_blob_maker(iv, self.next_blob_creator) done, num_bytes_written = self.current_blob.write(data) data = data[num_bytes_written:] diff --git a/lbrynet/daemon/auth/server.py b/lbrynet/daemon/auth/server.py index a87cd47bb..cc2105231 100644 --- a/lbrynet/daemon/auth/server.py +++ b/lbrynet/daemon/auth/server.py @@ -19,8 +19,8 @@ from lbrynet.core import utils from lbrynet.core.Error import ComponentsNotStarted, ComponentStartConditionNotMet from lbrynet.core.looping_call_manager import LoopingCallManager from lbrynet.daemon.ComponentManager import ComponentManager -from lbrynet.daemon.auth.util import APIKey, get_auth_message, LBRY_SECRET -from lbrynet.undecorated import undecorated +from .util import APIKey, get_auth_message, LBRY_SECRET +from .undecorated import undecorated from .factory import AuthJSONRPCResource log = logging.getLogger(__name__) diff --git a/lbrynet/undecorated.py b/lbrynet/daemon/auth/undecorated.py similarity index 100% rename from lbrynet/undecorated.py rename to lbrynet/daemon/auth/undecorated.py diff --git a/lbrynet/database/storage.py b/lbrynet/database/storage.py index c24b88d8e..244400ac3 100644 --- a/lbrynet/database/storage.py +++ b/lbrynet/database/storage.py @@ -684,11 +684,8 @@ class SQLiteStorage(WalletDatabase): ).fetchone() if not known_sd_hash: raise Exception("stream not found") - known_sd_hash = known_sd_hash[0] - if not isinstance(known_sd_hash, bytes): - known_sd_hash = known_sd_hash.encode() # check the claim contains the same sd hash - if known_sd_hash != claim.source_hash: + if known_sd_hash[0].encode() != claim.source_hash: raise Exception("stream mismatch") # if there is a current claim associated to the file, check that the new claim is an update to it diff --git a/lbrynet/dht/contact.py b/lbrynet/dht/contact.py index eb0f5c417..ebbda0fba 100644 --- a/lbrynet/dht/contact.py +++ b/lbrynet/dht/contact.py @@ -7,7 +7,7 @@ from lbrynet.dht import constants def is_valid_ipv4(address): try: - ip = ipaddress.ip_address(address.encode().decode()) # this needs to be unicode, thus re-encode-able + ip = ipaddress.ip_address(address) return ip.version == 4 except ipaddress.AddressValueError: return False @@ -22,8 +22,8 @@ class _Contact: def __init__(self, contactManager, id, ipAddress, udpPort, networkProtocol, firstComm): if id is not None: - if not len(id) == constants.key_bits / 8: - raise ValueError("invalid node id: %s" % hexlify(id.encode())) + if not len(id) == constants.key_bits // 8: + raise ValueError("invalid node id: {}".format(hexlify(id).decode())) if not 0 <= udpPort <= 65536: raise ValueError("invalid port") if not is_valid_ipv4(ipAddress): diff --git a/lbrynet/dht/distance.py b/lbrynet/dht/distance.py index 917928211..f07a1474a 100644 --- a/lbrynet/dht/distance.py +++ b/lbrynet/dht/distance.py @@ -1,26 +1,23 @@ from binascii import hexlify from lbrynet.dht import constants -import sys -if sys.version_info > (3,): - long = int class Distance: """Calculate the XOR result between two string variables. Frequently we re-use one of the points so as an optimization - we pre-calculate the long value of that point. + we pre-calculate the value of that point. """ def __init__(self, key): - if len(key) != constants.key_bits / 8: + if len(key) != constants.key_bits // 8: raise ValueError("invalid key length: %i" % len(key)) self.key = key - self.val_key_one = long(hexlify(key), 16) + self.val_key_one = int(hexlify(key), 16) def __call__(self, key_two): - val_key_two = long(hexlify(key_two), 16) + val_key_two = int(hexlify(key_two), 16) return self.val_key_one ^ val_key_two def is_closer(self, a, b): diff --git a/lbrynet/dht/kbucket.py b/lbrynet/dht/kbucket.py index 64027fe1e..a4756bed8 100644 --- a/lbrynet/dht/kbucket.py +++ b/lbrynet/dht/kbucket.py @@ -4,9 +4,6 @@ from binascii import hexlify from . import constants from .distance import Distance from .error import BucketFull -import sys -if sys.version_info > (3,): - long = int log = logging.getLogger(__name__) @@ -141,7 +138,7 @@ class KBucket: @rtype: bool """ if isinstance(key, bytes): - key = long(hexlify(key), 16) + key = int(hexlify(key), 16) return self.rangeMin <= key < self.rangeMax def __len__(self): diff --git a/lbrynet/dht/msgtypes.py b/lbrynet/dht/msgtypes.py index b33f0c035..14e6734f1 100644 --- a/lbrynet/dht/msgtypes.py +++ b/lbrynet/dht/msgtypes.py @@ -17,7 +17,7 @@ class Message: def __init__(self, rpcID, nodeID): if len(rpcID) != constants.rpc_id_length: raise ValueError("invalid rpc id: %i bytes (expected 20)" % len(rpcID)) - if len(nodeID) != constants.key_bits / 8: + if len(nodeID) != constants.key_bits // 8: raise ValueError("invalid node id: %i bytes (expected 48)" % len(nodeID)) self.id = rpcID self.nodeID = nodeID diff --git a/lbrynet/dht/node.py b/lbrynet/dht/node.py index 5decbc782..3433519e4 100644 --- a/lbrynet/dht/node.py +++ b/lbrynet/dht/node.py @@ -1,11 +1,3 @@ -#!/usr/bin/env python -# -# This library is free software, distributed under the terms of -# the GNU Lesser General Public License Version 3, or any later version. -# See the COPYING file included in this archive -# -# The docstrings in this module contain epytext markup; API documentation -# may be created by processing this file with epydoc: http://epydoc.sf.net import binascii import hashlib import struct @@ -34,7 +26,7 @@ def expand_peer(compact_peer_info): host = ".".join([str(ord(d)) for d in compact_peer_info[:4]]) port, = struct.unpack('>H', compact_peer_info[4:6]) peer_node_id = compact_peer_info[6:] - return (peer_node_id, host, port) + return peer_node_id, host, port def rpcmethod(func): @@ -348,7 +340,7 @@ class Node(MockKademliaHelper): stored_to = yield DeferredDict({contact: self.storeToContact(blob_hash, contact) for contact in contacts}) contacted_node_ids = [binascii.hexlify(contact.id) for contact in stored_to.keys() if stored_to[contact]] log.debug("Stored %s to %i of %i attempted peers", binascii.hexlify(blob_hash), - len(list(contacted_node_ids)), len(contacts)) + len(contacted_node_ids), len(contacts)) defer.returnValue(contacted_node_ids) def change_token(self): diff --git a/lbrynet/file_manager/EncryptedFileCreator.py b/lbrynet/file_manager/EncryptedFileCreator.py index 888c7fec7..1bbf6c1e6 100644 --- a/lbrynet/file_manager/EncryptedFileCreator.py +++ b/lbrynet/file_manager/EncryptedFileCreator.py @@ -2,10 +2,9 @@ Utilities for turning plain files into LBRY Files. """ -import six -import binascii import logging import os +from binascii import hexlify from twisted.internet import defer from twisted.protocols.basic import FileSender @@ -38,14 +37,14 @@ class EncryptedFileStreamCreator(CryptStreamCreator): def _finished(self): # calculate the stream hash self.stream_hash = get_stream_hash( - hexlify(self.name), hexlify(self.key), hexlify(self.name), + hexlify(self.name.encode()), hexlify(self.key), hexlify(self.name.encode()), self.blob_infos ) # generate the sd info self.sd_info = format_sd_info( - EncryptedFileStreamType, hexlify(self.name), hexlify(self.key), - hexlify(self.name), self.stream_hash.encode(), self.blob_infos + EncryptedFileStreamType, hexlify(self.name.encode()), hexlify(self.key), + hexlify(self.name.encode()), self.stream_hash.encode(), self.blob_infos ) # sanity check @@ -126,15 +125,7 @@ def create_lbry_file(blob_manager, storage, payment_rate_manager, lbry_file_mana ) log.debug("adding to the file manager") lbry_file = yield lbry_file_manager.add_published_file( - sd_info['stream_hash'], sd_hash, binascii.hexlify(file_directory.encode()), payment_rate_manager, + sd_info['stream_hash'], sd_hash, hexlify(file_directory.encode()), payment_rate_manager, payment_rate_manager.min_blob_data_payment_rate ) defer.returnValue(lbry_file) - - -def hexlify(str_or_unicode): - if isinstance(str_or_unicode, six.text_type): - strng = str_or_unicode.encode('utf-8') - else: - strng = str_or_unicode - return binascii.hexlify(strng) diff --git a/scripts/upload_assets.py b/scripts/upload_assets.py deleted file mode 100644 index ee66ffcc3..000000000 --- a/scripts/upload_assets.py +++ /dev/null @@ -1,149 +0,0 @@ -import glob -import json -import os -import subprocess -import sys - -import github -import uritemplate -import boto3 - - -def main(): - #upload_to_github_if_tagged('lbryio/lbry') - upload_to_s3('daemon') - - -def get_asset_filename(): - root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) - return glob.glob(os.path.join(root_dir, 'dist/*'))[0] - - -def get_cli_output(command): - return subprocess.check_output(command.split()).decode().strip() - - -def upload_to_s3(folder): - asset_path = get_asset_filename() - branch = get_cli_output('git rev-parse --abbrev-ref HEAD') - if branch = 'master': - tag = get_cli_output('git describe --always --abbrev=8 HEAD') - commit = get_cli_output('git show -s --format=%cd --date=format:%Y%m%d-%H%I%S HEAD') - bucket = 'releases.lbry.io' - key = '{}/{}-{}/{}'.format(folder, commit, tag, os.path.basename(asset_path)) - else: - key = '{}/{}-{}/{}'.format(folder, commit_date, tag, os.path.basename(asset_path)) - - print("Uploading {} to s3://{}/{}".format(asset_path, bucket, key)) - - if 'AWS_ACCESS_KEY_ID' not in os.environ or 'AWS_SECRET_ACCESS_KEY' not in os.environ: - print('Must set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to publish assets to s3') - return 1 - - s3 = boto3.resource( - 's3', - aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'], - aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'], - config=boto3.session.Config(signature_version='s3v4') - ) - s3.meta.client.upload_file(asset_path, bucket, key) - - -def upload_to_github_if_tagged(repo_name): - try: - current_tag = subprocess.check_output( - ['git', 'describe', '--exact-match', 'HEAD']).strip() - except subprocess.CalledProcessError: - print('Not uploading to GitHub as we are not currently on a tag') - return 1 - - print("Current tag: " + current_tag) - - if 'GH_TOKEN' not in os.environ: - print('Must set GH_TOKEN in order to publish assets to a release') - return 1 - - gh_token = os.environ['GH_TOKEN'] - auth = github.Github(gh_token) - repo = auth.get_repo(repo_name) - - if not check_repo_has_tag(repo, current_tag): - print('Tag {} is not in repo {}'.format(current_tag, repo)) - # TODO: maybe this should be an error - return 1 - - asset_path = get_asset_filename() - print("Uploading " + asset_path + " to Github tag " + current_tag) - release = get_github_release(repo, current_tag) - upload_asset_to_github(release, asset_path, gh_token) - - -def check_repo_has_tag(repo, target_tag): - tags = repo.get_tags().get_page(0) - for tag in tags: - if tag.name == target_tag: - return True - return False - - -def get_github_release(repo, current_tag): - for release in repo.get_releases(): - if release.tag_name == current_tag: - return release - raise Exception('No release for {} was found'.format(current_tag)) - - -def upload_asset_to_github(release, asset_to_upload, token): - basename = os.path.basename(asset_to_upload) - for asset in release.raw_data['assets']: - if asset['name'] == basename: - print('File {} has already been uploaded to {}'.format(basename, release.tag_name)) - return - - upload_uri = uritemplate.expand(release.upload_url, {'name': basename}) - count = 0 - while count < 10: - try: - output = _curl_uploader(upload_uri, asset_to_upload, token) - if 'errors' in output: - raise Exception(output) - else: - print('Successfully uploaded to {}'.format(output['browser_download_url'])) - except Exception: - print('Failed uploading on attempt {}'.format(count + 1)) - count += 1 - - -def _curl_uploader(upload_uri, asset_to_upload, token): - # using requests.post fails miserably with SSL EPIPE errors. I spent - # half a day trying to debug before deciding to switch to curl. - # - # TODO: actually set the content type - print('Using curl to upload {} to {}'.format(asset_to_upload, upload_uri)) - cmd = [ - 'curl', - '-sS', - '-X', 'POST', - '-u', ':{}'.format(os.environ['GH_TOKEN']), - '--header', 'Content-Type: application/octet-stream', - '--data-binary', '@-', - upload_uri - ] - # '-d', '{"some_key": "some_value"}', - print('Calling curl:') - print(cmd) - print('') - with open(asset_to_upload, 'rb') as fp: - p = subprocess.Popen(cmd, stdin=fp, stderr=subprocess.PIPE, stdout=subprocess.PIPE) - stdout, stderr = p.communicate() - print('curl return code: {}'.format(p.returncode)) - if stderr: - print('stderr output from curl:') - print(stderr) - print('stdout from curl:') - print(stdout) - return json.loads(stdout) - - -if __name__ == '__main__': - sys.exit(main()) diff --git a/tests/unit/dht/test_contact.py b/tests/unit/dht/test_contact.py index d18baaea3..1ab72a487 100644 --- a/tests/unit/dht/test_contact.py +++ b/tests/unit/dht/test_contact.py @@ -1,3 +1,4 @@ +from binascii import hexlify from twisted.internet import task from twisted.trial import unittest from lbrynet.core.utils import generate_id @@ -10,50 +11,62 @@ class ContactOperatorsTest(unittest.TestCase): def setUp(self): self.contact_manager = ContactManager() self.node_ids = [generate_id(), generate_id(), generate_id()] - self.firstContact = self.contact_manager.make_contact(self.node_ids[1], '127.0.0.1', 1000, None, 1) - self.secondContact = self.contact_manager.make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32) - self.secondContactCopy = self.contact_manager.make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32) - self.firstContactDifferentValues = self.contact_manager.make_contact(self.node_ids[1], '192.168.1.20', - 1000, None, 50) - self.assertRaises(ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20', - 100000, None) - self.assertRaises(ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20.1', - 1000, None) - self.assertRaises(ValueError, self.contact_manager.make_contact, self.node_ids[1], 'this is not an ip', - 1000, None) - self.assertRaises(ValueError, self.contact_manager.make_contact, "this is not a node id", '192.168.1.20.1', - 1000, None) + make_contact = self.contact_manager.make_contact + self.first_contact = make_contact(self.node_ids[1], '127.0.0.1', 1000, None, 1) + self.second_contact = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32) + self.second_contact_copy = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32) + self.first_contact_different_values = make_contact(self.node_ids[1], '192.168.1.20', 1000, None, 50) - def testNoDuplicateContactObjects(self): - self.assertTrue(self.secondContact is self.secondContactCopy) - self.assertTrue(self.firstContact is not self.firstContactDifferentValues) + def test_make_contact_error_cases(self): + self.assertRaises( + ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20', 100000, None) + self.assertRaises( + ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20.1', 1000, None) + self.assertRaises( + ValueError, self.contact_manager.make_contact, self.node_ids[1], 'this is not an ip', 1000, None) + self.assertRaises( + ValueError, self.contact_manager.make_contact, b'not valid node id', '192.168.1.20.1', 1000, None) - def testBoolean(self): + def test_no_duplicate_contact_objects(self): + self.assertTrue(self.second_contact is self.second_contact_copy) + self.assertTrue(self.first_contact is not self.first_contact_different_values) + + def test_boolean(self): """ Test "equals" and "not equals" comparisons """ self.assertNotEqual( - self.firstContact, self.secondContact, + self.first_contact, self.second_contact, 'Contacts with different IDs should not be equal.') self.assertEqual( - self.firstContact, self.firstContactDifferentValues, + self.first_contact, self.first_contact_different_values, 'Contacts with same IDs should be equal, even if their other values differ.') self.assertEqual( - self.secondContact, self.secondContactCopy, + self.second_contact, self.second_contact_copy, 'Different copies of the same Contact instance should be equal') - def testIllogicalComparisons(self): + def test_illogical_comparisons(self): """ Test comparisons with non-Contact and non-str types """ msg = '"{}" operator: Contact object should not be equal to {} type' for item in (123, [1, 2, 3], {'key': 'value'}): self.assertNotEqual( - self.firstContact, item, + self.first_contact, item, msg.format('eq', type(item).__name__)) self.assertTrue( - self.firstContact != item, + self.first_contact != item, msg.format('ne', type(item).__name__)) - def testCompactIP(self): - self.assertEqual(self.firstContact.compact_ip(), b'\x7f\x00\x00\x01') - self.assertEqual(self.secondContact.compact_ip(), b'\xc0\xa8\x00\x01') + def test_compact_ip(self): + self.assertEqual(self.first_contact.compact_ip(), b'\x7f\x00\x00\x01') + self.assertEqual(self.second_contact.compact_ip(), b'\xc0\xa8\x00\x01') + + def test_id_log(self): + self.assertEqual(self.first_contact.log_id(False), hexlify(self.node_ids[1])) + self.assertEqual(self.first_contact.log_id(True), hexlify(self.node_ids[1])[:8]) + + def test_hash(self): + # fails with "TypeError: unhashable type: '_Contact'" if __hash__ is not implemented + self.assertEqual( + len({self.first_contact, self.second_contact, self.second_contact_copy}), 2 + ) class TestContactLastReplied(unittest.TestCase): diff --git a/tests/unit/lbryfilemanager/test_EncryptedFileCreator.py b/tests/unit/lbryfilemanager/test_EncryptedFileCreator.py index 1909cae61..feeb14d92 100644 --- a/tests/unit/lbryfilemanager/test_EncryptedFileCreator.py +++ b/tests/unit/lbryfilemanager/test_EncryptedFileCreator.py @@ -1,8 +1,10 @@ -# -*- coding: utf-8 -*- -from cryptography.hazmat.primitives.ciphers.algorithms import AES +import json +import mock from twisted.trial import unittest from twisted.internet import defer +from cryptography.hazmat.primitives.ciphers.algorithms import AES +from lbrynet.database.storage import SQLiteStorage from lbrynet.core.StreamDescriptor import get_sd_info, BlobStreamDescriptorReader from lbrynet.core.StreamDescriptor import StreamDescriptorIdentifier from lbrynet.core.BlobManager import DiskBlobManager @@ -12,7 +14,7 @@ from lbrynet.core.PaymentRateManager import OnlyFreePaymentsManager from lbrynet.database.storage import SQLiteStorage from lbrynet.file_manager import EncryptedFileCreator from lbrynet.file_manager.EncryptedFileManager import EncryptedFileManager -from lbrynet.core.StreamDescriptor import bytes2unicode +from lbrynet.core.StreamDescriptor import JSONBytesEncoder from tests import mocks from tests.util import mk_db_and_blob_dir, rm_db_and_blob_dir @@ -30,7 +32,7 @@ MB = 2**20 def iv_generator(): while True: - yield '3' * (AES.block_size // 8) + yield b'3' * (AES.block_size // 8) class CreateEncryptedFileTest(unittest.TestCase): @@ -87,7 +89,8 @@ class CreateEncryptedFileTest(unittest.TestCase): # this comes from the database, the blobs returned are sorted sd_info = yield get_sd_info(self.storage, lbry_file.stream_hash, include_blobs=True) self.maxDiff = None - self.assertDictEqual(bytes2unicode(sd_info), sd_file_info) + unicode_sd_info = json.loads(json.dumps(sd_info, sort_keys=True, cls=JSONBytesEncoder)) + self.assertDictEqual(unicode_sd_info, sd_file_info) self.assertEqual(sd_info['stream_hash'], expected_stream_hash) self.assertEqual(len(sd_info['blobs']), 3) self.assertNotEqual(sd_info['blobs'][0]['length'], 0)