diff options
Diffstat (limited to 'core/test/tcp-transport.go')
-rw-r--r-- | core/test/tcp-transport.go | 101 |
1 files changed, 72 insertions, 29 deletions
diff --git a/core/test/tcp-transport.go b/core/test/tcp-transport.go index 0f9bd73..6e3ddfb 100644 --- a/core/test/tcp-transport.go +++ b/core/test/tcp-transport.go @@ -19,6 +19,7 @@ package test import ( "context" + "encoding/base64" "encoding/binary" "encoding/json" "fmt" @@ -28,13 +29,22 @@ import ( "net" "os" "strconv" + "strings" "sync" "syscall" "time" + "github.com/dexon-foundation/dexon-consensus-core/core/crypto" + "github.com/dexon-foundation/dexon-consensus-core/core/crypto/ecdsa" "github.com/dexon-foundation/dexon-consensus-core/core/types" ) +type tcpPeerRecord struct { + conn string + sendChannel chan<- []byte + pubKey crypto.PublicKey +} + // tcpMessage is the general message between peers and server. type tcpMessage struct { NodeID types.NodeID `json:"nid"` @@ -42,13 +52,33 @@ type tcpMessage struct { Info string `json:"conn"` } +// buildPeerInfo is a tricky way to combine connection string and +// base64 encoded byte slice for public key into a single string, +// separated by ';'. +func buildPeerInfo(pubKey crypto.PublicKey, conn string) string { + return conn + ";" + base64.StdEncoding.EncodeToString(pubKey.Bytes()) +} + +// parsePeerInfo parse connection string and base64 encoded public key built +// via buildPeerInfo. +func parsePeerInfo(info string) (key crypto.PublicKey, conn string) { + tokens := strings.Split(info, ";") + conn = tokens[0] + data, err := base64.StdEncoding.DecodeString(tokens[1]) + if err != nil { + panic(err) + } + key = ecdsa.NewPublicKeyFromByteSlice(data) + return +} + // TCPTransport implements Transport interface via TCP connection. type TCPTransport struct { peerType TransportPeerType nID types.NodeID + pubKey crypto.PublicKey localPort int - peersInfo map[types.NodeID]string - peers map[types.NodeID]chan<- []byte + peers map[types.NodeID]*tcpPeerRecord peersLock sync.RWMutex recvChannel chan *TransportEnvelope ctx context.Context @@ -60,7 +90,7 @@ type TCPTransport struct { // NewTCPTransport constructs an TCPTransport instance. func NewTCPTransport( peerType TransportPeerType, - nID types.NodeID, + pubKey crypto.PublicKey, latency LatencyModel, marshaller Marshaller, localPort int) *TCPTransport { @@ -68,9 +98,9 @@ func NewTCPTransport( ctx, cancel := context.WithCancel(context.Background()) return &TCPTransport{ peerType: peerType, - nID: nID, - peersInfo: make(map[types.NodeID]string), - peers: make(map[types.NodeID]chan<- []byte), + nID: types.NewNodeID(pubKey), + pubKey: pubKey, + peers: make(map[types.NodeID]*tcpPeerRecord), recvChannel: make(chan *TransportEnvelope, 1000), ctx: ctx, cancel: cancel, @@ -96,7 +126,7 @@ func (t *TCPTransport) Send( t.peersLock.RLock() defer t.peersLock.RUnlock() - t.peers[endpoint] <- payload + t.peers[endpoint].sendChannel <- payload }() return } @@ -110,7 +140,7 @@ func (t *TCPTransport) Broadcast(msg interface{}) (err error) { t.peersLock.RLock() defer t.peersLock.RUnlock() - for nID, ch := range t.peers { + for nID, rec := range t.peers { if nID == t.nID { continue } @@ -119,7 +149,7 @@ func (t *TCPTransport) Broadcast(msg interface{}) (err error) { time.Sleep(t.latency.Delay()) } ch <- payload - }(ch) + }(rec.sendChannel) } return } @@ -131,7 +161,7 @@ func (t *TCPTransport) Close() (err error) { // Reset peers. t.peersLock.Lock() defer t.peersLock.Unlock() - t.peers = make(map[types.NodeID]chan<- []byte) + t.peers = make(map[types.NodeID]*tcpPeerRecord) // Tell our user that this channel is closed. close(t.recvChannel) t.recvChannel = nil @@ -139,10 +169,9 @@ func (t *TCPTransport) Close() (err error) { } // Peers implements Transport.Peers method. -func (t *TCPTransport) Peers() (peers map[types.NodeID]struct{}) { - peers = make(map[types.NodeID]struct{}) - for nID := range t.peersInfo { - peers[nID] = struct{}{} +func (t *TCPTransport) Peers() (peers []crypto.PublicKey) { + for _, rec := range t.peers { + peers = append(peers, rec.pubKey) } return } @@ -376,7 +405,7 @@ func (t *TCPTransport) listenerRoutine(listener *net.TCPListener) { // we only utilize the write part for simplicity. func (t *TCPTransport) buildConnectionsToPeers() (err error) { var wg sync.WaitGroup - for nID, addr := range t.peersInfo { + for nID, rec := range t.peers { if nID == t.nID { continue } @@ -394,8 +423,8 @@ func (t *TCPTransport) buildConnectionsToPeers() (err error) { t.peersLock.Lock() defer t.peersLock.Unlock() - t.peers[nID] = t.connWriter(conn) - }(nID, addr) + t.peers[nID].sendChannel = t.connWriter(conn) + }(nID, rec.conn) } wg.Wait() return @@ -410,14 +439,15 @@ type TCPTransportClient struct { // NewTCPTransportClient constructs a TCPTransportClient instance. func NewTCPTransportClient( - nID types.NodeID, + pubKey crypto.PublicKey, latency LatencyModel, marshaller Marshaller, local bool) *TCPTransportClient { return &TCPTransportClient{ - TCPTransport: *NewTCPTransport(TransportPeer, nID, latency, marshaller, 8080), - local: local, + TCPTransport: *NewTCPTransport( + TransportPeer, pubKey, latency, marshaller, 8080), + local: local, } } @@ -436,7 +466,6 @@ func (t *TCPTransportClient) Report(msg interface{}) (err error) { // Join implements TransportClient.Join method. func (t *TCPTransportClient) Join( serverEndpoint interface{}) (ch <-chan *TransportEnvelope, err error) { - // Initiate a TCP server. // TODO(mission): config initial listening port. var ( @@ -475,7 +504,6 @@ func (t *TCPTransportClient) Join( t.localPort = 1024 + rand.Int()%1024 } go t.listenerRoutine(ln.(*net.TCPListener)) - serverConn, err := net.Dial("tcp", serverEndpoint.(string)) if err != nil { return @@ -492,17 +520,26 @@ func (t *TCPTransportClient) Join( conn = net.JoinHostPort(ip, strconv.Itoa(t.localPort)) } if err = t.Report(&tcpMessage{ - Type: "conn", NodeID: t.nID, - Info: conn, + Type: "conn", + Info: buildPeerInfo(t.pubKey, conn), }); err != nil { return } // Wait for peers list sent by server. e := <-t.recvChannel - if t.peersInfo, ok = e.Msg.(map[types.NodeID]string); !ok { + peersInfo, ok := e.Msg.(map[types.NodeID]string) + if !ok { panic(fmt.Errorf("expect peer list, not %v", e)) } + // Setup peers information. + for nID, info := range peersInfo { + pubKey, conn := parsePeerInfo(info) + t.peers[nID] = &tcpPeerRecord{ + conn: conn, + pubKey: pubKey, + } + } // Setup connections to other peers. if err = t.buildConnectionsToPeers(); err != nil { return @@ -551,7 +588,7 @@ func NewTCPTransportServer( // won't be zero. TCPTransport: *NewTCPTransport( TransportPeerServer, - types.NodeID{}, + ecdsa.NewPublicKeyFromByteSlice(nil), nil, marshaller, serverPort), @@ -576,6 +613,7 @@ func (t *TCPTransportServer) Host() (chan *TransportEnvelope, error) { func (t *TCPTransportServer) WaitForPeers(numPeers int) (err error) { // Collect peers info. Packets other than peer info is // unexpected. + peersInfo := make(map[types.NodeID]string) for { // Wait for connection info reported by peers. e := <-t.recvChannel @@ -586,9 +624,14 @@ func (t *TCPTransportServer) WaitForPeers(numPeers int) (err error) { if msg.Type != "conn" { panic(fmt.Errorf("expect connection report, not %v", e)) } - t.peersInfo[msg.NodeID] = msg.Info + pubKey, conn := parsePeerInfo(msg.Info) + t.peers[msg.NodeID] = &tcpPeerRecord{ + conn: conn, + pubKey: pubKey, + } + peersInfo[msg.NodeID] = msg.Info // Check if we already collect enought peers. - if len(t.peersInfo) == numPeers { + if len(peersInfo) == numPeers { break } } @@ -596,7 +639,7 @@ func (t *TCPTransportServer) WaitForPeers(numPeers int) (err error) { if err = t.buildConnectionsToPeers(); err != nil { return } - if err = t.Broadcast(t.peersInfo); err != nil { + if err = t.Broadcast(peersInfo); err != nil { return } // Wait for peers to send 'ready' report. |