diff --git a/server/accounts.go b/server/accounts.go index 8c6d1476..9c5598aa 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -291,19 +291,15 @@ func (a *Account) numLocalLeafNodes() int { // MaxTotalConnectionsReached returns if we have reached our limit for number of connections. func (a *Account) MaxTotalConnectionsReached() bool { + var mtc bool a.mu.RLock() - mtc := a.maxTotalConnectionsReached() + if a.mconns != jwt.NoLimit { + mtc = len(a.clients)-int(a.sysclients)+int(a.nrclients) >= int(a.mconns) + } a.mu.RUnlock() return mtc } -func (a *Account) maxTotalConnectionsReached() bool { - if a.mconns != jwt.NoLimit { - return len(a.clients)-int(a.sysclients)+int(a.nrclients) >= int(a.mconns) - } - return false -} - // MaxActiveConnections return the set limit for the account system // wide for total number of active connections. func (a *Account) MaxActiveConnections() int { @@ -1456,8 +1452,9 @@ func (s *Server) SetAccountResolver(ar AccountResolver) { // AccountResolver returns the registered account resolver. func (s *Server) AccountResolver() AccountResolver { s.mu.Lock() - defer s.mu.Unlock() - return s.accResolver + ar := s.accResolver + s.mu.Unlock() + return ar } // UpdateAccountClaims will call updateAccountClaims. @@ -1467,6 +1464,7 @@ func (s *Server) UpdateAccountClaims(a *Account, ac *jwt.AccountClaims) { // updateAccountClaims will update an existing account with new claims. // This will replace any exports or imports previously defined. +// Lock MUST NOT be held upon entry. func (s *Server) updateAccountClaims(a *Account, ac *jwt.AccountClaims) { if a == nil { return @@ -1547,22 +1545,10 @@ func (s *Server) updateAccountClaims(a *Account, ac *jwt.AccountClaims) { } } for _, i := range ac.Imports { - var acc *Account - if v, ok := s.accounts.Load(i.Account); ok { - acc = v.(*Account) - } - if acc == nil { - // Check to see if the account referenced is not one that - // we are currently building (but not yet fully registered). - if v, ok := s.tmpAccounts.Load(i.Account); ok { - acc = v.(*Account) - } - } - if acc == nil { - if acc, _ = s.fetchAccount(i.Account); acc == nil { - s.Debugf("Can't locate account [%s] for import of [%v] %s", i.Account, i.Subject, i.Type) - continue - } + acc, err := s.lookupAccount(i.Account) + if acc == nil || err != nil { + s.Errorf("Can't locate account [%s] for import of [%v] %s (err=%v)", i.Account, i.Subject, i.Type, err) + continue } switch i.Type { case jwt.Stream: @@ -1645,14 +1631,17 @@ func (s *Server) updateAccountClaims(a *Account, ac *jwt.AccountClaims) { clients := gatherClients() // Sort if we are over the limit. - if a.maxTotalConnectionsReached() { + if a.MaxTotalConnectionsReached() { sort.Slice(clients, func(i, j int) bool { return clients[i].start.After(clients[j].start) }) } now := time.Now().Unix() for i, c := range clients { - if a.mconns != jwt.NoLimit && i >= int(a.mconns) { + a.mu.RLock() + exceeded := a.mconns != jwt.NoLimit && i >= int(a.mconns) + a.mu.RUnlock() + if exceeded { c.maxAccountConnExceeded() continue } @@ -1690,6 +1679,7 @@ func (s *Server) updateAccountClaims(a *Account, ac *jwt.AccountClaims) { } // Helper to build an internal account structure from a jwt.AccountClaims. +// Lock MUST NOT be held upon entry. func (s *Server) buildInternalAccount(ac *jwt.AccountClaims) *Account { acc := NewAccount(ac.Subject) acc.Issuer = ac.Issuer diff --git a/server/events.go b/server/events.go index 49705270..e213f497 100644 --- a/server/events.go +++ b/server/events.go @@ -572,9 +572,7 @@ func (s *Server) initEventTracking() { // accountClaimUpdate will receive claim updates for accounts. func (s *Server) accountClaimUpdate(sub *subscription, _ *client, subject, reply string, msg []byte) { - s.mu.Lock() - defer s.mu.Unlock() - if !s.eventsEnabled() { + if !s.EventsEnabled() { return } toks := strings.Split(subject, tsep) @@ -903,8 +901,8 @@ func (s *Server) accountDisconnectEvent(c *client, now time.Time, reason string) RTT: c.getRTT(), }, Sent: DataStats{ - Msgs: c.inMsgs, - Bytes: c.inBytes, + Msgs: atomic.LoadInt64(&c.inMsgs), + Bytes: atomic.LoadInt64(&c.inBytes), }, Received: DataStats{ Msgs: c.outMsgs, diff --git a/server/events_test.go b/server/events_test.go index 2a21083c..e4a2dc20 100644 --- a/server/events_test.go +++ b/server/events_test.go @@ -828,7 +828,7 @@ func TestSystemAccountConnectionLimitsServersStaggered(t *testing.T) { } // Restart server B. - optsB.AccountResolver = sa.accResolver + optsB.AccountResolver = sa.AccountResolver() optsB.SystemAccount = sa.systemAccount().Name sb = RunServer(optsB) defer sb.Shutdown() @@ -1409,10 +1409,8 @@ func TestFetchAccountRace(t *testing.T) { // Replace B's account resolver with one that introduces // delay during the Fetch() - sb.mu.Lock() - sac := &slowAccResolver{AccountResolver: sb.accResolver} - sb.accResolver = sac - sb.mu.Unlock() + sac := &slowAccResolver{AccountResolver: sb.AccountResolver()} + sb.SetAccountResolver(sac) // Add the account in sa and sb addAccountToMemResolver(sa, userAcc, jwt) diff --git a/server/jwt_test.go b/server/jwt_test.go index 84a12d71..5c5721c4 100644 --- a/server/jwt_test.go +++ b/server/jwt_test.go @@ -46,15 +46,11 @@ func opTrustBasicSetup() *Server { func buildMemAccResolver(s *Server) { mr := &MemAccResolver{} - s.mu.Lock() - s.accResolver = mr - s.mu.Unlock() + s.SetAccountResolver(mr) } func addAccountToMemResolver(s *Server, pub, jwtclaim string) { - s.mu.Lock() - s.accResolver.Store(pub, jwtclaim) - s.mu.Unlock() + s.AccountResolver().Store(pub, jwtclaim) } func createClient(t *testing.T, s *Server, akp nkeys.KeyPair) (*client, *bufio.Reader, string) { @@ -2318,3 +2314,81 @@ func TestJWTCircularAccountServiceImport(t *testing.T) { parseAsync("SUB foo 1\r\nPING\r\n") expectPong(cr) } + +// This test ensures that connected clients are properly evicted +// (no deadlock) if the max conns of an account has been lowered +// and the account is being updated (following expiration during +// a lookup). +func TestJWTAccountLimitsMaxConnsAfterExpired(t *testing.T) { + s := opTrustBasicSetup() + defer s.Shutdown() + buildMemAccResolver(s) + + okp, _ := nkeys.FromSeed(oSeed) + + // Create accounts and imports/exports. + fooKP, _ := nkeys.CreateAccount() + fooPub, _ := fooKP.PublicKey() + fooAC := jwt.NewAccountClaims(fooPub) + fooAC.Limits.Conn = 10 + fooJWT, err := fooAC.Encode(okp) + if err != nil { + t.Fatalf("Error generating account JWT: %v", err) + } + addAccountToMemResolver(s, fooPub, fooJWT) + + newClient := func(expPre string) { + t.Helper() + // Create a client. + c, cr, cs := createClient(t, s, fooKP) + go c.parse([]byte(cs)) + l, _ := cr.ReadString('\n') + if !strings.HasPrefix(l, expPre) { + t.Fatalf("Expected a response starting with %q, got %q", expPre, l) + } + go func() { + for { + if _, _, err := cr.ReadLine(); err != nil { + return + } + } + }() + } + + for i := 0; i < 4; i++ { + newClient("PONG") + } + + // We will simulate that the account has expired. When + // a new client will connect, the server will do a lookup + // and find the account expired, which then will cause + // a fetch and a rebuild of the account. Since max conns + // is now lower, some clients should have been removed. + acc, _ := s.LookupAccount(fooPub) + acc.mu.Lock() + acc.expired = true + acc.mu.Unlock() + + // Now update with new expiration and max connections lowered to 2 + fooAC.Limits.Conn = 2 + fooJWT, err = fooAC.Encode(okp) + if err != nil { + t.Fatalf("Error generating account JWT: %v", err) + } + addAccountToMemResolver(s, fooPub, fooJWT) + + // Cause the lookup that will detect that account was expired + // and rebuild it, and kick clients out. + newClient("-ERR ") + + acc, _ = s.LookupAccount(fooPub) + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + acc.mu.RLock() + numClients := len(acc.clients) + acc.mu.RUnlock() + if numClients != 2 { + return fmt.Errorf("Should have 2 clients, got %v", numClients) + } + return nil + }) +} diff --git a/server/reload.go b/server/reload.go index 9260d3de..cde8767d 100644 --- a/server/reload.go +++ b/server/reload.go @@ -912,6 +912,8 @@ func (s *Server) reloadAuthorization() { acc.mu.RLock() accName := acc.Name acc.mu.RUnlock() + // Release server lock for following actions + s.mu.Unlock() accClaims, claimJWT, _ := s.fetchAccountClaims(accName) if accClaims != nil { err := s.updateAccountWithClaimJWT(acc, claimJWT) @@ -923,9 +925,10 @@ func (s *Server) reloadAuthorization() { s.Noticef("Reloaded: deleting account [removed]: %q", accName) s.accounts.Delete(k) } + // Regrab server lock. + s.mu.Lock() return true }) - } } diff --git a/server/server.go b/server/server.go index fd9cf56a..e1adec36 100644 --- a/server/server.go +++ b/server/server.go @@ -312,8 +312,8 @@ func NewServer(opts *Options) (*Server, error) { if r.LocalAccount == _EMPTY_ { continue } - if _, err := s.lookupAccount(r.LocalAccount); err != nil { - return nil, fmt.Errorf("no local account %q for leafnode: %v", r.LocalAccount, err) + if _, ok := s.accounts.Load(r.LocalAccount); !ok { + return nil, fmt.Errorf("no local account %q for remote leafnode", r.LocalAccount) } } } @@ -375,6 +375,7 @@ func (s *Server) globalAccount() *Account { } // Used to setup Accounts. +// Lock is held upon entry. func (s *Server) configureAccounts() error { // Create global account. if s.gacc == nil { @@ -437,7 +438,7 @@ func (s *Server) configureAccounts() error { // Set the system account if it was configured. if opts.SystemAccount != _EMPTY_ { - // Lock is held entering this function, so release to call lookupAccount. + // Lock may be acquired in lookupAccount, so release to call lookupAccount. s.mu.Unlock() _, err := s.lookupAccount(opts.SystemAccount) s.mu.Lock() @@ -684,16 +685,15 @@ func (s *Server) SetSystemAccount(accName string) error { return s.setSystemAccount(v.(*Account)) } - s.mu.Lock() // If we are here we do not have local knowledge of this account. // Do this one by hand to return more useful error. ac, jwt, err := s.fetchAccountClaims(accName) if err != nil { - s.mu.Unlock() return err } acc := s.buildInternalAccount(ac) acc.claimJWT = jwt + s.mu.Lock() s.registerAccount(acc) s.mu.Unlock() @@ -842,34 +842,35 @@ func (s *Server) registerAccount(acc *Account) { // lookupAccount is a function to return the account structure // associated with an account name. +// Lock MUST NOT be held upon entry. func (s *Server) lookupAccount(name string) (*Account, error) { + var acc *Account if v, ok := s.accounts.Load(name); ok { - acc := v.(*Account) + acc = v.(*Account) + } else if v, ok := s.tmpAccounts.Load(name); ok { + acc = v.(*Account) + } + if acc != nil { // If we are expired and we have a resolver, then // return the latest information from the resolver. if acc.IsExpired() { s.Debugf("Requested account [%s] has expired", name) - var err error - s.mu.Lock() - if s.accResolver != nil { - err = s.updateAccount(acc) - } - s.mu.Unlock() - if err != nil { - // This error could mask expired, so just return expired here. + if s.AccountResolver() != nil { + if err := s.updateAccount(acc); err != nil { + // This error could mask expired, so just return expired here. + return nil, ErrAccountExpired + } + } else { return nil, ErrAccountExpired } } return acc, nil } // If we have a resolver see if it can fetch the account. - if s.accResolver == nil { + if s.AccountResolver() == nil { return nil, ErrMissingAccount } - s.mu.Lock() - acc, err := s.fetchAccount(name) - s.mu.Unlock() - return acc, err + return s.fetchAccount(name) } // LookupAccount is a public function to return the account structure @@ -879,7 +880,7 @@ func (s *Server) LookupAccount(name string) (*Account, error) { } // This will fetch new claims and if found update the account with new claims. -// Lock should be held upon entry. +// Lock MUST NOT be held upon entry. func (s *Server) updateAccount(acc *Account) error { // TODO(dlc) - Make configurable if time.Since(acc.updated) < time.Second { @@ -894,6 +895,7 @@ func (s *Server) updateAccount(acc *Account) error { } // updateAccountWithClaimJWT will check and apply the claim update. +// Lock MUST NOT be held upon entry. func (s *Server) updateAccountWithClaimJWT(acc *Account, claimJWT string) error { if acc == nil { return ErrMissingAccount @@ -913,18 +915,16 @@ func (s *Server) updateAccountWithClaimJWT(acc *Account, claimJWT string) error } // fetchRawAccountClaims will grab raw account claims iff we have a resolver. -// Lock is held upon entry. +// Lock is NOT held upon entry. func (s *Server) fetchRawAccountClaims(name string) (string, error) { - accResolver := s.accResolver + accResolver := s.AccountResolver() if accResolver == nil { return "", ErrNoAccountResolver } - // Need to do actual Fetch without the lock. - s.mu.Unlock() + // Need to do actual Fetch start := time.Now() claimJWT, err := accResolver.Fetch(name) fetchTime := time.Since(start) - s.mu.Lock() if fetchTime > time.Second { s.Warnf("Account [%s] fetch took %v", name, fetchTime) } else { @@ -938,7 +938,7 @@ func (s *Server) fetchRawAccountClaims(name string) (string, error) { } // fetchAccountClaims will attempt to fetch new claims if a resolver is present. -// Lock is held upon entry. +// Lock is NOT held upon entry. func (s *Server) fetchAccountClaims(name string) (*jwt.AccountClaims, string, error) { claimJWT, err := s.fetchRawAccountClaims(name) if err != nil { @@ -962,7 +962,7 @@ func (s *Server) verifyAccountClaims(claimJWT string) (*jwt.AccountClaims, strin } // This will fetch an account from a resolver if defined. -// Lock should be held upon entry. +// Lock is NOT held upon entry. func (s *Server) fetchAccount(name string) (*Account, error) { accClaims, claimJWT, err := s.fetchAccountClaims(name) if accClaims != nil { @@ -982,7 +982,9 @@ func (s *Server) fetchAccount(name string) (*Account, error) { } acc := s.buildInternalAccount(accClaims) acc.claimJWT = claimJWT + s.mu.Lock() s.registerAccount(acc) + s.mu.Unlock() return acc, nil } return nil, err diff --git a/server/split_test.go b/server/split_test.go index b0029380..5fa4286d 100644 --- a/server/split_test.go +++ b/server/split_test.go @@ -29,7 +29,9 @@ func TestSplitBufferSubOp(t *testing.T) { t.Fatalf("Error creating gateways: %v", err) } s := &Server{gacc: NewAccount(globalAccountName), gateway: gws} + s.mu.Lock() s.registerAccount(s.gacc) + s.mu.Unlock() c := &client{srv: s, acc: s.gacc, msubs: -1, mpay: -1, mcl: 1024, subs: make(map[string]*subscription), nc: cli} subop := []byte("SUB foo 1\r\n") @@ -66,7 +68,9 @@ func TestSplitBufferSubOp(t *testing.T) { func TestSplitBufferUnsubOp(t *testing.T) { s := &Server{gacc: NewAccount(globalAccountName), gateway: &srvGateway{}} + s.mu.Lock() s.registerAccount(s.gacc) + s.mu.Unlock() c := &client{srv: s, acc: s.gacc, msubs: -1, mpay: -1, mcl: 1024, subs: make(map[string]*subscription)} subop := []byte("SUB foo 1024\r\n")