aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--cmd/rlpdump/main.go2
-rw-r--r--cmd/utils/cmd.go2
-rw-r--r--core/types/transaction.go2
-rw-r--r--eth/handler.go12
-rw-r--r--p2p/discover/udp.go2
-rw-r--r--p2p/message.go3
-rw-r--r--p2p/peer_error.go2
-rw-r--r--rlp/decode.go332
-rw-r--r--rlp/decode_test.go287
-rw-r--r--rlp/encode.go8
-rw-r--r--rlp/typecache.go51
-rw-r--r--whisper/envelope.go9
-rw-r--r--whisper/peer.go2
13 files changed, 481 insertions, 233 deletions
diff --git a/cmd/rlpdump/main.go b/cmd/rlpdump/main.go
index 8567dcff8..528ccc6bd 100644
--- a/cmd/rlpdump/main.go
+++ b/cmd/rlpdump/main.go
@@ -78,7 +78,7 @@ func main() {
os.Exit(2)
}
- s := rlp.NewStream(r)
+ s := rlp.NewStream(r, 0)
for {
if err := dump(s, 0); err != nil {
if err != io.EOF {
diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go
index 7286f5c5e..64faf6ad1 100644
--- a/cmd/utils/cmd.go
+++ b/cmd/utils/cmd.go
@@ -154,7 +154,7 @@ func ImportChain(chainmgr *core.ChainManager, fn string) error {
defer fh.Close()
chainmgr.Reset()
- stream := rlp.NewStream(fh)
+ stream := rlp.NewStream(fh, 0)
var i, n int
batchSize := 2500
diff --git a/core/types/transaction.go b/core/types/transaction.go
index 6646bdf29..d8dcd7424 100644
--- a/core/types/transaction.go
+++ b/core/types/transaction.go
@@ -22,7 +22,7 @@ type Transaction struct {
AccountNonce uint64
Price *big.Int
GasLimit *big.Int
- Recipient *common.Address // nil means contract creation
+ Recipient *common.Address `rlp:"nil"` // nil means contract creation
Amount *big.Int
Payload []byte
V byte
diff --git a/eth/handler.go b/eth/handler.go
index 780ec3931..5c0660d84 100644
--- a/eth/handler.go
+++ b/eth/handler.go
@@ -197,7 +197,7 @@ func (self *ProtocolManager) handleMsg(p *peer) error {
// returns either requested hashes or nothing (i.e. not found)
return p.sendBlockHashes(hashes)
case BlockHashesMsg:
- msgStream := rlp.NewStream(msg.Payload)
+ msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size))
var hashes []common.Hash
if err := msgStream.Decode(&hashes); err != nil {
@@ -209,12 +209,12 @@ func (self *ProtocolManager) handleMsg(p *peer) error {
}
case GetBlocksMsg:
- msgStream := rlp.NewStream(msg.Payload)
+ var blocks []*types.Block
+
+ msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size))
if _, err := msgStream.List(); err != nil {
return err
}
-
- var blocks []*types.Block
var i int
for {
i++
@@ -236,9 +236,9 @@ func (self *ProtocolManager) handleMsg(p *peer) error {
}
return p.sendBlocks(blocks)
case BlocksMsg:
- msgStream := rlp.NewStream(msg.Payload)
-
var blocks []*types.Block
+
+ msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size))
if err := msgStream.Decode(&blocks); err != nil {
glog.V(logger.Detail).Infoln("Decode error", err)
blocks = nil
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 61a0abed9..07a1a739c 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -413,7 +413,7 @@ func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
default:
return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
}
- err = rlp.Decode(bytes.NewReader(sigdata[1:]), req)
+ err = rlp.DecodeBytes(sigdata[1:], req)
return req, fromID, hash, err
}
diff --git a/p2p/message.go b/p2p/message.go
index b42acbe3c..be6405d6f 100644
--- a/p2p/message.go
+++ b/p2p/message.go
@@ -32,7 +32,8 @@ type Msg struct {
//
// For the decoding rules, please see package rlp.
func (msg Msg) Decode(val interface{}) error {
- if err := rlp.Decode(msg.Payload, val); err != nil {
+ s := rlp.NewStream(msg.Payload, uint64(msg.Size))
+ if err := s.Decode(val); err != nil {
return newPeerError(errInvalidMsg, "(code %x) (size %d) %v", msg.Code, msg.Size, err)
}
return nil
diff --git a/p2p/peer_error.go b/p2p/peer_error.go
index 402131630..a912f6064 100644
--- a/p2p/peer_error.go
+++ b/p2p/peer_error.go
@@ -57,7 +57,7 @@ func (self *peerError) Error() string {
return self.message
}
-type DiscReason byte
+type DiscReason uint
const (
DiscRequested DiscReason = iota
diff --git a/rlp/decode.go b/rlp/decode.go
index 3b5617475..6952ecaea 100644
--- a/rlp/decode.go
+++ b/rlp/decode.go
@@ -9,6 +9,7 @@ import (
"io"
"math/big"
"reflect"
+ "strings"
)
var (
@@ -35,25 +36,35 @@ type Decoder interface {
// If the type implements the Decoder interface, decode calls
// DecodeRLP.
//
-// To decode into a pointer, Decode will set the pointer to nil if the
-// input has size zero or the input is a single byte with value zero.
-// If the input has nonzero size, Decode will allocate a new value of
-// the type being pointed to.
+// To decode into a pointer, Decode will decode into the value pointed
+// to. If the pointer is nil, a new value of the pointer's element
+// type is allocated. If the pointer is non-nil, the existing value
+// will reused.
//
// To decode into a struct, Decode expects the input to be an RLP
// list. The decoded elements of the list are assigned to each public
-// field in the order given by the struct's definition. If the input
-// list has too few elements, no error is returned and the remaining
-// fields will have the zero value.
-// Recursive struct types are supported.
+// field in the order given by the struct's definition. The input list
+// must contain an element for each decoded field. Decode returns an
+// error if there are too few or too many elements.
+//
+// The decoding of struct fields honours one particular struct tag,
+// "nil". This tag applies to pointer-typed fields and changes the
+// decoding rules for the field such that input values of size zero
+// decode as a nil pointer. This tag can be useful when decoding recursive
+// types.
+//
+// type StructWithEmptyOK struct {
+// Foo *[20]byte `rlp:"nil"`
+// }
//
// To decode into a slice, the input must be a list and the resulting
-// slice will contain the input elements in order.
-// As a special case, if the slice has a byte-size element type, the input
-// can also be an RLP string.
+// slice will contain the input elements in order. For byte slices,
+// the input must be an RLP string. Array types decode similarly, with
+// the additional restriction that the number of input elements (or
+// bytes) must match the array's length.
//
// To decode into a Go string, the input must be an RLP string. The
-// bytes are taken as-is and will not necessarily be valid UTF-8.
+// input bytes are taken as-is and will not necessarily be valid UTF-8.
//
// To decode into an unsigned integer type, the input must also be an RLP
// string. The bytes are interpreted as a big endian representation of
@@ -64,20 +75,28 @@ type Decoder interface {
// To decode into an interface value, Decode stores one of these
// in the value:
//
-// []interface{}, for RLP lists
-// []byte, for RLP strings
+// []interface{}, for RLP lists
+// []byte, for RLP strings
//
// Non-empty interface types are not supported, nor are booleans,
// signed integers, floating point numbers, maps, channels and
// functions.
+//
+// Note that Decode does not set an input limit for all readers
+// and may be vulnerable to panics cause by huge value sizes. If
+// you need an input limit, use
+//
+// NewStream(r, limit).Decode(val)
func Decode(r io.Reader, val interface{}) error {
- return NewStream(r).Decode(val)
+ // TODO: this could use a Stream from a pool.
+ return NewStream(r, 0).Decode(val)
}
// DecodeBytes parses RLP data from b into val.
// Please see the documentation of Decode for the decoding rules.
func DecodeBytes(b []byte, val interface{}) error {
- return NewStream(bytes.NewReader(b)).Decode(val)
+ // TODO: this could use a Stream from a pool.
+ return NewStream(bytes.NewReader(b), uint64(len(b))).Decode(val)
}
type decodeError struct {
@@ -100,7 +119,9 @@ func (err *decodeError) Error() string {
func wrapStreamError(err error, typ reflect.Type) error {
switch err {
case ErrCanonInt:
- return &decodeError{msg: "canon int error appends zero's", typ: typ}
+ return &decodeError{msg: "non-canonical integer (leading zero bytes)", typ: typ}
+ case ErrCanonSize:
+ return &decodeError{msg: "non-canonical size information", typ: typ}
case ErrExpectedList:
return &decodeError{msg: "expected input list", typ: typ}
case ErrExpectedString:
@@ -125,7 +146,7 @@ var (
bigInt = reflect.TypeOf(big.Int{})
)
-func makeDecoder(typ reflect.Type) (dec decoder, err error) {
+func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) {
kind := typ.Kind()
switch {
case typ.Implements(decoderInterface):
@@ -145,6 +166,9 @@ func makeDecoder(typ reflect.Type) (dec decoder, err error) {
case kind == reflect.Struct:
return makeStructDecoder(typ)
case kind == reflect.Ptr:
+ if tags.nilOK {
+ return makeOptionalPtrDecoder(typ)
+ }
return makePtrDecoder(typ)
case kind == reflect.Interface:
return decodeInterface, nil
@@ -186,12 +210,10 @@ func decodeBigInt(s *Stream, val reflect.Value) error {
i = new(big.Int)
val.Set(reflect.ValueOf(i))
}
-
- // Reject big integers which are zero appended
+ // Reject leading zero bytes
if len(b) > 0 && b[0] == 0 {
return wrapStreamError(ErrCanonInt, val.Type())
}
-
i.SetBytes(b)
return nil
}
@@ -205,7 +227,7 @@ func makeListDecoder(typ reflect.Type) (decoder, error) {
return decodeByteSlice, nil
}
}
- etypeinfo, err := cachedTypeInfo1(etype)
+ etypeinfo, err := cachedTypeInfo1(etype, tags{})
if err != nil {
return nil, err
}
@@ -259,19 +281,10 @@ func decodeListSlice(s *Stream, val reflect.Value, elemdec decoder) error {
}
func decodeListArray(s *Stream, val reflect.Value, elemdec decoder) error {
- size, err := s.List()
+ _, err := s.List()
if err != nil {
- return err
- }
- if size == 0 {
- zero(val, 0)
- return s.ListEnd()
+ return wrapStreamError(err, val.Type())
}
-
- // The approach here is stolen from package json, although we differ
- // in the semantics for arrays. package json discards remaining
- // elements that would not fit into the array. We generate an error in
- // this case because we'd be losing information.
vlen := val.Len()
i := 0
for ; i < vlen; i++ {
@@ -282,24 +295,18 @@ func decodeListArray(s *Stream, val reflect.Value, elemdec decoder) error {
}
}
if i < vlen {
- zero(val, i)
+ return &decodeError{msg: "input list has too few elements", typ: val.Type()}
}
return wrapStreamError(s.ListEnd(), val.Type())
}
func decodeByteSlice(s *Stream, val reflect.Value) error {
- kind, _, err := s.Kind()
- if err != nil {
- return err
- }
- if kind == List {
- return decodeListSlice(s, val, decodeUint)
- }
b, err := s.Bytes()
- if err == nil {
- val.SetBytes(b)
+ if err != nil {
+ return wrapStreamError(err, val.Type())
}
- return err
+ val.SetBytes(b)
+ return nil
}
func decodeByteArray(s *Stream, val reflect.Value) error {
@@ -307,42 +314,38 @@ func decodeByteArray(s *Stream, val reflect.Value) error {
if err != nil {
return err
}
+ vlen := val.Len()
switch kind {
case Byte:
- if val.Len() == 0 {
+ if vlen == 0 {
return &decodeError{msg: "input string too long", typ: val.Type()}
}
+ if vlen > 1 {
+ return &decodeError{msg: "input string too short", typ: val.Type()}
+ }
bv, _ := s.Uint()
val.Index(0).SetUint(bv)
- zero(val, 1)
case String:
- if uint64(val.Len()) < size {
+ if uint64(vlen) < size {
return &decodeError{msg: "input string too long", typ: val.Type()}
}
- slice := val.Slice(0, int(size)).Interface().([]byte)
+ if uint64(vlen) > size {
+ return &decodeError{msg: "input string too short", typ: val.Type()}
+ }
+ slice := val.Slice(0, vlen).Interface().([]byte)
if err := s.readFull(slice); err != nil {
return err
}
- zero(val, int(size))
+ // Reject cases where single byte encoding should have been used.
+ if size == 1 && slice[0] < 56 {
+ return wrapStreamError(ErrCanonSize, val.Type())
+ }
case List:
- return decodeListArray(s, val, decodeUint)
+ return wrapStreamError(ErrExpectedString, val.Type())
}
return nil
}
-func zero(val reflect.Value, start int) {
- z := reflect.Zero(val.Type().Elem())
- end := val.Len()
- for i := start; i < end; i++ {
- val.Index(i).Set(z)
- }
-}
-
-type field struct {
- index int
- info *typeinfo
-}
-
func makeStructDecoder(typ reflect.Type) (decoder, error) {
fields, err := structFields(typ)
if err != nil {
@@ -355,8 +358,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
for _, f := range fields {
err = f.info.decoder(s, val.Field(f.index))
if err == EOL {
- // too few elements. leave the rest at their zero value.
- break
+ return &decodeError{msg: "too few elements", typ: typ}
} else if err != nil {
return addErrorContext(err, "."+typ.Field(f.index).Name)
}
@@ -366,15 +368,41 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
return dec, nil
}
+// makePtrDecoder creates a decoder that decodes into
+// the pointer's element type.
func makePtrDecoder(typ reflect.Type) (decoder, error) {
etype := typ.Elem()
- etypeinfo, err := cachedTypeInfo1(etype)
+ etypeinfo, err := cachedTypeInfo1(etype, tags{})
if err != nil {
return nil, err
}
dec := func(s *Stream, val reflect.Value) (err error) {
- _, size, err := s.Kind()
- if err != nil || size == 0 && s.byteval == 0 {
+ newval := val
+ if val.IsNil() {
+ newval = reflect.New(etype)
+ }
+ if err = etypeinfo.decoder(s, newval.Elem()); err == nil {
+ val.Set(newval)
+ }
+ return err
+ }
+ return dec, nil
+}
+
+// makeOptionalPtrDecoder creates a decoder that decodes empty values
+// as nil. Non-empty values are decoded into a value of the element type,
+// just like makePtrDecoder does.
+//
+// This decoder is used for pointer-typed struct fields with struct tag "nil".
+func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) {
+ etype := typ.Elem()
+ etypeinfo, err := cachedTypeInfo1(etype, tags{})
+ if err != nil {
+ return nil, err
+ }
+ dec := func(s *Stream, val reflect.Value) (err error) {
+ kind, size, err := s.Kind()
+ if err != nil || size == 0 && kind != Byte {
// rearm s.Kind. This is important because the input
// position must advance to the next value even though
// we don't read anything.
@@ -465,15 +493,18 @@ var (
// has been reached during streaming.
EOL = errors.New("rlp: end of list")
- // Other errors
+ // Actual Errors
ErrExpectedString = errors.New("rlp: expected String or Byte")
ErrExpectedList = errors.New("rlp: expected List")
- ErrCanonInt = errors.New("rlp: expected Int")
+ ErrCanonInt = errors.New("rlp: non-canonical integer format")
+ ErrCanonSize = errors.New("rlp: non-canonical size information")
ErrElemTooLarge = errors.New("rlp: element is larger than containing list")
+ ErrValueTooLarge = errors.New("rlp: value size exceeds available input length")
// internal errors
- errNotInList = errors.New("rlp: call of ListEnd outside of any list")
- errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL")
+ errNotInList = errors.New("rlp: call of ListEnd outside of any list")
+ errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL")
+ errUintOverflow = errors.New("rlp: uint overflow")
)
// ByteReader must be implemented by any input reader for a Stream. It
@@ -496,23 +527,44 @@ type ByteReader interface {
//
// Stream is not safe for concurrent use.
type Stream struct {
- r ByteReader
+ r ByteReader
+
+ // number of bytes remaining to be read from r.
+ remaining uint64
+ limited bool
+
+ // auxiliary buffer for integer decoding
uintbuf []byte
kind Kind // kind of value ahead
size uint64 // size of value ahead
byteval byte // value of single byte in type tag
+ kinderr error // error from last readKind
stack []listpos
}
type listpos struct{ pos, size uint64 }
-// NewStream creates a new stream reading from r.
-// If r does not implement ByteReader, the Stream will
-// introduce its own buffering.
-func NewStream(r io.Reader) *Stream {
+// NewStream creates a new decoding stream reading from r.
+//
+// If r implements the ByteReader interface, Stream will
+// not introduce any buffering.
+//
+// For non-toplevel values, Stream returns ErrElemTooLarge
+// for values that do not fit into the enclosing list.
+//
+// Stream supports an optional input limit. If a limit is set, the
+// size of any toplevel value will be checked against the remaining
+// input length. Stream operations that encounter a value exceeding
+// the remaining input length will return ErrValueTooLarge. The limit
+// can be set by passing a non-zero value for inputLimit.
+//
+// If r is a bytes.Reader or strings.Reader, the input limit is set to
+// the length of r's underlying data unless an explicit limit is
+// provided.
+func NewStream(r io.Reader, inputLimit uint64) *Stream {
s := new(Stream)
- s.Reset(r)
+ s.Reset(r, inputLimit)
return s
}
@@ -520,7 +572,7 @@ func NewStream(r io.Reader) *Stream {
// at an encoded list of the given length.
func NewListStream(r io.Reader, len uint64) *Stream {
s := new(Stream)
- s.Reset(r)
+ s.Reset(r, len)
s.kind = List
s.size = len
return s
@@ -543,6 +595,9 @@ func (s *Stream) Bytes() ([]byte, error) {
if err = s.readFull(b); err != nil {
return nil, err
}
+ if size == 1 && b[0] < 56 {
+ return nil, ErrCanonSize
+ }
return b, nil
default:
return nil, ErrExpectedString
@@ -574,8 +629,6 @@ func (s *Stream) Raw() ([]byte, error) {
return buf, nil
}
-var errUintOverflow = errors.New("rlp: uint overflow")
-
// Uint reads an RLP string of up to 8 bytes and returns its contents
// as an unsigned integer. If the input does not contain an RLP string, the
// returned error will be ErrExpectedString.
@@ -590,13 +643,27 @@ func (s *Stream) uint(maxbits int) (uint64, error) {
}
switch kind {
case Byte:
+ if s.byteval == 0 {
+ return 0, ErrCanonInt
+ }
s.kind = -1 // rearm Kind
return uint64(s.byteval), nil
case String:
if size > uint64(maxbits/8) {
return 0, errUintOverflow
}
- return s.readUint(byte(size))
+ v, err := s.readUint(byte(size))
+ switch {
+ case err == ErrCanonSize:
+ // Adjust error because we're not reading a size right now.
+ return 0, ErrCanonInt
+ case err != nil:
+ return 0, err
+ case size > 0 && v < 56:
+ return 0, ErrCanonSize
+ default:
+ return v, nil
+ }
default:
return 0, ErrExpectedString
}
@@ -653,7 +720,7 @@ func (s *Stream) Decode(val interface{}) error {
if rval.IsNil() {
return errDecodeIntoNil
}
- info, err := cachedTypeInfo(rtyp.Elem())
+ info, err := cachedTypeInfo(rtyp.Elem(), tags{})
if err != nil {
return err
}
@@ -667,17 +734,40 @@ func (s *Stream) Decode(val interface{}) error {
}
// Reset discards any information about the current decoding context
-// and starts reading from r. If r does not also implement ByteReader,
-// Stream will do its own buffering.
-func (s *Stream) Reset(r io.Reader) {
+// and starts reading from r. This method is meant to facilitate reuse
+// of a preallocated Stream across many decoding operations.
+//
+// If r does not also implement ByteReader, Stream will do its own
+// buffering.
+func (s *Stream) Reset(r io.Reader, inputLimit uint64) {
+ if inputLimit > 0 {
+ s.remaining = inputLimit
+ s.limited = true
+ } else {
+ // Attempt to automatically discover
+ // the limit when reading from a byte slice.
+ switch br := r.(type) {
+ case *bytes.Reader:
+ s.remaining = uint64(br.Len())
+ s.limited = true
+ case *strings.Reader:
+ s.remaining = uint64(br.Len())
+ s.limited = true
+ default:
+ s.limited = false
+ }
+ }
+ // Wrap r with a buffer if it doesn't have one.
bufr, ok := r.(ByteReader)
if !ok {
bufr = bufio.NewReader(r)
}
s.r = bufr
+ // Reset the decoding context.
s.stack = s.stack[:0]
s.size = 0
s.kind = -1
+ s.kinderr = nil
if s.uintbuf == nil {
s.uintbuf = make([]byte, 8)
}
@@ -700,19 +790,31 @@ func (s *Stream) Kind() (kind Kind, size uint64, err error) {
tos = &s.stack[len(s.stack)-1]
}
if s.kind < 0 {
+ s.kinderr = nil
+ // Don't read further if we're at the end of the
+ // innermost list.
if tos != nil && tos.pos == tos.size {
return 0, 0, EOL
}
- kind, size, err = s.readKind()
- if err != nil {
- return 0, 0, err
+ s.kind, s.size, s.kinderr = s.readKind()
+ if s.kinderr == nil {
+ if tos == nil {
+ // At toplevel, check that the value is smaller
+ // than the remaining input length.
+ if s.limited && s.size > s.remaining {
+ s.kinderr = ErrValueTooLarge
+ }
+ } else {
+ // Inside a list, check that the value doesn't overflow the list.
+ if s.size > tos.size-tos.pos {
+ s.kinderr = ErrElemTooLarge
+ }
+ }
}
- s.kind, s.size = kind, size
- }
- if tos != nil && tos.pos+s.size > tos.size {
- return 0, 0, ErrElemTooLarge
}
- return s.kind, s.size, nil
+ // Note: this might return a sticky error generated
+ // by an earlier call to readKind.
+ return s.kind, s.size, s.kinderr
}
func (s *Stream) readKind() (kind Kind, size uint64, err error) {
@@ -741,6 +843,9 @@ func (s *Stream) readKind() (kind Kind, size uint64, err error) {
// would be encoded as 0xB90400 followed by the string. The range of
// the first byte is thus [0xB8, 0xBF].
size, err = s.readUint(b - 0xB7)
+ if err == nil && size < 56 {
+ err = ErrCanonSize
+ }
return String, size, err
case b < 0xF8:
// If the total payload of a list
@@ -757,27 +862,46 @@ func (s *Stream) readKind() (kind Kind, size uint64, err error) {
// the concatenation of the RLP encodings of the items. The
// range of the first byte is thus [0xF8, 0xFF].
size, err = s.readUint(b - 0xF7)
+ if err == nil && size < 56 {
+ err = ErrCanonSize
+ }
return List, size, err
}
}
func (s *Stream) readUint(size byte) (uint64, error) {
- if size == 1 {
+ switch size {
+ case 0:
+ s.kind = -1 // rearm Kind
+ return 0, nil
+ case 1:
b, err := s.readByte()
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return uint64(b), err
+ default:
+ start := int(8 - size)
+ for i := 0; i < start; i++ {
+ s.uintbuf[i] = 0
+ }
+ if err := s.readFull(s.uintbuf[start:]); err != nil {
+ return 0, err
+ }
+ if s.uintbuf[start] == 0 {
+ // Note: readUint is also used to decode integer
+ // values. The error needs to be adjusted to become
+ // ErrCanonInt in this case.
+ return 0, ErrCanonSize
+ }
+ return binary.BigEndian.Uint64(s.uintbuf), nil
}
- start := int(8 - size)
- for i := 0; i < start; i++ {
- s.uintbuf[i] = 0
- }
- err := s.readFull(s.uintbuf[start:])
- return binary.BigEndian.Uint64(s.uintbuf), err
}
func (s *Stream) readFull(buf []byte) (err error) {
+ if s.limited && s.remaining < uint64(len(buf)) {
+ return ErrValueTooLarge
+ }
s.willRead(uint64(len(buf)))
var nn, n int
for n < len(buf) && err == nil {
@@ -791,6 +915,9 @@ func (s *Stream) readFull(buf []byte) (err error) {
}
func (s *Stream) readByte() (byte, error) {
+ if s.limited && s.remaining == 0 {
+ return 0, io.EOF
+ }
s.willRead(1)
b, err := s.r.ReadByte()
if len(s.stack) > 0 && err == io.EOF {
@@ -801,6 +928,9 @@ func (s *Stream) readByte() (byte, error) {
func (s *Stream) willRead(n uint64) {
s.kind = -1 // rearm Kind
+ if s.limited {
+ s.remaining -= n
+ }
if len(s.stack) > 0 {
s.stack[len(s.stack)-1].pos += n
}
diff --git a/rlp/decode_test.go b/rlp/decode_test.go
index 73a31c67f..d07520bd0 100644
--- a/rlp/decode_test.go
+++ b/rlp/decode_test.go
@@ -21,22 +21,18 @@ func TestStreamKind(t *testing.T) {
{"7F", Byte, 0},
{"80", String, 0},
{"B7", String, 55},
- {"B800", String, 0},
{"B90400", String, 1024},
- {"BA000400", String, 1024},
- {"BB00000400", String, 1024},
{"BFFFFFFFFFFFFFFFFF", String, ^uint64(0)},
{"C0", List, 0},
{"C8", List, 8},
{"F7", List, 55},
- {"F800", List, 0},
- {"F804", List, 4},
{"F90400", List, 1024},
{"FFFFFFFFFFFFFFFFFF", List, ^uint64(0)},
}
for i, test := range tests {
- s := NewStream(bytes.NewReader(unhex(test.input)))
+ // using plainReader to inhibit input limit errors.
+ s := NewStream(newPlainReader(unhex(test.input)), 0)
kind, len, err := s.Kind()
if err != nil {
t.Errorf("test %d: Kind returned error: %v", i, err)
@@ -70,29 +66,85 @@ func TestNewListStream(t *testing.T) {
}
func TestStreamErrors(t *testing.T) {
+ withoutInputLimit := func(b []byte) *Stream {
+ return NewStream(newPlainReader(b), 0)
+ }
+ withCustomInputLimit := func(limit uint64) func([]byte) *Stream {
+ return func(b []byte) *Stream {
+ return NewStream(bytes.NewReader(b), limit)
+ }
+ }
+
type calls []string
tests := []struct {
string
calls
- error
+ newStream func([]byte) *Stream // uses bytes.Reader if nil
+ error error
}{
- {"", calls{"Kind"}, io.EOF},
- {"", calls{"List"}, io.EOF},
- {"", calls{"Uint"}, io.EOF},
- {"C0", calls{"Bytes"}, ErrExpectedString},
- {"C0", calls{"Uint"}, ErrExpectedString},
- {"81", calls{"Bytes"}, io.ErrUnexpectedEOF},
- {"81", calls{"Uint"}, io.ErrUnexpectedEOF},
- {"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF},
- {"89000000000000000001", calls{"Uint"}, errUintOverflow},
- {"00", calls{"List"}, ErrExpectedList},
- {"80", calls{"List"}, ErrExpectedList},
- {"C0", calls{"List", "Uint"}, EOL},
- {"C801", calls{"List", "Uint", "Uint"}, io.ErrUnexpectedEOF},
- {"C8C9", calls{"List", "Kind"}, ErrElemTooLarge},
- {"C3C2010201", calls{"List", "List", "Uint", "Uint", "ListEnd", "Uint"}, EOL},
- {"00", calls{"ListEnd"}, errNotInList},
- {"C40102", calls{"List", "Uint", "ListEnd"}, errNotAtEOL},
+ {"C0", calls{"Bytes"}, nil, ErrExpectedString},
+ {"C0", calls{"Uint"}, nil, ErrExpectedString},
+ {"89000000000000000001", calls{"Uint"}, nil, errUintOverflow},
+ {"00", calls{"List"}, nil, ErrExpectedList},
+ {"80", calls{"List"}, nil, ErrExpectedList},
+ {"C0", calls{"List", "Uint"}, nil, EOL},
+ {"C8C9010101010101010101", calls{"List", "Kind"}, nil, ErrElemTooLarge},
+ {"C3C2010201", calls{"List", "List", "Uint", "Uint", "ListEnd", "Uint"}, nil, EOL},
+ {"00", calls{"ListEnd"}, nil, errNotInList},
+ {"C401020304", calls{"List", "Uint", "ListEnd"}, nil, errNotAtEOL},
+
+ // Non-canonical integers (e.g. leading zero bytes).
+ {"00", calls{"Uint"}, nil, ErrCanonInt},
+ {"820002", calls{"Uint"}, nil, ErrCanonInt},
+ {"8133", calls{"Uint"}, nil, ErrCanonSize},
+ {"8156", calls{"Uint"}, nil, nil},
+
+ // Size tags must use the smallest possible encoding.
+ // Leading zero bytes in the size tag are also rejected.
+ {"8100", calls{"Uint"}, nil, ErrCanonSize},
+ {"8100", calls{"Bytes"}, nil, ErrCanonSize},
+ {"B800", calls{"Kind"}, withoutInputLimit, ErrCanonSize},
+ {"B90000", calls{"Kind"}, withoutInputLimit, ErrCanonSize},
+ {"B90055", calls{"Kind"}, withoutInputLimit, ErrCanonSize},
+ {"BA0002FFFF", calls{"Bytes"}, withoutInputLimit, ErrCanonSize},
+ {"F800", calls{"Kind"}, withoutInputLimit, ErrCanonSize},
+ {"F90000", calls{"Kind"}, withoutInputLimit, ErrCanonSize},
+ {"F90055", calls{"Kind"}, withoutInputLimit, ErrCanonSize},
+ {"FA0002FFFF", calls{"List"}, withoutInputLimit, ErrCanonSize},
+
+ // Expected EOF
+ {"", calls{"Kind"}, nil, io.EOF},
+ {"", calls{"Uint"}, nil, io.EOF},
+ {"", calls{"List"}, nil, io.EOF},
+ {"8158", calls{"Uint", "Uint"}, nil, io.EOF},
+ {"C0", calls{"List", "ListEnd", "List"}, nil, io.EOF},
+
+ // Input limit errors.
+ {"81", calls{"Bytes"}, nil, ErrValueTooLarge},
+ {"81", calls{"Uint"}, nil, ErrValueTooLarge},
+ {"81", calls{"Raw"}, nil, ErrValueTooLarge},
+ {"BFFFFFFFFFFFFFFFFFFF", calls{"Bytes"}, nil, ErrValueTooLarge},
+ {"C801", calls{"List"}, nil, ErrValueTooLarge},
+
+ // Test for list element size check overflow.
+ {"CD04040404FFFFFFFFFFFFFFFFFF0303", calls{"List", "Uint", "Uint", "Uint", "Uint", "List"}, nil, ErrElemTooLarge},
+
+ // Test for input limit overflow. Since we are counting the limit
+ // down toward zero in Stream.remaining, reading too far can overflow
+ // remaining to a large value, effectively disabling the limit.
+ {"C40102030401", calls{"Raw", "Uint"}, withCustomInputLimit(5), io.EOF},
+ {"C4010203048158", calls{"Raw", "Uint"}, withCustomInputLimit(6), ErrValueTooLarge},
+
+ // Check that the same calls are fine without a limit.
+ {"C40102030401", calls{"Raw", "Uint"}, withoutInputLimit, nil},
+ {"C4010203048158", calls{"Raw", "Uint"}, withoutInputLimit, nil},
+
+ // Unexpected EOF. This only happens when there is
+ // no input limit, so the reader needs to be 'dumbed down'.
+ {"81", calls{"Bytes"}, withoutInputLimit, io.ErrUnexpectedEOF},
+ {"81", calls{"Uint"}, withoutInputLimit, io.ErrUnexpectedEOF},
+ {"BFFFFFFFFFFFFFFF", calls{"Bytes"}, withoutInputLimit, io.ErrUnexpectedEOF},
+ {"C801", calls{"List", "Uint", "Uint"}, withoutInputLimit, io.ErrUnexpectedEOF},
// This test verifies that the input position is advanced
// correctly when calling Bytes for empty strings. Kind can be called
@@ -109,12 +161,15 @@ func TestStreamErrors(t *testing.T) {
"Bytes", // past final element
"Bytes", // this one should fail
- }, EOL},
+ }, nil, EOL},
}
testfor:
for i, test := range tests {
- s := NewStream(bytes.NewReader(unhex(test.string)))
+ if test.newStream == nil {
+ test.newStream = func(b []byte) *Stream { return NewStream(bytes.NewReader(b), 0) }
+ }
+ s := test.newStream(unhex(test.string))
rs := reflect.ValueOf(s)
for j, call := range test.calls {
fval := rs.MethodByName(call)
@@ -124,11 +179,17 @@ testfor:
err = lastret.(error).Error()
}
if j == len(test.calls)-1 {
- if err != test.error.Error() {
- t.Errorf("test %d: last call (%s) error mismatch\ngot: %s\nwant: %v",
+ want := "<nil>"
+ if test.error != nil {
+ want = test.error.Error()
+ }
+ if err != want {
+ t.Log(test)
+ t.Errorf("test %d: last call (%s) error mismatch\ngot: %s\nwant: %s",
i, call, err, test.error)
}
} else if err != "<nil>" {
+ t.Log(test)
t.Errorf("test %d: call %d (%s) unexpected error: %q", i, j, call, err)
continue testfor
}
@@ -137,7 +198,7 @@ testfor:
}
func TestStreamList(t *testing.T) {
- s := NewStream(bytes.NewReader(unhex("C80102030405060708")))
+ s := NewStream(bytes.NewReader(unhex("C80102030405060708")), 0)
len, err := s.List()
if err != nil {
@@ -166,7 +227,7 @@ func TestStreamList(t *testing.T) {
}
func TestStreamRaw(t *testing.T) {
- s := NewStream(bytes.NewReader(unhex("C58401010101")))
+ s := NewStream(bytes.NewReader(unhex("C58401010101")), 0)
s.List()
want := unhex("8401010101")
@@ -219,7 +280,7 @@ type simplestruct struct {
type recstruct struct {
I uint
- Child *recstruct
+ Child *recstruct `rlp:"nil"`
}
var (
@@ -229,78 +290,58 @@ var (
)
)
-var (
- sharedByteArray [5]byte
- sharedPtr = new(*uint)
-)
-
var decodeTests = []decodeTest{
// integers
{input: "05", ptr: new(uint32), value: uint32(5)},
{input: "80", ptr: new(uint32), value: uint32(0)},
- {input: "8105", ptr: new(uint32), value: uint32(5)},
{input: "820505", ptr: new(uint32), value: uint32(0x0505)},
{input: "83050505", ptr: new(uint32), value: uint32(0x050505)},
{input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)},
{input: "850505050505", ptr: new(uint32), error: "rlp: input string too long for uint32"},
{input: "C0", ptr: new(uint32), error: "rlp: expected input string or byte for uint32"},
+ {input: "00", ptr: new(uint32), error: "rlp: non-canonical integer (leading zero bytes) for uint32"},
+ {input: "8105", ptr: new(uint32), error: "rlp: non-canonical size information for uint32"},
+ {input: "820004", ptr: new(uint32), error: "rlp: non-canonical integer (leading zero bytes) for uint32"},
+ {input: "B8020004", ptr: new(uint32), error: "rlp: non-canonical size information for uint32"},
// slices
{input: "C0", ptr: new([]uint), value: []uint{}},
{input: "C80102030405060708", ptr: new([]uint), value: []uint{1, 2, 3, 4, 5, 6, 7, 8}},
+ {input: "F8020004", ptr: new([]uint), error: "rlp: non-canonical size information for []uint"},
// arrays
- {input: "C0", ptr: new([5]uint), value: [5]uint{}},
{input: "C50102030405", ptr: new([5]uint), value: [5]uint{1, 2, 3, 4, 5}},
+ {input: "C0", ptr: new([5]uint), error: "rlp: input list has too few elements for [5]uint"},
+ {input: "C102", ptr: new([5]uint), error: "rlp: input list has too few elements for [5]uint"},
{input: "C6010203040506", ptr: new([5]uint), error: "rlp: input list has too many elements for [5]uint"},
+ {input: "F8020004", ptr: new([5]uint), error: "rlp: non-canonical size information for [5]uint"},
+
+ // zero sized arrays
+ {input: "C0", ptr: new([0]uint), value: [0]uint{}},
+ {input: "C101", ptr: new([0]uint), error: "rlp: input list has too many elements for [0]uint"},
// byte slices
{input: "01", ptr: new([]byte), value: []byte{1}},
{input: "80", ptr: new([]byte), value: []byte{}},
{input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")},
- {input: "C0", ptr: new([]byte), value: []byte{}},
- {input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}},
-
- {
- input: "C3820102",
- ptr: new([]byte),
- error: "rlp: input string too long for uint8, decoding into ([]uint8)[0]",
- },
+ {input: "C0", ptr: new([]byte), error: "rlp: expected input string or byte for []uint8"},
+ {input: "8105", ptr: new([]byte), error: "rlp: non-canonical size information for []uint8"},
// byte arrays
- {input: "01", ptr: new([5]byte), value: [5]byte{1}},
- {input: "80", ptr: new([5]byte), value: [5]byte{}},
+ {input: "02", ptr: new([1]byte), value: [1]byte{2}},
{input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}},
- {input: "C0", ptr: new([5]byte), value: [5]byte{}},
- {input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}},
- {
- input: "C3820102",
- ptr: new([5]byte),
- error: "rlp: input string too long for uint8, decoding into ([5]uint8)[0]",
- },
- {
- input: "86010203040506",
- ptr: new([5]byte),
- error: "rlp: input string too long for [5]uint8",
- },
- {
- input: "850101",
- ptr: new([5]byte),
- error: io.ErrUnexpectedEOF.Error(),
- },
-
- // byte array reuse (should be zeroed)
- {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}},
- {input: "8101", ptr: &sharedByteArray, value: [5]byte{1}}, // kind: String
- {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}},
- {input: "01", ptr: &sharedByteArray, value: [5]byte{1}}, // kind: Byte
- {input: "C3010203", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 0, 0}},
- {input: "C101", ptr: &sharedByteArray, value: [5]byte{1}}, // kind: List
+ // byte array errors
+ {input: "02", ptr: new([5]byte), error: "rlp: input string too short for [5]uint8"},
+ {input: "80", ptr: new([5]byte), error: "rlp: input string too short for [5]uint8"},
+ {input: "820000", ptr: new([5]byte), error: "rlp: input string too short for [5]uint8"},
+ {input: "C0", ptr: new([5]byte), error: "rlp: expected input string or byte for [5]uint8"},
+ {input: "C3010203", ptr: new([5]byte), error: "rlp: expected input string or byte for [5]uint8"},
+ {input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too long for [5]uint8"},
+ {input: "8105", ptr: new([1]byte), error: "rlp: non-canonical size information for [1]uint8"},
// zero sized byte arrays
{input: "80", ptr: new([0]byte), value: [0]byte{}},
- {input: "C0", ptr: new([0]byte), value: [0]byte{}},
{input: "01", ptr: new([0]byte), error: "rlp: input string too long for [0]uint8"},
{input: "8101", ptr: new([0]byte), error: "rlp: input string too long for [0]uint8"},
@@ -312,20 +353,44 @@ var decodeTests = []decodeTest{
// big ints
{input: "01", ptr: new(*big.Int), value: big.NewInt(1)},
{input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt},
- {input: "820001", ptr: new(big.Int), error: "rlp: canon int error appends zero's for *big.Int"},
{input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works
{input: "C0", ptr: new(*big.Int), error: "rlp: expected input string or byte for *big.Int"},
+ {input: "820001", ptr: new(big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"},
+ {input: "8105", ptr: new(big.Int), error: "rlp: non-canonical size information for *big.Int"},
// structs
- {input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}},
- {input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}},
- {input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}},
{
- input: "C501C302C103",
+ input: "C50583343434",
+ ptr: new(simplestruct),
+ value: simplestruct{5, "444"},
+ },
+ {
+ input: "C601C402C203C0",
ptr: new(recstruct),
value: recstruct{1, &recstruct{2, &recstruct{3, nil}}},
},
+ // struct errors
+ {
+ input: "C0",
+ ptr: new(simplestruct),
+ error: "rlp: too few elements for rlp.simplestruct",
+ },
+ {
+ input: "C105",
+ ptr: new(simplestruct),
+ error: "rlp: too few elements for rlp.simplestruct",
+ },
+ {
+ input: "C7C50583343434C0",
+ ptr: new([]*simplestruct),
+ error: "rlp: too few elements for rlp.simplestruct, decoding into ([]*rlp.simplestruct)[1]",
+ },
+ {
+ input: "83222222",
+ ptr: new(simplestruct),
+ error: "rlp: expected input list for rlp.simplestruct",
+ },
{
input: "C3010101",
ptr: new(simplestruct),
@@ -338,20 +403,16 @@ var decodeTests = []decodeTest{
},
// pointers
- {input: "00", ptr: new(*uint), value: (*uint)(nil)},
- {input: "80", ptr: new(*uint), value: (*uint)(nil)},
- {input: "C0", ptr: new(*uint), value: (*uint)(nil)},
+ {input: "00", ptr: new(*[]byte), value: &[]byte{0}},
+ {input: "80", ptr: new(*uint), value: uintp(0)},
+ {input: "C0", ptr: new(*uint), error: "rlp: expected input string or byte for uint"},
{input: "07", ptr: new(*uint), value: uintp(7)},
- {input: "8108", ptr: new(*uint), value: uintp(8)},
+ {input: "8158", ptr: new(*uint), value: uintp(0x58)},
{input: "C109", ptr: new(*[]uint), value: &[]uint{9}},
{input: "C58403030303", ptr: new(*[][]byte), value: &[][]byte{{3, 3, 3, 3}}},
// check that input position is advanced also for empty values.
- {input: "C3808005", ptr: new([]*uint), value: []*uint{nil, nil, uintp(5)}},
-
- // pointer should be reset to nil
- {input: "05", ptr: sharedPtr, value: uintp(5)},
- {input: "80", ptr: sharedPtr, value: (*uint)(nil)},
+ {input: "C3808005", ptr: new([]*uint), value: []*uint{uintp(0), uintp(0), uintp(5)}},
// interface{}
{input: "00", ptr: new(interface{}), value: []byte{0}},
@@ -401,11 +462,17 @@ func TestDecodeWithByteReader(t *testing.T) {
})
}
-// dumbReader reads from a byte slice but does not
-// implement ReadByte.
-type dumbReader []byte
+// plainReader reads from a byte slice but does not
+// implement ReadByte. It is also not recognized by the
+// size validation. This is useful to test how the decoder
+// behaves on a non-buffered input stream.
+type plainReader []byte
+
+func newPlainReader(b []byte) io.Reader {
+ return (*plainReader)(&b)
+}
-func (r *dumbReader) Read(buf []byte) (n int, err error) {
+func (r *plainReader) Read(buf []byte) (n int, err error) {
if len(*r) == 0 {
return 0, io.EOF
}
@@ -416,15 +483,14 @@ func (r *dumbReader) Read(buf []byte) (n int, err error) {
func TestDecodeWithNonByteReader(t *testing.T) {
runTests(t, func(input []byte, into interface{}) error {
- r := dumbReader(input)
- return Decode(&r, into)
+ return Decode(newPlainReader(input), into)
})
}
func TestDecodeStreamReset(t *testing.T) {
- s := NewStream(nil)
+ s := NewStream(nil, 0)
runTests(t, func(input []byte, into interface{}) error {
- s.Reset(bytes.NewReader(input))
+ s.Reset(bytes.NewReader(input), 0)
return s.Decode(into)
})
}
@@ -516,9 +582,36 @@ func ExampleDecode() {
// Decoded value: rlp.example{A:0xa, B:0x14, private:0x0, String:"foobar"}
}
+func ExampleDecode_structTagNil() {
+ // In this example, we'll use the "nil" struct tag to change
+ // how a pointer-typed field is decoded. The input contains an RLP
+ // list of one element, an empty string.
+ input := []byte{0xC1, 0x80}
+
+ // This type uses the normal rules.
+ // The empty input string is decoded as a pointer to an empty Go string.
+ var normalRules struct {
+ String *string
+ }
+ Decode(bytes.NewReader(input), &normalRules)
+ fmt.Printf("normal: String = %q\n", *normalRules.String)
+
+ // This type uses the struct tag.
+ // The empty input string is decoded as a nil pointer.
+ var withEmptyOK struct {
+ String *string `rlp:"nil"`
+ }
+ Decode(bytes.NewReader(input), &withEmptyOK)
+ fmt.Printf("with nil tag: String = %v\n", withEmptyOK.String)
+
+ // Output:
+ // normal: String = ""
+ // with nil tag: String = <nil>
+}
+
func ExampleStream() {
input, _ := hex.DecodeString("C90A1486666F6F626172")
- s := NewStream(bytes.NewReader(input))
+ s := NewStream(bytes.NewReader(input), 0)
// Check what kind of value lies ahead
kind, size, _ := s.Kind()
diff --git a/rlp/encode.go b/rlp/encode.go
index 6cf6776d6..10ff0ae79 100644
--- a/rlp/encode.go
+++ b/rlp/encode.go
@@ -194,7 +194,7 @@ func (w *encbuf) Write(b []byte) (int, error) {
func (w *encbuf) encode(val interface{}) error {
rval := reflect.ValueOf(val)
- ti, err := cachedTypeInfo(rval.Type())
+ ti, err := cachedTypeInfo(rval.Type(), tags{})
if err != nil {
return err
}
@@ -485,7 +485,7 @@ func writeInterface(val reflect.Value, w *encbuf) error {
return nil
}
eval := val.Elem()
- ti, err := cachedTypeInfo(eval.Type())
+ ti, err := cachedTypeInfo(eval.Type(), tags{})
if err != nil {
return err
}
@@ -493,7 +493,7 @@ func writeInterface(val reflect.Value, w *encbuf) error {
}
func makeSliceWriter(typ reflect.Type) (writer, error) {
- etypeinfo, err := cachedTypeInfo1(typ.Elem())
+ etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{})
if err != nil {
return nil, err
}
@@ -530,7 +530,7 @@ func makeStructWriter(typ reflect.Type) (writer, error) {
}
func makePtrWriter(typ reflect.Type) (writer, error) {
- etypeinfo, err := cachedTypeInfo1(typ.Elem())
+ etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{})
if err != nil {
return nil, err
}
diff --git a/rlp/typecache.go b/rlp/typecache.go
index 398f25d90..d512012e9 100644
--- a/rlp/typecache.go
+++ b/rlp/typecache.go
@@ -7,7 +7,7 @@ import (
var (
typeCacheMutex sync.RWMutex
- typeCache = make(map[reflect.Type]*typeinfo)
+ typeCache = make(map[typekey]*typeinfo)
)
type typeinfo struct {
@@ -15,13 +15,25 @@ type typeinfo struct {
writer
}
+// represents struct tags
+type tags struct {
+ nilOK bool
+}
+
+type typekey struct {
+ reflect.Type
+ // the key must include the struct tags because they
+ // might generate a different decoder.
+ tags
+}
+
type decoder func(*Stream, reflect.Value) error
type writer func(reflect.Value, *encbuf) error
-func cachedTypeInfo(typ reflect.Type) (*typeinfo, error) {
+func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) {
typeCacheMutex.RLock()
- info := typeCache[typ]
+ info := typeCache[typekey{typ, tags}]
typeCacheMutex.RUnlock()
if info != nil {
return info, nil
@@ -29,11 +41,12 @@ func cachedTypeInfo(typ reflect.Type) (*typeinfo, error) {
// not in the cache, need to generate info for this type.
typeCacheMutex.Lock()
defer typeCacheMutex.Unlock()
- return cachedTypeInfo1(typ)
+ return cachedTypeInfo1(typ, tags)
}
-func cachedTypeInfo1(typ reflect.Type) (*typeinfo, error) {
- info := typeCache[typ]
+func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) {
+ key := typekey{typ, tags}
+ info := typeCache[key]
if info != nil {
// another goroutine got the write lock first
return info, nil
@@ -41,21 +54,27 @@ func cachedTypeInfo1(typ reflect.Type) (*typeinfo, error) {
// put a dummmy value into the cache before generating.
// if the generator tries to lookup itself, it will get
// the dummy value and won't call itself recursively.
- typeCache[typ] = new(typeinfo)
- info, err := genTypeInfo(typ)
+ typeCache[key] = new(typeinfo)
+ info, err := genTypeInfo(typ, tags)
if err != nil {
// remove the dummy value if the generator fails
- delete(typeCache, typ)
+ delete(typeCache, key)
return nil, err
}
- *typeCache[typ] = *info
- return typeCache[typ], err
+ *typeCache[key] = *info
+ return typeCache[key], err
+}
+
+type field struct {
+ index int
+ info *typeinfo
}
func structFields(typ reflect.Type) (fields []field, err error) {
for i := 0; i < typ.NumField(); i++ {
if f := typ.Field(i); f.PkgPath == "" { // exported
- info, err := cachedTypeInfo1(f.Type)
+ tags := parseStructTag(f.Tag.Get("rlp"))
+ info, err := cachedTypeInfo1(f.Type, tags)
if err != nil {
return nil, err
}
@@ -65,9 +84,13 @@ func structFields(typ reflect.Type) (fields []field, err error) {
return fields, nil
}
-func genTypeInfo(typ reflect.Type) (info *typeinfo, err error) {
+func parseStructTag(tag string) tags {
+ return tags{nilOK: tag == "nil"}
+}
+
+func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) {
info = new(typeinfo)
- if info.decoder, err = makeDecoder(typ); err != nil {
+ if info.decoder, err = makeDecoder(typ, tags); err != nil {
return nil, err
}
if info.writer, err = makeWriter(typ); err != nil {
diff --git a/whisper/envelope.go b/whisper/envelope.go
index 0a817e26e..07762c300 100644
--- a/whisper/envelope.go
+++ b/whisper/envelope.go
@@ -109,16 +109,17 @@ func (self *Envelope) Hash() common.Hash {
return self.hash
}
-// rlpenv is an Envelope but is not an rlp.Decoder.
-// It is used for decoding because we need to
-type rlpenv Envelope
-
// DecodeRLP decodes an Envelope from an RLP data stream.
func (self *Envelope) DecodeRLP(s *rlp.Stream) error {
raw, err := s.Raw()
if err != nil {
return err
}
+ // The decoding of Envelope uses the struct fields but also needs
+ // to compute the hash of the whole RLP-encoded envelope. This
+ // type has the same structure as Envelope but is not an
+ // rlp.Decoder so we can reuse the Envelope struct definition.
+ type rlpenv Envelope
if err := rlp.DecodeBytes(raw, (*rlpenv)(self)); err != nil {
return err
}
diff --git a/whisper/peer.go b/whisper/peer.go
index e4301f37c..28abf4260 100644
--- a/whisper/peer.go
+++ b/whisper/peer.go
@@ -66,7 +66,7 @@ func (self *peer) handshake() error {
if packet.Code != statusCode {
return fmt.Errorf("peer sent %x before status packet", packet.Code)
}
- s := rlp.NewStream(packet.Payload)
+ s := rlp.NewStream(packet.Payload, uint64(packet.Size))
if _, err := s.List(); err != nil {
return fmt.Errorf("bad status message: %v", err)
}