aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/rlpx_test.go
blob: b3c2adf8d28a8fc100eff57db34b060fcfb4ef5a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
package p2p

import (
    "bytes"
    "crypto/rand"
    "encoding/hex"
    "fmt"
    "io/ioutil"
    "strings"
    "testing"

    "github.com/ethereum/go-ethereum/crypto"
    "github.com/ethereum/go-ethereum/crypto/sha3"
    "github.com/ethereum/go-ethereum/rlp"
)

func TestRlpxFrameFake(t *testing.T) {
    buf := new(bytes.Buffer)
    hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
    rw := newRlpxFrameRW(buf, secrets{
        AES:        crypto.Sha3(),
        MAC:        crypto.Sha3(),
        IngressMAC: hash,
        EgressMAC:  hash,
    })

    golden := unhex(`
00828ddae471818bb0bfa6b551d1cb42
01010101010101010101010101010101
ba628a4ba590cb43f7848f41c4382885
01010101010101010101010101010101
01010101010101010101010101010101
`)

    // Check WriteMsg. This puts a message into the buffer.
    if err := EncodeMsg(rw, 8, []interface{}{1, 2, 3, 4}); err != nil {
        t.Fatalf("WriteMsg error: %v", err)
    }
    written := buf.Bytes()
    if !bytes.Equal(written, golden) {
        t.Fatalf("output mismatch:\n  got:  %x\n  want: %x", written, golden)
    }

    // Check ReadMsg. It reads the message encoded by WriteMsg, which
    // is equivalent to the golden message above.
    msg, err := rw.ReadMsg()
    if err != nil {
        t.Fatalf("ReadMsg error: %v", err)
    }
    if msg.Size != 5 {
        t.Errorf("msg size mismatch: got %d, want %d", msg.Size, 5)
    }
    if msg.Code != 8 {
        t.Errorf("msg code mismatch: got %d, want %d", msg.Code, 8)
    }
    payload, _ := ioutil.ReadAll(msg.Payload)
    wantPayload := unhex("C401020304")
    if !bytes.Equal(payload, wantPayload) {
        t.Errorf("msg payload mismatch:\ngot  %x\nwant %x", payload, wantPayload)
    }
}

type fakeHash []byte

func (fakeHash) Write(p []byte) (int, error) { return len(p), nil }
func (fakeHash) Reset()                      {}
func (fakeHash) BlockSize() int              { return 0 }

func (h fakeHash) Size() int           { return len(h) }
func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }

func unhex(str string) []byte {
    b, err := hex.DecodeString(strings.Replace(str, "\n", "", -1))
    if err != nil {
        panic(fmt.Sprintf("invalid hex string: %q", str))
    }
    return b
}

func TestRlpxFrameRW(t *testing.T) {
    var (
        aesSecret      = make([]byte, 16)
        macSecret      = make([]byte, 16)
        egressMACinit  = make([]byte, 32)
        ingressMACinit = make([]byte, 32)
    )
    for _, s := range [][]byte{aesSecret, macSecret, egressMACinit, ingressMACinit} {
        rand.Read(s)
    }
    conn := new(bytes.Buffer)

    s1 := secrets{
        AES:        aesSecret,
        MAC:        macSecret,
        EgressMAC:  sha3.NewKeccak256(),
        IngressMAC: sha3.NewKeccak256(),
    }
    s1.EgressMAC.Write(egressMACinit)
    s1.IngressMAC.Write(ingressMACinit)
    rw1 := newRlpxFrameRW(conn, s1)

    s2 := secrets{
        AES:        aesSecret,
        MAC:        macSecret,
        EgressMAC:  sha3.NewKeccak256(),
        IngressMAC: sha3.NewKeccak256(),
    }
    s2.EgressMAC.Write(ingressMACinit)
    s2.IngressMAC.Write(egressMACinit)
    rw2 := newRlpxFrameRW(conn, s2)

    // send some messages
    for i := 0; i < 10; i++ {
        // write message into conn buffer
        wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)}
        err := EncodeMsg(rw1, uint64(i), wmsg)
        if err != nil {
            t.Fatalf("WriteMsg error (i=%d): %v", i, err)
        }

        // read message that rw1 just wrote
        msg, err := rw2.ReadMsg()
        if err != nil {
            t.Fatalf("ReadMsg error (i=%d): %v", i, err)
        }
        if msg.Code != uint64(i) {
            t.Fatalf("msg code mismatch: got %d, want %d", msg.Code, i)
        }
        payload, _ := ioutil.ReadAll(msg.Payload)
        wantPayload, _ := rlp.EncodeToBytes(wmsg)
        if !bytes.Equal(payload, wantPayload) {
            t.Fatalf("msg payload mismatch:\ngot  %x\nwant %x", payload, wantPayload)
        }
    }
}