unnecessary list() added during py3 port

instead of recursive bytes2unicode use a proper JSONEncoder to conver bytes->unicode for json.dumps()
removing excessive isinstance(data, bytes) checks
py3: / -> // and list() around .items() that gets modified in loop
moved lbrynet.undecorated to where its actually used and hopefully we can delete it eventually
removed build/upload_assets.py, travis can do all this now
This commit is contained in:
Lex Berezhny 2018-07-31 13:20:25 -04:00 committed by Jack Robison
parent f061ca2b15
commit 10b34d6b33
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
18 changed files with 85 additions and 254 deletions

View file

@ -124,7 +124,8 @@ disable=
keyword-arg-before-vararg, keyword-arg-before-vararg,
assignment-from-no-return, assignment-from-no-return,
useless-return, useless-return,
assignment-from-none assignment-from-none,
stop-iteration-return
[REPORTS] [REPORTS]
@ -389,7 +390,7 @@ int-import-graph=
[DESIGN] [DESIGN]
# Maximum number of arguments for function / method # Maximum number of arguments for function / method
max-args=5 max-args=10
# Argument names that match this expression will be ignored. Default to name # Argument names that match this expression will be ignored. Default to name
# with leading underscore # with leading underscore

View file

@ -149,7 +149,7 @@ class BlobFile:
def writer_finished(self, writer, err=None): def writer_finished(self, writer, err=None):
def fire_finished_deferred(): def fire_finished_deferred():
self._verified = True self._verified = True
for p, (w, finished_deferred) in self.writers.items(): for p, (w, finished_deferred) in list(self.writers.items()):
if w == writer: if w == writer:
del self.writers[p] del self.writers[p]
finished_deferred.callback(self) finished_deferred.callback(self)

View file

@ -46,8 +46,6 @@ class BlobFileCreator:
def write(self, data): def write(self, data):
if not self._is_open: if not self._is_open:
raise IOError raise IOError
if not isinstance(data, bytes):
data = data.encode()
self._hashsum.update(data) self._hashsum.update(data)
self.len_so_far += len(data) self.len_so_far += len(data)
self.buffer.write(data) self.buffer.write(data)

View file

@ -12,6 +12,13 @@ from lbrynet.core.HTTPBlobDownloader import HTTPBlobDownloader
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class JSONBytesEncoder(json.JSONEncoder):
def default(self, obj): # pylint: disable=E0202
if isinstance(obj, bytes):
return obj.decode()
return super().default(obj)
class StreamDescriptorReader: class StreamDescriptorReader:
"""Classes which derive from this class read a stream descriptor file return """Classes which derive from this class read a stream descriptor file return
a dictionary containing the fields in the file""" a dictionary containing the fields in the file"""
@ -66,16 +73,6 @@ class BlobStreamDescriptorReader(StreamDescriptorReader):
return threads.deferToThread(get_data) return threads.deferToThread(get_data)
def bytes2unicode(value):
if isinstance(value, bytes):
return value.decode()
elif isinstance(value, (list, tuple)):
return [bytes2unicode(v) for v in value]
elif isinstance(value, dict):
return {key: bytes2unicode(v) for key, v in value.items()}
return value
class StreamDescriptorWriter: class StreamDescriptorWriter:
"""Classes which derive from this class write fields from a dictionary """Classes which derive from this class write fields from a dictionary
of fields to a stream descriptor""" of fields to a stream descriptor"""
@ -83,7 +80,9 @@ class StreamDescriptorWriter:
pass pass
def create_descriptor(self, sd_info): def create_descriptor(self, sd_info):
return self._write_stream_descriptor(json.dumps(bytes2unicode(sd_info), sort_keys=True)) return self._write_stream_descriptor(
json.dumps(sd_info, sort_keys=True, cls=JSONBytesEncoder).encode()
)
def _write_stream_descriptor(self, raw_data): def _write_stream_descriptor(self, raw_data):
"""This method must be overridden by subclasses to write raw data to """This method must be overridden by subclasses to write raw data to

View file

