Nats based resolver & avoiding nats account server in smaller deployments (#1550)

* Adding nats based resolver and bootstrap system account

These resolver operate on an exclusive  directory
Two types:
full: managing all jwt in the directory
    Will synchronize with other full resolver
    nats-account-server will also run such a resolver
cache: lru cache managing only a subset of all jwt in the directory
    Will lookup jwt from full resolver
    Can overwrite expiration with a ttl for the file

Both:
    track expiration of jwt and clean up
    Support reload
    Notify the server of changed jwt

Bootstrapping system account allows users signed with the system account
jwt to connect, without the server knowing the jwt.
This allows uploading jwt (including system account) using nats by
publishing to $SYS.ACCOUNT.<name>.CLAIMS.UPDATE
Sending a request, server will respond with the result of the operation.

Receive all jwt stored in one server by sending a
request to $SYS.ACCOUNT.CLAIMS.PACK
One server will respond with a message per stored jwt.
The end of the responses is indicated by an empty message.

The content of dirstore.go and dirstore_test.go was moved from
nats-account-server

Signed-off-by: Matthias Hanel <mh@synadia.com>
This commit is contained in:
Matthias Hanel
2020-08-18 15:58:41 -04:00
committed by GitHub
parent fe549a1979
commit 48c87c1447
8 changed files with 2374 additions and 46 deletions

View File

@@ -14,8 +14,11 @@
package server
import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"math"
"math/rand"
"net/http"
"net/url"
@@ -2581,16 +2584,50 @@ func buildInternalNkeyUser(uc *jwt.UserClaims, acc *Account) *NkeyUser {
return nu
}
const fetchTimeout = 2 * time.Second
// AccountResolver interface. This is to fetch Account JWTs by public nkeys
type AccountResolver interface {
Fetch(name string) (string, error)
Store(name, jwt string) error
IsReadOnly() bool
Start(server *Server) error
IsTrackingUpdate() bool
Reload() error
Close()
}
// Default implementations of IsReadOnly/Start so only need to be written when changed
type resolverDefaultsOpsImpl struct{}
func (*resolverDefaultsOpsImpl) IsReadOnly() bool {
return true
}
func (*resolverDefaultsOpsImpl) IsTrackingUpdate() bool {
return false
}
func (*resolverDefaultsOpsImpl) Start(*Server) error {
return nil
}
func (*resolverDefaultsOpsImpl) Reload() error {
return nil
}
func (*resolverDefaultsOpsImpl) Close() {
}
func (*resolverDefaultsOpsImpl) Store(_, _ string) error {
return fmt.Errorf("Store operation not supported for URL Resolver")
}
// MemAccResolver is a memory only resolver.
// Mostly for testing.
type MemAccResolver struct {
sm sync.Map
resolverDefaultsOpsImpl
}
// Fetch will fetch the account jwt claims from the internal sync.Map.
@@ -2607,10 +2644,15 @@ func (m *MemAccResolver) Store(name, jwt string) error {
return nil
}
func (ur *MemAccResolver) IsReadOnly() bool {
return false
}
// URLAccResolver implements an http fetcher.
type URLAccResolver struct {
url string
c *http.Client
resolverDefaultsOpsImpl
}
// NewURLAccResolver returns a new resolver for the given base URL.
@@ -2626,7 +2668,7 @@ func NewURLAccResolver(url string) (*URLAccResolver, error) {
}
ur := &URLAccResolver{
url: url,
c: &http.Client{Timeout: 2 * time.Second, Transport: tr},
c: &http.Client{Timeout: fetchTimeout, Transport: tr},
}
return ur, nil
}
@@ -2651,7 +2693,277 @@ func (ur *URLAccResolver) Fetch(name string) (string, error) {
return string(body), nil
}
// Store is not implemented for URL Resolver.
func (ur *URLAccResolver) Store(name, jwt string) error {
return fmt.Errorf("Store operation not supported for URL Resolver")
// Resolver based on nats for synchronization and backing directory for storage.
type DirAccResolver struct {
*DirJWTStore
syncInterval time.Duration
}
func (dr *DirAccResolver) IsTrackingUpdate() bool {
return true
}
func respondToUpdate(s *Server, respSubj string, acc string, message string, err error) {
if err == nil {
s.Debugf("%s - %s", message, acc)
} else {
s.Errorf("%s - %s - %s", message, acc, err)
}
if respSubj == "" {
return
}
server := &ServerInfo{}
response := map[string]interface{}{"server": server}
if err == nil {
response["data"] = map[string]interface{}{
"code": http.StatusOK,
"account": acc,
"message": message,
}
} else {
response["error"] = map[string]interface{}{
"code": http.StatusInternalServerError,
"account": acc,
"description": fmt.Sprintf("%s - %v", message, err),
}
}
s.sendInternalMsgLocked(respSubj, _EMPTY_, server, response)
}
func (dr *DirAccResolver) Start(s *Server) error {
dr.Lock()
defer dr.Unlock()
dr.DirJWTStore.changed = func(pubKey string) {
if v, ok := s.accounts.Load(pubKey); !ok {
} else if jwt, err := dr.LoadAcc(pubKey); err != nil {
s.Errorf("update got error on load: %v", err)
} else if err := s.updateAccountWithClaimJWT(v.(*Account), jwt); err != nil {
s.Errorf("update resulted in error %v", err)
}
}
const accountPackRequest = "$SYS.ACCOUNT.CLAIMS.PACK"
const accountLookupRequest = "$SYS.ACCOUNT.*.CLAIMS.LOOKUP"
packRespIb := s.newRespInbox()
// subscribe to account jwt update requests
if _, err := s.sysSubscribe(fmt.Sprintf(accUpdateEventSubj, "*"), func(_ *subscription, _ *client, subj, resp string, msg []byte) {
tk := strings.Split(subj, tsep)
if len(tk) != accUpdateTokens {
return
}
pubKey := tk[accUpdateAccIndex]
if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil {
respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err)
} else if claim.Subject != pubKey {
err := errors.New("subject does not match jwt content")
respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err)
} else if err := dr.save(pubKey, string(msg)); err != nil {
respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err)
} else {
respondToUpdate(s, resp, pubKey, "jwt updated", nil)
}
}); err != nil {
return fmt.Errorf("error setting up update handling: %v", err)
} else if _, err := s.sysSubscribe(accountLookupRequest, func(_ *subscription, _ *client, subj, reply string, msg []byte) {
// respond to lookups with our version
if reply == "" {
return
}
tk := strings.Split(subj, tsep)
if len(tk) != accUpdateTokens {
return
}
if theJWT, err := dr.DirJWTStore.LoadAcc(tk[accUpdateAccIndex]); err != nil {
s.Errorf("Merging resulted in error: %v", err)
} else {
s.sendInternalMsgLocked(reply, "", nil, []byte(theJWT))
}
}); err != nil {
return fmt.Errorf("error setting up lookup request handling: %v", err)
} else if _, err = s.sysSubscribeQ(accountPackRequest, "responder",
// respond to pack requests with one or more pack messages
// an empty message signifies the end of the response responder
func(_ *subscription, _ *client, _, reply string, theirHash []byte) {
if reply == "" {
return
}
ourHash := dr.DirJWTStore.Hash()
if bytes.Equal(theirHash, ourHash[:]) {
s.sendInternalMsgLocked(reply, "", nil, []byte{})
s.Debugf("pack request matches hash %x", ourHash[:])
} else if err := dr.DirJWTStore.PackWalk(1, func(partialPackMsg string) {
s.sendInternalMsgLocked(reply, "", nil, []byte(partialPackMsg))
}); err != nil {
// let them timeout
s.Errorf("pack request error: %v", err)
} else {
s.Debugf("pack request hash %x - finished responding with hash %x")
s.sendInternalMsgLocked(reply, "", nil, []byte{})
}
}); err != nil {
return fmt.Errorf("error setting up pack request handling: %v", err)
} else if _, err = s.sysSubscribe(packRespIb, func(_ *subscription, _ *client, _, _ string, msg []byte) {
// embed pack responses into store
hash := dr.DirJWTStore.Hash()
if len(msg) == 0 { // end of response stream
s.Debugf("Merging Finished and resulting in: %x", dr.DirJWTStore.Hash())
return
} else if err := dr.DirJWTStore.Merge(string(msg)); err != nil {
s.Errorf("Merging resulted in error: %v", err)
} else {
s.Debugf("Merging succeeded and changed %x to %x", hash, dr.DirJWTStore.Hash())
}
}); err != nil {
return fmt.Errorf("error setting up pack response handling: %v", err)
}
// periodically send out pack message
quit := s.quitCh
s.startGoRoutine(func() {
defer s.grWG.Done()
ticker := time.NewTicker(dr.syncInterval)
for {
select {
case <-quit:
ticker.Stop()
return
case <-ticker.C:
}
ourHash := dr.DirJWTStore.Hash()
s.Debugf("Checking store state: %x", ourHash)
s.sendInternalMsgLocked(accountPackRequest, packRespIb, nil, ourHash[:])
}
})
s.Noticef("Managing all jwt in exclusive directory %s", dr.directory)
return nil
}
func (dr *DirAccResolver) Fetch(name string) (string, error) {
return dr.LoadAcc(name)
}
func (dr *DirAccResolver) Store(name, jwt string) error {
return dr.saveIfNewer(name, jwt)
}
func NewDirAccResolver(path string, limit int64, syncInterval time.Duration) (*DirAccResolver, error) {
if limit == 0 {
limit = math.MaxInt64
}
if syncInterval <= 0 {
syncInterval = time.Minute
}
store, err := NewExpiringDirJWTStore(path, false, true, 0, limit, false, 0, nil)
if err != nil {
return nil, err
}
return &DirAccResolver{store, syncInterval}, nil
}
// Caching resolver using nats for lookups and making use of a directory for storage
type CacheDirAccResolver struct {
DirAccResolver
*Server
ttl time.Duration
}
func (dr *CacheDirAccResolver) Fetch(name string) (string, error) {
if theJWT, _ := dr.LoadAcc(name); theJWT != "" {
return theJWT, nil
}
// lookup from other server
s := dr.Server
if s == nil {
return "", ErrNoAccountResolver
}
respC := make(chan []byte, 1)
accountLookupRequest := fmt.Sprintf("$SYS.ACCOUNT.%s.CLAIMS.LOOKUP", name)
s.mu.Lock()
replySubj := s.newRespInbox()
if s.sys == nil || s.sys.replies == nil {
s.mu.Unlock()
return "", fmt.Errorf("eventing shut down")
}
replies := s.sys.replies
// Store our handler.
replies[replySubj] = func(sub *subscription, _ *client, subject, _ string, msg []byte) {
clone := make([]byte, len(msg))
copy(clone, msg)
s.mu.Lock()
if _, ok := replies[replySubj]; ok {
respC <- clone // only send if there is still interest
}
s.mu.Unlock()
}
s.sendInternalMsg(accountLookupRequest, replySubj, nil, []byte{})
quit := s.quitCh
s.mu.Unlock()
var err error
var theJWT string
select {
case <-quit:
err = errors.New("fetching jwt failed due to shutdown")
case <-time.After(fetchTimeout):
err = errors.New("fetching jwt timed out")
case m := <-respC:
if err = dr.Store(name, string(m)); err == nil {
theJWT = string(m)
}
}
s.mu.Lock()
delete(replies, replySubj)
s.mu.Unlock()
close(respC)
return theJWT, err
}
func NewCacheDirAccResolver(path string, limit int64, ttl time.Duration) (*CacheDirAccResolver, error) {
if limit <= 0 {
limit = 1_000
}
store, err := NewExpiringDirJWTStore(path, false, true, 0, limit, true, ttl, nil)
if err != nil {
return nil, err
}
return &CacheDirAccResolver{DirAccResolver{store, 0}, nil, ttl}, nil
}
func (dr *CacheDirAccResolver) Start(s *Server) error {
dr.Lock()
defer dr.Unlock()
dr.Server = s
dr.DirJWTStore.changed = func(pubKey string) {
if v, ok := s.accounts.Load(pubKey); !ok {
} else if jwt, err := dr.LoadAcc(pubKey); err != nil {
s.Errorf("update got error on load: %v", err)
} else if err := s.updateAccountWithClaimJWT(v.(*Account), jwt); err != nil {
s.Errorf("update resulted in error %v", err)
}
}
// subscribe to account jwt update requests
if _, err := s.sysSubscribe(fmt.Sprintf(accUpdateEventSubj, "*"), func(_ *subscription, _ *client, subj, resp string, msg []byte) {
tk := strings.Split(subj, tsep)
if len(tk) != accUpdateTokens {
return
}
pubKey := tk[accUpdateAccIndex]
if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil {
respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err)
} else if claim.Subject != pubKey {
err := errors.New("subject does not match jwt content")
respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err)
} else if _, ok := s.accounts.Load(pubKey); !ok {
respondToUpdate(s, resp, pubKey, "jwt update cache skipped", nil)
} else if err := dr.save(pubKey, string(msg)); err != nil {
respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err)
} else {
respondToUpdate(s, resp, pubKey, "jwt updated cache", nil)
}
}); err != nil {
return fmt.Errorf("error setting up update handling: %v", err)
}
s.Noticef("Managing some jwt in exclusive directory %s", dr.directory)
return nil
}
func (dr *CacheDirAccResolver) Reload() error {
return dr.DirAccResolver.Reload()
}

