diff --git a/server/accounts.go b/server/accounts.go index b4d054b5..2ab4f539 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -2857,14 +2857,13 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim } if len(a.signingKeys) != len(old.signingKeys) { signersChanged = true - } else { - for k, scope := range a.signingKeys { - if oldScope, ok := old.signingKeys[k]; !ok { - signersChanged = true - } else if !reflect.DeepEqual(scope, oldScope) { - signersChanged = true - alteredScope[k] = struct{}{} - } + } + for k, scope := range a.signingKeys { + if oldScope, ok := old.signingKeys[k]; !ok { + signersChanged = true + } else if !reflect.DeepEqual(scope, oldScope) { + signersChanged = true + alteredScope[k] = struct{}{} } } a.mu.Unlock() diff --git a/server/auth.go b/server/auth.go index b12fa165..f6f92405 100644 --- a/server/auth.go +++ b/server/auth.go @@ -567,7 +567,7 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo return false } else if scope != nil { if err := scope.ValidateScopedSigner(juc); err != nil { - c.Debugf("User JWT is not valid") + c.Debugf("User JWT is not valid: %v", err) return false } else if uSc, ok := scope.(*jwt.UserScope); !ok { c.Debugf("User JWT is not valid") diff --git a/server/jwt_test.go b/server/jwt_test.go index 40a71777..ce65f932 100644 --- a/server/jwt_test.go +++ b/server/jwt_test.go @@ -5035,6 +5035,7 @@ func TestJWScopedSigningKeys(t *testing.T) { url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) errChan := make(chan error, 1) + defer close(errChan) awaitError := func(expected bool) { t.Helper() select { @@ -5048,7 +5049,7 @@ func TestJWScopedSigningKeys(t *testing.T) { } } } - errHdlr := nats.DisconnectErrHandler(func(conn *nats.Conn, err error) { + errHdlr := nats.ErrorHandler(func(conn *nats.Conn, s *nats.Subscription, err error) { errChan <- err }) if updateJwt(t, url, sysCreds, sysJwt, 1) != 1 { @@ -5063,22 +5064,27 @@ func TestJWScopedSigningKeys(t *testing.T) { t.Run("regular-signing-key", func(t *testing.T) { nc := natsConnect(t, url, nats.UserCredentials(aNonScopedCreds), errHdlr) defer nc.Close() + nc.Flush() err := nc.Publish("denied", nil) require_NoError(t, err) + nc.Flush() awaitError(false) }) - t.Run("scoped-signing-key-1", func(t *testing.T) { + t.Run("scoped-signing-key-client-side", func(t *testing.T) { nc := natsConnect(t, url, nats.UserCredentials(aScopedCreds), errHdlr) defer nc.Close() + nc.Flush() err := nc.Publish("too-long", []byte("way.too.long.for.payload.limit")) - require_NoError(t, err) - awaitError(true) + require_Error(t, err) + require_True(t, strings.Contains(err.Error(), ErrMaxPayload.Error())) }) - t.Run("scoped-signing-key-2", func(t *testing.T) { + t.Run("scoped-signing-key-server-side", func(t *testing.T) { nc := natsConnect(t, url, nats.UserCredentials(aScopedCreds), errHdlr) defer nc.Close() + nc.Flush() err := nc.Publish("denied", nil) require_NoError(t, err) + nc.Flush() awaitError(true) }) t.Run("scoped-signing-key-reload", func(t *testing.T) { @@ -5087,6 +5093,11 @@ func TestJWScopedSigningKeys(t *testing.T) { msgChan := make(chan *nats.Msg, 2) defer close(msgChan) nc := natsConnect(t, url, nats.UserCredentials(aScopedCreds), errHdlr, + nats.DisconnectErrHandler(func(conn *nats.Conn, err error) { + if err != nil { + errChan <- err + } + }), nats.ReconnectHandler(func(conn *nats.Conn) { reconChan <- struct{}{} }), @@ -5094,6 +5105,7 @@ func TestJWScopedSigningKeys(t *testing.T) { defer nc.Close() _, err := nc.ChanSubscribe("denied", msgChan) require_NoError(t, err) + nc.Flush() err = nc.Publish("denied", nil) require_NoError(t, err) awaitError(true) @@ -5109,12 +5121,13 @@ func TestJWScopedSigningKeys(t *testing.T) { // disconnect triggered by update awaitError(true) <-reconChan - if err := nc.Publish("denied", []byte("way.too.long.for.old.payload.limit")); err != nil { - t.Fatalf("Expected no error %v", err) - } + nc.Flush() + err = nc.Publish("denied", []byte("way.too.long.for.old.payload.limit")) + require_NoError(t, err) awaitError(false) msg := <-msgChan require_Equal(t, string(msg.Data), "way.too.long.for.old.payload.limit") require_Len(t, len(msgChan), 0) }) + require_Len(t, len(errChan), 0) }