updated with torba refactoring and working claim_send_to_address

This commit is contained in:
Lex Berezhny 2018-10-03 12:00:21 -04:00
parent 1a5b2c08ee
commit 7b9ff3e8b5
6 changed files with 71 additions and 115 deletions

View file

@ -2426,7 +2426,6 @@ class Daemon(AuthJSONRPCServer):
pass pass
@requires(WALLET_COMPONENT, conditions=[WALLET_IS_UNLOCKED]) @requires(WALLET_COMPONENT, conditions=[WALLET_IS_UNLOCKED])
@defer.inlineCallbacks
def jsonrpc_claim_send_to_address(self, claim_id, address, amount=None): def jsonrpc_claim_send_to_address(self, claim_id, address, amount=None):
""" """
Send a name claim to an address Send a name claim to an address
@ -2453,9 +2452,10 @@ class Daemon(AuthJSONRPCServer):
} }
""" """
result = yield self.wallet_manager.send_claim_to_address(claim_id, address, amount) decode_address(address)
response = yield self._render_response(result) return self.wallet_manager.send_claim_to_address(
return response claim_id, address, self.get_dewies_or_error("amount", amount) if amount else None
)
# TODO: claim_list_mine should be merged into claim_list, but idk how to authenticate it -Grin # TODO: claim_list_mine should be merged into claim_list, but idk how to authenticate it -Grin
@requires(WALLET_COMPONENT) @requires(WALLET_COMPONENT)
@ -3252,16 +3252,18 @@ class Daemon(AuthJSONRPCServer):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_channel_or_error(self, channel_id: str = None, name: str = None): def get_channel_or_error(self, channel_id: str = None, channel_name: str = None):
if channel_id is not None: if channel_id is not None:
certificates = yield self.wallet_manager.get_certificates(claim_id=channel_id) certificates = yield self.wallet_manager.get_certificates(
private_key_accounts=[self.default_account], claim_id=channel_id)
if not certificates: if not certificates:
raise ValueError("Couldn't find channel with claim_id '{}'." .format(channel_id)) raise ValueError("Couldn't find channel with claim_id '{}'." .format(channel_id))
return certificates[0] return certificates[0]
if name is not None: if channel_name is not None:
certificates = yield self.wallet_manager.get_certificates(name=name) certificates = yield self.wallet_manager.get_certificates(
private_key_accounts=[self.default_account], claim_name=channel_name)
if not certificates: if not certificates:
raise ValueError("Couldn't find channel with name '{}'.".format(name)) raise ValueError("Couldn't find channel with name '{}'.".format(channel_name))
return certificates[0] return certificates[0]
raise ValueError("Couldn't find channel because a channel name or channel_id was not provided.") raise ValueError("Couldn't find channel because a channel name or channel_id was not provided.")

View file

@ -167,7 +167,10 @@ class Account(BaseAccount):
return details return details
def get_claim(self, claim_id=None, txid=None, nout=None): def get_claim(self, claim_id=None, txid=None, nout=None):
return self.ledger.db.get_claim(self, claim_id, txid, nout) if claim_id is not None:
return self.ledger.db.get_claims(account=self, claim_id=claim_id)
elif txid is not None and nout is not None:
return self.ledger.db.get_claims(**{'account': self, 'txo.txid': txid, 'txo.position': nout})
def get_claims(self): def get_claims(self):
return self.ledger.db.get_claims(self) return self.ledger.db.get_claims(account=self)

View file

@ -1,5 +1,5 @@
from collections import namedtuple from collections import namedtuple
class Certificate(namedtuple('Certificate', ('txid', 'nout', 'claim_id', 'name', 'private_key'))): class Certificate(namedtuple('Certificate', ('channel', 'private_key'))):
pass pass

View file

