Merge branch 'remove_tempblobmanager'

This commit is contained in:
Jack Robison 2017-08-03 22:27:21 -04:00
commit 8d8946b96e
No known key found for this signature in database
GPG key ID: 284699E7404E3CFF
7 changed files with 67 additions and 257 deletions

View file

@ -33,7 +33,7 @@ at anytime.
* *
### Removed ### Removed
* * Removed TempBlobManager
* *

View file

@ -4,82 +4,17 @@ import time
import sqlite3 import sqlite3
from twisted.internet import threads, defer from twisted.internet import threads, defer
from twisted.python.failure import Failure
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from lbrynet.core.HashBlob import BlobFile, TempBlob, BlobFileCreator, TempBlobCreator from lbrynet.core.HashBlob import BlobFile, BlobFileCreator
from lbrynet.core.server.DHTHashAnnouncer import DHTHashSupplier from lbrynet.core.server.DHTHashAnnouncer import DHTHashSupplier
from lbrynet.core.Error import NoSuchBlobError
from lbrynet.core.sqlite_helpers import rerun_if_locked from lbrynet.core.sqlite_helpers import rerun_if_locked
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class DiskBlobManager(DHTHashSupplier):
class BlobManager(DHTHashSupplier):
"""This class is subclassed by classes which keep track of which blobs are available
and which give access to new/existing blobs"""
def __init__(self, hash_announcer):
DHTHashSupplier.__init__(self, hash_announcer)
def setup(self):
pass
def get_blob(self, blob_hash, length=None):
pass
def get_blob_creator(self):
pass
def _make_new_blob(self, blob_hash, length):
pass
def blob_completed(self, blob, next_announce_time=None):
pass
def completed_blobs(self, blobhashes_to_check):
pass
def hashes_to_announce(self):
pass
def creator_finished(self, blob_creator):
pass
def delete_blob(self, blob_hash):
pass
def blob_requested(self, blob_hash):
pass
def blob_downloaded(self, blob_hash):
pass
def blob_searched_on(self, blob_hash):
pass
def blob_paid_for(self, blob_hash, amount):
pass
def get_all_verified_blobs(self):
pass
def add_blob_to_download_history(self, blob_hash, host, rate):
pass
def add_blob_to_upload_history(self, blob_hash, host, rate):
pass
def _immediate_announce(self, blob_hashes):
if self.hash_announcer:
return self.hash_announcer.immediate_announce(blob_hashes)
# TODO: Having different managers for different blobs breaks the
# abstraction of a HashBlob. Why should the management of blobs
# care what kind of Blob it has?
class DiskBlobManager(BlobManager):
"""This class stores blobs on the hard disk""" """This class stores blobs on the hard disk"""
def __init__(self, hash_announcer, blob_dir, db_dir): def __init__(self, hash_announcer, blob_dir, db_dir):
BlobManager.__init__(self, hash_announcer) DHTHashSupplier.__init__(self, hash_announcer)
self.blob_dir = blob_dir self.blob_dir = blob_dir
self.db_file = os.path.join(db_dir, "blobs.db") self.db_file = os.path.join(db_dir, "blobs.db")
self.db_conn = adbapi.ConnectionPool('sqlite3', self.db_file, check_same_thread=False) self.db_conn = adbapi.ConnectionPool('sqlite3', self.db_file, check_same_thread=False)
@ -120,6 +55,10 @@ class DiskBlobManager(BlobManager):
self.blobs[blob_hash] = blob self.blobs[blob_hash] = blob
return defer.succeed(blob) return defer.succeed(blob)
def _immediate_announce(self, blob_hashes):
if self.hash_announcer:
return self.hash_announcer.immediate_announce(blob_hashes)
def blob_completed(self, blob, next_announce_time=None): def blob_completed(self, blob, next_announce_time=None):
if next_announce_time is None: if next_announce_time is None:
next_announce_time = self.get_next_announce_time() next_announce_time = self.get_next_announce_time()
@ -293,134 +232,3 @@ class DiskBlobManager(BlobManager):
return d return d
# TODO: Having different managers for different blobs breaks the
# abstraction of a HashBlob. Why should the management of blobs
# care what kind of Blob it has?
class TempBlobManager(BlobManager):
"""This class stores blobs in memory"""
def __init__(self, hash_announcer):
BlobManager.__init__(self, hash_announcer)
self.blob_type = TempBlob
self.blob_creator_type = TempBlobCreator
self.blobs = {}
self.blob_next_announces = {}
self.blob_hashes_to_delete = {} # {blob_hash: being_deleted (True/False)}
self._next_manage_call = None
def setup(self):
self._manage()
return defer.succeed(True)
def stop(self):
if self._next_manage_call is not None and self._next_manage_call.active():
self._next_manage_call.cancel()
self._next_manage_call = None
def get_blob(self, blob_hash, length=None):
if blob_hash in self.blobs:
return defer.succeed(self.blobs[blob_hash])
return self._make_new_blob(blob_hash, length)
def get_blob_creator(self):
return self.blob_creator_type(self)
def _make_new_blob(self, blob_hash, length=None):
blob = self.blob_type(blob_hash, length)
self.blobs[blob_hash] = blob
return defer.succeed(blob)
def blob_completed(self, blob, next_announce_time=None):
if next_announce_time is None:
next_announce_time = time.time()
self.blob_next_announces[blob.blob_hash] = next_announce_time
return defer.succeed(True)
def completed_blobs(self, blobhashes_to_check):
blobs = [
b.blob_hash for b in self.blobs.itervalues()
if b.blob_hash in blobhashes_to_check and b.is_validated()
]
return defer.succeed(blobs)
def get_all_verified_blobs(self):
d = self.completed_blobs(self.blobs)
return d
def hashes_to_announce(self):
now = time.time()
blobs = [
blob_hash for blob_hash, announce_time in self.blob_next_announces.iteritems()
if announce_time < now
]
next_announce_time = self.get_next_announce_time(len(blobs))
for b in blobs:
self.blob_next_announces[b] = next_announce_time
return defer.succeed(blobs)
def creator_finished(self, blob_creator):
assert blob_creator.blob_hash is not None
assert blob_creator.blob_hash not in self.blobs
assert blob_creator.length is not None
new_blob = self.blob_type(blob_creator.blob_hash, blob_creator.length)
# TODO: change this; its breaks the encapsulation of the
# blob. Maybe better would be to have the blob_creator
# produce a blob.
new_blob.data_buffer = blob_creator.data_buffer
new_blob._verified = True
self.blobs[blob_creator.blob_hash] = new_blob
self._immediate_announce([blob_creator.blob_hash])
next_announce_time = self.get_next_announce_time()
d = self.blob_completed(new_blob, next_announce_time)
d.addCallback(lambda _: new_blob)
return d
def delete_blobs(self, blob_hashes):
for blob_hash in blob_hashes:
if not blob_hash in self.blob_hashes_to_delete:
self.blob_hashes_to_delete[blob_hash] = False
def immediate_announce_all_blobs(self):
if self.hash_announcer:
return self.hash_announcer.immediate_announce(self.blobs.iterkeys())
def _manage(self):
from twisted.internet import reactor
d = self._delete_blobs_marked_for_deletion()
def set_next_manage_call():
log.info("Setting the next manage call in %s", str(self))
self._next_manage_call = reactor.callLater(1, self._manage)
d.addCallback(lambda _: set_next_manage_call())
def _delete_blobs_marked_for_deletion(self):
def remove_from_list(b_h):
del self.blob_hashes_to_delete[b_h]
log.info("Deleted blob %s", blob_hash)
return b_h
def set_not_deleting(err, b_h):
log.warning("Failed to delete blob %s. Reason: %s", str(b_h), err.getErrorMessage())
self.blob_hashes_to_delete[b_h] = False
return b_h
ds = []
for blob_hash, being_deleted in self.blob_hashes_to_delete.items():
if being_deleted is False:
if blob_hash in self.blobs:
self.blob_hashes_to_delete[blob_hash] = True
log.info("Found a blob marked for deletion: %s", blob_hash)
blob = self.blobs[blob_hash]
d = blob.delete()
d.addCallbacks(lambda _: remove_from_list(blob_hash), set_not_deleting,
errbackArgs=(blob_hash,))
ds.append(d)
else:
remove_from_list(blob_hash)
d = defer.fail(Failure(NoSuchBlobError(blob_hash)))
log.warning("Blob %s cannot be deleted because it is unknown")
ds.append(d)
return defer.DeferredList(ds)

