From 3a108d447554e7a1346214193952a3d96d1ed4f0 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Sun, 16 Jan 2022 14:04:30 -0500 Subject: [PATCH] update column family param type for iterator to match the rest of the api --- rocksdb/_rocksdb.pyx | 12 +++++------- tests/test_db.py | 25 +++++++++++++------------ 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/rocksdb/_rocksdb.pyx b/rocksdb/_rocksdb.pyx index 091fc6d..8f87daa 100644 --- a/rocksdb/_rocksdb.pyx +++ b/rocksdb/_rocksdb.pyx @@ -1949,7 +1949,7 @@ cdef class DB(object): st = self.db.Write(opts, batch.batch) check_status(st) - def iterator(self, start: bytes, column_family: bytes = None, iterate_lower_bound: bytes = None, + def iterator(self, start: bytes, column_family: ColumnFamilyHandle = None, iterate_lower_bound: bytes = None, iterate_upper_bound: bytes = None, reverse: bool = False, include_key: bool = True, include_value: bool = True, fill_cache: bool = True, prefix_same_as_start: bool = False, auto_prefix_mode: bool = False): @@ -1957,7 +1957,7 @@ cdef class DB(object): RocksDB Iterator Args: - column_family (bytes): the name of the column family + column_family (ColumnFamilyHandle): column family handle start (bytes): prefix to seek to iterate_lower_bound (bytes): defines the smallest key at which the backward iterator can return an entry. Once the bound is passed, Valid() will be false. `iterate_lower_bound` is @@ -2003,23 +2003,21 @@ cdef class DB(object): The iterator supports being `reversed` """ - cf = self.get_column_family(column_family) - if not include_value: iterator = self.iterkeys( - column_family=cf, fill_cache=fill_cache, prefix_same_as_start=prefix_same_as_start, + column_family=column_family, fill_cache=fill_cache, prefix_same_as_start=prefix_same_as_start, iterate_lower_bound=iterate_lower_bound, iterate_upper_bound=iterate_upper_bound, auto_prefix_mode=auto_prefix_mode ) elif not include_key: iterator = self.itervalues( - column_family=cf, fill_cache=fill_cache, prefix_same_as_start=prefix_same_as_start, + column_family=column_family, fill_cache=fill_cache, prefix_same_as_start=prefix_same_as_start, iterate_lower_bound=iterate_lower_bound, iterate_upper_bound=iterate_upper_bound, auto_prefix_mode=auto_prefix_mode ) else: iterator = self.iteritems( - column_family=cf, fill_cache=fill_cache, prefix_same_as_start=prefix_same_as_start, + column_family=column_family, fill_cache=fill_cache, prefix_same_as_start=prefix_same_as_start, iterate_lower_bound=iterate_lower_bound, iterate_upper_bound=iterate_upper_bound, auto_prefix_mode=auto_prefix_mode ) diff --git a/tests/test_db.py b/tests/test_db.py index a5de895..35ab4c4 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -723,8 +723,9 @@ class TestDBColumnFamilies(TestHelper): b'A': rocksdb.ColumnFamilyOptions(), b'B': rocksdb.ColumnFamilyOptions() } - secondary = rocksdb.get_db_with_options( - os.path.join(self.db_loc, "test"), create_if_missing=True, max_open_files=-1, + secondary = rocksdb.DB( + os.path.join(self.db_loc, "test"), + rocksdb.Options(create_if_missing=True, max_open_files=-1), secondary_name=secondary_location, column_families=cf ) self.addCleanup(secondary.close) @@ -781,11 +782,11 @@ class TestPrefixIterator(TestHelper): self.assertListEqual( [(b'a0', b'a0_value'), (b'a1', b'a1_value'), (b'a1b', b'a1b_value'), (b'a2b', b'a2b_value'), (b'a3', b'a3_value'), (b'a4', b'a4_value')], - list(self.db.iterator(start=b'a', iterate_upper_bound=b'b', prefix_same_as_start=True)) + list(self.db.iterator(start=b'a', iterate_upper_bound=b'b')) ) self.assertListEqual( [b'a0', b'a1', b'a1b', b'a2b', b'a3', b'a4'], - list(self.db.iterator(start=b'a', iterate_upper_bound=b'b', prefix_same_as_start=True, include_value=False)) + list(self.db.iterator(start=b'a', iterate_upper_bound=b'b', include_value=False)) ) self.assertListEqual( [b'a0', b'a1', b'a1b', b'a2b', b'a3', b'a4'], @@ -885,37 +886,37 @@ class TestPrefixIteratorWithExtractor(TestHelper): self.assertListEqual( [(b'a0', b'a0_value'), (b'a1', b'a1_value'), (b'a1b', b'a1b_value'), (b'a2b', b'a2b_value'), (b'a3', b'a3_value'), (b'a4', b'a4_value')], - list(map(lambda x: (x[0][-1], x[1]), self.db.iterator(column_family=b'first', start=b'a', prefix_same_as_start=True))) + list(map(lambda x: (x[0][-1], x[1]), self.db.iterator(column_family=cf_a, start=b'a', prefix_same_as_start=True))) ) self.assertListEqual( [b'a0', b'a1', b'a1b', b'a2b', b'a3', b'a4'], - list(map(lambda x: x[-1], self.db.iterator(column_family=b'first', start=b'a', include_value=False, prefix_same_as_start=True))) + list(map(lambda x: x[-1], self.db.iterator(column_family=cf_a, start=b'a', include_value=False, prefix_same_as_start=True))) ) self.assertListEqual( [b'a0', b'a1', b'a1b', b'a2b', b'a3', b'a4'], - list(map(lambda x: x[-1], self.db.iterator(column_family=b'first', start=b'a0', iterate_upper_bound=b'a5', include_value=False))) + list(map(lambda x: x[-1], self.db.iterator(column_family=cf_a, start=b'a0', iterate_upper_bound=b'a5', include_value=False))) ) self.assertListEqual( [b'a4', b'a3', b'a2b', b'a1b', b'a1', b'a0'], list(map(lambda x: x[-1], reversed(self.db.iterator( - column_family=b'first', start=b'a0', iterate_upper_bound=b'a5', include_value=False + column_family=cf_a, start=b'a0', iterate_upper_bound=b'a5', include_value=False )))) ) self.assertListEqual( [b'a0', b'a1', b'a1b', b'a2b', b'a3'], - list(map(lambda x: x[-1], self.db.iterator(column_family=b'first', start=b'a0', iterate_upper_bound=b'a4', include_value=False))) + list(map(lambda x: x[-1], self.db.iterator(column_family=cf_a, start=b'a0', iterate_upper_bound=b'a4', include_value=False))) ) self.assertListEqual( [b'a0', b'a1', b'a1b'], - list(map(lambda x: x[-1], self.db.iterator(column_family=b'first', start=b'a0', iterate_upper_bound=b'a2', include_value=False))) + list(map(lambda x: x[-1], self.db.iterator(column_family=cf_a, start=b'a0', iterate_upper_bound=b'a2', include_value=False))) ) self.assertListEqual( [b'a1b', b'a1', b'a0'], list(map(lambda x: x[-1], reversed( - self.db.iterator(column_family=b'first', start=b'a0', iterate_upper_bound=b'a2', include_value=False)))) + self.db.iterator(column_family=cf_a, start=b'a0', iterate_upper_bound=b'a2', include_value=False)))) ) self.assertListEqual( [b'b0'], - list(map(lambda x: x[-1], self.db.iterator(column_family=b'second', start=b'b', include_value=False))) + list(map(lambda x: x[-1], self.db.iterator(column_family=cf_b, start=b'b', include_value=False))) )