forked from LBRYCommunity/lbry-sdk
01cc4f28e0
Convert some of the calls to inlineCallbacks, which allowed the code to be cleaned up and made more clear
353 lines
14 KiB
Python
353 lines
14 KiB
Python
import logging
|
|
import sqlite3
|
|
import os
|
|
from twisted.internet import defer
|
|
from twisted.python.failure import Failure
|
|
from twisted.enterprise import adbapi
|
|
from lbrynet.core.Error import DuplicateStreamHashError, NoSuchStreamHashError
|
|
from lbrynet.core.sqlite_helpers import rerun_if_locked
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class DBEncryptedFileMetadataManager(object):
|
|
"""Store and provide access to LBRY file metadata using sqlite"""
|
|
|
|
def __init__(self, db_dir):
|
|
self.db_dir = db_dir
|
|
self.stream_info_db = None
|
|
self.stream_blob_db = None
|
|
self.stream_desc_db = None
|
|
|
|
def setup(self):
|
|
return self._open_db()
|
|
|
|
def stop(self):
|
|
self.db_conn = None
|
|
return defer.succeed(True)
|
|
|
|
def get_all_streams(self):
|
|
return self._get_all_streams()
|
|
|
|
def save_stream(self, stream_hash, file_name, key, suggested_file_name, blobs):
|
|
d = self._store_stream(stream_hash, file_name, key, suggested_file_name)
|
|
d.addCallback(lambda _: self.add_blobs_to_stream(stream_hash, blobs))
|
|
return d
|
|
|
|
def get_stream_info(self, stream_hash):
|
|
return self._get_stream_info(stream_hash)
|
|
|
|
def check_if_stream_exists(self, stream_hash):
|
|
return self._check_if_stream_exists(stream_hash)
|
|
|
|
def delete_stream(self, stream_hash):
|
|
return self._delete_stream(stream_hash)
|
|
|
|
def add_blobs_to_stream(self, stream_hash, blobs):
|
|
return self._add_blobs_to_stream(stream_hash, blobs, ignore_duplicate_error=True)
|
|
|
|
def get_blobs_for_stream(self, stream_hash, start_blob=None,
|
|
end_blob=None, count=None, reverse=False):
|
|
log.debug("Getting blobs for stream %s. Count is %s", stream_hash, count)
|
|
|
|
def get_positions_of_start_and_end():
|
|
if start_blob is not None:
|
|
d1 = self._get_blob_num_by_hash(stream_hash, start_blob)
|
|
else:
|
|
d1 = defer.succeed(None)
|
|
if end_blob is not None:
|
|
d2 = self._get_blob_num_by_hash(stream_hash, end_blob)
|
|
else:
|
|
d2 = defer.succeed(None)
|
|
|
|
dl = defer.DeferredList([d1, d2])
|
|
|
|
def get_positions(results):
|
|
start_num = None
|
|
end_num = None
|
|
if results[0][0] is True:
|
|
start_num = results[0][1]
|
|
if results[1][0] is True:
|
|
end_num = results[1][1]
|
|
return start_num, end_num
|
|
|
|
dl.addCallback(get_positions)
|
|
return dl
|
|
|
|
def get_blob_infos(nums):
|
|
start_num, end_num = nums
|
|
return self._get_further_blob_infos(stream_hash, start_num, end_num,
|
|
count, reverse)
|
|
|
|
d = get_positions_of_start_and_end()
|
|
d.addCallback(get_blob_infos)
|
|
return d
|
|
|
|
def get_stream_of_blob(self, blob_hash):
|
|
return self._get_stream_of_blobhash(blob_hash)
|
|
|
|
def save_sd_blob_hash_to_stream(self, stream_hash, sd_blob_hash):
|
|
return self._save_sd_blob_hash_to_stream(stream_hash, sd_blob_hash)
|
|
|
|
def get_sd_blob_hashes_for_stream(self, stream_hash):
|
|
return self._get_sd_blob_hashes_for_stream(stream_hash)
|
|
|
|
def _open_db(self):
|
|
# check_same_thread=False is solely to quiet a spurious error that appears to be due
|
|
# to a bug in twisted, where the connection is closed by a different thread than the
|
|
# one that opened it. The individual connections in the pool are not used in multiple
|
|
# threads.
|
|
self.db_conn = adbapi.ConnectionPool(
|
|
"sqlite3",
|
|
(os.path.join(self.db_dir, "lbryfile_info.db")),
|
|
check_same_thread=False)
|
|
|
|
def create_tables(transaction):
|
|
transaction.execute("create table if not exists lbry_files (" +
|
|
" stream_hash text primary key, " +
|
|
" key text, " +
|
|
" stream_name text, " +
|
|
" suggested_file_name text" +
|
|
")")
|
|
transaction.execute("create table if not exists lbry_file_blobs (" +
|
|
" blob_hash text, " +
|
|
" stream_hash text, " +
|
|
" position integer, " +
|
|
" iv text, " +
|
|
" length integer, " +
|
|
" foreign key(stream_hash) references lbry_files(stream_hash)" +
|
|
")")
|
|
transaction.execute("create table if not exists lbry_file_descriptors (" +
|
|
" sd_blob_hash TEXT PRIMARY KEY, " +
|
|
" stream_hash TEXT, " +
|
|
" foreign key(stream_hash) references lbry_files(stream_hash)" +
|
|
")")
|
|
|
|
return self.db_conn.runInteraction(create_tables)
|
|
|
|
@rerun_if_locked
|
|
def _delete_stream(self, stream_hash):
|
|
d = self.db_conn.runQuery(
|
|
"select stream_hash from lbry_files where stream_hash = ?", (stream_hash,))
|
|
d.addCallback(
|
|
lambda result: result[0][0] if result else Failure(NoSuchStreamHashError(stream_hash)))
|
|
|
|
def do_delete(transaction, s_h):
|
|
transaction.execute("delete from lbry_files where stream_hash = ?", (s_h,))
|
|
transaction.execute("delete from lbry_file_blobs where stream_hash = ?", (s_h,))
|
|
transaction.execute("delete from lbry_file_descriptors where stream_hash = ?", (s_h,))
|
|
|
|
d.addCallback(lambda s_h: self.db_conn.runInteraction(do_delete, s_h))
|
|
return d
|
|
|
|
@rerun_if_locked
|
|
def _store_stream(self, stream_hash, name, key, suggested_file_name):
|
|
d = self.db_conn.runQuery("insert into lbry_files values (?, ?, ?, ?)",
|
|
(stream_hash, key, name, suggested_file_name))
|
|
|
|
def check_duplicate(err):
|
|
if err.check(sqlite3.IntegrityError):
|
|
raise DuplicateStreamHashError(stream_hash)
|
|
return err
|
|
|
|
d.addErrback(check_duplicate)
|
|
return d
|
|
|
|
@rerun_if_locked
|
|
def _get_all_streams(self):
|
|
d = self.db_conn.runQuery("select stream_hash from lbry_files")
|
|
d.addCallback(lambda results: [r[0] for r in results])
|
|
return d
|
|
|
|
@rerun_if_locked
|
|
def _get_stream_info(self, stream_hash):
|
|
def get_result(res):
|
|
if res:
|
|
return res[0]
|
|
else:
|
|
raise NoSuchStreamHashError(stream_hash)
|
|
|
|
d = self.db_conn.runQuery(
|
|
"select key, stream_name, suggested_file_name from lbry_files where stream_hash = ?",
|
|
(stream_hash,))
|
|
d.addCallback(get_result)
|
|
return d
|
|
|
|
@rerun_if_locked
|
|
def _check_if_stream_exists(self, stream_hash):
|
|
d = self.db_conn.runQuery(
|
|
"select stream_hash from lbry_files where stream_hash = ?", (stream_hash,))
|
|
d.addCallback(lambda r: True if len(r) else False)
|
|
return d
|
|
|
|
@rerun_if_locked
|
|
def _get_blob_num_by_hash(self, stream_hash, blob_hash):
|
|
d = self.db_conn.runQuery(
|
|
"select position from lbry_file_blobs where stream_hash = ? and blob_hash = ?",
|
|
(stream_hash, blob_hash))
|
|
d.addCallback(lambda r: r[0][0] if len(r) else None)
|
|
return d
|
|
|
|
@rerun_if_locked
|
|
def _get_further_blob_infos(self, stream_hash, start_num, end_num, count=None, reverse=False):
|
|
params = []
|
|
q_string = "select * from ("
|
|
q_string += " select blob_hash, position, iv, length from lbry_file_blobs "
|
|
q_string += " where stream_hash = ? "
|
|
params.append(stream_hash)
|
|
if start_num is not None:
|
|
q_string += " and position > ? "
|
|
params.append(start_num)
|
|
if end_num is not None:
|
|
q_string += " and position < ? "
|
|
params.append(end_num)
|
|
q_string += " order by position "
|
|
if reverse is True:
|
|
q_string += " DESC "
|
|
if count is not None:
|
|
q_string += " limit ? "
|
|
params.append(count)
|
|
q_string += ") order by position"
|
|
# Order by position is done twice so that it always returns them from lowest position to
|
|
# greatest, but the limit by clause can select the 'count' greatest or 'count' least
|
|
return self.db_conn.runQuery(q_string, tuple(params))
|
|
|
|
@rerun_if_locked
|
|
def _add_blobs_to_stream(self, stream_hash, blob_infos, ignore_duplicate_error=False):
|
|
|
|
def add_blobs(transaction):
|
|
for blob_info in blob_infos:
|
|
try:
|
|
transaction.execute("insert into lbry_file_blobs values (?, ?, ?, ?, ?)",
|
|
(blob_info.blob_hash, stream_hash, blob_info.blob_num,
|
|
blob_info.iv, blob_info.length))
|
|
except sqlite3.IntegrityError:
|
|
if ignore_duplicate_error is False:
|
|
raise
|
|
|
|
return self.db_conn.runInteraction(add_blobs)
|
|
|
|
@rerun_if_locked
|
|
def _get_stream_of_blobhash(self, blob_hash):
|
|
d = self.db_conn.runQuery("select stream_hash from lbry_file_blobs where blob_hash = ?",
|
|
(blob_hash,))
|
|
d.addCallback(lambda r: r[0][0] if len(r) else None)
|
|
return d
|
|
|
|
@rerun_if_locked
|
|
def _save_sd_blob_hash_to_stream(self, stream_hash, sd_blob_hash):
|
|
log.info("Saving sd blob hash %s to stream hash %s", str(sd_blob_hash), str(stream_hash))
|
|
d = self.db_conn.runQuery("insert into lbry_file_descriptors values (?, ?)",
|
|
(sd_blob_hash, stream_hash))
|
|
|
|
def ignore_duplicate(err):
|
|
err.trap(sqlite3.IntegrityError)
|
|
log.info("sd blob hash already known")
|
|
|
|
d.addErrback(ignore_duplicate)
|
|
return d
|
|
|
|
@rerun_if_locked
|
|
def _get_sd_blob_hashes_for_stream(self, stream_hash):
|
|
log.debug("Looking up sd blob hashes for stream hash %s", str(stream_hash))
|
|
d = self.db_conn.runQuery(
|
|
"select sd_blob_hash from lbry_file_descriptors where stream_hash = ?",
|
|
(stream_hash,))
|
|
d.addCallback(lambda results: [r[0] for r in results])
|
|
return d
|
|
|
|
|
|
class TempEncryptedFileMetadataManager(object):
|
|
def __init__(self):
|
|
self.streams = {}
|
|
self.stream_blobs = {}
|
|
self.sd_files = {}
|
|
|
|
def setup(self):
|
|
return defer.succeed(True)
|
|
|
|
def stop(self):
|
|
return defer.succeed(True)
|
|
|
|
def get_all_streams(self):
|
|
return defer.succeed(self.streams.keys())
|
|
|
|
def save_stream(self, stream_hash, file_name, key, suggested_file_name, blobs):
|
|
self.streams[stream_hash] = {'suggested_file_name': suggested_file_name,
|
|
'stream_name': file_name,
|
|
'key': key}
|
|
d = self.add_blobs_to_stream(stream_hash, blobs)
|
|
d.addCallback(lambda _: stream_hash)
|
|
return d
|
|
|
|
def get_stream_info(self, stream_hash):
|
|
if stream_hash in self.streams:
|
|
stream_info = self.streams[stream_hash]
|
|
return defer.succeed([stream_info['key'], stream_info['stream_name'],
|
|
stream_info['suggested_file_name']])
|
|
return defer.succeed(None)
|
|
|
|
def delete_stream(self, stream_hash):
|
|
if stream_hash in self.streams:
|
|
del self.streams[stream_hash]
|
|
for (s_h, b_h) in self.stream_blobs.keys():
|
|
if s_h == stream_hash:
|
|
del self.stream_blobs[(s_h, b_h)]
|
|
return defer.succeed(True)
|
|
|
|
def add_blobs_to_stream(self, stream_hash, blobs):
|
|
assert stream_hash in self.streams, "Can't add blobs to a stream that isn't known"
|
|
for blob in blobs:
|
|
info = {}
|
|
info['blob_num'] = blob.blob_num
|
|
info['length'] = blob.length
|
|
info['iv'] = blob.iv
|
|
self.stream_blobs[(stream_hash, blob.blob_hash)] = info
|
|
return defer.succeed(True)
|
|
|
|
def get_blobs_for_stream(self, stream_hash, start_blob=None,
|
|
end_blob=None, count=None, reverse=False):
|
|
|
|
if start_blob is not None:
|
|
start_num = self._get_blob_num_by_hash(stream_hash, start_blob)
|
|
else:
|
|
start_num = None
|
|
if end_blob is not None:
|
|
end_num = self._get_blob_num_by_hash(stream_hash, end_blob)
|
|
else:
|
|
end_num = None
|
|
return self._get_further_blob_infos(stream_hash, start_num, end_num, count, reverse)
|
|
|
|
def get_stream_of_blob(self, blob_hash):
|
|
for (s_h, b_h) in self.stream_blobs.iterkeys():
|
|
if b_h == blob_hash:
|
|
return defer.succeed(s_h)
|
|
return defer.succeed(None)
|
|
|
|
def _get_further_blob_infos(self, stream_hash, start_num, end_num, count=None, reverse=False):
|
|
blob_infos = []
|
|
for (s_h, b_h), info in self.stream_blobs.iteritems():
|
|
if stream_hash == s_h:
|
|
position = info['blob_num']
|
|
length = info['length']
|
|
iv = info['iv']
|
|
if (start_num is None) or (position > start_num):
|
|
if (end_num is None) or (position < end_num):
|
|
blob_infos.append((b_h, position, iv, length))
|
|
blob_infos.sort(key=lambda i: i[1], reverse=reverse)
|
|
if count is not None:
|
|
blob_infos = blob_infos[:count]
|
|
return defer.succeed(blob_infos)
|
|
|
|
def _get_blob_num_by_hash(self, stream_hash, blob_hash):
|
|
if (stream_hash, blob_hash) in self.stream_blobs:
|
|
return defer.succeed(self.stream_blobs[(stream_hash, blob_hash)]['blob_num'])
|
|
|
|
def save_sd_blob_hash_to_stream(self, stream_hash, sd_blob_hash):
|
|
self.sd_files[sd_blob_hash] = stream_hash
|
|
return defer.succeed(True)
|
|
|
|
def get_sd_blob_hashes_for_stream(self, stream_hash):
|
|
return defer.succeed(
|
|
[sd_hash for sd_hash, s_h in self.sd_files.iteritems() if stream_hash == s_h])
|