From 1dba6418ed73d8b65e6d021966beee525febcfd4 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Thu, 28 May 2020 18:12:54 -0600 Subject: [PATCH] [ADDED] MQTT Support This PR introduces native support for MQTT clients. It requires use of accounts with JetStream enabled. Since as of now clustering is not available, MQTT will be limited to single instance. Only QoS 0 and 1 are supported at the moment. MQTT clients can exchange messages with NATS clients and vice-versa. Since JetStream is required, accounts with JetStream enabled must exist in order for an MQTT client to connect to the NATS Server. The administrator can limit the users that can use MQTT with the allowed_connection_types option in the user section. For instance: ``` accounts { mqtt { users [ {user: all, password: pwd, allowed_connection_types: ["STANDARD", "WEBSOCKET", "MQTT"]} {user: mqtt_only, password: pwd, allowed_connection_types: "MQTT"} ] jetstream: enabled } } ``` The "mqtt_only" can only be used for MQTT connections, which the user "all" accepts standard, websocket and MQTT clients. Here is what a configuration to enable MQTT looks like: ``` mqtt { # Specify a host and port to listen for websocket connections # # listen: "host:port" # It can also be configured with individual parameters, # namely host and port. # # host: "hostname" port: 1883 # TLS configuration section # # tls { # cert_file: "/path/to/cert.pem" # key_file: "/path/to/key.pem" # ca_file: "/path/to/ca.pem" # # # Time allowed for the TLS handshake to complete # timeout: 2.0 # # # Takes the user name from the certificate # # # # verify_an_map: true #} # Authentication override. Here are possible options. # # authorization { # # Simple username/password # # # user: "some_user_name" # password: "some_password" # # # Token. The server will check the MQTT's password in the connect # # protocol against this token. # # # # token: "some_token" # # # Time allowed for the client to send the MQTT connect protocol # # after the TCP connection is established. # # # timeout: 2.0 #} # If an MQTT client connects and does not provide a username/password and # this option is set, the server will use this client (and therefore account). # # no_auth_user: "some_user_name" # This is the time after which the server will redeliver a QoS 1 message # sent to a subscription that has not acknowledged (PUBACK) the message. # The default is 30 seconds. # # ack_wait: "1m" # This limits the number of QoS1 messages sent to a session without receiving # acknowledgement (PUBACK) from that session. MQTT specification defines # a packet identifier as an unsigned int 16, which means that the maximum # value is 65535. The default value is 1024. # # max_ack_pending: 100 } ``` Signed-off-by: Ivan Kozlovic --- server/accounts.go | 6 +- server/auth.go | 33 +- server/client.go | 67 +- server/client_test.go | 19 +- server/config_check_test.go | 78 +- server/consumer.go | 26 +- server/events.go | 3 +- server/jetstream.go | 13 +- server/monitor.go | 2 + server/mqtt.go | 2515 +++++++++++++++++++++ server/mqtt_test.go | 4111 +++++++++++++++++++++++++++++++++++ server/opts.go | 136 +- server/parser.go | 5 + server/reload.go | 41 +- server/server.go | 113 +- server/server_test.go | 5 + server/stream.go | 28 +- server/sublist.go | 62 + server/sublist_test.go | 68 + server/websocket.go | 3 +- server/websocket_test.go | 4 + 21 files changed, 7242 insertions(+), 96 deletions(-) create mode 100644 server/mqtt.go create mode 100644 server/mqtt_test.go diff --git a/server/accounts.go b/server/accounts.go index 9a68c032..02d32028 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -1726,7 +1726,7 @@ func (a *Account) subscribeInternal(subject string, cb msgHandler) (*subscriptio return nil, fmt.Errorf("no internal account client") } - return c.processSub([]byte(subject), nil, []byte(sid), cb, false) + return c.processSub(c.createSub([]byte(subject), nil, []byte(sid), cb), false) } // This will add an account subscription that matches the "from" from a service import entry. @@ -1751,7 +1751,7 @@ func (a *Account) addServiceImportSub(si *serviceImport) error { cb := func(sub *subscription, c *client, subject, reply string, msg []byte) { c.processServiceImport(si, a, msg) } - _, err := c.processSub([]byte(subject), nil, []byte(sid), cb, true) + _, err := c.processSub(c.createSub([]byte(subject), nil, []byte(sid), cb), true) return err } @@ -1951,7 +1951,7 @@ func (a *Account) createRespWildcard() []byte { a.mu.Unlock() // Create subscription and internal callback for all the wildcard response subjects. - c.processSub(wcsub, nil, []byte(sid), a.processServiceImportResponse, false) + c.processSub(c.createSub(wcsub, nil, []byte(sid), a.processServiceImportResponse), false) return pre } diff --git a/server/auth.go b/server/auth.go index 7043b2cc..da27dc7e 100644 --- a/server/auth.go +++ b/server/auth.go @@ -253,6 +253,8 @@ func (s *Server) configureAuthorization() { // Do similar for websocket config s.wsConfigAuth(&opts.Websocket) + // And for mqtt config + s.mqttConfigAuth(&opts.MQTT) } // Takes the given slices of NkeyUser and User options and build @@ -343,11 +345,15 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo ) s.mu.Lock() authRequired := s.info.AuthRequired - // c.ws is immutable, but may need lock if we get race reports. - if !authRequired && c.ws != nil { - // If no auth required for regular clients, then check if - // we have an override for websocket clients. - authRequired = s.websocket.authOverride + // c.ws/mqtt is immutable, but may need lock if we get race reports. + if !authRequired { + if c.mqtt != nil { + authRequired = s.mqtt.authOverride + } else if c.ws != nil { + // If no auth required for regular clients, then check if + // we have an override for websocket clients. + authRequired = s.websocket.authOverride + } } if !authRequired { // TODO(dlc) - If they send us credentials should we fail? @@ -361,7 +367,20 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo noAuthUser string ) tlsMap := opts.TLSMap - if c.ws != nil { + if c.mqtt != nil { + mo := &opts.MQTT + // Always override TLSMap. + tlsMap = mo.TLSMap + // The rest depends on if there was any auth override in + // the mqtt's config. + if s.mqtt.authOverride { + noAuthUser = mo.NoAuthUser + username = mo.Username + password = mo.Password + token = mo.Token + ao = true + } + } else if c.ws != nil { wo := &opts.Websocket // Always override TLSMap. tlsMap = wo.TLSMap @@ -998,7 +1017,7 @@ func validateAllowedConnectionTypes(m map[string]struct{}) error { for ct := range m { ctuc := strings.ToUpper(ct) switch ctuc { - case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode: + case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeMqtt: default: return fmt.Errorf("unknown connection type %q", ct) } diff --git a/server/client.go b/server/client.go index 6b9bb532..3a68b9d2 100644 --- a/server/client.go +++ b/server/client.go @@ -178,6 +178,7 @@ const ( NoRespondersRequiresHeaders ClusterNameConflict DuplicateRemoteLeafnodeConnection + DuplicateClientID ) // Some flags passed to processMsgResults @@ -237,6 +238,7 @@ type client struct { gw *gateway leaf *leaf ws *websocket + mqtt *mqtt // To keep track of gateway replies mapping gwrm map[string]*gwReplyMap @@ -442,6 +444,7 @@ type subscription struct { max int64 qw int32 closed int32 + mqtt *mqttSub } // Indicate that this subscription is closed. @@ -540,6 +543,8 @@ func (c *client) initClient() { name := "cid" if c.ws != nil { name = "wid" + } else if c.mqtt != nil { + name = "mid" } c.ncs.Store(fmt.Sprintf("%s - %s:%d", conn, name, c.cid)) case ROUTER: @@ -964,6 +969,9 @@ func (c *client) readLoop(pre []byte) { } nc := c.nc ws := c.ws != nil + if c.mqtt != nil { + c.mqtt.r = &mqttReader{reader: nc} + } c.in.rsz = startBufSize // Snapshot max control line since currently can not be changed on reload and we // were checking it on each call to parse. If this changes and we allow MaxControlLine @@ -983,6 +991,9 @@ func (c *client) readLoop(pre []byte) { c.mu.Unlock() defer func() { + if c.mqtt != nil { + s.mqttHandleWill(c) + } // These are used only in the readloop, so we can set them to nil // on exit of the readLoop. c.in.results, c.in.pacache = nil, nil @@ -1683,7 +1694,6 @@ func (c *client) processConnect(arg []byte) error { // By default register with the global account. c.registerWithAccount(srv.globalAccount()) } - } switch kind { @@ -1703,7 +1713,6 @@ func (c *client) processConnect(arg []byte) error { c.sendErr(ErrNoRespondersRequiresHeaders.Error()) c.closeConnection(NoRespondersRequiresHeaders) return ErrNoRespondersRequiresHeaders - } if verbose { c.sendOK() @@ -1933,6 +1942,9 @@ func (c *client) sendRTTPing() bool { // the c.rtt is 0 and wants to force an update by sending a PING. // Client lock held on entry. func (c *client) sendRTTPingLocked() bool { + if c.mqtt != nil { + return false + } // Most client libs send a CONNECT+PING and wait for a PONG from the // server. So if firstPongSent flag is set, it is ok for server to // send the PING. But in case we have client libs that don't do that, @@ -1978,7 +1990,9 @@ func (c *client) sendErr(err string) { if c.trace { c.traceOutOp("-ERR", []byte(err)) } - c.enqueueProto([]byte(fmt.Sprintf(errProto, err))) + if c.mqtt == nil { + c.enqueueProto([]byte(fmt.Sprintf(errProto, err))) + } c.mu.Unlock() } @@ -2212,32 +2226,30 @@ func (c *client) parseSub(argo []byte, noForward bool) error { arg := make([]byte, len(argo)) copy(arg, argo) args := splitArg(arg) - var ( - subject []byte - queue []byte - sid []byte - ) + sub := &subscription{client: c} switch len(args) { case 2: - subject = args[0] - queue = nil - sid = args[1] + sub.subject = args[0] + sub.queue = nil + sub.sid = args[1] case 3: - subject = args[0] - queue = args[1] - sid = args[2] + sub.subject = args[0] + sub.queue = args[1] + sub.sid = args[2] default: return fmt.Errorf("processSub Parse Error: '%s'", arg) } // If there was an error, it has been sent to the client. We don't return an // error here to not close the connection as a parsing error. - c.processSub(subject, queue, sid, nil, noForward) + c.processSub(sub, noForward) return nil } -func (c *client) processSub(subject, queue, bsid []byte, cb msgHandler, noForward bool) (*subscription, error) { - // Create the subscription - sub := &subscription{client: c, subject: subject, queue: queue, sid: bsid, icb: cb} +func (c *client) createSub(subject, queue, sid []byte, cb msgHandler) *subscription { + return &subscription{client: c, subject: subject, queue: queue, sid: sid, icb: cb} +} + +func (c *client) processSub(sub *subscription, noForward bool) (*subscription, error) { c.mu.Lock() @@ -2254,7 +2266,7 @@ func (c *client) processSub(subject, queue, bsid []byte, cb msgHandler, noForwar // This check does not apply to SYSTEM or JETSTREAM or ACCOUNT clients (because they don't have a `nc`...) if c.isClosed() && (kind != SYSTEM && kind != JETSTREAM && kind != ACCOUNT) { c.mu.Unlock() - return sub, nil + return nil, nil } // Check permissions if applicable. @@ -2301,6 +2313,9 @@ func (c *client) processSub(subject, queue, bsid []byte, cb msgHandler, noForwar updateGWs = c.srv.gateway.enabled } } + } else if es.mqtt != nil && sub.mqtt != nil { + es.mqtt.prm = sub.mqtt.prm + es.mqtt.qos = sub.mqtt.qos } // Unlocked from here onward c.mu.Unlock() @@ -3311,6 +3326,11 @@ func (c *client) processInboundClientMsg(msg []byte) bool { c.sendOK() } + // If MQTT client, check for retain flag now that we have passed permissions check + if c.mqtt != nil { + c.mqttHandlePubRetain() + } + // Check if this client's gateway replies map is not empty if atomic.LoadInt32(&c.cgwrt) > 0 && c.handleGWReplyMap(msg) { return true @@ -3983,7 +4003,7 @@ func (c *client) processPingTimer() { c.mu.Lock() c.ping.tmr = nil // Check if connection is still opened - if c.isClosed() { + if c.isClosed() || c.mqtt != nil { c.mu.Unlock() return } @@ -4042,7 +4062,7 @@ func adjustPingIntervalForGateway(d time.Duration) time.Duration { // Lock should be held func (c *client) setPingTimer() { - if c.srv == nil { + if c.srv == nil || c.mqtt != nil { return } d := c.srv.getOpts().PingInterval @@ -4596,7 +4616,7 @@ func convertAllowedConnectionTypes(cts []string) (map[string]struct{}, error) { for _, i := range cts { i = strings.ToUpper(i) switch i { - case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode: + case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeMqtt: m[i] = struct{}{} default: unknown = append(unknown, i) @@ -4627,6 +4647,9 @@ func (c *client) connectionTypeAllowed(acts map[string]struct{}) bool { if c.ws != nil { want = jwt.ConnectionTypeWebsocket } + if c.mqtt != nil { + want = jwt.ConnectionTypeMqtt + } _, ok := acts[want] return ok } diff --git a/server/client_test.go b/server/client_test.go index f207854e..68e48dcc 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -85,7 +85,7 @@ func createClientAsync(ch chan *client, s *Server, cli net.Conn) { s.grWG.Add(1) } go func() { - c := s.createClient(cli, nil) + c := s.createClient(cli, nil, nil) // Must be here to suppress +OK c.opts.Verbose = false if startWriteLoop { @@ -2317,7 +2317,7 @@ func TestCloseConnectionVeryEarly(t *testing.T) { // Call again with this closed connection. Alternatively, we // would have to call with a fake connection that implements // net.Conn but returns an error on Write. - s.createClient(c, nil) + s.createClient(c, nil, nil) // This connection should not have been added to the server. checkClientsCount(t, s, 0) @@ -2354,18 +2354,23 @@ func TestClientConnectionName(t *testing.T) { kind int kindStr string ws bool + mqtt bool }{ - {"client", CLIENT, "cid:", false}, - {"ws client", CLIENT, "wid:", true}, - {"route", ROUTER, "rid:", false}, - {"gateway", GATEWAY, "gid:", false}, - {"leafnode", LEAF, "lid:", false}, + {"client", CLIENT, "cid:", false, false}, + {"ws client", CLIENT, "wid:", true, false}, + {"mqtt client", CLIENT, "mid:", false, true}, + {"route", ROUTER, "rid:", false, false}, + {"gateway", GATEWAY, "gid:", false, false}, + {"leafnode", LEAF, "lid:", false, false}, } { t.Run(test.name, func(t *testing.T) { c := &client{srv: s, nc: &connString{}, kind: test.kind} if test.ws { c.ws = &websocket{} } + if test.mqtt { + c.mqtt = &mqtt{} + } c.initClient() if host := "fe80::abc:def:ghi:123%utun0"; host != c.host { diff --git a/server/config_check_test.go b/server/config_check_test.go index 166aa761..45d726e9 100644 --- a/server/config_check_test.go +++ b/server/config_check_test.go @@ -1301,8 +1301,8 @@ func TestConfigCheck(t *testing.T) { user: user1 password: pwd users = [{user: user2, password: pwd}] - } - } + } + } `, err: errors.New("can not have a single user/pass and a users array"), errorLine: 3, @@ -1312,17 +1312,75 @@ func TestConfigCheck(t *testing.T) { name: "duplicate usernames in leafnode authorization", config: ` leafnodes { - authorization { - users = [ - {user: user, password: pwd} - {user: user, password: pwd} - ] - } - } + authorization { + users = [ + {user: user, password: pwd} + {user: user, password: pwd} + ] + } + } `, err: errors.New(`duplicate user "user" detected in leafnode authorization`), errorLine: 3, - errorPos: 20, + errorPos: 21, + }, + { + name: "mqtt bad type", + config: ` + mqtt [ + "wrong" + ] + `, + err: errors.New(`Expected mqtt to be a map, got []interface {}`), + errorLine: 2, + errorPos: 17, + }, + { + name: "mqtt bad listen", + config: ` + mqtt { + listen: "xxxxxxxx" + } + `, + err: errors.New(`could not parse address string "xxxxxxxx"`), + errorLine: 3, + errorPos: 21, + }, + { + name: "mqtt bad host", + config: ` + mqtt { + host: 1234 + } + `, + err: errors.New(`interface conversion: interface {} is int64, not string`), + errorLine: 3, + errorPos: 21, + }, + { + name: "mqtt bad port", + config: ` + mqtt { + port: "abc" + } + `, + err: errors.New(`interface conversion: interface {} is string, not int64`), + errorLine: 3, + errorPos: 21, + }, + { + name: "mqtt bad TLS", + config: ` + mqtt { + port: -1 + tls { + cert_file: "./configs/certs/server.pem" + } + } + `, + err: errors.New(`missing 'key_file' in TLS configuration`), + errorLine: 4, + errorPos: 21, }, { name: "connection types wrong type", diff --git a/server/consumer.go b/server/consumer.go index edd618a4..ef0e71a4 100644 --- a/server/consumer.go +++ b/server/consumer.go @@ -213,6 +213,13 @@ const ( ) func (mset *Stream) AddConsumer(config *ConsumerConfig) (*Consumer, error) { + if name := mset.Name(); strings.HasPrefix(name, mqttStreamNamePrefix) { + return nil, fmt.Errorf("stream prefix %q is reserved for MQTT, unable to create consumer on %q", mqttStreamNamePrefix, name) + } + return mset.addConsumerCheckInterest(config, true) +} + +func (mset *Stream) addConsumerCheckInterest(config *ConsumerConfig, checkInterest bool) (*Consumer, error) { if config == nil { return nil, fmt.Errorf("consumer config required") } @@ -350,7 +357,7 @@ func (mset *Stream) AddConsumer(config *ConsumerConfig) (*Consumer, error) { } else { // If we are a push mode and not active and the only difference // is deliver subject then update and return. - if configsEqualSansDelivery(ocfg, *config) && eo.hasNoLocalInterest() { + if configsEqualSansDelivery(ocfg, *config) && (!checkInterest || eo.hasNoLocalInterest()) { eo.updateDeliverSubject(config.DeliverSubject) return eo, nil } else { @@ -2177,9 +2184,22 @@ func (mset *Stream) deliveryFormsCycle(deliverySubject string) bool { return false } -// This is same as check for delivery cycle. +// Check that the subject is a subset of the stream's configured subjects, +// or returns true if the stream has been created with no subject. func (mset *Stream) validSubject(partitionSubject string) bool { - return mset.deliveryFormsCycle(partitionSubject) + mset.mu.RLock() + defer mset.mu.RUnlock() + + if mset.nosubj && len(mset.config.Subjects) == 0 { + return true + } + + for _, subject := range mset.config.Subjects { + if subjectIsSubsetMatch(partitionSubject, subject) { + return true + } + } + return false } // SetInActiveDeleteThreshold sets the delete threshold for how long to wait diff --git a/server/events.go b/server/events.go index 449198cd..3e9c84dc 100644 --- a/server/events.go +++ b/server/events.go @@ -1410,7 +1410,8 @@ func (s *Server) systemSubscribe(subject, queue string, internalOnly bool, cb ms if queue != "" { q = []byte(queue) } - return c.processSub([]byte(subject), q, []byte(sid), cb, internalOnly) + // Now create the subscription + return c.processSub(c.createSub([]byte(subject), q, []byte(sid), cb), internalOnly) } func (s *Server) sysUnsubscribe(sub *subscription) { diff --git a/server/jetstream.go b/server/jetstream.go index 9a591471..0c651a01 100644 --- a/server/jetstream.go +++ b/server/jetstream.go @@ -533,7 +533,12 @@ func (a *Account) EnableJetStream(limits *JetStreamAccountLimits) error { s.Warnf(" Error adding Stream %q to Template %q: %v", cfg.Name, cfg.Template, err) } } - mset, err := a.AddStream(&cfg.StreamConfig) + // TODO: We should not rely on the stream name. + // However, having a StreamConfig property, such as AllowNoSubject, + // was not accepted because it does not make sense outside of the + // MQTT use-case. So need to revisit this. + mqtt := cfg.StreamConfig.Name == mqttStreamName + mset, err := a.addStreamWithStore(&cfg.StreamConfig, nil, mqtt) if err != nil { s.Warnf(" Error recreating Stream %q: %v", cfg.Name, err) continue @@ -578,7 +583,7 @@ func (a *Account) EnableJetStream(limits *JetStreamAccountLimits) error { // the consumer can reconnect. We will create it as a durable and switch it. cfg.ConsumerConfig.Durable = ofi.Name() } - obs, err := mset.AddConsumer(&cfg.ConsumerConfig) + obs, err := mset.addConsumerCheckInterest(&cfg.ConsumerConfig, !mqtt) if err != nil { s.Warnf(" Error adding Consumer: %v", err) continue @@ -1060,7 +1065,7 @@ func (a *Account) AddStreamTemplate(tc *StreamTemplateConfig) (*StreamTemplate, // FIXME(dlc) - Hacky tcopy := tc.deepCopy() tcopy.Config.Name = "_" - cfg, err := checkStreamCfg(tcopy.Config) + cfg, err := checkStreamCfg(tcopy.Config, false) if err != nil { return nil, err } @@ -1115,7 +1120,7 @@ func (t *StreamTemplate) createTemplateSubscriptions() error { sid := 1 for _, subject := range t.Config.Subjects { // Now create the subscription - if _, err := c.processSub([]byte(subject), nil, []byte(strconv.Itoa(sid)), t.processInboundTemplateMsg, false); err != nil { + if _, err := c.processSub(c.createSub([]byte(subject), nil, []byte(strconv.Itoa(sid)), t.processInboundTemplateMsg), false); err != nil { c.acc.DeleteStreamTemplate(t.Name) return err } diff --git a/server/monitor.go b/server/monitor.go index 4bae0fda..6c941f2a 100644 --- a/server/monitor.go +++ b/server/monitor.go @@ -1915,6 +1915,8 @@ func (reason ClosedState) String() string { return "Cluster Name Conflict" case DuplicateRemoteLeafnodeConnection: return "Duplicate Remote LeafNode Connection" + case DuplicateClientID: + return "Duplicate Client ID" } return "Unknown State" diff --git a/server/mqtt.go b/server/mqtt.go new file mode 100644 index 00000000..4c508b18 --- /dev/null +++ b/server/mqtt.go @@ -0,0 +1,2515 @@ +// Copyright 2020 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "sync" + "time" + "unicode/utf8" + + "github.com/nats-io/nuid" +) + +// References to "spec" here is from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.pdf + +const ( + mqttPacketConnect = byte(0x10) + mqttPacketConnectAck = byte(0x20) + mqttPacketPub = byte(0x30) + mqttPacketPubAck = byte(0x40) + mqttPacketPubRec = byte(0x50) + mqttPacketPubRel = byte(0x60) + mqttPacketPubComp = byte(0x70) + mqttPacketSub = byte(0x80) + mqttPacketSubAck = byte(0x90) + mqttPacketUnsub = byte(0xa0) + mqttPacketUnsubAck = byte(0xb0) + mqttPacketPing = byte(0xc0) + mqttPacketPingResp = byte(0xd0) + mqttPacketDisconnect = byte(0xe0) + mqttPacketMask = byte(0xf0) + mqttPacketFlagMask = byte(0x0f) + + mqttProtoLevel = byte(0x4) + + // Connect flags + mqttConnFlagReserved = byte(0x1) + mqttConnFlagCleanSession = byte(0x2) + mqttConnFlagWillFlag = byte(0x04) + mqttConnFlagWillQoS = byte(0x18) + mqttConnFlagWillRetain = byte(0x20) + mqttConnFlagPasswordFlag = byte(0x40) + mqttConnFlagUsernameFlag = byte(0x80) + + // Publish flags + mqttPubFlagRetain = byte(0x01) + mqttPubFlagQoS = byte(0x06) + mqttPubFlagDup = byte(0x08) + mqttPubQos1 = byte(0x2) // 1 << 1 + + // Subscribe flags + mqttSubscribeFlags = byte(0x2) + mqttSubAckFailure = byte(0x80) + + // Unsubscribe flags + mqttUnsubscribeFlags = byte(0x2) + + // ConnAck returned codes + mqttConnAckRCConnectionAccepted = byte(0x0) + mqttConnAckRCUnacceptableProtocolVersion = byte(0x1) + mqttConnAckRCIdentifierRejected = byte(0x2) + mqttConnAckRCServerUnavailable = byte(0x3) + mqttConnAckRCBadUserOrPassword = byte(0x4) + mqttConnAckRCNotAuthorized = byte(0x5) + + // Topic/Filter characters + mqttTopicLevelSep = '/' + mqttSingleLevelWC = '+' + mqttMultiLevelWC = '#' + + // This is appended to the sid of a subscription that is + // created on the upper level subject because of the MQTT + // wildcard '#' semantic. + mqttMultiLevelSidSuffix = " fwc" + + // This is the prefix for NATS subscriptions subjects associated as delivery + // subject of JS consumer. We want to make them unique so will prevent users + // MQTT subscriptions to start with this. + mqttSubPrefix = "$MQTT.sub." + + // MQTT Stream names prefix. Will be used to prevent users from creating + // JS consumers on those. + mqttStreamNamePrefix = "$MQTT_" + + // Stream name for MQTT messages on a given account + mqttStreamName = mqttStreamNamePrefix + "msgs" + + // Stream name for MQTT retained messages on a given account + mqttRetainedMsgsStreamName = mqttStreamNamePrefix + "rmsgs" + + // Stream name for MQTT sessions on a given account + mqttSessionsStreamName = mqttStreamNamePrefix + "sessions" + + // Normally, MQTT server should not redeliver QoS 1 messages to clients, + // except after client reconnects. However, NATS Server will redeliver + // unacknowledged messages after this default interval. This can be + // changed with the server.Options.MQTT.AckWait option. + mqttDefaultAckWait = 30 * time.Second + + // This is the default for the outstanding number of pending QoS 1 + // messages sent to a session with QoS 1 subscriptions. + mqttDefaultMaxAckPending = 1024 +) + +var ( + mqttPingResponse = []byte{mqttPacketPingResp, 0x0} + mqttProtoName = []byte("MQTT") + mqttOldProtoName = []byte("MQIsdp") +) + +type srvMQTT struct { + listener net.Listener + authOverride bool + sessmgr mqttSessionManager +} + +type mqttSessionManager struct { + mu sync.RWMutex + sessions map[string]*mqttAccountSessionManager // key is account name +} + +type mqttAccountSessionManager struct { + mu sync.RWMutex + sstream *Stream // stream where sessions are recorded + mstream *Stream // messages stream + rstream *Stream // retained messages stream + sessions map[string]*mqttSession // key is MQTT client ID + sl *Sublist // sublist allowing to find retained messages for given subscription + retmsgs map[string]*mqttRetainedMsg // retained messages +} + +type mqttSession struct { + mu sync.Mutex + c *client + subs map[string]byte + cons map[string]*Consumer + stream *Stream + sseq uint64 // stream sequence where this sesion is recorded + pending map[uint16]*mqttPending // Key is the PUBLISH packet identifier sent to client and maps to a mqttPending record + cpending map[*Consumer]map[uint64]uint16 // For each JS consumer, the key is the stream sequence and maps to the PUBLISH packet identifier + ppi uint16 // publish packet identifier + maxp uint16 + stalled bool + clean bool +} + +type mqttPersistedSession struct { + ID string `json:"id,omitempty"` + Clean bool `json:"clean,omitempty"` + Subs map[string]byte `json:"subs,omitempty"` + Cons map[string]string `json:"cons,omitempty"` +} + +type mqttRetainedMsg struct { + Msg []byte `json:"msg,omitempty"` + Flags byte `json:"flags,omitempty"` + Source string `json:"source,omitempty"` + + // non exported + sseq uint64 + sub *subscription +} + +type mqttSub struct { + qos byte + // Pending serialization of retained messages to be sent when subscription is registered + prm *mqttWriter + // This is the corresponding JS consumer. This is applicable to a subscription that is + // done for QoS > 0 (the subscription attached to a JS consumer's delivery subject). + jsCons *Consumer +} + +type mqtt struct { + r *mqttReader + cp *mqttConnectProto + pp *mqttPublish + asm *mqttAccountSessionManager // quick reference to account session manager, immutable after processConnect() + sess *mqttSession // quick reference to session, immutable after processConnect() +} + +type mqttPending struct { + sseq uint64 // stream sequence + dseq uint64 // consumer delivery sequence + dcount uint64 // consumer delivery count + jsCons *Consumer // pointer to JS consumer (to which we will call ackMsg() on with above info) +} + +type mqttConnectProto struct { + clientID string + rd time.Duration + will *mqttWill + flags byte +} + +type mqttIOReader interface { + io.Reader + SetReadDeadline(time.Time) error +} + +type mqttReader struct { + reader mqttIOReader + buf []byte + pos int +} + +type mqttWriter struct { + bytes.Buffer +} + +type mqttWill struct { + topic []byte + message []byte + qos byte + retain bool +} + +type mqttFilter struct { + filter string + qos byte +} + +type mqttPublish struct { + subject []byte + msg []byte + sz int + pi uint16 + flags byte +} + +func (s *Server) startMQTT() { + sopts := s.getOpts() + o := &sopts.MQTT + + var hl net.Listener + var err error + + port := o.Port + if port == -1 { + port = 0 + } + hp := net.JoinHostPort(o.Host, strconv.Itoa(port)) + s.mu.Lock() + if s.shutdown { + s.mu.Unlock() + return + } + s.mqtt.sessmgr.sessions = make(map[string]*mqttAccountSessionManager) + hl, err = net.Listen("tcp", hp) + if err != nil { + s.mu.Unlock() + s.Fatalf("Unable to listen for MQTT connections: %v", err) + return + } + if port == 0 { + o.Port = hl.Addr().(*net.TCPAddr).Port + } + s.mqtt.listener = hl + scheme := "mqtt" + if o.TLSConfig != nil { + scheme = "tls" + } + s.Noticef("Listening for MQTT clients on %s://%s:%d", scheme, o.Host, o.Port) + go s.acceptConnections(hl, "MQTT", func(conn net.Conn) { s.createClient(conn, nil, &mqtt{}) }, nil) + s.mu.Unlock() +} + +// Given the mqtt options, we check if any auth configuration +// has been provided. If so, possibly create users/nkey users and +// store them in s.mqtt.users/nkeys. +// Also update a boolean that indicates if auth is required for +// mqtt clients. +// Server lock is held on entry. +func (s *Server) mqttConfigAuth(opts *MQTTOpts) { + mqtt := &s.mqtt + // If any of those is specified, we consider that there is an override. + mqtt.authOverride = opts.Username != _EMPTY_ || opts.Token != _EMPTY_ || opts.NoAuthUser != _EMPTY_ +} + +// Validate the mqtt related options. +func validateMQTTOptions(o *Options) error { + mo := &o.MQTT + // If no port is defined, we don't care about other options + if mo.Port == 0 { + return nil + } + // If there is a NoAuthUser, we need to have Users defined and + // the user to be present. + if mo.NoAuthUser != _EMPTY_ { + if err := validateNoAuthUser(o, mo.NoAuthUser); err != nil { + return err + } + } + // Token/Username not possible if there are users/nkeys + if len(o.Users) > 0 || len(o.Nkeys) > 0 { + if mo.Username != _EMPTY_ { + return fmt.Errorf("mqtt authentication username not compatible with presence of users/nkeys") + } + if mo.Token != _EMPTY_ { + return fmt.Errorf("mqtt authentication token not compatible with presence of users/nkeys") + } + } + if mo.AckWait < 0 { + return fmt.Errorf("ack wait must be a positive value") + } + return nil +} + +// Parse protocols inside the given buffer. +// This is invoked from the readLoop. +func (c *client) mqttParse(buf []byte) error { + c.mu.Lock() + s := c.srv + trace := c.trace + connected := c.flags.isSet(connectReceived) + mqtt := c.mqtt + r := mqtt.r + var rd time.Duration + if mqtt.cp != nil { + rd = mqtt.cp.rd + if rd > 0 { + r.reader.SetReadDeadline(time.Time{}) + } + } + c.mu.Unlock() + + r.reset(buf) + + var err error + var b byte + var pl int + + for err == nil && r.hasMore() { + + // Read packet type and flags + if b, err = r.readByte("packet type"); err != nil { + break + } + + // Packet type + pt := b & mqttPacketMask + + // If client was not connected yet, the first packet must be + // a mqttPacketConnect otherwise we fail the connection. + if !connected && pt != mqttPacketConnect { + err = errors.New("not connected") + break + } + + if pl, err = r.readPacketLen(); err != nil { + break + } + + switch pt { + case mqttPacketPub: + pp := mqttPublish{flags: b & mqttPacketFlagMask} + err = c.mqttParsePub(r, pl, &pp) + if trace { + c.traceInOp("PUBLISH", errOrTrace(err, mqttPubTrace(&pp))) + if err == nil { + c.traceMsg(pp.msg) + } + } + if err == nil { + s.mqttProcessPub(c, &pp) + if pp.pi > 0 { + c.mqttEnqueuePubAck(pp.pi) + if trace { + c.traceOutOp("PUBACK", []byte(fmt.Sprintf("pi=%v", pp.pi))) + } + } + } + case mqttPacketPubAck: + var pi uint16 + pi, err = mqttParsePubAck(r, pl) + if trace { + c.traceInOp("PUBACK", errOrTrace(err, fmt.Sprintf("pi=%v", pi))) + } + if err == nil { + c.mqttProcessPubAck(pi) + } + case mqttPacketSub: + var pi uint16 // packet identifier + var filters []*mqttFilter + var subs []*subscription + pi, filters, err = c.mqttParseSubs(r, b, pl) + if trace { + c.traceInOp("SUBSCRIBE", errOrTrace(err, mqttSubscribeTrace(filters))) + } + if err == nil { + subs, err = c.mqttProcessSubs(filters) + if err == nil && trace { + c.traceOutOp("SUBACK", []byte(mqttSubscribeTrace(filters))) + } + } + if err == nil { + c.mqttEnqueueSubAck(pi, filters) + c.mqttSendRetainedMsgsToNewSubs(subs) + } + case mqttPacketUnsub: + var pi uint16 // packet identifier + var filters []*mqttFilter + pi, filters, err = c.mqttParseUnsubs(r, b, pl) + if trace { + c.traceInOp("UNSUBSCRIBE", errOrTrace(err, mqttUnsubscribeTrace(filters))) + } + if err == nil { + err = c.mqttProcessUnsubs(filters) + if err == nil && trace { + c.traceOutOp("UNSUBACK", []byte(strconv.FormatInt(int64(pi), 10))) + } + } + if err == nil { + c.mqttEnqueueUnsubAck(pi) + } + case mqttPacketPing: + if trace { + c.traceInOp("PINGREQ", nil) + } + c.mqttEnqueuePingResp() + if trace { + c.traceOutOp("PINGRESP", nil) + } + case mqttPacketConnect: + // It is an error to receive a second connect packet + if connected { + err = errors.New("second connect packet") + break + } + var rc byte + var cp *mqttConnectProto + var sessp bool + rc, cp, err = c.mqttParseConnect(r, pl) + if trace && cp != nil { + c.traceInOp("CONNECT", errOrTrace(err, c.mqttConnectTrace(cp))) + } + if rc != 0 { + c.mqttEnqueueConnAck(rc, sessp) + if trace { + c.traceOutOp("CONNACK", []byte(fmt.Sprintf("sp=%v rc=%v", sessp, rc))) + } + } else if err == nil { + if err = s.mqttProcessConnect(c, cp, trace); err == nil { + connected = true + rd = cp.rd + } + } + case mqttPacketDisconnect: + if trace { + c.traceInOp("DISCONNECT", nil) + } + // Normal disconnect, we need to discard the will. + // Spec [MQTT-3.1.2-8] + c.mu.Lock() + if c.mqtt.cp != nil { + c.mqtt.cp.will = nil + } + c.mu.Unlock() + c.closeConnection(ClientClosed) + return nil + case mqttPacketPubRec: + fallthrough + case mqttPacketPubRel: + fallthrough + case mqttPacketPubComp: + err = fmt.Errorf("protocol %d not supported", pt>>4) + default: + err = fmt.Errorf("received unknown packet type %d", pt>>4) + } + } + if err == nil && rd > 0 { + r.reader.SetReadDeadline(time.Now().Add(rd)) + } + return err +} + +// Update the session (possibly remove it) of this disconnected client. +func (s *Server) mqttHandleClosedClient(c *client) { + c.mu.Lock() + cp := c.mqtt.cp + accName := c.acc.Name + c.mu.Unlock() + if cp == nil { + return + } + sm := &s.mqtt.sessmgr + sm.mu.RLock() + asm, ok := sm.sessions[accName] + sm.mu.RUnlock() + if !ok { + return + } + + asm.mu.Lock() + defer asm.mu.Unlock() + es, ok := asm.sessions[cp.clientID] + // If not found, ignore. + if !ok { + return + } + es.mu.Lock() + defer es.mu.Unlock() + // If this client is not the currently registered client, ignore. + if es.c != c { + return + } + // It the session was created with "clean session" flag, we cleanup + // when the client disconnects. + if es.clean { + es.clear() + delete(asm.sessions, cp.clientID) + } else { + // Clear the client from the session, but session stays. + es.c = nil + } +} + +// Updates the MaxAckPending for all MQTT sessions, updating the +// JetStream consumers and updating their max ack pending and forcing +// a expiration of pending messages. +func (s *Server) mqttUpdateMaxAckPending(newmaxp uint16) { + msm := &s.mqtt.sessmgr + s.accounts.Range(func(k, _ interface{}) bool { + accName := k.(string) + msm.mu.RLock() + asm := msm.sessions[accName] + msm.mu.RUnlock() + if asm == nil { + // Move to next account + return true + } + asm.mu.RLock() + for _, sess := range asm.sessions { + sess.mu.Lock() + sess.maxp = newmaxp + // Do not check for sess.stalled here because due to original + // consumer's maxp it is possible that the consumer was stalled + // and MQTT code did not have to stall the session. + for _, cons := range sess.cons { + cons.mu.Lock() + cons.maxp = int(newmaxp) + cons.forceExpirePending() + cons.mu.Unlock() + } + sess.mu.Unlock() + } + asm.mu.RUnlock() + return true + }) +} + +////////////////////////////////////////////////////////////////////////////// +// +// Sessions Manager related functions +// +////////////////////////////////////////////////////////////////////////////// + +// Returns the MQTT sessions manager for a given account. +// If new, creates the required JetStream streams/consumers +// for handling of sessions and messages. +func (sm *mqttSessionManager) getOrCreateAccountSessionManager(clientID string, c *client) (*mqttAccountSessionManager, error) { + c.mu.Lock() + acc := c.acc + accName := acc.GetName() + c.mu.Unlock() + + sm.mu.RLock() + asm, ok := sm.sessions[accName] + sm.mu.RUnlock() + + if ok { + return asm, nil + } + + // Not found, now take the write lock and check again + sm.mu.Lock() + defer sm.mu.Unlock() + asm, ok = sm.sessions[accName] + if ok { + return asm, nil + } + // First check that we have JS enabled for this account. + // TODO: Since we check only when creating a session manager for this + // account, would probably need to do some cleanup if JS can be disabled + // on config reload. + if !acc.JetStreamEnabled() { + return nil, fmt.Errorf("JetStream must be enabled for account %q used by MQTT client ID %q", + accName, clientID) + } + // Need to create one here. + asm = &mqttAccountSessionManager{sessions: make(map[string]*mqttSession)} + if err := asm.init(acc, c); err != nil { + return nil, err + } + sm.sessions[accName] = asm + return asm, nil +} + +////////////////////////////////////////////////////////////////////////////// +// +// Account Sessions Manager related functions +// +////////////////////////////////////////////////////////////////////////////// + +// Creates JS streams/consumers for handling of sessions and messages for this +// account. Note that lookup are performed in case we are in a restart +// situation and the account is loaded for the first time but had state on disk. +// +// Global session manager lock is held on entry. +func (as *mqttAccountSessionManager) init(acc *Account, c *client) error { + opts := c.srv.getOpts() + var err error + // Start with sessions stream + as.sstream, err = acc.LookupStream(mqttSessionsStreamName) + if err != nil { + as.sstream, err = acc.addStreamWithStore(&StreamConfig{ + Subjects: []string{}, + Name: mqttSessionsStreamName, + Storage: FileStorage, + Retention: InterestPolicy, + }, nil, true) + if err != nil { + return fmt.Errorf("unable to create sessions stream for MQTT account %q: %v", acc.GetName(), err) + } + } + // Create the stream for the messages. + as.mstream, err = acc.LookupStream(mqttStreamName) + if err != nil { + as.mstream, err = acc.addStreamWithStore(&StreamConfig{ + Subjects: []string{}, + Name: mqttStreamName, + Storage: FileStorage, + Retention: InterestPolicy, + }, nil, true) + if err != nil { + return fmt.Errorf("unable to create messages stream for MQTT account %q: %v", acc.GetName(), err) + } + } + // Create the stream for retained messages. + as.rstream, err = acc.LookupStream(mqttRetainedMsgsStreamName) + if err != nil { + as.rstream, err = acc.addStreamWithStore(&StreamConfig{ + Subjects: []string{}, + Name: mqttRetainedMsgsStreamName, + Storage: FileStorage, + Retention: InterestPolicy, + }, nil, true) + if err != nil { + return fmt.Errorf("unable to create retained messages stream for MQTT account %q: %v", acc.GetName(), err) + } + } + // Now recover all sessions (in case it did already exist) + if state := as.sstream.State(); state.Msgs > 0 { + for seq := state.FirstSeq; seq <= state.LastSeq; seq++ { + _, _, content, _, err := as.sstream.store.LoadMsg(seq) + if err != nil { + if err != errDeletedMsg { + c.Errorf("Error loading session record at sequence %v: %v", seq, err) + } + continue + } + ps := &mqttPersistedSession{} + if err := json.Unmarshal(content, ps); err != nil { + c.Errorf("Error unmarshaling session record at sequence %v: %v", seq, err) + continue + } + if as.sessions == nil { + as.sessions = make(map[string]*mqttSession) + } + es, ok := as.sessions[ps.ID] + if ok && es.sseq != 0 { + as.sstream.DeleteMsg(es.sseq) + } else if !ok { + es = mqttSessionCreate(opts) + es.stream = as.sstream + as.sessions[ps.ID] = es + } + es.sseq = seq + es.clean = ps.Clean + es.subs = ps.Subs + if l := len(ps.Cons); l > 0 { + if es.cons == nil { + es.cons = make(map[string]*Consumer, l) + } + for sid, name := range ps.Cons { + if cons := as.mstream.LookupConsumer(name); cons != nil { + es.cons[sid] = cons + } + } + } + } + } + // Finally, recover retained messages. + if state := as.rstream.State(); state.Msgs > 0 { + for seq := state.FirstSeq; seq <= state.LastSeq; seq++ { + subject, _, content, _, err := as.rstream.store.LoadMsg(seq) + if err != nil { + if err != errDeletedMsg { + c.Errorf("Error loading retained message at sequence %v: %v", seq, err) + } + continue + } + rm := &mqttRetainedMsg{} + if err := json.Unmarshal(content, &rm); err != nil { + c.Errorf("Error unmarshaling retained message on subject %q, sequence %v: %v", subject, seq, err) + continue + } + rm = as.handleRetainedMsg(subject, rm) + rm.sseq = seq + } + } + return nil +} + +// Add/Replace this message from the retained messages map. +// Returns the retained message actually stored in the map, which means that +// it may be different from the given `rm`. +// +// Account session manager lock held on entry. +func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rm *mqttRetainedMsg) *mqttRetainedMsg { + if as.retmsgs == nil { + as.retmsgs = make(map[string]*mqttRetainedMsg) + as.sl = NewSublistWithCache() + } else { + // Check if we already had one. If so, update the existing one. + if erm, exists := as.retmsgs[key]; exists { + erm.Msg = rm.Msg + erm.Flags = rm.Flags + erm.Source = rm.Source + return erm + } + } + rm.sub = &subscription{subject: []byte(key)} + as.retmsgs[key] = rm + as.sl.Insert(rm.sub) + return rm +} + +// Process subscriptions for the given session/client. +// +// When `fromSubProto` is false, it means that this is invoked from the CONNECT +// protocol, when restoring subscriptions that were saved for this session. +// In that case, there is no need to update the session record. +// +// When `fromSubProto` is true, it means that this call is invoked from the +// processing of the SUBSCRIBE protocol, which means that the session needs to +// be updated. It also means that if a subscription on same subject with same +// QoS already exist, we should not be recreating the subscription/JS durable, +// since it was already done when processing the CONNECT protocol. +// +// Account session manager lock held on entry. +// Session lock held when `fromSubProto` is false. +func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, clientID string, c *client, + filters []*mqttFilter, fromSubProto, trace bool) ([]*subscription, error) { + + if fromSubProto { + sess.mu.Lock() + defer sess.mu.Unlock() + if sess.c != c { + return nil, fmt.Errorf("client %q no longer registered with MQTT session", clientID) + } + } + + addJSConsToSess := func(sid string, cons *Consumer) { + if cons == nil { + return + } + if sess.cons == nil { + sess.cons = make(map[string]*Consumer) + } + sess.cons[sid] = cons + } + + subs := make([]*subscription, 0, len(filters)) + for _, f := range filters { + if f.qos > 1 { + f.qos = 1 + } + subject := f.filter + sid := subject + + if strings.HasPrefix(subject, mqttSubPrefix) { + f.qos = mqttSubAckFailure + continue + } + + var jscons *Consumer + var jssub *subscription + var err error + + sub := c.mqttCreateSub(subject, sid, mqttDeliverMsgCb, f.qos) + if fromSubProto { + as.serializeRetainedMsgsForSub(sess, c, sub, trace) + } + // Note that if a subscription already exists on this subject, + // the sub is updated with the new qos/prm and the pointer to + // the existing subscription is returned. + sub, err = c.processSub(sub, false) + if err == nil { + // This will create (if not already exist) a JS consumer for subscriptions + // of QoS >= 1. But if a JS consumer already exists and the subscription + // for same subject is now a QoS==0, then the JS consumer will be deleted. + jscons, jssub, err = c.mqttProcessJSConsumer(sess, as.mstream, + subject, sid, f.qos, fromSubProto) + } + if err != nil { + c.Errorf("error subscribing to %q: err=%v", subject, err) + f.qos = mqttSubAckFailure + c.mqttCleanupFailedSub(sub, jscons, jssub) + continue + } + if mqttNeedSubForLevelUp(subject) { + var fwjscons *Consumer + var fwjssub *subscription + + // Say subject is "foo.>", remove the ".>" so that it becomes "foo" + fwcsubject := subject[:len(subject)-2] + // Change the sid to "foo fwc" + fwcsid := fwcsubject + mqttMultiLevelSidSuffix + fwcsub := c.mqttCreateSub(fwcsubject, fwcsid, mqttDeliverMsgCb, f.qos) + if fromSubProto { + as.serializeRetainedMsgsForSub(sess, c, fwcsub, trace) + } + // See note above about existing subscription. + fwcsub, err = c.processSub(fwcsub, false) + if err == nil { + fwjscons, fwjssub, err = c.mqttProcessJSConsumer(sess, as.mstream, + fwcsubject, fwcsid, f.qos, fromSubProto) + } + if err != nil { + c.Errorf("error subscribing to %q: err=%v", fwcsubject, err) + f.qos = mqttSubAckFailure + c.mqttCleanupFailedSub(sub, jscons, jssub) + c.mqttCleanupFailedSub(fwcsub, fwjscons, fwjssub) + continue + } + subs = append(subs, fwcsub) + addJSConsToSess(fwcsid, fwjscons) + } + subs = append(subs, sub) + addJSConsToSess(sid, jscons) + } + var err error + if fromSubProto { + err = sess.update(clientID, filters, true) + } + return subs, err +} + +// Retained publish messages matching this subscription are serialized in the +// subscription's `prm` mqtt writer. This buffer will be queued for outbound +// after the subscription is processed and SUBACK is sent or possibly when +// server processes an incoming published message matching the newly +// registered subscription. +// +// Account session manager lock held on entry. +// Session lock held on entry +func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(sess *mqttSession, c *client, sub *subscription, trace bool) { + if len(as.retmsgs) > 0 { + var rmsa [64]*mqttRetainedMsg + rms := rmsa[:0] + + as.getRetainedPublishMsgs(string(sub.subject), &rms) + for _, rm := range rms { + if sub.mqtt.prm == nil { + sub.mqtt.prm = &mqttWriter{} + } + prm := sub.mqtt.prm + pi := sess.getPubAckIdentifier(mqttGetQoS(rm.Flags), sub) + // Need to use the subject for the retained message, not the `sub` subject. + // We can find the published retained message in rm.sub.subject. + flags := mqttSerializePublishMsg(prm, pi, false, true, string(rm.sub.subject), rm.Msg[:len(rm.Msg)-LEN_CR_LF]) + if trace { + pp := mqttPublish{ + flags: flags, + pi: pi, + subject: rm.sub.subject, + sz: len(rm.Msg) - LEN_CR_LF, + } + c.traceOutOp("PUBLISH", []byte(mqttPubTrace(&pp))) + } + } + } +} + +// Returns in the provided slice all publish retained message records that +// match the given subscription's `subject` (which could have wildcards). +// +// Account session manager lock held on entry. +func (as *mqttAccountSessionManager) getRetainedPublishMsgs(subject string, rms *[]*mqttRetainedMsg) { + result := as.sl.ReverseMatch(subject) + if len(result.psubs) == 0 { + return + } + for _, sub := range result.psubs { + // Since this is a reverse match, the subscription objects here + // contain literals corresponding to the published subjects. + if rm, ok := as.retmsgs[string(sub.subject)]; ok { + *rms = append(*rms, rm) + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// +// MQTT session related functions +// +////////////////////////////////////////////////////////////////////////////// + +// Returns a new mqttSession object with max ack pending set based on +// option or use mqttDefaultMaxAckPending if no option set. +func mqttSessionCreate(opts *Options) *mqttSession { + maxp := opts.MQTT.MaxAckPending + if maxp == 0 { + maxp = mqttDefaultMaxAckPending + } + return &mqttSession{maxp: maxp} +} + +// Persists a session. Note that if the session's current client does not match +// the given client, nothing is done. +// +// Lock held on entry. +func (sess *mqttSession) save(clientID string) error { + ps := mqttPersistedSession{ + ID: clientID, + Clean: sess.clean, + Subs: sess.subs, + } + if l := len(sess.cons); l > 0 { + cons := make(map[string]string, l) + for sid, jscons := range sess.cons { + cons[sid] = jscons.Name() + } + ps.Cons = cons + } + sessBytes, _ := json.Marshal(&ps) + newSeq, _, err := sess.stream.store.StoreMsg("sessions", nil, sessBytes) + if err != nil { + return err + } + if sess.sseq != 0 { + sess.stream.DeleteMsg(sess.sseq) + } + sess.sseq = newSeq + return nil +} + +// Delete JS consumers for this session and delete the persisted session from +// the stream. +// +// Lock held on entry. +func (sess *mqttSession) clear() { + for consName, cons := range sess.cons { + delete(sess.cons, consName) + cons.Delete() + } + if sess.stream != nil && sess.sseq != 0 { + sess.stream.DeleteMsg(sess.sseq) + sess.sseq = 0 + } + sess.subs, sess.pending, sess.cpending = nil, nil, nil +} + +// This will update the session record for this client in the account's MQTT +// sessions stream if the session had any change in the subscriptions. +// +// Lock held on entry. +func (sess *mqttSession) update(clientID string, filters []*mqttFilter, add bool) error { + // Evaluate if we need to persist anything. + var needUpdate bool + for _, f := range filters { + if add { + if f.qos == mqttSubAckFailure { + continue + } + if qos, ok := sess.subs[f.filter]; !ok || qos != f.qos { + if sess.subs == nil { + sess.subs = make(map[string]byte) + } + sess.subs[f.filter] = f.qos + needUpdate = true + } + } else { + if _, ok := sess.subs[f.filter]; ok { + delete(sess.subs, f.filter) + needUpdate = true + } + } + } + var err error + if needUpdate { + err = sess.save(clientID) + } + return err +} + +// If both pQos and sub.mqtt.qos are > 0, then this will return the next +// packet identifier to use for a published message. +// +// Lock held on entry +func (sess *mqttSession) getPubAckIdentifier(pQos byte, sub *subscription) uint16 { + pi, _ := sess.trackPending(pQos, _EMPTY_, sub) + return pi +} + +// If publish message QoS (pQos) and the subscription's QoS are both at least 1, +// this function will assign a packet identifier (pi) and will keep track of +// the pending ack. If the message has already been redelivered (reply != ""), +// the returned boolean will be `true`. +// +// Lock held on entry +func (sess *mqttSession) trackPending(pQos byte, reply string, sub *subscription) (uint16, bool) { + if pQos == 0 || sub.mqtt.qos == 0 { + return 0, false + } + var dup bool + var pi uint16 + + bumpPI := func() uint16 { + var avail bool + next := sess.ppi + for i := 0; i < 0xFFFF; i++ { + next++ + if next == 0 { + next = 1 + } + if _, used := sess.pending[next]; !used { + sess.ppi = next + avail = true + break + } + } + if !avail { + return 0 + } + return sess.ppi + } + + if reply == _EMPTY_ || sub.mqtt.jsCons == nil { + return bumpPI(), false + } + + // Here, we have an ACK subject and a JS consumer... + jsCons := sub.mqtt.jsCons + if sess.pending == nil { + sess.pending = make(map[uint16]*mqttPending) + sess.cpending = make(map[*Consumer]map[uint64]uint16) + } + // Get the stream sequence and other from the ack reply subject + sseq, dseq, dcount := ackReplyInfo(reply) + + var pending *mqttPending + // For this JS consumer, check to see if we already have sseq->pi + sseqToPi, ok := sess.cpending[jsCons] + if !ok { + sseqToPi = make(map[uint64]uint16) + sess.cpending[jsCons] = sseqToPi + } else if pi, ok = sseqToPi[sseq]; ok { + // If we already have a pi, get the ack so we update it. + // We will reuse the save packet identifier (pi). + pending = sess.pending[pi] + } + if pi == 0 { + // sess.maxp will always have a value > 0. + if len(sess.pending) >= int(sess.maxp) { + // Indicate that we did not assign a packet identifier. + // The caller will not send the message to the subscription + // and JS will redeliver later, based on consumer's AckWait. + sess.stalled = true + return 0, false + } + pi = bumpPI() + sseqToPi[sseq] = pi + } + if pending == nil { + pending = &mqttPending{jsCons: jsCons, sseq: sseq} + sess.pending[pi] = pending + } + // Update pending with consumer delivery sequence and count + pending.dseq, pending.dcount = dseq, dcount + // If redelivery, return DUP flag + if dcount > 1 { + dup = true + } + return pi, dup +} + +////////////////////////////////////////////////////////////////////////////// +// +// CONNECT protocol related functions +// +////////////////////////////////////////////////////////////////////////////// + +// Parse the MQTT connect protocol +func (c *client) mqttParseConnect(r *mqttReader, pl int) (byte, *mqttConnectProto, error) { + + // Make sure that we have the expected length in the buffer, + // and if not, this will read it from the underlying reader. + if err := r.ensurePacketInBuffer(pl); err != nil { + return 0, nil, err + } + + // Protocol name + proto, err := r.readBytes("protocol name", false) + if err != nil { + return 0, nil, err + } + + // Spec [MQTT-3.1.2-1] + if !bytes.Equal(proto, mqttProtoName) { + // Check proto name against v3.1 to report better error + if bytes.Equal(proto, mqttOldProtoName) { + return 0, nil, fmt.Errorf("older protocol %q not supported", proto) + } + return 0, nil, fmt.Errorf("expected connect packet with protocol name %q, got %q", mqttProtoName, proto) + } + + // Protocol level + level, err := r.readByte("protocol level") + if err != nil { + return 0, nil, err + } + // Spec [MQTT-3.1.2-2] + if level != mqttProtoLevel { + return mqttConnAckRCUnacceptableProtocolVersion, nil, fmt.Errorf("unacceptable protocol version of %v", level) + } + + cp := &mqttConnectProto{} + // Connect flags + cp.flags, err = r.readByte("flags") + if err != nil { + return 0, nil, err + } + + // Spec [MQTT-3.1.2-3] + if cp.flags&mqttConnFlagReserved != 0 { + return 0, nil, fmt.Errorf("connect flags reserved bit not set to 0") + } + + var hasWill bool + wqos := (cp.flags & mqttConnFlagWillQoS) >> 3 + wretain := cp.flags&mqttConnFlagWillRetain != 0 + // Spec [MQTT-3.1.2-11] + if cp.flags&mqttConnFlagWillFlag == 0 { + // Spec [MQTT-3.1.2-13] + if wqos != 0 { + return 0, nil, fmt.Errorf("if Will flag is set to 0, Will QoS must be 0 too, got %v", wqos) + } + // Spec [MQTT-3.1.2-15] + if wretain { + return 0, nil, fmt.Errorf("if Will flag is set to 0, Will Retain flag must be 0 too") + } + } else { + // Spec [MQTT-3.1.2-14] + if wqos == 3 { + return 0, nil, fmt.Errorf("if Will flag is set to 1, Will QoS can be 0, 1 or 2, got %v", wqos) + } + hasWill = true + } + + // Spec [MQTT-3.1.2-19] + hasUser := cp.flags&mqttConnFlagUsernameFlag != 0 + // Spec [MQTT-3.1.2-21] + hasPassword := cp.flags&mqttConnFlagPasswordFlag != 0 + // Spec [MQTT-3.1.2-22] + if !hasUser && hasPassword { + return 0, nil, fmt.Errorf("password flag set but username flag is not") + } + + // Keep alive + var ka uint16 + ka, err = r.readUint16("keep alive") + if err != nil { + return 0, nil, err + } + // Spec [MQTT-3.1.2-24] + if ka > 0 { + cp.rd = time.Duration(float64(ka)*1.5) * time.Second + } + + // Payload starts here and order is mandated by: + // Spec [MQTT-3.1.3-1]: client ID, will topic, will message, username, password + + // Client ID + cp.clientID, err = r.readString("client ID") + if err != nil { + return 0, nil, err + } + // Spec [MQTT-3.1.3-7] + if cp.clientID == _EMPTY_ { + if cp.flags&mqttConnFlagCleanSession == 0 { + return mqttConnAckRCIdentifierRejected, nil, fmt.Errorf("when client ID is empty, clean session flag must be set to 1") + } + // Spec [MQTT-3.1.3-6] + cp.clientID = nuid.Next() + } + // Spec [MQTT-3.1.3-4] and [MQTT-3.1.3-9] + if !utf8.ValidString(cp.clientID) { + return mqttConnAckRCIdentifierRejected, nil, fmt.Errorf("invalid utf8 for client ID: %q", cp.clientID) + } + + if hasWill { + cp.will = &mqttWill{ + qos: wqos, + retain: wretain, + } + var topic []byte + topic, err = r.readBytes("Will topic", false) + if err != nil { + return 0, nil, err + } + if len(topic) == 0 { + return 0, nil, fmt.Errorf("empty Will topic not allowed") + } + if !utf8.Valid(topic) { + return 0, nil, fmt.Errorf("invalide utf8 for Will topic %q", topic) + } + // Convert MQTT topic to NATS subject + var copied bool + copied, topic, err = mqttTopicToNATSPubSubject(topic) + if err != nil { + return 0, nil, err + } + if !copied { + topic = copyBytes(topic) + } + cp.will.topic = topic + // Now will message + var msg []byte + msg, err = r.readBytes("Will message", false) + if err != nil { + return 0, nil, err + } + cp.will.message = make([]byte, 0, len(msg)+2) + cp.will.message = append(cp.will.message, msg...) + cp.will.message = append(cp.will.message, CR_LF...) + } + + if hasUser { + c.opts.Username, err = r.readString("user name") + if err != nil { + return 0, nil, err + } + if c.opts.Username == _EMPTY_ { + return mqttConnAckRCBadUserOrPassword, nil, fmt.Errorf("empty user name not allowed") + } + // Spec [MQTT-3.1.3-11] + if !utf8.ValidString(c.opts.Username) { + return mqttConnAckRCBadUserOrPassword, nil, fmt.Errorf("invalid utf8 for user name %q", c.opts.Username) + } + } + + if hasPassword { + c.opts.Password, err = r.readString("password") + if err != nil { + return 0, nil, err + } + c.opts.Token = c.opts.Password + } + return 0, cp, nil +} + +func (c *client) mqttConnectTrace(cp *mqttConnectProto) string { + trace := fmt.Sprintf("clientID=%s", cp.clientID) + if cp.rd > 0 { + trace += fmt.Sprintf(" keepAlive=%v", cp.rd) + } + if cp.will != nil { + trace += fmt.Sprintf(" will=(topic=%s QoS=%v retain=%v)", + cp.will.topic, cp.will.qos, cp.will.retain) + } + if c.opts.Username != _EMPTY_ { + trace += fmt.Sprintf(" username=%s", c.opts.Username) + } + if c.opts.Password != _EMPTY_ { + trace += " password=****" + } + return trace +} + +func (s *Server) mqttProcessConnect(c *client, cp *mqttConnectProto, trace bool) error { + sendConnAck := func(rc byte, sessp bool) { + c.mqttEnqueueConnAck(rc, sessp) + if trace { + c.traceOutOp("CONNACK", []byte(fmt.Sprintf("sp=%v rc=%v", sessp, rc))) + } + } + + c.mu.Lock() + c.clearAuthTimer() + c.mu.Unlock() + if !s.isClientAuthorized(c) { + sendConnAck(mqttConnAckRCNotAuthorized, false) + c.closeConnection(AuthenticationViolation) + return ErrAuthentication + } + // Now that we are are authenticated, we have the client bound to the account. + // Get the account's level MQTT sessions manager. If it does not exists yet, + // this will create it along with the streams where sessions and messages + // are stored. + sm := &s.mqtt.sessmgr + asm, err := sm.getOrCreateAccountSessionManager(cp.clientID, c) + if err != nil { + return err + } + + // Rest of code runs under the account's sessions manager write lock. + asm.mu.Lock() + defer asm.mu.Unlock() + + // Is the client requesting a clean session or not. + cleanSess := cp.flags&mqttConnFlagCleanSession != 0 + // Session present? Assume false, will be set to true only when applicable. + sessp := false + // Do we have an existing session for this client ID + es, ok := asm.sessions[cp.clientID] + if ok { + es.mu.Lock() + defer es.mu.Unlock() + // Clear the session if client wants a clean session. + // Also, Spec [MQTT-3.2.2-1]: don't report session present + if cleanSess || es.clean { + // Spec [MQTT-3.1.2-6]: If CleanSession is set to 1, the Client and + // Server MUST discard any previous Session and start a new one. + // This Session lasts as long as the Network Connection. State data + // associated with this Session MUST NOT be reused in any subsequent + // Session. + es.clear() + } else { + // Report to the client that the session was present + sessp = true + } + ec := es.c + // Is there an actual client associated with this session. + if ec != nil { + // Spec [MQTT-3.1.4-2]. If the ClientId represents a Client already + // connected to the Server then the Server MUST disconnect the existing + // client. + ec := es.c + ec.mu.Lock() + // Remove will before closing + ec.mqtt.cp.will = nil + ec.mu.Unlock() + // Close old client in separate go routine + go ec.closeConnection(DuplicateClientID) + } + // Bind with the new client + es.c = c + es.clean = cleanSess + } else { + // Spec [MQTT-3.2.2-3]: if the Server does not have stored Session state, + // it MUST set Session Present to 0 in the CONNACK packet. + es = mqttSessionCreate(s.getOpts()) + es.c, es.clean, es.stream = c, cleanSess, asm.sstream + es.mu.Lock() + defer es.mu.Unlock() + asm.sessions[cp.clientID] = es + es.save(cp.clientID) + } + c.mu.Lock() + c.flags.set(connectReceived) + c.mqtt.cp = cp + c.mqtt.asm = asm + c.mqtt.sess = es + c.mu.Unlock() + // Spec [MQTT-3.2.0-1]: At this point we need to send the CONNACK before + // restoring subscriptions, because CONNACK must be the first packet sent + // to the client. + sendConnAck(mqttConnAckRCConnectionAccepted, sessp) + // Now process possible saved subscriptions. + if l := len(es.subs); l > 0 { + filters := make([]*mqttFilter, 0, l) + for subject, qos := range es.subs { + filters = append(filters, &mqttFilter{filter: subject, qos: qos}) + } + if _, err := asm.processSubs(es, cp.clientID, c, filters, false, trace); err != nil { + return err + } + } + return nil +} + +func (c *client) mqttEnqueueConnAck(rc byte, sessionPresent bool) { + proto := [4]byte{mqttPacketConnectAck, 2, 0, rc} + c.mu.Lock() + // Spec [MQTT-3.2.2-4]. If return code is different from 0, then + // session present flag must be set to 0. + if rc == 0 { + if sessionPresent { + proto[2] = 1 + } + } + c.enqueueProto(proto[:]) + c.mu.Unlock() +} + +func (s *Server) mqttHandleWill(c *client) { + c.mu.Lock() + if c.mqtt.cp == nil { + c.mu.Unlock() + return + } + will := c.mqtt.cp.will + if will == nil { + c.mu.Unlock() + return + } + pp := &mqttPublish{ + subject: will.topic, + msg: will.message, + sz: len(will.message) - LEN_CR_LF, + flags: will.qos << 1, + } + if will.retain { + pp.flags |= mqttPubFlagRetain + } + c.mu.Unlock() + s.mqttProcessPub(c, pp) + c.flushClients(0) +} + +////////////////////////////////////////////////////////////////////////////// +// +// PUBLISH protocol related functions +// +////////////////////////////////////////////////////////////////////////////// + +func (c *client) mqttParsePub(r *mqttReader, pl int, pp *mqttPublish) error { + qos := (pp.flags & mqttPubFlagQoS) >> 1 + if qos > 1 { + return fmt.Errorf("publish QoS=%v not supported", qos) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + return err + } + // Keep track of where we are when starting to read the variable header + start := r.pos + + var err error + pp.subject, err = r.readBytes("topic", false) + if err != nil { + return err + } + if len(pp.subject) == 0 { + return fmt.Errorf("topic cannot be empty") + } + // Convert the topic to a NATS subject. This call will also check that + // there is no MQTT wildcards (Spec [MQTT-3.3.2-2] and [MQTT-4.7.1-1]) + // Note that this may not result in a copy if there is no special + // conversion. It is good because after the message is processed we + // won't have a reference to the buffer and we save a copy. + _, pp.subject, err = mqttTopicToNATSPubSubject(pp.subject) + if err != nil { + return err + } + + if qos > 0 { + pp.pi, err = r.readUint16("packet identifier") + if err != nil { + return err + } + if pp.pi == 0 { + return fmt.Errorf("with QoS=%v, packet identifier cannot be 0", qos) + } + } + + // The message payload will be the total packet length minus + // what we have consumed for the variable header + pp.sz = pl - (r.pos - start) + pp.msg = make([]byte, 0, pp.sz+2) + if pp.sz > 0 { + start = r.pos + r.pos += pp.sz + pp.msg = append(pp.msg, r.buf[start:r.pos]...) + } + pp.msg = append(pp.msg, _CRLF_...) + return nil +} + +func mqttPubTrace(pp *mqttPublish) string { + dup := pp.flags&mqttPubFlagDup != 0 + qos := mqttGetQoS(pp.flags) + retain := mqttIsRetained(pp.flags) + var piStr string + if pp.pi > 0 { + piStr = fmt.Sprintf(" pi=%v", pp.pi) + } + return fmt.Sprintf("%s dup=%v QoS=%v retain=%v size=%v%s", + pp.subject, dup, qos, retain, pp.sz, piStr) +} + +func (s *Server) mqttProcessPub(c *client, pp *mqttPublish) { + c.mqtt.pp = pp + c.pa.subject, c.pa.hdr, c.pa.size, c.pa.szb = pp.subject, -1, pp.sz, []byte(strconv.FormatInt(int64(pp.sz), 10)) + // This will work for QoS 0 but mqtt msg delivery callback will ignore + // delivery for QoS > 0 published messages (since it is handled specifically + // with call to directProcessInboundJetStreamMsg). + // However, this needs to be invoked before directProcessInboundJetStreamMsg() + // in case we are dealing with publish retained messages. + c.processInboundClientMsg(pp.msg) + if mqttGetQoS(pp.flags) > 0 { + // Since this is the fast path, we access the messages stream directly here + // without locking. All the fields mqtt.asm.mstream are immutable. + c.mqtt.asm.mstream.processInboundJetStreamMsg(nil, c, string(c.pa.subject), "", pp.msg[:len(pp.msg)-LEN_CR_LF]) + } + c.pa.subject, c.pa.hdr, c.pa.size, c.pa.szb = nil, -1, 0, nil + c.mqtt.pp = nil +} + +// Invoked when processing an inbound client message. If the "retain" flag is +// set, the message is stored so it can be later resent to (re)starting +// subscriptions that match the subject. +// +// Invoked from the publisher's readLoop. No client lock is held on entry. +func (c *client) mqttHandlePubRetain() { + pp := c.mqtt.pp + if mqttIsRetained(pp.flags) { + key := string(pp.subject) + asm := c.mqtt.asm + asm.mu.Lock() + // Spec [MQTT-3.3.1-11]. Payload of size 0 removes the retained message, + // but should still be delivered as a normal message. + if pp.sz == 0 { + if asm.retmsgs != nil { + if erm, ok := asm.retmsgs[key]; ok { + delete(asm.retmsgs, key) + asm.sl.Remove(erm.sub) + if erm.sseq != 0 { + asm.rstream.DeleteMsg(erm.sseq) + } + } + } + } else { + // Spec [MQTT-3.3.1-5]. Store the retained message with its QoS. + // When coming from a publish protocol, `pp` is referencing a stack + // variable that itself possibly references the read buffer. + rm := &mqttRetainedMsg{ + Msg: copyBytes(pp.msg), + Flags: pp.flags, + Source: c.opts.Username, + } + rm = asm.handleRetainedMsg(key, rm) + rmBytes, _ := json.Marshal(rm) + // TODO: For now we will report the error but continue... + seq, _, err := asm.rstream.store.StoreMsg(key, nil, rmBytes) + if err != nil { + c.mu.Lock() + acc := c.acc + c.mu.Unlock() + c.Errorf("unable to store retained message for account %q, subject %q: %v", + acc.GetName(), key, err) + } + // If it has been replaced, rm.sseq will be != 0 + if rm.sseq != 0 { + asm.rstream.DeleteMsg(rm.sseq) + } + // Keep track of current stream sequence (possibly 0 if failed to store) + rm.sseq = seq + } + + asm.mu.Unlock() + + // Clear the retain flag for a normal published message. + pp.flags &= ^mqttPubFlagRetain + } +} + +// After a config reload, it is possible that the source of a publish retained +// message is no longer allowed to publish on the given topic. If that is the +// case, the retained message is removed from the map and will no longer be +// sent to (re)starting subscriptions. +// +// Server lock is held on entry +func (s *Server) mqttCheckPubRetainedPerms() { + sm := &s.mqtt.sessmgr + sm.mu.RLock() + defer sm.mu.RUnlock() + + for _, asm := range sm.sessions { + perms := map[string]*perm{} + asm.mu.Lock() + for subject, rm := range asm.retmsgs { + if rm.Source == _EMPTY_ { + continue + } + // Lookup source from global users. + u := s.users[rm.Source] + if u != nil { + p, ok := perms[rm.Source] + if !ok { + p = generatePubPerms(u.Permissions) + perms[rm.Source] = p + } + // If there is permission and no longer allowed to publish in + // the subject, remove the publish retained message from the map. + if p != nil && !pubAllowed(p, subject) { + u = nil + } + } + + // Not present or permissions have changed such that the source can't + // publish on that subject anymore: remove it from the map. + if u == nil { + delete(asm.retmsgs, subject) + asm.rstream.DeleteMsg(rm.sseq) + asm.sl.Remove(rm.sub) + } + } + asm.mu.Unlock() + } +} + +// Helper to generate only pub permissions from a Permissions object +func generatePubPerms(perms *Permissions) *perm { + var p *perm + if perms.Publish.Allow != nil { + p = &perm{} + p.allow = NewSublistWithCache() + } + for _, pubSubject := range perms.Publish.Allow { + sub := &subscription{subject: []byte(pubSubject)} + p.allow.Insert(sub) + } + if len(perms.Publish.Deny) > 0 { + if p == nil { + p = &perm{} + } + p.deny = NewSublistWithCache() + } + for _, pubSubject := range perms.Publish.Deny { + sub := &subscription{subject: []byte(pubSubject)} + p.deny.Insert(sub) + } + return p +} + +// Helper that checks if given `perms` allow to publish on the given `subject` +func pubAllowed(perms *perm, subject string) bool { + allowed := true + if perms.allow != nil { + r := perms.allow.Match(subject) + allowed = len(r.psubs) != 0 + } + // If we have a deny list and are currently allowed, check that as well. + if allowed && perms.deny != nil { + r := perms.deny.Match(subject) + allowed = len(r.psubs) == 0 + } + return allowed +} + +func mqttWritePublish(w *mqttWriter, qos byte, dup, retain bool, subject string, pi uint16, payload []byte) { + flags := qos << 1 + if dup { + flags |= mqttPubFlagDup + } + if retain { + flags |= mqttPubFlagRetain + } + w.WriteByte(mqttPacketPub | flags) + pkLen := 2 + len(subject) + len(payload) + if qos > 0 { + pkLen += 2 + } + w.WriteVarInt(pkLen) + w.WriteString(subject) + if qos > 0 { + w.WriteUint16(pi) + } + w.Write([]byte(payload)) +} + +func (c *client) mqttEnqueuePubAck(pi uint16) { + proto := [4]byte{mqttPacketPubAck, 0x2, 0, 0} + proto[2] = byte(pi >> 8) + proto[3] = byte(pi) + c.mu.Lock() + c.enqueueProto(proto[:4]) + c.mu.Unlock() +} + +func mqttParsePubAck(r *mqttReader, pl int) (uint16, error) { + if err := r.ensurePacketInBuffer(pl); err != nil { + return 0, err + } + pi, err := r.readUint16("packet identifier") + if err != nil { + return 0, err + } + if pi == 0 { + return 0, fmt.Errorf("packet identifier cannot be 0") + } + return pi, nil +} + +func (c *client) mqttProcessPubAck(pi uint16) { + sess := c.mqtt.sess + if sess == nil { + return + } + sess.mu.Lock() + defer sess.mu.Unlock() + if sess.c != c { + return + } + if ack, ok := sess.pending[pi]; ok { + delete(sess.pending, pi) + jsCons := ack.jsCons + if sseqToPi, ok := sess.cpending[jsCons]; ok { + delete(sseqToPi, ack.sseq) + } + jsCons.ackMsg(ack.sseq, ack.dseq, ack.dcount) + if len(sess.pending) == 0 { + sess.ppi = 0 + } + if sess.stalled && len(sess.pending) < int(sess.maxp) { + sess.stalled = false + for _, cons := range sess.cons { + cons.mu.Lock() + cons.forceExpirePending() + cons.mu.Unlock() + } + } + } +} + +// Return the QoS from the given PUBLISH protocol's flags +func mqttGetQoS(flags byte) byte { + return flags & mqttPubFlagQoS >> 1 +} + +func mqttIsRetained(flags byte) bool { + return flags&mqttPubFlagRetain != 0 +} + +////////////////////////////////////////////////////////////////////////////// +// +// SUBSCRIBE related functions +// +////////////////////////////////////////////////////////////////////////////// + +func (c *client) mqttParseSubs(r *mqttReader, b byte, pl int) (uint16, []*mqttFilter, error) { + return c.mqttParseSubsOrUnsubs(r, b, pl, true) +} + +func (c *client) mqttParseSubsOrUnsubs(r *mqttReader, b byte, pl int, sub bool) (uint16, []*mqttFilter, error) { + var expectedFlag byte + var action string + if sub { + expectedFlag = mqttSubscribeFlags + } else { + expectedFlag = mqttUnsubscribeFlags + action = "un" + } + // Spec [MQTT-3.8.1-1], [MQTT-3.10.1-1] + if rf := b & 0xf; rf != expectedFlag { + return 0, nil, fmt.Errorf("wrong %ssubscribe reserved flags: %x", action, rf) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + return 0, nil, err + } + pi, err := r.readUint16("packet identifier") + if err != nil { + return 0, nil, fmt.Errorf("reading packet identifier: %v", err) + } + end := r.pos + (pl - 2) + var filters []*mqttFilter + for r.pos < end { + // Don't make a copy now because, this will happen during conversion + // or when processing the sub. + filter, err := r.readBytes("topic filter", false) + if err != nil { + return 0, nil, err + } + if len(filter) == 0 { + return 0, nil, errors.New("topic filter cannot be empty") + } + // Spec [MQTT-3.8.3-1], [MQTT-3.10.3-1] + if !utf8.Valid(filter) { + return 0, nil, fmt.Errorf("invalid utf8 for topic filter %q", filter) + } + var qos byte + // This won't return an error. We will find out if the subject + // is valid or not when trying to create the subscription. + _, filter, _ = mqttFilterToNATSSubject(filter) + if sub { + qos, err = r.readByte("QoS") + if err != nil { + return 0, nil, err + } + // Spec [MQTT-3-8.3-4]. + if qos > 2 { + return 0, nil, fmt.Errorf("subscribe QoS value must be 0, 1 or 2, got %v", qos) + } + } + filters = append(filters, &mqttFilter{string(filter), qos}) + } + // Spec [MQTT-3.8.3-3], [MQTT-3.10.3-2] + if len(filters) == 0 { + return 0, nil, fmt.Errorf("%ssubscribe protocol must contain at least 1 topic filter", action) + } + return pi, filters, nil +} + +func mqttSubscribeTrace(filters []*mqttFilter) string { + var sep string + trace := "[" + for i, f := range filters { + trace += sep + fmt.Sprintf("%s QoS=%v", f.filter, f.qos) + if i == 0 { + sep = ", " + } + } + trace += "]" + return trace +} + +func mqttDeliverMsgCb(sub *subscription, pc *client, subject, reply string, msg []byte) { + if sub.mqtt == nil { + return + } + + var ppFlags byte + var pQoS byte + var pi uint16 + var dup bool + var retained bool + + // This is the client associated with the subscription. + cc := sub.client + + // This is immutable + sess := cc.mqtt.sess + // We lock to check some of the subscription's fields and if we need to + // keep track of pending acks, etc.. + sess.mu.Lock() + if sess.c != cc { + sess.mu.Unlock() + return + } + + // Check the publisher's kind. If JETSTREAM it means that this is a persisted message + // that is being delivered. + if pc.kind == JETSTREAM { + // If there is no JS consumer attached to this subscription, it means that we are + // dealing with a bare NATS subscription, in which case we simply return to avoid + // duplicate delivery. + if sub.mqtt.jsCons == nil { + sess.mu.Unlock() + return + } + ppFlags = mqttPubQos1 + pQoS = 1 + // This is a QoS1 message for a QoS1 subscription, so get the pi and keep + // track of ack subject. + pi, dup = sess.trackPending(pQoS, reply, sub) + if pi == 0 { + // We have reached max pending, don't send the message now. + // JS will cause a redelivery and if by then the number of pending + // messages has fallen below threshold, the message will be resent. + sess.mu.Unlock() + return + } + // In JS case, we need to use the pc.ca.deliver value as the subject. + subject = string(pc.pa.deliver) + } else if pc.mqtt != nil { + // This is a MQTT publisher... + ppFlags = pc.mqtt.pp.flags + pQoS = mqttGetQoS(ppFlags) + // If the QoS of published message and subscription is 1, then we return here to + // avoid duplicate delivery. The JetStream publisher will handle that case. + if pQoS > 0 && sub.mqtt.qos > 0 { + sess.mu.Unlock() + return + } + retained = mqttIsRetained(ppFlags) + } + // else this is coming from a non MQTT publisher, so Qos 0, no dup nor retain flag, etc.. + sess.mu.Unlock() + + sw := mqttWriter{} + w := &sw + + flags := mqttSerializePublishMsg(w, pi, dup, retained, subject, msg) + + cc.mu.Lock() + if sub.mqtt.prm != nil { + cc.queueOutbound(sub.mqtt.prm.Bytes()) + sub.mqtt.prm = nil + } + cc.queueOutbound(w.Bytes()) + pc.addToPCD(cc) + if cc.trace { + pp := mqttPublish{ + flags: flags, + pi: pi, + subject: []byte(subject), + sz: len(msg), + } + cc.traceOutOp("PUBLISH", []byte(mqttPubTrace(&pp))) + } + cc.mu.Unlock() +} + +// Serializes to the given writer the message for the given subject. +func mqttSerializePublishMsg(w *mqttWriter, pi uint16, dup, retained bool, subject string, msg []byte) byte { + topic := natsSubjectToMQTTTopic(subject) + + // Compute len (will have to add packet id if message is sent as QoS>=1) + pkLen := 2 + len(topic) + len(msg) + + var flags byte + + // Set flags for dup/retained/qos1 + if dup { + flags |= mqttPubFlagDup + } + if retained { + flags |= mqttPubFlagRetain + } + // For now, we have only QoS 1 + if pi > 0 { + pkLen += 2 + flags |= mqttPubQos1 + } + + w.WriteByte(mqttPacketPub | flags) + w.WriteVarInt(pkLen) + w.WriteBytes(topic) + if pi > 0 { + w.WriteUint16(pi) + } + w.Write(msg) + + return flags +} + +// Helper to create an MQTT subscription. +func (c *client) mqttCreateSub(subject, sid string, cb msgHandler, qos byte) *subscription { + sub := c.createSub([]byte(subject), nil, []byte(sid), cb) + sub.mqtt = &mqttSub{qos: qos} + return sub +} + +// Process the list of subscriptions and update the given filter +// with the QoS that has been accepted (or failure). +// +// Spec [MQTT-3.8.4-3] says that if an exact same subscription is +// found, it needs to be replaced with the new one (possibly updating +// the qos) and that the flow of publications must not be interrupted, +// which I read as the replacement cannot be a "remove then add" if there +// is a chance that in between the 2 actions, published messages +// would be "lost" because there would not be any matching subscription. +func (c *client) mqttProcessSubs(filters []*mqttFilter) ([]*subscription, error) { + // Those things are immutable, but since processing subs is not + // really in the fast path, let's get them under the client lock. + c.mu.Lock() + asm := c.mqtt.asm + sess := c.mqtt.sess + clientID := c.mqtt.cp.clientID + trace := c.trace + c.mu.Unlock() + + asm.mu.RLock() + defer asm.mu.RUnlock() + return asm.processSubs(sess, clientID, c, filters, true, trace) +} + +func (c *client) mqttCleanupFailedSub(sub *subscription, jscons *Consumer, jssub *subscription) { + c.mu.Lock() + acc := c.acc + c.mu.Unlock() + + if sub != nil { + c.unsubscribe(acc, sub, true, true) + } + if jssub != nil { + c.unsubscribe(acc, jssub, true, true) + } + if jscons != nil { + jscons.Delete() + } +} + +// When invoked with a QoS of 0, looks for an existing JS durable consumer for +// the given sid and if one is found, delete the JS durable consumer and unsub +// the NATS subscription on the delivery subject. +// With a QoS > 0, creates or update the existing JS durable consumer along with +// its NATS subscription on a delivery subject. +// +// Account session manager lock held on entry. +func (c *client) mqttProcessJSConsumer(sess *mqttSession, stream *Stream, subject, + sid string, qos byte, fromSubProto bool) (*Consumer, *subscription, error) { + + // Check if we are already a JS consumer for this SID. + cons, exists := sess.cons[sid] + if exists { + // If current QoS is 0, it means that we need to delete the existing + // one (that was QoS > 0) + if qos == 0 { + // The JS durable consumer's delivery subject is on a NUID of + // the form: mqttSubPrefix + . It is also used as the sid + // for the NATS subscription, so use that for the lookup. + sub := c.subs[cons.Config().DeliverSubject] + delete(sess.cons, sid) + cons.Delete() + if sub != nil { + c.mu.Lock() + acc := c.acc + c.mu.Unlock() + c.unsubscribe(acc, sub, true, true) + } + return nil, nil, nil + } + // If this is called when processing SUBSCRIBE protocol, then if + // the JS consumer already exists, we are done (it was created + // during the processing of CONNECT). + if fromSubProto { + return nil, nil, nil + } + } + // Here it means we don't have a JS consumer and if we are QoS 0, + // we have nothing to do. + if qos == 0 { + return nil, nil, nil + } + var err error + inbox := mqttSubPrefix + nuid.Next() + if exists { + cons.updateDeliverSubject(inbox) + } else { + durName := nuid.Next() + opts := c.srv.getOpts() + ackWait := opts.MQTT.AckWait + if ackWait == 0 { + ackWait = mqttDefaultAckWait + } + maxAckPending := opts.MQTT.MaxAckPending + if maxAckPending == 0 { + maxAckPending = mqttDefaultMaxAckPending + } + cc := &ConsumerConfig{ + DeliverSubject: inbox, + Durable: durName, + AckPolicy: AckExplicit, + DeliverPolicy: DeliverNew, + FilterSubject: subject, + AckWait: ackWait, + MaxAckPending: int(maxAckPending), + } + cons, err = stream.addConsumerCheckInterest(cc, false) + if err != nil { + c.Errorf("Unable to add JetStream consumer for subscription on %q: err=%v", subject, err) + return nil, nil, err + } + } + sub := c.mqttCreateSub(inbox, inbox, mqttDeliverMsgCb, qos) + sub.mqtt.jsCons = cons + // This is an internal subscription on subject like "$MQTT.sub." that is setup + // for the JS durable's deliver subject. I don't think that there is any need to + // forward this subscription in the cluster/super cluster. + sub, err = c.processSub(sub, true) + if err != nil { + if !exists { + cons.Delete() + } + c.Errorf("Unable to create subscription for JetStream consumer on %q: %v", subject, err) + return nil, nil, err + } + return cons, sub, nil +} + +// Queues the published retained messages for each subscription and signals +// the writeLoop. +func (c *client) mqttSendRetainedMsgsToNewSubs(subs []*subscription) { + c.mu.Lock() + for _, sub := range subs { + if sub.mqtt != nil && sub.mqtt.prm != nil { + c.queueOutbound(sub.mqtt.prm.Bytes()) + sub.mqtt.prm = nil + } + } + c.flushSignal() + c.mu.Unlock() +} + +func (c *client) mqttEnqueueSubAck(pi uint16, filters []*mqttFilter) { + w := &mqttWriter{} + w.WriteByte(mqttPacketSubAck) + // packet length is 2 (for packet identifier) and 1 byte per filter. + w.WriteVarInt(2 + len(filters)) + w.WriteUint16(pi) + for _, f := range filters { + w.WriteByte(f.qos) + } + c.mu.Lock() + c.enqueueProto(w.Bytes()) + c.mu.Unlock() +} + +////////////////////////////////////////////////////////////////////////////// +// +// UNSUBSCRIBE related functions +// +////////////////////////////////////////////////////////////////////////////// + +func (c *client) mqttParseUnsubs(r *mqttReader, b byte, pl int) (uint16, []*mqttFilter, error) { + return c.mqttParseSubsOrUnsubs(r, b, pl, false) +} + +func (c *client) mqttProcessUnsubs(filters []*mqttFilter) error { + // Those things are immutable, but since processing unsubs is not + // really in the fast path, let's get them under the client lock. + c.mu.Lock() + sess := c.mqtt.sess + clientID := c.mqtt.cp.clientID + c.mu.Unlock() + + sess.mu.Lock() + defer sess.mu.Unlock() + if sess.c != c { + return fmt.Errorf("client %q no longer registered with MQTT session", clientID) + } + + removeJSCons := func(sid string) { + if jscons, ok := sess.cons[sid]; ok { + delete(sess.cons, sid) + jscons.Delete() + if seqPis, ok := sess.cpending[jscons]; ok { + delete(sess.cpending, jscons) + for _, pi := range seqPis { + delete(sess.pending, pi) + } + if len(sess.pending) == 0 { + sess.ppi = 0 + } + } + } + } + for _, f := range filters { + sid := f.filter + // Remove JS Consumer if one exists for this sid + removeJSCons(sid) + if err := c.processUnsub([]byte(sid)); err != nil { + c.Errorf("error unsubscribing from %q: %v", sid, err) + } + if mqttNeedSubForLevelUp(sid) { + subject := sid[:len(sid)-2] + sid = subject + mqttMultiLevelSidSuffix + removeJSCons(sid) + if err := c.processUnsub([]byte(sid)); err != nil { + c.Errorf("error unsubscribing from %q: %v", subject, err) + } + } + } + return sess.update(clientID, filters, false) +} + +func (c *client) mqttEnqueueUnsubAck(pi uint16) { + w := &mqttWriter{} + w.WriteByte(mqttPacketUnsubAck) + w.WriteVarInt(2) + w.WriteUint16(pi) + c.mu.Lock() + c.enqueueProto(w.Bytes()) + c.mu.Unlock() +} + +func mqttUnsubscribeTrace(filters []*mqttFilter) string { + var sep string + trace := "[" + for i, f := range filters { + trace += sep + f.filter + if i == 0 { + sep = ", " + } + } + trace += "]" + return trace +} + +////////////////////////////////////////////////////////////////////////////// +// +// PINGREQ/PINGRESP related functions +// +////////////////////////////////////////////////////////////////////////////// + +func (c *client) mqttEnqueuePingResp() { + c.mu.Lock() + c.enqueueProto(mqttPingResponse) + c.mu.Unlock() +} + +////////////////////////////////////////////////////////////////////////////// +// +// Trace functions +// +////////////////////////////////////////////////////////////////////////////// + +func errOrTrace(err error, trace string) []byte { + if err != nil { + return []byte(err.Error()) + } + return []byte(trace) +} + +////////////////////////////////////////////////////////////////////////////// +// +// Subject/Topic conversion functions +// +////////////////////////////////////////////////////////////////////////////// + +// Converts an MQTT Topic Name to a NATS Subject (used by PUBLISH) +// See mqttToNATSSubjectConversion() for details. +func mqttTopicToNATSPubSubject(mt []byte) (bool, []byte, error) { + return mqttToNATSSubjectConversion(mt, false) +} + +// Converts an MQTT Topic Filter to a NATS Subject (used by SUBSCRIBE) +// See mqttToNATSSubjectConversion() for details. +func mqttFilterToNATSSubject(filter []byte) (bool, []byte, error) { + return mqttToNATSSubjectConversion(filter, true) +} + +// Converts an MQTT Topic Name or Filter to a NATS Subject +// In MQTT: +// - a Topic Name does not have wildcard (PUBLISH uses only topic names). +// - a Topic Filter can include wildcards (SUBSCRIBE uses those). +// - '+' and '#' are wildcard characters (single and multiple levels respectively) +// - '/' is the topic level separator. +// +// Conversion that occurs: +// - '/' is replaced with '/.' if it is the first character in mt +// - '/' is replaced with './' if the last or next character in mt is '/' +// For instance, foo//bar would become foo./.bar +// - '/' is replaced with '.' for all other conditions (foo/bar -> foo.bar) +// - '.' and ' ' cause an error to be returned. +// +// If a copy occurred, the returned boolean will indicate this condition. +func mqttToNATSSubjectConversion(mt []byte, wcOk bool) (bool, []byte, error) { + var res = mt + var newSlice bool + + copyTopic := func(pos int) []byte { + if newSlice && cap(res) > pos+2 { + return res + } + newSlice = true + b := make([]byte, len(res)+10) + copy(b, res[:pos]) + res = b + return res + } + + var j int + end := len(mt) - 1 + for i := 0; i < len(mt); i++ { + switch mt[i] { + case mqttTopicLevelSep: + if i == 0 || res[j-1] == btsep { + res = copyTopic(0) + res[j] = mqttTopicLevelSep + j++ + res[j] = btsep + } else if i == end || mt[i+1] == mqttTopicLevelSep { + res = copyTopic(j) + res[j] = btsep + j++ + res[j] = mqttTopicLevelSep + } else { + res[j] = btsep + } + case btsep, ' ': + // As of now, we cannot support '.' or ' ' in the MQTT topic/filter. + return false, nil, fmt.Errorf("characters ' ' and '.' not supported for MQTT topics") + case mqttSingleLevelWC, mqttMultiLevelWC: + if !wcOk { + // Spec [MQTT-3.3.2-2] and [MQTT-4.7.1-1] + // The wildcard characters can be used in Topic Filters, but MUST NOT be used within a Topic Name + return false, nil, fmt.Errorf("wildcards not allowed in publish's topic: %q", mt) + } + if mt[i] == mqttSingleLevelWC { + res[j] = pwc + } else { + res[j] = fwc + } + default: + if newSlice { + res[j] = mt[i] + } + } + j++ + } + if newSlice && res[j-1] == btsep { + res = copyTopic(j) + res[j] = mqttTopicLevelSep + j++ + } + return newSlice, res[:j], nil +} + +// Converts a NATS subject to MQTT topic. This is for publish +// messages only, so there is no checking for wildcards. +// Rules are reversed of mqttToNATSSubjectConversion. +func natsSubjectToMQTTTopic(subject string) []byte { + topic := []byte(subject) + end := len(subject) - 1 + var j int + for i := 0; i < len(subject); i++ { + switch subject[i] { + case mqttTopicLevelSep: + if !(i == 0 && i < end && subject[i+1] == btsep) { + topic[j] = mqttTopicLevelSep + j++ + } + case btsep: + topic[j] = mqttTopicLevelSep + j++ + if i < end && subject[i+1] == mqttTopicLevelSep { + i++ + } + default: + topic[j] = subject[i] + j++ + } + } + return topic[:j] +} + +// Returns true if the subject has more than 1 token and ends with ".>" +func mqttNeedSubForLevelUp(subject string) bool { + if len(subject) < 3 { + return false + } + end := len(subject) + if subject[end-2] == '.' && subject[end-1] == fwc { + return true + } + return false +} + +////////////////////////////////////////////////////////////////////////////// +// +// MQTT Reader functions +// +////////////////////////////////////////////////////////////////////////////// + +func copyBytes(b []byte) []byte { + if b == nil { + return nil + } + cbuf := make([]byte, len(b)) + copy(cbuf, b) + return cbuf +} + +func (r *mqttReader) reset(buf []byte) { + r.buf = buf + r.pos = 0 +} + +func (r *mqttReader) hasMore() bool { + return r.pos != len(r.buf) +} + +func (r *mqttReader) readByte(field string) (byte, error) { + if r.pos == len(r.buf) { + return 0, fmt.Errorf("error reading %s: %v", field, io.EOF) + } + b := r.buf[r.pos] + r.pos++ + return b, nil +} + +func (r *mqttReader) readPacketLen() (int, error) { + m := 1 + v := 0 + for { + var b byte + if r.pos != len(r.buf) { + b = r.buf[r.pos] + r.pos++ + } else { + var buf [1]byte + if _, err := r.reader.Read(buf[:1]); err != nil { + if err == io.EOF { + return 0, io.ErrUnexpectedEOF + } + return 0, fmt.Errorf("error reading packet length: %v", err) + } + b = buf[0] + } + v += int(b&0x7f) * m + if (b & 0x80) == 0 { + return v, nil + } + m *= 0x80 + if m > 0x200000 { + return 0, errors.New("malformed variable int") + } + } +} + +func (r *mqttReader) ensurePacketInBuffer(pl int) error { + rem := len(r.buf) - r.pos + if rem >= pl { + return nil + } + b := make([]byte, pl) + start := copy(b, r.buf[r.pos:]) + for start != pl { + n, err := r.reader.Read(b[start:cap(b)]) + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return fmt.Errorf("error ensuring protocol is loaded: %v", err) + } + start += n + } + r.reset(b) + return nil +} + +func (r *mqttReader) readString(field string) (string, error) { + var s string + bs, err := r.readBytes(field, false) + if err == nil { + s = string(bs) + } + return s, err +} + +func (r *mqttReader) readBytes(field string, cp bool) ([]byte, error) { + luint, err := r.readUint16(field) + if err != nil { + return nil, err + } + l := int(luint) + if l == 0 { + return nil, nil + } + start := r.pos + if start+l > len(r.buf) { + return nil, fmt.Errorf("error reading %s: %v", field, io.ErrUnexpectedEOF) + } + r.pos += l + b := r.buf[start:r.pos] + if cp { + b = copyBytes(b) + } + return b, nil +} + +func (r *mqttReader) readUint16(field string) (uint16, error) { + if len(r.buf)-r.pos < 2 { + return 0, fmt.Errorf("error reading %s: %v", field, io.ErrUnexpectedEOF) + } + start := r.pos + r.pos += 2 + return binary.BigEndian.Uint16(r.buf[start:r.pos]), nil +} + +////////////////////////////////////////////////////////////////////////////// +// +// MQTT Writer functions +// +////////////////////////////////////////////////////////////////////////////// + +func (w *mqttWriter) WriteUint16(i uint16) { + w.WriteByte(byte(i >> 8)) + w.WriteByte(byte(i)) +} + +func (w *mqttWriter) WriteString(s string) { + w.WriteBytes([]byte(s)) +} + +func (w *mqttWriter) WriteBytes(bs []byte) { + w.WriteUint16(uint16(len(bs))) + w.Write(bs) +} + +func (w *mqttWriter) WriteVarInt(value int) { + for { + b := byte(value & 0x7f) + value >>= 7 + if value > 0 { + b |= 0x80 + } + w.WriteByte(b) + if value == 0 { + break + } + } +} diff --git a/server/mqtt_test.go b/server/mqtt_test.go new file mode 100644 index 00000000..d5ad36b3 --- /dev/null +++ b/server/mqtt_test.go @@ -0,0 +1,4111 @@ +// Copyright 2020 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bufio" + "bytes" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/nats-io/jwt/v2" + "github.com/nats-io/nats.go" +) + +type mqttErrorReader struct { + err error +} + +func (r *mqttErrorReader) Read(b []byte) (int, error) { return 0, r.err } +func (r *mqttErrorReader) SetReadDeadline(time.Time) error { return nil } + +func testNewEOFReader() *mqttErrorReader { + return &mqttErrorReader{err: io.EOF} +} + +func TestMQTTReader(t *testing.T) { + r := &mqttReader{} + r.reset([]byte{0, 2, 'a', 'b'}) + bs, err := r.readBytes("", false) + if err != nil { + t.Fatal(err) + } + sbs := string(bs) + if sbs != "ab" { + t.Fatalf(`expected "ab", got %q`, sbs) + } + + r.reset([]byte{0, 2, 'a', 'b'}) + bs, err = r.readBytes("", true) + if err != nil { + t.Fatal(err) + } + bs[0], bs[1] = 'c', 'd' + if bytes.Equal(bs, r.buf[2:]) { + t.Fatal("readBytes should have returned a copy") + } + + r.reset([]byte{'a', 'b'}) + if b, err := r.readByte(""); err != nil || b != 'a' { + t.Fatalf("Error reading byte: b=%v err=%v", b, err) + } + if !r.hasMore() { + t.Fatal("expected to have more, did not") + } + if b, err := r.readByte(""); err != nil || b != 'b' { + t.Fatalf("Error reading byte: b=%v err=%v", b, err) + } + if r.hasMore() { + t.Fatal("expected to not have more") + } + if _, err := r.readByte("test"); err == nil || !strings.Contains(err.Error(), "error reading test") { + t.Fatalf("unexpected error: %v", err) + } + + r.reset([]byte{0, 2, 'a', 'b'}) + if s, err := r.readString(""); err != nil || s != "ab" { + t.Fatalf("Error reading string: s=%q err=%v", s, err) + } + + r.reset([]byte{10}) + if _, err := r.readUint16("uint16"); err == nil || !strings.Contains(err.Error(), "error reading uint16") { + t.Fatalf("unexpected error: %v", err) + } + + r.reset([]byte{1, 2, 3}) + r.reader = testNewEOFReader() + if err := r.ensurePacketInBuffer(10); err == nil || !strings.Contains(err.Error(), "error ensuring protocol is loaded") { + t.Fatalf("unexpected error: %v", err) + } + + r.reset([]byte{0x82, 0xff, 0x3}) + l, err := r.readPacketLen() + if err != nil { + t.Fatal("error getting packet len") + } + if l != 0xff82 { + t.Fatalf("expected length 0xff82 got 0x%x", l) + } + r.reset([]byte{0xff, 0xff, 0xff, 0xff, 0xff}) + if _, err := r.readPacketLen(); err == nil || !strings.Contains(err.Error(), "malformed") { + t.Fatalf("unexpected error: %v", err) + } + r.reset([]byte{0x80}) + if _, err := r.readPacketLen(); err != io.ErrUnexpectedEOF { + t.Fatalf("unexpected error: %v", err) + } + + r.reset([]byte{0x80}) + r.reader = &mqttErrorReader{err: errors.New("on purpose")} + if _, err := r.readPacketLen(); err == nil || !strings.Contains(err.Error(), "on purpose") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestMQTTWriter(t *testing.T) { + w := &mqttWriter{} + w.WriteUint16(1234) + + r := &mqttReader{} + r.reset(w.Bytes()) + if v, err := r.readUint16(""); err != nil || v != 1234 { + t.Fatalf("unexpected value: v=%v err=%v", v, err) + } + + w.Reset() + w.WriteString("test") + r.reset(w.Bytes()) + if len(r.buf) != 6 { + t.Fatalf("Expected 2 bytes size before string, got %v", r.buf) + } + + w.Reset() + w.WriteBytes([]byte("test")) + r.reset(w.Bytes()) + if len(r.buf) != 6 { + t.Fatalf("Expected 2 bytes size before bytes, got %v", r.buf) + } + + ints := []int{ + 0, 1, 127, 128, 16383, 16384, 2097151, 2097152, 268435455, + } + lens := []int{ + 1, 1, 1, 2, 2, 3, 3, 4, 4, + } + + tl := 0 + w.Reset() + for i, v := range ints { + w.WriteVarInt(v) + tl += lens[i] + if tl != w.Len() { + t.Fatalf("expected len %d, got %d", tl, w.Len()) + } + } + + r.reset(w.Bytes()) + for _, v := range ints { + x, _ := r.readPacketLen() + if v != x { + t.Fatalf("expected %d, got %d", v, x) + } + } +} + +func testMQTTDefaultOptions() *Options { + o := DefaultOptions() + o.Cluster.Port = 0 + o.Gateway.Name = "" + o.Gateway.Port = 0 + o.LeafNode.Port = 0 + o.Websocket.Port = 0 + o.MQTT.Host = "127.0.0.1" + o.MQTT.Port = -1 + o.JetStream = true + return o +} + +func testMQTTRunServer(t testing.TB, o *Options) *Server { + o.NoLog = false + s, err := NewServer(o) + if err != nil { + t.Fatalf("Error creating server: %v", err) + } + l := &DummyLogger{} + s.SetLogger(l, true, true) + go s.Start() + if !s.ReadyForConnections(3 * time.Second) { + t.Fatal("Unable to start server") + } + return s +} + +func testMQTTShutdownServer(s *Server) { + if c := s.JetStreamConfig(); c != nil { + dir := strings.TrimSuffix(c.StoreDir, JetStreamStoreDir) + defer os.RemoveAll(dir) + } + s.Shutdown() +} + +func testMQTTDefaultTLSOptions(t *testing.T, verify bool) *Options { + t.Helper() + o := testMQTTDefaultOptions() + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/server-cert.pem", + KeyFile: "../test/configs/certs/server-key.pem", + CaFile: "../test/configs/certs/ca.pem", + Verify: verify, + } + var err error + o.MQTT.TLSConfig, err = GenTLSConfig(tc) + o.MQTT.TLSTimeout = 2.0 + if err != nil { + t.Fatalf("Error creating tls config: %v", err) + } + return o +} + +func TestMQTTConfig(t *testing.T) { + conf := createConfFile(t, []byte(` + mqtt { + port: -1 + tls { + cert_file: "./configs/certs/server.pem" + key_file: "./configs/certs/key.pem" + } + } + `)) + defer os.Remove(conf) + s, o := RunServerWithConfig(conf) + defer testMQTTShutdownServer(s) + if o.MQTT.TLSConfig == nil { + t.Fatal("expected TLS config to be set") + } +} + +func TestMQTTValidateOptions(t *testing.T) { + nmqtto := DefaultOptions() + mqtto := testMQTTDefaultOptions() + for _, test := range []struct { + name string + getOpts func() *Options + err string + }{ + {"mqtt disabled", func() *Options { return nmqtto.Clone() }, ""}, + {"mqtt username not allowed if users specified", func() *Options { + o := mqtto.Clone() + o.Users = []*User{&User{Username: "abc", Password: "pwd"}} + o.MQTT.Username = "b" + o.MQTT.Password = "pwd" + return o + }, "mqtt authentication username not compatible with presence of users/nkeys"}, + {"mqtt token not allowed if users specified", func() *Options { + o := mqtto.Clone() + o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: "abc"}} + o.MQTT.Token = "mytoken" + return o + }, "mqtt authentication token not compatible with presence of users/nkeys"}, + {"ack wait should be >=0", func() *Options { + o := mqtto.Clone() + o.MQTT.AckWait = -10 * time.Second + return o + }, "ack wait must be a positive value"}, + } { + t.Run(test.name, func(t *testing.T) { + err := validateMQTTOptions(test.getOpts()) + if test.err == "" && err != nil { + t.Fatalf("Unexpected error: %v", err) + } else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) { + t.Fatalf("Expected error to contain %q, got %v", test.err, err) + } + }) + } +} + +func TestMQTTParseOptions(t *testing.T) { + for _, test := range []struct { + name string + content string + checkOpt func(*MQTTOpts) error + err string + }{ + // Negative tests + {"bad type", "mqtt: []", nil, "to be a map"}, + {"bad listen", "mqtt: { listen: [] }", nil, "port or host:port"}, + {"bad port", `mqtt: { port: "abc" }`, nil, "not int64"}, + {"bad host", `mqtt: { host: 123 }`, nil, "not string"}, + {"bad tls", `mqtt: { tls: 123 }`, nil, "not map[string]interface {}"}, + {"unknown field", `mqtt: { this_does_not_exist: 123 }`, nil, "unknown"}, + {"ack wait", `mqtt: {ack_wait: abc}`, nil, "invalid duration"}, + {"max ack pending", `mqtt: {max_ack_pending: abc}`, nil, "not int64"}, + {"max ack pending too high", `mqtt: {max_ack_pending: 12345678}`, nil, "invalid value"}, + // Positive tests + {"tls gen fails", ` + mqtt { + tls { + cert_file: "./configs/certs/server.pem" + } + }`, nil, "missing 'key_file'"}, + {"listen port only", `mqtt { listen: 1234 }`, func(o *MQTTOpts) error { + if o.Port != 1234 { + return fmt.Errorf("expected 1234, got %v", o.Port) + } + return nil + }, ""}, + {"listen host and port", `mqtt { listen: "localhost:1234" }`, func(o *MQTTOpts) error { + if o.Host != "localhost" || o.Port != 1234 { + return fmt.Errorf("expected localhost:1234, got %v:%v", o.Host, o.Port) + } + return nil + }, ""}, + {"host", `mqtt { host: "localhost" }`, func(o *MQTTOpts) error { + if o.Host != "localhost" { + return fmt.Errorf("expected localhost, got %v", o.Host) + } + return nil + }, ""}, + {"port", `mqtt { port: 1234 }`, func(o *MQTTOpts) error { + if o.Port != 1234 { + return fmt.Errorf("expected 1234, got %v", o.Port) + } + return nil + }, ""}, + {"tls config", + ` + mqtt { + tls { + cert_file: "./configs/certs/server.pem" + key_file: "./configs/certs/key.pem" + } + } + `, func(o *MQTTOpts) error { + if o.TLSConfig == nil { + return fmt.Errorf("TLSConfig should have been set") + } + return nil + }, ""}, + {"no auth user", + ` + mqtt { + no_auth_user: "noauthuser" + } + `, func(o *MQTTOpts) error { + if o.NoAuthUser != "noauthuser" { + return fmt.Errorf("Invalid NoAuthUser value: %q", o.NoAuthUser) + } + return nil + }, ""}, + {"auth block", + ` + mqtt { + authorization { + user: "mqttuser" + password: "pwd" + token: "token" + timeout: 2.0 + } + } + `, func(o *MQTTOpts) error { + if o.Username != "mqttuser" || o.Password != "pwd" || o.Token != "token" || o.AuthTimeout != 2.0 { + return fmt.Errorf("Invalid auth block: %+v", o) + } + return nil + }, ""}, + {"auth timeout as int", + ` + mqtt { + authorization { + timeout: 2 + } + } + `, func(o *MQTTOpts) error { + if o.AuthTimeout != 2.0 { + return fmt.Errorf("Invalid auth timeout: %v", o.AuthTimeout) + } + return nil + }, ""}, + {"ack wait", + ` + mqtt { + ack_wait: "10s" + } + `, func(o *MQTTOpts) error { + if o.AckWait != 10*time.Second { + return fmt.Errorf("Invalid ack wait: %v", o.AckWait) + } + return nil + }, ""}, + {"max ack pending", + ` + mqtt { + max_ack_pending: 123 + } + `, func(o *MQTTOpts) error { + if o.MaxAckPending != 123 { + return fmt.Errorf("Invalid max ack pending: %v", o.MaxAckPending) + } + return nil + }, ""}, + } { + t.Run(test.name, func(t *testing.T) { + conf := createConfFile(t, []byte(test.content)) + defer os.Remove(conf) + o, err := ProcessConfigFile(conf) + if test.err != _EMPTY_ { + if err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("For content: %q, expected error about %q, got %v", test.content, test.err, err) + } + return + } else if err != nil { + t.Fatalf("Unexpected error for content %q: %v", test.content, err) + } + if err := test.checkOpt(&o.MQTT); err != nil { + t.Fatalf("Incorrect option for content %q: %v", test.content, err.Error()) + } + }) + } +} + +func TestMQTTStart(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + nc.Close() + + // Check failure to start due to port in use + o2 := testMQTTDefaultOptions() + o2.MQTT.Port = o.MQTT.Port + s2, err := NewServer(o2) + if err != nil { + t.Fatalf("Error creating server: %v", err) + } + defer s2.Shutdown() + l := &captureFatalLogger{fatalCh: make(chan string, 1)} + s2.SetLogger(l, false, false) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + s2.Start() + wg.Done() + }() + + select { + case e := <-l.fatalCh: + if !strings.Contains(e, "Unable to listen for MQTT connections") { + t.Fatalf("Unexpected error: %q", e) + } + case <-time.After(time.Second): + t.Fatal("Should have gotten a fatal error") + } +} + +func TestMQTTTLS(t *testing.T) { + o := testMQTTDefaultTLSOptions(t, false) + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + defer nc.Close() + // Set MaxVersion to TLSv1.2 so that we fail on handshake if there is + // a disagreement between server and client. + tlsc := &tls.Config{ + MaxVersion: tls.VersionTLS12, + InsecureSkipVerify: true, + } + tlsConn := tls.Client(nc, tlsc) + tlsConn.SetDeadline(time.Now().Add(time.Second)) + if err := tlsConn.Handshake(); err != nil { + t.Fatalf("Error doing tls handshake: %v", err) + } + nc.Close() + testMQTTShutdownServer(s) + + // Force client cert verification + o = testMQTTDefaultTLSOptions(t, true) + s = testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + defer nc.Close() + // Set MaxVersion to TLSv1.2 so that we fail on handshake if there is + // a disagreement between server and client. + tlsc = &tls.Config{ + MaxVersion: tls.VersionTLS12, + InsecureSkipVerify: true, + } + tlsConn = tls.Client(nc, tlsc) + tlsConn.SetDeadline(time.Now().Add(time.Second)) + if err := tlsConn.Handshake(); err == nil { + t.Fatal("Handshake expected to fail since client did not provide cert") + } + nc.Close() + + // Add client cert. + nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + defer nc.Close() + + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/client-cert.pem", + KeyFile: "../test/configs/certs/client-key.pem", + } + tlsc, err = GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error generating tls config: %v", err) + } + tlsc.InsecureSkipVerify = true + tlsConn = tls.Client(nc, tlsc) + tlsConn.SetDeadline(time.Now().Add(time.Second)) + if err := tlsConn.Handshake(); err != nil { + t.Fatalf("Handshake error: %v", err) + } + nc.Close() + testMQTTShutdownServer(s) + + // Lower TLS timeout so low that we should fail + o.MQTT.TLSTimeout = 0.001 + s = testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + defer nc.Close() + time.Sleep(100 * time.Millisecond) + tlsConn = tls.Client(nc, tlsc) + tlsConn.SetDeadline(time.Now().Add(time.Second)) + if err := tlsConn.Handshake(); err == nil { + t.Fatal("Expected failure, did not get one") + } +} + +type mqttConnInfo struct { + clientID string + cleanSess bool + keepAlive uint16 + will *mqttWill + user string + pass string +} + +func testMQTTGetClient(t testing.TB, s *Server, clientID string) *client { + t.Helper() + var mc *client + s.mu.Lock() + for _, c := range s.clients { + c.mu.Lock() + if c.mqtt != nil && c.mqtt.cp != nil && c.mqtt.cp.clientID == clientID { + mc = c + } + c.mu.Unlock() + if mc != nil { + break + } + } + s.mu.Unlock() + if mc == nil { + t.Fatalf("Did not find client %q", clientID) + } + return mc +} + +func testMQTTRead(c net.Conn) ([]byte, error) { + var buf [512]byte + // Make sure that test does not block + c.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := c.Read(buf[:]) + if err != nil { + return nil, err + } + c.SetReadDeadline(time.Time{}) + return copyBytes(buf[:n]), nil +} + +func testMQTTWrite(c net.Conn, buf []byte) (int, error) { + c.SetWriteDeadline(time.Now().Add(2 * time.Second)) + n, err := c.Write(buf) + c.SetWriteDeadline(time.Time{}) + return n, err +} + +func testMQTTConnect(t testing.TB, ci *mqttConnInfo, host string, port int) (net.Conn, *mqttReader) { + t.Helper() + + addr := fmt.Sprintf("%s:%d", host, port) + c, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating mqtt connection: %v", err) + } + + proto := mqttCreateConnectProto(ci) + if _, err := testMQTTWrite(c, proto); err != nil { + t.Fatalf("Error writing connect: %v", err) + } + + buf, err := testMQTTRead(c) + if err != nil { + t.Fatalf("Error reading: %v", err) + } + br := &mqttReader{reader: c} + br.reset(buf) + + return c, br +} + +func mqttCreateConnectProto(ci *mqttConnInfo) []byte { + flags := byte(0) + if ci.cleanSess { + flags |= mqttConnFlagCleanSession + } + if ci.will != nil { + flags |= mqttConnFlagWillFlag | (ci.will.qos << 3) + if ci.will.retain { + flags |= mqttConnFlagWillRetain + } + } + if ci.user != _EMPTY_ { + flags |= mqttConnFlagUsernameFlag + } + if ci.pass != _EMPTY_ { + flags |= mqttConnFlagPasswordFlag + } + + pkLen := 2 + len(mqttProtoName) + + 1 + // proto level + 1 + // flags + 2 + // keepAlive + 2 + len(ci.clientID) + + if ci.will != nil { + pkLen += 2 + len(ci.will.topic) + pkLen += 2 + len(ci.will.message) + } + if ci.user != _EMPTY_ { + pkLen += 2 + len(ci.user) + } + if ci.pass != _EMPTY_ { + pkLen += 2 + len(ci.pass) + } + + w := &mqttWriter{} + w.WriteByte(mqttPacketConnect) + w.WriteVarInt(pkLen) + w.WriteString(string(mqttProtoName)) + w.WriteByte(0x4) + w.WriteByte(flags) + w.WriteUint16(ci.keepAlive) + w.WriteString(ci.clientID) + if ci.will != nil { + w.WriteBytes(ci.will.topic) + w.WriteBytes(ci.will.message) + } + if ci.user != _EMPTY_ { + w.WriteString(ci.user) + } + if ci.pass != _EMPTY_ { + w.WriteBytes([]byte(ci.pass)) + } + return w.Bytes() +} + +func testMQTTCheckConnAck(t testing.TB, r *mqttReader, rc byte, sessionPresent bool) { + t.Helper() + r.reader.SetReadDeadline(time.Now().Add(2 * time.Second)) + if err := r.ensurePacketInBuffer(4); err != nil { + t.Fatalf("Error ensuring packet in buffer: %v", err) + } + r.reader.SetReadDeadline(time.Time{}) + b, err := r.readByte("connack packet type") + if err != nil { + t.Fatalf("Error reading packet type: %v", err) + } + pt := b & mqttPacketMask + if pt != mqttPacketConnectAck { + t.Fatalf("Expected ConnAck (%x), got %x", mqttPacketConnectAck, pt) + } + pl, err := r.readByte("connack packet len") + if err != nil { + t.Fatalf("Error reading packet length: %v", err) + } + if pl != 2 { + t.Fatalf("ConnAck packet length should be 2, got %v", pl) + } + caf, err := r.readByte("connack flags") + if err != nil { + t.Fatalf("Error reading packet length: %v", err) + } + if caf&0xfe != 0 { + t.Fatalf("ConnAck flag bits 7-1 should all be 0, got %x", caf>>1) + } + if sp := caf == 1; sp != sessionPresent { + t.Fatalf("Expected session present flag=%v got %v", sessionPresent, sp) + } + carc, err := r.readByte("connack return code") + if err != nil { + t.Fatalf("Error reading returned code: %v", err) + } + if carc != rc { + t.Fatalf("Expected return code to be %v, got %v", rc, carc) + } +} + +func testMQTTEnableJSForAccount(t *testing.T, s *Server, accName string) { + t.Helper() + acc, err := s.LookupAccount(accName) + if err != nil { + t.Fatalf("Error looking up account: %v", err) + } + limits := &JetStreamAccountLimits{ + MaxConsumers: -1, + MaxStreams: -1, + MaxMemory: 1024 * 1024, + } + if err := acc.EnableJetStream(limits); err != nil { + t.Fatalf("Error enabling JS: %v", err) + } +} + +func TestMQTTTLSVerifyAndMap(t *testing.T) { + accName := "MyAccount" + acc := NewAccount(accName) + certUserName := "CN=example.com,OU=NATS.io" + users := []*User{&User{Username: certUserName, Account: acc}} + + for _, test := range []struct { + name string + filtering bool + provideCert bool + }{ + {"no filtering, client provides cert", false, true}, + {"no filtering, client does not provide cert", false, false}, + {"filtering, client provides cert", true, true}, + {"filtering, client does not provide cert", true, false}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + o.Host = "localhost" + o.Accounts = []*Account{acc} + o.Users = users + if test.filtering { + o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeMqtt}) + } + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/tlsauth/server.pem", + KeyFile: "../test/configs/certs/tlsauth/server-key.pem", + CaFile: "../test/configs/certs/tlsauth/ca.pem", + Verify: true, + } + tlsc, err := GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error creating tls config: %v", err) + } + o.MQTT.TLSConfig = tlsc + o.MQTT.TLSTimeout = 2.0 + o.MQTT.TLSMap = true + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + testMQTTEnableJSForAccount(t, s, accName) + + addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port) + mc, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating ws connection: %v", err) + } + defer mc.Close() + tlscc := &tls.Config{} + if test.provideCert { + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/tlsauth/client.pem", + KeyFile: "../test/configs/certs/tlsauth/client-key.pem", + } + var err error + tlscc, err = GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error generating tls config: %v", err) + } + } + tlscc.InsecureSkipVerify = true + if test.provideCert { + tlscc.MinVersion = tls.VersionTLS13 + } + mc = tls.Client(mc, tlscc) + if err := mc.(*tls.Conn).Handshake(); err != nil { + t.Fatalf("Error during handshake: %v", err) + } + + ci := &mqttConnInfo{cleanSess: true} + proto := mqttCreateConnectProto(ci) + if _, err := testMQTTWrite(mc, proto); err != nil { + t.Fatalf("Error sending proto: %v", err) + } + buf, err := testMQTTRead(mc) + if !test.provideCert { + if err == nil { + t.Fatal("Expected error, did not get one") + } else if !strings.Contains(err.Error(), "bad certificate") { + t.Fatalf("Unexpected error: %v", err) + } + return + } + if err != nil { + t.Fatalf("Error reading: %v", err) + } + r := &mqttReader{reader: mc} + r.reset(buf) + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + var c *client + s.mu.Lock() + for _, sc := range s.clients { + sc.mu.Lock() + if sc.mqtt != nil { + c = sc + } + sc.mu.Unlock() + if c != nil { + break + } + } + s.mu.Unlock() + if c == nil { + t.Fatal("Client not found") + } + + var uname string + var accname string + c.mu.Lock() + uname = c.opts.Username + if c.acc != nil { + accname = c.acc.GetName() + } + c.mu.Unlock() + if uname != certUserName { + t.Fatalf("Expected username %q, got %q", certUserName, uname) + } + if accname != accName { + t.Fatalf("Expected account %q, got %v", accName, accname) + } + }) + } +} + +func TestMQTTBasicAuth(t *testing.T) { + for _, test := range []struct { + name string + opts func() *Options + user string + pass string + rc byte + }{ + { + "top level auth, no override, wrong u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.Username = "normal" + o.Password = "client" + return o + }, + "mqtt", "client", mqttConnAckRCNotAuthorized, + }, + { + "top level auth, no override, correct u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.Username = "normal" + o.Password = "client" + return o + }, + "normal", "client", mqttConnAckRCConnectionAccepted, + }, + { + "no top level auth, mqtt auth, wrong u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + return o + }, + "normal", "client", mqttConnAckRCNotAuthorized, + }, + { + "no top level auth, mqtt auth, correct u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + return o + }, + "mqtt", "client", mqttConnAckRCConnectionAccepted, + }, + { + "top level auth, mqtt override, wrong u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.Username = "normal" + o.Password = "client" + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + return o + }, + "normal", "client", mqttConnAckRCNotAuthorized, + }, + { + "top level auth, mqtt override, correct u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.Username = "normal" + o.Password = "client" + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + return o + }, + "mqtt", "client", mqttConnAckRCConnectionAccepted, + }, + } { + t.Run(test.name, func(t *testing.T) { + o := test.opts() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: test.user, + pass: test.pass, + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, test.rc, false) + }) + } +} + +func TestMQTTAuthTimeout(t *testing.T) { + for _, test := range []struct { + name string + at float64 + mat float64 + ok bool + }{ + {"use top-level auth timeout", 0.5, 0.0, true}, + {"use mqtt auth timeout", 0.5, 0.05, false}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + o.AuthTimeout = test.at + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + o.MQTT.AuthTimeout = test.mat + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + defer mc.Close() + + time.Sleep(100 * time.Millisecond) + + ci := &mqttConnInfo{ + cleanSess: true, + user: "mqtt", + pass: "client", + } + proto := mqttCreateConnectProto(ci) + if _, err := testMQTTWrite(mc, proto); err != nil { + if test.ok { + t.Fatalf("Error sending connect: %v", err) + } + // else it is ok since we got disconnected due to auth timeout + return + } + buf, err := testMQTTRead(mc) + if err != nil { + if test.ok { + t.Fatalf("Error reading: %v", err) + } + // else it is ok since we got disconnected due to auth timeout + return + } + r := &mqttReader{reader: mc} + r.reset(buf) + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + time.Sleep(500 * time.Millisecond) + testMQTTPublish(t, mc, r, 1, false, false, "foo", 1, []byte("msg")) + }) + } +} + +func TestMQTTTokenAuth(t *testing.T) { + for _, test := range []struct { + name string + opts func() *Options + token string + rc byte + }{ + { + "top level auth, no override, wrong token", + func() *Options { + o := testMQTTDefaultOptions() + o.Authorization = "goodtoken" + return o + }, + "badtoken", mqttConnAckRCNotAuthorized, + }, + { + "top level auth, no override, correct token", + func() *Options { + o := testMQTTDefaultOptions() + o.Authorization = "goodtoken" + return o + }, + "goodtoken", mqttConnAckRCConnectionAccepted, + }, + { + "no top level auth, mqtt auth, wrong token", + func() *Options { + o := testMQTTDefaultOptions() + o.MQTT.Token = "goodtoken" + return o + }, + "badtoken", mqttConnAckRCNotAuthorized, + }, + { + "no top level auth, mqtt auth, correct token", + func() *Options { + o := testMQTTDefaultOptions() + o.MQTT.Token = "goodtoken" + return o + }, + "goodtoken", mqttConnAckRCConnectionAccepted, + }, + { + "top level auth, mqtt override, wrong token", + func() *Options { + o := testMQTTDefaultOptions() + o.Authorization = "clienttoken" + o.MQTT.Token = "mqtttoken" + return o + }, + "clienttoken", mqttConnAckRCNotAuthorized, + }, + { + "top level auth, mqtt override, correct token", + func() *Options { + o := testMQTTDefaultOptions() + o.Authorization = "clienttoken" + o.MQTT.Token = "mqtttoken" + return o + }, + "mqtttoken", mqttConnAckRCConnectionAccepted, + }, + } { + t.Run(test.name, func(t *testing.T) { + o := test.opts() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: "ignore_use_token", + pass: test.token, + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, test.rc, false) + }) + } +} + +func TestMQTTUsersAuth(t *testing.T) { + users := []*User{&User{Username: "user", Password: "pwd"}} + for _, test := range []struct { + name string + opts func() *Options + user string + pass string + rc byte + }{ + { + "no filtering, wrong user", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + return o + }, + "wronguser", "pwd", mqttConnAckRCNotAuthorized, + }, + { + "no filtering, correct user", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + return o + }, + "user", "pwd", mqttConnAckRCConnectionAccepted, + }, + { + "filtering, user not allowed", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + // Only allowed for regular clients + o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard}) + return o + }, + "user", "pwd", mqttConnAckRCNotAuthorized, + }, + { + "filtering, user allowed", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeMqtt}) + return o + }, + "user", "pwd", mqttConnAckRCConnectionAccepted, + }, + { + "filtering, wrong password", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeMqtt}) + return o + }, + "user", "badpassword", mqttConnAckRCNotAuthorized, + }, + } { + t.Run(test.name, func(t *testing.T) { + o := test.opts() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: test.user, + pass: test.pass, + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, test.rc, false) + }) + } +} + +func TestMQTTNoAuthUserValidation(t *testing.T) { + o := testMQTTDefaultOptions() + o.Users = []*User{&User{Username: "user", Password: "pwd"}} + // Should fail because it is not part of o.Users. + o.MQTT.NoAuthUser = "notfound" + if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") { + t.Fatalf("Expected error saying not present as user, got %v", err) + } + + // Set a valid no auth user for global options, but still should fail because + // of o.MQTT.NoAuthUser + o.NoAuthUser = "user" + o.MQTT.NoAuthUser = "notfound" + if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") { + t.Fatalf("Expected error saying not present as user, got %v", err) + } +} + +func TestMQTTNoAuthUser(t *testing.T) { + for _, test := range []struct { + name string + override bool + useAuth bool + expectedUser string + expectedAcc string + }{ + {"no override, no user provided", false, false, "noauth", "normal"}, + {"no override, user povided", false, true, "user", "normal"}, + {"override, no user provided", true, false, "mqttnoauth", "mqtt"}, + {"override, user provided", true, true, "mqttuser", "mqtt"}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + normalAcc := NewAccount("normal") + mqttAcc := NewAccount("mqtt") + o.Accounts = []*Account{normalAcc, mqttAcc} + o.Users = []*User{ + &User{Username: "noauth", Password: "pwd", Account: normalAcc}, + &User{Username: "user", Password: "pwd", Account: normalAcc}, + &User{Username: "mqttnoauth", Password: "pwd", Account: mqttAcc}, + &User{Username: "mqttuser", Password: "pwd", Account: mqttAcc}, + } + o.NoAuthUser = "noauth" + if test.override { + o.MQTT.NoAuthUser = "mqttnoauth" + } + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + testMQTTEnableJSForAccount(t, s, "normal") + testMQTTEnableJSForAccount(t, s, "mqtt") + + ci := &mqttConnInfo{clientID: "mqtt", cleanSess: true} + if test.useAuth { + ci.user = test.expectedUser + ci.pass = "pwd" + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + c := testMQTTGetClient(t, s, "mqtt") + c.mu.Lock() + uname := c.opts.Username + aname := c.acc.GetName() + c.mu.Unlock() + if uname != test.expectedUser { + t.Fatalf("Expected selected user to be %q, got %q", test.expectedUser, uname) + } + if aname != test.expectedAcc { + t.Fatalf("Expected selected account to be %q, got %q", test.expectedAcc, aname) + } + }) + } +} + +func TestMQTTConnectNotFirstProto(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + c, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Error on dial: %v", err) + } + defer c.Close() + + w := &mqttWriter{} + mqttWritePublish(w, 0, false, false, "foo", 0, []byte("hello")) + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error publishing: %v", err) + } + testMQTTExpectDisconnect(t, c) +} + +func TestMQTTSecondConnect(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + proto := mqttCreateConnectProto(&mqttConnInfo{cleanSess: true}) + if _, err := testMQTTWrite(mc, proto); err != nil { + t.Fatalf("Error writing connect: %v", err) + } + testMQTTExpectDisconnect(t, mc) +} + +func TestMQTTParseConnect(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + proto []byte + pl int + reader mqttIOReader + err string + }{ + {"packet in buffer error", nil, 10, eofr, "error ensuring protocol is loaded"}, + {"bad proto name", []byte{0, 4, 'B', 'A', 'D'}, 5, nil, "protocol name"}, + {"invalid proto name", []byte{0, 3, 'B', 'A', 'D'}, 5, nil, "expected connect packet with protocol name"}, + {"old proto not supported", []byte{0, 6, 'M', 'Q', 'I', 's', 'd', 'p'}, 8, nil, "older protocol"}, + {"error on protocol level", []byte{0, 4, 'M', 'Q', 'T', 'T'}, 6, eofr, "protocol level"}, + {"unacceptable protocol version", []byte{0, 4, 'M', 'Q', 'T', 'T', 10}, 7, nil, "unacceptable protocol version"}, + {"error on flags", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel}, 7, eofr, "flags"}, + {"reserved flag", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 1}, 8, nil, "connect flags reserved bit not set to 0"}, + {"will qos without will flag", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 1 << 3}, 8, nil, "if Will flag is set to 0, Will QoS must be 0 too"}, + {"will retain without will flag", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 1 << 5}, 8, nil, "if Will flag is set to 0, Will Retain flag must be 0 too"}, + {"will qos", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 3<<3 | 1<<2}, 8, nil, "if Will flag is set to 1, Will QoS can be 0, 1 or 2"}, + {"no user but password", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagPasswordFlag}, 8, nil, "password flag set but username flag is not"}, + {"missing keep alive", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0}, 8, nil, "keep alive"}, + {"missing client ID", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0, 0, 1}, 10, nil, "client ID"}, + {"empty client ID", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0, 0, 1, 0, 0}, 12, nil, "when client ID is empty, clean session flag must be set to 1"}, + {"invalid utf8 client ID", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0, 0, 1, 0, 1, 241}, 13, nil, "invalid utf8 for client ID"}, + {"missing will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0}, 12, nil, "Will topic"}, + {"empty will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 0}, 14, nil, "empty Will topic not allowed"}, + {"invalid utf8 will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 241}, 15, nil, "invalide utf8 for Will topic"}, + {"invalid wildcard will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, '#'}, 15, nil, "wildcards not allowed"}, + {"error on will message", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 'a', 0, 3}, 17, eofr, "Will message"}, + {"error on username", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagCleanSession, 0, 0, 0, 0}, 12, eofr, "user name"}, + {"empty username", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 0}, 14, nil, "empty user name not allowed"}, + {"invalid utf8 username", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 241}, 15, nil, "invalid utf8 for user name"}, + {"error on password", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagPasswordFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 'a'}, 15, eofr, "password"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + mqtt := &mqtt{r: r} + c := &client{mqtt: mqtt} + if _, _, err := c.mqttParseConnect(r, test.pl); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func TestMQTTConnectFailsOnParse(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port) + c, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating mqtt connection: %v", err) + } + + pkLen := 2 + len(mqttProtoName) + + 1 + // proto level + 1 + // flags + 2 + // keepAlive + 2 + len("mqtt") + + w := &mqttWriter{} + w.WriteByte(mqttPacketConnect) + w.WriteVarInt(pkLen) + w.WriteString(string(mqttProtoName)) + w.WriteByte(0x7) + w.WriteByte(mqttConnFlagCleanSession) + w.WriteUint16(0) + w.WriteString("mqtt") + c.Write(w.Bytes()) + + buf, err := testMQTTRead(c) + if err != nil { + t.Fatalf("Error reading: %v", err) + } + r := &mqttReader{reader: c} + r.reset(buf) + testMQTTCheckConnAck(t, r, mqttConnAckRCUnacceptableProtocolVersion, false) +} + +func TestMQTTConnKeepAlive(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true, keepAlive: 1}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, mc, r, 0, false, false, "foo", 0, []byte("msg")) + + time.Sleep(2 * time.Second) + testMQTTExpectDisconnect(t, mc) +} + +func TestMQTTTopicAndSubjectConversion(t *testing.T) { + for _, test := range []struct { + name string + mqttTopic string + natsSubject string + err string + }{ + {"/", "/", "/./", ""}, + {"//", "//", "/././", ""}, + {"///", "///", "/./././", ""}, + {"////", "////", "/././././", ""}, + {"foo", "foo", "foo", ""}, + {"/foo", "/foo", "/.foo", ""}, + {"//foo", "//foo", "/./.foo", ""}, + {"///foo", "///foo", "/././.foo", ""}, + {"///foo/", "///foo/", "/././.foo./", ""}, + {"///foo//", "///foo//", "/././.foo././", ""}, + {"///foo///", "///foo///", "/././.foo./././", ""}, + {"foo/bar", "foo/bar", "foo.bar", ""}, + {"/foo/bar", "/foo/bar", "/.foo.bar", ""}, + {"/foo/bar/", "/foo/bar/", "/.foo.bar./", ""}, + {"foo/bar/baz", "foo/bar/baz", "foo.bar.baz", ""}, + {"/foo/bar/baz", "/foo/bar/baz", "/.foo.bar.baz", ""}, + {"/foo/bar/baz/", "/foo/bar/baz/", "/.foo.bar.baz./", ""}, + {"bar", "bar/", "bar./", ""}, + {"bar//", "bar//", "bar././", ""}, + {"bar///", "bar///", "bar./././", ""}, + {"foo//bar", "foo//bar", "foo./.bar", ""}, + {"foo///bar", "foo///bar", "foo././.bar", ""}, + {"foo////bar", "foo////bar", "foo./././.bar", ""}, + // These should produce errors + {"foo/+", "foo/+", "", "wildcards not allowed in publish"}, + {"foo/#", "foo/#", "", "wildcards not allowed in publish"}, + {"foo bar", "foo bar", "", "not supported"}, + {"foo.bar", "foo.bar", "", "not supported"}, + } { + t.Run(test.name, func(t *testing.T) { + _, res, err := mqttTopicToNATSPubSubject([]byte(test.mqttTopic)) + if test.err != _EMPTY_ { + if err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %q", test.err, err.Error()) + } + return + } + toNATS := string(res) + if toNATS != test.natsSubject { + t.Fatalf("Expected subject %q got %q", test.natsSubject, toNATS) + } + + res = natsSubjectToMQTTTopic(toNATS) + backToMQTT := string(res) + if backToMQTT != test.mqttTopic { + t.Fatalf("Expected topic %q got %q (NATS conversion was %q)", test.mqttTopic, backToMQTT, toNATS) + } + }) + } +} + +func TestMQTTFilterConversion(t *testing.T) { + // Similar to TopicConversion test except that wildcards are OK here. + // So testing only those. + for _, test := range []struct { + name string + mqttTopic string + natsSubject string + }{ + {"single level wildcard", "+", "*"}, + {"single level wildcard", "/+", "/.*"}, + {"single level wildcard", "+/", "*./"}, + {"single level wildcard", "/+/", "/.*./"}, + {"single level wildcard", "foo/+", "foo.*"}, + {"single level wildcard", "foo/+/", "foo.*./"}, + {"single level wildcard", "foo/+/bar", "foo.*.bar"}, + {"single level wildcard", "foo/+/+", "foo.*.*"}, + {"single level wildcard", "foo/+/+/", "foo.*.*./"}, + {"single level wildcard", "foo/+/+/bar", "foo.*.*.bar"}, + + {"multi level wildcard", "#", ">"}, + {"multi level wildcard", "/#", "/.>"}, + {"multi level wildcard", "/foo/#", "/.foo.>"}, + {"multi level wildcard", "foo/#", "foo.>"}, + {"multi level wildcard", "foo/bar/#", "foo.bar.>"}, + } { + t.Run(test.name, func(t *testing.T) { + _, res, err := mqttFilterToNATSSubject([]byte(test.mqttTopic)) + if err != nil { + t.Fatalf("Error: %v", err) + } + if string(res) != test.natsSubject { + t.Fatalf("Expected subject %q got %q", test.natsSubject, res) + } + }) + } +} + +func testMQTTReaderHasAtLeastOne(t testing.TB, r *mqttReader) { + t.Helper() + r.reader.SetReadDeadline(time.Now().Add(2 * time.Second)) + if err := r.ensurePacketInBuffer(1); err != nil { + t.Fatal(err) + } + r.reader.SetReadDeadline(time.Time{}) +} + +func TestMQTTParseSub(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + proto []byte + b byte + pl int + reader mqttIOReader + err string + }{ + {"reserved flag", nil, 3, 0, nil, "wrong subscribe reserved flags"}, + {"ensure packet loaded", []byte{1, 2}, mqttSubscribeFlags, 10, eofr, "error ensuring protocol is loaded"}, + {"error reading packet id", []byte{1}, mqttSubscribeFlags, 1, eofr, "reading packet identifier"}, + {"missing filters", []byte{0, 1}, mqttSubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"}, + {"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttSubscribeFlags, 5, eofr, "topic filter"}, + {"empty topic", []byte{0, 1, 0, 0}, mqttSubscribeFlags, 4, nil, "topic filter cannot be empty"}, + {"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttSubscribeFlags, 5, nil, "invalid utf8 for topic filter"}, + {"missing qos", []byte{0, 1, 0, 1, 'a'}, mqttSubscribeFlags, 5, nil, "QoS"}, + {"invalid qos", []byte{0, 1, 0, 1, 'a', 3}, mqttSubscribeFlags, 6, nil, "subscribe QoS value must be 0, 1 or 2"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + mqtt := &mqtt{r: r} + c := &client{mqtt: mqtt} + if _, _, err := c.mqttParseSubsOrUnsubs(r, test.b, test.pl, true); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func testMQTTSub(t testing.TB, pi uint16, c net.Conn, r *mqttReader, filters []*mqttFilter, expected []byte) { + t.Helper() + w := &mqttWriter{} + pkLen := 2 // for pi + for i := 0; i < len(filters); i++ { + f := filters[i] + pkLen += 2 + len(f.filter) + 1 + } + w.WriteByte(mqttPacketSub | mqttSubscribeFlags) + w.WriteVarInt(pkLen) + w.WriteUint16(pi) + for i := 0; i < len(filters); i++ { + f := filters[i] + w.WriteBytes([]byte(f.filter)) + w.WriteByte(f.qos) + } + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing SUBSCRIBE protocol: %v", err) + } + // Make sure we have at least 1 byte in buffer (if not will read) + testMQTTReaderHasAtLeastOne(t, r) + // Parse SUBACK + b, err := r.readByte("packet type") + if err != nil { + t.Fatal(err) + } + if pt := b & mqttPacketMask; pt != mqttPacketSubAck { + t.Fatalf("Expected SUBACK packet %x, got %x", mqttPacketSubAck, pt) + } + pl, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + t.Fatal(err) + } + rpi, err := r.readUint16("packet identifier") + if err != nil || rpi != pi { + t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err) + } + for i, rem := 0, pl-2; rem > 0; rem-- { + qos, err := r.readByte("filter qos") + if err != nil { + t.Fatal(err) + } + if qos != expected[i] { + t.Fatalf("For topic filter %q expected qos of %v, got %v", + filters[i].filter, expected[i], qos) + } + i++ + } +} + +func TestMQTTSubAck(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + subs := []*mqttFilter{ + {filter: "foo", qos: 0}, + {filter: "bar", qos: 1}, + {filter: "baz", qos: 2}, // Since we don't support, we should receive a result of 1 + {filter: "foo/#/bar", qos: 0}, // Invalid sub, so we should receive a result of mqttSubAckFailure + } + expected := []byte{ + 0, + 1, + 1, + mqttSubAckFailure, + } + testMQTTSub(t, 1, mc, r, subs, expected) +} + +func testMQTTFlush(t testing.TB, c net.Conn, bw *bufio.Writer, r *mqttReader) { + t.Helper() + w := &mqttWriter{} + w.WriteByte(mqttPacketPing) + w.WriteByte(0) + if bw != nil { + bw.Write(w.Bytes()) + bw.Flush() + } else { + c.Write(w.Bytes()) + } + r.ensurePacketInBuffer(2) + ab, err := r.readByte("pingresp") + if err != nil { + t.Fatalf("Error reading ping response: %v", err) + } + if pt := ab & mqttPacketMask; pt != mqttPacketPingResp { + t.Fatalf("Expected ping response got %x", pt) + } + l, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if l != 0 { + t.Fatalf("Expected PINGRESP length to be 0, got %v", l) + } +} + +func testMQTTExpectNothing(t testing.TB, r *mqttReader) { + t.Helper() + r.reader.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + if err := r.ensurePacketInBuffer(1); err == nil { + t.Fatalf("Expected nothing, got %v", r.buf[r.pos:]) + } + r.reader.SetReadDeadline(time.Time{}) +} + +func testMQTTCheckPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, flags byte, payload []byte) { + t.Helper() + pflags, pi := testMQTTGetPubMsg(t, c, r, topic, payload) + if pflags != flags { + t.Fatalf("Expected flags to be %x, got %x", flags, pflags) + } + if pi > 0 { + testMQTTSendPubAck(t, c, pi) + } +} + +func testMQTTCheckPubMsgNoAck(t testing.TB, c net.Conn, r *mqttReader, topic string, flags byte, payload []byte) uint16 { + t.Helper() + pflags, pi := testMQTTGetPubMsg(t, c, r, topic, payload) + if pflags != flags { + t.Fatalf("Expected flags to be %x, got %x", flags, pflags) + } + return pi +} + +func testMQTTGetPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, payload []byte) (byte, uint16) { + t.Helper() + testMQTTReaderHasAtLeastOne(t, r) + b, err := r.readByte("packet type") + if err != nil { + t.Fatal(err) + } + if pt := b & mqttPacketMask; pt != mqttPacketPub { + t.Fatalf("Expected PUBLISH packet %x, got %x", mqttPacketPub, pt) + } + pflags := b & mqttPacketFlagMask + qos := (pflags & mqttPubFlagQoS) >> 1 + pl, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + t.Fatal(err) + } + start := r.pos + ptopic, err := r.readString("topic name") + if err != nil { + t.Fatal(err) + } + if ptopic != topic { + t.Fatalf("Expected topic %q, got %q", topic, ptopic) + } + var pi uint16 + if qos > 0 { + pi, err = r.readUint16("packet identifier") + if err != nil { + t.Fatal(err) + } + } + msgLen := pl - (r.pos - start) + if r.pos+msgLen > len(r.buf) { + t.Fatalf("computed message length goes beyond buffer: ml=%v pos=%v lenBuf=%v", + msgLen, r.pos, len(r.buf)) + } + ppayload := r.buf[r.pos : r.pos+msgLen] + if !bytes.Equal(payload, ppayload) { + t.Fatalf("Expected payload %q, got %q", payload, ppayload) + } + r.pos += msgLen + return pflags, pi +} + +func testMQTTSendPubAck(t testing.TB, c net.Conn, pi uint16) { + t.Helper() + w := &mqttWriter{} + w.WriteByte(mqttPacketPubAck) + w.WriteVarInt(2) + w.WriteUint16(pi) + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing PUBACK: %v", err) + } +} + +func testMQTTPublish(t testing.TB, c net.Conn, r *mqttReader, qos byte, dup, retain bool, topic string, pi uint16, payload []byte) { + t.Helper() + w := &mqttWriter{} + mqttWritePublish(w, qos, dup, retain, topic, pi, payload) + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing PUBLISH proto: %v", err) + } + if qos > 0 { + // Since we don't support QoS 2, we should get disconnected + if qos == 2 { + testMQTTExpectDisconnect(t, c) + return + } + testMQTTReaderHasAtLeastOne(t, r) + // Parse PUBACK + b, err := r.readByte("packet type") + if err != nil { + t.Fatal(err) + } + if pt := b & mqttPacketMask; pt != mqttPacketPubAck { + t.Fatalf("Expected PUBACK packet %x, got %x", mqttPacketPubAck, pt) + } + pl, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + t.Fatal(err) + } + rpi, err := r.readUint16("packet identifier") + if err != nil || rpi != pi { + t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err) + } + } +} + +func TestMQTTParsePub(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + flags byte + proto []byte + pl int + reader mqttIOReader + err string + }{ + {"qos not supported", 0x4, nil, 0, nil, "not supported"}, + {"packet in buffer error", 0, nil, 10, eofr, "error ensuring protocol is loaded"}, + {"error on topic", 0, []byte{0, 3, 'f', 'o'}, 4, eofr, "topic"}, + {"empty topic", 0, []byte{0, 0}, 2, nil, "topic cannot be empty"}, + {"wildcards topic", 0, []byte{0, 1, '#'}, 3, nil, "wildcards not allowed"}, + {"error on packet identifier", mqttPubQos1, []byte{0, 3, 'f', 'o', 'o'}, 5, eofr, "packet identifier"}, + {"invalid packet identifier", mqttPubQos1, []byte{0, 3, 'f', 'o', 'o', 0, 0}, 7, nil, "packet identifier cannot be 0"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + mqtt := &mqtt{r: r} + c := &client{mqtt: mqtt} + pp := &mqttPublish{flags: test.flags} + if err := c.mqttParsePub(r, test.pl, pp); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func TestMQTTParsePubAck(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + proto []byte + pl int + reader mqttIOReader + err string + }{ + {"packet in buffer error", nil, 10, eofr, "error ensuring protocol is loaded"}, + {"error reading packet identifier", []byte{0}, 1, eofr, "packet identifier"}, + {"invalid packet identifier", []byte{0, 0}, 2, nil, "packet identifier cannot be 0"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + if _, err := mqttParsePubAck(r, test.pl); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func TestMQTTPublish(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, mcp, mpr, 0, false, false, "foo", 0, []byte("msg")) + testMQTTPublish(t, mcp, mpr, 1, false, false, "foo", 1, []byte("msg")) + testMQTTPublish(t, mcp, mpr, 2, false, false, "foo", 2, []byte("msg")) +} + +func TestMQTTSub(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + for _, test := range []struct { + name string + mqttSubTopic string + natsPubSubject string + mqttPubTopic string + ok bool + }{ + {"1 level match", "foo", "foo", "foo", true}, + {"1 level no match", "foo", "bar", "bar", false}, + {"2 levels match", "foo/bar", "foo.bar", "foo/bar", true}, + {"2 levels no match", "foo/bar", "foo.baz", "foo/baz", false}, + {"3 levels match", "/foo/bar", "/.foo.bar", "/foo/bar", true}, + {"3 levels no match", "/foo/bar", "/.foo.baz", "/foo/baz", false}, + + {"single level wc", "foo/+", "foo.bar.baz", "foo/bar/baz", false}, + {"single level wc", "foo/+", "foo.bar./", "foo/bar/", false}, + {"single level wc", "foo/+", "foo.bar", "foo/bar", true}, + {"single level wc", "foo/+", "foo./", "foo/", true}, + {"single level wc", "foo/+", "foo", "foo", false}, + {"single level wc", "foo/+", "/.foo", "/foo", false}, + + {"multiple level wc", "foo/#", "foo.bar.baz./", "foo/bar/baz/", true}, + {"multiple level wc", "foo/#", "foo.bar.baz", "foo/bar/baz", true}, + {"multiple level wc", "foo/#", "foo.bar./", "foo/bar/", true}, + {"multiple level wc", "foo/#", "foo.bar", "foo/bar", true}, + {"multiple level wc", "foo/#", "foo./", "foo/", true}, + {"multiple level wc", "foo/#", "foo", "foo", true}, + {"multiple level wc", "foo/#", "/.foo", "/foo", false}, + } { + t.Run(test.name, func(t *testing.T) { + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: test.mqttSubTopic, qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + natsPub(t, nc, test.natsPubSubject, []byte("msg")) + if test.ok { + testMQTTCheckPubMsg(t, mc, r, test.mqttPubTopic, 0, []byte("msg")) + } else { + testMQTTExpectNothing(t, r) + } + + testMQTTPublish(t, mcp, mpr, 0, false, false, test.mqttPubTopic, 0, []byte("msg")) + if test.ok { + testMQTTCheckPubMsg(t, mc, r, test.mqttPubTopic, 0, []byte("msg")) + } else { + testMQTTExpectNothing(t, r) + } + }) + } +} + +func TestMQTTSubQoS(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + mqttTopic := "foo/bar" + + // Subscribe with QoS 1 + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 1}}, []byte{1}) + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: mqttTopic, qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + // Publish from NATS, which means QoS 0 + natsPub(t, nc, "foo.bar", []byte("NATS")) + // Will receive as QoS 0 + testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("NATS")) + testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("NATS")) + + // Publish from MQTT with QoS 0 + testMQTTPublish(t, mcp, mpr, 0, false, false, mqttTopic, 0, []byte("msg")) + // Will receive as QoS 0 + testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("msg")) + + // Publish from MQTT with QoS 1 + testMQTTPublish(t, mcp, mpr, 1, false, false, mqttTopic, 1, []byte("msg")) + pflags1, pi1 := testMQTTGetPubMsg(t, mc, r, mqttTopic, []byte("msg")) + if pflags1 != 0x2 { + t.Fatalf("Expected flags to be 0x2, got %v", pflags1) + } + pflags2, pi2 := testMQTTGetPubMsg(t, mc, r, mqttTopic, []byte("msg")) + if pflags2 != 0x2 { + t.Fatalf("Expected flags to be 0x2, got %v", pflags2) + } + if pi1 == pi2 { + t.Fatalf("packet identifier for message 1: %v should be different from message 2", pi1) + } + testMQTTSendPubAck(t, mc, pi1) + testMQTTSendPubAck(t, mc, pi2) +} + +func getSubQoS(sub *subscription) int { + if sub.mqtt != nil { + return int(sub.mqtt.qos) + } + return -1 +} + +func TestMQTTSubDups(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Test with single SUBSCRIBE protocol but multiple filters + filters := []*mqttFilter{ + &mqttFilter{filter: "foo", qos: 1}, + &mqttFilter{filter: "foo", qos: 0}, + } + testMQTTSub(t, 1, mc, r, filters, []byte{1, 0}) + testMQTTFlush(t, mc, nil, r) + + // And also with separate SUBSCRIBE protocols + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 0}}, []byte{0}) + // Ask for QoS 2 but server will downgrade to 1 + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 2}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + // Publish and test msg received only once + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + testMQTTPublish(t, mcp, r, 0, false, false, "bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "bar", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + // Check that the QoS for subscriptions have been updated to the latest received filter + var err error + subc := testMQTTGetClient(t, s, "sub") + subc.mu.Lock() + if subc.opts.Username != "sub" { + err = fmt.Errorf("wrong user name") + } + if err == nil { + if sub := subc.subs["foo"]; sub == nil || getSubQoS(sub) != 0 { + err = fmt.Errorf("subscription foo QoS should be 0, got %v", getSubQoS(sub)) + } + } + if err == nil { + if sub := subc.subs["bar"]; sub == nil || getSubQoS(sub) != 1 { + err = fmt.Errorf("subscription bar QoS should be 1, got %v", getSubQoS(sub)) + } + } + subc.mu.Unlock() + if err != nil { + t.Fatal(err) + } + + // Now subscribe on "foo/#" which means that a PUBLISH on "foo" will be received + // by this subscription and also the one on "foo". + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + // Publish and test msg received twice + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + + checkWCSub := func(expectedQoS int) { + t.Helper() + + subc.mu.Lock() + defer subc.mu.Unlock() + + // When invoked with expectedQoS==1, we have the following subs: + // foo (QoS-0), bar (QoS-1), foo.> (QoS-1) + // which means (since QoS-1 have a JS consumer + sub for delivery + // and foo.> causes a "foo fwc") that we should have the following + // number of NATS subs: foo (1), bar (2), foo.> (2) and "foo fwc" (2), + // so total=7. + // When invoked with expectedQoS==0, it means that we have replaced + // foo/# QoS-1 to QoS-0, so we should have 2 less NATS subs, + // so total=5 + expected := 7 + if expectedQoS == 0 { + expected = 5 + } + if lenmap := len(subc.subs); lenmap != expected { + t.Fatalf("Subs map should have %v entries, got %v", expected, lenmap) + } + if sub, ok := subc.subs["foo.>"]; !ok { + t.Fatal("Expected sub foo.> to be present but was not") + } else if getSubQoS(sub) != expectedQoS { + t.Fatalf("Expected sub foo.> QoS to be %v, got %v", expectedQoS, getSubQoS(sub)) + } + if sub, ok := subc.subs["foo fwc"]; !ok { + t.Fatal("Expected sub foo fwc to be present but was not") + } else if getSubQoS(sub) != expectedQoS { + t.Fatalf("Expected sub foo fwc QoS to be %v, got %v", expectedQoS, getSubQoS(sub)) + } + // Make sure existing sub on "foo" qos was not changed. + if sub, ok := subc.subs["foo"]; !ok { + t.Fatal("Expected sub foo to be present but was not") + } else if getSubQoS(sub) != 0 { + t.Fatalf("Expected sub foo QoS to be 0, got %v", getSubQoS(sub)) + } + } + checkWCSub(1) + + // Sub again on same subject with lower QoS + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + // Publish and test msg received twice + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + checkWCSub(0) +} + +func TestMQTTSubWithSpaces(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo bar", qos: 0}}, []byte{mqttSubAckFailure}) +} + +func TestMQTTSubCaseSensitive(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "Foo/Bar", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + testMQTTPublish(t, mcp, r, 0, false, false, "Foo/Bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "Foo/Bar", 0, []byte("msg")) + + testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + natsPub(t, nc, "Foo.Bar", []byte("nats")) + testMQTTCheckPubMsg(t, mc, r, "Foo/Bar", 0, []byte("nats")) + + natsPub(t, nc, "foo.bar", []byte("nats")) + testMQTTExpectNothing(t, r) +} + +func TestMQTTPubSubMatrix(t *testing.T) { + for _, test := range []struct { + name string + natsPub bool + mqttPub bool + mqttPubQoS byte + natsSub bool + mqttSubQoS0 bool + mqttSubQoS1 bool + }{ + {"NATS to MQTT sub QoS-0", true, false, 0, false, true, false}, + {"NATS to MQTT sub QoS-1", true, false, 0, false, false, true}, + {"NATS to MQTT sub QoS-0 and QoS-1", true, false, 0, false, true, true}, + + {"MQTT QoS-0 to NATS sub", false, true, 0, true, false, false}, + {"MQTT QoS-0 to MQTT sub QoS-0", false, true, 0, false, true, false}, + {"MQTT QoS-0 to MQTT sub QoS-1", false, true, 0, false, false, true}, + {"MQTT QoS-0 to NATS sub and MQTT sub QoS-0", false, true, 0, true, true, false}, + {"MQTT QoS-0 to NATS sub and MQTT sub QoS-1", false, true, 0, true, false, true}, + {"MQTT QoS-0 to all subs", false, true, 0, true, true, true}, + + {"MQTT QoS-1 to NATS sub", false, true, 1, true, false, false}, + {"MQTT QoS-1 to MQTT sub QoS-0", false, true, 1, false, true, false}, + {"MQTT QoS-1 to MQTT sub QoS-1", false, true, 1, false, false, true}, + {"MQTT QoS-1 to NATS sub and MQTT sub QoS-0", false, true, 1, true, true, false}, + {"MQTT QoS-1 to NATS sub and MQTT sub QoS-1", false, true, 1, true, false, true}, + {"MQTT QoS-1 to all subs", false, true, 1, true, true, true}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + mc1, r1 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc1.Close() + testMQTTCheckConnAck(t, r1, mqttConnAckRCConnectionAccepted, false) + + mc2, r2 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, false) + + // First setup subscriptions based on test options. + var ns *nats.Subscription + if test.natsSub { + ns = natsSubSync(t, nc, "foo") + } + if test.mqttSubQoS0 { + testMQTTSub(t, 1, mc1, r1, []*mqttFilter{&mqttFilter{filter: "foo", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc1, nil, r1) + } + if test.mqttSubQoS1 { + testMQTTSub(t, 1, mc2, r2, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc2, nil, r2) + } + + // Just as a barrier + natsFlush(t, nc) + + // Now publish + if test.natsPub { + natsPubReq(t, nc, "foo", "", []byte("msg")) + } else { + testMQTTPublish(t, mc, r, test.mqttPubQoS, false, false, "foo", 1, []byte("msg")) + } + + // Check message received + if test.natsSub { + natsNexMsg(t, ns, time.Second) + // Make sure no other is received + if msg, err := ns.NextMsg(50 * time.Millisecond); err == nil { + t.Fatalf("Should not have gotten a second message, got %v", msg) + } + } + if test.mqttSubQoS0 { + testMQTTCheckPubMsg(t, mc1, r1, "foo", 0, []byte("msg")) + testMQTTExpectNothing(t, r1) + } + if test.mqttSubQoS1 { + var expectedFlag byte + if test.mqttPubQoS > 0 { + expectedFlag = test.mqttPubQoS << 1 + } + testMQTTCheckPubMsg(t, mc2, r2, "foo", expectedFlag, []byte("msg")) + testMQTTExpectNothing(t, r2) + } + }) + } +} + +func TestMQTTPreventSubWithMQTTSubPrefix(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, mc, r, + []*mqttFilter{&mqttFilter{filter: strings.ReplaceAll(mqttSubPrefix, ".", "/") + "foo/bar", qos: 1}}, + []byte{mqttSubAckFailure}) +} + +func TestMQTTSubWithNATSStream(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/bar", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + sc := &StreamConfig{ + Name: "test", + Storage: FileStorage, + Retention: InterestPolicy, + Subjects: []string{"foo.>"}, + } + mset, err := s.GlobalAccount().AddStream(sc) + if err != nil { + t.Fatalf("Unable to create stream: %v", err) + } + + sub := natsSubSync(t, nc, "bar") + cc := &ConsumerConfig{ + Durable: "dur", + AckPolicy: AckExplicit, + DeliverSubject: "bar", + } + if _, err := mset.AddConsumer(cc); err != nil { + t.Fatalf("Unable to add consumer: %v", err) + } + + // Now send message from NATS + resp, err := nc.Request("foo.bar", []byte("nats"), time.Second) + if err != nil { + t.Fatalf("Error publishing: %v", err) + } + ar := &ApiResponse{} + if err := json.Unmarshal(resp.Data, ar); err != nil || ar.Error != nil { + t.Fatalf("Unexpected response: err=%v resp=%+v", err, ar.Error) + } + + // Check that message is received by both + checkRecv := func(content string, flags byte) { + t.Helper() + if msg := natsNexMsg(t, sub, time.Second); string(msg.Data) != content { + t.Fatalf("Expected %q, got %q", content, msg.Data) + } + testMQTTCheckPubMsg(t, mc, r, "foo/bar", flags, []byte(content)) + } + checkRecv("nats", 0) + + // Send from MQTT as a QoS0 + testMQTTPublish(t, mc, r, 0, false, false, "foo/bar", 0, []byte("qos0")) + checkRecv("qos0", 0) + + // Send from MQTT as a QoS1 + testMQTTPublish(t, mc, r, 1, false, false, "foo/bar", 1, []byte("qos1")) + checkRecv("qos1", mqttPubQos1) +} + +func TestMQTTTrackPendingOverrun(t *testing.T) { + sess := &mqttSession{pending: make(map[uint16]*mqttPending)} + sub := &subscription{mqtt: &mqttSub{qos: 1}} + + sess.ppi = 0xFFFF + pi, _ := sess.trackPending(1, _EMPTY_, sub) + if pi != 1 { + t.Fatalf("Expected 1, got %v", pi) + } + + p := &mqttPending{} + for i := 1; i <= 0xFFFF; i++ { + sess.pending[uint16(i)] = p + } + pi, _ = sess.trackPending(1, _EMPTY_, sub) + if pi != 0 { + t.Fatalf("Expected 0, got %v", pi) + } + + delete(sess.pending, 1234) + pi, _ = sess.trackPending(1, _EMPTY_, sub) + if pi != 1234 { + t.Fatalf("Expected 1234, got %v", pi) + } +} + +func TestMQTTPreventStreamAndConsumerWithMQTTPrefix(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + sc := &StreamConfig{ + Name: mqttStreamNamePrefix + "test", + Storage: FileStorage, + Retention: InterestPolicy, + Subjects: []string{"foo.>"}, + } + if _, err := s.GlobalAccount().AddStream(sc); err == nil { + t.Fatal("Expected error") + } + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/bar", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + mset, err := s.GlobalAccount().LookupStream(mqttStreamName) + if err != nil { + t.Fatalf("Error looking up MQTT Stream: %v", err) + } + cc := &ConsumerConfig{ + Durable: "dur", + AckPolicy: AckExplicit, + DeliverSubject: "bar", + } + if _, err := mset.AddConsumer(cc); err == nil { + t.Fatal("Expected error") + } +} + +func TestMQTTSubRestart(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: false}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Start an MQTT subscription QoS=1 on "foo" + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + // Now start a NATS subscription on ">" (anything that would match the JS consumer delivery subject) + natsSubSync(t, nc, ">") + natsFlush(t, nc) + + // Restart the MQTT client + testMQTTDisconnect(t, mc, nil) + + mc, r = testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: false}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Restart an MQTT subscription QoS=1 on "foo" + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) +} + +func TestMQTTSubPropagation(t *testing.T) { + t.Skip("Skipping until JS clustering is supported") + o := testMQTTDefaultOptions() + o.Cluster.Host = "127.0.0.1" + o.Cluster.Port = -1 + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + o2 := DefaultOptions() + o2.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", o.Cluster.Port)) + s2 := RunServer(o2) + defer s2.Shutdown() + + checkClusterFormed(t, s, s2) + + nc := natsConnect(t, s2.ClientURL()) + defer nc.Close() + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + // Because in MQTT foo/# means foo.> but also foo, check that this is propagated + checkSubInterest(t, s2, globalAccountName, "foo", time.Second) + + // Publish on foo.bar, foo./ and foo and we should receive them + natsPub(t, nc, "foo.bar", []byte("hello")) + testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("hello")) + + natsPub(t, nc, "foo./", []byte("from")) + testMQTTCheckPubMsg(t, mc, r, "foo/", 0, []byte("from")) + + natsPub(t, nc, "foo", []byte("NATS")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("NATS")) +} + +func TestMQTTParseUnsub(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + proto []byte + b byte + pl int + reader mqttIOReader + err string + }{ + {"reserved flag", nil, 3, 0, nil, "wrong unsubscribe reserved flags"}, + {"ensure packet loaded", []byte{1, 2}, mqttUnsubscribeFlags, 10, eofr, "error ensuring protocol is loaded"}, + {"error reading packet id", []byte{1}, mqttUnsubscribeFlags, 1, eofr, "reading packet identifier"}, + {"missing filters", []byte{0, 1}, mqttUnsubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"}, + {"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttUnsubscribeFlags, 5, eofr, "topic filter"}, + {"empty topic", []byte{0, 1, 0, 0}, mqttUnsubscribeFlags, 4, nil, "topic filter cannot be empty"}, + {"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttUnsubscribeFlags, 5, nil, "invalid utf8 for topic filter"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + mqtt := &mqtt{r: r} + c := &client{mqtt: mqtt} + if _, _, err := c.mqttParseSubsOrUnsubs(r, test.b, test.pl, false); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func testMQTTUnsub(t *testing.T, pi uint16, c net.Conn, r *mqttReader, filters []*mqttFilter) { + t.Helper() + w := &mqttWriter{} + pkLen := 2 // for pi + for i := 0; i < len(filters); i++ { + f := filters[i] + pkLen += 2 + len(f.filter) + } + w.WriteByte(mqttPacketUnsub | mqttUnsubscribeFlags) + w.WriteVarInt(pkLen) + w.WriteUint16(pi) + for i := 0; i < len(filters); i++ { + f := filters[i] + w.WriteBytes([]byte(f.filter)) + } + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing UNSUBSCRIBE protocol: %v", err) + } + // Make sure we have at least 1 byte in buffer (if not will read) + testMQTTReaderHasAtLeastOne(t, r) + // Parse UNSUBACK + b, err := r.readByte("packet type") + if err != nil { + t.Fatal(err) + } + if pt := b & mqttPacketMask; pt != mqttPacketUnsubAck { + t.Fatalf("Expected UNSUBACK packet %x, got %x", mqttPacketUnsubAck, pt) + } + pl, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + t.Fatal(err) + } + rpi, err := r.readUint16("packet identifier") + if err != nil || rpi != pi { + t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err) + } +} + +func TestMQTTUnsub(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + // Publish and test msg received + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + + // Unsubscribe + testMQTTUnsub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo"}}) + + // Publish and test msg not received + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + // Use of wildcards subs + filters := []*mqttFilter{ + &mqttFilter{filter: "foo/bar", qos: 0}, + &mqttFilter{filter: "foo/#", qos: 0}, + } + testMQTTSub(t, 1, mc, r, filters, []byte{0, 0}) + testMQTTFlush(t, mc, nil, r) + + // Publish and check that message received twice + testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("msg")) + + // Unsub the wildcard one + testMQTTUnsub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#"}}) + // Publish and check that message received once + testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + // Unsub last + testMQTTUnsub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/bar"}}) + // Publish and test msg not received + testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg")) + testMQTTExpectNothing(t, r) +} + +func testMQTTExpectDisconnect(t testing.TB, c net.Conn) { + if buf, err := testMQTTRead(c); err == nil { + t.Fatalf("Expected connection to be disconnected, got %s", buf) + } +} + +func TestMQTTPublishTopicErrors(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + for _, test := range []struct { + name string + topic string + }{ + {"empty", ""}, + {"with single level wildcard", "foo/+"}, + {"with multiple level wildcard", "foo/#"}, + } { + t.Run(test.name, func(t *testing.T) { + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, mc, r, 0, false, false, test.topic, 0, []byte("msg")) + testMQTTExpectDisconnect(t, mc) + }) + } +} + +func testMQTTDisconnect(t testing.TB, c net.Conn, bw *bufio.Writer) { + t.Helper() + w := &mqttWriter{} + w.WriteByte(mqttPacketDisconnect) + w.WriteByte(0) + if bw != nil { + bw.Write(w.Bytes()) + bw.Flush() + } else { + c.Write(w.Bytes()) + } + testMQTTExpectDisconnect(t, c) +} + +func TestMQTTWill(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + sub := natsSubSync(t, nc, "will.topic") + + willMsg := []byte("bye") + + for _, test := range []struct { + name string + willExpected bool + willQoS byte + }{ + {"will qos 0", true, 0}, + {"will qos 1", true, 1}, + {"proper disconnect no will", false, 0}, + } { + t.Run(test.name, func(t *testing.T) { + mcs, rs := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "will/#", qos: 1}}, []byte{1}) + testMQTTFlush(t, mcs, nil, rs) + + mc, r := testMQTTConnect(t, + &mqttConnInfo{ + cleanSess: true, + will: &mqttWill{ + topic: []byte("will/topic"), + message: willMsg, + qos: test.willQoS, + }, + }, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + if test.willExpected { + mc.Close() + testMQTTCheckPubMsg(t, mcs, rs, "will/topic", test.willQoS<<1, willMsg) + wm := natsNexMsg(t, sub, time.Second) + if !bytes.Equal(wm.Data, willMsg) { + t.Fatalf("Expected will message to be %q, got %q", willMsg, wm.Data) + } + } else { + testMQTTDisconnect(t, mc, nil) + testMQTTExpectNothing(t, rs) + if wm, err := sub.NextMsg(100 * time.Millisecond); err == nil { + t.Fatalf("Should not have receive a message, got %v", wm) + } + } + }) + } +} + +func TestMQTTWillRetain(t *testing.T) { + for _, test := range []struct { + name string + pubQoS byte + subQoS byte + }{ + {"pub QoS0 sub QoS0", 0, 0}, + {"pub QoS0 sub QoS1", 0, 1}, + {"pub QoS1 sub QoS0", 1, 0}, + {"pub QoS1 sub QoS1", 1, 1}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + willTopic := []byte("will/topic") + willMsg := []byte("bye") + + mc, r := testMQTTConnect(t, + &mqttConnInfo{ + cleanSess: true, + will: &mqttWill{ + topic: willTopic, + message: willMsg, + qos: test.pubQoS, + retain: true, + }, + }, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Disconnect, which will cause will to be produced with retain flag. + mc.Close() + + // Create subscription on will topic and expect will message. + mcs, rs := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "will/#", qos: test.subQoS}}, []byte{test.subQoS}) + pflags, _ := testMQTTGetPubMsg(t, mcs, rs, "will/topic", willMsg) + if pflags&mqttPubFlagRetain == 0 { + t.Fatalf("expected retain flag to be set, it was not: %v", pflags) + } + // Expected QoS will be the lesser of the pub/sub QoS. + expectedQoS := test.pubQoS + if test.subQoS == 0 { + expectedQoS = 0 + } + if qos := mqttGetQoS(pflags); qos != expectedQoS { + t.Fatalf("expected qos to be %v, got %v", expectedQoS, qos) + } + }) + } +} + +func TestMQTTWillRetainPermViolation(t *testing.T) { + template := ` + port: -1 + jetstream: enabled + authorization { + mqtt_perms = { + publish = ["%s"] + subscribe = ["foo", "bar", "$MQTT.sub.>"] + } + users = [ + {user: mqtt, password: pass, permissions: $mqtt_perms} + ] + } + mqtt { + port: -1 + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(template, "foo"))) + defer os.Remove(conf) + + s, o := RunServerWithConfig(conf) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: "mqtt", + pass: "pass", + } + + // We create first a connection with the Will topic that the publisher + // is allowed to publish to. + ci.will = &mqttWill{ + topic: []byte("foo"), + message: []byte("bye"), + qos: 1, + retain: true, + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Disconnect, which will cause the Will to be sent with retain flag. + mc.Close() + + // Create a subscription on the Will subject and we should receive it. + ci.will = nil + mcs, rs := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + pflags, _ := testMQTTGetPubMsg(t, mcs, rs, "foo", []byte("bye")) + if pflags&mqttPubFlagRetain == 0 { + t.Fatalf("expected retain flag to be set, it was not: %v", pflags) + } + if qos := mqttGetQoS(pflags); qos != 1 { + t.Fatalf("expected qos to be 1, got %v", qos) + } + testMQTTDisconnect(t, mcs, nil) + + // Now create another connection with a Will that client is not allowed to publish to. + ci.will = &mqttWill{ + topic: []byte("bar"), + message: []byte("bye"), + qos: 1, + retain: true, + } + mc, r = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Disconnect, to cause Will to be produced, but in that case should not be stored + // since user not allowed to publish on "bar". + mc.Close() + + // Create sub on "bar" which user is allowed to subscribe to. + ci.will = nil + mcs, rs = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + // No Will should be published since it should not have been stored in the first place. + testMQTTExpectNothing(t, rs) + testMQTTDisconnect(t, mcs, nil) + + // Now remove permission to publish on "foo" and check that a new subscription + // on "foo" is now not getting the will message because the original user no + // longer has permission to do so. + reloadUpdateConfig(t, s, conf, fmt.Sprintf(template, "baz")) + + mcs, rs = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTExpectNothing(t, rs) + testMQTTDisconnect(t, mcs, nil) +} + +func TestMQTTPublishRetain(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + for _, test := range []struct { + name string + retained bool + sentValue string + expectedValue string + subGetsIt bool + }{ + {"publish retained", true, "retained", "retained", true}, + {"publish not retained", false, "not retained", "retained", true}, + {"remove retained", true, "", "", false}, + } { + t.Run(test.name, func(t *testing.T) { + mc1, rs1 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc1.Close() + testMQTTCheckConnAck(t, rs1, mqttConnAckRCConnectionAccepted, false) + testMQTTPublish(t, mc1, rs1, 0, false, test.retained, "foo", 0, []byte(test.sentValue)) + + mc2, rs2 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc2.Close() + testMQTTCheckConnAck(t, rs2, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc2, rs2, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 1}}, []byte{1}) + + if test.subGetsIt { + pflags, _ := testMQTTGetPubMsg(t, mc2, rs2, "foo", []byte(test.expectedValue)) + if pflags&mqttPubFlagRetain == 0 { + t.Fatalf("retain flag should have been set, it was not: flags=%v", pflags) + } + } else { + testMQTTExpectNothing(t, rs2) + } + + testMQTTDisconnect(t, mc1, nil) + testMQTTDisconnect(t, mc2, nil) + }) + } +} + +func TestMQTTPublishRetainPermViolation(t *testing.T) { + o := testMQTTDefaultOptions() + o.Users = []*User{ + { + Username: "mqtt", + Password: "pass", + Permissions: &Permissions{ + Publish: &SubjectPermission{Allow: []string{"foo"}}, + Subscribe: &SubjectPermission{Allow: []string{"bar", "$MQTT.sub.>"}}, + }, + }, + } + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: "mqtt", + pass: "pass", + } + + mc1, rs1 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc1.Close() + testMQTTCheckConnAck(t, rs1, mqttConnAckRCConnectionAccepted, false) + testMQTTPublish(t, mc1, rs1, 0, false, true, "bar", 0, []byte("retained")) + + mc2, rs2 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc2.Close() + testMQTTCheckConnAck(t, rs2, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc2, rs2, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + testMQTTExpectNothing(t, rs2) + + testMQTTDisconnect(t, mc1, nil) + testMQTTDisconnect(t, mc2, nil) +} + +func TestMQTTCleanSession(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + clientID: "me", + cleanSess: false, + } + c, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTDisconnect(t, c, nil) + + c, r = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + testMQTTDisconnect(t, c, nil) + + ci.cleanSess = true + c, r = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTDisconnect(t, c, nil) +} + +func TestMQTTDuplicateClientID(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + clientID: "me", + cleanSess: false, + } + c1, r1 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c1.Close() + testMQTTCheckConnAck(t, r1, mqttConnAckRCConnectionAccepted, false) + + c2, r2 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, true) + + // The old client should be disconnected. + testMQTTExpectDisconnect(t, c1) +} + +func TestMQTTPersistedSession(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, c, r, + []*mqttFilter{ + &mqttFilter{filter: "foo/#", qos: 1}, + &mqttFilter{filter: "bar", qos: 1}, + &mqttFilter{filter: "baz", qos: 0}, + }, + []byte{1, 1, 0}) + testMQTTFlush(t, c, nil, r) + + // Shutdown server, close connection and restart server. It should + // have restored the session and consumers. + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + s.Shutdown() + c.Close() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + // Create a publisher that will send qos1 so we verify that messages + // are stored for the persisted sessions. + c, r = testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, c, r, 1, false, false, "foo/bar", 1, []byte("msg0")) + testMQTTFlush(t, c, nil, r) + testMQTTDisconnect(t, c, nil) + c.Close() + + // Recreate consumer session + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Since consumers have been recovered, messages should be received + // (MQTT does not need client to recreate consumers for a recovered + // session) + + // Check that qos1 publish message is received. + testMQTTCheckPubMsg(t, c, r, "foo/bar", mqttPubQos1, []byte("msg0")) + + // Now publish some messages to all subscriptions. + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + natsPub(t, nc, "foo.bar", []byte("msg1")) + testMQTTCheckPubMsg(t, c, r, "foo/bar", 0, []byte("msg1")) + + natsPub(t, nc, "foo", []byte("msg2")) + testMQTTCheckPubMsg(t, c, r, "foo", 0, []byte("msg2")) + + natsPub(t, nc, "bar", []byte("msg3")) + testMQTTCheckPubMsg(t, c, r, "bar", 0, []byte("msg3")) + + natsPub(t, nc, "baz", []byte("msg4")) + testMQTTCheckPubMsg(t, c, r, "baz", 0, []byte("msg4")) + + // Now unsub "bar" and verify that message published on this topic + // is not received. + testMQTTUnsub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "bar"}}) + natsPub(t, nc, "bar", []byte("msg5")) + testMQTTExpectNothing(t, r) + + nc.Close() + s.Shutdown() + c.Close() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + // Recreate a client + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + nc = natsConnect(t, s.ClientURL()) + defer nc.Close() + + natsPub(t, nc, "foo.bar", []byte("msg6")) + testMQTTCheckPubMsg(t, c, r, "foo/bar", 0, []byte("msg6")) + + natsPub(t, nc, "foo", []byte("msg7")) + testMQTTCheckPubMsg(t, c, r, "foo", 0, []byte("msg7")) + + // Make sure that we did not recover bar. + natsPub(t, nc, "bar", []byte("msg8")) + testMQTTExpectNothing(t, r) + + natsPub(t, nc, "baz", []byte("msg9")) + testMQTTCheckPubMsg(t, c, r, "baz", 0, []byte("msg9")) + + // Have the sub client send a subscription downgrading the qos1 subscription. + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 0}}, []byte{0}) + testMQTTFlush(t, c, nil, r) + + nc.Close() + s.Shutdown() + c.Close() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + // Recreate the sub client + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Publish as a qos1 + c2, r2 := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer c2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, false) + testMQTTPublish(t, c2, r2, 1, false, false, "foo/bar", 1, []byte("msg10")) + + // Verify that it is received as qos0 which is the qos of the subscription. + testMQTTCheckPubMsg(t, c, r, "foo/bar", 0, []byte("msg10")) + + testMQTTDisconnect(t, c, nil) + c.Close() + testMQTTDisconnect(t, c2, nil) + c2.Close() + + // Finally, recreate the sub with clean session and ensure that all is gone + cisub.cleanSess = true + for i := 0; i < 2; i++ { + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + nc = natsConnect(t, s.ClientURL()) + defer nc.Close() + + natsPub(t, nc, "foo.bar", []byte("msg11")) + testMQTTExpectNothing(t, r) + + natsPub(t, nc, "foo", []byte("msg12")) + testMQTTExpectNothing(t, r) + + // Make sure that we did not recover bar. + natsPub(t, nc, "bar", []byte("msg13")) + testMQTTExpectNothing(t, r) + + natsPub(t, nc, "baz", []byte("msg14")) + testMQTTExpectNothing(t, r) + + testMQTTDisconnect(t, c, nil) + c.Close() + nc.Close() + + s.Shutdown() + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + } +} + +func TestMQTTRecoverSessionAndAddNewSub(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + cisub := &mqttConnInfo{clientID: "sub1", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTDisconnect(t, c, nil) + c.Close() + + // Shutdown server, close connection and restart server. It should + // have restored the session and consumers. + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + s.Shutdown() + c.Close() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // No need for defer since it is done top of function + + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + // Now add sub and make sure it does not crash + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, c, nil, r) + + // Now repeat with a new client but without server restart. + cisub2 := &mqttConnInfo{clientID: "sub2", cleanSess: false} + c2, r2 := testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port) + defer c2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, false) + testMQTTDisconnect(t, c2, nil) + c2.Close() + + c2, r2 = testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port) + defer c2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, true) + testMQTTSub(t, 1, c2, r2, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + testMQTTFlush(t, c2, nil, r2) +} + +func TestMQTTRecoverSessionWithSubAndClientResendSub(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + cisub1 := &mqttConnInfo{clientID: "sub1", cleanSess: false} + c, r := testMQTTConnect(t, cisub1, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Have a client send a SUBSCRIBE protocol for foo, QoS1 + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTDisconnect(t, c, nil) + c.Close() + + // Restart the server now. + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + s.Shutdown() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // No need for defer since it is done top of function + + // Now restart the client. Since the client was created with cleanSess==false, + // the server will have recorded the subscriptions for this client. + c, r = testMQTTConnect(t, cisub1, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + // At this point, the server has recreated the subscription on foo, QoS1. + + // For applications that restart, it is possible (likely) that they + // will resend their SUBSCRIBE protocols, so do so now: + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, c, nil, r) + + checkNumSub := func(clientID string) { + t.Helper() + + // Find the MQTT client... + mc := testMQTTGetClient(t, s, clientID) + + // Check how many NATS subscriptions are registered. + var fooSub int + var otherSub int + mc.mu.Lock() + for _, sub := range mc.subs { + switch string(sub.subject) { + case "foo": + fooSub++ + default: + otherSub++ + } + } + mc.mu.Unlock() + + // We should have 2 subscriptions, one on "foo", and one for the JS durable + // consumer's delivery subject. + if fooSub != 1 { + t.Fatalf("Expected 1 sub on 'foo', got %v", fooSub) + } + if otherSub != 1 { + t.Fatalf("Expected 1 subscription for JS durable, got %v", otherSub) + } + } + checkNumSub("sub1") + + c.Close() + + // Now same but without the server restart in-between. + cisub2 := &mqttConnInfo{clientID: "sub2", cleanSess: false} + c, r = testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTDisconnect(t, c, nil) + c.Close() + // Restart client + c, r = testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, c, nil, r) + // Check client subs + checkNumSub("sub2") +} + +func TestMQTTPersistRetainedMsg(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + c, r := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, c, r, 1, false, true, "foo", 1, []byte("foo1")) + testMQTTPublish(t, c, r, 1, false, true, "foo", 1, []byte("foo2")) + testMQTTPublish(t, c, r, 1, false, true, "bar", 1, []byte("bar1")) + testMQTTPublish(t, c, r, 0, false, true, "baz", 1, []byte("baz1")) + // Remove bar + testMQTTPublish(t, c, r, 1, false, true, "bar", 1, nil) + testMQTTDisconnect(t, c, nil) + c.Close() + + s.Shutdown() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTCheckPubMsg(t, c, r, "foo", mqttPubFlagRetain|mqttPubQos1, []byte("foo2")) + + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "baz", qos: 1}}, []byte{1}) + testMQTTCheckPubMsg(t, c, r, "baz", mqttPubFlagRetain, []byte("baz1")) + + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + testMQTTExpectNothing(t, r) + + testMQTTDisconnect(t, c, nil) + c.Close() +} + +func TestMQTTConnAckFirstProto(t *testing.T) { + o := testMQTTDefaultOptions() + o.NoLog, o.Debug, o.Trace = true, false, false + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 0}}, []byte{0}) + testMQTTDisconnect(t, c, nil) + c.Close() + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + wg := sync.WaitGroup{} + wg.Add(1) + ch := make(chan struct{}, 1) + ready := make(chan struct{}) + go func() { + defer wg.Done() + + close(ready) + for { + nc.Publish("foo", []byte("msg")) + select { + case <-ch: + return + default: + } + } + }() + + <-ready + for i := 0; i < 100; i++ { + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + c.Close() + } + close(ch) + wg.Wait() +} + +func TestMQTTRedeliveryAckWait(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.AckWait = 250 * time.Millisecond + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("foo1")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 2, []byte("foo2")) + testMQTTDisconnect(t, cp, nil) + cp.Close() + + for i := 0; i < 2; i++ { + flags := mqttPubQos1 + if i > 0 { + flags |= mqttPubFlagDup + } + pi1 := testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("foo1")) + pi2 := testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("foo2")) + + if pi1 != 1 || pi2 != 2 { + t.Fatalf("Unexpected pi values: %v, %v", pi1, pi2) + } + } + // Ack first message + testMQTTSendPubAck(t, c, 1) + // Redelivery should only be for second message now + for i := 0; i < 2; i++ { + flags := mqttPubQos1 | mqttPubFlagDup + pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("foo2")) + if pi != 2 { + t.Fatalf("Unexpected pi to be 2, got %v", pi) + } + } + + // Restart client, should receive second message with pi==2 + c.Close() + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + // Check that message is received with proper pi + pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("foo2")) + if pi != 2 { + t.Fatalf("Unexpected pi to be 2, got %v", pi) + } + // Now ack second message + testMQTTSendPubAck(t, c, 2) + // Flush to make sure it is processed before checking client's maps + testMQTTFlush(t, c, nil, r) + + // Look for the sub client + mc := testMQTTGetClient(t, s, "sub") + mc.mu.Lock() + sess := mc.mqtt.sess + sess.mu.Lock() + lpi := len(sess.pending) + var lsseq int + for _, sseqToPi := range sess.cpending { + lsseq += len(sseqToPi) + } + sess.mu.Unlock() + mc.mu.Unlock() + if lpi != 0 || lsseq != 0 { + t.Fatalf("Maps should be empty, got %v, %v", lpi, lsseq) + } +} + +func TestMQTTAckWaitConfigChange(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.AckWait = 250 * time.Millisecond + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + sendMsg := func(topic, payload string) { + t.Helper() + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, topic, 1, []byte(payload)) + testMQTTDisconnect(t, cp, nil) + cp.Close() + } + sendMsg("foo", "msg1") + + for i := 0; i < 2; i++ { + flags := mqttPubQos1 + if i > 0 { + flags |= mqttPubFlagDup + } + testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("msg1")) + } + + // Restart the server with a different AckWait option value. + // Verify that MQTT sub restart succeeds. It will keep the + // original value. + c.Close() + s.Shutdown() + + o.Port = -1 + o.MQTT.Port = -1 + o.MQTT.AckWait = 10 * time.Millisecond + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1")) + start := time.Now() + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1")) + if dur := time.Since(start); dur < 200*time.Millisecond { + t.Fatalf("AckWait seem to have changed for existing subscription: %v", dur) + } + + // Create new subscription + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + sendMsg("bar", "msg2") + testMQTTCheckPubMsgNoAck(t, c, r, "bar", mqttPubQos1, []byte("msg2")) + start = time.Now() + testMQTTCheckPubMsgNoAck(t, c, r, "bar", mqttPubQos1|mqttPubFlagDup, []byte("msg2")) + if dur := time.Since(start); dur > 50*time.Millisecond { + t.Fatalf("AckWait new value not used by new sub: %v", dur) + } + c.Close() +} + +func TestMQTTUnsubscribeWithPendingAcks(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.AckWait = 250 * time.Millisecond + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg")) + testMQTTDisconnect(t, cp, nil) + cp.Close() + + for i := 0; i < 2; i++ { + flags := mqttPubQos1 + if i > 0 { + flags |= mqttPubFlagDup + } + testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("msg")) + } + + testMQTTUnsub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo"}}) + testMQTTFlush(t, c, nil, r) + + mc := testMQTTGetClient(t, s, "sub") + mc.mu.Lock() + sess := mc.mqtt.sess + sess.mu.Lock() + pal := len(sess.pending) + sess.mu.Unlock() + mc.mu.Unlock() + if pal != 0 { + t.Fatalf("Expected pending ack map to be empty, got %v", pal) + } +} + +func TestMQTTMaxAckPending(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.MaxAckPending = 1 + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg2")) + + pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1")) + // Check that we don't receive the second one due to max ack pending + testMQTTExpectNothing(t, r) + + // Now ack first message + testMQTTSendPubAck(t, c, pi) + // Now we should receive message 2 + testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg2")) + testMQTTDisconnect(t, c, nil) + + // Send 2 messages while sub is offline + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg3")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg4")) + + // Restart consumer + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Should receive only message 3 + pi = testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg3")) + testMQTTExpectNothing(t, r) + + // Ack and get the next + testMQTTSendPubAck(t, c, pi) + testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg4")) + + // Check that change to config does not prevent restart of sub. + cp.Close() + c.Close() + s.Shutdown() + + o.Port = -1 + o.MQTT.Port = -1 + o.MQTT.MaxAckPending = 2 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + cp, rp = testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg5")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg6")) + + // Restart consumer + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Should receive only message 5 + pi = testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg5")) + testMQTTExpectNothing(t, r) + + // Ack and get the next + testMQTTSendPubAck(t, c, pi) + testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg6")) +} + +func TestMQTTMaxAckPendingForMultipleSubs(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.MaxAckPending = 1 + s := testMQTTRunServer(t, o) + defer s.Shutdown() + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1")) + pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1")) + + // Now send a second message but on topic bar + testMQTTPublish(t, cp, rp, 1, false, false, "bar", 1, []byte("msg2")) + + // JS allows us to limit per consumer, but we apply the limit to the + // session, so although JS will attempt to delivery this message, + // the MQTT code will suppress it. + testMQTTExpectNothing(t, r) + + // Ack the first message. + testMQTTSendPubAck(t, c, pi) + + // Now we should get the second message + testMQTTCheckPubMsg(t, c, r, "bar", mqttPubQos1|mqttPubFlagDup, []byte("msg2")) +} + +func TestMQTTConfigReload(t *testing.T) { + template := ` + jetstream: true + mqtt { + port: -1 + ack_wait: %s + max_ack_pending: %s + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(template, `"5s"`, `10000`))) + defer os.Remove(conf) + + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + if val := o.MQTT.AckWait; val != 5*time.Second { + t.Fatalf("Invalid ackwait: %v", val) + } + if val := o.MQTT.MaxAckPending; val != 10000 { + t.Fatalf("Invalid ackwait: %v", val) + } + + changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, `"250ms"`, `1`))) + if err := s.Reload(); err != nil { + t.Fatalf("Error on reload: %v", err) + } + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg2")) + + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1")) + start := time.Now() + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1")) + if dur := time.Since(start); dur > 500*time.Millisecond { + t.Fatalf("AckWait not applied? dur=%v", dur) + } + c.Close() + cp.Close() + s.Shutdown() + + changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, `"30s"`, `1`))) + s, o = RunServerWithConfig(conf) + defer s.Shutdown() + + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub = &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp = testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg2")) + + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1")) + testMQTTExpectNothing(t, r) + + // Increate the max ack pending + changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, `"30s"`, `10`))) + // Reload now + if err := s.Reload(); err != nil { + t.Fatalf("Error on reload: %v", err) + } + // See that message 2 can now be received (1 will be redelivered too) + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1")) + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg2")) +} + +// Benchmarks + +const ( + mqttPubSubj = "a" + mqttBenchBufLen = 32768 +) + +func mqttBenchPubQoS0(b *testing.B, subject, payload string, numSubs int) { + b.StopTimer() + o := testMQTTDefaultOptions() + s := RunServer(o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{clientID: "pub", cleanSess: true} + c, br := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(b, br, mqttConnAckRCConnectionAccepted, false) + w := &mqttWriter{} + mqttWritePublish(w, 0, false, false, subject, 0, []byte(payload)) + sendOp := w.Bytes() + + dch := make(chan error, 1) + totalSize := int64(len(sendOp)) + cdch := 0 + + createSub := func(i int) { + ci := &mqttConnInfo{clientID: fmt.Sprintf("sub%d", i), cleanSess: true} + cs, brs := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(b, brs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(b, 1, cs, brs, []*mqttFilter{&mqttFilter{filter: subject, qos: 0}}, []byte{0}) + testMQTTFlush(b, cs, nil, brs) + + w := &mqttWriter{} + varHeaderAndPayload := 2 + len(subject) + len(payload) + w.WriteVarInt(varHeaderAndPayload) + size := 1 + w.Len() + varHeaderAndPayload + totalSize += int64(size) + + go func() { + mqttBenchConsumeMsgQoS0(cs, int64(b.N)*int64(size), dch) + cs.Close() + }() + } + for i := 0; i < numSubs; i++ { + createSub(i + 1) + cdch++ + } + + bw := bufio.NewWriterSize(c, mqttBenchBufLen) + b.SetBytes(totalSize) + b.StartTimer() + for i := 0; i < b.N; i++ { + bw.Write(sendOp) + } + testMQTTFlush(b, c, bw, br) + for i := 0; i < cdch; i++ { + if e := <-dch; e != nil { + b.Fatal(e.Error()) + } + } + b.StopTimer() + c.Close() + s.Shutdown() +} + +func mqttBenchConsumeMsgQoS0(c net.Conn, total int64, dch chan<- error) { + var buf [mqttBenchBufLen]byte + var err error + var n int + for size := int64(0); size < total; { + n, err = c.Read(buf[:]) + if err != nil { + break + } + size += int64(n) + } + dch <- err +} + +func mqttBenchPubQoS1(b *testing.B, subject, payload string, numSubs int) { + b.StopTimer() + o := testMQTTDefaultOptions() + o.MQTT.MaxAckPending = 0xFFFF + s := RunServer(o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{cleanSess: true} + c, br := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(b, br, mqttConnAckRCConnectionAccepted, false) + + w := &mqttWriter{} + mqttWritePublish(w, 1, false, false, subject, 1, []byte(payload)) + // For reported bytes we will count the PUBLISH + PUBACK (4 bytes) + totalSize := int64(len(w.Bytes()) + 4) + w.Reset() + + pi := uint16(1) + maxpi := uint16(60000) + ppich := make(chan error, 10) + dch := make(chan error, 1+numSubs) + cdch := 1 + // Start go routine to consume PUBACK for published QoS 1 messages. + go mqttBenchConsumePubAck(c, b.N, dch, ppich) + + createSub := func(i int) { + ci := &mqttConnInfo{clientID: fmt.Sprintf("sub%d", i), cleanSess: true} + cs, brs := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(b, brs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(b, 1, cs, brs, []*mqttFilter{&mqttFilter{filter: subject, qos: 1}}, []byte{1}) + testMQTTFlush(b, cs, nil, brs) + + w := &mqttWriter{} + varHeaderAndPayload := 2 + len(subject) + 2 + len(payload) + w.WriteVarInt(varHeaderAndPayload) + size := 1 + w.Len() + varHeaderAndPayload + // Add to the bytes reported the size of message sent to subscriber + PUBACK (4 bytes) + totalSize += int64(size + 4) + + go func() { + mqttBenchConsumeMsgQos1(cs, b.N, size, dch) + cs.Close() + }() + } + for i := 0; i < numSubs; i++ { + createSub(i + 1) + cdch++ + } + + flush := func() { + b.Helper() + if _, err := c.Write(w.Bytes()); err != nil { + b.Fatalf("Error on write: %v", err) + } + w.Reset() + } + + b.SetBytes(totalSize) + b.StartTimer() + for i := 0; i < b.N; i++ { + if pi <= maxpi { + mqttWritePublish(w, 1, false, false, subject, pi, []byte(payload)) + pi++ + if w.Len() >= mqttBenchBufLen { + flush() + } + } else { + if w.Len() > 0 { + flush() + } + if pi > 60000 { + pi = 1 + maxpi = 0 + } + if e := <-ppich; e != nil { + b.Fatal(e.Error()) + } + maxpi += 10000 + i-- + } + } + if w.Len() > 0 { + flush() + } + for i := 0; i < cdch; i++ { + if e := <-dch; e != nil { + b.Fatal(e.Error()) + } + } + b.StopTimer() + c.Close() + s.Shutdown() +} + +func mqttBenchConsumeMsgQos1(c net.Conn, total, size int, dch chan<- error) { + var buf [mqttBenchBufLen]byte + pubAck := [4]byte{mqttPacketPubAck, 0x2, 0, 0} + var err error + var n int + var pi uint16 + var prev int + for i := 0; i < total; { + n, err = c.Read(buf[:]) + if err != nil { + break + } + n += prev + for ; n >= size; n -= size { + i++ + pi++ + pubAck[2] = byte(pi >> 8) + pubAck[3] = byte(pi) + if _, err = c.Write(pubAck[:4]); err != nil { + dch <- err + return + } + if pi == 60000 { + pi = 0 + } + } + prev = n + } + dch <- err +} + +func mqttBenchConsumePubAck(c net.Conn, total int, dch, ppich chan<- error) { + var buf [mqttBenchBufLen]byte + var err error + var n int + var pi uint16 + var prev int + for i := 0; i < total; { + n, err = c.Read(buf[:]) + if err != nil { + break + } + n += prev + for ; n >= 4; n -= 4 { + i++ + pi++ + if pi%10000 == 0 { + ppich <- nil + } + if pi == 60001 { + pi = 0 + } + } + prev = n + } + ppich <- err + dch <- err +} + +func BenchmarkMQTT_QoS0_Pub_______0b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, "", 0) +} + +func BenchmarkMQTT_QoS0_Pub_______8b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(8), 0) +} + +func BenchmarkMQTT_QoS0_Pub______32b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(32), 0) +} + +func BenchmarkMQTT_QoS0_Pub_____128b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(128), 0) +} + +func BenchmarkMQTT_QoS0_Pub_____256b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(256), 0) +} + +func BenchmarkMQTT_QoS0_Pub_______1K_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(1024), 0) +} + +func BenchmarkMQTT_QoS0_PubSub1___0b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, "", 1) +} + +func BenchmarkMQTT_QoS0_PubSub1___8b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(8), 1) +} + +func BenchmarkMQTT_QoS0_PubSub1__32b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(32), 1) +} + +func BenchmarkMQTT_QoS0_PubSub1_128b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(128), 1) +} + +func BenchmarkMQTT_QoS0_PubSub1_256b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(256), 1) +} + +func BenchmarkMQTT_QoS0_PubSub1___1K_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(1024), 1) +} + +func BenchmarkMQTT_QoS0_PubSub2___0b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, "", 2) +} + +func BenchmarkMQTT_QoS0_PubSub2___8b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(8), 2) +} + +func BenchmarkMQTT_QoS0_PubSub2__32b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(32), 2) +} + +func BenchmarkMQTT_QoS0_PubSub2_128b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(128), 2) +} + +func BenchmarkMQTT_QoS0_PubSub2_256b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(256), 2) +} + +func BenchmarkMQTT_QoS0_PubSub2___1K_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(1024), 2) +} + +func BenchmarkMQTT_QoS1_Pub_______0b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, "", 0) +} + +func BenchmarkMQTT_QoS1_Pub_______8b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(8), 0) +} + +func BenchmarkMQTT_QoS1_Pub______32b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(32), 0) +} + +func BenchmarkMQTT_QoS1_Pub_____128b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(128), 0) +} + +func BenchmarkMQTT_QoS1_Pub_____256b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(256), 0) +} + +func BenchmarkMQTT_QoS1_Pub_______1K_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(1024), 0) +} + +func BenchmarkMQTT_QoS1_PubSub1___0b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, "", 1) +} + +func BenchmarkMQTT_QoS1_PubSub1___8b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(8), 1) +} + +func BenchmarkMQTT_QoS1_PubSub1__32b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(32), 1) +} + +func BenchmarkMQTT_QoS1_PubSub1_128b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(128), 1) +} + +func BenchmarkMQTT_QoS1_PubSub1_256b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(256), 1) +} + +func BenchmarkMQTT_QoS1_PubSub1___1K_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(1024), 1) +} + +func BenchmarkMQTT_QoS1_PubSub2___0b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, "", 2) +} + +func BenchmarkMQTT_QoS1_PubSub2___8b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(8), 2) +} + +func BenchmarkMQTT_QoS1_PubSub2__32b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(32), 2) +} + +func BenchmarkMQTT_QoS1_PubSub2_128b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(128), 2) +} + +func BenchmarkMQTT_QoS1_PubSub2_256b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(256), 2) +} + +func BenchmarkMQTT_QoS1_PubSub2___1K_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(1024), 2) +} diff --git a/server/opts.go b/server/opts.go index 32802133..3f44538a 100644 --- a/server/opts.go +++ b/server/opts.go @@ -197,6 +197,7 @@ type Options struct { JetStreamMaxStore int64 `json:"-"` StoreDir string `json:"-"` Websocket WebsocketOpts `json:"-"` + MQTT MQTTOpts `json:"-"` ProfPort int `json:"-"` PidFile string `json:"-"` PortsFileDir string `json:"-"` @@ -256,7 +257,7 @@ type Options struct { routeProto int } -// WebsocketOpts ... +// WebsocketOpts are options for websocket type WebsocketOpts struct { // The server will accept websocket client connections on this hostname/IP. Host string @@ -316,6 +317,49 @@ type WebsocketOpts struct { HandshakeTimeout time.Duration } +// MQTTOpts are options for MQTT +type MQTTOpts struct { + // The server will accept MQTT client connections on this hostname/IP. + Host string + // The server will accept MQTT client connections on this port. + Port int + + // If no user name is provided when a client connects, will default to the + // matching user from the global list of users in `Options.Users`. + NoAuthUser string + + // Authentication section. If anything is configured in this section, + // it will override the authorization configuration of regular clients. + Username string + Password string + Token string + + // Timeout for the authentication process. + AuthTimeout float64 + + // TLS configuration is required. + TLSConfig *tls.Config + // If true, map certificate values for authentication purposes. + TLSMap bool + // Timeout for the TLS handshake + TLSTimeout float64 + + // AckWait is the amount of time after which a QoS 1 message sent to + // a client is redelivered as a DUPLICATE if the server has not + // received the PUBACK on the original Packet Identifier. + // The value has to be positive. + // Zero will cause the server to use the default value (1 hour). + // Note that changes to this option is applied only to new MQTT subscriptions. + AckWait time.Duration + + // MaxAckPending is the amount of QoS 1 messages the server can send to + // a session without receiving any PUBACK for those messages. + // The valid range is [0..65535]. + // Zero will cause the server to use the default value (1024). + // Note that changes to this option is applied only to new MQTT sessions. + MaxAckPending uint16 +} + type netResolver interface { LookupHost(ctx context.Context, host string) ([]string, error) } @@ -1011,6 +1055,11 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error *errors = append(*errors, err) return } + case "mqtt": + if err := parseMQTT(tk, o, errors, warnings); err != nil { + *errors = append(*errors, err) + return + } default: if au := atomic.LoadInt32(&allowUnknownTopLevelField); au == 0 && !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ @@ -3327,7 +3376,7 @@ func parseTLS(v interface{}, isClientCtx bool) (t *TLSConfigOpts, retErr error) return &tc, nil } -func parseAuthForWS(v interface{}, errors *[]error, warnings *[]error) *authorization { +func parseSimpleAuth(v interface{}, errors *[]error, warnings *[]error) *authorization { var ( am map[string]interface{} tk token @@ -3463,7 +3512,7 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro case "compression": o.Websocket.Compression = mv.(bool) case "authorization", "authentication": - auth := parseAuthForWS(tk, errors, warnings) + auth := parseSimpleAuth(tk, errors, warnings) o.Websocket.Username = auth.user o.Websocket.Password = auth.pass o.Websocket.Token = auth.token @@ -3488,6 +3537,79 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro return nil } +func parseMQTT(v interface{}, o *Options, errors *[]error, warnings *[]error) error { + var lt token + defer convertPanicToErrorList(<, errors) + + tk, v := unwrapValue(v, <) + gm, ok := v.(map[string]interface{}) + if !ok { + return &configErr{tk, fmt.Sprintf("Expected mqtt to be a map, got %T", v)} + } + for mk, mv := range gm { + // Again, unwrap token value if line check is required. + tk, mv = unwrapValue(mv, <) + switch strings.ToLower(mk) { + case "listen": + hp, err := parseListen(mv) + if err != nil { + err := &configErr{tk, err.Error()} + *errors = append(*errors, err) + continue + } + o.MQTT.Host = hp.host + o.MQTT.Port = hp.port + case "port": + o.MQTT.Port = int(mv.(int64)) + case "host", "net": + o.MQTT.Host = mv.(string) + case "tls": + tc, err := parseTLS(tk, true) + if err != nil { + *errors = append(*errors, err) + continue + } + if o.MQTT.TLSConfig, err = GenTLSConfig(tc); err != nil { + err := &configErr{tk, err.Error()} + *errors = append(*errors, err) + continue + } + o.MQTT.TLSTimeout = tc.Timeout + o.MQTT.TLSMap = tc.Map + case "authorization", "authentication": + auth := parseSimpleAuth(tk, errors, warnings) + o.MQTT.Username = auth.user + o.MQTT.Password = auth.pass + o.MQTT.Token = auth.token + o.MQTT.AuthTimeout = auth.timeout + case "no_auth_user": + o.MQTT.NoAuthUser = mv.(string) + case "ack_wait", "ackwait": + o.MQTT.AckWait = parseDuration("ack_wait", tk, mv, errors, warnings) + case "max_ack_pending", "max_pending", "max_inflight": + tmp := int(mv.(int64)) + if tmp < 0 || tmp > 0xFFFF { + err := &configErr{tk, fmt.Sprintf("invalid value %v, should in [0..%d] range", tmp, 0xFFFF)} + *errors = append(*errors, err) + } else { + o.MQTT.MaxAckPending = uint16(tmp) + } + default: + if !tk.IsUsedVariable() { + err := &unknownConfigFieldErr{ + field: mk, + configErr: configErr{ + token: tk, + }, + } + *errors = append(*errors, err) + continue + } + } + } + return nil +} + // GenTLSConfig loads TLS related configuration parameters. func GenTLSConfig(tc *TLSConfigOpts) (*tls.Config, error) { // Create the tls.Config from our options before including the certs. @@ -3831,6 +3953,14 @@ func setBaselineOptions(opts *Options) { opts.Websocket.Host = DEFAULT_HOST } } + if opts.MQTT.Port != 0 { + if opts.MQTT.Host == "" { + opts.MQTT.Host = DEFAULT_HOST + } + if opts.MQTT.TLSTimeout == 0 { + opts.MQTT.TLSTimeout = float64(TLS_TIMEOUT) / float64(time.Second) + } + } // JetStream if opts.JetStreamMaxMemory == 0 { opts.JetStreamMaxMemory = -1 diff --git a/server/parser.go b/server/parser.go index e9492b2f..f76337df 100644 --- a/server/parser.go +++ b/server/parser.go @@ -130,6 +130,11 @@ const ( ) func (c *client) parse(buf []byte) error { + // Branch out to mqtt clients. c.mqtt is immutable, but should it become + // an issue (say data race detection), we could branch outside in readLoop + if c.mqtt != nil { + return c.mqttParse(buf) + } var i int var b byte var lmsg bool diff --git a/server/reload.go b/server/reload.go index 14fa386e..790bb80a 100644 --- a/server/reload.go +++ b/server/reload.go @@ -596,6 +596,25 @@ func (m *maxTracedMsgLenOption) Apply(server *Server) { server.Noticef("Reloaded: max_traced_msg_len = %d", m.newValue) } +type mqttAckWaitReload struct { + noopOption + newValue time.Duration +} + +func (o *mqttAckWaitReload) Apply(s *Server) { + s.Noticef("Reloaded: MQTT ack_wait = %v", o.newValue) +} + +type mqttMaxAckPendingReload struct { + noopOption + newValue uint16 +} + +func (o *mqttMaxAckPendingReload) Apply(s *Server) { + s.mqttUpdateMaxAckPending(o.newValue) + s.Noticef("Reloaded: MQTT max_ack_pending = %v", o.newValue) +} + // Reload reads the current configuration file and applies any supported // changes. This returns an error if the server was not started with a config // file or an option which doesn't support hot-swapping was changed. @@ -633,6 +652,7 @@ func (s *Server) Reload() error { gatewayOrgPort := curOpts.Gateway.Port leafnodesOrgPort := curOpts.LeafNode.Port websocketOrgPort := curOpts.Websocket.Port + mqttOrgPort := curOpts.MQTT.Port s.mu.Unlock() @@ -665,6 +685,9 @@ func (s *Server) Reload() error { if newOpts.Websocket.Port == -1 { newOpts.Websocket.Port = websocketOrgPort } + if newOpts.MQTT.Port == -1 { + newOpts.MQTT.Port = mqttOrgPort + } if err := s.reloadOptions(curOpts, newOpts); err != nil { return err @@ -766,7 +789,7 @@ func imposeOrder(value interface{}) error { case WebsocketOpts: sort.Strings(value.AllowedOrigins) case string, bool, int, int32, int64, time.Duration, float64, nil, - LeafNodeOpts, ClusterOpts, *tls.Config, *URLAccResolver, *MemAccResolver, *DirAccResolver, *CacheDirAccResolver, Authentication: + LeafNodeOpts, ClusterOpts, *tls.Config, *URLAccResolver, *MemAccResolver, *DirAccResolver, *CacheDirAccResolver, Authentication, MQTTOpts: // explicitly skipped types default: // this will fail during unit tests @@ -973,6 +996,20 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { return nil, fmt.Errorf("config reload not supported for %s: old=%v, new=%v", field.Name, oldValue, newValue) } + case "mqtt": + diffOpts = append(diffOpts, &mqttAckWaitReload{newValue: newValue.(MQTTOpts).AckWait}) + diffOpts = append(diffOpts, &mqttMaxAckPendingReload{newValue: newValue.(MQTTOpts).MaxAckPending}) + // Nil out/set to 0 the options that we allow to be reloaded so that + // we only fail reload if some that we don't support are changed. + tmpOld := oldValue.(MQTTOpts) + tmpNew := newValue.(MQTTOpts) + tmpOld.TLSConfig, tmpOld.AckWait, tmpOld.MaxAckPending = nil, 0, 0 + tmpNew.TLSConfig, tmpNew.AckWait, tmpNew.MaxAckPending = nil, 0, 0 + if !reflect.DeepEqual(tmpOld, tmpNew) { + // See TODO(ik) note below about printing old/new values. + return nil, fmt.Errorf("config reload not supported for %s: old=%v, new=%v", + field.Name, oldValue, newValue) + } case "connecterrorreports": diffOpts = append(diffOpts, &connectErrorReports{newValue: newValue.(int)}) case "reconnecterrorreports": @@ -1233,6 +1270,8 @@ func (s *Server) reloadAuthorization() { // can't hold the lock as go routine reading it may be waiting for lock as well resetCh = s.sys.resetCh } + // Check that publish retained messages sources are still allowed to publish. + s.mqttCheckPubRetainedPerms() s.mu.Unlock() if resetCh != nil { diff --git a/server/server.go b/server/server.go index 455457a7..59767929 100644 --- a/server/server.go +++ b/server/server.go @@ -222,6 +222,9 @@ type Server struct { // Websocket structure websocket srvWebsocket + // MQTT structure + mqtt srvMQTT + // exporting account name the importer experienced issues with incompleteAccExporterMap sync.Map @@ -533,6 +536,9 @@ func validateOptions(o *Options) error { if err := validateClusterName(o); err != nil { return err } + if err := validateMQTTOptions(o); err != nil { + return err + } // Finally check websocket options. return validateWebsocketOptions(o) } @@ -1516,6 +1522,11 @@ func (s *Server) Start() { s.startWebsocketServer() } + // MQTT + if opts.MQTT.Port != 0 { + s.startMQTT() + } + // Start up routing as well if needed. if opts.Cluster.Port != 0 { s.startGoRoutine(func() { @@ -1611,6 +1622,13 @@ func (s *Server) Shutdown() { s.websocket.listener = nil } + // Kick MQTT accept loop + if s.mqtt.listener != nil { + doneExpected++ + s.mqtt.listener.Close() + s.mqtt.listener = nil + } + // Kick leafnodes AcceptLoop() if s.leafNodeListener != nil { doneExpected++ @@ -1747,7 +1765,7 @@ func (s *Server) AcceptLoop(clr chan struct{}) { s.clientConnectURLs = s.getClientConnectURLs() s.listener = l - go s.acceptConnections(l, "Client", func(conn net.Conn) { s.createClient(conn, nil) }, + go s.acceptConnections(l, "Client", func(conn net.Conn) { s.createClient(conn, nil, nil) }, func(_ error) bool { if s.isLameDuckMode() { // Signal that we are not accepting new clients @@ -2073,7 +2091,7 @@ func (c *tlsMixConn) Read(b []byte) (int, error) { return c.Conn.Read(b) } -func (s *Server) createClient(conn net.Conn, ws *websocket) *client { +func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client { // Snapshot server options. opts := s.getOpts() @@ -2086,32 +2104,49 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { now := time.Now() c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now, ws: ws} + if mqtt != nil { + c.mqtt = mqtt + // Set some of the options here since MQTT clients don't + // send a regular CONNECT (but have their own). + c.opts.Lang = "mqtt" + c.opts.Verbose = false + } c.registerWithAccount(s.globalAccount()) - // Grab JSON info string + var info Info + var authRequired bool + s.mu.Lock() - info := s.copyInfo() - // If this is a websocket client and there is no top-level auth specified, - // then we use the websocket's specific boolean that will be set to true - // if there is any auth{} configured in websocket{}. - if ws != nil && !info.AuthRequired { - info.AuthRequired = s.websocket.authOverride + // We don't need the INFO to mqtt clients. + if mqtt == nil { + // Grab JSON info string + info = s.copyInfo() + // If this is a websocket client and there is no top-level auth specified, + // then we use the websocket's specific boolean that will be set to true + // if there is any auth{} configured in websocket{}. + if ws != nil && !info.AuthRequired { + info.AuthRequired = s.websocket.authOverride + } + if s.nonceRequired() { + // Nonce handling + var raw [nonceLen]byte + nonce := raw[:] + s.generateNonce(nonce) + info.Nonce = string(nonce) + } + c.nonce = []byte(info.Nonce) + authRequired = info.AuthRequired + } else { + authRequired = s.info.AuthRequired || s.mqtt.authOverride } - if s.nonceRequired() { - // Nonce handling - var raw [nonceLen]byte - nonce := raw[:] - s.generateNonce(nonce) - info.Nonce = string(nonce) - } - c.nonce = []byte(info.Nonce) + s.totalClients++ s.mu.Unlock() // Grab lock c.mu.Lock() - if info.AuthRequired { + if authRequired { c.flags.set(expectConnect) } @@ -2120,10 +2155,13 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { c.Debugf("Client connection created") - // Send our information. - // Need to be sent in place since writeLoop cannot be started until - // TLS handshake is done (if applicable). - c.sendProtoNow(c.generateClientInfoJSON(info)) + // We don't send the INFO to mqtt clients. + if mqtt == nil { + // Send our information. + // Need to be sent in place since writeLoop cannot be started until + // TLS handshake is done (if applicable). + c.sendProtoNow(c.generateClientInfoJSON(info)) + } // Unlock to register c.mu.Unlock() @@ -2155,6 +2193,22 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { return nil } s.clients[c.cid] = c + + // May be overridden based on type of client. + TLSConfig := opts.TLSConfig + TLSTimeout := opts.TLSTimeout + + tlsRequired := info.TLSRequired + // Websocket clients do TLS in the websocket http server. + if ws != nil { + tlsRequired = false + } else if mqtt != nil { + tlsRequired = opts.MQTT.TLSConfig != nil + if tlsRequired { + TLSConfig = opts.MQTT.TLSConfig + TLSTimeout = opts.MQTT.TLSTimeout + } + } s.mu.Unlock() // Re-Grab lock @@ -2163,13 +2217,12 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // Connection could have been closed while sending the INFO proto. isClosed := c.isClosed() - tlsRequired := ws == nil && info.TLSRequired var pre []byte // If we have both TLS and non-TLS allowed we need to see which // one the client wants. if !isClosed && opts.TLSConfig != nil && opts.AllowNonTLS { pre = make([]byte, 4) - c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(opts.TLSTimeout))) + c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(TLSTimeout))) n, _ := io.ReadFull(c.nc, pre[:]) c.nc.SetReadDeadline(time.Time{}) pre = pre[:n] @@ -2190,11 +2243,11 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { pre = nil } - c.nc = tls.Server(c.nc, opts.TLSConfig) + c.nc = tls.Server(c.nc, TLSConfig) conn := c.nc.(*tls.Conn) // Setup the timeout - ttl := secondsToDuration(opts.TLSTimeout) + ttl := secondsToDuration(TLSTimeout) time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) conn.SetReadDeadline(time.Now().Add(ttl)) @@ -2231,7 +2284,7 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // Check for Auth. We schedule this timer after the TLS handshake to avoid // the race where the timer fires during the handshake and causes the // server to write bad data to the socket. See issue #432. - if info.AuthRequired { + if authRequired { timeout := opts.AuthTimeout // For websocket, possibly override only if set. We make sure that // opts.AuthTimeout is set to a default value if not configured, @@ -2239,6 +2292,8 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // if user has explicitly set or not. if ws != nil && opts.Websocket.AuthTimeout != 0 { timeout = opts.Websocket.AuthTimeout + } else if mqtt != nil && opts.MQTT.AuthTimeout != 0 { + timeout = opts.MQTT.AuthTimeout } c.setAuthTimer(secondsToDuration(timeout)) } @@ -2421,7 +2476,11 @@ func (s *Server) removeClient(c *client) { if updateProtoInfoCount { s.cproto-- } + mqtt := c.mqtt != nil s.mu.Unlock() + if mqtt { + s.mqttHandleClosedClient(c) + } case ROUTER: s.removeRoute(c) case GATEWAY: diff --git a/server/server_test.go b/server/server_test.go index 1114315b..b546905f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1245,6 +1245,8 @@ func TestServerShutdownDuringStart(t *testing.T) { o.Websocket.Port = -1 o.Websocket.HandshakeTimeout = 1 o.Websocket.NoTLS = true + o.MQTT.Host = "127.0.0.1" + o.MQTT.Port = -1 // We are going to test that if the server is shutdown // while Start() runs (in this case, before), we don't @@ -1286,6 +1288,9 @@ func TestServerShutdownDuringStart(t *testing.T) { if s.websocket.listener != nil { listeners = append(listeners, "websocket") } + if s.mqtt.listener != nil { + listeners = append(listeners, "mqtt") + } s.mu.Unlock() if len(listeners) > 0 { lst := "" diff --git a/server/stream.go b/server/stream.go index 729c1db4..d709767a 100644 --- a/server/stream.go +++ b/server/stream.go @@ -27,6 +27,7 @@ import ( "path/filepath" "reflect" "strconv" + "strings" "sync" "time" @@ -95,6 +96,7 @@ type Stream struct { ddarr []*ddentry ddindex int ddtmr *time.Timer + nosubj bool } // Headers for published messages. @@ -125,13 +127,20 @@ func (a *Account) AddStream(config *StreamConfig) (*Stream, error) { // AddStreamWithStore adds a stream for the given account with custome store config options. func (a *Account) AddStreamWithStore(config *StreamConfig, fsConfig *FileStoreConfig) (*Stream, error) { + if strings.HasPrefix(config.Name, mqttStreamNamePrefix) { + return nil, fmt.Errorf("prefix %q is reserved for MQTT, unable to create stream %q", mqttStreamNamePrefix, config.Name) + } + return a.addStreamWithStore(config, fsConfig, false) +} + +func (a *Account) addStreamWithStore(config *StreamConfig, fsConfig *FileStoreConfig, noSubjectsOK bool) (*Stream, error) { s, jsa, err := a.checkForJetStream() if err != nil { return nil, err } // Sensible defaults. - cfg, err := checkStreamCfg(config) + cfg, err := checkStreamCfg(config, noSubjectsOK) if err != nil { return nil, err } @@ -168,7 +177,7 @@ func (a *Account) AddStreamWithStore(config *StreamConfig, fsConfig *FileStoreCo // Setup the internal client. c := s.createInternalJetStreamClient() - mset := &Stream{jsa: jsa, config: cfg, client: c, consumers: make(map[string]*Consumer)} + mset := &Stream{jsa: jsa, config: cfg, client: c, consumers: make(map[string]*Consumer), nosubj: noSubjectsOK} jsa.streams[cfg.Name] = mset storeDir := path.Join(jsa.storeDir, streamsDir, cfg.Name) @@ -416,7 +425,7 @@ func (jsa *jsAccount) subjectsOverlap(subjects []string) bool { // Default duplicates window. const StreamDefaultDuplicatesWindow = 2 * time.Minute -func checkStreamCfg(config *StreamConfig) (StreamConfig, error) { +func checkStreamCfg(config *StreamConfig, noSubjectOk bool) (StreamConfig, error) { if config == nil { return StreamConfig{}, fmt.Errorf("stream configuration invalid") } @@ -466,7 +475,9 @@ func checkStreamCfg(config *StreamConfig) (StreamConfig, error) { } if len(cfg.Subjects) == 0 { - cfg.Subjects = append(cfg.Subjects, cfg.Name) + if !noSubjectOk { + cfg.Subjects = append(cfg.Subjects, cfg.Name) + } } else { // We can allow overlaps, but don't allow direct duplicates. dset := make(map[string]struct{}, len(cfg.Subjects)) @@ -519,7 +530,10 @@ func (mset *Stream) Delete() error { // Update will allow certain configuration properties of an existing stream to be updated. func (mset *Stream) Update(config *StreamConfig) error { - cfg, err := checkStreamCfg(config) + mset.mu.RLock() + nosubj := mset.nosubj + mset.mu.RUnlock() + cfg, err := checkStreamCfg(config, nosubj) if err != nil { return err } @@ -691,7 +705,7 @@ func (mset *Stream) subscribeInternal(subject string, cb msgHandler) (*subscript mset.sid++ // Now create the subscription - return c.processSub([]byte(subject), nil, []byte(strconv.Itoa(mset.sid)), cb, false) + return c.processSub(c.createSub([]byte(subject), nil, []byte(strconv.Itoa(mset.sid)), cb), false) } // Helper for unlocked stream. @@ -901,7 +915,7 @@ func getExpectedLastSeq(hdr []byte) uint64 { } // processInboundJetStreamMsg handles processing messages bound for a stream. -func (mset *Stream) processInboundJetStreamMsg(_ *subscription, pc *client, subject, reply string, msg []byte) { +func (mset *Stream) processInboundJetStreamMsg(sub *subscription, pc *client, subject, reply string, msg []byte) { mset.mu.Lock() store := mset.store c := mset.client diff --git a/server/sublist.go b/server/sublist.go index 1c860a83..fcb0c28d 100644 --- a/server/sublist.go +++ b/server/sublist.go @@ -1392,3 +1392,65 @@ func (s *Sublist) collectAllSubs(l *level, subs *[]*subscription) { s.collectAllSubs(l.fwc.next, subs) } } + +func (s *Sublist) ReverseMatch(subject string) *SublistResult { + tsa := [32]string{} + tokens := tsa[:0] + start := 0 + for i := 0; i < len(subject); i++ { + if subject[i] == btsep { + tokens = append(tokens, subject[start:i]) + start = i + 1 + } + } + tokens = append(tokens, subject[start:]) + + result := &SublistResult{} + + s.Lock() + reverseMatchLevel(s.root, tokens, nil, result) + // Check for empty result. + if len(result.psubs) == 0 && len(result.qsubs) == 0 { + result = emptyResult + } + s.Unlock() + + return result +} + +func reverseMatchLevel(l *level, toks []string, n *node, results *SublistResult) { + for i, t := range toks { + if l == nil { + return + } + if len(t) == 1 { + if t[0] == fwc { + getAllNodes(l, results) + return + } else if t[0] == pwc { + for _, n := range l.nodes { + reverseMatchLevel(n.next, toks[i+1:], n, results) + } + return + } + } + n = l.nodes[t] + if n == nil { + break + } + l = n.next + } + if n != nil { + addNodeToResults(n, results) + } +} + +func getAllNodes(l *level, results *SublistResult) { + if l == nil { + return + } + for _, n := range l.nodes { + addNodeToResults(n, results) + getAllNodes(n.next, results) + } +} diff --git a/server/sublist_test.go b/server/sublist_test.go index 43a2ae84..eb35e1ea 100644 --- a/server/sublist_test.go +++ b/server/sublist_test.go @@ -1258,6 +1258,74 @@ func TestSublistRegisterInterestNotification(t *testing.T) { expectFalse() } +func TestSublistReverseMatch(t *testing.T) { + s := NewSublistWithCache() + fooSub := newSub("foo") + barSub := newSub("bar") + fooBarSub := newSub("foo.bar") + fooBazSub := newSub("foo.baz") + fooBarBazSub := newSub("foo.bar.baz") + s.Insert(fooSub) + s.Insert(barSub) + s.Insert(fooBarSub) + s.Insert(fooBazSub) + s.Insert(fooBarBazSub) + + r := s.ReverseMatch("foo") + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, fooSub, t) + + r = s.ReverseMatch("bar") + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, barSub, t) + + r = s.ReverseMatch("*") + verifyLen(r.psubs, 2, t) + verifyMember(r.psubs, fooSub, t) + verifyMember(r.psubs, barSub, t) + + r = s.ReverseMatch("baz") + verifyLen(r.psubs, 0, t) + + r = s.ReverseMatch("foo.*") + verifyLen(r.psubs, 2, t) + verifyMember(r.psubs, fooBarSub, t) + verifyMember(r.psubs, fooBazSub, t) + + r = s.ReverseMatch("*.*") + verifyLen(r.psubs, 2, t) + verifyMember(r.psubs, fooBarSub, t) + verifyMember(r.psubs, fooBazSub, t) + + r = s.ReverseMatch("*.bar") + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, fooBarSub, t) + + r = s.ReverseMatch("*.baz") + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, fooBazSub, t) + + r = s.ReverseMatch("bar.*") + verifyLen(r.psubs, 0, t) + + r = s.ReverseMatch("*.bat") + verifyLen(r.psubs, 0, t) + + r = s.ReverseMatch("foo.>") + verifyLen(r.psubs, 3, t) + verifyMember(r.psubs, fooBarSub, t) + verifyMember(r.psubs, fooBazSub, t) + verifyMember(r.psubs, fooBarBazSub, t) + + r = s.ReverseMatch(">") + verifyLen(r.psubs, 5, t) + verifyMember(r.psubs, fooSub, t) + verifyMember(r.psubs, barSub, t) + verifyMember(r.psubs, fooBarSub, t) + verifyMember(r.psubs, fooBazSub, t) + verifyMember(r.psubs, fooBarBazSub, t) +} + // -- Benchmarks Setup -- var benchSublistSubs []*subscription diff --git a/server/websocket.go b/server/websocket.go index 7b10877d..10defdc4 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -856,6 +856,7 @@ func (s *Server) startWebsocketServer() { if port == 0 { s.opts.Websocket.Port = hl.Addr().(*net.TCPAddr).Port } + s.Noticef("Listening for websocket clients on %s://%s:%d", proto, o.Host, port) s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port) if err != nil { s.Fatalf("Unable to get websocket connect URLs: %v", err) @@ -870,7 +871,7 @@ func (s *Server) startWebsocketServer() { s.Errorf(err.Error()) return } - s.createClient(res.conn, res.ws) + s.createClient(res.conn, res.ws, nil) }) hs := &http.Server{ Addr: hp, diff --git a/server/websocket_test.go b/server/websocket_test.go index fa5e235b..e0b075fb 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -2095,6 +2095,10 @@ func TestWSTLSVerifyAndMap(t *testing.T) { {"no filtering, client does not provide cert", false, false}, {"filtering, client provides cert", true, true}, {"filtering, client does not provide cert", true, false}, + {"no users override, client provides cert", false, true}, + {"no users override, client does not provide cert", false, false}, + {"users override, client provides cert", true, true}, + {"users override, client does not provide cert", true, false}, } { t.Run(test.name, func(t *testing.T) { o := testWSOptions()