From 9a792482e9be48598e86590dc00ff06ccd5eb201 Mon Sep 17 00:00:00 2001 From: Tomasz Pietrek Date: Fri, 31 Mar 2023 15:56:11 +0200 Subject: [PATCH] Improve consumer with multiple filters Signed-off-by: Tomasz Pietrek --- server/consumer.go | 42 ++-- server/jetstream_consumer_test.go | 401 ++++++++++++++++++++++++++++++ 2 files changed, 423 insertions(+), 20 deletions(-) create mode 100644 server/jetstream_consumer_test.go diff --git a/server/consumer.go b/server/consumer.go index dd91b844..02fc273c 100644 --- a/server/consumer.go +++ b/server/consumer.go @@ -3290,12 +3290,10 @@ func (o *consumer) getNextMsg() (*jsPubMsg, uint64, error) { filter.pmsg = pmsg } else { pmsg.returnToPool() + pmsg = nil } if sseq >= filter.nextSeq { filter.nextSeq = sseq + 1 - if err == ErrStoreEOF { - o.updateSkipped(uint64(filter.currentSeq)) - } } // If we're sure that this filter has continuous sequence of messages, skip looking up other filters. @@ -3313,30 +3311,34 @@ func (o *consumer) getNextMsg() (*jsPubMsg, uint64, error) { // to avoid reflection. if o.subjf != nil && len(o.subjf) > 1 { sort.Slice(o.subjf, func(i, j int) bool { + if o.subjf[j].pmsg != nil && o.subjf[i].pmsg == nil { + return false + } + if o.subjf[i].pmsg != nil && o.subjf[j].pmsg == nil { + return true + } return o.subjf[j].nextSeq > o.subjf[i].nextSeq }) } - // Grab next message applicable to us. // Sort sequences first, to grab the first message. - for _, filter := range o.subjf { - // set o.sseq to the first subject sequence - if filter.nextSeq > o.sseq { - o.sseq = filter.nextSeq - } - // This means we got a message in this subject fetched. - if filter.pmsg != nil { - filter.currentSeq = filter.nextSeq - o.sseq = filter.currentSeq - returned := filter.pmsg - filter.pmsg = nil - return returned, 1, filter.err - } - if filter.err != nil { - lastErr = filter.err - } + filter := o.subjf[0] + // This means we got a message in this subject fetched. + if filter.pmsg != nil { + filter.currentSeq = filter.nextSeq + o.sseq = filter.currentSeq + returned := filter.pmsg + filter.pmsg = nil + return returned, 1, filter.err + } + if filter.err == ErrStoreEOF { + o.updateSkipped(filter.nextSeq) } + // set o.sseq to the first subject sequence + if filter.nextSeq > o.sseq { + o.sseq = filter.nextSeq + } return nil, 0, lastErr } diff --git a/server/jetstream_consumer_test.go b/server/jetstream_consumer_test.go new file mode 100644 index 00000000..d13765b5 --- /dev/null +++ b/server/jetstream_consumer_test.go @@ -0,0 +1,401 @@ +// Copyright 2022-2023 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !skip_js_tests +// +build !skip_js_tests + +package server + +import ( + "fmt" + "math/rand" + "sort" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/nats-io/nats.go" +) + +func TestJetStreamConsumerMultipleFiltersRace(t *testing.T) { + s := RunBasicJetStreamServer(t) + defer s.Shutdown() + + nc, js := jsClientConnect(t, s) + defer nc.Close() + acc := s.GlobalAccount() + + mset, err := acc.addStream(&StreamConfig{ + Name: "TEST", + Retention: LimitsPolicy, + Subjects: []string{"one", "two", "three", "four"}, + MaxAge: time.Second * 90, + }) + require_NoError(t, err) + + var seqs []uint64 + var mu sync.Mutex + + for i := 0; i < 10_000; i++ { + sendStreamMsg(t, nc, "one", "data") + sendStreamMsg(t, nc, "two", "data") + sendStreamMsg(t, nc, "three", "data") + sendStreamMsg(t, nc, "four", "data") + } + + mset.addConsumer(&ConsumerConfig{ + Durable: "consumer", + FilterSubjects: []string{"one", "two", "three"}, + AckPolicy: AckExplicit, + }) + + done := make(chan struct{}) + for i := 0; i < 10; i++ { + go func(t *testing.T) { + + c, err := js.PullSubscribe("", "consumer", nats.Bind("TEST", "consumer")) + require_NoError(t, err) + + for { + select { + case <-done: + return + default: + } + msgs, err := c.Fetch(10) + // We don't want to stop before at expected number of messages, as we want + // to also test against getting to many messages. + // Because of that, we ignore timeout and connection closed errors. + if err != nil && err != nats.ErrTimeout && err != nats.ErrConnectionClosed { + t.Errorf("error while fetching messages: %v", err) + } + + for _, msg := range msgs { + info, err := msg.Metadata() + require_NoError(t, err) + mu.Lock() + seqs = append(seqs, info.Sequence.Consumer) + mu.Unlock() + msg.Ack() + } + } + }(t) + } + + checkFor(t, time.Second*30, time.Second*1, func() error { + mu.Lock() + defer mu.Unlock() + if len(seqs) != 30_000 { + return fmt.Errorf("found %d messages instead of %d", len(seqs), 30_000) + } + sort.Slice(seqs, func(i, j int) bool { + return seqs[i] < seqs[j] + }) + for i := 1; i < len(seqs); i++ { + if seqs[i] != seqs[i-1]+1 { + fmt.Printf("seqs: %+v\n", seqs) + return fmt.Errorf("sequence mismatch at %v", i) + } + } + + return nil + }) + + close(done) + +} + +func TestJetStreamConsumerMultipleConsumersSingleFilter(t *testing.T) { + s := RunBasicJetStreamServer(t) + defer s.Shutdown() + + nc, js := jsClientConnect(t, s) + defer nc.Close() + acc := s.GlobalAccount() + + // Setup few subjects with varying messages count. + subjects := []struct { + subject string + messages int + wc bool + }{ + {subject: "one", messages: 5000}, + {subject: "two", messages: 7500}, + {subject: "three", messages: 2500}, + {subject: "four", messages: 1000}, + {subject: "five.>", messages: 3000, wc: true}, + } + + totalMsgs := 0 + for _, subject := range subjects { + totalMsgs += subject.messages + } + + // Setup consumers, filtering some of the messages from the stream. + consumers := []*struct { + name string + subjects []string + expectedMsgs int + delivered atomic.Int32 + }{ + {name: "C1", subjects: []string{"one"}, expectedMsgs: 5000}, + {name: "C2", subjects: []string{"two"}, expectedMsgs: 7500}, + {name: "C3", subjects: []string{"one"}, expectedMsgs: 5000}, + {name: "C4", subjects: []string{"one"}, expectedMsgs: 5000}, + } + + mset, err := acc.addStream(&StreamConfig{ + Name: "TEST", + Retention: LimitsPolicy, + Subjects: []string{"one", "two", "three", "four", "five.>"}, + MaxAge: time.Second * 90, + }) + require_NoError(t, err) + + for c, consumer := range consumers { + _, err := mset.addConsumer(&ConsumerConfig{ + Durable: consumer.name, + FilterSubjects: consumer.subjects, + AckPolicy: AckExplicit, + DeliverPolicy: DeliverAll, + AckWait: time.Second * 30, + DeliverSubject: nc.NewInbox(), + }) + require_NoError(t, err) + go func(c int, name string) { + _, err = js.Subscribe("", func(m *nats.Msg) { + require_NoError(t, m.Ack()) + require_NoError(t, err) + consumers[c].delivered.Add(1) + + }, nats.Bind("TEST", name)) + require_NoError(t, err) + }(c, consumer.name) + } + + // Publish with random intervals, while consumers are active. + rand.Seed(time.Now().UnixNano()) + var wg sync.WaitGroup + for _, subject := range subjects { + wg.Add(subject.messages) + go func(subject string, messages int, wc bool) { + nc, js := jsClientConnect(t, s) + defer nc.Close() + time.Sleep(time.Duration(rand.Int63n(1000)+1) * time.Millisecond) + for i := 0; i < messages; i++ { + time.Sleep(time.Duration(rand.Int63n(1000)+1) * time.Microsecond) + // If subject has wildcard, add random last subject token. + pubSubject := subject + if wc { + pubSubject = fmt.Sprintf("%v.%v", subject, rand.Int63n(10)) + } + _, err := js.PublishAsync(pubSubject, []byte("data")) + require_NoError(t, err) + wg.Done() + } + }(subject.subject, subject.messages, subject.wc) + } + wg.Wait() + + checkFor(t, time.Second*10, time.Millisecond*500, func() error { + for _, consumer := range consumers { + info, err := js.ConsumerInfo("TEST", consumer.name) + require_NoError(t, err) + if info.Delivered.Consumer != uint64(consumer.expectedMsgs) { + return fmt.Errorf("%v:expected consumer delivered seq %v, got %v. actually delivered: %v", consumer.name, consumer.expectedMsgs, info.Delivered.Consumer, consumer.delivered.Load()) + } + if info.AckFloor.Consumer != uint64(consumer.expectedMsgs) { + return fmt.Errorf("%v: expected consumer ack floor %v, got %v", consumer.name, totalMsgs, info.AckFloor.Consumer) + } + if consumer.delivered.Load() != int32(consumer.expectedMsgs) { + + return fmt.Errorf("%v: expected %v, got %v", consumer.name, consumer.expectedMsgs, consumer.delivered.Load()) + } + } + return nil + }) + +} + +func TestJetStreamConsumerMultipleConsumersMultipleFilters(t *testing.T) { + s := RunBasicJetStreamServer(t) + defer s.Shutdown() + + nc, js := jsClientConnect(t, s) + defer nc.Close() + acc := s.GlobalAccount() + + // Setup few subjects with varying messages count. + subjects := []struct { + subject string + messages int + wc bool + }{ + {subject: "one", messages: 50}, + {subject: "two", messages: 75}, + {subject: "three", messages: 250}, + {subject: "four", messages: 10}, + {subject: "five.>", messages: 300, wc: true}, + } + + totalMsgs := 0 + for _, subject := range subjects { + totalMsgs += subject.messages + } + + // Setup consumers, filtering some of the messages from the stream. + consumers := []*struct { + name string + subjects []string + expectedMsgs int + delivered atomic.Int32 + }{ + {name: "C1", subjects: []string{"one", "two"}, expectedMsgs: 125}, + {name: "C2", subjects: []string{"two", "three"}, expectedMsgs: 325}, + {name: "C3", subjects: []string{"one", "three"}, expectedMsgs: 300}, + {name: "C4", subjects: []string{"one", "five.>"}, expectedMsgs: 350}, + } + + mset, err := acc.addStream(&StreamConfig{ + Name: "TEST", + Retention: LimitsPolicy, + Subjects: []string{"one", "two", "three", "four", "five.>"}, + MaxAge: time.Second * 90, + }) + require_NoError(t, err) + + for c, consumer := range consumers { + _, err := mset.addConsumer(&ConsumerConfig{ + Durable: consumer.name, + FilterSubjects: consumer.subjects, + AckPolicy: AckExplicit, + DeliverPolicy: DeliverAll, + AckWait: time.Second * 30, + DeliverSubject: nc.NewInbox(), + }) + require_NoError(t, err) + go func(c int, name string) { + _, err = js.Subscribe("", func(m *nats.Msg) { + require_NoError(t, m.Ack()) + require_NoError(t, err) + consumers[c].delivered.Add(1) + + }, nats.Bind("TEST", name)) + require_NoError(t, err) + }(c, consumer.name) + } + + // Publish with random intervals, while consumers are active. + rand.Seed(time.Now().UnixNano()) + var wg sync.WaitGroup + for _, subject := range subjects { + wg.Add(subject.messages) + go func(subject string, messages int, wc bool) { + nc, js := jsClientConnect(t, s) + defer nc.Close() + time.Sleep(time.Duration(rand.Int63n(1000)+1) * time.Millisecond) + for i := 0; i < messages; i++ { + time.Sleep(time.Duration(rand.Int63n(1000)+1) * time.Microsecond) + // If subject has wildcard, add random last subject token. + pubSubject := subject + if wc { + pubSubject = fmt.Sprintf("%v.%v", subject, rand.Int63n(10)) + } + ack, err := js.PublishAsync(pubSubject, []byte("data")) + require_NoError(t, err) + go func() { + ack.Ok() + wg.Done() + }() + } + }(subject.subject, subject.messages, subject.wc) + } + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-time.After(time.Second * 15): + t.Fatalf("Timed out waiting for acks") + case <-done: + } + wg.Wait() + + checkFor(t, time.Second*15, time.Second*1, func() error { + for _, consumer := range consumers { + info, err := js.ConsumerInfo("TEST", consumer.name) + require_NoError(t, err) + if info.Delivered.Consumer != uint64(consumer.expectedMsgs) { + return fmt.Errorf("%v:expected consumer delivered seq %v, got %v. actually delivered: %v", consumer.name, consumer.expectedMsgs, info.Delivered.Consumer, consumer.delivered.Load()) + } + if info.AckFloor.Consumer != uint64(consumer.expectedMsgs) { + return fmt.Errorf("%v: expected consumer ack floor %v, got %v", consumer.name, totalMsgs, info.AckFloor.Consumer) + } + if consumer.delivered.Load() != int32(consumer.expectedMsgs) { + + return fmt.Errorf("%v: expected %v, got %v", consumer.name, consumer.expectedMsgs, consumer.delivered.Load()) + } + } + return nil + }) + +} + +func TestJetStreamConsumerMultipleFiltersSequence(t *testing.T) { + s := RunBasicJetStreamServer(t) + defer s.Shutdown() + + nc, js := jsClientConnect(t, s) + defer nc.Close() + acc := s.GlobalAccount() + + mset, err := acc.addStream(&StreamConfig{ + Name: "TEST", + Retention: LimitsPolicy, + Subjects: []string{"one", "two", "three", "four", "five.>"}, + MaxAge: time.Second * 90, + }) + require_NoError(t, err) + + _, err = mset.addConsumer(&ConsumerConfig{ + Durable: "DUR", + FilterSubjects: []string{"one", "two"}, + AckPolicy: AckExplicit, + DeliverPolicy: DeliverAll, + AckWait: time.Second * 30, + DeliverSubject: nc.NewInbox(), + }) + require_NoError(t, err) + + for i := 0; i < 20; i++ { + sendStreamMsg(t, nc, "one", fmt.Sprintf("%d", i)) + } + for i := 20; i < 40; i++ { + sendStreamMsg(t, nc, "two", fmt.Sprintf("%d", i)) + } + for i := 40; i < 60; i++ { + sendStreamMsg(t, nc, "one", fmt.Sprintf("%d", i)) + } + + sub, err := js.SubscribeSync("", nats.Bind("TEST", "DUR")) + require_NoError(t, err) + + for i := 0; i < 60; i++ { + msg, err := sub.NextMsg(time.Second * 1) + require_NoError(t, err) + require_True(t, string(msg.Data) == fmt.Sprintf("%d", i)) + } +}