aboutsummaryrefslogblamecommitdiffstats
path: root/p2p/server_test.go
blob: 313d086ec4af8864efbdd1569bd635d1e941525c (plain) (tree)
1
2
3
4
5
6
7
8
9
                                         
                                                
  
                                                                                  


                                                                              
                                                                             
                                                                 
                                                               

                                                                           
                                                                                  
 

           
                      
                
                   
             
                 
                 
                                                
                                                     
                                                      
 



































                                                                                                                 




                                          
                          
                                     
                                                                                              
         
                                                           
         
                     
 
                                     
                                     



                                                                      
                             
                                                                 
                              






                                                                          
         
                          
 
                                 
                                                                            
                                                                                 
                                                                    
                 


                                                                                         
                                                                  

         
                                   
                                                              
                                                         
                                                             




                                              
                                                     
                 

                                
                           
                                     
                                                                          

                              
                                                 
                                                                                         

                                
                                  
                        
                                         




                                                               
                                                                                    
                                                                                         
                                                                            
                         


                                                                                                 
                                                                               
                 
 
                                                                   
         
 
















                                                                                            
                 
         
 

                                                          


                                              
         



                               
 

                                                 
         






                                                                                  
                 
         
 
                   
                

                                                                 
         
 
 
















































                                                                                                  


                                                                          
 






                                                                                                
                                                
 

                      
 

                                    
 




                                                             




                                                                        
         
                                                    
         
                        
 



                                                                                                       
 



                                                                      
                 
         


                                                                                 
         


                                                                                 
         
                                                          
         
 
 
























































                                                                                                         
         
 
                                    




                                                                
                                                                                     
                 

                                                                          
                         
                 






                                                                                                                          

         

                                       
 
                                         
 

                       
 






                                                                                                                  
         











                                                         
 












                                                              
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.

package p2p

import (
    "crypto/ecdsa"
    "errors"
    "math/rand"
    "net"
    "reflect"
    "testing"
    "time"

    "github.com/ethereum/go-ethereum/crypto"
    "github.com/ethereum/go-ethereum/crypto/sha3"
    "github.com/ethereum/go-ethereum/p2p/discover"
)

func init() {
    // glog.SetV(6)
    // glog.SetToStderr(true)
}

type testTransport struct {
    id discover.NodeID
    *rlpx

    closeErr error
}

func newTestTransport(id discover.NodeID, fd net.Conn) transport {
    wrapped := newRLPX(fd).(*rlpx)
    wrapped.rw = newRLPXFrameRW(fd, secrets{
        MAC:        zero16,
        AES:        zero16,
        IngressMAC: sha3.NewKeccak256(),
        EgressMAC:  sha3.NewKeccak256(),
    })
    return &testTransport{id: id, rlpx: wrapped}
}

func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
    return c.id, nil
}

func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
    return &protoHandshake{ID: c.id, Name: "test"}, nil
}

func (c *testTransport) close(err error) {
    c.rlpx.fd.Close()
    c.closeErr = err
}

func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
    config := Config{
        Name:       "test",
        MaxPeers:   10,
        ListenAddr: "127.0.0.1:0",
        PrivateKey: newkey(),
    }
    server := &Server{
        Config:       config,
        newPeerHook:  pf,
        newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
    }
    if err := server.Start(); err != nil {
        t.Fatalf("Could not start server: %v", err)
    }
    return server
}

func TestServerListen(t *testing.T) {
    // start the test server
    connected := make(chan *Peer)
    remid := randomID()
    srv := startTestServer(t, remid, func(p *Peer) {
        if p.ID() != remid {
            t.Error("peer func called with wrong node id")
        }
        if p == nil {
            t.Error("peer func called with nil conn")
        }
        connected <- p
    })
    defer close(connected)
    defer srv.Stop()

    // dial the test server
    conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
    if err != nil {
        t.Fatalf("could not dial: %v", err)
    }
    defer conn.Close()

    select {
    case peer := <-connected:
        if peer.LocalAddr().String() != conn.RemoteAddr().String() {
            t.Errorf("peer started with wrong conn: got %v, want %v",
                peer.LocalAddr(), conn.RemoteAddr())
        }
        peers := srv.Peers()
        if !reflect.DeepEqual(peers, []*Peer{peer}) {
            t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
        }
    case <-time.After(1 * time.Second):
        t.Error("server did not accept within one second")
    }
}

