package udp import ( "crypto/hmac" "encoding/binary" "fmt" "net" "sync" "testing" "time" sha256 "github.com/minio/sha256-simd" "github.com/stretchr/testify/require" "github.com/chihaya/chihaya/pkg/log" ) var golden = []struct { createdAt int64 now int64 ip string key string valid bool }{ {0, 1, "127.0.0.1", "", true}, {0, 420420, "127.0.0.1", "", false}, {0, 0, "[::]", "", true}, } // simpleNewConnectionID generates a new connection ID the explicit way. // This is used to verify correct behaviour of the generator. func simpleNewConnectionID(ip net.IP, now time.Time, key string) []byte { buf := make([]byte, 8) binary.BigEndian.PutUint32(buf, uint32(now.Unix())) mac := hmac.New(sha256.New, []byte(key)) mac.Write(buf[:4]) mac.Write(ip) macBytes := mac.Sum(nil)[:4] copy(buf[4:], macBytes) // this is just in here because logging impacts performance and we benchmark // this version too. log.Debug("manually generated connection ID", log.Fields{"ip": ip, "now": now, "connID": buf}) return buf } func TestVerification(t *testing.T) { for _, tt := range golden { t.Run(fmt.Sprintf("%s created at %d verified at %d", tt.ip, tt.createdAt, tt.now), func(t *testing.T) { cid := NewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key) got := ValidConnectionID(cid, net.ParseIP(tt.ip), time.Unix(tt.now, 0), time.Minute, tt.key) if got != tt.valid { t.Errorf("expected validity: %t got validity: %t", tt.valid, got) } }) } } func TestGeneration(t *testing.T) { for _, tt := range golden { t.Run(fmt.Sprintf("%s created at %d", tt.ip, tt.createdAt), func(t *testing.T) { want := simpleNewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key) got := NewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key) require.Equal(t, want, got) }) } } func TestReuseGeneratorGenerate(t *testing.T) { for _, tt := range golden { t.Run(fmt.Sprintf("%s created at %d", tt.ip, tt.createdAt), func(t *testing.T) { cid := NewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key) require.Len(t, cid, 8) gen := NewConnectionIDGenerator(tt.key) for i := 0; i < 3; i++ { connID := gen.Generate(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0)) require.Equal(t, cid, connID) } }) } } func TestReuseGeneratorValidate(t *testing.T) { for _, tt := range golden { t.Run(fmt.Sprintf("%s created at %d verified at %d", tt.ip, tt.createdAt, tt.now), func(t *testing.T) { gen := NewConnectionIDGenerator(tt.key) cid := gen.Generate(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0)) for i := 0; i < 3; i++ { got := gen.Validate(cid, net.ParseIP(tt.ip), time.Unix(tt.now, 0), time.Minute) if got != tt.valid { t.Errorf("expected validity: %t got validity: %t", tt.valid, got) } } }) } } func BenchmarkSimpleNewConnectionID(b *testing.B) { ip := net.ParseIP("127.0.0.1") key := "some random string that is hopefully at least this long" createdAt := time.Now() b.RunParallel(func(pb *testing.PB) { sum := int64(0) for pb.Next() { cid := simpleNewConnectionID(ip, createdAt, key) sum += int64(cid[7]) } _ = sum }) } func BenchmarkNewConnectionID(b *testing.B) { ip := net.ParseIP("127.0.0.1") key := "some random string that is hopefully at least this long" createdAt := time.Now() b.RunParallel(func(pb *testing.PB) { sum := int64(0) for pb.Next() { cid := NewConnectionID(ip, createdAt, key) sum += int64(cid[7]) } _ = sum }) } func BenchmarkConnectionIDGenerator_Generate(b *testing.B) { ip := net.ParseIP("127.0.0.1") key := "some random string that is hopefully at least this long" createdAt := time.Now() pool := &sync.Pool{ New: func() interface{} { return NewConnectionIDGenerator(key) }, } b.RunParallel(func(pb *testing.PB) { sum := int64(0) for pb.Next() { gen := pool.Get().(*ConnectionIDGenerator) cid := gen.Generate(ip, createdAt) sum += int64(cid[7]) pool.Put(gen) } }) } func BenchmarkValidConnectionID(b *testing.B) { ip := net.ParseIP("127.0.0.1") key := "some random string that is hopefully at least this long" createdAt := time.Now() cid := NewConnectionID(ip, createdAt, key) b.RunParallel(func(pb *testing.PB) { for pb.Next() { if !ValidConnectionID(cid, ip, createdAt, 10*time.Second, key) { b.FailNow() } } }) } func BenchmarkConnectionIDGenerator_Validate(b *testing.B) { ip := net.ParseIP("127.0.0.1") key := "some random string that is hopefully at least this long" createdAt := time.Now() cid := NewConnectionID(ip, createdAt, key) pool := &sync.Pool{ New: func() interface{} { return NewConnectionIDGenerator(key) }, } b.RunParallel(func(pb *testing.PB) { for pb.Next() { gen := pool.Get().(*ConnectionIDGenerator) if !gen.Validate(cid, ip, createdAt, 10*time.Second) { b.FailNow() } pool.Put(gen) } }) }