diff options
-rw-r--r-- | p2p/client_identity.go | 63 | ||||
-rw-r--r-- | p2p/client_identity_test.go | 30 | ||||
-rw-r--r-- | p2p/connection.go | 275 | ||||
-rw-r--r-- | p2p/connection_test.go | 222 | ||||
-rw-r--r-- | p2p/message.go | 75 | ||||
-rw-r--r-- | p2p/message_test.go | 38 | ||||
-rw-r--r-- | p2p/messenger.go | 220 | ||||
-rw-r--r-- | p2p/messenger_test.go | 146 | ||||
-rw-r--r-- | p2p/natpmp.go | 55 | ||||
-rw-r--r-- | p2p/natupnp.go | 335 | ||||
-rw-r--r-- | p2p/network.go | 196 | ||||
-rw-r--r-- | p2p/peer.go | 83 | ||||
-rw-r--r-- | p2p/peer_error.go | 76 | ||||
-rw-r--r-- | p2p/peer_error_handler.go | 101 | ||||
-rw-r--r-- | p2p/peer_error_handler_test.go | 34 | ||||
-rw-r--r-- | p2p/peer_test.go | 96 | ||||
-rw-r--r-- | p2p/protocol.go | 278 | ||||
-rw-r--r-- | p2p/server.go | 484 | ||||
-rw-r--r-- | p2p/server_test.go | 208 |
19 files changed, 3015 insertions, 0 deletions
diff --git a/p2p/client_identity.go b/p2p/client_identity.go new file mode 100644 index 000000000..236b23106 --- /dev/null +++ b/p2p/client_identity.go @@ -0,0 +1,63 @@ +package p2p + +import ( + "fmt" + "runtime" +) + +// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc. +type ClientIdentity interface { + String() string + Pubkey() []byte +} + +type SimpleClientIdentity struct { + clientIdentifier string + version string + customIdentifier string + os string + implementation string + pubkey string +} + +func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey string) *SimpleClientIdentity { + clientIdentity := &SimpleClientIdentity{ + clientIdentifier: clientIdentifier, + version: version, + customIdentifier: customIdentifier, + os: runtime.GOOS, + implementation: runtime.Version(), + pubkey: pubkey, + } + + return clientIdentity +} + +func (c *SimpleClientIdentity) init() { +} + +func (c *SimpleClientIdentity) String() string { + var id string + if len(c.customIdentifier) > 0 { + id = "/" + c.customIdentifier + } + + return fmt.Sprintf("%s/v%s%s/%s/%s", + c.clientIdentifier, + c.version, + id, + c.os, + c.implementation) +} + +func (c *SimpleClientIdentity) Pubkey() []byte { + return []byte(c.pubkey) +} + +func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) { + c.customIdentifier = customIdentifier +} + +func (c *SimpleClientIdentity) GetCustomIdentifier() string { + return c.customIdentifier +} diff --git a/p2p/client_identity_test.go b/p2p/client_identity_test.go new file mode 100644 index 000000000..40b0e6f5e --- /dev/null +++ b/p2p/client_identity_test.go @@ -0,0 +1,30 @@ +package p2p + +import ( + "fmt" + "runtime" + "testing" +) + +func TestClientIdentity(t *testing.T) { + clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", "pubkey") + clientString := clientIdentity.String() + expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version()) + if clientString != expected { + t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString) + } + customIdentifier := clientIdentity.GetCustomIdentifier() + if customIdentifier != "test" { + t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier) + } + clientIdentity.SetCustomIdentifier("test2") + customIdentifier = clientIdentity.GetCustomIdentifier() + if customIdentifier != "test2" { + t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier) + } + clientString = clientIdentity.String() + expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version()) + if clientString != expected { + t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString) + } +} diff --git a/p2p/connection.go b/p2p/connection.go new file mode 100644 index 000000000..e999cbe55 --- /dev/null +++ b/p2p/connection.go @@ -0,0 +1,275 @@ +package p2p + +import ( + "bytes" + // "fmt" + "net" + "time" + + "github.com/ethereum/eth-go/ethutil" +) + +type Connection struct { + conn net.Conn + // conn NetworkConnection + timeout time.Duration + in chan []byte + out chan []byte + err chan *PeerError + closingIn chan chan bool + closingOut chan chan bool +} + +// const readBufferLength = 2 //for testing + +const readBufferLength = 1440 +const partialsQueueSize = 10 +const maxPendingQueueSize = 1 +const defaultTimeout = 500 + +var magicToken = []byte{34, 64, 8, 145} + +func (self *Connection) Open() { + go self.startRead() + go self.startWrite() +} + +func (self *Connection) Close() { + self.closeIn() + self.closeOut() +} + +func (self *Connection) closeIn() { + errc := make(chan bool) + self.closingIn <- errc + <-errc +} + +func (self *Connection) closeOut() { + errc := make(chan bool) + self.closingOut <- errc + <-errc +} + +func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection { + return &Connection{ + conn: conn, + timeout: defaultTimeout, + in: make(chan []byte), + out: make(chan []byte), + err: errchan, + closingIn: make(chan chan bool, 1), + closingOut: make(chan chan bool, 1), + } +} + +func (self *Connection) Read() <-chan []byte { + return self.in +} + +func (self *Connection) Write() chan<- []byte { + return self.out +} + +func (self *Connection) Error() <-chan *PeerError { + return self.err +} + +func (self *Connection) startRead() { + payloads := make(chan []byte) + done := make(chan *PeerError) + pending := [][]byte{} + var head []byte + var wait time.Duration // initally 0 (no delay) + read := time.After(wait * time.Millisecond) + + for { + // if pending empty, nil channel blocks + var in chan []byte + if len(pending) > 0 { + in = self.in // enable send case + head = pending[0] + } else { + in = nil + } + + select { + case <-read: + go self.read(payloads, done) + case err := <-done: + if err == nil { // no error but nothing to read + if len(pending) < maxPendingQueueSize { + wait = 100 + } else if wait == 0 { + wait = 100 + } else { + wait = 2 * wait + } + } else { + self.err <- err // report error + wait = 100 + } + read = time.After(wait * time.Millisecond) + case payload := <-payloads: + pending = append(pending, payload) + if len(pending) < maxPendingQueueSize { + wait = 0 + } else { + wait = 100 + } + read = time.After(wait * time.Millisecond) + case in <- head: + pending = pending[1:] + case errc := <-self.closingIn: + errc <- true + close(self.in) + return + } + + } +} + +func (self *Connection) startWrite() { + pending := [][]byte{} + done := make(chan *PeerError) + writing := false + for { + if len(pending) > 0 && !writing { + writing = true + go self.write(pending[0], done) + } + select { + case payload := <-self.out: + pending = append(pending, payload) + case err := <-done: + if err == nil { + pending = pending[1:] + writing = false + } else { + self.err <- err // report error + } + case errc := <-self.closingOut: + errc <- true + close(self.out) + return + } + } +} + +func pack(payload []byte) (packet []byte) { + length := ethutil.NumberToBytes(uint32(len(payload)), 32) + // return error if too long? + // Write magic token and payload length (first 8 bytes) + packet = append(magicToken, length...) + packet = append(packet, payload...) + return +} + +func avoidPanic(done chan *PeerError) { + if rec := recover(); rec != nil { + err := NewPeerError(MiscError, " %v", rec) + logger.Debugln(err) + done <- err + } +} + +func (self *Connection) write(payload []byte, done chan *PeerError) { + defer avoidPanic(done) + var err *PeerError + _, ok := self.conn.Write(pack(payload)) + if ok != nil { + err = NewPeerError(WriteError, " %v", ok) + logger.Debugln(err) + } + done <- err +} + +func (self *Connection) read(payloads chan []byte, done chan *PeerError) { + //defer avoidPanic(done) + + partials := make(chan []byte, partialsQueueSize) + errc := make(chan *PeerError) + go self.readPartials(partials, errc) + + packet := []byte{} + length := 8 + start := true + var err *PeerError +out: + for { + // appends partials read via connection until packet is + // - either parseable (>=8bytes) + // - or complete (payload fully consumed) + for len(packet) < length { + partial, ok := <-partials + if !ok { // partials channel is closed + err = <-errc + if err == nil && len(packet) > 0 { + if start { + err = NewPeerError(PacketTooShort, "%v", packet) + } else { + err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length) + } + } + break out + } + packet = append(packet, partial...) + } + if start { + // at least 8 bytes read, can validate packet + if bytes.Compare(magicToken, packet[:4]) != 0 { + err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4]) + break + } + length = int(ethutil.BytesToNumber(packet[4:8])) + packet = packet[8:] + + if length > 0 { + start = false // now consuming payload + } else { //penalize peer but read on + self.err <- NewPeerError(EmptyPayload, "") + length = 8 + } + } else { + // packet complete (payload fully consumed) + payloads <- packet[:length] + packet = packet[length:] // resclice packet + start = true + length = 8 + } + } + + // this stops partials read via the connection, should we? + //if err != nil { + // select { + // case errc <- err + // default: + //} + done <- err +} + +func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) { + defer close(partials) + for { + // Give buffering some time + self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond)) + buffer := make([]byte, readBufferLength) + // read partial from connection + bytesRead, err := self.conn.Read(buffer) + if err == nil || err.Error() == "EOF" { + if bytesRead > 0 { + partials <- buffer[:bytesRead] + } + if err != nil && err.Error() == "EOF" { + break + } + } else { + // unexpected error, report to errc + err := NewPeerError(ReadError, " %v", err) + logger.Debugln(err) + errc <- err + return // will close partials channel + } + } + close(errc) +} diff --git a/p2p/connection_test.go b/p2p/connection_test.go new file mode 100644 index 000000000..76ee8021c --- /dev/null +++ b/p2p/connection_test.go @@ -0,0 +1,222 @@ +package p2p + +import ( + "bytes" + "fmt" + "io" + "net" + "testing" + "time" +) + +type TestNetworkConnection struct { + in chan []byte + current []byte + Out [][]byte + addr net.Addr +} + +func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection { + return &TestNetworkConnection{ + in: make(chan []byte), + current: []byte{}, + Out: [][]byte{}, + addr: addr, + } +} + +func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) { + time.Sleep(latency) + for _, s := range packets { + self.in <- s + } +} + +func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) { + if len(self.current) == 0 { + select { + case self.current = <-self.in: + default: + return 0, io.EOF + } + } + length := len(self.current) + if length > len(buff) { + copy(buff[:], self.current[:len(buff)]) + self.current = self.current[len(buff):] + return len(buff), nil + } else { + copy(buff[:length], self.current[:]) + self.current = []byte{} + return length, io.EOF + } +} + +func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) { + self.Out = append(self.Out, buff) + fmt.Printf("net write %v\n%v\n", len(self.Out), buff) + return len(buff), nil +} + +func (self *TestNetworkConnection) Close() (err error) { + return +} + +func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) { + return +} + +func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) { + return self.addr +} + +func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) { + return +} + +func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) { + return +} + +func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) { + return +} + +func setupConnection() (*Connection, *TestNetworkConnection) { + addr := &TestAddr{"test:30303"} + net := NewTestNetworkConnection(addr) + conn := NewConnection(net, NewPeerErrorChannel()) + conn.Open() + return conn, net +} + +func TestReadingNilPacket(t *testing.T) { + conn, net := setupConnection() + go net.In(0, []byte{}) + // time.Sleep(10 * time.Millisecond) + select { + case packet := <-conn.Read(): + t.Errorf("read %v", packet) + case err := <-conn.Error(): + t.Errorf("incorrect error %v", err) + default: + } + conn.Close() +} + +func TestReadingShortPacket(t *testing.T) { + conn, net := setupConnection() + go net.In(0, []byte{0}) + select { + case packet := <-conn.Read(): + t.Errorf("read %v", packet) + case err := <-conn.Error(): + if err.Code != PacketTooShort { + t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort) + } + } + conn.Close() +} + +func TestReadingInvalidPacket(t *testing.T) { + conn, net := setupConnection() + go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + select { + case packet := <-conn.Read(): + t.Errorf("read %v", packet) + case err := <-conn.Error(): + if err.Code != MagicTokenMismatch { + t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch) + } + } + conn.Close() +} + +func TestReadingInvalidPayload(t *testing.T) { + conn, net := setupConnection() + go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0}) + select { + case packet := <-conn.Read(): + t.Errorf("read %v", packet) + case err := <-conn.Error(): + if err.Code != PayloadTooShort { + t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort) + } + } + conn.Close() +} + +func TestReadingEmptyPayload(t *testing.T) { + conn, net := setupConnection() + go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0}) + time.Sleep(10 * time.Millisecond) + select { + case packet := <-conn.Read(): + t.Errorf("read %v", packet) + default: + } + select { + case err := <-conn.Error(): + code := err.Code + if code != EmptyPayload { + t.Errorf("incorrect error, expected EmptyPayload, got %v", code) + } + default: + t.Errorf("no error, expected EmptyPayload") + } + conn.Close() +} + +func TestReadingCompletePacket(t *testing.T) { + conn, net := setupConnection() + go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1}) + time.Sleep(10 * time.Millisecond) + select { + case packet := <-conn.Read(): + if bytes.Compare(packet, []byte{1}) != 0 { + t.Errorf("incorrect payload read") + } + case err := <-conn.Error(): + t.Errorf("incorrect error %v", err) + default: + t.Errorf("nothing read") + } + conn.Close() +} + +func TestReadingTwoCompletePackets(t *testing.T) { + conn, net := setupConnection() + go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1}) + + for i := 0; i < 2; i++ { + time.Sleep(10 * time.Millisecond) + select { + case packet := <-conn.Read(): + if bytes.Compare(packet, []byte{byte(i)}) != 0 { + t.Errorf("incorrect payload read") + } + case err := <-conn.Error(): + t.Errorf("incorrect error %v", err) + default: + t.Errorf("nothing read") + } + } + conn.Close() +} + +func TestWriting(t *testing.T) { + conn, net := setupConnection() + conn.Write() <- []byte{0} + time.Sleep(10 * time.Millisecond) + if len(net.Out) == 0 { + t.Errorf("no output") + } else { + out := net.Out[0] + if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 { + t.Errorf("incorrect packet %v", out) + } + } + conn.Close() +} + +// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243 diff --git a/p2p/message.go b/p2p/message.go new file mode 100644 index 000000000..4886eaa1f --- /dev/null +++ b/p2p/message.go @@ -0,0 +1,75 @@ +package p2p + +import ( + // "fmt" + "github.com/ethereum/eth-go/ethutil" +) + +type MsgCode uint8 + +type Msg struct { + code MsgCode // this is the raw code as per adaptive msg code scheme + data *ethutil.Value + encoded []byte +} + +func (self *Msg) Code() MsgCode { + return self.code +} + +func (self *Msg) Data() *ethutil.Value { + return self.data +} + +func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) { + + // // data := [][]interface{}{} + // data := []interface{}{} + // for _, value := range params { + // if encodable, ok := value.(ethutil.RlpEncodeDecode); ok { + // data = append(data, encodable.RlpValue()) + // } else if raw, ok := value.([]interface{}); ok { + // data = append(data, raw) + // } else { + // // data = append(data, interface{}(raw)) + // err = fmt.Errorf("Unable to encode object of type %T", value) + // return + // } + // } + return &Msg{ + code: code, + data: ethutil.NewValue(interface{}(params)), + }, nil +} + +func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) { + value := ethutil.NewValueFromBytes(encoded) + // Type of message + code := value.Get(0).Uint() + // Actual data + data := value.SliceFrom(1) + + msg = &Msg{ + code: MsgCode(code), + data: data, + // data: ethutil.NewValue(data), + encoded: encoded, + } + return +} + +func (self *Msg) Decode(offset MsgCode) { + self.code = self.code - offset +} + +// encode takes an offset argument to implement adaptive message coding +// the encoded message is memoized to make msgs relayed to several peers more efficient +func (self *Msg) Encode(offset MsgCode) (res []byte) { + if len(self.encoded) == 0 { + res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode() + self.encoded = res + } else { + res = self.encoded + } + return +} diff --git a/p2p/message_test.go b/p2p/message_test.go new file mode 100644 index 000000000..e9d46f2c3 --- /dev/null +++ b/p2p/message_test.go @@ -0,0 +1,38 @@ +package p2p + +import ( + "testing" +) + +func TestNewMsg(t *testing.T) { + msg, _ := NewMsg(3, 1, "000") + if msg.Code() != 3 { + t.Errorf("incorrect code %v", msg.Code()) + } + data0 := msg.Data().Get(0).Uint() + data1 := string(msg.Data().Get(1).Bytes()) + if data0 != 1 { + t.Errorf("incorrect data %v", data0) + } + if data1 != "000" { + t.Errorf("incorrect data %v", data1) + } +} + +func TestEncodeDecodeMsg(t *testing.T) { + msg, _ := NewMsg(3, 1, "000") + encoded := msg.Encode(3) + msg, _ = NewMsgFromBytes(encoded) + msg.Decode(3) + if msg.Code() != 3 { + t.Errorf("incorrect code %v", msg.Code()) + } + data0 := msg.Data().Get(0).Uint() + data1 := msg.Data().Get(1).Str() + if data0 != 1 { + t.Errorf("incorrect data %v", data0) + } + if data1 != "000" { + t.Errorf("incorrect data %v", data1) + } +} diff --git a/p2p/messenger.go b/p2p/messenger.go new file mode 100644 index 000000000..d42ba1720 --- /dev/null +++ b/p2p/messenger.go @@ -0,0 +1,220 @@ +package p2p + +import ( + "fmt" + "sync" + "time" +) + +const ( + handlerTimeout = 1000 +) + +type Handlers map[string](func(p *Peer) Protocol) + +type Messenger struct { + conn *Connection + peer *Peer + handlers Handlers + protocolLock sync.RWMutex + protocols []Protocol + offsets []MsgCode // offsets for adaptive message idss + protocolTable map[string]int + quit chan chan bool + err chan *PeerError + pulse chan bool +} + +func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger { + baseProtocol := NewBaseProtocol(peer) + return &Messenger{ + conn: conn, + peer: peer, + offsets: []MsgCode{baseProtocol.Offset()}, + handlers: handlers, + protocols: []Protocol{baseProtocol}, + protocolTable: make(map[string]int), + err: errchan, + pulse: make(chan bool, 1), + quit: make(chan chan bool, 1), + } +} + +func (self *Messenger) Start() { + self.conn.Open() + go self.messenger() + self.protocolLock.RLock() + defer self.protocolLock.RUnlock() + self.protocols[0].Start() +} + +func (self *Messenger) Stop() { + // close pulse to stop ping pong monitoring + close(self.pulse) + self.protocolLock.RLock() + defer self.protocolLock.RUnlock() + for _, protocol := range self.protocols { + protocol.Stop() // could be parallel + } + q := make(chan bool) + self.quit <- q + <-q + self.conn.Close() +} + +func (self *Messenger) messenger() { + in := self.conn.Read() + for { + select { + case payload, ok := <-in: + //dispatches message to the protocol asynchronously + if ok { + go self.handle(payload) + } else { + return + } + case q := <-self.quit: + q <- true + return + } + } +} + +// handles each message by dispatching to the appropriate protocol +// using adaptive message codes +// this function is started as a separate go routine for each message +// it waits for the protocol response +// then encodes and sends outgoing messages to the connection's write channel +func (self *Messenger) handle(payload []byte) { + // send ping to heartbeat channel signalling time of last message + // select { + // case self.pulse <- true: + // default: + // } + self.pulse <- true + // initialise message from payload + msg, err := NewMsgFromBytes(payload) + if err != nil { + self.err <- NewPeerError(MiscError, " %v", err) + return + } + // retrieves protocol based on message Code + protocol, offset, peerErr := self.getProtocol(msg.Code()) + if err != nil { + self.err <- peerErr + return + } + // reset message code based on adaptive offset + msg.Decode(offset) + // dispatches + response := make(chan *Msg) + go protocol.HandleIn(msg, response) + // protocol reponse timeout to prevent leaks + timer := time.After(handlerTimeout * time.Millisecond) + for { + select { + case outgoing, ok := <-response: + // we check if response channel is not closed + if ok { + self.conn.Write() <- outgoing.Encode(offset) + } else { + return + } + case <-timer: + return + } + } +} + +// negotiated protocols +// stores offsets needed for adaptive message id scheme + +// based on offsets set at handshake +// get the right protocol to handle the message +func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) { + self.protocolLock.RLock() + defer self.protocolLock.RUnlock() + base := MsgCode(0) + for index, offset := range self.offsets { + if code < offset { + return self.protocols[index], base, nil + } + base = offset + } + return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code) +} + +func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) { + fmt.Printf("pingpong keepalive started at %v", time.Now()) + + timer := time.After(timeout) + pinged := false + for { + select { + case _, ok := <-self.pulse: + if ok { + pinged = false + timer = time.After(timeout) + } else { + // pulse is closed, stop monitoring + return + } + case <-timer: + if pinged { + fmt.Printf("timeout at %v", time.Now()) + timeoutCallback() + return + } else { + fmt.Printf("pinged at %v", time.Now()) + pingCallback() + timer = time.After(gracePeriod) + pinged = true + } + } + } +} + +func (self *Messenger) AddProtocols(protocols []string) { + self.protocolLock.Lock() + defer self.protocolLock.Unlock() + i := len(self.offsets) + offset := self.offsets[i-1] + for _, name := range protocols { + protocolFunc, ok := self.handlers[name] + if ok { + protocol := protocolFunc(self.peer) + self.protocolTable[name] = i + i++ + offset += protocol.Offset() + fmt.Println("offset ", name, offset) + + self.offsets = append(self.offsets, offset) + self.protocols = append(self.protocols, protocol) + protocol.Start() + } else { + fmt.Println("no ", name) + // protocol not handled + } + } +} + +func (self *Messenger) Write(protocol string, msg *Msg) error { + self.protocolLock.RLock() + defer self.protocolLock.RUnlock() + i := 0 + offset := MsgCode(0) + if len(protocol) > 0 { + var ok bool + i, ok = self.protocolTable[protocol] + if !ok { + return fmt.Errorf("protocol %v not handled by peer", protocol) + } + offset = self.offsets[i-1] + } + handler := self.protocols[i] + // checking if protocol status/caps allows the message to be sent out + if handler.HandleOut(msg) { + self.conn.Write() <- msg.Encode(offset) + } + return nil +} diff --git a/p2p/messenger_test.go b/p2p/messenger_test.go new file mode 100644 index 000000000..bc21d34ba --- /dev/null +++ b/p2p/messenger_test.go @@ -0,0 +1,146 @@ +package p2p + +import ( + // "fmt" + "bytes" + "github.com/ethereum/eth-go/ethutil" + "testing" + "time" +) + +func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) { + errchan := NewPeerErrorChannel() + addr := &TestAddr{"test:30303"} + net := NewTestNetworkConnection(addr) + conn := NewConnection(net, errchan) + mess := NewMessenger(nil, conn, errchan, handlers) + mess.Start() + return net, errchan, mess +} + +type TestProtocol struct { + Msgs []*Msg +} + +func (self *TestProtocol) Start() { +} + +func (self *TestProtocol) Stop() { +} + +func (self *TestProtocol) Offset() MsgCode { + return MsgCode(5) +} + +func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) { + self.Msgs = append(self.Msgs, msg) + close(response) +} + +func (self *TestProtocol) HandleOut(msg *Msg) bool { + if msg.Code() > 3 { + return false + } else { + return true + } +} + +func (self *TestProtocol) Name() string { + return "a" +} + +func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte { + msg, _ := NewMsg(code, params...) + encoded := msg.Encode(offset) + packet := []byte{34, 64, 8, 145} + packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...) + return append(packet, encoded...) +} + +func TestRead(t *testing.T) { + handlers := make(Handlers) + testProtocol := &TestProtocol{Msgs: []*Msg{}} + handlers["a"] = func(p *Peer) Protocol { return testProtocol } + net, _, mess := setupMessenger(handlers) + mess.AddProtocols([]string{"a"}) + defer mess.Stop() + wait := 1 * time.Millisecond + packet := Packet(16, 1, uint32(1), "000") + go net.In(0, packet) + time.Sleep(wait) + if len(testProtocol.Msgs) != 1 { + t.Errorf("msg not relayed to correct protocol") + } else { + if testProtocol.Msgs[0].Code() != 1 { + t.Errorf("incorrect msg code relayed to protocol") + } + } +} + +func TestWrite(t *testing.T) { + handlers := make(Handlers) + testProtocol := &TestProtocol{Msgs: []*Msg{}} + handlers["a"] = func(p *Peer) Protocol { return testProtocol } + net, _, mess := setupMessenger(handlers) + mess.AddProtocols([]string{"a"}) + defer mess.Stop() + wait := 1 * time.Millisecond + msg, _ := NewMsg(3, uint32(1), "000") + err := mess.Write("b", msg) + if err == nil { + t.Errorf("expect error for unknown protocol") + } + err = mess.Write("a", msg) + if err != nil { + t.Errorf("expect no error for known protocol: %v", err) + } else { + time.Sleep(wait) + if len(net.Out) != 1 { + t.Errorf("msg not written") + } else { + out := net.Out[0] + packet := Packet(16, 3, uint32(1), "000") + if bytes.Compare(out, packet) != 0 { + t.Errorf("incorrect packet %v", out) + } + } + } +} + +func TestPulse(t *testing.T) { + net, _, mess := setupMessenger(make(Handlers)) + defer mess.Stop() + ping := false + timeout := false + pingTimeout := 10 * time.Millisecond + gracePeriod := 200 * time.Millisecond + go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true }) + net.In(0, Packet(0, 1)) + if ping { + t.Errorf("ping sent too early") + } + time.Sleep(pingTimeout + 100*time.Millisecond) + if !ping { + t.Errorf("no ping sent after timeout") + } + if timeout { + t.Errorf("timeout too early") + } + ping = false + net.In(0, Packet(0, 1)) + time.Sleep(pingTimeout + 100*time.Millisecond) + if !ping { + t.Errorf("no ping sent after timeout") + } + if timeout { + t.Errorf("timeout too early") + } + ping = false + time.Sleep(gracePeriod) + if ping { + t.Errorf("ping called twice") + } + if !timeout { + t.Errorf("no timeout after grace period") + } +} diff --git a/p2p/natpmp.go b/p2p/natpmp.go new file mode 100644 index 000000000..ff966d070 --- /dev/null +++ b/p2p/natpmp.go @@ -0,0 +1,55 @@ +package p2p + +import ( + "fmt" + "net" + + natpmp "github.com/jackpal/go-nat-pmp" +) + +// Adapt the NAT-PMP protocol to the NAT interface + +// TODO: +// + Register for changes to the external address. +// + Re-register port mapping when router reboots. +// + A mechanism for keeping a port mapping registered. + +type natPMPClient struct { + client *natpmp.Client +} + +func NewNatPMP(gateway net.IP) (nat NAT) { + return &natPMPClient{natpmp.NewClient(gateway)} +} + +func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) { + response, err := n.client.GetExternalAddress() + if err != nil { + return + } + ip := response.ExternalIPAddress + addr = net.IPv4(ip[0], ip[1], ip[2], ip[3]) + return +} + +func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int, + description string, timeout int) (mappedExternalPort int, err error) { + if timeout <= 0 { + err = fmt.Errorf("timeout must not be <= 0") + return + } + // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping. + response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout) + if err != nil { + return + } + mappedExternalPort = int(response.MappedExternalPort) + return +} + +func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { + // To destroy a mapping, send an add-port with + // an internalPort of the internal port to destroy, an external port of zero and a time of zero. + _, err = n.client.AddPortMapping(protocol, internalPort, 0, 0) + return +} diff --git a/p2p/natupnp.go b/p2p/natupnp.go new file mode 100644 index 000000000..fa9798d4d --- /dev/null +++ b/p2p/natupnp.go @@ -0,0 +1,335 @@ +package p2p + +// Just enough UPnP to be able to forward ports +// + +import ( + "bytes" + "encoding/xml" + "errors" + "net" + "net/http" + "os" + "strconv" + "strings" + "time" +) + +type upnpNAT struct { + serviceURL string + ourIP string +} + +func upnpDiscover(attempts int) (nat NAT, err error) { + ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") + if err != nil { + return + } + conn, err := net.ListenPacket("udp4", ":0") + if err != nil { + return + } + socket := conn.(*net.UDPConn) + defer socket.Close() + + err = socket.SetDeadline(time.Now().Add(10 * time.Second)) + if err != nil { + return + } + + st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" + buf := bytes.NewBufferString( + "M-SEARCH * HTTP/1.1\r\n" + + "HOST: 239.255.255.250:1900\r\n" + + st + + "MAN: \"ssdp:discover\"\r\n" + + "MX: 2\r\n\r\n") + message := buf.Bytes() + answerBytes := make([]byte, 1024) + for i := 0; i < attempts; i++ { + _, err = socket.WriteToUDP(message, ssdp) + if err != nil { + return + } + var n int + n, _, err = socket.ReadFromUDP(answerBytes) + if err != nil { + continue + // socket.Close() + // return + } + answer := string(answerBytes[0:n]) + if strings.Index(answer, "\r\n"+st) < 0 { + continue + } + // HTTP header field names are case-insensitive. + // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 + locString := "\r\nlocation: " + answer = strings.ToLower(answer) + locIndex := strings.Index(answer, locString) + if locIndex < 0 { + continue + } + loc := answer[locIndex+len(locString):] + endIndex := strings.Index(loc, "\r\n") + if endIndex < 0 { + continue + } + locURL := loc[0:endIndex] + var serviceURL string + serviceURL, err = getServiceURL(locURL) + if err != nil { + return + } + var ourIP string + ourIP, err = getOurIP() + if err != nil { + return + } + nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP} + return + } + err = errors.New("UPnP port discovery failed.") + return +} + +// service represents the Service type in an UPnP xml description. +// Only the parts we care about are present and thus the xml may have more +// fields than present in the structure. +type service struct { + ServiceType string `xml:"serviceType"` + ControlURL string `xml:"controlURL"` +} + +// deviceList represents the deviceList type in an UPnP xml description. +// Only the parts we care about are present and thus the xml may have more +// fields than present in the structure. +type deviceList struct { + XMLName xml.Name `xml:"deviceList"` + Device []device `xml:"device"` +} + +// serviceList represents the serviceList type in an UPnP xml description. +// Only the parts we care about are present and thus the xml may have more +// fields than present in the structure. +type serviceList struct { + XMLName xml.Name `xml:"serviceList"` + Service []service `xml:"service"` +} + +// device represents the device type in an UPnP xml description. +// Only the parts we care about are present and thus the xml may have more +// fields than present in the structure. +type device struct { + XMLName xml.Name `xml:"device"` + DeviceType string `xml:"deviceType"` + DeviceList deviceList `xml:"deviceList"` + ServiceList serviceList `xml:"serviceList"` +} + +// specVersion represents the specVersion in a UPnP xml description. +// Only the parts we care about are present and thus the xml may have more +// fields than present in the structure. +type specVersion struct { + XMLName xml.Name `xml:"specVersion"` + Major int `xml:"major"` + Minor int `xml:"minor"` +} + +// root represents the Root document for a UPnP xml description. +// Only the parts we care about are present and thus the xml may have more +// fields than present in the structure. +type root struct { + XMLName xml.Name `xml:"root"` + SpecVersion specVersion + Device device +} + +func getChildDevice(d *device, deviceType string) *device { + dl := d.DeviceList.Device + for i := 0; i < len(dl); i++ { + if dl[i].DeviceType == deviceType { + return &dl[i] + } + } + return nil +} + +func getChildService(d *device, serviceType string) *service { + sl := d.ServiceList.Service + for i := 0; i < len(sl); i++ { + if sl[i].ServiceType == serviceType { + return &sl[i] + } + } + return nil +} + +func getOurIP() (ip string, err error) { + hostname, err := os.Hostname() + if err != nil { + return + } + p, err := net.LookupIP(hostname) + if err != nil && len(p) > 0 { + return + } + return p[0].String(), nil +} + +func getServiceURL(rootURL string) (url string, err error) { + r, err := http.Get(rootURL) + if err != nil { + return + } + defer r.Body.Close() + if r.StatusCode >= 400 { + err = errors.New(string(r.StatusCode)) + return + } + var root root + err = xml.NewDecoder(r.Body).Decode(&root) + + if err != nil { + return + } + a := &root.Device + if a.DeviceType != "urn:schemas-upnp-org:device:InternetGatewayDevice:1" { + err = errors.New("No InternetGatewayDevice") + return + } + b := getChildDevice(a, "urn:schemas-upnp-org:device:WANDevice:1") + if b == nil { + err = errors.New("No WANDevice") + return + } + c := getChildDevice(b, "urn:schemas-upnp-org:device:WANConnectionDevice:1") + if c == nil { + err = errors.New("No WANConnectionDevice") + return + } + d := getChildService(c, "urn:schemas-upnp-org:service:WANIPConnection:1") + if d == nil { + err = errors.New("No WANIPConnection") + return + } + url = combineURL(rootURL, d.ControlURL) + return +} + +func combineURL(rootURL, subURL string) string { + protocolEnd := "://" + protoEndIndex := strings.Index(rootURL, protocolEnd) + a := rootURL[protoEndIndex+len(protocolEnd):] + rootIndex := strings.Index(a, "/") + return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL +} + +func soapRequest(url, function, message string) (r *http.Response, err error) { + fullMessage := "<?xml version=\"1.0\" ?>" + + "<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" + + "<s:Body>" + message + "</s:Body></s:Envelope>" + + req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage)) + if err != nil { + return + } + req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"") + req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3") + //req.Header.Set("Transfer-Encoding", "chunked") + req.Header.Set("SOAPAction", "\"urn:schemas-upnp-org:service:WANIPConnection:1#"+function+"\"") + req.Header.Set("Connection", "Close") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Pragma", "no-cache") + + r, err = http.DefaultClient.Do(req) + if err != nil { + return + } + + if r.Body != nil { + defer r.Body.Close() + } + + if r.StatusCode >= 400 { + // log.Stderr(function, r.StatusCode) + err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function) + r = nil + return + } + return +} + +type statusInfo struct { + externalIpAddress string +} + +func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { + + message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + + "</u:GetStatusInfo>" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) + if err != nil { + return + } + + // TODO: Write a soap reply parser. It has to eat the Body and envelope tags... + + response.Body.Close() + return +} + +func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { + info, err := n.getStatusInfo() + if err != nil { + return + } + addr = net.ParseIP(info.externalIpAddress) + return +} + +func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { + // A single concatenation would break ARM compilation. + message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + + "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" + + "<NewInternalClient>" + n.ourIP + "</NewInternalClient>" + + "<NewEnabled>1</NewEnabled><NewPortMappingDescription>" + message += description + + "</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) + + "</NewLeaseDuration></u:AddPortMapping>" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "AddPortMapping", message) + if err != nil { + return + } + + // TODO: check response to see if the port was forwarded + // log.Println(message, response) + mappedExternalPort = externalPort + _ = response + return +} + +func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { + + message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + + "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + + "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + + "</u:DeletePortMapping>" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "DeletePortMapping", message) + if err != nil { + return + } + + // TODO: check response to see if the port was deleted + // log.Println(message, response) + _ = response + return +} diff --git a/p2p/network.go b/p2p/network.go new file mode 100644 index 000000000..820cef1a9 --- /dev/null +++ b/p2p/network.go @@ -0,0 +1,196 @@ +package p2p + +import ( + "fmt" + "math/rand" + "net" + "strconv" + "time" +) + +const ( + DialerTimeout = 180 //seconds + KeepAlivePeriod = 60 //minutes + portMappingUpdateInterval = 900 // seconds = 15 mins + upnpDiscoverAttempts = 3 +) + +// Dialer is not an interface in net, so we define one +// *net.Dialer conforms to this +type Dialer interface { + Dial(network, address string) (net.Conn, error) +} + +type Network interface { + Start() error + Listener(net.Addr) (net.Listener, error) + Dialer(net.Addr) (Dialer, error) + NewAddr(string, int) (addr net.Addr, err error) + ParseAddr(string) (addr net.Addr, err error) +} + +type NAT interface { + GetExternalAddress() (addr net.IP, err error) + AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) + DeletePortMapping(protocol string, externalPort, internalPort int) (err error) +} + +type TCPNetwork struct { + nat NAT + natType NATType + quit chan chan bool + ports chan string +} + +type NATType int + +const ( + NONE = iota + UPNP + PMP +) + +const ( + portMappingTimeout = 1200 // 20 mins +) + +func NewTCPNetwork(natType NATType) (net *TCPNetwork) { + return &TCPNetwork{ + natType: natType, + ports: make(chan string), + } +} + +func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) { + return &net.Dialer{ + Timeout: DialerTimeout * time.Second, + // KeepAlive: KeepAlivePeriod * time.Minute, + LocalAddr: addr, + }, nil +} + +func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) { + if self.natType == UPNP { + _, port, _ := net.SplitHostPort(addr.String()) + if self.quit == nil { + self.quit = make(chan chan bool) + go self.updatePortMappings() + } + self.ports <- port + } + return net.Listen(addr.Network(), addr.String()) +} + +func (self *TCPNetwork) Start() (err error) { + switch self.natType { + case NONE: + case UPNP: + nat, uerr := upnpDiscover(upnpDiscoverAttempts) + if uerr != nil { + err = fmt.Errorf("UPNP failed: ", uerr) + } else { + self.nat = nat + } + case PMP: + err = fmt.Errorf("PMP not implemented") + default: + err = fmt.Errorf("Invalid NAT type: %v", self.natType) + } + return +} + +func (self *TCPNetwork) Stop() { + q := make(chan bool) + self.quit <- q + <-q +} + +func (self *TCPNetwork) addPortMapping(lport int) (err error) { + _, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout) + if err != nil { + logger.Errorf("unable to add port mapping on %v: %v", lport, err) + } else { + logger.Debugf("succesfully added port mapping on %v", lport) + } + return +} + +func (self *TCPNetwork) updatePortMappings() { + timer := time.NewTimer(portMappingUpdateInterval * time.Second) + lports := []int{} +out: + for { + select { + case port := <-self.ports: + int64lport, _ := strconv.ParseInt(port, 10, 16) + lport := int(int64lport) + if err := self.addPortMapping(lport); err != nil { + lports = append(lports, lport) + } + case <-timer.C: + for lport := range lports { + if err := self.addPortMapping(lport); err != nil { + } + } + case errc := <-self.quit: + errc <- true + break out + } + } + + timer.Stop() + for lport := range lports { + if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil { + logger.Debugf("unable to remove port mapping on %v: %v", lport, err) + } else { + logger.Debugf("succesfully removed port mapping on %v", lport) + } + } +} + +func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) { + ip, err := self.lookupIP(host) + if err == nil { + return &net.TCPAddr{ + IP: ip, + Port: port, + }, nil + } + return nil, err +} + +func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) { + host, port, err := net.SplitHostPort(address) + if err == nil { + iport, _ := strconv.Atoi(port) + addr, e := self.NewAddr(host, iport) + return addr, e + } + return nil, err +} + +func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) { + if ip = net.ParseIP(host); ip != nil { + return + } + + var ips []net.IP + ips, err = net.LookupIP(host) + if err != nil { + logger.Warnln(err) + return + } + if len(ips) == 0 { + err = fmt.Errorf("No IP addresses available for %v", host) + logger.Warnln(err) + return + } + if len(ips) > 1 { + // Pick a random IP address, simulating round-robin DNS. + rand.Seed(time.Now().UTC().UnixNano()) + ip = ips[rand.Intn(len(ips))] + } else { + ip = ips[0] + } + return +} diff --git a/p2p/peer.go b/p2p/peer.go new file mode 100644 index 000000000..f4b68a007 --- /dev/null +++ b/p2p/peer.go @@ -0,0 +1,83 @@ +package p2p + +import ( + "fmt" + "net" + "strconv" +) + +type Peer struct { + // quit chan chan bool + Inbound bool // inbound (via listener) or outbound (via dialout) + Address net.Addr + Host []byte + Port uint16 + Pubkey []byte + Id string + Caps []string + peerErrorChan chan *PeerError + messenger *Messenger + peerErrorHandler *PeerErrorHandler + server *Server +} + +func (self *Peer) Messenger() *Messenger { + return self.messenger +} + +func (self *Peer) PeerErrorChan() chan *PeerError { + return self.peerErrorChan +} + +func (self *Peer) Server() *Server { + return self.server +} + +func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer { + peerErrorChan := NewPeerErrorChannel() + host, port, _ := net.SplitHostPort(address.String()) + intport, _ := strconv.Atoi(port) + peer := &Peer{ + Inbound: inbound, + Address: address, + Port: uint16(intport), + Host: net.ParseIP(host), + peerErrorChan: peerErrorChan, + server: server, + } + connection := NewConnection(conn, peerErrorChan) + peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers()) + peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist()) + return peer +} + +func (self *Peer) String() string { + var kind string + if self.Inbound { + kind = "inbound" + } else { + kind = "outbound" + } + return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps) +} + +func (self *Peer) Write(protocol string, msg *Msg) error { + return self.messenger.Write(protocol, msg) +} + +func (self *Peer) Start() { + self.peerErrorHandler.Start() + self.messenger.Start() +} + +func (self *Peer) Stop() { + self.peerErrorHandler.Stop() + self.messenger.Stop() + // q := make(chan bool) + // self.quit <- q + // <-q +} + +func (p *Peer) Encode() []interface{} { + return []interface{}{p.Host, p.Port, p.Pubkey} +} diff --git a/p2p/peer_error.go b/p2p/peer_error.go new file mode 100644 index 000000000..de921878a --- /dev/null +++ b/p2p/peer_error.go @@ -0,0 +1,76 @@ +package p2p + +import ( + "fmt" +) + +type ErrorCode int + +const errorChanCapacity = 10 + +const ( + PacketTooShort = iota + PayloadTooShort + MagicTokenMismatch + EmptyPayload + ReadError + WriteError + MiscError + InvalidMsgCode + InvalidMsg + P2PVersionMismatch + PubkeyMissing + PubkeyInvalid + PubkeyForbidden + ProtocolBreach + PortMismatch + PingTimeout + InvalidGenesis + InvalidNetworkId + InvalidProtocolVersion +) + +var errorToString = map[ErrorCode]string{ + PacketTooShort: "Packet too short", + PayloadTooShort: "Payload too short", + MagicTokenMismatch: "Magic token mismatch", + EmptyPayload: "Empty payload", + ReadError: "Read error", + WriteError: "Write error", + MiscError: "Misc error", + InvalidMsgCode: "Invalid message code", + InvalidMsg: "Invalid message", + P2PVersionMismatch: "P2P Version Mismatch", + PubkeyMissing: "Public key missing", + PubkeyInvalid: "Public key invalid", + PubkeyForbidden: "Public key forbidden", + ProtocolBreach: "Protocol Breach", + PortMismatch: "Port mismatch", + PingTimeout: "Ping timeout", + InvalidGenesis: "Invalid genesis block", + InvalidNetworkId: "Invalid network id", + InvalidProtocolVersion: "Invalid protocol version", +} + +type PeerError struct { + Code ErrorCode + message string +} + +func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError { + desc, ok := errorToString[code] + if !ok { + panic("invalid error code") + } + format = desc + ": " + format + message := fmt.Sprintf(format, v...) + return &PeerError{code, message} +} + +func (self *PeerError) Error() string { + return self.message +} + +func NewPeerErrorChannel() chan *PeerError { + return make(chan *PeerError, errorChanCapacity) +} diff --git a/p2p/peer_error_handler.go b/p2p/peer_error_handler.go new file mode 100644 index 000000000..ca6cae4db --- /dev/null +++ b/p2p/peer_error_handler.go @@ -0,0 +1,101 @@ +package p2p + +import ( + "net" +) + +const ( + severityThreshold = 10 +) + +type DisconnectRequest struct { + addr net.Addr + reason DiscReason +} + +type PeerErrorHandler struct { + quit chan chan bool + address net.Addr + peerDisconnect chan DisconnectRequest + severity int + peerErrorChan chan *PeerError + blacklist Blacklist +} + +func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler { + return &PeerErrorHandler{ + quit: make(chan chan bool), + address: address, + peerDisconnect: peerDisconnect, + peerErrorChan: peerErrorChan, + blacklist: blacklist, + } +} + +func (self *PeerErrorHandler) Start() { + go self.listen() +} + +func (self *PeerErrorHandler) Stop() { + q := make(chan bool) + self.quit <- q + <-q +} + +func (self *PeerErrorHandler) listen() { + for { + select { + case peerError, ok := <-self.peerErrorChan: + if ok { + logger.Debugf("error %v\n", peerError) + go self.handle(peerError) + } else { + return + } + case q := <-self.quit: + q <- true + return + } + } +} + +func (self *PeerErrorHandler) handle(peerError *PeerError) { + reason := DiscReason(' ') + switch peerError.Code { + case P2PVersionMismatch: + reason = DiscIncompatibleVersion + case PubkeyMissing, PubkeyInvalid: + reason = DiscInvalidIdentity + case PubkeyForbidden: + reason = DiscUselessPeer + case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach: + reason = DiscProtocolError + case PingTimeout: + reason = DiscReadTimeout + case WriteError, MiscError: + reason = DiscNetworkError + case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion: + reason = DiscSubprotocolError + default: + self.severity += self.getSeverity(peerError) + } + + if self.severity >= severityThreshold { + reason = DiscSubprotocolError + } + if reason != DiscReason(' ') { + self.peerDisconnect <- DisconnectRequest{ + addr: self.address, + reason: reason, + } + } +} + +func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int { + switch peerError.Code { + case ReadError: + return 4 //tolerate 3 :) + default: + return 1 + } +} diff --git a/p2p/peer_error_handler_test.go b/p2p/peer_error_handler_test.go new file mode 100644 index 000000000..790a7443b --- /dev/null +++ b/p2p/peer_error_handler_test.go @@ -0,0 +1,34 @@ +package p2p + +import ( + // "fmt" + "net" + "testing" + "time" +) + +func TestPeerErrorHandler(t *testing.T) { + address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303} + peerDisconnect := make(chan DisconnectRequest) + peerErrorChan := NewPeerErrorChannel() + peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist()) + peh.Start() + defer peh.Stop() + for i := 0; i < 11; i++ { + select { + case <-peerDisconnect: + t.Errorf("expected no disconnect request") + default: + } + peerErrorChan <- NewPeerError(MiscError, "") + } + time.Sleep(1 * time.Millisecond) + select { + case request := <-peerDisconnect: + if request.addr.String() != address.String() { + t.Errorf("incorrect address %v != %v", request.addr, address) + } + default: + t.Errorf("expected disconnect request") + } +} diff --git a/p2p/peer_test.go b/p2p/peer_test.go new file mode 100644 index 000000000..c37540bef --- /dev/null +++ b/p2p/peer_test.go @@ -0,0 +1,96 @@ +package p2p + +import ( + "bytes" + "fmt" + // "net" + "testing" + "time" +) + +func TestPeer(t *testing.T) { + handlers := make(Handlers) + testProtocol := &TestProtocol{Msgs: []*Msg{}} + handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } + handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } + addr := &TestAddr{"test:30"} + conn := NewTestNetworkConnection(addr) + _, server := SetupTestServer(handlers) + server.Handshake() + peer := NewPeer(conn, addr, true, server) + // peer.Messenger().AddProtocols([]string{"aaa", "ccc"}) + peer.Start() + defer peer.Stop() + time.Sleep(2 * time.Millisecond) + if len(conn.Out) != 1 { + t.Errorf("handshake not sent") + } else { + out := conn.Out[0] + packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:]) + if bytes.Compare(out, packet) != 0 { + t.Errorf("incorrect handshake packet %v != %v", out, packet) + } + } + + packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) + conn.In(0, packet) + time.Sleep(10 * time.Millisecond) + + pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) + if pro.state != handshakeReceived { + t.Errorf("handshake not received") + } + if peer.Port != 30 { + t.Errorf("port incorrectly set") + } + if peer.Id != "peer" { + t.Errorf("id incorrectly set") + } + if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" { + t.Errorf("pubkey incorrectly set") + } + fmt.Println(peer.Caps) + if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" { + t.Errorf("protocols incorrectly set") + } + + msg, _ := NewMsg(3) + err := peer.Write("aaa", msg) + if err != nil { + t.Errorf("expect no error for known protocol: %v", err) + } else { + time.Sleep(1 * time.Millisecond) + if len(conn.Out) != 2 { + t.Errorf("msg not written") + } else { + out := conn.Out[1] + packet := Packet(16, 3) + if bytes.Compare(out, packet) != 0 { + t.Errorf("incorrect packet %v != %v", out, packet) + } + } + } + + msg, _ = NewMsg(2) + err = peer.Write("ccc", msg) + if err != nil { + t.Errorf("expect no error for known protocol: %v", err) + } else { + time.Sleep(1 * time.Millisecond) + if len(conn.Out) != 3 { + t.Errorf("msg not written") + } else { + out := conn.Out[2] + packet := Packet(21, 2) + if bytes.Compare(out, packet) != 0 { + t.Errorf("incorrect packet %v != %v", out, packet) + } + } + } + + err = peer.Write("bbb", msg) + time.Sleep(1 * time.Millisecond) + if err == nil { + t.Errorf("expect error for unknown protocol") + } +} diff --git a/p2p/protocol.go b/p2p/protocol.go new file mode 100644 index 000000000..5d05ced7d --- /dev/null +++ b/p2p/protocol.go @@ -0,0 +1,278 @@ +package p2p + +import ( + "bytes" + "fmt" + "net" + "sort" + "sync" + "time" +) + +type Protocol interface { + Start() + Stop() + HandleIn(*Msg, chan *Msg) + HandleOut(*Msg) bool + Offset() MsgCode + Name() string +} + +const ( + P2PVersion = 0 + pingTimeout = 2 + pingGracePeriod = 2 +) + +const ( + HandshakeMsg = iota + DiscMsg + PingMsg + PongMsg + GetPeersMsg + PeersMsg + offset = 16 +) + +type ProtocolState uint8 + +const ( + nullState = iota + handshakeReceived +) + +type DiscReason byte + +const ( + // Values are given explicitly instead of by iota because these values are + // defined by the wire protocol spec; it is easier for humans to ensure + // correctness when values are explicit. + DiscRequested = 0x00 + DiscNetworkError = 0x01 + DiscProtocolError = 0x02 + DiscUselessPeer = 0x03 + DiscTooManyPeers = 0x04 + DiscAlreadyConnected = 0x05 + DiscIncompatibleVersion = 0x06 + DiscInvalidIdentity = 0x07 + DiscQuitting = 0x08 + DiscUnexpectedIdentity = 0x09 + DiscSelf = 0x0a + DiscReadTimeout = 0x0b + DiscSubprotocolError = 0x10 +) + +var discReasonToString = map[DiscReason]string{ + DiscRequested: "Disconnect requested", + DiscNetworkError: "Network error", + DiscProtocolError: "Breach of protocol", + DiscUselessPeer: "Useless peer", + DiscTooManyPeers: "Too many peers", + DiscAlreadyConnected: "Already connected", + DiscIncompatibleVersion: "Incompatible P2P protocol version", + DiscInvalidIdentity: "Invalid node identity", + DiscQuitting: "Client quitting", + DiscUnexpectedIdentity: "Unexpected identity", + DiscSelf: "Connected to self", + DiscReadTimeout: "Read timeout", + DiscSubprotocolError: "Subprotocol error", +} + +func (d DiscReason) String() string { + if len(discReasonToString) < int(d) { + return "Unknown" + } + + return discReasonToString[d] +} + +type BaseProtocol struct { + peer *Peer + state ProtocolState + stateLock sync.RWMutex +} + +func NewBaseProtocol(peer *Peer) *BaseProtocol { + self := &BaseProtocol{ + peer: peer, + } + + return self +} + +func (self *BaseProtocol) Start() { + if self.peer != nil { + self.peer.Write("", self.peer.Server().Handshake()) + go self.peer.Messenger().PingPong( + pingTimeout*time.Second, + pingGracePeriod*time.Second, + self.Ping, + self.Timeout, + ) + } +} + +func (self *BaseProtocol) Stop() { +} + +func (self *BaseProtocol) Ping() { + msg, _ := NewMsg(PingMsg) + self.peer.Write("", msg) +} + +func (self *BaseProtocol) Timeout() { + self.peerError(PingTimeout, "") +} + +func (self *BaseProtocol) Name() string { + return "" +} + +func (self *BaseProtocol) Offset() MsgCode { + return offset +} + +func (self *BaseProtocol) CheckState(state ProtocolState) bool { + self.stateLock.RLock() + self.stateLock.RUnlock() + if self.state != state { + return false + } else { + return true + } +} + +func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) { + if msg.Code() == HandshakeMsg { + self.handleHandshake(msg) + } else { + if !self.CheckState(handshakeReceived) { + self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code()) + close(response) + return + } + switch msg.Code() { + case DiscMsg: + logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint())) + self.peer.Server().PeerDisconnect() <- DisconnectRequest{ + addr: self.peer.Address, + reason: DiscRequested, + } + case PingMsg: + out, _ := NewMsg(PongMsg) + response <- out + case PongMsg: + case GetPeersMsg: + // Peer asked for list of connected peers + if out, err := self.peer.Server().PeersMessage(); err != nil { + response <- out + } + case PeersMsg: + self.handlePeers(msg) + default: + self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code()) + } + } + close(response) +} + +func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) { + // somewhat overly paranoid + allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived) + return +} + +func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) { + err := NewPeerError(errorCode, format, v...) + logger.Warnln(err) + fmt.Println(self.peer, err) + if self.peer != nil { + self.peer.PeerErrorChan() <- err + } +} + +func (self *BaseProtocol) handlePeers(msg *Msg) { + it := msg.Data().NewIterator() + for it.Next() { + ip := net.IP(it.Value().Get(0).Bytes()) + port := it.Value().Get(1).Uint() + address := &net.TCPAddr{IP: ip, Port: int(port)} + go self.peer.Server().PeerConnect(address) + } +} + +func (self *BaseProtocol) handleHandshake(msg *Msg) { + self.stateLock.Lock() + defer self.stateLock.Unlock() + if self.state != nullState { + self.peerError(ProtocolBreach, "extra handshake") + return + } + + c := msg.Data() + + var ( + p2pVersion = c.Get(0).Uint() + id = c.Get(1).Str() + caps = c.Get(2) + port = c.Get(3).Uint() + pubkey = c.Get(4).Bytes() + ) + fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey) + + // Check correctness of p2p protocol version + if p2pVersion != P2PVersion { + self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion) + return + } + + // Handle the pub key (validation, uniqueness) + if len(pubkey) == 0 { + self.peerError(PubkeyMissing, "not supplied in handshake.") + return + } + + if len(pubkey) != 64 { + self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8) + return + } + + // Self connect detection + if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 { + self.peerError(PubkeyForbidden, "not allowed to connect to self") + return + } + + // register pubkey on server. this also sets the pubkey on the peer (need lock) + if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil { + self.peerError(PubkeyForbidden, err.Error()) + return + } + + // check port + if self.peer.Inbound { + uint16port := uint16(port) + if self.peer.Port > 0 && self.peer.Port != uint16port { + self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port) + return + } else { + self.peer.Port = uint16port + } + } + + capsIt := caps.NewIterator() + for capsIt.Next() { + cap := capsIt.Value().Str() + self.peer.Caps = append(self.peer.Caps, cap) + } + sort.Strings(self.peer.Caps) + self.peer.Messenger().AddProtocols(self.peer.Caps) + + self.peer.Id = id + + self.state = handshakeReceived + + //p.ethereum.PushPeer(p) + // p.ethereum.reactor.Post("peerList", p.ethereum.Peers()) + return +} diff --git a/p2p/server.go b/p2p/server.go new file mode 100644 index 000000000..a6bbd9260 --- /dev/null +++ b/p2p/server.go @@ -0,0 +1,484 @@ +package p2p + +import ( + "bytes" + "fmt" + "net" + "sort" + "strconv" + "sync" + "time" + + "github.com/ethereum/eth-go/ethlog" +) + +const ( + outboundAddressPoolSize = 10 + disconnectGracePeriod = 2 +) + +type Blacklist interface { + Get([]byte) (bool, error) + Put([]byte) error + Delete([]byte) error + Exists(pubkey []byte) (ok bool) +} + +type BlacklistMap struct { + blacklist map[string]bool + lock sync.RWMutex +} + +func NewBlacklist() *BlacklistMap { + return &BlacklistMap{ + blacklist: make(map[string]bool), + } +} + +func (self *BlacklistMap) Get(pubkey []byte) (bool, error) { + self.lock.RLock() + defer self.lock.RUnlock() + v, ok := self.blacklist[string(pubkey)] + var err error + if !ok { + err = fmt.Errorf("not found") + } + return v, err +} + +func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) { + self.lock.RLock() + defer self.lock.RUnlock() + _, ok = self.blacklist[string(pubkey)] + return +} + +func (self *BlacklistMap) Put(pubkey []byte) error { + self.lock.RLock() + defer self.lock.RUnlock() + self.blacklist[string(pubkey)] = true + return nil +} + +func (self *BlacklistMap) Delete(pubkey []byte) error { + self.lock.RLock() + defer self.lock.RUnlock() + delete(self.blacklist, string(pubkey)) + return nil +} + +type Server struct { + network Network + listening bool //needed? + dialing bool //needed? + closed bool + identity ClientIdentity + addr net.Addr + port uint16 + protocols []string + + quit chan chan bool + peersLock sync.RWMutex + + maxPeers int + peers []*Peer + peerSlots chan int + peersTable map[string]int + peersMsg *Msg + peerCount int + + peerConnect chan net.Addr + peerDisconnect chan DisconnectRequest + blacklist Blacklist + handlers Handlers +} + +var logger = ethlog.NewLogger("P2P") + +func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server { + // get alphabetical list of protocol names from handlers map + protocols := []string{} + for protocol := range handlers { + protocols = append(protocols, protocol) + } + sort.Strings(protocols) + + _, port, _ := net.SplitHostPort(addr.String()) + intport, _ := strconv.Atoi(port) + + self := &Server{ + // NewSimpleClientIdentity(clientIdentifier, version, customIdentifier) + network: network, + identity: identity, + addr: addr, + port: uint16(intport), + protocols: protocols, + + quit: make(chan chan bool), + + maxPeers: maxPeers, + peers: make([]*Peer, maxPeers), + peerSlots: make(chan int, maxPeers), + peersTable: make(map[string]int), + + peerConnect: make(chan net.Addr, outboundAddressPoolSize), + peerDisconnect: make(chan DisconnectRequest), + blacklist: blacklist, + + handlers: handlers, + } + for i := 0; i < maxPeers; i++ { + self.peerSlots <- i // fill up with indexes + } + return self +} + +func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) { + addr, err = self.network.NewAddr(host, port) + return +} + +func (self *Server) ParseAddr(address string) (addr net.Addr, err error) { + addr, err = self.network.ParseAddr(address) + return +} + +func (self *Server) ClientIdentity() ClientIdentity { + return self.identity +} + +func (self *Server) PeersMessage() (msg *Msg, err error) { + // TODO: memoize and reset when peers change + self.peersLock.RLock() + defer self.peersLock.RUnlock() + msg = self.peersMsg + if msg == nil { + var peerData []interface{} + for _, i := range self.peersTable { + peer := self.peers[i] + peerData = append(peerData, peer.Encode()) + } + if len(peerData) == 0 { + err = fmt.Errorf("no peers") + } else { + msg, err = NewMsg(PeersMsg, peerData...) + self.peersMsg = msg //memoize + } + } + return +} + +func (self *Server) Peers() (peers []*Peer) { + self.peersLock.RLock() + defer self.peersLock.RUnlock() + for _, peer := range self.peers { + if peer != nil { + peers = append(peers, peer) + } + } + return +} + +func (self *Server) PeerCount() int { + self.peersLock.RLock() + defer self.peersLock.RUnlock() + return self.peerCount +} + +var getPeersMsg, _ = NewMsg(GetPeersMsg) + +func (self *Server) PeerConnect(addr net.Addr) { + // TODO: should buffer, filter and uniq + // send GetPeersMsg if not blocking + select { + case self.peerConnect <- addr: // not enough peers + self.Broadcast("", getPeersMsg) + default: // we dont care + } +} + +func (self *Server) PeerDisconnect() chan DisconnectRequest { + return self.peerDisconnect +} + +func (self *Server) Blacklist() Blacklist { + return self.blacklist +} + +func (self *Server) Handlers() Handlers { + return self.handlers +} + +func (self *Server) Broadcast(protocol string, msg *Msg) { + self.peersLock.RLock() + defer self.peersLock.RUnlock() + for _, peer := range self.peers { + if peer != nil { + peer.Write(protocol, msg) + } + } +} + +// Start the server +func (self *Server) Start(listen bool, dial bool) { + self.network.Start() + if listen { + listener, err := self.network.Listener(self.addr) + if err != nil { + logger.Warnf("Error initializing listener: %v", err) + logger.Warnf("Connection listening disabled") + self.listening = false + } else { + self.listening = true + logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr()) + go self.inboundPeerHandler(listener) + } + } + if dial { + dialer, err := self.network.Dialer(self.addr) + if err != nil { + logger.Warnf("Error initializing dialer: %v", err) + logger.Warnf("Connection dialout disabled") + self.dialing = false + } else { + self.dialing = true + logger.Infoln("Dial peers watching outbound address pool") + go self.outboundPeerHandler(dialer) + } + } + logger.Infoln("server started") +} + +func (self *Server) Stop() { + logger.Infoln("server stopping...") + // // quit one loop if dialing + if self.dialing { + logger.Infoln("stop dialout...") + dialq := make(chan bool) + self.quit <- dialq + <-dialq + fmt.Println("quit another") + } + // quit the other loop if listening + if self.listening { + logger.Infoln("stop listening...") + listenq := make(chan bool) + self.quit <- listenq + <-listenq + fmt.Println("quit one") + } + + fmt.Println("quit waited") + + logger.Infoln("stopping peers...") + peers := []net.Addr{} + self.peersLock.RLock() + self.closed = true + for _, peer := range self.peers { + if peer != nil { + peers = append(peers, peer.Address) + } + } + self.peersLock.RUnlock() + for _, address := range peers { + go self.removePeer(DisconnectRequest{ + addr: address, + reason: DiscQuitting, + }) + } + // wait till they actually disconnect + // this is checked by draining the peerSlots (slots are released back if a peer is removed) + i := 0 + fmt.Println("draining peers") + +FOR: + for { + select { + case slot := <-self.peerSlots: + i++ + fmt.Printf("%v: found slot %v", i, slot) + if i == self.maxPeers { + break FOR + } + } + } + logger.Infoln("server stopped") +} + +// main loop for adding connections via listening +func (self *Server) inboundPeerHandler(listener net.Listener) { + for { + select { + case slot := <-self.peerSlots: + go self.connectInboundPeer(listener, slot) + case errc := <-self.quit: + listener.Close() + fmt.Println("quit listenloop") + errc <- true + return + } + } +} + +// main loop for adding outbound peers based on peerConnect address pool +// this same loop handles peer disconnect requests as well +func (self *Server) outboundPeerHandler(dialer Dialer) { + // addressChan initially set to nil (only watches peerConnect if we need more peers) + var addressChan chan net.Addr + slots := self.peerSlots + var slot *int + for { + select { + case i := <-slots: + // we need a peer in slot i, slot reserved + slot = &i + // now we can watch for candidate peers in the next loop + addressChan = self.peerConnect + // do not consume more until candidate peer is found + slots = nil + case address := <-addressChan: + // candidate peer found, will dial out asyncronously + // if connection fails slot will be released + go self.connectOutboundPeer(dialer, address, *slot) + // we can watch if more peers needed in the next loop + slots = self.peerSlots + // until then we dont care about candidate peers + addressChan = nil + case request := <-self.peerDisconnect: + go self.removePeer(request) + case errc := <-self.quit: + if addressChan != nil && slot != nil { + self.peerSlots <- *slot + } + fmt.Println("quit dialloop") + errc <- true + return + } + } +} + +// check if peer address already connected +func (self *Server) connected(address net.Addr) (err error) { + self.peersLock.RLock() + defer self.peersLock.RUnlock() + // fmt.Printf("address: %v\n", address) + slot, found := self.peersTable[address.String()] + if found { + err = fmt.Errorf("already connected as peer %v (%v)", slot, address) + } + return +} + +// connect to peer via listener.Accept() +func (self *Server) connectInboundPeer(listener net.Listener, slot int) { + var address net.Addr + conn, err := listener.Accept() + if err == nil { + address = conn.RemoteAddr() + err = self.connected(address) + if err != nil { + conn.Close() + } + } + if err != nil { + logger.Debugln(err) + self.peerSlots <- slot + } else { + fmt.Printf("adding %v\n", address) + go self.addPeer(conn, address, true, slot) + } +} + +// connect to peer via dial out +func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) { + var conn net.Conn + err := self.connected(address) + if err == nil { + conn, err = dialer.Dial(address.Network(), address.String()) + } + if err != nil { + logger.Debugln(err) + self.peerSlots <- slot + } else { + go self.addPeer(conn, address, false, slot) + } +} + +// creates the new peer object and inserts it into its slot +func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) { + self.peersLock.Lock() + defer self.peersLock.Unlock() + if self.closed { + fmt.Println("oopsy, not no longer need peer") + conn.Close() //oopsy our bad + self.peerSlots <- slot // release slot + } else { + peer := NewPeer(conn, address, inbound, self) + self.peers[slot] = peer + self.peersTable[address.String()] = slot + self.peerCount++ + // reset peersmsg + self.peersMsg = nil + fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot) + peer.Start() + } +} + +// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot +func (self *Server) removePeer(request DisconnectRequest) { + self.peersLock.Lock() + + address := request.addr + slot := self.peersTable[address.String()] + peer := self.peers[slot] + fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot) + if peer == nil { + logger.Debugf("already removed peer on %v", address) + self.peersLock.Unlock() + return + } + // remove from list and index + self.peerCount-- + self.peers[slot] = nil + delete(self.peersTable, address.String()) + // reset peersmsg + self.peersMsg = nil + fmt.Printf("removed peer %v (slot %v)\n", peer, slot) + self.peersLock.Unlock() + + // sending disconnect message + disconnectMsg, _ := NewMsg(DiscMsg, request.reason) + peer.Write("", disconnectMsg) + // be nice and wait + time.Sleep(disconnectGracePeriod * time.Second) + // switch off peer and close connections etc. + fmt.Println("stopping peer") + peer.Stop() + fmt.Println("stopped peer") + // release slot to signal need for a new peer, last! + self.peerSlots <- slot +} + +// fix handshake message to push to peers +func (self *Server) Handshake() *Msg { + fmt.Println(self.identity.Pubkey()[1:]) + msg, _ := NewMsg(HandshakeMsg, P2PVersion, []byte(self.identity.String()), []interface{}{self.protocols}, self.port, self.identity.Pubkey()[1:]) + return msg +} + +func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error { + // Check for blacklisting + if self.blacklist.Exists(pubkey) { + return fmt.Errorf("blacklisted") + } + + self.peersLock.RLock() + defer self.peersLock.RUnlock() + for _, peer := range self.peers { + if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 { + return fmt.Errorf("already connected") + } + } + candidate.Pubkey = pubkey + return nil +} diff --git a/p2p/server_test.go b/p2p/server_test.go new file mode 100644 index 000000000..f749cc490 --- /dev/null +++ b/p2p/server_test.go @@ -0,0 +1,208 @@ +package p2p + +import ( + "bytes" + "fmt" + "net" + "testing" + "time" +) + +type TestNetwork struct { + connections map[string]*TestNetworkConnection + dialer Dialer + maxinbound int +} + +func NewTestNetwork(maxinbound int) *TestNetwork { + connections := make(map[string]*TestNetworkConnection) + return &TestNetwork{ + connections: connections, + dialer: &TestDialer{connections}, + maxinbound: maxinbound, + } +} + +func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) { + return self.dialer, nil +} + +func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) { + return &TestListener{ + connections: self.connections, + addr: addr, + max: self.maxinbound, + }, nil +} + +func (self *TestNetwork) Start() error { + return nil +} + +func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) { + return +} + +func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) { + return +} + +type TestAddr struct { + name string +} + +func (self *TestAddr) String() string { + return self.name +} + +func (*TestAddr) Network() string { + return "test" +} + +type TestDialer struct { + connections map[string]*TestNetworkConnection +} + +func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) { + address := &TestAddr{addr} + tconn := NewTestNetworkConnection(address) + self.connections[addr] = tconn + conn = net.Conn(tconn) + return +} + +type TestListener struct { + connections map[string]*TestNetworkConnection + addr net.Addr + max int + i int +} + +func (self *TestListener) Accept() (conn net.Conn, err error) { + self.i++ + if self.i > self.max { + err = fmt.Errorf("no more") + } else { + addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)} + tconn := NewTestNetworkConnection(addr) + key := tconn.RemoteAddr().String() + self.connections[key] = tconn + conn = net.Conn(tconn) + fmt.Printf("accepted connection from: %v \n", addr) + } + return +} + +func (self *TestListener) Close() error { + return nil +} + +func (self *TestListener) Addr() net.Addr { + return self.addr +} + +func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) { + network = NewTestNetwork(1) + addr := &TestAddr{"test:30303"} + identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey") + maxPeers := 2 + if handlers == nil { + handlers = make(Handlers) + } + blackist := NewBlacklist() + server = New(network, addr, identity, handlers, maxPeers, blackist) + fmt.Println(server.identity.Pubkey()) + return +} + +func TestServerListener(t *testing.T) { + network, server := SetupTestServer(nil) + server.Start(true, false) + time.Sleep(10 * time.Millisecond) + server.Stop() + peer1, ok := network.connections["inboundpeer-1"] + if !ok { + t.Error("not found inbound peer 1") + } else { + fmt.Printf("out: %v\n", peer1.Out) + if len(peer1.Out) != 2 { + t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) + } + } + +} + +func TestServerDialer(t *testing.T) { + network, server := SetupTestServer(nil) + server.Start(false, true) + server.peerConnect <- &TestAddr{"outboundpeer-1"} + time.Sleep(10 * time.Millisecond) + server.Stop() + peer1, ok := network.connections["outboundpeer-1"] + if !ok { + t.Error("not found outbound peer 1") + } else { + fmt.Printf("out: %v\n", peer1.Out) + if len(peer1.Out) != 2 { + t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) + } + } +} + +func TestServerBroadcast(t *testing.T) { + handlers := make(Handlers) + testProtocol := &TestProtocol{Msgs: []*Msg{}} + handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } + network, server := SetupTestServer(handlers) + server.Start(true, true) + server.peerConnect <- &TestAddr{"outboundpeer-1"} + time.Sleep(10 * time.Millisecond) + msg, _ := NewMsg(0) + server.Broadcast("", msg) + packet := Packet(0, 0) + time.Sleep(10 * time.Millisecond) + server.Stop() + peer1, ok := network.connections["outboundpeer-1"] + if !ok { + t.Error("not found outbound peer 1") + } else { + fmt.Printf("out: %v\n", peer1.Out) + if len(peer1.Out) != 3 { + t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) + } else { + if bytes.Compare(peer1.Out[1], packet) != 0 { + t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet) + } + } + } + peer2, ok := network.connections["inboundpeer-1"] + if !ok { + t.Error("not found inbound peer 2") + } else { + fmt.Printf("out: %v\n", peer2.Out) + if len(peer1.Out) != 3 { + t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out)) + } else { + if bytes.Compare(peer2.Out[1], packet) != 0 { + t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet) + } + } + } +} + +func TestServerPeersMessage(t *testing.T) { + handlers := make(Handlers) + _, server := SetupTestServer(handlers) + server.Start(true, true) + defer server.Stop() + server.peerConnect <- &TestAddr{"outboundpeer-1"} + time.Sleep(10 * time.Millisecond) + peersMsg, err := server.PeersMessage() + fmt.Println(peersMsg) + if err != nil { + t.Errorf("expect no error, got %v", err) + } + if c := server.PeerCount(); c != 2 { + t.Errorf("expect 2 peers, got %v", c) + } +} |