From 154d4303a90ef90ca71ecb2b7e55e8d95190472d Mon Sep 17 00:00:00 2001 From: Derek Collison Date: Fri, 24 Jul 2020 10:11:32 -0700 Subject: [PATCH] Add in consumer rate limits Signed-off-by: Derek Collison --- server/consumer.go | 42 +++++++++++++++++++++++ test/jetstream_test.go | 78 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+) diff --git a/server/consumer.go b/server/consumer.go index bf785bb3..163b6d0f 100644 --- a/server/consumer.go +++ b/server/consumer.go @@ -28,6 +28,7 @@ import ( "time" "github.com/nats-io/nuid" + "golang.org/x/time/rate" ) type ConsumerInfo struct { @@ -52,6 +53,7 @@ type ConsumerConfig struct { MaxDeliver int `json:"max_deliver,omitempty"` FilterSubject string `json:"filter_subject,omitempty"` ReplayPolicy ReplayPolicy `json:"replay_policy"` + RateLimit uint64 `json:"rate_limit_bps,omitempty"` // Bits per sec SampleFrequency string `json:"sample_freq,omitempty"` } @@ -166,6 +168,7 @@ type Consumer struct { adflr uint64 asflr uint64 dsubj string + rlimit *rate.Limiter reqSub *subscription ackSub *subscription ackReplyT string @@ -224,6 +227,9 @@ func (mset *Stream) AddConsumer(config *ConsumerConfig) (*Consumer, error) { if config.Durable == _EMPTY_ { return nil, fmt.Errorf("consumer in pull mode requires a durable name") } + if config.RateLimit > 0 { + return nil, fmt.Errorf("consumer in pull mode can not have rate limit set") + } } // Setup proper default for ack wait if we are in explicit ack mode. @@ -374,6 +380,24 @@ func (mset *Stream) AddConsumer(config *ConsumerConfig) (*Consumer, error) { } } + // Check if we have a rate limit set. + if config.RateLimit != 0 { + // TODO(dlc) - Make sane values or error if not sane? + // We are configured in bits per sec so adjust to bytes. + rl := rate.Limit(config.RateLimit / 8) + // Burst should be set to maximum msg size for this account, etc. + var burst int + if mset.config.MaxMsgSize > 0 { + burst = int(mset.config.MaxMsgSize) + } else if mset.jsa.account.limits.mpay > 0 { + burst = int(mset.jsa.account.limits.mpay) + } else { + s := mset.jsa.account.srv + burst = int(s.getOpts().MaxPayload) + } + o.rlimit = rate.NewLimiter(rl, burst) + } + // Check if we have filtered subject that is a wildcard. if config.FilterSubject != _EMPTY_ && !subjectIsLiteral(config.FilterSubject) { o.filterWC = true @@ -1218,9 +1242,27 @@ func (o *Consumer) loopAndDeliverMsgs(s *Server, a *Account) { o.mu.Lock() } } + // Track this regardless. lts = ts + // If we have a rate limit set make sure we check that here. + if o.rlimit != nil { + now := time.Now() + r := o.rlimit.ReserveN(now, len(msg)+len(hdr)+len(subj)+len(dsubj)+len(o.ackReplyT)) + delay := r.DelayFrom(now) + if delay > 0 { + qch := o.qch + o.mu.Unlock() + select { + case <-qch: + return + case <-time.After(delay): + } + o.mu.Lock() + } + } + o.deliverMsg(dsubj, subj, hdr, msg, seq, dcnt, ts) o.mu.Unlock() diff --git a/test/jetstream_test.go b/test/jetstream_test.go index 3507babb..f63cf97f 100644 --- a/test/jetstream_test.go +++ b/test/jetstream_test.go @@ -2617,6 +2617,84 @@ func TestJetStreamDeleteStreamManyConsumers(t *testing.T) { mset.Delete() } +func TestJetStreamConsumerRateLimit(t *testing.T) { + s := RunBasicJetStreamServer() + defer s.Shutdown() + + if config := s.JetStreamConfig(); config != nil { + defer os.RemoveAll(config.StoreDir) + } + + mname := "RATELIMIT" + mset, err := s.GlobalAccount().AddStream(&server.StreamConfig{Name: mname, Storage: server.FileStorage}) + if err != nil { + t.Fatalf("Unexpected error adding stream: %v", err) + } + + nc := clientConnectToServer(t, s) + defer nc.Close() + + msgSize := 128 * 1024 + msg := make([]byte, msgSize) + rand.Read(msg) + + // 10MB + totalSize := 10 * 1024 * 1024 + toSend := totalSize / msgSize + for i := 0; i < toSend; i++ { + nc.Publish(mname, msg) + } + nc.Flush() + state := mset.State() + if state.Msgs != uint64(toSend) { + t.Fatalf("Expected %d messages, got %d", toSend, state.Msgs) + } + + // 100Mbit + rateLimit := uint64(100 * 1024 * 1024) + // Make sure if you set a rate with a pull based consumer it errors. + _, err = mset.AddConsumer(&server.ConsumerConfig{Durable: "to", AckPolicy: server.AckExplicit, RateLimit: rateLimit}) + if err == nil { + t.Fatalf("Expected an error, got none") + } + + // Now create one and measure the rate delivered. + o, err := mset.AddConsumer(&server.ConsumerConfig{ + Durable: "rate", + DeliverSubject: "to", + RateLimit: rateLimit, + AckPolicy: server.AckNone}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer o.Delete() + + var received int + done := make(chan bool) + + start := time.Now() + + nc.Subscribe("to", func(m *nats.Msg) { + received++ + if received >= toSend { + done <- true + } + }) + nc.Flush() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive all the messages in time") + } + + tt := time.Since(start) + rate := float64(8*toSend*msgSize) / tt.Seconds() + if rate > float64(rateLimit)*1.25 { + t.Fatalf("Exceeded desired rate of %d mbps, got %0.f mbps", rateLimit/(1024*1024), rate/(1024*1024)) + } +} + func TestJetStreamEphemeralConsumerRecoveryAfterServerRestart(t *testing.T) { s := RunBasicJetStreamServer() defer s.Shutdown()