View file

@ -1,6 +1,6 @@
import logging import logging
import miniupnpc import miniupnpc
from lbrynet.core.BlobManager import DiskBlobManager, TempBlobManager from lbrynet.core.BlobManager import DiskBlobManager
from lbrynet.dht import node from lbrynet.dht import node
from lbrynet.core.PeerManager import PeerManager from lbrynet.core.PeerManager import PeerManager
from lbrynet.core.RateLimiter import RateLimiter from lbrynet.core.RateLimiter import RateLimiter
@ -294,7 +294,8 @@ class Session(object):
if self.blob_manager is None: if self.blob_manager is None:
if self.blob_dir is None: if self.blob_dir is None:
self.blob_manager = TempBlobManager(self.hash_announcer) raise Exception(
"TempBlobManager is no longer supported, specify BlobManager or db_dir")
else: else:
self.blob_manager = DiskBlobManager(self.hash_announcer, self.blob_manager = DiskBlobManager(self.hash_announcer,
self.blob_dir, self.blob_dir,

View file

@ -33,6 +33,7 @@ from lbrynet.core.server.BlobRequestHandler import BlobRequestHandlerFactory
from lbrynet.core.server.ServerProtocol import ServerProtocolFactory from lbrynet.core.server.ServerProtocol import ServerProtocolFactory
from tests import mocks from tests import mocks
from tests.util import mk_db_and_blob_dir, rm_db_and_blob_dir
FakeNode = mocks.Node FakeNode = mocks.Node
FakeWallet = mocks.Wallet FakeWallet = mocks.Wallet
@ -45,7 +46,6 @@ DummyBlobAvailabilityTracker = mocks.BlobAvailabilityTracker
log_format = "%(funcName)s(): %(message)s" log_format = "%(funcName)s(): %(message)s"
logging.basicConfig(level=logging.CRITICAL, format=log_format) logging.basicConfig(level=logging.CRITICAL, format=log_format)
def require_system(system): def require_system(system):
def wrapper(fn): def wrapper(fn):
return fn return fn
@ -111,13 +111,12 @@ class LbryUploader(object):
hash_announcer = FakeAnnouncer() hash_announcer = FakeAnnouncer()
rate_limiter = RateLimiter() rate_limiter = RateLimiter()
self.sd_identifier = StreamDescriptorIdentifier() self.sd_identifier = StreamDescriptorIdentifier()
db_dir = "server" self.db_dir, self.blob_dir = mk_db_and_blob_dir()
os.mkdir(db_dir)
self.session = Session( self.session = Session(
conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=db_dir, lbryid="abcd", conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=self.db_dir, blob_dir=self.blob_dir,
peer_finder=peer_finder, hash_announcer=hash_announcer, peer_port=5553, lbryid="abcd", peer_finder=peer_finder, hash_announcer=hash_announcer,
use_upnp=False, rate_limiter=rate_limiter, wallet=wallet, peer_port=5553, use_upnp=False, rate_limiter=rate_limiter, wallet=wallet,
blob_tracker_class=DummyBlobAvailabilityTracker, blob_tracker_class=DummyBlobAvailabilityTracker,
dht_node_class=Node, is_generous=self.is_generous) dht_node_class=Node, is_generous=self.is_generous)
stream_info_manager = TempEncryptedFileMetadataManager() stream_info_manager = TempEncryptedFileMetadataManager()
@ -173,6 +172,7 @@ class LbryUploader(object):
self.kill_check.stop() self.kill_check.stop()
self.dead_event.set() self.dead_event.set()
dl = defer.DeferredList(ds) dl = defer.DeferredList(ds)
dl.addCallback(lambda _: rm_db_and_blob_dir(self.db_dir, self.blob_dir))
dl.addCallback(lambda _: self.reactor.stop()) dl.addCallback(lambda _: self.reactor.stop())
return dl return dl
@ -216,15 +216,11 @@ def start_lbry_reuploader(sd_hash, kill_event, dead_event,
rate_limiter = RateLimiter() rate_limiter = RateLimiter()
sd_identifier = StreamDescriptorIdentifier() sd_identifier = StreamDescriptorIdentifier()
db_dir = "server_" + str(n) db_dir, blob_dir = mk_db_and_blob_dir()
blob_dir = os.path.join(db_dir, "blobfiles")
os.mkdir(db_dir)
os.mkdir(blob_dir)
session = Session(conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=db_dir, session = Session(conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=db_dir,
lbryid="abcd" + str(n), lbryid="abcd" + str(n),
peer_finder=peer_finder, hash_announcer=hash_announcer, peer_finder=peer_finder, hash_announcer=hash_announcer,
blob_dir=None, peer_port=peer_port, blob_dir=blob_dir, peer_port=peer_port,
use_upnp=False, rate_limiter=rate_limiter, wallet=wallet, use_upnp=False, rate_limiter=rate_limiter, wallet=wallet,
blob_tracker_class=DummyBlobAvailabilityTracker, blob_tracker_class=DummyBlobAvailabilityTracker,
is_generous=conf.ADJUSTABLE_SETTINGS['is_generous_host'][1]) is_generous=conf.ADJUSTABLE_SETTINGS['is_generous_host'][1])
@ -289,6 +285,7 @@ def start_lbry_reuploader(sd_hash, kill_event, dead_event,
ds.append(lbry_file_manager.stop()) ds.append(lbry_file_manager.stop())
if server_port: if server_port:
ds.append(server_port.stopListening()) ds.append(server_port.stopListening())
ds.append(rm_db_and_blob_dir(db_dir, blob_dir))
kill_check.stop() kill_check.stop()
dead_event.set() dead_event.set()
dl = defer.DeferredList(ds) dl = defer.DeferredList(ds)
@ -327,13 +324,11 @@ def start_blob_uploader(blob_hash_queue, kill_event, dead_event, slow, is_genero
if slow is True: if slow is True:
peer_port = 5553 peer_port = 5553
db_dir = "server1"
else: else:
peer_port = 5554 peer_port = 5554
db_dir = "server2"
blob_dir = os.path.join(db_dir, "blobfiles")
os.mkdir(db_dir) db_dir, blob_dir = mk_db_and_blob_dir()
os.mkdir(blob_dir)
session = Session(conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=db_dir, lbryid="efgh", session = Session(conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=db_dir, lbryid="efgh",
peer_finder=peer_finder, hash_announcer=hash_announcer, peer_finder=peer_finder, hash_announcer=hash_announcer,
@ -385,6 +380,7 @@ def start_blob_uploader(blob_hash_queue, kill_event, dead_event, slow, is_genero
dead_event.set() dead_event.set()
dl = defer.DeferredList(ds) dl = defer.DeferredList(ds)
dl.addCallback(lambda _: reactor.stop()) dl.addCallback(lambda _: reactor.stop())
dl.addCallback(lambda _: rm_db_and_blob_dir(db_dir, blob_dir))
return dl return dl
def check_for_kill(): def check_for_kill():
@ -509,14 +505,10 @@ class TestTransfer(TestCase):
rate_limiter = DummyRateLimiter() rate_limiter = DummyRateLimiter()
sd_identifier = StreamDescriptorIdentifier() sd_identifier = StreamDescriptorIdentifier()
db_dir = "client" db_dir, blob_dir = mk_db_and_blob_dir()
blob_dir = os.path.join(db_dir, "blobfiles")
os.mkdir(db_dir)
os.mkdir(blob_dir)
self.session = Session( self.session = Session(
conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=db_dir, lbryid="abcd", conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=db_dir,
peer_finder=peer_finder, hash_announcer=hash_announcer, lbryid="abcd", peer_finder=peer_finder, hash_announcer=hash_announcer,
blob_dir=blob_dir, peer_port=5553, blob_dir=blob_dir, peer_port=5553,
use_upnp=False, rate_limiter=rate_limiter, wallet=wallet, use_upnp=False, rate_limiter=rate_limiter, wallet=wallet,
blob_tracker_class=DummyBlobAvailabilityTracker, blob_tracker_class=DummyBlobAvailabilityTracker,
@ -572,6 +564,7 @@ class TestTransfer(TestCase):
logging.info("Client is shutting down") logging.info("Client is shutting down")
d.addCallback(lambda _: print_shutting_down()) d.addCallback(lambda _: print_shutting_down())
d.addCallback(lambda _: rm_db_and_blob_dir(db_dir, blob_dir))
d.addCallback(lambda _: arg) d.addCallback(lambda _: arg)
return d return d
@ -604,11 +597,7 @@ class TestTransfer(TestCase):
hash_announcer = FakeAnnouncer() hash_announcer = FakeAnnouncer()
rate_limiter = DummyRateLimiter() rate_limiter = DummyRateLimiter()
db_dir = "client" db_dir, blob_dir = mk_db_and_blob_dir()
blob_dir = os.path.join(db_dir, "blobfiles")
os.mkdir(db_dir)
os.mkdir(blob_dir)
self.session = Session( self.session = Session(
conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=db_dir, lbryid="abcd", conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=db_dir, lbryid="abcd",
peer_finder=peer_finder, hash_announcer=hash_announcer, peer_finder=peer_finder, hash_announcer=hash_announcer,
@ -660,6 +649,7 @@ class TestTransfer(TestCase):
logging.info("Client is shutting down") logging.info("Client is shutting down")
dl.addCallback(lambda _: print_shutting_down()) dl.addCallback(lambda _: print_shutting_down())
dl.addCallback(lambda _: rm_db_and_blob_dir(db_dir, blob_dir))
dl.addCallback(lambda _: arg) dl.addCallback(lambda _: arg)
return dl return dl
@ -686,11 +676,7 @@ class TestTransfer(TestCase):
downloaders = [] downloaders = []
db_dir = "client" db_dir, blob_dir = mk_db_and_blob_dir()
blob_dir = os.path.join(db_dir, "blobfiles")
os.mkdir(db_dir)
os.mkdir(blob_dir)
self.session = Session(conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=db_dir, self.session = Session(conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=db_dir,
lbryid="abcd", peer_finder=peer_finder, lbryid="abcd", peer_finder=peer_finder,
hash_announcer=hash_announcer, blob_dir=blob_dir, peer_port=5553, hash_announcer=hash_announcer, blob_dir=blob_dir, peer_port=5553,
@ -781,6 +767,7 @@ class TestTransfer(TestCase):
logging.info("Client is shutting down") logging.info("Client is shutting down")
d.addCallback(lambda _: print_shutting_down()) d.addCallback(lambda _: print_shutting_down())
d.addCallback(lambda _: rm_db_and_blob_dir(db_dir, blob_dir))
d.addCallback(lambda _: arg) d.addCallback(lambda _: arg)
return d return d
@ -811,14 +798,10 @@ class TestTransfer(TestCase):
rate_limiter = DummyRateLimiter() rate_limiter = DummyRateLimiter()
sd_identifier = StreamDescriptorIdentifier() sd_identifier = StreamDescriptorIdentifier()
db_dir = "client" db_dir, blob_dir = mk_db_and_blob_dir()
blob_dir = os.path.join(db_dir, "blobfiles")
os.mkdir(db_dir)
os.mkdir(blob_dir)
self.session = Session(conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=db_dir, self.session = Session(conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=db_dir,
lbryid="abcd", peer_finder=peer_finder, lbryid="abcd", peer_finder=peer_finder,
hash_announcer=hash_announcer, blob_dir=None, hash_announcer=hash_announcer, blob_dir=blob_dir,
peer_port=5553, use_upnp=False, rate_limiter=rate_limiter, peer_port=5553, use_upnp=False, rate_limiter=rate_limiter,
wallet=wallet, blob_tracker_class=DummyBlobAvailabilityTracker, wallet=wallet, blob_tracker_class=DummyBlobAvailabilityTracker,
is_generous=conf.ADJUSTABLE_SETTINGS['is_generous_host'][1]) is_generous=conf.ADJUSTABLE_SETTINGS['is_generous_host'][1])
@ -892,6 +875,7 @@ class TestTransfer(TestCase):
logging.info("Client is shutting down") logging.info("Client is shutting down")
d.addCallback(lambda _: print_shutting_down()) d.addCallback(lambda _: print_shutting_down())
d.addCallback(lambda _: rm_db_and_blob_dir(db_dir, blob_dir))
d.addCallback(lambda _: arg) d.addCallback(lambda _: arg)
return d return d

View file

@ -1,5 +1,6 @@
import os import os
import shutil import shutil
import tempfile
from twisted.internet import defer, threads, error from twisted.internet import defer, threads, error
from twisted.trial import unittest from twisted.trial import unittest
@ -19,7 +20,7 @@ from lbrynet.file_manager import EncryptedFileCreator
from lbrynet.file_manager import EncryptedFileManager from lbrynet.file_manager import EncryptedFileManager
from tests import mocks from tests import mocks
from tests.util import mk_db_and_blob_dir, rm_db_and_blob_dir
class TestReflector(unittest.TestCase): class TestReflector(unittest.TestCase):
def setUp(self): def setUp(self):
@ -56,16 +57,14 @@ class TestReflector(unittest.TestCase):
), ),
] ]
db_dir = "client" self.db_dir, self.blob_dir = mk_db_and_blob_dir()
os.mkdir(db_dir)
self.session = Session.Session( self.session = Session.Session(
conf.settings['data_rate'], conf.settings['data_rate'],
db_dir=db_dir, db_dir=self.db_dir,
lbryid="abcd", lbryid="abcd",
peer_finder=peer_finder, peer_finder=peer_finder,
hash_announcer=hash_announcer, hash_announcer=hash_announcer,
blob_dir=None, blob_dir=self.blob_dir,
peer_port=5553, peer_port=5553,
use_upnp=False, use_upnp=False,
rate_limiter=rate_limiter, rate_limiter=rate_limiter,
@ -74,12 +73,14 @@ class TestReflector(unittest.TestCase):
dht_node_class=Node dht_node_class=Node
) )
self.stream_info_manager = EncryptedFileMetadataManager.DBEncryptedFileMetadataManager(db_dir) self.stream_info_manager = EncryptedFileMetadataManager.DBEncryptedFileMetadataManager(self.db_dir)
self.lbry_file_manager = EncryptedFileManager.EncryptedFileManager( self.lbry_file_manager = EncryptedFileManager.EncryptedFileManager(
self.session, self.stream_info_manager, sd_identifier) self.session, self.stream_info_manager, sd_identifier)
self.server_blob_manager = BlobManager.TempBlobManager(hash_announcer) self.server_db_dir, self.server_blob_dir = mk_db_and_blob_dir()
self.server_blob_manager = BlobManager.DiskBlobManager(
hash_announcer, self.server_blob_dir, self.server_db_dir)
d = self.session.setup() d = self.session.setup()
d.addCallback(lambda _: self.stream_info_manager.setup()) d.addCallback(lambda _: self.stream_info_manager.setup())
@ -149,7 +150,8 @@ class TestReflector(unittest.TestCase):
def delete_test_env(): def delete_test_env():
try: try:
shutil.rmtree('client') rm_db_and_blob_dir(self.db_dir, self.blob_dir)
rm_db_and_blob_dir(self.server_db_dir, self.server_blob_dir)
except: except:
raise unittest.SkipTest("TODO: fix this for windows") raise unittest.SkipTest("TODO: fix this for windows")

View file

@ -5,6 +5,7 @@ import tempfile
from Crypto.Cipher import AES from Crypto.Cipher import AES
import mock import mock
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer
from lbrynet.core import BlobManager from lbrynet.core import BlobManager
from lbrynet.core import Session from lbrynet.core import Session
@ -13,7 +14,7 @@ from lbrynet.file_manager import EncryptedFileCreator
from lbrynet.file_manager import EncryptedFileManager from lbrynet.file_manager import EncryptedFileManager
from tests import mocks from tests import mocks
from tests.util import mk_db_and_blob_dir, rm_db_and_blob_dir
MB = 2**20 MB = 2**20
@ -27,16 +28,20 @@ class CreateEncryptedFileTest(unittest.TestCase):
timeout = 5 timeout = 5
def setUp(self): def setUp(self):
mocks.mock_conf_settings(self) mocks.mock_conf_settings(self)
self.tmp_dir = tempfile.mkdtemp() self.tmp_db_dir, self.tmp_blob_dir = mk_db_and_blob_dir()
@defer.inlineCallbacks
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmp_dir) yield self.blob_manager.stop()
rm_db_and_blob_dir(self.tmp_db_dir, self.tmp_blob_dir)
def create_file(self, filename): def create_file(self, filename):
session = mock.Mock(spec=Session.Session)(None, None) session = mock.Mock(spec=Session.Session)(None, None)
hash_announcer = DHTHashAnnouncer.DHTHashAnnouncer(None, None) hash_announcer = DHTHashAnnouncer.DHTHashAnnouncer(None, None)
session.blob_manager = BlobManager.TempBlobManager(hash_announcer) self.blob_manager = BlobManager.DiskBlobManager(hash_announcer, self.tmp_blob_dir, self.tmp_db_dir)
session.db_dir = self.tmp_dir session.blob_manager = self.blob_manager
session.blob_manager.setup()
session.db_dir = self.tmp_db_dir
manager = mock.Mock(spec=EncryptedFileManager.EncryptedFileManager)() manager = mock.Mock(spec=EncryptedFileManager.EncryptedFileManager)()
handle = mocks.GenFile(3*MB, '1') handle = mocks.GenFile(3*MB, '1')
key = '2'*AES.block_size key = '2'*AES.block_size

View file

@ -2,7 +2,8 @@ import datetime
import time import time
import binascii import binascii
import os import os
import tempfile
import shutil
import mock import mock
@ -10,6 +11,15 @@ DEFAULT_TIMESTAMP = datetime.datetime(2016, 1, 1)
DEFAULT_ISO_TIME = time.mktime(DEFAULT_TIMESTAMP.timetuple()) DEFAULT_ISO_TIME = time.mktime(DEFAULT_TIMESTAMP.timetuple())
def mk_db_and_blob_dir():
db_dir = tempfile.mkdtemp()
blob_dir = tempfile.mkdtemp()
return db_dir, blob_dir
def rm_db_and_blob_dir(db_dir, blob_dir):
shutil.rmtree(db_dir, ignore_errors=True)
shutil.rmtree(blob_dir, ignore_errors=True)
def random_lbry_hash(): def random_lbry_hash():
return binascii.b2a_hex(os.urandom(48)) return binascii.b2a_hex(os.urandom(48))