From 7ee55d0963555a1dfb212f0fb5c2ee59bedfb221 Mon Sep 17 00:00:00 2001 From: Jimmy Hu Date: Sun, 30 Sep 2018 00:01:12 +0800 Subject: test: tcp handshake (#151) --- core/test/tcp-transport.go | 224 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 175 insertions(+), 49 deletions(-) (limited to 'core') diff --git a/core/test/tcp-transport.go b/core/test/tcp-transport.go index 6e3ddfb..e1b73f0 100644 --- a/core/test/tcp-transport.go +++ b/core/test/tcp-transport.go @@ -72,6 +72,17 @@ func parsePeerInfo(info string) (key crypto.PublicKey, conn string) { return } +var ( + // ErrTCPHandShakeFail is reported if the tcp handshake fails. + ErrTCPHandShakeFail = fmt.Errorf("tcp handshake fail") + + // ErrConnectToUnexpectedPeer is reported if connect to unexpected peer. + ErrConnectToUnexpectedPeer = fmt.Errorf("connect to unexpected peer") + + // ErrMessageOverflow is reported if the message is too long. + ErrMessageOverflow = fmt.Errorf("message size overflow") +) + // TCPTransport implements Transport interface via TCP connection. type TCPTransport struct { peerType TransportPeerType @@ -110,6 +121,69 @@ func NewTCPTransport( } } +const handshakeMsg = "Welcome to DEXON network for test." + +func (t *TCPTransport) serverHandshake(conn net.Conn) ( + nID types.NodeID, err error) { + conn.SetDeadline(time.Now().Add(3 * time.Second)) + msg := &tcpMessage{ + NodeID: t.nID, + Type: "handshake", + Info: handshakeMsg, + } + var payload []byte + payload, err = json.Marshal(msg) + if err != nil { + return + } + if err = t.write(conn, payload); err != nil { + return + } + if payload, err = t.read(conn); err != nil { + return + } + if err = json.Unmarshal(payload, &msg); err != nil { + return + } + if msg.Type != "handshake-ack" || msg.Info != handshakeMsg { + err = ErrTCPHandShakeFail + return + } + nID = msg.NodeID + return +} + +func (t *TCPTransport) clientHandshake(conn net.Conn) ( + nID types.NodeID, err error) { + conn.SetDeadline(time.Now().Add(3 * time.Second)) + var payload []byte + if payload, err = t.read(conn); err != nil { + return + } + msg := &tcpMessage{} + if err = json.Unmarshal(payload, &msg); err != nil { + return + } + if msg.Type != "handshake" || msg.Info != handshakeMsg { + err = ErrTCPHandShakeFail + return + } + nID = msg.NodeID + msg = &tcpMessage{ + NodeID: t.nID, + Type: "handshake-ack", + Info: handshakeMsg, + } + payload, err = json.Marshal(msg) + if err != nil { + return + } + if err = t.write(conn, payload); err != nil { + return + } + return +} + // Send implements Transport.Send method. func (t *TCPTransport) Send( endpoint types.NodeID, msg interface{}) (err error) { @@ -176,6 +250,33 @@ func (t *TCPTransport) Peers() (peers []crypto.PublicKey) { return } +func (t *TCPTransport) write(conn net.Conn, b []byte) (err error) { + if len(b) > math.MaxUint32 { + return ErrMessageOverflow + } + msgLength := make([]byte, 4) + binary.LittleEndian.PutUint32(msgLength, uint32(len(b))) + if _, err = conn.Write(msgLength); err != nil { + return + } + if _, err = conn.Write(b); err != nil { + return + } + return +} + +func (t *TCPTransport) read(conn net.Conn) (b []byte, err error) { + msgLength := make([]byte, 4) + if _, err = io.ReadFull(conn, msgLength); err != nil { + return + } + b = make([]byte, int(binary.LittleEndian.Uint32(msgLength))) + if _, err = io.ReadFull(conn, b); err != nil { + return + } + return +} + func (t *TCPTransport) marshalMessage( msg interface{}) (payload []byte, err error) { @@ -264,10 +365,8 @@ func (t *TCPTransport) connReader(conn net.Conn) { }() var ( - msgLengthInByte [4]byte - msgLength uint32 - err error - payload = make([]byte, 4096) + err error + payload []byte ) checkErr := func(err error) (toBreak bool) { @@ -299,26 +398,13 @@ Loop: panic(err) } // Read message length. - if _, err = io.ReadFull(conn, msgLengthInByte[:]); err != nil { - if checkErr(err) { - break - } - continue - } - msgLength = binary.LittleEndian.Uint32(msgLengthInByte[:]) - // Resize buffer - if msgLength > uint32(len(payload)) { - payload = make([]byte, msgLength) - } - buff := payload[:msgLength] - // Read the message in bytes. - if _, err = io.ReadFull(conn, buff); err != nil { + if payload, err = t.read(conn); err != nil { if checkErr(err) { break } continue } - peerType, from, msg, err := t.unmarshalMessage(buff) + peerType, from, msg, err := t.unmarshalMessage(payload) if err != nil { panic(err) } @@ -332,6 +418,11 @@ Loop: // connWriter is a writer routine to write to TCP connection. func (t *TCPTransport) connWriter(conn net.Conn) chan<- []byte { + // Disable write deadline. + if err := conn.SetWriteDeadline(time.Time{}); err != nil { + panic(err) + } + ch := make(chan []byte, 1000) go func() { defer func() { @@ -351,16 +442,7 @@ func (t *TCPTransport) connWriter(conn net.Conn) chan<- []byte { return case msg := <-ch: // Send message length in uint32. - var msgLength [4]byte - if len(msg) > math.MaxUint32 { - panic(fmt.Errorf("message size overflow")) - } - binary.LittleEndian.PutUint32(msgLength[:], uint32(len(msg))) - if _, err := conn.Write(msgLength[:]); err != nil { - panic(err) - } - // Send the payload. - if _, err := conn.Write(msg); err != nil { + if err := t.write(conn, msg); err != nil { panic(err) } } @@ -371,7 +453,11 @@ func (t *TCPTransport) connWriter(conn net.Conn) chan<- []byte { // listenerRoutine is a routine to accept incoming request for TCP connection. func (t *TCPTransport) listenerRoutine(listener *net.TCPListener) { + closed := false defer func() { + if closed { + return + } if err := listener.Close(); err != nil { panic(err) } @@ -386,6 +472,11 @@ func (t *TCPTransport) listenerRoutine(listener *net.TCPListener) { listener.SetDeadline(time.Now().Add(5 * time.Second)) conn, err := listener.Accept() if err != nil { + // Check if the connection is closed. + if strings.Contains(err.Error(), "use of closed network connection") { + closed = true + return + } // Check if timeout error. nErr, ok := err.(*net.OpError) if !ok { @@ -396,6 +487,10 @@ func (t *TCPTransport) listenerRoutine(listener *net.TCPListener) { } continue } + if _, err := t.serverHandshake(conn); err != nil { + fmt.Println(err) + continue + } go t.connReader(conn) } } @@ -420,6 +515,15 @@ func (t *TCPTransport) buildConnectionsToPeers() (err error) { err = localErr return } + serverID, e := t.clientHandshake(conn) + if e != nil { + err = e + return + } + if nID != serverID { + err = ErrConnectToUnexpectedPeer + return + } t.peersLock.Lock() defer t.peersLock.Unlock() @@ -479,35 +583,57 @@ func (t *TCPTransportClient) Join( addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(t.localPort)) ln, err = net.Listen("tcp", addr) if err == nil { - break - } - if !t.local { - return - } - // In local-tcp, retry with other port when the address is in use. - operr, ok := err.(*net.OpError) - if !ok { - panic(err) - } - oserr, ok := operr.Err.(*os.SyscallError) - if !ok { - panic(operr) - } - errno, ok := oserr.Err.(syscall.Errno) - if !ok { - panic(oserr) + go t.listenerRoutine(ln.(*net.TCPListener)) + // It is possible to listen on the same port in some platform. + // Check if this one is actually listening. + testConn, e := net.Dial("tcp", addr) + if e != nil { + err = e + return + } + nID, e := t.clientHandshake(testConn) + if e != nil { + err = e + return + } + if nID == t.nID { + break + } + ln.Close() } - if errno != syscall.EADDRINUSE { - panic(errno) + if err != nil { + if !t.local { + return + } + // In local-tcp, retry with other port when the address is in use. + operr, ok := err.(*net.OpError) + if !ok { + panic(err) + } + oserr, ok := operr.Err.(*os.SyscallError) + if !ok { + panic(operr) + } + errno, ok := oserr.Err.(syscall.Errno) + if !ok { + panic(oserr) + } + if errno != syscall.EADDRINUSE { + panic(errno) + } } // The port is used, generate another port randomly. t.localPort = 1024 + rand.Int()%1024 } - go t.listenerRoutine(ln.(*net.TCPListener)) + serverConn, err := net.Dial("tcp", serverEndpoint.(string)) if err != nil { return } + _, err = t.clientHandshake(serverConn) + if err != nil { + return + } t.serverWriteChannel = t.connWriter(serverConn) if t.local { conn = addr -- cgit v1.2.3