diff options
Diffstat (limited to 'rpc/subscription_test.go')
-rw-r--r-- | rpc/subscription_test.go | 103 |
1 files changed, 43 insertions, 60 deletions
diff --git a/rpc/subscription_test.go b/rpc/subscription_test.go index 0ba177e63..24febc919 100644 --- a/rpc/subscription_test.go +++ b/rpc/subscription_test.go @@ -27,9 +27,8 @@ import ( ) type NotificationTestService struct { - mu sync.Mutex - unsubscribed bool - + mu sync.Mutex + unsubscribed chan string gotHangSubscriptionReq chan struct{} unblockHangSubscription chan struct{} } @@ -38,16 +37,10 @@ func (s *NotificationTestService) Echo(i int) int { return i } -func (s *NotificationTestService) wasUnsubCallbackCalled() bool { - s.mu.Lock() - defer s.mu.Unlock() - return s.unsubscribed -} - func (s *NotificationTestService) Unsubscribe(subid string) { - s.mu.Lock() - s.unsubscribed = true - s.mu.Unlock() + if s.unsubscribed != nil { + s.unsubscribed <- subid + } } func (s *NotificationTestService) SomeSubscription(ctx context.Context, n, val int) (*Subscription, error) { @@ -65,7 +58,6 @@ func (s *NotificationTestService) SomeSubscription(ctx context.Context, n, val i // test expects n events, if we begin sending event immediately some events // will probably be dropped since the subscription ID might not be send to // the client. - time.Sleep(5 * time.Second) for i := 0; i < n; i++ { if err := notifier.Notify(subscription.ID, val+i); err != nil { return @@ -74,13 +66,10 @@ func (s *NotificationTestService) SomeSubscription(ctx context.Context, n, val i select { case <-notifier.Closed(): - s.mu.Lock() - s.unsubscribed = true - s.mu.Unlock() case <-subscription.Err(): - s.mu.Lock() - s.unsubscribed = true - s.mu.Unlock() + } + if s.unsubscribed != nil { + s.unsubscribed <- string(subscription.ID) } }() @@ -107,7 +96,7 @@ func (s *NotificationTestService) HangSubscription(ctx context.Context, val int) func TestNotifications(t *testing.T) { server := NewServer() - service := &NotificationTestService{} + service := &NotificationTestService{unsubscribed: make(chan string)} if err := server.RegisterName("eth", service); err != nil { t.Fatalf("unable to register test service %v", err) @@ -157,10 +146,10 @@ func TestNotifications(t *testing.T) { } clientConn.Close() // causes notification unsubscribe callback to be called - time.Sleep(1 * time.Second) - - if !service.wasUnsubCallbackCalled() { - t.Error("unsubscribe callback not called after closing connection") + select { + case <-service.unsubscribed: + case <-time.After(1 * time.Second): + t.Fatal("Unsubscribe not called after one second") } } @@ -227,18 +216,19 @@ func waitForMessages(t *testing.T, in *json.Decoder, successes chan<- jsonSucces // for multiple different namespaces. func TestSubscriptionMultipleNamespaces(t *testing.T) { var ( - namespaces = []string{"eth", "shh", "bzz"} + namespaces = []string{"eth", "shh", "bzz"} + service = NotificationTestService{} + subCount = len(namespaces) * 2 + notificationCount = 3 + server = NewServer() - service = NotificationTestService{} clientConn, serverConn = net.Pipe() - - out = json.NewEncoder(clientConn) - in = json.NewDecoder(clientConn) - successes = make(chan jsonSuccessResponse) - failures = make(chan jsonErrResponse) - notifications = make(chan jsonNotification) - - errors = make(chan error, 10) + out = json.NewEncoder(clientConn) + in = json.NewDecoder(clientConn) + successes = make(chan jsonSuccessResponse) + failures = make(chan jsonErrResponse) + notifications = make(chan jsonNotification) + errors = make(chan error, 10) ) // setup and start server @@ -255,13 +245,12 @@ func TestSubscriptionMultipleNamespaces(t *testing.T) { go waitForMessages(t, in, successes, failures, notifications, errors) // create subscriptions one by one - n := 3 for i, namespace := range namespaces { request := map[string]interface{}{ "id": i, "method": fmt.Sprintf("%s_subscribe", namespace), "version": "2.0", - "params": []interface{}{"someSubscription", n, i}, + "params": []interface{}{"someSubscription", notificationCount, i}, } if err := out.Encode(&request); err != nil { @@ -276,7 +265,7 @@ func TestSubscriptionMultipleNamespaces(t *testing.T) { "id": i, "method": fmt.Sprintf("%s_subscribe", namespace), "version": "2.0", - "params": []interface{}{"someSubscription", n, i}, + "params": []interface{}{"someSubscription", notificationCount, i}, }) } @@ -285,46 +274,40 @@ func TestSubscriptionMultipleNamespaces(t *testing.T) { } timeout := time.After(30 * time.Second) - subids := make(map[string]string, 2*len(namespaces)) - count := make(map[string]int, 2*len(namespaces)) - - for { - done := true - for id := range count { - if count, found := count[id]; !found || count < (2*n) { + subids := make(map[string]string, subCount) + count := make(map[string]int, subCount) + allReceived := func() bool { + done := len(count) == subCount + for _, c := range count { + if c < notificationCount { done = false } } + return done + } - if done && len(count) == len(namespaces) { - break - } - + for !allReceived() { select { - case err := <-errors: - t.Fatal(err) case suc := <-successes: // subscription created subids[namespaces[int(suc.Id.(float64))]] = suc.Result.(string) + case notification := <-notifications: + count[notification.Params.Subscription]++ + case err := <-errors: + t.Fatal(err) case failure := <-failures: t.Errorf("received error: %v", failure.Error) - case notification := <-notifications: - if cnt, found := count[notification.Params.Subscription]; found { - count[notification.Params.Subscription] = cnt + 1 - } else { - count[notification.Params.Subscription] = 1 - } case <-timeout: for _, namespace := range namespaces { subid, found := subids[namespace] if !found { - t.Errorf("Subscription for '%s' not created", namespace) + t.Errorf("subscription for %q not created", namespace) continue } - if count, found := count[subid]; !found || count < n { - t.Errorf("Didn't receive all notifications (%d<%d) in time for namespace '%s'", count, n, namespace) + if count, found := count[subid]; !found || count < notificationCount { + t.Errorf("didn't receive all notifications (%d<%d) in time for namespace %q", count, notificationCount, namespace) } } - return + t.Fatal("timed out") } } } |