udp: Use net.IP inputs for connection ID generation
Add enum for action IDs Remove unnecessary length check
This commit is contained in:
parent
3d28f281fb
commit
14a6278de0
5 changed files with 23 additions and 23 deletions
|
@ -9,6 +9,7 @@ import (
|
||||||
"crypto/aes"
|
"crypto/aes"
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionIDGenerator represents the logic to generate 64-bit UDP
|
// ConnectionIDGenerator represents the logic to generate 64-bit UDP
|
||||||
|
@ -39,15 +40,11 @@ func NewConnectionIDGenerator() (gen *ConnectionIDGenerator, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate returns the 64-bit connection ID for an IP
|
// Generate returns the 64-bit connection ID for an IP
|
||||||
func (g *ConnectionIDGenerator) Generate(ip []byte) []byte {
|
func (g *ConnectionIDGenerator) Generate(ip net.IP) []byte {
|
||||||
return g.generate(ip, g.iv)
|
return g.generate(ip, g.iv)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *ConnectionIDGenerator) generate(ip []byte, iv []byte) []byte {
|
func (g *ConnectionIDGenerator) generate(ip net.IP, iv []byte) []byte {
|
||||||
if len(ip) > 16 {
|
|
||||||
panic("IP larger than 16 bytes")
|
|
||||||
}
|
|
||||||
|
|
||||||
for len(ip) < 8 {
|
for len(ip) < 8 {
|
||||||
ip = append(ip, ip...) // Not enough bits in output.
|
ip = append(ip, ip...) // Not enough bits in output.
|
||||||
}
|
}
|
||||||
|
@ -65,7 +62,7 @@ func (g *ConnectionIDGenerator) generate(ip []byte, iv []byte) []byte {
|
||||||
|
|
||||||
// Matches checks if the given connection ID matches an IP with the current or
|
// Matches checks if the given connection ID matches an IP with the current or
|
||||||
// previous initialization vectors.
|
// previous initialization vectors.
|
||||||
func (g *ConnectionIDGenerator) Matches(id []byte, ip []byte) bool {
|
func (g *ConnectionIDGenerator) Matches(id []byte, ip net.IP) bool {
|
||||||
if expected := g.generate(ip, g.iv); bytes.Equal(id, expected) {
|
if expected := g.generate(ip, g.iv); bytes.Equal(id, expected) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,13 @@ var (
|
||||||
errBadConnectionID = models.ProtocolError("bad connection ID")
|
errBadConnectionID = models.ProtocolError("bad connection ID")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
connectActionID uint32 = iota
|
||||||
|
announceActionID
|
||||||
|
scrapeActionID
|
||||||
|
errorActionID
|
||||||
|
)
|
||||||
|
|
||||||
// handleTorrentError writes err to w if err is a models.ClientError.
|
// handleTorrentError writes err to w if err is a models.ClientError.
|
||||||
func handleTorrentError(err error, w *Writer) {
|
func handleTorrentError(err error, w *Writer) {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -55,11 +62,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
|
||||||
transactionID: transactionID,
|
transactionID: transactionID,
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() { response = writer.buf.Bytes() }()
|
||||||
if writer.buf.Len() > 0 {
|
|
||||||
response = writer.buf.Bytes()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if action != 0 && !s.connIDGen.Matches(connID, addr.IP) {
|
if action != 0 && !s.connIDGen.Matches(connID, addr.IP) {
|
||||||
writer.WriteError(errBadConnectionID)
|
writer.WriteError(errBadConnectionID)
|
||||||
|
@ -67,7 +70,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
switch action {
|
switch action {
|
||||||
case 0:
|
case connectActionID:
|
||||||
actionName = "connect"
|
actionName = "connect"
|
||||||
if !bytes.Equal(connID, initialConnectionID) {
|
if !bytes.Equal(connID, initialConnectionID) {
|
||||||
return // Malformed packet.
|
return // Malformed packet.
|
||||||
|
@ -76,7 +79,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
|
||||||
writer.writeHeader(0)
|
writer.writeHeader(0)
|
||||||
writer.buf.Write(s.connIDGen.Generate(addr.IP))
|
writer.buf.Write(s.connIDGen.Generate(addr.IP))
|
||||||
|
|
||||||
case 1:
|
case announceActionID:
|
||||||
actionName = "announce"
|
actionName = "announce"
|
||||||
ann, err := s.newAnnounce(packet, addr.IP)
|
ann, err := s.newAnnounce(packet, addr.IP)
|
||||||
|
|
||||||
|
@ -86,7 +89,7 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte
|
||||||
|
|
||||||
handleTorrentError(err, writer)
|
handleTorrentError(err, writer)
|
||||||
|
|
||||||
case 2:
|
case scrapeActionID:
|
||||||
actionName = "scrape"
|
actionName = "scrape"
|
||||||
scrape, err := s.newScrape(packet)
|
scrape, err := s.newScrape(packet)
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ func (s *Server) serve(listenAddr string) error {
|
||||||
response, action := s.handlePacket(buffer[:n], addr)
|
response, action := s.handlePacket(buffer[:n], addr)
|
||||||
pool.GiveSlice(buffer)
|
pool.GiveSlice(buffer)
|
||||||
|
|
||||||
if response != nil {
|
if len(response) > 0 {
|
||||||
sock.WriteToUDP(response, addr)
|
sock.WriteToUDP(response, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,10 +20,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var testPort = "34137"
|
var testPort = "34137"
|
||||||
var connectAction = []byte{0, 0, 0, 0}
|
var connectAction = []byte{0, 0, 0, byte(connectActionID)}
|
||||||
var announceAction = []byte{0, 0, 0, 1}
|
var announceAction = []byte{0, 0, 0, byte(announceActionID)}
|
||||||
var scrapeAction = []byte{0, 0, 0, 2}
|
var scrapeAction = []byte{0, 0, 0, byte(scrapeActionID)}
|
||||||
var errorAction = []byte{0, 0, 0, 3}
|
var errorAction = []byte{0, 0, 0, byte(errorActionID)}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
stats.DefaultStats = stats.New(config.StatsConfig{})
|
stats.DefaultStats = stats.New(config.StatsConfig{})
|
||||||
|
|
|
@ -22,7 +22,7 @@ type Writer struct {
|
||||||
|
|
||||||
// WriteError writes the failure reason as a null-terminated string.
|
// WriteError writes the failure reason as a null-terminated string.
|
||||||
func (w *Writer) WriteError(err error) error {
|
func (w *Writer) WriteError(err error) error {
|
||||||
w.writeHeader(3)
|
w.writeHeader(errorActionID)
|
||||||
w.buf.WriteString(err.Error())
|
w.buf.WriteString(err.Error())
|
||||||
w.buf.WriteRune('\000')
|
w.buf.WriteRune('\000')
|
||||||
return nil
|
return nil
|
||||||
|
@ -30,7 +30,7 @@ func (w *Writer) WriteError(err error) error {
|
||||||
|
|
||||||
// WriteAnnounce encodes an announce response according to the UDP spec.
|
// WriteAnnounce encodes an announce response according to the UDP spec.
|
||||||
func (w *Writer) WriteAnnounce(res *models.AnnounceResponse) error {
|
func (w *Writer) WriteAnnounce(res *models.AnnounceResponse) error {
|
||||||
w.writeHeader(1)
|
w.writeHeader(announceActionID)
|
||||||
binary.Write(w.buf, binary.BigEndian, uint32(res.Interval/time.Second))
|
binary.Write(w.buf, binary.BigEndian, uint32(res.Interval/time.Second))
|
||||||
binary.Write(w.buf, binary.BigEndian, uint32(res.Incomplete))
|
binary.Write(w.buf, binary.BigEndian, uint32(res.Incomplete))
|
||||||
binary.Write(w.buf, binary.BigEndian, uint32(res.Complete))
|
binary.Write(w.buf, binary.BigEndian, uint32(res.Complete))
|
||||||
|
@ -45,7 +45,7 @@ func (w *Writer) WriteAnnounce(res *models.AnnounceResponse) error {
|
||||||
|
|
||||||
// WriteScrape encodes a scrape response according to the UDP spec.
|
// WriteScrape encodes a scrape response according to the UDP spec.
|
||||||
func (w *Writer) WriteScrape(res *models.ScrapeResponse) error {
|
func (w *Writer) WriteScrape(res *models.ScrapeResponse) error {
|
||||||
w.writeHeader(2)
|
w.writeHeader(scrapeActionID)
|
||||||
|
|
||||||
for _, torrent := range res.Files {
|
for _, torrent := range res.Files {
|
||||||
binary.Write(w.buf, binary.BigEndian, uint32(torrent.Seeders.Len()))
|
binary.Write(w.buf, binary.BigEndian, uint32(torrent.Seeders.Len()))
|
||||||
|
|
Loading…
Add table
Reference in a new issue