func TestServerDial(t *testing.T) {
    // run a one-shot TCP server to handle the connection.
    listener, err := net.Listen("tcp", "127.0.0.1:0")
    if err != nil {
        t.Fatalf("could not setup listener: %v", err)
    }
    defer listener.Close()
    accepted := make(chan net.Conn)
    go func() {
        conn, err := listener.Accept()
        if err != nil {
            t.Error("accept error:", err)
            return
        }
        accepted <- conn
    }()

    // start the server
    connected := make(chan *Peer)
    remid := randomID()
    srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
    defer close(connected)
    defer srv.Stop()

    // tell the server to connect
    tcpAddr := listener.Addr().(*net.TCPAddr)
    srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})

    select {
    case conn := <-accepted:
        defer conn.Close()

        select {
        case peer := <-connected:
            if peer.ID() != remid {
                t.Errorf("peer has wrong id")
            }
            if peer.Name() != "test" {
                t.Errorf("peer has wrong name")
            }
            if peer.RemoteAddr().String() != conn.LocalAddr().String() {
                t.Errorf("peer started with wrong conn: got %v, want %v",
                    peer.RemoteAddr(), conn.LocalAddr())
            }
            peers := srv.Peers()
            if !reflect.DeepEqual(peers, []*Peer{peer}) {
                t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
            }
        case <-time.After(1 * time.Second):
            t.Error("server did not launch peer within one second")
        }

    case <-time.After(1 * time.Second):
        t.Error("server did not connect within one second")
    }
}

// This test checks that tasks generated by dialstate are
// actually executed and taskdone is called for them.
func TestServerTaskScheduling(t *testing.T) {
    var (
        done           = make(chan *testTask)
        quit, returned = make(chan struct{}), make(chan struct{})
        tc             = 0
        tg             = taskgen{
            newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
                tc++
                return []task{&testTask{index: tc - 1}}
            },
            doneFunc: func(t task) {
                select {
                case done <- t.(*testTask):
                case <-quit:
                }
            },
        }
    )

    // The Server in this test isn't actually running
    // because we're only interested in what run does.
    srv := &Server{
        Config:  Config{MaxPeers: 10},
        quit:    make(chan struct{}),
        ntab:    fakeTable{},
        running: true,
    }
    srv.loopWG.Add(1)
    go func() {
        srv.run(tg)
        close(returned)
    }()

    var gotdone []*testTask
    for i := 0; i < 100; i++ {
        gotdone = append(gotdone, <-done)
    }
    for i, task := range gotdone {
        if task.index != i {
            t.Errorf("task %d has wrong index, got %d", i, task.index)
            break
        }
        if !task.called {
            t.Errorf("task %d was not called", i)
            break
        }
    }

    close(quit)
    srv.Stop()
    select {
    case <-returned:
    case <-time.After(500 * time.Millisecond):
        t.Error("Server.run did not return within 500ms")
    }
}

// This test checks that Server doesn't drop tasks,
// even if newTasks returns more than the maximum number of tasks.
func TestServerManyTasks(t *testing.T) {
    alltasks := make([]task, 300)
    for i := range alltasks {
        alltasks[i] = &testTask{index: i}
    }

    var (
        srv        = &Server{quit: make(chan struct{}), ntab: fakeTable{}, running: true}
        done       = make(chan *testTask)
        start, end = 0, 0
    )
    defer srv.Stop()
    srv.loopWG.Add(1)
    go srv.run(taskgen{
        newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
            start, end = end, end+maxActiveDialTasks+10
            if end > len(alltasks) {
                end = len(alltasks)
            }
            return alltasks[start:end]
        },
        doneFunc: func(tt task) {
            done <- tt.(*testTask)
        },
    })

    doneset := make(map[int]bool)
    timeout := time.After(2 * time.Second)
    for len(doneset) < len(alltasks) {
        select {
        case tt := <-done:
            if doneset[tt.index] {
                t.Errorf("task %d got done more than once", tt.index)
            } else {
                doneset[tt.index] = true
            }
        case <-timeout:
            t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks))
            for i := 0; i < len(alltasks); i++ {
                if !doneset[i] {
                    t.Logf("task %d not done", i)
                }
            }
            return
        }
    }
}

type taskgen struct {
    newFunc  func(running int, peers map[discover.NodeID]*Peer) []task
    doneFunc func(task)
}

func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
    return tg.newFunc(running, peers)
}
func (tg taskgen) taskDone(t task, now time.Time) {
    tg.doneFunc(t)
}
func (tg taskgen) addStatic(*discover.Node) {
}
func (tg taskgen) removeStatic(*discover.Node) {
}

type testTask struct {
    index  int
    called bool
}

func (t *testTask) Do(srv *Server) {
    t.called = true
}

