diff options
Diffstat (limited to 'p2p')
-rw-r--r-- | p2p/discover/node.go | 306 | ||||
-rw-r--r-- | p2p/discover/node_test.go | 219 | ||||
-rw-r--r-- | p2p/discover/table.go | 280 | ||||
-rw-r--r-- | p2p/discover/table_test.go | 311 | ||||
-rw-r--r-- | p2p/discover/udp.go | 432 | ||||
-rw-r--r-- | p2p/discover/udp_test.go | 211 | ||||
-rw-r--r-- | p2p/handshake.go | 436 | ||||
-rw-r--r-- | p2p/handshake_test.go | 171 | ||||
-rw-r--r-- | p2p/message.go | 210 | ||||
-rw-r--r-- | p2p/message_test.go | 151 | ||||
-rw-r--r-- | p2p/nat/nat.go | 235 | ||||
-rw-r--r-- | p2p/nat/natpmp.go | 115 | ||||
-rw-r--r-- | p2p/nat/natupnp.go | 149 | ||||
-rw-r--r-- | p2p/peer.go | 312 | ||||
-rw-r--r-- | p2p/peer_error.go | 131 | ||||
-rw-r--r-- | p2p/peer_test.go | 226 | ||||
-rw-r--r-- | p2p/protocol.go | 50 | ||||
-rw-r--r-- | p2p/rlpx.go | 174 | ||||
-rw-r--r-- | p2p/rlpx_test.go | 124 | ||||
-rw-r--r-- | p2p/server.go | 425 | ||||
-rw-r--r-- | p2p/server_test.go | 179 | ||||
-rw-r--r-- | p2p/testlog_test.go | 28 |
22 files changed, 4875 insertions, 0 deletions
diff --git a/p2p/discover/node.go b/p2p/discover/node.go new file mode 100644 index 000000000..de2588258 --- /dev/null +++ b/p2p/discover/node.go @@ -0,0 +1,306 @@ +package discover + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "encoding/hex" + "errors" + "fmt" + "io" + "math/big" + "math/rand" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/secp256k1" + "github.com/ethereum/go-ethereum/rlp" +) + +const nodeIDBits = 512 + +// Node represents a host on the network. +type Node struct { + ID NodeID + IP net.IP + + DiscPort int // UDP listening port for discovery protocol + TCPPort int // TCP listening port for RLPx + + active time.Time +} + +func newNode(id NodeID, addr *net.UDPAddr) *Node { + return &Node{ + ID: id, + IP: addr.IP, + DiscPort: addr.Port, + TCPPort: addr.Port, + active: time.Now(), + } +} + +func (n *Node) isValid() bool { + // TODO: don't accept localhost, LAN addresses from internet hosts + return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0 +} + +// The string representation of a Node is a URL. +// Please see ParseNode for a description of the format. +func (n *Node) String() string { + addr := net.TCPAddr{IP: n.IP, Port: n.TCPPort} + u := url.URL{ + Scheme: "enode", + User: url.User(fmt.Sprintf("%x", n.ID[:])), + Host: addr.String(), + } + if n.DiscPort != n.TCPPort { + u.RawQuery = "discport=" + strconv.Itoa(n.DiscPort) + } + return u.String() +} + +// ParseNode parses a node URL. +// +// A node URL has scheme "enode". +// +// The hexadecimal node ID is encoded in the username portion of the +// URL, separated from the host by an @ sign. The hostname can only be +// given as an IP address, DNS domain names are not allowed. The port +// in the host name section is the TCP listening port. If the TCP and +// UDP (discovery) ports differ, the UDP port is specified as query +// parameter "discport". +// +// In the following example, the node URL describes +// a node with IP address 10.3.58.6, TCP listening port 30303 +// and UDP discovery port 30301. +// +// enode://<hex node id>@10.3.58.6:30303?discport=30301 +func ParseNode(rawurl string) (*Node, error) { + var n Node + u, err := url.Parse(rawurl) + if u.Scheme != "enode" { + return nil, errors.New("invalid URL scheme, want \"enode\"") + } + if u.User == nil { + return nil, errors.New("does not contain node ID") + } + if n.ID, err = HexID(u.User.String()); err != nil { + return nil, fmt.Errorf("invalid node ID (%v)", err) + } + ip, port, err := net.SplitHostPort(u.Host) + if err != nil { + return nil, fmt.Errorf("invalid host: %v", err) + } + if n.IP = net.ParseIP(ip); n.IP == nil { + return nil, errors.New("invalid IP address") + } + if n.TCPPort, err = strconv.Atoi(port); err != nil { + return nil, errors.New("invalid port") + } + qv := u.Query() + if qv.Get("discport") == "" { + n.DiscPort = n.TCPPort + } else { + if n.DiscPort, err = strconv.Atoi(qv.Get("discport")); err != nil { + return nil, errors.New("invalid discport in query") + } + } + return &n, nil +} + +// MustParseNode parses a node URL. It panics if the URL is not valid. +func MustParseNode(rawurl string) *Node { + n, err := ParseNode(rawurl) + if err != nil { + panic("invalid node URL: " + err.Error()) + } + return n +} + +func (n Node) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, rpcNode{IP: n.IP.String(), Port: uint16(n.TCPPort), ID: n.ID}) +} +func (n *Node) DecodeRLP(s *rlp.Stream) (err error) { + var ext rpcNode + if err = s.Decode(&ext); err == nil { + n.TCPPort = int(ext.Port) + n.DiscPort = int(ext.Port) + n.ID = ext.ID + if n.IP = net.ParseIP(ext.IP); n.IP == nil { + return errors.New("invalid IP string") + } + } + return err +} + +// NodeID is a unique identifier for each node. +// The node identifier is a marshaled elliptic curve public key. +type NodeID [nodeIDBits / 8]byte + +// NodeID prints as a long hexadecimal number. +func (n NodeID) String() string { + return fmt.Sprintf("%#x", n[:]) +} + +// The Go syntax representation of a NodeID is a call to HexID. +func (n NodeID) GoString() string { + return fmt.Sprintf("discover.HexID(\"%#x\")", n[:]) +} + +// HexID converts a hex string to a NodeID. +// The string may be prefixed with 0x. +func HexID(in string) (NodeID, error) { + if strings.HasPrefix(in, "0x") { + in = in[2:] + } + var id NodeID + b, err := hex.DecodeString(in) + if err != nil { + return id, err + } else if len(b) != len(id) { + return id, fmt.Errorf("wrong length, need %d hex bytes", len(id)) + } + copy(id[:], b) + return id, nil +} + +// MustHexID converts a hex string to a NodeID. +// It panics if the string is not a valid NodeID. +func MustHexID(in string) NodeID { + id, err := HexID(in) + if err != nil { + panic(err) + } + return id +} + +// PubkeyID returns a marshaled representation of the given public key. +func PubkeyID(pub *ecdsa.PublicKey) NodeID { + var id NodeID + pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y) + if len(pbytes)-1 != len(id) { + panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes))) + } + copy(id[:], pbytes[1:]) + return id +} + +// Pubkey returns the public key represented by the node ID. +// It returns an error if the ID is not a point on the curve. +func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) { + p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)} + half := len(id) / 2 + p.X.SetBytes(id[:half]) + p.Y.SetBytes(id[half:]) + if !p.Curve.IsOnCurve(p.X, p.Y) { + return nil, errors.New("not a point on the S256 curve") + } + return p, nil +} + +// recoverNodeID computes the public key used to sign the +// given hash from the signature. +func recoverNodeID(hash, sig []byte) (id NodeID, err error) { + pubkey, err := secp256k1.RecoverPubkey(hash, sig) + if err != nil { + return id, err + } + if len(pubkey)-1 != len(id) { + return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8) + } + for i := range id { + id[i] = pubkey[i+1] + } + return id, nil +} + +// distcmp compares the distances a->target and b->target. +// Returns -1 if a is closer to target, 1 if b is closer to target +// and 0 if they are equal. +func distcmp(target, a, b NodeID) int { + for i := range target { + da := a[i] ^ target[i] + db := b[i] ^ target[i] + if da > db { + return 1 + } else if da < db { + return -1 + } + } + return 0 +} + +// table of leading zero counts for bytes [0..255] +var lzcount = [256]int{ + 8, 7, 6, 6, 5, 5, 5, 5, + 4, 4, 4, 4, 4, 4, 4, 4, + 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, + 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, +} + +// logdist returns the logarithmic distance between a and b, log2(a ^ b). +func logdist(a, b NodeID) int { + lz := 0 + for i := range a { + x := a[i] ^ b[i] + if x == 0 { + lz += 8 + } else { + lz += lzcount[x] + break + } + } + return len(a)*8 - lz +} + +// randomID returns a random NodeID such that logdist(a, b) == n +func randomID(a NodeID, n int) (b NodeID) { + if n == 0 { + return a + } + // flip bit at position n, fill the rest with random bits + b = a + pos := len(a) - n/8 - 1 + bit := byte(0x01) << (byte(n%8) - 1) + if bit == 0 { + pos++ + bit = 0x80 + } + b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits + for i := pos + 1; i < len(a); i++ { + b[i] = byte(rand.Intn(255)) + } + return b +} diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go new file mode 100644 index 000000000..60b01b6ca --- /dev/null +++ b/p2p/discover/node_test.go @@ -0,0 +1,219 @@ +package discover + +import ( + "math/big" + "math/rand" + "net" + "reflect" + "testing" + "testing/quick" + "time" + + "github.com/ethereum/go-ethereum/crypto" +) + +var ( + quickrand = rand.New(rand.NewSource(time.Now().Unix())) + quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand} +) + +var parseNodeTests = []struct { + rawurl string + wantError string + wantResult *Node +}{ + { + rawurl: "http://foobar", + wantError: `invalid URL scheme, want "enode"`, + }, + { + rawurl: "enode://foobar", + wantError: `does not contain node ID`, + }, + { + rawurl: "enode://01010101@123.124.125.126:3", + wantError: `invalid node ID (wrong length, need 64 hex bytes)`, + }, + { + rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@hostname:3", + wantError: `invalid IP address`, + }, + { + rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo", + wantError: `invalid port`, + }, + { + rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo", + wantError: `invalid discport in query`, + }, + { + rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150", + wantResult: &Node{ + ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + IP: net.ParseIP("127.0.0.1"), + DiscPort: 52150, + TCPPort: 52150, + }, + }, + { + rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150", + wantResult: &Node{ + ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + IP: net.ParseIP("::"), + DiscPort: 52150, + TCPPort: 52150, + }, + }, + { + rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=223344", + wantResult: &Node{ + ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + IP: net.ParseIP("127.0.0.1"), + DiscPort: 223344, + TCPPort: 52150, + }, + }, +} + +func TestParseNode(t *testing.T) { + for i, test := range parseNodeTests { + n, err := ParseNode(test.rawurl) + if err == nil && test.wantError != "" { + t.Errorf("test %d: got nil error, expected %#q", i, test.wantError) + continue + } + if err != nil && err.Error() != test.wantError { + t.Errorf("test %d: got error %#q, expected %#q", i, err.Error(), test.wantError) + continue + } + if !reflect.DeepEqual(n, test.wantResult) { + t.Errorf("test %d: result mismatch:\ngot: %#v, want: %#v", i, n, test.wantResult) + } + } +} + +func TestNodeString(t *testing.T) { + for i, test := range parseNodeTests { + if test.wantError != "" { + continue + } + str := test.wantResult.String() + if str != test.rawurl { + t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.rawurl) + } + } +} + +func TestHexID(t *testing.T) { + ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188} + id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc") + id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc") + + if id1 != ref { + t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:]) + } + if id2 != ref { + t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:]) + } +} + +func TestNodeID_recover(t *testing.T) { + prv := newkey() + hash := make([]byte, 32) + sig, err := crypto.Sign(hash, prv) + if err != nil { + t.Fatalf("signing error: %v", err) + } + + pub := PubkeyID(&prv.PublicKey) + recpub, err := recoverNodeID(hash, sig) + if err != nil { + t.Fatalf("recovery error: %v", err) + } + if pub != recpub { + t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub) + } + + ecdsa, err := pub.Pubkey() + if err != nil { + t.Errorf("Pubkey error: %v", err) + } + if !reflect.DeepEqual(ecdsa, &prv.PublicKey) { + t.Errorf("Pubkey mismatch:\n got: %#v\n want: %#v", ecdsa, &prv.PublicKey) + } +} + +func TestNodeID_pubkeyBad(t *testing.T) { + ecdsa, err := NodeID{}.Pubkey() + if err == nil { + t.Error("expected error for zero ID") + } + if ecdsa != nil { + t.Error("expected nil result") + } +} + +func TestNodeID_distcmp(t *testing.T) { + distcmpBig := func(target, a, b NodeID) int { + tbig := new(big.Int).SetBytes(target[:]) + abig := new(big.Int).SetBytes(a[:]) + bbig := new(big.Int).SetBytes(b[:]) + return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig)) + } + if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil { + t.Error(err) + } +} + +// the random tests is likely to miss the case where they're equal. +func TestNodeID_distcmpEqual(t *testing.T) { + base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0} + if distcmp(base, x, x) != 0 { + t.Errorf("distcmp(base, x, x) != 0") + } +} + +func TestNodeID_logdist(t *testing.T) { + logdistBig := func(a, b NodeID) int { + abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:]) + return new(big.Int).Xor(abig, bbig).BitLen() + } + if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil { + t.Error(err) + } +} + +// the random tests is likely to miss the case where they're equal. +func TestNodeID_logdistEqual(t *testing.T) { + x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + if logdist(x, x) != 0 { + t.Errorf("logdist(x, x) != 0") + } +} + +func TestNodeID_randomID(t *testing.T) { + // we don't use quick.Check here because its output isn't + // very helpful when the test fails. + for i := 0; i < quickcfg.MaxCount; i++ { + a := gen(NodeID{}, quickrand).(NodeID) + dist := quickrand.Intn(len(NodeID{}) * 8) + result := randomID(a, dist) + actualdist := logdist(result, a) + + if dist != actualdist { + t.Log("a: ", a) + t.Log("result:", result) + t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist) + } + } +} + +func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value { + var id NodeID + m := rand.Intn(len(id)) + for i := len(id) - 1; i > m; i-- { + id[i] = byte(rand.Uint32()) + } + return reflect.ValueOf(id) +} diff --git a/p2p/discover/table.go b/p2p/discover/table.go new file mode 100644 index 000000000..e3bec9328 --- /dev/null +++ b/p2p/discover/table.go @@ -0,0 +1,280 @@ +// Package discover implements the Node Discovery Protocol. +// +// The Node Discovery protocol provides a way to find RLPx nodes that +// can be connected to. It uses a Kademlia-like protocol to maintain a +// distributed database of the IDs and endpoints of all listening +// nodes. +package discover + +import ( + "net" + "sort" + "sync" + "time" +) + +const ( + alpha = 3 // Kademlia concurrency factor + bucketSize = 16 // Kademlia bucket size + nBuckets = nodeIDBits + 1 // Number of buckets +) + +type Table struct { + mutex sync.Mutex // protects buckets, their content, and nursery + buckets [nBuckets]*bucket // index of known nodes by distance + nursery []*Node // bootstrap nodes + + net transport + self *Node // metadata of the local node +} + +// transport is implemented by the UDP transport. +// it is an interface so we can test without opening lots of UDP +// sockets and without generating a private key. +type transport interface { + ping(*Node) error + findnode(e *Node, target NodeID) ([]*Node, error) + close() +} + +// bucket contains nodes, ordered by their last activity. +type bucket struct { + lastLookup time.Time + entries []*Node +} + +func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table { + tab := &Table{net: t, self: newNode(ourID, ourAddr)} + for i := range tab.buckets { + tab.buckets[i] = new(bucket) + } + return tab +} + +// Self returns the local node ID. +func (tab *Table) Self() NodeID { + return tab.self.ID +} + +// Close terminates the network listener. +func (tab *Table) Close() { + tab.net.close() +} + +// Bootstrap sets the bootstrap nodes. These nodes are used to connect +// to the network if the table is empty. Bootstrap will also attempt to +// fill the table by performing random lookup operations on the +// network. +func (tab *Table) Bootstrap(nodes []*Node) { + tab.mutex.Lock() + // TODO: maybe filter nodes with bad fields (nil, etc.) to avoid strange crashes + tab.nursery = make([]*Node, 0, len(nodes)) + for _, n := range nodes { + cpy := *n + tab.nursery = append(tab.nursery, &cpy) + } + tab.mutex.Unlock() + tab.refresh() +} + +// Lookup performs a network search for nodes close +// to the given target. It approaches the target by querying +// nodes that are closer to it on each iteration. +func (tab *Table) Lookup(target NodeID) []*Node { + var ( + asked = make(map[NodeID]bool) + seen = make(map[NodeID]bool) + reply = make(chan []*Node, alpha) + pendingQueries = 0 + ) + // don't query further if we hit the target or ourself. + // unlikely to happen often in practice. + asked[target] = true + asked[tab.self.ID] = true + + tab.mutex.Lock() + // update last lookup stamp (for refresh logic) + tab.buckets[logdist(tab.self.ID, target)].lastLookup = time.Now() + // generate initial result set + result := tab.closest(target, bucketSize) + tab.mutex.Unlock() + + for { + // ask the alpha closest nodes that we haven't asked yet + for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ { + n := result.entries[i] + if !asked[n.ID] { + asked[n.ID] = true + pendingQueries++ + go func() { + result, _ := tab.net.findnode(n, target) + reply <- result + }() + } + } + if pendingQueries == 0 { + // we have asked all closest nodes, stop the search + break + } + + // wait for the next reply + for _, n := range <-reply { + cn := n + if !seen[n.ID] { + seen[n.ID] = true + result.push(cn, bucketSize) + } + } + pendingQueries-- + } + return result.entries +} + +// refresh performs a lookup for a random target to keep buckets full. +func (tab *Table) refresh() { + ld := -1 // logdist of chosen bucket + tab.mutex.Lock() + for i, b := range tab.buckets { + if i > 0 && b.lastLookup.Before(time.Now().Add(-1*time.Hour)) { + ld = i + break + } + } + tab.mutex.Unlock() + + result := tab.Lookup(randomID(tab.self.ID, ld)) + if len(result) == 0 { + // bootstrap the table with a self lookup + tab.mutex.Lock() + tab.add(tab.nursery) + tab.mutex.Unlock() + tab.Lookup(tab.self.ID) + // TODO: the Kademlia paper says that we're supposed to perform + // random lookups in all buckets further away than our closest neighbor. + } +} + +// closest returns the n nodes in the table that are closest to the +// given id. The caller must hold tab.mutex. +func (tab *Table) closest(target NodeID, nresults int) *nodesByDistance { + // This is a very wasteful way to find the closest nodes but + // obviously correct. I believe that tree-based buckets would make + // this easier to implement efficiently. + close := &nodesByDistance{target: target} + for _, b := range tab.buckets { + for _, n := range b.entries { + close.push(n, nresults) + } + } + return close +} + +func (tab *Table) len() (n int) { + for _, b := range tab.buckets { + n += len(b.entries) + } + return n +} + +// bumpOrAdd updates the activity timestamp for the given node and +// attempts to insert the node into a bucket. The returned Node might +// not be part of the table. The caller must hold tab.mutex. +func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) { + b := tab.buckets[logdist(tab.self.ID, node)] + if n = b.bump(node); n == nil { + n = newNode(node, from) + if len(b.entries) == bucketSize { + tab.pingReplace(n, b) + } else { + b.entries = append(b.entries, n) + } + } + return n +} + +func (tab *Table) pingReplace(n *Node, b *bucket) { + old := b.entries[bucketSize-1] + go func() { + if err := tab.net.ping(old); err == nil { + // it responded, we don't need to replace it. + return + } + // it didn't respond, replace the node if it is still the oldest node. + tab.mutex.Lock() + if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old { + // slide down other entries and put the new one in front. + // TODO: insert in correct position to keep the order + copy(b.entries[1:], b.entries) + b.entries[0] = n + } + tab.mutex.Unlock() + }() +} + +// bump updates the activity timestamp for the given node. +// The caller must hold tab.mutex. +func (tab *Table) bump(node NodeID) { + tab.buckets[logdist(tab.self.ID, node)].bump(node) +} + +// add puts the entries into the table if their corresponding +// bucket is not full. The caller must hold tab.mutex. +func (tab *Table) add(entries []*Node) { +outer: + for _, n := range entries { + if n == nil || n.ID == tab.self.ID { + // skip bad entries. The RLP decoder returns nil for empty + // input lists. + continue + } + bucket := tab.buckets[logdist(tab.self.ID, n.ID)] + for i := range bucket.entries { + if bucket.entries[i].ID == n.ID { + // already in bucket + continue outer + } + } + if len(bucket.entries) < bucketSize { + bucket.entries = append(bucket.entries, n) + } + } +} + +func (b *bucket) bump(id NodeID) *Node { + for i, n := range b.entries { + if n.ID == id { + n.active = time.Now() + // move it to the front + copy(b.entries[1:], b.entries[:i+1]) + b.entries[0] = n + return n + } + } + return nil +} + +// nodesByDistance is a list of nodes, ordered by +// distance to target. +type nodesByDistance struct { + entries []*Node + target NodeID +} + +// push adds the given node to the list, keeping the total size below maxElems. +func (h *nodesByDistance) push(n *Node, maxElems int) { + ix := sort.Search(len(h.entries), func(i int) bool { + return distcmp(h.target, h.entries[i].ID, n.ID) > 0 + }) + if len(h.entries) < maxElems { + h.entries = append(h.entries, n) + } + if ix == len(h.entries) { + // farther away than all nodes we already have. + // if there was room for it, the node is now the last element. + } else { + // slide existing entries down to make room + // this will overwrite the entry we just appended. + copy(h.entries[ix+1:], h.entries[ix:]) + h.entries[ix] = n + } +} diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go new file mode 100644 index 000000000..08faea68e --- /dev/null +++ b/p2p/discover/table_test.go @@ -0,0 +1,311 @@ +package discover + +import ( + "crypto/ecdsa" + "errors" + "fmt" + "math/rand" + "net" + "reflect" + "testing" + "testing/quick" + "time" + + "github.com/ethereum/go-ethereum/crypto" +) + +func TestTable_bumpOrAddBucketAssign(t *testing.T) { + tab := newTable(nil, NodeID{}, &net.UDPAddr{}) + for i := 1; i < len(tab.buckets); i++ { + tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{}) + } + for i, b := range tab.buckets { + if i > 0 && len(b.entries) != 1 { + t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries)) + } + } +} + +func TestTable_bumpOrAddPingReplace(t *testing.T) { + pingC := make(pingC) + tab := newTable(pingC, NodeID{}, &net.UDPAddr{}) + last := fillBucket(tab, 200) + + // this bumpOrAdd should not replace the last node + // because the node replies to ping. + new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{}) + + pinged := <-pingC + if pinged != last.ID { + t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID) + } + + tab.mutex.Lock() + defer tab.mutex.Unlock() + if l := len(tab.buckets[200].entries); l != bucketSize { + t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l) + } + if !contains(tab.buckets[200].entries, last.ID) { + t.Error("last entry was removed") + } + if contains(tab.buckets[200].entries, new.ID) { + t.Error("new entry was added") + } +} + +func TestTable_bumpOrAddPingTimeout(t *testing.T) { + tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{}) + last := fillBucket(tab, 200) + + // this bumpOrAdd should replace the last node + // because the node does not reply to ping. + new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{}) + + // wait for async bucket update. damn. this needs to go away. + time.Sleep(2 * time.Millisecond) + + tab.mutex.Lock() + defer tab.mutex.Unlock() + if l := len(tab.buckets[200].entries); l != bucketSize { + t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l) + } + if contains(tab.buckets[200].entries, last.ID) { + t.Error("last entry was not removed") + } + if !contains(tab.buckets[200].entries, new.ID) { + t.Error("new entry was not added") + } +} + +func fillBucket(tab *Table, ld int) (last *Node) { + b := tab.buckets[ld] + for len(b.entries) < bucketSize { + b.entries = append(b.entries, &Node{ID: randomID(tab.self.ID, ld)}) + } + return b.entries[bucketSize-1] +} + +type pingC chan NodeID + +func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) { + panic("findnode called on pingRecorder") +} +func (t pingC) close() { + panic("close called on pingRecorder") +} +func (t pingC) ping(n *Node) error { + if t == nil { + return errTimeout + } + t <- n.ID + return nil +} + +func TestTable_bump(t *testing.T) { + tab := newTable(nil, NodeID{}, &net.UDPAddr{}) + + // add an old entry and two recent ones + oldactive := time.Now().Add(-2 * time.Minute) + old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive} + others := []*Node{ + &Node{ID: randomID(tab.self.ID, 200), active: time.Now()}, + &Node{ID: randomID(tab.self.ID, 200), active: time.Now()}, + } + tab.add(append(others, old)) + if tab.buckets[200].entries[0] == old { + t.Fatal("old entry is at front of bucket") + } + + // bumping the old entry should move it to the front + tab.bump(old.ID) + if old.active == oldactive { + t.Error("activity timestamp not updated") + } + if tab.buckets[200].entries[0] != old { + t.Errorf("bumped entry did not move to the front of bucket") + } +} + +func TestTable_closest(t *testing.T) { + t.Parallel() + + test := func(test *closeTest) bool { + // for any node table, Target and N + tab := newTable(nil, test.Self, &net.UDPAddr{}) + tab.add(test.All) + + // check that doClosest(Target, N) returns nodes + result := tab.closest(test.Target, test.N).entries + if hasDuplicates(result) { + t.Errorf("result contains duplicates") + return false + } + if !sortedByDistanceTo(test.Target, result) { + t.Errorf("result is not sorted by distance to target") + return false + } + + // check that the number of results is min(N, tablen) + wantN := test.N + if tlen := tab.len(); tlen < test.N { + wantN = tlen + } + if len(result) != wantN { + t.Errorf("wrong number of nodes: got %d, want %d", len(result), wantN) + return false + } else if len(result) == 0 { + return true // no need to check distance + } + + // check that the result nodes have minimum distance to target. + for _, b := range tab.buckets { + for _, n := range b.entries { + if contains(result, n.ID) { + continue // don't run the check below for nodes in result + } + farthestResult := result[len(result)-1].ID + if distcmp(test.Target, n.ID, farthestResult) < 0 { + t.Errorf("table contains node that is closer to target but it's not in result") + t.Logf(" Target: %v", test.Target) + t.Logf(" Farthest Result: %v", farthestResult) + t.Logf(" ID: %v", n.ID) + return false + } + } + } + return true + } + if err := quick.Check(test, quickcfg); err != nil { + t.Error(err) + } +} + +type closeTest struct { + Self NodeID + Target NodeID + All []*Node + N int +} + +func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { + t := &closeTest{ + Self: gen(NodeID{}, rand).(NodeID), + Target: gen(NodeID{}, rand).(NodeID), + N: rand.Intn(bucketSize), + } + for _, id := range gen([]NodeID{}, rand).([]NodeID) { + t.All = append(t.All, &Node{ID: id}) + } + return reflect.ValueOf(t) +} + +func TestTable_Lookup(t *testing.T) { + self := gen(NodeID{}, quickrand).(NodeID) + target := randomID(self, 200) + transport := findnodeOracle{t, target} + tab := newTable(transport, self, &net.UDPAddr{}) + + // lookup on empty table returns no nodes + if results := tab.Lookup(target); len(results) > 0 { + t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) + } + // seed table with initial node (otherwise lookup will terminate immediately) + tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200}) + + results := tab.Lookup(target) + t.Logf("results:") + for _, e := range results { + t.Logf(" ld=%d, %v", logdist(target, e.ID), e.ID) + } + if len(results) != bucketSize { + t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize) + } + if hasDuplicates(results) { + t.Errorf("result set contains duplicate entries") + } + if !sortedByDistanceTo(target, results) { + t.Errorf("result set not sorted by distance to target") + } + if !contains(results, target) { + t.Errorf("result set does not contain target") + } +} + +// findnode on this transport always returns at least one node +// that is one bucket closer to the target. +type findnodeOracle struct { + t *testing.T + target NodeID +} + +func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) { + t.t.Logf("findnode query at dist %d", n.DiscPort) + // current log distance is encoded in port number + var result []*Node + switch n.DiscPort { + case 0: + panic("query to node at distance 0") + default: + // TODO: add more randomness to distances + next := n.DiscPort - 1 + for i := 0; i < bucketSize; i++ { + result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next}) + } + } + return result, nil +} + +func (t findnodeOracle) close() {} + +func (t findnodeOracle) ping(n *Node) error { + return errors.New("ping is not supported by this transport") +} + +func hasDuplicates(slice []*Node) bool { + seen := make(map[NodeID]bool) + for _, e := range slice { + if seen[e.ID] { + return true + } + seen[e.ID] = true + } + return false +} + +func sortedByDistanceTo(distbase NodeID, slice []*Node) bool { + var last NodeID + for i, e := range slice { + if i > 0 && distcmp(distbase, e.ID, last) < 0 { + return false + } + last = e.ID + } + return true +} + +func contains(ns []*Node, id NodeID) bool { + for _, n := range ns { + if n.ID == id { + return true + } + } + return false +} + +// gen wraps quick.Value so it's easier to use. +// it generates a random value of the given value's type. +func gen(typ interface{}, rand *rand.Rand) interface{} { + v, ok := quick.Value(reflect.TypeOf(typ), rand) + if !ok { + panic(fmt.Sprintf("couldn't generate random value of type %T", typ)) + } + return v.Interface() +} + +func newkey() *ecdsa.PrivateKey { + key, err := crypto.GenerateKey() + if err != nil { + panic("couldn't generate key: " + err.Error()) + } + return key +} diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go new file mode 100644 index 000000000..69e9f3c2e --- /dev/null +++ b/p2p/discover/udp.go @@ -0,0 +1,432 @@ +package discover + +import ( + "bytes" + "crypto/ecdsa" + "errors" + "fmt" + "net" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/rlp" +) + +var log = logger.NewLogger("P2P Discovery") + +// Errors +var ( + errPacketTooSmall = errors.New("too small") + errBadHash = errors.New("bad hash") + errExpired = errors.New("expired") + errTimeout = errors.New("RPC timeout") + errClosed = errors.New("socket closed") +) + +// Timeouts +const ( + respTimeout = 300 * time.Millisecond + sendTimeout = 300 * time.Millisecond + expiration = 20 * time.Second + + refreshInterval = 1 * time.Hour +) + +// RPC packet types +const ( + pingPacket = iota + 1 // zero is 'reserved' + pongPacket + findnodePacket + neighborsPacket +) + +// RPC request structures +type ( + ping struct { + IP string // our IP + Port uint16 // our port + Expiration uint64 + } + + // reply to Ping + pong struct { + ReplyTok []byte + Expiration uint64 + } + + findnode struct { + // Id to look up. The responding node will send back nodes + // closest to the target. + Target NodeID + Expiration uint64 + } + + // reply to findnode + neighbors struct { + Nodes []*Node + Expiration uint64 + } +) + +type rpcNode struct { + IP string + Port uint16 + ID NodeID +} + +// udp implements the RPC protocol. +type udp struct { + conn *net.UDPConn + priv *ecdsa.PrivateKey + addpending chan *pending + replies chan reply + closing chan struct{} + nat nat.Interface + + *Table +} + +// pending represents a pending reply. +// +// some implementations of the protocol wish to send more than one +// reply packet to findnode. in general, any neighbors packet cannot +// be matched up with a specific findnode packet. +// +// our implementation handles this by storing a callback function for +// each pending reply. incoming packets from a node are dispatched +// to all the callback functions for that node. +type pending struct { + // these fields must match in the reply. + from NodeID + ptype byte + + // time when the request must complete + deadline time.Time + + // callback is called when a matching reply arrives. if it returns + // true, the callback is removed from the pending reply queue. + // if it returns false, the reply is considered incomplete and + // the callback will be invoked again for the next matching reply. + callback func(resp interface{}) (done bool) + + // errc receives nil when the callback indicates completion or an + // error if no further reply is received within the timeout. + errc chan<- error +} + +type reply struct { + from NodeID + ptype byte + data interface{} +} + +// ListenUDP returns a new table that listens for UDP packets on laddr. +func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table, error) { + addr, err := net.ResolveUDPAddr("udp", laddr) + if err != nil { + return nil, err + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + return nil, err + } + udp := &udp{ + conn: conn, + priv: priv, + closing: make(chan struct{}), + addpending: make(chan *pending), + replies: make(chan reply), + } + + realaddr := conn.LocalAddr().(*net.UDPAddr) + if natm != nil { + if !realaddr.IP.IsLoopback() { + go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") + } + // TODO: react to external IP changes over time. + if ext, err := natm.ExternalIP(); err == nil { + realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port} + } + } + udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr) + + go udp.loop() + go udp.readLoop() + log.Infoln("Listening, ", udp.self) + return udp.Table, nil +} + +func (t *udp) close() { + close(t.closing) + t.conn.Close() + // TODO: wait for the loops to end. +} + +// ping sends a ping message to the given node and waits for a reply. +func (t *udp) ping(e *Node) error { + // TODO: maybe check for ReplyTo field in callback to measure RTT + errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true }) + t.send(e, pingPacket, ping{ + IP: t.self.IP.String(), + Port: uint16(t.self.TCPPort), + Expiration: uint64(time.Now().Add(expiration).Unix()), + }) + return <-errc +} + +// findnode sends a findnode request to the given node and waits until +// the node has sent up to k neighbors. +func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) { + nodes := make([]*Node, 0, bucketSize) + nreceived := 0 + errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool { + reply := r.(*neighbors) + for _, n := range reply.Nodes { + nreceived++ + if n.isValid() { + nodes = append(nodes, n) + } + } + return nreceived >= bucketSize + }) + + t.send(to, findnodePacket, findnode{ + Target: target, + Expiration: uint64(time.Now().Add(expiration).Unix()), + }) + err := <-errc + return nodes, err +} + +// pending adds a reply callback to the pending reply queue. +// see the documentation of type pending for a detailed explanation. +func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error { + ch := make(chan error, 1) + p := &pending{from: id, ptype: ptype, callback: callback, errc: ch} + select { + case t.addpending <- p: + // loop will handle it + case <-t.closing: + ch <- errClosed + } + return ch +} + +// loop runs in its own goroutin. it keeps track of +// the refresh timer and the pending reply queue. +func (t *udp) loop() { + var ( + pending []*pending + nextDeadline time.Time + timeout = time.NewTimer(0) + refresh = time.NewTicker(refreshInterval) + ) + <-timeout.C // ignore first timeout + defer refresh.Stop() + defer timeout.Stop() + + rearmTimeout := func() { + if len(pending) == 0 || nextDeadline == pending[0].deadline { + return + } + nextDeadline = pending[0].deadline + timeout.Reset(nextDeadline.Sub(time.Now())) + } + + for { + select { + case <-refresh.C: + go t.refresh() + + case <-t.closing: + for _, p := range pending { + p.errc <- errClosed + } + return + + case p := <-t.addpending: + p.deadline = time.Now().Add(respTimeout) + pending = append(pending, p) + rearmTimeout() + + case reply := <-t.replies: + // run matching callbacks, remove if they return false. + for i := 0; i < len(pending); i++ { + p := pending[i] + if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) { + p.errc <- nil + copy(pending[i:], pending[i+1:]) + pending = pending[:len(pending)-1] + i-- + } + } + rearmTimeout() + + case now := <-timeout.C: + // notify and remove callbacks whose deadline is in the past. + i := 0 + for ; i < len(pending) && now.After(pending[i].deadline); i++ { + pending[i].errc <- errTimeout + } + if i > 0 { + copy(pending, pending[i:]) + pending = pending[:len(pending)-i] + } + rearmTimeout() + } + } +} + +const ( + macSize = 256 / 8 + sigSize = 520 / 8 + headSize = macSize + sigSize // space of packet frame data +) + +var headSpace = make([]byte, headSize) + +func (t *udp) send(to *Node, ptype byte, req interface{}) error { + b := new(bytes.Buffer) + b.Write(headSpace) + b.WriteByte(ptype) + if err := rlp.Encode(b, req); err != nil { + log.Errorln("error encoding packet:", err) + return err + } + + packet := b.Bytes() + sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv) + if err != nil { + log.Errorln("could not sign packet:", err) + return err + } + copy(packet[macSize:], sig) + // add the hash to the front. Note: this doesn't protect the + // packet in any way. Our public key will be part of this hash in + // the future. + copy(packet, crypto.Sha3(packet[macSize:])) + + toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort} + log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req) + if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil { + log.DebugDetailln("UDP send failed:", err) + } + return err +} + +// readLoop runs in its own goroutine. it handles incoming UDP packets. +func (t *udp) readLoop() { + defer t.conn.Close() + buf := make([]byte, 4096) // TODO: good buffer size + for { + nbytes, from, err := t.conn.ReadFromUDP(buf) + if err != nil { + return + } + if err := t.packetIn(from, buf[:nbytes]); err != nil { + log.Debugf("Bad packet from %v: %v\n", from, err) + } + } +} + +func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error { + if len(buf) < headSize+1 { + return errPacketTooSmall + } + hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] + shouldhash := crypto.Sha3(buf[macSize:]) + if !bytes.Equal(hash, shouldhash) { + return errBadHash + } + fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig) + if err != nil { + return err + } + + var req interface { + handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error + } + switch ptype := sigdata[0]; ptype { + case pingPacket: + req = new(ping) + case pongPacket: + req = new(pong) + case findnodePacket: + req = new(findnode) + case neighborsPacket: + req = new(neighbors) + default: + return fmt.Errorf("unknown type: %d", ptype) + } + if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil { + return err + } + log.DebugDetailf("<<< %v %T %v\n", from, req, req) + return req.handle(t, from, fromID, hash) +} + +func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { + if expired(req.Expiration) { + return errExpired + } + t.mutex.Lock() + // Note: we're ignoring the provided IP address right now + n := t.bumpOrAdd(fromID, from) + if req.Port != 0 { + n.TCPPort = int(req.Port) + } + t.mutex.Unlock() + + t.send(n, pongPacket, pong{ + ReplyTok: mac, + Expiration: uint64(time.Now().Add(expiration).Unix()), + }) + return nil +} + +func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { + if expired(req.Expiration) { + return errExpired + } + t.mutex.Lock() + t.bump(fromID) + t.mutex.Unlock() + + t.replies <- reply{fromID, pongPacket, req} + return nil +} + +func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { + if expired(req.Expiration) { + return errExpired + } + t.mutex.Lock() + e := t.bumpOrAdd(fromID, from) + closest := t.closest(req.Target, bucketSize).entries + t.mutex.Unlock() + + t.send(e, neighborsPacket, neighbors{ + Nodes: closest, + Expiration: uint64(time.Now().Add(expiration).Unix()), + }) + return nil +} + +func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { + if expired(req.Expiration) { + return errExpired + } + t.mutex.Lock() + t.bump(fromID) + t.add(req.Nodes) + t.mutex.Unlock() + + t.replies <- reply{fromID, neighborsPacket, req} + return nil +} + +func expired(ts uint64) bool { + return time.Unix(int64(ts), 0).Before(time.Now()) +} diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go new file mode 100644 index 000000000..0a8ff6358 --- /dev/null +++ b/p2p/discover/udp_test.go @@ -0,0 +1,211 @@ +package discover + +import ( + "fmt" + logpkg "log" + "net" + "os" + "testing" + "time" + + "github.com/ethereum/go-ethereum/logger" +) + +func init() { + logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel)) +} + +func TestUDP_ping(t *testing.T) { + t.Parallel() + + n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + defer n1.Close() + defer n2.Close() + + if err := n1.net.ping(n2.self); err != nil { + t.Fatalf("ping error: %v", err) + } + if find(n2, n1.self.ID) == nil { + t.Errorf("node 2 does not contain id of node 1") + } + if e := find(n1, n2.self.ID); e != nil { + t.Errorf("node 1 does contains id of node 2: %v", e) + } +} + +func find(tab *Table, id NodeID) *Node { + for _, b := range tab.buckets { + for _, e := range b.entries { + if e.ID == id { + return e + } + } + } + return nil +} + +func TestUDP_findnode(t *testing.T) { + t.Parallel() + + n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + defer n1.Close() + defer n2.Close() + + // put a few nodes into n2. the exact distribution shouldn't + // matter much, altough we need to take care not to overflow + // any bucket. + target := randomID(n1.self.ID, 100) + nodes := &nodesByDistance{target: target} + for i := 0; i < bucketSize; i++ { + n2.add([]*Node{&Node{ + IP: net.IP{1, 2, 3, byte(i)}, + DiscPort: i + 2, + TCPPort: i + 2, + ID: randomID(n2.self.ID, i+2), + }}) + } + n2.add(nodes.entries) + n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort}) + expected := n2.closest(target, bucketSize) + + err := runUDP(10, func() error { + result, _ := n1.net.findnode(n2.self, target) + if len(result) != bucketSize { + return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize) + } + for i := range result { + if result[i].ID != expected.entries[i].ID { + return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i]) + } + } + return nil + }) + if err != nil { + t.Error(err) + } +} + +func TestUDP_replytimeout(t *testing.T) { + t.Parallel() + + // reserve a port so we don't talk to an existing service by accident + addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + fd, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatal(err) + } + defer fd.Close() + + n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + defer n1.Close() + n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr)) + + if err := n1.net.ping(n2); err != errTimeout { + t.Error("expected timeout error, got", err) + } + + if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout { + t.Error("expected timeout error, got", err) + } else if len(result) > 0 { + t.Error("expected empty result, got", result) + } +} + +func TestUDP_findnodeMultiReply(t *testing.T) { + t.Parallel() + + n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + udp2 := n2.net.(*udp) + defer n1.Close() + defer n2.Close() + + err := runUDP(10, func() error { + nodes := make([]*Node, bucketSize) + for i := range nodes { + nodes[i] = &Node{ + IP: net.IP{1, 2, 3, 4}, + DiscPort: i + 1, + TCPPort: i + 1, + ID: randomID(n2.self.ID, i+1), + } + } + + // ask N2 for neighbors. it will send an empty reply back. + // the request will wait for up to bucketSize replies. + resultc := make(chan []*Node) + errc := make(chan error) + go func() { + ns, err := n1.net.findnode(n2.self, n1.self.ID) + if err != nil { + errc <- err + } else { + resultc <- ns + } + }() + + // send a few more neighbors packets to N1. + // it should collect those. + for end := 0; end < len(nodes); { + off := end + if end = end + 5; end > len(nodes) { + end = len(nodes) + } + udp2.send(n1.self, neighborsPacket, neighbors{ + Nodes: nodes[off:end], + Expiration: uint64(time.Now().Add(10 * time.Second).Unix()), + }) + } + + // check that they are all returned. we cannot just check for + // equality because they might not be returned in the order they + // were sent. + var result []*Node + select { + case result = <-resultc: + case err := <-errc: + return err + } + if hasDuplicates(result) { + return fmt.Errorf("result slice contains duplicates") + } + if len(result) != len(nodes) { + return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes)) + } + matched := make(map[NodeID]bool) + for _, n := range result { + for _, expn := range nodes { + if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port { + matched[n.ID] = true + } + } + } + if len(matched) != len(nodes) { + return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes)) + } + return nil + }) + if err != nil { + t.Error(err) + } +} + +// runUDP runs a test n times and returns an error if the test failed +// in all n runs. This is necessary because UDP is unreliable even for +// connections on the local machine, causing test failures. +func runUDP(n int, test func() error) error { + errcount := 0 + errors := "" + for i := 0; i < n; i++ { + if err := test(); err != nil { + errors += fmt.Sprintf("\n#%d: %v", i, err) + errcount++ + } + } + if errcount == n { + return fmt.Errorf("failed on all %d iterations:%s", n, errors) + } + return nil +} diff --git a/p2p/handshake.go b/p2p/handshake.go new file mode 100644 index 000000000..7fc497517 --- /dev/null +++ b/p2p/handshake.go @@ -0,0 +1,436 @@ +package p2p + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "errors" + "fmt" + "hash" + "io" + "net" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/ecies" + "github.com/ethereum/go-ethereum/crypto/secp256k1" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rlp" +) + +const ( + sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2 + sigLen = 65 // elliptic S256 + pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte + shaLen = 32 // hash length (for nonce etc) + + authMsgLen = sigLen + shaLen + pubLen + shaLen + 1 + authRespLen = pubLen + shaLen + 1 + + eciesBytes = 65 + 16 + 32 + encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake + encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake +) + +// conn represents a remote connection after encryption handshake +// and protocol handshake have completed. +// +// The MsgReadWriter is usually layered as follows: +// +// netWrapper (I/O timeouts, thread-safe ReadMsg, WriteMsg) +// rlpxFrameRW (message encoding, encryption, authentication) +// bufio.ReadWriter (buffering) +// net.Conn (network I/O) +// +type conn struct { + MsgReadWriter + *protoHandshake +} + +// secrets represents the connection secrets +// which are negotiated during the encryption handshake. +type secrets struct { + RemoteID discover.NodeID + AES, MAC []byte + EgressMAC, IngressMAC hash.Hash + Token []byte +} + +// protoHandshake is the RLP structure of the protocol handshake. +type protoHandshake struct { + Version uint64 + Name string + Caps []Cap + ListenPort uint64 + ID discover.NodeID +} + +// setupConn starts a protocol session on the given connection. +// It runs the encryption handshake and the protocol handshake. +// If dial is non-nil, the connection the local node is the initiator. +func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { + if dial == nil { + return setupInboundConn(fd, prv, our) + } else { + return setupOutboundConn(fd, prv, our, dial) + } +} + +func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (*conn, error) { + secrets, err := receiverEncHandshake(fd, prv, nil) + if err != nil { + return nil, fmt.Errorf("encryption handshake failed: %v", err) + } + + // Run the protocol handshake using authenticated messages. + rw := newRlpxFrameRW(fd, secrets) + rhs, err := readProtocolHandshake(rw, our) + if err != nil { + return nil, err + } + if rhs.ID != secrets.RemoteID { + return nil, errors.New("node ID in protocol handshake does not match encryption handshake") + } + // TODO: validate that handshake node ID matches + if err := writeProtocolHandshake(rw, our); err != nil { + return nil, fmt.Errorf("protocol write error: %v", err) + } + return &conn{rw, rhs}, nil +} + +func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { + secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil) + if err != nil { + return nil, fmt.Errorf("encryption handshake failed: %v", err) + } + + // Run the protocol handshake using authenticated messages. + rw := newRlpxFrameRW(fd, secrets) + if err := writeProtocolHandshake(rw, our); err != nil { + return nil, fmt.Errorf("protocol write error: %v", err) + } + rhs, err := readProtocolHandshake(rw, our) + if err != nil { + return nil, fmt.Errorf("protocol handshake read error: %v", err) + } + if rhs.ID != dial.ID { + return nil, errors.New("dialed node id mismatch") + } + return &conn{rw, rhs}, nil +} + +// encHandshake contains the state of the encryption handshake. +type encHandshake struct { + initiator bool + remoteID discover.NodeID + + remotePub *ecies.PublicKey // remote-pubk + initNonce, respNonce []byte // nonce + randomPrivKey *ecies.PrivateKey // ecdhe-random + remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk +} + +// secrets is called after the handshake is completed. +// It extracts the connection secrets from the handshake values. +func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) { + ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen) + if err != nil { + return secrets{}, err + } + + // derive base secrets from ephemeral key agreement + sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce)) + aesSecret := crypto.Sha3(ecdheSecret, sharedSecret) + s := secrets{ + RemoteID: h.remoteID, + AES: aesSecret, + MAC: crypto.Sha3(ecdheSecret, aesSecret), + Token: crypto.Sha3(sharedSecret), + } + + // setup sha3 instances for the MACs + mac1 := sha3.NewKeccak256() + mac1.Write(xor(s.MAC, h.respNonce)) + mac1.Write(auth) + mac2 := sha3.NewKeccak256() + mac2.Write(xor(s.MAC, h.initNonce)) + mac2.Write(authResp) + if h.initiator { + s.EgressMAC, s.IngressMAC = mac1, mac2 + } else { + s.EgressMAC, s.IngressMAC = mac2, mac1 + } + + return s, nil +} + +func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) { + return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen) +} + +// initiatorEncHandshake negotiates a session token on conn. +// it should be called on the dialing side of the connection. +// +// prv is the local client's private key. +// token is the token from a previous session with this node. +func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) { + h, err := newInitiatorHandshake(remoteID) + if err != nil { + return s, err + } + auth, err := h.authMsg(prv, token) + if err != nil { + return s, err + } + if _, err = conn.Write(auth); err != nil { + return s, err + } + + response := make([]byte, encAuthRespLen) + if _, err = io.ReadFull(conn, response); err != nil { + return s, err + } + if err := h.decodeAuthResp(response, prv); err != nil { + return s, err + } + return h.secrets(auth, response) +} + +func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) { + // generate random initiator nonce + n := make([]byte, shaLen) + if _, err := rand.Read(n); err != nil { + return nil, err + } + // generate random keypair to use for signing + randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil) + if err != nil { + return nil, err + } + rpub, err := remoteID.Pubkey() + if err != nil { + return nil, fmt.Errorf("bad remoteID: %v", err) + } + h := &encHandshake{ + initiator: true, + remoteID: remoteID, + remotePub: ecies.ImportECDSAPublic(rpub), + initNonce: n, + randomPrivKey: randpriv, + } + return h, nil +} + +// authMsg creates an encrypted initiator handshake message. +func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) { + var tokenFlag byte + if token == nil { + // no session token found means we need to generate shared secret. + // ecies shared secret is used as initial session token for new peers + // generate shared key from prv and remote pubkey + var err error + if token, err = h.ecdhShared(prv); err != nil { + return nil, err + } + } else { + // for known peers, we use stored token from the previous session + tokenFlag = 0x01 + } + + // sign known message: + // ecdh-shared-secret^nonce for new peers + // token^nonce for old peers + signed := xor(token, h.initNonce) + signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA()) + if err != nil { + return nil, err + } + + // encode auth message + // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag + msg := make([]byte, authMsgLen) + n := copy(msg, signature) + n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey))) + n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:]) + n += copy(msg[n:], h.initNonce) + msg[n] = tokenFlag + + // encrypt auth message using remote-pubk + return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil) +} + +// decodeAuthResp decode an encrypted authentication response message. +func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error { + msg, err := crypto.Decrypt(prv, auth) + if err != nil { + return fmt.Errorf("could not decrypt auth response (%v)", err) + } + h.respNonce = msg[pubLen : pubLen+shaLen] + h.remoteRandomPub, err = importPublicKey(msg[:pubLen]) + if err != nil { + return err + } + // ignore token flag for now + return nil +} + +// receiverEncHandshake negotiates a session token on conn. +// it should be called on the listening side of the connection. +// +// prv is the local client's private key. +// token is the token from a previous session with this node. +func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) { + // read remote auth sent by initiator. + auth := make([]byte, encAuthMsgLen) + if _, err := io.ReadFull(conn, auth); err != nil { + return s, err + } + h, err := decodeAuthMsg(prv, token, auth) + if err != nil { + return s, err + } + + // send auth response + resp, err := h.authResp(prv, token) + if err != nil { + return s, err + } + if _, err = conn.Write(resp); err != nil { + return s, err + } + + return h.secrets(auth, resp) +} + +func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) { + var err error + h := new(encHandshake) + // generate random keypair for session + h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil) + if err != nil { + return nil, err + } + // generate random nonce + h.respNonce = make([]byte, shaLen) + if _, err = rand.Read(h.respNonce); err != nil { + return nil, err + } + + msg, err := crypto.Decrypt(prv, auth) + if err != nil { + return nil, fmt.Errorf("could not decrypt auth message (%v)", err) + } + + // decode message parameters + // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag + h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1] + copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen]) + rpub, err := h.remoteID.Pubkey() + if err != nil { + return nil, fmt.Errorf("bad remoteID: %#v", err) + } + h.remotePub = ecies.ImportECDSAPublic(rpub) + + // recover remote random pubkey from signed message. + if token == nil { + // TODO: it is an error if the initiator has a token and we don't. check that. + + // no session token means we need to generate shared secret. + // ecies shared secret is used as initial session token for new peers. + // generate shared key from prv and remote pubkey. + if token, err = h.ecdhShared(prv); err != nil { + return nil, err + } + } + signedMsg := xor(token, h.initNonce) + remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen]) + if err != nil { + return nil, err + } + h.remoteRandomPub, _ = importPublicKey(remoteRandomPub) + return h, nil +} + +// authResp generates the encrypted authentication response message. +func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) { + // responder auth message + // E(remote-pubk, ecdhe-random-pubk || nonce || 0x0) + resp := make([]byte, authRespLen) + n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey)) + n += copy(resp[n:], h.respNonce) + if token == nil { + resp[n] = 0 + } else { + resp[n] = 1 + } + // encrypt using remote-pubk + return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil) +} + +// importPublicKey unmarshals 512 bit public keys. +func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) { + var pubKey65 []byte + switch len(pubKey) { + case 64: + // add 'uncompressed key' flag + pubKey65 = append([]byte{0x04}, pubKey...) + case 65: + pubKey65 = pubKey + default: + return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey)) + } + // TODO: fewer pointless conversions + return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil +} + +func exportPubkey(pub *ecies.PublicKey) []byte { + if pub == nil { + panic("nil pubkey") + } + return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:] +} + +func xor(one, other []byte) (xor []byte) { + xor = make([]byte, len(one)) + for i := 0; i < len(one); i++ { + xor[i] = one[i] ^ other[i] + } + return xor +} + +func writeProtocolHandshake(w MsgWriter, our *protoHandshake) error { + return EncodeMsg(w, handshakeMsg, our.Version, our.Name, our.Caps, our.ListenPort, our.ID[:]) +} + +func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) { + // read and handle remote handshake + msg, err := r.ReadMsg() + if err != nil { + return nil, err + } + if msg.Code == discMsg { + // disconnect before protocol handshake is valid according to the + // spec and we send it ourself if Server.addPeer fails. + var reason DiscReason + rlp.Decode(msg.Payload, &reason) + return nil, discRequestedError(reason) + } + if msg.Code != handshakeMsg { + return nil, fmt.Errorf("expected handshake, got %x", msg.Code) + } + if msg.Size > baseProtocolMaxMsgSize { + return nil, fmt.Errorf("message too big (%d > %d)", msg.Size, baseProtocolMaxMsgSize) + } + var hs protoHandshake + if err := msg.Decode(&hs); err != nil { + return nil, err + } + // validate handshake info + if hs.Version != our.Version { + return nil, newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", baseProtocolVersion, hs.Version) + } + if (hs.ID == discover.NodeID{}) { + return nil, newPeerError(errPubkeyInvalid, "missing") + } + return &hs, nil +} diff --git a/p2p/handshake_test.go b/p2p/handshake_test.go new file mode 100644 index 000000000..19423bb82 --- /dev/null +++ b/p2p/handshake_test.go @@ -0,0 +1,171 @@ +package p2p + +import ( + "bytes" + "crypto/rand" + "fmt" + "net" + "reflect" + "testing" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/ecies" + "github.com/ethereum/go-ethereum/p2p/discover" +) + +func TestSharedSecret(t *testing.T) { + prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader) + pub0 := &prv0.PublicKey + prv1, _ := crypto.GenerateKey() + pub1 := &prv1.PublicKey + + ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen) + if err != nil { + return + } + ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen) + if err != nil { + return + } + t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1) + if !bytes.Equal(ss0, ss1) { + t.Errorf("dont match :(") + } +} + +func TestEncHandshake(t *testing.T) { + for i := 0; i < 20; i++ { + start := time.Now() + if err := testEncHandshake(nil); err != nil { + t.Fatalf("i=%d %v", i, err) + } + t.Logf("(without token) %d %v\n", i+1, time.Since(start)) + } + + for i := 0; i < 20; i++ { + tok := make([]byte, shaLen) + rand.Reader.Read(tok) + start := time.Now() + if err := testEncHandshake(tok); err != nil { + t.Fatalf("i=%d %v", i, err) + } + t.Logf("(with token) %d %v\n", i+1, time.Since(start)) + } +} + +func testEncHandshake(token []byte) error { + type result struct { + side string + s secrets + err error + } + var ( + prv0, _ = crypto.GenerateKey() + prv1, _ = crypto.GenerateKey() + rw0, rw1 = net.Pipe() + output = make(chan result) + ) + + go func() { + r := result{side: "initiator"} + defer func() { output <- r }() + + pub1s := discover.PubkeyID(&prv1.PublicKey) + r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token) + if r.err != nil { + return + } + id1 := discover.PubkeyID(&prv1.PublicKey) + if r.s.RemoteID != id1 { + r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1) + } + }() + go func() { + r := result{side: "receiver"} + defer func() { output <- r }() + + r.s, r.err = receiverEncHandshake(rw1, prv1, token) + if r.err != nil { + return + } + id0 := discover.PubkeyID(&prv0.PublicKey) + if r.s.RemoteID != id0 { + r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0) + } + }() + + // wait for results from both sides + r1, r2 := <-output, <-output + + if r1.err != nil { + return fmt.Errorf("%s side error: %v", r1.side, r1.err) + } + if r2.err != nil { + return fmt.Errorf("%s side error: %v", r2.side, r2.err) + } + + // don't compare remote node IDs + r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{} + // flip MACs on one of them so they compare equal + r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC + if !reflect.DeepEqual(r1.s, r2.s) { + return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s) + } + return nil +} + +func TestSetupConn(t *testing.T) { + prv0, _ := crypto.GenerateKey() + prv1, _ := crypto.GenerateKey() + node0 := &discover.Node{ + ID: discover.PubkeyID(&prv0.PublicKey), + IP: net.IP{1, 2, 3, 4}, + TCPPort: 33, + } + node1 := &discover.Node{ + ID: discover.PubkeyID(&prv1.PublicKey), + IP: net.IP{5, 6, 7, 8}, + TCPPort: 44, + } + hs0 := &protoHandshake{ + Version: baseProtocolVersion, + ID: node0.ID, + Caps: []Cap{{"a", 0}, {"b", 2}}, + } + hs1 := &protoHandshake{ + Version: baseProtocolVersion, + ID: node1.ID, + Caps: []Cap{{"c", 1}, {"d", 3}}, + } + fd0, fd1 := net.Pipe() + + done := make(chan struct{}) + go func() { + defer close(done) + conn0, err := setupConn(fd0, prv0, hs0, node1) + if err != nil { + t.Errorf("outbound side error: %v", err) + return + } + if conn0.ID != node1.ID { + t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID) + } + if !reflect.DeepEqual(conn0.Caps, hs1.Caps) { + t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps) + } + }() + + conn1, err := setupConn(fd1, prv1, hs1, nil) + if err != nil { + t.Fatalf("inbound side error: %v", err) + } + if conn1.ID != node0.ID { + t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID) + } + if !reflect.DeepEqual(conn1.Caps, hs0.Caps) { + t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps) + } + + <-done +} diff --git a/p2p/message.go b/p2p/message.go new file mode 100644 index 000000000..f88c31d1d --- /dev/null +++ b/p2p/message.go @@ -0,0 +1,210 @@ +package p2p + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum/ethutil" + "github.com/ethereum/go-ethereum/rlp" +) + +// Msg defines the structure of a p2p message. +// +// Note that a Msg can only be sent once since the Payload reader is +// consumed during sending. It is not possible to create a Msg and +// send it any number of times. If you want to reuse an encoded +// structure, encode the payload into a byte array and create a +// separate Msg with a bytes.Reader as Payload for each send. +type Msg struct { + Code uint64 + Size uint32 // size of the paylod + Payload io.Reader +} + +// NewMsg creates an RLP-encoded message with the given code. +func NewMsg(code uint64, params ...interface{}) Msg { + p := bytes.NewReader(ethutil.Encode(params)) + return Msg{Code: code, Size: uint32(p.Len()), Payload: p} +} + +// Decode parse the RLP content of a message into +// the given value, which must be a pointer. +// +// For the decoding rules, please see package rlp. +func (msg Msg) Decode(val interface{}) error { + if err := rlp.Decode(msg.Payload, val); err != nil { + return newPeerError(errInvalidMsg, "(code %#x) (size %d) %v", msg.Code, msg.Size, err) + } + return nil +} + +func (msg Msg) String() string { + return fmt.Sprintf("msg #%v (%v bytes)", msg.Code, msg.Size) +} + +// Discard reads any remaining payload data into a black hole. +func (msg Msg) Discard() error { + _, err := io.Copy(ioutil.Discard, msg.Payload) + return err +} + +type MsgReader interface { + ReadMsg() (Msg, error) +} + +type MsgWriter interface { + // WriteMsg sends a message. It will block until the message's + // Payload has been consumed by the other end. + // + // Note that messages can be sent only once because their + // payload reader is drained. + WriteMsg(Msg) error +} + +// MsgReadWriter provides reading and writing of encoded messages. +// Implementations should ensure that ReadMsg and WriteMsg can be +// called simultaneously from multiple goroutines. +type MsgReadWriter interface { + MsgReader + MsgWriter +} + +// EncodeMsg writes an RLP-encoded message with the given code and +// data elements. +func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error { + return w.WriteMsg(NewMsg(code, data...)) +} + +// netWrapper wrapsa MsgReadWriter with locks around +// ReadMsg/WriteMsg and applies read/write deadlines. +type netWrapper struct { + rmu, wmu sync.Mutex + + rtimeout, wtimeout time.Duration + conn net.Conn + wrapped MsgReadWriter +} + +func (rw *netWrapper) ReadMsg() (Msg, error) { + rw.rmu.Lock() + defer rw.rmu.Unlock() + rw.conn.SetReadDeadline(time.Now().Add(rw.rtimeout)) + return rw.wrapped.ReadMsg() +} + +func (rw *netWrapper) WriteMsg(msg Msg) error { + rw.wmu.Lock() + defer rw.wmu.Unlock() + rw.conn.SetWriteDeadline(time.Now().Add(rw.wtimeout)) + return rw.wrapped.WriteMsg(msg) +} + +// eofSignal wraps a reader with eof signaling. the eof channel is +// closed when the wrapped reader returns an error or when count bytes +// have been read. +type eofSignal struct { + wrapped io.Reader + count uint32 // number of bytes left + eof chan<- struct{} +} + +// note: when using eofSignal to detect whether a message payload +// has been read, Read might not be called for zero sized messages. +func (r *eofSignal) Read(buf []byte) (int, error) { + if r.count == 0 { + if r.eof != nil { + r.eof <- struct{}{} + r.eof = nil + } + return 0, io.EOF + } + + max := len(buf) + if int(r.count) < len(buf) { + max = int(r.count) + } + n, err := r.wrapped.Read(buf[:max]) + r.count -= uint32(n) + if (err != nil || r.count == 0) && r.eof != nil { + r.eof <- struct{}{} // tell Peer that msg has been consumed + r.eof = nil + } + return n, err +} + +// MsgPipe creates a message pipe. Reads on one end are matched +// with writes on the other. The pipe is full-duplex, both ends +// implement MsgReadWriter. +func MsgPipe() (*MsgPipeRW, *MsgPipeRW) { + var ( + c1, c2 = make(chan Msg), make(chan Msg) + closing = make(chan struct{}) + closed = new(int32) + rw1 = &MsgPipeRW{c1, c2, closing, closed} + rw2 = &MsgPipeRW{c2, c1, closing, closed} + ) + return rw1, rw2 +} + +// ErrPipeClosed is returned from pipe operations after the +// pipe has been closed. +var ErrPipeClosed = errors.New("p2p: read or write on closed message pipe") + +// MsgPipeRW is an endpoint of a MsgReadWriter pipe. +type MsgPipeRW struct { + w chan<- Msg + r <-chan Msg + closing chan struct{} + closed *int32 +} + +// WriteMsg sends a messsage on the pipe. +// It blocks until the receiver has consumed the message payload. +func (p *MsgPipeRW) WriteMsg(msg Msg) error { + if atomic.LoadInt32(p.closed) == 0 { + consumed := make(chan struct{}, 1) + msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed} + select { + case p.w <- msg: + if msg.Size > 0 { + // wait for payload read or discard + <-consumed + } + return nil + case <-p.closing: + } + } + return ErrPipeClosed +} + +// ReadMsg returns a message sent on the other end of the pipe. +func (p *MsgPipeRW) ReadMsg() (Msg, error) { + if atomic.LoadInt32(p.closed) == 0 { + select { + case msg := <-p.r: + return msg, nil + case <-p.closing: + } + } + return Msg{}, ErrPipeClosed +} + +// Close unblocks any pending ReadMsg and WriteMsg calls on both ends +// of the pipe. They will return ErrPipeClosed. Note that Close does +// not interrupt any reads from a message payload. +func (p *MsgPipeRW) Close() error { + if atomic.AddInt32(p.closed, 1) != 1 { + // someone else is already closing + atomic.StoreInt32(p.closed, 1) // avoid overflow + return nil + } + close(p.closing) + return nil +} diff --git a/p2p/message_test.go b/p2p/message_test.go new file mode 100644 index 000000000..31ed61d87 --- /dev/null +++ b/p2p/message_test.go @@ -0,0 +1,151 @@ +package p2p + +import ( + "bytes" + "encoding/hex" + "fmt" + "io" + "io/ioutil" + "runtime" + "strings" + "testing" + "time" +) + +func TestNewMsg(t *testing.T) { + msg := NewMsg(3, 1, "000") + if msg.Code != 3 { + t.Errorf("incorrect code %d, want %d", msg.Code) + } + expect := unhex("c50183303030") + if msg.Size != uint32(len(expect)) { + t.Errorf("incorrect size %d, want %d", msg.Size, len(expect)) + } + pl, _ := ioutil.ReadAll(msg.Payload) + if !bytes.Equal(pl, expect) { + t.Errorf("incorrect payload content, got %x, want %x", pl, expect) + } +} + +func ExampleMsgPipe() { + rw1, rw2 := MsgPipe() + go func() { + EncodeMsg(rw1, 8, []byte{0, 0}) + EncodeMsg(rw1, 5, []byte{1, 1}) + rw1.Close() + }() + + for { + msg, err := rw2.ReadMsg() + if err != nil { + break + } + var data [1][]byte + msg.Decode(&data) + fmt.Printf("msg: %d, %x\n", msg.Code, data[0]) + } + // Output: + // msg: 8, 0000 + // msg: 5, 0101 +} + +func TestMsgPipeUnblockWrite(t *testing.T) { +loop: + for i := 0; i < 100; i++ { + rw1, rw2 := MsgPipe() + done := make(chan struct{}) + go func() { + if err := EncodeMsg(rw1, 1); err == nil { + t.Error("EncodeMsg returned nil error") + } else if err != ErrPipeClosed { + t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed) + } + close(done) + }() + + // this call should ensure that EncodeMsg is waiting to + // deliver sometimes. if this isn't done, Close is likely to + // be executed before EncodeMsg starts and then we won't test + // all the cases. + runtime.Gosched() + + rw2.Close() + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Errorf("write didn't unblock") + break loop + } + } +} + +// This test should panic if concurrent close isn't implemented correctly. +func TestMsgPipeConcurrentClose(t *testing.T) { + rw1, _ := MsgPipe() + for i := 0; i < 10; i++ { + go rw1.Close() + } +} + +func TestEOFSignal(t *testing.T) { + rb := make([]byte, 10) + + // empty reader + eof := make(chan struct{}, 1) + sig := &eofSignal{new(bytes.Buffer), 0, eof} + if n, err := sig.Read(rb); n != 0 || err != io.EOF { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // count before error + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof} + if n, err := sig.Read(rb); n != 4 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // error before count + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof} + if n, err := sig.Read(rb); n != 4 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + if n, err := sig.Read(rb); n != 0 || err != io.EOF { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // no signal if neither occurs + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof} + if n, err := sig.Read(rb); n != 10 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + t.Error("unexpected EOF signal") + default: + } +} + +func unhex(str string) []byte { + b, err := hex.DecodeString(strings.Replace(str, "\n", "", -1)) + if err != nil { + panic(fmt.Sprintf("invalid hex string: %q", str)) + } + return b +} diff --git a/p2p/nat/nat.go b/p2p/nat/nat.go new file mode 100644 index 000000000..12d355ba1 --- /dev/null +++ b/p2p/nat/nat.go @@ -0,0 +1,235 @@ +// Package nat provides access to common port mapping protocols. +package nat + +import ( + "errors" + "fmt" + "net" + "strings" + "sync" + "time" + + "github.com/ethereum/go-ethereum/logger" + "github.com/jackpal/go-nat-pmp" +) + +var log = logger.NewLogger("P2P NAT") + +// An implementation of nat.Interface can map local ports to ports +// accessible from the Internet. +type Interface interface { + // These methods manage a mapping between a port on the local + // machine to a port that can be connected to from the internet. + // + // protocol is "UDP" or "TCP". Some implementations allow setting + // a display name for the mapping. The mapping may be removed by + // the gateway when its lifetime ends. + AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error + DeleteMapping(protocol string, extport, intport int) error + + // This method should return the external (Internet-facing) + // address of the gateway device. + ExternalIP() (net.IP, error) + + // Should return name of the method. This is used for logging. + String() string +} + +// Parse parses a NAT interface description. +// The following formats are currently accepted. +// Note that mechanism names are not case-sensitive. +// +// "" or "none" return nil +// "extip:77.12.33.4" will assume the local machine is reachable on the given IP +// "any" uses the first auto-detected mechanism +// "upnp" uses the Universal Plug and Play protocol +// "pmp" uses NAT-PMP with an auto-detected gateway address +// "pmp:192.168.0.1" uses NAT-PMP with the given gateway address +func Parse(spec string) (Interface, error) { + var ( + parts = strings.SplitN(spec, ":", 2) + mech = strings.ToLower(parts[0]) + ip net.IP + ) + if len(parts) > 1 { + ip = net.ParseIP(parts[1]) + if ip == nil { + return nil, errors.New("invalid IP address") + } + } + switch mech { + case "", "none", "off": + return nil, nil + case "any", "auto", "on": + return Any(), nil + case "extip", "ip": + if ip == nil { + return nil, errors.New("missing IP address") + } + return ExtIP(ip), nil + case "upnp": + return UPnP(), nil + case "pmp", "natpmp", "nat-pmp": + return PMP(ip), nil + default: + return nil, fmt.Errorf("unknown mechanism %q", parts[0]) + } +} + +const ( + mapTimeout = 20 * time.Minute + mapUpdateInterval = 15 * time.Minute +) + +// Map adds a port mapping on m and keeps it alive until c is closed. +// This function is typically invoked in its own goroutine. +func Map(m Interface, c chan struct{}, protocol string, extport, intport int, name string) { + refresh := time.NewTimer(mapUpdateInterval) + defer func() { + refresh.Stop() + log.Debugf("Deleting port mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m) + m.DeleteMapping(protocol, extport, intport) + }() + log.Debugf("add mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m) + if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil { + log.Errorf("mapping error: %v\n", err) + } + for { + select { + case _, ok := <-c: + if !ok { + return + } + case <-refresh.C: + log.DebugDetailf("refresh mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m) + if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil { + log.Errorf("mapping error: %v\n", err) + } + refresh.Reset(mapUpdateInterval) + } + } +} + +// ExtIP assumes that the local machine is reachable on the given +// external IP address, and that any required ports were mapped manually. +// Mapping operations will not return an error but won't actually do anything. +func ExtIP(ip net.IP) Interface { + if ip == nil { + panic("IP must not be nil") + } + return extIP(ip) +} + +type extIP net.IP + +func (n extIP) ExternalIP() (net.IP, error) { return net.IP(n), nil } +func (n extIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) } + +// These do nothing. +func (extIP) AddMapping(string, int, int, string, time.Duration) error { return nil } +func (extIP) DeleteMapping(string, int, int) error { return nil } + +// Any returns a port mapper that tries to discover any supported +// mechanism on the local network. +func Any() Interface { + // TODO: attempt to discover whether the local machine has an + // Internet-class address. Return ExtIP in this case. + return startautodisc("UPnP or NAT-PMP", func() Interface { + found := make(chan Interface, 2) + go func() { found <- discoverUPnP() }() + go func() { found <- discoverPMP() }() + for i := 0; i < cap(found); i++ { + if c := <-found; c != nil { + return c + } + } + return nil + }) +} + +// UPnP returns a port mapper that uses UPnP. It will attempt to +// discover the address of your router using UDP broadcasts. +func UPnP() Interface { + return startautodisc("UPnP", discoverUPnP) +} + +// PMP returns a port mapper that uses NAT-PMP. The provided gateway +// address should be the IP of your router. If the given gateway +// address is nil, PMP will attempt to auto-discover the router. +func PMP(gateway net.IP) Interface { + if gateway != nil { + return &pmp{gw: gateway, c: natpmp.NewClient(gateway)} + } + return startautodisc("NAT-PMP", discoverPMP) +} + +// autodisc represents a port mapping mechanism that is still being +// auto-discovered. Calls to the Interface methods on this type will +// wait until the discovery is done and then call the method on the +// discovered mechanism. +// +// This type is useful because discovery can take a while but we +// want return an Interface value from UPnP, PMP and Auto immediately. +type autodisc struct { + what string + done <-chan Interface + + mu sync.Mutex + found Interface +} + +func startautodisc(what string, doit func() Interface) Interface { + // TODO: monitor network configuration and rerun doit when it changes. + done := make(chan Interface) + ad := &autodisc{what: what, done: done} + go func() { done <- doit(); close(done) }() + return ad +} + +func (n *autodisc) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error { + if err := n.wait(); err != nil { + return err + } + return n.found.AddMapping(protocol, extport, intport, name, lifetime) +} + +func (n *autodisc) DeleteMapping(protocol string, extport, intport int) error { + if err := n.wait(); err != nil { + return err + } + return n.found.DeleteMapping(protocol, extport, intport) +} + +func (n *autodisc) ExternalIP() (net.IP, error) { + if err := n.wait(); err != nil { + return nil, err + } + return n.found.ExternalIP() +} + +func (n *autodisc) String() string { + n.mu.Lock() + defer n.mu.Unlock() + if n.found == nil { + return n.what + } else { + return n.found.String() + } +} + +func (n *autodisc) wait() error { + n.mu.Lock() + found := n.found + n.mu.Unlock() + if found != nil { + // already discovered + return nil + } + if found = <-n.done; found == nil { + return errors.New("no devices discovered") + } + n.mu.Lock() + n.found = found + n.mu.Unlock() + return nil +} diff --git a/p2p/nat/natpmp.go b/p2p/nat/natpmp.go new file mode 100644 index 000000000..f249c6073 --- /dev/null +++ b/p2p/nat/natpmp.go @@ -0,0 +1,115 @@ +package nat + +import ( + "fmt" + "net" + "strings" + "time" + + "github.com/jackpal/go-nat-pmp" +) + +// natPMPClient adapts the NAT-PMP protocol implementation so it conforms to +// the common interface. +type pmp struct { + gw net.IP + c *natpmp.Client +} + +func (n *pmp) String() string { + return fmt.Sprintf("NAT-PMP(%v)", n.gw) +} + +func (n *pmp) ExternalIP() (net.IP, error) { + response, err := n.c.GetExternalAddress() + if err != nil { + return nil, err + } + return response.ExternalIPAddress[:], nil +} + +func (n *pmp) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error { + if lifetime <= 0 { + return fmt.Errorf("lifetime must not be <= 0") + } + // Note order of port arguments is switched between our + // AddMapping and the client's AddPortMapping. + _, err := n.c.AddPortMapping(strings.ToLower(protocol), intport, extport, int(lifetime/time.Second)) + return err +} + +func (n *pmp) DeleteMapping(protocol string, extport, intport int) (err error) { + // To destroy a mapping, send an add-port with an internalPort of + // the internal port to destroy, an external port of zero and a + // time of zero. + _, err = n.c.AddPortMapping(strings.ToLower(protocol), intport, 0, 0) + return err +} + +func discoverPMP() Interface { + // run external address lookups on all potential gateways + gws := potentialGateways() + found := make(chan *pmp, len(gws)) + for i := range gws { + gw := gws[i] + go func() { + c := natpmp.NewClient(gw) + if _, err := c.GetExternalAddress(); err != nil { + found <- nil + } else { + found <- &pmp{gw, c} + } + }() + } + // return the one that responds first. + // discovery needs to be quick, so we stop caring about + // any responses after a very short timeout. + timeout := time.NewTimer(1 * time.Second) + defer timeout.Stop() + for _ = range gws { + select { + case c := <-found: + if c != nil { + return c + } + case <-timeout.C: + return nil + } + } + return nil +} + +var ( + // LAN IP ranges + _, lan10, _ = net.ParseCIDR("10.0.0.0/8") + _, lan176, _ = net.ParseCIDR("172.16.0.0/12") + _, lan192, _ = net.ParseCIDR("192.168.0.0/16") +) + +// TODO: improve this. We currently assume that (on most networks) +// the router is X.X.X.1 in a local LAN range. +func potentialGateways() (gws []net.IP) { + ifaces, err := net.Interfaces() + if err != nil { + return nil + } + for _, iface := range ifaces { + ifaddrs, err := iface.Addrs() + if err != nil { + return gws + } + for _, addr := range ifaddrs { + switch x := addr.(type) { + case *net.IPNet: + if lan10.Contains(x.IP) || lan176.Contains(x.IP) || lan192.Contains(x.IP) { + ip := x.IP.Mask(x.Mask).To4() + if ip != nil { + ip[3] = ip[3] | 0x01 + gws = append(gws, ip) + } + } + } + } + } + return gws +} diff --git a/p2p/nat/natupnp.go b/p2p/nat/natupnp.go new file mode 100644 index 000000000..ef7765e8d --- /dev/null +++ b/p2p/nat/natupnp.go @@ -0,0 +1,149 @@ +package nat + +import ( + "errors" + "fmt" + "net" + "strings" + "time" + + "github.com/huin/goupnp" + "github.com/huin/goupnp/dcps/internetgateway1" + "github.com/huin/goupnp/dcps/internetgateway2" +) + +type upnp struct { + dev *goupnp.RootDevice + service string + client upnpClient +} + +type upnpClient interface { + GetExternalIPAddress() (string, error) + AddPortMapping(string, uint16, string, uint16, string, bool, string, uint32) error + DeletePortMapping(string, uint16, string) error + GetNATRSIPStatus() (sip bool, nat bool, err error) +} + +func (n *upnp) ExternalIP() (addr net.IP, err error) { + ipString, err := n.client.GetExternalIPAddress() + if err != nil { + return nil, err + } + ip := net.ParseIP(ipString) + if ip == nil { + return nil, errors.New("bad IP in response") + } + return ip, nil +} + +func (n *upnp) AddMapping(protocol string, extport, intport int, desc string, lifetime time.Duration) error { + ip, err := n.internalAddress() + if err != nil { + return nil + } + protocol = strings.ToUpper(protocol) + lifetimeS := uint32(lifetime / time.Second) + return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS) +} + +func (n *upnp) internalAddress() (net.IP, error) { + devaddr, err := net.ResolveUDPAddr("udp4", n.dev.URLBase.Host) + if err != nil { + return nil, err + } + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + for _, iface := range ifaces { + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + for _, addr := range addrs { + switch x := addr.(type) { + case *net.IPNet: + if x.Contains(devaddr.IP) { + return x.IP, nil + } + } + } + } + return nil, fmt.Errorf("could not find local address in same net as %v", devaddr) +} + +func (n *upnp) DeleteMapping(protocol string, extport, intport int) error { + return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol)) +} + +func (n *upnp) String() string { + return "UPNP " + n.service +} + +// discoverUPnP searches for Internet Gateway Devices +// and returns the first one it can find on the local network. +func discoverUPnP() Interface { + found := make(chan *upnp, 2) + // IGDv1 + go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp { + switch sc.Service.ServiceType { + case internetgateway1.URN_WANIPConnection_1: + return &upnp{dev, "IGDv1-IP1", &internetgateway1.WANIPConnection1{sc}} + case internetgateway1.URN_WANPPPConnection_1: + return &upnp{dev, "IGDv1-PPP1", &internetgateway1.WANPPPConnection1{sc}} + } + return nil + }) + // IGDv2 + go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp { + switch sc.Service.ServiceType { + case internetgateway2.URN_WANIPConnection_1: + return &upnp{dev, "IGDv2-IP1", &internetgateway2.WANIPConnection1{sc}} + case internetgateway2.URN_WANIPConnection_2: + return &upnp{dev, "IGDv2-IP2", &internetgateway2.WANIPConnection2{sc}} + case internetgateway2.URN_WANPPPConnection_1: + return &upnp{dev, "IGDv2-PPP1", &internetgateway2.WANPPPConnection1{sc}} + } + return nil + }) + for i := 0; i < cap(found); i++ { + if c := <-found; c != nil { + return c + } + } + return nil +} + +func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice, goupnp.ServiceClient) *upnp) { + devs, err := goupnp.DiscoverDevices(target) + if err != nil { + return + } + found := false + for i := 0; i < len(devs) && !found; i++ { + if devs[i].Root == nil { + continue + } + devs[i].Root.Device.VisitServices(func(service *goupnp.Service) { + if found { + return + } + // check for a matching IGD service + sc := goupnp.ServiceClient{service.NewSOAPClient(), devs[i].Root, service} + upnp := matcher(devs[i].Root, sc) + if upnp == nil { + return + } + // check whether port mapping is enabled + if _, nat, err := upnp.client.GetNATRSIPStatus(); err != nil || !nat { + return + } + out <- upnp + found = true + }) + } + if !found { + out <- nil + } +} diff --git a/p2p/peer.go b/p2p/peer.go new file mode 100644 index 000000000..c2c83abfc --- /dev/null +++ b/p2p/peer.go @@ -0,0 +1,312 @@ +package p2p + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "sort" + "sync" + "time" + + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rlp" +) + +const ( + baseProtocolVersion = 3 + baseProtocolLength = uint64(16) + baseProtocolMaxMsgSize = 10 * 1024 * 1024 + + pingInterval = 15 * time.Second + disconnectGracePeriod = 2 * time.Second +) + +const ( + // devp2p message codes + handshakeMsg = 0x00 + discMsg = 0x01 + pingMsg = 0x02 + pongMsg = 0x03 + getPeersMsg = 0x04 + peersMsg = 0x05 +) + +// Peer represents a connected remote node. +type Peer struct { + // Peers have all the log methods. + // Use them to display messages related to the peer. + *logger.Logger + + conn net.Conn + rw *conn + running map[string]*protoRW + + protoWG sync.WaitGroup + protoErr chan error + closed chan struct{} + disc chan DiscReason +} + +// NewPeer returns a peer for testing purposes. +func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer { + pipe, _ := net.Pipe() + msgpipe, _ := MsgPipe() + conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}} + peer := newPeer(pipe, conn, nil) + close(peer.closed) // ensures Disconnect doesn't block + return peer +} + +// ID returns the node's public key. +func (p *Peer) ID() discover.NodeID { + return p.rw.ID +} + +// Name returns the node name that the remote node advertised. +func (p *Peer) Name() string { + return p.rw.Name +} + +// Caps returns the capabilities (supported subprotocols) of the remote peer. +func (p *Peer) Caps() []Cap { + // TODO: maybe return copy + return p.rw.Caps +} + +// RemoteAddr returns the remote address of the network connection. +func (p *Peer) RemoteAddr() net.Addr { + return p.conn.RemoteAddr() +} + +// LocalAddr returns the local address of the network connection. +func (p *Peer) LocalAddr() net.Addr { + return p.conn.LocalAddr() +} + +// Disconnect terminates the peer connection with the given reason. +// It returns immediately and does not wait until the connection is closed. +func (p *Peer) Disconnect(reason DiscReason) { + select { + case p.disc <- reason: + case <-p.closed: + } +} + +// String implements fmt.Stringer. +func (p *Peer) String() string { + return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr()) +} + +func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer { + logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], fd.RemoteAddr()) + p := &Peer{ + Logger: logger.NewLogger(logtag), + conn: fd, + rw: conn, + running: matchProtocols(protocols, conn.Caps, conn), + disc: make(chan DiscReason), + protoErr: make(chan error), + closed: make(chan struct{}), + } + return p +} + +func (p *Peer) run() DiscReason { + var readErr = make(chan error, 1) + defer p.closeProtocols() + defer close(p.closed) + + p.startProtocols() + go func() { readErr <- p.readLoop() }() + + ping := time.NewTicker(pingInterval) + defer ping.Stop() + + // Wait for an error or disconnect. + var reason DiscReason +loop: + for { + select { + case <-ping.C: + go func() { + if err := EncodeMsg(p.rw, pingMsg, nil); err != nil { + p.protoErr <- err + return + } + }() + case err := <-readErr: + // We rely on protocols to abort if there is a write error. It + // might be more robust to handle them here as well. + p.DebugDetailf("Read error: %v\n", err) + p.conn.Close() + return DiscNetworkError + case err := <-p.protoErr: + reason = discReasonForError(err) + break loop + case reason = <-p.disc: + break loop + } + } + p.politeDisconnect(reason) + + // Wait for readLoop. It will end because conn is now closed. + <-readErr + p.Debugf("Disconnected: %v\n", reason) + return reason +} + +func (p *Peer) politeDisconnect(reason DiscReason) { + done := make(chan struct{}) + go func() { + EncodeMsg(p.rw, discMsg, uint(reason)) + // Wait for the other side to close the connection. + // Discard any data that they send until then. + io.Copy(ioutil.Discard, p.conn) + close(done) + }() + select { + case <-done: + case <-time.After(disconnectGracePeriod): + } + p.conn.Close() +} + +func (p *Peer) readLoop() error { + for { + p.conn.SetDeadline(time.Now().Add(frameReadTimeout)) + msg, err := p.rw.ReadMsg() + if err != nil { + return err + } + if err = p.handle(msg); err != nil { + return err + } + } + return nil +} + +func (p *Peer) handle(msg Msg) error { + switch { + case msg.Code == pingMsg: + msg.Discard() + go EncodeMsg(p.rw, pongMsg) + case msg.Code == discMsg: + var reason [1]DiscReason + // no need to discard or for error checking, we'll close the + // connection after this. + rlp.Decode(msg.Payload, &reason) + p.Disconnect(DiscRequested) + return discRequestedError(reason[0]) + case msg.Code < baseProtocolLength: + // ignore other base protocol messages + return msg.Discard() + default: + // it's a subprotocol message + proto, err := p.getProto(msg.Code) + if err != nil { + return fmt.Errorf("msg code out of range: %v", msg.Code) + } + proto.in <- msg + } + return nil +} + +// matchProtocols creates structures for matching named subprotocols. +func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW { + sort.Sort(capsByName(caps)) + offset := baseProtocolLength + result := make(map[string]*protoRW) +outer: + for _, cap := range caps { + for _, proto := range protocols { + if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil { + result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw} + offset += proto.Length + continue outer + } + } + } + return result +} + +func (p *Peer) startProtocols() { + for _, proto := range p.running { + proto := proto + p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version) + p.protoWG.Add(1) + go func() { + err := proto.Run(p, proto) + if err == nil { + p.DebugDetailf("Protocol %s/%d returned\n", proto.Name, proto.Version) + err = errors.New("protocol returned") + } else { + p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err) + } + select { + case p.protoErr <- err: + case <-p.closed: + } + p.protoWG.Done() + }() + } +} + +// getProto finds the protocol responsible for handling +// the given message code. +func (p *Peer) getProto(code uint64) (*protoRW, error) { + for _, proto := range p.running { + if code >= proto.offset && code < proto.offset+proto.Length { + return proto, nil + } + } + return nil, newPeerError(errInvalidMsgCode, "%d", code) +} + +func (p *Peer) closeProtocols() { + for _, p := range p.running { + close(p.in) + } + p.protoWG.Wait() +} + +// writeProtoMsg sends the given message on behalf of the given named protocol. +// this exists because of Server.Broadcast. +func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { + proto, ok := p.running[protoName] + if !ok { + return fmt.Errorf("protocol %s not handled by peer", protoName) + } + if msg.Code >= proto.Length { + return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) + } + msg.Code += proto.offset + return p.rw.WriteMsg(msg) +} + +type protoRW struct { + Protocol + + in chan Msg + offset uint64 + w MsgWriter +} + +func (rw *protoRW) WriteMsg(msg Msg) error { + if msg.Code >= rw.Length { + return newPeerError(errInvalidMsgCode, "not handled") + } + msg.Code += rw.offset + return rw.w.WriteMsg(msg) +} + +func (rw *protoRW) ReadMsg() (Msg, error) { + msg, ok := <-rw.in + if !ok { + return msg, io.EOF + } + msg.Code -= rw.offset + return msg, nil +} diff --git a/p2p/peer_error.go b/p2p/peer_error.go new file mode 100644 index 000000000..0ff4f4b43 --- /dev/null +++ b/p2p/peer_error.go @@ -0,0 +1,131 @@ +package p2p + +import ( + "fmt" +) + +const ( + errMagicTokenMismatch = iota + errRead + errWrite + errMisc + errInvalidMsgCode + errInvalidMsg + errP2PVersionMismatch + errPubkeyInvalid + errPubkeyForbidden + errProtocolBreach + errPingTimeout + errInvalidNetworkId + errInvalidProtocolVersion +) + +var errorToString = map[int]string{ + errMagicTokenMismatch: "magic token mismatch", + errRead: "read error", + errWrite: "write error", + errMisc: "misc error", + errInvalidMsgCode: "invalid message code", + errInvalidMsg: "invalid message", + errP2PVersionMismatch: "P2P Version Mismatch", + errPubkeyInvalid: "public key invalid", + errPubkeyForbidden: "public key forbidden", + errProtocolBreach: "protocol Breach", + errPingTimeout: "ping timeout", + errInvalidNetworkId: "invalid network id", + errInvalidProtocolVersion: "invalid protocol version", +} + +type peerError struct { + Code int + message string +} + +func newPeerError(code int, format string, v ...interface{}) *peerError { + desc, ok := errorToString[code] + if !ok { + panic("invalid error code") + } + err := &peerError{code, desc} + if format != "" { + err.message += ": " + fmt.Sprintf(format, v...) + } + return err +} + +func (self *peerError) Error() string { + return self.message +} + +type DiscReason byte + +const ( + DiscRequested DiscReason = iota + DiscNetworkError + DiscProtocolError + DiscUselessPeer + DiscTooManyPeers + DiscAlreadyConnected + DiscIncompatibleVersion + DiscInvalidIdentity + DiscQuitting + DiscUnexpectedIdentity + DiscSelf + DiscReadTimeout + DiscSubprotocolError +) + +var discReasonToString = [...]string{ + DiscRequested: "Disconnect requested", + DiscNetworkError: "Network error", + DiscProtocolError: "Breach of protocol", + DiscUselessPeer: "Useless peer", + DiscTooManyPeers: "Too many peers", + DiscAlreadyConnected: "Already connected", + DiscIncompatibleVersion: "Incompatible P2P protocol version", + DiscInvalidIdentity: "Invalid node identity", + DiscQuitting: "Client quitting", + DiscUnexpectedIdentity: "Unexpected identity", + DiscSelf: "Connected to self", + DiscReadTimeout: "Read timeout", + DiscSubprotocolError: "Subprotocol error", +} + +func (d DiscReason) String() string { + if len(discReasonToString) < int(d) { + return fmt.Sprintf("Unknown Reason(%d)", d) + } + return discReasonToString[d] +} + +type discRequestedError DiscReason + +func (err discRequestedError) Error() string { + return fmt.Sprintf("disconnect requested: %v", DiscReason(err)) +} + +func discReasonForError(err error) DiscReason { + if reason, ok := err.(discRequestedError); ok { + return DiscReason(reason) + } + peerError, ok := err.(*peerError) + if !ok { + return DiscSubprotocolError + } + switch peerError.Code { + case errP2PVersionMismatch: + return DiscIncompatibleVersion + case errPubkeyInvalid: + return DiscInvalidIdentity + case errPubkeyForbidden: + return DiscUselessPeer + case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach: + return DiscProtocolError + case errPingTimeout: + return DiscReadTimeout + case errRead, errWrite: + return DiscNetworkError + default: + return DiscSubprotocolError + } +} diff --git a/p2p/peer_test.go b/p2p/peer_test.go new file mode 100644 index 000000000..cc9f1f0cd --- /dev/null +++ b/p2p/peer_test.go @@ -0,0 +1,226 @@ +package p2p + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net" + "reflect" + "testing" + "time" + + "github.com/ethereum/go-ethereum/rlp" +) + +var discard = Protocol{ + Name: "discard", + Length: 1, + Run: func(p *Peer, rw MsgReadWriter) error { + for { + msg, err := rw.ReadMsg() + if err != nil { + return err + } + fmt.Printf("discarding %d\n", msg.Code) + if err = msg.Discard(); err != nil { + return err + } + } + }, +} + +func testPeer(protos []Protocol) (io.Closer, *conn, *Peer, <-chan DiscReason) { + fd1, _ := net.Pipe() + hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} + hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} + for _, p := range protos { + hs1.Caps = append(hs1.Caps, p.cap()) + hs2.Caps = append(hs2.Caps, p.cap()) + } + + p1, p2 := MsgPipe() + peer := newPeer(fd1, &conn{p1, hs1}, protos) + errc := make(chan DiscReason, 1) + go func() { errc <- peer.run() }() + + return p1, &conn{p2, hs2}, peer, errc +} + +func TestPeerProtoReadMsg(t *testing.T) { + defer testlog(t).detach() + + done := make(chan struct{}) + proto := Protocol{ + Name: "a", + Length: 5, + Run: func(peer *Peer, rw MsgReadWriter) error { + if err := expectMsg(rw, 2, []uint{1}); err != nil { + t.Error(err) + } + if err := expectMsg(rw, 3, []uint{2}); err != nil { + t.Error(err) + } + if err := expectMsg(rw, 4, []uint{3}); err != nil { + t.Error(err) + } + close(done) + return nil + }, + } + + closer, rw, _, errc := testPeer([]Protocol{proto}) + defer closer.Close() + + EncodeMsg(rw, baseProtocolLength+2, 1) + EncodeMsg(rw, baseProtocolLength+3, 2) + EncodeMsg(rw, baseProtocolLength+4, 3) + + select { + case <-done: + case err := <-errc: + t.Errorf("peer returned: %v", err) + case <-time.After(2 * time.Second): + t.Errorf("receive timeout") + } +} + +func TestPeerProtoEncodeMsg(t *testing.T) { + defer testlog(t).detach() + + proto := Protocol{ + Name: "a", + Length: 2, + Run: func(peer *Peer, rw MsgReadWriter) error { + if err := EncodeMsg(rw, 2); err == nil { + t.Error("expected error for out-of-range msg code, got nil") + } + if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil { + t.Errorf("write error: %v", err) + } + return nil + }, + } + closer, rw, _, _ := testPeer([]Protocol{proto}) + defer closer.Close() + + if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil { + t.Error(err) + } +} + +func TestPeerWriteForBroadcast(t *testing.T) { + defer testlog(t).detach() + + closer, rw, peer, peerErr := testPeer([]Protocol{discard}) + defer closer.Close() + + // test write errors + if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { + t.Errorf("expected error for unknown protocol, got nil") + } + if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil { + t.Errorf("expected error for out-of-range msg code, got nil") + } else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode { + t.Errorf("wrong error for out-of-range msg code, got %#v", err) + } + + // setup for reading the message on the other end + read := make(chan struct{}) + go func() { + if err := expectMsg(rw, 16, nil); err != nil { + t.Error(err) + } + close(read) + }() + + // test successful write + if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { + t.Errorf("expect no error for known protocol: %v", err) + } + select { + case <-read: + case err := <-peerErr: + t.Fatalf("peer stopped: %v", err) + } +} + +func TestPeerPing(t *testing.T) { + defer testlog(t).detach() + + closer, rw, _, _ := testPeer(nil) + defer closer.Close() + if err := EncodeMsg(rw, pingMsg); err != nil { + t.Fatal(err) + } + if err := expectMsg(rw, pongMsg, nil); err != nil { + t.Error(err) + } +} + +func TestPeerDisconnect(t *testing.T) { + defer testlog(t).detach() + + closer, rw, _, disc := testPeer(nil) + defer closer.Close() + if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil { + t.Fatal(err) + } + if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil { + t.Error(err) + } + closer.Close() // make test end faster + if reason := <-disc; reason != DiscRequested { + t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested) + } +} + +func TestNewPeer(t *testing.T) { + name := "nodename" + caps := []Cap{{"foo", 2}, {"bar", 3}} + id := randomID() + p := NewPeer(id, name, caps) + if p.ID() != id { + t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id) + } + if p.Name() != name { + t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name) + } + if !reflect.DeepEqual(p.Caps(), caps) { + t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) + } + + p.Disconnect(DiscAlreadyConnected) // Should not hang +} + +// expectMsg reads a message from r and verifies that its +// code and encoded RLP content match the provided values. +// If content is nil, the payload is discarded and not verified. +func expectMsg(r MsgReader, code uint64, content interface{}) error { + msg, err := r.ReadMsg() + if err != nil { + return err + } + if msg.Code != code { + return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code) + } + if content == nil { + return msg.Discard() + } else { + contentEnc, err := rlp.EncodeToBytes(content) + if err != nil { + panic("content encode error: " + err.Error()) + } + if int(msg.Size) != len(contentEnc) { + return fmt.Errorf("message size mismatch: got %d, want %d", msg.Size, len(contentEnc)) + } + actualContent, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return err + } + if !bytes.Equal(actualContent, contentEnc) { + return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc) + } + } + return nil +} diff --git a/p2p/protocol.go b/p2p/protocol.go new file mode 100644 index 000000000..5fa395eda --- /dev/null +++ b/p2p/protocol.go @@ -0,0 +1,50 @@ +package p2p + +import "fmt" + +// Protocol represents a P2P subprotocol implementation. +type Protocol struct { + // Name should contain the official protocol name, + // often a three-letter word. + Name string + + // Version should contain the version number of the protocol. + Version uint + + // Length should contain the number of message codes used + // by the protocol. + Length uint64 + + // Run is called in a new groutine when the protocol has been + // negotiated with a peer. It should read and write messages from + // rw. The Payload for each message must be fully consumed. + // + // The peer connection is closed when Start returns. It should return + // any protocol-level error (such as an I/O error) that is + // encountered. + Run func(peer *Peer, rw MsgReadWriter) error +} + +func (p Protocol) cap() Cap { + return Cap{p.Name, p.Version} +} + +// Cap is the structure of a peer capability. +type Cap struct { + Name string + Version uint +} + +func (cap Cap) RlpData() interface{} { + return []interface{}{cap.Name, cap.Version} +} + +func (cap Cap) String() string { + return fmt.Sprintf("%s/%d", cap.Name, cap.Version) +} + +type capsByName []Cap + +func (cs capsByName) Len() int { return len(cs) } +func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name } +func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] } diff --git a/p2p/rlpx.go b/p2p/rlpx.go new file mode 100644 index 000000000..6b533e275 --- /dev/null +++ b/p2p/rlpx.go @@ -0,0 +1,174 @@ +package p2p + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "errors" + "hash" + "io" + + "github.com/ethereum/go-ethereum/rlp" +) + +var ( + // this is used in place of actual frame header data. + // TODO: replace this when Msg contains the protocol type code. + zeroHeader = []byte{0xC2, 0x80, 0x80} + + // sixteen zero bytes + zero16 = make([]byte, 16) + + maxUint24 = ^uint32(0) >> 8 +) + +// rlpxFrameRW implements a simplified version of RLPx framing. +// chunked messages are not supported and all headers are equal to +// zeroHeader. +// +// rlpxFrameRW is not safe for concurrent use from multiple goroutines. +type rlpxFrameRW struct { + conn io.ReadWriter + enc cipher.Stream + dec cipher.Stream + + macCipher cipher.Block + egressMAC hash.Hash + ingressMAC hash.Hash +} + +func newRlpxFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW { + macc, err := aes.NewCipher(s.MAC) + if err != nil { + panic("invalid MAC secret: " + err.Error()) + } + encc, err := aes.NewCipher(s.AES) + if err != nil { + panic("invalid AES secret: " + err.Error()) + } + // we use an all-zeroes IV for AES because the key used + // for encryption is ephemeral. + iv := make([]byte, encc.BlockSize()) + return &rlpxFrameRW{ + conn: conn, + enc: cipher.NewCTR(encc, iv), + dec: cipher.NewCTR(encc, iv), + macCipher: macc, + egressMAC: s.EgressMAC, + ingressMAC: s.IngressMAC, + } +} + +func (rw *rlpxFrameRW) WriteMsg(msg Msg) error { + ptype, _ := rlp.EncodeToBytes(msg.Code) + + // write header + headbuf := make([]byte, 32) + fsize := uint32(len(ptype)) + msg.Size + if fsize > maxUint24 { + return errors.New("message size overflows uint24") + } + putInt24(fsize, headbuf) // TODO: check overflow + copy(headbuf[3:], zeroHeader) + rw.enc.XORKeyStream(headbuf[:16], headbuf[:16]) // first half is now encrypted + + // write header MAC + copy(headbuf[16:], updateMAC(rw.egressMAC, rw.macCipher, headbuf[:16])) + if _, err := rw.conn.Write(headbuf); err != nil { + return err + } + + // write encrypted frame, updating the egress MAC hash with + // the data written to conn. + tee := cipher.StreamWriter{S: rw.enc, W: io.MultiWriter(rw.conn, rw.egressMAC)} + if _, err := tee.Write(ptype); err != nil { + return err + } + if _, err := io.Copy(tee, msg.Payload); err != nil { + return err + } + if padding := fsize % 16; padding > 0 { + if _, err := tee.Write(zero16[:16-padding]); err != nil { + return err + } + } + + // write frame MAC. egress MAC hash is up to date because + // frame content was written to it as well. + fmacseed := rw.egressMAC.Sum(nil) + mac := updateMAC(rw.egressMAC, rw.macCipher, fmacseed) + _, err := rw.conn.Write(mac) + return err +} + +func (rw *rlpxFrameRW) ReadMsg() (msg Msg, err error) { + // read the header + headbuf := make([]byte, 32) + if _, err := io.ReadFull(rw.conn, headbuf); err != nil { + return msg, err + } + // verify header mac + shouldMAC := updateMAC(rw.ingressMAC, rw.macCipher, headbuf[:16]) + if !hmac.Equal(shouldMAC, headbuf[16:]) { + return msg, errors.New("bad header MAC") + } + rw.dec.XORKeyStream(headbuf[:16], headbuf[:16]) // first half is now decrypted + fsize := readInt24(headbuf) + // ignore protocol type for now + + // read the frame content + var rsize = fsize // frame size rounded up to 16 byte boundary + if padding := fsize % 16; padding > 0 { + rsize += 16 - padding + } + framebuf := make([]byte, rsize) + if _, err := io.ReadFull(rw.conn, framebuf); err != nil { + return msg, err + } + + // read and validate frame MAC. we can re-use headbuf for that. + rw.ingressMAC.Write(framebuf) + fmacseed := rw.ingressMAC.Sum(nil) + if _, err := io.ReadFull(rw.conn, headbuf[:16]); err != nil { + return msg, err + } + shouldMAC = updateMAC(rw.ingressMAC, rw.macCipher, fmacseed) + if !hmac.Equal(shouldMAC, headbuf[:16]) { + return msg, errors.New("bad frame MAC") + } + + // decrypt frame content + rw.dec.XORKeyStream(framebuf, framebuf) + + // decode message code + content := bytes.NewReader(framebuf[:fsize]) + if err := rlp.Decode(content, &msg.Code); err != nil { + return msg, err + } + msg.Size = uint32(content.Len()) + msg.Payload = content + return msg, nil +} + +// updateMAC reseeds the given hash with encrypted seed. +// it returns the first 16 bytes of the hash sum after seeding. +func updateMAC(mac hash.Hash, block cipher.Block, seed []byte) []byte { + aesbuf := make([]byte, aes.BlockSize) + block.Encrypt(aesbuf, mac.Sum(nil)) + for i := range aesbuf { + aesbuf[i] ^= seed[i] + } + mac.Write(aesbuf) + return mac.Sum(nil)[:16] +} + +func readInt24(b []byte) uint32 { + return uint32(b[2]) | uint32(b[1])<<8 | uint32(b[0])<<16 +} + +func putInt24(v uint32, b []byte) { + b[0] = byte(v >> 16) + b[1] = byte(v >> 8) + b[2] = byte(v) +} diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go new file mode 100644 index 000000000..49354c7ed --- /dev/null +++ b/p2p/rlpx_test.go @@ -0,0 +1,124 @@ +package p2p + +import ( + "bytes" + "crypto/rand" + "io/ioutil" + "strings" + "testing" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/rlp" +) + +func TestRlpxFrameFake(t *testing.T) { + buf := new(bytes.Buffer) + hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}) + rw := newRlpxFrameRW(buf, secrets{ + AES: crypto.Sha3(), + MAC: crypto.Sha3(), + IngressMAC: hash, + EgressMAC: hash, + }) + + golden := unhex(` +00828ddae471818bb0bfa6b551d1cb42 +01010101010101010101010101010101 +ba628a4ba590cb43f7848f41c4382885 +01010101010101010101010101010101 +`) + + // Check WriteMsg. This puts a message into the buffer. + if err := EncodeMsg(rw, 8, 1, 2, 3, 4); err != nil { + t.Fatalf("WriteMsg error: %v", err) + } + written := buf.Bytes() + if !bytes.Equal(written, golden) { + t.Fatalf("output mismatch:\n got: %x\n want: %x", written, golden) + } + + // Check ReadMsg. It reads the message encoded by WriteMsg, which + // is equivalent to the golden message above. + msg, err := rw.ReadMsg() + if err != nil { + t.Fatalf("ReadMsg error: %v", err) + } + if msg.Size != 5 { + t.Errorf("msg size mismatch: got %d, want %d", msg.Size, 5) + } + if msg.Code != 8 { + t.Errorf("msg code mismatch: got %d, want %d", msg.Code, 8) + } + payload, _ := ioutil.ReadAll(msg.Payload) + wantPayload := unhex("C401020304") + if !bytes.Equal(payload, wantPayload) { + t.Errorf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload) + } +} + +type fakeHash []byte + +func (fakeHash) Write(p []byte) (int, error) { return len(p), nil } +func (fakeHash) Reset() {} +func (fakeHash) BlockSize() int { return 0 } + +func (h fakeHash) Size() int { return len(h) } +func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) } + +func TestRlpxFrameRW(t *testing.T) { + var ( + aesSecret = make([]byte, 16) + macSecret = make([]byte, 16) + egressMACinit = make([]byte, 32) + ingressMACinit = make([]byte, 32) + ) + for _, s := range [][]byte{aesSecret, macSecret, egressMACinit, ingressMACinit} { + rand.Read(s) + } + conn := new(bytes.Buffer) + + s1 := secrets{ + AES: aesSecret, + MAC: macSecret, + EgressMAC: sha3.NewKeccak256(), + IngressMAC: sha3.NewKeccak256(), + } + s1.EgressMAC.Write(egressMACinit) + s1.IngressMAC.Write(ingressMACinit) + rw1 := newRlpxFrameRW(conn, s1) + + s2 := secrets{ + AES: aesSecret, + MAC: macSecret, + EgressMAC: sha3.NewKeccak256(), + IngressMAC: sha3.NewKeccak256(), + } + s2.EgressMAC.Write(ingressMACinit) + s2.IngressMAC.Write(egressMACinit) + rw2 := newRlpxFrameRW(conn, s2) + + // send some messages + for i := 0; i < 10; i++ { + // write message into conn buffer + wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)} + err := EncodeMsg(rw1, uint64(i), wmsg...) + if err != nil { + t.Fatalf("WriteMsg error (i=%d): %v", i, err) + } + + // read message that rw1 just wrote + msg, err := rw2.ReadMsg() + if err != nil { + t.Fatalf("ReadMsg error (i=%d): %v", i, err) + } + if msg.Code != uint64(i) { + t.Fatalf("msg code mismatch: got %d, want %d", msg.Code, i) + } + payload, _ := ioutil.ReadAll(msg.Payload) + wantPayload, _ := rlp.EncodeToBytes(wmsg) + if !bytes.Equal(payload, wantPayload) { + t.Fatalf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload) + } + } +} diff --git a/p2p/server.go b/p2p/server.go new file mode 100644 index 000000000..34000cb4c --- /dev/null +++ b/p2p/server.go @@ -0,0 +1,425 @@ +package p2p + +import ( + "bytes" + "crypto/ecdsa" + "errors" + "fmt" + "net" + "runtime" + "sync" + "time" + + "github.com/ethereum/go-ethereum/ethutil" + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/nat" +) + +const ( + defaultDialTimeout = 10 * time.Second + refreshPeersInterval = 30 * time.Second + + // total timeout for encryption handshake and protocol + // handshake in both directions. + handshakeTimeout = 5 * time.Second + // maximum time allowed for reading a complete message. + // this is effectively the amount of time a connection can be idle. + frameReadTimeout = 1 * time.Minute + // maximum amount of time allowed for writing a complete message. + frameWriteTimeout = 5 * time.Second +) + +var srvlog = logger.NewLogger("P2P Server") +var srvjslog = logger.NewJsonLogger() + +// MakeName creates a node name that follows the ethereum convention +// for such names. It adds the operation system name and Go runtime version +// the name. +func MakeName(name, version string) string { + return fmt.Sprintf("%s/v%s/%s/%s", name, version, runtime.GOOS, runtime.Version()) +} + +// Server manages all peer connections. +// +// The fields of Server are used as configuration parameters. +// You should set them before starting the Server. Fields may not be +// modified while the server is running. +type Server struct { + // This field must be set to a valid secp256k1 private key. + PrivateKey *ecdsa.PrivateKey + + // MaxPeers is the maximum number of peers that can be + // connected. It must be greater than zero. + MaxPeers int + + // Name sets the node name of this server. + // Use MakeName to create a name that follows existing conventions. + Name string + + // Bootstrap nodes are used to establish connectivity + // with the rest of the network. + BootstrapNodes []*discover.Node + + // Protocols should contain the protocols supported + // by the server. Matching protocols are launched for + // each peer. + Protocols []Protocol + + // If ListenAddr is set to a non-nil address, the server + // will listen for incoming connections. + // + // If the port is zero, the operating system will pick a port. The + // ListenAddr field will be updated with the actual address when + // the server is started. + ListenAddr string + + // If set to a non-nil value, the given NAT port mapper + // is used to make the listening port available to the + // Internet. + NAT nat.Interface + + // If Dialer is set to a non-nil value, the given Dialer + // is used to dial outbound peer connections. + Dialer *net.Dialer + + // If NoDial is true, the server will not dial any peers. + NoDial bool + + // Hooks for testing. These are useful because we can inhibit + // the whole protocol stack. + setupFunc + newPeerHook + + ourHandshake *protoHandshake + + lock sync.RWMutex + running bool + listener net.Listener + peers map[discover.NodeID]*Peer + + ntab *discover.Table + + quit chan struct{} + loopWG sync.WaitGroup // {dial,listen,nat}Loop + peerWG sync.WaitGroup // active peer goroutines + peerConnect chan *discover.Node +} + +type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node) (*conn, error) +type newPeerHook func(*Peer) + +// Peers returns all connected peers. +func (srv *Server) Peers() (peers []*Peer) { + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { + if peer != nil { + peers = append(peers, peer) + } + } + return +} + +// PeerCount returns the number of connected peers. +func (srv *Server) PeerCount() int { + srv.lock.RLock() + n := len(srv.peers) + srv.lock.RUnlock() + return n +} + +// SuggestPeer creates a connection to the given Node if it +// is not already connected. +func (srv *Server) SuggestPeer(n *discover.Node) { + srv.peerConnect <- n +} + +// Broadcast sends an RLP-encoded message to all connected peers. +// This method is deprecated and will be removed later. +func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) { + var payload []byte + if data != nil { + payload = ethutil.Encode(data) + } + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { + if peer != nil { + var msg = Msg{Code: code} + if data != nil { + msg.Payload = bytes.NewReader(payload) + msg.Size = uint32(len(payload)) + } + peer.writeProtoMsg(protocol, msg) + } + } +} + +// Start starts running the server. +// Servers can be re-used and started again after stopping. +func (srv *Server) Start() (err error) { + srv.lock.Lock() + defer srv.lock.Unlock() + if srv.running { + return errors.New("server already running") + } + srvlog.Infoln("Starting Server") + + // static fields + if srv.PrivateKey == nil { + return fmt.Errorf("Server.PrivateKey must be set to a non-nil key") + } + if srv.MaxPeers <= 0 { + return fmt.Errorf("Server.MaxPeers must be > 0") + } + srv.quit = make(chan struct{}) + srv.peers = make(map[discover.NodeID]*Peer) + srv.peerConnect = make(chan *discover.Node) + if srv.setupFunc == nil { + srv.setupFunc = setupConn + } + + // node table + ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT) + if err != nil { + return err + } + srv.ntab = ntab + + // handshake + srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self()} + for _, p := range srv.Protocols { + srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap()) + } + + // listen/dial + if srv.ListenAddr != "" { + if err := srv.startListening(); err != nil { + return err + } + } + if srv.Dialer == nil { + srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} + } + if !srv.NoDial { + srv.loopWG.Add(1) + go srv.dialLoop() + } + if srv.NoDial && srv.ListenAddr == "" { + srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.") + } + + srv.running = true + return nil +} + +func (srv *Server) startListening() error { + listener, err := net.Listen("tcp", srv.ListenAddr) + if err != nil { + return err + } + laddr := listener.Addr().(*net.TCPAddr) + srv.ListenAddr = laddr.String() + srv.listener = listener + srv.loopWG.Add(1) + go srv.listenLoop() + if !laddr.IP.IsLoopback() && srv.NAT != nil { + srv.loopWG.Add(1) + go func() { + nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p") + srv.loopWG.Done() + }() + } + return nil +} + +// Stop terminates the server and all active peer connections. +// It blocks until all active connections have been closed. +func (srv *Server) Stop() { + srv.lock.Lock() + if !srv.running { + srv.lock.Unlock() + return + } + srv.running = false + srv.lock.Unlock() + + srvlog.Infoln("Stopping Server") + srv.ntab.Close() + if srv.listener != nil { + // this unblocks listener Accept + srv.listener.Close() + } + close(srv.quit) + srv.loopWG.Wait() + + // No new peers can be added at this point because dialLoop and + // listenLoop are down. It is safe to call peerWG.Wait because + // peerWG.Add is not called outside of those loops. + for _, peer := range srv.peers { + peer.Disconnect(DiscQuitting) + } + srv.peerWG.Wait() +} + +// main loop for adding connections via listening +func (srv *Server) listenLoop() { + defer srv.loopWG.Done() + srvlog.Infoln("Listening on", srv.listener.Addr()) + for { + conn, err := srv.listener.Accept() + if err != nil { + return + } + srvlog.Debugf("Accepted conn %v\n", conn.RemoteAddr()) + srv.peerWG.Add(1) + go srv.startPeer(conn, nil) + } +} + +func (srv *Server) dialLoop() { + defer srv.loopWG.Done() + refresh := time.NewTicker(refreshPeersInterval) + defer refresh.Stop() + + srv.ntab.Bootstrap(srv.BootstrapNodes) + go srv.findPeers() + + dialed := make(chan *discover.Node) + dialing := make(map[discover.NodeID]bool) + + // TODO: limit number of active dials + // TODO: ensure only one findPeers goroutine is running + // TODO: pause findPeers when we're at capacity + + for { + select { + case <-refresh.C: + + go srv.findPeers() + + case dest := <-srv.peerConnect: + // avoid dialing nodes that are already connected. + // there is another check for this in addPeer, + // which runs after the handshake. + srv.lock.Lock() + _, isconnected := srv.peers[dest.ID] + srv.lock.Unlock() + if isconnected || dialing[dest.ID] || dest.ID == srv.ntab.Self() { + continue + } + + dialing[dest.ID] = true + srv.peerWG.Add(1) + go func() { + srv.dialNode(dest) + // at this point, the peer has been added + // or discarded. either way, we're not dialing it anymore. + dialed <- dest + }() + + case dest := <-dialed: + delete(dialing, dest.ID) + + case <-srv.quit: + // TODO: maybe wait for active dials + return + } + } +} + +func (srv *Server) dialNode(dest *discover.Node) { + addr := &net.TCPAddr{IP: dest.IP, Port: dest.TCPPort} + srvlog.Debugf("Dialing %v\n", dest) + conn, err := srv.Dialer.Dial("tcp", addr.String()) + if err != nil { + srvlog.DebugDetailf("dial error: %v", err) + return + } + srv.startPeer(conn, dest) +} + +func (srv *Server) findPeers() { + far := srv.ntab.Self() + for i := range far { + far[i] = ^far[i] + } + closeToSelf := srv.ntab.Lookup(srv.ntab.Self()) + farFromSelf := srv.ntab.Lookup(far) + + for i := 0; i < len(closeToSelf) || i < len(farFromSelf); i++ { + if i < len(closeToSelf) { + srv.peerConnect <- closeToSelf[i] + } + if i < len(farFromSelf) { + srv.peerConnect <- farFromSelf[i] + } + } +} + +func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) { + // TODO: handle/store session token + fd.SetDeadline(time.Now().Add(handshakeTimeout)) + conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest) + if err != nil { + fd.Close() + srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err) + return + } + + conn.MsgReadWriter = &netWrapper{ + wrapped: conn.MsgReadWriter, + conn: fd, rtimeout: frameReadTimeout, wtimeout: frameWriteTimeout, + } + p := newPeer(fd, conn, srv.Protocols) + if ok, reason := srv.addPeer(conn.ID, p); !ok { + srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason) + p.politeDisconnect(reason) + return + } + + srvlog.Debugf("Added %v\n", p) + srvjslog.LogJson(&logger.P2PConnected{ + RemoteId: fmt.Sprintf("%x", conn.ID[:]), + RemoteAddress: fd.RemoteAddr().String(), + RemoteVersionString: conn.Name, + NumConnections: srv.PeerCount(), + }) + + if srv.newPeerHook != nil { + srv.newPeerHook(p) + } + discreason := p.run() + srv.removePeer(p) + + srvlog.Debugf("Removed %v (%v)\n", p, discreason) + srvjslog.LogJson(&logger.P2PDisconnected{ + RemoteId: fmt.Sprintf("%x", conn.ID[:]), + NumConnections: srv.PeerCount(), + }) +} + +func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) { + srv.lock.Lock() + defer srv.lock.Unlock() + switch { + case !srv.running: + return false, DiscQuitting + case len(srv.peers) >= srv.MaxPeers: + return false, DiscTooManyPeers + case srv.peers[id] != nil: + return false, DiscAlreadyConnected + case id == srv.ntab.Self(): + return false, DiscSelf + } + srv.peers[id] = p + return true, 0 +} + +func (srv *Server) removePeer(p *Peer) { + srv.lock.Lock() + delete(srv.peers, p.ID()) + srv.lock.Unlock() + srv.peerWG.Done() +} diff --git a/p2p/server_test.go b/p2p/server_test.go new file mode 100644 index 000000000..30447050c --- /dev/null +++ b/p2p/server_test.go @@ -0,0 +1,179 @@ +package p2p + +import ( + "bytes" + "crypto/ecdsa" + "io" + "math/rand" + "net" + "sync" + "testing" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/p2p/discover" +) + +func startTestServer(t *testing.T, pf newPeerHook) *Server { + server := &Server{ + Name: "test", + MaxPeers: 10, + ListenAddr: "127.0.0.1:0", + PrivateKey: newkey(), + newPeerHook: pf, + setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { + id := randomID() + rw := newRlpxFrameRW(fd, secrets{ + MAC: zero16, + AES: zero16, + IngressMAC: sha3.NewKeccak256(), + EgressMAC: sha3.NewKeccak256(), + }) + return &conn{ + MsgReadWriter: rw, + protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion}, + }, nil + }, + } + if err := server.Start(); err != nil { + t.Fatalf("Could not start server: %v", err) + } + return server +} + +func TestServerListen(t *testing.T) { + defer testlog(t).detach() + + // start the test server + connected := make(chan *Peer) + srv := startTestServer(t, func(p *Peer) { + if p == nil { + t.Error("peer func called with nil conn") + } + connected <- p + }) + defer close(connected) + defer srv.Stop() + + // dial the test server + conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) + if err != nil { + t.Fatalf("could not dial: %v", err) + } + defer conn.Close() + + select { + case peer := <-connected: + if peer.LocalAddr().String() != conn.RemoteAddr().String() { + t.Errorf("peer started with wrong conn: got %v, want %v", + peer.LocalAddr(), conn.RemoteAddr()) + } + case <-time.After(1 * time.Second): + t.Error("server did not accept within one second") + } +} + +func TestServerDial(t *testing.T) { + defer testlog(t).detach() + + // run a one-shot TCP server to handle the connection. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("could not setup listener: %v") + } + defer listener.Close() + accepted := make(chan net.Conn) + go func() { + conn, err := listener.Accept() + if err != nil { + t.Error("accept error:", err) + return + } + conn.Close() + accepted <- conn + }() + + // start the server + connected := make(chan *Peer) + srv := startTestServer(t, func(p *Peer) { connected <- p }) + defer close(connected) + defer srv.Stop() + + // tell the server to connect + tcpAddr := listener.Addr().(*net.TCPAddr) + srv.SuggestPeer(&discover.Node{IP: tcpAddr.IP, TCPPort: tcpAddr.Port}) + + select { + case conn := <-accepted: + select { + case peer := <-connected: + if peer.RemoteAddr().String() != conn.LocalAddr().String() { + t.Errorf("peer started with wrong conn: got %v, want %v", + peer.RemoteAddr(), conn.LocalAddr()) + } + // TODO: validate more fields + case <-time.After(1 * time.Second): + t.Error("server did not launch peer within one second") + } + + case <-time.After(1 * time.Second): + t.Error("server did not connect within one second") + } +} + +func TestServerBroadcast(t *testing.T) { + defer testlog(t).detach() + + var connected sync.WaitGroup + srv := startTestServer(t, func(p *Peer) { + p.running = matchProtocols([]Protocol{discard}, []Cap{discard.cap()}, p.rw) + connected.Done() + }) + defer srv.Stop() + + // create a few peers + var conns = make([]net.Conn, 8) + connected.Add(len(conns)) + deadline := time.Now().Add(3 * time.Second) + dialer := &net.Dialer{Deadline: deadline} + for i := range conns { + conn, err := dialer.Dial("tcp", srv.ListenAddr) + if err != nil { + t.Fatalf("conn %d: dial error: %v", i, err) + } + defer conn.Close() + conn.SetDeadline(deadline) + conns[i] = conn + } + connected.Wait() + + // broadcast one message + srv.Broadcast("discard", 0, "foo") + golden := unhex("66e94d166f0a2c3b884cfa59ca34") + + // check that the message has been written everywhere + for i, conn := range conns { + buf := make([]byte, len(golden)) + if _, err := io.ReadFull(conn, buf); err != nil { + t.Errorf("conn %d: read error: %v", i, err) + } else if !bytes.Equal(buf, golden) { + t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden) + } + } +} + +func newkey() *ecdsa.PrivateKey { + key, err := crypto.GenerateKey() + if err != nil { + panic("couldn't generate key: " + err.Error()) + } + return key +} + +func randomID() (id discover.NodeID) { + for i := range id { + id[i] = byte(rand.Intn(255)) + } + return id +} diff --git a/p2p/testlog_test.go b/p2p/testlog_test.go new file mode 100644 index 000000000..c524c154c --- /dev/null +++ b/p2p/testlog_test.go @@ -0,0 +1,28 @@ +package p2p + +import ( + "testing" + + "github.com/ethereum/go-ethereum/logger" +) + +type testLogger struct{ t *testing.T } + +func testlog(t *testing.T) testLogger { + logger.Reset() + l := testLogger{t} + logger.AddLogSystem(l) + return l +} + +func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugDetailLevel } +func (testLogger) SetLogLevel(logger.LogLevel) {} + +func (l testLogger) LogPrint(level logger.LogLevel, msg string) { + l.t.Logf("%s", msg) +} + +func (testLogger) detach() { + logger.Flush() + logger.Reset() +} |