pkg/prand: replace with pkg/xorshift
This commit is contained in:
parent
6c3ddaefb3
commit
03b98e0090
5 changed files with 134 additions and 110 deletions
|
@ -3,12 +3,13 @@ package varinterval
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/chihaya/chihaya/bittorrent"
|
||||
"github.com/chihaya/chihaya/middleware"
|
||||
"github.com/chihaya/chihaya/pkg/prand"
|
||||
"github.com/chihaya/chihaya/pkg/xorshift"
|
||||
)
|
||||
|
||||
// ErrInvalidModifyResponseProbability is returned for a config with an invalid
|
||||
|
@ -47,7 +48,7 @@ func checkConfig(cfg Config) error {
|
|||
|
||||
type hook struct {
|
||||
cfg Config
|
||||
pr *prand.Container
|
||||
pr [1024]*xorshift.LockedXORShift128Plus
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
|
@ -59,18 +60,25 @@ func New(cfg Config) (middleware.Hook, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &hook{
|
||||
h := &hook{
|
||||
cfg: cfg,
|
||||
pr: prand.New(1024),
|
||||
}, nil
|
||||
pr: [1024]*xorshift.LockedXORShift128Plus{},
|
||||
}
|
||||
for i := range h.pr {
|
||||
h.pr[i] = xorshift.NewLockedXORShift128Plus(rand.Uint64(), rand.Uint64())
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *hook) getXORShiftByInfohash(ih *bittorrent.InfoHash) *xorshift.LockedXORShift128Plus {
|
||||
return h.pr[(int(ih[1])|int(ih[0])<<8)%len(h.pr)]
|
||||
}
|
||||
|
||||
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) {
|
||||
r := h.pr.GetByInfohash(req.InfoHash)
|
||||
r := h.getXORShiftByInfohash(&req.InfoHash)
|
||||
|
||||
if h.cfg.ModifyResponseProbability == 1 || r.Float32() < h.cfg.ModifyResponseProbability {
|
||||
addSeconds := time.Duration(r.Intn(h.cfg.MaxIncreaseDelta)+1) * time.Second
|
||||
h.pr.ReturnByInfohash(req.InfoHash)
|
||||
if h.cfg.ModifyResponseProbability == 1 || float32(xorshift.Intn(r, 1<<24))/(1<<24) < h.cfg.ModifyResponseProbability {
|
||||
addSeconds := time.Duration(xorshift.Intn(r, h.cfg.MaxIncreaseDelta)+1) * time.Second
|
||||
|
||||
resp.Interval += addSeconds
|
||||
|
||||
|
@ -81,7 +89,6 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
|
|||
return ctx, nil
|
||||
}
|
||||
|
||||
h.pr.ReturnByInfohash(req.InfoHash)
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
// Package prand allows parallel access to randomness based on indices or
|
||||
// infohashes.
|
||||
package prand
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/chihaya/chihaya/bittorrent"
|
||||
)
|
||||
|
||||
type lockableRand struct {
|
||||
*rand.Rand
|
||||
*sync.Mutex
|
||||
}
|
||||
|
||||
// Container is a container for sources of random numbers that can be locked
|
||||
// individually.
|
||||
type Container struct {
|
||||
rands []lockableRand
|
||||
}
|
||||
|
||||
// NewSeeded returns a new Container with num sources that are seeeded with
|
||||
// seed.
|
||||
func NewSeeded(num int, seed int64) *Container {
|
||||
toReturn := Container{
|
||||
rands: make([]lockableRand, num),
|
||||
}
|
||||
|
||||
for i := 0; i < num; i++ {
|
||||
toReturn.rands[i].Rand = rand.New(rand.NewSource(seed))
|
||||
toReturn.rands[i].Mutex = &sync.Mutex{}
|
||||
}
|
||||
|
||||
return &toReturn
|
||||
}
|
||||
|
||||
// New returns a new Container with num sources that are seeded with the current
|
||||
// time.
|
||||
func New(num int) *Container {
|
||||
return NewSeeded(num, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// Get locks and returns the nth source.
|
||||
//
|
||||
// Get panics if n is not a valid index for this Container.
|
||||
func (s *Container) Get(n int) *rand.Rand {
|
||||
r := s.rands[n]
|
||||
r.Lock()
|
||||
return r.Rand
|
||||
}
|
||||
|
||||
// GetByInfohash locks and returns a source derived from the infohash.
|
||||
func (s *Container) GetByInfohash(ih bittorrent.InfoHash) *rand.Rand {
|
||||
u := int(binary.BigEndian.Uint32(ih[:4])) % len(s.rands)
|
||||
return s.Get(u)
|
||||
}
|
||||
|
||||
// Return returns the nth source to be available again.
|
||||
//
|
||||
// Return panics if n is not a valid index for this Container.
|
||||
// Return also panics if the nth source is unlocked already.
|
||||
func (s *Container) Return(n int) {
|
||||
s.rands[n].Unlock()
|
||||
}
|
||||
|
||||
// ReturnByInfohash returns the source derived from the infohash.
|
||||
//
|
||||
// ReturnByInfohash panics if the source is unlocked already.
|
||||
func (s *Container) ReturnByInfohash(ih bittorrent.InfoHash) {
|
||||
u := int(binary.BigEndian.Uint32(ih[:4])) % len(s.rands)
|
||||
s.Return(u)
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
package prand
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkContainer_GetReturn(b *testing.B) {
|
||||
c := New(1024)
|
||||
a := uint64(0)
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(p *testing.PB) {
|
||||
i := int(atomic.AddUint64(&a, 1))
|
||||
var r *rand.Rand
|
||||
|
||||
for p.Next() {
|
||||
r = c.Get(i)
|
||||
c.Return(i)
|
||||
}
|
||||
|
||||
_ = r
|
||||
})
|
||||
}
|
71
pkg/xorshift/xorshift.go
Normal file
71
pkg/xorshift/xorshift.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
// Package xorshift implements the XORShift PRNG.
|
||||
package xorshift
|
||||
|
||||
import "sync"
|
||||
|
||||
// XORShift describes the functionality of an XORShift PRNG.
|
||||
type XORShift interface {
|
||||
Next() uint64
|
||||
}
|
||||
|
||||
// XORShift128Plus holds the state of an XORShift128Plus PRNG.
|
||||
type XORShift128Plus struct {
|
||||
state [2]uint64
|
||||
}
|
||||
|
||||
// Next generates a pseudorandom number and advances the state of s.
|
||||
func (s *XORShift128Plus) Next() uint64 {
|
||||
s1 := s.state[0]
|
||||
s0 := s.state[1]
|
||||
s1Tmp := s1 // need this for result computation
|
||||
s.state[0] = s0
|
||||
s1 ^= (s1 << 23) // a
|
||||
s.state[1] = s1 ^ s0 ^ (s1 >> 18) ^ (s0 >> 5) // b, c
|
||||
return s0 + s1Tmp
|
||||
}
|
||||
|
||||
// NewXORShift128Plus creates a new XORShift PRNG.
|
||||
func NewXORShift128Plus(s0, s1 uint64) *XORShift128Plus {
|
||||
return &XORShift128Plus{
|
||||
state: [2]uint64{s0, s1},
|
||||
}
|
||||
}
|
||||
|
||||
// LockedXORShift128Plus is a thread-safe XORShift128Plus.
|
||||
type LockedXORShift128Plus struct {
|
||||
sync.Mutex
|
||||
state [2]uint64
|
||||
}
|
||||
|
||||
// NewLockedXORShift128Plus creates a new LockedXORShift128Plus.
|
||||
func NewLockedXORShift128Plus(s0, s1 uint64) *LockedXORShift128Plus {
|
||||
return &LockedXORShift128Plus{
|
||||
state: [2]uint64{s0, s1},
|
||||
}
|
||||
}
|
||||
|
||||
// Next generates a pseudorandom number and advances the state of s.
|
||||
func (s *LockedXORShift128Plus) Next() uint64 {
|
||||
s.Lock()
|
||||
s1 := s.state[0]
|
||||
s0 := s.state[1]
|
||||
s1Tmp := s1 // need this for result computation
|
||||
s.state[0] = s0
|
||||
s1 ^= (s1 << 23) // a
|
||||
s.state[1] = s1 ^ s0 ^ (s1 >> 18) ^ (s0 >> 5) // b, c
|
||||
s.Unlock()
|
||||
return s0 + s1Tmp
|
||||
}
|
||||
|
||||
// Intn generates an int k that satisfies k >= 0 && k < n.
|
||||
// n must be > 0.
|
||||
func Intn(s XORShift, n int) int {
|
||||
if n <= 0 {
|
||||
panic("invalid n <= 0")
|
||||
}
|
||||
v := int(s.Next())
|
||||
if v < 0 {
|
||||
v = -v
|
||||
}
|
||||
return v % n
|
||||
}
|
46
pkg/xorshift/xorshift_test.go
Normal file
46
pkg/xorshift/xorshift_test.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package xorshift
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIntn(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
s := NewXORShift128Plus(rand.Uint64(), rand.Uint64())
|
||||
for i := 0; i < 10000; i++ {
|
||||
k := Intn(s, 10)
|
||||
require.True(t, k >= 0, "Intn() must be >= 0")
|
||||
require.True(t, k < 10, "Intn(k) must be < k")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkXORShift128Plus_Next(b *testing.B) {
|
||||
s := NewXORShift128Plus(rand.Uint64(), rand.Uint64())
|
||||
var k uint64
|
||||
for i := 0; i < b.N; i++ {
|
||||
k = s.Next()
|
||||
}
|
||||
_ = k
|
||||
}
|
||||
|
||||
func BenchmarkIntnXORShift128Plus(b *testing.B) {
|
||||
s := NewXORShift128Plus(rand.Uint64(), rand.Uint64())
|
||||
var k int
|
||||
for i := 0; i < b.N; i++ {
|
||||
k = Intn(s, 1000)
|
||||
}
|
||||
_ = k
|
||||
}
|
||||
|
||||
func BenchmarkLockedXORShift128Plus_Next(b *testing.B) {
|
||||
s := NewLockedXORShift128Plus(rand.Uint64(), rand.Uint64())
|
||||
var k uint64
|
||||
for i := 0; i < b.N; i++ {
|
||||
k = s.Next()
|
||||
}
|
||||
_ = k
|
||||
}
|
Loading…
Reference in a new issue