aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--p2p/client_identity.go6
-rw-r--r--p2p/message.go62
-rw-r--r--p2p/messenger.go221
-rw-r--r--p2p/messenger_test.go203
-rw-r--r--p2p/natpmp.go34
-rw-r--r--p2p/natupnp.go198
-rw-r--r--p2p/network.go196
-rw-r--r--p2p/peer.go476
-rw-r--r--p2p/peer_error.go150
-rw-r--r--p2p/peer_error_handler.go98
-rw-r--r--p2p/peer_error_handler_test.go34
-rw-r--r--p2p/peer_test.go308
-rw-r--r--p2p/protocol.go412
-rw-r--r--p2p/server.go713
-rw-r--r--p2p/server_test.go388
-rw-r--r--p2p/testlog_test.go28
-rw-r--r--p2p/testpoc7.go40
17 files changed, 1665 insertions, 1902 deletions
diff --git a/p2p/client_identity.go b/p2p/client_identity.go
index 236b23106..bc865b63b 100644
--- a/p2p/client_identity.go
+++ b/p2p/client_identity.go
@@ -5,10 +5,10 @@ import (
"runtime"
)
-// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc.
+// ClientIdentity represents the identity of a peer.
type ClientIdentity interface {
- String() string
- Pubkey() []byte
+ String() string // human readable identity
+ Pubkey() []byte // 512-bit public key
}
type SimpleClientIdentity struct {
diff --git a/p2p/message.go b/p2p/message.go
index 97d440a27..89ad189d7 100644
--- a/p2p/message.go
+++ b/p2p/message.go
@@ -11,8 +11,6 @@ import (
"github.com/ethereum/go-ethereum/ethutil"
)
-type MsgCode uint64
-
// Msg defines the structure of a p2p message.
//
// Note that a Msg can only be sent once since the Payload reader is
@@ -21,13 +19,13 @@ type MsgCode uint64
// structure, encode the payload into a byte array and create a
// separate Msg with a bytes.Reader as Payload for each send.
type Msg struct {
- Code MsgCode
+ Code uint64
Size uint32 // size of the paylod
Payload io.Reader
}
// NewMsg creates an RLP-encoded message with the given code.
-func NewMsg(code MsgCode, params ...interface{}) Msg {
+func NewMsg(code uint64, params ...interface{}) Msg {
buf := new(bytes.Buffer)
for _, p := range params {
buf.Write(ethutil.Encode(p))
@@ -63,6 +61,52 @@ func (msg Msg) Discard() error {
return err
}
+type MsgReader interface {
+ ReadMsg() (Msg, error)
+}
+
+type MsgWriter interface {
+ // WriteMsg sends an existing message.
+ // The Payload reader of the message is consumed.
+ // Note that messages can be sent only once.
+ WriteMsg(Msg) error
+
+ // EncodeMsg writes an RLP-encoded message with the given
+ // code and data elements.
+ EncodeMsg(code uint64, data ...interface{}) error
+}
+
+// MsgReadWriter provides reading and writing of encoded messages.
+type MsgReadWriter interface {
+ MsgReader
+ MsgWriter
+}
+
+// MsgLoop reads messages off the given reader and
+// calls the handler function for each decoded message until
+// it returns an error or the peer connection is closed.
+//
+// If a message is larger than the given maximum size,
+// MsgLoop returns an appropriate error.
+func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Value) error) error {
+ for {
+ msg, err := r.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if msg.Size > maxsize {
+ return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
+ }
+ value, err := msg.Data()
+ if err != nil {
+ return err
+ }
+ if err := f(msg.Code, value); err != nil {
+ return err
+ }
+ }
+}
+
var magicToken = []byte{34, 64, 8, 145}
func writeMsg(w io.Writer, msg Msg) error {
@@ -103,10 +147,10 @@ func readMsg(r byteReader) (msg Msg, err error) {
// read magic and payload size
start := make([]byte, 8)
if _, err = io.ReadFull(r, start); err != nil {
- return msg, NewPeerError(ReadError, "%v", err)
+ return msg, newPeerError(errRead, "%v", err)
}
if !bytes.HasPrefix(start, magicToken) {
- return msg, NewPeerError(MagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
+ return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
}
size := binary.BigEndian.Uint32(start[4:])
@@ -152,13 +196,13 @@ func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) {
}
// readUint reads an RLP-encoded unsigned integer from r.
-func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) {
+func readMsgCode(r byteReader) (code uint64, codelen uint32, err error) {
b, err := r.ReadByte()
if err != nil {
return 0, 0, err
}
if b < 0x80 {
- return MsgCode(b), 1, nil
+ return uint64(b), 1, nil
} else if b < 0x89 { // max length for uint64 is 8 bytes
codelen = uint32(b - 0x80)
if codelen == 0 {
@@ -168,7 +212,7 @@ func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) {
if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil {
return 0, 0, err
}
- return MsgCode(binary.BigEndian.Uint64(buf)), codelen, nil
+ return binary.BigEndian.Uint64(buf), codelen, nil
}
return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b)
}
diff --git a/p2p/messenger.go b/p2p/messenger.go
deleted file mode 100644
index c7948a9ac..000000000
--- a/p2p/messenger.go
+++ /dev/null
@@ -1,221 +0,0 @@
-package p2p
-
-import (
- "bufio"
- "bytes"
- "fmt"
- "io"
- "io/ioutil"
- "net"
- "sync"
- "time"
-)
-
-type Handlers map[string]Protocol
-
-type proto struct {
- in chan Msg
- maxcode, offset MsgCode
- messenger *messenger
-}
-
-func (rw *proto) WriteMsg(msg Msg) error {
- if msg.Code >= rw.maxcode {
- return NewPeerError(InvalidMsgCode, "not handled")
- }
- msg.Code += rw.offset
- return rw.messenger.writeMsg(msg)
-}
-
-func (rw *proto) ReadMsg() (Msg, error) {
- msg, ok := <-rw.in
- if !ok {
- return msg, io.EOF
- }
- msg.Code -= rw.offset
- return msg, nil
-}
-
-// eofSignal wraps a reader with eof signaling.
-// the eof channel is closed when the wrapped reader
-// reaches EOF.
-type eofSignal struct {
- wrapped io.Reader
- eof chan struct{}
-}
-
-func (r *eofSignal) Read(buf []byte) (int, error) {
- n, err := r.wrapped.Read(buf)
- if err != nil {
- close(r.eof) // tell messenger that msg has been consumed
- }
- return n, err
-}
-
-// messenger represents a message-oriented peer connection.
-// It keeps track of the set of protocols understood
-// by the remote peer.
-type messenger struct {
- peer *Peer
- handlers Handlers
-
- // the mutex protects the connection
- // so only one protocol can write at a time.
- writeMu sync.Mutex
- conn net.Conn
- bufconn *bufio.ReadWriter
-
- protocolLock sync.RWMutex
- protocols map[string]*proto
- offsets map[MsgCode]*proto
- protoWG sync.WaitGroup
-
- err chan error
- pulse chan bool
-}
-
-func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger {
- return &messenger{
- conn: conn,
- bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
- peer: peer,
- handlers: handlers,
- protocols: make(map[string]*proto),
- err: errchan,
- pulse: make(chan bool, 1),
- }
-}
-
-func (m *messenger) Start() {
- m.protocols[""] = m.startProto(0, "", &baseProtocol{})
- go m.readLoop()
-}
-
-func (m *messenger) Stop() {
- m.conn.Close()
- m.protoWG.Wait()
-}
-
-const (
- // maximum amount of time allowed for reading a message
- msgReadTimeout = 5 * time.Second
-
- // messages smaller than this many bytes will be read at
- // once before passing them to a protocol.
- wholePayloadSize = 64 * 1024
-)
-
-func (m *messenger) readLoop() {
- defer m.closeProtocols()
- for {
- m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
- msg, err := readMsg(m.bufconn)
- if err != nil {
- m.err <- err
- return
- }
- // send ping to heartbeat channel signalling time of last message
- m.pulse <- true
- proto, err := m.getProto(msg.Code)
- if err != nil {
- m.err <- err
- return
- }
- if msg.Size <= wholePayloadSize {
- // optimization: msg is small enough, read all
- // of it and move on to the next message
- buf, err := ioutil.ReadAll(msg.Payload)
- if err != nil {
- m.err <- err
- return
- }
- msg.Payload = bytes.NewReader(buf)
- proto.in <- msg
- } else {
- pr := &eofSignal{msg.Payload, make(chan struct{})}
- msg.Payload = pr
- proto.in <- msg
- <-pr.eof
- }
- }
-}
-
-func (m *messenger) closeProtocols() {
- m.protocolLock.RLock()
- for _, p := range m.protocols {
- close(p.in)
- }
- m.protocolLock.RUnlock()
-}
-
-func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto {
- proto := &proto{
- in: make(chan Msg),
- offset: offset,
- maxcode: impl.Offset(),
- messenger: m,
- }
- m.protoWG.Add(1)
- go func() {
- if err := impl.Start(m.peer, proto); err != nil && err != io.EOF {
- logger.Errorf("protocol %q error: %v\n", name, err)
- m.err <- err
- }
- m.protoWG.Done()
- }()
- return proto
-}
-
-// getProto finds the protocol responsible for handling
-// the given message code.
-func (m *messenger) getProto(code MsgCode) (*proto, error) {
- m.protocolLock.RLock()
- defer m.protocolLock.RUnlock()
- for _, proto := range m.protocols {
- if code >= proto.offset && code < proto.offset+proto.maxcode {
- return proto, nil
- }
- }
- return nil, NewPeerError(InvalidMsgCode, "%d", code)
-}
-
-// setProtocols starts all subprotocols shared with the
-// remote peer. the protocols must be sorted alphabetically.
-func (m *messenger) setRemoteProtocols(protocols []string) {
- m.protocolLock.Lock()
- defer m.protocolLock.Unlock()
- offset := baseProtocolOffset
- for _, name := range protocols {
- inst, ok := m.handlers[name]
- if !ok {
- continue // not handled
- }
- m.protocols[name] = m.startProto(offset, name, inst)
- offset += inst.Offset()
- }
-}
-
-// writeProtoMsg sends the given message on behalf of the given named protocol.
-func (m *messenger) writeProtoMsg(protoName string, msg Msg) error {
- m.protocolLock.RLock()
- proto, ok := m.protocols[protoName]
- m.protocolLock.RUnlock()
- if !ok {
- return fmt.Errorf("protocol %s not handled by peer", protoName)
- }
- if msg.Code >= proto.maxcode {
- return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
- }
- msg.Code += proto.offset
- return m.writeMsg(msg)
-}
-
-// writeMsg writes a message to the connection.
-func (m *messenger) writeMsg(msg Msg) error {
- m.writeMu.Lock()
- defer m.writeMu.Unlock()
- if err := writeMsg(m.bufconn, msg); err != nil {
- return err
- }
- return m.bufconn.Flush()
-}
diff --git a/p2p/messenger_test.go b/p2p/messenger_test.go
deleted file mode 100644
index 2264e10d3..000000000
--- a/p2p/messenger_test.go
+++ /dev/null
@@ -1,203 +0,0 @@
-package p2p
-
-import (
- "bufio"
- "fmt"
- "io"
- "log"
- "net"
- "os"
- "reflect"
- "testing"
- "time"
-
- logpkg "github.com/ethereum/go-ethereum/logger"
-)
-
-func init() {
- logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel))
-}
-
-func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
- conn1, conn2 := net.Pipe()
- id := NewSimpleClientIdentity("test", "0", "0", "public key")
- server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist())
- peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0)
- return conn2, peer, peer.messenger
-}
-
-func performTestHandshake(r *bufio.Reader, w io.Writer) error {
- // read remote handshake
- msg, err := readMsg(r)
- if err != nil {
- return fmt.Errorf("read error: %v", err)
- }
- if msg.Code != handshakeMsg {
- return fmt.Errorf("first message should be handshake, got %d", msg.Code)
- }
- if err := msg.Discard(); err != nil {
- return err
- }
- // send empty handshake
- pubkey := make([]byte, 64)
- msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey)
- return writeMsg(w, msg)
-}
-
-type testProtocol struct {
- offset MsgCode
- f func(MsgReadWriter)
-}
-
-func (p *testProtocol) Offset() MsgCode {
- return p.offset
-}
-
-func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error {
- p.f(rw)
- return nil
-}
-
-func TestRead(t *testing.T) {
- done := make(chan struct{})
- handlers := Handlers{
- "a": &testProtocol{5, func(rw MsgReadWriter) {
- msg, err := rw.ReadMsg()
- if err != nil {
- t.Errorf("read error: %v", err)
- }
- if msg.Code != 2 {
- t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
- }
- data, err := msg.Data()
- if err != nil {
- t.Errorf("data decoding error: %v", err)
- }
- expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
- if !reflect.DeepEqual(data.Slice(), expdata) {
- t.Errorf("incorrect msg data %#v", data.Slice())
- }
- close(done)
- }},
- }
-
- net, peer, m := testMessenger(handlers)
- defer peer.Stop()
- bufr := bufio.NewReader(net)
- if err := performTestHandshake(bufr, net); err != nil {
- t.Fatalf("handshake failed: %v", err)
- }
- m.setRemoteProtocols([]string{"a"})
-
- writeMsg(net, NewMsg(18, 1, "000"))
- select {
- case <-done:
- case <-time.After(2 * time.Second):
- t.Errorf("receive timeout")
- }
-}
-
-func TestWriteFromProto(t *testing.T) {
- handlers := Handlers{
- "a": &testProtocol{2, func(rw MsgReadWriter) {
- if err := rw.WriteMsg(NewMsg(2)); err == nil {
- t.Error("expected error for out-of-range msg code, got nil")
- }
- if err := rw.WriteMsg(NewMsg(1)); err != nil {
- t.Errorf("write error: %v", err)
- }
- }},
- }
- net, peer, mess := testMessenger(handlers)
- defer peer.Stop()
- bufr := bufio.NewReader(net)
- if err := performTestHandshake(bufr, net); err != nil {
- t.Fatalf("handshake failed: %v", err)
- }
- mess.setRemoteProtocols([]string{"a"})
-
- msg, err := readMsg(bufr)
- if err != nil {
- t.Errorf("read error: %v")
- }
- if msg.Code != 17 {
- t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
- }
-}
-
-var discardProto = &testProtocol{1, func(rw MsgReadWriter) {
- for {
- msg, err := rw.ReadMsg()
- if err != nil {
- return
- }
- if err = msg.Discard(); err != nil {
- return
- }
- }
-}}
-
-func TestMessengerWriteProtoMsg(t *testing.T) {
- handlers := Handlers{"a": discardProto}
- net, peer, mess := testMessenger(handlers)
- defer peer.Stop()
- bufr := bufio.NewReader(net)
- if err := performTestHandshake(bufr, net); err != nil {
- t.Fatalf("handshake failed: %v", err)
- }
- mess.setRemoteProtocols([]string{"a"})
-
- // test write errors
- if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil {
- t.Errorf("expected error for unknown protocol, got nil")
- }
- if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil {
- t.Errorf("expected error for out-of-range msg code, got nil")
- } else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode {
- t.Errorf("wrong error for out-of-range msg code, got %#v")
- }
-
- // test succcessful write
- read, readerr := make(chan Msg), make(chan error)
- go func() {
- if msg, err := readMsg(bufr); err != nil {
- readerr <- err
- } else {
- read <- msg
- }
- }()
- if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil {
- t.Errorf("expect no error for known protocol: %v", err)
- }
- select {
- case msg := <-read:
- if msg.Code != 16 {
- t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
- }
- msg.Discard()
- case err := <-readerr:
- t.Errorf("read error: %v", err)
- }
-}
-
-func TestPulse(t *testing.T) {
- net, peer, _ := testMessenger(nil)
- defer peer.Stop()
- bufr := bufio.NewReader(net)
- if err := performTestHandshake(bufr, net); err != nil {
- t.Fatalf("handshake failed: %v", err)
- }
-
- before := time.Now()
- msg, err := readMsg(bufr)
- if err != nil {
- t.Fatalf("read error: %v", err)
- }
- after := time.Now()
- if msg.Code != pingMsg {
- t.Errorf("expected ping message, got %d", msg.Code)
- }
- if d := after.Sub(before); d < pingTimeout {
- t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout)
- }
-}
diff --git a/p2p/natpmp.go b/p2p/natpmp.go
index ff966d070..6714678c4 100644
--- a/p2p/natpmp.go
+++ b/p2p/natpmp.go
@@ -3,6 +3,7 @@ package p2p
import (
"fmt"
"net"
+ "time"
natpmp "github.com/jackpal/go-nat-pmp"
)
@@ -13,38 +14,37 @@ import (
// + Register for changes to the external address.
// + Re-register port mapping when router reboots.
// + A mechanism for keeping a port mapping registered.
+// + Discover gateway address automatically.
type natPMPClient struct {
client *natpmp.Client
}
-func NewNatPMP(gateway net.IP) (nat NAT) {
+// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
+// address should be the IP of your router.
+func PMP(gateway net.IP) (nat NAT) {
return &natPMPClient{natpmp.NewClient(gateway)}
}
-func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) {
+func (*natPMPClient) String() string {
+ return "NAT-PMP"
+}
+
+func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
response, err := n.client.GetExternalAddress()
if err != nil {
- return
+ return nil, err
}
- ip := response.ExternalIPAddress
- addr = net.IPv4(ip[0], ip[1], ip[2], ip[3])
- return
+ return response.ExternalIPAddress[:], nil
}
-func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int,
- description string, timeout int) (mappedExternalPort int, err error) {
- if timeout <= 0 {
- err = fmt.Errorf("timeout must not be <= 0")
- return
+func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
+ if lifetime <= 0 {
+ return fmt.Errorf("lifetime must not be <= 0")
}
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
- response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout)
- if err != nil {
- return
- }
- mappedExternalPort = int(response.MappedExternalPort)
- return
+ _, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
+ return err
}
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
diff --git a/p2p/natupnp.go b/p2p/natupnp.go
index fa9798d4d..2e0d8ce8d 100644
--- a/p2p/natupnp.go
+++ b/p2p/natupnp.go
@@ -7,6 +7,7 @@ import (
"bytes"
"encoding/xml"
"errors"
+ "fmt"
"net"
"net/http"
"os"
@@ -15,28 +16,46 @@ import (
"time"
)
+const (
+ upnpDiscoverAttempts = 3
+ upnpDiscoverTimeout = 5 * time.Second
+)
+
+// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
+// discover the address of your router using UDP broadcasts.
+func UPNP() NAT {
+ return &upnpNAT{}
+}
+
type upnpNAT struct {
serviceURL string
ourIP string
}
-func upnpDiscover(attempts int) (nat NAT, err error) {
+func (n *upnpNAT) String() string {
+ return "UPNP"
+}
+
+func (n *upnpNAT) discover() error {
+ if n.serviceURL != "" {
+ // already discovered
+ return nil
+ }
+
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
if err != nil {
- return
+ return err
}
+ // TODO: try on all network interfaces simultaneously.
+ // Broadcasting on 0.0.0.0 could select a random interface
+ // to send on (platform specific).
conn, err := net.ListenPacket("udp4", ":0")
if err != nil {
- return
- }
- socket := conn.(*net.UDPConn)
- defer socket.Close()
-
- err = socket.SetDeadline(time.Now().Add(10 * time.Second))
- if err != nil {
- return
+ return err
}
+ defer conn.Close()
+ conn.SetDeadline(time.Now().Add(10 * time.Second))
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
buf := bytes.NewBufferString(
"M-SEARCH * HTTP/1.1\r\n" +
@@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
"MX: 2\r\n\r\n")
message := buf.Bytes()
answerBytes := make([]byte, 1024)
- for i := 0; i < attempts; i++ {
- _, err = socket.WriteToUDP(message, ssdp)
+ for i := 0; i < upnpDiscoverAttempts; i++ {
+ _, err = conn.WriteTo(message, ssdp)
if err != nil {
- return
+ return err
}
- var n int
- n, _, err = socket.ReadFromUDP(answerBytes)
+ nn, _, err := conn.ReadFrom(answerBytes)
if err != nil {
continue
- // socket.Close()
- // return
}
- answer := string(answerBytes[0:n])
+ answer := string(answerBytes[0:nn])
if strings.Index(answer, "\r\n"+st) < 0 {
continue
}
@@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
var serviceURL string
serviceURL, err = getServiceURL(locURL)
if err != nil {
- return
+ return err
}
var ourIP string
ourIP, err = getOurIP()
if err != nil {
- return
+ return err
}
- nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP}
+ n.serviceURL = serviceURL
+ n.ourIP = ourIP
+ return nil
+ }
+ return errors.New("UPnP port discovery failed.")
+}
+
+func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
+ if err := n.discover(); err != nil {
+ return nil, err
+ }
+ info, err := n.getStatusInfo()
+ return net.ParseIP(info.externalIpAddress), err
+}
+
+func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
+ if err := n.discover(); err != nil {
+ return err
+ }
+
+ // A single concatenation would break ARM compilation.
+ message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
+ "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
+ message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
+ message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
+ "<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
+ "<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
+ message += description +
+ "</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
+ "</NewLeaseDuration></u:AddPortMapping>"
+
+ // TODO: check response to see if the port was forwarded
+ _, err := soapRequest(n.serviceURL, "AddPortMapping", message)
+ return err
+}
+
+func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
+ if err := n.discover(); err != nil {
+ return err
+ }
+
+ message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
+ "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
+ "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
+ "</u:DeletePortMapping>"
+
+ // TODO: check response to see if the port was deleted
+ _, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
+ return err
+}
+
+type statusInfo struct {
+ externalIpAddress string
+}
+
+func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
+ message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
+ "</u:GetStatusInfo>"
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
+ if err != nil {
return
}
- err = errors.New("UPnP port discovery failed.")
+
+ // TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
+
+ response.Body.Close()
return
}
@@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) {
}
return
}
-
-type statusInfo struct {
- externalIpAddress string
-}
-
-func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
-
- message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
- "</u:GetStatusInfo>"
-
- var response *http.Response
- response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
- if err != nil {
- return
- }
-
- // TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
-
- response.Body.Close()
- return
-}
-
-func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
- info, err := n.getStatusInfo()
- if err != nil {
- return
- }
- addr = net.ParseIP(info.externalIpAddress)
- return
-}
-
-func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
- // A single concatenation would break ARM compilation.
- message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
- "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort)
- message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
- message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" +
- "<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
- "<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
- message += description +
- "</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) +
- "</NewLeaseDuration></u:AddPortMapping>"
-
- var response *http.Response
- response, err = soapRequest(n.serviceURL, "AddPortMapping", message)
- if err != nil {
- return
- }
-
- // TODO: check response to see if the port was forwarded
- // log.Println(message, response)
- mappedExternalPort = externalPort
- _ = response
- return
-}
-
-func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
-
- message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
- "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
- "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
- "</u:DeletePortMapping>"
-
- var response *http.Response
- response, err = soapRequest(n.serviceURL, "DeletePortMapping", message)
- if err != nil {
- return
- }
-
- // TODO: check response to see if the port was deleted
- // log.Println(message, response)
- _ = response
- return
-}
diff --git a/p2p/network.go b/p2p/network.go
deleted file mode 100644
index 820cef1a9..000000000
--- a/p2p/network.go
+++ /dev/null
@@ -1,196 +0,0 @@
-package p2p
-
-import (
- "fmt"
- "math/rand"
- "net"
- "strconv"
- "time"
-)
-
-const (
- DialerTimeout = 180 //seconds
- KeepAlivePeriod = 60 //minutes
- portMappingUpdateInterval = 900 // seconds = 15 mins
- upnpDiscoverAttempts = 3
-)
-
-// Dialer is not an interface in net, so we define one
-// *net.Dialer conforms to this
-type Dialer interface {
- Dial(network, address string) (net.Conn, error)
-}
-
-type Network interface {
- Start() error
- Listener(net.Addr) (net.Listener, error)
- Dialer(net.Addr) (Dialer, error)
- NewAddr(string, int) (addr net.Addr, err error)
- ParseAddr(string) (addr net.Addr, err error)
-}
-
-type NAT interface {
- GetExternalAddress() (addr net.IP, err error)
- AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
- DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
-}
-
-type TCPNetwork struct {
- nat NAT
- natType NATType
- quit chan chan bool
- ports chan string
-}
-
-type NATType int
-
-const (
- NONE = iota
- UPNP
- PMP
-)
-
-const (
- portMappingTimeout = 1200 // 20 mins
-)
-
-func NewTCPNetwork(natType NATType) (net *TCPNetwork) {
- return &TCPNetwork{
- natType: natType,
- ports: make(chan string),
- }
-}
-
-func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) {
- return &net.Dialer{
- Timeout: DialerTimeout * time.Second,
- // KeepAlive: KeepAlivePeriod * time.Minute,
- LocalAddr: addr,
- }, nil
-}
-
-func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) {
- if self.natType == UPNP {
- _, port, _ := net.SplitHostPort(addr.String())
- if self.quit == nil {
- self.quit = make(chan chan bool)
- go self.updatePortMappings()
- }
- self.ports <- port
- }
- return net.Listen(addr.Network(), addr.String())
-}
-
-func (self *TCPNetwork) Start() (err error) {
- switch self.natType {
- case NONE:
- case UPNP:
- nat, uerr := upnpDiscover(upnpDiscoverAttempts)
- if uerr != nil {
- err = fmt.Errorf("UPNP failed: ", uerr)
- } else {
- self.nat = nat
- }
- case PMP:
- err = fmt.Errorf("PMP not implemented")
- default:
- err = fmt.Errorf("Invalid NAT type: %v", self.natType)
- }
- return
-}
-
-func (self *TCPNetwork) Stop() {
- q := make(chan bool)
- self.quit <- q
- <-q
-}
-
-func (self *TCPNetwork) addPortMapping(lport int) (err error) {
- _, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout)
- if err != nil {
- logger.Errorf("unable to add port mapping on %v: %v", lport, err)
- } else {
- logger.Debugf("succesfully added port mapping on %v", lport)
- }
- return
-}
-
-func (self *TCPNetwork) updatePortMappings() {
- timer := time.NewTimer(portMappingUpdateInterval * time.Second)
- lports := []int{}
-out:
- for {
- select {
- case port := <-self.ports:
- int64lport, _ := strconv.ParseInt(port, 10, 16)
- lport := int(int64lport)
- if err := self.addPortMapping(lport); err != nil {
- lports = append(lports, lport)
- }
- case <-timer.C:
- for lport := range lports {
- if err := self.addPortMapping(lport); err != nil {
- }
- }
- case errc := <-self.quit:
- errc <- true
- break out
- }
- }
-
- timer.Stop()
- for lport := range lports {
- if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil {
- logger.Debugf("unable to remove port mapping on %v: %v", lport, err)
- } else {
- logger.Debugf("succesfully removed port mapping on %v", lport)
- }
- }
-}
-
-func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) {
- ip, err := self.lookupIP(host)
- if err == nil {
- return &net.TCPAddr{
- IP: ip,
- Port: port,
- }, nil
- }
- return nil, err
-}
-
-func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) {
- host, port, err := net.SplitHostPort(address)
- if err == nil {
- iport, _ := strconv.Atoi(port)
- addr, e := self.NewAddr(host, iport)
- return addr, e
- }
- return nil, err
-}
-
-func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) {
- if ip = net.ParseIP(host); ip != nil {
- return
- }
-
- var ips []net.IP
- ips, err = net.LookupIP(host)
- if err != nil {
- logger.Warnln(err)
- return
- }
- if len(ips) == 0 {
- err = fmt.Errorf("No IP addresses available for %v", host)
- logger.Warnln(err)
- return
- }
- if len(ips) > 1 {
- // Pick a random IP address, simulating round-robin DNS.
- rand.Seed(time.Now().UTC().UnixNano())
- ip = ips[rand.Intn(len(ips))]
- } else {
- ip = ips[0]
- }
- return
-}
diff --git a/p2p/peer.go b/p2p/peer.go
index 34b6152a3..238d3d9c9 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -1,66 +1,454 @@
package p2p
import (
+ "bufio"
+ "bytes"
"fmt"
+ "io"
+ "io/ioutil"
"net"
- "strconv"
+ "sort"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/logger"
)
+// peerAddr is the structure of a peer list element.
+// It is also a valid net.Addr.
+type peerAddr struct {
+ IP net.IP
+ Port uint64
+ Pubkey []byte // optional
+}
+
+func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr {
+ n := addr.Network()
+ if n != "tcp" && n != "tcp4" && n != "tcp6" {
+ // for testing with non-TCP
+ return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
+ }
+ ta := addr.(*net.TCPAddr)
+ return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
+}
+
+func (d peerAddr) Network() string {
+ if d.IP.To4() != nil {
+ return "tcp4"
+ } else {
+ return "tcp6"
+ }
+}
+
+func (d peerAddr) String() string {
+ return fmt.Sprintf("%v:%d", d.IP, d.Port)
+}
+
+func (d peerAddr) RlpData() interface{} {
+ return []interface{}{d.IP, d.Port, d.Pubkey}
+}
+
+// Peer represents a remote peer.
type Peer struct {
- Inbound bool // inbound (via listener) or outbound (via dialout)
- Address net.Addr
- Host []byte
- Port uint16
- Pubkey []byte
- Id string
- Caps []string
- peerErrorChan chan error
- messenger *messenger
- peerErrorHandler *PeerErrorHandler
- server *Server
-}
-
-func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer {
- peerErrorChan := NewPeerErrorChannel()
- host, port, _ := net.SplitHostPort(address.String())
- intport, _ := strconv.Atoi(port)
- peer := &Peer{
- Inbound: inbound,
- Address: address,
- Port: uint16(intport),
- Host: net.ParseIP(host),
- peerErrorChan: peerErrorChan,
- server: server,
- }
- peer.messenger = newMessenger(peer, conn, peerErrorChan, server.Handlers())
- peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan)
+ // Peers have all the log methods.
+ // Use them to display messages related to the peer.
+ *logger.Logger
+
+ infolock sync.Mutex
+ identity ClientIdentity
+ caps []Cap
+ listenAddr *peerAddr // what remote peer is listening on
+ dialAddr *peerAddr // non-nil if dialing
+
+ // The mutex protects the connection
+ // so only one protocol can write at a time.
+ writeMu sync.Mutex
+ conn net.Conn
+ bufconn *bufio.ReadWriter
+
+ // These fields maintain the running protocols.
+ protocols []Protocol
+ runBaseProtocol bool // for testing
+
+ runlock sync.RWMutex // protects running
+ running map[string]*proto
+
+ protoWG sync.WaitGroup
+ protoErr chan error
+ closed chan struct{}
+ disc chan DiscReason
+
+ activity event.TypeMux // for activity events
+
+ slot int // index into Server peer list
+
+ // These fields are kept so base protocol can access them.
+ // TODO: this should be one or more interfaces
+ ourID ClientIdentity // client id of the Server
+ ourListenAddr *peerAddr // listen addr of Server, nil if not listening
+ newPeerAddr chan<- *peerAddr // tell server about received peers
+ otherPeers func() []*Peer // should return the list of all peers
+ pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
+}
+
+// NewPeer returns a peer for testing purposes.
+func NewPeer(id ClientIdentity, caps []Cap) *Peer {
+ conn, _ := net.Pipe()
+ peer := newPeer(conn, nil, nil)
+ peer.setHandshakeInfo(id, nil, caps)
return peer
}
-func (self *Peer) String() string {
- var kind string
- if self.Inbound {
- kind = "inbound"
- } else {
+func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
+ p := newPeer(conn, server.Protocols, dialAddr)
+ p.ourID = server.Identity
+ p.newPeerAddr = server.peerConnect
+ p.otherPeers = server.Peers
+ p.pubkeyHook = server.verifyPeer
+ p.runBaseProtocol = true
+
+ // laddr can be updated concurrently by NAT traversal.
+ // newServerPeer must be called with the server lock held.
+ if server.laddr != nil {
+ p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
+ }
+ return p
+}
+
+func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer {
+ p := &Peer{
+ Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()),
+ conn: conn,
+ dialAddr: dialAddr,
+ bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
+ protocols: protocols,
+ running: make(map[string]*proto),
+ disc: make(chan DiscReason),
+ protoErr: make(chan error),
+ closed: make(chan struct{}),
+ }
+ return p
+}
+
+// Identity returns the client identity of the remote peer. The
+// identity can be nil if the peer has not yet completed the
+// handshake.
+func (p *Peer) Identity() ClientIdentity {
+ p.infolock.Lock()
+ defer p.infolock.Unlock()
+ return p.identity
+}
+
+// Caps returns the capabilities (supported subprotocols) of the remote peer.
+func (p *Peer) Caps() []Cap {
+ p.infolock.Lock()
+ defer p.infolock.Unlock()
+ return p.caps
+}
+
+func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) {
+ p.infolock.Lock()
+ p.identity = id
+ p.listenAddr = laddr
+ p.caps = caps
+ p.infolock.Unlock()
+}
+
+// RemoteAddr returns the remote address of the network connection.
+func (p *Peer) RemoteAddr() net.Addr {
+ return p.conn.RemoteAddr()
+}
+
+// LocalAddr returns the local address of the network connection.
+func (p *Peer) LocalAddr() net.Addr {
+ return p.conn.LocalAddr()
+}
+
+// Disconnect terminates the peer connection with the given reason.
+// It returns immediately and does not wait until the connection is closed.
+func (p *Peer) Disconnect(reason DiscReason) {
+ select {
+ case p.disc <- reason:
+ case <-p.closed:
+ }
+}
+
+// String implements fmt.Stringer.
+func (p *Peer) String() string {
+ kind := "inbound"
+ p.infolock.Lock()
+ if p.dialAddr != nil {
kind = "outbound"
}
- return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps)
+ p.infolock.Unlock()
+ return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
+}
+
+const (
+ // maximum amount of time allowed for reading a message
+ msgReadTimeout = 5 * time.Second
+ // maximum amount of time allowed for writing a message
+ msgWriteTimeout = 5 * time.Second
+ // messages smaller than this many bytes will be read at
+ // once before passing them to a protocol.
+ wholePayloadSize = 64 * 1024
+)
+
+var (
+ inactivityTimeout = 2 * time.Second
+ disconnectGracePeriod = 2 * time.Second
+)
+
+func (p *Peer) loop() (reason DiscReason, err error) {
+ defer p.activity.Stop()
+ defer p.closeProtocols()
+ defer close(p.closed)
+ defer p.conn.Close()
+
+ // read loop
+ readMsg := make(chan Msg)
+ readErr := make(chan error)
+ readNext := make(chan bool, 1)
+ protoDone := make(chan struct{}, 1)
+ go p.readLoop(readMsg, readErr, readNext)
+ readNext <- true
+
+ if p.runBaseProtocol {
+ p.startBaseProtocol()
+ }
+
+loop:
+ for {
+ select {
+ case msg := <-readMsg:
+ // a new message has arrived.
+ var wait bool
+ if wait, err = p.dispatch(msg, protoDone); err != nil {
+ p.Errorf("msg dispatch error: %v\n", err)
+ reason = discReasonForError(err)
+ break loop
+ }
+ if !wait {
+ // Msg has already been read completely, continue with next message.
+ readNext <- true
+ }
+ p.activity.Post(time.Now())
+ case <-protoDone:
+ // protocol has consumed the message payload,
+ // we can continue reading from the socket.
+ readNext <- true
+
+ case err := <-readErr:
+ // read failed. there is no need to run the
+ // polite disconnect sequence because the connection
+ // is probably dead anyway.
+ // TODO: handle write errors as well
+ return DiscNetworkError, err
+ case err = <-p.protoErr:
+ reason = discReasonForError(err)
+ break loop
+ case reason = <-p.disc:
+ break loop
+ }
+ }
+
+ // wait for read loop to return.
+ close(readNext)
+ <-readErr
+ // tell the remote end to disconnect
+ done := make(chan struct{})
+ go func() {
+ p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod))
+ p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod)
+ io.Copy(ioutil.Discard, p.conn)
+ close(done)
+ }()
+ select {
+ case <-done:
+ case <-time.After(disconnectGracePeriod):
+ }
+ return reason, err
+}
+
+func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) {
+ for _ = range unblock {
+ p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
+ if msg, err := readMsg(p.bufconn); err != nil {
+ errc <- err
+ } else {
+ msgc <- msg
+ }
+ }
+ close(errc)
+}
+
+func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) {
+ proto, err := p.getProto(msg.Code)
+ if err != nil {
+ return false, err
+ }
+ if msg.Size <= wholePayloadSize {
+ // optimization: msg is small enough, read all
+ // of it and move on to the next message
+ buf, err := ioutil.ReadAll(msg.Payload)
+ if err != nil {
+ return false, err
+ }
+ msg.Payload = bytes.NewReader(buf)
+ proto.in <- msg
+ } else {
+ wait = true
+ pr := &eofSignal{msg.Payload, protoDone}
+ msg.Payload = pr
+ proto.in <- msg
+ }
+ return wait, nil
+}
+
+func (p *Peer) startBaseProtocol() {
+ p.runlock.Lock()
+ defer p.runlock.Unlock()
+ p.running[""] = p.startProto(0, Protocol{
+ Length: baseProtocolLength,
+ Run: runBaseProtocol,
+ })
+}
+
+// startProtocols starts matching named subprotocols.
+func (p *Peer) startSubprotocols(caps []Cap) {
+ sort.Sort(capsByName(caps))
+
+ p.runlock.Lock()
+ defer p.runlock.Unlock()
+ offset := baseProtocolLength
+outer:
+ for _, cap := range caps {
+ for _, proto := range p.protocols {
+ if proto.Name == cap.Name &&
+ proto.Version == cap.Version &&
+ p.running[cap.Name] == nil {
+ p.running[cap.Name] = p.startProto(offset, proto)
+ offset += proto.Length
+ continue outer
+ }
+ }
+ }
+}
+
+func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
+ rw := &proto{
+ in: make(chan Msg),
+ offset: offset,
+ maxcode: impl.Length,
+ peer: p,
+ }
+ p.protoWG.Add(1)
+ go func() {
+ err := impl.Run(p, rw)
+ if err == nil {
+ p.Infof("protocol %q returned", impl.Name)
+ err = newPeerError(errMisc, "protocol returned")
+ } else {
+ p.Errorf("protocol %q error: %v\n", impl.Name, err)
+ }
+ select {
+ case p.protoErr <- err:
+ case <-p.closed:
+ }
+ p.protoWG.Done()
+ }()
+ return rw
+}
+
+// getProto finds the protocol responsible for handling
+// the given message code.
+func (p *Peer) getProto(code uint64) (*proto, error) {
+ p.runlock.RLock()
+ defer p.runlock.RUnlock()
+ for _, proto := range p.running {
+ if code >= proto.offset && code < proto.offset+proto.maxcode {
+ return proto, nil
+ }
+ }
+ return nil, newPeerError(errInvalidMsgCode, "%d", code)
+}
+
+func (p *Peer) closeProtocols() {
+ p.runlock.RLock()
+ for _, p := range p.running {
+ close(p.in)
+ }
+ p.runlock.RUnlock()
+ p.protoWG.Wait()
+}
+
+// writeProtoMsg sends the given message on behalf of the given named protocol.
+func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
+ p.runlock.RLock()
+ proto, ok := p.running[protoName]
+ p.runlock.RUnlock()
+ if !ok {
+ return fmt.Errorf("protocol %s not handled by peer", protoName)
+ }
+ if msg.Code >= proto.maxcode {
+ return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
+ }
+ msg.Code += proto.offset
+ return p.writeMsg(msg, msgWriteTimeout)
+}
+
+// writeMsg writes a message to the connection.
+func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
+ p.writeMu.Lock()
+ defer p.writeMu.Unlock()
+ p.conn.SetWriteDeadline(time.Now().Add(timeout))
+ if err := writeMsg(p.bufconn, msg); err != nil {
+ return newPeerError(errWrite, "%v", err)
+ }
+ return p.bufconn.Flush()
+}
+
+type proto struct {
+ name string
+ in chan Msg
+ maxcode, offset uint64
+ peer *Peer
+}
+
+func (rw *proto) WriteMsg(msg Msg) error {
+ if msg.Code >= rw.maxcode {
+ return newPeerError(errInvalidMsgCode, "not handled")
+ }
+ msg.Code += rw.offset
+ return rw.peer.writeMsg(msg, msgWriteTimeout)
}
-func (self *Peer) Write(protocol string, msg Msg) error {
- return self.messenger.writeProtoMsg(protocol, msg)
+func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
+ return rw.WriteMsg(NewMsg(code, data))
}
-func (self *Peer) Start() {
- self.peerErrorHandler.Start()
- self.messenger.Start()
+func (rw *proto) ReadMsg() (Msg, error) {
+ msg, ok := <-rw.in
+ if !ok {
+ return msg, io.EOF
+ }
+ msg.Code -= rw.offset
+ return msg, nil
}
-func (self *Peer) Stop() {
- self.peerErrorHandler.Stop()
- self.messenger.Stop()
+// eofSignal wraps a reader with eof signaling.
+// the eof channel is closed when the wrapped reader
+// reaches EOF.
+type eofSignal struct {
+ wrapped io.Reader
+ eof chan<- struct{}
}
-func (p *Peer) Encode() []interface{} {
- return []interface{}{p.Host, p.Port, p.Pubkey}
+func (r *eofSignal) Read(buf []byte) (int, error) {
+ n, err := r.wrapped.Read(buf)
+ if err != nil {
+ r.eof <- struct{}{} // tell Peer that msg has been consumed
+ }
+ return n, err
}
diff --git a/p2p/peer_error.go b/p2p/peer_error.go
index f3ef98d98..88b870fbd 100644
--- a/p2p/peer_error.go
+++ b/p2p/peer_error.go
@@ -4,71 +4,121 @@ import (
"fmt"
)
-type ErrorCode int
-
-const errorChanCapacity = 10
-
const (
- PacketTooLong = iota
- PayloadTooShort
- MagicTokenMismatch
- ReadError
- WriteError
- MiscError
- InvalidMsgCode
- InvalidMsg
- P2PVersionMismatch
- PubkeyMissing
- PubkeyInvalid
- PubkeyForbidden
- ProtocolBreach
- PortMismatch
- PingTimeout
- InvalidGenesis
- InvalidNetworkId
- InvalidProtocolVersion
+ errMagicTokenMismatch = iota
+ errRead
+ errWrite
+ errMisc
+ errInvalidMsgCode
+ errInvalidMsg
+ errP2PVersionMismatch
+ errPubkeyMissing
+ errPubkeyInvalid
+ errPubkeyForbidden
+ errProtocolBreach
+ errPingTimeout
+ errInvalidNetworkId
+ errInvalidProtocolVersion
)
-var errorToString = map[ErrorCode]string{
- PacketTooLong: "Packet too long",
- PayloadTooShort: "Payload too short",
- MagicTokenMismatch: "Magic token mismatch",
- ReadError: "Read error",
- WriteError: "Write error",
- MiscError: "Misc error",
- InvalidMsgCode: "Invalid message code",
- InvalidMsg: "Invalid message",
- P2PVersionMismatch: "P2P Version Mismatch",
- PubkeyMissing: "Public key missing",
- PubkeyInvalid: "Public key invalid",
- PubkeyForbidden: "Public key forbidden",
- ProtocolBreach: "Protocol Breach",
- PortMismatch: "Port mismatch",
- PingTimeout: "Ping timeout",
- InvalidGenesis: "Invalid genesis block",
- InvalidNetworkId: "Invalid network id",
- InvalidProtocolVersion: "Invalid protocol version",
+var errorToString = map[int]string{
+ errMagicTokenMismatch: "Magic token mismatch",
+ errRead: "Read error",
+ errWrite: "Write error",
+ errMisc: "Misc error",
+ errInvalidMsgCode: "Invalid message code",
+ errInvalidMsg: "Invalid message",
+ errP2PVersionMismatch: "P2P Version Mismatch",
+ errPubkeyMissing: "Public key missing",
+ errPubkeyInvalid: "Public key invalid",
+ errPubkeyForbidden: "Public key forbidden",
+ errProtocolBreach: "Protocol Breach",
+ errPingTimeout: "Ping timeout",
+ errInvalidNetworkId: "Invalid network id",
+ errInvalidProtocolVersion: "Invalid protocol version",
}
-type PeerError struct {
- Code ErrorCode
+type peerError struct {
+ Code int
message string
}
-func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError {
+func newPeerError(code int, format string, v ...interface{}) *peerError {
desc, ok := errorToString[code]
if !ok {
panic("invalid error code")
}
- format = desc + ": " + format
- message := fmt.Sprintf(format, v...)
- return &PeerError{code, message}
+ err := &peerError{code, desc}
+ if format != "" {
+ err.message += ": " + fmt.Sprintf(format, v...)
+ }
+ return err
}
-func (self *PeerError) Error() string {
+func (self *peerError) Error() string {
return self.message
}
-func NewPeerErrorChannel() chan error {
- return make(chan error, errorChanCapacity)
+type DiscReason byte
+
+const (
+ DiscRequested DiscReason = 0x00
+ DiscNetworkError = 0x01
+ DiscProtocolError = 0x02
+ DiscUselessPeer = 0x03
+ DiscTooManyPeers = 0x04
+ DiscAlreadyConnected = 0x05
+ DiscIncompatibleVersion = 0x06
+ DiscInvalidIdentity = 0x07
+ DiscQuitting = 0x08
+ DiscUnexpectedIdentity = 0x09
+ DiscSelf = 0x0a
+ DiscReadTimeout = 0x0b
+ DiscSubprotocolError = 0x10
+)
+
+var discReasonToString = [DiscSubprotocolError + 1]string{
+ DiscRequested: "Disconnect requested",
+ DiscNetworkError: "Network error",
+ DiscProtocolError: "Breach of protocol",
+ DiscUselessPeer: "Useless peer",
+ DiscTooManyPeers: "Too many peers",
+ DiscAlreadyConnected: "Already connected",
+ DiscIncompatibleVersion: "Incompatible P2P protocol version",
+ DiscInvalidIdentity: "Invalid node identity",
+ DiscQuitting: "Client quitting",
+ DiscUnexpectedIdentity: "Unexpected identity",
+ DiscSelf: "Connected to self",
+ DiscReadTimeout: "Read timeout",
+ DiscSubprotocolError: "Subprotocol error",
+}
+
+func (d DiscReason) String() string {
+ if len(discReasonToString) < int(d) {
+ return fmt.Sprintf("Unknown Reason(%d)", d)
+ }
+ return discReasonToString[d]
+}
+
+func discReasonForError(err error) DiscReason {
+ peerError, ok := err.(*peerError)
+ if !ok {
+ return DiscSubprotocolError
+ }
+ switch peerError.Code {
+ case errP2PVersionMismatch:
+ return DiscIncompatibleVersion
+ case errPubkeyMissing, errPubkeyInvalid:
+ return DiscInvalidIdentity
+ case errPubkeyForbidden:
+ return DiscUselessPeer
+ case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
+ return DiscProtocolError
+ case errPingTimeout:
+ return DiscReadTimeout
+ case errRead, errWrite, errMisc:
+ return DiscNetworkError
+ default:
+ return DiscSubprotocolError
+ }
}
diff --git a/p2p/peer_error_handler.go b/p2p/peer_error_handler.go
deleted file mode 100644
index 47dcd14ff..000000000
--- a/p2p/peer_error_handler.go
+++ /dev/null
@@ -1,98 +0,0 @@
-package p2p
-
-import (
- "net"
-)
-
-const (
- severityThreshold = 10
-)
-
-type DisconnectRequest struct {
- addr net.Addr
- reason DiscReason
-}
-
-type PeerErrorHandler struct {
- quit chan chan bool
- address net.Addr
- peerDisconnect chan DisconnectRequest
- severity int
- errc chan error
-}
-
-func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, errc chan error) *PeerErrorHandler {
- return &PeerErrorHandler{
- quit: make(chan chan bool),
- address: address,
- peerDisconnect: peerDisconnect,
- errc: errc,
- }
-}
-
-func (self *PeerErrorHandler) Start() {
- go self.listen()
-}
-
-func (self *PeerErrorHandler) Stop() {
- q := make(chan bool)
- self.quit <- q
- <-q
-}
-
-func (self *PeerErrorHandler) listen() {
- for {
- select {
- case err, ok := <-self.errc:
- if ok {
- logger.Debugf("error %v\n", err)
- go self.handle(err)
- } else {
- return
- }
- case q := <-self.quit:
- q <- true
- return
- }
- }
-}
-
-func (self *PeerErrorHandler) handle(err error) {
- reason := DiscReason(' ')
- peerError, ok := err.(*PeerError)
- if !ok {
- peerError = NewPeerError(MiscError, " %v", err)
- }
- switch peerError.Code {
- case P2PVersionMismatch:
- reason = DiscIncompatibleVersion
- case PubkeyMissing, PubkeyInvalid:
- reason = DiscInvalidIdentity
- case PubkeyForbidden:
- reason = DiscUselessPeer
- case InvalidMsgCode, PacketTooLong, PayloadTooShort, MagicTokenMismatch, ProtocolBreach:
- reason = DiscProtocolError
- case PingTimeout:
- reason = DiscReadTimeout
- case ReadError, WriteError, MiscError:
- reason = DiscNetworkError
- case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
- reason = DiscSubprotocolError
- default:
- self.severity += self.getSeverity(peerError)
- }
-
- if self.severity >= severityThreshold {
- reason = DiscSubprotocolError
- }
- if reason != DiscReason(' ') {
- self.peerDisconnect <- DisconnectRequest{
- addr: self.address,
- reason: reason,
- }
- }
-}
-
-func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
- return 1
-}
diff --git a/p2p/peer_error_handler_test.go b/p2p/peer_error_handler_test.go
deleted file mode 100644
index b93252f6a..000000000
--- a/p2p/peer_error_handler_test.go
+++ /dev/null
@@ -1,34 +0,0 @@
-package p2p
-
-import (
- // "fmt"
- "net"
- "testing"
- "time"
-)
-
-func TestPeerErrorHandler(t *testing.T) {
- address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
- peerDisconnect := make(chan DisconnectRequest)
- peerErrorChan := NewPeerErrorChannel()
- peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan)
- peh.Start()
- defer peh.Stop()
- for i := 0; i < 11; i++ {
- select {
- case <-peerDisconnect:
- t.Errorf("expected no disconnect request")
- default:
- }
- peerErrorChan <- NewPeerError(MiscError, "")
- }
- time.Sleep(1 * time.Millisecond)
- select {
- case request := <-peerDisconnect:
- if request.addr.String() != address.String() {
- t.Errorf("incorrect address %v != %v", request.addr, address)
- }
- default:
- t.Errorf("expected disconnect request")
- }
-}
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index da62cc380..1afa0ab17 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -1,90 +1,222 @@
package p2p
-// "net"
-
-// func TestPeer(t *testing.T) {
-// handlers := make(Handlers)
-// testProtocol := &TestProtocol{recv: make(chan testMsg)}
-// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
-// handlers["ccc"] = func(p *Peer) Protocol { return testProtocol }
-// addr := &TestAddr{"test:30"}
-// conn := NewTestNetworkConnection(addr)
-// _, server := SetupTestServer(handlers)
-// server.Handshake()
-// peer := NewPeer(conn, addr, true, server)
-// // peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
-// peer.Start()
-// defer peer.Stop()
-// time.Sleep(2 * time.Millisecond)
-// if len(conn.Out) != 1 {
-// t.Errorf("handshake not sent")
-// } else {
-// out := conn.Out[0]
-// packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
-// if bytes.Compare(out, packet) != 0 {
-// t.Errorf("incorrect handshake packet %v != %v", out, packet)
-// }
-// }
-
-// packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
-// conn.In(0, packet)
-// time.Sleep(10 * time.Millisecond)
-
-// pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
-// if pro.state != handshakeReceived {
-// t.Errorf("handshake not received")
-// }
-// if peer.Port != 30 {
-// t.Errorf("port incorrectly set")
-// }
-// if peer.Id != "peer" {
-// t.Errorf("id incorrectly set")
-// }
-// if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
-// t.Errorf("pubkey incorrectly set")
-// }
-// fmt.Println(peer.Caps)
-// if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
-// t.Errorf("protocols incorrectly set")
-// }
-
-// msg := NewMsg(3)
-// err := peer.Write("aaa", msg)
-// if err != nil {
-// t.Errorf("expect no error for known protocol: %v", err)
-// } else {
-// time.Sleep(1 * time.Millisecond)
-// if len(conn.Out) != 2 {
-// t.Errorf("msg not written")
-// } else {
-// out := conn.Out[1]
-// packet := Packet(16, 3)
-// if bytes.Compare(out, packet) != 0 {
-// t.Errorf("incorrect packet %v != %v", out, packet)
-// }
-// }
-// }
-
-// msg = NewMsg(2)
-// err = peer.Write("ccc", msg)
-// if err != nil {
-// t.Errorf("expect no error for known protocol: %v", err)
-// } else {
-// time.Sleep(1 * time.Millisecond)
-// if len(conn.Out) != 3 {
-// t.Errorf("msg not written")
-// } else {
-// out := conn.Out[2]
-// packet := Packet(21, 2)
-// if bytes.Compare(out, packet) != 0 {
-// t.Errorf("incorrect packet %v != %v", out, packet)
-// }
-// }
-// }
-
-// err = peer.Write("bbb", msg)
-// time.Sleep(1 * time.Millisecond)
-// if err == nil {
-// t.Errorf("expect error for unknown protocol")
-// }
-// }
+import (
+ "bufio"
+ "net"
+ "reflect"
+ "testing"
+ "time"
+)
+
+var discard = Protocol{
+ Name: "discard",
+ Length: 1,
+ Run: func(p *Peer, rw MsgReadWriter) error {
+ for {
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if err = msg.Discard(); err != nil {
+ return err
+ }
+ }
+ },
+}
+
+func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
+ conn1, conn2 := net.Pipe()
+ id := NewSimpleClientIdentity("test", "0", "0", "public key")
+ peer := newPeer(conn1, protos, nil)
+ peer.ourID = id
+ peer.pubkeyHook = func(*peerAddr) error { return nil }
+ errc := make(chan error, 1)
+ go func() {
+ _, err := peer.loop()
+ errc <- err
+ }()
+ return conn2, peer, errc
+}
+
+func TestPeerProtoReadMsg(t *testing.T) {
+ defer testlog(t).detach()
+
+ done := make(chan struct{})
+ proto := Protocol{
+ Name: "a",
+ Length: 5,
+ Run: func(peer *Peer, rw MsgReadWriter) error {
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ t.Errorf("read error: %v", err)
+ }
+ if msg.Code != 2 {
+ t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
+ }
+ data, err := msg.Data()
+ if err != nil {
+ t.Errorf("data decoding error: %v", err)
+ }
+ expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
+ if !reflect.DeepEqual(data.Slice(), expdata) {
+ t.Errorf("incorrect msg data %#v", data.Slice())
+ }
+ close(done)
+ return nil
+ },
+ }
+
+ net, peer, errc := testPeer([]Protocol{proto})
+ defer net.Close()
+ peer.startSubprotocols([]Cap{proto.cap()})
+
+ writeMsg(net, NewMsg(18, 1, "000"))
+ select {
+ case <-done:
+ case err := <-errc:
+ t.Errorf("peer returned: %v", err)
+ case <-time.After(2 * time.Second):
+ t.Errorf("receive timeout")
+ }
+}
+
+func TestPeerProtoReadLargeMsg(t *testing.T) {
+ defer testlog(t).detach()
+
+ msgsize := uint32(10 * 1024 * 1024)
+ done := make(chan struct{})
+ proto := Protocol{
+ Name: "a",
+ Length: 5,
+ Run: func(peer *Peer, rw MsgReadWriter) error {
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ t.Errorf("read error: %v", err)
+ }
+ if msg.Size != msgsize+4 {
+ t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
+ }
+ msg.Discard()
+ close(done)
+ return nil
+ },
+ }
+
+ net, peer, errc := testPeer([]Protocol{proto})
+ defer net.Close()
+ peer.startSubprotocols([]Cap{proto.cap()})
+
+ writeMsg(net, NewMsg(18, make([]byte, msgsize)))
+ select {
+ case <-done:
+ case err := <-errc:
+ t.Errorf("peer returned: %v", err)
+ case <-time.After(2 * time.Second):
+ t.Errorf("receive timeout")
+ }
+}
+
+func TestPeerProtoEncodeMsg(t *testing.T) {
+ defer testlog(t).detach()
+
+ proto := Protocol{
+ Name: "a",
+ Length: 2,
+ Run: func(peer *Peer, rw MsgReadWriter) error {
+ if err := rw.EncodeMsg(2); err == nil {
+ t.Error("expected error for out-of-range msg code, got nil")
+ }
+ if err := rw.EncodeMsg(1); err != nil {
+ t.Errorf("write error: %v", err)
+ }
+ return nil
+ },
+ }
+ net, peer, _ := testPeer([]Protocol{proto})
+ defer net.Close()
+ peer.startSubprotocols([]Cap{proto.cap()})
+
+ bufr := bufio.NewReader(net)
+ msg, err := readMsg(bufr)
+ if err != nil {
+ t.Errorf("read error: %v", err)
+ }
+ if msg.Code != 17 {
+ t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
+ }
+}
+
+func TestPeerWrite(t *testing.T) {
+ defer testlog(t).detach()
+
+ net, peer, peerErr := testPeer([]Protocol{discard})
+ defer net.Close()
+ peer.startSubprotocols([]Cap{discard.cap()})
+
+ // test write errors
+ if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
+ t.Errorf("expected error for unknown protocol, got nil")
+ }
+ if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
+ t.Errorf("expected error for out-of-range msg code, got nil")
+ } else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
+ t.Errorf("wrong error for out-of-range msg code, got %#v", err)
+ }
+
+ // setup for reading the message on the other end
+ read := make(chan struct{})
+ go func() {
+ bufr := bufio.NewReader(net)
+ msg, err := readMsg(bufr)
+ if err != nil {
+ t.Errorf("read error: %v", err)
+ } else if msg.Code != 16 {
+ t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
+ }
+ msg.Discard()
+ close(read)
+ }()
+
+ // test succcessful write
+ if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
+ t.Errorf("expect no error for known protocol: %v", err)
+ }
+ select {
+ case <-read:
+ case err := <-peerErr:
+ t.Fatalf("peer stopped: %v", err)
+ }
+}
+
+func TestPeerActivity(t *testing.T) {
+ // shorten inactivityTimeout while this test is running
+ oldT := inactivityTimeout
+ defer func() { inactivityTimeout = oldT }()
+ inactivityTimeout = 20 * time.Millisecond
+
+ net, peer, peerErr := testPeer([]Protocol{discard})
+ defer net.Close()
+ peer.startSubprotocols([]Cap{discard.cap()})
+
+ sub := peer.activity.Subscribe(time.Time{})
+ defer sub.Unsubscribe()
+
+ for i := 0; i < 6; i++ {
+ writeMsg(net, NewMsg(16))
+ select {
+ case <-sub.Chan():
+ case <-time.After(inactivityTimeout / 2):
+ t.Fatal("no event within ", inactivityTimeout/2)
+ case err := <-peerErr:
+ t.Fatal("peer error", err)
+ }
+ }
+
+ select {
+ case <-time.After(inactivityTimeout * 2):
+ case <-sub.Chan():
+ t.Fatal("got activity event while connection was inactive")
+ case err := <-peerErr:
+ t.Fatal("peer error", err)
+ }
+}
diff --git a/p2p/protocol.go b/p2p/protocol.go
index d22ba70cb..169dcdb6e 100644
--- a/p2p/protocol.go
+++ b/p2p/protocol.go
@@ -3,249 +3,185 @@ package p2p
import (
"bytes"
"net"
- "sort"
"time"
"github.com/ethereum/go-ethereum/ethutil"
)
-// Protocol is implemented by P2P subprotocols.
-type Protocol interface {
- // Start is called when the protocol becomes active.
- // It should read and write messages from rw.
- // Messages must be fully consumed.
- //
- // The connection is closed when Start returns. It should return
- // any protocol-level error (such as an I/O error) that is
- // encountered.
- Start(peer *Peer, rw MsgReadWriter) error
+// Protocol represents a P2P subprotocol implementation.
+type Protocol struct {
+ // Name should contain the official protocol name,
+ // often a three-letter word.
+ Name string
- // Offset should return the number of message codes
- // used by the protocol.
- Offset() MsgCode
-}
+ // Version should contain the version number of the protocol.
+ Version uint
-type MsgReader interface {
- ReadMsg() (Msg, error)
-}
-
-type MsgWriter interface {
- WriteMsg(Msg) error
-}
-
-// MsgReadWriter is passed to protocols. Protocol implementations can
-// use it to write messages back to a connected peer.
-type MsgReadWriter interface {
- MsgReader
- MsgWriter
-}
+ // Length should contain the number of message codes used
+ // by the protocol.
+ Length uint64
-type MsgHandler func(code MsgCode, data *ethutil.Value) error
-
-// MsgLoop reads messages off the given reader and
-// calls the handler function for each decoded message until
-// it returns an error or the peer connection is closed.
-//
-// If a message is larger than the given maximum size, RunProtocol
-// returns an appropriate error.n
-func MsgLoop(r MsgReader, maxsize uint32, handler MsgHandler) error {
- for {
- msg, err := r.ReadMsg()
- if err != nil {
- return err
- }
- if msg.Size > maxsize {
- return NewPeerError(InvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
- }
- value, err := msg.Data()
- if err != nil {
- return err
- }
- if err := handler(msg.Code, value); err != nil {
- return err
- }
- }
-}
-
-// the ÐΞVp2p base protocol
-type baseProtocol struct {
- rw MsgReadWriter
- peer *Peer
+ // Run is called in a new groutine when the protocol has been
+ // negotiated with a peer. It should read and write messages from
+ // rw. The Payload for each message must be fully consumed.
+ //
+ // The peer connection is closed when Start returns. It should return
+ // any protocol-level error (such as an I/O error) that is
+ // encountered.
+ Run func(peer *Peer, rw MsgReadWriter) error
}
-type bpMsg struct {
- code MsgCode
- data *ethutil.Value
+func (p Protocol) cap() Cap {
+ return Cap{p.Name, p.Version}
}
const (
- p2pVersion = 0
- pingTimeout = 2 * time.Second
- pingGracePeriod = 2 * time.Second
+ baseProtocolVersion = 2
+ baseProtocolLength = uint64(16)
+ baseProtocolMaxMsgSize = 10 * 1024 * 1024
)
const (
- // message codes
- handshakeMsg = iota
- discMsg
- pingMsg
- pongMsg
- getPeersMsg
- peersMsg
+ // devp2p message codes
+ handshakeMsg = 0x00
+ discMsg = 0x01
+ pingMsg = 0x02
+ pongMsg = 0x03
+ getPeersMsg = 0x04
+ peersMsg = 0x05
)
-const (
- baseProtocolOffset MsgCode = 16
- baseProtocolMaxMsgSize = 500 * 1024
-)
-
-type DiscReason byte
+// handshake is the structure of a handshake list.
+type handshake struct {
+ Version uint64
+ ID string
+ Caps []Cap
+ ListenPort uint64
+ NodeID []byte
+}
-const (
- // Values are given explicitly instead of by iota because these values are
- // defined by the wire protocol spec; it is easier for humans to ensure
- // correctness when values are explicit.
- DiscRequested = 0x00
- DiscNetworkError = 0x01
- DiscProtocolError = 0x02
- DiscUselessPeer = 0x03
- DiscTooManyPeers = 0x04
- DiscAlreadyConnected = 0x05
- DiscIncompatibleVersion = 0x06
- DiscInvalidIdentity = 0x07
- DiscQuitting = 0x08
- DiscUnexpectedIdentity = 0x09
- DiscSelf = 0x0a
- DiscReadTimeout = 0x0b
- DiscSubprotocolError = 0x10
-)
+func (h *handshake) String() string {
+ return h.ID
+}
+func (h *handshake) Pubkey() []byte {
+ return h.NodeID
+}
-var discReasonToString = [DiscSubprotocolError + 1]string{
- DiscRequested: "Disconnect requested",
- DiscNetworkError: "Network error",
- DiscProtocolError: "Breach of protocol",
- DiscUselessPeer: "Useless peer",
- DiscTooManyPeers: "Too many peers",
- DiscAlreadyConnected: "Already connected",
- DiscIncompatibleVersion: "Incompatible P2P protocol version",
- DiscInvalidIdentity: "Invalid node identity",
- DiscQuitting: "Client quitting",
- DiscUnexpectedIdentity: "Unexpected identity",
- DiscSelf: "Connected to self",
- DiscReadTimeout: "Read timeout",
- DiscSubprotocolError: "Subprotocol error",
+// Cap is the structure of a peer capability.
+type Cap struct {
+ Name string
+ Version uint
}
-func (d DiscReason) String() string {
- if len(discReasonToString) < int(d) {
- return "Unknown"
- }
- return discReasonToString[d]
+func (cap Cap) RlpData() interface{} {
+ return []interface{}{cap.Name, cap.Version}
}
-func (bp *baseProtocol) Offset() MsgCode {
- return baseProtocolOffset
+type capsByName []Cap
+
+func (cs capsByName) Len() int { return len(cs) }
+func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
+func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
+
+type baseProtocol struct {
+ rw MsgReadWriter
+ peer *Peer
}
-func (bp *baseProtocol) Start(peer *Peer, rw MsgReadWriter) error {
- bp.peer, bp.rw = peer, rw
+func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
+ bp := &baseProtocol{rw, peer}
- // Do the handshake.
- // TODO: disconnect is valid before handshake, too.
- rw.WriteMsg(bp.peer.server.handshakeMsg())
+ // do handshake
+ if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
+ return err
+ }
msg, err := rw.ReadMsg()
if err != nil {
return err
}
if msg.Code != handshakeMsg {
- return NewPeerError(ProtocolBreach, " first message must be handshake")
+ return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
}
data, err := msg.Data()
if err != nil {
- return NewPeerError(InvalidMsg, "%v", err)
+ return newPeerError(errInvalidMsg, "%v", err)
}
if err := bp.handleHandshake(data); err != nil {
return err
}
- msgin := make(chan bpMsg)
- done := make(chan error, 1)
+ // run main loop
+ quit := make(chan error, 1)
go func() {
- done <- MsgLoop(rw, baseProtocolMaxMsgSize,
- func(code MsgCode, data *ethutil.Value) error {
- msgin <- bpMsg{code, data}
- return nil
- })
+ quit <- MsgLoop(rw, baseProtocolMaxMsgSize, bp.handle)
}()
- return bp.loop(msgin, done)
+ return bp.loop(quit)
}
-func (bp *baseProtocol) loop(msgin <-chan bpMsg, quit <-chan error) error {
- logger.Debugf("pingpong keepalive started at %v\n", time.Now())
- messenger := bp.rw.(*proto).messenger
- pingTimer := time.NewTimer(pingTimeout)
- pinged := true
+var pingTimeout = 2 * time.Second
+
+func (bp *baseProtocol) loop(quit <-chan error) error {
+ ping := time.NewTimer(pingTimeout)
+ activity := bp.peer.activity.Subscribe(time.Time{})
+ lastActive := time.Time{}
+ defer ping.Stop()
+ defer activity.Unsubscribe()
- for {
+ getPeersTick := time.NewTicker(10 * time.Second)
+ defer getPeersTick.Stop()
+ err := bp.rw.EncodeMsg(getPeersMsg)
+
+ for err == nil {
select {
- case msg := <-msgin:
- if err := bp.handle(msg.code, msg.data); err != nil {
- return err
- }
- case err := <-quit:
+ case err = <-quit:
return err
- case <-messenger.pulse:
- pingTimer.Reset(pingTimeout)
- pinged = false
- case <-pingTimer.C:
- if pinged {
- return NewPeerError(PingTimeout, "")
+ case <-getPeersTick.C:
+ err = bp.rw.EncodeMsg(getPeersMsg)
+ case event := <-activity.Chan():
+ ping.Reset(pingTimeout)
+ lastActive = event.(time.Time)
+ case t := <-ping.C:
+ if lastActive.Add(pingTimeout * 2).Before(t) {
+ err = newPeerError(errPingTimeout, "")
+ } else if lastActive.Add(pingTimeout).Before(t) {
+ err = bp.rw.EncodeMsg(pingMsg)
}
- logger.Debugf("pinging at %v\n", time.Now())
- if err := bp.rw.WriteMsg(NewMsg(pingMsg)); err != nil {
- return NewPeerError(WriteError, "%v", err)
- }
- pinged = true
- pingTimer.Reset(pingTimeout)
}
}
+ return err
}
-func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error {
+func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error {
switch code {
case handshakeMsg:
- return NewPeerError(ProtocolBreach, " extra handshake received")
+ return newPeerError(errProtocolBreach, "extra handshake received")
case discMsg:
- logger.Infof("Disconnect requested from peer %v, reason", DiscReason(data.Get(0).Uint()))
- bp.peer.server.PeerDisconnect() <- DisconnectRequest{
- addr: bp.peer.Address,
- reason: DiscRequested,
- }
+ bp.peer.Disconnect(DiscReason(data.Get(0).Uint()))
+ return nil
case pingMsg:
- return bp.rw.WriteMsg(NewMsg(pongMsg))
+ return bp.rw.EncodeMsg(pongMsg)
case pongMsg:
- // reply for ping
case getPeersMsg:
- // Peer asked for list of connected peers.
- peersRLP := bp.peer.server.encodedPeerList()
- if peersRLP != nil {
- msg := Msg{
- Code: peersMsg,
- Size: uint32(len(peersRLP)),
- Payload: bytes.NewReader(peersRLP),
- }
- return bp.rw.WriteMsg(msg)
+ peers := bp.peerList()
+ // this is dangerous. the spec says that we should _delay_
+ // sending the response if no new information is available.
+ // this means that would need to send a response later when
+ // new peers become available.
+ //
+ // TODO: add event mechanism to notify baseProtocol for new peers
+ if len(peers) > 0 {
+ return bp.rw.EncodeMsg(peersMsg, peers)
}
case peersMsg:
bp.handlePeers(data)
default:
- return NewPeerError(InvalidMsgCode, "unknown message code %v", code)
+ return newPeerError(errInvalidMsgCode, "unknown message code %v", code)
}
return nil
}
@@ -253,62 +189,102 @@ func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error {
func (bp *baseProtocol) handlePeers(data *ethutil.Value) {
it := data.NewIterator()
for it.Next() {
- ip := net.IP(it.Value().Get(0).Bytes())
- port := it.Value().Get(1).Uint()
- address := &net.TCPAddr{IP: ip, Port: int(port)}
- go bp.peer.server.PeerConnect(address)
+ addr := &peerAddr{
+ IP: net.IP(it.Value().Get(0).Bytes()),
+ Port: it.Value().Get(1).Uint(),
+ Pubkey: it.Value().Get(2).Bytes(),
+ }
+ bp.peer.Debugf("received peer suggestion: %v", addr)
+ bp.peer.newPeerAddr <- addr
}
}
func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
- var (
- remoteVersion = c.Get(0).Uint()
- id = c.Get(1).Str()
- caps = c.Get(2)
- port = c.Get(3).Uint()
- pubkey = c.Get(4).Bytes()
- )
- // Check correctness of p2p protocol version
- if remoteVersion != p2pVersion {
- return NewPeerError(P2PVersionMismatch, "Require protocol %d, received %d\n", p2pVersion, remoteVersion)
+ hs := handshake{
+ Version: c.Get(0).Uint(),
+ ID: c.Get(1).Str(),
+ Caps: nil, // decoded below
+ ListenPort: c.Get(3).Uint(),
+ NodeID: c.Get(4).Bytes(),
}
-
- // Handle the pub key (validation, uniqueness)
- if len(pubkey) == 0 {
- return NewPeerError(PubkeyMissing, "not supplied in handshake.")
+ if hs.Version != baseProtocolVersion {
+ return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
+ baseProtocolVersion, hs.Version)
}
-
- if len(pubkey) != 64 {
- return NewPeerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8)
+ if len(hs.NodeID) == 0 {
+ return newPeerError(errPubkeyMissing, "")
+ }
+ if len(hs.NodeID) != 64 {
+ return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
+ }
+ if da := bp.peer.dialAddr; da != nil {
+ // verify that the peer we wanted to connect to
+ // actually holds the target public key.
+ if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
+ return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
+ }
+ }
+ pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
+ if err := bp.peer.pubkeyHook(pa); err != nil {
+ return newPeerError(errPubkeyForbidden, "%v", err)
+ }
+ capsIt := c.Get(2).NewIterator()
+ for capsIt.Next() {
+ cap := capsIt.Value()
+ name := cap.Get(0).Str()
+ if name != "" {
+ hs.Caps = append(hs.Caps, Cap{Name: name, Version: uint(cap.Get(1).Uint())})
+ }
}
- // self connect detection
- if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 {
- return NewPeerError(PubkeyForbidden, "not allowed to connect to self")
+ var addr *peerAddr
+ if hs.ListenPort != 0 {
+ addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
+ addr.Port = hs.ListenPort
}
+ bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
+ bp.peer.startSubprotocols(hs.Caps)
+ return nil
+}
- // register pubkey on server. this also sets the pubkey on the peer (need lock)
- if err := bp.peer.server.RegisterPubkey(bp.peer, pubkey); err != nil {
- return NewPeerError(PubkeyForbidden, err.Error())
+func (bp *baseProtocol) handshakeMsg() Msg {
+ var (
+ port uint64
+ caps []interface{}
+ )
+ if bp.peer.ourListenAddr != nil {
+ port = bp.peer.ourListenAddr.Port
}
+ for _, proto := range bp.peer.protocols {
+ caps = append(caps, proto.cap())
+ }
+ return NewMsg(handshakeMsg,
+ baseProtocolVersion,
+ bp.peer.ourID.String(),
+ caps,
+ port,
+ bp.peer.ourID.Pubkey()[1:],
+ )
+}
- // check port
- if bp.peer.Inbound {
- uint16port := uint16(port)
- if bp.peer.Port > 0 && bp.peer.Port != uint16port {
- return NewPeerError(PortMismatch, "port mismatch: %v != %v", bp.peer.Port, port)
- } else {
- bp.peer.Port = uint16port
+func (bp *baseProtocol) peerList() []ethutil.RlpEncodable {
+ peers := bp.peer.otherPeers()
+ ds := make([]ethutil.RlpEncodable, 0, len(peers))
+ for _, p := range peers {
+ p.infolock.Lock()
+ addr := p.listenAddr
+ p.infolock.Unlock()
+ // filter out this peer and peers that are not listening or
+ // have not completed the handshake.
+ // TODO: track previously sent peers and exclude them as well.
+ if p == bp.peer || addr == nil {
+ continue
}
+ ds = append(ds, addr)
}
-
- capsIt := caps.NewIterator()
- for capsIt.Next() {
- cap := capsIt.Value().Str()
- bp.peer.Caps = append(bp.peer.Caps, cap)
+ ourAddr := bp.peer.ourListenAddr
+ if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
+ ds = append(ds, ourAddr)
}
- sort.Strings(bp.peer.Caps)
- bp.rw.(*proto).messenger.setRemoteProtocols(bp.peer.Caps)
- bp.peer.Id = id
- return nil
+ return ds
}
diff --git a/p2p/server.go b/p2p/server.go
index 54d2cde30..8a6087566 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -2,155 +2,101 @@ package p2p
import (
"bytes"
+ "errors"
"fmt"
"net"
- "sort"
- "strconv"
"sync"
"time"
- logpkg "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/logger"
)
const (
- outboundAddressPoolSize = 10
- disconnectGracePeriod = 2
+ outboundAddressPoolSize = 500
+ defaultDialTimeout = 10 * time.Second
+ portMappingUpdateInterval = 15 * time.Minute
+ portMappingTimeout = 20 * time.Minute
)
-type Blacklist interface {
- Get([]byte) (bool, error)
- Put([]byte) error
- Delete([]byte) error
- Exists(pubkey []byte) (ok bool)
-}
-
-type BlacklistMap struct {
- blacklist map[string]bool
- lock sync.RWMutex
-}
-
-func NewBlacklist() *BlacklistMap {
- return &BlacklistMap{
- blacklist: make(map[string]bool),
- }
-}
-
-func (self *BlacklistMap) Get(pubkey []byte) (bool, error) {
- self.lock.RLock()
- defer self.lock.RUnlock()
- v, ok := self.blacklist[string(pubkey)]
- var err error
- if !ok {
- err = fmt.Errorf("not found")
- }
- return v, err
-}
-
-func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) {
- self.lock.RLock()
- defer self.lock.RUnlock()
- _, ok = self.blacklist[string(pubkey)]
- return
-}
-
-func (self *BlacklistMap) Put(pubkey []byte) error {
- self.lock.RLock()
- defer self.lock.RUnlock()
- self.blacklist[string(pubkey)] = true
- return nil
-}
-
-func (self *BlacklistMap) Delete(pubkey []byte) error {
- self.lock.RLock()
- defer self.lock.RUnlock()
- delete(self.blacklist, string(pubkey))
- return nil
-}
+var srvlog = logger.NewLogger("P2P Server")
+// Server manages all peer connections.
+//
+// The fields of Server are used as configuration parameters.
+// You should set them before starting the Server. Fields may not be
+// modified while the server is running.
type Server struct {
- network Network
- listening bool //needed?
- dialing bool //needed?
- closed bool
- identity ClientIdentity
- addr net.Addr
- port uint16
- protocols []string
-
- quit chan chan bool
- peersLock sync.RWMutex
-
- maxPeers int
- peers []*Peer
- peerSlots chan int
- peersTable map[string]int
- peerCount int
- cachedEncodedPeers []byte
-
- peerConnect chan net.Addr
- peerDisconnect chan DisconnectRequest
- blacklist Blacklist
- handlers Handlers
-}
-
-var logger = logpkg.NewLogger("P2P")
-
-func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server {
- // get alphabetical list of protocol names from handlers map
- protocols := []string{}
- for protocol := range handlers {
- protocols = append(protocols, protocol)
- }
- sort.Strings(protocols)
-
- _, port, _ := net.SplitHostPort(addr.String())
- intport, _ := strconv.Atoi(port)
-
- self := &Server{
- // NewSimpleClientIdentity(clientIdentifier, version, customIdentifier)
- network: network,
- identity: identity,
- addr: addr,
- port: uint16(intport),
- protocols: protocols,
-
- quit: make(chan chan bool),
-
- maxPeers: maxPeers,
- peers: make([]*Peer, maxPeers),
- peerSlots: make(chan int, maxPeers),
- peersTable: make(map[string]int),
+ // This field must be set to a valid client identity.
+ Identity ClientIdentity
+
+ // MaxPeers is the maximum number of peers that can be
+ // connected. It must be greater than zero.
+ MaxPeers int
+
+ // Protocols should contain the protocols supported
+ // by the server. Matching protocols are launched for
+ // each peer.
+ Protocols []Protocol
+
+ // If Blacklist is set to a non-nil value, the given Blacklist
+ // is used to verify peer connections.
+ Blacklist Blacklist
+
+ // If ListenAddr is set to a non-nil address, the server
+ // will listen for incoming connections.
+ //
+ // If the port is zero, the operating system will pick a port. The
+ // ListenAddr field will be updated with the actual address when
+ // the server is started.
+ ListenAddr string
+
+ // If set to a non-nil value, the given NAT port mapper
+ // is used to make the listening port available to the
+ // Internet.
+ NAT NAT
+
+ // If Dialer is set to a non-nil value, the given Dialer
+ // is used to dial outbound peer connections.
+ Dialer *net.Dialer
+
+ // If NoDial is true, the server will not dial any peers.
+ NoDial bool
+
+ // Hook for testing. This is useful because we can inhibit
+ // the whole protocol stack.
+ newPeerFunc peerFunc
- peerConnect: make(chan net.Addr, outboundAddressPoolSize),
- peerDisconnect: make(chan DisconnectRequest),
- blacklist: blacklist,
-
- handlers: handlers,
- }
- for i := 0; i < maxPeers; i++ {
- self.peerSlots <- i // fill up with indexes
- }
- return self
+ lock sync.RWMutex
+ running bool
+ listener net.Listener
+ laddr *net.TCPAddr // real listen addr
+ peers []*Peer
+ peerSlots chan int
+ peerCount int
+
+ quit chan struct{}
+ wg sync.WaitGroup
+ peerConnect chan *peerAddr
+ peerDisconnect chan *Peer
}
-func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) {
- addr, err = self.network.NewAddr(host, port)
- return
-}
+// NAT is implemented by NAT traversal methods.
+type NAT interface {
+ GetExternalAddress() (net.IP, error)
+ AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
+ DeletePortMapping(protocol string, extport, intport int) error
-func (self *Server) ParseAddr(address string) (addr net.Addr, err error) {
- addr, err = self.network.ParseAddr(address)
- return
+ // Should return name of the method.
+ String() string
}
-func (self *Server) ClientIdentity() ClientIdentity {
- return self.identity
-}
+type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer
-func (self *Server) Peers() (peers []*Peer) {
- self.peersLock.RLock()
- defer self.peersLock.RUnlock()
- for _, peer := range self.peers {
+// Peers returns all connected peers.
+func (srv *Server) Peers() (peers []*Peer) {
+ srv.lock.RLock()
+ defer srv.lock.RUnlock()
+ for _, peer := range srv.peers {
if peer != nil {
peers = append(peers, peer)
}
@@ -158,331 +104,364 @@ func (self *Server) Peers() (peers []*Peer) {
return
}
-func (self *Server) PeerCount() int {
- self.peersLock.RLock()
- defer self.peersLock.RUnlock()
- return self.peerCount
+// PeerCount returns the number of connected peers.
+func (srv *Server) PeerCount() int {
+ srv.lock.RLock()
+ defer srv.lock.RUnlock()
+ return srv.peerCount
}
-func (self *Server) PeerConnect(addr net.Addr) {
- // TODO: should buffer, filter and uniq
- // send GetPeersMsg if not blocking
+// SuggestPeer injects an address into the outbound address pool.
+func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
select {
- case self.peerConnect <- addr: // not enough peers
- self.Broadcast("", getPeersMsg)
- default: // we dont care
+ case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}:
+ default: // don't block
}
}
-func (self *Server) PeerDisconnect() chan DisconnectRequest {
- return self.peerDisconnect
-}
-
-func (self *Server) Blacklist() Blacklist {
- return self.blacklist
-}
-
-func (self *Server) Handlers() Handlers {
- return self.handlers
-}
-
-func (self *Server) Broadcast(protocol string, code MsgCode, data ...interface{}) {
+// Broadcast sends an RLP-encoded message to all connected peers.
+// This method is deprecated and will be removed later.
+func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
var payload []byte
if data != nil {
payload = encodePayload(data...)
}
- self.peersLock.RLock()
- defer self.peersLock.RUnlock()
- for _, peer := range self.peers {
+ srv.lock.RLock()
+ defer srv.lock.RUnlock()
+ for _, peer := range srv.peers {
if peer != nil {
var msg = Msg{Code: code}
if data != nil {
msg.Payload = bytes.NewReader(payload)
msg.Size = uint32(len(payload))
}
- peer.messenger.writeProtoMsg(protocol, msg)
+ peer.writeProtoMsg(protocol, msg)
}
}
}
-// Start the server
-func (self *Server) Start(listen bool, dial bool) {
- self.network.Start()
- if listen {
- listener, err := self.network.Listener(self.addr)
- if err != nil {
- logger.Warnf("Error initializing listener: %v", err)
- logger.Warnf("Connection listening disabled")
- self.listening = false
- } else {
- self.listening = true
- logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr())
- go self.inboundPeerHandler(listener)
- }
+// Start starts running the server.
+// Servers can be re-used and started again after stopping.
+func (srv *Server) Start() (err error) {
+ srv.lock.Lock()
+ defer srv.lock.Unlock()
+ if srv.running {
+ return errors.New("server already running")
+ }
+ srvlog.Infoln("Starting Server")
+
+ // initialize fields
+ if srv.Identity == nil {
+ return fmt.Errorf("Server.Identity must be set to a non-nil identity")
}
- if dial {
- dialer, err := self.network.Dialer(self.addr)
- if err != nil {
- logger.Warnf("Error initializing dialer: %v", err)
- logger.Warnf("Connection dialout disabled")
- self.dialing = false
- } else {
- self.dialing = true
- logger.Infoln("Dial peers watching outbound address pool")
- go self.outboundPeerHandler(dialer)
+ if srv.MaxPeers <= 0 {
+ return fmt.Errorf("Server.MaxPeers must be > 0")
+ }
+ srv.quit = make(chan struct{})
+ srv.peers = make([]*Peer, srv.MaxPeers)
+ srv.peerSlots = make(chan int, srv.MaxPeers)
+ srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize)
+ srv.peerDisconnect = make(chan *Peer)
+ if srv.newPeerFunc == nil {
+ srv.newPeerFunc = newServerPeer
+ }
+ if srv.Blacklist == nil {
+ srv.Blacklist = NewBlacklist()
+ }
+ if srv.Dialer == nil {
+ srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
+ }
+
+ if srv.ListenAddr != "" {
+ if err := srv.startListening(); err != nil {
+ return err
}
}
- logger.Infoln("server started")
+ if !srv.NoDial {
+ srv.wg.Add(1)
+ go srv.dialLoop()
+ }
+ if srv.NoDial && srv.ListenAddr == "" {
+ srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
+ }
+
+ // make all slots available
+ for i := range srv.peers {
+ srv.peerSlots <- i
+ }
+ // note: discLoop is not part of WaitGroup
+ go srv.discLoop()
+ srv.running = true
+ return nil
}
-func (self *Server) Stop() {
- logger.Infoln("server stopping...")
- // // quit one loop if dialing
- if self.dialing {
- logger.Infoln("stop dialout...")
- dialq := make(chan bool)
- self.quit <- dialq
- <-dialq
- fmt.Println("quit another")
- }
- // quit the other loop if listening
- if self.listening {
- logger.Infoln("stop listening...")
- listenq := make(chan bool)
- self.quit <- listenq
- <-listenq
- fmt.Println("quit one")
- }
-
- fmt.Println("quit waited")
-
- logger.Infoln("stopping peers...")
- peers := []net.Addr{}
- self.peersLock.RLock()
- self.closed = true
- for _, peer := range self.peers {
- if peer != nil {
- peers = append(peers, peer.Address)
- }
+func (srv *Server) startListening() error {
+ listener, err := net.Listen("tcp", srv.ListenAddr)
+ if err != nil {
+ return err
+ }
+ srv.ListenAddr = listener.Addr().String()
+ srv.laddr = listener.Addr().(*net.TCPAddr)
+ srv.listener = listener
+ srv.wg.Add(1)
+ go srv.listenLoop()
+ if !srv.laddr.IP.IsLoopback() && srv.NAT != nil {
+ srv.wg.Add(1)
+ go srv.natLoop(srv.laddr.Port)
+ }
+ return nil
+}
+
+// Stop terminates the server and all active peer connections.
+// It blocks until all active connections have been closed.
+func (srv *Server) Stop() {
+ srv.lock.Lock()
+ if !srv.running {
+ srv.lock.Unlock()
+ return
}
- self.peersLock.RUnlock()
- for _, address := range peers {
- go self.removePeer(DisconnectRequest{
- addr: address,
- reason: DiscQuitting,
- })
+ srv.running = false
+ srv.lock.Unlock()
+
+ srvlog.Infoln("Stopping server")
+ if srv.listener != nil {
+ // this unblocks listener Accept
+ srv.listener.Close()
+ }
+ close(srv.quit)
+ for _, peer := range srv.Peers() {
+ peer.Disconnect(DiscQuitting)
}
+ srv.wg.Wait()
+
// wait till they actually disconnect
- // this is checked by draining the peerSlots (slots are released back if a peer is removed)
- i := 0
- fmt.Println("draining peers")
+ // this is checked by claiming all peerSlots.
+ // slots become available as the peers disconnect.
+ for i := 0; i < cap(srv.peerSlots); i++ {
+ <-srv.peerSlots
+ }
+ // terminate discLoop
+ close(srv.peerDisconnect)
+}
+
+func (srv *Server) discLoop() {
+ for peer := range srv.peerDisconnect {
+ // peer has just disconnected. free up its slot.
+ srvlog.Infof("%v is gone", peer)
+ srv.peerSlots <- peer.slot
+ srv.lock.Lock()
+ srv.peers[peer.slot] = nil
+ srv.lock.Unlock()
+ }
+}
-FOR:
+// main loop for adding connections via listening
+func (srv *Server) listenLoop() {
+ defer srv.wg.Done()
+
+ srvlog.Infoln("Listening on", srv.listener.Addr())
for {
select {
- case slot := <-self.peerSlots:
- i++
- fmt.Printf("%v: found slot %v\n", i, slot)
- if i == self.maxPeers {
- break FOR
+ case slot := <-srv.peerSlots:
+ conn, err := srv.listener.Accept()
+ if err != nil {
+ srv.peerSlots <- slot
+ return
}
+ srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot)
+ srv.addPeer(conn, nil, slot)
+ case <-srv.quit:
+ return
}
}
- logger.Infoln("server stopped")
}
-// main loop for adding connections via listening
-func (self *Server) inboundPeerHandler(listener net.Listener) {
+func (srv *Server) natLoop(port int) {
+ defer srv.wg.Done()
for {
+ srv.updatePortMapping(port)
select {
- case slot := <-self.peerSlots:
- go self.connectInboundPeer(listener, slot)
- case errc := <-self.quit:
- listener.Close()
- fmt.Println("quit listenloop")
- errc <- true
+ case <-time.After(portMappingUpdateInterval):
+ // one more round
+ case <-srv.quit:
+ srv.removePortMapping(port)
return
}
}
}
-// main loop for adding outbound peers based on peerConnect address pool
-// this same loop handles peer disconnect requests as well
-func (self *Server) outboundPeerHandler(dialer Dialer) {
- // addressChan initially set to nil (only watches peerConnect if we need more peers)
- var addressChan chan net.Addr
- slots := self.peerSlots
- var slot *int
+func (srv *Server) updatePortMapping(port int) {
+ srvlog.Infoln("Attempting to map port", port, "with", srv.NAT)
+ err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout)
+ if err != nil {
+ srvlog.Errorln("Port mapping error:", err)
+ return
+ }
+ extip, err := srv.NAT.GetExternalAddress()
+ if err != nil {
+ srvlog.Errorln("Error getting external IP:", err)
+ return
+ }
+ srv.lock.Lock()
+ extaddr := *(srv.listener.Addr().(*net.TCPAddr))
+ extaddr.IP = extip
+ srvlog.Infoln("Mapped port, external addr is", &extaddr)
+ srv.laddr = &extaddr
+ srv.lock.Unlock()
+}
+
+func (srv *Server) removePortMapping(port int) {
+ srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT)
+ srv.NAT.DeletePortMapping("tcp", port, port)
+}
+
+func (srv *Server) dialLoop() {
+ defer srv.wg.Done()
+ var (
+ suggest chan *peerAddr
+ slot *int
+ slots = srv.peerSlots
+ )
for {
select {
case i := <-slots:
// we need a peer in slot i, slot reserved
slot = &i
// now we can watch for candidate peers in the next loop
- addressChan = self.peerConnect
+ suggest = srv.peerConnect
// do not consume more until candidate peer is found
slots = nil
- case address := <-addressChan:
+
+ case desc := <-suggest:
// candidate peer found, will dial out asyncronously
// if connection fails slot will be released
- go self.connectOutboundPeer(dialer, address, *slot)
+ go srv.dialPeer(desc, *slot)
// we can watch if more peers needed in the next loop
- slots = self.peerSlots
+ slots = srv.peerSlots
// until then we dont care about candidate peers
- addressChan = nil
- case request := <-self.peerDisconnect:
- go self.removePeer(request)
- case errc := <-self.quit:
- if addressChan != nil && slot != nil {
- self.peerSlots <- *slot
+ suggest = nil
+
+ case <-srv.quit:
+ // give back the currently reserved slot
+ if slot != nil {
+ srv.peerSlots <- *slot
}
- fmt.Println("quit dialloop")
- errc <- true
return
}
}
}
-// check if peer address already connected
-func (self *Server) isConnected(address net.Addr) bool {
- self.peersLock.RLock()
- defer self.peersLock.RUnlock()
- _, found := self.peersTable[address.String()]
- return found
-}
-
-// connect to peer via listener.Accept()
-func (self *Server) connectInboundPeer(listener net.Listener, slot int) {
- var address net.Addr
- conn, err := listener.Accept()
- if err != nil {
- logger.Debugln(err)
- self.peerSlots <- slot
- return
- }
- address = conn.RemoteAddr()
- // XXX: this won't work because the remote socket
- // address does not identify the peer. we should
- // probably get rid of this check and rely on public
- // key detection in the base protocol.
- if self.isConnected(address) {
- conn.Close()
- self.peerSlots <- slot
- return
- }
- fmt.Printf("adding %v\n", address)
- go self.addPeer(conn, address, true, slot)
-}
-
// connect to peer via dial out
-func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) {
- if self.isConnected(address) {
- return
- }
- conn, err := dialer.Dial(address.Network(), address.String())
+func (srv *Server) dialPeer(desc *peerAddr, slot int) {
+ srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot)
+ conn, err := srv.Dialer.Dial(desc.Network(), desc.String())
if err != nil {
- self.peerSlots <- slot
+ srvlog.Errorf("Dial error: %v", err)
+ srv.peerSlots <- slot
return
}
- go self.addPeer(conn, address, false, slot)
+ go srv.addPeer(conn, desc, slot)
}
// creates the new peer object and inserts it into its slot
-func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) *Peer {
- self.peersLock.Lock()
- defer self.peersLock.Unlock()
- if self.closed {
- fmt.Println("oopsy, not no longer need peer")
- conn.Close() //oopsy our bad
- self.peerSlots <- slot // release slot
+func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer {
+ srv.lock.Lock()
+ defer srv.lock.Unlock()
+ if !srv.running {
+ conn.Close()
+ srv.peerSlots <- slot // release slot
return nil
}
- logger.Infoln("adding new peer", address)
- peer := NewPeer(conn, address, inbound, self)
- self.peers[slot] = peer
- self.peersTable[address.String()] = slot
- self.peerCount++
- self.cachedEncodedPeers = nil
- fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
- peer.Start()
+ peer := srv.newPeerFunc(srv, conn, desc)
+ peer.slot = slot
+ srv.peers[slot] = peer
+ srv.peerCount++
+ go func() { peer.loop(); srv.peerDisconnect <- peer }()
return peer
}
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
-func (self *Server) removePeer(request DisconnectRequest) {
- self.peersLock.Lock()
-
- address := request.addr
- slot := self.peersTable[address.String()]
- peer := self.peers[slot]
- fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot)
- if peer == nil {
- logger.Debugf("already removed peer on %v", address)
- self.peersLock.Unlock()
+func (srv *Server) removePeer(peer *Peer) {
+ srv.lock.Lock()
+ defer srv.lock.Unlock()
+ srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot)
+ if srv.peers[peer.slot] != peer {
+ srvlog.Warnln("Invalid peer to remove:", peer)
return
}
// remove from list and index
- self.peerCount--
- self.peers[slot] = nil
- delete(self.peersTable, address.String())
- self.cachedEncodedPeers = nil
- fmt.Printf("removed peer %v (slot %v)\n", peer, slot)
- self.peersLock.Unlock()
-
- // sending disconnect message
- disconnectMsg := NewMsg(discMsg, request.reason)
- peer.Write("", disconnectMsg)
- // be nice and wait
- time.Sleep(disconnectGracePeriod * time.Second)
- // switch off peer and close connections etc.
- fmt.Println("stopping peer")
- peer.Stop()
- fmt.Println("stopped peer")
+ srv.peerCount--
+ srv.peers[peer.slot] = nil
// release slot to signal need for a new peer, last!
- self.peerSlots <- slot
+ srv.peerSlots <- peer.slot
}
-// encodedPeerList returns an RLP-encoded list of peers.
-// the returned slice will be nil if there are no peers.
-func (self *Server) encodedPeerList() []byte {
- // TODO: memoize and reset when peers change
- self.peersLock.RLock()
- defer self.peersLock.RUnlock()
- if self.cachedEncodedPeers == nil && self.peerCount > 0 {
- var peerData []interface{}
- for _, i := range self.peersTable {
- peer := self.peers[i]
- peerData = append(peerData, peer.Encode())
+func (srv *Server) verifyPeer(addr *peerAddr) error {
+ if srv.Blacklist.Exists(addr.Pubkey) {
+ return errors.New("blacklisted")
+ }
+ if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) {
+ return newPeerError(errPubkeyForbidden, "not allowed to connect to srv")
+ }
+ srv.lock.RLock()
+ defer srv.lock.RUnlock()
+ for _, peer := range srv.peers {
+ if peer != nil {
+ id := peer.Identity()
+ if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) {
+ return errors.New("already connected")
+ }
}
- self.cachedEncodedPeers = encodePayload(peerData)
}
- return self.cachedEncodedPeers
+ return nil
}
-// fix handshake message to push to peers
-func (self *Server) handshakeMsg() Msg {
- return NewMsg(handshakeMsg,
- p2pVersion,
- []byte(self.identity.String()),
- []interface{}{self.protocols},
- self.port,
- self.identity.Pubkey()[1:],
- )
+type Blacklist interface {
+ Get([]byte) (bool, error)
+ Put([]byte) error
+ Delete([]byte) error
+ Exists(pubkey []byte) (ok bool)
+}
+
+type BlacklistMap struct {
+ blacklist map[string]bool
+ lock sync.RWMutex
}
-func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error {
- // Check for blacklisting
- if self.blacklist.Exists(pubkey) {
- return fmt.Errorf("blacklisted")
+func NewBlacklist() *BlacklistMap {
+ return &BlacklistMap{
+ blacklist: make(map[string]bool),
}
+}
- self.peersLock.RLock()
- defer self.peersLock.RUnlock()
- for _, peer := range self.peers {
- if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 {
- return fmt.Errorf("already connected")
- }
+func (self *BlacklistMap) Get(pubkey []byte) (bool, error) {
+ self.lock.RLock()
+ defer self.lock.RUnlock()
+ v, ok := self.blacklist[string(pubkey)]
+ var err error
+ if !ok {
+ err = fmt.Errorf("not found")
}
- candidate.Pubkey = pubkey
+ return v, err
+}
+
+func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) {
+ self.lock.RLock()
+ defer self.lock.RUnlock()
+ _, ok = self.blacklist[string(pubkey)]
+ return
+}
+
+func (self *BlacklistMap) Put(pubkey []byte) error {
+ self.lock.RLock()
+ defer self.lock.RUnlock()
+ self.blacklist[string(pubkey)] = true
+ return nil
+}
+
+func (self *BlacklistMap) Delete(pubkey []byte) error {
+ self.lock.RLock()
+ defer self.lock.RUnlock()
+ delete(self.blacklist, string(pubkey))
return nil
}
diff --git a/p2p/server_test.go b/p2p/server_test.go
index a2594acba..5c0d08d39 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -1,289 +1,161 @@
package p2p
import (
- "fmt"
+ "bytes"
"io"
"net"
+ "sync"
"testing"
"time"
)
-type TestNetwork struct {
- connections map[string]*TestNetworkConnection
- dialer Dialer
- maxinbound int
-}
-
-func NewTestNetwork(maxinbound int) *TestNetwork {
- connections := make(map[string]*TestNetworkConnection)
- return &TestNetwork{
- connections: connections,
- dialer: &TestDialer{connections},
- maxinbound: maxinbound,
+func startTestServer(t *testing.T, pf peerFunc) *Server {
+ server := &Server{
+ Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"),
+ MaxPeers: 10,
+ ListenAddr: "127.0.0.1:0",
+ newPeerFunc: pf,
}
-}
-
-func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) {
- return self.dialer, nil
-}
-
-func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) {
- return &TestListener{
- connections: self.connections,
- addr: addr,
- max: self.maxinbound,
- close: make(chan struct{}),
- }, nil
-}
-
-func (self *TestNetwork) Start() error {
- return nil
-}
-
-func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) {
- return
-}
-
-func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) {
- return
-}
-
-type TestAddr struct {
- name string
-}
-
-func (self *TestAddr) String() string {
- return self.name
-}
-
-func (*TestAddr) Network() string {
- return "test"
-}
-
-type TestDialer struct {
- connections map[string]*TestNetworkConnection
-}
-
-func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) {
- address := &TestAddr{addr}
- tconn := NewTestNetworkConnection(address)
- self.connections[addr] = tconn
- conn = net.Conn(tconn)
- return
-}
-
-type TestListener struct {
- connections map[string]*TestNetworkConnection
- addr net.Addr
- max int
- i int
- close chan struct{}
-}
-
-func (self *TestListener) Accept() (net.Conn, error) {
- self.i++
- if self.i > self.max {
- <-self.close
- return nil, io.EOF
+ if err := server.Start(); err != nil {
+ t.Fatalf("Could not start server: %v", err)
}
- addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)}
- tconn := NewTestNetworkConnection(addr)
- key := tconn.RemoteAddr().String()
- self.connections[key] = tconn
- fmt.Printf("accepted connection from: %v \n", addr)
- return tconn, nil
-}
-
-func (self *TestListener) Close() error {
- close(self.close)
- return nil
-}
-
-func (self *TestListener) Addr() net.Addr {
- return self.addr
+ return server
}
-type TestNetworkConnection struct {
- in chan []byte
- close chan struct{}
- current []byte
- Out [][]byte
- addr net.Addr
-}
+func TestServerListen(t *testing.T) {
+ defer testlog(t).detach()
-func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
- return &TestNetworkConnection{
- in: make(chan []byte),
- close: make(chan struct{}),
- current: []byte{},
- Out: [][]byte{},
- addr: addr,
+ // start the test server
+ connected := make(chan *Peer)
+ srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
+ if conn == nil {
+ t.Error("peer func called with nil conn")
+ }
+ if dialAddr != nil {
+ t.Error("peer func called with non-nil dialAddr")
+ }
+ peer := newPeer(conn, nil, dialAddr)
+ connected <- peer
+ return peer
+ })
+ defer close(connected)
+ defer srv.Stop()
+
+ // dial the test server
+ conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
+ if err != nil {
+ t.Fatalf("could not dial: %v", err)
}
-}
+ defer conn.Close()
-func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
- time.Sleep(latency)
- for _, s := range packets {
- self.in <- s
+ select {
+ case peer := <-connected:
+ if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() {
+ t.Errorf("peer started with wrong conn: got %v, want %v",
+ peer.conn.LocalAddr(), conn.RemoteAddr())
+ }
+ case <-time.After(1 * time.Second):
+ t.Error("server did not accept within one second")
}
}
-func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
- if len(self.current) == 0 {
- var ok bool
+func TestServerDial(t *testing.T) {
+ defer testlog(t).detach()
+
+ // run a fake TCP server to handle the connection.
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("could not setup listener: %v")
+ }
+ defer listener.Close()
+ accepted := make(chan net.Conn)
+ go func() {
+ conn, err := listener.Accept()
+ if err != nil {
+ t.Error("acccept error:", err)
+ }
+ conn.Close()
+ accepted <- conn
+ }()
+
+ // start the test server
+ connected := make(chan *Peer)
+ srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
+ if conn == nil {
+ t.Error("peer func called with nil conn")
+ }
+ peer := newPeer(conn, nil, dialAddr)
+ connected <- peer
+ return peer
+ })
+ defer close(connected)
+ defer srv.Stop()
+
+ // tell the server to connect.
+ connAddr := newPeerAddr(listener.Addr(), nil)
+ srv.peerConnect <- connAddr
+
+ select {
+ case conn := <-accepted:
select {
- case self.current, ok = <-self.in:
- if !ok {
- return 0, io.EOF
+ case peer := <-connected:
+ if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() {
+ t.Errorf("peer started with wrong conn: got %v, want %v",
+ peer.conn.RemoteAddr(), conn.LocalAddr())
+ }
+ if peer.dialAddr != connAddr {
+ t.Errorf("peer started with wrong dialAddr: got %v, want %v",
+ peer.dialAddr, connAddr)
}
- case <-self.close:
- return 0, io.EOF
+ case <-time.After(1 * time.Second):
+ t.Error("server did not launch peer within one second")
}
- }
- length := len(self.current)
- if length > len(buff) {
- copy(buff[:], self.current[:len(buff)])
- self.current = self.current[len(buff):]
- return len(buff), nil
- } else {
- copy(buff[:length], self.current[:])
- self.current = []byte{}
- return length, io.EOF
- }
-}
-
-func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
- self.Out = append(self.Out, buff)
- fmt.Printf("net write(%d): %x\n", len(self.Out), buff)
- return len(buff), nil
-}
-
-func (self *TestNetworkConnection) Close() error {
- close(self.close)
- return nil
-}
-
-func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
- return
-}
-func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
- return self.addr
-}
-
-func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
- return
-}
-
-func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
- return
-}
-
-func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
- return
-}
-
-func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) {
- network = NewTestNetwork(1)
- addr := &TestAddr{"test:30303"}
- identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey")
- maxPeers := 2
- if handlers == nil {
- handlers = make(Handlers)
+ case <-time.After(1 * time.Second):
+ t.Error("server did not connect within one second")
}
- blackist := NewBlacklist()
- server = New(network, addr, identity, handlers, maxPeers, blackist)
- fmt.Println(server.identity.Pubkey())
- return
}
-func TestServerListener(t *testing.T) {
- t.SkipNow()
-
- network, server := SetupTestServer(nil)
- server.Start(true, false)
- time.Sleep(10 * time.Millisecond)
- server.Stop()
- peer1, ok := network.connections["inboundpeer-1"]
- if !ok {
- t.Error("not found inbound peer 1")
- } else {
- if len(peer1.Out) != 2 {
- t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2)
+func TestServerBroadcast(t *testing.T) {
+ defer testlog(t).detach()
+ var connected sync.WaitGroup
+ srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer {
+ peer := newPeer(c, []Protocol{discard}, dialAddr)
+ peer.startSubprotocols([]Cap{discard.cap()})
+ connected.Done()
+ return peer
+ })
+ defer srv.Stop()
+
+ // dial a bunch of conns
+ var conns = make([]net.Conn, 8)
+ connected.Add(len(conns))
+ deadline := time.Now().Add(3 * time.Second)
+ dialer := &net.Dialer{Deadline: deadline}
+ for i := range conns {
+ conn, err := dialer.Dial("tcp", srv.ListenAddr)
+ if err != nil {
+ t.Fatalf("conn %d: dial error: %v", i, err)
}
+ defer conn.Close()
+ conn.SetDeadline(deadline)
+ conns[i] = conn
}
-}
-
-func TestServerDialer(t *testing.T) {
- network, server := SetupTestServer(nil)
- server.Start(false, true)
- server.peerConnect <- &TestAddr{"outboundpeer-1"}
- time.Sleep(10 * time.Millisecond)
- server.Stop()
- peer1, ok := network.connections["outboundpeer-1"]
- if !ok {
- t.Error("not found outbound peer 1")
- } else {
- if len(peer1.Out) != 2 {
- t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2)
+ connected.Wait()
+
+ // broadcast one message
+ srv.Broadcast("discard", 0, "foo")
+ goldbuf := new(bytes.Buffer)
+ writeMsg(goldbuf, NewMsg(16, "foo"))
+ golden := goldbuf.Bytes()
+
+ // check that the message has been written everywhere
+ for i, conn := range conns {
+ buf := make([]byte, len(golden))
+ if _, err := io.ReadFull(conn, buf); err != nil {
+ t.Errorf("conn %d: read error: %v", i, err)
+ } else if !bytes.Equal(buf, golden) {
+ t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden)
}
}
}
-
-// func TestServerBroadcast(t *testing.T) {
-// handlers := make(Handlers)
-// testProtocol := &TestProtocol{Msgs: []*Msg{}}
-// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
-// network, server := SetupTestServer(handlers)
-// server.Start(true, true)
-// server.peerConnect <- &TestAddr{"outboundpeer-1"}
-// time.Sleep(10 * time.Millisecond)
-// msg := NewMsg(0)
-// server.Broadcast("", msg)
-// packet := Packet(0, 0)
-// time.Sleep(10 * time.Millisecond)
-// server.Stop()
-// peer1, ok := network.connections["outboundpeer-1"]
-// if !ok {
-// t.Error("not found outbound peer 1")
-// } else {
-// fmt.Printf("out: %v\n", peer1.Out)
-// if len(peer1.Out) != 3 {
-// t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
-// } else {
-// if bytes.Compare(peer1.Out[1], packet) != 0 {
-// t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
-// }
-// }
-// }
-// peer2, ok := network.connections["inboundpeer-1"]
-// if !ok {
-// t.Error("not found inbound peer 2")
-// } else {
-// fmt.Printf("out: %v\n", peer2.Out)
-// if len(peer1.Out) != 3 {
-// t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
-// } else {
-// if bytes.Compare(peer2.Out[1], packet) != 0 {
-// t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
-// }
-// }
-// }
-// }
-
-func TestServerPeersMessage(t *testing.T) {
- t.SkipNow()
- _, server := SetupTestServer(nil)
- server.Start(true, true)
- defer server.Stop()
- server.peerConnect <- &TestAddr{"outboundpeer-1"}
- time.Sleep(2000 * time.Millisecond)
-
- pl := server.encodedPeerList()
- if pl == nil {
- t.Errorf("expect non-nil peer list")
- }
- if c := server.PeerCount(); c != 2 {
- t.Errorf("expect 2 peers, got %v", c)
- }
-}
diff --git a/p2p/testlog_test.go b/p2p/testlog_test.go
new file mode 100644
index 000000000..951d43243
--- /dev/null
+++ b/p2p/testlog_test.go
@@ -0,0 +1,28 @@
+package p2p
+
+import (
+ "testing"
+
+ "github.com/ethereum/go-ethereum/logger"
+)
+
+type testLogger struct{ t *testing.T }
+
+func testlog(t *testing.T) testLogger {
+ logger.Reset()
+ l := testLogger{t}
+ logger.AddLogSystem(l)
+ return l
+}
+
+func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
+func (testLogger) SetLogLevel(logger.LogLevel) {}
+
+func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
+ l.t.Logf("%s", msg)
+}
+
+func (testLogger) detach() {
+ logger.Flush()
+ logger.Reset()
+}
diff --git a/p2p/testpoc7.go b/p2p/testpoc7.go
new file mode 100644
index 000000000..c0cc5c544
--- /dev/null
+++ b/p2p/testpoc7.go
@@ -0,0 +1,40 @@
+// +build none
+
+package main
+
+import (
+ "fmt"
+ "log"
+ "net"
+ "os"
+
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/obscuren/secp256k1-go"
+)
+
+func main() {
+ logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
+
+ pub, _ := secp256k1.GenerateKeyPair()
+ srv := p2p.Server{
+ MaxPeers: 10,
+ Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
+ ListenAddr: ":30303",
+ NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
+ }
+ if err := srv.Start(); err != nil {
+ fmt.Println("could not start server:", err)
+ os.Exit(1)
+ }
+
+ // add seed peers
+ seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
+ if err != nil {
+ fmt.Println("couldn't resolve:", err)
+ os.Exit(1)
+ }
+ srv.SuggestPeer(seed.IP, seed.Port, nil)
+
+ select {}
+}