aboutsummaryrefslogblamecommitdiffstats
path: root/core/test/tcp-transport.go
blob: a3b9abaa189db011ef2c335c54f22092e51c3d64 (plain) (tree)
1
2
3
4
5
6
7
8
9

                                                    
  
                                                                        
                                                                               


                                                                              
                                                                         


                                                                          

                                                                           
                                                      





                                  
                         








                         
                 



                 


                                                                       

 





                                    

                                                              


                                         

 















                                                                             



                                                        


              










                                                                                


                                                                  
                                
                                    
                       
                                                   










                                                       
                                






                                                               


                                                                   








                                                                 






























































                                                                    

                                         
                                                             












                                                     
                                                        












                                                               
                                       
                                 






                                                             
                                  










                                                 
                                                       






                                                     


                                                           



              


























                                                                    




                                                             
                                                        



                                                           
                                


                              
                                     
                                            

                                             






















                                                                      
                          




                                                             
                                                        









                                                                   
                                                 



                                                                                 


                                                                            

                              
                       


















                                                                                      

                              






























                                                                      
                                                            




                                          
                                                                       












                                                                




                                                                  


















                                                                 
                                                                          









                                                                              
                       
                      


                              













                                                                     




                                                                                              









                                                      



                                                                  








                                                                       
                                       
                                 


                                
                                                        








                                                                                      








                                                                


                                                  

                                                                     













                                                                       
                                




                                         


                                                                          

















                                                                               






                                                        
                                

             
                                                                               

                                                 
















                                                                                      
                 



















                                                                                          



                                                                    
 



                                                                   



                                              
                                                       










                                                                      
                              

                                                      


                        
                                              
                            

                                                        

                                                                







                                                   





                                                                          

                                     



                                                              

                                    

                                              


                                                        


                                                                               


                     



                                                                     













                                                                     



                                            
                                   
                                                                    


                                               
                                           











                                                                      
                                                                                







                                                               
                                                                        

                                                              
                                                  
             




                                                                         
                 


                                                                                





                                                       
                                                             
                                                       



                                             


                                                          
                                                     


                                                 
                                                  








                                                                               
                                                                 

                                                                                     
                                                
                                                     






                                                                          

              
// Copyright 2018 The dexon-consensus Authors
// This file is part of the dexon-consensus library.
//
// The dexon-consensus library is free software: you can redistribute it
// and/or modify it under the terms of the GNU Lesser General Public License as
// published by the Free Software Foundation, either version 3 of the License,
// or (at your option) any later version.
//
// The dexon-consensus library is distributed in the hope that it will be
// useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
// General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the dexon-consensus library. If not, see
// <http://www.gnu.org/licenses/>.

package test

import (
    "context"
    "encoding/base64"
    "encoding/binary"
    "encoding/json"
    "fmt"
    "io"
    "math"
    "math/rand"
    "net"
    "os"
    "strconv"
    "strings"
    "sync"
    "syscall"
    "time"

    "github.com/dexon-foundation/dexon-consensus/core/crypto"
    "github.com/dexon-foundation/dexon-consensus/core/crypto/ecdsa"
    "github.com/dexon-foundation/dexon-consensus/core/types"
)

type tcpPeerRecord struct {
    conn        string
    sendChannel chan<- []byte
    pubKey      crypto.PublicKey
}

// tcpMessage is the general message between peers and server.
type tcpMessage struct {
    NodeID types.NodeID `json:"nid"`
    Type   string       `json:"type"`
    Info   string       `json:"conn"`
}

// buildPeerInfo is a tricky way to combine connection string and
// base64 encoded byte slice for public key into a single string,
// separated by ';'.
func buildPeerInfo(pubKey crypto.PublicKey, conn string) string {
    return conn + ";" + base64.StdEncoding.EncodeToString(pubKey.Bytes())
}

