From 14a6278de015966e2a81a60387f402e8981eb945 Mon Sep 17 00:00:00 2001 From: Justin Li Date: Mon, 23 Feb 2015 21:30:45 -0500 Subject: [PATCH] udp: Use net.IP inputs for connection ID generation Add enum for action IDs Remove unnecessary length check --- udp/connection.go | 11 ++++------- udp/protocol.go | 19 +++++++++++-------- udp/udp.go | 2 +- udp/udp_test.go | 8 ++++---- udp/writer.go | 6 +++--- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/udp/connection.go b/udp/connection.go index 27b497a..e08c0a6 100644 --- a/udp/connection.go +++ b/udp/connection.go @@ -9,6 +9,7 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "net" ) // ConnectionIDGenerator represents the logic to generate 64-bit UDP @@ -39,15 +40,11 @@ func NewConnectionIDGenerator() (gen *ConnectionIDGenerator, err error) { } // Generate returns the 64-bit connection ID for an IP -func (g *ConnectionIDGenerator) Generate(ip []byte) []byte { +func (g *ConnectionIDGenerator) Generate(ip net.IP) []byte { return g.generate(ip, g.iv) } -func (g *ConnectionIDGenerator) generate(ip []byte, iv []byte) []byte { - if len(ip) > 16 { - panic("IP larger than 16 bytes") - } - +func (g *ConnectionIDGenerator) generate(ip net.IP, iv []byte) []byte { for len(ip) < 8 { ip = append(ip, ip...) // Not enough bits in output. } @@ -65,7 +62,7 @@ func (g *ConnectionIDGenerator) generate(ip []byte, iv []byte) []byte { // Matches checks if the given connection ID matches an IP with the current or // previous initialization vectors. -func (g *ConnectionIDGenerator) Matches(id []byte, ip []byte) bool { +func (g *ConnectionIDGenerator) Matches(id []byte, ip net.IP) bool { if expected := g.generate(ip, g.iv); bytes.Equal(id, expected) { return true } diff --git a/udp/protocol.go b/udp/protocol.go index bc5aff0..57cf8f3 100644 --- a/udp/protocol.go +++ b/udp/protocol.go @@ -26,6 +26,13 @@ var ( errBadConnectionID = models.ProtocolError("bad connection ID") ) +const ( + connectActionID uint32 = iota + announceActionID + scrapeActionID + errorActionID +) + // handleTorrentError writes err to w if err is a models.ClientError. func handleTorrentError(err error, w *Writer) { if err == nil { @@ -55,11 +62,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte transactionID: transactionID, } - defer func() { - if writer.buf.Len() > 0 { - response = writer.buf.Bytes() - } - }() + defer func() { response = writer.buf.Bytes() }() if action != 0 && !s.connIDGen.Matches(connID, addr.IP) { writer.WriteError(errBadConnectionID) @@ -67,7 +70,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte } switch action { - case 0: + case connectActionID: actionName = "connect" if !bytes.Equal(connID, initialConnectionID) { return // Malformed packet. @@ -76,7 +79,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte writer.writeHeader(0) writer.buf.Write(s.connIDGen.Generate(addr.IP)) - case 1: + case announceActionID: actionName = "announce" ann, err := s.newAnnounce(packet, addr.IP) @@ -86,7 +89,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte handleTorrentError(err, writer) - case 2: + case scrapeActionID: actionName = "scrape" scrape, err := s.newScrape(packet) diff --git a/udp/udp.go b/udp/udp.go index 890dc71..60bc152 100644 --- a/udp/udp.go +++ b/udp/udp.go @@ -74,7 +74,7 @@ func (s *Server) serve(listenAddr string) error { response, action := s.handlePacket(buffer[:n], addr) pool.GiveSlice(buffer) - if response != nil { + if len(response) > 0 { sock.WriteToUDP(response, addr) } diff --git a/udp/udp_test.go b/udp/udp_test.go index b9d57e3..8a6441f 100644 --- a/udp/udp_test.go +++ b/udp/udp_test.go @@ -20,10 +20,10 @@ import ( ) var testPort = "34137" -var connectAction = []byte{0, 0, 0, 0} -var announceAction = []byte{0, 0, 0, 1} -var scrapeAction = []byte{0, 0, 0, 2} -var errorAction = []byte{0, 0, 0, 3} +var connectAction = []byte{0, 0, 0, byte(connectActionID)} +var announceAction = []byte{0, 0, 0, byte(announceActionID)} +var scrapeAction = []byte{0, 0, 0, byte(scrapeActionID)} +var errorAction = []byte{0, 0, 0, byte(errorActionID)} func init() { stats.DefaultStats = stats.New(config.StatsConfig{}) diff --git a/udp/writer.go b/udp/writer.go index 2d26f9f..b28fd8a 100644 --- a/udp/writer.go +++ b/udp/writer.go @@ -22,7 +22,7 @@ type Writer struct { // WriteError writes the failure reason as a null-terminated string. func (w *Writer) WriteError(err error) error { - w.writeHeader(3) + w.writeHeader(errorActionID) w.buf.WriteString(err.Error()) w.buf.WriteRune('\000') return nil @@ -30,7 +30,7 @@ func (w *Writer) WriteError(err error) error { // WriteAnnounce encodes an announce response according to the UDP spec. func (w *Writer) WriteAnnounce(res *models.AnnounceResponse) error { - w.writeHeader(1) + w.writeHeader(announceActionID) binary.Write(w.buf, binary.BigEndian, uint32(res.Interval/time.Second)) binary.Write(w.buf, binary.BigEndian, uint32(res.Incomplete)) binary.Write(w.buf, binary.BigEndian, uint32(res.Complete)) @@ -45,7 +45,7 @@ func (w *Writer) WriteAnnounce(res *models.AnnounceResponse) error { // WriteScrape encodes a scrape response according to the UDP spec. func (w *Writer) WriteScrape(res *models.ScrapeResponse) error { - w.writeHeader(2) + w.writeHeader(scrapeActionID) for _, torrent := range res.Files { binary.Write(w.buf, binary.BigEndian, uint32(torrent.Seeders.Len()))