From 407339085fd6fffd15032572bc6ceb472504a451 Mon Sep 17 00:00:00 2001 From: zelig Date: Tue, 2 Jan 2018 23:30:09 +0100 Subject: p2p/protocols, p2p/testing: protocol abstraction and testing --- p2p/protocols/protocol.go | 311 ++++++++++++++++++++++++++++++++ p2p/protocols/protocol_test.go | 389 +++++++++++++++++++++++++++++++++++++++++ p2p/testing/peerpool.go | 67 +++++++ p2p/testing/protocolsession.go | 206 ++++++++++++++++++++++ p2p/testing/protocoltester.go | 200 +++++++++++++++++++++ 5 files changed, 1173 insertions(+) create mode 100644 p2p/protocols/protocol.go create mode 100644 p2p/protocols/protocol_test.go create mode 100644 p2p/testing/peerpool.go create mode 100644 p2p/testing/protocolsession.go create mode 100644 p2p/testing/protocoltester.go (limited to 'p2p') diff --git a/p2p/protocols/protocol.go b/p2p/protocols/protocol.go new file mode 100644 index 000000000..9914c9958 --- /dev/null +++ b/p2p/protocols/protocol.go @@ -0,0 +1,311 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +/* +Package protocols is an extension to p2p. It offers a user friendly simple way to define +devp2p subprotocols by abstracting away code standardly shared by protocols. + +* automate assigments of code indexes to messages +* automate RLP decoding/encoding based on reflecting +* provide the forever loop to read incoming messages +* standardise error handling related to communication +* standardised handshake negotiation +* TODO: automatic generation of wire protocol specification for peers + +*/ +package protocols + +import ( + "context" + "fmt" + "reflect" + "sync" + + "github.com/ethereum/go-ethereum/p2p" +) + +// error codes used by this protocol scheme +const ( + ErrMsgTooLong = iota + ErrDecode + ErrWrite + ErrInvalidMsgCode + ErrInvalidMsgType + ErrHandshake + ErrNoHandler + ErrHandler +) + +// error description strings associated with the codes +var errorToString = map[int]string{ + ErrMsgTooLong: "Message too long", + ErrDecode: "Invalid message (RLP error)", + ErrWrite: "Error sending message", + ErrInvalidMsgCode: "Invalid message code", + ErrInvalidMsgType: "Invalid message type", + ErrHandshake: "Handshake error", + ErrNoHandler: "No handler registered error", + ErrHandler: "Message handler error", +} + +/* +Error implements the standard go error interface. +Use: + + errorf(code, format, params ...interface{}) + +Prints as: + + :
+ +where description is given by code in errorToString +and details is fmt.Sprintf(format, params...) + +exported field Code can be checked +*/ +type Error struct { + Code int + message string + format string + params []interface{} +} + +func (e Error) Error() (message string) { + if len(e.message) == 0 { + name, ok := errorToString[e.Code] + if !ok { + panic("invalid message code") + } + e.message = name + if e.format != "" { + e.message += ": " + fmt.Sprintf(e.format, e.params...) + } + } + return e.message +} + +func errorf(code int, format string, params ...interface{}) *Error { + return &Error{ + Code: code, + format: format, + params: params, + } +} + +// Spec is a protocol specification including its name and version as well as +// the types of messages which are exchanged +type Spec struct { + // Name is the name of the protocol, often a three-letter word + Name string + + // Version is the version number of the protocol + Version uint + + // MaxMsgSize is the maximum accepted length of the message payload + MaxMsgSize uint32 + + // Messages is a list of message data types which this protocol uses, with + // each message type being sent with its array index as the code (so + // [&foo{}, &bar{}, &baz{}] would send foo, bar and baz with codes + // 0, 1 and 2 respectively) + // each message must have a single unique data type + Messages []interface{} + + initOnce sync.Once + codes map[reflect.Type]uint64 + types map[uint64]reflect.Type +} + +func (s *Spec) init() { + s.initOnce.Do(func() { + s.codes = make(map[reflect.Type]uint64, len(s.Messages)) + s.types = make(map[uint64]reflect.Type, len(s.Messages)) + for i, msg := range s.Messages { + code := uint64(i) + typ := reflect.TypeOf(msg) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + s.codes[typ] = code + s.types[code] = typ + } + }) +} + +// Length returns the number of message types in the protocol +func (s *Spec) Length() uint64 { + return uint64(len(s.Messages)) +} + +// GetCode returns the message code of a type, and boolean second argument is +// false if the message type is not found +func (s *Spec) GetCode(msg interface{}) (uint64, bool) { + s.init() + typ := reflect.TypeOf(msg) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + code, ok := s.codes[typ] + return code, ok +} + +// NewMsg construct a new message type given the code +func (s *Spec) NewMsg(code uint64) (interface{}, bool) { + s.init() + typ, ok := s.types[code] + if !ok { + return nil, false + } + return reflect.New(typ).Interface(), true +} + +// Peer represents a remote peer or protocol instance that is running on a peer connection with +// a remote peer +type Peer struct { + *p2p.Peer // the p2p.Peer object representing the remote + rw p2p.MsgReadWriter // p2p.MsgReadWriter to send messages to and read messages from + spec *Spec +} + +// NewPeer constructs a new peer +// this constructor is called by the p2p.Protocol#Run function +// the first two arguments are the arguments passed to p2p.Protocol.Run function +// the third argument is the Spec describing the protocol +func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer { + return &Peer{ + Peer: p, + rw: rw, + spec: spec, + } +} + +// Run starts the forever loop that handles incoming messages +// called within the p2p.Protocol#Run function +// the handler argument is a function which is called for each message received +// from the remote peer, a returned error causes the loop to exit +// resulting in disconnection +func (p *Peer) Run(handler func(msg interface{}) error) error { + for { + if err := p.handleIncoming(handler); err != nil { + return err + } + } +} + +// Drop disconnects a peer. +// TODO: may need to implement protocol drop only? don't want to kick off the peer +// if they are useful for other protocols +func (p *Peer) Drop(err error) { + p.Disconnect(p2p.DiscSubprotocolError) +} + +// Send takes a message, encodes it in RLP, finds the right message code and sends the +// message off to the peer +// this low level call will be wrapped by libraries providing routed or broadcast sends +// but often just used to forward and push messages to directly connected peers +func (p *Peer) Send(msg interface{}) error { + code, found := p.spec.GetCode(msg) + if !found { + return errorf(ErrInvalidMsgType, "%v", code) + } + return p2p.Send(p.rw, code, msg) +} + +// handleIncoming(code) +// is called each cycle of the main forever loop that dispatches incoming messages +// if this returns an error the loop returns and the peer is disconnected with the error +// this generic handler +// * checks message size, +// * checks for out-of-range message codes, +// * handles decoding with reflection, +// * call handlers as callbacks +func (p *Peer) handleIncoming(handle func(msg interface{}) error) error { + msg, err := p.rw.ReadMsg() + if err != nil { + return err + } + // make sure that the payload has been fully consumed + defer msg.Discard() + + if msg.Size > p.spec.MaxMsgSize { + return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize) + } + + val, ok := p.spec.NewMsg(msg.Code) + if !ok { + return errorf(ErrInvalidMsgCode, "%v", msg.Code) + } + if err := msg.Decode(val); err != nil { + return errorf(ErrDecode, "<= %v: %v", msg, err) + } + + // call the registered handler callbacks + // a registered callback take the decoded message as argument as an interface + // which the handler is supposed to cast to the appropriate type + // it is entirely safe not to check the cast in the handler since the handler is + // chosen based on the proper type in the first place + if err := handle(val); err != nil { + return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err) + } + return nil +} + +// Handshake negotiates a handshake on the peer connection +// * arguments +// * context +// * the local handshake to be sent to the remote peer +// * funcion to be called on the remote handshake (can be nil) +// * expects a remote handshake back of the same type +// * the dialing peer needs to send the handshake first and then waits for remote +// * the listening peer waits for the remote handshake and then sends it +// returns the remote handshake and an error +func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interface{}) error) (rhs interface{}, err error) { + if _, ok := p.spec.GetCode(hs); !ok { + return nil, errorf(ErrHandshake, "unknown handshake message type: %T", hs) + } + errc := make(chan error, 2) + handle := func(msg interface{}) error { + rhs = msg + if verify != nil { + return verify(rhs) + } + return nil + } + send := func() { errc <- p.Send(hs) } + receive := func() { errc <- p.handleIncoming(handle) } + + go func() { + if p.Inbound() { + receive() + send() + } else { + send() + receive() + } + }() + + for i := 0; i < 2; i++ { + select { + case err = <-errc: + case <-ctx.Done(): + err = ctx.Err() + } + if err != nil { + return nil, errorf(ErrHandshake, err.Error()) + } + } + return rhs, nil +} diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go new file mode 100644 index 000000000..b041a650a --- /dev/null +++ b/p2p/protocols/protocol_test.go @@ -0,0 +1,389 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package protocols + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + p2ptest "github.com/ethereum/go-ethereum/p2p/testing" +) + +// handshake message type +type hs0 struct { + C uint +} + +// message to kill/drop the peer with nodeID +type kill struct { + C discover.NodeID +} + +// message to drop connection +type drop struct { +} + +/// protoHandshake represents module-independent aspects of the protocol and is +// the first message peers send and receive as part the initial exchange +type protoHandshake struct { + Version uint // local and remote peer should have identical version + NetworkID string // local and remote peer should have identical network id +} + +// checkProtoHandshake verifies local and remote protoHandshakes match +func checkProtoHandshake(testVersion uint, testNetworkID string) func(interface{}) error { + return func(rhs interface{}) error { + remote := rhs.(*protoHandshake) + if remote.NetworkID != testNetworkID { + return fmt.Errorf("%s (!= %s)", remote.NetworkID, testNetworkID) + } + + if remote.Version != testVersion { + return fmt.Errorf("%d (!= %d)", remote.Version, testVersion) + } + return nil + } +} + +// newProtocol sets up a protocol +// the run function here demonstrates a typical protocol using peerPool, handshake +// and messages registered to handlers +func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) error { + spec := &Spec{ + Name: "test", + Version: 42, + MaxMsgSize: 10 * 1024, + Messages: []interface{}{ + protoHandshake{}, + hs0{}, + kill{}, + drop{}, + }, + } + return func(p *p2p.Peer, rw p2p.MsgReadWriter) error { + peer := NewPeer(p, rw, spec) + + // initiate one-off protohandshake and check validity + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + phs := &protoHandshake{42, "420"} + hsCheck := checkProtoHandshake(phs.Version, phs.NetworkID) + _, err := peer.Handshake(ctx, phs, hsCheck) + if err != nil { + return err + } + + lhs := &hs0{42} + // module handshake demonstrating a simple repeatable exchange of same-type message + hs, err := peer.Handshake(ctx, lhs, nil) + if err != nil { + return err + } + + if rmhs := hs.(*hs0); rmhs.C > lhs.C { + return fmt.Errorf("handshake mismatch remote %v > local %v", rmhs.C, lhs.C) + } + + handle := func(msg interface{}) error { + switch msg := msg.(type) { + + case *protoHandshake: + return errors.New("duplicate handshake") + + case *hs0: + rhs := msg + if rhs.C > lhs.C { + return fmt.Errorf("handshake mismatch remote %v > local %v", rhs.C, lhs.C) + } + lhs.C += rhs.C + return peer.Send(lhs) + + case *kill: + // demonstrates use of peerPool, killing another peer connection as a response to a message + id := msg.C + pp.Get(id).Drop(errors.New("killed")) + return nil + + case *drop: + // for testing we can trigger self induced disconnect upon receiving drop message + return errors.New("dropped") + + default: + return fmt.Errorf("unknown message type: %T", msg) + } + } + + pp.Add(peer) + defer pp.Remove(peer) + return peer.Run(handle) + } +} + +func protocolTester(t *testing.T, pp *p2ptest.TestPeerPool) *p2ptest.ProtocolTester { + conf := adapters.RandomNodeConfig() + return p2ptest.NewProtocolTester(t, conf.ID, 2, newProtocol(pp)) +} + +func protoHandshakeExchange(id discover.NodeID, proto *protoHandshake) []p2ptest.Exchange { + + return []p2ptest.Exchange{ + p2ptest.Exchange{ + Expects: []p2ptest.Expect{ + p2ptest.Expect{ + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: id, + }, + }, + }, + p2ptest.Exchange{ + Triggers: []p2ptest.Trigger{ + p2ptest.Trigger{ + Code: 0, + Msg: proto, + Peer: id, + }, + }, + }, + } +} + +func runProtoHandshake(t *testing.T, proto *protoHandshake, errs ...error) { + pp := p2ptest.NewTestPeerPool() + s := protocolTester(t, pp) + // TODO: make this more than one handshake + id := s.IDs[0] + if err := s.TestExchanges(protoHandshakeExchange(id, proto)...); err != nil { + t.Fatal(err) + } + var disconnects []*p2ptest.Disconnect + for i, err := range errs { + disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) + } + if err := s.TestDisconnected(disconnects...); err != nil { + t.Fatal(err) + } +} + +func TestProtoHandshakeVersionMismatch(t *testing.T) { + runProtoHandshake(t, &protoHandshake{41, "420"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 41 (!= 42)").Error())) +} + +func TestProtoHandshakeNetworkIDMismatch(t *testing.T) { + runProtoHandshake(t, &protoHandshake{42, "421"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 421 (!= 420)").Error())) +} + +func TestProtoHandshakeSuccess(t *testing.T) { + runProtoHandshake(t, &protoHandshake{42, "420"}) +} + +func moduleHandshakeExchange(id discover.NodeID, resp uint) []p2ptest.Exchange { + + return []p2ptest.Exchange{ + p2ptest.Exchange{ + Expects: []p2ptest.Expect{ + p2ptest.Expect{ + Code: 1, + Msg: &hs0{42}, + Peer: id, + }, + }, + }, + p2ptest.Exchange{ + Triggers: []p2ptest.Trigger{ + p2ptest.Trigger{ + Code: 1, + Msg: &hs0{resp}, + Peer: id, + }, + }, + }, + } +} + +func runModuleHandshake(t *testing.T, resp uint, errs ...error) { + pp := p2ptest.NewTestPeerPool() + s := protocolTester(t, pp) + id := s.IDs[0] + if err := s.TestExchanges(protoHandshakeExchange(id, &protoHandshake{42, "420"})...); err != nil { + t.Fatal(err) + } + if err := s.TestExchanges(moduleHandshakeExchange(id, resp)...); err != nil { + t.Fatal(err) + } + var disconnects []*p2ptest.Disconnect + for i, err := range errs { + disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) + } + if err := s.TestDisconnected(disconnects...); err != nil { + t.Fatal(err) + } +} + +func TestModuleHandshakeError(t *testing.T) { + runModuleHandshake(t, 43, fmt.Errorf("handshake mismatch remote 43 > local 42")) +} + +func TestModuleHandshakeSuccess(t *testing.T) { + runModuleHandshake(t, 42) +} + +// testing complex interactions over multiple peers, relaying, dropping +func testMultiPeerSetup(a, b discover.NodeID) []p2ptest.Exchange { + + return []p2ptest.Exchange{ + p2ptest.Exchange{ + Label: "primary handshake", + Expects: []p2ptest.Expect{ + p2ptest.Expect{ + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: a, + }, + p2ptest.Expect{ + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: b, + }, + }, + }, + p2ptest.Exchange{ + Label: "module handshake", + Triggers: []p2ptest.Trigger{ + p2ptest.Trigger{ + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: a, + }, + p2ptest.Trigger{ + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: b, + }, + }, + Expects: []p2ptest.Expect{ + p2ptest.Expect{ + Code: 1, + Msg: &hs0{42}, + Peer: a, + }, + p2ptest.Expect{ + Code: 1, + Msg: &hs0{42}, + Peer: b, + }, + }, + }, + + p2ptest.Exchange{Label: "alternative module handshake", Triggers: []p2ptest.Trigger{p2ptest.Trigger{Code: 1, Msg: &hs0{41}, Peer: a}, + p2ptest.Trigger{Code: 1, Msg: &hs0{41}, Peer: b}}}, + p2ptest.Exchange{Label: "repeated module handshake", Triggers: []p2ptest.Trigger{p2ptest.Trigger{Code: 1, Msg: &hs0{1}, Peer: a}}}, + p2ptest.Exchange{Label: "receiving repeated module handshake", Expects: []p2ptest.Expect{p2ptest.Expect{Code: 1, Msg: &hs0{43}, Peer: a}}}} +} + +func runMultiplePeers(t *testing.T, peer int, errs ...error) { + pp := p2ptest.NewTestPeerPool() + s := protocolTester(t, pp) + + if err := s.TestExchanges(testMultiPeerSetup(s.IDs[0], s.IDs[1])...); err != nil { + t.Fatal(err) + } + // after some exchanges of messages, we can test state changes + // here this is simply demonstrated by the peerPool + // after the handshake negotiations peers must be added to the pool + // time.Sleep(1) + tick := time.NewTicker(10 * time.Millisecond) + timeout := time.NewTimer(1 * time.Second) +WAIT: + for { + select { + case <-tick.C: + if pp.Has(s.IDs[0]) { + break WAIT + } + case <-timeout.C: + t.Fatal("timeout") + } + } + if !pp.Has(s.IDs[1]) { + t.Fatalf("missing peer test-1: %v (%v)", pp, s.IDs) + } + + // peer 0 sends kill request for peer with index + err := s.TestExchanges(p2ptest.Exchange{ + Triggers: []p2ptest.Trigger{ + p2ptest.Trigger{ + Code: 2, + Msg: &kill{s.IDs[peer]}, + Peer: s.IDs[0], + }, + }, + }) + + if err != nil { + t.Fatal(err) + } + + // the peer not killed sends a drop request + err = s.TestExchanges(p2ptest.Exchange{ + Triggers: []p2ptest.Trigger{ + p2ptest.Trigger{ + Code: 3, + Msg: &drop{}, + Peer: s.IDs[(peer+1)%2], + }, + }, + }) + + if err != nil { + t.Fatal(err) + } + + // check the actual discconnect errors on the individual peers + var disconnects []*p2ptest.Disconnect + for i, err := range errs { + disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) + } + if err := s.TestDisconnected(disconnects...); err != nil { + t.Fatal(err) + } + // test if disconnected peers have been removed from peerPool + if pp.Has(s.IDs[peer]) { + t.Fatalf("peer test-%v not dropped: %v (%v)", peer, pp, s.IDs) + } + +} + +func TestMultiplePeersDropSelf(t *testing.T) { + runMultiplePeers(t, 0, + fmt.Errorf("subprotocol error"), + fmt.Errorf("Message handler error: (msg code 3): dropped"), + ) +} + +func TestMultiplePeersDropOther(t *testing.T) { + runMultiplePeers(t, 1, + fmt.Errorf("Message handler error: (msg code 3): dropped"), + fmt.Errorf("subprotocol error"), + ) +} diff --git a/p2p/testing/peerpool.go b/p2p/testing/peerpool.go new file mode 100644 index 000000000..45c6e6142 --- /dev/null +++ b/p2p/testing/peerpool.go @@ -0,0 +1,67 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package testing + +import ( + "fmt" + "sync" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover" +) + +type TestPeer interface { + ID() discover.NodeID + Drop(error) +} + +// TestPeerPool is an example peerPool to demonstrate registration of peer connections +type TestPeerPool struct { + lock sync.Mutex + peers map[discover.NodeID]TestPeer +} + +func NewTestPeerPool() *TestPeerPool { + return &TestPeerPool{peers: make(map[discover.NodeID]TestPeer)} +} + +func (self *TestPeerPool) Add(p TestPeer) { + self.lock.Lock() + defer self.lock.Unlock() + log.Trace(fmt.Sprintf("pp add peer %v", p.ID())) + self.peers[p.ID()] = p + +} + +func (self *TestPeerPool) Remove(p TestPeer) { + self.lock.Lock() + defer self.lock.Unlock() + delete(self.peers, p.ID()) +} + +func (self *TestPeerPool) Has(id discover.NodeID) bool { + self.lock.Lock() + defer self.lock.Unlock() + _, ok := self.peers[id] + return ok +} + +func (self *TestPeerPool) Get(id discover.NodeID) TestPeer { + self.lock.Lock() + defer self.lock.Unlock() + return self.peers[id] +} diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go new file mode 100644 index 000000000..a779aeebb --- /dev/null +++ b/p2p/testing/protocolsession.go @@ -0,0 +1,206 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package testing + +import ( + "errors" + "fmt" + "time" + + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" +) + +// ProtocolSession is a quasi simulation of a pivot node running +// a service and a number of dummy peers that can send (trigger) or +// receive (expect) messages +type ProtocolSession struct { + Server *p2p.Server + IDs []discover.NodeID + adapter *adapters.SimAdapter + events chan *p2p.PeerEvent +} + +// Exchange is the basic units of protocol tests +// the triggers and expects in the arrays are run immediately and asynchronously +// thus one cannot have multiple expects for the SAME peer with DIFFERENT message types +// because it's unpredictable which expect will receive which message +// (with expect #1 and #2, messages might be sent #2 and #1, and both expects will complain about wrong message code) +// an exchange is defined on a session +type Exchange struct { + Label string + Triggers []Trigger + Expects []Expect +} + +// Trigger is part of the exchange, incoming message for the pivot node +// sent by a peer +type Trigger struct { + Msg interface{} // type of message to be sent + Code uint64 // code of message is given + Peer discover.NodeID // the peer to send the message to + Timeout time.Duration // timeout duration for the sending +} + +// Expect is part of an exchange, outgoing message from the pivot node +// received by a peer +type Expect struct { + Msg interface{} // type of message to expect + Code uint64 // code of message is now given + Peer discover.NodeID // the peer that expects the message + Timeout time.Duration // timeout duration for receiving +} + +// Disconnect represents a disconnect event, used and checked by TestDisconnected +type Disconnect struct { + Peer discover.NodeID // discconnected peer + Error error // disconnect reason +} + +// trigger sends messages from peers +func (self *ProtocolSession) trigger(trig Trigger) error { + simNode, ok := self.adapter.GetNode(trig.Peer) + if !ok { + return fmt.Errorf("trigger: peer %v does not exist (1- %v)", trig.Peer, len(self.IDs)) + } + mockNode, ok := simNode.Services()[0].(*mockNode) + if !ok { + return fmt.Errorf("trigger: peer %v is not a mock", trig.Peer) + } + + errc := make(chan error) + + go func() { + errc <- mockNode.Trigger(&trig) + }() + + t := trig.Timeout + if t == time.Duration(0) { + t = 1000 * time.Millisecond + } + select { + case err := <-errc: + return err + case <-time.After(t): + return fmt.Errorf("timout expecting %v to send to peer %v", trig.Msg, trig.Peer) + } +} + +// expect checks an expectation of a message sent out by the pivot node +func (self *ProtocolSession) expect(exp Expect) error { + if exp.Msg == nil { + return errors.New("no message to expect") + } + simNode, ok := self.adapter.GetNode(exp.Peer) + if !ok { + return fmt.Errorf("trigger: peer %v does not exist (1- %v)", exp.Peer, len(self.IDs)) + } + mockNode, ok := simNode.Services()[0].(*mockNode) + if !ok { + return fmt.Errorf("trigger: peer %v is not a mock", exp.Peer) + } + + errc := make(chan error) + go func() { + errc <- mockNode.Expect(&exp) + }() + + t := exp.Timeout + if t == time.Duration(0) { + t = 2000 * time.Millisecond + } + select { + case err := <-errc: + return err + case <-time.After(t): + return fmt.Errorf("timout expecting %v sent to peer %v", exp.Msg, exp.Peer) + } +} + +// TestExchanges tests a series of exchanges against the session +func (self *ProtocolSession) TestExchanges(exchanges ...Exchange) error { + // launch all triggers of this exchanges + + for _, e := range exchanges { + errc := make(chan error, len(e.Triggers)+len(e.Expects)) + for _, trig := range e.Triggers { + errc <- self.trigger(trig) + } + + // each expectation is spawned in separate go-routine + // expectations of an exchange are conjunctive but unordered, i.e., + // only all of them arriving constitutes a pass + // each expectation is meant to be for a different peer, otherwise they are expected to panic + // testing of an exchange blocks until all expectations are decided + // an expectation is decided if + // expected message arrives OR + // an unexpected message arrives (panic) + // times out on their individual timeout + for _, ex := range e.Expects { + // expect msg spawned to separate go routine + go func(exp Expect) { + errc <- self.expect(exp) + }(ex) + } + + // time out globally or finish when all expectations satisfied + timeout := time.After(5 * time.Second) + for i := 0; i < len(e.Triggers)+len(e.Expects); i++ { + select { + case err := <-errc: + if err != nil { + return fmt.Errorf("exchange failed with: %v", err) + } + case <-timeout: + return fmt.Errorf("exchange %v: '%v' timed out", i, e.Label) + } + } + } + return nil +} + +// TestDisconnected tests the disconnections given as arguments +// the disconnect structs describe what disconnect error is expected on which peer +func (self *ProtocolSession) TestDisconnected(disconnects ...*Disconnect) error { + expects := make(map[discover.NodeID]error) + for _, disconnect := range disconnects { + expects[disconnect.Peer] = disconnect.Error + } + + timeout := time.After(time.Second) + for len(expects) > 0 { + select { + case event := <-self.events: + if event.Type != p2p.PeerEventTypeDrop { + continue + } + expectErr, ok := expects[event.Peer] + if !ok { + continue + } + + if !(expectErr == nil && event.Error == "" || expectErr != nil && expectErr.Error() == event.Error) { + return fmt.Errorf("unexpected error on peer %v. expected '%v', got '%v'", event.Peer, expectErr, event.Error) + } + delete(expects, event.Peer) + case <-timeout: + return fmt.Errorf("timed out waiting for peers to disconnect") + } + } + return nil +} diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go new file mode 100644 index 000000000..ea5b106ff --- /dev/null +++ b/p2p/testing/protocoltester.go @@ -0,0 +1,200 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +/* +the p2p/testing package provides a unit test scheme to check simple +protocol message exchanges with one pivot node and a number of dummy peers +The pivot test node runs a node.Service, the dummy peers run a mock node +that can be used to send and receive messages +*/ + +package testing + +import ( + "fmt" + "sync" + "testing" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + "github.com/ethereum/go-ethereum/rpc" +) + +// ProtocolTester is the tester environment used for unit testing protocol +// message exchanges. It uses p2p/simulations framework +type ProtocolTester struct { + *ProtocolSession + network *simulations.Network +} + +// NewProtocolTester constructs a new ProtocolTester +// it takes as argument the pivot node id, the number of dummy peers and the +// protocol run function called on a peer connection by the p2p server +func NewProtocolTester(t *testing.T, id discover.NodeID, n int, run func(*p2p.Peer, p2p.MsgReadWriter) error) *ProtocolTester { + services := adapters.Services{ + "test": func(ctx *adapters.ServiceContext) (node.Service, error) { + return &testNode{run}, nil + }, + "mock": func(ctx *adapters.ServiceContext) (node.Service, error) { + return newMockNode(), nil + }, + } + adapter := adapters.NewSimAdapter(services) + net := simulations.NewNetwork(adapter, &simulations.NetworkConfig{}) + if _, err := net.NewNodeWithConfig(&adapters.NodeConfig{ + ID: id, + EnableMsgEvents: true, + Services: []string{"test"}, + }); err != nil { + panic(err.Error()) + } + if err := net.Start(id); err != nil { + panic(err.Error()) + } + + node := net.GetNode(id).Node.(*adapters.SimNode) + peers := make([]*adapters.NodeConfig, n) + peerIDs := make([]discover.NodeID, n) + for i := 0; i < n; i++ { + peers[i] = adapters.RandomNodeConfig() + peers[i].Services = []string{"mock"} + peerIDs[i] = peers[i].ID + } + events := make(chan *p2p.PeerEvent, 1000) + node.SubscribeEvents(events) + ps := &ProtocolSession{ + Server: node.Server(), + IDs: peerIDs, + adapter: adapter, + events: events, + } + self := &ProtocolTester{ + ProtocolSession: ps, + network: net, + } + + self.Connect(id, peers...) + + return self +} + +// Stop stops the p2p server +func (self *ProtocolTester) Stop() error { + self.Server.Stop() + return nil +} + +// Connect brings up the remote peer node and connects it using the +// p2p/simulations network connection with the in memory network adapter +func (self *ProtocolTester) Connect(selfID discover.NodeID, peers ...*adapters.NodeConfig) { + for _, peer := range peers { + log.Trace(fmt.Sprintf("start node %v", peer.ID)) + if _, err := self.network.NewNodeWithConfig(peer); err != nil { + panic(fmt.Sprintf("error starting peer %v: %v", peer.ID, err)) + } + if err := self.network.Start(peer.ID); err != nil { + panic(fmt.Sprintf("error starting peer %v: %v", peer.ID, err)) + } + log.Trace(fmt.Sprintf("connect to %v", peer.ID)) + if err := self.network.Connect(selfID, peer.ID); err != nil { + panic(fmt.Sprintf("error connecting to peer %v: %v", peer.ID, err)) + } + } + +} + +// testNode wraps a protocol run function and implements the node.Service +// interface +type testNode struct { + run func(*p2p.Peer, p2p.MsgReadWriter) error +} + +func (t *testNode) Protocols() []p2p.Protocol { + return []p2p.Protocol{{ + Length: 100, + Run: t.run, + }} +} + +func (t *testNode) APIs() []rpc.API { + return nil +} + +func (t *testNode) Start(server *p2p.Server) error { + return nil +} + +func (t *testNode) Stop() error { + return nil +} + +// mockNode is a testNode which doesn't actually run a protocol, instead +// exposing channels so that tests can manually trigger and expect certain +// messages +type mockNode struct { + testNode + + trigger chan *Trigger + expect chan *Expect + err chan error + stop chan struct{} + stopOnce sync.Once +} + +func newMockNode() *mockNode { + mock := &mockNode{ + trigger: make(chan *Trigger), + expect: make(chan *Expect), + err: make(chan error), + stop: make(chan struct{}), + } + mock.testNode.run = mock.Run + return mock +} + +// Run is a protocol run function which just loops waiting for tests to +// instruct it to either trigger or expect a message from the peer +func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error { + for { + select { + case trig := <-m.trigger: + m.err <- p2p.Send(rw, trig.Code, trig.Msg) + case exp := <-m.expect: + m.err <- p2p.ExpectMsg(rw, exp.Code, exp.Msg) + case <-m.stop: + return nil + } + } +} + +func (m *mockNode) Trigger(trig *Trigger) error { + m.trigger <- trig + return <-m.err +} + +func (m *mockNode) Expect(exp *Expect) error { + m.expect <- exp + return <-m.err +} + +func (m *mockNode) Stop() error { + m.stopOnce.Do(func() { close(m.stop) }) + return nil +} -- cgit v1.2.3 From c4712bf96bc1bae4a5ad4600e9719e4a74bde7d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felf=C3=B6ldi=20Zsolt?= Date: Thu, 8 Feb 2018 18:06:31 +0100 Subject: p2p/discv5: fix multiple discovery issues (#16036) * p2p/discv5: add query delay, fix node address update logic, retry refresh if empty * p2p/discv5: remove unnecessary ping before topic query * p2p/discv5: do not filter local address from topicNodes * p2p/discv5: remove canQuery() * p2p/discv5: gofmt --- p2p/discv5/net.go | 40 +++++++++++++++++++++++++--------------- p2p/discv5/ticket.go | 4 ++-- p2p/discv5/udp.go | 20 ++++++++++---------- 3 files changed, 37 insertions(+), 27 deletions(-) (limited to 'p2p') diff --git a/p2p/discv5/net.go b/p2p/discv5/net.go index f9baf126f..52c677b62 100644 --- a/p2p/discv5/net.go +++ b/p2p/discv5/net.go @@ -565,11 +565,8 @@ loop: if lookupChn := searchInfo[res.target.topic].lookupChn; lookupChn != nil { lookupChn <- net.ticketStore.radius[res.target.topic].converged } - net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node) []byte { - net.ping(n, n.addr()) - return n.pingEcho - }, func(n *Node, topic Topic) []byte { - if n.state == known { + net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node, topic Topic) []byte { + if n.state != nil && n.state.canQuery { return net.conn.send(n, topicQueryPacket, topicQuery{Topic: topic}) // TODO: set expiration } else { if n.state == unknown { @@ -633,15 +630,20 @@ loop: } net.refreshResp <- refreshDone case <-refreshDone: - log.Trace("<-net.refreshDone") - refreshDone = nil - list := searchReqWhenRefreshDone - searchReqWhenRefreshDone = nil - go func() { - for _, req := range list { - net.topicSearchReq <- req - } - }() + log.Trace("<-net.refreshDone", "table size", net.tab.count) + if net.tab.count != 0 { + refreshDone = nil + list := searchReqWhenRefreshDone + searchReqWhenRefreshDone = nil + go func() { + for _, req := range list { + net.topicSearchReq <- req + } + }() + } else { + refreshDone = make(chan struct{}) + net.refresh(refreshDone) + } } } log.Trace("loop stopped") @@ -751,7 +753,15 @@ func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n return n, err } if !n.IP.Equal(rn.IP) || n.UDP != rn.UDP || n.TCP != rn.TCP { - err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n) + if n.state == known { + // reject address change if node is known by us + err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n) + } else { + // accept otherwise; this will be handled nicer with signed ENRs + n.IP = rn.IP + n.UDP = rn.UDP + n.TCP = rn.TCP + } } return n, err } diff --git a/p2p/discv5/ticket.go b/p2p/discv5/ticket.go index 37ce8d23c..b3d1ac4ba 100644 --- a/p2p/discv5/ticket.go +++ b/p2p/discv5/ticket.go @@ -494,13 +494,13 @@ func (s *ticketStore) registerLookupDone(lookup lookupInfo, nodes []*Node, ping } } -func (s *ticketStore) searchLookupDone(lookup lookupInfo, nodes []*Node, ping func(n *Node) []byte, query func(n *Node, topic Topic) []byte) { +func (s *ticketStore) searchLookupDone(lookup lookupInfo, nodes []*Node, query func(n *Node, topic Topic) []byte) { now := mclock.Now() for i, n := range nodes { if i == 0 || (binary.BigEndian.Uint64(n.sha[:8])^binary.BigEndian.Uint64(lookup.target[:8])) < s.radius[lookup.topic].minRadius { if lookup.radiusLookup { if lastReq, ok := s.nodeLastReq[n]; !ok || time.Duration(now-lastReq.time) > radiusTC { - s.nodeLastReq[n] = reqInfo{pingHash: ping(n), lookup: lookup, time: now} + s.nodeLastReq[n] = reqInfo{pingHash: nil, lookup: lookup, time: now} } } // else { if s.canQueryTopic(n, lookup.topic) { diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go index 543771817..6ce72d2c1 100644 --- a/p2p/discv5/udp.go +++ b/p2p/discv5/udp.go @@ -49,7 +49,7 @@ var ( // Timeouts const ( respTimeout = 500 * time.Millisecond - sendTimeout = 500 * time.Millisecond + queryDelay = 1000 * time.Millisecond expiration = 20 * time.Second ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP @@ -318,20 +318,20 @@ func (t *udp) sendTopicRegister(remote *Node, topics []Topic, idx int, pong []by func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) { p := topicNodes{Echo: queryHash} - if len(nodes) == 0 { - t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p) - return - } - for i, result := range nodes { - if netutil.CheckRelayIP(remote.IP, result.IP) != nil { - continue + var sent bool + for _, result := range nodes { + if result.IP.Equal(t.net.tab.self.IP) || netutil.CheckRelayIP(remote.IP, result.IP) == nil { + p.Nodes = append(p.Nodes, nodeToRPC(result)) } - p.Nodes = append(p.Nodes, nodeToRPC(result)) - if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 { + if len(p.Nodes) == maxTopicNodes { t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p) p.Nodes = p.Nodes[:0] + sent = true } } + if !sent || len(p.Nodes) > 0 { + t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p) + } } func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) { -- cgit v1.2.3 From 9123eceb0f78f69e88d909a56ad7fadb75570198 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Mon, 12 Feb 2018 13:36:09 +0100 Subject: p2p, p2p/discover: misc connectivity improvements (#16069) * p2p: add DialRatio for configuration of inbound vs. dialed connections * p2p: add connection flags to PeerInfo * p2p/netutil: add SameNet, DistinctNetSet * p2p/discover: improve revalidation and seeding This changes node revalidation to be periodic instead of on-demand. This should prevent issues where dead nodes get stuck in closer buckets because no other node will ever come along to replace them. Every 5 seconds (on average), the last node in a random bucket is checked and moved to the front of the bucket if it is still responding. If revalidation fails, the last node is replaced by an entry of the 'replacement list' containing recently-seen nodes. Most close buckets are removed because it's very unlikely we'll ever encounter a node that would fall into any of those buckets. Table seeding is also improved: we now require a few minutes of table membership before considering a node as a potential seed node. This should make it less likely to store short-lived nodes as potential seeds. * p2p/discover: fix nits in UDP transport We would skip sending neighbors replies if there were fewer than maxNeighbors results and CheckRelayIP returned an error for the last one. While here, also resolve a TODO about pong reply tokens. --- p2p/discover/node.go | 6 +- p2p/discover/table.go | 484 +++++++++++++++++++++++++++++++-------------- p2p/discover/table_test.go | 171 ++++++++++------ p2p/discover/udp.go | 86 +++++--- p2p/discover/udp_test.go | 21 +- p2p/netutil/net.go | 131 ++++++++++++ p2p/netutil/net_test.go | 89 +++++++++ p2p/peer.go | 6 + p2p/server.go | 77 +++++--- 9 files changed, 795 insertions(+), 276 deletions(-) (limited to 'p2p') diff --git a/p2p/discover/node.go b/p2p/discover/node.go index fc928a91a..3b0c84115 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -29,6 +29,7 @@ import ( "regexp" "strconv" "strings" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" @@ -51,9 +52,8 @@ type Node struct { // with ID. sha common.Hash - // whether this node is currently being pinged in order to replace - // it in a bucket - contested bool + // Time when the node was added to the table. + addedAt time.Time } // NewNode creates a new node. It is mostly meant to be used for diff --git a/p2p/discover/table.go b/p2p/discover/table.go index ec4eb94ad..84c54dac1 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -23,10 +23,11 @@ package discover import ( - "crypto/rand" + crand "crypto/rand" "encoding/binary" "errors" "fmt" + mrand "math/rand" "net" "sort" "sync" @@ -35,29 +36,45 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/netutil" ) const ( - alpha = 3 // Kademlia concurrency factor - bucketSize = 16 // Kademlia bucket size - hashBits = len(common.Hash{}) * 8 - nBuckets = hashBits + 1 // Number of buckets - - maxBondingPingPongs = 16 - maxFindnodeFailures = 5 - - autoRefreshInterval = 1 * time.Hour - seedCount = 30 - seedMaxAge = 5 * 24 * time.Hour + alpha = 3 // Kademlia concurrency factor + bucketSize = 16 // Kademlia bucket size + maxReplacements = 10 // Size of per-bucket replacement list + + // We keep buckets for the upper 1/15 of distances because + // it's very unlikely we'll ever encounter a node that's closer. + hashBits = len(common.Hash{}) * 8 + nBuckets = hashBits / 15 // Number of buckets + bucketMinDistance = hashBits - nBuckets // Log distance of closest bucket + + // IP address limits. + bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24 + tableIPLimit, tableSubnet = 10, 24 + + maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions + maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped + + refreshInterval = 30 * time.Minute + revalidateInterval = 10 * time.Second + copyNodesInterval = 30 * time.Second + seedMinTableTime = 5 * time.Minute + seedCount = 30 + seedMaxAge = 5 * 24 * time.Hour ) type Table struct { - mutex sync.Mutex // protects buckets, their content, and nursery + mutex sync.Mutex // protects buckets, bucket content, nursery, rand buckets [nBuckets]*bucket // index of known nodes by distance nursery []*Node // bootstrap nodes - db *nodeDB // database of known nodes + rand *mrand.Rand // source of randomness, periodically reseeded + ips netutil.DistinctNetSet + db *nodeDB // database of known nodes refreshReq chan chan struct{} + initDone chan struct{} closeReq chan struct{} closed chan struct{} @@ -89,9 +106,13 @@ type transport interface { // bucket contains nodes, ordered by their last activity. the entry // that was most recently active is the first element in entries. -type bucket struct{ entries []*Node } +type bucket struct { + entries []*Node // live entries, sorted by time of last contact + replacements []*Node // recently seen nodes to be used if revalidation fails + ips netutil.DistinctNetSet +} -func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string) (*Table, error) { +func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) { // If no node database was given, use an in-memory one db, err := newNodeDB(nodeDBPath, Version, ourID) if err != nil { @@ -104,19 +125,42 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string bonding: make(map[NodeID]*bondproc), bondslots: make(chan struct{}, maxBondingPingPongs), refreshReq: make(chan chan struct{}), + initDone: make(chan struct{}), closeReq: make(chan struct{}), closed: make(chan struct{}), + rand: mrand.New(mrand.NewSource(0)), + ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}, + } + if err := tab.setFallbackNodes(bootnodes); err != nil { + return nil, err } for i := 0; i < cap(tab.bondslots); i++ { tab.bondslots <- struct{}{} } for i := range tab.buckets { - tab.buckets[i] = new(bucket) + tab.buckets[i] = &bucket{ + ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit}, + } } - go tab.refreshLoop() + tab.seedRand() + tab.loadSeedNodes(false) + // Start the background expiration goroutine after loading seeds so that the search for + // seed nodes also considers older nodes that would otherwise be removed by the + // expiration. + tab.db.ensureExpirer() + go tab.loop() return tab, nil } +func (tab *Table) seedRand() { + var b [8]byte + crand.Read(b[:]) + + tab.mutex.Lock() + tab.rand.Seed(int64(binary.BigEndian.Uint64(b[:]))) + tab.mutex.Unlock() +} + // Self returns the local node. // The returned node should not be modified by the caller. func (tab *Table) Self() *Node { @@ -127,9 +171,12 @@ func (tab *Table) Self() *Node { // table. It will not write the same node more than once. The nodes in // the slice are copies and can be modified by the caller. func (tab *Table) ReadRandomNodes(buf []*Node) (n int) { + if !tab.isInitDone() { + return 0 + } tab.mutex.Lock() defer tab.mutex.Unlock() - // TODO: tree-based buckets would help here + // Find all non-empty buckets and get a fresh slice of their entries. var buckets [][]*Node for _, b := range tab.buckets { @@ -141,8 +188,8 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) { return 0 } // Shuffle the buckets. - for i := uint32(len(buckets)) - 1; i > 0; i-- { - j := randUint(i) + for i := len(buckets) - 1; i > 0; i-- { + j := tab.rand.Intn(len(buckets)) buckets[i], buckets[j] = buckets[j], buckets[i] } // Move head of each bucket into buf, removing buckets that become empty. @@ -161,15 +208,6 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) { return i + 1 } -func randUint(max uint32) uint32 { - if max == 0 { - return 0 - } - var b [4]byte - rand.Read(b[:]) - return binary.BigEndian.Uint32(b[:]) % max -} - // Close terminates the network listener and flushes the node database. func (tab *Table) Close() { select { @@ -180,16 +218,15 @@ func (tab *Table) Close() { } } -// SetFallbackNodes sets the initial points of contact. These nodes +// setFallbackNodes sets the initial points of contact. These nodes // are used to connect to the network if the table is empty and there // are no known nodes in the database. -func (tab *Table) SetFallbackNodes(nodes []*Node) error { +func (tab *Table) setFallbackNodes(nodes []*Node) error { for _, n := range nodes { if err := n.validateComplete(); err != nil { return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err) } } - tab.mutex.Lock() tab.nursery = make([]*Node, 0, len(nodes)) for _, n := range nodes { cpy := *n @@ -198,11 +235,19 @@ func (tab *Table) SetFallbackNodes(nodes []*Node) error { cpy.sha = crypto.Keccak256Hash(n.ID[:]) tab.nursery = append(tab.nursery, &cpy) } - tab.mutex.Unlock() - tab.refresh() return nil } +// isInitDone returns whether the table's initial seeding procedure has completed. +func (tab *Table) isInitDone() bool { + select { + case <-tab.initDone: + return true + default: + return false + } +} + // Resolve searches for a specific node with the given ID. // It returns nil if the node could not be found. func (tab *Table) Resolve(targetID NodeID) *Node { @@ -314,33 +359,49 @@ func (tab *Table) refresh() <-chan struct{} { return done } -// refreshLoop schedules doRefresh runs and coordinates shutdown. -func (tab *Table) refreshLoop() { +// loop schedules refresh, revalidate runs and coordinates shutdown. +func (tab *Table) loop() { var ( - timer = time.NewTicker(autoRefreshInterval) - waiting []chan struct{} // accumulates waiting callers while doRefresh runs - done chan struct{} // where doRefresh reports completion + revalidate = time.NewTimer(tab.nextRevalidateTime()) + refresh = time.NewTicker(refreshInterval) + copyNodes = time.NewTicker(copyNodesInterval) + revalidateDone = make(chan struct{}) + refreshDone = make(chan struct{}) // where doRefresh reports completion + waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs ) + defer refresh.Stop() + defer revalidate.Stop() + defer copyNodes.Stop() + + // Start initial refresh. + go tab.doRefresh(refreshDone) + loop: for { select { - case <-timer.C: - if done == nil { - done = make(chan struct{}) - go tab.doRefresh(done) + case <-refresh.C: + tab.seedRand() + if refreshDone == nil { + refreshDone = make(chan struct{}) + go tab.doRefresh(refreshDone) } case req := <-tab.refreshReq: waiting = append(waiting, req) - if done == nil { - done = make(chan struct{}) - go tab.doRefresh(done) + if refreshDone == nil { + refreshDone = make(chan struct{}) + go tab.doRefresh(refreshDone) } - case <-done: + case <-refreshDone: for _, ch := range waiting { close(ch) } - waiting = nil - done = nil + waiting, refreshDone = nil, nil + case <-revalidate.C: + go tab.doRevalidate(revalidateDone) + case <-revalidateDone: + revalidate.Reset(tab.nextRevalidateTime()) + case <-copyNodes.C: + go tab.copyBondedNodes() case <-tab.closeReq: break loop } @@ -349,8 +410,8 @@ loop: if tab.net != nil { tab.net.close() } - if done != nil { - <-done + if refreshDone != nil { + <-refreshDone } for _, ch := range waiting { close(ch) @@ -365,38 +426,109 @@ loop: func (tab *Table) doRefresh(done chan struct{}) { defer close(done) + // Load nodes from the database and insert + // them. This should yield a few previously seen nodes that are + // (hopefully) still alive. + tab.loadSeedNodes(true) + + // Run self lookup to discover new neighbor nodes. + tab.lookup(tab.self.ID, false) + // The Kademlia paper specifies that the bucket refresh should // perform a lookup in the least recently used bucket. We cannot // adhere to this because the findnode target is a 512bit value // (not hash-sized) and it is not easily possible to generate a // sha3 preimage that falls into a chosen bucket. - // We perform a lookup with a random target instead. - var target NodeID - rand.Read(target[:]) - result := tab.lookup(target, false) - if len(result) > 0 { - return + // We perform a few lookups with a random target instead. + for i := 0; i < 3; i++ { + var target NodeID + crand.Read(target[:]) + tab.lookup(target, false) } +} - // The table is empty. Load nodes from the database and insert - // them. This should yield a few previously seen nodes that are - // (hopefully) still alive. +func (tab *Table) loadSeedNodes(bond bool) { seeds := tab.db.querySeeds(seedCount, seedMaxAge) - seeds = tab.bondall(append(seeds, tab.nursery...)) + seeds = append(seeds, tab.nursery...) + if bond { + seeds = tab.bondall(seeds) + } + for i := range seeds { + seed := seeds[i] + age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.lastPong(seed.ID)) }} + log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age) + tab.add(seed) + } +} + +// doRevalidate checks that the last node in a random bucket is still live +// and replaces or deletes the node if it isn't. +func (tab *Table) doRevalidate(done chan<- struct{}) { + defer func() { done <- struct{}{} }() - if len(seeds) == 0 { - log.Debug("No discv4 seed nodes found") + last, bi := tab.nodeToRevalidate() + if last == nil { + // No non-empty bucket found. + return + } + + // Ping the selected node and wait for a pong. + err := tab.ping(last.ID, last.addr()) + + tab.mutex.Lock() + defer tab.mutex.Unlock() + b := tab.buckets[bi] + if err == nil { + // The node responded, move it to the front. + log.Debug("Revalidated node", "b", bi, "id", last.ID) + b.bump(last) + return + } + // No reply received, pick a replacement or delete the node if there aren't + // any replacements. + if r := tab.replace(b, last); r != nil { + log.Debug("Replaced dead node", "b", bi, "id", last.ID, "ip", last.IP, "r", r.ID, "rip", r.IP) + } else { + log.Debug("Removed dead node", "b", bi, "id", last.ID, "ip", last.IP) } - for _, n := range seeds { - age := log.Lazy{Fn: func() time.Duration { return time.Since(tab.db.lastPong(n.ID)) }} - log.Trace("Found seed node in database", "id", n.ID, "addr", n.addr(), "age", age) +} + +// nodeToRevalidate returns the last node in a random, non-empty bucket. +func (tab *Table) nodeToRevalidate() (n *Node, bi int) { + tab.mutex.Lock() + defer tab.mutex.Unlock() + + for _, bi = range tab.rand.Perm(len(tab.buckets)) { + b := tab.buckets[bi] + if len(b.entries) > 0 { + last := b.entries[len(b.entries)-1] + return last, bi + } } + return nil, 0 +} + +func (tab *Table) nextRevalidateTime() time.Duration { tab.mutex.Lock() - tab.stuff(seeds) - tab.mutex.Unlock() + defer tab.mutex.Unlock() - // Finally, do a self lookup to fill up the buckets. - tab.lookup(tab.self.ID, false) + return time.Duration(tab.rand.Int63n(int64(revalidateInterval))) +} + +// copyBondedNodes adds nodes from the table to the database if they have been in the table +// longer then minTableTime. +func (tab *Table) copyBondedNodes() { + tab.mutex.Lock() + defer tab.mutex.Unlock() + + now := time.Now() + for _, b := range tab.buckets { + for _, n := range b.entries { + if now.Sub(n.addedAt) >= seedMinTableTime { + tab.db.updateNode(n) + } + } + } } // closest returns the n nodes in the table that are closest to the @@ -459,15 +591,14 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16 if id == tab.self.ID { return nil, errors.New("is self") } - // Retrieve a previously known node and any recent findnode failures - node, fails := tab.db.node(id), 0 - if node != nil { - fails = tab.db.findFails(id) + if pinged && !tab.isInitDone() { + return nil, errors.New("still initializing") } - // If the node is unknown (non-bonded) or failed (remotely unknown), bond from scratch - var result error + // Start bonding if we haven't seen this node for a while or if it failed findnode too often. + node, fails := tab.db.node(id), tab.db.findFails(id) age := time.Since(tab.db.lastPong(id)) - if node == nil || fails > 0 || age > nodeDBNodeExpiration { + var result error + if fails > 0 || age > nodeDBNodeExpiration { log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age) tab.bondmu.Lock() @@ -494,10 +625,10 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16 node = w.n } } + // Add the node to the table even if the bonding ping/pong + // fails. It will be relaced quickly if it continues to be + // unresponsive. if node != nil { - // Add the node to the table even if the bonding ping/pong - // fails. It will be relaced quickly if it continues to be - // unresponsive. tab.add(node) tab.db.updateFindFails(id, 0) } @@ -522,7 +653,6 @@ func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAdd } // Bonding succeeded, update the node database. w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort) - tab.db.updateNode(w.n) close(w.done) } @@ -534,16 +664,18 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { return err } tab.db.updateLastPong(id, time.Now()) - - // Start the background expiration goroutine after the first - // successful communication. Subsequent calls have no effect if it - // is already running. We do this here instead of somewhere else - // so that the search for seed nodes also considers older nodes - // that would otherwise be removed by the expiration. - tab.db.ensureExpirer() return nil } +// bucket returns the bucket for the given node ID hash. +func (tab *Table) bucket(sha common.Hash) *bucket { + d := logdist(tab.self.sha, sha) + if d <= bucketMinDistance { + return tab.buckets[0] + } + return tab.buckets[d-bucketMinDistance-1] +} + // add attempts to add the given node its corresponding bucket. If the // bucket has space available, adding the node succeeds immediately. // Otherwise, the node is added if the least recently active node in @@ -551,57 +683,29 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { // // The caller must not hold tab.mutex. func (tab *Table) add(new *Node) { - b := tab.buckets[logdist(tab.self.sha, new.sha)] tab.mutex.Lock() defer tab.mutex.Unlock() - if b.bump(new) { - return - } - var oldest *Node - if len(b.entries) == bucketSize { - oldest = b.entries[bucketSize-1] - if oldest.contested { - // The node is already being replaced, don't attempt - // to replace it. - return - } - oldest.contested = true - // Let go of the mutex so other goroutines can access - // the table while we ping the least recently active node. - tab.mutex.Unlock() - err := tab.ping(oldest.ID, oldest.addr()) - tab.mutex.Lock() - oldest.contested = false - if err == nil { - // The node responded, don't replace it. - return - } - } - added := b.replace(new, oldest) - if added && tab.nodeAddedHook != nil { - tab.nodeAddedHook(new) + + b := tab.bucket(new.sha) + if !tab.bumpOrAdd(b, new) { + // Node is not in table. Add it to the replacement list. + tab.addReplacement(b, new) } } // stuff adds nodes the table to the end of their corresponding bucket -// if the bucket is not full. The caller must hold tab.mutex. +// if the bucket is not full. The caller must not hold tab.mutex. func (tab *Table) stuff(nodes []*Node) { -outer: + tab.mutex.Lock() + defer tab.mutex.Unlock() + for _, n := range nodes { if n.ID == tab.self.ID { continue // don't add self } - bucket := tab.buckets[logdist(tab.self.sha, n.sha)] - for i := range bucket.entries { - if bucket.entries[i].ID == n.ID { - continue outer // already in bucket - } - } - if len(bucket.entries) < bucketSize { - bucket.entries = append(bucket.entries, n) - if tab.nodeAddedHook != nil { - tab.nodeAddedHook(n) - } + b := tab.bucket(n.sha) + if len(b.entries) < bucketSize { + tab.bumpOrAdd(b, n) } } } @@ -611,36 +715,72 @@ outer: func (tab *Table) delete(node *Node) { tab.mutex.Lock() defer tab.mutex.Unlock() - bucket := tab.buckets[logdist(tab.self.sha, node.sha)] - for i := range bucket.entries { - if bucket.entries[i].ID == node.ID { - bucket.entries = append(bucket.entries[:i], bucket.entries[i+1:]...) - return - } - } + + tab.deleteInBucket(tab.bucket(node.sha), node) } -func (b *bucket) replace(n *Node, last *Node) bool { - // Don't add if b already contains n. - for i := range b.entries { - if b.entries[i].ID == n.ID { - return false - } +func (tab *Table) addIP(b *bucket, ip net.IP) bool { + if netutil.IsLAN(ip) { + return true } - // Replace last if it is still the last entry or just add n if b - // isn't full. If is no longer the last entry, it has either been - // replaced with someone else or became active. - if len(b.entries) == bucketSize && (last == nil || b.entries[bucketSize-1].ID != last.ID) { + if !tab.ips.Add(ip) { + log.Debug("IP exceeds table limit", "ip", ip) return false } - if len(b.entries) < bucketSize { - b.entries = append(b.entries, nil) + if !b.ips.Add(ip) { + log.Debug("IP exceeds bucket limit", "ip", ip) + tab.ips.Remove(ip) + return false } - copy(b.entries[1:], b.entries) - b.entries[0] = n return true } +func (tab *Table) removeIP(b *bucket, ip net.IP) { + if netutil.IsLAN(ip) { + return + } + tab.ips.Remove(ip) + b.ips.Remove(ip) +} + +func (tab *Table) addReplacement(b *bucket, n *Node) { + for _, e := range b.replacements { + if e.ID == n.ID { + return // already in list + } + } + if !tab.addIP(b, n.IP) { + return + } + var removed *Node + b.replacements, removed = pushNode(b.replacements, n, maxReplacements) + if removed != nil { + tab.removeIP(b, removed.IP) + } +} + +// replace removes n from the replacement list and replaces 'last' with it if it is the +// last entry in the bucket. If 'last' isn't the last entry, it has either been replaced +// with someone else or became active. +func (tab *Table) replace(b *bucket, last *Node) *Node { + if len(b.entries) >= 0 && b.entries[len(b.entries)-1].ID != last.ID { + // Entry has moved, don't replace it. + return nil + } + // Still the last entry. + if len(b.replacements) == 0 { + tab.deleteInBucket(b, last) + return nil + } + r := b.replacements[tab.rand.Intn(len(b.replacements))] + b.replacements = deleteNode(b.replacements, r) + b.entries[len(b.entries)-1] = r + tab.removeIP(b, last.IP) + return r +} + +// bump moves the given node to the front of the bucket entry list +// if it is contained in that list. func (b *bucket) bump(n *Node) bool { for i := range b.entries { if b.entries[i].ID == n.ID { @@ -653,6 +793,50 @@ func (b *bucket) bump(n *Node) bool { return false } +// bumpOrAdd moves n to the front of the bucket entry list or adds it if the list isn't +// full. The return value is true if n is in the bucket. +func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool { + if b.bump(n) { + return true + } + if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP) { + return false + } + b.entries, _ = pushNode(b.entries, n, bucketSize) + b.replacements = deleteNode(b.replacements, n) + n.addedAt = time.Now() + if tab.nodeAddedHook != nil { + tab.nodeAddedHook(n) + } + return true +} + +func (tab *Table) deleteInBucket(b *bucket, n *Node) { + b.entries = deleteNode(b.entries, n) + tab.removeIP(b, n.IP) +} + +// pushNode adds n to the front of list, keeping at most max items. +func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) { + if len(list) < max { + list = append(list, nil) + } + removed := list[len(list)-1] + copy(list[1:], list) + list[0] = n + return list, removed +} + +// deleteNode removes n from list. +func deleteNode(list []*Node, n *Node) []*Node { + for i := range list { + if list[i].ID == n.ID { + return append(list[:i], list[i+1:]...) + } + } + return list +} + // nodesByDistance is a list of nodes, ordered by // distance to target. type nodesByDistance struct { diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 1037cc609..3ce48d299 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -20,6 +20,7 @@ import ( "crypto/ecdsa" "fmt" "math/rand" + "sync" "net" "reflect" @@ -32,60 +33,65 @@ import ( ) func TestTable_pingReplace(t *testing.T) { - doit := func(newNodeIsResponding, lastInBucketIsResponding bool) { - transport := newPingRecorder() - tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "") - defer tab.Close() - pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99) + run := func(newNodeResponding, lastInBucketResponding bool) { + name := fmt.Sprintf("newNodeResponding=%t/lastInBucketResponding=%t", newNodeResponding, lastInBucketResponding) + t.Run(name, func(t *testing.T) { + t.Parallel() + testPingReplace(t, newNodeResponding, lastInBucketResponding) + }) + } - // fill up the sender's bucket. - last := fillBucket(tab, 253) + run(true, true) + run(false, true) + run(true, false) + run(false, false) +} - // this call to bond should replace the last node - // in its bucket if the node is not responding. - transport.responding[last.ID] = lastInBucketIsResponding - transport.responding[pingSender.ID] = newNodeIsResponding - tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0) +func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) { + transport := newPingRecorder() + tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) + defer tab.Close() - // first ping goes to sender (bonding pingback) - if !transport.pinged[pingSender.ID] { - t.Error("table did not ping back sender") - } - if newNodeIsResponding { - // second ping goes to oldest node in bucket - // to see whether it is still alive. - if !transport.pinged[last.ID] { - t.Error("table did not ping last node in bucket") - } - } + // Wait for init so bond is accepted. + <-tab.initDone - tab.mutex.Lock() - defer tab.mutex.Unlock() - if l := len(tab.buckets[253].entries); l != bucketSize { - t.Errorf("wrong bucket size after bond: got %d, want %d", l, bucketSize) - } + // fill up the sender's bucket. + pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99) + last := fillBucket(tab, pingSender) - if lastInBucketIsResponding || !newNodeIsResponding { - if !contains(tab.buckets[253].entries, last.ID) { - t.Error("last entry was removed") - } - if contains(tab.buckets[253].entries, pingSender.ID) { - t.Error("new entry was added") - } - } else { - if contains(tab.buckets[253].entries, last.ID) { - t.Error("last entry was not removed") - } - if !contains(tab.buckets[253].entries, pingSender.ID) { - t.Error("new entry was not added") - } - } + // this call to bond should replace the last node + // in its bucket if the node is not responding. + transport.dead[last.ID] = !lastInBucketIsResponding + transport.dead[pingSender.ID] = !newNodeIsResponding + tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0) + tab.doRevalidate(make(chan struct{}, 1)) + + // first ping goes to sender (bonding pingback) + if !transport.pinged[pingSender.ID] { + t.Error("table did not ping back sender") + } + if !transport.pinged[last.ID] { + // second ping goes to oldest node in bucket + // to see whether it is still alive. + t.Error("table did not ping last node in bucket") } - doit(true, true) - doit(false, true) - doit(true, false) - doit(false, false) + tab.mutex.Lock() + defer tab.mutex.Unlock() + wantSize := bucketSize + if !lastInBucketIsResponding && !newNodeIsResponding { + wantSize-- + } + if l := len(tab.bucket(pingSender.sha).entries); l != wantSize { + t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize) + } + if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding { + t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding) + } + wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding + if found := contains(tab.bucket(pingSender.sha).entries, pingSender.ID); found != wantNewEntry { + t.Errorf("new entry found: %t, want: %t", found, wantNewEntry) + } } func TestBucket_bumpNoDuplicates(t *testing.T) { @@ -130,11 +136,45 @@ func TestBucket_bumpNoDuplicates(t *testing.T) { } } +// This checks that the table-wide IP limit is applied correctly. +func TestTable_IPLimit(t *testing.T) { + transport := newPingRecorder() + tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) + defer tab.Close() + + for i := 0; i < tableIPLimit+1; i++ { + n := nodeAtDistance(tab.self.sha, i) + n.IP = net.IP{172, 0, 1, byte(i)} + tab.add(n) + } + if tab.len() > tableIPLimit { + t.Errorf("too many nodes in table") + } +} + +// This checks that the table-wide IP limit is applied correctly. +func TestTable_BucketIPLimit(t *testing.T) { + transport := newPingRecorder() + tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) + defer tab.Close() + + d := 3 + for i := 0; i < bucketIPLimit+1; i++ { + n := nodeAtDistance(tab.self.sha, d) + n.IP = net.IP{172, 0, 1, byte(i)} + tab.add(n) + } + if tab.len() > bucketIPLimit { + t.Errorf("too many nodes in table") + } +} + // fillBucket inserts nodes into the given bucket until // it is full. The node's IDs dont correspond to their // hashes. -func fillBucket(tab *Table, ld int) (last *Node) { - b := tab.buckets[ld] +func fillBucket(tab *Table, n *Node) (last *Node) { + ld := logdist(tab.self.sha, n.sha) + b := tab.bucket(n.sha) for len(b.entries) < bucketSize { b.entries = append(b.entries, nodeAtDistance(tab.self.sha, ld)) } @@ -146,30 +186,39 @@ func fillBucket(tab *Table, ld int) (last *Node) { func nodeAtDistance(base common.Hash, ld int) (n *Node) { n = new(Node) n.sha = hashAtDistance(base, ld) - n.IP = net.IP{10, 0, 2, byte(ld)} + n.IP = net.IP{byte(ld), 0, 2, byte(ld)} copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID return n } -type pingRecorder struct{ responding, pinged map[NodeID]bool } +type pingRecorder struct { + mu sync.Mutex + dead, pinged map[NodeID]bool +} func newPingRecorder() *pingRecorder { - return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)} + return &pingRecorder{ + dead: make(map[NodeID]bool), + pinged: make(map[NodeID]bool), + } } func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { - panic("findnode called on pingRecorder") + return nil, nil } func (t *pingRecorder) close() {} func (t *pingRecorder) waitping(from NodeID) error { return nil // remote always pings } func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error { + t.mu.Lock() + defer t.mu.Unlock() + t.pinged[toid] = true - if t.responding[toid] { - return nil - } else { + if t.dead[toid] { return errTimeout + } else { + return nil } } @@ -178,7 +227,8 @@ func TestTable_closest(t *testing.T) { test := func(test *closeTest) bool { // for any node table, Target and N - tab, _ := newTable(nil, test.Self, &net.UDPAddr{}, "") + transport := newPingRecorder() + tab, _ := newTable(transport, test.Self, &net.UDPAddr{}, "", nil) defer tab.Close() tab.stuff(test.All) @@ -237,8 +287,11 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) { }, } test := func(buf []*Node) bool { - tab, _ := newTable(nil, NodeID{}, &net.UDPAddr{}, "") + transport := newPingRecorder() + tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) defer tab.Close() + <-tab.initDone + for i := 0; i < len(buf); i++ { ld := cfg.Rand.Intn(len(tab.buckets)) tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)}) @@ -280,7 +333,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { func TestTable_Lookup(t *testing.T) { self := nodeAtDistance(common.Hash{}, 0) - tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "") + tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "", nil) defer tab.Close() // lookup on empty table returns no nodes diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index 60436952d..e40de2c36 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -216,9 +216,22 @@ type ReadPacket struct { Addr *net.UDPAddr } +// Config holds Table-related settings. +type Config struct { + // These settings are required and configure the UDP listener: + PrivateKey *ecdsa.PrivateKey + + // These settings are optional: + AnnounceAddr *net.UDPAddr // local address announced in the DHT + NodeDBPath string // if set, the node database is stored at this filesystem location + NetRestrict *netutil.Netlist // network whitelist + Bootnodes []*Node // list of bootstrap nodes + Unhandled chan<- ReadPacket // unhandled packets are sent on this channel +} + // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) { - tab, _, err := newUDP(priv, conn, realaddr, unhandled, nodeDBPath, netrestrict) +func ListenUDP(c conn, cfg Config) (*Table, error) { + tab, _, err := newUDP(c, cfg) if err != nil { return nil, err } @@ -226,25 +239,29 @@ func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandl return tab, nil } -func newUDP(priv *ecdsa.PrivateKey, c conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) { +func newUDP(c conn, cfg Config) (*Table, *udp, error) { udp := &udp{ conn: c, - priv: priv, - netrestrict: netrestrict, + priv: cfg.PrivateKey, + netrestrict: cfg.NetRestrict, closing: make(chan struct{}), gotreply: make(chan reply), addpending: make(chan *pending), } + realaddr := c.LocalAddr().(*net.UDPAddr) + if cfg.AnnounceAddr != nil { + realaddr = cfg.AnnounceAddr + } // TODO: separate TCP port udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) - tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath) + tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes) if err != nil { return nil, nil, err } udp.Table = tab go udp.loop() - go udp.readLoop(unhandled) + go udp.readLoop(cfg.Unhandled) return udp.Table, udp, nil } @@ -256,14 +273,20 @@ func (t *udp) close() { // ping sends a ping message to the given node and waits for a reply. func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { - // TODO: maybe check for ReplyTo field in callback to measure RTT - errc := t.pending(toid, pongPacket, func(interface{}) bool { return true }) - t.send(toaddr, pingPacket, &ping{ + req := &ping{ Version: Version, From: t.ourEndpoint, To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB Expiration: uint64(time.Now().Add(expiration).Unix()), + } + packet, hash, err := encodePacket(t.priv, pingPacket, req) + if err != nil { + return err + } + errc := t.pending(toid, pongPacket, func(p interface{}) bool { + return bytes.Equal(p.(*pong).ReplyTok, hash) }) + t.write(toaddr, req.name(), packet) return <-errc } @@ -447,40 +470,45 @@ func init() { } } -func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) error { - packet, err := encodePacket(t.priv, ptype, req) +func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) { + packet, hash, err := encodePacket(t.priv, ptype, req) if err != nil { - return err + return hash, err } - _, err = t.conn.WriteToUDP(packet, toaddr) - log.Trace(">> "+req.name(), "addr", toaddr, "err", err) + return hash, t.write(toaddr, req.name(), packet) +} + +func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error { + _, err := t.conn.WriteToUDP(packet, toaddr) + log.Trace(">> "+what, "addr", toaddr, "err", err) return err } -func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) { +func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) { b := new(bytes.Buffer) b.Write(headSpace) b.WriteByte(ptype) if err := rlp.Encode(b, req); err != nil { log.Error("Can't encode discv4 packet", "err", err) - return nil, err + return nil, nil, err } - packet := b.Bytes() + packet = b.Bytes() sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) if err != nil { log.Error("Can't sign discv4 packet", "err", err) - return nil, err + return nil, nil, err } copy(packet[macSize:], sig) // add the hash to the front. Note: this doesn't protect the // packet in any way. Our public key will be part of this hash in // The future. - copy(packet, crypto.Keccak256(packet[macSize:])) - return packet, nil + hash = crypto.Keccak256(packet[macSize:]) + copy(packet, hash) + return packet, hash, nil } // readLoop runs in its own goroutine. it handles incoming UDP packets. -func (t *udp) readLoop(unhandled chan ReadPacket) { +func (t *udp) readLoop(unhandled chan<- ReadPacket) { defer t.conn.Close() if unhandled != nil { defer close(unhandled) @@ -601,18 +629,22 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte t.mutex.Unlock() p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} + var sent bool // Send neighbors in chunks with at most maxNeighbors per packet // to stay below the 1280 byte limit. - for i, n := range closest { - if netutil.CheckRelayIP(from.IP, n.IP) != nil { - continue + for _, n := range closest { + if netutil.CheckRelayIP(from.IP, n.IP) == nil { + p.Nodes = append(p.Nodes, nodeToRPC(n)) } - p.Nodes = append(p.Nodes, nodeToRPC(n)) - if len(p.Nodes) == maxNeighbors || i == len(closest)-1 { + if len(p.Nodes) == maxNeighbors { t.send(from, neighborsPacket, &p) p.Nodes = p.Nodes[:0] + sent = true } } + if len(p.Nodes) > 0 || !sent { + t.send(from, neighborsPacket, &p) + } return nil } diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index b81caf839..3ffa5c4dd 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -70,14 +70,15 @@ func newUDPTest(t *testing.T) *udpTest { remotekey: newkey(), remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, } - realaddr := test.pipe.LocalAddr().(*net.UDPAddr) - test.table, test.udp, _ = newUDP(test.localkey, test.pipe, realaddr, nil, "", nil) + test.table, test.udp, _ = newUDP(test.pipe, Config{PrivateKey: test.localkey}) + // Wait for initial refresh so the table doesn't send unexpected findnode. + <-test.table.initDone return test } // handles a packet as if it had been sent to the transport. func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error { - enc, err := encodePacket(test.remotekey, ptype, data) + enc, _, err := encodePacket(test.remotekey, ptype, data) if err != nil { return test.errorf("packet (%d) encode error: %v", ptype, err) } @@ -90,19 +91,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error { // waits for a packet to be sent by the transport. // validate should have type func(*udpTest, X) error, where X is a packet type. -func (test *udpTest) waitPacketOut(validate interface{}) error { +func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) { dgram := test.pipe.waitPacketOut() - p, _, _, err := decodePacket(dgram) + p, _, hash, err := decodePacket(dgram) if err != nil { - return test.errorf("sent packet decode error: %v", err) + return hash, test.errorf("sent packet decode error: %v", err) } fn := reflect.ValueOf(validate) exptype := fn.Type().In(0) if reflect.TypeOf(p) != exptype { - return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) + return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) } fn.Call([]reflect.Value{reflect.ValueOf(p)}) - return nil + return hash, nil } func (test *udpTest) errorf(format string, args ...interface{}) error { @@ -351,7 +352,7 @@ func TestUDP_successfulPing(t *testing.T) { }) // remote is unknown, the table pings back. - test.waitPacketOut(func(p *ping) error { + hash, _ := test.waitPacketOut(func(p *ping) error { if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) { t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint) } @@ -365,7 +366,7 @@ func TestUDP_successfulPing(t *testing.T) { } return nil }) - test.packetIn(nil, pongPacket, &pong{Expiration: futureExp}) + test.packetIn(nil, pongPacket, &pong{ReplyTok: hash, Expiration: futureExp}) // the node should be added to the table shortly after getting the // pong packet. diff --git a/p2p/netutil/net.go b/p2p/netutil/net.go index f6005afd2..656abb682 100644 --- a/p2p/netutil/net.go +++ b/p2p/netutil/net.go @@ -18,8 +18,11 @@ package netutil import ( + "bytes" "errors" + "fmt" "net" + "sort" "strings" ) @@ -189,3 +192,131 @@ func CheckRelayIP(sender, addr net.IP) error { } return nil } + +// SameNet reports whether two IP addresses have an equal prefix of the given bit length. +func SameNet(bits uint, ip, other net.IP) bool { + ip4, other4 := ip.To4(), other.To4() + switch { + case (ip4 == nil) != (other4 == nil): + return false + case ip4 != nil: + return sameNet(bits, ip4, other4) + default: + return sameNet(bits, ip.To16(), other.To16()) + } +} + +func sameNet(bits uint, ip, other net.IP) bool { + nb := int(bits / 8) + mask := ^byte(0xFF >> (bits % 8)) + if mask != 0 && nb < len(ip) && ip[nb]&mask != other[nb]&mask { + return false + } + return nb <= len(ip) && bytes.Equal(ip[:nb], other[:nb]) +} + +// DistinctNetSet tracks IPs, ensuring that at most N of them +// fall into the same network range. +type DistinctNetSet struct { + Subnet uint // number of common prefix bits + Limit uint // maximum number of IPs in each subnet + + members map[string]uint + buf net.IP +} + +// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the +// number of existing IPs in the defined range exceeds the limit. +func (s *DistinctNetSet) Add(ip net.IP) bool { + key := s.key(ip) + n := s.members[string(key)] + if n < s.Limit { + s.members[string(key)] = n + 1 + return true + } + return false +} + +// Remove removes an IP from the set. +func (s *DistinctNetSet) Remove(ip net.IP) { + key := s.key(ip) + if n, ok := s.members[string(key)]; ok { + if n == 1 { + delete(s.members, string(key)) + } else { + s.members[string(key)] = n - 1 + } + } +} + +// Contains whether the given IP is contained in the set. +func (s DistinctNetSet) Contains(ip net.IP) bool { + key := s.key(ip) + _, ok := s.members[string(key)] + return ok +} + +// Len returns the number of tracked IPs. +func (s DistinctNetSet) Len() int { + n := uint(0) + for _, i := range s.members { + n += i + } + return int(n) +} + +// key encodes the map key for an address into a temporary buffer. +// +// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types. +// The remainder of the key is the IP, truncated to the number of bits. +func (s *DistinctNetSet) key(ip net.IP) net.IP { + // Lazily initialize storage. + if s.members == nil { + s.members = make(map[string]uint) + s.buf = make(net.IP, 17) + } + // Canonicalize ip and bits. + typ := byte('6') + if ip4 := ip.To4(); ip4 != nil { + typ, ip = '4', ip4 + } + bits := s.Subnet + if bits > uint(len(ip)*8) { + bits = uint(len(ip) * 8) + } + // Encode the prefix into s.buf. + nb := int(bits / 8) + mask := ^byte(0xFF >> (bits % 8)) + s.buf[0] = typ + buf := append(s.buf[:1], ip[:nb]...) + if nb < len(ip) && mask != 0 { + buf = append(buf, ip[nb]&mask) + } + return buf +} + +// String implements fmt.Stringer +func (s DistinctNetSet) String() string { + var buf bytes.Buffer + buf.WriteString("{") + keys := make([]string, 0, len(s.members)) + for k := range s.members { + keys = append(keys, k) + } + sort.Strings(keys) + for i, k := range keys { + var ip net.IP + if k[0] == '4' { + ip = make(net.IP, 4) + } else { + ip = make(net.IP, 16) + } + copy(ip, k[1:]) + fmt.Fprintf(&buf, "%vĂ—%d", ip, s.members[k]) + if i != len(keys)-1 { + buf.WriteString(" ") + } + } + buf.WriteString("}") + return buf.String() +} diff --git a/p2p/netutil/net_test.go b/p2p/netutil/net_test.go index 1ee1fcb4d..3a6aa081f 100644 --- a/p2p/netutil/net_test.go +++ b/p2p/netutil/net_test.go @@ -17,9 +17,11 @@ package netutil import ( + "fmt" "net" "reflect" "testing" + "testing/quick" "github.com/davecgh/go-spew/spew" ) @@ -171,3 +173,90 @@ func BenchmarkCheckRelayIP(b *testing.B) { CheckRelayIP(sender, addr) } } + +func TestSameNet(t *testing.T) { + tests := []struct { + ip, other string + bits uint + want bool + }{ + {"0.0.0.0", "0.0.0.0", 32, true}, + {"0.0.0.0", "0.0.0.1", 0, true}, + {"0.0.0.0", "0.0.0.1", 31, true}, + {"0.0.0.0", "0.0.0.1", 32, false}, + {"0.33.0.1", "0.34.0.2", 8, true}, + {"0.33.0.1", "0.34.0.2", 13, true}, + {"0.33.0.1", "0.34.0.2", 15, false}, + } + + for _, test := range tests { + if ok := SameNet(test.bits, parseIP(test.ip), parseIP(test.other)); ok != test.want { + t.Errorf("SameNet(%d, %s, %s) == %t, want %t", test.bits, test.ip, test.other, ok, test.want) + } + } +} + +func ExampleSameNet() { + // This returns true because the IPs are in the same /24 network: + fmt.Println(SameNet(24, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 3})) + // This call returns false: + fmt.Println(SameNet(24, net.IP{127, 3, 0, 1}, net.IP{127, 5, 0, 3})) + // Output: + // true + // false +} + +func TestDistinctNetSet(t *testing.T) { + ops := []struct { + add, remove string + fails bool + }{ + {add: "127.0.0.1"}, + {add: "127.0.0.2"}, + {add: "127.0.0.3", fails: true}, + {add: "127.32.0.1"}, + {add: "127.32.0.2"}, + {add: "127.32.0.3", fails: true}, + {add: "127.33.0.1", fails: true}, + {add: "127.34.0.1"}, + {add: "127.34.0.2"}, + {add: "127.34.0.3", fails: true}, + // Make room for an address, then add again. + {remove: "127.0.0.1"}, + {add: "127.0.0.3"}, + {add: "127.0.0.3", fails: true}, + } + + set := DistinctNetSet{Subnet: 15, Limit: 2} + for _, op := range ops { + var desc string + if op.add != "" { + desc = fmt.Sprintf("Add(%s)", op.add) + if ok := set.Add(parseIP(op.add)); ok != !op.fails { + t.Errorf("%s == %t, want %t", desc, ok, !op.fails) + } + } else { + desc = fmt.Sprintf("Remove(%s)", op.remove) + set.Remove(parseIP(op.remove)) + } + t.Logf("%s: %v", desc, set) + } +} + +func TestDistinctNetSetAddRemove(t *testing.T) { + cfg := &quick.Config{} + fn := func(ips []net.IP) bool { + s := DistinctNetSet{Limit: 3, Subnet: 2} + for _, ip := range ips { + s.Add(ip) + } + for _, ip := range ips { + s.Remove(ip) + } + return s.Len() == 0 + } + + if err := quick.Check(fn, cfg); err != nil { + t.Fatal(err) + } +} diff --git a/p2p/peer.go b/p2p/peer.go index bad1c8c8b..477d8c219 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -419,6 +419,9 @@ type PeerInfo struct { Network struct { LocalAddress string `json:"localAddress"` // Local endpoint of the TCP data connection RemoteAddress string `json:"remoteAddress"` // Remote endpoint of the TCP data connection + Inbound bool `json:"inbound"` + Trusted bool `json:"trusted"` + Static bool `json:"static"` } `json:"network"` Protocols map[string]interface{} `json:"protocols"` // Sub-protocol specific metadata fields } @@ -439,6 +442,9 @@ func (p *Peer) Info() *PeerInfo { } info.Network.LocalAddress = p.LocalAddr().String() info.Network.RemoteAddress = p.RemoteAddr().String() + info.Network.Inbound = p.rw.is(inboundConn) + info.Network.Trusted = p.rw.is(trustedConn) + info.Network.Static = p.rw.is(staticDialedConn) // Gather all the running protocol infos for _, proto := range p.running { diff --git a/p2p/server.go b/p2p/server.go index 2cff94ea5..edc1d9d21 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -40,11 +40,10 @@ const ( refreshPeersInterval = 30 * time.Second staticPeerCheckInterval = 15 * time.Second - // Maximum number of concurrently handshaking inbound connections. - maxAcceptConns = 50 - - // Maximum number of concurrently dialing outbound connections. - maxActiveDialTasks = 16 + // Connectivity defaults. + maxActiveDialTasks = 16 + defaultMaxPendingPeers = 50 + defaultDialRatio = 3 // Maximum time allowed for reading a complete message. // This is effectively the amount of time a connection can be idle. @@ -70,6 +69,11 @@ type Config struct { // Zero defaults to preset values. MaxPendingPeers int `toml:",omitempty"` + // DialRatio controls the ratio of inbound to dialed connections. + // Example: a DialRatio of 2 allows 1/2 of connections to be dialed. + // Setting DialRatio to zero defaults it to 3. + DialRatio int `toml:",omitempty"` + // NoDiscovery can be used to disable the peer discovery mechanism. // Disabling is useful for protocol debugging (manual topology). NoDiscovery bool @@ -427,7 +431,6 @@ func (srv *Server) Start() (err error) { if err != nil { return err } - realaddr = conn.LocalAddr().(*net.UDPAddr) if srv.NAT != nil { if !realaddr.IP.IsLoopback() { @@ -447,11 +450,16 @@ func (srv *Server) Start() (err error) { // node table if !srv.NoDiscovery { - ntab, err := discover.ListenUDP(srv.PrivateKey, conn, realaddr, unhandled, srv.NodeDatabase, srv.NetRestrict) - if err != nil { - return err + cfg := discover.Config{ + PrivateKey: srv.PrivateKey, + AnnounceAddr: realaddr, + NodeDBPath: srv.NodeDatabase, + NetRestrict: srv.NetRestrict, + Bootnodes: srv.BootstrapNodes, + Unhandled: unhandled, } - if err := ntab.SetFallbackNodes(srv.BootstrapNodes); err != nil { + ntab, err := discover.ListenUDP(conn, cfg) + if err != nil { return err } srv.ntab = ntab @@ -476,10 +484,7 @@ func (srv *Server) Start() (err error) { srv.DiscV5 = ntab } - dynPeers := (srv.MaxPeers + 1) / 2 - if srv.NoDiscovery { - dynPeers = 0 - } + dynPeers := srv.maxDialedConns() dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict) // handshake @@ -536,6 +541,7 @@ func (srv *Server) run(dialstate dialer) { defer srv.loopWG.Done() var ( peers = make(map[discover.NodeID]*Peer) + inboundCount = 0 trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes)) taskdone = make(chan task, maxActiveDialTasks) runningTasks []task @@ -621,14 +627,14 @@ running: } // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. select { - case c.cont <- srv.encHandshakeChecks(peers, c): + case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c): case <-srv.quit: break running } case c := <-srv.addpeer: // At this point the connection is past the protocol handshake. // Its capabilities are known and the remote identity is verified. - err := srv.protoHandshakeChecks(peers, c) + err := srv.protoHandshakeChecks(peers, inboundCount, c) if err == nil { // The handshakes are done and it passed all checks. p := newPeer(c, srv.Protocols) @@ -639,8 +645,11 @@ running: } name := truncateName(c.name) srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1) - peers[c.id] = p go srv.runPeer(p) + peers[c.id] = p + if p.Inbound() { + inboundCount++ + } } // The dialer logic relies on the assumption that // dial tasks complete after the peer has been added or @@ -655,6 +664,9 @@ running: d := common.PrettyDuration(mclock.Now() - pd.created) pd.log.Debug("Removing p2p peer", "duration", d, "peers", len(peers)-1, "req", pd.requested, "err", pd.err) delete(peers, pd.ID()) + if pd.Inbound() { + inboundCount-- + } } } @@ -681,20 +693,22 @@ running: } } -func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { +func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error { // Drop connections with no matching protocols. if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 { return DiscUselessPeer } // Repeat the encryption handshake checks because the // peer set might have changed between the handshakes. - return srv.encHandshakeChecks(peers, c) + return srv.encHandshakeChecks(peers, inboundCount, c) } -func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { +func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error { switch { case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers: return DiscTooManyPeers + case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns(): + return DiscTooManyPeers case peers[c.id] != nil: return DiscAlreadyConnected case c.id == srv.Self().ID: @@ -704,6 +718,21 @@ func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) } } +func (srv *Server) maxInboundConns() int { + return srv.MaxPeers - srv.maxDialedConns() +} + +func (srv *Server) maxDialedConns() int { + if srv.NoDiscovery || srv.NoDial { + return 0 + } + r := srv.DialRatio + if r == 0 { + r = defaultDialRatio + } + return srv.MaxPeers / r +} + type tempError interface { Temporary() bool } @@ -714,10 +743,7 @@ func (srv *Server) listenLoop() { defer srv.loopWG.Done() srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab)) - // This channel acts as a semaphore limiting - // active inbound connections that are lingering pre-handshake. - // If all slots are taken, no further connections are accepted. - tokens := maxAcceptConns + tokens := defaultMaxPendingPeers if srv.MaxPendingPeers > 0 { tokens = srv.MaxPendingPeers } @@ -758,9 +784,6 @@ func (srv *Server) listenLoop() { fd = newMeteredConn(fd, true) srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr()) - - // Spawn the handler. It will give the slot back when the connection - // has been established. go func() { srv.SetupConn(fd, inboundConn, nil) slots <- struct{}{} -- cgit v1.2.3 From 589b603a9b1e17930d1e83ca64ce7cdc4c3d5c85 Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Mon, 12 Feb 2018 13:52:07 +0100 Subject: rpc: dns rebind protection (#15962) * cmd,node,rpc: add allowedHosts to prevent dns rebinding attacks * p2p,node: Fix bug with dumpconfig introduced in r54aeb8e4c0bb9f0e7a6c67258af67df3b266af3d * rpc: add wildcard support for rpcallowedhosts + go fmt * cmd/geth, cmd/utils, node, rpc: ignore direct ip(v4/6) addresses in rpc virtual hostnames check * http, rpc, utils: make vhosts into map, address review concerns * node: change log messages to use geth standard (not sprintf) * rpc: fix spelling --- p2p/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'p2p') diff --git a/p2p/server.go b/p2p/server.go index edc1d9d21..90e92dc05 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -142,7 +142,7 @@ type Config struct { EnableMsgEvents bool // Logger is a custom logger to use with the p2p.Server. - Logger log.Logger + Logger log.Logger `toml:",omitempty"` } // Server manages all peer connections. -- cgit v1.2.3 From 20797348ca9d037dc5b7a830bafdfe1ea703eac0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= Date: Tue, 13 Feb 2018 20:59:43 +0200 Subject: p2p/discover: fix out-of-bounds issue --- p2p/discover/table.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'p2p') diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 84c54dac1..17c9db777 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -763,7 +763,7 @@ func (tab *Table) addReplacement(b *bucket, n *Node) { // last entry in the bucket. If 'last' isn't the last entry, it has either been replaced // with someone else or became active. func (tab *Table) replace(b *bucket, last *Node) *Node { - if len(b.entries) >= 0 && b.entries[len(b.entries)-1].ID != last.ID { + if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID != last.ID { // Entry has moved, don't replace it. return nil } -- cgit v1.2.3 From a5c0bbb4f4c321c355637ef57fff807857128c6b Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Wed, 14 Feb 2018 13:49:11 +0100 Subject: all: update license information (#16089) --- p2p/simulations/adapters/state.go | 1 + 1 file changed, 1 insertion(+) (limited to 'p2p') diff --git a/p2p/simulations/adapters/state.go b/p2p/simulations/adapters/state.go index 8b1dfef90..0d4ecfb0f 100644 --- a/p2p/simulations/adapters/state.go +++ b/p2p/simulations/adapters/state.go @@ -13,6 +13,7 @@ // // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . + package adapters type SimStateStore struct { -- cgit v1.2.3 From 32301a4d6b3a9684e954057e7cdb15998764122b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= Date: Fri, 16 Feb 2018 17:05:08 +0200 Subject: p2p/discover: validate bond against lastpong, not db presence --- p2p/discover/udp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'p2p') diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index e40de2c36..5cc0b3d74 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -613,7 +613,7 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte if expired(req.Expiration) { return errExpired } - if t.db.node(fromID) == nil { + if age := time.Since(t.db.lastPong(fromID)); age > nodeDBNodeExpiration { // No bond exists, we don't process the packet. This prevents // an attack vector where the discovery protocol could be used // to amplify traffic in a DDOS attack. A malicious actor -- cgit v1.2.3 From aeedec4078647c22244552803c88521391224ab1 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Fri, 16 Feb 2018 20:16:23 +0100 Subject: p2p/discover: s/lastPong/bondTime/, update TestUDP_findnode I forgot to change the check in udp.go when I changed Table.bond to be based on lastPong instead of node presence in db. Rename lastPong to bondTime and add hasBond so it's clearer what this DB key is used for now. --- p2p/discover/database.go | 17 +++++++++++------ p2p/discover/database_test.go | 18 +++++++++--------- p2p/discover/table.go | 6 +++--- p2p/discover/udp.go | 2 +- p2p/discover/udp_test.go | 8 ++------ 5 files changed, 26 insertions(+), 25 deletions(-) (limited to 'p2p') diff --git a/p2p/discover/database.go b/p2p/discover/database.go index b136609f2..6f98de9b4 100644 --- a/p2p/discover/database.go +++ b/p2p/discover/database.go @@ -257,7 +257,7 @@ func (db *nodeDB) expireNodes() error { } // Skip the node if not expired yet (and not self) if !bytes.Equal(id[:], db.self[:]) { - if seen := db.lastPong(id); seen.After(threshold) { + if seen := db.bondTime(id); seen.After(threshold) { continue } } @@ -278,13 +278,18 @@ func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix()) } -// lastPong retrieves the time of the last successful contact from remote node. -func (db *nodeDB) lastPong(id NodeID) time.Time { +// bondTime retrieves the time of the last successful pong from remote node. +func (db *nodeDB) bondTime(id NodeID) time.Time { return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0) } -// updateLastPong updates the last time a remote node successfully contacted. -func (db *nodeDB) updateLastPong(id NodeID, instance time.Time) error { +// hasBond reports whether the given node is considered bonded. +func (db *nodeDB) hasBond(id NodeID) bool { + return time.Since(db.bondTime(id)) < nodeDBNodeExpiration +} + +// updateBondTime updates the last pong time of a node. +func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix()) } @@ -327,7 +332,7 @@ seek: if n.ID == db.self { continue seek } - if now.Sub(db.lastPong(n.ID)) > maxAge { + if now.Sub(db.bondTime(n.ID)) > maxAge { continue seek } for i := range nodes { diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go index be972fd2c..c4fa44d09 100644 --- a/p2p/discover/database_test.go +++ b/p2p/discover/database_test.go @@ -125,13 +125,13 @@ func TestNodeDBFetchStore(t *testing.T) { t.Errorf("ping: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node pong object - if stored := db.lastPong(node.ID); stored.Unix() != 0 { + if stored := db.bondTime(node.ID); stored.Unix() != 0 { t.Errorf("pong: non-existing object: %v", stored) } - if err := db.updateLastPong(node.ID, inst); err != nil { + if err := db.updateBondTime(node.ID, inst); err != nil { t.Errorf("pong: failed to update: %v", err) } - if stored := db.lastPong(node.ID); stored.Unix() != inst.Unix() { + if stored := db.bondTime(node.ID); stored.Unix() != inst.Unix() { t.Errorf("pong: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node findnode-failure object @@ -224,8 +224,8 @@ func TestNodeDBSeedQuery(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to insert lastPong: %v", i, err) + if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to insert bondTime: %v", i, err) } } @@ -332,8 +332,8 @@ func TestNodeDBExpiration(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to update pong: %v", i, err) + if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } // Expire some of them, and check the rest @@ -365,8 +365,8 @@ func TestNodeDBSelfExpiration(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to update pong: %v", i, err) + if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } // Expire the nodes and make sure self has been evacuated too diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 17c9db777..6509326e6 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -455,7 +455,7 @@ func (tab *Table) loadSeedNodes(bond bool) { } for i := range seeds { seed := seeds[i] - age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.lastPong(seed.ID)) }} + age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }} log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age) tab.add(seed) } @@ -596,7 +596,7 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16 } // Start bonding if we haven't seen this node for a while or if it failed findnode too often. node, fails := tab.db.node(id), tab.db.findFails(id) - age := time.Since(tab.db.lastPong(id)) + age := time.Since(tab.db.bondTime(id)) var result error if fails > 0 || age > nodeDBNodeExpiration { log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age) @@ -663,7 +663,7 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { if err := tab.net.ping(id, addr); err != nil { return err } - tab.db.updateLastPong(id, time.Now()) + tab.db.updateBondTime(id, time.Now()) return nil } diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index 5cc0b3d74..524c6e498 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -613,7 +613,7 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte if expired(req.Expiration) { return errExpired } - if age := time.Since(t.db.lastPong(fromID)); age > nodeDBNodeExpiration { + if !t.db.hasBond(fromID) { // No bond exists, we don't process the packet. This prevents // an attack vector where the discovery protocol could be used // to amplify traffic in a DDOS attack. A malicious actor diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index 3ffa5c4dd..db9804f7b 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -247,12 +247,8 @@ func TestUDP_findnode(t *testing.T) { // ensure there's a bond with the test node, // findnode won't be accepted otherwise. - test.table.db.updateNode(NewNode( - PubkeyID(&test.remotekey.PublicKey), - test.remoteaddr.IP, - uint16(test.remoteaddr.Port), - 99, - )) + test.table.db.updateBondTime(PubkeyID(&test.remotekey.PublicKey), time.Now()) + // check that closest neighbors are returned. test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) expected := test.table.closest(targetHash, bucketSize) -- cgit v1.2.3 From e07603bbc4d474667a0e724e9baefd8b7c330a61 Mon Sep 17 00:00:00 2001 From: Janos Guljas Date: Mon, 12 Feb 2018 18:40:25 +0100 Subject: p2p/testing: check for all expectations in TestExchanges Handle all expectations in ProtocolSession.TestExchanges in any order that are received. --- p2p/testing/protocolsession.go | 178 +++++++++++++++++++++++++++++------------ p2p/testing/protocoltester.go | 79 ++++++++++++++++-- 2 files changed, 200 insertions(+), 57 deletions(-) (limited to 'p2p') diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go index a779aeebb..361285f06 100644 --- a/p2p/testing/protocolsession.go +++ b/p2p/testing/protocolsession.go @@ -19,13 +19,17 @@ package testing import ( "errors" "fmt" + "sync" "time" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/simulations/adapters" ) +var errTimedOut = errors.New("timed out") + // ProtocolSession is a quasi simulation of a pivot node running // a service and a number of dummy peers that can send (trigger) or // receive (expect) messages @@ -46,6 +50,7 @@ type Exchange struct { Label string Triggers []Trigger Expects []Expect + Timeout time.Duration } // Trigger is part of the exchange, incoming message for the pivot node @@ -102,76 +107,145 @@ func (self *ProtocolSession) trigger(trig Trigger) error { } // expect checks an expectation of a message sent out by the pivot node -func (self *ProtocolSession) expect(exp Expect) error { - if exp.Msg == nil { - return errors.New("no message to expect") - } - simNode, ok := self.adapter.GetNode(exp.Peer) - if !ok { - return fmt.Errorf("trigger: peer %v does not exist (1- %v)", exp.Peer, len(self.IDs)) +func (self *ProtocolSession) expect(exps []Expect) error { + // construct a map of expectations for each node + peerExpects := make(map[discover.NodeID][]Expect) + for _, exp := range exps { + if exp.Msg == nil { + return errors.New("no message to expect") + } + peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp) } - mockNode, ok := simNode.Services()[0].(*mockNode) - if !ok { - return fmt.Errorf("trigger: peer %v is not a mock", exp.Peer) + + // construct a map of mockNodes for each node + mockNodes := make(map[discover.NodeID]*mockNode) + for nodeID := range peerExpects { + simNode, ok := self.adapter.GetNode(nodeID) + if !ok { + return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(self.IDs)) + } + mockNode, ok := simNode.Services()[0].(*mockNode) + if !ok { + return fmt.Errorf("trigger: peer %v is not a mock", nodeID) + } + mockNodes[nodeID] = mockNode } + // done chanell cancels all created goroutines when function returns + done := make(chan struct{}) + defer close(done) + // errc catches the first error from errc := make(chan error) + + wg := &sync.WaitGroup{} + wg.Add(len(mockNodes)) + for nodeID, mockNode := range mockNodes { + nodeID := nodeID + mockNode := mockNode + go func() { + defer wg.Done() + + // Sum all Expect timeouts to give the maximum + // time for all expectations to finish. + // mockNode.Expect checks all received messages against + // a list of expected messages and timeout for each + // of them can not be checked separately. + var t time.Duration + for _, exp := range peerExpects[nodeID] { + if exp.Timeout == time.Duration(0) { + t += 2000 * time.Millisecond + } else { + t += exp.Timeout + } + } + alarm := time.NewTimer(t) + defer alarm.Stop() + + // expectErrc is used to check if error returned + // from mockNode.Expect is not nil and to send it to + // errc only in that case. + // done channel will be closed when function + expectErrc := make(chan error) + go func() { + select { + case expectErrc <- mockNode.Expect(peerExpects[nodeID]...): + case <-done: + case <-alarm.C: + } + }() + + select { + case err := <-expectErrc: + if err != nil { + select { + case errc <- err: + case <-done: + case <-alarm.C: + errc <- errTimedOut + } + } + case <-done: + case <-alarm.C: + errc <- errTimedOut + } + + }() + } + go func() { - errc <- mockNode.Expect(&exp) + wg.Wait() + // close errc when all goroutines finish to return nill err from errc + close(errc) }() - t := exp.Timeout - if t == time.Duration(0) { - t = 2000 * time.Millisecond - } - select { - case err := <-errc: - return err - case <-time.After(t): - return fmt.Errorf("timout expecting %v sent to peer %v", exp.Msg, exp.Peer) - } + return <-errc } // TestExchanges tests a series of exchanges against the session func (self *ProtocolSession) TestExchanges(exchanges ...Exchange) error { - // launch all triggers of this exchanges + for i, e := range exchanges { + if err := self.testExchange(e); err != nil { + return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err) + } + log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label)) + } + return nil +} + +// testExchange tests a single Exchange. +// Default timeout value is 2 seconds. +func (self *ProtocolSession) testExchange(e Exchange) error { + errc := make(chan error) + done := make(chan struct{}) + defer close(done) - for _, e := range exchanges { - errc := make(chan error, len(e.Triggers)+len(e.Expects)) + go func() { for _, trig := range e.Triggers { - errc <- self.trigger(trig) + err := self.trigger(trig) + if err != nil { + errc <- err + return + } } - // each expectation is spawned in separate go-routine - // expectations of an exchange are conjunctive but unordered, i.e., - // only all of them arriving constitutes a pass - // each expectation is meant to be for a different peer, otherwise they are expected to panic - // testing of an exchange blocks until all expectations are decided - // an expectation is decided if - // expected message arrives OR - // an unexpected message arrives (panic) - // times out on their individual timeout - for _, ex := range e.Expects { - // expect msg spawned to separate go routine - go func(exp Expect) { - errc <- self.expect(exp) - }(ex) + select { + case errc <- self.expect(e.Expects): + case <-done: } + }() - // time out globally or finish when all expectations satisfied - timeout := time.After(5 * time.Second) - for i := 0; i < len(e.Triggers)+len(e.Expects); i++ { - select { - case err := <-errc: - if err != nil { - return fmt.Errorf("exchange failed with: %v", err) - } - case <-timeout: - return fmt.Errorf("exchange %v: '%v' timed out", i, e.Label) - } - } + // time out globally or finish when all expectations satisfied + t := e.Timeout + if t == 0 { + t = 2000 * time.Millisecond + } + alarm := time.NewTimer(t) + select { + case err := <-errc: + return err + case <-alarm.C: + return errTimedOut } - return nil } // TestDisconnected tests the disconnections given as arguments diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go index ea5b106ff..a797412d6 100644 --- a/p2p/testing/protocoltester.go +++ b/p2p/testing/protocoltester.go @@ -24,7 +24,11 @@ that can be used to send and receive messages package testing import ( + "bytes" "fmt" + "io" + "io/ioutil" + "strings" "sync" "testing" @@ -34,6 +38,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/simulations" "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" ) @@ -152,7 +157,7 @@ type mockNode struct { testNode trigger chan *Trigger - expect chan *Expect + expect chan []Expect err chan error stop chan struct{} stopOnce sync.Once @@ -161,7 +166,7 @@ type mockNode struct { func newMockNode() *mockNode { mock := &mockNode{ trigger: make(chan *Trigger), - expect: make(chan *Expect), + expect: make(chan []Expect), err: make(chan error), stop: make(chan struct{}), } @@ -176,8 +181,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error { select { case trig := <-m.trigger: m.err <- p2p.Send(rw, trig.Code, trig.Msg) - case exp := <-m.expect: - m.err <- p2p.ExpectMsg(rw, exp.Code, exp.Msg) + case exps := <-m.expect: + m.err <- expectMsgs(rw, exps) case <-m.stop: return nil } @@ -189,7 +194,7 @@ func (m *mockNode) Trigger(trig *Trigger) error { return <-m.err } -func (m *mockNode) Expect(exp *Expect) error { +func (m *mockNode) Expect(exp ...Expect) error { m.expect <- exp return <-m.err } @@ -198,3 +203,67 @@ func (m *mockNode) Stop() error { m.stopOnce.Do(func() { close(m.stop) }) return nil } + +func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error { + matched := make([]bool, len(exps)) + for { + msg, err := rw.ReadMsg() + if err != nil { + if err == io.EOF { + break + } + return err + } + actualContent, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return err + } + var found bool + for i, exp := range exps { + if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) { + if matched[i] { + return fmt.Errorf("message #%d received two times", i) + } + matched[i] = true + found = true + break + } + } + if !found { + expected := make([]string, 0) + for i, exp := range exps { + if matched[i] { + continue + } + expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg))) + } + return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or ")) + } + done := true + for _, m := range matched { + if !m { + done = false + break + } + } + if done { + return nil + } + } + for i, m := range matched { + if !m { + return fmt.Errorf("expected message #%d not received", i) + } + } + return nil +} + +// mustEncodeMsg uses rlp to encode a message. +// In case of error it panics. +func mustEncodeMsg(msg interface{}) []byte { + contentEnc, err := rlp.EncodeToBytes(msg) + if err != nil { + panic("content encode error: " + err.Error()) + } + return contentEnc +} -- cgit v1.2.3 From 14c76371bab105a99c796a652df86df3e3e7d958 Mon Sep 17 00:00:00 2001 From: Dmitry Shulyak Date: Wed, 21 Feb 2018 16:03:26 +0200 Subject: p2p: when peer is removed remove it also from dial history (#16060) This change removes a peer information from dialing history when peer is removed from static list. It allows to force a server to re-dial concrete peer if it is needed. In our case we are running geth node on mobile devices, and it is common for a network connection to flap on mobile. Almost every time it flaps or network connection is changed from cellular to wifi peers are disconnected with read timeout. And usually it takes 30 seconds (default expiration timeout) to recover connection with static peers after connectivity is restored. This change allows us to reconnect with peers almost immediately and it seems harmless enough. --- p2p/dial.go | 13 +++++++++++++ p2p/dial_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) (limited to 'p2p') diff --git a/p2p/dial.go b/p2p/dial.go index f5ff2c211..d8feceb9f 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -154,6 +154,9 @@ func (s *dialstate) addStatic(n *discover.Node) { func (s *dialstate) removeStatic(n *discover.Node) { // This removes a task so future attempts to connect will not be made. delete(s.static, n.ID) + // This removes a previous dial timestamp so that application + // can force a server to reconnect with chosen peer immediately. + s.hist.remove(n.ID) } func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { @@ -390,6 +393,16 @@ func (h dialHistory) min() pastDial { } func (h *dialHistory) add(id discover.NodeID, exp time.Time) { heap.Push(h, pastDial{id, exp}) + +} +func (h *dialHistory) remove(id discover.NodeID) bool { + for i, v := range *h { + if v.id == id { + heap.Remove(h, i) + return true + } + } + return false } func (h dialHistory) contains(id discover.NodeID) bool { for _, v := range h { diff --git a/p2p/dial_test.go b/p2p/dial_test.go index ad18ef9ab..2a7941fc6 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -515,6 +515,50 @@ func TestDialStateStaticDial(t *testing.T) { }) } +// This test checks that static peers will be redialed immediately if they were re-added to a static list. +func TestDialStaticAfterReset(t *testing.T) { + wantStatic := []*discover.Node{ + {ID: uintID(1)}, + {ID: uintID(2)}, + } + + rounds := []round{ + // Static dials are launched for the nodes that aren't yet connected. + { + peers: nil, + new: []task{ + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, + }, + }, + // No new dial tasks, all peers are connected. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(1)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(2)}}, + }, + done: []task{ + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, + &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, + }, + new: []task{ + &waitExpireTask{Duration: 30 * time.Second}, + }, + }, + } + dTest := dialtest{ + init: newDialState(wantStatic, nil, fakeTable{}, 0, nil), + rounds: rounds, + } + runDialTest(t, dTest) + for _, n := range wantStatic { + dTest.init.removeStatic(n) + dTest.init.addStatic(n) + } + // without removing peers they will be considered recently dialed + runDialTest(t, dTest) +} + // This test checks that past dials are not retried for some time. func TestDialStateCache(t *testing.T) { wantStatic := []*discover.Node{ -- cgit v1.2.3 From 28b20cff4b76595b695f9d784c98d50d31bb269b Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Thu, 22 Feb 2018 11:37:57 +0100 Subject: p2p/protocols: gofmt -w -s --- p2p/protocols/protocol_test.go | 44 +++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 22 deletions(-) (limited to 'p2p') diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go index b041a650a..053f537a6 100644 --- a/p2p/protocols/protocol_test.go +++ b/p2p/protocols/protocol_test.go @@ -147,18 +147,18 @@ func protocolTester(t *testing.T, pp *p2ptest.TestPeerPool) *p2ptest.ProtocolTes func protoHandshakeExchange(id discover.NodeID, proto *protoHandshake) []p2ptest.Exchange { return []p2ptest.Exchange{ - p2ptest.Exchange{ + { Expects: []p2ptest.Expect{ - p2ptest.Expect{ + { Code: 0, Msg: &protoHandshake{42, "420"}, Peer: id, }, }, }, - p2ptest.Exchange{ + { Triggers: []p2ptest.Trigger{ - p2ptest.Trigger{ + { Code: 0, Msg: proto, Peer: id, @@ -200,18 +200,18 @@ func TestProtoHandshakeSuccess(t *testing.T) { func moduleHandshakeExchange(id discover.NodeID, resp uint) []p2ptest.Exchange { return []p2ptest.Exchange{ - p2ptest.Exchange{ + { Expects: []p2ptest.Expect{ - p2ptest.Expect{ + { Code: 1, Msg: &hs0{42}, Peer: id, }, }, }, - p2ptest.Exchange{ + { Triggers: []p2ptest.Trigger{ - p2ptest.Trigger{ + { Code: 1, Msg: &hs0{resp}, Peer: id, @@ -252,42 +252,42 @@ func TestModuleHandshakeSuccess(t *testing.T) { func testMultiPeerSetup(a, b discover.NodeID) []p2ptest.Exchange { return []p2ptest.Exchange{ - p2ptest.Exchange{ + { Label: "primary handshake", Expects: []p2ptest.Expect{ - p2ptest.Expect{ + { Code: 0, Msg: &protoHandshake{42, "420"}, Peer: a, }, - p2ptest.Expect{ + { Code: 0, Msg: &protoHandshake{42, "420"}, Peer: b, }, }, }, - p2ptest.Exchange{ + { Label: "module handshake", Triggers: []p2ptest.Trigger{ - p2ptest.Trigger{ + { Code: 0, Msg: &protoHandshake{42, "420"}, Peer: a, }, - p2ptest.Trigger{ + { Code: 0, Msg: &protoHandshake{42, "420"}, Peer: b, }, }, Expects: []p2ptest.Expect{ - p2ptest.Expect{ + { Code: 1, Msg: &hs0{42}, Peer: a, }, - p2ptest.Expect{ + { Code: 1, Msg: &hs0{42}, Peer: b, @@ -295,10 +295,10 @@ func testMultiPeerSetup(a, b discover.NodeID) []p2ptest.Exchange { }, }, - p2ptest.Exchange{Label: "alternative module handshake", Triggers: []p2ptest.Trigger{p2ptest.Trigger{Code: 1, Msg: &hs0{41}, Peer: a}, - p2ptest.Trigger{Code: 1, Msg: &hs0{41}, Peer: b}}}, - p2ptest.Exchange{Label: "repeated module handshake", Triggers: []p2ptest.Trigger{p2ptest.Trigger{Code: 1, Msg: &hs0{1}, Peer: a}}}, - p2ptest.Exchange{Label: "receiving repeated module handshake", Expects: []p2ptest.Expect{p2ptest.Expect{Code: 1, Msg: &hs0{43}, Peer: a}}}} + {Label: "alternative module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{41}, Peer: a}, + {Code: 1, Msg: &hs0{41}, Peer: b}}}, + {Label: "repeated module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{1}, Peer: a}}}, + {Label: "receiving repeated module handshake", Expects: []p2ptest.Expect{{Code: 1, Msg: &hs0{43}, Peer: a}}}} } func runMultiplePeers(t *testing.T, peer int, errs ...error) { @@ -332,7 +332,7 @@ WAIT: // peer 0 sends kill request for peer with index err := s.TestExchanges(p2ptest.Exchange{ Triggers: []p2ptest.Trigger{ - p2ptest.Trigger{ + { Code: 2, Msg: &kill{s.IDs[peer]}, Peer: s.IDs[0], @@ -347,7 +347,7 @@ WAIT: // the peer not killed sends a drop request err = s.TestExchanges(p2ptest.Exchange{ Triggers: []p2ptest.Trigger{ - p2ptest.Trigger{ + { Code: 3, Msg: &drop{}, Peer: s.IDs[(peer+1)%2], -- cgit v1.2.3 From 1e457b659989478004329c2f3a82b5ac67b32cbf Mon Sep 17 00:00:00 2001 From: Anton Evangelatov Date: Thu, 22 Feb 2018 11:41:06 +0100 Subject: p2p: don't send DiscReason when using net.Pipe (#16004) --- p2p/rlpx.go | 10 ++++++++-- p2p/rlpx_test.go | 38 +++++++++++++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 5 deletions(-) (limited to 'p2p') diff --git a/p2p/rlpx.go b/p2p/rlpx.go index 24037ecc1..e65a0b604 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -108,8 +108,14 @@ func (t *rlpx) close(err error) { // Tell the remote end why we're disconnecting if possible. if t.rw != nil { if r, ok := err.(DiscReason); ok && r != DiscNetworkError { - t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)) - SendItems(t.rw, discMsg, r) + // rlpx tries to send DiscReason to disconnected peer + // if the connection is net.Pipe (in-memory simulation) + // it hangs forever, since net.Pipe does not implement + // a write deadline. Because of this only try to send + // the disconnect reason message if there is no error. + if err := t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)); err == nil { + SendItems(t.rw, discMsg, r) + } } } t.fd.Close() diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go index f4cefa650..bca460402 100644 --- a/p2p/rlpx_test.go +++ b/p2p/rlpx_test.go @@ -156,14 +156,18 @@ func TestProtocolHandshake(t *testing.T) { node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44} hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}} - fd0, fd1 = net.Pipe() - wg sync.WaitGroup + wg sync.WaitGroup ) + fd0, fd1, err := tcpPipe() + if err != nil { + t.Fatal(err) + } + wg.Add(2) go func() { defer wg.Done() - defer fd1.Close() + defer fd0.Close() rlpx := newRLPX(fd0) remid, err := rlpx.doEncHandshake(prv0, node1) if err != nil { @@ -597,3 +601,31 @@ func TestHandshakeForwardCompatibility(t *testing.T) { t.Errorf("ingress-mac('foo') mismatch:\ngot %x\nwant %x", fooIngressHash, wantFooIngressHash) } } + +// tcpPipe creates an in process full duplex pipe based on a localhost TCP socket +func tcpPipe() (net.Conn, net.Conn, error) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer l.Close() + + var aconn net.Conn + aerr := make(chan error, 1) + go func() { + var err error + aconn, err = l.Accept() + aerr <- err + }() + + dconn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + <-aerr + return nil, nil, err + } + if err := <-aerr; err != nil { + dconn.Close() + return nil, nil, err + } + return aconn, dconn, nil +} -- cgit v1.2.3