aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/discover/udp_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/discover/udp_test.go')
-rw-r--r--p2p/discover/udp_test.go422
1 files changed, 275 insertions, 147 deletions
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index 0a8ff6358..c6c4d78e3 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -1,10 +1,18 @@
package discover
import (
+ "bytes"
+ "crypto/ecdsa"
+ "errors"
"fmt"
+ "io"
logpkg "log"
"net"
"os"
+ "path"
+ "reflect"
+ "runtime"
+ "sync"
"testing"
"time"
@@ -15,197 +23,317 @@ func init() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
}
-func TestUDP_ping(t *testing.T) {
- t.Parallel()
-
- n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- defer n1.Close()
- defer n2.Close()
+type udpTest struct {
+ t *testing.T
+ pipe *dgramPipe
+ table *Table
+ udp *udp
+ sent [][]byte
+ localkey, remotekey *ecdsa.PrivateKey
+ remoteaddr *net.UDPAddr
+}
- if err := n1.net.ping(n2.self); err != nil {
- t.Fatalf("ping error: %v", err)
+func newUDPTest(t *testing.T) *udpTest {
+ test := &udpTest{
+ t: t,
+ pipe: newpipe(),
+ localkey: newkey(),
+ remotekey: newkey(),
+ remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303},
}
- if find(n2, n1.self.ID) == nil {
- t.Errorf("node 2 does not contain id of node 1")
+ test.table, test.udp = newUDP(test.localkey, test.pipe, nil)
+ return test
+}
+
+// handles a packet as if it had been sent to the transport.
+func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
+ enc, err := encodePacket(test.remotekey, ptype, data)
+ if err != nil {
+ return test.errorf("packet (%d) encode error: %v", err)
}
- if e := find(n1, n2.self.ID); e != nil {
- t.Errorf("node 1 does contains id of node 2: %v", e)
+ test.sent = append(test.sent, enc)
+ err = data.handle(test.udp, test.remoteaddr, PubkeyID(&test.remotekey.PublicKey), enc[:macSize])
+ if err != wantError {
+ return test.errorf("error mismatch: got %q, want %q", err, wantError)
}
+ return nil
}
-func find(tab *Table, id NodeID) *Node {
- for _, b := range tab.buckets {
- for _, e := range b.entries {
- if e.ID == id {
- return e
- }
- }
+// waits for a packet to be sent by the transport.
+// validate should have type func(*udpTest, X) error, where X is a packet type.
+func (test *udpTest) waitPacketOut(validate interface{}) error {
+ dgram := test.pipe.waitPacketOut()
+ p, _, _, err := decodePacket(dgram)
+ if err != nil {
+ return test.errorf("sent packet decode error: %v", err)
}
+ fn := reflect.ValueOf(validate)
+ exptype := fn.Type().In(0)
+ if reflect.TypeOf(p) != exptype {
+ return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
+ }
+ fn.Call([]reflect.Value{reflect.ValueOf(p)})
return nil
}
-func TestUDP_findnode(t *testing.T) {
+func (test *udpTest) errorf(format string, args ...interface{}) error {
+ _, file, line, ok := runtime.Caller(2) // errorf + waitPacketOut
+ if ok {
+ file = path.Base(file)
+ } else {
+ file = "???"
+ line = 1
+ }
+ err := fmt.Errorf(format, args...)
+ fmt.Printf("\t%s:%d: %v\n", file, line, err)
+ test.t.Fail()
+ return err
+}
+
+// shared test variables
+var (
+ futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())
+ testTarget = MustHexID("01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101")
+)
+
+func TestUDP_packetErrors(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.table.Close()
+
+ test.packetIn(errExpired, pingPacket, &ping{IP: "foo", Port: 99, Version: Version})
+ test.packetIn(errBadVersion, pingPacket, &ping{IP: "foo", Port: 99, Version: 99, Expiration: futureExp})
+ test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp})
+ test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp})
+ test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp})
+}
+
+func TestUDP_pingTimeout(t *testing.T) {
+ t.Parallel()
+ test := newUDPTest(t)
+ defer test.table.Close()
+
+ toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
+ toid := NodeID{1, 2, 3, 4}
+ if err := test.udp.ping(toid, toaddr); err != errTimeout {
+ t.Error("expected timeout error, got", err)
+ }
+}
+
+func TestUDP_findnodeTimeout(t *testing.T) {
t.Parallel()
+ test := newUDPTest(t)
+ defer test.table.Close()
- n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- defer n1.Close()
- defer n2.Close()
+ toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
+ toid := NodeID{1, 2, 3, 4}
+ target := NodeID{4, 5, 6, 7}
+ result, err := test.udp.findnode(toid, toaddr, target)
+ if err != errTimeout {
+ t.Error("expected timeout error, got", err)
+ }
+ if len(result) > 0 {
+ t.Error("expected empty result, got", result)
+ }
+}
- // put a few nodes into n2. the exact distribution shouldn't
- // matter much, altough we need to take care not to overflow
- // any bucket.
- target := randomID(n1.self.ID, 100)
+func TestUDP_findnode(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.table.Close()
+
+ // put a few nodes into the table. their exact
+ // distribution shouldn't matter much, altough we need to
+ // take care not to overflow any bucket.
+ target := testTarget
nodes := &nodesByDistance{target: target}
for i := 0; i < bucketSize; i++ {
- n2.add([]*Node{&Node{
+ nodes.push(&Node{
IP: net.IP{1, 2, 3, byte(i)},
DiscPort: i + 2,
TCPPort: i + 2,
- ID: randomID(n2.self.ID, i+2),
- }})
+ ID: randomID(test.table.self.ID, i+2),
+ }, bucketSize)
}
- n2.add(nodes.entries)
- n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
- expected := n2.closest(target, bucketSize)
+ test.table.add(nodes.entries)
+
+ // ensure there's a bond with the test node,
+ // findnode won't be accepted otherwise.
+ test.table.db.add(PubkeyID(&test.remotekey.PublicKey), test.remoteaddr, 99)
- err := runUDP(10, func() error {
- result, _ := n1.net.findnode(n2.self, target)
- if len(result) != bucketSize {
- return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
+ // check that closest neighbors are returned.
+ test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
+ test.waitPacketOut(func(p *neighbors) {
+ expected := test.table.closest(testTarget, bucketSize)
+ if len(p.Nodes) != bucketSize {
+ t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize)
}
- for i := range result {
- if result[i].ID != expected.entries[i].ID {
- return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
+ for i := range p.Nodes {
+ if p.Nodes[i].ID != expected.entries[i].ID {
+ t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i])
}
}
- return nil
})
- if err != nil {
- t.Error(err)
- }
}
-func TestUDP_replytimeout(t *testing.T) {
- t.Parallel()
+func TestUDP_findnodeMultiReply(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.table.Close()
- // reserve a port so we don't talk to an existing service by accident
- addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
- fd, err := net.ListenUDP("udp", addr)
- if err != nil {
- t.Fatal(err)
- }
- defer fd.Close()
+ // queue a pending findnode request
+ resultc, errc := make(chan []*Node), make(chan error)
+ go func() {
+ rid := PubkeyID(&test.remotekey.PublicKey)
+ ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget)
+ if err != nil && len(ns) == 0 {
+ errc <- err
+ } else {
+ resultc <- ns
+ }
+ }()
- n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- defer n1.Close()
- n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
+ // wait for the findnode to be sent.
+ // after it is sent, the transport is waiting for a reply
+ test.waitPacketOut(func(p *findnode) {
+ if p.Target != testTarget {
+ t.Errorf("wrong target: got %v, want %v", p.Target, testTarget)
+ }
+ })
- if err := n1.net.ping(n2); err != errTimeout {
- t.Error("expected timeout error, got", err)
+ // send the reply as two packets.
+ list := []*Node{
+ MustParseNode("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303"),
+ MustParseNode("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"),
+ MustParseNode("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301"),
+ MustParseNode("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"),
}
+ test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[:2]})
+ test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[2:]})
- if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
- t.Error("expected timeout error, got", err)
- } else if len(result) > 0 {
- t.Error("expected empty result, got", result)
+ // check that the sent neighbors are all returned by findnode
+ select {
+ case result := <-resultc:
+ if !reflect.DeepEqual(result, list) {
+ t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, list)
+ }
+ case err := <-errc:
+ t.Errorf("findnode error: %v", err)
+ case <-time.After(5 * time.Second):
+ t.Error("findnode did not return within 5 seconds")
}
}
-func TestUDP_findnodeMultiReply(t *testing.T) {
- t.Parallel()
+func TestUDP_successfulPing(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.table.Close()
- n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
- udp2 := n2.net.(*udp)
- defer n1.Close()
- defer n2.Close()
-
- err := runUDP(10, func() error {
- nodes := make([]*Node, bucketSize)
- for i := range nodes {
- nodes[i] = &Node{
- IP: net.IP{1, 2, 3, 4},
- DiscPort: i + 1,
- TCPPort: i + 1,
- ID: randomID(n2.self.ID, i+1),
- }
- }
+ done := make(chan struct{})
+ go func() {
+ test.packetIn(nil, pingPacket, &ping{IP: "foo", Port: 99, Version: Version, Expiration: futureExp})
+ close(done)
+ }()
- // ask N2 for neighbors. it will send an empty reply back.
- // the request will wait for up to bucketSize replies.
- resultc := make(chan []*Node)
- errc := make(chan error)
- go func() {
- ns, err := n1.net.findnode(n2.self, n1.self.ID)
- if err != nil {
- errc <- err
- } else {
- resultc <- ns
- }
- }()
-
- // send a few more neighbors packets to N1.
- // it should collect those.
- for end := 0; end < len(nodes); {
- off := end
- if end = end + 5; end > len(nodes) {
- end = len(nodes)
- }
- udp2.send(n1.self, neighborsPacket, neighbors{
- Nodes: nodes[off:end],
- Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
- })
+ // the ping is replied to.
+ test.waitPacketOut(func(p *pong) {
+ pinghash := test.sent[0][:macSize]
+ if !bytes.Equal(p.ReplyTok, pinghash) {
+ t.Errorf("got ReplyTok %x, want %x", p.ReplyTok, pinghash)
}
+ })
- // check that they are all returned. we cannot just check for
- // equality because they might not be returned in the order they
- // were sent.
- var result []*Node
- select {
- case result = <-resultc:
- case err := <-errc:
- return err
- }
- if hasDuplicates(result) {
- return fmt.Errorf("result slice contains duplicates")
- }
- if len(result) != len(nodes) {
- return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
- }
- matched := make(map[NodeID]bool)
- for _, n := range result {
- for _, expn := range nodes {
- if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
- matched[n.ID] = true
- }
+ // remote is unknown, the table pings back.
+ test.waitPacketOut(func(p *ping) error { return nil })
+ test.packetIn(nil, pongPacket, &pong{Expiration: futureExp})
+
+ // ping should return shortly after getting the pong packet.
+ <-done
+
+ // check that the node was added.
+ rid := PubkeyID(&test.remotekey.PublicKey)
+ rnode := find(test.table, rid)
+ if rnode == nil {
+ t.Fatalf("node %v not found in table", rid)
+ }
+ if !bytes.Equal(rnode.IP, test.remoteaddr.IP) {
+ t.Errorf("node has wrong IP: got %v, want: %v", rnode.IP, test.remoteaddr.IP)
+ }
+ if rnode.DiscPort != test.remoteaddr.Port {
+ t.Errorf("node has wrong Port: got %v, want: %v", rnode.DiscPort, test.remoteaddr.Port)
+ }
+ if rnode.TCPPort != 99 {
+ t.Errorf("node has wrong Port: got %v, want: %v", rnode.TCPPort, 99)
+ }
+}
+
+func find(tab *Table, id NodeID) *Node {
+ for _, b := range tab.buckets {
+ for _, e := range b.entries {
+ if e.ID == id {
+ return e
}
}
- if len(matched) != len(nodes) {
- return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
- }
- return nil
- })
- if err != nil {
- t.Error(err)
}
+ return nil
}
-// runUDP runs a test n times and returns an error if the test failed
-// in all n runs. This is necessary because UDP is unreliable even for
-// connections on the local machine, causing test failures.
-func runUDP(n int, test func() error) error {
- errcount := 0
- errors := ""
- for i := 0; i < n; i++ {
- if err := test(); err != nil {
- errors += fmt.Sprintf("\n#%d: %v", i, err)
- errcount++
- }
+// dgramPipe is a fake UDP socket. It queues all sent datagrams.
+type dgramPipe struct {
+ mu *sync.Mutex
+ cond *sync.Cond
+ closing chan struct{}
+ closed bool
+ queue [][]byte
+}
+
+func newpipe() *dgramPipe {
+ mu := new(sync.Mutex)
+ return &dgramPipe{
+ closing: make(chan struct{}),
+ cond: &sync.Cond{L: mu},
+ mu: mu,
}
- if errcount == n {
- return fmt.Errorf("failed on all %d iterations:%s", n, errors)
+}
+
+// WriteToUDP queues a datagram.
+func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) {
+ msg := make([]byte, len(b))
+ copy(msg, b)
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return 0, errors.New("closed")
+ }
+ c.queue = append(c.queue, msg)
+ c.cond.Signal()
+ return len(b), nil
+}
+
+// ReadFromUDP just hangs until the pipe is closed.
+func (c *dgramPipe) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
+ <-c.closing
+ return 0, nil, io.EOF
+}
+
+func (c *dgramPipe) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if !c.closed {
+ close(c.closing)
+ c.closed = true
}
return nil
}
+
+func (c *dgramPipe) LocalAddr() net.Addr {
+ return &net.UDPAddr{}
+}
+
+func (c *dgramPipe) waitPacketOut() []byte {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ for len(c.queue) == 0 {
+ c.cond.Wait()
+ }
+ p := c.queue[0]
+ copy(c.queue, c.queue[1:])
+ c.queue = c.queue[:len(c.queue)-1]
+ return p
+}