diff --git a/server/client.go b/server/client.go index e867ec62..6de72ab7 100644 --- a/server/client.go +++ b/server/client.go @@ -212,6 +212,8 @@ func (c *client) processRouteInfo(info *Info) { s.remotes[info.ID] = c s.mu.Unlock() Debug("Registering remote route", info.ID) + // Send our local subscriptions to this route. + s.sendLocalSubsToRoute(c) } } @@ -656,17 +658,19 @@ func (c *client) processMsg(msg []byte) { var rmap map[string]struct{} // If we are a route and we have a queue subscription, deliver direct - // since they are sent direct via L2 semantics. + // since they are sent direct via L2 semantics. If the match is a queue + // subscription, we will return from here regardless if we find a sub. if isRoute { - if sub := srv.routeSidQueueSubscriber(c.pa.sid); sub != nil { - mh := c.msgHeader(msgh[:si], sub) - c.deliverMsg(sub, mh, msg) + if sub, ok := srv.routeSidQueueSubscriber(c.pa.sid); ok { + if sub != nil { + mh := c.msgHeader(msgh[:si], sub) + c.deliverMsg(sub, mh, msg) + } return } } // Loop over all subscriptions that match. - for _, v := range r { sub := v.(*subscription) @@ -856,6 +860,11 @@ func (c *client) closeConnection() { for _, s := range subs { if sub, ok := s.(*subscription); ok { srv.sl.Remove(sub.subject, sub) + // Forward on unsubscribes if we are not + // a router ourselves. + if c.typ != ROUTER { + srv.broadcastUnSubscribe(sub) + } } } } diff --git a/server/const.go b/server/const.go index 158e6e2a..c5c84b61 100644 --- a/server/const.go +++ b/server/const.go @@ -8,7 +8,7 @@ import ( const ( // VERSION is the current version for the server. - VERSION = "0.4.6" + VERSION = "0.5.0" // DEFAULT_PORT is the deault port for client connections. DEFAULT_PORT = 4222 diff --git a/server/route.go b/server/route.go index 656cbf10..d20dccd8 100644 --- a/server/route.go +++ b/server/route.go @@ -119,9 +119,6 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client { s.routes[c.cid] = c s.mu.Unlock() - // Send our local subscriptions to this route. - s.sendLocalSubsToRoute(c) - return c } @@ -151,24 +148,24 @@ const ( // FIXME(dlc) - This may be too slow, check at later date. var qrsidRe = regexp.MustCompile(`QRSID:(\d+):([^\s]+)`) -func (s *Server) routeSidQueueSubscriber(rsid []byte) *subscription { +func (s *Server) routeSidQueueSubscriber(rsid []byte) (*subscription, bool) { if !bytes.HasPrefix(rsid, []byte(QRSID)) { - return nil + return nil, false } matches := qrsidRe.FindSubmatch(rsid) if matches == nil || len(matches) != EXPECTED_MATCHES { - return nil + return nil, false } cid := uint64(parseInt64(matches[RSID_CID_INDEX])) client := s.clients[cid] if client == nil { - return nil + return nil, true } sid := matches[RSID_SID_INDEX] if sub, ok := (client.subs.Get(sid)).(*subscription); ok { - return sub + return sub, true } - return nil + return nil, true } func routeSid(sub *subscription) string { diff --git a/server/server.go b/server/server.go index b7c4e90e..88bacff3 100644 --- a/server/server.go +++ b/server/server.go @@ -438,7 +438,11 @@ func (s *Server) removeClient(c *client) { case ROUTER: delete(s.routes, cid) if c.route != nil { - delete(s.remotes, c.route.remoteID) + rc, ok := s.remotes[c.route.remoteID] + // Only delete it if it is us.. + if ok && c == rc { + delete(s.remotes, c.route.remoteID) + } } } s.mu.Unlock() @@ -468,3 +472,11 @@ func (s *Server) NumClients() int { defer s.mu.Unlock() return len(s.clients) } + +// NumSubscriptions will report how many subscriptions are active. +func (s *Server) NumSubscriptions() uint32 { + s.mu.Lock() + defer s.mu.Unlock() + stats := s.sl.Stats() + return stats.NumSubs +} diff --git a/test/cluster_test.go b/test/cluster_test.go index 34e80a5d..312fc483 100644 --- a/test/cluster_test.go +++ b/test/cluster_test.go @@ -3,6 +3,7 @@ package test import ( +"fmt" "testing" "time" @@ -63,3 +64,257 @@ func TestBasicClusterPubSub(t *testing.T) { matches := expectMsgs(1) checkMsg(t, matches[0], "foo", "22", "", "2", "ok") } + +func TestClusterQueueSubs(t *testing.T) { + srvA, srvB, optsA, optsB := runServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() + + clientA := createClientConn(t, optsA.Host, optsA.Port) + defer clientA.Close() + + clientB := createClientConn(t, optsB.Host, optsB.Port) + defer clientB.Close() + + sendA, expectA := setupConn(t, clientA) + sendB, expectB := setupConn(t, clientB) + + expectMsgsA := expectMsgsCommand(t, expectA) + expectMsgsB := expectMsgsCommand(t, expectB) + + // Capture sids for checking later. + qg1Sids_a := []string{"1", "2", "3"} + + // Three queue subscribers + for _, sid := range qg1Sids_a { + sendA(fmt.Sprintf("SUB foo qg1 %s\r\n", sid)) + } + sendA("PING\r\n") + expectA(pongRe) + + sendB("PUB foo 2\r\nok\r\n") + sendB("PING\r\n") + expectB(pongRe) + + // Make sure we get only 1. + matches := expectMsgsA(1) + checkMsg(t, matches[0], "foo", "", "", "2", "ok") + + // Capture sids for checking later. + pSids := []string{"4", "5", "6"} + + // Create 3 normal subscribers + for _, sid := range pSids { + sendA(fmt.Sprintf("SUB foo %s\r\n", sid)) + } + + // Create a FWC Subscriber + pSids = append(pSids, "7") + sendA("SUB > 7\r\n") + sendA("PING\r\n") + expectA(pongRe) + + // Send to B + sendB("PUB foo 2\r\nok\r\n") + sendB("PING\r\n") + expectB(pongRe) + + // Should receive 5. + matches = expectMsgsA(5) + checkForQueueSid(t, matches, qg1Sids_a) + checkForPubSids(t, matches, pSids) + + // Send to A + sendA("PUB foo 2\r\nok\r\n") + + // Should receive 5. + matches = expectMsgsA(5) + checkForQueueSid(t, matches, qg1Sids_a) + checkForPubSids(t, matches, pSids) + + // Now add queue subscribers to B + qg2Sids_b := []string{"1", "2", "3"} + for _, sid := range qg2Sids_b { + sendB(fmt.Sprintf("SUB foo qg2 %s\r\n", sid)) + } + sendB("PING\r\n") + expectB(pongRe) + + // Send to B + sendB("PUB foo 2\r\nok\r\n") + + // Should receive 1 from B. + matches = expectMsgsB(1) + checkForQueueSid(t, matches, qg2Sids_b) + + // Should receive 5 still from A. + matches = expectMsgsA(5) + checkForQueueSid(t, matches, qg1Sids_a) + checkForPubSids(t, matches, pSids) + + // Now drop queue subscribers from A + for _, sid := range qg1Sids_a { + sendA(fmt.Sprintf("UNSUB %s\r\n", sid)) + } + sendA("PING\r\n") + expectA(pongRe) + + // Send to B + sendB("PUB foo 2\r\nok\r\n") + sendB("PING\r\n") + expectB(pongRe) + + // Should receive 4 now. + matches = expectMsgsA(4) + checkForPubSids(t, matches, pSids) + + // Send to A + sendA("PUB foo 2\r\nok\r\n") + + // Should receive 4 now. + matches = expectMsgsA(4) + checkForPubSids(t, matches, pSids) +} + +// Issue #22 +func TestClusterDoubleMsgs(t *testing.T) { + srvA, srvB, optsA, optsB := runServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() + + clientA1 := createClientConn(t, optsA.Host, optsA.Port) + defer clientA1.Close() + + clientA2 := createClientConn(t, optsA.Host, optsA.Port) + defer clientA2.Close() + + clientB := createClientConn(t, optsB.Host, optsB.Port) + defer clientB.Close() + + sendA1, expectA1 := setupConn(t, clientA1) + sendA2, expectA2 := setupConn(t, clientA2) + sendB, expectB := setupConn(t, clientB) + + expectMsgsA1 := expectMsgsCommand(t, expectA1) + expectMsgsA2 := expectMsgsCommand(t, expectA2) + + // Capture sids for checking later. + qg1Sids_a := []string{"1", "2", "3"} + + // Three queue subscribers + for _, sid := range qg1Sids_a { + sendA1(fmt.Sprintf("SUB foo qg1 %s\r\n", sid)) + } + sendA1("PING\r\n") + expectA1(pongRe) + + sendB("PUB foo 2\r\nok\r\n") + sendB("PING\r\n") + expectB(pongRe) + + // Make sure we get only 1. + matches := expectMsgsA1(1) + checkMsg(t, matches[0], "foo", "", "", "2", "ok") + checkForQueueSid(t, matches, qg1Sids_a) + + // Add a FWC subscriber on A2 + sendA2("SUB > 1\r\n") + sendA2("SUB foo 2\r\n") + sendA2("PING\r\n") + expectA2(pongRe) + pSids := []string{"1", "2"} + + sendB("PUB foo 2\r\nok\r\n") + sendB("PING\r\n") + expectB(pongRe) + + matches = expectMsgsA1(1) + checkMsg(t, matches[0], "foo", "", "", "2", "ok") + checkForQueueSid(t, matches, qg1Sids_a) + + matches = expectMsgsA2(2) + checkMsg(t, matches[0], "foo", "", "", "2", "ok") + checkForPubSids(t, matches, pSids) + + // Close ClientA1 + clientA1.Close() + +// time.Sleep(time.Second) + + sendB("PUB foo 2\r\nok\r\n") + sendB("PING\r\n") + expectB(pongRe) + + matches = expectMsgsA2(2) + checkMsg(t, matches[0], "foo", "", "", "2", "ok") + checkForPubSids(t, matches, pSids) +} + +// This will test that we drop remote sids correctly. +func TestClusterDropsRemoteSids(t *testing.T) { + srvA, srvB, optsA, _ := runServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() + + clientA := createClientConn(t, optsA.Host, optsA.Port) + defer clientA.Close() + + sendA, expectA := setupConn(t, clientA) + + // Add a subscription + sendA("SUB foo 1\r\n") + sendA("PING\r\n") + expectA(pongRe) + + // Wait for propogation. + time.Sleep(50 * time.Millisecond) + + if sc := srvA.NumSubscriptions(); sc != 1 { + t.Fatalf("Expected one subscription for srvA, got %d\n", sc) + } + if sc := srvB.NumSubscriptions(); sc != 1 { + t.Fatalf("Expected one subscription for srvB, got %d\n", sc) + } + + // Add another subscription + sendA("SUB bar 2\r\n") + sendA("PING\r\n") + expectA(pongRe) + + // Wait for propogation. + time.Sleep(50 * time.Millisecond) + + if sc := srvA.NumSubscriptions(); sc != 2 { + t.Fatalf("Expected two subscriptions for srvA, got %d\n", sc) + } + if sc := srvB.NumSubscriptions(); sc != 2 { + t.Fatalf("Expected two subscriptions for srvB, got %d\n", sc) + } + + // unsubscription + sendA("UNSUB 1\r\n") + sendA("PING\r\n") + expectA(pongRe) + + // Wait for propogation. + time.Sleep(50 * time.Millisecond) + + if sc := srvA.NumSubscriptions(); sc != 1 { + t.Fatalf("Expected one subscription for srvA, got %d\n", sc) + } + if sc := srvB.NumSubscriptions(); sc != 1 { + t.Fatalf("Expected one subscription for srvB, got %d\n", sc) + } + + // Close the client and make sure we remove subscription state. + clientA.Close() + + // Wait for propogation. + time.Sleep(50 * time.Millisecond) + if sc := srvA.NumSubscriptions(); sc != 0 { + t.Fatalf("Expected no subscriptions for srvA, got %d\n", sc) + } + if sc := srvB.NumSubscriptions(); sc != 0 { + t.Fatalf("Expected no subscriptions for srvB, got %d\n", sc) + } +} diff --git a/test/routes_test.go b/test/routes_test.go index 3e78c9ca..c369a3ec 100644 --- a/test/routes_test.go +++ b/test/routes_test.go @@ -415,17 +415,34 @@ func TestRouteResendsLocalSubsOnReconnect(t *testing.T) { clientExpect(pongRe) route := createRouteConn(t, opts.ClusterHost, opts.ClusterPort) - _, routeExpect := setupRouteEx(t, route, opts, "ROUTE:4222") + routeSend, routeExpect := setupRouteEx(t, route, opts, "ROUTE:4222") // Expect to see the local sub echoed through. + buf := routeExpect(infoRe) + + // Generate our own so we can send one to trigger the local subs. + info := server.Info{} + if err := json.Unmarshal(buf[4:], &info); err != nil { + t.Fatalf("Could not unmarshal route info: %v", err) + } + info.ID = "ROUTE:4222" + b, err := json.Marshal(info) + if err != nil { + t.Fatalf("Could not marshal test route info: %v", err) + } + infoJson := fmt.Sprintf("INFO %s\r\n", b) + + routeSend(infoJson) routeExpect(subRe) // Close and re-open route.Close() route = createRouteConn(t, opts.ClusterHost, opts.ClusterPort) - _, routeExpect = setupRouteEx(t, route, opts, "ROUTE:4222") + routeSend, routeExpect = setupRouteEx(t, route, opts, "ROUTE:4222") - // Expect to see the local sub echoed through. + // Expect to see the local sub echoed through after info. + routeExpect(infoRe) + routeSend(infoJson) routeExpect(subRe) } diff --git a/test/test.go b/test/test.go index b46fb9ce..51dbc329 100644 --- a/test/test.go +++ b/test/test.go @@ -337,3 +337,48 @@ func expectMsgsCommand(t tLogger, ef expectFun) func(int) [][][]byte { return matches } } + +// This will check that the matches include at least one of the sids. Useful for checking +// that we received messages on a certain queue group. +func checkForQueueSid(t tLogger, matches [][][]byte, sids []string) { + seen := make(map[string]int, len(sids)) + for _, sid := range sids { + seen[sid] = 0 + } + for _, m := range matches { + sid := string(m[SID_INDEX]) + if _, ok := seen[sid]; ok { + seen[sid] += 1 + } + } + // Make sure we only see one and exactly one. + total := 0 + for _, n := range seen { + total += n + } + if total != 1 { + stackFatalf(t, "Did not get a msg for queue sids group: expected 1 got %d\n", total) + } +} + +// This will check that the matches include all of the sids. Useful for checking +// that we received messages on all subscribers. +func checkForPubSids(t tLogger, matches [][][]byte, sids []string) { + seen := make(map[string]int, len(sids)) + for _, sid := range sids { + seen[sid] = 0 + } + for _, m := range matches { + sid := string(m[SID_INDEX]) + if _, ok := seen[sid]; ok { + seen[sid] += 1 + } + } + // Make sure we only see one and exactly one for each sid. + for sid, n := range seen { + if n != 1 { + stackFatalf(t, "Did not get a msg for sid[%s]: expected 1 got %d\n", sid, n) + + } + } +}