diff --git a/server/accounts.go b/server/accounts.go index 472141bf..b68059d0 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -3902,7 +3902,10 @@ func (dr *DirAccResolver) Start(s *Server) error { return fmt.Errorf("error setting up update handling: %v", err) } } - if _, err := s.sysSubscribe(accClaimsReqSubj, func(_ *subscription, _ *client, _ *Account, subj, resp string, msg []byte) { + if _, err := s.sysSubscribe(accClaimsReqSubj, func(_ *subscription, c *client, _ *Account, _, resp string, msg []byte) { + // As this is a raw message, we need to extract payload and only decode claims from it, + // in case request is sent with headers. + _, msg = c.msgParts(msg) if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil { respondToUpdate(s, resp, "n/a", "jwt update resulted in error", err) } else if claim.Issuer == op && strict { @@ -4190,7 +4193,10 @@ func (dr *CacheDirAccResolver) Start(s *Server) error { return fmt.Errorf("error setting up update handling: %v", err) } } - if _, err := s.sysSubscribe(accClaimsReqSubj, func(_ *subscription, _ *client, _ *Account, subj, resp string, msg []byte) { + if _, err := s.sysSubscribe(accClaimsReqSubj, func(_ *subscription, c *client, _ *Account, _, resp string, msg []byte) { + // As this is a raw message, we need to extract payload and only decode claims from it, + // in case request is sent with headers. + _, msg = c.msgParts(msg) if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil { respondToUpdate(s, resp, "n/a", "jwt update cache resulted in error", err) } else if claim.Issuer == op && strict { diff --git a/server/jwt_test.go b/server/jwt_test.go index 6467e9b3..10858924 100644 --- a/server/jwt_test.go +++ b/server/jwt_test.go @@ -6132,6 +6132,61 @@ func TestJWTAccountProtectedImport(t *testing.T) { }) } +// Headers are ignored in claims update, but passing them should not cause error. +func TestJWTClaimsUpdateWithHeaders(t *testing.T) { + skp, spub := createKey(t) + newUser(t, skp) + + sclaim := jwt.NewAccountClaims(spub) + encodeClaim(t, sclaim, spub) + + akp, apub := createKey(t) + newUser(t, akp) + claim := jwt.NewAccountClaims(apub) + jwtClaim := encodeClaim(t, claim, apub) + + dirSrv := t.TempDir() + + conf := createConfFile(t, []byte(fmt.Sprintf(` + listen: 127.0.0.1:-1 + operator: %s + system_account: %s + resolver: { + type: full + dir: '%s' + } + `, ojwt, spub, dirSrv))) + + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + type zapi struct { + Server *ServerInfo + Data *Connz + Error *ApiError + } + + sc := natsConnect(t, s.ClientURL(), createUserCreds(t, s, skp)) + defer sc.Close() + // Pass claims update with headers. + msg := &nats.Msg{ + Subject: "$SYS.REQ.CLAIMS.UPDATE", + Data: []byte(jwtClaim), + Header: map[string][]string{"key": {"value"}}, + } + resp, err := sc.RequestMsg(msg, time.Second) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + var cz zapi + if err := json.Unmarshal(resp.Data, &cz); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if cz.Error != nil { + t.Fatalf("Unexpected error: %+v", cz.Error) + } +} + func TestJWTMappings(t *testing.T) { sysKp, syspub := createKey(t) sysJwt := encodeClaim(t, jwt.NewAccountClaims(syspub), syspub)