aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--p2p/peer.go17
-rw-r--r--p2p/peer_test.go56
2 files changed, 68 insertions, 5 deletions
diff --git a/p2p/peer.go b/p2p/peer.go
index 893ba86d7..86c4d7ab5 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -300,7 +300,7 @@ func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error)
proto.in <- msg
} else {
wait = true
- pr := &eofSignal{msg.Payload, protoDone}
+ pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
msg.Payload = pr
proto.in <- msg
}
@@ -438,18 +438,25 @@ func (rw *proto) ReadMsg() (Msg, error) {
return msg, nil
}
-// eofSignal wraps a reader with eof signaling.
-// the eof channel is closed when the wrapped reader
-// reaches EOF.
+// eofSignal wraps a reader with eof signaling. the eof channel is
+// closed when the wrapped reader returns an error or when count bytes
+// have been read.
+//
type eofSignal struct {
wrapped io.Reader
+ count int64
eof chan<- struct{}
}
+// note: when using eofSignal to detect whether a message payload
+// has been read, Read might not be called for zero sized messages.
+
func (r *eofSignal) Read(buf []byte) (int, error) {
n, err := r.wrapped.Read(buf)
- if err != nil {
+ r.count -= int64(n)
+ if (err != nil || r.count <= 0) && r.eof != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed
+ r.eof = nil
}
return n, err
}
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index d9640292f..f7759786e 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"encoding/hex"
+ "io"
"io/ioutil"
"net"
"reflect"
@@ -237,3 +238,58 @@ func TestNewPeer(t *testing.T) {
// Should not hang.
p.Disconnect(DiscAlreadyConnected)
}
+
+func TestEOFSignal(t *testing.T) {
+ rb := make([]byte, 10)
+
+ // empty reader
+ eof := make(chan struct{}, 1)
+ sig := &eofSignal{new(bytes.Buffer), 0, eof}
+ if n, err := sig.Read(rb); n != 0 || err != io.EOF {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ default:
+ t.Error("EOF chan not signaled")
+ }
+
+ // count before error
+ eof = make(chan struct{}, 1)
+ sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
+ if n, err := sig.Read(rb); n != 8 || err != nil {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ default:
+ t.Error("EOF chan not signaled")
+ }
+
+ // error before count
+ eof = make(chan struct{}, 1)
+ sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
+ if n, err := sig.Read(rb); n != 4 || err != nil {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ if n, err := sig.Read(rb); n != 0 || err != io.EOF {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ default:
+ t.Error("EOF chan not signaled")
+ }
+
+ // no signal if neither occurs
+ eof = make(chan struct{}, 1)
+ sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
+ if n, err := sig.Read(rb); n != 10 || err != nil {
+ t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+ }
+ select {
+ case <-eof:
+ t.Error("unexpected EOF signal")
+ default:
+ }
+}