[added] disconnect of all clients and disable account on remove

Error sent to the client: Account Authentication Expired

Signed-off-by: Matthias Hanel <mh@synadia.com>
This commit is contained in:
Matthias Hanel
2021-03-30 02:24:02 -04:00
parent c3479d339e
commit 6ffe9adf97
4 changed files with 57 additions and 4 deletions

View File

@@ -3590,6 +3590,28 @@ func claimValidate(claim *jwt.AccountClaims) error {
return nil
}
func removeCb(s *Server, pubKey string) {
v, ok := s.accounts.Load(pubKey)
if !ok {
return
}
a := v.(*Account)
s.Debugf("Disable account %s due to remove", pubKey)
a.mu.Lock()
// lock out new clients
a.msubs = 0
a.mpay = 0
a.mconns = 0
a.mleafs = 0
a.updated = time.Now().UTC()
a.mu.Unlock()
// set the account to be expired and disconnect clients
a.expiredTimeout()
a.mu.Lock()
a.clearExpirationTimer()
a.mu.Unlock()
}
func (dr *DirAccResolver) Start(s *Server) error {
op, opKeys, strict, err := getOperatorKeys(s)
if err != nil {
@@ -3607,6 +3629,9 @@ func (dr *DirAccResolver) Start(s *Server) error {
s.Errorf("update resulted in error %v", err)
}
}
dr.DirJWTStore.deleted = func(pubKey string) {
removeCb(s, pubKey)
}
packRespIb := s.newRespInbox()
for _, reqSub := range []string{accUpdateEventSubjOld, accUpdateEventSubjNew} {
// subscribe to account jwt update requests
@@ -3859,6 +3884,9 @@ func (dr *CacheDirAccResolver) Start(s *Server) error {
s.Errorf("update resulted in error %v", err)
}
}
dr.DirJWTStore.deleted = func(pubKey string) {
removeCb(s, pubKey)
}
for _, reqSub := range []string{accUpdateEventSubjOld, accUpdateEventSubjNew} {
// subscribe to account jwt update requests
if _, err := s.sysSubscribe(fmt.Sprintf(reqSub, "*"), func(_ *subscription, _ *client, subj, resp string, msg []byte) {

View File

@@ -82,6 +82,7 @@ type DirJWTStore struct {
operator map[string]struct{}
expiration *expirationTracker
changed JWTChanged
deleted JWTChanged
}
func newDir(dirPath string, create bool) (string, error) {
@@ -454,7 +455,7 @@ func (store *DirJWTStore) delete(publicKey string) error {
return err
}
store.expiration.unTrack(publicKey)
// TODO do cb so server can evict the account and associated clients
store.deleted(publicKey)
return nil
}

View File

@@ -892,6 +892,10 @@ func TestRemove(t *testing.T) {
require_Len(t, len(f), 1)
}
dirStore, err := NewExpiringDirJWTStore(dir, false, false, deleteType, 0, 10, true, 0, nil)
delPubKey := ""
dirStore.deleted = func(publicKey string) {
delPubKey = publicKey
}
require_NoError(t, err)
defer dirStore.Close()
accountKey, err := nkeys.CreateAccount()
@@ -901,6 +905,11 @@ func TestRemove(t *testing.T) {
createTestAccount(t, dirStore, 0, accountKey)
require_OneJWT()
dirStore.delete(pubKey)
if deleteType == NoDelete {
require_True(t, delPubKey == "")
} else {
require_True(t, delPubKey == pubKey)
}
f, err := filepath.Glob(dir + string(os.PathSeparator) + "/*.jwt")
require_NoError(t, err)
require_Len(t, len(f), test.expected)

View File

@@ -4499,6 +4499,8 @@ func TestJWTAccountOps(t *testing.T) {
dir: %s
}
`, opJwt, syspub, cfg, dirSrv)))
disconnectErrChan := make(chan struct{}, 1)
defer close(disconnectErrChan)
defer os.Remove(conf)
srv, _ := RunServerWithConfig(conf)
defer srv.Shutdown()
@@ -4511,9 +4513,14 @@ func TestJWTAccountOps(t *testing.T) {
nc.Subscribe(fmt.Sprintf(accLookupReqSubj, apub), func(msg *nats.Msg) {
msg.Respond([]byte(ajwt1))
})
// connect so there is a reason to cache the request
ncA := natsConnect(t, srv.ClientURL(), nats.UserCredentials(aCreds1))
ncA.Close()
// connect so there is a reason to cache the request and so disconnect can be observed
ncA := natsConnect(t, srv.ClientURL(), nats.UserCredentials(aCreds1), nats.NoReconnect(),
nats.DisconnectErrHandler(func(conn *nats.Conn, err error) {
if lErr := conn.LastError(); strings.Contains(lErr.Error(), "Account Authentication Expired") {
disconnectErrChan <- struct{}{}
}
}))
defer ncA.Close()
resp, err := nc.Request(accListReqSubj, nil, time.Second)
require_NoError(t, err)
require_True(t, strings.Contains(string(resp.Data), apub))
@@ -4532,6 +4539,14 @@ func TestJWTAccountOps(t *testing.T) {
require_False(t, strings.Contains(string(resp.Data), apub))
require_True(t, strings.Contains(string(resp.Data), syspub))
require_NoError(t, err)
if i > 0 {
continue
}
select {
case <-disconnectErrChan:
case <-time.After(time.Second):
t.Fatal("Callback not executed")
}
}
})
}