aboutsummaryrefslogblamecommitdiffstats
path: root/p2p/server_test.go
blob: 01448cc7bb5559ad6efa7d93550d011a65cdca27 (plain) (tree)
1
2
3
4
5
6
7
8
9
10


           
                      
                
                   
             
                 

                 

                                                
                                                     
                                                      

 




































                                                                                                                 
                          





                                                                                              
         

                                                           
         
                     

 
                                     

                                     




                                                                      
                             

                                                                 
                              







                                                                          
         
                          
 

                                 
                                                                            
                                                                                 
                                                                    
                 



                                                                                         

                                                                  


         
                                   
                                                              








                                                         

                                                     




                                
                           
                                     

                                                                          


                              

                                                 
                                                                                         


                                
                        
                                         





                                                               
                                                                                    
                                                                                         
                                                                            
                         



                                                                                                 

                                                                               
                 
 

                                                                   
         

 

















                                                                                            
                 
         
 






                                                          
         




                               
 


                                                 
         







                                                                                  
                 
         
 

                   
                


                                                                 
         
 
 



                                                                          
 







                                                                                                
 


                      

 


                                    
 





                                                             
                                       
                                 
                                   
                                                                
         

                                                    
         
                        
 




                                                                                                       
 




                                                                      
                 
         



                                                                                 
         



                                                                                 
         

                                                          
         
 

 

























































                                                                                                         
         
 






                                                                                     
                 


                                                                          
                         
                 







                                                                                                                          


         


                                       
 

                                         
 


                       
 







                                                                                                                  
         












                                                         

 













                                                              
package p2p

import (
    "crypto/ecdsa"
    "errors"
    "math/rand"
    "net"
    "reflect"
    "testing"
    "time"

    "github.com/ethereum/go-ethereum/crypto"
    "github.com/ethereum/go-ethereum/crypto/sha3"
    "github.com/ethereum/go-ethereum/p2p/discover"
)

func init() {
    // glog.SetV(6)
    // glog.SetToStderr(true)
}

type testTransport struct {
    id discover.NodeID
    *rlpx

    closeErr error
}

func newTestTransport(id discover.NodeID, fd net.Conn) transport {
    wrapped := newRLPX(fd).(*rlpx)
    wrapped.rw = newRLPXFrameRW(fd, secrets{
        MAC:        zero16,
        AES:        zero16,
        IngressMAC: sha3.NewKeccak256(),
        EgressMAC:  sha3.NewKeccak256(),
    })
    return &testTransport{id: id, rlpx: wrapped}
}

func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
    return c.id, nil
}

func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
    return &protoHandshake{ID: c.id, Name: "test"}, nil
}

func (c *testTransport) close(err error) {
    c.rlpx.fd.Close()
    c.closeErr = err
}

func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
    server := &Server{
        Name:         "test",
        MaxPeers:     10,
        ListenAddr:   "127.0.0.1:0",
        PrivateKey:   newkey(),
        newPeerHook:  pf,
        newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
    }
    if err := server.Start(); err != nil {
        t.Fatalf("Could not start server: %v", err)
    }
    return server
}

func TestServerListen(t *testing.T) {
    // start the test server
    connected := make(chan *Peer)
    remid := randomID()
    srv := startTestServer(t, remid, func(p *Peer) {
        if p.ID() != remid {
            t.Error("peer func called with wrong node id")
        }
        if p == nil {
            t.Error("peer func called with nil conn")
        }
        connected <- p
    })
    defer close(connected)
    defer srv.Stop()

    // dial the test server
    conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
    if err != nil {
        t.Fatalf("could not dial: %v", err)
    }
    defer conn.Close()

    select {
    case peer := <-connected:
        if peer.LocalAddr().String() != conn.RemoteAddr().String() {
            t.Errorf("peer started with wrong conn: got %v, want %v",
                peer.LocalAddr(), conn.RemoteAddr())
        }
        peers := srv.Peers()
        if !reflect.DeepEqual(peers, []*Peer{peer}) {
            t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
        }
    case <-time.After(1 * time.Second):
        t.Error("server did not accept within one second")
    }
}

