From 2415bc71c6458a5eef731d379d885fd5a8f78fac Mon Sep 17 00:00:00 2001 From: Jimmy Zelinskie Date: Tue, 6 Sep 2016 23:40:41 -0400 Subject: [PATCH] bencode: add missing error handling In addition, this PR attempts to simplify some functions according to the output of `gocyclo -n 15`. --- frontend/http/bencode/bencode.go | 22 ++++ frontend/http/bencode/decoder.go | 92 +++++++------ frontend/http/bencode/encoder.go | 220 ++++++++++++++++++------------- 3 files changed, 198 insertions(+), 136 deletions(-) diff --git a/frontend/http/bencode/bencode.go b/frontend/http/bencode/bencode.go index c08bda1..e9bd05e 100644 --- a/frontend/http/bencode/bencode.go +++ b/frontend/http/bencode/bencode.go @@ -2,6 +2,11 @@ // type assertion over reflection for performance. package bencode +import "bytes" + +// Enforce that Dict implements the Marshaler interface. +var _ Marshaler = Dict{} + // Dict represents a bencode dictionary. type Dict map[string]interface{} @@ -10,9 +15,26 @@ func NewDict() Dict { return make(Dict) } +// MarshalBencode implements the Marshaler interface for Dict. +func (d Dict) MarshalBencode() ([]byte, error) { + var buf bytes.Buffer + err := marshalMap(&buf, map[string]interface{}(d)) + return buf.Bytes(), err +} + +// Enforce that List implements the Marshaler interface. +var _ Marshaler = List{} + // List represents a bencode list. type List []interface{} +// MarshalBencode implements the Marshaler interface for List. +func (l List) MarshalBencode() ([]byte, error) { + var buf bytes.Buffer + err := marshalList(&buf, []interface{}(l)) + return buf.Bytes(), err +} + // NewList allocates the memory for a List. func NewList() List { return make(List, 0) diff --git a/frontend/http/bencode/decoder.go b/frontend/http/bencode/decoder.go index a12f9b6..631641a 100644 --- a/frontend/http/bencode/decoder.go +++ b/frontend/http/bencode/decoder.go @@ -41,49 +41,10 @@ func unmarshal(r *bufio.Reader) (interface{}, error) { return readTerminatedInt(r, 'e') case 'l': - list := NewList() - for { - ok, err := readTerminator(r, 'e') - if err != nil { - return nil, err - } else if ok { - break - } - - v, err := unmarshal(r) - if err != nil { - return nil, err - } - list = append(list, v) - } - return list, nil + return readList(r) case 'd': - dict := NewDict() - for { - ok, err := readTerminator(r, 'e') - if err != nil { - return nil, err - } else if ok { - break - } - - v, err := unmarshal(r) - if err != nil { - return nil, err - } - - key, ok := v.(string) - if !ok { - return nil, errors.New("bencode: non-string map key") - } - - dict[key], err = unmarshal(r) - if err != nil { - return nil, err - } - } - return dict, nil + return readDict(r) default: err = r.UnreadByte() @@ -129,3 +90,52 @@ func readTerminatedInt(r *bufio.Reader, term byte) (int64, error) { return strconv.ParseInt(string(buf[:len(buf)-1]), 10, 64) } + +func readList(r *bufio.Reader) (List, error) { + list := NewList() + for { + ok, err := readTerminator(r, 'e') + if err != nil { + return nil, err + } else if ok { + break + } + + v, err := unmarshal(r) + if err != nil { + return nil, err + } + list = append(list, v) + } + + return list, nil +} + +func readDict(r *bufio.Reader) (Dict, error) { + dict := NewDict() + for { + ok, err := readTerminator(r, 'e') + if err != nil { + return nil, err + } else if ok { + break + } + + v, err := unmarshal(r) + if err != nil { + return nil, err + } + + key, ok := v.(string) + if !ok { + return nil, errors.New("bencode: non-string map key") + } + + dict[key], err = unmarshal(r) + if err != nil { + return nil, err + } + } + + return dict, nil +} diff --git a/frontend/http/bencode/encoder.go b/frontend/http/bencode/encoder.go index f6f1095..41093ae 100644 --- a/frontend/http/bencode/encoder.go +++ b/frontend/http/bencode/encoder.go @@ -25,8 +25,8 @@ func (enc *Encoder) Encode(v interface{}) error { // Marshal returns the bencoding of v. func Marshal(v interface{}) ([]byte, error) { - buf := &bytes.Buffer{} - err := marshal(buf, v) + var buf bytes.Buffer + err := marshal(&buf, v) return buf.Bytes(), err } @@ -36,124 +36,154 @@ type Marshaler interface { MarshalBencode() ([]byte, error) } -// marshal writes types bencoded to an io.Writer -func marshal(w io.Writer, data interface{}) error { +// marshal writes types bencoded to an io.Writer. +func marshal(w io.Writer, data interface{}) (err error) { switch v := data.(type) { case Marshaler: - bencoded, err := v.MarshalBencode() + var bencoded []byte + bencoded, err = v.MarshalBencode() if err != nil { return err } _, err = w.Write(bencoded) - if err != nil { - return err - } - - case string: - marshalString(w, v) - - case int: - marshalInt(w, int64(v)) - - case uint: - marshalUint(w, uint64(v)) - - case int16: - marshalInt(w, int64(v)) - - case uint16: - marshalUint(w, uint64(v)) - - case int32: - marshalInt(w, int64(v)) - - case uint32: - marshalUint(w, uint64(v)) - - case int64: - marshalInt(w, v) - - case uint64: - marshalUint(w, v) case []byte: - marshalBytes(w, v) + err = marshalBytes(w, v) - case time.Duration: // Assume seconds - marshalInt(w, int64(v/time.Second)) - - case Dict: - marshal(w, map[string]interface{}(v)) - - case []Dict: - w.Write([]byte{'l'}) - for _, val := range v { - err := marshal(w, val) - if err != nil { - return err - } - } - w.Write([]byte{'e'}) - - case map[string]interface{}: - w.Write([]byte{'d'}) - for key, val := range v { - marshalString(w, key) - err := marshal(w, val) - if err != nil { - return err - } - } - w.Write([]byte{'e'}) + case string: + err = marshalString(w, v) case []string: - w.Write([]byte{'l'}) - for _, val := range v { - err := marshal(w, val) - if err != nil { - return err - } - } - w.Write([]byte{'e'}) + err = marshalStringSlice(w, v) - case List: - marshal(w, []interface{}(v)) + case int: + err = marshalInt(w, int64(v)) + + case int16: + err = marshalInt(w, int64(v)) + + case int32: + err = marshalInt(w, int64(v)) + + case int64: + err = marshalInt(w, int64(v)) + + case uint: + err = marshalUint(w, uint64(v)) + + case uint16: + err = marshalUint(w, uint64(v)) + + case uint32: + err = marshalUint(w, uint64(v)) + + case uint64: + err = marshalUint(w, uint64(v)) + + case time.Duration: // Assume seconds + err = marshalInt(w, int64(v/time.Second)) + + case map[string]interface{}: + err = marshalMap(w, v) case []interface{}: - w.Write([]byte{'l'}) - for _, val := range v { - err := marshal(w, val) - if err != nil { - return err - } - } - w.Write([]byte{'e'}) + err = marshalList(w, v) default: return fmt.Errorf("attempted to marshal unsupported type:\n%t", v) } - return nil + return err } -func marshalInt(w io.Writer, v int64) { - w.Write([]byte{'i'}) - w.Write([]byte(strconv.FormatInt(v, 10))) - w.Write([]byte{'e'}) +func marshalInt(w io.Writer, v int64) error { + if _, err := w.Write([]byte{'i'}); err != nil { + return err + } + + if _, err := w.Write([]byte(strconv.FormatInt(v, 10))); err != nil { + return err + } + + _, err := w.Write([]byte{'e'}) + return err } -func marshalUint(w io.Writer, v uint64) { - w.Write([]byte{'i'}) - w.Write([]byte(strconv.FormatUint(v, 10))) - w.Write([]byte{'e'}) +func marshalUint(w io.Writer, v uint64) error { + if _, err := w.Write([]byte{'i'}); err != nil { + return err + } + + if _, err := w.Write([]byte(strconv.FormatUint(v, 10))); err != nil { + return err + } + + _, err := w.Write([]byte{'e'}) + return err } -func marshalBytes(w io.Writer, v []byte) { - w.Write([]byte(strconv.Itoa(len(v)))) - w.Write([]byte{':'}) - w.Write(v) +func marshalBytes(w io.Writer, v []byte) error { + if _, err := w.Write([]byte(strconv.Itoa(len(v)))); err != nil { + return err + } + + if _, err := w.Write([]byte{':'}); err != nil { + return err + } + + _, err := w.Write(v) + return err } -func marshalString(w io.Writer, v string) { - marshalBytes(w, []byte(v)) +func marshalString(w io.Writer, v string) error { + return marshalBytes(w, []byte(v)) +} + +func marshalStringSlice(w io.Writer, v []string) error { + if _, err := w.Write([]byte{'l'}); err != nil { + return err + } + + for _, val := range v { + if err := marshal(w, val); err != nil { + return err + } + } + + _, err := w.Write([]byte{'e'}) + return err +} + +func marshalList(w io.Writer, v []interface{}) error { + if _, err := w.Write([]byte{'l'}); err != nil { + return err + } + + for _, val := range v { + if err := marshal(w, val); err != nil { + return err + } + } + + _, err := w.Write([]byte{'e'}) + return err +} + +func marshalMap(w io.Writer, v map[string]interface{}) error { + if _, err := w.Write([]byte{'d'}); err != nil { + return err + } + + for key, val := range v { + if err := marshalString(w, key); err != nil { + return err + } + + if err := marshal(w, val); err != nil { + return err + } + } + + _, err := w.Write([]byte{'e'}) + return err }