@ -1,7 +1,5 @@
from twisted.internet import defer from twisted.internet import defer
from torba.basedatabase import BaseDatabase from torba.basedatabase import BaseDatabase
from torba.hash import TXRefImmutable
from torba.basetransaction import TXORef
from .certificate import Certificate from .certificate import Certificate
@ -48,96 +46,27 @@ class WalletDatabase(BaseDatabase):
row['claim_name'] = txo.claim_name row['claim_name'] = txo.claim_name
return row return row
def get_claims(self, **constraints):
constraints['claim_type__any'] = {'is_claim': 1, 'is_update': 1}
return self.get_utxos(**constraints)
def get_channels(self, **constraints):
if 'claim_name' not in constraints or 'claim_id' not in constraints:
constraints['claim_name__like'] = '@%'
return self.get_claims(**constraints)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_certificates(self, name=None, channel_id=None, private_key_accounts=None, exclude_without_key=False): def get_certificates(self, private_key_accounts, exclude_without_key=False, **constraints):
if name is not None: channels = yield self.get_channels(**constraints)
filter_sql = 'claim_name=?'
filter_value = name
elif channel_id is not None:
filter_sql = 'claim_id=?'
filter_value = channel_id
else:
raise ValueError("'name' or 'claim_id' is required")
txos = yield self.db.runQuery(
"""
SELECT tx.txid, txo.position, txo.claim_id
FROM txo JOIN tx ON tx.txid=txo.txid
WHERE {} AND (is_claim OR is_update)
GROUP BY txo.claim_id ORDER BY tx.height DESC, tx.position ASC;
""".format(filter_sql), (filter_value,)
)
certificates = [] certificates = []
# Lookup private keys for each certificate.
if private_key_accounts is not None: if private_key_accounts is not None:
for txid, nout, claim_id in txos: for channel in channels:
private_key = None
for account in private_key_accounts: for account in private_key_accounts:
private_key = account.get_certificate_private_key( private_key = account.get_certificate_private_key(channel.ref)
TXORef(TXRefImmutable.from_id(txid), nout) if private_key is not None:
) break
certificates.append(Certificate(txid, nout, claim_id, name, private_key)) if private_key is None and exclude_without_key:
continue
if exclude_without_key: certificates.append(Certificate(channel, private_key))
return [c for c in certificates if c.private_key is not None]
return certificates return certificates
@defer.inlineCallbacks
def get_claim(self, account, claim_id=None, txid=None, nout=None):
if claim_id is not None:
filter_sql = "claim_id=?"
filter_value = (claim_id,)
else:
filter_sql = "txo.txid=? AND txo.position=?"
filter_value = (txid, nout)
utxos = yield self.db.runQuery(
"""
SELECT amount, script, txo.txid, txo.position, account
FROM txo
JOIN tx ON tx.txid=txo.txid
JOIN pubkey_address ON pubkey_address.address=txo.address
WHERE {}
AND (is_claim OR is_update)
AND txoid NOT IN (SELECT txoid FROM txi)
ORDER BY tx.height DESC, tx.position ASC LIMIT 1;
""".format(filter_sql), filter_value
)
output_class = account.ledger.transaction_class.output_class
account_id = account.public_key.address
return [
output_class(
values[0],
output_class.script_class(values[1]),
TXRefImmutable.from_id(values[2]),
position=values[3],
is_change=False,
is_my_account=values[4] == account_id
) for values in utxos
]
@defer.inlineCallbacks
def get_claims(self, account):
utxos = yield self.db.runQuery(
"""
SELECT amount, script, txo.txid, txo.position
FROM txo
JOIN tx ON tx.txid=txo.txid
JOIN pubkey_address ON pubkey_address.address=txo.address
WHERE (is_claim OR is_update)
AND txoid NOT IN (SELECT txoid FROM txi)
AND account = :account
ORDER BY tx.height DESC, tx.position ASC;
""", {'account': account.public_key.address}
)
output_class = account.ledger.transaction_class.output_class
return [
output_class(
values[0],
output_class.script_class(values[1]),
TXRefImmutable.from_id(values[2]),
position=values[3],
is_change=False,
is_my_account=True
) for values in utxos
]

View file

