diff --git a/docs/api/database.rst b/docs/api/database.rst index 2ebab20..1e21ebf 100644 --- a/docs/api/database.rst +++ b/docs/api/database.rst @@ -8,7 +8,7 @@ Database object .. py:method:: __init__(db_name, Options opts, read_only=False) - :param string db_name: Name of the database to open + :param unicode db_name: Name of the database to open :param opts: Options for this specific database :type opts: :py:class:`rocksdb.Options` :param bool read_only: If ``True`` the database is opened read-only. diff --git a/docs/api/options.rst b/docs/api/options.rst index 14d3579..4593250 100644 --- a/docs/api/options.rst +++ b/docs/api/options.rst @@ -329,7 +329,7 @@ Options object and the db data dir's absolute path will be used as the log file name's prefix. - | *Type:* ``string`` + | *Type:* ``unicode`` | *Default:* ``""`` .. py:attribute:: wal_dir @@ -340,7 +340,7 @@ Options object If it is non empty, the log files will be in kept the specified dir. When destroying the db, all log files in wal_dir and the dir itself is deleted - | *Type:* ``string`` + | *Type:* ``unicode`` | *Default:* ``""`` .. py:attribute:: disable_seek_compaction diff --git a/rocksdb/_rocksdb.pyx b/rocksdb/_rocksdb.pyx index e357b66..d9f6541 100644 --- a/rocksdb/_rocksdb.pyx +++ b/rocksdb/_rocksdb.pyx @@ -7,6 +7,7 @@ from cython.operator cimport dereference as deref from cpython.string cimport PyString_AsString from cpython.string cimport PyString_Size from cpython.string cimport PyString_FromString +from cpython.unicode cimport PyUnicode_Decode from std_memory cimport shared_ptr cimport options @@ -24,6 +25,7 @@ from slice_ cimport slice_to_str from slice_ cimport str_to_slice from status cimport Status +import sys from interfaces import MergeOperator as IMergeOperator from interfaces import AssociativeMergeOperator as IAssociativeMergeOperator from interfaces import FilterPolicy as IFilterPolicy @@ -64,6 +66,23 @@ cdef check_status(const Status& st): ###################################################### +cdef string bytes_to_string(bytes path) except *: + return string(PyBytes_AsString(path), PyBytes_Size(path)) + +## only for filsystem paths +cdef string path_to_string(object path) except *: + if isinstance(path, bytes): + return bytes_to_string(path) + if isinstance(path, unicode): + path = path.encode(sys.getfilesystemencoding()) + return bytes_to_string(path) + else: + raise TypeError("Wrong type for path: %s" % path) + +cdef object string_to_path(string path): + fs_encoding = sys.getfilesystemencoding() + return PyUnicode_Decode(path.c_str(), path.size(), fs_encoding, "replace") + ## Here comes the stuff for the comparator @cython.internal cdef class PyComparator(object): @@ -609,15 +628,15 @@ cdef class Options(object): property db_log_dir: def __get__(self): - return self.opts.db_log_dir + return string_to_path(self.opts.db_log_dir) def __set__(self, value): - self.opts.db_log_dir = value + self.opts.db_log_dir = path_to_string(value) property wal_dir: def __get__(self): - return self.opts.wal_dir + return string_to_path(self.opts.wal_dir) def __set__(self, value): - self.opts.wal_dir = value + self.opts.wal_dir = path_to_string(value) property disable_seek_compaction: def __get__(self): @@ -946,14 +965,14 @@ cdef class DB(object): check_status( db.DB_OpenForReadOnly( deref(opts.opts), - db_name, + path_to_string(db_name), cython.address(self.db), False)) else: check_status( db.DB_Open( deref(opts.opts), - db_name, + path_to_string(db_name), cython.address(self.db))) self.opts = opts diff --git a/rocksdb/tests/test_db.py b/rocksdb/tests/test_db.py index 202cc31..0348c25 100644 --- a/rocksdb/tests/test_db.py +++ b/rocksdb/tests/test_db.py @@ -24,6 +24,12 @@ class TestDB(unittest.TestCase, TestHelper): def tearDown(self): self._close_db() + def test_unicode_path(self): + name = b'/tmp/M\xc3\xbcnchen'.decode('utf8') + rocksdb.DB(name, rocksdb.Options(create_if_missing=True)) + self.addCleanup(shutil.rmtree, name) + self.assertTrue(os.path.isdir(name)) + def test_get_none(self): self.assertIsNone(self.db.get('xxx')) diff --git a/rocksdb/tests/test_options.py b/rocksdb/tests/test_options.py index 599f3f3..a9b1851 100644 --- a/rocksdb/tests/test_options.py +++ b/rocksdb/tests/test_options.py @@ -52,3 +52,12 @@ class TestOptions(unittest.TestCase): ob = rocksdb.LRUCache(100) opts.block_cache = ob self.assertEqual(ob, opts.block_cache) + + def test_unicode_path(self): + name = b'/tmp/M\xc3\xbcnchen'.decode('utf8') + opts = rocksdb.Options() + opts.db_log_dir = name + opts.wal_dir = name + + self.assertEqual(name, opts.db_log_dir) + self.assertEqual(name, opts.wal_dir)