diff --git a/udp/connection.go b/udp/connection.go index 8b8c964..ce4d063 100644 --- a/udp/connection.go +++ b/udp/connection.go @@ -5,31 +5,37 @@ package udp import ( + "bytes" "crypto/aes" "crypto/cipher" "crypto/rand" ) -var connectionKey, connectionIV []byte +type ConnectionIDGenerator struct { + iv, iv2 []byte + block cipher.Block +} -func InitConnectionIDEncryption() error { - connectionKey = make([]byte, 16) - _, err := rand.Read(connectionKey) +func (g *ConnectionIDGenerator) Init() error { + key := make([]byte, 16) + _, err := rand.Read(key) if err != nil { return err } - connectionIV = make([]byte, 16) - _, err = rand.Read(connectionIV) - return err -} - -func GenerateConnectionID(ip []byte) []byte { - block, err := aes.NewCipher(connectionKey) + g.block, err = aes.NewCipher(key) if err != nil { - panic(err) + return err } + return g.NewIV() +} + +func (g *ConnectionIDGenerator) Generate(ip []byte) []byte { + return g.generate(ip, g.iv) +} + +func (g *ConnectionIDGenerator) generate(ip []byte, iv []byte) []byte { if len(ip) > 16 { panic("IP larger than 16 bytes") } @@ -39,7 +45,7 @@ func GenerateConnectionID(ip []byte) []byte { } ct := make([]byte, 16) - stream := cipher.NewCFBDecrypter(block, connectionIV) + stream := cipher.NewCFBDecrypter(g.block, iv) stream.XORKeyStream(ct, ip) for i := len(ip) - 1; i >= 8; i-- { @@ -49,6 +55,28 @@ func GenerateConnectionID(ip []byte) []byte { return ct[:8] } -func init() { - InitConnectionIDEncryption() +func (g *ConnectionIDGenerator) Matches(id []byte, ip []byte) bool { + if expected := g.generate(ip, g.iv); bytes.Equal(id, expected) { + return true + } + + if iv2 := g.iv2; iv2 != nil { + if expected := g.generate(ip, iv2); bytes.Equal(id, expected) { + return true + } + } + + return false +} + +func (g *ConnectionIDGenerator) NewIV() error { + newiv := make([]byte, 16) + if _, err := rand.Read(newiv); err != nil { + return err + } + + g.iv2 = g.iv + g.iv = newiv + + return nil } diff --git a/udp/connection_test.go b/udp/connection_test.go index 7c4aa39..aebccbc 100644 --- a/udp/connection_test.go +++ b/udp/connection_test.go @@ -11,16 +11,18 @@ import ( ) func TestInitReturnsNoError(t *testing.T) { - if err := InitConnectionIDEncryption(); err != nil { - t.Error("InitConnectionIDEncryption returned", err) + gen := &ConnectionIDGenerator{} + if err := gen.Init(); err != nil { + t.Error("Init returned", err) } } func testGenerateConnectionID(t *testing.T, ip net.IP) { - InitConnectionIDEncryption() + gen := &ConnectionIDGenerator{} + gen.Init() - id1 := GenerateConnectionID(ip) - id2 := GenerateConnectionID(ip) + id1 := gen.Generate(ip) + id2 := gen.Generate(ip) if !bytes.Equal(id1, id2) { t.Errorf("Connection ID mismatch: %x != %x", id1, id2) @@ -42,3 +44,30 @@ func TestGenerateConnectionIDIPv4(t *testing.T) { func TestGenerateConnectionIDIPv6(t *testing.T) { testGenerateConnectionID(t, net.ParseIP("1:2:3:4::5:6")) } + +func TestMatchesWorksWithPreviousIV(t *testing.T) { + gen := &ConnectionIDGenerator{} + gen.Init() + ip := net.ParseIP("192.168.1.123").To4() + + id1 := gen.Generate(ip) + if !gen.Matches(id1, ip) { + t.Errorf("Connection ID mismatch for current IV") + } + + gen.NewIV() + if !gen.Matches(id1, ip) { + t.Errorf("Connection ID mismatch for previous IV") + } + + id2 := gen.Generate(ip) + gen.NewIV() + + if gen.Matches(id1, ip) { + t.Errorf("Connection ID matched for discarded IV") + } + + if !gen.Matches(id2, ip) { + t.Errorf("Connection ID mismatch for previous IV") + } +} diff --git a/udp/protocol.go b/udp/protocol.go index 9958838..6948ad7 100644 --- a/udp/protocol.go +++ b/udp/protocol.go @@ -47,7 +47,6 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte connID := packet[0:8] action := binary.BigEndian.Uint32(packet[8:12]) transactionID := packet[12:16] - generatedConnID := GenerateConnectionID(addr.IP) writer := &Writer{ buf: new(bytes.Buffer), @@ -62,7 +61,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte } }() - if action != 0 && !bytes.Equal(connID, generatedConnID) { + if action != 0 && !s.connIDGen.Matches(connID, addr.IP) { writer.WriteError(errBadConnectionID) return } @@ -75,7 +74,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte } writer.writeHeader(0) - writer.buf.Write(generatedConnID) + writer.buf.Write(s.connIDGen.Generate(addr.IP)) case 1: actionName = "announce" @@ -120,13 +119,13 @@ func (s *Server) newAnnounce(packet []byte, ip net.IP) (*models.Announce, error) } ipbuf := packet[84:88] - if !bytes.Equal(ipbuf, []byte{0, 0, 0, 0}) { + if s.config.AllowIPSpoofing && !bytes.Equal(ipbuf, []byte{0, 0, 0, 0}) { ip = net.ParseIP(string(ipbuf)) } + if ip == nil { return nil, errMalformedIP - } - if ipv4 := ip.To4(); ipv4 != nil { + } else if ipv4 := ip.To4(); ipv4 != nil { ip = ipv4 } diff --git a/udp/udp.go b/udp/udp.go index f84b3b0..b838171 100644 --- a/udp/udp.go +++ b/udp/udp.go @@ -19,10 +19,10 @@ import ( // Server represents a UDP torrent tracker. type Server struct { - config *config.Config - tracker *tracker.Tracker - - done bool + config *config.Config + tracker *tracker.Tracker + connIDGen *ConnectionIDGenerator + done bool } func (s *Server) serve() error { @@ -66,7 +66,7 @@ func (s *Server) serve() error { if glog.V(2) { duration := time.Since(start) - glog.Infof("[UDP - %9s] %s", duration, action) + glog.Infof("[UDP - %9s] %s %s", duration, action, addr) } }() } @@ -78,6 +78,13 @@ func (s *Server) serve() error { func (s *Server) Serve() { glog.V(0).Info("Starting UDP on ", s.config.UDPListenAddr) + go func() { + // Generate a new IV every hour. + for range time.Tick(time.Hour) { + s.connIDGen.NewIV() + } + }() + if err := s.serve(); err != nil { glog.Errorf("Failed to run UDP server: %s", err.Error()) } else { @@ -92,8 +99,14 @@ func (s *Server) Stop() { // NewServer returns a new UDP server for a given configuration and tracker. func NewServer(cfg *config.Config, tkr *tracker.Tracker) *Server { + gen := &ConnectionIDGenerator{} + if err := gen.Init(); err != nil { + panic(err) + } + return &Server{ - config: cfg, - tracker: tkr, + config: cfg, + tracker: tkr, + connIDGen: gen, } }