diff --git a/server/accounts.go b/server/accounts.go index 9b464227..e329541e 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -2937,26 +2937,34 @@ func (dr *DirAccResolver) Start(s *Server) error { } } packRespIb := s.newRespInbox() - // subscribe to account jwt update requests - if _, err := s.sysSubscribe(fmt.Sprintf(accUpdateEventSubj, "*"), func(_ *subscription, _ *client, subj, resp string, msg []byte) { - tk := strings.Split(subj, tsep) - if len(tk) != accUpdateTokens { - return + for _, reqSub := range []string{accUpdateEventSubjOld, accUpdateEventSubjNew} { + // subscribe to account jwt update requests + if _, err := s.sysSubscribe(fmt.Sprintf(reqSub, "*"), func(_ *subscription, _ *client, subj, resp string, msg []byte) { + pubKey := "" + tk := strings.Split(subj, tsep) + if len(tk) == accUpdateTokensNew { + pubKey = tk[accReqAccIndex] + } else if len(tk) == accUpdateTokensOld { + pubKey = tk[accUpdateAccIdxOld] + } else { + s.Debugf("jwt update skipped due to bad subject %q", subj) + return + } + if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil { + respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err) + } else if claim.Subject != pubKey { + err := errors.New("subject does not match jwt content") + respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err) + } else if err := dr.save(pubKey, string(msg)); err != nil { + respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err) + } else { + respondToUpdate(s, resp, pubKey, "jwt updated", nil) + } + }); err != nil { + return fmt.Errorf("error setting up update handling: %v", err) } - pubKey := tk[accUpdateAccIndex] - if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil { - respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err) - } else if claim.Subject != pubKey { - err := errors.New("subject does not match jwt content") - respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err) - } else if err := dr.save(pubKey, string(msg)); err != nil { - respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err) - } else { - respondToUpdate(s, resp, pubKey, "jwt updated", nil) - } - }); err != nil { - return fmt.Errorf("error setting up update handling: %v", err) - } else if _, err := s.sysSubscribe(fmt.Sprintf(accLookupReqSubj, "*"), func(_ *subscription, _ *client, subj, reply string, msg []byte) { + } + if _, err := s.sysSubscribe(fmt.Sprintf(accLookupReqSubj, "*"), func(_ *subscription, _ *client, subj, reply string, msg []byte) { // respond to lookups with our version if reply == "" { return @@ -3131,27 +3139,34 @@ func (dr *CacheDirAccResolver) Start(s *Server) error { s.Errorf("update resulted in error %v", err) } } - // subscribe to account jwt update requests - if _, err := s.sysSubscribe(fmt.Sprintf(accUpdateEventSubj, "*"), func(_ *subscription, _ *client, subj, resp string, msg []byte) { - tk := strings.Split(subj, tsep) - if len(tk) != accUpdateTokens { - return + for _, reqSub := range []string{accUpdateEventSubjOld, accUpdateEventSubjNew} { + // subscribe to account jwt update requests + if _, err := s.sysSubscribe(fmt.Sprintf(reqSub, "*"), func(_ *subscription, _ *client, subj, resp string, msg []byte) { + pubKey := "" + tk := strings.Split(subj, tsep) + if len(tk) == accUpdateTokensNew { + pubKey = tk[accReqAccIndex] + } else if len(tk) == accUpdateTokensOld { + pubKey = tk[accUpdateAccIdxOld] + } else { + s.Debugf("jwt update cache skipped due to bad subject %q", subj) + return + } + if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil { + respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err) + } else if claim.Subject != pubKey { + err := errors.New("subject does not match jwt content") + respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err) + } else if _, ok := s.accounts.Load(pubKey); !ok { + respondToUpdate(s, resp, pubKey, "jwt update cache skipped", nil) + } else if err := dr.save(pubKey, string(msg)); err != nil { + respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err) + } else { + respondToUpdate(s, resp, pubKey, "jwt updated cache", nil) + } + }); err != nil { + return fmt.Errorf("error setting up update handling: %v", err) } - pubKey := tk[accUpdateAccIndex] - if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil { - respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err) - } else if claim.Subject != pubKey { - err := errors.New("subject does not match jwt content") - respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err) - } else if _, ok := s.accounts.Load(pubKey); !ok { - respondToUpdate(s, resp, pubKey, "jwt update cache skipped", nil) - } else if err := dr.save(pubKey, string(msg)); err != nil { - respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err) - } else { - respondToUpdate(s, resp, pubKey, "jwt updated cache", nil) - } - }); err != nil { - return fmt.Errorf("error setting up update handling: %v", err) } s.Noticef("Managing some jwt in exclusive directory %s", dr.directory) return nil diff --git a/server/events.go b/server/events.go index 7974161b..e8aace92 100644 --- a/server/events.go +++ b/server/events.go @@ -18,6 +18,7 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" + "errors" "fmt" "math/rand" "net/http" @@ -34,12 +35,15 @@ import ( const ( accLookupReqTokens = 6 accLookupReqSubj = "$SYS.REQ.ACCOUNT.%s.CLAIMS.LOOKUP" - accPackReqSubj = "$SYS.REQ.ACCOUNT.CLAIMS.PACK" + accPackReqSubj = "$SYS.REQ.CLAIMS.PACK" - connectEventSubj = "$SYS.ACCOUNT.%s.CONNECT" - disconnectEventSubj = "$SYS.ACCOUNT.%s.DISCONNECT" - accConnsReqSubj = "$SYS.REQ.ACCOUNT.%s.CONNS" - accUpdateEventSubj = "$SYS.ACCOUNT.%s.CLAIMS.UPDATE" + connectEventSubj = "$SYS.ACCOUNT.%s.CONNECT" + disconnectEventSubj = "$SYS.ACCOUNT.%s.DISCONNECT" + accConnsReqSubj = "$SYS.REQ.ACCOUNT.%s.CONNS" + // kept for backward compatibility when using http resolver + // this overlaps with the names for events but you'd have to have the operator private key in order to succeed. + accUpdateEventSubjOld = "$SYS.ACCOUNT.%s.CLAIMS.UPDATE" + accUpdateEventSubjNew = "$SYS.REQ.ACCOUNT.%s.CLAIMS.UPDATE" connsRespSubj = "$SYS._INBOX_.%s" accConnsEventSubjNew = "$SYS.ACCOUNT.%s.SERVER.CONNS" accConnsEventSubjOld = "$SYS.SERVER.ACCOUNT.%s.CONNS" // kept for backward compatibility @@ -62,8 +66,9 @@ const ( shutdownEventTokens = 4 serverSubjectIndex = 2 - accUpdateTokens = 5 - accUpdateAccIndex = 2 + accUpdateTokensNew = 6 + accUpdateTokensOld = 5 + accUpdateAccIdxOld = 2 accReqTokens = 5 accReqAccIndex = 3 @@ -602,9 +607,10 @@ func (s *Server) initEventTracking() { subscribeToUpdate = !s.accResolver.IsTrackingUpdate() } if subscribeToUpdate { - subject = fmt.Sprintf(accUpdateEventSubj, "*") - if _, err := s.sysSubscribe(subject, s.accountClaimUpdate); err != nil { - s.Errorf("Error setting up internal tracking: %v", err) + for _, sub := range []string{accUpdateEventSubjOld, accUpdateEventSubjNew} { + if _, err := s.sysSubscribe(fmt.Sprintf(sub, "*"), s.accountClaimUpdate); err != nil { + s.Errorf("Error setting up internal tracking: %v", err) + } } } // Listen for ping messages that will be sent to all servers for statsz. @@ -678,22 +684,31 @@ func (s *Server) addSystemAccountExports(sacc *Account) { } // accountClaimUpdate will receive claim updates for accounts. -func (s *Server) accountClaimUpdate(sub *subscription, _ *client, subject, reply string, msg []byte) { +func (s *Server) accountClaimUpdate(sub *subscription, _ *client, subject, resp string, msg []byte) { if !s.EventsEnabled() { return } + pubKey := "" toks := strings.Split(subject, tsep) - if len(toks) < accUpdateTokens { + if len(toks) == accUpdateTokensNew { + pubKey = toks[accReqAccIndex] + } else if len(toks) == accUpdateTokensOld { + pubKey = toks[accUpdateAccIdxOld] + } else { s.Debugf("Received account claims update on bad subject %q", subject) return } - pubKey := toks[accUpdateAccIndex] if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil { - s.Debugf("Received account claims update with bad jwt: %v", err) + respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err) } else if claim.Subject != pubKey { - s.Debugf("Received account claims update where jwt content does not match subject") - } else if v, ok := s.accounts.Load(pubKey); ok { - s.updateAccountWithClaimJWT(v.(*Account), string(msg)) + err := errors.New("subject does not match jwt content") + respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err) + } else if v, ok := s.accounts.Load(pubKey); !ok { + respondToUpdate(s, resp, pubKey, "jwt update skipped", nil) + } else if err := s.updateAccountWithClaimJWT(v.(*Account), string(msg)); err != nil { + respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err) + } else { + respondToUpdate(s, resp, pubKey, "jwt updated", nil) } } diff --git a/server/events_test.go b/server/events_test.go index 01180acd..0606261e 100644 --- a/server/events_test.go +++ b/server/events_test.go @@ -1146,53 +1146,61 @@ func TestSystemAccountFromConfig(t *testing.T) { } func TestAccountClaimsUpdates(t *testing.T) { - s, opts := runTrustedServer(t) - defer s.Shutdown() + test := func(subj string) { + s, opts := runTrustedServer(t) + defer s.Shutdown() - sacc, sakp := createAccount(s) - s.setSystemAccount(sacc) + sacc, sakp := createAccount(s) + s.setSystemAccount(sacc) - // Let's create a normal account with limits we can update. - okp, _ := nkeys.FromSeed(oSeed) - akp, _ := nkeys.CreateAccount() - pub, _ := akp.PublicKey() - nac := jwt.NewAccountClaims(pub) - nac.Limits.Conn = 4 - ajwt, _ := nac.Encode(okp) + // Let's create a normal account with limits we can update. + okp, _ := nkeys.FromSeed(oSeed) + akp, _ := nkeys.CreateAccount() + pub, _ := akp.PublicKey() + nac := jwt.NewAccountClaims(pub) + nac.Limits.Conn = 4 + ajwt, _ := nac.Encode(okp) - addAccountToMemResolver(s, pub, ajwt) + addAccountToMemResolver(s, pub, ajwt) - acc, _ := s.LookupAccount(pub) - if acc.MaxActiveConnections() != 4 { - t.Fatalf("Expected to see a limit of 4 connections") - } - - // Simulate a systems publisher so we can do an account claims update. - url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) - nc, err := nats.Connect(url, createUserCreds(t, s, sakp)) - if err != nil { - t.Fatalf("Error on connect: %v", err) - } - defer nc.Close() - - // Update the account - nac = jwt.NewAccountClaims(pub) - nac.Limits.Conn = 8 - issAt := time.Now().Add(-30 * time.Second).Unix() - nac.IssuedAt = issAt - expires := time.Now().Add(2 * time.Second).Unix() - nac.Expires = expires - ajwt, _ = nac.Encode(okp) - - // Publish to the system update subject. - claimUpdateSubj := fmt.Sprintf(accUpdateEventSubj, pub) - nc.Publish(claimUpdateSubj, []byte(ajwt)) - nc.Flush() - - acc, _ = s.LookupAccount(pub) - if acc.MaxActiveConnections() != 8 { - t.Fatalf("Account was not updated") + acc, _ := s.LookupAccount(pub) + if acc.MaxActiveConnections() != 4 { + t.Fatalf("Expected to see a limit of 4 connections") + } + + // Simulate a systems publisher so we can do an account claims update. + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(url, createUserCreds(t, s, sakp)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + // Update the account + nac = jwt.NewAccountClaims(pub) + nac.Limits.Conn = 8 + issAt := time.Now().Add(-30 * time.Second).Unix() + nac.IssuedAt = issAt + expires := time.Now().Add(2 * time.Second).Unix() + nac.Expires = expires + ajwt, _ = nac.Encode(okp) + + // Publish to the system update subject. + claimUpdateSubj := fmt.Sprintf(subj, pub) + nc.Publish(claimUpdateSubj, []byte(ajwt)) + nc.Flush() + + acc, _ = s.LookupAccount(pub) + if acc.MaxActiveConnections() != 8 { + t.Fatalf("Account was not updated") + } } + t.Run("new", func(t *testing.T) { + test(accUpdateEventSubjNew) + }) + t.Run("old", func(t *testing.T) { + test(accUpdateEventSubjOld) + }) } func TestAccountClaimsUpdatesWithServiceImports(t *testing.T) { @@ -1241,7 +1249,7 @@ func TestAccountClaimsUpdatesWithServiceImports(t *testing.T) { ajwt2, _ = nac2.Encode(okp) // Publish to the system update subject. - claimUpdateSubj := fmt.Sprintf(accUpdateEventSubj, pub2) + claimUpdateSubj := fmt.Sprintf(accUpdateEventSubjNew, pub2) nc.Publish(claimUpdateSubj, []byte(ajwt2)) } nc.Flush() @@ -1390,7 +1398,7 @@ func TestSystemAccountWithGateways(t *testing.T) { // If this tests fails with wrong number after 10 seconds we may have // added a new inititial subscription for the eventing system. - checkExpectedSubs(t, 27, sa) + checkExpectedSubs(t, 28, sa) // Create a client on B and see if we receive the event urlb := fmt.Sprintf("nats://%s:%d", ob.Host, ob.Port) diff --git a/server/jwt_test.go b/server/jwt_test.go index d9686846..8424095d 100644 --- a/server/jwt_test.go +++ b/server/jwt_test.go @@ -1904,7 +1904,7 @@ func TestAccountURLResolverFetchFailurePushReorder(t *testing.T) { // update expjwt2, this will correct the import issue sysc := natsConnect(t, sA.ClientURL(), createUserCreds(t, nil, syskp)) defer sysc.Close() - natsPub(t, sysc, fmt.Sprintf(accUpdateEventSubj, exppub), []byte(expjwt2)) + natsPub(t, sysc, fmt.Sprintf(accUpdateEventSubjNew, exppub), []byte(expjwt2)) sysc.Flush() // updating expjwt should cause this to pass checkSubInterest(t, sA, imppub, subj, 10*time.Second) @@ -1995,8 +1995,8 @@ func TestAccountURLResolverPermanentFetchFailure(t *testing.T) { sysc := natsConnect(t, sA.ClientURL(), createUserCreds(t, nil, syskp)) defer sysc.Close() // push accounts - natsPub(t, sysc, fmt.Sprintf(accUpdateEventSubj, imppub), []byte(impjwt)) - natsPub(t, sysc, fmt.Sprintf(accUpdateEventSubj, exppub), []byte(expjwt)) + natsPub(t, sysc, fmt.Sprintf(accUpdateEventSubjNew, imppub), []byte(impjwt)) + natsPub(t, sysc, fmt.Sprintf(accUpdateEventSubjNew, exppub), []byte(expjwt)) sysc.Flush() importErrCnt := 0 tmr := time.NewTimer(500 * time.Millisecond) @@ -3293,7 +3293,7 @@ func TestAccountNATSResolverFetch(t *testing.T) { sub := natsSubSync(t, c, resp) err := sub.AutoUnsubscribe(3) require_NoError(t, err) - require_NoError(t, c.PublishRequest(fmt.Sprintf(accUpdateEventSubj, pubKey), resp, []byte(jwt))) + require_NoError(t, c.PublishRequest(fmt.Sprintf(accUpdateEventSubjNew, pubKey), resp, []byte(jwt))) passCnt := 0 if require_NextMsg(sub) { passCnt++ @@ -3842,7 +3842,7 @@ func TestJWTJetStreamLimits(t *testing.T) { t.Helper() c := natsConnect(t, url, nats.UserCredentials(creds)) defer c.Close() - if msg, err := c.Request(fmt.Sprintf(accUpdateEventSubj, pubKey), []byte(jwt), time.Second); err != nil { + if msg, err := c.Request(fmt.Sprintf(accUpdateEventSubjNew, pubKey), []byte(jwt), time.Second); err != nil { t.Fatal("error not expected in this test", err) } else { content := make(map[string]interface{})