diff --git a/server/mqtt.go b/server/mqtt.go index 27fa8efd..a371268d 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -3601,8 +3601,9 @@ func (sess *mqttSession) processJSConsumer(c *client, subject, sid string, FilterSubject: mqttStreamSubjectPrefix + subject, AckWait: ackWait, MaxAckPending: maxAckPending, + MemoryStorage: opts.MQTT.ConsumerMemoryStorage, } - if r := c.srv.getOpts().MQTT.ConsumerReplicas; r > 0 { + if r := opts.MQTT.ConsumerReplicas; r > 0 { cc.Replicas = r } if err := sess.createConsumer(cc); err != nil { diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 85ab35d9..ba38c454 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -5734,6 +5734,7 @@ func TestMQTTConsumerReplicasOverride(t *testing.T) { listen: 127.0.0.1:-1 stream_replicas: 5 consumer_replicas: 1 + consumer_memory_storage: true } # For access to system account. @@ -5771,6 +5772,7 @@ func TestMQTTConsumerReplicasOverride(t *testing.T) { cl.stopAll() for _, o := range cl.opts { o.MQTT.ConsumerReplicas = 2 + o.MQTT.ConsumerMemoryStorage = false } cl.restartAllSamePorts() cl.waitOnStreamLeader(globalAccountName, mqttStreamName) @@ -5787,9 +5789,10 @@ func TestMQTTConsumerReplicasReload(t *testing.T) { mqtt { port: -1 consumer_replicas: %v + consumer_memory_storage: %s } ` - conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, 3))) + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, 3, "false"))) defer removeFile(t, conf) s, o := RunServerWithConfig(conf) defer testMQTTShutdownServer(s) @@ -5811,9 +5814,27 @@ func TestMQTTConsumerReplicasReload(t *testing.T) { t.Fatalf("Did not get the error regarding replicas count") } - reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, 1)) + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, 1, "true")) testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1}) + + mset, err := s.GlobalAccount().lookupStream(mqttStreamName) + if err != nil { + t.Fatalf("Error looking up stream: %v", err) + } + var cons *consumer + mset.mu.RLock() + for _, c := range mset.consumers { + cons = c + break + } + mset.mu.RUnlock() + cons.mu.RLock() + st := cons.store.Type() + cons.mu.RUnlock() + if st != MemoryStorage { + t.Fatalf("Expected storage %v, got %v", MemoryStorage, st) + } } func TestMQTTConsumerReplicasExceedsParentStream(t *testing.T) { diff --git a/server/opts.go b/server/opts.go index f82d6880..3a32d5ea 100644 --- a/server/opts.go +++ b/server/opts.go @@ -427,6 +427,10 @@ type MQTTOpts struct { // Note that existing consumers are not modified. ConsumerReplicas int + // Indicate if the consumers should be created with memory storage. + // Note that existing consumers are not modified. + ConsumerMemoryStorage bool + // Timeout for the authentication process. AuthTimeout float64 @@ -4185,6 +4189,8 @@ func parseMQTT(v interface{}, o *Options, errors *[]error, warnings *[]error) er o.MQTT.StreamReplicas = int(mv.(int64)) case "consumer_replicas": o.MQTT.ConsumerReplicas = int(mv.(int64)) + case "consumer_memory_storage": + o.MQTT.ConsumerMemoryStorage = mv.(bool) default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ diff --git a/server/reload.go b/server/reload.go index 4f0008ec..bce1e3f7 100644 --- a/server/reload.go +++ b/server/reload.go @@ -666,6 +666,15 @@ func (o *mqttConsumerReplicasReload) Apply(s *Server) { s.Noticef("Reloaded: MQTT consumer_replicas = %v", o.newValue) } +type mqttConsumerMemoryStorageReload struct { + noopOption + newValue bool +} + +func (o *mqttConsumerMemoryStorageReload) Apply(s *Server) { + s.Noticef("Reloaded: MQTT consumer_memory_storage = %v", o.newValue) +} + // Compares options and disconnects clients that are no longer listed in pinned certs. Lock must not be held. func (s *Server) recheckPinnedCerts(curOpts *Options, newOpts *Options) { s.mu.Lock() @@ -1198,12 +1207,13 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { diffOpts = append(diffOpts, &mqttMaxAckPendingReload{newValue: newValue.(MQTTOpts).MaxAckPending}) diffOpts = append(diffOpts, &mqttStreamReplicasReload{newValue: newValue.(MQTTOpts).StreamReplicas}) diffOpts = append(diffOpts, &mqttConsumerReplicasReload{newValue: newValue.(MQTTOpts).ConsumerReplicas}) + diffOpts = append(diffOpts, &mqttConsumerMemoryStorageReload{newValue: newValue.(MQTTOpts).ConsumerMemoryStorage}) // Nil out/set to 0 the options that we allow to be reloaded so that // we only fail reload if some that we don't support are changed. tmpOld := oldValue.(MQTTOpts) tmpNew := newValue.(MQTTOpts) - tmpOld.TLSConfig, tmpOld.AckWait, tmpOld.MaxAckPending, tmpOld.StreamReplicas, tmpOld.ConsumerReplicas = nil, 0, 0, 0, 0 - tmpNew.TLSConfig, tmpNew.AckWait, tmpNew.MaxAckPending, tmpNew.StreamReplicas, tmpNew.ConsumerReplicas = nil, 0, 0, 0, 0 + tmpOld.TLSConfig, tmpOld.AckWait, tmpOld.MaxAckPending, tmpOld.StreamReplicas, tmpOld.ConsumerReplicas, tmpOld.ConsumerMemoryStorage = nil, 0, 0, 0, 0, false + tmpNew.TLSConfig, tmpNew.AckWait, tmpNew.MaxAckPending, tmpNew.StreamReplicas, tmpNew.ConsumerReplicas, tmpNew.ConsumerMemoryStorage = nil, 0, 0, 0, 0, false if !reflect.DeepEqual(tmpOld, tmpNew) { // See TODO(ik) note below about printing old/new values. return nil, fmt.Errorf("config reload not supported for %s: old=%v, new=%v", @@ -1213,6 +1223,7 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { tmpNew.MaxAckPending = newValue.(MQTTOpts).MaxAckPending tmpNew.StreamReplicas = newValue.(MQTTOpts).StreamReplicas tmpNew.ConsumerReplicas = newValue.(MQTTOpts).ConsumerReplicas + tmpNew.ConsumerMemoryStorage = newValue.(MQTTOpts).ConsumerMemoryStorage case "connecterrorreports": diffOpts = append(diffOpts, &connectErrorReports{newValue: newValue.(int)}) case "reconnecterrorreports":