udp: Clean up connection ID checking

This commit is contained in:
Justin Li 2015-02-20 12:52:49 -05:00
parent 9526df74ad
commit 0d33210901

View file

@ -58,6 +58,17 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
transactionID: transactionID, transactionID: transactionID,
} }
defer func() {
if writer.buf.Len() > 0 {
response = writer.buf.Bytes()
}
}()
if action != 0 && !bytes.Equal(connID, generatedConnID) {
writer.WriteError(errBadConnectionID)
return
}
switch action { switch action {
case 0: case 0:
// Connect request. // Connect request.
@ -71,36 +82,25 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
case 1: case 1:
// Announce request. // Announce request.
if !bytes.Equal(connID, generatedConnID) { ann, err := s.newAnnounce(packet, addr.IP)
writer.WriteError(errBadConnectionID)
} else {
ann, err := s.newAnnounce(packet, addr.IP)
if err == nil { if err == nil {
err = s.tracker.HandleAnnounce(ann, writer) err = s.tracker.HandleAnnounce(ann, writer)
}
handleTorrentError(err, writer)
} }
handleTorrentError(err, writer)
case 2: case 2:
// Scrape request. // Scrape request.
if !bytes.Equal(connID, generatedConnID) { scrape, err := s.newScrape(packet)
writer.WriteError(errBadConnectionID)
} else {
scrape, err := s.newScrape(packet)
if err == nil { if err == nil {
err = s.tracker.HandleScrape(scrape, writer) err = s.tracker.HandleScrape(scrape, writer)
}
handleTorrentError(err, writer)
} }
handleTorrentError(err, writer)
} }
if writer.buf.Len() > 0 {
response = writer.buf.Bytes()
}
return return
} }
@ -166,6 +166,7 @@ func (s *Server) newScrape(packet []byte) (*models.Scrape, error) {
for len(packet) >= 20 { for len(packet) >= 20 {
infohash := packet[:20] infohash := packet[:20]
infohashes = append(infohashes, string(infohash)) infohashes = append(infohashes, string(infohash))
packet = packet[20:]
} }
return &models.Scrape{ return &models.Scrape{