// parsePeerInfo parse connection string and base64 encoded public key built
// via buildPeerInfo.
func parsePeerInfo(info string) (key crypto.PublicKey, conn string) {
    tokens := strings.Split(info, ";")
    conn = tokens[0]
    data, err := base64.StdEncoding.DecodeString(tokens[1])
    if err != nil {
        panic(err)
    }
    key, err = ecdsa.NewPublicKeyFromByteSlice(data)
    if err != nil {
        panic(err)
    }
    return
}

var (
    // ErrTCPHandShakeFail is reported if the tcp handshake fails.
    ErrTCPHandShakeFail = fmt.Errorf("tcp handshake fail")

    // ErrConnectToUnexpectedPeer is reported if connect to unexpected peer.
    ErrConnectToUnexpectedPeer = fmt.Errorf("connect to unexpected peer")

    // ErrMessageOverflow is reported if the message is too long.
    ErrMessageOverflow = fmt.Errorf("message size overflow")
)

// TCPTransport implements Transport interface via TCP connection.
type TCPTransport struct {
    peerType    TransportPeerType
    nID         types.NodeID
    pubKey      crypto.PublicKey
    localPort   int
    peers       map[types.NodeID]*tcpPeerRecord
    peersLock   sync.RWMutex
    recvChannel chan *TransportEnvelope
    ctx         context.Context
    cancel      context.CancelFunc
    latency     LatencyModel
    marshaller  Marshaller
}

// NewTCPTransport constructs an TCPTransport instance.
func NewTCPTransport(
    peerType TransportPeerType,
    pubKey crypto.PublicKey,
    latency LatencyModel,
    marshaller Marshaller,
    localPort int) *TCPTransport {

    ctx, cancel := context.WithCancel(context.Background())
    return &TCPTransport{
        peerType:    peerType,
        nID:         types.NewNodeID(pubKey),
        pubKey:      pubKey,
        peers:       make(map[types.NodeID]*tcpPeerRecord),
        recvChannel: make(chan *TransportEnvelope, 1000),
        ctx:         ctx,
        cancel:      cancel,
        localPort:   localPort,
        latency:     latency,
        marshaller:  marshaller,
    }
}

const handshakeMsg = "Welcome to DEXON network for test."

func (t *TCPTransport) serverHandshake(conn net.Conn) (
    nID types.NodeID, err error) {
    conn.SetDeadline(time.Now().Add(3 * time.Second))
    msg := &tcpMessage{
        NodeID: t.nID,
        Type:   "handshake",
        Info:   handshakeMsg,
    }
    var payload []byte
    payload, err = json.Marshal(msg)
    if err != nil {
        return
    }
    if err = t.write(conn, payload); err != nil {
        return
    }
    if payload, err = t.read(conn); err != nil {
        return
    }
    if err = json.Unmarshal(payload, &msg); err != nil {
        return
    }
    if msg.Type != "handshake-ack" || msg.Info != handshakeMsg {
        err = ErrTCPHandShakeFail
        return
    }
    nID = msg.NodeID
    return
}

func (t *TCPTransport) clientHandshake(conn net.Conn) (
    nID types.NodeID, err error) {
    conn.SetDeadline(time.Now().Add(3 * time.Second))
    var payload []byte
    if payload, err = t.read(conn); err != nil {
        return
    }
    msg := &tcpMessage{}
    if err = json.Unmarshal(payload, &msg); err != nil {
        return
    }
    if msg.Type != "handshake" || msg.Info != handshakeMsg {
        err = ErrTCPHandShakeFail
        return
    }
    nID = msg.NodeID
    msg = &tcpMessage{
        NodeID: t.nID,
        Type:   "handshake-ack",
        Info:   handshakeMsg,
    }
    payload, err = json.Marshal(msg)
    if err != nil {
        return
    }
    if err = t.write(conn, payload); err != nil {
        return
    }
    return
}

// Send implements Transport.Send method.
func (t *TCPTransport) Send(
    endpoint types.NodeID, msg interface{}) (err error) {

    payload, err := t.marshalMessage(msg)
    if err != nil {
        return
    }
    go func() {
        if t.latency != nil {
            time.Sleep(t.latency.Delay())
        }

        t.peersLock.RLock()
        defer t.peersLock.RUnlock()

        t.peers[endpoint].sendChannel <- payload
    }()
    return
}

