refactored wallet and mnemonic

This commit is contained in:
Lex Berezhny 2020-05-18 08:26:36 -04:00
parent 7c4f943bcb
commit be6ebf0047
7 changed files with 293 additions and 165 deletions

View file

@ -1,3 +1,3 @@
from .account import Account, AddressManager, SingleKey
from .wallet import Wallet from .wallet import Wallet
from .manager import WalletManager from .manager import WalletManager
from .account import Account, SingleKey, HierarchicalDeterministic

View file

@ -6,35 +6,23 @@ import asyncio
import random import random
from functools import partial from functools import partial
from hashlib import sha256 from hashlib import sha256
from string import hexdigits
from typing import Type, Dict, Tuple, Optional, Any, List from typing import Type, Dict, Tuple, Optional, Any, List
import ecdsa import ecdsa
from lbry.constants import COIN
from lbry.db import Database, CLAIM_TYPE_CODES, TXO_TYPES
from lbry.blockchain import Ledger, Transaction, Input, Output
from lbry.error import InvalidPasswordError from lbry.error import InvalidPasswordError
from lbry.crypto.crypt import aes_encrypt, aes_decrypt from lbry.crypto.crypt import aes_encrypt, aes_decrypt
from lbry.crypto.bip32 import PrivateKey, PubKey, from_extended_key_string from lbry.crypto.bip32 import PrivateKey, PubKey, from_extended_key_string
from lbry.constants import COIN
from lbry.blockchain.transaction import Transaction, Input, Output
from lbry.blockchain.ledger import Ledger
from lbry.db import Database
from lbry.db.constants import CLAIM_TYPE_CODES, TXO_TYPES
from .mnemonic import Mnemonic from . import mnemonic
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def validate_claim_id(claim_id):
if not len(claim_id) == 40:
raise Exception("Incorrect claimid length: %i" % len(claim_id))
if isinstance(claim_id, bytes):
claim_id = claim_id.decode('utf-8')
if set(claim_id).difference(hexdigits):
raise Exception("Claim id is not hex encoded")
class AddressManager: class AddressManager:
name: str name: str
@ -48,8 +36,7 @@ class AddressManager:
self.address_generator_lock = asyncio.Lock() self.address_generator_lock = asyncio.Lock()
@classmethod @classmethod
def from_dict(cls, account: 'Account', d: dict) \ def from_dict(cls, account: 'Account', d: dict) -> Tuple['AddressManager', 'AddressManager']:
-> Tuple['AddressManager', 'AddressManager']:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@ -222,24 +209,21 @@ class SingleKey(AddressManager):
class Account: class Account:
mnemonic_class = Mnemonic
private_key_class = PrivateKey
public_key_class = PubKey
address_generators: Dict[str, Type[AddressManager]] = { address_generators: Dict[str, Type[AddressManager]] = {
SingleKey.name: SingleKey, SingleKey.name: SingleKey,
HierarchicalDeterministic.name: HierarchicalDeterministic, HierarchicalDeterministic.name: HierarchicalDeterministic,
} }
def __init__(self, ledger: 'Ledger', db: 'Database', name: str, def __init__(self, ledger: Ledger, db: Database, name: str,
seed: str, private_key_string: str, encrypted: bool, phrase: str, language: str, private_key_string: str,
private_key: Optional[PrivateKey], public_key: PubKey, encrypted: bool, private_key: Optional[PrivateKey], public_key: PubKey,
address_generator: dict, modified_on: float, channel_keys: dict) -> None: address_generator: dict, modified_on: float, channel_keys: dict) -> None:
self.ledger = ledger self.ledger = ledger
self.db = db self.db = db
self.id = public_key.address self.id = public_key.address
self.name = name self.name = name
self.seed = seed self.phrase = phrase
self.language = language
self.modified_on = modified_on self.modified_on = modified_on
self.private_key_string = private_key_string self.private_key_string = private_key_string
self.init_vectors: Dict[str, bytes] = {} self.init_vectors: Dict[str, bytes] = {}
@ -251,6 +235,7 @@ class Account:
self.receiving, self.change = self.address_generator.from_dict(self, address_generator) self.receiving, self.change = self.address_generator.from_dict(self, address_generator)
self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}} self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}}
self.channel_keys = channel_keys self.channel_keys = channel_keys
self._channel_keys_deserialized = {}
def get_init_vector(self, key) -> Optional[bytes]: def get_init_vector(self, key) -> Optional[bytes]:
init_vector = self.init_vectors.get(key, None) init_vector = self.init_vectors.get(key, None)
@ -259,42 +244,40 @@ class Account:
return init_vector return init_vector
@classmethod @classmethod
def generate(cls, ledger: 'Ledger', db: 'Database', async def generate(
name: str = None, address_generator: dict = None): cls, ledger: Ledger, db: Database,
return cls.from_dict(ledger, db, { name: str = None, language: str = 'en',
address_generator: dict = None):
return await cls.from_dict(ledger, db, {
'name': name, 'name': name,
'seed': cls.mnemonic_class().make_seed(), 'seed': await mnemonic.generate_phrase(language),
'language': language,
'address_generator': address_generator or {} 'address_generator': address_generator or {}
}) })
@classmethod @classmethod
def get_private_key_from_seed(cls, ledger: 'Ledger', seed: str, password: str): async def keys_from_dict(cls, ledger: Ledger, d: dict) -> Tuple[str, Optional[PrivateKey], PubKey]:
return cls.private_key_class.from_seed( phrase = d.get('seed', '')
ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password or 'lbryum')
)
@classmethod
def keys_from_dict(cls, ledger: 'Ledger', d: dict) \
-> Tuple[str, Optional[PrivateKey], PubKey]:
seed = d.get('seed', '')
private_key_string = d.get('private_key', '') private_key_string = d.get('private_key', '')
private_key = None private_key = None
public_key = None public_key = None
encrypted = d.get('encrypted', False) encrypted = d.get('encrypted', False)
if not encrypted: if not encrypted:
if seed: if phrase:
private_key = cls.get_private_key_from_seed(ledger, seed, '') private_key = PrivateKey.from_seed(
ledger, await mnemonic.derive_key_from_phrase(phrase)
)
public_key = private_key.public_key public_key = private_key.public_key
elif private_key_string: elif private_key_string:
private_key = from_extended_key_string(ledger, private_key_string) private_key = from_extended_key_string(ledger, private_key_string)
public_key = private_key.public_key public_key = private_key.public_key
if public_key is None: if public_key is None:
public_key = from_extended_key_string(ledger, d['public_key']) public_key = from_extended_key_string(ledger, d['public_key'])
return seed, private_key, public_key return phrase, private_key, public_key
@classmethod @classmethod
def from_dict(cls, ledger: 'Ledger', db: 'Database', d: dict): async def from_dict(cls, ledger: Ledger, db: Database, d: dict):
seed, private_key, public_key = cls.keys_from_dict(ledger, d) phrase, private_key, public_key = await cls.keys_from_dict(ledger, d)
name = d.get('name') name = d.get('name')
if not name: if not name:
name = f'Account #{public_key.address}' name = f'Account #{public_key.address}'
@ -302,7 +285,8 @@ class Account:
ledger=ledger, ledger=ledger,
db=db, db=db,
name=name, name=name,
seed=seed, phrase=phrase,
language=d.get('lang', 'en'),
private_key_string=d.get('private_key', ''), private_key_string=d.get('private_key', ''),
encrypted=d.get('encrypted', False), encrypted=d.get('encrypted', False),
private_key=private_key, private_key=private_key,
@ -313,7 +297,7 @@ class Account:
) )
def to_dict(self, encrypt_password: str = None, include_channel_keys: bool = True): def to_dict(self, encrypt_password: str = None, include_channel_keys: bool = True):
private_key_string, seed = self.private_key_string, self.seed private_key_string, phrase = self.private_key_string, self.phrase
if not self.encrypted and self.private_key: if not self.encrypted and self.private_key:
private_key_string = self.private_key.extended_key_string() private_key_string = self.private_key.extended_key_string()
if not self.encrypted and encrypt_password: if not self.encrypted and encrypt_password:
@ -321,11 +305,12 @@ class Account:
private_key_string = aes_encrypt( private_key_string = aes_encrypt(
encrypt_password, private_key_string, self.get_init_vector('private_key') encrypt_password, private_key_string, self.get_init_vector('private_key')
) )
if seed: if phrase:
seed = aes_encrypt(encrypt_password, self.seed, self.get_init_vector('seed')) phrase = aes_encrypt(encrypt_password, self.phrase, self.get_init_vector('phrase'))
d = { d = {
'name': self.name, 'name': self.name,
'seed': seed, 'seed': phrase,
'lang': self.language,
'encrypted': bool(self.encrypted or encrypt_password), 'encrypted': bool(self.encrypted or encrypt_password),
'private_key': private_key_string, 'private_key': private_key_string,
'public_key': self.public_key.extended_key_string(), 'public_key': self.public_key.extended_key_string(),
@ -367,21 +352,21 @@ class Account:
'address_generator': self.address_generator.to_dict(self.receiving, self.change) 'address_generator': self.address_generator.to_dict(self.receiving, self.change)
} }
if show_seed: if show_seed:
details['seed'] = self.seed details['seed'] = self.phrase
details['certificates'] = len(self.channel_keys) details['certificates'] = len(self.channel_keys)
return details return details
def decrypt(self, password: str) -> bool: def decrypt(self, password: str) -> bool:
assert self.encrypted, "Key is not encrypted." assert self.encrypted, "Key is not encrypted."
try: try:
seed = self._decrypt_seed(password) phrase = self._decrypt_phrase(password)
except (ValueError, InvalidPasswordError): except (ValueError, InvalidPasswordError):
return False return False
try: try:
private_key = self._decrypt_private_key_string(password) private_key = self._decrypt_private_key_string(password)
except (TypeError, ValueError, InvalidPasswordError): except (TypeError, ValueError, InvalidPasswordError):
return False return False
self.seed = seed self.phrase = phrase
self.private_key = private_key self.private_key = private_key
self.private_key_string = "" self.private_key_string = ""
self.encrypted = False self.encrypted = False
@ -397,24 +382,20 @@ class Account:
self.ledger, private_key_string self.ledger, private_key_string
) )
def _decrypt_seed(self, password: str) -> str: def _decrypt_phrase(self, password: str) -> str:
if not self.seed: if not self.phrase:
return "" return ""
seed, self.init_vectors['seed'] = aes_decrypt(password, self.seed) phrase, self.init_vectors['phrase'] = aes_decrypt(password, self.phrase)
if not seed: if not phrase:
return "" return ""
try: if not mnemonic.is_phrase_valid(self.language, phrase):
Mnemonic().mnemonic_decode(seed) raise ValueError("Failed to decode seed phrase.")
except IndexError: return phrase
# failed to decode the seed, this either means it decrypted and is invalid
# or that we hit an edge case where an incorrect password gave valid padding
raise ValueError("Failed to decode seed.")
return seed
def encrypt(self, password: str) -> bool: def encrypt(self, password: str) -> bool:
assert not self.encrypted, "Key is already encrypted." assert not self.encrypted, "Key is already encrypted."
if self.seed: if self.phrase:
self.seed = aes_encrypt(password, self.seed, self.get_init_vector('seed')) self.phrase = aes_encrypt(password, self.phrase, self.get_init_vector('phrase'))
if isinstance(self.private_key, PrivateKey): if isinstance(self.private_key, PrivateKey):
self.private_key_string = aes_encrypt( self.private_key_string = aes_encrypt(
password, self.private_key.extended_key_string(), self.get_init_vector('private_key') password, self.private_key.extended_key_string(), self.get_init_vector('private_key')
@ -504,12 +485,20 @@ class Account:
public_key_bytes = private_key.get_verifying_key().to_der() public_key_bytes = private_key.get_verifying_key().to_der()
channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes) channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes)
self.channel_keys[channel_pubkey_hash] = private_key.to_pem().decode() self.channel_keys[channel_pubkey_hash] = private_key.to_pem().decode()
self._channel_keys_deserialized[channel_pubkey_hash] = private_key
def get_channel_private_key(self, public_key_bytes): async def get_channel_private_key(self, public_key_bytes):
channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes) channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes)
private_key = self._channel_keys_deserialized.get(channel_pubkey_hash)
if private_key:
return private_key
private_key_pem = self.channel_keys.get(channel_pubkey_hash) private_key_pem = self.channel_keys.get(channel_pubkey_hash)
if private_key_pem: if private_key_pem:
return ecdsa.SigningKey.from_pem(private_key_pem, hashfunc=sha256) private_key = await asyncio.get_running_loop().run_in_executor(
None, ecdsa.SigningKey.from_pem, private_key_pem, sha256
)
self._channel_keys_deserialized[channel_pubkey_hash] = private_key
return private_key
async def maybe_migrate_certificates(self): async def maybe_migrate_certificates(self):
def to_der(private_key_pem): def to_der(private_key_pem):