// This test checks that connections are disconnected
// just after the encryption handshake when the server is
// at capacity. Trusted connections should still be accepted.
func TestServerAtCap(t *testing.T) {
    trustedID := randomID()
    srv := &Server{
        Config: Config{
            PrivateKey:   newkey(),
            MaxPeers:     10,
            NoDial:       true,
            TrustedNodes: []*discover.Node{{ID: trustedID}},
        },
    }
    if err := srv.Start(); err != nil {
        t.Fatalf("could not start: %v", err)
    }
    defer srv.Stop()

    newconn := func(id discover.NodeID) *conn {
        fd, _ := net.Pipe()
        tx := newTestTransport(id, fd)
        return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
    }

    // Inject a few connections to fill up the peer set.
    for i := 0; i < 10; i++ {
        c := newconn(randomID())
        if err := srv.checkpoint(c, srv.addpeer); err != nil {
            t.Fatalf("could not add conn %d: %v", i, err)
        }
    }
    // Try inserting a non-trusted connection.
    c := newconn(randomID())
    if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
        t.Error("wrong error for insert:", err)
    }
    // Try inserting a trusted connection.
    c = newconn(trustedID)
    if err := srv.checkpoint(c, srv.posthandshake); err != nil {
        t.Error("unexpected error for trusted conn @posthandshake:", err)
    }
    if !c.is(trustedConn) {
        t.Error("Server did not set trusted flag")
    }

}

func TestServerSetupConn(t *testing.T) {
    id := randomID()
    srvkey := newkey()
    srvid := discover.PubkeyID(&srvkey.PublicKey)
    tests := []struct {
        dontstart bool
        tt        *setupTransport
        flags     connFlag
        dialDest  *discover.Node

        wantCloseErr error
        wantCalls    string
    }{
        {
            dontstart:    true,
            tt:           &setupTransport{id: id},
            wantCalls:    "close,",
            wantCloseErr: errServerStopped,
        },
        {
            tt:           &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
            flags:        inboundConn,
            wantCalls:    "doEncHandshake,close,",
            wantCloseErr: errors.New("read error"),
        },
        {
            tt:           &setupTransport{id: id},
            dialDest:     &discover.Node{ID: randomID()},
            flags:        dynDialedConn,
            wantCalls:    "doEncHandshake,close,",
            wantCloseErr: DiscUnexpectedIdentity,
        },
        {
            tt:           &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
            dialDest:     &discover.Node{ID: id},
            flags:        dynDialedConn,
            wantCalls:    "doEncHandshake,doProtoHandshake,close,",
            wantCloseErr: DiscUnexpectedIdentity,
        },
        {
            tt:           &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
            dialDest:     &discover.Node{ID: id},
            flags:        dynDialedConn,
            wantCalls:    "doEncHandshake,doProtoHandshake,close,",
            wantCloseErr: errors.New("foo"),
        },
        {
            tt:           &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
            flags:        inboundConn,
            wantCalls:    "doEncHandshake,close,",
            wantCloseErr: DiscSelf,
        },
        {
            tt:           &setupTransport{id: id, phs: &protoHandshake{ID: id}},
            flags:        inboundConn,
            wantCalls:    "doEncHandshake,doProtoHandshake,close,",
            wantCloseErr: DiscUselessPeer,
        },
    }

    for i, test := range tests {
        srv := &Server{
            Config: Config{
                PrivateKey: srvkey,
                MaxPeers:   10,
                NoDial:     true,
                Protocols:  []Protocol{discard},
            },
            newTransport: func(fd net.Conn) transport { return test.tt },
        }
        if !test.dontstart {
            if err := srv.Start(); err != nil {
                t.Fatalf("couldn't start server: %v", err)
            }
        }
        p1, _ := net.Pipe()
        srv.setupConn(p1, test.flags, test.dialDest)
        if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
            t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
        }
        if test.tt.calls != test.wantCalls {
            t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
        }
    }
}

type setupTransport struct {
    id              discover.NodeID
    encHandshakeErr error

    phs               *protoHandshake
    protoHandshakeErr error

    calls    string
    closeErr error
}

func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
    c.calls += "doEncHandshake,"
    return c.id, c.encHandshakeErr
}
func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
    c.calls += "doProtoHandshake,"
    if c.protoHandshakeErr != nil {
        return nil, c.protoHandshakeErr
    }
    return c.phs, nil
}
func (c *setupTransport) close(err error) {
    c.calls += "close,"
    c.closeErr = err
}

// setupConn shouldn't write to/read from the connection.
func (c *setupTransport) WriteMsg(Msg) error {
    panic("WriteMsg called on setupTransport")
}
func (c *setupTransport) ReadMsg() (Msg, error) {
    panic("ReadMsg called on setupTransport")
}

func newkey() *ecdsa.PrivateKey {
    key, err := crypto.GenerateKey()
    if err != nil {
        panic("couldn't generate key: " + err.Error())
    }
    return key
}

func randomID() (id discover.NodeID) {
    for i := range id {
        id[i] = byte(rand.Intn(255))
    }
    return id
}