diff options
Diffstat (limited to 'p2p')
-rw-r--r-- | p2p/client_identity.go | 63 | ||||
-rw-r--r-- | p2p/client_identity_test.go | 30 | ||||
-rw-r--r-- | p2p/message.go | 232 | ||||
-rw-r--r-- | p2p/message_test.go | 133 | ||||
-rw-r--r-- | p2p/natpmp.go | 55 | ||||
-rw-r--r-- | p2p/natupnp.go | 341 | ||||
-rw-r--r-- | p2p/peer.go | 462 | ||||
-rw-r--r-- | p2p/peer_error.go | 133 | ||||
-rw-r--r-- | p2p/peer_test.go | 295 | ||||
-rw-r--r-- | p2p/protocol.go | 294 | ||||
-rw-r--r-- | p2p/protocol_test.go | 58 | ||||
-rw-r--r-- | p2p/server.go | 467 | ||||
-rw-r--r-- | p2p/server_test.go | 161 | ||||
-rw-r--r-- | p2p/testlog_test.go | 28 | ||||
-rw-r--r-- | p2p/testpoc7.go | 40 |
15 files changed, 2792 insertions, 0 deletions
diff --git a/p2p/client_identity.go b/p2p/client_identity.go new file mode 100644 index 000000000..bc865b63b --- /dev/null +++ b/p2p/client_identity.go @@ -0,0 +1,63 @@ +package p2p + +import ( + "fmt" + "runtime" +) + +// ClientIdentity represents the identity of a peer. +type ClientIdentity interface { + String() string // human readable identity + Pubkey() []byte // 512-bit public key +} + +type SimpleClientIdentity struct { + clientIdentifier string + version string + customIdentifier string + os string + implementation string + pubkey string +} + +func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey string) *SimpleClientIdentity { + clientIdentity := &SimpleClientIdentity{ + clientIdentifier: clientIdentifier, + version: version, + customIdentifier: customIdentifier, + os: runtime.GOOS, + implementation: runtime.Version(), + pubkey: pubkey, + } + + return clientIdentity +} + +func (c *SimpleClientIdentity) init() { +} + +func (c *SimpleClientIdentity) String() string { + var id string + if len(c.customIdentifier) > 0 { + id = "/" + c.customIdentifier + } + + return fmt.Sprintf("%s/v%s%s/%s/%s", + c.clientIdentifier, + c.version, + id, + c.os, + c.implementation) +} + +func (c *SimpleClientIdentity) Pubkey() []byte { + return []byte(c.pubkey) +} + +func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) { + c.customIdentifier = customIdentifier +} + +func (c *SimpleClientIdentity) GetCustomIdentifier() string { + return c.customIdentifier +} diff --git a/p2p/client_identity_test.go b/p2p/client_identity_test.go new file mode 100644 index 000000000..40b0e6f5e --- /dev/null +++ b/p2p/client_identity_test.go @@ -0,0 +1,30 @@ +package p2p + +import ( + "fmt" + "runtime" + "testing" +) + +func TestClientIdentity(t *testing.T) { + clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", "pubkey") + clientString := clientIdentity.String() + expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version()) + if clientString != expected { + t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString) + } + customIdentifier := clientIdentity.GetCustomIdentifier() + if customIdentifier != "test" { + t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier) + } + clientIdentity.SetCustomIdentifier("test2") + customIdentifier = clientIdentity.GetCustomIdentifier() + if customIdentifier != "test2" { + t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier) + } + clientString = clientIdentity.String() + expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version()) + if clientString != expected { + t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString) + } +} diff --git a/p2p/message.go b/p2p/message.go new file mode 100644 index 000000000..f5418ff47 --- /dev/null +++ b/p2p/message.go @@ -0,0 +1,232 @@ +package p2p + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "io/ioutil" + "math/big" + "sync/atomic" + + "github.com/ethereum/go-ethereum/ethutil" + "github.com/ethereum/go-ethereum/rlp" +) + +// Msg defines the structure of a p2p message. +// +// Note that a Msg can only be sent once since the Payload reader is +// consumed during sending. It is not possible to create a Msg and +// send it any number of times. If you want to reuse an encoded +// structure, encode the payload into a byte array and create a +// separate Msg with a bytes.Reader as Payload for each send. +type Msg struct { + Code uint64 + Size uint32 // size of the paylod + Payload io.Reader +} + +// NewMsg creates an RLP-encoded message with the given code. +func NewMsg(code uint64, params ...interface{}) Msg { + buf := new(bytes.Buffer) + for _, p := range params { + buf.Write(ethutil.Encode(p)) + } + return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf} +} + +func encodePayload(params ...interface{}) []byte { + buf := new(bytes.Buffer) + for _, p := range params { + buf.Write(ethutil.Encode(p)) + } + return buf.Bytes() +} + +// Decode parse the RLP content of a message into +// the given value, which must be a pointer. +// +// For the decoding rules, please see package rlp. +func (msg Msg) Decode(val interface{}) error { + s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) + return s.Decode(val) +} + +// Discard reads any remaining payload data into a black hole. +func (msg Msg) Discard() error { + _, err := io.Copy(ioutil.Discard, msg.Payload) + return err +} + +type MsgReader interface { + ReadMsg() (Msg, error) +} + +type MsgWriter interface { + // WriteMsg sends an existing message. + // The Payload reader of the message is consumed. + // Note that messages can be sent only once. + WriteMsg(Msg) error + + // EncodeMsg writes an RLP-encoded message with the given + // code and data elements. + EncodeMsg(code uint64, data ...interface{}) error +} + +// MsgReadWriter provides reading and writing of encoded messages. +type MsgReadWriter interface { + MsgReader + MsgWriter +} + +var magicToken = []byte{34, 64, 8, 145} + +func writeMsg(w io.Writer, msg Msg) error { + // TODO: handle case when Size + len(code) + len(listhdr) overflows uint32 + code := ethutil.Encode(uint32(msg.Code)) + listhdr := makeListHeader(msg.Size + uint32(len(code))) + payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size + + start := make([]byte, 8) + copy(start, magicToken) + binary.BigEndian.PutUint32(start[4:], payloadLen) + + for _, b := range [][]byte{start, listhdr, code} { + if _, err := w.Write(b); err != nil { + return err + } + } + _, err := io.CopyN(w, msg.Payload, int64(msg.Size)) + return err +} + +func makeListHeader(length uint32) []byte { + if length < 56 { + return []byte{byte(length + 0xc0)} + } + enc := big.NewInt(int64(length)).Bytes() + lenb := byte(len(enc)) + 0xf7 + return append([]byte{lenb}, enc...) +} + +// readMsg reads a message header from r. +// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer. +func readMsg(r rlp.ByteReader) (msg Msg, err error) { + // read magic and payload size + start := make([]byte, 8) + if _, err = io.ReadFull(r, start); err != nil { + return msg, newPeerError(errRead, "%v", err) + } + if !bytes.HasPrefix(start, magicToken) { + return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken) + } + size := binary.BigEndian.Uint32(start[4:]) + + // decode start of RLP message to get the message code + posr := &postrack{r, 0} + s := rlp.NewStream(posr) + if _, err := s.List(); err != nil { + return msg, err + } + code, err := s.Uint() + if err != nil { + return msg, err + } + payloadsize := size - posr.p + return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil +} + +// postrack wraps an rlp.ByteReader with a position counter. +type postrack struct { + r rlp.ByteReader + p uint32 +} + +func (r *postrack) Read(buf []byte) (int, error) { + n, err := r.r.Read(buf) + r.p += uint32(n) + return n, err +} + +func (r *postrack) ReadByte() (byte, error) { + b, err := r.r.ReadByte() + if err == nil { + r.p++ + } + return b, err +} + +// MsgPipe creates a message pipe. Reads on one end are matched +// with writes on the other. The pipe is full-duplex, both ends +// implement MsgReadWriter. +func MsgPipe() (*MsgPipeRW, *MsgPipeRW) { + var ( + c1, c2 = make(chan Msg), make(chan Msg) + closing = make(chan struct{}) + closed = new(int32) + rw1 = &MsgPipeRW{c1, c2, closing, closed} + rw2 = &MsgPipeRW{c2, c1, closing, closed} + ) + return rw1, rw2 +} + +// ErrPipeClosed is returned from pipe operations after the +// pipe has been closed. +var ErrPipeClosed = errors.New("p2p: read or write on closed message pipe") + +// MsgPipeRW is an endpoint of a MsgReadWriter pipe. +type MsgPipeRW struct { + w chan<- Msg + r <-chan Msg + closing chan struct{} + closed *int32 +} + +// WriteMsg sends a messsage on the pipe. +// It blocks until the receiver has consumed the message payload. +func (p *MsgPipeRW) WriteMsg(msg Msg) error { + if atomic.LoadInt32(p.closed) == 0 { + consumed := make(chan struct{}, 1) + msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed} + select { + case p.w <- msg: + if msg.Size > 0 { + // wait for payload read or discard + <-consumed + } + return nil + case <-p.closing: + } + } + return ErrPipeClosed +} + +// EncodeMsg is a convenient shorthand for sending an RLP-encoded message. +func (p *MsgPipeRW) EncodeMsg(code uint64, data ...interface{}) error { + return p.WriteMsg(NewMsg(code, data...)) +} + +// ReadMsg returns a message sent on the other end of the pipe. +func (p *MsgPipeRW) ReadMsg() (Msg, error) { + if atomic.LoadInt32(p.closed) == 0 { + select { + case msg := <-p.r: + return msg, nil + case <-p.closing: + } + } + return Msg{}, ErrPipeClosed +} + +// Close unblocks any pending ReadMsg and WriteMsg calls on both ends +// of the pipe. They will return ErrPipeClosed. Note that Close does +// not interrupt any reads from a message payload. +func (p *MsgPipeRW) Close() error { + if atomic.AddInt32(p.closed, 1) != 1 { + // someone else is already closing + atomic.StoreInt32(p.closed, 1) // avoid overflow + return nil + } + close(p.closing) + return nil +} diff --git a/p2p/message_test.go b/p2p/message_test.go new file mode 100644 index 000000000..066d2516d --- /dev/null +++ b/p2p/message_test.go @@ -0,0 +1,133 @@ +package p2p + +import ( + "bytes" + "fmt" + "io/ioutil" + "runtime" + "testing" + "time" + + "github.com/ethereum/go-ethereum/ethutil" +) + +func TestNewMsg(t *testing.T) { + msg := NewMsg(3, 1, "000") + if msg.Code != 3 { + t.Errorf("incorrect code %d, want %d", msg.Code) + } + if msg.Size != 5 { + t.Errorf("incorrect size %d, want %d", msg.Size, 5) + } + pl, _ := ioutil.ReadAll(msg.Payload) + expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30} + if !bytes.Equal(pl, expect) { + t.Errorf("incorrect payload content, got %x, want %x", pl, expect) + } +} + +func TestEncodeDecodeMsg(t *testing.T) { + msg := NewMsg(3, 1, "000") + buf := new(bytes.Buffer) + if err := writeMsg(buf, msg); err != nil { + t.Fatalf("encodeMsg error: %v", err) + } + // t.Logf("encoded: %x", buf.Bytes()) + + decmsg, err := readMsg(buf) + if err != nil { + t.Fatalf("readMsg error: %v", err) + } + if decmsg.Code != 3 { + t.Errorf("incorrect code %d, want %d", decmsg.Code, 3) + } + if decmsg.Size != 5 { + t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) + } + + var data struct { + I uint + S string + } + if err := decmsg.Decode(&data); err != nil { + t.Fatalf("Decode error: %v", err) + } + if data.I != 1 { + t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1) + } + if data.S != "000" { + t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000") + } +} + +func TestDecodeRealMsg(t *testing.T) { + data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb") + msg, err := readMsg(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if msg.Code != 0 { + t.Errorf("incorrect code %d, want %d", msg.Code, 0) + } +} + +func ExampleMsgPipe() { + rw1, rw2 := MsgPipe() + go func() { + rw1.EncodeMsg(8, []byte{0, 0}) + rw1.EncodeMsg(5, []byte{1, 1}) + rw1.Close() + }() + + for { + msg, err := rw2.ReadMsg() + if err != nil { + break + } + var data [1][]byte + msg.Decode(&data) + fmt.Printf("msg: %d, %x\n", msg.Code, data[0]) + } + // Output: + // msg: 8, 0000 + // msg: 5, 0101 +} + +func TestMsgPipeUnblockWrite(t *testing.T) { +loop: + for i := 0; i < 100; i++ { + rw1, rw2 := MsgPipe() + done := make(chan struct{}) + go func() { + if err := rw1.EncodeMsg(1); err == nil { + t.Error("EncodeMsg returned nil error") + } else if err != ErrPipeClosed { + t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed) + } + close(done) + }() + + // this call should ensure that EncodeMsg is waiting to + // deliver sometimes. if this isn't done, Close is likely to + // be executed before EncodeMsg starts and then we won't test + // all the cases. + runtime.Gosched() + + rw2.Close() + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Errorf("write didn't unblock") + break loop + } + } +} + +// This test should panic if concurrent close isn't implemented correctly. +func TestMsgPipeConcurrentClose(t *testing.T) { + rw1, _ := MsgPipe() + for i := 0; i < 10; i++ { + go rw1.Close() + } +} diff --git a/p2p/natpmp.go b/p2p/natpmp.go new file mode 100644 index 000000000..6714678c4 --- /dev/null +++ b/p2p/natpmp.go @@ -0,0 +1,55 @@ +package p2p + +import ( + "fmt" + "net" + "time" + + natpmp "github.com/jackpal/go-nat-pmp" +) + +// Adapt the NAT-PMP protocol to the NAT interface + +// TODO: +// + Register for changes to the external address. +// + Re-register port mapping when router reboots. +// + A mechanism for keeping a port mapping registered. +// + Discover gateway address automatically. + +type natPMPClient struct { + client *natpmp.Client +} + +// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway +// address should be the IP of your router. +func PMP(gateway net.IP) (nat NAT) { + return &natPMPClient{natpmp.NewClient(gateway)} +} + +func (*natPMPClient) String() string { + return "NAT-PMP" +} + +func (n *natPMPClient) GetExternalAddress() (net.IP, error) { + response, err := n.client.GetExternalAddress() + if err != nil { + return nil, err + } + return response.ExternalIPAddress[:], nil +} + +func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error { + if lifetime <= 0 { + return fmt.Errorf("lifetime must not be <= 0") + } + // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping. + _, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second)) + return err +} + +func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { + // To destroy a mapping, send an add-port with + // an internalPort of the internal port to destroy, an external port of zero and a time of zero. + _, err = n.client.AddPortMapping(protocol, internalPort, 0, 0) + return +} diff --git a/p2p/natupnp.go b/p2p/natupnp.go new file mode 100644 index 000000000..2e0d8ce8d --- /dev/null +++ b/p2p/natupnp.go @@ -0,0 +1,341 @@ +package p2p + +// Just enough UPnP to be able to forward ports +// + +import ( + "bytes" + "encoding/xml" + "errors" + "fmt" + "net" + "net/http" + "os" + "strconv" + "strings" + "time" +) + +const ( + upnpDiscoverAttempts = 3 + upnpDiscoverTimeout = 5 * time.Second +) + +// UPNP returns a NAT port mapper that uses UPnP. It will attempt to +// discover the address of your router using UDP broadcasts. +func UPNP() NAT { + return &upnpNAT{} +} + +type upnpNAT struct { + serviceURL string + ourIP string +} + +func (n *upnpNAT) String() string { + return "UPNP" +} + +func (n *upnpNAT) discover() error { + if n.serviceURL != "" { + // already discovered + return nil + } + + ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") + if err != nil { + return err + } + // TODO: try on all network interfaces simultaneously. + // Broadcasting on 0.0.0.0 could select a random interface + // to send on (platform specific). + conn, err := net.ListenPacket("udp4", ":0") + if err != nil { + return err + } + defer conn.Close() + + conn.SetDeadline(time.Now().Add(10 * time.Second)) + st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" + buf := bytes.NewBufferString( + "M-SEARCH * HTTP/1.1\r\n" + + "HOST: 239.255.255.250:1900\r\n" + + st + + "MAN: \"ssdp:discover\"\r\n" + + "MX: 2\r\n\r\n") + message := buf.Bytes() + answerBytes := make([]byte, 1024) + for i := 0; i < upnpDiscoverAttempts; i++ { + _, err = conn.WriteTo(message, ssdp) + if err != nil { + return err + } + nn, _, err := conn.ReadFrom(answerBytes) + if err != nil { + continue + } + answer := string(answerBytes[0:nn]) + if strings.Index(answer, "\r\n"+st) < 0 { + continue + } + // HTTP header field names are case-insensitive. + // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 + locString := "\r\nlocation: " + answer = strings.ToLower(answer) + locIndex := strings.Index(answer, locString) + if locIndex < 0 { + continue + } + loc := answer[locIndex+len(locString):] + endIndex := strings.Index(loc, "\r\n") + if endIndex < 0 { + continue + } + locURL := loc[0:endIndex] + var serviceURL string + serviceURL, err = getServiceURL(locURL) + if err != nil { + return err + } + var ourIP string + ourIP, err = getOurIP() + if err != nil { + return err + } + n.serviceURL = serviceURL + n.ourIP = ourIP + return nil + } + return errors.New("UPnP port discovery failed.") +} + +func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { + if err := n.discover(); err != nil { + return nil, err + } + info, err := n.getStatusInfo() + return net.ParseIP(info.externalIpAddress), err +} + +func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error { + if err := n.discover(); err != nil { + return err + } + + // A single concatenation would break ARM compilation. + message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + + "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport) + message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" + + "<NewInternalClient>" + n.ourIP + "</NewInternalClient>" + + "<NewEnabled>1</NewEnabled><NewPortMappingDescription>" + message += description + + "</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) + + "</NewLeaseDuration></u:AddPortMapping>" + + // TODO: check response to see if the port was forwarded + _, err := soapRequest(n.serviceURL, "AddPortMapping", message) + return err +} + +func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error { + if err := n.discover(); err != nil { + return err + } + + message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + + "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + + "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + + "</u:DeletePortMapping>" + + // TODO: check response to see if the port was deleted + _, err := soapRequest(n.serviceURL, "DeletePortMapping", message) + return err +} + +type statusInfo struct { + externalIpAddress string +} + +func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { + message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + + "</u:GetStatusInfo>" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) + if err != nil { + return + } + + // TODO: Write a soap reply parser. It has to eat the Body and envelope tags... + + response.Body.Close() + return +} + +// service represents the Service type in an UPnP xml description. +// Only the parts we care about are present and thus the xml may have more +// fields than present in the structure. +type service struct { + ServiceType string `xml:"serviceType"` + ControlURL string `xml:"controlURL"` +} + +// deviceList represents the deviceList type in an UPnP xml description. +// Only the parts we care about are present and thus the xml may have more +// fields than present in the structure. +type deviceList struct { + XMLName xml.Name `xml:"deviceList"` + Device []device `xml:"device"` +} + +// serviceList represents the serviceList type in an UPnP xml description. +// Only the parts we care about are present and thus the xml may have more +// fields than present in the structure. +type serviceList struct { + XMLName xml.Name `xml:"serviceList"` + Service []service `xml:"service"` +} + +// device represents the device type in an UPnP xml description. +// Only the parts we care about are present and thus the xml may have more +// fields than present in the structure. +type device struct { + XMLName xml.Name `xml:"device"` + DeviceType string `xml:"deviceType"` + DeviceList deviceList `xml:"deviceList"` + ServiceList serviceList `xml:"serviceList"` +} + +// specVersion represents the specVersion in a UPnP xml description. +// Only the parts we care about are present and thus the xml may have more +// fields than present in the structure. +type specVersion struct { + XMLName xml.Name `xml:"specVersion"` + Major int `xml:"major"` + Minor int `xml:"minor"` +} + +// root represents the Root document for a UPnP xml description. +// Only the parts we care about are present and thus the xml may have more +// fields than present in the structure. +type root struct { + XMLName xml.Name `xml:"root"` + SpecVersion specVersion + Device device +} + +func getChildDevice(d *device, deviceType string) *device { + dl := d.DeviceList.Device + for i := 0; i < len(dl); i++ { + if dl[i].DeviceType == deviceType { + return &dl[i] + } + } + return nil +} + +func getChildService(d *device, serviceType string) *service { + sl := d.ServiceList.Service + for i := 0; i < len(sl); i++ { + if sl[i].ServiceType == serviceType { + return &sl[i] + } + } + return nil +} + +func getOurIP() (ip string, err error) { + hostname, err := os.Hostname() + if err != nil { + return + } + p, err := net.LookupIP(hostname) + if err != nil && len(p) > 0 { + return + } + return p[0].String(), nil +} + +func getServiceURL(rootURL string) (url string, err error) { + r, err := http.Get(rootURL) + if err != nil { + return + } + defer r.Body.Close() + if r.StatusCode >= 400 { + err = errors.New(string(r.StatusCode)) + return + } + var root root + err = xml.NewDecoder(r.Body).Decode(&root) + + if err != nil { + return + } + a := &root.Device + if a.DeviceType != "urn:schemas-upnp-org:device:InternetGatewayDevice:1" { + err = errors.New("No InternetGatewayDevice") + return + } + b := getChildDevice(a, "urn:schemas-upnp-org:device:WANDevice:1") + if b == nil { + err = errors.New("No WANDevice") + return + } + c := getChildDevice(b, "urn:schemas-upnp-org:device:WANConnectionDevice:1") + if c == nil { + err = errors.New("No WANConnectionDevice") + return + } + d := getChildService(c, "urn:schemas-upnp-org:service:WANIPConnection:1") + if d == nil { + err = errors.New("No WANIPConnection") + return + } + url = combineURL(rootURL, d.ControlURL) + return +} + +func combineURL(rootURL, subURL string) string { + protocolEnd := "://" + protoEndIndex := strings.Index(rootURL, protocolEnd) + a := rootURL[protoEndIndex+len(protocolEnd):] + rootIndex := strings.Index(a, "/") + return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL +} + +func soapRequest(url, function, message string) (r *http.Response, err error) { + fullMessage := "<?xml version=\"1.0\" ?>" + + "<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" + + "<s:Body>" + message + "</s:Body></s:Envelope>" + + req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage)) + if err != nil { + return + } + req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"") + req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3") + //req.Header.Set("Transfer-Encoding", "chunked") + req.Header.Set("SOAPAction", "\"urn:schemas-upnp-org:service:WANIPConnection:1#"+function+"\"") + req.Header.Set("Connection", "Close") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Pragma", "no-cache") + + r, err = http.DefaultClient.Do(req) + if err != nil { + return + } + + if r.Body != nil { + defer r.Body.Close() + } + + if r.StatusCode >= 400 { + // log.Stderr(function, r.StatusCode) + err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function) + r = nil + return + } + return +} diff --git a/p2p/peer.go b/p2p/peer.go new file mode 100644 index 000000000..86c4d7ab5 --- /dev/null +++ b/p2p/peer.go @@ -0,0 +1,462 @@ +package p2p + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "net" + "sort" + "sync" + "time" + + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/logger" +) + +// peerAddr is the structure of a peer list element. +// It is also a valid net.Addr. +type peerAddr struct { + IP net.IP + Port uint64 + Pubkey []byte // optional +} + +func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr { + n := addr.Network() + if n != "tcp" && n != "tcp4" && n != "tcp6" { + // for testing with non-TCP + return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey} + } + ta := addr.(*net.TCPAddr) + return &peerAddr{ta.IP, uint64(ta.Port), pubkey} +} + +func (d peerAddr) Network() string { + if d.IP.To4() != nil { + return "tcp4" + } else { + return "tcp6" + } +} + +func (d peerAddr) String() string { + return fmt.Sprintf("%v:%d", d.IP, d.Port) +} + +func (d peerAddr) RlpData() interface{} { + return []interface{}{d.IP, d.Port, d.Pubkey} +} + +// Peer represents a remote peer. +type Peer struct { + // Peers have all the log methods. + // Use them to display messages related to the peer. + *logger.Logger + + infolock sync.Mutex + identity ClientIdentity + caps []Cap + listenAddr *peerAddr // what remote peer is listening on + dialAddr *peerAddr // non-nil if dialing + + // The mutex protects the connection + // so only one protocol can write at a time. + writeMu sync.Mutex + conn net.Conn + bufconn *bufio.ReadWriter + + // These fields maintain the running protocols. + protocols []Protocol + runBaseProtocol bool // for testing + + runlock sync.RWMutex // protects running + running map[string]*proto + + protoWG sync.WaitGroup + protoErr chan error + closed chan struct{} + disc chan DiscReason + + activity event.TypeMux // for activity events + + slot int // index into Server peer list + + // These fields are kept so base protocol can access them. + // TODO: this should be one or more interfaces + ourID ClientIdentity // client id of the Server + ourListenAddr *peerAddr // listen addr of Server, nil if not listening + newPeerAddr chan<- *peerAddr // tell server about received peers + otherPeers func() []*Peer // should return the list of all peers + pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey +} + +// NewPeer returns a peer for testing purposes. +func NewPeer(id ClientIdentity, caps []Cap) *Peer { + conn, _ := net.Pipe() + peer := newPeer(conn, nil, nil) + peer.setHandshakeInfo(id, nil, caps) + close(peer.closed) + return peer +} + +func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer { + p := newPeer(conn, server.Protocols, dialAddr) + p.ourID = server.Identity + p.newPeerAddr = server.peerConnect + p.otherPeers = server.Peers + p.pubkeyHook = server.verifyPeer + p.runBaseProtocol = true + + // laddr can be updated concurrently by NAT traversal. + // newServerPeer must be called with the server lock held. + if server.laddr != nil { + p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey()) + } + return p +} + +func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer { + p := &Peer{ + Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()), + conn: conn, + dialAddr: dialAddr, + bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), + protocols: protocols, + running: make(map[string]*proto), + disc: make(chan DiscReason), + protoErr: make(chan error), + closed: make(chan struct{}), + } + return p +} + +// Identity returns the client identity of the remote peer. The +// identity can be nil if the peer has not yet completed the +// handshake. +func (p *Peer) Identity() ClientIdentity { + p.infolock.Lock() + defer p.infolock.Unlock() + return p.identity +} + +// Caps returns the capabilities (supported subprotocols) of the remote peer. +func (p *Peer) Caps() []Cap { + p.infolock.Lock() + defer p.infolock.Unlock() + return p.caps +} + +func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) { + p.infolock.Lock() + p.identity = id + p.listenAddr = laddr + p.caps = caps + p.infolock.Unlock() +} + +// RemoteAddr returns the remote address of the network connection. +func (p *Peer) RemoteAddr() net.Addr { + return p.conn.RemoteAddr() +} + +// LocalAddr returns the local address of the network connection. +func (p *Peer) LocalAddr() net.Addr { + return p.conn.LocalAddr() +} + +// Disconnect terminates the peer connection with the given reason. +// It returns immediately and does not wait until the connection is closed. +func (p *Peer) Disconnect(reason DiscReason) { + select { + case p.disc <- reason: + case <-p.closed: + } +} + +// String implements fmt.Stringer. +func (p *Peer) String() string { + kind := "inbound" + p.infolock.Lock() + if p.dialAddr != nil { + kind = "outbound" + } + p.infolock.Unlock() + return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind) +} + +const ( + // maximum amount of time allowed for reading a message + msgReadTimeout = 5 * time.Second + // maximum amount of time allowed for writing a message + msgWriteTimeout = 5 * time.Second + // messages smaller than this many bytes will be read at + // once before passing them to a protocol. + wholePayloadSize = 64 * 1024 +) + +var ( + inactivityTimeout = 2 * time.Second + disconnectGracePeriod = 2 * time.Second +) + +func (p *Peer) loop() (reason DiscReason, err error) { + defer p.activity.Stop() + defer p.closeProtocols() + defer close(p.closed) + defer p.conn.Close() + + // read loop + readMsg := make(chan Msg) + readErr := make(chan error) + readNext := make(chan bool, 1) + protoDone := make(chan struct{}, 1) + go p.readLoop(readMsg, readErr, readNext) + readNext <- true + + if p.runBaseProtocol { + p.startBaseProtocol() + } + +loop: + for { + select { + case msg := <-readMsg: + // a new message has arrived. + var wait bool + if wait, err = p.dispatch(msg, protoDone); err != nil { + p.Errorf("msg dispatch error: %v\n", err) + reason = discReasonForError(err) + break loop + } + if !wait { + // Msg has already been read completely, continue with next message. + readNext <- true + } + p.activity.Post(time.Now()) + case <-protoDone: + // protocol has consumed the message payload, + // we can continue reading from the socket. + readNext <- true + + case err := <-readErr: + // read failed. there is no need to run the + // polite disconnect sequence because the connection + // is probably dead anyway. + // TODO: handle write errors as well + return DiscNetworkError, err + case err = <-p.protoErr: + reason = discReasonForError(err) + break loop + case reason = <-p.disc: + break loop + } + } + + // wait for read loop to return. + close(readNext) + <-readErr + // tell the remote end to disconnect + done := make(chan struct{}) + go func() { + p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod)) + p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod) + io.Copy(ioutil.Discard, p.conn) + close(done) + }() + select { + case <-done: + case <-time.After(disconnectGracePeriod): + } + return reason, err +} + +func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) { + for _ = range unblock { + p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) + if msg, err := readMsg(p.bufconn); err != nil { + errc <- err + } else { + msgc <- msg + } + } + close(errc) +} + +func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) { + proto, err := p.getProto(msg.Code) + if err != nil { + return false, err + } + if msg.Size <= wholePayloadSize { + // optimization: msg is small enough, read all + // of it and move on to the next message + buf, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return false, err + } + msg.Payload = bytes.NewReader(buf) + proto.in <- msg + } else { + wait = true + pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone} + msg.Payload = pr + proto.in <- msg + } + return wait, nil +} + +func (p *Peer) startBaseProtocol() { + p.runlock.Lock() + defer p.runlock.Unlock() + p.running[""] = p.startProto(0, Protocol{ + Length: baseProtocolLength, + Run: runBaseProtocol, + }) +} + +// startProtocols starts matching named subprotocols. +func (p *Peer) startSubprotocols(caps []Cap) { + sort.Sort(capsByName(caps)) + + p.runlock.Lock() + defer p.runlock.Unlock() + offset := baseProtocolLength +outer: + for _, cap := range caps { + for _, proto := range p.protocols { + if proto.Name == cap.Name && + proto.Version == cap.Version && + p.running[cap.Name] == nil { + p.running[cap.Name] = p.startProto(offset, proto) + offset += proto.Length + continue outer + } + } + } +} + +func (p *Peer) startProto(offset uint64, impl Protocol) *proto { + rw := &proto{ + in: make(chan Msg), + offset: offset, + maxcode: impl.Length, + peer: p, + } + p.protoWG.Add(1) + go func() { + err := impl.Run(p, rw) + if err == nil { + p.Infof("protocol %q returned", impl.Name) + err = newPeerError(errMisc, "protocol returned") + } else { + p.Errorf("protocol %q error: %v\n", impl.Name, err) + } + select { + case p.protoErr <- err: + case <-p.closed: + } + p.protoWG.Done() + }() + return rw +} + +// getProto finds the protocol responsible for handling +// the given message code. +func (p *Peer) getProto(code uint64) (*proto, error) { + p.runlock.RLock() + defer p.runlock.RUnlock() + for _, proto := range p.running { + if code >= proto.offset && code < proto.offset+proto.maxcode { + return proto, nil + } + } + return nil, newPeerError(errInvalidMsgCode, "%d", code) +} + +func (p *Peer) closeProtocols() { + p.runlock.RLock() + for _, p := range p.running { + close(p.in) + } + p.runlock.RUnlock() + p.protoWG.Wait() +} + +// writeProtoMsg sends the given message on behalf of the given named protocol. +func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { + p.runlock.RLock() + proto, ok := p.running[protoName] + p.runlock.RUnlock() + if !ok { + return fmt.Errorf("protocol %s not handled by peer", protoName) + } + if msg.Code >= proto.maxcode { + return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) + } + msg.Code += proto.offset + return p.writeMsg(msg, msgWriteTimeout) +} + +// writeMsg writes a message to the connection. +func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error { + p.writeMu.Lock() + defer p.writeMu.Unlock() + p.conn.SetWriteDeadline(time.Now().Add(timeout)) + if err := writeMsg(p.bufconn, msg); err != nil { + return newPeerError(errWrite, "%v", err) + } + return p.bufconn.Flush() +} + +type proto struct { + name string + in chan Msg + maxcode, offset uint64 + peer *Peer +} + +func (rw *proto) WriteMsg(msg Msg) error { + if msg.Code >= rw.maxcode { + return newPeerError(errInvalidMsgCode, "not handled") + } + msg.Code += rw.offset + return rw.peer.writeMsg(msg, msgWriteTimeout) +} + +func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { + return rw.WriteMsg(NewMsg(code, data)) +} + +func (rw *proto) ReadMsg() (Msg, error) { + msg, ok := <-rw.in + if !ok { + return msg, io.EOF + } + msg.Code -= rw.offset + return msg, nil +} + +// eofSignal wraps a reader with eof signaling. the eof channel is +// closed when the wrapped reader returns an error or when count bytes +// have been read. +// +type eofSignal struct { + wrapped io.Reader + count int64 + eof chan<- struct{} +} + +// note: when using eofSignal to detect whether a message payload +// has been read, Read might not be called for zero sized messages. + +func (r *eofSignal) Read(buf []byte) (int, error) { + n, err := r.wrapped.Read(buf) + r.count -= int64(n) + if (err != nil || r.count <= 0) && r.eof != nil { + r.eof <- struct{}{} // tell Peer that msg has been consumed + r.eof = nil + } + return n, err +} diff --git a/p2p/peer_error.go b/p2p/peer_error.go new file mode 100644 index 000000000..0eb7ec838 --- /dev/null +++ b/p2p/peer_error.go @@ -0,0 +1,133 @@ +package p2p + +import ( + "fmt" +) + +const ( + errMagicTokenMismatch = iota + errRead + errWrite + errMisc + errInvalidMsgCode + errInvalidMsg + errP2PVersionMismatch + errPubkeyMissing + errPubkeyInvalid + errPubkeyForbidden + errProtocolBreach + errPingTimeout + errInvalidNetworkId + errInvalidProtocolVersion +) + +var errorToString = map[int]string{ + errMagicTokenMismatch: "Magic token mismatch", + errRead: "Read error", + errWrite: "Write error", + errMisc: "Misc error", + errInvalidMsgCode: "Invalid message code", + errInvalidMsg: "Invalid message", + errP2PVersionMismatch: "P2P Version Mismatch", + errPubkeyMissing: "Public key missing", + errPubkeyInvalid: "Public key invalid", + errPubkeyForbidden: "Public key forbidden", + errProtocolBreach: "Protocol Breach", + errPingTimeout: "Ping timeout", + errInvalidNetworkId: "Invalid network id", + errInvalidProtocolVersion: "Invalid protocol version", +} + +type peerError struct { + Code int + message string +} + +func newPeerError(code int, format string, v ...interface{}) *peerError { + desc, ok := errorToString[code] + if !ok { + panic("invalid error code") + } + err := &peerError{code, desc} + if format != "" { + err.message += ": " + fmt.Sprintf(format, v...) + } + return err +} + +func (self *peerError) Error() string { + return self.message +} + +type DiscReason byte + +const ( + DiscRequested DiscReason = 0x00 + DiscNetworkError = 0x01 + DiscProtocolError = 0x02 + DiscUselessPeer = 0x03 + DiscTooManyPeers = 0x04 + DiscAlreadyConnected = 0x05 + DiscIncompatibleVersion = 0x06 + DiscInvalidIdentity = 0x07 + DiscQuitting = 0x08 + DiscUnexpectedIdentity = 0x09 + DiscSelf = 0x0a + DiscReadTimeout = 0x0b + DiscSubprotocolError = 0x10 +) + +var discReasonToString = [DiscSubprotocolError + 1]string{ + DiscRequested: "Disconnect requested", + DiscNetworkError: "Network error", + DiscProtocolError: "Breach of protocol", + DiscUselessPeer: "Useless peer", + DiscTooManyPeers: "Too many peers", + DiscAlreadyConnected: "Already connected", + DiscIncompatibleVersion: "Incompatible P2P protocol version", + DiscInvalidIdentity: "Invalid node identity", + DiscQuitting: "Client quitting", + DiscUnexpectedIdentity: "Unexpected identity", + DiscSelf: "Connected to self", + DiscReadTimeout: "Read timeout", + DiscSubprotocolError: "Subprotocol error", +} + +func (d DiscReason) String() string { + if len(discReasonToString) < int(d) { + return fmt.Sprintf("Unknown Reason(%d)", d) + } + return discReasonToString[d] +} + +type discRequestedError DiscReason + +func (err discRequestedError) Error() string { + return fmt.Sprintf("disconnect requested: %v", DiscReason(err)) +} + +func discReasonForError(err error) DiscReason { + if reason, ok := err.(discRequestedError); ok { + return DiscReason(reason) + } + peerError, ok := err.(*peerError) + if !ok { + return DiscSubprotocolError + } + switch peerError.Code { + case errP2PVersionMismatch: + return DiscIncompatibleVersion + case errPubkeyMissing, errPubkeyInvalid: + return DiscInvalidIdentity + case errPubkeyForbidden: + return DiscUselessPeer + case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach: + return DiscProtocolError + case errPingTimeout: + return DiscReadTimeout + case errRead, errWrite, errMisc: + return DiscNetworkError + default: + return DiscSubprotocolError + } +} diff --git a/p2p/peer_test.go b/p2p/peer_test.go new file mode 100644 index 000000000..f7759786e --- /dev/null +++ b/p2p/peer_test.go @@ -0,0 +1,295 @@ +package p2p + +import ( + "bufio" + "bytes" + "encoding/hex" + "io" + "io/ioutil" + "net" + "reflect" + "testing" + "time" +) + +var discard = Protocol{ + Name: "discard", + Length: 1, + Run: func(p *Peer, rw MsgReadWriter) error { + for { + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if err = msg.Discard(); err != nil { + return err + } + } + }, +} + +func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { + conn1, conn2 := net.Pipe() + id := NewSimpleClientIdentity("test", "0", "0", "public key") + peer := newPeer(conn1, protos, nil) + peer.ourID = id + peer.pubkeyHook = func(*peerAddr) error { return nil } + errc := make(chan error, 1) + go func() { + _, err := peer.loop() + errc <- err + }() + return conn2, peer, errc +} + +func TestPeerProtoReadMsg(t *testing.T) { + defer testlog(t).detach() + + done := make(chan struct{}) + proto := Protocol{ + Name: "a", + Length: 5, + Run: func(peer *Peer, rw MsgReadWriter) error { + msg, err := rw.ReadMsg() + if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Code != 2 { + t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) + } + data, err := ioutil.ReadAll(msg.Payload) + if err != nil { + t.Errorf("payload read error: %v", err) + } + expdata, _ := hex.DecodeString("0183303030") + if !bytes.Equal(expdata, data) { + t.Errorf("incorrect msg data %x", data) + } + close(done) + return nil + }, + } + + net, peer, errc := testPeer([]Protocol{proto}) + defer net.Close() + peer.startSubprotocols([]Cap{proto.cap()}) + + writeMsg(net, NewMsg(18, 1, "000")) + select { + case <-done: + case err := <-errc: + t.Errorf("peer returned: %v", err) + case <-time.After(2 * time.Second): + t.Errorf("receive timeout") + } +} + +func TestPeerProtoReadLargeMsg(t *testing.T) { + defer testlog(t).detach() + + msgsize := uint32(10 * 1024 * 1024) + done := make(chan struct{}) + proto := Protocol{ + Name: "a", + Length: 5, + Run: func(peer *Peer, rw MsgReadWriter) error { + msg, err := rw.ReadMsg() + if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Size != msgsize+4 { + t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize) + } + msg.Discard() + close(done) + return nil + }, + } + + net, peer, errc := testPeer([]Protocol{proto}) + defer net.Close() + peer.startSubprotocols([]Cap{proto.cap()}) + + writeMsg(net, NewMsg(18, make([]byte, msgsize))) + select { + case <-done: + case err := <-errc: + t.Errorf("peer returned: %v", err) + case <-time.After(2 * time.Second): + t.Errorf("receive timeout") + } +} + +func TestPeerProtoEncodeMsg(t *testing.T) { + defer testlog(t).detach() + + proto := Protocol{ + Name: "a", + Length: 2, + Run: func(peer *Peer, rw MsgReadWriter) error { + if err := rw.EncodeMsg(2); err == nil { + t.Error("expected error for out-of-range msg code, got nil") + } + if err := rw.EncodeMsg(1); err != nil { + t.Errorf("write error: %v", err) + } + return nil + }, + } + net, peer, _ := testPeer([]Protocol{proto}) + defer net.Close() + peer.startSubprotocols([]Cap{proto.cap()}) + + bufr := bufio.NewReader(net) + msg, err := readMsg(bufr) + if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Code != 17 { + t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) + } +} + +func TestPeerWrite(t *testing.T) { + defer testlog(t).detach() + + net, peer, peerErr := testPeer([]Protocol{discard}) + defer net.Close() + peer.startSubprotocols([]Cap{discard.cap()}) + + // test write errors + if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { + t.Errorf("expected error for unknown protocol, got nil") + } + if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil { + t.Errorf("expected error for out-of-range msg code, got nil") + } else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode { + t.Errorf("wrong error for out-of-range msg code, got %#v", err) + } + + // setup for reading the message on the other end + read := make(chan struct{}) + go func() { + bufr := bufio.NewReader(net) + msg, err := readMsg(bufr) + if err != nil { + t.Errorf("read error: %v", err) + } else if msg.Code != 16 { + t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) + } + msg.Discard() + close(read) + }() + + // test succcessful write + if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { + t.Errorf("expect no error for known protocol: %v", err) + } + select { + case <-read: + case err := <-peerErr: + t.Fatalf("peer stopped: %v", err) + } +} + +func TestPeerActivity(t *testing.T) { + // shorten inactivityTimeout while this test is running + oldT := inactivityTimeout + defer func() { inactivityTimeout = oldT }() + inactivityTimeout = 20 * time.Millisecond + + net, peer, peerErr := testPeer([]Protocol{discard}) + defer net.Close() + peer.startSubprotocols([]Cap{discard.cap()}) + + sub := peer.activity.Subscribe(time.Time{}) + defer sub.Unsubscribe() + + for i := 0; i < 6; i++ { + writeMsg(net, NewMsg(16)) + select { + case <-sub.Chan(): + case <-time.After(inactivityTimeout / 2): + t.Fatal("no event within ", inactivityTimeout/2) + case err := <-peerErr: + t.Fatal("peer error", err) + } + } + + select { + case <-time.After(inactivityTimeout * 2): + case <-sub.Chan(): + t.Fatal("got activity event while connection was inactive") + case err := <-peerErr: + t.Fatal("peer error", err) + } +} + +func TestNewPeer(t *testing.T) { + id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey") + caps := []Cap{{"foo", 2}, {"bar", 3}} + p := NewPeer(id, caps) + if !reflect.DeepEqual(p.Caps(), caps) { + t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) + } + if p.Identity() != id { + t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id) + } + // Should not hang. + p.Disconnect(DiscAlreadyConnected) +} + +func TestEOFSignal(t *testing.T) { + rb := make([]byte, 10) + + // empty reader + eof := make(chan struct{}, 1) + sig := &eofSignal{new(bytes.Buffer), 0, eof} + if n, err := sig.Read(rb); n != 0 || err != io.EOF { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // count before error + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof} + if n, err := sig.Read(rb); n != 8 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // error before count + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof} + if n, err := sig.Read(rb); n != 4 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + if n, err := sig.Read(rb); n != 0 || err != io.EOF { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // no signal if neither occurs + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof} + if n, err := sig.Read(rb); n != 10 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + t.Error("unexpected EOF signal") + default: + } +} diff --git a/p2p/protocol.go b/p2p/protocol.go new file mode 100644 index 000000000..3f52205f5 --- /dev/null +++ b/p2p/protocol.go @@ -0,0 +1,294 @@ +package p2p + +import ( + "bytes" + "time" + + "github.com/ethereum/go-ethereum/ethutil" +) + +// Protocol represents a P2P subprotocol implementation. +type Protocol struct { + // Name should contain the official protocol name, + // often a three-letter word. + Name string + + // Version should contain the version number of the protocol. + Version uint + + // Length should contain the number of message codes used + // by the protocol. + Length uint64 + + // Run is called in a new groutine when the protocol has been + // negotiated with a peer. It should read and write messages from + // rw. The Payload for each message must be fully consumed. + // + // The peer connection is closed when Start returns. It should return + // any protocol-level error (such as an I/O error) that is + // encountered. + Run func(peer *Peer, rw MsgReadWriter) error +} + +func (p Protocol) cap() Cap { + return Cap{p.Name, p.Version} +} + +const ( + baseProtocolVersion = 2 + baseProtocolLength = uint64(16) + baseProtocolMaxMsgSize = 10 * 1024 * 1024 +) + +const ( + // devp2p message codes + handshakeMsg = 0x00 + discMsg = 0x01 + pingMsg = 0x02 + pongMsg = 0x03 + getPeersMsg = 0x04 + peersMsg = 0x05 +) + +// handshake is the structure of a handshake list. +type handshake struct { + Version uint64 + ID string + Caps []Cap + ListenPort uint64 + NodeID []byte +} + +func (h *handshake) String() string { + return h.ID +} +func (h *handshake) Pubkey() []byte { + return h.NodeID +} + +// Cap is the structure of a peer capability. +type Cap struct { + Name string + Version uint +} + +func (cap Cap) RlpData() interface{} { + return []interface{}{cap.Name, cap.Version} +} + +type capsByName []Cap + +func (cs capsByName) Len() int { return len(cs) } +func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name } +func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] } + +type baseProtocol struct { + rw MsgReadWriter + peer *Peer +} + +func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { + bp := &baseProtocol{rw, peer} + if err := bp.doHandshake(rw); err != nil { + return err + } + // run main loop + quit := make(chan error, 1) + go func() { + for { + if err := bp.handle(rw); err != nil { + quit <- err + break + } + } + }() + return bp.loop(quit) +} + +var pingTimeout = 2 * time.Second + +func (bp *baseProtocol) loop(quit <-chan error) error { + ping := time.NewTimer(pingTimeout) + activity := bp.peer.activity.Subscribe(time.Time{}) + lastActive := time.Time{} + defer ping.Stop() + defer activity.Unsubscribe() + + getPeersTick := time.NewTicker(10 * time.Second) + defer getPeersTick.Stop() + err := bp.rw.EncodeMsg(getPeersMsg) + + for err == nil { + select { + case err = <-quit: + return err + case <-getPeersTick.C: + err = bp.rw.EncodeMsg(getPeersMsg) + case event := <-activity.Chan(): + ping.Reset(pingTimeout) + lastActive = event.(time.Time) + case t := <-ping.C: + if lastActive.Add(pingTimeout * 2).Before(t) { + err = newPeerError(errPingTimeout, "") + } else if lastActive.Add(pingTimeout).Before(t) { + err = bp.rw.EncodeMsg(pingMsg) + } + } + } + return err +} + +func (bp *baseProtocol) handle(rw MsgReadWriter) error { + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if msg.Size > baseProtocolMaxMsgSize { + return newPeerError(errMisc, "message too big") + } + // make sure that the payload has been fully consumed + defer msg.Discard() + + switch msg.Code { + case handshakeMsg: + return newPeerError(errProtocolBreach, "extra handshake received") + + case discMsg: + var reason [1]DiscReason + if err := msg.Decode(&reason); err != nil { + return err + } + return discRequestedError(reason[0]) + + case pingMsg: + return bp.rw.EncodeMsg(pongMsg) + + case pongMsg: + + case getPeersMsg: + peers := bp.peerList() + // this is dangerous. the spec says that we should _delay_ + // sending the response if no new information is available. + // this means that would need to send a response later when + // new peers become available. + // + // TODO: add event mechanism to notify baseProtocol for new peers + if len(peers) > 0 { + return bp.rw.EncodeMsg(peersMsg, peers) + } + + case peersMsg: + var peers []*peerAddr + if err := msg.Decode(&peers); err != nil { + return err + } + for _, addr := range peers { + bp.peer.Debugf("received peer suggestion: %v", addr) + bp.peer.newPeerAddr <- addr + } + + default: + return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code) + } + return nil +} + +func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { + // send our handshake + if err := rw.WriteMsg(bp.handshakeMsg()); err != nil { + return err + } + + // read and handle remote handshake + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if msg.Code != handshakeMsg { + return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code) + } + if msg.Size > baseProtocolMaxMsgSize { + return newPeerError(errMisc, "message too big") + } + + var hs handshake + if err := msg.Decode(&hs); err != nil { + return err + } + + // validate handshake info + if hs.Version != baseProtocolVersion { + return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", + baseProtocolVersion, hs.Version) + } + if len(hs.NodeID) == 0 { + return newPeerError(errPubkeyMissing, "") + } + if len(hs.NodeID) != 64 { + return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8) + } + if da := bp.peer.dialAddr; da != nil { + // verify that the peer we wanted to connect to + // actually holds the target public key. + if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) { + return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch") + } + } + pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) + if err := bp.peer.pubkeyHook(pa); err != nil { + return newPeerError(errPubkeyForbidden, "%v", err) + } + + // TODO: remove Caps with empty name + + var addr *peerAddr + if hs.ListenPort != 0 { + addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) + addr.Port = hs.ListenPort + } + bp.peer.setHandshakeInfo(&hs, addr, hs.Caps) + bp.peer.startSubprotocols(hs.Caps) + return nil +} + +func (bp *baseProtocol) handshakeMsg() Msg { + var ( + port uint64 + caps []interface{} + ) + if bp.peer.ourListenAddr != nil { + port = bp.peer.ourListenAddr.Port + } + for _, proto := range bp.peer.protocols { + caps = append(caps, proto.cap()) + } + return NewMsg(handshakeMsg, + baseProtocolVersion, + bp.peer.ourID.String(), + caps, + port, + bp.peer.ourID.Pubkey()[1:], + ) +} + +func (bp *baseProtocol) peerList() []ethutil.RlpEncodable { + peers := bp.peer.otherPeers() + ds := make([]ethutil.RlpEncodable, 0, len(peers)) + for _, p := range peers { + p.infolock.Lock() + addr := p.listenAddr + p.infolock.Unlock() + // filter out this peer and peers that are not listening or + // have not completed the handshake. + // TODO: track previously sent peers and exclude them as well. + if p == bp.peer || addr == nil { + continue + } + ds = append(ds, addr) + } + ourAddr := bp.peer.ourListenAddr + if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() { + ds = append(ds, ourAddr) + } + return ds +} diff --git a/p2p/protocol_test.go b/p2p/protocol_test.go new file mode 100644 index 000000000..65f26fb12 --- /dev/null +++ b/p2p/protocol_test.go @@ -0,0 +1,58 @@ +package p2p + +import ( + "fmt" + "testing" +) + +func TestBaseProtocolDisconnect(t *testing.T) { + peer := NewPeer(NewSimpleClientIdentity("p1", "", "", "foo"), nil) + peer.ourID = NewSimpleClientIdentity("p2", "", "", "bar") + peer.pubkeyHook = func(*peerAddr) error { return nil } + + rw1, rw2 := MsgPipe() + done := make(chan struct{}) + go func() { + if err := expectMsg(rw2, handshakeMsg); err != nil { + t.Error(err) + } + err := rw2.EncodeMsg(handshakeMsg, + baseProtocolVersion, + "", + []interface{}{}, + 0, + make([]byte, 64), + ) + if err != nil { + t.Error(err) + } + if err := expectMsg(rw2, getPeersMsg); err != nil { + t.Error(err) + } + if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil { + t.Error(err) + } + close(done) + }() + + if err := runBaseProtocol(peer, rw1); err == nil { + t.Errorf("base protocol returned without error") + } else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting { + t.Errorf("base protocol returned wrong error: %v", err) + } + <-done +} + +func expectMsg(r MsgReader, code uint64) error { + msg, err := r.ReadMsg() + if err != nil { + return err + } + if err := msg.Discard(); err != nil { + return err + } + if msg.Code != code { + return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code) + } + return nil +} diff --git a/p2p/server.go b/p2p/server.go new file mode 100644 index 000000000..8a6087566 --- /dev/null +++ b/p2p/server.go @@ -0,0 +1,467 @@ +package p2p + +import ( + "bytes" + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/ethereum/go-ethereum/logger" +) + +const ( + outboundAddressPoolSize = 500 + defaultDialTimeout = 10 * time.Second + portMappingUpdateInterval = 15 * time.Minute + portMappingTimeout = 20 * time.Minute +) + +var srvlog = logger.NewLogger("P2P Server") + +// Server manages all peer connections. +// +// The fields of Server are used as configuration parameters. +// You should set them before starting the Server. Fields may not be +// modified while the server is running. +type Server struct { + // This field must be set to a valid client identity. + Identity ClientIdentity + + // MaxPeers is the maximum number of peers that can be + // connected. It must be greater than zero. + MaxPeers int + + // Protocols should contain the protocols supported + // by the server. Matching protocols are launched for + // each peer. + Protocols []Protocol + + // If Blacklist is set to a non-nil value, the given Blacklist + // is used to verify peer connections. + Blacklist Blacklist + + // If ListenAddr is set to a non-nil address, the server + // will listen for incoming connections. + // + // If the port is zero, the operating system will pick a port. The + // ListenAddr field will be updated with the actual address when + // the server is started. + ListenAddr string + + // If set to a non-nil value, the given NAT port mapper + // is used to make the listening port available to the + // Internet. + NAT NAT + + // If Dialer is set to a non-nil value, the given Dialer + // is used to dial outbound peer connections. + Dialer *net.Dialer + + // If NoDial is true, the server will not dial any peers. + NoDial bool + + // Hook for testing. This is useful because we can inhibit + // the whole protocol stack. + newPeerFunc peerFunc + + lock sync.RWMutex + running bool + listener net.Listener + laddr *net.TCPAddr // real listen addr + peers []*Peer + peerSlots chan int + peerCount int + + quit chan struct{} + wg sync.WaitGroup + peerConnect chan *peerAddr + peerDisconnect chan *Peer +} + +// NAT is implemented by NAT traversal methods. +type NAT interface { + GetExternalAddress() (net.IP, error) + AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error + DeletePortMapping(protocol string, extport, intport int) error + + // Should return name of the method. + String() string +} + +type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer + +// Peers returns all connected peers. +func (srv *Server) Peers() (peers []*Peer) { + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { + if peer != nil { + peers = append(peers, peer) + } + } + return +} + +// PeerCount returns the number of connected peers. +func (srv *Server) PeerCount() int { + srv.lock.RLock() + defer srv.lock.RUnlock() + return srv.peerCount +} + +// SuggestPeer injects an address into the outbound address pool. +func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) { + select { + case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}: + default: // don't block + } +} + +// Broadcast sends an RLP-encoded message to all connected peers. +// This method is deprecated and will be removed later. +func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) { + var payload []byte + if data != nil { + payload = encodePayload(data...) + } + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { + if peer != nil { + var msg = Msg{Code: code} + if data != nil { + msg.Payload = bytes.NewReader(payload) + msg.Size = uint32(len(payload)) + } + peer.writeProtoMsg(protocol, msg) + } + } +} + +// Start starts running the server. +// Servers can be re-used and started again after stopping. +func (srv *Server) Start() (err error) { + srv.lock.Lock() + defer srv.lock.Unlock() + if srv.running { + return errors.New("server already running") + } + srvlog.Infoln("Starting Server") + + // initialize fields + if srv.Identity == nil { + return fmt.Errorf("Server.Identity must be set to a non-nil identity") + } + if srv.MaxPeers <= 0 { + return fmt.Errorf("Server.MaxPeers must be > 0") + } + srv.quit = make(chan struct{}) + srv.peers = make([]*Peer, srv.MaxPeers) + srv.peerSlots = make(chan int, srv.MaxPeers) + srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize) + srv.peerDisconnect = make(chan *Peer) + if srv.newPeerFunc == nil { + srv.newPeerFunc = newServerPeer + } + if srv.Blacklist == nil { + srv.Blacklist = NewBlacklist() + } + if srv.Dialer == nil { + srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} + } + + if srv.ListenAddr != "" { + if err := srv.startListening(); err != nil { + return err + } + } + if !srv.NoDial { + srv.wg.Add(1) + go srv.dialLoop() + } + if srv.NoDial && srv.ListenAddr == "" { + srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.") + } + + // make all slots available + for i := range srv.peers { + srv.peerSlots <- i + } + // note: discLoop is not part of WaitGroup + go srv.discLoop() + srv.running = true + return nil +} + +func (srv *Server) startListening() error { + listener, err := net.Listen("tcp", srv.ListenAddr) + if err != nil { + return err + } + srv.ListenAddr = listener.Addr().String() + srv.laddr = listener.Addr().(*net.TCPAddr) + srv.listener = listener + srv.wg.Add(1) + go srv.listenLoop() + if !srv.laddr.IP.IsLoopback() && srv.NAT != nil { + srv.wg.Add(1) + go srv.natLoop(srv.laddr.Port) + } + return nil +} + +// Stop terminates the server and all active peer connections. +// It blocks until all active connections have been closed. +func (srv *Server) Stop() { + srv.lock.Lock() + if !srv.running { + srv.lock.Unlock() + return + } + srv.running = false + srv.lock.Unlock() + + srvlog.Infoln("Stopping server") + if srv.listener != nil { + // this unblocks listener Accept + srv.listener.Close() + } + close(srv.quit) + for _, peer := range srv.Peers() { + peer.Disconnect(DiscQuitting) + } + srv.wg.Wait() + + // wait till they actually disconnect + // this is checked by claiming all peerSlots. + // slots become available as the peers disconnect. + for i := 0; i < cap(srv.peerSlots); i++ { + <-srv.peerSlots + } + // terminate discLoop + close(srv.peerDisconnect) +} + +func (srv *Server) discLoop() { + for peer := range srv.peerDisconnect { + // peer has just disconnected. free up its slot. + srvlog.Infof("%v is gone", peer) + srv.peerSlots <- peer.slot + srv.lock.Lock() + srv.peers[peer.slot] = nil + srv.lock.Unlock() + } +} + +// main loop for adding connections via listening +func (srv *Server) listenLoop() { + defer srv.wg.Done() + + srvlog.Infoln("Listening on", srv.listener.Addr()) + for { + select { + case slot := <-srv.peerSlots: + conn, err := srv.listener.Accept() + if err != nil { + srv.peerSlots <- slot + return + } + srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot) + srv.addPeer(conn, nil, slot) + case <-srv.quit: + return + } + } +} + +func (srv *Server) natLoop(port int) { + defer srv.wg.Done() + for { + srv.updatePortMapping(port) + select { + case <-time.After(portMappingUpdateInterval): + // one more round + case <-srv.quit: + srv.removePortMapping(port) + return + } + } +} + +func (srv *Server) updatePortMapping(port int) { + srvlog.Infoln("Attempting to map port", port, "with", srv.NAT) + err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout) + if err != nil { + srvlog.Errorln("Port mapping error:", err) + return + } + extip, err := srv.NAT.GetExternalAddress() + if err != nil { + srvlog.Errorln("Error getting external IP:", err) + return + } + srv.lock.Lock() + extaddr := *(srv.listener.Addr().(*net.TCPAddr)) + extaddr.IP = extip + srvlog.Infoln("Mapped port, external addr is", &extaddr) + srv.laddr = &extaddr + srv.lock.Unlock() +} + +func (srv *Server) removePortMapping(port int) { + srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT) + srv.NAT.DeletePortMapping("tcp", port, port) +} + +func (srv *Server) dialLoop() { + defer srv.wg.Done() + var ( + suggest chan *peerAddr + slot *int + slots = srv.peerSlots + ) + for { + select { + case i := <-slots: + // we need a peer in slot i, slot reserved + slot = &i + // now we can watch for candidate peers in the next loop + suggest = srv.peerConnect + // do not consume more until candidate peer is found + slots = nil + + case desc := <-suggest: + // candidate peer found, will dial out asyncronously + // if connection fails slot will be released + go srv.dialPeer(desc, *slot) + // we can watch if more peers needed in the next loop + slots = srv.peerSlots + // until then we dont care about candidate peers + suggest = nil + + case <-srv.quit: + // give back the currently reserved slot + if slot != nil { + srv.peerSlots <- *slot + } + return + } + } +} + +// connect to peer via dial out +func (srv *Server) dialPeer(desc *peerAddr, slot int) { + srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot) + conn, err := srv.Dialer.Dial(desc.Network(), desc.String()) + if err != nil { + srvlog.Errorf("Dial error: %v", err) + srv.peerSlots <- slot + return + } + go srv.addPeer(conn, desc, slot) +} + +// creates the new peer object and inserts it into its slot +func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer { + srv.lock.Lock() + defer srv.lock.Unlock() + if !srv.running { + conn.Close() + srv.peerSlots <- slot // release slot + return nil + } + peer := srv.newPeerFunc(srv, conn, desc) + peer.slot = slot + srv.peers[slot] = peer + srv.peerCount++ + go func() { peer.loop(); srv.peerDisconnect <- peer }() + return peer +} + +// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot +func (srv *Server) removePeer(peer *Peer) { + srv.lock.Lock() + defer srv.lock.Unlock() + srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot) + if srv.peers[peer.slot] != peer { + srvlog.Warnln("Invalid peer to remove:", peer) + return + } + // remove from list and index + srv.peerCount-- + srv.peers[peer.slot] = nil + // release slot to signal need for a new peer, last! + srv.peerSlots <- peer.slot +} + +func (srv *Server) verifyPeer(addr *peerAddr) error { + if srv.Blacklist.Exists(addr.Pubkey) { + return errors.New("blacklisted") + } + if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) { + return newPeerError(errPubkeyForbidden, "not allowed to connect to srv") + } + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { + if peer != nil { + id := peer.Identity() + if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) { + return errors.New("already connected") + } + } + } + return nil +} + +type Blacklist interface { + Get([]byte) (bool, error) + Put([]byte) error + Delete([]byte) error + Exists(pubkey []byte) (ok bool) +} + +type BlacklistMap struct { + blacklist map[string]bool + lock sync.RWMutex +} + +func NewBlacklist() *BlacklistMap { + return &BlacklistMap{ + blacklist: make(map[string]bool), + } +} + +func (self *BlacklistMap) Get(pubkey []byte) (bool, error) { + self.lock.RLock() + defer self.lock.RUnlock() + v, ok := self.blacklist[string(pubkey)] + var err error + if !ok { + err = fmt.Errorf("not found") + } + return v, err +} + +func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) { + self.lock.RLock() + defer self.lock.RUnlock() + _, ok = self.blacklist[string(pubkey)] + return +} + +func (self *BlacklistMap) Put(pubkey []byte) error { + self.lock.RLock() + defer self.lock.RUnlock() + self.blacklist[string(pubkey)] = true + return nil +} + +func (self *BlacklistMap) Delete(pubkey []byte) error { + self.lock.RLock() + defer self.lock.RUnlock() + delete(self.blacklist, string(pubkey)) + return nil +} diff --git a/p2p/server_test.go b/p2p/server_test.go new file mode 100644 index 000000000..5c0d08d39 --- /dev/null +++ b/p2p/server_test.go @@ -0,0 +1,161 @@ +package p2p + +import ( + "bytes" + "io" + "net" + "sync" + "testing" + "time" +) + +func startTestServer(t *testing.T, pf peerFunc) *Server { + server := &Server{ + Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"), + MaxPeers: 10, + ListenAddr: "127.0.0.1:0", + newPeerFunc: pf, + } + if err := server.Start(); err != nil { + t.Fatalf("Could not start server: %v", err) + } + return server +} + +func TestServerListen(t *testing.T) { + defer testlog(t).detach() + + // start the test server + connected := make(chan *Peer) + srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { + if conn == nil { + t.Error("peer func called with nil conn") + } + if dialAddr != nil { + t.Error("peer func called with non-nil dialAddr") + } + peer := newPeer(conn, nil, dialAddr) + connected <- peer + return peer + }) + defer close(connected) + defer srv.Stop() + + // dial the test server + conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) + if err != nil { + t.Fatalf("could not dial: %v", err) + } + defer conn.Close() + + select { + case peer := <-connected: + if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() { + t.Errorf("peer started with wrong conn: got %v, want %v", + peer.conn.LocalAddr(), conn.RemoteAddr()) + } + case <-time.After(1 * time.Second): + t.Error("server did not accept within one second") + } +} + +func TestServerDial(t *testing.T) { + defer testlog(t).detach() + + // run a fake TCP server to handle the connection. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("could not setup listener: %v") + } + defer listener.Close() + accepted := make(chan net.Conn) + go func() { + conn, err := listener.Accept() + if err != nil { + t.Error("acccept error:", err) + } + conn.Close() + accepted <- conn + }() + + // start the test server + connected := make(chan *Peer) + srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { + if conn == nil { + t.Error("peer func called with nil conn") + } + peer := newPeer(conn, nil, dialAddr) + connected <- peer + return peer + }) + defer close(connected) + defer srv.Stop() + + // tell the server to connect. + connAddr := newPeerAddr(listener.Addr(), nil) + srv.peerConnect <- connAddr + + select { + case conn := <-accepted: + select { + case peer := <-connected: + if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() { + t.Errorf("peer started with wrong conn: got %v, want %v", + peer.conn.RemoteAddr(), conn.LocalAddr()) + } + if peer.dialAddr != connAddr { + t.Errorf("peer started with wrong dialAddr: got %v, want %v", + peer.dialAddr, connAddr) + } + case <-time.After(1 * time.Second): + t.Error("server did not launch peer within one second") + } + + case <-time.After(1 * time.Second): + t.Error("server did not connect within one second") + } +} + +func TestServerBroadcast(t *testing.T) { + defer testlog(t).detach() + var connected sync.WaitGroup + srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer { + peer := newPeer(c, []Protocol{discard}, dialAddr) + peer.startSubprotocols([]Cap{discard.cap()}) + connected.Done() + return peer + }) + defer srv.Stop() + + // dial a bunch of conns + var conns = make([]net.Conn, 8) + connected.Add(len(conns)) + deadline := time.Now().Add(3 * time.Second) + dialer := &net.Dialer{Deadline: deadline} + for i := range conns { + conn, err := dialer.Dial("tcp", srv.ListenAddr) + if err != nil { + t.Fatalf("conn %d: dial error: %v", i, err) + } + defer conn.Close() + conn.SetDeadline(deadline) + conns[i] = conn + } + connected.Wait() + + // broadcast one message + srv.Broadcast("discard", 0, "foo") + goldbuf := new(bytes.Buffer) + writeMsg(goldbuf, NewMsg(16, "foo")) + golden := goldbuf.Bytes() + + // check that the message has been written everywhere + for i, conn := range conns { + buf := make([]byte, len(golden)) + if _, err := io.ReadFull(conn, buf); err != nil { + t.Errorf("conn %d: read error: %v", i, err) + } else if !bytes.Equal(buf, golden) { + t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden) + } + } +} diff --git a/p2p/testlog_test.go b/p2p/testlog_test.go new file mode 100644 index 000000000..951d43243 --- /dev/null +++ b/p2p/testlog_test.go @@ -0,0 +1,28 @@ +package p2p + +import ( + "testing" + + "github.com/ethereum/go-ethereum/logger" +) + +type testLogger struct{ t *testing.T } + +func testlog(t *testing.T) testLogger { + logger.Reset() + l := testLogger{t} + logger.AddLogSystem(l) + return l +} + +func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel } +func (testLogger) SetLogLevel(logger.LogLevel) {} + +func (l testLogger) LogPrint(level logger.LogLevel, msg string) { + l.t.Logf("%s", msg) +} + +func (testLogger) detach() { + logger.Flush() + logger.Reset() +} diff --git a/p2p/testpoc7.go b/p2p/testpoc7.go new file mode 100644 index 000000000..c0cc5c544 --- /dev/null +++ b/p2p/testpoc7.go @@ -0,0 +1,40 @@ +// +build none + +package main + +import ( + "fmt" + "log" + "net" + "os" + + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/p2p" + "github.com/obscuren/secp256k1-go" +) + +func main() { + logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel)) + + pub, _ := secp256k1.GenerateKeyPair() + srv := p2p.Server{ + MaxPeers: 10, + Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)), + ListenAddr: ":30303", + NAT: p2p.PMP(net.ParseIP("10.0.0.1")), + } + if err := srv.Start(); err != nil { + fmt.Println("could not start server:", err) + os.Exit(1) + } + + // add seed peers + seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303") + if err != nil { + fmt.Println("couldn't resolve:", err) + os.Exit(1) + } + srv.SuggestPeer(seed.IP, seed.Port, nil) + + select {} +} |