View file

@ -29,10 +29,19 @@ class WalletManager:
for wallet in self.wallets.values(): for wallet in self.wallets.values():
return wallet return wallet
def get_or_default(self, wallet_id: Optional[str]) -> Optional[Wallet]: def get_or_default(self, wallet_id: Optional[str]) -> Wallet:
if wallet_id: if wallet_id:
return self[wallet_id] return self[wallet_id]
return self.default wallet = self.default
if not wallet:
raise ValueError("No wallets available.")
return wallet
def get_or_default_for_spending(self, wallet_id: Optional[str]) -> Wallet:
wallet = self.get_or_default(wallet_id)
if wallet.is_locked:
raise ValueError("Cannot spend funds with locked wallet, unlock first.")
return wallet
@property @property
def path(self): def path(self):
@ -72,7 +81,7 @@ class WalletManager:
create_account=self.ledger.conf.create_default_account create_account=self.ledger.conf.create_default_account
) )
elif not default_wallet.has_accounts and self.ledger.conf.create_default_account: elif not default_wallet.has_accounts and self.ledger.conf.create_default_account:
default_wallet.accounts.generate() await default_wallet.accounts.generate()
def add(self, wallet: Wallet) -> Wallet: def add(self, wallet: Wallet) -> Wallet:
self.wallets[wallet.id] = wallet self.wallets[wallet.id] = wallet
@ -92,11 +101,16 @@ class WalletManager:
wallet = await Wallet.from_path(self.ledger, self.db, wallet_path) wallet = await Wallet.from_path(self.ledger, self.db, wallet_path)
return self.add(wallet) return self.add(wallet)
async def create(self, wallet_id: str, name: str, create_account=False, single_key=False) -> Wallet: async def create(
self, wallet_id: str, name: str,
create_account=False, language='en', single_key=False) -> Wallet:
if wallet_id in self.wallets: if wallet_id in self.wallets:
raise Exception(f"Wallet with id '{wallet_id}' is already loaded and cannot be created.") raise Exception(f"Wallet with id '{wallet_id}' is already loaded and cannot be created.")
wallet_path = os.path.join(self.path, wallet_id) wallet_path = os.path.join(self.path, wallet_id)
if os.path.exists(wallet_path): if os.path.exists(wallet_path):
raise Exception(f"Wallet at path '{wallet_path}' already exists, use 'wallet_add' to load wallet.") raise Exception(f"Wallet at path '{wallet_path}' already exists, use 'wallet_add' to load wallet.")
wallet = await Wallet.create(self.ledger, self.db, wallet_path, name, create_account, single_key) wallet = await Wallet.create(
self.ledger, self.db, wallet_path, name,
create_account, language, single_key
)
return self.add(wallet) return self.add(wallet)

