Separate strictly between bytes/unicode

which makes py3 support possible
This commit is contained in:
hofmockel 2014-01-16 21:32:00 +01:00
parent 6277f9ab5c
commit 7d61b63b79
4 changed files with 159 additions and 149 deletions

View file

@ -4,9 +4,10 @@ from libcpp.deque cimport deque
from libcpp.vector cimport vector
from libcpp cimport bool as cpp_bool
from cython.operator cimport dereference as deref
from cpython.string cimport PyString_AsString
from cpython.string cimport PyString_Size
from cpython.string cimport PyString_FromString
from cpython.bytes cimport PyBytes_AsString
from cpython.bytes cimport PyBytes_Size
from cpython.bytes cimport PyBytes_FromString
from cpython.bytes cimport PyBytes_FromStringAndSize
from cpython.unicode cimport PyUnicode_Decode
from std_memory cimport shared_ptr
@ -21,8 +22,6 @@ cimport snapshot
cimport db
cimport iterator
from slice_ cimport slice_to_str
from slice_ cimport str_to_slice
from status cimport Status
import sys
@ -66,9 +65,18 @@ cdef check_status(const Status& st):
######################################################
cdef string bytes_to_string(bytes path) except *:
cdef string bytes_to_string(path) except *:
return string(PyBytes_AsString(path), PyBytes_Size(path))
cdef string_to_bytes(string ob):
return PyBytes_FromStringAndSize(ob.c_str(), ob.size())
cdef slice_.Slice bytes_to_slice(ob) except *:
return slice_.Slice(PyBytes_AsString(ob), PyBytes_Size(ob))
cdef slice_to_bytes(slice_.Slice sl):
return PyBytes_FromStringAndSize(sl.data(), sl.size())
## only for filsystem paths
cdef string path_to_string(object path) except *:
if isinstance(path, bytes):
@ -80,7 +88,7 @@ cdef string path_to_string(object path) except *:
raise TypeError("Wrong type for path: %s" % path)
cdef object string_to_path(string path):
fs_encoding = sys.getfilesystemencoding()
fs_encoding = sys.getfilesystemencoding().encode('ascii')
return PyUnicode_Decode(path.c_str(), path.size(), fs_encoding, "replace")
## Here comes the stuff for the comparator
@ -105,7 +113,7 @@ cdef class PyGenericComparator(PyComparator):
self.ob = ob
self.comparator_ptr = <comparator.Comparator*>(
new comparator.ComparatorWrapper(
ob.name(),
bytes_to_string(ob.name()),
<void*>ob,
compare_callback))
@ -126,12 +134,12 @@ cdef class PyBytewiseComparator(PyComparator):
self.comparator_ptr = comparator.BytewiseComparator()
def name(self):
return PyString_FromString(self.comparator_ptr.Name())
return PyBytes_FromString(self.comparator_ptr.Name())
def compare(self, str a, str b):
def compare(self, a, b):
return self.comparator_ptr.Compare(
str_to_slice(a),
str_to_slice(b))
bytes_to_slice(a),
bytes_to_slice(b))
cdef object get_ob(self):
return self
@ -144,7 +152,7 @@ cdef int compare_callback(
const slice_.Slice& a,
const slice_.Slice& b) with gil:
return (<object>ctx).compare(slice_to_str(a), slice_to_str(b))
return (<object>ctx).compare(slice_to_bytes(a), slice_to_bytes(b))
BytewiseComparator = PyBytewiseComparator
#########################################
@ -160,6 +168,7 @@ cdef class PyFilterPolicy(object):
cdef const filter_policy.FilterPolicy* get_policy(self):
return NULL
@cython.internal
cdef class PyGenericFilterPolicy(PyFilterPolicy):
cdef filter_policy.FilterPolicy* policy
@ -171,7 +180,7 @@ cdef class PyGenericFilterPolicy(PyFilterPolicy):
self.ob = ob
self.policy = <filter_policy.FilterPolicy*> new filter_policy.FilterPolicyWrapper(
ob.name(),
bytes_to_string(ob.name()),
<void*>ob,
<void*>ob,
create_filter_callback,
@ -192,16 +201,18 @@ cdef void create_filter_callback(
int n,
string* dst) with gil:
cdef string ret = (<object>ctx).create_filter(
[slice_to_str(keys[i]) for i in range(n)])
dst.append(ret)
ret = (<object>ctx).create_filter(
[slice_to_bytes(keys[i]) for i in range(n)])
dst.append(bytes_to_string(ret))
cdef cpp_bool key_may_match_callback(
void* ctx,
const slice_.Slice& key,
const slice_.Slice& filt) with gil:
return (<object>ctx).key_may_match(slice_to_str(key), slice_to_str(filt))
return (<object>ctx).key_may_match(
slice_to_bytes(key),
slice_to_bytes(filt))
@cython.internal
cdef class PyBloomFilterPolicy(PyFilterPolicy):
@ -214,26 +225,26 @@ cdef class PyBloomFilterPolicy(PyFilterPolicy):
del self.policy
def name(self):
return PyString_FromString(self.policy.Name())
return PyBytes_FromString(self.policy.Name())
def create_filter(self, keys):
cdef string dst
cdef vector[slice_.Slice] c_keys
for key in keys:
c_keys.push_back(str_to_slice(key))
c_keys.push_back(bytes_to_slice(key))
self.policy.CreateFilter(
vector_data(c_keys),
c_keys.size(),
cython.address(dst))
return dst
return string_to_bytes(dst)
def key_may_match(self, key, filter_):
return self.policy.KeyMayMatch(
str_to_slice(key),
str_to_slice(filter_))
bytes_to_slice(key),
bytes_to_slice(filter_))
cdef object get_ob(self):
return self
@ -258,7 +269,7 @@ cdef class PyMergeOperator(object):
self.merge_op.reset(
<merge_operator.MergeOperator*>
new merge_operator.AssociativeMergeOperatorWrapper(
ob.name(),
bytes_to_string(ob.name()),
<void*>(ob),
merge_callback))
@ -267,7 +278,7 @@ cdef class PyMergeOperator(object):
self.merge_op.reset(
<merge_operator.MergeOperator*>
new merge_operator.MergeOperatorWrapper(
ob.name(),
bytes_to_string(ob.name()),
<void*>ob,
<void*>ob,
full_merge_callback,
@ -292,18 +303,16 @@ cdef cpp_bool merge_callback(
if existing_value == NULL:
py_existing_value = None
else:
py_existing_value = slice_to_str(deref(existing_value))
py_existing_value = slice_to_bytes(deref(existing_value))
try:
ret = (<object>ctx).merge(
slice_to_str(key),
slice_to_bytes(key),
py_existing_value,
slice_to_str(value))
slice_to_bytes(value))
if ret[0]:
new_value.assign(
PyString_AsString(ret[1]),
PyString_Size(ret[1]))
new_value.assign(bytes_to_string(ret[1]))
return True
return False
@ -319,25 +328,23 @@ cdef cpp_bool full_merge_callback(
void* ctx,
const slice_.Slice& key,
const slice_.Slice* existing_value,
const deque[string]& operand_list,
const deque[string]& op_list,
string* new_value,
logger.Logger* log) with gil:
if existing_value == NULL:
py_existing_value = None
else:
py_existing_value = slice_to_str(deref(existing_value))
py_existing_value = slice_to_bytes(deref(existing_value))
try:
ret = (<object>ctx).full_merge(
slice_to_str(key),
slice_to_bytes(key),
py_existing_value,
[operand_list[i] for i in range(operand_list.size())])
[string_to_bytes(op_list[i]) for i in range(op_list.size())])
if ret[0]:
new_value.assign(
PyString_AsString(ret[1]),
PyString_Size(ret[1]))
new_value.assign(bytes_to_string(ret[1]))
return True
return False
@ -359,14 +366,12 @@ cdef cpp_bool partial_merge_callback(
try:
ret = (<object>ctx).partial_merge(
slice_to_str(key),
slice_to_str(left_op),
slice_to_str(right_op))
slice_to_bytes(key),
slice_to_bytes(left_op),
slice_to_bytes(right_op))
if ret[0]:
new_value.assign(
PyString_AsString(ret[1]),
PyString_Size(ret[1]))
new_value.assign(bytes_to_string(ret[1]))
return True
return False
@ -416,10 +421,10 @@ LRUCache = PyLRUCache
cdef class CompressionType(object):
no_compression = 'no_compression'
snappy_compression = 'snappy_compression'
zlib_compression = 'zlib_compression'
bzip2_compression = 'bzip2_compression'
no_compression = u'no_compression'
snappy_compression = u'snappy_compression'
zlib_compression = u'zlib_compression'
bzip2_compression = u'bzip2_compression'
cdef class Options(object):
cdef options.Options* opts
@ -939,13 +944,13 @@ cdef class WriteBatch(object):
del self.batch
def put(self, key, value):
self.batch.Put(str_to_slice(key), str_to_slice(value))
self.batch.Put(bytes_to_slice(key), bytes_to_slice(value))
def merge(self, key, value):
self.batch.Merge(str_to_slice(key), str_to_slice(value))
self.batch.Merge(bytes_to_slice(key), bytes_to_slice(value))
def delete(self, key):
self.batch.Delete(str_to_slice(key))
self.batch.Delete(bytes_to_slice(key))
def clear(self):
self.batch.Clear()
@ -986,7 +991,7 @@ cdef class DB(object):
opts.disableWAL = disable_wal
check_status(
self.db.Put(opts, str_to_slice(key), str_to_slice(value)))
self.db.Put(opts, bytes_to_slice(key), bytes_to_slice(value)))
def delete(self, key, sync=False, disable_wal=False):
cdef options.WriteOptions opts
@ -994,7 +999,7 @@ cdef class DB(object):
opts.disableWAL = disable_wal
check_status(
self.db.Delete(opts, str_to_slice(key)))
self.db.Delete(opts, bytes_to_slice(key)))
def merge(self, key, value, sync=False, disable_wal=False):
cdef options.WriteOptions opts
@ -1002,7 +1007,7 @@ cdef class DB(object):
opts.disableWAL = disable_wal
check_status(
self.db.Merge(opts, str_to_slice(key), str_to_slice(value)))
self.db.Merge(opts, bytes_to_slice(key), bytes_to_slice(value)))
def write(self, WriteBatch batch, sync=False, disable_wal=False):
cdef options.WriteOptions opts
@ -1018,11 +1023,11 @@ cdef class DB(object):
st = self.db.Get(
self.build_read_opts(self.__parse_read_opts(*args, **kwargs)),
str_to_slice(key),
bytes_to_slice(key),
cython.address(res))
if st.ok():
return res
return string_to_bytes(res)
elif st.IsNotFound():
return None
else:
@ -1034,7 +1039,7 @@ cdef class DB(object):
cdef vector[slice_.Slice] c_keys
for key in keys:
c_keys.push_back(str_to_slice(key))
c_keys.push_back(bytes_to_slice(key))
cdef vector[Status] res = self.db.MultiGet(
self.build_read_opts(self.__parse_read_opts(*args, **kwargs)),
@ -1044,7 +1049,7 @@ cdef class DB(object):
cdef dict ret_dict = {}
for index in range(len(keys)):
if res[index].ok():
ret_dict[keys[index]] = values[index]
ret_dict[keys[index]] = string_to_bytes(values[index])
elif res[index].IsNotFound():
ret_dict[keys[index]] = None
else:
@ -1063,13 +1068,13 @@ cdef class DB(object):
value_found = False
exists = self.db.KeyMayExist(
opts,
str_to_slice(key),
bytes_to_slice(key),
cython.address(value),
cython.address(value_found))
if exists:
if value_found:
return (True, value)
return (True, string_to_bytes(value))
else:
return (True, None)
else:
@ -1077,7 +1082,7 @@ cdef class DB(object):
else:
exists = self.db.KeyMayExist(
opts,
str_to_slice(key),
bytes_to_slice(key),
cython.address(value))
return (exists, None)
@ -1112,8 +1117,8 @@ cdef class DB(object):
def get_property(self, prop):
cdef string value
if self.db.GetProperty(str_to_slice(prop), cython.address(value)):
return value
if self.db.GetProperty(bytes_to_slice(prop), cython.address(value)):
return string_to_bytes(value)
else:
return None
@ -1125,11 +1130,11 @@ cdef class DB(object):
ret = []
for ob in metadata:
t = {}
t['name'] = ob.name
t['name'] = string_to_path(ob.name)
t['level'] = ob.level
t['size'] = ob.size
t['smallestkey'] = ob.smallestkey
t['largestkey'] = ob.largestkey
t['smallestkey'] = string_to_bytes(ob.smallestkey)
t['largestkey'] = string_to_bytes(ob.largestkey)
t['smallest_seqno'] = ob.smallest_seqno
t['largest_seqno'] = ob.largest_seqno
@ -1216,7 +1221,7 @@ cdef class BaseIterator(object):
self.ptr.SeekToLast()
cpdef seek(self, key):
self.ptr.Seek(str_to_slice(key))
self.ptr.Seek(bytes_to_slice(key))
cdef object get_ob(self):
return None
@ -1224,17 +1229,17 @@ cdef class BaseIterator(object):
@cython.internal
cdef class KeysIterator(BaseIterator):
cdef object get_ob(self):
return slice_to_str(self.ptr.key())
return slice_to_bytes(self.ptr.key())
@cython.internal
cdef class ValuesIterator(BaseIterator):
cdef object get_ob(self):
return slice_to_str(self.ptr.value())
return slice_to_bytes(self.ptr.value())
@cython.internal
cdef class ItemsIterator(BaseIterator):
cdef object get_ob(self):
return (slice_to_str(self.ptr.key()), slice_to_str(self.ptr.value()))
return (slice_to_bytes(self.ptr.key()), slice_to_bytes(self.ptr.value()))
@cython.internal
cdef class ReversedIterator(object):

View file

@ -1,8 +1,5 @@
from libcpp.string cimport string
from libcpp cimport bool as cpp_bool
from cpython.string cimport PyString_Size
from cpython.string cimport PyString_AsString
from cpython.string cimport PyString_FromStringAndSize
cdef extern from "rocksdb/slice.h" namespace "rocksdb":
cdef cppclass Slice:
@ -21,9 +18,3 @@ cdef extern from "rocksdb/slice.h" namespace "rocksdb":
string ToString(cpp_bool)
int compare(const Slice&)
cpp_bool starts_with(const Slice&)
cdef inline Slice str_to_slice(str ob):
return Slice(PyString_AsString(ob), PyString_Size(ob))
cdef inline str slice_to_str(Slice ob):
return PyString_FromStringAndSize(ob.data(), ob.size())

View file

@ -4,6 +4,8 @@ import gc
import unittest
import rocksdb
def int_to_bytes(ob):
return str(ob).encode('ascii')
class TestHelper(object):
def _clean(self):
@ -31,153 +33,165 @@ class TestDB(unittest.TestCase, TestHelper):
self.assertTrue(os.path.isdir(name))
def test_get_none(self):
self.assertIsNone(self.db.get('xxx'))
self.assertIsNone(self.db.get(b'xxx'))
def test_put_get(self):
self.db.put("a", "b")
self.assertEqual("b", self.db.get("a"))
self.db.put(b"a", b"b")
self.assertEqual(b"b", self.db.get(b"a"))
def test_multi_get(self):
self.db.put("a", "1")
self.db.put("b", "2")
self.db.put("c", "3")
self.db.put(b"a", b"1")
self.db.put(b"b", b"2")
self.db.put(b"c", b"3")
ret = self.db.multi_get(['a', 'b', 'c'])
ref = {'a': '1', 'c': '3', 'b': '2'}
ret = self.db.multi_get([b'a', b'b', b'c'])
ref = {b'a': b'1', b'c': b'3', b'b': b'2'}
self.assertEqual(ref, ret)
def test_delete(self):
self.db.put("a", "b")
self.assertEqual("b", self.db.get("a"))
self.db.delete("a")
self.assertIsNone(self.db.get("a"))
self.db.put(b"a", b"b")
self.assertEqual(b"b", self.db.get(b"a"))
self.db.delete(b"a")
self.assertIsNone(self.db.get(b"a"))
def test_write_batch(self):
batch = rocksdb.WriteBatch()
batch.put("key", "v1")
batch.delete("key")
batch.put("key", "v2")
batch.put("key", "v3")
batch.put("a", "b")
batch.put(b"key", b"v1")
batch.delete(b"key")
batch.put(b"key", b"v2")
batch.put(b"key", b"v3")
batch.put(b"a", b"b")
self.db.write(batch)
ref = {'a': 'b', 'key': 'v3'}
ret = self.db.multi_get(['key', 'a'])
ref = {b'a': b'b', b'key': b'v3'}
ret = self.db.multi_get([b'key', b'a'])
self.assertEqual(ref, ret)
def test_key_may_exists(self):
self.db.put("a", '1')
self.db.put(b"a", b'1')
self.assertEqual((False, None), self.db.key_may_exist("x"))
self.assertEqual((False, None), self.db.key_may_exist('x', True))
self.assertEqual((True, None), self.db.key_may_exist('a'))
self.assertEqual((True, '1'), self.db.key_may_exist('a', True))
self.assertEqual((False, None), self.db.key_may_exist(b"x"))
self.assertEqual((False, None), self.db.key_may_exist(b'x', True))
self.assertEqual((True, None), self.db.key_may_exist(b'a'))
self.assertEqual((True, b'1'), self.db.key_may_exist(b'a', True))
def test_iter_keys(self):
for x in range(300):
self.db.put(str(x), str(x))
self.db.put(int_to_bytes(x), int_to_bytes(x))
it = self.db.iterkeys()
self.assertEqual([], list(it))
it.seek_to_last()
self.assertEqual(['99'], list(it))
self.assertEqual([b'99'], list(it))
ref = sorted([str(x) for x in range(300)])
ref = sorted([int_to_bytes(x) for x in range(300)])
it.seek_to_first()
self.assertEqual(ref, list(it))
it.seek('90')
ref = ['90', '91', '92', '93', '94', '95', '96', '97', '98', '99']
it.seek(b'90')
ref = [
b'90',
b'91',
b'92',
b'93',
b'94',
b'95',
b'96',
b'97',
b'98',
b'99'
]
self.assertEqual(ref, list(it))
def test_iter_values(self):
for x in range(300):
self.db.put(str(x), str(x * 1000))
self.db.put(int_to_bytes(x), int_to_bytes(x * 1000))
it = self.db.itervalues()
self.assertEqual([], list(it))
it.seek_to_last()
self.assertEqual(['99000'], list(it))
self.assertEqual([b'99000'], list(it))
ref = sorted([str(x) for x in range(300)])
ref = [str(int(x) * 1000) for x in ref]
ref = sorted([int_to_bytes(x) for x in range(300)])
ref = [int_to_bytes(int(x) * 1000) for x in ref]
it.seek_to_first()
self.assertEqual(ref, list(it))
it.seek('90')
ref = [str(x * 1000) for x in range(90, 100)]
it.seek(b'90')
ref = [int_to_bytes(x * 1000) for x in range(90, 100)]
self.assertEqual(ref, list(it))
def test_iter_items(self):
for x in range(300):
self.db.put(str(x), str(x * 1000))
self.db.put(int_to_bytes(x), int_to_bytes(x * 1000))
it = self.db.iteritems()
self.assertEqual([], list(it))
it.seek_to_last()
self.assertEqual([('99', '99000')], list(it))
self.assertEqual([(b'99', b'99000')], list(it))
ref = sorted([str(x) for x in range(300)])
ref = [(x, str(int(x) * 1000)) for x in ref]
ref = sorted([int_to_bytes(x) for x in range(300)])
ref = [(x, int_to_bytes(int(x) * 1000)) for x in ref]
it.seek_to_first()
self.assertEqual(ref, list(it))
it.seek('90')
ref = [(str(x), str(x * 1000)) for x in range(90, 100)]
it.seek(b'90')
ref = [(int_to_bytes(x), int_to_bytes(x * 1000)) for x in range(90, 100)]
self.assertEqual(ref, list(it))
def test_reverse_iter(self):
for x in range(100):
self.db.put(str(x), str(x * 1000))
self.db.put(int_to_bytes(x), int_to_bytes(x * 1000))
it = self.db.iteritems()
it.seek_to_last()
ref = reversed(sorted([str(x) for x in range(100)]))
ref = [(x, str(int(x) * 1000)) for x in ref]
ref = reversed(sorted([int_to_bytes(x) for x in range(100)]))
ref = [(x, int_to_bytes(int(x) * 1000)) for x in ref]
self.assertEqual(ref, list(reversed(it)))
def test_snapshot(self):
self.db.put("a", "1")
self.db.put("b", "2")
self.db.put(b"a", b"1")
self.db.put(b"b", b"2")
snapshot = self.db.snapshot()
self.db.put("a", "2")
self.db.delete("b")
self.db.put(b"a", b"2")
self.db.delete(b"b")
it = self.db.iteritems()
it.seek_to_first()
self.assertEqual({'a': '2'}, dict(it))
self.assertEqual({b'a': b'2'}, dict(it))
it = self.db.iteritems(snapshot=snapshot)
it.seek_to_first()
self.assertEqual({'a': '1', 'b': '2'}, dict(it))
self.assertEqual({b'a': b'1', b'b': b'2'}, dict(it))
def test_get_property(self):
for x in range(300):
self.db.put(str(x), str(x))
x = int_to_bytes(x)
self.db.put(x, x)
self.assertIsNotNone(self.db.get_property('rocksdb.stats'))
self.assertIsNotNone(self.db.get_property('rocksdb.sstables'))
self.assertIsNotNone(self.db.get_property('rocksdb.num-files-at-level0'))
self.assertIsNone(self.db.get_property('does not exsits'))
self.assertIsNotNone(self.db.get_property(b'rocksdb.stats'))
self.assertIsNotNone(self.db.get_property(b'rocksdb.sstables'))
self.assertIsNotNone(self.db.get_property(b'rocksdb.num-files-at-level0'))
self.assertIsNone(self.db.get_property(b'does not exsits'))
class AssocCounter(rocksdb.interfaces.AssociativeMergeOperator):
def merge(self, key, existing_value, value):
if existing_value:
return (True, str(int(existing_value) + int(value)))
return (True, int_to_bytes(int(existing_value) + int(value)))
return (True, value)
def name(self):
return 'AssocCounter'
return b'AssocCounter'
class TestAssocMerge(unittest.TestCase, TestHelper):
@ -193,23 +207,23 @@ class TestAssocMerge(unittest.TestCase, TestHelper):
def test_merge(self):
for x in range(1000):
self.db.merge("a", str(x))
self.assertEqual(str(sum(range(1000))), self.db.get('a'))
self.db.merge(b"a", int_to_bytes(x))
self.assertEqual(sum(range(1000)), int(self.db.get(b'a')))
class FullCounter(rocksdb.interfaces.MergeOperator):
def name(self):
return 'fullcounter'
return b'fullcounter'
def full_merge(self, key, existing_value, operand_list):
ret = sum([int(x) for x in operand_list])
if existing_value:
ret += int(existing_value)
return (True, str(ret))
return (True, int_to_bytes(ret))
def partial_merge(self, key, left, right):
return (True, str(int(left) + int(right)))
return (True, int_to_bytes(int(left) + int(right)))
class TestFullMerge(unittest.TestCase, TestHelper):
@ -225,13 +239,13 @@ class TestFullMerge(unittest.TestCase, TestHelper):
def test_merge(self):
for x in range(1000):
self.db.merge("a", str(x))
self.assertEqual(str(sum(range(1000))), self.db.get('a'))
self.db.merge(b"a", int_to_bytes(x))
self.assertEqual(sum(range(1000)), int(self.db.get(b'a')))
class SimpleComparator(rocksdb.interfaces.Comparator):
def name(self):
return 'mycompare'
return b'mycompare'
def compare(self, a, b):
a = int(a)
@ -257,6 +271,6 @@ class TestComparator(unittest.TestCase, TestHelper):
def test_compare(self):
for x in range(1000):
self.db.put(str(x), str(x))
self.db.put(int_to_bytes(x), int_to_bytes(x))
self.assertEqual('300', self.db.get('300'))
self.assertEqual(b'300', self.db.get(b'300'))

View file

@ -3,13 +3,13 @@ import rocksdb
class TestFilterPolicy(rocksdb.interfaces.FilterPolicy):
def create_filter(self, keys):
return 'nix'
return b'nix'
def key_may_match(self, key, fil):
return True
def name(self):
return 'testfilter'
return b'testfilter'
class TestMergeOperator(rocksdb.interfaces.MergeOperator):
def full_merge(self, *args, **kwargs):
@ -19,7 +19,7 @@ class TestMergeOperator(rocksdb.interfaces.MergeOperator):
return (False, None)
def name(self):
return 'testmergeop'
return b'testmergeop'
class TestOptions(unittest.TestCase):
def test_simple(self):