aboutsummaryrefslogtreecommitdiffstats
path: root/rlp
diff options
context:
space:
mode:
Diffstat (limited to 'rlp')
-rw-r--r--rlp/encode.go71
-rw-r--r--rlp/encode_test.go15
2 files changed, 85 insertions, 1 deletions
diff --git a/rlp/encode.go b/rlp/encode.go
index d80b66315..9d11d66bf 100644
--- a/rlp/encode.go
+++ b/rlp/encode.go
@@ -32,6 +32,48 @@ type Encoder interface {
EncodeRLP(io.Writer) error
}
+// Flat wraps a value (which must encode as a list) so
+// it encodes as the list's elements.
+//
+// Example: suppose you have defined a type
+//
+// type foo struct { A, B uint }
+//
+// Under normal encoding rules,
+//
+// rlp.Encode(foo{1, 2}) --> 0xC20102
+//
+// This function can help you achieve the following encoding:
+//
+// rlp.Encode(rlp.Flat(foo{1, 2})) --> 0x0102
+func Flat(val interface{}) Encoder {
+ return flatenc{val}
+}
+
+type flatenc struct{ val interface{} }
+
+func (e flatenc) EncodeRLP(out io.Writer) error {
+ // record current output position
+ var (
+ eb = out.(*encbuf)
+ prevstrsize = len(eb.str)
+ prevnheads = len(eb.lheads)
+ )
+ if err := eb.encode(e.val); err != nil {
+ return err
+ }
+ // check that a new list header has appeared
+ if len(eb.lheads) == prevnheads || eb.lheads[prevnheads].offset == prevstrsize-1 {
+ return fmt.Errorf("rlp.Flat: %T did not encode as list", e.val)
+ }
+ // remove the new list header
+ newhead := eb.lheads[prevnheads]
+ copy(eb.lheads[prevnheads:], eb.lheads[prevnheads+1:])
+ eb.lheads = eb.lheads[:len(eb.lheads)-1]
+ eb.lhsize -= newhead.tagsize()
+ return nil
+}
+
// Encode writes the RLP encoding of val to w. Note that Encode may
// perform many small writes in some cases. Consider making w
// buffered.
@@ -123,6 +165,13 @@ func (head *listhead) encode(buf []byte) []byte {
}
}
+func (head *listhead) tagsize() int {
+ if head.size < 56 {
+ return 1
+ }
+ return 1 + intsize(uint64(head.size))
+}
+
func newencbuf() *encbuf {
return &encbuf{sizebuf: make([]byte, 9)}
}
@@ -301,8 +350,10 @@ func makeWriter(typ reflect.Type) (writer, error) {
return writeUint, nil
case kind == reflect.String:
return writeString, nil
- case kind == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 && !typ.Elem().Implements(encoderInterface):
+ case kind == reflect.Slice && isByte(typ.Elem()):
return writeBytes, nil
+ case kind == reflect.Array && isByte(typ.Elem()):
+ return writeByteArray, nil
case kind == reflect.Slice || kind == reflect.Array:
return makeSliceWriter(typ)
case kind == reflect.Struct:
@@ -314,6 +365,10 @@ func makeWriter(typ reflect.Type) (writer, error) {
}
}
+func isByte(typ reflect.Type) bool {
+ return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface)
+}
+
func writeUint(val reflect.Value, w *encbuf) error {
i := val.Uint()
if i == 0 {
@@ -358,6 +413,20 @@ func writeBytes(val reflect.Value, w *encbuf) error {
return nil
}
+func writeByteArray(val reflect.Value, w *encbuf) error {
+ if !val.CanAddr() {
+ // Slice requires the value to be addressable.
+ // Make it addressable by copying.
+ copy := reflect.New(val.Type()).Elem()
+ copy.Set(val)
+ val = copy
+ }
+ size := val.Len()
+ slice := val.Slice(0, size).Bytes()
+ w.encodeString(slice)
+ return nil
+}
+
func writeString(val reflect.Value, w *encbuf) error {
s := val.String()
w.encodeStringHeader(len(s))
diff --git a/rlp/encode_test.go b/rlp/encode_test.go
index 18b843737..c283fbd57 100644
--- a/rlp/encode_test.go
+++ b/rlp/encode_test.go
@@ -40,6 +40,8 @@ func (e *encodableReader) Read(b []byte) (int, error) {
panic("called")
}
+type namedByteType byte
+
var (
_ = Encoder(&testEncoder{})
_ = Encoder(byteEncoder(0))
@@ -102,6 +104,10 @@ var encTests = []encTest{
// byte slices, strings
{val: []byte{}, output: "80"},
{val: []byte{1, 2, 3}, output: "83010203"},
+
+ {val: []namedByteType{1, 2, 3}, output: "83010203"},
+ {val: [...]namedByteType{1, 2, 3}, output: "83010203"},
+
{val: "", output: "80"},
{val: "dog", output: "83646F67"},
{
@@ -177,6 +183,15 @@ var encTests = []encTest{
{val: &recstruct{5, nil}, output: "C205C0"},
{val: &recstruct{5, &recstruct{4, &recstruct{3, nil}}}, output: "C605C404C203C0"},
+ // flat
+ {val: Flat(uint(1)), error: "rlp.Flat: uint did not encode as list"},
+ {val: Flat(simplestruct{A: 3, B: "foo"}), output: "0383666F6F"},
+ {
+ // value generates more list headers after the Flat
+ val: []interface{}{"foo", []uint{1, 2}, Flat([]uint{3, 4}), []uint{5, 6}, "bar"},
+ output: "D083666F6FC201020304C2050683626172",
+ },
+
// nil
{val: (*uint)(nil), output: "80"},
{val: (*string)(nil), output: "80"},