View file

@ -12,19 +12,19 @@ def get_languages():
return words.languages return words.languages
def normalize(mnemonic: str) -> str: def normalize(phrase: str) -> str:
return ' '.join(unicodedata.normalize('NFKD', mnemonic).lower().split()) return ' '.join(unicodedata.normalize('NFKD', phrase).lower().split())
def is_valid(language, mnemonic): def is_phrase_valid(language, phrase):
local_words = getattr(words, language) local_words = getattr(words, language)
for word in normalize(mnemonic).split(): for word in normalize(phrase).split():
if word not in local_words: if word not in local_words:
return False return False
return bool(mnemonic) return bool(phrase)
def sync_generate(language: str) -> str: def sync_generate_phrase(language: str) -> str:
local_words = getattr(words, language) local_words = getattr(words, language)
entropy = randbits(132) entropy = randbits(132)
nonce = 0 nonce = 0
@ -41,17 +41,17 @@ def sync_generate(language: str) -> str:
return seed return seed
def sync_to_seed(mnemonic: str) -> bytes: def sync_derive_key_from_phrase(phrase: str) -> bytes:
return hashlib.pbkdf2_hmac('sha512', normalize(mnemonic).encode(), b'lbryum', 2048) return hashlib.pbkdf2_hmac('sha512', normalize(phrase).encode(), b'lbryum', 2048)
async def generate(language: str) -> str: async def generate_phrase(language: str) -> str:
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, sync_generate, language None, sync_generate_phrase, language
) )
async def to_seed(mnemonic: str) -> bytes: async def derive_key_from_phrase(phrase: str) -> bytes:
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, sync_to_seed, mnemonic None, sync_derive_key_from_phrase, phrase
) )

