diff --git a/lbry/tests/integration/test_transaction_commands.py b/lbry/tests/integration/test_transaction_commands.py index a7f624650..0e77f3496 100644 --- a/lbry/tests/integration/test_transaction_commands.py +++ b/lbry/tests/integration/test_transaction_commands.py @@ -36,3 +36,34 @@ class TransactionCommandsTestCase(CommandTestCase): await self.assertBalance(self.account, '0.0') await self.daemon.jsonrpc_utxo_release() await self.assertBalance(self.account, '11.0') + + +class TestSegwit(CommandTestCase): + + VERBOSITY = 10 + + async def test_segwit(self): + p2sh_address = await self.blockchain.get_new_address(self.blockchain.P2SH_SEGWIT_ADDRESS) + bech32_address = await self.blockchain.get_new_address(self.blockchain.BECH32_ADDRESS) + p2sh_tx1 = await self.blockchain.send_to_address(p2sh_address, '1.0') + p2sh_tx2 = await self.blockchain.send_to_address(p2sh_address, '1.0') + bech32_tx1 = await self.blockchain.send_to_address(bech32_address, '1.0') + bech32_tx2 = await self.blockchain.send_to_address(bech32_address, '1.0') + + await self.generate(1) + + address = (await self.account.receiving.get_addresses(limit=1, only_usable=True))[0] + + tx = await self.blockchain.create_raw_transaction([ + {"txid": p2sh_tx1, "vout": 0}, + {"txid": p2sh_tx2, "vout": 0}, + {"txid": bech32_tx1, "vout": 0}, + {"txid": bech32_tx2, "vout": 0}, + ], [{address: '3.5'}] + ) + + tx = await self.blockchain.sign_raw_transaction_with_wallet(tx) + txid = await self.blockchain.send_raw_transaction(tx) + await self.on_transaction_id(txid) + + await self.assertBalance(self.account, '13.5') diff --git a/torba/torba/client/basetransaction.py b/torba/torba/client/basetransaction.py index 3b4761401..9f9d69a91 100644 --- a/torba/torba/client/basetransaction.py +++ b/torba/torba/client/basetransaction.py @@ -34,7 +34,7 @@ class TXRefMutable(TXRef): @property def hash(self): if self._hash is None: - self._hash = sha256(sha256(self.tx.raw)) + self._hash = sha256(sha256(self.tx.raw_sans_segwit)) return self._hash @property @@ -260,6 +260,9 @@ class BaseTransaction: def __init__(self, raw=None, version: int = 1, locktime: int = 0, is_verified: bool = False, height: int = -2, position: int = -1) -> None: self._raw = raw + self._raw_sans_segwit = None + self.is_segwit_flag = 0 + self.witnesses = [] self.ref = TXRefMutable(self) self.version = version self.locktime = locktime @@ -302,8 +305,17 @@ class BaseTransaction: self._raw = self._serialize() return self._raw + @property + def raw_sans_segwit(self): + if self.is_segwit_flag: + if self._raw_sans_segwit is None: + self._raw_sans_segwit = self._serialize(sans_segwit=True) + return self._raw_sans_segwit + return self.raw + def _reset(self): self._raw = None + self._raw_sans_segwit = None self.ref.reset() @property @@ -390,7 +402,7 @@ class BaseTransaction: """ Sum of output values *plus* the cost involved to spend them. """ return sum(txo.amount + txo.get_fee(ledger) for txo in self._outputs) - def _serialize(self, with_inputs: bool = True) -> bytes: + def _serialize(self, with_inputs: bool = True, sans_segwit: bool = False) -> bytes: stream = BCDataStream() stream.write_uint32(self.version) if with_inputs: @@ -425,9 +437,8 @@ class BaseTransaction: stream = BCDataStream(self._raw) self.version = stream.read_uint32() input_count = stream.read_compact_size() - flag = 0 if input_count == 0: - flag = stream.read_uint8() + self.is_segwit_flag = stream.read_uint8() input_count = stream.read_compact_size() self._add(self._inputs, [ self.input_class.deserialize_from(stream) for _ in range(input_count) @@ -436,12 +447,13 @@ class BaseTransaction: self._add(self._outputs, [ self.output_class.deserialize_from(stream) for _ in range(output_count) ]) - if flag == 1: + if self.is_segwit_flag: # drain witness portion of transaction # too many witnesses for no crime + self.witnesses = [] for _ in range(input_count): for _ in range(stream.read_compact_size()): - stream.read(stream.read_compact_size()) + self.witnesses.append(stream.read(stream.read_compact_size())) self.locktime = stream.read_uint32() @classmethod diff --git a/torba/torba/orchstr8/node.py b/torba/torba/orchstr8/node.py index 421011d9f..bdcb8ddfe 100644 --- a/torba/torba/orchstr8/node.py +++ b/torba/torba/orchstr8/node.py @@ -1,4 +1,5 @@ import os +import json import shutil import asyncio import zipfile @@ -267,6 +268,9 @@ class BlockchainProcess(asyncio.SubprocessProtocol): class BlockchainNode: + P2SH_SEGWIT_ADDRESS = "p2sh-segwit" + BECH32_ADDRESS = "bech32" + def __init__(self, url, daemon, cli, segwit_enabled=False): self.latest_release_url = url self.project_dir = os.path.dirname(os.path.dirname(__file__)) @@ -391,6 +395,9 @@ class BlockchainNode: def get_raw_change_address(self): return self._cli_cmnd('getrawchangeaddress') + def get_new_address(self, type): + return self._cli_cmnd('getnewaddress', "", type) + async def get_balance(self): return float(await self._cli_cmnd('getbalance')) @@ -400,6 +407,12 @@ class BlockchainNode: def send_raw_transaction(self, tx): return self._cli_cmnd('sendrawtransaction', tx.decode()) + def create_raw_transaction(self, inputs, outputs): + return self._cli_cmnd('createrawtransaction', json.dumps(inputs), json.dumps(outputs)) + + async def sign_raw_transaction_with_wallet(self, tx): + return json.loads(await self._cli_cmnd('signrawtransactionwithwallet', tx))['hex'].encode() + def decode_raw_transaction(self, tx): return self._cli_cmnd('decoderawtransaction', hexlify(tx.raw).decode())