udp: Use net.IP inputs for connection ID generation
Add enum for action IDs Remove unnecessary length check
This commit is contained in:
parent
3d28f281fb
commit
14a6278de0
5 changed files with 23 additions and 23 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"net"
|
||||
)
|
||||
|
||||
// ConnectionIDGenerator represents the logic to generate 64-bit UDP
|
||||
|
@ -39,15 +40,11 @@ func NewConnectionIDGenerator() (gen *ConnectionIDGenerator, err error) {
|
|||
}
|
||||
|
||||
// Generate returns the 64-bit connection ID for an IP
|
||||
func (g *ConnectionIDGenerator) Generate(ip []byte) []byte {
|
||||
func (g *ConnectionIDGenerator) Generate(ip net.IP) []byte {
|
||||
return g.generate(ip, g.iv)
|
||||
}
|
||||
|
||||
func (g *ConnectionIDGenerator) generate(ip []byte, iv []byte) []byte {
|
||||
if len(ip) > 16 {
|
||||
panic("IP larger than 16 bytes")
|
||||
}
|
||||
|
||||
func (g *ConnectionIDGenerator) generate(ip net.IP, iv []byte) []byte {
|
||||
for len(ip) < 8 {
|
||||
ip = append(ip, ip...) // Not enough bits in output.
|
||||
}
|
||||
|
@ -65,7 +62,7 @@ func (g *ConnectionIDGenerator) generate(ip []byte, iv []byte) []byte {
|
|||
|
||||
// Matches checks if the given connection ID matches an IP with the current or
|
||||
// previous initialization vectors.
|
||||
func (g *ConnectionIDGenerator) Matches(id []byte, ip []byte) bool {
|
||||
func (g *ConnectionIDGenerator) Matches(id []byte, ip net.IP) bool {
|
||||
if expected := g.generate(ip, g.iv); bytes.Equal(id, expected) {
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -26,6 +26,13 @@ var (
|
|||
errBadConnectionID = models.ProtocolError("bad connection ID")
|
||||
)
|
||||
|
||||
const (
|
||||
connectActionID uint32 = iota
|
||||
announceActionID
|
||||
scrapeActionID
|
||||
errorActionID
|
||||
)
|
||||
|
||||
// handleTorrentError writes err to w if err is a models.ClientError.
|
||||
func handleTorrentError(err error, w *Writer) {
|
||||
if err == nil {
|
||||
|
@ -55,11 +62,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
|
|||
transactionID: transactionID,
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if writer.buf.Len() > 0 {
|
||||
response = writer.buf.Bytes()
|
||||
}
|
||||
}()
|
||||
defer func() { response = writer.buf.Bytes() }()
|
||||
|
||||
if action != 0 && !s.connIDGen.Matches(connID, addr.IP) {
|
||||
writer.WriteError(errBadConnectionID)
|
||||
|
@ -67,7 +70,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
|
|||
}
|
||||
|
||||
switch action {
|
||||
case 0:
|
||||
case connectActionID:
|
||||
actionName = "connect"
|
||||
if !bytes.Equal(connID, initialConnectionID) {
|
||||
return // Malformed packet.
|
||||
|
@ -76,7 +79,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
|
|||
writer.writeHeader(0)
|
||||
writer.buf.Write(s.connIDGen.Generate(addr.IP))
|
||||
|
||||
case 1:
|
||||
case announceActionID:
|
||||
actionName = "announce"
|
||||
ann, err := s.newAnnounce(packet, addr.IP)
|
||||
|
||||
|
@ -86,7 +89,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
|
|||
|
||||
handleTorrentError(err, writer)
|
||||
|
||||
case 2:
|
||||
case scrapeActionID:
|
||||
actionName = "scrape"
|
||||
scrape, err := s.newScrape(packet)
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ func (s *Server) serve(listenAddr string) error {
|
|||
response, action := s.handlePacket(buffer[:n], addr)
|
||||
pool.GiveSlice(buffer)
|
||||
|
||||
if response != nil {
|
||||
if len(response) > 0 {
|
||||
sock.WriteToUDP(response, addr)
|
||||
}
|
||||
|
||||
|
|
|
@ -20,10 +20,10 @@ import (
|
|||
)
|
||||
|
||||
var testPort = "34137"
|
||||
var connectAction = []byte{0, 0, 0, 0}
|
||||
var announceAction = []byte{0, 0, 0, 1}
|
||||
var scrapeAction = []byte{0, 0, 0, 2}
|
||||
var errorAction = []byte{0, 0, 0, 3}
|
||||
var connectAction = []byte{0, 0, 0, byte(connectActionID)}
|
||||
var announceAction = []byte{0, 0, 0, byte(announceActionID)}
|
||||
var scrapeAction = []byte{0, 0, 0, byte(scrapeActionID)}
|
||||
var errorAction = []byte{0, 0, 0, byte(errorActionID)}
|
||||
|
||||
func init() {
|
||||
stats.DefaultStats = stats.New(config.StatsConfig{})
|
||||
|
|
|
@ -22,7 +22,7 @@ type Writer struct {
|
|||
|
||||
// WriteError writes the failure reason as a null-terminated string.
|
||||
func (w *Writer) WriteError(err error) error {
|
||||
w.writeHeader(3)
|
||||
w.writeHeader(errorActionID)
|
||||
w.buf.WriteString(err.Error())
|
||||
w.buf.WriteRune('\000')
|
||||
return nil
|
||||
|
@ -30,7 +30,7 @@ func (w *Writer) WriteError(err error) error {
|
|||
|
||||
// WriteAnnounce encodes an announce response according to the UDP spec.
|
||||
func (w *Writer) WriteAnnounce(res *models.AnnounceResponse) error {
|
||||
w.writeHeader(1)
|
||||
w.writeHeader(announceActionID)
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(res.Interval/time.Second))
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(res.Incomplete))
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(res.Complete))
|
||||
|
@ -45,7 +45,7 @@ func (w *Writer) WriteAnnounce(res *models.AnnounceResponse) error {
|
|||
|
||||
// WriteScrape encodes a scrape response according to the UDP spec.
|
||||
func (w *Writer) WriteScrape(res *models.ScrapeResponse) error {
|
||||
w.writeHeader(2)
|
||||
w.writeHeader(scrapeActionID)
|
||||
|
||||
for _, torrent := range res.Files {
|
||||
binary.Write(w.buf, binary.BigEndian, uint32(torrent.Seeders.Len()))
|
||||
|
|
Loading…
Add table
Reference in a new issue