pkg/prand: replace with pkg/xorshift

This commit is contained in:
Leo Balduf 2017-06-10 12:26:42 +02:00
parent 6c3ddaefb3
commit 03b98e0090
5 changed files with 134 additions and 110 deletions

View file

@ -3,12 +3,13 @@ package varinterval
import (
"context"
"errors"
"math/rand"
"sync"
"time"
"github.com/chihaya/chihaya/bittorrent"
"github.com/chihaya/chihaya/middleware"
"github.com/chihaya/chihaya/pkg/prand"
"github.com/chihaya/chihaya/pkg/xorshift"
)
// ErrInvalidModifyResponseProbability is returned for a config with an invalid
@ -47,7 +48,7 @@ func checkConfig(cfg Config) error {
type hook struct {
cfg Config
pr *prand.Container
pr [1024]*xorshift.LockedXORShift128Plus
sync.Mutex
}
@ -59,18 +60,25 @@ func New(cfg Config) (middleware.Hook, error) {
return nil, err
}
return &hook{
h := &hook{
cfg: cfg,
pr: prand.New(1024),
}, nil
pr: [1024]*xorshift.LockedXORShift128Plus{},
}
for i := range h.pr {
h.pr[i] = xorshift.NewLockedXORShift128Plus(rand.Uint64(), rand.Uint64())
}
return h, nil
}
func (h *hook) getXORShiftByInfohash(ih *bittorrent.InfoHash) *xorshift.LockedXORShift128Plus {
return h.pr[(int(ih[1])|int(ih[0])<<8)%len(h.pr)]
}
func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) (context.Context, error) {
r := h.pr.GetByInfohash(req.InfoHash)
r := h.getXORShiftByInfohash(&req.InfoHash)
if h.cfg.ModifyResponseProbability == 1 || r.Float32() < h.cfg.ModifyResponseProbability {
addSeconds := time.Duration(r.Intn(h.cfg.MaxIncreaseDelta)+1) * time.Second
h.pr.ReturnByInfohash(req.InfoHash)
if h.cfg.ModifyResponseProbability == 1 || float32(xorshift.Intn(r, 1<<24))/(1<<24) < h.cfg.ModifyResponseProbability {
addSeconds := time.Duration(xorshift.Intn(r, h.cfg.MaxIncreaseDelta)+1) * time.Second
resp.Interval += addSeconds
@ -81,7 +89,6 @@ func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceReque
return ctx, nil
}
h.pr.ReturnByInfohash(req.InfoHash)
return ctx, nil
}

View file

@ -1,75 +0,0 @@
// Package prand allows parallel access to randomness based on indices or
// infohashes.
package prand
import (
"encoding/binary"
"math/rand"
"sync"
"time"
"github.com/chihaya/chihaya/bittorrent"
)
type lockableRand struct {
*rand.Rand
*sync.Mutex
}
// Container is a container for sources of random numbers that can be locked
// individually.
type Container struct {
rands []lockableRand
}
// NewSeeded returns a new Container with num sources that are seeeded with
// seed.
func NewSeeded(num int, seed int64) *Container {
toReturn := Container{
rands: make([]lockableRand, num),
}
for i := 0; i < num; i++ {
toReturn.rands[i].Rand = rand.New(rand.NewSource(seed))
toReturn.rands[i].Mutex = &sync.Mutex{}
}
return &toReturn
}
// New returns a new Container with num sources that are seeded with the current
// time.
func New(num int) *Container {
return NewSeeded(num, time.Now().UnixNano())
}
// Get locks and returns the nth source.
//
// Get panics if n is not a valid index for this Container.
func (s *Container) Get(n int) *rand.Rand {
r := s.rands[n]
r.Lock()
return r.Rand
}
// GetByInfohash locks and returns a source derived from the infohash.
func (s *Container) GetByInfohash(ih bittorrent.InfoHash) *rand.Rand {
u := int(binary.BigEndian.Uint32(ih[:4])) % len(s.rands)
return s.Get(u)
}
// Return returns the nth source to be available again.
//
// Return panics if n is not a valid index for this Container.
// Return also panics if the nth source is unlocked already.
func (s *Container) Return(n int) {
s.rands[n].Unlock()
}
// ReturnByInfohash returns the source derived from the infohash.
//
// ReturnByInfohash panics if the source is unlocked already.
func (s *Container) ReturnByInfohash(ih bittorrent.InfoHash) {
u := int(binary.BigEndian.Uint32(ih[:4])) % len(s.rands)
s.Return(u)
}

View file

@ -1,25 +0,0 @@
package prand
import (
"math/rand"
"sync/atomic"
"testing"
)
func BenchmarkContainer_GetReturn(b *testing.B) {
c := New(1024)
a := uint64(0)
b.ResetTimer()
b.RunParallel(func(p *testing.PB) {
i := int(atomic.AddUint64(&a, 1))
var r *rand.Rand
for p.Next() {
r = c.Get(i)
c.Return(i)
}
_ = r
})
}

71
pkg/xorshift/xorshift.go Normal file
View file

@ -0,0 +1,71 @@
// Package xorshift implements the XORShift PRNG.
package xorshift
import "sync"
// XORShift describes the functionality of an XORShift PRNG.
type XORShift interface {
Next() uint64
}
// XORShift128Plus holds the state of an XORShift128Plus PRNG.
type XORShift128Plus struct {
state [2]uint64
}
// Next generates a pseudorandom number and advances the state of s.
func (s *XORShift128Plus) Next() uint64 {
s1 := s.state[0]
s0 := s.state[1]
s1Tmp := s1 // need this for result computation
s.state[0] = s0
s1 ^= (s1 << 23) // a
s.state[1] = s1 ^ s0 ^ (s1 >> 18) ^ (s0 >> 5) // b, c
return s0 + s1Tmp
}
// NewXORShift128Plus creates a new XORShift PRNG.
func NewXORShift128Plus(s0, s1 uint64) *XORShift128Plus {
return &XORShift128Plus{
state: [2]uint64{s0, s1},
}
}
// LockedXORShift128Plus is a thread-safe XORShift128Plus.
type LockedXORShift128Plus struct {
sync.Mutex
state [2]uint64
}
// NewLockedXORShift128Plus creates a new LockedXORShift128Plus.
func NewLockedXORShift128Plus(s0, s1 uint64) *LockedXORShift128Plus {
return &LockedXORShift128Plus{
state: [2]uint64{s0, s1},
}
}
// Next generates a pseudorandom number and advances the state of s.
func (s *LockedXORShift128Plus) Next() uint64 {
s.Lock()
s1 := s.state[0]
s0 := s.state[1]
s1Tmp := s1 // need this for result computation
s.state[0] = s0
s1 ^= (s1 << 23) // a
s.state[1] = s1 ^ s0 ^ (s1 >> 18) ^ (s0 >> 5) // b, c
s.Unlock()
return s0 + s1Tmp
}
// Intn generates an int k that satisfies k >= 0 && k < n.
// n must be > 0.
func Intn(s XORShift, n int) int {
if n <= 0 {
panic("invalid n <= 0")
}
v := int(s.Next())
if v < 0 {
v = -v
}
return v % n
}

View file

@ -0,0 +1,46 @@
package xorshift
import (
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestIntn(t *testing.T) {
rand.Seed(time.Now().UnixNano())
s := NewXORShift128Plus(rand.Uint64(), rand.Uint64())
for i := 0; i < 10000; i++ {
k := Intn(s, 10)
require.True(t, k >= 0, "Intn() must be >= 0")
require.True(t, k < 10, "Intn(k) must be < k")
}
}
func BenchmarkXORShift128Plus_Next(b *testing.B) {
s := NewXORShift128Plus(rand.Uint64(), rand.Uint64())
var k uint64
for i := 0; i < b.N; i++ {
k = s.Next()
}
_ = k
}
func BenchmarkIntnXORShift128Plus(b *testing.B) {
s := NewXORShift128Plus(rand.Uint64(), rand.Uint64())
var k int
for i := 0; i < b.N; i++ {
k = Intn(s, 1000)
}
_ = k
}
func BenchmarkLockedXORShift128Plus_Next(b *testing.B) {
s := NewLockedXORShift128Plus(rand.Uint64(), rand.Uint64())
var k uint64
for i := 0; i < b.N; i++ {
k = s.Next()
}
_ = k
}