udp: Use net.IP inputs for connection ID generation

Add enum for action IDs

Remove unnecessary length check
This commit is contained in:
Justin Li 2015-02-23 21:30:45 -05:00
parent 3d28f281fb
commit 14a6278de0
5 changed files with 23 additions and 23 deletions

View file

@ -9,6 +9,7 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"net"
) )
// ConnectionIDGenerator represents the logic to generate 64-bit UDP // ConnectionIDGenerator represents the logic to generate 64-bit UDP
@ -39,15 +40,11 @@ func NewConnectionIDGenerator() (gen *ConnectionIDGenerator, err error) {
} }
// Generate returns the 64-bit connection ID for an IP // Generate returns the 64-bit connection ID for an IP
func (g *ConnectionIDGenerator) Generate(ip []byte) []byte { func (g *ConnectionIDGenerator) Generate(ip net.IP) []byte {
return g.generate(ip, g.iv) return g.generate(ip, g.iv)
} }
func (g *ConnectionIDGenerator) generate(ip []byte, iv []byte) []byte { func (g *ConnectionIDGenerator) generate(ip net.IP, iv []byte) []byte {
if len(ip) > 16 {
panic("IP larger than 16 bytes")
}
for len(ip) < 8 { for len(ip) < 8 {
ip = append(ip, ip...) // Not enough bits in output. ip = append(ip, ip...) // Not enough bits in output.
} }
@ -65,7 +62,7 @@ func (g *ConnectionIDGenerator) generate(ip []byte, iv []byte) []byte {
// Matches checks if the given connection ID matches an IP with the current or // Matches checks if the given connection ID matches an IP with the current or
// previous initialization vectors. // previous initialization vectors.
func (g *ConnectionIDGenerator) Matches(id []byte, ip []byte) bool { func (g *ConnectionIDGenerator) Matches(id []byte, ip net.IP) bool {
if expected := g.generate(ip, g.iv); bytes.Equal(id, expected) { if expected := g.generate(ip, g.iv); bytes.Equal(id, expected) {
return true return true
} }

View file

@ -26,6 +26,13 @@ var (
errBadConnectionID = models.ProtocolError("bad connection ID") errBadConnectionID = models.ProtocolError("bad connection ID")
) )
const (
connectActionID uint32 = iota
announceActionID
scrapeActionID
errorActionID
)
// handleTorrentError writes err to w if err is a models.ClientError. // handleTorrentError writes err to w if err is a models.ClientError.
func handleTorrentError(err error, w *Writer) { func handleTorrentError(err error, w *Writer) {
if err == nil { if err == nil {
@ -55,11 +62,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
transactionID: transactionID, transactionID: transactionID,
} }
defer func() { defer func() { response = writer.buf.Bytes() }()
if writer.buf.Len() > 0 {
response = writer.buf.Bytes()
}
}()
if action != 0 && !s.connIDGen.Matches(connID, addr.IP) { if action != 0 && !s.connIDGen.Matches(connID, addr.IP) {
writer.WriteError(errBadConnectionID) writer.WriteError(errBadConnectionID)
@ -67,7 +70,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
} }
switch action { switch action {
case 0: case connectActionID:
actionName = "connect" actionName = "connect"
if !bytes.Equal(connID, initialConnectionID) { if !bytes.Equal(connID, initialConnectionID) {
return // Malformed packet. return // Malformed packet.
@ -76,7 +79,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
writer.writeHeader(0) writer.writeHeader(0)
writer.buf.Write(s.connIDGen.Generate(addr.IP)) writer.buf.Write(s.connIDGen.Generate(addr.IP))
case 1: case announceActionID:
actionName = "announce" actionName = "announce"
ann, err := s.newAnnounce(packet, addr.IP) ann, err := s.newAnnounce(packet, addr.IP)
@ -86,7 +89,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
handleTorrentError(err, writer) handleTorrentError(err, writer)
case 2: case scrapeActionID:
actionName = "scrape" actionName = "scrape"
scrape, err := s.newScrape(packet) scrape, err := s.newScrape(packet)

View file

@ -74,7 +74,7 @@ func (s *Server) serve(listenAddr string) error {
response, action := s.handlePacket(buffer[:n], addr) response, action := s.handlePacket(buffer[:n], addr)
pool.GiveSlice(buffer) pool.GiveSlice(buffer)
if response != nil { if len(response) > 0 {
sock.WriteToUDP(response, addr) sock.WriteToUDP(response, addr)
} }

View file

@ -20,10 +20,10 @@ import (
) )
var testPort = "34137" var testPort = "34137"
var connectAction = []byte{0, 0, 0, 0} var connectAction = []byte{0, 0, 0, byte(connectActionID)}
var announceAction = []byte{0, 0, 0, 1} var announceAction = []byte{0, 0, 0, byte(announceActionID)}
var scrapeAction = []byte{0, 0, 0, 2} var scrapeAction = []byte{0, 0, 0, byte(scrapeActionID)}
var errorAction = []byte{0, 0, 0, 3} var errorAction = []byte{0, 0, 0, byte(errorActionID)}
func init() { func init() {
stats.DefaultStats = stats.New(config.StatsConfig{}) stats.DefaultStats = stats.New(config.StatsConfig{})

View file

@ -22,7 +22,7 @@ type Writer struct {
// WriteError writes the failure reason as a null-terminated string. // WriteError writes the failure reason as a null-terminated string.
func (w *Writer) WriteError(err error) error { func (w *Writer) WriteError(err error) error {
w.writeHeader(3) w.writeHeader(errorActionID)
w.buf.WriteString(err.Error()) w.buf.WriteString(err.Error())
w.buf.WriteRune('\000') w.buf.WriteRune('\000')
return nil return nil
@ -30,7 +30,7 @@ func (w *Writer) WriteError(err error) error {
// WriteAnnounce encodes an announce response according to the UDP spec. // WriteAnnounce encodes an announce response according to the UDP spec.
func (w *Writer) WriteAnnounce(res *models.AnnounceResponse) error { func (w *Writer) WriteAnnounce(res *models.AnnounceResponse) error {
w.writeHeader(1) w.writeHeader(announceActionID)
binary.Write(w.buf, binary.BigEndian, uint32(res.Interval/time.Second)) binary.Write(w.buf, binary.BigEndian, uint32(res.Interval/time.Second))
binary.Write(w.buf, binary.BigEndian, uint32(res.Incomplete)) binary.Write(w.buf, binary.BigEndian, uint32(res.Incomplete))
binary.Write(w.buf, binary.BigEndian, uint32(res.Complete)) binary.Write(w.buf, binary.BigEndian, uint32(res.Complete))
@ -45,7 +45,7 @@ func (w *Writer) WriteAnnounce(res *models.AnnounceResponse) error {
// WriteScrape encodes a scrape response according to the UDP spec. // WriteScrape encodes a scrape response according to the UDP spec.
func (w *Writer) WriteScrape(res *models.ScrapeResponse) error { func (w *Writer) WriteScrape(res *models.ScrapeResponse) error {
w.writeHeader(2) w.writeHeader(scrapeActionID)
for _, torrent := range res.Files { for _, torrent := range res.Files {
binary.Write(w.buf, binary.BigEndian, uint32(torrent.Seeders.Len())) binary.Write(w.buf, binary.BigEndian, uint32(torrent.Seeders.Len()))