func TestServerDial(t *testing.T) {
    // run a one-shot TCP server to handle the connection.
    listener, err := net.Listen("tcp", "127.0.0.1:0")
    if err != nil {
        t.Fatalf("could not setup listener: %v")
    }
    defer listener.Close()
    accepted := make(chan net.Conn)
    go func() {
        conn, err := listener.Accept()
        if err != nil {
            t.Error("accept error:", err)
            return
        }
        conn.Close()
        accepted <- conn
    }()

    // start the server
    connected := make(chan *Peer)
    remid := randomID()
    srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
    defer close(connected)
    defer srv.Stop()

    // tell the server to connect
    tcpAddr := listener.Addr().(*net.TCPAddr)
    srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})

    select {
    case conn := <-accepted:
        select {
        case peer := <-connected:
            if peer.ID() != remid {
                t.Errorf("peer has wrong id")
            }
            if peer.Name() != "test" {
                t.Errorf("peer has wrong name")
            }
            if peer.RemoteAddr().String() != conn.LocalAddr().String() {
                t.Errorf("peer started with wrong conn: got %v, want %v",
                    peer.RemoteAddr(), conn.LocalAddr())
            }
            peers := srv.Peers()
            if !reflect.DeepEqual(peers, []*Peer{peer}) {
                t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
            }
        case <-time.After(1 * time.Second):
            t.Error("server did not launch peer within one second")
        }

    case <-time.After(1 * time.Second):
        t.Error("server did not connect within one second")
    }
}

// This test checks that tasks generated by dialstate are
// actually executed and taskdone is called for them.
func TestServerTaskScheduling(t *testing.T) {
    var (
        done           = make(chan *testTask)
        quit, returned = make(chan struct{}), make(chan struct{})
        tc             = 0
        tg             = taskgen{
            newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
                tc++
                return []task{&testTask{index: tc - 1}}
            },
            doneFunc: func(t task) {
                select {
                case done <- t.(*testTask):
                case <-quit:
                }
            },
        }
    )

    // The Server in this test isn't actually running
    // because we're only interested in what run does.
    srv := &Server{
        MaxPeers: 10,
        quit:     make(chan struct{}),
        ntab:     fakeTable{},
        running:  true,
    }
    srv.loopWG.Add(1)
    go func() {
        srv.run(tg)
        close(returned)
    }()

    var gotdone []*testTask
    for i := 0; i < 100; i++ {
        gotdone = append(gotdone, <-done)
    }
    for i, task := range gotdone {
        if task.index != i {
            t.Errorf("task %d has wrong index, got %d", i, task.index)
            break
        }
        if !task.called {
            t.Errorf("task %d was not called", i)
            break
        }
    }

    close(quit)
    srv.Stop()
    select {
    case <-returned:
    case <-time.After(500 * time.Millisecond):
        t.Error("Server.run did not return within 500ms")
    }
}

type taskgen struct {
    newFunc  func(running int, peers map[discover.NodeID]*Peer) []task
    doneFunc func(task)
}

func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
    return tg.newFunc(running, peers)
}
func (tg taskgen) taskDone(t task, now time.Time) {
    tg.doneFunc(t)
}
func (tg taskgen) addStatic(*discover.Node) {
}

type testTask struct {
    index  int
    called bool
}

func (t *testTask) Do(srv *Server) {
    t.called = true
}

// This test checks that connections are disconnected
// just after the encryption handshake when the server is
// at capacity. Trusted connections should still be accepted.
func TestServerAtCap(t *testing.T) {
    trustedID := randomID()
    srv := &Server{
        PrivateKey:   newkey(),
        MaxPeers:     10,
        NoDial:       true,
        TrustedNodes: []*discover.Node{{ID: trustedID}},
    }
    if err := srv.Start(); err != nil {
        t.Fatalf("could not start: %v", err)
    }
    defer srv.Stop()

    newconn := func(id discover.NodeID) *conn {
        fd, _ := net.Pipe()
        tx := newTestTransport(id, fd)
        return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
    }

    // Inject a few connections to fill up the peer set.
    for i := 0; i < 10; i++ {
        c := newconn(randomID())
        if err := srv.checkpoint(c, srv.addpeer); err != nil {
            t.Fatalf("could not add conn %d: %v", i, err)
        }
    }
    // Try inserting a non-trusted connection.
    c := newconn(randomID())
    if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
        t.Error("wrong error for insert:", err)
    }
    // Try inserting a trusted connection.
    c = newconn(trustedID)
    if err := srv.checkpoint(c, srv.posthandshake); err != nil {
        t.Error("unexpected error for trusted conn @posthandshake:", err)
    }
    if !c.is(trustedConn) {
        t.Error("Server did not set trusted flag")
    }

}

