Merge pull request #217 from jzelinskie/bencode

bencode: add missing error handling
This commit is contained in:
mrd0ll4r 2016-09-07 20:51:14 -04:00 committed by GitHub
commit 52d7039a3e
3 changed files with 198 additions and 136 deletions

View file

@ -2,6 +2,11 @@
// type assertion over reflection for performance.
package bencode
import "bytes"
// Enforce that Dict implements the Marshaler interface.
var _ Marshaler = Dict{}
// Dict represents a bencode dictionary.
type Dict map[string]interface{}
@ -10,9 +15,26 @@ func NewDict() Dict {
return make(Dict)
}
// MarshalBencode implements the Marshaler interface for Dict.
func (d Dict) MarshalBencode() ([]byte, error) {
var buf bytes.Buffer
err := marshalMap(&buf, map[string]interface{}(d))
return buf.Bytes(), err
}
// Enforce that List implements the Marshaler interface.
var _ Marshaler = List{}
// List represents a bencode list.
type List []interface{}
// MarshalBencode implements the Marshaler interface for List.
func (l List) MarshalBencode() ([]byte, error) {
var buf bytes.Buffer
err := marshalList(&buf, []interface{}(l))
return buf.Bytes(), err
}
// NewList allocates the memory for a List.
func NewList() List {
return make(List, 0)

View file

@ -41,49 +41,10 @@ func unmarshal(r *bufio.Reader) (interface{}, error) {
return readTerminatedInt(r, 'e')
case 'l':
list := NewList()
for {
ok, err := readTerminator(r, 'e')
if err != nil {
return nil, err
} else if ok {
break
}
v, err := unmarshal(r)
if err != nil {
return nil, err
}
list = append(list, v)
}
return list, nil
return readList(r)
case 'd':
dict := NewDict()
for {
ok, err := readTerminator(r, 'e')
if err != nil {
return nil, err
} else if ok {
break
}
v, err := unmarshal(r)
if err != nil {
return nil, err
}
key, ok := v.(string)
if !ok {
return nil, errors.New("bencode: non-string map key")
}
dict[key], err = unmarshal(r)
if err != nil {
return nil, err
}
}
return dict, nil
return readDict(r)
default:
err = r.UnreadByte()
@ -129,3 +90,52 @@ func readTerminatedInt(r *bufio.Reader, term byte) (int64, error) {
return strconv.ParseInt(string(buf[:len(buf)-1]), 10, 64)
}
func readList(r *bufio.Reader) (List, error) {
list := NewList()
for {
ok, err := readTerminator(r, 'e')
if err != nil {
return nil, err
} else if ok {
break
}
v, err := unmarshal(r)
if err != nil {
return nil, err
}
list = append(list, v)
}
return list, nil
}
func readDict(r *bufio.Reader) (Dict, error) {
dict := NewDict()
for {
ok, err := readTerminator(r, 'e')
if err != nil {
return nil, err
} else if ok {
break
}
v, err := unmarshal(r)
if err != nil {
return nil, err
}
key, ok := v.(string)
if !ok {
return nil, errors.New("bencode: non-string map key")
}
dict[key], err = unmarshal(r)
if err != nil {
return nil, err
}
}
return dict, nil
}

View file

@ -25,8 +25,8 @@ func (enc *Encoder) Encode(v interface{}) error {
// Marshal returns the bencoding of v.
func Marshal(v interface{}) ([]byte, error) {
buf := &bytes.Buffer{}
err := marshal(buf, v)
var buf bytes.Buffer
err := marshal(&buf, v)
return buf.Bytes(), err
}
@ -36,124 +36,154 @@ type Marshaler interface {
MarshalBencode() ([]byte, error)
}
// marshal writes types bencoded to an io.Writer
func marshal(w io.Writer, data interface{}) error {
// marshal writes types bencoded to an io.Writer.
func marshal(w io.Writer, data interface{}) (err error) {
switch v := data.(type) {
case Marshaler:
bencoded, err := v.MarshalBencode()
var bencoded []byte
bencoded, err = v.MarshalBencode()
if err != nil {
return err
}
_, err = w.Write(bencoded)
if err != nil {
return err
}
case string:
marshalString(w, v)
case int:
marshalInt(w, int64(v))
case uint:
marshalUint(w, uint64(v))
case int16:
marshalInt(w, int64(v))
case uint16:
marshalUint(w, uint64(v))
case int32:
marshalInt(w, int64(v))
case uint32:
marshalUint(w, uint64(v))
case int64:
marshalInt(w, v)
case uint64:
marshalUint(w, v)
case []byte:
marshalBytes(w, v)
err = marshalBytes(w, v)
case time.Duration: // Assume seconds
marshalInt(w, int64(v/time.Second))
case Dict:
marshal(w, map[string]interface{}(v))
case []Dict:
w.Write([]byte{'l'})
for _, val := range v {
err := marshal(w, val)
if err != nil {
return err
}
}
w.Write([]byte{'e'})
case map[string]interface{}:
w.Write([]byte{'d'})
for key, val := range v {
marshalString(w, key)
err := marshal(w, val)
if err != nil {
return err
}
}
w.Write([]byte{'e'})
case string:
err = marshalString(w, v)
case []string:
w.Write([]byte{'l'})
for _, val := range v {
err := marshal(w, val)
if err != nil {
return err
}
}
w.Write([]byte{'e'})
err = marshalStringSlice(w, v)
case List:
marshal(w, []interface{}(v))
case int:
err = marshalInt(w, int64(v))
case int16:
err = marshalInt(w, int64(v))
case int32:
err = marshalInt(w, int64(v))
case int64:
err = marshalInt(w, int64(v))
case uint:
err = marshalUint(w, uint64(v))
case uint16:
err = marshalUint(w, uint64(v))
case uint32:
err = marshalUint(w, uint64(v))
case uint64:
err = marshalUint(w, uint64(v))
case time.Duration: // Assume seconds
err = marshalInt(w, int64(v/time.Second))
case map[string]interface{}:
err = marshalMap(w, v)
case []interface{}:
w.Write([]byte{'l'})
for _, val := range v {
err := marshal(w, val)
if err != nil {
return err
}
}
w.Write([]byte{'e'})
err = marshalList(w, v)
default:
return fmt.Errorf("attempted to marshal unsupported type:\n%t", v)
}
return nil
return err
}
func marshalInt(w io.Writer, v int64) {
w.Write([]byte{'i'})
w.Write([]byte(strconv.FormatInt(v, 10)))
w.Write([]byte{'e'})
func marshalInt(w io.Writer, v int64) error {
if _, err := w.Write([]byte{'i'}); err != nil {
return err
}
if _, err := w.Write([]byte(strconv.FormatInt(v, 10))); err != nil {
return err
}
_, err := w.Write([]byte{'e'})
return err
}
func marshalUint(w io.Writer, v uint64) {
w.Write([]byte{'i'})
w.Write([]byte(strconv.FormatUint(v, 10)))
w.Write([]byte{'e'})
func marshalUint(w io.Writer, v uint64) error {
if _, err := w.Write([]byte{'i'}); err != nil {
return err
}
if _, err := w.Write([]byte(strconv.FormatUint(v, 10))); err != nil {
return err
}
_, err := w.Write([]byte{'e'})
return err
}
func marshalBytes(w io.Writer, v []byte) {
w.Write([]byte(strconv.Itoa(len(v))))
w.Write([]byte{':'})
w.Write(v)
func marshalBytes(w io.Writer, v []byte) error {
if _, err := w.Write([]byte(strconv.Itoa(len(v)))); err != nil {
return err
}
if _, err := w.Write([]byte{':'}); err != nil {
return err
}
_, err := w.Write(v)
return err
}
func marshalString(w io.Writer, v string) {
marshalBytes(w, []byte(v))
func marshalString(w io.Writer, v string) error {
return marshalBytes(w, []byte(v))
}
func marshalStringSlice(w io.Writer, v []string) error {
if _, err := w.Write([]byte{'l'}); err != nil {
return err
}
for _, val := range v {
if err := marshal(w, val); err != nil {
return err
}
}
_, err := w.Write([]byte{'e'})
return err
}
func marshalList(w io.Writer, v []interface{}) error {
if _, err := w.Write([]byte{'l'}); err != nil {
return err
}
for _, val := range v {
if err := marshal(w, val); err != nil {
return err
}
}
_, err := w.Write([]byte{'e'})
return err
}
func marshalMap(w io.Writer, v map[string]interface{}) error {
if _, err := w.Write([]byte{'d'}); err != nil {
return err
}
for key, val := range v {
if err := marshalString(w, key); err != nil {
return err
}
if err := marshal(w, val); err != nil {
return err
}
}
_, err := w.Write([]byte{'e'})
return err
}