View file

@ -7,7 +7,6 @@ from collections import defaultdict
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from typing import List, Optional, DefaultDict, NamedTuple from typing import List, Optional, DefaultDict, NamedTuple
import pylru
from lbry.crypto.hash import double_sha256, sha256 from lbry.crypto.hash import double_sha256, sha256
from lbry.service.api import Client from lbry.service.api import Client
@ -80,7 +79,7 @@ class SPVSync(Sync):
self._on_ready_controller = EventController() self._on_ready_controller = EventController()
self.on_ready = self._on_ready_controller.stream self.on_ready = self._on_ready_controller.stream
self._tx_cache = pylru.lrucache(100000) #self._tx_cache = pylru.lrucache(100000)
self._update_tasks = TaskGroup() self._update_tasks = TaskGroup()
self._other_tasks = TaskGroup() # that we dont need to start self._other_tasks = TaskGroup() # that we dont need to start
self._header_processing_lock = asyncio.Lock() self._header_processing_lock = asyncio.Lock()

View file

@ -3,7 +3,7 @@ import json
import zlib import zlib
import asyncio import asyncio
import logging import logging
from typing import List, Sequence, Tuple, Optional, Iterable from typing import Awaitable, Callable, List, Tuple, Optional, Iterable, Union
from hashlib import sha256 from hashlib import sha256
from operator import attrgetter from operator import attrgetter
from decimal import Decimal from decimal import Decimal
@ -19,6 +19,7 @@ from lbry.crypto.bip32 import PubKey, PrivateKey
from lbry.schema.claim import Claim from lbry.schema.claim import Claim
from lbry.schema.purchase import Purchase from lbry.schema.purchase import Purchase
from lbry.error import InsufficientFundsError, KeyFeeAboveMaxAllowedError from lbry.error import InsufficientFundsError, KeyFeeAboveMaxAllowedError
from lbry.stream.managed_stream import ManagedStream
from .account import Account, SingleKey, HierarchicalDeterministic from .account import Account, SingleKey, HierarchicalDeterministic
from .coinselection import CoinSelector, OutputEffectiveAmountEstimator from .coinselection import CoinSelector, OutputEffectiveAmountEstimator
@ -61,10 +62,12 @@ class Wallet:
return os.path.basename(self.storage.path) if self.storage.path else self.name return os.path.basename(self.storage.path) if self.storage.path else self.name
@classmethod @classmethod
async def create(cls, ledger: Ledger, db: Database, path: str, name: str, create_account=False, single_key=False): async def create(
cls, ledger: Ledger, db: Database, path: str, name: str,
create_account=False, language='en', single_key=False):
wallet = cls(ledger, db, name, WalletStorage(path), {}) wallet = cls(ledger, db, name, WalletStorage(path), {})
if create_account: if create_account:
wallet.accounts.generate(address_generator={ await wallet.accounts.generate(language=language, address_generator={
'name': SingleKey.name if single_key else HierarchicalDeterministic.name 'name': SingleKey.name if single_key else HierarchicalDeterministic.name
}) })
await wallet.save() await wallet.save()
@ -88,7 +91,7 @@ class Wallet:
preferences=json_dict.get('preferences', {}), preferences=json_dict.get('preferences', {}),
) )
for account_dict in json_dict.get('accounts', []): for account_dict in json_dict.get('accounts', []):
wallet.accounts.add_from_dict(account_dict) await wallet.accounts.add_from_dict(account_dict)
return wallet return wallet
def to_dict(self, encrypt_password: str = None): def to_dict(self, encrypt_password: str = None):
@ -135,13 +138,13 @@ class Wallet:
decompressed = zlib.decompress(decrypted) decompressed = zlib.decompress(decrypted)
return json.loads(decompressed) return json.loads(decompressed)
def merge(self, password: str, data: str) -> List[Account]: async def merge(self, password: str, data: str) -> List[Account]:
assert not self.is_locked, "Cannot sync apply on a locked wallet." assert not self.is_locked, "Cannot sync apply on a locked wallet."
added_accounts = [] added_accounts = []
decrypted_data = self.unpack(password, data) decrypted_data = self.unpack(password, data)
self.preferences.merge(decrypted_data.get('preferences', {})) self.preferences.merge(decrypted_data.get('preferences', {}))
for account_dict in decrypted_data['accounts']: for account_dict in decrypted_data['accounts']:
_, _, pubkey = Account.keys_from_dict(self.ledger, account_dict) _, _, pubkey = await Account.keys_from_dict(self.ledger, account_dict)
account_id = pubkey.address account_id = pubkey.address
local_match = None local_match = None
for local_account in self.accounts: for local_account in self.accounts:
@ -182,18 +185,18 @@ class Wallet:
def is_encrypted(self) -> bool: def is_encrypted(self) -> bool:
return self.is_locked or self.preferences.get(ENCRYPT_ON_DISK, False) return self.is_locked or self.preferences.get(ENCRYPT_ON_DISK, False)
def decrypt(self): async def decrypt(self):
assert not self.is_locked, "Cannot decrypt a locked wallet, unlock first." assert not self.is_locked, "Cannot decrypt a locked wallet, unlock first."
self.preferences[ENCRYPT_ON_DISK] = False self.preferences[ENCRYPT_ON_DISK] = False
self.save() await self.save()
return True return True
def encrypt(self, password): async def encrypt(self, password):
assert not self.is_locked, "Cannot re-encrypt a locked wallet, unlock first." assert not self.is_locked, "Cannot re-encrypt a locked wallet, unlock first."
assert password, "Cannot encrypt with blank password." assert password, "Cannot encrypt with blank password."
self.encryption_password = password self.encryption_password = password
self.preferences[ENCRYPT_ON_DISK] = True self.preferences[ENCRYPT_ON_DISK] = True
self.save() await self.save()
return True return True
@property @property
@ -232,7 +235,7 @@ class Wallet:
if await account.save_max_gap(): if await account.save_max_gap():
gap_changed = True gap_changed = True
if gap_changed: if gap_changed:
self.save() await self.save()
async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]): async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]):
estimators = [] estimators = []
@ -260,9 +263,9 @@ class Wallet:
**constraints **constraints
), self.ledger) ), self.ledger)
async def create_transaction(self, inputs: Iterable[Input], outputs: Iterable[Output], async def create_transaction(
funding_accounts: Iterable[Account], change_account: Account, self, inputs: Iterable[Input], outputs: Iterable[Output],
sign: bool = True): funding_accounts: Iterable[Account], change_account: Account):
""" Find optimal set of inputs when only outputs are provided; add change """ Find optimal set of inputs when only outputs are provided; add change
outputs if only inputs are provided or if inputs are greater than outputs. """ outputs if only inputs are provided or if inputs are greater than outputs. """
@ -318,11 +321,7 @@ class Wallet:
# less than the fee, after 5 attempts we give up and go home # less than the fee, after 5 attempts we give up and go home
cost += cost_of_change + 1 cost += cost_of_change + 1
if sign:
await self.sign(tx)
except Exception as e: except Exception as e:
log.exception('Failed to create transaction:')
await self.db.release_tx(tx) await self.db.release_tx(tx)
raise e raise e
@ -374,6 +373,15 @@ class Wallet:
'Failed to display wallet state, please file issue ' 'Failed to display wallet state, please file issue '
'for this bug along with the traceback you see below:') 'for this bug along with the traceback you see below:')
async def verify_duplicate(self, name: str, allow_duplicate: bool):
if not allow_duplicate:
claims, _ = await self.claims.list(claim_name=name)
if len(claims) > 0:
raise Exception(
f"You already have a claim published under the name '{name}'. "
f"Use --allow-duplicate-name flag to override."
)
class AccountListManager: class AccountListManager:
__slots__ = 'wallet', '_accounts' __slots__ = 'wallet', '_accounts'
@ -399,13 +407,15 @@ class AccountListManager:
for account in self: for account in self:
return account return account
def generate(self, name: str = None, address_generator: dict = None) -> Account: async def generate(self, name: str = None, language: str = 'en', address_generator: dict = None) -> Account:
account = Account.generate(self.wallet.ledger, self.wallet.db, name, address_generator) account = await Account.generate(
self.wallet.ledger, self.wallet.db, name, language, address_generator
)
self._accounts.append(account) self._accounts.append(account)
return account return account
def add_from_dict(self, account_dict: dict) -> Account: async def add_from_dict(self, account_dict: dict) -> Account:
account = Account.from_dict(self.wallet.ledger, self.wallet.db, account_dict) account = await Account.from_dict(self.wallet.ledger, self.wallet.db, account_dict)
self._accounts.append(account) self._accounts.append(account)
return account return account
@ -424,7 +434,9 @@ class AccountListManager:
return self.default return self.default
return self[account_id] return self[account_id]
def get_or_all(self, account_ids: List[str]) -> List[Account]: def get_or_all(self, account_ids: Union[List[str], str]) -> List[Account]:
if account_ids and isinstance(account_ids, str):
account_ids = [account_ids]
return [self[account_id] for account_id in account_ids] if account_ids else self._accounts return [self[account_id] for account_id in account_ids] if account_ids else self._accounts
async def get_account_details(self, **kwargs): async def get_account_details(self, **kwargs):
@ -437,16 +449,15 @@ class AccountListManager:
class BaseListManager: class BaseListManager:
__slots__ = 'wallet', 'db' __slots__ = 'wallet',
def __init__(self, wallet: Wallet): def __init__(self, wallet: Wallet):
self.wallet = wallet self.wallet = wallet
self.db = wallet.db
async def create(self, **kwargs) -> Transaction: async def create(self, **kwargs) -> Transaction:
raise NotImplementedError raise NotImplementedError
async def delete(self, **constraints): async def delete(self, **constraints) -> Transaction:
raise NotImplementedError raise NotImplementedError
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]: async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
@ -463,21 +474,39 @@ class ClaimListManager(BaseListManager):
name = 'claim' name = 'claim'
__slots__ = () __slots__ = ()
async def create( async def _create(
self, name: str, claim: Claim, amount: int, holding_address: str, self, name: str, claim: Claim, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account, signing_channel: Output = None): funding_accounts: List[Account], change_account: Account,
claim_output = Output.pay_claim_name_pubkey_hash( signing_channel: Output = None) -> Transaction:
txo = Output.pay_claim_name_pubkey_hash(
amount, name, claim, self.wallet.ledger.address_to_hash160(holding_address) amount, name, claim, self.wallet.ledger.address_to_hash160(holding_address)
) )
if signing_channel is not None: if signing_channel is not None:
claim_output.sign(signing_channel, b'placeholder txid:nout') txo.sign(signing_channel, b'placeholder txid:nout')
return await self.wallet.create_transaction( tx = await self.wallet.create_transaction(
[], [claim_output], funding_accounts, change_account, sign=False [], [txo], funding_accounts, change_account
) )
return tx
async def create(
self, name: str, claim: Claim, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account,
signing_channel: Output = None) -> Transaction:
tx = await self._create(
name, claim, amount, holding_address,
funding_accounts, change_account,
signing_channel
)
txo = tx.outputs[0]
if signing_channel is not None:
txo.sign(signing_channel)
await self.wallet.sign(tx)
return tx
async def update( async def update(
self, previous_claim: Output, claim: Claim, amount: int, holding_address: str, self, previous_claim: Output, claim: Claim, amount: int, holding_address: str,
funding_accounts: List[Account], change_account: Account, signing_channel: Output = None): funding_accounts: List[Account], change_account: Account,
signing_channel: Output = None) -> Transaction:
updated_claim = Output.pay_update_claim_pubkey_hash( updated_claim = Output.pay_update_claim_pubkey_hash(
amount, previous_claim.claim_name, previous_claim.claim_id, amount, previous_claim.claim_name, previous_claim.claim_id,
claim, self.wallet.ledger.address_to_hash160(holding_address) claim, self.wallet.ledger.address_to_hash160(holding_address)
@ -497,7 +526,7 @@ class ClaimListManager(BaseListManager):
) )
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]: async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
return await self.db.get_claims(wallet=self.wallet, **constraints) return await self.wallet.db.get_claims(wallet=self.wallet, **constraints)
async def get(self, claim_id=None, claim_name=None, txid=None, nout=None) -> Output: async def get(self, claim_id=None, claim_name=None, txid=None, nout=None) -> Output:
if txid is not None and nout is not None: if txid is not None and nout is not None:
@ -523,59 +552,157 @@ class ClaimListManager(BaseListManager):
return await self.get(claim_id, claim_name, txid, nout) return await self.get(claim_id, claim_name, txid, nout)
class ChannelListManager(ClaimListManager):
name = 'channel'
__slots__ = ()
async def create(
self, name: str, amount: int, holding_account: Account,
funding_accounts: List[Account], save_key=True, **kwargs) -> Transaction:
holding_address = await holding_account.receiving.get_or_create_usable_address()
claim = Claim()
claim.channel.update(**kwargs)
txo = Output.pay_claim_name_pubkey_hash(
amount, name, claim, self.wallet.ledger.address_to_hash160(holding_address)
)
await txo.generate_channel_private_key()
tx = await self.wallet.create_transaction(
[], [txo], funding_accounts, funding_accounts[0]
)
await self.wallet.sign(tx)
if save_key:
holding_account.add_channel_private_key(txo.private_key)
await self.wallet.save()
return tx
async def update(
self, old: Output, amount: int, new_signing_key: bool, replace: bool,
holding_account: Account, funding_accounts: List[Account],
save_key=True, **kwargs) -> Transaction:
moving_accounts = False
holding_address = old.get_address(self.wallet.ledger)
if holding_account:
old_account = await self.wallet.get_account_for_address(holding_address)
if holding_account.id != old_account.id:
holding_address = await holding_account.receiving.get_or_create_usable_address()
moving_accounts = True
elif new_signing_key:
holding_account = await self.wallet.get_account_for_address(holding_address)
if replace:
claim = Claim()
claim.channel.public_key_bytes = old.claim.channel.public_key_bytes
else:
claim = Claim.from_bytes(old.claim.to_bytes())
claim.channel.update(**kwargs)
txo = Output.pay_update_claim_pubkey_hash(
amount, old.claim_name, old.claim_id, claim,
self.wallet.ledger.address_to_hash160(holding_address)
)
if new_signing_key:
await txo.generate_channel_private_key()
else:
txo.private_key = old.private_key
tx = await self.wallet.create_transaction(
[Input.spend(old)], [txo], funding_accounts, funding_accounts[0]
)
await self.wallet.sign(tx)
if any((new_signing_key, moving_accounts)) and save_key:
holding_account.add_channel_private_key(txo.private_key)
await self.wallet.save()
return tx
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
return await self.wallet.db.get_channels(wallet=self.wallet, **constraints)
async def get_for_signing(self, channel_id=None, channel_name=None) -> Output:
channel = await self.get(claim_id=channel_id, claim_name=channel_name)
if not channel.has_private_key:
raise Exception(
f"Couldn't find private key for channel '{channel.claim_name}', "
f"can't use channel for signing. "
)
return channel
async def get_for_signing_or_none(self, channel_id=None, channel_name=None) -> Optional[Output]:
if channel_id or channel_name:
return await self.get_for_signing(channel_id, channel_name)
class StreamListManager(ClaimListManager): class StreamListManager(ClaimListManager):
__slots__ = () __slots__ = ()
async def create(self, *args, **kwargs): async def create(
return await super().create(*args, **kwargs) self, name: str, amount: int, file_path: str,
create_file_stream: Callable[[str], Awaitable[ManagedStream]],
holding_address: str, funding_accounts: List[Account],
signing_channel: Optional[Output] = None,
preview=False, **kwargs) -> Tuple[Transaction, ManagedStream]:
claim = Claim()
claim.stream.update(file_path=file_path, sd_hash='0' * 96, **kwargs)
# before creating file stream, create TX to ensure we have enough LBC
tx = await self._create(
name, claim, amount, holding_address,
funding_accounts, funding_accounts[0],
signing_channel
)
txo = tx.outputs[0]
file_stream = None
try:
# we have enough LBC to create TX, now try create the file stream
if not preview:
file_stream = await create_file_stream(file_path)
claim.stream.source.sd_hash = file_stream.sd_hash
txo.script.generate()
# creating TX and file stream was successful, now sign all the things
if signing_channel is not None:
txo.sign(signing_channel)
await self.wallet.sign(tx)
except Exception as e:
# creating file stream or something else went wrong, release txos
await self.wallet.db.release_tx(tx)
raise e
return tx, file_stream
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]: async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
return await self.db.get_streams(wallet=self.wallet, **constraints) return await self.wallet.db.get_streams(wallet=self.wallet, **constraints)
class CollectionListManager(ClaimListManager): class CollectionListManager(ClaimListManager):
__slots__ = () __slots__ = ()
async def create(self, *args, **kwargs): async def create(
return await super().create(*args, **kwargs) self, name: str, amount: int, holding_address: str, funding_accounts: List[Account],
channel: Optional[Output] = None, **kwargs) -> Transaction:
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
return await self.db.get_collections(wallet=self.wallet, **constraints)
class ChannelListManager(ClaimListManager):
name = 'channel'
__slots__ = ()
async def create(self, name: str, amount: int, account: Account, funding_accounts: List[Account],
claim_address: str, preview=False, **kwargs):
claim = Claim() claim = Claim()
claim.channel.update(**kwargs) claim.collection.update(**kwargs)
tx = await super().create( return await super().create(
name, claim, amount, claim_address, funding_accounts, funding_accounts[0] name, claim, amount, holding_address, funding_accounts, funding_accounts[0], channel
) )
txo = tx.outputs[0]
txo.generate_channel_private_key()
await self.wallet.sign(tx)
if not preview:
account.add_channel_private_key(txo.private_key)
await self.wallet.save()
return tx
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]: async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
return await self.db.get_channels(wallet=self.wallet, **constraints) return await self.wallet.db.get_collections(wallet=self.wallet, **constraints)
async def get_for_signing(self, **kwargs) -> Output:
channel = await self.get(**kwargs)
if not channel.has_private_key:
raise Exception(
f"Couldn't find private key for channel '{channel.claim_name}', can't use channel for signing. "
)
return channel
async def get_for_signing_or_none(self, **kwargs) -> Optional[Output]:
if any(kwargs.values()):
return await self.get_for_signing(**kwargs)
class SupportListManager(BaseListManager): class SupportListManager(BaseListManager):
@ -591,7 +718,7 @@ class SupportListManager(BaseListManager):
) )
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]: async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
return await self.db.get_supports(**constraints) return await self.wallet.db.get_supports(**constraints)
async def get(self, **constraints) -> Output: async def get(self, **constraints) -> Output:
raise NotImplementedError raise NotImplementedError
@ -645,7 +772,7 @@ class PurchaseListManager(BaseListManager):
) )
async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]: async def list(self, **constraints) -> Tuple[List[Output], Optional[int]]:
return await self.db.get_purchases(**constraints) return await self.wallet.db.get_purchases(**constraints)
async def get(self, **constraints) -> Output: async def get(self, **constraints) -> Output:
raise NotImplementedError raise NotImplementedError

View file

@ -2,7 +2,6 @@ from .english import words as en
from .french import words as fr from .french import words as fr
from .italian import words as it from .italian import words as it
from .japanese import words as ja from .japanese import words as ja
from .portuguese import words as pt
from .spanish import words as es from .spanish import words as es
from .chinese_simplified import words as zh from .chinese import words as zh
languages = 'en', 'fr', 'it', 'ja', 'pt', 'es', 'zh languages = 'en', 'fr', 'it', 'ja', 'es', 'zh'