diff --git a/server/filestore.go b/server/filestore.go index 07c3e9d4..20a9a4a4 100644 --- a/server/filestore.go +++ b/server/filestore.go @@ -404,6 +404,14 @@ func newFileStoreWithCreated(fcfg FileStoreConfig, cfg StreamConfig, created tim return nil, fmt.Errorf("could not create hash: %v", err) } + keyFile := filepath.Join(fs.fcfg.StoreDir, JetStreamMetaFileKey) + // Make sure we do not have an encrypted store underneath of us but no main key. + if fs.prf == nil { + if _, err := os.Stat(keyFile); err == nil { + return nil, errNoMainKey + } + } + // Recover our message state. if err := fs.recoverMsgs(); err != nil { return nil, err @@ -421,7 +429,6 @@ func newFileStoreWithCreated(fcfg FileStoreConfig, cfg StreamConfig, created tim // If we expect to be encrypted check that what we are restoring is not plaintext. // This can happen on snapshot restores or conversions. if fs.prf != nil { - keyFile := filepath.Join(fs.fcfg.StoreDir, JetStreamMetaFileKey) if _, err := os.Stat(keyFile); err != nil && os.IsNotExist(err) { if err := fs.writeStreamMeta(); err != nil { return nil, err @@ -4638,6 +4645,7 @@ var ( errMsgBlkTooBig = errors.New("message block size exceeded int capacity") errUnknownCipher = errors.New("unknown cipher") errDIOStalled = errors.New("IO is stalled") + errNoMainKey = errors.New("encrypted store encountered with no main key") ) // Used for marking messages that have had their checksums checked. diff --git a/server/filestore_test.go b/server/filestore_test.go index 0bf4328a..50a598fe 100644 --- a/server/filestore_test.go +++ b/server/filestore_test.go @@ -5604,3 +5604,41 @@ func TestFileStoreSkipMsgAndNumBlocks(t *testing.T) { fs.StoreMsg(subj, nil, msg) require_True(t, fs.numMsgBlocks() == 2) } + +func TestFileStoreRestoreEncryptedWithNoKeyFuncFails(t *testing.T) { + // No need for all permutations here. + fcfg := FileStoreConfig{StoreDir: t.TempDir(), Cipher: AES} + scfg := StreamConfig{Name: "zzz", Subjects: []string{"zzz"}, Storage: FileStorage} + + // Create at first with encryption (prf) + prf := func(context []byte) ([]byte, error) { + h := hmac.New(sha256.New, []byte("dlc22")) + if _, err := h.Write(context); err != nil { + return nil, err + } + return h.Sum(nil), nil + } + + fs, err := newFileStoreWithCreated( + fcfg, scfg, + time.Now(), + prf, + ) + require_NoError(t, err) + + subj, msg := "zzz", bytes.Repeat([]byte("X"), 100) + numMsgs := 100 + for i := 0; i < numMsgs; i++ { + fs.StoreMsg(subj, nil, msg) + } + + fs.Stop() + + // Make sure if we try to restore with no prf (key) that it fails. + _, err = newFileStoreWithCreated( + fcfg, scfg, + time.Now(), + nil, + ) + require_Error(t, err, errNoMainKey) +} diff --git a/server/jetstream_cluster.go b/server/jetstream_cluster.go index d7596547..da16a28c 100644 --- a/server/jetstream_cluster.go +++ b/server/jetstream_cluster.go @@ -759,6 +759,7 @@ func (js *jetStream) setupMetaGroup() error { s.Errorf("Error creating filestore: %v", err) return err } + // Register our server. fs.registerServer(s) diff --git a/server/jetstream_cluster_3_test.go b/server/jetstream_cluster_3_test.go index be8a6b08..57be3f03 100644 --- a/server/jetstream_cluster_3_test.go +++ b/server/jetstream_cluster_3_test.go @@ -4541,3 +4541,50 @@ func TestJetStreamBinaryStreamSnapshotCapability(t *testing.T) { t.Fatalf("Expected to signal that we could support binary stream snapshots") } } + +func TestJetStreamClusterBadEncryptKey(t *testing.T) { + c := createJetStreamClusterWithTemplate(t, jsClusterEncryptedTempl, "JSC", 3) + defer c.shutdown() + + nc, js := jsClientConnect(t, c.randomServer()) + defer nc.Close() + + // Create 10 streams. + for i := 0; i < 10; i++ { + _, err := js.AddStream(&nats.StreamConfig{ + Name: fmt.Sprintf("TEST-%d", i), + Replicas: 3, + }) + require_NoError(t, err) + } + + // Grab random server. + s := c.randomServer() + s.Shutdown() + s.WaitForShutdown() + + var opts *Options + for i := 0; i < len(c.servers); i++ { + if c.servers[i] == s { + opts = c.opts[i] + break + } + } + require_NotNil(t, opts) + + // Replace key with an empty key. + buf, err := os.ReadFile(opts.ConfigFile) + require_NoError(t, err) + nbuf := bytes.Replace(buf, []byte("key: \"s3cr3t!\""), []byte("key: \"\""), 1) + err = os.WriteFile(opts.ConfigFile, nbuf, 0640) + require_NoError(t, err) + + // Make sure trying to start the server now fails. + s, err = NewServer(LoadConfig(opts.ConfigFile)) + require_NoError(t, err) + require_NotNil(t, s) + s.Start() + if err := s.readyForConnections(1 * time.Second); err == nil { + t.Fatalf("Expected server not to start") + } +}