diff options
Diffstat (limited to 'p2p')
-rw-r--r-- | p2p/client_identity.go | 63 | ||||
-rw-r--r-- | p2p/client_identity_test.go | 30 | ||||
-rw-r--r-- | p2p/crypto.go | 363 | ||||
-rw-r--r-- | p2p/crypto_test.go | 167 | ||||
-rw-r--r-- | p2p/discover/node.go | 291 | ||||
-rw-r--r-- | p2p/discover/node_test.go | 201 | ||||
-rw-r--r-- | p2p/discover/table.go | 280 | ||||
-rw-r--r-- | p2p/discover/table_test.go | 311 | ||||
-rw-r--r-- | p2p/discover/udp.go | 431 | ||||
-rw-r--r-- | p2p/discover/udp_test.go | 211 | ||||
-rw-r--r-- | p2p/message.go | 143 | ||||
-rw-r--r-- | p2p/message_test.go | 140 | ||||
-rw-r--r-- | p2p/nat.go | 23 | ||||
-rw-r--r-- | p2p/nat/nat.go | 235 | ||||
-rw-r--r-- | p2p/nat/natpmp.go | 115 | ||||
-rw-r--r-- | p2p/nat/natupnp.go | 149 | ||||
-rw-r--r-- | p2p/natpmp.go | 55 | ||||
-rw-r--r-- | p2p/natupnp.go | 341 | ||||
-rw-r--r-- | p2p/peer.go | 472 | ||||
-rw-r--r-- | p2p/peer_error.go | 58 | ||||
-rw-r--r-- | p2p/peer_test.go | 297 | ||||
-rw-r--r-- | p2p/protocol.go | 244 | ||||
-rw-r--r-- | p2p/protocol_test.go | 158 | ||||
-rw-r--r-- | p2p/server.go | 402 | ||||
-rw-r--r-- | p2p/server_test.go | 87 | ||||
-rw-r--r-- | p2p/testlog_test.go | 2 | ||||
-rw-r--r-- | p2p/testpoc7.go | 40 |
27 files changed, 3620 insertions, 1689 deletions
diff --git a/p2p/client_identity.go b/p2p/client_identity.go deleted file mode 100644 index f15fd01bf..000000000 --- a/p2p/client_identity.go +++ /dev/null @@ -1,63 +0,0 @@ -package p2p - -import ( - "fmt" - "runtime" -) - -// ClientIdentity represents the identity of a peer. -type ClientIdentity interface { - String() string // human readable identity - Pubkey() []byte // 512-bit public key -} - -type SimpleClientIdentity struct { - clientIdentifier string - version string - customIdentifier string - os string - implementation string - pubkey []byte -} - -func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey []byte) *SimpleClientIdentity { - clientIdentity := &SimpleClientIdentity{ - clientIdentifier: clientIdentifier, - version: version, - customIdentifier: customIdentifier, - os: runtime.GOOS, - implementation: runtime.Version(), - pubkey: pubkey, - } - - return clientIdentity -} - -func (c *SimpleClientIdentity) init() { -} - -func (c *SimpleClientIdentity) String() string { - var id string - if len(c.customIdentifier) > 0 { - id = "/" + c.customIdentifier - } - - return fmt.Sprintf("%s/v%s%s/%s/%s", - c.clientIdentifier, - c.version, - id, - c.os, - c.implementation) -} - -func (c *SimpleClientIdentity) Pubkey() []byte { - return []byte(c.pubkey) -} - -func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) { - c.customIdentifier = customIdentifier -} - -func (c *SimpleClientIdentity) GetCustomIdentifier() string { - return c.customIdentifier -} diff --git a/p2p/client_identity_test.go b/p2p/client_identity_test.go deleted file mode 100644 index 7248a7b1a..000000000 --- a/p2p/client_identity_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package p2p - -import ( - "fmt" - "runtime" - "testing" -) - -func TestClientIdentity(t *testing.T) { - clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", []byte("pubkey")) - clientString := clientIdentity.String() - expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version()) - if clientString != expected { - t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString) - } - customIdentifier := clientIdentity.GetCustomIdentifier() - if customIdentifier != "test" { - t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier) - } - clientIdentity.SetCustomIdentifier("test2") - customIdentifier = clientIdentity.GetCustomIdentifier() - if customIdentifier != "test2" { - t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier) - } - clientString = clientIdentity.String() - expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version()) - if clientString != expected { - t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString) - } -} diff --git a/p2p/crypto.go b/p2p/crypto.go new file mode 100644 index 000000000..2692d708c --- /dev/null +++ b/p2p/crypto.go @@ -0,0 +1,363 @@ +package p2p + +import ( + // "binary" + "crypto/ecdsa" + "crypto/rand" + "fmt" + "io" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/secp256k1" + ethlogger "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/obscuren/ecies" +) + +var clogger = ethlogger.NewLogger("CRYPTOID") + +const ( + sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2 + sigLen = 65 // elliptic S256 + pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte + shaLen = 32 // hash length (for nonce etc) + + authMsgLen = sigLen + shaLen + pubLen + shaLen + 1 + authRespLen = pubLen + shaLen + 1 + + eciesBytes = 65 + 16 + 32 + iHSLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake + rHSLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake +) + +type hexkey []byte + +func (self hexkey) String() string { + return fmt.Sprintf("(%d) %x", len(self), []byte(self)) +} + +func encHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, dial *discover.Node) ( + remoteID discover.NodeID, + sessionToken []byte, + err error, +) { + if dial == nil { + var remotePubkey []byte + sessionToken, remotePubkey, err = inboundEncHandshake(conn, prv, nil) + copy(remoteID[:], remotePubkey) + } else { + remoteID = dial.ID + sessionToken, err = outboundEncHandshake(conn, prv, remoteID[:], nil) + } + return remoteID, sessionToken, err +} + +// outboundEncHandshake negotiates a session token on conn. +// it should be called on the dialing side of the connection. +// +// privateKey is the local client's private key +// remotePublicKey is the remote peer's node ID +// sessionToken is the token from a previous session with this node. +func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePublicKey []byte, sessionToken []byte) ( + newSessionToken []byte, + err error, +) { + auth, initNonce, randomPrivKey, err := authMsg(prvKey, remotePublicKey, sessionToken) + if err != nil { + return nil, err + } + if sessionToken != nil { + clogger.Debugf("session-token: %v", hexkey(sessionToken)) + } + + clogger.Debugf("initiator-nonce: %v", hexkey(initNonce)) + clogger.Debugf("initiator-random-private-key: %v", hexkey(crypto.FromECDSA(randomPrivKey))) + randomPublicKeyS, _ := exportPublicKey(&randomPrivKey.PublicKey) + clogger.Debugf("initiator-random-public-key: %v", hexkey(randomPublicKeyS)) + if _, err = conn.Write(auth); err != nil { + return nil, err + } + clogger.Debugf("initiator handshake: %v", hexkey(auth)) + + response := make([]byte, rHSLen) + if _, err = io.ReadFull(conn, response); err != nil { + return nil, err + } + recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prvKey) + if err != nil { + return nil, err + } + + clogger.Debugf("receiver-nonce: %v", hexkey(recNonce)) + remoteRandomPubKeyS, _ := exportPublicKey(remoteRandomPubKey) + clogger.Debugf("receiver-random-public-key: %v", hexkey(remoteRandomPubKeyS)) + return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey) +} + +// authMsg creates the initiator handshake. +func authMsg(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte) ( + auth, initNonce []byte, + randomPrvKey *ecdsa.PrivateKey, + err error, +) { + // session init, common to both parties + remotePubKey, err := importPublicKey(remotePubKeyS) + if err != nil { + return + } + + var tokenFlag byte // = 0x00 + if sessionToken == nil { + // no session token found means we need to generate shared secret. + // ecies shared secret is used as initial session token for new peers + // generate shared key from prv and remote pubkey + if sessionToken, err = ecies.ImportECDSA(prvKey).GenerateShared(ecies.ImportECDSAPublic(remotePubKey), sskLen, sskLen); err != nil { + return + } + // tokenFlag = 0x00 // redundant + } else { + // for known peers, we use stored token from the previous session + tokenFlag = 0x01 + } + + //E(remote-pubk, S(ecdhe-random, ecdh-shared-secret^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x0) + // E(remote-pubk, S(ecdhe-random, token^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x1) + // allocate msgLen long message, + var msg []byte = make([]byte, authMsgLen) + initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1] + if _, err = rand.Read(initNonce); err != nil { + return + } + // create known message + // ecdh-shared-secret^nonce for new peers + // token^nonce for old peers + var sharedSecret = xor(sessionToken, initNonce) + + // generate random keypair to use for signing + if randomPrvKey, err = crypto.GenerateKey(); err != nil { + return + } + // sign shared secret (message known to both parties): shared-secret + var signature []byte + // signature = sign(ecdhe-random, shared-secret) + // uses secp256k1.Sign + if signature, err = crypto.Sign(sharedSecret, randomPrvKey); err != nil { + return + } + + // message + // signed-shared-secret || H(ecdhe-random-pubk) || pubk || nonce || 0x0 + copy(msg, signature) // copy signed-shared-secret + // H(ecdhe-random-pubk) + var randomPubKey64 []byte + if randomPubKey64, err = exportPublicKey(&randomPrvKey.PublicKey); err != nil { + return + } + var pubKey64 []byte + if pubKey64, err = exportPublicKey(&prvKey.PublicKey); err != nil { + return + } + copy(msg[sigLen:sigLen+shaLen], crypto.Sha3(randomPubKey64)) + // pubkey copied to the correct segment. + copy(msg[sigLen+shaLen:sigLen+shaLen+pubLen], pubKey64) + // nonce is already in the slice + // stick tokenFlag byte to the end + msg[authMsgLen-1] = tokenFlag + + // encrypt using remote-pubk + // auth = eciesEncrypt(remote-pubk, msg) + if auth, err = crypto.Encrypt(remotePubKey, msg); err != nil { + return + } + return +} + +// completeHandshake is called when the initiator receives an +// authentication response (aka receiver handshake). It completes the +// handshake by reading off parameters the remote peer provides needed +// to set up the secure session. +func completeHandshake(auth []byte, prvKey *ecdsa.PrivateKey) ( + respNonce []byte, + remoteRandomPubKey *ecdsa.PublicKey, + tokenFlag bool, + err error, +) { + var msg []byte + // they prove that msg is meant for me, + // I prove I possess private key if i can read it + if msg, err = crypto.Decrypt(prvKey, auth); err != nil { + return + } + + respNonce = msg[pubLen : pubLen+shaLen] + var remoteRandomPubKeyS = msg[:pubLen] + if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil { + return + } + if msg[authRespLen-1] == 0x01 { + tokenFlag = true + } + return +} + +// inboundEncHandshake negotiates a session token on conn. +// it should be called on the listening side of the connection. +// +// privateKey is the local client's private key +// sessionToken is the token from a previous session with this node. +func inboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, sessionToken []byte) ( + token, remotePubKey []byte, + err error, +) { + // we are listening connection. we are responders in the + // handshake. Extract info from the authentication. The initiator + // starts by sending us a handshake that we need to respond to. so + // we read auth message first, then respond. + auth := make([]byte, iHSLen) + if _, err := io.ReadFull(conn, auth); err != nil { + return nil, nil, err + } + response, recNonce, initNonce, remotePubKey, randomPrivKey, remoteRandomPubKey, err := authResp(auth, sessionToken, prvKey) + if err != nil { + return nil, nil, err + } + clogger.Debugf("receiver-nonce: %v", hexkey(recNonce)) + clogger.Debugf("receiver-random-priv-key: %v", hexkey(crypto.FromECDSA(randomPrivKey))) + if _, err = conn.Write(response); err != nil { + return nil, nil, err + } + clogger.Debugf("receiver handshake:\n%v", hexkey(response)) + token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey) + return token, remotePubKey, err +} + +// authResp is called by peer if it accepted (but not +// initiated) the connection from the remote. It is passed the initiator +// handshake received and the session token belonging to the +// remote initiator. +// +// The first return value is the authentication response (aka receiver +// handshake) that is to be sent to the remote initiator. +func authResp(auth, sessionToken []byte, prvKey *ecdsa.PrivateKey) ( + authResp, respNonce, initNonce, remotePubKeyS []byte, + randomPrivKey *ecdsa.PrivateKey, + remoteRandomPubKey *ecdsa.PublicKey, + err error, +) { + // they prove that msg is meant for me, + // I prove I possess private key if i can read it + msg, err := crypto.Decrypt(prvKey, auth) + if err != nil { + return + } + + remotePubKeyS = msg[sigLen+shaLen : sigLen+shaLen+pubLen] + remotePubKey, _ := importPublicKey(remotePubKeyS) + + var tokenFlag byte + if sessionToken == nil { + // no session token found means we need to generate shared secret. + // ecies shared secret is used as initial session token for new peers + // generate shared key from prv and remote pubkey + if sessionToken, err = ecies.ImportECDSA(prvKey).GenerateShared(ecies.ImportECDSAPublic(remotePubKey), sskLen, sskLen); err != nil { + return + } + // tokenFlag = 0x00 // redundant + } else { + // for known peers, we use stored token from the previous session + tokenFlag = 0x01 + } + + // the initiator nonce is read off the end of the message + initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1] + // I prove that i own prv key (to derive shared secret, and read + // nonce off encrypted msg) and that I own shared secret they + // prove they own the private key belonging to ecdhe-random-pubk + // we can now reconstruct the signed message and recover the peers + // pubkey + var signedMsg = xor(sessionToken, initNonce) + var remoteRandomPubKeyS []byte + if remoteRandomPubKeyS, err = secp256k1.RecoverPubkey(signedMsg, msg[:sigLen]); err != nil { + return + } + // convert to ECDSA standard + if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil { + return + } + + // now we find ourselves a long task too, fill it random + var resp = make([]byte, authRespLen) + // generate shaLen long nonce + respNonce = resp[pubLen : pubLen+shaLen] + if _, err = rand.Read(respNonce); err != nil { + return + } + // generate random keypair for session + if randomPrivKey, err = crypto.GenerateKey(); err != nil { + return + } + // responder auth message + // E(remote-pubk, ecdhe-random-pubk || nonce || 0x0) + var randomPubKeyS []byte + if randomPubKeyS, err = exportPublicKey(&randomPrivKey.PublicKey); err != nil { + return + } + copy(resp[:pubLen], randomPubKeyS) + // nonce is already in the slice + resp[authRespLen-1] = tokenFlag + + // encrypt using remote-pubk + // auth = eciesEncrypt(remote-pubk, msg) + // why not encrypt with ecdhe-random-remote + if authResp, err = crypto.Encrypt(remotePubKey, resp); err != nil { + return + } + return +} + +// newSession is called after the handshake is completed. The +// arguments are values negotiated in the handshake. The return value +// is a new session Token to be remembered for the next time we +// connect with this peer. +func newSession(initNonce, respNonce []byte, privKey *ecdsa.PrivateKey, remoteRandomPubKey *ecdsa.PublicKey) ([]byte, error) { + // 3) Now we can trust ecdhe-random-pubk to derive new keys + //ecdhe-shared-secret = ecdh.agree(ecdhe-random, remote-ecdhe-random-pubk) + pubKey := ecies.ImportECDSAPublic(remoteRandomPubKey) + dhSharedSecret, err := ecies.ImportECDSA(privKey).GenerateShared(pubKey, sskLen, sskLen) + if err != nil { + return nil, err + } + sharedSecret := crypto.Sha3(dhSharedSecret, crypto.Sha3(respNonce, initNonce)) + sessionToken := crypto.Sha3(sharedSecret) + return sessionToken, nil +} + +// importPublicKey unmarshals 512 bit public keys. +func importPublicKey(pubKey []byte) (pubKeyEC *ecdsa.PublicKey, err error) { + var pubKey65 []byte + switch len(pubKey) { + case 64: + // add 'uncompressed key' flag + pubKey65 = append([]byte{0x04}, pubKey...) + case 65: + pubKey65 = pubKey + default: + return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey)) + } + return crypto.ToECDSAPub(pubKey65), nil +} + +func exportPublicKey(pubKeyEC *ecdsa.PublicKey) (pubKey []byte, err error) { + if pubKeyEC == nil { + return nil, fmt.Errorf("no ECDSA public key given") + } + return crypto.FromECDSAPub(pubKeyEC)[1:], nil +} + +func xor(one, other []byte) (xor []byte) { + xor = make([]byte, len(one)) + for i := 0; i < len(one); i++ { + xor[i] = one[i] ^ other[i] + } + return xor +} diff --git a/p2p/crypto_test.go b/p2p/crypto_test.go new file mode 100644 index 000000000..0a9d49f96 --- /dev/null +++ b/p2p/crypto_test.go @@ -0,0 +1,167 @@ +package p2p + +import ( + "bytes" + "crypto/ecdsa" + "crypto/rand" + "net" + "testing" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/obscuren/ecies" +) + +func TestPublicKeyEncoding(t *testing.T) { + prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader) + pub0 := &prv0.PublicKey + pub0s := crypto.FromECDSAPub(pub0) + pub1, err := importPublicKey(pub0s) + if err != nil { + t.Errorf("%v", err) + } + eciesPub1 := ecies.ImportECDSAPublic(pub1) + if eciesPub1 == nil { + t.Errorf("invalid ecdsa public key") + } + pub1s, err := exportPublicKey(pub1) + if err != nil { + t.Errorf("%v", err) + } + if len(pub1s) != 64 { + t.Errorf("wrong length expect 64, got", len(pub1s)) + } + pub2, err := importPublicKey(pub1s) + if err != nil { + t.Errorf("%v", err) + } + pub2s, err := exportPublicKey(pub2) + if err != nil { + t.Errorf("%v", err) + } + if !bytes.Equal(pub1s, pub2s) { + t.Errorf("exports dont match") + } + pub2sEC := crypto.FromECDSAPub(pub2) + if !bytes.Equal(pub0s, pub2sEC) { + t.Errorf("exports dont match") + } +} + +func TestSharedSecret(t *testing.T) { + prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader) + pub0 := &prv0.PublicKey + prv1, _ := crypto.GenerateKey() + pub1 := &prv1.PublicKey + + ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen) + if err != nil { + return + } + ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen) + if err != nil { + return + } + t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1) + if !bytes.Equal(ss0, ss1) { + t.Errorf("dont match :(") + } +} + +func TestCryptoHandshake(t *testing.T) { + testCryptoHandshake(newkey(), newkey(), nil, t) +} + +func TestCryptoHandshakeWithToken(t *testing.T) { + sessionToken := make([]byte, shaLen) + rand.Read(sessionToken) + testCryptoHandshake(newkey(), newkey(), sessionToken, t) +} + +func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *testing.T) { + var err error + // pub0 := &prv0.PublicKey + pub1 := &prv1.PublicKey + + // pub0s := crypto.FromECDSAPub(pub0) + pub1s := crypto.FromECDSAPub(pub1) + + // simulate handshake by feeding output to input + // initiator sends handshake 'auth' + auth, initNonce, randomPrivKey, err := authMsg(prv0, pub1s, sessionToken) + if err != nil { + t.Errorf("%v", err) + } + t.Logf("-> %v", hexkey(auth)) + + // receiver reads auth and responds with response + response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1) + if err != nil { + t.Errorf("%v", err) + } + t.Logf("<- %v\n", hexkey(response)) + + // initiator reads receiver's response and the key exchange completes + recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0) + if err != nil { + t.Errorf("completeHandshake error: %v", err) + } + + // now both parties should have the same session parameters + initSessionToken, err := newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey) + if err != nil { + t.Errorf("newSession error: %v", err) + } + + recSessionToken, err := newSession(remoteInitNonce, remoteRecNonce, remoteRandomPrivKey, remoteInitRandomPubKey) + if err != nil { + t.Errorf("newSession error: %v", err) + } + + // fmt.Printf("\nauth (%v) %x\n\nresp (%v) %x\n\n", len(auth), auth, len(response), response) + + // fmt.Printf("\nauth %x\ninitNonce %x\nresponse%x\nremoteRecNonce %x\nremoteInitNonce %x\nremoteRandomPubKey %x\nrecNonce %x\nremoteInitRandomPubKey %x\ninitSessionToken %x\n\n", auth, initNonce, response, remoteRecNonce, remoteInitNonce, remoteRandomPubKey, recNonce, remoteInitRandomPubKey, initSessionToken) + + if !bytes.Equal(initNonce, remoteInitNonce) { + t.Errorf("nonces do not match") + } + if !bytes.Equal(recNonce, remoteRecNonce) { + t.Errorf("receiver nonces do not match") + } + if !bytes.Equal(initSessionToken, recSessionToken) { + t.Errorf("session tokens do not match") + } +} + +func TestHandshake(t *testing.T) { + defer testlog(t).detach() + + prv0, _ := crypto.GenerateKey() + prv1, _ := crypto.GenerateKey() + pub0s, _ := exportPublicKey(&prv0.PublicKey) + pub1s, _ := exportPublicKey(&prv1.PublicKey) + rw0, rw1 := net.Pipe() + tokens := make(chan []byte) + + go func() { + token, err := outboundEncHandshake(rw0, prv0, pub1s, nil) + if err != nil { + t.Errorf("outbound side error: %v", err) + } + tokens <- token + }() + go func() { + token, remotePubkey, err := inboundEncHandshake(rw1, prv1, nil) + if err != nil { + t.Errorf("inbound side error: %v", err) + } + if !bytes.Equal(remotePubkey, pub0s) { + t.Errorf("inbound side returned wrong remote pubkey\n got: %x\n want: %x", remotePubkey, pub0s) + } + tokens <- token + }() + + t1, t2 := <-tokens, <-tokens + if !bytes.Equal(t1, t2) { + t.Error("session token mismatch") + } +} diff --git a/p2p/discover/node.go b/p2p/discover/node.go new file mode 100644 index 000000000..c6d2e9766 --- /dev/null +++ b/p2p/discover/node.go @@ -0,0 +1,291 @@ +package discover + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "encoding/hex" + "errors" + "fmt" + "io" + "math/rand" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/ethereum/go-ethereum/crypto/secp256k1" + "github.com/ethereum/go-ethereum/rlp" +) + +const nodeIDBits = 512 + +// Node represents a host on the network. +type Node struct { + ID NodeID + IP net.IP + + DiscPort int // UDP listening port for discovery protocol + TCPPort int // TCP listening port for RLPx + + active time.Time +} + +func newNode(id NodeID, addr *net.UDPAddr) *Node { + return &Node{ + ID: id, + IP: addr.IP, + DiscPort: addr.Port, + TCPPort: addr.Port, + active: time.Now(), + } +} + +func (n *Node) isValid() bool { + // TODO: don't accept localhost, LAN addresses from internet hosts + return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0 +} + +// The string representation of a Node is a URL. +// Please see ParseNode for a description of the format. +func (n *Node) String() string { + addr := net.TCPAddr{IP: n.IP, Port: n.TCPPort} + u := url.URL{ + Scheme: "enode", + User: url.User(fmt.Sprintf("%x", n.ID[:])), + Host: addr.String(), + } + if n.DiscPort != n.TCPPort { + u.RawQuery = "discport=" + strconv.Itoa(n.DiscPort) + } + return u.String() +} + +// ParseNode parses a node URL. +// +// A node URL has scheme "enode". +// +// The hexadecimal node ID is encoded in the username portion of the +// URL, separated from the host by an @ sign. The hostname can only be +// given as an IP address, DNS domain names are not allowed. The port +// in the host name section is the TCP listening port. If the TCP and +// UDP (discovery) ports differ, the UDP port is specified as query +// parameter "discport". +// +// In the following example, the node URL describes +// a node with IP address 10.3.58.6, TCP listening port 30303 +// and UDP discovery port 30301. +// +// enode://<hex node id>@10.3.58.6:30303?discport=30301 +func ParseNode(rawurl string) (*Node, error) { + var n Node + u, err := url.Parse(rawurl) + if u.Scheme != "enode" { + return nil, errors.New("invalid URL scheme, want \"enode\"") + } + if u.User == nil { + return nil, errors.New("does not contain node ID") + } + if n.ID, err = HexID(u.User.String()); err != nil { + return nil, fmt.Errorf("invalid node ID (%v)", err) + } + ip, port, err := net.SplitHostPort(u.Host) + if err != nil { + return nil, fmt.Errorf("invalid host: %v", err) + } + if n.IP = net.ParseIP(ip); n.IP == nil { + return nil, errors.New("invalid IP address") + } + if n.TCPPort, err = strconv.Atoi(port); err != nil { + return nil, errors.New("invalid port") + } + qv := u.Query() + if qv.Get("discport") == "" { + n.DiscPort = n.TCPPort + } else { + if n.DiscPort, err = strconv.Atoi(qv.Get("discport")); err != nil { + return nil, errors.New("invalid discport in query") + } + } + return &n, nil +} + +// MustParseNode parses a node URL. It panics if the URL is not valid. +func MustParseNode(rawurl string) *Node { + n, err := ParseNode(rawurl) + if err != nil { + panic("invalid node URL: " + err.Error()) + } + return n +} + +func (n Node) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, rpcNode{IP: n.IP.String(), Port: uint16(n.TCPPort), ID: n.ID}) +} +func (n *Node) DecodeRLP(s *rlp.Stream) (err error) { + var ext rpcNode + if err = s.Decode(&ext); err == nil { + n.TCPPort = int(ext.Port) + n.DiscPort = int(ext.Port) + n.ID = ext.ID + if n.IP = net.ParseIP(ext.IP); n.IP == nil { + return errors.New("invalid IP string") + } + } + return err +} + +// NodeID is a unique identifier for each node. +// The node identifier is a marshaled elliptic curve public key. +type NodeID [nodeIDBits / 8]byte + +// NodeID prints as a long hexadecimal number. +func (n NodeID) String() string { + return fmt.Sprintf("%#x", n[:]) +} + +// The Go syntax representation of a NodeID is a call to HexID. +func (n NodeID) GoString() string { + return fmt.Sprintf("discover.HexID(\"%#x\")", n[:]) +} + +// HexID converts a hex string to a NodeID. +// The string may be prefixed with 0x. +func HexID(in string) (NodeID, error) { + if strings.HasPrefix(in, "0x") { + in = in[2:] + } + var id NodeID + b, err := hex.DecodeString(in) + if err != nil { + return id, err + } else if len(b) != len(id) { + return id, fmt.Errorf("wrong length, need %d hex bytes", len(id)) + } + copy(id[:], b) + return id, nil +} + +// MustHexID converts a hex string to a NodeID. +// It panics if the string is not a valid NodeID. +func MustHexID(in string) NodeID { + id, err := HexID(in) + if err != nil { + panic(err) + } + return id +} + +// PubkeyID returns a marshaled representation of the given public key. +func PubkeyID(pub *ecdsa.PublicKey) NodeID { + var id NodeID + pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y) + if len(pbytes)-1 != len(id) { + panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes))) + } + copy(id[:], pbytes[1:]) + return id +} + +// recoverNodeID computes the public key used to sign the +// given hash from the signature. +func recoverNodeID(hash, sig []byte) (id NodeID, err error) { + pubkey, err := secp256k1.RecoverPubkey(hash, sig) + if err != nil { + return id, err + } + if len(pubkey)-1 != len(id) { + return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8) + } + for i := range id { + id[i] = pubkey[i+1] + } + return id, nil +} + +// distcmp compares the distances a->target and b->target. +// Returns -1 if a is closer to target, 1 if b is closer to target +// and 0 if they are equal. +func distcmp(target, a, b NodeID) int { + for i := range target { + da := a[i] ^ target[i] + db := b[i] ^ target[i] + if da > db { + return 1 + } else if da < db { + return -1 + } + } + return 0 +} + +// table of leading zero counts for bytes [0..255] +var lzcount = [256]int{ + 8, 7, 6, 6, 5, 5, 5, 5, + 4, 4, 4, 4, 4, 4, 4, 4, + 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, + 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, +} + +// logdist returns the logarithmic distance between a and b, log2(a ^ b). +func logdist(a, b NodeID) int { + lz := 0 + for i := range a { + x := a[i] ^ b[i] + if x == 0 { + lz += 8 + } else { + lz += lzcount[x] + break + } + } + return len(a)*8 - lz +} + +// randomID returns a random NodeID such that logdist(a, b) == n +func randomID(a NodeID, n int) (b NodeID) { + if n == 0 { + return a + } + // flip bit at position n, fill the rest with random bits + b = a + pos := len(a) - n/8 - 1 + bit := byte(0x01) << (byte(n%8) - 1) + if bit == 0 { + pos++ + bit = 0x80 + } + b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits + for i := pos + 1; i < len(a); i++ { + b[i] = byte(rand.Intn(255)) + } + return b +} diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go new file mode 100644 index 000000000..ae82ae4f1 --- /dev/null +++ b/p2p/discover/node_test.go @@ -0,0 +1,201 @@ +package discover + +import ( + "math/big" + "math/rand" + "net" + "reflect" + "testing" + "testing/quick" + "time" + + "github.com/ethereum/go-ethereum/crypto" +) + +var ( + quickrand = rand.New(rand.NewSource(time.Now().Unix())) + quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand} +) + +var parseNodeTests = []struct { + rawurl string + wantError string + wantResult *Node +}{ + { + rawurl: "http://foobar", + wantError: `invalid URL scheme, want "enode"`, + }, + { + rawurl: "enode://foobar", + wantError: `does not contain node ID`, + }, + { + rawurl: "enode://01010101@123.124.125.126:3", + wantError: `invalid node ID (wrong length, need 64 hex bytes)`, + }, + { + rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@hostname:3", + wantError: `invalid IP address`, + }, + { + rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo", + wantError: `invalid port`, + }, + { + rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo", + wantError: `invalid discport in query`, + }, + { + rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150", + wantResult: &Node{ + ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + IP: net.ParseIP("127.0.0.1"), + DiscPort: 52150, + TCPPort: 52150, + }, + }, + { + rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150", + wantResult: &Node{ + ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + IP: net.ParseIP("::"), + DiscPort: 52150, + TCPPort: 52150, + }, + }, + { + rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=223344", + wantResult: &Node{ + ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + IP: net.ParseIP("127.0.0.1"), + DiscPort: 223344, + TCPPort: 52150, + }, + }, +} + +func TestParseNode(t *testing.T) { + for i, test := range parseNodeTests { + n, err := ParseNode(test.rawurl) + if err == nil && test.wantError != "" { + t.Errorf("test %d: got nil error, expected %#q", i, test.wantError) + continue + } + if err != nil && err.Error() != test.wantError { + t.Errorf("test %d: got error %#q, expected %#q", i, err.Error(), test.wantError) + continue + } + if !reflect.DeepEqual(n, test.wantResult) { + t.Errorf("test %d: result mismatch:\ngot: %#v, want: %#v", i, n, test.wantResult) + } + } +} + +func TestNodeString(t *testing.T) { + for i, test := range parseNodeTests { + if test.wantError != "" { + continue + } + str := test.wantResult.String() + if str != test.rawurl { + t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.rawurl) + } + } +} + +func TestHexID(t *testing.T) { + ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188} + id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc") + id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc") + + if id1 != ref { + t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:]) + } + if id2 != ref { + t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:]) + } +} + +func TestNodeID_recover(t *testing.T) { + prv := newkey() + hash := make([]byte, 32) + sig, err := crypto.Sign(hash, prv) + if err != nil { + t.Fatalf("signing error: %v", err) + } + + pub := PubkeyID(&prv.PublicKey) + recpub, err := recoverNodeID(hash, sig) + if err != nil { + t.Fatalf("recovery error: %v", err) + } + if pub != recpub { + t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub) + } +} + +func TestNodeID_distcmp(t *testing.T) { + distcmpBig := func(target, a, b NodeID) int { + tbig := new(big.Int).SetBytes(target[:]) + abig := new(big.Int).SetBytes(a[:]) + bbig := new(big.Int).SetBytes(b[:]) + return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig)) + } + if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil { + t.Error(err) + } +} + +// the random tests is likely to miss the case where they're equal. +func TestNodeID_distcmpEqual(t *testing.T) { + base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0} + if distcmp(base, x, x) != 0 { + t.Errorf("distcmp(base, x, x) != 0") + } +} + +func TestNodeID_logdist(t *testing.T) { + logdistBig := func(a, b NodeID) int { + abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:]) + return new(big.Int).Xor(abig, bbig).BitLen() + } + if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil { + t.Error(err) + } +} + +// the random tests is likely to miss the case where they're equal. +func TestNodeID_logdistEqual(t *testing.T) { + x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + if logdist(x, x) != 0 { + t.Errorf("logdist(x, x) != 0") + } +} + +func TestNodeID_randomID(t *testing.T) { + // we don't use quick.Check here because its output isn't + // very helpful when the test fails. + for i := 0; i < quickcfg.MaxCount; i++ { + a := gen(NodeID{}, quickrand).(NodeID) + dist := quickrand.Intn(len(NodeID{}) * 8) + result := randomID(a, dist) + actualdist := logdist(result, a) + + if dist != actualdist { + t.Log("a: ", a) + t.Log("result:", result) + t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist) + } + } +} + +func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value { + var id NodeID + m := rand.Intn(len(id)) + for i := len(id) - 1; i > m; i-- { + id[i] = byte(rand.Uint32()) + } + return reflect.ValueOf(id) +} diff --git a/p2p/discover/table.go b/p2p/discover/table.go new file mode 100644 index 000000000..e3bec9328 --- /dev/null +++ b/p2p/discover/table.go @@ -0,0 +1,280 @@ +// Package discover implements the Node Discovery Protocol. +// +// The Node Discovery protocol provides a way to find RLPx nodes that +// can be connected to. It uses a Kademlia-like protocol to maintain a +// distributed database of the IDs and endpoints of all listening +// nodes. +package discover + +import ( + "net" + "sort" + "sync" + "time" +) + +const ( + alpha = 3 // Kademlia concurrency factor + bucketSize = 16 // Kademlia bucket size + nBuckets = nodeIDBits + 1 // Number of buckets +) + +type Table struct { + mutex sync.Mutex // protects buckets, their content, and nursery + buckets [nBuckets]*bucket // index of known nodes by distance + nursery []*Node // bootstrap nodes + + net transport + self *Node // metadata of the local node +} + +// transport is implemented by the UDP transport. +// it is an interface so we can test without opening lots of UDP +// sockets and without generating a private key. +type transport interface { + ping(*Node) error + findnode(e *Node, target NodeID) ([]*Node, error) + close() +} + +// bucket contains nodes, ordered by their last activity. +type bucket struct { + lastLookup time.Time + entries []*Node +} + +func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table { + tab := &Table{net: t, self: newNode(ourID, ourAddr)} + for i := range tab.buckets { + tab.buckets[i] = new(bucket) + } + return tab +} + +// Self returns the local node ID. +func (tab *Table) Self() NodeID { + return tab.self.ID +} + +// Close terminates the network listener. +func (tab *Table) Close() { + tab.net.close() +} + +// Bootstrap sets the bootstrap nodes. These nodes are used to connect +// to the network if the table is empty. Bootstrap will also attempt to +// fill the table by performing random lookup operations on the +// network. +func (tab *Table) Bootstrap(nodes []*Node) { + tab.mutex.Lock() + // TODO: maybe filter nodes with bad fields (nil, etc.) to avoid strange crashes + tab.nursery = make([]*Node, 0, len(nodes)) + for _, n := range nodes { + cpy := *n + tab.nursery = append(tab.nursery, &cpy) + } + tab.mutex.Unlock() + tab.refresh() +} + +// Lookup performs a network search for nodes close +// to the given target. It approaches the target by querying +// nodes that are closer to it on each iteration. +func (tab *Table) Lookup(target NodeID) []*Node { + var ( + asked = make(map[NodeID]bool) + seen = make(map[NodeID]bool) + reply = make(chan []*Node, alpha) + pendingQueries = 0 + ) + // don't query further if we hit the target or ourself. + // unlikely to happen often in practice. + asked[target] = true + asked[tab.self.ID] = true + + tab.mutex.Lock() + // update last lookup stamp (for refresh logic) + tab.buckets[logdist(tab.self.ID, target)].lastLookup = time.Now() + // generate initial result set + result := tab.closest(target, bucketSize) + tab.mutex.Unlock() + + for { + // ask the alpha closest nodes that we haven't asked yet + for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ { + n := result.entries[i] + if !asked[n.ID] { + asked[n.ID] = true + pendingQueries++ + go func() { + result, _ := tab.net.findnode(n, target) + reply <- result + }() + } + } + if pendingQueries == 0 { + // we have asked all closest nodes, stop the search + break + } + + // wait for the next reply + for _, n := range <-reply { + cn := n + if !seen[n.ID] { + seen[n.ID] = true + result.push(cn, bucketSize) + } + } + pendingQueries-- + } + return result.entries +} + +// refresh performs a lookup for a random target to keep buckets full. +func (tab *Table) refresh() { + ld := -1 // logdist of chosen bucket + tab.mutex.Lock() + for i, b := range tab.buckets { + if i > 0 && b.lastLookup.Before(time.Now().Add(-1*time.Hour)) { + ld = i + break + } + } + tab.mutex.Unlock() + + result := tab.Lookup(randomID(tab.self.ID, ld)) + if len(result) == 0 { + // bootstrap the table with a self lookup + tab.mutex.Lock() + tab.add(tab.nursery) + tab.mutex.Unlock() + tab.Lookup(tab.self.ID) + // TODO: the Kademlia paper says that we're supposed to perform + // random lookups in all buckets further away than our closest neighbor. + } +} + +// closest returns the n nodes in the table that are closest to the +// given id. The caller must hold tab.mutex. +func (tab *Table) closest(target NodeID, nresults int) *nodesByDistance { + // This is a very wasteful way to find the closest nodes but + // obviously correct. I believe that tree-based buckets would make + // this easier to implement efficiently. + close := &nodesByDistance{target: target} + for _, b := range tab.buckets { + for _, n := range b.entries { + close.push(n, nresults) + } + } + return close +} + +func (tab *Table) len() (n int) { + for _, b := range tab.buckets { + n += len(b.entries) + } + return n +} + +// bumpOrAdd updates the activity timestamp for the given node and +// attempts to insert the node into a bucket. The returned Node might +// not be part of the table. The caller must hold tab.mutex. +func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) { + b := tab.buckets[logdist(tab.self.ID, node)] + if n = b.bump(node); n == nil { + n = newNode(node, from) + if len(b.entries) == bucketSize { + tab.pingReplace(n, b) + } else { + b.entries = append(b.entries, n) + } + } + return n +} + +func (tab *Table) pingReplace(n *Node, b *bucket) { + old := b.entries[bucketSize-1] + go func() { + if err := tab.net.ping(old); err == nil { + // it responded, we don't need to replace it. + return + } + // it didn't respond, replace the node if it is still the oldest node. + tab.mutex.Lock() + if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old { + // slide down other entries and put the new one in front. + // TODO: insert in correct position to keep the order + copy(b.entries[1:], b.entries) + b.entries[0] = n + } + tab.mutex.Unlock() + }() +} + +// bump updates the activity timestamp for the given node. +// The caller must hold tab.mutex. +func (tab *Table) bump(node NodeID) { + tab.buckets[logdist(tab.self.ID, node)].bump(node) +} + +// add puts the entries into the table if their corresponding +// bucket is not full. The caller must hold tab.mutex. +func (tab *Table) add(entries []*Node) { +outer: + for _, n := range entries { + if n == nil || n.ID == tab.self.ID { + // skip bad entries. The RLP decoder returns nil for empty + // input lists. + continue + } + bucket := tab.buckets[logdist(tab.self.ID, n.ID)] + for i := range bucket.entries { + if bucket.entries[i].ID == n.ID { + // already in bucket + continue outer + } + } + if len(bucket.entries) < bucketSize { + bucket.entries = append(bucket.entries, n) + } + } +} + +func (b *bucket) bump(id NodeID) *Node { + for i, n := range b.entries { + if n.ID == id { + n.active = time.Now() + // move it to the front + copy(b.entries[1:], b.entries[:i+1]) + b.entries[0] = n + return n + } + } + return nil +} + +// nodesByDistance is a list of nodes, ordered by +// distance to target. +type nodesByDistance struct { + entries []*Node + target NodeID +} + +// push adds the given node to the list, keeping the total size below maxElems. +func (h *nodesByDistance) push(n *Node, maxElems int) { + ix := sort.Search(len(h.entries), func(i int) bool { + return distcmp(h.target, h.entries[i].ID, n.ID) > 0 + }) + if len(h.entries) < maxElems { + h.entries = append(h.entries, n) + } + if ix == len(h.entries) { + // farther away than all nodes we already have. + // if there was room for it, the node is now the last element. + } else { + // slide existing entries down to make room + // this will overwrite the entry we just appended. + copy(h.entries[ix+1:], h.entries[ix:]) + h.entries[ix] = n + } +} diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go new file mode 100644 index 000000000..08faea68e --- /dev/null +++ b/p2p/discover/table_test.go @@ -0,0 +1,311 @@ +package discover + +import ( + "crypto/ecdsa" + "errors" + "fmt" + "math/rand" + "net" + "reflect" + "testing" + "testing/quick" + "time" + + "github.com/ethereum/go-ethereum/crypto" +) + +func TestTable_bumpOrAddBucketAssign(t *testing.T) { + tab := newTable(nil, NodeID{}, &net.UDPAddr{}) + for i := 1; i < len(tab.buckets); i++ { + tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{}) + } + for i, b := range tab.buckets { + if i > 0 && len(b.entries) != 1 { + t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries)) + } + } +} + +func TestTable_bumpOrAddPingReplace(t *testing.T) { + pingC := make(pingC) + tab := newTable(pingC, NodeID{}, &net.UDPAddr{}) + last := fillBucket(tab, 200) + + // this bumpOrAdd should not replace the last node + // because the node replies to ping. + new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{}) + + pinged := <-pingC + if pinged != last.ID { + t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID) + } + + tab.mutex.Lock() + defer tab.mutex.Unlock() + if l := len(tab.buckets[200].entries); l != bucketSize { + t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l) + } + if !contains(tab.buckets[200].entries, last.ID) { + t.Error("last entry was removed") + } + if contains(tab.buckets[200].entries, new.ID) { + t.Error("new entry was added") + } +} + +func TestTable_bumpOrAddPingTimeout(t *testing.T) { + tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{}) + last := fillBucket(tab, 200) + + // this bumpOrAdd should replace the last node + // because the node does not reply to ping. + new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{}) + + // wait for async bucket update. damn. this needs to go away. + time.Sleep(2 * time.Millisecond) + + tab.mutex.Lock() + defer tab.mutex.Unlock() + if l := len(tab.buckets[200].entries); l != bucketSize { + t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l) + } + if contains(tab.buckets[200].entries, last.ID) { + t.Error("last entry was not removed") + } + if !contains(tab.buckets[200].entries, new.ID) { + t.Error("new entry was not added") + } +} + +func fillBucket(tab *Table, ld int) (last *Node) { + b := tab.buckets[ld] + for len(b.entries) < bucketSize { + b.entries = append(b.entries, &Node{ID: randomID(tab.self.ID, ld)}) + } + return b.entries[bucketSize-1] +} + +type pingC chan NodeID + +func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) { + panic("findnode called on pingRecorder") +} +func (t pingC) close() { + panic("close called on pingRecorder") +} +func (t pingC) ping(n *Node) error { + if t == nil { + return errTimeout + } + t <- n.ID + return nil +} + +func TestTable_bump(t *testing.T) { + tab := newTable(nil, NodeID{}, &net.UDPAddr{}) + + // add an old entry and two recent ones + oldactive := time.Now().Add(-2 * time.Minute) + old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive} + others := []*Node{ + &Node{ID: randomID(tab.self.ID, 200), active: time.Now()}, + &Node{ID: randomID(tab.self.ID, 200), active: time.Now()}, + } + tab.add(append(others, old)) + if tab.buckets[200].entries[0] == old { + t.Fatal("old entry is at front of bucket") + } + + // bumping the old entry should move it to the front + tab.bump(old.ID) + if old.active == oldactive { + t.Error("activity timestamp not updated") + } + if tab.buckets[200].entries[0] != old { + t.Errorf("bumped entry did not move to the front of bucket") + } +} + +func TestTable_closest(t *testing.T) { + t.Parallel() + + test := func(test *closeTest) bool { + // for any node table, Target and N + tab := newTable(nil, test.Self, &net.UDPAddr{}) + tab.add(test.All) + + // check that doClosest(Target, N) returns nodes + result := tab.closest(test.Target, test.N).entries + if hasDuplicates(result) { + t.Errorf("result contains duplicates") + return false + } + if !sortedByDistanceTo(test.Target, result) { + t.Errorf("result is not sorted by distance to target") + return false + } + + // check that the number of results is min(N, tablen) + wantN := test.N + if tlen := tab.len(); tlen < test.N { + wantN = tlen + } + if len(result) != wantN { + t.Errorf("wrong number of nodes: got %d, want %d", len(result), wantN) + return false + } else if len(result) == 0 { + return true // no need to check distance + } + + // check that the result nodes have minimum distance to target. + for _, b := range tab.buckets { + for _, n := range b.entries { + if contains(result, n.ID) { + continue // don't run the check below for nodes in result + } + farthestResult := result[len(result)-1].ID + if distcmp(test.Target, n.ID, farthestResult) < 0 { + t.Errorf("table contains node that is closer to target but it's not in result") + t.Logf(" Target: %v", test.Target) + t.Logf(" Farthest Result: %v", farthestResult) + t.Logf(" ID: %v", n.ID) + return false + } + } + } + return true + } + if err := quick.Check(test, quickcfg); err != nil { + t.Error(err) + } +} + +type closeTest struct { + Self NodeID + Target NodeID + All []*Node + N int +} + +func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { + t := &closeTest{ + Self: gen(NodeID{}, rand).(NodeID), + Target: gen(NodeID{}, rand).(NodeID), + N: rand.Intn(bucketSize), + } + for _, id := range gen([]NodeID{}, rand).([]NodeID) { + t.All = append(t.All, &Node{ID: id}) + } + return reflect.ValueOf(t) +} + +func TestTable_Lookup(t *testing.T) { + self := gen(NodeID{}, quickrand).(NodeID) + target := randomID(self, 200) + transport := findnodeOracle{t, target} + tab := newTable(transport, self, &net.UDPAddr{}) + + // lookup on empty table returns no nodes + if results := tab.Lookup(target); len(results) > 0 { + t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) + } + // seed table with initial node (otherwise lookup will terminate immediately) + tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200}) + + results := tab.Lookup(target) + t.Logf("results:") + for _, e := range results { + t.Logf(" ld=%d, %v", logdist(target, e.ID), e.ID) + } + if len(results) != bucketSize { + t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize) + } + if hasDuplicates(results) { + t.Errorf("result set contains duplicate entries") + } + if !sortedByDistanceTo(target, results) { + t.Errorf("result set not sorted by distance to target") + } + if !contains(results, target) { + t.Errorf("result set does not contain target") + } +} + +// findnode on this transport always returns at least one node +// that is one bucket closer to the target. +type findnodeOracle struct { + t *testing.T + target NodeID +} + +func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) { + t.t.Logf("findnode query at dist %d", n.DiscPort) + // current log distance is encoded in port number + var result []*Node + switch n.DiscPort { + case 0: + panic("query to node at distance 0") + default: + // TODO: add more randomness to distances + next := n.DiscPort - 1 + for i := 0; i < bucketSize; i++ { + result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next}) + } + } + return result, nil +} + +func (t findnodeOracle) close() {} + +func (t findnodeOracle) ping(n *Node) error { + return errors.New("ping is not supported by this transport") +} + +func hasDuplicates(slice []*Node) bool { + seen := make(map[NodeID]bool) + for _, e := range slice { + if seen[e.ID] { + return true + } + seen[e.ID] = true + } + return false +} + +func sortedByDistanceTo(distbase NodeID, slice []*Node) bool { + var last NodeID + for i, e := range slice { + if i > 0 && distcmp(distbase, e.ID, last) < 0 { + return false + } + last = e.ID + } + return true +} + +func contains(ns []*Node, id NodeID) bool { + for _, n := range ns { + if n.ID == id { + return true + } + } + return false +} + +// gen wraps quick.Value so it's easier to use. +// it generates a random value of the given value's type. +func gen(typ interface{}, rand *rand.Rand) interface{} { + v, ok := quick.Value(reflect.TypeOf(typ), rand) + if !ok { + panic(fmt.Sprintf("couldn't generate random value of type %T", typ)) + } + return v.Interface() +} + +func newkey() *ecdsa.PrivateKey { + key, err := crypto.GenerateKey() + if err != nil { + panic("couldn't generate key: " + err.Error()) + } + return key +} diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go new file mode 100644 index 000000000..b2a895442 --- /dev/null +++ b/p2p/discover/udp.go @@ -0,0 +1,431 @@ +package discover + +import ( + "bytes" + "crypto/ecdsa" + "errors" + "fmt" + "net" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/rlp" +) + +var log = logger.NewLogger("P2P Discovery") + +// Errors +var ( + errPacketTooSmall = errors.New("too small") + errBadHash = errors.New("bad hash") + errExpired = errors.New("expired") + errTimeout = errors.New("RPC timeout") + errClosed = errors.New("socket closed") +) + +// Timeouts +const ( + respTimeout = 300 * time.Millisecond + sendTimeout = 300 * time.Millisecond + expiration = 20 * time.Second + + refreshInterval = 1 * time.Hour +) + +// RPC packet types +const ( + pingPacket = iota + 1 // zero is 'reserved' + pongPacket + findnodePacket + neighborsPacket +) + +// RPC request structures +type ( + ping struct { + IP string // our IP + Port uint16 // our port + Expiration uint64 + } + + // reply to Ping + pong struct { + ReplyTok []byte + Expiration uint64 + } + + findnode struct { + // Id to look up. The responding node will send back nodes + // closest to the target. + Target NodeID + Expiration uint64 + } + + // reply to findnode + neighbors struct { + Nodes []*Node + Expiration uint64 + } +) + +type rpcNode struct { + IP string + Port uint16 + ID NodeID +} + +// udp implements the RPC protocol. +type udp struct { + conn *net.UDPConn + priv *ecdsa.PrivateKey + addpending chan *pending + replies chan reply + closing chan struct{} + nat nat.Interface + + *Table +} + +// pending represents a pending reply. +// +// some implementations of the protocol wish to send more than one +// reply packet to findnode. in general, any neighbors packet cannot +// be matched up with a specific findnode packet. +// +// our implementation handles this by storing a callback function for +// each pending reply. incoming packets from a node are dispatched +// to all the callback functions for that node. +type pending struct { + // these fields must match in the reply. + from NodeID + ptype byte + + // time when the request must complete + deadline time.Time + + // callback is called when a matching reply arrives. if it returns + // true, the callback is removed from the pending reply queue. + // if it returns false, the reply is considered incomplete and + // the callback will be invoked again for the next matching reply. + callback func(resp interface{}) (done bool) + + // errc receives nil when the callback indicates completion or an + // error if no further reply is received within the timeout. + errc chan<- error +} + +type reply struct { + from NodeID + ptype byte + data interface{} +} + +// ListenUDP returns a new table that listens for UDP packets on laddr. +func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table, error) { + addr, err := net.ResolveUDPAddr("udp", laddr) + if err != nil { + return nil, err + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + return nil, err + } + udp := &udp{ + conn: conn, + priv: priv, + closing: make(chan struct{}), + addpending: make(chan *pending), + replies: make(chan reply), + } + + realaddr := conn.LocalAddr().(*net.UDPAddr) + if natm != nil { + if !realaddr.IP.IsLoopback() { + go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") + } + // TODO: react to external IP changes over time. + if ext, err := natm.ExternalIP(); err == nil { + realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port} + } + } + udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr) + + go udp.loop() + go udp.readLoop() + log.Infoln("Listening, ", udp.self) + return udp.Table, nil +} + +func (t *udp) close() { + close(t.closing) + t.conn.Close() + // TODO: wait for the loops to end. +} + +// ping sends a ping message to the given node and waits for a reply. +func (t *udp) ping(e *Node) error { + // TODO: maybe check for ReplyTo field in callback to measure RTT + errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true }) + t.send(e, pingPacket, ping{ + IP: t.self.IP.String(), + Port: uint16(t.self.TCPPort), + Expiration: uint64(time.Now().Add(expiration).Unix()), + }) + return <-errc +} + +// findnode sends a findnode request to the given node and waits until +// the node has sent up to k neighbors. +func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) { + nodes := make([]*Node, 0, bucketSize) + nreceived := 0 + errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool { + reply := r.(*neighbors) + for _, n := range reply.Nodes { + nreceived++ + if n.isValid() { + nodes = append(nodes, n) + } + } + return nreceived >= bucketSize + }) + + t.send(to, findnodePacket, findnode{ + Target: target, + Expiration: uint64(time.Now().Add(expiration).Unix()), + }) + err := <-errc + return nodes, err +} + +// pending adds a reply callback to the pending reply queue. +// see the documentation of type pending for a detailed explanation. +func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error { + ch := make(chan error, 1) + p := &pending{from: id, ptype: ptype, callback: callback, errc: ch} + select { + case t.addpending <- p: + // loop will handle it + case <-t.closing: + ch <- errClosed + } + return ch +} + +// loop runs in its own goroutin. it keeps track of +// the refresh timer and the pending reply queue. +func (t *udp) loop() { + var ( + pending []*pending + nextDeadline time.Time + timeout = time.NewTimer(0) + refresh = time.NewTicker(refreshInterval) + ) + <-timeout.C // ignore first timeout + defer refresh.Stop() + defer timeout.Stop() + + rearmTimeout := func() { + if len(pending) == 0 || nextDeadline == pending[0].deadline { + return + } + nextDeadline = pending[0].deadline + timeout.Reset(nextDeadline.Sub(time.Now())) + } + + for { + select { + case <-refresh.C: + go t.refresh() + + case <-t.closing: + for _, p := range pending { + p.errc <- errClosed + } + return + + case p := <-t.addpending: + p.deadline = time.Now().Add(respTimeout) + pending = append(pending, p) + rearmTimeout() + + case reply := <-t.replies: + // run matching callbacks, remove if they return false. + for i, p := range pending { + if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) { + p.errc <- nil + copy(pending[i:], pending[i+1:]) + pending = pending[:len(pending)-1] + i-- + } + } + rearmTimeout() + + case now := <-timeout.C: + // notify and remove callbacks whose deadline is in the past. + i := 0 + for ; i < len(pending) && now.After(pending[i].deadline); i++ { + pending[i].errc <- errTimeout + } + if i > 0 { + copy(pending, pending[i:]) + pending = pending[:len(pending)-i] + } + rearmTimeout() + } + } +} + +const ( + macSize = 256 / 8 + sigSize = 520 / 8 + headSize = macSize + sigSize // space of packet frame data +) + +var headSpace = make([]byte, headSize) + +func (t *udp) send(to *Node, ptype byte, req interface{}) error { + b := new(bytes.Buffer) + b.Write(headSpace) + b.WriteByte(ptype) + if err := rlp.Encode(b, req); err != nil { + log.Errorln("error encoding packet:", err) + return err + } + + packet := b.Bytes() + sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv) + if err != nil { + log.Errorln("could not sign packet:", err) + return err + } + copy(packet[macSize:], sig) + // add the hash to the front. Note: this doesn't protect the + // packet in any way. Our public key will be part of this hash in + // the future. + copy(packet, crypto.Sha3(packet[macSize:])) + + toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort} + log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req) + if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil { + log.DebugDetailln("UDP send failed:", err) + } + return err +} + +// readLoop runs in its own goroutine. it handles incoming UDP packets. +func (t *udp) readLoop() { + defer t.conn.Close() + buf := make([]byte, 4096) // TODO: good buffer size + for { + nbytes, from, err := t.conn.ReadFromUDP(buf) + if err != nil { + return + } + if err := t.packetIn(from, buf[:nbytes]); err != nil { + log.Debugf("Bad packet from %v: %v\n", from, err) + } + } +} + +func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error { + if len(buf) < headSize+1 { + return errPacketTooSmall + } + hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] + shouldhash := crypto.Sha3(buf[macSize:]) + if !bytes.Equal(hash, shouldhash) { + return errBadHash + } + fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig) + if err != nil { + return err + } + + var req interface { + handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error + } + switch ptype := sigdata[0]; ptype { + case pingPacket: + req = new(ping) + case pongPacket: + req = new(pong) + case findnodePacket: + req = new(findnode) + case neighborsPacket: + req = new(neighbors) + default: + return fmt.Errorf("unknown type: %d", ptype) + } + if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil { + return err + } + log.DebugDetailf("<<< %v %T %v\n", from, req, req) + return req.handle(t, from, fromID, hash) +} + +func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { + if expired(req.Expiration) { + return errExpired + } + t.mutex.Lock() + // Note: we're ignoring the provided IP address right now + n := t.bumpOrAdd(fromID, from) + if req.Port != 0 { + n.TCPPort = int(req.Port) + } + t.mutex.Unlock() + + t.send(n, pongPacket, pong{ + ReplyTok: mac, + Expiration: uint64(time.Now().Add(expiration).Unix()), + }) + return nil +} + +func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { + if expired(req.Expiration) { + return errExpired + } + t.mutex.Lock() + t.bump(fromID) + t.mutex.Unlock() + + t.replies <- reply{fromID, pongPacket, req} + return nil +} + +func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { + if expired(req.Expiration) { + return errExpired + } + t.mutex.Lock() + e := t.bumpOrAdd(fromID, from) + closest := t.closest(req.Target, bucketSize).entries + t.mutex.Unlock() + + t.send(e, neighborsPacket, neighbors{ + Nodes: closest, + Expiration: uint64(time.Now().Add(expiration).Unix()), + }) + return nil +} + +func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { + if expired(req.Expiration) { + return errExpired + } + t.mutex.Lock() + t.bump(fromID) + t.add(req.Nodes) + t.mutex.Unlock() + + t.replies <- reply{fromID, neighborsPacket, req} + return nil +} + +func expired(ts uint64) bool { + return time.Unix(int64(ts), 0).Before(time.Now()) +} diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go new file mode 100644 index 000000000..0a8ff6358 --- /dev/null +++ b/p2p/discover/udp_test.go @@ -0,0 +1,211 @@ +package discover + +import ( + "fmt" + logpkg "log" + "net" + "os" + "testing" + "time" + + "github.com/ethereum/go-ethereum/logger" +) + +func init() { + logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel)) +} + +func TestUDP_ping(t *testing.T) { + t.Parallel() + + n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + defer n1.Close() + defer n2.Close() + + if err := n1.net.ping(n2.self); err != nil { + t.Fatalf("ping error: %v", err) + } + if find(n2, n1.self.ID) == nil { + t.Errorf("node 2 does not contain id of node 1") + } + if e := find(n1, n2.self.ID); e != nil { + t.Errorf("node 1 does contains id of node 2: %v", e) + } +} + +func find(tab *Table, id NodeID) *Node { + for _, b := range tab.buckets { + for _, e := range b.entries { + if e.ID == id { + return e + } + } + } + return nil +} + +func TestUDP_findnode(t *testing.T) { + t.Parallel() + + n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + defer n1.Close() + defer n2.Close() + + // put a few nodes into n2. the exact distribution shouldn't + // matter much, altough we need to take care not to overflow + // any bucket. + target := randomID(n1.self.ID, 100) + nodes := &nodesByDistance{target: target} + for i := 0; i < bucketSize; i++ { + n2.add([]*Node{&Node{ + IP: net.IP{1, 2, 3, byte(i)}, + DiscPort: i + 2, + TCPPort: i + 2, + ID: randomID(n2.self.ID, i+2), + }}) + } + n2.add(nodes.entries) + n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort}) + expected := n2.closest(target, bucketSize) + + err := runUDP(10, func() error { + result, _ := n1.net.findnode(n2.self, target) + if len(result) != bucketSize { + return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize) + } + for i := range result { + if result[i].ID != expected.entries[i].ID { + return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i]) + } + } + return nil + }) + if err != nil { + t.Error(err) + } +} + +func TestUDP_replytimeout(t *testing.T) { + t.Parallel() + + // reserve a port so we don't talk to an existing service by accident + addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + fd, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatal(err) + } + defer fd.Close() + + n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + defer n1.Close() + n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr)) + + if err := n1.net.ping(n2); err != errTimeout { + t.Error("expected timeout error, got", err) + } + + if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout { + t.Error("expected timeout error, got", err) + } else if len(result) > 0 { + t.Error("expected empty result, got", result) + } +} + +func TestUDP_findnodeMultiReply(t *testing.T) { + t.Parallel() + + n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) + udp2 := n2.net.(*udp) + defer n1.Close() + defer n2.Close() + + err := runUDP(10, func() error { + nodes := make([]*Node, bucketSize) + for i := range nodes { + nodes[i] = &Node{ + IP: net.IP{1, 2, 3, 4}, + DiscPort: i + 1, + TCPPort: i + 1, + ID: randomID(n2.self.ID, i+1), + } + } + + // ask N2 for neighbors. it will send an empty reply back. + // the request will wait for up to bucketSize replies. + resultc := make(chan []*Node) + errc := make(chan error) + go func() { + ns, err := n1.net.findnode(n2.self, n1.self.ID) + if err != nil { + errc <- err + } else { + resultc <- ns + } + }() + + // send a few more neighbors packets to N1. + // it should collect those. + for end := 0; end < len(nodes); { + off := end + if end = end + 5; end > len(nodes) { + end = len(nodes) + } + udp2.send(n1.self, neighborsPacket, neighbors{ + Nodes: nodes[off:end], + Expiration: uint64(time.Now().Add(10 * time.Second).Unix()), + }) + } + + // check that they are all returned. we cannot just check for + // equality because they might not be returned in the order they + // were sent. + var result []*Node + select { + case result = <-resultc: + case err := <-errc: + return err + } + if hasDuplicates(result) { + return fmt.Errorf("result slice contains duplicates") + } + if len(result) != len(nodes) { + return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes)) + } + matched := make(map[NodeID]bool) + for _, n := range result { + for _, expn := range nodes { + if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port { + matched[n.ID] = true + } + } + } + if len(matched) != len(nodes) { + return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes)) + } + return nil + }) + if err != nil { + t.Error(err) + } +} + +// runUDP runs a test n times and returns an error if the test failed +// in all n runs. This is necessary because UDP is unreliable even for +// connections on the local machine, causing test failures. +func runUDP(n int, test func() error) error { + errcount := 0 + errors := "" + for i := 0; i < n; i++ { + if err := test(); err != nil { + errors += fmt.Sprintf("\n#%d: %v", i, err) + errcount++ + } + } + if errcount == n { + return fmt.Errorf("failed on all %d iterations:%s", n, errors) + } + return nil +} diff --git a/p2p/message.go b/p2p/message.go index daf2bf05c..07916f7b3 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -1,6 +1,7 @@ package p2p import ( + "bufio" "bytes" "encoding/binary" "errors" @@ -8,12 +9,37 @@ import ( "io" "io/ioutil" "math/big" + "net" + "sync" "sync/atomic" + "time" "github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/rlp" ) +// parameters for frameRW +const ( + // maximum time allowed for reading a message header. + // this is effectively the amount of time a connection can be idle. + frameReadTimeout = 1 * time.Minute + + // maximum time allowed for reading the payload data of a message. + // this is shorter than (and distinct from) frameReadTimeout because + // the connection is not considered idle while a message is transferred. + // this also limits the payload size of messages to how much the connection + // can transfer within the timeout. + payloadReadTimeout = 5 * time.Second + + // maximum amount of time allowed for writing a complete message. + msgWriteTimeout = 5 * time.Second + + // messages smaller than this many bytes will be read at + // once before passing them to a protocol. this increases + // concurrency in the processing. + wholePayloadSize = 64 * 1024 +) + // Msg defines the structure of a p2p message. // // Note that a Msg can only be sent once since the Payload reader is @@ -74,11 +100,14 @@ type MsgWriter interface { // WriteMsg sends a message. It will block until the message's // Payload has been consumed by the other end. // - // Note that messages can be sent only once. + // Note that messages can be sent only once because their + // payload reader is drained. WriteMsg(Msg) error } // MsgReadWriter provides reading and writing of encoded messages. +// Implementations should ensure that ReadMsg and WriteMsg can be +// called simultaneously from multiple goroutines. type MsgReadWriter interface { MsgReader MsgWriter @@ -90,8 +119,45 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error { return w.WriteMsg(NewMsg(code, data...)) } +// frameRW is a MsgReadWriter that reads and writes devp2p message frames. +// As required by the interface, ReadMsg and WriteMsg can be called from +// multiple goroutines. +type frameRW struct { + net.Conn // make Conn methods available. be careful. + bufconn *bufio.ReadWriter + + // this channel is used to 'lend' bufconn to a caller of ReadMsg + // until the message payload has been consumed. the channel + // receives a value when EOF is reached on the payload, unblocking + // a pending call to ReadMsg. + rsync chan struct{} + + // this mutex guards writes to bufconn. + writeMu sync.Mutex +} + +func newFrameRW(conn net.Conn, timeout time.Duration) *frameRW { + rsync := make(chan struct{}, 1) + rsync <- struct{}{} + return &frameRW{ + Conn: conn, + bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), + rsync: rsync, + } +} + var magicToken = []byte{34, 64, 8, 145} +func (rw *frameRW) WriteMsg(msg Msg) error { + rw.writeMu.Lock() + defer rw.writeMu.Unlock() + rw.SetWriteDeadline(time.Now().Add(msgWriteTimeout)) + if err := writeMsg(rw.bufconn, msg); err != nil { + return err + } + return rw.bufconn.Flush() +} + func writeMsg(w io.Writer, msg Msg) error { // TODO: handle case when Size + len(code) + len(listhdr) overflows uint32 code := ethutil.Encode(uint32(msg.Code)) @@ -120,31 +186,51 @@ func makeListHeader(length uint32) []byte { return append([]byte{lenb}, enc...) } -// readMsg reads a message header from r. -// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer. -func readMsg(r rlp.ByteReader) (msg Msg, err error) { +func (rw *frameRW) ReadMsg() (msg Msg, err error) { + <-rw.rsync // wait until bufconn is ours + + rw.SetReadDeadline(time.Now().Add(frameReadTimeout)) + // read magic and payload size start := make([]byte, 8) - if _, err = io.ReadFull(r, start); err != nil { - return msg, newPeerError(errRead, "%v", err) + if _, err = io.ReadFull(rw.bufconn, start); err != nil { + return msg, err } if !bytes.HasPrefix(start, magicToken) { - return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken) + return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken) } size := binary.BigEndian.Uint32(start[4:]) // decode start of RLP message to get the message code - posr := &postrack{r, 0} + posr := &postrack{rw.bufconn, 0} s := rlp.NewStream(posr) if _, err := s.List(); err != nil { return msg, err } - code, err := s.Uint() + msg.Code, err = s.Uint() if err != nil { return msg, err } - payloadsize := size - posr.p - return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil + msg.Size = size - posr.p + + rw.SetReadDeadline(time.Now().Add(payloadReadTimeout)) + + if msg.Size <= wholePayloadSize { + // msg is small, read all of it and move on to the next message. + pbuf := make([]byte, msg.Size) + if _, err := io.ReadFull(rw.bufconn, pbuf); err != nil { + return msg, err + } + rw.rsync <- struct{}{} // bufconn is available again + msg.Payload = bytes.NewReader(pbuf) + } else { + // lend bufconn to the caller until it has + // consumed the payload. eofSignal will send a value + // on rw.rsync when EOF is reached. + pr := &eofSignal{rw.bufconn, msg.Size, rw.rsync} + msg.Payload = pr + } + return msg, nil } // postrack wraps an rlp.ByteReader with a position counter. @@ -167,6 +253,39 @@ func (r *postrack) ReadByte() (byte, error) { return b, err } +// eofSignal wraps a reader with eof signaling. the eof channel is +// closed when the wrapped reader returns an error or when count bytes +// have been read. +type eofSignal struct { + wrapped io.Reader + count uint32 // number of bytes left + eof chan<- struct{} +} + +// note: when using eofSignal to detect whether a message payload +// has been read, Read might not be called for zero sized messages. +func (r *eofSignal) Read(buf []byte) (int, error) { + if r.count == 0 { + if r.eof != nil { + r.eof <- struct{}{} + r.eof = nil + } + return 0, io.EOF + } + + max := len(buf) + if int(r.count) < len(buf) { + max = int(r.count) + } + n, err := r.wrapped.Read(buf[:max]) + r.count -= uint32(n) + if (err != nil || r.count == 0) && r.eof != nil { + r.eof <- struct{}{} // tell Peer that msg has been consumed + r.eof = nil + } + return n, err +} + // MsgPipe creates a message pipe. Reads on one end are matched // with writes on the other. The pipe is full-duplex, both ends // implement MsgReadWriter. @@ -198,7 +317,7 @@ type MsgPipeRW struct { func (p *MsgPipeRW) WriteMsg(msg Msg) error { if atomic.LoadInt32(p.closed) == 0 { consumed := make(chan struct{}, 1) - msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed} + msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed} select { case p.w <- msg: if msg.Size > 0 { diff --git a/p2p/message_test.go b/p2p/message_test.go index 5cde9abf5..4b94ebb5f 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -3,12 +3,11 @@ package p2p import ( "bytes" "fmt" + "io" "io/ioutil" "runtime" "testing" "time" - - "github.com/ethereum/go-ethereum/ethutil" ) func TestNewMsg(t *testing.T) { @@ -26,51 +25,51 @@ func TestNewMsg(t *testing.T) { } } -func TestEncodeDecodeMsg(t *testing.T) { - msg := NewMsg(3, 1, "000") - buf := new(bytes.Buffer) - if err := writeMsg(buf, msg); err != nil { - t.Fatalf("encodeMsg error: %v", err) - } - // t.Logf("encoded: %x", buf.Bytes()) +// func TestEncodeDecodeMsg(t *testing.T) { +// msg := NewMsg(3, 1, "000") +// buf := new(bytes.Buffer) +// if err := writeMsg(buf, msg); err != nil { +// t.Fatalf("encodeMsg error: %v", err) +// } +// // t.Logf("encoded: %x", buf.Bytes()) - decmsg, err := readMsg(buf) - if err != nil { - t.Fatalf("readMsg error: %v", err) - } - if decmsg.Code != 3 { - t.Errorf("incorrect code %d, want %d", decmsg.Code, 3) - } - if decmsg.Size != 5 { - t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) - } +// decmsg, err := readMsg(buf) +// if err != nil { +// t.Fatalf("readMsg error: %v", err) +// } +// if decmsg.Code != 3 { +// t.Errorf("incorrect code %d, want %d", decmsg.Code, 3) +// } +// if decmsg.Size != 5 { +// t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) +// } - var data struct { - I uint - S string - } - if err := decmsg.Decode(&data); err != nil { - t.Fatalf("Decode error: %v", err) - } - if data.I != 1 { - t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1) - } - if data.S != "000" { - t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000") - } -} +// var data struct { +// I uint +// S string +// } +// if err := decmsg.Decode(&data); err != nil { +// t.Fatalf("Decode error: %v", err) +// } +// if data.I != 1 { +// t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1) +// } +// if data.S != "000" { +// t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000") +// } +// } -func TestDecodeRealMsg(t *testing.T) { - data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb") - msg, err := readMsg(bytes.NewReader(data)) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } +// func TestDecodeRealMsg(t *testing.T) { +// data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb") +// msg, err := readMsg(bytes.NewReader(data)) +// if err != nil { +// t.Fatalf("unexpected error: %v", err) +// } - if msg.Code != 0 { - t.Errorf("incorrect code %d, want %d", msg.Code, 0) - } -} +// if msg.Code != 0 { +// t.Errorf("incorrect code %d, want %d", msg.Code, 0) +// } +// } func ExampleMsgPipe() { rw1, rw2 := MsgPipe() @@ -131,3 +130,58 @@ func TestMsgPipeConcurrentClose(t *testing.T) { go rw1.Close() } } + +func TestEOFSignal(t *testing.T) { + rb := make([]byte, 10) + + // empty reader + eof := make(chan struct{}, 1) + sig := &eofSignal{new(bytes.Buffer), 0, eof} + if n, err := sig.Read(rb); n != 0 || err != io.EOF { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // count before error + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof} + if n, err := sig.Read(rb); n != 4 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // error before count + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof} + if n, err := sig.Read(rb); n != 4 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + if n, err := sig.Read(rb); n != 0 || err != io.EOF { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // no signal if neither occurs + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof} + if n, err := sig.Read(rb); n != 10 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + t.Error("unexpected EOF signal") + default: + } +} diff --git a/p2p/nat.go b/p2p/nat.go deleted file mode 100644 index 9b771c3e8..000000000 --- a/p2p/nat.go +++ /dev/null @@ -1,23 +0,0 @@ -package p2p - -import ( - "fmt" - "net" -) - -func ParseNAT(natType string, gateway string) (nat NAT, err error) { - switch natType { - case "UPNP": - nat = UPNP() - case "PMP": - ip := net.ParseIP(gateway) - if ip == nil { - return nil, fmt.Errorf("cannot resolve PMP gateway IP %s", gateway) - } - nat = PMP(ip) - case "": - default: - return nil, fmt.Errorf("unrecognised NAT type '%s'", natType) - } - return -} diff --git a/p2p/nat/nat.go b/p2p/nat/nat.go new file mode 100644 index 000000000..12d355ba1 --- /dev/null +++ b/p2p/nat/nat.go @@ -0,0 +1,235 @@ +// Package nat provides access to common port mapping protocols. +package nat + +import ( + "errors" + "fmt" + "net" + "strings" + "sync" + "time" + + "github.com/ethereum/go-ethereum/logger" + "github.com/jackpal/go-nat-pmp" +) + +var log = logger.NewLogger("P2P NAT") + +// An implementation of nat.Interface can map local ports to ports +// accessible from the Internet. +type Interface interface { + // These methods manage a mapping between a port on the local + // machine to a port that can be connected to from the internet. + // + // protocol is "UDP" or "TCP". Some implementations allow setting + // a display name for the mapping. The mapping may be removed by + // the gateway when its lifetime ends. + AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error + DeleteMapping(protocol string, extport, intport int) error + + // This method should return the external (Internet-facing) + // address of the gateway device. + ExternalIP() (net.IP, error) + + // Should return name of the method. This is used for logging. + String() string +} + +// Parse parses a NAT interface description. +// The following formats are currently accepted. +// Note that mechanism names are not case-sensitive. +// +// "" or "none" return nil +// "extip:77.12.33.4" will assume the local machine is reachable on the given IP +// "any" uses the first auto-detected mechanism +// "upnp" uses the Universal Plug and Play protocol +// "pmp" uses NAT-PMP with an auto-detected gateway address +// "pmp:192.168.0.1" uses NAT-PMP with the given gateway address +func Parse(spec string) (Interface, error) { + var ( + parts = strings.SplitN(spec, ":", 2) + mech = strings.ToLower(parts[0]) + ip net.IP + ) + if len(parts) > 1 { + ip = net.ParseIP(parts[1]) + if ip == nil { + return nil, errors.New("invalid IP address") + } + } + switch mech { + case "", "none", "off": + return nil, nil + case "any", "auto", "on": + return Any(), nil + case "extip", "ip": + if ip == nil { + return nil, errors.New("missing IP address") + } + return ExtIP(ip), nil + case "upnp": + return UPnP(), nil + case "pmp", "natpmp", "nat-pmp": + return PMP(ip), nil + default: + return nil, fmt.Errorf("unknown mechanism %q", parts[0]) + } +} + +const ( + mapTimeout = 20 * time.Minute + mapUpdateInterval = 15 * time.Minute +) + +// Map adds a port mapping on m and keeps it alive until c is closed. +// This function is typically invoked in its own goroutine. +func Map(m Interface, c chan struct{}, protocol string, extport, intport int, name string) { + refresh := time.NewTimer(mapUpdateInterval) + defer func() { + refresh.Stop() + log.Debugf("Deleting port mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m) + m.DeleteMapping(protocol, extport, intport) + }() + log.Debugf("add mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m) + if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil { + log.Errorf("mapping error: %v\n", err) + } + for { + select { + case _, ok := <-c: + if !ok { + return + } + case <-refresh.C: + log.DebugDetailf("refresh mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m) + if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil { + log.Errorf("mapping error: %v\n", err) + } + refresh.Reset(mapUpdateInterval) + } + } +} + +// ExtIP assumes that the local machine is reachable on the given +// external IP address, and that any required ports were mapped manually. +// Mapping operations will not return an error but won't actually do anything. +func ExtIP(ip net.IP) Interface { + if ip == nil { + panic("IP must not be nil") + } + return extIP(ip) +} + +type extIP net.IP + +func (n extIP) ExternalIP() (net.IP, error) { return net.IP(n), nil } +func (n extIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) } + +// These do nothing. +func (extIP) AddMapping(string, int, int, string, time.Duration) error { return nil } +func (extIP) DeleteMapping(string, int, int) error { return nil } + +// Any returns a port mapper that tries to discover any supported +// mechanism on the local network. +func Any() Interface { + // TODO: attempt to discover whether the local machine has an + // Internet-class address. Return ExtIP in this case. + return startautodisc("UPnP or NAT-PMP", func() Interface { + found := make(chan Interface, 2) + go func() { found <- discoverUPnP() }() + go func() { found <- discoverPMP() }() + for i := 0; i < cap(found); i++ { + if c := <-found; c != nil { + return c + } + } + return nil + }) +} + +// UPnP returns a port mapper that uses UPnP. It will attempt to +// discover the address of your router using UDP broadcasts. +func UPnP() Interface { + return startautodisc("UPnP", discoverUPnP) +} + +// PMP returns a port mapper that uses NAT-PMP. The provided gateway +// address should be the IP of your router. If the given gateway +// address is nil, PMP will attempt to auto-discover the router. +func PMP(gateway net.IP) Interface { + if gateway != nil { + return &pmp{gw: gateway, c: natpmp.NewClient(gateway)} + } + return startautodisc("NAT-PMP", discoverPMP) +} + +// autodisc represents a port mapping mechanism that is still being +// auto-discovered. Calls to the Interface methods on this type will +// wait until the discovery is done and then call the method on the +// discovered mechanism. +// +// This type is useful because discovery can take a while but we +// want return an Interface value from UPnP, PMP and Auto immediately. +type autodisc struct { + what string + done <-chan Interface + + mu sync.Mutex + found Interface +} + +func startautodisc(what string, doit func() Interface) Interface { + // TODO: monitor network configuration and rerun doit when it changes. + done := make(chan Interface) + ad := &autodisc{what: what, done: done} + go func() { done <- doit(); close(done) }() + return ad +} + +func (n *autodisc) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error { + if err := n.wait(); err != nil { + return err + } + return n.found.AddMapping(protocol, extport, intport, name, lifetime) +} + +func (n *autodisc) DeleteMapping(protocol string, extport, intport int) error { + if err := n.wait(); err != nil { + return err + } + return n.found.DeleteMapping(protocol, extport, intport) +} + +func (n *autodisc) ExternalIP() (net.IP, error) { + if err := n.wait(); err != nil { + return nil, err + } + return n.found.ExternalIP() +} + +func (n *autodisc) String() string { + n.mu.Lock() + defer n.mu.Unlock() + if n.found == nil { + return n.what + } else { + return n.found.String() + } +} + +func (n *autodisc) wait() error { + n.mu.Lock() + found := n.found + n.mu.Unlock() + if found != nil { + // already discovered + return nil + } + if found = <-n.done; found == nil { + return errors.New("no devices discovered") + } + n.mu.Lock() + n.found = found + n.mu.Unlock() + return nil +} diff --git a/p2p/nat/natpmp.go b/p2p/nat/natpmp.go new file mode 100644 index 000000000..f249c6073 --- /dev/null +++ b/p2p/nat/natpmp.go @@ -0,0 +1,115 @@ +package nat + +import ( + "fmt" + "net" + "strings" + "time" + + "github.com/jackpal/go-nat-pmp" +) + +// natPMPClient adapts the NAT-PMP protocol implementation so it conforms to +// the common interface. +type pmp struct { + gw net.IP + c *natpmp.Client +} + +func (n *pmp) String() string { + return fmt.Sprintf("NAT-PMP(%v)", n.gw) +} + +func (n *pmp) ExternalIP() (net.IP, error) { + response, err := n.c.GetExternalAddress() + if err != nil { + return nil, err + } + return response.ExternalIPAddress[:], nil +} + +func (n *pmp) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error { + if lifetime <= 0 { + return fmt.Errorf("lifetime must not be <= 0") + } + // Note order of port arguments is switched between our + // AddMapping and the client's AddPortMapping. + _, err := n.c.AddPortMapping(strings.ToLower(protocol), intport, extport, int(lifetime/time.Second)) + return err +} + +func (n *pmp) DeleteMapping(protocol string, extport, intport int) (err error) { + // To destroy a mapping, send an add-port with an internalPort of + // the internal port to destroy, an external port of zero and a + // time of zero. + _, err = n.c.AddPortMapping(strings.ToLower(protocol), intport, 0, 0) + return err +} + +func discoverPMP() Interface { + // run external address lookups on all potential gateways + gws := potentialGateways() + found := make(chan *pmp, len(gws)) + for i := range gws { + gw := gws[i] + go func() { + c := natpmp.NewClient(gw) + if _, err := c.GetExternalAddress(); err != nil { + found <- nil + } else { + found <- &pmp{gw, c} + } + }() + } + // return the one that responds first. + // discovery needs to be quick, so we stop caring about + // any responses after a very short timeout. + timeout := time.NewTimer(1 * time.Second) + defer timeout.Stop() + for _ = range gws { + select { + case c := <-found: + if c != nil { + return c + } + case <-timeout.C: + return nil + } + } + return nil +} + +var ( + // LAN IP ranges + _, lan10, _ = net.ParseCIDR("10.0.0.0/8") + _, lan176, _ = net.ParseCIDR("172.16.0.0/12") + _, lan192, _ = net.ParseCIDR("192.168.0.0/16") +) + +// TODO: improve this. We currently assume that (on most networks) +// the router is X.X.X.1 in a local LAN range. +func potentialGateways() (gws []net.IP) { + ifaces, err := net.Interfaces() + if err != nil { + return nil + } + for _, iface := range ifaces { + ifaddrs, err := iface.Addrs() + if err != nil { + return gws + } + for _, addr := range ifaddrs { + switch x := addr.(type) { + case *net.IPNet: + if lan10.Contains(x.IP) || lan176.Contains(x.IP) || lan192.Contains(x.IP) { + ip := x.IP.Mask(x.Mask).To4() + if ip != nil { + ip[3] = ip[3] | 0x01 + gws = append(gws, ip) + } + } + } + } + } + return gws +} diff --git a/p2p/nat/natupnp.go b/p2p/nat/natupnp.go new file mode 100644 index 000000000..25cbc0d71 --- /dev/null +++ b/p2p/nat/natupnp.go @@ -0,0 +1,149 @@ +package nat + +import ( + "errors" + "fmt" + "net" + "strings" + "time" + + "github.com/fjl/goupnp" + "github.com/fjl/goupnp/dcps/internetgateway1" + "github.com/fjl/goupnp/dcps/internetgateway2" +) + +type upnp struct { + dev *goupnp.RootDevice + service string + client upnpClient +} + +type upnpClient interface { + GetExternalIPAddress() (string, error) + AddPortMapping(string, uint16, string, uint16, string, bool, string, uint32) error + DeletePortMapping(string, uint16, string) error + GetNATRSIPStatus() (sip bool, nat bool, err error) +} + +func (n *upnp) ExternalIP() (addr net.IP, err error) { + ipString, err := n.client.GetExternalIPAddress() + if err != nil { + return nil, err + } + ip := net.ParseIP(ipString) + if ip == nil { + return nil, errors.New("bad IP in response") + } + return ip, nil +} + +func (n *upnp) AddMapping(protocol string, extport, intport int, desc string, lifetime time.Duration) error { + ip, err := n.internalAddress() + if err != nil { + return nil + } + protocol = strings.ToUpper(protocol) + lifetimeS := uint32(lifetime / time.Second) + return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS) +} + +func (n *upnp) internalAddress() (net.IP, error) { + devaddr, err := net.ResolveUDPAddr("udp4", n.dev.URLBase.Host) + if err != nil { + return nil, err + } + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + for _, iface := range ifaces { + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + for _, addr := range addrs { + switch x := addr.(type) { + case *net.IPNet: + if x.Contains(devaddr.IP) { + return x.IP, nil + } + } + } + } + return nil, fmt.Errorf("could not find local address in same net as %v", devaddr) +} + +func (n *upnp) DeleteMapping(protocol string, extport, intport int) error { + return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol)) +} + +func (n *upnp) String() string { + return "UPNP " + n.service +} + +// discoverUPnP searches for Internet Gateway Devices +// and returns the first one it can find on the local network. +func discoverUPnP() Interface { + found := make(chan *upnp, 2) + // IGDv1 + go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp { + switch sc.Service.ServiceType { + case internetgateway1.URN_WANIPConnection_1: + return &upnp{dev, "IGDv1-IP1", &internetgateway1.WANIPConnection1{sc}} + case internetgateway1.URN_WANPPPConnection_1: + return &upnp{dev, "IGDv1-PPP1", &internetgateway1.WANPPPConnection1{sc}} + } + return nil + }) + // IGDv2 + go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp { + switch sc.Service.ServiceType { + case internetgateway2.URN_WANIPConnection_1: + return &upnp{dev, "IGDv2-IP1", &internetgateway2.WANIPConnection1{sc}} + case internetgateway2.URN_WANIPConnection_2: + return &upnp{dev, "IGDv2-IP2", &internetgateway2.WANIPConnection2{sc}} + case internetgateway2.URN_WANPPPConnection_1: + return &upnp{dev, "IGDv2-PPP1", &internetgateway2.WANPPPConnection1{sc}} + } + return nil + }) + for i := 0; i < cap(found); i++ { + if c := <-found; c != nil { + return c + } + } + return nil +} + +func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice, goupnp.ServiceClient) *upnp) { + devs, err := goupnp.DiscoverDevices(target) + if err != nil { + return + } + found := false + for i := 0; i < len(devs) && !found; i++ { + if devs[i].Root == nil { + continue + } + devs[i].Root.Device.VisitServices(func(service *goupnp.Service) { + if found { + return + } + // check for a matching IGD service + sc := goupnp.ServiceClient{service.NewSOAPClient(), devs[i].Root, service} + upnp := matcher(devs[i].Root, sc) + if upnp == nil { + return + } + // check whether port mapping is enabled + if _, nat, err := upnp.client.GetNATRSIPStatus(); err != nil || !nat { + return + } + out <- upnp + found = true + }) + } + if !found { + out <- nil + } +} diff --git a/p2p/natpmp.go b/p2p/natpmp.go deleted file mode 100644 index 6714678c4..000000000 --- a/p2p/natpmp.go +++ /dev/null @@ -1,55 +0,0 @@ -package p2p - -import ( - "fmt" - "net" - "time" - - natpmp "github.com/jackpal/go-nat-pmp" -) - -// Adapt the NAT-PMP protocol to the NAT interface - -// TODO: -// + Register for changes to the external address. -// + Re-register port mapping when router reboots. -// + A mechanism for keeping a port mapping registered. -// + Discover gateway address automatically. - -type natPMPClient struct { - client *natpmp.Client -} - -// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway -// address should be the IP of your router. -func PMP(gateway net.IP) (nat NAT) { - return &natPMPClient{natpmp.NewClient(gateway)} -} - -func (*natPMPClient) String() string { - return "NAT-PMP" -} - -func (n *natPMPClient) GetExternalAddress() (net.IP, error) { - response, err := n.client.GetExternalAddress() - if err != nil { - return nil, err - } - return response.ExternalIPAddress[:], nil -} - -func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error { - if lifetime <= 0 { - return fmt.Errorf("lifetime must not be <= 0") - } - // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping. - _, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second)) - return err -} - -func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { - // To destroy a mapping, send an add-port with - // an internalPort of the internal port to destroy, an external port of zero and a time of zero. - _, err = n.client.AddPortMapping(protocol, internalPort, 0, 0) - return -} diff --git a/p2p/natupnp.go b/p2p/natupnp.go deleted file mode 100644 index 2e0d8ce8d..000000000 --- a/p2p/natupnp.go +++ /dev/null @@ -1,341 +0,0 @@ -package p2p - -// Just enough UPnP to be able to forward ports -// - -import ( - "bytes" - "encoding/xml" - "errors" - "fmt" - "net" - "net/http" - "os" - "strconv" - "strings" - "time" -) - -const ( - upnpDiscoverAttempts = 3 - upnpDiscoverTimeout = 5 * time.Second -) - -// UPNP returns a NAT port mapper that uses UPnP. It will attempt to -// discover the address of your router using UDP broadcasts. -func UPNP() NAT { - return &upnpNAT{} -} - -type upnpNAT struct { - serviceURL string - ourIP string -} - -func (n *upnpNAT) String() string { - return "UPNP" -} - -func (n *upnpNAT) discover() error { - if n.serviceURL != "" { - // already discovered - return nil - } - - ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") - if err != nil { - return err - } - // TODO: try on all network interfaces simultaneously. - // Broadcasting on 0.0.0.0 could select a random interface - // to send on (platform specific). - conn, err := net.ListenPacket("udp4", ":0") - if err != nil { - return err - } - defer conn.Close() - - conn.SetDeadline(time.Now().Add(10 * time.Second)) - st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" - buf := bytes.NewBufferString( - "M-SEARCH * HTTP/1.1\r\n" + - "HOST: 239.255.255.250:1900\r\n" + - st + - "MAN: \"ssdp:discover\"\r\n" + - "MX: 2\r\n\r\n") - message := buf.Bytes() - answerBytes := make([]byte, 1024) - for i := 0; i < upnpDiscoverAttempts; i++ { - _, err = conn.WriteTo(message, ssdp) - if err != nil { - return err - } - nn, _, err := conn.ReadFrom(answerBytes) - if err != nil { - continue - } - answer := string(answerBytes[0:nn]) - if strings.Index(answer, "\r\n"+st) < 0 { - continue - } - // HTTP header field names are case-insensitive. - // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 - locString := "\r\nlocation: " - answer = strings.ToLower(answer) - locIndex := strings.Index(answer, locString) - if locIndex < 0 { - continue - } - loc := answer[locIndex+len(locString):] - endIndex := strings.Index(loc, "\r\n") - if endIndex < 0 { - continue - } - locURL := loc[0:endIndex] - var serviceURL string - serviceURL, err = getServiceURL(locURL) - if err != nil { - return err - } - var ourIP string - ourIP, err = getOurIP() - if err != nil { - return err - } - n.serviceURL = serviceURL - n.ourIP = ourIP - return nil - } - return errors.New("UPnP port discovery failed.") -} - -func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { - if err := n.discover(); err != nil { - return nil, err - } - info, err := n.getStatusInfo() - return net.ParseIP(info.externalIpAddress), err -} - -func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error { - if err := n.discover(); err != nil { - return err - } - - // A single concatenation would break ARM compilation. - message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + - "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport) - message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" - message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" + - "<NewInternalClient>" + n.ourIP + "</NewInternalClient>" + - "<NewEnabled>1</NewEnabled><NewPortMappingDescription>" - message += description + - "</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) + - "</NewLeaseDuration></u:AddPortMapping>" - - // TODO: check response to see if the port was forwarded - _, err := soapRequest(n.serviceURL, "AddPortMapping", message) - return err -} - -func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error { - if err := n.discover(); err != nil { - return err - } - - message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + - "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + - "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + - "</u:DeletePortMapping>" - - // TODO: check response to see if the port was deleted - _, err := soapRequest(n.serviceURL, "DeletePortMapping", message) - return err -} - -type statusInfo struct { - externalIpAddress string -} - -func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { - message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + - "</u:GetStatusInfo>" - - var response *http.Response - response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) - if err != nil { - return - } - - // TODO: Write a soap reply parser. It has to eat the Body and envelope tags... - - response.Body.Close() - return -} - -// service represents the Service type in an UPnP xml description. -// Only the parts we care about are present and thus the xml may have more -// fields than present in the structure. -type service struct { - ServiceType string `xml:"serviceType"` - ControlURL string `xml:"controlURL"` -} - -// deviceList represents the deviceList type in an UPnP xml description. -// Only the parts we care about are present and thus the xml may have more -// fields than present in the structure. -type deviceList struct { - XMLName xml.Name `xml:"deviceList"` - Device []device `xml:"device"` -} - -// serviceList represents the serviceList type in an UPnP xml description. -// Only the parts we care about are present and thus the xml may have more -// fields than present in the structure. -type serviceList struct { - XMLName xml.Name `xml:"serviceList"` - Service []service `xml:"service"` -} - -// device represents the device type in an UPnP xml description. -// Only the parts we care about are present and thus the xml may have more -// fields than present in the structure. -type device struct { - XMLName xml.Name `xml:"device"` - DeviceType string `xml:"deviceType"` - DeviceList deviceList `xml:"deviceList"` - ServiceList serviceList `xml:"serviceList"` -} - -// specVersion represents the specVersion in a UPnP xml description. -// Only the parts we care about are present and thus the xml may have more -// fields than present in the structure. -type specVersion struct { - XMLName xml.Name `xml:"specVersion"` - Major int `xml:"major"` - Minor int `xml:"minor"` -} - -// root represents the Root document for a UPnP xml description. -// Only the parts we care about are present and thus the xml may have more -// fields than present in the structure. -type root struct { - XMLName xml.Name `xml:"root"` - SpecVersion specVersion - Device device -} - -func getChildDevice(d *device, deviceType string) *device { - dl := d.DeviceList.Device - for i := 0; i < len(dl); i++ { - if dl[i].DeviceType == deviceType { - return &dl[i] - } - } - return nil -} - -func getChildService(d *device, serviceType string) *service { - sl := d.ServiceList.Service - for i := 0; i < len(sl); i++ { - if sl[i].ServiceType == serviceType { - return &sl[i] - } - } - return nil -} - -func getOurIP() (ip string, err error) { - hostname, err := os.Hostname() - if err != nil { - return - } - p, err := net.LookupIP(hostname) - if err != nil && len(p) > 0 { - return - } - return p[0].String(), nil -} - -func getServiceURL(rootURL string) (url string, err error) { - r, err := http.Get(rootURL) - if err != nil { - return - } - defer r.Body.Close() - if r.StatusCode >= 400 { - err = errors.New(string(r.StatusCode)) - return - } - var root root - err = xml.NewDecoder(r.Body).Decode(&root) - - if err != nil { - return - } - a := &root.Device - if a.DeviceType != "urn:schemas-upnp-org:device:InternetGatewayDevice:1" { - err = errors.New("No InternetGatewayDevice") - return - } - b := getChildDevice(a, "urn:schemas-upnp-org:device:WANDevice:1") - if b == nil { - err = errors.New("No WANDevice") - return - } - c := getChildDevice(b, "urn:schemas-upnp-org:device:WANConnectionDevice:1") - if c == nil { - err = errors.New("No WANConnectionDevice") - return - } - d := getChildService(c, "urn:schemas-upnp-org:service:WANIPConnection:1") - if d == nil { - err = errors.New("No WANIPConnection") - return - } - url = combineURL(rootURL, d.ControlURL) - return -} - -func combineURL(rootURL, subURL string) string { - protocolEnd := "://" - protoEndIndex := strings.Index(rootURL, protocolEnd) - a := rootURL[protoEndIndex+len(protocolEnd):] - rootIndex := strings.Index(a, "/") - return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL -} - -func soapRequest(url, function, message string) (r *http.Response, err error) { - fullMessage := "<?xml version=\"1.0\" ?>" + - "<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" + - "<s:Body>" + message + "</s:Body></s:Envelope>" - - req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage)) - if err != nil { - return - } - req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"") - req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3") - //req.Header.Set("Transfer-Encoding", "chunked") - req.Header.Set("SOAPAction", "\"urn:schemas-upnp-org:service:WANIPConnection:1#"+function+"\"") - req.Header.Set("Connection", "Close") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Pragma", "no-cache") - - r, err = http.DefaultClient.Do(req) - if err != nil { - return - } - - if r.Body != nil { - defer r.Body.Close() - } - - if r.StatusCode >= 400 { - // log.Stderr(function, r.StatusCode) - err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function) - r = nil - return - } - return -} diff --git a/p2p/peer.go b/p2p/peer.go index 2380a3285..fd5bec7d5 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -1,8 +1,7 @@ package p2p import ( - "bufio" - "bytes" + "errors" "fmt" "io" "io/ioutil" @@ -11,159 +10,109 @@ import ( "sync" "time" - "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rlp" ) -// peerAddr is the structure of a peer list element. -// It is also a valid net.Addr. -type peerAddr struct { - IP net.IP - Port uint64 - Pubkey []byte // optional -} - -func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr { - n := addr.Network() - if n != "tcp" && n != "tcp4" && n != "tcp6" { - // for testing with non-TCP - return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey} - } - ta := addr.(*net.TCPAddr) - return &peerAddr{ta.IP, uint64(ta.Port), pubkey} -} +const ( + baseProtocolVersion = 3 + baseProtocolLength = uint64(16) + baseProtocolMaxMsgSize = 10 * 1024 * 1024 -func (d peerAddr) Network() string { - if d.IP.To4() != nil { - return "tcp4" - } else { - return "tcp6" - } -} + disconnectGracePeriod = 2 * time.Second +) -func (d peerAddr) String() string { - return fmt.Sprintf("%v:%d", d.IP, d.Port) -} +const ( + // devp2p message codes + handshakeMsg = 0x00 + discMsg = 0x01 + pingMsg = 0x02 + pongMsg = 0x03 + getPeersMsg = 0x04 + peersMsg = 0x05 +) -func (d *peerAddr) RlpData() interface{} { - return []interface{}{string(d.IP), d.Port, d.Pubkey} +// handshake is the RLP structure of the protocol handshake. +type handshake struct { + Version uint64 + Name string + Caps []Cap + ListenPort uint64 + NodeID discover.NodeID } -// Peer represents a remote peer. +// Peer represents a connected remote node. type Peer struct { // Peers have all the log methods. // Use them to display messages related to the peer. *logger.Logger - infolock sync.Mutex - identity ClientIdentity - caps []Cap - listenAddr *peerAddr // what remote peer is listening on - dialAddr *peerAddr // non-nil if dialing + infoMu sync.Mutex + name string + caps []Cap + + ourID, remoteID *discover.NodeID + ourName string - // The mutex protects the connection - // so only one protocol can write at a time. - writeMu sync.Mutex - conn net.Conn - bufconn *bufio.ReadWriter + rw *frameRW // These fields maintain the running protocols. - protocols []Protocol - runBaseProtocol bool // for testing + protocols []Protocol + runlock sync.RWMutex // protects running + running map[string]*proto - runlock sync.RWMutex // protects running - running map[string]*proto + // disables protocol handshake, for testing + noHandshake bool protoWG sync.WaitGroup protoErr chan error closed chan struct{} disc chan DiscReason - - activity event.TypeMux // for activity events - - slot int // index into Server peer list - - // These fields are kept so base protocol can access them. - // TODO: this should be one or more interfaces - ourID ClientIdentity // client id of the Server - ourListenAddr *peerAddr // listen addr of Server, nil if not listening - newPeerAddr chan<- *peerAddr // tell server about received peers - otherPeers func() []*Peer // should return the list of all peers - pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey } // NewPeer returns a peer for testing purposes. -func NewPeer(id ClientIdentity, caps []Cap) *Peer { +func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer { conn, _ := net.Pipe() - peer := newPeer(conn, nil, nil) - peer.setHandshakeInfo(id, nil, caps) - close(peer.closed) + peer := newPeer(conn, nil, "", nil, &id) + peer.setHandshakeInfo(name, caps) + close(peer.closed) // ensures Disconnect doesn't block return peer } -func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer { - p := newPeer(conn, server.Protocols, dialAddr) - p.ourID = server.Identity - p.newPeerAddr = server.peerConnect - p.otherPeers = server.Peers - p.pubkeyHook = server.verifyPeer - p.runBaseProtocol = true - - // laddr can be updated concurrently by NAT traversal. - // newServerPeer must be called with the server lock held. - if server.laddr != nil { - p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey()) - } - return p -} - -func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer { - p := &Peer{ - Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()), - conn: conn, - dialAddr: dialAddr, - bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), - protocols: protocols, - running: make(map[string]*proto), - disc: make(chan DiscReason), - protoErr: make(chan error), - closed: make(chan struct{}), - } - return p +// ID returns the node's public key. +func (p *Peer) ID() discover.NodeID { + return *p.remoteID } -// Identity returns the client identity of the remote peer. The -// identity can be nil if the peer has not yet completed the -// handshake. -func (p *Peer) Identity() ClientIdentity { - p.infolock.Lock() - defer p.infolock.Unlock() - return p.identity +// Name returns the node name that the remote node advertised. +func (p *Peer) Name() string { + // this needs a lock because the information is part of the + // protocol handshake. + p.infoMu.Lock() + name := p.name + p.infoMu.Unlock() + return name } // Caps returns the capabilities (supported subprotocols) of the remote peer. func (p *Peer) Caps() []Cap { - p.infolock.Lock() - defer p.infolock.Unlock() - return p.caps -} - -func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) { - p.infolock.Lock() - p.identity = id - p.listenAddr = laddr - p.caps = caps - p.infolock.Unlock() + // this needs a lock because the information is part of the + // protocol handshake. + p.infoMu.Lock() + caps := p.caps + p.infoMu.Unlock() + return caps } // RemoteAddr returns the remote address of the network connection. func (p *Peer) RemoteAddr() net.Addr { - return p.conn.RemoteAddr() + return p.rw.RemoteAddr() } // LocalAddr returns the local address of the network connection. func (p *Peer) LocalAddr() net.Addr { - return p.conn.LocalAddr() + return p.rw.LocalAddr() } // Disconnect terminates the peer connection with the given reason. @@ -177,149 +126,177 @@ func (p *Peer) Disconnect(reason DiscReason) { // String implements fmt.Stringer. func (p *Peer) String() string { - kind := "inbound" - p.infolock.Lock() - if p.dialAddr != nil { - kind = "outbound" - } - p.infolock.Unlock() - return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind) + return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr()) } -const ( - // maximum amount of time allowed for reading a message - msgReadTimeout = 5 * time.Second - // maximum amount of time allowed for writing a message - msgWriteTimeout = 5 * time.Second - // messages smaller than this many bytes will be read at - // once before passing them to a protocol. - wholePayloadSize = 64 * 1024 -) +func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer { + logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr()) + return &Peer{ + Logger: logger.NewLogger(logtag), + rw: newFrameRW(conn, msgWriteTimeout), + ourID: ourID, + ourName: ourName, + remoteID: remoteID, + protocols: protocols, + running: make(map[string]*proto), + disc: make(chan DiscReason), + protoErr: make(chan error), + closed: make(chan struct{}), + } +} -var ( - inactivityTimeout = 2 * time.Second - disconnectGracePeriod = 2 * time.Second -) +func (p *Peer) setHandshakeInfo(name string, caps []Cap) { + p.infoMu.Lock() + p.name = name + p.caps = caps + p.infoMu.Unlock() +} -func (p *Peer) loop() (reason DiscReason, err error) { - defer p.activity.Stop() +func (p *Peer) run() DiscReason { + var readErr = make(chan error, 1) defer p.closeProtocols() defer close(p.closed) - defer p.conn.Close() - - // read loop - readMsg := make(chan Msg) - readErr := make(chan error) - readNext := make(chan bool, 1) - protoDone := make(chan struct{}, 1) - go p.readLoop(readMsg, readErr, readNext) - readNext <- true - - if p.runBaseProtocol { - p.startBaseProtocol() - } -loop: - for { - select { - case msg := <-readMsg: - // a new message has arrived. - var wait bool - if wait, err = p.dispatch(msg, protoDone); err != nil { - p.Errorf("msg dispatch error: %v\n", err) - reason = discReasonForError(err) - break loop - } - if !wait { - // Msg has already been read completely, continue with next message. - readNext <- true - } - p.activity.Post(time.Now()) - case <-protoDone: - // protocol has consumed the message payload, - // we can continue reading from the socket. - readNext <- true - - case err := <-readErr: - // read failed. there is no need to run the - // polite disconnect sequence because the connection - // is probably dead anyway. - // TODO: handle write errors as well - return DiscNetworkError, err - case err = <-p.protoErr: - reason = discReasonForError(err) - break loop - case reason = <-p.disc: - break loop + go func() { readErr <- p.readLoop() }() + + if !p.noHandshake { + if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil { + p.DebugDetailf("Protocol handshake error: %v\n", err) + p.rw.Close() + return DiscProtocolError } } - // wait for read loop to return. - close(readNext) + // Wait for an error or disconnect. + var reason DiscReason + select { + case err := <-readErr: + // We rely on protocols to abort if there is a write error. It + // might be more robust to handle them here as well. + p.DebugDetailf("Read error: %v\n", err) + p.rw.Close() + return DiscNetworkError + + case err := <-p.protoErr: + reason = discReasonForError(err) + case reason = <-p.disc: + } + p.politeDisconnect(reason) + + // Wait for readLoop. It will end because conn is now closed. <-readErr - // tell the remote end to disconnect + p.Debugf("Disconnected: %v\n", reason) + return reason +} + +func (p *Peer) politeDisconnect(reason DiscReason) { done := make(chan struct{}) go func() { - p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod)) - p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod) - io.Copy(ioutil.Discard, p.conn) + EncodeMsg(p.rw, discMsg, uint(reason)) + // Wait for the other side to close the connection. + // Discard any data that they send until then. + io.Copy(ioutil.Discard, p.rw) close(done) }() select { case <-done: case <-time.After(disconnectGracePeriod): } - return reason, err + p.rw.Close() } -func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) { - for _ = range unblock { - p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) - if msg, err := readMsg(p.bufconn); err != nil { - errc <- err - } else { - msgc <- msg +func (p *Peer) readLoop() error { + if !p.noHandshake { + if err := readProtocolHandshake(p, p.rw); err != nil { + return err } } - close(errc) -} - -func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) { - proto, err := p.getProto(msg.Code) - if err != nil { - return false, err + for { + msg, err := p.rw.ReadMsg() + if err != nil { + return err + } + if err = p.handle(msg); err != nil { + return err + } } - if msg.Size <= wholePayloadSize { - // optimization: msg is small enough, read all - // of it and move on to the next message - buf, err := ioutil.ReadAll(msg.Payload) + return nil +} + +func (p *Peer) handle(msg Msg) error { + switch { + case msg.Code == pingMsg: + msg.Discard() + go EncodeMsg(p.rw, pongMsg) + case msg.Code == discMsg: + var reason DiscReason + // no need to discard or for error checking, we'll close the + // connection after this. + rlp.Decode(msg.Payload, &reason) + p.Disconnect(DiscRequested) + return discRequestedError(reason) + case msg.Code < baseProtocolLength: + // ignore other base protocol messages + return msg.Discard() + default: + // it's a subprotocol message + proto, err := p.getProto(msg.Code) if err != nil { - return false, err + return fmt.Errorf("msg code out of range: %v", msg.Code) } - msg.Payload = bytes.NewReader(buf) - proto.in <- msg - } else { - wait = true - pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone} - msg.Payload = pr proto.in <- msg } - return wait, nil + return nil } -func (p *Peer) startBaseProtocol() { - p.runlock.Lock() - defer p.runlock.Unlock() - p.running[""] = p.startProto(0, Protocol{ - Length: baseProtocolLength, - Run: runBaseProtocol, - }) +func readProtocolHandshake(p *Peer, rw MsgReadWriter) error { + // read and handle remote handshake + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if msg.Code == discMsg { + // disconnect before protocol handshake is valid according to the + // spec and we send it ourself if Server.addPeer fails. + var reason DiscReason + rlp.Decode(msg.Payload, &reason) + return discRequestedError(reason) + } + if msg.Code != handshakeMsg { + return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code) + } + if msg.Size > baseProtocolMaxMsgSize { + return newPeerError(errInvalidMsg, "message too big") + } + var hs handshake + if err := msg.Decode(&hs); err != nil { + return err + } + // validate handshake info + if hs.Version != baseProtocolVersion { + return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", + baseProtocolVersion, hs.Version) + } + if hs.NodeID == *p.remoteID { + return newPeerError(errPubkeyForbidden, "node ID mismatch") + } + // TODO: remove Caps with empty name + p.setHandshakeInfo(hs.Name, hs.Caps) + p.startSubprotocols(hs.Caps) + return nil +} + +func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error { + var caps []interface{} + for _, proto := range ps { + caps = append(caps, proto.cap()) + } + return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id) } // startProtocols starts matching named subprotocols. func (p *Peer) startSubprotocols(caps []Cap) { sort.Sort(capsByName(caps)) - p.runlock.Lock() defer p.runlock.Unlock() offset := baseProtocolLength @@ -338,20 +315,22 @@ outer: } func (p *Peer) startProto(offset uint64, impl Protocol) *proto { + p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version) rw := &proto{ + name: impl.Name, in: make(chan Msg), offset: offset, maxcode: impl.Length, - peer: p, + w: p.rw, } p.protoWG.Add(1) go func() { err := impl.Run(p, rw) if err == nil { - p.Infof("protocol %q returned", impl.Name) - err = newPeerError(errMisc, "protocol returned") + p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version) + err = errors.New("protocol returned") } else { - p.Errorf("protocol %q error: %v\n", impl.Name, err) + p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err) } select { case p.protoErr <- err: @@ -385,6 +364,7 @@ func (p *Peer) closeProtocols() { } // writeProtoMsg sends the given message on behalf of the given named protocol. +// this exists because of Server.Broadcast. func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { p.runlock.RLock() proto, ok := p.running[protoName] @@ -396,25 +376,14 @@ func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) } msg.Code += proto.offset - return p.writeMsg(msg, msgWriteTimeout) -} - -// writeMsg writes a message to the connection. -func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error { - p.writeMu.Lock() - defer p.writeMu.Unlock() - p.conn.SetWriteDeadline(time.Now().Add(timeout)) - if err := writeMsg(p.bufconn, msg); err != nil { - return newPeerError(errWrite, "%v", err) - } - return p.bufconn.Flush() + return p.rw.WriteMsg(msg) } type proto struct { name string in chan Msg maxcode, offset uint64 - peer *Peer + w MsgWriter } func (rw *proto) WriteMsg(msg Msg) error { @@ -422,11 +391,7 @@ func (rw *proto) WriteMsg(msg Msg) error { return newPeerError(errInvalidMsgCode, "not handled") } msg.Code += rw.offset - return rw.peer.writeMsg(msg, msgWriteTimeout) -} - -func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { - return rw.WriteMsg(NewMsg(code, data...)) + return rw.w.WriteMsg(msg) } func (rw *proto) ReadMsg() (Msg, error) { @@ -437,26 +402,3 @@ func (rw *proto) ReadMsg() (Msg, error) { msg.Code -= rw.offset return msg, nil } - -// eofSignal wraps a reader with eof signaling. the eof channel is -// closed when the wrapped reader returns an error or when count bytes -// have been read. -// -type eofSignal struct { - wrapped io.Reader - count int64 - eof chan<- struct{} -} - -// note: when using eofSignal to detect whether a message payload -// has been read, Read might not be called for zero sized messages. - -func (r *eofSignal) Read(buf []byte) (int, error) { - n, err := r.wrapped.Read(buf) - r.count -= int64(n) - if (err != nil || r.count <= 0) && r.eof != nil { - r.eof <- struct{}{} // tell Peer that msg has been consumed - r.eof = nil - } - return n, err -} diff --git a/p2p/peer_error.go b/p2p/peer_error.go index 0eb7ec838..0ff4f4b43 100644 --- a/p2p/peer_error.go +++ b/p2p/peer_error.go @@ -12,7 +12,6 @@ const ( errInvalidMsgCode errInvalidMsg errP2PVersionMismatch - errPubkeyMissing errPubkeyInvalid errPubkeyForbidden errProtocolBreach @@ -22,20 +21,19 @@ const ( ) var errorToString = map[int]string{ - errMagicTokenMismatch: "Magic token mismatch", - errRead: "Read error", - errWrite: "Write error", - errMisc: "Misc error", - errInvalidMsgCode: "Invalid message code", - errInvalidMsg: "Invalid message", + errMagicTokenMismatch: "magic token mismatch", + errRead: "read error", + errWrite: "write error", + errMisc: "misc error", + errInvalidMsgCode: "invalid message code", + errInvalidMsg: "invalid message", errP2PVersionMismatch: "P2P Version Mismatch", - errPubkeyMissing: "Public key missing", - errPubkeyInvalid: "Public key invalid", - errPubkeyForbidden: "Public key forbidden", - errProtocolBreach: "Protocol Breach", - errPingTimeout: "Ping timeout", - errInvalidNetworkId: "Invalid network id", - errInvalidProtocolVersion: "Invalid protocol version", + errPubkeyInvalid: "public key invalid", + errPubkeyForbidden: "public key forbidden", + errProtocolBreach: "protocol Breach", + errPingTimeout: "ping timeout", + errInvalidNetworkId: "invalid network id", + errInvalidProtocolVersion: "invalid protocol version", } type peerError struct { @@ -62,22 +60,22 @@ func (self *peerError) Error() string { type DiscReason byte const ( - DiscRequested DiscReason = 0x00 - DiscNetworkError = 0x01 - DiscProtocolError = 0x02 - DiscUselessPeer = 0x03 - DiscTooManyPeers = 0x04 - DiscAlreadyConnected = 0x05 - DiscIncompatibleVersion = 0x06 - DiscInvalidIdentity = 0x07 - DiscQuitting = 0x08 - DiscUnexpectedIdentity = 0x09 - DiscSelf = 0x0a - DiscReadTimeout = 0x0b - DiscSubprotocolError = 0x10 + DiscRequested DiscReason = iota + DiscNetworkError + DiscProtocolError + DiscUselessPeer + DiscTooManyPeers + DiscAlreadyConnected + DiscIncompatibleVersion + DiscInvalidIdentity + DiscQuitting + DiscUnexpectedIdentity + DiscSelf + DiscReadTimeout + DiscSubprotocolError ) -var discReasonToString = [DiscSubprotocolError + 1]string{ +var discReasonToString = [...]string{ DiscRequested: "Disconnect requested", DiscNetworkError: "Network error", DiscProtocolError: "Breach of protocol", @@ -117,7 +115,7 @@ func discReasonForError(err error) DiscReason { switch peerError.Code { case errP2PVersionMismatch: return DiscIncompatibleVersion - case errPubkeyMissing, errPubkeyInvalid: + case errPubkeyInvalid: return DiscInvalidIdentity case errPubkeyForbidden: return DiscUselessPeer @@ -125,7 +123,7 @@ func discReasonForError(err error) DiscReason { return DiscProtocolError case errPingTimeout: return DiscReadTimeout - case errRead, errWrite, errMisc: + case errRead, errWrite: return DiscNetworkError default: return DiscSubprotocolError diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 4ee88f112..68c9910a2 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -1,15 +1,17 @@ package p2p import ( - "bufio" "bytes" - "encoding/hex" - "io" + "fmt" "io/ioutil" "net" "reflect" + "sort" "testing" "time" + + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rlp" ) var discard = Protocol{ @@ -28,17 +30,13 @@ var discard = Protocol{ }, } -func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { +func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) { conn1, conn2 := net.Pipe() - peer := newPeer(conn1, protos, nil) - peer.ourID = &peerId{} - peer.pubkeyHook = func(*peerAddr) error { return nil } - errc := make(chan error, 1) - go func() { - _, err := peer.loop() - errc <- err - }() - return conn2, peer, errc + peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{}) + peer.noHandshake = noHandshake + errc := make(chan DiscReason, 1) + go func() { errc <- peer.run() }() + return newFrameRW(conn2, msgWriteTimeout), peer, errc } func TestPeerProtoReadMsg(t *testing.T) { @@ -49,31 +47,28 @@ func TestPeerProtoReadMsg(t *testing.T) { Name: "a", Length: 5, Run: func(peer *Peer, rw MsgReadWriter) error { - msg, err := rw.ReadMsg() - if err != nil { - t.Errorf("read error: %v", err) + if err := expectMsg(rw, 2, []uint{1}); err != nil { + t.Error(err) } - if msg.Code != 2 { - t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) - } - data, err := ioutil.ReadAll(msg.Payload) - if err != nil { - t.Errorf("payload read error: %v", err) + if err := expectMsg(rw, 3, []uint{2}); err != nil { + t.Error(err) } - expdata, _ := hex.DecodeString("0183303030") - if !bytes.Equal(expdata, data) { - t.Errorf("incorrect msg data %x", data) + if err := expectMsg(rw, 4, []uint{3}); err != nil { + t.Error(err) } close(done) return nil }, } - net, peer, errc := testPeer([]Protocol{proto}) - defer net.Close() + rw, peer, errc := testPeer(true, []Protocol{proto}) + defer rw.Close() peer.startSubprotocols([]Cap{proto.cap()}) - writeMsg(net, NewMsg(18, 1, "000")) + EncodeMsg(rw, baseProtocolLength+2, 1) + EncodeMsg(rw, baseProtocolLength+3, 2) + EncodeMsg(rw, baseProtocolLength+4, 3) + select { case <-done: case err := <-errc: @@ -105,11 +100,11 @@ func TestPeerProtoReadLargeMsg(t *testing.T) { }, } - net, peer, errc := testPeer([]Protocol{proto}) - defer net.Close() + rw, peer, errc := testPeer(true, []Protocol{proto}) + defer rw.Close() peer.startSubprotocols([]Cap{proto.cap()}) - writeMsg(net, NewMsg(18, make([]byte, msgsize))) + EncodeMsg(rw, 18, make([]byte, msgsize)) select { case <-done: case err := <-errc: @@ -135,32 +130,20 @@ func TestPeerProtoEncodeMsg(t *testing.T) { return nil }, } - net, peer, _ := testPeer([]Protocol{proto}) - defer net.Close() + rw, peer, _ := testPeer(true, []Protocol{proto}) + defer rw.Close() peer.startSubprotocols([]Cap{proto.cap()}) - bufr := bufio.NewReader(net) - msg, err := readMsg(bufr) - if err != nil { - t.Errorf("read error: %v", err) - } - if msg.Code != 17 { - t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) - } - var data []string - if err := msg.Decode(&data); err != nil { - t.Errorf("payload decode error: %v", err) - } - if !reflect.DeepEqual(data, []string{"foo", "bar"}) { - t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"}) + if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil { + t.Error(err) } } -func TestPeerWrite(t *testing.T) { +func TestPeerWriteForBroadcast(t *testing.T) { defer testlog(t).detach() - net, peer, peerErr := testPeer([]Protocol{discard}) - defer net.Close() + rw, peer, peerErr := testPeer(true, []Protocol{discard}) + defer rw.Close() peer.startSubprotocols([]Cap{discard.cap()}) // test write errors @@ -176,18 +159,13 @@ func TestPeerWrite(t *testing.T) { // setup for reading the message on the other end read := make(chan struct{}) go func() { - bufr := bufio.NewReader(net) - msg, err := readMsg(bufr) - if err != nil { - t.Errorf("read error: %v", err) - } else if msg.Code != 16 { - t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) + if err := expectMsg(rw, 16, nil); err != nil { + t.Error() } - msg.Discard() close(read) }() - // test succcessful write + // test successful write if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { t.Errorf("expect no error for known protocol: %v", err) } @@ -198,104 +176,153 @@ func TestPeerWrite(t *testing.T) { } } -func TestPeerActivity(t *testing.T) { - // shorten inactivityTimeout while this test is running - oldT := inactivityTimeout - defer func() { inactivityTimeout = oldT }() - inactivityTimeout = 20 * time.Millisecond +func TestPeerPing(t *testing.T) { + defer testlog(t).detach() - net, peer, peerErr := testPeer([]Protocol{discard}) - defer net.Close() - peer.startSubprotocols([]Cap{discard.cap()}) + rw, _, _ := testPeer(true, nil) + defer rw.Close() + if err := EncodeMsg(rw, pingMsg); err != nil { + t.Fatal(err) + } + if err := expectMsg(rw, pongMsg, nil); err != nil { + t.Error(err) + } +} - sub := peer.activity.Subscribe(time.Time{}) - defer sub.Unsubscribe() +func TestPeerDisconnect(t *testing.T) { + defer testlog(t).detach() - for i := 0; i < 6; i++ { - writeMsg(net, NewMsg(16)) - select { - case <-sub.Chan(): - case <-time.After(inactivityTimeout / 2): - t.Fatal("no event within ", inactivityTimeout/2) - case err := <-peerErr: - t.Fatal("peer error", err) - } + rw, _, disc := testPeer(true, nil) + defer rw.Close() + if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil { + t.Fatal(err) } - - select { - case <-time.After(inactivityTimeout * 2): - case <-sub.Chan(): - t.Fatal("got activity event while connection was inactive") - case err := <-peerErr: - t.Fatal("peer error", err) + if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil { + t.Error(err) + } + rw.Close() // make test end faster + if reason := <-disc; reason != DiscRequested { + t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested) } } -func TestNewPeer(t *testing.T) { - caps := []Cap{{"foo", 2}, {"bar", 3}} - id := &peerId{} - p := NewPeer(id, caps) - if !reflect.DeepEqual(p.Caps(), caps) { - t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) +func TestPeerHandshake(t *testing.T) { + defer testlog(t).detach() + + // remote has two matching protocols: a and c + remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}}) + remoteID := randomID() + remote.ourID = &remoteID + remote.ourName = "remote peer" + + start := make(chan string) + stop := make(chan struct{}) + run := func(p *Peer, rw MsgReadWriter) error { + name := rw.(*proto).name + if name != "a" && name != "c" { + t.Errorf("protocol %q should not be started", name) + } else { + start <- name + } + <-stop + return nil } - if p.Identity() != id { - t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id) + protocols := []Protocol{ + {Name: "a", Version: 1, Length: 1, Run: run}, + {Name: "b", Version: 2, Length: 1, Run: run}, + {Name: "c", Version: 3, Length: 1, Run: run}, + {Name: "d", Version: 4, Length: 1, Run: run}, } - // Should not hang. - p.Disconnect(DiscAlreadyConnected) -} + rw, p, disc := testPeer(false, protocols) + p.remoteID = remote.ourID + defer rw.Close() -func TestEOFSignal(t *testing.T) { - rb := make([]byte, 10) + // run the handshake + remoteProtocols := []Protocol{protocols[0], protocols[2]} + if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil { + t.Fatalf("handshake write error: %v", err) + } + if err := readProtocolHandshake(remote, rw); err != nil { + t.Fatalf("handshake read error: %v", err) + } - // empty reader - eof := make(chan struct{}, 1) - sig := &eofSignal{new(bytes.Buffer), 0, eof} - if n, err := sig.Read(rb); n != 0 || err != io.EOF { - t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + // check that all protocols have been started + var started []string + for i := 0; i < 2; i++ { + select { + case name := <-start: + started = append(started, name) + case <-time.After(100 * time.Millisecond): + } } - select { - case <-eof: - default: - t.Error("EOF chan not signaled") + sort.Strings(started) + if !reflect.DeepEqual(started, []string{"a", "c"}) { + t.Errorf("wrong protocols started: %v", started) } - // count before error - eof = make(chan struct{}, 1) - sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof} - if n, err := sig.Read(rb); n != 8 || err != nil { - t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + // check that metadata has been set + if p.ID() != remoteID { + t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID) } - select { - case <-eof: - default: - t.Error("EOF chan not signaled") + if p.Name() != remote.ourName { + t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName) } - // error before count - eof = make(chan struct{}, 1) - sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof} - if n, err := sig.Read(rb); n != 4 || err != nil { - t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + close(stop) + expectMsg(rw, discMsg, nil) + t.Logf("disc reason: %v", <-disc) +} + +func TestNewPeer(t *testing.T) { + name := "nodename" + caps := []Cap{{"foo", 2}, {"bar", 3}} + id := randomID() + p := NewPeer(id, name, caps) + if p.ID() != id { + t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id) } - if n, err := sig.Read(rb); n != 0 || err != io.EOF { - t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + if p.Name() != name { + t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name) } - select { - case <-eof: - default: - t.Error("EOF chan not signaled") + if !reflect.DeepEqual(p.Caps(), caps) { + t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) } - // no signal if neither occurs - eof = make(chan struct{}, 1) - sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof} - if n, err := sig.Read(rb); n != 10 || err != nil { - t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + p.Disconnect(DiscAlreadyConnected) // Should not hang +} + +// expectMsg reads a message from r and verifies that its +// code and encoded RLP content match the provided values. +// If content is nil, the payload is discarded and not verified. +func expectMsg(r MsgReader, code uint64, content interface{}) error { + msg, err := r.ReadMsg() + if err != nil { + return err } - select { - case <-eof: - t.Error("unexpected EOF signal") - default: + if msg.Code != code { + return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code) + } + if content == nil { + return msg.Discard() + } else { + contentEnc, err := rlp.EncodeToBytes(content) + if err != nil { + panic("content encode error: " + err.Error()) + } + // skip over list header in encoded value. this is temporary. + contentEncR := bytes.NewReader(contentEnc) + if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil { + panic("content must encode as RLP list") + } + contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():] + + actualContent, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return err + } + if !bytes.Equal(actualContent, contentEnc) { + return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc) + } } + return nil } diff --git a/p2p/protocol.go b/p2p/protocol.go index 1d121a885..fe359fc54 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -1,10 +1,5 @@ package p2p -import ( - "bytes" - "time" -) - // Protocol represents a P2P subprotocol implementation. type Protocol struct { // Name should contain the official protocol name, @@ -32,38 +27,6 @@ func (p Protocol) cap() Cap { return Cap{p.Name, p.Version} } -const ( - baseProtocolVersion = 2 - baseProtocolLength = uint64(16) - baseProtocolMaxMsgSize = 10 * 1024 * 1024 -) - -const ( - // devp2p message codes - handshakeMsg = 0x00 - discMsg = 0x01 - pingMsg = 0x02 - pongMsg = 0x03 - getPeersMsg = 0x04 - peersMsg = 0x05 -) - -// handshake is the structure of a handshake list. -type handshake struct { - Version uint64 - ID string - Caps []Cap - ListenPort uint64 - NodeID []byte -} - -func (h *handshake) String() string { - return h.ID -} -func (h *handshake) Pubkey() []byte { - return h.NodeID -} - // Cap is the structure of a peer capability. type Cap struct { Name string @@ -79,210 +42,3 @@ type capsByName []Cap func (cs capsByName) Len() int { return len(cs) } func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name } func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] } - -type baseProtocol struct { - rw MsgReadWriter - peer *Peer -} - -func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { - bp := &baseProtocol{rw, peer} - errc := make(chan error, 1) - go func() { errc <- rw.WriteMsg(bp.handshakeMsg()) }() - if err := bp.readHandshake(); err != nil { - return err - } - // handle write error - if err := <-errc; err != nil { - return err - } - // run main loop - go func() { - for { - if err := bp.handle(rw); err != nil { - errc <- err - break - } - } - }() - return bp.loop(errc) -} - -var pingTimeout = 2 * time.Second - -func (bp *baseProtocol) loop(quit <-chan error) error { - ping := time.NewTimer(pingTimeout) - activity := bp.peer.activity.Subscribe(time.Time{}) - lastActive := time.Time{} - defer ping.Stop() - defer activity.Unsubscribe() - - getPeersTick := time.NewTicker(10 * time.Second) - defer getPeersTick.Stop() - err := EncodeMsg(bp.rw, getPeersMsg) - - for err == nil { - select { - case err = <-quit: - return err - case <-getPeersTick.C: - err = EncodeMsg(bp.rw, getPeersMsg) - case event := <-activity.Chan(): - ping.Reset(pingTimeout) - lastActive = event.(time.Time) - case t := <-ping.C: - if lastActive.Add(pingTimeout * 2).Before(t) { - err = newPeerError(errPingTimeout, "") - } else if lastActive.Add(pingTimeout).Before(t) { - err = EncodeMsg(bp.rw, pingMsg) - } - } - } - return err -} - -func (bp *baseProtocol) handle(rw MsgReadWriter) error { - msg, err := rw.ReadMsg() - if err != nil { - return err - } - if msg.Size > baseProtocolMaxMsgSize { - return newPeerError(errMisc, "message too big") - } - // make sure that the payload has been fully consumed - defer msg.Discard() - - switch msg.Code { - case handshakeMsg: - return newPeerError(errProtocolBreach, "extra handshake received") - - case discMsg: - var reason [1]DiscReason - if err := msg.Decode(&reason); err != nil { - return err - } - return discRequestedError(reason[0]) - - case pingMsg: - return EncodeMsg(bp.rw, pongMsg) - - case pongMsg: - - case getPeersMsg: - peers := bp.peerList() - // this is dangerous. the spec says that we should _delay_ - // sending the response if no new information is available. - // this means that would need to send a response later when - // new peers become available. - // - // TODO: add event mechanism to notify baseProtocol for new peers - if len(peers) > 0 { - return EncodeMsg(bp.rw, peersMsg, peers...) - } - - case peersMsg: - var peers []*peerAddr - if err := msg.Decode(&peers); err != nil { - return err - } - for _, addr := range peers { - bp.peer.Debugf("received peer suggestion: %v", addr) - bp.peer.newPeerAddr <- addr - } - - default: - return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code) - } - return nil -} - -func (bp *baseProtocol) readHandshake() error { - // read and handle remote handshake - msg, err := bp.rw.ReadMsg() - if err != nil { - return err - } - if msg.Code != handshakeMsg { - return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code) - } - if msg.Size > baseProtocolMaxMsgSize { - return newPeerError(errMisc, "message too big") - } - var hs handshake - if err := msg.Decode(&hs); err != nil { - return err - } - // validate handshake info - if hs.Version != baseProtocolVersion { - return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", - baseProtocolVersion, hs.Version) - } - if len(hs.NodeID) == 0 { - return newPeerError(errPubkeyMissing, "") - } - if len(hs.NodeID) != 64 { - return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8) - } - if da := bp.peer.dialAddr; da != nil { - // verify that the peer we wanted to connect to - // actually holds the target public key. - if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) { - return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch") - } - } - pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) - if err := bp.peer.pubkeyHook(pa); err != nil { - return newPeerError(errPubkeyForbidden, "%v", err) - } - // TODO: remove Caps with empty name - var addr *peerAddr - if hs.ListenPort != 0 { - addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) - addr.Port = hs.ListenPort - } - bp.peer.setHandshakeInfo(&hs, addr, hs.Caps) - bp.peer.startSubprotocols(hs.Caps) - return nil -} - -func (bp *baseProtocol) handshakeMsg() Msg { - var ( - port uint64 - caps []interface{} - ) - if bp.peer.ourListenAddr != nil { - port = bp.peer.ourListenAddr.Port - } - for _, proto := range bp.peer.protocols { - caps = append(caps, proto.cap()) - } - return NewMsg(handshakeMsg, - baseProtocolVersion, - bp.peer.ourID.String(), - caps, - port, - bp.peer.ourID.Pubkey()[1:], - ) -} - -func (bp *baseProtocol) peerList() []interface{} { - peers := bp.peer.otherPeers() - ds := make([]interface{}, 0, len(peers)) - for _, p := range peers { - p.infolock.Lock() - addr := p.listenAddr - p.infolock.Unlock() - // filter out this peer and peers that are not listening or - // have not completed the handshake. - // TODO: track previously sent peers and exclude them as well. - if p == bp.peer || addr == nil { - continue - } - ds = append(ds, addr) - } - ourAddr := bp.peer.ourListenAddr - if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() { - ds = append(ds, ourAddr) - } - return ds -} diff --git a/p2p/protocol_test.go b/p2p/protocol_test.go deleted file mode 100644 index b1d10ac53..000000000 --- a/p2p/protocol_test.go +++ /dev/null @@ -1,158 +0,0 @@ -package p2p - -import ( - "fmt" - "net" - "reflect" - "sync" - "testing" - - "github.com/ethereum/go-ethereum/crypto" -) - -type peerId struct { - pubkey []byte -} - -func (self *peerId) String() string { - return fmt.Sprintf("test peer %x", self.Pubkey()[:4]) -} - -func (self *peerId) Pubkey() (pubkey []byte) { - pubkey = self.pubkey - if len(pubkey) == 0 { - pubkey = crypto.GenerateNewKeyPair().PublicKey - self.pubkey = pubkey - } - return -} - -func newTestPeer() (peer *Peer) { - peer = NewPeer(&peerId{}, []Cap{}) - peer.pubkeyHook = func(*peerAddr) error { return nil } - peer.ourID = &peerId{} - peer.listenAddr = &peerAddr{} - peer.otherPeers = func() []*Peer { return nil } - return -} - -func TestBaseProtocolPeers(t *testing.T) { - peerList := []*peerAddr{ - {IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}}, - {IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}}, - } - listenAddr := &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}} - rw1, rw2 := MsgPipe() - defer rw1.Close() - wg := new(sync.WaitGroup) - - // run matcher, close pipe when addresses have arrived - numPeers := len(peerList) + 1 - addrChan := make(chan *peerAddr) - wg.Add(1) - go func() { - i := 0 - for got := range addrChan { - var want *peerAddr - switch { - case i < len(peerList): - want = peerList[i] - case i == len(peerList): - want = listenAddr // listenAddr should be the last thing sent - } - t.Logf("got peer %d/%d: %v", i+1, numPeers, got) - if !reflect.DeepEqual(want, got) { - t.Errorf("mismatch: got %+v, want %+v", got, want) - } - i++ - if i == numPeers { - break - } - } - if i != numPeers { - t.Errorf("wrong number of peers received: got %d, want %d", i, numPeers) - } - rw1.Close() - wg.Done() - }() - - // run first peer (in background) - peer1 := newTestPeer() - peer1.ourListenAddr = listenAddr - peer1.otherPeers = func() []*Peer { - pl := make([]*Peer, len(peerList)) - for i, addr := range peerList { - pl[i] = &Peer{listenAddr: addr} - } - return pl - } - wg.Add(1) - go func() { - runBaseProtocol(peer1, rw1) - wg.Done() - }() - - // run second peer - peer2 := newTestPeer() - peer2.newPeerAddr = addrChan // feed peer suggestions into matcher - if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed { - t.Errorf("peer2 terminated with unexpected error: %v", err) - } - - // terminate matcher - close(addrChan) - wg.Wait() -} - -func TestBaseProtocolDisconnect(t *testing.T) { - peer := NewPeer(&peerId{}, nil) - peer.ourID = &peerId{} - peer.pubkeyHook = func(*peerAddr) error { return nil } - - rw1, rw2 := MsgPipe() - done := make(chan struct{}) - go func() { - if err := expectMsg(rw2, handshakeMsg); err != nil { - t.Error(err) - } - err := EncodeMsg(rw2, handshakeMsg, - baseProtocolVersion, - "", - []interface{}{}, - 0, - make([]byte, 64), - ) - if err != nil { - t.Error(err) - } - if err := expectMsg(rw2, getPeersMsg); err != nil { - t.Error(err) - } - if err := EncodeMsg(rw2, discMsg, DiscQuitting); err != nil { - t.Error(err) - } - - close(done) - }() - - if err := runBaseProtocol(peer, rw1); err == nil { - t.Errorf("base protocol returned without error") - } else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting { - t.Errorf("base protocol returned wrong error: %v", err) - } - <-done -} - -func expectMsg(r MsgReader, code uint64) error { - msg, err := r.ReadMsg() - if err != nil { - return err - } - if err := msg.Discard(); err != nil { - return err - } - if msg.Code != code { - return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code) - } - return nil -} diff --git a/p2p/server.go b/p2p/server.go index e0d9f18a5..e510be521 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -2,37 +2,56 @@ package p2p import ( "bytes" + "crypto/ecdsa" "errors" "fmt" + "io" "net" + "runtime" "sync" "time" "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/nat" ) const ( - outboundAddressPoolSize = 500 - defaultDialTimeout = 10 * time.Second - portMappingUpdateInterval = 15 * time.Minute - portMappingTimeout = 20 * time.Minute + handshakeTimeout = 5 * time.Second + defaultDialTimeout = 10 * time.Second + refreshPeersInterval = 30 * time.Second ) var srvlog = logger.NewLogger("P2P Server") +// MakeName creates a node name that follows the ethereum convention +// for such names. It adds the operation system name and Go runtime version +// the name. +func MakeName(name, version string) string { + return fmt.Sprintf("%s/v%s/%s/%s", name, version, runtime.GOOS, runtime.Version()) +} + // Server manages all peer connections. // // The fields of Server are used as configuration parameters. // You should set them before starting the Server. Fields may not be // modified while the server is running. type Server struct { - // This field must be set to a valid client identity. - Identity ClientIdentity + // This field must be set to a valid secp256k1 private key. + PrivateKey *ecdsa.PrivateKey // MaxPeers is the maximum number of peers that can be // connected. It must be greater than zero. MaxPeers int + // Name sets the node name of this server. + // Use MakeName to create a name that follows existing conventions. + Name string + + // Bootstrap nodes are used to establish connectivity + // with the rest of the network. + BootstrapNodes []*discover.Node + // Protocols should contain the protocols supported // by the server. Matching protocols are launched for // each peer. @@ -53,7 +72,7 @@ type Server struct { // If set to a non-nil value, the given NAT port mapper // is used to make the listening port available to the // Internet. - NAT NAT + NAT nat.Interface // If Dialer is set to a non-nil value, the given Dialer // is used to dial outbound peer connections. @@ -62,35 +81,26 @@ type Server struct { // If NoDial is true, the server will not dial any peers. NoDial bool - // Hook for testing. This is useful because we can inhibit + // Hooks for testing. These are useful because we can inhibit // the whole protocol stack. - newPeerFunc peerFunc + handshakeFunc + newPeerHook - lock sync.RWMutex - running bool - listener net.Listener - laddr *net.TCPAddr // real listen addr - peers []*Peer - peerSlots chan int - peerCount int - - quit chan struct{} - wg sync.WaitGroup - peerConnect chan *peerAddr - peerDisconnect chan *Peer -} + lock sync.RWMutex + running bool + listener net.Listener + peers map[discover.NodeID]*Peer -// NAT is implemented by NAT traversal methods. -type NAT interface { - GetExternalAddress() (net.IP, error) - AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error - DeletePortMapping(protocol string, extport, intport int) error + ntab *discover.Table - // Should return name of the method. - String() string + quit chan struct{} + loopWG sync.WaitGroup // {dial,listen,nat}Loop + peerWG sync.WaitGroup // active peer goroutines + peerConnect chan *discover.Node } -type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer +type handshakeFunc func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (discover.NodeID, []byte, error) +type newPeerHook func(*Peer) // Peers returns all connected peers. func (srv *Server) Peers() (peers []*Peer) { @@ -107,18 +117,15 @@ func (srv *Server) Peers() (peers []*Peer) { // PeerCount returns the number of connected peers. func (srv *Server) PeerCount() int { srv.lock.RLock() - defer srv.lock.RUnlock() - return srv.peerCount + n := len(srv.peers) + srv.lock.RUnlock() + return n } -// SuggestPeer injects an address into the outbound address pool. -func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) { - addr := &peerAddr{ip, uint64(port), nodeID} - select { - case srv.peerConnect <- addr: - default: // don't block - srvlog.Warnf("peer suggestion %v ignored", addr) - } +// SuggestPeer creates a connection to the given Node if it +// is not already connected. +func (srv *Server) SuggestPeer(n *discover.Node) { + srv.peerConnect <- n } // Broadcast sends an RLP-encoded message to all connected peers. @@ -152,47 +159,46 @@ func (srv *Server) Start() (err error) { } srvlog.Infoln("Starting Server") - // initialize fields - if srv.Identity == nil { - return fmt.Errorf("Server.Identity must be set to a non-nil identity") + // initialize all the fields + if srv.PrivateKey == nil { + return fmt.Errorf("Server.PrivateKey must be set to a non-nil key") } if srv.MaxPeers <= 0 { return fmt.Errorf("Server.MaxPeers must be > 0") } srv.quit = make(chan struct{}) - srv.peers = make([]*Peer, srv.MaxPeers) - srv.peerSlots = make(chan int, srv.MaxPeers) - srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize) - srv.peerDisconnect = make(chan *Peer) - if srv.newPeerFunc == nil { - srv.newPeerFunc = newServerPeer + srv.peers = make(map[discover.NodeID]*Peer) + srv.peerConnect = make(chan *discover.Node) + + if srv.handshakeFunc == nil { + srv.handshakeFunc = encHandshake } if srv.Blacklist == nil { srv.Blacklist = NewBlacklist() } - if srv.Dialer == nil { - srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} - } - if srv.ListenAddr != "" { if err := srv.startListening(); err != nil { return err } } + + // dial stuff + dt, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT) + if err != nil { + return err + } + srv.ntab = dt + if srv.Dialer == nil { + srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} + } if !srv.NoDial { - srv.wg.Add(1) + srv.loopWG.Add(1) go srv.dialLoop() } if srv.NoDial && srv.ListenAddr == "" { srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.") } - // make all slots available - for i := range srv.peers { - srv.peerSlots <- i - } - // note: discLoop is not part of WaitGroup - go srv.discLoop() srv.running = true return nil } @@ -202,14 +208,17 @@ func (srv *Server) startListening() error { if err != nil { return err } - srv.ListenAddr = listener.Addr().String() - srv.laddr = listener.Addr().(*net.TCPAddr) + laddr := listener.Addr().(*net.TCPAddr) + srv.ListenAddr = laddr.String() srv.listener = listener - srv.wg.Add(1) + srv.loopWG.Add(1) go srv.listenLoop() - if !srv.laddr.IP.IsLoopback() && srv.NAT != nil { - srv.wg.Add(1) - go srv.natLoop(srv.laddr.Port) + if !laddr.IP.IsLoopback() && srv.NAT != nil { + srv.loopWG.Add(1) + go func() { + nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p") + srv.loopWG.Done() + }() } return nil } @@ -225,200 +234,171 @@ func (srv *Server) Stop() { srv.running = false srv.lock.Unlock() - srvlog.Infoln("Stopping server") + srvlog.Infoln("Stopping Server") + srv.ntab.Close() if srv.listener != nil { // this unblocks listener Accept srv.listener.Close() } close(srv.quit) - for _, peer := range srv.Peers() { - peer.Disconnect(DiscQuitting) - } - srv.wg.Wait() - - // wait till they actually disconnect - // this is checked by claiming all peerSlots. - // slots become available as the peers disconnect. - for i := 0; i < cap(srv.peerSlots); i++ { - <-srv.peerSlots - } - // terminate discLoop - close(srv.peerDisconnect) -} + srv.loopWG.Wait() -func (srv *Server) discLoop() { - for peer := range srv.peerDisconnect { - srv.removePeer(peer) + // No new peers can be added at this point because dialLoop and + // listenLoop are down. It is safe to call peerWG.Wait because + // peerWG.Add is not called outside of those loops. + for _, peer := range srv.peers { + peer.Disconnect(DiscQuitting) } + srv.peerWG.Wait() } // main loop for adding connections via listening func (srv *Server) listenLoop() { - defer srv.wg.Done() - + defer srv.loopWG.Done() srvlog.Infoln("Listening on", srv.listener.Addr()) for { - select { - case slot := <-srv.peerSlots: - srvlog.Debugf("grabbed slot %v for listening", slot) - conn, err := srv.listener.Accept() - if err != nil { - srv.peerSlots <- slot - return - } - srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot) - srv.addPeer(conn, nil, slot) - case <-srv.quit: + conn, err := srv.listener.Accept() + if err != nil { return } + srvlog.Debugf("Accepted conn %v\n", conn.RemoteAddr()) + srv.peerWG.Add(1) + go srv.startPeer(conn, nil) } } -func (srv *Server) natLoop(port int) { - defer srv.wg.Done() +func (srv *Server) dialLoop() { + defer srv.loopWG.Done() + refresh := time.NewTicker(refreshPeersInterval) + defer refresh.Stop() + + srv.ntab.Bootstrap(srv.BootstrapNodes) + go srv.findPeers() + + dialed := make(chan *discover.Node) + dialing := make(map[discover.NodeID]bool) + + // TODO: limit number of active dials + // TODO: ensure only one findPeers goroutine is running + // TODO: pause findPeers when we're at capacity + for { - srv.updatePortMapping(port) select { - case <-time.After(portMappingUpdateInterval): - // one more round + case <-refresh.C: + + go srv.findPeers() + + case dest := <-srv.peerConnect: + // avoid dialing nodes that are already connected. + // there is another check for this in addPeer, + // which runs after the handshake. + srv.lock.Lock() + _, isconnected := srv.peers[dest.ID] + srv.lock.Unlock() + if isconnected || dialing[dest.ID] || dest.ID == srv.ntab.Self() { + continue + } + + dialing[dest.ID] = true + srv.peerWG.Add(1) + go func() { + srv.dialNode(dest) + // at this point, the peer has been added + // or discarded. either way, we're not dialing it anymore. + dialed <- dest + }() + + case dest := <-dialed: + delete(dialing, dest.ID) + case <-srv.quit: - srv.removePortMapping(port) + // TODO: maybe wait for active dials return } } } -func (srv *Server) updatePortMapping(port int) { - srvlog.Infoln("Attempting to map port", port, "with", srv.NAT) - err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout) +func (srv *Server) dialNode(dest *discover.Node) { + addr := &net.TCPAddr{IP: dest.IP, Port: dest.TCPPort} + srvlog.Debugf("Dialing %v\n", dest) + conn, err := srv.Dialer.Dial("tcp", addr.String()) if err != nil { - srvlog.Errorln("Port mapping error:", err) - return - } - extip, err := srv.NAT.GetExternalAddress() - if err != nil { - srvlog.Errorln("Error getting external IP:", err) + srvlog.DebugDetailf("dial error: %v", err) return } - srv.lock.Lock() - extaddr := *(srv.listener.Addr().(*net.TCPAddr)) - extaddr.IP = extip - srvlog.Infoln("Mapped port, external addr is", &extaddr) - srv.laddr = &extaddr - srv.lock.Unlock() + srv.startPeer(conn, dest) } -func (srv *Server) removePortMapping(port int) { - srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT) - srv.NAT.DeletePortMapping("tcp", port, port) -} - -func (srv *Server) dialLoop() { - defer srv.wg.Done() - var ( - suggest chan *peerAddr - slot *int - slots = srv.peerSlots - ) - for { - select { - case i := <-slots: - // we need a peer in slot i, slot reserved - slot = &i - // now we can watch for candidate peers in the next loop - suggest = srv.peerConnect - // do not consume more until candidate peer is found - slots = nil - - case desc := <-suggest: - // candidate peer found, will dial out asyncronously - // if connection fails slot will be released - srvlog.DebugDetailf("dial %v (%v)", desc, *slot) - go srv.dialPeer(desc, *slot) - // we can watch if more peers needed in the next loop - slots = srv.peerSlots - // until then we dont care about candidate peers - suggest = nil +func (srv *Server) findPeers() { + far := srv.ntab.Self() + for i := range far { + far[i] = ^far[i] + } + closeToSelf := srv.ntab.Lookup(srv.ntab.Self()) + farFromSelf := srv.ntab.Lookup(far) - case <-srv.quit: - // give back the currently reserved slot - if slot != nil { - srv.peerSlots <- *slot - } - return + for i := 0; i < len(closeToSelf) || i < len(farFromSelf); i++ { + if i < len(closeToSelf) { + srv.peerConnect <- closeToSelf[i] + } + if i < len(farFromSelf) { + srv.peerConnect <- farFromSelf[i] } } } -// connect to peer via dial out -func (srv *Server) dialPeer(desc *peerAddr, slot int) { - srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot) - conn, err := srv.Dialer.Dial(desc.Network(), desc.String()) +func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) { + // TODO: handle/store session token + conn.SetDeadline(time.Now().Add(handshakeTimeout)) + remoteID, _, err := srv.handshakeFunc(conn, srv.PrivateKey, dest) if err != nil { - srvlog.DebugDetailf("dial error: %v", err) - srv.peerSlots <- slot + conn.Close() + srvlog.Debugf("Encryption Handshake with %v failed: %v", conn.RemoteAddr(), err) return } - go srv.addPeer(conn, desc, slot) -} + ourID := srv.ntab.Self() + p := newPeer(conn, srv.Protocols, srv.Name, &ourID, &remoteID) + if ok, reason := srv.addPeer(remoteID, p); !ok { + srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason) + p.politeDisconnect(reason) + return + } + srvlog.Debugf("Added %v\n", p) -// creates the new peer object and inserts it into its slot -func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer { - srv.lock.Lock() - defer srv.lock.Unlock() - if !srv.running { - conn.Close() - srv.peerSlots <- slot // release slot - return nil + if srv.newPeerHook != nil { + srv.newPeerHook(p) } - peer := srv.newPeerFunc(srv, conn, desc) - peer.slot = slot - srv.peers[slot] = peer - srv.peerCount++ - go func() { - peer.loop() - srv.peerDisconnect <- peer - }() - return peer + discreason := p.run() + srv.removePeer(p) + srvlog.Debugf("Removed %v (%v)\n", p, discreason) } -// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot -func (srv *Server) removePeer(peer *Peer) { +func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) { srv.lock.Lock() defer srv.lock.Unlock() - srvlog.Debugf("Removing %v (slot %v)\n", peer, peer.slot) - if srv.peers[peer.slot] != peer { - srvlog.Warnln("Invalid peer to remove:", peer) - return - } - // remove from list and index - srv.peerCount-- - srv.peers[peer.slot] = nil - // release slot to signal need for a new peer, last! - srv.peerSlots <- peer.slot + switch { + case !srv.running: + return false, DiscQuitting + case len(srv.peers) >= srv.MaxPeers: + return false, DiscTooManyPeers + case srv.peers[id] != nil: + return false, DiscAlreadyConnected + case srv.Blacklist.Exists(id[:]): + return false, DiscUselessPeer + case id == srv.ntab.Self(): + return false, DiscSelf + } + srv.peers[id] = p + return true, 0 } -func (srv *Server) verifyPeer(addr *peerAddr) error { - if srv.Blacklist.Exists(addr.Pubkey) { - return errors.New("blacklisted") - } - if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) { - return newPeerError(errPubkeyForbidden, "not allowed to connect to srv") - } - srv.lock.RLock() - defer srv.lock.RUnlock() - for _, peer := range srv.peers { - if peer != nil { - id := peer.Identity() - if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) { - return errors.New("already connected") - } - } - } - return nil +func (srv *Server) removePeer(p *Peer) { + srv.lock.Lock() + delete(srv.peers, *p.remoteID) + srv.lock.Unlock() + srv.peerWG.Done() } -// TODO replace with "Set" type Blacklist interface { Get([]byte) (bool, error) Put([]byte) error diff --git a/p2p/server_test.go b/p2p/server_test.go index ceb89e3f7..aa2b3d243 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -2,19 +2,28 @@ package p2p import ( "bytes" + "crypto/ecdsa" "io" + "math/rand" "net" "sync" "testing" "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/discover" ) -func startTestServer(t *testing.T, pf peerFunc) *Server { +func startTestServer(t *testing.T, pf newPeerHook) *Server { server := &Server{ - Identity: &peerId{}, + Name: "test", MaxPeers: 10, ListenAddr: "127.0.0.1:0", - newPeerFunc: pf, + PrivateKey: newkey(), + newPeerHook: pf, + handshakeFunc: func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (id discover.NodeID, st []byte, err error) { + return randomID(), nil, err + }, } if err := server.Start(); err != nil { t.Fatalf("Could not start server: %v", err) @@ -27,16 +36,11 @@ func TestServerListen(t *testing.T) { // start the test server connected := make(chan *Peer) - srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { - if conn == nil { + srv := startTestServer(t, func(p *Peer) { + if p == nil { t.Error("peer func called with nil conn") } - if dialAddr != nil { - t.Error("peer func called with non-nil dialAddr") - } - peer := newPeer(conn, nil, dialAddr) - connected <- peer - return peer + connected <- p }) defer close(connected) defer srv.Stop() @@ -50,9 +54,9 @@ func TestServerListen(t *testing.T) { select { case peer := <-connected: - if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() { + if peer.LocalAddr().String() != conn.RemoteAddr().String() { t.Errorf("peer started with wrong conn: got %v, want %v", - peer.conn.LocalAddr(), conn.RemoteAddr()) + peer.LocalAddr(), conn.RemoteAddr()) } case <-time.After(1 * time.Second): t.Error("server did not accept within one second") @@ -62,7 +66,7 @@ func TestServerListen(t *testing.T) { func TestServerDial(t *testing.T) { defer testlog(t).detach() - // run a fake TCP server to handle the connection. + // run a one-shot TCP server to handle the connection. listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("could not setup listener: %v") @@ -72,41 +76,32 @@ func TestServerDial(t *testing.T) { go func() { conn, err := listener.Accept() if err != nil { - t.Error("acccept error:", err) + t.Error("accept error:", err) + return } conn.Close() accepted <- conn }() - // start the test server + // start the server connected := make(chan *Peer) - srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { - if conn == nil { - t.Error("peer func called with nil conn") - } - peer := newPeer(conn, nil, dialAddr) - connected <- peer - return peer - }) + srv := startTestServer(t, func(p *Peer) { connected <- p }) defer close(connected) defer srv.Stop() - // tell the server to connect. - connAddr := newPeerAddr(listener.Addr(), nil) - srv.peerConnect <- connAddr + // tell the server to connect + tcpAddr := listener.Addr().(*net.TCPAddr) + srv.SuggestPeer(&discover.Node{IP: tcpAddr.IP, TCPPort: tcpAddr.Port}) select { case conn := <-accepted: select { case peer := <-connected: - if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() { + if peer.RemoteAddr().String() != conn.LocalAddr().String() { t.Errorf("peer started with wrong conn: got %v, want %v", - peer.conn.RemoteAddr(), conn.LocalAddr()) - } - if peer.dialAddr != connAddr { - t.Errorf("peer started with wrong dialAddr: got %v, want %v", - peer.dialAddr, connAddr) + peer.RemoteAddr(), conn.LocalAddr()) } + // TODO: validate more fields case <-time.After(1 * time.Second): t.Error("server did not launch peer within one second") } @@ -118,16 +113,17 @@ func TestServerDial(t *testing.T) { func TestServerBroadcast(t *testing.T) { defer testlog(t).detach() + var connected sync.WaitGroup - srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer { - peer := newPeer(c, []Protocol{discard}, dialAddr) - peer.startSubprotocols([]Cap{discard.cap()}) + srv := startTestServer(t, func(p *Peer) { + p.protocols = []Protocol{discard} + p.startSubprotocols([]Cap{discard.cap()}) + p.noHandshake = true connected.Done() - return peer }) defer srv.Stop() - // dial a bunch of conns + // create a few peers var conns = make([]net.Conn, 8) connected.Add(len(conns)) deadline := time.Now().Add(3 * time.Second) @@ -159,3 +155,18 @@ func TestServerBroadcast(t *testing.T) { } } } + +func newkey() *ecdsa.PrivateKey { + key, err := crypto.GenerateKey() + if err != nil { + panic("couldn't generate key: " + err.Error()) + } + return key +} + +func randomID() (id discover.NodeID) { + for i := range id { + id[i] = byte(rand.Intn(255)) + } + return id +} diff --git a/p2p/testlog_test.go b/p2p/testlog_test.go index 951d43243..c524c154c 100644 --- a/p2p/testlog_test.go +++ b/p2p/testlog_test.go @@ -15,7 +15,7 @@ func testlog(t *testing.T) testLogger { return l } -func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel } +func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugDetailLevel } func (testLogger) SetLogLevel(logger.LogLevel) {} func (l testLogger) LogPrint(level logger.LogLevel, msg string) { diff --git a/p2p/testpoc7.go b/p2p/testpoc7.go deleted file mode 100644 index 63b996f79..000000000 --- a/p2p/testpoc7.go +++ /dev/null @@ -1,40 +0,0 @@ -// +build none - -package main - -import ( - "fmt" - "log" - "net" - "os" - - "github.com/ethereum/go-ethereum/crypto/secp256k1" - "github.com/ethereum/go-ethereum/logger" - "github.com/ethereum/go-ethereum/p2p" -) - -func main() { - logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel)) - - pub, _ := secp256k1.GenerateKeyPair() - srv := p2p.Server{ - MaxPeers: 10, - Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)), - ListenAddr: ":30303", - NAT: p2p.PMP(net.ParseIP("10.0.0.1")), - } - if err := srv.Start(); err != nil { - fmt.Println("could not start server:", err) - os.Exit(1) - } - - // add seed peers - seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303") - if err != nil { - fmt.Println("couldn't resolve:", err) - os.Exit(1) - } - srv.SuggestPeer(seed.IP, seed.Port, nil) - - select {} -} |