diff options
Diffstat (limited to 'p2p/message.go')
-rw-r--r-- | p2p/message.go | 255 |
1 files changed, 255 insertions, 0 deletions
diff --git a/p2p/message.go b/p2p/message.go new file mode 100644 index 000000000..b42acbe3c --- /dev/null +++ b/p2p/message.go @@ -0,0 +1,255 @@ +package p2p + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum/rlp" +) + +// Msg defines the structure of a p2p message. +// +// Note that a Msg can only be sent once since the Payload reader is +// consumed during sending. It is not possible to create a Msg and +// send it any number of times. If you want to reuse an encoded +// structure, encode the payload into a byte array and create a +// separate Msg with a bytes.Reader as Payload for each send. +type Msg struct { + Code uint64 + Size uint32 // size of the paylod + Payload io.Reader +} + +// Decode parses the RLP content of a message into +// the given value, which must be a pointer. +// +// For the decoding rules, please see package rlp. +func (msg Msg) Decode(val interface{}) error { + if err := rlp.Decode(msg.Payload, val); err != nil { + return newPeerError(errInvalidMsg, "(code %x) (size %d) %v", msg.Code, msg.Size, err) + } + return nil +} + +func (msg Msg) String() string { + return fmt.Sprintf("msg #%v (%v bytes)", msg.Code, msg.Size) +} + +// Discard reads any remaining payload data into a black hole. +func (msg Msg) Discard() error { + _, err := io.Copy(ioutil.Discard, msg.Payload) + return err +} + +type MsgReader interface { + ReadMsg() (Msg, error) +} + +type MsgWriter interface { + // WriteMsg sends a message. It will block until the message's + // Payload has been consumed by the other end. + // + // Note that messages can be sent only once because their + // payload reader is drained. + WriteMsg(Msg) error +} + +// MsgReadWriter provides reading and writing of encoded messages. +// Implementations should ensure that ReadMsg and WriteMsg can be +// called simultaneously from multiple goroutines. +type MsgReadWriter interface { + MsgReader + MsgWriter +} + +// Send writes an RLP-encoded message with the given code. +// data should encode as an RLP list. +func Send(w MsgWriter, msgcode uint64, data interface{}) error { + size, r, err := rlp.EncodeToReader(data) + if err != nil { + return err + } + return w.WriteMsg(Msg{Code: msgcode, Size: uint32(size), Payload: r}) +} + +// SendItems writes an RLP with the given code and data elements. +// For a call such as: +// +// SendItems(w, code, e1, e2, e3) +// +// the message payload will be an RLP list containing the items: +// +// [e1, e2, e3] +// +func SendItems(w MsgWriter, msgcode uint64, elems ...interface{}) error { + return Send(w, msgcode, elems) +} + +// netWrapper wraps a MsgReadWriter with locks around +// ReadMsg/WriteMsg and applies read/write deadlines. +type netWrapper struct { + rmu, wmu sync.Mutex + + rtimeout, wtimeout time.Duration + conn net.Conn + wrapped MsgReadWriter +} + +func (rw *netWrapper) ReadMsg() (Msg, error) { + rw.rmu.Lock() + defer rw.rmu.Unlock() + rw.conn.SetReadDeadline(time.Now().Add(rw.rtimeout)) + return rw.wrapped.ReadMsg() +} + +func (rw *netWrapper) WriteMsg(msg Msg) error { + rw.wmu.Lock() + defer rw.wmu.Unlock() + rw.conn.SetWriteDeadline(time.Now().Add(rw.wtimeout)) + return rw.wrapped.WriteMsg(msg) +} + +// eofSignal wraps a reader with eof signaling. the eof channel is +// closed when the wrapped reader returns an error or when count bytes +// have been read. +type eofSignal struct { + wrapped io.Reader + count uint32 // number of bytes left + eof chan<- struct{} +} + +// note: when using eofSignal to detect whether a message payload +// has been read, Read might not be called for zero sized messages. +func (r *eofSignal) Read(buf []byte) (int, error) { + if r.count == 0 { + if r.eof != nil { + r.eof <- struct{}{} + r.eof = nil + } + return 0, io.EOF + } + + max := len(buf) + if int(r.count) < len(buf) { + max = int(r.count) + } + n, err := r.wrapped.Read(buf[:max]) + r.count -= uint32(n) + if (err != nil || r.count == 0) && r.eof != nil { + r.eof <- struct{}{} // tell Peer that msg has been consumed + r.eof = nil + } + return n, err +} + +// MsgPipe creates a message pipe. Reads on one end are matched +// with writes on the other. The pipe is full-duplex, both ends +// implement MsgReadWriter. +func MsgPipe() (*MsgPipeRW, *MsgPipeRW) { + var ( + c1, c2 = make(chan Msg), make(chan Msg) + closing = make(chan struct{}) + closed = new(int32) + rw1 = &MsgPipeRW{c1, c2, closing, closed} + rw2 = &MsgPipeRW{c2, c1, closing, closed} + ) + return rw1, rw2 +} + +// ErrPipeClosed is returned from pipe operations after the +// pipe has been closed. +var ErrPipeClosed = errors.New("p2p: read or write on closed message pipe") + +// MsgPipeRW is an endpoint of a MsgReadWriter pipe. +type MsgPipeRW struct { + w chan<- Msg + r <-chan Msg + closing chan struct{} + closed *int32 +} + +// WriteMsg sends a messsage on the pipe. +// It blocks until the receiver has consumed the message payload. +func (p *MsgPipeRW) WriteMsg(msg Msg) error { + if atomic.LoadInt32(p.closed) == 0 { + consumed := make(chan struct{}, 1) + msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed} + select { + case p.w <- msg: + if msg.Size > 0 { + // wait for payload read or discard + select { + case <-consumed: + case <-p.closing: + } + } + return nil + case <-p.closing: + } + } + return ErrPipeClosed +} + +// ReadMsg returns a message sent on the other end of the pipe. +func (p *MsgPipeRW) ReadMsg() (Msg, error) { + if atomic.LoadInt32(p.closed) == 0 { + select { + case msg := <-p.r: + return msg, nil + case <-p.closing: + } + } + return Msg{}, ErrPipeClosed +} + +// Close unblocks any pending ReadMsg and WriteMsg calls on both ends +// of the pipe. They will return ErrPipeClosed. Close also +// interrupts any reads from a message payload. +func (p *MsgPipeRW) Close() error { + if atomic.AddInt32(p.closed, 1) != 1 { + // someone else is already closing + atomic.StoreInt32(p.closed, 1) // avoid overflow + return nil + } + close(p.closing) + return nil +} + +// ExpectMsg reads a message from r and verifies that its +// code and encoded RLP content match the provided values. +// If content is nil, the payload is discarded and not verified. +func ExpectMsg(r MsgReader, code uint64, content interface{}) error { + msg, err := r.ReadMsg() + if err != nil { + return err + } + if msg.Code != code { + return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code) + } + if content == nil { + return msg.Discard() + } else { + contentEnc, err := rlp.EncodeToBytes(content) + if err != nil { + panic("content encode error: " + err.Error()) + } + if int(msg.Size) != len(contentEnc) { + return fmt.Errorf("message size mismatch: got %d, want %d", msg.Size, len(contentEnc)) + } + actualContent, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return err + } + if !bytes.Equal(actualContent, contentEnc) { + return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc) + } + } + return nil +} |