diff --git a/udp/protocol.go b/udp/protocol.go index 2af5799..e0f3a89 100644 --- a/udp/protocol.go +++ b/udp/protocol.go @@ -22,6 +22,7 @@ var ( errMalformedPacket = errors.New("malformed packet") errMalformedIP = errors.New("malformed IP address") errMalformedEvent = errors.New("malformed event ID") + errBadConnectionID = errors.New("bad connection ID") ) func writeHeader(response []byte, action uint32, transactionID []byte) { @@ -50,7 +51,12 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte transactionID := packet[12:16] generatedConnID := GenerateConnectionID(addr.IP) - writer := &Writer{transactionID: transactionID} + writer := &Writer{ + buf: new(bytes.Buffer), + + connectionID: connID, + transactionID: transactionID, + } switch action { case 0: @@ -65,22 +71,34 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte case 1: // Announce request. - writer.buf = new(bytes.Buffer) - ann, err := s.newAnnounce(packet, addr.IP) + if !bytes.Equal(connID, generatedConnID) { + writer.WriteError(errBadConnectionID) + } else { + ann, err := s.newAnnounce(packet, addr.IP) - if err == nil { - err = s.tracker.HandleAnnounce(ann, writer) + if err == nil { + err = s.tracker.HandleAnnounce(ann, writer) + } + + handleTorrentError(err, writer) } - handleTorrentError(err, writer) - case 2: // Scrape request. - writer.buf = new(bytes.Buffer) - // handleTorrentError(s.tracker.HandleScrape(scrape, writer), writer) + if !bytes.Equal(connID, generatedConnID) { + writer.WriteError(errBadConnectionID) + } else { + scrape, err := s.newScrape(packet) + + if err == nil { + err = s.tracker.HandleScrape(scrape, writer) + } + + handleTorrentError(err, writer) + } } - if writer.buf != nil { + if writer.buf.Len() > 0 { response = writer.buf.Bytes() } return @@ -132,3 +150,26 @@ func (s *Server) newAnnounce(packet []byte, ip net.IP) (*models.Announce, error) Uploaded: uploaded, }, nil } + +func (s *Server) newScrape(packet []byte) (*models.Scrape, error) { + if len(packet) < 16 { + return nil, errMalformedPacket + } + + var infohashes []string + packet = packet[16:] + + if len(packet)%20 != 0 { + return nil, errMalformedPacket + } + + for len(packet) >= 20 { + infohash := packet[:20] + infohashes = append(infohashes, string(infohash)) + } + + return &models.Scrape{ + Config: s.config, + Infohashes: infohashes, + }, nil +} diff --git a/udp/writer.go b/udp/writer.go index 1b6f6b1..c6104ff 100644 --- a/udp/writer.go +++ b/udp/writer.go @@ -14,6 +14,7 @@ import ( type Writer struct { buf *bytes.Buffer + connectionID []byte transactionID []byte } @@ -40,6 +41,13 @@ func (w *Writer) WriteAnnounce(res *models.AnnounceResponse) error { func (w *Writer) WriteScrape(res *models.ScrapeResponse) error { w.writeHeader(2) + + for _, torrent := range res.Files { + binary.Write(w.buf, binary.BigEndian, uint32(torrent.Seeders.Len())) + binary.Write(w.buf, binary.BigEndian, uint32(torrent.Snatches)) + binary.Write(w.buf, binary.BigEndian, uint32(torrent.Leechers.Len())) + } + return nil }