diff --git a/db/db.go b/db/db.go index 5ac31e3..d546cd5 100644 --- a/db/db.go +++ b/db/db.go @@ -48,10 +48,10 @@ type ReadOnlyDBColumnFamily struct { DB *grocksdb.DB Handles map[string]*grocksdb.ColumnFamilyHandle Opts *grocksdb.ReadOptions - TxCounts *stack.SliceBacked + TxCounts *stack.SliceBacked[uint32] Height uint32 LastState *prefixes.DBStateValue - Headers *stack.SliceBacked + Headers *stack.SliceBacked[[]byte] BlockingChannelHashes [][]byte FilteringChannelHashes [][]byte BlockedStreams map[string][]byte @@ -556,7 +556,7 @@ func (db *ReadOnlyDBColumnFamily) Advance(height uint32) { } txCount := txCountObj.TxCount - if db.TxCounts.GetTip().(uint32) >= txCount { + if db.TxCounts.GetTip() >= txCount { log.Error("current tip should be less than new txCount", "tx count tip:", db.TxCounts.GetTip(), "tx count:", txCount) } @@ -636,11 +636,10 @@ func (db *ReadOnlyDBColumnFamily) detectChanges(notifCh chan *internal.HeightHas if err != nil { return err } - curHeaderObj := db.Headers.GetTip() - if curHeaderObj == nil { + curHeader := db.Headers.GetTip() + if curHeader == nil { break } - curHeader := curHeaderObj.([]byte) log.Debugln("lastHeightHeader: ", hex.EncodeToString(lastHeightHeader)) log.Debugln("curHeader: ", hex.EncodeToString(curHeader)) if bytes.Equal(curHeader, lastHeightHeader) { @@ -716,7 +715,7 @@ func (db *ReadOnlyDBColumnFamily) InitHeaders() error { } //TODO: figure out a reasonable default and make it a constant - db.Headers = stack.NewSliceBacked(12000) + db.Headers = stack.NewSliceBacked[[]byte](12000) startKey := prefixes.NewHeaderKey(0) // endKey := prefixes.NewHeaderKey(db.LastState.Height) @@ -743,7 +742,7 @@ func (db *ReadOnlyDBColumnFamily) InitTxCounts() error { return err } - db.TxCounts = stack.NewSliceBacked(InitialTxCountSize) + db.TxCounts = stack.NewSliceBacked[uint32](InitialTxCountSize) options := NewIterateOptions().WithPrefix([]byte{prefixes.TxCount}).WithCfHandle(handle) options = options.WithIncludeKey(false).WithIncludeValue(true).WithIncludeStop(true) diff --git a/db/db_resolve.go b/db/db_resolve.go index 33970fc..9020791 100644 --- a/db/db_resolve.go +++ b/db/db_resolve.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/lbryio/herald.go/db/prefixes" + "github.com/lbryio/herald.go/db/stack" "github.com/lbryio/herald.go/internal" pb "github.com/lbryio/herald.go/protobuf/go" lbryurl "github.com/lbryio/lbry.go/v3/url" @@ -40,7 +41,8 @@ func PrepareResolveResult( return nil, err } - height, createdHeight := db.TxCounts.TxCountsBisectRight(txNum, rootTxNum) + heights := stack.BisectRight(db.TxCounts, []uint32{txNum, rootTxNum}) + height, createdHeight := heights[0], heights[1] lastTakeoverHeight := controllingClaim.Height expirationHeight := GetExpirationHeight(height) @@ -86,7 +88,7 @@ func PrepareResolveResult( return nil, err } repostTxPostition = repostTxo.Position - repostHeight, _ = db.TxCounts.TxCountsBisectRight(repostTxo.TxNum, rootTxNum) + repostHeight = stack.BisectRight(db.TxCounts, []uint32{repostTxo.TxNum})[0] } } @@ -122,7 +124,7 @@ func PrepareResolveResult( return nil, err } channelTxPostition = channelVals.Position - channelHeight, _ = db.TxCounts.TxCountsBisectRight(channelVals.TxNum, rootTxNum) + channelHeight = stack.BisectRight(db.TxCounts, []uint32{channelVals.TxNum})[0] } } diff --git a/db/stack/stack.go b/db/stack/stack.go index 2cfab34..d69475e 100644 --- a/db/stack/stack.go +++ b/db/stack/stack.go @@ -7,23 +7,24 @@ import ( "sync" "github.com/lbryio/herald.go/internal" + "golang.org/x/exp/constraints" ) -type SliceBacked struct { - slice []interface{} +type SliceBacked[T any] struct { + slice []T len uint32 mut sync.RWMutex } -func NewSliceBacked(size int) *SliceBacked { - return &SliceBacked{ - slice: make([]interface{}, size), +func NewSliceBacked[T any](size int) *SliceBacked[T] { + return &SliceBacked[T]{ + slice: make([]T, size), len: 0, mut: sync.RWMutex{}, } } -func (s *SliceBacked) Push(v interface{}) { +func (s *SliceBacked[T]) Push(v T) { s.mut.Lock() defer s.mut.Unlock() @@ -35,64 +36,67 @@ func (s *SliceBacked) Push(v interface{}) { s.len++ } -func (s *SliceBacked) Pop() interface{} { +func (s *SliceBacked[T]) Pop() T { s.mut.Lock() defer s.mut.Unlock() if s.len == 0 { - return nil + var null T + return null } s.len-- return s.slice[s.len] } -func (s *SliceBacked) Get(i uint32) interface{} { +func (s *SliceBacked[T]) Get(i uint32) T { s.mut.RLock() defer s.mut.RUnlock() if i >= s.len { - return nil + var null T + return null } return s.slice[i] } -func (s *SliceBacked) GetTip() interface{} { +func (s *SliceBacked[T]) GetTip() T { s.mut.RLock() defer s.mut.RUnlock() if s.len == 0 { - return nil + var null T + return null } return s.slice[s.len-1] } -func (s *SliceBacked) Len() uint32 { +func (s *SliceBacked[T]) Len() uint32 { s.mut.RLock() defer s.mut.RUnlock() return s.len } -func (s *SliceBacked) Cap() int { +func (s *SliceBacked[T]) Cap() int { s.mut.RLock() defer s.mut.RUnlock() return cap(s.slice) } -func (s *SliceBacked) GetSlice() []interface{} { +func (s *SliceBacked[T]) GetSlice() []T { // This is not thread safe so I won't bother with locking return s.slice } -// This function is dangerous because it assumes underlying types -func (s *SliceBacked) TxCountsBisectRight(txNum, rootTxNum uint32) (uint32, uint32) { +func BisectRight[T constraints.Ordered](s *SliceBacked[T], searchKeys []T) []uint32 { s.mut.RLock() defer s.mut.RUnlock() - txCounts := s.slice[:s.Len()] - height := internal.BisectRight(txCounts, txNum) - createdHeight := internal.BisectRight(txCounts, rootTxNum) + found := make([]uint32, len(searchKeys)) + for i, k := range searchKeys { + found[i] = internal.BisectRight(s.slice[:s.Len()], k) + } - return height, createdHeight + return found } diff --git a/db/stack/stack_test.go b/db/stack/stack_test.go index 05709e0..87cdfc2 100644 --- a/db/stack/stack_test.go +++ b/db/stack/stack_test.go @@ -10,7 +10,7 @@ import ( func TestPush(t *testing.T) { var want uint32 = 3 - stack := stack.NewSliceBacked(10) + stack := stack.NewSliceBacked[int](10) stack.Push(0) stack.Push(1) @@ -22,7 +22,7 @@ func TestPush(t *testing.T) { } func TestPushPop(t *testing.T) { - stack := stack.NewSliceBacked(10) + stack := stack.NewSliceBacked[int](10) for i := 0; i < 5; i++ { stack.Push(i) @@ -46,20 +46,20 @@ func TestPushPop(t *testing.T) { } } -func doPushes(stack *stack.SliceBacked, numPushes int) { +func doPushes(stack *stack.SliceBacked[int], numPushes int) { for i := 0; i < numPushes; i++ { stack.Push(i) } } -func doPops(stack *stack.SliceBacked, numPops int) { +func doPops(stack *stack.SliceBacked[int], numPops int) { for i := 0; i < numPops; i++ { stack.Pop() } } func TestMultiThreaded(t *testing.T) { - stack := stack.NewSliceBacked(100000) + stack := stack.NewSliceBacked[int](100000) go doPushes(stack, 100000) go doPushes(stack, 100000) @@ -83,7 +83,7 @@ func TestMultiThreaded(t *testing.T) { } func TestGet(t *testing.T) { - stack := stack.NewSliceBacked(10) + stack := stack.NewSliceBacked[int](10) for i := 0; i < 5; i++ { stack.Push(i) @@ -99,6 +99,10 @@ func TestGet(t *testing.T) { } } + if got := stack.Get(5); got != 0 { + t.Errorf("got %v, want %v", got, 0) + } + slice := stack.GetSlice() if len(slice) != 10 { @@ -107,7 +111,7 @@ func TestGet(t *testing.T) { } func TestLenCap(t *testing.T) { - stack := stack.NewSliceBacked(10) + stack := stack.NewSliceBacked[int](10) if got := stack.Len(); got != 0 { t.Errorf("got %v, want %v", got, 0) diff --git a/internal/search.go b/internal/search.go index d0cfeb9..ff46114 100644 --- a/internal/search.go +++ b/internal/search.go @@ -1,10 +1,14 @@ package internal -import "sort" +import ( + "sort" + + "golang.org/x/exp/constraints" +) // BisectRight returns the index of the first element in the list that is greater than or equal to the value. // https://stackoverflow.com/questions/29959506/is-there-a-go-analog-of-pythons-bisect-module -func BisectRight(arr []interface{}, val uint32) uint32 { - i := sort.Search(len(arr), func(i int) bool { return arr[i].(uint32) >= val }) +func BisectRight[T constraints.Ordered](arr []T, val T) uint32 { + i := sort.Search(len(arr), func(i int) bool { return arr[i] >= val }) return uint32(i) }