udp: Use net.IP inputs for connection ID generation

Add enum for action IDs

Remove unnecessary length check
This commit is contained in:
Justin Li 2015-02-23 21:30:45 -05:00
parent 3d28f281fb
commit 14a6278de0
5 changed files with 23 additions and 23 deletions

View file

@ -9,6 +9,7 @@ import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"net"
)
// ConnectionIDGenerator represents the logic to generate 64-bit UDP
@ -39,15 +40,11 @@ func NewConnectionIDGenerator() (gen *ConnectionIDGenerator, err error) {
}
// Generate returns the 64-bit connection ID for an IP
func (g *ConnectionIDGenerator) Generate(ip []byte) []byte {
func (g *ConnectionIDGenerator) Generate(ip net.IP) []byte {
return g.generate(ip, g.iv)
}
func (g *ConnectionIDGenerator) generate(ip []byte, iv []byte) []byte {
if len(ip) > 16 {
panic("IP larger than 16 bytes")
}
func (g *ConnectionIDGenerator) generate(ip net.IP, iv []byte) []byte {
for len(ip) < 8 {
ip = append(ip, ip...) // Not enough bits in output.
}
@ -65,7 +62,7 @@ func (g *ConnectionIDGenerator) generate(ip []byte, iv []byte) []byte {
// Matches checks if the given connection ID matches an IP with the current or
// previous initialization vectors.
func (g *ConnectionIDGenerator) Matches(id []byte, ip []byte) bool {
func (g *ConnectionIDGenerator) Matches(id []byte, ip net.IP) bool {
if expected := g.generate(ip, g.iv); bytes.Equal(id, expected) {
return true
}

View file

@ -26,6 +26,13 @@ var (
errBadConnectionID = models.ProtocolError("bad connection ID")
)
const (
connectActionID uint32 = iota
announceActionID
scrapeActionID
errorActionID
)
// handleTorrentError writes err to w if err is a models.ClientError.
func handleTorrentError(err error, w *Writer) {
if err == nil {
@ -55,11 +62,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
transactionID: transactionID,
}
defer func() {
if writer.buf.Len() > 0 {
response = writer.buf.Bytes()
}
}()
defer func() { response = writer.buf.Bytes() }()
if action != 0 && !s.connIDGen.Matches(connID, addr.IP) {
writer.WriteError(errBadConnectionID)
@ -67,7 +70,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
}
switch action {
case 0:
case connectActionID:
actionName = "connect"
if !bytes.Equal(connID, initialConnectionID) {
return // Malformed packet.
@ -76,7 +79,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
writer.writeHeader(0)
writer.buf.Write(s.connIDGen.Generate(addr.IP))
case 1:
case announceActionID:
actionName = "announce"
ann, err := s.newAnnounce(packet, addr.IP)
@ -86,7 +89,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
handleTorrentError(err, writer)
case 2:
case scrapeActionID:
actionName = "scrape"
scrape, err := s.newScrape(packet)

View file

@ -74,7 +74,7 @@ func (s *Server) serve(listenAddr string) error {
response, action := s.handlePacket(buffer[:n], addr)
pool.GiveSlice(buffer)
if response != nil {
if len(response) > 0 {
sock.WriteToUDP(response, addr)
}

View file

@ -20,10 +20,10 @@ import (
)
var testPort = "34137"
var connectAction = []byte{0, 0, 0, 0}
var announceAction = []byte{0, 0, 0, 1}
var scrapeAction = []byte{0, 0, 0, 2}
var errorAction = []byte{0, 0, 0, 3}
var connectAction = []byte{0, 0, 0, byte(connectActionID)}
var announceAction = []byte{0, 0, 0, byte(announceActionID)}
var scrapeAction = []byte{0, 0, 0, byte(scrapeActionID)}
var errorAction = []byte{0, 0, 0, byte(errorActionID)}
func init() {
stats.DefaultStats = stats.New(config.StatsConfig{})

View file

@ -22,7 +22,7 @@ type Writer struct {
// WriteError writes the failure reason as a null-terminated string.
func (w *Writer) WriteError(err error) error {
w.writeHeader(3)
w.writeHeader(errorActionID)
w.buf.WriteString(err.Error())
w.buf.WriteRune('\000')
return nil
@ -30,7 +30,7 @@ func (w *Writer) WriteError(err error) error {
// WriteAnnounce encodes an announce response according to the UDP spec.
func (w *Writer) WriteAnnounce(res *models.AnnounceResponse) error {
w.writeHeader(1)
w.writeHeader(announceActionID)
binary.Write(w.buf, binary.BigEndian, uint32(res.Interval/time.Second))
binary.Write(w.buf, binary.BigEndian, uint32(res.Incomplete))
binary.Write(w.buf, binary.BigEndian, uint32(res.Complete))
@ -45,7 +45,7 @@ func (w *Writer) WriteAnnounce(res *models.AnnounceResponse) error {
// WriteScrape encodes a scrape response according to the UDP spec.
func (w *Writer) WriteScrape(res *models.ScrapeResponse) error {
w.writeHeader(2)
w.writeHeader(scrapeActionID)
for _, torrent := range res.Files {
binary.Write(w.buf, binary.BigEndian, uint32(torrent.Seeders.Len()))