diff --git a/server/dirstore.go b/server/dirstore.go index b0d82ea6..9bb499fd 100644 --- a/server/dirstore.go +++ b/server/dirstore.go @@ -28,6 +28,8 @@ import ( "sync" "time" + "github.com/nats-io/nkeys" + "github.com/nats-io/jwt/v2" // only used to decode, not for storage ) @@ -321,6 +323,9 @@ func (store *DirJWTStore) Merge(pack string) error { return fmt.Errorf("line in package didn't contain 2 entries: %q", line) } pubKey := split[0] + if !nkeys.IsValidPublicAccountKey(pubKey) { + return fmt.Errorf("key to merge is not a valid public account key") + } if err := store.saveIfNewer(pubKey, split[1]); err != nil { return err } @@ -370,6 +375,9 @@ func (store *DirJWTStore) pathForKey(publicKey string) string { if len(publicKey) < 2 { return _EMPTY_ } + if !nkeys.IsValidPublicKey(publicKey) { + return _EMPTY_ + } fileName := fmt.Sprintf("%s%s", publicKey, fileExtension) if store.shard { last := publicKey[len(publicKey)-2:] @@ -488,7 +496,7 @@ func (store *DirJWTStore) save(publicKey string, theJWT string) error { } // Assumes the lock is NOT held, and only updates if the jwt is new, or the one on disk is older -// returns true when the jwt changed +// When changed, invokes jwt changed callback func (store *DirJWTStore) saveIfNewer(publicKey string, theJWT string) error { if store.readonly { return fmt.Errorf("store is read-only") @@ -505,7 +513,7 @@ func (store *DirJWTStore) saveIfNewer(publicKey string, theJWT string) error { } if _, err := os.Stat(path); err == nil { if newJWT, err := jwt.DecodeGeneric(theJWT); err != nil { - // skip if it can't be decoded + return err } else if existing, err := ioutil.ReadFile(path); err != nil { return err } else if existingJWT, err := jwt.DecodeGeneric(string(existing)); err != nil { @@ -514,6 +522,10 @@ func (store *DirJWTStore) saveIfNewer(publicKey string, theJWT string) error { return nil } else if existingJWT.IssuedAt > newJWT.IssuedAt { return nil + } else if newJWT.Subject != publicKey { + return fmt.Errorf("jwt subject nkey and provided nkey do not match") + } else if existingJWT.Subject != newJWT.Subject { + return fmt.Errorf("subject of existing and new jwt do not match") } } store.Lock() diff --git a/server/dirstore_test.go b/server/dirstore_test.go index 6c4171cb..e5bceb57 100644 --- a/server/dirstore_test.go +++ b/server/dirstore_test.go @@ -30,6 +30,40 @@ import ( "github.com/nats-io/nkeys" ) +var ( + one, two, three, four = "", "", "", "" + jwt1, jwt2, jwt3, jwt4 = "", "", "", "" + op nkeys.KeyPair +) + +func init() { + op, _ = nkeys.CreateOperator() + + nkone, _ := nkeys.CreateAccount() + pub, _ := nkone.PublicKey() + one = pub + ac := jwt.NewAccountClaims(pub) + jwt1, _ = ac.Encode(op) + + nktwo, _ := nkeys.CreateAccount() + pub, _ = nktwo.PublicKey() + two = pub + ac = jwt.NewAccountClaims(pub) + jwt2, _ = ac.Encode(op) + + nkthree, _ := nkeys.CreateAccount() + pub, _ = nkthree.PublicKey() + three = pub + ac = jwt.NewAccountClaims(pub) + jwt3, _ = ac.Encode(op) + + nkfour, _ := nkeys.CreateAccount() + pub, _ = nkfour.PublicKey() + four = pub + ac = jwt.NewAccountClaims(pub) + jwt4, _ = ac.Encode(op) +} + func TestShardedDirStoreWriteAndReadonly(t *testing.T) { t.Parallel() dir := createDir(t, "jwtstore_test") @@ -38,10 +72,10 @@ func TestShardedDirStoreWriteAndReadonly(t *testing.T) { require_NoError(t, err) expected := map[string]string{ - "one": "alpha", - "two": "beta", - "three": "gamma", - "four": "delta", + one: "alpha", + two: "beta", + three: "gamma", + four: "delta", } for k, v := range expected { @@ -91,10 +125,10 @@ func TestUnshardedDirStoreWriteAndReadonly(t *testing.T) { require_NoError(t, err) expected := map[string]string{ - "one": "alpha", - "two": "beta", - "three": "gamma", - "four": "delta", + one: "alpha", + two: "beta", + three: "gamma", + four: "delta", } require_False(t, store.IsReadOnly()) @@ -172,10 +206,10 @@ func TestShardedDirStorePackMerge(t *testing.T) { require_NoError(t, err) expected := map[string]string{ - "one": "alpha", - "two": "beta", - "three": "gamma", - "four": "delta", + one: "alpha", + two: "beta", + three: "gamma", + four: "delta", } require_False(t, store.IsReadOnly()) @@ -246,10 +280,10 @@ func TestShardedToUnsharedDirStorePackMerge(t *testing.T) { require_NoError(t, err) expected := map[string]string{ - "one": "alpha", - "two": "beta", - "three": "gamma", - "four": "delta", + one: "alpha", + two: "beta", + three: "gamma", + four: "delta", } require_False(t, store.IsReadOnly()) @@ -848,10 +882,10 @@ const infDur = time.Duration(math.MaxInt64) func TestNotificationOnPack(t *testing.T) { t.Parallel() jwts := map[string]string{ - "key1": "value", - "key2": "value", - "key3": "value", - "key4": "value", + one: jwt1, + two: jwt2, + three: jwt3, + four: jwt4, } notificationChan := make(chan struct{}, len(jwts)) // set to same len so all extra will block notification := func(pubKey string) { @@ -921,12 +955,13 @@ func TestNotificationOnPackWalk(t *testing.T) { store[i] = mergeStore } for i := 0; i < iterCnt; i++ { //iterations - jwt := make(map[string]string) + jwts := make(map[string]string) for j := 0; j < keyCnt; j++ { - key := fmt.Sprintf("key%d-%d", i, j) - value := "value" - jwt[key] = value - store[0].SaveAcc(key, value) + kp, _ := nkeys.CreateAccount() + key, _ := kp.PublicKey() + ac := jwt.NewAccountClaims(key) + jwts[key], _ = ac.Encode(op) + require_NoError(t, store[0].SaveAcc(key, jwts[key])) } for j := 0; j < storeCnt-1; j++ { // stores err := store[j].PackWalk(3, func(partialPackMsg string) {