check headers bounds on access

This commit is contained in:
Victor Shyba 2019-08-16 03:07:46 -03:00
parent 92823dcea9
commit 4c41e7cb29
2 changed files with 15 additions and 2 deletions

View file

@ -94,6 +94,17 @@ class BasicHeadersTests(BitcoinHeadersTestCase):
await headers.connect(len(headers), remainder)
self.assertEqual(headers.height, 32259)
async def test_bounds(self):
headers = MainHeaders(':memory:')
await headers.connect(0, self.get_bytes(block_bytes(3001)))
self.assertEqual(headers.height, 3000)
with self.assertRaises(IndexError):
_ = headers[3001]
with self.assertRaises(IndexError):
_ = headers[-1]
self.assertIsNotNone(headers[3000])
self.assertIsNotNone(headers[0])
async def test_repair(self):
headers = MainHeaders(':memory:')
await headers.connect(0, self.get_bytes(block_bytes(3001)))

View file

@ -70,8 +70,10 @@ class BaseHeaders:
return True
def __getitem__(self, height) -> dict:
assert not isinstance(height, slice), \
"Slicing of header chain has not been implemented yet."
if isinstance(height, slice):
raise NotImplementedError("Slicing of header chain has not been implemented yet.")
if not 0 <= height <= self.height:
raise IndexError(f"{height} is out of bounds, current height: {self.height}")
return self.deserialize(height, self.get_raw_header(height))
def get_raw_header(self, height) -> bytes: