diff --git a/server/accounts.go b/server/accounts.go index 9a68c032..02d32028 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -1726,7 +1726,7 @@ func (a *Account) subscribeInternal(subject string, cb msgHandler) (*subscriptio return nil, fmt.Errorf("no internal account client") } - return c.processSub([]byte(subject), nil, []byte(sid), cb, false) + return c.processSub(c.createSub([]byte(subject), nil, []byte(sid), cb), false) } // This will add an account subscription that matches the "from" from a service import entry. @@ -1751,7 +1751,7 @@ func (a *Account) addServiceImportSub(si *serviceImport) error { cb := func(sub *subscription, c *client, subject, reply string, msg []byte) { c.processServiceImport(si, a, msg) } - _, err := c.processSub([]byte(subject), nil, []byte(sid), cb, true) + _, err := c.processSub(c.createSub([]byte(subject), nil, []byte(sid), cb), true) return err } @@ -1951,7 +1951,7 @@ func (a *Account) createRespWildcard() []byte { a.mu.Unlock() // Create subscription and internal callback for all the wildcard response subjects. - c.processSub(wcsub, nil, []byte(sid), a.processServiceImportResponse, false) + c.processSub(c.createSub(wcsub, nil, []byte(sid), a.processServiceImportResponse), false) return pre } diff --git a/server/auth.go b/server/auth.go index 7043b2cc..da27dc7e 100644 --- a/server/auth.go +++ b/server/auth.go @@ -253,6 +253,8 @@ func (s *Server) configureAuthorization() { // Do similar for websocket config s.wsConfigAuth(&opts.Websocket) + // And for mqtt config + s.mqttConfigAuth(&opts.MQTT) } // Takes the given slices of NkeyUser and User options and build @@ -343,11 +345,15 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo ) s.mu.Lock() authRequired := s.info.AuthRequired - // c.ws is immutable, but may need lock if we get race reports. - if !authRequired && c.ws != nil { - // If no auth required for regular clients, then check if - // we have an override for websocket clients. - authRequired = s.websocket.authOverride + // c.ws/mqtt is immutable, but may need lock if we get race reports. + if !authRequired { + if c.mqtt != nil { + authRequired = s.mqtt.authOverride + } else if c.ws != nil { + // If no auth required for regular clients, then check if + // we have an override for websocket clients. + authRequired = s.websocket.authOverride + } } if !authRequired { // TODO(dlc) - If they send us credentials should we fail? @@ -361,7 +367,20 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo noAuthUser string ) tlsMap := opts.TLSMap - if c.ws != nil { + if c.mqtt != nil { + mo := &opts.MQTT + // Always override TLSMap. + tlsMap = mo.TLSMap + // The rest depends on if there was any auth override in + // the mqtt's config. + if s.mqtt.authOverride { + noAuthUser = mo.NoAuthUser + username = mo.Username + password = mo.Password + token = mo.Token + ao = true + } + } else if c.ws != nil { wo := &opts.Websocket // Always override TLSMap. tlsMap = wo.TLSMap @@ -998,7 +1017,7 @@ func validateAllowedConnectionTypes(m map[string]struct{}) error { for ct := range m { ctuc := strings.ToUpper(ct) switch ctuc { - case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode: + case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeMqtt: default: return fmt.Errorf("unknown connection type %q", ct) } diff --git a/server/client.go b/server/client.go index 6b9bb532..3a68b9d2 100644 --- a/server/client.go +++ b/server/client.go @@ -178,6 +178,7 @@ const ( NoRespondersRequiresHeaders ClusterNameConflict DuplicateRemoteLeafnodeConnection + DuplicateClientID ) // Some flags passed to processMsgResults @@ -237,6 +238,7 @@ type client struct { gw *gateway leaf *leaf ws *websocket + mqtt *mqtt // To keep track of gateway replies mapping gwrm map[string]*gwReplyMap @@ -442,6 +444,7 @@ type subscription struct { max int64 qw int32 closed int32 + mqtt *mqttSub } // Indicate that this subscription is closed. @@ -540,6 +543,8 @@ func (c *client) initClient() { name := "cid" if c.ws != nil { name = "wid" + } else if c.mqtt != nil { + name = "mid" } c.ncs.Store(fmt.Sprintf("%s - %s:%d", conn, name, c.cid)) case ROUTER: @@ -964,6 +969,9 @@ func (c *client) readLoop(pre []byte) { } nc := c.nc ws := c.ws != nil + if c.mqtt != nil { + c.mqtt.r = &mqttReader{reader: nc} + } c.in.rsz = startBufSize // Snapshot max control line since currently can not be changed on reload and we // were checking it on each call to parse. If this changes and we allow MaxControlLine @@ -983,6 +991,9 @@ func (c *client) readLoop(pre []byte) { c.mu.Unlock() defer func() { + if c.mqtt != nil { + s.mqttHandleWill(c) + } // These are used only in the readloop, so we can set them to nil // on exit of the readLoop. c.in.results, c.in.pacache = nil, nil @@ -1683,7 +1694,6 @@ func (c *client) processConnect(arg []byte) error { // By default register with the global account. c.registerWithAccount(srv.globalAccount()) } - } switch kind { @@ -1703,7 +1713,6 @@ func (c *client) processConnect(arg []byte) error { c.sendErr(ErrNoRespondersRequiresHeaders.Error()) c.closeConnection(NoRespondersRequiresHeaders) return ErrNoRespondersRequiresHeaders - } if verbose { c.sendOK() @@ -1933,6 +1942,9 @@ func (c *client) sendRTTPing() bool { // the c.rtt is 0 and wants to force an update by sending a PING. // Client lock held on entry. func (c *client) sendRTTPingLocked() bool { + if c.mqtt != nil { + return false + } // Most client libs send a CONNECT+PING and wait for a PONG from the // server. So if firstPongSent flag is set, it is ok for server to // send the PING. But in case we have client libs that don't do that, @@ -1978,7 +1990,9 @@ func (c *client) sendErr(err string) { if c.trace { c.traceOutOp("-ERR", []byte(err)) } - c.enqueueProto([]byte(fmt.Sprintf(errProto, err))) + if c.mqtt == nil { + c.enqueueProto([]byte(fmt.Sprintf(errProto, err))) + } c.mu.Unlock() } @@ -2212,32 +2226,30 @@ func (c *client) parseSub(argo []byte, noForward bool) error { arg := make([]byte, len(argo)) copy(arg, argo) args := splitArg(arg) - var ( - subject []byte - queue []byte - sid []byte - ) + sub := &subscription{client: c} switch len(args) { case 2: - subject = args[0] - queue = nil - sid = args[1] + sub.subject = args[0] + sub.queue = nil + sub.sid = args[1] case 3: - subject = args[0] - queue = args[1] - sid = args[2] + sub.subject = args[0] + sub.queue = args[1] + sub.sid = args[2] default: return fmt.Errorf("processSub Parse Error: '%s'", arg) } // If there was an error, it has been sent to the client. We don't return an // error here to not close the connection as a parsing error. - c.processSub(subject, queue, sid, nil, noForward) + c.processSub(sub, noForward) return nil } -func (c *client) processSub(subject, queue, bsid []byte, cb msgHandler, noForward bool) (*subscription, error) { - // Create the subscription - sub := &subscription{client: c, subject: subject, queue: queue, sid: bsid, icb: cb} +func (c *client) createSub(subject, queue, sid []byte, cb msgHandler) *subscription { + return &subscription{client: c, subject: subject, queue: queue, sid: sid, icb: cb} +} + +func (c *client) processSub(sub *subscription, noForward bool) (*subscription, error) { c.mu.Lock() @@ -2254,7 +2266,7 @@ func (c *client) processSub(subject, queue, bsid []byte, cb msgHandler, noForwar // This check does not apply to SYSTEM or JETSTREAM or ACCOUNT clients (because they don't have a `nc`...) if c.isClosed() && (kind != SYSTEM && kind != JETSTREAM && kind != ACCOUNT) { c.mu.Unlock() - return sub, nil + return nil, nil } // Check permissions if applicable. @@ -2301,6 +2313,9 @@ func (c *client) processSub(subject, queue, bsid []byte, cb msgHandler, noForwar updateGWs = c.srv.gateway.enabled } } + } else if es.mqtt != nil && sub.mqtt != nil { + es.mqtt.prm = sub.mqtt.prm + es.mqtt.qos = sub.mqtt.qos } // Unlocked from here onward c.mu.Unlock() @@ -3311,6 +3326,11 @@ func (c *client) processInboundClientMsg(msg []byte) bool { c.sendOK() } + // If MQTT client, check for retain flag now that we have passed permissions check + if c.mqtt != nil { + c.mqttHandlePubRetain() + } + // Check if this client's gateway replies map is not empty if atomic.LoadInt32(&c.cgwrt) > 0 && c.handleGWReplyMap(msg) { return true @@ -3983,7 +4003,7 @@ func (c *client) processPingTimer() { c.mu.Lock() c.ping.tmr = nil // Check if connection is still opened - if c.isClosed() { + if c.isClosed() || c.mqtt != nil { c.mu.Unlock() return } @@ -4042,7 +4062,7 @@ func adjustPingIntervalForGateway(d time.Duration) time.Duration { // Lock should be held func (c *client) setPingTimer() { - if c.srv == nil { + if c.srv == nil || c.mqtt != nil { return } d := c.srv.getOpts().PingInterval @@ -4596,7 +4616,7 @@ func convertAllowedConnectionTypes(cts []string) (map[string]struct{}, error) { for _, i := range cts { i = strings.ToUpper(i) switch i { - case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode: + case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeMqtt: m[i] = struct{}{} default: unknown = append(unknown, i) @@ -4627,6 +4647,9 @@ func (c *client) connectionTypeAllowed(acts map[string]struct{}) bool { if c.ws != nil { want = jwt.ConnectionTypeWebsocket } + if c.mqtt != nil { + want = jwt.ConnectionTypeMqtt + } _, ok := acts[want] return ok } diff --git a/server/client_test.go b/server/client_test.go index f207854e..68e48dcc 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -85,7 +85,7 @@ func createClientAsync(ch chan *client, s *Server, cli net.Conn) { s.grWG.Add(1) } go func() { - c := s.createClient(cli, nil) + c := s.createClient(cli, nil, nil) // Must be here to suppress +OK c.opts.Verbose = false if startWriteLoop { @@ -2317,7 +2317,7 @@ func TestCloseConnectionVeryEarly(t *testing.T) { // Call again with this closed connection. Alternatively, we // would have to call with a fake connection that implements // net.Conn but returns an error on Write. - s.createClient(c, nil) + s.createClient(c, nil, nil) // This connection should not have been added to the server. checkClientsCount(t, s, 0) @@ -2354,18 +2354,23 @@ func TestClientConnectionName(t *testing.T) { kind int kindStr string ws bool + mqtt bool }{ - {"client", CLIENT, "cid:", false}, - {"ws client", CLIENT, "wid:", true}, - {"route", ROUTER, "rid:", false}, - {"gateway", GATEWAY, "gid:", false}, - {"leafnode", LEAF, "lid:", false}, + {"client", CLIENT, "cid:", false, false}, + {"ws client", CLIENT, "wid:", true, false}, + {"mqtt client", CLIENT, "mid:", false, true}, + {"route", ROUTER, "rid:", false, false}, + {"gateway", GATEWAY, "gid:", false, false}, + {"leafnode", LEAF, "lid:", false, false}, } { t.Run(test.name, func(t *testing.T) { c := &client{srv: s, nc: &connString{}, kind: test.kind} if test.ws { c.ws = &websocket{} } + if test.mqtt { + c.mqtt = &mqtt{} + } c.initClient() if host := "fe80::abc:def:ghi:123%utun0"; host != c.host { diff --git a/server/config_check_test.go b/server/config_check_test.go index 166aa761..45d726e9 100644 --- a/server/config_check_test.go +++ b/server/config_check_test.go @@ -1301,8 +1301,8 @@ func TestConfigCheck(t *testing.T) { user: user1 password: pwd users = [{user: user2, password: pwd}] - } - } + } + } `, err: errors.New("can not have a single user/pass and a users array"), errorLine: 3, @@ -1312,17 +1312,75 @@ func TestConfigCheck(t *testing.T) { name: "duplicate usernames in leafnode authorization", config: ` leafnodes { - authorization { - users = [ - {user: user, password: pwd} - {user: user, password: pwd} - ] - } - } + authorization { + users = [ + {user: user, password: pwd} + {user: user, password: pwd} + ] + } + } `, err: errors.New(`duplicate user "user" detected in leafnode authorization`), errorLine: 3, - errorPos: 20, + errorPos: 21, + }, + { + name: "mqtt bad type", + config: ` + mqtt [ + "wrong" + ] + `, + err: errors.New(`Expected mqtt to be a map, got []interface {}`), + errorLine: 2, + errorPos: 17, + }, + { + name: "mqtt bad listen", + config: ` + mqtt { + listen: "xxxxxxxx" + } + `, + err: errors.New(`could not parse address string "xxxxxxxx"`), + errorLine: 3, + errorPos: 21, + }, + { + name: "mqtt bad host", + config: ` + mqtt { + host: 1234 + } + `, + err: errors.New(`interface conversion: interface {} is int64, not string`), + errorLine: 3, + errorPos: 21, + }, + { + name: "mqtt bad port", + config: ` + mqtt { + port: "abc" + } + `, + err: errors.New(`interface conversion: interface {} is string, not int64`), + errorLine: 3, + errorPos: 21, + }, + { + name: "mqtt bad TLS", + config: ` + mqtt { + port: -1 + tls { + cert_file: "./configs/certs/server.pem" + } + } + `, + err: errors.New(`missing 'key_file' in TLS configuration`), + errorLine: 4, + errorPos: 21, }, { name: "connection types wrong type", diff --git a/server/consumer.go b/server/consumer.go index edd618a4..ef0e71a4 100644 --- a/server/consumer.go +++ b/server/consumer.go @@ -213,6 +213,13 @@ const ( ) func (mset *Stream) AddConsumer(config *ConsumerConfig) (*Consumer, error) { + if name := mset.Name(); strings.HasPrefix(name, mqttStreamNamePrefix) { + return nil, fmt.Errorf("stream prefix %q is reserved for MQTT, unable to create consumer on %q", mqttStreamNamePrefix, name) + } + return mset.addConsumerCheckInterest(config, true) +} + +func (mset *Stream) addConsumerCheckInterest(config *ConsumerConfig, checkInterest bool) (*Consumer, error) { if config == nil { return nil, fmt.Errorf("consumer config required") } @@ -350,7 +357,7 @@ func (mset *Stream) AddConsumer(config *ConsumerConfig) (*Consumer, error) { } else { // If we are a push mode and not active and the only difference // is deliver subject then update and return. - if configsEqualSansDelivery(ocfg, *config) && eo.hasNoLocalInterest() { + if configsEqualSansDelivery(ocfg, *config) && (!checkInterest || eo.hasNoLocalInterest()) { eo.updateDeliverSubject(config.DeliverSubject) return eo, nil } else { @@ -2177,9 +2184,22 @@ func (mset *Stream) deliveryFormsCycle(deliverySubject string) bool { return false } -// This is same as check for delivery cycle. +// Check that the subject is a subset of the stream's configured subjects, +// or returns true if the stream has been created with no subject. func (mset *Stream) validSubject(partitionSubject string) bool { - return mset.deliveryFormsCycle(partitionSubject) + mset.mu.RLock() + defer mset.mu.RUnlock() + + if mset.nosubj && len(mset.config.Subjects) == 0 { + return true + } + + for _, subject := range mset.config.Subjects { + if subjectIsSubsetMatch(partitionSubject, subject) { + return true + } + } + return false } // SetInActiveDeleteThreshold sets the delete threshold for how long to wait diff --git a/server/events.go b/server/events.go index 449198cd..3e9c84dc 100644 --- a/server/events.go +++ b/server/events.go @@ -1410,7 +1410,8 @@ func (s *Server) systemSubscribe(subject, queue string, internalOnly bool, cb ms if queue != "" { q = []byte(queue) } - return c.processSub([]byte(subject), q, []byte(sid), cb, internalOnly) + // Now create the subscription + return c.processSub(c.createSub([]byte(subject), q, []byte(sid), cb), internalOnly) } func (s *Server) sysUnsubscribe(sub *subscription) { diff --git a/server/jetstream.go b/server/jetstream.go index 9a591471..0c651a01 100644 --- a/server/jetstream.go +++ b/server/jetstream.go @@ -533,7 +533,12 @@ func (a *Account) EnableJetStream(limits *JetStreamAccountLimits) error { s.Warnf(" Error adding Stream %q to Template %q: %v", cfg.Name, cfg.Template, err) } } - mset, err := a.AddStream(&cfg.StreamConfig) + // TODO: We should not rely on the stream name. + // However, having a StreamConfig property, such as AllowNoSubject, + // was not accepted because it does not make sense outside of the + // MQTT use-case. So need to revisit this. + mqtt := cfg.StreamConfig.Name == mqttStreamName + mset, err := a.addStreamWithStore(&cfg.StreamConfig, nil, mqtt) if err != nil { s.Warnf(" Error recreating Stream %q: %v", cfg.Name, err) continue @@ -578,7 +583,7 @@ func (a *Account) EnableJetStream(limits *JetStreamAccountLimits) error { // the consumer can reconnect. We will create it as a durable and switch it. cfg.ConsumerConfig.Durable = ofi.Name() } - obs, err := mset.AddConsumer(&cfg.ConsumerConfig) + obs, err := mset.addConsumerCheckInterest(&cfg.ConsumerConfig, !mqtt) if err != nil { s.Warnf(" Error adding Consumer: %v", err) continue @@ -1060,7 +1065,7 @@ func (a *Account) AddStreamTemplate(tc *StreamTemplateConfig) (*StreamTemplate, // FIXME(dlc) - Hacky tcopy := tc.deepCopy() tcopy.Config.Name = "_" - cfg, err := checkStreamCfg(tcopy.Config) + cfg, err := checkStreamCfg(tcopy.Config, false) if err != nil { return nil, err } @@ -1115,7 +1120,7 @@ func (t *StreamTemplate) createTemplateSubscriptions() error { sid := 1 for _, subject := range t.Config.Subjects { // Now create the subscription - if _, err := c.processSub([]byte(subject), nil, []byte(strconv.Itoa(sid)), t.processInboundTemplateMsg, false); err != nil { + if _, err := c.processSub(c.createSub([]byte(subject), nil, []byte(strconv.Itoa(sid)), t.processInboundTemplateMsg), false); err != nil { c.acc.DeleteStreamTemplate(t.Name) return err } diff --git a/server/monitor.go b/server/monitor.go index 4bae0fda..6c941f2a 100644 --- a/server/monitor.go +++ b/server/monitor.go @@ -1915,6 +1915,8 @@ func (reason ClosedState) String() string { return "Cluster Name Conflict" case DuplicateRemoteLeafnodeConnection: return "Duplicate Remote LeafNode Connection" + case DuplicateClientID: + return "Duplicate Client ID" } return "Unknown State" diff --git a/server/mqtt.go b/server/mqtt.go new file mode 100644 index 00000000..4c508b18 --- /dev/null +++ b/server/mqtt.go @@ -0,0 +1,2515 @@ +// Copyright 2020 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "sync" + "time" + "unicode/utf8" + + "github.com/nats-io/nuid" +) + +// References to "spec" here is from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.pdf + +const ( + mqttPacketConnect = byte(0x10) + mqttPacketConnectAck = byte(0x20) + mqttPacketPub = byte(0x30) + mqttPacketPubAck = byte(0x40) + mqttPacketPubRec = byte(0x50) + mqttPacketPubRel = byte(0x60) + mqttPacketPubComp = byte(0x70) + mqttPacketSub = byte(0x80) + mqttPacketSubAck = byte(0x90) + mqttPacketUnsub = byte(0xa0) + mqttPacketUnsubAck = byte(0xb0) + mqttPacketPing = byte(0xc0) + mqttPacketPingResp = byte(0xd0) + mqttPacketDisconnect = byte(0xe0) + mqttPacketMask = byte(0xf0) + mqttPacketFlagMask = byte(0x0f) + + mqttProtoLevel = byte(0x4) + + // Connect flags + mqttConnFlagReserved = byte(0x1) + mqttConnFlagCleanSession = byte(0x2) + mqttConnFlagWillFlag = byte(0x04) + mqttConnFlagWillQoS = byte(0x18) + mqttConnFlagWillRetain = byte(0x20) + mqttConnFlagPasswordFlag = byte(0x40) + mqttConnFlagUsernameFlag = byte(0x80) + + // Publish flags + mqttPubFlagRetain = byte(0x01) + mqttPubFlagQoS = byte(0x06) + mqttPubFlagDup = byte(0x08) + mqttPubQos1 = byte(0x2) // 1 << 1 + + // Subscribe flags + mqttSubscribeFlags = byte(0x2) + mqttSubAckFailure = byte(0x80) + + // Unsubscribe flags + mqttUnsubscribeFlags = byte(0x2) + + // ConnAck returned codes + mqttConnAckRCConnectionAccepted = byte(0x0) + mqttConnAckRCUnacceptableProtocolVersion = byte(0x1) + mqttConnAckRCIdentifierRejected = byte(0x2) + mqttConnAckRCServerUnavailable = byte(0x3) + mqttConnAckRCBadUserOrPassword = byte(0x4) + mqttConnAckRCNotAuthorized = byte(0x5) + + // Topic/Filter characters + mqttTopicLevelSep = '/' + mqttSingleLevelWC = '+' + mqttMultiLevelWC = '#' + + // This is appended to the sid of a subscription that is + // created on the upper level subject because of the MQTT + // wildcard '#' semantic. + mqttMultiLevelSidSuffix = " fwc" + + // This is the prefix for NATS subscriptions subjects associated as delivery + // subject of JS consumer. We want to make them unique so will prevent users + // MQTT subscriptions to start with this. + mqttSubPrefix = "$MQTT.sub." + + // MQTT Stream names prefix. Will be used to prevent users from creating + // JS consumers on those. + mqttStreamNamePrefix = "$MQTT_" + + // Stream name for MQTT messages on a given account + mqttStreamName = mqttStreamNamePrefix + "msgs" + + // Stream name for MQTT retained messages on a given account + mqttRetainedMsgsStreamName = mqttStreamNamePrefix + "rmsgs" + + // Stream name for MQTT sessions on a given account + mqttSessionsStreamName = mqttStreamNamePrefix + "sessions" + + // Normally, MQTT server should not redeliver QoS 1 messages to clients, + // except after client reconnects. However, NATS Server will redeliver + // unacknowledged messages after this default interval. This can be + // changed with the server.Options.MQTT.AckWait option. + mqttDefaultAckWait = 30 * time.Second + + // This is the default for the outstanding number of pending QoS 1 + // messages sent to a session with QoS 1 subscriptions. + mqttDefaultMaxAckPending = 1024 +) + +var ( + mqttPingResponse = []byte{mqttPacketPingResp, 0x0} + mqttProtoName = []byte("MQTT") + mqttOldProtoName = []byte("MQIsdp") +) + +type srvMQTT struct { + listener net.Listener + authOverride bool + sessmgr mqttSessionManager +} + +type mqttSessionManager struct { + mu sync.RWMutex + sessions map[string]*mqttAccountSessionManager // key is account name +} + +type mqttAccountSessionManager struct { + mu sync.RWMutex + sstream *Stream // stream where sessions are recorded + mstream *Stream // messages stream + rstream *Stream // retained messages stream + sessions map[string]*mqttSession // key is MQTT client ID + sl *Sublist // sublist allowing to find retained messages for given subscription + retmsgs map[string]*mqttRetainedMsg // retained messages +} + +type mqttSession struct { + mu sync.Mutex + c *client + subs map[string]byte + cons map[string]*Consumer + stream *Stream + sseq uint64 // stream sequence where this sesion is recorded + pending map[uint16]*mqttPending // Key is the PUBLISH packet identifier sent to client and maps to a mqttPending record + cpending map[*Consumer]map[uint64]uint16 // For each JS consumer, the key is the stream sequence and maps to the PUBLISH packet identifier + ppi uint16 // publish packet identifier + maxp uint16 + stalled bool + clean bool +} + +type mqttPersistedSession struct { + ID string `json:"id,omitempty"` + Clean bool `json:"clean,omitempty"` + Subs map[string]byte `json:"subs,omitempty"` + Cons map[string]string `json:"cons,omitempty"` +} + +type mqttRetainedMsg struct { + Msg []byte `json:"msg,omitempty"` + Flags byte `json:"flags,omitempty"` + Source string `json:"source,omitempty"` + + // non exported + sseq uint64 + sub *subscription +} + +type mqttSub struct { + qos byte + // Pending serialization of retained messages to be sent when subscription is registered + prm *mqttWriter + // This is the corresponding JS consumer. This is applicable to a subscription that is + // done for QoS > 0 (the subscription attached to a JS consumer's delivery subject). + jsCons *Consumer +} + +type mqtt struct { + r *mqttReader + cp *mqttConnectProto + pp *mqttPublish + asm *mqttAccountSessionManager // quick reference to account session manager, immutable after processConnect() + sess *mqttSession // quick reference to session, immutable after processConnect() +} + +type mqttPending struct { + sseq uint64 // stream sequence + dseq uint64 // consumer delivery sequence + dcount uint64 // consumer delivery count + jsCons *Consumer // pointer to JS consumer (to which we will call ackMsg() on with above info) +} + +type mqttConnectProto struct { + clientID string + rd time.Duration + will *mqttWill + flags byte +} + +type mqttIOReader interface { + io.Reader + SetReadDeadline(time.Time) error +} + +type mqttReader struct { + reader mqttIOReader + buf []byte + pos int +} + +type mqttWriter struct { + bytes.Buffer +} + +type mqttWill struct { + topic []byte + message []byte + qos byte + retain bool +} + +type mqttFilter struct { + filter string + qos byte +} + +type mqttPublish struct { + subject []byte + msg []byte + sz int + pi uint16 + flags byte +} + +func (s *Server) startMQTT() { + sopts := s.getOpts() + o := &sopts.MQTT + + var hl net.Listener + var err error + + port := o.Port + if port == -1 { + port = 0 + } + hp := net.JoinHostPort(o.Host, strconv.Itoa(port)) + s.mu.Lock() + if s.shutdown { + s.mu.Unlock() + return + } + s.mqtt.sessmgr.sessions = make(map[string]*mqttAccountSessionManager) + hl, err = net.Listen("tcp", hp) + if err != nil { + s.mu.Unlock() + s.Fatalf("Unable to listen for MQTT connections: %v", err) + return + } + if port == 0 { + o.Port = hl.Addr().(*net.TCPAddr).Port + } + s.mqtt.listener = hl + scheme := "mqtt" + if o.TLSConfig != nil { + scheme = "tls" + } + s.Noticef("Listening for MQTT clients on %s://%s:%d", scheme, o.Host, o.Port) + go s.acceptConnections(hl, "MQTT", func(conn net.Conn) { s.createClient(conn, nil, &mqtt{}) }, nil) + s.mu.Unlock() +} + +// Given the mqtt options, we check if any auth configuration +// has been provided. If so, possibly create users/nkey users and +// store them in s.mqtt.users/nkeys. +// Also update a boolean that indicates if auth is required for +// mqtt clients. +// Server lock is held on entry. +func (s *Server) mqttConfigAuth(opts *MQTTOpts) { + mqtt := &s.mqtt + // If any of those is specified, we consider that there is an override. + mqtt.authOverride = opts.Username != _EMPTY_ || opts.Token != _EMPTY_ || opts.NoAuthUser != _EMPTY_ +} + +// Validate the mqtt related options. +func validateMQTTOptions(o *Options) error { + mo := &o.MQTT + // If no port is defined, we don't care about other options + if mo.Port == 0 { + return nil + } + // If there is a NoAuthUser, we need to have Users defined and + // the user to be present. + if mo.NoAuthUser != _EMPTY_ { + if err := validateNoAuthUser(o, mo.NoAuthUser); err != nil { + return err + } + } + // Token/Username not possible if there are users/nkeys + if len(o.Users) > 0 || len(o.Nkeys) > 0 { + if mo.Username != _EMPTY_ { + return fmt.Errorf("mqtt authentication username not compatible with presence of users/nkeys") + } + if mo.Token != _EMPTY_ { + return fmt.Errorf("mqtt authentication token not compatible with presence of users/nkeys") + } + } + if mo.AckWait < 0 { + return fmt.Errorf("ack wait must be a positive value") + } + return nil +} + +// Parse protocols inside the given buffer. +// This is invoked from the readLoop. +func (c *client) mqttParse(buf []byte) error { + c.mu.Lock() + s := c.srv + trace := c.trace + connected := c.flags.isSet(connectReceived) + mqtt := c.mqtt + r := mqtt.r + var rd time.Duration + if mqtt.cp != nil { + rd = mqtt.cp.rd + if rd > 0 { + r.reader.SetReadDeadline(time.Time{}) + } + } + c.mu.Unlock() + + r.reset(buf) + + var err error + var b byte + var pl int + + for err == nil && r.hasMore() { + + // Read packet type and flags + if b, err = r.readByte("packet type"); err != nil { + break + } + + // Packet type + pt := b & mqttPacketMask + + // If client was not connected yet, the first packet must be + // a mqttPacketConnect otherwise we fail the connection. + if !connected && pt != mqttPacketConnect { + err = errors.New("not connected") + break + } + + if pl, err = r.readPacketLen(); err != nil { + break + } + + switch pt { + case mqttPacketPub: + pp := mqttPublish{flags: b & mqttPacketFlagMask} + err = c.mqttParsePub(r, pl, &pp) + if trace { + c.traceInOp("PUBLISH", errOrTrace(err, mqttPubTrace(&pp))) + if err == nil { + c.traceMsg(pp.msg) + } + } + if err == nil { + s.mqttProcessPub(c, &pp) + if pp.pi > 0 { + c.mqttEnqueuePubAck(pp.pi) + if trace { + c.traceOutOp("PUBACK", []byte(fmt.Sprintf("pi=%v", pp.pi))) + } + } + } + case mqttPacketPubAck: + var pi uint16 + pi, err = mqttParsePubAck(r, pl) + if trace { + c.traceInOp("PUBACK", errOrTrace(err, fmt.Sprintf("pi=%v", pi))) + } + if err == nil { + c.mqttProcessPubAck(pi) + } + case mqttPacketSub: + var pi uint16 // packet identifier + var filters []*mqttFilter + var subs []*subscription + pi, filters, err = c.mqttParseSubs(r, b, pl) + if trace { + c.traceInOp("SUBSCRIBE", errOrTrace(err, mqttSubscribeTrace(filters))) + } + if err == nil { + subs, err = c.mqttProcessSubs(filters) + if err == nil && trace { + c.traceOutOp("SUBACK", []byte(mqttSubscribeTrace(filters))) + } + } + if err == nil { + c.mqttEnqueueSubAck(pi, filters) + c.mqttSendRetainedMsgsToNewSubs(subs) + } + case mqttPacketUnsub: + var pi uint16 // packet identifier + var filters []*mqttFilter + pi, filters, err = c.mqttParseUnsubs(r, b, pl) + if trace { + c.traceInOp("UNSUBSCRIBE", errOrTrace(err, mqttUnsubscribeTrace(filters))) + } + if err == nil { + err = c.mqttProcessUnsubs(filters) + if err == nil && trace { + c.traceOutOp("UNSUBACK", []byte(strconv.FormatInt(int64(pi), 10))) + } + } + if err == nil { + c.mqttEnqueueUnsubAck(pi) + } + case mqttPacketPing: + if trace { + c.traceInOp("PINGREQ", nil) + } + c.mqttEnqueuePingResp() + if trace { + c.traceOutOp("PINGRESP", nil) + } + case mqttPacketConnect: + // It is an error to receive a second connect packet + if connected { + err = errors.New("second connect packet") + break + } + var rc byte + var cp *mqttConnectProto + var sessp bool + rc, cp, err = c.mqttParseConnect(r, pl) + if trace && cp != nil { + c.traceInOp("CONNECT", errOrTrace(err, c.mqttConnectTrace(cp))) + } + if rc != 0 { + c.mqttEnqueueConnAck(rc, sessp) + if trace { + c.traceOutOp("CONNACK", []byte(fmt.Sprintf("sp=%v rc=%v", sessp, rc))) + } + } else if err == nil { + if err = s.mqttProcessConnect(c, cp, trace); err == nil { + connected = true + rd = cp.rd + } + } + case mqttPacketDisconnect: + if trace { + c.traceInOp("DISCONNECT", nil) + } + // Normal disconnect, we need to discard the will. + // Spec [MQTT-3.1.2-8] + c.mu.Lock() + if c.mqtt.cp != nil { + c.mqtt.cp.will = nil + } + c.mu.Unlock() + c.closeConnection(ClientClosed) + return nil + case mqttPacketPubRec: + fallthrough + case mqttPacketPubRel: + fallthrough + case mqttPacketPubComp: + err = fmt.Errorf("protocol %d not supported", pt>>4) + default: + err = fmt.Errorf("received unknown packet type %d", pt>>4) + } + } + if err == nil && rd > 0 { + r.reader.SetReadDeadline(time.Now().Add(rd)) + } + return err +} + +// Update the session (possibly remove it) of this disconnected client. +func (s *Server) mqttHandleClosedClient(c *client) { + c.mu.Lock() + cp := c.mqtt.cp + accName := c.acc.Name + c.mu.Unlock() + if cp == nil { + return + } + sm := &s.mqtt.sessmgr + sm.mu.RLock() + asm, ok := sm.sessions[accName] + sm.mu.RUnlock() + if !ok { + return + } + + asm.mu.Lock() + defer asm.mu.Unlock() + es, ok := asm.sessions[cp.clientID] + // If not found, ignore. + if !ok { + return + } + es.mu.Lock() + defer es.mu.Unlock() + // If this client is not the currently registered client, ignore. + if es.c != c { + return + } + // It the session was created with "clean session" flag, we cleanup + // when the client disconnects. + if es.clean { + es.clear() + delete(asm.sessions, cp.clientID) + } else { + // Clear the client from the session, but session stays. + es.c = nil + } +} + +// Updates the MaxAckPending for all MQTT sessions, updating the +// JetStream consumers and updating their max ack pending and forcing +// a expiration of pending messages. +func (s *Server) mqttUpdateMaxAckPending(newmaxp uint16) { + msm := &s.mqtt.sessmgr + s.accounts.Range(func(k, _ interface{}) bool { + accName := k.(string) + msm.mu.RLock() + asm := msm.sessions[accName] + msm.mu.RUnlock() + if asm == nil { + // Move to next account + return true + } + asm.mu.RLock() + for _, sess := range asm.sessions { + sess.mu.Lock() + sess.maxp = newmaxp + // Do not check for sess.stalled here because due to original + // consumer's maxp it is possible that the consumer was stalled + // and MQTT code did not have to stall the session. + for _, cons := range sess.cons { + cons.mu.Lock() + cons.maxp = int(newmaxp) + cons.forceExpirePending() + cons.mu.Unlock() + } + sess.mu.Unlock() + } + asm.mu.RUnlock() + return true + }) +} + +////////////////////////////////////////////////////////////////////////////// +// +// Sessions Manager related functions +// +////////////////////////////////////////////////////////////////////////////// + +// Returns the MQTT sessions manager for a given account. +// If new, creates the required JetStream streams/consumers +// for handling of sessions and messages. +func (sm *mqttSessionManager) getOrCreateAccountSessionManager(clientID string, c *client) (*mqttAccountSessionManager, error) { + c.mu.Lock() + acc := c.acc + accName := acc.GetName() + c.mu.Unlock() + + sm.mu.RLock() + asm, ok := sm.sessions[accName] + sm.mu.RUnlock() + + if ok { + return asm, nil + } + + // Not found, now take the write lock and check again + sm.mu.Lock() + defer sm.mu.Unlock() + asm, ok = sm.sessions[accName] + if ok { + return asm, nil + } + // First check that we have JS enabled for this account. + // TODO: Since we check only when creating a session manager for this + // account, would probably need to do some cleanup if JS can be disabled + // on config reload. + if !acc.JetStreamEnabled() { + return nil, fmt.Errorf("JetStream must be enabled for account %q used by MQTT client ID %q", + accName, clientID) + } + // Need to create one here. + asm = &mqttAccountSessionManager{sessions: make(map[string]*mqttSession)} + if err := asm.init(acc, c); err != nil { + return nil, err + } + sm.sessions[accName] = asm + return asm, nil +} + +////////////////////////////////////////////////////////////////////////////// +// +// Account Sessions Manager related functions +// +////////////////////////////////////////////////////////////////////////////// + +// Creates JS streams/consumers for handling of sessions and messages for this +// account. Note that lookup are performed in case we are in a restart +// situation and the account is loaded for the first time but had state on disk. +// +// Global session manager lock is held on entry. +func (as *mqttAccountSessionManager) init(acc *Account, c *client) error { + opts := c.srv.getOpts() + var err error + // Start with sessions stream + as.sstream, err = acc.LookupStream(mqttSessionsStreamName) + if err != nil { + as.sstream, err = acc.addStreamWithStore(&StreamConfig{ + Subjects: []string{}, + Name: mqttSessionsStreamName, + Storage: FileStorage, + Retention: InterestPolicy, + }, nil, true) + if err != nil { + return fmt.Errorf("unable to create sessions stream for MQTT account %q: %v", acc.GetName(), err) + } + } + // Create the stream for the messages. + as.mstream, err = acc.LookupStream(mqttStreamName) + if err != nil { + as.mstream, err = acc.addStreamWithStore(&StreamConfig{ + Subjects: []string{}, + Name: mqttStreamName, + Storage: FileStorage, + Retention: InterestPolicy, + }, nil, true) + if err != nil { + return fmt.Errorf("unable to create messages stream for MQTT account %q: %v", acc.GetName(), err) + } + } + // Create the stream for retained messages. + as.rstream, err = acc.LookupStream(mqttRetainedMsgsStreamName) + if err != nil { + as.rstream, err = acc.addStreamWithStore(&StreamConfig{ + Subjects: []string{}, + Name: mqttRetainedMsgsStreamName, + Storage: FileStorage, + Retention: InterestPolicy, + }, nil, true) + if err != nil { + return fmt.Errorf("unable to create retained messages stream for MQTT account %q: %v", acc.GetName(), err) + } + } + // Now recover all sessions (in case it did already exist) + if state := as.sstream.State(); state.Msgs > 0 { + for seq := state.FirstSeq; seq <= state.LastSeq; seq++ { + _, _, content, _, err := as.sstream.store.LoadMsg(seq) + if err != nil { + if err != errDeletedMsg { + c.Errorf("Error loading session record at sequence %v: %v", seq, err) + } + continue + } + ps := &mqttPersistedSession{} + if err := json.Unmarshal(content, ps); err != nil { + c.Errorf("Error unmarshaling session record at sequence %v: %v", seq, err) + continue + } + if as.sessions == nil { + as.sessions = make(map[string]*mqttSession) + } + es, ok := as.sessions[ps.ID] + if ok && es.sseq != 0 { + as.sstream.DeleteMsg(es.sseq) + } else if !ok { + es = mqttSessionCreate(opts) + es.stream = as.sstream + as.sessions[ps.ID] = es + } + es.sseq = seq + es.clean = ps.Clean + es.subs = ps.Subs + if l := len(ps.Cons); l > 0 { + if es.cons == nil { + es.cons = make(map[string]*Consumer, l) + } + for sid, name := range ps.Cons { + if cons := as.mstream.LookupConsumer(name); cons != nil { + es.cons[sid] = cons + } + } + } + } + } + // Finally, recover retained messages. + if state := as.rstream.State(); state.Msgs > 0 { + for seq := state.FirstSeq; seq <= state.LastSeq; seq++ { + subject, _, content, _, err := as.rstream.store.LoadMsg(seq) + if err != nil { + if err != errDeletedMsg { + c.Errorf("Error loading retained message at sequence %v: %v", seq, err) + } + continue + } + rm := &mqttRetainedMsg{} + if err := json.Unmarshal(content, &rm); err != nil { + c.Errorf("Error unmarshaling retained message on subject %q, sequence %v: %v", subject, seq, err) + continue + } + rm = as.handleRetainedMsg(subject, rm) + rm.sseq = seq + } + } + return nil +} + +// Add/Replace this message from the retained messages map. +// Returns the retained message actually stored in the map, which means that +// it may be different from the given `rm`. +// +// Account session manager lock held on entry. +func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rm *mqttRetainedMsg) *mqttRetainedMsg { + if as.retmsgs == nil { + as.retmsgs = make(map[string]*mqttRetainedMsg) + as.sl = NewSublistWithCache() + } else { + // Check if we already had one. If so, update the existing one. + if erm, exists := as.retmsgs[key]; exists { + erm.Msg = rm.Msg + erm.Flags = rm.Flags + erm.Source = rm.Source + return erm + } + } + rm.sub = &subscription{subject: []byte(key)} + as.retmsgs[key] = rm + as.sl.Insert(rm.sub) + return rm +} + +// Process subscriptions for the given session/client. +// +// When `fromSubProto` is false, it means that this is invoked from the CONNECT +// protocol, when restoring subscriptions that were saved for this session. +// In that case, there is no need to update the session record. +// +// When `fromSubProto` is true, it means that this call is invoked from the +// processing of the SUBSCRIBE protocol, which means that the session needs to +// be updated. It also means that if a subscription on same subject with same +// QoS already exist, we should not be recreating the subscription/JS durable, +// since it was already done when processing the CONNECT protocol. +// +// Account session manager lock held on entry. +// Session lock held when `fromSubProto` is false. +func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, clientID string, c *client, + filters []*mqttFilter, fromSubProto, trace bool) ([]*subscription, error) { + + if fromSubProto { + sess.mu.Lock() + defer sess.mu.Unlock() + if sess.c != c { + return nil, fmt.Errorf("client %q no longer registered with MQTT session", clientID) + } + } + + addJSConsToSess := func(sid string, cons *Consumer) { + if cons == nil { + return + } + if sess.cons == nil { + sess.cons = make(map[string]*Consumer) + } + sess.cons[sid] = cons + } + + subs := make([]*subscription, 0, len(filters)) + for _, f := range filters { + if f.qos > 1 { + f.qos = 1 + } + subject := f.filter + sid := subject + + if strings.HasPrefix(subject, mqttSubPrefix) { + f.qos = mqttSubAckFailure + continue + } + + var jscons *Consumer + var jssub *subscription + var err error + + sub := c.mqttCreateSub(subject, sid, mqttDeliverMsgCb, f.qos) + if fromSubProto { + as.serializeRetainedMsgsForSub(sess, c, sub, trace) + } + // Note that if a subscription already exists on this subject, + // the sub is updated with the new qos/prm and the pointer to + // the existing subscription is returned. + sub, err = c.processSub(sub, false) + if err == nil { + // This will create (if not already exist) a JS consumer for subscriptions + // of QoS >= 1. But if a JS consumer already exists and the subscription + // for same subject is now a QoS==0, then the JS consumer will be deleted. + jscons, jssub, err = c.mqttProcessJSConsumer(sess, as.mstream, + subject, sid, f.qos, fromSubProto) + } + if err != nil { + c.Errorf("error subscribing to %q: err=%v", subject, err) + f.qos = mqttSubAckFailure + c.mqttCleanupFailedSub(sub, jscons, jssub) + continue + } + if mqttNeedSubForLevelUp(subject) { + var fwjscons *Consumer + var fwjssub *subscription + + // Say subject is "foo.>", remove the ".>" so that it becomes "foo" + fwcsubject := subject[:len(subject)-2] + // Change the sid to "foo fwc" + fwcsid := fwcsubject + mqttMultiLevelSidSuffix + fwcsub := c.mqttCreateSub(fwcsubject, fwcsid, mqttDeliverMsgCb, f.qos) + if fromSubProto { + as.serializeRetainedMsgsForSub(sess, c, fwcsub, trace) + } + // See note above about existing subscription. + fwcsub, err = c.processSub(fwcsub, false) + if err == nil { + fwjscons, fwjssub, err = c.mqttProcessJSConsumer(sess, as.mstream, + fwcsubject, fwcsid, f.qos, fromSubProto) + } + if err != nil { + c.Errorf("error subscribing to %q: err=%v", fwcsubject, err) + f.qos = mqttSubAckFailure + c.mqttCleanupFailedSub(sub, jscons, jssub) + c.mqttCleanupFailedSub(fwcsub, fwjscons, fwjssub) + continue + } + subs = append(subs, fwcsub) + addJSConsToSess(fwcsid, fwjscons) + } + subs = append(subs, sub) + addJSConsToSess(sid, jscons) + } + var err error + if fromSubProto { + err = sess.update(clientID, filters, true) + } + return subs, err +} + +// Retained publish messages matching this subscription are serialized in the +// subscription's `prm` mqtt writer. This buffer will be queued for outbound +// after the subscription is processed and SUBACK is sent or possibly when +// server processes an incoming published message matching the newly +// registered subscription. +// +// Account session manager lock held on entry. +// Session lock held on entry +func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(sess *mqttSession, c *client, sub *subscription, trace bool) { + if len(as.retmsgs) > 0 { + var rmsa [64]*mqttRetainedMsg + rms := rmsa[:0] + + as.getRetainedPublishMsgs(string(sub.subject), &rms) + for _, rm := range rms { + if sub.mqtt.prm == nil { + sub.mqtt.prm = &mqttWriter{} + } + prm := sub.mqtt.prm + pi := sess.getPubAckIdentifier(mqttGetQoS(rm.Flags), sub) + // Need to use the subject for the retained message, not the `sub` subject. + // We can find the published retained message in rm.sub.subject. + flags := mqttSerializePublishMsg(prm, pi, false, true, string(rm.sub.subject), rm.Msg[:len(rm.Msg)-LEN_CR_LF]) + if trace { + pp := mqttPublish{ + flags: flags, + pi: pi, + subject: rm.sub.subject, + sz: len(rm.Msg) - LEN_CR_LF, + } + c.traceOutOp("PUBLISH", []byte(mqttPubTrace(&pp))) + } + } + } +} + +// Returns in the provided slice all publish retained message records that +// match the given subscription's `subject` (which could have wildcards). +// +// Account session manager lock held on entry. +func (as *mqttAccountSessionManager) getRetainedPublishMsgs(subject string, rms *[]*mqttRetainedMsg) { + result := as.sl.ReverseMatch(subject) + if len(result.psubs) == 0 { + return + } + for _, sub := range result.psubs { + // Since this is a reverse match, the subscription objects here + // contain literals corresponding to the published subjects. + if rm, ok := as.retmsgs[string(sub.subject)]; ok { + *rms = append(*rms, rm) + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// +// MQTT session related functions +// +////////////////////////////////////////////////////////////////////////////// + +// Returns a new mqttSession object with max ack pending set based on +// option or use mqttDefaultMaxAckPending if no option set. +func mqttSessionCreate(opts *Options) *mqttSession { + maxp := opts.MQTT.MaxAckPending + if maxp == 0 { + maxp = mqttDefaultMaxAckPending + } + return &mqttSession{maxp: maxp} +} + +// Persists a session. Note that if the session's current client does not match +// the given client, nothing is done. +// +// Lock held on entry. +func (sess *mqttSession) save(clientID string) error { + ps := mqttPersistedSession{ + ID: clientID, + Clean: sess.clean, + Subs: sess.subs, + } + if l := len(sess.cons); l > 0 { + cons := make(map[string]string, l) + for sid, jscons := range sess.cons { + cons[sid] = jscons.Name() + } + ps.Cons = cons + } + sessBytes, _ := json.Marshal(&ps) + newSeq, _, err := sess.stream.store.StoreMsg("sessions", nil, sessBytes) + if err != nil { + return err + } + if sess.sseq != 0 { + sess.stream.DeleteMsg(sess.sseq) + } + sess.sseq = newSeq + return nil +} + +// Delete JS consumers for this session and delete the persisted session from +// the stream. +// +// Lock held on entry. +func (sess *mqttSession) clear() { + for consName, cons := range sess.cons { + delete(sess.cons, consName) + cons.Delete() + } + if sess.stream != nil && sess.sseq != 0 { + sess.stream.DeleteMsg(sess.sseq) + sess.sseq = 0 + } + sess.subs, sess.pending, sess.cpending = nil, nil, nil +} + +// This will update the session record for this client in the account's MQTT +// sessions stream if the session had any change in the subscriptions. +// +// Lock held on entry. +func (sess *mqttSession) update(clientID string, filters []*mqttFilter, add bool) error { + // Evaluate if we need to persist anything. + var needUpdate bool + for _, f := range filters { + if add { + if f.qos == mqttSubAckFailure { + continue + } + if qos, ok := sess.subs[f.filter]; !ok || qos != f.qos { + if sess.subs == nil { + sess.subs = make(map[string]byte) + } + sess.subs[f.filter] = f.qos + needUpdate = true + } + } else { + if _, ok := sess.subs[f.filter]; ok { + delete(sess.subs, f.filter) + needUpdate = true + } + } + } + var err error + if needUpdate { + err = sess.save(clientID) + } + return err +} + +// If both pQos and sub.mqtt.qos are > 0, then this will return the next +// packet identifier to use for a published message. +// +// Lock held on entry +func (sess *mqttSession) getPubAckIdentifier(pQos byte, sub *subscription) uint16 { + pi, _ := sess.trackPending(pQos, _EMPTY_, sub) + return pi +} + +// If publish message QoS (pQos) and the subscription's QoS are both at least 1, +// this function will assign a packet identifier (pi) and will keep track of +// the pending ack. If the message has already been redelivered (reply != ""), +// the returned boolean will be `true`. +// +// Lock held on entry +func (sess *mqttSession) trackPending(pQos byte, reply string, sub *subscription) (uint16, bool) { + if pQos == 0 || sub.mqtt.qos == 0 { + return 0, false + } + var dup bool + var pi uint16 + + bumpPI := func() uint16 { + var avail bool + next := sess.ppi + for i := 0; i < 0xFFFF; i++ { + next++ + if next == 0 { + next = 1 + } + if _, used := sess.pending[next]; !used { + sess.ppi = next + avail = true + break + } + } + if !avail { + return 0 + } + return sess.ppi + } + + if reply == _EMPTY_ || sub.mqtt.jsCons == nil { + return bumpPI(), false + } + + // Here, we have an ACK subject and a JS consumer... + jsCons := sub.mqtt.jsCons + if sess.pending == nil { + sess.pending = make(map[uint16]*mqttPending) + sess.cpending = make(map[*Consumer]map[uint64]uint16) + } + // Get the stream sequence and other from the ack reply subject + sseq, dseq, dcount := ackReplyInfo(reply) + + var pending *mqttPending + // For this JS consumer, check to see if we already have sseq->pi + sseqToPi, ok := sess.cpending[jsCons] + if !ok { + sseqToPi = make(map[uint64]uint16) + sess.cpending[jsCons] = sseqToPi + } else if pi, ok = sseqToPi[sseq]; ok { + // If we already have a pi, get the ack so we update it. + // We will reuse the save packet identifier (pi). + pending = sess.pending[pi] + } + if pi == 0 { + // sess.maxp will always have a value > 0. + if len(sess.pending) >= int(sess.maxp) { + // Indicate that we did not assign a packet identifier. + // The caller will not send the message to the subscription + // and JS will redeliver later, based on consumer's AckWait. + sess.stalled = true + return 0, false + } + pi = bumpPI() + sseqToPi[sseq] = pi + } + if pending == nil { + pending = &mqttPending{jsCons: jsCons, sseq: sseq} + sess.pending[pi] = pending + } + // Update pending with consumer delivery sequence and count + pending.dseq, pending.dcount = dseq, dcount + // If redelivery, return DUP flag + if dcount > 1 { + dup = true + } + return pi, dup +} + +////////////////////////////////////////////////////////////////////////////// +// +// CONNECT protocol related functions +// +////////////////////////////////////////////////////////////////////////////// + +// Parse the MQTT connect protocol +func (c *client) mqttParseConnect(r *mqttReader, pl int) (byte, *mqttConnectProto, error) { + + // Make sure that we have the expected length in the buffer, + // and if not, this will read it from the underlying reader. + if err := r.ensurePacketInBuffer(pl); err != nil { + return 0, nil, err + } + + // Protocol name + proto, err := r.readBytes("protocol name", false) + if err != nil { + return 0, nil, err + } + + // Spec [MQTT-3.1.2-1] + if !bytes.Equal(proto, mqttProtoName) { + // Check proto name against v3.1 to report better error + if bytes.Equal(proto, mqttOldProtoName) { + return 0, nil, fmt.Errorf("older protocol %q not supported", proto) + } + return 0, nil, fmt.Errorf("expected connect packet with protocol name %q, got %q", mqttProtoName, proto) + } + + // Protocol level + level, err := r.readByte("protocol level") + if err != nil { + return 0, nil, err + } + // Spec [MQTT-3.1.2-2] + if level != mqttProtoLevel { + return mqttConnAckRCUnacceptableProtocolVersion, nil, fmt.Errorf("unacceptable protocol version of %v", level) + } + + cp := &mqttConnectProto{} + // Connect flags + cp.flags, err = r.readByte("flags") + if err != nil { + return 0, nil, err + } + + // Spec [MQTT-3.1.2-3] + if cp.flags&mqttConnFlagReserved != 0 { + return 0, nil, fmt.Errorf("connect flags reserved bit not set to 0") + } + + var hasWill bool + wqos := (cp.flags & mqttConnFlagWillQoS) >> 3 + wretain := cp.flags&mqttConnFlagWillRetain != 0 + // Spec [MQTT-3.1.2-11] + if cp.flags&mqttConnFlagWillFlag == 0 { + // Spec [MQTT-3.1.2-13] + if wqos != 0 { + return 0, nil, fmt.Errorf("if Will flag is set to 0, Will QoS must be 0 too, got %v", wqos) + } + // Spec [MQTT-3.1.2-15] + if wretain { + return 0, nil, fmt.Errorf("if Will flag is set to 0, Will Retain flag must be 0 too") + } + } else { + // Spec [MQTT-3.1.2-14] + if wqos == 3 { + return 0, nil, fmt.Errorf("if Will flag is set to 1, Will QoS can be 0, 1 or 2, got %v", wqos) + } + hasWill = true + } + + // Spec [MQTT-3.1.2-19] + hasUser := cp.flags&mqttConnFlagUsernameFlag != 0 + // Spec [MQTT-3.1.2-21] + hasPassword := cp.flags&mqttConnFlagPasswordFlag != 0 + // Spec [MQTT-3.1.2-22] + if !hasUser && hasPassword { + return 0, nil, fmt.Errorf("password flag set but username flag is not") + } + + // Keep alive + var ka uint16 + ka, err = r.readUint16("keep alive") + if err != nil { + return 0, nil, err + } + // Spec [MQTT-3.1.2-24] + if ka > 0 { + cp.rd = time.Duration(float64(ka)*1.5) * time.Second + } + + // Payload starts here and order is mandated by: + // Spec [MQTT-3.1.3-1]: client ID, will topic, will message, username, password + + // Client ID + cp.clientID, err = r.readString("client ID") + if err != nil { + return 0, nil, err + } + // Spec [MQTT-3.1.3-7] + if cp.clientID == _EMPTY_ { + if cp.flags&mqttConnFlagCleanSession == 0 { + return mqttConnAckRCIdentifierRejected, nil, fmt.Errorf("when client ID is empty, clean session flag must be set to 1") + } + // Spec [MQTT-3.1.3-6] + cp.clientID = nuid.Next() + } + // Spec [MQTT-3.1.3-4] and [MQTT-3.1.3-9] + if !utf8.ValidString(cp.clientID) { + return mqttConnAckRCIdentifierRejected, nil, fmt.Errorf("invalid utf8 for client ID: %q", cp.clientID) + } + + if hasWill { + cp.will = &mqttWill{ + qos: wqos, + retain: wretain, + } + var topic []byte + topic, err = r.readBytes("Will topic", false) + if err != nil { + return 0, nil, err + } + if len(topic) == 0 { + return 0, nil, fmt.Errorf("empty Will topic not allowed") + } + if !utf8.Valid(topic) { + return 0, nil, fmt.Errorf("invalide utf8 for Will topic %q", topic) + } + // Convert MQTT topic to NATS subject + var copied bool + copied, topic, err = mqttTopicToNATSPubSubject(topic) + if err != nil { + return 0, nil, err + } + if !copied { + topic = copyBytes(topic) + } + cp.will.topic = topic + // Now will message + var msg []byte + msg, err = r.readBytes("Will message", false) + if err != nil { + return 0, nil, err + } + cp.will.message = make([]byte, 0, len(msg)+2) + cp.will.message = append(cp.will.message, msg...) + cp.will.message = append(cp.will.message, CR_LF...) + } + + if hasUser { + c.opts.Username, err = r.readString("user name") + if err != nil { + return 0, nil, err + } + if c.opts.Username == _EMPTY_ { + return mqttConnAckRCBadUserOrPassword, nil, fmt.Errorf("empty user name not allowed") + } + // Spec [MQTT-3.1.3-11] + if !utf8.ValidString(c.opts.Username) { + return mqttConnAckRCBadUserOrPassword, nil, fmt.Errorf("invalid utf8 for user name %q", c.opts.Username) + } + } + + if hasPassword { + c.opts.Password, err = r.readString("password") + if err != nil { + return 0, nil, err + } + c.opts.Token = c.opts.Password + } + return 0, cp, nil +} + +func (c *client) mqttConnectTrace(cp *mqttConnectProto) string { + trace := fmt.Sprintf("clientID=%s", cp.clientID) + if cp.rd > 0 { + trace += fmt.Sprintf(" keepAlive=%v", cp.rd) + } + if cp.will != nil { + trace += fmt.Sprintf(" will=(topic=%s QoS=%v retain=%v)", + cp.will.topic, cp.will.qos, cp.will.retain) + } + if c.opts.Username != _EMPTY_ { + trace += fmt.Sprintf(" username=%s", c.opts.Username) + } + if c.opts.Password != _EMPTY_ { + trace += " password=****" + } + return trace +} + +func (s *Server) mqttProcessConnect(c *client, cp *mqttConnectProto, trace bool) error { + sendConnAck := func(rc byte, sessp bool) { + c.mqttEnqueueConnAck(rc, sessp) + if trace { + c.traceOutOp("CONNACK", []byte(fmt.Sprintf("sp=%v rc=%v", sessp, rc))) + } + } + + c.mu.Lock() + c.clearAuthTimer() + c.mu.Unlock() + if !s.isClientAuthorized(c) { + sendConnAck(mqttConnAckRCNotAuthorized, false) + c.closeConnection(AuthenticationViolation) + return ErrAuthentication + } + // Now that we are are authenticated, we have the client bound to the account. + // Get the account's level MQTT sessions manager. If it does not exists yet, + // this will create it along with the streams where sessions and messages + // are stored. + sm := &s.mqtt.sessmgr + asm, err := sm.getOrCreateAccountSessionManager(cp.clientID, c) + if err != nil { + return err + } + + // Rest of code runs under the account's sessions manager write lock. + asm.mu.Lock() + defer asm.mu.Unlock() + + // Is the client requesting a clean session or not. + cleanSess := cp.flags&mqttConnFlagCleanSession != 0 + // Session present? Assume false, will be set to true only when applicable. + sessp := false + // Do we have an existing session for this client ID + es, ok := asm.sessions[cp.clientID] + if ok { + es.mu.Lock() + defer es.mu.Unlock() + // Clear the session if client wants a clean session. + // Also, Spec [MQTT-3.2.2-1]: don't report session present + if cleanSess || es.clean { + // Spec [MQTT-3.1.2-6]: If CleanSession is set to 1, the Client and + // Server MUST discard any previous Session and start a new one. + // This Session lasts as long as the Network Connection. State data + // associated with this Session MUST NOT be reused in any subsequent + // Session. + es.clear() + } else { + // Report to the client that the session was present + sessp = true + } + ec := es.c + // Is there an actual client associated with this session. + if ec != nil { + // Spec [MQTT-3.1.4-2]. If the ClientId represents a Client already + // connected to the Server then the Server MUST disconnect the existing + // client. + ec := es.c + ec.mu.Lock() + // Remove will before closing + ec.mqtt.cp.will = nil + ec.mu.Unlock() + // Close old client in separate go routine + go ec.closeConnection(DuplicateClientID) + } + // Bind with the new client + es.c = c + es.clean = cleanSess + } else { + // Spec [MQTT-3.2.2-3]: if the Server does not have stored Session state, + // it MUST set Session Present to 0 in the CONNACK packet. + es = mqttSessionCreate(s.getOpts()) + es.c, es.clean, es.stream = c, cleanSess, asm.sstream + es.mu.Lock() + defer es.mu.Unlock() + asm.sessions[cp.clientID] = es + es.save(cp.clientID) + } + c.mu.Lock() + c.flags.set(connectReceived) + c.mqtt.cp = cp + c.mqtt.asm = asm + c.mqtt.sess = es + c.mu.Unlock() + // Spec [MQTT-3.2.0-1]: At this point we need to send the CONNACK before + // restoring subscriptions, because CONNACK must be the first packet sent + // to the client. + sendConnAck(mqttConnAckRCConnectionAccepted, sessp) + // Now process possible saved subscriptions. + if l := len(es.subs); l > 0 { + filters := make([]*mqttFilter, 0, l) + for subject, qos := range es.subs { + filters = append(filters, &mqttFilter{filter: subject, qos: qos}) + } + if _, err := asm.processSubs(es, cp.clientID, c, filters, false, trace); err != nil { + return err + } + } + return nil +} + +func (c *client) mqttEnqueueConnAck(rc byte, sessionPresent bool) { + proto := [4]byte{mqttPacketConnectAck, 2, 0, rc} + c.mu.Lock() + // Spec [MQTT-3.2.2-4]. If return code is different from 0, then + // session present flag must be set to 0. + if rc == 0 { + if sessionPresent { + proto[2] = 1 + } + } + c.enqueueProto(proto[:]) + c.mu.Unlock() +} + +func (s *Server) mqttHandleWill(c *client) { + c.mu.Lock() + if c.mqtt.cp == nil { + c.mu.Unlock() + return + } + will := c.mqtt.cp.will + if will == nil { + c.mu.Unlock() + return + } + pp := &mqttPublish{ + subject: will.topic, + msg: will.message, + sz: len(will.message) - LEN_CR_LF, + flags: will.qos << 1, + } + if will.retain { + pp.flags |= mqttPubFlagRetain + } + c.mu.Unlock() + s.mqttProcessPub(c, pp) + c.flushClients(0) +} + +////////////////////////////////////////////////////////////////////////////// +// +// PUBLISH protocol related functions +// +////////////////////////////////////////////////////////////////////////////// + +func (c *client) mqttParsePub(r *mqttReader, pl int, pp *mqttPublish) error { + qos := (pp.flags & mqttPubFlagQoS) >> 1 + if qos > 1 { + return fmt.Errorf("publish QoS=%v not supported", qos) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + return err + } + // Keep track of where we are when starting to read the variable header + start := r.pos + + var err error + pp.subject, err = r.readBytes("topic", false) + if err != nil { + return err + } + if len(pp.subject) == 0 { + return fmt.Errorf("topic cannot be empty") + } + // Convert the topic to a NATS subject. This call will also check that + // there is no MQTT wildcards (Spec [MQTT-3.3.2-2] and [MQTT-4.7.1-1]) + // Note that this may not result in a copy if there is no special + // conversion. It is good because after the message is processed we + // won't have a reference to the buffer and we save a copy. + _, pp.subject, err = mqttTopicToNATSPubSubject(pp.subject) + if err != nil { + return err + } + + if qos > 0 { + pp.pi, err = r.readUint16("packet identifier") + if err != nil { + return err + } + if pp.pi == 0 { + return fmt.Errorf("with QoS=%v, packet identifier cannot be 0", qos) + } + } + + // The message payload will be the total packet length minus + // what we have consumed for the variable header + pp.sz = pl - (r.pos - start) + pp.msg = make([]byte, 0, pp.sz+2) + if pp.sz > 0 { + start = r.pos + r.pos += pp.sz + pp.msg = append(pp.msg, r.buf[start:r.pos]...) + } + pp.msg = append(pp.msg, _CRLF_...) + return nil +} + +func mqttPubTrace(pp *mqttPublish) string { + dup := pp.flags&mqttPubFlagDup != 0 + qos := mqttGetQoS(pp.flags) + retain := mqttIsRetained(pp.flags) + var piStr string + if pp.pi > 0 { + piStr = fmt.Sprintf(" pi=%v", pp.pi) + } + return fmt.Sprintf("%s dup=%v QoS=%v retain=%v size=%v%s", + pp.subject, dup, qos, retain, pp.sz, piStr) +} + +func (s *Server) mqttProcessPub(c *client, pp *mqttPublish) { + c.mqtt.pp = pp + c.pa.subject, c.pa.hdr, c.pa.size, c.pa.szb = pp.subject, -1, pp.sz, []byte(strconv.FormatInt(int64(pp.sz), 10)) + // This will work for QoS 0 but mqtt msg delivery callback will ignore + // delivery for QoS > 0 published messages (since it is handled specifically + // with call to directProcessInboundJetStreamMsg). + // However, this needs to be invoked before directProcessInboundJetStreamMsg() + // in case we are dealing with publish retained messages. + c.processInboundClientMsg(pp.msg) + if mqttGetQoS(pp.flags) > 0 { + // Since this is the fast path, we access the messages stream directly here + // without locking. All the fields mqtt.asm.mstream are immutable. + c.mqtt.asm.mstream.processInboundJetStreamMsg(nil, c, string(c.pa.subject), "", pp.msg[:len(pp.msg)-LEN_CR_LF]) + } + c.pa.subject, c.pa.hdr, c.pa.size, c.pa.szb = nil, -1, 0, nil + c.mqtt.pp = nil +} + +// Invoked when processing an inbound client message. If the "retain" flag is +// set, the message is stored so it can be later resent to (re)starting +// subscriptions that match the subject. +// +// Invoked from the publisher's readLoop. No client lock is held on entry. +func (c *client) mqttHandlePubRetain() { + pp := c.mqtt.pp + if mqttIsRetained(pp.flags) { + key := string(pp.subject) + asm := c.mqtt.asm + asm.mu.Lock() + // Spec [MQTT-3.3.1-11]. Payload of size 0 removes the retained message, + // but should still be delivered as a normal message. + if pp.sz == 0 { + if asm.retmsgs != nil { + if erm, ok := asm.retmsgs[key]; ok { + delete(asm.retmsgs, key) + asm.sl.Remove(erm.sub) + if erm.sseq != 0 { + asm.rstream.DeleteMsg(erm.sseq) + } + } + } + } else { + // Spec [MQTT-3.3.1-5]. Store the retained message with its QoS. + // When coming from a publish protocol, `pp` is referencing a stack + // variable that itself possibly references the read buffer. + rm := &mqttRetainedMsg{ + Msg: copyBytes(pp.msg), + Flags: pp.flags, + Source: c.opts.Username, + } + rm = asm.handleRetainedMsg(key, rm) + rmBytes, _ := json.Marshal(rm) + // TODO: For now we will report the error but continue... + seq, _, err := asm.rstream.store.StoreMsg(key, nil, rmBytes) + if err != nil { + c.mu.Lock() + acc := c.acc + c.mu.Unlock() + c.Errorf("unable to store retained message for account %q, subject %q: %v", + acc.GetName(), key, err) + } + // If it has been replaced, rm.sseq will be != 0 + if rm.sseq != 0 { + asm.rstream.DeleteMsg(rm.sseq) + } + // Keep track of current stream sequence (possibly 0 if failed to store) + rm.sseq = seq + } + + asm.mu.Unlock() + + // Clear the retain flag for a normal published message. + pp.flags &= ^mqttPubFlagRetain + } +} + +// After a config reload, it is possible that the source of a publish retained +// message is no longer allowed to publish on the given topic. If that is the +// case, the retained message is removed from the map and will no longer be +// sent to (re)starting subscriptions. +// +// Server lock is held on entry +func (s *Server) mqttCheckPubRetainedPerms() { + sm := &s.mqtt.sessmgr + sm.mu.RLock() + defer sm.mu.RUnlock() + + for _, asm := range sm.sessions { + perms := map[string]*perm{} + asm.mu.Lock() + for subject, rm := range asm.retmsgs { + if rm.Source == _EMPTY_ { + continue + } + // Lookup source from global users. + u := s.users[rm.Source] + if u != nil { + p, ok := perms[rm.Source] + if !ok { + p = generatePubPerms(u.Permissions) + perms[rm.Source] = p + } + // If there is permission and no longer allowed to publish in + // the subject, remove the publish retained message from the map. + if p != nil && !pubAllowed(p, subject) { + u = nil + } + } + + // Not present or permissions have changed such that the source can't + // publish on that subject anymore: remove it from the map. + if u == nil { + delete(asm.retmsgs, subject) + asm.rstream.DeleteMsg(rm.sseq) + asm.sl.Remove(rm.sub) + } + } + asm.mu.Unlock() + } +} + +// Helper to generate only pub permissions from a Permissions object +func generatePubPerms(perms *Permissions) *perm { + var p *perm + if perms.Publish.Allow != nil { + p = &perm{} + p.allow = NewSublistWithCache() + } + for _, pubSubject := range perms.Publish.Allow { + sub := &subscription{subject: []byte(pubSubject)} + p.allow.Insert(sub) + } + if len(perms.Publish.Deny) > 0 { + if p == nil { + p = &perm{} + } + p.deny = NewSublistWithCache() + } + for _, pubSubject := range perms.Publish.Deny { + sub := &subscription{subject: []byte(pubSubject)} + p.deny.Insert(sub) + } + return p +} + +// Helper that checks if given `perms` allow to publish on the given `subject` +func pubAllowed(perms *perm, subject string) bool { + allowed := true + if perms.allow != nil { + r := perms.allow.Match(subject) + allowed = len(r.psubs) != 0 + } + // If we have a deny list and are currently allowed, check that as well. + if allowed && perms.deny != nil { + r := perms.deny.Match(subject) + allowed = len(r.psubs) == 0 + } + return allowed +} + +func mqttWritePublish(w *mqttWriter, qos byte, dup, retain bool, subject string, pi uint16, payload []byte) { + flags := qos << 1 + if dup { + flags |= mqttPubFlagDup + } + if retain { + flags |= mqttPubFlagRetain + } + w.WriteByte(mqttPacketPub | flags) + pkLen := 2 + len(subject) + len(payload) + if qos > 0 { + pkLen += 2 + } + w.WriteVarInt(pkLen) + w.WriteString(subject) + if qos > 0 { + w.WriteUint16(pi) + } + w.Write([]byte(payload)) +} + +func (c *client) mqttEnqueuePubAck(pi uint16) { + proto := [4]byte{mqttPacketPubAck, 0x2, 0, 0} + proto[2] = byte(pi >> 8) + proto[3] = byte(pi) + c.mu.Lock() + c.enqueueProto(proto[:4]) + c.mu.Unlock() +} + +func mqttParsePubAck(r *mqttReader, pl int) (uint16, error) { + if err := r.ensurePacketInBuffer(pl); err != nil { + return 0, err + } + pi, err := r.readUint16("packet identifier") + if err != nil { + return 0, err + } + if pi == 0 { + return 0, fmt.Errorf("packet identifier cannot be 0") + } + return pi, nil +} + +func (c *client) mqttProcessPubAck(pi uint16) { + sess := c.mqtt.sess + if sess == nil { + return + } + sess.mu.Lock() + defer sess.mu.Unlock() + if sess.c != c { + return + } + if ack, ok := sess.pending[pi]; ok { + delete(sess.pending, pi) + jsCons := ack.jsCons + if sseqToPi, ok := sess.cpending[jsCons]; ok { + delete(sseqToPi, ack.sseq) + } + jsCons.ackMsg(ack.sseq, ack.dseq, ack.dcount) + if len(sess.pending) == 0 { + sess.ppi = 0 + } + if sess.stalled && len(sess.pending) < int(sess.maxp) { + sess.stalled = false + for _, cons := range sess.cons { + cons.mu.Lock() + cons.forceExpirePending() + cons.mu.Unlock() + } + } + } +} + +// Return the QoS from the given PUBLISH protocol's flags +func mqttGetQoS(flags byte) byte { + return flags & mqttPubFlagQoS >> 1 +} + +func mqttIsRetained(flags byte) bool { + return flags&mqttPubFlagRetain != 0 +} + +////////////////////////////////////////////////////////////////////////////// +// +// SUBSCRIBE related functions +// +////////////////////////////////////////////////////////////////////////////// + +func (c *client) mqttParseSubs(r *mqttReader, b byte, pl int) (uint16, []*mqttFilter, error) { + return c.mqttParseSubsOrUnsubs(r, b, pl, true) +} + +func (c *client) mqttParseSubsOrUnsubs(r *mqttReader, b byte, pl int, sub bool) (uint16, []*mqttFilter, error) { + var expectedFlag byte + var action string + if sub { + expectedFlag = mqttSubscribeFlags + } else { + expectedFlag = mqttUnsubscribeFlags + action = "un" + } + // Spec [MQTT-3.8.1-1], [MQTT-3.10.1-1] + if rf := b & 0xf; rf != expectedFlag { + return 0, nil, fmt.Errorf("wrong %ssubscribe reserved flags: %x", action, rf) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + return 0, nil, err + } + pi, err := r.readUint16("packet identifier") + if err != nil { + return 0, nil, fmt.Errorf("reading packet identifier: %v", err) + } + end := r.pos + (pl - 2) + var filters []*mqttFilter + for r.pos < end { + // Don't make a copy now because, this will happen during conversion + // or when processing the sub. + filter, err := r.readBytes("topic filter", false) + if err != nil { + return 0, nil, err + } + if len(filter) == 0 { + return 0, nil, errors.New("topic filter cannot be empty") + } + // Spec [MQTT-3.8.3-1], [MQTT-3.10.3-1] + if !utf8.Valid(filter) { + return 0, nil, fmt.Errorf("invalid utf8 for topic filter %q", filter) + } + var qos byte + // This won't return an error. We will find out if the subject + // is valid or not when trying to create the subscription. + _, filter, _ = mqttFilterToNATSSubject(filter) + if sub { + qos, err = r.readByte("QoS") + if err != nil { + return 0, nil, err + } + // Spec [MQTT-3-8.3-4]. + if qos > 2 { + return 0, nil, fmt.Errorf("subscribe QoS value must be 0, 1 or 2, got %v", qos) + } + } + filters = append(filters, &mqttFilter{string(filter), qos}) + } + // Spec [MQTT-3.8.3-3], [MQTT-3.10.3-2] + if len(filters) == 0 { + return 0, nil, fmt.Errorf("%ssubscribe protocol must contain at least 1 topic filter", action) + } + return pi, filters, nil +} + +func mqttSubscribeTrace(filters []*mqttFilter) string { + var sep string + trace := "[" + for i, f := range filters { + trace += sep + fmt.Sprintf("%s QoS=%v", f.filter, f.qos) + if i == 0 { + sep = ", " + } + } + trace += "]" + return trace +} + +func mqttDeliverMsgCb(sub *subscription, pc *client, subject, reply string, msg []byte) { + if sub.mqtt == nil { + return + } + + var ppFlags byte + var pQoS byte + var pi uint16 + var dup bool + var retained bool + + // This is the client associated with the subscription. + cc := sub.client + + // This is immutable + sess := cc.mqtt.sess + // We lock to check some of the subscription's fields and if we need to + // keep track of pending acks, etc.. + sess.mu.Lock() + if sess.c != cc { + sess.mu.Unlock() + return + } + + // Check the publisher's kind. If JETSTREAM it means that this is a persisted message + // that is being delivered. + if pc.kind == JETSTREAM { + // If there is no JS consumer attached to this subscription, it means that we are + // dealing with a bare NATS subscription, in which case we simply return to avoid + // duplicate delivery. + if sub.mqtt.jsCons == nil { + sess.mu.Unlock() + return + } + ppFlags = mqttPubQos1 + pQoS = 1 + // This is a QoS1 message for a QoS1 subscription, so get the pi and keep + // track of ack subject. + pi, dup = sess.trackPending(pQoS, reply, sub) + if pi == 0 { + // We have reached max pending, don't send the message now. + // JS will cause a redelivery and if by then the number of pending + // messages has fallen below threshold, the message will be resent. + sess.mu.Unlock() + return + } + // In JS case, we need to use the pc.ca.deliver value as the subject. + subject = string(pc.pa.deliver) + } else if pc.mqtt != nil { + // This is a MQTT publisher... + ppFlags = pc.mqtt.pp.flags + pQoS = mqttGetQoS(ppFlags) + // If the QoS of published message and subscription is 1, then we return here to + // avoid duplicate delivery. The JetStream publisher will handle that case. + if pQoS > 0 && sub.mqtt.qos > 0 { + sess.mu.Unlock() + return + } + retained = mqttIsRetained(ppFlags) + } + // else this is coming from a non MQTT publisher, so Qos 0, no dup nor retain flag, etc.. + sess.mu.Unlock() + + sw := mqttWriter{} + w := &sw + + flags := mqttSerializePublishMsg(w, pi, dup, retained, subject, msg) + + cc.mu.Lock() + if sub.mqtt.prm != nil { + cc.queueOutbound(sub.mqtt.prm.Bytes()) + sub.mqtt.prm = nil + } + cc.queueOutbound(w.Bytes()) + pc.addToPCD(cc) + if cc.trace { + pp := mqttPublish{ + flags: flags, + pi: pi, + subject: []byte(subject), + sz: len(msg), + } + cc.traceOutOp("PUBLISH", []byte(mqttPubTrace(&pp))) + } + cc.mu.Unlock() +} + +// Serializes to the given writer the message for the given subject. +func mqttSerializePublishMsg(w *mqttWriter, pi uint16, dup, retained bool, subject string, msg []byte) byte { + topic := natsSubjectToMQTTTopic(subject) + + // Compute len (will have to add packet id if message is sent as QoS>=1) + pkLen := 2 + len(topic) + len(msg) + + var flags byte + + // Set flags for dup/retained/qos1 + if dup { + flags |= mqttPubFlagDup + } + if retained { + flags |= mqttPubFlagRetain + } + // For now, we have only QoS 1 + if pi > 0 { + pkLen += 2 + flags |= mqttPubQos1 + } + + w.WriteByte(mqttPacketPub | flags) + w.WriteVarInt(pkLen) + w.WriteBytes(topic) + if pi > 0 { + w.WriteUint16(pi) + } + w.Write(msg) + + return flags +} + +// Helper to create an MQTT subscription. +func (c *client) mqttCreateSub(subject, sid string, cb msgHandler, qos byte) *subscription { + sub := c.createSub([]byte(subject), nil, []byte(sid), cb) + sub.mqtt = &mqttSub{qos: qos} + return sub +} + +// Process the list of subscriptions and update the given filter +// with the QoS that has been accepted (or failure). +// +// Spec [MQTT-3.8.4-3] says that if an exact same subscription is +// found, it needs to be replaced with the new one (possibly updating +// the qos) and that the flow of publications must not be interrupted, +// which I read as the replacement cannot be a "remove then add" if there +// is a chance that in between the 2 actions, published messages +// would be "lost" because there would not be any matching subscription. +func (c *client) mqttProcessSubs(filters []*mqttFilter) ([]*subscription, error) { + // Those things are immutable, but since processing subs is not + // really in the fast path, let's get them under the client lock. + c.mu.Lock() + asm := c.mqtt.asm + sess := c.mqtt.sess + clientID := c.mqtt.cp.clientID + trace := c.trace + c.mu.Unlock() + + asm.mu.RLock() + defer asm.mu.RUnlock() + return asm.processSubs(sess, clientID, c, filters, true, trace) +} + +func (c *client) mqttCleanupFailedSub(sub *subscription, jscons *Consumer, jssub *subscription) { + c.mu.Lock() + acc := c.acc + c.mu.Unlock() + + if sub != nil { + c.unsubscribe(acc, sub, true, true) + } + if jssub != nil { + c.unsubscribe(acc, jssub, true, true) + } + if jscons != nil { + jscons.Delete() + } +} + +// When invoked with a QoS of 0, looks for an existing JS durable consumer for +// the given sid and if one is found, delete the JS durable consumer and unsub +// the NATS subscription on the delivery subject. +// With a QoS > 0, creates or update the existing JS durable consumer along with +// its NATS subscription on a delivery subject. +// +// Account session manager lock held on entry. +func (c *client) mqttProcessJSConsumer(sess *mqttSession, stream *Stream, subject, + sid string, qos byte, fromSubProto bool) (*Consumer, *subscription, error) { + + // Check if we are already a JS consumer for this SID. + cons, exists := sess.cons[sid] + if exists { + // If current QoS is 0, it means that we need to delete the existing + // one (that was QoS > 0) + if qos == 0 { + // The JS durable consumer's delivery subject is on a NUID of + // the form: mqttSubPrefix + . It is also used as the sid + // for the NATS subscription, so use that for the lookup. + sub := c.subs[cons.Config().DeliverSubject] + delete(sess.cons, sid) + cons.Delete() + if sub != nil { + c.mu.Lock() + acc := c.acc + c.mu.Unlock() + c.unsubscribe(acc, sub, true, true) + } + return nil, nil, nil + } + // If this is called when processing SUBSCRIBE protocol, then if + // the JS consumer already exists, we are done (it was created + // during the processing of CONNECT). + if fromSubProto { + return nil, nil, nil + } + } + // Here it means we don't have a JS consumer and if we are QoS 0, + // we have nothing to do. + if qos == 0 { + return nil, nil, nil + } + var err error + inbox := mqttSubPrefix + nuid.Next() + if exists { + cons.updateDeliverSubject(inbox) + } else { + durName := nuid.Next() + opts := c.srv.getOpts() + ackWait := opts.MQTT.AckWait + if ackWait == 0 { + ackWait = mqttDefaultAckWait + } + maxAckPending := opts.MQTT.MaxAckPending + if maxAckPending == 0 { + maxAckPending = mqttDefaultMaxAckPending + } + cc := &ConsumerConfig{ + DeliverSubject: inbox, + Durable: durName, + AckPolicy: AckExplicit, + DeliverPolicy: DeliverNew, + FilterSubject: subject, + AckWait: ackWait, + MaxAckPending: int(maxAckPending), + } + cons, err = stream.addConsumerCheckInterest(cc, false) + if err != nil { + c.Errorf("Unable to add JetStream consumer for subscription on %q: err=%v", subject, err) + return nil, nil, err + } + } + sub := c.mqttCreateSub(inbox, inbox, mqttDeliverMsgCb, qos) + sub.mqtt.jsCons = cons + // This is an internal subscription on subject like "$MQTT.sub." that is setup + // for the JS durable's deliver subject. I don't think that there is any need to + // forward this subscription in the cluster/super cluster. + sub, err = c.processSub(sub, true) + if err != nil { + if !exists { + cons.Delete() + } + c.Errorf("Unable to create subscription for JetStream consumer on %q: %v", subject, err) + return nil, nil, err + } + return cons, sub, nil +} + +// Queues the published retained messages for each subscription and signals +// the writeLoop. +func (c *client) mqttSendRetainedMsgsToNewSubs(subs []*subscription) { + c.mu.Lock() + for _, sub := range subs { + if sub.mqtt != nil && sub.mqtt.prm != nil { + c.queueOutbound(sub.mqtt.prm.Bytes()) + sub.mqtt.prm = nil + } + } + c.flushSignal() + c.mu.Unlock() +} + +func (c *client) mqttEnqueueSubAck(pi uint16, filters []*mqttFilter) { + w := &mqttWriter{} + w.WriteByte(mqttPacketSubAck) + // packet length is 2 (for packet identifier) and 1 byte per filter. + w.WriteVarInt(2 + len(filters)) + w.WriteUint16(pi) + for _, f := range filters { + w.WriteByte(f.qos) + } + c.mu.Lock() + c.enqueueProto(w.Bytes()) + c.mu.Unlock() +} + +////////////////////////////////////////////////////////////////////////////// +// +// UNSUBSCRIBE related functions +// +////////////////////////////////////////////////////////////////////////////// + +func (c *client) mqttParseUnsubs(r *mqttReader, b byte, pl int) (uint16, []*mqttFilter, error) { + return c.mqttParseSubsOrUnsubs(r, b, pl, false) +} + +func (c *client) mqttProcessUnsubs(filters []*mqttFilter) error { + // Those things are immutable, but since processing unsubs is not + // really in the fast path, let's get them under the client lock. + c.mu.Lock() + sess := c.mqtt.sess + clientID := c.mqtt.cp.clientID + c.mu.Unlock() + + sess.mu.Lock() + defer sess.mu.Unlock() + if sess.c != c { + return fmt.Errorf("client %q no longer registered with MQTT session", clientID) + } + + removeJSCons := func(sid string) { + if jscons, ok := sess.cons[sid]; ok { + delete(sess.cons, sid) + jscons.Delete() + if seqPis, ok := sess.cpending[jscons]; ok { + delete(sess.cpending, jscons) + for _, pi := range seqPis { + delete(sess.pending, pi) + } + if len(sess.pending) == 0 { + sess.ppi = 0 + } + } + } + } + for _, f := range filters { + sid := f.filter + // Remove JS Consumer if one exists for this sid + removeJSCons(sid) + if err := c.processUnsub([]byte(sid)); err != nil { + c.Errorf("error unsubscribing from %q: %v", sid, err) + } + if mqttNeedSubForLevelUp(sid) { + subject := sid[:len(sid)-2] + sid = subject + mqttMultiLevelSidSuffix + removeJSCons(sid) + if err := c.processUnsub([]byte(sid)); err != nil { + c.Errorf("error unsubscribing from %q: %v", subject, err) + } + } + } + return sess.update(clientID, filters, false) +} + +func (c *client) mqttEnqueueUnsubAck(pi uint16) { + w := &mqttWriter{} + w.WriteByte(mqttPacketUnsubAck) + w.WriteVarInt(2) + w.WriteUint16(pi) + c.mu.Lock() + c.enqueueProto(w.Bytes()) + c.mu.Unlock() +} + +func mqttUnsubscribeTrace(filters []*mqttFilter) string { + var sep string + trace := "[" + for i, f := range filters { + trace += sep + f.filter + if i == 0 { + sep = ", " + } + } + trace += "]" + return trace +} + +////////////////////////////////////////////////////////////////////////////// +// +// PINGREQ/PINGRESP related functions +// +////////////////////////////////////////////////////////////////////////////// + +func (c *client) mqttEnqueuePingResp() { + c.mu.Lock() + c.enqueueProto(mqttPingResponse) + c.mu.Unlock() +} + +////////////////////////////////////////////////////////////////////////////// +// +// Trace functions +// +////////////////////////////////////////////////////////////////////////////// + +func errOrTrace(err error, trace string) []byte { + if err != nil { + return []byte(err.Error()) + } + return []byte(trace) +} + +////////////////////////////////////////////////////////////////////////////// +// +// Subject/Topic conversion functions +// +////////////////////////////////////////////////////////////////////////////// + +// Converts an MQTT Topic Name to a NATS Subject (used by PUBLISH) +// See mqttToNATSSubjectConversion() for details. +func mqttTopicToNATSPubSubject(mt []byte) (bool, []byte, error) { + return mqttToNATSSubjectConversion(mt, false) +} + +// Converts an MQTT Topic Filter to a NATS Subject (used by SUBSCRIBE) +// See mqttToNATSSubjectConversion() for details. +func mqttFilterToNATSSubject(filter []byte) (bool, []byte, error) { + return mqttToNATSSubjectConversion(filter, true) +} + +// Converts an MQTT Topic Name or Filter to a NATS Subject +// In MQTT: +// - a Topic Name does not have wildcard (PUBLISH uses only topic names). +// - a Topic Filter can include wildcards (SUBSCRIBE uses those). +// - '+' and '#' are wildcard characters (single and multiple levels respectively) +// - '/' is the topic level separator. +// +// Conversion that occurs: +// - '/' is replaced with '/.' if it is the first character in mt +// - '/' is replaced with './' if the last or next character in mt is '/' +// For instance, foo//bar would become foo./.bar +// - '/' is replaced with '.' for all other conditions (foo/bar -> foo.bar) +// - '.' and ' ' cause an error to be returned. +// +// If a copy occurred, the returned boolean will indicate this condition. +func mqttToNATSSubjectConversion(mt []byte, wcOk bool) (bool, []byte, error) { + var res = mt + var newSlice bool + + copyTopic := func(pos int) []byte { + if newSlice && cap(res) > pos+2 { + return res + } + newSlice = true + b := make([]byte, len(res)+10) + copy(b, res[:pos]) + res = b + return res + } + + var j int + end := len(mt) - 1 + for i := 0; i < len(mt); i++ { + switch mt[i] { + case mqttTopicLevelSep: + if i == 0 || res[j-1] == btsep { + res = copyTopic(0) + res[j] = mqttTopicLevelSep + j++ + res[j] = btsep + } else if i == end || mt[i+1] == mqttTopicLevelSep { + res = copyTopic(j) + res[j] = btsep + j++ + res[j] = mqttTopicLevelSep + } else { + res[j] = btsep + } + case btsep, ' ': + // As of now, we cannot support '.' or ' ' in the MQTT topic/filter. + return false, nil, fmt.Errorf("characters ' ' and '.' not supported for MQTT topics") + case mqttSingleLevelWC, mqttMultiLevelWC: + if !wcOk { + // Spec [MQTT-3.3.2-2] and [MQTT-4.7.1-1] + // The wildcard characters can be used in Topic Filters, but MUST NOT be used within a Topic Name + return false, nil, fmt.Errorf("wildcards not allowed in publish's topic: %q", mt) + } + if mt[i] == mqttSingleLevelWC { + res[j] = pwc + } else { + res[j] = fwc + } + default: + if newSlice { + res[j] = mt[i] + } + } + j++ + } + if newSlice && res[j-1] == btsep { + res = copyTopic(j) + res[j] = mqttTopicLevelSep + j++ + } + return newSlice, res[:j], nil +} + +// Converts a NATS subject to MQTT topic. This is for publish +// messages only, so there is no checking for wildcards. +// Rules are reversed of mqttToNATSSubjectConversion. +func natsSubjectToMQTTTopic(subject string) []byte { + topic := []byte(subject) + end := len(subject) - 1 + var j int + for i := 0; i < len(subject); i++ { + switch subject[i] { + case mqttTopicLevelSep: + if !(i == 0 && i < end && subject[i+1] == btsep) { + topic[j] = mqttTopicLevelSep + j++ + } + case btsep: + topic[j] = mqttTopicLevelSep + j++ + if i < end && subject[i+1] == mqttTopicLevelSep { + i++ + } + default: + topic[j] = subject[i] + j++ + } + } + return topic[:j] +} + +// Returns true if the subject has more than 1 token and ends with ".>" +func mqttNeedSubForLevelUp(subject string) bool { + if len(subject) < 3 { + return false + } + end := len(subject) + if subject[end-2] == '.' && subject[end-1] == fwc { + return true + } + return false +} + +////////////////////////////////////////////////////////////////////////////// +// +// MQTT Reader functions +// +////////////////////////////////////////////////////////////////////////////// + +func copyBytes(b []byte) []byte { + if b == nil { + return nil + } + cbuf := make([]byte, len(b)) + copy(cbuf, b) + return cbuf +} + +func (r *mqttReader) reset(buf []byte) { + r.buf = buf + r.pos = 0 +} + +func (r *mqttReader) hasMore() bool { + return r.pos != len(r.buf) +} + +func (r *mqttReader) readByte(field string) (byte, error) { + if r.pos == len(r.buf) { + return 0, fmt.Errorf("error reading %s: %v", field, io.EOF) + } + b := r.buf[r.pos] + r.pos++ + return b, nil +} + +func (r *mqttReader) readPacketLen() (int, error) { + m := 1 + v := 0 + for { + var b byte + if r.pos != len(r.buf) { + b = r.buf[r.pos] + r.pos++ + } else { + var buf [1]byte + if _, err := r.reader.Read(buf[:1]); err != nil { + if err == io.EOF { + return 0, io.ErrUnexpectedEOF + } + return 0, fmt.Errorf("error reading packet length: %v", err) + } + b = buf[0] + } + v += int(b&0x7f) * m + if (b & 0x80) == 0 { + return v, nil + } + m *= 0x80 + if m > 0x200000 { + return 0, errors.New("malformed variable int") + } + } +} + +func (r *mqttReader) ensurePacketInBuffer(pl int) error { + rem := len(r.buf) - r.pos + if rem >= pl { + return nil + } + b := make([]byte, pl) + start := copy(b, r.buf[r.pos:]) + for start != pl { + n, err := r.reader.Read(b[start:cap(b)]) + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return fmt.Errorf("error ensuring protocol is loaded: %v", err) + } + start += n + } + r.reset(b) + return nil +} + +func (r *mqttReader) readString(field string) (string, error) { + var s string + bs, err := r.readBytes(field, false) + if err == nil { + s = string(bs) + } + return s, err +} + +func (r *mqttReader) readBytes(field string, cp bool) ([]byte, error) { + luint, err := r.readUint16(field) + if err != nil { + return nil, err + } + l := int(luint) + if l == 0 { + return nil, nil + } + start := r.pos + if start+l > len(r.buf) { + return nil, fmt.Errorf("error reading %s: %v", field, io.ErrUnexpectedEOF) + } + r.pos += l + b := r.buf[start:r.pos] + if cp { + b = copyBytes(b) + } + return b, nil +} + +func (r *mqttReader) readUint16(field string) (uint16, error) { + if len(r.buf)-r.pos < 2 { + return 0, fmt.Errorf("error reading %s: %v", field, io.ErrUnexpectedEOF) + } + start := r.pos + r.pos += 2 + return binary.BigEndian.Uint16(r.buf[start:r.pos]), nil +} + +////////////////////////////////////////////////////////////////////////////// +// +// MQTT Writer functions +// +////////////////////////////////////////////////////////////////////////////// + +func (w *mqttWriter) WriteUint16(i uint16) { + w.WriteByte(byte(i >> 8)) + w.WriteByte(byte(i)) +} + +func (w *mqttWriter) WriteString(s string) { + w.WriteBytes([]byte(s)) +} + +func (w *mqttWriter) WriteBytes(bs []byte) { + w.WriteUint16(uint16(len(bs))) + w.Write(bs) +} + +func (w *mqttWriter) WriteVarInt(value int) { + for { + b := byte(value & 0x7f) + value >>= 7 + if value > 0 { + b |= 0x80 + } + w.WriteByte(b) + if value == 0 { + break + } + } +} diff --git a/server/mqtt_test.go b/server/mqtt_test.go new file mode 100644 index 00000000..d5ad36b3 --- /dev/null +++ b/server/mqtt_test.go @@ -0,0 +1,4111 @@ +// Copyright 2020 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bufio" + "bytes" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/nats-io/jwt/v2" + "github.com/nats-io/nats.go" +) + +type mqttErrorReader struct { + err error +} + +func (r *mqttErrorReader) Read(b []byte) (int, error) { return 0, r.err } +func (r *mqttErrorReader) SetReadDeadline(time.Time) error { return nil } + +func testNewEOFReader() *mqttErrorReader { + return &mqttErrorReader{err: io.EOF} +} + +func TestMQTTReader(t *testing.T) { + r := &mqttReader{} + r.reset([]byte{0, 2, 'a', 'b'}) + bs, err := r.readBytes("", false) + if err != nil { + t.Fatal(err) + } + sbs := string(bs) + if sbs != "ab" { + t.Fatalf(`expected "ab", got %q`, sbs) + } + + r.reset([]byte{0, 2, 'a', 'b'}) + bs, err = r.readBytes("", true) + if err != nil { + t.Fatal(err) + } + bs[0], bs[1] = 'c', 'd' + if bytes.Equal(bs, r.buf[2:]) { + t.Fatal("readBytes should have returned a copy") + } + + r.reset([]byte{'a', 'b'}) + if b, err := r.readByte(""); err != nil || b != 'a' { + t.Fatalf("Error reading byte: b=%v err=%v", b, err) + } + if !r.hasMore() { + t.Fatal("expected to have more, did not") + } + if b, err := r.readByte(""); err != nil || b != 'b' { + t.Fatalf("Error reading byte: b=%v err=%v", b, err) + } + if r.hasMore() { + t.Fatal("expected to not have more") + } + if _, err := r.readByte("test"); err == nil || !strings.Contains(err.Error(), "error reading test") { + t.Fatalf("unexpected error: %v", err) + } + + r.reset([]byte{0, 2, 'a', 'b'}) + if s, err := r.readString(""); err != nil || s != "ab" { + t.Fatalf("Error reading string: s=%q err=%v", s, err) + } + + r.reset([]byte{10}) + if _, err := r.readUint16("uint16"); err == nil || !strings.Contains(err.Error(), "error reading uint16") { + t.Fatalf("unexpected error: %v", err) + } + + r.reset([]byte{1, 2, 3}) + r.reader = testNewEOFReader() + if err := r.ensurePacketInBuffer(10); err == nil || !strings.Contains(err.Error(), "error ensuring protocol is loaded") { + t.Fatalf("unexpected error: %v", err) + } + + r.reset([]byte{0x82, 0xff, 0x3}) + l, err := r.readPacketLen() + if err != nil { + t.Fatal("error getting packet len") + } + if l != 0xff82 { + t.Fatalf("expected length 0xff82 got 0x%x", l) + } + r.reset([]byte{0xff, 0xff, 0xff, 0xff, 0xff}) + if _, err := r.readPacketLen(); err == nil || !strings.Contains(err.Error(), "malformed") { + t.Fatalf("unexpected error: %v", err) + } + r.reset([]byte{0x80}) + if _, err := r.readPacketLen(); err != io.ErrUnexpectedEOF { + t.Fatalf("unexpected error: %v", err) + } + + r.reset([]byte{0x80}) + r.reader = &mqttErrorReader{err: errors.New("on purpose")} + if _, err := r.readPacketLen(); err == nil || !strings.Contains(err.Error(), "on purpose") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestMQTTWriter(t *testing.T) { + w := &mqttWriter{} + w.WriteUint16(1234) + + r := &mqttReader{} + r.reset(w.Bytes()) + if v, err := r.readUint16(""); err != nil || v != 1234 { + t.Fatalf("unexpected value: v=%v err=%v", v, err) + } + + w.Reset() + w.WriteString("test") + r.reset(w.Bytes()) + if len(r.buf) != 6 { + t.Fatalf("Expected 2 bytes size before string, got %v", r.buf) + } + + w.Reset() + w.WriteBytes([]byte("test")) + r.reset(w.Bytes()) + if len(r.buf) != 6 { + t.Fatalf("Expected 2 bytes size before bytes, got %v", r.buf) + } + + ints := []int{ + 0, 1, 127, 128, 16383, 16384, 2097151, 2097152, 268435455, + } + lens := []int{ + 1, 1, 1, 2, 2, 3, 3, 4, 4, + } + + tl := 0 + w.Reset() + for i, v := range ints { + w.WriteVarInt(v) + tl += lens[i] + if tl != w.Len() { + t.Fatalf("expected len %d, got %d", tl, w.Len()) + } + } + + r.reset(w.Bytes()) + for _, v := range ints { + x, _ := r.readPacketLen() + if v != x { + t.Fatalf("expected %d, got %d", v, x) + } + } +} + +func testMQTTDefaultOptions() *Options { + o := DefaultOptions() + o.Cluster.Port = 0 + o.Gateway.Name = "" + o.Gateway.Port = 0 + o.LeafNode.Port = 0 + o.Websocket.Port = 0 + o.MQTT.Host = "127.0.0.1" + o.MQTT.Port = -1 + o.JetStream = true + return o +} + +func testMQTTRunServer(t testing.TB, o *Options) *Server { + o.NoLog = false + s, err := NewServer(o) + if err != nil { + t.Fatalf("Error creating server: %v", err) + } + l := &DummyLogger{} + s.SetLogger(l, true, true) + go s.Start() + if !s.ReadyForConnections(3 * time.Second) { + t.Fatal("Unable to start server") + } + return s +} + +func testMQTTShutdownServer(s *Server) { + if c := s.JetStreamConfig(); c != nil { + dir := strings.TrimSuffix(c.StoreDir, JetStreamStoreDir) + defer os.RemoveAll(dir) + } + s.Shutdown() +} + +func testMQTTDefaultTLSOptions(t *testing.T, verify bool) *Options { + t.Helper() + o := testMQTTDefaultOptions() + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/server-cert.pem", + KeyFile: "../test/configs/certs/server-key.pem", + CaFile: "../test/configs/certs/ca.pem", + Verify: verify, + } + var err error + o.MQTT.TLSConfig, err = GenTLSConfig(tc) + o.MQTT.TLSTimeout = 2.0 + if err != nil { + t.Fatalf("Error creating tls config: %v", err) + } + return o +} + +func TestMQTTConfig(t *testing.T) { + conf := createConfFile(t, []byte(` + mqtt { + port: -1 + tls { + cert_file: "./configs/certs/server.pem" + key_file: "./configs/certs/key.pem" + } + } + `)) + defer os.Remove(conf) + s, o := RunServerWithConfig(conf) + defer testMQTTShutdownServer(s) + if o.MQTT.TLSConfig == nil { + t.Fatal("expected TLS config to be set") + } +} + +func TestMQTTValidateOptions(t *testing.T) { + nmqtto := DefaultOptions() + mqtto := testMQTTDefaultOptions() + for _, test := range []struct { + name string + getOpts func() *Options + err string + }{ + {"mqtt disabled", func() *Options { return nmqtto.Clone() }, ""}, + {"mqtt username not allowed if users specified", func() *Options { + o := mqtto.Clone() + o.Users = []*User{&User{Username: "abc", Password: "pwd"}} + o.MQTT.Username = "b" + o.MQTT.Password = "pwd" + return o + }, "mqtt authentication username not compatible with presence of users/nkeys"}, + {"mqtt token not allowed if users specified", func() *Options { + o := mqtto.Clone() + o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: "abc"}} + o.MQTT.Token = "mytoken" + return o + }, "mqtt authentication token not compatible with presence of users/nkeys"}, + {"ack wait should be >=0", func() *Options { + o := mqtto.Clone() + o.MQTT.AckWait = -10 * time.Second + return o + }, "ack wait must be a positive value"}, + } { + t.Run(test.name, func(t *testing.T) { + err := validateMQTTOptions(test.getOpts()) + if test.err == "" && err != nil { + t.Fatalf("Unexpected error: %v", err) + } else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) { + t.Fatalf("Expected error to contain %q, got %v", test.err, err) + } + }) + } +} + +func TestMQTTParseOptions(t *testing.T) { + for _, test := range []struct { + name string + content string + checkOpt func(*MQTTOpts) error + err string + }{ + // Negative tests + {"bad type", "mqtt: []", nil, "to be a map"}, + {"bad listen", "mqtt: { listen: [] }", nil, "port or host:port"}, + {"bad port", `mqtt: { port: "abc" }`, nil, "not int64"}, + {"bad host", `mqtt: { host: 123 }`, nil, "not string"}, + {"bad tls", `mqtt: { tls: 123 }`, nil, "not map[string]interface {}"}, + {"unknown field", `mqtt: { this_does_not_exist: 123 }`, nil, "unknown"}, + {"ack wait", `mqtt: {ack_wait: abc}`, nil, "invalid duration"}, + {"max ack pending", `mqtt: {max_ack_pending: abc}`, nil, "not int64"}, + {"max ack pending too high", `mqtt: {max_ack_pending: 12345678}`, nil, "invalid value"}, + // Positive tests + {"tls gen fails", ` + mqtt { + tls { + cert_file: "./configs/certs/server.pem" + } + }`, nil, "missing 'key_file'"}, + {"listen port only", `mqtt { listen: 1234 }`, func(o *MQTTOpts) error { + if o.Port != 1234 { + return fmt.Errorf("expected 1234, got %v", o.Port) + } + return nil + }, ""}, + {"listen host and port", `mqtt { listen: "localhost:1234" }`, func(o *MQTTOpts) error { + if o.Host != "localhost" || o.Port != 1234 { + return fmt.Errorf("expected localhost:1234, got %v:%v", o.Host, o.Port) + } + return nil + }, ""}, + {"host", `mqtt { host: "localhost" }`, func(o *MQTTOpts) error { + if o.Host != "localhost" { + return fmt.Errorf("expected localhost, got %v", o.Host) + } + return nil + }, ""}, + {"port", `mqtt { port: 1234 }`, func(o *MQTTOpts) error { + if o.Port != 1234 { + return fmt.Errorf("expected 1234, got %v", o.Port) + } + return nil + }, ""}, + {"tls config", + ` + mqtt { + tls { + cert_file: "./configs/certs/server.pem" + key_file: "./configs/certs/key.pem" + } + } + `, func(o *MQTTOpts) error { + if o.TLSConfig == nil { + return fmt.Errorf("TLSConfig should have been set") + } + return nil + }, ""}, + {"no auth user", + ` + mqtt { + no_auth_user: "noauthuser" + } + `, func(o *MQTTOpts) error { + if o.NoAuthUser != "noauthuser" { + return fmt.Errorf("Invalid NoAuthUser value: %q", o.NoAuthUser) + } + return nil + }, ""}, + {"auth block", + ` + mqtt { + authorization { + user: "mqttuser" + password: "pwd" + token: "token" + timeout: 2.0 + } + } + `, func(o *MQTTOpts) error { + if o.Username != "mqttuser" || o.Password != "pwd" || o.Token != "token" || o.AuthTimeout != 2.0 { + return fmt.Errorf("Invalid auth block: %+v", o) + } + return nil + }, ""}, + {"auth timeout as int", + ` + mqtt { + authorization { + timeout: 2 + } + } + `, func(o *MQTTOpts) error { + if o.AuthTimeout != 2.0 { + return fmt.Errorf("Invalid auth timeout: %v", o.AuthTimeout) + } + return nil + }, ""}, + {"ack wait", + ` + mqtt { + ack_wait: "10s" + } + `, func(o *MQTTOpts) error { + if o.AckWait != 10*time.Second { + return fmt.Errorf("Invalid ack wait: %v", o.AckWait) + } + return nil + }, ""}, + {"max ack pending", + ` + mqtt { + max_ack_pending: 123 + } + `, func(o *MQTTOpts) error { + if o.MaxAckPending != 123 { + return fmt.Errorf("Invalid max ack pending: %v", o.MaxAckPending) + } + return nil + }, ""}, + } { + t.Run(test.name, func(t *testing.T) { + conf := createConfFile(t, []byte(test.content)) + defer os.Remove(conf) + o, err := ProcessConfigFile(conf) + if test.err != _EMPTY_ { + if err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("For content: %q, expected error about %q, got %v", test.content, test.err, err) + } + return + } else if err != nil { + t.Fatalf("Unexpected error for content %q: %v", test.content, err) + } + if err := test.checkOpt(&o.MQTT); err != nil { + t.Fatalf("Incorrect option for content %q: %v", test.content, err.Error()) + } + }) + } +} + +func TestMQTTStart(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + nc.Close() + + // Check failure to start due to port in use + o2 := testMQTTDefaultOptions() + o2.MQTT.Port = o.MQTT.Port + s2, err := NewServer(o2) + if err != nil { + t.Fatalf("Error creating server: %v", err) + } + defer s2.Shutdown() + l := &captureFatalLogger{fatalCh: make(chan string, 1)} + s2.SetLogger(l, false, false) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + s2.Start() + wg.Done() + }() + + select { + case e := <-l.fatalCh: + if !strings.Contains(e, "Unable to listen for MQTT connections") { + t.Fatalf("Unexpected error: %q", e) + } + case <-time.After(time.Second): + t.Fatal("Should have gotten a fatal error") + } +} + +func TestMQTTTLS(t *testing.T) { + o := testMQTTDefaultTLSOptions(t, false) + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + defer nc.Close() + // Set MaxVersion to TLSv1.2 so that we fail on handshake if there is + // a disagreement between server and client. + tlsc := &tls.Config{ + MaxVersion: tls.VersionTLS12, + InsecureSkipVerify: true, + } + tlsConn := tls.Client(nc, tlsc) + tlsConn.SetDeadline(time.Now().Add(time.Second)) + if err := tlsConn.Handshake(); err != nil { + t.Fatalf("Error doing tls handshake: %v", err) + } + nc.Close() + testMQTTShutdownServer(s) + + // Force client cert verification + o = testMQTTDefaultTLSOptions(t, true) + s = testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + defer nc.Close() + // Set MaxVersion to TLSv1.2 so that we fail on handshake if there is + // a disagreement between server and client. + tlsc = &tls.Config{ + MaxVersion: tls.VersionTLS12, + InsecureSkipVerify: true, + } + tlsConn = tls.Client(nc, tlsc) + tlsConn.SetDeadline(time.Now().Add(time.Second)) + if err := tlsConn.Handshake(); err == nil { + t.Fatal("Handshake expected to fail since client did not provide cert") + } + nc.Close() + + // Add client cert. + nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + defer nc.Close() + + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/client-cert.pem", + KeyFile: "../test/configs/certs/client-key.pem", + } + tlsc, err = GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error generating tls config: %v", err) + } + tlsc.InsecureSkipVerify = true + tlsConn = tls.Client(nc, tlsc) + tlsConn.SetDeadline(time.Now().Add(time.Second)) + if err := tlsConn.Handshake(); err != nil { + t.Fatalf("Handshake error: %v", err) + } + nc.Close() + testMQTTShutdownServer(s) + + // Lower TLS timeout so low that we should fail + o.MQTT.TLSTimeout = 0.001 + s = testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + defer nc.Close() + time.Sleep(100 * time.Millisecond) + tlsConn = tls.Client(nc, tlsc) + tlsConn.SetDeadline(time.Now().Add(time.Second)) + if err := tlsConn.Handshake(); err == nil { + t.Fatal("Expected failure, did not get one") + } +} + +type mqttConnInfo struct { + clientID string + cleanSess bool + keepAlive uint16 + will *mqttWill + user string + pass string +} + +func testMQTTGetClient(t testing.TB, s *Server, clientID string) *client { + t.Helper() + var mc *client + s.mu.Lock() + for _, c := range s.clients { + c.mu.Lock() + if c.mqtt != nil && c.mqtt.cp != nil && c.mqtt.cp.clientID == clientID { + mc = c + } + c.mu.Unlock() + if mc != nil { + break + } + } + s.mu.Unlock() + if mc == nil { + t.Fatalf("Did not find client %q", clientID) + } + return mc +} + +func testMQTTRead(c net.Conn) ([]byte, error) { + var buf [512]byte + // Make sure that test does not block + c.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := c.Read(buf[:]) + if err != nil { + return nil, err + } + c.SetReadDeadline(time.Time{}) + return copyBytes(buf[:n]), nil +} + +func testMQTTWrite(c net.Conn, buf []byte) (int, error) { + c.SetWriteDeadline(time.Now().Add(2 * time.Second)) + n, err := c.Write(buf) + c.SetWriteDeadline(time.Time{}) + return n, err +} + +func testMQTTConnect(t testing.TB, ci *mqttConnInfo, host string, port int) (net.Conn, *mqttReader) { + t.Helper() + + addr := fmt.Sprintf("%s:%d", host, port) + c, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating mqtt connection: %v", err) + } + + proto := mqttCreateConnectProto(ci) + if _, err := testMQTTWrite(c, proto); err != nil { + t.Fatalf("Error writing connect: %v", err) + } + + buf, err := testMQTTRead(c) + if err != nil { + t.Fatalf("Error reading: %v", err) + } + br := &mqttReader{reader: c} + br.reset(buf) + + return c, br +} + +func mqttCreateConnectProto(ci *mqttConnInfo) []byte { + flags := byte(0) + if ci.cleanSess { + flags |= mqttConnFlagCleanSession + } + if ci.will != nil { + flags |= mqttConnFlagWillFlag | (ci.will.qos << 3) + if ci.will.retain { + flags |= mqttConnFlagWillRetain + } + } + if ci.user != _EMPTY_ { + flags |= mqttConnFlagUsernameFlag + } + if ci.pass != _EMPTY_ { + flags |= mqttConnFlagPasswordFlag + } + + pkLen := 2 + len(mqttProtoName) + + 1 + // proto level + 1 + // flags + 2 + // keepAlive + 2 + len(ci.clientID) + + if ci.will != nil { + pkLen += 2 + len(ci.will.topic) + pkLen += 2 + len(ci.will.message) + } + if ci.user != _EMPTY_ { + pkLen += 2 + len(ci.user) + } + if ci.pass != _EMPTY_ { + pkLen += 2 + len(ci.pass) + } + + w := &mqttWriter{} + w.WriteByte(mqttPacketConnect) + w.WriteVarInt(pkLen) + w.WriteString(string(mqttProtoName)) + w.WriteByte(0x4) + w.WriteByte(flags) + w.WriteUint16(ci.keepAlive) + w.WriteString(ci.clientID) + if ci.will != nil { + w.WriteBytes(ci.will.topic) + w.WriteBytes(ci.will.message) + } + if ci.user != _EMPTY_ { + w.WriteString(ci.user) + } + if ci.pass != _EMPTY_ { + w.WriteBytes([]byte(ci.pass)) + } + return w.Bytes() +} + +func testMQTTCheckConnAck(t testing.TB, r *mqttReader, rc byte, sessionPresent bool) { + t.Helper() + r.reader.SetReadDeadline(time.Now().Add(2 * time.Second)) + if err := r.ensurePacketInBuffer(4); err != nil { + t.Fatalf("Error ensuring packet in buffer: %v", err) + } + r.reader.SetReadDeadline(time.Time{}) + b, err := r.readByte("connack packet type") + if err != nil { + t.Fatalf("Error reading packet type: %v", err) + } + pt := b & mqttPacketMask + if pt != mqttPacketConnectAck { + t.Fatalf("Expected ConnAck (%x), got %x", mqttPacketConnectAck, pt) + } + pl, err := r.readByte("connack packet len") + if err != nil { + t.Fatalf("Error reading packet length: %v", err) + } + if pl != 2 { + t.Fatalf("ConnAck packet length should be 2, got %v", pl) + } + caf, err := r.readByte("connack flags") + if err != nil { + t.Fatalf("Error reading packet length: %v", err) + } + if caf&0xfe != 0 { + t.Fatalf("ConnAck flag bits 7-1 should all be 0, got %x", caf>>1) + } + if sp := caf == 1; sp != sessionPresent { + t.Fatalf("Expected session present flag=%v got %v", sessionPresent, sp) + } + carc, err := r.readByte("connack return code") + if err != nil { + t.Fatalf("Error reading returned code: %v", err) + } + if carc != rc { + t.Fatalf("Expected return code to be %v, got %v", rc, carc) + } +} + +func testMQTTEnableJSForAccount(t *testing.T, s *Server, accName string) { + t.Helper() + acc, err := s.LookupAccount(accName) + if err != nil { + t.Fatalf("Error looking up account: %v", err) + } + limits := &JetStreamAccountLimits{ + MaxConsumers: -1, + MaxStreams: -1, + MaxMemory: 1024 * 1024, + } + if err := acc.EnableJetStream(limits); err != nil { + t.Fatalf("Error enabling JS: %v", err) + } +} + +func TestMQTTTLSVerifyAndMap(t *testing.T) { + accName := "MyAccount" + acc := NewAccount(accName) + certUserName := "CN=example.com,OU=NATS.io" + users := []*User{&User{Username: certUserName, Account: acc}} + + for _, test := range []struct { + name string + filtering bool + provideCert bool + }{ + {"no filtering, client provides cert", false, true}, + {"no filtering, client does not provide cert", false, false}, + {"filtering, client provides cert", true, true}, + {"filtering, client does not provide cert", true, false}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + o.Host = "localhost" + o.Accounts = []*Account{acc} + o.Users = users + if test.filtering { + o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeMqtt}) + } + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/tlsauth/server.pem", + KeyFile: "../test/configs/certs/tlsauth/server-key.pem", + CaFile: "../test/configs/certs/tlsauth/ca.pem", + Verify: true, + } + tlsc, err := GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error creating tls config: %v", err) + } + o.MQTT.TLSConfig = tlsc + o.MQTT.TLSTimeout = 2.0 + o.MQTT.TLSMap = true + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + testMQTTEnableJSForAccount(t, s, accName) + + addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port) + mc, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating ws connection: %v", err) + } + defer mc.Close() + tlscc := &tls.Config{} + if test.provideCert { + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/tlsauth/client.pem", + KeyFile: "../test/configs/certs/tlsauth/client-key.pem", + } + var err error + tlscc, err = GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error generating tls config: %v", err) + } + } + tlscc.InsecureSkipVerify = true + if test.provideCert { + tlscc.MinVersion = tls.VersionTLS13 + } + mc = tls.Client(mc, tlscc) + if err := mc.(*tls.Conn).Handshake(); err != nil { + t.Fatalf("Error during handshake: %v", err) + } + + ci := &mqttConnInfo{cleanSess: true} + proto := mqttCreateConnectProto(ci) + if _, err := testMQTTWrite(mc, proto); err != nil { + t.Fatalf("Error sending proto: %v", err) + } + buf, err := testMQTTRead(mc) + if !test.provideCert { + if err == nil { + t.Fatal("Expected error, did not get one") + } else if !strings.Contains(err.Error(), "bad certificate") { + t.Fatalf("Unexpected error: %v", err) + } + return + } + if err != nil { + t.Fatalf("Error reading: %v", err) + } + r := &mqttReader{reader: mc} + r.reset(buf) + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + var c *client + s.mu.Lock() + for _, sc := range s.clients { + sc.mu.Lock() + if sc.mqtt != nil { + c = sc + } + sc.mu.Unlock() + if c != nil { + break + } + } + s.mu.Unlock() + if c == nil { + t.Fatal("Client not found") + } + + var uname string + var accname string + c.mu.Lock() + uname = c.opts.Username + if c.acc != nil { + accname = c.acc.GetName() + } + c.mu.Unlock() + if uname != certUserName { + t.Fatalf("Expected username %q, got %q", certUserName, uname) + } + if accname != accName { + t.Fatalf("Expected account %q, got %v", accName, accname) + } + }) + } +} + +func TestMQTTBasicAuth(t *testing.T) { + for _, test := range []struct { + name string + opts func() *Options + user string + pass string + rc byte + }{ + { + "top level auth, no override, wrong u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.Username = "normal" + o.Password = "client" + return o + }, + "mqtt", "client", mqttConnAckRCNotAuthorized, + }, + { + "top level auth, no override, correct u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.Username = "normal" + o.Password = "client" + return o + }, + "normal", "client", mqttConnAckRCConnectionAccepted, + }, + { + "no top level auth, mqtt auth, wrong u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + return o + }, + "normal", "client", mqttConnAckRCNotAuthorized, + }, + { + "no top level auth, mqtt auth, correct u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + return o + }, + "mqtt", "client", mqttConnAckRCConnectionAccepted, + }, + { + "top level auth, mqtt override, wrong u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.Username = "normal" + o.Password = "client" + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + return o + }, + "normal", "client", mqttConnAckRCNotAuthorized, + }, + { + "top level auth, mqtt override, correct u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.Username = "normal" + o.Password = "client" + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + return o + }, + "mqtt", "client", mqttConnAckRCConnectionAccepted, + }, + } { + t.Run(test.name, func(t *testing.T) { + o := test.opts() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: test.user, + pass: test.pass, + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, test.rc, false) + }) + } +} + +func TestMQTTAuthTimeout(t *testing.T) { + for _, test := range []struct { + name string + at float64 + mat float64 + ok bool + }{ + {"use top-level auth timeout", 0.5, 0.0, true}, + {"use mqtt auth timeout", 0.5, 0.05, false}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + o.AuthTimeout = test.at + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + o.MQTT.AuthTimeout = test.mat + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + defer mc.Close() + + time.Sleep(100 * time.Millisecond) + + ci := &mqttConnInfo{ + cleanSess: true, + user: "mqtt", + pass: "client", + } + proto := mqttCreateConnectProto(ci) + if _, err := testMQTTWrite(mc, proto); err != nil { + if test.ok { + t.Fatalf("Error sending connect: %v", err) + } + // else it is ok since we got disconnected due to auth timeout + return + } + buf, err := testMQTTRead(mc) + if err != nil { + if test.ok { + t.Fatalf("Error reading: %v", err) + } + // else it is ok since we got disconnected due to auth timeout + return + } + r := &mqttReader{reader: mc} + r.reset(buf) + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + time.Sleep(500 * time.Millisecond) + testMQTTPublish(t, mc, r, 1, false, false, "foo", 1, []byte("msg")) + }) + } +} + +func TestMQTTTokenAuth(t *testing.T) { + for _, test := range []struct { + name string + opts func() *Options + token string + rc byte + }{ + { + "top level auth, no override, wrong token", + func() *Options { + o := testMQTTDefaultOptions() + o.Authorization = "goodtoken" + return o + }, + "badtoken", mqttConnAckRCNotAuthorized, + }, + { + "top level auth, no override, correct token", + func() *Options { + o := testMQTTDefaultOptions() + o.Authorization = "goodtoken" + return o + }, + "goodtoken", mqttConnAckRCConnectionAccepted, + }, + { + "no top level auth, mqtt auth, wrong token", + func() *Options { + o := testMQTTDefaultOptions() + o.MQTT.Token = "goodtoken" + return o + }, + "badtoken", mqttConnAckRCNotAuthorized, + }, + { + "no top level auth, mqtt auth, correct token", + func() *Options { + o := testMQTTDefaultOptions() + o.MQTT.Token = "goodtoken" + return o + }, + "goodtoken", mqttConnAckRCConnectionAccepted, + }, + { + "top level auth, mqtt override, wrong token", + func() *Options { + o := testMQTTDefaultOptions() + o.Authorization = "clienttoken" + o.MQTT.Token = "mqtttoken" + return o + }, + "clienttoken", mqttConnAckRCNotAuthorized, + }, + { + "top level auth, mqtt override, correct token", + func() *Options { + o := testMQTTDefaultOptions() + o.Authorization = "clienttoken" + o.MQTT.Token = "mqtttoken" + return o + }, + "mqtttoken", mqttConnAckRCConnectionAccepted, + }, + } { + t.Run(test.name, func(t *testing.T) { + o := test.opts() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: "ignore_use_token", + pass: test.token, + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, test.rc, false) + }) + } +} + +func TestMQTTUsersAuth(t *testing.T) { + users := []*User{&User{Username: "user", Password: "pwd"}} + for _, test := range []struct { + name string + opts func() *Options + user string + pass string + rc byte + }{ + { + "no filtering, wrong user", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + return o + }, + "wronguser", "pwd", mqttConnAckRCNotAuthorized, + }, + { + "no filtering, correct user", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + return o + }, + "user", "pwd", mqttConnAckRCConnectionAccepted, + }, + { + "filtering, user not allowed", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + // Only allowed for regular clients + o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard}) + return o + }, + "user", "pwd", mqttConnAckRCNotAuthorized, + }, + { + "filtering, user allowed", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeMqtt}) + return o + }, + "user", "pwd", mqttConnAckRCConnectionAccepted, + }, + { + "filtering, wrong password", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeMqtt}) + return o + }, + "user", "badpassword", mqttConnAckRCNotAuthorized, + }, + } { + t.Run(test.name, func(t *testing.T) { + o := test.opts() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: test.user, + pass: test.pass, + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, test.rc, false) + }) + } +} + +func TestMQTTNoAuthUserValidation(t *testing.T) { + o := testMQTTDefaultOptions() + o.Users = []*User{&User{Username: "user", Password: "pwd"}} + // Should fail because it is not part of o.Users. + o.MQTT.NoAuthUser = "notfound" + if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") { + t.Fatalf("Expected error saying not present as user, got %v", err) + } + + // Set a valid no auth user for global options, but still should fail because + // of o.MQTT.NoAuthUser + o.NoAuthUser = "user" + o.MQTT.NoAuthUser = "notfound" + if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") { + t.Fatalf("Expected error saying not present as user, got %v", err) + } +} + +func TestMQTTNoAuthUser(t *testing.T) { + for _, test := range []struct { + name string + override bool + useAuth bool + expectedUser string + expectedAcc string + }{ + {"no override, no user provided", false, false, "noauth", "normal"}, + {"no override, user povided", false, true, "user", "normal"}, + {"override, no user provided", true, false, "mqttnoauth", "mqtt"}, + {"override, user provided", true, true, "mqttuser", "mqtt"}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + normalAcc := NewAccount("normal") + mqttAcc := NewAccount("mqtt") + o.Accounts = []*Account{normalAcc, mqttAcc} + o.Users = []*User{ + &User{Username: "noauth", Password: "pwd", Account: normalAcc}, + &User{Username: "user", Password: "pwd", Account: normalAcc}, + &User{Username: "mqttnoauth", Password: "pwd", Account: mqttAcc}, + &User{Username: "mqttuser", Password: "pwd", Account: mqttAcc}, + } + o.NoAuthUser = "noauth" + if test.override { + o.MQTT.NoAuthUser = "mqttnoauth" + } + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + testMQTTEnableJSForAccount(t, s, "normal") + testMQTTEnableJSForAccount(t, s, "mqtt") + + ci := &mqttConnInfo{clientID: "mqtt", cleanSess: true} + if test.useAuth { + ci.user = test.expectedUser + ci.pass = "pwd" + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + c := testMQTTGetClient(t, s, "mqtt") + c.mu.Lock() + uname := c.opts.Username + aname := c.acc.GetName() + c.mu.Unlock() + if uname != test.expectedUser { + t.Fatalf("Expected selected user to be %q, got %q", test.expectedUser, uname) + } + if aname != test.expectedAcc { + t.Fatalf("Expected selected account to be %q, got %q", test.expectedAcc, aname) + } + }) + } +} + +func TestMQTTConnectNotFirstProto(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + c, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Error on dial: %v", err) + } + defer c.Close() + + w := &mqttWriter{} + mqttWritePublish(w, 0, false, false, "foo", 0, []byte("hello")) + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error publishing: %v", err) + } + testMQTTExpectDisconnect(t, c) +} + +func TestMQTTSecondConnect(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + proto := mqttCreateConnectProto(&mqttConnInfo{cleanSess: true}) + if _, err := testMQTTWrite(mc, proto); err != nil { + t.Fatalf("Error writing connect: %v", err) + } + testMQTTExpectDisconnect(t, mc) +} + +func TestMQTTParseConnect(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + proto []byte + pl int + reader mqttIOReader + err string + }{ + {"packet in buffer error", nil, 10, eofr, "error ensuring protocol is loaded"}, + {"bad proto name", []byte{0, 4, 'B', 'A', 'D'}, 5, nil, "protocol name"}, + {"invalid proto name", []byte{0, 3, 'B', 'A', 'D'}, 5, nil, "expected connect packet with protocol name"}, + {"old proto not supported", []byte{0, 6, 'M', 'Q', 'I', 's', 'd', 'p'}, 8, nil, "older protocol"}, + {"error on protocol level", []byte{0, 4, 'M', 'Q', 'T', 'T'}, 6, eofr, "protocol level"}, + {"unacceptable protocol version", []byte{0, 4, 'M', 'Q', 'T', 'T', 10}, 7, nil, "unacceptable protocol version"}, + {"error on flags", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel}, 7, eofr, "flags"}, + {"reserved flag", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 1}, 8, nil, "connect flags reserved bit not set to 0"}, + {"will qos without will flag", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 1 << 3}, 8, nil, "if Will flag is set to 0, Will QoS must be 0 too"}, + {"will retain without will flag", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 1 << 5}, 8, nil, "if Will flag is set to 0, Will Retain flag must be 0 too"}, + {"will qos", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 3<<3 | 1<<2}, 8, nil, "if Will flag is set to 1, Will QoS can be 0, 1 or 2"}, + {"no user but password", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagPasswordFlag}, 8, nil, "password flag set but username flag is not"}, + {"missing keep alive", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0}, 8, nil, "keep alive"}, + {"missing client ID", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0, 0, 1}, 10, nil, "client ID"}, + {"empty client ID", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0, 0, 1, 0, 0}, 12, nil, "when client ID is empty, clean session flag must be set to 1"}, + {"invalid utf8 client ID", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0, 0, 1, 0, 1, 241}, 13, nil, "invalid utf8 for client ID"}, + {"missing will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0}, 12, nil, "Will topic"}, + {"empty will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 0}, 14, nil, "empty Will topic not allowed"}, + {"invalid utf8 will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 241}, 15, nil, "invalide utf8 for Will topic"}, + {"invalid wildcard will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, '#'}, 15, nil, "wildcards not allowed"}, + {"error on will message", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 'a', 0, 3}, 17, eofr, "Will message"}, + {"error on username", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagCleanSession, 0, 0, 0, 0}, 12, eofr, "user name"}, + {"empty username", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 0}, 14, nil, "empty user name not allowed"}, + {"invalid utf8 username", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 241}, 15, nil, "invalid utf8 for user name"}, + {"error on password", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagPasswordFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 'a'}, 15, eofr, "password"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + mqtt := &mqtt{r: r} + c := &client{mqtt: mqtt} + if _, _, err := c.mqttParseConnect(r, test.pl); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func TestMQTTConnectFailsOnParse(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port) + c, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating mqtt connection: %v", err) + } + + pkLen := 2 + len(mqttProtoName) + + 1 + // proto level + 1 + // flags + 2 + // keepAlive + 2 + len("mqtt") + + w := &mqttWriter{} + w.WriteByte(mqttPacketConnect) + w.WriteVarInt(pkLen) + w.WriteString(string(mqttProtoName)) + w.WriteByte(0x7) + w.WriteByte(mqttConnFlagCleanSession) + w.WriteUint16(0) + w.WriteString("mqtt") + c.Write(w.Bytes()) + + buf, err := testMQTTRead(c) + if err != nil { + t.Fatalf("Error reading: %v", err) + } + r := &mqttReader{reader: c} + r.reset(buf) + testMQTTCheckConnAck(t, r, mqttConnAckRCUnacceptableProtocolVersion, false) +} + +func TestMQTTConnKeepAlive(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true, keepAlive: 1}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, mc, r, 0, false, false, "foo", 0, []byte("msg")) + + time.Sleep(2 * time.Second) + testMQTTExpectDisconnect(t, mc) +} + +func TestMQTTTopicAndSubjectConversion(t *testing.T) { + for _, test := range []struct { + name string + mqttTopic string + natsSubject string + err string + }{ + {"/", "/", "/./", ""}, + {"//", "//", "/././", ""}, + {"///", "///", "/./././", ""}, + {"////", "////", "/././././", ""}, + {"foo", "foo", "foo", ""}, + {"/foo", "/foo", "/.foo", ""}, + {"//foo", "//foo", "/./.foo", ""}, + {"///foo", "///foo", "/././.foo", ""}, + {"///foo/", "///foo/", "/././.foo./", ""}, + {"///foo//", "///foo//", "/././.foo././", ""}, + {"///foo///", "///foo///", "/././.foo./././", ""}, + {"foo/bar", "foo/bar", "foo.bar", ""}, + {"/foo/bar", "/foo/bar", "/.foo.bar", ""}, + {"/foo/bar/", "/foo/bar/", "/.foo.bar./", ""}, + {"foo/bar/baz", "foo/bar/baz", "foo.bar.baz", ""}, + {"/foo/bar/baz", "/foo/bar/baz", "/.foo.bar.baz", ""}, + {"/foo/bar/baz/", "/foo/bar/baz/", "/.foo.bar.baz./", ""}, + {"bar", "bar/", "bar./", ""}, + {"bar//", "bar//", "bar././", ""}, + {"bar///", "bar///", "bar./././", ""}, + {"foo//bar", "foo//bar", "foo./.bar", ""}, + {"foo///bar", "foo///bar", "foo././.bar", ""}, + {"foo////bar", "foo////bar", "foo./././.bar", ""}, + // These should produce errors + {"foo/+", "foo/+", "", "wildcards not allowed in publish"}, + {"foo/#", "foo/#", "", "wildcards not allowed in publish"}, + {"foo bar", "foo bar", "", "not supported"}, + {"foo.bar", "foo.bar", "", "not supported"}, + } { + t.Run(test.name, func(t *testing.T) { + _, res, err := mqttTopicToNATSPubSubject([]byte(test.mqttTopic)) + if test.err != _EMPTY_ { + if err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %q", test.err, err.Error()) + } + return + } + toNATS := string(res) + if toNATS != test.natsSubject { + t.Fatalf("Expected subject %q got %q", test.natsSubject, toNATS) + } + + res = natsSubjectToMQTTTopic(toNATS) + backToMQTT := string(res) + if backToMQTT != test.mqttTopic { + t.Fatalf("Expected topic %q got %q (NATS conversion was %q)", test.mqttTopic, backToMQTT, toNATS) + } + }) + } +} + +func TestMQTTFilterConversion(t *testing.T) { + // Similar to TopicConversion test except that wildcards are OK here. + // So testing only those. + for _, test := range []struct { + name string + mqttTopic string + natsSubject string + }{ + {"single level wildcard", "+", "*"}, + {"single level wildcard", "/+", "/.*"}, + {"single level wildcard", "+/", "*./"}, + {"single level wildcard", "/+/", "/.*./"}, + {"single level wildcard", "foo/+", "foo.*"}, + {"single level wildcard", "foo/+/", "foo.*./"}, + {"single level wildcard", "foo/+/bar", "foo.*.bar"}, + {"single level wildcard", "foo/+/+", "foo.*.*"}, + {"single level wildcard", "foo/+/+/", "foo.*.*./"}, + {"single level wildcard", "foo/+/+/bar", "foo.*.*.bar"}, + + {"multi level wildcard", "#", ">"}, + {"multi level wildcard", "/#", "/.>"}, + {"multi level wildcard", "/foo/#", "/.foo.>"}, + {"multi level wildcard", "foo/#", "foo.>"}, + {"multi level wildcard", "foo/bar/#", "foo.bar.>"}, + } { + t.Run(test.name, func(t *testing.T) { + _, res, err := mqttFilterToNATSSubject([]byte(test.mqttTopic)) + if err != nil { + t.Fatalf("Error: %v", err) + } + if string(res) != test.natsSubject { + t.Fatalf("Expected subject %q got %q", test.natsSubject, res) + } + }) + } +} + +func testMQTTReaderHasAtLeastOne(t testing.TB, r *mqttReader) { + t.Helper() + r.reader.SetReadDeadline(time.Now().Add(2 * time.Second)) + if err := r.ensurePacketInBuffer(1); err != nil { + t.Fatal(err) + } + r.reader.SetReadDeadline(time.Time{}) +} + +func TestMQTTParseSub(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + proto []byte + b byte + pl int + reader mqttIOReader + err string + }{ + {"reserved flag", nil, 3, 0, nil, "wrong subscribe reserved flags"}, + {"ensure packet loaded", []byte{1, 2}, mqttSubscribeFlags, 10, eofr, "error ensuring protocol is loaded"}, + {"error reading packet id", []byte{1}, mqttSubscribeFlags, 1, eofr, "reading packet identifier"}, + {"missing filters", []byte{0, 1}, mqttSubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"}, + {"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttSubscribeFlags, 5, eofr, "topic filter"}, + {"empty topic", []byte{0, 1, 0, 0}, mqttSubscribeFlags, 4, nil, "topic filter cannot be empty"}, + {"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttSubscribeFlags, 5, nil, "invalid utf8 for topic filter"}, + {"missing qos", []byte{0, 1, 0, 1, 'a'}, mqttSubscribeFlags, 5, nil, "QoS"}, + {"invalid qos", []byte{0, 1, 0, 1, 'a', 3}, mqttSubscribeFlags, 6, nil, "subscribe QoS value must be 0, 1 or 2"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + mqtt := &mqtt{r: r} + c := &client{mqtt: mqtt} + if _, _, err := c.mqttParseSubsOrUnsubs(r, test.b, test.pl, true); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func testMQTTSub(t testing.TB, pi uint16, c net.Conn, r *mqttReader, filters []*mqttFilter, expected []byte) { + t.Helper() + w := &mqttWriter{} + pkLen := 2 // for pi + for i := 0; i < len(filters); i++ { + f := filters[i] + pkLen += 2 + len(f.filter) + 1 + } + w.WriteByte(mqttPacketSub | mqttSubscribeFlags) + w.WriteVarInt(pkLen) + w.WriteUint16(pi) + for i := 0; i < len(filters); i++ { + f := filters[i] + w.WriteBytes([]byte(f.filter)) + w.WriteByte(f.qos) + } + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing SUBSCRIBE protocol: %v", err) + } + // Make sure we have at least 1 byte in buffer (if not will read) + testMQTTReaderHasAtLeastOne(t, r) + // Parse SUBACK + b, err := r.readByte("packet type") + if err != nil { + t.Fatal(err) + } + if pt := b & mqttPacketMask; pt != mqttPacketSubAck { + t.Fatalf("Expected SUBACK packet %x, got %x", mqttPacketSubAck, pt) + } + pl, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + t.Fatal(err) + } + rpi, err := r.readUint16("packet identifier") + if err != nil || rpi != pi { + t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err) + } + for i, rem := 0, pl-2; rem > 0; rem-- { + qos, err := r.readByte("filter qos") + if err != nil { + t.Fatal(err) + } + if qos != expected[i] { + t.Fatalf("For topic filter %q expected qos of %v, got %v", + filters[i].filter, expected[i], qos) + } + i++ + } +} + +func TestMQTTSubAck(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + subs := []*mqttFilter{ + {filter: "foo", qos: 0}, + {filter: "bar", qos: 1}, + {filter: "baz", qos: 2}, // Since we don't support, we should receive a result of 1 + {filter: "foo/#/bar", qos: 0}, // Invalid sub, so we should receive a result of mqttSubAckFailure + } + expected := []byte{ + 0, + 1, + 1, + mqttSubAckFailure, + } + testMQTTSub(t, 1, mc, r, subs, expected) +} + +func testMQTTFlush(t testing.TB, c net.Conn, bw *bufio.Writer, r *mqttReader) { + t.Helper() + w := &mqttWriter{} + w.WriteByte(mqttPacketPing) + w.WriteByte(0) + if bw != nil { + bw.Write(w.Bytes()) + bw.Flush() + } else { + c.Write(w.Bytes()) + } + r.ensurePacketInBuffer(2) + ab, err := r.readByte("pingresp") + if err != nil { + t.Fatalf("Error reading ping response: %v", err) + } + if pt := ab & mqttPacketMask; pt != mqttPacketPingResp { + t.Fatalf("Expected ping response got %x", pt) + } + l, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if l != 0 { + t.Fatalf("Expected PINGRESP length to be 0, got %v", l) + } +} + +func testMQTTExpectNothing(t testing.TB, r *mqttReader) { + t.Helper() + r.reader.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + if err := r.ensurePacketInBuffer(1); err == nil { + t.Fatalf("Expected nothing, got %v", r.buf[r.pos:]) + } + r.reader.SetReadDeadline(time.Time{}) +} + +func testMQTTCheckPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, flags byte, payload []byte) { + t.Helper() + pflags, pi := testMQTTGetPubMsg(t, c, r, topic, payload) + if pflags != flags { + t.Fatalf("Expected flags to be %x, got %x", flags, pflags) + } + if pi > 0 { + testMQTTSendPubAck(t, c, pi) + } +} + +func testMQTTCheckPubMsgNoAck(t testing.TB, c net.Conn, r *mqttReader, topic string, flags byte, payload []byte) uint16 { + t.Helper() + pflags, pi := testMQTTGetPubMsg(t, c, r, topic, payload) + if pflags != flags { + t.Fatalf("Expected flags to be %x, got %x", flags, pflags) + } + return pi +} + +func testMQTTGetPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, payload []byte) (byte, uint16) { + t.Helper() + testMQTTReaderHasAtLeastOne(t, r) + b, err := r.readByte("packet type") + if err != nil { + t.Fatal(err) + } + if pt := b & mqttPacketMask; pt != mqttPacketPub { + t.Fatalf("Expected PUBLISH packet %x, got %x", mqttPacketPub, pt) + } + pflags := b & mqttPacketFlagMask + qos := (pflags & mqttPubFlagQoS) >> 1 + pl, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + t.Fatal(err) + } + start := r.pos + ptopic, err := r.readString("topic name") + if err != nil { + t.Fatal(err) + } + if ptopic != topic { + t.Fatalf("Expected topic %q, got %q", topic, ptopic) + } + var pi uint16 + if qos > 0 { + pi, err = r.readUint16("packet identifier") + if err != nil { + t.Fatal(err) + } + } + msgLen := pl - (r.pos - start) + if r.pos+msgLen > len(r.buf) { + t.Fatalf("computed message length goes beyond buffer: ml=%v pos=%v lenBuf=%v", + msgLen, r.pos, len(r.buf)) + } + ppayload := r.buf[r.pos : r.pos+msgLen] + if !bytes.Equal(payload, ppayload) { + t.Fatalf("Expected payload %q, got %q", payload, ppayload) + } + r.pos += msgLen + return pflags, pi +} + +func testMQTTSendPubAck(t testing.TB, c net.Conn, pi uint16) { + t.Helper() + w := &mqttWriter{} + w.WriteByte(mqttPacketPubAck) + w.WriteVarInt(2) + w.WriteUint16(pi) + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing PUBACK: %v", err) + } +} + +func testMQTTPublish(t testing.TB, c net.Conn, r *mqttReader, qos byte, dup, retain bool, topic string, pi uint16, payload []byte) { + t.Helper() + w := &mqttWriter{} + mqttWritePublish(w, qos, dup, retain, topic, pi, payload) + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing PUBLISH proto: %v", err) + } + if qos > 0 { + // Since we don't support QoS 2, we should get disconnected + if qos == 2 { + testMQTTExpectDisconnect(t, c) + return + } + testMQTTReaderHasAtLeastOne(t, r) + // Parse PUBACK + b, err := r.readByte("packet type") + if err != nil { + t.Fatal(err) + } + if pt := b & mqttPacketMask; pt != mqttPacketPubAck { + t.Fatalf("Expected PUBACK packet %x, got %x", mqttPacketPubAck, pt) + } + pl, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + t.Fatal(err) + } + rpi, err := r.readUint16("packet identifier") + if err != nil || rpi != pi { + t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err) + } + } +} + +func TestMQTTParsePub(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + flags byte + proto []byte + pl int + reader mqttIOReader + err string + }{ + {"qos not supported", 0x4, nil, 0, nil, "not supported"}, + {"packet in buffer error", 0, nil, 10, eofr, "error ensuring protocol is loaded"}, + {"error on topic", 0, []byte{0, 3, 'f', 'o'}, 4, eofr, "topic"}, + {"empty topic", 0, []byte{0, 0}, 2, nil, "topic cannot be empty"}, + {"wildcards topic", 0, []byte{0, 1, '#'}, 3, nil, "wildcards not allowed"}, + {"error on packet identifier", mqttPubQos1, []byte{0, 3, 'f', 'o', 'o'}, 5, eofr, "packet identifier"}, + {"invalid packet identifier", mqttPubQos1, []byte{0, 3, 'f', 'o', 'o', 0, 0}, 7, nil, "packet identifier cannot be 0"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + mqtt := &mqtt{r: r} + c := &client{mqtt: mqtt} + pp := &mqttPublish{flags: test.flags} + if err := c.mqttParsePub(r, test.pl, pp); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func TestMQTTParsePubAck(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + proto []byte + pl int + reader mqttIOReader + err string + }{ + {"packet in buffer error", nil, 10, eofr, "error ensuring protocol is loaded"}, + {"error reading packet identifier", []byte{0}, 1, eofr, "packet identifier"}, + {"invalid packet identifier", []byte{0, 0}, 2, nil, "packet identifier cannot be 0"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + if _, err := mqttParsePubAck(r, test.pl); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func TestMQTTPublish(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, mcp, mpr, 0, false, false, "foo", 0, []byte("msg")) + testMQTTPublish(t, mcp, mpr, 1, false, false, "foo", 1, []byte("msg")) + testMQTTPublish(t, mcp, mpr, 2, false, false, "foo", 2, []byte("msg")) +} + +func TestMQTTSub(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + for _, test := range []struct { + name string + mqttSubTopic string + natsPubSubject string + mqttPubTopic string + ok bool + }{ + {"1 level match", "foo", "foo", "foo", true}, + {"1 level no match", "foo", "bar", "bar", false}, + {"2 levels match", "foo/bar", "foo.bar", "foo/bar", true}, + {"2 levels no match", "foo/bar", "foo.baz", "foo/baz", false}, + {"3 levels match", "/foo/bar", "/.foo.bar", "/foo/bar", true}, + {"3 levels no match", "/foo/bar", "/.foo.baz", "/foo/baz", false}, + + {"single level wc", "foo/+", "foo.bar.baz", "foo/bar/baz", false}, + {"single level wc", "foo/+", "foo.bar./", "foo/bar/", false}, + {"single level wc", "foo/+", "foo.bar", "foo/bar", true}, + {"single level wc", "foo/+", "foo./", "foo/", true}, + {"single level wc", "foo/+", "foo", "foo", false}, + {"single level wc", "foo/+", "/.foo", "/foo", false}, + + {"multiple level wc", "foo/#", "foo.bar.baz./", "foo/bar/baz/", true}, + {"multiple level wc", "foo/#", "foo.bar.baz", "foo/bar/baz", true}, + {"multiple level wc", "foo/#", "foo.bar./", "foo/bar/", true}, + {"multiple level wc", "foo/#", "foo.bar", "foo/bar", true}, + {"multiple level wc", "foo/#", "foo./", "foo/", true}, + {"multiple level wc", "foo/#", "foo", "foo", true}, + {"multiple level wc", "foo/#", "/.foo", "/foo", false}, + } { + t.Run(test.name, func(t *testing.T) { + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: test.mqttSubTopic, qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + natsPub(t, nc, test.natsPubSubject, []byte("msg")) + if test.ok { + testMQTTCheckPubMsg(t, mc, r, test.mqttPubTopic, 0, []byte("msg")) + } else { + testMQTTExpectNothing(t, r) + } + + testMQTTPublish(t, mcp, mpr, 0, false, false, test.mqttPubTopic, 0, []byte("msg")) + if test.ok { + testMQTTCheckPubMsg(t, mc, r, test.mqttPubTopic, 0, []byte("msg")) + } else { + testMQTTExpectNothing(t, r) + } + }) + } +} + +func TestMQTTSubQoS(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + mqttTopic := "foo/bar" + + // Subscribe with QoS 1 + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 1}}, []byte{1}) + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: mqttTopic, qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + // Publish from NATS, which means QoS 0 + natsPub(t, nc, "foo.bar", []byte("NATS")) + // Will receive as QoS 0 + testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("NATS")) + testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("NATS")) + + // Publish from MQTT with QoS 0 + testMQTTPublish(t, mcp, mpr, 0, false, false, mqttTopic, 0, []byte("msg")) + // Will receive as QoS 0 + testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("msg")) + + // Publish from MQTT with QoS 1 + testMQTTPublish(t, mcp, mpr, 1, false, false, mqttTopic, 1, []byte("msg")) + pflags1, pi1 := testMQTTGetPubMsg(t, mc, r, mqttTopic, []byte("msg")) + if pflags1 != 0x2 { + t.Fatalf("Expected flags to be 0x2, got %v", pflags1) + } + pflags2, pi2 := testMQTTGetPubMsg(t, mc, r, mqttTopic, []byte("msg")) + if pflags2 != 0x2 { + t.Fatalf("Expected flags to be 0x2, got %v", pflags2) + } + if pi1 == pi2 { + t.Fatalf("packet identifier for message 1: %v should be different from message 2", pi1) + } + testMQTTSendPubAck(t, mc, pi1) + testMQTTSendPubAck(t, mc, pi2) +} + +func getSubQoS(sub *subscription) int { + if sub.mqtt != nil { + return int(sub.mqtt.qos) + } + return -1 +} + +func TestMQTTSubDups(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Test with single SUBSCRIBE protocol but multiple filters + filters := []*mqttFilter{ + &mqttFilter{filter: "foo", qos: 1}, + &mqttFilter{filter: "foo", qos: 0}, + } + testMQTTSub(t, 1, mc, r, filters, []byte{1, 0}) + testMQTTFlush(t, mc, nil, r) + + // And also with separate SUBSCRIBE protocols + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 0}}, []byte{0}) + // Ask for QoS 2 but server will downgrade to 1 + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 2}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + // Publish and test msg received only once + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + testMQTTPublish(t, mcp, r, 0, false, false, "bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "bar", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + // Check that the QoS for subscriptions have been updated to the latest received filter + var err error + subc := testMQTTGetClient(t, s, "sub") + subc.mu.Lock() + if subc.opts.Username != "sub" { + err = fmt.Errorf("wrong user name") + } + if err == nil { + if sub := subc.subs["foo"]; sub == nil || getSubQoS(sub) != 0 { + err = fmt.Errorf("subscription foo QoS should be 0, got %v", getSubQoS(sub)) + } + } + if err == nil { + if sub := subc.subs["bar"]; sub == nil || getSubQoS(sub) != 1 { + err = fmt.Errorf("subscription bar QoS should be 1, got %v", getSubQoS(sub)) + } + } + subc.mu.Unlock() + if err != nil { + t.Fatal(err) + } + + // Now subscribe on "foo/#" which means that a PUBLISH on "foo" will be received + // by this subscription and also the one on "foo". + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + // Publish and test msg received twice + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + + checkWCSub := func(expectedQoS int) { + t.Helper() + + subc.mu.Lock() + defer subc.mu.Unlock() + + // When invoked with expectedQoS==1, we have the following subs: + // foo (QoS-0), bar (QoS-1), foo.> (QoS-1) + // which means (since QoS-1 have a JS consumer + sub for delivery + // and foo.> causes a "foo fwc") that we should have the following + // number of NATS subs: foo (1), bar (2), foo.> (2) and "foo fwc" (2), + // so total=7. + // When invoked with expectedQoS==0, it means that we have replaced + // foo/# QoS-1 to QoS-0, so we should have 2 less NATS subs, + // so total=5 + expected := 7 + if expectedQoS == 0 { + expected = 5 + } + if lenmap := len(subc.subs); lenmap != expected { + t.Fatalf("Subs map should have %v entries, got %v", expected, lenmap) + } + if sub, ok := subc.subs["foo.>"]; !ok { + t.Fatal("Expected sub foo.> to be present but was not") + } else if getSubQoS(sub) != expectedQoS { + t.Fatalf("Expected sub foo.> QoS to be %v, got %v", expectedQoS, getSubQoS(sub)) + } + if sub, ok := subc.subs["foo fwc"]; !ok { + t.Fatal("Expected sub foo fwc to be present but was not") + } else if getSubQoS(sub) != expectedQoS { + t.Fatalf("Expected sub foo fwc QoS to be %v, got %v", expectedQoS, getSubQoS(sub)) + } + // Make sure existing sub on "foo" qos was not changed. + if sub, ok := subc.subs["foo"]; !ok { + t.Fatal("Expected sub foo to be present but was not") + } else if getSubQoS(sub) != 0 { + t.Fatalf("Expected sub foo QoS to be 0, got %v", getSubQoS(sub)) + } + } + checkWCSub(1) + + // Sub again on same subject with lower QoS + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + // Publish and test msg received twice + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + checkWCSub(0) +} + +func TestMQTTSubWithSpaces(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo bar", qos: 0}}, []byte{mqttSubAckFailure}) +} + +func TestMQTTSubCaseSensitive(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "Foo/Bar", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + testMQTTPublish(t, mcp, r, 0, false, false, "Foo/Bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "Foo/Bar", 0, []byte("msg")) + + testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + natsPub(t, nc, "Foo.Bar", []byte("nats")) + testMQTTCheckPubMsg(t, mc, r, "Foo/Bar", 0, []byte("nats")) + + natsPub(t, nc, "foo.bar", []byte("nats")) + testMQTTExpectNothing(t, r) +} + +func TestMQTTPubSubMatrix(t *testing.T) { + for _, test := range []struct { + name string + natsPub bool + mqttPub bool + mqttPubQoS byte + natsSub bool + mqttSubQoS0 bool + mqttSubQoS1 bool + }{ + {"NATS to MQTT sub QoS-0", true, false, 0, false, true, false}, + {"NATS to MQTT sub QoS-1", true, false, 0, false, false, true}, + {"NATS to MQTT sub QoS-0 and QoS-1", true, false, 0, false, true, true}, + + {"MQTT QoS-0 to NATS sub", false, true, 0, true, false, false}, + {"MQTT QoS-0 to MQTT sub QoS-0", false, true, 0, false, true, false}, + {"MQTT QoS-0 to MQTT sub QoS-1", false, true, 0, false, false, true}, + {"MQTT QoS-0 to NATS sub and MQTT sub QoS-0", false, true, 0, true, true, false}, + {"MQTT QoS-0 to NATS sub and MQTT sub QoS-1", false, true, 0, true, false, true}, + {"MQTT QoS-0 to all subs", false, true, 0, true, true, true}, + + {"MQTT QoS-1 to NATS sub", false, true, 1, true, false, false}, + {"MQTT QoS-1 to MQTT sub QoS-0", false, true, 1, false, true, false}, + {"MQTT QoS-1 to MQTT sub QoS-1", false, true, 1, false, false, true}, + {"MQTT QoS-1 to NATS sub and MQTT sub QoS-0", false, true, 1, true, true, false}, + {"MQTT QoS-1 to NATS sub and MQTT sub QoS-1", false, true, 1, true, false, true}, + {"MQTT QoS-1 to all subs", false, true, 1, true, true, true}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + mc1, r1 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc1.Close() + testMQTTCheckConnAck(t, r1, mqttConnAckRCConnectionAccepted, false) + + mc2, r2 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, false) + + // First setup subscriptions based on test options. + var ns *nats.Subscription + if test.natsSub { + ns = natsSubSync(t, nc, "foo") + } + if test.mqttSubQoS0 { + testMQTTSub(t, 1, mc1, r1, []*mqttFilter{&mqttFilter{filter: "foo", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc1, nil, r1) + } + if test.mqttSubQoS1 { + testMQTTSub(t, 1, mc2, r2, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc2, nil, r2) + } + + // Just as a barrier + natsFlush(t, nc) + + // Now publish + if test.natsPub { + natsPubReq(t, nc, "foo", "", []byte("msg")) + } else { + testMQTTPublish(t, mc, r, test.mqttPubQoS, false, false, "foo", 1, []byte("msg")) + } + + // Check message received + if test.natsSub { + natsNexMsg(t, ns, time.Second) + // Make sure no other is received + if msg, err := ns.NextMsg(50 * time.Millisecond); err == nil { + t.Fatalf("Should not have gotten a second message, got %v", msg) + } + } + if test.mqttSubQoS0 { + testMQTTCheckPubMsg(t, mc1, r1, "foo", 0, []byte("msg")) + testMQTTExpectNothing(t, r1) + } + if test.mqttSubQoS1 { + var expectedFlag byte + if test.mqttPubQoS > 0 { + expectedFlag = test.mqttPubQoS << 1 + } + testMQTTCheckPubMsg(t, mc2, r2, "foo", expectedFlag, []byte("msg")) + testMQTTExpectNothing(t, r2) + } + }) + } +} + +func TestMQTTPreventSubWithMQTTSubPrefix(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, mc, r, + []*mqttFilter{&mqttFilter{filter: strings.ReplaceAll(mqttSubPrefix, ".", "/") + "foo/bar", qos: 1}}, + []byte{mqttSubAckFailure}) +} + +func TestMQTTSubWithNATSStream(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/bar", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + sc := &StreamConfig{ + Name: "test", + Storage: FileStorage, + Retention: InterestPolicy, + Subjects: []string{"foo.>"}, + } + mset, err := s.GlobalAccount().AddStream(sc) + if err != nil { + t.Fatalf("Unable to create stream: %v", err) + } + + sub := natsSubSync(t, nc, "bar") + cc := &ConsumerConfig{ + Durable: "dur", + AckPolicy: AckExplicit, + DeliverSubject: "bar", + } + if _, err := mset.AddConsumer(cc); err != nil { + t.Fatalf("Unable to add consumer: %v", err) + } + + // Now send message from NATS + resp, err := nc.Request("foo.bar", []byte("nats"), time.Second) + if err != nil { + t.Fatalf("Error publishing: %v", err) + } + ar := &ApiResponse{} + if err := json.Unmarshal(resp.Data, ar); err != nil || ar.Error != nil { + t.Fatalf("Unexpected response: err=%v resp=%+v", err, ar.Error) + } + + // Check that message is received by both + checkRecv := func(content string, flags byte) { + t.Helper() + if msg := natsNexMsg(t, sub, time.Second); string(msg.Data) != content { + t.Fatalf("Expected %q, got %q", content, msg.Data) + } + testMQTTCheckPubMsg(t, mc, r, "foo/bar", flags, []byte(content)) + } + checkRecv("nats", 0) + + // Send from MQTT as a QoS0 + testMQTTPublish(t, mc, r, 0, false, false, "foo/bar", 0, []byte("qos0")) + checkRecv("qos0", 0) + + // Send from MQTT as a QoS1 + testMQTTPublish(t, mc, r, 1, false, false, "foo/bar", 1, []byte("qos1")) + checkRecv("qos1", mqttPubQos1) +} + +func TestMQTTTrackPendingOverrun(t *testing.T) { + sess := &mqttSession{pending: make(map[uint16]*mqttPending)} + sub := &subscription{mqtt: &mqttSub{qos: 1}} + + sess.ppi = 0xFFFF + pi, _ := sess.trackPending(1, _EMPTY_, sub) + if pi != 1 { + t.Fatalf("Expected 1, got %v", pi) + } + + p := &mqttPending{} + for i := 1; i <= 0xFFFF; i++ { + sess.pending[uint16(i)] = p + } + pi, _ = sess.trackPending(1, _EMPTY_, sub) + if pi != 0 { + t.Fatalf("Expected 0, got %v", pi) + } + + delete(sess.pending, 1234) + pi, _ = sess.trackPending(1, _EMPTY_, sub) + if pi != 1234 { + t.Fatalf("Expected 1234, got %v", pi) + } +} + +func TestMQTTPreventStreamAndConsumerWithMQTTPrefix(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + sc := &StreamConfig{ + Name: mqttStreamNamePrefix + "test", + Storage: FileStorage, + Retention: InterestPolicy, + Subjects: []string{"foo.>"}, + } + if _, err := s.GlobalAccount().AddStream(sc); err == nil { + t.Fatal("Expected error") + } + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/bar", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + mset, err := s.GlobalAccount().LookupStream(mqttStreamName) + if err != nil { + t.Fatalf("Error looking up MQTT Stream: %v", err) + } + cc := &ConsumerConfig{ + Durable: "dur", + AckPolicy: AckExplicit, + DeliverSubject: "bar", + } + if _, err := mset.AddConsumer(cc); err == nil { + t.Fatal("Expected error") + } +} + +func TestMQTTSubRestart(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: false}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Start an MQTT subscription QoS=1 on "foo" + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + // Now start a NATS subscription on ">" (anything that would match the JS consumer delivery subject) + natsSubSync(t, nc, ">") + natsFlush(t, nc) + + // Restart the MQTT client + testMQTTDisconnect(t, mc, nil) + + mc, r = testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: false}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Restart an MQTT subscription QoS=1 on "foo" + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) +} + +func TestMQTTSubPropagation(t *testing.T) { + t.Skip("Skipping until JS clustering is supported") + o := testMQTTDefaultOptions() + o.Cluster.Host = "127.0.0.1" + o.Cluster.Port = -1 + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + o2 := DefaultOptions() + o2.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", o.Cluster.Port)) + s2 := RunServer(o2) + defer s2.Shutdown() + + checkClusterFormed(t, s, s2) + + nc := natsConnect(t, s2.ClientURL()) + defer nc.Close() + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + // Because in MQTT foo/# means foo.> but also foo, check that this is propagated + checkSubInterest(t, s2, globalAccountName, "foo", time.Second) + + // Publish on foo.bar, foo./ and foo and we should receive them + natsPub(t, nc, "foo.bar", []byte("hello")) + testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("hello")) + + natsPub(t, nc, "foo./", []byte("from")) + testMQTTCheckPubMsg(t, mc, r, "foo/", 0, []byte("from")) + + natsPub(t, nc, "foo", []byte("NATS")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("NATS")) +} + +func TestMQTTParseUnsub(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + proto []byte + b byte + pl int + reader mqttIOReader + err string + }{ + {"reserved flag", nil, 3, 0, nil, "wrong unsubscribe reserved flags"}, + {"ensure packet loaded", []byte{1, 2}, mqttUnsubscribeFlags, 10, eofr, "error ensuring protocol is loaded"}, + {"error reading packet id", []byte{1}, mqttUnsubscribeFlags, 1, eofr, "reading packet identifier"}, + {"missing filters", []byte{0, 1}, mqttUnsubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"}, + {"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttUnsubscribeFlags, 5, eofr, "topic filter"}, + {"empty topic", []byte{0, 1, 0, 0}, mqttUnsubscribeFlags, 4, nil, "topic filter cannot be empty"}, + {"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttUnsubscribeFlags, 5, nil, "invalid utf8 for topic filter"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + mqtt := &mqtt{r: r} + c := &client{mqtt: mqtt} + if _, _, err := c.mqttParseSubsOrUnsubs(r, test.b, test.pl, false); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func testMQTTUnsub(t *testing.T, pi uint16, c net.Conn, r *mqttReader, filters []*mqttFilter) { + t.Helper() + w := &mqttWriter{} + pkLen := 2 // for pi + for i := 0; i < len(filters); i++ { + f := filters[i] + pkLen += 2 + len(f.filter) + } + w.WriteByte(mqttPacketUnsub | mqttUnsubscribeFlags) + w.WriteVarInt(pkLen) + w.WriteUint16(pi) + for i := 0; i < len(filters); i++ { + f := filters[i] + w.WriteBytes([]byte(f.filter)) + } + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing UNSUBSCRIBE protocol: %v", err) + } + // Make sure we have at least 1 byte in buffer (if not will read) + testMQTTReaderHasAtLeastOne(t, r) + // Parse UNSUBACK + b, err := r.readByte("packet type") + if err != nil { + t.Fatal(err) + } + if pt := b & mqttPacketMask; pt != mqttPacketUnsubAck { + t.Fatalf("Expected UNSUBACK packet %x, got %x", mqttPacketUnsubAck, pt) + } + pl, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + t.Fatal(err) + } + rpi, err := r.readUint16("packet identifier") + if err != nil || rpi != pi { + t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err) + } +} + +func TestMQTTUnsub(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + // Publish and test msg received + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + + // Unsubscribe + testMQTTUnsub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo"}}) + + // Publish and test msg not received + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + // Use of wildcards subs + filters := []*mqttFilter{ + &mqttFilter{filter: "foo/bar", qos: 0}, + &mqttFilter{filter: "foo/#", qos: 0}, + } + testMQTTSub(t, 1, mc, r, filters, []byte{0, 0}) + testMQTTFlush(t, mc, nil, r) + + // Publish and check that message received twice + testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("msg")) + + // Unsub the wildcard one + testMQTTUnsub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#"}}) + // Publish and check that message received once + testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + // Unsub last + testMQTTUnsub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/bar"}}) + // Publish and test msg not received + testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg")) + testMQTTExpectNothing(t, r) +} + +func testMQTTExpectDisconnect(t testing.TB, c net.Conn) { + if buf, err := testMQTTRead(c); err == nil { + t.Fatalf("Expected connection to be disconnected, got %s", buf) + } +} + +func TestMQTTPublishTopicErrors(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + for _, test := range []struct { + name string + topic string + }{ + {"empty", ""}, + {"with single level wildcard", "foo/+"}, + {"with multiple level wildcard", "foo/#"}, + } { + t.Run(test.name, func(t *testing.T) { + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, mc, r, 0, false, false, test.topic, 0, []byte("msg")) + testMQTTExpectDisconnect(t, mc) + }) + } +} + +func testMQTTDisconnect(t testing.TB, c net.Conn, bw *bufio.Writer) { + t.Helper() + w := &mqttWriter{} + w.WriteByte(mqttPacketDisconnect) + w.WriteByte(0) + if bw != nil { + bw.Write(w.Bytes()) + bw.Flush() + } else { + c.Write(w.Bytes()) + } + testMQTTExpectDisconnect(t, c) +} + +func TestMQTTWill(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + sub := natsSubSync(t, nc, "will.topic") + + willMsg := []byte("bye") + + for _, test := range []struct { + name string + willExpected bool + willQoS byte + }{ + {"will qos 0", true, 0}, + {"will qos 1", true, 1}, + {"proper disconnect no will", false, 0}, + } { + t.Run(test.name, func(t *testing.T) { + mcs, rs := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "will/#", qos: 1}}, []byte{1}) + testMQTTFlush(t, mcs, nil, rs) + + mc, r := testMQTTConnect(t, + &mqttConnInfo{ + cleanSess: true, + will: &mqttWill{ + topic: []byte("will/topic"), + message: willMsg, + qos: test.willQoS, + }, + }, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + if test.willExpected { + mc.Close() + testMQTTCheckPubMsg(t, mcs, rs, "will/topic", test.willQoS<<1, willMsg) + wm := natsNexMsg(t, sub, time.Second) + if !bytes.Equal(wm.Data, willMsg) { + t.Fatalf("Expected will message to be %q, got %q", willMsg, wm.Data) + } + } else { + testMQTTDisconnect(t, mc, nil) + testMQTTExpectNothing(t, rs) + if wm, err := sub.NextMsg(100 * time.Millisecond); err == nil { + t.Fatalf("Should not have receive a message, got %v", wm) + } + } + }) + } +} + +func TestMQTTWillRetain(t *testing.T) { + for _, test := range []struct { + name string + pubQoS byte + subQoS byte + }{ + {"pub QoS0 sub QoS0", 0, 0}, + {"pub QoS0 sub QoS1", 0, 1}, + {"pub QoS1 sub QoS0", 1, 0}, + {"pub QoS1 sub QoS1", 1, 1}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + willTopic := []byte("will/topic") + willMsg := []byte("bye") + + mc, r := testMQTTConnect(t, + &mqttConnInfo{ + cleanSess: true, + will: &mqttWill{ + topic: willTopic, + message: willMsg, + qos: test.pubQoS, + retain: true, + }, + }, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Disconnect, which will cause will to be produced with retain flag. + mc.Close() + + // Create subscription on will topic and expect will message. + mcs, rs := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "will/#", qos: test.subQoS}}, []byte{test.subQoS}) + pflags, _ := testMQTTGetPubMsg(t, mcs, rs, "will/topic", willMsg) + if pflags&mqttPubFlagRetain == 0 { + t.Fatalf("expected retain flag to be set, it was not: %v", pflags) + } + // Expected QoS will be the lesser of the pub/sub QoS. + expectedQoS := test.pubQoS + if test.subQoS == 0 { + expectedQoS = 0 + } + if qos := mqttGetQoS(pflags); qos != expectedQoS { + t.Fatalf("expected qos to be %v, got %v", expectedQoS, qos) + } + }) + } +} + +func TestMQTTWillRetainPermViolation(t *testing.T) { + template := ` + port: -1 + jetstream: enabled + authorization { + mqtt_perms = { + publish = ["%s"] + subscribe = ["foo", "bar", "$MQTT.sub.>"] + } + users = [ + {user: mqtt, password: pass, permissions: $mqtt_perms} + ] + } + mqtt { + port: -1 + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(template, "foo"))) + defer os.Remove(conf) + + s, o := RunServerWithConfig(conf) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: "mqtt", + pass: "pass", + } + + // We create first a connection with the Will topic that the publisher + // is allowed to publish to. + ci.will = &mqttWill{ + topic: []byte("foo"), + message: []byte("bye"), + qos: 1, + retain: true, + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Disconnect, which will cause the Will to be sent with retain flag. + mc.Close() + + // Create a subscription on the Will subject and we should receive it. + ci.will = nil + mcs, rs := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + pflags, _ := testMQTTGetPubMsg(t, mcs, rs, "foo", []byte("bye")) + if pflags&mqttPubFlagRetain == 0 { + t.Fatalf("expected retain flag to be set, it was not: %v", pflags) + } + if qos := mqttGetQoS(pflags); qos != 1 { + t.Fatalf("expected qos to be 1, got %v", qos) + } + testMQTTDisconnect(t, mcs, nil) + + // Now create another connection with a Will that client is not allowed to publish to. + ci.will = &mqttWill{ + topic: []byte("bar"), + message: []byte("bye"), + qos: 1, + retain: true, + } + mc, r = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Disconnect, to cause Will to be produced, but in that case should not be stored + // since user not allowed to publish on "bar". + mc.Close() + + // Create sub on "bar" which user is allowed to subscribe to. + ci.will = nil + mcs, rs = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + // No Will should be published since it should not have been stored in the first place. + testMQTTExpectNothing(t, rs) + testMQTTDisconnect(t, mcs, nil) + + // Now remove permission to publish on "foo" and check that a new subscription + // on "foo" is now not getting the will message because the original user no + // longer has permission to do so. + reloadUpdateConfig(t, s, conf, fmt.Sprintf(template, "baz")) + + mcs, rs = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTExpectNothing(t, rs) + testMQTTDisconnect(t, mcs, nil) +} + +func TestMQTTPublishRetain(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + for _, test := range []struct { + name string + retained bool + sentValue string + expectedValue string + subGetsIt bool + }{ + {"publish retained", true, "retained", "retained", true}, + {"publish not retained", false, "not retained", "retained", true}, + {"remove retained", true, "", "", false}, + } { + t.Run(test.name, func(t *testing.T) { + mc1, rs1 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc1.Close() + testMQTTCheckConnAck(t, rs1, mqttConnAckRCConnectionAccepted, false) + testMQTTPublish(t, mc1, rs1, 0, false, test.retained, "foo", 0, []byte(test.sentValue)) + + mc2, rs2 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc2.Close() + testMQTTCheckConnAck(t, rs2, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc2, rs2, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 1}}, []byte{1}) + + if test.subGetsIt { + pflags, _ := testMQTTGetPubMsg(t, mc2, rs2, "foo", []byte(test.expectedValue)) + if pflags&mqttPubFlagRetain == 0 { + t.Fatalf("retain flag should have been set, it was not: flags=%v", pflags) + } + } else { + testMQTTExpectNothing(t, rs2) + } + + testMQTTDisconnect(t, mc1, nil) + testMQTTDisconnect(t, mc2, nil) + }) + } +} + +func TestMQTTPublishRetainPermViolation(t *testing.T) { + o := testMQTTDefaultOptions() + o.Users = []*User{ + { + Username: "mqtt", + Password: "pass", + Permissions: &Permissions{ + Publish: &SubjectPermission{Allow: []string{"foo"}}, + Subscribe: &SubjectPermission{Allow: []string{"bar", "$MQTT.sub.>"}}, + }, + }, + } + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: "mqtt", + pass: "pass", + } + + mc1, rs1 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc1.Close() + testMQTTCheckConnAck(t, rs1, mqttConnAckRCConnectionAccepted, false) + testMQTTPublish(t, mc1, rs1, 0, false, true, "bar", 0, []byte("retained")) + + mc2, rs2 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc2.Close() + testMQTTCheckConnAck(t, rs2, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc2, rs2, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + testMQTTExpectNothing(t, rs2) + + testMQTTDisconnect(t, mc1, nil) + testMQTTDisconnect(t, mc2, nil) +} + +func TestMQTTCleanSession(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + clientID: "me", + cleanSess: false, + } + c, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTDisconnect(t, c, nil) + + c, r = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + testMQTTDisconnect(t, c, nil) + + ci.cleanSess = true + c, r = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTDisconnect(t, c, nil) +} + +func TestMQTTDuplicateClientID(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + clientID: "me", + cleanSess: false, + } + c1, r1 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c1.Close() + testMQTTCheckConnAck(t, r1, mqttConnAckRCConnectionAccepted, false) + + c2, r2 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, true) + + // The old client should be disconnected. + testMQTTExpectDisconnect(t, c1) +} + +func TestMQTTPersistedSession(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, c, r, + []*mqttFilter{ + &mqttFilter{filter: "foo/#", qos: 1}, + &mqttFilter{filter: "bar", qos: 1}, + &mqttFilter{filter: "baz", qos: 0}, + }, + []byte{1, 1, 0}) + testMQTTFlush(t, c, nil, r) + + // Shutdown server, close connection and restart server. It should + // have restored the session and consumers. + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + s.Shutdown() + c.Close() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + // Create a publisher that will send qos1 so we verify that messages + // are stored for the persisted sessions. + c, r = testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, c, r, 1, false, false, "foo/bar", 1, []byte("msg0")) + testMQTTFlush(t, c, nil, r) + testMQTTDisconnect(t, c, nil) + c.Close() + + // Recreate consumer session + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Since consumers have been recovered, messages should be received + // (MQTT does not need client to recreate consumers for a recovered + // session) + + // Check that qos1 publish message is received. + testMQTTCheckPubMsg(t, c, r, "foo/bar", mqttPubQos1, []byte("msg0")) + + // Now publish some messages to all subscriptions. + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + natsPub(t, nc, "foo.bar", []byte("msg1")) + testMQTTCheckPubMsg(t, c, r, "foo/bar", 0, []byte("msg1")) + + natsPub(t, nc, "foo", []byte("msg2")) + testMQTTCheckPubMsg(t, c, r, "foo", 0, []byte("msg2")) + + natsPub(t, nc, "bar", []byte("msg3")) + testMQTTCheckPubMsg(t, c, r, "bar", 0, []byte("msg3")) + + natsPub(t, nc, "baz", []byte("msg4")) + testMQTTCheckPubMsg(t, c, r, "baz", 0, []byte("msg4")) + + // Now unsub "bar" and verify that message published on this topic + // is not received. + testMQTTUnsub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "bar"}}) + natsPub(t, nc, "bar", []byte("msg5")) + testMQTTExpectNothing(t, r) + + nc.Close() + s.Shutdown() + c.Close() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + // Recreate a client + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + nc = natsConnect(t, s.ClientURL()) + defer nc.Close() + + natsPub(t, nc, "foo.bar", []byte("msg6")) + testMQTTCheckPubMsg(t, c, r, "foo/bar", 0, []byte("msg6")) + + natsPub(t, nc, "foo", []byte("msg7")) + testMQTTCheckPubMsg(t, c, r, "foo", 0, []byte("msg7")) + + // Make sure that we did not recover bar. + natsPub(t, nc, "bar", []byte("msg8")) + testMQTTExpectNothing(t, r) + + natsPub(t, nc, "baz", []byte("msg9")) + testMQTTCheckPubMsg(t, c, r, "baz", 0, []byte("msg9")) + + // Have the sub client send a subscription downgrading the qos1 subscription. + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 0}}, []byte{0}) + testMQTTFlush(t, c, nil, r) + + nc.Close() + s.Shutdown() + c.Close() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + // Recreate the sub client + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Publish as a qos1 + c2, r2 := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer c2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, false) + testMQTTPublish(t, c2, r2, 1, false, false, "foo/bar", 1, []byte("msg10")) + + // Verify that it is received as qos0 which is the qos of the subscription. + testMQTTCheckPubMsg(t, c, r, "foo/bar", 0, []byte("msg10")) + + testMQTTDisconnect(t, c, nil) + c.Close() + testMQTTDisconnect(t, c2, nil) + c2.Close() + + // Finally, recreate the sub with clean session and ensure that all is gone + cisub.cleanSess = true + for i := 0; i < 2; i++ { + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + nc = natsConnect(t, s.ClientURL()) + defer nc.Close() + + natsPub(t, nc, "foo.bar", []byte("msg11")) + testMQTTExpectNothing(t, r) + + natsPub(t, nc, "foo", []byte("msg12")) + testMQTTExpectNothing(t, r) + + // Make sure that we did not recover bar. + natsPub(t, nc, "bar", []byte("msg13")) + testMQTTExpectNothing(t, r) + + natsPub(t, nc, "baz", []byte("msg14")) + testMQTTExpectNothing(t, r) + + testMQTTDisconnect(t, c, nil) + c.Close() + nc.Close() + + s.Shutdown() + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + } +} + +func TestMQTTRecoverSessionAndAddNewSub(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + cisub := &mqttConnInfo{clientID: "sub1", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTDisconnect(t, c, nil) + c.Close() + + // Shutdown server, close connection and restart server. It should + // have restored the session and consumers. + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + s.Shutdown() + c.Close() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // No need for defer since it is done top of function + + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + // Now add sub and make sure it does not crash + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, c, nil, r) + + // Now repeat with a new client but without server restart. + cisub2 := &mqttConnInfo{clientID: "sub2", cleanSess: false} + c2, r2 := testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port) + defer c2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, false) + testMQTTDisconnect(t, c2, nil) + c2.Close() + + c2, r2 = testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port) + defer c2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, true) + testMQTTSub(t, 1, c2, r2, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + testMQTTFlush(t, c2, nil, r2) +} + +func TestMQTTRecoverSessionWithSubAndClientResendSub(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + cisub1 := &mqttConnInfo{clientID: "sub1", cleanSess: false} + c, r := testMQTTConnect(t, cisub1, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Have a client send a SUBSCRIBE protocol for foo, QoS1 + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTDisconnect(t, c, nil) + c.Close() + + // Restart the server now. + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + s.Shutdown() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // No need for defer since it is done top of function + + // Now restart the client. Since the client was created with cleanSess==false, + // the server will have recorded the subscriptions for this client. + c, r = testMQTTConnect(t, cisub1, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + // At this point, the server has recreated the subscription on foo, QoS1. + + // For applications that restart, it is possible (likely) that they + // will resend their SUBSCRIBE protocols, so do so now: + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, c, nil, r) + + checkNumSub := func(clientID string) { + t.Helper() + + // Find the MQTT client... + mc := testMQTTGetClient(t, s, clientID) + + // Check how many NATS subscriptions are registered. + var fooSub int + var otherSub int + mc.mu.Lock() + for _, sub := range mc.subs { + switch string(sub.subject) { + case "foo": + fooSub++ + default: + otherSub++ + } + } + mc.mu.Unlock() + + // We should have 2 subscriptions, one on "foo", and one for the JS durable + // consumer's delivery subject. + if fooSub != 1 { + t.Fatalf("Expected 1 sub on 'foo', got %v", fooSub) + } + if otherSub != 1 { + t.Fatalf("Expected 1 subscription for JS durable, got %v", otherSub) + } + } + checkNumSub("sub1") + + c.Close() + + // Now same but without the server restart in-between. + cisub2 := &mqttConnInfo{clientID: "sub2", cleanSess: false} + c, r = testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTDisconnect(t, c, nil) + c.Close() + // Restart client + c, r = testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, c, nil, r) + // Check client subs + checkNumSub("sub2") +} + +func TestMQTTPersistRetainedMsg(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + c, r := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, c, r, 1, false, true, "foo", 1, []byte("foo1")) + testMQTTPublish(t, c, r, 1, false, true, "foo", 1, []byte("foo2")) + testMQTTPublish(t, c, r, 1, false, true, "bar", 1, []byte("bar1")) + testMQTTPublish(t, c, r, 0, false, true, "baz", 1, []byte("baz1")) + // Remove bar + testMQTTPublish(t, c, r, 1, false, true, "bar", 1, nil) + testMQTTDisconnect(t, c, nil) + c.Close() + + s.Shutdown() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTCheckPubMsg(t, c, r, "foo", mqttPubFlagRetain|mqttPubQos1, []byte("foo2")) + + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "baz", qos: 1}}, []byte{1}) + testMQTTCheckPubMsg(t, c, r, "baz", mqttPubFlagRetain, []byte("baz1")) + + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + testMQTTExpectNothing(t, r) + + testMQTTDisconnect(t, c, nil) + c.Close() +} + +func TestMQTTConnAckFirstProto(t *testing.T) { + o := testMQTTDefaultOptions() + o.NoLog, o.Debug, o.Trace = true, false, false + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 0}}, []byte{0}) + testMQTTDisconnect(t, c, nil) + c.Close() + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + wg := sync.WaitGroup{} + wg.Add(1) + ch := make(chan struct{}, 1) + ready := make(chan struct{}) + go func() { + defer wg.Done() + + close(ready) + for { + nc.Publish("foo", []byte("msg")) + select { + case <-ch: + return + default: + } + } + }() + + <-ready + for i := 0; i < 100; i++ { + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + c.Close() + } + close(ch) + wg.Wait() +} + +func TestMQTTRedeliveryAckWait(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.AckWait = 250 * time.Millisecond + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("foo1")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 2, []byte("foo2")) + testMQTTDisconnect(t, cp, nil) + cp.Close() + + for i := 0; i < 2; i++ { + flags := mqttPubQos1 + if i > 0 { + flags |= mqttPubFlagDup + } + pi1 := testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("foo1")) + pi2 := testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("foo2")) + + if pi1 != 1 || pi2 != 2 { + t.Fatalf("Unexpected pi values: %v, %v", pi1, pi2) + } + } + // Ack first message + testMQTTSendPubAck(t, c, 1) + // Redelivery should only be for second message now + for i := 0; i < 2; i++ { + flags := mqttPubQos1 | mqttPubFlagDup + pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("foo2")) + if pi != 2 { + t.Fatalf("Unexpected pi to be 2, got %v", pi) + } + } + + // Restart client, should receive second message with pi==2 + c.Close() + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + // Check that message is received with proper pi + pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("foo2")) + if pi != 2 { + t.Fatalf("Unexpected pi to be 2, got %v", pi) + } + // Now ack second message + testMQTTSendPubAck(t, c, 2) + // Flush to make sure it is processed before checking client's maps + testMQTTFlush(t, c, nil, r) + + // Look for the sub client + mc := testMQTTGetClient(t, s, "sub") + mc.mu.Lock() + sess := mc.mqtt.sess + sess.mu.Lock() + lpi := len(sess.pending) + var lsseq int + for _, sseqToPi := range sess.cpending { + lsseq += len(sseqToPi) + } + sess.mu.Unlock() + mc.mu.Unlock() + if lpi != 0 || lsseq != 0 { + t.Fatalf("Maps should be empty, got %v, %v", lpi, lsseq) + } +} + +func TestMQTTAckWaitConfigChange(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.AckWait = 250 * time.Millisecond + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + sendMsg := func(topic, payload string) { + t.Helper() + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, topic, 1, []byte(payload)) + testMQTTDisconnect(t, cp, nil) + cp.Close() + } + sendMsg("foo", "msg1") + + for i := 0; i < 2; i++ { + flags := mqttPubQos1 + if i > 0 { + flags |= mqttPubFlagDup + } + testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("msg1")) + } + + // Restart the server with a different AckWait option value. + // Verify that MQTT sub restart succeeds. It will keep the + // original value. + c.Close() + s.Shutdown() + + o.Port = -1 + o.MQTT.Port = -1 + o.MQTT.AckWait = 10 * time.Millisecond + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1")) + start := time.Now() + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1")) + if dur := time.Since(start); dur < 200*time.Millisecond { + t.Fatalf("AckWait seem to have changed for existing subscription: %v", dur) + } + + // Create new subscription + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + sendMsg("bar", "msg2") + testMQTTCheckPubMsgNoAck(t, c, r, "bar", mqttPubQos1, []byte("msg2")) + start = time.Now() + testMQTTCheckPubMsgNoAck(t, c, r, "bar", mqttPubQos1|mqttPubFlagDup, []byte("msg2")) + if dur := time.Since(start); dur > 50*time.Millisecond { + t.Fatalf("AckWait new value not used by new sub: %v", dur) + } + c.Close() +} + +func TestMQTTUnsubscribeWithPendingAcks(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.AckWait = 250 * time.Millisecond + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg")) + testMQTTDisconnect(t, cp, nil) + cp.Close() + + for i := 0; i < 2; i++ { + flags := mqttPubQos1 + if i > 0 { + flags |= mqttPubFlagDup + } + testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("msg")) + } + + testMQTTUnsub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo"}}) + testMQTTFlush(t, c, nil, r) + + mc := testMQTTGetClient(t, s, "sub") + mc.mu.Lock() + sess := mc.mqtt.sess + sess.mu.Lock() + pal := len(sess.pending) + sess.mu.Unlock() + mc.mu.Unlock() + if pal != 0 { + t.Fatalf("Expected pending ack map to be empty, got %v", pal) + } +} + +func TestMQTTMaxAckPending(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.MaxAckPending = 1 + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg2")) + + pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1")) + // Check that we don't receive the second one due to max ack pending + testMQTTExpectNothing(t, r) + + // Now ack first message + testMQTTSendPubAck(t, c, pi) + // Now we should receive message 2 + testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg2")) + testMQTTDisconnect(t, c, nil) + + // Send 2 messages while sub is offline + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg3")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg4")) + + // Restart consumer + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Should receive only message 3 + pi = testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg3")) + testMQTTExpectNothing(t, r) + + // Ack and get the next + testMQTTSendPubAck(t, c, pi) + testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg4")) + + // Check that change to config does not prevent restart of sub. + cp.Close() + c.Close() + s.Shutdown() + + o.Port = -1 + o.MQTT.Port = -1 + o.MQTT.MaxAckPending = 2 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + cp, rp = testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg5")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg6")) + + // Restart consumer + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Should receive only message 5 + pi = testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg5")) + testMQTTExpectNothing(t, r) + + // Ack and get the next + testMQTTSendPubAck(t, c, pi) + testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg6")) +} + +func TestMQTTMaxAckPendingForMultipleSubs(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.MaxAckPending = 1 + s := testMQTTRunServer(t, o) + defer s.Shutdown() + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1")) + pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1")) + + // Now send a second message but on topic bar + testMQTTPublish(t, cp, rp, 1, false, false, "bar", 1, []byte("msg2")) + + // JS allows us to limit per consumer, but we apply the limit to the + // session, so although JS will attempt to delivery this message, + // the MQTT code will suppress it. + testMQTTExpectNothing(t, r) + + // Ack the first message. + testMQTTSendPubAck(t, c, pi) + + // Now we should get the second message + testMQTTCheckPubMsg(t, c, r, "bar", mqttPubQos1|mqttPubFlagDup, []byte("msg2")) +} + +func TestMQTTConfigReload(t *testing.T) { + template := ` + jetstream: true + mqtt { + port: -1 + ack_wait: %s + max_ack_pending: %s + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(template, `"5s"`, `10000`))) + defer os.Remove(conf) + + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + if val := o.MQTT.AckWait; val != 5*time.Second { + t.Fatalf("Invalid ackwait: %v", val) + } + if val := o.MQTT.MaxAckPending; val != 10000 { + t.Fatalf("Invalid ackwait: %v", val) + } + + changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, `"250ms"`, `1`))) + if err := s.Reload(); err != nil { + t.Fatalf("Error on reload: %v", err) + } + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg2")) + + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1")) + start := time.Now() + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1")) + if dur := time.Since(start); dur > 500*time.Millisecond { + t.Fatalf("AckWait not applied? dur=%v", dur) + } + c.Close() + cp.Close() + s.Shutdown() + + changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, `"30s"`, `1`))) + s, o = RunServerWithConfig(conf) + defer s.Shutdown() + + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub = &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp = testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg2")) + + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1")) + testMQTTExpectNothing(t, r) + + // Increate the max ack pending + changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, `"30s"`, `10`))) + // Reload now + if err := s.Reload(); err != nil { + t.Fatalf("Error on reload: %v", err) + } + // See that message 2 can now be received (1 will be redelivered too) + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1")) + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg2")) +} + +// Benchmarks + +const ( + mqttPubSubj = "a" + mqttBenchBufLen = 32768 +) + +func mqttBenchPubQoS0(b *testing.B, subject, payload string, numSubs int) { + b.StopTimer() + o := testMQTTDefaultOptions() + s := RunServer(o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{clientID: "pub", cleanSess: true} + c, br := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(b, br, mqttConnAckRCConnectionAccepted, false) + w := &mqttWriter{} + mqttWritePublish(w, 0, false, false, subject, 0, []byte(payload)) + sendOp := w.Bytes() + + dch := make(chan error, 1) + totalSize := int64(len(sendOp)) + cdch := 0 + + createSub := func(i int) { + ci := &mqttConnInfo{clientID: fmt.Sprintf("sub%d", i), cleanSess: true} + cs, brs := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(b, brs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(b, 1, cs, brs, []*mqttFilter{&mqttFilter{filter: subject, qos: 0}}, []byte{0}) + testMQTTFlush(b, cs, nil, brs) + + w := &mqttWriter{} + varHeaderAndPayload := 2 + len(subject) + len(payload) + w.WriteVarInt(varHeaderAndPayload) + size := 1 + w.Len() + varHeaderAndPayload + totalSize += int64(size) + + go func() { + mqttBenchConsumeMsgQoS0(cs, int64(b.N)*int64(size), dch) + cs.Close() + }() + } + for i := 0; i < numSubs; i++ { + createSub(i + 1) + cdch++ + } + + bw := bufio.NewWriterSize(c, mqttBenchBufLen) + b.SetBytes(totalSize) + b.StartTimer() + for i := 0; i < b.N; i++ { + bw.Write(sendOp) + } + testMQTTFlush(b, c, bw, br) + for i := 0; i < cdch; i++ { + if e := <-dch; e != nil { + b.Fatal(e.Error()) + } + } + b.StopTimer() + c.Close() + s.Shutdown() +} + +func mqttBenchConsumeMsgQoS0(c net.Conn, total int64, dch chan<- error) { + var buf [mqttBenchBufLen]byte + var err error + var n int + for size := int64(0); size < total; { + n, err = c.Read(buf[:]) + if err != nil { + break + } + size += int64(n) + } + dch <- err +} + +func mqttBenchPubQoS1(b *testing.B, subject, payload string, numSubs int) { + b.StopTimer() + o := testMQTTDefaultOptions() + o.MQTT.MaxAckPending = 0xFFFF + s := RunServer(o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{cleanSess: true} + c, br := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(b, br, mqttConnAckRCConnectionAccepted, false) + + w := &mqttWriter{} + mqttWritePublish(w, 1, false, false, subject, 1, []byte(payload)) + // For reported bytes we will count the PUBLISH + PUBACK (4 bytes) + totalSize := int64(len(w.Bytes()) + 4) + w.Reset() + + pi := uint16(1) + maxpi := uint16(60000) + ppich := make(chan error, 10) + dch := make(chan error, 1+numSubs) + cdch := 1 + // Start go routine to consume PUBACK for published QoS 1 messages. + go mqttBenchConsumePubAck(c, b.N, dch, ppich) + + createSub := func(i int) { + ci := &mqttConnInfo{clientID: fmt.Sprintf("sub%d", i), cleanSess: true} + cs, brs := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(b, brs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(b, 1, cs, brs, []*mqttFilter{&mqttFilter{filter: subject, qos: 1}}, []byte{1}) + testMQTTFlush(b, cs, nil, brs) + + w := &mqttWriter{} + varHeaderAndPayload := 2 + len(subject) + 2 + len(payload) + w.WriteVarInt(varHeaderAndPayload) + size := 1 + w.Len() + varHeaderAndPayload + // Add to the bytes reported the size of message sent to subscriber + PUBACK (4 bytes) + totalSize += int64(size + 4) + + go func() { + mqttBenchConsumeMsgQos1(cs, b.N, size, dch) + cs.Close() + }() + } + for i := 0; i < numSubs; i++ { + createSub(i + 1) + cdch++ + } + + flush := func() { + b.Helper() + if _, err := c.Write(w.Bytes()); err != nil { + b.Fatalf("Error on write: %v", err) + } + w.Reset() + } + + b.SetBytes(totalSize) + b.StartTimer() + for i := 0; i < b.N; i++ { + if pi <= maxpi { + mqttWritePublish(w, 1, false, false, subject, pi, []byte(payload)) + pi++ + if w.Len() >= mqttBenchBufLen { + flush() + } + } else { + if w.Len() > 0 { + flush() + } + if pi > 60000 { + pi = 1 + maxpi = 0 + } + if e := <-ppich; e != nil { + b.Fatal(e.Error()) + } + maxpi += 10000 + i-- + } + } + if w.Len() > 0 { + flush() + } + for i := 0; i < cdch; i++ { + if e := <-dch; e != nil { + b.Fatal(e.Error()) + } + } + b.StopTimer() + c.Close() + s.Shutdown() +} + +func mqttBenchConsumeMsgQos1(c net.Conn, total, size int, dch chan<- error) { + var buf [mqttBenchBufLen]byte + pubAck := [4]byte{mqttPacketPubAck, 0x2, 0, 0} + var err error + var n int + var pi uint16 + var prev int + for i := 0; i < total; { + n, err = c.Read(buf[:]) + if err != nil { + break + } + n += prev + for ; n >= size; n -= size { + i++ + pi++ + pubAck[2] = byte(pi >> 8) + pubAck[3] = byte(pi) + if _, err = c.Write(pubAck[:4]); err != nil { + dch <- err + return + } + if pi == 60000 { + pi = 0 + } + } + prev = n + } + dch <- err +} + +func mqttBenchConsumePubAck(c net.Conn, total int, dch, ppich chan<- error) { + var buf [mqttBenchBufLen]byte + var err error + var n int + var pi uint16 + var prev int + for i := 0; i < total; { + n, err = c.Read(buf[:]) + if err != nil { + break + } + n += prev + for ; n >= 4; n -= 4 { + i++ + pi++ + if pi%10000 == 0 { + ppich <- nil + } + if pi == 60001 { + pi = 0 + } + } + prev = n + } + ppich <- err + dch <- err +} + +func BenchmarkMQTT_QoS0_Pub_______0b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, "", 0) +} + +func BenchmarkMQTT_QoS0_Pub_______8b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(8), 0) +} + +func BenchmarkMQTT_QoS0_Pub______32b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(32), 0) +} + +func BenchmarkMQTT_QoS0_Pub_____128b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(128), 0) +} + +func BenchmarkMQTT_QoS0_Pub_____256b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(256), 0) +} + +func BenchmarkMQTT_QoS0_Pub_______1K_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(1024), 0) +} + +func BenchmarkMQTT_QoS0_PubSub1___0b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, "", 1) +} + +func BenchmarkMQTT_QoS0_PubSub1___8b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(8), 1) +} + +func BenchmarkMQTT_QoS0_PubSub1__32b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(32), 1) +} + +func BenchmarkMQTT_QoS0_PubSub1_128b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(128), 1) +} + +func BenchmarkMQTT_QoS0_PubSub1_256b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(256), 1) +} + +func BenchmarkMQTT_QoS0_PubSub1___1K_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(1024), 1) +} + +func BenchmarkMQTT_QoS0_PubSub2___0b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, "", 2) +} + +func BenchmarkMQTT_QoS0_PubSub2___8b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(8), 2) +} + +func BenchmarkMQTT_QoS0_PubSub2__32b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(32), 2) +} + +func BenchmarkMQTT_QoS0_PubSub2_128b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(128), 2) +} + +func BenchmarkMQTT_QoS0_PubSub2_256b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(256), 2) +} + +func BenchmarkMQTT_QoS0_PubSub2___1K_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(1024), 2) +} + +func BenchmarkMQTT_QoS1_Pub_______0b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, "", 0) +} + +func BenchmarkMQTT_QoS1_Pub_______8b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(8), 0) +} + +func BenchmarkMQTT_QoS1_Pub______32b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(32), 0) +} + +func BenchmarkMQTT_QoS1_Pub_____128b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(128), 0) +} + +func BenchmarkMQTT_QoS1_Pub_____256b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(256), 0) +} + +func BenchmarkMQTT_QoS1_Pub_______1K_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(1024), 0) +} + +func BenchmarkMQTT_QoS1_PubSub1___0b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, "", 1) +} + +func BenchmarkMQTT_QoS1_PubSub1___8b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(8), 1) +} + +func BenchmarkMQTT_QoS1_PubSub1__32b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(32), 1) +} + +func BenchmarkMQTT_QoS1_PubSub1_128b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(128), 1) +} + +func BenchmarkMQTT_QoS1_PubSub1_256b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(256), 1) +} + +func BenchmarkMQTT_QoS1_PubSub1___1K_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(1024), 1) +} + +func BenchmarkMQTT_QoS1_PubSub2___0b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, "", 2) +} + +func BenchmarkMQTT_QoS1_PubSub2___8b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(8), 2) +} + +func BenchmarkMQTT_QoS1_PubSub2__32b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(32), 2) +} + +func BenchmarkMQTT_QoS1_PubSub2_128b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(128), 2) +} + +func BenchmarkMQTT_QoS1_PubSub2_256b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(256), 2) +} + +func BenchmarkMQTT_QoS1_PubSub2___1K_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(1024), 2) +} diff --git a/server/opts.go b/server/opts.go index 32802133..3f44538a 100644 --- a/server/opts.go +++ b/server/opts.go @@ -197,6 +197,7 @@ type Options struct { JetStreamMaxStore int64 `json:"-"` StoreDir string `json:"-"` Websocket WebsocketOpts `json:"-"` + MQTT MQTTOpts `json:"-"` ProfPort int `json:"-"` PidFile string `json:"-"` PortsFileDir string `json:"-"` @@ -256,7 +257,7 @@ type Options struct { routeProto int } -// WebsocketOpts ... +// WebsocketOpts are options for websocket type WebsocketOpts struct { // The server will accept websocket client connections on this hostname/IP. Host string @@ -316,6 +317,49 @@ type WebsocketOpts struct { HandshakeTimeout time.Duration } +// MQTTOpts are options for MQTT +type MQTTOpts struct { + // The server will accept MQTT client connections on this hostname/IP. + Host string + // The server will accept MQTT client connections on this port. + Port int + + // If no user name is provided when a client connects, will default to the + // matching user from the global list of users in `Options.Users`. + NoAuthUser string + + // Authentication section. If anything is configured in this section, + // it will override the authorization configuration of regular clients. + Username string + Password string + Token string + + // Timeout for the authentication process. + AuthTimeout float64 + + // TLS configuration is required. + TLSConfig *tls.Config + // If true, map certificate values for authentication purposes. + TLSMap bool + // Timeout for the TLS handshake + TLSTimeout float64 + + // AckWait is the amount of time after which a QoS 1 message sent to + // a client is redelivered as a DUPLICATE if the server has not + // received the PUBACK on the original Packet Identifier. + // The value has to be positive. + // Zero will cause the server to use the default value (1 hour). + // Note that changes to this option is applied only to new MQTT subscriptions. + AckWait time.Duration + + // MaxAckPending is the amount of QoS 1 messages the server can send to + // a session without receiving any PUBACK for those messages. + // The valid range is [0..65535]. + // Zero will cause the server to use the default value (1024). + // Note that changes to this option is applied only to new MQTT sessions. + MaxAckPending uint16 +} + type netResolver interface { LookupHost(ctx context.Context, host string) ([]string, error) } @@ -1011,6 +1055,11 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error *errors = append(*errors, err) return } + case "mqtt": + if err := parseMQTT(tk, o, errors, warnings); err != nil { + *errors = append(*errors, err) + return + } default: if au := atomic.LoadInt32(&allowUnknownTopLevelField); au == 0 && !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ @@ -3327,7 +3376,7 @@ func parseTLS(v interface{}, isClientCtx bool) (t *TLSConfigOpts, retErr error) return &tc, nil } -func parseAuthForWS(v interface{}, errors *[]error, warnings *[]error) *authorization { +func parseSimpleAuth(v interface{}, errors *[]error, warnings *[]error) *authorization { var ( am map[string]interface{} tk token @@ -3463,7 +3512,7 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro case "compression": o.Websocket.Compression = mv.(bool) case "authorization", "authentication": - auth := parseAuthForWS(tk, errors, warnings) + auth := parseSimpleAuth(tk, errors, warnings) o.Websocket.Username = auth.user o.Websocket.Password = auth.pass o.Websocket.Token = auth.token @@ -3488,6 +3537,79 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro return nil } +func parseMQTT(v interface{}, o *Options, errors *[]error, warnings *[]error) error { + var lt token + defer convertPanicToErrorList(<, errors) + + tk, v := unwrapValue(v, <) + gm, ok := v.(map[string]interface{}) + if !ok { + return &configErr{tk, fmt.Sprintf("Expected mqtt to be a map, got %T", v)} + } + for mk, mv := range gm { + // Again, unwrap token value if line check is required. + tk, mv = unwrapValue(mv, <) + switch strings.ToLower(mk) { + case "listen": + hp, err := parseListen(mv) + if err != nil { + err := &configErr{tk, err.Error()} + *errors = append(*errors, err) + continue + } + o.MQTT.Host = hp.host + o.MQTT.Port = hp.port + case "port": + o.MQTT.Port = int(mv.(int64)) + case "host", "net": + o.MQTT.Host = mv.(string) + case "tls": + tc, err := parseTLS(tk, true) + if err != nil { + *errors = append(*errors, err) + continue + } + if o.MQTT.TLSConfig, err = GenTLSConfig(tc); err != nil { + err := &configErr{tk, err.Error()} + *errors = append(*errors, err) + continue + } + o.MQTT.TLSTimeout = tc.Timeout + o.MQTT.TLSMap = tc.Map + case "authorization", "authentication": + auth := parseSimpleAuth(tk, errors, warnings) + o.MQTT.Username = auth.user + o.MQTT.Password = auth.pass + o.MQTT.Token = auth.token + o.MQTT.AuthTimeout = auth.timeout + case "no_auth_user": + o.MQTT.NoAuthUser = mv.(string) + case "ack_wait", "ackwait": + o.MQTT.AckWait = parseDuration("ack_wait", tk, mv, errors, warnings) + case "max_ack_pending", "max_pending", "max_inflight": + tmp := int(mv.(int64)) + if tmp < 0 || tmp > 0xFFFF { + err := &configErr{tk, fmt.Sprintf("invalid value %v, should in [0..%d] range", tmp, 0xFFFF)} + *errors = append(*errors, err) + } else { + o.MQTT.MaxAckPending = uint16(tmp) + } + default: + if !tk.IsUsedVariable() { + err := &unknownConfigFieldErr{ + field: mk, + configErr: configErr{ + token: tk, + }, + } + *errors = append(*errors, err) + continue + } + } + } + return nil +} + // GenTLSConfig loads TLS related configuration parameters. func GenTLSConfig(tc *TLSConfigOpts) (*tls.Config, error) { // Create the tls.Config from our options before including the certs. @@ -3831,6 +3953,14 @@ func setBaselineOptions(opts *Options) { opts.Websocket.Host = DEFAULT_HOST } } + if opts.MQTT.Port != 0 { + if opts.MQTT.Host == "" { + opts.MQTT.Host = DEFAULT_HOST + } + if opts.MQTT.TLSTimeout == 0 { + opts.MQTT.TLSTimeout = float64(TLS_TIMEOUT) / float64(time.Second) + } + } // JetStream if opts.JetStreamMaxMemory == 0 { opts.JetStreamMaxMemory = -1 diff --git a/server/parser.go b/server/parser.go index e9492b2f..f76337df 100644 --- a/server/parser.go +++ b/server/parser.go @@ -130,6 +130,11 @@ const ( ) func (c *client) parse(buf []byte) error { + // Branch out to mqtt clients. c.mqtt is immutable, but should it become + // an issue (say data race detection), we could branch outside in readLoop + if c.mqtt != nil { + return c.mqttParse(buf) + } var i int var b byte var lmsg bool diff --git a/server/reload.go b/server/reload.go index 14fa386e..790bb80a 100644 --- a/server/reload.go +++ b/server/reload.go @@ -596,6 +596,25 @@ func (m *maxTracedMsgLenOption) Apply(server *Server) { server.Noticef("Reloaded: max_traced_msg_len = %d", m.newValue) } +type mqttAckWaitReload struct { + noopOption + newValue time.Duration +} + +func (o *mqttAckWaitReload) Apply(s *Server) { + s.Noticef("Reloaded: MQTT ack_wait = %v", o.newValue) +} + +type mqttMaxAckPendingReload struct { + noopOption + newValue uint16 +} + +func (o *mqttMaxAckPendingReload) Apply(s *Server) { + s.mqttUpdateMaxAckPending(o.newValue) + s.Noticef("Reloaded: MQTT max_ack_pending = %v", o.newValue) +} + // Reload reads the current configuration file and applies any supported // changes. This returns an error if the server was not started with a config // file or an option which doesn't support hot-swapping was changed. @@ -633,6 +652,7 @@ func (s *Server) Reload() error { gatewayOrgPort := curOpts.Gateway.Port leafnodesOrgPort := curOpts.LeafNode.Port websocketOrgPort := curOpts.Websocket.Port + mqttOrgPort := curOpts.MQTT.Port s.mu.Unlock() @@ -665,6 +685,9 @@ func (s *Server) Reload() error { if newOpts.Websocket.Port == -1 { newOpts.Websocket.Port = websocketOrgPort } + if newOpts.MQTT.Port == -1 { + newOpts.MQTT.Port = mqttOrgPort + } if err := s.reloadOptions(curOpts, newOpts); err != nil { return err @@ -766,7 +789,7 @@ func imposeOrder(value interface{}) error { case WebsocketOpts: sort.Strings(value.AllowedOrigins) case string, bool, int, int32, int64, time.Duration, float64, nil, - LeafNodeOpts, ClusterOpts, *tls.Config, *URLAccResolver, *MemAccResolver, *DirAccResolver, *CacheDirAccResolver, Authentication: + LeafNodeOpts, ClusterOpts, *tls.Config, *URLAccResolver, *MemAccResolver, *DirAccResolver, *CacheDirAccResolver, Authentication, MQTTOpts: // explicitly skipped types default: // this will fail during unit tests @@ -973,6 +996,20 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { return nil, fmt.Errorf("config reload not supported for %s: old=%v, new=%v", field.Name, oldValue, newValue) } + case "mqtt": + diffOpts = append(diffOpts, &mqttAckWaitReload{newValue: newValue.(MQTTOpts).AckWait}) + diffOpts = append(diffOpts, &mqttMaxAckPendingReload{newValue: newValue.(MQTTOpts).MaxAckPending}) + // Nil out/set to 0 the options that we allow to be reloaded so that + // we only fail reload if some that we don't support are changed. + tmpOld := oldValue.(MQTTOpts) + tmpNew := newValue.(MQTTOpts) + tmpOld.TLSConfig, tmpOld.AckWait, tmpOld.MaxAckPending = nil, 0, 0 + tmpNew.TLSConfig, tmpNew.AckWait, tmpNew.MaxAckPending = nil, 0, 0 + if !reflect.DeepEqual(tmpOld, tmpNew) { + // See TODO(ik) note below about printing old/new values. + return nil, fmt.Errorf("config reload not supported for %s: old=%v, new=%v", + field.Name, oldValue, newValue) + } case "connecterrorreports": diffOpts = append(diffOpts, &connectErrorReports{newValue: newValue.(int)}) case "reconnecterrorreports": @@ -1233,6 +1270,8 @@ func (s *Server) reloadAuthorization() { // can't hold the lock as go routine reading it may be waiting for lock as well resetCh = s.sys.resetCh } + // Check that publish retained messages sources are still allowed to publish. + s.mqttCheckPubRetainedPerms() s.mu.Unlock() if resetCh != nil { diff --git a/server/server.go b/server/server.go index 455457a7..59767929 100644 --- a/server/server.go +++ b/server/server.go @@ -222,6 +222,9 @@ type Server struct { // Websocket structure websocket srvWebsocket + // MQTT structure + mqtt srvMQTT + // exporting account name the importer experienced issues with incompleteAccExporterMap sync.Map @@ -533,6 +536,9 @@ func validateOptions(o *Options) error { if err := validateClusterName(o); err != nil { return err } + if err := validateMQTTOptions(o); err != nil { + return err + } // Finally check websocket options. return validateWebsocketOptions(o) } @@ -1516,6 +1522,11 @@ func (s *Server) Start() { s.startWebsocketServer() } + // MQTT + if opts.MQTT.Port != 0 { + s.startMQTT() + } + // Start up routing as well if needed. if opts.Cluster.Port != 0 { s.startGoRoutine(func() { @@ -1611,6 +1622,13 @@ func (s *Server) Shutdown() { s.websocket.listener = nil } + // Kick MQTT accept loop + if s.mqtt.listener != nil { + doneExpected++ + s.mqtt.listener.Close() + s.mqtt.listener = nil + } + // Kick leafnodes AcceptLoop() if s.leafNodeListener != nil { doneExpected++ @@ -1747,7 +1765,7 @@ func (s *Server) AcceptLoop(clr chan struct{}) { s.clientConnectURLs = s.getClientConnectURLs() s.listener = l - go s.acceptConnections(l, "Client", func(conn net.Conn) { s.createClient(conn, nil) }, + go s.acceptConnections(l, "Client", func(conn net.Conn) { s.createClient(conn, nil, nil) }, func(_ error) bool { if s.isLameDuckMode() { // Signal that we are not accepting new clients @@ -2073,7 +2091,7 @@ func (c *tlsMixConn) Read(b []byte) (int, error) { return c.Conn.Read(b) } -func (s *Server) createClient(conn net.Conn, ws *websocket) *client { +func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client { // Snapshot server options. opts := s.getOpts() @@ -2086,32 +2104,49 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { now := time.Now() c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now, ws: ws} + if mqtt != nil { + c.mqtt = mqtt + // Set some of the options here since MQTT clients don't + // send a regular CONNECT (but have their own). + c.opts.Lang = "mqtt" + c.opts.Verbose = false + } c.registerWithAccount(s.globalAccount()) - // Grab JSON info string + var info Info + var authRequired bool + s.mu.Lock() - info := s.copyInfo() - // If this is a websocket client and there is no top-level auth specified, - // then we use the websocket's specific boolean that will be set to true - // if there is any auth{} configured in websocket{}. - if ws != nil && !info.AuthRequired { - info.AuthRequired = s.websocket.authOverride + // We don't need the INFO to mqtt clients. + if mqtt == nil { + // Grab JSON info string + info = s.copyInfo() + // If this is a websocket client and there is no top-level auth specified, + // then we use the websocket's specific boolean that will be set to true + // if there is any auth{} configured in websocket{}. + if ws != nil && !info.AuthRequired { + info.AuthRequired = s.websocket.authOverride + } + if s.nonceRequired() { + // Nonce handling + var raw [nonceLen]byte + nonce := raw[:] + s.generateNonce(nonce) + info.Nonce = string(nonce) + } + c.nonce = []byte(info.Nonce) + authRequired = info.AuthRequired + } else { + authRequired = s.info.AuthRequired || s.mqtt.authOverride } - if s.nonceRequired() { - // Nonce handling - var raw [nonceLen]byte - nonce := raw[:] - s.generateNonce(nonce) - info.Nonce = string(nonce) - } - c.nonce = []byte(info.Nonce) + s.totalClients++ s.mu.Unlock() // Grab lock c.mu.Lock() - if info.AuthRequired { + if authRequired { c.flags.set(expectConnect) } @@ -2120,10 +2155,13 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { c.Debugf("Client connection created") - // Send our information. - // Need to be sent in place since writeLoop cannot be started until - // TLS handshake is done (if applicable). - c.sendProtoNow(c.generateClientInfoJSON(info)) + // We don't send the INFO to mqtt clients. + if mqtt == nil { + // Send our information. + // Need to be sent in place since writeLoop cannot be started until + // TLS handshake is done (if applicable). + c.sendProtoNow(c.generateClientInfoJSON(info)) + } // Unlock to register c.mu.Unlock() @@ -2155,6 +2193,22 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { return nil } s.clients[c.cid] = c + + // May be overridden based on type of client. + TLSConfig := opts.TLSConfig + TLSTimeout := opts.TLSTimeout + + tlsRequired := info.TLSRequired + // Websocket clients do TLS in the websocket http server. + if ws != nil { + tlsRequired = false + } else if mqtt != nil { + tlsRequired = opts.MQTT.TLSConfig != nil + if tlsRequired { + TLSConfig = opts.MQTT.TLSConfig + TLSTimeout = opts.MQTT.TLSTimeout + } + } s.mu.Unlock() // Re-Grab lock @@ -2163,13 +2217,12 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // Connection could have been closed while sending the INFO proto. isClosed := c.isClosed() - tlsRequired := ws == nil && info.TLSRequired var pre []byte // If we have both TLS and non-TLS allowed we need to see which // one the client wants. if !isClosed && opts.TLSConfig != nil && opts.AllowNonTLS { pre = make([]byte, 4) - c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(opts.TLSTimeout))) + c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(TLSTimeout))) n, _ := io.ReadFull(c.nc, pre[:]) c.nc.SetReadDeadline(time.Time{}) pre = pre[:n] @@ -2190,11 +2243,11 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { pre = nil } - c.nc = tls.Server(c.nc, opts.TLSConfig) + c.nc = tls.Server(c.nc, TLSConfig) conn := c.nc.(*tls.Conn) // Setup the timeout - ttl := secondsToDuration(opts.TLSTimeout) + ttl := secondsToDuration(TLSTimeout) time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) conn.SetReadDeadline(time.Now().Add(ttl)) @@ -2231,7 +2284,7 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // Check for Auth. We schedule this timer after the TLS handshake to avoid // the race where the timer fires during the handshake and causes the // server to write bad data to the socket. See issue #432. - if info.AuthRequired { + if authRequired { timeout := opts.AuthTimeout // For websocket, possibly override only if set. We make sure that // opts.AuthTimeout is set to a default value if not configured, @@ -2239,6 +2292,8 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // if user has explicitly set or not. if ws != nil && opts.Websocket.AuthTimeout != 0 { timeout = opts.Websocket.AuthTimeout + } else if mqtt != nil && opts.MQTT.AuthTimeout != 0 { + timeout = opts.MQTT.AuthTimeout } c.setAuthTimer(secondsToDuration(timeout)) } @@ -2421,7 +2476,11 @@ func (s *Server) removeClient(c *client) { if updateProtoInfoCount { s.cproto-- } + mqtt := c.mqtt != nil s.mu.Unlock() + if mqtt { + s.mqttHandleClosedClient(c) + } case ROUTER: s.removeRoute(c) case GATEWAY: diff --git a/server/server_test.go b/server/server_test.go index 1114315b..b546905f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1245,6 +1245,8 @@ func TestServerShutdownDuringStart(t *testing.T) { o.Websocket.Port = -1 o.Websocket.HandshakeTimeout = 1 o.Websocket.NoTLS = true + o.MQTT.Host = "127.0.0.1" + o.MQTT.Port = -1 // We are going to test that if the server is shutdown // while Start() runs (in this case, before), we don't @@ -1286,6 +1288,9 @@ func TestServerShutdownDuringStart(t *testing.T) { if s.websocket.listener != nil { listeners = append(listeners, "websocket") } + if s.mqtt.listener != nil { + listeners = append(listeners, "mqtt") + } s.mu.Unlock() if len(listeners) > 0 { lst := "" diff --git a/server/stream.go b/server/stream.go index 729c1db4..d709767a 100644 --- a/server/stream.go +++ b/server/stream.go @@ -27,6 +27,7 @@ import ( "path/filepath" "reflect" "strconv" + "strings" "sync" "time" @@ -95,6 +96,7 @@ type Stream struct { ddarr []*ddentry ddindex int ddtmr *time.Timer + nosubj bool } // Headers for published messages. @@ -125,13 +127,20 @@ func (a *Account) AddStream(config *StreamConfig) (*Stream, error) { // AddStreamWithStore adds a stream for the given account with custome store config options. func (a *Account) AddStreamWithStore(config *StreamConfig, fsConfig *FileStoreConfig) (*Stream, error) { + if strings.HasPrefix(config.Name, mqttStreamNamePrefix) { + return nil, fmt.Errorf("prefix %q is reserved for MQTT, unable to create stream %q", mqttStreamNamePrefix, config.Name) + } + return a.addStreamWithStore(config, fsConfig, false) +} + +func (a *Account) addStreamWithStore(config *StreamConfig, fsConfig *FileStoreConfig, noSubjectsOK bool) (*Stream, error) { s, jsa, err := a.checkForJetStream() if err != nil { return nil, err } // Sensible defaults. - cfg, err := checkStreamCfg(config) + cfg, err := checkStreamCfg(config, noSubjectsOK) if err != nil { return nil, err } @@ -168,7 +177,7 @@ func (a *Account) AddStreamWithStore(config *StreamConfig, fsConfig *FileStoreCo // Setup the internal client. c := s.createInternalJetStreamClient() - mset := &Stream{jsa: jsa, config: cfg, client: c, consumers: make(map[string]*Consumer)} + mset := &Stream{jsa: jsa, config: cfg, client: c, consumers: make(map[string]*Consumer), nosubj: noSubjectsOK} jsa.streams[cfg.Name] = mset storeDir := path.Join(jsa.storeDir, streamsDir, cfg.Name) @@ -416,7 +425,7 @@ func (jsa *jsAccount) subjectsOverlap(subjects []string) bool { // Default duplicates window. const StreamDefaultDuplicatesWindow = 2 * time.Minute -func checkStreamCfg(config *StreamConfig) (StreamConfig, error) { +func checkStreamCfg(config *StreamConfig, noSubjectOk bool) (StreamConfig, error) { if config == nil { return StreamConfig{}, fmt.Errorf("stream configuration invalid") } @@ -466,7 +475,9 @@ func checkStreamCfg(config *StreamConfig) (StreamConfig, error) { } if len(cfg.Subjects) == 0 { - cfg.Subjects = append(cfg.Subjects, cfg.Name) + if !noSubjectOk { + cfg.Subjects = append(cfg.Subjects, cfg.Name) + } } else { // We can allow overlaps, but don't allow direct duplicates. dset := make(map[string]struct{}, len(cfg.Subjects)) @@ -519,7 +530,10 @@ func (mset *Stream) Delete() error { // Update will allow certain configuration properties of an existing stream to be updated. func (mset *Stream) Update(config *StreamConfig) error { - cfg, err := checkStreamCfg(config) + mset.mu.RLock() + nosubj := mset.nosubj + mset.mu.RUnlock() + cfg, err := checkStreamCfg(config, nosubj) if err != nil { return err } @@ -691,7 +705,7 @@ func (mset *Stream) subscribeInternal(subject string, cb msgHandler) (*subscript mset.sid++ // Now create the subscription - return c.processSub([]byte(subject), nil, []byte(strconv.Itoa(mset.sid)), cb, false) + return c.processSub(c.createSub([]byte(subject), nil, []byte(strconv.Itoa(mset.sid)), cb), false) } // Helper for unlocked stream. @@ -901,7 +915,7 @@ func getExpectedLastSeq(hdr []byte) uint64 { } // processInboundJetStreamMsg handles processing messages bound for a stream. -func (mset *Stream) processInboundJetStreamMsg(_ *subscription, pc *client, subject, reply string, msg []byte) { +func (mset *Stream) processInboundJetStreamMsg(sub *subscription, pc *client, subject, reply string, msg []byte) { mset.mu.Lock() store := mset.store c := mset.client diff --git a/server/sublist.go b/server/sublist.go index 1c860a83..fcb0c28d 100644 --- a/server/sublist.go +++ b/server/sublist.go @@ -1392,3 +1392,65 @@ func (s *Sublist) collectAllSubs(l *level, subs *[]*subscription) { s.collectAllSubs(l.fwc.next, subs) } } + +func (s *Sublist) ReverseMatch(subject string) *SublistResult { + tsa := [32]string{} + tokens := tsa[:0] + start := 0 + for i := 0; i < len(subject); i++ { + if subject[i] == btsep { + tokens = append(tokens, subject[start:i]) + start = i + 1 + } + } + tokens = append(tokens, subject[start:]) + + result := &SublistResult{} + + s.Lock() + reverseMatchLevel(s.root, tokens, nil, result) + // Check for empty result. + if len(result.psubs) == 0 && len(result.qsubs) == 0 { + result = emptyResult + } + s.Unlock() + + return result +} + +func reverseMatchLevel(l *level, toks []string, n *node, results *SublistResult) { + for i, t := range toks { + if l == nil { + return + } + if len(t) == 1 { + if t[0] == fwc { + getAllNodes(l, results) + return + } else if t[0] == pwc { + for _, n := range l.nodes { + reverseMatchLevel(n.next, toks[i+1:], n, results) + } + return + } + } + n = l.nodes[t] + if n == nil { + break + } + l = n.next + } + if n != nil { + addNodeToResults(n, results) + } +} + +func getAllNodes(l *level, results *SublistResult) { + if l == nil { + return + } + for _, n := range l.nodes { + addNodeToResults(n, results) + getAllNodes(n.next, results) + } +} diff --git a/server/sublist_test.go b/server/sublist_test.go index 43a2ae84..eb35e1ea 100644 --- a/server/sublist_test.go +++ b/server/sublist_test.go @@ -1258,6 +1258,74 @@ func TestSublistRegisterInterestNotification(t *testing.T) { expectFalse() } +func TestSublistReverseMatch(t *testing.T) { + s := NewSublistWithCache() + fooSub := newSub("foo") + barSub := newSub("bar") + fooBarSub := newSub("foo.bar") + fooBazSub := newSub("foo.baz") + fooBarBazSub := newSub("foo.bar.baz") + s.Insert(fooSub) + s.Insert(barSub) + s.Insert(fooBarSub) + s.Insert(fooBazSub) + s.Insert(fooBarBazSub) + + r := s.ReverseMatch("foo") + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, fooSub, t) + + r = s.ReverseMatch("bar") + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, barSub, t) + + r = s.ReverseMatch("*") + verifyLen(r.psubs, 2, t) + verifyMember(r.psubs, fooSub, t) + verifyMember(r.psubs, barSub, t) + + r = s.ReverseMatch("baz") + verifyLen(r.psubs, 0, t) + + r = s.ReverseMatch("foo.*") + verifyLen(r.psubs, 2, t) + verifyMember(r.psubs, fooBarSub, t) + verifyMember(r.psubs, fooBazSub, t) + + r = s.ReverseMatch("*.*") + verifyLen(r.psubs, 2, t) + verifyMember(r.psubs, fooBarSub, t) + verifyMember(r.psubs, fooBazSub, t) + + r = s.ReverseMatch("*.bar") + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, fooBarSub, t) + + r = s.ReverseMatch("*.baz") + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, fooBazSub, t) + + r = s.ReverseMatch("bar.*") + verifyLen(r.psubs, 0, t) + + r = s.ReverseMatch("*.bat") + verifyLen(r.psubs, 0, t) + + r = s.ReverseMatch("foo.>") + verifyLen(r.psubs, 3, t) + verifyMember(r.psubs, fooBarSub, t) + verifyMember(r.psubs, fooBazSub, t) + verifyMember(r.psubs, fooBarBazSub, t) + + r = s.ReverseMatch(">") + verifyLen(r.psubs, 5, t) + verifyMember(r.psubs, fooSub, t) + verifyMember(r.psubs, barSub, t) + verifyMember(r.psubs, fooBarSub, t) + verifyMember(r.psubs, fooBazSub, t) + verifyMember(r.psubs, fooBarBazSub, t) +} + // -- Benchmarks Setup -- var benchSublistSubs []*subscription diff --git a/server/websocket.go b/server/websocket.go index 7b10877d..10defdc4 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -856,6 +856,7 @@ func (s *Server) startWebsocketServer() { if port == 0 { s.opts.Websocket.Port = hl.Addr().(*net.TCPAddr).Port } + s.Noticef("Listening for websocket clients on %s://%s:%d", proto, o.Host, port) s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port) if err != nil { s.Fatalf("Unable to get websocket connect URLs: %v", err) @@ -870,7 +871,7 @@ func (s *Server) startWebsocketServer() { s.Errorf(err.Error()) return } - s.createClient(res.conn, res.ws) + s.createClient(res.conn, res.ws, nil) }) hs := &http.Server{ Addr: hp, diff --git a/server/websocket_test.go b/server/websocket_test.go index fa5e235b..e0b075fb 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -2095,6 +2095,10 @@ func TestWSTLSVerifyAndMap(t *testing.T) { {"no filtering, client does not provide cert", false, false}, {"filtering, client provides cert", true, true}, {"filtering, client does not provide cert", true, false}, + {"no users override, client provides cert", false, true}, + {"no users override, client does not provide cert", false, false}, + {"users override, client provides cert", true, true}, + {"users override, client does not provide cert", true, false}, } { t.Run(test.name, func(t *testing.T) { o := testWSOptions()