aboutsummaryrefslogtreecommitdiffstats
path: root/p2p
diff options
context:
space:
mode:
Diffstat (limited to 'p2p')
-rw-r--r--p2p/client_identity.go63
-rw-r--r--p2p/client_identity_test.go30
-rw-r--r--p2p/message.go232
-rw-r--r--p2p/message_test.go133
-rw-r--r--p2p/natpmp.go55
-rw-r--r--p2p/natupnp.go341
-rw-r--r--p2p/peer.go462
-rw-r--r--p2p/peer_error.go133
-rw-r--r--p2p/peer_test.go295
-rw-r--r--p2p/protocol.go294
-rw-r--r--p2p/protocol_test.go58
-rw-r--r--p2p/server.go467
-rw-r--r--p2p/server_test.go161
-rw-r--r--p2p/testlog_test.go28
-rw-r--r--p2p/testpoc7.go40
15 files changed, 2792 insertions, 0 deletions
diff --git a/p2p/client_identity.go b/p2p/client_identity.go
new file mode 100644
index 000000000..bc865b63b
--- /dev/null
+++ b/p2p/client_identity.go
@@ -0,0 +1,63 @@
+package p2p
+
+import (
+ "fmt"
+ "runtime"
+)
+
+// ClientIdentity represents the identity of a peer.
+type ClientIdentity interface {
+ String() string // human readable identity
+ Pubkey() []byte // 512-bit public key
+}
+
+type SimpleClientIdentity struct {
+ clientIdentifier string
+ version string
+ customIdentifier string
+ os string
+ implementation string
+ pubkey string
+}
+
+func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey string) *SimpleClientIdentity {
+ clientIdentity := &SimpleClientIdentity{
+ clientIdentifier: clientIdentifier,
+ version: version,
+ customIdentifier: customIdentifier,
+ os: runtime.GOOS,
+ implementation: runtime.Version(),
+ pubkey: pubkey,
+ }
+
+ return clientIdentity
+}
+
+func (c *SimpleClientIdentity) init() {
+}
+
+func (c *SimpleClientIdentity) String() string {
+ var id string
+ if len(c.customIdentifier) > 0 {
+ id = "/" + c.customIdentifier
+ }
+
+ return fmt.Sprintf("%s/v%s%s/%s/%s",
+ c.clientIdentifier,
+ c.version,
+ id,
+ c.os,
+ c.implementation)
+}
+
+func (c *SimpleClientIdentity) Pubkey() []byte {
+ return []byte(c.pubkey)
+}
+
+func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) {
+ c.customIdentifier = customIdentifier
+}
+
+func (c *SimpleClientIdentity) GetCustomIdentifier() string {
+ return c.customIdentifier
+}
diff --git a/p2p/client_identity_test.go b/p2p/client_identity_test.go
new file mode 100644
index 000000000..40b0e6f5e
--- /dev/null
+++ b/p2p/client_identity_test.go
@@ -0,0 +1,30 @@
+package p2p
+
+import (
+ "fmt"
+ "runtime"
+ "testing"
+)
+
+func TestClientIdentity(t *testing.T) {
+ clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", "pubkey")
+ clientString := clientIdentity.String()
+ expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
+ if clientString != expected {
+ t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
+ }
+ customIdentifier := clientIdentity.GetCustomIdentifier()
+ if customIdentifier != "test" {
+ t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier)
+ }
+ clientIdentity.SetCustomIdentifier("test2")
+ customIdentifier = clientIdentity.GetCustomIdentifier()
+ if customIdentifier != "test2" {
+ t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier)
+ }
+ clientString = clientIdentity.String()
+ expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version())
+ if clientString != expected {
+ t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
+ }
+}
diff --git a/p2p/message.go b/p2p/message.go
new file mode 100644
index 000000000..f5418ff47
--- /dev/null
+++ b/p2p/message.go
@@ -0,0 +1,232 @@
+package p2p
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "io"
+ "io/ioutil"
+ "math/big"
+ "sync/atomic"
+
+ "github.com/ethereum/go-ethereum/ethutil"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// Msg defines the structure of a p2p message.
+//
+// Note that a Msg can only be sent once since the Payload reader is
+// consumed during sending. It is not possible to create a Msg and
+// send it any number of times. If you want to reuse an encoded
+// structure, encode the payload into a byte array and create a
+// separate Msg with a bytes.Reader as Payload for each send.
+type Msg struct {
+ Code uint64
+ Size uint32 // size of the paylod
+ Payload io.Reader
+}
+
+// NewMsg creates an RLP-encoded message with the given code.
+func NewMsg(code uint64, params ...interface{}) Msg {
+ buf := new(bytes.Buffer)
+ for _, p := range params {
+ buf.Write(ethutil.Encode(p))
+ }
+ return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf}
+}
+
+func encodePayload(params ...interface{}) []byte {
+ buf := new(bytes.Buffer)
+ for _, p := range params {
+ buf.Write(ethutil.Encode(p))
+ }
+ return buf.Bytes()
+}
+
+// Decode parse the RLP content of a message into
+// the given value, which must be a pointer.
+//
+// For the decoding rules, please see package rlp.
+func (msg Msg) Decode(val interface{}) error {
+ s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
+ return s.Decode(val)
+}
+
+// Discard reads any remaining payload data into a black hole.
+func (msg Msg) Discard() error {
+ _, err := io.Copy(ioutil.Discard, msg.Payload)
+ return err
+}
+
+type MsgReader interface {
+ ReadMsg() (Msg, error)
+}
+
+type MsgWriter interface {
+ // WriteMsg sends an existing message.
+ // The Payload reader of the message is consumed.
+ // Note that messages can be sent only once.
+ WriteMsg(Msg) error
+
+ // EncodeMsg writes an RLP-encoded message with the given
+ // code and data elements.
+ EncodeMsg(code uint64, data ...interface{}) error
+}
+
+// MsgReadWriter provides reading and writing of encoded messages.
+type MsgReadWriter interface {
+ MsgReader
+ MsgWriter
+}
+
+var magicToken = []byte{34, 64, 8, 145}
+
+func writeMsg(w io.Writer, msg Msg) error {
+ // TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
+ code := ethutil.Encode(uint32(msg.Code))
+ listhdr := makeListHeader(msg.Size + uint32(len(code)))
+ payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size
+
+ start := make([]byte, 8)
+ copy(start, magicToken)
+ binary.BigEndian.PutUint32(start[4:], payloadLen)
+
+ for _, b := range [][]byte{start, listhdr, code} {
+ if _, err := w.Write(b); err != nil {
+ return err
+ }
+ }
+ _, err := io.CopyN(w, msg.Payload, int64(msg.Size))
+ return err
+}
+
+func makeListHeader(length uint32) []byte {
+ if length < 56 {
+ return []byte{byte(length + 0xc0)}
+ }
+ enc := big.NewInt(int64(length)).Bytes()
+ lenb := byte(len(enc)) + 0xf7
+ return append([]byte{lenb}, enc...)
+}
+
+// readMsg reads a message header from r.
+// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer.
+func readMsg(r rlp.ByteReader) (msg Msg, err error) {
+ // read magic and payload size
+ start := make([]byte, 8)
+ if _, err = io.ReadFull(r, start); err != nil {
+ return msg, newPeerError(errRead, "%v", err)
+ }
+ if !bytes.HasPrefix(start, magicToken) {
+ return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
+ }
+ size := binary.BigEndian.Uint32(start[4:])
+
+ // decode start of RLP message to get the message code
+ posr := &postrack{r, 0}
+ s := rlp.NewStream(posr)
+ if _, err := s.List(); err != nil {
+ return msg, err
+ }
+ code, err := s.Uint()
+ if err != nil {
+ return msg, err
+ }
+ payloadsize := size - posr.p
+ return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil
+}
+
+// postrack wraps an rlp.ByteReader with a position counter.
+type postrack struct {
+ r rlp.ByteReader
+ p uint32
+}
+
+func (r *postrack) Read(buf []byte) (int, error) {
+ n, err := r.r.Read(buf)
+ r.p += uint32(n)
+ return n, err
+}
+
+func (r *postrack) ReadByte() (byte, error) {
+ b, err := r.r.ReadByte()
+ if err == nil {
+ r.p++
+ }
+ return b, err
+}
+
+// MsgPipe creates a message pipe. Reads on one end are matched
+// with writes on the other. The pipe is full-duplex, both ends
+// implement MsgReadWriter.
+func MsgPipe() (*MsgPipeRW, *MsgPipeRW) {
+ var (
+ c1, c2 = make(chan Msg), make(chan Msg)
+ closing = make(chan struct{})
+ closed = new(int32)
+ rw1 = &MsgPipeRW{c1, c2, closing, closed}
+ rw2 = &MsgPipeRW{c2, c1, closing, closed}
+ )
+ return rw1, rw2
+}
+
+// ErrPipeClosed is returned from pipe operations after the
+// pipe has been closed.
+var ErrPipeClosed = errors.New("p2p: read or write on closed message pipe")
+
+// MsgPipeRW is an endpoint of a MsgReadWriter pipe.
+type MsgPipeRW struct {
+ w chan<- Msg
+ r <-chan Msg
+ closing chan struct{}
+ closed *int32
+}
+
+// WriteMsg sends a messsage on the pipe.
+// It blocks until the receiver has consumed the message payload.
+func (p *MsgPipeRW) WriteMsg(msg Msg) error {
+ if atomic.LoadInt32(p.closed) == 0 {
+ consumed := make(chan struct{}, 1)
+ msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed}
+ select {
+ case p.w <- msg:
+ if msg.Size > 0 {
+ // wait for payload read or discard
+ <-consumed
+ }
+ return nil
+ case <-p.closing:
+ }
+ }
+ return ErrPipeClosed
+}
+
+// EncodeMsg is a convenient shorthand for sending an RLP-encoded message.
+func (p *MsgPipeRW) EncodeMsg(code uint64, data ...interface{}) error {
+ return p.WriteMsg(NewMsg(code, data...))
+}
+
+// ReadMsg returns a message sent on the other end of the pipe.
+func (p *MsgPipeRW) ReadMsg() (Msg, error) {
+ if atomic.LoadInt32(p.closed) == 0 {
+ select {
+ case msg := <-p.r:
+ return msg, nil
+ case <-p.closing:
+ }
+ }
+ return Msg{}, ErrPipeClosed
+}
+
+// Close unblocks any pending ReadMsg and WriteMsg calls on both ends
+// of the pipe. They will return ErrPipeClosed. Note that Close does
+// not interrupt any reads from a message payload.
+func (p *MsgPipeRW) Close() error {
+ if atomic.AddInt32(p.closed, 1) != 1 {
+ // someone else is already closing
+ atomic.StoreInt32(p.closed, 1) // avoid overflow
+ return nil
+ }
+ close(p.closing)
+ return nil
+}
diff --git a/p2p/message_test.go b/p2p/message_test.go
new file mode 100644
index 000000000..066d2516d
--- /dev/null
+++ b/p2p/message_test.go
@@ -0,0 +1,133 @@
+package p2p
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "runtime"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/ethutil"
+)
+
+func TestNewMsg(t *testing.T) {
+ msg := NewMsg(3, 1, "000")
+ if msg.Code != 3 {
+ t.Errorf("incorrect code %d, want %d", msg.Code)
+ }
+ if msg.Size != 5 {
+ t.Errorf("incorrect size %d, want %d", msg.Size, 5)
+ }
+ pl, _ := ioutil.ReadAll(msg.Payload)
+ expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30}
+ if !bytes.Equal(pl, expect) {
+ t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
+ }
+}
+
+func TestEncodeDecodeMsg(t *testing.T) {
+ msg := NewMsg(3, 1, "000")
+ buf := new(bytes.Buffer)
+ if err := writeMsg(buf, msg); err != nil {
+ t.Fatalf("encodeMsg error: %v", err)
+ }
+ // t.Logf("encoded: %x", buf.Bytes())
+
+ decmsg, err := readMsg(buf)
+ if err != nil {
+ t.Fatalf("readMsg error: %v", err)
+ }
+ if decmsg.Code != 3 {
+ t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
+ }
+ if decmsg.Size != 5 {
+ t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
+ }
+
+ var data struct {
+ I uint
+ S string
+ }
+ if err := decmsg.Decode(&data); err != nil {
+ t.Fatalf("Decode error: %v", err)
+ }
+ if data.I != 1 {
+ t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
+ }
+ if data.S != "000" {
+ t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
+ }
+}
+
+func TestDecodeRealMsg(t *testing.T) {
+ data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
+ msg, err := readMsg(bytes.NewReader(data))
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if msg.Code != 0 {
+ t.Errorf("incorrect code %d, want %d", msg.Code, 0)
+ }
+}
+
+func ExampleMsgPipe() {
+ rw1, rw2 := MsgPipe()
+ go func() {
+ rw1.EncodeMsg(8, []byte{0, 0})
+ rw1.EncodeMsg(5, []byte{1, 1})
+ rw1.Close()
+ }()
+
+ for {
+ msg, err := rw2.ReadMsg()
+ if err != nil {
+ break
+ }
+ var data [1][]byte
+ msg.Decode(&data)
+ fmt.Printf("msg: %d, %x\n", msg.Code, data[0])
+ }
+ // Output:
+ // msg: 8, 0000
+ // msg: 5, 0101
+}
+
+func TestMsgPipeUnblockWrite(t *testing.T) {
+loop:
+ for i := 0; i < 100; i++ {
+ rw1, rw2 := MsgPipe()
+ done := make(chan struct{})
+ go func() {
+ if err := rw1.EncodeMsg(1); err == nil {
+ t.Error("EncodeMsg returned nil error")
+ } else if err != ErrPipeClosed {
+ t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed)
+ }
+ close(done)
+ }()
+
+ // this call should ensure that EncodeMsg is waiting to
+ // deliver sometimes. if this isn't done, Close is likely to
+ // be executed before EncodeMsg starts and then we won't test
+ // all the cases.
+ runtime.Gosched()
+
+ rw2.Close()
+ select {
+ case <-done:
+ case <-time.After(200 * time.Millisecond):
+ t.Errorf("write didn't unblock")
+ break loop
+ }
+ }
+}
+
+// This test should panic if concurrent close isn't implemented correctly.
+func TestMsgPipeConcurrentClose(t *testing.T) {
+ rw1, _ := MsgPipe()
+ for i := 0; i < 10; i++ {
+ go rw1.Close()
+ }
+}
diff --git a/p2p/natpmp.go b/p2p/natpmp.go
new file mode 100644
index 000000000..6714678c4
--- /dev/null
+++ b/p2p/natpmp.go
@@ -0,0 +1,55 @@
+package p2p
+
+import (
+ "fmt"
+ "net"
+ "time"
+
+ natpmp "github.com/jackpal/go-nat-pmp"
+)
+
+// Adapt the NAT-PMP protocol to the NAT interface
+
+// TODO:
+// + Register for changes to the external address.
+// + Re-register port mapping when router reboots.
+// + A mechanism for keeping a port mapping registered.
+// + Discover gateway address automatically.
+
+type natPMPClient struct {
+ client *natpmp.Client
+}
+
+// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
+// address should be the IP of your router.
+func PMP(gateway net.IP) (nat NAT) {
+ return &natPMPClient{natpmp.NewClient(gateway)}
+}
+
+func (*natPMPClient) String() string {
+ return "NAT-PMP"
+}
+
+func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
+ response, err := n.client.GetExternalAddress()
+ if err != nil {
+ return nil, err
+ }
+ return response.ExternalIPAddress[:], nil
+}
+
+func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
+ if lifetime <= 0 {
+ return fmt.Errorf("lifetime must not be <= 0")
+ }
+ // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
+ _, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
+ return err
+}
+
+func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
+ // To destroy a mapping, send an add-port with
+ // an internalPort of the internal port to destroy, an external port of zero and a time of zero.
+ _, err = n.client.AddPortMapping(protocol, internalPort, 0, 0)
+ return
+}
diff --git a/p2p/natupnp.go b/p2p/natupnp.go
new file mode 100644
index 000000000..2e0d8ce8d
--- /dev/null
+++ b/p2p/natupnp.go
@@ -0,0 +1,341 @@
+package p2p
+
+// Just enough UPnP to be able to forward ports
+//
+
+import (
+ "bytes"
+ "encoding/xml"
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "os"
+ "strconv"
+ "strings"
+ "time"
+)
+
+const (
+ upnpDiscoverAttempts = 3
+ upnpDiscoverTimeout = 5 * time.Second
+)
+
+// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
+// discover the address of your router using UDP broadcasts.
+func UPNP() NAT {
+ return &upnpNAT{}
+}
+
+type upnpNAT struct {
+ serviceURL string
+ ourIP string
+}
+
+func (n *upnpNAT) String() string {
+ return "UPNP"
+}
+
+func (n *upnpNAT) discover() error {
+ if n.serviceURL != "" {
+ // already discovered
+ return nil
+ }
+
+ ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
+ if err != nil {
+ return err
+ }
+ // TODO: try on all network interfaces simultaneously.
+ // Broadcasting on 0.0.0.0 could select a random interface
+ // to send on (platform specific).
+ conn, err := net.ListenPacket("udp4", ":0")
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ conn.SetDeadline(time.Now().Add(10 * time.Second))
+ st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
+ buf := bytes.NewBufferString(
+ "M-SEARCH * HTTP/1.1\r\n" +
+ "HOST: 239.255.255.250:1900\r\n" +
+ st +
+ "MAN: \"ssdp:discover\"\r\n" +
+ "MX: 2\r\n\r\n")
+ message := buf.Bytes()
+ answerBytes := make([]byte, 1024)
+ for i := 0; i < upnpDiscoverAttempts; i++ {
+ _, err = conn.WriteTo(message, ssdp)
+ if err != nil {
+ return err
+ }
+ nn, _, err := conn.ReadFrom(answerBytes)
+ if err != nil {
+ continue
+ }
+ answer := string(answerBytes[0:nn])
+ if strings.Index(answer, "\r\n"+st) < 0 {
+ continue
+ }
+ // HTTP header field names are case-insensitive.
+ // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
+ locString := "\r\nlocation: "
+ answer = strings.ToLower(answer)
+ locIndex := strings.Index(answer, locString)
+ if locIndex < 0 {
+ continue
+ }
+ loc := answer[locIndex+len(locString):]
+ endIndex := strings.Index(loc, "\r\n")
+ if endIndex < 0 {
+ continue
+ }
+ locURL := loc[0:endIndex]
+ var serviceURL string
+ serviceURL, err = getServiceURL(locURL)
+ if err != nil {
+ return err
+ }
+ var ourIP string
+ ourIP, err = getOurIP()
+ if err != nil {
+ return err
+ }
+ n.serviceURL = serviceURL
+ n.ourIP = ourIP
+ return nil
+ }
+ return errors.New("UPnP port discovery failed.")
+}
+
+func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
+ if err := n.discover(); err != nil {
+ return nil, err
+ }
+ info, err := n.getStatusInfo()
+ return net.ParseIP(info.externalIpAddress), err
+}
+
+func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
+ if err := n.discover(); err != nil {
+ return err
+ }
+
+ // A single concatenation would break ARM compilation.
+ message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
+ "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
+ message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
+ message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
+ "<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
+ "<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
+ message += description +
+ "</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
+ "</NewLeaseDuration></u:AddPortMapping>"
+
+ // TODO: check response to see if the port was forwarded
+ _, err := soapRequest(n.serviceURL, "AddPortMapping", message)
+ return err
+}
+
+func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
+ if err := n.discover(); err != nil {
+ return err
+ }
+
+ message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
+ "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
+ "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
+ "</u:DeletePortMapping>"
+
+ // TODO: check response to see if the port was deleted
+ _, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
+ return err
+}
+
+type statusInfo struct {
+ externalIpAddress string
+}
+
+func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
+ message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
+ "</u:GetStatusInfo>"
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
+ if err != nil {
+ return
+ }
+
+ // TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
+
+ response.Body.Close()
+ return
+}
+
+// service represents the Service type in an UPnP xml description.
+// Only the parts we care about are present and thus the xml may have more
+// fields than present in the structure.
+type service struct {
+ ServiceType string `xml:"serviceType"`
+ ControlURL string `xml:"controlURL"`
+}
+
+// deviceList represents the deviceList type in an UPnP xml description.
+// Only the parts we care about are present and thus the xml may have more
+// fields than present in the structure.
+type deviceList struct {
+ XMLName xml.Name `xml:"deviceList"`
+ Device []device `xml:"device"`
+}
+
+// serviceList represents the serviceList type in an UPnP xml description.
+// Only the parts we care about are present and thus the xml may have more
+// fields than present in the structure.
+type serviceList struct {
+ XMLName xml.Name `xml:"serviceList"`
+ Service []service `xml:"service"`
+}
+
+// device represents the device type in an UPnP xml description.
+// Only the parts we care about are present and thus the xml may have more
+// fields than present in the structure.
+type device struct {
+ XMLName xml.Name `xml:"device"`
+ DeviceType string `xml:"deviceType"`
+ DeviceList deviceList `xml:"deviceList"`
+ ServiceList serviceList `xml:"serviceList"`
+}
+
+// specVersion represents the specVersion in a UPnP xml description.
+// Only the parts we care about are present and thus the xml may have more
+// fields than present in the structure.
+type specVersion struct {
+ XMLName xml.Name `xml:"specVersion"`
+ Major int `xml:"major"`
+ Minor int `xml:"minor"`
+}
+
+// root represents the Root document for a UPnP xml description.
+// Only the parts we care about are present and thus the xml may have more
+// fields than present in the structure.
+type root struct {
+ XMLName xml.Name `xml:"root"`
+ SpecVersion specVersion
+ Device device
+}
+
+func getChildDevice(d *device, deviceType string) *device {
+ dl := d.DeviceList.Device
+ for i := 0; i < len(dl); i++ {
+ if dl[i].DeviceType == deviceType {
+ return &dl[i]
+ }
+ }
+ return nil
+}
+
+func getChildService(d *device, serviceType string) *service {
+ sl := d.ServiceList.Service
+ for i := 0; i < len(sl); i++ {
+ if sl[i].ServiceType == serviceType {
+ return &sl[i]
+ }
+ }
+ return nil
+}
+
+func getOurIP() (ip string, err error) {
+ hostname, err := os.Hostname()
+ if err != nil {
+ return
+ }
+ p, err := net.LookupIP(hostname)
+ if err != nil && len(p) > 0 {
+ return
+ }
+ return p[0].String(), nil
+}
+
+func getServiceURL(rootURL string) (url string, err error) {
+ r, err := http.Get(rootURL)
+ if err != nil {
+ return
+ }
+ defer r.Body.Close()
+ if r.StatusCode >= 400 {
+ err = errors.New(string(r.StatusCode))
+ return
+ }
+ var root root
+ err = xml.NewDecoder(r.Body).Decode(&root)
+
+ if err != nil {
+ return
+ }
+ a := &root.Device
+ if a.DeviceType != "urn:schemas-upnp-org:device:InternetGatewayDevice:1" {
+ err = errors.New("No InternetGatewayDevice")
+ return
+ }
+ b := getChildDevice(a, "urn:schemas-upnp-org:device:WANDevice:1")
+ if b == nil {
+ err = errors.New("No WANDevice")
+ return
+ }
+ c := getChildDevice(b, "urn:schemas-upnp-org:device:WANConnectionDevice:1")
+ if c == nil {
+ err = errors.New("No WANConnectionDevice")
+ return
+ }
+ d := getChildService(c, "urn:schemas-upnp-org:service:WANIPConnection:1")
+ if d == nil {
+ err = errors.New("No WANIPConnection")
+ return
+ }
+ url = combineURL(rootURL, d.ControlURL)
+ return
+}
+
+func combineURL(rootURL, subURL string) string {
+ protocolEnd := "://"
+ protoEndIndex := strings.Index(rootURL, protocolEnd)
+ a := rootURL[protoEndIndex+len(protocolEnd):]
+ rootIndex := strings.Index(a, "/")
+ return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
+}
+
+func soapRequest(url, function, message string) (r *http.Response, err error) {
+ fullMessage := "<?xml version=\"1.0\" ?>" +
+ "<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" +
+ "<s:Body>" + message + "</s:Body></s:Envelope>"
+
+ req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
+ if err != nil {
+ return
+ }
+ req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
+ req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
+ //req.Header.Set("Transfer-Encoding", "chunked")
+ req.Header.Set("SOAPAction", "\"urn:schemas-upnp-org:service:WANIPConnection:1#"+function+"\"")
+ req.Header.Set("Connection", "Close")
+ req.Header.Set("Cache-Control", "no-cache")
+ req.Header.Set("Pragma", "no-cache")
+
+ r, err = http.DefaultClient.Do(req)
+ if err != nil {
+ return
+ }
+
+ if r.Body != nil {
+ defer r.Body.Close()
+ }
+
+ if r.StatusCode >= 400 {
+ // log.Stderr(function, r.StatusCode)
+ err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
+ r = nil
+ return
+ }
+ return
+}
diff --git a/p2p/peer.go b/p2p/peer.go
new file mode 100644
index 000000000..86c4d7ab5
--- /dev/null
+++ b/p2p/peer.go
@@ -0,0 +1,462 @@
+package p2p
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "sort"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/logger"
+)
+
+// peerAddr is the structure of a peer list element.
+// It is also a valid net.Addr.
+type peerAddr struct {
+ IP net.IP
+ Port uint64
+ Pubkey []byte // optional
+}
+
+func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr {
+ n := addr.Network()
+ if n != "tcp" && n != "tcp4" && n != "tcp6" {
+ // for testing with non-TCP
+ return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
+ }
+ ta := addr.(*net.TCPAddr)
+ return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
+}
+
+func (d peerAddr) Network() string {
+ if d.IP.To4() != nil {
+ return "tcp4"
+ } else {
+ return "tcp6"
+ }
+}
+
+func (d peerAddr) String() string {
+ return fmt.Sprintf("%v:%d", d.IP, d.Port)
+}
+
+func (d peerAddr) RlpData() interface{} {
+ return []interface{}{d.IP, d.Port, d.Pubkey}
+}
+
+// Peer represents a remote peer.
+type Peer struct {
+ // Peers have all the log methods.
+ // Use them to display messages related to the peer.
+ *logger.Logger
+
+ infolock sync.Mutex
+ identity ClientIdentity
+ caps []Cap
+ listenAddr *peerAddr // what remote peer is listening on
+ dialAddr *peerAddr // non-nil if dialing
+
+ // The mutex protects the connection
+ // so only one protocol can write at a time.
+ writeMu sync.Mutex
+ conn net.Conn
+ bufconn *bufio.ReadWriter
+
+ // These fields maintain the running protocols.
+ protocols []Protocol
+ runBaseProtocol bool // for testing
+
+ runlock sync.RWMutex // protects running
+ running map[string]*proto
+
+ protoWG sync.WaitGroup
+ protoErr chan error
+ closed chan struct{}
+ disc chan DiscReason
+
+ activity event.TypeMux // for activity events
+
+ slot int // index into Server peer list
+
+ // These fields are kept so base protocol can access them.
+ // TODO: this should be one or more interfaces
+ ourID ClientIdentity // client id of the Server
+ ourListenAddr *peerAddr // listen addr of Server, nil if not listening
+ newPeerAddr chan<- *peerAddr // tell server about received peers
+ otherPeers func() []*Peer // should return the list of all peers
+ pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
+}
+
+// NewPeer returns a peer for testing purposes.
+func NewPeer(id ClientIdentity, caps []Cap) *Peer {
+ conn, _ := net.Pipe()
+ peer := newPeer(conn, nil, nil)
+ peer.setHandshakeInfo(id, nil, caps)
+ close(peer.closed)
+ return peer
+}
+
+func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
+ p := newPeer(conn, server.Protocols, dialAddr)
+ p.ourID = server.Identity
+ p.newPeerAddr = server.peerConnect
+ p.otherPeers = server.Peers
+ p.pubkeyHook = server.verifyPeer
+ p.runBaseProtocol = true
+
+ // laddr can be updated concurrently by NAT traversal.
+ // newServerPeer must be called with the server lock held.
+ if server.laddr != nil {
+ p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
+ }
+ return p
+}
+
+func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer {
+ p := &Peer{
+ Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()),
+ conn: conn,
+ dialAddr: dialAddr,
+ bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
+ protocols: protocols,
+ running: make(map[string]*proto),
+ disc: make(chan DiscReason),
+ protoErr: make(chan error),
+ closed: make(chan struct{}),
+ }
+ return p
+}
+
+// Identity returns the client identity of the remote peer. The
+// identity can be nil if the peer has not yet completed the
+// handshake.
+func (p *Peer) Identity() ClientIdentity {
+ p.infolock.Lock()
+ defer p.infolock.Unlock()
+ return p.identity
+}
+
+// Caps returns the capabilities (supported subprotocols) of the remote peer.
+func (p *Peer) Caps() []Cap {
+ p.infolock.Lock()
+ defer p.infolock.Unlock()
+ return p.caps
+}
+
+func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) {
+ p.infolock.Lock()
+ p.identity = id
+ p.listenAddr = laddr
+ p.caps = caps
+ p.infolock.Unlock()
+}
+
+// RemoteAddr returns the remote address of the network connection.
+func (p *Peer) RemoteAddr() net.Addr {
+ return p.conn.RemoteAddr()
+}
+
+// LocalAddr returns the local address of the network connection.
+func (p *Peer) LocalAddr() net.Addr {
+ return p.conn.LocalAddr()
+}
+
+// Disconnect terminates the peer connection with the given reason.
+// It returns immediately and does not wait until the connection is closed.
+func (p *Peer) Disconnect(reason DiscReason) {
+ select {
+ case p.disc <- reason:
+ case <-p.closed:
+ }
+}
+
+// String implements fmt.Stringer.
+func (p *Peer) String() string {
+ kind := "inbound"
+ p.infolock.Lock()
+ if p.dialAddr != nil {
+ kind = "outbound"
+ }
+ p.infolock.Unlock()
+ return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
+}
+
+const (
+ // maximum amount of time allowed for reading a message
+ msgReadTimeout = 5 * time.Second
+ // maximum amount of time allowed for writing a message
+ msgWriteTimeout = 5 * time.Second
+ // messages smaller than this many bytes will be read at
+ // once before passing them to a protocol.
+ wholePayloadSize = 64 * 1024
+)
+
+var (
+ inactivityTimeout = 2 * time.Second
+ disconnectGracePeriod = 2 * time.Second
+)
+
+func (p *Peer) loop() (reason DiscReason, err error) {
+ defer p.activity.Stop()
+ defer p.closeProtocols()
+ defer close(p.closed)
+ defer p.conn.Close()
+
+ // read loop
+ readMsg := make(chan Msg)
+ readErr := make(chan error)
+ readNext := make(chan bool, 1)
+ protoDone := make(chan struct{}, 1)
+ go p.readLoop(readMsg, readErr, readNext)
+ readNext <- true
+
+ if p.runBaseProtocol {
+ p.startBaseProtocol()
+ }
+
+loop:
+ for {
+ select {
+ case msg := <-readMsg:
+ // a new message has arrived.
+ var wait bool
+ if wait, err = p.dispatch(msg, protoDone); err != nil {
+ p.Errorf("msg dispatch error: %v\n", err)
+ reason = discReasonForError(err)
+ break loop
+ }
+ if !wait {
+ // Msg has already been read completely, continue with next message.
+ readNext <- true
+ }
+ p.activity.Post(time.Now())
+ case <-protoDone:
+ // protocol has consumed the message payload,
+ // we can continue reading from the socket.
+ readNext <- true
+
+ case err := <-readErr:
+ // read failed. there is no need to run the
+ // polite disconnect sequence because the connection
+ // is probably dead anyway.
+ // TODO: handle write errors as well
+ return DiscNetworkError, err
+ case err = <-p.protoErr:
+ reason = discReasonForError(err)
+ break loop
+ case reason = <-p.disc:
+ break loop
+ }
+ }
+
+ // wait for read loop to return.
+ close(readNext)
+ <-readErr
+ // tell the remote end to disconnect
+ done := make(chan struct{})
+ go func() {
+ p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod))
+ p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod)
+ io.Copy(ioutil.Discard, p.conn)
+ close(done)
+ }()
+ select {
+ case <-done:
+ case <-time.After(disconnectGracePeriod):
+ }
+ return reason, err
+}
+
+func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) {
+ for _ = range unblock {
+ p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
+ if msg, err := readMsg(p.bufconn); err != nil {
+ errc <- err
+ } else {
+ msgc <- msg
+ }
+ }
+ close(errc)
+}
+
+func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) {
+ proto, err := p.getProto(msg.Code)
+ if err != nil {
+ return false, err
+ }
+ if msg.Size <= wholePayloadSize {
+ // optimization: msg is small enough, read all
+ // of it and move on to the next message
+ buf, err := ioutil.ReadAll(msg.Payload)
+ if err != nil {
+ return false, err
+ }
+ msg.Payload = bytes.NewReader(buf)
+ proto.in <- msg
+ } else {
+ wait = true
+ pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
+ msg.Payload = pr
+ proto.in <- msg
+ }
+ return wait, nil
+}
+
+func (p *Peer) startBaseProtocol() {
+ p.runlock.Lock()
+ defer p.runlock.Unlock()
+ p.running[""] = p.startProto(0, Protocol{
+ Length: baseProtocolLength,
+ Run: runBaseProtocol,
+ })
+}
+
+// startProtocols starts matching named subprotocols.
+func (p *Peer) startSubprotocols(caps []Cap) {
+ sort.Sort(capsByName(caps))
+
+ p.runlock.Lock()
+ defer p.runlock.Unlock()
+ offset := baseProtocolLength
+outer:
+ for _, cap := range caps {
+ for _, proto := range p.protocols {
+ if proto.Name == cap.Name &&
+ proto.Version == cap.Version &&
+ p.running[cap.Name] == nil {
+ p.running[cap.Name] = p.startProto(offset, proto)
+ offset += proto.Length
+ continue outer
+ }
+ }
+ }
+}
+
+func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
+ rw := &proto{
+ in: make(chan Msg),
+ offset: offset,
+ maxcode: impl.Length,
+ peer: p,
+ }
+ p.protoWG.Add(1)
+ go func() {
+ err := impl.Run(p, rw)
+ if err == nil {
+ p.Infof("protocol %q returned", impl.Name)
+ err = newPeerError(errMisc, "protocol returned")
+ } else {
+ p.Errorf("protocol %q error: %v\n", impl.Name, err)
+ }
+ select {
+ case p.protoErr <- err:
+ case <-p.closed:
+ }
+ p.protoWG.Done()
+ }()
+ return rw
+}
+
+// getProto finds the protocol responsible for handling
+// the given message code.
+func (p *Peer) getProto(code uint64) (*proto, error) {
+ p.runlock.RLock()
+ defer p.runlock.RUnlock()
+ for _, proto := range p.running {
+ if code >= proto.offset && code < proto.offset+proto.maxcode {
+ return proto, nil
+ }
+ }
+ return nil, newPeerError(errInvalidMsgCode, "%d", code)
+}
+
+func (p *Peer) closeProtocols() {
+ p.runlock.RLock()
+ for _, p := range p.running {
+ close(p.in)
+ }
+ p.runlock.RUnlock()
+ p.protoWG.Wait()
+}
+
+// writeProtoMsg sends the given message on behalf of the given named protocol.
+func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
+ p.runlock.RLock()
+ proto, ok := p.running[protoName]
+ p.runlock.RUnlock()
+ if !ok {
+ return fmt.Errorf("protocol %s not handled by peer", protoName)
+ }
+ if msg.Code >= proto.maxcode {
+ return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
+ }
+ msg.Code += proto.offset
+ return p.writeMsg(msg, msgWriteTimeout)
+}
+
+// writeMsg writes a message to the connection.
+func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
+ p.writeMu.Lock()
+ defer p.writeMu.Unlock()
+ p.conn.SetWriteDeadline(time.Now().Add(timeout))
+ if err := writeMsg(p.bufconn, msg); err != nil {
+ return newPeerError(errWrite, "%v", err)
+ }
+ return p.bufconn.Flush()
+}
+
+type proto struct {
+ name string
+ in chan Msg
+ maxcode, offset uint64
+ peer *Peer
+}
+
+func (rw *proto) WriteMsg(msg Msg) error {
+ if msg.Code >= rw.maxcode {
+ return newPeerError(errInvalidMsgCode, "not handled")
+ }
+ msg.Code += rw.offset
+ return rw.peer.writeMsg(msg, msgWriteTimeout)
+}
+
+func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
+ return rw.WriteMsg(NewMsg(code, data))
+}
+
+func (rw *proto) ReadMsg() (Msg, error) {
+ msg, ok := <-rw.in
+ if !ok {
+ return msg, io.EOF
+ }
+ msg.Code -= rw.offset
+ return msg, nil
+}
+
+// eofSignal wraps a reader with eof signaling. the eof channel is
+// closed when the wrapped reader returns an error or when count bytes
+// have been read.
+//
+type eofSignal struct {
+ wrapped io.Reader
+ count int64
+ eof chan<- struct{}
+}
+
+// note: when using eofSignal to detect whether a message payload
+// has been read, Read might not be called for zero sized messages.
+
+func (r *eofSignal) Read(buf []byte) (int, error) {
+ n, err := r.wrapped.Read(buf)
+ r.count -= int64(n)
+ if (err != nil || r.count <= 0) && r.eof != nil {
+ r.eof <- struct{}{} // tell Peer that msg has been consumed
+ r.eof = nil
+ }
+ return n, err
+}
diff --git a/p2p/peer_error.go b/p2p/peer_error.go
new file mode 100644
index 000000000..0eb7ec838
--- /dev/null
+++ b/p2p/peer_error.go
@@ -0,0 +1,133 @@
+package p2p
+
+import (
+ "fmt"
+)
+
+const (
+ errMagicTokenMismatch = iota
+ errRead
+ errWrite
+ errMisc
+ errInvalidMsgCode
+ errInvalidMsg
+ errP2PVersionMismatch
+ errPubkeyMissing
+ errPubkeyInvalid
+ errPubkeyForbidden
+ errProtocolBreach
+ errPingTimeout
+ errInvalidNetworkId
+ errInvalidProtocolVersion
+)
+
+var errorToString = map[int]string{
+ errMagicTokenMismatch: "Magic token mismatch",
+ errRead: "Read error",
+ errWrite: "Write error",
+ errMisc: "Misc error",
+ errInvalidMsgCode: "Invalid message code",
+ errInvalidMsg: "Invalid message",
+ errP2PVersionMismatch: "P2P Version Mismatch",
+ errPubkeyMissing: "Public key missing",
+ errPubkeyInvalid: "Public key invalid",
+ errPubkeyForbidden: "Public key forbidden",
+ errProtocolBreach: "Protocol Breach",
+ errPingTimeout: "Ping timeout",
+ errInvalidNetworkId: "Invalid network id",
+ errInvalidProtocolVersion: "Invalid protocol version",
+}
+
+type peerError struct {
+ Code int
+ message string
+}
+
+func newPeerError(code int, format string, v ...interface{}) *peerError {
+ desc, ok := errorToString[code]
+ if !ok {
+ panic("invalid error code")
+ }
+ err := &peerError{code, desc}
+ if format != "" {
+ err.message += ": " + fmt.Sprintf(format, v...)
+ }
+ return err
+}
+
+func (self *peerError) Error() string {
+ return self.message
+}
+
+type DiscReason byte
+
+const (
+ DiscRequested DiscReason = 0x00
+ DiscNetworkError = 0x01
+ DiscProtocolError = 0x02
+ DiscUselessPeer = 0x03
+ DiscTooManyPeers = 0x04
+ DiscAlreadyConnected = 0x05
+ DiscIncompatibleVersion = 0x06
+ DiscInvalidIdentity = 0x07
+ DiscQuitting = 0x08
+ DiscUnexpectedIdentity = 0x09
+ DiscSelf = 0x0a
+ DiscReadTimeout = 0x0b
+ DiscSubprotocolError = 0x10
+)
+
+var discReasonToString = [DiscSubprotocolError + 1]string{
+ DiscRequested: "Disconnect requested",
+ DiscNetworkError: "Network error",
+ DiscProtocolError: "Breach of protocol",
+ DiscUselessPeer: "Useless peer",
+ DiscTooManyPeers: "Too many peers",
+ DiscAlreadyConnected: "Already connected",
+ DiscIncompatibleVersion: "Incompatible P2P protocol version",
+ DiscInvalidIdentity: "Invalid node identity",
+ DiscQuitting: "Client quitting",
+ DiscUnexpectedIdentity: "Unexpected identity",
+ DiscSelf: "Connected to self",
+ DiscReadTimeout: "Read timeout",
+ DiscSubprotocolError: "Subprotocol error",
+}
+
+func (d DiscReason) String() string {
+ if len(discReasonToString) < int(d) {
+ return fmt.Sprintf("Unknown Reason(%d)", d)
+ }
+ return discReasonToString[d]
+}
+
+type discRequestedError DiscReason
+
+func (err discRequestedError) Error() string {
+ return fmt.Sprintf("disconnect requested: %v", DiscReason(err))
+}
+
+func discReasonForError(err error) DiscReason {
+ if reason, ok := err.(discRequestedError); ok {
+ return DiscReason(reason)
+ }
+ peerError, ok := err.(*peerError)
+ if !ok {
+ return DiscSubprotocolError
+ }
+ switch peerError.Code {
+ case errP2PVersionMismatch:
+ return DiscIncompatibleVersion
+ case errPubkeyMissing, errPubkeyInvalid:
+ return DiscInvalidIdentity
+ case errPubkeyForbidden:
+ return DiscUselessPeer
+ case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
+ return DiscProtocolError
+ case errPingTimeout:
+ return DiscReadTimeout
+ case errRead, errWrite, errMisc:
+ return DiscNetworkError
+ default:
+ return DiscSubprotocolError
+ }
+}
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
new file mode 100644
index 000000000..f7759786e
--- /dev/null
+++ b/p2p/peer_test.go
@@ -0,0 +1,295 @@
+package p2p
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/hex"
+ "io"
+ "io/ioutil"
+ "net"
+ "reflect"
+ "testing"
+ "time"
+)
+
+var discard = Protocol{
+ Name: "discard",
+ Length: 1,
+ Run: func(p *Peer, rw MsgReadWriter) error {
+ for {
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if err = msg.Discard(); err != nil {
+ return err
+ }
+ }
+ },
+}
+
+func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
+ conn1, conn2 := net.Pipe()
+ id := NewSimpleClientIdentity("test", "0", "0", "public key")
+ peer := newPeer(conn1, protos, nil)
+ peer.ourID = id
+ peer.pubkeyHook = func(*peerAddr) error { return nil }
+ errc := make(chan error, 1)
+ go func() {
+ _, err := peer.loop()
+ errc <- err
+ }()
+ return conn2, peer, errc
+}
+
+func TestPeerProtoReadMsg(t *testing.T) {
+ defer testlog(t).detach()
+
+ done := make(chan struct{})
+ proto := Protocol{
+ Name: "a",
+ Length: 5,
+ Run: func(peer *Peer, rw MsgReadWriter) error {
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ t.Errorf("read error: %v", err)
+ }
+ if msg.Code != 2 {
+ t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
+ }
+ data, err := ioutil.ReadAll(msg.Payload)
+ if err != nil {
+ t.Errorf("payload read error: %v", err)
+ }
+ expdata, _ := hex.DecodeString("0183303030")
+ if !bytes.Equal(expdata, data) {
+ t.Errorf("incorrect msg data %x", data)
+ }
+ close(done)
+ return nil
+ },
+ }
+
+ net, peer, errc := testPeer([]Protocol{proto})
+ defer net.Close()
+ peer.startSubprotocols([]Cap{proto.cap()})
+
+ writeMsg(net, NewMsg(18, 1, "000"))
+ select {
+ case <-done:
+ case err := <-errc:
+ t.Errorf("peer returned: %v", err)
+ case <-time.After(2 * time.Second):
+ t.Errorf("receive timeout")
+ }
+}
+
+func TestPeerProtoReadLargeMsg(t *testing.T) {
+ defer testlog(t).detach()
+
+ msgsize := uint32(10 * 1024 * 1024)
+ done := make(chan struct{})
+ proto := Protocol{
+ Name: "a",
+ Length: 5,
+ Run: func(peer *Peer, rw MsgReadWriter) error {
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ t.Errorf("read error: %v", err)
+ }
+ if msg.Size != msgsize+4 {
+ t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
+ }
+ msg.Discard()
+ close(done)
+ return nil
+ },
+ }
+
+ net, peer, errc := testPeer([]Protocol{proto})
+ defer net.Close()
+ peer.startSubprotocols([]Cap{proto.cap()})
+
+ writeMsg(net, NewMsg(18, make([]byte, msgsize)))
+ select {
+ case <-done:
+ case err := <-errc:
+ t.Errorf("peer returned: %v", err)
+ case <-time.After(2 * time.Second):
+ t.Errorf("receive timeout")
+ }
+}
+
+func TestPeerProtoEncodeMsg(t *testing.T) {
+ defer testlog(t).detach()
+
+ proto := Protocol{
+ Name: "a",
+ Length: 2,
+ Run: func(peer *Peer, rw MsgReadWriter) error {
+ if err := rw.EncodeMsg(2); err == nil {
+ t.Error("expected error for out-of-range msg code, got nil")
+ }
+ if err := rw.EncodeMsg(1); err != nil {
+ t.Errorf("write error: %v", err)
+ }
+ return nil
+ },
+ }
+ net, peer, _ := testPeer([]Protocol{proto})
+ defer net.Close()
+ peer.startSubprotocols([]Cap{proto.cap()})
+
+ bufr := bufio.NewReader(net)
+ msg, err := readMsg(bufr)
+ if err != nil {
+ t.Errorf("read error: %v", err)
+ }
+ if msg.Code != 17 {
+ t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
+ }
+}
+
+func TestPeerWrite(t *testing.T) {
+ defer testlog(t).detach()
+
+ net, peer, peerErr := testPeer([]Protocol{discard})
+ defer net.Close()
+ peer.startSubprotocols([]Cap{discard.cap()})
+
+ // test write errors
+ if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
+ t.Errorf("expected error for unknown protocol, got nil")
+ }
+ if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
+ t.Errorf("expected error for out-of-range msg code, got nil")
+ } else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
+ t.Errorf("wrong error for out-of-range msg code, got %#v", err)
+ }
+
+ // setup for reading the message on the other end
+ read := make(chan struct{})
+ go func() {
+ bufr := bufio.NewReader(net)
+ msg, err := readMsg(bufr)
+ if err != nil {
+ t.Errorf("read error: %v", err)
+ } else if msg.Code != 16 {
+ t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
+ }
+ msg.Discard()
+ close(read)
+ }()
+
+ // test succcessful write
+ if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
+ t.Errorf("expect no error for known protocol: %v", err)
+ }
+ select {
+ case <-read:
+ case err := <-peerErr:
+ t.Fatalf("peer stopped: %v", err)
+ }
+}
+
+func TestPeerActivity(t *testing.T) {
+ // shorten inactivityTimeout while this test is running
+ oldT := inactivityTimeout
+ defer func() { inactivityTimeout = oldT }()
+ inactivityTimeout = 20 * time.Millisecond
+
+ net, peer, peerErr := testPeer([]Protocol{discard})
+ defer net.Close()
+ peer.startSubprotocols([]Cap{discard.cap()})
+
+ sub := peer.activity.Subscribe(time.Time{})
+ defer sub.Unsubscribe()
+
+ for i := 0; i < 6; i++ {
+ writeMsg(net, NewMsg(16))
+ select {
+ case <-sub.Chan():
+ case <-time.After(inactivityTimeout / 2):
+ t.Fatal("no event within ", inactivityTimeout/2)
+ case err := <-peerErr:
+ t.Fatal("peer error", err)
+ }
+ }
+
+ select {
+ case <-time.After(inactivityTimeout * 2):
+ case <-sub.Chan():
+ t.Fatal("got activity event while connection was inactive")
+ case err := <-peerErr:
+ t.Fatal("peer error", err)
+ }
+}
+
+func TestNewPeer(t *testing.T) {
+ id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey")
+ caps := []Cap{{"foo", 2}, {"bar", 3}}
+ p := NewPeer(id, caps)
+ if !reflect.DeepEqual(p.Caps(), caps) {
+ t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
+ }
+ if p.Identity() != id {
+ t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id)
+ }
+ // Should not hang.
+ p.Disconnect(DiscAlreadyConnected)
+}
+
+func TestEOFSignal(t *testing.T) {
+ rb := make([]byte, 10)
+
+ // empty reader
+ eof := make(chan struct{}, 1)
+ sig := &eofSignal{new(bytes.Buffer), 0, eof}
+ if n, err := sig.Read(rb); n != 0 || err != io.EOF {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ default:
+ t.Error("EOF chan not signaled")
+ }
+
+ // count before error
+ eof = make(chan struct{}, 1)
+ sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
+ if n, err := sig.Read(rb); n != 8 || err != nil {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ default:
+ t.Error("EOF chan not signaled")
+ }
+
+ // error before count
+ eof = make(chan struct{}, 1)
+ sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
+ if n, err := sig.Read(rb); n != 4 || err != nil {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ if n, err := sig.Read(rb); n != 0 || err != io.EOF {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ default:
+ t.Error("EOF chan not signaled")
+ }
+
+ // no signal if neither occurs
+ eof = make(chan struct{}, 1)
+ sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
+ if n, err := sig.Read(rb); n != 10 || err != nil {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ t.Error("unexpected EOF signal")
+ default:
+ }
+}
diff --git a/p2p/protocol.go b/p2p/protocol.go
new file mode 100644
index 000000000..3f52205f5
--- /dev/null
+++ b/p2p/protocol.go
@@ -0,0 +1,294 @@
+package p2p
+
+import (
+ "bytes"
+ "time"
+
+ "github.com/ethereum/go-ethereum/ethutil"
+)
+
+// Protocol represents a P2P subprotocol implementation.
+type Protocol struct {
+ // Name should contain the official protocol name,
+ // often a three-letter word.
+ Name string
+
+ // Version should contain the version number of the protocol.
+ Version uint
+
+ // Length should contain the number of message codes used
+ // by the protocol.
+ Length uint64
+
+ // Run is called in a new groutine when the protocol has been
+ // negotiated with a peer. It should read and write messages from
+ // rw. The Payload for each message must be fully consumed.
+ //
+ // The peer connection is closed when Start returns. It should return
+ // any protocol-level error (such as an I/O error) that is
+ // encountered.
+ Run func(peer *Peer, rw MsgReadWriter) error
+}
+
+func (p Protocol) cap() Cap {
+ return Cap{p.Name, p.Version}
+}
+
+const (
+ baseProtocolVersion = 2
+ baseProtocolLength = uint64(16)
+ baseProtocolMaxMsgSize = 10 * 1024 * 1024
+)
+
+const (
+ // devp2p message codes
+ handshakeMsg = 0x00
+ discMsg = 0x01
+ pingMsg = 0x02
+ pongMsg = 0x03
+ getPeersMsg = 0x04
+ peersMsg = 0x05
+)
+
+// handshake is the structure of a handshake list.
+type handshake struct {
+ Version uint64
+ ID string
+ Caps []Cap
+ ListenPort uint64
+ NodeID []byte
+}
+
+func (h *handshake) String() string {
+ return h.ID
+}
+func (h *handshake) Pubkey() []byte {
+ return h.NodeID
+}
+
+// Cap is the structure of a peer capability.
+type Cap struct {
+ Name string
+ Version uint
+}
+
+func (cap Cap) RlpData() interface{} {
+ return []interface{}{cap.Name, cap.Version}
+}
+
+type capsByName []Cap
+
+func (cs capsByName) Len() int { return len(cs) }
+func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
+func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
+
+type baseProtocol struct {
+ rw MsgReadWriter
+ peer *Peer
+}
+
+func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
+ bp := &baseProtocol{rw, peer}
+ if err := bp.doHandshake(rw); err != nil {
+ return err
+ }
+ // run main loop
+ quit := make(chan error, 1)
+ go func() {
+ for {
+ if err := bp.handle(rw); err != nil {
+ quit <- err
+ break
+ }
+ }
+ }()
+ return bp.loop(quit)
+}
+
+var pingTimeout = 2 * time.Second
+
+func (bp *baseProtocol) loop(quit <-chan error) error {
+ ping := time.NewTimer(pingTimeout)
+ activity := bp.peer.activity.Subscribe(time.Time{})
+ lastActive := time.Time{}
+ defer ping.Stop()
+ defer activity.Unsubscribe()
+
+ getPeersTick := time.NewTicker(10 * time.Second)
+ defer getPeersTick.Stop()
+ err := bp.rw.EncodeMsg(getPeersMsg)
+
+ for err == nil {
+ select {
+ case err = <-quit:
+ return err
+ case <-getPeersTick.C:
+ err = bp.rw.EncodeMsg(getPeersMsg)
+ case event := <-activity.Chan():
+ ping.Reset(pingTimeout)
+ lastActive = event.(time.Time)
+ case t := <-ping.C:
+ if lastActive.Add(pingTimeout * 2).Before(t) {
+ err = newPeerError(errPingTimeout, "")
+ } else if lastActive.Add(pingTimeout).Before(t) {
+ err = bp.rw.EncodeMsg(pingMsg)
+ }
+ }
+ }
+ return err
+}
+
+func (bp *baseProtocol) handle(rw MsgReadWriter) error {
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if msg.Size > baseProtocolMaxMsgSize {
+ return newPeerError(errMisc, "message too big")
+ }
+ // make sure that the payload has been fully consumed
+ defer msg.Discard()
+
+ switch msg.Code {
+ case handshakeMsg:
+ return newPeerError(errProtocolBreach, "extra handshake received")
+
+ case discMsg:
+ var reason [1]DiscReason
+ if err := msg.Decode(&reason); err != nil {
+ return err
+ }
+ return discRequestedError(reason[0])
+
+ case pingMsg:
+ return bp.rw.EncodeMsg(pongMsg)
+
+ case pongMsg:
+
+ case getPeersMsg:
+ peers := bp.peerList()
+ // this is dangerous. the spec says that we should _delay_
+ // sending the response if no new information is available.
+ // this means that would need to send a response later when
+ // new peers become available.
+ //
+ // TODO: add event mechanism to notify baseProtocol for new peers
+ if len(peers) > 0 {
+ return bp.rw.EncodeMsg(peersMsg, peers)
+ }
+
+ case peersMsg:
+ var peers []*peerAddr
+ if err := msg.Decode(&peers); err != nil {
+ return err
+ }
+ for _, addr := range peers {
+ bp.peer.Debugf("received peer suggestion: %v", addr)
+ bp.peer.newPeerAddr <- addr
+ }
+
+ default:
+ return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
+ }
+ return nil
+}
+
+func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error {
+ // send our handshake
+ if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
+ return err
+ }
+
+ // read and handle remote handshake
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if msg.Code != handshakeMsg {
+ return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
+ }
+ if msg.Size > baseProtocolMaxMsgSize {
+ return newPeerError(errMisc, "message too big")
+ }
+
+ var hs handshake
+ if err := msg.Decode(&hs); err != nil {
+ return err
+ }
+
+ // validate handshake info
+ if hs.Version != baseProtocolVersion {
+ return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
+ baseProtocolVersion, hs.Version)
+ }
+ if len(hs.NodeID) == 0 {
+ return newPeerError(errPubkeyMissing, "")
+ }
+ if len(hs.NodeID) != 64 {
+ return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
+ }
+ if da := bp.peer.dialAddr; da != nil {
+ // verify that the peer we wanted to connect to
+ // actually holds the target public key.
+ if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
+ return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
+ }
+ }
+ pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
+ if err := bp.peer.pubkeyHook(pa); err != nil {
+ return newPeerError(errPubkeyForbidden, "%v", err)
+ }
+
+ // TODO: remove Caps with empty name
+
+ var addr *peerAddr
+ if hs.ListenPort != 0 {
+ addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
+ addr.Port = hs.ListenPort
+ }
+ bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
+ bp.peer.startSubprotocols(hs.Caps)
+ return nil
+}
+
+func (bp *baseProtocol) handshakeMsg() Msg {
+ var (
+ port uint64
+ caps []interface{}
+ )
+ if bp.peer.ourListenAddr != nil {
+ port = bp.peer.ourListenAddr.Port
+ }
+ for _, proto := range bp.peer.protocols {
+ caps = append(caps, proto.cap())
+ }
+ return NewMsg(handshakeMsg,
+ baseProtocolVersion,
+ bp.peer.ourID.String(),
+ caps,
+ port,
+ bp.peer.ourID.Pubkey()[1:],
+ )
+}
+
+func (bp *baseProtocol) peerList() []ethutil.RlpEncodable {
+ peers := bp.peer.otherPeers()
+ ds := make([]ethutil.RlpEncodable, 0, len(peers))
+ for _, p := range peers {
+ p.infolock.Lock()
+ addr := p.listenAddr
+ p.infolock.Unlock()
+ // filter out this peer and peers that are not listening or
+ // have not completed the handshake.
+ // TODO: track previously sent peers and exclude them as well.
+ if p == bp.peer || addr == nil {
+ continue
+ }
+ ds = append(ds, addr)
+ }
+ ourAddr := bp.peer.ourListenAddr
+ if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
+ ds = append(ds, ourAddr)
+ }
+ return ds
+}
diff --git a/p2p/protocol_test.go b/p2p/protocol_test.go
new file mode 100644
index 000000000..65f26fb12
--- /dev/null
+++ b/p2p/protocol_test.go
@@ -0,0 +1,58 @@
+package p2p
+
+import (
+ "fmt"
+ "testing"
+)
+
+func TestBaseProtocolDisconnect(t *testing.T) {
+ peer := NewPeer(NewSimpleClientIdentity("p1", "", "", "foo"), nil)
+ peer.ourID = NewSimpleClientIdentity("p2", "", "", "bar")
+ peer.pubkeyHook = func(*peerAddr) error { return nil }
+
+ rw1, rw2 := MsgPipe()
+ done := make(chan struct{})
+ go func() {
+ if err := expectMsg(rw2, handshakeMsg); err != nil {
+ t.Error(err)
+ }
+ err := rw2.EncodeMsg(handshakeMsg,
+ baseProtocolVersion,
+ "",
+ []interface{}{},
+ 0,
+ make([]byte, 64),
+ )
+ if err != nil {
+ t.Error(err)
+ }
+ if err := expectMsg(rw2, getPeersMsg); err != nil {
+ t.Error(err)
+ }
+ if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil {
+ t.Error(err)
+ }
+ close(done)
+ }()
+
+ if err := runBaseProtocol(peer, rw1); err == nil {
+ t.Errorf("base protocol returned without error")
+ } else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting {
+ t.Errorf("base protocol returned wrong error: %v", err)
+ }
+ <-done
+}
+
+func expectMsg(r MsgReader, code uint64) error {
+ msg, err := r.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if err := msg.Discard(); err != nil {
+ return err
+ }
+ if msg.Code != code {
+ return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code)
+ }
+ return nil
+}
diff --git a/p2p/server.go b/p2p/server.go
new file mode 100644
index 000000000..8a6087566
--- /dev/null
+++ b/p2p/server.go
@@ -0,0 +1,467 @@
+package p2p
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/logger"
+)
+
+const (
+ outboundAddressPoolSize = 500
+ defaultDialTimeout = 10 * time.Second
+ portMappingUpdateInterval = 15 * time.Minute
+ portMappingTimeout = 20 * time.Minute
+)
+
+var srvlog = logger.NewLogger("P2P Server")
+
+// Server manages all peer connections.
+//
+// The fields of Server are used as configuration parameters.
+// You should set them before starting the Server. Fields may not be
+// modified while the server is running.
+type Server struct {
+ // This field must be set to a valid client identity.
+ Identity ClientIdentity
+
+ // MaxPeers is the maximum number of peers that can be
+ // connected. It must be greater than zero.
+ MaxPeers int
+
+ // Protocols should contain the protocols supported
+ // by the server. Matching protocols are launched for
+ // each peer.
+ Protocols []Protocol
+
+ // If Blacklist is set to a non-nil value, the given Blacklist
+ // is used to verify peer connections.
+ Blacklist Blacklist
+
+ // If ListenAddr is set to a non-nil address, the server
+ // will listen for incoming connections.
+ //
+ // If the port is zero, the operating system will pick a port. The
+ // ListenAddr field will be updated with the actual address when
+ // the server is started.
+ ListenAddr string
+
+ // If set to a non-nil value, the given NAT port mapper
+ // is used to make the listening port available to the
+ // Internet.
+ NAT NAT
+
+ // If Dialer is set to a non-nil value, the given Dialer
+ // is used to dial outbound peer connections.
+ Dialer *net.Dialer
+
+ // If NoDial is true, the server will not dial any peers.
+ NoDial bool
+
+ // Hook for testing. This is useful because we can inhibit
+ // the whole protocol stack.
+ newPeerFunc peerFunc
+
+ lock sync.RWMutex
+ running bool
+ listener net.Listener
+ laddr *net.TCPAddr // real listen addr
+ peers []*Peer
+ peerSlots chan int
+ peerCount int
+
+ quit chan struct{}
+ wg sync.WaitGroup
+ peerConnect chan *peerAddr
+ peerDisconnect chan *Peer
+}
+
+// NAT is implemented by NAT traversal methods.
+type NAT interface {
+ GetExternalAddress() (net.IP, error)
+ AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
+ DeletePortMapping(protocol string, extport, intport int) error
+
+ // Should return name of the method.
+ String() string
+}
+
+type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer
+
+// Peers returns all connected peers.
+func (srv *Server) Peers() (peers []*Peer) {
+ srv.lock.RLock()
+ defer srv.lock.RUnlock()
+ for _, peer := range srv.peers {
+ if peer != nil {
+ peers = append(peers, peer)
+ }
+ }
+ return
+}
+
+// PeerCount returns the number of connected peers.
+func (srv *Server) PeerCount() int {
+ srv.lock.RLock()
+ defer srv.lock.RUnlock()
+ return srv.peerCount
+}
+
+// SuggestPeer injects an address into the outbound address pool.
+func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
+ select {
+ case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}:
+ default: // don't block
+ }
+}
+
+// Broadcast sends an RLP-encoded message to all connected peers.
+// This method is deprecated and will be removed later.
+func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
+ var payload []byte
+ if data != nil {
+ payload = encodePayload(data...)
+ }
+ srv.lock.RLock()
+ defer srv.lock.RUnlock()
+ for _, peer := range srv.peers {
+ if peer != nil {
+ var msg = Msg{Code: code}
+ if data != nil {
+ msg.Payload = bytes.NewReader(payload)
+ msg.Size = uint32(len(payload))
+ }
+ peer.writeProtoMsg(protocol, msg)
+ }
+ }
+}
+
+// Start starts running the server.
+// Servers can be re-used and started again after stopping.
+func (srv *Server) Start() (err error) {
+ srv.lock.Lock()
+ defer srv.lock.Unlock()
+ if srv.running {
+ return errors.New("server already running")
+ }
+ srvlog.Infoln("Starting Server")
+
+ // initialize fields
+ if srv.Identity == nil {
+ return fmt.Errorf("Server.Identity must be set to a non-nil identity")
+ }
+ if srv.MaxPeers <= 0 {
+ return fmt.Errorf("Server.MaxPeers must be > 0")
+ }
+ srv.quit = make(chan struct{})
+ srv.peers = make([]*Peer, srv.MaxPeers)
+ srv.peerSlots = make(chan int, srv.MaxPeers)
+ srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize)
+ srv.peerDisconnect = make(chan *Peer)
+ if srv.newPeerFunc == nil {
+ srv.newPeerFunc = newServerPeer
+ }
+ if srv.Blacklist == nil {
+ srv.Blacklist = NewBlacklist()
+ }
+ if srv.Dialer == nil {
+ srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
+ }
+
+ if srv.ListenAddr != "" {
+ if err := srv.startListening(); err != nil {
+ return err
+ }
+ }
+ if !srv.NoDial {
+ srv.wg.Add(1)
+ go srv.dialLoop()
+ }
+ if srv.NoDial && srv.ListenAddr == "" {
+ srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
+ }
+
+ // make all slots available
+ for i := range srv.peers {
+ srv.peerSlots <- i
+ }
+ // note: discLoop is not part of WaitGroup
+ go srv.discLoop()
+ srv.running = true
+ return nil
+}
+
+func (srv *Server) startListening() error {
+ listener, err := net.Listen("tcp", srv.ListenAddr)
+ if err != nil {
+ return err
+ }
+ srv.ListenAddr = listener.Addr().String()
+ srv.laddr = listener.Addr().(*net.TCPAddr)
+ srv.listener = listener
+ srv.wg.Add(1)
+ go srv.listenLoop()
+ if !srv.laddr.IP.IsLoopback() && srv.NAT != nil {
+ srv.wg.Add(1)
+ go srv.natLoop(srv.laddr.Port)
+ }
+ return nil
+}
+
+// Stop terminates the server and all active peer connections.
+// It blocks until all active connections have been closed.
+func (srv *Server) Stop() {
+ srv.lock.Lock()
+ if !srv.running {
+ srv.lock.Unlock()
+ return
+ }
+ srv.running = false
+ srv.lock.Unlock()
+
+ srvlog.Infoln("Stopping server")
+ if srv.listener != nil {
+ // this unblocks listener Accept
+ srv.listener.Close()
+ }
+ close(srv.quit)
+ for _, peer := range srv.Peers() {
+ peer.Disconnect(DiscQuitting)
+ }
+ srv.wg.Wait()
+
+ // wait till they actually disconnect
+ // this is checked by claiming all peerSlots.
+ // slots become available as the peers disconnect.
+ for i := 0; i < cap(srv.peerSlots); i++ {
+ <-srv.peerSlots
+ }
+ // terminate discLoop
+ close(srv.peerDisconnect)
+}
+
+func (srv *Server) discLoop() {
+ for peer := range srv.peerDisconnect {
+ // peer has just disconnected. free up its slot.
+ srvlog.Infof("%v is gone", peer)
+ srv.peerSlots <- peer.slot
+ srv.lock.Lock()
+ srv.peers[peer.slot] = nil
+ srv.lock.Unlock()
+ }
+}
+
+// main loop for adding connections via listening
+func (srv *Server) listenLoop() {
+ defer srv.wg.Done()
+
+ srvlog.Infoln("Listening on", srv.listener.Addr())
+ for {
+ select {
+ case slot := <-srv.peerSlots:
+ conn, err := srv.listener.Accept()
+ if err != nil {
+ srv.peerSlots <- slot
+ return
+ }
+ srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot)
+ srv.addPeer(conn, nil, slot)
+ case <-srv.quit:
+ return
+ }
+ }
+}
+
+func (srv *Server) natLoop(port int) {
+ defer srv.wg.Done()
+ for {
+ srv.updatePortMapping(port)
+ select {
+ case <-time.After(portMappingUpdateInterval):
+ // one more round
+ case <-srv.quit:
+ srv.removePortMapping(port)
+ return
+ }
+ }
+}
+
+func (srv *Server) updatePortMapping(port int) {
+ srvlog.Infoln("Attempting to map port", port, "with", srv.NAT)
+ err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout)
+ if err != nil {
+ srvlog.Errorln("Port mapping error:", err)
+ return
+ }
+ extip, err := srv.NAT.GetExternalAddress()
+ if err != nil {
+ srvlog.Errorln("Error getting external IP:", err)
+ return
+ }
+ srv.lock.Lock()
+ extaddr := *(srv.listener.Addr().(*net.TCPAddr))
+ extaddr.IP = extip
+ srvlog.Infoln("Mapped port, external addr is", &extaddr)
+ srv.laddr = &extaddr
+ srv.lock.Unlock()
+}
+
+func (srv *Server) removePortMapping(port int) {
+ srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT)
+ srv.NAT.DeletePortMapping("tcp", port, port)
+}
+
+func (srv *Server) dialLoop() {
+ defer srv.wg.Done()
+ var (
+ suggest chan *peerAddr
+ slot *int
+ slots = srv.peerSlots
+ )
+ for {
+ select {
+ case i := <-slots:
+ // we need a peer in slot i, slot reserved
+ slot = &i
+ // now we can watch for candidate peers in the next loop
+ suggest = srv.peerConnect
+ // do not consume more until candidate peer is found
+ slots = nil
+
+ case desc := <-suggest:
+ // candidate peer found, will dial out asyncronously
+ // if connection fails slot will be released
+ go srv.dialPeer(desc, *slot)
+ // we can watch if more peers needed in the next loop
+ slots = srv.peerSlots
+ // until then we dont care about candidate peers
+ suggest = nil
+
+ case <-srv.quit:
+ // give back the currently reserved slot
+ if slot != nil {
+ srv.peerSlots <- *slot
+ }
+ return
+ }
+ }
+}
+
+// connect to peer via dial out
+func (srv *Server) dialPeer(desc *peerAddr, slot int) {
+ srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot)
+ conn, err := srv.Dialer.Dial(desc.Network(), desc.String())
+ if err != nil {
+ srvlog.Errorf("Dial error: %v", err)
+ srv.peerSlots <- slot
+ return
+ }
+ go srv.addPeer(conn, desc, slot)
+}
+
+// creates the new peer object and inserts it into its slot
+func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer {
+ srv.lock.Lock()
+ defer srv.lock.Unlock()
+ if !srv.running {
+ conn.Close()
+ srv.peerSlots <- slot // release slot
+ return nil
+ }
+ peer := srv.newPeerFunc(srv, conn, desc)
+ peer.slot = slot
+ srv.peers[slot] = peer
+ srv.peerCount++
+ go func() { peer.loop(); srv.peerDisconnect <- peer }()
+ return peer
+}
+
+// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
+func (srv *Server) removePeer(peer *Peer) {
+ srv.lock.Lock()
+ defer srv.lock.Unlock()
+ srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot)
+ if srv.peers[peer.slot] != peer {
+ srvlog.Warnln("Invalid peer to remove:", peer)
+ return
+ }
+ // remove from list and index
+ srv.peerCount--
+ srv.peers[peer.slot] = nil
+ // release slot to signal need for a new peer, last!
+ srv.peerSlots <- peer.slot
+}
+
+func (srv *Server) verifyPeer(addr *peerAddr) error {
+ if srv.Blacklist.Exists(addr.Pubkey) {
+ return errors.New("blacklisted")
+ }
+ if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) {
+ return newPeerError(errPubkeyForbidden, "not allowed to connect to srv")
+ }
+ srv.lock.RLock()
+ defer srv.lock.RUnlock()
+ for _, peer := range srv.peers {
+ if peer != nil {
+ id := peer.Identity()
+ if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) {
+ return errors.New("already connected")
+ }
+ }
+ }
+ return nil
+}
+
+type Blacklist interface {
+ Get([]byte) (bool, error)
+ Put([]byte) error
+ Delete([]byte) error
+ Exists(pubkey []byte) (ok bool)
+}
+
+type BlacklistMap struct {
+ blacklist map[string]bool
+ lock sync.RWMutex
+}
+
+func NewBlacklist() *BlacklistMap {
+ return &BlacklistMap{
+ blacklist: make(map[string]bool),
+ }
+}
+
+func (self *BlacklistMap) Get(pubkey []byte) (bool, error) {
+ self.lock.RLock()
+ defer self.lock.RUnlock()
+ v, ok := self.blacklist[string(pubkey)]
+ var err error
+ if !ok {
+ err = fmt.Errorf("not found")
+ }
+ return v, err
+}
+
+func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) {
+ self.lock.RLock()
+ defer self.lock.RUnlock()
+ _, ok = self.blacklist[string(pubkey)]
+ return
+}
+
+func (self *BlacklistMap) Put(pubkey []byte) error {
+ self.lock.RLock()
+ defer self.lock.RUnlock()
+ self.blacklist[string(pubkey)] = true
+ return nil
+}
+
+func (self *BlacklistMap) Delete(pubkey []byte) error {
+ self.lock.RLock()
+ defer self.lock.RUnlock()
+ delete(self.blacklist, string(pubkey))
+ return nil
+}
diff --git a/p2p/server_test.go b/p2p/server_test.go
new file mode 100644
index 000000000..5c0d08d39
--- /dev/null
+++ b/p2p/server_test.go
@@ -0,0 +1,161 @@
+package p2p
+
+import (
+ "bytes"
+ "io"
+ "net"
+ "sync"
+ "testing"
+ "time"
+)
+
+func startTestServer(t *testing.T, pf peerFunc) *Server {
+ server := &Server{
+ Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"),
+ MaxPeers: 10,
+ ListenAddr: "127.0.0.1:0",
+ newPeerFunc: pf,
+ }
+ if err := server.Start(); err != nil {
+ t.Fatalf("Could not start server: %v", err)
+ }
+ return server
+}
+
+func TestServerListen(t *testing.T) {
+ defer testlog(t).detach()
+
+ // start the test server
+ connected := make(chan *Peer)
+ srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
+ if conn == nil {
+ t.Error("peer func called with nil conn")
+ }
+ if dialAddr != nil {
+ t.Error("peer func called with non-nil dialAddr")
+ }
+ peer := newPeer(conn, nil, dialAddr)
+ connected <- peer
+ return peer
+ })
+ defer close(connected)
+ defer srv.Stop()
+
+ // dial the test server
+ conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
+ if err != nil {
+ t.Fatalf("could not dial: %v", err)
+ }
+ defer conn.Close()
+
+ select {
+ case peer := <-connected:
+ if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() {
+ t.Errorf("peer started with wrong conn: got %v, want %v",
+ peer.conn.LocalAddr(), conn.RemoteAddr())
+ }
+ case <-time.After(1 * time.Second):
+ t.Error("server did not accept within one second")
+ }
+}
+
+func TestServerDial(t *testing.T) {
+ defer testlog(t).detach()
+
+ // run a fake TCP server to handle the connection.
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("could not setup listener: %v")
+ }
+ defer listener.Close()
+ accepted := make(chan net.Conn)
+ go func() {
+ conn, err := listener.Accept()
+ if err != nil {
+ t.Error("acccept error:", err)
+ }
+ conn.Close()
+ accepted <- conn
+ }()
+
+ // start the test server
+ connected := make(chan *Peer)
+ srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
+ if conn == nil {
+ t.Error("peer func called with nil conn")
+ }
+ peer := newPeer(conn, nil, dialAddr)
+ connected <- peer
+ return peer
+ })
+ defer close(connected)
+ defer srv.Stop()
+
+ // tell the server to connect.
+ connAddr := newPeerAddr(listener.Addr(), nil)
+ srv.peerConnect <- connAddr
+
+ select {
+ case conn := <-accepted:
+ select {
+ case peer := <-connected:
+ if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() {
+ t.Errorf("peer started with wrong conn: got %v, want %v",
+ peer.conn.RemoteAddr(), conn.LocalAddr())
+ }
+ if peer.dialAddr != connAddr {
+ t.Errorf("peer started with wrong dialAddr: got %v, want %v",
+ peer.dialAddr, connAddr)
+ }
+ case <-time.After(1 * time.Second):
+ t.Error("server did not launch peer within one second")
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Error("server did not connect within one second")
+ }
+}
+
+func TestServerBroadcast(t *testing.T) {
+ defer testlog(t).detach()
+ var connected sync.WaitGroup
+ srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer {
+ peer := newPeer(c, []Protocol{discard}, dialAddr)
+ peer.startSubprotocols([]Cap{discard.cap()})
+ connected.Done()
+ return peer
+ })
+ defer srv.Stop()
+
+ // dial a bunch of conns
+ var conns = make([]net.Conn, 8)
+ connected.Add(len(conns))
+ deadline := time.Now().Add(3 * time.Second)
+ dialer := &net.Dialer{Deadline: deadline}
+ for i := range conns {
+ conn, err := dialer.Dial("tcp", srv.ListenAddr)
+ if err != nil {
+ t.Fatalf("conn %d: dial error: %v", i, err)
+ }
+ defer conn.Close()
+ conn.SetDeadline(deadline)
+ conns[i] = conn
+ }
+ connected.Wait()
+
+ // broadcast one message
+ srv.Broadcast("discard", 0, "foo")
+ goldbuf := new(bytes.Buffer)
+ writeMsg(goldbuf, NewMsg(16, "foo"))
+ golden := goldbuf.Bytes()
+
+ // check that the message has been written everywhere
+ for i, conn := range conns {
+ buf := make([]byte, len(golden))
+ if _, err := io.ReadFull(conn, buf); err != nil {
+ t.Errorf("conn %d: read error: %v", i, err)
+ } else if !bytes.Equal(buf, golden) {
+ t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden)
+ }
+ }
+}
diff --git a/p2p/testlog_test.go b/p2p/testlog_test.go
new file mode 100644
index 000000000..951d43243
--- /dev/null
+++ b/p2p/testlog_test.go
@@ -0,0 +1,28 @@
+package p2p
+
+import (
+ "testing"
+
+ "github.com/ethereum/go-ethereum/logger"
+)
+
+type testLogger struct{ t *testing.T }
+
+func testlog(t *testing.T) testLogger {
+ logger.Reset()
+ l := testLogger{t}
+ logger.AddLogSystem(l)
+ return l
+}
+
+func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
+func (testLogger) SetLogLevel(logger.LogLevel) {}
+
+func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
+ l.t.Logf("%s", msg)
+}
+
+func (testLogger) detach() {
+ logger.Flush()
+ logger.Reset()
+}
diff --git a/p2p/testpoc7.go b/p2p/testpoc7.go
new file mode 100644
index 000000000..c0cc5c544
--- /dev/null
+++ b/p2p/testpoc7.go
@@ -0,0 +1,40 @@
+// +build none
+
+package main
+
+import (
+ "fmt"
+ "log"
+ "net"
+ "os"
+
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/obscuren/secp256k1-go"
+)
+
+func main() {
+ logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
+
+ pub, _ := secp256k1.GenerateKeyPair()
+ srv := p2p.Server{
+ MaxPeers: 10,
+ Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
+ ListenAddr: ":30303",
+ NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
+ }
+ if err := srv.Start(); err != nil {
+ fmt.Println("could not start server:", err)
+ os.Exit(1)
+ }
+
+ // add seed peers
+ seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
+ if err != nil {
+ fmt.Println("couldn't resolve:", err)
+ os.Exit(1)
+ }
+ srv.SuggestPeer(seed.IP, seed.Port, nil)
+
+ select {}
+}