mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-17 03:24:40 -07:00
Merge pull request #2048 from nats-io/jwt-remove
Enhance Jwt remove by allowing operator signing keys to sign the remove request. It is reasonable for a customer to expect clients of the removed account to be disconnected. Absent account cleanup, expire the account.
This commit is contained in:
@@ -3513,8 +3513,7 @@ func handleDeleteRequest(store *DirJWTStore, s *Server, msg []byte, reply string
|
||||
if sysAcc := s.SystemAccount(); sysAcc != nil {
|
||||
sysAccName = sysAcc.GetName()
|
||||
}
|
||||
// TODO Can allow keys (issuer) to delete accounts they issued and operator key to delete all accounts.
|
||||
// For now only operator is allowed to delete
|
||||
// Only operator and operator signing key are allowed to delete
|
||||
gk, err := jwt.DecodeGeneric(string(msg))
|
||||
if err == nil {
|
||||
subj = gk.Subject
|
||||
@@ -3522,10 +3521,8 @@ func handleDeleteRequest(store *DirJWTStore, s *Server, msg []byte, reply string
|
||||
err = fmt.Errorf("delete must be enabled in server config")
|
||||
} else if subj != gk.Issuer {
|
||||
err = fmt.Errorf("not self signed")
|
||||
} else if !s.isTrustedIssuer(gk.Issuer) {
|
||||
} else if _, ok := store.operator[gk.Issuer]; !ok {
|
||||
err = fmt.Errorf("not trusted")
|
||||
} else if store.operator != gk.Issuer {
|
||||
err = fmt.Errorf("needs to be the operator operator")
|
||||
} else if list, ok := gk.Data["accounts"]; !ok {
|
||||
err = fmt.Errorf("malformed request")
|
||||
} else if accIds, ok = list.([]interface{}); !ok {
|
||||
@@ -3560,21 +3557,28 @@ func handleDeleteRequest(store *DirJWTStore, s *Server, msg []byte, reply string
|
||||
respondToUpdate(s, reply, "", fmt.Sprintf("deleted %d accounts", passCnt), nil)
|
||||
} else {
|
||||
respondToUpdate(s, reply, "", fmt.Sprintf("deleted %d accounts, failed for %d", passCnt, len(errs)),
|
||||
errors.New(strings.Join(errs, "<\n")))
|
||||
errors.New(strings.Join(errs, "\n")))
|
||||
}
|
||||
}
|
||||
|
||||
func getOperator(s *Server) (string, bool, error) {
|
||||
func getOperatorKeys(s *Server) (string, map[string]struct{}, bool, error) {
|
||||
var op string
|
||||
var strict bool
|
||||
keys := make(map[string]struct{})
|
||||
if opts := s.getOpts(); opts != nil && len(opts.TrustedOperators) > 0 {
|
||||
op = opts.TrustedOperators[0].Subject
|
||||
strict = opts.TrustedOperators[0].StrictSigningKeyUsage
|
||||
if !strict {
|
||||
keys[opts.TrustedOperators[0].Subject] = struct{}{}
|
||||
}
|
||||
for _, key := range opts.TrustedOperators[0].SigningKeys {
|
||||
keys[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
if op == "" {
|
||||
return "", false, fmt.Errorf("no operator found")
|
||||
if len(keys) == 0 {
|
||||
return "", nil, false, fmt.Errorf("no operator key found")
|
||||
}
|
||||
return op, strict, nil
|
||||
return op, keys, strict, nil
|
||||
}
|
||||
|
||||
func claimValidate(claim *jwt.AccountClaims) error {
|
||||
@@ -3586,15 +3590,37 @@ func claimValidate(claim *jwt.AccountClaims) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeCb(s *Server, pubKey string) {
|
||||
v, ok := s.accounts.Load(pubKey)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
a := v.(*Account)
|
||||
s.Debugf("Disable account %s due to remove", pubKey)
|
||||
a.mu.Lock()
|
||||
// lock out new clients
|
||||
a.msubs = 0
|
||||
a.mpay = 0
|
||||
a.mconns = 0
|
||||
a.mleafs = 0
|
||||
a.updated = time.Now().UTC()
|
||||
a.mu.Unlock()
|
||||
// set the account to be expired and disconnect clients
|
||||
a.expiredTimeout()
|
||||
a.mu.Lock()
|
||||
a.clearExpirationTimer()
|
||||
a.mu.Unlock()
|
||||
}
|
||||
|
||||
func (dr *DirAccResolver) Start(s *Server) error {
|
||||
op, strict, err := getOperator(s)
|
||||
op, opKeys, strict, err := getOperatorKeys(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dr.Lock()
|
||||
defer dr.Unlock()
|
||||
dr.Server = s
|
||||
dr.operator = op
|
||||
dr.operator = opKeys
|
||||
dr.DirJWTStore.changed = func(pubKey string) {
|
||||
if v, ok := s.accounts.Load(pubKey); !ok {
|
||||
} else if theJwt, err := dr.LoadAcc(pubKey); err != nil {
|
||||
@@ -3603,6 +3629,9 @@ func (dr *DirAccResolver) Start(s *Server) error {
|
||||
s.Errorf("update resulted in error %v", err)
|
||||
}
|
||||
}
|
||||
dr.DirJWTStore.deleted = func(pubKey string) {
|
||||
removeCb(s, pubKey)
|
||||
}
|
||||
packRespIb := s.newRespInbox()
|
||||
for _, reqSub := range []string{accUpdateEventSubjOld, accUpdateEventSubjNew} {
|
||||
// subscribe to account jwt update requests
|
||||
@@ -3839,14 +3868,14 @@ func NewCacheDirAccResolver(path string, limit int64, ttl time.Duration, _ ...di
|
||||
}
|
||||
|
||||
func (dr *CacheDirAccResolver) Start(s *Server) error {
|
||||
op, strict, err := getOperator(s)
|
||||
op, opKeys, strict, err := getOperatorKeys(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dr.Lock()
|
||||
defer dr.Unlock()
|
||||
dr.Server = s
|
||||
dr.operator = op
|
||||
dr.operator = opKeys
|
||||
dr.DirJWTStore.changed = func(pubKey string) {
|
||||
if v, ok := s.accounts.Load(pubKey); !ok {
|
||||
} else if theJwt, err := dr.LoadAcc(pubKey); err != nil {
|
||||
@@ -3855,6 +3884,9 @@ func (dr *CacheDirAccResolver) Start(s *Server) error {
|
||||
s.Errorf("update resulted in error %v", err)
|
||||
}
|
||||
}
|
||||
dr.DirJWTStore.deleted = func(pubKey string) {
|
||||
removeCb(s, pubKey)
|
||||
}
|
||||
for _, reqSub := range []string{accUpdateEventSubjOld, accUpdateEventSubjNew} {
|
||||
// subscribe to account jwt update requests
|
||||
if _, err := s.sysSubscribe(fmt.Sprintf(reqSub, "*"), func(_ *subscription, _ *client, subj, resp string, msg []byte) {
|
||||
|
||||
@@ -79,9 +79,10 @@ type DirJWTStore struct {
|
||||
shard bool
|
||||
readonly bool
|
||||
deleteType deleteType
|
||||
operator string
|
||||
operator map[string]struct{}
|
||||
expiration *expirationTracker
|
||||
changed JWTChanged
|
||||
deleted JWTChanged
|
||||
}
|
||||
|
||||
func newDir(dirPath string, create bool) (string, error) {
|
||||
@@ -454,7 +455,7 @@ func (store *DirJWTStore) delete(publicKey string) error {
|
||||
return err
|
||||
}
|
||||
store.expiration.unTrack(publicKey)
|
||||
// TODO do cb so server can evict the account and associated clients
|
||||
store.deleted(publicKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -892,6 +892,10 @@ func TestRemove(t *testing.T) {
|
||||
require_Len(t, len(f), 1)
|
||||
}
|
||||
dirStore, err := NewExpiringDirJWTStore(dir, false, false, deleteType, 0, 10, true, 0, nil)
|
||||
delPubKey := ""
|
||||
dirStore.deleted = func(publicKey string) {
|
||||
delPubKey = publicKey
|
||||
}
|
||||
require_NoError(t, err)
|
||||
defer dirStore.Close()
|
||||
accountKey, err := nkeys.CreateAccount()
|
||||
@@ -901,6 +905,11 @@ func TestRemove(t *testing.T) {
|
||||
createTestAccount(t, dirStore, 0, accountKey)
|
||||
require_OneJWT()
|
||||
dirStore.delete(pubKey)
|
||||
if deleteType == NoDelete {
|
||||
require_True(t, delPubKey == "")
|
||||
} else {
|
||||
require_True(t, delPubKey == pubKey)
|
||||
}
|
||||
f, err := filepath.Glob(dir + string(os.PathSeparator) + "/*.jwt")
|
||||
require_NoError(t, err)
|
||||
require_Len(t, len(f), test.expected)
|
||||
|
||||
@@ -4438,13 +4438,21 @@ func TestJWTUserRevocation(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestJWTAccountOps(t *testing.T) {
|
||||
op, _ := nkeys.CreateOperator()
|
||||
opPk, _ := op.PublicKey()
|
||||
sk, _ := nkeys.CreateOperator()
|
||||
skPk, _ := sk.PublicKey()
|
||||
opClaim := jwt.NewOperatorClaims(opPk)
|
||||
opClaim.SigningKeys.Add(skPk)
|
||||
opJwt, err := opClaim.Encode(op)
|
||||
require_NoError(t, err)
|
||||
createAccountAndUser := func(pubKey, jwt1, creds1 *string) {
|
||||
t.Helper()
|
||||
kp, _ := nkeys.CreateAccount()
|
||||
*pubKey, _ = kp.PublicKey()
|
||||
claim := jwt.NewAccountClaims(*pubKey)
|
||||
var err error
|
||||
*jwt1, err = claim.Encode(oKp)
|
||||
*jwt1, err = claim.Encode(sk)
|
||||
require_NoError(t, err)
|
||||
|
||||
ukp, _ := nkeys.CreateUser()
|
||||
@@ -4457,12 +4465,12 @@ func TestJWTAccountOps(t *testing.T) {
|
||||
require_NoError(t, err)
|
||||
*creds1 = genCredsFile(t, ujwt1, seed)
|
||||
}
|
||||
generateRequest := func(accs []string) []byte {
|
||||
generateRequest := func(accs []string, kp nkeys.KeyPair) []byte {
|
||||
t.Helper()
|
||||
opk, _ := oKp.PublicKey()
|
||||
opk, _ := kp.PublicKey()
|
||||
c := jwt.NewGenericClaims(opk)
|
||||
c.Data["accounts"] = accs
|
||||
cJwt, err := c.Encode(oKp)
|
||||
cJwt, err := c.Encode(kp)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error %v", err)
|
||||
}
|
||||
@@ -4490,7 +4498,9 @@ func TestJWTAccountOps(t *testing.T) {
|
||||
%s
|
||||
dir: %s
|
||||
}
|
||||
`, ojwt, syspub, cfg, dirSrv)))
|
||||
`, opJwt, syspub, cfg, dirSrv)))
|
||||
disconnectErrChan := make(chan struct{}, 1)
|
||||
defer close(disconnectErrChan)
|
||||
defer os.Remove(conf)
|
||||
srv, _ := RunServerWithConfig(conf)
|
||||
defer srv.Shutdown()
|
||||
@@ -4503,26 +4513,40 @@ func TestJWTAccountOps(t *testing.T) {
|
||||
nc.Subscribe(fmt.Sprintf(accLookupReqSubj, apub), func(msg *nats.Msg) {
|
||||
msg.Respond([]byte(ajwt1))
|
||||
})
|
||||
// connect so there is a reason to cache the request
|
||||
ncA := natsConnect(t, srv.ClientURL(), nats.UserCredentials(aCreds1))
|
||||
ncA.Close()
|
||||
// connect so there is a reason to cache the request and so disconnect can be observed
|
||||
ncA := natsConnect(t, srv.ClientURL(), nats.UserCredentials(aCreds1), nats.NoReconnect(),
|
||||
nats.DisconnectErrHandler(func(conn *nats.Conn, err error) {
|
||||
if lErr := conn.LastError(); strings.Contains(lErr.Error(), "Account Authentication Expired") {
|
||||
disconnectErrChan <- struct{}{}
|
||||
}
|
||||
}))
|
||||
defer ncA.Close()
|
||||
resp, err := nc.Request(accListReqSubj, nil, time.Second)
|
||||
require_NoError(t, err)
|
||||
require_True(t, strings.Contains(string(resp.Data), apub))
|
||||
require_True(t, strings.Contains(string(resp.Data), syspub))
|
||||
// delete nothing
|
||||
resp, err = nc.Request(accDeleteReqSubj, generateRequest([]string{}), time.Second)
|
||||
resp, err = nc.Request(accDeleteReqSubj, generateRequest([]string{}, op), time.Second)
|
||||
require_NoError(t, err)
|
||||
require_True(t, strings.Contains(string(resp.Data), `"message":"deleted 0 accounts"`))
|
||||
// issue delete, twice to also delete a non existing account
|
||||
// also switch which key used to sign the request
|
||||
for i := 0; i < 2; i++ {
|
||||
resp, err = nc.Request(accDeleteReqSubj, generateRequest([]string{apub}), time.Second)
|
||||
resp, err = nc.Request(accDeleteReqSubj, generateRequest([]string{apub}, sk), time.Second)
|
||||
require_NoError(t, err)
|
||||
require_True(t, strings.Contains(string(resp.Data), `"message":"deleted 1 accounts"`))
|
||||
resp, err = nc.Request(accListReqSubj, nil, time.Second)
|
||||
require_False(t, strings.Contains(string(resp.Data), apub))
|
||||
require_True(t, strings.Contains(string(resp.Data), syspub))
|
||||
require_NoError(t, err)
|
||||
if i > 0 {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-disconnectErrChan:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Callback not executed")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user