// Broadcast implements Transport.Broadcast method.
func (t *TCPTransport) Broadcast(msg interface{}) (err error) {
    payload, err := t.marshalMessage(msg)
    if err != nil {
        return
    }
    t.peersLock.RLock()
    defer t.peersLock.RUnlock()

    for nID, rec := range t.peers {
        if nID == t.nID {
            continue
        }
        go func(ch chan<- []byte) {
            if t.latency != nil {
                time.Sleep(t.latency.Delay())
            }
            ch <- payload
        }(rec.sendChannel)
    }
    return
}

// Close implements Transport.Close method.
func (t *TCPTransport) Close() (err error) {
    // Tell all routines raised by us to die.
    t.cancel()
    // Reset peers.
    t.peersLock.Lock()
    defer t.peersLock.Unlock()
    t.peers = make(map[types.NodeID]*tcpPeerRecord)
    // Tell our user that this channel is closed.
    close(t.recvChannel)
    t.recvChannel = nil
    return
}

// Peers implements Transport.Peers method.
func (t *TCPTransport) Peers() (peers []crypto.PublicKey) {
    for _, rec := range t.peers {
        peers = append(peers, rec.pubKey)
    }
    return
}

func (t *TCPTransport) write(conn net.Conn, b []byte) (err error) {
    if len(b) > math.MaxUint32 {
        return ErrMessageOverflow
    }
    msgLength := make([]byte, 4)
    binary.LittleEndian.PutUint32(msgLength, uint32(len(b)))
    if _, err = conn.Write(msgLength); err != nil {
        return
    }
    if _, err = conn.Write(b); err != nil {
        return
    }
    return
}

func (t *TCPTransport) read(conn net.Conn) (b []byte, err error) {
    msgLength := make([]byte, 4)
    if _, err = io.ReadFull(conn, msgLength); err != nil {
        return
    }
    b = make([]byte, int(binary.LittleEndian.Uint32(msgLength)))
    if _, err = io.ReadFull(conn, b); err != nil {
        return
    }
    return
}

func (t *TCPTransport) marshalMessage(
    msg interface{}) (payload []byte, err error) {

    msgCarrier := struct {
        PeerType TransportPeerType `json:"peer_type"`
        From     types.NodeID      `json:"from"`
        Type     string            `json:"type"`
        Payload  interface{}       `json:"payload"`
    }{
        PeerType: t.peerType,
        From:     t.nID,
        Payload:  msg,
    }
    switch msg.(type) {
    case map[types.NodeID]string:
        msgCarrier.Type = "peerlist"
    case *tcpMessage:
        msgCarrier.Type = "trans-msg"
    default:
        if t.marshaller == nil {
            err = fmt.Errorf("unknown msg type: %v", msg)
            break
        }
        // Delegate to user defined marshaller.
        var buff []byte
        msgCarrier.Type, buff, err = t.marshaller.Marshal(msg)
        if err != nil {
            break
        }
        msgCarrier.Payload = json.RawMessage(buff)
    }
    if err != nil {
        return
    }
    payload, err = json.Marshal(msgCarrier)
    return
}

func (t *TCPTransport) unmarshalMessage(
    payload []byte) (
    peerType TransportPeerType,
    from types.NodeID,
    msg interface{},
    err error) {

    msgCarrier := struct {
        PeerType TransportPeerType `json:"peer_type"`
        From     types.NodeID      `json:"from"`
        Type     string            `json:"type"`
        Payload  json.RawMessage   `json:"payload"`
    }{}
    if err = json.Unmarshal(payload, &msgCarrier); err != nil {
        return
    }
    peerType = msgCarrier.PeerType
    from = msgCarrier.From
    switch msgCarrier.Type {
    case "peerlist":
        var peers map[types.NodeID]string
        if err = json.Unmarshal(msgCarrier.Payload, &peers); err != nil {
            return
        }
        msg = peers
    case "trans-msg":
        m := &tcpMessage{}
        if err = json.Unmarshal(msgCarrier.Payload, m); err != nil {
            return
        }
        msg = m
    default:
        if t.marshaller == nil {
            err = fmt.Errorf("unknown msg type: %v", msgCarrier.Type)
            break
        }
        msg, err = t.marshaller.Unmarshal(msgCarrier.Type, msgCarrier.Payload)
    }
    return
}

