diff --git a/rocksdb/_rocksdb.pyx b/rocksdb/_rocksdb.pyx index d9f6541..103a7bd 100644 --- a/rocksdb/_rocksdb.pyx +++ b/rocksdb/_rocksdb.pyx @@ -4,9 +4,10 @@ from libcpp.deque cimport deque from libcpp.vector cimport vector from libcpp cimport bool as cpp_bool from cython.operator cimport dereference as deref -from cpython.string cimport PyString_AsString -from cpython.string cimport PyString_Size -from cpython.string cimport PyString_FromString +from cpython.bytes cimport PyBytes_AsString +from cpython.bytes cimport PyBytes_Size +from cpython.bytes cimport PyBytes_FromString +from cpython.bytes cimport PyBytes_FromStringAndSize from cpython.unicode cimport PyUnicode_Decode from std_memory cimport shared_ptr @@ -21,8 +22,6 @@ cimport snapshot cimport db cimport iterator -from slice_ cimport slice_to_str -from slice_ cimport str_to_slice from status cimport Status import sys @@ -66,9 +65,18 @@ cdef check_status(const Status& st): ###################################################### -cdef string bytes_to_string(bytes path) except *: +cdef string bytes_to_string(path) except *: return string(PyBytes_AsString(path), PyBytes_Size(path)) +cdef string_to_bytes(string ob): + return PyBytes_FromStringAndSize(ob.c_str(), ob.size()) + +cdef slice_.Slice bytes_to_slice(ob) except *: + return slice_.Slice(PyBytes_AsString(ob), PyBytes_Size(ob)) + +cdef slice_to_bytes(slice_.Slice sl): + return PyBytes_FromStringAndSize(sl.data(), sl.size()) + ## only for filsystem paths cdef string path_to_string(object path) except *: if isinstance(path, bytes): @@ -80,7 +88,7 @@ cdef string path_to_string(object path) except *: raise TypeError("Wrong type for path: %s" % path) cdef object string_to_path(string path): - fs_encoding = sys.getfilesystemencoding() + fs_encoding = sys.getfilesystemencoding().encode('ascii') return PyUnicode_Decode(path.c_str(), path.size(), fs_encoding, "replace") ## Here comes the stuff for the comparator @@ -105,7 +113,7 @@ cdef class PyGenericComparator(PyComparator): self.ob = ob self.comparator_ptr = ( new comparator.ComparatorWrapper( - ob.name(), + bytes_to_string(ob.name()), ob, compare_callback)) @@ -126,12 +134,12 @@ cdef class PyBytewiseComparator(PyComparator): self.comparator_ptr = comparator.BytewiseComparator() def name(self): - return PyString_FromString(self.comparator_ptr.Name()) + return PyBytes_FromString(self.comparator_ptr.Name()) - def compare(self, str a, str b): + def compare(self, a, b): return self.comparator_ptr.Compare( - str_to_slice(a), - str_to_slice(b)) + bytes_to_slice(a), + bytes_to_slice(b)) cdef object get_ob(self): return self @@ -144,7 +152,7 @@ cdef int compare_callback( const slice_.Slice& a, const slice_.Slice& b) with gil: - return (ctx).compare(slice_to_str(a), slice_to_str(b)) + return (ctx).compare(slice_to_bytes(a), slice_to_bytes(b)) BytewiseComparator = PyBytewiseComparator ######################################### @@ -160,6 +168,7 @@ cdef class PyFilterPolicy(object): cdef const filter_policy.FilterPolicy* get_policy(self): return NULL + @cython.internal cdef class PyGenericFilterPolicy(PyFilterPolicy): cdef filter_policy.FilterPolicy* policy @@ -171,7 +180,7 @@ cdef class PyGenericFilterPolicy(PyFilterPolicy): self.ob = ob self.policy = new filter_policy.FilterPolicyWrapper( - ob.name(), + bytes_to_string(ob.name()), ob, ob, create_filter_callback, @@ -192,16 +201,18 @@ cdef void create_filter_callback( int n, string* dst) with gil: - cdef string ret = (ctx).create_filter( - [slice_to_str(keys[i]) for i in range(n)]) - dst.append(ret) + ret = (ctx).create_filter( + [slice_to_bytes(keys[i]) for i in range(n)]) + dst.append(bytes_to_string(ret)) cdef cpp_bool key_may_match_callback( void* ctx, const slice_.Slice& key, const slice_.Slice& filt) with gil: - return (ctx).key_may_match(slice_to_str(key), slice_to_str(filt)) + return (ctx).key_may_match( + slice_to_bytes(key), + slice_to_bytes(filt)) @cython.internal cdef class PyBloomFilterPolicy(PyFilterPolicy): @@ -214,26 +225,26 @@ cdef class PyBloomFilterPolicy(PyFilterPolicy): del self.policy def name(self): - return PyString_FromString(self.policy.Name()) + return PyBytes_FromString(self.policy.Name()) def create_filter(self, keys): cdef string dst cdef vector[slice_.Slice] c_keys for key in keys: - c_keys.push_back(str_to_slice(key)) + c_keys.push_back(bytes_to_slice(key)) self.policy.CreateFilter( vector_data(c_keys), c_keys.size(), cython.address(dst)) - return dst + return string_to_bytes(dst) def key_may_match(self, key, filter_): return self.policy.KeyMayMatch( - str_to_slice(key), - str_to_slice(filter_)) + bytes_to_slice(key), + bytes_to_slice(filter_)) cdef object get_ob(self): return self @@ -258,7 +269,7 @@ cdef class PyMergeOperator(object): self.merge_op.reset( new merge_operator.AssociativeMergeOperatorWrapper( - ob.name(), + bytes_to_string(ob.name()), (ob), merge_callback)) @@ -267,7 +278,7 @@ cdef class PyMergeOperator(object): self.merge_op.reset( new merge_operator.MergeOperatorWrapper( - ob.name(), + bytes_to_string(ob.name()), ob, ob, full_merge_callback, @@ -292,18 +303,16 @@ cdef cpp_bool merge_callback( if existing_value == NULL: py_existing_value = None else: - py_existing_value = slice_to_str(deref(existing_value)) + py_existing_value = slice_to_bytes(deref(existing_value)) try: ret = (ctx).merge( - slice_to_str(key), + slice_to_bytes(key), py_existing_value, - slice_to_str(value)) + slice_to_bytes(value)) if ret[0]: - new_value.assign( - PyString_AsString(ret[1]), - PyString_Size(ret[1])) + new_value.assign(bytes_to_string(ret[1])) return True return False @@ -319,25 +328,23 @@ cdef cpp_bool full_merge_callback( void* ctx, const slice_.Slice& key, const slice_.Slice* existing_value, - const deque[string]& operand_list, + const deque[string]& op_list, string* new_value, logger.Logger* log) with gil: if existing_value == NULL: py_existing_value = None else: - py_existing_value = slice_to_str(deref(existing_value)) + py_existing_value = slice_to_bytes(deref(existing_value)) try: ret = (ctx).full_merge( - slice_to_str(key), + slice_to_bytes(key), py_existing_value, - [operand_list[i] for i in range(operand_list.size())]) + [string_to_bytes(op_list[i]) for i in range(op_list.size())]) if ret[0]: - new_value.assign( - PyString_AsString(ret[1]), - PyString_Size(ret[1])) + new_value.assign(bytes_to_string(ret[1])) return True return False @@ -359,14 +366,12 @@ cdef cpp_bool partial_merge_callback( try: ret = (ctx).partial_merge( - slice_to_str(key), - slice_to_str(left_op), - slice_to_str(right_op)) + slice_to_bytes(key), + slice_to_bytes(left_op), + slice_to_bytes(right_op)) if ret[0]: - new_value.assign( - PyString_AsString(ret[1]), - PyString_Size(ret[1])) + new_value.assign(bytes_to_string(ret[1])) return True return False @@ -416,10 +421,10 @@ LRUCache = PyLRUCache cdef class CompressionType(object): - no_compression = 'no_compression' - snappy_compression = 'snappy_compression' - zlib_compression = 'zlib_compression' - bzip2_compression = 'bzip2_compression' + no_compression = u'no_compression' + snappy_compression = u'snappy_compression' + zlib_compression = u'zlib_compression' + bzip2_compression = u'bzip2_compression' cdef class Options(object): cdef options.Options* opts @@ -939,13 +944,13 @@ cdef class WriteBatch(object): del self.batch def put(self, key, value): - self.batch.Put(str_to_slice(key), str_to_slice(value)) + self.batch.Put(bytes_to_slice(key), bytes_to_slice(value)) def merge(self, key, value): - self.batch.Merge(str_to_slice(key), str_to_slice(value)) + self.batch.Merge(bytes_to_slice(key), bytes_to_slice(value)) def delete(self, key): - self.batch.Delete(str_to_slice(key)) + self.batch.Delete(bytes_to_slice(key)) def clear(self): self.batch.Clear() @@ -986,7 +991,7 @@ cdef class DB(object): opts.disableWAL = disable_wal check_status( - self.db.Put(opts, str_to_slice(key), str_to_slice(value))) + self.db.Put(opts, bytes_to_slice(key), bytes_to_slice(value))) def delete(self, key, sync=False, disable_wal=False): cdef options.WriteOptions opts @@ -994,7 +999,7 @@ cdef class DB(object): opts.disableWAL = disable_wal check_status( - self.db.Delete(opts, str_to_slice(key))) + self.db.Delete(opts, bytes_to_slice(key))) def merge(self, key, value, sync=False, disable_wal=False): cdef options.WriteOptions opts @@ -1002,7 +1007,7 @@ cdef class DB(object): opts.disableWAL = disable_wal check_status( - self.db.Merge(opts, str_to_slice(key), str_to_slice(value))) + self.db.Merge(opts, bytes_to_slice(key), bytes_to_slice(value))) def write(self, WriteBatch batch, sync=False, disable_wal=False): cdef options.WriteOptions opts @@ -1018,11 +1023,11 @@ cdef class DB(object): st = self.db.Get( self.build_read_opts(self.__parse_read_opts(*args, **kwargs)), - str_to_slice(key), + bytes_to_slice(key), cython.address(res)) if st.ok(): - return res + return string_to_bytes(res) elif st.IsNotFound(): return None else: @@ -1034,7 +1039,7 @@ cdef class DB(object): cdef vector[slice_.Slice] c_keys for key in keys: - c_keys.push_back(str_to_slice(key)) + c_keys.push_back(bytes_to_slice(key)) cdef vector[Status] res = self.db.MultiGet( self.build_read_opts(self.__parse_read_opts(*args, **kwargs)), @@ -1044,7 +1049,7 @@ cdef class DB(object): cdef dict ret_dict = {} for index in range(len(keys)): if res[index].ok(): - ret_dict[keys[index]] = values[index] + ret_dict[keys[index]] = string_to_bytes(values[index]) elif res[index].IsNotFound(): ret_dict[keys[index]] = None else: @@ -1063,13 +1068,13 @@ cdef class DB(object): value_found = False exists = self.db.KeyMayExist( opts, - str_to_slice(key), + bytes_to_slice(key), cython.address(value), cython.address(value_found)) if exists: if value_found: - return (True, value) + return (True, string_to_bytes(value)) else: return (True, None) else: @@ -1077,7 +1082,7 @@ cdef class DB(object): else: exists = self.db.KeyMayExist( opts, - str_to_slice(key), + bytes_to_slice(key), cython.address(value)) return (exists, None) @@ -1112,8 +1117,8 @@ cdef class DB(object): def get_property(self, prop): cdef string value - if self.db.GetProperty(str_to_slice(prop), cython.address(value)): - return value + if self.db.GetProperty(bytes_to_slice(prop), cython.address(value)): + return string_to_bytes(value) else: return None @@ -1125,11 +1130,11 @@ cdef class DB(object): ret = [] for ob in metadata: t = {} - t['name'] = ob.name + t['name'] = string_to_path(ob.name) t['level'] = ob.level t['size'] = ob.size - t['smallestkey'] = ob.smallestkey - t['largestkey'] = ob.largestkey + t['smallestkey'] = string_to_bytes(ob.smallestkey) + t['largestkey'] = string_to_bytes(ob.largestkey) t['smallest_seqno'] = ob.smallest_seqno t['largest_seqno'] = ob.largest_seqno @@ -1216,7 +1221,7 @@ cdef class BaseIterator(object): self.ptr.SeekToLast() cpdef seek(self, key): - self.ptr.Seek(str_to_slice(key)) + self.ptr.Seek(bytes_to_slice(key)) cdef object get_ob(self): return None @@ -1224,17 +1229,17 @@ cdef class BaseIterator(object): @cython.internal cdef class KeysIterator(BaseIterator): cdef object get_ob(self): - return slice_to_str(self.ptr.key()) + return slice_to_bytes(self.ptr.key()) @cython.internal cdef class ValuesIterator(BaseIterator): cdef object get_ob(self): - return slice_to_str(self.ptr.value()) + return slice_to_bytes(self.ptr.value()) @cython.internal cdef class ItemsIterator(BaseIterator): cdef object get_ob(self): - return (slice_to_str(self.ptr.key()), slice_to_str(self.ptr.value())) + return (slice_to_bytes(self.ptr.key()), slice_to_bytes(self.ptr.value())) @cython.internal cdef class ReversedIterator(object): diff --git a/rocksdb/slice_.pxd b/rocksdb/slice_.pxd index c4e1180..9d8a3dd 100644 --- a/rocksdb/slice_.pxd +++ b/rocksdb/slice_.pxd @@ -1,8 +1,5 @@ from libcpp.string cimport string from libcpp cimport bool as cpp_bool -from cpython.string cimport PyString_Size -from cpython.string cimport PyString_AsString -from cpython.string cimport PyString_FromStringAndSize cdef extern from "rocksdb/slice.h" namespace "rocksdb": cdef cppclass Slice: @@ -21,9 +18,3 @@ cdef extern from "rocksdb/slice.h" namespace "rocksdb": string ToString(cpp_bool) int compare(const Slice&) cpp_bool starts_with(const Slice&) - -cdef inline Slice str_to_slice(str ob): - return Slice(PyString_AsString(ob), PyString_Size(ob)) - -cdef inline str slice_to_str(Slice ob): - return PyString_FromStringAndSize(ob.data(), ob.size()) diff --git a/rocksdb/tests/test_db.py b/rocksdb/tests/test_db.py index 0348c25..e00868a 100644 --- a/rocksdb/tests/test_db.py +++ b/rocksdb/tests/test_db.py @@ -4,6 +4,8 @@ import gc import unittest import rocksdb +def int_to_bytes(ob): + return str(ob).encode('ascii') class TestHelper(object): def _clean(self): @@ -31,153 +33,165 @@ class TestDB(unittest.TestCase, TestHelper): self.assertTrue(os.path.isdir(name)) def test_get_none(self): - self.assertIsNone(self.db.get('xxx')) + self.assertIsNone(self.db.get(b'xxx')) def test_put_get(self): - self.db.put("a", "b") - self.assertEqual("b", self.db.get("a")) + self.db.put(b"a", b"b") + self.assertEqual(b"b", self.db.get(b"a")) def test_multi_get(self): - self.db.put("a", "1") - self.db.put("b", "2") - self.db.put("c", "3") + self.db.put(b"a", b"1") + self.db.put(b"b", b"2") + self.db.put(b"c", b"3") - ret = self.db.multi_get(['a', 'b', 'c']) - ref = {'a': '1', 'c': '3', 'b': '2'} + ret = self.db.multi_get([b'a', b'b', b'c']) + ref = {b'a': b'1', b'c': b'3', b'b': b'2'} self.assertEqual(ref, ret) def test_delete(self): - self.db.put("a", "b") - self.assertEqual("b", self.db.get("a")) - self.db.delete("a") - self.assertIsNone(self.db.get("a")) + self.db.put(b"a", b"b") + self.assertEqual(b"b", self.db.get(b"a")) + self.db.delete(b"a") + self.assertIsNone(self.db.get(b"a")) def test_write_batch(self): batch = rocksdb.WriteBatch() - batch.put("key", "v1") - batch.delete("key") - batch.put("key", "v2") - batch.put("key", "v3") - batch.put("a", "b") + batch.put(b"key", b"v1") + batch.delete(b"key") + batch.put(b"key", b"v2") + batch.put(b"key", b"v3") + batch.put(b"a", b"b") self.db.write(batch) - ref = {'a': 'b', 'key': 'v3'} - ret = self.db.multi_get(['key', 'a']) + ref = {b'a': b'b', b'key': b'v3'} + ret = self.db.multi_get([b'key', b'a']) self.assertEqual(ref, ret) def test_key_may_exists(self): - self.db.put("a", '1') + self.db.put(b"a", b'1') - self.assertEqual((False, None), self.db.key_may_exist("x")) - self.assertEqual((False, None), self.db.key_may_exist('x', True)) - self.assertEqual((True, None), self.db.key_may_exist('a')) - self.assertEqual((True, '1'), self.db.key_may_exist('a', True)) + self.assertEqual((False, None), self.db.key_may_exist(b"x")) + self.assertEqual((False, None), self.db.key_may_exist(b'x', True)) + self.assertEqual((True, None), self.db.key_may_exist(b'a')) + self.assertEqual((True, b'1'), self.db.key_may_exist(b'a', True)) def test_iter_keys(self): for x in range(300): - self.db.put(str(x), str(x)) + self.db.put(int_to_bytes(x), int_to_bytes(x)) it = self.db.iterkeys() self.assertEqual([], list(it)) it.seek_to_last() - self.assertEqual(['99'], list(it)) + self.assertEqual([b'99'], list(it)) - ref = sorted([str(x) for x in range(300)]) + ref = sorted([int_to_bytes(x) for x in range(300)]) it.seek_to_first() self.assertEqual(ref, list(it)) - it.seek('90') - ref = ['90', '91', '92', '93', '94', '95', '96', '97', '98', '99'] + it.seek(b'90') + ref = [ + b'90', + b'91', + b'92', + b'93', + b'94', + b'95', + b'96', + b'97', + b'98', + b'99' + ] self.assertEqual(ref, list(it)) def test_iter_values(self): for x in range(300): - self.db.put(str(x), str(x * 1000)) + self.db.put(int_to_bytes(x), int_to_bytes(x * 1000)) it = self.db.itervalues() self.assertEqual([], list(it)) it.seek_to_last() - self.assertEqual(['99000'], list(it)) + self.assertEqual([b'99000'], list(it)) - ref = sorted([str(x) for x in range(300)]) - ref = [str(int(x) * 1000) for x in ref] + ref = sorted([int_to_bytes(x) for x in range(300)]) + ref = [int_to_bytes(int(x) * 1000) for x in ref] it.seek_to_first() self.assertEqual(ref, list(it)) - it.seek('90') - ref = [str(x * 1000) for x in range(90, 100)] + it.seek(b'90') + ref = [int_to_bytes(x * 1000) for x in range(90, 100)] self.assertEqual(ref, list(it)) def test_iter_items(self): for x in range(300): - self.db.put(str(x), str(x * 1000)) + self.db.put(int_to_bytes(x), int_to_bytes(x * 1000)) it = self.db.iteritems() self.assertEqual([], list(it)) it.seek_to_last() - self.assertEqual([('99', '99000')], list(it)) + self.assertEqual([(b'99', b'99000')], list(it)) - ref = sorted([str(x) for x in range(300)]) - ref = [(x, str(int(x) * 1000)) for x in ref] + ref = sorted([int_to_bytes(x) for x in range(300)]) + ref = [(x, int_to_bytes(int(x) * 1000)) for x in ref] it.seek_to_first() self.assertEqual(ref, list(it)) - it.seek('90') - ref = [(str(x), str(x * 1000)) for x in range(90, 100)] + it.seek(b'90') + ref = [(int_to_bytes(x), int_to_bytes(x * 1000)) for x in range(90, 100)] self.assertEqual(ref, list(it)) def test_reverse_iter(self): for x in range(100): - self.db.put(str(x), str(x * 1000)) + self.db.put(int_to_bytes(x), int_to_bytes(x * 1000)) it = self.db.iteritems() it.seek_to_last() - ref = reversed(sorted([str(x) for x in range(100)])) - ref = [(x, str(int(x) * 1000)) for x in ref] + ref = reversed(sorted([int_to_bytes(x) for x in range(100)])) + ref = [(x, int_to_bytes(int(x) * 1000)) for x in ref] self.assertEqual(ref, list(reversed(it))) def test_snapshot(self): - self.db.put("a", "1") - self.db.put("b", "2") + self.db.put(b"a", b"1") + self.db.put(b"b", b"2") snapshot = self.db.snapshot() - self.db.put("a", "2") - self.db.delete("b") + self.db.put(b"a", b"2") + self.db.delete(b"b") it = self.db.iteritems() it.seek_to_first() - self.assertEqual({'a': '2'}, dict(it)) + self.assertEqual({b'a': b'2'}, dict(it)) it = self.db.iteritems(snapshot=snapshot) it.seek_to_first() - self.assertEqual({'a': '1', 'b': '2'}, dict(it)) + self.assertEqual({b'a': b'1', b'b': b'2'}, dict(it)) def test_get_property(self): for x in range(300): - self.db.put(str(x), str(x)) + x = int_to_bytes(x) + self.db.put(x, x) - self.assertIsNotNone(self.db.get_property('rocksdb.stats')) - self.assertIsNotNone(self.db.get_property('rocksdb.sstables')) - self.assertIsNotNone(self.db.get_property('rocksdb.num-files-at-level0')) - self.assertIsNone(self.db.get_property('does not exsits')) + self.assertIsNotNone(self.db.get_property(b'rocksdb.stats')) + self.assertIsNotNone(self.db.get_property(b'rocksdb.sstables')) + self.assertIsNotNone(self.db.get_property(b'rocksdb.num-files-at-level0')) + self.assertIsNone(self.db.get_property(b'does not exsits')) class AssocCounter(rocksdb.interfaces.AssociativeMergeOperator): def merge(self, key, existing_value, value): if existing_value: - return (True, str(int(existing_value) + int(value))) + return (True, int_to_bytes(int(existing_value) + int(value))) return (True, value) def name(self): - return 'AssocCounter' + return b'AssocCounter' class TestAssocMerge(unittest.TestCase, TestHelper): @@ -193,23 +207,23 @@ class TestAssocMerge(unittest.TestCase, TestHelper): def test_merge(self): for x in range(1000): - self.db.merge("a", str(x)) - self.assertEqual(str(sum(range(1000))), self.db.get('a')) + self.db.merge(b"a", int_to_bytes(x)) + self.assertEqual(sum(range(1000)), int(self.db.get(b'a'))) class FullCounter(rocksdb.interfaces.MergeOperator): def name(self): - return 'fullcounter' + return b'fullcounter' def full_merge(self, key, existing_value, operand_list): ret = sum([int(x) for x in operand_list]) if existing_value: ret += int(existing_value) - return (True, str(ret)) + return (True, int_to_bytes(ret)) def partial_merge(self, key, left, right): - return (True, str(int(left) + int(right))) + return (True, int_to_bytes(int(left) + int(right))) class TestFullMerge(unittest.TestCase, TestHelper): @@ -225,13 +239,13 @@ class TestFullMerge(unittest.TestCase, TestHelper): def test_merge(self): for x in range(1000): - self.db.merge("a", str(x)) - self.assertEqual(str(sum(range(1000))), self.db.get('a')) + self.db.merge(b"a", int_to_bytes(x)) + self.assertEqual(sum(range(1000)), int(self.db.get(b'a'))) class SimpleComparator(rocksdb.interfaces.Comparator): def name(self): - return 'mycompare' + return b'mycompare' def compare(self, a, b): a = int(a) @@ -257,6 +271,6 @@ class TestComparator(unittest.TestCase, TestHelper): def test_compare(self): for x in range(1000): - self.db.put(str(x), str(x)) + self.db.put(int_to_bytes(x), int_to_bytes(x)) - self.assertEqual('300', self.db.get('300')) + self.assertEqual(b'300', self.db.get(b'300')) diff --git a/rocksdb/tests/test_options.py b/rocksdb/tests/test_options.py index a9b1851..a993861 100644 --- a/rocksdb/tests/test_options.py +++ b/rocksdb/tests/test_options.py @@ -3,13 +3,13 @@ import rocksdb class TestFilterPolicy(rocksdb.interfaces.FilterPolicy): def create_filter(self, keys): - return 'nix' + return b'nix' def key_may_match(self, key, fil): return True def name(self): - return 'testfilter' + return b'testfilter' class TestMergeOperator(rocksdb.interfaces.MergeOperator): def full_merge(self, *args, **kwargs): @@ -19,7 +19,7 @@ class TestMergeOperator(rocksdb.interfaces.MergeOperator): return (False, None) def name(self): - return 'testmergeop' + return b'testmergeop' class TestOptions(unittest.TestCase): def test_simple(self):