diff options
Diffstat (limited to 'rlp/decode.go')
-rw-r--r-- | rlp/decode.go | 43 |
1 files changed, 29 insertions, 14 deletions
diff --git a/rlp/decode.go b/rlp/decode.go index 6952ecaea..0c660426f 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -820,6 +820,16 @@ func (s *Stream) Kind() (kind Kind, size uint64, err error) { func (s *Stream) readKind() (kind Kind, size uint64, err error) { b, err := s.readByte() if err != nil { + if len(s.stack) == 0 { + // At toplevel, Adjust the error to actual EOF. io.EOF is + // used by callers to determine when to stop decoding. + switch err { + case io.ErrUnexpectedEOF: + err = io.EOF + case ErrValueTooLarge: + err = io.EOF + } + } return 0, 0, err } s.byteval = 0 @@ -876,9 +886,6 @@ func (s *Stream) readUint(size byte) (uint64, error) { return 0, nil case 1: b, err := s.readByte() - if err == io.EOF { - err = io.ErrUnexpectedEOF - } return uint64(b), err default: start := int(8 - size) @@ -899,10 +906,9 @@ func (s *Stream) readUint(size byte) (uint64, error) { } func (s *Stream) readFull(buf []byte) (err error) { - if s.limited && s.remaining < uint64(len(buf)) { - return ErrValueTooLarge + if err := s.willRead(uint64(len(buf))); err != nil { + return err } - s.willRead(uint64(len(buf))) var nn, n int for n < len(buf) && err == nil { nn, err = s.r.Read(buf[n:]) @@ -915,23 +921,32 @@ func (s *Stream) readFull(buf []byte) (err error) { } func (s *Stream) readByte() (byte, error) { - if s.limited && s.remaining == 0 { - return 0, io.EOF + if err := s.willRead(1); err != nil { + return 0, err } - s.willRead(1) b, err := s.r.ReadByte() - if len(s.stack) > 0 && err == io.EOF { + if err == io.EOF { err = io.ErrUnexpectedEOF } return b, err } -func (s *Stream) willRead(n uint64) { +func (s *Stream) willRead(n uint64) error { s.kind = -1 // rearm Kind - if s.limited { - s.remaining -= n - } + if len(s.stack) > 0 { + // check list overflow + tos := s.stack[len(s.stack)-1] + if n > tos.size-tos.pos { + return ErrElemTooLarge + } s.stack[len(s.stack)-1].pos += n } + if s.limited { + if n > s.remaining { + return ErrValueTooLarge + } + s.remaining -= n + } + return nil } |