// connReader is a reader routine to read from a TCP connection.
func (t *TCPTransport) connReader(conn net.Conn) {
    defer func() {
        if err := conn.Close(); err != nil {
            panic(err)
        }
    }()

    var (
        err     error
        payload []byte
    )

    checkErr := func(err error) (toBreak bool) {
        if err == io.EOF {
            toBreak = true
            return
        }
        // Check if timeout.
        nErr, ok := err.(*net.OpError)
        if !ok {
            panic(err)
        }
        if !nErr.Timeout() {
            panic(err)
        }
        return
    }
Loop:
    for {
        select {
        case <-t.ctx.Done():
            break Loop
        default:
        }
        // Add timeout when reading to check if shutdown.
        if err := conn.SetReadDeadline(
            time.Now().Add(2 * time.Second)); err != nil {

            panic(err)
        }
        // Read message length.
        if payload, err = t.read(conn); err != nil {
            if checkErr(err) {
                break
            }
            continue
        }
        peerType, from, msg, err := t.unmarshalMessage(payload)
        if err != nil {
            panic(err)
        }
        t.recvChannel <- &TransportEnvelope{
            PeerType: peerType,
            From:     from,
            Msg:      msg,
        }
    }
}

// connWriter is a writer routine to write to TCP connection.
func (t *TCPTransport) connWriter(conn net.Conn) chan<- []byte {
    // Disable write deadline.
    if err := conn.SetWriteDeadline(time.Time{}); err != nil {
        panic(err)
    }

    ch := make(chan []byte, 1000)
    go func() {
        defer func() {
            close(ch)
            if err := conn.Close(); err != nil {
                panic(err)
            }
        }()
        for {
            select {
            case <-t.ctx.Done():
                return
            default:
            }
            select {
            case <-t.ctx.Done():
                return
            case msg := <-ch:
                // Send message length in uint32.
                if err := t.write(conn, msg); err != nil {
                    panic(err)
                }
            }
        }
    }()
    return ch
}

// listenerRoutine is a routine to accept incoming request for TCP connection.
func (t *TCPTransport) listenerRoutine(listener *net.TCPListener) {
    closed := false
    defer func() {
        if closed {
            return
        }
        if err := listener.Close(); err != nil {
            panic(err)
        }
    }()
    for {
        select {
        case <-t.ctx.Done():
            return
        default:
        }

        listener.SetDeadline(time.Now().Add(5 * time.Second))
        conn, err := listener.Accept()
        if err != nil {
            // Check if the connection is closed.
            if strings.Contains(err.Error(), "use of closed network connection") {
                closed = true
                return
            }
            // Check if timeout error.
            nErr, ok := err.(*net.OpError)
            if !ok {
                panic(err)
            }
            if !nErr.Timeout() {
                panic(err)
            }
            continue
        }
        if _, err := t.serverHandshake(conn); err != nil {
            fmt.Println(err)
            continue
        }
        go t.connReader(conn)
    }
}

// buildConnectionToPeers constructs TCP connections to each peer.
// Although TCP connection could be used for both read/write operation,
// we only utilize the write part for simplicity.
func (t *TCPTransport) buildConnectionsToPeers() (err error) {
    var wg sync.WaitGroup
    for nID, rec := range t.peers {
        if nID == t.nID {
            continue
        }
        wg.Add(1)
        go func(nID types.NodeID, addr string) {
            defer wg.Done()

            conn, localErr := net.Dial("tcp", addr)
            if localErr != nil {
                // Propagate this error to outside, at least one error
                // could be returned to caller.
                err = localErr
                return
            }
            serverID, e := t.clientHandshake(conn)
            if e != nil {
                err = e
                return
            }
            if nID != serverID {
                err = ErrConnectToUnexpectedPeer
                return
            }
            t.peersLock.Lock()
            defer t.peersLock.Unlock()

            t.peers[nID].sendChannel = t.connWriter(conn)
        }(nID, rec.conn)
    }
    wg.Wait()
    return
}

