forked from LBRYCommunity/lbry-sdk
227 lines
8.3 KiB
Python
227 lines
8.3 KiB
Python
import os
|
|
import stat
|
|
import json
|
|
import asyncio
|
|
import logging
|
|
from typing import Optional, Dict
|
|
|
|
from lbry.db import Database
|
|
|
|
from .wallet import Wallet
|
|
from .account import SingleKey, HierarchicalDeterministic
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class WalletManager:
|
|
|
|
def __init__(self, db: Database):
|
|
self.db = db
|
|
self.ledger = db.ledger
|
|
self.wallets: Dict[str, Wallet] = {}
|
|
if self.ledger.conf.wallet_storage == "file":
|
|
self.storage = FileWallet(self.db, self.ledger.conf.wallet_dir)
|
|
elif self.ledger.conf.wallet_storage == "database":
|
|
self.storage = DatabaseWallet(self.db)
|
|
else:
|
|
raise Exception(f"Unknown wallet storage format: {self.ledger.conf.wallet_storage}")
|
|
|
|
def __getitem__(self, wallet_id: str) -> Wallet:
|
|
try:
|
|
return self.wallets[wallet_id]
|
|
except KeyError:
|
|
raise ValueError(f"Couldn't find wallet: {wallet_id}.")
|
|
|
|
async def generate_addresses(self):
|
|
for wallet in self.wallets.values():
|
|
await wallet.generate_addresses()
|
|
|
|
@property
|
|
def default(self) -> Optional[Wallet]:
|
|
for wallet in self.wallets.values():
|
|
return wallet
|
|
|
|
def get_or_default(self, wallet_id: Optional[str]) -> Wallet:
|
|
if wallet_id:
|
|
return self[wallet_id]
|
|
wallet = self.default
|
|
if not wallet:
|
|
raise ValueError("No wallets available.")
|
|
return wallet
|
|
|
|
def get_or_default_for_spending(self, wallet_id: Optional[str]) -> Wallet:
|
|
wallet = self.get_or_default(wallet_id)
|
|
if wallet.is_locked:
|
|
raise ValueError("Cannot spend funds with locked wallet, unlock first.")
|
|
return wallet
|
|
|
|
async def initialize(self):
|
|
for wallet_id in self.ledger.conf.wallets:
|
|
if wallet_id in self.wallets:
|
|
log.warning("Ignoring duplicate wallet_id in config: %s", wallet_id)
|
|
continue
|
|
await self.load(wallet_id)
|
|
default_wallet = self.default
|
|
if default_wallet is None:
|
|
if self.ledger.conf.create_default_wallet:
|
|
assert self.ledger.conf.wallets[0] == "default_wallet", (
|
|
"Requesting to generate the default wallet but the 'wallets' "
|
|
"config setting does not include 'default_wallet' as the first wallet."
|
|
)
|
|
await self.create(
|
|
self.ledger.conf.wallets[0], 'Wallet',
|
|
create_account=self.ledger.conf.create_default_account
|
|
)
|
|
elif not default_wallet.has_accounts and self.ledger.conf.create_default_account:
|
|
await default_wallet.accounts.generate()
|
|
|
|
async def load(self, wallet_id: str) -> Optional[Wallet]:
|
|
wallet = await self.storage.get(wallet_id)
|
|
if wallet is not None:
|
|
return self.add(wallet)
|
|
|
|
async def create(
|
|
self, wallet_id: str, name: str = "",
|
|
create_account=False, language="en", single_key=False
|
|
) -> Wallet:
|
|
if wallet_id in self.wallets:
|
|
raise Exception(f"Wallet with id '{wallet_id}' is already loaded and cannot be created.")
|
|
if await self.storage.exists(wallet_id):
|
|
raise Exception(f"Wallet '{wallet_id}' already exists, use 'wallet_add' to load wallet.")
|
|
wallet = Wallet(wallet_id, self.db, name)
|
|
if create_account:
|
|
await wallet.accounts.generate(language=language, address_generator={
|
|
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
|
|
})
|
|
await self.storage.save(wallet)
|
|
return self.add(wallet)
|
|
|
|
def add(self, wallet: Wallet) -> Wallet:
|
|
self.wallets[wallet.id] = wallet
|
|
wallet.on_change.listen(lambda _: self.storage.save(wallet))
|
|
return wallet
|
|
|
|
async def _report_state(self):
|
|
try:
|
|
for account in self.accounts:
|
|
balance = dewies_to_lbc(await account.get_balance(include_claims=True))
|
|
_, channel_count = await account.get_channels(limit=1)
|
|
claim_count = await account.get_claim_count()
|
|
if isinstance(account.receiving, SingleKey):
|
|
log.info("Loaded single key account %s with %s LBC. "
|
|
"%d channels, %d certificates and %d claims",
|
|
account.id, balance, channel_count, len(account.channel_keys), claim_count)
|
|
else:
|
|
total_receiving = len(await account.receiving.get_addresses())
|
|
total_change = len(await account.change.get_addresses())
|
|
log.info("Loaded account %s with %s LBC, %d receiving addresses (gap: %d), "
|
|
"%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ",
|
|
account.id, balance, total_receiving, account.receiving.gap, total_change,
|
|
account.change.gap, channel_count, len(account.channel_keys), claim_count)
|
|
except Exception as err:
|
|
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
|
|
raise
|
|
log.exception(
|
|
'Failed to display wallet state, please file issue '
|
|
'for this bug along with the traceback you see below:'
|
|
)
|
|
|
|
|
|
class WalletStorage:
|
|
|
|
async def prepare(self):
|
|
raise NotImplementedError
|
|
|
|
async def exists(self, walllet_id: str) -> bool:
|
|
raise NotImplementedError
|
|
|
|
async def get(self, wallet_id: str) -> Wallet:
|
|
raise NotImplementedError
|
|
|
|
async def save(self, wallet: Wallet):
|
|
raise NotImplementedError
|
|
|
|
|
|
class FileWallet(WalletStorage):
|
|
|
|
def __init__(self, db, wallet_dir):
|
|
self.db = db
|
|
self.wallet_dir = wallet_dir
|
|
|
|
def get_wallet_path(self, wallet_id: str):
|
|
return os.path.join(self.wallet_dir, wallet_id)
|
|
|
|
async def prepare(self):
|
|
await asyncio.get_running_loop().run_in_executor(
|
|
None, self.sync_ensure_wallets_directory_exists
|
|
)
|
|
|
|
def sync_ensure_wallets_directory_exists(self):
|
|
if not os.path.exists(self.wallet_dir):
|
|
os.mkdir(self.wallet_dir)
|
|
|
|
async def exists(self, wallet_id: str) -> bool:
|
|
return os.path.exists(self.get_wallet_path(wallet_id))
|
|
|
|
async def get(self, wallet_id: str) -> Wallet:
|
|
wallet_dict = await asyncio.get_running_loop().run_in_executor(
|
|
None, self.sync_read, wallet_id
|
|
)
|
|
if wallet_dict is not None:
|
|
return await Wallet.from_dict(wallet_id, wallet_dict, self.db)
|
|
|
|
def sync_read(self, wallet_id):
|
|
try:
|
|
with open(self.get_wallet_path(wallet_id), 'r') as f:
|
|
json_data = f.read()
|
|
return json.loads(json_data)
|
|
except FileNotFoundError:
|
|
return None
|
|
|
|
async def save(self, wallet: Wallet):
|
|
return await asyncio.get_running_loop().run_in_executor(
|
|
None, self.sync_write, wallet
|
|
)
|
|
|
|
def sync_write(self, wallet: Wallet):
|
|
temp_path = os.path.join(self.wallet_dir, f".tmp.{os.path.basename(wallet.id)}")
|
|
with open(temp_path, "w") as f:
|
|
f.write(wallet.to_serialized())
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
|
|
wallet_path = self.get_wallet_path(wallet.id)
|
|
if os.path.exists(wallet_path):
|
|
mode = os.stat(wallet_path).st_mode
|
|
else:
|
|
mode = stat.S_IREAD | stat.S_IWRITE
|
|
try:
|
|
os.rename(temp_path, wallet_path)
|
|
except Exception: # pylint: disable=broad-except
|
|
os.remove(wallet_path)
|
|
os.rename(temp_path, wallet_path)
|
|
os.chmod(wallet_path, mode)
|
|
|
|
|
|
class DatabaseWallet(WalletStorage):
|
|
|
|
def __init__(self, db: 'Database'):
|
|
self.db = db
|
|
|
|
async def prepare(self):
|
|
pass
|
|
|
|
async def exists(self, wallet_id: str) -> bool:
|
|
return await self.db.has_wallet(wallet_id)
|
|
|
|
async def get(self, wallet_id: str) -> Wallet:
|
|
data = await self.db.get_wallet(wallet_id)
|
|
if data:
|
|
wallet_dict = json.loads(data['data'])
|
|
if wallet_dict is not None:
|
|
return await Wallet.from_dict(wallet_id, wallet_dict, self.db)
|
|
|
|
async def save(self, wallet: Wallet):
|
|
await self.db.add_wallet(
|
|
wallet.id, wallet.to_serialized()
|
|
)
|