@ -2,7 +2,7 @@ import os
import json import json
import logging import logging
from datetime import datetime from datetime import datetime
from typing import List from typing import List, Optional
from twisted.internet import defer from twisted.internet import defer
@ -200,6 +200,20 @@ class LbryWalletManager(BaseWalletManager):
yield account.ledger.broadcast(tx) yield account.ledger.broadcast(tx)
return tx return tx
@defer.inlineCallbacks
def send_claim_to_address(self, claim_id: str, destination_address: str, amount: Optional[int],
account=None):
account = account or self.default_account
claims = account.ledger.db.get_utxos(claim_id=claim_id)
if not claims:
raise NameError("Claim not found: {}".format(claim_id))
tx = yield Transaction.update(
claims[0], ClaimDict.deserialize(claims[0].script.value['claim']), amount,
destination_address.encode(), [account], account
)
yield self.ledger.broadcast(tx)
return tx
def send_points_to_address(self, reserved: ReservedPoints, amount: int, account=None): def send_points_to_address(self, reserved: ReservedPoints, amount: int, account=None):
destination_address: bytes = reserved.identifier.encode('latin1') destination_address: bytes = reserved.identifier.encode('latin1')
return self.send_amount_to_address(amount, destination_address, account) return self.send_amount_to_address(amount, destination_address, account)
@ -297,7 +311,7 @@ class LbryWalletManager(BaseWalletManager):
claim_address = yield account.receiving.get_or_create_usable_address() claim_address = yield account.receiving.get_or_create_usable_address()
if certificate: if certificate:
claim = claim.sign( claim = claim.sign(
certificate.private_key, claim_address, certificate.claim_id certificate.private_key, claim_address, certificate.channel.claim_id
) )
existing_claims = yield account.get_unspent_outputs(include_claims=True, claim_name=name) existing_claims = yield account.get_unspent_outputs(include_claims=True, claim_name=name)
if len(existing_claims) == 0: if len(existing_claims) == 0:
@ -315,7 +329,7 @@ class LbryWalletManager(BaseWalletManager):
tx, tx.outputs[0], claim_address, claim_dict, name, amount tx, tx.outputs[0], claim_address, claim_dict, name, amount
)]) )])
# TODO: release reserved tx outputs in case anything fails by this point # TODO: release reserved tx outputs in case anything fails by this point
defer.returnValue(tx) return tx
def _old_get_temp_claim_info(self, tx, txo, address, claim_dict, name, bid): def _old_get_temp_claim_info(self, tx, txo, address, claim_dict, name, bid):
return { return {
@ -371,8 +385,12 @@ class LbryWalletManager(BaseWalletManager):
def channel_list(self): def channel_list(self):
return self.default_account.get_channels() return self.default_account.get_channels()
def get_certificates(self, name=None, claim_id=None): def get_certificates(self, private_key_accounts, exclude_without_key=True, **constraints):
return self.db.get_certificates(name, claim_id, self.accounts, exclude_without_key=True) return self.db.get_certificates(
private_key_accounts=private_key_accounts,
exclude_without_key=exclude_without_key,
**constraints
)
def update_peer_address(self, peer, address): def update_peer_address(self, peer, address):
pass # TODO: Data payments is disabled pass # TODO: Data payments is disabled

View file

@ -34,16 +34,18 @@ class BasicAccountingTests(LedgerTestCase):
address = yield self.account.receiving.get_or_create_usable_address() address = yield self.account.receiving.get_or_create_usable_address()
hash160 = self.ledger.address_to_hash160(address) hash160 = self.ledger.address_to_hash160(address)
tx = Transaction().add_outputs([Output.pay_pubkey_hash(100, hash160)]) tx = Transaction(is_verified=True)\
.add_outputs([Output.pay_pubkey_hash(100, hash160)])
yield self.ledger.db.save_transaction_io( yield self.ledger.db.save_transaction_io(
'insert', tx, True, address, hash160, '{}:{}:'.format(tx.id, 1) 'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1)
) )
balance = yield self.account.get_balance(0) balance = yield self.account.get_balance(0)
self.assertEqual(balance, 100) self.assertEqual(balance, 100)
tx = Transaction().add_outputs([Output.pay_claim_name_pubkey_hash(100, 'foo', b'', hash160)]) tx = Transaction(is_verified=True)\
.add_outputs([Output.pay_claim_name_pubkey_hash(100, 'foo', b'', hash160)])
yield self.ledger.db.save_transaction_io( yield self.ledger.db.save_transaction_io(
'insert', tx, True, address, hash160, '{}:{}:'.format(tx.id, 1) 'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1)
) )
balance = yield self.account.get_balance(0) balance = yield self.account.get_balance(0)
self.assertEqual(balance, 100) # claim names don't count towards balance self.assertEqual(balance, 100) # claim names don't count towards balance
@ -55,17 +57,19 @@ class BasicAccountingTests(LedgerTestCase):
address = yield self.account.receiving.get_or_create_usable_address() address = yield self.account.receiving.get_or_create_usable_address()
hash160 = self.ledger.address_to_hash160(address) hash160 = self.ledger.address_to_hash160(address)
tx = Transaction().add_outputs([Output.pay_pubkey_hash(100, hash160)]) tx = Transaction(is_verified=True)\
.add_outputs([Output.pay_pubkey_hash(100, hash160)])
yield self.ledger.db.save_transaction_io( yield self.ledger.db.save_transaction_io(
'insert', tx, True, address, hash160, '{}:{}:'.format(tx.id, 1) 'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1)
) )
utxos = yield self.account.get_unspent_outputs() utxos = yield self.account.get_unspent_outputs()
self.assertEqual(len(utxos), 1) self.assertEqual(len(utxos), 1)
tx = Transaction().add_inputs([Input.spend(utxos[0])]) tx = Transaction(is_verified=True)\
.add_inputs([Input.spend(utxos[0])])
yield self.ledger.db.save_transaction_io( yield self.ledger.db.save_transaction_io(
'insert', tx, True, address, hash160, '{}:{}:'.format(tx.id, 1) 'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1)
) )
balance = yield self.account.get_balance(0, include_claims=True) balance = yield self.account.get_balance(0, include_claims=True)
self.assertEqual(balance, 0) self.assertEqual(balance, 0)