From ad2e9d7b8d91a42fde5503c75e14f3f2e0fc8cfc Mon Sep 17 00:00:00 2001 From: Lev Brouk Date: Mon, 28 Aug 2023 11:52:01 -0700 Subject: [PATCH] MQTT QoS2 support --- server/client.go | 2 +- server/mqtt.go | 1204 +++++++++++++++++++++++++++++++++---------- server/mqtt_test.go | 487 ++++++++++++++--- server/opts.go | 27 +- 4 files changed, 1355 insertions(+), 365 deletions(-) diff --git a/server/client.go b/server/client.go index 41997db5..a324baa4 100644 --- a/server/client.go +++ b/server/client.go @@ -4223,7 +4223,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt c.pa.reply = nrr if changed && c.isMqtt() && c.pa.hdr > 0 { - c.srv.mqttStoreQoS1MsgForAccountOnNewSubject(c.pa.hdr, msg, siAcc.GetName(), to) + c.srv.mqttStoreQoSMsgForAccountOnNewSubject(c.pa.hdr, msg, siAcc.GetName(), to) } // FIXME(dlc) - Do L1 cache trick like normal client? diff --git a/server/mqtt.go b/server/mqtt.go index 9c59b42b..6286fccb 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -68,7 +68,8 @@ const ( mqttPubFlagRetain = byte(0x01) mqttPubFlagQoS = byte(0x06) mqttPubFlagDup = byte(0x08) - mqttPubQos1 = byte(0x2) // 1 << 1 + mqttPubQos1 = byte(0x1 << 1) + mqttPubQoS2 = byte(0x2 << 1) // Subscribe flags mqttSubscribeFlags = byte(0x2) @@ -104,28 +105,34 @@ const ( // MQTT subscriptions to start with this. mqttSubPrefix = "$MQTT.sub." - // MQTT Stream names prefix. - mqttStreamNamePrefix = "$MQTT_" - // Stream name for MQTT messages on a given account - mqttStreamName = mqttStreamNamePrefix + "msgs" + mqttStreamName = "$MQTT_msgs" mqttStreamSubjectPrefix = "$MQTT.msgs." // Stream name for MQTT retained messages on a given account - mqttRetainedMsgsStreamName = mqttStreamNamePrefix + "rmsgs" + mqttRetainedMsgsStreamName = "$MQTT_rmsgs" mqttRetainedMsgsStreamSubject = "$MQTT.rmsgs." // Stream name for MQTT sessions on a given account - mqttSessStreamName = mqttStreamNamePrefix + "sess" + mqttSessStreamName = "$MQTT_sess" mqttSessStreamSubjectPrefix = "$MQTT.sess." // Stream name prefix for MQTT sessions on a given account - mqttSessionsStreamNamePrefix = mqttStreamNamePrefix + "sess_" + mqttSessionsStreamNamePrefix = "$MQTT_sess_" - // Normally, MQTT server should not redeliver QoS 1 messages to clients, - // except after client reconnects. However, NATS Server will redeliver - // unacknowledged messages after this default interval. This can be - // changed with the server.Options.MQTT.AckWait option. + // Stream name and subject for incoming MQTT QoS2 messages + mqttQoS2IncomingMsgsStreamName = "$MQTT_qos2in" + mqttQoS2IncomingMsgsStreamSubjectPrefix = "$MQTT.qos2.in." + + // Stream name and subjects for outgoing MQTT QoS2 PUBREL messages + mqttQoS2PubRelStreamName = "$MQTT_qos2out" + mqttQoS2PubRelStoredSubjectPrefix = "$MQTT.qos2.out." + mqttQoS2PubRelDeliverySubjectPrefix = "$MQTT.qos2.delivery." + + // As per spec, MQTT server may not redeliver QoS 1 and 2 messages to + // clients, except after client reconnects. However, NATS Server will + // redeliver unacknowledged messages after this default interval. This can + // be changed with the server.Options.MQTT.AckWait option. mqttDefaultAckWait = 30 * time.Second // This is the default for the outstanding number of pending QoS 1 @@ -159,9 +166,6 @@ const ( mqttJSARetainedMsgDel = "RD" mqttJSAStreamNames = "SN" - // Name of the header key added to NATS message to carry mqtt PUBLISH information - mqttNatsHeader = "Nmqtt-Pub" - // This is how long to keep a client in the flappers map before closing the // connection. This prevent quick reconnect from those clients that keep // wanting to connect with a client ID already in use. @@ -181,7 +185,6 @@ var ( mqttPingResponse = []byte{mqttPacketPingResp, 0x0} mqttProtoName = []byte("MQTT") mqttOldProtoName = []byte("MQIsdp") - mqttNatsHeaderB = []byte(mqttNatsHeader) mqttSessJailDur = mqttSessFlappingJailDur mqttFlapCleanItvl = mqttSessFlappingCleanupInterval mqttJSAPITimeout = 4 * time.Second @@ -206,6 +209,7 @@ var ( errMQTTTopicIsEmpty = errors.New("topic cannot be empty") errMQTTPacketIdentifierIsZero = errors.New("packet identifier cannot be 0") errMQTTUnsupportedCharacters = errors.New("characters ' ' and '.' not supported for MQTT topics") + errMQTTInvalidSession = errors.New("invalid MQTT session") ) type srvMQTT struct { @@ -262,17 +266,36 @@ type mqttRetMsgDel struct { } type mqttSession struct { - mu sync.Mutex - id string // client ID - idHash string // client ID hash - c *client - jsa *mqttJSA - subs map[string]byte - cons map[string]*ConsumerConfig - seq uint64 - pending map[uint16]*mqttPending // Key is the PUBLISH packet identifier sent to client and maps to a mqttPending record - cpending map[string]map[uint64]uint16 // For each JS consumer, the key is the stream sequence and maps to the PUBLISH packet identifier - ppi uint16 // publish packet identifier + mu sync.Mutex + id string // client ID + idHash string // client ID hash + c *client + jsa *mqttJSA + subs map[string]byte // Key is MQTT SUBSCRIBE filter, value is the subscription QoS + cons map[string]*ConsumerConfig + pubRelConsumer *ConsumerConfig + pubRelSubscribed bool + pubRelDeliverySubject string + pubRelDeliverySubjectB []byte + pubRelSubject string + seq uint64 + + // pendingPublish maps packet identifiers (PI) to JetStream ACK subjects for + // QoS1 and 2 PUBLISH messages pending delivery to the session's client. + pendingPublish map[uint16]*mqttPending + + // pendingPubRel maps PIs to JetStream ACK subjects for QoS2 PUBREL + // messages pending delivery to the session's client. + pendingPubRel map[uint16]*mqttPending + + // cpending maps delivery attempts (that come with a JS ACK subject) to + // existing PIs. + cpending map[string]map[uint64]uint16 // composite key: jsDur, sseq + + // "Last used" publish packet identifier (PI). starting point searching for the next available. + last_pi uint16 + + // Maximum number of pending acks for this session. maxp uint16 tmaxack int clean bool @@ -285,6 +308,7 @@ type mqttPersistedSession struct { Clean bool `json:"clean,omitempty"` Subs map[string]byte `json:"subs,omitempty"` Cons map[string]*ConsumerConfig `json:"cons,omitempty"` + PubRel *ConsumerConfig `json:"pubrel,omitempty"` } type mqttRetainedMsg struct { @@ -322,9 +346,9 @@ type mqtt struct { } type mqttPending struct { - sseq uint64 // stream sequence - subject string // the ACK subject to send the ack to - jsDur string // JS durable name + sseq uint64 // stream sequence + jsAckSubject string // the ACK subject to send the ack to + jsDur string // JS durable name } type mqttConnectProto struct { @@ -376,6 +400,28 @@ type mqttPublish struct { flags byte } +// When we re-encode incoming MQTT PUBLISH messages for NATS delivery, we add +// the following headers: +// - "Nmqtt-Pub" (*always) indicates that the message originated from MQTT, and +// contains the original message QoS. +// - "Nmqtt-Subject" contains the original MQTT subject from mqttParsePub. +// - "Nmqtt-Mapped" contains the mapping during mqttParsePub. +// +// When we submit a PUBREL for delivery, we add a "Nmqtt-PubRel" header that +// contains the PI. +const ( + mqttNatsHeader = "Nmqtt-Pub" + mqttNatsPubRelHeader = "Nmqtt-PubRel" + mqttNatsHeaderSubject = "Nmqtt-Subject" + mqttNatsHeaderMapped = "Nmqtt-Mapped" +) + +type mqttParsedPublishNATSHeader struct { + qos byte + subject []byte + mapped []byte +} + func (s *Server) startMQTT() { sopts := s.getOpts() o := &sopts.MQTT @@ -667,6 +713,39 @@ func (c *client) mqttParse(buf []byte) error { } switch pt { + // Packets that we receive back when we act as the "sender": PUBACK, + // PUBREC, PUBCOMP. + case mqttPacketPubAck: + var pi uint16 + pi, err = mqttParsePIPacket(r, pl) + if trace { + c.traceInOp("PUBACK", errOrTrace(err, fmt.Sprintf("pi=%v", pi))) + } + if err == nil { + err = c.mqttProcessPubAck(pi) + } + + case mqttPacketPubRec: + var pi uint16 + pi, err = mqttParsePIPacket(r, pl) + if trace { + c.traceInOp("PUBREC", errOrTrace(err, fmt.Sprintf("pi=%v", pi))) + } + if err == nil { + err = c.mqttProcessPubRec(pi) + } + + case mqttPacketPubComp: + var pi uint16 + pi, err = mqttParsePIPacket(r, pl) + if trace { + c.traceInOp("PUBCOMP", errOrTrace(err, fmt.Sprintf("pi=%v", pi))) + } + if err == nil { + c.mqttProcessPubComp(pi) + } + + // Packets where we act as the "receiver": PUBLISH, PUBREL, SUBSCRIBE, UNSUBSCRIBE. case mqttPacketPub: pp := c.mqtt.pp pp.flags = b & mqttPacketFlagMask @@ -678,23 +757,19 @@ func (c *client) mqttParse(buf []byte) error { } } if err == nil { - err = s.mqttProcessPub(c, pp) + err = s.mqttProcessPub(c, pp, trace) } - if err == nil && pp.pi > 0 { - c.mqttEnqueuePubAck(pp.pi) - if trace { - c.traceOutOp("PUBACK", []byte(fmt.Sprintf("pi=%v", pp.pi))) - } - } - case mqttPacketPubAck: + + case mqttPacketPubRel: var pi uint16 - pi, err = mqttParsePubAck(r, pl) + pi, err = mqttParsePIPacket(r, pl) if trace { - c.traceInOp("PUBACK", errOrTrace(err, fmt.Sprintf("pi=%v", pi))) + c.traceInOp("PUBREL", errOrTrace(err, fmt.Sprintf("pi=%v", pi))) } if err == nil { - c.mqttProcessPubAck(pi) + err = s.mqttProcessPubRel(c, pi, trace) } + case mqttPacketSub: var pi uint16 // packet identifier var filters []*mqttFilter @@ -713,6 +788,7 @@ func (c *client) mqttParse(buf []byte) error { c.mqttEnqueueSubAck(pi, filters) c.mqttSendRetainedMsgsToNewSubs(subs) } + case mqttPacketUnsub: var pi uint16 // packet identifier var filters []*mqttFilter @@ -729,6 +805,8 @@ func (c *client) mqttParse(buf []byte) error { if err == nil { c.mqttEnqueueUnsubAck(pi) } + + // Packets that we get both as a receiver and sender: PING, CONNECT, DISCONNECT case mqttPacketPing: if trace { c.traceInOp("PINGREQ", nil) @@ -737,6 +815,7 @@ func (c *client) mqttParse(buf []byte) error { if trace { c.traceOutOp("PINGRESP", nil) } + case mqttPacketConnect: // It is an error to receive a second connect packet if connected { @@ -770,6 +849,7 @@ func (c *client) mqttParse(buf []byte) error { rd = cp.rd } } + case mqttPacketDisconnect: if trace { c.traceInOp("DISCONNECT", nil) @@ -784,8 +864,7 @@ func (c *client) mqttParse(buf []byte) error { s.mqttHandleClosedClient(c) c.closeConnection(ClientClosed) return nil - case mqttPacketPubRec, mqttPacketPubRel, mqttPacketPubComp: - err = fmt.Errorf("protocol %d not supported", pt>>4) + default: err = fmt.Errorf("received unknown packet type %d", pt>>4) } @@ -900,15 +979,12 @@ func (s *Server) mqttGetJSAForAccount(acc string) *mqttJSA { return jsa } -func (s *Server) mqttStoreQoS1MsgForAccountOnNewSubject(hdr int, msg []byte, acc, subject string) { +func (s *Server) mqttStoreQoSMsgForAccountOnNewSubject(hdr int, msg []byte, acc, subject string) { if s == nil || hdr <= 0 { return } - nhv := getHeader(mqttNatsHeader, msg[:hdr]) - if len(nhv) < 1 { - return - } - if qos := nhv[0] - '0'; qos != 1 { + h := mqttParsePublishNATSHeader(msg[:hdr]) + if h == nil || h.qos == 0 { return } jsa := s.mqttGetJSAForAccount(acc) @@ -918,6 +994,35 @@ func (s *Server) mqttStoreQoS1MsgForAccountOnNewSubject(hdr int, msg []byte, acc jsa.storeMsg(mqttStreamSubjectPrefix+subject, hdr, msg) } +func mqttParsePublishNATSHeader(headerBytes []byte) *mqttParsedPublishNATSHeader { + if len(headerBytes) == 0 { + return nil + } + + pubValue := getHeader(mqttNatsHeader, headerBytes) + if len(pubValue) == 0 { + return nil + } + return &mqttParsedPublishNATSHeader{ + qos: pubValue[0] - '0', + subject: getHeader(mqttNatsHeaderSubject, headerBytes), + mapped: getHeader(mqttNatsHeaderMapped, headerBytes), + } +} + +func mqttParsePubRelNATSHeader(headerBytes []byte) uint16 { + if len(headerBytes) == 0 { + return 0 + } + + pubrelValue := getHeader(mqttNatsPubRelHeader, headerBytes) + if len(pubrelValue) == 0 { + return 0 + } + pi, _ := strconv.Atoi(string(pubrelValue)) + return uint16(pi) +} + // Returns the MQTT sessions manager for a given account. // If new, creates the required JetStream streams/consumers // for handling of sessions and messages. @@ -1146,6 +1251,46 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc } } + if si, err := lookupStream(mqttQoS2IncomingMsgsStreamName, "QoS2 incoming messages"); err != nil { + return nil, err + } else if si == nil { + // Create the stream for the incoming QoS2 messages that have not been + // PUBREL-ed by the sender. Subject is + // "$MQTT.qos2..", the .PI is to achieve exactly + // once for each PI. + cfg := &StreamConfig{ + Name: mqttQoS2IncomingMsgsStreamName, + Subjects: []string{mqttQoS2IncomingMsgsStreamSubjectPrefix + ">"}, + Storage: FileStorage, + Retention: LimitsPolicy, + Discard: DiscardNew, + MaxMsgsPer: 1, + DiscardNewPer: true, + Replicas: replicas, + } + if _, _, err := jsa.createStream(cfg); isErrorOtherThan(err, JSStreamNameExistErr) { + return nil, fmt.Errorf("create QoS2 incoming messages stream for account %q: %v", accName, err) + } + } + + if si, err := lookupStream(mqttQoS2PubRelStreamName, "QoS2 outgoing PUBREL"); err != nil { + return nil, err + } else if si == nil { + // Create the stream for the incoming QoS2 messages that have not been + // PUBREL-ed by the sender. NATS messages are submitted as + // "$MQTT.pubrel." + cfg := &StreamConfig{ + Name: mqttQoS2PubRelStreamName, + Subjects: []string{mqttQoS2PubRelStoredSubjectPrefix + ">"}, + Storage: FileStorage, + Retention: InterestPolicy, + Replicas: replicas, + } + if _, _, err := jsa.createStream(cfg); isErrorOtherThan(err, JSStreamNameExistErr) { + return nil, fmt.Errorf("create QoS2 outgoing PUBREL stream for account %q: %v", accName, err) + } + } + // This is the only case where we need "si" after lookup/create si, err := lookupStream(mqttRetainedMsgsStreamName, "retained messages") if err != nil { @@ -1348,6 +1493,16 @@ func (jsa *mqttJSA) newRequestEx(kind, subject string, hdr int, msg []byte, time return i, nil } +func (jsa *mqttJSA) sendAck(ackSubject string) { + if ackSubject == _EMPTY_ { + return + } + + // We pass -1 for the hdr so that the send loop does not need to + // add the "client info" header. This is not a JS API request per se. + jsa.sendq.push(&mqttJSPubMsg{subj: ackSubject, hdr: -1}) +} + func (jsa *mqttJSA) createConsumer(cfg *CreateConsumerRequest) (*JSApiConsumerCreateResponse, error) { cfgb, err := json.Marshal(cfg) if err != nil { @@ -2068,8 +2223,8 @@ func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, c *client, var err error subs := make([]*subscription, 0, len(filters)) for _, f := range filters { - if f.qos > 1 { - f.qos = 1 + if f.qos > 2 { + f.qos = 2 } subject := f.filter sid := subject @@ -2085,11 +2240,15 @@ func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, c *client, // Note that if a subscription already exists on this subject, // the existing sub is returned. Need to update the qos. asAndSessLock() - sub, err := c.processSub([]byte(subject), nil, []byte(sid), mqttDeliverMsgCbQos0, false) + sub, err := c.processSub([]byte(subject), nil, []byte(sid), mqttDeliverMsgCbQoS0, false) if err == nil { setupSub(sub, f.qos) } + if f.qos == 2 { + err = sess.ensurePubRelConsumerSubscription(c) + } asAndSessUnlock() + if err == nil { // This will create (if not already exist) a JS consumer for subscriptions // of QoS >= 1. But if a JS consumer already exists and the subscription @@ -2102,6 +2261,7 @@ func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, c *client, sess.cleanupFailedSub(c, sub, jscons, jssub) continue } + if mqttNeedSubForLevelUp(subject) { var fwjscons *ConsumerConfig var fwjssub *subscription @@ -2113,7 +2273,7 @@ func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, c *client, fwcsid := fwcsubject + mqttMultiLevelSidSuffix // See note above about existing subscription. asAndSessLock() - fwcsub, err = c.processSub([]byte(fwcsubject), nil, []byte(fwcsid), mqttDeliverMsgCbQos0, false) + fwcsub, err = c.processSub([]byte(fwcsubject), nil, []byte(fwcsid), mqttDeliverMsgCbQoS0, false) if err == nil { setupSub(fwcsub, f.qos) } @@ -2131,12 +2291,14 @@ func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, c *client, subs = append(subs, fwcsub) addJSConsToSess(fwcsid, fwjscons) } + subs = append(subs, sub) addJSConsToSess(sid, jscons) } if fromSubProto { err = sess.update(filters, true) } + return subs, err } @@ -2162,10 +2324,24 @@ func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(sess *mqttSessi sub.mqtt.prm = &mqttWriter{} } prm := sub.mqtt.prm - pi := sess.getPubAckIdentifier(mqttGetQoS(rm.Flags), sub) + var pi uint16 + qos := mqttGetQoS(rm.Flags) + if qos > sub.mqtt.qos { + qos = sub.mqtt.qos + } + if qos > 0 { + pi = sess.trackPublishRetained() + + // If we failed to get a PI for this message, send it as a QoS0, the + // best we can do? + if pi == 0 { + qos = 0 + } + } + // Need to use the subject for the retained message, not the `sub` subject. // We can find the published retained message in rm.sub.subject. - flags := mqttSerializePublishMsg(prm, pi, false, true, []byte(rm.Topic), rm.Msg) + flags := mqttSerializePublishMsg(prm, pi, qos, false, true, []byte(rm.Topic), rm.Msg) if trace { pp := mqttPublish{ topic: []byte(rm.Topic), @@ -2242,6 +2418,7 @@ func (as *mqttAccountSessionManager) createOrRestoreSession(clientID string, opt sess.clean = ps.Clean sess.subs = ps.Subs sess.cons = ps.Cons + sess.pubRelConsumer = ps.PubRel as.addSession(sess, true) return sess, true, nil } @@ -2391,7 +2568,17 @@ func mqttSessionCreate(jsa *mqttJSA, id, idHash string, seq uint64, opts *Option if maxp == 0 { maxp = mqttDefaultMaxAckPending } - return &mqttSession{jsa: jsa, id: id, idHash: idHash, seq: seq, maxp: maxp} + + return &mqttSession{ + jsa: jsa, + id: id, + idHash: idHash, + seq: seq, + maxp: maxp, + pubRelSubject: mqttQoS2PubRelStoredSubjectPrefix + idHash, + pubRelDeliverySubject: mqttQoS2PubRelDeliverySubjectPrefix + idHash, + pubRelDeliverySubjectB: []byte(mqttQoS2PubRelDeliverySubjectPrefix + idHash), + } } // Persists a session. Note that if the session's current client does not match @@ -2406,6 +2593,7 @@ func (sess *mqttSession) save() error { Clean: sess.clean, Subs: sess.subs, Cons: sess.cons, + PubRel: sess.pubRelConsumer, } b, _ := json.Marshal(&ps) @@ -2443,24 +2631,43 @@ func (sess *mqttSession) save() error { // Lock not held on entry, but session is in the locked map. func (sess *mqttSession) clear() error { var durs []string + var pubRelDur string + sess.mu.Lock() id := sess.id seq := sess.seq if l := len(sess.cons); l > 0 { durs = make([]string, 0, l) - for sid, cc := range sess.cons { - delete(sess.cons, sid) - durs = append(durs, cc.Durable) - } } - sess.subs, sess.pending, sess.cpending, sess.seq, sess.tmaxack = nil, nil, nil, 0, 0 + for sid, cc := range sess.cons { + delete(sess.cons, sid) + durs = append(durs, cc.Durable) + } + if sess.pubRelConsumer != nil { + pubRelDur = sess.pubRelConsumer.Durable + } + + sess.subs = nil + sess.pendingPublish = nil + sess.pendingPubRel = nil + sess.cpending = nil + sess.pubRelConsumer = nil + sess.seq = 0 + sess.tmaxack = 0 + sess.mu.Unlock() + for _, dur := range durs { if _, err := sess.jsa.deleteConsumer(mqttStreamName, dur); isErrorOtherThan(err, JSConsumerNotFoundErr) { - sess.mu.Unlock() return fmt.Errorf("unable to delete consumer %q for session %q: %v", dur, sess.id, err) } } - sess.mu.Unlock() + if pubRelDur != "" { + _, err := sess.jsa.deleteConsumer(mqttQoS2PubRelStreamName, pubRelDur) + if isErrorOtherThan(err, JSConsumerNotFoundErr) { + return fmt.Errorf("unable to delete consumer %q for session %q: %v", pubRelDur, sess.id, err) + } + } + if seq > 0 { if err := sess.jsa.deleteMsg(mqttSessStreamName, seq, true); err != nil { return fmt.Errorf("unable to delete session %q record at sequence %v", id, seq) @@ -2503,106 +2710,217 @@ func (sess *mqttSession) update(filters []*mqttFilter, add bool) error { return err } -// If both pQos and sub.mqtt.qos are > 0, then this will return the next -// packet identifier to use for a published message. +func (sess *mqttSession) bumpPI() uint16 { + var avail bool + next := sess.last_pi + for i := 0; i < 0xFFFF; i++ { + next++ + if next == 0 { + next = 1 + } + + _, usedInPublish := sess.pendingPublish[next] + _, usedInPubRel := sess.pendingPubRel[next] + if !usedInPublish && !usedInPubRel { + sess.last_pi = next + avail = true + break + } + } + if !avail { + return 0 + } + return sess.last_pi +} + +// trackPublishRetained is invoked when a retained (QoS) message is published. +// It need a new PI to be allocated, so we add it to the pendingPublish map, +// with an empty value. Since cpending (not pending) is used to serialize the PI +// mappings, we need to add this PI there as well. Make a unique key by using +// mqttRetainedMsgsStreamName for the durable name, and PI for sseq. // // Lock held on entry -func (sess *mqttSession) getPubAckIdentifier(pQos byte, sub *subscription) uint16 { - pi, _ := sess.trackPending(pQos, _EMPTY_, sub) +func (sess *mqttSession) trackPublishRetained() uint16 { + // Make sure we initialize the tracking maps. + if sess.pendingPublish == nil { + sess.pendingPublish = make(map[uint16]*mqttPending) + } + if sess.cpending == nil { + sess.cpending = make(map[string]map[uint64]uint16) + } + + pi := sess.bumpPI() + if pi == 0 { + return 0 + } + sess.pendingPublish[pi] = &mqttPending{} + return pi } -// If publish message QoS (pQos) and the subscription's QoS are both at least 1, -// this function will assign a packet identifier (pi) and will keep track of -// the pending ack. If the message has already been redelivered (reply != ""), -// the returned boolean will be `true`. +// trackPublish is invoked when a (QoS) PUBLISH message is to be delivered. It +// detects an untracked (new) message based on its sequence extracted from its +// delivery-time jsAckSubject, and adds it to the tracking maps. Returns a PI to +// use for the message (new, or previously used), and whether this is a +// duplicate delivery attempt. // // Lock held on entry -func (sess *mqttSession) trackPending(pQos byte, reply string, sub *subscription) (uint16, bool) { - if pQos == 0 || sub.mqtt.qos == 0 { - return 0, false - } +func (sess *mqttSession) trackPublish(jsDur, jsAckSubject string) (uint16, bool) { var dup bool var pi uint16 - bumpPI := func() uint16 { - var avail bool - next := sess.ppi - for i := 0; i < 0xFFFF; i++ { - next++ - if next == 0 { - next = 1 - } - if _, used := sess.pending[next]; !used { - sess.ppi = next - avail = true - break - } - } - if !avail { - return 0 - } - return sess.ppi + if jsAckSubject == _EMPTY_ || jsDur == _EMPTY_ { + return 0, false } - // This can happen when invoked from getPubAckIdentifier... - if reply == _EMPTY_ || sub.mqtt.jsDur == _EMPTY_ { - return bumpPI(), false + // Make sure we initialize the tracking maps. + if sess.pendingPublish == nil { + sess.pendingPublish = make(map[uint16]*mqttPending) } - - // Here, we have an ACK subject and a JS consumer... - jsDur := sub.mqtt.jsDur - if sess.pending == nil { - sess.pending = make(map[uint16]*mqttPending) + if sess.cpending == nil { sess.cpending = make(map[string]map[uint64]uint16) } - // Get the stream sequence and other from the ack reply subject - sseq, _, dcount := ackReplyInfo(reply) - var pending *mqttPending - // For this JS consumer, check to see if we already have sseq->pi + // Get the stream sequence and duplicate flag from the ack reply subject. + sseq, _, dcount := ackReplyInfo(jsAckSubject) + if dcount > 1 { + dup = true + } + + var ack *mqttPending sseqToPi, ok := sess.cpending[jsDur] if !ok { sseqToPi = make(map[uint64]uint16) sess.cpending[jsDur] = sseqToPi - } else if pi, ok = sseqToPi[sseq]; ok { - // If we already have a pi, get the ack so we update it. - // We will reuse the save packet identifier (pi). - pending = sess.pending[pi] + } else { + pi = sseqToPi[sseq] } - if pi == 0 { + + if pi != 0 { + // There is a possible race between a PUBLISH re-delivery calling us, + // and a PUBREC received already having submitting a PUBREL into JS . If + // so, indicate no need for (re-)delivery by returning a PI of 0. + _, usedForPubRel := sess.pendingPubRel[pi] + if /*dup && */ usedForPubRel { + return 0, false + } + + // We should have a pending JS ACK for this PI. + ack = sess.pendingPublish[pi] + } else { // sess.maxp will always have a value > 0. - if len(sess.pending) >= int(sess.maxp) { + if len(sess.pendingPublish) >= int(sess.maxp) { // Indicate that we did not assign a packet identifier. // The caller will not send the message to the subscription // and JS will redeliver later, based on consumer's AckWait. return 0, false } - pi = bumpPI() + + pi = sess.bumpPI() + if pi == 0 { + return 0, false + } + sseqToPi[sseq] = pi } - if pending == nil { - pending = &mqttPending{jsDur: jsDur, sseq: sseq, subject: reply} - sess.pending[pi] = pending - } - // If redelivery, return DUP flag - if dcount > 1 { - dup = true + + if ack == nil { + sess.pendingPublish[pi] = &mqttPending{ + jsDur: jsDur, + sseq: sseq, + jsAckSubject: jsAckSubject, + } + } else { + ack.jsAckSubject = jsAckSubject + ack.sseq = sseq + ack.jsDur = jsDur } + return pi, dup } -// Sends a request to create a JS Durable Consumer based on the given consumer's config. -// This will wait in place for the reply from the server handling the requests. +// Stops a PI from being tracked as a PUBLISH. It can still be in use for a +// pending PUBREL. // // Lock held on entry -func (sess *mqttSession) createConsumer(consConfig *ConsumerConfig) error { - cfg := &CreateConsumerRequest{ - Stream: mqttStreamName, - Config: *consConfig, +func (sess *mqttSession) untrackPublish(pi uint16) (jsAckSubject string) { + ack, ok := sess.pendingPublish[pi] + if !ok { + return _EMPTY_ } - _, err := sess.jsa.createConsumer(cfg) - return err + + delete(sess.pendingPublish, pi) + if len(sess.pendingPublish) == 0 { + sess.last_pi = 0 + } + + if len(sess.cpending) != 0 && ack.jsDur != _EMPTY_ { + if sseqToPi := sess.cpending[ack.jsDur]; sseqToPi != nil { + delete(sseqToPi, ack.sseq) + } + } + + return ack.jsAckSubject +} + +// trackPubRel is invoked in 2 cases: (a) when we receive a PUBREC and we need +// to change from tracking the PI as a PUBLISH to a PUBREL; and (b) when we +// attempt to deliver the PUBREL to record the JS ack subject for it. +// +// Lock held on entry +func (sess *mqttSession) trackAsPubRel(pi uint16, jsAckSubject string) { + if sess.pubRelConsumer == nil { + // The cosumer MUST be set up already. + return + } + jsDur := sess.pubRelConsumer.Durable + + if sess.pendingPubRel == nil { + sess.pendingPubRel = make(map[uint16]*mqttPending) + } + + if jsAckSubject == _EMPTY_ { + sess.pendingPubRel[pi] = &mqttPending{ + jsDur: jsDur, + } + return + } + + sseq, _, _ := ackReplyInfo(jsAckSubject) + + var sseqToPi map[uint64]uint16 + if sess.cpending == nil { + sess.cpending = make(map[string]map[uint64]uint16) + } else if sseqToPi = sess.cpending[jsDur]; sseqToPi == nil { + sseqToPi = make(map[uint64]uint16) + sess.cpending[jsDur] = sseqToPi + } + sseqToPi[sseq] = pi + sess.pendingPubRel[pi] = &mqttPending{ + jsDur: sess.pubRelConsumer.Durable, + sseq: sseq, + jsAckSubject: jsAckSubject, + } +} + +// Stops a PI from being tracked as a PUBREL. +// +// Lock held on entry +func (sess *mqttSession) untrackPubRel(pi uint16) (jsAckSubject string) { + ack, ok := sess.pendingPubRel[pi] + if !ok { + return _EMPTY_ + } + + delete(sess.pendingPubRel, pi) + + if sess.pubRelConsumer != nil && len(sess.cpending) > 0 { + if sseqToPi := sess.cpending[ack.jsDur]; sseqToPi != nil { + delete(sseqToPi, ack.sseq) + } + } + + return ack.jsAckSubject } // Sends a consumer delete request, but does not wait for response. @@ -3050,7 +3368,7 @@ func (s *Server) mqttHandleWill(c *client) { pp.flags |= mqttPubFlagRetain } c.mu.Unlock() - s.mqttProcessPub(c, pp) + s.mqttInitiateMsgDelivery(c, pp) c.flushClients(0) } @@ -3062,8 +3380,8 @@ func (s *Server) mqttHandleWill(c *client) { func (c *client) mqttParsePub(r *mqttReader, pl int, pp *mqttPublish, hasMappings bool) error { qos := mqttGetQoS(pp.flags) - if qos > 1 { - return fmt.Errorf("publish QoS=%v not supported", qos) + if qos > 2 { + return fmt.Errorf("QoS=%v is invalid in MQTT", qos) } // Keep track of where we are when starting to read the variable header start := r.pos @@ -3140,44 +3458,239 @@ func mqttPubTrace(pp *mqttPublish) string { pp.topic, dup, qos, retain, pp.sz, piStr) } +// Composes a NATS message from a MQTT PUBLISH packet. The message includes an +// internal header containint the original packet's QoS, and for QoS2 packets +// the original subject. +// +// Example (QoS2, subject: "foo.bar"): +// +// NATS/1.0\r\n +// Nmqtt-Pub:2foo.bar\r\n +// \r\n +func mqttNewDeliverableMessage(pp *mqttPublish, encodePP bool) (natsMsg []byte, headerLen int) { + size := len(hdrLine) + + len(mqttNatsHeader) + 2 + 2 + // 2 for ':', and 2 for CRLF + 2 + // end-of-header CRLF + len(pp.msg) + if encodePP { + size += len(mqttNatsHeaderSubject) + 1 + // +1 for ':' + len(pp.subject) + 2 // 2 for CRLF + + if len(pp.mapped) > 0 { + size += len(mqttNatsHeaderMapped) + 1 + // +1 for ':' + len(pp.mapped) + 2 // 2 for CRLF + } + } + buf := bytes.NewBuffer(make([]byte, 0, size)) + + qos := mqttGetQoS(pp.flags) + + buf.WriteString(hdrLine) + buf.WriteString(mqttNatsHeader) + buf.WriteByte(':') + buf.WriteByte(qos + '0') + buf.WriteString(_CRLF_) + + if encodePP { + buf.WriteString(mqttNatsHeaderSubject) + buf.WriteByte(':') + buf.Write(pp.subject) + buf.WriteString(_CRLF_) + + if len(pp.mapped) > 0 { + buf.WriteString(mqttNatsHeaderMapped) + buf.WriteByte(':') + buf.Write(pp.mapped) + buf.WriteString(_CRLF_) + } + } + + // End of header + buf.WriteString(_CRLF_) + + headerLen = buf.Len() + + buf.Write(pp.msg) + return buf.Bytes(), headerLen +} + +// Composes a NATS message for a pending PUBREL packet. The message includes an +// internal header containing the PI for PUBREL/PUBCOMP. +// +// Example (PI:123): +// +// NATS/1.0\r\n +// Nmqtt-PubRel:123\r\n +// \r\n +func mqttNewDeliverablePubRel(pi uint16) (natsMsg []byte, headerLen int) { + size := len(hdrLine) + + len(mqttNatsPubRelHeader) + 6 + 2 + // 6 for ':65535', and 2 for CRLF + 2 // end-of-header CRLF + buf := bytes.NewBuffer(make([]byte, 0, size)) + buf.WriteString(hdrLine) + buf.WriteString(mqttNatsPubRelHeader) + buf.WriteByte(':') + buf.WriteString(strconv.FormatInt(int64(pi), 10)) + buf.WriteString(_CRLF_) + buf.WriteString(_CRLF_) + return buf.Bytes(), buf.Len() +} + // Process the PUBLISH packet. // // Runs from the client's readLoop. // No lock held on entry. -func (s *Server) mqttProcessPub(c *client, pp *mqttPublish) error { - c.pa.subject, c.pa.mapped, c.pa.hdr, c.pa.size, c.pa.reply = pp.subject, pp.mapped, -1, pp.sz, nil +func (s *Server) mqttProcessPub(c *client, pp *mqttPublish, trace bool) error { + qos := mqttGetQoS(pp.flags) - bb := bytes.Buffer{} - bb.WriteString(hdrLine) - bb.Write(mqttNatsHeaderB) - bb.WriteByte(':') - bb.WriteByte('0' + mqttGetQoS(pp.flags)) - bb.WriteString(_CRLF_) - bb.WriteString(_CRLF_) - c.pa.hdr = bb.Len() - c.pa.hdb = []byte(strconv.FormatInt(int64(c.pa.hdr), 10)) - bb.Write(pp.msg) - c.pa.size = bb.Len() - c.pa.szb = []byte(strconv.FormatInt(int64(c.pa.size), 10)) - msgToSend := bb.Bytes() + switch qos { + case 0: + return s.mqttInitiateMsgDelivery(c, pp) - var err error - // Unless we have a publish permission error, if the message is QoS1, then we - // need to store the message (and deliver it to JS durable consumers). - if _, permIssue := c.processInboundClientMsg(msgToSend); !permIssue && mqttGetQoS(pp.flags) > 0 { - // We need to call flushClients now since this we may have called c.addToPCD - // with destination clients (possibly a route). Without calling flushClients - // the following call may then be stuck waiting for a reply that may never - // come because the destination is not flushed (due to c.out.fsp > 0, - // see addToPCD and writeLoop for details). - c.flushClients(0) - // Store this QoS1 message. - _, err = c.mqtt.sess.jsa.storeMsg(mqttStreamSubjectPrefix+string(c.pa.subject), c.pa.hdr, msgToSend) + case 1: + // [MQTT-4.3.2-2]. Initiate onward delivery of the Application Message, + // Send PUBACK. + // + // The receiver is not required to complete delivery of the Application + // Message before sending the PUBACK. When its original sender receives + // the PUBACK packet, ownership of the Application Message is + // transferred to the receiver. + err := s.mqttInitiateMsgDelivery(c, pp) + if err == nil { + c.mqttEnqueuePubResponse(mqttPacketPubAck, pp.pi, trace) + } + return err + + case 2: + // [MQTT-4.3.3-2]. Method A, Store message, send PUBREC. + // + // The receiver is not required to complete delivery of the Application + // Message before sending the PUBREC or PUBCOMP. When its original + // sender receives the PUBREC packet, ownership of the Application + // Message is transferred to the receiver. + err := s.mqttStoreQoS2MsgOnce(c, pp) + if err == nil { + c.mqttEnqueuePubResponse(mqttPacketPubRec, pp.pi, trace) + } + return err + + default: + return fmt.Errorf("unreachable: invalid QoS in mqttProcessPub: %v", qos) } - c.pa.subject, c.pa.mapped, c.pa.hdr, c.pa.size, c.pa.szb, c.pa.reply = nil, nil, -1, 0, nil, nil +} + +func (s *Server) mqttInitiateMsgDelivery(c *client, pp *mqttPublish) error { + natsMsg, headerLen := mqttNewDeliverableMessage(pp, false) + + // Set the client's pubarg for processing. + c.pa.subject = pp.subject + c.pa.mapped = pp.mapped + c.pa.reply = nil + c.pa.hdr = headerLen + c.pa.hdb = []byte(strconv.FormatInt(int64(c.pa.hdr), 10)) + c.pa.size = len(natsMsg) + c.pa.szb = []byte(strconv.FormatInt(int64(c.pa.size), 10)) + defer func() { + c.pa.subject = nil + c.pa.mapped = nil + c.pa.reply = nil + c.pa.hdr = -1 + c.pa.hdb = nil + c.pa.size = 0 + c.pa.szb = nil + }() + + _, permIssue := c.processInboundClientMsg(natsMsg) + if permIssue { + return nil + } + + // If QoS 0 messages don't need to be stored, other (1 and 2) do. Store them + // JetStream under "$MQTT.msgs." + if qos := mqttGetQoS(pp.flags); qos == 0 { + return nil + } + + // We need to call flushClients now since this we may have called c.addToPCD + // with destination clients (possibly a route). Without calling flushClients + // the following call may then be stuck waiting for a reply that may never + // come because the destination is not flushed (due to c.out.fsp > 0, + // see addToPCD and writeLoop for details). + c.flushClients(0) + + _, err := c.mqtt.sess.jsa.storeMsg(mqttStreamSubjectPrefix+string(c.pa.subject), headerLen, natsMsg) + return err } +var mqttMaxMsgErrPattern = fmt.Sprintf("%s (%v)", ErrMaxMsgsPerSubject.Error(), JSStreamStoreFailedF) + +func (s *Server) mqttStoreQoS2MsgOnce(c *client, pp *mqttPublish) error { + // `true` means encode the MQTT PUBLISH packet in the NATS message header. + natsMsg, headerLen := mqttNewDeliverableMessage(pp, true) + + // Do not broadcast the message until it has been deduplicated and released + // by the sender. Instead store this QoS2 message as + // "$MQTT.qos2..". If the message is a duplicate, we get back + // a ErrMaxMsgsPerSubject, otherwise it does not change the flow, still need + // to send a PUBREC back to the client. The original subject (translated + // from MQTT topic) is included in the NATS header of the stored message to + // use for latter delivery. + _, err := c.mqtt.sess.jsa.storeMsg(c.mqttQoS2InternalSubject(pp.pi), headerLen, natsMsg) + + // TODO: would prefer a more robust and performant way of checking the + // error, but it comes back wrapped as an API result. + if err != nil && + (isErrorOtherThan(err, JSStreamStoreFailedF) || err.Error() != mqttMaxMsgErrPattern) { + return err + } + + return nil +} + +func (c *client) mqttQoS2InternalSubject(pi uint16) string { + return mqttQoS2IncomingMsgsStreamSubjectPrefix + c.mqtt.cid + "." + strconv.FormatUint(uint64(pi), 10) +} + +// Process a PUBREL packet (QoS2, acting as Receiver). +// +// Runs from the client's readLoop. +// No lock held on entry. +func (s *Server) mqttProcessPubRel(c *client, pi uint16, trace bool) error { + // Once done with the processing, send a PUBCOMP back to the client. + defer c.mqttEnqueuePubResponse(mqttPacketPubComp, pi, trace) + + // See if there is a message pending for this pi. All failures are treated + // as "not found". + asm := c.mqtt.asm + stored, _ := asm.jsa.loadLastMsgFor(mqttQoS2IncomingMsgsStreamName, c.mqttQoS2InternalSubject(pi)) + + if stored == nil { + // No message found, nothing to do. + return nil + } + // Best attempt to delete the message from the QoS2 stream. + asm.jsa.deleteMsg(mqttQoS2IncomingMsgsStreamName, stored.Sequence, true) + + // only MQTT QoS2 messages should be here, and they must have a subject. + h := mqttParsePublishNATSHeader(stored.Header) + if h == nil || h.qos != 2 || len(h.subject) == 0 { + return errors.New("invalid message in QoS2 PUBREL stream") + } + + pp := &mqttPublish{ + topic: natsSubjectToMQTTTopic(string(h.subject)), + subject: h.subject, + mapped: h.mapped, + msg: stored.Data, + sz: len(stored.Data), + pi: pi, + flags: h.qos << 1, + } + + return s.mqttInitiateMsgDelivery(c, pp) +} + // Invoked when processing an inbound client message. If the "retain" flag is // set, the message is stored so it can be later resent to (re)starting // subscriptions that match the subject. @@ -3395,16 +3908,31 @@ func mqttWritePublish(w *mqttWriter, qos byte, dup, retain bool, subject string, w.Write([]byte(payload)) } -func (c *client) mqttEnqueuePubAck(pi uint16) { - proto := [4]byte{mqttPacketPubAck, 0x2, 0, 0} +func (c *client) mqttEnqueuePubResponse(packetType byte, pi uint16, trace bool) { + proto := [4]byte{packetType, 0x2, 0, 0} proto[2] = byte(pi >> 8) proto[3] = byte(pi) c.mu.Lock() c.enqueueProto(proto[:4]) c.mu.Unlock() + + if trace { + name := "(???)" + switch packetType { + case mqttPacketPubAck: + name = "PUBACK" + case mqttPacketPubRec: + name = "PUBREC" + case mqttPacketPubRel: + name = "PUBREL" + case mqttPacketPubComp: + name = "PUBCOMP" + } + c.traceOutOp(name, []byte(fmt.Sprintf("pi=%v", pi))) + } } -func mqttParsePubAck(r *mqttReader, pl int) (uint16, error) { +func mqttParsePIPacket(r *mqttReader, pl int) (uint16, error) { pi, err := r.readUint16("packet identifier") if err != nil { return 0, err @@ -3415,40 +3943,71 @@ func mqttParsePubAck(r *mqttReader, pl int) (uint16, error) { return pi, nil } -// Process a PUBACK packet. -// Updates the session's pending list and sends an ACK to JS. +// Process a PUBACK (QoS1) or a PUBREC (QoS2) packet, acting as Sender. Set +// isPubRec to false to process as a PUBACK. // -// Runs from the client's readLoop. -// No lock held on entry. -func (c *client) mqttProcessPubAck(pi uint16) { +// Runs from the client's readLoop. No lock held on entry. +func (c *client) mqttProcessPublishReceived(pi uint16, isPubRec bool) (err error) { + sess := c.mqtt.sess + if sess == nil { + return errMQTTInvalidSession + } + + var jsAckSubject string + sess.mu.Lock() + // Must be the same client, and the session must have been setup for QoS2. + if sess.c != c { + sess.mu.Unlock() + return errMQTTInvalidSession + } + if isPubRec { + // The JS ACK subject for the PUBREL will be filled in at the delivery + // attempt. + sess.trackAsPubRel(pi, _EMPTY_) + } + jsAckSubject = sess.untrackPublish(pi) + sess.mu.Unlock() + + if isPubRec { + natsMsg, headerLen := mqttNewDeliverablePubRel(pi) + _, err = sess.jsa.storeMsg(sess.pubRelSubject, headerLen, natsMsg) + if err != nil { + // Failure to send out PUBREL will terminate the connection. + return err + } + } + + // Send the ack to JS to remove the pending message from the consumer. + sess.jsa.sendAck(jsAckSubject) + return nil +} + +func (c *client) mqttProcessPubAck(pi uint16) error { + return c.mqttProcessPublishReceived(pi, false) +} + +func (c *client) mqttProcessPubRec(pi uint16) error { + return c.mqttProcessPublishReceived(pi, true) +} + +// Runs from the client's readLoop. No lock held on entry. +func (c *client) mqttProcessPubComp(pi uint16) { sess := c.mqtt.sess if sess == nil { return } + + var jsAckSubject string sess.mu.Lock() if sess.c != c { sess.mu.Unlock() return } - var ackSubject string - if ack, ok := sess.pending[pi]; ok { - delete(sess.pending, pi) - jsDur := ack.jsDur - if sseqToPi, ok := sess.cpending[jsDur]; ok { - delete(sseqToPi, ack.sseq) - } - if len(sess.pending) == 0 { - sess.ppi = 0 - } - ackSubject = ack.subject - } - // Send the ack if applicable. - if ackSubject != _EMPTY_ { - // We pass -1 for the hdr so that the send loop does not need to - // add the "client info" header. This is not a JS API request per se. - sess.jsa.sendq.push(&mqttJSPubMsg{subj: ackSubject, hdr: -1}) - } + jsAckSubject = sess.untrackPubRel(pi) sess.mu.Unlock() + + // Send the ack to JS to remove the pending message from the consumer. + sess.jsa.sendAck(jsAckSubject) } // Return the QoS from the given PUBLISH protocol's flags @@ -3551,24 +4110,25 @@ func mqttSubscribeTrace(pi uint16, filters []*mqttFilter) string { return sb.String() } -// For a MQTT QoS0 subscription, we create a single NATS subscription -// on the actual subject, for instance "foo.bar". -// For a MQTT QoS1 subscription, we create 2 subscriptions, one on -// "foo.bar" (as for QoS0, but sub.mqtt.qos will be 1), and one on -// the subject "$MQTT.sub." which is the delivery subject of -// the JS durable consumer with the filter subject "$MQTT.msgs.foo.bar". +// For a MQTT QoS0 subscription, we create a single NATS subscription on the +// actual subject, for instance "foo.bar". +// +// For a MQTT QoS1+ subscription, we create 2 subscriptions, one on "foo.bar" +// (as for QoS0, but sub.mqtt.qos will be 1 or 2), and one on the subject +// "$MQTT.sub." which is the delivery subject of the JS durable consumer +// with the filter subject "$MQTT.msgs.foo.bar". // // This callback delivers messages to the client as QoS0 messages, either -// because they have been produced as QoS0 messages (and therefore only -// this callback can receive them), they are QoS1 published messages but -// this callback is for a subscription that is QoS0, or the published -// messages come from NATS publishers. +// because: (a) they have been produced as MQTT QoS0 messages (and therefore +// only this callback can receive them); (b) they are MQTT QoS1+ published +// messages but this callback is for a subscription that is QoS0; or (c) the +// published messages come from (other) NATS publishers on the subject. // -// This callback must reject a message if it is known to be a QoS1 published -// message and this is the callback for a QoS1 subscription because in -// that case, it will be handled by the other callback. This avoid getting -// duplicate deliveries. -func mqttDeliverMsgCbQos0(sub *subscription, pc *client, _ *Account, subject, reply string, rmsg []byte) { +// This callback must reject a message if it is known to be a QoS1+ published +// message and this is the callback for a QoS1+ subscription because in that +// case, it will be handled by the other callback. This avoid getting duplicate +// deliveries. +func mqttDeliverMsgCbQoS0(sub *subscription, pc *client, _ *Account, subject, reply string, rmsg []byte) { if pc.kind == JETSTREAM && len(reply) > 0 && strings.HasPrefix(reply, jsAckPre) { return } @@ -3584,23 +4144,24 @@ func mqttDeliverMsgCbQos0(sub *subscription, pc *client, _ *Account, subject, re // Check the subscription's QoS. This needs to be protected because // the client may change an existing subscription at any time. sess.mu.Lock() - subQos := sub.mqtt.qos - isReserved := mqttCheckReserved(sub, subject) + subQoS := sub.mqtt.qos + isReservedSub := mqttIsReservedSub(sub, subject) sess.mu.Unlock() // We have a wildcard subscription and this subject starts with '$' so ignore per Spec [MQTT-4.7.2-1]. - if isReserved { + if isReservedSub { return } var retained bool var topic []byte - - // This is an MQTT publisher directly connected to this server. if pc.isMqtt() { - // If the MQTT subscription is QoS1, then we bail out if the published - // message is QoS1 because it will be handled in the other callack. - if subQos == 1 && mqttGetQoS(pc.mqtt.pp.flags) > 0 { + // This is an MQTT publisher directly connected to this server. + + // If the message was published with a QoS > 0 and the sub has the QoS > + // 0 then the message will be delivered by the other callback. + msgQoS := mqttGetQoS(pc.mqtt.pp.flags) + if subQoS > 0 && msgQoS > 0 { return } topic = pc.mqtt.pp.topic @@ -3612,16 +4173,10 @@ func mqttDeliverMsgCbQos0(sub *subscription, pc *client, _ *Account, subject, re } else { // Non MQTT client, could be NATS publisher, or ROUTER, etc.. - - // For QoS1 subs, we need to make sure that if there is a header, it does - // not say that this is a QoS1 published message, because it will be handled - // in the other callback. - if subQos != 0 && len(hdr) > 0 { - if nhv := getHeader(mqttNatsHeader, hdr); len(nhv) >= 1 { - if qos := nhv[0] - '0'; qos > 0 { - return - } - } + h := mqttParsePublishNATSHeader(hdr) + if subQoS > 0 && h != nil && h.qos > 0 { + // will be delivered by the JetStream callback + return } // If size is more than what a MQTT client can handle, we should probably reject, @@ -3633,17 +4188,16 @@ func mqttDeliverMsgCbQos0(sub *subscription, pc *client, _ *Account, subject, re } // Message never has a packet identifier nor is marked as duplicate. - pc.mqttDeliver(cc, sub, 0, false, retained, topic, msg) + pc.mqttEnqueuePublishMsgTo(cc, sub, 0, 0, false, retained, topic, msg) } -// This is the callback attached to a JS durable subscription for a MQTT Qos1 sub. -// Only JETSTREAM should be sending a message to this subject (the delivery subject -// associated with the JS durable consumer), but in cluster mode, this can be coming -// from a route, gw, etc... We make sure that if this is the case, the message contains -// a NATS/MQTT header that indicates that this is a published QoS1 message. -func mqttDeliverMsgCbQos1(sub *subscription, pc *client, _ *Account, subject, reply string, rmsg []byte) { - var retained bool - +// This is the callback attached to a JS durable subscription for a MQTT QoS 1+ +// sub. Only JETSTREAM should be sending a message to this subject (the delivery +// subject associated with the JS durable consumer), but in cluster mode, this +// can be coming from a route, gw, etc... We make sure that if this is the case, +// the message contains a NATS/MQTT header that indicates that this is a +// published QoS1+ message. +func mqttDeliverMsgCbQoS12(sub *subscription, pc *client, _ *Account, subject, reply string, rmsg []byte) { // Message on foo.bar is stored under $MQTT.msgs.foo.bar, so the subject has to be // at least as long as the stream subject prefix "$MQTT.msgs.", and after removing // the prefix, has to be at least 1 character long. @@ -3652,17 +4206,12 @@ func mqttDeliverMsgCbQos1(sub *subscription, pc *client, _ *Account, subject, re } hdr, msg := pc.msgParts(rmsg) - if pc.kind != JETSTREAM { - if len(hdr) == 0 { - return - } - nhv := getHeader(mqttNatsHeader, hdr) - if len(nhv) < 1 { - return - } - if qos := nhv[0] - '0'; qos != 1 { - return - } + h := mqttParsePublishNATSHeader(hdr) + if pc.kind != JETSTREAM && (h == nil || h.qos == 0) { + // MQTT QoS 0 messages must be ignored, they will be delivered by the + // other callback, the direct NATS subscription. All JETSTREAM messages + // will have the header. + return } // This is the client associated with the subscription. @@ -3679,20 +4228,32 @@ func mqttDeliverMsgCbQos1(sub *subscription, pc *client, _ *Account, subject, re return } - // This is a QoS1 message for a QoS1 subscription, so get the pi and keep - // track of ack subject. - pQoS := byte(1) - pi, dup := sess.trackPending(pQoS, reply, sub) - - // Check for reserved subject violation. - strippedSubj := string(subject[len(mqttStreamSubjectPrefix):]) - if isReserved := mqttCheckReserved(sub, strippedSubj); isReserved { + // In this callback we handle only QoS-published messages to QoS + // subscriptions. Ignore if either is 0, will be delivered by the other + // callback, mqttDeliverMsgCbQos1. + var qos byte + if h != nil { + qos = h.qos + } + if qos > sub.mqtt.qos { + qos = sub.mqtt.qos + } + if qos == 0 { sess.mu.Unlock() - if pi > 0 { - cc.mqttProcessPubAck(pi) - } return } + + // Check for reserved subject violation. If so, we will send the ack to + // remove the message, and do nothing else. + strippedSubj := string(subject[len(mqttStreamSubjectPrefix):]) + isReservedSub := mqttIsReservedSub(sub, strippedSubj) + if isReservedSub { + sess.mu.Unlock() + sess.jsa.sendAck(reply) + return + } + + pi, dup := sess.trackPublish(sub.mqtt.jsDur, reply) sess.mu.Unlock() if pi == 0 { @@ -3702,15 +4263,45 @@ func mqttDeliverMsgCbQos1(sub *subscription, pc *client, _ *Account, subject, re return } - topic := natsSubjectToMQTTTopic(strippedSubj) + originalTopic := natsSubjectToMQTTTopic(strippedSubj) + pc.mqttEnqueuePublishMsgTo(cc, sub, pi, qos, dup, false, originalTopic, msg) +} - pc.mqttDeliver(cc, sub, pi, dup, retained, topic, msg) +func mqttDeliverPubRelCb(sub *subscription, pc *client, _ *Account, subject, reply string, rmsg []byte) { + if sub.client.mqtt == nil || sub.client.mqtt.sess == nil || reply == _EMPTY_ { + return + } + + hdr, _ := pc.msgParts(rmsg) + pi := mqttParsePubRelNATSHeader(hdr) + if pi == 0 { + return + } + + // This is the client associated with the subscription. + cc := sub.client + + // This is immutable + sess := cc.mqtt.sess + + sess.mu.Lock() + if sess.c != cc || sess.pubRelConsumer == nil { + sess.mu.Unlock() + return + } + sess.trackAsPubRel(pi, reply) + trace := cc.trace + sess.mu.Unlock() + + cc.mqttEnqueuePubResponse(mqttPacketPubRel, pi, trace) } // The MQTT Server MUST NOT match Topic Filters starting with a wildcard character (# or +) // with Topic Names beginning with a $ character, Spec [MQTT-4.7.2-1]. // We will return true if there is a violation. -func mqttCheckReserved(sub *subscription, subject string) bool { +// +// Session lock must be held on entry to protect access to sub.mqtt.reserved. +func mqttIsReservedSub(sub *subscription, subject string) bool { // If the subject does not start with $ nothing to do here. if !sub.mqtt.reserved || len(subject) == 0 || subject[0] != mqttReservedPre { return false @@ -3732,11 +4323,11 @@ func isMQTTReservedSubscription(subject string) bool { // Common function to mqtt delivery callbacks to serialize and send the message // to the `cc` client. -func (c *client) mqttDeliver(cc *client, sub *subscription, pi uint16, dup, retained bool, topic, msg []byte) { +func (c *client) mqttEnqueuePublishMsgTo(cc *client, sub *subscription, pi uint16, qos byte, dup, retained bool, topic, msg []byte) { sw := mqttWriter{} w := &sw - flags := mqttSerializePublishMsg(w, pi, dup, retained, topic, msg) + flags := mqttSerializePublishMsg(w, pi, qos, dup, retained, topic, msg) cc.mu.Lock() if sub.mqtt.prm != nil { @@ -3745,7 +4336,10 @@ func (c *client) mqttDeliver(cc *client, sub *subscription, pi uint16, dup, reta } cc.queueOutbound(w.Bytes()) c.addToPCD(cc) - if cc.trace { + trace := cc.trace + cc.mu.Unlock() + + if trace { pp := mqttPublish{ topic: topic, flags: flags, @@ -3754,12 +4348,10 @@ func (c *client) mqttDeliver(cc *client, sub *subscription, pi uint16, dup, reta } cc.traceOutOp("PUBLISH", []byte(mqttPubTrace(&pp))) } - cc.mu.Unlock() } // Serializes to the given writer the message for the given subject. -func mqttSerializePublishMsg(w *mqttWriter, pi uint16, dup, retained bool, topic, msg []byte) byte { - +func mqttSerializePublishMsg(w *mqttWriter, pi uint16, qos byte, dup, retained bool, topic, msg []byte) byte { // Compute len (will have to add packet id if message is sent as QoS>=1) pkLen := 2 + len(topic) + len(msg) @@ -3775,7 +4367,7 @@ func mqttSerializePublishMsg(w *mqttWriter, pi uint16, dup, retained bool, topic // For now, we have only QoS 1 if pi > 0 { pkLen += 2 - flags |= mqttPubQos1 + flags |= qos << 1 } w.WriteByte(mqttPacketPub | flags) @@ -3835,6 +4427,68 @@ func (sess *mqttSession) cleanupFailedSub(c *client, sub *subscription, cc *Cons } } +// Make sure we are set up to deliver PUBREL messages to this QoS2-subscribed +// session. +// +// Session lock held on entry. Need to make sure no other subscribe packet races +// to do the same. +func (sess *mqttSession) ensurePubRelConsumerSubscription(c *client) error { + opts := c.srv.getOpts() + ackWait := opts.MQTT.AckWait + if ackWait == 0 { + ackWait = mqttDefaultAckWait + } + maxAckPending := int(opts.MQTT.MaxAckPending) + if maxAckPending == 0 { + maxAckPending = mqttDefaultMaxAckPending + } + + // Subscribe before the consumer is created so we don't loose any messages. + if !sess.pubRelSubscribed { + _, err := c.processSub(sess.pubRelDeliverySubjectB, nil, sess.pubRelDeliverySubjectB, + mqttDeliverPubRelCb, false) + if err != nil { + c.Errorf("Unable to create subscription for JetStream consumer on %q: %v", sess.pubRelDeliverySubject, err) + return err + } + sess.pubRelSubscribed = true + } + + // Create the consumer if needed. + if sess.pubRelConsumer == nil { + // Check that the limit of subs' maxAckPending are not going over the limit + if after := sess.tmaxack + maxAckPending; after > mqttMaxAckTotalLimit { + return fmt.Errorf("max_ack_pending for all consumers would be %v which exceeds the limit of %v", + after, mqttMaxAckTotalLimit) + } + + ccr := &CreateConsumerRequest{ + Stream: mqttQoS2PubRelStreamName, + Config: ConsumerConfig{ + DeliverSubject: sess.pubRelDeliverySubject, + Durable: sess.idHash + "_pubrel", + AckPolicy: AckExplicit, + DeliverPolicy: DeliverNew, + FilterSubject: sess.pubRelSubject, + AckWait: ackWait, + MaxAckPending: maxAckPending, + MemoryStorage: opts.MQTT.ConsumerMemoryStorage, + }, + } + if opts.MQTT.ConsumerInactiveThreshold > 0 { + ccr.Config.InactiveThreshold = opts.MQTT.ConsumerInactiveThreshold + } + if _, err := sess.jsa.createConsumer(ccr); err != nil { + c.Errorf("Unable to add JetStream consumer for PUBREL for client %q: err=%v", sess.id, err) + return err + } + sess.pubRelConsumer = &ccr.Config + sess.tmaxack += maxAckPending + } + + return nil +} + // When invoked with a QoS of 0, looks for an existing JS durable consumer for // the given sid and if one is found, delete the JS durable consumer and unsub // the NATS subscription on the delivery subject. @@ -3897,29 +4551,33 @@ func (sess *mqttSession) processJSConsumer(c *client, subject, sid string, } durName := sess.idHash + "_" + nuid.Next() - cc = &ConsumerConfig{ - DeliverSubject: inbox, - Durable: durName, - AckPolicy: AckExplicit, - DeliverPolicy: DeliverNew, - FilterSubject: mqttStreamSubjectPrefix + subject, - AckWait: ackWait, - MaxAckPending: maxAckPending, - MemoryStorage: opts.MQTT.ConsumerMemoryStorage, + ccr := &CreateConsumerRequest{ + Stream: mqttStreamName, + Config: ConsumerConfig{ + DeliverSubject: inbox, + Durable: durName, + AckPolicy: AckExplicit, + DeliverPolicy: DeliverNew, + FilterSubject: mqttStreamSubjectPrefix + subject, + AckWait: ackWait, + MaxAckPending: maxAckPending, + MemoryStorage: opts.MQTT.ConsumerMemoryStorage, + }, } if opts.MQTT.ConsumerInactiveThreshold > 0 { - cc.InactiveThreshold = opts.MQTT.ConsumerInactiveThreshold + ccr.Config.InactiveThreshold = opts.MQTT.ConsumerInactiveThreshold } - if err := sess.createConsumer(cc); err != nil { + if _, err := sess.jsa.createConsumer(ccr); err != nil { c.Errorf("Unable to add JetStream consumer for subscription on %q: err=%v", subject, err) return nil, nil, err } + cc = &ccr.Config sess.tmaxack += maxAckPending } // This is an internal subscription on subject like "$MQTT.sub." that is setup // for the JS durable's deliver subject. sess.mu.Lock() - sub, err := c.processSub([]byte(inbox), nil, []byte(inbox), mqttDeliverMsgCbQos1, false) + sub, err := c.processSub([]byte(inbox), nil, []byte(inbox), mqttDeliverMsgCbQoS12, false) if err != nil { sess.mu.Unlock() sess.deleteConsumer(cc) @@ -4004,10 +4662,10 @@ func (c *client) mqttProcessUnsubs(filters []*mqttFilter) error { if seqPis, ok := sess.cpending[cc.Durable]; ok { delete(sess.cpending, cc.Durable) for _, pi := range seqPis { - delete(sess.pending, pi) + delete(sess.pendingPublish, pi) } - if len(sess.pending) == 0 { - sess.ppi = 0 + if len(sess.pendingPublish) == 0 { + sess.last_pi = 0 } } sess.mu.Unlock() diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 46564134..00b35953 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -134,6 +134,18 @@ func testMQTTReadPacket(t testing.TB, r *mqttReader) (byte, int) { return b, pl } +func testMQTTReadPIPacket(expectedType byte, t testing.TB, r *mqttReader, expectedPI uint16) { + t.Helper() + b, _ := testMQTTReadPacket(t, r) + if pt := b & mqttPacketMask; pt != expectedType { + t.Fatalf("Expected packet %x, got %x", expectedType, pt) + } + rpi, err := r.readUint16("packet identifier") + if err != nil || rpi != expectedPI { + t.Fatalf("Expected PI %v got: %v, err=%v", expectedPI, rpi, err) + } +} + func TestMQTTReader(t *testing.T) { r := &mqttReader{} r.reset([]byte{0, 2, 'a', 'b'}) @@ -1699,39 +1711,6 @@ func TestMQTTDontSetPinger(t *testing.T) { testMQTTPublish(t, mc, r, 0, false, false, "foo", 0, []byte("msg")) } -func TestMQTTUnsupportedPackets(t *testing.T) { - o := testMQTTDefaultOptions() - s := testMQTTRunServer(t, o) - defer testMQTTShutdownServer(s) - - for _, test := range []struct { - name string - packetType byte - }{ - {"pubrec", mqttPacketPubRec}, - {"pubrel", mqttPacketPubRel}, - {"pubcomp", mqttPacketPubComp}, - } { - t.Run(test.name, func(t *testing.T) { - mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) - defer mc.Close() - testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) - - w := &mqttWriter{} - pt := test.packetType - if test.packetType == mqttPacketPubRel { - pt |= byte(0x2) - } - w.WriteByte(pt) - w.WriteVarInt(2) - w.WriteUint16(1) - mc.Write(w.Bytes()) - - testMQTTExpectDisconnect(t, mc) - }) - } -} - func TestMQTTTopicAndSubjectConversion(t *testing.T) { for _, test := range []struct { name string @@ -1933,13 +1912,13 @@ func TestMQTTSubAck(t *testing.T) { subs := []*mqttFilter{ {filter: "foo", qos: 0}, {filter: "bar", qos: 1}, - {filter: "baz", qos: 2}, // Since we don't support, we should receive a result of 1 + {filter: "baz", qos: 2}, {filter: "foo/#/bar", qos: 0}, // Invalid sub, so we should receive a result of mqttSubAckFailure } expected := []byte{ 0, 1, - 1, + 2, mqttSubAckFailure, } testMQTTSub(t, 1, mc, r, subs, expected) @@ -1975,37 +1954,48 @@ func testMQTTExpectNothing(t testing.TB, r *mqttReader) { r.reader.SetReadDeadline(time.Time{}) } -func testMQTTCheckPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, flags byte, payload []byte) { +func testMQTTCheckPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, expectedFlags byte, payload []byte) uint16 { t.Helper() - pflags, pi := testMQTTGetPubMsg(t, c, r, topic, payload) - if pflags != flags { - t.Fatalf("Expected flags to be %x, got %x", flags, pflags) + pi := testMQTTCheckPubMsgNoAck(t, c, r, topic, expectedFlags, payload) + if pi == 0 { + return 0 } - if pi > 0 { - testMQTTSendPubAck(t, c, pi) + qos := mqttGetQoS(expectedFlags) + switch qos { + case 1: + testMQTTSendPIPacket(mqttPacketPubAck, t, c, pi) + case 2: + testMQTTSendPIPacket(mqttPacketPubRec, t, c, pi) } + return pi } -func testMQTTCheckPubMsgNoAck(t testing.TB, c net.Conn, r *mqttReader, topic string, flags byte, payload []byte) uint16 { +func testMQTTCheckPubMsgNoAck(t testing.TB, c net.Conn, r *mqttReader, topic string, expectedFlags byte, payload []byte) uint16 { t.Helper() pflags, pi := testMQTTGetPubMsg(t, c, r, topic, payload) - if pflags != flags { - t.Fatalf("Expected flags to be %x, got %x", flags, pflags) + if pflags != expectedFlags { + t.Fatalf("Expected flags to be %x, got %x", expectedFlags, pflags) } return pi } func testMQTTGetPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, payload []byte) (byte, uint16) { + t.Helper() flags, pi, _ := testMQTTGetPubMsgEx(t, c, r, topic, payload) return flags, pi } -func testMQTTGetPubMsgEx(t testing.TB, c net.Conn, r *mqttReader, topic string, payload []byte) (byte, uint16, string) { +func testMQTTGetPubMsgEx(t testing.TB, _ net.Conn, r *mqttReader, topic string, payload []byte) (byte, uint16, string) { t.Helper() b, pl := testMQTTReadPacket(t, r) if pt := b & mqttPacketMask; pt != mqttPacketPub { t.Fatalf("Expected PUBLISH packet %x, got %x", mqttPacketPub, pt) } + return testMQTTGetPubMsgExEx(t, nil, r, b, pl, topic, payload) +} + +func testMQTTGetPubMsgExEx(t testing.TB, _ net.Conn, r *mqttReader, b byte, pl int, topic string, payload []byte) (byte, uint16, string) { + t.Helper() pflags := b & mqttPacketFlagMask qos := (pflags & mqttPubFlagQoS) >> 1 start := r.pos @@ -2036,14 +2026,23 @@ func testMQTTGetPubMsgEx(t testing.TB, c net.Conn, r *mqttReader, topic string, return pflags, pi, ptopic } -func testMQTTSendPubAck(t testing.TB, c net.Conn, pi uint16) { +func testMQTTSendPIPacket(packetType byte, t testing.TB, c net.Conn, pi uint16) { t.Helper() w := &mqttWriter{} - w.WriteByte(mqttPacketPubAck) + w.WriteByte(packetType) w.WriteVarInt(2) w.WriteUint16(pi) if _, err := testMQTTWrite(c, w.Bytes()); err != nil { - t.Fatalf("Error writing PUBACK: %v", err) + t.Fatalf("Error writing packet type %v: %v", packetType, err) + } +} + +func testMQTTPublishNoAcks(t testing.TB, c net.Conn, qos byte, dup, retain bool, topic string, pi uint16, payload []byte) { + t.Helper() + w := &mqttWriter{} + mqttWritePublish(w, qos, dup, retain, topic, pi, payload) + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing PUBLISH proto: %v", err) } } @@ -2054,12 +2053,8 @@ func testMQTTPublish(t testing.TB, c net.Conn, r *mqttReader, qos byte, dup, ret if _, err := testMQTTWrite(c, w.Bytes()); err != nil { t.Fatalf("Error writing PUBLISH proto: %v", err) } - if qos > 0 { - // Since we don't support QoS 2, we should get disconnected - if qos == 2 { - testMQTTExpectDisconnect(t, c) - return - } + switch qos { + case 1: b, _ := testMQTTReadPacket(t, r) if pt := b & mqttPacketMask; pt != mqttPacketPubAck { t.Fatalf("Expected PUBACK packet %x, got %x", mqttPacketPubAck, pt) @@ -2068,6 +2063,29 @@ func testMQTTPublish(t testing.TB, c net.Conn, r *mqttReader, qos byte, dup, ret if err != nil || rpi != pi { t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err) } + + case 2: + b, _ := testMQTTReadPacket(t, r) + if pt := b & mqttPacketMask; pt != mqttPacketPubRec { + t.Fatalf("Expected PUBREC packet %x, got %x", mqttPacketPubRec, pt) + } + rpi, err := r.readUint16("packet identifier") + if err != nil || rpi != pi { + t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err) + } + + testMQTTSendPIPacket(mqttPacketPubRel, t, c, pi) + + b, _ = testMQTTReadPacket(t, r) + if pt := b & mqttPacketMask; pt != mqttPacketPubComp { + t.Fatalf("Expected PUBCOMP packet %x, got %x", mqttPacketPubComp, pt) + } + rpi, err = r.readUint16("packet identifier") + if err != nil || rpi != pi { + t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err) + } + + testMQTTFlush(t, c, nil, r) } } @@ -2079,7 +2097,7 @@ func TestMQTTParsePub(t *testing.T) { pl int err string }{ - {"qos not supported", 0x4, nil, 0, "not supported"}, + {"qos not supported", (3 << 1), nil, 0, "QoS=3 is invalid in MQTT"}, {"packet in buffer error", 0, nil, 10, io.ErrUnexpectedEOF.Error()}, {"error on topic", 0, []byte{0, 3, 'f', 'o'}, 4, "topic"}, {"empty topic", 0, []byte{0, 0}, 2, errMQTTTopicIsEmpty.Error()}, @@ -2100,7 +2118,7 @@ func TestMQTTParsePub(t *testing.T) { } } -func TestMQTTParsePubAck(t *testing.T) { +func TestMQTTParsePIMsg(t *testing.T) { for _, test := range []struct { name string proto []byte @@ -2114,7 +2132,7 @@ func TestMQTTParsePubAck(t *testing.T) { t.Run(test.name, func(t *testing.T) { r := &mqttReader{} r.reset(test.proto) - if _, err := mqttParsePubAck(r, test.pl); err == nil || !strings.Contains(err.Error(), test.err) { + if _, err := mqttParsePIPacket(r, test.pl); err == nil || !strings.Contains(err.Error(), test.err) { t.Fatalf("Expected error %q, got %v", test.err, err) } }) @@ -2204,7 +2222,72 @@ func TestMQTTSub(t *testing.T) { } } -func TestMQTTSubQoS(t *testing.T) { +func TestMQTTSubQoS2(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + topic := "foo/bar/baz" + mqttTopic0 := "foo/#" + mqttTopic1 := "foo/bar/#" + mqttTopic2 := topic + testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: mqttTopic0, qos: 0}}, []byte{0}) + testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: mqttTopic1, qos: 1}}, []byte{1}) + testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: mqttTopic2, qos: 2}}, []byte{2}) + testMQTTFlush(t, mc, nil, r) + + for pubQoS, expectedCounts := range map[byte]map[byte]int{ + 0: {0: 3}, + 1: {0: 1, 1: 2}, + 2: {0: 1, 1: 1, 2: 1}, + } { + t.Run(fmt.Sprintf("pubQoS %v", pubQoS), func(t *testing.T) { + pubPI := uint16(456) + + testMQTTPublish(t, mcp, mpr, pubQoS, false, false, topic, pubPI, []byte("msg")) + + qosCounts := map[byte]int{} + delivered := map[uint16]byte{} + + // We have 3 subscriptions, each should receive the message, with the + // QoS that maybe "trimmed" to that of the subscription. + for i := 0; i < 3; i++ { + flags, pi := testMQTTGetPubMsg(t, mc, r, topic, []byte("msg")) + delivered[pi] = flags + qosCounts[mqttGetQoS(flags)]++ + } + + for pi, flags := range delivered { + switch mqttGetQoS(flags) { + case 1: + testMQTTSendPIPacket(mqttPacketPubAck, t, mc, pi) + + case 2: + testMQTTSendPIPacket(mqttPacketPubRec, t, mc, pi) + testMQTTReadPIPacket(mqttPacketPubRel, t, r, pi) + testMQTTSendPIPacket(mqttPacketPubComp, t, mc, pi) + } + } + + if !reflect.DeepEqual(qosCounts, expectedCounts) { + t.Fatalf("Expected QoS %#v, got %#v", expectedCounts, qosCounts) + } + }) + } +} + +func TestMQTTSubQoS1(t *testing.T) { o := testMQTTDefaultOptions() s := testMQTTRunServer(t, o) defer testMQTTShutdownServer(s) @@ -2252,8 +2335,8 @@ func TestMQTTSubQoS(t *testing.T) { if pi1 == pi2 { t.Fatalf("packet identifier for message 1: %v should be different from message 2", pi1) } - testMQTTSendPubAck(t, mc, pi1) - testMQTTSendPubAck(t, mc, pi2) + testMQTTSendPIPacket(mqttPacketPubAck, t, mc, pi1) + testMQTTSendPIPacket(mqttPacketPubAck, t, mc, pi2) } func getSubQoS(sub *subscription) int { @@ -2286,8 +2369,8 @@ func TestMQTTSubDups(t *testing.T) { // And also with separate SUBSCRIBE protocols testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "bar", qos: 0}}, []byte{0}) - // Ask for QoS 2 but server will downgrade to 1 - testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "bar", qos: 2}}, []byte{1}) + // Ask for QoS 1 but server will downgrade to 1 + testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "bar", qos: 1}}, []byte{1}) testMQTTFlush(t, mc, nil, r) // Publish and test msg received only once @@ -2610,26 +2693,25 @@ func TestMQTTSubWithNATSStream(t *testing.T) { } func TestMQTTTrackPendingOverrun(t *testing.T) { - sess := &mqttSession{pending: make(map[uint16]*mqttPending)} - sub := &subscription{mqtt: &mqttSub{qos: 1}} + sess := mqttSession{} - sess.ppi = 0xFFFF - pi, _ := sess.trackPending(1, _EMPTY_, sub) + sess.last_pi = 0xFFFF + pi := sess.trackPublishRetained() if pi != 1 { t.Fatalf("Expected 1, got %v", pi) } p := &mqttPending{} for i := 1; i <= 0xFFFF; i++ { - sess.pending[uint16(i)] = p + sess.pendingPublish[uint16(i)] = p } - pi, _ = sess.trackPending(1, _EMPTY_, sub) + pi, _ = sess.trackPublish("test", "test") if pi != 0 { t.Fatalf("Expected 0, got %v", pi) } - delete(sess.pending, 1234) - pi, _ = sess.trackPending(1, _EMPTY_, sub) + delete(sess.pendingPublish, 1234) + pi = sess.trackPublishRetained() if pi != 1234 { t.Fatalf("Expected 1234, got %v", pi) } @@ -4694,7 +4776,7 @@ func TestMQTTRedeliveryAckWait(t *testing.T) { } } // Ack first message - testMQTTSendPubAck(t, c, 1) + testMQTTSendPIPacket(mqttPacketPubAck, t, c, 1) // Redelivery should only be for second message now for i := 0; i < 2; i++ { flags := mqttPubQos1 | mqttPubFlagDup @@ -4715,7 +4797,7 @@ func TestMQTTRedeliveryAckWait(t *testing.T) { t.Fatalf("Unexpected pi to be 2, got %v", pi) } // Now ack second message - testMQTTSendPubAck(t, c, 2) + testMQTTSendPIPacket(mqttPacketPubAck, t, c, 2) // Flush to make sure it is processed before checking client's maps testMQTTFlush(t, c, nil, r) @@ -4724,7 +4806,7 @@ func TestMQTTRedeliveryAckWait(t *testing.T) { mc.mu.Lock() sess := mc.mqtt.sess sess.mu.Lock() - lpi := len(sess.pending) + lpi := len(sess.pendingPublish) var lsseq int for _, sseqToPi := range sess.cpending { lsseq += len(sseqToPi) @@ -4736,6 +4818,253 @@ func TestMQTTRedeliveryAckWait(t *testing.T) { } } +// - [MQTT-3.10.4-3] If a Server deletes a Subscription It MUST complete the +// delivery of any QoS 1 or QoS 2 messages which it has started to send to the +// Client. +// +// Test flow: +// - Subscribe to foo, publish 3 QoS2 messages. +// - After one is PUBCOMP-ed, and one is PUBREC-ed, Unsubscribe. +// - See that the remaining 2 are fully delivered. +func TestMQTTQoS2InflightMsgsDeliveredAfterUnsubscribe(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.AckWait = 10 * time.Millisecond + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + var qos2 byte = 2 + cisub := &mqttConnInfo{clientID: "sub", cleanSess: true} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: qos2}}, []byte{qos2}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + // send 3 messages + testMQTTPublish(t, cp, rp, qos2, false, false, "foo", 441, []byte("data1")) + testMQTTPublish(t, cp, rp, qos2, false, false, "foo", 442, []byte("data2")) + testMQTTPublish(t, cp, rp, qos2, false, false, "foo", 443, []byte("data3")) + + testMQTTDisconnect(t, cp, nil) + cp.Close() + + subPI1 := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQoS2, []byte("data1")) + subPI2 := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQoS2, []byte("data2")) + // subPI3 := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQoS2, []byte("data3")) + _ = testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQoS2, []byte("data3")) + + // fully receive first message + testMQTTSendPIPacket(mqttPacketPubRec, t, c, subPI1) + testMQTTReadPIPacket(mqttPacketPubRel, t, r, subPI1) + testMQTTSendPIPacket(mqttPacketPubComp, t, c, subPI1) + + // Do not PUBCOMP the 2nd message yet. + testMQTTSendPIPacket(mqttPacketPubRec, t, c, subPI2) + testMQTTReadPIPacket(mqttPacketPubRel, t, r, subPI2) + + // Unsubscribe + testMQTTUnsub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: qos2}}) + + // We expect that PI2 and PI3 will continue to be delivered, from their + // respective states. + gotPI2PubRel := false + + // TODO: Currently, we do not get the unacknowledged PUBLISH re-delivered + // after an UNSUBSCRIBE. Ongoing discussion if we should/must. + // gotPI3Publish := false + // gotPI3PubRel := false + for !gotPI2PubRel /* || !gotPI3Publish || !gotPI3PubRel */ { + b, _ /* len */ := testMQTTReadPacket(t, r) + switch b & mqttPacketMask { + case mqttPacketPubRel: + pi, err := r.readUint16("packet identifier") + if err != nil { + t.Fatalf("got unexpected error: %v", err) + } + switch pi { + case subPI2: + testMQTTSendPIPacket(mqttPacketPubComp, t, c, pi) + gotPI2PubRel = true + // case subPI3: + // testMQTTSendPIPacket(mqttPacketPubComp, t, c, pi) + // gotPI3PubRel = true + default: + t.Fatalf("Expected PI %v got: %v", subPI2, pi) + } + + // case mqttPacketPub: + // _, pi, _ := testMQTTGetPubMsgExEx(t, c, r, b, len, "foo", []byte("data3")) + // if pi != subPI3 { + // t.Fatalf("Expected PI %v got: %v", subPI3, pi) + // } + // gotPI3Publish = true + // testMQTTSendPIPacket(mqttPacketPubRec, t, c, subPI3) + + default: + t.Fatalf("Unexpected packet type: %v", b&mqttPacketMask) + } + } + + testMQTTExpectNothing(t, r) +} + +func TestMQTTQoS2RejectPublishDuplicates(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + var qos2 byte = 2 + cisub := &mqttConnInfo{clientID: "sub", cleanSess: true} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: qos2}}, []byte{qos2}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + // Publish 3 different with same PI before we get any PUBREC back, then + // complete the PUBREL/PUBCOMP flow as needed. Only one message (first + // payload) should be delivered. PUBRECs, + var pubPI uint16 = 444 + testMQTTPublishNoAcks(t, cp, qos2, false, false, "foo", pubPI, []byte("data1")) + testMQTTPublishNoAcks(t, cp, qos2, true, false, "foo", pubPI, []byte("data2")) + testMQTTPublishNoAcks(t, cp, qos2, false, false, "foo", pubPI, []byte("data3")) + + for i := 0; i < 3; i++ { + // [MQTT-4.3.3-1] The receiver + // + // - MUST respond with a PUBREC containing the Packet Identifier from + // the incoming PUBLISH Packet, having accepted ownership of the + // Application Message. + // + // - Until it has received the corresponding PUBREL packet, the Receiver + // MUST acknowledge any subsequent PUBLISH packet with the same Packet + // Identifier by sending a PUBREC. It MUST NOT cause duplicate messages + // to be delivered to any onward recipients in this case. + testMQTTReadPIPacket(mqttPacketPubRec, t, rp, pubPI) + } + for i := 0; i < 3; i++ { + testMQTTSendPIPacket(mqttPacketPubRel, t, cp, pubPI) + } + for i := 0; i < 3; i++ { + // [MQTT-4.3.3-1] MUST respond to a PUBREL packet by sending a PUBCOMP + // packet containing the same Packet Identifier as the PUBREL. + testMQTTReadPIPacket(mqttPacketPubComp, t, rp, pubPI) + } + + // [MQTT-4.3.3-1] After it has sent a PUBCOMP, the receiver MUST treat any + // subsequent PUBLISH packet that contains that Packet Identifier as being a + // new publication. + // + // Publish another message, identical to the first one. Since the server + // already sent us a PUBCOMP, it will deliver this message, for a total of 2 + // delivered. + testMQTTPublish(t, cp, rp, qos2, false, false, "foo", pubPI, []byte("data5")) + + testMQTTDisconnect(t, cp, nil) + cp.Close() + + // Verify we got a total of 2 messages. + subPI1 := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQoS2, []byte("data1")) + subPI2 := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQoS2, []byte("data5")) + for _, pi := range []uint16{subPI1, subPI2} { + testMQTTSendPIPacket(mqttPacketPubRec, t, c, pi) + testMQTTReadPIPacket(mqttPacketPubRel, t, r, pi) + testMQTTSendPIPacket(mqttPacketPubComp, t, c, pi) + } + testMQTTExpectNothing(t, r) +} + +func TestMQTTQoS2RetriesPublish(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.AckWait = 10 * time.Millisecond + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + var qos2 byte = 2 + cisub := &mqttConnInfo{clientID: "sub", cleanSess: true} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: qos2}}, []byte{qos2}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + // Publish a message and close the pub connection. + var pubPI uint16 = 444 + testMQTTPublish(t, cp, rp, qos2, false, false, "foo", pubPI, []byte("data1")) + testMQTTDisconnect(t, cp, nil) + cp.Close() + + // See that we got the message delivered to the sub, but don't PUBREC it + // yet. + subPI := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQoS2, []byte("data1")) + + // See that the message is redelivered again 3 times, with the DUP on, before we PUBREC it. + for i := 0; i < 3; i++ { + expectedFlags := mqttPubQoS2 | mqttPubFlagDup + pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", expectedFlags, []byte("data1")) + if pi != subPI { + t.Fatalf("Expected pi to be %v, got %v", subPI, pi) + } + } + + // Finish the exchange and make sure there are no more attempts. + testMQTTSendPIPacket(mqttPacketPubRec, t, c, subPI) + testMQTTReadPIPacket(mqttPacketPubRel, t, r, subPI) + testMQTTSendPIPacket(mqttPacketPubComp, t, c, subPI) + testMQTTExpectNothing(t, r) +} + +func TestMQTTQoS2RetriesPubRel(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.AckWait = 10 * time.Millisecond + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + var qos2 byte = 2 + cisub := &mqttConnInfo{clientID: "sub", cleanSess: true} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: qos2}}, []byte{qos2}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + // Publish a message and close the pub connection. + var pubPI uint16 = 444 + testMQTTPublish(t, cp, rp, qos2, false, false, "foo", pubPI, []byte("data1")) + testMQTTDisconnect(t, cp, nil) + cp.Close() + + // See that we got the message delivered to the sub, PUBREC it and expect a + // PUBREL from the server. + subPI := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQoS2, []byte("data1")) + testMQTTSendPIPacket(mqttPacketPubRec, t, c, subPI) + + // See that we get PUBREL redelivered several times, there's no DUP flag to + // check. + testMQTTReadPIPacket(mqttPacketPubRel, t, r, subPI) + testMQTTReadPIPacket(mqttPacketPubRel, t, r, subPI) + testMQTTReadPIPacket(mqttPacketPubRel, t, r, subPI) + + // Finish the exchange and make sure there are no more attempts. + testMQTTSendPIPacket(mqttPacketPubComp, t, c, subPI) + testMQTTExpectNothing(t, r) +} + func TestMQTTAckWaitConfigChange(t *testing.T) { o := testMQTTDefaultOptions() o.MQTT.AckWait = 250 * time.Millisecond @@ -4842,7 +5171,7 @@ func TestMQTTUnsubscribeWithPendingAcks(t *testing.T) { mc.mu.Lock() sess := mc.mqtt.sess sess.mu.Lock() - pal := len(sess.pending) + pal := len(sess.pendingPublish) sess.mu.Unlock() mc.mu.Unlock() if pal != 0 { @@ -4880,7 +5209,7 @@ func TestMQTTMaxAckPending(t *testing.T) { testMQTTExpectNothing(t, r) // Now ack first message - testMQTTSendPubAck(t, c, pi) + testMQTTSendPIPacket(mqttPacketPubAck, t, c, pi) // Now we should receive message 2 testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg2")) testMQTTDisconnect(t, c, nil) @@ -4909,7 +5238,7 @@ func TestMQTTMaxAckPending(t *testing.T) { testMQTTExpectNothing(t, r) // Ack and get the next - testMQTTSendPubAck(t, c, pi) + testMQTTSendPIPacket(mqttPacketPubAck, t, c, pi) testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg4")) // Make sure this message gets ack'ed @@ -4918,7 +5247,7 @@ func TestMQTTMaxAckPending(t *testing.T) { mcli.mu.Lock() sess := mcli.mqtt.sess sess.mu.Lock() - np := len(sess.pending) + np := len(sess.pendingPublish) sess.mu.Unlock() mcli.mu.Unlock() if np != 0 { @@ -4956,7 +5285,7 @@ func TestMQTTMaxAckPending(t *testing.T) { testMQTTExpectNothing(t, r) // Ack and get the next - testMQTTSendPubAck(t, c, pi) + testMQTTSendPIPacket(mqttPacketPubAck, t, c, pi) testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg6")) } @@ -4991,7 +5320,7 @@ func TestMQTTMaxAckPendingForMultipleSubs(t *testing.T) { testMQTTExpectNothing(t, r) // Ack the first message. - testMQTTSendPubAck(t, c, pi) + testMQTTSendPIPacket(mqttPacketPubAck, t, c, pi) // Now we should get the second message testMQTTCheckPubMsg(t, c, r, "bar", mqttPubQos1|mqttPubFlagDup, []byte("msg2")) diff --git a/server/opts.go b/server/opts.go index fd5f3234..234b80d4 100644 --- a/server/opts.go +++ b/server/opts.go @@ -517,21 +517,24 @@ type MQTTOpts struct { // Set of allowable certificates TLSPinnedCerts PinnedCertSet - // AckWait is the amount of time after which a QoS 1 message sent to - // a client is redelivered as a DUPLICATE if the server has not - // received the PUBACK on the original Packet Identifier. - // The value has to be positive. - // Zero will cause the server to use the default value (30 seconds). - // Note that changes to this option is applied only to new MQTT subscriptions. + // AckWait is the amount of time after which a QoS 1 or 2 message sent to a + // client is redelivered as a DUPLICATE if the server has not received the + // PUBACK on the original Packet Identifier. The same value applies to + // PubRel redelivery. The value has to be positive. Zero will cause the + // server to use the default value (30 seconds). Note that changes to this + // option is applied only to new MQTT subscriptions (or sessions for + // PubRels). AckWait time.Duration - // MaxAckPending is the amount of QoS 1 messages the server can send to - // a subscription without receiving any PUBACK for those messages. - // The valid range is [0..65535]. + // MaxAckPending is the amount of QoS 1 and 2 messages (combined) the server + // can send to a subscription without receiving any PUBACK for those + // messages. The valid range is [0..65535]. + // // The total of subscriptions' MaxAckPending on a given session cannot - // exceed 65535. Attempting to create a subscription that would bring - // the total above the limit would result in the server returning 0x80 - // in the SUBACK for this subscription. + // exceed 65535. Attempting to create a subscription that would bring the + // total above the limit would result in the server returning 0x80 in the + // SUBACK for this subscription. + // // Due to how the NATS Server handles the MQTT "#" wildcard, each // subscription ending with "#" will use 2 times the MaxAckPending value. // Note that changes to this option is applied only to new subscriptions.