// TCPTransportClient implement TransportClient base on TCP connection.
type TCPTransportClient struct {
    TCPTransport
    local              bool
    serverWriteChannel chan<- []byte
}

// NewTCPTransportClient constructs a TCPTransportClient instance.
func NewTCPTransportClient(
    pubKey crypto.PublicKey,
    latency LatencyModel,
    marshaller Marshaller,
    local bool) *TCPTransportClient {

    return &TCPTransportClient{
        TCPTransport: *NewTCPTransport(
            TransportPeer, pubKey, latency, marshaller, 8080),
        local: local,
    }
}

// Report implements TransportClient.Report method.
func (t *TCPTransportClient) Report(msg interface{}) (err error) {
    payload, err := t.marshalMessage(msg)
    if err != nil {
        return
    }
    go func() {
        t.serverWriteChannel <- payload
    }()
    return
}

// Join implements TransportClient.Join method.
func (t *TCPTransportClient) Join(
    serverEndpoint interface{}) (ch <-chan *TransportEnvelope, err error) {
    // Initiate a TCP server.
    // TODO(mission): config initial listening port.
    var (
        ln        net.Listener
        envelopes = []*TransportEnvelope{}
        ok        bool
        addr      string
        conn      string
    )
    for {
        addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(t.localPort))
        ln, err = net.Listen("tcp", addr)
        if err == nil {
            go t.listenerRoutine(ln.(*net.TCPListener))
            // It is possible to listen on the same port in some platform.
            // Check if this one is actually listening.
            testConn, e := net.Dial("tcp", addr)
            if e != nil {
                err = e
                return
            }
            nID, e := t.clientHandshake(testConn)
            if e != nil {
                err = e
                return
            }
            if nID == t.nID {
                break
            }
            ln.Close()
        }
        if err != nil {
            if !t.local {
                return
            }
            // In local-tcp, retry with other port when the address is in use.
            operr, ok := err.(*net.OpError)
            if !ok {
                panic(err)
            }
            oserr, ok := operr.Err.(*os.SyscallError)
            if !ok {
                panic(operr)
            }
            errno, ok := oserr.Err.(syscall.Errno)
            if !ok {
                panic(oserr)
            }
            if errno != syscall.EADDRINUSE {
                panic(errno)
            }
        }
        // The port is used, generate another port randomly.
        t.localPort = 1024 + rand.Int()%1024
    }

    serverConn, err := net.Dial("tcp", serverEndpoint.(string))
    if err != nil {
        return
    }
    _, err = t.clientHandshake(serverConn)
    if err != nil {
        return
    }
    t.serverWriteChannel = t.connWriter(serverConn)
    if t.local {
        conn = addr
    } else {
        // Find my IP.
        var ip string
        if ip, err = FindMyIP(); err != nil {
            return
        }
        conn = net.JoinHostPort(ip, strconv.Itoa(t.localPort))
    }
    if err = t.Report(&tcpMessage{
        NodeID: t.nID,
        Type:   "conn",
        Info:   buildPeerInfo(t.pubKey, conn),
    }); err != nil {
        return
    }
    // Wait for peers list sent by server.
    e := <-t.recvChannel
    peersInfo, ok := e.Msg.(map[types.NodeID]string)
    if !ok {
        panic(fmt.Errorf("expect peer list, not %v", e))
    }
    // Setup peers information.
    for nID, info := range peersInfo {
        pubKey, conn := parsePeerInfo(info)
        t.peers[nID] = &tcpPeerRecord{
            conn:   conn,
            pubKey: pubKey,
        }
    }
    // Setup connections to other peers.
    if err = t.buildConnectionsToPeers(); err != nil {
        return
    }
    // Report to server that the connections to other peers are ready.
    if err = t.Report(&tcpMessage{
        Type:   "conn-ready",
        NodeID: t.nID,
    }); err != nil {
        return
    }
    // Wait for server to ack us that all peers are ready.
    for {
        e := <-t.recvChannel
        msg, ok := e.Msg.(*tcpMessage)
        if !ok {
            envelopes = append(envelopes, e)
            continue
        }
        if msg.Type != "all-ready" {
            err = fmt.Errorf("expected ready message, but %v", msg)
            return
        }
        break
    }
    // Replay those messages sent before peer list and ready-ack.
    for _, e := range envelopes {
        t.recvChannel <- e
    }
    ch = t.recvChannel
    return
}

