aboutsummaryrefslogtreecommitdiffstats
path: root/p2p
diff options
context:
space:
mode:
authorPaweł Bylica <pawel.bylica@imapp.pl>2015-02-24 01:39:05 +0800
committerPaweł Bylica <pawel.bylica@imapp.pl>2015-02-24 01:39:05 +0800
commit114c3b4efe7f30ab7be0bec013210e7b4c3d08d7 (patch)
tree5230f6fee87dcbac36e1d71d6ab731b55eab8268 /p2p
parentb9894c1d0979b9f3e8428b1dc230f1ece106f676 (diff)
parentdd086791acf477da7641c168f82de70ed0b2dca6 (diff)
downloaddexon-114c3b4efe7f30ab7be0bec013210e7b4c3d08d7.tar
dexon-114c3b4efe7f30ab7be0bec013210e7b4c3d08d7.tar.gz
dexon-114c3b4efe7f30ab7be0bec013210e7b4c3d08d7.tar.bz2
dexon-114c3b4efe7f30ab7be0bec013210e7b4c3d08d7.tar.lz
dexon-114c3b4efe7f30ab7be0bec013210e7b4c3d08d7.tar.xz
dexon-114c3b4efe7f30ab7be0bec013210e7b4c3d08d7.tar.zst
dexon-114c3b4efe7f30ab7be0bec013210e7b4c3d08d7.zip
Merge remote-tracking branch 'upstream/develop' into evmjit
Diffstat (limited to 'p2p')
-rw-r--r--p2p/client_identity.go63
-rw-r--r--p2p/client_identity_test.go30
-rw-r--r--p2p/discover/node.go291
-rw-r--r--p2p/discover/node_test.go201
-rw-r--r--p2p/discover/table.go280
-rw-r--r--p2p/discover/table_test.go311
-rw-r--r--p2p/discover/udp.go432
-rw-r--r--p2p/discover/udp_test.go211
-rw-r--r--p2p/handshake.go434
-rw-r--r--p2p/handshake_test.go224
-rw-r--r--p2p/message.go143
-rw-r--r--p2p/message_test.go140
-rw-r--r--p2p/nat.go23
-rw-r--r--p2p/nat/nat.go235
-rw-r--r--p2p/nat/natpmp.go115
-rw-r--r--p2p/nat/natupnp.go149
-rw-r--r--p2p/natpmp.go55
-rw-r--r--p2p/natupnp.go341
-rw-r--r--p2p/peer.go466
-rw-r--r--p2p/peer_error.go58
-rw-r--r--p2p/peer_test.go252
-rw-r--r--p2p/protocol.go248
-rw-r--r--p2p/protocol_test.go158
-rw-r--r--p2p/server.go423
-rw-r--r--p2p/server_test.go89
-rw-r--r--p2p/testlog_test.go2
-rw-r--r--p2p/testpoc7.go40
27 files changed, 3672 insertions, 1742 deletions
diff --git a/p2p/client_identity.go b/p2p/client_identity.go
deleted file mode 100644
index f15fd01bf..000000000
--- a/p2p/client_identity.go
+++ /dev/null
@@ -1,63 +0,0 @@
-package p2p
-
-import (
- "fmt"
- "runtime"
-)
-
-// ClientIdentity represents the identity of a peer.
-type ClientIdentity interface {
- String() string // human readable identity
- Pubkey() []byte // 512-bit public key
-}
-
-type SimpleClientIdentity struct {
- clientIdentifier string
- version string
- customIdentifier string
- os string
- implementation string
- pubkey []byte
-}
-
-func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey []byte) *SimpleClientIdentity {
- clientIdentity := &SimpleClientIdentity{
- clientIdentifier: clientIdentifier,
- version: version,
- customIdentifier: customIdentifier,
- os: runtime.GOOS,
- implementation: runtime.Version(),
- pubkey: pubkey,
- }
-
- return clientIdentity
-}
-
-func (c *SimpleClientIdentity) init() {
-}
-
-func (c *SimpleClientIdentity) String() string {
- var id string
- if len(c.customIdentifier) > 0 {
- id = "/" + c.customIdentifier
- }
-
- return fmt.Sprintf("%s/v%s%s/%s/%s",
- c.clientIdentifier,
- c.version,
- id,
- c.os,
- c.implementation)
-}
-
-func (c *SimpleClientIdentity) Pubkey() []byte {
- return []byte(c.pubkey)
-}
-
-func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) {
- c.customIdentifier = customIdentifier
-}
-
-func (c *SimpleClientIdentity) GetCustomIdentifier() string {
- return c.customIdentifier
-}
diff --git a/p2p/client_identity_test.go b/p2p/client_identity_test.go
deleted file mode 100644
index 7248a7b1a..000000000
--- a/p2p/client_identity_test.go
+++ /dev/null
@@ -1,30 +0,0 @@
-package p2p
-
-import (
- "fmt"
- "runtime"
- "testing"
-)
-
-func TestClientIdentity(t *testing.T) {
- clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", []byte("pubkey"))
- clientString := clientIdentity.String()
- expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
- if clientString != expected {
- t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
- }
- customIdentifier := clientIdentity.GetCustomIdentifier()
- if customIdentifier != "test" {
- t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier)
- }
- clientIdentity.SetCustomIdentifier("test2")
- customIdentifier = clientIdentity.GetCustomIdentifier()
- if customIdentifier != "test2" {
- t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier)
- }
- clientString = clientIdentity.String()
- expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version())
- if clientString != expected {
- t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
- }
-}
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
new file mode 100644
index 000000000..c6d2e9766
--- /dev/null
+++ b/p2p/discover/node.go
@@ -0,0 +1,291 @@
+package discover
+
+import (
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "math/rand"
+ "net"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto/secp256k1"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+const nodeIDBits = 512
+
+// Node represents a host on the network.
+type Node struct {
+ ID NodeID
+ IP net.IP
+
+ DiscPort int // UDP listening port for discovery protocol
+ TCPPort int // TCP listening port for RLPx
+
+ active time.Time
+}
+
+func newNode(id NodeID, addr *net.UDPAddr) *Node {
+ return &Node{
+ ID: id,
+ IP: addr.IP,
+ DiscPort: addr.Port,
+ TCPPort: addr.Port,
+ active: time.Now(),
+ }
+}
+
+func (n *Node) isValid() bool {
+ // TODO: don't accept localhost, LAN addresses from internet hosts
+ return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
+}
+
+// The string representation of a Node is a URL.
+// Please see ParseNode for a description of the format.
+func (n *Node) String() string {
+ addr := net.TCPAddr{IP: n.IP, Port: n.TCPPort}
+ u := url.URL{
+ Scheme: "enode",
+ User: url.User(fmt.Sprintf("%x", n.ID[:])),
+ Host: addr.String(),
+ }
+ if n.DiscPort != n.TCPPort {
+ u.RawQuery = "discport=" + strconv.Itoa(n.DiscPort)
+ }
+ return u.String()
+}
+
+// ParseNode parses a node URL.
+//
+// A node URL has scheme "enode".
+//
+// The hexadecimal node ID is encoded in the username portion of the
+// URL, separated from the host by an @ sign. The hostname can only be
+// given as an IP address, DNS domain names are not allowed. The port
+// in the host name section is the TCP listening port. If the TCP and
+// UDP (discovery) ports differ, the UDP port is specified as query
+// parameter "discport".
+//
+// In the following example, the node URL describes
+// a node with IP address 10.3.58.6, TCP listening port 30303
+// and UDP discovery port 30301.
+//
+// enode://<hex node id>@10.3.58.6:30303?discport=30301
+func ParseNode(rawurl string) (*Node, error) {
+ var n Node
+ u, err := url.Parse(rawurl)
+ if u.Scheme != "enode" {
+ return nil, errors.New("invalid URL scheme, want \"enode\"")
+ }
+ if u.User == nil {
+ return nil, errors.New("does not contain node ID")
+ }
+ if n.ID, err = HexID(u.User.String()); err != nil {
+ return nil, fmt.Errorf("invalid node ID (%v)", err)
+ }
+ ip, port, err := net.SplitHostPort(u.Host)
+ if err != nil {
+ return nil, fmt.Errorf("invalid host: %v", err)
+ }
+ if n.IP = net.ParseIP(ip); n.IP == nil {
+ return nil, errors.New("invalid IP address")
+ }
+ if n.TCPPort, err = strconv.Atoi(port); err != nil {
+ return nil, errors.New("invalid port")
+ }
+ qv := u.Query()
+ if qv.Get("discport") == "" {
+ n.DiscPort = n.TCPPort
+ } else {
+ if n.DiscPort, err = strconv.Atoi(qv.Get("discport")); err != nil {
+ return nil, errors.New("invalid discport in query")
+ }
+ }
+ return &n, nil
+}
+
+// MustParseNode parses a node URL. It panics if the URL is not valid.
+func MustParseNode(rawurl string) *Node {
+ n, err := ParseNode(rawurl)
+ if err != nil {
+ panic("invalid node URL: " + err.Error())
+ }
+ return n
+}
+
+func (n Node) EncodeRLP(w io.Writer) error {
+ return rlp.Encode(w, rpcNode{IP: n.IP.String(), Port: uint16(n.TCPPort), ID: n.ID})
+}
+func (n *Node) DecodeRLP(s *rlp.Stream) (err error) {
+ var ext rpcNode
+ if err = s.Decode(&ext); err == nil {
+ n.TCPPort = int(ext.Port)
+ n.DiscPort = int(ext.Port)
+ n.ID = ext.ID
+ if n.IP = net.ParseIP(ext.IP); n.IP == nil {
+ return errors.New("invalid IP string")
+ }
+ }
+ return err
+}
+
+// NodeID is a unique identifier for each node.
+// The node identifier is a marshaled elliptic curve public key.
+type NodeID [nodeIDBits / 8]byte
+
+// NodeID prints as a long hexadecimal number.
+func (n NodeID) String() string {
+ return fmt.Sprintf("%#x", n[:])
+}
+
+// The Go syntax representation of a NodeID is a call to HexID.
+func (n NodeID) GoString() string {
+ return fmt.Sprintf("discover.HexID(\"%#x\")", n[:])
+}
+
+// HexID converts a hex string to a NodeID.
+// The string may be prefixed with 0x.
+func HexID(in string) (NodeID, error) {
+ if strings.HasPrefix(in, "0x") {
+ in = in[2:]
+ }
+ var id NodeID
+ b, err := hex.DecodeString(in)
+ if err != nil {
+ return id, err
+ } else if len(b) != len(id) {
+ return id, fmt.Errorf("wrong length, need %d hex bytes", len(id))
+ }
+ copy(id[:], b)
+ return id, nil
+}
+
+// MustHexID converts a hex string to a NodeID.
+// It panics if the string is not a valid NodeID.
+func MustHexID(in string) NodeID {
+ id, err := HexID(in)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// PubkeyID returns a marshaled representation of the given public key.
+func PubkeyID(pub *ecdsa.PublicKey) NodeID {
+ var id NodeID
+ pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
+ if len(pbytes)-1 != len(id) {
+ panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
+ }
+ copy(id[:], pbytes[1:])
+ return id
+}
+
+// recoverNodeID computes the public key used to sign the
+// given hash from the signature.
+func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
+ pubkey, err := secp256k1.RecoverPubkey(hash, sig)
+ if err != nil {
+ return id, err
+ }
+ if len(pubkey)-1 != len(id) {
+ return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
+ }
+ for i := range id {
+ id[i] = pubkey[i+1]
+ }
+ return id, nil
+}
+
+// distcmp compares the distances a->target and b->target.
+// Returns -1 if a is closer to target, 1 if b is closer to target
+// and 0 if they are equal.
+func distcmp(target, a, b NodeID) int {
+ for i := range target {
+ da := a[i] ^ target[i]
+ db := b[i] ^ target[i]
+ if da > db {
+ return 1
+ } else if da < db {
+ return -1
+ }
+ }
+ return 0
+}
+
+// table of leading zero counts for bytes [0..255]
+var lzcount = [256]int{
+ 8, 7, 6, 6, 5, 5, 5, 5,
+ 4, 4, 4, 4, 4, 4, 4, 4,
+ 3, 3, 3, 3, 3, 3, 3, 3,
+ 3, 3, 3, 3, 3, 3, 3, 3,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+}
+
+// logdist returns the logarithmic distance between a and b, log2(a ^ b).
+func logdist(a, b NodeID) int {
+ lz := 0
+ for i := range a {
+ x := a[i] ^ b[i]
+ if x == 0 {
+ lz += 8
+ } else {
+ lz += lzcount[x]
+ break
+ }
+ }
+ return len(a)*8 - lz
+}
+
+// randomID returns a random NodeID such that logdist(a, b) == n
+func randomID(a NodeID, n int) (b NodeID) {
+ if n == 0 {
+ return a
+ }
+ // flip bit at position n, fill the rest with random bits
+ b = a
+ pos := len(a) - n/8 - 1
+ bit := byte(0x01) << (byte(n%8) - 1)
+ if bit == 0 {
+ pos++
+ bit = 0x80
+ }
+ b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
+ for i := pos + 1; i < len(a); i++ {
+ b[i] = byte(rand.Intn(255))
+ }
+ return b
+}
diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go
new file mode 100644
index 000000000..ae82ae4f1
--- /dev/null
+++ b/p2p/discover/node_test.go
@@ -0,0 +1,201 @@
+package discover
+
+import (
+ "math/big"
+ "math/rand"
+ "net"
+ "reflect"
+ "testing"
+ "testing/quick"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+)
+
+var (
+ quickrand = rand.New(rand.NewSource(time.Now().Unix()))
+ quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand}
+)
+
+var parseNodeTests = []struct {
+ rawurl string
+ wantError string
+ wantResult *Node
+}{
+ {
+ rawurl: "http://foobar",
+ wantError: `invalid URL scheme, want "enode"`,
+ },
+ {
+ rawurl: "enode://foobar",
+ wantError: `does not contain node ID`,
+ },
+ {
+ rawurl: "enode://01010101@123.124.125.126:3",
+ wantError: `invalid node ID (wrong length, need 64 hex bytes)`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@hostname:3",
+ wantError: `invalid IP address`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo",
+ wantError: `invalid port`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo",
+ wantError: `invalid discport in query`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150",
+ wantResult: &Node{
+ ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ IP: net.ParseIP("127.0.0.1"),
+ DiscPort: 52150,
+ TCPPort: 52150,
+ },
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150",
+ wantResult: &Node{
+ ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ IP: net.ParseIP("::"),
+ DiscPort: 52150,
+ TCPPort: 52150,
+ },
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=223344",
+ wantResult: &Node{
+ ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ IP: net.ParseIP("127.0.0.1"),
+ DiscPort: 223344,
+ TCPPort: 52150,
+ },
+ },
+}
+
+func TestParseNode(t *testing.T) {
+ for i, test := range parseNodeTests {
+ n, err := ParseNode(test.rawurl)
+ if err == nil && test.wantError != "" {
+ t.Errorf("test %d: got nil error, expected %#q", i, test.wantError)
+ continue
+ }
+ if err != nil && err.Error() != test.wantError {
+ t.Errorf("test %d: got error %#q, expected %#q", i, err.Error(), test.wantError)
+ continue
+ }
+ if !reflect.DeepEqual(n, test.wantResult) {
+ t.Errorf("test %d: result mismatch:\ngot: %#v, want: %#v", i, n, test.wantResult)
+ }
+ }
+}
+
+func TestNodeString(t *testing.T) {
+ for i, test := range parseNodeTests {
+ if test.wantError != "" {
+ continue
+ }
+ str := test.wantResult.String()
+ if str != test.rawurl {
+ t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.rawurl)
+ }
+ }
+}
+
+func TestHexID(t *testing.T) {
+ ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
+ id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
+ id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
+
+ if id1 != ref {
+ t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:])
+ }
+ if id2 != ref {
+ t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:])
+ }
+}
+
+func TestNodeID_recover(t *testing.T) {
+ prv := newkey()
+ hash := make([]byte, 32)
+ sig, err := crypto.Sign(hash, prv)
+ if err != nil {
+ t.Fatalf("signing error: %v", err)
+ }
+
+ pub := PubkeyID(&prv.PublicKey)
+ recpub, err := recoverNodeID(hash, sig)
+ if err != nil {
+ t.Fatalf("recovery error: %v", err)
+ }
+ if pub != recpub {
+ t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub)
+ }
+}
+
+func TestNodeID_distcmp(t *testing.T) {
+ distcmpBig := func(target, a, b NodeID) int {
+ tbig := new(big.Int).SetBytes(target[:])
+ abig := new(big.Int).SetBytes(a[:])
+ bbig := new(big.Int).SetBytes(b[:])
+ return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
+ }
+ if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil {
+ t.Error(err)
+ }
+}
+
+// the random tests is likely to miss the case where they're equal.
+func TestNodeID_distcmpEqual(t *testing.T) {
+ base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+ x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
+ if distcmp(base, x, x) != 0 {
+ t.Errorf("distcmp(base, x, x) != 0")
+ }
+}
+
+func TestNodeID_logdist(t *testing.T) {
+ logdistBig := func(a, b NodeID) int {
+ abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
+ return new(big.Int).Xor(abig, bbig).BitLen()
+ }
+ if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil {
+ t.Error(err)
+ }
+}
+
+// the random tests is likely to miss the case where they're equal.
+func TestNodeID_logdistEqual(t *testing.T) {
+ x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+ if logdist(x, x) != 0 {
+ t.Errorf("logdist(x, x) != 0")
+ }
+}
+
+func TestNodeID_randomID(t *testing.T) {
+ // we don't use quick.Check here because its output isn't
+ // very helpful when the test fails.
+ for i := 0; i < quickcfg.MaxCount; i++ {
+ a := gen(NodeID{}, quickrand).(NodeID)
+ dist := quickrand.Intn(len(NodeID{}) * 8)
+ result := randomID(a, dist)
+ actualdist := logdist(result, a)
+
+ if dist != actualdist {
+ t.Log("a: ", a)
+ t.Log("result:", result)
+ t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist)
+ }
+ }
+}
+
+func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
+ var id NodeID
+ m := rand.Intn(len(id))
+ for i := len(id) - 1; i > m; i-- {
+ id[i] = byte(rand.Uint32())
+ }
+ return reflect.ValueOf(id)
+}
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
new file mode 100644
index 000000000..e3bec9328
--- /dev/null
+++ b/p2p/discover/table.go
@@ -0,0 +1,280 @@
+// Package discover implements the Node Discovery Protocol.
+//
+// The Node Discovery protocol provides a way to find RLPx nodes that
+// can be connected to. It uses a Kademlia-like protocol to maintain a
+// distributed database of the IDs and endpoints of all listening
+// nodes.
+package discover
+
+import (
+ "net"
+ "sort"
+ "sync"
+ "time"
+)
+
+const (
+ alpha = 3 // Kademlia concurrency factor
+ bucketSize = 16 // Kademlia bucket size
+ nBuckets = nodeIDBits + 1 // Number of buckets
+)
+
+type Table struct {
+ mutex sync.Mutex // protects buckets, their content, and nursery
+ buckets [nBuckets]*bucket // index of known nodes by distance
+ nursery []*Node // bootstrap nodes
+
+ net transport
+ self *Node // metadata of the local node
+}
+
+// transport is implemented by the UDP transport.
+// it is an interface so we can test without opening lots of UDP
+// sockets and without generating a private key.
+type transport interface {
+ ping(*Node) error
+ findnode(e *Node, target NodeID) ([]*Node, error)
+ close()
+}
+
+// bucket contains nodes, ordered by their last activity.
+type bucket struct {
+ lastLookup time.Time
+ entries []*Node
+}
+
+func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
+ tab := &Table{net: t, self: newNode(ourID, ourAddr)}
+ for i := range tab.buckets {
+ tab.buckets[i] = new(bucket)
+ }
+ return tab
+}
+
+// Self returns the local node ID.
+func (tab *Table) Self() NodeID {
+ return tab.self.ID
+}
+
+// Close terminates the network listener.
+func (tab *Table) Close() {
+ tab.net.close()
+}
+
+// Bootstrap sets the bootstrap nodes. These nodes are used to connect
+// to the network if the table is empty. Bootstrap will also attempt to
+// fill the table by performing random lookup operations on the
+// network.
+func (tab *Table) Bootstrap(nodes []*Node) {
+ tab.mutex.Lock()
+ // TODO: maybe filter nodes with bad fields (nil, etc.) to avoid strange crashes
+ tab.nursery = make([]*Node, 0, len(nodes))
+ for _, n := range nodes {
+ cpy := *n
+ tab.nursery = append(tab.nursery, &cpy)
+ }
+ tab.mutex.Unlock()
+ tab.refresh()
+}
+
+// Lookup performs a network search for nodes close
+// to the given target. It approaches the target by querying
+// nodes that are closer to it on each iteration.
+func (tab *Table) Lookup(target NodeID) []*Node {
+ var (
+ asked = make(map[NodeID]bool)
+ seen = make(map[NodeID]bool)
+ reply = make(chan []*Node, alpha)
+ pendingQueries = 0
+ )
+ // don't query further if we hit the target or ourself.
+ // unlikely to happen often in practice.
+ asked[target] = true
+ asked[tab.self.ID] = true
+
+ tab.mutex.Lock()
+ // update last lookup stamp (for refresh logic)
+ tab.buckets[logdist(tab.self.ID, target)].lastLookup = time.Now()
+ // generate initial result set
+ result := tab.closest(target, bucketSize)
+ tab.mutex.Unlock()
+
+ for {
+ // ask the alpha closest nodes that we haven't asked yet
+ for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
+ n := result.entries[i]
+ if !asked[n.ID] {
+ asked[n.ID] = true
+ pendingQueries++
+ go func() {
+ result, _ := tab.net.findnode(n, target)
+ reply <- result
+ }()
+ }
+ }
+ if pendingQueries == 0 {
+ // we have asked all closest nodes, stop the search
+ break
+ }
+
+ // wait for the next reply
+ for _, n := range <-reply {
+ cn := n
+ if !seen[n.ID] {
+ seen[n.ID] = true
+ result.push(cn, bucketSize)
+ }
+ }
+ pendingQueries--
+ }
+ return result.entries
+}
+
+// refresh performs a lookup for a random target to keep buckets full.
+func (tab *Table) refresh() {
+ ld := -1 // logdist of chosen bucket
+ tab.mutex.Lock()
+ for i, b := range tab.buckets {
+ if i > 0 && b.lastLookup.Before(time.Now().Add(-1*time.Hour)) {
+ ld = i
+ break
+ }
+ }
+ tab.mutex.Unlock()
+
+ result := tab.Lookup(randomID(tab.self.ID, ld))
+ if len(result) == 0 {
+ // bootstrap the table with a self lookup
+ tab.mutex.Lock()
+ tab.add(tab.nursery)
+ tab.mutex.Unlock()
+ tab.Lookup(tab.self.ID)
+ // TODO: the Kademlia paper says that we're supposed to perform
+ // random lookups in all buckets further away than our closest neighbor.
+ }
+}
+
+// closest returns the n nodes in the table that are closest to the
+// given id. The caller must hold tab.mutex.
+func (tab *Table) closest(target NodeID, nresults int) *nodesByDistance {
+ // This is a very wasteful way to find the closest nodes but
+ // obviously correct. I believe that tree-based buckets would make
+ // this easier to implement efficiently.
+ close := &nodesByDistance{target: target}
+ for _, b := range tab.buckets {
+ for _, n := range b.entries {
+ close.push(n, nresults)
+ }
+ }
+ return close
+}
+
+func (tab *Table) len() (n int) {
+ for _, b := range tab.buckets {
+ n += len(b.entries)
+ }
+ return n
+}
+
+// bumpOrAdd updates the activity timestamp for the given node and
+// attempts to insert the node into a bucket. The returned Node might
+// not be part of the table. The caller must hold tab.mutex.
+func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
+ b := tab.buckets[logdist(tab.self.ID, node)]
+ if n = b.bump(node); n == nil {
+ n = newNode(node, from)
+ if len(b.entries) == bucketSize {
+ tab.pingReplace(n, b)
+ } else {
+ b.entries = append(b.entries, n)
+ }
+ }
+ return n
+}
+
+func (tab *Table) pingReplace(n *Node, b *bucket) {
+ old := b.entries[bucketSize-1]
+ go func() {
+ if err := tab.net.ping(old); err == nil {
+ // it responded, we don't need to replace it.
+ return
+ }
+ // it didn't respond, replace the node if it is still the oldest node.
+ tab.mutex.Lock()
+ if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
+ // slide down other entries and put the new one in front.
+ // TODO: insert in correct position to keep the order
+ copy(b.entries[1:], b.entries)
+ b.entries[0] = n
+ }
+ tab.mutex.Unlock()
+ }()
+}
+
+// bump updates the activity timestamp for the given node.
+// The caller must hold tab.mutex.
+func (tab *Table) bump(node NodeID) {
+ tab.buckets[logdist(tab.self.ID, node)].bump(node)
+}
+
+// add puts the entries into the table if their corresponding
+// bucket is not full. The caller must hold tab.mutex.
+func (tab *Table) add(entries []*Node) {
+outer:
+ for _, n := range entries {
+ if n == nil || n.ID == tab.self.ID {
+ // skip bad entries. The RLP decoder returns nil for empty
+ // input lists.
+ continue
+ }
+ bucket := tab.buckets[logdist(tab.self.ID, n.ID)]
+ for i := range bucket.entries {
+ if bucket.entries[i].ID == n.ID {
+ // already in bucket
+ continue outer
+ }
+ }
+ if len(bucket.entries) < bucketSize {
+ bucket.entries = append(bucket.entries, n)
+ }
+ }
+}
+
+func (b *bucket) bump(id NodeID) *Node {
+ for i, n := range b.entries {
+ if n.ID == id {
+ n.active = time.Now()
+ // move it to the front
+ copy(b.entries[1:], b.entries[:i+1])
+ b.entries[0] = n
+ return n
+ }
+ }
+ return nil
+}
+
+// nodesByDistance is a list of nodes, ordered by
+// distance to target.
+type nodesByDistance struct {
+ entries []*Node
+ target NodeID
+}
+
+// push adds the given node to the list, keeping the total size below maxElems.
+func (h *nodesByDistance) push(n *Node, maxElems int) {
+ ix := sort.Search(len(h.entries), func(i int) bool {
+ return distcmp(h.target, h.entries[i].ID, n.ID) > 0
+ })
+ if len(h.entries) < maxElems {
+ h.entries = append(h.entries, n)
+ }
+ if ix == len(h.entries) {
+ // farther away than all nodes we already have.
+ // if there was room for it, the node is now the last element.
+ } else {
+ // slide existing entries down to make room
+ // this will overwrite the entry we just appended.
+ copy(h.entries[ix+1:], h.entries[ix:])
+ h.entries[ix] = n
+ }
+}
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
new file mode 100644
index 000000000..08faea68e
--- /dev/null
+++ b/p2p/discover/table_test.go
@@ -0,0 +1,311 @@
+package discover
+
+import (
+ "crypto/ecdsa"
+ "errors"
+ "fmt"
+ "math/rand"
+ "net"
+ "reflect"
+ "testing"
+ "testing/quick"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+)
+
+func TestTable_bumpOrAddBucketAssign(t *testing.T) {
+ tab := newTable(nil, NodeID{}, &net.UDPAddr{})
+ for i := 1; i < len(tab.buckets); i++ {
+ tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{})
+ }
+ for i, b := range tab.buckets {
+ if i > 0 && len(b.entries) != 1 {
+ t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries))
+ }
+ }
+}
+
+func TestTable_bumpOrAddPingReplace(t *testing.T) {
+ pingC := make(pingC)
+ tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
+ last := fillBucket(tab, 200)
+
+ // this bumpOrAdd should not replace the last node
+ // because the node replies to ping.
+ new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
+
+ pinged := <-pingC
+ if pinged != last.ID {
+ t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID)
+ }
+
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ if l := len(tab.buckets[200].entries); l != bucketSize {
+ t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
+ }
+ if !contains(tab.buckets[200].entries, last.ID) {
+ t.Error("last entry was removed")
+ }
+ if contains(tab.buckets[200].entries, new.ID) {
+ t.Error("new entry was added")
+ }
+}
+
+func TestTable_bumpOrAddPingTimeout(t *testing.T) {
+ tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
+ last := fillBucket(tab, 200)
+
+ // this bumpOrAdd should replace the last node
+ // because the node does not reply to ping.
+ new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
+
+ // wait for async bucket update. damn. this needs to go away.
+ time.Sleep(2 * time.Millisecond)
+
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ if l := len(tab.buckets[200].entries); l != bucketSize {
+ t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
+ }
+ if contains(tab.buckets[200].entries, last.ID) {
+ t.Error("last entry was not removed")
+ }
+ if !contains(tab.buckets[200].entries, new.ID) {
+ t.Error("new entry was not added")
+ }
+}
+
+func fillBucket(tab *Table, ld int) (last *Node) {
+ b := tab.buckets[ld]
+ for len(b.entries) < bucketSize {
+ b.entries = append(b.entries, &Node{ID: randomID(tab.self.ID, ld)})
+ }
+ return b.entries[bucketSize-1]
+}
+
+type pingC chan NodeID
+
+func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) {
+ panic("findnode called on pingRecorder")
+}
+func (t pingC) close() {
+ panic("close called on pingRecorder")
+}
+func (t pingC) ping(n *Node) error {
+ if t == nil {
+ return errTimeout
+ }
+ t <- n.ID
+ return nil
+}
+
+func TestTable_bump(t *testing.T) {
+ tab := newTable(nil, NodeID{}, &net.UDPAddr{})
+
+ // add an old entry and two recent ones
+ oldactive := time.Now().Add(-2 * time.Minute)
+ old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
+ others := []*Node{
+ &Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
+ &Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
+ }
+ tab.add(append(others, old))
+ if tab.buckets[200].entries[0] == old {
+ t.Fatal("old entry is at front of bucket")
+ }
+
+ // bumping the old entry should move it to the front
+ tab.bump(old.ID)
+ if old.active == oldactive {
+ t.Error("activity timestamp not updated")
+ }
+ if tab.buckets[200].entries[0] != old {
+ t.Errorf("bumped entry did not move to the front of bucket")
+ }
+}
+
+func TestTable_closest(t *testing.T) {
+ t.Parallel()
+
+ test := func(test *closeTest) bool {
+ // for any node table, Target and N
+ tab := newTable(nil, test.Self, &net.UDPAddr{})
+ tab.add(test.All)
+
+ // check that doClosest(Target, N) returns nodes
+ result := tab.closest(test.Target, test.N).entries
+ if hasDuplicates(result) {
+ t.Errorf("result contains duplicates")
+ return false
+ }
+ if !sortedByDistanceTo(test.Target, result) {
+ t.Errorf("result is not sorted by distance to target")
+ return false
+ }
+
+ // check that the number of results is min(N, tablen)
+ wantN := test.N
+ if tlen := tab.len(); tlen < test.N {
+ wantN = tlen
+ }
+ if len(result) != wantN {
+ t.Errorf("wrong number of nodes: got %d, want %d", len(result), wantN)
+ return false
+ } else if len(result) == 0 {
+ return true // no need to check distance
+ }
+
+ // check that the result nodes have minimum distance to target.
+ for _, b := range tab.buckets {
+ for _, n := range b.entries {
+ if contains(result, n.ID) {
+ continue // don't run the check below for nodes in result
+ }
+ farthestResult := result[len(result)-1].ID
+ if distcmp(test.Target, n.ID, farthestResult) < 0 {
+ t.Errorf("table contains node that is closer to target but it's not in result")
+ t.Logf(" Target: %v", test.Target)
+ t.Logf(" Farthest Result: %v", farthestResult)
+ t.Logf(" ID: %v", n.ID)
+ return false
+ }
+ }
+ }
+ return true
+ }
+ if err := quick.Check(test, quickcfg); err != nil {
+ t.Error(err)
+ }
+}
+
+type closeTest struct {
+ Self NodeID
+ Target NodeID
+ All []*Node
+ N int
+}
+
+func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
+ t := &closeTest{
+ Self: gen(NodeID{}, rand).(NodeID),
+ Target: gen(NodeID{}, rand).(NodeID),
+ N: rand.Intn(bucketSize),
+ }
+ for _, id := range gen([]NodeID{}, rand).([]NodeID) {
+ t.All = append(t.All, &Node{ID: id})
+ }
+ return reflect.ValueOf(t)
+}
+
+func TestTable_Lookup(t *testing.T) {
+ self := gen(NodeID{}, quickrand).(NodeID)
+ target := randomID(self, 200)
+ transport := findnodeOracle{t, target}
+ tab := newTable(transport, self, &net.UDPAddr{})
+
+ // lookup on empty table returns no nodes
+ if results := tab.Lookup(target); len(results) > 0 {
+ t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
+ }
+ // seed table with initial node (otherwise lookup will terminate immediately)
+ tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200})
+
+ results := tab.Lookup(target)
+ t.Logf("results:")
+ for _, e := range results {
+ t.Logf(" ld=%d, %v", logdist(target, e.ID), e.ID)
+ }
+ if len(results) != bucketSize {
+ t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize)
+ }
+ if hasDuplicates(results) {
+ t.Errorf("result set contains duplicate entries")
+ }
+ if !sortedByDistanceTo(target, results) {
+ t.Errorf("result set not sorted by distance to target")
+ }
+ if !contains(results, target) {
+ t.Errorf("result set does not contain target")
+ }
+}
+
+// findnode on this transport always returns at least one node
+// that is one bucket closer to the target.
+type findnodeOracle struct {
+ t *testing.T
+ target NodeID
+}
+
+func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
+ t.t.Logf("findnode query at dist %d", n.DiscPort)
+ // current log distance is encoded in port number
+ var result []*Node
+ switch n.DiscPort {
+ case 0:
+ panic("query to node at distance 0")
+ default:
+ // TODO: add more randomness to distances
+ next := n.DiscPort - 1
+ for i := 0; i < bucketSize; i++ {
+ result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
+ }
+ }
+ return result, nil
+}
+
+func (t findnodeOracle) close() {}
+
+func (t findnodeOracle) ping(n *Node) error {
+ return errors.New("ping is not supported by this transport")
+}
+
+func hasDuplicates(slice []*Node) bool {
+ seen := make(map[NodeID]bool)
+ for _, e := range slice {
+ if seen[e.ID] {
+ return true
+ }
+ seen[e.ID] = true
+ }
+ return false
+}
+
+func sortedByDistanceTo(distbase NodeID, slice []*Node) bool {
+ var last NodeID
+ for i, e := range slice {
+ if i > 0 && distcmp(distbase, e.ID, last) < 0 {
+ return false
+ }
+ last = e.ID
+ }
+ return true
+}
+
+func contains(ns []*Node, id NodeID) bool {
+ for _, n := range ns {
+ if n.ID == id {
+ return true
+ }
+ }
+ return false
+}
+
+// gen wraps quick.Value so it's easier to use.
+// it generates a random value of the given value's type.
+func gen(typ interface{}, rand *rand.Rand) interface{} {
+ v, ok := quick.Value(reflect.TypeOf(typ), rand)
+ if !ok {
+ panic(fmt.Sprintf("couldn't generate random value of type %T", typ))
+ }
+ return v.Interface()
+}
+
+func newkey() *ecdsa.PrivateKey {
+ key, err := crypto.GenerateKey()
+ if err != nil {
+ panic("couldn't generate key: " + err.Error())
+ }
+ return key
+}
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
new file mode 100644
index 000000000..69e9f3c2e
--- /dev/null
+++ b/p2p/discover/udp.go
@@ -0,0 +1,432 @@
+package discover
+
+import (
+ "bytes"
+ "crypto/ecdsa"
+ "errors"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/p2p/nat"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+var log = logger.NewLogger("P2P Discovery")
+
+// Errors
+var (
+ errPacketTooSmall = errors.New("too small")
+ errBadHash = errors.New("bad hash")
+ errExpired = errors.New("expired")
+ errTimeout = errors.New("RPC timeout")
+ errClosed = errors.New("socket closed")
+)
+
+// Timeouts
+const (
+ respTimeout = 300 * time.Millisecond
+ sendTimeout = 300 * time.Millisecond
+ expiration = 20 * time.Second
+
+ refreshInterval = 1 * time.Hour
+)
+
+// RPC packet types
+const (
+ pingPacket = iota + 1 // zero is 'reserved'
+ pongPacket
+ findnodePacket
+ neighborsPacket
+)
+
+// RPC request structures
+type (
+ ping struct {
+ IP string // our IP
+ Port uint16 // our port
+ Expiration uint64
+ }
+
+ // reply to Ping
+ pong struct {
+ ReplyTok []byte
+ Expiration uint64
+ }
+
+ findnode struct {
+ // Id to look up. The responding node will send back nodes
+ // closest to the target.
+ Target NodeID
+ Expiration uint64
+ }
+
+ // reply to findnode
+ neighbors struct {
+ Nodes []*Node
+ Expiration uint64
+ }
+)
+
+type rpcNode struct {
+ IP string
+ Port uint16
+ ID NodeID
+}
+
+// udp implements the RPC protocol.
+type udp struct {
+ conn *net.UDPConn
+ priv *ecdsa.PrivateKey
+ addpending chan *pending
+ replies chan reply
+ closing chan struct{}
+ nat nat.Interface
+
+ *Table
+}
+
+// pending represents a pending reply.
+//
+// some implementations of the protocol wish to send more than one
+// reply packet to findnode. in general, any neighbors packet cannot
+// be matched up with a specific findnode packet.
+//
+// our implementation handles this by storing a callback function for
+// each pending reply. incoming packets from a node are dispatched
+// to all the callback functions for that node.
+type pending struct {
+ // these fields must match in the reply.
+ from NodeID
+ ptype byte
+
+ // time when the request must complete
+ deadline time.Time
+
+ // callback is called when a matching reply arrives. if it returns
+ // true, the callback is removed from the pending reply queue.
+ // if it returns false, the reply is considered incomplete and
+ // the callback will be invoked again for the next matching reply.
+ callback func(resp interface{}) (done bool)
+
+ // errc receives nil when the callback indicates completion or an
+ // error if no further reply is received within the timeout.
+ errc chan<- error
+}
+
+type reply struct {
+ from NodeID
+ ptype byte
+ data interface{}
+}
+
+// ListenUDP returns a new table that listens for UDP packets on laddr.
+func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table, error) {
+ addr, err := net.ResolveUDPAddr("udp", laddr)
+ if err != nil {
+ return nil, err
+ }
+ conn, err := net.ListenUDP("udp", addr)
+ if err != nil {
+ return nil, err
+ }
+ udp := &udp{
+ conn: conn,
+ priv: priv,
+ closing: make(chan struct{}),
+ addpending: make(chan *pending),
+ replies: make(chan reply),
+ }
+
+ realaddr := conn.LocalAddr().(*net.UDPAddr)
+ if natm != nil {
+ if !realaddr.IP.IsLoopback() {
+ go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
+ }
+ // TODO: react to external IP changes over time.
+ if ext, err := natm.ExternalIP(); err == nil {
+ realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
+ }
+ }
+ udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
+
+ go udp.loop()
+ go udp.readLoop()
+ log.Infoln("Listening, ", udp.self)
+ return udp.Table, nil
+}
+
+func (t *udp) close() {
+ close(t.closing)
+ t.conn.Close()
+ // TODO: wait for the loops to end.
+}
+
+// ping sends a ping message to the given node and waits for a reply.
+func (t *udp) ping(e *Node) error {
+ // TODO: maybe check for ReplyTo field in callback to measure RTT
+ errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
+ t.send(e, pingPacket, ping{
+ IP: t.self.IP.String(),
+ Port: uint16(t.self.TCPPort),
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ return <-errc
+}
+
+// findnode sends a findnode request to the given node and waits until
+// the node has sent up to k neighbors.
+func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
+ nodes := make([]*Node, 0, bucketSize)
+ nreceived := 0
+ errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool {
+ reply := r.(*neighbors)
+ for _, n := range reply.Nodes {
+ nreceived++
+ if n.isValid() {
+ nodes = append(nodes, n)
+ }
+ }
+ return nreceived >= bucketSize
+ })
+
+ t.send(to, findnodePacket, findnode{
+ Target: target,
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ err := <-errc
+ return nodes, err
+}
+
+// pending adds a reply callback to the pending reply queue.
+// see the documentation of type pending for a detailed explanation.
+func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
+ ch := make(chan error, 1)
+ p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
+ select {
+ case t.addpending <- p:
+ // loop will handle it
+ case <-t.closing:
+ ch <- errClosed
+ }
+ return ch
+}
+
+// loop runs in its own goroutin. it keeps track of
+// the refresh timer and the pending reply queue.
+func (t *udp) loop() {
+ var (
+ pending []*pending
+ nextDeadline time.Time
+ timeout = time.NewTimer(0)
+ refresh = time.NewTicker(refreshInterval)
+ )
+ <-timeout.C // ignore first timeout
+ defer refresh.Stop()
+ defer timeout.Stop()
+
+ rearmTimeout := func() {
+ if len(pending) == 0 || nextDeadline == pending[0].deadline {
+ return
+ }
+ nextDeadline = pending[0].deadline
+ timeout.Reset(nextDeadline.Sub(time.Now()))
+ }
+
+ for {
+ select {
+ case <-refresh.C:
+ go t.refresh()
+
+ case <-t.closing:
+ for _, p := range pending {
+ p.errc <- errClosed
+ }
+ return
+
+ case p := <-t.addpending:
+ p.deadline = time.Now().Add(respTimeout)
+ pending = append(pending, p)
+ rearmTimeout()
+
+ case reply := <-t.replies:
+ // run matching callbacks, remove if they return false.
+ for i := 0; i < len(pending); i++ {
+ p := pending[i]
+ if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) {
+ p.errc <- nil
+ copy(pending[i:], pending[i+1:])
+ pending = pending[:len(pending)-1]
+ i--
+ }
+ }
+ rearmTimeout()
+
+ case now := <-timeout.C:
+ // notify and remove callbacks whose deadline is in the past.
+ i := 0
+ for ; i < len(pending) && now.After(pending[i].deadline); i++ {
+ pending[i].errc <- errTimeout
+ }
+ if i > 0 {
+ copy(pending, pending[i:])
+ pending = pending[:len(pending)-i]
+ }
+ rearmTimeout()
+ }
+ }
+}
+
+const (
+ macSize = 256 / 8
+ sigSize = 520 / 8
+ headSize = macSize + sigSize // space of packet frame data
+)
+
+var headSpace = make([]byte, headSize)
+
+func (t *udp) send(to *Node, ptype byte, req interface{}) error {
+ b := new(bytes.Buffer)
+ b.Write(headSpace)
+ b.WriteByte(ptype)
+ if err := rlp.Encode(b, req); err != nil {
+ log.Errorln("error encoding packet:", err)
+ return err
+ }
+
+ packet := b.Bytes()
+ sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
+ if err != nil {
+ log.Errorln("could not sign packet:", err)
+ return err
+ }
+ copy(packet[macSize:], sig)
+ // add the hash to the front. Note: this doesn't protect the
+ // packet in any way. Our public key will be part of this hash in
+ // the future.
+ copy(packet, crypto.Sha3(packet[macSize:]))
+
+ toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
+ log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
+ if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
+ log.DebugDetailln("UDP send failed:", err)
+ }
+ return err
+}
+
+// readLoop runs in its own goroutine. it handles incoming UDP packets.
+func (t *udp) readLoop() {
+ defer t.conn.Close()
+ buf := make([]byte, 4096) // TODO: good buffer size
+ for {
+ nbytes, from, err := t.conn.ReadFromUDP(buf)
+ if err != nil {
+ return
+ }
+ if err := t.packetIn(from, buf[:nbytes]); err != nil {
+ log.Debugf("Bad packet from %v: %v\n", from, err)
+ }
+ }
+}
+
+func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
+ if len(buf) < headSize+1 {
+ return errPacketTooSmall
+ }
+ hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
+ shouldhash := crypto.Sha3(buf[macSize:])
+ if !bytes.Equal(hash, shouldhash) {
+ return errBadHash
+ }
+ fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
+ if err != nil {
+ return err
+ }
+
+ var req interface {
+ handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
+ }
+ switch ptype := sigdata[0]; ptype {
+ case pingPacket:
+ req = new(ping)
+ case pongPacket:
+ req = new(pong)
+ case findnodePacket:
+ req = new(findnode)
+ case neighborsPacket:
+ req = new(neighbors)
+ default:
+ return fmt.Errorf("unknown type: %d", ptype)
+ }
+ if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
+ return err
+ }
+ log.DebugDetailf("<<< %v %T %v\n", from, req, req)
+ return req.handle(t, from, fromID, hash)
+}
+
+func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ t.mutex.Lock()
+ // Note: we're ignoring the provided IP address right now
+ n := t.bumpOrAdd(fromID, from)
+ if req.Port != 0 {
+ n.TCPPort = int(req.Port)
+ }
+ t.mutex.Unlock()
+
+ t.send(n, pongPacket, pong{
+ ReplyTok: mac,
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ return nil
+}
+
+func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ t.mutex.Lock()
+ t.bump(fromID)
+ t.mutex.Unlock()
+
+ t.replies <- reply{fromID, pongPacket, req}
+ return nil
+}
+
+func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ t.mutex.Lock()
+ e := t.bumpOrAdd(fromID, from)
+ closest := t.closest(req.Target, bucketSize).entries
+ t.mutex.Unlock()
+
+ t.send(e, neighborsPacket, neighbors{
+ Nodes: closest,
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ return nil
+}
+
+func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ t.mutex.Lock()
+ t.bump(fromID)
+ t.add(req.Nodes)
+ t.mutex.Unlock()
+
+ t.replies <- reply{fromID, neighborsPacket, req}
+ return nil
+}
+
+func expired(ts uint64) bool {
+ return time.Unix(int64(ts), 0).Before(time.Now())
+}
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
new file mode 100644
index 000000000..0a8ff6358
--- /dev/null
+++ b/p2p/discover/udp_test.go
@@ -0,0 +1,211 @@
+package discover
+
+import (
+ "fmt"
+ logpkg "log"
+ "net"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/logger"
+)
+
+func init() {
+ logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
+}
+
+func TestUDP_ping(t *testing.T) {
+ t.Parallel()
+
+ n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ defer n1.Close()
+ defer n2.Close()
+
+ if err := n1.net.ping(n2.self); err != nil {
+ t.Fatalf("ping error: %v", err)
+ }
+ if find(n2, n1.self.ID) == nil {
+ t.Errorf("node 2 does not contain id of node 1")
+ }
+ if e := find(n1, n2.self.ID); e != nil {
+ t.Errorf("node 1 does contains id of node 2: %v", e)
+ }
+}
+
+func find(tab *Table, id NodeID) *Node {
+ for _, b := range tab.buckets {
+ for _, e := range b.entries {
+ if e.ID == id {
+ return e
+ }
+ }
+ }
+ return nil
+}
+
+func TestUDP_findnode(t *testing.T) {
+ t.Parallel()
+
+ n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ defer n1.Close()
+ defer n2.Close()
+
+ // put a few nodes into n2. the exact distribution shouldn't
+ // matter much, altough we need to take care not to overflow
+ // any bucket.
+ target := randomID(n1.self.ID, 100)
+ nodes := &nodesByDistance{target: target}
+ for i := 0; i < bucketSize; i++ {
+ n2.add([]*Node{&Node{
+ IP: net.IP{1, 2, 3, byte(i)},
+ DiscPort: i + 2,
+ TCPPort: i + 2,
+ ID: randomID(n2.self.ID, i+2),
+ }})
+ }
+ n2.add(nodes.entries)
+ n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
+ expected := n2.closest(target, bucketSize)
+
+ err := runUDP(10, func() error {
+ result, _ := n1.net.findnode(n2.self, target)
+ if len(result) != bucketSize {
+ return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
+ }
+ for i := range result {
+ if result[i].ID != expected.entries[i].ID {
+ return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestUDP_replytimeout(t *testing.T) {
+ t.Parallel()
+
+ // reserve a port so we don't talk to an existing service by accident
+ addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
+ fd, err := net.ListenUDP("udp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer fd.Close()
+
+ n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ defer n1.Close()
+ n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
+
+ if err := n1.net.ping(n2); err != errTimeout {
+ t.Error("expected timeout error, got", err)
+ }
+
+ if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
+ t.Error("expected timeout error, got", err)
+ } else if len(result) > 0 {
+ t.Error("expected empty result, got", result)
+ }
+}
+
+func TestUDP_findnodeMultiReply(t *testing.T) {
+ t.Parallel()
+
+ n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ udp2 := n2.net.(*udp)
+ defer n1.Close()
+ defer n2.Close()
+
+ err := runUDP(10, func() error {
+ nodes := make([]*Node, bucketSize)
+ for i := range nodes {
+ nodes[i] = &Node{
+ IP: net.IP{1, 2, 3, 4},
+ DiscPort: i + 1,
+ TCPPort: i + 1,
+ ID: randomID(n2.self.ID, i+1),
+ }
+ }
+
+ // ask N2 for neighbors. it will send an empty reply back.
+ // the request will wait for up to bucketSize replies.
+ resultc := make(chan []*Node)
+ errc := make(chan error)
+ go func() {
+ ns, err := n1.net.findnode(n2.self, n1.self.ID)
+ if err != nil {
+ errc <- err
+ } else {
+ resultc <- ns
+ }
+ }()
+
+ // send a few more neighbors packets to N1.
+ // it should collect those.
+ for end := 0; end < len(nodes); {
+ off := end
+ if end = end + 5; end > len(nodes) {
+ end = len(nodes)
+ }
+ udp2.send(n1.self, neighborsPacket, neighbors{
+ Nodes: nodes[off:end],
+ Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
+ })
+ }
+
+ // check that they are all returned. we cannot just check for
+ // equality because they might not be returned in the order they
+ // were sent.
+ var result []*Node
+ select {
+ case result = <-resultc:
+ case err := <-errc:
+ return err
+ }
+ if hasDuplicates(result) {
+ return fmt.Errorf("result slice contains duplicates")
+ }
+ if len(result) != len(nodes) {
+ return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
+ }
+ matched := make(map[NodeID]bool)
+ for _, n := range result {
+ for _, expn := range nodes {
+ if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
+ matched[n.ID] = true
+ }
+ }
+ }
+ if len(matched) != len(nodes) {
+ return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
+ }
+ return nil
+ })
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+// runUDP runs a test n times and returns an error if the test failed
+// in all n runs. This is necessary because UDP is unreliable even for
+// connections on the local machine, causing test failures.
+func runUDP(n int, test func() error) error {
+ errcount := 0
+ errors := ""
+ for i := 0; i < n; i++ {
+ if err := test(); err != nil {
+ errors += fmt.Sprintf("\n#%d: %v", i, err)
+ errcount++
+ }
+ }
+ if errcount == n {
+ return fmt.Errorf("failed on all %d iterations:%s", n, errors)
+ }
+ return nil
+}
diff --git a/p2p/handshake.go b/p2p/handshake.go
new file mode 100644
index 000000000..614711eaf
--- /dev/null
+++ b/p2p/handshake.go
@@ -0,0 +1,434 @@
+package p2p
+
+import (
+ "crypto/ecdsa"
+ "crypto/rand"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/ecies"
+ "github.com/ethereum/go-ethereum/crypto/secp256k1"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+const (
+ sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
+ sigLen = 65 // elliptic S256
+ pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
+ shaLen = 32 // hash length (for nonce etc)
+
+ authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
+ authRespLen = pubLen + shaLen + 1
+
+ eciesBytes = 65 + 16 + 32
+ iHSLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
+ rHSLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
+)
+
+type conn struct {
+ *frameRW
+ *protoHandshake
+}
+
+func newConn(fd net.Conn, hs *protoHandshake) *conn {
+ return &conn{newFrameRW(fd, msgWriteTimeout), hs}
+}
+
+// encHandshake represents information about the remote end
+// of a connection that is negotiated during the encryption handshake.
+type encHandshake struct {
+ ID discover.NodeID
+ IngressMAC []byte
+ EgressMAC []byte
+ Token []byte
+}
+
+// protoHandshake is the RLP structure of the protocol handshake.
+type protoHandshake struct {
+ Version uint64
+ Name string
+ Caps []Cap
+ ListenPort uint64
+ ID discover.NodeID
+}
+
+// setupConn starts a protocol session on the given connection.
+// It runs the encryption handshake and the protocol handshake.
+// If dial is non-nil, the connection the local node is the initiator.
+func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
+ if dial == nil {
+ return setupInboundConn(fd, prv, our)
+ } else {
+ return setupOutboundConn(fd, prv, our, dial)
+ }
+}
+
+func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (*conn, error) {
+ // var remotePubkey []byte
+ // sessionToken, remotePubkey, err = inboundEncHandshake(fd, prv, nil)
+ // copy(remoteID[:], remotePubkey)
+
+ rw := newFrameRW(fd, msgWriteTimeout)
+ rhs, err := readProtocolHandshake(rw, our)
+ if err != nil {
+ return nil, err
+ }
+ if err := writeProtocolHandshake(rw, our); err != nil {
+ return nil, fmt.Errorf("protocol write error: %v", err)
+ }
+ return &conn{rw, rhs}, nil
+}
+
+func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
+ // remoteID = dial.ID
+ // sessionToken, err = outboundEncHandshake(fd, prv, remoteID[:], nil)
+
+ rw := newFrameRW(fd, msgWriteTimeout)
+ if err := writeProtocolHandshake(rw, our); err != nil {
+ return nil, fmt.Errorf("protocol write error: %v", err)
+ }
+ rhs, err := readProtocolHandshake(rw, our)
+ if err != nil {
+ return nil, fmt.Errorf("protocol handshake read error: %v", err)
+ }
+ if rhs.ID != dial.ID {
+ return nil, errors.New("dialed node id mismatch")
+ }
+ return &conn{rw, rhs}, nil
+}
+
+// outboundEncHandshake negotiates a session token on conn.
+// it should be called on the dialing side of the connection.
+//
+// privateKey is the local client's private key
+// remotePublicKey is the remote peer's node ID
+// sessionToken is the token from a previous session with this node.
+func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePublicKey []byte, sessionToken []byte) (
+ newSessionToken []byte,
+ err error,
+) {
+ auth, initNonce, randomPrivKey, err := authMsg(prvKey, remotePublicKey, sessionToken)
+ if err != nil {
+ return nil, err
+ }
+ if _, err = conn.Write(auth); err != nil {
+ return nil, err
+ }
+
+ response := make([]byte, rHSLen)
+ if _, err = io.ReadFull(conn, response); err != nil {
+ return nil, err
+ }
+ recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prvKey)
+ if err != nil {
+ return nil, err
+ }
+
+ return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
+}
+
+// authMsg creates the initiator handshake.
+func authMsg(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte) (
+ auth, initNonce []byte,
+ randomPrvKey *ecdsa.PrivateKey,
+ err error,
+) {
+ // session init, common to both parties
+ remotePubKey, err := importPublicKey(remotePubKeyS)
+ if err != nil {
+ return
+ }
+
+ var tokenFlag byte // = 0x00
+ if sessionToken == nil {
+ // no session token found means we need to generate shared secret.
+ // ecies shared secret is used as initial session token for new peers
+ // generate shared key from prv and remote pubkey
+ if sessionToken, err = ecies.ImportECDSA(prvKey).GenerateShared(ecies.ImportECDSAPublic(remotePubKey), sskLen, sskLen); err != nil {
+ return
+ }
+ // tokenFlag = 0x00 // redundant
+ } else {
+ // for known peers, we use stored token from the previous session
+ tokenFlag = 0x01
+ }
+
+ //E(remote-pubk, S(ecdhe-random, ecdh-shared-secret^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x0)
+ // E(remote-pubk, S(ecdhe-random, token^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x1)
+ // allocate msgLen long message,
+ var msg []byte = make([]byte, authMsgLen)
+ initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
+ if _, err = rand.Read(initNonce); err != nil {
+ return
+ }
+ // create known message
+ // ecdh-shared-secret^nonce for new peers
+ // token^nonce for old peers
+ var sharedSecret = xor(sessionToken, initNonce)
+
+ // generate random keypair to use for signing
+ if randomPrvKey, err = crypto.GenerateKey(); err != nil {
+ return
+ }
+ // sign shared secret (message known to both parties): shared-secret
+ var signature []byte
+ // signature = sign(ecdhe-random, shared-secret)
+ // uses secp256k1.Sign
+ if signature, err = crypto.Sign(sharedSecret, randomPrvKey); err != nil {
+ return
+ }
+
+ // message
+ // signed-shared-secret || H(ecdhe-random-pubk) || pubk || nonce || 0x0
+ copy(msg, signature) // copy signed-shared-secret
+ // H(ecdhe-random-pubk)
+ var randomPubKey64 []byte
+ if randomPubKey64, err = exportPublicKey(&randomPrvKey.PublicKey); err != nil {
+ return
+ }
+ var pubKey64 []byte
+ if pubKey64, err = exportPublicKey(&prvKey.PublicKey); err != nil {
+ return
+ }
+ copy(msg[sigLen:sigLen+shaLen], crypto.Sha3(randomPubKey64))
+ // pubkey copied to the correct segment.
+ copy(msg[sigLen+shaLen:sigLen+shaLen+pubLen], pubKey64)
+ // nonce is already in the slice
+ // stick tokenFlag byte to the end
+ msg[authMsgLen-1] = tokenFlag
+
+ // encrypt using remote-pubk
+ // auth = eciesEncrypt(remote-pubk, msg)
+ if auth, err = crypto.Encrypt(remotePubKey, msg); err != nil {
+ return
+ }
+ return
+}
+
+// completeHandshake is called when the initiator receives an
+// authentication response (aka receiver handshake). It completes the
+// handshake by reading off parameters the remote peer provides needed
+// to set up the secure session.
+func completeHandshake(auth []byte, prvKey *ecdsa.PrivateKey) (
+ respNonce []byte,
+ remoteRandomPubKey *ecdsa.PublicKey,
+ tokenFlag bool,
+ err error,
+) {
+ var msg []byte
+ // they prove that msg is meant for me,
+ // I prove I possess private key if i can read it
+ if msg, err = crypto.Decrypt(prvKey, auth); err != nil {
+ return
+ }
+
+ respNonce = msg[pubLen : pubLen+shaLen]
+ var remoteRandomPubKeyS = msg[:pubLen]
+ if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil {
+ return
+ }
+ if msg[authRespLen-1] == 0x01 {
+ tokenFlag = true
+ }
+ return
+}
+
+// inboundEncHandshake negotiates a session token on conn.
+// it should be called on the listening side of the connection.
+//
+// privateKey is the local client's private key
+// sessionToken is the token from a previous session with this node.
+func inboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, sessionToken []byte) (
+ token, remotePubKey []byte,
+ err error,
+) {
+ // we are listening connection. we are responders in the
+ // handshake. Extract info from the authentication. The initiator
+ // starts by sending us a handshake that we need to respond to. so
+ // we read auth message first, then respond.
+ auth := make([]byte, iHSLen)
+ if _, err := io.ReadFull(conn, auth); err != nil {
+ return nil, nil, err
+ }
+ response, recNonce, initNonce, remotePubKey, randomPrivKey, remoteRandomPubKey, err := authResp(auth, sessionToken, prvKey)
+ if err != nil {
+ return nil, nil, err
+ }
+ if _, err = conn.Write(response); err != nil {
+ return nil, nil, err
+ }
+ token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
+ return token, remotePubKey, err
+}
+
+// authResp is called by peer if it accepted (but not
+// initiated) the connection from the remote. It is passed the initiator
+// handshake received and the session token belonging to the
+// remote initiator.
+//
+// The first return value is the authentication response (aka receiver
+// handshake) that is to be sent to the remote initiator.
+func authResp(auth, sessionToken []byte, prvKey *ecdsa.PrivateKey) (
+ authResp, respNonce, initNonce, remotePubKeyS []byte,
+ randomPrivKey *ecdsa.PrivateKey,
+ remoteRandomPubKey *ecdsa.PublicKey,
+ err error,
+) {
+ // they prove that msg is meant for me,
+ // I prove I possess private key if i can read it
+ msg, err := crypto.Decrypt(prvKey, auth)
+ if err != nil {
+ return
+ }
+
+ remotePubKeyS = msg[sigLen+shaLen : sigLen+shaLen+pubLen]
+ remotePubKey, _ := importPublicKey(remotePubKeyS)
+
+ var tokenFlag byte
+ if sessionToken == nil {
+ // no session token found means we need to generate shared secret.
+ // ecies shared secret is used as initial session token for new peers
+ // generate shared key from prv and remote pubkey
+ if sessionToken, err = ecies.ImportECDSA(prvKey).GenerateShared(ecies.ImportECDSAPublic(remotePubKey), sskLen, sskLen); err != nil {
+ return
+ }
+ // tokenFlag = 0x00 // redundant
+ } else {
+ // for known peers, we use stored token from the previous session
+ tokenFlag = 0x01
+ }
+
+ // the initiator nonce is read off the end of the message
+ initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
+ // I prove that i own prv key (to derive shared secret, and read
+ // nonce off encrypted msg) and that I own shared secret they
+ // prove they own the private key belonging to ecdhe-random-pubk
+ // we can now reconstruct the signed message and recover the peers
+ // pubkey
+ var signedMsg = xor(sessionToken, initNonce)
+ var remoteRandomPubKeyS []byte
+ if remoteRandomPubKeyS, err = secp256k1.RecoverPubkey(signedMsg, msg[:sigLen]); err != nil {
+ return
+ }
+ // convert to ECDSA standard
+ if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil {
+ return
+ }
+
+ // now we find ourselves a long task too, fill it random
+ var resp = make([]byte, authRespLen)
+ // generate shaLen long nonce
+ respNonce = resp[pubLen : pubLen+shaLen]
+ if _, err = rand.Read(respNonce); err != nil {
+ return
+ }
+ // generate random keypair for session
+ if randomPrivKey, err = crypto.GenerateKey(); err != nil {
+ return
+ }
+ // responder auth message
+ // E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
+ var randomPubKeyS []byte
+ if randomPubKeyS, err = exportPublicKey(&randomPrivKey.PublicKey); err != nil {
+ return
+ }
+ copy(resp[:pubLen], randomPubKeyS)
+ // nonce is already in the slice
+ resp[authRespLen-1] = tokenFlag
+
+ // encrypt using remote-pubk
+ // auth = eciesEncrypt(remote-pubk, msg)
+ // why not encrypt with ecdhe-random-remote
+ if authResp, err = crypto.Encrypt(remotePubKey, resp); err != nil {
+ return
+ }
+ return
+}
+
+// newSession is called after the handshake is completed. The
+// arguments are values negotiated in the handshake. The return value
+// is a new session Token to be remembered for the next time we
+// connect with this peer.
+func newSession(initNonce, respNonce []byte, privKey *ecdsa.PrivateKey, remoteRandomPubKey *ecdsa.PublicKey) ([]byte, error) {
+ // 3) Now we can trust ecdhe-random-pubk to derive new keys
+ //ecdhe-shared-secret = ecdh.agree(ecdhe-random, remote-ecdhe-random-pubk)
+ pubKey := ecies.ImportECDSAPublic(remoteRandomPubKey)
+ dhSharedSecret, err := ecies.ImportECDSA(privKey).GenerateShared(pubKey, sskLen, sskLen)
+ if err != nil {
+ return nil, err
+ }
+ sharedSecret := crypto.Sha3(dhSharedSecret, crypto.Sha3(respNonce, initNonce))
+ sessionToken := crypto.Sha3(sharedSecret)
+ return sessionToken, nil
+}
+
+// importPublicKey unmarshals 512 bit public keys.
+func importPublicKey(pubKey []byte) (pubKeyEC *ecdsa.PublicKey, err error) {
+ var pubKey65 []byte
+ switch len(pubKey) {
+ case 64:
+ // add 'uncompressed key' flag
+ pubKey65 = append([]byte{0x04}, pubKey...)
+ case 65:
+ pubKey65 = pubKey
+ default:
+ return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
+ }
+ return crypto.ToECDSAPub(pubKey65), nil
+}
+
+func exportPublicKey(pubKeyEC *ecdsa.PublicKey) (pubKey []byte, err error) {
+ if pubKeyEC == nil {
+ return nil, fmt.Errorf("no ECDSA public key given")
+ }
+ return crypto.FromECDSAPub(pubKeyEC)[1:], nil
+}
+
+func xor(one, other []byte) (xor []byte) {
+ xor = make([]byte, len(one))
+ for i := 0; i < len(one); i++ {
+ xor[i] = one[i] ^ other[i]
+ }
+ return xor
+}
+
+func writeProtocolHandshake(w MsgWriter, our *protoHandshake) error {
+ return EncodeMsg(w, handshakeMsg, our.Version, our.Name, our.Caps, our.ListenPort, our.ID[:])
+}
+
+func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
+ // read and handle remote handshake
+ msg, err := r.ReadMsg()
+ if err != nil {
+ return nil, err
+ }
+ if msg.Code == discMsg {
+ // disconnect before protocol handshake is valid according to the
+ // spec and we send it ourself if Server.addPeer fails.
+ var reason DiscReason
+ rlp.Decode(msg.Payload, &reason)
+ return nil, discRequestedError(reason)
+ }
+ if msg.Code != handshakeMsg {
+ return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
+ }
+ if msg.Size > baseProtocolMaxMsgSize {
+ return nil, fmt.Errorf("message too big (%d > %d)", msg.Size, baseProtocolMaxMsgSize)
+ }
+ var hs protoHandshake
+ if err := msg.Decode(&hs); err != nil {
+ return nil, err
+ }
+ // validate handshake info
+ if hs.Version != our.Version {
+ return nil, newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", baseProtocolVersion, hs.Version)
+ }
+ if (hs.ID == discover.NodeID{}) {
+ return nil, newPeerError(errPubkeyInvalid, "missing")
+ }
+ return &hs, nil
+}
diff --git a/p2p/handshake_test.go b/p2p/handshake_test.go
new file mode 100644
index 000000000..06c6a6932
--- /dev/null
+++ b/p2p/handshake_test.go
@@ -0,0 +1,224 @@
+package p2p
+
+import (
+ "bytes"
+ "crypto/ecdsa"
+ "crypto/rand"
+ "net"
+ "reflect"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/ecies"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+)
+
+func TestPublicKeyEncoding(t *testing.T) {
+ prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
+ pub0 := &prv0.PublicKey
+ pub0s := crypto.FromECDSAPub(pub0)
+ pub1, err := importPublicKey(pub0s)
+ if err != nil {
+ t.Errorf("%v", err)
+ }
+ eciesPub1 := ecies.ImportECDSAPublic(pub1)
+ if eciesPub1 == nil {
+ t.Errorf("invalid ecdsa public key")
+ }
+ pub1s, err := exportPublicKey(pub1)
+ if err != nil {
+ t.Errorf("%v", err)
+ }
+ if len(pub1s) != 64 {
+ t.Errorf("wrong length expect 64, got", len(pub1s))
+ }
+ pub2, err := importPublicKey(pub1s)
+ if err != nil {
+ t.Errorf("%v", err)
+ }
+ pub2s, err := exportPublicKey(pub2)
+ if err != nil {
+ t.Errorf("%v", err)
+ }
+ if !bytes.Equal(pub1s, pub2s) {
+ t.Errorf("exports dont match")
+ }
+ pub2sEC := crypto.FromECDSAPub(pub2)
+ if !bytes.Equal(pub0s, pub2sEC) {
+ t.Errorf("exports dont match")
+ }
+}
+
+func TestSharedSecret(t *testing.T) {
+ prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
+ pub0 := &prv0.PublicKey
+ prv1, _ := crypto.GenerateKey()
+ pub1 := &prv1.PublicKey
+
+ ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
+ if err != nil {
+ return
+ }
+ ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
+ if err != nil {
+ return
+ }
+ t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
+ if !bytes.Equal(ss0, ss1) {
+ t.Errorf("dont match :(")
+ }
+}
+
+func TestCryptoHandshake(t *testing.T) {
+ testCryptoHandshake(newkey(), newkey(), nil, t)
+}
+
+func TestCryptoHandshakeWithToken(t *testing.T) {
+ sessionToken := make([]byte, shaLen)
+ rand.Read(sessionToken)
+ testCryptoHandshake(newkey(), newkey(), sessionToken, t)
+}
+
+func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *testing.T) {
+ var err error
+ // pub0 := &prv0.PublicKey
+ pub1 := &prv1.PublicKey
+
+ // pub0s := crypto.FromECDSAPub(pub0)
+ pub1s := crypto.FromECDSAPub(pub1)
+
+ // simulate handshake by feeding output to input
+ // initiator sends handshake 'auth'
+ auth, initNonce, randomPrivKey, err := authMsg(prv0, pub1s, sessionToken)
+ if err != nil {
+ t.Errorf("%v", err)
+ }
+ // t.Logf("-> %v", hexkey(auth))
+
+ // receiver reads auth and responds with response
+ response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1)
+ if err != nil {
+ t.Errorf("%v", err)
+ }
+ // t.Logf("<- %v\n", hexkey(response))
+
+ // initiator reads receiver's response and the key exchange completes
+ recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0)
+ if err != nil {
+ t.Errorf("completeHandshake error: %v", err)
+ }
+
+ // now both parties should have the same session parameters
+ initSessionToken, err := newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
+ if err != nil {
+ t.Errorf("newSession error: %v", err)
+ }
+
+ recSessionToken, err := newSession(remoteInitNonce, remoteRecNonce, remoteRandomPrivKey, remoteInitRandomPubKey)
+ if err != nil {
+ t.Errorf("newSession error: %v", err)
+ }
+
+ // fmt.Printf("\nauth (%v) %x\n\nresp (%v) %x\n\n", len(auth), auth, len(response), response)
+
+ // fmt.Printf("\nauth %x\ninitNonce %x\nresponse%x\nremoteRecNonce %x\nremoteInitNonce %x\nremoteRandomPubKey %x\nrecNonce %x\nremoteInitRandomPubKey %x\ninitSessionToken %x\n\n", auth, initNonce, response, remoteRecNonce, remoteInitNonce, remoteRandomPubKey, recNonce, remoteInitRandomPubKey, initSessionToken)
+
+ if !bytes.Equal(initNonce, remoteInitNonce) {
+ t.Errorf("nonces do not match")
+ }
+ if !bytes.Equal(recNonce, remoteRecNonce) {
+ t.Errorf("receiver nonces do not match")
+ }
+ if !bytes.Equal(initSessionToken, recSessionToken) {
+ t.Errorf("session tokens do not match")
+ }
+}
+
+func TestEncHandshake(t *testing.T) {
+ defer testlog(t).detach()
+
+ prv0, _ := crypto.GenerateKey()
+ prv1, _ := crypto.GenerateKey()
+ pub0s, _ := exportPublicKey(&prv0.PublicKey)
+ pub1s, _ := exportPublicKey(&prv1.PublicKey)
+ rw0, rw1 := net.Pipe()
+ tokens := make(chan []byte)
+
+ go func() {
+ token, err := outboundEncHandshake(rw0, prv0, pub1s, nil)
+ if err != nil {
+ t.Errorf("outbound side error: %v", err)
+ }
+ tokens <- token
+ }()
+ go func() {
+ token, remotePubkey, err := inboundEncHandshake(rw1, prv1, nil)
+ if err != nil {
+ t.Errorf("inbound side error: %v", err)
+ }
+ if !bytes.Equal(remotePubkey, pub0s) {
+ t.Errorf("inbound side returned wrong remote pubkey\n got: %x\n want: %x", remotePubkey, pub0s)
+ }
+ tokens <- token
+ }()
+
+ t1, t2 := <-tokens, <-tokens
+ if !bytes.Equal(t1, t2) {
+ t.Error("session token mismatch")
+ }
+}
+
+func TestSetupConn(t *testing.T) {
+ prv0, _ := crypto.GenerateKey()
+ prv1, _ := crypto.GenerateKey()
+ node0 := &discover.Node{
+ ID: discover.PubkeyID(&prv0.PublicKey),
+ IP: net.IP{1, 2, 3, 4},
+ TCPPort: 33,
+ }
+ node1 := &discover.Node{
+ ID: discover.PubkeyID(&prv1.PublicKey),
+ IP: net.IP{5, 6, 7, 8},
+ TCPPort: 44,
+ }
+ hs0 := &protoHandshake{
+ Version: baseProtocolVersion,
+ ID: node0.ID,
+ Caps: []Cap{{"a", 0}, {"b", 2}},
+ }
+ hs1 := &protoHandshake{
+ Version: baseProtocolVersion,
+ ID: node1.ID,
+ Caps: []Cap{{"c", 1}, {"d", 3}},
+ }
+ fd0, fd1 := net.Pipe()
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ conn0, err := setupConn(fd0, prv0, hs0, node1)
+ if err != nil {
+ t.Errorf("outbound side error: %v", err)
+ return
+ }
+ if conn0.ID != node1.ID {
+ t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
+ }
+ if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
+ t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
+ }
+ }()
+
+ conn1, err := setupConn(fd1, prv1, hs1, nil)
+ if err != nil {
+ t.Fatalf("inbound side error: %v", err)
+ }
+ if conn1.ID != node0.ID {
+ t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
+ }
+ if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
+ t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
+ }
+
+ <-done
+}
diff --git a/p2p/message.go b/p2p/message.go
index daf2bf05c..7adad4b09 100644
--- a/p2p/message.go
+++ b/p2p/message.go
@@ -1,6 +1,7 @@
package p2p
import (
+ "bufio"
"bytes"
"encoding/binary"
"errors"
@@ -8,12 +9,37 @@ import (
"io"
"io/ioutil"
"math/big"
+ "net"
+ "sync"
"sync/atomic"
+ "time"
"github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/rlp"
)
+// parameters for frameRW
+const (
+ // maximum time allowed for reading a message header.
+ // this is effectively the amount of time a connection can be idle.
+ frameReadTimeout = 1 * time.Minute
+
+ // maximum time allowed for reading the payload data of a message.
+ // this is shorter than (and distinct from) frameReadTimeout because
+ // the connection is not considered idle while a message is transferred.
+ // this also limits the payload size of messages to how much the connection
+ // can transfer within the timeout.
+ payloadReadTimeout = 5 * time.Second
+
+ // maximum amount of time allowed for writing a complete message.
+ msgWriteTimeout = 5 * time.Second
+
+ // messages smaller than this many bytes will be read at
+ // once before passing them to a protocol. this increases
+ // concurrency in the processing.
+ wholePayloadSize = 64 * 1024
+)
+
// Msg defines the structure of a p2p message.
//
// Note that a Msg can only be sent once since the Payload reader is
@@ -74,11 +100,14 @@ type MsgWriter interface {
// WriteMsg sends a message. It will block until the message's
// Payload has been consumed by the other end.
//
- // Note that messages can be sent only once.
+ // Note that messages can be sent only once because their
+ // payload reader is drained.
WriteMsg(Msg) error
}
// MsgReadWriter provides reading and writing of encoded messages.
+// Implementations should ensure that ReadMsg and WriteMsg can be
+// called simultaneously from multiple goroutines.
type MsgReadWriter interface {
MsgReader
MsgWriter
@@ -90,8 +119,45 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
return w.WriteMsg(NewMsg(code, data...))
}
+// frameRW is a MsgReadWriter that reads and writes devp2p message frames.
+// As required by the interface, ReadMsg and WriteMsg can be called from
+// multiple goroutines.
+type frameRW struct {
+ net.Conn // make Conn methods available. be careful.
+ bufconn *bufio.ReadWriter
+
+ // this channel is used to 'lend' bufconn to a caller of ReadMsg
+ // until the message payload has been consumed. the channel
+ // receives a value when EOF is reached on the payload, unblocking
+ // a pending call to ReadMsg.
+ rsync chan struct{}
+
+ // this mutex guards writes to bufconn.
+ writeMu sync.Mutex
+}
+
+func newFrameRW(conn net.Conn, timeout time.Duration) *frameRW {
+ rsync := make(chan struct{}, 1)
+ rsync <- struct{}{}
+ return &frameRW{
+ Conn: conn,
+ bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
+ rsync: rsync,
+ }
+}
+
var magicToken = []byte{34, 64, 8, 145}
+func (rw *frameRW) WriteMsg(msg Msg) error {
+ rw.writeMu.Lock()
+ defer rw.writeMu.Unlock()
+ rw.SetWriteDeadline(time.Now().Add(msgWriteTimeout))
+ if err := writeMsg(rw.bufconn, msg); err != nil {
+ return err
+ }
+ return rw.bufconn.Flush()
+}
+
func writeMsg(w io.Writer, msg Msg) error {
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
code := ethutil.Encode(uint32(msg.Code))
@@ -120,31 +186,51 @@ func makeListHeader(length uint32) []byte {
return append([]byte{lenb}, enc...)
}
-// readMsg reads a message header from r.
-// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer.
-func readMsg(r rlp.ByteReader) (msg Msg, err error) {
+func (rw *frameRW) ReadMsg() (msg Msg, err error) {
+ <-rw.rsync // wait until bufconn is ours
+
+ rw.SetReadDeadline(time.Now().Add(frameReadTimeout))
+
// read magic and payload size
start := make([]byte, 8)
- if _, err = io.ReadFull(r, start); err != nil {
- return msg, newPeerError(errRead, "%v", err)
+ if _, err = io.ReadFull(rw.bufconn, start); err != nil {
+ return msg, err
}
if !bytes.HasPrefix(start, magicToken) {
- return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
+ return msg, fmt.Errorf("bad magic token %x", start[:4])
}
size := binary.BigEndian.Uint32(start[4:])
// decode start of RLP message to get the message code
- posr := &postrack{r, 0}
+ posr := &postrack{rw.bufconn, 0}
s := rlp.NewStream(posr)
if _, err := s.List(); err != nil {
return msg, err
}
- code, err := s.Uint()
+ msg.Code, err = s.Uint()
if err != nil {
return msg, err
}
- payloadsize := size - posr.p
- return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil
+ msg.Size = size - posr.p
+
+ rw.SetReadDeadline(time.Now().Add(payloadReadTimeout))
+
+ if msg.Size <= wholePayloadSize {
+ // msg is small, read all of it and move on to the next message.
+ pbuf := make([]byte, msg.Size)
+ if _, err := io.ReadFull(rw.bufconn, pbuf); err != nil {
+ return msg, err
+ }
+ rw.rsync <- struct{}{} // bufconn is available again
+ msg.Payload = bytes.NewReader(pbuf)
+ } else {
+ // lend bufconn to the caller until it has
+ // consumed the payload. eofSignal will send a value
+ // on rw.rsync when EOF is reached.
+ pr := &eofSignal{rw.bufconn, msg.Size, rw.rsync}
+ msg.Payload = pr
+ }
+ return msg, nil
}
// postrack wraps an rlp.ByteReader with a position counter.
@@ -167,6 +253,39 @@ func (r *postrack) ReadByte() (byte, error) {
return b, err
}
+// eofSignal wraps a reader with eof signaling. the eof channel is
+// closed when the wrapped reader returns an error or when count bytes
+// have been read.
+type eofSignal struct {
+ wrapped io.Reader
+ count uint32 // number of bytes left
+ eof chan<- struct{}
+}
+
+// note: when using eofSignal to detect whether a message payload
+// has been read, Read might not be called for zero sized messages.
+func (r *eofSignal) Read(buf []byte) (int, error) {
+ if r.count == 0 {
+ if r.eof != nil {
+ r.eof <- struct{}{}
+ r.eof = nil
+ }
+ return 0, io.EOF
+ }
+
+ max := len(buf)
+ if int(r.count) < len(buf) {
+ max = int(r.count)
+ }
+ n, err := r.wrapped.Read(buf[:max])
+ r.count -= uint32(n)
+ if (err != nil || r.count == 0) && r.eof != nil {
+ r.eof <- struct{}{} // tell Peer that msg has been consumed
+ r.eof = nil
+ }
+ return n, err
+}
+
// MsgPipe creates a message pipe. Reads on one end are matched
// with writes on the other. The pipe is full-duplex, both ends
// implement MsgReadWriter.
@@ -198,7 +317,7 @@ type MsgPipeRW struct {
func (p *MsgPipeRW) WriteMsg(msg Msg) error {
if atomic.LoadInt32(p.closed) == 0 {
consumed := make(chan struct{}, 1)
- msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed}
+ msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed}
select {
case p.w <- msg:
if msg.Size > 0 {
diff --git a/p2p/message_test.go b/p2p/message_test.go
index 5cde9abf5..4b94ebb5f 100644
--- a/p2p/message_test.go
+++ b/p2p/message_test.go
@@ -3,12 +3,11 @@ package p2p
import (
"bytes"
"fmt"
+ "io"
"io/ioutil"
"runtime"
"testing"
"time"
-
- "github.com/ethereum/go-ethereum/ethutil"
)
func TestNewMsg(t *testing.T) {
@@ -26,51 +25,51 @@ func TestNewMsg(t *testing.T) {
}
}
-func TestEncodeDecodeMsg(t *testing.T) {
- msg := NewMsg(3, 1, "000")
- buf := new(bytes.Buffer)
- if err := writeMsg(buf, msg); err != nil {
- t.Fatalf("encodeMsg error: %v", err)
- }
- // t.Logf("encoded: %x", buf.Bytes())
+// func TestEncodeDecodeMsg(t *testing.T) {
+// msg := NewMsg(3, 1, "000")
+// buf := new(bytes.Buffer)
+// if err := writeMsg(buf, msg); err != nil {
+// t.Fatalf("encodeMsg error: %v", err)
+// }
+// // t.Logf("encoded: %x", buf.Bytes())
- decmsg, err := readMsg(buf)
- if err != nil {
- t.Fatalf("readMsg error: %v", err)
- }
- if decmsg.Code != 3 {
- t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
- }
- if decmsg.Size != 5 {
- t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
- }
+// decmsg, err := readMsg(buf)
+// if err != nil {
+// t.Fatalf("readMsg error: %v", err)
+// }
+// if decmsg.Code != 3 {
+// t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
+// }
+// if decmsg.Size != 5 {
+// t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
+// }
- var data struct {
- I uint
- S string
- }
- if err := decmsg.Decode(&data); err != nil {
- t.Fatalf("Decode error: %v", err)
- }
- if data.I != 1 {
- t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
- }
- if data.S != "000" {
- t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
- }
-}
+// var data struct {
+// I uint
+// S string
+// }
+// if err := decmsg.Decode(&data); err != nil {
+// t.Fatalf("Decode error: %v", err)
+// }
+// if data.I != 1 {
+// t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
+// }
+// if data.S != "000" {
+// t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
+// }
+// }
-func TestDecodeRealMsg(t *testing.T) {
- data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
- msg, err := readMsg(bytes.NewReader(data))
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
+// func TestDecodeRealMsg(t *testing.T) {
+// data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
+// msg, err := readMsg(bytes.NewReader(data))
+// if err != nil {
+// t.Fatalf("unexpected error: %v", err)
+// }
- if msg.Code != 0 {
- t.Errorf("incorrect code %d, want %d", msg.Code, 0)
- }
-}
+// if msg.Code != 0 {
+// t.Errorf("incorrect code %d, want %d", msg.Code, 0)
+// }
+// }
func ExampleMsgPipe() {
rw1, rw2 := MsgPipe()
@@ -131,3 +130,58 @@ func TestMsgPipeConcurrentClose(t *testing.T) {
go rw1.Close()
}
}
+
+func TestEOFSignal(t *testing.T) {
+ rb := make([]byte, 10)
+
+ // empty reader
+ eof := make(chan struct{}, 1)
+ sig := &eofSignal{new(bytes.Buffer), 0, eof}
+ if n, err := sig.Read(rb); n != 0 || err != io.EOF {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ default:
+ t.Error("EOF chan not signaled")
+ }
+
+ // count before error
+ eof = make(chan struct{}, 1)
+ sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
+ if n, err := sig.Read(rb); n != 4 || err != nil {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ default:
+ t.Error("EOF chan not signaled")
+ }
+
+ // error before count
+ eof = make(chan struct{}, 1)
+ sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
+ if n, err := sig.Read(rb); n != 4 || err != nil {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ if n, err := sig.Read(rb); n != 0 || err != io.EOF {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ default:
+ t.Error("EOF chan not signaled")
+ }
+
+ // no signal if neither occurs
+ eof = make(chan struct{}, 1)
+ sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
+ if n, err := sig.Read(rb); n != 10 || err != nil {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ t.Error("unexpected EOF signal")
+ default:
+ }
+}
diff --git a/p2p/nat.go b/p2p/nat.go
deleted file mode 100644
index 9b771c3e8..000000000
--- a/p2p/nat.go
+++ /dev/null
@@ -1,23 +0,0 @@
-package p2p
-
-import (
- "fmt"
- "net"
-)
-
-func ParseNAT(natType string, gateway string) (nat NAT, err error) {
- switch natType {
- case "UPNP":
- nat = UPNP()
- case "PMP":
- ip := net.ParseIP(gateway)
- if ip == nil {
- return nil, fmt.Errorf("cannot resolve PMP gateway IP %s", gateway)
- }
- nat = PMP(ip)
- case "":
- default:
- return nil, fmt.Errorf("unrecognised NAT type '%s'", natType)
- }
- return
-}
diff --git a/p2p/nat/nat.go b/p2p/nat/nat.go
new file mode 100644
index 000000000..12d355ba1
--- /dev/null
+++ b/p2p/nat/nat.go
@@ -0,0 +1,235 @@
+// Package nat provides access to common port mapping protocols.
+package nat
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/jackpal/go-nat-pmp"
+)
+
+var log = logger.NewLogger("P2P NAT")
+
+// An implementation of nat.Interface can map local ports to ports
+// accessible from the Internet.
+type Interface interface {
+ // These methods manage a mapping between a port on the local
+ // machine to a port that can be connected to from the internet.
+ //
+ // protocol is "UDP" or "TCP". Some implementations allow setting
+ // a display name for the mapping. The mapping may be removed by
+ // the gateway when its lifetime ends.
+ AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
+ DeleteMapping(protocol string, extport, intport int) error
+
+ // This method should return the external (Internet-facing)
+ // address of the gateway device.
+ ExternalIP() (net.IP, error)
+
+ // Should return name of the method. This is used for logging.
+ String() string
+}
+
+// Parse parses a NAT interface description.
+// The following formats are currently accepted.
+// Note that mechanism names are not case-sensitive.
+//
+// "" or "none" return nil
+// "extip:77.12.33.4" will assume the local machine is reachable on the given IP
+// "any" uses the first auto-detected mechanism
+// "upnp" uses the Universal Plug and Play protocol
+// "pmp" uses NAT-PMP with an auto-detected gateway address
+// "pmp:192.168.0.1" uses NAT-PMP with the given gateway address
+func Parse(spec string) (Interface, error) {
+ var (
+ parts = strings.SplitN(spec, ":", 2)
+ mech = strings.ToLower(parts[0])
+ ip net.IP
+ )
+ if len(parts) > 1 {
+ ip = net.ParseIP(parts[1])
+ if ip == nil {
+ return nil, errors.New("invalid IP address")
+ }
+ }
+ switch mech {
+ case "", "none", "off":
+ return nil, nil
+ case "any", "auto", "on":
+ return Any(), nil
+ case "extip", "ip":
+ if ip == nil {
+ return nil, errors.New("missing IP address")
+ }
+ return ExtIP(ip), nil
+ case "upnp":
+ return UPnP(), nil
+ case "pmp", "natpmp", "nat-pmp":
+ return PMP(ip), nil
+ default:
+ return nil, fmt.Errorf("unknown mechanism %q", parts[0])
+ }
+}
+
+const (
+ mapTimeout = 20 * time.Minute
+ mapUpdateInterval = 15 * time.Minute
+)
+
+// Map adds a port mapping on m and keeps it alive until c is closed.
+// This function is typically invoked in its own goroutine.
+func Map(m Interface, c chan struct{}, protocol string, extport, intport int, name string) {
+ refresh := time.NewTimer(mapUpdateInterval)
+ defer func() {
+ refresh.Stop()
+ log.Debugf("Deleting port mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
+ m.DeleteMapping(protocol, extport, intport)
+ }()
+ log.Debugf("add mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
+ if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil {
+ log.Errorf("mapping error: %v\n", err)
+ }
+ for {
+ select {
+ case _, ok := <-c:
+ if !ok {
+ return
+ }
+ case <-refresh.C:
+ log.DebugDetailf("refresh mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
+ if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil {
+ log.Errorf("mapping error: %v\n", err)
+ }
+ refresh.Reset(mapUpdateInterval)
+ }
+ }
+}
+
+// ExtIP assumes that the local machine is reachable on the given
+// external IP address, and that any required ports were mapped manually.
+// Mapping operations will not return an error but won't actually do anything.
+func ExtIP(ip net.IP) Interface {
+ if ip == nil {
+ panic("IP must not be nil")
+ }
+ return extIP(ip)
+}
+
+type extIP net.IP
+
+func (n extIP) ExternalIP() (net.IP, error) { return net.IP(n), nil }
+func (n extIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) }
+
+// These do nothing.
+func (extIP) AddMapping(string, int, int, string, time.Duration) error { return nil }
+func (extIP) DeleteMapping(string, int, int) error { return nil }
+
+// Any returns a port mapper that tries to discover any supported
+// mechanism on the local network.
+func Any() Interface {
+ // TODO: attempt to discover whether the local machine has an
+ // Internet-class address. Return ExtIP in this case.
+ return startautodisc("UPnP or NAT-PMP", func() Interface {
+ found := make(chan Interface, 2)
+ go func() { found <- discoverUPnP() }()
+ go func() { found <- discoverPMP() }()
+ for i := 0; i < cap(found); i++ {
+ if c := <-found; c != nil {
+ return c
+ }
+ }
+ return nil
+ })
+}
+
+// UPnP returns a port mapper that uses UPnP. It will attempt to
+// discover the address of your router using UDP broadcasts.
+func UPnP() Interface {
+ return startautodisc("UPnP", discoverUPnP)
+}
+
+// PMP returns a port mapper that uses NAT-PMP. The provided gateway
+// address should be the IP of your router. If the given gateway
+// address is nil, PMP will attempt to auto-discover the router.
+func PMP(gateway net.IP) Interface {
+ if gateway != nil {
+ return &pmp{gw: gateway, c: natpmp.NewClient(gateway)}
+ }
+ return startautodisc("NAT-PMP", discoverPMP)
+}
+
+// autodisc represents a port mapping mechanism that is still being
+// auto-discovered. Calls to the Interface methods on this type will
+// wait until the discovery is done and then call the method on the
+// discovered mechanism.
+//
+// This type is useful because discovery can take a while but we
+// want return an Interface value from UPnP, PMP and Auto immediately.
+type autodisc struct {
+ what string
+ done <-chan Interface
+
+ mu sync.Mutex
+ found Interface
+}
+
+func startautodisc(what string, doit func() Interface) Interface {
+ // TODO: monitor network configuration and rerun doit when it changes.
+ done := make(chan Interface)
+ ad := &autodisc{what: what, done: done}
+ go func() { done <- doit(); close(done) }()
+ return ad
+}
+
+func (n *autodisc) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
+ if err := n.wait(); err != nil {
+ return err
+ }
+ return n.found.AddMapping(protocol, extport, intport, name, lifetime)
+}
+
+func (n *autodisc) DeleteMapping(protocol string, extport, intport int) error {
+ if err := n.wait(); err != nil {
+ return err
+ }
+ return n.found.DeleteMapping(protocol, extport, intport)
+}
+
+func (n *autodisc) ExternalIP() (net.IP, error) {
+ if err := n.wait(); err != nil {
+ return nil, err
+ }
+ return n.found.ExternalIP()
+}
+
+func (n *autodisc) String() string {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ if n.found == nil {
+ return n.what
+ } else {
+ return n.found.String()
+ }
+}
+
+func (n *autodisc) wait() error {
+ n.mu.Lock()
+ found := n.found
+ n.mu.Unlock()
+ if found != nil {
+ // already discovered
+ return nil
+ }
+ if found = <-n.done; found == nil {
+ return errors.New("no devices discovered")
+ }
+ n.mu.Lock()
+ n.found = found
+ n.mu.Unlock()
+ return nil
+}
diff --git a/p2p/nat/natpmp.go b/p2p/nat/natpmp.go
new file mode 100644
index 000000000..f249c6073
--- /dev/null
+++ b/p2p/nat/natpmp.go
@@ -0,0 +1,115 @@
+package nat
+
+import (
+ "fmt"
+ "net"
+ "strings"
+ "time"
+
+ "github.com/jackpal/go-nat-pmp"
+)
+
+// natPMPClient adapts the NAT-PMP protocol implementation so it conforms to
+// the common interface.
+type pmp struct {
+ gw net.IP
+ c *natpmp.Client
+}
+
+func (n *pmp) String() string {
+ return fmt.Sprintf("NAT-PMP(%v)", n.gw)
+}
+
+func (n *pmp) ExternalIP() (net.IP, error) {
+ response, err := n.c.GetExternalAddress()
+ if err != nil {
+ return nil, err
+ }
+ return response.ExternalIPAddress[:], nil
+}
+
+func (n *pmp) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
+ if lifetime <= 0 {
+ return fmt.Errorf("lifetime must not be <= 0")
+ }
+ // Note order of port arguments is switched between our
+ // AddMapping and the client's AddPortMapping.
+ _, err := n.c.AddPortMapping(strings.ToLower(protocol), intport, extport, int(lifetime/time.Second))
+ return err
+}
+
+func (n *pmp) DeleteMapping(protocol string, extport, intport int) (err error) {
+ // To destroy a mapping, send an add-port with an internalPort of
+ // the internal port to destroy, an external port of zero and a
+ // time of zero.
+ _, err = n.c.AddPortMapping(strings.ToLower(protocol), intport, 0, 0)
+ return err
+}
+
+func discoverPMP() Interface {
+ // run external address lookups on all potential gateways
+ gws := potentialGateways()
+ found := make(chan *pmp, len(gws))
+ for i := range gws {
+ gw := gws[i]
+ go func() {
+ c := natpmp.NewClient(gw)
+ if _, err := c.GetExternalAddress(); err != nil {
+ found <- nil
+ } else {
+ found <- &pmp{gw, c}
+ }
+ }()
+ }
+ // return the one that responds first.
+ // discovery needs to be quick, so we stop caring about
+ // any responses after a very short timeout.
+ timeout := time.NewTimer(1 * time.Second)
+ defer timeout.Stop()
+ for _ = range gws {
+ select {
+ case c := <-found:
+ if c != nil {
+ return c
+ }
+ case <-timeout.C:
+ return nil
+ }
+ }
+ return nil
+}
+
+var (
+ // LAN IP ranges
+ _, lan10, _ = net.ParseCIDR("10.0.0.0/8")
+ _, lan176, _ = net.ParseCIDR("172.16.0.0/12")
+ _, lan192, _ = net.ParseCIDR("192.168.0.0/16")
+)
+
+// TODO: improve this. We currently assume that (on most networks)
+// the router is X.X.X.1 in a local LAN range.
+func potentialGateways() (gws []net.IP) {
+ ifaces, err := net.Interfaces()
+ if err != nil {
+ return nil
+ }
+ for _, iface := range ifaces {
+ ifaddrs, err := iface.Addrs()
+ if err != nil {
+ return gws
+ }
+ for _, addr := range ifaddrs {
+ switch x := addr.(type) {
+ case *net.IPNet:
+ if lan10.Contains(x.IP) || lan176.Contains(x.IP) || lan192.Contains(x.IP) {
+ ip := x.IP.Mask(x.Mask).To4()
+ if ip != nil {
+ ip[3] = ip[3] | 0x01
+ gws = append(gws, ip)
+ }
+ }
+ }
+ }
+ }
+ return gws
+}
diff --git a/p2p/nat/natupnp.go b/p2p/nat/natupnp.go
new file mode 100644
index 000000000..ef7765e8d
--- /dev/null
+++ b/p2p/nat/natupnp.go
@@ -0,0 +1,149 @@
+package nat
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "strings"
+ "time"
+
+ "github.com/huin/goupnp"
+ "github.com/huin/goupnp/dcps/internetgateway1"
+ "github.com/huin/goupnp/dcps/internetgateway2"
+)
+
+type upnp struct {
+ dev *goupnp.RootDevice
+ service string
+ client upnpClient
+}
+
+type upnpClient interface {
+ GetExternalIPAddress() (string, error)
+ AddPortMapping(string, uint16, string, uint16, string, bool, string, uint32) error
+ DeletePortMapping(string, uint16, string) error
+ GetNATRSIPStatus() (sip bool, nat bool, err error)
+}
+
+func (n *upnp) ExternalIP() (addr net.IP, err error) {
+ ipString, err := n.client.GetExternalIPAddress()
+ if err != nil {
+ return nil, err
+ }
+ ip := net.ParseIP(ipString)
+ if ip == nil {
+ return nil, errors.New("bad IP in response")
+ }
+ return ip, nil
+}
+
+func (n *upnp) AddMapping(protocol string, extport, intport int, desc string, lifetime time.Duration) error {
+ ip, err := n.internalAddress()
+ if err != nil {
+ return nil
+ }
+ protocol = strings.ToUpper(protocol)
+ lifetimeS := uint32(lifetime / time.Second)
+ return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS)
+}
+
+func (n *upnp) internalAddress() (net.IP, error) {
+ devaddr, err := net.ResolveUDPAddr("udp4", n.dev.URLBase.Host)
+ if err != nil {
+ return nil, err
+ }
+ ifaces, err := net.Interfaces()
+ if err != nil {
+ return nil, err
+ }
+ for _, iface := range ifaces {
+ addrs, err := iface.Addrs()
+ if err != nil {
+ return nil, err
+ }
+ for _, addr := range addrs {
+ switch x := addr.(type) {
+ case *net.IPNet:
+ if x.Contains(devaddr.IP) {
+ return x.IP, nil
+ }
+ }
+ }
+ }
+ return nil, fmt.Errorf("could not find local address in same net as %v", devaddr)
+}
+
+func (n *upnp) DeleteMapping(protocol string, extport, intport int) error {
+ return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol))
+}
+
+func (n *upnp) String() string {
+ return "UPNP " + n.service
+}
+
+// discoverUPnP searches for Internet Gateway Devices
+// and returns the first one it can find on the local network.
+func discoverUPnP() Interface {
+ found := make(chan *upnp, 2)
+ // IGDv1
+ go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
+ switch sc.Service.ServiceType {
+ case internetgateway1.URN_WANIPConnection_1:
+ return &upnp{dev, "IGDv1-IP1", &internetgateway1.WANIPConnection1{sc}}
+ case internetgateway1.URN_WANPPPConnection_1:
+ return &upnp{dev, "IGDv1-PPP1", &internetgateway1.WANPPPConnection1{sc}}
+ }
+ return nil
+ })
+ // IGDv2
+ go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
+ switch sc.Service.ServiceType {
+ case internetgateway2.URN_WANIPConnection_1:
+ return &upnp{dev, "IGDv2-IP1", &internetgateway2.WANIPConnection1{sc}}
+ case internetgateway2.URN_WANIPConnection_2:
+ return &upnp{dev, "IGDv2-IP2", &internetgateway2.WANIPConnection2{sc}}
+ case internetgateway2.URN_WANPPPConnection_1:
+ return &upnp{dev, "IGDv2-PPP1", &internetgateway2.WANPPPConnection1{sc}}
+ }
+ return nil
+ })
+ for i := 0; i < cap(found); i++ {
+ if c := <-found; c != nil {
+ return c
+ }
+ }
+ return nil
+}
+
+func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice, goupnp.ServiceClient) *upnp) {
+ devs, err := goupnp.DiscoverDevices(target)
+ if err != nil {
+ return
+ }
+ found := false
+ for i := 0; i < len(devs) && !found; i++ {
+ if devs[i].Root == nil {
+ continue
+ }
+ devs[i].Root.Device.VisitServices(func(service *goupnp.Service) {
+ if found {
+ return
+ }
+ // check for a matching IGD service
+ sc := goupnp.ServiceClient{service.NewSOAPClient(), devs[i].Root, service}
+ upnp := matcher(devs[i].Root, sc)
+ if upnp == nil {
+ return
+ }
+ // check whether port mapping is enabled
+ if _, nat, err := upnp.client.GetNATRSIPStatus(); err != nil || !nat {
+ return
+ }
+ out <- upnp
+ found = true
+ })
+ }
+ if !found {
+ out <- nil
+ }
+}
diff --git a/p2p/natpmp.go b/p2p/natpmp.go
deleted file mode 100644
index 6714678c4..000000000
--- a/p2p/natpmp.go
+++ /dev/null
@@ -1,55 +0,0 @@
-package p2p
-
-import (
- "fmt"
- "net"
- "time"
-
- natpmp "github.com/jackpal/go-nat-pmp"
-)
-
-// Adapt the NAT-PMP protocol to the NAT interface
-
-// TODO:
-// + Register for changes to the external address.
-// + Re-register port mapping when router reboots.
-// + A mechanism for keeping a port mapping registered.
-// + Discover gateway address automatically.
-
-type natPMPClient struct {
- client *natpmp.Client
-}
-
-// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
-// address should be the IP of your router.
-func PMP(gateway net.IP) (nat NAT) {
- return &natPMPClient{natpmp.NewClient(gateway)}
-}
-
-func (*natPMPClient) String() string {
- return "NAT-PMP"
-}
-
-func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
- response, err := n.client.GetExternalAddress()
- if err != nil {
- return nil, err
- }
- return response.ExternalIPAddress[:], nil
-}
-
-func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
- if lifetime <= 0 {
- return fmt.Errorf("lifetime must not be <= 0")
- }
- // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
- _, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
- return err
-}
-
-func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
- // To destroy a mapping, send an add-port with
- // an internalPort of the internal port to destroy, an external port of zero and a time of zero.
- _, err = n.client.AddPortMapping(protocol, internalPort, 0, 0)
- return
-}
diff --git a/p2p/natupnp.go b/p2p/natupnp.go
deleted file mode 100644
index 2e0d8ce8d..000000000
--- a/p2p/natupnp.go
+++ /dev/null
@@ -1,341 +0,0 @@
-package p2p
-
-// Just enough UPnP to be able to forward ports
-//
-
-import (
- "bytes"
- "encoding/xml"
- "errors"
- "fmt"
- "net"
- "net/http"
- "os"
- "strconv"
- "strings"
- "time"
-)
-
-const (
- upnpDiscoverAttempts = 3
- upnpDiscoverTimeout = 5 * time.Second
-)
-
-// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
-// discover the address of your router using UDP broadcasts.
-func UPNP() NAT {
- return &upnpNAT{}
-}
-
-type upnpNAT struct {
- serviceURL string
- ourIP string
-}
-
-func (n *upnpNAT) String() string {
- return "UPNP"
-}
-
-func (n *upnpNAT) discover() error {
- if n.serviceURL != "" {
- // already discovered
- return nil
- }
-
- ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
- if err != nil {
- return err
- }
- // TODO: try on all network interfaces simultaneously.
- // Broadcasting on 0.0.0.0 could select a random interface
- // to send on (platform specific).
- conn, err := net.ListenPacket("udp4", ":0")
- if err != nil {
- return err
- }
- defer conn.Close()
-
- conn.SetDeadline(time.Now().Add(10 * time.Second))
- st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
- buf := bytes.NewBufferString(
- "M-SEARCH * HTTP/1.1\r\n" +
- "HOST: 239.255.255.250:1900\r\n" +
- st +
- "MAN: \"ssdp:discover\"\r\n" +
- "MX: 2\r\n\r\n")
- message := buf.Bytes()
- answerBytes := make([]byte, 1024)
- for i := 0; i < upnpDiscoverAttempts; i++ {
- _, err = conn.WriteTo(message, ssdp)
- if err != nil {
- return err
- }
- nn, _, err := conn.ReadFrom(answerBytes)
- if err != nil {
- continue
- }
- answer := string(answerBytes[0:nn])
- if strings.Index(answer, "\r\n"+st) < 0 {
- continue
- }
- // HTTP header field names are case-insensitive.
- // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
- locString := "\r\nlocation: "
- answer = strings.ToLower(answer)
- locIndex := strings.Index(answer, locString)
- if locIndex < 0 {
- continue
- }
- loc := answer[locIndex+len(locString):]
- endIndex := strings.Index(loc, "\r\n")
- if endIndex < 0 {
- continue
- }
- locURL := loc[0:endIndex]
- var serviceURL string
- serviceURL, err = getServiceURL(locURL)
- if err != nil {
- return err
- }
- var ourIP string
- ourIP, err = getOurIP()
- if err != nil {
- return err
- }
- n.serviceURL = serviceURL
- n.ourIP = ourIP
- return nil
- }
- return errors.New("UPnP port discovery failed.")
-}
-
-func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
- if err := n.discover(); err != nil {
- return nil, err
- }
- info, err := n.getStatusInfo()
- return net.ParseIP(info.externalIpAddress), err
-}
-
-func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
- if err := n.discover(); err != nil {
- return err
- }
-
- // A single concatenation would break ARM compilation.
- message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
- "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
- message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
- message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
- "<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
- "<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
- message += description +
- "</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
- "</NewLeaseDuration></u:AddPortMapping>"
-
- // TODO: check response to see if the port was forwarded
- _, err := soapRequest(n.serviceURL, "AddPortMapping", message)
- return err
-}
-
-func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
- if err := n.discover(); err != nil {
- return err
- }
-
- message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
- "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
- "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
- "</u:DeletePortMapping>"
-
- // TODO: check response to see if the port was deleted
- _, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
- return err
-}
-
-type statusInfo struct {
- externalIpAddress string
-}
-
-func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
- message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
- "</u:GetStatusInfo>"
-
- var response *http.Response
- response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
- if err != nil {
- return
- }
-
- // TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
-
- response.Body.Close()
- return
-}
-
-// service represents the Service type in an UPnP xml description.
-// Only the parts we care about are present and thus the xml may have more
-// fields than present in the structure.
-type service struct {
- ServiceType string `xml:"serviceType"`
- ControlURL string `xml:"controlURL"`
-}
-
-// deviceList represents the deviceList type in an UPnP xml description.
-// Only the parts we care about are present and thus the xml may have more
-// fields than present in the structure.
-type deviceList struct {
- XMLName xml.Name `xml:"deviceList"`
- Device []device `xml:"device"`
-}
-
-// serviceList represents the serviceList type in an UPnP xml description.
-// Only the parts we care about are present and thus the xml may have more
-// fields than present in the structure.
-type serviceList struct {
- XMLName xml.Name `xml:"serviceList"`
- Service []service `xml:"service"`
-}
-
-// device represents the device type in an UPnP xml description.
-// Only the parts we care about are present and thus the xml may have more
-// fields than present in the structure.
-type device struct {
- XMLName xml.Name `xml:"device"`
- DeviceType string `xml:"deviceType"`
- DeviceList deviceList `xml:"deviceList"`
- ServiceList serviceList `xml:"serviceList"`
-}
-
-// specVersion represents the specVersion in a UPnP xml description.
-// Only the parts we care about are present and thus the xml may have more
-// fields than present in the structure.
-type specVersion struct {
- XMLName xml.Name `xml:"specVersion"`
- Major int `xml:"major"`
- Minor int `xml:"minor"`
-}
-
-// root represents the Root document for a UPnP xml description.
-// Only the parts we care about are present and thus the xml may have more
-// fields than present in the structure.
-type root struct {
- XMLName xml.Name `xml:"root"`
- SpecVersion specVersion
- Device device
-}
-
-func getChildDevice(d *device, deviceType string) *device {
- dl := d.DeviceList.Device
- for i := 0; i < len(dl); i++ {
- if dl[i].DeviceType == deviceType {
- return &dl[i]
- }
- }
- return nil
-}
-
-func getChildService(d *device, serviceType string) *service {
- sl := d.ServiceList.Service
- for i := 0; i < len(sl); i++ {
- if sl[i].ServiceType == serviceType {
- return &sl[i]
- }
- }
- return nil
-}
-
-func getOurIP() (ip string, err error) {
- hostname, err := os.Hostname()
- if err != nil {
- return
- }
- p, err := net.LookupIP(hostname)
- if err != nil && len(p) > 0 {
- return
- }
- return p[0].String(), nil
-}
-
-func getServiceURL(rootURL string) (url string, err error) {
- r, err := http.Get(rootURL)
- if err != nil {
- return
- }
- defer r.Body.Close()
- if r.StatusCode >= 400 {
- err = errors.New(string(r.StatusCode))
- return
- }
- var root root
- err = xml.NewDecoder(r.Body).Decode(&root)
-
- if err != nil {
- return
- }
- a := &root.Device
- if a.DeviceType != "urn:schemas-upnp-org:device:InternetGatewayDevice:1" {
- err = errors.New("No InternetGatewayDevice")
- return
- }
- b := getChildDevice(a, "urn:schemas-upnp-org:device:WANDevice:1")
- if b == nil {
- err = errors.New("No WANDevice")
- return
- }
- c := getChildDevice(b, "urn:schemas-upnp-org:device:WANConnectionDevice:1")
- if c == nil {
- err = errors.New("No WANConnectionDevice")
- return
- }
- d := getChildService(c, "urn:schemas-upnp-org:service:WANIPConnection:1")
- if d == nil {
- err = errors.New("No WANIPConnection")
- return
- }
- url = combineURL(rootURL, d.ControlURL)
- return
-}
-
-func combineURL(rootURL, subURL string) string {
- protocolEnd := "://"
- protoEndIndex := strings.Index(rootURL, protocolEnd)
- a := rootURL[protoEndIndex+len(protocolEnd):]
- rootIndex := strings.Index(a, "/")
- return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
-}
-
-func soapRequest(url, function, message string) (r *http.Response, err error) {
- fullMessage := "<?xml version=\"1.0\" ?>" +
- "<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" +
- "<s:Body>" + message + "</s:Body></s:Envelope>"
-
- req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
- if err != nil {
- return
- }
- req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
- req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
- //req.Header.Set("Transfer-Encoding", "chunked")
- req.Header.Set("SOAPAction", "\"urn:schemas-upnp-org:service:WANIPConnection:1#"+function+"\"")
- req.Header.Set("Connection", "Close")
- req.Header.Set("Cache-Control", "no-cache")
- req.Header.Set("Pragma", "no-cache")
-
- r, err = http.DefaultClient.Do(req)
- if err != nil {
- return
- }
-
- if r.Body != nil {
- defer r.Body.Close()
- }
-
- if r.StatusCode >= 400 {
- // log.Stderr(function, r.StatusCode)
- err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
- r = nil
- return
- }
- return
-}
diff --git a/p2p/peer.go b/p2p/peer.go
index 2380a3285..fb027c834 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -1,8 +1,7 @@
package p2p
import (
- "bufio"
- "bytes"
+ "errors"
"fmt"
"io"
"io/ioutil"
@@ -11,159 +10,78 @@ import (
"sync"
"time"
- "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+ "github.com/ethereum/go-ethereum/rlp"
)
-// peerAddr is the structure of a peer list element.
-// It is also a valid net.Addr.
-type peerAddr struct {
- IP net.IP
- Port uint64
- Pubkey []byte // optional
-}
-
-func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr {
- n := addr.Network()
- if n != "tcp" && n != "tcp4" && n != "tcp6" {
- // for testing with non-TCP
- return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
- }
- ta := addr.(*net.TCPAddr)
- return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
-}
-
-func (d peerAddr) Network() string {
- if d.IP.To4() != nil {
- return "tcp4"
- } else {
- return "tcp6"
- }
-}
+const (
+ baseProtocolVersion = 3
+ baseProtocolLength = uint64(16)
+ baseProtocolMaxMsgSize = 10 * 1024 * 1024
-func (d peerAddr) String() string {
- return fmt.Sprintf("%v:%d", d.IP, d.Port)
-}
+ disconnectGracePeriod = 2 * time.Second
+ pingInterval = 15 * time.Second
+)
-func (d *peerAddr) RlpData() interface{} {
- return []interface{}{string(d.IP), d.Port, d.Pubkey}
-}
+const (
+ // devp2p message codes
+ handshakeMsg = 0x00
+ discMsg = 0x01
+ pingMsg = 0x02
+ pongMsg = 0x03
+ getPeersMsg = 0x04
+ peersMsg = 0x05
+)
-// Peer represents a remote peer.
+// Peer represents a connected remote node.
type Peer struct {
// Peers have all the log methods.
// Use them to display messages related to the peer.
*logger.Logger
- infolock sync.Mutex
- identity ClientIdentity
- caps []Cap
- listenAddr *peerAddr // what remote peer is listening on
- dialAddr *peerAddr // non-nil if dialing
-
- // The mutex protects the connection
- // so only one protocol can write at a time.
- writeMu sync.Mutex
- conn net.Conn
- bufconn *bufio.ReadWriter
-
- // These fields maintain the running protocols.
- protocols []Protocol
- runBaseProtocol bool // for testing
-
- runlock sync.RWMutex // protects running
- running map[string]*proto
+ rw *conn
+ running map[string]*protoRW
protoWG sync.WaitGroup
protoErr chan error
closed chan struct{}
disc chan DiscReason
-
- activity event.TypeMux // for activity events
-
- slot int // index into Server peer list
-
- // These fields are kept so base protocol can access them.
- // TODO: this should be one or more interfaces
- ourID ClientIdentity // client id of the Server
- ourListenAddr *peerAddr // listen addr of Server, nil if not listening
- newPeerAddr chan<- *peerAddr // tell server about received peers
- otherPeers func() []*Peer // should return the list of all peers
- pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
}
// NewPeer returns a peer for testing purposes.
-func NewPeer(id ClientIdentity, caps []Cap) *Peer {
- conn, _ := net.Pipe()
- peer := newPeer(conn, nil, nil)
- peer.setHandshakeInfo(id, nil, caps)
- close(peer.closed)
+func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
+ pipe, _ := net.Pipe()
+ conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps})
+ peer := newPeer(conn, nil)
+ close(peer.closed) // ensures Disconnect doesn't block
return peer
}
-func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
- p := newPeer(conn, server.Protocols, dialAddr)
- p.ourID = server.Identity
- p.newPeerAddr = server.peerConnect
- p.otherPeers = server.Peers
- p.pubkeyHook = server.verifyPeer
- p.runBaseProtocol = true
-
- // laddr can be updated concurrently by NAT traversal.
- // newServerPeer must be called with the server lock held.
- if server.laddr != nil {
- p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
- }
- return p
-}
-
-func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer {
- p := &Peer{
- Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()),
- conn: conn,
- dialAddr: dialAddr,
- bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
- protocols: protocols,
- running: make(map[string]*proto),
- disc: make(chan DiscReason),
- protoErr: make(chan error),
- closed: make(chan struct{}),
- }
- return p
+// ID returns the node's public key.
+func (p *Peer) ID() discover.NodeID {
+ return p.rw.ID
}
-// Identity returns the client identity of the remote peer. The
-// identity can be nil if the peer has not yet completed the
-// handshake.
-func (p *Peer) Identity() ClientIdentity {
- p.infolock.Lock()
- defer p.infolock.Unlock()
- return p.identity
+// Name returns the node name that the remote node advertised.
+func (p *Peer) Name() string {
+ return p.rw.Name
}
// Caps returns the capabilities (supported subprotocols) of the remote peer.
func (p *Peer) Caps() []Cap {
- p.infolock.Lock()
- defer p.infolock.Unlock()
- return p.caps
-}
-
-func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) {
- p.infolock.Lock()
- p.identity = id
- p.listenAddr = laddr
- p.caps = caps
- p.infolock.Unlock()
+ // TODO: maybe return copy
+ return p.rw.Caps
}
// RemoteAddr returns the remote address of the network connection.
func (p *Peer) RemoteAddr() net.Addr {
- return p.conn.RemoteAddr()
+ return p.rw.RemoteAddr()
}
// LocalAddr returns the local address of the network connection.
func (p *Peer) LocalAddr() net.Addr {
- return p.conn.LocalAddr()
+ return p.rw.LocalAddr()
}
// Disconnect terminates the peer connection with the given reason.
@@ -177,198 +95,166 @@ func (p *Peer) Disconnect(reason DiscReason) {
// String implements fmt.Stringer.
func (p *Peer) String() string {
- kind := "inbound"
- p.infolock.Lock()
- if p.dialAddr != nil {
- kind = "outbound"
- }
- p.infolock.Unlock()
- return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
+ return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
}
-const (
- // maximum amount of time allowed for reading a message
- msgReadTimeout = 5 * time.Second
- // maximum amount of time allowed for writing a message
- msgWriteTimeout = 5 * time.Second
- // messages smaller than this many bytes will be read at
- // once before passing them to a protocol.
- wholePayloadSize = 64 * 1024
-)
-
-var (
- inactivityTimeout = 2 * time.Second
- disconnectGracePeriod = 2 * time.Second
-)
+func newPeer(conn *conn, protocols []Protocol) *Peer {
+ logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr())
+ p := &Peer{
+ Logger: logger.NewLogger(logtag),
+ rw: conn,
+ running: matchProtocols(protocols, conn.Caps, conn),
+ disc: make(chan DiscReason),
+ protoErr: make(chan error),
+ closed: make(chan struct{}),
+ }
+ return p
+}
-func (p *Peer) loop() (reason DiscReason, err error) {
- defer p.activity.Stop()
+func (p *Peer) run() DiscReason {
+ var readErr = make(chan error, 1)
defer p.closeProtocols()
defer close(p.closed)
- defer p.conn.Close()
-
- // read loop
- readMsg := make(chan Msg)
- readErr := make(chan error)
- readNext := make(chan bool, 1)
- protoDone := make(chan struct{}, 1)
- go p.readLoop(readMsg, readErr, readNext)
- readNext <- true
-
- if p.runBaseProtocol {
- p.startBaseProtocol()
- }
+ p.startProtocols()
+ go func() { readErr <- p.readLoop() }()
+
+ ping := time.NewTicker(pingInterval)
+ defer ping.Stop()
+
+ // Wait for an error or disconnect.
+ var reason DiscReason
loop:
for {
select {
- case msg := <-readMsg:
- // a new message has arrived.
- var wait bool
- if wait, err = p.dispatch(msg, protoDone); err != nil {
- p.Errorf("msg dispatch error: %v\n", err)
- reason = discReasonForError(err)
- break loop
- }
- if !wait {
- // Msg has already been read completely, continue with next message.
- readNext <- true
- }
- p.activity.Post(time.Now())
- case <-protoDone:
- // protocol has consumed the message payload,
- // we can continue reading from the socket.
- readNext <- true
-
+ case <-ping.C:
+ go func() {
+ if err := EncodeMsg(p.rw, pingMsg, nil); err != nil {
+ p.protoErr <- err
+ return
+ }
+ }()
case err := <-readErr:
- // read failed. there is no need to run the
- // polite disconnect sequence because the connection
- // is probably dead anyway.
- // TODO: handle write errors as well
- return DiscNetworkError, err
- case err = <-p.protoErr:
+ // We rely on protocols to abort if there is a write error. It
+ // might be more robust to handle them here as well.
+ p.DebugDetailf("Read error: %v\n", err)
+ p.rw.Close()
+ return DiscNetworkError
+ case err := <-p.protoErr:
reason = discReasonForError(err)
break loop
case reason = <-p.disc:
break loop
}
}
+ p.politeDisconnect(reason)
- // wait for read loop to return.
- close(readNext)
+ // Wait for readLoop. It will end because conn is now closed.
<-readErr
- // tell the remote end to disconnect
+ p.Debugf("Disconnected: %v\n", reason)
+ return reason
+}
+
+func (p *Peer) politeDisconnect(reason DiscReason) {
done := make(chan struct{})
go func() {
- p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod))
- p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod)
- io.Copy(ioutil.Discard, p.conn)
+ EncodeMsg(p.rw, discMsg, uint(reason))
+ // Wait for the other side to close the connection.
+ // Discard any data that they send until then.
+ io.Copy(ioutil.Discard, p.rw)
close(done)
}()
select {
case <-done:
case <-time.After(disconnectGracePeriod):
}
- return reason, err
+ p.rw.Close()
}
-func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) {
- for _ = range unblock {
- p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
- if msg, err := readMsg(p.bufconn); err != nil {
- errc <- err
- } else {
- msgc <- msg
+func (p *Peer) readLoop() error {
+ for {
+ msg, err := p.rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if err = p.handle(msg); err != nil {
+ return err
}
}
- close(errc)
-}
-
-func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) {
- proto, err := p.getProto(msg.Code)
- if err != nil {
- return false, err
- }
- if msg.Size <= wholePayloadSize {
- // optimization: msg is small enough, read all
- // of it and move on to the next message
- buf, err := ioutil.ReadAll(msg.Payload)
+ return nil
+}
+
+func (p *Peer) handle(msg Msg) error {
+ switch {
+ case msg.Code == pingMsg:
+ msg.Discard()
+ go EncodeMsg(p.rw, pongMsg)
+ case msg.Code == discMsg:
+ var reason DiscReason
+ // no need to discard or for error checking, we'll close the
+ // connection after this.
+ rlp.Decode(msg.Payload, &reason)
+ p.Disconnect(DiscRequested)
+ return discRequestedError(reason)
+ case msg.Code < baseProtocolLength:
+ // ignore other base protocol messages
+ return msg.Discard()
+ default:
+ // it's a subprotocol message
+ proto, err := p.getProto(msg.Code)
if err != nil {
- return false, err
+ return fmt.Errorf("msg code out of range: %v", msg.Code)
}
- msg.Payload = bytes.NewReader(buf)
- proto.in <- msg
- } else {
- wait = true
- pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
- msg.Payload = pr
proto.in <- msg
}
- return wait, nil
-}
-
-func (p *Peer) startBaseProtocol() {
- p.runlock.Lock()
- defer p.runlock.Unlock()
- p.running[""] = p.startProto(0, Protocol{
- Length: baseProtocolLength,
- Run: runBaseProtocol,
- })
+ return nil
}
-// startProtocols starts matching named subprotocols.
-func (p *Peer) startSubprotocols(caps []Cap) {
+// matchProtocols creates structures for matching named subprotocols.
+func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
sort.Sort(capsByName(caps))
-
- p.runlock.Lock()
- defer p.runlock.Unlock()
offset := baseProtocolLength
+ result := make(map[string]*protoRW)
outer:
for _, cap := range caps {
- for _, proto := range p.protocols {
- if proto.Name == cap.Name &&
- proto.Version == cap.Version &&
- p.running[cap.Name] == nil {
- p.running[cap.Name] = p.startProto(offset, proto)
+ for _, proto := range protocols {
+ if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil {
+ result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
offset += proto.Length
continue outer
}
}
}
+ return result
}
-func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
- rw := &proto{
- in: make(chan Msg),
- offset: offset,
- maxcode: impl.Length,
- peer: p,
+func (p *Peer) startProtocols() {
+ for _, proto := range p.running {
+ proto := proto
+ p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version)
+ p.protoWG.Add(1)
+ go func() {
+ err := proto.Run(p, proto)
+ if err == nil {
+ p.DebugDetailf("Protocol %s/%d returned\n", proto.Name, proto.Version)
+ err = errors.New("protocol returned")
+ } else {
+ p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err)
+ }
+ select {
+ case p.protoErr <- err:
+ case <-p.closed:
+ }
+ p.protoWG.Done()
+ }()
}
- p.protoWG.Add(1)
- go func() {
- err := impl.Run(p, rw)
- if err == nil {
- p.Infof("protocol %q returned", impl.Name)
- err = newPeerError(errMisc, "protocol returned")
- } else {
- p.Errorf("protocol %q error: %v\n", impl.Name, err)
- }
- select {
- case p.protoErr <- err:
- case <-p.closed:
- }
- p.protoWG.Done()
- }()
- return rw
}
// getProto finds the protocol responsible for handling
// the given message code.
-func (p *Peer) getProto(code uint64) (*proto, error) {
- p.runlock.RLock()
- defer p.runlock.RUnlock()
+func (p *Peer) getProto(code uint64) (*protoRW, error) {
for _, proto := range p.running {
- if code >= proto.offset && code < proto.offset+proto.maxcode {
+ if code >= proto.offset && code < proto.offset+proto.Length {
return proto, nil
}
}
@@ -376,60 +262,43 @@ func (p *Peer) getProto(code uint64) (*proto, error) {
}
func (p *Peer) closeProtocols() {
- p.runlock.RLock()
for _, p := range p.running {
close(p.in)
}
- p.runlock.RUnlock()
p.protoWG.Wait()
}
// writeProtoMsg sends the given message on behalf of the given named protocol.
+// this exists because of Server.Broadcast.
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
- p.runlock.RLock()
proto, ok := p.running[protoName]
- p.runlock.RUnlock()
if !ok {
return fmt.Errorf("protocol %s not handled by peer", protoName)
}
- if msg.Code >= proto.maxcode {
+ if msg.Code >= proto.Length {
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
}
msg.Code += proto.offset
- return p.writeMsg(msg, msgWriteTimeout)
+ return p.rw.WriteMsg(msg)
}
-// writeMsg writes a message to the connection.
-func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
- p.writeMu.Lock()
- defer p.writeMu.Unlock()
- p.conn.SetWriteDeadline(time.Now().Add(timeout))
- if err := writeMsg(p.bufconn, msg); err != nil {
- return newPeerError(errWrite, "%v", err)
- }
- return p.bufconn.Flush()
-}
+type protoRW struct {
+ Protocol
-type proto struct {
- name string
- in chan Msg
- maxcode, offset uint64
- peer *Peer
+ in chan Msg
+ offset uint64
+ w MsgWriter
}
-func (rw *proto) WriteMsg(msg Msg) error {
- if msg.Code >= rw.maxcode {
+func (rw *protoRW) WriteMsg(msg Msg) error {
+ if msg.Code >= rw.Length {
return newPeerError(errInvalidMsgCode, "not handled")
}
msg.Code += rw.offset
- return rw.peer.writeMsg(msg, msgWriteTimeout)
-}
-
-func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
- return rw.WriteMsg(NewMsg(code, data...))
+ return rw.w.WriteMsg(msg)
}
-func (rw *proto) ReadMsg() (Msg, error) {
+func (rw *protoRW) ReadMsg() (Msg, error) {
msg, ok := <-rw.in
if !ok {
return msg, io.EOF
@@ -437,26 +306,3 @@ func (rw *proto) ReadMsg() (Msg, error) {
msg.Code -= rw.offset
return msg, nil
}
-
-// eofSignal wraps a reader with eof signaling. the eof channel is
-// closed when the wrapped reader returns an error or when count bytes
-// have been read.
-//
-type eofSignal struct {
- wrapped io.Reader
- count int64
- eof chan<- struct{}
-}
-
-// note: when using eofSignal to detect whether a message payload
-// has been read, Read might not be called for zero sized messages.
-
-func (r *eofSignal) Read(buf []byte) (int, error) {
- n, err := r.wrapped.Read(buf)
- r.count -= int64(n)
- if (err != nil || r.count <= 0) && r.eof != nil {
- r.eof <- struct{}{} // tell Peer that msg has been consumed
- r.eof = nil
- }
- return n, err
-}
diff --git a/p2p/peer_error.go b/p2p/peer_error.go
index 0eb7ec838..0ff4f4b43 100644
--- a/p2p/peer_error.go
+++ b/p2p/peer_error.go
@@ -12,7 +12,6 @@ const (
errInvalidMsgCode
errInvalidMsg
errP2PVersionMismatch
- errPubkeyMissing
errPubkeyInvalid
errPubkeyForbidden
errProtocolBreach
@@ -22,20 +21,19 @@ const (
)
var errorToString = map[int]string{
- errMagicTokenMismatch: "Magic token mismatch",
- errRead: "Read error",
- errWrite: "Write error",
- errMisc: "Misc error",
- errInvalidMsgCode: "Invalid message code",
- errInvalidMsg: "Invalid message",
+ errMagicTokenMismatch: "magic token mismatch",
+ errRead: "read error",
+ errWrite: "write error",
+ errMisc: "misc error",
+ errInvalidMsgCode: "invalid message code",
+ errInvalidMsg: "invalid message",
errP2PVersionMismatch: "P2P Version Mismatch",
- errPubkeyMissing: "Public key missing",
- errPubkeyInvalid: "Public key invalid",
- errPubkeyForbidden: "Public key forbidden",
- errProtocolBreach: "Protocol Breach",
- errPingTimeout: "Ping timeout",
- errInvalidNetworkId: "Invalid network id",
- errInvalidProtocolVersion: "Invalid protocol version",
+ errPubkeyInvalid: "public key invalid",
+ errPubkeyForbidden: "public key forbidden",
+ errProtocolBreach: "protocol Breach",
+ errPingTimeout: "ping timeout",
+ errInvalidNetworkId: "invalid network id",
+ errInvalidProtocolVersion: "invalid protocol version",
}
type peerError struct {
@@ -62,22 +60,22 @@ func (self *peerError) Error() string {
type DiscReason byte
const (
- DiscRequested DiscReason = 0x00
- DiscNetworkError = 0x01
- DiscProtocolError = 0x02
- DiscUselessPeer = 0x03
- DiscTooManyPeers = 0x04
- DiscAlreadyConnected = 0x05
- DiscIncompatibleVersion = 0x06
- DiscInvalidIdentity = 0x07
- DiscQuitting = 0x08
- DiscUnexpectedIdentity = 0x09
- DiscSelf = 0x0a
- DiscReadTimeout = 0x0b
- DiscSubprotocolError = 0x10
+ DiscRequested DiscReason = iota
+ DiscNetworkError
+ DiscProtocolError
+ DiscUselessPeer
+ DiscTooManyPeers
+ DiscAlreadyConnected
+ DiscIncompatibleVersion
+ DiscInvalidIdentity
+ DiscQuitting
+ DiscUnexpectedIdentity
+ DiscSelf
+ DiscReadTimeout
+ DiscSubprotocolError
)
-var discReasonToString = [DiscSubprotocolError + 1]string{
+var discReasonToString = [...]string{
DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol",
@@ -117,7 +115,7 @@ func discReasonForError(err error) DiscReason {
switch peerError.Code {
case errP2PVersionMismatch:
return DiscIncompatibleVersion
- case errPubkeyMissing, errPubkeyInvalid:
+ case errPubkeyInvalid:
return DiscInvalidIdentity
case errPubkeyForbidden:
return DiscUselessPeer
@@ -125,7 +123,7 @@ func discReasonForError(err error) DiscReason {
return DiscProtocolError
case errPingTimeout:
return DiscReadTimeout
- case errRead, errWrite, errMisc:
+ case errRead, errWrite:
return DiscNetworkError
default:
return DiscSubprotocolError
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index 4ee88f112..a1260adbd 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -1,15 +1,15 @@
package p2p
import (
- "bufio"
"bytes"
- "encoding/hex"
- "io"
+ "fmt"
"io/ioutil"
"net"
"reflect"
"testing"
"time"
+
+ "github.com/ethereum/go-ethereum/rlp"
)
var discard = Protocol{
@@ -21,6 +21,7 @@ var discard = Protocol{
if err != nil {
return err
}
+ fmt.Printf("discarding %d\n", msg.Code)
if err = msg.Discard(); err != nil {
return err
}
@@ -28,17 +29,20 @@ var discard = Protocol{
},
}
-func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
- conn1, conn2 := net.Pipe()
- peer := newPeer(conn1, protos, nil)
- peer.ourID = &peerId{}
- peer.pubkeyHook = func(*peerAddr) error { return nil }
- errc := make(chan error, 1)
- go func() {
- _, err := peer.loop()
- errc <- err
- }()
- return conn2, peer, errc
+func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) {
+ fd1, fd2 := net.Pipe()
+ hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
+ hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
+ for _, p := range protos {
+ hs1.Caps = append(hs1.Caps, p.cap())
+ hs2.Caps = append(hs2.Caps, p.cap())
+ }
+
+ peer := newPeer(newConn(fd1, hs1), protos)
+ errc := make(chan DiscReason, 1)
+ go func() { errc <- peer.run() }()
+
+ return newConn(fd2, hs2), peer, errc
}
func TestPeerProtoReadMsg(t *testing.T) {
@@ -49,31 +53,27 @@ func TestPeerProtoReadMsg(t *testing.T) {
Name: "a",
Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error {
- msg, err := rw.ReadMsg()
- if err != nil {
- t.Errorf("read error: %v", err)
- }
- if msg.Code != 2 {
- t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
+ if err := expectMsg(rw, 2, []uint{1}); err != nil {
+ t.Error(err)
}
- data, err := ioutil.ReadAll(msg.Payload)
- if err != nil {
- t.Errorf("payload read error: %v", err)
+ if err := expectMsg(rw, 3, []uint{2}); err != nil {
+ t.Error(err)
}
- expdata, _ := hex.DecodeString("0183303030")
- if !bytes.Equal(expdata, data) {
- t.Errorf("incorrect msg data %x", data)
+ if err := expectMsg(rw, 4, []uint{3}); err != nil {
+ t.Error(err)
}
close(done)
return nil
},
}
- net, peer, errc := testPeer([]Protocol{proto})
- defer net.Close()
- peer.startSubprotocols([]Cap{proto.cap()})
+ rw, _, errc := testPeer([]Protocol{proto})
+ defer rw.Close()
+
+ EncodeMsg(rw, baseProtocolLength+2, 1)
+ EncodeMsg(rw, baseProtocolLength+3, 2)
+ EncodeMsg(rw, baseProtocolLength+4, 3)
- writeMsg(net, NewMsg(18, 1, "000"))
select {
case <-done:
case err := <-errc:
@@ -105,11 +105,10 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
},
}
- net, peer, errc := testPeer([]Protocol{proto})
- defer net.Close()
- peer.startSubprotocols([]Cap{proto.cap()})
+ rw, _, errc := testPeer([]Protocol{proto})
+ defer rw.Close()
- writeMsg(net, NewMsg(18, make([]byte, msgsize)))
+ EncodeMsg(rw, 18, make([]byte, msgsize))
select {
case <-done:
case err := <-errc:
@@ -135,33 +134,19 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
return nil
},
}
- net, peer, _ := testPeer([]Protocol{proto})
- defer net.Close()
- peer.startSubprotocols([]Cap{proto.cap()})
+ rw, _, _ := testPeer([]Protocol{proto})
+ defer rw.Close()
- bufr := bufio.NewReader(net)
- msg, err := readMsg(bufr)
- if err != nil {
- t.Errorf("read error: %v", err)
- }
- if msg.Code != 17 {
- t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
- }
- var data []string
- if err := msg.Decode(&data); err != nil {
- t.Errorf("payload decode error: %v", err)
- }
- if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
- t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
+ if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
+ t.Error(err)
}
}
-func TestPeerWrite(t *testing.T) {
+func TestPeerWriteForBroadcast(t *testing.T) {
defer testlog(t).detach()
- net, peer, peerErr := testPeer([]Protocol{discard})
- defer net.Close()
- peer.startSubprotocols([]Cap{discard.cap()})
+ rw, peer, peerErr := testPeer([]Protocol{discard})
+ defer rw.Close()
// test write errors
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
@@ -176,18 +161,13 @@ func TestPeerWrite(t *testing.T) {
// setup for reading the message on the other end
read := make(chan struct{})
go func() {
- bufr := bufio.NewReader(net)
- msg, err := readMsg(bufr)
- if err != nil {
- t.Errorf("read error: %v", err)
- } else if msg.Code != 16 {
- t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
+ if err := expectMsg(rw, 16, nil); err != nil {
+ t.Error(err)
}
- msg.Discard()
close(read)
}()
- // test succcessful write
+ // test successful write
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err)
}
@@ -198,104 +178,86 @@ func TestPeerWrite(t *testing.T) {
}
}
-func TestPeerActivity(t *testing.T) {
- // shorten inactivityTimeout while this test is running
- oldT := inactivityTimeout
- defer func() { inactivityTimeout = oldT }()
- inactivityTimeout = 20 * time.Millisecond
+func TestPeerPing(t *testing.T) {
+ defer testlog(t).detach()
- net, peer, peerErr := testPeer([]Protocol{discard})
- defer net.Close()
- peer.startSubprotocols([]Cap{discard.cap()})
+ rw, _, _ := testPeer(nil)
+ defer rw.Close()
+ if err := EncodeMsg(rw, pingMsg); err != nil {
+ t.Fatal(err)
+ }
+ if err := expectMsg(rw, pongMsg, nil); err != nil {
+ t.Error(err)
+ }
+}
- sub := peer.activity.Subscribe(time.Time{})
- defer sub.Unsubscribe()
+func TestPeerDisconnect(t *testing.T) {
+ defer testlog(t).detach()
- for i := 0; i < 6; i++ {
- writeMsg(net, NewMsg(16))
- select {
- case <-sub.Chan():
- case <-time.After(inactivityTimeout / 2):
- t.Fatal("no event within ", inactivityTimeout/2)
- case err := <-peerErr:
- t.Fatal("peer error", err)
- }
+ rw, _, disc := testPeer(nil)
+ defer rw.Close()
+ if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
+ t.Fatal(err)
}
-
- select {
- case <-time.After(inactivityTimeout * 2):
- case <-sub.Chan():
- t.Fatal("got activity event while connection was inactive")
- case err := <-peerErr:
- t.Fatal("peer error", err)
+ if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
+ t.Error(err)
+ }
+ rw.Close() // make test end faster
+ if reason := <-disc; reason != DiscRequested {
+ t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
}
}
func TestNewPeer(t *testing.T) {
+ name := "nodename"
caps := []Cap{{"foo", 2}, {"bar", 3}}
- id := &peerId{}
- p := NewPeer(id, caps)
- if !reflect.DeepEqual(p.Caps(), caps) {
- t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
- }
- if p.Identity() != id {
- t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id)
+ id := randomID()
+ p := NewPeer(id, name, caps)
+ if p.ID() != id {
+ t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id)
}
- // Should not hang.
- p.Disconnect(DiscAlreadyConnected)
-}
-
-func TestEOFSignal(t *testing.T) {
- rb := make([]byte, 10)
-
- // empty reader
- eof := make(chan struct{}, 1)
- sig := &eofSignal{new(bytes.Buffer), 0, eof}
- if n, err := sig.Read(rb); n != 0 || err != io.EOF {
- t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ if p.Name() != name {
+ t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name)
}
- select {
- case <-eof:
- default:
- t.Error("EOF chan not signaled")
+ if !reflect.DeepEqual(p.Caps(), caps) {
+ t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
}
- // count before error
- eof = make(chan struct{}, 1)
- sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
- if n, err := sig.Read(rb); n != 8 || err != nil {
- t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
- }
- select {
- case <-eof:
- default:
- t.Error("EOF chan not signaled")
- }
+ p.Disconnect(DiscAlreadyConnected) // Should not hang
+}
- // error before count
- eof = make(chan struct{}, 1)
- sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
- if n, err := sig.Read(rb); n != 4 || err != nil {
- t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
- }
- if n, err := sig.Read(rb); n != 0 || err != io.EOF {
- t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+// expectMsg reads a message from r and verifies that its
+// code and encoded RLP content match the provided values.
+// If content is nil, the payload is discarded and not verified.
+func expectMsg(r MsgReader, code uint64, content interface{}) error {
+ msg, err := r.ReadMsg()
+ if err != nil {
+ return err
}
- select {
- case <-eof:
- default:
- t.Error("EOF chan not signaled")
+ if msg.Code != code {
+ return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code)
}
+ if content == nil {
+ return msg.Discard()
+ } else {
+ contentEnc, err := rlp.EncodeToBytes(content)
+ if err != nil {
+ panic("content encode error: " + err.Error())
+ }
+ // skip over list header in encoded value. this is temporary.
+ contentEncR := bytes.NewReader(contentEnc)
+ if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil {
+ panic("content must encode as RLP list")
+ }
+ contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():]
- // no signal if neither occurs
- eof = make(chan struct{}, 1)
- sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
- if n, err := sig.Read(rb); n != 10 || err != nil {
- t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
- }
- select {
- case <-eof:
- t.Error("unexpected EOF signal")
- default:
+ actualContent, err := ioutil.ReadAll(msg.Payload)
+ if err != nil {
+ return err
+ }
+ if !bytes.Equal(actualContent, contentEnc) {
+ return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc)
+ }
}
+ return nil
}
diff --git a/p2p/protocol.go b/p2p/protocol.go
index 1d121a885..5fa395eda 100644
--- a/p2p/protocol.go
+++ b/p2p/protocol.go
@@ -1,9 +1,6 @@
package p2p
-import (
- "bytes"
- "time"
-)
+import "fmt"
// Protocol represents a P2P subprotocol implementation.
type Protocol struct {
@@ -32,38 +29,6 @@ func (p Protocol) cap() Cap {
return Cap{p.Name, p.Version}
}
-const (
- baseProtocolVersion = 2
- baseProtocolLength = uint64(16)
- baseProtocolMaxMsgSize = 10 * 1024 * 1024
-)
-
-const (
- // devp2p message codes
- handshakeMsg = 0x00
- discMsg = 0x01
- pingMsg = 0x02
- pongMsg = 0x03
- getPeersMsg = 0x04
- peersMsg = 0x05
-)
-
-// handshake is the structure of a handshake list.
-type handshake struct {
- Version uint64
- ID string
- Caps []Cap
- ListenPort uint64
- NodeID []byte
-}
-
-func (h *handshake) String() string {
- return h.ID
-}
-func (h *handshake) Pubkey() []byte {
- return h.NodeID
-}
-
// Cap is the structure of a peer capability.
type Cap struct {
Name string
@@ -74,215 +39,12 @@ func (cap Cap) RlpData() interface{} {
return []interface{}{cap.Name, cap.Version}
}
+func (cap Cap) String() string {
+ return fmt.Sprintf("%s/%d", cap.Name, cap.Version)
+}
+
type capsByName []Cap
func (cs capsByName) Len() int { return len(cs) }
func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
-
-type baseProtocol struct {
- rw MsgReadWriter
- peer *Peer
-}
-
-func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
- bp := &baseProtocol{rw, peer}
- errc := make(chan error, 1)
- go func() { errc <- rw.WriteMsg(bp.handshakeMsg()) }()
- if err := bp.readHandshake(); err != nil {
- return err
- }
- // handle write error
- if err := <-errc; err != nil {
- return err
- }
- // run main loop
- go func() {
- for {
- if err := bp.handle(rw); err != nil {
- errc <- err
- break
- }
- }
- }()
- return bp.loop(errc)
-}
-
-var pingTimeout = 2 * time.Second
-
-func (bp *baseProtocol) loop(quit <-chan error) error {
- ping := time.NewTimer(pingTimeout)
- activity := bp.peer.activity.Subscribe(time.Time{})
- lastActive := time.Time{}
- defer ping.Stop()
- defer activity.Unsubscribe()
-
- getPeersTick := time.NewTicker(10 * time.Second)
- defer getPeersTick.Stop()
- err := EncodeMsg(bp.rw, getPeersMsg)
-
- for err == nil {
- select {
- case err = <-quit:
- return err
- case <-getPeersTick.C:
- err = EncodeMsg(bp.rw, getPeersMsg)
- case event := <-activity.Chan():
- ping.Reset(pingTimeout)
- lastActive = event.(time.Time)
- case t := <-ping.C:
- if lastActive.Add(pingTimeout * 2).Before(t) {
- err = newPeerError(errPingTimeout, "")
- } else if lastActive.Add(pingTimeout).Before(t) {
- err = EncodeMsg(bp.rw, pingMsg)
- }
- }
- }
- return err
-}
-
-func (bp *baseProtocol) handle(rw MsgReadWriter) error {
- msg, err := rw.ReadMsg()
- if err != nil {
- return err
- }
- if msg.Size > baseProtocolMaxMsgSize {
- return newPeerError(errMisc, "message too big")
- }
- // make sure that the payload has been fully consumed
- defer msg.Discard()
-
- switch msg.Code {
- case handshakeMsg:
- return newPeerError(errProtocolBreach, "extra handshake received")
-
- case discMsg:
- var reason [1]DiscReason
- if err := msg.Decode(&reason); err != nil {
- return err
- }
- return discRequestedError(reason[0])
-
- case pingMsg:
- return EncodeMsg(bp.rw, pongMsg)
-
- case pongMsg:
-
- case getPeersMsg:
- peers := bp.peerList()
- // this is dangerous. the spec says that we should _delay_
- // sending the response if no new information is available.
- // this means that would need to send a response later when
- // new peers become available.
- //
- // TODO: add event mechanism to notify baseProtocol for new peers
- if len(peers) > 0 {
- return EncodeMsg(bp.rw, peersMsg, peers...)
- }
-
- case peersMsg:
- var peers []*peerAddr
- if err := msg.Decode(&peers); err != nil {
- return err
- }
- for _, addr := range peers {
- bp.peer.Debugf("received peer suggestion: %v", addr)
- bp.peer.newPeerAddr <- addr
- }
-
- default:
- return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
- }
- return nil
-}
-
-func (bp *baseProtocol) readHandshake() error {
- // read and handle remote handshake
- msg, err := bp.rw.ReadMsg()
- if err != nil {
- return err
- }
- if msg.Code != handshakeMsg {
- return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
- }
- if msg.Size > baseProtocolMaxMsgSize {
- return newPeerError(errMisc, "message too big")
- }
- var hs handshake
- if err := msg.Decode(&hs); err != nil {
- return err
- }
- // validate handshake info
- if hs.Version != baseProtocolVersion {
- return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
- baseProtocolVersion, hs.Version)
- }
- if len(hs.NodeID) == 0 {
- return newPeerError(errPubkeyMissing, "")
- }
- if len(hs.NodeID) != 64 {
- return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
- }
- if da := bp.peer.dialAddr; da != nil {
- // verify that the peer we wanted to connect to
- // actually holds the target public key.
- if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
- return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
- }
- }
- pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
- if err := bp.peer.pubkeyHook(pa); err != nil {
- return newPeerError(errPubkeyForbidden, "%v", err)
- }
- // TODO: remove Caps with empty name
- var addr *peerAddr
- if hs.ListenPort != 0 {
- addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
- addr.Port = hs.ListenPort
- }
- bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
- bp.peer.startSubprotocols(hs.Caps)
- return nil
-}
-
-func (bp *baseProtocol) handshakeMsg() Msg {
- var (
- port uint64
- caps []interface{}
- )
- if bp.peer.ourListenAddr != nil {
- port = bp.peer.ourListenAddr.Port
- }
- for _, proto := range bp.peer.protocols {
- caps = append(caps, proto.cap())
- }
- return NewMsg(handshakeMsg,
- baseProtocolVersion,
- bp.peer.ourID.String(),
- caps,
- port,
- bp.peer.ourID.Pubkey()[1:],
- )
-}
-
-func (bp *baseProtocol) peerList() []interface{} {
- peers := bp.peer.otherPeers()
- ds := make([]interface{}, 0, len(peers))
- for _, p := range peers {
- p.infolock.Lock()
- addr := p.listenAddr
- p.infolock.Unlock()
- // filter out this peer and peers that are not listening or
- // have not completed the handshake.
- // TODO: track previously sent peers and exclude them as well.
- if p == bp.peer || addr == nil {
- continue
- }
- ds = append(ds, addr)
- }
- ourAddr := bp.peer.ourListenAddr
- if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
- ds = append(ds, ourAddr)
- }
- return ds
-}
diff --git a/p2p/protocol_test.go b/p2p/protocol_test.go
deleted file mode 100644
index b1d10ac53..000000000
--- a/p2p/protocol_test.go
+++ /dev/null
@@ -1,158 +0,0 @@
-package p2p
-
-import (
- "fmt"
- "net"
- "reflect"
- "sync"
- "testing"
-
- "github.com/ethereum/go-ethereum/crypto"
-)
-
-type peerId struct {
- pubkey []byte
-}
-
-func (self *peerId) String() string {
- return fmt.Sprintf("test peer %x", self.Pubkey()[:4])
-}
-
-func (self *peerId) Pubkey() (pubkey []byte) {
- pubkey = self.pubkey
- if len(pubkey) == 0 {
- pubkey = crypto.GenerateNewKeyPair().PublicKey
- self.pubkey = pubkey
- }
- return
-}
-
-func newTestPeer() (peer *Peer) {
- peer = NewPeer(&peerId{}, []Cap{})
- peer.pubkeyHook = func(*peerAddr) error { return nil }
- peer.ourID = &peerId{}
- peer.listenAddr = &peerAddr{}
- peer.otherPeers = func() []*Peer { return nil }
- return
-}
-
-func TestBaseProtocolPeers(t *testing.T) {
- peerList := []*peerAddr{
- {IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}},
- {IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}},
- }
- listenAddr := &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}}
- rw1, rw2 := MsgPipe()
- defer rw1.Close()
- wg := new(sync.WaitGroup)
-
- // run matcher, close pipe when addresses have arrived
- numPeers := len(peerList) + 1
- addrChan := make(chan *peerAddr)
- wg.Add(1)
- go func() {
- i := 0
- for got := range addrChan {
- var want *peerAddr
- switch {
- case i < len(peerList):
- want = peerList[i]
- case i == len(peerList):
- want = listenAddr // listenAddr should be the last thing sent
- }
- t.Logf("got peer %d/%d: %v", i+1, numPeers, got)
- if !reflect.DeepEqual(want, got) {
- t.Errorf("mismatch: got %+v, want %+v", got, want)
- }
- i++
- if i == numPeers {
- break
- }
- }
- if i != numPeers {
- t.Errorf("wrong number of peers received: got %d, want %d", i, numPeers)
- }
- rw1.Close()
- wg.Done()
- }()
-
- // run first peer (in background)
- peer1 := newTestPeer()
- peer1.ourListenAddr = listenAddr
- peer1.otherPeers = func() []*Peer {
- pl := make([]*Peer, len(peerList))
- for i, addr := range peerList {
- pl[i] = &Peer{listenAddr: addr}
- }
- return pl
- }
- wg.Add(1)
- go func() {
- runBaseProtocol(peer1, rw1)
- wg.Done()
- }()
-
- // run second peer
- peer2 := newTestPeer()
- peer2.newPeerAddr = addrChan // feed peer suggestions into matcher
- if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed {
- t.Errorf("peer2 terminated with unexpected error: %v", err)
- }
-
- // terminate matcher
- close(addrChan)
- wg.Wait()
-}
-
-func TestBaseProtocolDisconnect(t *testing.T) {
- peer := NewPeer(&peerId{}, nil)
- peer.ourID = &peerId{}
- peer.pubkeyHook = func(*peerAddr) error { return nil }
-
- rw1, rw2 := MsgPipe()
- done := make(chan struct{})
- go func() {
- if err := expectMsg(rw2, handshakeMsg); err != nil {
- t.Error(err)
- }
- err := EncodeMsg(rw2, handshakeMsg,
- baseProtocolVersion,
- "",
- []interface{}{},
- 0,
- make([]byte, 64),
- )
- if err != nil {
- t.Error(err)
- }
- if err := expectMsg(rw2, getPeersMsg); err != nil {
- t.Error(err)
- }
- if err := EncodeMsg(rw2, discMsg, DiscQuitting); err != nil {
- t.Error(err)
- }
-
- close(done)
- }()
-
- if err := runBaseProtocol(peer, rw1); err == nil {
- t.Errorf("base protocol returned without error")
- } else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting {
- t.Errorf("base protocol returned wrong error: %v", err)
- }
- <-done
-}
-
-func expectMsg(r MsgReader, code uint64) error {
- msg, err := r.ReadMsg()
- if err != nil {
- return err
- }
- if err := msg.Discard(); err != nil {
- return err
- }
- if msg.Code != code {
- return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code)
- }
- return nil
-}
diff --git a/p2p/server.go b/p2p/server.go
index 4fd1f7d03..3ea2538d1 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -2,23 +2,34 @@ package p2p
import (
"bytes"
+ "crypto/ecdsa"
"errors"
"fmt"
"net"
+ "runtime"
"sync"
"time"
"github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+ "github.com/ethereum/go-ethereum/p2p/nat"
)
const (
- outboundAddressPoolSize = 500
- defaultDialTimeout = 10 * time.Second
- portMappingUpdateInterval = 15 * time.Minute
- portMappingTimeout = 20 * time.Minute
+ handshakeTimeout = 5 * time.Second
+ defaultDialTimeout = 10 * time.Second
+ refreshPeersInterval = 30 * time.Second
)
var srvlog = logger.NewLogger("P2P Server")
+var srvjslog = logger.NewJsonLogger()
+
+// MakeName creates a node name that follows the ethereum convention
+// for such names. It adds the operation system name and Go runtime version
+// the name.
+func MakeName(name, version string) string {
+ return fmt.Sprintf("%s/v%s/%s/%s", name, version, runtime.GOOS, runtime.Version())
+}
// Server manages all peer connections.
//
@@ -26,13 +37,21 @@ var srvlog = logger.NewLogger("P2P Server")
// You should set them before starting the Server. Fields may not be
// modified while the server is running.
type Server struct {
- // This field must be set to a valid client identity.
- Identity ClientIdentity
+ // This field must be set to a valid secp256k1 private key.
+ PrivateKey *ecdsa.PrivateKey
// MaxPeers is the maximum number of peers that can be
// connected. It must be greater than zero.
MaxPeers int
+ // Name sets the node name of this server.
+ // Use MakeName to create a name that follows existing conventions.
+ Name string
+
+ // Bootstrap nodes are used to establish connectivity
+ // with the rest of the network.
+ BootstrapNodes []*discover.Node
+
// Protocols should contain the protocols supported
// by the server. Matching protocols are launched for
// each peer.
@@ -53,7 +72,7 @@ type Server struct {
// If set to a non-nil value, the given NAT port mapper
// is used to make the listening port available to the
// Internet.
- NAT NAT
+ NAT nat.Interface
// If Dialer is set to a non-nil value, the given Dialer
// is used to dial outbound peer connections.
@@ -62,35 +81,28 @@ type Server struct {
// If NoDial is true, the server will not dial any peers.
NoDial bool
- // Hook for testing. This is useful because we can inhibit
+ // Hooks for testing. These are useful because we can inhibit
// the whole protocol stack.
- newPeerFunc peerFunc
+ setupFunc
+ newPeerHook
- lock sync.RWMutex
- running bool
- listener net.Listener
- laddr *net.TCPAddr // real listen addr
- peers []*Peer
- peerSlots chan int
- peerCount int
-
- quit chan struct{}
- wg sync.WaitGroup
- peerConnect chan *peerAddr
- peerDisconnect chan *Peer
-}
+ ourHandshake *protoHandshake
-// NAT is implemented by NAT traversal methods.
-type NAT interface {
- GetExternalAddress() (net.IP, error)
- AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
- DeletePortMapping(protocol string, extport, intport int) error
+ lock sync.RWMutex
+ running bool
+ listener net.Listener
+ peers map[discover.NodeID]*Peer
- // Should return name of the method.
- String() string
+ ntab *discover.Table
+
+ quit chan struct{}
+ loopWG sync.WaitGroup // {dial,listen,nat}Loop
+ peerWG sync.WaitGroup // active peer goroutines
+ peerConnect chan *discover.Node
}
-type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer
+type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node) (*conn, error)
+type newPeerHook func(*Peer)
// Peers returns all connected peers.
func (srv *Server) Peers() (peers []*Peer) {
@@ -107,18 +119,15 @@ func (srv *Server) Peers() (peers []*Peer) {
// PeerCount returns the number of connected peers.
func (srv *Server) PeerCount() int {
srv.lock.RLock()
- defer srv.lock.RUnlock()
- return srv.peerCount
+ n := len(srv.peers)
+ srv.lock.RUnlock()
+ return n
}
-// SuggestPeer injects an address into the outbound address pool.
-func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
- addr := &peerAddr{ip, uint64(port), nodeID}
- select {
- case srv.peerConnect <- addr:
- default: // don't block
- srvlog.Warnf("peer suggestion %v ignored", addr)
- }
+// SuggestPeer creates a connection to the given Node if it
+// is not already connected.
+func (srv *Server) SuggestPeer(n *discover.Node) {
+ srv.peerConnect <- n
}
// Broadcast sends an RLP-encoded message to all connected peers.
@@ -152,47 +161,53 @@ func (srv *Server) Start() (err error) {
}
srvlog.Infoln("Starting Server")
- // initialize fields
- if srv.Identity == nil {
- return fmt.Errorf("Server.Identity must be set to a non-nil identity")
+ // static fields
+ if srv.PrivateKey == nil {
+ return fmt.Errorf("Server.PrivateKey must be set to a non-nil key")
}
if srv.MaxPeers <= 0 {
return fmt.Errorf("Server.MaxPeers must be > 0")
}
srv.quit = make(chan struct{})
- srv.peers = make([]*Peer, srv.MaxPeers)
- srv.peerSlots = make(chan int, srv.MaxPeers)
- srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize)
- srv.peerDisconnect = make(chan *Peer)
- if srv.newPeerFunc == nil {
- srv.newPeerFunc = newServerPeer
+ srv.peers = make(map[discover.NodeID]*Peer)
+ srv.peerConnect = make(chan *discover.Node)
+ if srv.setupFunc == nil {
+ srv.setupFunc = setupConn
}
if srv.Blacklist == nil {
srv.Blacklist = NewBlacklist()
}
- if srv.Dialer == nil {
- srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
+
+ // node table
+ ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
+ if err != nil {
+ return err
+ }
+ srv.ntab = ntab
+
+ // handshake
+ srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self()}
+ for _, p := range srv.Protocols {
+ srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
}
+ // listen/dial
if srv.ListenAddr != "" {
if err := srv.startListening(); err != nil {
return err
}
}
+ if srv.Dialer == nil {
+ srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
+ }
if !srv.NoDial {
- srv.wg.Add(1)
+ srv.loopWG.Add(1)
go srv.dialLoop()
}
if srv.NoDial && srv.ListenAddr == "" {
srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
}
- // make all slots available
- for i := range srv.peers {
- srv.peerSlots <- i
- }
- // note: discLoop is not part of WaitGroup
- go srv.discLoop()
srv.running = true
return nil
}
@@ -202,14 +217,17 @@ func (srv *Server) startListening() error {
if err != nil {
return err
}
- srv.ListenAddr = listener.Addr().String()
- srv.laddr = listener.Addr().(*net.TCPAddr)
+ laddr := listener.Addr().(*net.TCPAddr)
+ srv.ListenAddr = laddr.String()
srv.listener = listener
- srv.wg.Add(1)
+ srv.loopWG.Add(1)
go srv.listenLoop()
- if !srv.laddr.IP.IsLoopback() && srv.NAT != nil {
- srv.wg.Add(1)
- go srv.natLoop(srv.laddr.Port)
+ if !laddr.IP.IsLoopback() && srv.NAT != nil {
+ srv.loopWG.Add(1)
+ go func() {
+ nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p")
+ srv.loopWG.Done()
+ }()
}
return nil
}
@@ -225,197 +243,182 @@ func (srv *Server) Stop() {
srv.running = false
srv.lock.Unlock()
- srvlog.Infoln("Stopping server")
+ srvlog.Infoln("Stopping Server")
+ srv.ntab.Close()
if srv.listener != nil {
// this unblocks listener Accept
srv.listener.Close()
}
close(srv.quit)
- for _, peer := range srv.Peers() {
- peer.Disconnect(DiscQuitting)
- }
- srv.wg.Wait()
-
- // wait till they actually disconnect
- // this is checked by claiming all peerSlots.
- // slots become available as the peers disconnect.
- for i := 0; i < cap(srv.peerSlots); i++ {
- <-srv.peerSlots
- }
- // terminate discLoop
- close(srv.peerDisconnect)
-}
+ srv.loopWG.Wait()
-func (srv *Server) discLoop() {
- for peer := range srv.peerDisconnect {
- srv.removePeer(peer)
+ // No new peers can be added at this point because dialLoop and
+ // listenLoop are down. It is safe to call peerWG.Wait because
+ // peerWG.Add is not called outside of those loops.
+ for _, peer := range srv.peers {
+ peer.Disconnect(DiscQuitting)
}
+ srv.peerWG.Wait()
}
// main loop for adding connections via listening
func (srv *Server) listenLoop() {
- defer srv.wg.Done()
-
+ defer srv.loopWG.Done()
srvlog.Infoln("Listening on", srv.listener.Addr())
for {
- select {
- case slot := <-srv.peerSlots:
- srvlog.Debugf("grabbed slot %v for listening", slot)
- conn, err := srv.listener.Accept()
- if err != nil {
- srv.peerSlots <- slot
- return
- }
- srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot)
- srv.addPeer(conn, nil, slot)
- case <-srv.quit:
+ conn, err := srv.listener.Accept()
+ if err != nil {
return
}
+ srvlog.Debugf("Accepted conn %v\n", conn.RemoteAddr())
+ srv.peerWG.Add(1)
+ go srv.startPeer(conn, nil)
}
}
-func (srv *Server) natLoop(port int) {
- defer srv.wg.Done()
+func (srv *Server) dialLoop() {
+ defer srv.loopWG.Done()
+ refresh := time.NewTicker(refreshPeersInterval)
+ defer refresh.Stop()
+
+ srv.ntab.Bootstrap(srv.BootstrapNodes)
+ go srv.findPeers()
+
+ dialed := make(chan *discover.Node)
+ dialing := make(map[discover.NodeID]bool)
+
+ // TODO: limit number of active dials
+ // TODO: ensure only one findPeers goroutine is running
+ // TODO: pause findPeers when we're at capacity
+
for {
- srv.updatePortMapping(port)
select {
- case <-time.After(portMappingUpdateInterval):
- // one more round
+ case <-refresh.C:
+
+ go srv.findPeers()
+
+ case dest := <-srv.peerConnect:
+ // avoid dialing nodes that are already connected.
+ // there is another check for this in addPeer,
+ // which runs after the handshake.
+ srv.lock.Lock()
+ _, isconnected := srv.peers[dest.ID]
+ srv.lock.Unlock()
+ if isconnected || dialing[dest.ID] || dest.ID == srv.ntab.Self() {
+ continue
+ }
+
+ dialing[dest.ID] = true
+ srv.peerWG.Add(1)
+ go func() {
+ srv.dialNode(dest)
+ // at this point, the peer has been added
+ // or discarded. either way, we're not dialing it anymore.
+ dialed <- dest
+ }()
+
+ case dest := <-dialed:
+ delete(dialing, dest.ID)
+
case <-srv.quit:
- srv.removePortMapping(port)
+ // TODO: maybe wait for active dials
return
}
}
}
-func (srv *Server) updatePortMapping(port int) {
- srvlog.Infoln("Attempting to map port", port, "with", srv.NAT)
- err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout)
+func (srv *Server) dialNode(dest *discover.Node) {
+ addr := &net.TCPAddr{IP: dest.IP, Port: dest.TCPPort}
+ srvlog.Debugf("Dialing %v\n", dest)
+ conn, err := srv.Dialer.Dial("tcp", addr.String())
if err != nil {
- srvlog.Errorln("Port mapping error:", err)
- return
- }
- extip, err := srv.NAT.GetExternalAddress()
- if err != nil {
- srvlog.Errorln("Error getting external IP:", err)
+ srvlog.DebugDetailf("dial error: %v", err)
return
}
- srv.lock.Lock()
- extaddr := *(srv.listener.Addr().(*net.TCPAddr))
- extaddr.IP = extip
- srvlog.Infoln("Mapped port, external addr is", &extaddr)
- srv.laddr = &extaddr
- srv.lock.Unlock()
-}
-
-func (srv *Server) removePortMapping(port int) {
- srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT)
- srv.NAT.DeletePortMapping("tcp", port, port)
+ srv.startPeer(conn, dest)
}
-func (srv *Server) dialLoop() {
- defer srv.wg.Done()
- var (
- suggest chan *peerAddr
- slot *int
- slots = srv.peerSlots
- )
- for {
- select {
- case i := <-slots:
- // we need a peer in slot i, slot reserved
- slot = &i
- // now we can watch for candidate peers in the next loop
- suggest = srv.peerConnect
- // do not consume more until candidate peer is found
- slots = nil
-
- case desc := <-suggest:
- // candidate peer found, will dial out asyncronously
- // if connection fails slot will be released
- srvlog.DebugDetailf("dial %v (%v)", desc, *slot)
- go srv.dialPeer(desc, *slot)
- // we can watch if more peers needed in the next loop
- slots = srv.peerSlots
- // until then we dont care about candidate peers
- suggest = nil
+func (srv *Server) findPeers() {
+ far := srv.ntab.Self()
+ for i := range far {
+ far[i] = ^far[i]
+ }
+ closeToSelf := srv.ntab.Lookup(srv.ntab.Self())
+ farFromSelf := srv.ntab.Lookup(far)
- case <-srv.quit:
- // give back the currently reserved slot
- if slot != nil {
- srv.peerSlots <- *slot
- }
- return
+ for i := 0; i < len(closeToSelf) || i < len(farFromSelf); i++ {
+ if i < len(closeToSelf) {
+ srv.peerConnect <- closeToSelf[i]
+ }
+ if i < len(farFromSelf) {
+ srv.peerConnect <- farFromSelf[i]
}
}
}
-// connect to peer via dial out
-func (srv *Server) dialPeer(desc *peerAddr, slot int) {
- srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot)
- conn, err := srv.Dialer.Dial(desc.Network(), desc.String())
+func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
+ // TODO: handle/store session token
+ fd.SetDeadline(time.Now().Add(handshakeTimeout))
+ conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest)
if err != nil {
- srvlog.DebugDetailf("dial error: %v", err)
- srv.peerSlots <- slot
+ fd.Close()
+ srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err)
+ return
+ }
+ p := newPeer(conn, srv.Protocols)
+ if ok, reason := srv.addPeer(conn.ID, p); !ok {
+ srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
+ p.politeDisconnect(reason)
return
}
- go srv.addPeer(conn, desc, slot)
-}
-// creates the new peer object and inserts it into its slot
-func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer {
- srv.lock.Lock()
- defer srv.lock.Unlock()
- if !srv.running {
- conn.Close()
- srv.peerSlots <- slot // release slot
- return nil
+ srvlog.Debugf("Added %v\n", p)
+ srvjslog.LogJson(&logger.P2PConnected{
+ RemoteId: fmt.Sprintf("%x", conn.ID[:]),
+ RemoteAddress: conn.RemoteAddr().String(),
+ RemoteVersionString: conn.Name,
+ NumConnections: srv.PeerCount(),
+ })
+
+ if srv.newPeerHook != nil {
+ srv.newPeerHook(p)
}
- peer := srv.newPeerFunc(srv, conn, desc)
- peer.slot = slot
- srv.peers[slot] = peer
- srv.peerCount++
- go func() { peer.loop(); srv.peerDisconnect <- peer }()
- return peer
+ discreason := p.run()
+ srv.removePeer(p)
+
+ srvlog.Debugf("Removed %v (%v)\n", p, discreason)
+ srvjslog.LogJson(&logger.P2PDisconnected{
+ RemoteId: fmt.Sprintf("%x", conn.ID[:]),
+ NumConnections: srv.PeerCount(),
+ })
}
-// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
-func (srv *Server) removePeer(peer *Peer) {
+func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) {
srv.lock.Lock()
defer srv.lock.Unlock()
- srvlog.Debugf("Removing %v (slot %v)\n", peer, peer.slot)
- if srv.peers[peer.slot] != peer {
- srvlog.Warnln("Invalid peer to remove:", peer)
- return
- }
- // remove from list and index
- srv.peerCount--
- srv.peers[peer.slot] = nil
- // release slot to signal need for a new peer, last!
- srv.peerSlots <- peer.slot
+ switch {
+ case !srv.running:
+ return false, DiscQuitting
+ case len(srv.peers) >= srv.MaxPeers:
+ return false, DiscTooManyPeers
+ case srv.peers[id] != nil:
+ return false, DiscAlreadyConnected
+ case srv.Blacklist.Exists(id[:]):
+ return false, DiscUselessPeer
+ case id == srv.ntab.Self():
+ return false, DiscSelf
+ }
+ srv.peers[id] = p
+ return true, 0
}
-func (srv *Server) verifyPeer(addr *peerAddr) error {
- if srv.Blacklist.Exists(addr.Pubkey) {
- return errors.New("blacklisted")
- }
- if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) {
- return newPeerError(errPubkeyForbidden, "not allowed to connect to srv")
- }
- srv.lock.RLock()
- defer srv.lock.RUnlock()
- for _, peer := range srv.peers {
- if peer != nil {
- id := peer.Identity()
- if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) {
- return errors.New("already connected")
- }
- }
- }
- return nil
+func (srv *Server) removePeer(p *Peer) {
+ srv.lock.Lock()
+ delete(srv.peers, p.ID())
+ srv.lock.Unlock()
+ srv.peerWG.Done()
}
-// TODO replace with "Set"
type Blacklist interface {
Get([]byte) (bool, error)
Put([]byte) error
@@ -453,15 +456,15 @@ func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) {
}
func (self *BlacklistMap) Put(pubkey []byte) error {
- self.lock.RLock()
- defer self.lock.RUnlock()
+ self.lock.Lock()
+ defer self.lock.Unlock()
self.blacklist[string(pubkey)] = true
return nil
}
func (self *BlacklistMap) Delete(pubkey []byte) error {
- self.lock.RLock()
- defer self.lock.RUnlock()
+ self.lock.Lock()
+ defer self.lock.Unlock()
delete(self.blacklist, string(pubkey))
return nil
}
diff --git a/p2p/server_test.go b/p2p/server_test.go
index ceb89e3f7..c109fffb9 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -2,19 +2,32 @@ package p2p
import (
"bytes"
+ "crypto/ecdsa"
"io"
+ "math/rand"
"net"
"sync"
"testing"
"time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/p2p/discover"
)
-func startTestServer(t *testing.T, pf peerFunc) *Server {
+func startTestServer(t *testing.T, pf newPeerHook) *Server {
server := &Server{
- Identity: &peerId{},
+ Name: "test",
MaxPeers: 10,
ListenAddr: "127.0.0.1:0",
- newPeerFunc: pf,
+ PrivateKey: newkey(),
+ newPeerHook: pf,
+ setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
+ id := randomID()
+ return &conn{
+ frameRW: newFrameRW(fd, msgWriteTimeout),
+ protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
+ }, nil
+ },
}
if err := server.Start(); err != nil {
t.Fatalf("Could not start server: %v", err)
@@ -27,16 +40,11 @@ func TestServerListen(t *testing.T) {
// start the test server
connected := make(chan *Peer)
- srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
- if conn == nil {
+ srv := startTestServer(t, func(p *Peer) {
+ if p == nil {
t.Error("peer func called with nil conn")
}
- if dialAddr != nil {
- t.Error("peer func called with non-nil dialAddr")
- }
- peer := newPeer(conn, nil, dialAddr)
- connected <- peer
- return peer
+ connected <- p
})
defer close(connected)
defer srv.Stop()
@@ -50,9 +58,9 @@ func TestServerListen(t *testing.T) {
select {
case peer := <-connected:
- if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() {
+ if peer.LocalAddr().String() != conn.RemoteAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v",
- peer.conn.LocalAddr(), conn.RemoteAddr())
+ peer.LocalAddr(), conn.RemoteAddr())
}
case <-time.After(1 * time.Second):
t.Error("server did not accept within one second")
@@ -62,7 +70,7 @@ func TestServerListen(t *testing.T) {
func TestServerDial(t *testing.T) {
defer testlog(t).detach()
- // run a fake TCP server to handle the connection.
+ // run a one-shot TCP server to handle the connection.
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("could not setup listener: %v")
@@ -72,41 +80,32 @@ func TestServerDial(t *testing.T) {
go func() {
conn, err := listener.Accept()
if err != nil {
- t.Error("acccept error:", err)
+ t.Error("accept error:", err)
+ return
}
conn.Close()
accepted <- conn
}()
- // start the test server
+ // start the server
connected := make(chan *Peer)
- srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
- if conn == nil {
- t.Error("peer func called with nil conn")
- }
- peer := newPeer(conn, nil, dialAddr)
- connected <- peer
- return peer
- })
+ srv := startTestServer(t, func(p *Peer) { connected <- p })
defer close(connected)
defer srv.Stop()
- // tell the server to connect.
- connAddr := newPeerAddr(listener.Addr(), nil)
- srv.peerConnect <- connAddr
+ // tell the server to connect
+ tcpAddr := listener.Addr().(*net.TCPAddr)
+ srv.SuggestPeer(&discover.Node{IP: tcpAddr.IP, TCPPort: tcpAddr.Port})
select {
case conn := <-accepted:
select {
case peer := <-connected:
- if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() {
+ if peer.RemoteAddr().String() != conn.LocalAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v",
- peer.conn.RemoteAddr(), conn.LocalAddr())
- }
- if peer.dialAddr != connAddr {
- t.Errorf("peer started with wrong dialAddr: got %v, want %v",
- peer.dialAddr, connAddr)
+ peer.RemoteAddr(), conn.LocalAddr())
}
+ // TODO: validate more fields
case <-time.After(1 * time.Second):
t.Error("server did not launch peer within one second")
}
@@ -118,16 +117,15 @@ func TestServerDial(t *testing.T) {
func TestServerBroadcast(t *testing.T) {
defer testlog(t).detach()
+
var connected sync.WaitGroup
- srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer {
- peer := newPeer(c, []Protocol{discard}, dialAddr)
- peer.startSubprotocols([]Cap{discard.cap()})
+ srv := startTestServer(t, func(p *Peer) {
+ p.running = matchProtocols([]Protocol{discard}, []Cap{discard.cap()}, p.rw)
connected.Done()
- return peer
})
defer srv.Stop()
- // dial a bunch of conns
+ // create a few peers
var conns = make([]net.Conn, 8)
connected.Add(len(conns))
deadline := time.Now().Add(3 * time.Second)
@@ -159,3 +157,18 @@ func TestServerBroadcast(t *testing.T) {
}
}
}
+
+func newkey() *ecdsa.PrivateKey {
+ key, err := crypto.GenerateKey()
+ if err != nil {
+ panic("couldn't generate key: " + err.Error())
+ }
+ return key
+}
+
+func randomID() (id discover.NodeID) {
+ for i := range id {
+ id[i] = byte(rand.Intn(255))
+ }
+ return id
+}
diff --git a/p2p/testlog_test.go b/p2p/testlog_test.go
index 951d43243..c524c154c 100644
--- a/p2p/testlog_test.go
+++ b/p2p/testlog_test.go
@@ -15,7 +15,7 @@ func testlog(t *testing.T) testLogger {
return l
}
-func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
+func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugDetailLevel }
func (testLogger) SetLogLevel(logger.LogLevel) {}
func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
diff --git a/p2p/testpoc7.go b/p2p/testpoc7.go
deleted file mode 100644
index 63b996f79..000000000
--- a/p2p/testpoc7.go
+++ /dev/null
@@ -1,40 +0,0 @@
-// +build none
-
-package main
-
-import (
- "fmt"
- "log"
- "net"
- "os"
-
- "github.com/ethereum/go-ethereum/crypto/secp256k1"
- "github.com/ethereum/go-ethereum/logger"
- "github.com/ethereum/go-ethereum/p2p"
-)
-
-func main() {
- logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
-
- pub, _ := secp256k1.GenerateKeyPair()
- srv := p2p.Server{
- MaxPeers: 10,
- Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
- ListenAddr: ":30303",
- NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
- }
- if err := srv.Start(); err != nil {
- fmt.Println("could not start server:", err)
- os.Exit(1)
- }
-
- // add seed peers
- seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
- if err != nil {
- fmt.Println("couldn't resolve:", err)
- os.Exit(1)
- }
- srv.SuggestPeer(seed.IP, seed.Port, nil)
-
- select {}
-}