aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/protocol_test.go
blob: b1d10ac5360f2ea61a8eb1cb87d6715b2deec944 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
package p2p

import (
    "fmt"
    "net"
    "reflect"
    "sync"
    "testing"

    "github.com/ethereum/go-ethereum/crypto"
)

type peerId struct {
    pubkey []byte
}

func (self *peerId) String() string {
    return fmt.Sprintf("test peer %x", self.Pubkey()[:4])
}

func (self *peerId) Pubkey() (pubkey []byte) {
    pubkey = self.pubkey
    if len(pubkey) == 0 {
        pubkey = crypto.GenerateNewKeyPair().PublicKey
        self.pubkey = pubkey
    }
    return
}

func newTestPeer() (peer *Peer) {
    peer = NewPeer(&peerId{}, []Cap{})
    peer.pubkeyHook = func(*peerAddr) error { return nil }
    peer.ourID = &peerId{}
    peer.listenAddr = &peerAddr{}
    peer.otherPeers = func() []*Peer { return nil }
    return
}

func TestBaseProtocolPeers(t *testing.T) {
    peerList := []*peerAddr{
        {IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}},
        {IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}},
    }
    listenAddr := &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}}
    rw1, rw2 := MsgPipe()
    defer rw1.Close()
    wg := new(sync.WaitGroup)

    // run matcher, close pipe when addresses have arrived
    numPeers := len(peerList) + 1
    addrChan := make(chan *peerAddr)
    wg.Add(1)
    go func() {
        i := 0
        for got := range addrChan {
            var want *peerAddr
            switch {
            case i < len(peerList):
                want = peerList[i]
            case i == len(peerList):
                want = listenAddr // listenAddr should be the last thing sent
            }
            t.Logf("got peer %d/%d: %v", i+1, numPeers, got)
            if !reflect.DeepEqual(want, got) {
                t.Errorf("mismatch: got %+v, want %+v", got, want)
            }
            i++
            if i == numPeers {
                break
            }
        }
        if i != numPeers {
            t.Errorf("wrong number of peers received: got %d, want %d", i, numPeers)
        }
        rw1.Close()
        wg.Done()
    }()

    // run first peer (in background)
    peer1 := newTestPeer()
    peer1.ourListenAddr = listenAddr
    peer1.otherPeers = func() []*Peer {
        pl := make([]*Peer, len(peerList))
        for i, addr := range peerList {
            pl[i] = &Peer{listenAddr: addr}
        }
        return pl
    }
    wg.Add(1)
    go func() {
        runBaseProtocol(peer1, rw1)
        wg.Done()
    }()

    // run second peer
    peer2 := newTestPeer()
    peer2.newPeerAddr = addrChan // feed peer suggestions into matcher
    if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed {
        t.Errorf("peer2 terminated with unexpected error: %v", err)
    }

    // terminate matcher
    close(addrChan)
    wg.Wait()
}

func TestBaseProtocolDisconnect(t *testing.T) {
    peer := NewPeer(&peerId{}, nil)
    peer.ourID = &peerId{}
    peer.pubkeyHook = func(*peerAddr) error { return nil }

    rw1, rw2 := MsgPipe()
    done := make(chan struct{})
    go func() {
        if err := expectMsg(rw2, handshakeMsg); err != nil {
            t.Error(err)
        }
        err := EncodeMsg(rw2, handshakeMsg,
            baseProtocolVersion,
            "",
            []interface{}{},
            0,
            make([]byte, 64),
        )
        if err != nil {
            t.Error(err)
        }
        if err := expectMsg(rw2, getPeersMsg); err != nil {
            t.Error(err)
        }
        if err := EncodeMsg(rw2, discMsg, DiscQuitting); err != nil {
            t.Error(err)
        }

        close(done)
    }()

    if err := runBaseProtocol(peer, rw1); err == nil {
        t.Errorf("base protocol returned without error")
    } else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting {
        t.Errorf("base protocol returned wrong error: %v", err)
    }
    <-done
}

func expectMsg(r MsgReader, code uint64) error {
    msg, err := r.ReadMsg()
    if err != nil {
        return err
    }
    if err := msg.Discard(); err != nil {
        return err
    }
    if msg.Code != code {
        return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code)
    }
    return nil
}