Separate strictly between bytes/unicode

which makes py3 support possible
This commit is contained in:
hofmockel 2014-01-16 21:32:00 +01:00
parent 6277f9ab5c
commit 7d61b63b79
4 changed files with 159 additions and 149 deletions

View file

@ -4,9 +4,10 @@ from libcpp.deque cimport deque
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libcpp cimport bool as cpp_bool from libcpp cimport bool as cpp_bool
from cython.operator cimport dereference as deref from cython.operator cimport dereference as deref
from cpython.string cimport PyString_AsString from cpython.bytes cimport PyBytes_AsString
from cpython.string cimport PyString_Size from cpython.bytes cimport PyBytes_Size
from cpython.string cimport PyString_FromString from cpython.bytes cimport PyBytes_FromString
from cpython.bytes cimport PyBytes_FromStringAndSize
from cpython.unicode cimport PyUnicode_Decode from cpython.unicode cimport PyUnicode_Decode
from std_memory cimport shared_ptr from std_memory cimport shared_ptr
@ -21,8 +22,6 @@ cimport snapshot
cimport db cimport db
cimport iterator cimport iterator
from slice_ cimport slice_to_str
from slice_ cimport str_to_slice
from status cimport Status from status cimport Status
import sys import sys
@ -66,9 +65,18 @@ cdef check_status(const Status& st):
###################################################### ######################################################
cdef string bytes_to_string(bytes path) except *: cdef string bytes_to_string(path) except *:
return string(PyBytes_AsString(path), PyBytes_Size(path)) return string(PyBytes_AsString(path), PyBytes_Size(path))
cdef string_to_bytes(string ob):
return PyBytes_FromStringAndSize(ob.c_str(), ob.size())
cdef slice_.Slice bytes_to_slice(ob) except *:
return slice_.Slice(PyBytes_AsString(ob), PyBytes_Size(ob))
cdef slice_to_bytes(slice_.Slice sl):
return PyBytes_FromStringAndSize(sl.data(), sl.size())
## only for filsystem paths ## only for filsystem paths
cdef string path_to_string(object path) except *: cdef string path_to_string(object path) except *:
if isinstance(path, bytes): if isinstance(path, bytes):
@ -80,7 +88,7 @@ cdef string path_to_string(object path) except *:
raise TypeError("Wrong type for path: %s" % path) raise TypeError("Wrong type for path: %s" % path)
cdef object string_to_path(string path): cdef object string_to_path(string path):
fs_encoding = sys.getfilesystemencoding() fs_encoding = sys.getfilesystemencoding().encode('ascii')
return PyUnicode_Decode(path.c_str(), path.size(), fs_encoding, "replace") return PyUnicode_Decode(path.c_str(), path.size(), fs_encoding, "replace")
## Here comes the stuff for the comparator ## Here comes the stuff for the comparator
@ -105,7 +113,7 @@ cdef class PyGenericComparator(PyComparator):
self.ob = ob self.ob = ob
self.comparator_ptr = <comparator.Comparator*>( self.comparator_ptr = <comparator.Comparator*>(
new comparator.ComparatorWrapper( new comparator.ComparatorWrapper(
ob.name(), bytes_to_string(ob.name()),
<void*>ob, <void*>ob,
compare_callback)) compare_callback))
@ -126,12 +134,12 @@ cdef class PyBytewiseComparator(PyComparator):
self.comparator_ptr = comparator.BytewiseComparator() self.comparator_ptr = comparator.BytewiseComparator()
def name(self): def name(self):
return PyString_FromString(self.comparator_ptr.Name()) return PyBytes_FromString(self.comparator_ptr.Name())
def compare(self, str a, str b): def compare(self, a, b):
return self.comparator_ptr.Compare( return self.comparator_ptr.Compare(
str_to_slice(a), bytes_to_slice(a),
str_to_slice(b)) bytes_to_slice(b))
cdef object get_ob(self): cdef object get_ob(self):
return self return self
@ -144,7 +152,7 @@ cdef int compare_callback(
const slice_.Slice& a, const slice_.Slice& a,
const slice_.Slice& b) with gil: const slice_.Slice& b) with gil:
return (<object>ctx).compare(slice_to_str(a), slice_to_str(b)) return (<object>ctx).compare(slice_to_bytes(a), slice_to_bytes(b))
BytewiseComparator = PyBytewiseComparator BytewiseComparator = PyBytewiseComparator
######################################### #########################################
@ -160,6 +168,7 @@ cdef class PyFilterPolicy(object):
cdef const filter_policy.FilterPolicy* get_policy(self): cdef const filter_policy.FilterPolicy* get_policy(self):
return NULL return NULL
@cython.internal @cython.internal
cdef class PyGenericFilterPolicy(PyFilterPolicy): cdef class PyGenericFilterPolicy(PyFilterPolicy):
cdef filter_policy.FilterPolicy* policy cdef filter_policy.FilterPolicy* policy
@ -171,7 +180,7 @@ cdef class PyGenericFilterPolicy(PyFilterPolicy):
self.ob = ob self.ob = ob
self.policy = <filter_policy.FilterPolicy*> new filter_policy.FilterPolicyWrapper( self.policy = <filter_policy.FilterPolicy*> new filter_policy.FilterPolicyWrapper(
ob.name(), bytes_to_string(ob.name()),
<void*>ob, <void*>ob,
<void*>ob, <void*>ob,
create_filter_callback, create_filter_callback,
@ -192,16 +201,18 @@ cdef void create_filter_callback(
int n, int n,
string* dst) with gil: string* dst) with gil:
cdef string ret = (<object>ctx).create_filter( ret = (<object>ctx).create_filter(
[slice_to_str(keys[i]) for i in range(n)]) [slice_to_bytes(keys[i]) for i in range(n)])
dst.append(ret) dst.append(bytes_to_string(ret))
cdef cpp_bool key_may_match_callback( cdef cpp_bool key_may_match_callback(
void* ctx, void* ctx,
const slice_.Slice& key, const slice_.Slice& key,
const slice_.Slice& filt) with gil: const slice_.Slice& filt) with gil:
return (<object>ctx).key_may_match(slice_to_str(key), slice_to_str(filt)) return (<object>ctx).key_may_match(
slice_to_bytes(key),
slice_to_bytes(filt))
@cython.internal @cython.internal
cdef class PyBloomFilterPolicy(PyFilterPolicy): cdef class PyBloomFilterPolicy(PyFilterPolicy):
@ -214,26 +225,26 @@ cdef class PyBloomFilterPolicy(PyFilterPolicy):
del self.policy del self.policy
def name(self): def name(self):
return PyString_FromString(self.policy.Name()) return PyBytes_FromString(self.policy.Name())
def create_filter(self, keys): def create_filter(self, keys):
cdef string dst cdef string dst
cdef vector[slice_.Slice] c_keys cdef vector[slice_.Slice] c_keys
for key in keys: for key in keys:
c_keys.push_back(str_to_slice(key)) c_keys.push_back(bytes_to_slice(key))
self.policy.CreateFilter( self.policy.CreateFilter(
vector_data(c_keys), vector_data(c_keys),
c_keys.size(), c_keys.size(),
cython.address(dst)) cython.address(dst))
return dst return string_to_bytes(dst)
def key_may_match(self, key, filter_): def key_may_match(self, key, filter_):
return self.policy.KeyMayMatch( return self.policy.KeyMayMatch(
str_to_slice(key), bytes_to_slice(key),
str_to_slice(filter_)) bytes_to_slice(filter_))
cdef object get_ob(self): cdef object get_ob(self):
return self return self
@ -258,7 +269,7 @@ cdef class PyMergeOperator(object):
self.merge_op.reset( self.merge_op.reset(
<merge_operator.MergeOperator*> <merge_operator.MergeOperator*>
new merge_operator.AssociativeMergeOperatorWrapper( new merge_operator.AssociativeMergeOperatorWrapper(
ob.name(), bytes_to_string(ob.name()),
<void*>(ob), <void*>(ob),
merge_callback)) merge_callback))
@ -267,7 +278,7 @@ cdef class PyMergeOperator(object):
self.merge_op.reset( self.merge_op.reset(
<merge_operator.MergeOperator*> <merge_operator.MergeOperator*>
new merge_operator.MergeOperatorWrapper( new merge_operator.MergeOperatorWrapper(
ob.name(), bytes_to_string(ob.name()),
<void*>ob, <void*>ob,
<void*>ob, <void*>ob,
full_merge_callback, full_merge_callback,
@ -292,18 +303,16 @@ cdef cpp_bool merge_callback(
if existing_value == NULL: if existing_value == NULL:
py_existing_value = None py_existing_value = None
else: else:
py_existing_value = slice_to_str(deref(existing_value)) py_existing_value = slice_to_bytes(deref(existing_value))
try: try:
ret = (<object>ctx).merge( ret = (<object>ctx).merge(
slice_to_str(key), slice_to_bytes(key),
py_existing_value, py_existing_value,
slice_to_str(value)) slice_to_bytes(value))
if ret[0]: if ret[0]:
new_value.assign( new_value.assign(bytes_to_string(ret[1]))
PyString_AsString(ret[1]),
PyString_Size(ret[1]))
return True return True
return False return False
@ -319,25 +328,23 @@ cdef cpp_bool full_merge_callback(
void* ctx, void* ctx,
const slice_.Slice& key, const slice_.Slice& key,
const slice_.Slice* existing_value, const slice_.Slice* existing_value,
const deque[string]& operand_list, const deque[string]& op_list,
string* new_value, string* new_value,
logger.Logger* log) with gil: logger.Logger* log) with gil:
if existing_value == NULL: if existing_value == NULL:
py_existing_value = None py_existing_value = None
else: else:
py_existing_value = slice_to_str(deref(existing_value)) py_existing_value = slice_to_bytes(deref(existing_value))
try: try:
ret = (<object>ctx).full_merge( ret = (<object>ctx).full_merge(
slice_to_str(key), slice_to_bytes(key),
py_existing_value, py_existing_value,
[operand_list[i] for i in range(operand_list.size())]) [string_to_bytes(op_list[i]) for i in range(op_list.size())])
if ret[0]: if ret[0]:
new_value.assign( new_value.assign(bytes_to_string(ret[1]))
PyString_AsString(ret[1]),
PyString_Size(ret[1]))
return True return True
return False return False
@ -359,14 +366,12 @@ cdef cpp_bool partial_merge_callback(
try: try:
ret = (<object>ctx).partial_merge( ret = (<object>ctx).partial_merge(
slice_to_str(key), slice_to_bytes(key),
slice_to_str(left_op), slice_to_bytes(left_op),
slice_to_str(right_op)) slice_to_bytes(right_op))
if ret[0]: if ret[0]:
new_value.assign( new_value.assign(bytes_to_string(ret[1]))
PyString_AsString(ret[1]),
PyString_Size(ret[1]))
return True return True
return False return False
@ -416,10 +421,10 @@ LRUCache = PyLRUCache
cdef class CompressionType(object): cdef class CompressionType(object):
no_compression = 'no_compression' no_compression = u'no_compression'
snappy_compression = 'snappy_compression' snappy_compression = u'snappy_compression'
zlib_compression = 'zlib_compression' zlib_compression = u'zlib_compression'
bzip2_compression = 'bzip2_compression' bzip2_compression = u'bzip2_compression'
cdef class Options(object): cdef class Options(object):
cdef options.Options* opts cdef options.Options* opts
@ -939,13 +944,13 @@ cdef class WriteBatch(object):
del self.batch del self.batch
def put(self, key, value): def put(self, key, value):
self.batch.Put(str_to_slice(key), str_to_slice(value)) self.batch.Put(bytes_to_slice(key), bytes_to_slice(value))
def merge(self, key, value): def merge(self, key, value):
self.batch.Merge(str_to_slice(key), str_to_slice(value)) self.batch.Merge(bytes_to_slice(key), bytes_to_slice(value))
def delete(self, key): def delete(self, key):
self.batch.Delete(str_to_slice(key)) self.batch.Delete(bytes_to_slice(key))
def clear(self): def clear(self):
self.batch.Clear() self.batch.Clear()
@ -986,7 +991,7 @@ cdef class DB(object):
opts.disableWAL = disable_wal opts.disableWAL = disable_wal
check_status( check_status(
self.db.Put(opts, str_to_slice(key), str_to_slice(value))) self.db.Put(opts, bytes_to_slice(key), bytes_to_slice(value)))
def delete(self, key, sync=False, disable_wal=False): def delete(self, key, sync=False, disable_wal=False):
cdef options.WriteOptions opts cdef options.WriteOptions opts
@ -994,7 +999,7 @@ cdef class DB(object):
opts.disableWAL = disable_wal opts.disableWAL = disable_wal
check_status( check_status(
self.db.Delete(opts, str_to_slice(key))) self.db.Delete(opts, bytes_to_slice(key)))
def merge(self, key, value, sync=False, disable_wal=False): def merge(self, key, value, sync=False, disable_wal=False):
cdef options.WriteOptions opts cdef options.WriteOptions opts
@ -1002,7 +1007,7 @@ cdef class DB(object):
opts.disableWAL = disable_wal opts.disableWAL = disable_wal
check_status( check_status(
self.db.Merge(opts, str_to_slice(key), str_to_slice(value))) self.db.Merge(opts, bytes_to_slice(key), bytes_to_slice(value)))
def write(self, WriteBatch batch, sync=False, disable_wal=False): def write(self, WriteBatch batch, sync=False, disable_wal=False):
cdef options.WriteOptions opts cdef options.WriteOptions opts
@ -1018,11 +1023,11 @@ cdef class DB(object):
st = self.db.Get( st = self.db.Get(
self.build_read_opts(self.__parse_read_opts(*args, **kwargs)), self.build_read_opts(self.__parse_read_opts(*args, **kwargs)),
str_to_slice(key), bytes_to_slice(key),
cython.address(res)) cython.address(res))
if st.ok(): if st.ok():
return res return string_to_bytes(res)
elif st.IsNotFound(): elif st.IsNotFound():
return None return None
else: else:
@ -1034,7 +1039,7 @@ cdef class DB(object):
cdef vector[slice_.Slice] c_keys cdef vector[slice_.Slice] c_keys
for key in keys: for key in keys:
c_keys.push_back(str_to_slice(key)) c_keys.push_back(bytes_to_slice(key))
cdef vector[Status] res = self.db.MultiGet( cdef vector[Status] res = self.db.MultiGet(
self.build_read_opts(self.__parse_read_opts(*args, **kwargs)), self.build_read_opts(self.__parse_read_opts(*args, **kwargs)),
@ -1044,7 +1049,7 @@ cdef class DB(object):
cdef dict ret_dict = {} cdef dict ret_dict = {}
for index in range(len(keys)): for index in range(len(keys)):
if res[index].ok(): if res[index].ok():
ret_dict[keys[index]] = values[index] ret_dict[keys[index]] = string_to_bytes(values[index])
elif res[index].IsNotFound(): elif res[index].IsNotFound():
ret_dict[keys[index]] = None ret_dict[keys[index]] = None
else: else:
@ -1063,13 +1068,13 @@ cdef class DB(object):
value_found = False value_found = False
exists = self.db.KeyMayExist( exists = self.db.KeyMayExist(
opts, opts,
str_to_slice(key), bytes_to_slice(key),
cython.address(value), cython.address(value),
cython.address(value_found)) cython.address(value_found))
if exists: if exists:
if value_found: if value_found:
return (True, value) return (True, string_to_bytes(value))
else: else:
return (True, None) return (True, None)
else: else:
@ -1077,7 +1082,7 @@ cdef class DB(object):
else: else:
exists = self.db.KeyMayExist( exists = self.db.KeyMayExist(
opts, opts,
str_to_slice(key), bytes_to_slice(key),
cython.address(value)) cython.address(value))
return (exists, None) return (exists, None)
@ -1112,8 +1117,8 @@ cdef class DB(object):
def get_property(self, prop): def get_property(self, prop):
cdef string value cdef string value
if self.db.GetProperty(str_to_slice(prop), cython.address(value)): if self.db.GetProperty(bytes_to_slice(prop), cython.address(value)):
return value return string_to_bytes(value)
else: else:
return None return None
@ -1125,11 +1130,11 @@ cdef class DB(object):
ret = [] ret = []
for ob in metadata: for ob in metadata:
t = {} t = {}
t['name'] = ob.name t['name'] = string_to_path(ob.name)
t['level'] = ob.level t['level'] = ob.level
t['size'] = ob.size t['size'] = ob.size
t['smallestkey'] = ob.smallestkey t['smallestkey'] = string_to_bytes(ob.smallestkey)
t['largestkey'] = ob.largestkey t['largestkey'] = string_to_bytes(ob.largestkey)
t['smallest_seqno'] = ob.smallest_seqno t['smallest_seqno'] = ob.smallest_seqno
t['largest_seqno'] = ob.largest_seqno t['largest_seqno'] = ob.largest_seqno
@ -1216,7 +1221,7 @@ cdef class BaseIterator(object):
self.ptr.SeekToLast() self.ptr.SeekToLast()
cpdef seek(self, key): cpdef seek(self, key):
self.ptr.Seek(str_to_slice(key)) self.ptr.Seek(bytes_to_slice(key))
cdef object get_ob(self): cdef object get_ob(self):
return None return None
@ -1224,17 +1229,17 @@ cdef class BaseIterator(object):
@cython.internal @cython.internal
cdef class KeysIterator(BaseIterator): cdef class KeysIterator(BaseIterator):
cdef object get_ob(self): cdef object get_ob(self):
return slice_to_str(self.ptr.key()) return slice_to_bytes(self.ptr.key())
@cython.internal @cython.internal
cdef class ValuesIterator(BaseIterator): cdef class ValuesIterator(BaseIterator):
cdef object get_ob(self): cdef object get_ob(self):
return slice_to_str(self.ptr.value()) return slice_to_bytes(self.ptr.value())
@cython.internal @cython.internal
cdef class ItemsIterator(BaseIterator): cdef class ItemsIterator(BaseIterator):
cdef object get_ob(self): cdef object get_ob(self):
return (slice_to_str(self.ptr.key()), slice_to_str(self.ptr.value())) return (slice_to_bytes(self.ptr.key()), slice_to_bytes(self.ptr.value()))
@cython.internal @cython.internal
cdef class ReversedIterator(object): cdef class ReversedIterator(object):

View file

@ -1,8 +1,5 @@
from libcpp.string cimport string from libcpp.string cimport string
from libcpp cimport bool as cpp_bool from libcpp cimport bool as cpp_bool
from cpython.string cimport PyString_Size
from cpython.string cimport PyString_AsString
from cpython.string cimport PyString_FromStringAndSize
cdef extern from "rocksdb/slice.h" namespace "rocksdb": cdef extern from "rocksdb/slice.h" namespace "rocksdb":
cdef cppclass Slice: cdef cppclass Slice:
@ -21,9 +18,3 @@ cdef extern from "rocksdb/slice.h" namespace "rocksdb":
string ToString(cpp_bool) string ToString(cpp_bool)
int compare(const Slice&) int compare(const Slice&)
cpp_bool starts_with(const Slice&) cpp_bool starts_with(const Slice&)
cdef inline Slice str_to_slice(str ob):
return Slice(PyString_AsString(ob), PyString_Size(ob))
cdef inline str slice_to_str(Slice ob):
return PyString_FromStringAndSize(ob.data(), ob.size())

View file

@ -4,6 +4,8 @@ import gc
import unittest import unittest
import rocksdb import rocksdb
def int_to_bytes(ob):
return str(ob).encode('ascii')
class TestHelper(object): class TestHelper(object):
def _clean(self): def _clean(self):
@ -31,153 +33,165 @@ class TestDB(unittest.TestCase, TestHelper):
self.assertTrue(os.path.isdir(name)) self.assertTrue(os.path.isdir(name))
def test_get_none(self): def test_get_none(self):
self.assertIsNone(self.db.get('xxx')) self.assertIsNone(self.db.get(b'xxx'))
def test_put_get(self): def test_put_get(self):
self.db.put("a", "b") self.db.put(b"a", b"b")
self.assertEqual("b", self.db.get("a")) self.assertEqual(b"b", self.db.get(b"a"))
def test_multi_get(self): def test_multi_get(self):
self.db.put("a", "1") self.db.put(b"a", b"1")
self.db.put("b", "2") self.db.put(b"b", b"2")
self.db.put("c", "3") self.db.put(b"c", b"3")
ret = self.db.multi_get(['a', 'b', 'c']) ret = self.db.multi_get([b'a', b'b', b'c'])
ref = {'a': '1', 'c': '3', 'b': '2'} ref = {b'a': b'1', b'c': b'3', b'b': b'2'}
self.assertEqual(ref, ret) self.assertEqual(ref, ret)
def test_delete(self): def test_delete(self):
self.db.put("a", "b") self.db.put(b"a", b"b")
self.assertEqual("b", self.db.get("a")) self.assertEqual(b"b", self.db.get(b"a"))
self.db.delete("a") self.db.delete(b"a")
self.assertIsNone(self.db.get("a")) self.assertIsNone(self.db.get(b"a"))
def test_write_batch(self): def test_write_batch(self):
batch = rocksdb.WriteBatch() batch = rocksdb.WriteBatch()
batch.put("key", "v1") batch.put(b"key", b"v1")
batch.delete("key") batch.delete(b"key")
batch.put("key", "v2") batch.put(b"key", b"v2")
batch.put("key", "v3") batch.put(b"key", b"v3")
batch.put("a", "b") batch.put(b"a", b"b")
self.db.write(batch) self.db.write(batch)
ref = {'a': 'b', 'key': 'v3'} ref = {b'a': b'b', b'key': b'v3'}
ret = self.db.multi_get(['key', 'a']) ret = self.db.multi_get([b'key', b'a'])
self.assertEqual(ref, ret) self.assertEqual(ref, ret)
def test_key_may_exists(self): def test_key_may_exists(self):
self.db.put("a", '1') self.db.put(b"a", b'1')
self.assertEqual((False, None), self.db.key_may_exist("x")) self.assertEqual((False, None), self.db.key_may_exist(b"x"))
self.assertEqual((False, None), self.db.key_may_exist('x', True)) self.assertEqual((False, None), self.db.key_may_exist(b'x', True))
self.assertEqual((True, None), self.db.key_may_exist('a')) self.assertEqual((True, None), self.db.key_may_exist(b'a'))
self.assertEqual((True, '1'), self.db.key_may_exist('a', True)) self.assertEqual((True, b'1'), self.db.key_may_exist(b'a', True))
def test_iter_keys(self): def test_iter_keys(self):
for x in range(300): for x in range(300):
self.db.put(str(x), str(x)) self.db.put(int_to_bytes(x), int_to_bytes(x))
it = self.db.iterkeys() it = self.db.iterkeys()
self.assertEqual([], list(it)) self.assertEqual([], list(it))
it.seek_to_last() it.seek_to_last()
self.assertEqual(['99'], list(it)) self.assertEqual([b'99'], list(it))
ref = sorted([str(x) for x in range(300)]) ref = sorted([int_to_bytes(x) for x in range(300)])
it.seek_to_first() it.seek_to_first()
self.assertEqual(ref, list(it)) self.assertEqual(ref, list(it))
it.seek('90') it.seek(b'90')
ref = ['90', '91', '92', '93', '94', '95', '96', '97', '98', '99'] ref = [
b'90',
b'91',
b'92',
b'93',
b'94',
b'95',
b'96',
b'97',
b'98',
b'99'
]
self.assertEqual(ref, list(it)) self.assertEqual(ref, list(it))
def test_iter_values(self): def test_iter_values(self):
for x in range(300): for x in range(300):
self.db.put(str(x), str(x * 1000)) self.db.put(int_to_bytes(x), int_to_bytes(x * 1000))
it = self.db.itervalues() it = self.db.itervalues()
self.assertEqual([], list(it)) self.assertEqual([], list(it))
it.seek_to_last() it.seek_to_last()
self.assertEqual(['99000'], list(it)) self.assertEqual([b'99000'], list(it))
ref = sorted([str(x) for x in range(300)]) ref = sorted([int_to_bytes(x) for x in range(300)])
ref = [str(int(x) * 1000) for x in ref] ref = [int_to_bytes(int(x) * 1000) for x in ref]
it.seek_to_first() it.seek_to_first()
self.assertEqual(ref, list(it)) self.assertEqual(ref, list(it))
it.seek('90') it.seek(b'90')
ref = [str(x * 1000) for x in range(90, 100)] ref = [int_to_bytes(x * 1000) for x in range(90, 100)]
self.assertEqual(ref, list(it)) self.assertEqual(ref, list(it))
def test_iter_items(self): def test_iter_items(self):
for x in range(300): for x in range(300):
self.db.put(str(x), str(x * 1000)) self.db.put(int_to_bytes(x), int_to_bytes(x * 1000))
it = self.db.iteritems() it = self.db.iteritems()
self.assertEqual([], list(it)) self.assertEqual([], list(it))
it.seek_to_last() it.seek_to_last()
self.assertEqual([('99', '99000')], list(it)) self.assertEqual([(b'99', b'99000')], list(it))
ref = sorted([str(x) for x in range(300)]) ref = sorted([int_to_bytes(x) for x in range(300)])
ref = [(x, str(int(x) * 1000)) for x in ref] ref = [(x, int_to_bytes(int(x) * 1000)) for x in ref]
it.seek_to_first() it.seek_to_first()
self.assertEqual(ref, list(it)) self.assertEqual(ref, list(it))
it.seek('90') it.seek(b'90')
ref = [(str(x), str(x * 1000)) for x in range(90, 100)] ref = [(int_to_bytes(x), int_to_bytes(x * 1000)) for x in range(90, 100)]
self.assertEqual(ref, list(it)) self.assertEqual(ref, list(it))
def test_reverse_iter(self): def test_reverse_iter(self):
for x in range(100): for x in range(100):
self.db.put(str(x), str(x * 1000)) self.db.put(int_to_bytes(x), int_to_bytes(x * 1000))
it = self.db.iteritems() it = self.db.iteritems()
it.seek_to_last() it.seek_to_last()
ref = reversed(sorted([str(x) for x in range(100)])) ref = reversed(sorted([int_to_bytes(x) for x in range(100)]))
ref = [(x, str(int(x) * 1000)) for x in ref] ref = [(x, int_to_bytes(int(x) * 1000)) for x in ref]
self.assertEqual(ref, list(reversed(it))) self.assertEqual(ref, list(reversed(it)))
def test_snapshot(self): def test_snapshot(self):
self.db.put("a", "1") self.db.put(b"a", b"1")
self.db.put("b", "2") self.db.put(b"b", b"2")
snapshot = self.db.snapshot() snapshot = self.db.snapshot()
self.db.put("a", "2") self.db.put(b"a", b"2")
self.db.delete("b") self.db.delete(b"b")
it = self.db.iteritems() it = self.db.iteritems()
it.seek_to_first() it.seek_to_first()
self.assertEqual({'a': '2'}, dict(it)) self.assertEqual({b'a': b'2'}, dict(it))
it = self.db.iteritems(snapshot=snapshot) it = self.db.iteritems(snapshot=snapshot)
it.seek_to_first() it.seek_to_first()
self.assertEqual({'a': '1', 'b': '2'}, dict(it)) self.assertEqual({b'a': b'1', b'b': b'2'}, dict(it))
def test_get_property(self): def test_get_property(self):
for x in range(300): for x in range(300):
self.db.put(str(x), str(x)) x = int_to_bytes(x)
self.db.put(x, x)
self.assertIsNotNone(self.db.get_property('rocksdb.stats')) self.assertIsNotNone(self.db.get_property(b'rocksdb.stats'))
self.assertIsNotNone(self.db.get_property('rocksdb.sstables')) self.assertIsNotNone(self.db.get_property(b'rocksdb.sstables'))
self.assertIsNotNone(self.db.get_property('rocksdb.num-files-at-level0')) self.assertIsNotNone(self.db.get_property(b'rocksdb.num-files-at-level0'))
self.assertIsNone(self.db.get_property('does not exsits')) self.assertIsNone(self.db.get_property(b'does not exsits'))
class AssocCounter(rocksdb.interfaces.AssociativeMergeOperator): class AssocCounter(rocksdb.interfaces.AssociativeMergeOperator):
def merge(self, key, existing_value, value): def merge(self, key, existing_value, value):
if existing_value: if existing_value:
return (True, str(int(existing_value) + int(value))) return (True, int_to_bytes(int(existing_value) + int(value)))
return (True, value) return (True, value)
def name(self): def name(self):
return 'AssocCounter' return b'AssocCounter'
class TestAssocMerge(unittest.TestCase, TestHelper): class TestAssocMerge(unittest.TestCase, TestHelper):
@ -193,23 +207,23 @@ class TestAssocMerge(unittest.TestCase, TestHelper):
def test_merge(self): def test_merge(self):
for x in range(1000): for x in range(1000):
self.db.merge("a", str(x)) self.db.merge(b"a", int_to_bytes(x))
self.assertEqual(str(sum(range(1000))), self.db.get('a')) self.assertEqual(sum(range(1000)), int(self.db.get(b'a')))
class FullCounter(rocksdb.interfaces.MergeOperator): class FullCounter(rocksdb.interfaces.MergeOperator):
def name(self): def name(self):
return 'fullcounter' return b'fullcounter'
def full_merge(self, key, existing_value, operand_list): def full_merge(self, key, existing_value, operand_list):
ret = sum([int(x) for x in operand_list]) ret = sum([int(x) for x in operand_list])
if existing_value: if existing_value:
ret += int(existing_value) ret += int(existing_value)
return (True, str(ret)) return (True, int_to_bytes(ret))
def partial_merge(self, key, left, right): def partial_merge(self, key, left, right):
return (True, str(int(left) + int(right))) return (True, int_to_bytes(int(left) + int(right)))
class TestFullMerge(unittest.TestCase, TestHelper): class TestFullMerge(unittest.TestCase, TestHelper):
@ -225,13 +239,13 @@ class TestFullMerge(unittest.TestCase, TestHelper):
def test_merge(self): def test_merge(self):
for x in range(1000): for x in range(1000):
self.db.merge("a", str(x)) self.db.merge(b"a", int_to_bytes(x))
self.assertEqual(str(sum(range(1000))), self.db.get('a')) self.assertEqual(sum(range(1000)), int(self.db.get(b'a')))
class SimpleComparator(rocksdb.interfaces.Comparator): class SimpleComparator(rocksdb.interfaces.Comparator):
def name(self): def name(self):
return 'mycompare' return b'mycompare'
def compare(self, a, b): def compare(self, a, b):
a = int(a) a = int(a)
@ -257,6 +271,6 @@ class TestComparator(unittest.TestCase, TestHelper):
def test_compare(self): def test_compare(self):
for x in range(1000): for x in range(1000):
self.db.put(str(x), str(x)) self.db.put(int_to_bytes(x), int_to_bytes(x))
self.assertEqual('300', self.db.get('300')) self.assertEqual(b'300', self.db.get(b'300'))

View file

@ -3,13 +3,13 @@ import rocksdb
class TestFilterPolicy(rocksdb.interfaces.FilterPolicy): class TestFilterPolicy(rocksdb.interfaces.FilterPolicy):
def create_filter(self, keys): def create_filter(self, keys):
return 'nix' return b'nix'
def key_may_match(self, key, fil): def key_may_match(self, key, fil):
return True return True
def name(self): def name(self):
return 'testfilter' return b'testfilter'
class TestMergeOperator(rocksdb.interfaces.MergeOperator): class TestMergeOperator(rocksdb.interfaces.MergeOperator):
def full_merge(self, *args, **kwargs): def full_merge(self, *args, **kwargs):
@ -19,7 +19,7 @@ class TestMergeOperator(rocksdb.interfaces.MergeOperator):
return (False, None) return (False, None)
def name(self): def name(self):
return 'testmergeop' return b'testmergeop'
class TestOptions(unittest.TestCase): class TestOptions(unittest.TestCase):
def test_simple(self): def test_simple(self):