pkg/prand: replace with pkg/xorshift
This commit is contained in:
parent
6c3ddaefb3
commit
03b98e0090
5 changed files with 134 additions and 110 deletions
|
@ -3,12 +3,13 @@ package varinterval
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"math/rand"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/chihaya/chihaya/bittorrent"
|
"github.com/chihaya/chihaya/bittorrent"
|
||||||
"github.com/chihaya/chihaya/middleware"
|
"github.com/chihaya/chihaya/middleware"
|
||||||
"github.com/chihaya/chihaya/pkg/prand"
|
"github.com/chihaya/chihaya/pkg/xorshift"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrInvalidModifyResponseProbability is returned for a config with an invalid
|
// ErrInvalidModifyResponseProbability is returned for a config with an invalid
|
||||||
|
@ -47,7 +48,7 @@ func checkConfig(cfg Config) error {
|
||||||
|
|
||||||
type hook struct {
|
type hook struct {
|
||||||
cfg Config
|
cfg Config
|
||||||
pr *prand.Container
|
pr [1024]*xorshift.LockedXORShift128Plus
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,18 +60,25 @@ func New(cfg Config) (middleware.Hook, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &hook{
|
h := &hook{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
pr: prand.New(1024),
|
pr: [1024]*xorshift.LockedXORShift128Plus{},
|
||||||
}, nil
|
}
|
||||||
|
for i := range h.pr {
|
||||||
|
h.pr[i] = xorshift.NewLockedXORShift128Plus(rand.Uint64(), rand.Uint64())
|
||||||
|
}
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *hook) getXORShiftByInfohash(ih *bittorrent.InfoHash) *xorshift.LockedXORShift128Plus {
|
||||||
|
return h.pr[(int(ih[1])|int(ih[0])<<8)%len(h.pr)]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) {
|
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) {
|
||||||
r := h.pr.GetByInfohash(req.InfoHash)
|
r := h.getXORShiftByInfohash(&req.InfoHash)
|
||||||
|
|
||||||
if h.cfg.ModifyResponseProbability == 1 || r.Float32() < h.cfg.ModifyResponseProbability {
|
if h.cfg.ModifyResponseProbability == 1 || float32(xorshift.Intn(r, 1<<24))/(1<<24) < h.cfg.ModifyResponseProbability {
|
||||||
addSeconds := time.Duration(r.Intn(h.cfg.MaxIncreaseDelta)+1) * time.Second
|
addSeconds := time.Duration(xorshift.Intn(r, h.cfg.MaxIncreaseDelta)+1) * time.Second
|
||||||
h.pr.ReturnByInfohash(req.InfoHash)
|
|
||||||
|
|
||||||
resp.Interval += addSeconds
|
resp.Interval += addSeconds
|
||||||
|
|
||||||
|
@ -81,7 +89,6 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
|
||||||
return ctx, nil
|
return ctx, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
h.pr.ReturnByInfohash(req.InfoHash)
|
|
||||||
return ctx, nil
|
return ctx, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,75 +0,0 @@
|
||||||
// Package prand allows parallel access to randomness based on indices or
|
|
||||||
// infohashes.
|
|
||||||
package prand
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"math/rand"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/chihaya/chihaya/bittorrent"
|
|
||||||
)
|
|
||||||
|
|
||||||
type lockableRand struct {
|
|
||||||
*rand.Rand
|
|
||||||
*sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// Container is a container for sources of random numbers that can be locked
|
|
||||||
// individually.
|
|
||||||
type Container struct {
|
|
||||||
rands []lockableRand
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSeeded returns a new Container with num sources that are seeeded with
|
|
||||||
// seed.
|
|
||||||
func NewSeeded(num int, seed int64) *Container {
|
|
||||||
toReturn := Container{
|
|
||||||
rands: make([]lockableRand, num),
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < num; i++ {
|
|
||||||
toReturn.rands[i].Rand = rand.New(rand.NewSource(seed))
|
|
||||||
toReturn.rands[i].Mutex = &sync.Mutex{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &toReturn
|
|
||||||
}
|
|
||||||
|
|
||||||
// New returns a new Container with num sources that are seeded with the current
|
|
||||||
// time.
|
|
||||||
func New(num int) *Container {
|
|
||||||
return NewSeeded(num, time.Now().UnixNano())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get locks and returns the nth source.
|
|
||||||
//
|
|
||||||
// Get panics if n is not a valid index for this Container.
|
|
||||||
func (s *Container) Get(n int) *rand.Rand {
|
|
||||||
r := s.rands[n]
|
|
||||||
r.Lock()
|
|
||||||
return r.Rand
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetByInfohash locks and returns a source derived from the infohash.
|
|
||||||
func (s *Container) GetByInfohash(ih bittorrent.InfoHash) *rand.Rand {
|
|
||||||
u := int(binary.BigEndian.Uint32(ih[:4])) % len(s.rands)
|
|
||||||
return s.Get(u)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return returns the nth source to be available again.
|
|
||||||
//
|
|
||||||
// Return panics if n is not a valid index for this Container.
|
|
||||||
// Return also panics if the nth source is unlocked already.
|
|
||||||
func (s *Container) Return(n int) {
|
|
||||||
s.rands[n].Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReturnByInfohash returns the source derived from the infohash.
|
|
||||||
//
|
|
||||||
// ReturnByInfohash panics if the source is unlocked already.
|
|
||||||
func (s *Container) ReturnByInfohash(ih bittorrent.InfoHash) {
|
|
||||||
u := int(binary.BigEndian.Uint32(ih[:4])) % len(s.rands)
|
|
||||||
s.Return(u)
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
package prand
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/rand"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func BenchmarkContainer_GetReturn(b *testing.B) {
|
|
||||||
c := New(1024)
|
|
||||||
a := uint64(0)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
b.RunParallel(func(p *testing.PB) {
|
|
||||||
i := int(atomic.AddUint64(&a, 1))
|
|
||||||
var r *rand.Rand
|
|
||||||
|
|
||||||
for p.Next() {
|
|
||||||
r = c.Get(i)
|
|
||||||
c.Return(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = r
|
|
||||||
})
|
|
||||||
}
|
|
71
pkg/xorshift/xorshift.go
Normal file
71
pkg/xorshift/xorshift.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
// Package xorshift implements the XORShift PRNG.
|
||||||
|
package xorshift
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// XORShift describes the functionality of an XORShift PRNG.
|
||||||
|
type XORShift interface {
|
||||||
|
Next() uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// XORShift128Plus holds the state of an XORShift128Plus PRNG.
|
||||||
|
type XORShift128Plus struct {
|
||||||
|
state [2]uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next generates a pseudorandom number and advances the state of s.
|
||||||
|
func (s *XORShift128Plus) Next() uint64 {
|
||||||
|
s1 := s.state[0]
|
||||||
|
s0 := s.state[1]
|
||||||
|
s1Tmp := s1 // need this for result computation
|
||||||
|
s.state[0] = s0
|
||||||
|
s1 ^= (s1 << 23) // a
|
||||||
|
s.state[1] = s1 ^ s0 ^ (s1 >> 18) ^ (s0 >> 5) // b, c
|
||||||
|
return s0 + s1Tmp
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewXORShift128Plus creates a new XORShift PRNG.
|
||||||
|
func NewXORShift128Plus(s0, s1 uint64) *XORShift128Plus {
|
||||||
|
return &XORShift128Plus{
|
||||||
|
state: [2]uint64{s0, s1},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LockedXORShift128Plus is a thread-safe XORShift128Plus.
|
||||||
|
type LockedXORShift128Plus struct {
|
||||||
|
sync.Mutex
|
||||||
|
state [2]uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLockedXORShift128Plus creates a new LockedXORShift128Plus.
|
||||||
|
func NewLockedXORShift128Plus(s0, s1 uint64) *LockedXORShift128Plus {
|
||||||
|
return &LockedXORShift128Plus{
|
||||||
|
state: [2]uint64{s0, s1},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next generates a pseudorandom number and advances the state of s.
|
||||||
|
func (s *LockedXORShift128Plus) Next() uint64 {
|
||||||
|
s.Lock()
|
||||||
|
s1 := s.state[0]
|
||||||
|
s0 := s.state[1]
|
||||||
|
s1Tmp := s1 // need this for result computation
|
||||||
|
s.state[0] = s0
|
||||||
|
s1 ^= (s1 << 23) // a
|
||||||
|
s.state[1] = s1 ^ s0 ^ (s1 >> 18) ^ (s0 >> 5) // b, c
|
||||||
|
s.Unlock()
|
||||||
|
return s0 + s1Tmp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intn generates an int k that satisfies k >= 0 && k < n.
|
||||||
|
// n must be > 0.
|
||||||
|
func Intn(s XORShift, n int) int {
|
||||||
|
if n <= 0 {
|
||||||
|
panic("invalid n <= 0")
|
||||||
|
}
|
||||||
|
v := int(s.Next())
|
||||||
|
if v < 0 {
|
||||||
|
v = -v
|
||||||
|
}
|
||||||
|
return v % n
|
||||||
|
}
|
46
pkg/xorshift/xorshift_test.go
Normal file
46
pkg/xorshift/xorshift_test.go
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
package xorshift
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIntn(t *testing.T) {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
s := NewXORShift128Plus(rand.Uint64(), rand.Uint64())
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
k := Intn(s, 10)
|
||||||
|
require.True(t, k >= 0, "Intn() must be >= 0")
|
||||||
|
require.True(t, k < 10, "Intn(k) must be < k")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkXORShift128Plus_Next(b *testing.B) {
|
||||||
|
s := NewXORShift128Plus(rand.Uint64(), rand.Uint64())
|
||||||
|
var k uint64
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
k = s.Next()
|
||||||
|
}
|
||||||
|
_ = k
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIntnXORShift128Plus(b *testing.B) {
|
||||||
|
s := NewXORShift128Plus(rand.Uint64(), rand.Uint64())
|
||||||
|
var k int
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
k = Intn(s, 1000)
|
||||||
|
}
|
||||||
|
_ = k
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLockedXORShift128Plus_Next(b *testing.B) {
|
||||||
|
s := NewLockedXORShift128Plus(rand.Uint64(), rand.Uint64())
|
||||||
|
var k uint64
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
k = s.Next()
|
||||||
|
}
|
||||||
|
_ = k
|
||||||
|
}
|
Loading…
Reference in a new issue