diff options
Diffstat (limited to 'rlp/decode.go')
-rw-r--r-- | rlp/decode.go | 193 |
1 files changed, 110 insertions, 83 deletions
diff --git a/rlp/decode.go b/rlp/decode.go index 7d95af02b..712d9fcf1 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -54,7 +54,7 @@ type Decoder interface { // To decode into a Go string, the input must be an RLP string. The // bytes are taken as-is and will not necessarily be valid UTF-8. // -// To decode into an integer type, the input must also be an RLP +// To decode into an unsigned integer type, the input must also be an RLP // string. The bytes are interpreted as a big endian representation of // the integer. If the RLP string is larger than the bit size of the // type, Decode will return an error. Decode also supports *big.Int. @@ -66,8 +66,9 @@ type Decoder interface { // []interface{}, for RLP lists // []byte, for RLP strings // -// Non-empty interface types are not supported, nor are bool, float32, -// float64, maps, channel types and functions. +// Non-empty interface types are not supported, nor are booleans, +// signed integers, floating point numbers, maps, channels and +// functions. func Decode(r io.Reader, val interface{}) error { return NewStream(r).Decode(val) } @@ -81,37 +82,58 @@ func (err decodeError) Error() string { return fmt.Sprintf("rlp: %s for %v", err.msg, err.typ) } -func makeNumDecoder(typ reflect.Type) decoder { - kind := typ.Kind() - switch { - case kind <= reflect.Int64: - return decodeInt - case kind <= reflect.Uint64: - return decodeUint - default: - panic("fallthrough") +func wrapStreamError(err error, typ reflect.Type) error { + switch err { + case ErrExpectedList: + return decodeError{"expected input list", typ} + case ErrExpectedString: + return decodeError{"expected input string or byte", typ} + case errUintOverflow: + return decodeError{"input string too long", typ} + case errNotAtEOL: + return decodeError{"input list has too many elements", typ} } + return err } -func decodeInt(s *Stream, val reflect.Value) error { - typ := val.Type() - num, err := s.uint(typ.Bits()) - if err == errUintOverflow { - return decodeError{"input string too long", typ} - } else if err != nil { - return err +var ( + decoderInterface = reflect.TypeOf(new(Decoder)).Elem() + bigInt = reflect.TypeOf(big.Int{}) +) + +func makeDecoder(typ reflect.Type) (dec decoder, err error) { + kind := typ.Kind() + switch { + case typ.Implements(decoderInterface): + return decodeDecoder, nil + case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(decoderInterface): + return decodeDecoderNoPtr, nil + case typ.AssignableTo(reflect.PtrTo(bigInt)): + return decodeBigInt, nil + case typ.AssignableTo(bigInt): + return decodeBigIntNoPtr, nil + case isUint(kind): + return decodeUint, nil + case kind == reflect.String: + return decodeString, nil + case kind == reflect.Slice || kind == reflect.Array: + return makeListDecoder(typ) + case kind == reflect.Struct: + return makeStructDecoder(typ) + case kind == reflect.Ptr: + return makePtrDecoder(typ) + case kind == reflect.Interface && typ.NumMethod() == 0: + return decodeInterface, nil + default: + return nil, fmt.Errorf("rlp: type %v is not RLP-serializable", typ) } - val.SetInt(int64(num)) - return nil } func decodeUint(s *Stream, val reflect.Value) error { typ := val.Type() num, err := s.uint(typ.Bits()) - if err == errUintOverflow { - return decodeError{"input string too big", typ} - } else if err != nil { - return err + if err != nil { + return wrapStreamError(err, val.Type()) } val.SetUint(num) return nil @@ -120,7 +142,7 @@ func decodeUint(s *Stream, val reflect.Value) error { func decodeString(s *Stream, val reflect.Value) error { b, err := s.Bytes() if err != nil { - return err + return wrapStreamError(err, val.Type()) } val.SetString(string(b)) return nil @@ -133,7 +155,7 @@ func decodeBigIntNoPtr(s *Stream, val reflect.Value) error { func decodeBigInt(s *Stream, val reflect.Value) error { b, err := s.Bytes() if err != nil { - return err + return wrapStreamError(err, val.Type()) } i := val.Interface().(*big.Int) if i == nil { @@ -144,8 +166,6 @@ func decodeBigInt(s *Stream, val reflect.Value) error { return nil } -const maxInt = int(^uint(0) >> 1) - func makeListDecoder(typ reflect.Type) (decoder, error) { etype := typ.Elem() if etype.Kind() == reflect.Uint8 && !reflect.PtrTo(etype).Implements(decoderInterface) { @@ -159,55 +179,41 @@ func makeListDecoder(typ reflect.Type) (decoder, error) { if err != nil { return nil, err } - var maxLen = maxInt + if typ.Kind() == reflect.Array { - maxLen = typ.Len() - } - dec := func(s *Stream, val reflect.Value) error { - return decodeList(s, val, etypeinfo.decoder, maxLen) + return func(s *Stream, val reflect.Value) error { + return decodeListArray(s, val, etypeinfo.decoder) + }, nil } - return dec, nil + return func(s *Stream, val reflect.Value) error { + return decodeListSlice(s, val, etypeinfo.decoder) + }, nil } -// decodeList decodes RLP list elements into slices and arrays. -// -// The approach here is stolen from package json, although we differ -// in the semantics for arrays. package json discards remaining -// elements that would not fit into the array. We generate an error in -// this case because we'd be losing information. -func decodeList(s *Stream, val reflect.Value, elemdec decoder, maxelem int) error { +func decodeListSlice(s *Stream, val reflect.Value, elemdec decoder) error { size, err := s.List() if err != nil { - return err + return wrapStreamError(err, val.Type()) } if size == 0 { - if val.Kind() == reflect.Slice { - val.Set(reflect.MakeSlice(val.Type(), 0, 0)) - } else { - zero(val, 0) - } + val.Set(reflect.MakeSlice(val.Type(), 0, 0)) return s.ListEnd() } i := 0 - for { - if i > maxelem { - return decodeError{"input list has too many elements", val.Type()} - } - if val.Kind() == reflect.Slice { - // grow slice if necessary - if i >= val.Cap() { - newcap := val.Cap() + val.Cap()/2 - if newcap < 4 { - newcap = 4 - } - newv := reflect.MakeSlice(val.Type(), val.Len(), newcap) - reflect.Copy(newv, val) - val.Set(newv) - } - if i >= val.Len() { - val.SetLen(i + 1) + for ; ; i++ { + // grow slice if necessary + if i >= val.Cap() { + newcap := val.Cap() + val.Cap()/2 + if newcap < 4 { + newcap = 4 } + newv := reflect.MakeSlice(val.Type(), val.Len(), newcap) + reflect.Copy(newv, val) + val.Set(newv) + } + if i >= val.Len() { + val.SetLen(i + 1) } // decode into element if err := elemdec(s, val.Index(i)); err == EOL { @@ -215,26 +221,49 @@ func decodeList(s *Stream, val reflect.Value, elemdec decoder, maxelem int) erro } else if err != nil { return err } - i++ } if i < val.Len() { - if val.Kind() == reflect.Array { - // zero the rest of the array. - zero(val, i) - } else { - val.SetLen(i) - } + val.SetLen(i) } return s.ListEnd() } +func decodeListArray(s *Stream, val reflect.Value, elemdec decoder) error { + size, err := s.List() + if err != nil { + return err + } + if size == 0 { + zero(val, 0) + return s.ListEnd() + } + + // The approach here is stolen from package json, although we differ + // in the semantics for arrays. package json discards remaining + // elements that would not fit into the array. We generate an error in + // this case because we'd be losing information. + vlen := val.Len() + i := 0 + for ; i < vlen; i++ { + if err := elemdec(s, val.Index(i)); err == EOL { + break + } else if err != nil { + return err + } + } + if i < vlen { + zero(val, i) + } + return wrapStreamError(s.ListEnd(), val.Type()) +} + func decodeByteSlice(s *Stream, val reflect.Value) error { kind, _, err := s.Kind() if err != nil { return err } if kind == List { - return decodeList(s, val, decodeUint, maxInt) + return decodeListSlice(s, val, decodeUint) } b, err := s.Bytes() if err == nil { @@ -251,14 +280,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error { switch kind { case Byte: if val.Len() == 0 { - return decodeError{"input string too big", val.Type()} + return decodeError{"input string too long", val.Type()} } bv, _ := s.Uint() val.Index(0).SetUint(bv) zero(val, 1) case String: if uint64(val.Len()) < size { - return decodeError{"input string too big", val.Type()} + return decodeError{"input string too long", val.Type()} } slice := val.Slice(0, int(size)).Interface().([]byte) if err := s.readFull(slice); err != nil { @@ -266,14 +295,15 @@ func decodeByteArray(s *Stream, val reflect.Value) error { } zero(val, int(size)) case List: - return decodeList(s, val, decodeUint, val.Len()) + return decodeListArray(s, val, decodeUint) } return nil } func zero(val reflect.Value, start int) { z := reflect.Zero(val.Type().Elem()) - for i := start; i < val.Len(); i++ { + end := val.Len() + for i := start; i < end; i++ { val.Index(i).Set(z) } } @@ -296,7 +326,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { } dec := func(s *Stream, val reflect.Value) (err error) { if _, err = s.List(); err != nil { - return err + return wrapStreamError(err, typ) } for _, f := range fields { err = f.info.decoder(s, val.Field(f.index)) @@ -307,10 +337,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { return err } } - if err = s.ListEnd(); err == errNotAtEOL { - err = decodeError{"input list has too many elements", typ} - } - return err + return wrapStreamError(s.ListEnd(), typ) } return dec, nil } @@ -348,7 +375,7 @@ func decodeInterface(s *Stream, val reflect.Value) error { } if kind == List { slice := reflect.New(ifsliceType).Elem() - if err := decodeList(s, slice, decodeInterface, maxInt); err != nil { + if err := decodeListSlice(s, slice, decodeInterface); err != nil { return err } val.Set(slice) |