diff --git a/frontend/udp/connection_id.go b/frontend/udp/connection_id.go index 3753371..991fea1 100644 --- a/frontend/udp/connection_id.go +++ b/frontend/udp/connection_id.go @@ -3,6 +3,7 @@ package udp import ( "crypto/hmac" "encoding/binary" + "hash" "net" "time" @@ -11,12 +12,67 @@ import ( "github.com/chihaya/chihaya/pkg/log" ) -// ttl is the number of seconds a connection ID should be valid according to -// BEP 15. +// ttl is the duration a connection ID should be valid according to BEP 15. const ttl = 2 * time.Minute -// NewConnectionID creates a new 8 byte connection identifier for UDP packets -// as described by BEP 15. +// NewConnectionID creates an 8-byte connection identifier for UDP packets as +// described by BEP 15. +// This is a wrapper around creating a new ConnectionIDGenerator and generating +// an ID. It is recommended to use the generator for performance. +func NewConnectionID(ip net.IP, now time.Time, key string) []byte { + return NewConnectionIDGenerator(key).Generate(ip, now) +} + +// ValidConnectionID determines whether a connection identifier is legitimate. +// This is a wrapper around creating a new ConnectionIDGenerator and validating +// the ID. It is recommended to use the generator for performance. +func ValidConnectionID(connectionID []byte, ip net.IP, now time.Time, maxClockSkew time.Duration, key string) bool { + return NewConnectionIDGenerator(key).Validate(connectionID, ip, now, maxClockSkew) +} + +// A ConnectionIDGenerator is a reusable generator and validator for connection +// IDs as described in BEP 15. +// It is not thread safe, but is safe to be pooled and reused by other +// goroutines. It manages its state itself, so it can be taken from and returned +// to a pool without any cleanup. +// After initial creation, it can generate connection IDs without allocating. +// See Generate and Validate for usage notes and guarantees. +type ConnectionIDGenerator struct { + // mac is a keyed HMAC that can be reused for subsequent connection ID + // generations. + mac hash.Hash + + // connID is an 8-byte slice that holds the generated connection ID after a + // call to Generate. + // It must not be referenced after the generator is returned to a pool. + // It will be overwritten by subsequent calls to Generate. + connID []byte + + // scratch is a 32-byte slice that is used as a scratchpad for the generated + // HMACs. + scratch []byte +} + +// NewConnectionIDGenerator creates a new connection ID generator. +func NewConnectionIDGenerator(key string) *ConnectionIDGenerator { + return &ConnectionIDGenerator{ + mac: hmac.New(sha256.New, []byte(key)), + connID: make([]byte, 8), + scratch: make([]byte, 32), + } +} + +// reset resets the generator. +// This is called by other methods of the generator, it's not necessary to call +// it after getting a generator from a pool. +func (g *ConnectionIDGenerator) reset() { + g.mac.Reset() + g.connID = g.connID[:8] + g.scratch = g.scratch[:0] +} + +// Generate generates an 8-byte connection ID as described in BEP 15 for the +// given IP and the current time. // // The first 4 bytes of the connection identifier is a unix timestamp and the // last 4 bytes are a truncated HMAC token created from the aforementioned @@ -25,31 +81,36 @@ const ttl = 2 * time.Minute // Truncated HMAC is known to be safe for 2^(-n) where n is the size in bits // of the truncated HMAC token. In this use case we have 32 bits, thus a // forgery probability of approximately 1 in 4 billion. -func NewConnectionID(ip net.IP, now time.Time, key string) []byte { - buf := make([]byte, 8) - binary.BigEndian.PutUint32(buf, uint32(now.Unix())) +// +// The generated ID is written to g.connID, which is also returned. g.connID +// will be reused, so it must not be referenced after returning the generator +// to a pool and will be overwritten be subsequent calls to Generate! +func (g *ConnectionIDGenerator) Generate(ip net.IP, now time.Time) []byte { + g.reset() - mac := hmac.New(sha256.New, []byte(key)) - mac.Write(buf[:4]) - mac.Write(ip) - macBytes := mac.Sum(nil)[:4] - copy(buf[4:], macBytes) + binary.BigEndian.PutUint32(g.connID, uint32(now.Unix())) - log.Debug("generated connection ID", log.Fields{"ip": ip, "now": now, "key": key, "connID": buf}) - return buf + g.mac.Write(g.connID[:4]) + g.mac.Write(ip) + g.scratch = g.mac.Sum(g.scratch) + copy(g.connID[4:8], g.scratch[:4]) + + log.Debug("generated connection ID", log.Fields{"ip": ip, "now": now, "connID": g.connID}) + return g.connID } -// ValidConnectionID determines whether a connection identifier is legitimate. -func ValidConnectionID(connectionID []byte, ip net.IP, now time.Time, maxClockSkew time.Duration, key string) bool { +// Validate validates the given connection ID for an IP and the current time. +func (g *ConnectionIDGenerator) Validate(connectionID []byte, ip net.IP, now time.Time, maxClockSkew time.Duration) bool { ts := time.Unix(int64(binary.BigEndian.Uint32(connectionID[:4])), 0) - log.Debug("validating connection ID", log.Fields{"connID": connectionID, "ip": ip, "ts": ts, "now": now, "key": key}) + log.Debug("validating connection ID", log.Fields{"connID": connectionID, "ip": ip, "ts": ts, "now": now}) if now.After(ts.Add(ttl)) || ts.After(now.Add(maxClockSkew)) { return false } - mac := hmac.New(sha256.New, []byte(key)) - mac.Write(connectionID[:4]) - mac.Write(ip) - expectedMAC := mac.Sum(nil)[:4] - return hmac.Equal(expectedMAC, connectionID[4:]) + g.reset() + + g.mac.Write(connectionID[:4]) + g.mac.Write(ip) + g.scratch = g.mac.Sum(g.scratch) + return hmac.Equal(g.scratch[:4], connectionID[4:]) } diff --git a/frontend/udp/connection_id_test.go b/frontend/udp/connection_id_test.go index ab86b80..044d7db 100644 --- a/frontend/udp/connection_id_test.go +++ b/frontend/udp/connection_id_test.go @@ -1,10 +1,19 @@ package udp import ( + "crypto/hmac" + "encoding/binary" "fmt" "net" + "sync" "testing" "time" + + "github.com/minio/sha256-simd" + + "github.com/stretchr/testify/require" + + "github.com/chihaya/chihaya/pkg/log" ) var golden = []struct { @@ -19,6 +28,24 @@ var golden = []struct { {0, 0, "[::]", "", true}, } +// simpleNewConnectionID generates a new connection ID the explicit way. +// This is used to verify correct behaviour of the generator. +func simpleNewConnectionID(ip net.IP, now time.Time, key string) []byte { + buf := make([]byte, 8) + binary.BigEndian.PutUint32(buf, uint32(now.Unix())) + + mac := hmac.New(sha256.New, []byte(key)) + mac.Write(buf[:4]) + mac.Write(ip) + macBytes := mac.Sum(nil)[:4] + copy(buf[4:], macBytes) + + // this is just in here because logging impacts performance and we benchmark + // this version too. + log.Debug("manually generated connection ID", log.Fields{"ip": ip, "now": now, "connID": buf}) + return buf +} + func TestVerification(t *testing.T) { for _, tt := range golden { t.Run(fmt.Sprintf("%s created at %d verified at %d", tt.ip, tt.createdAt, tt.now), func(t *testing.T) { @@ -31,18 +58,101 @@ func TestVerification(t *testing.T) { } } +func TestGeneration(t *testing.T) { + for _, tt := range golden { + t.Run(fmt.Sprintf("%s created at %d", tt.ip, tt.createdAt), func(t *testing.T) { + want := simpleNewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key) + got := NewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key) + require.Equal(t, want, got) + }) + } +} + +func TestReuseGeneratorGenerate(t *testing.T) { + for _, tt := range golden { + t.Run(fmt.Sprintf("%s created at %d", tt.ip, tt.createdAt), func(t *testing.T) { + cid := NewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key) + require.Len(t, cid, 8) + + gen := NewConnectionIDGenerator(tt.key) + + for i := 0; i < 3; i++ { + connID := gen.Generate(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0)) + require.Equal(t, cid, connID) + } + }) + } +} + +func TestReuseGeneratorValidate(t *testing.T) { + for _, tt := range golden { + t.Run(fmt.Sprintf("%s created at %d verified at %d", tt.ip, tt.createdAt, tt.now), func(t *testing.T) { + gen := NewConnectionIDGenerator(tt.key) + cid := gen.Generate(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0)) + for i := 0; i < 3; i++ { + got := gen.Validate(cid, net.ParseIP(tt.ip), time.Unix(tt.now, 0), time.Minute) + if got != tt.valid { + t.Errorf("expected validity: %t got validity: %t", tt.valid, got) + } + } + }) + } +} + +func BenchmarkSimpleNewConnectionID(b *testing.B) { + ip := net.ParseIP("127.0.0.1") + key := "some random string that is hopefully at least this long" + createdAt := time.Now() + + b.RunParallel(func(pb *testing.PB) { + sum := int64(0) + + for pb.Next() { + cid := simpleNewConnectionID(ip, createdAt, key) + sum += int64(cid[7]) + } + + _ = sum + }) +} + func BenchmarkNewConnectionID(b *testing.B) { ip := net.ParseIP("127.0.0.1") key := "some random string that is hopefully at least this long" createdAt := time.Now() - sum := int64(0) - for i := 0; i < b.N; i++ { - cid := NewConnectionID(ip, createdAt, key) - sum += int64(cid[7]) + b.RunParallel(func(pb *testing.PB) { + sum := int64(0) + + for pb.Next() { + cid := NewConnectionID(ip, createdAt, key) + sum += int64(cid[7]) + } + + _ = sum + }) +} + +func BenchmarkConnectionIDGenerator_Generate(b *testing.B) { + ip := net.ParseIP("127.0.0.1") + key := "some random string that is hopefully at least this long" + createdAt := time.Now() + + pool := &sync.Pool{ + New: func() interface{} { + return NewConnectionIDGenerator(key) + }, } - _ = sum + b.RunParallel(func(pb *testing.PB) { + sum := int64(0) + for pb.Next() { + gen := pool.Get().(*ConnectionIDGenerator) + cid := gen.Generate(ip, createdAt) + sum += int64(cid[7]) + pool.Put(gen) + } + }) } func BenchmarkValidConnectionID(b *testing.B) { @@ -51,9 +161,34 @@ func BenchmarkValidConnectionID(b *testing.B) { createdAt := time.Now() cid := NewConnectionID(ip, createdAt, key) - for i := 0; i < b.N; i++ { - if !ValidConnectionID(cid, ip, createdAt, 10*time.Second, key) { - b.FailNow() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if !ValidConnectionID(cid, ip, createdAt, 10*time.Second, key) { + b.FailNow() + } } - } + }) +} + +func BenchmarkConnectionIDGenerator_Validate(b *testing.B) { + ip := net.ParseIP("127.0.0.1") + key := "some random string that is hopefully at least this long" + createdAt := time.Now() + cid := NewConnectionID(ip, createdAt, key) + + pool := &sync.Pool{ + New: func() interface{} { + return NewConnectionIDGenerator(key) + }, + } + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + gen := pool.Get().(*ConnectionIDGenerator) + if !gen.Validate(cid, ip, createdAt, 10*time.Second) { + b.FailNow() + } + pool.Put(gen) + } + }) } diff --git a/frontend/udp/frontend.go b/frontend/udp/frontend.go index 1681adb..bc5b421 100644 --- a/frontend/udp/frontend.go +++ b/frontend/udp/frontend.go @@ -52,6 +52,8 @@ type Frontend struct { closing chan struct{} wg sync.WaitGroup + genPool *sync.Pool + logic frontend.TrackerLogic Config } @@ -75,6 +77,11 @@ func NewFrontend(logic frontend.TrackerLogic, cfg Config) (*Frontend, error) { closing: make(chan struct{}), logic: logic, Config: cfg, + genPool: &sync.Pool{ + New: func() interface{} { + return NewConnectionIDGenerator(cfg.PrivateKey) + }, + }, } go func() { @@ -211,9 +218,13 @@ func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string actionID := binary.BigEndian.Uint32(r.Packet[8:12]) txID := r.Packet[12:16] + // get a connection ID generator/validator from the pool. + gen := t.genPool.Get().(*ConnectionIDGenerator) + defer t.genPool.Put(gen) + // If this isn't requesting a new connection ID and the connection ID is // invalid, then fail. - if actionID != connectActionID && !ValidConnectionID(connID, r.IP, timecache.Now(), t.MaxClockSkew, t.PrivateKey) { + if actionID != connectActionID && !gen.Validate(connID, r.IP, timecache.Now(), t.MaxClockSkew) { err = errBadConnectionID WriteError(w, txID, err) return @@ -239,7 +250,7 @@ func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string panic(fmt.Sprintf("udp: invalid IP: neither v4 nor v6, IP: %#v", r.IP)) } - WriteConnectionID(w, txID, NewConnectionID(r.IP, timecache.Now(), t.PrivateKey)) + WriteConnectionID(w, txID, gen.Generate(r.IP, timecache.Now())) case announceActionID, announceV6ActionID: actionName = "announce"