diff --git a/server/accounts.go b/server/accounts.go index b9606f4f..42e19389 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -14,8 +14,11 @@ package server import ( + "bytes" + "errors" "fmt" "io/ioutil" + "math" "math/rand" "net/http" "net/url" @@ -2581,16 +2584,50 @@ func buildInternalNkeyUser(uc *jwt.UserClaims, acc *Account) *NkeyUser { return nu } +const fetchTimeout = 2 * time.Second + // AccountResolver interface. This is to fetch Account JWTs by public nkeys type AccountResolver interface { Fetch(name string) (string, error) Store(name, jwt string) error + IsReadOnly() bool + Start(server *Server) error + IsTrackingUpdate() bool + Reload() error + Close() +} + +// Default implementations of IsReadOnly/Start so only need to be written when changed +type resolverDefaultsOpsImpl struct{} + +func (*resolverDefaultsOpsImpl) IsReadOnly() bool { + return true +} + +func (*resolverDefaultsOpsImpl) IsTrackingUpdate() bool { + return false +} + +func (*resolverDefaultsOpsImpl) Start(*Server) error { + return nil +} + +func (*resolverDefaultsOpsImpl) Reload() error { + return nil +} + +func (*resolverDefaultsOpsImpl) Close() { +} + +func (*resolverDefaultsOpsImpl) Store(_, _ string) error { + return fmt.Errorf("Store operation not supported for URL Resolver") } // MemAccResolver is a memory only resolver. // Mostly for testing. type MemAccResolver struct { sm sync.Map + resolverDefaultsOpsImpl } // Fetch will fetch the account jwt claims from the internal sync.Map. @@ -2607,10 +2644,15 @@ func (m *MemAccResolver) Store(name, jwt string) error { return nil } +func (ur *MemAccResolver) IsReadOnly() bool { + return false +} + // URLAccResolver implements an http fetcher. type URLAccResolver struct { url string c *http.Client + resolverDefaultsOpsImpl } // NewURLAccResolver returns a new resolver for the given base URL. @@ -2626,7 +2668,7 @@ func NewURLAccResolver(url string) (*URLAccResolver, error) { } ur := &URLAccResolver{ url: url, - c: &http.Client{Timeout: 2 * time.Second, Transport: tr}, + c: &http.Client{Timeout: fetchTimeout, Transport: tr}, } return ur, nil } @@ -2651,7 +2693,277 @@ func (ur *URLAccResolver) Fetch(name string) (string, error) { return string(body), nil } -// Store is not implemented for URL Resolver. -func (ur *URLAccResolver) Store(name, jwt string) error { - return fmt.Errorf("Store operation not supported for URL Resolver") +// Resolver based on nats for synchronization and backing directory for storage. +type DirAccResolver struct { + *DirJWTStore + syncInterval time.Duration +} + +func (dr *DirAccResolver) IsTrackingUpdate() bool { + return true +} + +func respondToUpdate(s *Server, respSubj string, acc string, message string, err error) { + if err == nil { + s.Debugf("%s - %s", message, acc) + } else { + s.Errorf("%s - %s - %s", message, acc, err) + } + if respSubj == "" { + return + } + server := &ServerInfo{} + response := map[string]interface{}{"server": server} + if err == nil { + response["data"] = map[string]interface{}{ + "code": http.StatusOK, + "account": acc, + "message": message, + } + } else { + response["error"] = map[string]interface{}{ + "code": http.StatusInternalServerError, + "account": acc, + "description": fmt.Sprintf("%s - %v", message, err), + } + } + s.sendInternalMsgLocked(respSubj, _EMPTY_, server, response) +} + +func (dr *DirAccResolver) Start(s *Server) error { + dr.Lock() + defer dr.Unlock() + dr.DirJWTStore.changed = func(pubKey string) { + if v, ok := s.accounts.Load(pubKey); !ok { + } else if jwt, err := dr.LoadAcc(pubKey); err != nil { + s.Errorf("update got error on load: %v", err) + } else if err := s.updateAccountWithClaimJWT(v.(*Account), jwt); err != nil { + s.Errorf("update resulted in error %v", err) + } + } + const accountPackRequest = "$SYS.ACCOUNT.CLAIMS.PACK" + const accountLookupRequest = "$SYS.ACCOUNT.*.CLAIMS.LOOKUP" + packRespIb := s.newRespInbox() + // subscribe to account jwt update requests + if _, err := s.sysSubscribe(fmt.Sprintf(accUpdateEventSubj, "*"), func(_ *subscription, _ *client, subj, resp string, msg []byte) { + tk := strings.Split(subj, tsep) + if len(tk) != accUpdateTokens { + return + } + pubKey := tk[accUpdateAccIndex] + if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil { + respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err) + } else if claim.Subject != pubKey { + err := errors.New("subject does not match jwt content") + respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err) + } else if err := dr.save(pubKey, string(msg)); err != nil { + respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err) + } else { + respondToUpdate(s, resp, pubKey, "jwt updated", nil) + } + }); err != nil { + return fmt.Errorf("error setting up update handling: %v", err) + } else if _, err := s.sysSubscribe(accountLookupRequest, func(_ *subscription, _ *client, subj, reply string, msg []byte) { + // respond to lookups with our version + if reply == "" { + return + } + tk := strings.Split(subj, tsep) + if len(tk) != accUpdateTokens { + return + } + if theJWT, err := dr.DirJWTStore.LoadAcc(tk[accUpdateAccIndex]); err != nil { + s.Errorf("Merging resulted in error: %v", err) + } else { + s.sendInternalMsgLocked(reply, "", nil, []byte(theJWT)) + } + }); err != nil { + return fmt.Errorf("error setting up lookup request handling: %v", err) + } else if _, err = s.sysSubscribeQ(accountPackRequest, "responder", + // respond to pack requests with one or more pack messages + // an empty message signifies the end of the response responder + func(_ *subscription, _ *client, _, reply string, theirHash []byte) { + if reply == "" { + return + } + ourHash := dr.DirJWTStore.Hash() + if bytes.Equal(theirHash, ourHash[:]) { + s.sendInternalMsgLocked(reply, "", nil, []byte{}) + s.Debugf("pack request matches hash %x", ourHash[:]) + } else if err := dr.DirJWTStore.PackWalk(1, func(partialPackMsg string) { + s.sendInternalMsgLocked(reply, "", nil, []byte(partialPackMsg)) + }); err != nil { + // let them timeout + s.Errorf("pack request error: %v", err) + } else { + s.Debugf("pack request hash %x - finished responding with hash %x") + s.sendInternalMsgLocked(reply, "", nil, []byte{}) + } + }); err != nil { + return fmt.Errorf("error setting up pack request handling: %v", err) + } else if _, err = s.sysSubscribe(packRespIb, func(_ *subscription, _ *client, _, _ string, msg []byte) { + // embed pack responses into store + hash := dr.DirJWTStore.Hash() + if len(msg) == 0 { // end of response stream + s.Debugf("Merging Finished and resulting in: %x", dr.DirJWTStore.Hash()) + return + } else if err := dr.DirJWTStore.Merge(string(msg)); err != nil { + s.Errorf("Merging resulted in error: %v", err) + } else { + s.Debugf("Merging succeeded and changed %x to %x", hash, dr.DirJWTStore.Hash()) + } + }); err != nil { + return fmt.Errorf("error setting up pack response handling: %v", err) + } + // periodically send out pack message + quit := s.quitCh + s.startGoRoutine(func() { + defer s.grWG.Done() + ticker := time.NewTicker(dr.syncInterval) + for { + select { + case <-quit: + ticker.Stop() + return + case <-ticker.C: + } + ourHash := dr.DirJWTStore.Hash() + s.Debugf("Checking store state: %x", ourHash) + s.sendInternalMsgLocked(accountPackRequest, packRespIb, nil, ourHash[:]) + } + }) + s.Noticef("Managing all jwt in exclusive directory %s", dr.directory) + return nil +} + +func (dr *DirAccResolver) Fetch(name string) (string, error) { + return dr.LoadAcc(name) +} + +func (dr *DirAccResolver) Store(name, jwt string) error { + return dr.saveIfNewer(name, jwt) +} + +func NewDirAccResolver(path string, limit int64, syncInterval time.Duration) (*DirAccResolver, error) { + if limit == 0 { + limit = math.MaxInt64 + } + if syncInterval <= 0 { + syncInterval = time.Minute + } + store, err := NewExpiringDirJWTStore(path, false, true, 0, limit, false, 0, nil) + if err != nil { + return nil, err + } + return &DirAccResolver{store, syncInterval}, nil +} + +// Caching resolver using nats for lookups and making use of a directory for storage +type CacheDirAccResolver struct { + DirAccResolver + *Server + ttl time.Duration +} + +func (dr *CacheDirAccResolver) Fetch(name string) (string, error) { + if theJWT, _ := dr.LoadAcc(name); theJWT != "" { + return theJWT, nil + } + // lookup from other server + s := dr.Server + if s == nil { + return "", ErrNoAccountResolver + } + respC := make(chan []byte, 1) + accountLookupRequest := fmt.Sprintf("$SYS.ACCOUNT.%s.CLAIMS.LOOKUP", name) + s.mu.Lock() + replySubj := s.newRespInbox() + if s.sys == nil || s.sys.replies == nil { + s.mu.Unlock() + return "", fmt.Errorf("eventing shut down") + } + replies := s.sys.replies + // Store our handler. + replies[replySubj] = func(sub *subscription, _ *client, subject, _ string, msg []byte) { + clone := make([]byte, len(msg)) + copy(clone, msg) + s.mu.Lock() + if _, ok := replies[replySubj]; ok { + respC <- clone // only send if there is still interest + } + s.mu.Unlock() + } + s.sendInternalMsg(accountLookupRequest, replySubj, nil, []byte{}) + quit := s.quitCh + s.mu.Unlock() + var err error + var theJWT string + select { + case <-quit: + err = errors.New("fetching jwt failed due to shutdown") + case <-time.After(fetchTimeout): + err = errors.New("fetching jwt timed out") + case m := <-respC: + if err = dr.Store(name, string(m)); err == nil { + theJWT = string(m) + } + } + s.mu.Lock() + delete(replies, replySubj) + s.mu.Unlock() + close(respC) + return theJWT, err +} + +func NewCacheDirAccResolver(path string, limit int64, ttl time.Duration) (*CacheDirAccResolver, error) { + if limit <= 0 { + limit = 1_000 + } + store, err := NewExpiringDirJWTStore(path, false, true, 0, limit, true, ttl, nil) + if err != nil { + return nil, err + } + return &CacheDirAccResolver{DirAccResolver{store, 0}, nil, ttl}, nil +} + +func (dr *CacheDirAccResolver) Start(s *Server) error { + dr.Lock() + defer dr.Unlock() + dr.Server = s + dr.DirJWTStore.changed = func(pubKey string) { + if v, ok := s.accounts.Load(pubKey); !ok { + } else if jwt, err := dr.LoadAcc(pubKey); err != nil { + s.Errorf("update got error on load: %v", err) + } else if err := s.updateAccountWithClaimJWT(v.(*Account), jwt); err != nil { + s.Errorf("update resulted in error %v", err) + } + } + // subscribe to account jwt update requests + if _, err := s.sysSubscribe(fmt.Sprintf(accUpdateEventSubj, "*"), func(_ *subscription, _ *client, subj, resp string, msg []byte) { + tk := strings.Split(subj, tsep) + if len(tk) != accUpdateTokens { + return + } + pubKey := tk[accUpdateAccIndex] + if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil { + respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err) + } else if claim.Subject != pubKey { + err := errors.New("subject does not match jwt content") + respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err) + } else if _, ok := s.accounts.Load(pubKey); !ok { + respondToUpdate(s, resp, pubKey, "jwt update cache skipped", nil) + } else if err := dr.save(pubKey, string(msg)); err != nil { + respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err) + } else { + respondToUpdate(s, resp, pubKey, "jwt updated cache", nil) + } + }); err != nil { + return fmt.Errorf("error setting up update handling: %v", err) + } + s.Noticef("Managing some jwt in exclusive directory %s", dr.directory) + return nil +} + +func (dr *CacheDirAccResolver) Reload() error { + return dr.DirAccResolver.Reload() } diff --git a/server/dirstore.go b/server/dirstore.go new file mode 100644 index 00000000..a1e0bc5f --- /dev/null +++ b/server/dirstore.go @@ -0,0 +1,670 @@ +/* + * Copyright 2020 The NATS Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package server + +import ( + "bytes" + "container/heap" + "container/list" + "crypto/sha256" + "errors" + "fmt" + "io/ioutil" + "math" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/nats-io/jwt/v2" // only used to decode, not for storage +) + +const ( + fileExtension = ".jwt" +) + +// validatePathExists checks that the provided path exists and is a dir if requested +func validatePathExists(path string, dir bool) (string, error) { + if path == "" { + return "", errors.New("path is not specified") + } + + abs, err := filepath.Abs(path) + if err != nil { + return "", fmt.Errorf("error parsing path [%s]: %v", abs, err) + } + + var finfo os.FileInfo + if finfo, err = os.Stat(abs); os.IsNotExist(err) { + return "", fmt.Errorf("the path [%s] doesn't exist", abs) + } + + mode := finfo.Mode() + if dir && mode.IsRegular() { + return "", fmt.Errorf("the path [%s] is not a directory", abs) + } + + if !dir && mode.IsDir() { + return "", fmt.Errorf("the path [%s] is not a file", abs) + } + + return abs, nil +} + +// ValidateDirPath checks that the provided path exists and is a dir +func validateDirPath(path string) (string, error) { + return validatePathExists(path, true) +} + +// JWTChanged functions are called when the store file watcher notices a JWT changed +type JWTChanged func(publicKey string) + +// DirJWTStore implements the JWT Store interface, keeping JWTs in an optionally sharded +// directory structure +type DirJWTStore struct { + sync.Mutex + directory string + shard bool + readonly bool + expiration *expirationTracker + changed JWTChanged +} + +func newDir(dirPath string, create bool) (string, error) { + fullPath, err := validateDirPath(dirPath) + if err != nil { + if !create { + return "", err + } + if err = os.MkdirAll(dirPath, 0755); err != nil { + return "", err + } + if fullPath, err = validateDirPath(dirPath); err != nil { + return "", err + } + } + return fullPath, nil +} + +// Creates a directory based jwt store. +// Reads files only, does NOT watch directories and files. +func NewImmutableDirJWTStore(dirPath string, shard bool) (*DirJWTStore, error) { + theStore, err := NewDirJWTStore(dirPath, shard, false) + if err != nil { + return nil, err + } + theStore.readonly = true + return theStore, nil +} + +// Creates a directory based jwt store. +// Operates on files only, does NOT watch directories and files. +func NewDirJWTStore(dirPath string, shard bool, create bool) (*DirJWTStore, error) { + fullPath, err := newDir(dirPath, create) + if err != nil { + return nil, err + } + theStore := &DirJWTStore{ + directory: fullPath, + shard: shard, + } + return theStore, nil +} + +// Creates a directory based jwt store. +// +// When ttl is set deletion of file is based on it and not on the jwt expiration +// To completely disable expiration (including expiration in jwt) set ttl to max duration time.Duration(math.MaxInt64) +// +// limit defines how many files are allowed at any given time. Set to math.MaxInt64 to disable. +// evictOnLimit determines the behavior once limit is reached. +// true - Evict based on lru strategy +// false - return an error +func NewExpiringDirJWTStore(dirPath string, shard bool, create bool, expireCheck time.Duration, limit int64, + evictOnLimit bool, ttl time.Duration, changeNotification JWTChanged) (*DirJWTStore, error) { + fullPath, err := newDir(dirPath, create) + if err != nil { + return nil, err + } + theStore := &DirJWTStore{ + directory: fullPath, + shard: shard, + changed: changeNotification, + } + if expireCheck <= 0 { + if ttl != 0 { + expireCheck = ttl / 2 + } + if expireCheck == 0 || expireCheck > time.Minute { + expireCheck = time.Minute + } + } + if limit <= 0 { + limit = math.MaxInt64 + } + theStore.startExpiring(expireCheck, limit, evictOnLimit, ttl) + theStore.Lock() + err = filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error { + if strings.HasSuffix(path, fileExtension) { + if theJwt, err := ioutil.ReadFile(path); err == nil { + hash := sha256.Sum256(theJwt) + _, file := filepath.Split(path) + theStore.expiration.track(strings.TrimSuffix(file, fileExtension), &hash, string(theJwt)) + } + } + return nil + }) + theStore.Unlock() + if err != nil { + theStore.Close() + return nil, err + } + return theStore, err +} + +func (store *DirJWTStore) IsReadOnly() bool { + return store.readonly +} + +func (store *DirJWTStore) LoadAcc(publicKey string) (string, error) { + return store.load(publicKey) +} + +func (store *DirJWTStore) SaveAcc(publicKey string, theJWT string) error { + return store.save(publicKey, theJWT) +} + +func (store *DirJWTStore) LoadAct(hash string) (string, error) { + return store.load(hash) +} + +func (store *DirJWTStore) SaveAct(hash string, theJWT string) error { + return store.save(hash, theJWT) +} + +func (store *DirJWTStore) Close() { + store.Lock() + defer store.Unlock() + if store.expiration != nil { + store.expiration.close() + store.expiration = nil + } +} + +// Pack up to maxJWTs into a package +func (store *DirJWTStore) Pack(maxJWTs int) (string, error) { + count := 0 + var pack []string + if maxJWTs > 0 { + pack = make([]string, 0, maxJWTs) + } else { + pack = []string{} + } + store.Lock() + err := filepath.Walk(store.directory, func(path string, info os.FileInfo, err error) error { + if !info.IsDir() && strings.HasSuffix(path, fileExtension) { // this is a JWT + if count == maxJWTs { // won't match negative + return nil + } + pubKey := strings.TrimSuffix(filepath.Base(path), fileExtension) + if store.expiration != nil { + if _, ok := store.expiration.idx[pubKey]; !ok { + return nil // only include indexed files + } + } + jwtBytes, err := ioutil.ReadFile(path) + if err != nil { + return err + } + if store.expiration != nil { + claim, err := jwt.DecodeGeneric(string(jwtBytes)) + if err == nil && claim.Expires > 0 && claim.Expires < time.Now().Unix() { + return nil + } + } + pack = append(pack, fmt.Sprintf("%s|%s", pubKey, string(jwtBytes))) + count++ + } + return nil + }) + store.Unlock() + if err != nil { + return "", err + } else { + return strings.Join(pack, "\n"), nil + } +} + +// Pack up to maxJWTs into a message and invoke callback with it +func (store *DirJWTStore) PackWalk(maxJWTs int, cb func(partialPackMsg string)) error { + if maxJWTs <= 0 || cb == nil { + return errors.New("bad arguments to PackWalk") + } + var packMsg []string + store.Lock() + dir := store.directory + exp := store.expiration + store.Unlock() + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if !info.IsDir() && strings.HasSuffix(path, fileExtension) { // this is a JWT + pubKey := strings.TrimSuffix(filepath.Base(path), fileExtension) + store.Lock() + if exp != nil { + if _, ok := store.expiration.idx[pubKey]; !ok { + store.Unlock() + return nil // only include indexed files + } + } + store.Unlock() + jwtBytes, err := ioutil.ReadFile(path) + if err != nil { + return err + } + if exp != nil { + claim, err := jwt.DecodeGeneric(string(jwtBytes)) + if err == nil && claim.Expires > 0 && claim.Expires < time.Now().Unix() { + return nil + } + } + packMsg = append(packMsg, fmt.Sprintf("%s|%s", pubKey, string(jwtBytes))) + if len(packMsg) == maxJWTs { // won't match negative + cb(strings.Join(packMsg, "\n")) + packMsg = nil + } + } + return nil + }) + if packMsg != nil { + cb(strings.Join(packMsg, "\n")) + } + return err +} + +// Merge takes the JWTs from package and adds them to the store +// Merge is destructive in the sense that it doesn't check if the JWT +// is newer or anything like that. +func (store *DirJWTStore) Merge(pack string) error { + newJWTs := strings.Split(pack, "\n") + for _, line := range newJWTs { + if line == "" { // ignore blank lines + continue + } + split := strings.Split(line, "|") + if len(split) != 2 { + return fmt.Errorf("line in package didn't contain 2 entries: %q", line) + } + pubKey := split[0] + if err := store.saveIfNewer(pubKey, split[1]); err != nil { + return err + } + } + return nil +} + +func (store *DirJWTStore) Reload() error { + store.Lock() + exp := store.expiration + if exp == nil || store.readonly { + store.Unlock() + return nil + } + idx := exp.idx + changed := store.changed + isCache := store.expiration.evictOnLimit + // clear out indexing data structures + exp.heap = make([]*jwtItem, 0, len(exp.heap)) + exp.idx = make(map[string]*list.Element) + exp.lru = list.New() + exp.hash = [sha256.Size]byte{} + store.Unlock() + return filepath.Walk(store.directory, func(path string, info os.FileInfo, err error) error { + if strings.HasSuffix(path, fileExtension) { + if theJwt, err := ioutil.ReadFile(path); err == nil { + hash := sha256.Sum256(theJwt) + _, file := filepath.Split(path) + pkey := strings.TrimSuffix(file, fileExtension) + notify := isCache // for cache, issue cb even when file not present (may have been evicted) + if i, ok := idx[pkey]; ok { + notify = !bytes.Equal(i.Value.(*jwtItem).hash[:], hash[:]) + } + store.Lock() + exp.track(pkey, &hash, string(theJwt)) + store.Unlock() + if notify && changed != nil { + changed(pkey) + } + } + } + return nil + }) +} + +func (store *DirJWTStore) pathForKey(publicKey string) string { + if len(publicKey) < 2 { + return "" + } + fileName := fmt.Sprintf("%s%s", publicKey, fileExtension) + if store.shard { + last := publicKey[len(publicKey)-2:] + return filepath.Join(store.directory, last, fileName) + } else { + return filepath.Join(store.directory, fileName) + } +} + +// Load checks the memory store and returns the matching JWT or an error +// Assumes lock is NOT held +func (store *DirJWTStore) load(publicKey string) (string, error) { + store.Lock() + defer store.Unlock() + if path := store.pathForKey(publicKey); path == "" { + return "", fmt.Errorf("invalid public key") + } else if data, err := ioutil.ReadFile(path); err != nil { + return "", err + } else { + if store.expiration != nil { + store.expiration.updateTrack(publicKey) + } + return string(data), nil + } +} + +// write that keeps hash of all jwt in sync +// Assumes the lock is held. Does return true or an error never both. +func (store *DirJWTStore) write(path string, publicKey string, theJWT string) (bool, error) { + var newHash *[sha256.Size]byte + if store.expiration != nil { + h := sha256.Sum256([]byte(theJWT)) + newHash = &h + if v, ok := store.expiration.idx[publicKey]; ok { + store.expiration.updateTrack(publicKey) + // this write is an update, move to back + it := v.Value.(*jwtItem) + oldHash := it.hash[:] + if bytes.Equal(oldHash, newHash[:]) { + return false, nil + } + } else if int64(store.expiration.Len()) >= store.expiration.limit { + if !store.expiration.evictOnLimit { + return false, errors.New("jwt store is full") + } + // this write is an add, pick the least recently used value for removal + i := store.expiration.lru.Front().Value.(*jwtItem) + if err := os.Remove(store.pathForKey(i.publicKey)); err != nil { + return false, err + } else { + store.expiration.unTrack(i.publicKey) + } + } + } + if err := ioutil.WriteFile(path, []byte(theJWT), 0644); err != nil { + return false, err + } else if store.expiration != nil { + store.expiration.track(publicKey, newHash, theJWT) + } + return true, nil +} + +// Save puts the JWT in a map by public key and performs update callbacks +// Assumes lock is NOT held +func (store *DirJWTStore) save(publicKey string, theJWT string) error { + if store.readonly { + return fmt.Errorf("store is read-only") + } + store.Lock() + path := store.pathForKey(publicKey) + if path == "" { + store.Unlock() + return fmt.Errorf("invalid public key") + } + dirPath := filepath.Dir(path) + if _, err := validateDirPath(dirPath); err != nil { + if err := os.MkdirAll(dirPath, 0755); err != nil { + store.Unlock() + return err + } + } + changed, err := store.write(path, publicKey, theJWT) + cb := store.changed + store.Unlock() + if changed && cb != nil { + cb(publicKey) + } + return err +} + +// Assumes the lock is NOT held, and only updates if the jwt is new, or the one on disk is older +// returns true when the jwt changed +func (store *DirJWTStore) saveIfNewer(publicKey string, theJWT string) error { + if store.readonly { + return fmt.Errorf("store is read-only") + } + path := store.pathForKey(publicKey) + if path == "" { + return fmt.Errorf("invalid public key") + } + dirPath := filepath.Dir(path) + if _, err := validateDirPath(dirPath); err != nil { + if err := os.MkdirAll(dirPath, 0755); err != nil { + return err + } + } + if _, err := os.Stat(path); err == nil { + if newJWT, err := jwt.DecodeGeneric(theJWT); err != nil { + // skip if it can't be decoded + } else if existing, err := ioutil.ReadFile(path); err != nil { + return err + } else if existingJWT, err := jwt.DecodeGeneric(string(existing)); err != nil { + // skip if it can't be decoded + } else if existingJWT.ID == newJWT.ID { + return nil + } else if existingJWT.IssuedAt > newJWT.IssuedAt { + return nil + } + } + store.Lock() + cb := store.changed + changed, err := store.write(path, publicKey, theJWT) + store.Unlock() + if err != nil { + return err + } else if changed && cb != nil { + cb(publicKey) + } + return nil +} + +func xorAssign(lVal *[sha256.Size]byte, rVal [sha256.Size]byte) { + for i := range rVal { + (*lVal)[i] ^= rVal[i] + } +} + +// returns a hash representing all indexed jwt +func (store *DirJWTStore) Hash() [sha256.Size]byte { + store.Lock() + defer store.Unlock() + if store.expiration == nil { + return [sha256.Size]byte{} + } else { + return store.expiration.hash + } +} + +// An jwtItem is something managed by the priority queue +type jwtItem struct { + index int + publicKey string + expiration int64 // consists of unix time of expiration (ttl when set or jwt expiration) in seconds + hash [sha256.Size]byte +} + +// A expirationTracker implements heap.Interface and holds Items. +type expirationTracker struct { + heap []*jwtItem // sorted by jwtItem.expiration + idx map[string]*list.Element + lru *list.List // keep which jwt are least used + limit int64 // limit how many jwt are being tracked + evictOnLimit bool // when limit is hit, error or evict using lru + ttl time.Duration + hash [sha256.Size]byte // xor of all jwtItem.hash in idx + quit chan struct{} + wg sync.WaitGroup +} + +func (pq *expirationTracker) Len() int { return len(pq.heap) } + +func (q *expirationTracker) Less(i, j int) bool { + pq := q.heap + return pq[i].expiration < pq[j].expiration +} + +func (q *expirationTracker) Swap(i, j int) { + pq := q.heap + pq[i], pq[j] = pq[j], pq[i] + pq[i].index = i + pq[j].index = j +} + +func (q *expirationTracker) Push(x interface{}) { + n := len(q.heap) + item := x.(*jwtItem) + item.index = n + q.heap = append(q.heap, item) + q.idx[item.publicKey] = q.lru.PushBack(item) +} + +func (q *expirationTracker) Pop() interface{} { + old := q.heap + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + item.index = -1 + q.heap = old[0 : n-1] + q.lru.Remove(q.idx[item.publicKey]) + delete(q.idx, item.publicKey) + return item +} + +func (pq *expirationTracker) updateTrack(publicKey string) { + if e, ok := pq.idx[publicKey]; ok { + i := e.Value.(*jwtItem) + if pq.ttl != 0 { + // only update expiration when set + i.expiration = time.Now().Add(pq.ttl).Unix() + heap.Fix(pq, i.index) + } + if pq.evictOnLimit { + pq.lru.MoveToBack(e) + } + } +} + +func (pq *expirationTracker) unTrack(publicKey string) { + if it, ok := pq.idx[publicKey]; ok { + xorAssign(&pq.hash, it.Value.(*jwtItem).hash) + heap.Remove(pq, it.Value.(*jwtItem).index) + delete(pq.idx, publicKey) + } +} + +func (pq *expirationTracker) track(publicKey string, hash *[sha256.Size]byte, theJWT string) { + var exp int64 + // prioritize ttl over expiration + if pq.ttl != 0 { + if pq.ttl == time.Duration(math.MaxInt64) { + exp = math.MaxInt64 + } else { + exp = time.Now().Add(pq.ttl).Unix() + } + } else { + if g, err := jwt.DecodeGeneric(theJWT); err == nil { + exp = g.Expires + } + if exp == 0 { + exp = math.MaxInt64 // default to indefinite + } + } + if e, ok := pq.idx[publicKey]; ok { + i := e.Value.(*jwtItem) + xorAssign(&pq.hash, i.hash) // remove old hash + i.expiration = exp + i.hash = *hash + heap.Fix(pq, i.index) + } else { + heap.Push(pq, &jwtItem{-1, publicKey, exp, *hash}) + } + xorAssign(&pq.hash, *hash) // add in new hash +} + +func (pq *expirationTracker) close() { + if pq == nil || pq.quit == nil { + return + } + close(pq.quit) + pq.quit = nil +} + +func (store *DirJWTStore) startExpiring(reCheck time.Duration, limit int64, evictOnLimit bool, ttl time.Duration) { + store.Lock() + defer store.Unlock() + quit := make(chan struct{}) + pq := &expirationTracker{ + make([]*jwtItem, 0, 10), + make(map[string]*list.Element), + list.New(), + limit, + evictOnLimit, + ttl, + [sha256.Size]byte{}, + quit, + sync.WaitGroup{}, + } + store.expiration = pq + pq.wg.Add(1) + go func() { + t := time.NewTicker(reCheck) + defer t.Stop() + defer pq.wg.Done() + for { + now := time.Now() + store.Lock() + if pq.Len() > 0 { + if it := heap.Pop(pq).(*jwtItem); it.expiration <= now.Unix() { + path := store.pathForKey(it.publicKey) + if err := os.Remove(path); err != nil { + heap.Push(pq, it) // retry later + } else { + pq.unTrack(it.publicKey) + xorAssign(&pq.hash, it.hash) + store.Unlock() + continue // we removed an entry, check next one + } + } else { + heap.Push(pq, it) + } + } + store.Unlock() + select { + case <-t.C: + case <-quit: + return + } + } + }() +} diff --git a/server/dirstore_test.go b/server/dirstore_test.go new file mode 100644 index 00000000..661d38a1 --- /dev/null +++ b/server/dirstore_test.go @@ -0,0 +1,898 @@ +/* + * Copyright 2020 The NATS Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package server + +import ( + "bytes" + "crypto/sha256" + "fmt" + "io/ioutil" + "math" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/nats-io/jwt/v2" + "github.com/nats-io/nkeys" +) + +func require_True(t *testing.T, b bool) { + t.Helper() + if !b { + t.Errorf("require true, but got false") + } +} + +func require_False(t *testing.T, b bool) { + t.Helper() + if b { + t.Errorf("require no false, but got true") + } +} + +func require_NoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Errorf("require no error, but got: %v", err) + } +} + +func require_Error(t *testing.T, err error) { + t.Helper() + if err == nil { + t.Errorf("require no error, but got: %v", err) + } +} + +func require_Equal(t *testing.T, a, b string) { + t.Helper() + if strings.Compare(a, b) != 0 { + t.Errorf("require equal, but got: %v != %v", a, b) + } +} + +func require_NotEqual(t *testing.T, a, b [32]byte) { + t.Helper() + if bytes.Equal(a[:], b[:]) { + t.Errorf("require not equal, but got: %v != %v", a, b) + } +} + +func require_Len(t *testing.T, a, b int) { + t.Helper() + if a != b { + t.Errorf("require len, but got: %v != %v", a, b) + } +} + +func TestShardedDirStoreWriteAndReadonly(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + + store, err := NewDirJWTStore(dir, true, false) + require_NoError(t, err) + + expected := map[string]string{ + "one": "alpha", + "two": "beta", + "three": "gamma", + "four": "delta", + } + + for k, v := range expected { + store.SaveAcc(k, v) + } + + for k, v := range expected { + got, err := store.LoadAcc(k) + require_NoError(t, err) + require_Equal(t, v, got) + } + + got, err := store.LoadAcc("random") + require_Error(t, err) + require_Equal(t, "", got) + + got, err = store.LoadAcc("") + require_Error(t, err) + require_Equal(t, "", got) + + err = store.SaveAcc("", "onetwothree") + require_Error(t, err) + store.Close() + + // re-use the folder for readonly mode + store, err = NewImmutableDirJWTStore(dir, true) + require_NoError(t, err) + + require_True(t, store.IsReadOnly()) + + err = store.SaveAcc("five", "omega") + require_Error(t, err) + + for k, v := range expected { + got, err := store.LoadAcc(k) + require_NoError(t, err) + require_Equal(t, v, got) + } + store.Close() +} + +func TestUnshardedDirStoreWriteAndReadonly(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + + store, err := NewDirJWTStore(dir, false, false) + require_NoError(t, err) + + expected := map[string]string{ + "one": "alpha", + "two": "beta", + "three": "gamma", + "four": "delta", + } + + require_False(t, store.IsReadOnly()) + + for k, v := range expected { + store.SaveAcc(k, v) + } + + for k, v := range expected { + got, err := store.LoadAcc(k) + require_NoError(t, err) + require_Equal(t, v, got) + } + + got, err := store.LoadAcc("random") + require_Error(t, err) + require_Equal(t, "", got) + + got, err = store.LoadAcc("") + require_Error(t, err) + require_Equal(t, "", got) + + err = store.SaveAcc("", "onetwothree") + require_Error(t, err) + store.Close() + + // re-use the folder for readonly mode + store, err = NewImmutableDirJWTStore(dir, false) + require_NoError(t, err) + + require_True(t, store.IsReadOnly()) + + err = store.SaveAcc("five", "omega") + require_Error(t, err) + + for k, v := range expected { + got, err := store.LoadAcc(k) + require_NoError(t, err) + require_Equal(t, v, got) + } + store.Close() +} + +func TestNoCreateRequiresDir(t *testing.T) { + _, err := NewDirJWTStore("/a/b/c", true, false) + require_Error(t, err) +} + +func TestCreateMakesDir(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + + fullPath := filepath.Join(dir, "a/b") + + _, err = os.Stat(fullPath) + require_Error(t, err) + require_True(t, os.IsNotExist(err)) + + s, err := NewDirJWTStore(fullPath, false, true) + require_NoError(t, err) + s.Close() + + _, err = os.Stat(fullPath) + require_NoError(t, err) +} + +func TestShardedDirStorePackMerge(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + dir2, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + dir3, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + + store, err := NewDirJWTStore(dir, true, false) + require_NoError(t, err) + + expected := map[string]string{ + "one": "alpha", + "two": "beta", + "three": "gamma", + "four": "delta", + } + + require_False(t, store.IsReadOnly()) + + for k, v := range expected { + store.SaveAcc(k, v) + } + + for k, v := range expected { + got, err := store.LoadAcc(k) + require_NoError(t, err) + require_Equal(t, v, got) + } + + got, err := store.LoadAcc("random") + require_Error(t, err) + require_Equal(t, "", got) + + pack, err := store.Pack(-1) + require_NoError(t, err) + + inc, err := NewDirJWTStore(dir2, true, false) + require_NoError(t, err) + + inc.Merge(pack) + + for k, v := range expected { + got, err := inc.LoadAcc(k) + require_NoError(t, err) + require_Equal(t, v, got) + } + + got, err = inc.LoadAcc("random") + require_Error(t, err) + require_Equal(t, "", got) + + limitedPack, err := inc.Pack(1) + require_NoError(t, err) + + limited, err := NewDirJWTStore(dir3, true, false) + + require_NoError(t, err) + + limited.Merge(limitedPack) + + count := 0 + for k, v := range expected { + got, err := limited.LoadAcc(k) + if err == nil { + count++ + require_Equal(t, v, got) + } + } + + require_Len(t, 1, count) + + got, err = inc.LoadAcc("random") + require_Error(t, err) + require_Equal(t, "", got) +} + +func TestShardedToUnsharedDirStorePackMerge(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + dir2, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + + store, err := NewDirJWTStore(dir, true, false) + require_NoError(t, err) + + expected := map[string]string{ + "one": "alpha", + "two": "beta", + "three": "gamma", + "four": "delta", + } + + require_False(t, store.IsReadOnly()) + + for k, v := range expected { + store.SaveAcc(k, v) + } + + for k, v := range expected { + got, err := store.LoadAcc(k) + require_NoError(t, err) + require_Equal(t, v, got) + } + + got, err := store.LoadAcc("random") + require_Error(t, err) + require_Equal(t, "", got) + + pack, err := store.Pack(-1) + require_NoError(t, err) + + inc, err := NewDirJWTStore(dir2, false, false) + require_NoError(t, err) + + inc.Merge(pack) + + for k, v := range expected { + got, err := inc.LoadAcc(k) + require_NoError(t, err) + require_Equal(t, v, got) + } + + got, err = inc.LoadAcc("random") + require_Error(t, err) + require_Equal(t, "", got) + + err = store.Merge("foo") + require_Error(t, err) + + err = store.Merge("") // will skip it + require_NoError(t, err) + + err = store.Merge("a|something") // should fail on a for sharding + require_Error(t, err) +} + +func TestMergeOnlyOnNewer(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + + dirStore, err := NewDirJWTStore(dir, true, false) + require_NoError(t, err) + + accountKey, err := nkeys.CreateAccount() + require_NoError(t, err) + + pubKey, err := accountKey.PublicKey() + require_NoError(t, err) + + account := jwt.NewAccountClaims(pubKey) + account.Name = "old" + olderJWT, err := account.Encode(accountKey) + require_NoError(t, err) + + time.Sleep(2 * time.Second) + + account.Name = "new" + newerJWT, err := account.Encode(accountKey) + require_NoError(t, err) + + // Should work + err = dirStore.SaveAcc(pubKey, olderJWT) + require_NoError(t, err) + fromStore, err := dirStore.LoadAcc(pubKey) + require_NoError(t, err) + require_Equal(t, olderJWT, fromStore) + + // should replace + err = dirStore.saveIfNewer(pubKey, newerJWT) + require_NoError(t, err) + fromStore, err = dirStore.LoadAcc(pubKey) + require_NoError(t, err) + require_Equal(t, newerJWT, fromStore) + + // should fail + err = dirStore.saveIfNewer(pubKey, olderJWT) + require_NoError(t, err) + fromStore, err = dirStore.LoadAcc(pubKey) + require_NoError(t, err) + require_Equal(t, newerJWT, fromStore) +} + +func createTestAccount(t *testing.T, dirStore *DirJWTStore, expSec int, accKey nkeys.KeyPair) string { + t.Helper() + pubKey, err := accKey.PublicKey() + require_NoError(t, err) + account := jwt.NewAccountClaims(pubKey) + if expSec > 0 { + account.Expires = time.Now().Add(time.Second * time.Duration(expSec)).Unix() + } + jwt, err := account.Encode(accKey) + require_NoError(t, err) + err = dirStore.SaveAcc(pubKey, jwt) + require_NoError(t, err) + return jwt +} + +func assertStoreSize(t *testing.T, dirStore *DirJWTStore, length int) { + t.Helper() + f, err := ioutil.ReadDir(dirStore.directory) + require_NoError(t, err) + require_Len(t, len(f), length) + dirStore.Lock() + require_Len(t, len(dirStore.expiration.idx), length) + require_Len(t, dirStore.expiration.lru.Len(), length) + require_Len(t, len(dirStore.expiration.heap), length) + dirStore.Unlock() +} + +func TestExpiration(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + + dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 10, true, 0, nil) + require_NoError(t, err) + defer dirStore.Close() + + account := func(expSec int) { + accountKey, err := nkeys.CreateAccount() + require_NoError(t, err) + createTestAccount(t, dirStore, expSec, accountKey) + } + + h := dirStore.Hash() + + for i := 1; i <= 5; i++ { + account(i * 2) + nh := dirStore.Hash() + require_NotEqual(t, h, nh) + h = nh + } + time.Sleep(1 * time.Second) + for i := 5; i > 0; i-- { + f, err := ioutil.ReadDir(dir) + require_NoError(t, err) + require_Len(t, len(f), i) + assertStoreSize(t, dirStore, i) + + time.Sleep(2 * time.Second) + + nh := dirStore.Hash() + require_NotEqual(t, h, nh) + h = nh + } + assertStoreSize(t, dirStore, 0) +} + +func TestLimit(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + + dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 5, true, 0, nil) + require_NoError(t, err) + defer dirStore.Close() + + account := func(expSec int) { + accountKey, err := nkeys.CreateAccount() + require_NoError(t, err) + createTestAccount(t, dirStore, expSec, accountKey) + } + + h := dirStore.Hash() + + accountKey, err := nkeys.CreateAccount() + require_NoError(t, err) + // update first account + for i := 0; i < 10; i++ { + createTestAccount(t, dirStore, 50, accountKey) + assertStoreSize(t, dirStore, 1) + } + // new accounts + for i := 0; i < 10; i++ { + account(i) + nh := dirStore.Hash() + require_NotEqual(t, h, nh) + h = nh + } + // first account should be gone now accountKey.PublicKey() + key, _ := accountKey.PublicKey() + _, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, key)) + require_True(t, os.IsNotExist(err)) + + // update first account + for i := 0; i < 10; i++ { + createTestAccount(t, dirStore, 50, accountKey) + assertStoreSize(t, dirStore, 5) + } +} + +func TestLimitNoEvict(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + + dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 2, false, 0, nil) + require_NoError(t, err) + defer dirStore.Close() + + accountKey1, err := nkeys.CreateAccount() + require_NoError(t, err) + pKey1, err := accountKey1.PublicKey() + require_NoError(t, err) + accountKey2, err := nkeys.CreateAccount() + require_NoError(t, err) + accountKey3, err := nkeys.CreateAccount() + require_NoError(t, err) + pKey3, err := accountKey3.PublicKey() + require_NoError(t, err) + + createTestAccount(t, dirStore, 100, accountKey1) + assertStoreSize(t, dirStore, 1) + createTestAccount(t, dirStore, 2, accountKey2) + assertStoreSize(t, dirStore, 2) + + hBefore := dirStore.Hash() + // 2 jwt are already stored. third must result in an error + pubKey, err := accountKey3.PublicKey() + require_NoError(t, err) + account := jwt.NewAccountClaims(pubKey) + jwt, err := account.Encode(accountKey3) + require_NoError(t, err) + err = dirStore.SaveAcc(pubKey, jwt) + require_Error(t, err) + assertStoreSize(t, dirStore, 2) + _, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1)) + require_NoError(t, err) + _, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3)) + require_True(t, os.IsNotExist(err)) + // check that the hash did not change + hAfter := dirStore.Hash() + require_True(t, bytes.Equal(hBefore[:], hAfter[:])) + // wait for expiration of account2 + time.Sleep(3 * time.Second) + err = dirStore.SaveAcc(pubKey, jwt) + require_NoError(t, err) + assertStoreSize(t, dirStore, 2) + _, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1)) + require_NoError(t, err) + _, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3)) + require_NoError(t, err) +} + +func TestLruLoad(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 2, true, 0, nil) + require_NoError(t, err) + defer dirStore.Close() + + accountKey1, err := nkeys.CreateAccount() + require_NoError(t, err) + pKey1, err := accountKey1.PublicKey() + require_NoError(t, err) + accountKey2, err := nkeys.CreateAccount() + require_NoError(t, err) + accountKey3, err := nkeys.CreateAccount() + require_NoError(t, err) + pKey3, err := accountKey3.PublicKey() + require_NoError(t, err) + + createTestAccount(t, dirStore, 10, accountKey1) + assertStoreSize(t, dirStore, 1) + createTestAccount(t, dirStore, 10, accountKey2) + assertStoreSize(t, dirStore, 2) + dirStore.LoadAcc(pKey1) // will reorder 1/2 + createTestAccount(t, dirStore, 10, accountKey3) + assertStoreSize(t, dirStore, 2) + + _, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1)) + require_NoError(t, err) + _, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3)) + require_NoError(t, err) +} + +func TestLru(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + + dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 2, true, 0, nil) + require_NoError(t, err) + defer dirStore.Close() + + accountKey1, err := nkeys.CreateAccount() + require_NoError(t, err) + pKey1, err := accountKey1.PublicKey() + require_NoError(t, err) + accountKey2, err := nkeys.CreateAccount() + require_NoError(t, err) + accountKey3, err := nkeys.CreateAccount() + require_NoError(t, err) + pKey3, err := accountKey3.PublicKey() + require_NoError(t, err) + + createTestAccount(t, dirStore, 10, accountKey1) + assertStoreSize(t, dirStore, 1) + createTestAccount(t, dirStore, 10, accountKey2) + assertStoreSize(t, dirStore, 2) + createTestAccount(t, dirStore, 10, accountKey3) + assertStoreSize(t, dirStore, 2) + _, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1)) + require_True(t, os.IsNotExist(err)) + _, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3)) + require_NoError(t, err) + + // update -> will change this keys position for eviction + createTestAccount(t, dirStore, 10, accountKey2) + assertStoreSize(t, dirStore, 2) + // recreate -> will evict 3 + createTestAccount(t, dirStore, 1, accountKey1) + assertStoreSize(t, dirStore, 2) + _, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3)) + require_True(t, os.IsNotExist(err)) + // let key1 expire + time.Sleep(2 * time.Second) + assertStoreSize(t, dirStore, 1) + _, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1)) + require_True(t, os.IsNotExist(err)) + // recreate key3 - no eviction + createTestAccount(t, dirStore, 10, accountKey3) + assertStoreSize(t, dirStore, 2) +} + +func TestReload(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + notificationChan := make(chan struct{}, 5) + dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 2, true, 0, func(publicKey string) { + notificationChan <- struct{}{} + }) + require_NoError(t, err) + defer dirStore.Close() + newAccount := func() string { + t.Helper() + accKey, err := nkeys.CreateAccount() + require_NoError(t, err) + pKey, err := accKey.PublicKey() + require_NoError(t, err) + pubKey, err := accKey.PublicKey() + require_NoError(t, err) + account := jwt.NewAccountClaims(pubKey) + jwt, err := account.Encode(accKey) + require_NoError(t, err) + file := fmt.Sprintf("%s/%s.jwt", dir, pKey) + err = ioutil.WriteFile(file, []byte(jwt), 0644) + require_NoError(t, err) + return file + } + files := make(map[string]struct{}) + assertStoreSize(t, dirStore, 0) + hash := dirStore.Hash() + emptyHash := [sha256.Size]byte{} + require_True(t, bytes.Equal(hash[:], emptyHash[:])) + for i := 0; i < 5; i++ { + files[newAccount()] = struct{}{} + err = dirStore.Reload() + require_NoError(t, err) + <-notificationChan + assertStoreSize(t, dirStore, i+1) + hash = dirStore.Hash() + require_False(t, bytes.Equal(hash[:], emptyHash[:])) + msg, err := dirStore.Pack(-1) + require_NoError(t, err) + require_Len(t, len(strings.Split(msg, "\n")), len(files)) + } + for k := range files { + hash = dirStore.Hash() + require_False(t, bytes.Equal(hash[:], emptyHash[:])) + os.Remove(k) + err = dirStore.Reload() + require_NoError(t, err) + assertStoreSize(t, dirStore, len(files)-1) + delete(files, k) + msg, err := dirStore.Pack(-1) + require_NoError(t, err) + if len(files) != 0 { // when len is 0, we have an empty line + require_Len(t, len(strings.Split(msg, "\n")), len(files)) + } + } + require_True(t, len(notificationChan) == 0) + hash = dirStore.Hash() + require_True(t, bytes.Equal(hash[:], emptyHash[:])) +} + +func TestExpirationUpdate(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + + dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 10, true, 0, nil) + require_NoError(t, err) + defer dirStore.Close() + + accountKey, err := nkeys.CreateAccount() + require_NoError(t, err) + + h := dirStore.Hash() + + createTestAccount(t, dirStore, 0, accountKey) + nh := dirStore.Hash() + require_NotEqual(t, h, nh) + h = nh + + time.Sleep(2 * time.Second) + f, err := ioutil.ReadDir(dir) + require_NoError(t, err) + require_Len(t, len(f), 1) + + createTestAccount(t, dirStore, 5, accountKey) + nh = dirStore.Hash() + require_NotEqual(t, h, nh) + h = nh + + time.Sleep(2 * time.Second) + f, err = ioutil.ReadDir(dir) + require_NoError(t, err) + require_Len(t, len(f), 1) + + createTestAccount(t, dirStore, 0, accountKey) + nh = dirStore.Hash() + require_NotEqual(t, h, nh) + h = nh + + time.Sleep(2 * time.Second) + f, err = ioutil.ReadDir(dir) + require_NoError(t, err) + require_Len(t, len(f), 1) + + createTestAccount(t, dirStore, 1, accountKey) + nh = dirStore.Hash() + require_NotEqual(t, h, nh) + h = nh + + time.Sleep(2 * time.Second) + f, err = ioutil.ReadDir(dir) + require_NoError(t, err) + require_Len(t, len(f), 0) + + empty := [32]byte{} + h = dirStore.Hash() + require_Equal(t, string(h[:]), string(empty[:])) +} + +func TestTTL(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + require_OneJWT := func() { + t.Helper() + f, err := ioutil.ReadDir(dir) + require_NoError(t, err) + require_Len(t, len(f), 1) + } + dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 10, true, 2*time.Second, nil) + require_NoError(t, err) + defer dirStore.Close() + + accountKey, err := nkeys.CreateAccount() + require_NoError(t, err) + pubKey, err := accountKey.PublicKey() + require_NoError(t, err) + jwt := createTestAccount(t, dirStore, 0, accountKey) + require_OneJWT() + for i := 0; i < 6; i++ { + time.Sleep(time.Second) + dirStore.LoadAcc(pubKey) + require_OneJWT() + } + for i := 0; i < 6; i++ { + time.Sleep(time.Second) + dirStore.SaveAcc(pubKey, jwt) + require_OneJWT() + } + for i := 0; i < 6; i++ { + time.Sleep(time.Second) + createTestAccount(t, dirStore, 0, accountKey) + require_OneJWT() + } + time.Sleep(3 * time.Second) + f, err := ioutil.ReadDir(dir) + require_NoError(t, err) + require_Len(t, len(f), 0) +} + +const infDur = time.Duration(math.MaxInt64) + +func TestNotificationOnPack(t *testing.T) { + jwts := map[string]string{ + "key1": "value", + "key2": "value", + "key3": "value", + "key4": "value", + } + notificationChan := make(chan struct{}, len(jwts)) // set to same len so all extra will block + notification := func(pubKey string) { + if _, ok := jwts[pubKey]; !ok { + t.Errorf("Key not found: %s", pubKey) + } + notificationChan <- struct{}{} + } + dirPack, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + packStore, err := NewExpiringDirJWTStore(dirPack, false, false, infDur, 0, true, 0, notification) + require_NoError(t, err) + // prefill the store with data + for k, v := range jwts { + require_NoError(t, packStore.SaveAcc(k, v)) + } + for i := 0; i < len(jwts); i++ { + <-notificationChan + } + msg, err := packStore.Pack(-1) + require_NoError(t, err) + packStore.Close() + hash := packStore.Hash() + for _, shard := range []bool{true, false, true, false} { + dirMerge, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + mergeStore, err := NewExpiringDirJWTStore(dirMerge, shard, false, infDur, 0, true, 0, notification) + require_NoError(t, err) + // set + err = mergeStore.Merge(msg) + require_NoError(t, err) + assertStoreSize(t, mergeStore, len(jwts)) + hash1 := packStore.Hash() + require_True(t, bytes.Equal(hash[:], hash1[:])) + for i := 0; i < len(jwts); i++ { + <-notificationChan + } + // overwrite - assure + err = mergeStore.Merge(msg) + require_NoError(t, err) + assertStoreSize(t, mergeStore, len(jwts)) + hash2 := packStore.Hash() + require_True(t, bytes.Equal(hash1[:], hash2[:])) + + hash = hash1 + msg, err = mergeStore.Pack(-1) + require_NoError(t, err) + mergeStore.Close() + require_True(t, len(notificationChan) == 0) + + for k, v := range jwts { + j, err := packStore.LoadAcc(k) + require_NoError(t, err) + require_Equal(t, j, v) + } + } +} + +func TestNotificationOnPackWalk(t *testing.T) { + const storeCnt = 5 + const keyCnt = 50 + const iterCnt = 8 + store := [storeCnt]*DirJWTStore{} + for i := 0; i < storeCnt; i++ { + dirMerge, err := ioutil.TempDir(os.TempDir(), "jwtstore_test") + require_NoError(t, err) + mergeStore, err := NewExpiringDirJWTStore(dirMerge, true, false, infDur, 0, true, 0, nil) + require_NoError(t, err) + store[i] = mergeStore + } + for i := 0; i < iterCnt; i++ { //iterations + jwt := make(map[string]string) + for j := 0; j < keyCnt; j++ { + key := fmt.Sprintf("key%d-%d", i, j) + value := "value" + jwt[key] = value + store[0].SaveAcc(key, value) + } + for j := 0; j < storeCnt-1; j++ { // stores + err := store[j].PackWalk(3, func(partialPackMsg string) { + err := store[j+1].Merge(partialPackMsg) + require_NoError(t, err) + }) + require_NoError(t, err) + } + for i := 0; i < storeCnt-1; i++ { + h1 := store[i].Hash() + h2 := store[i+1].Hash() + require_True(t, bytes.Equal(h1[:], h2[:])) + } + } + for i := 0; i < storeCnt; i++ { + store[i].Close() + } +} diff --git a/server/events.go b/server/events.go index 212e12a0..9d56ba1a 100644 --- a/server/events.go +++ b/server/events.go @@ -584,9 +584,15 @@ func (s *Server) initEventTracking() { s.Errorf("Error setting up internal tracking: %v", err) } // Listen for account claims updates. - subject = fmt.Sprintf(accUpdateEventSubj, "*") - if _, err := s.sysSubscribe(subject, s.accountClaimUpdate); err != nil { - s.Errorf("Error setting up internal tracking: %v", err) + subscribeToUpdate := true + if s.accResolver != nil { + subscribeToUpdate = !s.accResolver.IsTrackingUpdate() + } + if subscribeToUpdate { + subject = fmt.Sprintf(accUpdateEventSubj, "*") + if _, err := s.sysSubscribe(subject, s.accountClaimUpdate); err != nil { + s.Errorf("Error setting up internal tracking: %v", err) + } } // Listen for requests for our statsz. subject = fmt.Sprintf(serverStatsReqSubj, s.info.ID) @@ -647,7 +653,6 @@ func (s *Server) initEventTracking() { if _, err := s.sysSubscribe(subject, s.remoteLatencyUpdate); err != nil { s.Errorf("Error setting up internal latency tracking: %v", err) } - // This is for simple debugging of number of subscribers that exist in the system. if _, err := s.sysSubscribeInternal(accSubsSubj, s.debugSubscribers); err != nil { s.Errorf("Error setting up internal debug service for subscribers: %v", err) @@ -1239,17 +1244,22 @@ func (s *Server) sendAuthErrorEvent(c *client) { // required to be copied. type msgHandler func(sub *subscription, client *client, subject, reply string, msg []byte) -// Create an internal subscription. No support for queue groups atm. +// Create an internal subscription. sysSubscribeQ for queue groups. func (s *Server) sysSubscribe(subject string, cb msgHandler) (*subscription, error) { - return s.systemSubscribe(subject, false, cb) + return s.systemSubscribe(subject, "", false, cb) +} + +// Create an internal subscription with queue +func (s *Server) sysSubscribeQ(subject, queue string, cb msgHandler) (*subscription, error) { + return s.systemSubscribe(subject, queue, false, cb) } // Create an internal subscription but do not forward interest. func (s *Server) sysSubscribeInternal(subject string, cb msgHandler) (*subscription, error) { - return s.systemSubscribe(subject, true, cb) + return s.systemSubscribe(subject, "", true, cb) } -func (s *Server) systemSubscribe(subject string, internalOnly bool, cb msgHandler) (*subscription, error) { +func (s *Server) systemSubscribe(subject, queue string, internalOnly bool, cb msgHandler) (*subscription, error) { if !s.eventsEnabled() { return nil, ErrNoSysAccount } @@ -1263,12 +1273,16 @@ func (s *Server) systemSubscribe(subject string, internalOnly bool, cb msgHandle sid := strconv.Itoa(s.sys.sid) s.mu.Unlock() + // Now create the subscription if trace { - c.traceInOp("SUB", []byte(subject+" "+sid)) + c.traceInOp("SUB", []byte(subject+" "+queue+" "+sid)) } - // Now create the subscription - return c.processSub([]byte(subject), nil, []byte(sid), cb, internalOnly) + var q []byte + if queue != "" { + q = []byte(queue) + } + return c.processSub([]byte(subject), q, []byte(sid), cb, internalOnly) } func (s *Server) sysUnsubscribe(sub *subscription) { diff --git a/server/jwt_test.go b/server/jwt_test.go index bba85a40..f13bb23d 100644 --- a/server/jwt_test.go +++ b/server/jwt_test.go @@ -23,6 +23,7 @@ import ( "net/http" "net/http/httptest" "os" + "path/filepath" "strings" "sync" "sync/atomic" @@ -2868,3 +2869,319 @@ func TestExpiredUserCredentialsRenewal(t *testing.T) { t.Fatalf("Expected lastErr to be cleared, got %q", nc.LastError()) } } + +func TestAccountNATSResolverFetch(t *testing.T) { + require_NextMsg := func(sub *nats.Subscription) bool { + msg := natsNexMsg(t, sub, 10*time.Second) + content := make(map[string]interface{}) + json.Unmarshal(msg.Data, &content) + if _, ok := content["data"]; ok { + return true + } + return false + } + require_FileAbsent := func(dir string, pub string) { + t.Helper() + _, err := os.Stat(filepath.Join(dir, pub+".jwt")) + require_Error(t, err) + require_True(t, os.IsNotExist(err)) + } + require_FilePresent := func(dir string, pub string) { + t.Helper() + _, err := os.Stat(filepath.Join(dir, pub+".jwt")) + require_NoError(t, err) + } + require_FileEqual := func(dir string, pub string, jwt string) { + t.Helper() + content, err := ioutil.ReadFile(filepath.Join(dir, pub+".jwt")) + require_NoError(t, err) + require_Equal(t, string(content), jwt) + } + require_1Connection := func(url string, creds string) { + t.Helper() + c := natsConnect(t, url, nats.UserCredentials(creds)) + defer c.Close() + if _, err := nats.Connect(url, nats.UserCredentials(creds)); err == nil { + t.Fatal("Second connection was supposed to fail due to limits") + } else if !strings.Contains(err.Error(), ErrTooManyAccountConnections.Error()) { + t.Fatal("Second connection was supposed to fail with too many conns") + } + } + require_2Connection := func(url string, creds string) { + t.Helper() + c1 := natsConnect(t, url, nats.UserCredentials(creds)) + defer c1.Close() + c2 := natsConnect(t, url, nats.UserCredentials(creds)) + defer c2.Close() + if _, err := nats.Connect(url, nats.UserCredentials(creds)); err == nil { + t.Fatal("Third connection was supposed to fail due to limits") + } else if !strings.Contains(err.Error(), ErrTooManyAccountConnections.Error()) { + t.Fatal("Third connection was supposed to fail with too many conns") + } + } + writeFile := func(dir string, pub string, jwt string) { + t.Helper() + err := ioutil.WriteFile(filepath.Join(dir, pub+".jwt"), []byte(jwt), 0644) + require_NoError(t, err) + } + createDir := func(prefix string) string { + t.Helper() + dir, err := ioutil.TempDir("", prefix) + require_NoError(t, err) + return dir + } + connect := func(url string, credsfile string) { + t.Helper() + nc := natsConnect(t, url, nats.UserCredentials(credsfile)) + nc.Close() + } + createAccountAndUser := func(pair nkeys.KeyPair, limit bool) (string, string, string, string) { + t.Helper() + kp, _ := nkeys.CreateAccount() + pub, _ := kp.PublicKey() + claim := jwt.NewAccountClaims(pub) + if limit { + claim.Limits.Conn = 1 + } + jwt1, err := claim.Encode(pair) + require_NoError(t, err) + time.Sleep(2 * time.Second) + // create updated claim allowing more connections + if limit { + claim.Limits.Conn = 2 + } + jwt2, err := claim.Encode(pair) + require_NoError(t, err) + ukp, _ := nkeys.CreateUser() + seed, _ := ukp.Seed() + upub, _ := ukp.PublicKey() + uclaim := newJWTTestUserClaims() + uclaim.Subject = upub + ujwt, err := uclaim.Encode(kp) + require_NoError(t, err) + return pub, jwt1, jwt2, genCredsFile(t, ujwt, seed) + } + updateJwt := func(url string, creds string, pubKey string, jwt string) int { + t.Helper() + c := natsConnect(t, url, nats.UserCredentials(creds), + nats.DisconnectErrHandler(func(_ *nats.Conn, err error) { + if err != nil { + t.Fatal("error not expected in this test", err) + } + }), + nats.ErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) { + t.Fatal("error not expected in this test", err) + }), + ) + defer c.Close() + resp := c.NewRespInbox() + sub := natsSubSync(t, c, resp) + err := sub.AutoUnsubscribe(3) + require_NoError(t, err) + require_NoError(t, c.PublishRequest(fmt.Sprintf(accUpdateEventSubj, pubKey), resp, []byte(jwt))) + passCnt := 0 + if require_NextMsg(sub) { + passCnt++ + } + if require_NextMsg(sub) { + passCnt++ + } + if require_NextMsg(sub) { + passCnt++ + } + return passCnt + } + // Create Operator + op, _ := nkeys.CreateOperator() + opub, _ := op.PublicKey() + oc := jwt.NewOperatorClaims(opub) + oc.Subject = opub + ojwt, err := oc.Encode(op) + require_NoError(t, err) + // Create Accounts and corresponding user creds + syspub, sysjwt, _, sysCreds := createAccountAndUser(op, false) + defer os.Remove(sysCreds) + apub, ajwt1, ajwt2, aCreds := createAccountAndUser(op, true) + defer os.Remove(aCreds) + bpub, bjwt1, bjwt2, bCreds := createAccountAndUser(op, true) + defer os.Remove(bCreds) + cpub, cjwt1, cjwt2, cCreds := createAccountAndUser(op, true) + defer os.Remove(cCreds) + // Create one directory for each server + dirA := createDir("srv-a") + defer os.RemoveAll(dirA) + dirB := createDir("srv-b") + defer os.RemoveAll(dirB) + dirC := createDir("srv-c") + defer os.RemoveAll(dirC) + // simulate a restart of the server by storing files in them + // Server A/B will completely sync, so after startup each server + // will contain the union off all stored/configured jwt + // Server C will send out lookup requests for jwt it does not store itself + writeFile(dirA, apub, ajwt1) + writeFile(dirB, bpub, bjwt1) + writeFile(dirC, cpub, cjwt1) + // Create seed server A (using no_advertise to prevent fail over) + confA := createConfFile(t, []byte(fmt.Sprintf(` + listen: -1 + server_name: srv-A + operator: %s + system_account: %s + resolver: { + type: full + dir: %s + interval: "1s" + limit: 4 + } + resolver_preload: { + %s: %s + } + cluster { + name: clust + listen: -1 + no_advertise: true + } + `, ojwt, syspub, dirA, cpub, cjwt1))) + defer os.Remove(confA) + sA, _ := RunServerWithConfig(confA) + defer sA.Shutdown() + // during startup resolver_preload causes the directory to contain data + require_FilePresent(dirA, cpub) + // Create Server B (using no_advertise to prevent fail over) + confB := createConfFile(t, []byte(fmt.Sprintf(` + listen: -1 + server_name: srv-B + operator: %s + system_account: %s + resolver: { + type: full + dir: %s + interval: "1s" + limit: 4 + } + cluster { + name: clust + listen: -1 + no_advertise: true + routes [ + nats-route://localhost:%d + ] + } + `, ojwt, syspub, dirB, sA.opts.Cluster.Port))) + defer os.Remove(confB) + sB, _ := RunServerWithConfig(confB) + // Create Server C (using no_advertise to prevent fail over) + confC := createConfFile(t, []byte(fmt.Sprintf(` + listen: -1 + server_name: srv-C + operator: %s + system_account: %s + resolver: { + type: cache + dir: %s + ttl: "10s" + limit: 4 + } + cluster { + name: clust + listen: -1 + no_advertise: true + routes [ + nats-route://localhost:%d + ] + } + `, ojwt, syspub, dirC, sA.opts.Cluster.Port))) + defer os.Remove(confC) + sC, _ := RunServerWithConfig(confC) + // startup cluster + checkClusterFormed(t, sA, sB, sC) + time.Sleep(2 * time.Second) // wait for the protocol to converge + + // Check all accounts + require_FilePresent(dirA, apub) // was already present on startup + require_FilePresent(dirB, apub) // was copied from server A + require_FileAbsent(dirC, apub) + require_FilePresent(dirA, bpub) // was copied from server B + require_FilePresent(dirB, bpub) // was already present on startup + require_FileAbsent(dirC, bpub) + require_FilePresent(dirA, cpub) // was present in preload + require_FilePresent(dirB, cpub) // was copied from server A + require_FilePresent(dirC, cpub) // was already present on startup + // This is to test that connecting to it still works + require_FileAbsent(dirA, syspub) + require_FileAbsent(dirB, syspub) + require_FileAbsent(dirC, syspub) + // system account client can connect to every server + connect(sA.ClientURL(), sysCreds) + connect(sB.ClientURL(), sysCreds) + connect(sC.ClientURL(), sysCreds) + checkClusterFormed(t, sA, sB, sC) + // upload system account and require a response from each server + passCnt := updateJwt(sA.ClientURL(), sysCreds, syspub, sysjwt) + require_True(t, passCnt == 3) + require_FilePresent(dirA, syspub) // was just received + require_FilePresent(dirB, syspub) // was just received + require_FilePresent(dirC, syspub) // was just received + // Only files missing are in C, which is only caching + connect(sC.ClientURL(), aCreds) + connect(sC.ClientURL(), bCreds) + require_FilePresent(dirC, apub) // was looked up form A or B + require_FilePresent(dirC, bpub) // was looked up from A or B + + // Check limits and update jwt B connecting to server A + for port, v := range map[string]struct{ pub, jwt, creds string }{ + sB.ClientURL(): {bpub, bjwt2, bCreds}, + sC.ClientURL(): {cpub, cjwt2, cCreds}, + } { + require_1Connection(sA.ClientURL(), v.creds) + require_1Connection(sB.ClientURL(), v.creds) + require_1Connection(sC.ClientURL(), v.creds) + checkClientsCount(t, sA, 0) + checkClientsCount(t, sB, 0) + checkClientsCount(t, sC, 0) + passCnt := updateJwt(port, sysCreds, v.pub, v.jwt) + require_True(t, passCnt == 3) + require_2Connection(sA.ClientURL(), v.creds) + require_2Connection(sB.ClientURL(), v.creds) + require_2Connection(sC.ClientURL(), v.creds) + require_FileEqual(dirA, v.pub, v.jwt) + require_FileEqual(dirB, v.pub, v.jwt) + require_FileEqual(dirC, v.pub, v.jwt) + } + + // Simulates A having missed an update + // shutting B down as it has it will directly connect to A and connect right away + sB.Shutdown() + writeFile(dirB, apub, ajwt2) // this will be copied to server A + sB, _ = RunServerWithConfig(confB) + defer sB.Shutdown() + checkClusterFormed(t, sA, sB, sC) + time.Sleep(2 * time.Second) // wait for the protocol to converge, will also assure that C's cache becomes invalid + // Restart server C. this is a workaround to force C to do a lookup in the absence of account cleanup + sC.Shutdown() + sC, _ = RunServerWithConfig(confC) //TODO remove this once we clean up accounts + + require_FileEqual(dirA, apub, ajwt2) // was copied from server B + require_FileEqual(dirB, apub, ajwt2) // was restarted with this + require_FileEqual(dirC, apub, ajwt1) // still contains old cached value + require_2Connection(sA.ClientURL(), aCreds) + require_2Connection(sB.ClientURL(), aCreds) + require_1Connection(sC.ClientURL(), aCreds) + + // Restart server C. this is a workaround to force C to do a lookup in the absence of account cleanup + sC.Shutdown() + sC, _ = RunServerWithConfig(confC) //TODO remove this once we clean up accounts + defer sC.Shutdown() + require_FileEqual(dirC, apub, ajwt1) // still contains old cached value + time.Sleep(time.Second * 12) // Force next connect to do a lookup + connect(sC.ClientURL(), aCreds) // When lookup happens + require_FileEqual(dirC, apub, ajwt2) // was looked up form A or B + require_2Connection(sC.ClientURL(), aCreds) + + // Test exceeding limit. For the exclusive directory resolver, limit is a stop gap measure. + // It is not expected to be hit. When hit the administrator is supposed to take action. + dpub, djwt1, _, dCreds := createAccountAndUser(op, true) + defer os.Remove(dCreds) + passCnt = updateJwt(sA.ClientURL(), sysCreds, dpub, djwt1) + require_True(t, passCnt == 1) // Only Server C updated +} diff --git a/server/opts.go b/server/opts.go index bed56494..ec6a413a 100644 --- a/server/opts.go +++ b/server/opts.go @@ -816,22 +816,16 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error } } case "resolver", "account_resolver", "accounts_resolver": - // "resolver" takes precedence over value obtained from "operator". - // Clear so that parsing errors are not silently ignored. - o.AccountResolver = nil - var memResolverRe = regexp.MustCompile(`(MEM|MEMORY|mem|memory)\s*`) - var resolverRe = regexp.MustCompile(`(?:URL|url){1}(?:\({1}\s*"?([^\s"]*)"?\s*\){1})?\s*`) - str, ok := v.(string) - if !ok { - err := &configErr{tk, fmt.Sprintf("error parsing operator resolver, wrong type %T", v)} - *errors = append(*errors, err) - return - } - if memResolverRe.MatchString(str) { - o.AccountResolver = &MemAccResolver{} - } else { - items := resolverRe.FindStringSubmatch(str) - if len(items) == 2 { + switch v := v.(type) { + case string: + // "resolver" takes precedence over value obtained from "operator". + // Clear so that parsing errors are not silently ignored. + o.AccountResolver = nil + memResolverRe := regexp.MustCompile(`(?i)(MEM|MEMORY)\s*`) + resolverRe := regexp.MustCompile(`(?i)(?:URL){1}(?:\({1}\s*"?([^\s"]*)"?\s*\){1})?\s*`) + if memResolverRe.MatchString(v) { + o.AccountResolver = &MemAccResolver{} + } else if items := resolverRe.FindStringSubmatch(v); len(items) == 2 { url := items[1] _, err := parseURL(url, "account resolver") if err != nil { @@ -846,9 +840,70 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error o.AccountResolver = ur } } + case map[string]interface{}: + dir := "" + dirType := "" + limit := int64(0) + ttl := time.Duration(0) + sync := time.Duration(0) + var err error + if v, ok := v["dir"]; ok { + _, v := unwrapValue(v, <) + dir = v.(string) + } + if v, ok := v["type"]; ok { + _, v := unwrapValue(v, <) + dirType = v.(string) + } + if v, ok := v["limit"]; ok { + _, v := unwrapValue(v, <) + limit = v.(int64) + } + if v, ok := v["ttl"]; ok { + _, v := unwrapValue(v, <) + ttl, err = time.ParseDuration(v.(string)) + } + if v, ok := v["interval"]; err == nil && ok { + _, v := unwrapValue(v, <) + sync, err = time.ParseDuration(v.(string)) + } + if err != nil { + *errors = append(*errors, &configErr{tk, err.Error()}) + return + } + if dir == "" { + *errors = append(*errors, &configErr{tk, "dir needs to point to a directory"}) + return + } + if info, err := os.Stat(dir); err != nil || !info.IsDir() || info.Mode().Perm()&(1<<(uint(7))) == 0 { + info.IsDir() + } + var res AccountResolver + switch strings.ToUpper(dirType) { + case "CACHE": + if sync != 0 { + *errors = append(*errors, &configErr{tk, "CACHE does not accept sync"}) + } + res, err = NewCacheDirAccResolver(dir, limit, ttl) + case "FULL": + if ttl != 0 { + *errors = append(*errors, &configErr{tk, "FULL does not accept ttl"}) + } + res, err = NewDirAccResolver(dir, limit, sync) + } + if err != nil { + *errors = append(*errors, &configErr{tk, err.Error()}) + return + } + o.AccountResolver = res + default: + err := &configErr{tk, fmt.Sprintf("error parsing operator resolver, wrong type %T", v)} + *errors = append(*errors, err) + return } if o.AccountResolver == nil { - err := &configErr{tk, "error parsing account resolver, should be MEM or URL(\"url\")"} + err := &configErr{tk, "error parsing account resolver, should be MEM or " + + " URL(\"url\") or a map containing dir and type state=[FULL|CACHE])"} *errors = append(*errors, err) } case "resolver_tls": diff --git a/server/reload.go b/server/reload.go index 3df1c10c..acd89e69 100644 --- a/server/reload.go +++ b/server/reload.go @@ -770,7 +770,7 @@ func imposeOrder(value interface{}) error { return value.AllowedOrigins[i] < value.AllowedOrigins[j] }) case string, bool, int, int32, int64, time.Duration, float64, nil, - LeafNodeOpts, ClusterOpts, *tls.Config, *URLAccResolver, *MemAccResolver, Authentication: + LeafNodeOpts, ClusterOpts, *tls.Config, *URLAccResolver, *MemAccResolver, *DirAccResolver, *CacheDirAccResolver, Authentication: // explicitly skipped types default: // this will fail during unit tests @@ -1275,6 +1275,10 @@ func (s *Server) reloadAuthorization() { if checkJetStream { s.configAllJetStreamAccounts() } + + if res := s.AccountResolver(); res != nil { + res.Reload() + } } // Returns true if given client current account has changed (or user diff --git a/server/server.go b/server/server.go index 4ae18d40..b8c1dab9 100644 --- a/server/server.go +++ b/server/server.go @@ -360,14 +360,11 @@ func NewServer(opts *Options) (*Server, error) { // waiting for complete shutdown. s.shutdownComplete = make(chan struct{}) - // For tracking accounts - if err := s.configureAccounts(); err != nil { + // Check for configured account resolvers. + if err := s.configureResolver(); err != nil { return nil, err } - // If there is an URL account resolver, do basic test to see if anyone is home. - // Do this after configureAccounts() which calls configureResolver(), which will - // set TLSConfig if specified. if ar := opts.AccountResolver; ar != nil { if ur, ok := ar.(*URLAccResolver); ok { if _, err := ur.Fetch(""); err != nil { @@ -375,6 +372,31 @@ func NewServer(opts *Options) (*Server, error) { } } } + // For other resolver: + // In operator mode, when the account resolver depends on an external system and + // the system account can't fetched, inject a temporary one. + if ar := s.accResolver; len(opts.TrustedOperators) == 1 && ar != nil && + opts.SystemAccount != _EMPTY_ && opts.SystemAccount != DEFAULT_SYSTEM_ACCOUNT { + if _, ok := ar.(*MemAccResolver); !ok { + s.mu.Unlock() + var a *Account + // perform direct lookup to avoid warning trace + if _, err := ar.Fetch(s.opts.SystemAccount); err == nil { + a, _ = s.fetchAccount(s.opts.SystemAccount) + } + s.mu.Lock() + if a == nil { + sac := NewAccount(s.opts.SystemAccount) + sac.Issuer = opts.TrustedOperators[0].Issuer + s.registerAccountNoLock(sac) + } + } + } + + // For tracking accounts + if err := s.configureAccounts(); err != nil { + return nil, err + } // In local config mode, check that leafnode configuration // refers to account that exist. @@ -610,11 +632,6 @@ func (s *Server) configureAccounts() error { return true }) - // Check for configured account resolvers. - if err := s.configureResolver(); err != nil { - return err - } - // Set the system account if it was configured. // Otherwise create a default one. if opts.SystemAccount != _EMPTY_ { @@ -654,8 +671,8 @@ func (s *Server) configureResolver() error { } } if len(opts.resolverPreloads) > 0 { - if _, ok := s.accResolver.(*MemAccResolver); !ok { - return fmt.Errorf("resolver preloads only available for resolver type MEM") + if s.accResolver.IsReadOnly() { + return fmt.Errorf("resolver preloads only available for writeable resolver types MEM/DIR/CACHE_DIR") } for k, v := range opts.resolverPreloads { _, err := jwt.DecodeAccountClaims(v) @@ -1346,7 +1363,8 @@ func (s *Server) Start() { // Log the pid to a file if opts.PidFile != _EMPTY_ { if err := s.logPid(); err != nil { - PrintAndDie(fmt.Sprintf("Could not write pidfile: %v\n", err)) + s.Fatalf("Could not write pidfile: %v", err) + return } } @@ -1361,6 +1379,42 @@ func (s *Server) Start() { s.SetDefaultSystemAccount() } + // start up resolver machinery + if ar := s.AccountResolver(); ar != nil { + if err := ar.Start(s); err != nil { + s.Fatalf("Could not start resolver: %v", err) + return + } + // In operator mode, when the account resolver depends on an external system and + // the system account is the bootstrapping account, start fetching it + if len(opts.TrustedOperators) == 1 && opts.SystemAccount != _EMPTY_ && opts.SystemAccount != DEFAULT_SYSTEM_ACCOUNT { + _, isMemResolver := ar.(*MemAccResolver) + if v, ok := s.accounts.Load(s.opts.SystemAccount); !isMemResolver && ok && v.(*Account).claimJWT == "" { + s.Noticef("Using bootstrapping system account") + s.startGoRoutine(func() { + defer s.grWG.Done() + t := time.NewTicker(time.Second) + defer t.Stop() + for { + select { + case <-s.quitCh: + return + case <-t.C: + if _, err := ar.Fetch(s.opts.SystemAccount); err != nil { + continue + } + if _, err := s.fetchAccount(s.opts.SystemAccount); err != nil { + continue + } + s.Noticef("System account fetched and updated") + return + } + } + }) + } + } + } + // Start expiration of mapped GW replies, regardless if // this server is configured with gateway or not. s.startGWReplyMapExpiration() @@ -1483,6 +1537,10 @@ func (s *Server) Shutdown() { } s.Noticef("Initiating Shutdown...") + if s.accResolver != nil { + s.accResolver.Close() + } + opts := s.getOpts() s.shutdown = true