diff options
author | obscuren <geffobscura@gmail.com> | 2015-04-19 23:07:59 +0800 |
---|---|---|
committer | obscuren <geffobscura@gmail.com> | 2015-04-19 23:07:59 +0800 |
commit | 8eff550e8b9bf121c27a4c2469ec9878d803a60e (patch) | |
tree | ddc62daeec7a44c2baacb986f8c9d76df25a5976 | |
parent | 4683f9c0a71fd42e749da46ac56c6ba76f379931 (diff) | |
parent | 8f3a7e41deff4084b166aca1337258077bd2a3e6 (diff) | |
download | dexon-8eff550e8b9bf121c27a4c2469ec9878d803a60e.tar dexon-8eff550e8b9bf121c27a4c2469ec9878d803a60e.tar.gz dexon-8eff550e8b9bf121c27a4c2469ec9878d803a60e.tar.bz2 dexon-8eff550e8b9bf121c27a4c2469ec9878d803a60e.tar.lz dexon-8eff550e8b9bf121c27a4c2469ec9878d803a60e.tar.xz dexon-8eff550e8b9bf121c27a4c2469ec9878d803a60e.tar.zst dexon-8eff550e8b9bf121c27a4c2469ec9878d803a60e.zip |
Merge branch 'fjl-rlp-size-validation' into develop
-rw-r--r-- | cmd/rlpdump/main.go | 2 | ||||
-rw-r--r-- | cmd/utils/cmd.go | 2 | ||||
-rw-r--r-- | core/types/transaction.go | 2 | ||||
-rw-r--r-- | eth/handler.go | 12 | ||||
-rw-r--r-- | p2p/discover/udp.go | 2 | ||||
-rw-r--r-- | p2p/message.go | 3 | ||||
-rw-r--r-- | p2p/peer_error.go | 2 | ||||
-rw-r--r-- | rlp/decode.go | 332 | ||||
-rw-r--r-- | rlp/decode_test.go | 287 | ||||
-rw-r--r-- | rlp/encode.go | 8 | ||||
-rw-r--r-- | rlp/typecache.go | 51 | ||||
-rw-r--r-- | whisper/envelope.go | 9 | ||||
-rw-r--r-- | whisper/peer.go | 2 |
13 files changed, 481 insertions, 233 deletions
diff --git a/cmd/rlpdump/main.go b/cmd/rlpdump/main.go index 8567dcff8..528ccc6bd 100644 --- a/cmd/rlpdump/main.go +++ b/cmd/rlpdump/main.go @@ -78,7 +78,7 @@ func main() { os.Exit(2) } - s := rlp.NewStream(r) + s := rlp.NewStream(r, 0) for { if err := dump(s, 0); err != nil { if err != io.EOF { diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go index 7286f5c5e..64faf6ad1 100644 --- a/cmd/utils/cmd.go +++ b/cmd/utils/cmd.go @@ -154,7 +154,7 @@ func ImportChain(chainmgr *core.ChainManager, fn string) error { defer fh.Close() chainmgr.Reset() - stream := rlp.NewStream(fh) + stream := rlp.NewStream(fh, 0) var i, n int batchSize := 2500 diff --git a/core/types/transaction.go b/core/types/transaction.go index 6646bdf29..d8dcd7424 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -22,7 +22,7 @@ type Transaction struct { AccountNonce uint64 Price *big.Int GasLimit *big.Int - Recipient *common.Address // nil means contract creation + Recipient *common.Address `rlp:"nil"` // nil means contract creation Amount *big.Int Payload []byte V byte diff --git a/eth/handler.go b/eth/handler.go index 780ec3931..5c0660d84 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -197,7 +197,7 @@ func (self *ProtocolManager) handleMsg(p *peer) error { // returns either requested hashes or nothing (i.e. not found) return p.sendBlockHashes(hashes) case BlockHashesMsg: - msgStream := rlp.NewStream(msg.Payload) + msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) var hashes []common.Hash if err := msgStream.Decode(&hashes); err != nil { @@ -209,12 +209,12 @@ func (self *ProtocolManager) handleMsg(p *peer) error { } case GetBlocksMsg: - msgStream := rlp.NewStream(msg.Payload) + var blocks []*types.Block + + msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) if _, err := msgStream.List(); err != nil { return err } - - var blocks []*types.Block var i int for { i++ @@ -236,9 +236,9 @@ func (self *ProtocolManager) handleMsg(p *peer) error { } return p.sendBlocks(blocks) case BlocksMsg: - msgStream := rlp.NewStream(msg.Payload) - var blocks []*types.Block + + msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) if err := msgStream.Decode(&blocks); err != nil { glog.V(logger.Detail).Infoln("Decode error", err) blocks = nil diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index 61a0abed9..07a1a739c 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -413,7 +413,7 @@ func decodePacket(buf []byte) (packet, NodeID, []byte, error) { default: return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype) } - err = rlp.Decode(bytes.NewReader(sigdata[1:]), req) + err = rlp.DecodeBytes(sigdata[1:], req) return req, fromID, hash, err } diff --git a/p2p/message.go b/p2p/message.go index b42acbe3c..be6405d6f 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -32,7 +32,8 @@ type Msg struct { // // For the decoding rules, please see package rlp. func (msg Msg) Decode(val interface{}) error { - if err := rlp.Decode(msg.Payload, val); err != nil { + s := rlp.NewStream(msg.Payload, uint64(msg.Size)) + if err := s.Decode(val); err != nil { return newPeerError(errInvalidMsg, "(code %x) (size %d) %v", msg.Code, msg.Size, err) } return nil diff --git a/p2p/peer_error.go b/p2p/peer_error.go index 402131630..a912f6064 100644 --- a/p2p/peer_error.go +++ b/p2p/peer_error.go @@ -57,7 +57,7 @@ func (self *peerError) Error() string { return self.message } -type DiscReason byte +type DiscReason uint const ( DiscRequested DiscReason = iota diff --git a/rlp/decode.go b/rlp/decode.go index 3b5617475..6952ecaea 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -9,6 +9,7 @@ import ( "io" "math/big" "reflect" + "strings" ) var ( @@ -35,25 +36,35 @@ type Decoder interface { // If the type implements the Decoder interface, decode calls // DecodeRLP. // -// To decode into a pointer, Decode will set the pointer to nil if the -// input has size zero or the input is a single byte with value zero. -// If the input has nonzero size, Decode will allocate a new value of -// the type being pointed to. +// To decode into a pointer, Decode will decode into the value pointed +// to. If the pointer is nil, a new value of the pointer's element +// type is allocated. If the pointer is non-nil, the existing value +// will reused. // // To decode into a struct, Decode expects the input to be an RLP // list. The decoded elements of the list are assigned to each public -// field in the order given by the struct's definition. If the input -// list has too few elements, no error is returned and the remaining -// fields will have the zero value. -// Recursive struct types are supported. +// field in the order given by the struct's definition. The input list +// must contain an element for each decoded field. Decode returns an +// error if there are too few or too many elements. +// +// The decoding of struct fields honours one particular struct tag, +// "nil". This tag applies to pointer-typed fields and changes the +// decoding rules for the field such that input values of size zero +// decode as a nil pointer. This tag can be useful when decoding recursive +// types. +// +// type StructWithEmptyOK struct { +// Foo *[20]byte `rlp:"nil"` +// } // // To decode into a slice, the input must be a list and the resulting -// slice will contain the input elements in order. -// As a special case, if the slice has a byte-size element type, the input -// can also be an RLP string. +// slice will contain the input elements in order. For byte slices, +// the input must be an RLP string. Array types decode similarly, with +// the additional restriction that the number of input elements (or +// bytes) must match the array's length. // // To decode into a Go string, the input must be an RLP string. The -// bytes are taken as-is and will not necessarily be valid UTF-8. +// input bytes are taken as-is and will not necessarily be valid UTF-8. // // To decode into an unsigned integer type, the input must also be an RLP // string. The bytes are interpreted as a big endian representation of @@ -64,20 +75,28 @@ type Decoder interface { // To decode into an interface value, Decode stores one of these // in the value: // -// []interface{}, for RLP lists -// []byte, for RLP strings +// []interface{}, for RLP lists +// []byte, for RLP strings // // Non-empty interface types are not supported, nor are booleans, // signed integers, floating point numbers, maps, channels and // functions. +// +// Note that Decode does not set an input limit for all readers +// and may be vulnerable to panics cause by huge value sizes. If +// you need an input limit, use +// +// NewStream(r, limit).Decode(val) func Decode(r io.Reader, val interface{}) error { - return NewStream(r).Decode(val) + // TODO: this could use a Stream from a pool. + return NewStream(r, 0).Decode(val) } // DecodeBytes parses RLP data from b into val. // Please see the documentation of Decode for the decoding rules. func DecodeBytes(b []byte, val interface{}) error { - return NewStream(bytes.NewReader(b)).Decode(val) + // TODO: this could use a Stream from a pool. + return NewStream(bytes.NewReader(b), uint64(len(b))).Decode(val) } type decodeError struct { @@ -100,7 +119,9 @@ func (err *decodeError) Error() string { func wrapStreamError(err error, typ reflect.Type) error { switch err { case ErrCanonInt: - return &decodeError{msg: "canon int error appends zero's", typ: typ} + return &decodeError{msg: "non-canonical integer (leading zero bytes)", typ: typ} + case ErrCanonSize: + return &decodeError{msg: "non-canonical size information", typ: typ} case ErrExpectedList: return &decodeError{msg: "expected input list", typ: typ} case ErrExpectedString: @@ -125,7 +146,7 @@ var ( bigInt = reflect.TypeOf(big.Int{}) ) -func makeDecoder(typ reflect.Type) (dec decoder, err error) { +func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) { kind := typ.Kind() switch { case typ.Implements(decoderInterface): @@ -145,6 +166,9 @@ func makeDecoder(typ reflect.Type) (dec decoder, err error) { case kind == reflect.Struct: return makeStructDecoder(typ) case kind == reflect.Ptr: + if tags.nilOK { + return makeOptionalPtrDecoder(typ) + } return makePtrDecoder(typ) case kind == reflect.Interface: return decodeInterface, nil @@ -186,12 +210,10 @@ func decodeBigInt(s *Stream, val reflect.Value) error { i = new(big.Int) val.Set(reflect.ValueOf(i)) } - - // Reject big integers which are zero appended + // Reject leading zero bytes if len(b) > 0 && b[0] == 0 { return wrapStreamError(ErrCanonInt, val.Type()) } - i.SetBytes(b) return nil } @@ -205,7 +227,7 @@ func makeListDecoder(typ reflect.Type) (decoder, error) { return decodeByteSlice, nil } } - etypeinfo, err := cachedTypeInfo1(etype) + etypeinfo, err := cachedTypeInfo1(etype, tags{}) if err != nil { return nil, err } @@ -259,19 +281,10 @@ func decodeListSlice(s *Stream, val reflect.Value, elemdec decoder) error { } func decodeListArray(s *Stream, val reflect.Value, elemdec decoder) error { - size, err := s.List() + _, err := s.List() if err != nil { - return err - } - if size == 0 { - zero(val, 0) - return s.ListEnd() + return wrapStreamError(err, val.Type()) } - - // The approach here is stolen from package json, although we differ - // in the semantics for arrays. package json discards remaining - // elements that would not fit into the array. We generate an error in - // this case because we'd be losing information. vlen := val.Len() i := 0 for ; i < vlen; i++ { @@ -282,24 +295,18 @@ func decodeListArray(s *Stream, val reflect.Value, elemdec decoder) error { } } if i < vlen { - zero(val, i) + return &decodeError{msg: "input list has too few elements", typ: val.Type()} } return wrapStreamError(s.ListEnd(), val.Type()) } func decodeByteSlice(s *Stream, val reflect.Value) error { - kind, _, err := s.Kind() - if err != nil { - return err - } - if kind == List { - return decodeListSlice(s, val, decodeUint) - } b, err := s.Bytes() - if err == nil { - val.SetBytes(b) + if err != nil { + return wrapStreamError(err, val.Type()) } - return err + val.SetBytes(b) + return nil } func decodeByteArray(s *Stream, val reflect.Value) error { @@ -307,42 +314,38 @@ func decodeByteArray(s *Stream, val reflect.Value) error { if err != nil { return err } + vlen := val.Len() switch kind { case Byte: - if val.Len() == 0 { + if vlen == 0 { return &decodeError{msg: "input string too long", typ: val.Type()} } + if vlen > 1 { + return &decodeError{msg: "input string too short", typ: val.Type()} + } bv, _ := s.Uint() val.Index(0).SetUint(bv) - zero(val, 1) case String: - if uint64(val.Len()) < size { + if uint64(vlen) < size { return &decodeError{msg: "input string too long", typ: val.Type()} } - slice := val.Slice(0, int(size)).Interface().([]byte) + if uint64(vlen) > size { + return &decodeError{msg: "input string too short", typ: val.Type()} + } + slice := val.Slice(0, vlen).Interface().([]byte) if err := s.readFull(slice); err != nil { return err } - zero(val, int(size)) + // Reject cases where single byte encoding should have been used. + if size == 1 && slice[0] < 56 { + return wrapStreamError(ErrCanonSize, val.Type()) + } case List: - return decodeListArray(s, val, decodeUint) + return wrapStreamError(ErrExpectedString, val.Type()) } return nil } -func zero(val reflect.Value, start int) { - z := reflect.Zero(val.Type().Elem()) - end := val.Len() - for i := start; i < end; i++ { - val.Index(i).Set(z) - } -} - -type field struct { - index int - info *typeinfo -} - func makeStructDecoder(typ reflect.Type) (decoder, error) { fields, err := structFields(typ) if err != nil { @@ -355,8 +358,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { for _, f := range fields { err = f.info.decoder(s, val.Field(f.index)) if err == EOL { - // too few elements. leave the rest at their zero value. - break + return &decodeError{msg: "too few elements", typ: typ} } else if err != nil { return addErrorContext(err, "."+typ.Field(f.index).Name) } @@ -366,15 +368,41 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { return dec, nil } +// makePtrDecoder creates a decoder that decodes into +// the pointer's element type. func makePtrDecoder(typ reflect.Type) (decoder, error) { etype := typ.Elem() - etypeinfo, err := cachedTypeInfo1(etype) + etypeinfo, err := cachedTypeInfo1(etype, tags{}) if err != nil { return nil, err } dec := func(s *Stream, val reflect.Value) (err error) { - _, size, err := s.Kind() - if err != nil || size == 0 && s.byteval == 0 { + newval := val + if val.IsNil() { + newval = reflect.New(etype) + } + if err = etypeinfo.decoder(s, newval.Elem()); err == nil { + val.Set(newval) + } + return err + } + return dec, nil +} + +// makeOptionalPtrDecoder creates a decoder that decodes empty values +// as nil. Non-empty values are decoded into a value of the element type, +// just like makePtrDecoder does. +// +// This decoder is used for pointer-typed struct fields with struct tag "nil". +func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { + etype := typ.Elem() + etypeinfo, err := cachedTypeInfo1(etype, tags{}) + if err != nil { + return nil, err + } + dec := func(s *Stream, val reflect.Value) (err error) { + kind, size, err := s.Kind() + if err != nil || size == 0 && kind != Byte { // rearm s.Kind. This is important because the input // position must advance to the next value even though // we don't read anything. @@ -465,15 +493,18 @@ var ( // has been reached during streaming. EOL = errors.New("rlp: end of list") - // Other errors + // Actual Errors ErrExpectedString = errors.New("rlp: expected String or Byte") ErrExpectedList = errors.New("rlp: expected List") - ErrCanonInt = errors.New("rlp: expected Int") + ErrCanonInt = errors.New("rlp: non-canonical integer format") + ErrCanonSize = errors.New("rlp: non-canonical size information") ErrElemTooLarge = errors.New("rlp: element is larger than containing list") + ErrValueTooLarge = errors.New("rlp: value size exceeds available input length") // internal errors - errNotInList = errors.New("rlp: call of ListEnd outside of any list") - errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") + errNotInList = errors.New("rlp: call of ListEnd outside of any list") + errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") + errUintOverflow = errors.New("rlp: uint overflow") ) // ByteReader must be implemented by any input reader for a Stream. It @@ -496,23 +527,44 @@ type ByteReader interface { // // Stream is not safe for concurrent use. type Stream struct { - r ByteReader + r ByteReader + + // number of bytes remaining to be read from r. + remaining uint64 + limited bool + + // auxiliary buffer for integer decoding uintbuf []byte kind Kind // kind of value ahead size uint64 // size of value ahead byteval byte // value of single byte in type tag + kinderr error // error from last readKind stack []listpos } type listpos struct{ pos, size uint64 } -// NewStream creates a new stream reading from r. -// If r does not implement ByteReader, the Stream will -// introduce its own buffering. -func NewStream(r io.Reader) *Stream { +// NewStream creates a new decoding stream reading from r. +// +// If r implements the ByteReader interface, Stream will +// not introduce any buffering. +// +// For non-toplevel values, Stream returns ErrElemTooLarge +// for values that do not fit into the enclosing list. +// +// Stream supports an optional input limit. If a limit is set, the +// size of any toplevel value will be checked against the remaining +// input length. Stream operations that encounter a value exceeding +// the remaining input length will return ErrValueTooLarge. The limit +// can be set by passing a non-zero value for inputLimit. +// +// If r is a bytes.Reader or strings.Reader, the input limit is set to +// the length of r's underlying data unless an explicit limit is +// provided. +func NewStream(r io.Reader, inputLimit uint64) *Stream { s := new(Stream) - s.Reset(r) + s.Reset(r, inputLimit) return s } @@ -520,7 +572,7 @@ func NewStream(r io.Reader) *Stream { // at an encoded list of the given length. func NewListStream(r io.Reader, len uint64) *Stream { s := new(Stream) - s.Reset(r) + s.Reset(r, len) s.kind = List s.size = len return s @@ -543,6 +595,9 @@ func (s *Stream) Bytes() ([]byte, error) { if err = s.readFull(b); err != nil { return nil, err } + if size == 1 && b[0] < 56 { + return nil, ErrCanonSize + } return b, nil default: return nil, ErrExpectedString @@ -574,8 +629,6 @@ func (s *Stream) Raw() ([]byte, error) { return buf, nil } -var errUintOverflow = errors.New("rlp: uint overflow") - // Uint reads an RLP string of up to 8 bytes and returns its contents // as an unsigned integer. If the input does not contain an RLP string, the // returned error will be ErrExpectedString. @@ -590,13 +643,27 @@ func (s *Stream) uint(maxbits int) (uint64, error) { } switch kind { case Byte: + if s.byteval == 0 { + return 0, ErrCanonInt + } s.kind = -1 // rearm Kind return uint64(s.byteval), nil case String: if size > uint64(maxbits/8) { return 0, errUintOverflow } - return s.readUint(byte(size)) + v, err := s.readUint(byte(size)) + switch { + case err == ErrCanonSize: + // Adjust error because we're not reading a size right now. + return 0, ErrCanonInt + case err != nil: + return 0, err + case size > 0 && v < 56: + return 0, ErrCanonSize + default: + return v, nil + } default: return 0, ErrExpectedString } @@ -653,7 +720,7 @@ func (s *Stream) Decode(val interface{}) error { if rval.IsNil() { return errDecodeIntoNil } - info, err := cachedTypeInfo(rtyp.Elem()) + info, err := cachedTypeInfo(rtyp.Elem(), tags{}) if err != nil { return err } @@ -667,17 +734,40 @@ func (s *Stream) Decode(val interface{}) error { } // Reset discards any information about the current decoding context -// and starts reading from r. If r does not also implement ByteReader, -// Stream will do its own buffering. -func (s *Stream) Reset(r io.Reader) { +// and starts reading from r. This method is meant to facilitate reuse +// of a preallocated Stream across many decoding operations. +// +// If r does not also implement ByteReader, Stream will do its own +// buffering. +func (s *Stream) Reset(r io.Reader, inputLimit uint64) { + if inputLimit > 0 { + s.remaining = inputLimit + s.limited = true + } else { + // Attempt to automatically discover + // the limit when reading from a byte slice. + switch br := r.(type) { + case *bytes.Reader: + s.remaining = uint64(br.Len()) + s.limited = true + case *strings.Reader: + s.remaining = uint64(br.Len()) + s.limited = true + default: + s.limited = false + } + } + // Wrap r with a buffer if it doesn't have one. bufr, ok := r.(ByteReader) if !ok { bufr = bufio.NewReader(r) } s.r = bufr + // Reset the decoding context. s.stack = s.stack[:0] s.size = 0 s.kind = -1 + s.kinderr = nil if s.uintbuf == nil { s.uintbuf = make([]byte, 8) } @@ -700,19 +790,31 @@ func (s *Stream) Kind() (kind Kind, size uint64, err error) { tos = &s.stack[len(s.stack)-1] } if s.kind < 0 { + s.kinderr = nil + // Don't read further if we're at the end of the + // innermost list. if tos != nil && tos.pos == tos.size { return 0, 0, EOL } - kind, size, err = s.readKind() - if err != nil { - return 0, 0, err + s.kind, s.size, s.kinderr = s.readKind() + if s.kinderr == nil { + if tos == nil { + // At toplevel, check that the value is smaller + // than the remaining input length. + if s.limited && s.size > s.remaining { + s.kinderr = ErrValueTooLarge + } + } else { + // Inside a list, check that the value doesn't overflow the list. + if s.size > tos.size-tos.pos { + s.kinderr = ErrElemTooLarge + } + } } - s.kind, s.size = kind, size - } - if tos != nil && tos.pos+s.size > tos.size { - return 0, 0, ErrElemTooLarge } - return s.kind, s.size, nil + // Note: this might return a sticky error generated + // by an earlier call to readKind. + return s.kind, s.size, s.kinderr } func (s *Stream) readKind() (kind Kind, size uint64, err error) { @@ -741,6 +843,9 @@ func (s *Stream) readKind() (kind Kind, size uint64, err error) { // would be encoded as 0xB90400 followed by the string. The range of // the first byte is thus [0xB8, 0xBF]. size, err = s.readUint(b - 0xB7) + if err == nil && size < 56 { + err = ErrCanonSize + } return String, size, err case b < 0xF8: // If the total payload of a list @@ -757,27 +862,46 @@ func (s *Stream) readKind() (kind Kind, size uint64, err error) { // the concatenation of the RLP encodings of the items. The // range of the first byte is thus [0xF8, 0xFF]. size, err = s.readUint(b - 0xF7) + if err == nil && size < 56 { + err = ErrCanonSize + } return List, size, err } } func (s *Stream) readUint(size byte) (uint64, error) { - if size == 1 { + switch size { + case 0: + s.kind = -1 // rearm Kind + return 0, nil + case 1: b, err := s.readByte() if err == io.EOF { err = io.ErrUnexpectedEOF } return uint64(b), err + default: + start := int(8 - size) + for i := 0; i < start; i++ { + s.uintbuf[i] = 0 + } + if err := s.readFull(s.uintbuf[start:]); err != nil { + return 0, err + } + if s.uintbuf[start] == 0 { + // Note: readUint is also used to decode integer + // values. The error needs to be adjusted to become + // ErrCanonInt in this case. + return 0, ErrCanonSize + } + return binary.BigEndian.Uint64(s.uintbuf), nil } - start := int(8 - size) - for i := 0; i < start; i++ { - s.uintbuf[i] = 0 - } - err := s.readFull(s.uintbuf[start:]) - return binary.BigEndian.Uint64(s.uintbuf), err } func (s *Stream) readFull(buf []byte) (err error) { + if s.limited && s.remaining < uint64(len(buf)) { + return ErrValueTooLarge + } s.willRead(uint64(len(buf))) var nn, n int for n < len(buf) && err == nil { @@ -791,6 +915,9 @@ func (s *Stream) readFull(buf []byte) (err error) { } func (s *Stream) readByte() (byte, error) { + if s.limited && s.remaining == 0 { + return 0, io.EOF + } s.willRead(1) b, err := s.r.ReadByte() if len(s.stack) > 0 && err == io.EOF { @@ -801,6 +928,9 @@ func (s *Stream) readByte() (byte, error) { func (s *Stream) willRead(n uint64) { s.kind = -1 // rearm Kind + if s.limited { + s.remaining -= n + } if len(s.stack) > 0 { s.stack[len(s.stack)-1].pos += n } diff --git a/rlp/decode_test.go b/rlp/decode_test.go index 73a31c67f..d07520bd0 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -21,22 +21,18 @@ func TestStreamKind(t *testing.T) { {"7F", Byte, 0}, {"80", String, 0}, {"B7", String, 55}, - {"B800", String, 0}, {"B90400", String, 1024}, - {"BA000400", String, 1024}, - {"BB00000400", String, 1024}, {"BFFFFFFFFFFFFFFFFF", String, ^uint64(0)}, {"C0", List, 0}, {"C8", List, 8}, {"F7", List, 55}, - {"F800", List, 0}, - {"F804", List, 4}, {"F90400", List, 1024}, {"FFFFFFFFFFFFFFFFFF", List, ^uint64(0)}, } for i, test := range tests { - s := NewStream(bytes.NewReader(unhex(test.input))) + // using plainReader to inhibit input limit errors. + s := NewStream(newPlainReader(unhex(test.input)), 0) kind, len, err := s.Kind() if err != nil { t.Errorf("test %d: Kind returned error: %v", i, err) @@ -70,29 +66,85 @@ func TestNewListStream(t *testing.T) { } func TestStreamErrors(t *testing.T) { + withoutInputLimit := func(b []byte) *Stream { + return NewStream(newPlainReader(b), 0) + } + withCustomInputLimit := func(limit uint64) func([]byte) *Stream { + return func(b []byte) *Stream { + return NewStream(bytes.NewReader(b), limit) + } + } + type calls []string tests := []struct { string calls - error + newStream func([]byte) *Stream // uses bytes.Reader if nil + error error }{ - {"", calls{"Kind"}, io.EOF}, - {"", calls{"List"}, io.EOF}, - {"", calls{"Uint"}, io.EOF}, - {"C0", calls{"Bytes"}, ErrExpectedString}, - {"C0", calls{"Uint"}, ErrExpectedString}, - {"81", calls{"Bytes"}, io.ErrUnexpectedEOF}, - {"81", calls{"Uint"}, io.ErrUnexpectedEOF}, - {"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF}, - {"89000000000000000001", calls{"Uint"}, errUintOverflow}, - {"00", calls{"List"}, ErrExpectedList}, - {"80", calls{"List"}, ErrExpectedList}, - {"C0", calls{"List", "Uint"}, EOL}, - {"C801", calls{"List", "Uint", "Uint"}, io.ErrUnexpectedEOF}, - {"C8C9", calls{"List", "Kind"}, ErrElemTooLarge}, - {"C3C2010201", calls{"List", "List", "Uint", "Uint", "ListEnd", "Uint"}, EOL}, - {"00", calls{"ListEnd"}, errNotInList}, - {"C40102", calls{"List", "Uint", "ListEnd"}, errNotAtEOL}, + {"C0", calls{"Bytes"}, nil, ErrExpectedString}, + {"C0", calls{"Uint"}, nil, ErrExpectedString}, + {"89000000000000000001", calls{"Uint"}, nil, errUintOverflow}, + {"00", calls{"List"}, nil, ErrExpectedList}, + {"80", calls{"List"}, nil, ErrExpectedList}, + {"C0", calls{"List", "Uint"}, nil, EOL}, + {"C8C9010101010101010101", calls{"List", "Kind"}, nil, ErrElemTooLarge}, + {"C3C2010201", calls{"List", "List", "Uint", "Uint", "ListEnd", "Uint"}, nil, EOL}, + {"00", calls{"ListEnd"}, nil, errNotInList}, + {"C401020304", calls{"List", "Uint", "ListEnd"}, nil, errNotAtEOL}, + + // Non-canonical integers (e.g. leading zero bytes). + {"00", calls{"Uint"}, nil, ErrCanonInt}, + {"820002", calls{"Uint"}, nil, ErrCanonInt}, + {"8133", calls{"Uint"}, nil, ErrCanonSize}, + {"8156", calls{"Uint"}, nil, nil}, + + // Size tags must use the smallest possible encoding. + // Leading zero bytes in the size tag are also rejected. + {"8100", calls{"Uint"}, nil, ErrCanonSize}, + {"8100", calls{"Bytes"}, nil, ErrCanonSize}, + {"B800", calls{"Kind"}, withoutInputLimit, ErrCanonSize}, + {"B90000", calls{"Kind"}, withoutInputLimit, ErrCanonSize}, + {"B90055", calls{"Kind"}, withoutInputLimit, ErrCanonSize}, + {"BA0002FFFF", calls{"Bytes"}, withoutInputLimit, ErrCanonSize}, + {"F800", calls{"Kind"}, withoutInputLimit, ErrCanonSize}, + {"F90000", calls{"Kind"}, withoutInputLimit, ErrCanonSize}, + {"F90055", calls{"Kind"}, withoutInputLimit, ErrCanonSize}, + {"FA0002FFFF", calls{"List"}, withoutInputLimit, ErrCanonSize}, + + // Expected EOF + {"", calls{"Kind"}, nil, io.EOF}, + {"", calls{"Uint"}, nil, io.EOF}, + {"", calls{"List"}, nil, io.EOF}, + {"8158", calls{"Uint", "Uint"}, nil, io.EOF}, + {"C0", calls{"List", "ListEnd", "List"}, nil, io.EOF}, + + // Input limit errors. + {"81", calls{"Bytes"}, nil, ErrValueTooLarge}, + {"81", calls{"Uint"}, nil, ErrValueTooLarge}, + {"81", calls{"Raw"}, nil, ErrValueTooLarge}, + {"BFFFFFFFFFFFFFFFFFFF", calls{"Bytes"}, nil, ErrValueTooLarge}, + {"C801", calls{"List"}, nil, ErrValueTooLarge}, + + // Test for list element size check overflow. + {"CD04040404FFFFFFFFFFFFFFFFFF0303", calls{"List", "Uint", "Uint", "Uint", "Uint", "List"}, nil, ErrElemTooLarge}, + + // Test for input limit overflow. Since we are counting the limit + // down toward zero in Stream.remaining, reading too far can overflow + // remaining to a large value, effectively disabling the limit. + {"C40102030401", calls{"Raw", "Uint"}, withCustomInputLimit(5), io.EOF}, + {"C4010203048158", calls{"Raw", "Uint"}, withCustomInputLimit(6), ErrValueTooLarge}, + + // Check that the same calls are fine without a limit. + {"C40102030401", calls{"Raw", "Uint"}, withoutInputLimit, nil}, + {"C4010203048158", calls{"Raw", "Uint"}, withoutInputLimit, nil}, + + // Unexpected EOF. This only happens when there is + // no input limit, so the reader needs to be 'dumbed down'. + {"81", calls{"Bytes"}, withoutInputLimit, io.ErrUnexpectedEOF}, + {"81", calls{"Uint"}, withoutInputLimit, io.ErrUnexpectedEOF}, + {"BFFFFFFFFFFFFFFF", calls{"Bytes"}, withoutInputLimit, io.ErrUnexpectedEOF}, + {"C801", calls{"List", "Uint", "Uint"}, withoutInputLimit, io.ErrUnexpectedEOF}, // This test verifies that the input position is advanced // correctly when calling Bytes for empty strings. Kind can be called @@ -109,12 +161,15 @@ func TestStreamErrors(t *testing.T) { "Bytes", // past final element "Bytes", // this one should fail - }, EOL}, + }, nil, EOL}, } testfor: for i, test := range tests { - s := NewStream(bytes.NewReader(unhex(test.string))) + if test.newStream == nil { + test.newStream = func(b []byte) *Stream { return NewStream(bytes.NewReader(b), 0) } + } + s := test.newStream(unhex(test.string)) rs := reflect.ValueOf(s) for j, call := range test.calls { fval := rs.MethodByName(call) @@ -124,11 +179,17 @@ testfor: err = lastret.(error).Error() } if j == len(test.calls)-1 { - if err != test.error.Error() { - t.Errorf("test %d: last call (%s) error mismatch\ngot: %s\nwant: %v", + want := "<nil>" + if test.error != nil { + want = test.error.Error() + } + if err != want { + t.Log(test) + t.Errorf("test %d: last call (%s) error mismatch\ngot: %s\nwant: %s", i, call, err, test.error) } } else if err != "<nil>" { + t.Log(test) t.Errorf("test %d: call %d (%s) unexpected error: %q", i, j, call, err) continue testfor } @@ -137,7 +198,7 @@ testfor: } func TestStreamList(t *testing.T) { - s := NewStream(bytes.NewReader(unhex("C80102030405060708"))) + s := NewStream(bytes.NewReader(unhex("C80102030405060708")), 0) len, err := s.List() if err != nil { @@ -166,7 +227,7 @@ func TestStreamList(t *testing.T) { } func TestStreamRaw(t *testing.T) { - s := NewStream(bytes.NewReader(unhex("C58401010101"))) + s := NewStream(bytes.NewReader(unhex("C58401010101")), 0) s.List() want := unhex("8401010101") @@ -219,7 +280,7 @@ type simplestruct struct { type recstruct struct { I uint - Child *recstruct + Child *recstruct `rlp:"nil"` } var ( @@ -229,78 +290,58 @@ var ( ) ) -var ( - sharedByteArray [5]byte - sharedPtr = new(*uint) -) - var decodeTests = []decodeTest{ // integers {input: "05", ptr: new(uint32), value: uint32(5)}, {input: "80", ptr: new(uint32), value: uint32(0)}, - {input: "8105", ptr: new(uint32), value: uint32(5)}, {input: "820505", ptr: new(uint32), value: uint32(0x0505)}, {input: "83050505", ptr: new(uint32), value: uint32(0x050505)}, {input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)}, {input: "850505050505", ptr: new(uint32), error: "rlp: input string too long for uint32"}, {input: "C0", ptr: new(uint32), error: "rlp: expected input string or byte for uint32"}, + {input: "00", ptr: new(uint32), error: "rlp: non-canonical integer (leading zero bytes) for uint32"}, + {input: "8105", ptr: new(uint32), error: "rlp: non-canonical size information for uint32"}, + {input: "820004", ptr: new(uint32), error: "rlp: non-canonical integer (leading zero bytes) for uint32"}, + {input: "B8020004", ptr: new(uint32), error: "rlp: non-canonical size information for uint32"}, // slices {input: "C0", ptr: new([]uint), value: []uint{}}, {input: "C80102030405060708", ptr: new([]uint), value: []uint{1, 2, 3, 4, 5, 6, 7, 8}}, + {input: "F8020004", ptr: new([]uint), error: "rlp: non-canonical size information for []uint"}, // arrays - {input: "C0", ptr: new([5]uint), value: [5]uint{}}, {input: "C50102030405", ptr: new([5]uint), value: [5]uint{1, 2, 3, 4, 5}}, + {input: "C0", ptr: new([5]uint), error: "rlp: input list has too few elements for [5]uint"}, + {input: "C102", ptr: new([5]uint), error: "rlp: input list has too few elements for [5]uint"}, {input: "C6010203040506", ptr: new([5]uint), error: "rlp: input list has too many elements for [5]uint"}, + {input: "F8020004", ptr: new([5]uint), error: "rlp: non-canonical size information for [5]uint"}, + + // zero sized arrays + {input: "C0", ptr: new([0]uint), value: [0]uint{}}, + {input: "C101", ptr: new([0]uint), error: "rlp: input list has too many elements for [0]uint"}, // byte slices {input: "01", ptr: new([]byte), value: []byte{1}}, {input: "80", ptr: new([]byte), value: []byte{}}, {input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")}, - {input: "C0", ptr: new([]byte), value: []byte{}}, - {input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}}, - - { - input: "C3820102", - ptr: new([]byte), - error: "rlp: input string too long for uint8, decoding into ([]uint8)[0]", - }, + {input: "C0", ptr: new([]byte), error: "rlp: expected input string or byte for []uint8"}, + {input: "8105", ptr: new([]byte), error: "rlp: non-canonical size information for []uint8"}, // byte arrays - {input: "01", ptr: new([5]byte), value: [5]byte{1}}, - {input: "80", ptr: new([5]byte), value: [5]byte{}}, + {input: "02", ptr: new([1]byte), value: [1]byte{2}}, {input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}}, - {input: "C0", ptr: new([5]byte), value: [5]byte{}}, - {input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}}, - { - input: "C3820102", - ptr: new([5]byte), - error: "rlp: input string too long for uint8, decoding into ([5]uint8)[0]", - }, - { - input: "86010203040506", - ptr: new([5]byte), - error: "rlp: input string too long for [5]uint8", - }, - { - input: "850101", - ptr: new([5]byte), - error: io.ErrUnexpectedEOF.Error(), - }, - - // byte array reuse (should be zeroed) - {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, - {input: "8101", ptr: &sharedByteArray, value: [5]byte{1}}, // kind: String - {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, - {input: "01", ptr: &sharedByteArray, value: [5]byte{1}}, // kind: Byte - {input: "C3010203", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 0, 0}}, - {input: "C101", ptr: &sharedByteArray, value: [5]byte{1}}, // kind: List + // byte array errors + {input: "02", ptr: new([5]byte), error: "rlp: input string too short for [5]uint8"}, + {input: "80", ptr: new([5]byte), error: "rlp: input string too short for [5]uint8"}, + {input: "820000", ptr: new([5]byte), error: "rlp: input string too short for [5]uint8"}, + {input: "C0", ptr: new([5]byte), error: "rlp: expected input string or byte for [5]uint8"}, + {input: "C3010203", ptr: new([5]byte), error: "rlp: expected input string or byte for [5]uint8"}, + {input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too long for [5]uint8"}, + {input: "8105", ptr: new([1]byte), error: "rlp: non-canonical size information for [1]uint8"}, // zero sized byte arrays {input: "80", ptr: new([0]byte), value: [0]byte{}}, - {input: "C0", ptr: new([0]byte), value: [0]byte{}}, {input: "01", ptr: new([0]byte), error: "rlp: input string too long for [0]uint8"}, {input: "8101", ptr: new([0]byte), error: "rlp: input string too long for [0]uint8"}, @@ -312,20 +353,44 @@ var decodeTests = []decodeTest{ // big ints {input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, - {input: "820001", ptr: new(big.Int), error: "rlp: canon int error appends zero's for *big.Int"}, {input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works {input: "C0", ptr: new(*big.Int), error: "rlp: expected input string or byte for *big.Int"}, + {input: "820001", ptr: new(big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, + {input: "8105", ptr: new(big.Int), error: "rlp: non-canonical size information for *big.Int"}, // structs - {input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}}, - {input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}}, - {input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}}, { - input: "C501C302C103", + input: "C50583343434", + ptr: new(simplestruct), + value: simplestruct{5, "444"}, + }, + { + input: "C601C402C203C0", ptr: new(recstruct), value: recstruct{1, &recstruct{2, &recstruct{3, nil}}}, }, + // struct errors + { + input: "C0", + ptr: new(simplestruct), + error: "rlp: too few elements for rlp.simplestruct", + }, + { + input: "C105", + ptr: new(simplestruct), + error: "rlp: too few elements for rlp.simplestruct", + }, + { + input: "C7C50583343434C0", + ptr: new([]*simplestruct), + error: "rlp: too few elements for rlp.simplestruct, decoding into ([]*rlp.simplestruct)[1]", + }, + { + input: "83222222", + ptr: new(simplestruct), + error: "rlp: expected input list for rlp.simplestruct", + }, { input: "C3010101", ptr: new(simplestruct), @@ -338,20 +403,16 @@ var decodeTests = []decodeTest{ }, // pointers - {input: "00", ptr: new(*uint), value: (*uint)(nil)}, - {input: "80", ptr: new(*uint), value: (*uint)(nil)}, - {input: "C0", ptr: new(*uint), value: (*uint)(nil)}, + {input: "00", ptr: new(*[]byte), value: &[]byte{0}}, + {input: "80", ptr: new(*uint), value: uintp(0)}, + {input: "C0", ptr: new(*uint), error: "rlp: expected input string or byte for uint"}, {input: "07", ptr: new(*uint), value: uintp(7)}, - {input: "8108", ptr: new(*uint), value: uintp(8)}, + {input: "8158", ptr: new(*uint), value: uintp(0x58)}, {input: "C109", ptr: new(*[]uint), value: &[]uint{9}}, {input: "C58403030303", ptr: new(*[][]byte), value: &[][]byte{{3, 3, 3, 3}}}, // check that input position is advanced also for empty values. - {input: "C3808005", ptr: new([]*uint), value: []*uint{nil, nil, uintp(5)}}, - - // pointer should be reset to nil - {input: "05", ptr: sharedPtr, value: uintp(5)}, - {input: "80", ptr: sharedPtr, value: (*uint)(nil)}, + {input: "C3808005", ptr: new([]*uint), value: []*uint{uintp(0), uintp(0), uintp(5)}}, // interface{} {input: "00", ptr: new(interface{}), value: []byte{0}}, @@ -401,11 +462,17 @@ func TestDecodeWithByteReader(t *testing.T) { }) } -// dumbReader reads from a byte slice but does not -// implement ReadByte. -type dumbReader []byte +// plainReader reads from a byte slice but does not +// implement ReadByte. It is also not recognized by the +// size validation. This is useful to test how the decoder +// behaves on a non-buffered input stream. +type plainReader []byte + +func newPlainReader(b []byte) io.Reader { + return (*plainReader)(&b) +} -func (r *dumbReader) Read(buf []byte) (n int, err error) { +func (r *plainReader) Read(buf []byte) (n int, err error) { if len(*r) == 0 { return 0, io.EOF } @@ -416,15 +483,14 @@ func (r *dumbReader) Read(buf []byte) (n int, err error) { func TestDecodeWithNonByteReader(t *testing.T) { runTests(t, func(input []byte, into interface{}) error { - r := dumbReader(input) - return Decode(&r, into) + return Decode(newPlainReader(input), into) }) } func TestDecodeStreamReset(t *testing.T) { - s := NewStream(nil) + s := NewStream(nil, 0) runTests(t, func(input []byte, into interface{}) error { - s.Reset(bytes.NewReader(input)) + s.Reset(bytes.NewReader(input), 0) return s.Decode(into) }) } @@ -516,9 +582,36 @@ func ExampleDecode() { // Decoded value: rlp.example{A:0xa, B:0x14, private:0x0, String:"foobar"} } +func ExampleDecode_structTagNil() { + // In this example, we'll use the "nil" struct tag to change + // how a pointer-typed field is decoded. The input contains an RLP + // list of one element, an empty string. + input := []byte{0xC1, 0x80} + + // This type uses the normal rules. + // The empty input string is decoded as a pointer to an empty Go string. + var normalRules struct { + String *string + } + Decode(bytes.NewReader(input), &normalRules) + fmt.Printf("normal: String = %q\n", *normalRules.String) + + // This type uses the struct tag. + // The empty input string is decoded as a nil pointer. + var withEmptyOK struct { + String *string `rlp:"nil"` + } + Decode(bytes.NewReader(input), &withEmptyOK) + fmt.Printf("with nil tag: String = %v\n", withEmptyOK.String) + + // Output: + // normal: String = "" + // with nil tag: String = <nil> +} + func ExampleStream() { input, _ := hex.DecodeString("C90A1486666F6F626172") - s := NewStream(bytes.NewReader(input)) + s := NewStream(bytes.NewReader(input), 0) // Check what kind of value lies ahead kind, size, _ := s.Kind() diff --git a/rlp/encode.go b/rlp/encode.go index 6cf6776d6..10ff0ae79 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -194,7 +194,7 @@ func (w *encbuf) Write(b []byte) (int, error) { func (w *encbuf) encode(val interface{}) error { rval := reflect.ValueOf(val) - ti, err := cachedTypeInfo(rval.Type()) + ti, err := cachedTypeInfo(rval.Type(), tags{}) if err != nil { return err } @@ -485,7 +485,7 @@ func writeInterface(val reflect.Value, w *encbuf) error { return nil } eval := val.Elem() - ti, err := cachedTypeInfo(eval.Type()) + ti, err := cachedTypeInfo(eval.Type(), tags{}) if err != nil { return err } @@ -493,7 +493,7 @@ func writeInterface(val reflect.Value, w *encbuf) error { } func makeSliceWriter(typ reflect.Type) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem()) + etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) if err != nil { return nil, err } @@ -530,7 +530,7 @@ func makeStructWriter(typ reflect.Type) (writer, error) { } func makePtrWriter(typ reflect.Type) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem()) + etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) if err != nil { return nil, err } diff --git a/rlp/typecache.go b/rlp/typecache.go index 398f25d90..d512012e9 100644 --- a/rlp/typecache.go +++ b/rlp/typecache.go @@ -7,7 +7,7 @@ import ( var ( typeCacheMutex sync.RWMutex - typeCache = make(map[reflect.Type]*typeinfo) + typeCache = make(map[typekey]*typeinfo) ) type typeinfo struct { @@ -15,13 +15,25 @@ type typeinfo struct { writer } +// represents struct tags +type tags struct { + nilOK bool +} + +type typekey struct { + reflect.Type + // the key must include the struct tags because they + // might generate a different decoder. + tags +} + type decoder func(*Stream, reflect.Value) error type writer func(reflect.Value, *encbuf) error -func cachedTypeInfo(typ reflect.Type) (*typeinfo, error) { +func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) { typeCacheMutex.RLock() - info := typeCache[typ] + info := typeCache[typekey{typ, tags}] typeCacheMutex.RUnlock() if info != nil { return info, nil @@ -29,11 +41,12 @@ func cachedTypeInfo(typ reflect.Type) (*typeinfo, error) { // not in the cache, need to generate info for this type. typeCacheMutex.Lock() defer typeCacheMutex.Unlock() - return cachedTypeInfo1(typ) + return cachedTypeInfo1(typ, tags) } -func cachedTypeInfo1(typ reflect.Type) (*typeinfo, error) { - info := typeCache[typ] +func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) { + key := typekey{typ, tags} + info := typeCache[key] if info != nil { // another goroutine got the write lock first return info, nil @@ -41,21 +54,27 @@ func cachedTypeInfo1(typ reflect.Type) (*typeinfo, error) { // put a dummmy value into the cache before generating. // if the generator tries to lookup itself, it will get // the dummy value and won't call itself recursively. - typeCache[typ] = new(typeinfo) - info, err := genTypeInfo(typ) + typeCache[key] = new(typeinfo) + info, err := genTypeInfo(typ, tags) if err != nil { // remove the dummy value if the generator fails - delete(typeCache, typ) + delete(typeCache, key) return nil, err } - *typeCache[typ] = *info - return typeCache[typ], err + *typeCache[key] = *info + return typeCache[key], err +} + +type field struct { + index int + info *typeinfo } func structFields(typ reflect.Type) (fields []field, err error) { for i := 0; i < typ.NumField(); i++ { if f := typ.Field(i); f.PkgPath == "" { // exported - info, err := cachedTypeInfo1(f.Type) + tags := parseStructTag(f.Tag.Get("rlp")) + info, err := cachedTypeInfo1(f.Type, tags) if err != nil { return nil, err } @@ -65,9 +84,13 @@ func structFields(typ reflect.Type) (fields []field, err error) { return fields, nil } -func genTypeInfo(typ reflect.Type) (info *typeinfo, err error) { +func parseStructTag(tag string) tags { + return tags{nilOK: tag == "nil"} +} + +func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) { info = new(typeinfo) - if info.decoder, err = makeDecoder(typ); err != nil { + if info.decoder, err = makeDecoder(typ, tags); err != nil { return nil, err } if info.writer, err = makeWriter(typ); err != nil { diff --git a/whisper/envelope.go b/whisper/envelope.go index 0a817e26e..07762c300 100644 --- a/whisper/envelope.go +++ b/whisper/envelope.go @@ -109,16 +109,17 @@ func (self *Envelope) Hash() common.Hash { return self.hash } -// rlpenv is an Envelope but is not an rlp.Decoder. -// It is used for decoding because we need to -type rlpenv Envelope - // DecodeRLP decodes an Envelope from an RLP data stream. func (self *Envelope) DecodeRLP(s *rlp.Stream) error { raw, err := s.Raw() if err != nil { return err } + // The decoding of Envelope uses the struct fields but also needs + // to compute the hash of the whole RLP-encoded envelope. This + // type has the same structure as Envelope but is not an + // rlp.Decoder so we can reuse the Envelope struct definition. + type rlpenv Envelope if err := rlp.DecodeBytes(raw, (*rlpenv)(self)); err != nil { return err } diff --git a/whisper/peer.go b/whisper/peer.go index e4301f37c..28abf4260 100644 --- a/whisper/peer.go +++ b/whisper/peer.go @@ -66,7 +66,7 @@ func (self *peer) handshake() error { if packet.Code != statusCode { return fmt.Errorf("peer sent %x before status packet", packet.Code) } - s := rlp.NewStream(packet.Payload) + s := rlp.NewStream(packet.Payload, uint64(packet.Size)) if _, err := s.List(); err != nil { return fmt.Errorf("bad status message: %v", err) } |