aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--rpc/subscription.go36
-rw-r--r--rpc/subscription_test.go103
2 files changed, 68 insertions, 71 deletions
diff --git a/rpc/subscription.go b/rpc/subscription.go
index 6ce7befa1..6bbb6f75d 100644
--- a/rpc/subscription.go
+++ b/rpc/subscription.go
@@ -52,9 +52,10 @@ type notifierKey struct{}
// Server callbacks use the notifier to send notifications.
type Notifier struct {
codec ServerCodec
- subMu sync.RWMutex // guards active and inactive maps
+ subMu sync.Mutex
active map[ID]*Subscription
inactive map[ID]*Subscription
+ buffer map[ID][]interface{} // unsent notifications of inactive subscriptions
}
// newNotifier creates a new notifier that can be used to send subscription
@@ -64,6 +65,7 @@ func newNotifier(codec ServerCodec) *Notifier {
codec: codec,
active: make(map[ID]*Subscription),
inactive: make(map[ID]*Subscription),
+ buffer: make(map[ID][]interface{}),
}
}
@@ -88,20 +90,26 @@ func (n *Notifier) CreateSubscription() *Subscription {
// Notify sends a notification to the client with the given data as payload.
// If an error occurs the RPC connection is closed and the error is returned.
func (n *Notifier) Notify(id ID, data interface{}) error {
- n.subMu.RLock()
- defer n.subMu.RUnlock()
-
- sub, active := n.active[id]
- if active {
- notification := n.codec.CreateNotification(string(id), sub.namespace, data)
- if err := n.codec.Write(notification); err != nil {
- n.codec.Close()
- return err
- }
+ n.subMu.Lock()
+ defer n.subMu.Unlock()
+
+ if sub, active := n.active[id]; active {
+ n.send(sub, data)
+ } else {
+ n.buffer[id] = append(n.buffer[id], data)
}
return nil
}
+func (n *Notifier) send(sub *Subscription, data interface{}) error {
+ notification := n.codec.CreateNotification(string(sub.ID), sub.namespace, data)
+ err := n.codec.Write(notification)
+ if err != nil {
+ n.codec.Close()
+ }
+ return err
+}
+
// Closed returns a channel that is closed when the RPC connection is closed.
func (n *Notifier) Closed() <-chan interface{} {
return n.codec.Closed()
@@ -127,9 +135,15 @@ func (n *Notifier) unsubscribe(id ID) error {
func (n *Notifier) activate(id ID, namespace string) {
n.subMu.Lock()
defer n.subMu.Unlock()
+
if sub, found := n.inactive[id]; found {
sub.namespace = namespace
n.active[id] = sub
delete(n.inactive, id)
+ // Send buffered notifications.
+ for _, data := range n.buffer[id] {
+ n.send(sub, data)
+ }
+ delete(n.buffer, id)
}
}
diff --git a/rpc/subscription_test.go b/rpc/subscription_test.go
index 0ba177e63..24febc919 100644
--- a/rpc/subscription_test.go
+++ b/rpc/subscription_test.go
@@ -27,9 +27,8 @@ import (
)
type NotificationTestService struct {
- mu sync.Mutex
- unsubscribed bool
-
+ mu sync.Mutex
+ unsubscribed chan string
gotHangSubscriptionReq chan struct{}
unblockHangSubscription chan struct{}
}
@@ -38,16 +37,10 @@ func (s *NotificationTestService) Echo(i int) int {
return i
}
-func (s *NotificationTestService) wasUnsubCallbackCalled() bool {
- s.mu.Lock()
- defer s.mu.Unlock()
- return s.unsubscribed
-}
-
func (s *NotificationTestService) Unsubscribe(subid string) {
- s.mu.Lock()
- s.unsubscribed = true
- s.mu.Unlock()
+ if s.unsubscribed != nil {
+ s.unsubscribed <- subid
+ }
}
func (s *NotificationTestService) SomeSubscription(ctx context.Context, n, val int) (*Subscription, error) {
@@ -65,7 +58,6 @@ func (s *NotificationTestService) SomeSubscription(ctx context.Context, n, val i
// test expects n events, if we begin sending event immediately some events
// will probably be dropped since the subscription ID might not be send to
// the client.
- time.Sleep(5 * time.Second)
for i := 0; i < n; i++ {
if err := notifier.Notify(subscription.ID, val+i); err != nil {
return
@@ -74,13 +66,10 @@ func (s *NotificationTestService) SomeSubscription(ctx context.Context, n, val i
select {
case <-notifier.Closed():
- s.mu.Lock()
- s.unsubscribed = true
- s.mu.Unlock()
case <-subscription.Err():
- s.mu.Lock()
- s.unsubscribed = true
- s.mu.Unlock()
+ }
+ if s.unsubscribed != nil {
+ s.unsubscribed <- string(subscription.ID)
}
}()
@@ -107,7 +96,7 @@ func (s *NotificationTestService) HangSubscription(ctx context.Context, val int)
func TestNotifications(t *testing.T) {
server := NewServer()
- service := &NotificationTestService{}
+ service := &NotificationTestService{unsubscribed: make(chan string)}
if err := server.RegisterName("eth", service); err != nil {
t.Fatalf("unable to register test service %v", err)
@@ -157,10 +146,10 @@ func TestNotifications(t *testing.T) {
}
clientConn.Close() // causes notification unsubscribe callback to be called
- time.Sleep(1 * time.Second)
-
- if !service.wasUnsubCallbackCalled() {
- t.Error("unsubscribe callback not called after closing connection")
+ select {
+ case <-service.unsubscribed:
+ case <-time.After(1 * time.Second):
+ t.Fatal("Unsubscribe not called after one second")
}
}
@@ -227,18 +216,19 @@ func waitForMessages(t *testing.T, in *json.Decoder, successes chan<- jsonSucces
// for multiple different namespaces.
func TestSubscriptionMultipleNamespaces(t *testing.T) {
var (
- namespaces = []string{"eth", "shh", "bzz"}
+ namespaces = []string{"eth", "shh", "bzz"}
+ service = NotificationTestService{}
+ subCount = len(namespaces) * 2
+ notificationCount = 3
+
server = NewServer()
- service = NotificationTestService{}
clientConn, serverConn = net.Pipe()
-
- out = json.NewEncoder(clientConn)
- in = json.NewDecoder(clientConn)
- successes = make(chan jsonSuccessResponse)
- failures = make(chan jsonErrResponse)
- notifications = make(chan jsonNotification)
-
- errors = make(chan error, 10)
+ out = json.NewEncoder(clientConn)
+ in = json.NewDecoder(clientConn)
+ successes = make(chan jsonSuccessResponse)
+ failures = make(chan jsonErrResponse)
+ notifications = make(chan jsonNotification)
+ errors = make(chan error, 10)
)
// setup and start server
@@ -255,13 +245,12 @@ func TestSubscriptionMultipleNamespaces(t *testing.T) {
go waitForMessages(t, in, successes, failures, notifications, errors)
// create subscriptions one by one
- n := 3
for i, namespace := range namespaces {
request := map[string]interface{}{
"id": i,
"method": fmt.Sprintf("%s_subscribe", namespace),
"version": "2.0",
- "params": []interface{}{"someSubscription", n, i},
+ "params": []interface{}{"someSubscription", notificationCount, i},
}
if err := out.Encode(&request); err != nil {
@@ -276,7 +265,7 @@ func TestSubscriptionMultipleNamespaces(t *testing.T) {
"id": i,
"method": fmt.Sprintf("%s_subscribe", namespace),
"version": "2.0",
- "params": []interface{}{"someSubscription", n, i},
+ "params": []interface{}{"someSubscription", notificationCount, i},
})
}
@@ -285,46 +274,40 @@ func TestSubscriptionMultipleNamespaces(t *testing.T) {
}
timeout := time.After(30 * time.Second)
- subids := make(map[string]string, 2*len(namespaces))
- count := make(map[string]int, 2*len(namespaces))
-
- for {
- done := true
- for id := range count {
- if count, found := count[id]; !found || count < (2*n) {
+ subids := make(map[string]string, subCount)
+ count := make(map[string]int, subCount)
+ allReceived := func() bool {
+ done := len(count) == subCount
+ for _, c := range count {
+ if c < notificationCount {
done = false
}
}
+ return done
+ }
- if done && len(count) == len(namespaces) {
- break
- }
-
+ for !allReceived() {
select {
- case err := <-errors:
- t.Fatal(err)
case suc := <-successes: // subscription created
subids[namespaces[int(suc.Id.(float64))]] = suc.Result.(string)
+ case notification := <-notifications:
+ count[notification.Params.Subscription]++
+ case err := <-errors:
+ t.Fatal(err)
case failure := <-failures:
t.Errorf("received error: %v", failure.Error)
- case notification := <-notifications:
- if cnt, found := count[notification.Params.Subscription]; found {
- count[notification.Params.Subscription] = cnt + 1
- } else {
- count[notification.Params.Subscription] = 1
- }
case <-timeout:
for _, namespace := range namespaces {
subid, found := subids[namespace]
if !found {
- t.Errorf("Subscription for '%s' not created", namespace)
+ t.Errorf("subscription for %q not created", namespace)
continue
}
- if count, found := count[subid]; !found || count < n {
- t.Errorf("Didn't receive all notifications (%d<%d) in time for namespace '%s'", count, n, namespace)
+ if count, found := count[subid]; !found || count < notificationCount {
+ t.Errorf("didn't receive all notifications (%d<%d) in time for namespace %q", count, notificationCount, namespace)
}
}
- return
+ t.Fatal("timed out")
}
}
}