diff --git a/server/filestore.go b/server/filestore.go index 618ffc0f..1f4479e5 100644 --- a/server/filestore.go +++ b/server/filestore.go @@ -1916,8 +1916,9 @@ func (fs *fileStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts in } // Per subject max check needed. + mmp := uint64(fs.cfg.MaxMsgsPer) var psmc uint64 - psmax := fs.cfg.MaxMsgsPer > 0 && len(subj) > 0 + psmax := mmp > 0 && len(subj) > 0 if psmax { if info, ok := fs.psim[subj]; ok { psmc = info.total @@ -1928,13 +1929,12 @@ func (fs *fileStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts in // Check if we are discarding new messages when we reach the limit. if fs.cfg.Discard == DiscardNew { var asl bool - if psmax && psmc >= uint64(fs.cfg.MaxMsgsPer) { + if psmax && psmc >= mmp { // If we are instructed to discard new per subject, this is an error. if fs.cfg.DiscardNewPer { return ErrMaxMsgsPerSubject } - fseq, err = fs.firstSeqForSubj(subj) - if err != nil { + if fseq, err = fs.firstSeqForSubj(subj); err != nil { return err } asl = true @@ -1990,12 +1990,25 @@ func (fs *fileStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts in // Enforce per message limits. // We snapshotted psmc before our actual write, so >= comparison needed. - if psmax && psmc >= uint64(fs.cfg.MaxMsgsPer) { + if psmax && psmc >= mmp { // We may have done this above. if fseq == 0 { fseq, _ = fs.firstSeqForSubj(subj) } - fs.removeMsg(fseq, false, false) + if ok, _ := fs.removeMsg(fseq, false, false); ok { + // Make sure we are below the limit. + if psmc--; psmc >= mmp { + for info, ok := fs.psim[subj]; ok && info.total > mmp; info, ok = fs.psim[subj] { + if seq, _ := fs.firstSeqForSubj(subj); seq > 0 { + if ok, _ := fs.removeMsg(seq, false, false); !ok { + break + } + } else { + break + } + } + } + } } // Limits checks and enforcement. @@ -5317,23 +5330,29 @@ func (mb *msgBlock) removeSeqPerSubject(subj string, seq uint64, smp *StoreMsg) return } - // Here what we are removing is the first message. - // If we only have one message left we can simply assign it to last. + // Only one left. if ss.Msgs == 1 { - ss.First = ss.Last + if seq != ss.First { + ss.Last = ss.First + } else { + ss.First = ss.Last + } return } + // Recalculate first. // TODO(dlc) - Might want to optimize this. - var smv StoreMsg - if smp == nil { - smp = &smv - } - for tseq := seq + 1; tseq <= ss.Last; tseq++ { - if sm, _ := mb.cacheLookup(tseq, smp); sm != nil { - if sm.subj == subj { - ss.First = tseq - return + if seq == ss.First { + var smv StoreMsg + if smp == nil { + smp = &smv + } + for tseq := seq + 1; tseq <= ss.Last; tseq++ { + if sm, _ := mb.cacheLookup(tseq, smp); sm != nil { + if sm.subj == subj { + ss.First = tseq + return + } } } } @@ -5443,9 +5462,6 @@ func (mb *msgBlock) loadPerSubjectInfo() ([]byte, error) { // Helper to make sure fss loaded if we are tracking. // Lock should be held func (mb *msgBlock) ensurePerSubjectInfoLoaded() error { - // Clear - mb.fssNeedsWrite = false - if mb.fss != nil || mb.noTrack { return nil } @@ -5536,6 +5552,7 @@ func (mb *msgBlock) readPerSubjectInfo(hasLock bool) error { fss[subj] = &SimpleState{Msgs: msgs, First: first, Last: last} } mb.fss = fss + mb.fssNeedsWrite = false // Make sure we run the cache expire timer. if len(mb.fss) > 0 { @@ -5613,10 +5630,11 @@ func (mb *msgBlock) close(sync bool) { } // Check if we are tracking by subject. - if mb.fss != nil { + if len(mb.fss) > 0 && mb.fssNeedsWrite { mb.writePerSubjectInfo() - mb.fss = nil } + mb.fss = nil + mb.fssNeedsWrite = false // Close cache mb.clearCacheAndOffset() diff --git a/server/norace_test.go b/server/norace_test.go index 104af425..9db3c5d0 100644 --- a/server/norace_test.go +++ b/server/norace_test.go @@ -23,6 +23,7 @@ import ( "context" "encoding/binary" "encoding/json" + "errors" "fmt" "io" "math/rand" @@ -6273,3 +6274,80 @@ func TestNoRaceJetStreamClusterConsumerInfoSpeed(t *testing.T) { checkNumPending(toKeep) } + +func TestNoRaceJetStreamKVAccountWithServerRestarts(t *testing.T) { + c := createJetStreamClusterExplicit(t, "R3S", 3) + defer c.shutdown() + + nc, js := jsClientConnect(t, c.randomServer()) + defer nc.Close() + + _, err := js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: "TEST", + Replicas: 3, + }) + require_NoError(t, err) + + npubs := 10_000 + par := 8 + iter := 2 + nsubjs := 250 + + wg := sync.WaitGroup{} + putKeys := func() { + wg.Add(1) + go func() { + defer wg.Done() + nc, js := jsClientConnect(t, c.randomServer()) + defer nc.Close() + kv, err := js.KeyValue("TEST") + require_NoError(t, err) + + for i := 0; i < npubs; i++ { + subj := fmt.Sprintf("KEY-%d", rand.Intn(nsubjs)) + if _, err := kv.PutString(subj, "hello"); err != nil { + nc, js := jsClientConnect(t, c.randomServer()) + defer nc.Close() + kv, err = js.KeyValue("TEST") + require_NoError(t, err) + } + } + }() + } + + restartServers := func() { + time.Sleep(2 * time.Second) + // Rotate through and restart the servers. + for _, server := range c.servers { + server.Shutdown() + restarted := c.restartServer(server) + checkFor(t, time.Second, 200*time.Millisecond, func() error { + hs := restarted.healthz(&HealthzOptions{ + JSEnabled: true, + JSServerOnly: true, + }) + if hs.Error != _EMPTY_ { + return errors.New(hs.Error) + } + return nil + }) + } + c.waitOnLeader() + c.waitOnStreamLeader(globalAccountName, "KV_TEST") + } + + for n := 0; n < iter; n++ { + for i := 0; i < par; i++ { + putKeys() + } + restartServers() + } + wg.Wait() + + nc, js = jsClientConnect(t, c.randomServer()) + defer nc.Close() + + si, err := js.StreamInfo("KV_TEST") + require_NoError(t, err) + require_True(t, si.State.NumSubjects == uint64(nsubjs)) +}