diff options
Diffstat (limited to 'rlp')
-rw-r--r-- | rlp/decode.go | 39 | ||||
-rw-r--r-- | rlp/decode_test.go | 36 | ||||
-rw-r--r-- | rlp/encode.go | 58 | ||||
-rw-r--r-- | rlp/encode_test.go | 7 |
4 files changed, 119 insertions, 21 deletions
diff --git a/rlp/decode.go b/rlp/decode.go index 55f7187a3..0fde0a947 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -2,6 +2,7 @@ package rlp import ( "bufio" + "bytes" "encoding/binary" "errors" "fmt" @@ -73,6 +74,12 @@ func Decode(r io.Reader, val interface{}) error { return NewStream(r).Decode(val) } +// DecodeBytes parses RLP data from b into val. +// Please see the documentation of Decode for the decoding rules. +func DecodeBytes(b []byte, val interface{}) error { + return NewStream(bytes.NewReader(b)).Decode(val) +} + type decodeError struct { msg string typ reflect.Type @@ -360,7 +367,12 @@ func makePtrDecoder(typ reflect.Type) (decoder, error) { dec := func(s *Stream, val reflect.Value) (err error) { _, size, err := s.Kind() if err != nil || size == 0 && s.byteval == 0 { - val.Set(reflect.Zero(typ)) // set to nil + // rearm s.Kind. This is important because the input + // position must advance to the next value even though + // we don't read anything. + s.kind = -1 + // set the pointer to nil. + val.Set(reflect.Zero(typ)) return err } newval := val @@ -528,6 +540,31 @@ func (s *Stream) Bytes() ([]byte, error) { } } +// Raw reads a raw encoded value including RLP type information. +func (s *Stream) Raw() ([]byte, error) { + kind, size, err := s.Kind() + if err != nil { + return nil, err + } + if kind == Byte { + s.kind = -1 // rearm Kind + return []byte{s.byteval}, nil + } + // the original header has already been read and is no longer + // available. read content and put a new header in front of it. + start := headsize(size) + buf := make([]byte, uint64(start)+size) + if err := s.readFull(buf[start:]); err != nil { + return nil, err + } + if kind == String { + puthead(buf, 0x80, 0xB8, size) + } else { + puthead(buf, 0xC0, 0xF7, size) + } + return buf, nil +} + var errUintOverflow = errors.New("rlp: uint overflow") // Uint reads an RLP string of up to 8 bytes and returns its contents diff --git a/rlp/decode_test.go b/rlp/decode_test.go index 9f66840b1..a18ff1d08 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -39,7 +39,7 @@ func TestStreamKind(t *testing.T) { s := NewStream(bytes.NewReader(unhex(test.input))) kind, len, err := s.Kind() if err != nil { - t.Errorf("test %d: Type returned error: %v", i, err) + t.Errorf("test %d: Kind returned error: %v", i, err) continue } if kind != test.wantKind { @@ -93,6 +93,23 @@ func TestStreamErrors(t *testing.T) { {"C3C2010201", calls{"List", "List", "Uint", "Uint", "ListEnd", "Uint"}, EOL}, {"00", calls{"ListEnd"}, errNotInList}, {"C40102", calls{"List", "Uint", "ListEnd"}, errNotAtEOL}, + + // This test verifies that the input position is advanced + // correctly when calling Bytes for empty strings. Kind can be called + // any number of times in between and doesn't advance. + {"C3808080", calls{ + "List", // enter the list + "Bytes", // past first element + + "Kind", "Kind", "Kind", // this shouldn't advance + + "Bytes", // past second element + + "Kind", "Kind", // can't hurt to try + + "Bytes", // past final element + "Bytes", // this one should fail + }, EOL}, } testfor: @@ -148,6 +165,20 @@ func TestStreamList(t *testing.T) { } } +func TestStreamRaw(t *testing.T) { + s := NewStream(bytes.NewReader(unhex("C58401010101"))) + s.List() + + want := unhex("8401010101") + raw, err := s.Raw() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(want, raw) { + t.Errorf("raw mismatch: got %x, want %x", raw, want) + } +} + func TestDecodeErrors(t *testing.T) { r := bytes.NewReader(nil) @@ -314,6 +345,9 @@ var decodeTests = []decodeTest{ {input: "C109", ptr: new(*[]uint), value: &[]uint{9}}, {input: "C58403030303", ptr: new(*[][]byte), value: &[][]byte{{3, 3, 3, 3}}}, + // check that input position is advanced also for empty values. + {input: "C3808005", ptr: new([]*uint), value: []*uint{nil, nil, uintp(5)}}, + // pointer should be reset to nil {input: "05", ptr: sharedPtr, value: uintp(5)}, {input: "80", ptr: sharedPtr, value: (*uint)(nil)}, diff --git a/rlp/encode.go b/rlp/encode.go index 9d11d66bf..289bc4eaa 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -70,7 +70,7 @@ func (e flatenc) EncodeRLP(out io.Writer) error { newhead := eb.lheads[prevnheads] copy(eb.lheads[prevnheads:], eb.lheads[prevnheads+1:]) eb.lheads = eb.lheads[:len(eb.lheads)-1] - eb.lhsize -= newhead.tagsize() + eb.lhsize -= headsize(uint64(newhead.size)) return nil } @@ -155,21 +155,29 @@ type listhead struct { // encode writes head to the given buffer, which must be at least // 9 bytes long. It returns the encoded bytes. func (head *listhead) encode(buf []byte) []byte { - if head.size < 56 { - buf[0] = 0xC0 + byte(head.size) - return buf[:1] - } else { - sizesize := putint(buf[1:], uint64(head.size)) - buf[0] = 0xF7 + byte(sizesize) - return buf[:sizesize+1] + return buf[:puthead(buf, 0xC0, 0xF7, uint64(head.size))] +} + +// headsize returns the size of a list or string header +// for a value of the given size. +func headsize(size uint64) int { + if size < 56 { + return 1 } + return 1 + intsize(size) } -func (head *listhead) tagsize() int { - if head.size < 56 { +// puthead writes a list or string header to buf. +// buf must be at least 9 bytes long. +func puthead(buf []byte, smalltag, largetag byte, size uint64) int { + if size < 56 { + buf[0] = smalltag + byte(size) return 1 + } else { + sizesize := putint(buf[1:], size) + buf[0] = largetag + byte(sizesize) + return sizesize + 1 } - return 1 + intsize(uint64(head.size)) } func newencbuf() *encbuf { @@ -203,8 +211,13 @@ func (w *encbuf) encodeStringHeader(size int) { } func (w *encbuf) encodeString(b []byte) { - w.encodeStringHeader(len(b)) - w.str = append(w.str, b...) + if len(b) == 1 && b[0] <= 0x7F { + // fits single byte, no string header + w.str = append(w.str, b[0]) + } else { + w.encodeStringHeader(len(b)) + w.str = append(w.str, b...) + } } func (w *encbuf) list() *listhead { @@ -386,7 +399,12 @@ func writeUint(val reflect.Value, w *encbuf) error { } func writeBigIntPtr(val reflect.Value, w *encbuf) error { - return writeBigInt(val.Interface().(*big.Int), w) + ptr := val.Interface().(*big.Int) + if ptr == nil { + w.str = append(w.str, 0x80) + return nil + } + return writeBigInt(ptr, w) } func writeBigIntNoPtr(val reflect.Value, w *encbuf) error { @@ -399,9 +417,6 @@ func writeBigInt(i *big.Int, w *encbuf) error { return fmt.Errorf("rlp: cannot encode negative *big.Int") } else if cmp == 0 { w.str = append(w.str, 0x80) - } else if bits := i.BitLen(); bits < 8 { - // fits single byte - w.str = append(w.str, byte(i.Uint64())) } else { w.encodeString(i.Bytes()) } @@ -429,8 +444,13 @@ func writeByteArray(val reflect.Value, w *encbuf) error { func writeString(val reflect.Value, w *encbuf) error { s := val.String() - w.encodeStringHeader(len(s)) - w.str = append(w.str, s...) + if len(s) == 1 && s[0] <= 0x7f { + // fits single byte, no string header + w.str = append(w.str, s[0]) + } else { + w.encodeStringHeader(len(s)) + w.str = append(w.str, s...) + } return nil } diff --git a/rlp/encode_test.go b/rlp/encode_test.go index c283fbd57..611514bda 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -103,12 +103,18 @@ var encTests = []encTest{ // byte slices, strings {val: []byte{}, output: "80"}, + {val: []byte{0x7E}, output: "7E"}, + {val: []byte{0x7F}, output: "7F"}, + {val: []byte{0x80}, output: "8180"}, {val: []byte{1, 2, 3}, output: "83010203"}, {val: []namedByteType{1, 2, 3}, output: "83010203"}, {val: [...]namedByteType{1, 2, 3}, output: "83010203"}, {val: "", output: "80"}, + {val: "\x7E", output: "7E"}, + {val: "\x7F", output: "7F"}, + {val: "\x80", output: "8180"}, {val: "dog", output: "83646F67"}, { val: "Lorem ipsum dolor sit amet, consectetur adipisicing eli", @@ -196,6 +202,7 @@ var encTests = []encTest{ {val: (*uint)(nil), output: "80"}, {val: (*string)(nil), output: "80"}, {val: (*[]byte)(nil), output: "80"}, + {val: (*big.Int)(nil), output: "80"}, {val: (*[]string)(nil), output: "C0"}, {val: (*[]interface{})(nil), output: "C0"}, {val: (*[]struct{ uint })(nil), output: "C0"}, |