diff options
Diffstat (limited to 'rlp/decode.go')
-rw-r--r-- | rlp/decode.go | 81 |
1 files changed, 65 insertions, 16 deletions
diff --git a/rlp/decode.go b/rlp/decode.go index 96d912f56..7d95af02b 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -1,6 +1,7 @@ package rlp import ( + "bufio" "encoding/binary" "errors" "fmt" @@ -24,8 +25,9 @@ type Decoder interface { DecodeRLP(*Stream) error } -// Decode parses RLP-encoded data from r and stores the result -// in the value pointed to by val. Val must be a non-nil pointer. +// Decode parses RLP-encoded data from r and stores the result in the +// value pointed to by val. Val must be a non-nil pointer. If r does +// not implement ByteReader, Decode will do its own buffering. // // Decode uses the following type-dependent decoding rules: // @@ -66,10 +68,19 @@ type Decoder interface { // // Non-empty interface types are not supported, nor are bool, float32, // float64, maps, channel types and functions. -func Decode(r ByteReader, val interface{}) error { +func Decode(r io.Reader, val interface{}) error { return NewStream(r).Decode(val) } +type decodeError struct { + msg string + typ reflect.Type +} + +func (err decodeError) Error() string { + return fmt.Sprintf("rlp: %s for %v", err.msg, err.typ) +} + func makeNumDecoder(typ reflect.Type) decoder { kind := typ.Kind() switch { @@ -83,8 +94,11 @@ func makeNumDecoder(typ reflect.Type) decoder { } func decodeInt(s *Stream, val reflect.Value) error { - num, err := s.uint(val.Type().Bits()) - if err != nil { + typ := val.Type() + num, err := s.uint(typ.Bits()) + if err == errUintOverflow { + return decodeError{"input string too long", typ} + } else if err != nil { return err } val.SetInt(int64(num)) @@ -92,8 +106,11 @@ func decodeInt(s *Stream, val reflect.Value) error { } func decodeUint(s *Stream, val reflect.Value) error { - num, err := s.uint(val.Type().Bits()) - if err != nil { + typ := val.Type() + num, err := s.uint(typ.Bits()) + if err == errUintOverflow { + return decodeError{"input string too big", typ} + } else if err != nil { return err } val.SetUint(num) @@ -175,7 +192,7 @@ func decodeList(s *Stream, val reflect.Value, elemdec decoder, maxelem int) erro i := 0 for { if i > maxelem { - return fmt.Errorf("rlp: input List has more than %d elements", maxelem) + return decodeError{"input list has too many elements", val.Type()} } if val.Kind() == reflect.Slice { // grow slice if necessary @@ -226,8 +243,6 @@ func decodeByteSlice(s *Stream, val reflect.Value) error { return err } -var errStringDoesntFitArray = errors.New("rlp: string value doesn't fit into target array") - func decodeByteArray(s *Stream, val reflect.Value) error { kind, size, err := s.Kind() if err != nil { @@ -236,14 +251,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error { switch kind { case Byte: if val.Len() == 0 { - return errStringDoesntFitArray + return decodeError{"input string too big", val.Type()} } bv, _ := s.Uint() val.Index(0).SetUint(bv) zero(val, 1) case String: if uint64(val.Len()) < size { - return errStringDoesntFitArray + return decodeError{"input string too big", val.Type()} } slice := val.Slice(0, int(size)).Interface().([]byte) if err := s.readFull(slice); err != nil { @@ -293,7 +308,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { } } if err = s.ListEnd(); err == errNotAtEOL { - err = errors.New("rlp: input List has too many elements") + err = decodeError{"input list has too many elements", typ} } return err } @@ -432,8 +447,23 @@ type Stream struct { type listpos struct{ pos, size uint64 } -func NewStream(r ByteReader) *Stream { - return &Stream{r: r, uintbuf: make([]byte, 8), kind: -1} +// NewStream creates a new stream reading from r. +// If r does not implement ByteReader, the Stream will +// introduce its own buffering. +func NewStream(r io.Reader) *Stream { + s := new(Stream) + s.Reset(r) + return s +} + +// NewListStream creates a new stream that pretends to be positioned +// at an encoded list of the given length. +func NewListStream(r io.Reader, len uint64) *Stream { + s := new(Stream) + s.Reset(r) + s.kind = List + s.size = len + return s } // Bytes reads an RLP string and returns its contents as a byte slice. @@ -459,6 +489,8 @@ func (s *Stream) Bytes() ([]byte, error) { } } +var errUintOverflow = errors.New("rlp: uint overflow") + // Uint reads an RLP string of up to 8 bytes and returns its contents // as an unsigned integer. If the input does not contain an RLP string, the // returned error will be ErrExpectedString. @@ -477,7 +509,7 @@ func (s *Stream) uint(maxbits int) (uint64, error) { return uint64(s.byteval), nil case String: if size > uint64(maxbits/8) { - return 0, fmt.Errorf("rlp: string is larger than %d bits", maxbits) + return 0, errUintOverflow } return s.readUint(byte(size)) default: @@ -543,6 +575,23 @@ func (s *Stream) Decode(val interface{}) error { return info.decoder(s, rval.Elem()) } +// Reset discards any information about the current decoding context +// and starts reading from r. If r does not also implement ByteReader, +// Stream will do its own buffering. +func (s *Stream) Reset(r io.Reader) { + bufr, ok := r.(ByteReader) + if !ok { + bufr = bufio.NewReader(r) + } + s.r = bufr + s.stack = s.stack[:0] + s.size = 0 + s.kind = -1 + if s.uintbuf == nil { + s.uintbuf = make([]byte, 8) + } +} + // Kind returns the kind and size of the next value in the // input stream. // |