aboutsummaryrefslogtreecommitdiffstats
path: root/rlp
diff options
context:
space:
mode:
Diffstat (limited to 'rlp')
-rw-r--r--rlp/decode.go39
-rw-r--r--rlp/decode_test.go36
-rw-r--r--rlp/encode.go58
-rw-r--r--rlp/encode_test.go7
4 files changed, 119 insertions, 21 deletions
diff --git a/rlp/decode.go b/rlp/decode.go
index 55f7187a3..0fde0a947 100644
--- a/rlp/decode.go
+++ b/rlp/decode.go
@@ -2,6 +2,7 @@ package rlp
import (
"bufio"
+ "bytes"
"encoding/binary"
"errors"
"fmt"
@@ -73,6 +74,12 @@ func Decode(r io.Reader, val interface{}) error {
return NewStream(r).Decode(val)
}
+// DecodeBytes parses RLP data from b into val.
+// Please see the documentation of Decode for the decoding rules.
+func DecodeBytes(b []byte, val interface{}) error {
+ return NewStream(bytes.NewReader(b)).Decode(val)
+}
+
type decodeError struct {
msg string
typ reflect.Type
@@ -360,7 +367,12 @@ func makePtrDecoder(typ reflect.Type) (decoder, error) {
dec := func(s *Stream, val reflect.Value) (err error) {
_, size, err := s.Kind()
if err != nil || size == 0 && s.byteval == 0 {
- val.Set(reflect.Zero(typ)) // set to nil
+ // rearm s.Kind. This is important because the input
+ // position must advance to the next value even though
+ // we don't read anything.
+ s.kind = -1
+ // set the pointer to nil.
+ val.Set(reflect.Zero(typ))
return err
}
newval := val
@@ -528,6 +540,31 @@ func (s *Stream) Bytes() ([]byte, error) {
}
}
+// Raw reads a raw encoded value including RLP type information.
+func (s *Stream) Raw() ([]byte, error) {
+ kind, size, err := s.Kind()
+ if err != nil {
+ return nil, err
+ }
+ if kind == Byte {
+ s.kind = -1 // rearm Kind
+ return []byte{s.byteval}, nil
+ }
+ // the original header has already been read and is no longer
+ // available. read content and put a new header in front of it.
+ start := headsize(size)
+ buf := make([]byte, uint64(start)+size)
+ if err := s.readFull(buf[start:]); err != nil {
+ return nil, err
+ }
+ if kind == String {
+ puthead(buf, 0x80, 0xB8, size)
+ } else {
+ puthead(buf, 0xC0, 0xF7, size)
+ }
+ return buf, nil
+}
+
var errUintOverflow = errors.New("rlp: uint overflow")
// Uint reads an RLP string of up to 8 bytes and returns its contents
diff --git a/rlp/decode_test.go b/rlp/decode_test.go
index 9f66840b1..a18ff1d08 100644
--- a/rlp/decode_test.go
+++ b/rlp/decode_test.go
@@ -39,7 +39,7 @@ func TestStreamKind(t *testing.T) {
s := NewStream(bytes.NewReader(unhex(test.input)))
kind, len, err := s.Kind()
if err != nil {
- t.Errorf("test %d: Type returned error: %v", i, err)
+ t.Errorf("test %d: Kind returned error: %v", i, err)
continue
}
if kind != test.wantKind {
@@ -93,6 +93,23 @@ func TestStreamErrors(t *testing.T) {
{"C3C2010201", calls{"List", "List", "Uint", "Uint", "ListEnd", "Uint"}, EOL},
{"00", calls{"ListEnd"}, errNotInList},
{"C40102", calls{"List", "Uint", "ListEnd"}, errNotAtEOL},
+
+ // This test verifies that the input position is advanced
+ // correctly when calling Bytes for empty strings. Kind can be called
+ // any number of times in between and doesn't advance.
+ {"C3808080", calls{
+ "List", // enter the list
+ "Bytes", // past first element
+
+ "Kind", "Kind", "Kind", // this shouldn't advance
+
+ "Bytes", // past second element
+
+ "Kind", "Kind", // can't hurt to try
+
+ "Bytes", // past final element
+ "Bytes", // this one should fail
+ }, EOL},
}
testfor:
@@ -148,6 +165,20 @@ func TestStreamList(t *testing.T) {
}
}
+func TestStreamRaw(t *testing.T) {
+ s := NewStream(bytes.NewReader(unhex("C58401010101")))
+ s.List()
+
+ want := unhex("8401010101")
+ raw, err := s.Raw()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(want, raw) {
+ t.Errorf("raw mismatch: got %x, want %x", raw, want)
+ }
+}
+
func TestDecodeErrors(t *testing.T) {
r := bytes.NewReader(nil)
@@ -314,6 +345,9 @@ var decodeTests = []decodeTest{
{input: "C109", ptr: new(*[]uint), value: &[]uint{9}},
{input: "C58403030303", ptr: new(*[][]byte), value: &[][]byte{{3, 3, 3, 3}}},
+ // check that input position is advanced also for empty values.
+ {input: "C3808005", ptr: new([]*uint), value: []*uint{nil, nil, uintp(5)}},
+
// pointer should be reset to nil
{input: "05", ptr: sharedPtr, value: uintp(5)},
{input: "80", ptr: sharedPtr, value: (*uint)(nil)},
diff --git a/rlp/encode.go b/rlp/encode.go
index 9d11d66bf..289bc4eaa 100644
--- a/rlp/encode.go
+++ b/rlp/encode.go
@@ -70,7 +70,7 @@ func (e flatenc) EncodeRLP(out io.Writer) error {
newhead := eb.lheads[prevnheads]
copy(eb.lheads[prevnheads:], eb.lheads[prevnheads+1:])
eb.lheads = eb.lheads[:len(eb.lheads)-1]
- eb.lhsize -= newhead.tagsize()
+ eb.lhsize -= headsize(uint64(newhead.size))
return nil
}
@@ -155,21 +155,29 @@ type listhead struct {
// encode writes head to the given buffer, which must be at least
// 9 bytes long. It returns the encoded bytes.
func (head *listhead) encode(buf []byte) []byte {
- if head.size < 56 {
- buf[0] = 0xC0 + byte(head.size)
- return buf[:1]
- } else {
- sizesize := putint(buf[1:], uint64(head.size))
- buf[0] = 0xF7 + byte(sizesize)
- return buf[:sizesize+1]
+ return buf[:puthead(buf, 0xC0, 0xF7, uint64(head.size))]
+}
+
+// headsize returns the size of a list or string header
+// for a value of the given size.
+func headsize(size uint64) int {
+ if size < 56 {
+ return 1
}
+ return 1 + intsize(size)
}
-func (head *listhead) tagsize() int {
- if head.size < 56 {
+// puthead writes a list or string header to buf.
+// buf must be at least 9 bytes long.
+func puthead(buf []byte, smalltag, largetag byte, size uint64) int {
+ if size < 56 {
+ buf[0] = smalltag + byte(size)
return 1
+ } else {
+ sizesize := putint(buf[1:], size)
+ buf[0] = largetag + byte(sizesize)
+ return sizesize + 1
}
- return 1 + intsize(uint64(head.size))
}
func newencbuf() *encbuf {
@@ -203,8 +211,13 @@ func (w *encbuf) encodeStringHeader(size int) {
}
func (w *encbuf) encodeString(b []byte) {
- w.encodeStringHeader(len(b))
- w.str = append(w.str, b...)
+ if len(b) == 1 && b[0] <= 0x7F {
+ // fits single byte, no string header
+ w.str = append(w.str, b[0])
+ } else {
+ w.encodeStringHeader(len(b))
+ w.str = append(w.str, b...)
+ }
}
func (w *encbuf) list() *listhead {
@@ -386,7 +399,12 @@ func writeUint(val reflect.Value, w *encbuf) error {
}
func writeBigIntPtr(val reflect.Value, w *encbuf) error {
- return writeBigInt(val.Interface().(*big.Int), w)
+ ptr := val.Interface().(*big.Int)
+ if ptr == nil {
+ w.str = append(w.str, 0x80)
+ return nil
+ }
+ return writeBigInt(ptr, w)
}
func writeBigIntNoPtr(val reflect.Value, w *encbuf) error {
@@ -399,9 +417,6 @@ func writeBigInt(i *big.Int, w *encbuf) error {
return fmt.Errorf("rlp: cannot encode negative *big.Int")
} else if cmp == 0 {
w.str = append(w.str, 0x80)
- } else if bits := i.BitLen(); bits < 8 {
- // fits single byte
- w.str = append(w.str, byte(i.Uint64()))
} else {
w.encodeString(i.Bytes())
}
@@ -429,8 +444,13 @@ func writeByteArray(val reflect.Value, w *encbuf) error {
func writeString(val reflect.Value, w *encbuf) error {
s := val.String()
- w.encodeStringHeader(len(s))
- w.str = append(w.str, s...)
+ if len(s) == 1 && s[0] <= 0x7f {
+ // fits single byte, no string header
+ w.str = append(w.str, s[0])
+ } else {
+ w.encodeStringHeader(len(s))
+ w.str = append(w.str, s...)
+ }
return nil
}
diff --git a/rlp/encode_test.go b/rlp/encode_test.go
index c283fbd57..611514bda 100644
--- a/rlp/encode_test.go
+++ b/rlp/encode_test.go
@@ -103,12 +103,18 @@ var encTests = []encTest{
// byte slices, strings
{val: []byte{}, output: "80"},
+ {val: []byte{0x7E}, output: "7E"},
+ {val: []byte{0x7F}, output: "7F"},
+ {val: []byte{0x80}, output: "8180"},
{val: []byte{1, 2, 3}, output: "83010203"},
{val: []namedByteType{1, 2, 3}, output: "83010203"},
{val: [...]namedByteType{1, 2, 3}, output: "83010203"},
{val: "", output: "80"},
+ {val: "\x7E", output: "7E"},
+ {val: "\x7F", output: "7F"},
+ {val: "\x80", output: "8180"},
{val: "dog", output: "83646F67"},
{
val: "Lorem ipsum dolor sit amet, consectetur adipisicing eli",
@@ -196,6 +202,7 @@ var encTests = []encTest{
{val: (*uint)(nil), output: "80"},
{val: (*string)(nil), output: "80"},
{val: (*[]byte)(nil), output: "80"},
+ {val: (*big.Int)(nil), output: "80"},
{val: (*[]string)(nil), output: "C0"},
{val: (*[]interface{})(nil), output: "C0"},
{val: (*[]struct{ uint })(nil), output: "C0"},