Merge remote-tracking branch 'origin/treq_and_cryptography'

This commit is contained in:
Jack Robison 2018-05-11 09:22:28 -04:00
commit b3bf193188
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
17 changed files with 131 additions and 347 deletions

View file

@ -25,6 +25,11 @@ at anytime.
* *
### Changed ### Changed
* check headers file integrity on startup, removing/truncating the file to force re-download when necessary
* support partial headers file download from S3
* changed txrequests for treq
* changed cryptography version to 2.2.2
* removed pycrypto dependency, replacing all calls to cryptography
* several internal dht functions to use inlineCallbacks * several internal dht functions to use inlineCallbacks
* `DHTHashAnnouncer` and `Node` manage functions to use `LoopingCall`s instead of scheduling with `callLater`. * `DHTHashAnnouncer` and `Node` manage functions to use `LoopingCall`s instead of scheduling with `callLater`.
* `store` kademlia rpc method to block on the call finishing and to return storing peer information * `store` kademlia rpc method to block on the call finishing and to return storing peer information

View file

@ -1,8 +1,8 @@
import collections import collections
import logging import logging
import treq
from twisted.internet import defer, task from twisted.internet import defer, task
from requests import auth
from txrequests import Session
from lbrynet import conf from lbrynet import conf
from lbrynet.core import looping_call_manager, utils, system_info from lbrynet.core import looping_call_manager, utils, system_info
@ -216,8 +216,8 @@ class Manager(object):
class Api(object): class Api(object):
def __init__(self, session, url, write_key, enabled): def __init__(self, cookies, url, write_key, enabled):
self.session = session self.cookies = cookies
self.url = url self.url = url
self._write_key = write_key self._write_key = write_key
self._enabled = enabled self._enabled = enabled
@ -232,14 +232,17 @@ class Api(object):
# timeout will have expired. # timeout will have expired.
# #
# by forcing the connection to close, we will disable the keep-alive. # by forcing the connection to close, we will disable the keep-alive.
def update_cookies(response):
self.cookies.update(response.cookies())
return response
assert endpoint[0] == '/' assert endpoint[0] == '/'
headers = {"Connection": "close"} headers = {b"Connection": b"close"}
return self.session.post( d = treq.post(self.url + endpoint, auth=(self._write_key, ''), json=data,
self.url + endpoint, headers=headers, cookies=self.cookies)
json=data, d.addCallback(update_cookies)
auth=auth.HTTPBasicAuth(self._write_key, ''), return d
headers=headers
)
def track(self, event): def track(self, event):
"""Send a single tracking event""" """Send a single tracking event"""
@ -257,11 +260,10 @@ class Api(object):
@classmethod @classmethod
def new_instance(cls, enabled=None): def new_instance(cls, enabled=None):
"""Initialize an instance using values from the configuration""" """Initialize an instance using values from the configuration"""
session = Session()
if enabled is None: if enabled is None:
enabled = conf.settings['share_usage_data'] enabled = conf.settings['share_usage_data']
return cls( return cls(
session, {},
conf.settings['ANALYTICS_ENDPOINT'], conf.settings['ANALYTICS_ENDPOINT'],
utils.deobfuscate(conf.settings['ANALYTICS_TOKEN']), utils.deobfuscate(conf.settings['ANALYTICS_TOKEN']),
enabled, enabled,

View file

@ -236,6 +236,7 @@ FIXED_SETTINGS = {
'SLACK_WEBHOOK': ('nUE0pUZ6Yl9bo29epl5moTSwnl5wo20ip2IlqzywMKZiIQSFZR5' 'SLACK_WEBHOOK': ('nUE0pUZ6Yl9bo29epl5moTSwnl5wo20ip2IlqzywMKZiIQSFZR5'
'AHx4mY0VmF0WQZ1ESEP9kMHZlp1WzJwWOoKN3ImR1M2yUAaMyqGZ='), 'AHx4mY0VmF0WQZ1ESEP9kMHZlp1WzJwWOoKN3ImR1M2yUAaMyqGZ='),
'WALLET_TYPES': [LBRYUM_WALLET, LBRYCRD_WALLET], 'WALLET_TYPES': [LBRYUM_WALLET, LBRYCRD_WALLET],
'HEADERS_FILE_SHA256_CHECKSUM': (366295, 'b0c8197153a33ccbc52fb81a279588b6015b68b7726f73f6a2b81f7e25bfe4b9')
} }
ADJUSTABLE_SETTINGS = { ADJUSTABLE_SETTINGS = {

View file

@ -3,15 +3,14 @@ from collections import defaultdict, deque
import datetime import datetime
import logging import logging
from decimal import Decimal from decimal import Decimal
import treq
from zope.interface import implements from zope.interface import implements
from twisted.internet import threads, reactor, defer, task from twisted.internet import threads, reactor, defer, task
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted._threads._ithreads import AlreadyQuit
from twisted.internet.error import ConnectionAborted from twisted.internet.error import ConnectionAborted
from txrequests import Session as _TxRequestsSession
from requests import Session as requestsSession
from hashlib import sha256
from lbryum import wallet as lbryum_wallet from lbryum import wallet as lbryum_wallet
from lbryum.network import Network from lbryum.network import Network
from lbryum.simple_config import SimpleConfig from lbryum.simple_config import SimpleConfig
@ -36,29 +35,6 @@ from lbrynet.core.Error import DownloadCanceledError, RequestCanceledError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class TxRequestsSession(_TxRequestsSession):
# Session from txrequests would throw AlreadyQuit errors, this catches them
def __init__(self, pool=None, minthreads=1, maxthreads=4, **kwargs):
requestsSession.__init__(self, **kwargs) # pylint: disable=non-parent-init-called
self.ownPool = False
if pool is None:
self.ownPool = True
pool = ThreadPool(minthreads=minthreads, maxthreads=maxthreads)
# unclosed ThreadPool leads to reactor hangs at shutdown
# this is a problem in many situation, so better enforce pool stop here
def stop_pool():
try:
pool.stop()
except AlreadyQuit:
pass
reactor.addSystemEventTrigger("after", "shutdown", stop_pool)
self.pool = pool
if self.ownPool:
pool.start()
class ReservedPoints(object): class ReservedPoints(object):
def __init__(self, identifier, amount): def __init__(self, identifier, amount):
self.identifier = identifier self.identifier = identifier
@ -118,25 +94,39 @@ class Wallet(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def fetch_headers_from_s3(self): def fetch_headers_from_s3(self):
with TxRequestsSession() as s: local_header_size = self.local_header_file_size()
r = yield s.get(HEADERS_URL) resume_header = {"Range": "bytes={}-".format(local_header_size)}
raw_headers = r.content response = yield treq.get(HEADERS_URL, headers=resume_header)
if not len(raw_headers) % HEADER_SIZE: # should be divisible by the header size got_406 = response.code == 406 # our file is bigger
s3_height = (len(raw_headers) / HEADER_SIZE) - 1 final_size_after_download = response.length + local_header_size
local_height = self.local_header_file_height() if got_406:
if s3_height > local_height: log.warning("s3 is more out of date than we are")
with open(os.path.join(self.config.path, "blockchain_headers"), "wb") as headers_file: # should have something to download and a final length divisible by the header size
headers_file.write(raw_headers) elif final_size_after_download and not final_size_after_download % HEADER_SIZE:
log.info("fetched headers from s3 (s3 height: %i)", s3_height) s3_height = (final_size_after_download / HEADER_SIZE) - 1
local_height = self.local_header_file_height()
if s3_height > local_height:
if local_header_size:
log.info("Resuming download of %i bytes from s3", response.length)
with open(os.path.join(self.config.path, "blockchain_headers"), "a+b") as headers_file:
yield treq.collect(response, headers_file.write)
else: else:
log.warning("s3 is more out of date than we are") with open(os.path.join(self.config.path, "blockchain_headers"), "wb") as headers_file:
yield treq.collect(response, headers_file.write)
log.info("fetched headers from s3 (s3 height: %i), now verifying integrity after download.", s3_height)
self._check_header_file_integrity()
else: else:
log.error("invalid size for headers from s3") log.warning("s3 is more out of date than we are")
else:
log.error("invalid size for headers from s3")
def local_header_file_height(self): def local_header_file_height(self):
return max((self.local_header_file_size() / HEADER_SIZE) - 1, 0)
def local_header_file_size(self):
headers_path = os.path.join(self.config.path, "blockchain_headers") headers_path = os.path.join(self.config.path, "blockchain_headers")
if os.path.isfile(headers_path): if os.path.isfile(headers_path):
return max((os.stat(headers_path).st_size / 112) - 1, 0) return os.stat(headers_path).st_size
return 0 return 0
@defer.inlineCallbacks @defer.inlineCallbacks
@ -154,6 +144,7 @@ class Wallet(object):
from lbrynet import conf from lbrynet import conf
if conf.settings['blockchain_name'] != "lbrycrd_main": if conf.settings['blockchain_name'] != "lbrycrd_main":
defer.returnValue(False) defer.returnValue(False)
self._check_header_file_integrity()
s3_headers_depth = conf.settings['s3_headers_depth'] s3_headers_depth = conf.settings['s3_headers_depth']
if not s3_headers_depth: if not s3_headers_depth:
defer.returnValue(False) defer.returnValue(False)
@ -163,12 +154,36 @@ class Wallet(object):
try: try:
remote_height = yield self.get_remote_height(server_url, port) remote_height = yield self.get_remote_height(server_url, port)
log.info("%s:%i height: %i, local height: %s", server_url, port, remote_height, local_height) log.info("%s:%i height: %i, local height: %s", server_url, port, remote_height, local_height)
if remote_height > local_height + s3_headers_depth: if remote_height > (local_height + s3_headers_depth):
defer.returnValue(True) defer.returnValue(True)
except Exception as err: except Exception as err:
log.warning("error requesting remote height from %s:%i - %s", server_url, port, err) log.warning("error requesting remote height from %s:%i - %s", server_url, port, err)
defer.returnValue(False) defer.returnValue(False)
def _check_header_file_integrity(self):
# TODO: temporary workaround for usability. move to txlbryum and check headers instead of file integrity
from lbrynet import conf
if conf.settings['blockchain_name'] != "lbrycrd_main":
return
hashsum = sha256()
checksum_height, checksum = conf.settings['HEADERS_FILE_SHA256_CHECKSUM']
checksum_length_in_bytes = checksum_height * HEADER_SIZE
if self.local_header_file_size() < checksum_length_in_bytes:
return
headers_path = os.path.join(self.config.path, "blockchain_headers")
with open(headers_path, "rb") as headers_file:
hashsum.update(headers_file.read(checksum_length_in_bytes))
current_checksum = hashsum.hexdigest()
if current_checksum != checksum:
msg = "Expected checksum {}, got {}".format(checksum, current_checksum)
log.warning("Wallet file corrupted, checksum mismatch. " + msg)
log.warning("Deleting header file so it can be downloaded again.")
os.unlink(headers_path)
elif (self.local_header_file_size() % HEADER_SIZE) != 0:
log.warning("Header file is good up to checkpoint height, but incomplete. Truncating to checkpoint.")
with open(headers_path, "rb+") as headers_file:
headers_file.truncate(checksum_length_in_bytes)
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
should_download_headers = yield self.should_download_headers_from_s3() should_download_headers = yield self.should_download_headers_from_s3()

View file

@ -6,8 +6,7 @@ import os
import sys import sys
import traceback import traceback
from txrequests import Session import treq
from requests.exceptions import ConnectionError
from twisted.internet import defer from twisted.internet import defer
import twisted.python.log import twisted.python.log
@ -35,13 +34,13 @@ TRACE = 5
class HTTPSHandler(logging.Handler): class HTTPSHandler(logging.Handler):
def __init__(self, url, fqdn=False, localname=None, facility=None, session=None): def __init__(self, url, fqdn=False, localname=None, facility=None, cookies=None):
logging.Handler.__init__(self) logging.Handler.__init__(self)
self.url = url self.url = url
self.fqdn = fqdn self.fqdn = fqdn
self.localname = localname self.localname = localname
self.facility = facility self.facility = facility
self.session = session if session is not None else Session() self.cookies = cookies or {}
def get_full_message(self, record): def get_full_message(self, record):
if record.exc_info: if record.exc_info:
@ -52,10 +51,8 @@ class HTTPSHandler(logging.Handler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _emit(self, record): def _emit(self, record):
payload = self.format(record) payload = self.format(record)
try: response = yield treq.post(self.url, data=payload, cookies=self.cookies)
yield self.session.post(self.url, data=payload) self.cookies.update(response.cookies())
except ConnectionError:
pass
def emit(self, record): def emit(self, record):
return self._emit(record) return self._emit(record)

View file

@ -37,7 +37,7 @@ def get_platform(get_ip=True):
"build": build_type.BUILD, # CI server sets this during build step "build": build_type.BUILD, # CI server sets this during build step
} }
# TODO: remove this from get_platform and add a get_external_ip function using txrequests # TODO: remove this from get_platform and add a get_external_ip function using treq
if get_ip: if get_ip:
try: try:
response = json.loads(urlopen("https://api.lbry.io/ip").read()) response = json.loads(urlopen("https://api.lbry.io/ip").read())

View file

@ -1,12 +1,12 @@
""" """
Utility for creating Crypt Streams, which are encrypted blobs and associated metadata. Utility for creating Crypt Streams, which are encrypted blobs and associated metadata.
""" """
import os
import logging import logging
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from twisted.internet import interfaces, defer from twisted.internet import interfaces, defer
from zope.interface import implements from zope.interface import implements
from Crypto import Random
from Crypto.Cipher import AES
from lbrynet.cryptstream.CryptBlob import CryptStreamBlobMaker from lbrynet.cryptstream.CryptBlob import CryptStreamBlobMaker
@ -101,13 +101,13 @@ class CryptStreamCreator(object):
@staticmethod @staticmethod
def random_iv_generator(): def random_iv_generator():
while 1: while 1:
yield Random.new().read(AES.block_size) yield os.urandom(AES.block_size / 8)
def setup(self): def setup(self):
"""Create the symmetric key if it wasn't provided""" """Create the symmetric key if it wasn't provided"""
if self.key is None: if self.key is None:
self.key = Random.new().read(AES.block_size) self.key = os.urandom(AES.block_size / 8)
return defer.succeed(True) return defer.succeed(True)

View file

@ -1,8 +1,9 @@
import time import time
import requests
import logging import logging
import json import json
from twisted.internet import defer, threads
import treq
from twisted.internet import defer
from twisted.internet.task import LoopingCall from twisted.internet.task import LoopingCall
from lbrynet.core.Error import InvalidExchangeRateResponse from lbrynet.core.Error import InvalidExchangeRateResponse
@ -52,9 +53,10 @@ class MarketFeed(object):
def is_online(self): def is_online(self):
return self._online return self._online
@defer.inlineCallbacks
def _make_request(self): def _make_request(self):
r = requests.get(self.url, self.params, timeout=self.REQUESTS_TIMEOUT) response = yield treq.get(self.url, params=self.params, timeout=self.REQUESTS_TIMEOUT)
return r.text defer.returnValue((yield response.content()))
def _handle_response(self, response): def _handle_response(self, response):
return NotImplementedError return NotImplementedError
@ -75,7 +77,7 @@ class MarketFeed(object):
self._online = False self._online = False
def _update_price(self): def _update_price(self):
d = threads.deferToThread(self._make_request) d = self._make_request()
d.addCallback(self._handle_response) d.addCallback(self._handle_response)
d.addCallback(self._subtract_fee) d.addCallback(self._subtract_fee)
d.addCallback(self._save_price) d.addCallback(self._save_price)

View file

@ -1,10 +0,0 @@
"""
A client library for sending and receiving payments on the point trader network.
The point trader network is a simple payment system used solely for testing lbrynet-console. A user
creates a public key, registers it with the point trader server, and receives free points for
registering. The public key is used to spend points, and also used as an address to which points
are sent. To spend points, the public key signs a message containing the amount and the destination
public key and sends it to the point trader server. To check for payments, the recipient sends a
signed message asking the point trader server for its balance.
"""

View file

@ -1,230 +0,0 @@
from lbrynet import conf
from twisted.web.client import Agent, FileBodyProducer, Headers, ResponseDone
from twisted.internet import threads, defer, protocol
from Crypto.Hash import SHA
from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_PSS
from StringIO import StringIO
import time
import json
import binascii
class BeginningPrinter(protocol.Protocol):
def __init__(self, finished):
self.finished = finished
self.data = ""
def dataReceived(self, bytes):
self.data = self.data + bytes
def connectionLost(self, reason):
if reason.check(ResponseDone) is not None:
self.finished.callback(str(self.data))
else:
self.finished.errback(reason)
def read_body(response):
d = defer.Deferred()
response.deliverBody(BeginningPrinter(d))
return d
def get_body(response):
if response.code != 200:
print "\n\n\n\nbad error code\n\n\n\n"
raise ValueError(response.phrase)
else:
return read_body(response)
def get_body_from_request(path, data):
from twisted.internet import reactor
jsondata = FileBodyProducer(StringIO(json.dumps(data)))
agent = Agent(reactor)
d = agent.request(
'POST', conf.settings['pointtrader_server'] + path,
Headers({'Content-Type': ['application/json']}), jsondata)
d.addCallback(get_body)
return d
def print_response(response):
pass
def print_error(err):
print err.getTraceback()
return err
def register_new_account(private_key):
data = {}
data['pub_key'] = private_key.publickey().exportKey()
def get_success_from_body(body):
r = json.loads(body)
if not 'success' in r or r['success'] is False:
return False
return True
d = get_body_from_request('/register/', data)
d.addCallback(get_success_from_body)
return d
def send_points(private_key, recipient_public_key, amount):
encoded_public_key = private_key.publickey().exportKey()
timestamp = time.time()
h = SHA.new()
h.update(encoded_public_key)
h.update(recipient_public_key)
h.update(str(amount))
h.update(str(timestamp))
signer = PKCS1_PSS.new(private_key)
signature = binascii.hexlify(signer.sign(h))
data = {}
data['sender_pub_key'] = encoded_public_key
data['recipient_pub_key'] = recipient_public_key
data['amount'] = amount
data['timestamp'] = timestamp
data['signature'] = signature
def get_success_from_body(body):
r = json.loads(body)
if not 'success' in r or r['success'] is False:
return False
return True
d = get_body_from_request('/send-points/', data)
d.addCallback(get_success_from_body)
return d
def get_recent_transactions(private_key):
encoded_public_key = private_key.publickey().exportKey()
timestamp = time.time()
h = SHA.new()
h.update(encoded_public_key)
h.update(str(timestamp))
signer = PKCS1_PSS.new(private_key)
signature = binascii.hexlify(signer.sign(h))
data = {}
data['pub_key'] = encoded_public_key
data['timestamp'] = timestamp
data['signature'] = signature
data['end_time'] = 0
data['start_time'] = 120
def get_transactions_from_body(body):
r = json.loads(body)
if "transactions" not in r:
raise ValueError("Invalid response: no 'transactions' field")
else:
return r['transactions']
d = get_body_from_request('/get-transactions/', data)
d.addCallback(get_transactions_from_body)
return d
def get_balance(private_key):
encoded_public_key = private_key.publickey().exportKey()
timestamp = time.time()
h = SHA.new()
h.update(encoded_public_key)
h.update(str(timestamp))
signer = PKCS1_PSS.new(private_key)
signature = binascii.hexlify(signer.sign(h))
data = {}
data['pub_key'] = encoded_public_key
data['timestamp'] = timestamp
data['signature'] = signature
def get_balance_from_body(body):
r = json.loads(body)
if not 'balance' in r:
raise ValueError("Invalid response: no 'balance' field")
else:
return float(r['balance'])
d = get_body_from_request('/get-balance/', data)
d.addCallback(get_balance_from_body)
return d
def run_full_test():
keys = []
def save_key(private_key):
keys.append(private_key)
return private_key
def check_balances_and_transactions(unused, bal1, bal2, num_transactions):
def assert_balance_is(actual, expected):
assert abs(actual - expected) < .05
print "correct balance. actual:", str(actual), "expected:", str(expected)
return True
def assert_transaction_length_is(transactions, expected_length):
assert len(transactions) == expected_length
print "correct transaction length"
return True
d1 = get_balance(keys[0])
d1.addCallback(assert_balance_is, bal1)
d2 = get_balance(keys[1])
d2.addCallback(assert_balance_is, bal2)
d3 = get_recent_transactions(keys[0])
d3.addCallback(assert_transaction_length_is, num_transactions)
d4 = get_recent_transactions(keys[1])
d4.addCallback(assert_transaction_length_is, num_transactions)
dl = defer.DeferredList([d1, d2, d3, d4])
return dl
def do_transfer(unused, amount):
d = send_points(keys[0], keys[1].publickey().exportKey(), amount)
return d
d1 = threads.deferToThread(RSA.generate, 4096)
d1.addCallback(save_key)
d1.addCallback(register_new_account)
d2 = threads.deferToThread(RSA.generate, 4096)
d2.addCallback(save_key)
d2.addCallback(register_new_account)
dlist = defer.DeferredList([d1, d2])
dlist.addCallback(check_balances_and_transactions, 1000, 1000, 0)
dlist.addCallback(do_transfer, 50)
dlist.addCallback(check_balances_and_transactions, 950, 1050, 1)
dlist.addCallback(do_transfer, 75)
dlist.addCallback(check_balances_and_transactions, 875, 1125, 2)
dlist.addErrback(print_error)
if __name__ == "__main__":
from twisted.internet import reactor
reactor.callLater(1, run_full_test)
reactor.callLater(25, reactor.stop)
reactor.run()

View file

@ -4,11 +4,9 @@ import os
import platform import platform
import shutil import shutil
import sys import sys
import random
import unittest import unittest
from Crypto import Random from hashlib import md5
from Crypto.Hash import MD5
from lbrynet import conf from lbrynet import conf
from lbrynet.file_manager.EncryptedFileManager import EncryptedFileManager from lbrynet.file_manager.EncryptedFileManager import EncryptedFileManager
from lbrynet.core.Session import Session from lbrynet.core.Session import Session
@ -98,9 +96,6 @@ class LbryUploader(object):
from twisted.internet import reactor from twisted.internet import reactor
self.reactor = reactor self.reactor = reactor
logging.debug("Starting the uploader") logging.debug("Starting the uploader")
Random.atfork()
r = random.Random()
r.seed("start_lbry_uploader")
wallet = FakeWallet() wallet = FakeWallet()
peer_manager = PeerManager() peer_manager = PeerManager()
peer_finder = FakePeerFinder(5553, peer_manager, 1) peer_finder = FakePeerFinder(5553, peer_manager, 1)
@ -191,10 +186,6 @@ def start_lbry_reuploader(sd_hash, kill_event, dead_event,
logging.debug("Starting the uploader") logging.debug("Starting the uploader")
Random.atfork()
r = random.Random()
r.seed("start_lbry_reuploader")
wallet = FakeWallet() wallet = FakeWallet()
peer_port = 5553 + n peer_port = 5553 + n
@ -297,7 +288,6 @@ def start_blob_uploader(blob_hash_queue, kill_event, dead_event, slow, is_genero
logging.debug("Starting the uploader") logging.debug("Starting the uploader")
Random.atfork()
wallet = FakeWallet() wallet = FakeWallet()
peer_manager = PeerManager() peer_manager = PeerManager()
@ -515,7 +505,7 @@ class TestTransfer(TestCase):
def check_md5_sum(): def check_md5_sum():
f = open(os.path.join(db_dir, 'test_file')) f = open(os.path.join(db_dir, 'test_file'))
hashsum = MD5.new() hashsum = md5()
hashsum.update(f.read()) hashsum.update(f.read())
self.assertEqual(hashsum.hexdigest(), "4ca2aafb4101c1e42235aad24fbb83be") self.assertEqual(hashsum.hexdigest(), "4ca2aafb4101c1e42235aad24fbb83be")
@ -688,7 +678,7 @@ class TestTransfer(TestCase):
def check_md5_sum(): def check_md5_sum():
f = open(os.path.join(db_dir, 'test_file')) f = open(os.path.join(db_dir, 'test_file'))
hashsum = MD5.new() hashsum = md5()
hashsum.update(f.read()) hashsum.update(f.read())
self.assertEqual(hashsum.hexdigest(), "4ca2aafb4101c1e42235aad24fbb83be") self.assertEqual(hashsum.hexdigest(), "4ca2aafb4101c1e42235aad24fbb83be")
@ -811,7 +801,7 @@ class TestTransfer(TestCase):
def check_md5_sum(): def check_md5_sum():
f = open('test_file') f = open('test_file')
hashsum = MD5.new() hashsum = md5()
hashsum.update(f.read()) hashsum.update(f.read())
self.assertEqual(hashsum.hexdigest(), "e5941d615f53312fd66638239c1f90d5") self.assertEqual(hashsum.hexdigest(), "e5941d615f53312fd66638239c1f90d5")

View file

@ -2,7 +2,7 @@ import os
import shutil import shutil
import tempfile import tempfile
from Crypto.Hash import MD5 from hashlib import md5
from twisted.trial.unittest import TestCase from twisted.trial.unittest import TestCase
from twisted.internet import defer, threads from twisted.internet import defer, threads
@ -127,7 +127,7 @@ class TestStreamify(TestCase):
self.assertTrue(lbry_file.sd_hash, sd_hash) self.assertTrue(lbry_file.sd_hash, sd_hash)
yield lbry_file.start() yield lbry_file.start()
f = open('test_file') f = open('test_file')
hashsum = MD5.new() hashsum = md5()
hashsum.update(f.read()) hashsum.update(f.read())
self.assertEqual(hashsum.hexdigest(), "68959747edc73df45e45db6379dd7b3b") self.assertEqual(hashsum.hexdigest(), "68959747edc73df45e45db6379dd7b3b")

View file

@ -1,7 +1,10 @@
import base64
import struct import struct
import io import io
from Crypto.PublicKey import RSA from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from twisted.internet import defer, error from twisted.internet import defer, error
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -15,6 +18,12 @@ from lbrynet import conf
from util import debug_kademlia_packet from util import debug_kademlia_packet
KB = 2**10 KB = 2**10
PUBLIC_EXPONENT = 65537 # http://www.daemonology.net/blog/2009-06-11-cryptographic-right-answers.html
def decode_rsa_key(pem_key):
decoded = base64.b64decode(''.join(pem_key.splitlines()[1:-1]))
return serialization.load_der_public_key(decoded, default_backend())
class FakeLBRYFile(object): class FakeLBRYFile(object):
@ -137,9 +146,10 @@ class PointTraderKeyQueryHandler(object):
if self.query_identifiers[0] in queries: if self.query_identifiers[0] in queries:
new_encoded_pub_key = queries[self.query_identifiers[0]] new_encoded_pub_key = queries[self.query_identifiers[0]]
try: try:
RSA.importKey(new_encoded_pub_key) decode_rsa_key(new_encoded_pub_key)
except (ValueError, TypeError, IndexError): except (ValueError, TypeError, IndexError):
return defer.fail(Failure(ValueError("Client sent an invalid public key"))) value_error = ValueError("Client sent an invalid public key: {}".format(new_encoded_pub_key))
return defer.fail(Failure(value_error))
self.public_key = new_encoded_pub_key self.public_key = new_encoded_pub_key
self.wallet.set_public_key_for_peer(self.peer, self.public_key) self.wallet.set_public_key_for_peer(self.peer, self.public_key)
fields = {'public_key': self.wallet.encoded_public_key} fields = {'public_key': self.wallet.encoded_public_key}
@ -152,8 +162,10 @@ class PointTraderKeyQueryHandler(object):
class Wallet(object): class Wallet(object):
def __init__(self): def __init__(self):
self.private_key = RSA.generate(1024) self.private_key = rsa.generate_private_key(public_exponent=PUBLIC_EXPONENT,
self.encoded_public_key = self.private_key.publickey().exportKey() key_size=1024, backend=default_backend())
self.encoded_public_key = self.private_key.public_key().public_bytes(serialization.Encoding.PEM,
serialization.PublicFormat.PKCS1)
self._config = None self._config = None
self.network = None self.network = None
self.wallet = None self.wallet = None

View file

@ -5,11 +5,13 @@ from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from lbrynet.tests.mocks import mock_conf_settings from lbrynet.tests.mocks import mock_conf_settings
from Crypto import Random from cryptography.hazmat.primitives.ciphers.algorithms import AES
from Crypto.Cipher import AES
import random import random
import string import string
import StringIO import StringIO
import os
AES_BLOCK_SIZE_BYTES = AES.block_size / 8
class MocBlob(object): class MocBlob(object):
def __init__(self): def __init__(self):
@ -44,8 +46,8 @@ class TestCryptBlob(unittest.TestCase):
# max blob size is 2*2**20 -1 ( -1 due to required padding in the end ) # max blob size is 2*2**20 -1 ( -1 due to required padding in the end )
blob = MocBlob() blob = MocBlob()
blob_num = 0 blob_num = 0
key = Random.new().read(AES.block_size) key = os.urandom(AES_BLOCK_SIZE_BYTES)
iv = Random.new().read(AES.block_size) iv = os.urandom(AES_BLOCK_SIZE_BYTES)
maker = CryptBlob.CryptStreamBlobMaker(key, iv, blob_num, blob) maker = CryptBlob.CryptStreamBlobMaker(key, iv, blob_num, blob)
write_size = size_of_data write_size = size_of_data
string_to_encrypt = random_string(size_of_data) string_to_encrypt = random_string(size_of_data)
@ -54,7 +56,7 @@ class TestCryptBlob(unittest.TestCase):
done, num_bytes = maker.write(string_to_encrypt) done, num_bytes = maker.write(string_to_encrypt)
yield maker.close() yield maker.close()
self.assertEqual(size_of_data, num_bytes) self.assertEqual(size_of_data, num_bytes)
expected_encrypted_blob_size = ((size_of_data / AES.block_size) + 1) * AES.block_size expected_encrypted_blob_size = ((size_of_data / AES_BLOCK_SIZE_BYTES) + 1) * AES_BLOCK_SIZE_BYTES
self.assertEqual(expected_encrypted_blob_size, len(blob.data)) self.assertEqual(expected_encrypted_blob_size, len(blob.data))
if size_of_data < MAX_BLOB_SIZE-1: if size_of_data < MAX_BLOB_SIZE-1:

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from Crypto.Cipher import AES from cryptography.hazmat.primitives.ciphers.algorithms import AES
import mock import mock
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer from twisted.internet import defer
@ -18,7 +18,7 @@ MB = 2**20
def iv_generator(): def iv_generator():
while True: while True:
yield '3' * AES.block_size yield '3' * (AES.block_size / 8)
class CreateEncryptedFileTest(unittest.TestCase): class CreateEncryptedFileTest(unittest.TestCase):
@ -47,7 +47,7 @@ class CreateEncryptedFileTest(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def create_file(self, filename): def create_file(self, filename):
handle = mocks.GenFile(3*MB, '1') handle = mocks.GenFile(3*MB, '1')
key = '2'*AES.block_size key = '2' * (AES.block_size / 8)
out = yield EncryptedFileCreator.create_lbry_file(self.session, self.file_manager, filename, handle, out = yield EncryptedFileCreator.create_lbry_file(self.session, self.file_manager, filename, handle,
key, iv_generator()) key, iv_generator())
defer.returnValue(out) defer.returnValue(out)

View file

@ -1,5 +1,5 @@
Twisted==16.6.0 Twisted==16.6.0
cryptography==2.0.3 cryptography==2.2.2
appdirs==1.4.3 appdirs==1.4.3
argparse==1.2.1 argparse==1.2.1
docopt==0.6.2 docopt==0.6.2
@ -16,15 +16,14 @@ git+https://github.com/lbryio/lbryschema.git@v0.0.15#egg=lbryschema
git+https://github.com/lbryio/lbryum.git@v3.2.1#egg=lbryum git+https://github.com/lbryio/lbryum.git@v3.2.1#egg=lbryum
miniupnpc==1.9 miniupnpc==1.9
pbkdf2==1.3 pbkdf2==1.3
pycrypto==2.6.1
pyyaml==3.12 pyyaml==3.12
PyGithub==1.34 PyGithub==1.34
qrcode==5.2.2 qrcode==5.2.2
requests==2.9.1 requests==2.9.1
txrequests==0.9.5
service_identity==16.0.0 service_identity==16.0.0
six>=1.9.0 six>=1.9.0
slowaes==0.1a1 slowaes==0.1a1
txJSON-RPC==0.5 txJSON-RPC==0.5
wsgiref==0.1.2 wsgiref==0.1.2
zope.interface==4.3.3 zope.interface==4.3.3
treq==17.8.0

View file

@ -23,12 +23,11 @@ requires = [
'lbryschema==0.0.15', 'lbryschema==0.0.15',
'lbryum==3.2.1', 'lbryum==3.2.1',
'miniupnpc', 'miniupnpc',
'pycrypto',
'pyyaml', 'pyyaml',
'requests', 'requests',
'txrequests',
'txJSON-RPC', 'txJSON-RPC',
'zope.interface', 'zope.interface',
'treq',
'docopt' 'docopt'
] ]