diff options
Diffstat (limited to 'p2p/server_test.go')
-rw-r--r-- | p2p/server_test.go | 179 |
1 files changed, 179 insertions, 0 deletions
diff --git a/p2p/server_test.go b/p2p/server_test.go new file mode 100644 index 000000000..30447050c --- /dev/null +++ b/p2p/server_test.go @@ -0,0 +1,179 @@ +package p2p + +import ( + "bytes" + "crypto/ecdsa" + "io" + "math/rand" + "net" + "sync" + "testing" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/p2p/discover" +) + +func startTestServer(t *testing.T, pf newPeerHook) *Server { + server := &Server{ + Name: "test", + MaxPeers: 10, + ListenAddr: "127.0.0.1:0", + PrivateKey: newkey(), + newPeerHook: pf, + setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { + id := randomID() + rw := newRlpxFrameRW(fd, secrets{ + MAC: zero16, + AES: zero16, + IngressMAC: sha3.NewKeccak256(), + EgressMAC: sha3.NewKeccak256(), + }) + return &conn{ + MsgReadWriter: rw, + protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion}, + }, nil + }, + } + if err := server.Start(); err != nil { + t.Fatalf("Could not start server: %v", err) + } + return server +} + +func TestServerListen(t *testing.T) { + defer testlog(t).detach() + + // start the test server + connected := make(chan *Peer) + srv := startTestServer(t, func(p *Peer) { + if p == nil { + t.Error("peer func called with nil conn") + } + connected <- p + }) + defer close(connected) + defer srv.Stop() + + // dial the test server + conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) + if err != nil { + t.Fatalf("could not dial: %v", err) + } + defer conn.Close() + + select { + case peer := <-connected: + if peer.LocalAddr().String() != conn.RemoteAddr().String() { + t.Errorf("peer started with wrong conn: got %v, want %v", + peer.LocalAddr(), conn.RemoteAddr()) + } + case <-time.After(1 * time.Second): + t.Error("server did not accept within one second") + } +} + +func TestServerDial(t *testing.T) { + defer testlog(t).detach() + + // run a one-shot TCP server to handle the connection. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("could not setup listener: %v") + } + defer listener.Close() + accepted := make(chan net.Conn) + go func() { + conn, err := listener.Accept() + if err != nil { + t.Error("accept error:", err) + return + } + conn.Close() + accepted <- conn + }() + + // start the server + connected := make(chan *Peer) + srv := startTestServer(t, func(p *Peer) { connected <- p }) + defer close(connected) + defer srv.Stop() + + // tell the server to connect + tcpAddr := listener.Addr().(*net.TCPAddr) + srv.SuggestPeer(&discover.Node{IP: tcpAddr.IP, TCPPort: tcpAddr.Port}) + + select { + case conn := <-accepted: + select { + case peer := <-connected: + if peer.RemoteAddr().String() != conn.LocalAddr().String() { + t.Errorf("peer started with wrong conn: got %v, want %v", + peer.RemoteAddr(), conn.LocalAddr()) + } + // TODO: validate more fields + case <-time.After(1 * time.Second): + t.Error("server did not launch peer within one second") + } + + case <-time.After(1 * time.Second): + t.Error("server did not connect within one second") + } +} + +func TestServerBroadcast(t *testing.T) { + defer testlog(t).detach() + + var connected sync.WaitGroup + srv := startTestServer(t, func(p *Peer) { + p.running = matchProtocols([]Protocol{discard}, []Cap{discard.cap()}, p.rw) + connected.Done() + }) + defer srv.Stop() + + // create a few peers + var conns = make([]net.Conn, 8) + connected.Add(len(conns)) + deadline := time.Now().Add(3 * time.Second) + dialer := &net.Dialer{Deadline: deadline} + for i := range conns { + conn, err := dialer.Dial("tcp", srv.ListenAddr) + if err != nil { + t.Fatalf("conn %d: dial error: %v", i, err) + } + defer conn.Close() + conn.SetDeadline(deadline) + conns[i] = conn + } + connected.Wait() + + // broadcast one message + srv.Broadcast("discard", 0, "foo") + golden := unhex("66e94d166f0a2c3b884cfa59ca34") + + // check that the message has been written everywhere + for i, conn := range conns { + buf := make([]byte, len(golden)) + if _, err := io.ReadFull(conn, buf); err != nil { + t.Errorf("conn %d: read error: %v", i, err) + } else if !bytes.Equal(buf, golden) { + t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden) + } + } +} + +func newkey() *ecdsa.PrivateKey { + key, err := crypto.GenerateKey() + if err != nil { + panic("couldn't generate key: " + err.Error()) + } + return key +} + +func randomID() (id discover.NodeID) { + for i := range id { + id[i] = byte(rand.Intn(255)) + } + return id +} |