diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/init.go | 1 | ||||
-rw-r--r-- | tests/rlp_test.go | 20 | ||||
-rw-r--r-- | tests/rlp_test_util.go | 156 |
3 files changed, 177 insertions, 0 deletions
diff --git a/tests/init.go b/tests/init.go index f149c11f1..30cff6f21 100644 --- a/tests/init.go +++ b/tests/init.go @@ -35,6 +35,7 @@ var ( stateTestDir = filepath.Join(baseDir, "StateTests") transactionTestDir = filepath.Join(baseDir, "TransactionTests") vmTestDir = filepath.Join(baseDir, "VMTests") + rlpTestDir = filepath.Join(baseDir, "RLPTests") BlockSkipTests = []string{ // These tests are not valid, as they are out of scope for RLP and diff --git a/tests/rlp_test.go b/tests/rlp_test.go new file mode 100644 index 000000000..70bd19627 --- /dev/null +++ b/tests/rlp_test.go @@ -0,0 +1,20 @@ +package tests + +import ( + "path/filepath" + "testing" +) + +func TestRLP(t *testing.T) { + err := RunRLPTest(filepath.Join(rlpTestDir, "rlptest.json"), nil) + if err != nil { + t.Fatal(err) + } +} + +func TestRLP_invalid(t *testing.T) { + err := RunRLPTest(filepath.Join(rlpTestDir, "invalidRLPTest.json"), nil) + if err != nil { + t.Fatal(err) + } +} diff --git a/tests/rlp_test_util.go b/tests/rlp_test_util.go new file mode 100644 index 000000000..d7042eef7 --- /dev/null +++ b/tests/rlp_test_util.go @@ -0,0 +1,156 @@ +package tests + +import ( + "bytes" + "encoding/hex" + "errors" + "fmt" + "io" + "math/big" + "os" + "strings" + + "github.com/ethereum/go-ethereum/rlp" +) + +type RLPTest struct { + In interface{} + Out string +} + +func RunRLPTest(file string, skip []string) error { + f, err := os.Open(file) + if err != nil { + return err + } + defer f.Close() + return RunRLPTestWithReader(f, skip) +} + +func RunRLPTestWithReader(r io.Reader, skip []string) error { + var tests map[string]*RLPTest + if err := readJson(r, &tests); err != nil { + return err + } + for _, s := range skip { + delete(tests, s) + } + for name, test := range tests { + if err := test.Run(); err != nil { + return fmt.Errorf("test %q failed: %v", name, err) + } + } + return nil +} + +// Run executes the test. +func (t *RLPTest) Run() error { + outb, err := hex.DecodeString(t.Out) + if err != nil { + return fmt.Errorf("invalid hex in Out") + } + if t.In == "VALID" || t.In == "INVALID" { + return checkDecodeInterface(outb, t.In == "VALID") + } + + // Check whether encoding the value produces the same bytes. + in := translateJSON(t.In) + b, err := rlp.EncodeToBytes(in) + if err != nil { + return fmt.Errorf("encode failed: %v", err) + } + if !bytes.Equal(b, outb) { + return fmt.Errorf("encode produced %x, want %x", b, outb) + } + // Test decoding from a stream. + s := rlp.NewStream(bytes.NewReader(outb), 0) + return checkDecodeFromJSON(s, in) +} + +func checkDecodeInterface(b []byte, isValid bool) error { + err := rlp.DecodeBytes(b, new(interface{})) + switch { + case isValid && err != nil: + return fmt.Errorf("decoding failed: %v", err) + case !isValid && err == nil: + return fmt.Errorf("decoding of invalid value succeeded") + } + return nil +} + +// translateJSON makes test json values encodable with RLP. +func translateJSON(v interface{}) interface{} { + switch v := v.(type) { + case float64: + return uint64(v) + case string: + if len(v) > 0 && v[0] == '#' { // # starts a faux big int. + big, ok := new(big.Int).SetString(v[1:], 10) + if !ok { + panic(fmt.Errorf("bad test: bad big int: %q", v)) + } + return big + } + return []byte(v) + case []interface{}: + new := make([]interface{}, len(v)) + for i := range v { + new[i] = translateJSON(v[i]) + } + return new + default: + panic(fmt.Errorf("can't handle %T", v)) + } +} + +// checkDecodeFromJSON decodes from s guided by exp. For each JSON +// value, the value decoded from the RLP stream must match. +func checkDecodeFromJSON(s *rlp.Stream, exp interface{}) error { + switch exp := exp.(type) { + case uint64: + i, err := s.Uint() + if err != nil { + return addStack("Uint", exp, err) + } + if i != exp { + return addStack("Uint", exp, fmt.Errorf("result mismatch: got %d", i)) + } + case *big.Int: + big := new(big.Int) + if err := s.Decode(&big); err != nil { + return addStack("Big", exp, err) + } + if big.Cmp(exp) != 0 { + return addStack("Big", exp, fmt.Errorf("result mismatch: got %d", big)) + } + case []byte: + b, err := s.Bytes() + if err != nil { + return addStack("Bytes", exp, err) + } + if !bytes.Equal(b, exp) { + return addStack("Bytes", exp, fmt.Errorf("result mismatch: got %x", b)) + } + case []interface{}: + if _, err := s.List(); err != nil { + return addStack("List", exp, err) + } + for i, v := range exp { + if err := checkDecodeFromJSON(s, v); err != nil { + return addStack(fmt.Sprintf("[%d]", i), exp, err) + } + } + if err := s.ListEnd(); err != nil { + return addStack("ListEnd", exp, err) + } + default: + panic(fmt.Errorf("unhandled type: %T", exp)) + } + return nil +} + +func addStack(op string, val interface{}, err error) error { + lines := strings.Split(err.Error(), "\n") + lines = append(lines, fmt.Sprintf("\t%s: %v", op, val)) + return errors.New(strings.Join(lines, "\n")) +} |