mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 11:48:43 -07:00
Nats based resolver & avoiding nats account server in smaller deployments (#1550)
* Adding nats based resolver and bootstrap system account
These resolver operate on an exclusive directory
Two types:
full: managing all jwt in the directory
Will synchronize with other full resolver
nats-account-server will also run such a resolver
cache: lru cache managing only a subset of all jwt in the directory
Will lookup jwt from full resolver
Can overwrite expiration with a ttl for the file
Both:
track expiration of jwt and clean up
Support reload
Notify the server of changed jwt
Bootstrapping system account allows users signed with the system account
jwt to connect, without the server knowing the jwt.
This allows uploading jwt (including system account) using nats by
publishing to $SYS.ACCOUNT.<name>.CLAIMS.UPDATE
Sending a request, server will respond with the result of the operation.
Receive all jwt stored in one server by sending a
request to $SYS.ACCOUNT.CLAIMS.PACK
One server will respond with a message per stored jwt.
The end of the responses is indicated by an empty message.
The content of dirstore.go and dirstore_test.go was moved from
nats-account-server
Signed-off-by: Matthias Hanel <mh@synadia.com>
This commit is contained in:
@@ -14,8 +14,11 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -2581,16 +2584,50 @@ func buildInternalNkeyUser(uc *jwt.UserClaims, acc *Account) *NkeyUser {
|
||||
return nu
|
||||
}
|
||||
|
||||
const fetchTimeout = 2 * time.Second
|
||||
|
||||
// AccountResolver interface. This is to fetch Account JWTs by public nkeys
|
||||
type AccountResolver interface {
|
||||
Fetch(name string) (string, error)
|
||||
Store(name, jwt string) error
|
||||
IsReadOnly() bool
|
||||
Start(server *Server) error
|
||||
IsTrackingUpdate() bool
|
||||
Reload() error
|
||||
Close()
|
||||
}
|
||||
|
||||
// Default implementations of IsReadOnly/Start so only need to be written when changed
|
||||
type resolverDefaultsOpsImpl struct{}
|
||||
|
||||
func (*resolverDefaultsOpsImpl) IsReadOnly() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (*resolverDefaultsOpsImpl) IsTrackingUpdate() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (*resolverDefaultsOpsImpl) Start(*Server) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*resolverDefaultsOpsImpl) Reload() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*resolverDefaultsOpsImpl) Close() {
|
||||
}
|
||||
|
||||
func (*resolverDefaultsOpsImpl) Store(_, _ string) error {
|
||||
return fmt.Errorf("Store operation not supported for URL Resolver")
|
||||
}
|
||||
|
||||
// MemAccResolver is a memory only resolver.
|
||||
// Mostly for testing.
|
||||
type MemAccResolver struct {
|
||||
sm sync.Map
|
||||
resolverDefaultsOpsImpl
|
||||
}
|
||||
|
||||
// Fetch will fetch the account jwt claims from the internal sync.Map.
|
||||
@@ -2607,10 +2644,15 @@ func (m *MemAccResolver) Store(name, jwt string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ur *MemAccResolver) IsReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// URLAccResolver implements an http fetcher.
|
||||
type URLAccResolver struct {
|
||||
url string
|
||||
c *http.Client
|
||||
resolverDefaultsOpsImpl
|
||||
}
|
||||
|
||||
// NewURLAccResolver returns a new resolver for the given base URL.
|
||||
@@ -2626,7 +2668,7 @@ func NewURLAccResolver(url string) (*URLAccResolver, error) {
|
||||
}
|
||||
ur := &URLAccResolver{
|
||||
url: url,
|
||||
c: &http.Client{Timeout: 2 * time.Second, Transport: tr},
|
||||
c: &http.Client{Timeout: fetchTimeout, Transport: tr},
|
||||
}
|
||||
return ur, nil
|
||||
}
|
||||
@@ -2651,7 +2693,277 @@ func (ur *URLAccResolver) Fetch(name string) (string, error) {
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
// Store is not implemented for URL Resolver.
|
||||
func (ur *URLAccResolver) Store(name, jwt string) error {
|
||||
return fmt.Errorf("Store operation not supported for URL Resolver")
|
||||
// Resolver based on nats for synchronization and backing directory for storage.
|
||||
type DirAccResolver struct {
|
||||
*DirJWTStore
|
||||
syncInterval time.Duration
|
||||
}
|
||||
|
||||
func (dr *DirAccResolver) IsTrackingUpdate() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func respondToUpdate(s *Server, respSubj string, acc string, message string, err error) {
|
||||
if err == nil {
|
||||
s.Debugf("%s - %s", message, acc)
|
||||
} else {
|
||||
s.Errorf("%s - %s - %s", message, acc, err)
|
||||
}
|
||||
if respSubj == "" {
|
||||
return
|
||||
}
|
||||
server := &ServerInfo{}
|
||||
response := map[string]interface{}{"server": server}
|
||||
if err == nil {
|
||||
response["data"] = map[string]interface{}{
|
||||
"code": http.StatusOK,
|
||||
"account": acc,
|
||||
"message": message,
|
||||
}
|
||||
} else {
|
||||
response["error"] = map[string]interface{}{
|
||||
"code": http.StatusInternalServerError,
|
||||
"account": acc,
|
||||
"description": fmt.Sprintf("%s - %v", message, err),
|
||||
}
|
||||
}
|
||||
s.sendInternalMsgLocked(respSubj, _EMPTY_, server, response)
|
||||
}
|
||||
|
||||
func (dr *DirAccResolver) Start(s *Server) error {
|
||||
dr.Lock()
|
||||
defer dr.Unlock()
|
||||
dr.DirJWTStore.changed = func(pubKey string) {
|
||||
if v, ok := s.accounts.Load(pubKey); !ok {
|
||||
} else if jwt, err := dr.LoadAcc(pubKey); err != nil {
|
||||
s.Errorf("update got error on load: %v", err)
|
||||
} else if err := s.updateAccountWithClaimJWT(v.(*Account), jwt); err != nil {
|
||||
s.Errorf("update resulted in error %v", err)
|
||||
}
|
||||
}
|
||||
const accountPackRequest = "$SYS.ACCOUNT.CLAIMS.PACK"
|
||||
const accountLookupRequest = "$SYS.ACCOUNT.*.CLAIMS.LOOKUP"
|
||||
packRespIb := s.newRespInbox()
|
||||
// subscribe to account jwt update requests
|
||||
if _, err := s.sysSubscribe(fmt.Sprintf(accUpdateEventSubj, "*"), func(_ *subscription, _ *client, subj, resp string, msg []byte) {
|
||||
tk := strings.Split(subj, tsep)
|
||||
if len(tk) != accUpdateTokens {
|
||||
return
|
||||
}
|
||||
pubKey := tk[accUpdateAccIndex]
|
||||
if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil {
|
||||
respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err)
|
||||
} else if claim.Subject != pubKey {
|
||||
err := errors.New("subject does not match jwt content")
|
||||
respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err)
|
||||
} else if err := dr.save(pubKey, string(msg)); err != nil {
|
||||
respondToUpdate(s, resp, pubKey, "jwt update resulted in error", err)
|
||||
} else {
|
||||
respondToUpdate(s, resp, pubKey, "jwt updated", nil)
|
||||
}
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error setting up update handling: %v", err)
|
||||
} else if _, err := s.sysSubscribe(accountLookupRequest, func(_ *subscription, _ *client, subj, reply string, msg []byte) {
|
||||
// respond to lookups with our version
|
||||
if reply == "" {
|
||||
return
|
||||
}
|
||||
tk := strings.Split(subj, tsep)
|
||||
if len(tk) != accUpdateTokens {
|
||||
return
|
||||
}
|
||||
if theJWT, err := dr.DirJWTStore.LoadAcc(tk[accUpdateAccIndex]); err != nil {
|
||||
s.Errorf("Merging resulted in error: %v", err)
|
||||
} else {
|
||||
s.sendInternalMsgLocked(reply, "", nil, []byte(theJWT))
|
||||
}
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error setting up lookup request handling: %v", err)
|
||||
} else if _, err = s.sysSubscribeQ(accountPackRequest, "responder",
|
||||
// respond to pack requests with one or more pack messages
|
||||
// an empty message signifies the end of the response responder
|
||||
func(_ *subscription, _ *client, _, reply string, theirHash []byte) {
|
||||
if reply == "" {
|
||||
return
|
||||
}
|
||||
ourHash := dr.DirJWTStore.Hash()
|
||||
if bytes.Equal(theirHash, ourHash[:]) {
|
||||
s.sendInternalMsgLocked(reply, "", nil, []byte{})
|
||||
s.Debugf("pack request matches hash %x", ourHash[:])
|
||||
} else if err := dr.DirJWTStore.PackWalk(1, func(partialPackMsg string) {
|
||||
s.sendInternalMsgLocked(reply, "", nil, []byte(partialPackMsg))
|
||||
}); err != nil {
|
||||
// let them timeout
|
||||
s.Errorf("pack request error: %v", err)
|
||||
} else {
|
||||
s.Debugf("pack request hash %x - finished responding with hash %x")
|
||||
s.sendInternalMsgLocked(reply, "", nil, []byte{})
|
||||
}
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error setting up pack request handling: %v", err)
|
||||
} else if _, err = s.sysSubscribe(packRespIb, func(_ *subscription, _ *client, _, _ string, msg []byte) {
|
||||
// embed pack responses into store
|
||||
hash := dr.DirJWTStore.Hash()
|
||||
if len(msg) == 0 { // end of response stream
|
||||
s.Debugf("Merging Finished and resulting in: %x", dr.DirJWTStore.Hash())
|
||||
return
|
||||
} else if err := dr.DirJWTStore.Merge(string(msg)); err != nil {
|
||||
s.Errorf("Merging resulted in error: %v", err)
|
||||
} else {
|
||||
s.Debugf("Merging succeeded and changed %x to %x", hash, dr.DirJWTStore.Hash())
|
||||
}
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error setting up pack response handling: %v", err)
|
||||
}
|
||||
// periodically send out pack message
|
||||
quit := s.quitCh
|
||||
s.startGoRoutine(func() {
|
||||
defer s.grWG.Done()
|
||||
ticker := time.NewTicker(dr.syncInterval)
|
||||
for {
|
||||
select {
|
||||
case <-quit:
|
||||
ticker.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
ourHash := dr.DirJWTStore.Hash()
|
||||
s.Debugf("Checking store state: %x", ourHash)
|
||||
s.sendInternalMsgLocked(accountPackRequest, packRespIb, nil, ourHash[:])
|
||||
}
|
||||
})
|
||||
s.Noticef("Managing all jwt in exclusive directory %s", dr.directory)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dr *DirAccResolver) Fetch(name string) (string, error) {
|
||||
return dr.LoadAcc(name)
|
||||
}
|
||||
|
||||
func (dr *DirAccResolver) Store(name, jwt string) error {
|
||||
return dr.saveIfNewer(name, jwt)
|
||||
}
|
||||
|
||||
func NewDirAccResolver(path string, limit int64, syncInterval time.Duration) (*DirAccResolver, error) {
|
||||
if limit == 0 {
|
||||
limit = math.MaxInt64
|
||||
}
|
||||
if syncInterval <= 0 {
|
||||
syncInterval = time.Minute
|
||||
}
|
||||
store, err := NewExpiringDirJWTStore(path, false, true, 0, limit, false, 0, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DirAccResolver{store, syncInterval}, nil
|
||||
}
|
||||
|
||||
// Caching resolver using nats for lookups and making use of a directory for storage
|
||||
type CacheDirAccResolver struct {
|
||||
DirAccResolver
|
||||
*Server
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func (dr *CacheDirAccResolver) Fetch(name string) (string, error) {
|
||||
if theJWT, _ := dr.LoadAcc(name); theJWT != "" {
|
||||
return theJWT, nil
|
||||
}
|
||||
// lookup from other server
|
||||
s := dr.Server
|
||||
if s == nil {
|
||||
return "", ErrNoAccountResolver
|
||||
}
|
||||
respC := make(chan []byte, 1)
|
||||
accountLookupRequest := fmt.Sprintf("$SYS.ACCOUNT.%s.CLAIMS.LOOKUP", name)
|
||||
s.mu.Lock()
|
||||
replySubj := s.newRespInbox()
|
||||
if s.sys == nil || s.sys.replies == nil {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("eventing shut down")
|
||||
}
|
||||
replies := s.sys.replies
|
||||
// Store our handler.
|
||||
replies[replySubj] = func(sub *subscription, _ *client, subject, _ string, msg []byte) {
|
||||
clone := make([]byte, len(msg))
|
||||
copy(clone, msg)
|
||||
s.mu.Lock()
|
||||
if _, ok := replies[replySubj]; ok {
|
||||
respC <- clone // only send if there is still interest
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.sendInternalMsg(accountLookupRequest, replySubj, nil, []byte{})
|
||||
quit := s.quitCh
|
||||
s.mu.Unlock()
|
||||
var err error
|
||||
var theJWT string
|
||||
select {
|
||||
case <-quit:
|
||||
err = errors.New("fetching jwt failed due to shutdown")
|
||||
case <-time.After(fetchTimeout):
|
||||
err = errors.New("fetching jwt timed out")
|
||||
case m := <-respC:
|
||||
if err = dr.Store(name, string(m)); err == nil {
|
||||
theJWT = string(m)
|
||||
}
|
||||
}
|
||||
s.mu.Lock()
|
||||
delete(replies, replySubj)
|
||||
s.mu.Unlock()
|
||||
close(respC)
|
||||
return theJWT, err
|
||||
}
|
||||
|
||||
func NewCacheDirAccResolver(path string, limit int64, ttl time.Duration) (*CacheDirAccResolver, error) {
|
||||
if limit <= 0 {
|
||||
limit = 1_000
|
||||
}
|
||||
store, err := NewExpiringDirJWTStore(path, false, true, 0, limit, true, ttl, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CacheDirAccResolver{DirAccResolver{store, 0}, nil, ttl}, nil
|
||||
}
|
||||
|
||||
func (dr *CacheDirAccResolver) Start(s *Server) error {
|
||||
dr.Lock()
|
||||
defer dr.Unlock()
|
||||
dr.Server = s
|
||||
dr.DirJWTStore.changed = func(pubKey string) {
|
||||
if v, ok := s.accounts.Load(pubKey); !ok {
|
||||
} else if jwt, err := dr.LoadAcc(pubKey); err != nil {
|
||||
s.Errorf("update got error on load: %v", err)
|
||||
} else if err := s.updateAccountWithClaimJWT(v.(*Account), jwt); err != nil {
|
||||
s.Errorf("update resulted in error %v", err)
|
||||
}
|
||||
}
|
||||
// subscribe to account jwt update requests
|
||||
if _, err := s.sysSubscribe(fmt.Sprintf(accUpdateEventSubj, "*"), func(_ *subscription, _ *client, subj, resp string, msg []byte) {
|
||||
tk := strings.Split(subj, tsep)
|
||||
if len(tk) != accUpdateTokens {
|
||||
return
|
||||
}
|
||||
pubKey := tk[accUpdateAccIndex]
|
||||
if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil {
|
||||
respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err)
|
||||
} else if claim.Subject != pubKey {
|
||||
err := errors.New("subject does not match jwt content")
|
||||
respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err)
|
||||
} else if _, ok := s.accounts.Load(pubKey); !ok {
|
||||
respondToUpdate(s, resp, pubKey, "jwt update cache skipped", nil)
|
||||
} else if err := dr.save(pubKey, string(msg)); err != nil {
|
||||
respondToUpdate(s, resp, pubKey, "jwt update cache resulted in error", err)
|
||||
} else {
|
||||
respondToUpdate(s, resp, pubKey, "jwt updated cache", nil)
|
||||
}
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error setting up update handling: %v", err)
|
||||
}
|
||||
s.Noticef("Managing some jwt in exclusive directory %s", dr.directory)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dr *CacheDirAccResolver) Reload() error {
|
||||
return dr.DirAccResolver.Reload()
|
||||
}
|
||||
|
||||
670
server/dirstore.go
Normal file
670
server/dirstore.go
Normal file
@@ -0,0 +1,670 @@
|
||||
/*
|
||||
* Copyright 2020 The NATS Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"container/heap"
|
||||
"container/list"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/jwt/v2" // only used to decode, not for storage
|
||||
)
|
||||
|
||||
const (
|
||||
fileExtension = ".jwt"
|
||||
)
|
||||
|
||||
// validatePathExists checks that the provided path exists and is a dir if requested
|
||||
func validatePathExists(path string, dir bool) (string, error) {
|
||||
if path == "" {
|
||||
return "", errors.New("path is not specified")
|
||||
}
|
||||
|
||||
abs, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing path [%s]: %v", abs, err)
|
||||
}
|
||||
|
||||
var finfo os.FileInfo
|
||||
if finfo, err = os.Stat(abs); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("the path [%s] doesn't exist", abs)
|
||||
}
|
||||
|
||||
mode := finfo.Mode()
|
||||
if dir && mode.IsRegular() {
|
||||
return "", fmt.Errorf("the path [%s] is not a directory", abs)
|
||||
}
|
||||
|
||||
if !dir && mode.IsDir() {
|
||||
return "", fmt.Errorf("the path [%s] is not a file", abs)
|
||||
}
|
||||
|
||||
return abs, nil
|
||||
}
|
||||
|
||||
// ValidateDirPath checks that the provided path exists and is a dir
|
||||
func validateDirPath(path string) (string, error) {
|
||||
return validatePathExists(path, true)
|
||||
}
|
||||
|
||||
// JWTChanged functions are called when the store file watcher notices a JWT changed
|
||||
type JWTChanged func(publicKey string)
|
||||
|
||||
// DirJWTStore implements the JWT Store interface, keeping JWTs in an optionally sharded
|
||||
// directory structure
|
||||
type DirJWTStore struct {
|
||||
sync.Mutex
|
||||
directory string
|
||||
shard bool
|
||||
readonly bool
|
||||
expiration *expirationTracker
|
||||
changed JWTChanged
|
||||
}
|
||||
|
||||
func newDir(dirPath string, create bool) (string, error) {
|
||||
fullPath, err := validateDirPath(dirPath)
|
||||
if err != nil {
|
||||
if !create {
|
||||
return "", err
|
||||
}
|
||||
if err = os.MkdirAll(dirPath, 0755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if fullPath, err = validateDirPath(dirPath); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return fullPath, nil
|
||||
}
|
||||
|
||||
// Creates a directory based jwt store.
|
||||
// Reads files only, does NOT watch directories and files.
|
||||
func NewImmutableDirJWTStore(dirPath string, shard bool) (*DirJWTStore, error) {
|
||||
theStore, err := NewDirJWTStore(dirPath, shard, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
theStore.readonly = true
|
||||
return theStore, nil
|
||||
}
|
||||
|
||||
// Creates a directory based jwt store.
|
||||
// Operates on files only, does NOT watch directories and files.
|
||||
func NewDirJWTStore(dirPath string, shard bool, create bool) (*DirJWTStore, error) {
|
||||
fullPath, err := newDir(dirPath, create)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
theStore := &DirJWTStore{
|
||||
directory: fullPath,
|
||||
shard: shard,
|
||||
}
|
||||
return theStore, nil
|
||||
}
|
||||
|
||||
// Creates a directory based jwt store.
|
||||
//
|
||||
// When ttl is set deletion of file is based on it and not on the jwt expiration
|
||||
// To completely disable expiration (including expiration in jwt) set ttl to max duration time.Duration(math.MaxInt64)
|
||||
//
|
||||
// limit defines how many files are allowed at any given time. Set to math.MaxInt64 to disable.
|
||||
// evictOnLimit determines the behavior once limit is reached.
|
||||
// true - Evict based on lru strategy
|
||||
// false - return an error
|
||||
func NewExpiringDirJWTStore(dirPath string, shard bool, create bool, expireCheck time.Duration, limit int64,
|
||||
evictOnLimit bool, ttl time.Duration, changeNotification JWTChanged) (*DirJWTStore, error) {
|
||||
fullPath, err := newDir(dirPath, create)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
theStore := &DirJWTStore{
|
||||
directory: fullPath,
|
||||
shard: shard,
|
||||
changed: changeNotification,
|
||||
}
|
||||
if expireCheck <= 0 {
|
||||
if ttl != 0 {
|
||||
expireCheck = ttl / 2
|
||||
}
|
||||
if expireCheck == 0 || expireCheck > time.Minute {
|
||||
expireCheck = time.Minute
|
||||
}
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = math.MaxInt64
|
||||
}
|
||||
theStore.startExpiring(expireCheck, limit, evictOnLimit, ttl)
|
||||
theStore.Lock()
|
||||
err = filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error {
|
||||
if strings.HasSuffix(path, fileExtension) {
|
||||
if theJwt, err := ioutil.ReadFile(path); err == nil {
|
||||
hash := sha256.Sum256(theJwt)
|
||||
_, file := filepath.Split(path)
|
||||
theStore.expiration.track(strings.TrimSuffix(file, fileExtension), &hash, string(theJwt))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
theStore.Unlock()
|
||||
if err != nil {
|
||||
theStore.Close()
|
||||
return nil, err
|
||||
}
|
||||
return theStore, err
|
||||
}
|
||||
|
||||
func (store *DirJWTStore) IsReadOnly() bool {
|
||||
return store.readonly
|
||||
}
|
||||
|
||||
func (store *DirJWTStore) LoadAcc(publicKey string) (string, error) {
|
||||
return store.load(publicKey)
|
||||
}
|
||||
|
||||
func (store *DirJWTStore) SaveAcc(publicKey string, theJWT string) error {
|
||||
return store.save(publicKey, theJWT)
|
||||
}
|
||||
|
||||
func (store *DirJWTStore) LoadAct(hash string) (string, error) {
|
||||
return store.load(hash)
|
||||
}
|
||||
|
||||
func (store *DirJWTStore) SaveAct(hash string, theJWT string) error {
|
||||
return store.save(hash, theJWT)
|
||||
}
|
||||
|
||||
func (store *DirJWTStore) Close() {
|
||||
store.Lock()
|
||||
defer store.Unlock()
|
||||
if store.expiration != nil {
|
||||
store.expiration.close()
|
||||
store.expiration = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Pack up to maxJWTs into a package
|
||||
func (store *DirJWTStore) Pack(maxJWTs int) (string, error) {
|
||||
count := 0
|
||||
var pack []string
|
||||
if maxJWTs > 0 {
|
||||
pack = make([]string, 0, maxJWTs)
|
||||
} else {
|
||||
pack = []string{}
|
||||
}
|
||||
store.Lock()
|
||||
err := filepath.Walk(store.directory, func(path string, info os.FileInfo, err error) error {
|
||||
if !info.IsDir() && strings.HasSuffix(path, fileExtension) { // this is a JWT
|
||||
if count == maxJWTs { // won't match negative
|
||||
return nil
|
||||
}
|
||||
pubKey := strings.TrimSuffix(filepath.Base(path), fileExtension)
|
||||
if store.expiration != nil {
|
||||
if _, ok := store.expiration.idx[pubKey]; !ok {
|
||||
return nil // only include indexed files
|
||||
}
|
||||
}
|
||||
jwtBytes, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if store.expiration != nil {
|
||||
claim, err := jwt.DecodeGeneric(string(jwtBytes))
|
||||
if err == nil && claim.Expires > 0 && claim.Expires < time.Now().Unix() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
pack = append(pack, fmt.Sprintf("%s|%s", pubKey, string(jwtBytes)))
|
||||
count++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
store.Unlock()
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
return strings.Join(pack, "\n"), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Pack up to maxJWTs into a message and invoke callback with it
|
||||
func (store *DirJWTStore) PackWalk(maxJWTs int, cb func(partialPackMsg string)) error {
|
||||
if maxJWTs <= 0 || cb == nil {
|
||||
return errors.New("bad arguments to PackWalk")
|
||||
}
|
||||
var packMsg []string
|
||||
store.Lock()
|
||||
dir := store.directory
|
||||
exp := store.expiration
|
||||
store.Unlock()
|
||||
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if !info.IsDir() && strings.HasSuffix(path, fileExtension) { // this is a JWT
|
||||
pubKey := strings.TrimSuffix(filepath.Base(path), fileExtension)
|
||||
store.Lock()
|
||||
if exp != nil {
|
||||
if _, ok := store.expiration.idx[pubKey]; !ok {
|
||||
store.Unlock()
|
||||
return nil // only include indexed files
|
||||
}
|
||||
}
|
||||
store.Unlock()
|
||||
jwtBytes, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exp != nil {
|
||||
claim, err := jwt.DecodeGeneric(string(jwtBytes))
|
||||
if err == nil && claim.Expires > 0 && claim.Expires < time.Now().Unix() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
packMsg = append(packMsg, fmt.Sprintf("%s|%s", pubKey, string(jwtBytes)))
|
||||
if len(packMsg) == maxJWTs { // won't match negative
|
||||
cb(strings.Join(packMsg, "\n"))
|
||||
packMsg = nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if packMsg != nil {
|
||||
cb(strings.Join(packMsg, "\n"))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Merge takes the JWTs from package and adds them to the store
|
||||
// Merge is destructive in the sense that it doesn't check if the JWT
|
||||
// is newer or anything like that.
|
||||
func (store *DirJWTStore) Merge(pack string) error {
|
||||
newJWTs := strings.Split(pack, "\n")
|
||||
for _, line := range newJWTs {
|
||||
if line == "" { // ignore blank lines
|
||||
continue
|
||||
}
|
||||
split := strings.Split(line, "|")
|
||||
if len(split) != 2 {
|
||||
return fmt.Errorf("line in package didn't contain 2 entries: %q", line)
|
||||
}
|
||||
pubKey := split[0]
|
||||
if err := store.saveIfNewer(pubKey, split[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *DirJWTStore) Reload() error {
|
||||
store.Lock()
|
||||
exp := store.expiration
|
||||
if exp == nil || store.readonly {
|
||||
store.Unlock()
|
||||
return nil
|
||||
}
|
||||
idx := exp.idx
|
||||
changed := store.changed
|
||||
isCache := store.expiration.evictOnLimit
|
||||
// clear out indexing data structures
|
||||
exp.heap = make([]*jwtItem, 0, len(exp.heap))
|
||||
exp.idx = make(map[string]*list.Element)
|
||||
exp.lru = list.New()
|
||||
exp.hash = [sha256.Size]byte{}
|
||||
store.Unlock()
|
||||
return filepath.Walk(store.directory, func(path string, info os.FileInfo, err error) error {
|
||||
if strings.HasSuffix(path, fileExtension) {
|
||||
if theJwt, err := ioutil.ReadFile(path); err == nil {
|
||||
hash := sha256.Sum256(theJwt)
|
||||
_, file := filepath.Split(path)
|
||||
pkey := strings.TrimSuffix(file, fileExtension)
|
||||
notify := isCache // for cache, issue cb even when file not present (may have been evicted)
|
||||
if i, ok := idx[pkey]; ok {
|
||||
notify = !bytes.Equal(i.Value.(*jwtItem).hash[:], hash[:])
|
||||
}
|
||||
store.Lock()
|
||||
exp.track(pkey, &hash, string(theJwt))
|
||||
store.Unlock()
|
||||
if notify && changed != nil {
|
||||
changed(pkey)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (store *DirJWTStore) pathForKey(publicKey string) string {
|
||||
if len(publicKey) < 2 {
|
||||
return ""
|
||||
}
|
||||
fileName := fmt.Sprintf("%s%s", publicKey, fileExtension)
|
||||
if store.shard {
|
||||
last := publicKey[len(publicKey)-2:]
|
||||
return filepath.Join(store.directory, last, fileName)
|
||||
} else {
|
||||
return filepath.Join(store.directory, fileName)
|
||||
}
|
||||
}
|
||||
|
||||
// Load checks the memory store and returns the matching JWT or an error
|
||||
// Assumes lock is NOT held
|
||||
func (store *DirJWTStore) load(publicKey string) (string, error) {
|
||||
store.Lock()
|
||||
defer store.Unlock()
|
||||
if path := store.pathForKey(publicKey); path == "" {
|
||||
return "", fmt.Errorf("invalid public key")
|
||||
} else if data, err := ioutil.ReadFile(path); err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
if store.expiration != nil {
|
||||
store.expiration.updateTrack(publicKey)
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
}
|
||||
|
||||
// write that keeps hash of all jwt in sync
|
||||
// Assumes the lock is held. Does return true or an error never both.
|
||||
func (store *DirJWTStore) write(path string, publicKey string, theJWT string) (bool, error) {
|
||||
var newHash *[sha256.Size]byte
|
||||
if store.expiration != nil {
|
||||
h := sha256.Sum256([]byte(theJWT))
|
||||
newHash = &h
|
||||
if v, ok := store.expiration.idx[publicKey]; ok {
|
||||
store.expiration.updateTrack(publicKey)
|
||||
// this write is an update, move to back
|
||||
it := v.Value.(*jwtItem)
|
||||
oldHash := it.hash[:]
|
||||
if bytes.Equal(oldHash, newHash[:]) {
|
||||
return false, nil
|
||||
}
|
||||
} else if int64(store.expiration.Len()) >= store.expiration.limit {
|
||||
if !store.expiration.evictOnLimit {
|
||||
return false, errors.New("jwt store is full")
|
||||
}
|
||||
// this write is an add, pick the least recently used value for removal
|
||||
i := store.expiration.lru.Front().Value.(*jwtItem)
|
||||
if err := os.Remove(store.pathForKey(i.publicKey)); err != nil {
|
||||
return false, err
|
||||
} else {
|
||||
store.expiration.unTrack(i.publicKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := ioutil.WriteFile(path, []byte(theJWT), 0644); err != nil {
|
||||
return false, err
|
||||
} else if store.expiration != nil {
|
||||
store.expiration.track(publicKey, newHash, theJWT)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Save puts the JWT in a map by public key and performs update callbacks
|
||||
// Assumes lock is NOT held
|
||||
func (store *DirJWTStore) save(publicKey string, theJWT string) error {
|
||||
if store.readonly {
|
||||
return fmt.Errorf("store is read-only")
|
||||
}
|
||||
store.Lock()
|
||||
path := store.pathForKey(publicKey)
|
||||
if path == "" {
|
||||
store.Unlock()
|
||||
return fmt.Errorf("invalid public key")
|
||||
}
|
||||
dirPath := filepath.Dir(path)
|
||||
if _, err := validateDirPath(dirPath); err != nil {
|
||||
if err := os.MkdirAll(dirPath, 0755); err != nil {
|
||||
store.Unlock()
|
||||
return err
|
||||
}
|
||||
}
|
||||
changed, err := store.write(path, publicKey, theJWT)
|
||||
cb := store.changed
|
||||
store.Unlock()
|
||||
if changed && cb != nil {
|
||||
cb(publicKey)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Assumes the lock is NOT held, and only updates if the jwt is new, or the one on disk is older
|
||||
// returns true when the jwt changed
|
||||
func (store *DirJWTStore) saveIfNewer(publicKey string, theJWT string) error {
|
||||
if store.readonly {
|
||||
return fmt.Errorf("store is read-only")
|
||||
}
|
||||
path := store.pathForKey(publicKey)
|
||||
if path == "" {
|
||||
return fmt.Errorf("invalid public key")
|
||||
}
|
||||
dirPath := filepath.Dir(path)
|
||||
if _, err := validateDirPath(dirPath); err != nil {
|
||||
if err := os.MkdirAll(dirPath, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
if newJWT, err := jwt.DecodeGeneric(theJWT); err != nil {
|
||||
// skip if it can't be decoded
|
||||
} else if existing, err := ioutil.ReadFile(path); err != nil {
|
||||
return err
|
||||
} else if existingJWT, err := jwt.DecodeGeneric(string(existing)); err != nil {
|
||||
// skip if it can't be decoded
|
||||
} else if existingJWT.ID == newJWT.ID {
|
||||
return nil
|
||||
} else if existingJWT.IssuedAt > newJWT.IssuedAt {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
store.Lock()
|
||||
cb := store.changed
|
||||
changed, err := store.write(path, publicKey, theJWT)
|
||||
store.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if changed && cb != nil {
|
||||
cb(publicKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func xorAssign(lVal *[sha256.Size]byte, rVal [sha256.Size]byte) {
|
||||
for i := range rVal {
|
||||
(*lVal)[i] ^= rVal[i]
|
||||
}
|
||||
}
|
||||
|
||||
// returns a hash representing all indexed jwt
|
||||
func (store *DirJWTStore) Hash() [sha256.Size]byte {
|
||||
store.Lock()
|
||||
defer store.Unlock()
|
||||
if store.expiration == nil {
|
||||
return [sha256.Size]byte{}
|
||||
} else {
|
||||
return store.expiration.hash
|
||||
}
|
||||
}
|
||||
|
||||
// An jwtItem is something managed by the priority queue
|
||||
type jwtItem struct {
|
||||
index int
|
||||
publicKey string
|
||||
expiration int64 // consists of unix time of expiration (ttl when set or jwt expiration) in seconds
|
||||
hash [sha256.Size]byte
|
||||
}
|
||||
|
||||
// A expirationTracker implements heap.Interface and holds Items.
|
||||
type expirationTracker struct {
|
||||
heap []*jwtItem // sorted by jwtItem.expiration
|
||||
idx map[string]*list.Element
|
||||
lru *list.List // keep which jwt are least used
|
||||
limit int64 // limit how many jwt are being tracked
|
||||
evictOnLimit bool // when limit is hit, error or evict using lru
|
||||
ttl time.Duration
|
||||
hash [sha256.Size]byte // xor of all jwtItem.hash in idx
|
||||
quit chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func (pq *expirationTracker) Len() int { return len(pq.heap) }
|
||||
|
||||
func (q *expirationTracker) Less(i, j int) bool {
|
||||
pq := q.heap
|
||||
return pq[i].expiration < pq[j].expiration
|
||||
}
|
||||
|
||||
func (q *expirationTracker) Swap(i, j int) {
|
||||
pq := q.heap
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
pq[i].index = i
|
||||
pq[j].index = j
|
||||
}
|
||||
|
||||
func (q *expirationTracker) Push(x interface{}) {
|
||||
n := len(q.heap)
|
||||
item := x.(*jwtItem)
|
||||
item.index = n
|
||||
q.heap = append(q.heap, item)
|
||||
q.idx[item.publicKey] = q.lru.PushBack(item)
|
||||
}
|
||||
|
||||
func (q *expirationTracker) Pop() interface{} {
|
||||
old := q.heap
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
old[n-1] = nil // avoid memory leak
|
||||
item.index = -1
|
||||
q.heap = old[0 : n-1]
|
||||
q.lru.Remove(q.idx[item.publicKey])
|
||||
delete(q.idx, item.publicKey)
|
||||
return item
|
||||
}
|
||||
|
||||
func (pq *expirationTracker) updateTrack(publicKey string) {
|
||||
if e, ok := pq.idx[publicKey]; ok {
|
||||
i := e.Value.(*jwtItem)
|
||||
if pq.ttl != 0 {
|
||||
// only update expiration when set
|
||||
i.expiration = time.Now().Add(pq.ttl).Unix()
|
||||
heap.Fix(pq, i.index)
|
||||
}
|
||||
if pq.evictOnLimit {
|
||||
pq.lru.MoveToBack(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pq *expirationTracker) unTrack(publicKey string) {
|
||||
if it, ok := pq.idx[publicKey]; ok {
|
||||
xorAssign(&pq.hash, it.Value.(*jwtItem).hash)
|
||||
heap.Remove(pq, it.Value.(*jwtItem).index)
|
||||
delete(pq.idx, publicKey)
|
||||
}
|
||||
}
|
||||
|
||||
func (pq *expirationTracker) track(publicKey string, hash *[sha256.Size]byte, theJWT string) {
|
||||
var exp int64
|
||||
// prioritize ttl over expiration
|
||||
if pq.ttl != 0 {
|
||||
if pq.ttl == time.Duration(math.MaxInt64) {
|
||||
exp = math.MaxInt64
|
||||
} else {
|
||||
exp = time.Now().Add(pq.ttl).Unix()
|
||||
}
|
||||
} else {
|
||||
if g, err := jwt.DecodeGeneric(theJWT); err == nil {
|
||||
exp = g.Expires
|
||||
}
|
||||
if exp == 0 {
|
||||
exp = math.MaxInt64 // default to indefinite
|
||||
}
|
||||
}
|
||||
if e, ok := pq.idx[publicKey]; ok {
|
||||
i := e.Value.(*jwtItem)
|
||||
xorAssign(&pq.hash, i.hash) // remove old hash
|
||||
i.expiration = exp
|
||||
i.hash = *hash
|
||||
heap.Fix(pq, i.index)
|
||||
} else {
|
||||
heap.Push(pq, &jwtItem{-1, publicKey, exp, *hash})
|
||||
}
|
||||
xorAssign(&pq.hash, *hash) // add in new hash
|
||||
}
|
||||
|
||||
func (pq *expirationTracker) close() {
|
||||
if pq == nil || pq.quit == nil {
|
||||
return
|
||||
}
|
||||
close(pq.quit)
|
||||
pq.quit = nil
|
||||
}
|
||||
|
||||
func (store *DirJWTStore) startExpiring(reCheck time.Duration, limit int64, evictOnLimit bool, ttl time.Duration) {
|
||||
store.Lock()
|
||||
defer store.Unlock()
|
||||
quit := make(chan struct{})
|
||||
pq := &expirationTracker{
|
||||
make([]*jwtItem, 0, 10),
|
||||
make(map[string]*list.Element),
|
||||
list.New(),
|
||||
limit,
|
||||
evictOnLimit,
|
||||
ttl,
|
||||
[sha256.Size]byte{},
|
||||
quit,
|
||||
sync.WaitGroup{},
|
||||
}
|
||||
store.expiration = pq
|
||||
pq.wg.Add(1)
|
||||
go func() {
|
||||
t := time.NewTicker(reCheck)
|
||||
defer t.Stop()
|
||||
defer pq.wg.Done()
|
||||
for {
|
||||
now := time.Now()
|
||||
store.Lock()
|
||||
if pq.Len() > 0 {
|
||||
if it := heap.Pop(pq).(*jwtItem); it.expiration <= now.Unix() {
|
||||
path := store.pathForKey(it.publicKey)
|
||||
if err := os.Remove(path); err != nil {
|
||||
heap.Push(pq, it) // retry later
|
||||
} else {
|
||||
pq.unTrack(it.publicKey)
|
||||
xorAssign(&pq.hash, it.hash)
|
||||
store.Unlock()
|
||||
continue // we removed an entry, check next one
|
||||
}
|
||||
} else {
|
||||
heap.Push(pq, it)
|
||||
}
|
||||
}
|
||||
store.Unlock()
|
||||
select {
|
||||
case <-t.C:
|
||||
case <-quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
898
server/dirstore_test.go
Normal file
898
server/dirstore_test.go
Normal file
@@ -0,0 +1,898 @@
|
||||
/*
|
||||
* Copyright 2020 The NATS Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/jwt/v2"
|
||||
"github.com/nats-io/nkeys"
|
||||
)
|
||||
|
||||
func require_True(t *testing.T, b bool) {
|
||||
t.Helper()
|
||||
if !b {
|
||||
t.Errorf("require true, but got false")
|
||||
}
|
||||
}
|
||||
|
||||
func require_False(t *testing.T, b bool) {
|
||||
t.Helper()
|
||||
if b {
|
||||
t.Errorf("require no false, but got true")
|
||||
}
|
||||
}
|
||||
|
||||
func require_NoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Errorf("require no error, but got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func require_Error(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
t.Errorf("require no error, but got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func require_Equal(t *testing.T, a, b string) {
|
||||
t.Helper()
|
||||
if strings.Compare(a, b) != 0 {
|
||||
t.Errorf("require equal, but got: %v != %v", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
func require_NotEqual(t *testing.T, a, b [32]byte) {
|
||||
t.Helper()
|
||||
if bytes.Equal(a[:], b[:]) {
|
||||
t.Errorf("require not equal, but got: %v != %v", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
func require_Len(t *testing.T, a, b int) {
|
||||
t.Helper()
|
||||
if a != b {
|
||||
t.Errorf("require len, but got: %v != %v", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShardedDirStoreWriteAndReadonly(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
|
||||
store, err := NewDirJWTStore(dir, true, false)
|
||||
require_NoError(t, err)
|
||||
|
||||
expected := map[string]string{
|
||||
"one": "alpha",
|
||||
"two": "beta",
|
||||
"three": "gamma",
|
||||
"four": "delta",
|
||||
}
|
||||
|
||||
for k, v := range expected {
|
||||
store.SaveAcc(k, v)
|
||||
}
|
||||
|
||||
for k, v := range expected {
|
||||
got, err := store.LoadAcc(k)
|
||||
require_NoError(t, err)
|
||||
require_Equal(t, v, got)
|
||||
}
|
||||
|
||||
got, err := store.LoadAcc("random")
|
||||
require_Error(t, err)
|
||||
require_Equal(t, "", got)
|
||||
|
||||
got, err = store.LoadAcc("")
|
||||
require_Error(t, err)
|
||||
require_Equal(t, "", got)
|
||||
|
||||
err = store.SaveAcc("", "onetwothree")
|
||||
require_Error(t, err)
|
||||
store.Close()
|
||||
|
||||
// re-use the folder for readonly mode
|
||||
store, err = NewImmutableDirJWTStore(dir, true)
|
||||
require_NoError(t, err)
|
||||
|
||||
require_True(t, store.IsReadOnly())
|
||||
|
||||
err = store.SaveAcc("five", "omega")
|
||||
require_Error(t, err)
|
||||
|
||||
for k, v := range expected {
|
||||
got, err := store.LoadAcc(k)
|
||||
require_NoError(t, err)
|
||||
require_Equal(t, v, got)
|
||||
}
|
||||
store.Close()
|
||||
}
|
||||
|
||||
func TestUnshardedDirStoreWriteAndReadonly(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
|
||||
store, err := NewDirJWTStore(dir, false, false)
|
||||
require_NoError(t, err)
|
||||
|
||||
expected := map[string]string{
|
||||
"one": "alpha",
|
||||
"two": "beta",
|
||||
"three": "gamma",
|
||||
"four": "delta",
|
||||
}
|
||||
|
||||
require_False(t, store.IsReadOnly())
|
||||
|
||||
for k, v := range expected {
|
||||
store.SaveAcc(k, v)
|
||||
}
|
||||
|
||||
for k, v := range expected {
|
||||
got, err := store.LoadAcc(k)
|
||||
require_NoError(t, err)
|
||||
require_Equal(t, v, got)
|
||||
}
|
||||
|
||||
got, err := store.LoadAcc("random")
|
||||
require_Error(t, err)
|
||||
require_Equal(t, "", got)
|
||||
|
||||
got, err = store.LoadAcc("")
|
||||
require_Error(t, err)
|
||||
require_Equal(t, "", got)
|
||||
|
||||
err = store.SaveAcc("", "onetwothree")
|
||||
require_Error(t, err)
|
||||
store.Close()
|
||||
|
||||
// re-use the folder for readonly mode
|
||||
store, err = NewImmutableDirJWTStore(dir, false)
|
||||
require_NoError(t, err)
|
||||
|
||||
require_True(t, store.IsReadOnly())
|
||||
|
||||
err = store.SaveAcc("five", "omega")
|
||||
require_Error(t, err)
|
||||
|
||||
for k, v := range expected {
|
||||
got, err := store.LoadAcc(k)
|
||||
require_NoError(t, err)
|
||||
require_Equal(t, v, got)
|
||||
}
|
||||
store.Close()
|
||||
}
|
||||
|
||||
func TestNoCreateRequiresDir(t *testing.T) {
|
||||
_, err := NewDirJWTStore("/a/b/c", true, false)
|
||||
require_Error(t, err)
|
||||
}
|
||||
|
||||
func TestCreateMakesDir(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
|
||||
fullPath := filepath.Join(dir, "a/b")
|
||||
|
||||
_, err = os.Stat(fullPath)
|
||||
require_Error(t, err)
|
||||
require_True(t, os.IsNotExist(err))
|
||||
|
||||
s, err := NewDirJWTStore(fullPath, false, true)
|
||||
require_NoError(t, err)
|
||||
s.Close()
|
||||
|
||||
_, err = os.Stat(fullPath)
|
||||
require_NoError(t, err)
|
||||
}
|
||||
|
||||
func TestShardedDirStorePackMerge(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
dir2, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
dir3, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
|
||||
store, err := NewDirJWTStore(dir, true, false)
|
||||
require_NoError(t, err)
|
||||
|
||||
expected := map[string]string{
|
||||
"one": "alpha",
|
||||
"two": "beta",
|
||||
"three": "gamma",
|
||||
"four": "delta",
|
||||
}
|
||||
|
||||
require_False(t, store.IsReadOnly())
|
||||
|
||||
for k, v := range expected {
|
||||
store.SaveAcc(k, v)
|
||||
}
|
||||
|
||||
for k, v := range expected {
|
||||
got, err := store.LoadAcc(k)
|
||||
require_NoError(t, err)
|
||||
require_Equal(t, v, got)
|
||||
}
|
||||
|
||||
got, err := store.LoadAcc("random")
|
||||
require_Error(t, err)
|
||||
require_Equal(t, "", got)
|
||||
|
||||
pack, err := store.Pack(-1)
|
||||
require_NoError(t, err)
|
||||
|
||||
inc, err := NewDirJWTStore(dir2, true, false)
|
||||
require_NoError(t, err)
|
||||
|
||||
inc.Merge(pack)
|
||||
|
||||
for k, v := range expected {
|
||||
got, err := inc.LoadAcc(k)
|
||||
require_NoError(t, err)
|
||||
require_Equal(t, v, got)
|
||||
}
|
||||
|
||||
got, err = inc.LoadAcc("random")
|
||||
require_Error(t, err)
|
||||
require_Equal(t, "", got)
|
||||
|
||||
limitedPack, err := inc.Pack(1)
|
||||
require_NoError(t, err)
|
||||
|
||||
limited, err := NewDirJWTStore(dir3, true, false)
|
||||
|
||||
require_NoError(t, err)
|
||||
|
||||
limited.Merge(limitedPack)
|
||||
|
||||
count := 0
|
||||
for k, v := range expected {
|
||||
got, err := limited.LoadAcc(k)
|
||||
if err == nil {
|
||||
count++
|
||||
require_Equal(t, v, got)
|
||||
}
|
||||
}
|
||||
|
||||
require_Len(t, 1, count)
|
||||
|
||||
got, err = inc.LoadAcc("random")
|
||||
require_Error(t, err)
|
||||
require_Equal(t, "", got)
|
||||
}
|
||||
|
||||
func TestShardedToUnsharedDirStorePackMerge(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
dir2, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
|
||||
store, err := NewDirJWTStore(dir, true, false)
|
||||
require_NoError(t, err)
|
||||
|
||||
expected := map[string]string{
|
||||
"one": "alpha",
|
||||
"two": "beta",
|
||||
"three": "gamma",
|
||||
"four": "delta",
|
||||
}
|
||||
|
||||
require_False(t, store.IsReadOnly())
|
||||
|
||||
for k, v := range expected {
|
||||
store.SaveAcc(k, v)
|
||||
}
|
||||
|
||||
for k, v := range expected {
|
||||
got, err := store.LoadAcc(k)
|
||||
require_NoError(t, err)
|
||||
require_Equal(t, v, got)
|
||||
}
|
||||
|
||||
got, err := store.LoadAcc("random")
|
||||
require_Error(t, err)
|
||||
require_Equal(t, "", got)
|
||||
|
||||
pack, err := store.Pack(-1)
|
||||
require_NoError(t, err)
|
||||
|
||||
inc, err := NewDirJWTStore(dir2, false, false)
|
||||
require_NoError(t, err)
|
||||
|
||||
inc.Merge(pack)
|
||||
|
||||
for k, v := range expected {
|
||||
got, err := inc.LoadAcc(k)
|
||||
require_NoError(t, err)
|
||||
require_Equal(t, v, got)
|
||||
}
|
||||
|
||||
got, err = inc.LoadAcc("random")
|
||||
require_Error(t, err)
|
||||
require_Equal(t, "", got)
|
||||
|
||||
err = store.Merge("foo")
|
||||
require_Error(t, err)
|
||||
|
||||
err = store.Merge("") // will skip it
|
||||
require_NoError(t, err)
|
||||
|
||||
err = store.Merge("a|something") // should fail on a for sharding
|
||||
require_Error(t, err)
|
||||
}
|
||||
|
||||
func TestMergeOnlyOnNewer(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
|
||||
dirStore, err := NewDirJWTStore(dir, true, false)
|
||||
require_NoError(t, err)
|
||||
|
||||
accountKey, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
|
||||
pubKey, err := accountKey.PublicKey()
|
||||
require_NoError(t, err)
|
||||
|
||||
account := jwt.NewAccountClaims(pubKey)
|
||||
account.Name = "old"
|
||||
olderJWT, err := account.Encode(accountKey)
|
||||
require_NoError(t, err)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
account.Name = "new"
|
||||
newerJWT, err := account.Encode(accountKey)
|
||||
require_NoError(t, err)
|
||||
|
||||
// Should work
|
||||
err = dirStore.SaveAcc(pubKey, olderJWT)
|
||||
require_NoError(t, err)
|
||||
fromStore, err := dirStore.LoadAcc(pubKey)
|
||||
require_NoError(t, err)
|
||||
require_Equal(t, olderJWT, fromStore)
|
||||
|
||||
// should replace
|
||||
err = dirStore.saveIfNewer(pubKey, newerJWT)
|
||||
require_NoError(t, err)
|
||||
fromStore, err = dirStore.LoadAcc(pubKey)
|
||||
require_NoError(t, err)
|
||||
require_Equal(t, newerJWT, fromStore)
|
||||
|
||||
// should fail
|
||||
err = dirStore.saveIfNewer(pubKey, olderJWT)
|
||||
require_NoError(t, err)
|
||||
fromStore, err = dirStore.LoadAcc(pubKey)
|
||||
require_NoError(t, err)
|
||||
require_Equal(t, newerJWT, fromStore)
|
||||
}
|
||||
|
||||
func createTestAccount(t *testing.T, dirStore *DirJWTStore, expSec int, accKey nkeys.KeyPair) string {
|
||||
t.Helper()
|
||||
pubKey, err := accKey.PublicKey()
|
||||
require_NoError(t, err)
|
||||
account := jwt.NewAccountClaims(pubKey)
|
||||
if expSec > 0 {
|
||||
account.Expires = time.Now().Add(time.Second * time.Duration(expSec)).Unix()
|
||||
}
|
||||
jwt, err := account.Encode(accKey)
|
||||
require_NoError(t, err)
|
||||
err = dirStore.SaveAcc(pubKey, jwt)
|
||||
require_NoError(t, err)
|
||||
return jwt
|
||||
}
|
||||
|
||||
func assertStoreSize(t *testing.T, dirStore *DirJWTStore, length int) {
|
||||
t.Helper()
|
||||
f, err := ioutil.ReadDir(dirStore.directory)
|
||||
require_NoError(t, err)
|
||||
require_Len(t, len(f), length)
|
||||
dirStore.Lock()
|
||||
require_Len(t, len(dirStore.expiration.idx), length)
|
||||
require_Len(t, dirStore.expiration.lru.Len(), length)
|
||||
require_Len(t, len(dirStore.expiration.heap), length)
|
||||
dirStore.Unlock()
|
||||
}
|
||||
|
||||
func TestExpiration(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
|
||||
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 10, true, 0, nil)
|
||||
require_NoError(t, err)
|
||||
defer dirStore.Close()
|
||||
|
||||
account := func(expSec int) {
|
||||
accountKey, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
createTestAccount(t, dirStore, expSec, accountKey)
|
||||
}
|
||||
|
||||
h := dirStore.Hash()
|
||||
|
||||
for i := 1; i <= 5; i++ {
|
||||
account(i * 2)
|
||||
nh := dirStore.Hash()
|
||||
require_NotEqual(t, h, nh)
|
||||
h = nh
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
for i := 5; i > 0; i-- {
|
||||
f, err := ioutil.ReadDir(dir)
|
||||
require_NoError(t, err)
|
||||
require_Len(t, len(f), i)
|
||||
assertStoreSize(t, dirStore, i)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
nh := dirStore.Hash()
|
||||
require_NotEqual(t, h, nh)
|
||||
h = nh
|
||||
}
|
||||
assertStoreSize(t, dirStore, 0)
|
||||
}
|
||||
|
||||
func TestLimit(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
|
||||
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 5, true, 0, nil)
|
||||
require_NoError(t, err)
|
||||
defer dirStore.Close()
|
||||
|
||||
account := func(expSec int) {
|
||||
accountKey, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
createTestAccount(t, dirStore, expSec, accountKey)
|
||||
}
|
||||
|
||||
h := dirStore.Hash()
|
||||
|
||||
accountKey, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
// update first account
|
||||
for i := 0; i < 10; i++ {
|
||||
createTestAccount(t, dirStore, 50, accountKey)
|
||||
assertStoreSize(t, dirStore, 1)
|
||||
}
|
||||
// new accounts
|
||||
for i := 0; i < 10; i++ {
|
||||
account(i)
|
||||
nh := dirStore.Hash()
|
||||
require_NotEqual(t, h, nh)
|
||||
h = nh
|
||||
}
|
||||
// first account should be gone now accountKey.PublicKey()
|
||||
key, _ := accountKey.PublicKey()
|
||||
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, key))
|
||||
require_True(t, os.IsNotExist(err))
|
||||
|
||||
// update first account
|
||||
for i := 0; i < 10; i++ {
|
||||
createTestAccount(t, dirStore, 50, accountKey)
|
||||
assertStoreSize(t, dirStore, 5)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitNoEvict(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
|
||||
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 2, false, 0, nil)
|
||||
require_NoError(t, err)
|
||||
defer dirStore.Close()
|
||||
|
||||
accountKey1, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
pKey1, err := accountKey1.PublicKey()
|
||||
require_NoError(t, err)
|
||||
accountKey2, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
accountKey3, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
pKey3, err := accountKey3.PublicKey()
|
||||
require_NoError(t, err)
|
||||
|
||||
createTestAccount(t, dirStore, 100, accountKey1)
|
||||
assertStoreSize(t, dirStore, 1)
|
||||
createTestAccount(t, dirStore, 2, accountKey2)
|
||||
assertStoreSize(t, dirStore, 2)
|
||||
|
||||
hBefore := dirStore.Hash()
|
||||
// 2 jwt are already stored. third must result in an error
|
||||
pubKey, err := accountKey3.PublicKey()
|
||||
require_NoError(t, err)
|
||||
account := jwt.NewAccountClaims(pubKey)
|
||||
jwt, err := account.Encode(accountKey3)
|
||||
require_NoError(t, err)
|
||||
err = dirStore.SaveAcc(pubKey, jwt)
|
||||
require_Error(t, err)
|
||||
assertStoreSize(t, dirStore, 2)
|
||||
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1))
|
||||
require_NoError(t, err)
|
||||
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3))
|
||||
require_True(t, os.IsNotExist(err))
|
||||
// check that the hash did not change
|
||||
hAfter := dirStore.Hash()
|
||||
require_True(t, bytes.Equal(hBefore[:], hAfter[:]))
|
||||
// wait for expiration of account2
|
||||
time.Sleep(3 * time.Second)
|
||||
err = dirStore.SaveAcc(pubKey, jwt)
|
||||
require_NoError(t, err)
|
||||
assertStoreSize(t, dirStore, 2)
|
||||
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1))
|
||||
require_NoError(t, err)
|
||||
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3))
|
||||
require_NoError(t, err)
|
||||
}
|
||||
|
||||
func TestLruLoad(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 2, true, 0, nil)
|
||||
require_NoError(t, err)
|
||||
defer dirStore.Close()
|
||||
|
||||
accountKey1, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
pKey1, err := accountKey1.PublicKey()
|
||||
require_NoError(t, err)
|
||||
accountKey2, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
accountKey3, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
pKey3, err := accountKey3.PublicKey()
|
||||
require_NoError(t, err)
|
||||
|
||||
createTestAccount(t, dirStore, 10, accountKey1)
|
||||
assertStoreSize(t, dirStore, 1)
|
||||
createTestAccount(t, dirStore, 10, accountKey2)
|
||||
assertStoreSize(t, dirStore, 2)
|
||||
dirStore.LoadAcc(pKey1) // will reorder 1/2
|
||||
createTestAccount(t, dirStore, 10, accountKey3)
|
||||
assertStoreSize(t, dirStore, 2)
|
||||
|
||||
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1))
|
||||
require_NoError(t, err)
|
||||
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3))
|
||||
require_NoError(t, err)
|
||||
}
|
||||
|
||||
func TestLru(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
|
||||
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 2, true, 0, nil)
|
||||
require_NoError(t, err)
|
||||
defer dirStore.Close()
|
||||
|
||||
accountKey1, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
pKey1, err := accountKey1.PublicKey()
|
||||
require_NoError(t, err)
|
||||
accountKey2, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
accountKey3, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
pKey3, err := accountKey3.PublicKey()
|
||||
require_NoError(t, err)
|
||||
|
||||
createTestAccount(t, dirStore, 10, accountKey1)
|
||||
assertStoreSize(t, dirStore, 1)
|
||||
createTestAccount(t, dirStore, 10, accountKey2)
|
||||
assertStoreSize(t, dirStore, 2)
|
||||
createTestAccount(t, dirStore, 10, accountKey3)
|
||||
assertStoreSize(t, dirStore, 2)
|
||||
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1))
|
||||
require_True(t, os.IsNotExist(err))
|
||||
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3))
|
||||
require_NoError(t, err)
|
||||
|
||||
// update -> will change this keys position for eviction
|
||||
createTestAccount(t, dirStore, 10, accountKey2)
|
||||
assertStoreSize(t, dirStore, 2)
|
||||
// recreate -> will evict 3
|
||||
createTestAccount(t, dirStore, 1, accountKey1)
|
||||
assertStoreSize(t, dirStore, 2)
|
||||
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey3))
|
||||
require_True(t, os.IsNotExist(err))
|
||||
// let key1 expire
|
||||
time.Sleep(2 * time.Second)
|
||||
assertStoreSize(t, dirStore, 1)
|
||||
_, err = os.Stat(fmt.Sprintf("%s/%s.jwt", dir, pKey1))
|
||||
require_True(t, os.IsNotExist(err))
|
||||
// recreate key3 - no eviction
|
||||
createTestAccount(t, dirStore, 10, accountKey3)
|
||||
assertStoreSize(t, dirStore, 2)
|
||||
}
|
||||
|
||||
func TestReload(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
notificationChan := make(chan struct{}, 5)
|
||||
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 2, true, 0, func(publicKey string) {
|
||||
notificationChan <- struct{}{}
|
||||
})
|
||||
require_NoError(t, err)
|
||||
defer dirStore.Close()
|
||||
newAccount := func() string {
|
||||
t.Helper()
|
||||
accKey, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
pKey, err := accKey.PublicKey()
|
||||
require_NoError(t, err)
|
||||
pubKey, err := accKey.PublicKey()
|
||||
require_NoError(t, err)
|
||||
account := jwt.NewAccountClaims(pubKey)
|
||||
jwt, err := account.Encode(accKey)
|
||||
require_NoError(t, err)
|
||||
file := fmt.Sprintf("%s/%s.jwt", dir, pKey)
|
||||
err = ioutil.WriteFile(file, []byte(jwt), 0644)
|
||||
require_NoError(t, err)
|
||||
return file
|
||||
}
|
||||
files := make(map[string]struct{})
|
||||
assertStoreSize(t, dirStore, 0)
|
||||
hash := dirStore.Hash()
|
||||
emptyHash := [sha256.Size]byte{}
|
||||
require_True(t, bytes.Equal(hash[:], emptyHash[:]))
|
||||
for i := 0; i < 5; i++ {
|
||||
files[newAccount()] = struct{}{}
|
||||
err = dirStore.Reload()
|
||||
require_NoError(t, err)
|
||||
<-notificationChan
|
||||
assertStoreSize(t, dirStore, i+1)
|
||||
hash = dirStore.Hash()
|
||||
require_False(t, bytes.Equal(hash[:], emptyHash[:]))
|
||||
msg, err := dirStore.Pack(-1)
|
||||
require_NoError(t, err)
|
||||
require_Len(t, len(strings.Split(msg, "\n")), len(files))
|
||||
}
|
||||
for k := range files {
|
||||
hash = dirStore.Hash()
|
||||
require_False(t, bytes.Equal(hash[:], emptyHash[:]))
|
||||
os.Remove(k)
|
||||
err = dirStore.Reload()
|
||||
require_NoError(t, err)
|
||||
assertStoreSize(t, dirStore, len(files)-1)
|
||||
delete(files, k)
|
||||
msg, err := dirStore.Pack(-1)
|
||||
require_NoError(t, err)
|
||||
if len(files) != 0 { // when len is 0, we have an empty line
|
||||
require_Len(t, len(strings.Split(msg, "\n")), len(files))
|
||||
}
|
||||
}
|
||||
require_True(t, len(notificationChan) == 0)
|
||||
hash = dirStore.Hash()
|
||||
require_True(t, bytes.Equal(hash[:], emptyHash[:]))
|
||||
}
|
||||
|
||||
func TestExpirationUpdate(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
|
||||
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 10, true, 0, nil)
|
||||
require_NoError(t, err)
|
||||
defer dirStore.Close()
|
||||
|
||||
accountKey, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
|
||||
h := dirStore.Hash()
|
||||
|
||||
createTestAccount(t, dirStore, 0, accountKey)
|
||||
nh := dirStore.Hash()
|
||||
require_NotEqual(t, h, nh)
|
||||
h = nh
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
f, err := ioutil.ReadDir(dir)
|
||||
require_NoError(t, err)
|
||||
require_Len(t, len(f), 1)
|
||||
|
||||
createTestAccount(t, dirStore, 5, accountKey)
|
||||
nh = dirStore.Hash()
|
||||
require_NotEqual(t, h, nh)
|
||||
h = nh
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
f, err = ioutil.ReadDir(dir)
|
||||
require_NoError(t, err)
|
||||
require_Len(t, len(f), 1)
|
||||
|
||||
createTestAccount(t, dirStore, 0, accountKey)
|
||||
nh = dirStore.Hash()
|
||||
require_NotEqual(t, h, nh)
|
||||
h = nh
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
f, err = ioutil.ReadDir(dir)
|
||||
require_NoError(t, err)
|
||||
require_Len(t, len(f), 1)
|
||||
|
||||
createTestAccount(t, dirStore, 1, accountKey)
|
||||
nh = dirStore.Hash()
|
||||
require_NotEqual(t, h, nh)
|
||||
h = nh
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
f, err = ioutil.ReadDir(dir)
|
||||
require_NoError(t, err)
|
||||
require_Len(t, len(f), 0)
|
||||
|
||||
empty := [32]byte{}
|
||||
h = dirStore.Hash()
|
||||
require_Equal(t, string(h[:]), string(empty[:]))
|
||||
}
|
||||
|
||||
func TestTTL(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
require_OneJWT := func() {
|
||||
t.Helper()
|
||||
f, err := ioutil.ReadDir(dir)
|
||||
require_NoError(t, err)
|
||||
require_Len(t, len(f), 1)
|
||||
}
|
||||
dirStore, err := NewExpiringDirJWTStore(dir, false, false, time.Millisecond*100, 10, true, 2*time.Second, nil)
|
||||
require_NoError(t, err)
|
||||
defer dirStore.Close()
|
||||
|
||||
accountKey, err := nkeys.CreateAccount()
|
||||
require_NoError(t, err)
|
||||
pubKey, err := accountKey.PublicKey()
|
||||
require_NoError(t, err)
|
||||
jwt := createTestAccount(t, dirStore, 0, accountKey)
|
||||
require_OneJWT()
|
||||
for i := 0; i < 6; i++ {
|
||||
time.Sleep(time.Second)
|
||||
dirStore.LoadAcc(pubKey)
|
||||
require_OneJWT()
|
||||
}
|
||||
for i := 0; i < 6; i++ {
|
||||
time.Sleep(time.Second)
|
||||
dirStore.SaveAcc(pubKey, jwt)
|
||||
require_OneJWT()
|
||||
}
|
||||
for i := 0; i < 6; i++ {
|
||||
time.Sleep(time.Second)
|
||||
createTestAccount(t, dirStore, 0, accountKey)
|
||||
require_OneJWT()
|
||||
}
|
||||
time.Sleep(3 * time.Second)
|
||||
f, err := ioutil.ReadDir(dir)
|
||||
require_NoError(t, err)
|
||||
require_Len(t, len(f), 0)
|
||||
}
|
||||
|
||||
const infDur = time.Duration(math.MaxInt64)
|
||||
|
||||
func TestNotificationOnPack(t *testing.T) {
|
||||
jwts := map[string]string{
|
||||
"key1": "value",
|
||||
"key2": "value",
|
||||
"key3": "value",
|
||||
"key4": "value",
|
||||
}
|
||||
notificationChan := make(chan struct{}, len(jwts)) // set to same len so all extra will block
|
||||
notification := func(pubKey string) {
|
||||
if _, ok := jwts[pubKey]; !ok {
|
||||
t.Errorf("Key not found: %s", pubKey)
|
||||
}
|
||||
notificationChan <- struct{}{}
|
||||
}
|
||||
dirPack, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
packStore, err := NewExpiringDirJWTStore(dirPack, false, false, infDur, 0, true, 0, notification)
|
||||
require_NoError(t, err)
|
||||
// prefill the store with data
|
||||
for k, v := range jwts {
|
||||
require_NoError(t, packStore.SaveAcc(k, v))
|
||||
}
|
||||
for i := 0; i < len(jwts); i++ {
|
||||
<-notificationChan
|
||||
}
|
||||
msg, err := packStore.Pack(-1)
|
||||
require_NoError(t, err)
|
||||
packStore.Close()
|
||||
hash := packStore.Hash()
|
||||
for _, shard := range []bool{true, false, true, false} {
|
||||
dirMerge, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
mergeStore, err := NewExpiringDirJWTStore(dirMerge, shard, false, infDur, 0, true, 0, notification)
|
||||
require_NoError(t, err)
|
||||
// set
|
||||
err = mergeStore.Merge(msg)
|
||||
require_NoError(t, err)
|
||||
assertStoreSize(t, mergeStore, len(jwts))
|
||||
hash1 := packStore.Hash()
|
||||
require_True(t, bytes.Equal(hash[:], hash1[:]))
|
||||
for i := 0; i < len(jwts); i++ {
|
||||
<-notificationChan
|
||||
}
|
||||
// overwrite - assure
|
||||
err = mergeStore.Merge(msg)
|
||||
require_NoError(t, err)
|
||||
assertStoreSize(t, mergeStore, len(jwts))
|
||||
hash2 := packStore.Hash()
|
||||
require_True(t, bytes.Equal(hash1[:], hash2[:]))
|
||||
|
||||
hash = hash1
|
||||
msg, err = mergeStore.Pack(-1)
|
||||
require_NoError(t, err)
|
||||
mergeStore.Close()
|
||||
require_True(t, len(notificationChan) == 0)
|
||||
|
||||
for k, v := range jwts {
|
||||
j, err := packStore.LoadAcc(k)
|
||||
require_NoError(t, err)
|
||||
require_Equal(t, j, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotificationOnPackWalk(t *testing.T) {
|
||||
const storeCnt = 5
|
||||
const keyCnt = 50
|
||||
const iterCnt = 8
|
||||
store := [storeCnt]*DirJWTStore{}
|
||||
for i := 0; i < storeCnt; i++ {
|
||||
dirMerge, err := ioutil.TempDir(os.TempDir(), "jwtstore_test")
|
||||
require_NoError(t, err)
|
||||
mergeStore, err := NewExpiringDirJWTStore(dirMerge, true, false, infDur, 0, true, 0, nil)
|
||||
require_NoError(t, err)
|
||||
store[i] = mergeStore
|
||||
}
|
||||
for i := 0; i < iterCnt; i++ { //iterations
|
||||
jwt := make(map[string]string)
|
||||
for j := 0; j < keyCnt; j++ {
|
||||
key := fmt.Sprintf("key%d-%d", i, j)
|
||||
value := "value"
|
||||
jwt[key] = value
|
||||
store[0].SaveAcc(key, value)
|
||||
}
|
||||
for j := 0; j < storeCnt-1; j++ { // stores
|
||||
err := store[j].PackWalk(3, func(partialPackMsg string) {
|
||||
err := store[j+1].Merge(partialPackMsg)
|
||||
require_NoError(t, err)
|
||||
})
|
||||
require_NoError(t, err)
|
||||
}
|
||||
for i := 0; i < storeCnt-1; i++ {
|
||||
h1 := store[i].Hash()
|
||||
h2 := store[i+1].Hash()
|
||||
require_True(t, bytes.Equal(h1[:], h2[:]))
|
||||
}
|
||||
}
|
||||
for i := 0; i < storeCnt; i++ {
|
||||
store[i].Close()
|
||||
}
|
||||
}
|
||||
@@ -584,9 +584,15 @@ func (s *Server) initEventTracking() {
|
||||
s.Errorf("Error setting up internal tracking: %v", err)
|
||||
}
|
||||
// Listen for account claims updates.
|
||||
subject = fmt.Sprintf(accUpdateEventSubj, "*")
|
||||
if _, err := s.sysSubscribe(subject, s.accountClaimUpdate); err != nil {
|
||||
s.Errorf("Error setting up internal tracking: %v", err)
|
||||
subscribeToUpdate := true
|
||||
if s.accResolver != nil {
|
||||
subscribeToUpdate = !s.accResolver.IsTrackingUpdate()
|
||||
}
|
||||
if subscribeToUpdate {
|
||||
subject = fmt.Sprintf(accUpdateEventSubj, "*")
|
||||
if _, err := s.sysSubscribe(subject, s.accountClaimUpdate); err != nil {
|
||||
s.Errorf("Error setting up internal tracking: %v", err)
|
||||
}
|
||||
}
|
||||
// Listen for requests for our statsz.
|
||||
subject = fmt.Sprintf(serverStatsReqSubj, s.info.ID)
|
||||
@@ -647,7 +653,6 @@ func (s *Server) initEventTracking() {
|
||||
if _, err := s.sysSubscribe(subject, s.remoteLatencyUpdate); err != nil {
|
||||
s.Errorf("Error setting up internal latency tracking: %v", err)
|
||||
}
|
||||
|
||||
// This is for simple debugging of number of subscribers that exist in the system.
|
||||
if _, err := s.sysSubscribeInternal(accSubsSubj, s.debugSubscribers); err != nil {
|
||||
s.Errorf("Error setting up internal debug service for subscribers: %v", err)
|
||||
@@ -1239,17 +1244,22 @@ func (s *Server) sendAuthErrorEvent(c *client) {
|
||||
// required to be copied.
|
||||
type msgHandler func(sub *subscription, client *client, subject, reply string, msg []byte)
|
||||
|
||||
// Create an internal subscription. No support for queue groups atm.
|
||||
// Create an internal subscription. sysSubscribeQ for queue groups.
|
||||
func (s *Server) sysSubscribe(subject string, cb msgHandler) (*subscription, error) {
|
||||
return s.systemSubscribe(subject, false, cb)
|
||||
return s.systemSubscribe(subject, "", false, cb)
|
||||
}
|
||||
|
||||
// Create an internal subscription with queue
|
||||
func (s *Server) sysSubscribeQ(subject, queue string, cb msgHandler) (*subscription, error) {
|
||||
return s.systemSubscribe(subject, queue, false, cb)
|
||||
}
|
||||
|
||||
// Create an internal subscription but do not forward interest.
|
||||
func (s *Server) sysSubscribeInternal(subject string, cb msgHandler) (*subscription, error) {
|
||||
return s.systemSubscribe(subject, true, cb)
|
||||
return s.systemSubscribe(subject, "", true, cb)
|
||||
}
|
||||
|
||||
func (s *Server) systemSubscribe(subject string, internalOnly bool, cb msgHandler) (*subscription, error) {
|
||||
func (s *Server) systemSubscribe(subject, queue string, internalOnly bool, cb msgHandler) (*subscription, error) {
|
||||
if !s.eventsEnabled() {
|
||||
return nil, ErrNoSysAccount
|
||||
}
|
||||
@@ -1263,12 +1273,16 @@ func (s *Server) systemSubscribe(subject string, internalOnly bool, cb msgHandle
|
||||
sid := strconv.Itoa(s.sys.sid)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Now create the subscription
|
||||
if trace {
|
||||
c.traceInOp("SUB", []byte(subject+" "+sid))
|
||||
c.traceInOp("SUB", []byte(subject+" "+queue+" "+sid))
|
||||
}
|
||||
|
||||
// Now create the subscription
|
||||
return c.processSub([]byte(subject), nil, []byte(sid), cb, internalOnly)
|
||||
var q []byte
|
||||
if queue != "" {
|
||||
q = []byte(queue)
|
||||
}
|
||||
return c.processSub([]byte(subject), q, []byte(sid), cb, internalOnly)
|
||||
}
|
||||
|
||||
func (s *Server) sysUnsubscribe(sub *subscription) {
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -2868,3 +2869,319 @@ func TestExpiredUserCredentialsRenewal(t *testing.T) {
|
||||
t.Fatalf("Expected lastErr to be cleared, got %q", nc.LastError())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountNATSResolverFetch(t *testing.T) {
|
||||
require_NextMsg := func(sub *nats.Subscription) bool {
|
||||
msg := natsNexMsg(t, sub, 10*time.Second)
|
||||
content := make(map[string]interface{})
|
||||
json.Unmarshal(msg.Data, &content)
|
||||
if _, ok := content["data"]; ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
require_FileAbsent := func(dir string, pub string) {
|
||||
t.Helper()
|
||||
_, err := os.Stat(filepath.Join(dir, pub+".jwt"))
|
||||
require_Error(t, err)
|
||||
require_True(t, os.IsNotExist(err))
|
||||
}
|
||||
require_FilePresent := func(dir string, pub string) {
|
||||
t.Helper()
|
||||
_, err := os.Stat(filepath.Join(dir, pub+".jwt"))
|
||||
require_NoError(t, err)
|
||||
}
|
||||
require_FileEqual := func(dir string, pub string, jwt string) {
|
||||
t.Helper()
|
||||
content, err := ioutil.ReadFile(filepath.Join(dir, pub+".jwt"))
|
||||
require_NoError(t, err)
|
||||
require_Equal(t, string(content), jwt)
|
||||
}
|
||||
require_1Connection := func(url string, creds string) {
|
||||
t.Helper()
|
||||
c := natsConnect(t, url, nats.UserCredentials(creds))
|
||||
defer c.Close()
|
||||
if _, err := nats.Connect(url, nats.UserCredentials(creds)); err == nil {
|
||||
t.Fatal("Second connection was supposed to fail due to limits")
|
||||
} else if !strings.Contains(err.Error(), ErrTooManyAccountConnections.Error()) {
|
||||
t.Fatal("Second connection was supposed to fail with too many conns")
|
||||
}
|
||||
}
|
||||
require_2Connection := func(url string, creds string) {
|
||||
t.Helper()
|
||||
c1 := natsConnect(t, url, nats.UserCredentials(creds))
|
||||
defer c1.Close()
|
||||
c2 := natsConnect(t, url, nats.UserCredentials(creds))
|
||||
defer c2.Close()
|
||||
if _, err := nats.Connect(url, nats.UserCredentials(creds)); err == nil {
|
||||
t.Fatal("Third connection was supposed to fail due to limits")
|
||||
} else if !strings.Contains(err.Error(), ErrTooManyAccountConnections.Error()) {
|
||||
t.Fatal("Third connection was supposed to fail with too many conns")
|
||||
}
|
||||
}
|
||||
writeFile := func(dir string, pub string, jwt string) {
|
||||
t.Helper()
|
||||
err := ioutil.WriteFile(filepath.Join(dir, pub+".jwt"), []byte(jwt), 0644)
|
||||
require_NoError(t, err)
|
||||
}
|
||||
createDir := func(prefix string) string {
|
||||
t.Helper()
|
||||
dir, err := ioutil.TempDir("", prefix)
|
||||
require_NoError(t, err)
|
||||
return dir
|
||||
}
|
||||
connect := func(url string, credsfile string) {
|
||||
t.Helper()
|
||||
nc := natsConnect(t, url, nats.UserCredentials(credsfile))
|
||||
nc.Close()
|
||||
}
|
||||
createAccountAndUser := func(pair nkeys.KeyPair, limit bool) (string, string, string, string) {
|
||||
t.Helper()
|
||||
kp, _ := nkeys.CreateAccount()
|
||||
pub, _ := kp.PublicKey()
|
||||
claim := jwt.NewAccountClaims(pub)
|
||||
if limit {
|
||||
claim.Limits.Conn = 1
|
||||
}
|
||||
jwt1, err := claim.Encode(pair)
|
||||
require_NoError(t, err)
|
||||
time.Sleep(2 * time.Second)
|
||||
// create updated claim allowing more connections
|
||||
if limit {
|
||||
claim.Limits.Conn = 2
|
||||
}
|
||||
jwt2, err := claim.Encode(pair)
|
||||
require_NoError(t, err)
|
||||
ukp, _ := nkeys.CreateUser()
|
||||
seed, _ := ukp.Seed()
|
||||
upub, _ := ukp.PublicKey()
|
||||
uclaim := newJWTTestUserClaims()
|
||||
uclaim.Subject = upub
|
||||
ujwt, err := uclaim.Encode(kp)
|
||||
require_NoError(t, err)
|
||||
return pub, jwt1, jwt2, genCredsFile(t, ujwt, seed)
|
||||
}
|
||||
updateJwt := func(url string, creds string, pubKey string, jwt string) int {
|
||||
t.Helper()
|
||||
c := natsConnect(t, url, nats.UserCredentials(creds),
|
||||
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
|
||||
if err != nil {
|
||||
t.Fatal("error not expected in this test", err)
|
||||
}
|
||||
}),
|
||||
nats.ErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) {
|
||||
t.Fatal("error not expected in this test", err)
|
||||
}),
|
||||
)
|
||||
defer c.Close()
|
||||
resp := c.NewRespInbox()
|
||||
sub := natsSubSync(t, c, resp)
|
||||
err := sub.AutoUnsubscribe(3)
|
||||
require_NoError(t, err)
|
||||
require_NoError(t, c.PublishRequest(fmt.Sprintf(accUpdateEventSubj, pubKey), resp, []byte(jwt)))
|
||||
passCnt := 0
|
||||
if require_NextMsg(sub) {
|
||||
passCnt++
|
||||
}
|
||||
if require_NextMsg(sub) {
|
||||
passCnt++
|
||||
}
|
||||
if require_NextMsg(sub) {
|
||||
passCnt++
|
||||
}
|
||||
return passCnt
|
||||
}
|
||||
// Create Operator
|
||||
op, _ := nkeys.CreateOperator()
|
||||
opub, _ := op.PublicKey()
|
||||
oc := jwt.NewOperatorClaims(opub)
|
||||
oc.Subject = opub
|
||||
ojwt, err := oc.Encode(op)
|
||||
require_NoError(t, err)
|
||||
// Create Accounts and corresponding user creds
|
||||
syspub, sysjwt, _, sysCreds := createAccountAndUser(op, false)
|
||||
defer os.Remove(sysCreds)
|
||||
apub, ajwt1, ajwt2, aCreds := createAccountAndUser(op, true)
|
||||
defer os.Remove(aCreds)
|
||||
bpub, bjwt1, bjwt2, bCreds := createAccountAndUser(op, true)
|
||||
defer os.Remove(bCreds)
|
||||
cpub, cjwt1, cjwt2, cCreds := createAccountAndUser(op, true)
|
||||
defer os.Remove(cCreds)
|
||||
// Create one directory for each server
|
||||
dirA := createDir("srv-a")
|
||||
defer os.RemoveAll(dirA)
|
||||
dirB := createDir("srv-b")
|
||||
defer os.RemoveAll(dirB)
|
||||
dirC := createDir("srv-c")
|
||||
defer os.RemoveAll(dirC)
|
||||
// simulate a restart of the server by storing files in them
|
||||
// Server A/B will completely sync, so after startup each server
|
||||
// will contain the union off all stored/configured jwt
|
||||
// Server C will send out lookup requests for jwt it does not store itself
|
||||
writeFile(dirA, apub, ajwt1)
|
||||
writeFile(dirB, bpub, bjwt1)
|
||||
writeFile(dirC, cpub, cjwt1)
|
||||
// Create seed server A (using no_advertise to prevent fail over)
|
||||
confA := createConfFile(t, []byte(fmt.Sprintf(`
|
||||
listen: -1
|
||||
server_name: srv-A
|
||||
operator: %s
|
||||
system_account: %s
|
||||
resolver: {
|
||||
type: full
|
||||
dir: %s
|
||||
interval: "1s"
|
||||
limit: 4
|
||||
}
|
||||
resolver_preload: {
|
||||
%s: %s
|
||||
}
|
||||
cluster {
|
||||
name: clust
|
||||
listen: -1
|
||||
no_advertise: true
|
||||
}
|
||||
`, ojwt, syspub, dirA, cpub, cjwt1)))
|
||||
defer os.Remove(confA)
|
||||
sA, _ := RunServerWithConfig(confA)
|
||||
defer sA.Shutdown()
|
||||
// during startup resolver_preload causes the directory to contain data
|
||||
require_FilePresent(dirA, cpub)
|
||||
// Create Server B (using no_advertise to prevent fail over)
|
||||
confB := createConfFile(t, []byte(fmt.Sprintf(`
|
||||
listen: -1
|
||||
server_name: srv-B
|
||||
operator: %s
|
||||
system_account: %s
|
||||
resolver: {
|
||||
type: full
|
||||
dir: %s
|
||||
interval: "1s"
|
||||
limit: 4
|
||||
}
|
||||
cluster {
|
||||
name: clust
|
||||
listen: -1
|
||||
no_advertise: true
|
||||
routes [
|
||||
nats-route://localhost:%d
|
||||
]
|
||||
}
|
||||
`, ojwt, syspub, dirB, sA.opts.Cluster.Port)))
|
||||
defer os.Remove(confB)
|
||||
sB, _ := RunServerWithConfig(confB)
|
||||
// Create Server C (using no_advertise to prevent fail over)
|
||||
confC := createConfFile(t, []byte(fmt.Sprintf(`
|
||||
listen: -1
|
||||
server_name: srv-C
|
||||
operator: %s
|
||||
system_account: %s
|
||||
resolver: {
|
||||
type: cache
|
||||
dir: %s
|
||||
ttl: "10s"
|
||||
limit: 4
|
||||
}
|
||||
cluster {
|
||||
name: clust
|
||||
listen: -1
|
||||
no_advertise: true
|
||||
routes [
|
||||
nats-route://localhost:%d
|
||||
]
|
||||
}
|
||||
`, ojwt, syspub, dirC, sA.opts.Cluster.Port)))
|
||||
defer os.Remove(confC)
|
||||
sC, _ := RunServerWithConfig(confC)
|
||||
// startup cluster
|
||||
checkClusterFormed(t, sA, sB, sC)
|
||||
time.Sleep(2 * time.Second) // wait for the protocol to converge
|
||||
|
||||
// Check all accounts
|
||||
require_FilePresent(dirA, apub) // was already present on startup
|
||||
require_FilePresent(dirB, apub) // was copied from server A
|
||||
require_FileAbsent(dirC, apub)
|
||||
require_FilePresent(dirA, bpub) // was copied from server B
|
||||
require_FilePresent(dirB, bpub) // was already present on startup
|
||||
require_FileAbsent(dirC, bpub)
|
||||
require_FilePresent(dirA, cpub) // was present in preload
|
||||
require_FilePresent(dirB, cpub) // was copied from server A
|
||||
require_FilePresent(dirC, cpub) // was already present on startup
|
||||
// This is to test that connecting to it still works
|
||||
require_FileAbsent(dirA, syspub)
|
||||
require_FileAbsent(dirB, syspub)
|
||||
require_FileAbsent(dirC, syspub)
|
||||
// system account client can connect to every server
|
||||
connect(sA.ClientURL(), sysCreds)
|
||||
connect(sB.ClientURL(), sysCreds)
|
||||
connect(sC.ClientURL(), sysCreds)
|
||||
checkClusterFormed(t, sA, sB, sC)
|
||||
// upload system account and require a response from each server
|
||||
passCnt := updateJwt(sA.ClientURL(), sysCreds, syspub, sysjwt)
|
||||
require_True(t, passCnt == 3)
|
||||
require_FilePresent(dirA, syspub) // was just received
|
||||
require_FilePresent(dirB, syspub) // was just received
|
||||
require_FilePresent(dirC, syspub) // was just received
|
||||
// Only files missing are in C, which is only caching
|
||||
connect(sC.ClientURL(), aCreds)
|
||||
connect(sC.ClientURL(), bCreds)
|
||||
require_FilePresent(dirC, apub) // was looked up form A or B
|
||||
require_FilePresent(dirC, bpub) // was looked up from A or B
|
||||
|
||||
// Check limits and update jwt B connecting to server A
|
||||
for port, v := range map[string]struct{ pub, jwt, creds string }{
|
||||
sB.ClientURL(): {bpub, bjwt2, bCreds},
|
||||
sC.ClientURL(): {cpub, cjwt2, cCreds},
|
||||
} {
|
||||
require_1Connection(sA.ClientURL(), v.creds)
|
||||
require_1Connection(sB.ClientURL(), v.creds)
|
||||
require_1Connection(sC.ClientURL(), v.creds)
|
||||
checkClientsCount(t, sA, 0)
|
||||
checkClientsCount(t, sB, 0)
|
||||
checkClientsCount(t, sC, 0)
|
||||
passCnt := updateJwt(port, sysCreds, v.pub, v.jwt)
|
||||
require_True(t, passCnt == 3)
|
||||
require_2Connection(sA.ClientURL(), v.creds)
|
||||
require_2Connection(sB.ClientURL(), v.creds)
|
||||
require_2Connection(sC.ClientURL(), v.creds)
|
||||
require_FileEqual(dirA, v.pub, v.jwt)
|
||||
require_FileEqual(dirB, v.pub, v.jwt)
|
||||
require_FileEqual(dirC, v.pub, v.jwt)
|
||||
}
|
||||
|
||||
// Simulates A having missed an update
|
||||
// shutting B down as it has it will directly connect to A and connect right away
|
||||
sB.Shutdown()
|
||||
writeFile(dirB, apub, ajwt2) // this will be copied to server A
|
||||
sB, _ = RunServerWithConfig(confB)
|
||||
defer sB.Shutdown()
|
||||
checkClusterFormed(t, sA, sB, sC)
|
||||
time.Sleep(2 * time.Second) // wait for the protocol to converge, will also assure that C's cache becomes invalid
|
||||
// Restart server C. this is a workaround to force C to do a lookup in the absence of account cleanup
|
||||
sC.Shutdown()
|
||||
sC, _ = RunServerWithConfig(confC) //TODO remove this once we clean up accounts
|
||||
|
||||
require_FileEqual(dirA, apub, ajwt2) // was copied from server B
|
||||
require_FileEqual(dirB, apub, ajwt2) // was restarted with this
|
||||
require_FileEqual(dirC, apub, ajwt1) // still contains old cached value
|
||||
require_2Connection(sA.ClientURL(), aCreds)
|
||||
require_2Connection(sB.ClientURL(), aCreds)
|
||||
require_1Connection(sC.ClientURL(), aCreds)
|
||||
|
||||
// Restart server C. this is a workaround to force C to do a lookup in the absence of account cleanup
|
||||
sC.Shutdown()
|
||||
sC, _ = RunServerWithConfig(confC) //TODO remove this once we clean up accounts
|
||||
defer sC.Shutdown()
|
||||
require_FileEqual(dirC, apub, ajwt1) // still contains old cached value
|
||||
time.Sleep(time.Second * 12) // Force next connect to do a lookup
|
||||
connect(sC.ClientURL(), aCreds) // When lookup happens
|
||||
require_FileEqual(dirC, apub, ajwt2) // was looked up form A or B
|
||||
require_2Connection(sC.ClientURL(), aCreds)
|
||||
|
||||
// Test exceeding limit. For the exclusive directory resolver, limit is a stop gap measure.
|
||||
// It is not expected to be hit. When hit the administrator is supposed to take action.
|
||||
dpub, djwt1, _, dCreds := createAccountAndUser(op, true)
|
||||
defer os.Remove(dCreds)
|
||||
passCnt = updateJwt(sA.ClientURL(), sysCreds, dpub, djwt1)
|
||||
require_True(t, passCnt == 1) // Only Server C updated
|
||||
}
|
||||
|
||||
@@ -816,22 +816,16 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error
|
||||
}
|
||||
}
|
||||
case "resolver", "account_resolver", "accounts_resolver":
|
||||
// "resolver" takes precedence over value obtained from "operator".
|
||||
// Clear so that parsing errors are not silently ignored.
|
||||
o.AccountResolver = nil
|
||||
var memResolverRe = regexp.MustCompile(`(MEM|MEMORY|mem|memory)\s*`)
|
||||
var resolverRe = regexp.MustCompile(`(?:URL|url){1}(?:\({1}\s*"?([^\s"]*)"?\s*\){1})?\s*`)
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
err := &configErr{tk, fmt.Sprintf("error parsing operator resolver, wrong type %T", v)}
|
||||
*errors = append(*errors, err)
|
||||
return
|
||||
}
|
||||
if memResolverRe.MatchString(str) {
|
||||
o.AccountResolver = &MemAccResolver{}
|
||||
} else {
|
||||
items := resolverRe.FindStringSubmatch(str)
|
||||
if len(items) == 2 {
|
||||
switch v := v.(type) {
|
||||
case string:
|
||||
// "resolver" takes precedence over value obtained from "operator".
|
||||
// Clear so that parsing errors are not silently ignored.
|
||||
o.AccountResolver = nil
|
||||
memResolverRe := regexp.MustCompile(`(?i)(MEM|MEMORY)\s*`)
|
||||
resolverRe := regexp.MustCompile(`(?i)(?:URL){1}(?:\({1}\s*"?([^\s"]*)"?\s*\){1})?\s*`)
|
||||
if memResolverRe.MatchString(v) {
|
||||
o.AccountResolver = &MemAccResolver{}
|
||||
} else if items := resolverRe.FindStringSubmatch(v); len(items) == 2 {
|
||||
url := items[1]
|
||||
_, err := parseURL(url, "account resolver")
|
||||
if err != nil {
|
||||
@@ -846,9 +840,70 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error
|
||||
o.AccountResolver = ur
|
||||
}
|
||||
}
|
||||
case map[string]interface{}:
|
||||
dir := ""
|
||||
dirType := ""
|
||||
limit := int64(0)
|
||||
ttl := time.Duration(0)
|
||||
sync := time.Duration(0)
|
||||
var err error
|
||||
if v, ok := v["dir"]; ok {
|
||||
_, v := unwrapValue(v, <)
|
||||
dir = v.(string)
|
||||
}
|
||||
if v, ok := v["type"]; ok {
|
||||
_, v := unwrapValue(v, <)
|
||||
dirType = v.(string)
|
||||
}
|
||||
if v, ok := v["limit"]; ok {
|
||||
_, v := unwrapValue(v, <)
|
||||
limit = v.(int64)
|
||||
}
|
||||
if v, ok := v["ttl"]; ok {
|
||||
_, v := unwrapValue(v, <)
|
||||
ttl, err = time.ParseDuration(v.(string))
|
||||
}
|
||||
if v, ok := v["interval"]; err == nil && ok {
|
||||
_, v := unwrapValue(v, <)
|
||||
sync, err = time.ParseDuration(v.(string))
|
||||
}
|
||||
if err != nil {
|
||||
*errors = append(*errors, &configErr{tk, err.Error()})
|
||||
return
|
||||
}
|
||||
if dir == "" {
|
||||
*errors = append(*errors, &configErr{tk, "dir needs to point to a directory"})
|
||||
return
|
||||
}
|
||||
if info, err := os.Stat(dir); err != nil || !info.IsDir() || info.Mode().Perm()&(1<<(uint(7))) == 0 {
|
||||
info.IsDir()
|
||||
}
|
||||
var res AccountResolver
|
||||
switch strings.ToUpper(dirType) {
|
||||
case "CACHE":
|
||||
if sync != 0 {
|
||||
*errors = append(*errors, &configErr{tk, "CACHE does not accept sync"})
|
||||
}
|
||||
res, err = NewCacheDirAccResolver(dir, limit, ttl)
|
||||
case "FULL":
|
||||
if ttl != 0 {
|
||||
*errors = append(*errors, &configErr{tk, "FULL does not accept ttl"})
|
||||
}
|
||||
res, err = NewDirAccResolver(dir, limit, sync)
|
||||
}
|
||||
if err != nil {
|
||||
*errors = append(*errors, &configErr{tk, err.Error()})
|
||||
return
|
||||
}
|
||||
o.AccountResolver = res
|
||||
default:
|
||||
err := &configErr{tk, fmt.Sprintf("error parsing operator resolver, wrong type %T", v)}
|
||||
*errors = append(*errors, err)
|
||||
return
|
||||
}
|
||||
if o.AccountResolver == nil {
|
||||
err := &configErr{tk, "error parsing account resolver, should be MEM or URL(\"url\")"}
|
||||
err := &configErr{tk, "error parsing account resolver, should be MEM or " +
|
||||
" URL(\"url\") or a map containing dir and type state=[FULL|CACHE])"}
|
||||
*errors = append(*errors, err)
|
||||
}
|
||||
case "resolver_tls":
|
||||
|
||||
@@ -770,7 +770,7 @@ func imposeOrder(value interface{}) error {
|
||||
return value.AllowedOrigins[i] < value.AllowedOrigins[j]
|
||||
})
|
||||
case string, bool, int, int32, int64, time.Duration, float64, nil,
|
||||
LeafNodeOpts, ClusterOpts, *tls.Config, *URLAccResolver, *MemAccResolver, Authentication:
|
||||
LeafNodeOpts, ClusterOpts, *tls.Config, *URLAccResolver, *MemAccResolver, *DirAccResolver, *CacheDirAccResolver, Authentication:
|
||||
// explicitly skipped types
|
||||
default:
|
||||
// this will fail during unit tests
|
||||
@@ -1275,6 +1275,10 @@ func (s *Server) reloadAuthorization() {
|
||||
if checkJetStream {
|
||||
s.configAllJetStreamAccounts()
|
||||
}
|
||||
|
||||
if res := s.AccountResolver(); res != nil {
|
||||
res.Reload()
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if given client current account has changed (or user
|
||||
|
||||
@@ -360,14 +360,11 @@ func NewServer(opts *Options) (*Server, error) {
|
||||
// waiting for complete shutdown.
|
||||
s.shutdownComplete = make(chan struct{})
|
||||
|
||||
// For tracking accounts
|
||||
if err := s.configureAccounts(); err != nil {
|
||||
// Check for configured account resolvers.
|
||||
if err := s.configureResolver(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If there is an URL account resolver, do basic test to see if anyone is home.
|
||||
// Do this after configureAccounts() which calls configureResolver(), which will
|
||||
// set TLSConfig if specified.
|
||||
if ar := opts.AccountResolver; ar != nil {
|
||||
if ur, ok := ar.(*URLAccResolver); ok {
|
||||
if _, err := ur.Fetch(""); err != nil {
|
||||
@@ -375,6 +372,31 @@ func NewServer(opts *Options) (*Server, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
// For other resolver:
|
||||
// In operator mode, when the account resolver depends on an external system and
|
||||
// the system account can't fetched, inject a temporary one.
|
||||
if ar := s.accResolver; len(opts.TrustedOperators) == 1 && ar != nil &&
|
||||
opts.SystemAccount != _EMPTY_ && opts.SystemAccount != DEFAULT_SYSTEM_ACCOUNT {
|
||||
if _, ok := ar.(*MemAccResolver); !ok {
|
||||
s.mu.Unlock()
|
||||
var a *Account
|
||||
// perform direct lookup to avoid warning trace
|
||||
if _, err := ar.Fetch(s.opts.SystemAccount); err == nil {
|
||||
a, _ = s.fetchAccount(s.opts.SystemAccount)
|
||||
}
|
||||
s.mu.Lock()
|
||||
if a == nil {
|
||||
sac := NewAccount(s.opts.SystemAccount)
|
||||
sac.Issuer = opts.TrustedOperators[0].Issuer
|
||||
s.registerAccountNoLock(sac)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For tracking accounts
|
||||
if err := s.configureAccounts(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// In local config mode, check that leafnode configuration
|
||||
// refers to account that exist.
|
||||
@@ -610,11 +632,6 @@ func (s *Server) configureAccounts() error {
|
||||
return true
|
||||
})
|
||||
|
||||
// Check for configured account resolvers.
|
||||
if err := s.configureResolver(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the system account if it was configured.
|
||||
// Otherwise create a default one.
|
||||
if opts.SystemAccount != _EMPTY_ {
|
||||
@@ -654,8 +671,8 @@ func (s *Server) configureResolver() error {
|
||||
}
|
||||
}
|
||||
if len(opts.resolverPreloads) > 0 {
|
||||
if _, ok := s.accResolver.(*MemAccResolver); !ok {
|
||||
return fmt.Errorf("resolver preloads only available for resolver type MEM")
|
||||
if s.accResolver.IsReadOnly() {
|
||||
return fmt.Errorf("resolver preloads only available for writeable resolver types MEM/DIR/CACHE_DIR")
|
||||
}
|
||||
for k, v := range opts.resolverPreloads {
|
||||
_, err := jwt.DecodeAccountClaims(v)
|
||||
@@ -1346,7 +1363,8 @@ func (s *Server) Start() {
|
||||
// Log the pid to a file
|
||||
if opts.PidFile != _EMPTY_ {
|
||||
if err := s.logPid(); err != nil {
|
||||
PrintAndDie(fmt.Sprintf("Could not write pidfile: %v\n", err))
|
||||
s.Fatalf("Could not write pidfile: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1361,6 +1379,42 @@ func (s *Server) Start() {
|
||||
s.SetDefaultSystemAccount()
|
||||
}
|
||||
|
||||
// start up resolver machinery
|
||||
if ar := s.AccountResolver(); ar != nil {
|
||||
if err := ar.Start(s); err != nil {
|
||||
s.Fatalf("Could not start resolver: %v", err)
|
||||
return
|
||||
}
|
||||
// In operator mode, when the account resolver depends on an external system and
|
||||
// the system account is the bootstrapping account, start fetching it
|
||||
if len(opts.TrustedOperators) == 1 && opts.SystemAccount != _EMPTY_ && opts.SystemAccount != DEFAULT_SYSTEM_ACCOUNT {
|
||||
_, isMemResolver := ar.(*MemAccResolver)
|
||||
if v, ok := s.accounts.Load(s.opts.SystemAccount); !isMemResolver && ok && v.(*Account).claimJWT == "" {
|
||||
s.Noticef("Using bootstrapping system account")
|
||||
s.startGoRoutine(func() {
|
||||
defer s.grWG.Done()
|
||||
t := time.NewTicker(time.Second)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.quitCh:
|
||||
return
|
||||
case <-t.C:
|
||||
if _, err := ar.Fetch(s.opts.SystemAccount); err != nil {
|
||||
continue
|
||||
}
|
||||
if _, err := s.fetchAccount(s.opts.SystemAccount); err != nil {
|
||||
continue
|
||||
}
|
||||
s.Noticef("System account fetched and updated")
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start expiration of mapped GW replies, regardless if
|
||||
// this server is configured with gateway or not.
|
||||
s.startGWReplyMapExpiration()
|
||||
@@ -1483,6 +1537,10 @@ func (s *Server) Shutdown() {
|
||||
}
|
||||
s.Noticef("Initiating Shutdown...")
|
||||
|
||||
if s.accResolver != nil {
|
||||
s.accResolver.Close()
|
||||
}
|
||||
|
||||
opts := s.getOpts()
|
||||
|
||||
s.shutdown = true
|
||||
|
||||
Reference in New Issue
Block a user