aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--p2p/dial.go276
-rw-r--r--p2p/dial_test.go482
-rw-r--r--p2p/handshake.go448
-rw-r--r--p2p/handshake_test.go172
-rw-r--r--p2p/peer.go55
-rw-r--r--p2p/peer_error.go56
-rw-r--r--p2p/peer_test.go20
-rw-r--r--p2p/rlpx.go444
-rw-r--r--p2p/rlpx_test.go244
-rw-r--r--p2p/server.go666
-rw-r--r--p2p/server_test.go584
11 files changed, 2118 insertions, 1329 deletions
diff --git a/p2p/dial.go b/p2p/dial.go
new file mode 100644
index 000000000..71065c5ee
--- /dev/null
+++ b/p2p/dial.go
@@ -0,0 +1,276 @@
+package p2p
+
+import (
+ "container/heap"
+ "crypto/rand"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/logger/glog"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+)
+
+const (
+ // This is the amount of time spent waiting in between
+ // redialing a certain node.
+ dialHistoryExpiration = 30 * time.Second
+
+ // Discovery lookup tasks will wait for this long when
+ // no results are returned. This can happen if the table
+ // becomes empty (i.e. not often).
+ emptyLookupDelay = 10 * time.Second
+)
+
+// dialstate schedules dials and discovery lookups.
+// it get's a chance to compute new tasks on every iteration
+// of the main loop in Server.run.
+type dialstate struct {
+ maxDynDials int
+ ntab discoverTable
+
+ lookupRunning bool
+ bootstrapped bool
+
+ dialing map[discover.NodeID]connFlag
+ lookupBuf []*discover.Node // current discovery lookup results
+ randomNodes []*discover.Node // filled from Table
+ static map[discover.NodeID]*discover.Node
+ hist *dialHistory
+}
+
+type discoverTable interface {
+ Self() *discover.Node
+ Close()
+ Bootstrap([]*discover.Node)
+ Lookup(target discover.NodeID) []*discover.Node
+ ReadRandomNodes([]*discover.Node) int
+}
+
+// the dial history remembers recent dials.
+type dialHistory []pastDial
+
+// pastDial is an entry in the dial history.
+type pastDial struct {
+ id discover.NodeID
+ exp time.Time
+}
+
+type task interface {
+ Do(*Server)
+}
+
+// A dialTask is generated for each node that is dialed.
+type dialTask struct {
+ flags connFlag
+ dest *discover.Node
+}
+
+// discoverTask runs discovery table operations.
+// Only one discoverTask is active at any time.
+//
+// If bootstrap is true, the task runs Table.Bootstrap,
+// otherwise it performs a random lookup and leaves the
+// results in the task.
+type discoverTask struct {
+ bootstrap bool
+ results []*discover.Node
+}
+
+// A waitExpireTask is generated if there are no other tasks
+// to keep the loop in Server.run ticking.
+type waitExpireTask struct {
+ time.Duration
+}
+
+func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate {
+ s := &dialstate{
+ maxDynDials: maxdyn,
+ ntab: ntab,
+ static: make(map[discover.NodeID]*discover.Node),
+ dialing: make(map[discover.NodeID]connFlag),
+ randomNodes: make([]*discover.Node, maxdyn/2),
+ hist: new(dialHistory),
+ }
+ for _, n := range static {
+ s.static[n.ID] = n
+ }
+ return s
+}
+
+func (s *dialstate) addStatic(n *discover.Node) {
+ s.static[n.ID] = n
+}
+
+func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task {
+ var newtasks []task
+ addDial := func(flag connFlag, n *discover.Node) bool {
+ _, dialing := s.dialing[n.ID]
+ if dialing || peers[n.ID] != nil || s.hist.contains(n.ID) {
+ return false
+ }
+ s.dialing[n.ID] = flag
+ newtasks = append(newtasks, &dialTask{flags: flag, dest: n})
+ return true
+ }
+
+ // Compute number of dynamic dials necessary at this point.
+ needDynDials := s.maxDynDials
+ for _, p := range peers {
+ if p.rw.is(dynDialedConn) {
+ needDynDials--
+ }
+ }
+ for _, flag := range s.dialing {
+ if flag&dynDialedConn != 0 {
+ needDynDials--
+ }
+ }
+
+ // Expire the dial history on every invocation.
+ s.hist.expire(now)
+
+ // Create dials for static nodes if they are not connected.
+ for _, n := range s.static {
+ addDial(staticDialedConn, n)
+ }
+
+ // Use random nodes from the table for half of the necessary
+ // dynamic dials.
+ randomCandidates := needDynDials / 2
+ if randomCandidates > 0 && s.bootstrapped {
+ n := s.ntab.ReadRandomNodes(s.randomNodes)
+ for i := 0; i < randomCandidates && i < n; i++ {
+ if addDial(dynDialedConn, s.randomNodes[i]) {
+ needDynDials--
+ }
+ }
+ }
+ // Create dynamic dials from random lookup results, removing tried
+ // items from the result buffer.
+ i := 0
+ for ; i < len(s.lookupBuf) && needDynDials > 0; i++ {
+ if addDial(dynDialedConn, s.lookupBuf[i]) {
+ needDynDials--
+ }
+ }
+ s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])]
+ // Launch a discovery lookup if more candidates are needed. The
+ // first discoverTask bootstraps the table and won't return any
+ // results.
+ if len(s.lookupBuf) < needDynDials && !s.lookupRunning {
+ s.lookupRunning = true
+ newtasks = append(newtasks, &discoverTask{bootstrap: !s.bootstrapped})
+ }
+
+ // Launch a timer to wait for the next node to expire if all
+ // candidates have been tried and no task is currently active.
+ // This should prevent cases where the dialer logic is not ticked
+ // because there are no pending events.
+ if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
+ t := &waitExpireTask{s.hist.min().exp.Sub(now)}
+ newtasks = append(newtasks, t)
+ }
+ return newtasks
+}
+
+func (s *dialstate) taskDone(t task, now time.Time) {
+ switch t := t.(type) {
+ case *dialTask:
+ s.hist.add(t.dest.ID, now.Add(dialHistoryExpiration))
+ delete(s.dialing, t.dest.ID)
+ case *discoverTask:
+ if t.bootstrap {
+ s.bootstrapped = true
+ }
+ s.lookupRunning = false
+ s.lookupBuf = append(s.lookupBuf, t.results...)
+ }
+}
+
+func (t *dialTask) Do(srv *Server) {
+ addr := &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)}
+ glog.V(logger.Debug).Infof("dialing %v\n", t.dest)
+ fd, err := srv.Dialer.Dial("tcp", addr.String())
+ if err != nil {
+ glog.V(logger.Detail).Infof("dial error: %v", err)
+ return
+ }
+ srv.setupConn(fd, t.flags, t.dest)
+}
+func (t *dialTask) String() string {
+ return fmt.Sprintf("%v %x %v:%d", t.flags, t.dest.ID[:8], t.dest.IP, t.dest.TCP)
+}
+
+func (t *discoverTask) Do(srv *Server) {
+ if t.bootstrap {
+ srv.ntab.Bootstrap(srv.BootstrapNodes)
+ } else {
+ var target discover.NodeID
+ rand.Read(target[:])
+ t.results = srv.ntab.Lookup(target)
+ // newTasks generates a lookup task whenever dynamic dials are
+ // necessary. Lookups need to take some time, otherwise the
+ // event loop spins too fast. An empty result can only be
+ // returned if the table is empty.
+ if len(t.results) == 0 {
+ time.Sleep(emptyLookupDelay)
+ }
+ }
+}
+
+func (t *discoverTask) String() (s string) {
+ if t.bootstrap {
+ s = "discovery bootstrap"
+ } else {
+ s = "discovery lookup"
+ }
+ if len(t.results) > 0 {
+ s += fmt.Sprintf(" (%d results)", len(t.results))
+ }
+ return s
+}
+
+func (t waitExpireTask) Do(*Server) {
+ time.Sleep(t.Duration)
+}
+func (t waitExpireTask) String() string {
+ return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
+}
+
+// Use only these methods to access or modify dialHistory.
+func (h dialHistory) min() pastDial {
+ return h[0]
+}
+func (h *dialHistory) add(id discover.NodeID, exp time.Time) {
+ heap.Push(h, pastDial{id, exp})
+}
+func (h dialHistory) contains(id discover.NodeID) bool {
+ for _, v := range h {
+ if v.id == id {
+ return true
+ }
+ }
+ return false
+}
+func (h *dialHistory) expire(now time.Time) {
+ for h.Len() > 0 && h.min().exp.Before(now) {
+ heap.Pop(h)
+ }
+}
+
+// heap.Interface boilerplate
+func (h dialHistory) Len() int { return len(h) }
+func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
+func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
+func (h *dialHistory) Push(x interface{}) {
+ *h = append(*h, x.(pastDial))
+}
+func (h *dialHistory) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[0 : n-1]
+ return x
+}
diff --git a/p2p/dial_test.go b/p2p/dial_test.go
new file mode 100644
index 000000000..78568c5ed
--- /dev/null
+++ b/p2p/dial_test.go
@@ -0,0 +1,482 @@
+package p2p
+
+import (
+ "encoding/binary"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/davecgh/go-spew/spew"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+)
+
+func init() {
+ spew.Config.Indent = "\t"
+}
+
+type dialtest struct {
+ init *dialstate // state before and after the test.
+ rounds []round
+}
+
+type round struct {
+ peers []*Peer // current peer set
+ done []task // tasks that got done this round
+ new []task // the result must match this one
+}
+
+func runDialTest(t *testing.T, test dialtest) {
+ var (
+ vtime time.Time
+ running int
+ )
+ pm := func(ps []*Peer) map[discover.NodeID]*Peer {
+ m := make(map[discover.NodeID]*Peer)
+ for _, p := range ps {
+ m[p.rw.id] = p
+ }
+ return m
+ }
+ for i, round := range test.rounds {
+ for _, task := range round.done {
+ running--
+ if running < 0 {
+ panic("running task counter underflow")
+ }
+ test.init.taskDone(task, vtime)
+ }
+
+ new := test.init.newTasks(running, pm(round.peers), vtime)
+ if !sametasks(new, round.new) {
+ t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n",
+ i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
+ }
+
+ // Time advances by 16 seconds on every round.
+ vtime = vtime.Add(16 * time.Second)
+ running += len(new)
+ }
+}
+
+type fakeTable []*discover.Node
+
+func (t fakeTable) Self() *discover.Node { return new(discover.Node) }
+func (t fakeTable) Close() {}
+func (t fakeTable) Bootstrap([]*discover.Node) {}
+func (t fakeTable) Lookup(target discover.NodeID) []*discover.Node {
+ return nil
+}
+func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int {
+ return copy(buf, t)
+}
+
+// This test checks that dynamic dials are launched from discovery results.
+func TestDialStateDynDial(t *testing.T) {
+ runDialTest(t, dialtest{
+ init: newDialState(nil, fakeTable{}, 5),
+ rounds: []round{
+ // A discovery query is launched.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ },
+ new: []task{&discoverTask{bootstrap: true}},
+ },
+ // Dynamic dials are launched when it completes.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ },
+ done: []task{
+ &discoverTask{bootstrap: true, results: []*discover.Node{
+ {ID: uintID(2)}, // this one is already connected and not dialed.
+ {ID: uintID(3)},
+ {ID: uintID(4)},
+ {ID: uintID(5)},
+ {ID: uintID(6)}, // these are not tried because max dyn dials is 5
+ {ID: uintID(7)}, // ...
+ }},
+ },
+ new: []task{
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
+ },
+ },
+ // Some of the dials complete but no new ones are launched yet because
+ // the sum of active dial count and dynamic peer count is == maxDynDials.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(3)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(4)}},
+ },
+ done: []task{
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
+ },
+ },
+ // No new dial tasks are launched in the this round because
+ // maxDynDials has been reached.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(3)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(4)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(5)}},
+ },
+ done: []task{
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
+ },
+ new: []task{
+ &waitExpireTask{Duration: 14 * time.Second},
+ },
+ },
+ // In this round, the peer with id 2 drops off. The query
+ // results from last discovery lookup are reused.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(3)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(4)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(5)}},
+ },
+ new: []task{
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(6)}},
+ },
+ },
+ // More peers (3,4) drop off and dial for ID 6 completes.
+ // The last query result from the discovery lookup is reused
+ // and a new one is spawned because more candidates are needed.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(5)}},
+ },
+ done: []task{
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(6)}},
+ },
+ new: []task{
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(7)}},
+ &discoverTask{},
+ },
+ },
+ // Peer 7 is connected, but there still aren't enough dynamic peers
+ // (4 out of 5). However, a discovery is already running, so ensure
+ // no new is started.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(5)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(7)}},
+ },
+ done: []task{
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(7)}},
+ },
+ },
+ // Finish the running node discovery with an empty set. A new lookup
+ // should be immediately requested.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(5)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(7)}},
+ },
+ done: []task{
+ &discoverTask{},
+ },
+ new: []task{
+ &discoverTask{},
+ },
+ },
+ },
+ })
+}
+
+func TestDialStateDynDialFromTable(t *testing.T) {
+ // This table always returns the same random nodes
+ // in the order given below.
+ table := fakeTable{
+ {ID: uintID(1)},
+ {ID: uintID(2)},
+ {ID: uintID(3)},
+ {ID: uintID(4)},
+ {ID: uintID(5)},
+ {ID: uintID(6)},
+ {ID: uintID(7)},
+ {ID: uintID(8)},
+ }
+
+ runDialTest(t, dialtest{
+ init: newDialState(nil, table, 10),
+ rounds: []round{
+ // Discovery bootstrap is launched.
+ {
+ new: []task{&discoverTask{bootstrap: true}},
+ },
+ // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
+ {
+ done: []task{
+ &discoverTask{bootstrap: true},
+ },
+ new: []task{
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(1)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(2)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
+ &discoverTask{bootstrap: false},
+ },
+ },
+ // Dialing nodes 1,2 succeeds. Dials from the lookup are launched.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ },
+ done: []task{
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(1)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(2)}},
+ &discoverTask{results: []*discover.Node{
+ {ID: uintID(10)},
+ {ID: uintID(11)},
+ {ID: uintID(12)},
+ }},
+ },
+ new: []task{
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(10)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(11)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(12)}},
+ &discoverTask{bootstrap: false},
+ },
+ },
+ // Dialing nodes 3,4,5 fails. The dials from the lookup succeed.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(10)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(11)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(12)}},
+ },
+ done: []task{
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(10)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(11)}},
+ &dialTask{dynDialedConn, &discover.Node{ID: uintID(12)}},
+ },
+ },
+ // Waiting for expiry. No waitExpireTask is launched because the
+ // discovery query is still running.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(10)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(11)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(12)}},
+ },
+ },
+ // Nodes 3,4 are not tried again because only the first two
+ // returned random nodes (nodes 1,2) are tried and they're
+ // already connected.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(10)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(11)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(12)}},
+ },
+ },
+ },
+ })
+}
+
+// This test checks that static dials are launched.
+func TestDialStateStaticDial(t *testing.T) {
+ wantStatic := []*discover.Node{
+ {ID: uintID(1)},
+ {ID: uintID(2)},
+ {ID: uintID(3)},
+ {ID: uintID(4)},
+ {ID: uintID(5)},
+ }
+
+ runDialTest(t, dialtest{
+ init: newDialState(wantStatic, fakeTable{}, 0),
+ rounds: []round{
+ // Static dials are launched for the nodes that
+ // aren't yet connected.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ },
+ new: []task{
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}},
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(5)}},
+ },
+ },
+ // No new tasks are launched in this round because all static
+ // nodes are either connected or still being dialed.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: staticDialedConn, id: uintID(3)}},
+ },
+ done: []task{
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
+ },
+ },
+ // No new dial tasks are launched because all static
+ // nodes are now connected.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: staticDialedConn, id: uintID(3)}},
+ {rw: &conn{flags: staticDialedConn, id: uintID(4)}},
+ {rw: &conn{flags: staticDialedConn, id: uintID(5)}},
+ },
+ done: []task{
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}},
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(5)}},
+ },
+ new: []task{
+ &waitExpireTask{Duration: 14 * time.Second},
+ },
+ },
+ // Wait a round for dial history to expire, no new tasks should spawn.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: staticDialedConn, id: uintID(3)}},
+ {rw: &conn{flags: staticDialedConn, id: uintID(4)}},
+ {rw: &conn{flags: staticDialedConn, id: uintID(5)}},
+ },
+ },
+ // If a static node is dropped, it should be immediately redialed,
+ // irrespective whether it was originally static or dynamic.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: staticDialedConn, id: uintID(3)}},
+ {rw: &conn{flags: staticDialedConn, id: uintID(5)}},
+ },
+ new: []task{
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}},
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}},
+ },
+ },
+ },
+ })
+}
+
+// This test checks that past dials are not retried for some time.
+func TestDialStateCache(t *testing.T) {
+ wantStatic := []*discover.Node{
+ {ID: uintID(1)},
+ {ID: uintID(2)},
+ {ID: uintID(3)},
+ }
+
+ runDialTest(t, dialtest{
+ init: newDialState(wantStatic, fakeTable{}, 0),
+ rounds: []round{
+ // Static dials are launched for the nodes that
+ // aren't yet connected.
+ {
+ peers: nil,
+ new: []task{
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(1)}},
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}},
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
+ },
+ },
+ // No new tasks are launched in this round because all static
+ // nodes are either connected or still being dialed.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: staticDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: staticDialedConn, id: uintID(2)}},
+ },
+ done: []task{
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(1)}},
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}},
+ },
+ },
+ // A salvage task is launched to wait for node 3's history
+ // entry to expire.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ },
+ done: []task{
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
+ },
+ new: []task{
+ &waitExpireTask{Duration: 14 * time.Second},
+ },
+ },
+ // Still waiting for node 3's entry to expire in the cache.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ },
+ },
+ // The cache entry for node 3 has expired and is retried.
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
+ {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ },
+ new: []task{
+ &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
+ },
+ },
+ },
+ })
+}
+
+// compares task lists but doesn't care about the order.
+func sametasks(a, b []task) bool {
+ if len(a) != len(b) {
+ return false
+ }
+next:
+ for _, ta := range a {
+ for _, tb := range b {
+ if reflect.DeepEqual(ta, tb) {
+ continue next
+ }
+ }
+ return false
+ }
+ return true
+}
+
+func uintID(i uint32) discover.NodeID {
+ var id discover.NodeID
+ binary.BigEndian.PutUint32(id[:], i)
+ return id
+}
diff --git a/p2p/handshake.go b/p2p/handshake.go
deleted file mode 100644
index 4cdcee6d4..000000000
--- a/p2p/handshake.go
+++ /dev/null
@@ -1,448 +0,0 @@
-package p2p
-
-import (
- "crypto/ecdsa"
- "crypto/elliptic"
- "crypto/rand"
- "errors"
- "fmt"
- "hash"
- "io"
- "net"
-
- "github.com/ethereum/go-ethereum/crypto"
- "github.com/ethereum/go-ethereum/crypto/ecies"
- "github.com/ethereum/go-ethereum/crypto/secp256k1"
- "github.com/ethereum/go-ethereum/crypto/sha3"
- "github.com/ethereum/go-ethereum/p2p/discover"
- "github.com/ethereum/go-ethereum/rlp"
-)
-
-const (
- sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
- sigLen = 65 // elliptic S256
- pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
- shaLen = 32 // hash length (for nonce etc)
-
- authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
- authRespLen = pubLen + shaLen + 1
-
- eciesBytes = 65 + 16 + 32
- encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
- encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
-)
-
-// conn represents a remote connection after encryption handshake
-// and protocol handshake have completed.
-//
-// The MsgReadWriter is usually layered as follows:
-//
-// netWrapper (I/O timeouts, thread-safe ReadMsg, WriteMsg)
-// rlpxFrameRW (message encoding, encryption, authentication)
-// bufio.ReadWriter (buffering)
-// net.Conn (network I/O)
-//
-type conn struct {
- MsgReadWriter
- *protoHandshake
-}
-
-// secrets represents the connection secrets
-// which are negotiated during the encryption handshake.
-type secrets struct {
- RemoteID discover.NodeID
- AES, MAC []byte
- EgressMAC, IngressMAC hash.Hash
- Token []byte
-}
-
-// protoHandshake is the RLP structure of the protocol handshake.
-type protoHandshake struct {
- Version uint64
- Name string
- Caps []Cap
- ListenPort uint64
- ID discover.NodeID
-}
-
-// setupConn starts a protocol session on the given connection. It
-// runs the encryption handshake and the protocol handshake. If dial
-// is non-nil, the connection the local node is the initiator. If
-// keepconn returns false, the connection will be disconnected with
-// DiscTooManyPeers after the key exchange.
-func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) {
- if dial == nil {
- return setupInboundConn(fd, prv, our, keepconn)
- } else {
- return setupOutboundConn(fd, prv, our, dial, keepconn)
- }
-}
-
-func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, keepconn func(discover.NodeID) bool) (*conn, error) {
- secrets, err := receiverEncHandshake(fd, prv, nil)
- if err != nil {
- return nil, fmt.Errorf("encryption handshake failed: %v", err)
- }
- rw := newRlpxFrameRW(fd, secrets)
- if !keepconn(secrets.RemoteID) {
- SendItems(rw, discMsg, DiscTooManyPeers)
- return nil, errors.New("we have too many peers")
- }
- // Run the protocol handshake using authenticated messages.
- rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
- if err != nil {
- return nil, err
- }
- if err := Send(rw, handshakeMsg, our); err != nil {
- return nil, fmt.Errorf("protocol handshake write error: %v", err)
- }
- return &conn{rw, rhs}, nil
-}
-
-func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) {
- secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil)
- if err != nil {
- return nil, fmt.Errorf("encryption handshake failed: %v", err)
- }
- rw := newRlpxFrameRW(fd, secrets)
- if !keepconn(secrets.RemoteID) {
- SendItems(rw, discMsg, DiscTooManyPeers)
- return nil, errors.New("we have too many peers")
- }
- // Run the protocol handshake using authenticated messages.
- //
- // Note that even though writing the handshake is first, we prefer
- // returning the handshake read error. If the remote side
- // disconnects us early with a valid reason, we should return it
- // as the error so it can be tracked elsewhere.
- werr := make(chan error, 1)
- go func() { werr <- Send(rw, handshakeMsg, our) }()
- rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
- if err != nil {
- return nil, err
- }
- if err := <-werr; err != nil {
- return nil, fmt.Errorf("protocol handshake write error: %v", err)
- }
- if rhs.ID != dial.ID {
- return nil, errors.New("dialed node id mismatch")
- }
- return &conn{rw, rhs}, nil
-}
-
-// encHandshake contains the state of the encryption handshake.
-type encHandshake struct {
- initiator bool
- remoteID discover.NodeID
-
- remotePub *ecies.PublicKey // remote-pubk
- initNonce, respNonce []byte // nonce
- randomPrivKey *ecies.PrivateKey // ecdhe-random
- remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk
-}
-
-// secrets is called after the handshake is completed.
-// It extracts the connection secrets from the handshake values.
-func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) {
- ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen)
- if err != nil {
- return secrets{}, err
- }
-
- // derive base secrets from ephemeral key agreement
- sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce))
- aesSecret := crypto.Sha3(ecdheSecret, sharedSecret)
- s := secrets{
- RemoteID: h.remoteID,
- AES: aesSecret,
- MAC: crypto.Sha3(ecdheSecret, aesSecret),
- Token: crypto.Sha3(sharedSecret),
- }
-
- // setup sha3 instances for the MACs
- mac1 := sha3.NewKeccak256()
- mac1.Write(xor(s.MAC, h.respNonce))
- mac1.Write(auth)
- mac2 := sha3.NewKeccak256()
- mac2.Write(xor(s.MAC, h.initNonce))
- mac2.Write(authResp)
- if h.initiator {
- s.EgressMAC, s.IngressMAC = mac1, mac2
- } else {
- s.EgressMAC, s.IngressMAC = mac2, mac1
- }
-
- return s, nil
-}
-
-func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) {
- return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen)
-}
-
-// initiatorEncHandshake negotiates a session token on conn.
-// it should be called on the dialing side of the connection.
-//
-// prv is the local client's private key.
-// token is the token from a previous session with this node.
-func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) {
- h, err := newInitiatorHandshake(remoteID)
- if err != nil {
- return s, err
- }
- auth, err := h.authMsg(prv, token)
- if err != nil {
- return s, err
- }
- if _, err = conn.Write(auth); err != nil {
- return s, err
- }
-
- response := make([]byte, encAuthRespLen)
- if _, err = io.ReadFull(conn, response); err != nil {
- return s, err
- }
- if err := h.decodeAuthResp(response, prv); err != nil {
- return s, err
- }
- return h.secrets(auth, response)
-}
-
-func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) {
- // generate random initiator nonce
- n := make([]byte, shaLen)
- if _, err := rand.Read(n); err != nil {
- return nil, err
- }
- // generate random keypair to use for signing
- randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
- if err != nil {
- return nil, err
- }
- rpub, err := remoteID.Pubkey()
- if err != nil {
- return nil, fmt.Errorf("bad remoteID: %v", err)
- }
- h := &encHandshake{
- initiator: true,
- remoteID: remoteID,
- remotePub: ecies.ImportECDSAPublic(rpub),
- initNonce: n,
- randomPrivKey: randpriv,
- }
- return h, nil
-}
-
-// authMsg creates an encrypted initiator handshake message.
-func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
- var tokenFlag byte
- if token == nil {
- // no session token found means we need to generate shared secret.
- // ecies shared secret is used as initial session token for new peers
- // generate shared key from prv and remote pubkey
- var err error
- if token, err = h.ecdhShared(prv); err != nil {
- return nil, err
- }
- } else {
- // for known peers, we use stored token from the previous session
- tokenFlag = 0x01
- }
-
- // sign known message:
- // ecdh-shared-secret^nonce for new peers
- // token^nonce for old peers
- signed := xor(token, h.initNonce)
- signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA())
- if err != nil {
- return nil, err
- }
-
- // encode auth message
- // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
- msg := make([]byte, authMsgLen)
- n := copy(msg, signature)
- n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey)))
- n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:])
- n += copy(msg[n:], h.initNonce)
- msg[n] = tokenFlag
-
- // encrypt auth message using remote-pubk
- return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil)
-}
-
-// decodeAuthResp decode an encrypted authentication response message.
-func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error {
- msg, err := crypto.Decrypt(prv, auth)
- if err != nil {
- return fmt.Errorf("could not decrypt auth response (%v)", err)
- }
- h.respNonce = msg[pubLen : pubLen+shaLen]
- h.remoteRandomPub, err = importPublicKey(msg[:pubLen])
- if err != nil {
- return err
- }
- // ignore token flag for now
- return nil
-}
-
-// receiverEncHandshake negotiates a session token on conn.
-// it should be called on the listening side of the connection.
-//
-// prv is the local client's private key.
-// token is the token from a previous session with this node.
-func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) {
- // read remote auth sent by initiator.
- auth := make([]byte, encAuthMsgLen)
- if _, err := io.ReadFull(conn, auth); err != nil {
- return s, err
- }
- h, err := decodeAuthMsg(prv, token, auth)
- if err != nil {
- return s, err
- }
-
- // send auth response
- resp, err := h.authResp(prv, token)
- if err != nil {
- return s, err
- }
- if _, err = conn.Write(resp); err != nil {
- return s, err
- }
-
- return h.secrets(auth, resp)
-}
-
-func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) {
- var err error
- h := new(encHandshake)
- // generate random keypair for session
- h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
- if err != nil {
- return nil, err
- }
- // generate random nonce
- h.respNonce = make([]byte, shaLen)
- if _, err = rand.Read(h.respNonce); err != nil {
- return nil, err
- }
-
- msg, err := crypto.Decrypt(prv, auth)
- if err != nil {
- return nil, fmt.Errorf("could not decrypt auth message (%v)", err)
- }
-
- // decode message parameters
- // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
- h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
- copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen])
- rpub, err := h.remoteID.Pubkey()
- if err != nil {
- return nil, fmt.Errorf("bad remoteID: %#v", err)
- }
- h.remotePub = ecies.ImportECDSAPublic(rpub)
-
- // recover remote random pubkey from signed message.
- if token == nil {
- // TODO: it is an error if the initiator has a token and we don't. check that.
-
- // no session token means we need to generate shared secret.
- // ecies shared secret is used as initial session token for new peers.
- // generate shared key from prv and remote pubkey.
- if token, err = h.ecdhShared(prv); err != nil {
- return nil, err
- }
- }
- signedMsg := xor(token, h.initNonce)
- remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen])
- if err != nil {
- return nil, err
- }
- h.remoteRandomPub, _ = importPublicKey(remoteRandomPub)
- return h, nil
-}
-
-// authResp generates the encrypted authentication response message.
-func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
- // responder auth message
- // E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
- resp := make([]byte, authRespLen)
- n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey))
- n += copy(resp[n:], h.respNonce)
- if token == nil {
- resp[n] = 0
- } else {
- resp[n] = 1
- }
- // encrypt using remote-pubk
- return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil)
-}
-
-// importPublicKey unmarshals 512 bit public keys.
-func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) {
- var pubKey65 []byte
- switch len(pubKey) {
- case 64:
- // add 'uncompressed key' flag
- pubKey65 = append([]byte{0x04}, pubKey...)
- case 65:
- pubKey65 = pubKey
- default:
- return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
- }
- // TODO: fewer pointless conversions
- return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil
-}
-
-func exportPubkey(pub *ecies.PublicKey) []byte {
- if pub == nil {
- panic("nil pubkey")
- }
- return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:]
-}
-
-func xor(one, other []byte) (xor []byte) {
- xor = make([]byte, len(one))
- for i := 0; i < len(one); i++ {
- xor[i] = one[i] ^ other[i]
- }
- return xor
-}
-
-func readProtocolHandshake(rw MsgReadWriter, wantID discover.NodeID, our *protoHandshake) (*protoHandshake, error) {
- msg, err := rw.ReadMsg()
- if err != nil {
- return nil, err
- }
- if msg.Code == discMsg {
- // disconnect before protocol handshake is valid according to the
- // spec and we send it ourself if Server.addPeer fails.
- var reason [1]DiscReason
- rlp.Decode(msg.Payload, &reason)
- return nil, reason[0]
- }
- if msg.Code != handshakeMsg {
- return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
- }
- if msg.Size > baseProtocolMaxMsgSize {
- return nil, fmt.Errorf("message too big (%d > %d)", msg.Size, baseProtocolMaxMsgSize)
- }
- var hs protoHandshake
- if err := msg.Decode(&hs); err != nil {
- return nil, err
- }
- // validate handshake info
- if hs.Version != our.Version {
- SendItems(rw, discMsg, DiscIncompatibleVersion)
- return nil, fmt.Errorf("required version %d, received %d\n", baseProtocolVersion, hs.Version)
- }
- if (hs.ID == discover.NodeID{}) {
- SendItems(rw, discMsg, DiscInvalidIdentity)
- return nil, errors.New("invalid public key in handshake")
- }
- if hs.ID != wantID {
- SendItems(rw, discMsg, DiscUnexpectedIdentity)
- return nil, errors.New("handshake node ID does not match encryption handshake")
- }
- return &hs, nil
-}
diff --git a/p2p/handshake_test.go b/p2p/handshake_test.go
deleted file mode 100644
index ab75921a3..000000000
--- a/p2p/handshake_test.go
+++ /dev/null
@@ -1,172 +0,0 @@
-package p2p
-
-import (
- "bytes"
- "crypto/rand"
- "fmt"
- "net"
- "reflect"
- "testing"
- "time"
-
- "github.com/ethereum/go-ethereum/crypto"
- "github.com/ethereum/go-ethereum/crypto/ecies"
- "github.com/ethereum/go-ethereum/p2p/discover"
-)
-
-func TestSharedSecret(t *testing.T) {
- prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
- pub0 := &prv0.PublicKey
- prv1, _ := crypto.GenerateKey()
- pub1 := &prv1.PublicKey
-
- ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
- if err != nil {
- return
- }
- ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
- if err != nil {
- return
- }
- t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
- if !bytes.Equal(ss0, ss1) {
- t.Errorf("dont match :(")
- }
-}
-
-func TestEncHandshake(t *testing.T) {
- for i := 0; i < 20; i++ {
- start := time.Now()
- if err := testEncHandshake(nil); err != nil {
- t.Fatalf("i=%d %v", i, err)
- }
- t.Logf("(without token) %d %v\n", i+1, time.Since(start))
- }
-
- for i := 0; i < 20; i++ {
- tok := make([]byte, shaLen)
- rand.Reader.Read(tok)
- start := time.Now()
- if err := testEncHandshake(tok); err != nil {
- t.Fatalf("i=%d %v", i, err)
- }
- t.Logf("(with token) %d %v\n", i+1, time.Since(start))
- }
-}
-
-func testEncHandshake(token []byte) error {
- type result struct {
- side string
- s secrets
- err error
- }
- var (
- prv0, _ = crypto.GenerateKey()
- prv1, _ = crypto.GenerateKey()
- rw0, rw1 = net.Pipe()
- output = make(chan result)
- )
-
- go func() {
- r := result{side: "initiator"}
- defer func() { output <- r }()
-
- pub1s := discover.PubkeyID(&prv1.PublicKey)
- r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
- if r.err != nil {
- return
- }
- id1 := discover.PubkeyID(&prv1.PublicKey)
- if r.s.RemoteID != id1 {
- r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
- }
- }()
- go func() {
- r := result{side: "receiver"}
- defer func() { output <- r }()
-
- r.s, r.err = receiverEncHandshake(rw1, prv1, token)
- if r.err != nil {
- return
- }
- id0 := discover.PubkeyID(&prv0.PublicKey)
- if r.s.RemoteID != id0 {
- r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
- }
- }()
-
- // wait for results from both sides
- r1, r2 := <-output, <-output
-
- if r1.err != nil {
- return fmt.Errorf("%s side error: %v", r1.side, r1.err)
- }
- if r2.err != nil {
- return fmt.Errorf("%s side error: %v", r2.side, r2.err)
- }
-
- // don't compare remote node IDs
- r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
- // flip MACs on one of them so they compare equal
- r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
- if !reflect.DeepEqual(r1.s, r2.s) {
- return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
- }
- return nil
-}
-
-func TestSetupConn(t *testing.T) {
- prv0, _ := crypto.GenerateKey()
- prv1, _ := crypto.GenerateKey()
- node0 := &discover.Node{
- ID: discover.PubkeyID(&prv0.PublicKey),
- IP: net.IP{1, 2, 3, 4},
- TCP: 33,
- }
- node1 := &discover.Node{
- ID: discover.PubkeyID(&prv1.PublicKey),
- IP: net.IP{5, 6, 7, 8},
- TCP: 44,
- }
- hs0 := &protoHandshake{
- Version: baseProtocolVersion,
- ID: node0.ID,
- Caps: []Cap{{"a", 0}, {"b", 2}},
- }
- hs1 := &protoHandshake{
- Version: baseProtocolVersion,
- ID: node1.ID,
- Caps: []Cap{{"c", 1}, {"d", 3}},
- }
- fd0, fd1 := net.Pipe()
-
- done := make(chan struct{})
- keepalways := func(discover.NodeID) bool { return true }
- go func() {
- defer close(done)
- conn0, err := setupConn(fd0, prv0, hs0, node1, keepalways)
- if err != nil {
- t.Errorf("outbound side error: %v", err)
- return
- }
- if conn0.ID != node1.ID {
- t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
- }
- if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
- t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
- }
- }()
-
- conn1, err := setupConn(fd1, prv1, hs1, nil, keepalways)
- if err != nil {
- t.Fatalf("inbound side error: %v", err)
- }
- if conn1.ID != node0.ID {
- t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
- }
- if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
- t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
- }
-
- <-done
-}
diff --git a/p2p/peer.go b/p2p/peer.go
index 87a91d406..cbe5ccc84 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -33,9 +33,17 @@ const (
peersMsg = 0x05
)
+// protoHandshake is the RLP structure of the protocol handshake.
+type protoHandshake struct {
+ Version uint64
+ Name string
+ Caps []Cap
+ ListenPort uint64
+ ID discover.NodeID
+}
+
// Peer represents a connected remote node.
type Peer struct {
- conn net.Conn
rw *conn
running map[string]*protoRW
@@ -48,37 +56,36 @@ type Peer struct {
// NewPeer returns a peer for testing purposes.
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
pipe, _ := net.Pipe()
- msgpipe, _ := MsgPipe()
- conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}}
- peer := newPeer(pipe, conn, nil)
+ conn := &conn{fd: pipe, transport: nil, id: id, caps: caps, name: name}
+ peer := newPeer(conn, nil)
close(peer.closed) // ensures Disconnect doesn't block
return peer
}
// ID returns the node's public key.
func (p *Peer) ID() discover.NodeID {
- return p.rw.ID
+ return p.rw.id
}
// Name returns the node name that the remote node advertised.
func (p *Peer) Name() string {
- return p.rw.Name
+ return p.rw.name
}
// Caps returns the capabilities (supported subprotocols) of the remote peer.
func (p *Peer) Caps() []Cap {
// TODO: maybe return copy
- return p.rw.Caps
+ return p.rw.caps
}
// RemoteAddr returns the remote address of the network connection.
func (p *Peer) RemoteAddr() net.Addr {
- return p.conn.RemoteAddr()
+ return p.rw.fd.RemoteAddr()
}
// LocalAddr returns the local address of the network connection.
func (p *Peer) LocalAddr() net.Addr {
- return p.conn.LocalAddr()
+ return p.rw.fd.LocalAddr()
}
// Disconnect terminates the peer connection with the given reason.
@@ -92,13 +99,12 @@ func (p *Peer) Disconnect(reason DiscReason) {
// String implements fmt.Stringer.
func (p *Peer) String() string {
- return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
+ return fmt.Sprintf("Peer %x %v", p.rw.id[:8], p.RemoteAddr())
}
-func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer {
- protomap := matchProtocols(protocols, conn.Caps, conn)
+func newPeer(conn *conn, protocols []Protocol) *Peer {
+ protomap := matchProtocols(protocols, conn.caps, conn)
p := &Peer{
- conn: fd,
rw: conn,
running: protomap,
disc: make(chan DiscReason),
@@ -117,7 +123,10 @@ func (p *Peer) run() DiscReason {
p.startProtocols()
// Wait for an error or disconnect.
- var reason DiscReason
+ var (
+ reason DiscReason
+ requested bool
+ )
select {
case err := <-readErr:
if r, ok := err.(DiscReason); ok {
@@ -131,21 +140,17 @@ func (p *Peer) run() DiscReason {
case err := <-p.protoErr:
reason = discReasonForError(err)
case reason = <-p.disc:
- p.politeDisconnect(reason)
- reason = DiscRequested
+ requested = true
}
-
close(p.closed)
+ p.rw.close(reason)
p.wg.Wait()
- glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
- return reason
-}
-func (p *Peer) politeDisconnect(reason DiscReason) {
- if reason != DiscNetworkError {
- SendItems(p.rw, discMsg, uint(reason))
+ if requested {
+ reason = DiscRequested
}
- p.conn.Close()
+ glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
+ return reason
}
func (p *Peer) pingLoop() {
@@ -254,7 +259,7 @@ func (p *Peer) startProtocols() {
glog.V(logger.Detail).Infof("%v: Protocol %s/%d returned\n", p, proto.Name, proto.Version)
err = errors.New("protocol returned")
} else if err != io.EOF {
- glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: \n", p, proto.Name, proto.Version, err)
+ glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: %v\n", p, proto.Name, proto.Version, err)
}
p.protoErr <- err
p.wg.Done()
diff --git a/p2p/peer_error.go b/p2p/peer_error.go
index a912f6064..6938a9801 100644
--- a/p2p/peer_error.go
+++ b/p2p/peer_error.go
@@ -5,39 +5,17 @@ import (
)
const (
- errMagicTokenMismatch = iota
- errRead
- errWrite
- errMisc
- errInvalidMsgCode
+ errInvalidMsgCode = iota
errInvalidMsg
- errP2PVersionMismatch
- errPubkeyInvalid
- errPubkeyForbidden
- errProtocolBreach
- errPingTimeout
- errInvalidNetworkId
- errInvalidProtocolVersion
)
var errorToString = map[int]string{
- errMagicTokenMismatch: "magic token mismatch",
- errRead: "read error",
- errWrite: "write error",
- errMisc: "misc error",
- errInvalidMsgCode: "invalid message code",
- errInvalidMsg: "invalid message",
- errP2PVersionMismatch: "P2P Version Mismatch",
- errPubkeyInvalid: "public key invalid",
- errPubkeyForbidden: "public key forbidden",
- errProtocolBreach: "protocol Breach",
- errPingTimeout: "ping timeout",
- errInvalidNetworkId: "invalid network id",
- errInvalidProtocolVersion: "invalid protocol version",
+ errInvalidMsgCode: "invalid message code",
+ errInvalidMsg: "invalid message",
}
type peerError struct {
- Code int
+ code int
message string
}
@@ -107,23 +85,13 @@ func discReasonForError(err error) DiscReason {
return reason
}
peerError, ok := err.(*peerError)
- if !ok {
- return DiscSubprotocolError
- }
- switch peerError.Code {
- case errP2PVersionMismatch:
- return DiscIncompatibleVersion
- case errPubkeyInvalid:
- return DiscInvalidIdentity
- case errPubkeyForbidden:
- return DiscUselessPeer
- case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
- return DiscProtocolError
- case errPingTimeout:
- return DiscReadTimeout
- case errRead, errWrite:
- return DiscNetworkError
- default:
- return DiscSubprotocolError
+ if ok {
+ switch peerError.code {
+ case errInvalidMsgCode, errInvalidMsg:
+ return DiscProtocolError
+ default:
+ return DiscSubprotocolError
+ }
}
+ return DiscSubprotocolError
}
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index 0ac032ab7..7b772e198 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -28,24 +28,20 @@ var discard = Protocol{
}
func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan DiscReason) {
- fd1, _ := net.Pipe()
- hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
- hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
+ fd1, fd2 := net.Pipe()
+ c1 := &conn{fd: fd1, transport: newTestTransport(randomID(), fd1)}
+ c2 := &conn{fd: fd2, transport: newTestTransport(randomID(), fd2)}
for _, p := range protos {
- hs1.Caps = append(hs1.Caps, p.cap())
- hs2.Caps = append(hs2.Caps, p.cap())
+ c1.caps = append(c1.caps, p.cap())
+ c2.caps = append(c2.caps, p.cap())
}
- p1, p2 := MsgPipe()
- peer := newPeer(fd1, &conn{p1, hs1}, protos)
+ peer := newPeer(c1, protos)
errc := make(chan DiscReason, 1)
go func() { errc <- peer.run() }()
- closer := func() {
- p1.Close()
- fd1.Close()
- }
- return closer, &conn{p2, hs2}, peer, errc
+ closer := func() { c2.close(errors.New("close func called")) }
+ return closer, c2, peer, errc
}
func TestPeerProtoReadMsg(t *testing.T) {
diff --git a/p2p/rlpx.go b/p2p/rlpx.go
index 6b533e275..e1cb13aae 100644
--- a/p2p/rlpx.go
+++ b/p2p/rlpx.go
@@ -4,23 +4,459 @@ import (
"bytes"
"crypto/aes"
"crypto/cipher"
+ "crypto/ecdsa"
+ "crypto/elliptic"
"crypto/hmac"
+ "crypto/rand"
"errors"
+ "fmt"
"hash"
"io"
+ "net"
+ "sync"
+ "time"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/ecies"
+ "github.com/ethereum/go-ethereum/crypto/secp256k1"
+ "github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp"
)
+const (
+ maxUint24 = ^uint32(0) >> 8
+
+ sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
+ sigLen = 65 // elliptic S256
+ pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
+ shaLen = 32 // hash length (for nonce etc)
+
+ authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
+ authRespLen = pubLen + shaLen + 1
+
+ eciesBytes = 65 + 16 + 32
+ encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
+ encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
+
+ // total timeout for encryption handshake and protocol
+ // handshake in both directions.
+ handshakeTimeout = 5 * time.Second
+
+ // This is the timeout for sending the disconnect reason.
+ // This is shorter than the usual timeout because we don't want
+ // to wait if the connection is known to be bad anyway.
+ discWriteTimeout = 1 * time.Second
+)
+
+// rlpx is the transport protocol used by actual (non-test) connections.
+// It wraps the frame encoder with locks and read/write deadlines.
+type rlpx struct {
+ fd net.Conn
+
+ rmu, wmu sync.Mutex
+ rw *rlpxFrameRW
+}
+
+func newRLPX(fd net.Conn) transport {
+ fd.SetDeadline(time.Now().Add(handshakeTimeout))
+ return &rlpx{fd: fd}
+}
+
+func (t *rlpx) ReadMsg() (Msg, error) {
+ t.rmu.Lock()
+ defer t.rmu.Unlock()
+ t.fd.SetReadDeadline(time.Now().Add(frameReadTimeout))
+ return t.rw.ReadMsg()
+}
+
+func (t *rlpx) WriteMsg(msg Msg) error {
+ t.wmu.Lock()
+ defer t.wmu.Unlock()
+ t.fd.SetWriteDeadline(time.Now().Add(frameWriteTimeout))
+ return t.rw.WriteMsg(msg)
+}
+
+func (t *rlpx) close(err error) {
+ t.wmu.Lock()
+ defer t.wmu.Unlock()
+ // Tell the remote end why we're disconnecting if possible.
+ if t.rw != nil {
+ if r, ok := err.(DiscReason); ok && r != DiscNetworkError {
+ t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout))
+ SendItems(t.rw, discMsg, r)
+ }
+ }
+ t.fd.Close()
+}
+
+// doEncHandshake runs the protocol handshake using authenticated
+// messages. the protocol handshake is the first authenticated message
+// and also verifies whether the encryption handshake 'worked' and the
+// remote side actually provided the right public key.
+func (t *rlpx) doProtoHandshake(our *protoHandshake) (their *protoHandshake, err error) {
+ // Writing our handshake happens concurrently, we prefer
+ // returning the handshake read error. If the remote side
+ // disconnects us early with a valid reason, we should return it
+ // as the error so it can be tracked elsewhere.
+ werr := make(chan error, 1)
+ go func() { werr <- Send(t.rw, handshakeMsg, our) }()
+ if their, err = readProtocolHandshake(t.rw, our); err != nil {
+ return nil, err
+ }
+ if err := <-werr; err != nil {
+ return nil, fmt.Errorf("write error: %v", err)
+ }
+ return their, nil
+}
+
+func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake, error) {
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ return nil, err
+ }
+ if msg.Size > baseProtocolMaxMsgSize {
+ return nil, fmt.Errorf("message too big")
+ }
+ if msg.Code == discMsg {
+ // Disconnect before protocol handshake is valid according to the
+ // spec and we send it ourself if the posthanshake checks fail.
+ // We can't return the reason directly, though, because it is echoed
+ // back otherwise. Wrap it in a string instead.
+ var reason [1]DiscReason
+ rlp.Decode(msg.Payload, &reason)
+ return nil, reason[0]
+ }
+ if msg.Code != handshakeMsg {
+ return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
+ }
+ var hs protoHandshake
+ if err := msg.Decode(&hs); err != nil {
+ return nil, err
+ }
+ // validate handshake info
+ if hs.Version != our.Version {
+ return nil, DiscIncompatibleVersion
+ }
+ if (hs.ID == discover.NodeID{}) {
+ return nil, DiscInvalidIdentity
+ }
+ return &hs, nil
+}
+
+func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *discover.Node) (discover.NodeID, error) {
+ var (
+ sec secrets
+ err error
+ )
+ if dial == nil {
+ sec, err = receiverEncHandshake(t.fd, prv, nil)
+ } else {
+ sec, err = initiatorEncHandshake(t.fd, prv, dial.ID, nil)
+ }
+ if err != nil {
+ return discover.NodeID{}, err
+ }
+ t.wmu.Lock()
+ t.rw = newRLPXFrameRW(t.fd, sec)
+ t.wmu.Unlock()
+ return sec.RemoteID, nil
+}
+
+// encHandshake contains the state of the encryption handshake.
+type encHandshake struct {
+ initiator bool
+ remoteID discover.NodeID
+
+ remotePub *ecies.PublicKey // remote-pubk
+ initNonce, respNonce []byte // nonce
+ randomPrivKey *ecies.PrivateKey // ecdhe-random
+ remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk
+}
+
+// secrets represents the connection secrets
+// which are negotiated during the encryption handshake.
+type secrets struct {
+ RemoteID discover.NodeID
+ AES, MAC []byte
+ EgressMAC, IngressMAC hash.Hash
+ Token []byte
+}
+
+// secrets is called after the handshake is completed.
+// It extracts the connection secrets from the handshake values.
+func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) {
+ ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen)
+ if err != nil {
+ return secrets{}, err
+ }
+
+ // derive base secrets from ephemeral key agreement
+ sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce))
+ aesSecret := crypto.Sha3(ecdheSecret, sharedSecret)
+ s := secrets{
+ RemoteID: h.remoteID,
+ AES: aesSecret,
+ MAC: crypto.Sha3(ecdheSecret, aesSecret),
+ Token: crypto.Sha3(sharedSecret),
+ }
+
+ // setup sha3 instances for the MACs
+ mac1 := sha3.NewKeccak256()
+ mac1.Write(xor(s.MAC, h.respNonce))
+ mac1.Write(auth)
+ mac2 := sha3.NewKeccak256()
+ mac2.Write(xor(s.MAC, h.initNonce))
+ mac2.Write(authResp)
+ if h.initiator {
+ s.EgressMAC, s.IngressMAC = mac1, mac2
+ } else {
+ s.EgressMAC, s.IngressMAC = mac2, mac1
+ }
+
+ return s, nil
+}
+
+func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) {
+ return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen)
+}
+
+// initiatorEncHandshake negotiates a session token on conn.
+// it should be called on the dialing side of the connection.
+//
+// prv is the local client's private key.
+// token is the token from a previous session with this node.
+func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) {
+ h, err := newInitiatorHandshake(remoteID)
+ if err != nil {
+ return s, err
+ }
+ auth, err := h.authMsg(prv, token)
+ if err != nil {
+ return s, err
+ }
+ if _, err = conn.Write(auth); err != nil {
+ return s, err
+ }
+
+ response := make([]byte, encAuthRespLen)
+ if _, err = io.ReadFull(conn, response); err != nil {
+ return s, err
+ }
+ if err := h.decodeAuthResp(response, prv); err != nil {
+ return s, err
+ }
+ return h.secrets(auth, response)
+}
+
+func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) {
+ // generate random initiator nonce
+ n := make([]byte, shaLen)
+ if _, err := rand.Read(n); err != nil {
+ return nil, err
+ }
+ // generate random keypair to use for signing
+ randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
+ if err != nil {
+ return nil, err
+ }
+ rpub, err := remoteID.Pubkey()
+ if err != nil {
+ return nil, fmt.Errorf("bad remoteID: %v", err)
+ }
+ h := &encHandshake{
+ initiator: true,
+ remoteID: remoteID,
+ remotePub: ecies.ImportECDSAPublic(rpub),
+ initNonce: n,
+ randomPrivKey: randpriv,
+ }
+ return h, nil
+}
+
+// authMsg creates an encrypted initiator handshake message.
+func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
+ var tokenFlag byte
+ if token == nil {
+ // no session token found means we need to generate shared secret.
+ // ecies shared secret is used as initial session token for new peers
+ // generate shared key from prv and remote pubkey
+ var err error
+ if token, err = h.ecdhShared(prv); err != nil {
+ return nil, err
+ }
+ } else {
+ // for known peers, we use stored token from the previous session
+ tokenFlag = 0x01
+ }
+
+ // sign known message:
+ // ecdh-shared-secret^nonce for new peers
+ // token^nonce for old peers
+ signed := xor(token, h.initNonce)
+ signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA())
+ if err != nil {
+ return nil, err
+ }
+
+ // encode auth message
+ // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
+ msg := make([]byte, authMsgLen)
+ n := copy(msg, signature)
+ n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey)))
+ n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:])
+ n += copy(msg[n:], h.initNonce)
+ msg[n] = tokenFlag
+
+ // encrypt auth message using remote-pubk
+ return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil)
+}
+
+// decodeAuthResp decode an encrypted authentication response message.
+func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error {
+ msg, err := crypto.Decrypt(prv, auth)
+ if err != nil {
+ return fmt.Errorf("could not decrypt auth response (%v)", err)
+ }
+ h.respNonce = msg[pubLen : pubLen+shaLen]
+ h.remoteRandomPub, err = importPublicKey(msg[:pubLen])
+ if err != nil {
+ return err
+ }
+ // ignore token flag for now
+ return nil
+}
+
+// receiverEncHandshake negotiates a session token on conn.
+// it should be called on the listening side of the connection.
+//
+// prv is the local client's private key.
+// token is the token from a previous session with this node.
+func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) {
+ // read remote auth sent by initiator.
+ auth := make([]byte, encAuthMsgLen)
+ if _, err := io.ReadFull(conn, auth); err != nil {
+ return s, err
+ }
+ h, err := decodeAuthMsg(prv, token, auth)
+ if err != nil {
+ return s, err
+ }
+
+ // send auth response
+ resp, err := h.authResp(prv, token)
+ if err != nil {
+ return s, err
+ }
+ if _, err = conn.Write(resp); err != nil {
+ return s, err
+ }
+
+ return h.secrets(auth, resp)
+}
+
+func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) {
+ var err error
+ h := new(encHandshake)
+ // generate random keypair for session
+ h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
+ if err != nil {
+ return nil, err
+ }
+ // generate random nonce
+ h.respNonce = make([]byte, shaLen)
+ if _, err = rand.Read(h.respNonce); err != nil {
+ return nil, err
+ }
+
+ msg, err := crypto.Decrypt(prv, auth)
+ if err != nil {
+ return nil, fmt.Errorf("could not decrypt auth message (%v)", err)
+ }
+
+ // decode message parameters
+ // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
+ h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
+ copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen])
+ rpub, err := h.remoteID.Pubkey()
+ if err != nil {
+ return nil, fmt.Errorf("bad remoteID: %#v", err)
+ }
+ h.remotePub = ecies.ImportECDSAPublic(rpub)
+
+ // recover remote random pubkey from signed message.
+ if token == nil {
+ // TODO: it is an error if the initiator has a token and we don't. check that.
+
+ // no session token means we need to generate shared secret.
+ // ecies shared secret is used as initial session token for new peers.
+ // generate shared key from prv and remote pubkey.
+ if token, err = h.ecdhShared(prv); err != nil {
+ return nil, err
+ }
+ }
+ signedMsg := xor(token, h.initNonce)
+ remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen])
+ if err != nil {
+ return nil, err
+ }
+ h.remoteRandomPub, _ = importPublicKey(remoteRandomPub)
+ return h, nil
+}
+
+// authResp generates the encrypted authentication response message.
+func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
+ // responder auth message
+ // E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
+ resp := make([]byte, authRespLen)
+ n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey))
+ n += copy(resp[n:], h.respNonce)
+ if token == nil {
+ resp[n] = 0
+ } else {
+ resp[n] = 1
+ }
+ // encrypt using remote-pubk
+ return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil)
+}
+
+// importPublicKey unmarshals 512 bit public keys.
+func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) {
+ var pubKey65 []byte
+ switch len(pubKey) {
+ case 64:
+ // add 'uncompressed key' flag
+ pubKey65 = append([]byte{0x04}, pubKey...)
+ case 65:
+ pubKey65 = pubKey
+ default:
+ return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
+ }
+ // TODO: fewer pointless conversions
+ return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil
+}
+
+func exportPubkey(pub *ecies.PublicKey) []byte {
+ if pub == nil {
+ panic("nil pubkey")
+ }
+ return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:]
+}
+
+func xor(one, other []byte) (xor []byte) {
+ xor = make([]byte, len(one))
+ for i := 0; i < len(one); i++ {
+ xor[i] = one[i] ^ other[i]
+ }
+ return xor
+}
+
var (
// this is used in place of actual frame header data.
// TODO: replace this when Msg contains the protocol type code.
zeroHeader = []byte{0xC2, 0x80, 0x80}
-
// sixteen zero bytes
zero16 = make([]byte, 16)
-
- maxUint24 = ^uint32(0) >> 8
)
// rlpxFrameRW implements a simplified version of RLPx framing.
@@ -38,7 +474,7 @@ type rlpxFrameRW struct {
ingressMAC hash.Hash
}
-func newRlpxFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW {
+func newRLPXFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW {
macc, err := aes.NewCipher(s.MAC)
if err != nil {
panic("invalid MAC secret: " + err.Error())
diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go
index d98f1c2cd..44be46a99 100644
--- a/p2p/rlpx_test.go
+++ b/p2p/rlpx_test.go
@@ -3,19 +3,253 @@ package p2p
import (
"bytes"
"crypto/rand"
+ "errors"
+ "fmt"
"io/ioutil"
+ "net"
+ "reflect"
"strings"
+ "sync"
"testing"
+ "time"
+ "github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp"
)
-func TestRlpxFrameFake(t *testing.T) {
+func TestSharedSecret(t *testing.T) {
+ prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
+ pub0 := &prv0.PublicKey
+ prv1, _ := crypto.GenerateKey()
+ pub1 := &prv1.PublicKey
+
+ ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
+ if err != nil {
+ return
+ }
+ ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
+ if err != nil {
+ return
+ }
+ t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
+ if !bytes.Equal(ss0, ss1) {
+ t.Errorf("dont match :(")
+ }
+}
+
+func TestEncHandshake(t *testing.T) {
+ for i := 0; i < 10; i++ {
+ start := time.Now()
+ if err := testEncHandshake(nil); err != nil {
+ t.Fatalf("i=%d %v", i, err)
+ }
+ t.Logf("(without token) %d %v\n", i+1, time.Since(start))
+ }
+ for i := 0; i < 10; i++ {
+ tok := make([]byte, shaLen)
+ rand.Reader.Read(tok)
+ start := time.Now()
+ if err := testEncHandshake(tok); err != nil {
+ t.Fatalf("i=%d %v", i, err)
+ }
+ t.Logf("(with token) %d %v\n", i+1, time.Since(start))
+ }
+}
+
+func testEncHandshake(token []byte) error {
+ type result struct {
+ side string
+ id discover.NodeID
+ err error
+ }
+ var (
+ prv0, _ = crypto.GenerateKey()
+ prv1, _ = crypto.GenerateKey()
+ fd0, fd1 = net.Pipe()
+ c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx)
+ output = make(chan result)
+ )
+
+ go func() {
+ r := result{side: "initiator"}
+ defer func() { output <- r }()
+
+ dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)}
+ r.id, r.err = c0.doEncHandshake(prv0, dest)
+ if r.err != nil {
+ return
+ }
+ id1 := discover.PubkeyID(&prv1.PublicKey)
+ if r.id != id1 {
+ r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1)
+ }
+ }()
+ go func() {
+ r := result{side: "receiver"}
+ defer func() { output <- r }()
+
+ r.id, r.err = c1.doEncHandshake(prv1, nil)
+ if r.err != nil {
+ return
+ }
+ id0 := discover.PubkeyID(&prv0.PublicKey)
+ if r.id != id0 {
+ r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0)
+ }
+ }()
+
+ // wait for results from both sides
+ r1, r2 := <-output, <-output
+ if r1.err != nil {
+ return fmt.Errorf("%s side error: %v", r1.side, r1.err)
+ }
+ if r2.err != nil {
+ return fmt.Errorf("%s side error: %v", r2.side, r2.err)
+ }
+
+ // compare derived secrets
+ if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) {
+ return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC)
+ }
+ if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) {
+ return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC)
+ }
+ if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) {
+ return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc)
+ }
+ if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) {
+ return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec)
+ }
+ return nil
+}
+
+func TestProtocolHandshake(t *testing.T) {
+ var (
+ prv0, _ = crypto.GenerateKey()
+ node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33}
+ hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}}
+
+ prv1, _ = crypto.GenerateKey()
+ node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44}
+ hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}}
+
+ fd0, fd1 = net.Pipe()
+ wg sync.WaitGroup
+ )
+
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ rlpx := newRLPX(fd0)
+ remid, err := rlpx.doEncHandshake(prv0, node1)
+ if err != nil {
+ t.Errorf("dial side enc handshake failed: %v", err)
+ return
+ }
+ if remid != node1.ID {
+ t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID)
+ return
+ }
+
+ phs, err := rlpx.doProtoHandshake(hs0)
+ if err != nil {
+ t.Errorf("dial side proto handshake error: %v", err)
+ return
+ }
+ if !reflect.DeepEqual(phs, hs1) {
+ t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1))
+ return
+ }
+ rlpx.close(DiscQuitting)
+ }()
+ go func() {
+ defer wg.Done()
+ rlpx := newRLPX(fd1)
+ remid, err := rlpx.doEncHandshake(prv1, nil)
+ if err != nil {
+ t.Errorf("listen side enc handshake failed: %v", err)
+ return
+ }
+ if remid != node0.ID {
+ t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID)
+ return
+ }
+
+ phs, err := rlpx.doProtoHandshake(hs1)
+ if err != nil {
+ t.Errorf("listen side proto handshake error: %v", err)
+ return
+ }
+ if !reflect.DeepEqual(phs, hs0) {
+ t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0))
+ return
+ }
+
+ if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil {
+ t.Errorf("error receiving disconnect: %v", err)
+ }
+ }()
+ wg.Wait()
+}
+
+func TestProtocolHandshakeErrors(t *testing.T) {
+ our := &protoHandshake{Version: 3, Caps: []Cap{{"foo", 2}, {"bar", 3}}, Name: "quux"}
+ id := randomID()
+ tests := []struct {
+ code uint64
+ msg interface{}
+ err error
+ }{
+ {
+ code: discMsg,
+ msg: []DiscReason{DiscQuitting},
+ err: DiscQuitting,
+ },
+ {
+ code: 0x989898,
+ msg: []byte{1},
+ err: errors.New("expected handshake, got 989898"),
+ },
+ {
+ code: handshakeMsg,
+ msg: make([]byte, baseProtocolMaxMsgSize+2),
+ err: errors.New("message too big"),
+ },
+ {
+ code: handshakeMsg,
+ msg: []byte{1, 2, 3},
+ err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"),
+ },
+ {
+ code: handshakeMsg,
+ msg: &protoHandshake{Version: 9944, ID: id},
+ err: DiscIncompatibleVersion,
+ },
+ {
+ code: handshakeMsg,
+ msg: &protoHandshake{Version: 3},
+ err: DiscInvalidIdentity,
+ },
+ }
+
+ for i, test := range tests {
+ p1, p2 := MsgPipe()
+ go Send(p1, test.code, test.msg)
+ _, err := readProtocolHandshake(p2, our)
+ if !reflect.DeepEqual(err, test.err) {
+ t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err)
+ }
+ }
+}
+
+func TestRLPXFrameFake(t *testing.T) {
buf := new(bytes.Buffer)
hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
- rw := newRlpxFrameRW(buf, secrets{
+ rw := newRLPXFrameRW(buf, secrets{
AES: crypto.Sha3(),
MAC: crypto.Sha3(),
IngressMAC: hash,
@@ -66,7 +300,7 @@ func (fakeHash) BlockSize() int { return 0 }
func (h fakeHash) Size() int { return len(h) }
func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
-func TestRlpxFrameRW(t *testing.T) {
+func TestRLPXFrameRW(t *testing.T) {
var (
aesSecret = make([]byte, 16)
macSecret = make([]byte, 16)
@@ -86,7 +320,7 @@ func TestRlpxFrameRW(t *testing.T) {
}
s1.EgressMAC.Write(egressMACinit)
s1.IngressMAC.Write(ingressMACinit)
- rw1 := newRlpxFrameRW(conn, s1)
+ rw1 := newRLPXFrameRW(conn, s1)
s2 := secrets{
AES: aesSecret,
@@ -96,7 +330,7 @@ func TestRlpxFrameRW(t *testing.T) {
}
s2.EgressMAC.Write(ingressMACinit)
s2.IngressMAC.Write(egressMACinit)
- rw2 := newRlpxFrameRW(conn, s2)
+ rw2 := newRLPXFrameRW(conn, s2)
// send some messages
for i := 0; i < 10; i++ {
diff --git a/p2p/server.go b/p2p/server.go
index 529fedbca..27e617610 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -2,7 +2,6 @@ package p2p
import (
"crypto/ecdsa"
- "crypto/rand"
"errors"
"fmt"
"net"
@@ -24,11 +23,8 @@ const (
maxAcceptConns = 50
// Maximum number of concurrently dialing outbound connections.
- maxDialingConns = 10
+ maxActiveDialTasks = 16
- // total timeout for encryption handshake and protocol
- // handshake in both directions.
- handshakeTimeout = 5 * time.Second
// maximum time allowed for reading a complete message.
// this is effectively the amount of time a connection can be idle.
frameReadTimeout = 1 * time.Minute
@@ -36,6 +32,8 @@ const (
frameWriteTimeout = 5 * time.Second
)
+var errServerStopped = errors.New("server stopped")
+
var srvjslog = logger.NewJsonLogger()
// Server manages all peer connections.
@@ -103,68 +101,173 @@ type Server struct {
// Hooks for testing. These are useful because we can inhibit
// the whole protocol stack.
- setupFunc
- newPeerHook
+ newTransport func(net.Conn) transport
+ newPeerHook func(*Peer)
+
+ lock sync.Mutex // protects running
+ running bool
+ ntab discoverTable
+ listener net.Listener
ourHandshake *protoHandshake
- lock sync.RWMutex // protects running, peers and the trust fields
- running bool
- peers map[discover.NodeID]*Peer
- staticNodes map[discover.NodeID]*discover.Node // Map of currently maintained static remote nodes
- staticDial chan *discover.Node // Dial request channel reserved for the static nodes
- staticCycle time.Duration // Overrides staticPeerCheckInterval, used for testing
- trustedNodes map[discover.NodeID]bool // Set of currently trusted remote nodes
+ // These are for Peers, PeerCount (and nothing else).
+ peerOp chan peerOpFunc
+ peerOpDone chan struct{}
+
+ quit chan struct{}
+ addstatic chan *discover.Node
+ posthandshake chan *conn
+ addpeer chan *conn
+ delpeer chan *Peer
+ loopWG sync.WaitGroup // loop, listenLoop
+}
+
+type peerOpFunc func(map[discover.NodeID]*Peer)
+
+type connFlag int
- ntab *discover.Table
- listener net.Listener
+const (
+ dynDialedConn connFlag = 1 << iota
+ staticDialedConn
+ inboundConn
+ trustedConn
+)
+
+// conn wraps a network connection with information gathered
+// during the two handshakes.
+type conn struct {
+ fd net.Conn
+ transport
+ flags connFlag
+ cont chan error // The run loop uses cont to signal errors to setupConn.
+ id discover.NodeID // valid after the encryption handshake
+ caps []Cap // valid after the protocol handshake
+ name string // valid after the protocol handshake
+}
- quit chan struct{}
- loopWG sync.WaitGroup // {dial,listen,nat}Loop
- peerWG sync.WaitGroup // active peer goroutines
+type transport interface {
+ // The two handshakes.
+ doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error)
+ doProtoHandshake(our *protoHandshake) (*protoHandshake, error)
+ // The MsgReadWriter can only be used after the encryption
+ // handshake has completed. The code uses conn.id to track this
+ // by setting it to a non-nil value after the encryption handshake.
+ MsgReadWriter
+ // transports must provide Close because we use MsgPipe in some of
+ // the tests. Closing the actual network connection doesn't do
+ // anything in those tests because NsgPipe doesn't use it.
+ close(err error)
}
-type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node, func(discover.NodeID) bool) (*conn, error)
-type newPeerHook func(*Peer)
+func (c *conn) String() string {
+ s := c.flags.String() + " conn"
+ if (c.id != discover.NodeID{}) {
+ s += fmt.Sprintf(" %x", c.id[:8])
+ }
+ s += " " + c.fd.RemoteAddr().String()
+ return s
+}
+
+func (f connFlag) String() string {
+ s := ""
+ if f&trustedConn != 0 {
+ s += " trusted"
+ }
+ if f&dynDialedConn != 0 {
+ s += " dyn dial"
+ }
+ if f&staticDialedConn != 0 {
+ s += " static dial"
+ }
+ if f&inboundConn != 0 {
+ s += " inbound"
+ }
+ if s != "" {
+ s = s[1:]
+ }
+ return s
+}
+
+func (c *conn) is(f connFlag) bool {
+ return c.flags&f != 0
+}
// Peers returns all connected peers.
-func (srv *Server) Peers() (peers []*Peer) {
- srv.lock.RLock()
- defer srv.lock.RUnlock()
- for _, peer := range srv.peers {
- if peer != nil {
- peers = append(peers, peer)
+func (srv *Server) Peers() []*Peer {
+ var ps []*Peer
+ select {
+ // Note: We'd love to put this function into a variable but
+ // that seems to cause a weird compiler error in some
+ // environments.
+ case srv.peerOp <- func(peers map[discover.NodeID]*Peer) {
+ for _, p := range peers {
+ ps = append(ps, p)
}
+ }:
+ <-srv.peerOpDone
+ case <-srv.quit:
}
- return
+ return ps
}
// PeerCount returns the number of connected peers.
func (srv *Server) PeerCount() int {
- srv.lock.RLock()
- n := len(srv.peers)
- srv.lock.RUnlock()
- return n
+ var count int
+ select {
+ case srv.peerOp <- func(ps map[discover.NodeID]*Peer) { count = len(ps) }:
+ <-srv.peerOpDone
+ case <-srv.quit:
+ }
+ return count
}
// AddPeer connects to the given node and maintains the connection until the
// server is shut down. If the connection fails for any reason, the server will
// attempt to reconnect the peer.
func (srv *Server) AddPeer(node *discover.Node) {
+ select {
+ case srv.addstatic <- node:
+ case <-srv.quit:
+ }
+}
+
+// Self returns the local node's endpoint information.
+func (srv *Server) Self() *discover.Node {
srv.lock.Lock()
defer srv.lock.Unlock()
+ if !srv.running {
+ return &discover.Node{IP: net.ParseIP("0.0.0.0")}
+ }
+ return srv.ntab.Self()
+}
- srv.staticNodes[node.ID] = node
+// Stop terminates the server and all active peer connections.
+// It blocks until all active connections have been closed.
+func (srv *Server) Stop() {
+ srv.lock.Lock()
+ defer srv.lock.Unlock()
+ if !srv.running {
+ return
+ }
+ srv.running = false
+ if srv.listener != nil {
+ // this unblocks listener Accept
+ srv.listener.Close()
+ }
+ close(srv.quit)
+ srv.loopWG.Wait()
}
// Start starts running the server.
-// Servers can be re-used and started again after stopping.
+// Servers can not be re-used after stopping.
func (srv *Server) Start() (err error) {
srv.lock.Lock()
defer srv.lock.Unlock()
if srv.running {
return errors.New("server already running")
}
+ srv.running = true
glog.V(logger.Info).Infoln("Starting Server")
// static fields
@@ -174,23 +277,19 @@ func (srv *Server) Start() (err error) {
if srv.MaxPeers <= 0 {
return fmt.Errorf("Server.MaxPeers must be > 0")
}
- srv.quit = make(chan struct{})
- srv.peers = make(map[discover.NodeID]*Peer)
-
- // Create the current trust maps, and the associated dialing channel
- srv.trustedNodes = make(map[discover.NodeID]bool)
- for _, node := range srv.TrustedNodes {
- srv.trustedNodes[node.ID] = true
- }
- srv.staticNodes = make(map[discover.NodeID]*discover.Node)
- for _, node := range srv.StaticNodes {
- srv.staticNodes[node.ID] = node
+ if srv.newTransport == nil {
+ srv.newTransport = newRLPX
}
- srv.staticDial = make(chan *discover.Node)
-
- if srv.setupFunc == nil {
- srv.setupFunc = setupConn
+ if srv.Dialer == nil {
+ srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
+ srv.quit = make(chan struct{})
+ srv.addpeer = make(chan *conn)
+ srv.delpeer = make(chan *Peer)
+ srv.posthandshake = make(chan *conn)
+ srv.addstatic = make(chan *discover.Node)
+ srv.peerOp = make(chan peerOpFunc)
+ srv.peerOpDone = make(chan struct{})
// node table
ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase)
@@ -198,37 +297,31 @@ func (srv *Server) Start() (err error) {
return err
}
srv.ntab = ntab
+ dialer := newDialState(srv.StaticNodes, srv.ntab, srv.MaxPeers/2)
// handshake
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self().ID}
for _, p := range srv.Protocols {
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
}
-
// listen/dial
if srv.ListenAddr != "" {
if err := srv.startListening(); err != nil {
return err
}
}
- if srv.Dialer == nil {
- srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
- }
- if !srv.NoDial {
- srv.loopWG.Add(1)
- go srv.dialLoop()
- }
if srv.NoDial && srv.ListenAddr == "" {
glog.V(logger.Warn).Infoln("I will be kind-of useless, neither dialing nor listening.")
}
- // maintain the static peers
- go srv.staticNodesLoop()
+ srv.loopWG.Add(1)
+ go srv.run(dialer)
srv.running = true
return nil
}
func (srv *Server) startListening() error {
+ // Launch the TCP listener.
listener, err := net.Listen("tcp", srv.ListenAddr)
if err != nil {
return err
@@ -238,6 +331,7 @@ func (srv *Server) startListening() error {
srv.listener = listener
srv.loopWG.Add(1)
go srv.listenLoop()
+ // Map the TCP listening port if NAT is configured.
if !laddr.IP.IsLoopback() && srv.NAT != nil {
srv.loopWG.Add(1)
go func() {
@@ -248,50 +342,164 @@ func (srv *Server) startListening() error {
return nil
}
-// Stop terminates the server and all active peer connections.
-// It blocks until all active connections have been closed.
-func (srv *Server) Stop() {
- srv.lock.Lock()
- if !srv.running {
- srv.lock.Unlock()
- return
+type dialer interface {
+ newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task
+ taskDone(task, time.Time)
+ addStatic(*discover.Node)
+}
+
+func (srv *Server) run(dialstate dialer) {
+ defer srv.loopWG.Done()
+ var (
+ peers = make(map[discover.NodeID]*Peer)
+ trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes))
+
+ tasks []task
+ pendingTasks []task
+ taskdone = make(chan task, maxActiveDialTasks)
+ )
+ // Put trusted nodes into a map to speed up checks.
+ // Trusted peers are loaded on startup and cannot be
+ // modified while the server is running.
+ for _, n := range srv.TrustedNodes {
+ trusted[n.ID] = true
+ }
+
+ // Some task list helpers.
+ delTask := func(t task) {
+ for i := range tasks {
+ if tasks[i] == t {
+ tasks = append(tasks[:i], tasks[i+1:]...)
+ break
+ }
+ }
}
- srv.running = false
- srv.lock.Unlock()
+ scheduleTasks := func(new []task) {
+ pt := append(pendingTasks, new...)
+ start := maxActiveDialTasks - len(tasks)
+ if len(pt) < start {
+ start = len(pt)
+ }
+ if start > 0 {
+ tasks = append(tasks, pt[:start]...)
+ for _, t := range pt[:start] {
+ t := t
+ glog.V(logger.Detail).Infoln("new task:", t)
+ go func() { t.Do(srv); taskdone <- t }()
+ }
+ copy(pt, pt[start:])
+ pendingTasks = pt[:len(pt)-start]
+ }
+ }
+
+running:
+ for {
+ // Query the dialer for new tasks and launch them.
+ now := time.Now()
+ nt := dialstate.newTasks(len(pendingTasks)+len(tasks), peers, now)
+ scheduleTasks(nt)
- glog.V(logger.Info).Infoln("Stopping Server")
+ select {
+ case <-srv.quit:
+ // The server was stopped. Run the cleanup logic.
+ glog.V(logger.Detail).Infoln("<-quit: spinning down")
+ break running
+ case n := <-srv.addstatic:
+ // This channel is used by AddPeer to add to the
+ // ephemeral static peer list. Add it to the dialer,
+ // it will keep the node connected.
+ glog.V(logger.Detail).Infoln("<-addstatic:", n)
+ dialstate.addStatic(n)
+ case op := <-srv.peerOp:
+ // This channel is used by Peers and PeerCount.
+ op(peers)
+ srv.peerOpDone <- struct{}{}
+ case t := <-taskdone:
+ // A task got done. Tell dialstate about it so it
+ // can update its state and remove it from the active
+ // tasks list.
+ glog.V(logger.Detail).Infoln("<-taskdone:", t)
+ dialstate.taskDone(t, now)
+ delTask(t)
+ case c := <-srv.posthandshake:
+ // A connection has passed the encryption handshake so
+ // the remote identity is known (but hasn't been verified yet).
+ if trusted[c.id] {
+ // Ensure that the trusted flag is set before checking against MaxPeers.
+ c.flags |= trustedConn
+ }
+ glog.V(logger.Detail).Infoln("<-posthandshake:", c)
+ // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
+ c.cont <- srv.encHandshakeChecks(peers, c)
+ case c := <-srv.addpeer:
+ // At this point the connection is past the protocol handshake.
+ // Its capabilities are known and the remote identity is verified.
+ glog.V(logger.Detail).Infoln("<-addpeer:", c)
+ err := srv.protoHandshakeChecks(peers, c)
+ if err != nil {
+ glog.V(logger.Detail).Infof("Not adding %v as peer: %v", c, err)
+ } else {
+ // The handshakes are done and it passed all checks.
+ p := newPeer(c, srv.Protocols)
+ peers[c.id] = p
+ go srv.runPeer(p)
+ }
+ // The dialer logic relies on the assumption that
+ // dial tasks complete after the peer has been added or
+ // discarded. Unblock the task last.
+ c.cont <- err
+ case p := <-srv.delpeer:
+ // A peer disconnected.
+ glog.V(logger.Detail).Infoln("<-delpeer:", p)
+ delete(peers, p.ID())
+ }
+ }
+
+ // Terminate discovery. If there is a running lookup it will terminate soon.
srv.ntab.Close()
- if srv.listener != nil {
- // this unblocks listener Accept
- srv.listener.Close()
+ // Disconnect all peers.
+ for _, p := range peers {
+ p.Disconnect(DiscQuitting)
+ }
+ // Wait for peers to shut down. Pending connections and tasks are
+ // not handled here and will terminate soon-ish because srv.quit
+ // is closed.
+ glog.V(logger.Detail).Infof("ignoring %d pending tasks at spindown", len(tasks))
+ for len(peers) > 0 {
+ p := <-srv.delpeer
+ glog.V(logger.Detail).Infoln("<-delpeer (spindown):", p)
+ delete(peers, p.ID())
}
- close(srv.quit)
- srv.loopWG.Wait()
+}
- // No new peers can be added at this point because dialLoop and
- // listenLoop are down. It is safe to call peerWG.Wait because
- // peerWG.Add is not called outside of those loops.
- srv.lock.Lock()
- for _, peer := range srv.peers {
- peer.Disconnect(DiscQuitting)
+func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
+ // Drop connections with no matching protocols.
+ if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
+ return DiscUselessPeer
}
- srv.lock.Unlock()
- srv.peerWG.Wait()
+ // Repeat the encryption handshake checks because the
+ // peer set might have changed between the handshakes.
+ return srv.encHandshakeChecks(peers, c)
}
-// Self returns the local node's endpoint information.
-func (srv *Server) Self() *discover.Node {
- srv.lock.RLock()
- defer srv.lock.RUnlock()
- if !srv.running {
- return &discover.Node{IP: net.ParseIP("0.0.0.0")}
+func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
+ switch {
+ case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
+ return DiscTooManyPeers
+ case peers[c.id] != nil:
+ return DiscAlreadyConnected
+ case c.id == srv.ntab.Self().ID:
+ return DiscSelf
+ default:
+ return nil
}
- return srv.ntab.Self()
}
-// main loop for adding connections via listening
+// listenLoop runs in its own goroutine and accepts
+// inbound connections.
func (srv *Server) listenLoop() {
defer srv.loopWG.Done()
+ glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr())
// This channel acts as a semaphore limiting
// active inbound connections that are lingering pre-handshake.
@@ -305,204 +513,92 @@ func (srv *Server) listenLoop() {
slots <- struct{}{}
}
- glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr())
for {
<-slots
- conn, err := srv.listener.Accept()
+ fd, err := srv.listener.Accept()
if err != nil {
return
}
- glog.V(logger.Debug).Infof("Accepted conn %v\n", conn.RemoteAddr())
- srv.peerWG.Add(1)
+ glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr())
go func() {
- srv.startPeer(conn, nil)
+ srv.setupConn(fd, inboundConn, nil)
slots <- struct{}{}
}()
}
}
-// staticNodesLoop is responsible for periodically checking that static
-// connections are actually live, and requests dialing if not.
-func (srv *Server) staticNodesLoop() {
- // Create a default maintenance ticker, but override it requested
- cycle := staticPeerCheckInterval
- if srv.staticCycle != 0 {
- cycle = srv.staticCycle
- }
- tick := time.NewTicker(cycle)
-
- for {
- select {
- case <-srv.quit:
- return
-
- case <-tick.C:
- // Collect all the non-connected static nodes
- needed := []*discover.Node{}
- srv.lock.RLock()
- for id, node := range srv.staticNodes {
- if _, ok := srv.peers[id]; !ok {
- needed = append(needed, node)
- }
- }
- srv.lock.RUnlock()
-
- // Try to dial each of them (don't hang if server terminates)
- for _, node := range needed {
- glog.V(logger.Debug).Infof("Dialing static peer %v", node)
- select {
- case srv.staticDial <- node:
- case <-srv.quit:
- return
- }
- }
- }
- }
-}
-
-func (srv *Server) dialLoop() {
- var (
- dialed = make(chan *discover.Node)
- dialing = make(map[discover.NodeID]bool)
- findresults = make(chan []*discover.Node)
- refresh = time.NewTimer(0)
- )
- defer srv.loopWG.Done()
- defer refresh.Stop()
-
- // Limit the number of concurrent dials
- tokens := maxDialingConns
- if srv.MaxPendingPeers > 0 {
- tokens = srv.MaxPendingPeers
- }
- slots := make(chan struct{}, tokens)
- for i := 0; i < tokens; i++ {
- slots <- struct{}{}
+// setupConn runs the handshakes and attempts to add the connection
+// as a peer. It returns when the connection has been added as a peer
+// or the handshakes have failed.
+func (srv *Server) setupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) {
+ // Prevent leftover pending conns from entering the handshake.
+ srv.lock.Lock()
+ running := srv.running
+ srv.lock.Unlock()
+ c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)}
+ if !running {
+ c.close(errServerStopped)
+ return
}
- dial := func(dest *discover.Node) {
- // Don't dial nodes that would fail the checks in addPeer.
- // This is important because the connection handshake is a lot
- // of work and we'd rather avoid doing that work for peers
- // that can't be added.
- srv.lock.RLock()
- ok, _ := srv.checkPeer(dest.ID)
- srv.lock.RUnlock()
- if !ok || dialing[dest.ID] {
- return
- }
- // Request a dial slot to prevent CPU exhaustion
- <-slots
-
- dialing[dest.ID] = true
- srv.peerWG.Add(1)
- go func() {
- srv.dialNode(dest)
- slots <- struct{}{}
- dialed <- dest
- }()
+ // Run the encryption handshake.
+ var err error
+ if c.id, err = c.doEncHandshake(srv.PrivateKey, dialDest); err != nil {
+ glog.V(logger.Debug).Infof("%v faild enc handshake: %v", c, err)
+ c.close(err)
+ return
}
-
- srv.ntab.Bootstrap(srv.BootstrapNodes)
- for {
- select {
- case <-refresh.C:
- // Grab some nodes to connect to if we're not at capacity.
- srv.lock.RLock()
- needpeers := len(srv.peers) < srv.MaxPeers/2
- srv.lock.RUnlock()
- if needpeers {
- go func() {
- var target discover.NodeID
- rand.Read(target[:])
- findresults <- srv.ntab.Lookup(target)
- }()
- } else {
- // Make sure we check again if the peer count falls
- // below MaxPeers.
- refresh.Reset(refreshPeersInterval)
- }
- case dest := <-srv.staticDial:
- dial(dest)
- case dests := <-findresults:
- for _, dest := range dests {
- dial(dest)
- }
- refresh.Reset(refreshPeersInterval)
- case dest := <-dialed:
- delete(dialing, dest.ID)
- if len(dialing) == 0 {
- // Check again immediately after dialing all current candidates.
- refresh.Reset(0)
- }
- case <-srv.quit:
- // TODO: maybe wait for active dials
- return
- }
+ // For dialed connections, check that the remote public key matches.
+ if dialDest != nil && c.id != dialDest.ID {
+ c.close(DiscUnexpectedIdentity)
+ glog.V(logger.Debug).Infof("%v dialed identity mismatch, want %x", c, dialDest.ID[:8])
+ return
}
-}
-
-func (srv *Server) dialNode(dest *discover.Node) {
- addr := &net.TCPAddr{IP: dest.IP, Port: int(dest.TCP)}
- glog.V(logger.Debug).Infof("Dialing %v\n", dest)
- conn, err := srv.Dialer.Dial("tcp", addr.String())
- if err != nil {
- // dialLoop adds to the wait group counter when launching
- // dialNode, so we need to count it down again. startPeer also
- // does that when an error occurs.
- srv.peerWG.Done()
- glog.V(logger.Detail).Infof("dial error: %v", err)
+ if err := srv.checkpoint(c, srv.posthandshake); err != nil {
+ glog.V(logger.Debug).Infof("%v failed checkpoint posthandshake: %v", c, err)
+ c.close(err)
return
}
- srv.startPeer(conn, dest)
-}
-
-func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
- // TODO: handle/store session token
-
- // Run setupFunc, which should create an authenticated connection
- // and run the capability exchange. Note that any early error
- // returns during that exchange need to call peerWG.Done because
- // the callers of startPeer added the peer to the wait group already.
- fd.SetDeadline(time.Now().Add(handshakeTimeout))
-
- conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest, srv.keepconn)
+ // Run the protocol handshake
+ phs, err := c.doProtoHandshake(srv.ourHandshake)
if err != nil {
- fd.Close()
- glog.V(logger.Debug).Infof("Handshake with %v failed: %v", fd.RemoteAddr(), err)
- srv.peerWG.Done()
+ glog.V(logger.Debug).Infof("%v failed proto handshake: %v", c, err)
+ c.close(err)
return
}
- conn.MsgReadWriter = &netWrapper{
- wrapped: conn.MsgReadWriter,
- conn: fd, rtimeout: frameReadTimeout, wtimeout: frameWriteTimeout,
+ if phs.ID != c.id {
+ glog.V(logger.Debug).Infof("%v wrong proto handshake identity: %x", c, phs.ID[:8])
+ c.close(DiscUnexpectedIdentity)
+ return
}
- p := newPeer(fd, conn, srv.Protocols)
- if ok, reason := srv.addPeer(conn, p); !ok {
- glog.V(logger.Detail).Infof("Not adding %v (%v)\n", p, reason)
- p.politeDisconnect(reason)
- srv.peerWG.Done()
+ c.caps, c.name = phs.Caps, phs.Name
+ if err := srv.checkpoint(c, srv.addpeer); err != nil {
+ glog.V(logger.Debug).Infof("%v failed checkpoint addpeer: %v", c, err)
+ c.close(err)
return
}
- // The handshakes are done and it passed all checks.
- // Spawn the Peer loops.
- go srv.runPeer(p)
+ // If the checks completed successfully, runPeer has now been
+ // launched by run.
}
-// preflight checks whether a connection should be kept. it runs
-// after the encryption handshake, as soon as the remote identity is
-// known.
-func (srv *Server) keepconn(id discover.NodeID) bool {
- srv.lock.RLock()
- defer srv.lock.RUnlock()
- if _, ok := srv.staticNodes[id]; ok {
- return true // static nodes are always allowed
+// checkpoint sends the conn to run, which performs the
+// post-handshake checks for the stage (posthandshake, addpeer).
+func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error {
+ select {
+ case stage <- c:
+ case <-srv.quit:
+ return errServerStopped
}
- if _, ok := srv.trustedNodes[id]; ok {
- return true // trusted nodes are always allowed
+ select {
+ case err := <-c.cont:
+ return err
+ case <-srv.quit:
+ return errServerStopped
}
- return len(srv.peers) < srv.MaxPeers
}
+// runPeer runs in its own goroutine for each peer.
+// it waits until the Peer logic returns and removes
+// the peer.
func (srv *Server) runPeer(p *Peer) {
glog.V(logger.Debug).Infof("Added %v\n", p)
srvjslog.LogJson(&logger.P2PConnected{
@@ -511,58 +607,18 @@ func (srv *Server) runPeer(p *Peer) {
RemoteVersionString: p.Name(),
NumConnections: srv.PeerCount(),
})
+
if srv.newPeerHook != nil {
srv.newPeerHook(p)
}
discreason := p.run()
- srv.removePeer(p)
+ // Note: run waits for existing peers to be sent on srv.delpeer
+ // before returning, so this send should not select on srv.quit.
+ srv.delpeer <- p
+
glog.V(logger.Debug).Infof("Removed %v (%v)\n", p, discreason)
srvjslog.LogJson(&logger.P2PDisconnected{
RemoteId: p.ID().String(),
NumConnections: srv.PeerCount(),
})
}
-
-func (srv *Server) addPeer(conn *conn, p *Peer) (bool, DiscReason) {
- // drop connections with no matching protocols.
- if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, conn.protoHandshake.Caps) == 0 {
- return false, DiscUselessPeer
- }
- // add the peer if it passes the other checks.
- srv.lock.Lock()
- defer srv.lock.Unlock()
- if ok, reason := srv.checkPeer(conn.ID); !ok {
- return false, reason
- }
- srv.peers[conn.ID] = p
- return true, 0
-}
-
-// checkPeer verifies whether a peer looks promising and should be allowed/kept
-// in the pool, or if it's of no use.
-func (srv *Server) checkPeer(id discover.NodeID) (bool, DiscReason) {
- // First up, figure out if the peer is static or trusted
- _, static := srv.staticNodes[id]
- trusted := srv.trustedNodes[id]
-
- // Make sure the peer passes all required checks
- switch {
- case !srv.running:
- return false, DiscQuitting
- case !static && !trusted && len(srv.peers) >= srv.MaxPeers:
- return false, DiscTooManyPeers
- case srv.peers[id] != nil:
- return false, DiscAlreadyConnected
- case id == srv.ntab.Self().ID:
- return false, DiscSelf
- default:
- return true, 0
- }
-}
-
-func (srv *Server) removePeer(p *Peer) {
- srv.lock.Lock()
- delete(srv.peers, p.ID())
- srv.lock.Unlock()
- srv.peerWG.Done()
-}
diff --git a/p2p/server_test.go b/p2p/server_test.go
index 6f7aaf8e1..01448cc7b 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -2,8 +2,10 @@ package p2p
import (
"crypto/ecdsa"
+ "errors"
"math/rand"
"net"
+ "reflect"
"testing"
"time"
@@ -12,29 +14,50 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover"
)
-func startTestServer(t *testing.T, pf newPeerHook) *Server {
+func init() {
+ // glog.SetV(6)
+ // glog.SetToStderr(true)
+}
+
+type testTransport struct {
+ id discover.NodeID
+ *rlpx
+
+ closeErr error
+}
+
+func newTestTransport(id discover.NodeID, fd net.Conn) transport {
+ wrapped := newRLPX(fd).(*rlpx)
+ wrapped.rw = newRLPXFrameRW(fd, secrets{
+ MAC: zero16,
+ AES: zero16,
+ IngressMAC: sha3.NewKeccak256(),
+ EgressMAC: sha3.NewKeccak256(),
+ })
+ return &testTransport{id: id, rlpx: wrapped}
+}
+
+func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
+ return c.id, nil
+}
+
+func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
+ return &protoHandshake{ID: c.id, Name: "test"}, nil
+}
+
+func (c *testTransport) close(err error) {
+ c.rlpx.fd.Close()
+ c.closeErr = err
+}
+
+func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
server := &Server{
- Name: "test",
- MaxPeers: 10,
- ListenAddr: "127.0.0.1:0",
- PrivateKey: newkey(),
- newPeerHook: pf,
- setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) {
- id := randomID()
- if !keepconn(id) {
- return nil, DiscAlreadyConnected
- }
- rw := newRlpxFrameRW(fd, secrets{
- MAC: zero16,
- AES: zero16,
- IngressMAC: sha3.NewKeccak256(),
- EgressMAC: sha3.NewKeccak256(),
- })
- return &conn{
- MsgReadWriter: rw,
- protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
- }, nil
- },
+ Name: "test",
+ MaxPeers: 10,
+ ListenAddr: "127.0.0.1:0",
+ PrivateKey: newkey(),
+ newPeerHook: pf,
+ newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
}
if err := server.Start(); err != nil {
t.Fatalf("Could not start server: %v", err)
@@ -45,7 +68,11 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
func TestServerListen(t *testing.T) {
// start the test server
connected := make(chan *Peer)
- srv := startTestServer(t, func(p *Peer) {
+ remid := randomID()
+ srv := startTestServer(t, remid, func(p *Peer) {
+ if p.ID() != remid {
+ t.Error("peer func called with wrong node id")
+ }
if p == nil {
t.Error("peer func called with nil conn")
}
@@ -67,6 +94,10 @@ func TestServerListen(t *testing.T) {
t.Errorf("peer started with wrong conn: got %v, want %v",
peer.LocalAddr(), conn.RemoteAddr())
}
+ peers := srv.Peers()
+ if !reflect.DeepEqual(peers, []*Peer{peer}) {
+ t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
+ }
case <-time.After(1 * time.Second):
t.Error("server did not accept within one second")
}
@@ -92,23 +123,33 @@ func TestServerDial(t *testing.T) {
// start the server
connected := make(chan *Peer)
- srv := startTestServer(t, func(p *Peer) { connected <- p })
+ remid := randomID()
+ srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
defer close(connected)
defer srv.Stop()
// tell the server to connect
tcpAddr := listener.Addr().(*net.TCPAddr)
- srv.staticDial <- &discover.Node{IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)}
+ srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
select {
case conn := <-accepted:
select {
case peer := <-connected:
+ if peer.ID() != remid {
+ t.Errorf("peer has wrong id")
+ }
+ if peer.Name() != "test" {
+ t.Errorf("peer has wrong name")
+ }
if peer.RemoteAddr().String() != conn.LocalAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v",
peer.RemoteAddr(), conn.LocalAddr())
}
- // TODO: validate more fields
+ peers := srv.Peers()
+ if !reflect.DeepEqual(peers, []*Peer{peer}) {
+ t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
+ }
case <-time.After(1 * time.Second):
t.Error("server did not launch peer within one second")
}
@@ -118,331 +159,250 @@ func TestServerDial(t *testing.T) {
}
}
-// This test checks that connections are disconnected
-// just after the encryption handshake when the server is
-// at capacity.
-//
-// It also serves as a light-weight integration test.
-func TestServerDisconnectAtCap(t *testing.T) {
- started := make(chan *Peer)
- srv := &Server{
- ListenAddr: "127.0.0.1:0",
- PrivateKey: newkey(),
- MaxPeers: 10,
- NoDial: true,
- // This hook signals that the peer was actually started. We
- // need to wait for the peer to be started before dialing the
- // next connection to get a deterministic peer count.
- newPeerHook: func(p *Peer) { started <- p },
- }
- if err := srv.Start(); err != nil {
- t.Fatal(err)
- }
- defer srv.Stop()
-
- nconns := srv.MaxPeers + 1
- dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
- for i := 0; i < nconns; i++ {
- conn, err := dialer.Dial("tcp", srv.ListenAddr)
- if err != nil {
- t.Fatalf("conn %d: dial error: %v", i, err)
- }
- // Close the connection when the test ends, before
- // shutting down the server.
- defer conn.Close()
- // Run the handshakes just like a real peer would.
- key := newkey()
- hs := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
- _, err = setupConn(conn, key, hs, srv.Self(), keepalways)
- if i == nconns-1 {
- // When handling the last connection, the server should
- // disconnect immediately instead of running the protocol
- // handshake.
- if err != DiscTooManyPeers {
- t.Errorf("conn %d: got error %q, expected %q", i, err, DiscTooManyPeers)
- }
- } else {
- // For all earlier connections, the handshake should go through.
- if err != nil {
- t.Fatalf("conn %d: unexpected error: %v", i, err)
- }
- // Wait for runPeer to be started.
- <-started
+// This test checks that tasks generated by dialstate are
+// actually executed and taskdone is called for them.
+func TestServerTaskScheduling(t *testing.T) {
+ var (
+ done = make(chan *testTask)
+ quit, returned = make(chan struct{}), make(chan struct{})
+ tc = 0
+ tg = taskgen{
+ newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
+ tc++
+ return []task{&testTask{index: tc - 1}}
+ },
+ doneFunc: func(t task) {
+ select {
+ case done <- t.(*testTask):
+ case <-quit:
+ }
+ },
}
- }
-}
+ )
-// Tests that static peers are (re)connected, and done so even above max peers.
-func TestServerStaticPeers(t *testing.T) {
- // Create a test server with limited connection slots
- started := make(chan *Peer)
- server := &Server{
- ListenAddr: "127.0.0.1:0",
- PrivateKey: newkey(),
- MaxPeers: 3,
- newPeerHook: func(p *Peer) { started <- p },
- staticCycle: time.Second,
- }
- if err := server.Start(); err != nil {
- t.Fatal(err)
+ // The Server in this test isn't actually running
+ // because we're only interested in what run does.
+ srv := &Server{
+ MaxPeers: 10,
+ quit: make(chan struct{}),
+ ntab: fakeTable{},
+ running: true,
}
- defer server.Stop()
-
- // Fill up all the slots on the server
- dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
- for i := 0; i < server.MaxPeers; i++ {
- // Establish a new connection
- conn, err := dialer.Dial("tcp", server.ListenAddr)
- if err != nil {
- t.Fatalf("conn %d: dial error: %v", i, err)
- }
- defer conn.Close()
+ srv.loopWG.Add(1)
+ go func() {
+ srv.run(tg)
+ close(returned)
+ }()
- // Run the handshakes just like a real peer would, and wait for completion
- key := newkey()
- shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
- if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil {
- t.Fatalf("conn %d: unexpected error: %v", i, err)
- }
- <-started
+ var gotdone []*testTask
+ for i := 0; i < 100; i++ {
+ gotdone = append(gotdone, <-done)
}
- // Open a TCP listener to accept static connections
- listener, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- t.Fatalf("failed to setup listener: %v", err)
- }
- defer listener.Close()
-
- connected := make(chan net.Conn)
- go func() {
- for i := 0; i < 3; i++ {
- conn, err := listener.Accept()
- if err == nil {
- connected <- conn
- }
+ for i, task := range gotdone {
+ if task.index != i {
+ t.Errorf("task %d has wrong index, got %d", i, task.index)
+ break
+ }
+ if !task.called {
+ t.Errorf("task %d was not called", i)
+ break
}
- }()
- // Inject a static node and wait for a remote dial, then redial, then nothing
- addr := listener.Addr().(*net.TCPAddr)
- static := &discover.Node{
- ID: discover.PubkeyID(&newkey().PublicKey),
- IP: addr.IP,
- TCP: uint16(addr.Port),
}
- server.AddPeer(static)
+ close(quit)
+ srv.Stop()
select {
- case conn := <-connected:
- // Close the first connection, expect redial
- conn.Close()
-
- case <-time.After(2 * server.staticCycle):
- t.Fatalf("remote dial timeout")
+ case <-returned:
+ case <-time.After(500 * time.Millisecond):
+ t.Error("Server.run did not return within 500ms")
}
+}
- select {
- case conn := <-connected:
- // Keep the second connection, don't expect redial
- defer conn.Close()
-
- case <-time.After(2 * server.staticCycle):
- t.Fatalf("remote re-dial timeout")
- }
+type taskgen struct {
+ newFunc func(running int, peers map[discover.NodeID]*Peer) []task
+ doneFunc func(task)
+}
- select {
- case <-time.After(2 * server.staticCycle):
- // Timeout as no dial occurred
+func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
+ return tg.newFunc(running, peers)
+}
+func (tg taskgen) taskDone(t task, now time.Time) {
+ tg.doneFunc(t)
+}
+func (tg taskgen) addStatic(*discover.Node) {
+}
- case <-connected:
- t.Fatalf("connected node dialed")
- }
+type testTask struct {
+ index int
+ called bool
}
-// Tests that trusted peers and can connect above max peer caps.
-func TestServerTrustedPeers(t *testing.T) {
+func (t *testTask) Do(srv *Server) {
+ t.called = true
+}
- // Create a trusted peer to accept connections from
- key := newkey()
- trusted := &discover.Node{
- ID: discover.PubkeyID(&key.PublicKey),
- }
- // Create a test server with limited connection slots
- started := make(chan *Peer)
- server := &Server{
- ListenAddr: "127.0.0.1:0",
+// This test checks that connections are disconnected
+// just after the encryption handshake when the server is
+// at capacity. Trusted connections should still be accepted.
+func TestServerAtCap(t *testing.T) {
+ trustedID := randomID()
+ srv := &Server{
PrivateKey: newkey(),
- MaxPeers: 3,
+ MaxPeers: 10,
NoDial: true,
- TrustedNodes: []*discover.Node{trusted},
- newPeerHook: func(p *Peer) { started <- p },
+ TrustedNodes: []*discover.Node{{ID: trustedID}},
}
- if err := server.Start(); err != nil {
- t.Fatal(err)
+ if err := srv.Start(); err != nil {
+ t.Fatalf("could not start: %v", err)
}
- defer server.Stop()
+ defer srv.Stop()
- // Fill up all the slots on the server
- dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
- for i := 0; i < server.MaxPeers; i++ {
- // Establish a new connection
- conn, err := dialer.Dial("tcp", server.ListenAddr)
- if err != nil {
- t.Fatalf("conn %d: dial error: %v", i, err)
- }
- defer conn.Close()
+ newconn := func(id discover.NodeID) *conn {
+ fd, _ := net.Pipe()
+ tx := newTestTransport(id, fd)
+ return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
+ }
- // Run the handshakes just like a real peer would, and wait for completion
- key := newkey()
- shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
- if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil {
- t.Fatalf("conn %d: unexpected error: %v", i, err)
+ // Inject a few connections to fill up the peer set.
+ for i := 0; i < 10; i++ {
+ c := newconn(randomID())
+ if err := srv.checkpoint(c, srv.addpeer); err != nil {
+ t.Fatalf("could not add conn %d: %v", i, err)
}
- <-started
}
- // Dial from the trusted peer, ensure connection is accepted
- conn, err := dialer.Dial("tcp", server.ListenAddr)
- if err != nil {
- t.Fatalf("trusted node: dial error: %v", err)
+ // Try inserting a non-trusted connection.
+ c := newconn(randomID())
+ if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
+ t.Error("wrong error for insert:", err)
}
- defer conn.Close()
-
- shake := &protoHandshake{Version: baseProtocolVersion, ID: trusted.ID}
- if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil {
- t.Fatalf("trusted node: unexpected error: %v", err)
+ // Try inserting a trusted connection.
+ c = newconn(trustedID)
+ if err := srv.checkpoint(c, srv.posthandshake); err != nil {
+ t.Error("unexpected error for trusted conn @posthandshake:", err)
}
- select {
- case <-started:
- // Ok, trusted peer accepted
-
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("trusted node timeout")
+ if !c.is(trustedConn) {
+ t.Error("Server did not set trusted flag")
}
+
}
-// Tests that a failed dial will temporarily throttle a peer.
-func TestServerMaxPendingDials(t *testing.T) {
- // Start a simple test server
- server := &Server{
- ListenAddr: "127.0.0.1:0",
- PrivateKey: newkey(),
- MaxPeers: 10,
- MaxPendingPeers: 1,
- }
- if err := server.Start(); err != nil {
- t.Fatal("failed to start test server: %v", err)
+func TestServerSetupConn(t *testing.T) {
+ id := randomID()
+ srvkey := newkey()
+ srvid := discover.PubkeyID(&srvkey.PublicKey)
+ tests := []struct {
+ dontstart bool
+ tt *setupTransport
+ flags connFlag
+ dialDest *discover.Node
+
+ wantCloseErr error
+ wantCalls string
+ }{
+ {
+ dontstart: true,
+ tt: &setupTransport{id: id},
+ wantCalls: "close,",
+ wantCloseErr: errServerStopped,
+ },
+ {
+ tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
+ flags: inboundConn,
+ wantCalls: "doEncHandshake,close,",
+ wantCloseErr: errors.New("read error"),
+ },
+ {
+ tt: &setupTransport{id: id},
+ dialDest: &discover.Node{ID: randomID()},
+ flags: dynDialedConn,
+ wantCalls: "doEncHandshake,close,",
+ wantCloseErr: DiscUnexpectedIdentity,
+ },
+ {
+ tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
+ dialDest: &discover.Node{ID: id},
+ flags: dynDialedConn,
+ wantCalls: "doEncHandshake,doProtoHandshake,close,",
+ wantCloseErr: DiscUnexpectedIdentity,
+ },
+ {
+ tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
+ dialDest: &discover.Node{ID: id},
+ flags: dynDialedConn,
+ wantCalls: "doEncHandshake,doProtoHandshake,close,",
+ wantCloseErr: errors.New("foo"),
+ },
+ {
+ tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
+ flags: inboundConn,
+ wantCalls: "doEncHandshake,close,",
+ wantCloseErr: DiscSelf,
+ },
+ {
+ tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}},
+ flags: inboundConn,
+ wantCalls: "doEncHandshake,doProtoHandshake,close,",
+ wantCloseErr: DiscUselessPeer,
+ },
}
- defer server.Stop()
- // Simulate two separate remote peers
- peers := make(chan *discover.Node, 2)
- conns := make(chan net.Conn, 2)
- for i := 0; i < 2; i++ {
- listener, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- t.Fatalf("listener %d: failed to setup: %v", i, err)
- }
- defer listener.Close()
-
- addr := listener.Addr().(*net.TCPAddr)
- peers <- &discover.Node{
- ID: discover.PubkeyID(&newkey().PublicKey),
- IP: addr.IP,
- TCP: uint16(addr.Port),
+ for i, test := range tests {
+ srv := &Server{
+ PrivateKey: srvkey,
+ MaxPeers: 10,
+ NoDial: true,
+ Protocols: []Protocol{discard},
+ newTransport: func(fd net.Conn) transport { return test.tt },
}
- go func() {
- conn, err := listener.Accept()
- if err == nil {
- conns <- conn
+ if !test.dontstart {
+ if err := srv.Start(); err != nil {
+ t.Fatalf("couldn't start server: %v", err)
}
- }()
- }
- // Request a dial for both peers
- go func() {
- for i := 0; i < 2; i++ {
- server.staticDial <- <-peers // hack piggybacking the static implementation
}
- }()
-
- // Make sure only one outbound connection goes through
- var conn net.Conn
-
- select {
- case conn = <-conns:
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("first dial timeout")
- }
- select {
- case conn = <-conns:
- t.Fatalf("second dial completed prematurely")
- case <-time.After(100 * time.Millisecond):
- }
- // Finish the first dial, check the second
- conn.Close()
- select {
- case conn = <-conns:
- conn.Close()
-
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("second dial timeout")
+ p1, _ := net.Pipe()
+ srv.setupConn(p1, test.flags, test.dialDest)
+ if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
+ t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
+ }
+ if test.tt.calls != test.wantCalls {
+ t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
+ }
}
}
-func TestServerMaxPendingAccepts(t *testing.T) {
- // Start a test server and a peer sink for synchronization
- started := make(chan *Peer)
- server := &Server{
- ListenAddr: "127.0.0.1:0",
- PrivateKey: newkey(),
- MaxPeers: 10,
- MaxPendingPeers: 1,
- NoDial: true,
- newPeerHook: func(p *Peer) { started <- p },
- }
- if err := server.Start(); err != nil {
- t.Fatal("failed to start test server: %v", err)
- }
- defer server.Stop()
+type setupTransport struct {
+ id discover.NodeID
+ encHandshakeErr error
- // Try and connect to the server on multiple threads concurrently
- conns := make([]net.Conn, 2)
- for i := 0; i < 2; i++ {
- dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
+ phs *protoHandshake
+ protoHandshakeErr error
- conn, err := dialer.Dial("tcp", server.ListenAddr)
- if err != nil {
- t.Fatalf("failed to dial server: %v", err)
- }
- conns[i] = conn
- }
- // Check that a handshake on the second doesn't pass
- go func() {
- key := newkey()
- shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
- if _, err := setupConn(conns[1], key, shake, server.Self(), keepalways); err != nil {
- t.Fatalf("failed to run handshake: %v", err)
- }
- }()
- select {
- case <-started:
- t.Fatalf("handshake on second connection accepted")
+ calls string
+ closeErr error
+}
- case <-time.After(time.Second):
- }
- // Shake on first, check that both go through
- go func() {
- key := newkey()
- shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
- if _, err := setupConn(conns[0], key, shake, server.Self(), keepalways); err != nil {
- t.Fatalf("failed to run handshake: %v", err)
- }
- }()
- for i := 0; i < 2; i++ {
- select {
- case <-started:
- case <-time.After(time.Second):
- t.Fatalf("peer %d: handshake timeout", i)
- }
+func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
+ c.calls += "doEncHandshake,"
+ return c.id, c.encHandshakeErr
+}
+func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
+ c.calls += "doProtoHandshake,"
+ if c.protoHandshakeErr != nil {
+ return nil, c.protoHandshakeErr
}
+ return c.phs, nil
+}
+func (c *setupTransport) close(err error) {
+ c.calls += "close,"
+ c.closeErr = err
+}
+
+// setupConn shouldn't write to/read from the connection.
+func (c *setupTransport) WriteMsg(Msg) error {
+ panic("WriteMsg called on setupTransport")
+}
+func (c *setupTransport) ReadMsg() (Msg, error) {
+ panic("ReadMsg called on setupTransport")
}
func newkey() *ecdsa.PrivateKey {
@@ -459,7 +419,3 @@ func randomID() (id discover.NodeID) {
}
return id
}
-
-func keepalways(id discover.NodeID) bool {
- return true
-}