aboutsummaryrefslogblamecommitdiffstats
path: root/tests/rlp_test_util.go
blob: c322b78c6a143a19c3c55707548a8afdf9414eae (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15














                                             
                                                      
                     









                                                                    


                  
                                                                       








                                                   
                                                                  





















                                                                          

                                                                












                                                                         
                                







































                                                                                 



                                                                      
















































                                                                                               
package tests

import (
    "bytes"
    "encoding/hex"
    "errors"
    "fmt"
    "io"
    "math/big"
    "os"
    "strings"

    "github.com/ethereum/go-ethereum/rlp"
)

// RLPTest is the JSON structure of a single RLP test.
type RLPTest struct {
    // If the value of In is "INVALID" or "VALID", the test
    // checks whether Out can be decoded into a value of
    // type interface{}.
    //
    // For other JSON values, In is treated as a driver for
    // calls to rlp.Stream. The test also verifies that encoding
    // In produces the bytes in Out.
    In interface{}

    // Out is a hex-encoded RLP value.
    Out string
}

// RunRLPTest runs the tests in the given file, skipping tests by name.
func RunRLPTest(file string, skip []string) error {
    f, err := os.Open(file)
    if err != nil {
        return err
    }
    defer f.Close()
    return RunRLPTestWithReader(f, skip)
}

// RunRLPTest runs the tests encoded in r, skipping tests by name.
func RunRLPTestWithReader(r io.Reader, skip []string) error {
    var tests map[string]*RLPTest
    if err := readJson(r, &tests); err != nil {
        return err
    }
    for _, s := range skip {
        delete(tests, s)
    }
    for name, test := range tests {
        if err := test.Run(); err != nil {
            return fmt.Errorf("test %q failed: %v", name, err)
        }
    }
    return nil
}

// Run executes the test.
func (t *RLPTest) Run() error {
    outb, err := hex.DecodeString(t.Out)
    if err != nil {
        return fmt.Errorf("invalid hex in Out")
    }

    // Handle simple decoding tests with no actual In value.
    if t.In == "VALID" || t.In == "INVALID" {
        return checkDecodeInterface(outb, t.In == "VALID")
    }

    // Check whether encoding the value produces the same bytes.
    in := translateJSON(t.In)
    b, err := rlp.EncodeToBytes(in)
    if err != nil {
        return fmt.Errorf("encode failed: %v", err)
    }
    if !bytes.Equal(b, outb) {
        return fmt.Errorf("encode produced %x, want %x", b, outb)
    }
    // Test stream decoding.
    s := rlp.NewStream(bytes.NewReader(outb), 0)
    return checkDecodeFromJSON(s, in)
}

func checkDecodeInterface(b []byte, isValid bool) error {
    err := rlp.DecodeBytes(b, new(interface{}))
    switch {
    case isValid && err != nil:
        return fmt.Errorf("decoding failed: %v", err)
    case !isValid && err == nil:
        return fmt.Errorf("decoding of invalid value succeeded")
    }
    return nil
}

// translateJSON makes test json values encodable with RLP.
func translateJSON(v interface{}) interface{} {
    switch v := v.(type) {
    case float64:
        return uint64(v)
    case string:
        if len(v) > 0 && v[0] == '#' { // # starts a faux big int.
            big, ok := new(big.Int).SetString(v[1:], 10)
            if !ok {
                panic(fmt.Errorf("bad test: bad big int: %q", v))
            }
            return big
        }
        return []byte(v)
    case []interface{}:
        new := make([]interface{}, len(v))
        for i := range v {
            new[i] = translateJSON(v[i])
        }
        return new
    default:
        panic(fmt.Errorf("can't handle %T", v))
    }
}

// checkDecodeFromJSON decodes from s guided by exp. exp drives the
// Stream by invoking decoding operations (Uint, Big, List, ...) based
// on the type of each value. The value decoded from the RLP stream
// must match the JSON value.
func checkDecodeFromJSON(s *rlp.Stream, exp interface{}) error {
    switch exp := exp.(type) {
    case uint64:
        i, err := s.Uint()
        if err != nil {
            return addStack("Uint", exp, err)
        }
        if i != exp {
            return addStack("Uint", exp, fmt.Errorf("result mismatch: got %d", i))
        }
    case *big.Int:
        big := new(big.Int)
        if err := s.Decode(&big); err != nil {
            return addStack("Big", exp, err)
        }
        if big.Cmp(exp) != 0 {
            return addStack("Big", exp, fmt.Errorf("result mismatch: got %d", big))
        }
    case []byte:
        b, err := s.Bytes()
        if err != nil {
            return addStack("Bytes", exp, err)
        }
        if !bytes.Equal(b, exp) {
            return addStack("Bytes", exp, fmt.Errorf("result mismatch: got %x", b))
        }
    case []interface{}:
        if _, err := s.List(); err != nil {
            return addStack("List", exp, err)
        }
        for i, v := range exp {
            if err := checkDecodeFromJSON(s, v); err != nil {
                return addStack(fmt.Sprintf("[%d]", i), exp, err)
            }
        }
        if err := s.ListEnd(); err != nil {
            return addStack("ListEnd", exp, err)
        }
    default:
        panic(fmt.Errorf("unhandled type: %T", exp))
    }
    return nil
}

func addStack(op string, val interface{}, err error) error {
    lines := strings.Split(err.Error(), "\n")
    lines = append(lines, fmt.Sprintf("\t%s: %v", op, val))
    return errors.New(strings.Join(lines, "\n"))
}