@ -68,14 +68,14 @@ class StreamBlobDecryptor:
def write_bytes(): def write_bytes():
if self.len_read < self.length: if self.len_read < self.length:
num_bytes_to_decrypt = greatest_multiple(len(self.buff), (AES.block_size / 8)) num_bytes_to_decrypt = greatest_multiple(len(self.buff), (AES.block_size // 8))
data_to_decrypt, self.buff = split(self.buff, num_bytes_to_decrypt) data_to_decrypt, self.buff = split(self.buff, num_bytes_to_decrypt)
write_func(self.cipher.update(data_to_decrypt)) write_func(self.cipher.update(data_to_decrypt))
def finish_decrypt(): def finish_decrypt():
bytes_left = len(self.buff) % (AES.block_size / 8) bytes_left = len(self.buff) % (AES.block_size // 8)
if bytes_left != 0: if bytes_left != 0:
log.warning(self.buff[-1 * (AES.block_size / 8):].encode('hex')) log.warning(self.buff[-1 * (AES.block_size // 8):].encode('hex'))
raise Exception("blob %s has incorrect padding: %i bytes left" % raise Exception("blob %s has incorrect padding: %i bytes left" %
(self.blob.blob_hash, bytes_left)) (self.blob.blob_hash, bytes_left))
data_to_decrypt, self.buff = self.buff, b'' data_to_decrypt, self.buff = self.buff, b''
@ -128,8 +128,6 @@ class CryptStreamBlobMaker:
max bytes are written. num_bytes_to_write is the number max bytes are written. num_bytes_to_write is the number
of bytes that will be written from data in this call of bytes that will be written from data in this call
""" """
if not isinstance(data, bytes):
data = data.encode()
max_bytes_to_write = MAX_BLOB_SIZE - self.length - 1 max_bytes_to_write = MAX_BLOB_SIZE - self.length - 1
done = False done = False
if max_bytes_to_write <= len(data): if max_bytes_to_write <= len(data):

View file

@ -102,12 +102,6 @@ class CryptStreamCreator:
while 1: while 1:
yield os.urandom(AES.block_size // 8) yield os.urandom(AES.block_size // 8)
def get_next_iv(self):
iv = next(self.iv_generator)
if not isinstance(iv, bytes):
return iv.encode()
return iv
def setup(self): def setup(self):
"""Create the symmetric key if it wasn't provided""" """Create the symmetric key if it wasn't provided"""
@ -127,7 +121,7 @@ class CryptStreamCreator:
yield defer.DeferredList(self.finished_deferreds) yield defer.DeferredList(self.finished_deferreds)
self.blob_count += 1 self.blob_count += 1
iv = self.get_next_iv() iv = next(self.iv_generator)
final_blob = self._get_blob_maker(iv, self.blob_manager.get_blob_creator()) final_blob = self._get_blob_maker(iv, self.blob_manager.get_blob_creator())
stream_terminator = yield final_blob.close() stream_terminator = yield final_blob.close()
terminator_info = yield self._blob_finished(stream_terminator) terminator_info = yield self._blob_finished(stream_terminator)
@ -138,7 +132,7 @@ class CryptStreamCreator:
if self.current_blob is None: if self.current_blob is None:
self.next_blob_creator = self.blob_manager.get_blob_creator() self.next_blob_creator = self.blob_manager.get_blob_creator()
self.blob_count += 1 self.blob_count += 1
iv = self.get_next_iv() iv = next(self.iv_generator)
self.current_blob = self._get_blob_maker(iv, self.next_blob_creator) self.current_blob = self._get_blob_maker(iv, self.next_blob_creator)
done, num_bytes_written = self.current_blob.write(data) done, num_bytes_written = self.current_blob.write(data)
data = data[num_bytes_written:] data = data[num_bytes_written:]

View file

@ -19,8 +19,8 @@ from lbrynet.core import utils
from lbrynet.core.Error import ComponentsNotStarted, ComponentStartConditionNotMet from lbrynet.core.Error import ComponentsNotStarted, ComponentStartConditionNotMet
from lbrynet.core.looping_call_manager import LoopingCallManager from lbrynet.core.looping_call_manager import LoopingCallManager
from lbrynet.daemon.ComponentManager import ComponentManager from lbrynet.daemon.ComponentManager import ComponentManager
from lbrynet.daemon.auth.util import APIKey, get_auth_message, LBRY_SECRET from .util import APIKey, get_auth_message, LBRY_SECRET
from lbrynet.undecorated import undecorated from .undecorated import undecorated
from .factory import AuthJSONRPCResource from .factory import AuthJSONRPCResource
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -684,11 +684,8 @@ class SQLiteStorage(WalletDatabase):
).fetchone() ).fetchone()
if not known_sd_hash: if not known_sd_hash:
raise Exception("stream not found") raise Exception("stream not found")
known_sd_hash = known_sd_hash[0]
if not isinstance(known_sd_hash, bytes):
known_sd_hash = known_sd_hash.encode()
# check the claim contains the same sd hash # check the claim contains the same sd hash
if known_sd_hash != claim.source_hash: if known_sd_hash[0].encode() != claim.source_hash:
raise Exception("stream mismatch") raise Exception("stream mismatch")
# if there is a current claim associated to the file, check that the new claim is an update to it # if there is a current claim associated to the file, check that the new claim is an update to it

View file

@ -7,7 +7,7 @@ from lbrynet.dht import constants
def is_valid_ipv4(address): def is_valid_ipv4(address):
try: try:
ip = ipaddress.ip_address(address.encode().decode()) # this needs to be unicode, thus re-encode-able ip = ipaddress.ip_address(address)
return ip.version == 4 return ip.version == 4
except ipaddress.AddressValueError: except ipaddress.AddressValueError:
return False return False
@ -22,8 +22,8 @@ class _Contact:
def __init__(self, contactManager, id, ipAddress, udpPort, networkProtocol, firstComm): def __init__(self, contactManager, id, ipAddress, udpPort, networkProtocol, firstComm):
if id is not None: if id is not None:
if not len(id) == constants.key_bits / 8: if not len(id) == constants.key_bits // 8:
raise ValueError("invalid node id: %s" % hexlify(id.encode())) raise ValueError("invalid node id: {}".format(hexlify(id).decode()))
if not 0 <= udpPort <= 65536: if not 0 <= udpPort <= 65536:
raise ValueError("invalid port") raise ValueError("invalid port")
if not is_valid_ipv4(ipAddress): if not is_valid_ipv4(ipAddress):

View file

@ -1,26 +1,23 @@
from binascii import hexlify from binascii import hexlify
from lbrynet.dht import constants from lbrynet.dht import constants
import sys
if sys.version_info > (3,):
long = int
class Distance: class Distance:
"""Calculate the XOR result between two string variables. """Calculate the XOR result between two string variables.
Frequently we re-use one of the points so as an optimization Frequently we re-use one of the points so as an optimization
we pre-calculate the long value of that point. we pre-calculate the value of that point.
""" """
def __init__(self, key): def __init__(self, key):
if len(key) != constants.key_bits / 8: if len(key) != constants.key_bits // 8:
raise ValueError("invalid key length: %i" % len(key)) raise ValueError("invalid key length: %i" % len(key))
self.key = key self.key = key
self.val_key_one = long(hexlify(key), 16) self.val_key_one = int(hexlify(key), 16)
def __call__(self, key_two): def __call__(self, key_two):
val_key_two = long(hexlify(key_two), 16) val_key_two = int(hexlify(key_two), 16)
return self.val_key_one ^ val_key_two return self.val_key_one ^ val_key_two
def is_closer(self, a, b): def is_closer(self, a, b):

View file

@ -4,9 +4,6 @@ from binascii import hexlify
from . import constants from . import constants
from .distance import Distance from .distance import Distance
from .error import BucketFull from .error import BucketFull
import sys
if sys.version_info > (3,):
long = int
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -141,7 +138,7 @@ class KBucket:
@rtype: bool @rtype: bool
""" """
if isinstance(key, bytes): if isinstance(key, bytes):
key = long(hexlify(key), 16) key = int(hexlify(key), 16)
return self.rangeMin <= key < self.rangeMax return self.rangeMin <= key < self.rangeMax
def __len__(self): def __len__(self):

View file

@ -17,7 +17,7 @@ class Message:
def __init__(self, rpcID, nodeID): def __init__(self, rpcID, nodeID):
if len(rpcID) != constants.rpc_id_length: if len(rpcID) != constants.rpc_id_length:
raise ValueError("invalid rpc id: %i bytes (expected 20)" % len(rpcID)) raise ValueError("invalid rpc id: %i bytes (expected 20)" % len(rpcID))
if len(nodeID) != constants.key_bits / 8: if len(nodeID) != constants.key_bits // 8:
raise ValueError("invalid node id: %i bytes (expected 48)" % len(nodeID)) raise ValueError("invalid node id: %i bytes (expected 48)" % len(nodeID))
self.id = rpcID self.id = rpcID
self.nodeID = nodeID self.nodeID = nodeID

View file

@ -1,11 +1,3 @@
#!/usr/bin/env python
#
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
#
# The docstrings in this module contain epytext markup; API documentation
# may be created by processing this file with epydoc: http://epydoc.sf.net
import binascii import binascii
import hashlib import hashlib
import struct import struct
@ -34,7 +26,7 @@ def expand_peer(compact_peer_info):
host = ".".join([str(ord(d)) for d in compact_peer_info[:4]]) host = ".".join([str(ord(d)) for d in compact_peer_info[:4]])
port, = struct.unpack('>H', compact_peer_info[4:6]) port, = struct.unpack('>H', compact_peer_info[4:6])
peer_node_id = compact_peer_info[6:] peer_node_id = compact_peer_info[6:]
return (peer_node_id, host, port) return peer_node_id, host, port
def rpcmethod(func): def rpcmethod(func):
@ -348,7 +340,7 @@ class Node(MockKademliaHelper):
stored_to = yield DeferredDict({contact: self.storeToContact(blob_hash, contact) for contact in contacts}) stored_to = yield DeferredDict({contact: self.storeToContact(blob_hash, contact) for contact in contacts})
contacted_node_ids = [binascii.hexlify(contact.id) for contact in stored_to.keys() if stored_to[contact]] contacted_node_ids = [binascii.hexlify(contact.id) for contact in stored_to.keys() if stored_to[contact]]
log.debug("Stored %s to %i of %i attempted peers", binascii.hexlify(blob_hash), log.debug("Stored %s to %i of %i attempted peers", binascii.hexlify(blob_hash),
len(list(contacted_node_ids)), len(contacts)) len(contacted_node_ids), len(contacts))
defer.returnValue(contacted_node_ids) defer.returnValue(contacted_node_ids)
def change_token(self): def change_token(self):

View file

@ -2,10 +2,9 @@
Utilities for turning plain files into LBRY Files. Utilities for turning plain files into LBRY Files.
""" """
import six
import binascii
import logging import logging
import os import os
from binascii import hexlify
from twisted.internet import defer from twisted.internet import defer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
@ -38,14 +37,14 @@ class EncryptedFileStreamCreator(CryptStreamCreator):
def _finished(self): def _finished(self):
# calculate the stream hash # calculate the stream hash
self.stream_hash = get_stream_hash( self.stream_hash = get_stream_hash(
hexlify(self.name), hexlify(self.key), hexlify(self.name), hexlify(self.name.encode()), hexlify(self.key), hexlify(self.name.encode()),
self.blob_infos self.blob_infos
) )
# generate the sd info # generate the sd info
self.sd_info = format_sd_info( self.sd_info = format_sd_info(
EncryptedFileStreamType, hexlify(self.name), hexlify(self.key), EncryptedFileStreamType, hexlify(self.name.encode()), hexlify(self.key),
hexlify(self.name), self.stream_hash.encode(), self.blob_infos hexlify(self.name.encode()), self.stream_hash.encode(), self.blob_infos
) )
# sanity check # sanity check
@ -126,15 +125,7 @@ def create_lbry_file(blob_manager, storage, payment_rate_manager, lbry_file_mana
) )
log.debug("adding to the file manager") log.debug("adding to the file manager")
lbry_file = yield lbry_file_manager.add_published_file( lbry_file = yield lbry_file_manager.add_published_file(
sd_info['stream_hash'], sd_hash, binascii.hexlify(file_directory.encode()), payment_rate_manager, sd_info['stream_hash'], sd_hash, hexlify(file_directory.encode()), payment_rate_manager,
payment_rate_manager.min_blob_data_payment_rate payment_rate_manager.min_blob_data_payment_rate
) )
defer.returnValue(lbry_file) defer.returnValue(lbry_file)
def hexlify(str_or_unicode):
if isinstance(str_or_unicode, six.text_type):
strng = str_or_unicode.encode('utf-8')
else:
strng = str_or_unicode
return binascii.hexlify(strng)

View file

@ -1,149 +0,0 @@
import glob
import json
import os
import subprocess
import sys
import github
import uritemplate
import boto3
def main():
#upload_to_github_if_tagged('lbryio/lbry')
upload_to_s3('daemon')
def get_asset_filename():
root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
return glob.glob(os.path.join(root_dir, 'dist/*'))[0]
def get_cli_output(command):
return subprocess.check_output(command.split()).decode().strip()
def upload_to_s3(folder):
asset_path = get_asset_filename()
branch = get_cli_output('git rev-parse --abbrev-ref HEAD')
if branch = 'master':
tag = get_cli_output('git describe --always --abbrev=8 HEAD')
commit = get_cli_output('git show -s --format=%cd --date=format:%Y%m%d-%H%I%S HEAD')
bucket = 'releases.lbry.io'
key = '{}/{}-{}/{}'.format(folder, commit, tag, os.path.basename(asset_path))
else:
key = '{}/{}-{}/{}'.format(folder, commit_date, tag, os.path.basename(asset_path))
print("Uploading {} to s3://{}/{}".format(asset_path, bucket, key))
if 'AWS_ACCESS_KEY_ID' not in os.environ or 'AWS_SECRET_ACCESS_KEY' not in os.environ:
print('Must set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to publish assets to s3')
return 1
s3 = boto3.resource(
's3',
aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
config=boto3.session.Config(signature_version='s3v4')
)
s3.meta.client.upload_file(asset_path, bucket, key)
def upload_to_github_if_tagged(repo_name):
try:
current_tag = subprocess.check_output(
['git', 'describe', '--exact-match', 'HEAD']).strip()
except subprocess.CalledProcessError:
print('Not uploading to GitHub as we are not currently on a tag')
return 1
print("Current tag: " + current_tag)
if 'GH_TOKEN' not in os.environ:
print('Must set GH_TOKEN in order to publish assets to a release')
return 1
gh_token = os.environ['GH_TOKEN']
auth = github.Github(gh_token)
repo = auth.get_repo(repo_name)
if not check_repo_has_tag(repo, current_tag):
print('Tag {} is not in repo {}'.format(current_tag, repo))
# TODO: maybe this should be an error
return 1
asset_path = get_asset_filename()
print("Uploading " + asset_path + " to Github tag " + current_tag)
release = get_github_release(repo, current_tag)
upload_asset_to_github(release, asset_path, gh_token)
def check_repo_has_tag(repo, target_tag):
tags = repo.get_tags().get_page(0)
for tag in tags:
if tag.name == target_tag:
return True
return False
def get_github_release(repo, current_tag):
for release in repo.get_releases():
if release.tag_name == current_tag:
return release
raise Exception('No release for {} was found'.format(current_tag))
def upload_asset_to_github(release, asset_to_upload, token):
basename = os.path.basename(asset_to_upload)
for asset in release.raw_data['assets']:
if asset['name'] == basename:
print('File {} has already been uploaded to {}'.format(basename, release.tag_name))
return
upload_uri = uritemplate.expand(release.upload_url, {'name': basename})
count = 0
while count < 10:
try:
output = _curl_uploader(upload_uri, asset_to_upload, token)
if 'errors' in output:
raise Exception(output)
else:
print('Successfully uploaded to {}'.format(output['browser_download_url']))
except Exception:
print('Failed uploading on attempt {}'.format(count + 1))
count += 1
def _curl_uploader(upload_uri, asset_to_upload, token):
# using requests.post fails miserably with SSL EPIPE errors. I spent
# half a day trying to debug before deciding to switch to curl.
#
# TODO: actually set the content type
print('Using curl to upload {} to {}'.format(asset_to_upload, upload_uri))
cmd = [
'curl',
'-sS',
'-X', 'POST',
'-u', ':{}'.format(os.environ['GH_TOKEN']),
'--header', 'Content-Type: application/octet-stream',
'--data-binary', '@-',
upload_uri
]
# '-d', '{"some_key": "some_value"}',
print('Calling curl:')
print(cmd)
print('')
with open(asset_to_upload, 'rb') as fp:
p = subprocess.Popen(cmd, stdin=fp, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
stdout, stderr = p.communicate()
print('curl return code: {}'.format(p.returncode))
if stderr:
print('stderr output from curl:')
print(stderr)
print('stdout from curl:')
print(stdout)
return json.loads(stdout)
if __name__ == '__main__':
sys.exit(main())

View file

@ -1,3 +1,4 @@
from binascii import hexlify
from twisted.internet import task from twisted.internet import task
from twisted.trial import unittest from twisted.trial import unittest
from lbrynet.core.utils import generate_id from lbrynet.core.utils import generate_id
@ -10,50 +11,62 @@ class ContactOperatorsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.contact_manager = ContactManager() self.contact_manager = ContactManager()
self.node_ids = [generate_id(), generate_id(), generate_id()] self.node_ids = [generate_id(), generate_id(), generate_id()]
self.firstContact = self.contact_manager.make_contact(self.node_ids[1], '127.0.0.1', 1000, None, 1) make_contact = self.contact_manager.make_contact
self.secondContact = self.contact_manager.make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32) self.first_contact = make_contact(self.node_ids[1], '127.0.0.1', 1000, None, 1)
self.secondContactCopy = self.contact_manager.make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32) self.second_contact = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32)
self.firstContactDifferentValues = self.contact_manager.make_contact(self.node_ids[1], '192.168.1.20', self.second_contact_copy = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32)
1000, None, 50) self.first_contact_different_values = make_contact(self.node_ids[1], '192.168.1.20', 1000, None, 50)
self.assertRaises(ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20',
100000, None)
self.assertRaises(ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20.1',
1000, None)
self.assertRaises(ValueError, self.contact_manager.make_contact, self.node_ids[1], 'this is not an ip',
1000, None)
self.assertRaises(ValueError, self.contact_manager.make_contact, "this is not a node id", '192.168.1.20.1',
1000, None)
def testNoDuplicateContactObjects(self): def test_make_contact_error_cases(self):
self.assertTrue(self.secondContact is self.secondContactCopy) self.assertRaises(
self.assertTrue(self.firstContact is not self.firstContactDifferentValues) ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20', 100000, None)
self.assertRaises(
ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20.1', 1000, None)
self.assertRaises(
ValueError, self.contact_manager.make_contact, self.node_ids[1], 'this is not an ip', 1000, None)
self.assertRaises(
ValueError, self.contact_manager.make_contact, b'not valid node id', '192.168.1.20.1', 1000, None)
def testBoolean(self): def test_no_duplicate_contact_objects(self):
self.assertTrue(self.second_contact is self.second_contact_copy)
self.assertTrue(self.first_contact is not self.first_contact_different_values)
def test_boolean(self):
""" Test "equals" and "not equals" comparisons """ """ Test "equals" and "not equals" comparisons """
self.assertNotEqual( self.assertNotEqual(
self.firstContact, self.secondContact, self.first_contact, self.second_contact,
'Contacts with different IDs should not be equal.') 'Contacts with different IDs should not be equal.')
self.assertEqual( self.assertEqual(
self.firstContact, self.firstContactDifferentValues, self.first_contact, self.first_contact_different_values,
'Contacts with same IDs should be equal, even if their other values differ.') 'Contacts with same IDs should be equal, even if their other values differ.')
self.assertEqual( self.assertEqual(
self.secondContact, self.secondContactCopy, self.second_contact, self.second_contact_copy,
'Different copies of the same Contact instance should be equal') 'Different copies of the same Contact instance should be equal')
def testIllogicalComparisons(self): def test_illogical_comparisons(self):
""" Test comparisons with non-Contact and non-str types """ """ Test comparisons with non-Contact and non-str types """
msg = '"{}" operator: Contact object should not be equal to {} type' msg = '"{}" operator: Contact object should not be equal to {} type'
for item in (123, [1, 2, 3], {'key': 'value'}): for item in (123, [1, 2, 3], {'key': 'value'}):
self.assertNotEqual( self.assertNotEqual(
self.firstContact, item, self.first_contact, item,
msg.format('eq', type(item).__name__)) msg.format('eq', type(item).__name__))
self.assertTrue( self.assertTrue(
self.firstContact != item, self.first_contact != item,
msg.format('ne', type(item).__name__)) msg.format('ne', type(item).__name__))
def testCompactIP(self): def test_compact_ip(self):
self.assertEqual(self.firstContact.compact_ip(), b'\x7f\x00\x00\x01') self.assertEqual(self.first_contact.compact_ip(), b'\x7f\x00\x00\x01')
self.assertEqual(self.secondContact.compact_ip(), b'\xc0\xa8\x00\x01') self.assertEqual(self.second_contact.compact_ip(), b'\xc0\xa8\x00\x01')
def test_id_log(self):
self.assertEqual(self.first_contact.log_id(False), hexlify(self.node_ids[1]))
self.assertEqual(self.first_contact.log_id(True), hexlify(self.node_ids[1])[:8])
def test_hash(self):
# fails with "TypeError: unhashable type: '_Contact'" if __hash__ is not implemented
self.assertEqual(
len({self.first_contact, self.second_contact, self.second_contact_copy}), 2
)
class TestContactLastReplied(unittest.TestCase): class TestContactLastReplied(unittest.TestCase):

View file

@ -1,8 +1,10 @@
# -*- coding: utf-8 -*- import json
from cryptography.hazmat.primitives.ciphers.algorithms import AES import mock
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer from twisted.internet import defer
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from lbrynet.database.storage import SQLiteStorage
from lbrynet.core.StreamDescriptor import get_sd_info, BlobStreamDescriptorReader from lbrynet.core.StreamDescriptor import get_sd_info, BlobStreamDescriptorReader
from lbrynet.core.StreamDescriptor import StreamDescriptorIdentifier from lbrynet.core.StreamDescriptor import StreamDescriptorIdentifier
from lbrynet.core.BlobManager import DiskBlobManager from lbrynet.core.BlobManager import DiskBlobManager
@ -12,7 +14,7 @@ from lbrynet.core.PaymentRateManager import OnlyFreePaymentsManager
from lbrynet.database.storage import SQLiteStorage from lbrynet.database.storage import SQLiteStorage
from lbrynet.file_manager import EncryptedFileCreator from lbrynet.file_manager import EncryptedFileCreator
from lbrynet.file_manager.EncryptedFileManager import EncryptedFileManager from lbrynet.file_manager.EncryptedFileManager import EncryptedFileManager
from lbrynet.core.StreamDescriptor import bytes2unicode from lbrynet.core.StreamDescriptor import JSONBytesEncoder
from tests import mocks from tests import mocks
from tests.util import mk_db_and_blob_dir, rm_db_and_blob_dir from tests.util import mk_db_and_blob_dir, rm_db_and_blob_dir
@ -30,7 +32,7 @@ MB = 2**20
def iv_generator(): def iv_generator():
while True: while True:
yield '3' * (AES.block_size // 8) yield b'3' * (AES.block_size // 8)
class CreateEncryptedFileTest(unittest.TestCase): class CreateEncryptedFileTest(unittest.TestCase):
@ -87,7 +89,8 @@ class CreateEncryptedFileTest(unittest.TestCase):
# this comes from the database, the blobs returned are sorted # this comes from the database, the blobs returned are sorted
sd_info = yield get_sd_info(self.storage, lbry_file.stream_hash, include_blobs=True) sd_info = yield get_sd_info(self.storage, lbry_file.stream_hash, include_blobs=True)
self.maxDiff = None self.maxDiff = None
self.assertDictEqual(bytes2unicode(sd_info), sd_file_info) unicode_sd_info = json.loads(json.dumps(sd_info, sort_keys=True, cls=JSONBytesEncoder))
self.assertDictEqual(unicode_sd_info, sd_file_info)
self.assertEqual(sd_info['stream_hash'], expected_stream_hash) self.assertEqual(sd_info['stream_hash'], expected_stream_hash)
self.assertEqual(len(sd_info['blobs']), 3) self.assertEqual(len(sd_info['blobs']), 3)
self.assertNotEqual(sd_info['blobs'][0]['length'], 0) self.assertNotEqual(sd_info['blobs'][0]['length'], 0)