diff --git a/common.go b/common.go index bb5b08d3..4822db07 100644 --- a/common.go +++ b/common.go @@ -17,8 +17,112 @@ import ( const maxVarIntPayload = 9 // readElement reads the next sequence of bytes from r using little endian -// depending on the concrete type of element pointed to. +// depending on the concrete type of element pointed to. It also accepts a +// scratch buffer that is used for the primitive values rather than creating +// a new buffer on every call. func readElement(r io.Reader, element interface{}) error { + var scratch [8]byte + + // Attempt to read the element based on the concrete type via fast + // type assertions first. + switch e := element.(type) { + case *int32: + b := scratch[0:4] + _, err := io.ReadFull(r, b) + if err != nil { + return err + } + *e = int32(binary.LittleEndian.Uint32(b)) + return nil + + case *uint32: + b := scratch[0:4] + _, err := io.ReadFull(r, b) + if err != nil { + return err + } + *e = binary.LittleEndian.Uint32(b) + return nil + + case *int64: + b := scratch[0:8] + _, err := io.ReadFull(r, b) + if err != nil { + return err + } + *e = int64(binary.LittleEndian.Uint64(b)) + return nil + + case *uint64: + b := scratch[0:8] + _, err := io.ReadFull(r, b) + if err != nil { + return err + } + *e = binary.LittleEndian.Uint64(b) + return nil + + // Message header checksum. + case *[4]byte: + _, err := io.ReadFull(r, e[:]) + if err != nil { + return err + } + return nil + + // Message header command. + case *[commandSize]uint8: + _, err := io.ReadFull(r, e[:]) + if err != nil { + return err + } + return nil + + // IP address. + case *[16]byte: + _, err := io.ReadFull(r, e[:]) + if err != nil { + return err + } + return nil + + case *ShaHash: + _, err := io.ReadFull(r, e[:]) + if err != nil { + return err + } + return nil + + case *ServiceFlag: + b := scratch[0:8] + _, err := io.ReadFull(r, b) + if err != nil { + return err + } + *e = ServiceFlag(binary.LittleEndian.Uint64(b)) + return nil + + case *InvType: + b := scratch[0:4] + _, err := io.ReadFull(r, b) + if err != nil { + return err + } + *e = InvType(binary.LittleEndian.Uint32(b)) + return nil + + case *BitcoinNet: + b := scratch[0:4] + _, err := io.ReadFull(r, b) + if err != nil { + return err + } + *e = BitcoinNet(binary.LittleEndian.Uint32(b)) + return nil + } + + // Fall back to the slower binary.Read if a fast path was not available + // above. return binary.Read(r, binary.LittleEndian, element) }