From 1ddc5bd9f614573c85b66c1c440a10ed54ae6b2e Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Tue, 17 May 2022 18:06:12 -0600 Subject: [PATCH] Added consumer_replicas (similar to stream_replicas) Signed-off-by: Ivan Kozlovic --- server/mqtt.go | 7 ++ server/mqtt_test.go | 182 +++++++++++++++++++++++++++++++++++++++++++- server/opts.go | 11 +++ server/reload.go | 15 +++- 4 files changed, 212 insertions(+), 3 deletions(-) diff --git a/server/mqtt.go b/server/mqtt.go index e570da45..27fa8efd 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -577,6 +577,10 @@ func validateMQTTOptions(o *Options) error { if err := validatePinnedCerts(mo.TLSPinnedCerts); err != nil { return fmt.Errorf("mqtt: %v", err) } + if mo.ConsumerReplicas > 0 && mo.StreamReplicas > 0 && mo.ConsumerReplicas > mo.StreamReplicas { + return fmt.Errorf("mqtt: consumer_replicas (%v) cannot be higher than stream_replicas (%v)", + mo.ConsumerReplicas, mo.StreamReplicas) + } return nil } @@ -3598,6 +3602,9 @@ func (sess *mqttSession) processJSConsumer(c *client, subject, sid string, AckWait: ackWait, MaxAckPending: maxAckPending, } + if r := c.srv.getOpts().MQTT.ConsumerReplicas; r > 0 { + cc.Replicas = r + } if err := sess.createConsumer(cc); err != nil { c.Errorf("Unable to add JetStream consumer for subscription on %q: err=%v", subject, err) return nil, nil, err diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 5969b227..85ab35d9 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -5557,6 +5557,8 @@ func TestMQTTStreamReplicasOverride(t *testing.T) { defer cl.shutdown() connectAndCheck := func(restarted bool) { + t.Helper() + o := cl.opts[0] mc, r := testMQTTConnectRetry(t, &mqttConnInfo{clientID: "test", cleanSess: false}, o.MQTT.Host, o.MQTT.Port, 5) defer mc.Close() @@ -5662,9 +5664,11 @@ func TestMQTTStreamReplicasInsufficientResources(t *testing.T) { defer cl.shutdown() l := &captureErrorLogger{errCh: make(chan string, 10)} + for _, s := range cl.servers { + s.SetLogger(l, false, false) + } o := cl.opts[1] - cl.servers[1].SetLogger(l, false, false) _, _, err := testMQTTConnectRetryWithError(t, &mqttConnInfo{clientID: "mqtt", cleanSess: false}, o.MQTT.Host, o.MQTT.Port, 0) if err == nil { t.Fatal("Expected to fail, did not") @@ -5680,6 +5684,182 @@ func TestMQTTStreamReplicasInsufficientResources(t *testing.T) { } } +func TestMQTTConsumerReplicasValidate(t *testing.T) { + o := testMQTTDefaultOptions() + for _, test := range []struct { + name string + sr int + cr int + err bool + }{ + {"stream replicas neg", -1, 3, false}, + {"stream replicas 0", 0, 3, false}, + {"consumer replicas neg", 0, -1, false}, + {"consumer replicas 0", -1, 0, false}, + {"consumer replicas too high", 1, 2, true}, + } { + t.Run(test.name, func(t *testing.T) { + o.MQTT.StreamReplicas = test.sr + o.MQTT.ConsumerReplicas = test.cr + err := validateMQTTOptions(o) + if test.err { + if err == nil { + t.Fatal("Expected error, did not get one") + } + if !strings.Contains(err.Error(), "cannot be higher") { + t.Fatalf("Unexpected error: %v", err) + } + // OK + return + } else if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + }) + } +} + +func TestMQTTConsumerReplicasOverride(t *testing.T) { + conf := ` + listen: 127.0.0.1:-1 + server_name: %s + jetstream: {max_mem_store: 256MB, max_file_store: 2GB, store_dir: '%s'} + + cluster { + name: %s + listen: 127.0.0.1:%d + routes = [%s] + } + + mqtt { + listen: 127.0.0.1:-1 + stream_replicas: 5 + consumer_replicas: 1 + } + + # For access to system account. + accounts { $SYS { users = [ { user: "admin", pass: "s3cr3t!" } ] } } + ` + cl := createJetStreamClusterWithTemplate(t, conf, "MQTT", 5) + defer cl.shutdown() + + connectAndCheck := func(subject string, restarted bool) { + t.Helper() + + o := cl.opts[0] + mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: "test", cleanSess: false}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, restarted) + testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1}) + + nc, js := jsClientConnect(t, cl.servers[2]) + defer nc.Close() + + for ci := range js.ConsumersInfo(mqttStreamName) { + if ci.Config.FilterSubject == mqttStreamSubjectPrefix+"foo" { + if len(ci.Cluster.Replicas) != 0 { + t.Fatalf("Expected consumer to be R1, got: %+v", ci.Cluster) + } + } else { + if len(ci.Cluster.Replicas) != 1 { + t.Fatalf("Expected consumer to be R2, got: %+v", ci.Cluster) + } + } + } + } + connectAndCheck("foo", false) + + cl.stopAll() + for _, o := range cl.opts { + o.MQTT.ConsumerReplicas = 2 + } + cl.restartAllSamePorts() + cl.waitOnStreamLeader(globalAccountName, mqttStreamName) + cl.waitOnStreamLeader(globalAccountName, mqttRetainedMsgsStreamName) + cl.waitOnStreamLeader(globalAccountName, mqttSessStreamName) + + connectAndCheck("bar", true) +} + +func TestMQTTConsumerReplicasReload(t *testing.T) { + tmpl := ` + jetstream: enabled + server_name: mqtt + mqtt { + port: -1 + consumer_replicas: %v + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, 3))) + defer removeFile(t, conf) + s, o := RunServerWithConfig(conf) + defer testMQTTShutdownServer(s) + + l := &captureErrorLogger{errCh: make(chan string, 10)} + s.SetLogger(l, false, false) + + c, r := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: false}, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{mqttSubAckFailure}) + + select { + case e := <-l.errCh: + if !strings.Contains(e, NewJSStreamReplicasNotSupportedError().Description) { + t.Fatalf("Expected error regarding replicas, got %v", e) + } + case <-time.After(2 * time.Second): + t.Fatalf("Did not get the error regarding replicas count") + } + + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, 1)) + + testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1}) +} + +func TestMQTTConsumerReplicasExceedsParentStream(t *testing.T) { + conf := ` + listen: 127.0.0.1:-1 + server_name: %s + jetstream: {max_mem_store: 256MB, max_file_store: 2GB, store_dir: '%s'} + + cluster { + name: %s + listen: 127.0.0.1:%d + routes = [%s] + } + + mqtt { + listen: 127.0.0.1:-1 + consumer_replicas: 4 + } + + # For access to system account. + accounts { $SYS { users = [ { user: "admin", pass: "s3cr3t!" } ] } } + ` + cl := createJetStreamClusterWithTemplate(t, conf, "MQTT", 3) + defer cl.shutdown() + + l := &captureErrorLogger{errCh: make(chan string, 10)} + for _, s := range cl.servers { + s.SetLogger(l, false, false) + } + + o := cl.opts[0] + mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: "test", cleanSess: false}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{mqttSubAckFailure}) + + select { + case e := <-l.errCh: + if !strings.Contains(e, NewJSConsumerReplicasExceedsStreamError().Description) { + t.Fatalf("Expected error regarding replicas exceeded parent, got %v", e) + } + case <-time.After(2 * time.Second): + t.Fatalf("Did not get the error regarding replicas count") + } +} + ////////////////////////////////////////////////////////////////////////// // // Benchmarks diff --git a/server/opts.go b/server/opts.go index 585dcb16..f82d6880 100644 --- a/server/opts.go +++ b/server/opts.go @@ -418,6 +418,15 @@ type MQTTOpts struct { // count is not modified. Use the NATS CLI to update the count if desired. StreamReplicas int + // Number of replicas for MQTT consumers. + // Negative or 0 value means that there is no override and the consumer + // will have the same replica factor that the stream it belongs to. + // If a value is specified, it will require to be lower than the stream + // replicas count (lower than StreamReplicas if specified, but also lower + // than the automatic value determined by cluster size). + // Note that existing consumers are not modified. + ConsumerReplicas int + // Timeout for the authentication process. AuthTimeout float64 @@ -4174,6 +4183,8 @@ func parseMQTT(v interface{}, o *Options, errors *[]error, warnings *[]error) er o.MQTT.JsDomain = mv.(string) case "stream_replicas": o.MQTT.StreamReplicas = int(mv.(int64)) + case "consumer_replicas": + o.MQTT.ConsumerReplicas = int(mv.(int64)) default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ diff --git a/server/reload.go b/server/reload.go index 3b6468fa..4f0008ec 100644 --- a/server/reload.go +++ b/server/reload.go @@ -657,6 +657,15 @@ func (o *mqttStreamReplicasReload) Apply(s *Server) { s.Noticef("Reloaded: MQTT stream_replicas = %v", o.newValue) } +type mqttConsumerReplicasReload struct { + noopOption + newValue int +} + +func (o *mqttConsumerReplicasReload) Apply(s *Server) { + s.Noticef("Reloaded: MQTT consumer_replicas = %v", o.newValue) +} + // Compares options and disconnects clients that are no longer listed in pinned certs. Lock must not be held. func (s *Server) recheckPinnedCerts(curOpts *Options, newOpts *Options) { s.mu.Lock() @@ -1188,12 +1197,13 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { diffOpts = append(diffOpts, &mqttAckWaitReload{newValue: newValue.(MQTTOpts).AckWait}) diffOpts = append(diffOpts, &mqttMaxAckPendingReload{newValue: newValue.(MQTTOpts).MaxAckPending}) diffOpts = append(diffOpts, &mqttStreamReplicasReload{newValue: newValue.(MQTTOpts).StreamReplicas}) + diffOpts = append(diffOpts, &mqttConsumerReplicasReload{newValue: newValue.(MQTTOpts).ConsumerReplicas}) // Nil out/set to 0 the options that we allow to be reloaded so that // we only fail reload if some that we don't support are changed. tmpOld := oldValue.(MQTTOpts) tmpNew := newValue.(MQTTOpts) - tmpOld.TLSConfig, tmpOld.AckWait, tmpOld.MaxAckPending, tmpOld.StreamReplicas = nil, 0, 0, 0 - tmpNew.TLSConfig, tmpNew.AckWait, tmpNew.MaxAckPending, tmpNew.StreamReplicas = nil, 0, 0, 0 + tmpOld.TLSConfig, tmpOld.AckWait, tmpOld.MaxAckPending, tmpOld.StreamReplicas, tmpOld.ConsumerReplicas = nil, 0, 0, 0, 0 + tmpNew.TLSConfig, tmpNew.AckWait, tmpNew.MaxAckPending, tmpNew.StreamReplicas, tmpNew.ConsumerReplicas = nil, 0, 0, 0, 0 if !reflect.DeepEqual(tmpOld, tmpNew) { // See TODO(ik) note below about printing old/new values. return nil, fmt.Errorf("config reload not supported for %s: old=%v, new=%v", @@ -1202,6 +1212,7 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { tmpNew.AckWait = newValue.(MQTTOpts).AckWait tmpNew.MaxAckPending = newValue.(MQTTOpts).MaxAckPending tmpNew.StreamReplicas = newValue.(MQTTOpts).StreamReplicas + tmpNew.ConsumerReplicas = newValue.(MQTTOpts).ConsumerReplicas case "connecterrorreports": diffOpts = append(diffOpts, &connectErrorReports{newValue: newValue.(int)}) case "reconnecterrorreports":