From 2040748c62de623b548364cf4932212f051e4b74 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Tue, 6 Mar 2018 14:30:19 -0500 Subject: [PATCH 1/2] add stratum client protocol --- lbrynet/txlbryum/__init__.py | 0 lbrynet/txlbryum/client.py | 201 +++++++++++++++++++++++++++++++++++ lbrynet/txlbryum/errors.py | 18 ++++ lbrynet/txlbryum/factory.py | 107 +++++++++++++++++++ 4 files changed, 326 insertions(+) create mode 100644 lbrynet/txlbryum/__init__.py create mode 100644 lbrynet/txlbryum/client.py create mode 100644 lbrynet/txlbryum/errors.py create mode 100644 lbrynet/txlbryum/factory.py diff --git a/lbrynet/txlbryum/__init__.py b/lbrynet/txlbryum/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lbrynet/txlbryum/client.py b/lbrynet/txlbryum/client.py new file mode 100644 index 000000000..33bc0c490 --- /dev/null +++ b/lbrynet/txlbryum/client.py @@ -0,0 +1,201 @@ +import json +import logging +import socket + +from twisted.internet import defer, error +from twisted.protocols.basic import LineOnlyReceiver +from errors import RemoteServiceException, ProtocolException, ServiceException + +log = logging.getLogger() + + +class RequestCounter(object): + def __init__(self): + self.on_finish = defer.Deferred() + self.counter = 0 + + def set_count(self, cnt): + self.counter = cnt + + def decrease(self): + self.counter -= 1 + if self.counter <= 0: + self.finish() + + def finish(self): + if not self.on_finish.called: + self.on_finish.callback(True) + + +class StratumClientProtocol(LineOnlyReceiver): + delimiter = '\n' + + def __init__(self): + self._connected = defer.Deferred() + + def _get_id(self): + self.request_id += 1 + return self.request_id + + def _get_ip(self): + return self.transport.getPeer().host + + def get_session(self): + return self.session + + def connectionMade(self): + try: + self.transport.setTcpNoDelay(True) + self.transport.setTcpKeepAlive(True) + self.transport.socket.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, + 120) # Seconds before sending keepalive probes + self.transport.socket.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, + 1) # Interval in seconds between keepalive probes + self.transport.socket.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, + 5) # Failed keepalive probles before declaring other end dead + except Exception as err: + # Supported only by the socket transport, + # but there's really no better place in code to trigger this. + log.warning("Error setting up socket: %s", err) + + self.request_id = 0 + self.lookup_table = {} + self.on_finish = None # Will point to defer which is called + # once all client requests are processed + + self._connected.callback(True) + + # Initiate connection session + self.session = {} + + log.debug("Connected %s" % self.transport.getPeer().host) + + def transport_write(self, data): + '''Overwrite this if transport needs some extra care about data written + to the socket, like adding message format in websocket.''' + try: + self.transport.write(data) + except AttributeError: + # Transport is disconnected + log.warning("transport is disconnected") + + def writeJsonRequest(self, method, params, is_notification=False): + request_id = None if is_notification else self._get_id() + serialized = json.dumps({'id': request_id, 'method': method, 'params': params}) + self.transport_write("%s\n" % serialized) + return request_id + + def writeJsonResponse(self, data, message_id): + serialized = json.dumps({'id': message_id, 'result': data, 'error': None}) + self.transport_write("%s\n" % serialized) + + def writeJsonError(self, code, message, traceback, message_id): + serialized = json.dumps( + {'id': message_id, 'result': None, 'error': (code, message, traceback)} + ) + self.transport_write("%s\n" % serialized) + + def writeGeneralError(self, message, code=-1): + log.error(message) + return self.writeJsonError(code, message, None, None) + + def process_response(self, data, message_id, request_counter): + self.writeJsonResponse(data.result, message_id) + request_counter.decrease() + + def process_failure(self, failure, message_id, request_counter): + if not isinstance(failure.value, ServiceException): + # All handled exceptions should inherit from ServiceException class. + # Throwing other exception class means that it is unhandled error + # and we should log it. + log.exception(failure) + code = getattr(failure.value, 'code', -1) + if message_id != None: + tb = failure.getBriefTraceback() + self.writeJsonError(code, failure.getErrorMessage(), tb, message_id) + request_counter.decrease() + + def dataReceived(self, data, request_counter=None): + '''Original code from Twisted, hacked for request_counter proxying. + request_counter is hack for HTTP transport, didn't found cleaner solution how + to indicate end of request processing in asynchronous manner. + + TODO: This would deserve some unit test to be sure that future twisted versions + will work nicely with this.''' + + if request_counter == None: + request_counter = RequestCounter() + + lines = (self._buffer + data).split(self.delimiter) + self._buffer = lines.pop(-1) + request_counter.set_count(len(lines)) + self.on_finish = request_counter.on_finish + + for line in lines: + if self.transport.disconnecting: + request_counter.finish() + return + if len(line) > self.MAX_LENGTH: + request_counter.finish() + return self.lineLengthExceeded(line) + else: + try: + self.lineReceived(line, request_counter) + except Exception as exc: + request_counter.finish() + # log.exception("Processing of message failed") + log.warning("Failed message: %s from %s" % (str(exc), self._get_ip())) + return error.ConnectionLost('Processing of message failed') + + if len(self._buffer) > self.MAX_LENGTH: + request_counter.finish() + return self.lineLengthExceeded(self._buffer) + + def lineReceived(self, line, request_counter): + try: + message = json.loads(line) + except (ValueError, TypeError): + # self.writeGeneralError("Cannot decode message '%s'" % line) + request_counter.finish() + raise ProtocolException("Cannot decode message '%s'" % line.strip()) + msg_id = message.get('id', 0) + msg_result = message.get('result') + msg_error = message.get('error') + if msg_id: + # It's a RPC response + # Perform lookup to the table of waiting requests. + request_counter.decrease() + try: + meta = self.lookup_table[msg_id] + del self.lookup_table[msg_id] + except KeyError: + # When deferred object for given message ID isn't found, it's an error + raise ProtocolException( + "Lookup for deferred object for message ID '%s' failed." % msg_id) + # If there's an error, handle it as errback + # If both result and error are null, handle it as a success with blank result + if msg_error != None: + meta['defer'].errback( + RemoteServiceException(msg_error[0], msg_error[1], msg_error[2]) + ) + else: + meta['defer'].callback(msg_result) + else: + request_counter.decrease() + raise ProtocolException("Cannot handle message '%s'" % line) + + def rpc(self, method, params, is_notification=False): + ''' + This method performs remote RPC call. + + If method should expect an response, it store + request ID to lookup table and wait for corresponding + response message. + ''' + + request_id = self.writeJsonRequest(method, params, is_notification) + if is_notification: + return + d = defer.Deferred() + self.lookup_table[request_id] = {'defer': d, 'method': method, 'params': params} + return d diff --git a/lbrynet/txlbryum/errors.py b/lbrynet/txlbryum/errors.py new file mode 100644 index 000000000..eaa8723dc --- /dev/null +++ b/lbrynet/txlbryum/errors.py @@ -0,0 +1,18 @@ +class TransportException(Exception): + pass + + +class ServiceException(Exception): + code = -2 + + +class RemoteServiceException(Exception): + pass + + +class ProtocolException(Exception): + pass + + +class MethodNotFoundException(ServiceException): + code = -3 diff --git a/lbrynet/txlbryum/factory.py b/lbrynet/txlbryum/factory.py new file mode 100644 index 000000000..72af607d1 --- /dev/null +++ b/lbrynet/txlbryum/factory.py @@ -0,0 +1,107 @@ +import logging +from twisted.internet import defer +from twisted.internet.protocol import ClientFactory +from client import StratumClientProtocol +from errors import TransportException + +log = logging.getLogger() + + +class StratumClient(ClientFactory): + protocol = StratumClientProtocol + + def __init__(self, connected_d=None): + self.client = None + self.connected_d = connected_d or defer.Deferred() + + def buildProtocol(self, addr): + client = self.protocol() + client.factory = self + self.client = client + self.client._connected.addCallback(lambda _: self.connected_d.callback(self)) + return client + + def _rpc(self, method, params, *args, **kwargs): + if not self.client: + raise TransportException("Not connected") + + return self.client.rpc(method, params, *args, **kwargs) + + def blockchain_claimtrie_getvaluesforuris(self, block_hash, *uris): + return self._rpc('blockchain.claimtrie.getvaluesforuris', + [block_hash] + list(uris)) + + def blockchain_claimtrie_getvaluesforuri(self, block_hash, uri): + return self._rpc('blockchain.claimtrie.getvaluesforuri', [block_hash, uri]) + + def blockchain_claimtrie_getclaimssignedbynthtoname(self, name, n): + return self._rpc('blockchain.claimtrie.getclaimssignedbynthtoname', [name, n]) + + def blockchain_claimtrie_getclaimssignedbyid(self, certificate_id): + return self._rpc('blockchain.claimtrie.getclaimssignedbyid', [certificate_id]) + + def blockchain_claimtrie_getclaimssignedby(self, name): + return self._rpc('blockchain.claimtrie.getclaimssignedby', [name]) + + def blockchain_claimtrie_getnthclaimforname(self, name, n): + return self._rpc('blockchain.claimtrie.getnthclaimforname', [name, n]) + + def blockchain_claimtrie_getclaimsbyids(self, *claim_ids): + return self._rpc('blockchain.claimtrie.getclaimsbyids', list(claim_ids)) + + def blockchain_claimtrie_getclaimbyid(self, claim_id): + return self._rpc('blockchain.claimtrie.getclaimbyid', [claim_id]) + + def blockchain_claimtrie_get(self): + return self._rpc('blockchain.claimtrie.get', []) + + def blockchain_block_get_block(self, block_hash): + return self._rpc('blockchain.block.get_block', [block_hash]) + + def blockchain_claimtrie_getclaimsforname(self, name): + return self._rpc('blockchain.claimtrie.getclaimsforname', [name]) + + def blockchain_claimtrie_getclaimsintx(self, txid): + return self._rpc('blockchain.claimtrie.getclaimsintx', [txid]) + + def blockchain_claimtrie_getvalue(self, name, block_hash=None): + return self._rpc('blockchain.claimtrie.getvalue', [name, block_hash]) + + def blockchain_relayfee(self): + return self._rpc('blockchain.relayfee', []) + + def blockchain_estimatefee(self): + return self._rpc('blockchain.estimatefee', []) + + def blockchain_transaction_get(self, txid): + return self._rpc('blockchain.transaction.get', [txid]) + + def blockchain_transaction_get_merkle(self, tx_hash, height, cache_only=False): + return self._rpc('blockchain.transaction.get_merkle', [tx_hash, height, cache_only]) + + def blockchain_transaction_broadcast(self, raw_transaction): + return self._rpc('blockchain.transaction.broadcast', [raw_transaction]) + + def blockchain_block_get_chunk(self, index, cache_only=False): + return self._rpc('blockchain.block.get_chunk', [index, cache_only]) + + def blockchain_block_get_header(self, height, cache_only=False): + return self._rpc('blockchain.block.get_header', [height, cache_only]) + + def blockchain_utxo_get_address(self, txid, pos): + return self._rpc('blockchain.utxo.get_address', [txid, pos]) + + def blockchain_address_listunspent(self, address): + return self._rpc('blockchain.address.listunspent', [address]) + + def blockchain_address_get_proof(self, address): + return self._rpc('blockchain.address.get_proof', [address]) + + def blockchain_address_get_balance(self, address): + return self._rpc('blockchain.address.get_balance', [address]) + + def blockchain_address_get_mempool(self, address): + return self._rpc('blockchain.address.get_mempool', [address]) + + def blockchain_address_get_history(self, address): + return self._rpc('blockchain.address.get_history', [address]) From 633b49da923e6d87bf4703083a2d59bb1a5b8018 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Wed, 14 Mar 2018 15:01:47 -0400 Subject: [PATCH 2/2] removed RequestCounter --- lbrynet/txlbryum/client.py | 44 +++++--------------------------------- 1 file changed, 5 insertions(+), 39 deletions(-) diff --git a/lbrynet/txlbryum/client.py b/lbrynet/txlbryum/client.py index 33bc0c490..96a6a08e1 100644 --- a/lbrynet/txlbryum/client.py +++ b/lbrynet/txlbryum/client.py @@ -9,24 +9,6 @@ from errors import RemoteServiceException, ProtocolException, ServiceException log = logging.getLogger() -class RequestCounter(object): - def __init__(self): - self.on_finish = defer.Deferred() - self.counter = 0 - - def set_count(self, cnt): - self.counter = cnt - - def decrease(self): - self.counter -= 1 - if self.counter <= 0: - self.finish() - - def finish(self): - if not self.on_finish.called: - self.on_finish.callback(True) - - class StratumClientProtocol(LineOnlyReceiver): delimiter = '\n' @@ -60,8 +42,6 @@ class StratumClientProtocol(LineOnlyReceiver): self.request_id = 0 self.lookup_table = {} - self.on_finish = None # Will point to defer which is called - # once all client requests are processed self._connected.callback(True) @@ -99,11 +79,10 @@ class StratumClientProtocol(LineOnlyReceiver): log.error(message) return self.writeJsonError(code, message, None, None) - def process_response(self, data, message_id, request_counter): + def process_response(self, data, message_id): self.writeJsonResponse(data.result, message_id) - request_counter.decrease() - def process_failure(self, failure, message_id, request_counter): + def process_failure(self, failure, message_id): if not isinstance(failure.value, ServiceException): # All handled exceptions should inherit from ServiceException class. # Throwing other exception class means that it is unhandled error @@ -113,9 +92,8 @@ class StratumClientProtocol(LineOnlyReceiver): if message_id != None: tb = failure.getBriefTraceback() self.writeJsonError(code, failure.getErrorMessage(), tb, message_id) - request_counter.decrease() - def dataReceived(self, data, request_counter=None): + def dataReceived(self, data): '''Original code from Twisted, hacked for request_counter proxying. request_counter is hack for HTTP transport, didn't found cleaner solution how to indicate end of request processing in asynchronous manner. @@ -123,40 +101,30 @@ class StratumClientProtocol(LineOnlyReceiver): TODO: This would deserve some unit test to be sure that future twisted versions will work nicely with this.''' - if request_counter == None: - request_counter = RequestCounter() - lines = (self._buffer + data).split(self.delimiter) self._buffer = lines.pop(-1) - request_counter.set_count(len(lines)) - self.on_finish = request_counter.on_finish for line in lines: if self.transport.disconnecting: - request_counter.finish() return if len(line) > self.MAX_LENGTH: - request_counter.finish() return self.lineLengthExceeded(line) else: try: - self.lineReceived(line, request_counter) + self.lineReceived(line) except Exception as exc: - request_counter.finish() # log.exception("Processing of message failed") log.warning("Failed message: %s from %s" % (str(exc), self._get_ip())) return error.ConnectionLost('Processing of message failed') if len(self._buffer) > self.MAX_LENGTH: - request_counter.finish() return self.lineLengthExceeded(self._buffer) - def lineReceived(self, line, request_counter): + def lineReceived(self, line): try: message = json.loads(line) except (ValueError, TypeError): # self.writeGeneralError("Cannot decode message '%s'" % line) - request_counter.finish() raise ProtocolException("Cannot decode message '%s'" % line.strip()) msg_id = message.get('id', 0) msg_result = message.get('result') @@ -164,7 +132,6 @@ class StratumClientProtocol(LineOnlyReceiver): if msg_id: # It's a RPC response # Perform lookup to the table of waiting requests. - request_counter.decrease() try: meta = self.lookup_table[msg_id] del self.lookup_table[msg_id] @@ -181,7 +148,6 @@ class StratumClientProtocol(LineOnlyReceiver): else: meta['defer'].callback(msg_result) else: - request_counter.decrease() raise ProtocolException("Cannot handle message '%s'" % line) def rpc(self, method, params, is_notification=False):