diff --git a/rocksdb/_rocksdb.pyx b/rocksdb/_rocksdb.pyx index 254b209..1a81fee 100644 --- a/rocksdb/_rocksdb.pyx +++ b/rocksdb/_rocksdb.pyx @@ -49,6 +49,8 @@ from interfaces import SliceTransform as ISliceTransform import traceback import errors +ctypedef const filter_policy.FilterPolicy ConstFilterPolicy + cdef extern from "cpp/utils.hpp" namespace "py_rocks": cdef const Slice* vector_data(vector[Slice]&) @@ -202,41 +204,36 @@ cdef class PyFilterPolicy(object): cdef object get_ob(self): return None - cdef const filter_policy.FilterPolicy* get_policy(self): - return NULL + cdef shared_ptr[ConstFilterPolicy] get_policy(self): + return shared_ptr[ConstFilterPolicy]() cdef set_info_log(self, shared_ptr[logger.Logger] info_log): pass @cython.internal cdef class PyGenericFilterPolicy(PyFilterPolicy): - cdef filter_policy.FilterPolicyWrapper* policy + cdef shared_ptr[filter_policy.FilterPolicyWrapper] policy cdef object ob def __cinit__(self, object ob): - self.policy = NULL if not isinstance(ob, IFilterPolicy): raise TypeError("%s is not of type %s" % (ob, IFilterPolicy)) self.ob = ob - self.policy = new filter_policy.FilterPolicyWrapper( + self.policy.reset(new filter_policy.FilterPolicyWrapper( bytes_to_string(ob.name()), ob, create_filter_callback, - key_may_match_callback) - - def __dealloc__(self): - if not self.policy == NULL: - del self.policy + key_may_match_callback)) cdef object get_ob(self): return self.ob - cdef const filter_policy.FilterPolicy* get_policy(self): - return self.policy + cdef shared_ptr[ConstFilterPolicy] get_policy(self): + return (self.policy) cdef set_info_log(self, shared_ptr[logger.Logger] info_log): - self.policy.set_info_log(info_log) + self.policy.get().set_info_log(info_log) cdef void create_filter_callback( @@ -274,18 +271,13 @@ cdef cpp_bool key_may_match_callback( @cython.internal cdef class PyBloomFilterPolicy(PyFilterPolicy): - cdef const filter_policy.FilterPolicy* policy + cdef shared_ptr[ConstFilterPolicy] policy def __cinit__(self, int bits_per_key): - self.policy = NULL - self.policy = filter_policy.NewBloomFilterPolicy(bits_per_key) - - def __dealloc__(self): - if not self.policy == NULL: - del self.policy + self.policy.reset(filter_policy.NewBloomFilterPolicy(bits_per_key)) def name(self): - return PyBytes_FromString(self.policy.Name()) + return PyBytes_FromString(self.policy.get().Name()) def create_filter(self, keys): cdef string dst @@ -294,7 +286,7 @@ cdef class PyBloomFilterPolicy(PyFilterPolicy): for key in keys: c_keys.push_back(bytes_to_slice(key)) - self.policy.CreateFilter( + self.policy.get().CreateFilter( vector_data(c_keys), c_keys.size(), cython.address(dst)) @@ -302,14 +294,14 @@ cdef class PyBloomFilterPolicy(PyFilterPolicy): return string_to_bytes(dst) def key_may_match(self, key, filter_): - return self.policy.KeyMayMatch( + return self.policy.get().KeyMayMatch( bytes_to_slice(key), bytes_to_slice(filter_)) cdef object get_ob(self): return self - cdef const filter_policy.FilterPolicy* get_policy(self): + cdef shared_ptr[ConstFilterPolicy] get_policy(self): return self.policy BloomFilterPolicy = PyBloomFilterPolicy @@ -561,13 +553,19 @@ cdef class PyTableFactory(object): cdef shared_ptr[table_factory.TableFactory] get_table_factory(self): return self.factory + cdef set_info_log(self, shared_ptr[logger.Logger] info_log): + pass + cdef class BlockBasedTableFactory(PyTableFactory): + cdef PyFilterPolicy py_filter_policy + def __init__(self, index_type='binary_search', py_bool hash_index_allow_collision=True, checksum='crc32', PyCache block_cache=None, PyCache block_cache_compressed=None, + filter_policy=None, no_block_cache=False, block_size=None, block_size_deviation=None, @@ -622,8 +620,24 @@ cdef class BlockBasedTableFactory(PyTableFactory): if block_cache_compressed is not None: table_options.block_cache_compressed = block_cache_compressed.get_cache() + # Set the filter_policy + self.py_filter_policy = None + if filter_policy is not None: + if isinstance(filter_policy, PyFilterPolicy): + if (filter_policy).get_policy().get() == NULL: + raise Exception("Cannot set filter policy: %s" % filter_policy) + self.py_filter_policy = filter_policy + else: + self.py_filter_policy = PyGenericFilterPolicy(filter_policy) + + table_options.filter_policy = self.py_filter_policy.get_policy() + self.factory.reset(table_factory.NewBlockBasedTableFactory(table_options)) + cdef set_info_log(self, shared_ptr[logger.Logger] info_log): + if self.py_filter_policy is not None: + self.py_filter_policy.set_info_log(info_log) + cdef class PlainTableFactory(PyTableFactory): def __init__( self, @@ -701,7 +715,6 @@ cdef class Options(object): cdef options.Options* opts cdef PyComparator py_comparator cdef PyMergeOperator py_merge_operator - cdef PyFilterPolicy py_filter_policy cdef PySliceTransform py_prefix_extractor cdef PyTableFactory py_table_factory cdef PyMemtableFactory py_memtable_factory @@ -721,7 +734,6 @@ cdef class Options(object): def __init__(self, **kwargs): self.py_comparator = BytewiseComparator() self.py_merge_operator = None - self.py_filter_policy = None self.py_prefix_extractor = None self.py_table_factory = None self.py_memtable_factory = None @@ -1196,22 +1208,6 @@ cdef class Options(object): self.py_merge_operator = PyMergeOperator(value) self.opts.merge_operator = self.py_merge_operator.get_operator() - property filter_policy: - def __get__(self): - if self.py_filter_policy is None: - return None - return self.py_filter_policy.get_ob() - - def __set__(self, value): - if isinstance(value, PyFilterPolicy): - if (value).get_policy() == NULL: - raise Exception("Cannot set filter policy: %s" % value) - self.py_filter_policy = value - else: - self.py_filter_policy = PyGenericFilterPolicy(value) - - self.opts.filter_policy = self.py_filter_policy.get_policy() - property prefix_extractor: def __get__(self): if self.py_prefix_extractor is None: @@ -1297,8 +1293,8 @@ cdef class DB(object): if opts.py_comparator is not None: opts.py_comparator.set_info_log(info_log) - if opts.py_filter_policy is not None: - opts.py_filter_policy.set_info_log(info_log) + if opts.py_table_factory is not None: + opts.py_table_factory.set_info_log(info_log) if opts.prefix_extractor is not None: opts.py_prefix_extractor.set_info_log(info_log) diff --git a/rocksdb/options.pxd b/rocksdb/options.pxd index b17a5ab..a2be987 100644 --- a/rocksdb/options.pxd +++ b/rocksdb/options.pxd @@ -5,7 +5,6 @@ from libc.stdint cimport uint64_t from std_memory cimport shared_ptr from comparator cimport Comparator from merge_operator cimport MergeOperator -from filter_policy cimport FilterPolicy from logger cimport Logger from slice_ cimport Slice from snapshot cimport Snapshot @@ -32,7 +31,6 @@ cdef extern from "rocksdb/options.h" namespace "rocksdb": cdef cppclass Options: const Comparator* comparator shared_ptr[MergeOperator] merge_operator - const FilterPolicy* filter_policy # TODO: compaction_filter # TODO: compaction_filter_factory cpp_bool create_if_missing diff --git a/rocksdb/table_factory.pxd b/rocksdb/table_factory.pxd index 2c61e64..2359292 100644 --- a/rocksdb/table_factory.pxd +++ b/rocksdb/table_factory.pxd @@ -3,6 +3,7 @@ from libcpp cimport bool as cpp_bool from std_memory cimport shared_ptr from cache cimport Cache +from filter_policy cimport FilterPolicy cdef extern from "rocksdb/table.h" namespace "rocksdb": cdef cppclass TableFactory: @@ -28,6 +29,7 @@ cdef extern from "rocksdb/table.h" namespace "rocksdb": cpp_bool whole_key_filtering shared_ptr[Cache] block_cache shared_ptr[Cache] block_cache_compressed + shared_ptr[FilterPolicy] filter_policy cdef TableFactory* NewBlockBasedTableFactory(const BlockBasedTableOptions&)