From fa8a4a59bc7978be7fe6aec44644bbdd82499a03 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Wed, 16 Aug 2017 11:59:03 -0400 Subject: [PATCH] parameterize node id length --- dht/dht.go | 2 +- dht/krpc.go | 14 +++++++------- dht/peerwire.go | 15 ++++++++------- dht/routingtable.go | 19 ++++++++++--------- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/dht/dht.go b/dht/dht.go index f9b9606..c97a75f 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -103,7 +103,7 @@ func New(config *Config) *DHT { config = NewStandardConfig() } - node, err := newNode(randomString(20), config.Network, config.Address) + node, err := newNode(randomString(nodeIDLength), config.Network, config.Address) if err != nil { panic(err) } diff --git a/dht/krpc.go b/dht/krpc.go index a0ef08f..75fcb64 100644 --- a/dht/krpc.go +++ b/dht/krpc.go @@ -446,7 +446,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, response map[string]interface{}) return } - if len(id) != 20 { + if len(id) != nodeIDLength { send(dht, addr, makeError(t, protocolError, "invalid id")) return } @@ -473,7 +473,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, response map[string]interface{}) } target := a["target"].(string) - if len(target) != 20 { + if len(target) != nodeIDLength { send(dht, addr, makeError(t, protocolError, "invalid target")) return } @@ -503,7 +503,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, response map[string]interface{}) infoHash := a["info_hash"].(string) - if len(infoHash) != 20 { + if len(infoHash) != nodeIDLength { send(dht, addr, makeError(t, protocolError, "invalid info_hash")) return } @@ -587,14 +587,14 @@ func findOn(dht *DHT, r map[string]interface{}, target *bitmap, queryType string } nodes := r["nodes"].(string) - if len(nodes)%26 != 0 { - return errors.New("the length of nodes should can be divided by 26") + if len(nodes)%compactNodeInfoLength != 0 { + return fmt.Errorf("the length of nodes should can be divided by %d", compactNodeInfoLength) } hasNew, found := false, false - for i := 0; i < len(nodes)/26; i++ { + for i := 0; i < len(nodes)/compactNodeInfoLength; i++ { no, _ := newNodeFromCompactInfo( - string(nodes[i*26:(i+1)*26]), dht.Network) + string(nodes[i*compactNodeInfoLength:(i+1)*compactNodeInfoLength]), dht.Network) if no.id.RawString() == target.RawString() { found = true diff --git a/dht/peerwire.go b/dht/peerwire.go index eb2d448..50522af 100644 --- a/dht/peerwire.go +++ b/dht/peerwire.go @@ -36,6 +36,7 @@ var handshakePrefix = []byte{ 19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114, 111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 16, 0, 1, } +var handshakePrefixLength = len(handshakePrefix) // read reads size-length bytes from conn to data. func read(conn *net.TCPConn, size int, data *bytes.Buffer) error { @@ -81,10 +82,10 @@ func sendMessage(conn *net.TCPConn, data []byte) error { // sendHandshake sends handshake message to conn. func sendHandshake(conn *net.TCPConn, infoHash, peerID []byte) error { - data := make([]byte, 68) - copy(data[:28], handshakePrefix) - copy(data[28:48], infoHash) - copy(data[48:], peerID) + data := make([]byte, handshakePrefixLength+nodeIDLength+len(peerID)) + copy(data[:handshakePrefixLength], handshakePrefix) + copy(data[handshakePrefixLength:handshakePrefixLength+nodeIDLength], infoHash) + copy(data[handshakePrefixLength+nodeIDLength:], peerID) conn.SetWriteDeadline(time.Now().Add(time.Second * 10)) _, err := conn.Write(data) @@ -93,7 +94,7 @@ func sendHandshake(conn *net.TCPConn, infoHash, peerID []byte) error { // onHandshake handles the handshake response. func onHandshake(data []byte) (err error) { - if !(bytes.Equal(handshakePrefix[:20], data[:20]) && data[25]&0x10 != 0) { + if !(bytes.Equal(handshakePrefix[:handshakePrefixLength], data[:handshakePrefixLength]) && data[25]&0x10 != 0) { err = errors.New("invalid handshake response") } return @@ -253,7 +254,7 @@ func (wire *Wire) fetchMetadata(r Request) { data := bytes.NewBuffer(nil) data.Grow(BLOCK) - if sendHandshake(conn, infoHash, []byte(randomString(20))) != nil || + if sendHandshake(conn, infoHash, []byte(randomString(nodeIDLength))) != nil || read(conn, 68, data) != nil || onHandshake(data.Next(68)) != nil || sendExtHandshake(conn) != nil { @@ -374,7 +375,7 @@ func (wire *Wire) Run() { string(r.InfoHash), genAddress(r.IP, r.Port), }, ":") - if len(r.InfoHash) != 20 || wire.blackList.in(r.IP, r.Port) || + if len(r.InfoHash) != nodeIDLength || wire.blackList.in(r.IP, r.Port) || wire.queue.Has(key) { return } diff --git a/dht/routingtable.go b/dht/routingtable.go index 893772c..f744078 100644 --- a/dht/routingtable.go +++ b/dht/routingtable.go @@ -11,6 +11,8 @@ import ( // maxPrefixLength is the length of DHT node. const maxPrefixLength = 160 +const nodeIDLength = 20 +const compactNodeInfoLength = nodeIDLength + 6 // node represents a DHT node. type node struct { @@ -21,8 +23,8 @@ type node struct { // newNode returns a node pointer. func newNode(id, network, address string) (*node, error) { - if len(id) != 20 { - return nil, errors.New("node id should be a 20-length string") + if len(id) != nodeIDLength { + return nil, fmt.Errorf("node id should be a %d-length string", nodeIDLength) } addr, err := net.ResolveUDPAddr(network, address) @@ -34,15 +36,14 @@ func newNode(id, network, address string) (*node, error) { } // newNodeFromCompactInfo parses compactNodeInfo and returns a node pointer. -func newNodeFromCompactInfo( - compactNodeInfo string, network string) (*node, error) { +func newNodeFromCompactInfo(compactNodeInfo string, network string) (*node, error) { - if len(compactNodeInfo) != 26 { - return nil, errors.New("compactNodeInfo should be a 26-length string") + if len(compactNodeInfo) != compactNodeInfoLength { + return nil, fmt.Errorf("compactNodeInfo should be a %d-length string", compactNodeInfoLength) } - id := compactNodeInfo[:20] - ip, port, _ := decodeCompactIPPortInfo(compactNodeInfo[20:]) + id := compactNodeInfo[:nodeIDLength] + ip, port, _ := decodeCompactIPPortInfo(compactNodeInfo[nodeIDLength:]) return newNode(id, network, genAddress(ip.String(), port)) } @@ -179,7 +180,7 @@ func (bucket *kbucket) RandomChildID() string { return strings.Join([]string{ bucket.prefix.RawString()[:prefixLen], - randomString(20 - prefixLen), + randomString(nodeIDLength - prefixLen), }, "") }