// TCPTransportServer implements TransportServer via TCP connections.
type TCPTransportServer struct {
    TCPTransport
}

// NewTCPTransportServer constructs TCPTransportServer instance.
func NewTCPTransportServer(
    marshaller Marshaller,
    serverPort int) *TCPTransportServer {

    prvKey, err := ecdsa.NewPrivateKey()
    if err != nil {
        panic(err)
    }
    return &TCPTransportServer{
        // NOTE: the assumption here is the node ID of peers
        //       won't be zero.
        TCPTransport: *NewTCPTransport(
            TransportPeerServer,
            prvKey.PublicKey(),
            nil,
            marshaller,
            serverPort),
    }
}

// Host implements TransportServer.Host method.
func (t *TCPTransportServer) Host() (chan *TransportEnvelope, error) {
    // The port of peer server should be known to other peers,
    // if we can listen on the pre-defiend part, we don't have to
    // retry with other random ports.
    ln, err := net.Listen(
        "tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(t.localPort)))
    if err != nil {
        return nil, err
    }
    go t.listenerRoutine(ln.(*net.TCPListener))
    return t.recvChannel, nil
}

// WaitForPeers implements TransportServer.WaitForPeers method.
func (t *TCPTransportServer) WaitForPeers(numPeers uint32) (err error) {
    // Collect peers info. Packets other than peer info is
    // unexpected.
    peersInfo := make(map[types.NodeID]string)
    for {
        // Wait for connection info reported by peers.
        e := <-t.recvChannel
        msg, ok := e.Msg.(*tcpMessage)
        if !ok {
            panic(fmt.Errorf("expect tcpMessage, not %v", e))
        }
        if msg.Type != "conn" {
            panic(fmt.Errorf("expect connection report, not %v", e))
        }
        pubKey, conn := parsePeerInfo(msg.Info)
        t.peers[msg.NodeID] = &tcpPeerRecord{
            conn:   conn,
            pubKey: pubKey,
        }
        peersInfo[msg.NodeID] = msg.Info
        // Check if we already collect enought peers.
        if uint32(len(peersInfo)) == numPeers {
            break
        }
    }
    // Send collected peers back to them.
    if err = t.buildConnectionsToPeers(); err != nil {
        return
    }
    if err = t.Broadcast(peersInfo); err != nil {
        return
    }
    // Wait for peers to send 'ready' report.
    readies := make(map[types.NodeID]struct{})
    for {
        e := <-t.recvChannel
        msg, ok := e.Msg.(*tcpMessage)
        if !ok {
            panic(fmt.Errorf("expect tcpMessage, not %v", e))
        }
        if msg.Type != "conn-ready" {
            panic(fmt.Errorf("expect connection ready, not %v", e))
        }
        if _, reported := readies[msg.NodeID]; reported {
            panic(fmt.Errorf("already report conn-ready message: %v", e))
        }
        readies[msg.NodeID] = struct{}{}
        if uint32(len(readies)) == numPeers {
            break
        }
    }
    // Ack all peers ready to go.
    if err = t.Broadcast(&tcpMessage{Type: "all-ready"}); err != nil {
        return
    }
    return
}