diff --git a/server/mqtt.go b/server/mqtt.go index 5584b27a..018d3577 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -112,7 +112,7 @@ const ( // Stream name for MQTT retained messages on a given account mqttRetainedMsgsStreamName = mqttStreamNamePrefix + "rmsgs" - mqttRetainedMsgsStreamSubject = "$MQTT.rmsgs" + mqttRetainedMsgsStreamSubject = "$MQTT.rmsgs." // Stream name for MQTT sessions on a given account mqttSessStreamName = mqttStreamNamePrefix + "sess" @@ -145,6 +145,7 @@ const ( mqttJSAIdTokenPos = 3 mqttJSATokenPos = 4 mqttJSAStreamCreate = "SC" + mqttJSAStreamUpdate = "SU" mqttJSAStreamLookup = "SL" mqttJSAStreamDel = "SD" mqttJSAConsumerCreate = "CC" @@ -1150,11 +1151,12 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc } else if si == nil { // Create the stream for retained messages. cfg := &StreamConfig{ - Name: mqttRetainedMsgsStreamName, - Subjects: []string{mqttRetainedMsgsStreamSubject}, - Storage: FileStorage, - Retention: LimitsPolicy, - Replicas: replicas, + Name: mqttRetainedMsgsStreamName, + Subjects: []string{mqttRetainedMsgsStreamSubject + as.domainTk + ">"}, + Storage: FileStorage, + Retention: LimitsPolicy, + Replicas: replicas, + MaxMsgsPer: 1, } // We will need "si" outside of this block. si, _, err = jsa.createStream(cfg) @@ -1170,6 +1172,39 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc } } } + // Doing this check outside of above if/else due to possible race when + // creating the stream. + wantedSubj := mqttRetainedMsgsStreamSubject + as.domainTk + ">" + if len(si.Config.Subjects) != 1 || si.Config.Subjects[0] != wantedSubj { + // Update only the Subjects at this stage, not MaxMsgsPer yet. + si.Config.Subjects = []string{wantedSubj} + if si, err = jsa.updateStream(&si.Config); err != nil { + return nil, fmt.Errorf("failed to update stream config: %w", err) + } + } + // Try to transfer regardless if we have already updated the stream or not + // in case not all messages were transferred and the server was restarted. + if as.transferRetainedToPerKeySubjectStream(s) { + // We need another lookup to have up-to-date si.State values in order + // to load all retained messages. + si, err = lookupStream(mqttRetainedMsgsStreamName, "retained messages") + if err != nil { + return nil, err + } + } + // Now, if the stream does not have MaxMsgsPer set to 1, and there are no + // more messages on the single $MQTT.rmsgs subject, update the stream again. + if si.Config.MaxMsgsPer != 1 { + _, err := jsa.loadNextMsgFor(mqttRetainedMsgsStreamName, "$MQTT.rmsgs") + // Looking for an error indicated that there is no such message. + if err != nil && IsNatsErr(err, JSNoMessageFoundErr) { + si.Config.MaxMsgsPer = 1 + // We will need an up-to-date si, so don't use local variable here. + if si, err = jsa.updateStream(&si.Config); err != nil { + return nil, fmt.Errorf("failed to update stream config: %w", err) + } + } + } var lastSeq uint64 var rmDoneCh chan struct{} @@ -1199,7 +1234,7 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc Stream: mqttRetainedMsgsStreamName, Config: ConsumerConfig{ Durable: rmDurName, - FilterSubject: mqttRetainedMsgsStreamSubject, + FilterSubject: mqttRetainedMsgsStreamSubject + as.domainTk + ">", DeliverSubject: rmsubj, ReplayPolicy: ReplayInstant, AckPolicy: AckNone, @@ -1353,6 +1388,19 @@ func (jsa *mqttJSA) createStream(cfg *StreamConfig) (*StreamInfo, bool, error) { return scr.StreamInfo, scr.DidCreate, scr.ToError() } +func (jsa *mqttJSA) updateStream(cfg *StreamConfig) (*StreamInfo, error) { + cfgb, err := json.Marshal(cfg) + if err != nil { + return nil, err + } + scri, err := jsa.newRequest(mqttJSAStreamUpdate, fmt.Sprintf(JSApiStreamUpdateT, cfg.Name), 0, cfgb) + if err != nil { + return nil, err + } + scr := scri.(*JSApiStreamUpdateResponse) + return scr.StreamInfo, scr.ToError() +} + func (jsa *mqttJSA) lookupStream(name string) (*StreamInfo, error) { slri, err := jsa.newRequest(mqttJSAStreamLookup, fmt.Sprintf(JSApiStreamInfoT, name), 0, nil) if err != nil { @@ -1385,6 +1433,20 @@ func (jsa *mqttJSA) loadLastMsgFor(streamName string, subject string) (*StoredMs return lmr.Message, lmr.ToError() } +func (jsa *mqttJSA) loadNextMsgFor(streamName string, subject string) (*StoredMsg, error) { + mreq := &JSApiMsgGetRequest{NextFor: subject} + req, err := json.Marshal(mreq) + if err != nil { + return nil, err + } + lmri, err := jsa.newRequest(mqttJSAMsgLoad, fmt.Sprintf(JSApiMsgGetT, streamName), 0, req) + if err != nil { + return nil, err + } + lmr := lmri.(*JSApiMsgGetResponse) + return lmr.Message, lmr.ToError() +} + func (jsa *mqttJSA) loadMsg(streamName string, seq uint64) (*StoredMsg, error) { mreq := &JSApiMsgGetRequest{Seq: seq} req, err := json.Marshal(mreq) @@ -1464,6 +1526,12 @@ func (as *mqttAccountSessionManager) processJSAPIReplies(_ *subscription, pc *cl resp.Error = NewJSInvalidJSONError() } ch <- resp + case mqttJSAStreamUpdate: + var resp = &JSApiStreamUpdateResponse{} + if err := json.Unmarshal(msg, resp); err != nil { + resp.Error = NewJSInvalidJSONError() + } + ch <- resp case mqttJSAStreamLookup: var resp = &JSApiStreamInfoResponse{} if err := json.Unmarshal(msg, &resp); err != nil { @@ -2260,6 +2328,56 @@ func (as *mqttAccountSessionManager) transferUniqueSessStreamsToMuxed(log *Serve retry = false } +func (as *mqttAccountSessionManager) transferRetainedToPerKeySubjectStream(log *Server) bool { + jsa := &as.jsa + var count, errors int + + for { + // Try and look up messages on the original undivided "$MQTT.rmsgs" subject. + // If nothing is returned here, we assume to have migrated all old messages. + smsg, err := jsa.loadNextMsgFor(mqttRetainedMsgsStreamName, "$MQTT.rmsgs") + if err != nil { + if IsNatsErr(err, JSNoMessageFoundErr) { + // We've ran out of messages to transfer so give up. + break + } + log.Warnf(" Unable to load retained message with sequence %d: %s", smsg.Sequence, err) + errors++ + break + } + // Unmarshal the message so that we can obtain the subject name. + var rmsg mqttRetainedMsg + if err := json.Unmarshal(smsg.Data, &rmsg); err != nil { + log.Warnf(" Unable to unmarshal retained message with sequence %d, skipping", smsg.Sequence) + errors++ + continue + } + // Store the message again, this time with the new per-key subject. + subject := mqttRetainedMsgsStreamSubject + as.domainTk + rmsg.Subject + if _, err := jsa.storeMsgWithKind(mqttJSASessPersist, subject, 0, smsg.Data); err != nil { + log.Errorf(" Unable to transfer the retained message with sequence %d: %v", smsg.Sequence, err) + errors++ + continue + } + // Delete the original message. + if err := jsa.deleteMsg(mqttRetainedMsgsStreamName, smsg.Sequence, true); err != nil { + log.Errorf(" Unable to clean up the retained message with sequence %d: %v", smsg.Sequence, err) + errors++ + continue + } + count++ + } + if errors > 0 { + next := mqttDefaultTransferRetry + log.Warnf("Failed to transfer %d MQTT retained messages, will try again in %v", errors, next) + time.AfterFunc(next, func() { as.transferRetainedToPerKeySubjectStream(log) }) + } else if count > 0 { + log.Noticef("Transfer of %d MQTT retained messages done!", count) + } + // Signal if there was any activity (either some transferred or some errors) + return errors > 0 || count > 0 +} + ////////////////////////////////////////////////////////////////////////////// // // MQTT session related functions @@ -3092,7 +3210,7 @@ func (c *client) mqttHandlePubRetain() { Source: c.opts.Username, } rmBytes, _ := json.Marshal(rm) - smr, err := asm.jsa.storeMsg(mqttRetainedMsgsStreamSubject, -1, rmBytes) + smr, err := asm.jsa.storeMsg(mqttRetainedMsgsStreamSubject+asm.domainTk+key, -1, rmBytes) if err == nil { // Update the new sequence rm.sseq = smr.Sequence diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 9c9b0b8f..31f59da9 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -1979,6 +1979,11 @@ func testMQTTCheckPubMsgNoAck(t testing.TB, c net.Conn, r *mqttReader, topic str } func testMQTTGetPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, payload []byte) (byte, uint16) { + flags, pi, _ := testMQTTGetPubMsgEx(t, c, r, topic, payload) + return flags, pi +} + +func testMQTTGetPubMsgEx(t testing.TB, c net.Conn, r *mqttReader, topic string, payload []byte) (byte, uint16, string) { t.Helper() b, pl := testMQTTReadPacket(t, r) if pt := b & mqttPacketMask; pt != mqttPacketPub { @@ -1991,7 +1996,7 @@ func testMQTTGetPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, pa if err != nil { t.Fatal(err) } - if ptopic != topic { + if topic != _EMPTY_ && ptopic != topic { t.Fatalf("Expected topic %q, got %q", topic, ptopic) } var pi uint16 @@ -2011,7 +2016,7 @@ func testMQTTGetPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, pa t.Fatalf("Expected payload %q, got %q", payload, ppayload) } r.pos += msgLen - return pflags, pi + return pflags, pi, ptopic } func testMQTTSendPubAck(t testing.TB, c net.Conn, pi uint16) { @@ -2993,6 +2998,88 @@ func TestMQTTRetainedMsgNetworkUpdates(t *testing.T) { } } +func TestMQTTRetainedMsgMigration(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc, js := jsClientConnect(t, s) + defer nc.Close() + + // Create the retained messages stream to listen on the old subject first. + // The server will correct this when the migration takes place. + _, err := js.AddStream(&nats.StreamConfig{ + Name: mqttRetainedMsgsStreamName, + Subjects: []string{`$MQTT.rmsgs`}, + Storage: nats.FileStorage, + Retention: nats.LimitsPolicy, + Replicas: 1, + }) + require_NoError(t, err) + + // Publish some retained messages on the old "$MQTT.rmsgs" subject. + for i := 0; i < 100; i++ { + msg := fmt.Sprintf( + `{"origin":"b5IQZNtG","subject":"test%d","topic":"test%d","msg":"YmFy","flags":1}`, i, i, + ) + _, err := js.Publish(`$MQTT.rmsgs`, []byte(msg)) + require_NoError(t, err) + } + + // Check that the old subject looks right. + si, err := js.StreamInfo(mqttRetainedMsgsStreamName, &nats.StreamInfoRequest{ + SubjectsFilter: `$MQTT.>`, + }) + require_NoError(t, err) + if si.State.NumSubjects != 1 { + t.Fatalf("expected 1 subject, got %d", si.State.NumSubjects) + } + if n := si.State.Subjects[`$MQTT.rmsgs`]; n != 100 { + t.Fatalf("expected to find 100 messages on the original subject but found %d", n) + } + + // Create an MQTT client, this will cause a migration to take place. + mc, rc := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, rc, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, rc, []*mqttFilter{{filter: "+", qos: 0}}, []byte{0}) + topics := map[string]struct{}{} + for i := 0; i < 100; i++ { + _, _, topic := testMQTTGetPubMsgEx(t, mc, rc, _EMPTY_, []byte("bar")) + topics[topic] = struct{}{} + } + if len(topics) != 100 { + t.Fatalf("Unexpected topics: %v", topics) + } + + // Now look at the stream, there should be 100 messages on the new + // divided subjects and none on the old undivided subject. + si, err = js.StreamInfo(mqttRetainedMsgsStreamName, &nats.StreamInfoRequest{ + SubjectsFilter: `$MQTT.>`, + }) + require_NoError(t, err) + if si.State.NumSubjects != 100 { + t.Fatalf("expected 100 subjects, got %d", si.State.NumSubjects) + } + if n := si.State.Subjects[`$MQTT.rmsgs`]; n > 0 { + t.Fatalf("expected to find no messages on the original subject but found %d", n) + } + + // Check that the message counts look right. There should be one + // retained message per key. + for i := 0; i < 100; i++ { + expected := fmt.Sprintf(`$MQTT.rmsgs.test%d`, i) + n, ok := si.State.Subjects[expected] + if !ok { + t.Fatalf("expected to find %q but didn't", expected) + } + if n != 1 { + t.Fatalf("expected %q to have 1 message but had %d", expected, n) + } + } +} + func TestMQTTClusterReplicasCount(t *testing.T) { for _, test := range []struct { size int