Merge pull request #420 from mrd0ll4r/connid-pool
frontend/udp: pool connection ID generation state
This commit is contained in:
commit
564a54a178
3 changed files with 240 additions and 33 deletions
|
@ -3,6 +3,7 @@ package udp
|
||||||
import (
|
import (
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"hash"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -11,12 +12,67 @@ import (
|
||||||
"github.com/chihaya/chihaya/pkg/log"
|
"github.com/chihaya/chihaya/pkg/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ttl is the number of seconds a connection ID should be valid according to
|
// ttl is the duration a connection ID should be valid according to BEP 15.
|
||||||
// BEP 15.
|
|
||||||
const ttl = 2 * time.Minute
|
const ttl = 2 * time.Minute
|
||||||
|
|
||||||
// NewConnectionID creates a new 8 byte connection identifier for UDP packets
|
// NewConnectionID creates an 8-byte connection identifier for UDP packets as
|
||||||
// as described by BEP 15.
|
// described by BEP 15.
|
||||||
|
// This is a wrapper around creating a new ConnectionIDGenerator and generating
|
||||||
|
// an ID. It is recommended to use the generator for performance.
|
||||||
|
func NewConnectionID(ip net.IP, now time.Time, key string) []byte {
|
||||||
|
return NewConnectionIDGenerator(key).Generate(ip, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidConnectionID determines whether a connection identifier is legitimate.
|
||||||
|
// This is a wrapper around creating a new ConnectionIDGenerator and validating
|
||||||
|
// the ID. It is recommended to use the generator for performance.
|
||||||
|
func ValidConnectionID(connectionID []byte, ip net.IP, now time.Time, maxClockSkew time.Duration, key string) bool {
|
||||||
|
return NewConnectionIDGenerator(key).Validate(connectionID, ip, now, maxClockSkew)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A ConnectionIDGenerator is a reusable generator and validator for connection
|
||||||
|
// IDs as described in BEP 15.
|
||||||
|
// It is not thread safe, but is safe to be pooled and reused by other
|
||||||
|
// goroutines. It manages its state itself, so it can be taken from and returned
|
||||||
|
// to a pool without any cleanup.
|
||||||
|
// After initial creation, it can generate connection IDs without allocating.
|
||||||
|
// See Generate and Validate for usage notes and guarantees.
|
||||||
|
type ConnectionIDGenerator struct {
|
||||||
|
// mac is a keyed HMAC that can be reused for subsequent connection ID
|
||||||
|
// generations.
|
||||||
|
mac hash.Hash
|
||||||
|
|
||||||
|
// connID is an 8-byte slice that holds the generated connection ID after a
|
||||||
|
// call to Generate.
|
||||||
|
// It must not be referenced after the generator is returned to a pool.
|
||||||
|
// It will be overwritten by subsequent calls to Generate.
|
||||||
|
connID []byte
|
||||||
|
|
||||||
|
// scratch is a 32-byte slice that is used as a scratchpad for the generated
|
||||||
|
// HMACs.
|
||||||
|
scratch []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConnectionIDGenerator creates a new connection ID generator.
|
||||||
|
func NewConnectionIDGenerator(key string) *ConnectionIDGenerator {
|
||||||
|
return &ConnectionIDGenerator{
|
||||||
|
mac: hmac.New(sha256.New, []byte(key)),
|
||||||
|
connID: make([]byte, 8),
|
||||||
|
scratch: make([]byte, 32),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset resets the generator.
|
||||||
|
// This is called by other methods of the generator, it's not necessary to call
|
||||||
|
// it after getting a generator from a pool.
|
||||||
|
func (g *ConnectionIDGenerator) reset() {
|
||||||
|
g.mac.Reset()
|
||||||
|
g.connID = g.connID[:8]
|
||||||
|
g.scratch = g.scratch[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate generates an 8-byte connection ID as described in BEP 15 for the
|
||||||
|
// given IP and the current time.
|
||||||
//
|
//
|
||||||
// The first 4 bytes of the connection identifier is a unix timestamp and the
|
// The first 4 bytes of the connection identifier is a unix timestamp and the
|
||||||
// last 4 bytes are a truncated HMAC token created from the aforementioned
|
// last 4 bytes are a truncated HMAC token created from the aforementioned
|
||||||
|
@ -25,31 +81,36 @@ const ttl = 2 * time.Minute
|
||||||
// Truncated HMAC is known to be safe for 2^(-n) where n is the size in bits
|
// Truncated HMAC is known to be safe for 2^(-n) where n is the size in bits
|
||||||
// of the truncated HMAC token. In this use case we have 32 bits, thus a
|
// of the truncated HMAC token. In this use case we have 32 bits, thus a
|
||||||
// forgery probability of approximately 1 in 4 billion.
|
// forgery probability of approximately 1 in 4 billion.
|
||||||
func NewConnectionID(ip net.IP, now time.Time, key string) []byte {
|
//
|
||||||
buf := make([]byte, 8)
|
// The generated ID is written to g.connID, which is also returned. g.connID
|
||||||
binary.BigEndian.PutUint32(buf, uint32(now.Unix()))
|
// will be reused, so it must not be referenced after returning the generator
|
||||||
|
// to a pool and will be overwritten be subsequent calls to Generate!
|
||||||
|
func (g *ConnectionIDGenerator) Generate(ip net.IP, now time.Time) []byte {
|
||||||
|
g.reset()
|
||||||
|
|
||||||
mac := hmac.New(sha256.New, []byte(key))
|
binary.BigEndian.PutUint32(g.connID, uint32(now.Unix()))
|
||||||
mac.Write(buf[:4])
|
|
||||||
mac.Write(ip)
|
|
||||||
macBytes := mac.Sum(nil)[:4]
|
|
||||||
copy(buf[4:], macBytes)
|
|
||||||
|
|
||||||
log.Debug("generated connection ID", log.Fields{"ip": ip, "now": now, "key": key, "connID": buf})
|
g.mac.Write(g.connID[:4])
|
||||||
return buf
|
g.mac.Write(ip)
|
||||||
|
g.scratch = g.mac.Sum(g.scratch)
|
||||||
|
copy(g.connID[4:8], g.scratch[:4])
|
||||||
|
|
||||||
|
log.Debug("generated connection ID", log.Fields{"ip": ip, "now": now, "connID": g.connID})
|
||||||
|
return g.connID
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidConnectionID determines whether a connection identifier is legitimate.
|
// Validate validates the given connection ID for an IP and the current time.
|
||||||
func ValidConnectionID(connectionID []byte, ip net.IP, now time.Time, maxClockSkew time.Duration, key string) bool {
|
func (g *ConnectionIDGenerator) Validate(connectionID []byte, ip net.IP, now time.Time, maxClockSkew time.Duration) bool {
|
||||||
ts := time.Unix(int64(binary.BigEndian.Uint32(connectionID[:4])), 0)
|
ts := time.Unix(int64(binary.BigEndian.Uint32(connectionID[:4])), 0)
|
||||||
log.Debug("validating connection ID", log.Fields{"connID": connectionID, "ip": ip, "ts": ts, "now": now, "key": key})
|
log.Debug("validating connection ID", log.Fields{"connID": connectionID, "ip": ip, "ts": ts, "now": now})
|
||||||
if now.After(ts.Add(ttl)) || ts.After(now.Add(maxClockSkew)) {
|
if now.After(ts.Add(ttl)) || ts.After(now.Add(maxClockSkew)) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
mac := hmac.New(sha256.New, []byte(key))
|
g.reset()
|
||||||
mac.Write(connectionID[:4])
|
|
||||||
mac.Write(ip)
|
g.mac.Write(connectionID[:4])
|
||||||
expectedMAC := mac.Sum(nil)[:4]
|
g.mac.Write(ip)
|
||||||
return hmac.Equal(expectedMAC, connectionID[4:])
|
g.scratch = g.mac.Sum(g.scratch)
|
||||||
|
return hmac.Equal(g.scratch[:4], connectionID[4:])
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,19 @@
|
||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/minio/sha256-simd"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/chihaya/chihaya/pkg/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
var golden = []struct {
|
var golden = []struct {
|
||||||
|
@ -19,6 +28,24 @@ var golden = []struct {
|
||||||
{0, 0, "[::]", "", true},
|
{0, 0, "[::]", "", true},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// simpleNewConnectionID generates a new connection ID the explicit way.
|
||||||
|
// This is used to verify correct behaviour of the generator.
|
||||||
|
func simpleNewConnectionID(ip net.IP, now time.Time, key string) []byte {
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
binary.BigEndian.PutUint32(buf, uint32(now.Unix()))
|
||||||
|
|
||||||
|
mac := hmac.New(sha256.New, []byte(key))
|
||||||
|
mac.Write(buf[:4])
|
||||||
|
mac.Write(ip)
|
||||||
|
macBytes := mac.Sum(nil)[:4]
|
||||||
|
copy(buf[4:], macBytes)
|
||||||
|
|
||||||
|
// this is just in here because logging impacts performance and we benchmark
|
||||||
|
// this version too.
|
||||||
|
log.Debug("manually generated connection ID", log.Fields{"ip": ip, "now": now, "connID": buf})
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
func TestVerification(t *testing.T) {
|
func TestVerification(t *testing.T) {
|
||||||
for _, tt := range golden {
|
for _, tt := range golden {
|
||||||
t.Run(fmt.Sprintf("%s created at %d verified at %d", tt.ip, tt.createdAt, tt.now), func(t *testing.T) {
|
t.Run(fmt.Sprintf("%s created at %d verified at %d", tt.ip, tt.createdAt, tt.now), func(t *testing.T) {
|
||||||
|
@ -31,18 +58,101 @@ func TestVerification(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGeneration(t *testing.T) {
|
||||||
|
for _, tt := range golden {
|
||||||
|
t.Run(fmt.Sprintf("%s created at %d", tt.ip, tt.createdAt), func(t *testing.T) {
|
||||||
|
want := simpleNewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
|
||||||
|
got := NewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
|
||||||
|
require.Equal(t, want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReuseGeneratorGenerate(t *testing.T) {
|
||||||
|
for _, tt := range golden {
|
||||||
|
t.Run(fmt.Sprintf("%s created at %d", tt.ip, tt.createdAt), func(t *testing.T) {
|
||||||
|
cid := NewConnectionID(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0), tt.key)
|
||||||
|
require.Len(t, cid, 8)
|
||||||
|
|
||||||
|
gen := NewConnectionIDGenerator(tt.key)
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
connID := gen.Generate(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0))
|
||||||
|
require.Equal(t, cid, connID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReuseGeneratorValidate(t *testing.T) {
|
||||||
|
for _, tt := range golden {
|
||||||
|
t.Run(fmt.Sprintf("%s created at %d verified at %d", tt.ip, tt.createdAt, tt.now), func(t *testing.T) {
|
||||||
|
gen := NewConnectionIDGenerator(tt.key)
|
||||||
|
cid := gen.Generate(net.ParseIP(tt.ip), time.Unix(tt.createdAt, 0))
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
got := gen.Validate(cid, net.ParseIP(tt.ip), time.Unix(tt.now, 0), time.Minute)
|
||||||
|
if got != tt.valid {
|
||||||
|
t.Errorf("expected validity: %t got validity: %t", tt.valid, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSimpleNewConnectionID(b *testing.B) {
|
||||||
|
ip := net.ParseIP("127.0.0.1")
|
||||||
|
key := "some random string that is hopefully at least this long"
|
||||||
|
createdAt := time.Now()
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
sum := int64(0)
|
||||||
|
|
||||||
|
for pb.Next() {
|
||||||
|
cid := simpleNewConnectionID(ip, createdAt, key)
|
||||||
|
sum += int64(cid[7])
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = sum
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkNewConnectionID(b *testing.B) {
|
func BenchmarkNewConnectionID(b *testing.B) {
|
||||||
ip := net.ParseIP("127.0.0.1")
|
ip := net.ParseIP("127.0.0.1")
|
||||||
key := "some random string that is hopefully at least this long"
|
key := "some random string that is hopefully at least this long"
|
||||||
createdAt := time.Now()
|
createdAt := time.Now()
|
||||||
sum := int64(0)
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
cid := NewConnectionID(ip, createdAt, key)
|
sum := int64(0)
|
||||||
sum += int64(cid[7])
|
|
||||||
|
for pb.Next() {
|
||||||
|
cid := NewConnectionID(ip, createdAt, key)
|
||||||
|
sum += int64(cid[7])
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = sum
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkConnectionIDGenerator_Generate(b *testing.B) {
|
||||||
|
ip := net.ParseIP("127.0.0.1")
|
||||||
|
key := "some random string that is hopefully at least this long"
|
||||||
|
createdAt := time.Now()
|
||||||
|
|
||||||
|
pool := &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return NewConnectionIDGenerator(key)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = sum
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
sum := int64(0)
|
||||||
|
for pb.Next() {
|
||||||
|
gen := pool.Get().(*ConnectionIDGenerator)
|
||||||
|
cid := gen.Generate(ip, createdAt)
|
||||||
|
sum += int64(cid[7])
|
||||||
|
pool.Put(gen)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkValidConnectionID(b *testing.B) {
|
func BenchmarkValidConnectionID(b *testing.B) {
|
||||||
|
@ -51,9 +161,34 @@ func BenchmarkValidConnectionID(b *testing.B) {
|
||||||
createdAt := time.Now()
|
createdAt := time.Now()
|
||||||
cid := NewConnectionID(ip, createdAt, key)
|
cid := NewConnectionID(ip, createdAt, key)
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
if !ValidConnectionID(cid, ip, createdAt, 10*time.Second, key) {
|
for pb.Next() {
|
||||||
b.FailNow()
|
if !ValidConnectionID(cid, ip, createdAt, 10*time.Second, key) {
|
||||||
|
b.FailNow()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkConnectionIDGenerator_Validate(b *testing.B) {
|
||||||
|
ip := net.ParseIP("127.0.0.1")
|
||||||
|
key := "some random string that is hopefully at least this long"
|
||||||
|
createdAt := time.Now()
|
||||||
|
cid := NewConnectionID(ip, createdAt, key)
|
||||||
|
|
||||||
|
pool := &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return NewConnectionIDGenerator(key)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
gen := pool.Get().(*ConnectionIDGenerator)
|
||||||
|
if !gen.Validate(cid, ip, createdAt, 10*time.Second) {
|
||||||
|
b.FailNow()
|
||||||
|
}
|
||||||
|
pool.Put(gen)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,6 +52,8 @@ type Frontend struct {
|
||||||
closing chan struct{}
|
closing chan struct{}
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
genPool *sync.Pool
|
||||||
|
|
||||||
logic frontend.TrackerLogic
|
logic frontend.TrackerLogic
|
||||||
Config
|
Config
|
||||||
}
|
}
|
||||||
|
@ -75,6 +77,11 @@ func NewFrontend(logic frontend.TrackerLogic, cfg Config) (*Frontend, error) {
|
||||||
closing: make(chan struct{}),
|
closing: make(chan struct{}),
|
||||||
logic: logic,
|
logic: logic,
|
||||||
Config: cfg,
|
Config: cfg,
|
||||||
|
genPool: &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return NewConnectionIDGenerator(cfg.PrivateKey)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -211,9 +218,13 @@ func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string
|
||||||
actionID := binary.BigEndian.Uint32(r.Packet[8:12])
|
actionID := binary.BigEndian.Uint32(r.Packet[8:12])
|
||||||
txID := r.Packet[12:16]
|
txID := r.Packet[12:16]
|
||||||
|
|
||||||
|
// get a connection ID generator/validator from the pool.
|
||||||
|
gen := t.genPool.Get().(*ConnectionIDGenerator)
|
||||||
|
defer t.genPool.Put(gen)
|
||||||
|
|
||||||
// If this isn't requesting a new connection ID and the connection ID is
|
// If this isn't requesting a new connection ID and the connection ID is
|
||||||
// invalid, then fail.
|
// invalid, then fail.
|
||||||
if actionID != connectActionID && !ValidConnectionID(connID, r.IP, timecache.Now(), t.MaxClockSkew, t.PrivateKey) {
|
if actionID != connectActionID && !gen.Validate(connID, r.IP, timecache.Now(), t.MaxClockSkew) {
|
||||||
err = errBadConnectionID
|
err = errBadConnectionID
|
||||||
WriteError(w, txID, err)
|
WriteError(w, txID, err)
|
||||||
return
|
return
|
||||||
|
@ -239,7 +250,7 @@ func (t *Frontend) handleRequest(r Request, w ResponseWriter) (actionName string
|
||||||
panic(fmt.Sprintf("udp: invalid IP: neither v4 nor v6, IP: %#v", r.IP))
|
panic(fmt.Sprintf("udp: invalid IP: neither v4 nor v6, IP: %#v", r.IP))
|
||||||
}
|
}
|
||||||
|
|
||||||
WriteConnectionID(w, txID, NewConnectionID(r.IP, timecache.Now(), t.PrivateKey))
|
WriteConnectionID(w, txID, gen.Generate(r.IP, timecache.Now()))
|
||||||
|
|
||||||
case announceActionID, announceV6ActionID:
|
case announceActionID, announceV6ActionID:
|
||||||
actionName = "announce"
|
actionName = "announce"
|
||||||
|
|
Loading…
Reference in a new issue