func TestServerSetupConn(t *testing.T) {
    id := randomID()
    srvkey := newkey()
    srvid := discover.PubkeyID(&srvkey.PublicKey)
    tests := []struct {
        dontstart bool
        tt        *setupTransport
        flags     connFlag
        dialDest  *discover.Node

        wantCloseErr error
        wantCalls    string
    }{
        {
            dontstart:    true,
            tt:           &setupTransport{id: id},
            wantCalls:    "close,",
            wantCloseErr: errServerStopped,
        },
        {
            tt:           &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
            flags:        inboundConn,
            wantCalls:    "doEncHandshake,close,",
            wantCloseErr: errors.New("read error"),
        },
        {
            tt:           &setupTransport{id: id},
            dialDest:     &discover.Node{ID: randomID()},
            flags:        dynDialedConn,
            wantCalls:    "doEncHandshake,close,",
            wantCloseErr: DiscUnexpectedIdentity,
        },
        {
            tt:           &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
            dialDest:     &discover.Node{ID: id},
            flags:        dynDialedConn,
            wantCalls:    "doEncHandshake,doProtoHandshake,close,",
            wantCloseErr: DiscUnexpectedIdentity,
        },
        {
            tt:           &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
            dialDest:     &discover.Node{ID: id},
            flags:        dynDialedConn,
            wantCalls:    "doEncHandshake,doProtoHandshake,close,",
            wantCloseErr: errors.New("foo"),
        },
        {
            tt:           &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
            flags:        inboundConn,
            wantCalls:    "doEncHandshake,close,",
            wantCloseErr: DiscSelf,
        },
        {
            tt:           &setupTransport{id: id, phs: &protoHandshake{ID: id}},
            flags:        inboundConn,
            wantCalls:    "doEncHandshake,doProtoHandshake,close,",
            wantCloseErr: DiscUselessPeer,
        },
    }

    for i, test := range tests {
        srv := &Server{
            PrivateKey:   srvkey,
            MaxPeers:     10,
            NoDial:       true,
            Protocols:    []Protocol{discard},
            newTransport: func(fd net.Conn) transport { return test.tt },
        }
        if !test.dontstart {
            if err := srv.Start(); err != nil {
                t.Fatalf("couldn't start server: %v", err)
            }
        }
        p1, _ := net.Pipe()
        srv.setupConn(p1, test.flags, test.dialDest)
        if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
            t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
        }
        if test.tt.calls != test.wantCalls {
            t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
        }
    }
}

type setupTransport struct {
    id              discover.NodeID
    encHandshakeErr error

    phs               *protoHandshake
    protoHandshakeErr error

    calls    string
    closeErr error
}

func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
    c.calls += "doEncHandshake,"
    return c.id, c.encHandshakeErr
}
func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
    c.calls += "doProtoHandshake,"
    if c.protoHandshakeErr != nil {
        return nil, c.protoHandshakeErr
    }
    return c.phs, nil
}
func (c *setupTransport) close(err error) {
    c.calls += "close,"
    c.closeErr = err
}

// setupConn shouldn't write to/read from the connection.
func (c *setupTransport) WriteMsg(Msg) error {
    panic("WriteMsg called on setupTransport")
}
func (c *setupTransport) ReadMsg() (Msg, error) {
    panic("ReadMsg called on setupTransport")
}

func newkey() *ecdsa.PrivateKey {
    key, err := crypto.GenerateKey()
    if err != nil {
        panic("couldn't generate key: " + err.Error())
    }
    return key
}

func randomID() (id discover.NodeID) {
    for i := range id {
        id[i] = byte(rand.Intn(255))
    }
    return id
}