aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJimmy Hu <jimmy.hu@dexon.org>2018-09-30 00:01:12 +0800
committerGitHub <noreply@github.com>2018-09-30 00:01:12 +0800
commit7ee55d0963555a1dfb212f0fb5c2ee59bedfb221 (patch)
tree1cfde11678b27a319970b02fd2e829e5432225a3
parent34c7476e060e6f537a5af15498cc79bb287865d2 (diff)
downloaddexon-consensus-7ee55d0963555a1dfb212f0fb5c2ee59bedfb221.tar
dexon-consensus-7ee55d0963555a1dfb212f0fb5c2ee59bedfb221.tar.gz
dexon-consensus-7ee55d0963555a1dfb212f0fb5c2ee59bedfb221.tar.bz2
dexon-consensus-7ee55d0963555a1dfb212f0fb5c2ee59bedfb221.tar.lz
dexon-consensus-7ee55d0963555a1dfb212f0fb5c2ee59bedfb221.tar.xz
dexon-consensus-7ee55d0963555a1dfb212f0fb5c2ee59bedfb221.tar.zst
dexon-consensus-7ee55d0963555a1dfb212f0fb5c2ee59bedfb221.zip
test: tcp handshake (#151)
-rw-r--r--core/test/tcp-transport.go224
1 files changed, 175 insertions, 49 deletions
diff --git a/core/test/tcp-transport.go b/core/test/tcp-transport.go
index 6e3ddfb..e1b73f0 100644
--- a/core/test/tcp-transport.go
+++ b/core/test/tcp-transport.go
@@ -72,6 +72,17 @@ func parsePeerInfo(info string) (key crypto.PublicKey, conn string) {
return
}
+var (
+ // ErrTCPHandShakeFail is reported if the tcp handshake fails.
+ ErrTCPHandShakeFail = fmt.Errorf("tcp handshake fail")
+
+ // ErrConnectToUnexpectedPeer is reported if connect to unexpected peer.
+ ErrConnectToUnexpectedPeer = fmt.Errorf("connect to unexpected peer")
+
+ // ErrMessageOverflow is reported if the message is too long.
+ ErrMessageOverflow = fmt.Errorf("message size overflow")
+)
+
// TCPTransport implements Transport interface via TCP connection.
type TCPTransport struct {
peerType TransportPeerType
@@ -110,6 +121,69 @@ func NewTCPTransport(
}
}
+const handshakeMsg = "Welcome to DEXON network for test."
+
+func (t *TCPTransport) serverHandshake(conn net.Conn) (
+ nID types.NodeID, err error) {
+ conn.SetDeadline(time.Now().Add(3 * time.Second))
+ msg := &tcpMessage{
+ NodeID: t.nID,
+ Type: "handshake",
+ Info: handshakeMsg,
+ }
+ var payload []byte
+ payload, err = json.Marshal(msg)
+ if err != nil {
+ return
+ }
+ if err = t.write(conn, payload); err != nil {
+ return
+ }
+ if payload, err = t.read(conn); err != nil {
+ return
+ }
+ if err = json.Unmarshal(payload, &msg); err != nil {
+ return
+ }
+ if msg.Type != "handshake-ack" || msg.Info != handshakeMsg {
+ err = ErrTCPHandShakeFail
+ return
+ }
+ nID = msg.NodeID
+ return
+}
+
+func (t *TCPTransport) clientHandshake(conn net.Conn) (
+ nID types.NodeID, err error) {
+ conn.SetDeadline(time.Now().Add(3 * time.Second))
+ var payload []byte
+ if payload, err = t.read(conn); err != nil {
+ return
+ }
+ msg := &tcpMessage{}
+ if err = json.Unmarshal(payload, &msg); err != nil {
+ return
+ }
+ if msg.Type != "handshake" || msg.Info != handshakeMsg {
+ err = ErrTCPHandShakeFail
+ return
+ }
+ nID = msg.NodeID
+ msg = &tcpMessage{
+ NodeID: t.nID,
+ Type: "handshake-ack",
+ Info: handshakeMsg,
+ }
+ payload, err = json.Marshal(msg)
+ if err != nil {
+ return
+ }
+ if err = t.write(conn, payload); err != nil {
+ return
+ }
+ return
+}
+
// Send implements Transport.Send method.
func (t *TCPTransport) Send(
endpoint types.NodeID, msg interface{}) (err error) {
@@ -176,6 +250,33 @@ func (t *TCPTransport) Peers() (peers []crypto.PublicKey) {
return
}
+func (t *TCPTransport) write(conn net.Conn, b []byte) (err error) {
+ if len(b) > math.MaxUint32 {
+ return ErrMessageOverflow
+ }
+ msgLength := make([]byte, 4)
+ binary.LittleEndian.PutUint32(msgLength, uint32(len(b)))
+ if _, err = conn.Write(msgLength); err != nil {
+ return
+ }
+ if _, err = conn.Write(b); err != nil {
+ return
+ }
+ return
+}
+
+func (t *TCPTransport) read(conn net.Conn) (b []byte, err error) {
+ msgLength := make([]byte, 4)
+ if _, err = io.ReadFull(conn, msgLength); err != nil {
+ return
+ }
+ b = make([]byte, int(binary.LittleEndian.Uint32(msgLength)))
+ if _, err = io.ReadFull(conn, b); err != nil {
+ return
+ }
+ return
+}
+
func (t *TCPTransport) marshalMessage(
msg interface{}) (payload []byte, err error) {
@@ -264,10 +365,8 @@ func (t *TCPTransport) connReader(conn net.Conn) {
}()
var (
- msgLengthInByte [4]byte
- msgLength uint32
- err error
- payload = make([]byte, 4096)
+ err error
+ payload []byte
)
checkErr := func(err error) (toBreak bool) {
@@ -299,26 +398,13 @@ Loop:
panic(err)
}
// Read message length.
- if _, err = io.ReadFull(conn, msgLengthInByte[:]); err != nil {
- if checkErr(err) {
- break
- }
- continue
- }
- msgLength = binary.LittleEndian.Uint32(msgLengthInByte[:])
- // Resize buffer
- if msgLength > uint32(len(payload)) {
- payload = make([]byte, msgLength)
- }
- buff := payload[:msgLength]
- // Read the message in bytes.
- if _, err = io.ReadFull(conn, buff); err != nil {
+ if payload, err = t.read(conn); err != nil {
if checkErr(err) {
break
}
continue
}
- peerType, from, msg, err := t.unmarshalMessage(buff)
+ peerType, from, msg, err := t.unmarshalMessage(payload)
if err != nil {
panic(err)
}
@@ -332,6 +418,11 @@ Loop:
// connWriter is a writer routine to write to TCP connection.
func (t *TCPTransport) connWriter(conn net.Conn) chan<- []byte {
+ // Disable write deadline.
+ if err := conn.SetWriteDeadline(time.Time{}); err != nil {
+ panic(err)
+ }
+
ch := make(chan []byte, 1000)
go func() {
defer func() {
@@ -351,16 +442,7 @@ func (t *TCPTransport) connWriter(conn net.Conn) chan<- []byte {
return
case msg := <-ch:
// Send message length in uint32.
- var msgLength [4]byte
- if len(msg) > math.MaxUint32 {
- panic(fmt.Errorf("message size overflow"))
- }
- binary.LittleEndian.PutUint32(msgLength[:], uint32(len(msg)))
- if _, err := conn.Write(msgLength[:]); err != nil {
- panic(err)
- }
- // Send the payload.
- if _, err := conn.Write(msg); err != nil {
+ if err := t.write(conn, msg); err != nil {
panic(err)
}
}
@@ -371,7 +453,11 @@ func (t *TCPTransport) connWriter(conn net.Conn) chan<- []byte {
// listenerRoutine is a routine to accept incoming request for TCP connection.
func (t *TCPTransport) listenerRoutine(listener *net.TCPListener) {
+ closed := false
defer func() {
+ if closed {
+ return
+ }
if err := listener.Close(); err != nil {
panic(err)
}
@@ -386,6 +472,11 @@ func (t *TCPTransport) listenerRoutine(listener *net.TCPListener) {
listener.SetDeadline(time.Now().Add(5 * time.Second))
conn, err := listener.Accept()
if err != nil {
+ // Check if the connection is closed.
+ if strings.Contains(err.Error(), "use of closed network connection") {
+ closed = true
+ return
+ }
// Check if timeout error.
nErr, ok := err.(*net.OpError)
if !ok {
@@ -396,6 +487,10 @@ func (t *TCPTransport) listenerRoutine(listener *net.TCPListener) {
}
continue
}
+ if _, err := t.serverHandshake(conn); err != nil {
+ fmt.Println(err)
+ continue
+ }
go t.connReader(conn)
}
}
@@ -420,6 +515,15 @@ func (t *TCPTransport) buildConnectionsToPeers() (err error) {
err = localErr
return
}
+ serverID, e := t.clientHandshake(conn)
+ if e != nil {
+ err = e
+ return
+ }
+ if nID != serverID {
+ err = ErrConnectToUnexpectedPeer
+ return
+ }
t.peersLock.Lock()
defer t.peersLock.Unlock()
@@ -479,35 +583,57 @@ func (t *TCPTransportClient) Join(
addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(t.localPort))
ln, err = net.Listen("tcp", addr)
if err == nil {
- break
- }
- if !t.local {
- return
- }
- // In local-tcp, retry with other port when the address is in use.
- operr, ok := err.(*net.OpError)
- if !ok {
- panic(err)
- }
- oserr, ok := operr.Err.(*os.SyscallError)
- if !ok {
- panic(operr)
- }
- errno, ok := oserr.Err.(syscall.Errno)
- if !ok {
- panic(oserr)
+ go t.listenerRoutine(ln.(*net.TCPListener))
+ // It is possible to listen on the same port in some platform.
+ // Check if this one is actually listening.
+ testConn, e := net.Dial("tcp", addr)
+ if e != nil {
+ err = e
+ return
+ }
+ nID, e := t.clientHandshake(testConn)
+ if e != nil {
+ err = e
+ return
+ }
+ if nID == t.nID {
+ break
+ }
+ ln.Close()
}
- if errno != syscall.EADDRINUSE {
- panic(errno)
+ if err != nil {
+ if !t.local {
+ return
+ }
+ // In local-tcp, retry with other port when the address is in use.
+ operr, ok := err.(*net.OpError)
+ if !ok {
+ panic(err)
+ }
+ oserr, ok := operr.Err.(*os.SyscallError)
+ if !ok {
+ panic(operr)
+ }
+ errno, ok := oserr.Err.(syscall.Errno)
+ if !ok {
+ panic(oserr)
+ }
+ if errno != syscall.EADDRINUSE {
+ panic(errno)
+ }
}
// The port is used, generate another port randomly.
t.localPort = 1024 + rand.Int()%1024
}
- go t.listenerRoutine(ln.(*net.TCPListener))
+
serverConn, err := net.Dial("tcp", serverEndpoint.(string))
if err != nil {
return
}
+ _, err = t.clientHandshake(serverConn)
+ if err != nil {
+ return
+ }
t.serverWriteChannel = t.connWriter(serverConn)
if t.local {
conn = addr