670
server/dirstore.go Normal file
View File

@@ -0,0 +1,670 @@
/*
* Copyright 2020 The NATS Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package server
import (
"bytes"
"container/heap"
"container/list"
"crypto/sha256"
"errors"
"fmt"
"io/ioutil"
"math"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/nats-io/jwt/v2" // only used to decode, not for storage
)
const (
fileExtension = ".jwt"
)
// validatePathExists checks that the provided path exists and is a dir if requested
func validatePathExists(path string, dir bool) (string, error) {
if path == "" {
return "", errors.New("path is not specified")
}
abs, err := filepath.Abs(path)
if err != nil {
return "", fmt.Errorf("error parsing path [%s]: %v", abs, err)
}
var finfo os.FileInfo
if finfo, err = os.Stat(abs); os.IsNotExist(err) {
return "", fmt.Errorf("the path [%s] doesn't exist", abs)
}
mode := finfo.Mode()
if dir && mode.IsRegular() {
return "", fmt.Errorf("the path [%s] is not a directory", abs)
}
if !dir && mode.IsDir() {
return "", fmt.Errorf("the path [%s] is not a file", abs)
}
return abs, nil
}
// ValidateDirPath checks that the provided path exists and is a dir
func validateDirPath(path string) (string, error) {
return validatePathExists(path, true)
}
// JWTChanged functions are called when the store file watcher notices a JWT changed
type JWTChanged func(publicKey string)
// DirJWTStore implements the JWT Store interface, keeping JWTs in an optionally sharded
// directory structure
type DirJWTStore struct {
sync.Mutex
directory string
shard bool
readonly bool
expiration *expirationTracker
changed JWTChanged
}
func newDir(dirPath string, create bool) (string, error) {
fullPath, err := validateDirPath(dirPath)
if err != nil {
if !create {
return "", err
}
if err = os.MkdirAll(dirPath, 0755); err != nil {
return "", err
}
if fullPath, err = validateDirPath(dirPath); err != nil {
return "", err
}
}
return fullPath, nil
}
// Creates a directory based jwt store.
// Reads files only, does NOT watch directories and files.
func NewImmutableDirJWTStore(dirPath string, shard bool) (*DirJWTStore, error) {
theStore, err := NewDirJWTStore(dirPath, shard, false)
if err != nil {
return nil, err
}
theStore.readonly = true
return theStore, nil
}
// Creates a directory based jwt store.
// Operates on files only, does NOT watch directories and files.
func NewDirJWTStore(dirPath string, shard bool, create bool) (*DirJWTStore, error) {
fullPath, err := newDir(dirPath, create)
if err != nil {
return nil, err
}
theStore := &DirJWTStore{
directory: fullPath,
shard: shard,
}
return theStore, nil
}
// Creates a directory based jwt store.
//
// When ttl is set deletion of file is based on it and not on the jwt expiration
// To completely disable expiration (including expiration in jwt) set ttl to max duration time.Duration(math.MaxInt64)
//
// limit defines how many files are allowed at any given time. Set to math.MaxInt64 to disable.
// evictOnLimit determines the behavior once limit is reached.
// true - Evict based on lru strategy
// false - return an error
func NewExpiringDirJWTStore(dirPath string, shard bool, create bool, expireCheck time.Duration, limit int64,
evictOnLimit bool, ttl time.Duration, changeNotification JWTChanged) (*DirJWTStore, error) {
fullPath, err := newDir(dirPath, create)
if err != nil {
return nil, err
}
theStore := &DirJWTStore{
directory: fullPath,
shard: shard,
changed: changeNotification,
}
if expireCheck <= 0 {
if ttl != 0 {
expireCheck = ttl / 2
}
if expireCheck == 0 || expireCheck > time.Minute {
expireCheck = time.Minute
}
}
if limit <= 0 {
limit = math.MaxInt64
}
theStore.startExpiring(expireCheck, limit, evictOnLimit, ttl)
theStore.Lock()
err = filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error {
if strings.HasSuffix(path, fileExtension) {
if theJwt, err := ioutil.ReadFile(path); err == nil {
hash := sha256.Sum256(theJwt)
_, file := filepath.Split(path)
theStore.expiration.track(strings.TrimSuffix(file, fileExtension), &hash, string(theJwt))
}
}
return nil
})
theStore.Unlock()
if err != nil {
theStore.Close()
return nil, err
}
return theStore, err
}
func (store *DirJWTStore) IsReadOnly() bool {
return store.readonly
}
func (store *DirJWTStore) LoadAcc(publicKey string) (string, error) {
return store.load(publicKey)
}
func (store *DirJWTStore) SaveAcc(publicKey string, theJWT string) error {
return store.save(publicKey, theJWT)
}
func (store *DirJWTStore) LoadAct(hash string) (string, error) {
return store.load(hash)
}
func (store *DirJWTStore) SaveAct(hash string, theJWT string) error {
return store.save(hash, theJWT)
}
func (store *DirJWTStore) Close() {
store.Lock()
defer store.Unlock()
if store.expiration != nil {
store.expiration.close()
store.expiration = nil
}
}
// Pack up to maxJWTs into a package
func (store *DirJWTStore) Pack(maxJWTs int) (string, error) {
count := 0
var pack []string
if maxJWTs > 0 {
pack = make([]string, 0, maxJWTs)
} else {
pack = []string{}
}
store.Lock()
err := filepath.Walk(store.directory, func(path string, info os.FileInfo, err error) error {
if !info.IsDir() && strings.HasSuffix(path, fileExtension) { // this is a JWT
if count == maxJWTs { // won't match negative
return nil
}
pubKey := strings.TrimSuffix(filepath.Base(path), fileExtension)
if store.expiration != nil {
if _, ok := store.expiration.idx[pubKey]; !ok {
return nil // only include indexed files
}
}
jwtBytes, err := ioutil.ReadFile(path)
if err != nil {
return err
}
if store.expiration != nil {
claim, err := jwt.DecodeGeneric(string(jwtBytes))
if err == nil && claim.Expires > 0 && claim.Expires < time.Now().Unix() {
return nil
}
}
pack = append(pack, fmt.Sprintf("%s|%s", pubKey, string(jwtBytes)))
count++
}
return nil
})
store.Unlock()
if err != nil {
return "", err
} else {
return strings.Join(pack, "\n"), nil
}
}
// Pack up to maxJWTs into a message and invoke callback with it
func (store *DirJWTStore) PackWalk(maxJWTs int, cb func(partialPackMsg string)) error {
if maxJWTs <= 0 || cb == nil {
return errors.New("bad arguments to PackWalk")
}
var packMsg []string
store.Lock()
dir := store.directory
exp := store.expiration
store.Unlock()
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if !info.IsDir() && strings.HasSuffix(path, fileExtension) { // this is a JWT
pubKey := strings.TrimSuffix(filepath.Base(path), fileExtension)
store.Lock()
if exp != nil {
if _, ok := store.expiration.idx[pubKey]; !ok {
store.Unlock()
return nil // only include indexed files
}
}
store.Unlock()
jwtBytes, err := ioutil.ReadFile(path)
if err != nil {
return err
}
if exp != nil {
claim, err := jwt.DecodeGeneric(string(jwtBytes))
if err == nil && claim.Expires > 0 && claim.Expires < time.Now().Unix() {
return nil
}
}
packMsg = append(packMsg, fmt.Sprintf("%s|%s", pubKey, string(jwtBytes)))
if len(packMsg) == maxJWTs { // won't match negative
cb(strings.Join(packMsg, "\n"))
packMsg = nil
}
}
return nil
})
if packMsg != nil {
cb(strings.Join(packMsg, "\n"))
}
return err
}
// Merge takes the JWTs from package and adds them to the store
// Merge is destructive in the sense that it doesn't check if the JWT
// is newer or anything like that.
func (store *DirJWTStore) Merge(pack string) error {
newJWTs := strings.Split(pack, "\n")
for _, line := range newJWTs {
if line == "" { // ignore blank lines
continue
}
split := strings.Split(line, "|")
if len(split) != 2 {
return fmt.Errorf("line in package didn't contain 2 entries: %q", line)
}
pubKey := split[0]
if err := store.saveIfNewer(pubKey, split[1]); err != nil {
return err
}
}
return nil
}
func (store *DirJWTStore) Reload() error {
store.Lock()
exp := store.expiration
if exp == nil || store.readonly {
store.Unlock()
return nil
}
idx := exp.idx
changed := store.changed
isCache := store.expiration.evictOnLimit
// clear out indexing data structures
exp.heap = make([]*jwtItem, 0, len(exp.heap))
exp.idx = make(map[string]*list.Element)
exp.lru = list.New()
exp.hash = [sha256.Size]byte{}
store.Unlock()
return filepath.Walk(store.directory, func(path string, info os.FileInfo, err error) error {
if strings.HasSuffix(path, fileExtension) {
if theJwt, err := ioutil.ReadFile(path); err == nil {
hash := sha256.Sum256(theJwt)
_, file := filepath.Split(path)
pkey := strings.TrimSuffix(file, fileExtension)
notify := isCache // for cache, issue cb even when file not present (may have been evicted)
if i, ok := idx[pkey]; ok {
notify = !bytes.Equal(i.Value.(*jwtItem).hash[:], hash[:])
}
store.Lock()
exp.track(pkey, &hash, string(theJwt))
store.Unlock()
if notify && changed != nil {
changed(pkey)
}
}
}
return nil
})
}
func (store *DirJWTStore) pathForKey(publicKey string) string {
if len(publicKey) < 2 {
return ""
}
fileName := fmt.Sprintf("%s%s", publicKey, fileExtension)
if store.shard {
last := publicKey[len(publicKey)-2:]
return filepath.Join(store.directory, last, fileName)
} else {
return filepath.Join(store.directory, fileName)
}
}
// Load checks the memory store and returns the matching JWT or an error
// Assumes lock is NOT held
func (store *DirJWTStore) load(publicKey string) (string, error) {
store.Lock()
defer store.Unlock()
if path := store.pathForKey(publicKey); path == "" {
return "", fmt.Errorf("invalid public key")
} else if data, err := ioutil.ReadFile(path); err != nil {
return "", err
} else {
if store.expiration != nil {
store.expiration.updateTrack(publicKey)
}
return string(data), nil
}
}
// write that keeps hash of all jwt in sync
// Assumes the lock is held. Does return true or an error never both.
func (store *DirJWTStore) write(path string, publicKey string, theJWT string) (bool, error) {
var newHash *[sha256.Size]byte
if store.expiration != nil {
h := sha256.Sum256([]byte(theJWT))
newHash = &h
if v, ok := store.expiration.idx[publicKey]; ok {
store.expiration.updateTrack(publicKey)
// this write is an update, move to back
it := v.Value.(*jwtItem)
oldHash := it.hash[:]
if bytes.Equal(oldHash, newHash[:]) {
return false, nil
}
} else if int64(store.expiration.Len()) >= store.expiration.limit {
if !store.expiration.evictOnLimit {
return false, errors.New("jwt store is full")
}
// this write is an add, pick the least recently used value for removal
i := store.expiration.lru.Front().Value.(*jwtItem)
if err := os.Remove(store.pathForKey(i.publicKey)); err != nil {
return false, err
} else {
store.expiration.unTrack(i.publicKey)
}
}
}
if err := ioutil.WriteFile(path, []byte(theJWT), 0644); err != nil {
return false, err
} else if store.expiration != nil {
store.expiration.track(publicKey, newHash, theJWT)
}
return true, nil
}
// Save puts the JWT in a map by public key and performs update callbacks
// Assumes lock is NOT held
func (store *DirJWTStore) save(publicKey string, theJWT string) error {
if store.readonly {
return fmt.Errorf("store is read-only")
}
store.Lock()
path := store.pathForKey(publicKey)
if path == "" {
store.Unlock()
return fmt.Errorf("invalid public key")
}
dirPath := filepath.Dir(path)
if _, err := validateDirPath(dirPath); err != nil {
if err := os.MkdirAll(dirPath, 0755); err != nil {
store.Unlock()
return err
}
}
changed, err := store.write(path, publicKey, theJWT)
cb := store.changed
store.Unlock()
if changed && cb != nil {
cb(publicKey)
}
return err
}
// Assumes the lock is NOT held, and only updates if the jwt is new, or the one on disk is older
// returns true when the jwt changed
func (store *DirJWTStore) saveIfNewer(publicKey string, theJWT string) error {
if store.readonly {
return fmt.Errorf("store is read-only")
}
path := store.pathForKey(publicKey)
if path == "" {
return fmt.Errorf("invalid public key")
}
dirPath := filepath.Dir(path)
if _, err := validateDirPath(dirPath); err != nil {
if err := os.MkdirAll(dirPath, 0755); err != nil {
return err
}
}
if _, err := os.Stat(path); err == nil {
if newJWT, err := jwt.DecodeGeneric(theJWT); err != nil {
// skip if it can't be decoded
} else if existing, err := ioutil.ReadFile(path); err != nil {
return err
} else if existingJWT, err := jwt.DecodeGeneric(string(existing)); err != nil {
// skip if it can't be decoded
} else if existingJWT.ID == newJWT.ID {
return nil
} else if existingJWT.IssuedAt > newJWT.IssuedAt {
return nil
}
}
store.Lock()
cb := store.changed
changed, err := store.write(path, publicKey, theJWT)
store.Unlock()
if err != nil {
return err
} else if changed && cb != nil {
cb(publicKey)
}
return nil
}
func xorAssign(lVal *[sha256.Size]byte, rVal [sha256.Size]byte) {
for i := range rVal {
(*lVal)[i] ^= rVal[i]
}
}
// returns a hash representing all indexed jwt
func (store *DirJWTStore) Hash() [sha256.Size]byte {
store.Lock()
defer store.Unlock()
if store.expiration == nil {
return [sha256.Size]byte{}
} else {
return store.expiration.hash
}
}
// An jwtItem is something managed by the priority queue
type jwtItem struct {
index int
publicKey string
expiration int64 // consists of unix time of expiration (ttl when set or jwt expiration) in seconds
hash [sha256.Size]byte
}
// A expirationTracker implements heap.Interface and holds Items.
type expirationTracker struct {
heap []*jwtItem // sorted by jwtItem.expiration
idx map[string]*list.Element
lru *list.List // keep which jwt are least used
limit int64 // limit how many jwt are being tracked
evictOnLimit bool // when limit is hit, error or evict using lru
ttl time.Duration
hash [sha256.Size]byte // xor of all jwtItem.hash in idx
quit chan struct{}
wg sync.WaitGroup
}
func (pq *expirationTracker) Len() int { return len(pq.heap) }
func (q *expirationTracker) Less(i, j int) bool {
pq := q.heap
return pq[i].expiration < pq[j].expiration
}
func (q *expirationTracker) Swap(i, j int) {
pq := q.heap
pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i
pq[j].index = j
}
func (q *expirationTracker) Push(x interface{}) {
n := len(q.heap)
item := x.(*jwtItem)
item.index = n
q.heap = append(q.heap, item)
q.idx[item.publicKey] = q.lru.PushBack(item)
}
func (q *expirationTracker) Pop() interface{} {
old := q.heap
n := len(old)
item := old[n-1]
old[n-1] = nil // avoid memory leak
item.index = -1
q.heap = old[0 : n-1]
q.lru.Remove(q.idx[item.publicKey])
delete(q.idx, item.publicKey)
return item
}
func (pq *expirationTracker) updateTrack(publicKey string) {
if e, ok := pq.idx[publicKey]; ok {
i := e.Value.(*jwtItem)
if pq.ttl != 0 {
// only update expiration when set
i.expiration = time.Now().Add(pq.ttl).Unix()
heap.Fix(pq, i.index)
}
if pq.evictOnLimit {
pq.lru.MoveToBack(e)
}
}
}
func (pq *expirationTracker) unTrack(publicKey string) {
if it, ok := pq.idx[publicKey]; ok {
xorAssign(&pq.hash, it.Value.(*jwtItem).hash)
heap.Remove(pq, it.Value.(*jwtItem).index)
delete(pq.idx, publicKey)
}
}
func (pq *expirationTracker) track(publicKey string, hash *[sha256.Size]byte, theJWT string) {
var exp int64
// prioritize ttl over expiration
if pq.ttl != 0 {
if pq.ttl == time.Duration(math.MaxInt64) {
exp = math.MaxInt64
} else {
exp = time.Now().Add(pq.ttl).Unix()
}
} else {
if g, err := jwt.DecodeGeneric(theJWT); err == nil {
exp = g.Expires
}
if exp == 0 {
exp = math.MaxInt64 // default to indefinite
}
}
if e, ok := pq.idx[publicKey]; ok {
i := e.Value.(*jwtItem)
xorAssign(&pq.hash, i.hash) // remove old hash
i.expiration = exp
i.hash = *hash
heap.Fix(pq, i.index)
} else {
heap.Push(pq, &jwtItem{-1, publicKey, exp, *hash})
}
xorAssign(&pq.hash, *hash) // add in new hash
}
func (pq *expirationTracker) close() {
if pq == nil || pq.quit == nil {
return
}
close(pq.quit)
pq.quit = nil
}
func (store *DirJWTStore) startExpiring(reCheck time.Duration, limit int64, evictOnLimit bool, ttl time.Duration) {
store.Lock()
defer store.Unlock()
quit := make(chan struct{})
pq := &expirationTracker{
make([]*jwtItem, 0, 10),
make(map[string]*list.Element),
list.New(),
limit,
evictOnLimit,
ttl,
[sha256.Size]byte{},
quit,
sync.WaitGroup{},
}
store.expiration = pq
pq.wg.Add(1)
go func() {
t := time.NewTicker(reCheck)
defer t.Stop()
defer pq.wg.Done()
for {
now := time.Now()
store.Lock()
if pq.Len() > 0 {
if it := heap.Pop(pq).(*jwtItem); it.expiration <= now.Unix() {
path := store.pathForKey(it.publicKey)
if err := os.Remove(path); err != nil {
heap.Push(pq, it) // retry later
} else {
pq.unTrack(it.publicKey)
xorAssign(&pq.hash, it.hash)
store.Unlock()
continue // we removed an entry, check next one
}
} else {
heap.Push(pq, it)
}
}
store.Unlock()
select {
case <-t.C:
case <-quit:
return
}
}
}()
}

898
server/dirstore_test.go Normal file
View File

@@ -0,0 +1,898 @@
/*
* Copyright 2020 The NATS Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package server
import (
"bytes"
"crypto/sha256"
"fmt"
"io/ioutil"
"math"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/nats-io/jwt/v2"
"github.com/nats-io/nkeys"
)
func require_True(t *testing.T, b bool) {
t.Helper()
if !b {
t.Errorf("require true, but got false")
}
}
func require_False(t *testing.T, b bool) {
t.Helper()
if b {
t.Errorf("require no false, but got true")
}
}
func require_NoError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Errorf("require no error, but got: %v", err)
}
}
func require_Error(t *testing.T, err error) {
t.Helper()
if err == nil {
t.Errorf("require no error, but got: %v", err)
}
}
func require_Equal(t *testing.T, a, b string) {
t.Helper()
if strings.Compare(a, b) != 0 {
t.Errorf("require equal, but got: %v != %v", a, b)
}
}
func require_NotEqual(t *testing.T, a, b [32]byte) {
t.Helper()
if bytes.Equal(a[:], b[:]) {
t.Errorf("require not equal, but got: %v != %v", a, b)
}
}
func require_Len(t *testing.T, a, b int) {
t.Helper()
if a != b {
t.Errorf("require len, but got: %v != %v", a, b)
}
}
func TestShardedDirStoreWriteAndReadonly(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
store, err := NewDirJWTStore(dir, true, false)
require_NoError(t, err)
expected := map[string]string{
"one": "alpha",
"two": "beta",
"three": "gamma",
"four": "delta",
}
for k, v := range expected {
store.SaveAcc(k, v)
}
for k, v := range expected {
got, err := store.LoadAcc(k)
require_NoError(t, err)
require_Equal(t, v, got)
}
got, err := store.LoadAcc("random")
require_Error(t, err)
require_Equal(t, "", got)
got, err = store.LoadAcc("")
require_Error(t, err)
require_Equal(t, "", got)
err = store.SaveAcc("", "onetwothree")
require_Error(t, err)
store.Close()
// re-use the folder for readonly mode
store, err = NewImmutableDirJWTStore(dir, true)
require_NoError(t, err)
require_True(t, store.IsReadOnly())
err = store.SaveAcc("five", "omega")
require_Error(t, err)
for k, v := range expected {
got, err := store.LoadAcc(k)
require_NoError(t, err)
require_Equal(t, v, got)
}
store.Close()
}
func TestUnshardedDirStoreWriteAndReadonly(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
store, err := NewDirJWTStore(dir, false, false)
require_NoError(t, err)
expected := map[string]string{
"one": "alpha",
"two": "beta",
"three": "gamma",
"four": "delta",
}
require_False(t, store.IsReadOnly())
for k, v := range expected {
store.SaveAcc(k, v)
}
for k, v := range expected {
got, err := store.LoadAcc(k)
require_NoError(t, err)
require_Equal(t, v, got)
}
got, err := store.LoadAcc("random")
require_Error(t, err)
require_Equal(t, "", got)
got, err = store.LoadAcc("")
require_Error(t, err)
require_Equal(t, "", got)
err = store.SaveAcc("", "onetwothree")
require_Error(t, err)
store.Close()
// re-use the folder for readonly mode
store, err = NewImmutableDirJWTStore(dir, false)
require_NoError(t, err)
require_True(t, store.IsReadOnly())
err = store.SaveAcc("five", "omega")
require_Error(t, err)
for k, v := range expected {
got, err := store.LoadAcc(k)
require_NoError(t, err)
require_Equal(t, v, got)
}
store.Close()
}
func TestNoCreateRequiresDir(t *testing.T) {
_, err := NewDirJWTStore("/a/b/c", true, false)
require_Error(t, err)
}
func TestCreateMakesDir(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
fullPath := filepath.Join(dir, "a/b")
_, err = os.Stat(fullPath)
require_Error(t, err)
require_True(t, os.IsNotExist(err))
s, err := NewDirJWTStore(fullPath, false, true)
require_NoError(t, err)
s.Close()
_, err = os.Stat(fullPath)
require_NoError(t, err)
}
func TestShardedDirStorePackMerge(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
dir2, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
dir3, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
store, err := NewDirJWTStore(dir, true, false)
require_NoError(t, err)
expected := map[string]string{
"one": "alpha",
"two": "beta",
"three": "gamma",
"four": "delta",
}
require_False(t, store.IsReadOnly())
for k, v := range expected {
store.SaveAcc(k, v)
}
for k, v := range expected {
got, err := store.LoadAcc(k)
require_NoError(t, err)
require_Equal(t, v, got)
}
got, err := store.LoadAcc("random")
require_Error(t, err)
require_Equal(t, "", got)
pack, err := store.Pack(-1)
require_NoError(t, err)
inc, err := NewDirJWTStore(dir2, true, false)
require_NoError(t, err)
inc.Merge(pack)
for k, v := range expected {
got, err := inc.LoadAcc(k)
require_NoError(t, err)
require_Equal(t, v, got)
}
got, err = inc.LoadAcc("random")
require_Error(t, err)
require_Equal(t, "", got)
limitedPack, err := inc.Pack(1)
require_NoError(t, err)
limited, err := NewDirJWTStore(dir3, true, false)
require_NoError(t, err)
limited.Merge(limitedPack)
count := 0
for k, v := range expected {
got, err := limited.LoadAcc(k)
if err == nil {
count++
require_Equal(t, v, got)
}
}
require_Len(t, 1, count)
got, err = inc.LoadAcc("random")
require_Error(t, err)
require_Equal(t, "", got)
}
func TestShardedToUnsharedDirStorePackMerge(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
dir2, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
store, err := NewDirJWTStore(dir, true, false)
require_NoError(t, err)
expected := map[string]string{
"one": "alpha",
"two": "beta",
"three": "gamma",
"four": "delta",
}
require_False(t, store.IsReadOnly())
for k, v := range expected {
store.SaveAcc(k, v)
}
for k, v := range expected {
got, err := store.LoadAcc(k)
require_NoError(t, err)
require_Equal(t, v, got)
}
got, err := store.LoadAcc("random")
require_Error(t, err)
require_Equal(t, "", got)
pack, err := store.Pack(-1)
require_NoError(t, err)
inc, err := NewDirJWTStore(dir2, false, false)
require_NoError(t, err)
inc.Merge(pack)
for k, v := range expected {
got, err := inc.LoadAcc(k)
require_NoError(t, err)
require_Equal(t, v, got)
}
got, err = inc.LoadAcc("random")
require_Error(t, err)
require_Equal(t, "", got)
err = store.Merge("foo")
require_Error(t, err)
err = store.Merge("") // will skip it
require_NoError(t, err)
err = store.Merge("a|something") // should fail on a for sharding
require_Error(t, err)
}
func TestMergeOnlyOnNewer(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
dirStore, err := NewDirJWTStore(dir, true, false)
require_NoError(t, err)
accountKey, err := nkeys.CreateAccount()
require_NoError(t, err)
pubKey, err := accountKey.PublicKey()
require_NoError(t, err)
account := jwt.NewAccountClaims(pubKey)
account.Name = "old"
olderJWT, err := account.Encode(accountKey)
require_NoError(t, err)
time.Sleep(2 * time.Second)
account.Name = "new"
newerJWT, err := account.Encode(accountKey)
require_NoError(t, err)
// Should work
err = dirStore.SaveAcc(pubKey, olderJWT)
require_NoError(t, err)
fromStore, err := dirStore.LoadAcc(pubKey)
require_NoError(t, err)
require_Equal(t, olderJWT, fromStore)
// should replace
err = dirStore.saveIfNewer(pubKey, newerJWT)
require_NoError(t, err)
fromStore, err = dirStore.LoadAcc(pubKey)
require_NoError(t, err)
require_Equal(t, newerJWT, fromStore)
// should fail
err = dirStore.saveIfNewer(pubKey, olderJWT)
require_NoError(t, err)
fromStore, err = dirStore.LoadAcc(pubKey)
require_NoError(t, err)
require_Equal(t, newerJWT, fromStore)
}
func createTestAccount(t *testing.T, dirStore *DirJWTStore, expSec int, accKey nkeys.KeyPair) string {
t.Helper()
pubKey, err := accKey.PublicKey()
require_NoError(t, err)
account := jwt.NewAccountClaims(pubKey)
if expSec > 0 {
account.Expires = time.Now().Add(time.Second * time.Duration(expSec)).Unix()
}
jwt, err := account.Encode(accKey)
require_NoError(t, err)
err = dirStore.SaveAcc(pubKey, jwt)
require_NoError(t, err)
return jwt
}
func assertStoreSize(t *testing.T, dirStore *DirJWTStore, length int) {
t.Helper()
f, err := ioutil.ReadDir(dirStore.directory)
require_NoError(t, err)
require_Len(t, len(f), length)
dirStore.Lock()
require_Len(t, len(dirStore.expiration.idx), length)
require_Len(t, dirStore.expiration.lru.Len(), length)
require_Len(t, len(dirStore.expiration.heap), length)
dirStore.Unlock()
}
func TestExpiration(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 10, true, 0, nil)
require_NoError(t, err)
defer dirStore.Close()
account := func(expSec int) {
accountKey, err := nkeys.CreateAccount()
require_NoError(t, err)
createTestAccount(t, dirStore, expSec, accountKey)
}
h := dirStore.Hash()
for i := 1; i <= 5; i++ {
account(i * 2)
nh := dirStore.Hash()
require_NotEqual(t, h, nh)
h = nh
}
time.Sleep(1 * time.Second)
for i := 5; i > 0; i-- {
f, err := ioutil.ReadDir(dir)
require_NoError(t, err)
require_Len(t, len(f), i)
assertStoreSize(t, dirStore, i)
time.Sleep(2 * time.Second)
nh := dirStore.Hash()
require_NotEqual(t, h, nh)
h = nh
}
assertStoreSize(t, dirStore, 0)
}
func TestLimit(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 5, true, 0, nil)
require_NoError(t, err)
defer dirStore.Close()
account := func(expSec int) {
accountKey, err := nkeys.CreateAccount()
require_NoError(t, err)
createTestAccount(t, dirStore, expSec, accountKey)
}
h := dirStore.Hash()
accountKey, err := nkeys.CreateAccount()
require_NoError(t, err)
// update first account
for i := 0; i < 10; i++ {
createTestAccount(t, dirStore, 50, accountKey)
assertStoreSize(t, dirStore, 1)
}
// new accounts
for i := 0; i < 10; i++ {
account(i)
nh := dirStore.Hash()
require_NotEqual(t, h, nh)
h = nh
}
// first account should be gone now accountKey.PublicKey()
key, _ := accountKey.PublicKey()
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, key))
require_True(t, os.IsNotExist(err))
// update first account
for i := 0; i < 10; i++ {
createTestAccount(t, dirStore, 50, accountKey)
assertStoreSize(t, dirStore, 5)
}
}
func TestLimitNoEvict(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 2, false, 0, nil)
require_NoError(t, err)
defer dirStore.Close()
accountKey1, err := nkeys.CreateAccount()
require_NoError(t, err)
pKey1, err := accountKey1.PublicKey()
require_NoError(t, err)
accountKey2, err := nkeys.CreateAccount()
require_NoError(t, err)
accountKey3, err := nkeys.CreateAccount()
require_NoError(t, err)
pKey3, err := accountKey3.PublicKey()
require_NoError(t, err)
createTestAccount(t, dirStore, 100, accountKey1)
assertStoreSize(t, dirStore, 1)
createTestAccount(t, dirStore, 2, accountKey2)
assertStoreSize(t, dirStore, 2)
hBefore := dirStore.Hash()
// 2 jwt are already stored. third must result in an error
pubKey, err := accountKey3.PublicKey()
require_NoError(t, err)
account := jwt.NewAccountClaims(pubKey)
jwt, err := account.Encode(accountKey3)
require_NoError(t, err)
err = dirStore.SaveAcc(pubKey, jwt)
require_Error(t, err)
assertStoreSize(t, dirStore, 2)
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1))
require_NoError(t, err)
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3))
require_True(t, os.IsNotExist(err))
// check that the hash did not change
hAfter := dirStore.Hash()
require_True(t, bytes.Equal(hBefore[:], hAfter[:]))
// wait for expiration of account2
time.Sleep(3 * time.Second)
err = dirStore.SaveAcc(pubKey, jwt)
require_NoError(t, err)
assertStoreSize(t, dirStore, 2)
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1))
require_NoError(t, err)
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3))
require_NoError(t, err)
}
func TestLruLoad(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 2, true, 0, nil)
require_NoError(t, err)
defer dirStore.Close()
accountKey1, err := nkeys.CreateAccount()
require_NoError(t, err)
pKey1, err := accountKey1.PublicKey()
require_NoError(t, err)
accountKey2, err := nkeys.CreateAccount()
require_NoError(t, err)
accountKey3, err := nkeys.CreateAccount()
require_NoError(t, err)
pKey3, err := accountKey3.PublicKey()
require_NoError(t, err)
createTestAccount(t, dirStore, 10, accountKey1)
assertStoreSize(t, dirStore, 1)
createTestAccount(t, dirStore, 10, accountKey2)
assertStoreSize(t, dirStore, 2)
dirStore.LoadAcc(pKey1) // will reorder 1/2
createTestAccount(t, dirStore, 10, accountKey3)
assertStoreSize(t, dirStore, 2)
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1))
require_NoError(t, err)
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3))
require_NoError(t, err)
}
func TestLru(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 2, true, 0, nil)
require_NoError(t, err)
defer dirStore.Close()
accountKey1, err := nkeys.CreateAccount()
require_NoError(t, err)
pKey1, err := accountKey1.PublicKey()
require_NoError(t, err)
accountKey2, err := nkeys.CreateAccount()
require_NoError(t, err)
accountKey3, err := nkeys.CreateAccount()
require_NoError(t, err)
pKey3, err := accountKey3.PublicKey()
require_NoError(t, err)
createTestAccount(t, dirStore, 10, accountKey1)
assertStoreSize(t, dirStore, 1)
createTestAccount(t, dirStore, 10, accountKey2)
assertStoreSize(t, dirStore, 2)
createTestAccount(t, dirStore, 10, accountKey3)
assertStoreSize(t, dirStore, 2)
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1))
require_True(t, os.IsNotExist(err))
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3))
require_NoError(t, err)
// update -> will change this keys position for eviction
createTestAccount(t, dirStore, 10, accountKey2)
assertStoreSize(t, dirStore, 2)
// recreate -> will evict 3
createTestAccount(t, dirStore, 1, accountKey1)
assertStoreSize(t, dirStore, 2)
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3))
require_True(t, os.IsNotExist(err))
// let key1 expire
time.Sleep(2 * time.Second)
assertStoreSize(t, dirStore, 1)
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1))
require_True(t, os.IsNotExist(err))
// recreate key3 - no eviction
createTestAccount(t, dirStore, 10, accountKey3)
assertStoreSize(t, dirStore, 2)
}
func TestReload(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
notificationChan := make(chan struct{}, 5)
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 2, true, 0, func(publicKey string) {
notificationChan <- struct{}{}
})
require_NoError(t, err)
defer dirStore.Close()
newAccount := func() string {
t.Helper()
accKey, err := nkeys.CreateAccount()
require_NoError(t, err)
pKey, err := accKey.PublicKey()
require_NoError(t, err)
pubKey, err := accKey.PublicKey()
require_NoError(t, err)
account := jwt.NewAccountClaims(pubKey)
jwt, err := account.Encode(accKey)
require_NoError(t, err)
file := fmt.Sprintf("%s/%s.jwt", dir, pKey)
err = ioutil.WriteFile(file, []byte(jwt), 0644)
require_NoError(t, err)
return file
}
files := make(map[string]struct{})
assertStoreSize(t, dirStore, 0)
hash := dirStore.Hash()
emptyHash := [sha256.Size]byte{}
require_True(t, bytes.Equal(hash[:], emptyHash[:]))
for i := 0; i < 5; i++ {
files[newAccount()] = struct{}{}
err = dirStore.Reload()
require_NoError(t, err)
<-notificationChan
assertStoreSize(t, dirStore, i+1)
hash = dirStore.Hash()
require_False(t, bytes.Equal(hash[:], emptyHash[:]))
msg, err := dirStore.Pack(-1)
require_NoError(t, err)
require_Len(t, len(strings.Split(msg, "\n")), len(files))
}
for k := range files {
hash = dirStore.Hash()
require_False(t, bytes.Equal(hash[:], emptyHash[:]))
os.Remove(k)
err = dirStore.Reload()
require_NoError(t, err)
assertStoreSize(t, dirStore, len(files)-1)
delete(files, k)
msg, err := dirStore.Pack(-1)
require_NoError(t, err)
if len(files) != 0 { // when len is 0, we have an empty line
require_Len(t, len(strings.Split(msg, "\n")), len(files))
}
}
require_True(t, len(notificationChan) == 0)
hash = dirStore.Hash()
require_True(t, bytes.Equal(hash[:], emptyHash[:]))
}
func TestExpirationUpdate(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 10, true, 0, nil)
require_NoError(t, err)
defer dirStore.Close()
accountKey, err := nkeys.CreateAccount()
require_NoError(t, err)
h := dirStore.Hash()
createTestAccount(t, dirStore, 0, accountKey)
nh := dirStore.Hash()
require_NotEqual(t, h, nh)
h = nh
time.Sleep(2 * time.Second)
f, err := ioutil.ReadDir(dir)
require_NoError(t, err)
require_Len(t, len(f), 1)
createTestAccount(t, dirStore, 5, accountKey)
nh = dirStore.Hash()
require_NotEqual(t, h, nh)
h = nh
time.Sleep(2 * time.Second)
f, err = ioutil.ReadDir(dir)
require_NoError(t, err)
require_Len(t, len(f), 1)
createTestAccount(t, dirStore, 0, accountKey)
nh = dirStore.Hash()
require_NotEqual(t, h, nh)
h = nh
time.Sleep(2 * time.Second)
f, err = ioutil.ReadDir(dir)
require_NoError(t, err)
require_Len(t, len(f), 1)
createTestAccount(t, dirStore, 1, accountKey)
nh = dirStore.Hash()
require_NotEqual(t, h, nh)
h = nh
time.Sleep(2 * time.Second)
f, err = ioutil.ReadDir(dir)
require_NoError(t, err)
require_Len(t, len(f), 0)
empty := [32]byte{}
h = dirStore.Hash()
require_Equal(t, string(h[:]), string(empty[:]))
}
func TestTTL(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
require_OneJWT := func() {
t.Helper()
f, err := ioutil.ReadDir(dir)
require_NoError(t, err)
require_Len(t, len(f), 1)
}
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 10, true, 2*time.Second, nil)
require_NoError(t, err)
defer dirStore.Close()
accountKey, err := nkeys.CreateAccount()
require_NoError(t, err)
pubKey, err := accountKey.PublicKey()
require_NoError(t, err)
jwt := createTestAccount(t, dirStore, 0, accountKey)
require_OneJWT()
for i := 0; i < 6; i++ {
time.Sleep(time.Second)
dirStore.LoadAcc(pubKey)
require_OneJWT()
}
for i := 0; i < 6; i++ {
time.Sleep(time.Second)
dirStore.SaveAcc(pubKey, jwt)
require_OneJWT()
}
for i := 0; i < 6; i++ {
time.Sleep(time.Second)
createTestAccount(t, dirStore, 0, accountKey)
require_OneJWT()
}
time.Sleep(3 * time.Second)
f, err := ioutil.ReadDir(dir)
require_NoError(t, err)
require_Len(t, len(f), 0)
}
const infDur = time.Duration(math.MaxInt64)
func TestNotificationOnPack(t *testing.T) {
jwts := map[string]string{
"key1": "value",
"key2": "value",
"key3": "value",
"key4": "value",
}
notificationChan := make(chan struct{}, len(jwts)) // set to same len so all extra will block
notification := func(pubKey string) {
if _, ok := jwts[pubKey]; !ok {
t.Errorf("Key not found: %s", pubKey)
}
notificationChan <- struct{}{}
}
dirPack, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
packStore, err := NewExpiringDirJWTStore(dirPack, false, false, infDur, 0, true, 0, notification)
require_NoError(t, err)
// prefill the store with data
for k, v := range jwts {
require_NoError(t, packStore.SaveAcc(k, v))
}
for i := 0; i < len(jwts); i++ {
<-notificationChan
}
msg, err := packStore.Pack(-1)
require_NoError(t, err)
packStore.Close()
hash := packStore.Hash()
for _, shard := range []bool{true, false, true, false} {
dirMerge, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
mergeStore, err := NewExpiringDirJWTStore(dirMerge, shard, false, infDur, 0, true, 0, notification)
require_NoError(t, err)
// set
err = mergeStore.Merge(msg)
require_NoError(t, err)
assertStoreSize(t, mergeStore, len(jwts))
hash1 := packStore.Hash()
require_True(t, bytes.Equal(hash[:], hash1[:]))
for i := 0; i < len(jwts); i++ {
<-notificationChan
}
// overwrite - assure
err = mergeStore.Merge(msg)
require_NoError(t, err)
assertStoreSize(t, mergeStore, len(jwts))
hash2 := packStore.Hash()
require_True(t, bytes.Equal(hash1[:], hash2[:]))
hash = hash1
msg, err = mergeStore.Pack(-1)
require_NoError(t, err)
mergeStore.Close()
require_True(t, len(notificationChan) == 0)
for k, v := range jwts {
j, err := packStore.LoadAcc(k)
require_NoError(t, err)
require_Equal(t, j, v)
}
}
}
func TestNotificationOnPackWalk(t *testing.T) {
const storeCnt = 5
const keyCnt = 50
const iterCnt = 8
store := [storeCnt]*DirJWTStore{}
for i := 0; i < storeCnt; i++ {
dirMerge, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
require_NoError(t, err)
mergeStore, err := NewExpiringDirJWTStore(dirMerge, true, false, infDur, 0, true, 0, nil)
require_NoError(t, err)
store[i] = mergeStore
}
for i := 0; i < iterCnt; i++ { //iterations
jwt := make(map[string]string)
for j := 0; j < keyCnt; j++ {
key := fmt.Sprintf("key%d-%d", i, j)
value := "value"
jwt[key] = value
store[0].SaveAcc(key, value)
}
for j := 0; j < storeCnt-1; j++ { // stores
err := store[j].PackWalk(3, func(partialPackMsg string) {
err := store[j+1].Merge(partialPackMsg)
require_NoError(t, err)
})
require_NoError(t, err)
}
for i := 0; i < storeCnt-1; i++ {
h1 := store[i].Hash()
h2 := store[i+1].Hash()
require_True(t, bytes.Equal(h1[:], h2[:]))
}
}
for i := 0; i < storeCnt; i++ {
store[i].Close()
}
}

View File

@@ -584,9 +584,15 @@ func (s *Server) initEventTracking() {
s.Errorf("Error setting up internal tracking: %v", err)
}
// Listen for account claims updates.
subject = fmt.Sprintf(accUpdateEventSubj, "*")
if _, err := s.sysSubscribe(subject, s.accountClaimUpdate); err != nil {
s.Errorf("Error setting up internal tracking: %v", err)
subscribeToUpdate := true
if s.accResolver != nil {
subscribeToUpdate = !s.accResolver.IsTrackingUpdate()
}
if subscribeToUpdate {
subject = fmt.Sprintf(accUpdateEventSubj, "*")
if _, err := s.sysSubscribe(subject, s.accountClaimUpdate); err != nil {
s.Errorf("Error setting up internal tracking: %v", err)
}
}
// Listen for requests for our statsz.
subject = fmt.Sprintf(serverStatsReqSubj, s.info.ID)
@@ -647,7 +653,6 @@ func (s *Server) initEventTracking() {
if _, err := s.sysSubscribe(subject, s.remoteLatencyUpdate); err != nil {
s.Errorf("Error setting up internal latency tracking: %v", err)
}
// This is for simple debugging of number of subscribers that exist in the system.
if _, err := s.sysSubscribeInternal(accSubsSubj, s.debugSubscribers); err != nil {
s.Errorf("Error setting up internal debug service for subscribers: %v", err)
@@ -1239,17 +1244,22 @@ func (s *Server) sendAuthErrorEvent(c *client) {
// required to be copied.
type msgHandler func(sub *subscription, client *client, subject, reply string, msg []byte)
// Create an internal subscription. No support for queue groups atm.
// Create an internal subscription. sysSubscribeQ for queue groups.
func (s *Server) sysSubscribe(subject string, cb msgHandler) (*subscription, error) {
return s.systemSubscribe(subject, false, cb)
return s.systemSubscribe(subject, "", false, cb)
}
// Create an internal subscription with queue
func (s *Server) sysSubscribeQ(subject, queue string, cb msgHandler) (*subscription, error) {
return s.systemSubscribe(subject, queue, false, cb)
}
// Create an internal subscription but do not forward interest.
func (s *Server) sysSubscribeInternal(subject string, cb msgHandler) (*subscription, error) {
return s.systemSubscribe(subject, true, cb)
return s.systemSubscribe(subject, "", true, cb)
}
func (s *Server) systemSubscribe(subject string, internalOnly bool, cb msgHandler) (*subscription, error) {
func (s *Server) systemSubscribe(subject, queue string, internalOnly bool, cb msgHandler) (*subscription, error) {
if !s.eventsEnabled() {
return nil, ErrNoSysAccount
}
@@ -1263,12 +1273,16 @@ func (s *Server) systemSubscribe(subject string, internalOnly bool, cb msgHandle
sid := strconv.Itoa(s.sys.sid)
s.mu.Unlock()
// Now create the subscription
if trace {
c.traceInOp("SUB", []byte(subject+" "+sid))
c.traceInOp("SUB", []byte(subject+" "+queue+" "+sid))
}
// Now create the subscription
return c.processSub([]byte(subject), nil, []byte(sid), cb, internalOnly)
var q []byte
if queue != "" {
q = []byte(queue)
}
return c.processSub([]byte(subject), q, []byte(sid), cb, internalOnly)
}
func (s *Server) sysUnsubscribe(sub *subscription) {

View File

@@ -23,6 +23,7 @@ import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
@@ -2868,3 +2869,319 @@ func TestExpiredUserCredentialsRenewal(t *testing.T) {
t.Fatalf("Expected lastErr to be cleared, got %q", nc.LastError())
}
}
func TestAccountNATSResolverFetch(t *testing.T) {
require_NextMsg := func(sub *nats.Subscription) bool {
msg := natsNexMsg(t, sub, 10*time.Second)
content := make(map[string]interface{})
json.Unmarshal(msg.Data, &content)
if _, ok := content["data"]; ok {
return true
}
return false
}
require_FileAbsent := func(dir string, pub string) {
t.Helper()
_, err := os.Stat(filepath.Join(dir, pub+".jwt"))
require_Error(t, err)
require_True(t, os.IsNotExist(err))
}
require_FilePresent := func(dir string, pub string) {
t.Helper()
_, err := os.Stat(filepath.Join(dir, pub+".jwt"))
require_NoError(t, err)
}
require_FileEqual := func(dir string, pub string, jwt string) {
t.Helper()
content, err := ioutil.ReadFile(filepath.Join(dir, pub+".jwt"))
require_NoError(t, err)
require_Equal(t, string(content), jwt)
}
require_1Connection := func(url string, creds string) {
t.Helper()
c := natsConnect(t, url, nats.UserCredentials(creds))
defer c.Close()
if _, err := nats.Connect(url, nats.UserCredentials(creds)); err == nil {
t.Fatal("Second connection was supposed to fail due to limits")
} else if !strings.Contains(err.Error(), ErrTooManyAccountConnections.Error()) {
t.Fatal("Second connection was supposed to fail with too many conns")
}
}
require_2Connection := func(url string, creds string) {
t.Helper()
c1 := natsConnect(t, url, nats.UserCredentials(creds))
defer c1.Close()
c2 := natsConnect(t, url, nats.UserCredentials(creds))
defer c2.Close()
if _, err := nats.Connect(url, nats.UserCredentials(creds)); err == nil {
t.Fatal("Third connection was supposed to fail due to limits")
} else if !strings.Contains(err.Error(), ErrTooManyAccountConnections.Error()) {
t.Fatal("Third connection was supposed to fail with too many conns")
}
}
writeFile := func(dir string, pub string, jwt string) {
t.Helper()
err := ioutil.WriteFile(filepath.Join(dir, pub+".jwt"), []byte(jwt), 0644)
require_NoError(t, err)
}
createDir := func(prefix string) string {
t.Helper()
dir, err := ioutil.TempDir("", prefix)
require_NoError(t, err)
return dir
}
connect := func(url string, credsfile string) {
t.Helper()
nc := natsConnect(t, url, nats.UserCredentials(credsfile))
nc.Close()
}
createAccountAndUser := func(pair nkeys.KeyPair, limit bool) (string, string, string, string) {
t.Helper()
kp, _ := nkeys.CreateAccount()
pub, _ := kp.PublicKey()
claim := jwt.NewAccountClaims(pub)
if limit {
claim.Limits.Conn = 1
}
jwt1, err := claim.Encode(pair)
require_NoError(t, err)
time.Sleep(2 * time.Second)
// create updated claim allowing more connections
if limit {
claim.Limits.Conn = 2
}
jwt2, err := claim.Encode(pair)
require_NoError(t, err)
ukp, _ := nkeys.CreateUser()
seed, _ := ukp.Seed()
upub, _ := ukp.PublicKey()
uclaim := newJWTTestUserClaims()
uclaim.Subject = upub
ujwt, err := uclaim.Encode(kp)
require_NoError(t, err)
return pub, jwt1, jwt2, genCredsFile(t, ujwt, seed)
}
updateJwt := func(url string, creds string, pubKey string, jwt string) int {
t.Helper()
c := natsConnect(t, url, nats.UserCredentials(creds),
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
if err != nil {
t.Fatal("error not expected in this test", err)
}
}),
nats.ErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) {
t.Fatal("error not expected in this test", err)
}),
)
defer c.Close()
resp := c.NewRespInbox()
sub := natsSubSync(t, c, resp)
err := sub.AutoUnsubscribe(3)
require_NoError(t, err)
require_NoError(t, c.PublishRequest(fmt.Sprintf(accUpdateEventSubj, pubKey), resp, []byte(jwt)))
passCnt := 0
if require_NextMsg(sub) {
passCnt++
}
if require_NextMsg(sub) {
passCnt++
}
if require_NextMsg(sub) {
passCnt++
}
return passCnt
}
// Create Operator
op, _ := nkeys.CreateOperator()
opub, _ := op.PublicKey()
oc := jwt.NewOperatorClaims(opub)
oc.Subject = opub
ojwt, err := oc.Encode(op)
require_NoError(t, err)
// Create Accounts and corresponding user creds
syspub, sysjwt, _, sysCreds := createAccountAndUser(op, false)
defer os.Remove(sysCreds)
apub, ajwt1, ajwt2, aCreds := createAccountAndUser(op, true)
defer os.Remove(aCreds)
bpub, bjwt1, bjwt2, bCreds := createAccountAndUser(op, true)
defer os.Remove(bCreds)
cpub, cjwt1, cjwt2, cCreds := createAccountAndUser(op, true)
defer os.Remove(cCreds)
// Create one directory for each server
dirA := createDir("srv-a")
defer os.RemoveAll(dirA)
dirB := createDir("srv-b")
defer os.RemoveAll(dirB)
dirC := createDir("srv-c")
defer os.RemoveAll(dirC)
// simulate a restart of the server by storing files in them
// Server A/B will completely sync, so after startup each server
// will contain the union off all stored/configured jwt
// Server C will send out lookup requests for jwt it does not store itself
writeFile(dirA, apub, ajwt1)
writeFile(dirB, bpub, bjwt1)
writeFile(dirC, cpub, cjwt1)
// Create seed server A (using no_advertise to prevent fail over)
confA := createConfFile(t, []byte(fmt.Sprintf(`
listen: -1
server_name: srv-A
operator: %s
system_account: %s
resolver: {
type: full
dir: %s
interval: "1s"
limit: 4
}
resolver_preload: {
%s: %s
}
cluster {
name: clust
listen: -1
no_advertise: true
}
`, ojwt, syspub, dirA, cpub, cjwt1)))
defer os.Remove(confA)
sA, _ := RunServerWithConfig(confA)
defer sA.Shutdown()
// during startup resolver_preload causes the directory to contain data
require_FilePresent(dirA, cpub)
// Create Server B (using no_advertise to prevent fail over)
confB := createConfFile(t, []byte(fmt.Sprintf(`
listen: -1
server_name: srv-B
operator: %s
system_account: %s
resolver: {
type: full
dir: %s
interval: "1s"
limit: 4
}
cluster {
name: clust
listen: -1
no_advertise: true
routes [
nats-route://localhost:%d
]
}
`, ojwt, syspub, dirB, sA.opts.Cluster.Port)))
defer os.Remove(confB)
sB, _ := RunServerWithConfig(confB)
// Create Server C (using no_advertise to prevent fail over)
confC := createConfFile(t, []byte(fmt.Sprintf(`
listen: -1
server_name: srv-C
operator: %s
system_account: %s
resolver: {
type: cache
dir: %s
ttl: "10s"
limit: 4
}
cluster {
name: clust
listen: -1
no_advertise: true
routes [
nats-route://localhost:%d
]
}
`, ojwt, syspub, dirC, sA.opts.Cluster.Port)))
defer os.Remove(confC)
sC, _ := RunServerWithConfig(confC)
// startup cluster
checkClusterFormed(t, sA, sB, sC)
time.Sleep(2 * time.Second) // wait for the protocol to converge
// Check all accounts
require_FilePresent(dirA, apub) // was already present on startup
require_FilePresent(dirB, apub) // was copied from server A
require_FileAbsent(dirC, apub)
require_FilePresent(dirA, bpub) // was copied from server B
require_FilePresent(dirB, bpub) // was already present on startup
require_FileAbsent(dirC, bpub)
require_FilePresent(dirA, cpub) // was present in preload
require_FilePresent(dirB, cpub) // was copied from server A
require_FilePresent(dirC, cpub) // was already present on startup
// This is to test that connecting to it still works
require_FileAbsent(dirA, syspub)
require_FileAbsent(dirB, syspub)
require_FileAbsent(dirC, syspub)
// system account client can connect to every server
connect(sA.ClientURL(), sysCreds)
connect(sB.ClientURL(), sysCreds)
connect(sC.ClientURL(), sysCreds)
checkClusterFormed(t, sA, sB, sC)
// upload system account and require a response from each server
passCnt := updateJwt(sA.ClientURL(), sysCreds, syspub, sysjwt)
require_True(t, passCnt == 3)
require_FilePresent(dirA, syspub) // was just received
require_FilePresent(dirB, syspub) // was just received
require_FilePresent(dirC, syspub) // was just received
// Only files missing are in C, which is only caching
connect(sC.ClientURL(), aCreds)
connect(sC.ClientURL(), bCreds)
require_FilePresent(dirC, apub) // was looked up form A or B
require_FilePresent(dirC, bpub) // was looked up from A or B
// Check limits and update jwt B connecting to server A
for port, v := range map[string]struct{ pub, jwt, creds string }{
sB.ClientURL(): {bpub, bjwt2, bCreds},
sC.ClientURL(): {cpub, cjwt2, cCreds},
} {
require_1Connection(sA.ClientURL(), v.creds)
require_1Connection(sB.ClientURL(), v.creds)
require_1Connection(sC.ClientURL(), v.creds)
checkClientsCount(t, sA, 0)
checkClientsCount(t, sB, 0)
checkClientsCount(t, sC, 0)
passCnt := updateJwt(port, sysCreds, v.pub, v.jwt)
require_True(t, passCnt == 3)
require_2Connection(sA.ClientURL(), v.creds)
require_2Connection(sB.ClientURL(), v.creds)
require_2Connection(sC.ClientURL(), v.creds)
require_FileEqual(dirA, v.pub, v.jwt)
require_FileEqual(dirB, v.pub, v.jwt)
require_FileEqual(dirC, v.pub, v.jwt)
}
// Simulates A having missed an update
// shutting B down as it has it will directly connect to A and connect right away
sB.Shutdown()
writeFile(dirB, apub, ajwt2) // this will be copied to server A
sB, _ = RunServerWithConfig(confB)
defer sB.Shutdown()
checkClusterFormed(t, sA, sB, sC)
time.Sleep(2 * time.Second) // wait for the protocol to converge, will also assure that C's cache becomes invalid
// Restart server C. this is a workaround to force C to do a lookup in the absence of account cleanup
sC.Shutdown()
sC, _ = RunServerWithConfig(confC) //TODO remove this once we clean up accounts
require_FileEqual(dirA, apub, ajwt2) // was copied from server B
require_FileEqual(dirB, apub, ajwt2) // was restarted with this
require_FileEqual(dirC, apub, ajwt1) // still contains old cached value
require_2Connection(sA.ClientURL(), aCreds)
require_2Connection(sB.ClientURL(), aCreds)
require_1Connection(sC.ClientURL(), aCreds)
// Restart server C. this is a workaround to force C to do a lookup in the absence of account cleanup
sC.Shutdown()
sC, _ = RunServerWithConfig(confC) //TODO remove this once we clean up accounts
defer sC.Shutdown()
require_FileEqual(dirC, apub, ajwt1) // still contains old cached value
time.Sleep(time.Second * 12) // Force next connect to do a lookup
connect(sC.ClientURL(), aCreds) // When lookup happens
require_FileEqual(dirC, apub, ajwt2) // was looked up form A or B
require_2Connection(sC.ClientURL(), aCreds)
// Test exceeding limit. For the exclusive directory resolver, limit is a stop gap measure.
// It is not expected to be hit. When hit the administrator is supposed to take action.
dpub, djwt1, _, dCreds := createAccountAndUser(op, true)
defer os.Remove(dCreds)
passCnt = updateJwt(sA.ClientURL(), sysCreds, dpub, djwt1)
require_True(t, passCnt == 1) // Only Server C updated
}

View File

@@ -816,22 +816,16 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error
}
}
case "resolver", "account_resolver", "accounts_resolver":
// "resolver" takes precedence over value obtained from "operator".
// Clear so that parsing errors are not silently ignored.
o.AccountResolver = nil
var memResolverRe = regexp.MustCompile(`(MEM|MEMORY|mem|memory)\s*`)
var resolverRe = regexp.MustCompile(`(?:URL|url){1}(?:\({1}\s*"?([^\s"]*)"?\s*\){1})?\s*`)
str, ok := v.(string)
if !ok {
err := &configErr{tk, fmt.Sprintf("error parsing operator resolver, wrong type %T", v)}
*errors = append(*errors, err)
return
}
if memResolverRe.MatchString(str) {
o.AccountResolver = &MemAccResolver{}
} else {
items := resolverRe.FindStringSubmatch(str)
if len(items) == 2 {
switch v := v.(type) {
case string:
// "resolver" takes precedence over value obtained from "operator".
// Clear so that parsing errors are not silently ignored.
o.AccountResolver = nil
memResolverRe := regexp.MustCompile(`(?i)(MEM|MEMORY)\s*`)
resolverRe := regexp.MustCompile(`(?i)(?:URL){1}(?:\({1}\s*"?([^\s"]*)"?\s*\){1})?\s*`)
if memResolverRe.MatchString(v) {
o.AccountResolver = &MemAccResolver{}
} else if items := resolverRe.FindStringSubmatch(v); len(items) == 2 {
url := items[1]
_, err := parseURL(url, "account resolver")
if err != nil {
@@ -846,9 +840,70 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error
o.AccountResolver = ur
}
}
case map[string]interface{}:
dir := ""
dirType := ""
limit := int64(0)
ttl := time.Duration(0)
sync := time.Duration(0)
var err error
if v, ok := v["dir"]; ok {
_, v := unwrapValue(v, &lt)
dir = v.(string)
}
if v, ok := v["type"]; ok {
_, v := unwrapValue(v, &lt)
dirType = v.(string)
}
if v, ok := v["limit"]; ok {
_, v := unwrapValue(v, &lt)
limit = v.(int64)
}
if v, ok := v["ttl"]; ok {
_, v := unwrapValue(v, &lt)
ttl, err = time.ParseDuration(v.(string))
}
if v, ok := v["interval"]; err == nil && ok {
_, v := unwrapValue(v, &lt)
sync, err = time.ParseDuration(v.(string))
}
if err != nil {
*errors = append(*errors, &configErr{tk, err.Error()})
return
}
if dir == "" {
*errors = append(*errors, &configErr{tk, "dir needs to point to a directory"})
return
}
if info, err := os.Stat(dir); err != nil || !info.IsDir() || info.Mode().Perm()&(1<<(uint(7))) == 0 {
info.IsDir()
}
var res AccountResolver
switch strings.ToUpper(dirType) {
case "CACHE":
if sync != 0 {
*errors = append(*errors, &configErr{tk, "CACHE does not accept sync"})
}
res, err = NewCacheDirAccResolver(dir, limit, ttl)
case "FULL":
if ttl != 0 {
*errors = append(*errors, &configErr{tk, "FULL does not accept ttl"})
}
res, err = NewDirAccResolver(dir, limit, sync)
}
if err != nil {
*errors = append(*errors, &configErr{tk, err.Error()})
return
}
o.AccountResolver = res
default:
err := &configErr{tk, fmt.Sprintf("error parsing operator resolver, wrong type %T", v)}
*errors = append(*errors, err)
return
}
if o.AccountResolver == nil {
err := &configErr{tk, "error parsing account resolver, should be MEM or URL(\"url\")"}
err := &configErr{tk, "error parsing account resolver, should be MEM or " +
" URL(\"url\") or a map containing dir and type state=[FULL|CACHE])"}
*errors = append(*errors, err)
}
case "resolver_tls":

View File

@@ -770,7 +770,7 @@ func imposeOrder(value interface{}) error {
return value.AllowedOrigins[i] < value.AllowedOrigins[j]
})
case string, bool, int, int32, int64, time.Duration, float64, nil,
LeafNodeOpts, ClusterOpts, *tls.Config, *URLAccResolver, *MemAccResolver, Authentication:
LeafNodeOpts, ClusterOpts, *tls.Config, *URLAccResolver, *MemAccResolver, *DirAccResolver, *CacheDirAccResolver, Authentication:
// explicitly skipped types
default:
// this will fail during unit tests
@@ -1275,6 +1275,10 @@ func (s *Server) reloadAuthorization() {
if checkJetStream {
s.configAllJetStreamAccounts()
}
if res := s.AccountResolver(); res != nil {
res.Reload()
}
}
// Returns true if given client current account has changed (or user

View File

@@ -360,14 +360,11 @@ func NewServer(opts *Options) (*Server, error) {
// waiting for complete shutdown.
s.shutdownComplete = make(chan struct{})
// For tracking accounts
if err := s.configureAccounts(); err != nil {
// Check for configured account resolvers.
if err := s.configureResolver(); err != nil {
return nil, err
}
// If there is an URL account resolver, do basic test to see if anyone is home.
// Do this after configureAccounts() which calls configureResolver(), which will
// set TLSConfig if specified.
if ar := opts.AccountResolver; ar != nil {
if ur, ok := ar.(*URLAccResolver); ok {
if _, err := ur.Fetch(""); err != nil {
@@ -375,6 +372,31 @@ func NewServer(opts *Options) (*Server, error) {
}
}
}
// For other resolver:
// In operator mode, when the account resolver depends on an external system and
// the system account can't fetched, inject a temporary one.
if ar := s.accResolver; len(opts.TrustedOperators) == 1 && ar != nil &&
opts.SystemAccount != _EMPTY_ && opts.SystemAccount != DEFAULT_SYSTEM_ACCOUNT {
if _, ok := ar.(*MemAccResolver); !ok {
s.mu.Unlock()
var a *Account
// perform direct lookup to avoid warning trace
if _, err := ar.Fetch(s.opts.SystemAccount); err == nil {
a, _ = s.fetchAccount(s.opts.SystemAccount)
}
s.mu.Lock()
if a == nil {
sac := NewAccount(s.opts.SystemAccount)
sac.Issuer = opts.TrustedOperators[0].Issuer
s.registerAccountNoLock(sac)
}
}
}
// For tracking accounts
if err := s.configureAccounts(); err != nil {
return nil, err
}
// In local config mode, check that leafnode configuration
// refers to account that exist.
@@ -610,11 +632,6 @@ func (s *Server) configureAccounts() error {
return true
})
// Check for configured account resolvers.
if err := s.configureResolver(); err != nil {
return err
}
// Set the system account if it was configured.
// Otherwise create a default one.
if opts.SystemAccount != _EMPTY_ {
@@ -654,8 +671,8 @@ func (s *Server) configureResolver() error {
}
}
if len(opts.resolverPreloads) > 0 {
if _, ok := s.accResolver.(*MemAccResolver); !ok {
return fmt.Errorf("resolver preloads only available for resolver type MEM")
if s.accResolver.IsReadOnly() {
return fmt.Errorf("resolver preloads only available for writeable resolver types MEM/DIR/CACHE_DIR")
}
for k, v := range opts.resolverPreloads {
_, err := jwt.DecodeAccountClaims(v)
@@ -1346,7 +1363,8 @@ func (s *Server) Start() {
// Log the pid to a file
if opts.PidFile != _EMPTY_ {
if err := s.logPid(); err != nil {
PrintAndDie(fmt.Sprintf("Could not write pidfile: %v\n", err))
s.Fatalf("Could not write pidfile: %v", err)
return
}
}
@@ -1361,6 +1379,42 @@ func (s *Server) Start() {
s.SetDefaultSystemAccount()
}
// start up resolver machinery
if ar := s.AccountResolver(); ar != nil {
if err := ar.Start(s); err != nil {
s.Fatalf("Could not start resolver: %v", err)
return
}
// In operator mode, when the account resolver depends on an external system and
// the system account is the bootstrapping account, start fetching it
if len(opts.TrustedOperators) == 1 && opts.SystemAccount != _EMPTY_ && opts.SystemAccount != DEFAULT_SYSTEM_ACCOUNT {
_, isMemResolver := ar.(*MemAccResolver)
if v, ok := s.accounts.Load(s.opts.SystemAccount); !isMemResolver && ok && v.(*Account).claimJWT == "" {
s.Noticef("Using bootstrapping system account")
s.startGoRoutine(func() {
defer s.grWG.Done()
t := time.NewTicker(time.Second)
defer t.Stop()
for {
select {
case <-s.quitCh:
return
case <-t.C:
if _, err := ar.Fetch(s.opts.SystemAccount); err != nil {
continue
}
if _, err := s.fetchAccount(s.opts.SystemAccount); err != nil {
continue
}
s.Noticef("System account fetched and updated")
return
}
}
})
}
}
}
// Start expiration of mapped GW replies, regardless if
// this server is configured with gateway or not.
s.startGWReplyMapExpiration()
@@ -1483,6 +1537,10 @@ func (s *Server) Shutdown() {
}
s.Noticef("Initiating Shutdown...")
if s.accResolver != nil {
s.accResolver.Close()
}
opts := s.getOpts()
s.shutdown = true