aboutsummaryrefslogtreecommitdiffstats
path: root/p2p
diff options
context:
space:
mode:
Diffstat (limited to 'p2p')
-rw-r--r--p2p/discover/node.go306
-rw-r--r--p2p/discover/node_test.go219
-rw-r--r--p2p/discover/table.go280
-rw-r--r--p2p/discover/table_test.go311
-rw-r--r--p2p/discover/udp.go432
-rw-r--r--p2p/discover/udp_test.go211
-rw-r--r--p2p/handshake.go436
-rw-r--r--p2p/handshake_test.go171
-rw-r--r--p2p/message.go210
-rw-r--r--p2p/message_test.go151
-rw-r--r--p2p/nat/nat.go235
-rw-r--r--p2p/nat/natpmp.go115
-rw-r--r--p2p/nat/natupnp.go149
-rw-r--r--p2p/peer.go312
-rw-r--r--p2p/peer_error.go131
-rw-r--r--p2p/peer_test.go226
-rw-r--r--p2p/protocol.go50
-rw-r--r--p2p/rlpx.go174
-rw-r--r--p2p/rlpx_test.go124
-rw-r--r--p2p/server.go421
-rw-r--r--p2p/server_test.go179
-rw-r--r--p2p/testlog_test.go28
22 files changed, 4871 insertions, 0 deletions
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
new file mode 100644
index 000000000..de2588258
--- /dev/null
+++ b/p2p/discover/node.go
@@ -0,0 +1,306 @@
+package discover
+
+import (
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "math/big"
+ "math/rand"
+ "net"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/secp256k1"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+const nodeIDBits = 512
+
+// Node represents a host on the network.
+type Node struct {
+ ID NodeID
+ IP net.IP
+
+ DiscPort int // UDP listening port for discovery protocol
+ TCPPort int // TCP listening port for RLPx
+
+ active time.Time
+}
+
+func newNode(id NodeID, addr *net.UDPAddr) *Node {
+ return &Node{
+ ID: id,
+ IP: addr.IP,
+ DiscPort: addr.Port,
+ TCPPort: addr.Port,
+ active: time.Now(),
+ }
+}
+
+func (n *Node) isValid() bool {
+ // TODO: don't accept localhost, LAN addresses from internet hosts
+ return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
+}
+
+// The string representation of a Node is a URL.
+// Please see ParseNode for a description of the format.
+func (n *Node) String() string {
+ addr := net.TCPAddr{IP: n.IP, Port: n.TCPPort}
+ u := url.URL{
+ Scheme: "enode",
+ User: url.User(fmt.Sprintf("%x", n.ID[:])),
+ Host: addr.String(),
+ }
+ if n.DiscPort != n.TCPPort {
+ u.RawQuery = "discport=" + strconv.Itoa(n.DiscPort)
+ }
+ return u.String()
+}
+
+// ParseNode parses a node URL.
+//
+// A node URL has scheme "enode".
+//
+// The hexadecimal node ID is encoded in the username portion of the
+// URL, separated from the host by an @ sign. The hostname can only be
+// given as an IP address, DNS domain names are not allowed. The port
+// in the host name section is the TCP listening port. If the TCP and
+// UDP (discovery) ports differ, the UDP port is specified as query
+// parameter "discport".
+//
+// In the following example, the node URL describes
+// a node with IP address 10.3.58.6, TCP listening port 30303
+// and UDP discovery port 30301.
+//
+// enode://<hex node id>@10.3.58.6:30303?discport=30301
+func ParseNode(rawurl string) (*Node, error) {
+ var n Node
+ u, err := url.Parse(rawurl)
+ if u.Scheme != "enode" {
+ return nil, errors.New("invalid URL scheme, want \"enode\"")
+ }
+ if u.User == nil {
+ return nil, errors.New("does not contain node ID")
+ }
+ if n.ID, err = HexID(u.User.String()); err != nil {
+ return nil, fmt.Errorf("invalid node ID (%v)", err)
+ }
+ ip, port, err := net.SplitHostPort(u.Host)
+ if err != nil {
+ return nil, fmt.Errorf("invalid host: %v", err)
+ }
+ if n.IP = net.ParseIP(ip); n.IP == nil {
+ return nil, errors.New("invalid IP address")
+ }
+ if n.TCPPort, err = strconv.Atoi(port); err != nil {
+ return nil, errors.New("invalid port")
+ }
+ qv := u.Query()
+ if qv.Get("discport") == "" {
+ n.DiscPort = n.TCPPort
+ } else {
+ if n.DiscPort, err = strconv.Atoi(qv.Get("discport")); err != nil {
+ return nil, errors.New("invalid discport in query")
+ }
+ }
+ return &n, nil
+}
+
+// MustParseNode parses a node URL. It panics if the URL is not valid.
+func MustParseNode(rawurl string) *Node {
+ n, err := ParseNode(rawurl)
+ if err != nil {
+ panic("invalid node URL: " + err.Error())
+ }
+ return n
+}
+
+func (n Node) EncodeRLP(w io.Writer) error {
+ return rlp.Encode(w, rpcNode{IP: n.IP.String(), Port: uint16(n.TCPPort), ID: n.ID})
+}
+func (n *Node) DecodeRLP(s *rlp.Stream) (err error) {
+ var ext rpcNode
+ if err = s.Decode(&ext); err == nil {
+ n.TCPPort = int(ext.Port)
+ n.DiscPort = int(ext.Port)
+ n.ID = ext.ID
+ if n.IP = net.ParseIP(ext.IP); n.IP == nil {
+ return errors.New("invalid IP string")
+ }
+ }
+ return err
+}
+
+// NodeID is a unique identifier for each node.
+// The node identifier is a marshaled elliptic curve public key.
+type NodeID [nodeIDBits / 8]byte
+
+// NodeID prints as a long hexadecimal number.
+func (n NodeID) String() string {
+ return fmt.Sprintf("%#x", n[:])
+}
+
+// The Go syntax representation of a NodeID is a call to HexID.
+func (n NodeID) GoString() string {
+ return fmt.Sprintf("discover.HexID(\"%#x\")", n[:])
+}
+
+// HexID converts a hex string to a NodeID.
+// The string may be prefixed with 0x.
+func HexID(in string) (NodeID, error) {
+ if strings.HasPrefix(in, "0x") {
+ in = in[2:]
+ }
+ var id NodeID
+ b, err := hex.DecodeString(in)
+ if err != nil {
+ return id, err
+ } else if len(b) != len(id) {
+ return id, fmt.Errorf("wrong length, need %d hex bytes", len(id))
+ }
+ copy(id[:], b)
+ return id, nil
+}
+
+// MustHexID converts a hex string to a NodeID.
+// It panics if the string is not a valid NodeID.
+func MustHexID(in string) NodeID {
+ id, err := HexID(in)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// PubkeyID returns a marshaled representation of the given public key.
+func PubkeyID(pub *ecdsa.PublicKey) NodeID {
+ var id NodeID
+ pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
+ if len(pbytes)-1 != len(id) {
+ panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
+ }
+ copy(id[:], pbytes[1:])
+ return id
+}
+
+// Pubkey returns the public key represented by the node ID.
+// It returns an error if the ID is not a point on the curve.
+func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) {
+ p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)}
+ half := len(id) / 2
+ p.X.SetBytes(id[:half])
+ p.Y.SetBytes(id[half:])
+ if !p.Curve.IsOnCurve(p.X, p.Y) {
+ return nil, errors.New("not a point on the S256 curve")
+ }
+ return p, nil
+}
+
+// recoverNodeID computes the public key used to sign the
+// given hash from the signature.
+func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
+ pubkey, err := secp256k1.RecoverPubkey(hash, sig)
+ if err != nil {
+ return id, err
+ }
+ if len(pubkey)-1 != len(id) {
+ return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
+ }
+ for i := range id {
+ id[i] = pubkey[i+1]
+ }
+ return id, nil
+}
+
+// distcmp compares the distances a->target and b->target.
+// Returns -1 if a is closer to target, 1 if b is closer to target
+// and 0 if they are equal.
+func distcmp(target, a, b NodeID) int {
+ for i := range target {
+ da := a[i] ^ target[i]
+ db := b[i] ^ target[i]
+ if da > db {
+ return 1
+ } else if da < db {
+ return -1
+ }
+ }
+ return 0
+}
+
+// table of leading zero counts for bytes [0..255]
+var lzcount = [256]int{
+ 8, 7, 6, 6, 5, 5, 5, 5,
+ 4, 4, 4, 4, 4, 4, 4, 4,
+ 3, 3, 3, 3, 3, 3, 3, 3,
+ 3, 3, 3, 3, 3, 3, 3, 3,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+}
+
+// logdist returns the logarithmic distance between a and b, log2(a ^ b).
+func logdist(a, b NodeID) int {
+ lz := 0
+ for i := range a {
+ x := a[i] ^ b[i]
+ if x == 0 {
+ lz += 8
+ } else {
+ lz += lzcount[x]
+ break
+ }
+ }
+ return len(a)*8 - lz
+}
+
+// randomID returns a random NodeID such that logdist(a, b) == n
+func randomID(a NodeID, n int) (b NodeID) {
+ if n == 0 {
+ return a
+ }
+ // flip bit at position n, fill the rest with random bits
+ b = a
+ pos := len(a) - n/8 - 1
+ bit := byte(0x01) << (byte(n%8) - 1)
+ if bit == 0 {
+ pos++
+ bit = 0x80
+ }
+ b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
+ for i := pos + 1; i < len(a); i++ {
+ b[i] = byte(rand.Intn(255))
+ }
+ return b
+}
diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go
new file mode 100644
index 000000000..60b01b6ca
--- /dev/null
+++ b/p2p/discover/node_test.go
@@ -0,0 +1,219 @@
+package discover
+
+import (
+ "math/big"
+ "math/rand"
+ "net"
+ "reflect"
+ "testing"
+ "testing/quick"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+)
+
+var (
+ quickrand = rand.New(rand.NewSource(time.Now().Unix()))
+ quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand}
+)
+
+var parseNodeTests = []struct {
+ rawurl string
+ wantError string
+ wantResult *Node
+}{
+ {
+ rawurl: "http://foobar",
+ wantError: `invalid URL scheme, want "enode"`,
+ },
+ {
+ rawurl: "enode://foobar",
+ wantError: `does not contain node ID`,
+ },
+ {
+ rawurl: "enode://01010101@123.124.125.126:3",
+ wantError: `invalid node ID (wrong length, need 64 hex bytes)`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@hostname:3",
+ wantError: `invalid IP address`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo",
+ wantError: `invalid port`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo",
+ wantError: `invalid discport in query`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150",
+ wantResult: &Node{
+ ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ IP: net.ParseIP("127.0.0.1"),
+ DiscPort: 52150,
+ TCPPort: 52150,
+ },
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150",
+ wantResult: &Node{
+ ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ IP: net.ParseIP("::"),
+ DiscPort: 52150,
+ TCPPort: 52150,
+ },
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=223344",
+ wantResult: &Node{
+ ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ IP: net.ParseIP("127.0.0.1"),
+ DiscPort: 223344,
+ TCPPort: 52150,
+ },
+ },
+}
+
+func TestParseNode(t *testing.T) {
+ for i, test := range parseNodeTests {
+ n, err := ParseNode(test.rawurl)
+ if err == nil && test.wantError != "" {
+ t.Errorf("test %d: got nil error, expected %#q", i, test.wantError)
+ continue
+ }
+ if err != nil && err.Error() != test.wantError {
+ t.Errorf("test %d: got error %#q, expected %#q", i, err.Error(), test.wantError)
+ continue
+ }
+ if !reflect.DeepEqual(n, test.wantResult) {
+ t.Errorf("test %d: result mismatch:\ngot: %#v, want: %#v", i, n, test.wantResult)
+ }
+ }
+}
+
+func TestNodeString(t *testing.T) {
+ for i, test := range parseNodeTests {
+ if test.wantError != "" {
+ continue
+ }
+ str := test.wantResult.String()
+ if str != test.rawurl {
+ t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.rawurl)
+ }
+ }
+}
+
+func TestHexID(t *testing.T) {
+ ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
+ id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
+ id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
+
+ if id1 != ref {
+ t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:])
+ }
+ if id2 != ref {
+ t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:])
+ }
+}
+
+func TestNodeID_recover(t *testing.T) {
+ prv := newkey()
+ hash := make([]byte, 32)
+ sig, err := crypto.Sign(hash, prv)
+ if err != nil {
+ t.Fatalf("signing error: %v", err)
+ }
+
+ pub := PubkeyID(&prv.PublicKey)
+ recpub, err := recoverNodeID(hash, sig)
+ if err != nil {
+ t.Fatalf("recovery error: %v", err)
+ }
+ if pub != recpub {
+ t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub)
+ }
+
+ ecdsa, err := pub.Pubkey()
+ if err != nil {
+ t.Errorf("Pubkey error: %v", err)
+ }
+ if !reflect.DeepEqual(ecdsa, &prv.PublicKey) {
+ t.Errorf("Pubkey mismatch:\n got: %#v\n want: %#v", ecdsa, &prv.PublicKey)
+ }
+}
+
+func TestNodeID_pubkeyBad(t *testing.T) {
+ ecdsa, err := NodeID{}.Pubkey()
+ if err == nil {
+ t.Error("expected error for zero ID")
+ }
+ if ecdsa != nil {
+ t.Error("expected nil result")
+ }
+}
+
+func TestNodeID_distcmp(t *testing.T) {
+ distcmpBig := func(target, a, b NodeID) int {
+ tbig := new(big.Int).SetBytes(target[:])
+ abig := new(big.Int).SetBytes(a[:])
+ bbig := new(big.Int).SetBytes(b[:])
+ return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
+ }
+ if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil {
+ t.Error(err)
+ }
+}
+
+// the random tests is likely to miss the case where they're equal.
+func TestNodeID_distcmpEqual(t *testing.T) {
+ base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+ x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
+ if distcmp(base, x, x) != 0 {
+ t.Errorf("distcmp(base, x, x) != 0")
+ }
+}
+
+func TestNodeID_logdist(t *testing.T) {
+ logdistBig := func(a, b NodeID) int {
+ abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
+ return new(big.Int).Xor(abig, bbig).BitLen()
+ }
+ if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil {
+ t.Error(err)
+ }
+}
+
+// the random tests is likely to miss the case where they're equal.
+func TestNodeID_logdistEqual(t *testing.T) {
+ x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+ if logdist(x, x) != 0 {
+ t.Errorf("logdist(x, x) != 0")
+ }
+}
+
+func TestNodeID_randomID(t *testing.T) {
+ // we don't use quick.Check here because its output isn't
+ // very helpful when the test fails.
+ for i := 0; i < quickcfg.MaxCount; i++ {
+ a := gen(NodeID{}, quickrand).(NodeID)
+ dist := quickrand.Intn(len(NodeID{}) * 8)
+ result := randomID(a, dist)
+ actualdist := logdist(result, a)
+
+ if dist != actualdist {
+ t.Log("a: ", a)
+ t.Log("result:", result)
+ t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist)
+ }
+ }
+}
+
+func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
+ var id NodeID
+ m := rand.Intn(len(id))
+ for i := len(id) - 1; i > m; i-- {
+ id[i] = byte(rand.Uint32())
+ }
+ return reflect.ValueOf(id)
+}
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
new file mode 100644
index 000000000..33b705a12
--- /dev/null
+++ b/p2p/discover/table.go
@@ -0,0 +1,280 @@
+// Package discover implements the Node Discovery Protocol.
+//
+// The Node Discovery protocol provides a way to find RLPx nodes that
+// can be connected to. It uses a Kademlia-like protocol to maintain a
+// distributed database of the IDs and endpoints of all listening
+// nodes.
+package discover
+
+import (
+ "net"
+ "sort"
+ "sync"
+ "time"
+)
+
+const (
+ alpha = 3 // Kademlia concurrency factor
+ bucketSize = 16 // Kademlia bucket size
+ nBuckets = nodeIDBits + 1 // Number of buckets
+)
+
+type Table struct {
+ mutex sync.Mutex // protects buckets, their content, and nursery
+ buckets [nBuckets]*bucket // index of known nodes by distance
+ nursery []*Node // bootstrap nodes
+
+ net transport
+ self *Node // metadata of the local node
+}
+
+// transport is implemented by the UDP transport.
+// it is an interface so we can test without opening lots of UDP
+// sockets and without generating a private key.
+type transport interface {
+ ping(*Node) error
+ findnode(e *Node, target NodeID) ([]*Node, error)
+ close()
+}
+
+// bucket contains nodes, ordered by their last activity.
+type bucket struct {
+ lastLookup time.Time
+ entries []*Node
+}
+
+func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
+ tab := &Table{net: t, self: newNode(ourID, ourAddr)}
+ for i := range tab.buckets {
+ tab.buckets[i] = new(bucket)
+ }
+ return tab
+}
+
+// Self returns the local node.
+func (tab *Table) Self() *Node {
+ return tab.self
+}
+
+// Close terminates the network listener.
+func (tab *Table) Close() {
+ tab.net.close()
+}
+
+// Bootstrap sets the bootstrap nodes. These nodes are used to connect
+// to the network if the table is empty. Bootstrap will also attempt to
+// fill the table by performing random lookup operations on the
+// network.
+func (tab *Table) Bootstrap(nodes []*Node) {
+ tab.mutex.Lock()
+ // TODO: maybe filter nodes with bad fields (nil, etc.) to avoid strange crashes
+ tab.nursery = make([]*Node, 0, len(nodes))
+ for _, n := range nodes {
+ cpy := *n
+ tab.nursery = append(tab.nursery, &cpy)
+ }
+ tab.mutex.Unlock()
+ tab.refresh()
+}
+
+// Lookup performs a network search for nodes close
+// to the given target. It approaches the target by querying
+// nodes that are closer to it on each iteration.
+func (tab *Table) Lookup(target NodeID) []*Node {
+ var (
+ asked = make(map[NodeID]bool)
+ seen = make(map[NodeID]bool)
+ reply = make(chan []*Node, alpha)
+ pendingQueries = 0
+ )
+ // don't query further if we hit the target or ourself.
+ // unlikely to happen often in practice.
+ asked[target] = true
+ asked[tab.self.ID] = true
+
+ tab.mutex.Lock()
+ // update last lookup stamp (for refresh logic)
+ tab.buckets[logdist(tab.self.ID, target)].lastLookup = time.Now()
+ // generate initial result set
+ result := tab.closest(target, bucketSize)
+ tab.mutex.Unlock()
+
+ for {
+ // ask the alpha closest nodes that we haven't asked yet
+ for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
+ n := result.entries[i]
+ if !asked[n.ID] {
+ asked[n.ID] = true
+ pendingQueries++
+ go func() {
+ result, _ := tab.net.findnode(n, target)
+ reply <- result
+ }()
+ }
+ }
+ if pendingQueries == 0 {
+ // we have asked all closest nodes, stop the search
+ break
+ }
+
+ // wait for the next reply
+ for _, n := range <-reply {
+ cn := n
+ if !seen[n.ID] {
+ seen[n.ID] = true
+ result.push(cn, bucketSize)
+ }
+ }
+ pendingQueries--
+ }
+ return result.entries
+}
+
+// refresh performs a lookup for a random target to keep buckets full.
+func (tab *Table) refresh() {
+ ld := -1 // logdist of chosen bucket
+ tab.mutex.Lock()
+ for i, b := range tab.buckets {
+ if i > 0 && b.lastLookup.Before(time.Now().Add(-1*time.Hour)) {
+ ld = i
+ break
+ }
+ }
+ tab.mutex.Unlock()
+
+ result := tab.Lookup(randomID(tab.self.ID, ld))
+ if len(result) == 0 {
+ // bootstrap the table with a self lookup
+ tab.mutex.Lock()
+ tab.add(tab.nursery)
+ tab.mutex.Unlock()
+ tab.Lookup(tab.self.ID)
+ // TODO: the Kademlia paper says that we're supposed to perform
+ // random lookups in all buckets further away than our closest neighbor.
+ }
+}
+
+// closest returns the n nodes in the table that are closest to the
+// given id. The caller must hold tab.mutex.
+func (tab *Table) closest(target NodeID, nresults int) *nodesByDistance {
+ // This is a very wasteful way to find the closest nodes but
+ // obviously correct. I believe that tree-based buckets would make
+ // this easier to implement efficiently.
+ close := &nodesByDistance{target: target}
+ for _, b := range tab.buckets {
+ for _, n := range b.entries {
+ close.push(n, nresults)
+ }
+ }
+ return close
+}
+
+func (tab *Table) len() (n int) {
+ for _, b := range tab.buckets {
+ n += len(b.entries)
+ }
+ return n
+}
+
+// bumpOrAdd updates the activity timestamp for the given node and
+// attempts to insert the node into a bucket. The returned Node might
+// not be part of the table. The caller must hold tab.mutex.
+func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
+ b := tab.buckets[logdist(tab.self.ID, node)]
+ if n = b.bump(node); n == nil {
+ n = newNode(node, from)
+ if len(b.entries) == bucketSize {
+ tab.pingReplace(n, b)
+ } else {
+ b.entries = append(b.entries, n)
+ }
+ }
+ return n
+}
+
+func (tab *Table) pingReplace(n *Node, b *bucket) {
+ old := b.entries[bucketSize-1]
+ go func() {
+ if err := tab.net.ping(old); err == nil {
+ // it responded, we don't need to replace it.
+ return
+ }
+ // it didn't respond, replace the node if it is still the oldest node.
+ tab.mutex.Lock()
+ if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
+ // slide down other entries and put the new one in front.
+ // TODO: insert in correct position to keep the order
+ copy(b.entries[1:], b.entries)
+ b.entries[0] = n
+ }
+ tab.mutex.Unlock()
+ }()
+}
+
+// bump updates the activity timestamp for the given node.
+// The caller must hold tab.mutex.
+func (tab *Table) bump(node NodeID) {
+ tab.buckets[logdist(tab.self.ID, node)].bump(node)
+}
+
+// add puts the entries into the table if their corresponding
+// bucket is not full. The caller must hold tab.mutex.
+func (tab *Table) add(entries []*Node) {
+outer:
+ for _, n := range entries {
+ if n == nil || n.ID == tab.self.ID {
+ // skip bad entries. The RLP decoder returns nil for empty
+ // input lists.
+ continue
+ }
+ bucket := tab.buckets[logdist(tab.self.ID, n.ID)]
+ for i := range bucket.entries {
+ if bucket.entries[i].ID == n.ID {
+ // already in bucket
+ continue outer
+ }
+ }
+ if len(bucket.entries) < bucketSize {
+ bucket.entries = append(bucket.entries, n)
+ }
+ }
+}
+
+func (b *bucket) bump(id NodeID) *Node {
+ for i, n := range b.entries {
+ if n.ID == id {
+ n.active = time.Now()
+ // move it to the front
+ copy(b.entries[1:], b.entries[:i+1])
+ b.entries[0] = n
+ return n
+ }
+ }
+ return nil
+}
+
+// nodesByDistance is a list of nodes, ordered by
+// distance to target.
+type nodesByDistance struct {
+ entries []*Node
+ target NodeID
+}
+
+// push adds the given node to the list, keeping the total size below maxElems.
+func (h *nodesByDistance) push(n *Node, maxElems int) {
+ ix := sort.Search(len(h.entries), func(i int) bool {
+ return distcmp(h.target, h.entries[i].ID, n.ID) > 0
+ })
+ if len(h.entries) < maxElems {
+ h.entries = append(h.entries, n)
+ }
+ if ix == len(h.entries) {
+ // farther away than all nodes we already have.
+ // if there was room for it, the node is now the last element.
+ } else {
+ // slide existing entries down to make room
+ // this will overwrite the entry we just appended.
+ copy(h.entries[ix+1:], h.entries[ix:])
+ h.entries[ix] = n
+ }
+}
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
new file mode 100644
index 000000000..08faea68e
--- /dev/null
+++ b/p2p/discover/table_test.go
@@ -0,0 +1,311 @@
+package discover
+
+import (
+ "crypto/ecdsa"
+ "errors"
+ "fmt"
+ "math/rand"
+ "net"
+ "reflect"
+ "testing"
+ "testing/quick"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+)
+
+func TestTable_bumpOrAddBucketAssign(t *testing.T) {
+ tab := newTable(nil, NodeID{}, &net.UDPAddr{})
+ for i := 1; i < len(tab.buckets); i++ {
+ tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{})
+ }
+ for i, b := range tab.buckets {
+ if i > 0 && len(b.entries) != 1 {
+ t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries))
+ }
+ }
+}
+
+func TestTable_bumpOrAddPingReplace(t *testing.T) {
+ pingC := make(pingC)
+ tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
+ last := fillBucket(tab, 200)
+
+ // this bumpOrAdd should not replace the last node
+ // because the node replies to ping.
+ new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
+
+ pinged := <-pingC
+ if pinged != last.ID {
+ t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID)
+ }
+
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ if l := len(tab.buckets[200].entries); l != bucketSize {
+ t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
+ }
+ if !contains(tab.buckets[200].entries, last.ID) {
+ t.Error("last entry was removed")
+ }
+ if contains(tab.buckets[200].entries, new.ID) {
+ t.Error("new entry was added")
+ }
+}
+
+func TestTable_bumpOrAddPingTimeout(t *testing.T) {
+ tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
+ last := fillBucket(tab, 200)
+
+ // this bumpOrAdd should replace the last node
+ // because the node does not reply to ping.
+ new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
+
+ // wait for async bucket update. damn. this needs to go away.
+ time.Sleep(2 * time.Millisecond)
+
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ if l := len(tab.buckets[200].entries); l != bucketSize {
+ t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
+ }
+ if contains(tab.buckets[200].entries, last.ID) {
+ t.Error("last entry was not removed")
+ }
+ if !contains(tab.buckets[200].entries, new.ID) {
+ t.Error("new entry was not added")
+ }
+}
+
+func fillBucket(tab *Table, ld int) (last *Node) {
+ b := tab.buckets[ld]
+ for len(b.entries) < bucketSize {
+ b.entries = append(b.entries, &Node{ID: randomID(tab.self.ID, ld)})
+ }
+ return b.entries[bucketSize-1]
+}
+
+type pingC chan NodeID
+
+func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) {
+ panic("findnode called on pingRecorder")
+}
+func (t pingC) close() {
+ panic("close called on pingRecorder")
+}
+func (t pingC) ping(n *Node) error {
+ if t == nil {
+ return errTimeout
+ }
+ t <- n.ID
+ return nil
+}
+
+func TestTable_bump(t *testing.T) {
+ tab := newTable(nil, NodeID{}, &net.UDPAddr{})
+
+ // add an old entry and two recent ones
+ oldactive := time.Now().Add(-2 * time.Minute)
+ old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
+ others := []*Node{
+ &Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
+ &Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
+ }
+ tab.add(append(others, old))
+ if tab.buckets[200].entries[0] == old {
+ t.Fatal("old entry is at front of bucket")
+ }
+
+ // bumping the old entry should move it to the front
+ tab.bump(old.ID)
+ if old.active == oldactive {
+ t.Error("activity timestamp not updated")
+ }
+ if tab.buckets[200].entries[0] != old {
+ t.Errorf("bumped entry did not move to the front of bucket")
+ }
+}
+
+func TestTable_closest(t *testing.T) {
+ t.Parallel()
+
+ test := func(test *closeTest) bool {
+ // for any node table, Target and N
+ tab := newTable(nil, test.Self, &net.UDPAddr{})
+ tab.add(test.All)
+
+ // check that doClosest(Target, N) returns nodes
+ result := tab.closest(test.Target, test.N).entries
+ if hasDuplicates(result) {
+ t.Errorf("result contains duplicates")
+ return false
+ }
+ if !sortedByDistanceTo(test.Target, result) {
+ t.Errorf("result is not sorted by distance to target")
+ return false
+ }
+
+ // check that the number of results is min(N, tablen)
+ wantN := test.N
+ if tlen := tab.len(); tlen < test.N {
+ wantN = tlen
+ }
+ if len(result) != wantN {
+ t.Errorf("wrong number of nodes: got %d, want %d", len(result), wantN)
+ return false
+ } else if len(result) == 0 {
+ return true // no need to check distance
+ }
+
+ // check that the result nodes have minimum distance to target.
+ for _, b := range tab.buckets {
+ for _, n := range b.entries {
+ if contains(result, n.ID) {
+ continue // don't run the check below for nodes in result
+ }
+ farthestResult := result[len(result)-1].ID
+ if distcmp(test.Target, n.ID, farthestResult) < 0 {
+ t.Errorf("table contains node that is closer to target but it's not in result")
+ t.Logf(" Target: %v", test.Target)
+ t.Logf(" Farthest Result: %v", farthestResult)
+ t.Logf(" ID: %v", n.ID)
+ return false
+ }
+ }
+ }
+ return true
+ }
+ if err := quick.Check(test, quickcfg); err != nil {
+ t.Error(err)
+ }
+}
+
+type closeTest struct {
+ Self NodeID
+ Target NodeID
+ All []*Node
+ N int
+}
+
+func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
+ t := &closeTest{
+ Self: gen(NodeID{}, rand).(NodeID),
+ Target: gen(NodeID{}, rand).(NodeID),
+ N: rand.Intn(bucketSize),
+ }
+ for _, id := range gen([]NodeID{}, rand).([]NodeID) {
+ t.All = append(t.All, &Node{ID: id})
+ }
+ return reflect.ValueOf(t)
+}
+
+func TestTable_Lookup(t *testing.T) {
+ self := gen(NodeID{}, quickrand).(NodeID)
+ target := randomID(self, 200)
+ transport := findnodeOracle{t, target}
+ tab := newTable(transport, self, &net.UDPAddr{})
+
+ // lookup on empty table returns no nodes
+ if results := tab.Lookup(target); len(results) > 0 {
+ t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
+ }
+ // seed table with initial node (otherwise lookup will terminate immediately)
+ tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200})
+
+ results := tab.Lookup(target)
+ t.Logf("results:")
+ for _, e := range results {
+ t.Logf(" ld=%d, %v", logdist(target, e.ID), e.ID)
+ }
+ if len(results) != bucketSize {
+ t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize)
+ }
+ if hasDuplicates(results) {
+ t.Errorf("result set contains duplicate entries")
+ }
+ if !sortedByDistanceTo(target, results) {
+ t.Errorf("result set not sorted by distance to target")
+ }
+ if !contains(results, target) {
+ t.Errorf("result set does not contain target")
+ }
+}
+
+// findnode on this transport always returns at least one node
+// that is one bucket closer to the target.
+type findnodeOracle struct {
+ t *testing.T
+ target NodeID
+}
+
+func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
+ t.t.Logf("findnode query at dist %d", n.DiscPort)
+ // current log distance is encoded in port number
+ var result []*Node
+ switch n.DiscPort {
+ case 0:
+ panic("query to node at distance 0")
+ default:
+ // TODO: add more randomness to distances
+ next := n.DiscPort - 1
+ for i := 0; i < bucketSize; i++ {
+ result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
+ }
+ }
+ return result, nil
+}
+
+func (t findnodeOracle) close() {}
+
+func (t findnodeOracle) ping(n *Node) error {
+ return errors.New("ping is not supported by this transport")
+}
+
+func hasDuplicates(slice []*Node) bool {
+ seen := make(map[NodeID]bool)
+ for _, e := range slice {
+ if seen[e.ID] {
+ return true
+ }
+ seen[e.ID] = true
+ }
+ return false
+}
+
+func sortedByDistanceTo(distbase NodeID, slice []*Node) bool {
+ var last NodeID
+ for i, e := range slice {
+ if i > 0 && distcmp(distbase, e.ID, last) < 0 {
+ return false
+ }
+ last = e.ID
+ }
+ return true
+}
+
+func contains(ns []*Node, id NodeID) bool {
+ for _, n := range ns {
+ if n.ID == id {
+ return true
+ }
+ }
+ return false
+}
+
+// gen wraps quick.Value so it's easier to use.
+// it generates a random value of the given value's type.
+func gen(typ interface{}, rand *rand.Rand) interface{} {
+ v, ok := quick.Value(reflect.TypeOf(typ), rand)
+ if !ok {
+ panic(fmt.Sprintf("couldn't generate random value of type %T", typ))
+ }
+ return v.Interface()
+}
+
+func newkey() *ecdsa.PrivateKey {
+ key, err := crypto.GenerateKey()
+ if err != nil {
+ panic("couldn't generate key: " + err.Error())
+ }
+ return key
+}
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
new file mode 100644
index 000000000..69e9f3c2e
--- /dev/null
+++ b/p2p/discover/udp.go
@@ -0,0 +1,432 @@
+package discover
+
+import (
+ "bytes"
+ "crypto/ecdsa"
+ "errors"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/p2p/nat"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+var log = logger.NewLogger("P2P Discovery")
+
+// Errors
+var (
+ errPacketTooSmall = errors.New("too small")
+ errBadHash = errors.New("bad hash")
+ errExpired = errors.New("expired")
+ errTimeout = errors.New("RPC timeout")
+ errClosed = errors.New("socket closed")
+)
+
+// Timeouts
+const (
+ respTimeout = 300 * time.Millisecond
+ sendTimeout = 300 * time.Millisecond
+ expiration = 20 * time.Second
+
+ refreshInterval = 1 * time.Hour
+)
+
+// RPC packet types
+const (
+ pingPacket = iota + 1 // zero is 'reserved'
+ pongPacket
+ findnodePacket
+ neighborsPacket
+)
+
+// RPC request structures
+type (
+ ping struct {
+ IP string // our IP
+ Port uint16 // our port
+ Expiration uint64
+ }
+
+ // reply to Ping
+ pong struct {
+ ReplyTok []byte
+ Expiration uint64
+ }
+
+ findnode struct {
+ // Id to look up. The responding node will send back nodes
+ // closest to the target.
+ Target NodeID
+ Expiration uint64
+ }
+
+ // reply to findnode
+ neighbors struct {
+ Nodes []*Node
+ Expiration uint64
+ }
+)
+
+type rpcNode struct {
+ IP string
+ Port uint16
+ ID NodeID
+}
+
+// udp implements the RPC protocol.
+type udp struct {
+ conn *net.UDPConn
+ priv *ecdsa.PrivateKey
+ addpending chan *pending
+ replies chan reply
+ closing chan struct{}
+ nat nat.Interface
+
+ *Table
+}
+
+// pending represents a pending reply.
+//
+// some implementations of the protocol wish to send more than one
+// reply packet to findnode. in general, any neighbors packet cannot
+// be matched up with a specific findnode packet.
+//
+// our implementation handles this by storing a callback function for
+// each pending reply. incoming packets from a node are dispatched
+// to all the callback functions for that node.
+type pending struct {
+ // these fields must match in the reply.
+ from NodeID
+ ptype byte
+
+ // time when the request must complete
+ deadline time.Time
+
+ // callback is called when a matching reply arrives. if it returns
+ // true, the callback is removed from the pending reply queue.
+ // if it returns false, the reply is considered incomplete and
+ // the callback will be invoked again for the next matching reply.
+ callback func(resp interface{}) (done bool)
+
+ // errc receives nil when the callback indicates completion or an
+ // error if no further reply is received within the timeout.
+ errc chan<- error
+}
+
+type reply struct {
+ from NodeID
+ ptype byte
+ data interface{}
+}
+
+// ListenUDP returns a new table that listens for UDP packets on laddr.
+func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table, error) {
+ addr, err := net.ResolveUDPAddr("udp", laddr)
+ if err != nil {
+ return nil, err
+ }
+ conn, err := net.ListenUDP("udp", addr)
+ if err != nil {
+ return nil, err
+ }
+ udp := &udp{
+ conn: conn,
+ priv: priv,
+ closing: make(chan struct{}),
+ addpending: make(chan *pending),
+ replies: make(chan reply),
+ }
+
+ realaddr := conn.LocalAddr().(*net.UDPAddr)
+ if natm != nil {
+ if !realaddr.IP.IsLoopback() {
+ go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
+ }
+ // TODO: react to external IP changes over time.
+ if ext, err := natm.ExternalIP(); err == nil {
+ realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
+ }
+ }
+ udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
+
+ go udp.loop()
+ go udp.readLoop()
+ log.Infoln("Listening, ", udp.self)
+ return udp.Table, nil
+}
+
+func (t *udp) close() {
+ close(t.closing)
+ t.conn.Close()
+ // TODO: wait for the loops to end.
+}
+
+// ping sends a ping message to the given node and waits for a reply.
+func (t *udp) ping(e *Node) error {
+ // TODO: maybe check for ReplyTo field in callback to measure RTT
+ errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
+ t.send(e, pingPacket, ping{
+ IP: t.self.IP.String(),
+ Port: uint16(t.self.TCPPort),
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ return <-errc
+}
+
+// findnode sends a findnode request to the given node and waits until
+// the node has sent up to k neighbors.
+func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
+ nodes := make([]*Node, 0, bucketSize)
+ nreceived := 0
+ errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool {
+ reply := r.(*neighbors)
+ for _, n := range reply.Nodes {
+ nreceived++
+ if n.isValid() {
+ nodes = append(nodes, n)
+ }
+ }
+ return nreceived >= bucketSize
+ })
+
+ t.send(to, findnodePacket, findnode{
+ Target: target,
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ err := <-errc
+ return nodes, err
+}
+
+// pending adds a reply callback to the pending reply queue.
+// see the documentation of type pending for a detailed explanation.
+func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
+ ch := make(chan error, 1)
+ p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
+ select {
+ case t.addpending <- p:
+ // loop will handle it
+ case <-t.closing:
+ ch <- errClosed
+ }
+ return ch
+}
+
+// loop runs in its own goroutin. it keeps track of
+// the refresh timer and the pending reply queue.
+func (t *udp) loop() {
+ var (
+ pending []*pending
+ nextDeadline time.Time
+ timeout = time.NewTimer(0)
+ refresh = time.NewTicker(refreshInterval)
+ )
+ <-timeout.C // ignore first timeout
+ defer refresh.Stop()
+ defer timeout.Stop()
+
+ rearmTimeout := func() {
+ if len(pending) == 0 || nextDeadline == pending[0].deadline {
+ return
+ }
+ nextDeadline = pending[0].deadline
+ timeout.Reset(nextDeadline.Sub(time.Now()))
+ }
+
+ for {
+ select {
+ case <-refresh.C:
+ go t.refresh()
+
+ case <-t.closing:
+ for _, p := range pending {
+ p.errc <- errClosed
+ }
+ return
+
+ case p := <-t.addpending:
+ p.deadline = time.Now().Add(respTimeout)
+ pending = append(pending, p)
+ rearmTimeout()
+
+ case reply := <-t.replies:
+ // run matching callbacks, remove if they return false.
+ for i := 0; i < len(pending); i++ {
+ p := pending[i]
+ if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) {
+ p.errc <- nil
+ copy(pending[i:], pending[i+1:])
+ pending = pending[:len(pending)-1]
+ i--
+ }
+ }
+ rearmTimeout()
+
+ case now := <-timeout.C:
+ // notify and remove callbacks whose deadline is in the past.
+ i := 0
+ for ; i < len(pending) && now.After(pending[i].deadline); i++ {
+ pending[i].errc <- errTimeout
+ }
+ if i > 0 {
+ copy(pending, pending[i:])
+ pending = pending[:len(pending)-i]
+ }
+ rearmTimeout()
+ }
+ }
+}
+
+const (
+ macSize = 256 / 8
+ sigSize = 520 / 8
+ headSize = macSize + sigSize // space of packet frame data
+)
+
+var headSpace = make([]byte, headSize)
+
+func (t *udp) send(to *Node, ptype byte, req interface{}) error {
+ b := new(bytes.Buffer)
+ b.Write(headSpace)
+ b.WriteByte(ptype)
+ if err := rlp.Encode(b, req); err != nil {
+ log.Errorln("error encoding packet:", err)
+ return err
+ }
+
+ packet := b.Bytes()
+ sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
+ if err != nil {
+ log.Errorln("could not sign packet:", err)
+ return err
+ }
+ copy(packet[macSize:], sig)
+ // add the hash to the front. Note: this doesn't protect the
+ // packet in any way. Our public key will be part of this hash in
+ // the future.
+ copy(packet, crypto.Sha3(packet[macSize:]))
+
+ toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
+ log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
+ if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
+ log.DebugDetailln("UDP send failed:", err)
+ }
+ return err
+}
+
+// readLoop runs in its own goroutine. it handles incoming UDP packets.
+func (t *udp) readLoop() {
+ defer t.conn.Close()
+ buf := make([]byte, 4096) // TODO: good buffer size
+ for {
+ nbytes, from, err := t.conn.ReadFromUDP(buf)
+ if err != nil {
+ return
+ }
+ if err := t.packetIn(from, buf[:nbytes]); err != nil {
+ log.Debugf("Bad packet from %v: %v\n", from, err)
+ }
+ }
+}
+
+func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
+ if len(buf) < headSize+1 {
+ return errPacketTooSmall
+ }
+ hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
+ shouldhash := crypto.Sha3(buf[macSize:])
+ if !bytes.Equal(hash, shouldhash) {
+ return errBadHash
+ }
+ fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
+ if err != nil {
+ return err
+ }
+
+ var req interface {
+ handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
+ }
+ switch ptype := sigdata[0]; ptype {
+ case pingPacket:
+ req = new(ping)
+ case pongPacket:
+ req = new(pong)
+ case findnodePacket:
+ req = new(findnode)
+ case neighborsPacket:
+ req = new(neighbors)
+ default:
+ return fmt.Errorf("unknown type: %d", ptype)
+ }
+ if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
+ return err
+ }
+ log.DebugDetailf("<<< %v %T %v\n", from, req, req)
+ return req.handle(t, from, fromID, hash)
+}
+
+func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ t.mutex.Lock()
+ // Note: we're ignoring the provided IP address right now
+ n := t.bumpOrAdd(fromID, from)
+ if req.Port != 0 {
+ n.TCPPort = int(req.Port)
+ }
+ t.mutex.Unlock()
+
+ t.send(n, pongPacket, pong{
+ ReplyTok: mac,
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ return nil
+}
+
+func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ t.mutex.Lock()
+ t.bump(fromID)
+ t.mutex.Unlock()
+
+ t.replies <- reply{fromID, pongPacket, req}
+ return nil
+}
+
+func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ t.mutex.Lock()
+ e := t.bumpOrAdd(fromID, from)
+ closest := t.closest(req.Target, bucketSize).entries
+ t.mutex.Unlock()
+
+ t.send(e, neighborsPacket, neighbors{
+ Nodes: closest,
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ return nil
+}
+
+func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ t.mutex.Lock()
+ t.bump(fromID)
+ t.add(req.Nodes)
+ t.mutex.Unlock()
+
+ t.replies <- reply{fromID, neighborsPacket, req}
+ return nil
+}
+
+func expired(ts uint64) bool {
+ return time.Unix(int64(ts), 0).Before(time.Now())
+}
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
new file mode 100644
index 000000000..0a8ff6358
--- /dev/null
+++ b/p2p/discover/udp_test.go
@@ -0,0 +1,211 @@
+package discover
+
+import (
+ "fmt"
+ logpkg "log"
+ "net"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/logger"
+)
+
+func init() {
+ logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
+}
+
+func TestUDP_ping(t *testing.T) {
+ t.Parallel()
+
+ n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ defer n1.Close()
+ defer n2.Close()
+
+ if err := n1.net.ping(n2.self); err != nil {
+ t.Fatalf("ping error: %v", err)
+ }
+ if find(n2, n1.self.ID) == nil {
+ t.Errorf("node 2 does not contain id of node 1")
+ }
+ if e := find(n1, n2.self.ID); e != nil {
+ t.Errorf("node 1 does contains id of node 2: %v", e)
+ }
+}
+
+func find(tab *Table, id NodeID) *Node {
+ for _, b := range tab.buckets {
+ for _, e := range b.entries {
+ if e.ID == id {
+ return e
+ }
+ }
+ }
+ return nil
+}
+
+func TestUDP_findnode(t *testing.T) {
+ t.Parallel()
+
+ n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ defer n1.Close()
+ defer n2.Close()
+
+ // put a few nodes into n2. the exact distribution shouldn't
+ // matter much, altough we need to take care not to overflow
+ // any bucket.
+ target := randomID(n1.self.ID, 100)
+ nodes := &nodesByDistance{target: target}
+ for i := 0; i < bucketSize; i++ {
+ n2.add([]*Node{&Node{
+ IP: net.IP{1, 2, 3, byte(i)},
+ DiscPort: i + 2,
+ TCPPort: i + 2,
+ ID: randomID(n2.self.ID, i+2),
+ }})
+ }
+ n2.add(nodes.entries)
+ n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
+ expected := n2.closest(target, bucketSize)
+
+ err := runUDP(10, func() error {
+ result, _ := n1.net.findnode(n2.self, target)
+ if len(result) != bucketSize {
+ return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
+ }
+ for i := range result {
+ if result[i].ID != expected.entries[i].ID {
+ return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestUDP_replytimeout(t *testing.T) {
+ t.Parallel()
+
+ // reserve a port so we don't talk to an existing service by accident
+ addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
+ fd, err := net.ListenUDP("udp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer fd.Close()
+
+ n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ defer n1.Close()
+ n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
+
+ if err := n1.net.ping(n2); err != errTimeout {
+ t.Error("expected timeout error, got", err)
+ }
+
+ if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
+ t.Error("expected timeout error, got", err)
+ } else if len(result) > 0 {
+ t.Error("expected empty result, got", result)
+ }
+}
+
+func TestUDP_findnodeMultiReply(t *testing.T) {
+ t.Parallel()
+
+ n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ udp2 := n2.net.(*udp)
+ defer n1.Close()
+ defer n2.Close()
+
+ err := runUDP(10, func() error {
+ nodes := make([]*Node, bucketSize)
+ for i := range nodes {
+ nodes[i] = &Node{
+ IP: net.IP{1, 2, 3, 4},
+ DiscPort: i + 1,
+ TCPPort: i + 1,
+ ID: randomID(n2.self.ID, i+1),
+ }
+ }
+
+ // ask N2 for neighbors. it will send an empty reply back.
+ // the request will wait for up to bucketSize replies.
+ resultc := make(chan []*Node)
+ errc := make(chan error)
+ go func() {
+ ns, err := n1.net.findnode(n2.self, n1.self.ID)
+ if err != nil {
+ errc <- err
+ } else {
+ resultc <- ns
+ }
+ }()
+
+ // send a few more neighbors packets to N1.
+ // it should collect those.
+ for end := 0; end < len(nodes); {
+ off := end
+ if end = end + 5; end > len(nodes) {
+ end = len(nodes)
+ }
+ udp2.send(n1.self, neighborsPacket, neighbors{
+ Nodes: nodes[off:end],
+ Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
+ })
+ }
+
+ // check that they are all returned. we cannot just check for
+ // equality because they might not be returned in the order they
+ // were sent.
+ var result []*Node
+ select {
+ case result = <-resultc:
+ case err := <-errc:
+ return err
+ }
+ if hasDuplicates(result) {
+ return fmt.Errorf("result slice contains duplicates")
+ }
+ if len(result) != len(nodes) {
+ return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
+ }
+ matched := make(map[NodeID]bool)
+ for _, n := range result {
+ for _, expn := range nodes {
+ if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
+ matched[n.ID] = true
+ }
+ }
+ }
+ if len(matched) != len(nodes) {
+ return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
+ }
+ return nil
+ })
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+// runUDP runs a test n times and returns an error if the test failed
+// in all n runs. This is necessary because UDP is unreliable even for
+// connections on the local machine, causing test failures.
+func runUDP(n int, test func() error) error {
+ errcount := 0
+ errors := ""
+ for i := 0; i < n; i++ {
+ if err := test(); err != nil {
+ errors += fmt.Sprintf("\n#%d: %v", i, err)
+ errcount++
+ }
+ }
+ if errcount == n {
+ return fmt.Errorf("failed on all %d iterations:%s", n, errors)
+ }
+ return nil
+}
diff --git a/p2p/handshake.go b/p2p/handshake.go
new file mode 100644
index 000000000..7fc497517
--- /dev/null
+++ b/p2p/handshake.go
@@ -0,0 +1,436 @@
+package p2p
+
+import (
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "errors"
+ "fmt"
+ "hash"
+ "io"
+ "net"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/ecies"
+ "github.com/ethereum/go-ethereum/crypto/secp256k1"
+ "github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+const (
+ sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
+ sigLen = 65 // elliptic S256
+ pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
+ shaLen = 32 // hash length (for nonce etc)
+
+ authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
+ authRespLen = pubLen + shaLen + 1
+
+ eciesBytes = 65 + 16 + 32
+ encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
+ encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
+)
+
+// conn represents a remote connection after encryption handshake
+// and protocol handshake have completed.
+//
+// The MsgReadWriter is usually layered as follows:
+//
+// netWrapper (I/O timeouts, thread-safe ReadMsg, WriteMsg)
+// rlpxFrameRW (message encoding, encryption, authentication)
+// bufio.ReadWriter (buffering)
+// net.Conn (network I/O)
+//
+type conn struct {
+ MsgReadWriter
+ *protoHandshake
+}
+
+// secrets represents the connection secrets
+// which are negotiated during the encryption handshake.
+type secrets struct {
+ RemoteID discover.NodeID
+ AES, MAC []byte
+ EgressMAC, IngressMAC hash.Hash
+ Token []byte
+}
+
+// protoHandshake is the RLP structure of the protocol handshake.
+type protoHandshake struct {
+ Version uint64
+ Name string
+ Caps []Cap
+ ListenPort uint64
+ ID discover.NodeID
+}
+
+// setupConn starts a protocol session on the given connection.
+// It runs the encryption handshake and the protocol handshake.
+// If dial is non-nil, the connection the local node is the initiator.
+func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
+ if dial == nil {
+ return setupInboundConn(fd, prv, our)
+ } else {
+ return setupOutboundConn(fd, prv, our, dial)
+ }
+}
+
+func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (*conn, error) {
+ secrets, err := receiverEncHandshake(fd, prv, nil)
+ if err != nil {
+ return nil, fmt.Errorf("encryption handshake failed: %v", err)
+ }
+
+ // Run the protocol handshake using authenticated messages.
+ rw := newRlpxFrameRW(fd, secrets)
+ rhs, err := readProtocolHandshake(rw, our)
+ if err != nil {
+ return nil, err
+ }
+ if rhs.ID != secrets.RemoteID {
+ return nil, errors.New("node ID in protocol handshake does not match encryption handshake")
+ }
+ // TODO: validate that handshake node ID matches
+ if err := writeProtocolHandshake(rw, our); err != nil {
+ return nil, fmt.Errorf("protocol write error: %v", err)
+ }
+ return &conn{rw, rhs}, nil
+}
+
+func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
+ secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil)
+ if err != nil {
+ return nil, fmt.Errorf("encryption handshake failed: %v", err)
+ }
+
+ // Run the protocol handshake using authenticated messages.
+ rw := newRlpxFrameRW(fd, secrets)
+ if err := writeProtocolHandshake(rw, our); err != nil {
+ return nil, fmt.Errorf("protocol write error: %v", err)
+ }
+ rhs, err := readProtocolHandshake(rw, our)
+ if err != nil {
+ return nil, fmt.Errorf("protocol handshake read error: %v", err)
+ }
+ if rhs.ID != dial.ID {
+ return nil, errors.New("dialed node id mismatch")
+ }
+ return &conn{rw, rhs}, nil
+}
+
+// encHandshake contains the state of the encryption handshake.
+type encHandshake struct {
+ initiator bool
+ remoteID discover.NodeID
+
+ remotePub *ecies.PublicKey // remote-pubk
+ initNonce, respNonce []byte // nonce
+ randomPrivKey *ecies.PrivateKey // ecdhe-random
+ remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk
+}
+
+// secrets is called after the handshake is completed.
+// It extracts the connection secrets from the handshake values.
+func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) {
+ ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen)
+ if err != nil {
+ return secrets{}, err
+ }
+
+ // derive base secrets from ephemeral key agreement
+ sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce))
+ aesSecret := crypto.Sha3(ecdheSecret, sharedSecret)
+ s := secrets{
+ RemoteID: h.remoteID,
+ AES: aesSecret,
+ MAC: crypto.Sha3(ecdheSecret, aesSecret),
+ Token: crypto.Sha3(sharedSecret),
+ }
+
+ // setup sha3 instances for the MACs
+ mac1 := sha3.NewKeccak256()
+ mac1.Write(xor(s.MAC, h.respNonce))
+ mac1.Write(auth)
+ mac2 := sha3.NewKeccak256()
+ mac2.Write(xor(s.MAC, h.initNonce))
+ mac2.Write(authResp)
+ if h.initiator {
+ s.EgressMAC, s.IngressMAC = mac1, mac2
+ } else {
+ s.EgressMAC, s.IngressMAC = mac2, mac1
+ }
+
+ return s, nil
+}
+
+func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) {
+ return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen)
+}
+
+// initiatorEncHandshake negotiates a session token on conn.
+// it should be called on the dialing side of the connection.
+//
+// prv is the local client's private key.
+// token is the token from a previous session with this node.
+func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) {
+ h, err := newInitiatorHandshake(remoteID)
+ if err != nil {
+ return s, err
+ }
+ auth, err := h.authMsg(prv, token)
+ if err != nil {
+ return s, err
+ }
+ if _, err = conn.Write(auth); err != nil {
+ return s, err
+ }
+
+ response := make([]byte, encAuthRespLen)
+ if _, err = io.ReadFull(conn, response); err != nil {
+ return s, err
+ }
+ if err := h.decodeAuthResp(response, prv); err != nil {
+ return s, err
+ }
+ return h.secrets(auth, response)
+}
+
+func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) {
+ // generate random initiator nonce
+ n := make([]byte, shaLen)
+ if _, err := rand.Read(n); err != nil {
+ return nil, err
+ }
+ // generate random keypair to use for signing
+ randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
+ if err != nil {
+ return nil, err
+ }
+ rpub, err := remoteID.Pubkey()
+ if err != nil {
+ return nil, fmt.Errorf("bad remoteID: %v", err)
+ }
+ h := &encHandshake{
+ initiator: true,
+ remoteID: remoteID,
+ remotePub: ecies.ImportECDSAPublic(rpub),
+ initNonce: n,
+ randomPrivKey: randpriv,
+ }
+ return h, nil
+}
+
+// authMsg creates an encrypted initiator handshake message.
+func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
+ var tokenFlag byte
+ if token == nil {
+ // no session token found means we need to generate shared secret.
+ // ecies shared secret is used as initial session token for new peers
+ // generate shared key from prv and remote pubkey
+ var err error
+ if token, err = h.ecdhShared(prv); err != nil {
+ return nil, err
+ }
+ } else {
+ // for known peers, we use stored token from the previous session
+ tokenFlag = 0x01
+ }
+
+ // sign known message:
+ // ecdh-shared-secret^nonce for new peers
+ // token^nonce for old peers
+ signed := xor(token, h.initNonce)
+ signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA())
+ if err != nil {
+ return nil, err
+ }
+
+ // encode auth message
+ // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
+ msg := make([]byte, authMsgLen)
+ n := copy(msg, signature)
+ n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey)))
+ n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:])
+ n += copy(msg[n:], h.initNonce)
+ msg[n] = tokenFlag
+
+ // encrypt auth message using remote-pubk
+ return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil)
+}
+
+// decodeAuthResp decode an encrypted authentication response message.
+func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error {
+ msg, err := crypto.Decrypt(prv, auth)
+ if err != nil {
+ return fmt.Errorf("could not decrypt auth response (%v)", err)
+ }
+ h.respNonce = msg[pubLen : pubLen+shaLen]
+ h.remoteRandomPub, err = importPublicKey(msg[:pubLen])
+ if err != nil {
+ return err
+ }
+ // ignore token flag for now
+ return nil
+}
+
+// receiverEncHandshake negotiates a session token on conn.
+// it should be called on the listening side of the connection.
+//
+// prv is the local client's private key.
+// token is the token from a previous session with this node.
+func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) {
+ // read remote auth sent by initiator.
+ auth := make([]byte, encAuthMsgLen)
+ if _, err := io.ReadFull(conn, auth); err != nil {
+ return s, err
+ }
+ h, err := decodeAuthMsg(prv, token, auth)
+ if err != nil {
+ return s, err
+ }
+
+ // send auth response
+ resp, err := h.authResp(prv, token)
+ if err != nil {
+ return s, err
+ }
+ if _, err = conn.Write(resp); err != nil {
+ return s, err
+ }
+
+ return h.secrets(auth, resp)
+}
+
+func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) {
+ var err error
+ h := new(encHandshake)
+ // generate random keypair for session
+ h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
+ if err != nil {
+ return nil, err
+ }
+ // generate random nonce
+ h.respNonce = make([]byte, shaLen)
+ if _, err = rand.Read(h.respNonce); err != nil {
+ return nil, err
+ }
+
+ msg, err := crypto.Decrypt(prv, auth)
+ if err != nil {
+ return nil, fmt.Errorf("could not decrypt auth message (%v)", err)
+ }
+
+ // decode message parameters
+ // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
+ h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
+ copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen])
+ rpub, err := h.remoteID.Pubkey()
+ if err != nil {
+ return nil, fmt.Errorf("bad remoteID: %#v", err)
+ }
+ h.remotePub = ecies.ImportECDSAPublic(rpub)
+
+ // recover remote random pubkey from signed message.
+ if token == nil {
+ // TODO: it is an error if the initiator has a token and we don't. check that.
+
+ // no session token means we need to generate shared secret.
+ // ecies shared secret is used as initial session token for new peers.
+ // generate shared key from prv and remote pubkey.
+ if token, err = h.ecdhShared(prv); err != nil {
+ return nil, err
+ }
+ }
+ signedMsg := xor(token, h.initNonce)
+ remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen])
+ if err != nil {
+ return nil, err
+ }
+ h.remoteRandomPub, _ = importPublicKey(remoteRandomPub)
+ return h, nil
+}
+
+// authResp generates the encrypted authentication response message.
+func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
+ // responder auth message
+ // E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
+ resp := make([]byte, authRespLen)
+ n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey))
+ n += copy(resp[n:], h.respNonce)
+ if token == nil {
+ resp[n] = 0
+ } else {
+ resp[n] = 1
+ }
+ // encrypt using remote-pubk
+ return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil)
+}
+
+// importPublicKey unmarshals 512 bit public keys.
+func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) {
+ var pubKey65 []byte
+ switch len(pubKey) {
+ case 64:
+ // add 'uncompressed key' flag
+ pubKey65 = append([]byte{0x04}, pubKey...)
+ case 65:
+ pubKey65 = pubKey
+ default:
+ return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
+ }
+ // TODO: fewer pointless conversions
+ return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil
+}
+
+func exportPubkey(pub *ecies.PublicKey) []byte {
+ if pub == nil {
+ panic("nil pubkey")
+ }
+ return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:]
+}
+
+func xor(one, other []byte) (xor []byte) {
+ xor = make([]byte, len(one))
+ for i := 0; i < len(one); i++ {
+ xor[i] = one[i] ^ other[i]
+ }
+ return xor
+}
+
+func writeProtocolHandshake(w MsgWriter, our *protoHandshake) error {
+ return EncodeMsg(w, handshakeMsg, our.Version, our.Name, our.Caps, our.ListenPort, our.ID[:])
+}
+
+func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
+ // read and handle remote handshake
+ msg, err := r.ReadMsg()
+ if err != nil {
+ return nil, err
+ }
+ if msg.Code == discMsg {
+ // disconnect before protocol handshake is valid according to the
+ // spec and we send it ourself if Server.addPeer fails.
+ var reason DiscReason
+ rlp.Decode(msg.Payload, &reason)
+ return nil, discRequestedError(reason)
+ }
+ if msg.Code != handshakeMsg {
+ return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
+ }
+ if msg.Size > baseProtocolMaxMsgSize {
+ return nil, fmt.Errorf("message too big (%d > %d)", msg.Size, baseProtocolMaxMsgSize)
+ }
+ var hs protoHandshake
+ if err := msg.Decode(&hs); err != nil {
+ return nil, err
+ }
+ // validate handshake info
+ if hs.Version != our.Version {
+ return nil, newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", baseProtocolVersion, hs.Version)
+ }
+ if (hs.ID == discover.NodeID{}) {
+ return nil, newPeerError(errPubkeyInvalid, "missing")
+ }
+ return &hs, nil
+}
diff --git a/p2p/handshake_test.go b/p2p/handshake_test.go
new file mode 100644
index 000000000..19423bb82
--- /dev/null
+++ b/p2p/handshake_test.go
@@ -0,0 +1,171 @@
+package p2p
+
+import (
+ "bytes"
+ "crypto/rand"
+ "fmt"
+ "net"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/ecies"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+)
+
+func TestSharedSecret(t *testing.T) {
+ prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
+ pub0 := &prv0.PublicKey
+ prv1, _ := crypto.GenerateKey()
+ pub1 := &prv1.PublicKey
+
+ ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
+ if err != nil {
+ return
+ }
+ ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
+ if err != nil {
+ return
+ }
+ t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
+ if !bytes.Equal(ss0, ss1) {
+ t.Errorf("dont match :(")
+ }
+}
+
+func TestEncHandshake(t *testing.T) {
+ for i := 0; i < 20; i++ {
+ start := time.Now()
+ if err := testEncHandshake(nil); err != nil {
+ t.Fatalf("i=%d %v", i, err)
+ }
+ t.Logf("(without token) %d %v\n", i+1, time.Since(start))
+ }
+
+ for i := 0; i < 20; i++ {
+ tok := make([]byte, shaLen)
+ rand.Reader.Read(tok)
+ start := time.Now()
+ if err := testEncHandshake(tok); err != nil {
+ t.Fatalf("i=%d %v", i, err)
+ }
+ t.Logf("(with token) %d %v\n", i+1, time.Since(start))
+ }
+}
+
+func testEncHandshake(token []byte) error {
+ type result struct {
+ side string
+ s secrets
+ err error
+ }
+ var (
+ prv0, _ = crypto.GenerateKey()
+ prv1, _ = crypto.GenerateKey()
+ rw0, rw1 = net.Pipe()
+ output = make(chan result)
+ )
+
+ go func() {
+ r := result{side: "initiator"}
+ defer func() { output <- r }()
+
+ pub1s := discover.PubkeyID(&prv1.PublicKey)
+ r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
+ if r.err != nil {
+ return
+ }
+ id1 := discover.PubkeyID(&prv1.PublicKey)
+ if r.s.RemoteID != id1 {
+ r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
+ }
+ }()
+ go func() {
+ r := result{side: "receiver"}
+ defer func() { output <- r }()
+
+ r.s, r.err = receiverEncHandshake(rw1, prv1, token)
+ if r.err != nil {
+ return
+ }
+ id0 := discover.PubkeyID(&prv0.PublicKey)
+ if r.s.RemoteID != id0 {
+ r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
+ }
+ }()
+
+ // wait for results from both sides
+ r1, r2 := <-output, <-output
+
+ if r1.err != nil {
+ return fmt.Errorf("%s side error: %v", r1.side, r1.err)
+ }
+ if r2.err != nil {
+ return fmt.Errorf("%s side error: %v", r2.side, r2.err)
+ }
+
+ // don't compare remote node IDs
+ r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
+ // flip MACs on one of them so they compare equal
+ r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
+ if !reflect.DeepEqual(r1.s, r2.s) {
+ return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
+ }
+ return nil
+}
+
+func TestSetupConn(t *testing.T) {
+ prv0, _ := crypto.GenerateKey()
+ prv1, _ := crypto.GenerateKey()
+ node0 := &discover.Node{
+ ID: discover.PubkeyID(&prv0.PublicKey),
+ IP: net.IP{1, 2, 3, 4},
+ TCPPort: 33,
+ }
+ node1 := &discover.Node{
+ ID: discover.PubkeyID(&prv1.PublicKey),
+ IP: net.IP{5, 6, 7, 8},
+ TCPPort: 44,
+ }
+ hs0 := &protoHandshake{
+ Version: baseProtocolVersion,
+ ID: node0.ID,
+ Caps: []Cap{{"a", 0}, {"b", 2}},
+ }
+ hs1 := &protoHandshake{
+ Version: baseProtocolVersion,
+ ID: node1.ID,
+ Caps: []Cap{{"c", 1}, {"d", 3}},
+ }
+ fd0, fd1 := net.Pipe()
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ conn0, err := setupConn(fd0, prv0, hs0, node1)
+ if err != nil {
+ t.Errorf("outbound side error: %v", err)
+ return
+ }
+ if conn0.ID != node1.ID {
+ t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
+ }
+ if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
+ t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
+ }
+ }()
+
+ conn1, err := setupConn(fd1, prv1, hs1, nil)
+ if err != nil {
+ t.Fatalf("inbound side error: %v", err)
+ }
+ if conn1.ID != node0.ID {
+ t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
+ }
+ if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
+ t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
+ }
+
+ <-done
+}
diff --git a/p2p/message.go b/p2p/message.go
new file mode 100644
index 000000000..14e4404c9
--- /dev/null
+++ b/p2p/message.go
@@ -0,0 +1,210 @@
+package p2p
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// Msg defines the structure of a p2p message.
+//
+// Note that a Msg can only be sent once since the Payload reader is
+// consumed during sending. It is not possible to create a Msg and
+// send it any number of times. If you want to reuse an encoded
+// structure, encode the payload into a byte array and create a
+// separate Msg with a bytes.Reader as Payload for each send.
+type Msg struct {
+ Code uint64
+ Size uint32 // size of the paylod
+ Payload io.Reader
+}
+
+// NewMsg creates an RLP-encoded message with the given code.
+func NewMsg(code uint64, params ...interface{}) Msg {
+ p := bytes.NewReader(common.Encode(params))
+ return Msg{Code: code, Size: uint32(p.Len()), Payload: p}
+}
+
+// Decode parse the RLP content of a message into
+// the given value, which must be a pointer.
+//
+// For the decoding rules, please see package rlp.
+func (msg Msg) Decode(val interface{}) error {
+ if err := rlp.Decode(msg.Payload, val); err != nil {
+ return newPeerError(errInvalidMsg, "(code %#x) (size %d) %v", msg.Code, msg.Size, err)
+ }
+ return nil
+}
+
+func (msg Msg) String() string {
+ return fmt.Sprintf("msg #%v (%v bytes)", msg.Code, msg.Size)
+}
+
+// Discard reads any remaining payload data into a black hole.
+func (msg Msg) Discard() error {
+ _, err := io.Copy(ioutil.Discard, msg.Payload)
+ return err
+}
+
+type MsgReader interface {
+ ReadMsg() (Msg, error)
+}
+
+type MsgWriter interface {
+ // WriteMsg sends a message. It will block until the message's
+ // Payload has been consumed by the other end.
+ //
+ // Note that messages can be sent only once because their
+ // payload reader is drained.
+ WriteMsg(Msg) error
+}
+
+// MsgReadWriter provides reading and writing of encoded messages.
+// Implementations should ensure that ReadMsg and WriteMsg can be
+// called simultaneously from multiple goroutines.
+type MsgReadWriter interface {
+ MsgReader
+ MsgWriter
+}
+
+// EncodeMsg writes an RLP-encoded message with the given code and
+// data elements.
+func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
+ return w.WriteMsg(NewMsg(code, data...))
+}
+
+// netWrapper wrapsa MsgReadWriter with locks around
+// ReadMsg/WriteMsg and applies read/write deadlines.
+type netWrapper struct {
+ rmu, wmu sync.Mutex
+
+ rtimeout, wtimeout time.Duration
+ conn net.Conn
+ wrapped MsgReadWriter
+}
+
+func (rw *netWrapper) ReadMsg() (Msg, error) {
+ rw.rmu.Lock()
+ defer rw.rmu.Unlock()
+ rw.conn.SetReadDeadline(time.Now().Add(rw.rtimeout))
+ return rw.wrapped.ReadMsg()
+}
+
+func (rw *netWrapper) WriteMsg(msg Msg) error {
+ rw.wmu.Lock()
+ defer rw.wmu.Unlock()
+ rw.conn.SetWriteDeadline(time.Now().Add(rw.wtimeout))
+ return rw.wrapped.WriteMsg(msg)
+}
+
+// eofSignal wraps a reader with eof signaling. the eof channel is
+// closed when the wrapped reader returns an error or when count bytes
+// have been read.
+type eofSignal struct {
+ wrapped io.Reader
+ count uint32 // number of bytes left
+ eof chan<- struct{}
+}
+
+// note: when using eofSignal to detect whether a message payload
+// has been read, Read might not be called for zero sized messages.
+func (r *eofSignal) Read(buf []byte) (int, error) {
+ if r.count == 0 {
+ if r.eof != nil {
+ r.eof <- struct{}{}
+ r.eof = nil
+ }
+ return 0, io.EOF
+ }
+
+ max := len(buf)
+ if int(r.count) < len(buf) {
+ max = int(r.count)
+ }
+ n, err := r.wrapped.Read(buf[:max])
+ r.count -= uint32(n)
+ if (err != nil || r.count == 0) && r.eof != nil {
+ r.eof <- struct{}{} // tell Peer that msg has been consumed
+ r.eof = nil
+ }
+ return n, err
+}
+
+// MsgPipe creates a message pipe. Reads on one end are matched
+// with writes on the other. The pipe is full-duplex, both ends
+// implement MsgReadWriter.
+func MsgPipe() (*MsgPipeRW, *MsgPipeRW) {
+ var (
+ c1, c2 = make(chan Msg), make(chan Msg)
+ closing = make(chan struct{})
+ closed = new(int32)
+ rw1 = &MsgPipeRW{c1, c2, closing, closed}
+ rw2 = &MsgPipeRW{c2, c1, closing, closed}
+ )
+ return rw1, rw2
+}
+
+// ErrPipeClosed is returned from pipe operations after the
+// pipe has been closed.
+var ErrPipeClosed = errors.New("p2p: read or write on closed message pipe")
+
+// MsgPipeRW is an endpoint of a MsgReadWriter pipe.
+type MsgPipeRW struct {
+ w chan<- Msg
+ r <-chan Msg
+ closing chan struct{}
+ closed *int32
+}
+
+// WriteMsg sends a messsage on the pipe.
+// It blocks until the receiver has consumed the message payload.
+func (p *MsgPipeRW) WriteMsg(msg Msg) error {
+ if atomic.LoadInt32(p.closed) == 0 {
+ consumed := make(chan struct{}, 1)
+ msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed}
+ select {
+ case p.w <- msg:
+ if msg.Size > 0 {
+ // wait for payload read or discard
+ <-consumed
+ }
+ return nil
+ case <-p.closing:
+ }
+ }
+ return ErrPipeClosed
+}
+
+// ReadMsg returns a message sent on the other end of the pipe.
+func (p *MsgPipeRW) ReadMsg() (Msg, error) {
+ if atomic.LoadInt32(p.closed) == 0 {
+ select {
+ case msg := <-p.r:
+ return msg, nil
+ case <-p.closing:
+ }
+ }
+ return Msg{}, ErrPipeClosed
+}
+
+// Close unblocks any pending ReadMsg and WriteMsg calls on both ends
+// of the pipe. They will return ErrPipeClosed. Note that Close does
+// not interrupt any reads from a message payload.
+func (p *MsgPipeRW) Close() error {
+ if atomic.AddInt32(p.closed, 1) != 1 {
+ // someone else is already closing
+ atomic.StoreInt32(p.closed, 1) // avoid overflow
+ return nil
+ }
+ close(p.closing)
+ return nil
+}
diff --git a/p2p/message_test.go b/p2p/message_test.go
new file mode 100644
index 000000000..31ed61d87
--- /dev/null
+++ b/p2p/message_test.go
@@ -0,0 +1,151 @@
+package p2p
+
+import (
+ "bytes"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestNewMsg(t *testing.T) {
+ msg := NewMsg(3, 1, "000")
+ if msg.Code != 3 {
+ t.Errorf("incorrect code %d, want %d", msg.Code)
+ }
+ expect := unhex("c50183303030")
+ if msg.Size != uint32(len(expect)) {
+ t.Errorf("incorrect size %d, want %d", msg.Size, len(expect))
+ }
+ pl, _ := ioutil.ReadAll(msg.Payload)
+ if !bytes.Equal(pl, expect) {
+ t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
+ }
+}
+
+func ExampleMsgPipe() {
+ rw1, rw2 := MsgPipe()
+ go func() {
+ EncodeMsg(rw1, 8, []byte{0, 0})
+ EncodeMsg(rw1, 5, []byte{1, 1})
+ rw1.Close()
+ }()
+
+ for {
+ msg, err := rw2.ReadMsg()
+ if err != nil {
+ break
+ }
+ var data [1][]byte
+ msg.Decode(&data)
+ fmt.Printf("msg: %d, %x\n", msg.Code, data[0])
+ }
+ // Output:
+ // msg: 8, 0000
+ // msg: 5, 0101
+}
+
+func TestMsgPipeUnblockWrite(t *testing.T) {
+loop:
+ for i := 0; i < 100; i++ {
+ rw1, rw2 := MsgPipe()
+ done := make(chan struct{})
+ go func() {
+ if err := EncodeMsg(rw1, 1); err == nil {
+ t.Error("EncodeMsg returned nil error")
+ } else if err != ErrPipeClosed {
+ t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed)
+ }
+ close(done)
+ }()
+
+ // this call should ensure that EncodeMsg is waiting to
+ // deliver sometimes. if this isn't done, Close is likely to
+ // be executed before EncodeMsg starts and then we won't test
+ // all the cases.
+ runtime.Gosched()
+
+ rw2.Close()
+ select {
+ case <-done:
+ case <-time.After(200 * time.Millisecond):
+ t.Errorf("write didn't unblock")
+ break loop
+ }
+ }
+}
+
+// This test should panic if concurrent close isn't implemented correctly.
+func TestMsgPipeConcurrentClose(t *testing.T) {
+ rw1, _ := MsgPipe()
+ for i := 0; i < 10; i++ {
+ go rw1.Close()
+ }
+}
+
+func TestEOFSignal(t *testing.T) {
+ rb := make([]byte, 10)
+
+ // empty reader
+ eof := make(chan struct{}, 1)
+ sig := &eofSignal{new(bytes.Buffer), 0, eof}
+ if n, err := sig.Read(rb); n != 0 || err != io.EOF {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ default:
+ t.Error("EOF chan not signaled")
+ }
+
+ // count before error
+ eof = make(chan struct{}, 1)
+ sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
+ if n, err := sig.Read(rb); n != 4 || err != nil {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ default:
+ t.Error("EOF chan not signaled")
+ }
+
+ // error before count
+ eof = make(chan struct{}, 1)
+ sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
+ if n, err := sig.Read(rb); n != 4 || err != nil {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ if n, err := sig.Read(rb); n != 0 || err != io.EOF {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ default:
+ t.Error("EOF chan not signaled")
+ }
+
+ // no signal if neither occurs
+ eof = make(chan struct{}, 1)
+ sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
+ if n, err := sig.Read(rb); n != 10 || err != nil {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ t.Error("unexpected EOF signal")
+ default:
+ }
+}
+
+func unhex(str string) []byte {
+ b, err := hex.DecodeString(strings.Replace(str, "\n", "", -1))
+ if err != nil {
+ panic(fmt.Sprintf("invalid hex string: %q", str))
+ }
+ return b
+}
diff --git a/p2p/nat/nat.go b/p2p/nat/nat.go
new file mode 100644
index 000000000..12d355ba1
--- /dev/null
+++ b/p2p/nat/nat.go
@@ -0,0 +1,235 @@
+// Package nat provides access to common port mapping protocols.
+package nat
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/jackpal/go-nat-pmp"
+)
+
+var log = logger.NewLogger("P2P NAT")
+
+// An implementation of nat.Interface can map local ports to ports
+// accessible from the Internet.
+type Interface interface {
+ // These methods manage a mapping between a port on the local
+ // machine to a port that can be connected to from the internet.
+ //
+ // protocol is "UDP" or "TCP". Some implementations allow setting
+ // a display name for the mapping. The mapping may be removed by
+ // the gateway when its lifetime ends.
+ AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
+ DeleteMapping(protocol string, extport, intport int) error
+
+ // This method should return the external (Internet-facing)
+ // address of the gateway device.
+ ExternalIP() (net.IP, error)
+
+ // Should return name of the method. This is used for logging.
+ String() string
+}
+
+// Parse parses a NAT interface description.
+// The following formats are currently accepted.
+// Note that mechanism names are not case-sensitive.
+//
+// "" or "none" return nil
+// "extip:77.12.33.4" will assume the local machine is reachable on the given IP
+// "any" uses the first auto-detected mechanism
+// "upnp" uses the Universal Plug and Play protocol
+// "pmp" uses NAT-PMP with an auto-detected gateway address
+// "pmp:192.168.0.1" uses NAT-PMP with the given gateway address
+func Parse(spec string) (Interface, error) {
+ var (
+ parts = strings.SplitN(spec, ":", 2)
+ mech = strings.ToLower(parts[0])
+ ip net.IP
+ )
+ if len(parts) > 1 {
+ ip = net.ParseIP(parts[1])
+ if ip == nil {
+ return nil, errors.New("invalid IP address")
+ }
+ }
+ switch mech {
+ case "", "none", "off":
+ return nil, nil
+ case "any", "auto", "on":
+ return Any(), nil
+ case "extip", "ip":
+ if ip == nil {
+ return nil, errors.New("missing IP address")
+ }
+ return ExtIP(ip), nil
+ case "upnp":
+ return UPnP(), nil
+ case "pmp", "natpmp", "nat-pmp":
+ return PMP(ip), nil
+ default:
+ return nil, fmt.Errorf("unknown mechanism %q", parts[0])
+ }
+}
+
+const (
+ mapTimeout = 20 * time.Minute
+ mapUpdateInterval = 15 * time.Minute
+)
+
+// Map adds a port mapping on m and keeps it alive until c is closed.
+// This function is typically invoked in its own goroutine.
+func Map(m Interface, c chan struct{}, protocol string, extport, intport int, name string) {
+ refresh := time.NewTimer(mapUpdateInterval)
+ defer func() {
+ refresh.Stop()
+ log.Debugf("Deleting port mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
+ m.DeleteMapping(protocol, extport, intport)
+ }()
+ log.Debugf("add mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
+ if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil {
+ log.Errorf("mapping error: %v\n", err)
+ }
+ for {
+ select {
+ case _, ok := <-c:
+ if !ok {
+ return
+ }
+ case <-refresh.C:
+ log.DebugDetailf("refresh mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
+ if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil {
+ log.Errorf("mapping error: %v\n", err)
+ }
+ refresh.Reset(mapUpdateInterval)
+ }
+ }
+}
+
+// ExtIP assumes that the local machine is reachable on the given
+// external IP address, and that any required ports were mapped manually.
+// Mapping operations will not return an error but won't actually do anything.
+func ExtIP(ip net.IP) Interface {
+ if ip == nil {
+ panic("IP must not be nil")
+ }
+ return extIP(ip)
+}
+
+type extIP net.IP
+
+func (n extIP) ExternalIP() (net.IP, error) { return net.IP(n), nil }
+func (n extIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) }
+
+// These do nothing.
+func (extIP) AddMapping(string, int, int, string, time.Duration) error { return nil }
+func (extIP) DeleteMapping(string, int, int) error { return nil }
+
+// Any returns a port mapper that tries to discover any supported
+// mechanism on the local network.
+func Any() Interface {
+ // TODO: attempt to discover whether the local machine has an
+ // Internet-class address. Return ExtIP in this case.
+ return startautodisc("UPnP or NAT-PMP", func() Interface {
+ found := make(chan Interface, 2)
+ go func() { found <- discoverUPnP() }()
+ go func() { found <- discoverPMP() }()
+ for i := 0; i < cap(found); i++ {
+ if c := <-found; c != nil {
+ return c
+ }
+ }
+ return nil
+ })
+}
+
+// UPnP returns a port mapper that uses UPnP. It will attempt to
+// discover the address of your router using UDP broadcasts.
+func UPnP() Interface {
+ return startautodisc("UPnP", discoverUPnP)
+}
+
+// PMP returns a port mapper that uses NAT-PMP. The provided gateway
+// address should be the IP of your router. If the given gateway
+// address is nil, PMP will attempt to auto-discover the router.
+func PMP(gateway net.IP) Interface {
+ if gateway != nil {
+ return &pmp{gw: gateway, c: natpmp.NewClient(gateway)}
+ }
+ return startautodisc("NAT-PMP", discoverPMP)
+}
+
+// autodisc represents a port mapping mechanism that is still being
+// auto-discovered. Calls to the Interface methods on this type will
+// wait until the discovery is done and then call the method on the
+// discovered mechanism.
+//
+// This type is useful because discovery can take a while but we
+// want return an Interface value from UPnP, PMP and Auto immediately.
+type autodisc struct {
+ what string
+ done <-chan Interface
+
+ mu sync.Mutex
+ found Interface
+}
+
+func startautodisc(what string, doit func() Interface) Interface {
+ // TODO: monitor network configuration and rerun doit when it changes.
+ done := make(chan Interface)
+ ad := &autodisc{what: what, done: done}
+ go func() { done <- doit(); close(done) }()
+ return ad
+}
+
+func (n *autodisc) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
+ if err := n.wait(); err != nil {
+ return err
+ }
+ return n.found.AddMapping(protocol, extport, intport, name, lifetime)
+}
+
+func (n *autodisc) DeleteMapping(protocol string, extport, intport int) error {
+ if err := n.wait(); err != nil {
+ return err
+ }
+ return n.found.DeleteMapping(protocol, extport, intport)
+}
+
+func (n *autodisc) ExternalIP() (net.IP, error) {
+ if err := n.wait(); err != nil {
+ return nil, err
+ }
+ return n.found.ExternalIP()
+}
+
+func (n *autodisc) String() string {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ if n.found == nil {
+ return n.what
+ } else {
+ return n.found.String()
+ }
+}
+
+func (n *autodisc) wait() error {
+ n.mu.Lock()
+ found := n.found
+ n.mu.Unlock()
+ if found != nil {
+ // already discovered
+ return nil
+ }
+ if found = <-n.done; found == nil {
+ return errors.New("no devices discovered")
+ }
+ n.mu.Lock()
+ n.found = found
+ n.mu.Unlock()
+ return nil
+}
diff --git a/p2p/nat/natpmp.go b/p2p/nat/natpmp.go
new file mode 100644
index 000000000..f249c6073
--- /dev/null
+++ b/p2p/nat/natpmp.go
@@ -0,0 +1,115 @@
+package nat
+
+import (
+ "fmt"
+ "net"
+ "strings"
+ "time"
+
+ "github.com/jackpal/go-nat-pmp"
+)
+
+// natPMPClient adapts the NAT-PMP protocol implementation so it conforms to
+// the common interface.
+type pmp struct {
+ gw net.IP
+ c *natpmp.Client
+}
+
+func (n *pmp) String() string {
+ return fmt.Sprintf("NAT-PMP(%v)", n.gw)
+}
+
+func (n *pmp) ExternalIP() (net.IP, error) {
+ response, err := n.c.GetExternalAddress()
+ if err != nil {
+ return nil, err
+ }
+ return response.ExternalIPAddress[:], nil
+}
+
+func (n *pmp) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
+ if lifetime <= 0 {
+ return fmt.Errorf("lifetime must not be <= 0")
+ }
+ // Note order of port arguments is switched between our
+ // AddMapping and the client's AddPortMapping.
+ _, err := n.c.AddPortMapping(strings.ToLower(protocol), intport, extport, int(lifetime/time.Second))
+ return err
+}
+
+func (n *pmp) DeleteMapping(protocol string, extport, intport int) (err error) {
+ // To destroy a mapping, send an add-port with an internalPort of
+ // the internal port to destroy, an external port of zero and a
+ // time of zero.
+ _, err = n.c.AddPortMapping(strings.ToLower(protocol), intport, 0, 0)
+ return err
+}
+
+func discoverPMP() Interface {
+ // run external address lookups on all potential gateways
+ gws := potentialGateways()
+ found := make(chan *pmp, len(gws))
+ for i := range gws {
+ gw := gws[i]
+ go func() {
+ c := natpmp.NewClient(gw)
+ if _, err := c.GetExternalAddress(); err != nil {
+ found <- nil
+ } else {
+ found <- &pmp{gw, c}
+ }
+ }()
+ }
+ // return the one that responds first.
+ // discovery needs to be quick, so we stop caring about
+ // any responses after a very short timeout.
+ timeout := time.NewTimer(1 * time.Second)
+ defer timeout.Stop()
+ for _ = range gws {
+ select {
+ case c := <-found:
+ if c != nil {
+ return c
+ }
+ case <-timeout.C:
+ return nil
+ }
+ }
+ return nil
+}
+
+var (
+ // LAN IP ranges
+ _, lan10, _ = net.ParseCIDR("10.0.0.0/8")
+ _, lan176, _ = net.ParseCIDR("172.16.0.0/12")
+ _, lan192, _ = net.ParseCIDR("192.168.0.0/16")
+)
+
+// TODO: improve this. We currently assume that (on most networks)
+// the router is X.X.X.1 in a local LAN range.
+func potentialGateways() (gws []net.IP) {
+ ifaces, err := net.Interfaces()
+ if err != nil {
+ return nil
+ }
+ for _, iface := range ifaces {
+ ifaddrs, err := iface.Addrs()
+ if err != nil {
+ return gws
+ }
+ for _, addr := range ifaddrs {
+ switch x := addr.(type) {
+ case *net.IPNet:
+ if lan10.Contains(x.IP) || lan176.Contains(x.IP) || lan192.Contains(x.IP) {
+ ip := x.IP.Mask(x.Mask).To4()
+ if ip != nil {
+ ip[3] = ip[3] | 0x01
+ gws = append(gws, ip)
+ }
+ }
+ }
+ }
+ }
+ return gws
+}
diff --git a/p2p/nat/natupnp.go b/p2p/nat/natupnp.go
new file mode 100644
index 000000000..ef7765e8d
--- /dev/null
+++ b/p2p/nat/natupnp.go
@@ -0,0 +1,149 @@
+package nat
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "strings"
+ "time"
+
+ "github.com/huin/goupnp"
+ "github.com/huin/goupnp/dcps/internetgateway1"
+ "github.com/huin/goupnp/dcps/internetgateway2"
+)
+
+type upnp struct {
+ dev *goupnp.RootDevice
+ service string
+ client upnpClient
+}
+
+type upnpClient interface {
+ GetExternalIPAddress() (string, error)
+ AddPortMapping(string, uint16, string, uint16, string, bool, string, uint32) error
+ DeletePortMapping(string, uint16, string) error
+ GetNATRSIPStatus() (sip bool, nat bool, err error)
+}
+
+func (n *upnp) ExternalIP() (addr net.IP, err error) {
+ ipString, err := n.client.GetExternalIPAddress()
+ if err != nil {
+ return nil, err
+ }
+ ip := net.ParseIP(ipString)
+ if ip == nil {
+ return nil, errors.New("bad IP in response")
+ }
+ return ip, nil
+}
+
+func (n *upnp) AddMapping(protocol string, extport, intport int, desc string, lifetime time.Duration) error {
+ ip, err := n.internalAddress()
+ if err != nil {
+ return nil
+ }
+ protocol = strings.ToUpper(protocol)
+ lifetimeS := uint32(lifetime / time.Second)
+ return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS)
+}
+
+func (n *upnp) internalAddress() (net.IP, error) {
+ devaddr, err := net.ResolveUDPAddr("udp4", n.dev.URLBase.Host)
+ if err != nil {
+ return nil, err
+ }
+ ifaces, err := net.Interfaces()
+ if err != nil {
+ return nil, err
+ }
+ for _, iface := range ifaces {
+ addrs, err := iface.Addrs()
+ if err != nil {
+ return nil, err
+ }
+ for _, addr := range addrs {
+ switch x := addr.(type) {
+ case *net.IPNet:
+ if x.Contains(devaddr.IP) {
+ return x.IP, nil
+ }
+ }
+ }
+ }
+ return nil, fmt.Errorf("could not find local address in same net as %v", devaddr)
+}
+
+func (n *upnp) DeleteMapping(protocol string, extport, intport int) error {
+ return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol))
+}
+
+func (n *upnp) String() string {
+ return "UPNP " + n.service
+}
+
+// discoverUPnP searches for Internet Gateway Devices
+// and returns the first one it can find on the local network.
+func discoverUPnP() Interface {
+ found := make(chan *upnp, 2)
+ // IGDv1
+ go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
+ switch sc.Service.ServiceType {
+ case internetgateway1.URN_WANIPConnection_1:
+ return &upnp{dev, "IGDv1-IP1", &internetgateway1.WANIPConnection1{sc}}
+ case internetgateway1.URN_WANPPPConnection_1:
+ return &upnp{dev, "IGDv1-PPP1", &internetgateway1.WANPPPConnection1{sc}}
+ }
+ return nil
+ })
+ // IGDv2
+ go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
+ switch sc.Service.ServiceType {
+ case internetgateway2.URN_WANIPConnection_1:
+ return &upnp{dev, "IGDv2-IP1", &internetgateway2.WANIPConnection1{sc}}
+ case internetgateway2.URN_WANIPConnection_2:
+ return &upnp{dev, "IGDv2-IP2", &internetgateway2.WANIPConnection2{sc}}
+ case internetgateway2.URN_WANPPPConnection_1:
+ return &upnp{dev, "IGDv2-PPP1", &internetgateway2.WANPPPConnection1{sc}}
+ }
+ return nil
+ })
+ for i := 0; i < cap(found); i++ {
+ if c := <-found; c != nil {
+ return c
+ }
+ }
+ return nil
+}
+
+func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice, goupnp.ServiceClient) *upnp) {
+ devs, err := goupnp.DiscoverDevices(target)
+ if err != nil {
+ return
+ }
+ found := false
+ for i := 0; i < len(devs) && !found; i++ {
+ if devs[i].Root == nil {
+ continue
+ }
+ devs[i].Root.Device.VisitServices(func(service *goupnp.Service) {
+ if found {
+ return
+ }
+ // check for a matching IGD service
+ sc := goupnp.ServiceClient{service.NewSOAPClient(), devs[i].Root, service}
+ upnp := matcher(devs[i].Root, sc)
+ if upnp == nil {
+ return
+ }
+ // check whether port mapping is enabled
+ if _, nat, err := upnp.client.GetNATRSIPStatus(); err != nil || !nat {
+ return
+ }
+ out <- upnp
+ found = true
+ })
+ }
+ if !found {
+ out <- nil
+ }
+}
diff --git a/p2p/peer.go b/p2p/peer.go
new file mode 100644
index 000000000..c2c83abfc
--- /dev/null
+++ b/p2p/peer.go
@@ -0,0 +1,312 @@
+package p2p
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "sort"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+const (
+ baseProtocolVersion = 3
+ baseProtocolLength = uint64(16)
+ baseProtocolMaxMsgSize = 10 * 1024 * 1024
+
+ pingInterval = 15 * time.Second
+ disconnectGracePeriod = 2 * time.Second
+)
+
+const (
+ // devp2p message codes
+ handshakeMsg = 0x00
+ discMsg = 0x01
+ pingMsg = 0x02
+ pongMsg = 0x03
+ getPeersMsg = 0x04
+ peersMsg = 0x05
+)
+
+// Peer represents a connected remote node.
+type Peer struct {
+ // Peers have all the log methods.
+ // Use them to display messages related to the peer.
+ *logger.Logger
+
+ conn net.Conn
+ rw *conn
+ running map[string]*protoRW
+
+ protoWG sync.WaitGroup
+ protoErr chan error
+ closed chan struct{}
+ disc chan DiscReason
+}
+
+// NewPeer returns a peer for testing purposes.
+func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
+ pipe, _ := net.Pipe()
+ msgpipe, _ := MsgPipe()
+ conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}}
+ peer := newPeer(pipe, conn, nil)
+ close(peer.closed) // ensures Disconnect doesn't block
+ return peer
+}
+
+// ID returns the node's public key.
+func (p *Peer) ID() discover.NodeID {
+ return p.rw.ID
+}
+
+// Name returns the node name that the remote node advertised.
+func (p *Peer) Name() string {
+ return p.rw.Name
+}
+
+// Caps returns the capabilities (supported subprotocols) of the remote peer.
+func (p *Peer) Caps() []Cap {
+ // TODO: maybe return copy
+ return p.rw.Caps
+}
+
+// RemoteAddr returns the remote address of the network connection.
+func (p *Peer) RemoteAddr() net.Addr {
+ return p.conn.RemoteAddr()
+}
+
+// LocalAddr returns the local address of the network connection.
+func (p *Peer) LocalAddr() net.Addr {
+ return p.conn.LocalAddr()
+}
+
+// Disconnect terminates the peer connection with the given reason.
+// It returns immediately and does not wait until the connection is closed.
+func (p *Peer) Disconnect(reason DiscReason) {
+ select {
+ case p.disc <- reason:
+ case <-p.closed:
+ }
+}
+
+// String implements fmt.Stringer.
+func (p *Peer) String() string {
+ return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
+}
+
+func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer {
+ logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], fd.RemoteAddr())
+ p := &Peer{
+ Logger: logger.NewLogger(logtag),
+ conn: fd,
+ rw: conn,
+ running: matchProtocols(protocols, conn.Caps, conn),
+ disc: make(chan DiscReason),
+ protoErr: make(chan error),
+ closed: make(chan struct{}),
+ }
+ return p
+}
+
+func (p *Peer) run() DiscReason {
+ var readErr = make(chan error, 1)
+ defer p.closeProtocols()
+ defer close(p.closed)
+
+ p.startProtocols()
+ go func() { readErr <- p.readLoop() }()
+
+ ping := time.NewTicker(pingInterval)
+ defer ping.Stop()
+
+ // Wait for an error or disconnect.
+ var reason DiscReason
+loop:
+ for {
+ select {
+ case <-ping.C:
+ go func() {
+ if err := EncodeMsg(p.rw, pingMsg, nil); err != nil {
+ p.protoErr <- err
+ return
+ }
+ }()
+ case err := <-readErr:
+ // We rely on protocols to abort if there is a write error. It
+ // might be more robust to handle them here as well.
+ p.DebugDetailf("Read error: %v\n", err)
+ p.conn.Close()
+ return DiscNetworkError
+ case err := <-p.protoErr:
+ reason = discReasonForError(err)
+ break loop
+ case reason = <-p.disc:
+ break loop
+ }
+ }
+ p.politeDisconnect(reason)
+
+ // Wait for readLoop. It will end because conn is now closed.
+ <-readErr
+ p.Debugf("Disconnected: %v\n", reason)
+ return reason
+}
+
+func (p *Peer) politeDisconnect(reason DiscReason) {
+ done := make(chan struct{})
+ go func() {
+ EncodeMsg(p.rw, discMsg, uint(reason))
+ // Wait for the other side to close the connection.
+ // Discard any data that they send until then.
+ io.Copy(ioutil.Discard, p.conn)
+ close(done)
+ }()
+ select {
+ case <-done:
+ case <-time.After(disconnectGracePeriod):
+ }
+ p.conn.Close()
+}
+
+func (p *Peer) readLoop() error {
+ for {
+ p.conn.SetDeadline(time.Now().Add(frameReadTimeout))
+ msg, err := p.rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if err = p.handle(msg); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (p *Peer) handle(msg Msg) error {
+ switch {
+ case msg.Code == pingMsg:
+ msg.Discard()
+ go EncodeMsg(p.rw, pongMsg)
+ case msg.Code == discMsg:
+ var reason [1]DiscReason
+ // no need to discard or for error checking, we'll close the
+ // connection after this.
+ rlp.Decode(msg.Payload, &reason)
+ p.Disconnect(DiscRequested)
+ return discRequestedError(reason[0])
+ case msg.Code < baseProtocolLength:
+ // ignore other base protocol messages
+ return msg.Discard()
+ default:
+ // it's a subprotocol message
+ proto, err := p.getProto(msg.Code)
+ if err != nil {
+ return fmt.Errorf("msg code out of range: %v", msg.Code)
+ }
+ proto.in <- msg
+ }
+ return nil
+}
+
+// matchProtocols creates structures for matching named subprotocols.
+func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
+ sort.Sort(capsByName(caps))
+ offset := baseProtocolLength
+ result := make(map[string]*protoRW)
+outer:
+ for _, cap := range caps {
+ for _, proto := range protocols {
+ if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil {
+ result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
+ offset += proto.Length
+ continue outer
+ }
+ }
+ }
+ return result
+}
+
+func (p *Peer) startProtocols() {
+ for _, proto := range p.running {
+ proto := proto
+ p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version)
+ p.protoWG.Add(1)
+ go func() {
+ err := proto.Run(p, proto)
+ if err == nil {
+ p.DebugDetailf("Protocol %s/%d returned\n", proto.Name, proto.Version)
+ err = errors.New("protocol returned")
+ } else {
+ p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err)
+ }
+ select {
+ case p.protoErr <- err:
+ case <-p.closed:
+ }
+ p.protoWG.Done()
+ }()
+ }
+}
+
+// getProto finds the protocol responsible for handling
+// the given message code.
+func (p *Peer) getProto(code uint64) (*protoRW, error) {
+ for _, proto := range p.running {
+ if code >= proto.offset && code < proto.offset+proto.Length {
+ return proto, nil
+ }
+ }
+ return nil, newPeerError(errInvalidMsgCode, "%d", code)
+}
+
+func (p *Peer) closeProtocols() {
+ for _, p := range p.running {
+ close(p.in)
+ }
+ p.protoWG.Wait()
+}
+
+// writeProtoMsg sends the given message on behalf of the given named protocol.
+// this exists because of Server.Broadcast.
+func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
+ proto, ok := p.running[protoName]
+ if !ok {
+ return fmt.Errorf("protocol %s not handled by peer", protoName)
+ }
+ if msg.Code >= proto.Length {
+ return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
+ }
+ msg.Code += proto.offset
+ return p.rw.WriteMsg(msg)
+}
+
+type protoRW struct {
+ Protocol
+
+ in chan Msg
+ offset uint64
+ w MsgWriter
+}
+
+func (rw *protoRW) WriteMsg(msg Msg) error {
+ if msg.Code >= rw.Length {
+ return newPeerError(errInvalidMsgCode, "not handled")
+ }
+ msg.Code += rw.offset
+ return rw.w.WriteMsg(msg)
+}
+
+func (rw *protoRW) ReadMsg() (Msg, error) {
+ msg, ok := <-rw.in
+ if !ok {
+ return msg, io.EOF
+ }
+ msg.Code -= rw.offset
+ return msg, nil
+}
diff --git a/p2p/peer_error.go b/p2p/peer_error.go
new file mode 100644
index 000000000..0ff4f4b43
--- /dev/null
+++ b/p2p/peer_error.go
@@ -0,0 +1,131 @@
+package p2p
+
+import (
+ "fmt"
+)
+
+const (
+ errMagicTokenMismatch = iota
+ errRead
+ errWrite
+ errMisc
+ errInvalidMsgCode
+ errInvalidMsg
+ errP2PVersionMismatch
+ errPubkeyInvalid
+ errPubkeyForbidden
+ errProtocolBreach
+ errPingTimeout
+ errInvalidNetworkId
+ errInvalidProtocolVersion
+)
+
+var errorToString = map[int]string{
+ errMagicTokenMismatch: "magic token mismatch",
+ errRead: "read error",
+ errWrite: "write error",
+ errMisc: "misc error",
+ errInvalidMsgCode: "invalid message code",
+ errInvalidMsg: "invalid message",
+ errP2PVersionMismatch: "P2P Version Mismatch",
+ errPubkeyInvalid: "public key invalid",
+ errPubkeyForbidden: "public key forbidden",
+ errProtocolBreach: "protocol Breach",
+ errPingTimeout: "ping timeout",
+ errInvalidNetworkId: "invalid network id",
+ errInvalidProtocolVersion: "invalid protocol version",
+}
+
+type peerError struct {
+ Code int
+ message string
+}
+
+func newPeerError(code int, format string, v ...interface{}) *peerError {
+ desc, ok := errorToString[code]
+ if !ok {
+ panic("invalid error code")
+ }
+ err := &peerError{code, desc}
+ if format != "" {
+ err.message += ": " + fmt.Sprintf(format, v...)
+ }
+ return err
+}
+
+func (self *peerError) Error() string {
+ return self.message
+}
+
+type DiscReason byte
+
+const (
+ DiscRequested DiscReason = iota
+ DiscNetworkError
+ DiscProtocolError
+ DiscUselessPeer
+ DiscTooManyPeers
+ DiscAlreadyConnected
+ DiscIncompatibleVersion
+ DiscInvalidIdentity
+ DiscQuitting
+ DiscUnexpectedIdentity
+ DiscSelf
+ DiscReadTimeout
+ DiscSubprotocolError
+)
+
+var discReasonToString = [...]string{
+ DiscRequested: "Disconnect requested",
+ DiscNetworkError: "Network error",
+ DiscProtocolError: "Breach of protocol",
+ DiscUselessPeer: "Useless peer",
+ DiscTooManyPeers: "Too many peers",
+ DiscAlreadyConnected: "Already connected",
+ DiscIncompatibleVersion: "Incompatible P2P protocol version",
+ DiscInvalidIdentity: "Invalid node identity",
+ DiscQuitting: "Client quitting",
+ DiscUnexpectedIdentity: "Unexpected identity",
+ DiscSelf: "Connected to self",
+ DiscReadTimeout: "Read timeout",
+ DiscSubprotocolError: "Subprotocol error",
+}
+
+func (d DiscReason) String() string {
+ if len(discReasonToString) < int(d) {
+ return fmt.Sprintf("Unknown Reason(%d)", d)
+ }
+ return discReasonToString[d]
+}
+
+type discRequestedError DiscReason
+
+func (err discRequestedError) Error() string {
+ return fmt.Sprintf("disconnect requested: %v", DiscReason(err))
+}
+
+func discReasonForError(err error) DiscReason {
+ if reason, ok := err.(discRequestedError); ok {
+ return DiscReason(reason)
+ }
+ peerError, ok := err.(*peerError)
+ if !ok {
+ return DiscSubprotocolError
+ }
+ switch peerError.Code {
+ case errP2PVersionMismatch:
+ return DiscIncompatibleVersion
+ case errPubkeyInvalid:
+ return DiscInvalidIdentity
+ case errPubkeyForbidden:
+ return DiscUselessPeer
+ case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
+ return DiscProtocolError
+ case errPingTimeout:
+ return DiscReadTimeout
+ case errRead, errWrite:
+ return DiscNetworkError
+ default:
+ return DiscSubprotocolError
+ }
+}
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
new file mode 100644
index 000000000..cc9f1f0cd
--- /dev/null
+++ b/p2p/peer_test.go
@@ -0,0 +1,226 @@
+package p2p
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+var discard = Protocol{
+ Name: "discard",
+ Length: 1,
+ Run: func(p *Peer, rw MsgReadWriter) error {
+ for {
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ fmt.Printf("discarding %d\n", msg.Code)
+ if err = msg.Discard(); err != nil {
+ return err
+ }
+ }
+ },
+}
+
+func testPeer(protos []Protocol) (io.Closer, *conn, *Peer, <-chan DiscReason) {
+ fd1, _ := net.Pipe()
+ hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
+ hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
+ for _, p := range protos {
+ hs1.Caps = append(hs1.Caps, p.cap())
+ hs2.Caps = append(hs2.Caps, p.cap())
+ }
+
+ p1, p2 := MsgPipe()
+ peer := newPeer(fd1, &conn{p1, hs1}, protos)
+ errc := make(chan DiscReason, 1)
+ go func() { errc <- peer.run() }()
+
+ return p1, &conn{p2, hs2}, peer, errc
+}
+
+func TestPeerProtoReadMsg(t *testing.T) {
+ defer testlog(t).detach()
+
+ done := make(chan struct{})
+ proto := Protocol{
+ Name: "a",
+ Length: 5,
+ Run: func(peer *Peer, rw MsgReadWriter) error {
+ if err := expectMsg(rw, 2, []uint{1}); err != nil {
+ t.Error(err)
+ }
+ if err := expectMsg(rw, 3, []uint{2}); err != nil {
+ t.Error(err)
+ }
+ if err := expectMsg(rw, 4, []uint{3}); err != nil {
+ t.Error(err)
+ }
+ close(done)
+ return nil
+ },
+ }
+
+ closer, rw, _, errc := testPeer([]Protocol{proto})
+ defer closer.Close()
+
+ EncodeMsg(rw, baseProtocolLength+2, 1)
+ EncodeMsg(rw, baseProtocolLength+3, 2)
+ EncodeMsg(rw, baseProtocolLength+4, 3)
+
+ select {
+ case <-done:
+ case err := <-errc:
+ t.Errorf("peer returned: %v", err)
+ case <-time.After(2 * time.Second):
+ t.Errorf("receive timeout")
+ }
+}
+
+func TestPeerProtoEncodeMsg(t *testing.T) {
+ defer testlog(t).detach()
+
+ proto := Protocol{
+ Name: "a",
+ Length: 2,
+ Run: func(peer *Peer, rw MsgReadWriter) error {
+ if err := EncodeMsg(rw, 2); err == nil {
+ t.Error("expected error for out-of-range msg code, got nil")
+ }
+ if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil {
+ t.Errorf("write error: %v", err)
+ }
+ return nil
+ },
+ }
+ closer, rw, _, _ := testPeer([]Protocol{proto})
+ defer closer.Close()
+
+ if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
+ t.Error(err)
+ }
+}
+
+func TestPeerWriteForBroadcast(t *testing.T) {
+ defer testlog(t).detach()
+
+ closer, rw, peer, peerErr := testPeer([]Protocol{discard})
+ defer closer.Close()
+
+ // test write errors
+ if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
+ t.Errorf("expected error for unknown protocol, got nil")
+ }
+ if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
+ t.Errorf("expected error for out-of-range msg code, got nil")
+ } else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
+ t.Errorf("wrong error for out-of-range msg code, got %#v", err)
+ }
+
+ // setup for reading the message on the other end
+ read := make(chan struct{})
+ go func() {
+ if err := expectMsg(rw, 16, nil); err != nil {
+ t.Error(err)
+ }
+ close(read)
+ }()
+
+ // test successful write
+ if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
+ t.Errorf("expect no error for known protocol: %v", err)
+ }
+ select {
+ case <-read:
+ case err := <-peerErr:
+ t.Fatalf("peer stopped: %v", err)
+ }
+}
+
+func TestPeerPing(t *testing.T) {
+ defer testlog(t).detach()
+
+ closer, rw, _, _ := testPeer(nil)
+ defer closer.Close()
+ if err := EncodeMsg(rw, pingMsg); err != nil {
+ t.Fatal(err)
+ }
+ if err := expectMsg(rw, pongMsg, nil); err != nil {
+ t.Error(err)
+ }
+}
+
+func TestPeerDisconnect(t *testing.T) {
+ defer testlog(t).detach()
+
+ closer, rw, _, disc := testPeer(nil)
+ defer closer.Close()
+ if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
+ t.Fatal(err)
+ }
+ if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
+ t.Error(err)
+ }
+ closer.Close() // make test end faster
+ if reason := <-disc; reason != DiscRequested {
+ t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
+ }
+}
+
+func TestNewPeer(t *testing.T) {
+ name := "nodename"
+ caps := []Cap{{"foo", 2}, {"bar", 3}}
+ id := randomID()
+ p := NewPeer(id, name, caps)
+ if p.ID() != id {
+ t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id)
+ }
+ if p.Name() != name {
+ t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name)
+ }
+ if !reflect.DeepEqual(p.Caps(), caps) {
+ t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
+ }
+
+ p.Disconnect(DiscAlreadyConnected) // Should not hang
+}
+
+// expectMsg reads a message from r and verifies that its
+// code and encoded RLP content match the provided values.
+// If content is nil, the payload is discarded and not verified.
+func expectMsg(r MsgReader, code uint64, content interface{}) error {
+ msg, err := r.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if msg.Code != code {
+ return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code)
+ }
+ if content == nil {
+ return msg.Discard()
+ } else {
+ contentEnc, err := rlp.EncodeToBytes(content)
+ if err != nil {
+ panic("content encode error: " + err.Error())
+ }
+ if int(msg.Size) != len(contentEnc) {
+ return fmt.Errorf("message size mismatch: got %d, want %d", msg.Size, len(contentEnc))
+ }
+ actualContent, err := ioutil.ReadAll(msg.Payload)
+ if err != nil {
+ return err
+ }
+ if !bytes.Equal(actualContent, contentEnc) {
+ return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc)
+ }
+ }
+ return nil
+}
diff --git a/p2p/protocol.go b/p2p/protocol.go
new file mode 100644
index 000000000..5fa395eda
--- /dev/null
+++ b/p2p/protocol.go
@@ -0,0 +1,50 @@
+package p2p
+
+import "fmt"
+
+// Protocol represents a P2P subprotocol implementation.
+type Protocol struct {
+ // Name should contain the official protocol name,
+ // often a three-letter word.
+ Name string
+
+ // Version should contain the version number of the protocol.
+ Version uint
+
+ // Length should contain the number of message codes used
+ // by the protocol.
+ Length uint64
+
+ // Run is called in a new groutine when the protocol has been
+ // negotiated with a peer. It should read and write messages from
+ // rw. The Payload for each message must be fully consumed.
+ //
+ // The peer connection is closed when Start returns. It should return
+ // any protocol-level error (such as an I/O error) that is
+ // encountered.
+ Run func(peer *Peer, rw MsgReadWriter) error
+}
+
+func (p Protocol) cap() Cap {
+ return Cap{p.Name, p.Version}
+}
+
+// Cap is the structure of a peer capability.
+type Cap struct {
+ Name string
+ Version uint
+}
+
+func (cap Cap) RlpData() interface{} {
+ return []interface{}{cap.Name, cap.Version}
+}
+
+func (cap Cap) String() string {
+ return fmt.Sprintf("%s/%d", cap.Name, cap.Version)
+}
+
+type capsByName []Cap
+
+func (cs capsByName) Len() int { return len(cs) }
+func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
+func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
diff --git a/p2p/rlpx.go b/p2p/rlpx.go
new file mode 100644
index 000000000..6b533e275
--- /dev/null
+++ b/p2p/rlpx.go
@@ -0,0 +1,174 @@
+package p2p
+
+import (
+ "bytes"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/hmac"
+ "errors"
+ "hash"
+ "io"
+
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+var (
+ // this is used in place of actual frame header data.
+ // TODO: replace this when Msg contains the protocol type code.
+ zeroHeader = []byte{0xC2, 0x80, 0x80}
+
+ // sixteen zero bytes
+ zero16 = make([]byte, 16)
+
+ maxUint24 = ^uint32(0) >> 8
+)
+
+// rlpxFrameRW implements a simplified version of RLPx framing.
+// chunked messages are not supported and all headers are equal to
+// zeroHeader.
+//
+// rlpxFrameRW is not safe for concurrent use from multiple goroutines.
+type rlpxFrameRW struct {
+ conn io.ReadWriter
+ enc cipher.Stream
+ dec cipher.Stream
+
+ macCipher cipher.Block
+ egressMAC hash.Hash
+ ingressMAC hash.Hash
+}
+
+func newRlpxFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW {
+ macc, err := aes.NewCipher(s.MAC)
+ if err != nil {
+ panic("invalid MAC secret: " + err.Error())
+ }
+ encc, err := aes.NewCipher(s.AES)
+ if err != nil {
+ panic("invalid AES secret: " + err.Error())
+ }
+ // we use an all-zeroes IV for AES because the key used
+ // for encryption is ephemeral.
+ iv := make([]byte, encc.BlockSize())
+ return &rlpxFrameRW{
+ conn: conn,
+ enc: cipher.NewCTR(encc, iv),
+ dec: cipher.NewCTR(encc, iv),
+ macCipher: macc,
+ egressMAC: s.EgressMAC,
+ ingressMAC: s.IngressMAC,
+ }
+}
+
+func (rw *rlpxFrameRW) WriteMsg(msg Msg) error {
+ ptype, _ := rlp.EncodeToBytes(msg.Code)
+
+ // write header
+ headbuf := make([]byte, 32)
+ fsize := uint32(len(ptype)) + msg.Size
+ if fsize > maxUint24 {
+ return errors.New("message size overflows uint24")
+ }
+ putInt24(fsize, headbuf) // TODO: check overflow
+ copy(headbuf[3:], zeroHeader)
+ rw.enc.XORKeyStream(headbuf[:16], headbuf[:16]) // first half is now encrypted
+
+ // write header MAC
+ copy(headbuf[16:], updateMAC(rw.egressMAC, rw.macCipher, headbuf[:16]))
+ if _, err := rw.conn.Write(headbuf); err != nil {
+ return err
+ }
+
+ // write encrypted frame, updating the egress MAC hash with
+ // the data written to conn.
+ tee := cipher.StreamWriter{S: rw.enc, W: io.MultiWriter(rw.conn, rw.egressMAC)}
+ if _, err := tee.Write(ptype); err != nil {
+ return err
+ }
+ if _, err := io.Copy(tee, msg.Payload); err != nil {
+ return err
+ }
+ if padding := fsize % 16; padding > 0 {
+ if _, err := tee.Write(zero16[:16-padding]); err != nil {
+ return err
+ }
+ }
+
+ // write frame MAC. egress MAC hash is up to date because
+ // frame content was written to it as well.
+ fmacseed := rw.egressMAC.Sum(nil)
+ mac := updateMAC(rw.egressMAC, rw.macCipher, fmacseed)
+ _, err := rw.conn.Write(mac)
+ return err
+}
+
+func (rw *rlpxFrameRW) ReadMsg() (msg Msg, err error) {
+ // read the header
+ headbuf := make([]byte, 32)
+ if _, err := io.ReadFull(rw.conn, headbuf); err != nil {
+ return msg, err
+ }
+ // verify header mac
+ shouldMAC := updateMAC(rw.ingressMAC, rw.macCipher, headbuf[:16])
+ if !hmac.Equal(shouldMAC, headbuf[16:]) {
+ return msg, errors.New("bad header MAC")
+ }
+ rw.dec.XORKeyStream(headbuf[:16], headbuf[:16]) // first half is now decrypted
+ fsize := readInt24(headbuf)
+ // ignore protocol type for now
+
+ // read the frame content
+ var rsize = fsize // frame size rounded up to 16 byte boundary
+ if padding := fsize % 16; padding > 0 {
+ rsize += 16 - padding
+ }
+ framebuf := make([]byte, rsize)
+ if _, err := io.ReadFull(rw.conn, framebuf); err != nil {
+ return msg, err
+ }
+
+ // read and validate frame MAC. we can re-use headbuf for that.
+ rw.ingressMAC.Write(framebuf)
+ fmacseed := rw.ingressMAC.Sum(nil)
+ if _, err := io.ReadFull(rw.conn, headbuf[:16]); err != nil {
+ return msg, err
+ }
+ shouldMAC = updateMAC(rw.ingressMAC, rw.macCipher, fmacseed)
+ if !hmac.Equal(shouldMAC, headbuf[:16]) {
+ return msg, errors.New("bad frame MAC")
+ }
+
+ // decrypt frame content
+ rw.dec.XORKeyStream(framebuf, framebuf)
+
+ // decode message code
+ content := bytes.NewReader(framebuf[:fsize])
+ if err := rlp.Decode(content, &msg.Code); err != nil {
+ return msg, err
+ }
+ msg.Size = uint32(content.Len())
+ msg.Payload = content
+ return msg, nil
+}
+
+// updateMAC reseeds the given hash with encrypted seed.
+// it returns the first 16 bytes of the hash sum after seeding.
+func updateMAC(mac hash.Hash, block cipher.Block, seed []byte) []byte {
+ aesbuf := make([]byte, aes.BlockSize)
+ block.Encrypt(aesbuf, mac.Sum(nil))
+ for i := range aesbuf {
+ aesbuf[i] ^= seed[i]
+ }
+ mac.Write(aesbuf)
+ return mac.Sum(nil)[:16]
+}
+
+func readInt24(b []byte) uint32 {
+ return uint32(b[2]) | uint32(b[1])<<8 | uint32(b[0])<<16
+}
+
+func putInt24(v uint32, b []byte) {
+ b[0] = byte(v >> 16)
+ b[1] = byte(v >> 8)
+ b[2] = byte(v)
+}
diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go
new file mode 100644
index 000000000..49354c7ed
--- /dev/null
+++ b/p2p/rlpx_test.go
@@ -0,0 +1,124 @@
+package p2p
+
+import (
+ "bytes"
+ "crypto/rand"
+ "io/ioutil"
+ "strings"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+func TestRlpxFrameFake(t *testing.T) {
+ buf := new(bytes.Buffer)
+ hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
+ rw := newRlpxFrameRW(buf, secrets{
+ AES: crypto.Sha3(),
+ MAC: crypto.Sha3(),
+ IngressMAC: hash,
+ EgressMAC: hash,
+ })
+
+ golden := unhex(`
+00828ddae471818bb0bfa6b551d1cb42
+01010101010101010101010101010101
+ba628a4ba590cb43f7848f41c4382885
+01010101010101010101010101010101
+`)
+
+ // Check WriteMsg. This puts a message into the buffer.
+ if err := EncodeMsg(rw, 8, 1, 2, 3, 4); err != nil {
+ t.Fatalf("WriteMsg error: %v", err)
+ }
+ written := buf.Bytes()
+ if !bytes.Equal(written, golden) {
+ t.Fatalf("output mismatch:\n got: %x\n want: %x", written, golden)
+ }
+
+ // Check ReadMsg. It reads the message encoded by WriteMsg, which
+ // is equivalent to the golden message above.
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ t.Fatalf("ReadMsg error: %v", err)
+ }
+ if msg.Size != 5 {
+ t.Errorf("msg size mismatch: got %d, want %d", msg.Size, 5)
+ }
+ if msg.Code != 8 {
+ t.Errorf("msg code mismatch: got %d, want %d", msg.Code, 8)
+ }
+ payload, _ := ioutil.ReadAll(msg.Payload)
+ wantPayload := unhex("C401020304")
+ if !bytes.Equal(payload, wantPayload) {
+ t.Errorf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload)
+ }
+}
+
+type fakeHash []byte
+
+func (fakeHash) Write(p []byte) (int, error) { return len(p), nil }
+func (fakeHash) Reset() {}
+func (fakeHash) BlockSize() int { return 0 }
+
+func (h fakeHash) Size() int { return len(h) }
+func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
+
+func TestRlpxFrameRW(t *testing.T) {
+ var (
+ aesSecret = make([]byte, 16)
+ macSecret = make([]byte, 16)
+ egressMACinit = make([]byte, 32)
+ ingressMACinit = make([]byte, 32)
+ )
+ for _, s := range [][]byte{aesSecret, macSecret, egressMACinit, ingressMACinit} {
+ rand.Read(s)
+ }
+ conn := new(bytes.Buffer)
+
+ s1 := secrets{
+ AES: aesSecret,
+ MAC: macSecret,
+ EgressMAC: sha3.NewKeccak256(),
+ IngressMAC: sha3.NewKeccak256(),
+ }
+ s1.EgressMAC.Write(egressMACinit)
+ s1.IngressMAC.Write(ingressMACinit)
+ rw1 := newRlpxFrameRW(conn, s1)
+
+ s2 := secrets{
+ AES: aesSecret,
+ MAC: macSecret,
+ EgressMAC: sha3.NewKeccak256(),
+ IngressMAC: sha3.NewKeccak256(),
+ }
+ s2.EgressMAC.Write(ingressMACinit)
+ s2.IngressMAC.Write(egressMACinit)
+ rw2 := newRlpxFrameRW(conn, s2)
+
+ // send some messages
+ for i := 0; i < 10; i++ {
+ // write message into conn buffer
+ wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)}
+ err := EncodeMsg(rw1, uint64(i), wmsg...)
+ if err != nil {
+ t.Fatalf("WriteMsg error (i=%d): %v", i, err)
+ }
+
+ // read message that rw1 just wrote
+ msg, err := rw2.ReadMsg()
+ if err != nil {
+ t.Fatalf("ReadMsg error (i=%d): %v", i, err)
+ }
+ if msg.Code != uint64(i) {
+ t.Fatalf("msg code mismatch: got %d, want %d", msg.Code, i)
+ }
+ payload, _ := ioutil.ReadAll(msg.Payload)
+ wantPayload, _ := rlp.EncodeToBytes(wmsg)
+ if !bytes.Equal(payload, wantPayload) {
+ t.Fatalf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload)
+ }
+ }
+}
diff --git a/p2p/server.go b/p2p/server.go
new file mode 100644
index 000000000..9c148ba39
--- /dev/null
+++ b/p2p/server.go
@@ -0,0 +1,421 @@
+package p2p
+
+import (
+ "bytes"
+ "crypto/ecdsa"
+ "errors"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+ "github.com/ethereum/go-ethereum/p2p/nat"
+)
+
+const (
+ defaultDialTimeout = 10 * time.Second
+ refreshPeersInterval = 30 * time.Second
+
+ // total timeout for encryption handshake and protocol
+ // handshake in both directions.
+ handshakeTimeout = 5 * time.Second
+ // maximum time allowed for reading a complete message.
+ // this is effectively the amount of time a connection can be idle.
+ frameReadTimeout = 1 * time.Minute
+ // maximum amount of time allowed for writing a complete message.
+ frameWriteTimeout = 5 * time.Second
+)
+
+var srvlog = logger.NewLogger("P2P Server")
+var srvjslog = logger.NewJsonLogger()
+
+// Server manages all peer connections.
+//
+// The fields of Server are used as configuration parameters.
+// You should set them before starting the Server. Fields may not be
+// modified while the server is running.
+type Server struct {
+ // This field must be set to a valid secp256k1 private key.
+ PrivateKey *ecdsa.PrivateKey
+
+ // MaxPeers is the maximum number of peers that can be
+ // connected. It must be greater than zero.
+ MaxPeers int
+
+ // Name sets the node name of this server.
+ // Use common.MakeName to create a name that follows existing conventions.
+ Name string
+
+ // Bootstrap nodes are used to establish connectivity
+ // with the rest of the network.
+ BootstrapNodes []*discover.Node
+
+ // Protocols should contain the protocols supported
+ // by the server. Matching protocols are launched for
+ // each peer.
+ Protocols []Protocol
+
+ // If ListenAddr is set to a non-nil address, the server
+ // will listen for incoming connections.
+ //
+ // If the port is zero, the operating system will pick a port. The
+ // ListenAddr field will be updated with the actual address when
+ // the server is started.
+ ListenAddr string
+
+ // If set to a non-nil value, the given NAT port mapper
+ // is used to make the listening port available to the
+ // Internet.
+ NAT nat.Interface
+
+ // If Dialer is set to a non-nil value, the given Dialer
+ // is used to dial outbound peer connections.
+ Dialer *net.Dialer
+
+ // If NoDial is true, the server will not dial any peers.
+ NoDial bool
+
+ // Hooks for testing. These are useful because we can inhibit
+ // the whole protocol stack.
+ setupFunc
+ newPeerHook
+
+ ourHandshake *protoHandshake
+
+ lock sync.RWMutex
+ running bool
+ listener net.Listener
+ peers map[discover.NodeID]*Peer
+
+ ntab *discover.Table
+
+ quit chan struct{}
+ loopWG sync.WaitGroup // {dial,listen,nat}Loop
+ peerWG sync.WaitGroup // active peer goroutines
+ peerConnect chan *discover.Node
+}
+
+type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node) (*conn, error)
+type newPeerHook func(*Peer)
+
+// Peers returns all connected peers.
+func (srv *Server) Peers() (peers []*Peer) {
+ srv.lock.RLock()
+ defer srv.lock.RUnlock()
+ for _, peer := range srv.peers {
+ if peer != nil {
+ peers = append(peers, peer)
+ }
+ }
+ return
+}
+
+// PeerCount returns the number of connected peers.
+func (srv *Server) PeerCount() int {
+ srv.lock.RLock()
+ n := len(srv.peers)
+ srv.lock.RUnlock()
+ return n
+}
+
+// SuggestPeer creates a connection to the given Node if it
+// is not already connected.
+func (srv *Server) SuggestPeer(n *discover.Node) {
+ srv.peerConnect <- n
+}
+
+// Broadcast sends an RLP-encoded message to all connected peers.
+// This method is deprecated and will be removed later.
+func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
+ var payload []byte
+ if data != nil {
+ payload = common.Encode(data)
+ }
+ srv.lock.RLock()
+ defer srv.lock.RUnlock()
+ for _, peer := range srv.peers {
+ if peer != nil {
+ var msg = Msg{Code: code}
+ if data != nil {
+ msg.Payload = bytes.NewReader(payload)
+ msg.Size = uint32(len(payload))
+ }
+ peer.writeProtoMsg(protocol, msg)
+ }
+ }
+}
+
+// Start starts running the server.
+// Servers can be re-used and started again after stopping.
+func (srv *Server) Start() (err error) {
+ srv.lock.Lock()
+ defer srv.lock.Unlock()
+ if srv.running {
+ return errors.New("server already running")
+ }
+ srvlog.Infoln("Starting Server")
+
+ // static fields
+ if srv.PrivateKey == nil {
+ return fmt.Errorf("Server.PrivateKey must be set to a non-nil key")
+ }
+ if srv.MaxPeers <= 0 {
+ return fmt.Errorf("Server.MaxPeers must be > 0")
+ }
+ srv.quit = make(chan struct{})
+ srv.peers = make(map[discover.NodeID]*Peer)
+ srv.peerConnect = make(chan *discover.Node)
+ if srv.setupFunc == nil {
+ srv.setupFunc = setupConn
+ }
+
+ // node table
+ ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
+ if err != nil {
+ return err
+ }
+ srv.ntab = ntab
+
+ // handshake
+ srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self().ID}
+ for _, p := range srv.Protocols {
+ srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
+ }
+
+ // listen/dial
+ if srv.ListenAddr != "" {
+ if err := srv.startListening(); err != nil {
+ return err
+ }
+ }
+ if srv.Dialer == nil {
+ srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
+ }
+ if !srv.NoDial {
+ srv.loopWG.Add(1)
+ go srv.dialLoop()
+ }
+ if srv.NoDial && srv.ListenAddr == "" {
+ srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
+ }
+
+ srv.running = true
+ return nil
+}
+
+func (srv *Server) startListening() error {
+ listener, err := net.Listen("tcp", srv.ListenAddr)
+ if err != nil {
+ return err
+ }
+ laddr := listener.Addr().(*net.TCPAddr)
+ srv.ListenAddr = laddr.String()
+ srv.listener = listener
+ srv.loopWG.Add(1)
+ go srv.listenLoop()
+ if !laddr.IP.IsLoopback() && srv.NAT != nil {
+ srv.loopWG.Add(1)
+ go func() {
+ nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p")
+ srv.loopWG.Done()
+ }()
+ }
+ return nil
+}
+
+// Stop terminates the server and all active peer connections.
+// It blocks until all active connections have been closed.
+func (srv *Server) Stop() {
+ srv.lock.Lock()
+ if !srv.running {
+ srv.lock.Unlock()
+ return
+ }
+ srv.running = false
+ srv.lock.Unlock()
+
+ srvlog.Infoln("Stopping Server")
+ srv.ntab.Close()
+ if srv.listener != nil {
+ // this unblocks listener Accept
+ srv.listener.Close()
+ }
+ close(srv.quit)
+ srv.loopWG.Wait()
+
+ // No new peers can be added at this point because dialLoop and
+ // listenLoop are down. It is safe to call peerWG.Wait because
+ // peerWG.Add is not called outside of those loops.
+ for _, peer := range srv.peers {
+ peer.Disconnect(DiscQuitting)
+ }
+ srv.peerWG.Wait()
+}
+
+// main loop for adding connections via listening
+func (srv *Server) listenLoop() {
+ defer srv.loopWG.Done()
+ srvlog.Infoln("Listening on", srv.listener.Addr())
+ for {
+ conn, err := srv.listener.Accept()
+ if err != nil {
+ return
+ }
+ srvlog.Debugf("Accepted conn %v\n", conn.RemoteAddr())
+ srv.peerWG.Add(1)
+ go srv.startPeer(conn, nil)
+ }
+}
+
+func (srv *Server) dialLoop() {
+ defer srv.loopWG.Done()
+ refresh := time.NewTicker(refreshPeersInterval)
+ defer refresh.Stop()
+
+ srv.ntab.Bootstrap(srv.BootstrapNodes)
+ go srv.findPeers()
+
+ dialed := make(chan *discover.Node)
+ dialing := make(map[discover.NodeID]bool)
+
+ // TODO: limit number of active dials
+ // TODO: ensure only one findPeers goroutine is running
+ // TODO: pause findPeers when we're at capacity
+
+ for {
+ select {
+ case <-refresh.C:
+
+ go srv.findPeers()
+
+ case dest := <-srv.peerConnect:
+ // avoid dialing nodes that are already connected.
+ // there is another check for this in addPeer,
+ // which runs after the handshake.
+ srv.lock.Lock()
+ _, isconnected := srv.peers[dest.ID]
+ srv.lock.Unlock()
+ if isconnected || dialing[dest.ID] || dest.ID == srv.Self().ID {
+ continue
+ }
+
+ dialing[dest.ID] = true
+ srv.peerWG.Add(1)
+ go func() {
+ srv.dialNode(dest)
+ // at this point, the peer has been added
+ // or discarded. either way, we're not dialing it anymore.
+ dialed <- dest
+ }()
+
+ case dest := <-dialed:
+ delete(dialing, dest.ID)
+
+ case <-srv.quit:
+ // TODO: maybe wait for active dials
+ return
+ }
+ }
+}
+
+func (srv *Server) dialNode(dest *discover.Node) {
+ addr := &net.TCPAddr{IP: dest.IP, Port: dest.TCPPort}
+ srvlog.Debugf("Dialing %v\n", dest)
+ conn, err := srv.Dialer.Dial("tcp", addr.String())
+ if err != nil {
+ srvlog.DebugDetailf("dial error: %v", err)
+ return
+ }
+ srv.startPeer(conn, dest)
+}
+
+func (srv *Server) Self() *discover.Node {
+ return srv.ntab.Self()
+}
+
+func (srv *Server) findPeers() {
+ far := srv.Self().ID
+ for i := range far {
+ far[i] = ^far[i]
+ }
+ closeToSelf := srv.ntab.Lookup(srv.Self().ID)
+ farFromSelf := srv.ntab.Lookup(far)
+
+ for i := 0; i < len(closeToSelf) || i < len(farFromSelf); i++ {
+ if i < len(closeToSelf) {
+ srv.peerConnect <- closeToSelf[i]
+ }
+ if i < len(farFromSelf) {
+ srv.peerConnect <- farFromSelf[i]
+ }
+ }
+}
+
+func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
+ // TODO: handle/store session token
+ fd.SetDeadline(time.Now().Add(handshakeTimeout))
+ conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest)
+ if err != nil {
+ fd.Close()
+ srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err)
+ return
+ }
+
+ conn.MsgReadWriter = &netWrapper{
+ wrapped: conn.MsgReadWriter,
+ conn: fd, rtimeout: frameReadTimeout, wtimeout: frameWriteTimeout,
+ }
+ p := newPeer(fd, conn, srv.Protocols)
+ if ok, reason := srv.addPeer(conn.ID, p); !ok {
+ srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
+ p.politeDisconnect(reason)
+ return
+ }
+
+ srvlog.Debugf("Added %v\n", p)
+ srvjslog.LogJson(&logger.P2PConnected{
+ RemoteId: fmt.Sprintf("%x", conn.ID[:]),
+ RemoteAddress: fd.RemoteAddr().String(),
+ RemoteVersionString: conn.Name,
+ NumConnections: srv.PeerCount(),
+ })
+
+ if srv.newPeerHook != nil {
+ srv.newPeerHook(p)
+ }
+ discreason := p.run()
+ srv.removePeer(p)
+
+ srvlog.Debugf("Removed %v (%v)\n", p, discreason)
+ srvjslog.LogJson(&logger.P2PDisconnected{
+ RemoteId: fmt.Sprintf("%x", conn.ID[:]),
+ NumConnections: srv.PeerCount(),
+ })
+}
+
+func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) {
+ srv.lock.Lock()
+ defer srv.lock.Unlock()
+ switch {
+ case !srv.running:
+ return false, DiscQuitting
+ case len(srv.peers) >= srv.MaxPeers:
+ return false, DiscTooManyPeers
+ case srv.peers[id] != nil:
+ return false, DiscAlreadyConnected
+ case id == srv.Self().ID:
+ return false, DiscSelf
+ }
+ srv.peers[id] = p
+ return true, 0
+}
+
+func (srv *Server) removePeer(p *Peer) {
+ srv.lock.Lock()
+ delete(srv.peers, p.ID())
+ srv.lock.Unlock()
+ srv.peerWG.Done()
+}
diff --git a/p2p/server_test.go b/p2p/server_test.go
new file mode 100644
index 000000000..30447050c
--- /dev/null
+++ b/p2p/server_test.go
@@ -0,0 +1,179 @@
+package p2p
+
+import (
+ "bytes"
+ "crypto/ecdsa"
+ "io"
+ "math/rand"
+ "net"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+)
+
+func startTestServer(t *testing.T, pf newPeerHook) *Server {
+ server := &Server{
+ Name: "test",
+ MaxPeers: 10,
+ ListenAddr: "127.0.0.1:0",
+ PrivateKey: newkey(),
+ newPeerHook: pf,
+ setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
+ id := randomID()
+ rw := newRlpxFrameRW(fd, secrets{
+ MAC: zero16,
+ AES: zero16,
+ IngressMAC: sha3.NewKeccak256(),
+ EgressMAC: sha3.NewKeccak256(),
+ })
+ return &conn{
+ MsgReadWriter: rw,
+ protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
+ }, nil
+ },
+ }
+ if err := server.Start(); err != nil {
+ t.Fatalf("Could not start server: %v", err)
+ }
+ return server
+}
+
+func TestServerListen(t *testing.T) {
+ defer testlog(t).detach()
+
+ // start the test server
+ connected := make(chan *Peer)
+ srv := startTestServer(t, func(p *Peer) {
+ if p == nil {
+ t.Error("peer func called with nil conn")
+ }
+ connected <- p
+ })
+ defer close(connected)
+ defer srv.Stop()
+
+ // dial the test server
+ conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
+ if err != nil {
+ t.Fatalf("could not dial: %v", err)
+ }
+ defer conn.Close()
+
+ select {
+ case peer := <-connected:
+ if peer.LocalAddr().String() != conn.RemoteAddr().String() {
+ t.Errorf("peer started with wrong conn: got %v, want %v",
+ peer.LocalAddr(), conn.RemoteAddr())
+ }
+ case <-time.After(1 * time.Second):
+ t.Error("server did not accept within one second")
+ }
+}
+
+func TestServerDial(t *testing.T) {
+ defer testlog(t).detach()
+
+ // run a one-shot TCP server to handle the connection.
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("could not setup listener: %v")
+ }
+ defer listener.Close()
+ accepted := make(chan net.Conn)
+ go func() {
+ conn, err := listener.Accept()
+ if err != nil {
+ t.Error("accept error:", err)
+ return
+ }
+ conn.Close()
+ accepted <- conn
+ }()
+
+ // start the server
+ connected := make(chan *Peer)
+ srv := startTestServer(t, func(p *Peer) { connected <- p })
+ defer close(connected)
+ defer srv.Stop()
+
+ // tell the server to connect
+ tcpAddr := listener.Addr().(*net.TCPAddr)
+ srv.SuggestPeer(&discover.Node{IP: tcpAddr.IP, TCPPort: tcpAddr.Port})
+
+ select {
+ case conn := <-accepted:
+ select {
+ case peer := <-connected:
+ if peer.RemoteAddr().String() != conn.LocalAddr().String() {
+ t.Errorf("peer started with wrong conn: got %v, want %v",
+ peer.RemoteAddr(), conn.LocalAddr())
+ }
+ // TODO: validate more fields
+ case <-time.After(1 * time.Second):
+ t.Error("server did not launch peer within one second")
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Error("server did not connect within one second")
+ }
+}
+
+func TestServerBroadcast(t *testing.T) {
+ defer testlog(t).detach()
+
+ var connected sync.WaitGroup
+ srv := startTestServer(t, func(p *Peer) {
+ p.running = matchProtocols([]Protocol{discard}, []Cap{discard.cap()}, p.rw)
+ connected.Done()
+ })
+ defer srv.Stop()
+
+ // create a few peers
+ var conns = make([]net.Conn, 8)
+ connected.Add(len(conns))
+ deadline := time.Now().Add(3 * time.Second)
+ dialer := &net.Dialer{Deadline: deadline}
+ for i := range conns {
+ conn, err := dialer.Dial("tcp", srv.ListenAddr)
+ if err != nil {
+ t.Fatalf("conn %d: dial error: %v", i, err)
+ }
+ defer conn.Close()
+ conn.SetDeadline(deadline)
+ conns[i] = conn
+ }
+ connected.Wait()
+
+ // broadcast one message
+ srv.Broadcast("discard", 0, "foo")
+ golden := unhex("66e94d166f0a2c3b884cfa59ca34")
+
+ // check that the message has been written everywhere
+ for i, conn := range conns {
+ buf := make([]byte, len(golden))
+ if _, err := io.ReadFull(conn, buf); err != nil {
+ t.Errorf("conn %d: read error: %v", i, err)
+ } else if !bytes.Equal(buf, golden) {
+ t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden)
+ }
+ }
+}
+
+func newkey() *ecdsa.PrivateKey {
+ key, err := crypto.GenerateKey()
+ if err != nil {
+ panic("couldn't generate key: " + err.Error())
+ }
+ return key
+}
+
+func randomID() (id discover.NodeID) {
+ for i := range id {
+ id[i] = byte(rand.Intn(255))
+ }
+ return id
+}
diff --git a/p2p/testlog_test.go b/p2p/testlog_test.go
new file mode 100644
index 000000000..c524c154c
--- /dev/null
+++ b/p2p/testlog_test.go
@@ -0,0 +1,28 @@
+package p2p
+
+import (
+ "testing"
+
+ "github.com/ethereum/go-ethereum/logger"
+)
+
+type testLogger struct{ t *testing.T }
+
+func testlog(t *testing.T) testLogger {
+ logger.Reset()
+ l := testLogger{t}
+ logger.AddLogSystem(l)
+ return l
+}
+
+func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugDetailLevel }
+func (testLogger) SetLogLevel(logger.LogLevel) {}
+
+func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
+ l.t.Logf("%s", msg)
+}
+
+func (testLogger) detach() {
+ logger.Flush()
+ logger.Reset()
+}