mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 11:48:43 -07:00
[FIXED] Locking issue around account lookup/updates
Ensure that lookupAccount does not hold server lock during updateAccount and fetchAccount. Updating the account cannot have the server lock because it is possible that during updateAccountClaims(), clients are being removed, which would try to get the server lock (deep down in closeConnection/s.removeClient). Added a test that would have show the deadlock prior to changes in this PR. Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
This commit is contained in:
@@ -291,19 +291,15 @@ func (a *Account) numLocalLeafNodes() int {
|
||||
|
||||
// MaxTotalConnectionsReached returns if we have reached our limit for number of connections.
|
||||
func (a *Account) MaxTotalConnectionsReached() bool {
|
||||
var mtc bool
|
||||
a.mu.RLock()
|
||||
mtc := a.maxTotalConnectionsReached()
|
||||
if a.mconns != jwt.NoLimit {
|
||||
mtc = len(a.clients)-int(a.sysclients)+int(a.nrclients) >= int(a.mconns)
|
||||
}
|
||||
a.mu.RUnlock()
|
||||
return mtc
|
||||
}
|
||||
|
||||
func (a *Account) maxTotalConnectionsReached() bool {
|
||||
if a.mconns != jwt.NoLimit {
|
||||
return len(a.clients)-int(a.sysclients)+int(a.nrclients) >= int(a.mconns)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MaxActiveConnections return the set limit for the account system
|
||||
// wide for total number of active connections.
|
||||
func (a *Account) MaxActiveConnections() int {
|
||||
@@ -1456,8 +1452,9 @@ func (s *Server) SetAccountResolver(ar AccountResolver) {
|
||||
// AccountResolver returns the registered account resolver.
|
||||
func (s *Server) AccountResolver() AccountResolver {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.accResolver
|
||||
ar := s.accResolver
|
||||
s.mu.Unlock()
|
||||
return ar
|
||||
}
|
||||
|
||||
// UpdateAccountClaims will call updateAccountClaims.
|
||||
@@ -1467,6 +1464,7 @@ func (s *Server) UpdateAccountClaims(a *Account, ac *jwt.AccountClaims) {
|
||||
|
||||
// updateAccountClaims will update an existing account with new claims.
|
||||
// This will replace any exports or imports previously defined.
|
||||
// Lock MUST NOT be held upon entry.
|
||||
func (s *Server) updateAccountClaims(a *Account, ac *jwt.AccountClaims) {
|
||||
if a == nil {
|
||||
return
|
||||
@@ -1547,22 +1545,10 @@ func (s *Server) updateAccountClaims(a *Account, ac *jwt.AccountClaims) {
|
||||
}
|
||||
}
|
||||
for _, i := range ac.Imports {
|
||||
var acc *Account
|
||||
if v, ok := s.accounts.Load(i.Account); ok {
|
||||
acc = v.(*Account)
|
||||
}
|
||||
if acc == nil {
|
||||
// Check to see if the account referenced is not one that
|
||||
// we are currently building (but not yet fully registered).
|
||||
if v, ok := s.tmpAccounts.Load(i.Account); ok {
|
||||
acc = v.(*Account)
|
||||
}
|
||||
}
|
||||
if acc == nil {
|
||||
if acc, _ = s.fetchAccount(i.Account); acc == nil {
|
||||
s.Debugf("Can't locate account [%s] for import of [%v] %s", i.Account, i.Subject, i.Type)
|
||||
continue
|
||||
}
|
||||
acc, err := s.lookupAccount(i.Account)
|
||||
if acc == nil || err != nil {
|
||||
s.Errorf("Can't locate account [%s] for import of [%v] %s (err=%v)", i.Account, i.Subject, i.Type, err)
|
||||
continue
|
||||
}
|
||||
switch i.Type {
|
||||
case jwt.Stream:
|
||||
@@ -1645,14 +1631,17 @@ func (s *Server) updateAccountClaims(a *Account, ac *jwt.AccountClaims) {
|
||||
|
||||
clients := gatherClients()
|
||||
// Sort if we are over the limit.
|
||||
if a.maxTotalConnectionsReached() {
|
||||
if a.MaxTotalConnectionsReached() {
|
||||
sort.Slice(clients, func(i, j int) bool {
|
||||
return clients[i].start.After(clients[j].start)
|
||||
})
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
for i, c := range clients {
|
||||
if a.mconns != jwt.NoLimit && i >= int(a.mconns) {
|
||||
a.mu.RLock()
|
||||
exceeded := a.mconns != jwt.NoLimit && i >= int(a.mconns)
|
||||
a.mu.RUnlock()
|
||||
if exceeded {
|
||||
c.maxAccountConnExceeded()
|
||||
continue
|
||||
}
|
||||
@@ -1690,6 +1679,7 @@ func (s *Server) updateAccountClaims(a *Account, ac *jwt.AccountClaims) {
|
||||
}
|
||||
|
||||
// Helper to build an internal account structure from a jwt.AccountClaims.
|
||||
// Lock MUST NOT be held upon entry.
|
||||
func (s *Server) buildInternalAccount(ac *jwt.AccountClaims) *Account {
|
||||
acc := NewAccount(ac.Subject)
|
||||
acc.Issuer = ac.Issuer
|
||||
|
||||
@@ -572,9 +572,7 @@ func (s *Server) initEventTracking() {
|
||||
|
||||
// accountClaimUpdate will receive claim updates for accounts.
|
||||
func (s *Server) accountClaimUpdate(sub *subscription, _ *client, subject, reply string, msg []byte) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if !s.eventsEnabled() {
|
||||
if !s.EventsEnabled() {
|
||||
return
|
||||
}
|
||||
toks := strings.Split(subject, tsep)
|
||||
@@ -903,8 +901,8 @@ func (s *Server) accountDisconnectEvent(c *client, now time.Time, reason string)
|
||||
RTT: c.getRTT(),
|
||||
},
|
||||
Sent: DataStats{
|
||||
Msgs: c.inMsgs,
|
||||
Bytes: c.inBytes,
|
||||
Msgs: atomic.LoadInt64(&c.inMsgs),
|
||||
Bytes: atomic.LoadInt64(&c.inBytes),
|
||||
},
|
||||
Received: DataStats{
|
||||
Msgs: c.outMsgs,
|
||||
|
||||
@@ -828,7 +828,7 @@ func TestSystemAccountConnectionLimitsServersStaggered(t *testing.T) {
|
||||
}
|
||||
|
||||
// Restart server B.
|
||||
optsB.AccountResolver = sa.accResolver
|
||||
optsB.AccountResolver = sa.AccountResolver()
|
||||
optsB.SystemAccount = sa.systemAccount().Name
|
||||
sb = RunServer(optsB)
|
||||
defer sb.Shutdown()
|
||||
@@ -1409,10 +1409,8 @@ func TestFetchAccountRace(t *testing.T) {
|
||||
|
||||
// Replace B's account resolver with one that introduces
|
||||
// delay during the Fetch()
|
||||
sb.mu.Lock()
|
||||
sac := &slowAccResolver{AccountResolver: sb.accResolver}
|
||||
sb.accResolver = sac
|
||||
sb.mu.Unlock()
|
||||
sac := &slowAccResolver{AccountResolver: sb.AccountResolver()}
|
||||
sb.SetAccountResolver(sac)
|
||||
|
||||
// Add the account in sa and sb
|
||||
addAccountToMemResolver(sa, userAcc, jwt)
|
||||
|
||||
@@ -46,15 +46,11 @@ func opTrustBasicSetup() *Server {
|
||||
|
||||
func buildMemAccResolver(s *Server) {
|
||||
mr := &MemAccResolver{}
|
||||
s.mu.Lock()
|
||||
s.accResolver = mr
|
||||
s.mu.Unlock()
|
||||
s.SetAccountResolver(mr)
|
||||
}
|
||||
|
||||
func addAccountToMemResolver(s *Server, pub, jwtclaim string) {
|
||||
s.mu.Lock()
|
||||
s.accResolver.Store(pub, jwtclaim)
|
||||
s.mu.Unlock()
|
||||
s.AccountResolver().Store(pub, jwtclaim)
|
||||
}
|
||||
|
||||
func createClient(t *testing.T, s *Server, akp nkeys.KeyPair) (*client, *bufio.Reader, string) {
|
||||
@@ -2318,3 +2314,81 @@ func TestJWTCircularAccountServiceImport(t *testing.T) {
|
||||
parseAsync("SUB foo 1\r\nPING\r\n")
|
||||
expectPong(cr)
|
||||
}
|
||||
|
||||
// This test ensures that connected clients are properly evicted
|
||||
// (no deadlock) if the max conns of an account has been lowered
|
||||
// and the account is being updated (following expiration during
|
||||
// a lookup).
|
||||
func TestJWTAccountLimitsMaxConnsAfterExpired(t *testing.T) {
|
||||
s := opTrustBasicSetup()
|
||||
defer s.Shutdown()
|
||||
buildMemAccResolver(s)
|
||||
|
||||
okp, _ := nkeys.FromSeed(oSeed)
|
||||
|
||||
// Create accounts and imports/exports.
|
||||
fooKP, _ := nkeys.CreateAccount()
|
||||
fooPub, _ := fooKP.PublicKey()
|
||||
fooAC := jwt.NewAccountClaims(fooPub)
|
||||
fooAC.Limits.Conn = 10
|
||||
fooJWT, err := fooAC.Encode(okp)
|
||||
if err != nil {
|
||||
t.Fatalf("Error generating account JWT: %v", err)
|
||||
}
|
||||
addAccountToMemResolver(s, fooPub, fooJWT)
|
||||
|
||||
newClient := func(expPre string) {
|
||||
t.Helper()
|
||||
// Create a client.
|
||||
c, cr, cs := createClient(t, s, fooKP)
|
||||
go c.parse([]byte(cs))
|
||||
l, _ := cr.ReadString('\n')
|
||||
if !strings.HasPrefix(l, expPre) {
|
||||
t.Fatalf("Expected a response starting with %q, got %q", expPre, l)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
if _, _, err := cr.ReadLine(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
newClient("PONG")
|
||||
}
|
||||
|
||||
// We will simulate that the account has expired. When
|
||||
// a new client will connect, the server will do a lookup
|
||||
// and find the account expired, which then will cause
|
||||
// a fetch and a rebuild of the account. Since max conns
|
||||
// is now lower, some clients should have been removed.
|
||||
acc, _ := s.LookupAccount(fooPub)
|
||||
acc.mu.Lock()
|
||||
acc.expired = true
|
||||
acc.mu.Unlock()
|
||||
|
||||
// Now update with new expiration and max connections lowered to 2
|
||||
fooAC.Limits.Conn = 2
|
||||
fooJWT, err = fooAC.Encode(okp)
|
||||
if err != nil {
|
||||
t.Fatalf("Error generating account JWT: %v", err)
|
||||
}
|
||||
addAccountToMemResolver(s, fooPub, fooJWT)
|
||||
|
||||
// Cause the lookup that will detect that account was expired
|
||||
// and rebuild it, and kick clients out.
|
||||
newClient("-ERR ")
|
||||
|
||||
acc, _ = s.LookupAccount(fooPub)
|
||||
checkFor(t, 2*time.Second, 15*time.Millisecond, func() error {
|
||||
acc.mu.RLock()
|
||||
numClients := len(acc.clients)
|
||||
acc.mu.RUnlock()
|
||||
if numClients != 2 {
|
||||
return fmt.Errorf("Should have 2 clients, got %v", numClients)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -912,6 +912,8 @@ func (s *Server) reloadAuthorization() {
|
||||
acc.mu.RLock()
|
||||
accName := acc.Name
|
||||
acc.mu.RUnlock()
|
||||
// Release server lock for following actions
|
||||
s.mu.Unlock()
|
||||
accClaims, claimJWT, _ := s.fetchAccountClaims(accName)
|
||||
if accClaims != nil {
|
||||
err := s.updateAccountWithClaimJWT(acc, claimJWT)
|
||||
@@ -923,9 +925,10 @@ func (s *Server) reloadAuthorization() {
|
||||
s.Noticef("Reloaded: deleting account [removed]: %q", accName)
|
||||
s.accounts.Delete(k)
|
||||
}
|
||||
// Regrab server lock.
|
||||
s.mu.Lock()
|
||||
return true
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -312,8 +312,8 @@ func NewServer(opts *Options) (*Server, error) {
|
||||
if r.LocalAccount == _EMPTY_ {
|
||||
continue
|
||||
}
|
||||
if _, err := s.lookupAccount(r.LocalAccount); err != nil {
|
||||
return nil, fmt.Errorf("no local account %q for leafnode: %v", r.LocalAccount, err)
|
||||
if _, ok := s.accounts.Load(r.LocalAccount); !ok {
|
||||
return nil, fmt.Errorf("no local account %q for remote leafnode", r.LocalAccount)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -375,6 +375,7 @@ func (s *Server) globalAccount() *Account {
|
||||
}
|
||||
|
||||
// Used to setup Accounts.
|
||||
// Lock is held upon entry.
|
||||
func (s *Server) configureAccounts() error {
|
||||
// Create global account.
|
||||
if s.gacc == nil {
|
||||
@@ -437,7 +438,7 @@ func (s *Server) configureAccounts() error {
|
||||
|
||||
// Set the system account if it was configured.
|
||||
if opts.SystemAccount != _EMPTY_ {
|
||||
// Lock is held entering this function, so release to call lookupAccount.
|
||||
// Lock may be acquired in lookupAccount, so release to call lookupAccount.
|
||||
s.mu.Unlock()
|
||||
_, err := s.lookupAccount(opts.SystemAccount)
|
||||
s.mu.Lock()
|
||||
@@ -684,16 +685,15 @@ func (s *Server) SetSystemAccount(accName string) error {
|
||||
return s.setSystemAccount(v.(*Account))
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
// If we are here we do not have local knowledge of this account.
|
||||
// Do this one by hand to return more useful error.
|
||||
ac, jwt, err := s.fetchAccountClaims(accName)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
acc := s.buildInternalAccount(ac)
|
||||
acc.claimJWT = jwt
|
||||
s.mu.Lock()
|
||||
s.registerAccount(acc)
|
||||
s.mu.Unlock()
|
||||
|
||||
@@ -842,34 +842,35 @@ func (s *Server) registerAccount(acc *Account) {
|
||||
|
||||
// lookupAccount is a function to return the account structure
|
||||
// associated with an account name.
|
||||
// Lock MUST NOT be held upon entry.
|
||||
func (s *Server) lookupAccount(name string) (*Account, error) {
|
||||
var acc *Account
|
||||
if v, ok := s.accounts.Load(name); ok {
|
||||
acc := v.(*Account)
|
||||
acc = v.(*Account)
|
||||
} else if v, ok := s.tmpAccounts.Load(name); ok {
|
||||
acc = v.(*Account)
|
||||
}
|
||||
if acc != nil {
|
||||
// If we are expired and we have a resolver, then
|
||||
// return the latest information from the resolver.
|
||||
if acc.IsExpired() {
|
||||
s.Debugf("Requested account [%s] has expired", name)
|
||||
var err error
|
||||
s.mu.Lock()
|
||||
if s.accResolver != nil {
|
||||
err = s.updateAccount(acc)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
if err != nil {
|
||||
// This error could mask expired, so just return expired here.
|
||||
if s.AccountResolver() != nil {
|
||||
if err := s.updateAccount(acc); err != nil {
|
||||
// This error could mask expired, so just return expired here.
|
||||
return nil, ErrAccountExpired
|
||||
}
|
||||
} else {
|
||||
return nil, ErrAccountExpired
|
||||
}
|
||||
}
|
||||
return acc, nil
|
||||
}
|
||||
// If we have a resolver see if it can fetch the account.
|
||||
if s.accResolver == nil {
|
||||
if s.AccountResolver() == nil {
|
||||
return nil, ErrMissingAccount
|
||||
}
|
||||
s.mu.Lock()
|
||||
acc, err := s.fetchAccount(name)
|
||||
s.mu.Unlock()
|
||||
return acc, err
|
||||
return s.fetchAccount(name)
|
||||
}
|
||||
|
||||
// LookupAccount is a public function to return the account structure
|
||||
@@ -879,7 +880,7 @@ func (s *Server) LookupAccount(name string) (*Account, error) {
|
||||
}
|
||||
|
||||
// This will fetch new claims and if found update the account with new claims.
|
||||
// Lock should be held upon entry.
|
||||
// Lock MUST NOT be held upon entry.
|
||||
func (s *Server) updateAccount(acc *Account) error {
|
||||
// TODO(dlc) - Make configurable
|
||||
if time.Since(acc.updated) < time.Second {
|
||||
@@ -894,6 +895,7 @@ func (s *Server) updateAccount(acc *Account) error {
|
||||
}
|
||||
|
||||
// updateAccountWithClaimJWT will check and apply the claim update.
|
||||
// Lock MUST NOT be held upon entry.
|
||||
func (s *Server) updateAccountWithClaimJWT(acc *Account, claimJWT string) error {
|
||||
if acc == nil {
|
||||
return ErrMissingAccount
|
||||
@@ -913,18 +915,16 @@ func (s *Server) updateAccountWithClaimJWT(acc *Account, claimJWT string) error
|
||||
}
|
||||
|
||||
// fetchRawAccountClaims will grab raw account claims iff we have a resolver.
|
||||
// Lock is held upon entry.
|
||||
// Lock is NOT held upon entry.
|
||||
func (s *Server) fetchRawAccountClaims(name string) (string, error) {
|
||||
accResolver := s.accResolver
|
||||
accResolver := s.AccountResolver()
|
||||
if accResolver == nil {
|
||||
return "", ErrNoAccountResolver
|
||||
}
|
||||
// Need to do actual Fetch without the lock.
|
||||
s.mu.Unlock()
|
||||
// Need to do actual Fetch
|
||||
start := time.Now()
|
||||
claimJWT, err := accResolver.Fetch(name)
|
||||
fetchTime := time.Since(start)
|
||||
s.mu.Lock()
|
||||
if fetchTime > time.Second {
|
||||
s.Warnf("Account [%s] fetch took %v", name, fetchTime)
|
||||
} else {
|
||||
@@ -938,7 +938,7 @@ func (s *Server) fetchRawAccountClaims(name string) (string, error) {
|
||||
}
|
||||
|
||||
// fetchAccountClaims will attempt to fetch new claims if a resolver is present.
|
||||
// Lock is held upon entry.
|
||||
// Lock is NOT held upon entry.
|
||||
func (s *Server) fetchAccountClaims(name string) (*jwt.AccountClaims, string, error) {
|
||||
claimJWT, err := s.fetchRawAccountClaims(name)
|
||||
if err != nil {
|
||||
@@ -962,7 +962,7 @@ func (s *Server) verifyAccountClaims(claimJWT string) (*jwt.AccountClaims, strin
|
||||
}
|
||||
|
||||
// This will fetch an account from a resolver if defined.
|
||||
// Lock should be held upon entry.
|
||||
// Lock is NOT held upon entry.
|
||||
func (s *Server) fetchAccount(name string) (*Account, error) {
|
||||
accClaims, claimJWT, err := s.fetchAccountClaims(name)
|
||||
if accClaims != nil {
|
||||
@@ -982,7 +982,9 @@ func (s *Server) fetchAccount(name string) (*Account, error) {
|
||||
}
|
||||
acc := s.buildInternalAccount(accClaims)
|
||||
acc.claimJWT = claimJWT
|
||||
s.mu.Lock()
|
||||
s.registerAccount(acc)
|
||||
s.mu.Unlock()
|
||||
return acc, nil
|
||||
}
|
||||
return nil, err
|
||||
|
||||
@@ -29,7 +29,9 @@ func TestSplitBufferSubOp(t *testing.T) {
|
||||
t.Fatalf("Error creating gateways: %v", err)
|
||||
}
|
||||
s := &Server{gacc: NewAccount(globalAccountName), gateway: gws}
|
||||
s.mu.Lock()
|
||||
s.registerAccount(s.gacc)
|
||||
s.mu.Unlock()
|
||||
c := &client{srv: s, acc: s.gacc, msubs: -1, mpay: -1, mcl: 1024, subs: make(map[string]*subscription), nc: cli}
|
||||
|
||||
subop := []byte("SUB foo 1\r\n")
|
||||
@@ -66,7 +68,9 @@ func TestSplitBufferSubOp(t *testing.T) {
|
||||
|
||||
func TestSplitBufferUnsubOp(t *testing.T) {
|
||||
s := &Server{gacc: NewAccount(globalAccountName), gateway: &srvGateway{}}
|
||||
s.mu.Lock()
|
||||
s.registerAccount(s.gacc)
|
||||
s.mu.Unlock()
|
||||
c := &client{srv: s, acc: s.gacc, msubs: -1, mpay: -1, mcl: 1024, subs: make(map[string]*subscription)}
|
||||
|
||||
subop := []byte("SUB foo 1024\r\n")
|
||||
|
||||
Reference in New Issue
Block a user