diff options
Diffstat (limited to 'tests/rlp_test_util.go')
-rw-r--r-- | tests/rlp_test_util.go | 172 |
1 files changed, 172 insertions, 0 deletions
diff --git a/tests/rlp_test_util.go b/tests/rlp_test_util.go new file mode 100644 index 000000000..c322b78c6 --- /dev/null +++ b/tests/rlp_test_util.go @@ -0,0 +1,172 @@ +package tests + +import ( + "bytes" + "encoding/hex" + "errors" + "fmt" + "io" + "math/big" + "os" + "strings" + + "github.com/ethereum/go-ethereum/rlp" +) + +// RLPTest is the JSON structure of a single RLP test. +type RLPTest struct { + // If the value of In is "INVALID" or "VALID", the test + // checks whether Out can be decoded into a value of + // type interface{}. + // + // For other JSON values, In is treated as a driver for + // calls to rlp.Stream. The test also verifies that encoding + // In produces the bytes in Out. + In interface{} + + // Out is a hex-encoded RLP value. + Out string +} + +// RunRLPTest runs the tests in the given file, skipping tests by name. +func RunRLPTest(file string, skip []string) error { + f, err := os.Open(file) + if err != nil { + return err + } + defer f.Close() + return RunRLPTestWithReader(f, skip) +} + +// RunRLPTest runs the tests encoded in r, skipping tests by name. +func RunRLPTestWithReader(r io.Reader, skip []string) error { + var tests map[string]*RLPTest + if err := readJson(r, &tests); err != nil { + return err + } + for _, s := range skip { + delete(tests, s) + } + for name, test := range tests { + if err := test.Run(); err != nil { + return fmt.Errorf("test %q failed: %v", name, err) + } + } + return nil +} + +// Run executes the test. +func (t *RLPTest) Run() error { + outb, err := hex.DecodeString(t.Out) + if err != nil { + return fmt.Errorf("invalid hex in Out") + } + + // Handle simple decoding tests with no actual In value. + if t.In == "VALID" || t.In == "INVALID" { + return checkDecodeInterface(outb, t.In == "VALID") + } + + // Check whether encoding the value produces the same bytes. + in := translateJSON(t.In) + b, err := rlp.EncodeToBytes(in) + if err != nil { + return fmt.Errorf("encode failed: %v", err) + } + if !bytes.Equal(b, outb) { + return fmt.Errorf("encode produced %x, want %x", b, outb) + } + // Test stream decoding. + s := rlp.NewStream(bytes.NewReader(outb), 0) + return checkDecodeFromJSON(s, in) +} + +func checkDecodeInterface(b []byte, isValid bool) error { + err := rlp.DecodeBytes(b, new(interface{})) + switch { + case isValid && err != nil: + return fmt.Errorf("decoding failed: %v", err) + case !isValid && err == nil: + return fmt.Errorf("decoding of invalid value succeeded") + } + return nil +} + +// translateJSON makes test json values encodable with RLP. +func translateJSON(v interface{}) interface{} { + switch v := v.(type) { + case float64: + return uint64(v) + case string: + if len(v) > 0 && v[0] == '#' { // # starts a faux big int. + big, ok := new(big.Int).SetString(v[1:], 10) + if !ok { + panic(fmt.Errorf("bad test: bad big int: %q", v)) + } + return big + } + return []byte(v) + case []interface{}: + new := make([]interface{}, len(v)) + for i := range v { + new[i] = translateJSON(v[i]) + } + return new + default: + panic(fmt.Errorf("can't handle %T", v)) + } +} + +// checkDecodeFromJSON decodes from s guided by exp. exp drives the +// Stream by invoking decoding operations (Uint, Big, List, ...) based +// on the type of each value. The value decoded from the RLP stream +// must match the JSON value. +func checkDecodeFromJSON(s *rlp.Stream, exp interface{}) error { + switch exp := exp.(type) { + case uint64: + i, err := s.Uint() + if err != nil { + return addStack("Uint", exp, err) + } + if i != exp { + return addStack("Uint", exp, fmt.Errorf("result mismatch: got %d", i)) + } + case *big.Int: + big := new(big.Int) + if err := s.Decode(&big); err != nil { + return addStack("Big", exp, err) + } + if big.Cmp(exp) != 0 { + return addStack("Big", exp, fmt.Errorf("result mismatch: got %d", big)) + } + case []byte: + b, err := s.Bytes() + if err != nil { + return addStack("Bytes", exp, err) + } + if !bytes.Equal(b, exp) { + return addStack("Bytes", exp, fmt.Errorf("result mismatch: got %x", b)) + } + case []interface{}: + if _, err := s.List(); err != nil { + return addStack("List", exp, err) + } + for i, v := range exp { + if err := checkDecodeFromJSON(s, v); err != nil { + return addStack(fmt.Sprintf("[%d]", i), exp, err) + } + } + if err := s.ListEnd(); err != nil { + return addStack("ListEnd", exp, err) + } + default: + panic(fmt.Errorf("unhandled type: %T", exp)) + } + return nil +} + +func addStack(op string, val interface{}, err error) error { + lines := strings.Split(err.Error(), "\n") + lines = append(lines, fmt.Sprintf("\t%s: %v", op, val)) + return errors.New(strings.Join(lines, "\n")) +} |