194 lines
4.8 KiB
Go
194 lines
4.8 KiB
Go
package udp
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/minio/sha256-simd"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/chihaya/chihaya/pkg/log"
|
|
)
|
|
|
|
var golden = []struct {
|
|
createdAt int64
|
|
now int64
|
|
ip string
|
|
key string
|
|
valid bool
|
|
}{
|
|
{0, 1, "127.0.0.1", "", true},
|
|
{0, 420420, "127.0.0.1", "", false},
|
|
{0, 0, "[::]", "", true},
|
|
}
|
|
|
|
// simpleNewConnectionID generates a new connection ID the explicit way.
|
|
// This is used to verify correct behaviour of the generator.
|
|
func simpleNewConnectionID(ip net.IP, now time.Time, key string) []byte {
|
|
buf := make([]byte, 8)
|
|
binary.BigEndian.PutUint32(buf, uint32(now.Unix()))
|
|
|
|
mac := hmac.New(sha256.New, []byte(key))
|
|
mac.Write(buf[:4])
|
|
mac.Write(ip)
|
|
macBytes := mac.Sum(nil)[:4]
|
|
copy(buf[4:], macBytes)
|
|
|
|
// this is just in here because logging impacts performance and we benchmark
|
|
// this version too.
|
|
log.Debug("manually generated connection ID", log.Fields{"ip": ip, "now": now, "connID": buf})
|
|
return buf
|
|
}
|
|
|
|
func TestVerification(t *testing.T) {
|
|
for _, tt := range golden {
|
|
t.Run(fmt.Sprintf("%s created at %d verified at %d", tt.ip, tt.createdAt, tt.now), func(t *testing.T) {
|
|
cid := NewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
|
|
got := ValidConnectionID(cid, net.ParseIP(tt.ip), time.Unix(tt.now, 0), time.Minute, tt.key)
|
|
if got != tt.valid {
|
|
t.Errorf("expected validity: %t got validity: %t", tt.valid, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGeneration(t *testing.T) {
|
|
for _, tt := range golden {
|
|
t.Run(fmt.Sprintf("%s created at %d", tt.ip, tt.createdAt), func(t *testing.T) {
|
|
want := simpleNewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
|
|
got := NewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
|
|
require.Equal(t, want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestReuseGeneratorGenerate(t *testing.T) {
|
|
for _, tt := range golden {
|
|
t.Run(fmt.Sprintf("%s created at %d", tt.ip, tt.createdAt), func(t *testing.T) {
|
|
cid := NewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
|
|
require.Len(t, cid, 8)
|
|
|
|
gen := NewConnectionIDGenerator(tt.key)
|
|
|
|
for i := 0; i < 3; i++ {
|
|
connID := gen.Generate(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0))
|
|
require.Equal(t, cid, connID)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestReuseGeneratorValidate(t *testing.T) {
|
|
for _, tt := range golden {
|
|
t.Run(fmt.Sprintf("%s created at %d verified at %d", tt.ip, tt.createdAt, tt.now), func(t *testing.T) {
|
|
gen := NewConnectionIDGenerator(tt.key)
|
|
cid := gen.Generate(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0))
|
|
for i := 0; i < 3; i++ {
|
|
got := gen.Validate(cid, net.ParseIP(tt.ip), time.Unix(tt.now, 0), time.Minute)
|
|
if got != tt.valid {
|
|
t.Errorf("expected validity: %t got validity: %t", tt.valid, got)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func BenchmarkSimpleNewConnectionID(b *testing.B) {
|
|
ip := net.ParseIP("127.0.0.1")
|
|
key := "some random string that is hopefully at least this long"
|
|
createdAt := time.Now()
|
|
|
|
b.RunParallel(func(pb *testing.PB) {
|
|
sum := int64(0)
|
|
|
|
for pb.Next() {
|
|
cid := simpleNewConnectionID(ip, createdAt, key)
|
|
sum += int64(cid[7])
|
|
}
|
|
|
|
_ = sum
|
|
})
|
|
}
|
|
|
|
func BenchmarkNewConnectionID(b *testing.B) {
|
|
ip := net.ParseIP("127.0.0.1")
|
|
key := "some random string that is hopefully at least this long"
|
|
createdAt := time.Now()
|
|
|
|
b.RunParallel(func(pb *testing.PB) {
|
|
sum := int64(0)
|
|
|
|
for pb.Next() {
|
|
cid := NewConnectionID(ip, createdAt, key)
|
|
sum += int64(cid[7])
|
|
}
|
|
|
|
_ = sum
|
|
})
|
|
}
|
|
|
|
func BenchmarkConnectionIDGenerator_Generate(b *testing.B) {
|
|
ip := net.ParseIP("127.0.0.1")
|
|
key := "some random string that is hopefully at least this long"
|
|
createdAt := time.Now()
|
|
|
|
pool := &sync.Pool{
|
|
New: func() interface{} {
|
|
return NewConnectionIDGenerator(key)
|
|
},
|
|
}
|
|
|
|
b.RunParallel(func(pb *testing.PB) {
|
|
sum := int64(0)
|
|
for pb.Next() {
|
|
gen := pool.Get().(*ConnectionIDGenerator)
|
|
cid := gen.Generate(ip, createdAt)
|
|
sum += int64(cid[7])
|
|
pool.Put(gen)
|
|
}
|
|
})
|
|
}
|
|
|
|
func BenchmarkValidConnectionID(b *testing.B) {
|
|
ip := net.ParseIP("127.0.0.1")
|
|
key := "some random string that is hopefully at least this long"
|
|
createdAt := time.Now()
|
|
cid := NewConnectionID(ip, createdAt, key)
|
|
|
|
b.RunParallel(func(pb *testing.PB) {
|
|
for pb.Next() {
|
|
if !ValidConnectionID(cid, ip, createdAt, 10*time.Second, key) {
|
|
b.FailNow()
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func BenchmarkConnectionIDGenerator_Validate(b *testing.B) {
|
|
ip := net.ParseIP("127.0.0.1")
|
|
key := "some random string that is hopefully at least this long"
|
|
createdAt := time.Now()
|
|
cid := NewConnectionID(ip, createdAt, key)
|
|
|
|
pool := &sync.Pool{
|
|
New: func() interface{} {
|
|
return NewConnectionIDGenerator(key)
|
|
},
|
|
}
|
|
|
|
b.RunParallel(func(pb *testing.PB) {
|
|
for pb.Next() {
|
|
gen := pool.Get().(*ConnectionIDGenerator)
|
|
if !gen.Validate(cid, ip, createdAt, 10*time.Second) {
|
|
b.FailNow()
|
|
}
|
|
pool.Put(gen)
|
|
}
|
|
})
|
|
}
|