diff --git a/server/accounts.go b/server/accounts.go index 437b548c..39ab2e99 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -41,6 +41,7 @@ type Account struct { claimJWT string updated time.Time mu sync.RWMutex + sqmu sync.Mutex sl *Sublist etmr *time.Timer ctmr *time.Timer diff --git a/server/accounts_test.go b/server/accounts_test.go index a0fd945d..a7ecaec5 100644 --- a/server/accounts_test.go +++ b/server/accounts_test.go @@ -20,9 +20,11 @@ import ( "os" "strconv" "strings" + "sync" "testing" "time" + "github.com/nats-io/nats.go" "github.com/nats-io/nkeys" ) @@ -2054,6 +2056,47 @@ func TestAccountCheckStreamImportsEqual(t *testing.T) { } } +func TestAccountNoDeadlockOnQueueSubRouteMapUpdate(t *testing.T) { + opts := DefaultOptions() + s := RunServer(opts) + defer s.Shutdown() + + nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + nc.QueueSubscribeSync("foo", "bar") + + var accs []*Account + for i := 0; i < 10; i++ { + acc, _ := s.RegisterAccount(fmt.Sprintf("acc%d", i)) + acc.mu.Lock() + accs = append(accs, acc) + } + + opts2 := DefaultOptions() + opts2.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", opts.Cluster.Host, opts.Cluster.Port)) + s2 := RunServer(opts2) + defer s2.Shutdown() + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + time.Sleep(100 * time.Millisecond) + for _, acc := range accs { + acc.mu.Unlock() + } + wg.Done() + }() + + nc.QueueSubscribeSync("foo", "bar") + nc.Flush() + + wg.Wait() +} + func BenchmarkNewRouteReply(b *testing.B) { opts := defaultServerOptions s := New(&opts) diff --git a/server/route.go b/server/route.go index f5e15d88..8b589c75 100644 --- a/server/route.go +++ b/server/route.go @@ -1320,18 +1320,36 @@ func (s *Server) updateRouteSubscriptionMap(acc *Account, sub *subscription, del var n int32 var ok bool - acc.mu.Lock() + isq := len(sub.queue) > 0 + + accLock := func() { + // Not required for code correctness, but helps reduce the number of + // updates sent to the routes when processing high number of concurrent + // queue subscriptions updates (sub/unsub). + // See https://github.com/nats-io/nats-server/pull/1126 ffor more details. + if isq { + acc.sqmu.Lock() + } + acc.mu.Lock() + } + accUnlock := func() { + acc.mu.Unlock() + if isq { + acc.sqmu.Unlock() + } + } + + accLock() // This is non-nil when we know we are in cluster mode. rm, lqws := acc.rm, acc.lqws if rm == nil { - acc.mu.Unlock() + accUnlock() return } // Create the fast key which will use the subject or 'subjectqueue' for queue subscribers. key := keyFromSub(sub) - isq := len(sub.queue) > 0 // Decide whether we need to send an update out to all the routes. update := isq @@ -1356,7 +1374,7 @@ func (s *Server) updateRouteSubscriptionMap(acc *Account, sub *subscription, del update = true // Adding a new entry for normal sub means update (0->1) } - acc.mu.Unlock() + accUnlock() if !update { return @@ -1388,17 +1406,23 @@ func (s *Server) updateRouteSubscriptionMap(acc *Account, sub *subscription, del // here but not necessarily all updates need to be sent. We need to block and recheck the // n count with the lock held through sending here. We will suppress duplicate sends of same qw. if isq { + // However, we can't hold the acc.mu lock since we allow client.mu.Lock -> acc.mu.Lock + // but not the opposite. So use a dedicated lock while holding the route's lock. + acc.sqmu.Lock() + defer acc.sqmu.Unlock() + acc.mu.Lock() - defer acc.mu.Unlock() n = rm[key] sub.qw = n // Check the last sent weight here. If same, then someone // beat us to it and we can just return here. Otherwise update if ls, ok := lqws[key]; ok && ls == n { + acc.mu.Unlock() return } else { lqws[key] = n } + acc.mu.Unlock() } // Snapshot into array diff --git a/test/norace_test.go b/test/norace_test.go index d194301d..93457699 100644 --- a/test/norace_test.go +++ b/test/norace_test.go @@ -380,6 +380,7 @@ func TestQueueSubWeightOrderMultipleConnections(t *testing.T) { // we just want to make sure we always are increasing and that a previous update to // a lesser queue weight is never delivered for this test. maxExpected := 10000 + updates := 0 for qw := 0; qw < maxExpected; { buf := routeExpect(rsubRe) matches := rsubRe.FindAllSubmatch(buf, -1) @@ -397,6 +398,10 @@ func TestQueueSubWeightOrderMultipleConnections(t *testing.T) { t.Fatalf("Was expecting increasing queue weight after %d, got %d", qw, nqw) } qw = nqw + updates++ } } + if updates >= maxExpected { + t.Fatalf("Was not expecting all %v updates to be received", maxExpected) + } }