diff options
Diffstat (limited to 'p2p')
-rw-r--r-- | p2p/dial.go | 13 | ||||
-rw-r--r-- | p2p/dial_test.go | 44 | ||||
-rw-r--r-- | p2p/discover/database.go | 17 | ||||
-rw-r--r-- | p2p/discover/database_test.go | 18 | ||||
-rw-r--r-- | p2p/discover/table.go | 6 | ||||
-rw-r--r-- | p2p/discover/udp.go | 2 | ||||
-rw-r--r-- | p2p/discover/udp_test.go | 8 | ||||
-rw-r--r-- | p2p/message.go | 26 | ||||
-rw-r--r-- | p2p/metrics.go | 8 | ||||
-rw-r--r-- | p2p/protocols/protocol.go | 311 | ||||
-rw-r--r-- | p2p/protocols/protocol_test.go | 389 | ||||
-rw-r--r-- | p2p/rlpx.go | 18 | ||||
-rw-r--r-- | p2p/rlpx_test.go | 38 | ||||
-rw-r--r-- | p2p/server.go | 4 | ||||
-rw-r--r-- | p2p/testing/peerpool.go | 67 | ||||
-rw-r--r-- | p2p/testing/protocolsession.go | 280 | ||||
-rw-r--r-- | p2p/testing/protocoltester.go | 269 |
17 files changed, 1451 insertions, 67 deletions
diff --git a/p2p/dial.go b/p2p/dial.go index f5ff2c211..d8feceb9f 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -154,6 +154,9 @@ func (s *dialstate) addStatic(n *discover.Node) { func (s *dialstate) removeStatic(n *discover.Node) { // This removes a task so future attempts to connect will not be made. delete(s.static, n.ID) + // This removes a previous dial timestamp so that application + // can force a server to reconnect with chosen peer immediately. + s.hist.remove(n.ID) } func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { @@ -390,6 +393,16 @@ func (h dialHistory) min() pastDial { } func (h *dialHistory) add(id discover.NodeID, exp time.Time) { heap.Push(h, pastDial{id, exp}) + +} +func (h *dialHistory) remove(id discover.NodeID) bool { + for i, v := range *h { + if v.id == id { + heap.Remove(h, i) + return true + } + } + return false } func (h dialHistory) contains(id discover.NodeID) bool { for _, v := range h { diff --git a/p2p/dial_test.go b/p2p/dial_test.go index ad18ef9ab..2a7941fc6 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -515,6 +515,50 @@ func TestDialStateStaticDial(t *testing.T) { }) } +// This test checks that static peers will be redialed immediately if they were re-added to a static list. +func TestDialStaticAfterReset(t *testing.T) { + wantStatic := []*discover.Node{ + {ID: uintID(1)}, + {ID: uintID(2)}, + } + + rounds := []round{ + // Static dials are launched for the nodes that aren't yet connected. + { + peers: nil, + new: []task{ + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, + }, + }, + // No new dial tasks, all peers are connected. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(1)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(2)}}, + }, + done: []task{ + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, + }, + new: []task{ + &waitExpireTask{Duration: 30 * time.Second}, + }, + }, + } + dTest := dialtest{ + init: newDialState(wantStatic, nil, fakeTable{}, 0, nil), + rounds: rounds, + } + runDialTest(t, dTest) + for _, n := range wantStatic { + dTest.init.removeStatic(n) + dTest.init.addStatic(n) + } + // without removing peers they will be considered recently dialed + runDialTest(t, dTest) +} + // This test checks that past dials are not retried for some time. func TestDialStateCache(t *testing.T) { wantStatic := []*discover.Node{ diff --git a/p2p/discover/database.go b/p2p/discover/database.go index b136609f2..6f98de9b4 100644 --- a/p2p/discover/database.go +++ b/p2p/discover/database.go @@ -257,7 +257,7 @@ func (db *nodeDB) expireNodes() error { } // Skip the node if not expired yet (and not self) if !bytes.Equal(id[:], db.self[:]) { - if seen := db.lastPong(id); seen.After(threshold) { + if seen := db.bondTime(id); seen.After(threshold) { continue } } @@ -278,13 +278,18 @@ func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix()) } -// lastPong retrieves the time of the last successful contact from remote node. -func (db *nodeDB) lastPong(id NodeID) time.Time { +// bondTime retrieves the time of the last successful pong from remote node. +func (db *nodeDB) bondTime(id NodeID) time.Time { return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0) } -// updateLastPong updates the last time a remote node successfully contacted. -func (db *nodeDB) updateLastPong(id NodeID, instance time.Time) error { +// hasBond reports whether the given node is considered bonded. +func (db *nodeDB) hasBond(id NodeID) bool { + return time.Since(db.bondTime(id)) < nodeDBNodeExpiration +} + +// updateBondTime updates the last pong time of a node. +func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix()) } @@ -327,7 +332,7 @@ seek: if n.ID == db.self { continue seek } - if now.Sub(db.lastPong(n.ID)) > maxAge { + if now.Sub(db.bondTime(n.ID)) > maxAge { continue seek } for i := range nodes { diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go index be972fd2c..c4fa44d09 100644 --- a/p2p/discover/database_test.go +++ b/p2p/discover/database_test.go @@ -125,13 +125,13 @@ func TestNodeDBFetchStore(t *testing.T) { t.Errorf("ping: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node pong object - if stored := db.lastPong(node.ID); stored.Unix() != 0 { + if stored := db.bondTime(node.ID); stored.Unix() != 0 { t.Errorf("pong: non-existing object: %v", stored) } - if err := db.updateLastPong(node.ID, inst); err != nil { + if err := db.updateBondTime(node.ID, inst); err != nil { t.Errorf("pong: failed to update: %v", err) } - if stored := db.lastPong(node.ID); stored.Unix() != inst.Unix() { + if stored := db.bondTime(node.ID); stored.Unix() != inst.Unix() { t.Errorf("pong: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node findnode-failure object @@ -224,8 +224,8 @@ func TestNodeDBSeedQuery(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to insert lastPong: %v", i, err) + if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to insert bondTime: %v", i, err) } } @@ -332,8 +332,8 @@ func TestNodeDBExpiration(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to update pong: %v", i, err) + if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } // Expire some of them, and check the rest @@ -365,8 +365,8 @@ func TestNodeDBSelfExpiration(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to update pong: %v", i, err) + if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } // Expire the nodes and make sure self has been evacuated too diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 17c9db777..6509326e6 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -455,7 +455,7 @@ func (tab *Table) loadSeedNodes(bond bool) { } for i := range seeds { seed := seeds[i] - age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.lastPong(seed.ID)) }} + age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }} log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age) tab.add(seed) } @@ -596,7 +596,7 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16 } // Start bonding if we haven't seen this node for a while or if it failed findnode too often. node, fails := tab.db.node(id), tab.db.findFails(id) - age := time.Since(tab.db.lastPong(id)) + age := time.Since(tab.db.bondTime(id)) var result error if fails > 0 || age > nodeDBNodeExpiration { log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age) @@ -663,7 +663,7 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { if err := tab.net.ping(id, addr); err != nil { return err } - tab.db.updateLastPong(id, time.Now()) + tab.db.updateBondTime(id, time.Now()) return nil } diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index e40de2c36..524c6e498 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -613,7 +613,7 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte if expired(req.Expiration) { return errExpired } - if t.db.node(fromID) == nil { + if !t.db.hasBond(fromID) { // No bond exists, we don't process the packet. This prevents // an attack vector where the discovery protocol could be used // to amplify traffic in a DDOS attack. A malicious actor diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index 3ffa5c4dd..db9804f7b 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -247,12 +247,8 @@ func TestUDP_findnode(t *testing.T) { // ensure there's a bond with the test node, // findnode won't be accepted otherwise. - test.table.db.updateNode(NewNode( - PubkeyID(&test.remotekey.PublicKey), - test.remoteaddr.IP, - uint16(test.remoteaddr.Port), - 99, - )) + test.table.db.updateBondTime(PubkeyID(&test.remotekey.PublicKey), time.Now()) + // check that closest neighbors are returned. test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) expected := test.table.closest(targetHash, bucketSize) diff --git a/p2p/message.go b/p2p/message.go index 5690494bf..50b419970 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -22,8 +22,6 @@ import ( "fmt" "io" "io/ioutil" - "net" - "sync" "sync/atomic" "time" @@ -112,30 +110,6 @@ func SendItems(w MsgWriter, msgcode uint64, elems ...interface{}) error { return Send(w, msgcode, elems) } -// netWrapper wraps a MsgReadWriter with locks around -// ReadMsg/WriteMsg and applies read/write deadlines. -type netWrapper struct { - rmu, wmu sync.Mutex - - rtimeout, wtimeout time.Duration - conn net.Conn - wrapped MsgReadWriter -} - -func (rw *netWrapper) ReadMsg() (Msg, error) { - rw.rmu.Lock() - defer rw.rmu.Unlock() - rw.conn.SetReadDeadline(time.Now().Add(rw.rtimeout)) - return rw.wrapped.ReadMsg() -} - -func (rw *netWrapper) WriteMsg(msg Msg) error { - rw.wmu.Lock() - defer rw.wmu.Unlock() - rw.conn.SetWriteDeadline(time.Now().Add(rw.wtimeout)) - return rw.wrapped.WriteMsg(msg) -} - // eofSignal wraps a reader with eof signaling. the eof channel is // closed when the wrapped reader returns an error or when count bytes // have been read. diff --git a/p2p/metrics.go b/p2p/metrics.go index 98b61901d..4cbff90ac 100644 --- a/p2p/metrics.go +++ b/p2p/metrics.go @@ -25,10 +25,10 @@ import ( ) var ( - ingressConnectMeter = metrics.NewMeter("p2p/InboundConnects") - ingressTrafficMeter = metrics.NewMeter("p2p/InboundTraffic") - egressConnectMeter = metrics.NewMeter("p2p/OutboundConnects") - egressTrafficMeter = metrics.NewMeter("p2p/OutboundTraffic") + ingressConnectMeter = metrics.NewRegisteredMeter("p2p/InboundConnects", nil) + ingressTrafficMeter = metrics.NewRegisteredMeter("p2p/InboundTraffic", nil) + egressConnectMeter = metrics.NewRegisteredMeter("p2p/OutboundConnects", nil) + egressTrafficMeter = metrics.NewRegisteredMeter("p2p/OutboundTraffic", nil) ) // meteredConn is a wrapper around a network TCP connection that meters both the diff --git a/p2p/protocols/protocol.go b/p2p/protocols/protocol.go new file mode 100644 index 000000000..9914c9958 --- /dev/null +++ b/p2p/protocols/protocol.go @@ -0,0 +1,311 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +/* +Package protocols is an extension to p2p. It offers a user friendly simple way to define +devp2p subprotocols by abstracting away code standardly shared by protocols. + +* automate assigments of code indexes to messages +* automate RLP decoding/encoding based on reflecting +* provide the forever loop to read incoming messages +* standardise error handling related to communication +* standardised handshake negotiation +* TODO: automatic generation of wire protocol specification for peers + +*/ +package protocols + +import ( + "context" + "fmt" + "reflect" + "sync" + + "github.com/ethereum/go-ethereum/p2p" +) + +// error codes used by this protocol scheme +const ( + ErrMsgTooLong = iota + ErrDecode + ErrWrite + ErrInvalidMsgCode + ErrInvalidMsgType + ErrHandshake + ErrNoHandler + ErrHandler +) + +// error description strings associated with the codes +var errorToString = map[int]string{ + ErrMsgTooLong: "Message too long", + ErrDecode: "Invalid message (RLP error)", + ErrWrite: "Error sending message", + ErrInvalidMsgCode: "Invalid message code", + ErrInvalidMsgType: "Invalid message type", + ErrHandshake: "Handshake error", + ErrNoHandler: "No handler registered error", + ErrHandler: "Message handler error", +} + +/* +Error implements the standard go error interface. +Use: + + errorf(code, format, params ...interface{}) + +Prints as: + + <description>: <details> + +where description is given by code in errorToString +and details is fmt.Sprintf(format, params...) + +exported field Code can be checked +*/ +type Error struct { + Code int + message string + format string + params []interface{} +} + +func (e Error) Error() (message string) { + if len(e.message) == 0 { + name, ok := errorToString[e.Code] + if !ok { + panic("invalid message code") + } + e.message = name + if e.format != "" { + e.message += ": " + fmt.Sprintf(e.format, e.params...) + } + } + return e.message +} + +func errorf(code int, format string, params ...interface{}) *Error { + return &Error{ + Code: code, + format: format, + params: params, + } +} + +// Spec is a protocol specification including its name and version as well as +// the types of messages which are exchanged +type Spec struct { + // Name is the name of the protocol, often a three-letter word + Name string + + // Version is the version number of the protocol + Version uint + + // MaxMsgSize is the maximum accepted length of the message payload + MaxMsgSize uint32 + + // Messages is a list of message data types which this protocol uses, with + // each message type being sent with its array index as the code (so + // [&foo{}, &bar{}, &baz{}] would send foo, bar and baz with codes + // 0, 1 and 2 respectively) + // each message must have a single unique data type + Messages []interface{} + + initOnce sync.Once + codes map[reflect.Type]uint64 + types map[uint64]reflect.Type +} + +func (s *Spec) init() { + s.initOnce.Do(func() { + s.codes = make(map[reflect.Type]uint64, len(s.Messages)) + s.types = make(map[uint64]reflect.Type, len(s.Messages)) + for i, msg := range s.Messages { + code := uint64(i) + typ := reflect.TypeOf(msg) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + s.codes[typ] = code + s.types[code] = typ + } + }) +} + +// Length returns the number of message types in the protocol +func (s *Spec) Length() uint64 { + return uint64(len(s.Messages)) +} + +// GetCode returns the message code of a type, and boolean second argument is +// false if the message type is not found +func (s *Spec) GetCode(msg interface{}) (uint64, bool) { + s.init() + typ := reflect.TypeOf(msg) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + code, ok := s.codes[typ] + return code, ok +} + +// NewMsg construct a new message type given the code +func (s *Spec) NewMsg(code uint64) (interface{}, bool) { + s.init() + typ, ok := s.types[code] + if !ok { + return nil, false + } + return reflect.New(typ).Interface(), true +} + +// Peer represents a remote peer or protocol instance that is running on a peer connection with +// a remote peer +type Peer struct { + *p2p.Peer // the p2p.Peer object representing the remote + rw p2p.MsgReadWriter // p2p.MsgReadWriter to send messages to and read messages from + spec *Spec +} + +// NewPeer constructs a new peer +// this constructor is called by the p2p.Protocol#Run function +// the first two arguments are the arguments passed to p2p.Protocol.Run function +// the third argument is the Spec describing the protocol +func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer { + return &Peer{ + Peer: p, + rw: rw, + spec: spec, + } +} + +// Run starts the forever loop that handles incoming messages +// called within the p2p.Protocol#Run function +// the handler argument is a function which is called for each message received +// from the remote peer, a returned error causes the loop to exit +// resulting in disconnection +func (p *Peer) Run(handler func(msg interface{}) error) error { + for { + if err := p.handleIncoming(handler); err != nil { + return err + } + } +} + +// Drop disconnects a peer. +// TODO: may need to implement protocol drop only? don't want to kick off the peer +// if they are useful for other protocols +func (p *Peer) Drop(err error) { + p.Disconnect(p2p.DiscSubprotocolError) +} + +// Send takes a message, encodes it in RLP, finds the right message code and sends the +// message off to the peer +// this low level call will be wrapped by libraries providing routed or broadcast sends +// but often just used to forward and push messages to directly connected peers +func (p *Peer) Send(msg interface{}) error { + code, found := p.spec.GetCode(msg) + if !found { + return errorf(ErrInvalidMsgType, "%v", code) + } + return p2p.Send(p.rw, code, msg) +} + +// handleIncoming(code) +// is called each cycle of the main forever loop that dispatches incoming messages +// if this returns an error the loop returns and the peer is disconnected with the error +// this generic handler +// * checks message size, +// * checks for out-of-range message codes, +// * handles decoding with reflection, +// * call handlers as callbacks +func (p *Peer) handleIncoming(handle func(msg interface{}) error) error { + msg, err := p.rw.ReadMsg() + if err != nil { + return err + } + // make sure that the payload has been fully consumed + defer msg.Discard() + + if msg.Size > p.spec.MaxMsgSize { + return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize) + } + + val, ok := p.spec.NewMsg(msg.Code) + if !ok { + return errorf(ErrInvalidMsgCode, "%v", msg.Code) + } + if err := msg.Decode(val); err != nil { + return errorf(ErrDecode, "<= %v: %v", msg, err) + } + + // call the registered handler callbacks + // a registered callback take the decoded message as argument as an interface + // which the handler is supposed to cast to the appropriate type + // it is entirely safe not to check the cast in the handler since the handler is + // chosen based on the proper type in the first place + if err := handle(val); err != nil { + return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err) + } + return nil +} + +// Handshake negotiates a handshake on the peer connection +// * arguments +// * context +// * the local handshake to be sent to the remote peer +// * funcion to be called on the remote handshake (can be nil) +// * expects a remote handshake back of the same type +// * the dialing peer needs to send the handshake first and then waits for remote +// * the listening peer waits for the remote handshake and then sends it +// returns the remote handshake and an error +func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interface{}) error) (rhs interface{}, err error) { + if _, ok := p.spec.GetCode(hs); !ok { + return nil, errorf(ErrHandshake, "unknown handshake message type: %T", hs) + } + errc := make(chan error, 2) + handle := func(msg interface{}) error { + rhs = msg + if verify != nil { + return verify(rhs) + } + return nil + } + send := func() { errc <- p.Send(hs) } + receive := func() { errc <- p.handleIncoming(handle) } + + go func() { + if p.Inbound() { + receive() + send() + } else { + send() + receive() + } + }() + + for i := 0; i < 2; i++ { + select { + case err = <-errc: + case <-ctx.Done(): + err = ctx.Err() + } + if err != nil { + return nil, errorf(ErrHandshake, err.Error()) + } + } + return rhs, nil +} diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go new file mode 100644 index 000000000..053f537a6 --- /dev/null +++ b/p2p/protocols/protocol_test.go @@ -0,0 +1,389 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package protocols + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + p2ptest "github.com/ethereum/go-ethereum/p2p/testing" +) + +// handshake message type +type hs0 struct { + C uint +} + +// message to kill/drop the peer with nodeID +type kill struct { + C discover.NodeID +} + +// message to drop connection +type drop struct { +} + +/// protoHandshake represents module-independent aspects of the protocol and is +// the first message peers send and receive as part the initial exchange +type protoHandshake struct { + Version uint // local and remote peer should have identical version + NetworkID string // local and remote peer should have identical network id +} + +// checkProtoHandshake verifies local and remote protoHandshakes match +func checkProtoHandshake(testVersion uint, testNetworkID string) func(interface{}) error { + return func(rhs interface{}) error { + remote := rhs.(*protoHandshake) + if remote.NetworkID != testNetworkID { + return fmt.Errorf("%s (!= %s)", remote.NetworkID, testNetworkID) + } + + if remote.Version != testVersion { + return fmt.Errorf("%d (!= %d)", remote.Version, testVersion) + } + return nil + } +} + +// newProtocol sets up a protocol +// the run function here demonstrates a typical protocol using peerPool, handshake +// and messages registered to handlers +func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) error { + spec := &Spec{ + Name: "test", + Version: 42, + MaxMsgSize: 10 * 1024, + Messages: []interface{}{ + protoHandshake{}, + hs0{}, + kill{}, + drop{}, + }, + } + return func(p *p2p.Peer, rw p2p.MsgReadWriter) error { + peer := NewPeer(p, rw, spec) + + // initiate one-off protohandshake and check validity + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + phs := &protoHandshake{42, "420"} + hsCheck := checkProtoHandshake(phs.Version, phs.NetworkID) + _, err := peer.Handshake(ctx, phs, hsCheck) + if err != nil { + return err + } + + lhs := &hs0{42} + // module handshake demonstrating a simple repeatable exchange of same-type message + hs, err := peer.Handshake(ctx, lhs, nil) + if err != nil { + return err + } + + if rmhs := hs.(*hs0); rmhs.C > lhs.C { + return fmt.Errorf("handshake mismatch remote %v > local %v", rmhs.C, lhs.C) + } + + handle := func(msg interface{}) error { + switch msg := msg.(type) { + + case *protoHandshake: + return errors.New("duplicate handshake") + + case *hs0: + rhs := msg + if rhs.C > lhs.C { + return fmt.Errorf("handshake mismatch remote %v > local %v", rhs.C, lhs.C) + } + lhs.C += rhs.C + return peer.Send(lhs) + + case *kill: + // demonstrates use of peerPool, killing another peer connection as a response to a message + id := msg.C + pp.Get(id).Drop(errors.New("killed")) + return nil + + case *drop: + // for testing we can trigger self induced disconnect upon receiving drop message + return errors.New("dropped") + + default: + return fmt.Errorf("unknown message type: %T", msg) + } + } + + pp.Add(peer) + defer pp.Remove(peer) + return peer.Run(handle) + } +} + +func protocolTester(t *testing.T, pp *p2ptest.TestPeerPool) *p2ptest.ProtocolTester { + conf := adapters.RandomNodeConfig() + return p2ptest.NewProtocolTester(t, conf.ID, 2, newProtocol(pp)) +} + +func protoHandshakeExchange(id discover.NodeID, proto *protoHandshake) []p2ptest.Exchange { + + return []p2ptest.Exchange{ + { + Expects: []p2ptest.Expect{ + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: id, + }, + }, + }, + { + Triggers: []p2ptest.Trigger{ + { + Code: 0, + Msg: proto, + Peer: id, + }, + }, + }, + } +} + +func runProtoHandshake(t *testing.T, proto *protoHandshake, errs ...error) { + pp := p2ptest.NewTestPeerPool() + s := protocolTester(t, pp) + // TODO: make this more than one handshake + id := s.IDs[0] + if err := s.TestExchanges(protoHandshakeExchange(id, proto)...); err != nil { + t.Fatal(err) + } + var disconnects []*p2ptest.Disconnect + for i, err := range errs { + disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) + } + if err := s.TestDisconnected(disconnects...); err != nil { + t.Fatal(err) + } +} + +func TestProtoHandshakeVersionMismatch(t *testing.T) { + runProtoHandshake(t, &protoHandshake{41, "420"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 41 (!= 42)").Error())) +} + +func TestProtoHandshakeNetworkIDMismatch(t *testing.T) { + runProtoHandshake(t, &protoHandshake{42, "421"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 421 (!= 420)").Error())) +} + +func TestProtoHandshakeSuccess(t *testing.T) { + runProtoHandshake(t, &protoHandshake{42, "420"}) +} + +func moduleHandshakeExchange(id discover.NodeID, resp uint) []p2ptest.Exchange { + + return []p2ptest.Exchange{ + { + Expects: []p2ptest.Expect{ + { + Code: 1, + Msg: &hs0{42}, + Peer: id, + }, + }, + }, + { + Triggers: []p2ptest.Trigger{ + { + Code: 1, + Msg: &hs0{resp}, + Peer: id, + }, + }, + }, + } +} + +func runModuleHandshake(t *testing.T, resp uint, errs ...error) { + pp := p2ptest.NewTestPeerPool() + s := protocolTester(t, pp) + id := s.IDs[0] + if err := s.TestExchanges(protoHandshakeExchange(id, &protoHandshake{42, "420"})...); err != nil { + t.Fatal(err) + } + if err := s.TestExchanges(moduleHandshakeExchange(id, resp)...); err != nil { + t.Fatal(err) + } + var disconnects []*p2ptest.Disconnect + for i, err := range errs { + disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) + } + if err := s.TestDisconnected(disconnects...); err != nil { + t.Fatal(err) + } +} + +func TestModuleHandshakeError(t *testing.T) { + runModuleHandshake(t, 43, fmt.Errorf("handshake mismatch remote 43 > local 42")) +} + +func TestModuleHandshakeSuccess(t *testing.T) { + runModuleHandshake(t, 42) +} + +// testing complex interactions over multiple peers, relaying, dropping +func testMultiPeerSetup(a, b discover.NodeID) []p2ptest.Exchange { + + return []p2ptest.Exchange{ + { + Label: "primary handshake", + Expects: []p2ptest.Expect{ + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: a, + }, + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: b, + }, + }, + }, + { + Label: "module handshake", + Triggers: []p2ptest.Trigger{ + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: a, + }, + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: b, + }, + }, + Expects: []p2ptest.Expect{ + { + Code: 1, + Msg: &hs0{42}, + Peer: a, + }, + { + Code: 1, + Msg: &hs0{42}, + Peer: b, + }, + }, + }, + + {Label: "alternative module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{41}, Peer: a}, + {Code: 1, Msg: &hs0{41}, Peer: b}}}, + {Label: "repeated module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{1}, Peer: a}}}, + {Label: "receiving repeated module handshake", Expects: []p2ptest.Expect{{Code: 1, Msg: &hs0{43}, Peer: a}}}} +} + +func runMultiplePeers(t *testing.T, peer int, errs ...error) { + pp := p2ptest.NewTestPeerPool() + s := protocolTester(t, pp) + + if err := s.TestExchanges(testMultiPeerSetup(s.IDs[0], s.IDs[1])...); err != nil { + t.Fatal(err) + } + // after some exchanges of messages, we can test state changes + // here this is simply demonstrated by the peerPool + // after the handshake negotiations peers must be added to the pool + // time.Sleep(1) + tick := time.NewTicker(10 * time.Millisecond) + timeout := time.NewTimer(1 * time.Second) +WAIT: + for { + select { + case <-tick.C: + if pp.Has(s.IDs[0]) { + break WAIT + } + case <-timeout.C: + t.Fatal("timeout") + } + } + if !pp.Has(s.IDs[1]) { + t.Fatalf("missing peer test-1: %v (%v)", pp, s.IDs) + } + + // peer 0 sends kill request for peer with index <peer> + err := s.TestExchanges(p2ptest.Exchange{ + Triggers: []p2ptest.Trigger{ + { + Code: 2, + Msg: &kill{s.IDs[peer]}, + Peer: s.IDs[0], + }, + }, + }) + + if err != nil { + t.Fatal(err) + } + + // the peer not killed sends a drop request + err = s.TestExchanges(p2ptest.Exchange{ + Triggers: []p2ptest.Trigger{ + { + Code: 3, + Msg: &drop{}, + Peer: s.IDs[(peer+1)%2], + }, + }, + }) + + if err != nil { + t.Fatal(err) + } + + // check the actual discconnect errors on the individual peers + var disconnects []*p2ptest.Disconnect + for i, err := range errs { + disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) + } + if err := s.TestDisconnected(disconnects...); err != nil { + t.Fatal(err) + } + // test if disconnected peers have been removed from peerPool + if pp.Has(s.IDs[peer]) { + t.Fatalf("peer test-%v not dropped: %v (%v)", peer, pp, s.IDs) + } + +} + +func TestMultiplePeersDropSelf(t *testing.T) { + runMultiplePeers(t, 0, + fmt.Errorf("subprotocol error"), + fmt.Errorf("Message handler error: (msg code 3): dropped"), + ) +} + +func TestMultiplePeersDropOther(t *testing.T) { + runMultiplePeers(t, 1, + fmt.Errorf("Message handler error: (msg code 3): dropped"), + fmt.Errorf("subprotocol error"), + ) +} diff --git a/p2p/rlpx.go b/p2p/rlpx.go index 24037ecc1..1889edac9 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -108,17 +108,19 @@ func (t *rlpx) close(err error) { // Tell the remote end why we're disconnecting if possible. if t.rw != nil { if r, ok := err.(DiscReason); ok && r != DiscNetworkError { - t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)) - SendItems(t.rw, discMsg, r) + // rlpx tries to send DiscReason to disconnected peer + // if the connection is net.Pipe (in-memory simulation) + // it hangs forever, since net.Pipe does not implement + // a write deadline. Because of this only try to send + // the disconnect reason message if there is no error. + if err := t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)); err == nil { + SendItems(t.rw, discMsg, r) + } } } t.fd.Close() } -// doEncHandshake runs the protocol handshake using authenticated -// messages. the protocol handshake is the first authenticated message -// and also verifies whether the encryption handshake 'worked' and the -// remote side actually provided the right public key. func (t *rlpx) doProtoHandshake(our *protoHandshake) (their *protoHandshake, err error) { // Writing our handshake happens concurrently, we prefer // returning the handshake read error. If the remote side @@ -169,6 +171,10 @@ func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake, return &hs, nil } +// doEncHandshake runs the protocol handshake using authenticated +// messages. the protocol handshake is the first authenticated message +// and also verifies whether the encryption handshake 'worked' and the +// remote side actually provided the right public key. func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *discover.Node) (discover.NodeID, error) { var ( sec secrets diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go index f4cefa650..bca460402 100644 --- a/p2p/rlpx_test.go +++ b/p2p/rlpx_test.go @@ -156,14 +156,18 @@ func TestProtocolHandshake(t *testing.T) { node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44} hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}} - fd0, fd1 = net.Pipe() - wg sync.WaitGroup + wg sync.WaitGroup ) + fd0, fd1, err := tcpPipe() + if err != nil { + t.Fatal(err) + } + wg.Add(2) go func() { defer wg.Done() - defer fd1.Close() + defer fd0.Close() rlpx := newRLPX(fd0) remid, err := rlpx.doEncHandshake(prv0, node1) if err != nil { @@ -597,3 +601,31 @@ func TestHandshakeForwardCompatibility(t *testing.T) { t.Errorf("ingress-mac('foo') mismatch:\ngot %x\nwant %x", fooIngressHash, wantFooIngressHash) } } + +// tcpPipe creates an in process full duplex pipe based on a localhost TCP socket +func tcpPipe() (net.Conn, net.Conn, error) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer l.Close() + + var aconn net.Conn + aerr := make(chan error, 1) + go func() { + var err error + aconn, err = l.Accept() + aerr <- err + }() + + dconn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + <-aerr + return nil, nil, err + } + if err := <-aerr; err != nil { + dconn.Close() + return nil, nil, err + } + return aconn, dconn, nil +} diff --git a/p2p/server.go b/p2p/server.go index 90e92dc05..c41d1dc15 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -36,9 +36,7 @@ import ( ) const ( - defaultDialTimeout = 15 * time.Second - refreshPeersInterval = 30 * time.Second - staticPeerCheckInterval = 15 * time.Second + defaultDialTimeout = 15 * time.Second // Connectivity defaults. maxActiveDialTasks = 16 diff --git a/p2p/testing/peerpool.go b/p2p/testing/peerpool.go new file mode 100644 index 000000000..45c6e6142 --- /dev/null +++ b/p2p/testing/peerpool.go @@ -0,0 +1,67 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package testing + +import ( + "fmt" + "sync" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover" +) + +type TestPeer interface { + ID() discover.NodeID + Drop(error) +} + +// TestPeerPool is an example peerPool to demonstrate registration of peer connections +type TestPeerPool struct { + lock sync.Mutex + peers map[discover.NodeID]TestPeer +} + +func NewTestPeerPool() *TestPeerPool { + return &TestPeerPool{peers: make(map[discover.NodeID]TestPeer)} +} + +func (self *TestPeerPool) Add(p TestPeer) { + self.lock.Lock() + defer self.lock.Unlock() + log.Trace(fmt.Sprintf("pp add peer %v", p.ID())) + self.peers[p.ID()] = p + +} + +func (self *TestPeerPool) Remove(p TestPeer) { + self.lock.Lock() + defer self.lock.Unlock() + delete(self.peers, p.ID()) +} + +func (self *TestPeerPool) Has(id discover.NodeID) bool { + self.lock.Lock() + defer self.lock.Unlock() + _, ok := self.peers[id] + return ok +} + +func (self *TestPeerPool) Get(id discover.NodeID) TestPeer { + self.lock.Lock() + defer self.lock.Unlock() + return self.peers[id] +} diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go new file mode 100644 index 000000000..361285f06 --- /dev/null +++ b/p2p/testing/protocolsession.go @@ -0,0 +1,280 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package testing + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" +) + +var errTimedOut = errors.New("timed out") + +// ProtocolSession is a quasi simulation of a pivot node running +// a service and a number of dummy peers that can send (trigger) or +// receive (expect) messages +type ProtocolSession struct { + Server *p2p.Server + IDs []discover.NodeID + adapter *adapters.SimAdapter + events chan *p2p.PeerEvent +} + +// Exchange is the basic units of protocol tests +// the triggers and expects in the arrays are run immediately and asynchronously +// thus one cannot have multiple expects for the SAME peer with DIFFERENT message types +// because it's unpredictable which expect will receive which message +// (with expect #1 and #2, messages might be sent #2 and #1, and both expects will complain about wrong message code) +// an exchange is defined on a session +type Exchange struct { + Label string + Triggers []Trigger + Expects []Expect + Timeout time.Duration +} + +// Trigger is part of the exchange, incoming message for the pivot node +// sent by a peer +type Trigger struct { + Msg interface{} // type of message to be sent + Code uint64 // code of message is given + Peer discover.NodeID // the peer to send the message to + Timeout time.Duration // timeout duration for the sending +} + +// Expect is part of an exchange, outgoing message from the pivot node +// received by a peer +type Expect struct { + Msg interface{} // type of message to expect + Code uint64 // code of message is now given + Peer discover.NodeID // the peer that expects the message + Timeout time.Duration // timeout duration for receiving +} + +// Disconnect represents a disconnect event, used and checked by TestDisconnected +type Disconnect struct { + Peer discover.NodeID // discconnected peer + Error error // disconnect reason +} + +// trigger sends messages from peers +func (self *ProtocolSession) trigger(trig Trigger) error { + simNode, ok := self.adapter.GetNode(trig.Peer) + if !ok { + return fmt.Errorf("trigger: peer %v does not exist (1- %v)", trig.Peer, len(self.IDs)) + } + mockNode, ok := simNode.Services()[0].(*mockNode) + if !ok { + return fmt.Errorf("trigger: peer %v is not a mock", trig.Peer) + } + + errc := make(chan error) + + go func() { + errc <- mockNode.Trigger(&trig) + }() + + t := trig.Timeout + if t == time.Duration(0) { + t = 1000 * time.Millisecond + } + select { + case err := <-errc: + return err + case <-time.After(t): + return fmt.Errorf("timout expecting %v to send to peer %v", trig.Msg, trig.Peer) + } +} + +// expect checks an expectation of a message sent out by the pivot node +func (self *ProtocolSession) expect(exps []Expect) error { + // construct a map of expectations for each node + peerExpects := make(map[discover.NodeID][]Expect) + for _, exp := range exps { + if exp.Msg == nil { + return errors.New("no message to expect") + } + peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp) + } + + // construct a map of mockNodes for each node + mockNodes := make(map[discover.NodeID]*mockNode) + for nodeID := range peerExpects { + simNode, ok := self.adapter.GetNode(nodeID) + if !ok { + return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(self.IDs)) + } + mockNode, ok := simNode.Services()[0].(*mockNode) + if !ok { + return fmt.Errorf("trigger: peer %v is not a mock", nodeID) + } + mockNodes[nodeID] = mockNode + } + + // done chanell cancels all created goroutines when function returns + done := make(chan struct{}) + defer close(done) + // errc catches the first error from + errc := make(chan error) + + wg := &sync.WaitGroup{} + wg.Add(len(mockNodes)) + for nodeID, mockNode := range mockNodes { + nodeID := nodeID + mockNode := mockNode + go func() { + defer wg.Done() + + // Sum all Expect timeouts to give the maximum + // time for all expectations to finish. + // mockNode.Expect checks all received messages against + // a list of expected messages and timeout for each + // of them can not be checked separately. + var t time.Duration + for _, exp := range peerExpects[nodeID] { + if exp.Timeout == time.Duration(0) { + t += 2000 * time.Millisecond + } else { + t += exp.Timeout + } + } + alarm := time.NewTimer(t) + defer alarm.Stop() + + // expectErrc is used to check if error returned + // from mockNode.Expect is not nil and to send it to + // errc only in that case. + // done channel will be closed when function + expectErrc := make(chan error) + go func() { + select { + case expectErrc <- mockNode.Expect(peerExpects[nodeID]...): + case <-done: + case <-alarm.C: + } + }() + + select { + case err := <-expectErrc: + if err != nil { + select { + case errc <- err: + case <-done: + case <-alarm.C: + errc <- errTimedOut + } + } + case <-done: + case <-alarm.C: + errc <- errTimedOut + } + + }() + } + + go func() { + wg.Wait() + // close errc when all goroutines finish to return nill err from errc + close(errc) + }() + + return <-errc +} + +// TestExchanges tests a series of exchanges against the session +func (self *ProtocolSession) TestExchanges(exchanges ...Exchange) error { + for i, e := range exchanges { + if err := self.testExchange(e); err != nil { + return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err) + } + log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label)) + } + return nil +} + +// testExchange tests a single Exchange. +// Default timeout value is 2 seconds. +func (self *ProtocolSession) testExchange(e Exchange) error { + errc := make(chan error) + done := make(chan struct{}) + defer close(done) + + go func() { + for _, trig := range e.Triggers { + err := self.trigger(trig) + if err != nil { + errc <- err + return + } + } + + select { + case errc <- self.expect(e.Expects): + case <-done: + } + }() + + // time out globally or finish when all expectations satisfied + t := e.Timeout + if t == 0 { + t = 2000 * time.Millisecond + } + alarm := time.NewTimer(t) + select { + case err := <-errc: + return err + case <-alarm.C: + return errTimedOut + } +} + +// TestDisconnected tests the disconnections given as arguments +// the disconnect structs describe what disconnect error is expected on which peer +func (self *ProtocolSession) TestDisconnected(disconnects ...*Disconnect) error { + expects := make(map[discover.NodeID]error) + for _, disconnect := range disconnects { + expects[disconnect.Peer] = disconnect.Error + } + + timeout := time.After(time.Second) + for len(expects) > 0 { + select { + case event := <-self.events: + if event.Type != p2p.PeerEventTypeDrop { + continue + } + expectErr, ok := expects[event.Peer] + if !ok { + continue + } + + if !(expectErr == nil && event.Error == "" || expectErr != nil && expectErr.Error() == event.Error) { + return fmt.Errorf("unexpected error on peer %v. expected '%v', got '%v'", event.Peer, expectErr, event.Error) + } + delete(expects, event.Peer) + case <-timeout: + return fmt.Errorf("timed out waiting for peers to disconnect") + } + } + return nil +} diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go new file mode 100644 index 000000000..a797412d6 --- /dev/null +++ b/p2p/testing/protocoltester.go @@ -0,0 +1,269 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +/* +the p2p/testing package provides a unit test scheme to check simple +protocol message exchanges with one pivot node and a number of dummy peers +The pivot test node runs a node.Service, the dummy peers run a mock node +that can be used to send and receive messages +*/ + +package testing + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "strings" + "sync" + "testing" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/rpc" +) + +// ProtocolTester is the tester environment used for unit testing protocol +// message exchanges. It uses p2p/simulations framework +type ProtocolTester struct { + *ProtocolSession + network *simulations.Network +} + +// NewProtocolTester constructs a new ProtocolTester +// it takes as argument the pivot node id, the number of dummy peers and the +// protocol run function called on a peer connection by the p2p server +func NewProtocolTester(t *testing.T, id discover.NodeID, n int, run func(*p2p.Peer, p2p.MsgReadWriter) error) *ProtocolTester { + services := adapters.Services{ + "test": func(ctx *adapters.ServiceContext) (node.Service, error) { + return &testNode{run}, nil + }, + "mock": func(ctx *adapters.ServiceContext) (node.Service, error) { + return newMockNode(), nil + }, + } + adapter := adapters.NewSimAdapter(services) + net := simulations.NewNetwork(adapter, &simulations.NetworkConfig{}) + if _, err := net.NewNodeWithConfig(&adapters.NodeConfig{ + ID: id, + EnableMsgEvents: true, + Services: []string{"test"}, + }); err != nil { + panic(err.Error()) + } + if err := net.Start(id); err != nil { + panic(err.Error()) + } + + node := net.GetNode(id).Node.(*adapters.SimNode) + peers := make([]*adapters.NodeConfig, n) + peerIDs := make([]discover.NodeID, n) + for i := 0; i < n; i++ { + peers[i] = adapters.RandomNodeConfig() + peers[i].Services = []string{"mock"} + peerIDs[i] = peers[i].ID + } + events := make(chan *p2p.PeerEvent, 1000) + node.SubscribeEvents(events) + ps := &ProtocolSession{ + Server: node.Server(), + IDs: peerIDs, + adapter: adapter, + events: events, + } + self := &ProtocolTester{ + ProtocolSession: ps, + network: net, + } + + self.Connect(id, peers...) + + return self +} + +// Stop stops the p2p server +func (self *ProtocolTester) Stop() error { + self.Server.Stop() + return nil +} + +// Connect brings up the remote peer node and connects it using the +// p2p/simulations network connection with the in memory network adapter +func (self *ProtocolTester) Connect(selfID discover.NodeID, peers ...*adapters.NodeConfig) { + for _, peer := range peers { + log.Trace(fmt.Sprintf("start node %v", peer.ID)) + if _, err := self.network.NewNodeWithConfig(peer); err != nil { + panic(fmt.Sprintf("error starting peer %v: %v", peer.ID, err)) + } + if err := self.network.Start(peer.ID); err != nil { + panic(fmt.Sprintf("error starting peer %v: %v", peer.ID, err)) + } + log.Trace(fmt.Sprintf("connect to %v", peer.ID)) + if err := self.network.Connect(selfID, peer.ID); err != nil { + panic(fmt.Sprintf("error connecting to peer %v: %v", peer.ID, err)) + } + } + +} + +// testNode wraps a protocol run function and implements the node.Service +// interface +type testNode struct { + run func(*p2p.Peer, p2p.MsgReadWriter) error +} + +func (t *testNode) Protocols() []p2p.Protocol { + return []p2p.Protocol{{ + Length: 100, + Run: t.run, + }} +} + +func (t *testNode) APIs() []rpc.API { + return nil +} + +func (t *testNode) Start(server *p2p.Server) error { + return nil +} + +func (t *testNode) Stop() error { + return nil +} + +// mockNode is a testNode which doesn't actually run a protocol, instead +// exposing channels so that tests can manually trigger and expect certain +// messages +type mockNode struct { + testNode + + trigger chan *Trigger + expect chan []Expect + err chan error + stop chan struct{} + stopOnce sync.Once +} + +func newMockNode() *mockNode { + mock := &mockNode{ + trigger: make(chan *Trigger), + expect: make(chan []Expect), + err: make(chan error), + stop: make(chan struct{}), + } + mock.testNode.run = mock.Run + return mock +} + +// Run is a protocol run function which just loops waiting for tests to +// instruct it to either trigger or expect a message from the peer +func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error { + for { + select { + case trig := <-m.trigger: + m.err <- p2p.Send(rw, trig.Code, trig.Msg) + case exps := <-m.expect: + m.err <- expectMsgs(rw, exps) + case <-m.stop: + return nil + } + } +} + +func (m *mockNode) Trigger(trig *Trigger) error { + m.trigger <- trig + return <-m.err +} + +func (m *mockNode) Expect(exp ...Expect) error { + m.expect <- exp + return <-m.err +} + +func (m *mockNode) Stop() error { + m.stopOnce.Do(func() { close(m.stop) }) + return nil +} + +func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error { + matched := make([]bool, len(exps)) + for { + msg, err := rw.ReadMsg() + if err != nil { + if err == io.EOF { + break + } + return err + } + actualContent, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return err + } + var found bool + for i, exp := range exps { + if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) { + if matched[i] { + return fmt.Errorf("message #%d received two times", i) + } + matched[i] = true + found = true + break + } + } + if !found { + expected := make([]string, 0) + for i, exp := range exps { + if matched[i] { + continue + } + expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg))) + } + return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or ")) + } + done := true + for _, m := range matched { + if !m { + done = false + break + } + } + if done { + return nil + } + } + for i, m := range matched { + if !m { + return fmt.Errorf("expected message #%d not received", i) + } + } + return nil +} + +// mustEncodeMsg uses rlp to encode a message. +// In case of error it panics. +func mustEncodeMsg(msg interface{}) []byte { + contentEnc, err := rlp.EncodeToBytes(msg) + if err != nil { + panic("content encode error: " + err.Error()) + } + return contentEnc +} |