From 1dba6418ed73d8b65e6d021966beee525febcfd4 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Thu, 28 May 2020 18:12:54 -0600 Subject: [PATCH 01/11] [ADDED] MQTT Support This PR introduces native support for MQTT clients. It requires use of accounts with JetStream enabled. Since as of now clustering is not available, MQTT will be limited to single instance. Only QoS 0 and 1 are supported at the moment. MQTT clients can exchange messages with NATS clients and vice-versa. Since JetStream is required, accounts with JetStream enabled must exist in order for an MQTT client to connect to the NATS Server. The administrator can limit the users that can use MQTT with the allowed_connection_types option in the user section. For instance: ``` accounts { mqtt { users [ {user: all, password: pwd, allowed_connection_types: ["STANDARD", "WEBSOCKET", "MQTT"]} {user: mqtt_only, password: pwd, allowed_connection_types: "MQTT"} ] jetstream: enabled } } ``` The "mqtt_only" can only be used for MQTT connections, which the user "all" accepts standard, websocket and MQTT clients. Here is what a configuration to enable MQTT looks like: ``` mqtt { # Specify a host and port to listen for websocket connections # # listen: "host:port" # It can also be configured with individual parameters, # namely host and port. # # host: "hostname" port: 1883 # TLS configuration section # # tls { # cert_file: "/path/to/cert.pem" # key_file: "/path/to/key.pem" # ca_file: "/path/to/ca.pem" # # # Time allowed for the TLS handshake to complete # timeout: 2.0 # # # Takes the user name from the certificate # # # # verify_an_map: true #} # Authentication override. Here are possible options. # # authorization { # # Simple username/password # # # user: "some_user_name" # password: "some_password" # # # Token. The server will check the MQTT's password in the connect # # protocol against this token. # # # # token: "some_token" # # # Time allowed for the client to send the MQTT connect protocol # # after the TCP connection is established. # # # timeout: 2.0 #} # If an MQTT client connects and does not provide a username/password and # this option is set, the server will use this client (and therefore account). # # no_auth_user: "some_user_name" # This is the time after which the server will redeliver a QoS 1 message # sent to a subscription that has not acknowledged (PUBACK) the message. # The default is 30 seconds. # # ack_wait: "1m" # This limits the number of QoS1 messages sent to a session without receiving # acknowledgement (PUBACK) from that session. MQTT specification defines # a packet identifier as an unsigned int 16, which means that the maximum # value is 65535. The default value is 1024. # # max_ack_pending: 100 } ``` Signed-off-by: Ivan Kozlovic --- server/accounts.go | 6 +- server/auth.go | 33 +- server/client.go | 67 +- server/client_test.go | 19 +- server/config_check_test.go | 78 +- server/consumer.go | 26 +- server/events.go | 3 +- server/jetstream.go | 13 +- server/monitor.go | 2 + server/mqtt.go | 2515 +++++++++++++++++++++ server/mqtt_test.go | 4111 +++++++++++++++++++++++++++++++++++ server/opts.go | 136 +- server/parser.go | 5 + server/reload.go | 41 +- server/server.go | 113 +- server/server_test.go | 5 + server/stream.go | 28 +- server/sublist.go | 62 + server/sublist_test.go | 68 + server/websocket.go | 3 +- server/websocket_test.go | 4 + 21 files changed, 7242 insertions(+), 96 deletions(-) create mode 100644 server/mqtt.go create mode 100644 server/mqtt_test.go diff --git a/server/accounts.go b/server/accounts.go index 9a68c032..02d32028 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -1726,7 +1726,7 @@ func (a *Account) subscribeInternal(subject string, cb msgHandler) (*subscriptio return nil, fmt.Errorf("no internal account client") } - return c.processSub([]byte(subject), nil, []byte(sid), cb, false) + return c.processSub(c.createSub([]byte(subject), nil, []byte(sid), cb), false) } // This will add an account subscription that matches the "from" from a service import entry. @@ -1751,7 +1751,7 @@ func (a *Account) addServiceImportSub(si *serviceImport) error { cb := func(sub *subscription, c *client, subject, reply string, msg []byte) { c.processServiceImport(si, a, msg) } - _, err := c.processSub([]byte(subject), nil, []byte(sid), cb, true) + _, err := c.processSub(c.createSub([]byte(subject), nil, []byte(sid), cb), true) return err } @@ -1951,7 +1951,7 @@ func (a *Account) createRespWildcard() []byte { a.mu.Unlock() // Create subscription and internal callback for all the wildcard response subjects. - c.processSub(wcsub, nil, []byte(sid), a.processServiceImportResponse, false) + c.processSub(c.createSub(wcsub, nil, []byte(sid), a.processServiceImportResponse), false) return pre } diff --git a/server/auth.go b/server/auth.go index 7043b2cc..da27dc7e 100644 --- a/server/auth.go +++ b/server/auth.go @@ -253,6 +253,8 @@ func (s *Server) configureAuthorization() { // Do similar for websocket config s.wsConfigAuth(&opts.Websocket) + // And for mqtt config + s.mqttConfigAuth(&opts.MQTT) } // Takes the given slices of NkeyUser and User options and build @@ -343,11 +345,15 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo ) s.mu.Lock() authRequired := s.info.AuthRequired - // c.ws is immutable, but may need lock if we get race reports. - if !authRequired && c.ws != nil { - // If no auth required for regular clients, then check if - // we have an override for websocket clients. - authRequired = s.websocket.authOverride + // c.ws/mqtt is immutable, but may need lock if we get race reports. + if !authRequired { + if c.mqtt != nil { + authRequired = s.mqtt.authOverride + } else if c.ws != nil { + // If no auth required for regular clients, then check if + // we have an override for websocket clients. + authRequired = s.websocket.authOverride + } } if !authRequired { // TODO(dlc) - If they send us credentials should we fail? @@ -361,7 +367,20 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo noAuthUser string ) tlsMap := opts.TLSMap - if c.ws != nil { + if c.mqtt != nil { + mo := &opts.MQTT + // Always override TLSMap. + tlsMap = mo.TLSMap + // The rest depends on if there was any auth override in + // the mqtt's config. + if s.mqtt.authOverride { + noAuthUser = mo.NoAuthUser + username = mo.Username + password = mo.Password + token = mo.Token + ao = true + } + } else if c.ws != nil { wo := &opts.Websocket // Always override TLSMap. tlsMap = wo.TLSMap @@ -998,7 +1017,7 @@ func validateAllowedConnectionTypes(m map[string]struct{}) error { for ct := range m { ctuc := strings.ToUpper(ct) switch ctuc { - case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode: + case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeMqtt: default: return fmt.Errorf("unknown connection type %q", ct) } diff --git a/server/client.go b/server/client.go index 6b9bb532..3a68b9d2 100644 --- a/server/client.go +++ b/server/client.go @@ -178,6 +178,7 @@ const ( NoRespondersRequiresHeaders ClusterNameConflict DuplicateRemoteLeafnodeConnection + DuplicateClientID ) // Some flags passed to processMsgResults @@ -237,6 +238,7 @@ type client struct { gw *gateway leaf *leaf ws *websocket + mqtt *mqtt // To keep track of gateway replies mapping gwrm map[string]*gwReplyMap @@ -442,6 +444,7 @@ type subscription struct { max int64 qw int32 closed int32 + mqtt *mqttSub } // Indicate that this subscription is closed. @@ -540,6 +543,8 @@ func (c *client) initClient() { name := "cid" if c.ws != nil { name = "wid" + } else if c.mqtt != nil { + name = "mid" } c.ncs.Store(fmt.Sprintf("%s - %s:%d", conn, name, c.cid)) case ROUTER: @@ -964,6 +969,9 @@ func (c *client) readLoop(pre []byte) { } nc := c.nc ws := c.ws != nil + if c.mqtt != nil { + c.mqtt.r = &mqttReader{reader: nc} + } c.in.rsz = startBufSize // Snapshot max control line since currently can not be changed on reload and we // were checking it on each call to parse. If this changes and we allow MaxControlLine @@ -983,6 +991,9 @@ func (c *client) readLoop(pre []byte) { c.mu.Unlock() defer func() { + if c.mqtt != nil { + s.mqttHandleWill(c) + } // These are used only in the readloop, so we can set them to nil // on exit of the readLoop. c.in.results, c.in.pacache = nil, nil @@ -1683,7 +1694,6 @@ func (c *client) processConnect(arg []byte) error { // By default register with the global account. c.registerWithAccount(srv.globalAccount()) } - } switch kind { @@ -1703,7 +1713,6 @@ func (c *client) processConnect(arg []byte) error { c.sendErr(ErrNoRespondersRequiresHeaders.Error()) c.closeConnection(NoRespondersRequiresHeaders) return ErrNoRespondersRequiresHeaders - } if verbose { c.sendOK() @@ -1933,6 +1942,9 @@ func (c *client) sendRTTPing() bool { // the c.rtt is 0 and wants to force an update by sending a PING. // Client lock held on entry. func (c *client) sendRTTPingLocked() bool { + if c.mqtt != nil { + return false + } // Most client libs send a CONNECT+PING and wait for a PONG from the // server. So if firstPongSent flag is set, it is ok for server to // send the PING. But in case we have client libs that don't do that, @@ -1978,7 +1990,9 @@ func (c *client) sendErr(err string) { if c.trace { c.traceOutOp("-ERR", []byte(err)) } - c.enqueueProto([]byte(fmt.Sprintf(errProto, err))) + if c.mqtt == nil { + c.enqueueProto([]byte(fmt.Sprintf(errProto, err))) + } c.mu.Unlock() } @@ -2212,32 +2226,30 @@ func (c *client) parseSub(argo []byte, noForward bool) error { arg := make([]byte, len(argo)) copy(arg, argo) args := splitArg(arg) - var ( - subject []byte - queue []byte - sid []byte - ) + sub := &subscription{client: c} switch len(args) { case 2: - subject = args[0] - queue = nil - sid = args[1] + sub.subject = args[0] + sub.queue = nil + sub.sid = args[1] case 3: - subject = args[0] - queue = args[1] - sid = args[2] + sub.subject = args[0] + sub.queue = args[1] + sub.sid = args[2] default: return fmt.Errorf("processSub Parse Error: '%s'", arg) } // If there was an error, it has been sent to the client. We don't return an // error here to not close the connection as a parsing error. - c.processSub(subject, queue, sid, nil, noForward) + c.processSub(sub, noForward) return nil } -func (c *client) processSub(subject, queue, bsid []byte, cb msgHandler, noForward bool) (*subscription, error) { - // Create the subscription - sub := &subscription{client: c, subject: subject, queue: queue, sid: bsid, icb: cb} +func (c *client) createSub(subject, queue, sid []byte, cb msgHandler) *subscription { + return &subscription{client: c, subject: subject, queue: queue, sid: sid, icb: cb} +} + +func (c *client) processSub(sub *subscription, noForward bool) (*subscription, error) { c.mu.Lock() @@ -2254,7 +2266,7 @@ func (c *client) processSub(subject, queue, bsid []byte, cb msgHandler, noForwar // This check does not apply to SYSTEM or JETSTREAM or ACCOUNT clients (because they don't have a `nc`...) if c.isClosed() && (kind != SYSTEM && kind != JETSTREAM && kind != ACCOUNT) { c.mu.Unlock() - return sub, nil + return nil, nil } // Check permissions if applicable. @@ -2301,6 +2313,9 @@ func (c *client) processSub(subject, queue, bsid []byte, cb msgHandler, noForwar updateGWs = c.srv.gateway.enabled } } + } else if es.mqtt != nil && sub.mqtt != nil { + es.mqtt.prm = sub.mqtt.prm + es.mqtt.qos = sub.mqtt.qos } // Unlocked from here onward c.mu.Unlock() @@ -3311,6 +3326,11 @@ func (c *client) processInboundClientMsg(msg []byte) bool { c.sendOK() } + // If MQTT client, check for retain flag now that we have passed permissions check + if c.mqtt != nil { + c.mqttHandlePubRetain() + } + // Check if this client's gateway replies map is not empty if atomic.LoadInt32(&c.cgwrt) > 0 && c.handleGWReplyMap(msg) { return true @@ -3983,7 +4003,7 @@ func (c *client) processPingTimer() { c.mu.Lock() c.ping.tmr = nil // Check if connection is still opened - if c.isClosed() { + if c.isClosed() || c.mqtt != nil { c.mu.Unlock() return } @@ -4042,7 +4062,7 @@ func adjustPingIntervalForGateway(d time.Duration) time.Duration { // Lock should be held func (c *client) setPingTimer() { - if c.srv == nil { + if c.srv == nil || c.mqtt != nil { return } d := c.srv.getOpts().PingInterval @@ -4596,7 +4616,7 @@ func convertAllowedConnectionTypes(cts []string) (map[string]struct{}, error) { for _, i := range cts { i = strings.ToUpper(i) switch i { - case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode: + case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeMqtt: m[i] = struct{}{} default: unknown = append(unknown, i) @@ -4627,6 +4647,9 @@ func (c *client) connectionTypeAllowed(acts map[string]struct{}) bool { if c.ws != nil { want = jwt.ConnectionTypeWebsocket } + if c.mqtt != nil { + want = jwt.ConnectionTypeMqtt + } _, ok := acts[want] return ok } diff --git a/server/client_test.go b/server/client_test.go index f207854e..68e48dcc 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -85,7 +85,7 @@ func createClientAsync(ch chan *client, s *Server, cli net.Conn) { s.grWG.Add(1) } go func() { - c := s.createClient(cli, nil) + c := s.createClient(cli, nil, nil) // Must be here to suppress +OK c.opts.Verbose = false if startWriteLoop { @@ -2317,7 +2317,7 @@ func TestCloseConnectionVeryEarly(t *testing.T) { // Call again with this closed connection. Alternatively, we // would have to call with a fake connection that implements // net.Conn but returns an error on Write. - s.createClient(c, nil) + s.createClient(c, nil, nil) // This connection should not have been added to the server. checkClientsCount(t, s, 0) @@ -2354,18 +2354,23 @@ func TestClientConnectionName(t *testing.T) { kind int kindStr string ws bool + mqtt bool }{ - {"client", CLIENT, "cid:", false}, - {"ws client", CLIENT, "wid:", true}, - {"route", ROUTER, "rid:", false}, - {"gateway", GATEWAY, "gid:", false}, - {"leafnode", LEAF, "lid:", false}, + {"client", CLIENT, "cid:", false, false}, + {"ws client", CLIENT, "wid:", true, false}, + {"mqtt client", CLIENT, "mid:", false, true}, + {"route", ROUTER, "rid:", false, false}, + {"gateway", GATEWAY, "gid:", false, false}, + {"leafnode", LEAF, "lid:", false, false}, } { t.Run(test.name, func(t *testing.T) { c := &client{srv: s, nc: &connString{}, kind: test.kind} if test.ws { c.ws = &websocket{} } + if test.mqtt { + c.mqtt = &mqtt{} + } c.initClient() if host := "fe80::abc:def:ghi:123%utun0"; host != c.host { diff --git a/server/config_check_test.go b/server/config_check_test.go index 166aa761..45d726e9 100644 --- a/server/config_check_test.go +++ b/server/config_check_test.go @@ -1301,8 +1301,8 @@ func TestConfigCheck(t *testing.T) { user: user1 password: pwd users = [{user: user2, password: pwd}] - } - } + } + } `, err: errors.New("can not have a single user/pass and a users array"), errorLine: 3, @@ -1312,17 +1312,75 @@ func TestConfigCheck(t *testing.T) { name: "duplicate usernames in leafnode authorization", config: ` leafnodes { - authorization { - users = [ - {user: user, password: pwd} - {user: user, password: pwd} - ] - } - } + authorization { + users = [ + {user: user, password: pwd} + {user: user, password: pwd} + ] + } + } `, err: errors.New(`duplicate user "user" detected in leafnode authorization`), errorLine: 3, - errorPos: 20, + errorPos: 21, + }, + { + name: "mqtt bad type", + config: ` + mqtt [ + "wrong" + ] + `, + err: errors.New(`Expected mqtt to be a map, got []interface {}`), + errorLine: 2, + errorPos: 17, + }, + { + name: "mqtt bad listen", + config: ` + mqtt { + listen: "xxxxxxxx" + } + `, + err: errors.New(`could not parse address string "xxxxxxxx"`), + errorLine: 3, + errorPos: 21, + }, + { + name: "mqtt bad host", + config: ` + mqtt { + host: 1234 + } + `, + err: errors.New(`interface conversion: interface {} is int64, not string`), + errorLine: 3, + errorPos: 21, + }, + { + name: "mqtt bad port", + config: ` + mqtt { + port: "abc" + } + `, + err: errors.New(`interface conversion: interface {} is string, not int64`), + errorLine: 3, + errorPos: 21, + }, + { + name: "mqtt bad TLS", + config: ` + mqtt { + port: -1 + tls { + cert_file: "./configs/certs/server.pem" + } + } + `, + err: errors.New(`missing 'key_file' in TLS configuration`), + errorLine: 4, + errorPos: 21, }, { name: "connection types wrong type", diff --git a/server/consumer.go b/server/consumer.go index edd618a4..ef0e71a4 100644 --- a/server/consumer.go +++ b/server/consumer.go @@ -213,6 +213,13 @@ const ( ) func (mset *Stream) AddConsumer(config *ConsumerConfig) (*Consumer, error) { + if name := mset.Name(); strings.HasPrefix(name, mqttStreamNamePrefix) { + return nil, fmt.Errorf("stream prefix %q is reserved for MQTT, unable to create consumer on %q", mqttStreamNamePrefix, name) + } + return mset.addConsumerCheckInterest(config, true) +} + +func (mset *Stream) addConsumerCheckInterest(config *ConsumerConfig, checkInterest bool) (*Consumer, error) { if config == nil { return nil, fmt.Errorf("consumer config required") } @@ -350,7 +357,7 @@ func (mset *Stream) AddConsumer(config *ConsumerConfig) (*Consumer, error) { } else { // If we are a push mode and not active and the only difference // is deliver subject then update and return. - if configsEqualSansDelivery(ocfg, *config) && eo.hasNoLocalInterest() { + if configsEqualSansDelivery(ocfg, *config) && (!checkInterest || eo.hasNoLocalInterest()) { eo.updateDeliverSubject(config.DeliverSubject) return eo, nil } else { @@ -2177,9 +2184,22 @@ func (mset *Stream) deliveryFormsCycle(deliverySubject string) bool { return false } -// This is same as check for delivery cycle. +// Check that the subject is a subset of the stream's configured subjects, +// or returns true if the stream has been created with no subject. func (mset *Stream) validSubject(partitionSubject string) bool { - return mset.deliveryFormsCycle(partitionSubject) + mset.mu.RLock() + defer mset.mu.RUnlock() + + if mset.nosubj && len(mset.config.Subjects) == 0 { + return true + } + + for _, subject := range mset.config.Subjects { + if subjectIsSubsetMatch(partitionSubject, subject) { + return true + } + } + return false } // SetInActiveDeleteThreshold sets the delete threshold for how long to wait diff --git a/server/events.go b/server/events.go index 449198cd..3e9c84dc 100644 --- a/server/events.go +++ b/server/events.go @@ -1410,7 +1410,8 @@ func (s *Server) systemSubscribe(subject, queue string, internalOnly bool, cb ms if queue != "" { q = []byte(queue) } - return c.processSub([]byte(subject), q, []byte(sid), cb, internalOnly) + // Now create the subscription + return c.processSub(c.createSub([]byte(subject), q, []byte(sid), cb), internalOnly) } func (s *Server) sysUnsubscribe(sub *subscription) { diff --git a/server/jetstream.go b/server/jetstream.go index 9a591471..0c651a01 100644 --- a/server/jetstream.go +++ b/server/jetstream.go @@ -533,7 +533,12 @@ func (a *Account) EnableJetStream(limits *JetStreamAccountLimits) error { s.Warnf(" Error adding Stream %q to Template %q: %v", cfg.Name, cfg.Template, err) } } - mset, err := a.AddStream(&cfg.StreamConfig) + // TODO: We should not rely on the stream name. + // However, having a StreamConfig property, such as AllowNoSubject, + // was not accepted because it does not make sense outside of the + // MQTT use-case. So need to revisit this. + mqtt := cfg.StreamConfig.Name == mqttStreamName + mset, err := a.addStreamWithStore(&cfg.StreamConfig, nil, mqtt) if err != nil { s.Warnf(" Error recreating Stream %q: %v", cfg.Name, err) continue @@ -578,7 +583,7 @@ func (a *Account) EnableJetStream(limits *JetStreamAccountLimits) error { // the consumer can reconnect. We will create it as a durable and switch it. cfg.ConsumerConfig.Durable = ofi.Name() } - obs, err := mset.AddConsumer(&cfg.ConsumerConfig) + obs, err := mset.addConsumerCheckInterest(&cfg.ConsumerConfig, !mqtt) if err != nil { s.Warnf(" Error adding Consumer: %v", err) continue @@ -1060,7 +1065,7 @@ func (a *Account) AddStreamTemplate(tc *StreamTemplateConfig) (*StreamTemplate, // FIXME(dlc) - Hacky tcopy := tc.deepCopy() tcopy.Config.Name = "_" - cfg, err := checkStreamCfg(tcopy.Config) + cfg, err := checkStreamCfg(tcopy.Config, false) if err != nil { return nil, err } @@ -1115,7 +1120,7 @@ func (t *StreamTemplate) createTemplateSubscriptions() error { sid := 1 for _, subject := range t.Config.Subjects { // Now create the subscription - if _, err := c.processSub([]byte(subject), nil, []byte(strconv.Itoa(sid)), t.processInboundTemplateMsg, false); err != nil { + if _, err := c.processSub(c.createSub([]byte(subject), nil, []byte(strconv.Itoa(sid)), t.processInboundTemplateMsg), false); err != nil { c.acc.DeleteStreamTemplate(t.Name) return err } diff --git a/server/monitor.go b/server/monitor.go index 4bae0fda..6c941f2a 100644 --- a/server/monitor.go +++ b/server/monitor.go @@ -1915,6 +1915,8 @@ func (reason ClosedState) String() string { return "Cluster Name Conflict" case DuplicateRemoteLeafnodeConnection: return "Duplicate Remote LeafNode Connection" + case DuplicateClientID: + return "Duplicate Client ID" } return "Unknown State" diff --git a/server/mqtt.go b/server/mqtt.go new file mode 100644 index 00000000..4c508b18 --- /dev/null +++ b/server/mqtt.go @@ -0,0 +1,2515 @@ +// Copyright 2020 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "sync" + "time" + "unicode/utf8" + + "github.com/nats-io/nuid" +) + +// References to "spec" here is from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.pdf + +const ( + mqttPacketConnect = byte(0x10) + mqttPacketConnectAck = byte(0x20) + mqttPacketPub = byte(0x30) + mqttPacketPubAck = byte(0x40) + mqttPacketPubRec = byte(0x50) + mqttPacketPubRel = byte(0x60) + mqttPacketPubComp = byte(0x70) + mqttPacketSub = byte(0x80) + mqttPacketSubAck = byte(0x90) + mqttPacketUnsub = byte(0xa0) + mqttPacketUnsubAck = byte(0xb0) + mqttPacketPing = byte(0xc0) + mqttPacketPingResp = byte(0xd0) + mqttPacketDisconnect = byte(0xe0) + mqttPacketMask = byte(0xf0) + mqttPacketFlagMask = byte(0x0f) + + mqttProtoLevel = byte(0x4) + + // Connect flags + mqttConnFlagReserved = byte(0x1) + mqttConnFlagCleanSession = byte(0x2) + mqttConnFlagWillFlag = byte(0x04) + mqttConnFlagWillQoS = byte(0x18) + mqttConnFlagWillRetain = byte(0x20) + mqttConnFlagPasswordFlag = byte(0x40) + mqttConnFlagUsernameFlag = byte(0x80) + + // Publish flags + mqttPubFlagRetain = byte(0x01) + mqttPubFlagQoS = byte(0x06) + mqttPubFlagDup = byte(0x08) + mqttPubQos1 = byte(0x2) // 1 << 1 + + // Subscribe flags + mqttSubscribeFlags = byte(0x2) + mqttSubAckFailure = byte(0x80) + + // Unsubscribe flags + mqttUnsubscribeFlags = byte(0x2) + + // ConnAck returned codes + mqttConnAckRCConnectionAccepted = byte(0x0) + mqttConnAckRCUnacceptableProtocolVersion = byte(0x1) + mqttConnAckRCIdentifierRejected = byte(0x2) + mqttConnAckRCServerUnavailable = byte(0x3) + mqttConnAckRCBadUserOrPassword = byte(0x4) + mqttConnAckRCNotAuthorized = byte(0x5) + + // Topic/Filter characters + mqttTopicLevelSep = '/' + mqttSingleLevelWC = '+' + mqttMultiLevelWC = '#' + + // This is appended to the sid of a subscription that is + // created on the upper level subject because of the MQTT + // wildcard '#' semantic. + mqttMultiLevelSidSuffix = " fwc" + + // This is the prefix for NATS subscriptions subjects associated as delivery + // subject of JS consumer. We want to make them unique so will prevent users + // MQTT subscriptions to start with this. + mqttSubPrefix = "$MQTT.sub." + + // MQTT Stream names prefix. Will be used to prevent users from creating + // JS consumers on those. + mqttStreamNamePrefix = "$MQTT_" + + // Stream name for MQTT messages on a given account + mqttStreamName = mqttStreamNamePrefix + "msgs" + + // Stream name for MQTT retained messages on a given account + mqttRetainedMsgsStreamName = mqttStreamNamePrefix + "rmsgs" + + // Stream name for MQTT sessions on a given account + mqttSessionsStreamName = mqttStreamNamePrefix + "sessions" + + // Normally, MQTT server should not redeliver QoS 1 messages to clients, + // except after client reconnects. However, NATS Server will redeliver + // unacknowledged messages after this default interval. This can be + // changed with the server.Options.MQTT.AckWait option. + mqttDefaultAckWait = 30 * time.Second + + // This is the default for the outstanding number of pending QoS 1 + // messages sent to a session with QoS 1 subscriptions. + mqttDefaultMaxAckPending = 1024 +) + +var ( + mqttPingResponse = []byte{mqttPacketPingResp, 0x0} + mqttProtoName = []byte("MQTT") + mqttOldProtoName = []byte("MQIsdp") +) + +type srvMQTT struct { + listener net.Listener + authOverride bool + sessmgr mqttSessionManager +} + +type mqttSessionManager struct { + mu sync.RWMutex + sessions map[string]*mqttAccountSessionManager // key is account name +} + +type mqttAccountSessionManager struct { + mu sync.RWMutex + sstream *Stream // stream where sessions are recorded + mstream *Stream // messages stream + rstream *Stream // retained messages stream + sessions map[string]*mqttSession // key is MQTT client ID + sl *Sublist // sublist allowing to find retained messages for given subscription + retmsgs map[string]*mqttRetainedMsg // retained messages +} + +type mqttSession struct { + mu sync.Mutex + c *client + subs map[string]byte + cons map[string]*Consumer + stream *Stream + sseq uint64 // stream sequence where this sesion is recorded + pending map[uint16]*mqttPending // Key is the PUBLISH packet identifier sent to client and maps to a mqttPending record + cpending map[*Consumer]map[uint64]uint16 // For each JS consumer, the key is the stream sequence and maps to the PUBLISH packet identifier + ppi uint16 // publish packet identifier + maxp uint16 + stalled bool + clean bool +} + +type mqttPersistedSession struct { + ID string `json:"id,omitempty"` + Clean bool `json:"clean,omitempty"` + Subs map[string]byte `json:"subs,omitempty"` + Cons map[string]string `json:"cons,omitempty"` +} + +type mqttRetainedMsg struct { + Msg []byte `json:"msg,omitempty"` + Flags byte `json:"flags,omitempty"` + Source string `json:"source,omitempty"` + + // non exported + sseq uint64 + sub *subscription +} + +type mqttSub struct { + qos byte + // Pending serialization of retained messages to be sent when subscription is registered + prm *mqttWriter + // This is the corresponding JS consumer. This is applicable to a subscription that is + // done for QoS > 0 (the subscription attached to a JS consumer's delivery subject). + jsCons *Consumer +} + +type mqtt struct { + r *mqttReader + cp *mqttConnectProto + pp *mqttPublish + asm *mqttAccountSessionManager // quick reference to account session manager, immutable after processConnect() + sess *mqttSession // quick reference to session, immutable after processConnect() +} + +type mqttPending struct { + sseq uint64 // stream sequence + dseq uint64 // consumer delivery sequence + dcount uint64 // consumer delivery count + jsCons *Consumer // pointer to JS consumer (to which we will call ackMsg() on with above info) +} + +type mqttConnectProto struct { + clientID string + rd time.Duration + will *mqttWill + flags byte +} + +type mqttIOReader interface { + io.Reader + SetReadDeadline(time.Time) error +} + +type mqttReader struct { + reader mqttIOReader + buf []byte + pos int +} + +type mqttWriter struct { + bytes.Buffer +} + +type mqttWill struct { + topic []byte + message []byte + qos byte + retain bool +} + +type mqttFilter struct { + filter string + qos byte +} + +type mqttPublish struct { + subject []byte + msg []byte + sz int + pi uint16 + flags byte +} + +func (s *Server) startMQTT() { + sopts := s.getOpts() + o := &sopts.MQTT + + var hl net.Listener + var err error + + port := o.Port + if port == -1 { + port = 0 + } + hp := net.JoinHostPort(o.Host, strconv.Itoa(port)) + s.mu.Lock() + if s.shutdown { + s.mu.Unlock() + return + } + s.mqtt.sessmgr.sessions = make(map[string]*mqttAccountSessionManager) + hl, err = net.Listen("tcp", hp) + if err != nil { + s.mu.Unlock() + s.Fatalf("Unable to listen for MQTT connections: %v", err) + return + } + if port == 0 { + o.Port = hl.Addr().(*net.TCPAddr).Port + } + s.mqtt.listener = hl + scheme := "mqtt" + if o.TLSConfig != nil { + scheme = "tls" + } + s.Noticef("Listening for MQTT clients on %s://%s:%d", scheme, o.Host, o.Port) + go s.acceptConnections(hl, "MQTT", func(conn net.Conn) { s.createClient(conn, nil, &mqtt{}) }, nil) + s.mu.Unlock() +} + +// Given the mqtt options, we check if any auth configuration +// has been provided. If so, possibly create users/nkey users and +// store them in s.mqtt.users/nkeys. +// Also update a boolean that indicates if auth is required for +// mqtt clients. +// Server lock is held on entry. +func (s *Server) mqttConfigAuth(opts *MQTTOpts) { + mqtt := &s.mqtt + // If any of those is specified, we consider that there is an override. + mqtt.authOverride = opts.Username != _EMPTY_ || opts.Token != _EMPTY_ || opts.NoAuthUser != _EMPTY_ +} + +// Validate the mqtt related options. +func validateMQTTOptions(o *Options) error { + mo := &o.MQTT + // If no port is defined, we don't care about other options + if mo.Port == 0 { + return nil + } + // If there is a NoAuthUser, we need to have Users defined and + // the user to be present. + if mo.NoAuthUser != _EMPTY_ { + if err := validateNoAuthUser(o, mo.NoAuthUser); err != nil { + return err + } + } + // Token/Username not possible if there are users/nkeys + if len(o.Users) > 0 || len(o.Nkeys) > 0 { + if mo.Username != _EMPTY_ { + return fmt.Errorf("mqtt authentication username not compatible with presence of users/nkeys") + } + if mo.Token != _EMPTY_ { + return fmt.Errorf("mqtt authentication token not compatible with presence of users/nkeys") + } + } + if mo.AckWait < 0 { + return fmt.Errorf("ack wait must be a positive value") + } + return nil +} + +// Parse protocols inside the given buffer. +// This is invoked from the readLoop. +func (c *client) mqttParse(buf []byte) error { + c.mu.Lock() + s := c.srv + trace := c.trace + connected := c.flags.isSet(connectReceived) + mqtt := c.mqtt + r := mqtt.r + var rd time.Duration + if mqtt.cp != nil { + rd = mqtt.cp.rd + if rd > 0 { + r.reader.SetReadDeadline(time.Time{}) + } + } + c.mu.Unlock() + + r.reset(buf) + + var err error + var b byte + var pl int + + for err == nil && r.hasMore() { + + // Read packet type and flags + if b, err = r.readByte("packet type"); err != nil { + break + } + + // Packet type + pt := b & mqttPacketMask + + // If client was not connected yet, the first packet must be + // a mqttPacketConnect otherwise we fail the connection. + if !connected && pt != mqttPacketConnect { + err = errors.New("not connected") + break + } + + if pl, err = r.readPacketLen(); err != nil { + break + } + + switch pt { + case mqttPacketPub: + pp := mqttPublish{flags: b & mqttPacketFlagMask} + err = c.mqttParsePub(r, pl, &pp) + if trace { + c.traceInOp("PUBLISH", errOrTrace(err, mqttPubTrace(&pp))) + if err == nil { + c.traceMsg(pp.msg) + } + } + if err == nil { + s.mqttProcessPub(c, &pp) + if pp.pi > 0 { + c.mqttEnqueuePubAck(pp.pi) + if trace { + c.traceOutOp("PUBACK", []byte(fmt.Sprintf("pi=%v", pp.pi))) + } + } + } + case mqttPacketPubAck: + var pi uint16 + pi, err = mqttParsePubAck(r, pl) + if trace { + c.traceInOp("PUBACK", errOrTrace(err, fmt.Sprintf("pi=%v", pi))) + } + if err == nil { + c.mqttProcessPubAck(pi) + } + case mqttPacketSub: + var pi uint16 // packet identifier + var filters []*mqttFilter + var subs []*subscription + pi, filters, err = c.mqttParseSubs(r, b, pl) + if trace { + c.traceInOp("SUBSCRIBE", errOrTrace(err, mqttSubscribeTrace(filters))) + } + if err == nil { + subs, err = c.mqttProcessSubs(filters) + if err == nil && trace { + c.traceOutOp("SUBACK", []byte(mqttSubscribeTrace(filters))) + } + } + if err == nil { + c.mqttEnqueueSubAck(pi, filters) + c.mqttSendRetainedMsgsToNewSubs(subs) + } + case mqttPacketUnsub: + var pi uint16 // packet identifier + var filters []*mqttFilter + pi, filters, err = c.mqttParseUnsubs(r, b, pl) + if trace { + c.traceInOp("UNSUBSCRIBE", errOrTrace(err, mqttUnsubscribeTrace(filters))) + } + if err == nil { + err = c.mqttProcessUnsubs(filters) + if err == nil && trace { + c.traceOutOp("UNSUBACK", []byte(strconv.FormatInt(int64(pi), 10))) + } + } + if err == nil { + c.mqttEnqueueUnsubAck(pi) + } + case mqttPacketPing: + if trace { + c.traceInOp("PINGREQ", nil) + } + c.mqttEnqueuePingResp() + if trace { + c.traceOutOp("PINGRESP", nil) + } + case mqttPacketConnect: + // It is an error to receive a second connect packet + if connected { + err = errors.New("second connect packet") + break + } + var rc byte + var cp *mqttConnectProto + var sessp bool + rc, cp, err = c.mqttParseConnect(r, pl) + if trace && cp != nil { + c.traceInOp("CONNECT", errOrTrace(err, c.mqttConnectTrace(cp))) + } + if rc != 0 { + c.mqttEnqueueConnAck(rc, sessp) + if trace { + c.traceOutOp("CONNACK", []byte(fmt.Sprintf("sp=%v rc=%v", sessp, rc))) + } + } else if err == nil { + if err = s.mqttProcessConnect(c, cp, trace); err == nil { + connected = true + rd = cp.rd + } + } + case mqttPacketDisconnect: + if trace { + c.traceInOp("DISCONNECT", nil) + } + // Normal disconnect, we need to discard the will. + // Spec [MQTT-3.1.2-8] + c.mu.Lock() + if c.mqtt.cp != nil { + c.mqtt.cp.will = nil + } + c.mu.Unlock() + c.closeConnection(ClientClosed) + return nil + case mqttPacketPubRec: + fallthrough + case mqttPacketPubRel: + fallthrough + case mqttPacketPubComp: + err = fmt.Errorf("protocol %d not supported", pt>>4) + default: + err = fmt.Errorf("received unknown packet type %d", pt>>4) + } + } + if err == nil && rd > 0 { + r.reader.SetReadDeadline(time.Now().Add(rd)) + } + return err +} + +// Update the session (possibly remove it) of this disconnected client. +func (s *Server) mqttHandleClosedClient(c *client) { + c.mu.Lock() + cp := c.mqtt.cp + accName := c.acc.Name + c.mu.Unlock() + if cp == nil { + return + } + sm := &s.mqtt.sessmgr + sm.mu.RLock() + asm, ok := sm.sessions[accName] + sm.mu.RUnlock() + if !ok { + return + } + + asm.mu.Lock() + defer asm.mu.Unlock() + es, ok := asm.sessions[cp.clientID] + // If not found, ignore. + if !ok { + return + } + es.mu.Lock() + defer es.mu.Unlock() + // If this client is not the currently registered client, ignore. + if es.c != c { + return + } + // It the session was created with "clean session" flag, we cleanup + // when the client disconnects. + if es.clean { + es.clear() + delete(asm.sessions, cp.clientID) + } else { + // Clear the client from the session, but session stays. + es.c = nil + } +} + +// Updates the MaxAckPending for all MQTT sessions, updating the +// JetStream consumers and updating their max ack pending and forcing +// a expiration of pending messages. +func (s *Server) mqttUpdateMaxAckPending(newmaxp uint16) { + msm := &s.mqtt.sessmgr + s.accounts.Range(func(k, _ interface{}) bool { + accName := k.(string) + msm.mu.RLock() + asm := msm.sessions[accName] + msm.mu.RUnlock() + if asm == nil { + // Move to next account + return true + } + asm.mu.RLock() + for _, sess := range asm.sessions { + sess.mu.Lock() + sess.maxp = newmaxp + // Do not check for sess.stalled here because due to original + // consumer's maxp it is possible that the consumer was stalled + // and MQTT code did not have to stall the session. + for _, cons := range sess.cons { + cons.mu.Lock() + cons.maxp = int(newmaxp) + cons.forceExpirePending() + cons.mu.Unlock() + } + sess.mu.Unlock() + } + asm.mu.RUnlock() + return true + }) +} + +////////////////////////////////////////////////////////////////////////////// +// +// Sessions Manager related functions +// +////////////////////////////////////////////////////////////////////////////// + +// Returns the MQTT sessions manager for a given account. +// If new, creates the required JetStream streams/consumers +// for handling of sessions and messages. +func (sm *mqttSessionManager) getOrCreateAccountSessionManager(clientID string, c *client) (*mqttAccountSessionManager, error) { + c.mu.Lock() + acc := c.acc + accName := acc.GetName() + c.mu.Unlock() + + sm.mu.RLock() + asm, ok := sm.sessions[accName] + sm.mu.RUnlock() + + if ok { + return asm, nil + } + + // Not found, now take the write lock and check again + sm.mu.Lock() + defer sm.mu.Unlock() + asm, ok = sm.sessions[accName] + if ok { + return asm, nil + } + // First check that we have JS enabled for this account. + // TODO: Since we check only when creating a session manager for this + // account, would probably need to do some cleanup if JS can be disabled + // on config reload. + if !acc.JetStreamEnabled() { + return nil, fmt.Errorf("JetStream must be enabled for account %q used by MQTT client ID %q", + accName, clientID) + } + // Need to create one here. + asm = &mqttAccountSessionManager{sessions: make(map[string]*mqttSession)} + if err := asm.init(acc, c); err != nil { + return nil, err + } + sm.sessions[accName] = asm + return asm, nil +} + +////////////////////////////////////////////////////////////////////////////// +// +// Account Sessions Manager related functions +// +////////////////////////////////////////////////////////////////////////////// + +// Creates JS streams/consumers for handling of sessions and messages for this +// account. Note that lookup are performed in case we are in a restart +// situation and the account is loaded for the first time but had state on disk. +// +// Global session manager lock is held on entry. +func (as *mqttAccountSessionManager) init(acc *Account, c *client) error { + opts := c.srv.getOpts() + var err error + // Start with sessions stream + as.sstream, err = acc.LookupStream(mqttSessionsStreamName) + if err != nil { + as.sstream, err = acc.addStreamWithStore(&StreamConfig{ + Subjects: []string{}, + Name: mqttSessionsStreamName, + Storage: FileStorage, + Retention: InterestPolicy, + }, nil, true) + if err != nil { + return fmt.Errorf("unable to create sessions stream for MQTT account %q: %v", acc.GetName(), err) + } + } + // Create the stream for the messages. + as.mstream, err = acc.LookupStream(mqttStreamName) + if err != nil { + as.mstream, err = acc.addStreamWithStore(&StreamConfig{ + Subjects: []string{}, + Name: mqttStreamName, + Storage: FileStorage, + Retention: InterestPolicy, + }, nil, true) + if err != nil { + return fmt.Errorf("unable to create messages stream for MQTT account %q: %v", acc.GetName(), err) + } + } + // Create the stream for retained messages. + as.rstream, err = acc.LookupStream(mqttRetainedMsgsStreamName) + if err != nil { + as.rstream, err = acc.addStreamWithStore(&StreamConfig{ + Subjects: []string{}, + Name: mqttRetainedMsgsStreamName, + Storage: FileStorage, + Retention: InterestPolicy, + }, nil, true) + if err != nil { + return fmt.Errorf("unable to create retained messages stream for MQTT account %q: %v", acc.GetName(), err) + } + } + // Now recover all sessions (in case it did already exist) + if state := as.sstream.State(); state.Msgs > 0 { + for seq := state.FirstSeq; seq <= state.LastSeq; seq++ { + _, _, content, _, err := as.sstream.store.LoadMsg(seq) + if err != nil { + if err != errDeletedMsg { + c.Errorf("Error loading session record at sequence %v: %v", seq, err) + } + continue + } + ps := &mqttPersistedSession{} + if err := json.Unmarshal(content, ps); err != nil { + c.Errorf("Error unmarshaling session record at sequence %v: %v", seq, err) + continue + } + if as.sessions == nil { + as.sessions = make(map[string]*mqttSession) + } + es, ok := as.sessions[ps.ID] + if ok && es.sseq != 0 { + as.sstream.DeleteMsg(es.sseq) + } else if !ok { + es = mqttSessionCreate(opts) + es.stream = as.sstream + as.sessions[ps.ID] = es + } + es.sseq = seq + es.clean = ps.Clean + es.subs = ps.Subs + if l := len(ps.Cons); l > 0 { + if es.cons == nil { + es.cons = make(map[string]*Consumer, l) + } + for sid, name := range ps.Cons { + if cons := as.mstream.LookupConsumer(name); cons != nil { + es.cons[sid] = cons + } + } + } + } + } + // Finally, recover retained messages. + if state := as.rstream.State(); state.Msgs > 0 { + for seq := state.FirstSeq; seq <= state.LastSeq; seq++ { + subject, _, content, _, err := as.rstream.store.LoadMsg(seq) + if err != nil { + if err != errDeletedMsg { + c.Errorf("Error loading retained message at sequence %v: %v", seq, err) + } + continue + } + rm := &mqttRetainedMsg{} + if err := json.Unmarshal(content, &rm); err != nil { + c.Errorf("Error unmarshaling retained message on subject %q, sequence %v: %v", subject, seq, err) + continue + } + rm = as.handleRetainedMsg(subject, rm) + rm.sseq = seq + } + } + return nil +} + +// Add/Replace this message from the retained messages map. +// Returns the retained message actually stored in the map, which means that +// it may be different from the given `rm`. +// +// Account session manager lock held on entry. +func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rm *mqttRetainedMsg) *mqttRetainedMsg { + if as.retmsgs == nil { + as.retmsgs = make(map[string]*mqttRetainedMsg) + as.sl = NewSublistWithCache() + } else { + // Check if we already had one. If so, update the existing one. + if erm, exists := as.retmsgs[key]; exists { + erm.Msg = rm.Msg + erm.Flags = rm.Flags + erm.Source = rm.Source + return erm + } + } + rm.sub = &subscription{subject: []byte(key)} + as.retmsgs[key] = rm + as.sl.Insert(rm.sub) + return rm +} + +// Process subscriptions for the given session/client. +// +// When `fromSubProto` is false, it means that this is invoked from the CONNECT +// protocol, when restoring subscriptions that were saved for this session. +// In that case, there is no need to update the session record. +// +// When `fromSubProto` is true, it means that this call is invoked from the +// processing of the SUBSCRIBE protocol, which means that the session needs to +// be updated. It also means that if a subscription on same subject with same +// QoS already exist, we should not be recreating the subscription/JS durable, +// since it was already done when processing the CONNECT protocol. +// +// Account session manager lock held on entry. +// Session lock held when `fromSubProto` is false. +func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, clientID string, c *client, + filters []*mqttFilter, fromSubProto, trace bool) ([]*subscription, error) { + + if fromSubProto { + sess.mu.Lock() + defer sess.mu.Unlock() + if sess.c != c { + return nil, fmt.Errorf("client %q no longer registered with MQTT session", clientID) + } + } + + addJSConsToSess := func(sid string, cons *Consumer) { + if cons == nil { + return + } + if sess.cons == nil { + sess.cons = make(map[string]*Consumer) + } + sess.cons[sid] = cons + } + + subs := make([]*subscription, 0, len(filters)) + for _, f := range filters { + if f.qos > 1 { + f.qos = 1 + } + subject := f.filter + sid := subject + + if strings.HasPrefix(subject, mqttSubPrefix) { + f.qos = mqttSubAckFailure + continue + } + + var jscons *Consumer + var jssub *subscription + var err error + + sub := c.mqttCreateSub(subject, sid, mqttDeliverMsgCb, f.qos) + if fromSubProto { + as.serializeRetainedMsgsForSub(sess, c, sub, trace) + } + // Note that if a subscription already exists on this subject, + // the sub is updated with the new qos/prm and the pointer to + // the existing subscription is returned. + sub, err = c.processSub(sub, false) + if err == nil { + // This will create (if not already exist) a JS consumer for subscriptions + // of QoS >= 1. But if a JS consumer already exists and the subscription + // for same subject is now a QoS==0, then the JS consumer will be deleted. + jscons, jssub, err = c.mqttProcessJSConsumer(sess, as.mstream, + subject, sid, f.qos, fromSubProto) + } + if err != nil { + c.Errorf("error subscribing to %q: err=%v", subject, err) + f.qos = mqttSubAckFailure + c.mqttCleanupFailedSub(sub, jscons, jssub) + continue + } + if mqttNeedSubForLevelUp(subject) { + var fwjscons *Consumer + var fwjssub *subscription + + // Say subject is "foo.>", remove the ".>" so that it becomes "foo" + fwcsubject := subject[:len(subject)-2] + // Change the sid to "foo fwc" + fwcsid := fwcsubject + mqttMultiLevelSidSuffix + fwcsub := c.mqttCreateSub(fwcsubject, fwcsid, mqttDeliverMsgCb, f.qos) + if fromSubProto { + as.serializeRetainedMsgsForSub(sess, c, fwcsub, trace) + } + // See note above about existing subscription. + fwcsub, err = c.processSub(fwcsub, false) + if err == nil { + fwjscons, fwjssub, err = c.mqttProcessJSConsumer(sess, as.mstream, + fwcsubject, fwcsid, f.qos, fromSubProto) + } + if err != nil { + c.Errorf("error subscribing to %q: err=%v", fwcsubject, err) + f.qos = mqttSubAckFailure + c.mqttCleanupFailedSub(sub, jscons, jssub) + c.mqttCleanupFailedSub(fwcsub, fwjscons, fwjssub) + continue + } + subs = append(subs, fwcsub) + addJSConsToSess(fwcsid, fwjscons) + } + subs = append(subs, sub) + addJSConsToSess(sid, jscons) + } + var err error + if fromSubProto { + err = sess.update(clientID, filters, true) + } + return subs, err +} + +// Retained publish messages matching this subscription are serialized in the +// subscription's `prm` mqtt writer. This buffer will be queued for outbound +// after the subscription is processed and SUBACK is sent or possibly when +// server processes an incoming published message matching the newly +// registered subscription. +// +// Account session manager lock held on entry. +// Session lock held on entry +func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(sess *mqttSession, c *client, sub *subscription, trace bool) { + if len(as.retmsgs) > 0 { + var rmsa [64]*mqttRetainedMsg + rms := rmsa[:0] + + as.getRetainedPublishMsgs(string(sub.subject), &rms) + for _, rm := range rms { + if sub.mqtt.prm == nil { + sub.mqtt.prm = &mqttWriter{} + } + prm := sub.mqtt.prm + pi := sess.getPubAckIdentifier(mqttGetQoS(rm.Flags), sub) + // Need to use the subject for the retained message, not the `sub` subject. + // We can find the published retained message in rm.sub.subject. + flags := mqttSerializePublishMsg(prm, pi, false, true, string(rm.sub.subject), rm.Msg[:len(rm.Msg)-LEN_CR_LF]) + if trace { + pp := mqttPublish{ + flags: flags, + pi: pi, + subject: rm.sub.subject, + sz: len(rm.Msg) - LEN_CR_LF, + } + c.traceOutOp("PUBLISH", []byte(mqttPubTrace(&pp))) + } + } + } +} + +// Returns in the provided slice all publish retained message records that +// match the given subscription's `subject` (which could have wildcards). +// +// Account session manager lock held on entry. +func (as *mqttAccountSessionManager) getRetainedPublishMsgs(subject string, rms *[]*mqttRetainedMsg) { + result := as.sl.ReverseMatch(subject) + if len(result.psubs) == 0 { + return + } + for _, sub := range result.psubs { + // Since this is a reverse match, the subscription objects here + // contain literals corresponding to the published subjects. + if rm, ok := as.retmsgs[string(sub.subject)]; ok { + *rms = append(*rms, rm) + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// +// MQTT session related functions +// +////////////////////////////////////////////////////////////////////////////// + +// Returns a new mqttSession object with max ack pending set based on +// option or use mqttDefaultMaxAckPending if no option set. +func mqttSessionCreate(opts *Options) *mqttSession { + maxp := opts.MQTT.MaxAckPending + if maxp == 0 { + maxp = mqttDefaultMaxAckPending + } + return &mqttSession{maxp: maxp} +} + +// Persists a session. Note that if the session's current client does not match +// the given client, nothing is done. +// +// Lock held on entry. +func (sess *mqttSession) save(clientID string) error { + ps := mqttPersistedSession{ + ID: clientID, + Clean: sess.clean, + Subs: sess.subs, + } + if l := len(sess.cons); l > 0 { + cons := make(map[string]string, l) + for sid, jscons := range sess.cons { + cons[sid] = jscons.Name() + } + ps.Cons = cons + } + sessBytes, _ := json.Marshal(&ps) + newSeq, _, err := sess.stream.store.StoreMsg("sessions", nil, sessBytes) + if err != nil { + return err + } + if sess.sseq != 0 { + sess.stream.DeleteMsg(sess.sseq) + } + sess.sseq = newSeq + return nil +} + +// Delete JS consumers for this session and delete the persisted session from +// the stream. +// +// Lock held on entry. +func (sess *mqttSession) clear() { + for consName, cons := range sess.cons { + delete(sess.cons, consName) + cons.Delete() + } + if sess.stream != nil && sess.sseq != 0 { + sess.stream.DeleteMsg(sess.sseq) + sess.sseq = 0 + } + sess.subs, sess.pending, sess.cpending = nil, nil, nil +} + +// This will update the session record for this client in the account's MQTT +// sessions stream if the session had any change in the subscriptions. +// +// Lock held on entry. +func (sess *mqttSession) update(clientID string, filters []*mqttFilter, add bool) error { + // Evaluate if we need to persist anything. + var needUpdate bool + for _, f := range filters { + if add { + if f.qos == mqttSubAckFailure { + continue + } + if qos, ok := sess.subs[f.filter]; !ok || qos != f.qos { + if sess.subs == nil { + sess.subs = make(map[string]byte) + } + sess.subs[f.filter] = f.qos + needUpdate = true + } + } else { + if _, ok := sess.subs[f.filter]; ok { + delete(sess.subs, f.filter) + needUpdate = true + } + } + } + var err error + if needUpdate { + err = sess.save(clientID) + } + return err +} + +// If both pQos and sub.mqtt.qos are > 0, then this will return the next +// packet identifier to use for a published message. +// +// Lock held on entry +func (sess *mqttSession) getPubAckIdentifier(pQos byte, sub *subscription) uint16 { + pi, _ := sess.trackPending(pQos, _EMPTY_, sub) + return pi +} + +// If publish message QoS (pQos) and the subscription's QoS are both at least 1, +// this function will assign a packet identifier (pi) and will keep track of +// the pending ack. If the message has already been redelivered (reply != ""), +// the returned boolean will be `true`. +// +// Lock held on entry +func (sess *mqttSession) trackPending(pQos byte, reply string, sub *subscription) (uint16, bool) { + if pQos == 0 || sub.mqtt.qos == 0 { + return 0, false + } + var dup bool + var pi uint16 + + bumpPI := func() uint16 { + var avail bool + next := sess.ppi + for i := 0; i < 0xFFFF; i++ { + next++ + if next == 0 { + next = 1 + } + if _, used := sess.pending[next]; !used { + sess.ppi = next + avail = true + break + } + } + if !avail { + return 0 + } + return sess.ppi + } + + if reply == _EMPTY_ || sub.mqtt.jsCons == nil { + return bumpPI(), false + } + + // Here, we have an ACK subject and a JS consumer... + jsCons := sub.mqtt.jsCons + if sess.pending == nil { + sess.pending = make(map[uint16]*mqttPending) + sess.cpending = make(map[*Consumer]map[uint64]uint16) + } + // Get the stream sequence and other from the ack reply subject + sseq, dseq, dcount := ackReplyInfo(reply) + + var pending *mqttPending + // For this JS consumer, check to see if we already have sseq->pi + sseqToPi, ok := sess.cpending[jsCons] + if !ok { + sseqToPi = make(map[uint64]uint16) + sess.cpending[jsCons] = sseqToPi + } else if pi, ok = sseqToPi[sseq]; ok { + // If we already have a pi, get the ack so we update it. + // We will reuse the save packet identifier (pi). + pending = sess.pending[pi] + } + if pi == 0 { + // sess.maxp will always have a value > 0. + if len(sess.pending) >= int(sess.maxp) { + // Indicate that we did not assign a packet identifier. + // The caller will not send the message to the subscription + // and JS will redeliver later, based on consumer's AckWait. + sess.stalled = true + return 0, false + } + pi = bumpPI() + sseqToPi[sseq] = pi + } + if pending == nil { + pending = &mqttPending{jsCons: jsCons, sseq: sseq} + sess.pending[pi] = pending + } + // Update pending with consumer delivery sequence and count + pending.dseq, pending.dcount = dseq, dcount + // If redelivery, return DUP flag + if dcount > 1 { + dup = true + } + return pi, dup +} + +////////////////////////////////////////////////////////////////////////////// +// +// CONNECT protocol related functions +// +////////////////////////////////////////////////////////////////////////////// + +// Parse the MQTT connect protocol +func (c *client) mqttParseConnect(r *mqttReader, pl int) (byte, *mqttConnectProto, error) { + + // Make sure that we have the expected length in the buffer, + // and if not, this will read it from the underlying reader. + if err := r.ensurePacketInBuffer(pl); err != nil { + return 0, nil, err + } + + // Protocol name + proto, err := r.readBytes("protocol name", false) + if err != nil { + return 0, nil, err + } + + // Spec [MQTT-3.1.2-1] + if !bytes.Equal(proto, mqttProtoName) { + // Check proto name against v3.1 to report better error + if bytes.Equal(proto, mqttOldProtoName) { + return 0, nil, fmt.Errorf("older protocol %q not supported", proto) + } + return 0, nil, fmt.Errorf("expected connect packet with protocol name %q, got %q", mqttProtoName, proto) + } + + // Protocol level + level, err := r.readByte("protocol level") + if err != nil { + return 0, nil, err + } + // Spec [MQTT-3.1.2-2] + if level != mqttProtoLevel { + return mqttConnAckRCUnacceptableProtocolVersion, nil, fmt.Errorf("unacceptable protocol version of %v", level) + } + + cp := &mqttConnectProto{} + // Connect flags + cp.flags, err = r.readByte("flags") + if err != nil { + return 0, nil, err + } + + // Spec [MQTT-3.1.2-3] + if cp.flags&mqttConnFlagReserved != 0 { + return 0, nil, fmt.Errorf("connect flags reserved bit not set to 0") + } + + var hasWill bool + wqos := (cp.flags & mqttConnFlagWillQoS) >> 3 + wretain := cp.flags&mqttConnFlagWillRetain != 0 + // Spec [MQTT-3.1.2-11] + if cp.flags&mqttConnFlagWillFlag == 0 { + // Spec [MQTT-3.1.2-13] + if wqos != 0 { + return 0, nil, fmt.Errorf("if Will flag is set to 0, Will QoS must be 0 too, got %v", wqos) + } + // Spec [MQTT-3.1.2-15] + if wretain { + return 0, nil, fmt.Errorf("if Will flag is set to 0, Will Retain flag must be 0 too") + } + } else { + // Spec [MQTT-3.1.2-14] + if wqos == 3 { + return 0, nil, fmt.Errorf("if Will flag is set to 1, Will QoS can be 0, 1 or 2, got %v", wqos) + } + hasWill = true + } + + // Spec [MQTT-3.1.2-19] + hasUser := cp.flags&mqttConnFlagUsernameFlag != 0 + // Spec [MQTT-3.1.2-21] + hasPassword := cp.flags&mqttConnFlagPasswordFlag != 0 + // Spec [MQTT-3.1.2-22] + if !hasUser && hasPassword { + return 0, nil, fmt.Errorf("password flag set but username flag is not") + } + + // Keep alive + var ka uint16 + ka, err = r.readUint16("keep alive") + if err != nil { + return 0, nil, err + } + // Spec [MQTT-3.1.2-24] + if ka > 0 { + cp.rd = time.Duration(float64(ka)*1.5) * time.Second + } + + // Payload starts here and order is mandated by: + // Spec [MQTT-3.1.3-1]: client ID, will topic, will message, username, password + + // Client ID + cp.clientID, err = r.readString("client ID") + if err != nil { + return 0, nil, err + } + // Spec [MQTT-3.1.3-7] + if cp.clientID == _EMPTY_ { + if cp.flags&mqttConnFlagCleanSession == 0 { + return mqttConnAckRCIdentifierRejected, nil, fmt.Errorf("when client ID is empty, clean session flag must be set to 1") + } + // Spec [MQTT-3.1.3-6] + cp.clientID = nuid.Next() + } + // Spec [MQTT-3.1.3-4] and [MQTT-3.1.3-9] + if !utf8.ValidString(cp.clientID) { + return mqttConnAckRCIdentifierRejected, nil, fmt.Errorf("invalid utf8 for client ID: %q", cp.clientID) + } + + if hasWill { + cp.will = &mqttWill{ + qos: wqos, + retain: wretain, + } + var topic []byte + topic, err = r.readBytes("Will topic", false) + if err != nil { + return 0, nil, err + } + if len(topic) == 0 { + return 0, nil, fmt.Errorf("empty Will topic not allowed") + } + if !utf8.Valid(topic) { + return 0, nil, fmt.Errorf("invalide utf8 for Will topic %q", topic) + } + // Convert MQTT topic to NATS subject + var copied bool + copied, topic, err = mqttTopicToNATSPubSubject(topic) + if err != nil { + return 0, nil, err + } + if !copied { + topic = copyBytes(topic) + } + cp.will.topic = topic + // Now will message + var msg []byte + msg, err = r.readBytes("Will message", false) + if err != nil { + return 0, nil, err + } + cp.will.message = make([]byte, 0, len(msg)+2) + cp.will.message = append(cp.will.message, msg...) + cp.will.message = append(cp.will.message, CR_LF...) + } + + if hasUser { + c.opts.Username, err = r.readString("user name") + if err != nil { + return 0, nil, err + } + if c.opts.Username == _EMPTY_ { + return mqttConnAckRCBadUserOrPassword, nil, fmt.Errorf("empty user name not allowed") + } + // Spec [MQTT-3.1.3-11] + if !utf8.ValidString(c.opts.Username) { + return mqttConnAckRCBadUserOrPassword, nil, fmt.Errorf("invalid utf8 for user name %q", c.opts.Username) + } + } + + if hasPassword { + c.opts.Password, err = r.readString("password") + if err != nil { + return 0, nil, err + } + c.opts.Token = c.opts.Password + } + return 0, cp, nil +} + +func (c *client) mqttConnectTrace(cp *mqttConnectProto) string { + trace := fmt.Sprintf("clientID=%s", cp.clientID) + if cp.rd > 0 { + trace += fmt.Sprintf(" keepAlive=%v", cp.rd) + } + if cp.will != nil { + trace += fmt.Sprintf(" will=(topic=%s QoS=%v retain=%v)", + cp.will.topic, cp.will.qos, cp.will.retain) + } + if c.opts.Username != _EMPTY_ { + trace += fmt.Sprintf(" username=%s", c.opts.Username) + } + if c.opts.Password != _EMPTY_ { + trace += " password=****" + } + return trace +} + +func (s *Server) mqttProcessConnect(c *client, cp *mqttConnectProto, trace bool) error { + sendConnAck := func(rc byte, sessp bool) { + c.mqttEnqueueConnAck(rc, sessp) + if trace { + c.traceOutOp("CONNACK", []byte(fmt.Sprintf("sp=%v rc=%v", sessp, rc))) + } + } + + c.mu.Lock() + c.clearAuthTimer() + c.mu.Unlock() + if !s.isClientAuthorized(c) { + sendConnAck(mqttConnAckRCNotAuthorized, false) + c.closeConnection(AuthenticationViolation) + return ErrAuthentication + } + // Now that we are are authenticated, we have the client bound to the account. + // Get the account's level MQTT sessions manager. If it does not exists yet, + // this will create it along with the streams where sessions and messages + // are stored. + sm := &s.mqtt.sessmgr + asm, err := sm.getOrCreateAccountSessionManager(cp.clientID, c) + if err != nil { + return err + } + + // Rest of code runs under the account's sessions manager write lock. + asm.mu.Lock() + defer asm.mu.Unlock() + + // Is the client requesting a clean session or not. + cleanSess := cp.flags&mqttConnFlagCleanSession != 0 + // Session present? Assume false, will be set to true only when applicable. + sessp := false + // Do we have an existing session for this client ID + es, ok := asm.sessions[cp.clientID] + if ok { + es.mu.Lock() + defer es.mu.Unlock() + // Clear the session if client wants a clean session. + // Also, Spec [MQTT-3.2.2-1]: don't report session present + if cleanSess || es.clean { + // Spec [MQTT-3.1.2-6]: If CleanSession is set to 1, the Client and + // Server MUST discard any previous Session and start a new one. + // This Session lasts as long as the Network Connection. State data + // associated with this Session MUST NOT be reused in any subsequent + // Session. + es.clear() + } else { + // Report to the client that the session was present + sessp = true + } + ec := es.c + // Is there an actual client associated with this session. + if ec != nil { + // Spec [MQTT-3.1.4-2]. If the ClientId represents a Client already + // connected to the Server then the Server MUST disconnect the existing + // client. + ec := es.c + ec.mu.Lock() + // Remove will before closing + ec.mqtt.cp.will = nil + ec.mu.Unlock() + // Close old client in separate go routine + go ec.closeConnection(DuplicateClientID) + } + // Bind with the new client + es.c = c + es.clean = cleanSess + } else { + // Spec [MQTT-3.2.2-3]: if the Server does not have stored Session state, + // it MUST set Session Present to 0 in the CONNACK packet. + es = mqttSessionCreate(s.getOpts()) + es.c, es.clean, es.stream = c, cleanSess, asm.sstream + es.mu.Lock() + defer es.mu.Unlock() + asm.sessions[cp.clientID] = es + es.save(cp.clientID) + } + c.mu.Lock() + c.flags.set(connectReceived) + c.mqtt.cp = cp + c.mqtt.asm = asm + c.mqtt.sess = es + c.mu.Unlock() + // Spec [MQTT-3.2.0-1]: At this point we need to send the CONNACK before + // restoring subscriptions, because CONNACK must be the first packet sent + // to the client. + sendConnAck(mqttConnAckRCConnectionAccepted, sessp) + // Now process possible saved subscriptions. + if l := len(es.subs); l > 0 { + filters := make([]*mqttFilter, 0, l) + for subject, qos := range es.subs { + filters = append(filters, &mqttFilter{filter: subject, qos: qos}) + } + if _, err := asm.processSubs(es, cp.clientID, c, filters, false, trace); err != nil { + return err + } + } + return nil +} + +func (c *client) mqttEnqueueConnAck(rc byte, sessionPresent bool) { + proto := [4]byte{mqttPacketConnectAck, 2, 0, rc} + c.mu.Lock() + // Spec [MQTT-3.2.2-4]. If return code is different from 0, then + // session present flag must be set to 0. + if rc == 0 { + if sessionPresent { + proto[2] = 1 + } + } + c.enqueueProto(proto[:]) + c.mu.Unlock() +} + +func (s *Server) mqttHandleWill(c *client) { + c.mu.Lock() + if c.mqtt.cp == nil { + c.mu.Unlock() + return + } + will := c.mqtt.cp.will + if will == nil { + c.mu.Unlock() + return + } + pp := &mqttPublish{ + subject: will.topic, + msg: will.message, + sz: len(will.message) - LEN_CR_LF, + flags: will.qos << 1, + } + if will.retain { + pp.flags |= mqttPubFlagRetain + } + c.mu.Unlock() + s.mqttProcessPub(c, pp) + c.flushClients(0) +} + +////////////////////////////////////////////////////////////////////////////// +// +// PUBLISH protocol related functions +// +////////////////////////////////////////////////////////////////////////////// + +func (c *client) mqttParsePub(r *mqttReader, pl int, pp *mqttPublish) error { + qos := (pp.flags & mqttPubFlagQoS) >> 1 + if qos > 1 { + return fmt.Errorf("publish QoS=%v not supported", qos) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + return err + } + // Keep track of where we are when starting to read the variable header + start := r.pos + + var err error + pp.subject, err = r.readBytes("topic", false) + if err != nil { + return err + } + if len(pp.subject) == 0 { + return fmt.Errorf("topic cannot be empty") + } + // Convert the topic to a NATS subject. This call will also check that + // there is no MQTT wildcards (Spec [MQTT-3.3.2-2] and [MQTT-4.7.1-1]) + // Note that this may not result in a copy if there is no special + // conversion. It is good because after the message is processed we + // won't have a reference to the buffer and we save a copy. + _, pp.subject, err = mqttTopicToNATSPubSubject(pp.subject) + if err != nil { + return err + } + + if qos > 0 { + pp.pi, err = r.readUint16("packet identifier") + if err != nil { + return err + } + if pp.pi == 0 { + return fmt.Errorf("with QoS=%v, packet identifier cannot be 0", qos) + } + } + + // The message payload will be the total packet length minus + // what we have consumed for the variable header + pp.sz = pl - (r.pos - start) + pp.msg = make([]byte, 0, pp.sz+2) + if pp.sz > 0 { + start = r.pos + r.pos += pp.sz + pp.msg = append(pp.msg, r.buf[start:r.pos]...) + } + pp.msg = append(pp.msg, _CRLF_...) + return nil +} + +func mqttPubTrace(pp *mqttPublish) string { + dup := pp.flags&mqttPubFlagDup != 0 + qos := mqttGetQoS(pp.flags) + retain := mqttIsRetained(pp.flags) + var piStr string + if pp.pi > 0 { + piStr = fmt.Sprintf(" pi=%v", pp.pi) + } + return fmt.Sprintf("%s dup=%v QoS=%v retain=%v size=%v%s", + pp.subject, dup, qos, retain, pp.sz, piStr) +} + +func (s *Server) mqttProcessPub(c *client, pp *mqttPublish) { + c.mqtt.pp = pp + c.pa.subject, c.pa.hdr, c.pa.size, c.pa.szb = pp.subject, -1, pp.sz, []byte(strconv.FormatInt(int64(pp.sz), 10)) + // This will work for QoS 0 but mqtt msg delivery callback will ignore + // delivery for QoS > 0 published messages (since it is handled specifically + // with call to directProcessInboundJetStreamMsg). + // However, this needs to be invoked before directProcessInboundJetStreamMsg() + // in case we are dealing with publish retained messages. + c.processInboundClientMsg(pp.msg) + if mqttGetQoS(pp.flags) > 0 { + // Since this is the fast path, we access the messages stream directly here + // without locking. All the fields mqtt.asm.mstream are immutable. + c.mqtt.asm.mstream.processInboundJetStreamMsg(nil, c, string(c.pa.subject), "", pp.msg[:len(pp.msg)-LEN_CR_LF]) + } + c.pa.subject, c.pa.hdr, c.pa.size, c.pa.szb = nil, -1, 0, nil + c.mqtt.pp = nil +} + +// Invoked when processing an inbound client message. If the "retain" flag is +// set, the message is stored so it can be later resent to (re)starting +// subscriptions that match the subject. +// +// Invoked from the publisher's readLoop. No client lock is held on entry. +func (c *client) mqttHandlePubRetain() { + pp := c.mqtt.pp + if mqttIsRetained(pp.flags) { + key := string(pp.subject) + asm := c.mqtt.asm + asm.mu.Lock() + // Spec [MQTT-3.3.1-11]. Payload of size 0 removes the retained message, + // but should still be delivered as a normal message. + if pp.sz == 0 { + if asm.retmsgs != nil { + if erm, ok := asm.retmsgs[key]; ok { + delete(asm.retmsgs, key) + asm.sl.Remove(erm.sub) + if erm.sseq != 0 { + asm.rstream.DeleteMsg(erm.sseq) + } + } + } + } else { + // Spec [MQTT-3.3.1-5]. Store the retained message with its QoS. + // When coming from a publish protocol, `pp` is referencing a stack + // variable that itself possibly references the read buffer. + rm := &mqttRetainedMsg{ + Msg: copyBytes(pp.msg), + Flags: pp.flags, + Source: c.opts.Username, + } + rm = asm.handleRetainedMsg(key, rm) + rmBytes, _ := json.Marshal(rm) + // TODO: For now we will report the error but continue... + seq, _, err := asm.rstream.store.StoreMsg(key, nil, rmBytes) + if err != nil { + c.mu.Lock() + acc := c.acc + c.mu.Unlock() + c.Errorf("unable to store retained message for account %q, subject %q: %v", + acc.GetName(), key, err) + } + // If it has been replaced, rm.sseq will be != 0 + if rm.sseq != 0 { + asm.rstream.DeleteMsg(rm.sseq) + } + // Keep track of current stream sequence (possibly 0 if failed to store) + rm.sseq = seq + } + + asm.mu.Unlock() + + // Clear the retain flag for a normal published message. + pp.flags &= ^mqttPubFlagRetain + } +} + +// After a config reload, it is possible that the source of a publish retained +// message is no longer allowed to publish on the given topic. If that is the +// case, the retained message is removed from the map and will no longer be +// sent to (re)starting subscriptions. +// +// Server lock is held on entry +func (s *Server) mqttCheckPubRetainedPerms() { + sm := &s.mqtt.sessmgr + sm.mu.RLock() + defer sm.mu.RUnlock() + + for _, asm := range sm.sessions { + perms := map[string]*perm{} + asm.mu.Lock() + for subject, rm := range asm.retmsgs { + if rm.Source == _EMPTY_ { + continue + } + // Lookup source from global users. + u := s.users[rm.Source] + if u != nil { + p, ok := perms[rm.Source] + if !ok { + p = generatePubPerms(u.Permissions) + perms[rm.Source] = p + } + // If there is permission and no longer allowed to publish in + // the subject, remove the publish retained message from the map. + if p != nil && !pubAllowed(p, subject) { + u = nil + } + } + + // Not present or permissions have changed such that the source can't + // publish on that subject anymore: remove it from the map. + if u == nil { + delete(asm.retmsgs, subject) + asm.rstream.DeleteMsg(rm.sseq) + asm.sl.Remove(rm.sub) + } + } + asm.mu.Unlock() + } +} + +// Helper to generate only pub permissions from a Permissions object +func generatePubPerms(perms *Permissions) *perm { + var p *perm + if perms.Publish.Allow != nil { + p = &perm{} + p.allow = NewSublistWithCache() + } + for _, pubSubject := range perms.Publish.Allow { + sub := &subscription{subject: []byte(pubSubject)} + p.allow.Insert(sub) + } + if len(perms.Publish.Deny) > 0 { + if p == nil { + p = &perm{} + } + p.deny = NewSublistWithCache() + } + for _, pubSubject := range perms.Publish.Deny { + sub := &subscription{subject: []byte(pubSubject)} + p.deny.Insert(sub) + } + return p +} + +// Helper that checks if given `perms` allow to publish on the given `subject` +func pubAllowed(perms *perm, subject string) bool { + allowed := true + if perms.allow != nil { + r := perms.allow.Match(subject) + allowed = len(r.psubs) != 0 + } + // If we have a deny list and are currently allowed, check that as well. + if allowed && perms.deny != nil { + r := perms.deny.Match(subject) + allowed = len(r.psubs) == 0 + } + return allowed +} + +func mqttWritePublish(w *mqttWriter, qos byte, dup, retain bool, subject string, pi uint16, payload []byte) { + flags := qos << 1 + if dup { + flags |= mqttPubFlagDup + } + if retain { + flags |= mqttPubFlagRetain + } + w.WriteByte(mqttPacketPub | flags) + pkLen := 2 + len(subject) + len(payload) + if qos > 0 { + pkLen += 2 + } + w.WriteVarInt(pkLen) + w.WriteString(subject) + if qos > 0 { + w.WriteUint16(pi) + } + w.Write([]byte(payload)) +} + +func (c *client) mqttEnqueuePubAck(pi uint16) { + proto := [4]byte{mqttPacketPubAck, 0x2, 0, 0} + proto[2] = byte(pi >> 8) + proto[3] = byte(pi) + c.mu.Lock() + c.enqueueProto(proto[:4]) + c.mu.Unlock() +} + +func mqttParsePubAck(r *mqttReader, pl int) (uint16, error) { + if err := r.ensurePacketInBuffer(pl); err != nil { + return 0, err + } + pi, err := r.readUint16("packet identifier") + if err != nil { + return 0, err + } + if pi == 0 { + return 0, fmt.Errorf("packet identifier cannot be 0") + } + return pi, nil +} + +func (c *client) mqttProcessPubAck(pi uint16) { + sess := c.mqtt.sess + if sess == nil { + return + } + sess.mu.Lock() + defer sess.mu.Unlock() + if sess.c != c { + return + } + if ack, ok := sess.pending[pi]; ok { + delete(sess.pending, pi) + jsCons := ack.jsCons + if sseqToPi, ok := sess.cpending[jsCons]; ok { + delete(sseqToPi, ack.sseq) + } + jsCons.ackMsg(ack.sseq, ack.dseq, ack.dcount) + if len(sess.pending) == 0 { + sess.ppi = 0 + } + if sess.stalled && len(sess.pending) < int(sess.maxp) { + sess.stalled = false + for _, cons := range sess.cons { + cons.mu.Lock() + cons.forceExpirePending() + cons.mu.Unlock() + } + } + } +} + +// Return the QoS from the given PUBLISH protocol's flags +func mqttGetQoS(flags byte) byte { + return flags & mqttPubFlagQoS >> 1 +} + +func mqttIsRetained(flags byte) bool { + return flags&mqttPubFlagRetain != 0 +} + +////////////////////////////////////////////////////////////////////////////// +// +// SUBSCRIBE related functions +// +////////////////////////////////////////////////////////////////////////////// + +func (c *client) mqttParseSubs(r *mqttReader, b byte, pl int) (uint16, []*mqttFilter, error) { + return c.mqttParseSubsOrUnsubs(r, b, pl, true) +} + +func (c *client) mqttParseSubsOrUnsubs(r *mqttReader, b byte, pl int, sub bool) (uint16, []*mqttFilter, error) { + var expectedFlag byte + var action string + if sub { + expectedFlag = mqttSubscribeFlags + } else { + expectedFlag = mqttUnsubscribeFlags + action = "un" + } + // Spec [MQTT-3.8.1-1], [MQTT-3.10.1-1] + if rf := b & 0xf; rf != expectedFlag { + return 0, nil, fmt.Errorf("wrong %ssubscribe reserved flags: %x", action, rf) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + return 0, nil, err + } + pi, err := r.readUint16("packet identifier") + if err != nil { + return 0, nil, fmt.Errorf("reading packet identifier: %v", err) + } + end := r.pos + (pl - 2) + var filters []*mqttFilter + for r.pos < end { + // Don't make a copy now because, this will happen during conversion + // or when processing the sub. + filter, err := r.readBytes("topic filter", false) + if err != nil { + return 0, nil, err + } + if len(filter) == 0 { + return 0, nil, errors.New("topic filter cannot be empty") + } + // Spec [MQTT-3.8.3-1], [MQTT-3.10.3-1] + if !utf8.Valid(filter) { + return 0, nil, fmt.Errorf("invalid utf8 for topic filter %q", filter) + } + var qos byte + // This won't return an error. We will find out if the subject + // is valid or not when trying to create the subscription. + _, filter, _ = mqttFilterToNATSSubject(filter) + if sub { + qos, err = r.readByte("QoS") + if err != nil { + return 0, nil, err + } + // Spec [MQTT-3-8.3-4]. + if qos > 2 { + return 0, nil, fmt.Errorf("subscribe QoS value must be 0, 1 or 2, got %v", qos) + } + } + filters = append(filters, &mqttFilter{string(filter), qos}) + } + // Spec [MQTT-3.8.3-3], [MQTT-3.10.3-2] + if len(filters) == 0 { + return 0, nil, fmt.Errorf("%ssubscribe protocol must contain at least 1 topic filter", action) + } + return pi, filters, nil +} + +func mqttSubscribeTrace(filters []*mqttFilter) string { + var sep string + trace := "[" + for i, f := range filters { + trace += sep + fmt.Sprintf("%s QoS=%v", f.filter, f.qos) + if i == 0 { + sep = ", " + } + } + trace += "]" + return trace +} + +func mqttDeliverMsgCb(sub *subscription, pc *client, subject, reply string, msg []byte) { + if sub.mqtt == nil { + return + } + + var ppFlags byte + var pQoS byte + var pi uint16 + var dup bool + var retained bool + + // This is the client associated with the subscription. + cc := sub.client + + // This is immutable + sess := cc.mqtt.sess + // We lock to check some of the subscription's fields and if we need to + // keep track of pending acks, etc.. + sess.mu.Lock() + if sess.c != cc { + sess.mu.Unlock() + return + } + + // Check the publisher's kind. If JETSTREAM it means that this is a persisted message + // that is being delivered. + if pc.kind == JETSTREAM { + // If there is no JS consumer attached to this subscription, it means that we are + // dealing with a bare NATS subscription, in which case we simply return to avoid + // duplicate delivery. + if sub.mqtt.jsCons == nil { + sess.mu.Unlock() + return + } + ppFlags = mqttPubQos1 + pQoS = 1 + // This is a QoS1 message for a QoS1 subscription, so get the pi and keep + // track of ack subject. + pi, dup = sess.trackPending(pQoS, reply, sub) + if pi == 0 { + // We have reached max pending, don't send the message now. + // JS will cause a redelivery and if by then the number of pending + // messages has fallen below threshold, the message will be resent. + sess.mu.Unlock() + return + } + // In JS case, we need to use the pc.ca.deliver value as the subject. + subject = string(pc.pa.deliver) + } else if pc.mqtt != nil { + // This is a MQTT publisher... + ppFlags = pc.mqtt.pp.flags + pQoS = mqttGetQoS(ppFlags) + // If the QoS of published message and subscription is 1, then we return here to + // avoid duplicate delivery. The JetStream publisher will handle that case. + if pQoS > 0 && sub.mqtt.qos > 0 { + sess.mu.Unlock() + return + } + retained = mqttIsRetained(ppFlags) + } + // else this is coming from a non MQTT publisher, so Qos 0, no dup nor retain flag, etc.. + sess.mu.Unlock() + + sw := mqttWriter{} + w := &sw + + flags := mqttSerializePublishMsg(w, pi, dup, retained, subject, msg) + + cc.mu.Lock() + if sub.mqtt.prm != nil { + cc.queueOutbound(sub.mqtt.prm.Bytes()) + sub.mqtt.prm = nil + } + cc.queueOutbound(w.Bytes()) + pc.addToPCD(cc) + if cc.trace { + pp := mqttPublish{ + flags: flags, + pi: pi, + subject: []byte(subject), + sz: len(msg), + } + cc.traceOutOp("PUBLISH", []byte(mqttPubTrace(&pp))) + } + cc.mu.Unlock() +} + +// Serializes to the given writer the message for the given subject. +func mqttSerializePublishMsg(w *mqttWriter, pi uint16, dup, retained bool, subject string, msg []byte) byte { + topic := natsSubjectToMQTTTopic(subject) + + // Compute len (will have to add packet id if message is sent as QoS>=1) + pkLen := 2 + len(topic) + len(msg) + + var flags byte + + // Set flags for dup/retained/qos1 + if dup { + flags |= mqttPubFlagDup + } + if retained { + flags |= mqttPubFlagRetain + } + // For now, we have only QoS 1 + if pi > 0 { + pkLen += 2 + flags |= mqttPubQos1 + } + + w.WriteByte(mqttPacketPub | flags) + w.WriteVarInt(pkLen) + w.WriteBytes(topic) + if pi > 0 { + w.WriteUint16(pi) + } + w.Write(msg) + + return flags +} + +// Helper to create an MQTT subscription. +func (c *client) mqttCreateSub(subject, sid string, cb msgHandler, qos byte) *subscription { + sub := c.createSub([]byte(subject), nil, []byte(sid), cb) + sub.mqtt = &mqttSub{qos: qos} + return sub +} + +// Process the list of subscriptions and update the given filter +// with the QoS that has been accepted (or failure). +// +// Spec [MQTT-3.8.4-3] says that if an exact same subscription is +// found, it needs to be replaced with the new one (possibly updating +// the qos) and that the flow of publications must not be interrupted, +// which I read as the replacement cannot be a "remove then add" if there +// is a chance that in between the 2 actions, published messages +// would be "lost" because there would not be any matching subscription. +func (c *client) mqttProcessSubs(filters []*mqttFilter) ([]*subscription, error) { + // Those things are immutable, but since processing subs is not + // really in the fast path, let's get them under the client lock. + c.mu.Lock() + asm := c.mqtt.asm + sess := c.mqtt.sess + clientID := c.mqtt.cp.clientID + trace := c.trace + c.mu.Unlock() + + asm.mu.RLock() + defer asm.mu.RUnlock() + return asm.processSubs(sess, clientID, c, filters, true, trace) +} + +func (c *client) mqttCleanupFailedSub(sub *subscription, jscons *Consumer, jssub *subscription) { + c.mu.Lock() + acc := c.acc + c.mu.Unlock() + + if sub != nil { + c.unsubscribe(acc, sub, true, true) + } + if jssub != nil { + c.unsubscribe(acc, jssub, true, true) + } + if jscons != nil { + jscons.Delete() + } +} + +// When invoked with a QoS of 0, looks for an existing JS durable consumer for +// the given sid and if one is found, delete the JS durable consumer and unsub +// the NATS subscription on the delivery subject. +// With a QoS > 0, creates or update the existing JS durable consumer along with +// its NATS subscription on a delivery subject. +// +// Account session manager lock held on entry. +func (c *client) mqttProcessJSConsumer(sess *mqttSession, stream *Stream, subject, + sid string, qos byte, fromSubProto bool) (*Consumer, *subscription, error) { + + // Check if we are already a JS consumer for this SID. + cons, exists := sess.cons[sid] + if exists { + // If current QoS is 0, it means that we need to delete the existing + // one (that was QoS > 0) + if qos == 0 { + // The JS durable consumer's delivery subject is on a NUID of + // the form: mqttSubPrefix + . It is also used as the sid + // for the NATS subscription, so use that for the lookup. + sub := c.subs[cons.Config().DeliverSubject] + delete(sess.cons, sid) + cons.Delete() + if sub != nil { + c.mu.Lock() + acc := c.acc + c.mu.Unlock() + c.unsubscribe(acc, sub, true, true) + } + return nil, nil, nil + } + // If this is called when processing SUBSCRIBE protocol, then if + // the JS consumer already exists, we are done (it was created + // during the processing of CONNECT). + if fromSubProto { + return nil, nil, nil + } + } + // Here it means we don't have a JS consumer and if we are QoS 0, + // we have nothing to do. + if qos == 0 { + return nil, nil, nil + } + var err error + inbox := mqttSubPrefix + nuid.Next() + if exists { + cons.updateDeliverSubject(inbox) + } else { + durName := nuid.Next() + opts := c.srv.getOpts() + ackWait := opts.MQTT.AckWait + if ackWait == 0 { + ackWait = mqttDefaultAckWait + } + maxAckPending := opts.MQTT.MaxAckPending + if maxAckPending == 0 { + maxAckPending = mqttDefaultMaxAckPending + } + cc := &ConsumerConfig{ + DeliverSubject: inbox, + Durable: durName, + AckPolicy: AckExplicit, + DeliverPolicy: DeliverNew, + FilterSubject: subject, + AckWait: ackWait, + MaxAckPending: int(maxAckPending), + } + cons, err = stream.addConsumerCheckInterest(cc, false) + if err != nil { + c.Errorf("Unable to add JetStream consumer for subscription on %q: err=%v", subject, err) + return nil, nil, err + } + } + sub := c.mqttCreateSub(inbox, inbox, mqttDeliverMsgCb, qos) + sub.mqtt.jsCons = cons + // This is an internal subscription on subject like "$MQTT.sub." that is setup + // for the JS durable's deliver subject. I don't think that there is any need to + // forward this subscription in the cluster/super cluster. + sub, err = c.processSub(sub, true) + if err != nil { + if !exists { + cons.Delete() + } + c.Errorf("Unable to create subscription for JetStream consumer on %q: %v", subject, err) + return nil, nil, err + } + return cons, sub, nil +} + +// Queues the published retained messages for each subscription and signals +// the writeLoop. +func (c *client) mqttSendRetainedMsgsToNewSubs(subs []*subscription) { + c.mu.Lock() + for _, sub := range subs { + if sub.mqtt != nil && sub.mqtt.prm != nil { + c.queueOutbound(sub.mqtt.prm.Bytes()) + sub.mqtt.prm = nil + } + } + c.flushSignal() + c.mu.Unlock() +} + +func (c *client) mqttEnqueueSubAck(pi uint16, filters []*mqttFilter) { + w := &mqttWriter{} + w.WriteByte(mqttPacketSubAck) + // packet length is 2 (for packet identifier) and 1 byte per filter. + w.WriteVarInt(2 + len(filters)) + w.WriteUint16(pi) + for _, f := range filters { + w.WriteByte(f.qos) + } + c.mu.Lock() + c.enqueueProto(w.Bytes()) + c.mu.Unlock() +} + +////////////////////////////////////////////////////////////////////////////// +// +// UNSUBSCRIBE related functions +// +////////////////////////////////////////////////////////////////////////////// + +func (c *client) mqttParseUnsubs(r *mqttReader, b byte, pl int) (uint16, []*mqttFilter, error) { + return c.mqttParseSubsOrUnsubs(r, b, pl, false) +} + +func (c *client) mqttProcessUnsubs(filters []*mqttFilter) error { + // Those things are immutable, but since processing unsubs is not + // really in the fast path, let's get them under the client lock. + c.mu.Lock() + sess := c.mqtt.sess + clientID := c.mqtt.cp.clientID + c.mu.Unlock() + + sess.mu.Lock() + defer sess.mu.Unlock() + if sess.c != c { + return fmt.Errorf("client %q no longer registered with MQTT session", clientID) + } + + removeJSCons := func(sid string) { + if jscons, ok := sess.cons[sid]; ok { + delete(sess.cons, sid) + jscons.Delete() + if seqPis, ok := sess.cpending[jscons]; ok { + delete(sess.cpending, jscons) + for _, pi := range seqPis { + delete(sess.pending, pi) + } + if len(sess.pending) == 0 { + sess.ppi = 0 + } + } + } + } + for _, f := range filters { + sid := f.filter + // Remove JS Consumer if one exists for this sid + removeJSCons(sid) + if err := c.processUnsub([]byte(sid)); err != nil { + c.Errorf("error unsubscribing from %q: %v", sid, err) + } + if mqttNeedSubForLevelUp(sid) { + subject := sid[:len(sid)-2] + sid = subject + mqttMultiLevelSidSuffix + removeJSCons(sid) + if err := c.processUnsub([]byte(sid)); err != nil { + c.Errorf("error unsubscribing from %q: %v", subject, err) + } + } + } + return sess.update(clientID, filters, false) +} + +func (c *client) mqttEnqueueUnsubAck(pi uint16) { + w := &mqttWriter{} + w.WriteByte(mqttPacketUnsubAck) + w.WriteVarInt(2) + w.WriteUint16(pi) + c.mu.Lock() + c.enqueueProto(w.Bytes()) + c.mu.Unlock() +} + +func mqttUnsubscribeTrace(filters []*mqttFilter) string { + var sep string + trace := "[" + for i, f := range filters { + trace += sep + f.filter + if i == 0 { + sep = ", " + } + } + trace += "]" + return trace +} + +////////////////////////////////////////////////////////////////////////////// +// +// PINGREQ/PINGRESP related functions +// +////////////////////////////////////////////////////////////////////////////// + +func (c *client) mqttEnqueuePingResp() { + c.mu.Lock() + c.enqueueProto(mqttPingResponse) + c.mu.Unlock() +} + +////////////////////////////////////////////////////////////////////////////// +// +// Trace functions +// +////////////////////////////////////////////////////////////////////////////// + +func errOrTrace(err error, trace string) []byte { + if err != nil { + return []byte(err.Error()) + } + return []byte(trace) +} + +////////////////////////////////////////////////////////////////////////////// +// +// Subject/Topic conversion functions +// +////////////////////////////////////////////////////////////////////////////// + +// Converts an MQTT Topic Name to a NATS Subject (used by PUBLISH) +// See mqttToNATSSubjectConversion() for details. +func mqttTopicToNATSPubSubject(mt []byte) (bool, []byte, error) { + return mqttToNATSSubjectConversion(mt, false) +} + +// Converts an MQTT Topic Filter to a NATS Subject (used by SUBSCRIBE) +// See mqttToNATSSubjectConversion() for details. +func mqttFilterToNATSSubject(filter []byte) (bool, []byte, error) { + return mqttToNATSSubjectConversion(filter, true) +} + +// Converts an MQTT Topic Name or Filter to a NATS Subject +// In MQTT: +// - a Topic Name does not have wildcard (PUBLISH uses only topic names). +// - a Topic Filter can include wildcards (SUBSCRIBE uses those). +// - '+' and '#' are wildcard characters (single and multiple levels respectively) +// - '/' is the topic level separator. +// +// Conversion that occurs: +// - '/' is replaced with '/.' if it is the first character in mt +// - '/' is replaced with './' if the last or next character in mt is '/' +// For instance, foo//bar would become foo./.bar +// - '/' is replaced with '.' for all other conditions (foo/bar -> foo.bar) +// - '.' and ' ' cause an error to be returned. +// +// If a copy occurred, the returned boolean will indicate this condition. +func mqttToNATSSubjectConversion(mt []byte, wcOk bool) (bool, []byte, error) { + var res = mt + var newSlice bool + + copyTopic := func(pos int) []byte { + if newSlice && cap(res) > pos+2 { + return res + } + newSlice = true + b := make([]byte, len(res)+10) + copy(b, res[:pos]) + res = b + return res + } + + var j int + end := len(mt) - 1 + for i := 0; i < len(mt); i++ { + switch mt[i] { + case mqttTopicLevelSep: + if i == 0 || res[j-1] == btsep { + res = copyTopic(0) + res[j] = mqttTopicLevelSep + j++ + res[j] = btsep + } else if i == end || mt[i+1] == mqttTopicLevelSep { + res = copyTopic(j) + res[j] = btsep + j++ + res[j] = mqttTopicLevelSep + } else { + res[j] = btsep + } + case btsep, ' ': + // As of now, we cannot support '.' or ' ' in the MQTT topic/filter. + return false, nil, fmt.Errorf("characters ' ' and '.' not supported for MQTT topics") + case mqttSingleLevelWC, mqttMultiLevelWC: + if !wcOk { + // Spec [MQTT-3.3.2-2] and [MQTT-4.7.1-1] + // The wildcard characters can be used in Topic Filters, but MUST NOT be used within a Topic Name + return false, nil, fmt.Errorf("wildcards not allowed in publish's topic: %q", mt) + } + if mt[i] == mqttSingleLevelWC { + res[j] = pwc + } else { + res[j] = fwc + } + default: + if newSlice { + res[j] = mt[i] + } + } + j++ + } + if newSlice && res[j-1] == btsep { + res = copyTopic(j) + res[j] = mqttTopicLevelSep + j++ + } + return newSlice, res[:j], nil +} + +// Converts a NATS subject to MQTT topic. This is for publish +// messages only, so there is no checking for wildcards. +// Rules are reversed of mqttToNATSSubjectConversion. +func natsSubjectToMQTTTopic(subject string) []byte { + topic := []byte(subject) + end := len(subject) - 1 + var j int + for i := 0; i < len(subject); i++ { + switch subject[i] { + case mqttTopicLevelSep: + if !(i == 0 && i < end && subject[i+1] == btsep) { + topic[j] = mqttTopicLevelSep + j++ + } + case btsep: + topic[j] = mqttTopicLevelSep + j++ + if i < end && subject[i+1] == mqttTopicLevelSep { + i++ + } + default: + topic[j] = subject[i] + j++ + } + } + return topic[:j] +} + +// Returns true if the subject has more than 1 token and ends with ".>" +func mqttNeedSubForLevelUp(subject string) bool { + if len(subject) < 3 { + return false + } + end := len(subject) + if subject[end-2] == '.' && subject[end-1] == fwc { + return true + } + return false +} + +////////////////////////////////////////////////////////////////////////////// +// +// MQTT Reader functions +// +////////////////////////////////////////////////////////////////////////////// + +func copyBytes(b []byte) []byte { + if b == nil { + return nil + } + cbuf := make([]byte, len(b)) + copy(cbuf, b) + return cbuf +} + +func (r *mqttReader) reset(buf []byte) { + r.buf = buf + r.pos = 0 +} + +func (r *mqttReader) hasMore() bool { + return r.pos != len(r.buf) +} + +func (r *mqttReader) readByte(field string) (byte, error) { + if r.pos == len(r.buf) { + return 0, fmt.Errorf("error reading %s: %v", field, io.EOF) + } + b := r.buf[r.pos] + r.pos++ + return b, nil +} + +func (r *mqttReader) readPacketLen() (int, error) { + m := 1 + v := 0 + for { + var b byte + if r.pos != len(r.buf) { + b = r.buf[r.pos] + r.pos++ + } else { + var buf [1]byte + if _, err := r.reader.Read(buf[:1]); err != nil { + if err == io.EOF { + return 0, io.ErrUnexpectedEOF + } + return 0, fmt.Errorf("error reading packet length: %v", err) + } + b = buf[0] + } + v += int(b&0x7f) * m + if (b & 0x80) == 0 { + return v, nil + } + m *= 0x80 + if m > 0x200000 { + return 0, errors.New("malformed variable int") + } + } +} + +func (r *mqttReader) ensurePacketInBuffer(pl int) error { + rem := len(r.buf) - r.pos + if rem >= pl { + return nil + } + b := make([]byte, pl) + start := copy(b, r.buf[r.pos:]) + for start != pl { + n, err := r.reader.Read(b[start:cap(b)]) + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return fmt.Errorf("error ensuring protocol is loaded: %v", err) + } + start += n + } + r.reset(b) + return nil +} + +func (r *mqttReader) readString(field string) (string, error) { + var s string + bs, err := r.readBytes(field, false) + if err == nil { + s = string(bs) + } + return s, err +} + +func (r *mqttReader) readBytes(field string, cp bool) ([]byte, error) { + luint, err := r.readUint16(field) + if err != nil { + return nil, err + } + l := int(luint) + if l == 0 { + return nil, nil + } + start := r.pos + if start+l > len(r.buf) { + return nil, fmt.Errorf("error reading %s: %v", field, io.ErrUnexpectedEOF) + } + r.pos += l + b := r.buf[start:r.pos] + if cp { + b = copyBytes(b) + } + return b, nil +} + +func (r *mqttReader) readUint16(field string) (uint16, error) { + if len(r.buf)-r.pos < 2 { + return 0, fmt.Errorf("error reading %s: %v", field, io.ErrUnexpectedEOF) + } + start := r.pos + r.pos += 2 + return binary.BigEndian.Uint16(r.buf[start:r.pos]), nil +} + +////////////////////////////////////////////////////////////////////////////// +// +// MQTT Writer functions +// +////////////////////////////////////////////////////////////////////////////// + +func (w *mqttWriter) WriteUint16(i uint16) { + w.WriteByte(byte(i >> 8)) + w.WriteByte(byte(i)) +} + +func (w *mqttWriter) WriteString(s string) { + w.WriteBytes([]byte(s)) +} + +func (w *mqttWriter) WriteBytes(bs []byte) { + w.WriteUint16(uint16(len(bs))) + w.Write(bs) +} + +func (w *mqttWriter) WriteVarInt(value int) { + for { + b := byte(value & 0x7f) + value >>= 7 + if value > 0 { + b |= 0x80 + } + w.WriteByte(b) + if value == 0 { + break + } + } +} diff --git a/server/mqtt_test.go b/server/mqtt_test.go new file mode 100644 index 00000000..d5ad36b3 --- /dev/null +++ b/server/mqtt_test.go @@ -0,0 +1,4111 @@ +// Copyright 2020 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bufio" + "bytes" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/nats-io/jwt/v2" + "github.com/nats-io/nats.go" +) + +type mqttErrorReader struct { + err error +} + +func (r *mqttErrorReader) Read(b []byte) (int, error) { return 0, r.err } +func (r *mqttErrorReader) SetReadDeadline(time.Time) error { return nil } + +func testNewEOFReader() *mqttErrorReader { + return &mqttErrorReader{err: io.EOF} +} + +func TestMQTTReader(t *testing.T) { + r := &mqttReader{} + r.reset([]byte{0, 2, 'a', 'b'}) + bs, err := r.readBytes("", false) + if err != nil { + t.Fatal(err) + } + sbs := string(bs) + if sbs != "ab" { + t.Fatalf(`expected "ab", got %q`, sbs) + } + + r.reset([]byte{0, 2, 'a', 'b'}) + bs, err = r.readBytes("", true) + if err != nil { + t.Fatal(err) + } + bs[0], bs[1] = 'c', 'd' + if bytes.Equal(bs, r.buf[2:]) { + t.Fatal("readBytes should have returned a copy") + } + + r.reset([]byte{'a', 'b'}) + if b, err := r.readByte(""); err != nil || b != 'a' { + t.Fatalf("Error reading byte: b=%v err=%v", b, err) + } + if !r.hasMore() { + t.Fatal("expected to have more, did not") + } + if b, err := r.readByte(""); err != nil || b != 'b' { + t.Fatalf("Error reading byte: b=%v err=%v", b, err) + } + if r.hasMore() { + t.Fatal("expected to not have more") + } + if _, err := r.readByte("test"); err == nil || !strings.Contains(err.Error(), "error reading test") { + t.Fatalf("unexpected error: %v", err) + } + + r.reset([]byte{0, 2, 'a', 'b'}) + if s, err := r.readString(""); err != nil || s != "ab" { + t.Fatalf("Error reading string: s=%q err=%v", s, err) + } + + r.reset([]byte{10}) + if _, err := r.readUint16("uint16"); err == nil || !strings.Contains(err.Error(), "error reading uint16") { + t.Fatalf("unexpected error: %v", err) + } + + r.reset([]byte{1, 2, 3}) + r.reader = testNewEOFReader() + if err := r.ensurePacketInBuffer(10); err == nil || !strings.Contains(err.Error(), "error ensuring protocol is loaded") { + t.Fatalf("unexpected error: %v", err) + } + + r.reset([]byte{0x82, 0xff, 0x3}) + l, err := r.readPacketLen() + if err != nil { + t.Fatal("error getting packet len") + } + if l != 0xff82 { + t.Fatalf("expected length 0xff82 got 0x%x", l) + } + r.reset([]byte{0xff, 0xff, 0xff, 0xff, 0xff}) + if _, err := r.readPacketLen(); err == nil || !strings.Contains(err.Error(), "malformed") { + t.Fatalf("unexpected error: %v", err) + } + r.reset([]byte{0x80}) + if _, err := r.readPacketLen(); err != io.ErrUnexpectedEOF { + t.Fatalf("unexpected error: %v", err) + } + + r.reset([]byte{0x80}) + r.reader = &mqttErrorReader{err: errors.New("on purpose")} + if _, err := r.readPacketLen(); err == nil || !strings.Contains(err.Error(), "on purpose") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestMQTTWriter(t *testing.T) { + w := &mqttWriter{} + w.WriteUint16(1234) + + r := &mqttReader{} + r.reset(w.Bytes()) + if v, err := r.readUint16(""); err != nil || v != 1234 { + t.Fatalf("unexpected value: v=%v err=%v", v, err) + } + + w.Reset() + w.WriteString("test") + r.reset(w.Bytes()) + if len(r.buf) != 6 { + t.Fatalf("Expected 2 bytes size before string, got %v", r.buf) + } + + w.Reset() + w.WriteBytes([]byte("test")) + r.reset(w.Bytes()) + if len(r.buf) != 6 { + t.Fatalf("Expected 2 bytes size before bytes, got %v", r.buf) + } + + ints := []int{ + 0, 1, 127, 128, 16383, 16384, 2097151, 2097152, 268435455, + } + lens := []int{ + 1, 1, 1, 2, 2, 3, 3, 4, 4, + } + + tl := 0 + w.Reset() + for i, v := range ints { + w.WriteVarInt(v) + tl += lens[i] + if tl != w.Len() { + t.Fatalf("expected len %d, got %d", tl, w.Len()) + } + } + + r.reset(w.Bytes()) + for _, v := range ints { + x, _ := r.readPacketLen() + if v != x { + t.Fatalf("expected %d, got %d", v, x) + } + } +} + +func testMQTTDefaultOptions() *Options { + o := DefaultOptions() + o.Cluster.Port = 0 + o.Gateway.Name = "" + o.Gateway.Port = 0 + o.LeafNode.Port = 0 + o.Websocket.Port = 0 + o.MQTT.Host = "127.0.0.1" + o.MQTT.Port = -1 + o.JetStream = true + return o +} + +func testMQTTRunServer(t testing.TB, o *Options) *Server { + o.NoLog = false + s, err := NewServer(o) + if err != nil { + t.Fatalf("Error creating server: %v", err) + } + l := &DummyLogger{} + s.SetLogger(l, true, true) + go s.Start() + if !s.ReadyForConnections(3 * time.Second) { + t.Fatal("Unable to start server") + } + return s +} + +func testMQTTShutdownServer(s *Server) { + if c := s.JetStreamConfig(); c != nil { + dir := strings.TrimSuffix(c.StoreDir, JetStreamStoreDir) + defer os.RemoveAll(dir) + } + s.Shutdown() +} + +func testMQTTDefaultTLSOptions(t *testing.T, verify bool) *Options { + t.Helper() + o := testMQTTDefaultOptions() + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/server-cert.pem", + KeyFile: "../test/configs/certs/server-key.pem", + CaFile: "../test/configs/certs/ca.pem", + Verify: verify, + } + var err error + o.MQTT.TLSConfig, err = GenTLSConfig(tc) + o.MQTT.TLSTimeout = 2.0 + if err != nil { + t.Fatalf("Error creating tls config: %v", err) + } + return o +} + +func TestMQTTConfig(t *testing.T) { + conf := createConfFile(t, []byte(` + mqtt { + port: -1 + tls { + cert_file: "./configs/certs/server.pem" + key_file: "./configs/certs/key.pem" + } + } + `)) + defer os.Remove(conf) + s, o := RunServerWithConfig(conf) + defer testMQTTShutdownServer(s) + if o.MQTT.TLSConfig == nil { + t.Fatal("expected TLS config to be set") + } +} + +func TestMQTTValidateOptions(t *testing.T) { + nmqtto := DefaultOptions() + mqtto := testMQTTDefaultOptions() + for _, test := range []struct { + name string + getOpts func() *Options + err string + }{ + {"mqtt disabled", func() *Options { return nmqtto.Clone() }, ""}, + {"mqtt username not allowed if users specified", func() *Options { + o := mqtto.Clone() + o.Users = []*User{&User{Username: "abc", Password: "pwd"}} + o.MQTT.Username = "b" + o.MQTT.Password = "pwd" + return o + }, "mqtt authentication username not compatible with presence of users/nkeys"}, + {"mqtt token not allowed if users specified", func() *Options { + o := mqtto.Clone() + o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: "abc"}} + o.MQTT.Token = "mytoken" + return o + }, "mqtt authentication token not compatible with presence of users/nkeys"}, + {"ack wait should be >=0", func() *Options { + o := mqtto.Clone() + o.MQTT.AckWait = -10 * time.Second + return o + }, "ack wait must be a positive value"}, + } { + t.Run(test.name, func(t *testing.T) { + err := validateMQTTOptions(test.getOpts()) + if test.err == "" && err != nil { + t.Fatalf("Unexpected error: %v", err) + } else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) { + t.Fatalf("Expected error to contain %q, got %v", test.err, err) + } + }) + } +} + +func TestMQTTParseOptions(t *testing.T) { + for _, test := range []struct { + name string + content string + checkOpt func(*MQTTOpts) error + err string + }{ + // Negative tests + {"bad type", "mqtt: []", nil, "to be a map"}, + {"bad listen", "mqtt: { listen: [] }", nil, "port or host:port"}, + {"bad port", `mqtt: { port: "abc" }`, nil, "not int64"}, + {"bad host", `mqtt: { host: 123 }`, nil, "not string"}, + {"bad tls", `mqtt: { tls: 123 }`, nil, "not map[string]interface {}"}, + {"unknown field", `mqtt: { this_does_not_exist: 123 }`, nil, "unknown"}, + {"ack wait", `mqtt: {ack_wait: abc}`, nil, "invalid duration"}, + {"max ack pending", `mqtt: {max_ack_pending: abc}`, nil, "not int64"}, + {"max ack pending too high", `mqtt: {max_ack_pending: 12345678}`, nil, "invalid value"}, + // Positive tests + {"tls gen fails", ` + mqtt { + tls { + cert_file: "./configs/certs/server.pem" + } + }`, nil, "missing 'key_file'"}, + {"listen port only", `mqtt { listen: 1234 }`, func(o *MQTTOpts) error { + if o.Port != 1234 { + return fmt.Errorf("expected 1234, got %v", o.Port) + } + return nil + }, ""}, + {"listen host and port", `mqtt { listen: "localhost:1234" }`, func(o *MQTTOpts) error { + if o.Host != "localhost" || o.Port != 1234 { + return fmt.Errorf("expected localhost:1234, got %v:%v", o.Host, o.Port) + } + return nil + }, ""}, + {"host", `mqtt { host: "localhost" }`, func(o *MQTTOpts) error { + if o.Host != "localhost" { + return fmt.Errorf("expected localhost, got %v", o.Host) + } + return nil + }, ""}, + {"port", `mqtt { port: 1234 }`, func(o *MQTTOpts) error { + if o.Port != 1234 { + return fmt.Errorf("expected 1234, got %v", o.Port) + } + return nil + }, ""}, + {"tls config", + ` + mqtt { + tls { + cert_file: "./configs/certs/server.pem" + key_file: "./configs/certs/key.pem" + } + } + `, func(o *MQTTOpts) error { + if o.TLSConfig == nil { + return fmt.Errorf("TLSConfig should have been set") + } + return nil + }, ""}, + {"no auth user", + ` + mqtt { + no_auth_user: "noauthuser" + } + `, func(o *MQTTOpts) error { + if o.NoAuthUser != "noauthuser" { + return fmt.Errorf("Invalid NoAuthUser value: %q", o.NoAuthUser) + } + return nil + }, ""}, + {"auth block", + ` + mqtt { + authorization { + user: "mqttuser" + password: "pwd" + token: "token" + timeout: 2.0 + } + } + `, func(o *MQTTOpts) error { + if o.Username != "mqttuser" || o.Password != "pwd" || o.Token != "token" || o.AuthTimeout != 2.0 { + return fmt.Errorf("Invalid auth block: %+v", o) + } + return nil + }, ""}, + {"auth timeout as int", + ` + mqtt { + authorization { + timeout: 2 + } + } + `, func(o *MQTTOpts) error { + if o.AuthTimeout != 2.0 { + return fmt.Errorf("Invalid auth timeout: %v", o.AuthTimeout) + } + return nil + }, ""}, + {"ack wait", + ` + mqtt { + ack_wait: "10s" + } + `, func(o *MQTTOpts) error { + if o.AckWait != 10*time.Second { + return fmt.Errorf("Invalid ack wait: %v", o.AckWait) + } + return nil + }, ""}, + {"max ack pending", + ` + mqtt { + max_ack_pending: 123 + } + `, func(o *MQTTOpts) error { + if o.MaxAckPending != 123 { + return fmt.Errorf("Invalid max ack pending: %v", o.MaxAckPending) + } + return nil + }, ""}, + } { + t.Run(test.name, func(t *testing.T) { + conf := createConfFile(t, []byte(test.content)) + defer os.Remove(conf) + o, err := ProcessConfigFile(conf) + if test.err != _EMPTY_ { + if err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("For content: %q, expected error about %q, got %v", test.content, test.err, err) + } + return + } else if err != nil { + t.Fatalf("Unexpected error for content %q: %v", test.content, err) + } + if err := test.checkOpt(&o.MQTT); err != nil { + t.Fatalf("Incorrect option for content %q: %v", test.content, err.Error()) + } + }) + } +} + +func TestMQTTStart(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + nc.Close() + + // Check failure to start due to port in use + o2 := testMQTTDefaultOptions() + o2.MQTT.Port = o.MQTT.Port + s2, err := NewServer(o2) + if err != nil { + t.Fatalf("Error creating server: %v", err) + } + defer s2.Shutdown() + l := &captureFatalLogger{fatalCh: make(chan string, 1)} + s2.SetLogger(l, false, false) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + s2.Start() + wg.Done() + }() + + select { + case e := <-l.fatalCh: + if !strings.Contains(e, "Unable to listen for MQTT connections") { + t.Fatalf("Unexpected error: %q", e) + } + case <-time.After(time.Second): + t.Fatal("Should have gotten a fatal error") + } +} + +func TestMQTTTLS(t *testing.T) { + o := testMQTTDefaultTLSOptions(t, false) + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + defer nc.Close() + // Set MaxVersion to TLSv1.2 so that we fail on handshake if there is + // a disagreement between server and client. + tlsc := &tls.Config{ + MaxVersion: tls.VersionTLS12, + InsecureSkipVerify: true, + } + tlsConn := tls.Client(nc, tlsc) + tlsConn.SetDeadline(time.Now().Add(time.Second)) + if err := tlsConn.Handshake(); err != nil { + t.Fatalf("Error doing tls handshake: %v", err) + } + nc.Close() + testMQTTShutdownServer(s) + + // Force client cert verification + o = testMQTTDefaultTLSOptions(t, true) + s = testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + defer nc.Close() + // Set MaxVersion to TLSv1.2 so that we fail on handshake if there is + // a disagreement between server and client. + tlsc = &tls.Config{ + MaxVersion: tls.VersionTLS12, + InsecureSkipVerify: true, + } + tlsConn = tls.Client(nc, tlsc) + tlsConn.SetDeadline(time.Now().Add(time.Second)) + if err := tlsConn.Handshake(); err == nil { + t.Fatal("Handshake expected to fail since client did not provide cert") + } + nc.Close() + + // Add client cert. + nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + defer nc.Close() + + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/client-cert.pem", + KeyFile: "../test/configs/certs/client-key.pem", + } + tlsc, err = GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error generating tls config: %v", err) + } + tlsc.InsecureSkipVerify = true + tlsConn = tls.Client(nc, tlsc) + tlsConn.SetDeadline(time.Now().Add(time.Second)) + if err := tlsConn.Handshake(); err != nil { + t.Fatalf("Handshake error: %v", err) + } + nc.Close() + testMQTTShutdownServer(s) + + // Lower TLS timeout so low that we should fail + o.MQTT.TLSTimeout = 0.001 + s = testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Unable to create tcp connection to mqtt port: %v", err) + } + defer nc.Close() + time.Sleep(100 * time.Millisecond) + tlsConn = tls.Client(nc, tlsc) + tlsConn.SetDeadline(time.Now().Add(time.Second)) + if err := tlsConn.Handshake(); err == nil { + t.Fatal("Expected failure, did not get one") + } +} + +type mqttConnInfo struct { + clientID string + cleanSess bool + keepAlive uint16 + will *mqttWill + user string + pass string +} + +func testMQTTGetClient(t testing.TB, s *Server, clientID string) *client { + t.Helper() + var mc *client + s.mu.Lock() + for _, c := range s.clients { + c.mu.Lock() + if c.mqtt != nil && c.mqtt.cp != nil && c.mqtt.cp.clientID == clientID { + mc = c + } + c.mu.Unlock() + if mc != nil { + break + } + } + s.mu.Unlock() + if mc == nil { + t.Fatalf("Did not find client %q", clientID) + } + return mc +} + +func testMQTTRead(c net.Conn) ([]byte, error) { + var buf [512]byte + // Make sure that test does not block + c.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := c.Read(buf[:]) + if err != nil { + return nil, err + } + c.SetReadDeadline(time.Time{}) + return copyBytes(buf[:n]), nil +} + +func testMQTTWrite(c net.Conn, buf []byte) (int, error) { + c.SetWriteDeadline(time.Now().Add(2 * time.Second)) + n, err := c.Write(buf) + c.SetWriteDeadline(time.Time{}) + return n, err +} + +func testMQTTConnect(t testing.TB, ci *mqttConnInfo, host string, port int) (net.Conn, *mqttReader) { + t.Helper() + + addr := fmt.Sprintf("%s:%d", host, port) + c, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating mqtt connection: %v", err) + } + + proto := mqttCreateConnectProto(ci) + if _, err := testMQTTWrite(c, proto); err != nil { + t.Fatalf("Error writing connect: %v", err) + } + + buf, err := testMQTTRead(c) + if err != nil { + t.Fatalf("Error reading: %v", err) + } + br := &mqttReader{reader: c} + br.reset(buf) + + return c, br +} + +func mqttCreateConnectProto(ci *mqttConnInfo) []byte { + flags := byte(0) + if ci.cleanSess { + flags |= mqttConnFlagCleanSession + } + if ci.will != nil { + flags |= mqttConnFlagWillFlag | (ci.will.qos << 3) + if ci.will.retain { + flags |= mqttConnFlagWillRetain + } + } + if ci.user != _EMPTY_ { + flags |= mqttConnFlagUsernameFlag + } + if ci.pass != _EMPTY_ { + flags |= mqttConnFlagPasswordFlag + } + + pkLen := 2 + len(mqttProtoName) + + 1 + // proto level + 1 + // flags + 2 + // keepAlive + 2 + len(ci.clientID) + + if ci.will != nil { + pkLen += 2 + len(ci.will.topic) + pkLen += 2 + len(ci.will.message) + } + if ci.user != _EMPTY_ { + pkLen += 2 + len(ci.user) + } + if ci.pass != _EMPTY_ { + pkLen += 2 + len(ci.pass) + } + + w := &mqttWriter{} + w.WriteByte(mqttPacketConnect) + w.WriteVarInt(pkLen) + w.WriteString(string(mqttProtoName)) + w.WriteByte(0x4) + w.WriteByte(flags) + w.WriteUint16(ci.keepAlive) + w.WriteString(ci.clientID) + if ci.will != nil { + w.WriteBytes(ci.will.topic) + w.WriteBytes(ci.will.message) + } + if ci.user != _EMPTY_ { + w.WriteString(ci.user) + } + if ci.pass != _EMPTY_ { + w.WriteBytes([]byte(ci.pass)) + } + return w.Bytes() +} + +func testMQTTCheckConnAck(t testing.TB, r *mqttReader, rc byte, sessionPresent bool) { + t.Helper() + r.reader.SetReadDeadline(time.Now().Add(2 * time.Second)) + if err := r.ensurePacketInBuffer(4); err != nil { + t.Fatalf("Error ensuring packet in buffer: %v", err) + } + r.reader.SetReadDeadline(time.Time{}) + b, err := r.readByte("connack packet type") + if err != nil { + t.Fatalf("Error reading packet type: %v", err) + } + pt := b & mqttPacketMask + if pt != mqttPacketConnectAck { + t.Fatalf("Expected ConnAck (%x), got %x", mqttPacketConnectAck, pt) + } + pl, err := r.readByte("connack packet len") + if err != nil { + t.Fatalf("Error reading packet length: %v", err) + } + if pl != 2 { + t.Fatalf("ConnAck packet length should be 2, got %v", pl) + } + caf, err := r.readByte("connack flags") + if err != nil { + t.Fatalf("Error reading packet length: %v", err) + } + if caf&0xfe != 0 { + t.Fatalf("ConnAck flag bits 7-1 should all be 0, got %x", caf>>1) + } + if sp := caf == 1; sp != sessionPresent { + t.Fatalf("Expected session present flag=%v got %v", sessionPresent, sp) + } + carc, err := r.readByte("connack return code") + if err != nil { + t.Fatalf("Error reading returned code: %v", err) + } + if carc != rc { + t.Fatalf("Expected return code to be %v, got %v", rc, carc) + } +} + +func testMQTTEnableJSForAccount(t *testing.T, s *Server, accName string) { + t.Helper() + acc, err := s.LookupAccount(accName) + if err != nil { + t.Fatalf("Error looking up account: %v", err) + } + limits := &JetStreamAccountLimits{ + MaxConsumers: -1, + MaxStreams: -1, + MaxMemory: 1024 * 1024, + } + if err := acc.EnableJetStream(limits); err != nil { + t.Fatalf("Error enabling JS: %v", err) + } +} + +func TestMQTTTLSVerifyAndMap(t *testing.T) { + accName := "MyAccount" + acc := NewAccount(accName) + certUserName := "CN=example.com,OU=NATS.io" + users := []*User{&User{Username: certUserName, Account: acc}} + + for _, test := range []struct { + name string + filtering bool + provideCert bool + }{ + {"no filtering, client provides cert", false, true}, + {"no filtering, client does not provide cert", false, false}, + {"filtering, client provides cert", true, true}, + {"filtering, client does not provide cert", true, false}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + o.Host = "localhost" + o.Accounts = []*Account{acc} + o.Users = users + if test.filtering { + o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeMqtt}) + } + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/tlsauth/server.pem", + KeyFile: "../test/configs/certs/tlsauth/server-key.pem", + CaFile: "../test/configs/certs/tlsauth/ca.pem", + Verify: true, + } + tlsc, err := GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error creating tls config: %v", err) + } + o.MQTT.TLSConfig = tlsc + o.MQTT.TLSTimeout = 2.0 + o.MQTT.TLSMap = true + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + testMQTTEnableJSForAccount(t, s, accName) + + addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port) + mc, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating ws connection: %v", err) + } + defer mc.Close() + tlscc := &tls.Config{} + if test.provideCert { + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/tlsauth/client.pem", + KeyFile: "../test/configs/certs/tlsauth/client-key.pem", + } + var err error + tlscc, err = GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error generating tls config: %v", err) + } + } + tlscc.InsecureSkipVerify = true + if test.provideCert { + tlscc.MinVersion = tls.VersionTLS13 + } + mc = tls.Client(mc, tlscc) + if err := mc.(*tls.Conn).Handshake(); err != nil { + t.Fatalf("Error during handshake: %v", err) + } + + ci := &mqttConnInfo{cleanSess: true} + proto := mqttCreateConnectProto(ci) + if _, err := testMQTTWrite(mc, proto); err != nil { + t.Fatalf("Error sending proto: %v", err) + } + buf, err := testMQTTRead(mc) + if !test.provideCert { + if err == nil { + t.Fatal("Expected error, did not get one") + } else if !strings.Contains(err.Error(), "bad certificate") { + t.Fatalf("Unexpected error: %v", err) + } + return + } + if err != nil { + t.Fatalf("Error reading: %v", err) + } + r := &mqttReader{reader: mc} + r.reset(buf) + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + var c *client + s.mu.Lock() + for _, sc := range s.clients { + sc.mu.Lock() + if sc.mqtt != nil { + c = sc + } + sc.mu.Unlock() + if c != nil { + break + } + } + s.mu.Unlock() + if c == nil { + t.Fatal("Client not found") + } + + var uname string + var accname string + c.mu.Lock() + uname = c.opts.Username + if c.acc != nil { + accname = c.acc.GetName() + } + c.mu.Unlock() + if uname != certUserName { + t.Fatalf("Expected username %q, got %q", certUserName, uname) + } + if accname != accName { + t.Fatalf("Expected account %q, got %v", accName, accname) + } + }) + } +} + +func TestMQTTBasicAuth(t *testing.T) { + for _, test := range []struct { + name string + opts func() *Options + user string + pass string + rc byte + }{ + { + "top level auth, no override, wrong u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.Username = "normal" + o.Password = "client" + return o + }, + "mqtt", "client", mqttConnAckRCNotAuthorized, + }, + { + "top level auth, no override, correct u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.Username = "normal" + o.Password = "client" + return o + }, + "normal", "client", mqttConnAckRCConnectionAccepted, + }, + { + "no top level auth, mqtt auth, wrong u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + return o + }, + "normal", "client", mqttConnAckRCNotAuthorized, + }, + { + "no top level auth, mqtt auth, correct u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + return o + }, + "mqtt", "client", mqttConnAckRCConnectionAccepted, + }, + { + "top level auth, mqtt override, wrong u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.Username = "normal" + o.Password = "client" + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + return o + }, + "normal", "client", mqttConnAckRCNotAuthorized, + }, + { + "top level auth, mqtt override, correct u/p", + func() *Options { + o := testMQTTDefaultOptions() + o.Username = "normal" + o.Password = "client" + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + return o + }, + "mqtt", "client", mqttConnAckRCConnectionAccepted, + }, + } { + t.Run(test.name, func(t *testing.T) { + o := test.opts() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: test.user, + pass: test.pass, + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, test.rc, false) + }) + } +} + +func TestMQTTAuthTimeout(t *testing.T) { + for _, test := range []struct { + name string + at float64 + mat float64 + ok bool + }{ + {"use top-level auth timeout", 0.5, 0.0, true}, + {"use mqtt auth timeout", 0.5, 0.05, false}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + o.AuthTimeout = test.at + o.MQTT.Username = "mqtt" + o.MQTT.Password = "client" + o.MQTT.AuthTimeout = test.mat + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + defer mc.Close() + + time.Sleep(100 * time.Millisecond) + + ci := &mqttConnInfo{ + cleanSess: true, + user: "mqtt", + pass: "client", + } + proto := mqttCreateConnectProto(ci) + if _, err := testMQTTWrite(mc, proto); err != nil { + if test.ok { + t.Fatalf("Error sending connect: %v", err) + } + // else it is ok since we got disconnected due to auth timeout + return + } + buf, err := testMQTTRead(mc) + if err != nil { + if test.ok { + t.Fatalf("Error reading: %v", err) + } + // else it is ok since we got disconnected due to auth timeout + return + } + r := &mqttReader{reader: mc} + r.reset(buf) + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + time.Sleep(500 * time.Millisecond) + testMQTTPublish(t, mc, r, 1, false, false, "foo", 1, []byte("msg")) + }) + } +} + +func TestMQTTTokenAuth(t *testing.T) { + for _, test := range []struct { + name string + opts func() *Options + token string + rc byte + }{ + { + "top level auth, no override, wrong token", + func() *Options { + o := testMQTTDefaultOptions() + o.Authorization = "goodtoken" + return o + }, + "badtoken", mqttConnAckRCNotAuthorized, + }, + { + "top level auth, no override, correct token", + func() *Options { + o := testMQTTDefaultOptions() + o.Authorization = "goodtoken" + return o + }, + "goodtoken", mqttConnAckRCConnectionAccepted, + }, + { + "no top level auth, mqtt auth, wrong token", + func() *Options { + o := testMQTTDefaultOptions() + o.MQTT.Token = "goodtoken" + return o + }, + "badtoken", mqttConnAckRCNotAuthorized, + }, + { + "no top level auth, mqtt auth, correct token", + func() *Options { + o := testMQTTDefaultOptions() + o.MQTT.Token = "goodtoken" + return o + }, + "goodtoken", mqttConnAckRCConnectionAccepted, + }, + { + "top level auth, mqtt override, wrong token", + func() *Options { + o := testMQTTDefaultOptions() + o.Authorization = "clienttoken" + o.MQTT.Token = "mqtttoken" + return o + }, + "clienttoken", mqttConnAckRCNotAuthorized, + }, + { + "top level auth, mqtt override, correct token", + func() *Options { + o := testMQTTDefaultOptions() + o.Authorization = "clienttoken" + o.MQTT.Token = "mqtttoken" + return o + }, + "mqtttoken", mqttConnAckRCConnectionAccepted, + }, + } { + t.Run(test.name, func(t *testing.T) { + o := test.opts() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: "ignore_use_token", + pass: test.token, + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, test.rc, false) + }) + } +} + +func TestMQTTUsersAuth(t *testing.T) { + users := []*User{&User{Username: "user", Password: "pwd"}} + for _, test := range []struct { + name string + opts func() *Options + user string + pass string + rc byte + }{ + { + "no filtering, wrong user", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + return o + }, + "wronguser", "pwd", mqttConnAckRCNotAuthorized, + }, + { + "no filtering, correct user", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + return o + }, + "user", "pwd", mqttConnAckRCConnectionAccepted, + }, + { + "filtering, user not allowed", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + // Only allowed for regular clients + o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard}) + return o + }, + "user", "pwd", mqttConnAckRCNotAuthorized, + }, + { + "filtering, user allowed", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeMqtt}) + return o + }, + "user", "pwd", mqttConnAckRCConnectionAccepted, + }, + { + "filtering, wrong password", + func() *Options { + o := testMQTTDefaultOptions() + o.Users = users + o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeMqtt}) + return o + }, + "user", "badpassword", mqttConnAckRCNotAuthorized, + }, + } { + t.Run(test.name, func(t *testing.T) { + o := test.opts() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: test.user, + pass: test.pass, + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, test.rc, false) + }) + } +} + +func TestMQTTNoAuthUserValidation(t *testing.T) { + o := testMQTTDefaultOptions() + o.Users = []*User{&User{Username: "user", Password: "pwd"}} + // Should fail because it is not part of o.Users. + o.MQTT.NoAuthUser = "notfound" + if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") { + t.Fatalf("Expected error saying not present as user, got %v", err) + } + + // Set a valid no auth user for global options, but still should fail because + // of o.MQTT.NoAuthUser + o.NoAuthUser = "user" + o.MQTT.NoAuthUser = "notfound" + if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") { + t.Fatalf("Expected error saying not present as user, got %v", err) + } +} + +func TestMQTTNoAuthUser(t *testing.T) { + for _, test := range []struct { + name string + override bool + useAuth bool + expectedUser string + expectedAcc string + }{ + {"no override, no user provided", false, false, "noauth", "normal"}, + {"no override, user povided", false, true, "user", "normal"}, + {"override, no user provided", true, false, "mqttnoauth", "mqtt"}, + {"override, user provided", true, true, "mqttuser", "mqtt"}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + normalAcc := NewAccount("normal") + mqttAcc := NewAccount("mqtt") + o.Accounts = []*Account{normalAcc, mqttAcc} + o.Users = []*User{ + &User{Username: "noauth", Password: "pwd", Account: normalAcc}, + &User{Username: "user", Password: "pwd", Account: normalAcc}, + &User{Username: "mqttnoauth", Password: "pwd", Account: mqttAcc}, + &User{Username: "mqttuser", Password: "pwd", Account: mqttAcc}, + } + o.NoAuthUser = "noauth" + if test.override { + o.MQTT.NoAuthUser = "mqttnoauth" + } + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + testMQTTEnableJSForAccount(t, s, "normal") + testMQTTEnableJSForAccount(t, s, "mqtt") + + ci := &mqttConnInfo{clientID: "mqtt", cleanSess: true} + if test.useAuth { + ci.user = test.expectedUser + ci.pass = "pwd" + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + c := testMQTTGetClient(t, s, "mqtt") + c.mu.Lock() + uname := c.opts.Username + aname := c.acc.GetName() + c.mu.Unlock() + if uname != test.expectedUser { + t.Fatalf("Expected selected user to be %q, got %q", test.expectedUser, uname) + } + if aname != test.expectedAcc { + t.Fatalf("Expected selected account to be %q, got %q", test.expectedAcc, aname) + } + }) + } +} + +func TestMQTTConnectNotFirstProto(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + c, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) + if err != nil { + t.Fatalf("Error on dial: %v", err) + } + defer c.Close() + + w := &mqttWriter{} + mqttWritePublish(w, 0, false, false, "foo", 0, []byte("hello")) + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error publishing: %v", err) + } + testMQTTExpectDisconnect(t, c) +} + +func TestMQTTSecondConnect(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + proto := mqttCreateConnectProto(&mqttConnInfo{cleanSess: true}) + if _, err := testMQTTWrite(mc, proto); err != nil { + t.Fatalf("Error writing connect: %v", err) + } + testMQTTExpectDisconnect(t, mc) +} + +func TestMQTTParseConnect(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + proto []byte + pl int + reader mqttIOReader + err string + }{ + {"packet in buffer error", nil, 10, eofr, "error ensuring protocol is loaded"}, + {"bad proto name", []byte{0, 4, 'B', 'A', 'D'}, 5, nil, "protocol name"}, + {"invalid proto name", []byte{0, 3, 'B', 'A', 'D'}, 5, nil, "expected connect packet with protocol name"}, + {"old proto not supported", []byte{0, 6, 'M', 'Q', 'I', 's', 'd', 'p'}, 8, nil, "older protocol"}, + {"error on protocol level", []byte{0, 4, 'M', 'Q', 'T', 'T'}, 6, eofr, "protocol level"}, + {"unacceptable protocol version", []byte{0, 4, 'M', 'Q', 'T', 'T', 10}, 7, nil, "unacceptable protocol version"}, + {"error on flags", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel}, 7, eofr, "flags"}, + {"reserved flag", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 1}, 8, nil, "connect flags reserved bit not set to 0"}, + {"will qos without will flag", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 1 << 3}, 8, nil, "if Will flag is set to 0, Will QoS must be 0 too"}, + {"will retain without will flag", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 1 << 5}, 8, nil, "if Will flag is set to 0, Will Retain flag must be 0 too"}, + {"will qos", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 3<<3 | 1<<2}, 8, nil, "if Will flag is set to 1, Will QoS can be 0, 1 or 2"}, + {"no user but password", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagPasswordFlag}, 8, nil, "password flag set but username flag is not"}, + {"missing keep alive", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0}, 8, nil, "keep alive"}, + {"missing client ID", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0, 0, 1}, 10, nil, "client ID"}, + {"empty client ID", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0, 0, 1, 0, 0}, 12, nil, "when client ID is empty, clean session flag must be set to 1"}, + {"invalid utf8 client ID", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0, 0, 1, 0, 1, 241}, 13, nil, "invalid utf8 for client ID"}, + {"missing will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0}, 12, nil, "Will topic"}, + {"empty will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 0}, 14, nil, "empty Will topic not allowed"}, + {"invalid utf8 will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 241}, 15, nil, "invalide utf8 for Will topic"}, + {"invalid wildcard will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, '#'}, 15, nil, "wildcards not allowed"}, + {"error on will message", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 'a', 0, 3}, 17, eofr, "Will message"}, + {"error on username", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagCleanSession, 0, 0, 0, 0}, 12, eofr, "user name"}, + {"empty username", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 0}, 14, nil, "empty user name not allowed"}, + {"invalid utf8 username", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 241}, 15, nil, "invalid utf8 for user name"}, + {"error on password", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagPasswordFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 'a'}, 15, eofr, "password"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + mqtt := &mqtt{r: r} + c := &client{mqtt: mqtt} + if _, _, err := c.mqttParseConnect(r, test.pl); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func TestMQTTConnectFailsOnParse(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port) + c, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating mqtt connection: %v", err) + } + + pkLen := 2 + len(mqttProtoName) + + 1 + // proto level + 1 + // flags + 2 + // keepAlive + 2 + len("mqtt") + + w := &mqttWriter{} + w.WriteByte(mqttPacketConnect) + w.WriteVarInt(pkLen) + w.WriteString(string(mqttProtoName)) + w.WriteByte(0x7) + w.WriteByte(mqttConnFlagCleanSession) + w.WriteUint16(0) + w.WriteString("mqtt") + c.Write(w.Bytes()) + + buf, err := testMQTTRead(c) + if err != nil { + t.Fatalf("Error reading: %v", err) + } + r := &mqttReader{reader: c} + r.reset(buf) + testMQTTCheckConnAck(t, r, mqttConnAckRCUnacceptableProtocolVersion, false) +} + +func TestMQTTConnKeepAlive(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true, keepAlive: 1}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, mc, r, 0, false, false, "foo", 0, []byte("msg")) + + time.Sleep(2 * time.Second) + testMQTTExpectDisconnect(t, mc) +} + +func TestMQTTTopicAndSubjectConversion(t *testing.T) { + for _, test := range []struct { + name string + mqttTopic string + natsSubject string + err string + }{ + {"/", "/", "/./", ""}, + {"//", "//", "/././", ""}, + {"///", "///", "/./././", ""}, + {"////", "////", "/././././", ""}, + {"foo", "foo", "foo", ""}, + {"/foo", "/foo", "/.foo", ""}, + {"//foo", "//foo", "/./.foo", ""}, + {"///foo", "///foo", "/././.foo", ""}, + {"///foo/", "///foo/", "/././.foo./", ""}, + {"///foo//", "///foo//", "/././.foo././", ""}, + {"///foo///", "///foo///", "/././.foo./././", ""}, + {"foo/bar", "foo/bar", "foo.bar", ""}, + {"/foo/bar", "/foo/bar", "/.foo.bar", ""}, + {"/foo/bar/", "/foo/bar/", "/.foo.bar./", ""}, + {"foo/bar/baz", "foo/bar/baz", "foo.bar.baz", ""}, + {"/foo/bar/baz", "/foo/bar/baz", "/.foo.bar.baz", ""}, + {"/foo/bar/baz/", "/foo/bar/baz/", "/.foo.bar.baz./", ""}, + {"bar", "bar/", "bar./", ""}, + {"bar//", "bar//", "bar././", ""}, + {"bar///", "bar///", "bar./././", ""}, + {"foo//bar", "foo//bar", "foo./.bar", ""}, + {"foo///bar", "foo///bar", "foo././.bar", ""}, + {"foo////bar", "foo////bar", "foo./././.bar", ""}, + // These should produce errors + {"foo/+", "foo/+", "", "wildcards not allowed in publish"}, + {"foo/#", "foo/#", "", "wildcards not allowed in publish"}, + {"foo bar", "foo bar", "", "not supported"}, + {"foo.bar", "foo.bar", "", "not supported"}, + } { + t.Run(test.name, func(t *testing.T) { + _, res, err := mqttTopicToNATSPubSubject([]byte(test.mqttTopic)) + if test.err != _EMPTY_ { + if err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %q", test.err, err.Error()) + } + return + } + toNATS := string(res) + if toNATS != test.natsSubject { + t.Fatalf("Expected subject %q got %q", test.natsSubject, toNATS) + } + + res = natsSubjectToMQTTTopic(toNATS) + backToMQTT := string(res) + if backToMQTT != test.mqttTopic { + t.Fatalf("Expected topic %q got %q (NATS conversion was %q)", test.mqttTopic, backToMQTT, toNATS) + } + }) + } +} + +func TestMQTTFilterConversion(t *testing.T) { + // Similar to TopicConversion test except that wildcards are OK here. + // So testing only those. + for _, test := range []struct { + name string + mqttTopic string + natsSubject string + }{ + {"single level wildcard", "+", "*"}, + {"single level wildcard", "/+", "/.*"}, + {"single level wildcard", "+/", "*./"}, + {"single level wildcard", "/+/", "/.*./"}, + {"single level wildcard", "foo/+", "foo.*"}, + {"single level wildcard", "foo/+/", "foo.*./"}, + {"single level wildcard", "foo/+/bar", "foo.*.bar"}, + {"single level wildcard", "foo/+/+", "foo.*.*"}, + {"single level wildcard", "foo/+/+/", "foo.*.*./"}, + {"single level wildcard", "foo/+/+/bar", "foo.*.*.bar"}, + + {"multi level wildcard", "#", ">"}, + {"multi level wildcard", "/#", "/.>"}, + {"multi level wildcard", "/foo/#", "/.foo.>"}, + {"multi level wildcard", "foo/#", "foo.>"}, + {"multi level wildcard", "foo/bar/#", "foo.bar.>"}, + } { + t.Run(test.name, func(t *testing.T) { + _, res, err := mqttFilterToNATSSubject([]byte(test.mqttTopic)) + if err != nil { + t.Fatalf("Error: %v", err) + } + if string(res) != test.natsSubject { + t.Fatalf("Expected subject %q got %q", test.natsSubject, res) + } + }) + } +} + +func testMQTTReaderHasAtLeastOne(t testing.TB, r *mqttReader) { + t.Helper() + r.reader.SetReadDeadline(time.Now().Add(2 * time.Second)) + if err := r.ensurePacketInBuffer(1); err != nil { + t.Fatal(err) + } + r.reader.SetReadDeadline(time.Time{}) +} + +func TestMQTTParseSub(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + proto []byte + b byte + pl int + reader mqttIOReader + err string + }{ + {"reserved flag", nil, 3, 0, nil, "wrong subscribe reserved flags"}, + {"ensure packet loaded", []byte{1, 2}, mqttSubscribeFlags, 10, eofr, "error ensuring protocol is loaded"}, + {"error reading packet id", []byte{1}, mqttSubscribeFlags, 1, eofr, "reading packet identifier"}, + {"missing filters", []byte{0, 1}, mqttSubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"}, + {"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttSubscribeFlags, 5, eofr, "topic filter"}, + {"empty topic", []byte{0, 1, 0, 0}, mqttSubscribeFlags, 4, nil, "topic filter cannot be empty"}, + {"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttSubscribeFlags, 5, nil, "invalid utf8 for topic filter"}, + {"missing qos", []byte{0, 1, 0, 1, 'a'}, mqttSubscribeFlags, 5, nil, "QoS"}, + {"invalid qos", []byte{0, 1, 0, 1, 'a', 3}, mqttSubscribeFlags, 6, nil, "subscribe QoS value must be 0, 1 or 2"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + mqtt := &mqtt{r: r} + c := &client{mqtt: mqtt} + if _, _, err := c.mqttParseSubsOrUnsubs(r, test.b, test.pl, true); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func testMQTTSub(t testing.TB, pi uint16, c net.Conn, r *mqttReader, filters []*mqttFilter, expected []byte) { + t.Helper() + w := &mqttWriter{} + pkLen := 2 // for pi + for i := 0; i < len(filters); i++ { + f := filters[i] + pkLen += 2 + len(f.filter) + 1 + } + w.WriteByte(mqttPacketSub | mqttSubscribeFlags) + w.WriteVarInt(pkLen) + w.WriteUint16(pi) + for i := 0; i < len(filters); i++ { + f := filters[i] + w.WriteBytes([]byte(f.filter)) + w.WriteByte(f.qos) + } + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing SUBSCRIBE protocol: %v", err) + } + // Make sure we have at least 1 byte in buffer (if not will read) + testMQTTReaderHasAtLeastOne(t, r) + // Parse SUBACK + b, err := r.readByte("packet type") + if err != nil { + t.Fatal(err) + } + if pt := b & mqttPacketMask; pt != mqttPacketSubAck { + t.Fatalf("Expected SUBACK packet %x, got %x", mqttPacketSubAck, pt) + } + pl, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + t.Fatal(err) + } + rpi, err := r.readUint16("packet identifier") + if err != nil || rpi != pi { + t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err) + } + for i, rem := 0, pl-2; rem > 0; rem-- { + qos, err := r.readByte("filter qos") + if err != nil { + t.Fatal(err) + } + if qos != expected[i] { + t.Fatalf("For topic filter %q expected qos of %v, got %v", + filters[i].filter, expected[i], qos) + } + i++ + } +} + +func TestMQTTSubAck(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + subs := []*mqttFilter{ + {filter: "foo", qos: 0}, + {filter: "bar", qos: 1}, + {filter: "baz", qos: 2}, // Since we don't support, we should receive a result of 1 + {filter: "foo/#/bar", qos: 0}, // Invalid sub, so we should receive a result of mqttSubAckFailure + } + expected := []byte{ + 0, + 1, + 1, + mqttSubAckFailure, + } + testMQTTSub(t, 1, mc, r, subs, expected) +} + +func testMQTTFlush(t testing.TB, c net.Conn, bw *bufio.Writer, r *mqttReader) { + t.Helper() + w := &mqttWriter{} + w.WriteByte(mqttPacketPing) + w.WriteByte(0) + if bw != nil { + bw.Write(w.Bytes()) + bw.Flush() + } else { + c.Write(w.Bytes()) + } + r.ensurePacketInBuffer(2) + ab, err := r.readByte("pingresp") + if err != nil { + t.Fatalf("Error reading ping response: %v", err) + } + if pt := ab & mqttPacketMask; pt != mqttPacketPingResp { + t.Fatalf("Expected ping response got %x", pt) + } + l, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if l != 0 { + t.Fatalf("Expected PINGRESP length to be 0, got %v", l) + } +} + +func testMQTTExpectNothing(t testing.TB, r *mqttReader) { + t.Helper() + r.reader.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + if err := r.ensurePacketInBuffer(1); err == nil { + t.Fatalf("Expected nothing, got %v", r.buf[r.pos:]) + } + r.reader.SetReadDeadline(time.Time{}) +} + +func testMQTTCheckPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, flags byte, payload []byte) { + t.Helper() + pflags, pi := testMQTTGetPubMsg(t, c, r, topic, payload) + if pflags != flags { + t.Fatalf("Expected flags to be %x, got %x", flags, pflags) + } + if pi > 0 { + testMQTTSendPubAck(t, c, pi) + } +} + +func testMQTTCheckPubMsgNoAck(t testing.TB, c net.Conn, r *mqttReader, topic string, flags byte, payload []byte) uint16 { + t.Helper() + pflags, pi := testMQTTGetPubMsg(t, c, r, topic, payload) + if pflags != flags { + t.Fatalf("Expected flags to be %x, got %x", flags, pflags) + } + return pi +} + +func testMQTTGetPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, payload []byte) (byte, uint16) { + t.Helper() + testMQTTReaderHasAtLeastOne(t, r) + b, err := r.readByte("packet type") + if err != nil { + t.Fatal(err) + } + if pt := b & mqttPacketMask; pt != mqttPacketPub { + t.Fatalf("Expected PUBLISH packet %x, got %x", mqttPacketPub, pt) + } + pflags := b & mqttPacketFlagMask + qos := (pflags & mqttPubFlagQoS) >> 1 + pl, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + t.Fatal(err) + } + start := r.pos + ptopic, err := r.readString("topic name") + if err != nil { + t.Fatal(err) + } + if ptopic != topic { + t.Fatalf("Expected topic %q, got %q", topic, ptopic) + } + var pi uint16 + if qos > 0 { + pi, err = r.readUint16("packet identifier") + if err != nil { + t.Fatal(err) + } + } + msgLen := pl - (r.pos - start) + if r.pos+msgLen > len(r.buf) { + t.Fatalf("computed message length goes beyond buffer: ml=%v pos=%v lenBuf=%v", + msgLen, r.pos, len(r.buf)) + } + ppayload := r.buf[r.pos : r.pos+msgLen] + if !bytes.Equal(payload, ppayload) { + t.Fatalf("Expected payload %q, got %q", payload, ppayload) + } + r.pos += msgLen + return pflags, pi +} + +func testMQTTSendPubAck(t testing.TB, c net.Conn, pi uint16) { + t.Helper() + w := &mqttWriter{} + w.WriteByte(mqttPacketPubAck) + w.WriteVarInt(2) + w.WriteUint16(pi) + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing PUBACK: %v", err) + } +} + +func testMQTTPublish(t testing.TB, c net.Conn, r *mqttReader, qos byte, dup, retain bool, topic string, pi uint16, payload []byte) { + t.Helper() + w := &mqttWriter{} + mqttWritePublish(w, qos, dup, retain, topic, pi, payload) + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing PUBLISH proto: %v", err) + } + if qos > 0 { + // Since we don't support QoS 2, we should get disconnected + if qos == 2 { + testMQTTExpectDisconnect(t, c) + return + } + testMQTTReaderHasAtLeastOne(t, r) + // Parse PUBACK + b, err := r.readByte("packet type") + if err != nil { + t.Fatal(err) + } + if pt := b & mqttPacketMask; pt != mqttPacketPubAck { + t.Fatalf("Expected PUBACK packet %x, got %x", mqttPacketPubAck, pt) + } + pl, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + t.Fatal(err) + } + rpi, err := r.readUint16("packet identifier") + if err != nil || rpi != pi { + t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err) + } + } +} + +func TestMQTTParsePub(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + flags byte + proto []byte + pl int + reader mqttIOReader + err string + }{ + {"qos not supported", 0x4, nil, 0, nil, "not supported"}, + {"packet in buffer error", 0, nil, 10, eofr, "error ensuring protocol is loaded"}, + {"error on topic", 0, []byte{0, 3, 'f', 'o'}, 4, eofr, "topic"}, + {"empty topic", 0, []byte{0, 0}, 2, nil, "topic cannot be empty"}, + {"wildcards topic", 0, []byte{0, 1, '#'}, 3, nil, "wildcards not allowed"}, + {"error on packet identifier", mqttPubQos1, []byte{0, 3, 'f', 'o', 'o'}, 5, eofr, "packet identifier"}, + {"invalid packet identifier", mqttPubQos1, []byte{0, 3, 'f', 'o', 'o', 0, 0}, 7, nil, "packet identifier cannot be 0"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + mqtt := &mqtt{r: r} + c := &client{mqtt: mqtt} + pp := &mqttPublish{flags: test.flags} + if err := c.mqttParsePub(r, test.pl, pp); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func TestMQTTParsePubAck(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + proto []byte + pl int + reader mqttIOReader + err string + }{ + {"packet in buffer error", nil, 10, eofr, "error ensuring protocol is loaded"}, + {"error reading packet identifier", []byte{0}, 1, eofr, "packet identifier"}, + {"invalid packet identifier", []byte{0, 0}, 2, nil, "packet identifier cannot be 0"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + if _, err := mqttParsePubAck(r, test.pl); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func TestMQTTPublish(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, mcp, mpr, 0, false, false, "foo", 0, []byte("msg")) + testMQTTPublish(t, mcp, mpr, 1, false, false, "foo", 1, []byte("msg")) + testMQTTPublish(t, mcp, mpr, 2, false, false, "foo", 2, []byte("msg")) +} + +func TestMQTTSub(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + for _, test := range []struct { + name string + mqttSubTopic string + natsPubSubject string + mqttPubTopic string + ok bool + }{ + {"1 level match", "foo", "foo", "foo", true}, + {"1 level no match", "foo", "bar", "bar", false}, + {"2 levels match", "foo/bar", "foo.bar", "foo/bar", true}, + {"2 levels no match", "foo/bar", "foo.baz", "foo/baz", false}, + {"3 levels match", "/foo/bar", "/.foo.bar", "/foo/bar", true}, + {"3 levels no match", "/foo/bar", "/.foo.baz", "/foo/baz", false}, + + {"single level wc", "foo/+", "foo.bar.baz", "foo/bar/baz", false}, + {"single level wc", "foo/+", "foo.bar./", "foo/bar/", false}, + {"single level wc", "foo/+", "foo.bar", "foo/bar", true}, + {"single level wc", "foo/+", "foo./", "foo/", true}, + {"single level wc", "foo/+", "foo", "foo", false}, + {"single level wc", "foo/+", "/.foo", "/foo", false}, + + {"multiple level wc", "foo/#", "foo.bar.baz./", "foo/bar/baz/", true}, + {"multiple level wc", "foo/#", "foo.bar.baz", "foo/bar/baz", true}, + {"multiple level wc", "foo/#", "foo.bar./", "foo/bar/", true}, + {"multiple level wc", "foo/#", "foo.bar", "foo/bar", true}, + {"multiple level wc", "foo/#", "foo./", "foo/", true}, + {"multiple level wc", "foo/#", "foo", "foo", true}, + {"multiple level wc", "foo/#", "/.foo", "/foo", false}, + } { + t.Run(test.name, func(t *testing.T) { + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: test.mqttSubTopic, qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + natsPub(t, nc, test.natsPubSubject, []byte("msg")) + if test.ok { + testMQTTCheckPubMsg(t, mc, r, test.mqttPubTopic, 0, []byte("msg")) + } else { + testMQTTExpectNothing(t, r) + } + + testMQTTPublish(t, mcp, mpr, 0, false, false, test.mqttPubTopic, 0, []byte("msg")) + if test.ok { + testMQTTCheckPubMsg(t, mc, r, test.mqttPubTopic, 0, []byte("msg")) + } else { + testMQTTExpectNothing(t, r) + } + }) + } +} + +func TestMQTTSubQoS(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + mqttTopic := "foo/bar" + + // Subscribe with QoS 1 + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 1}}, []byte{1}) + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: mqttTopic, qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + // Publish from NATS, which means QoS 0 + natsPub(t, nc, "foo.bar", []byte("NATS")) + // Will receive as QoS 0 + testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("NATS")) + testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("NATS")) + + // Publish from MQTT with QoS 0 + testMQTTPublish(t, mcp, mpr, 0, false, false, mqttTopic, 0, []byte("msg")) + // Will receive as QoS 0 + testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("msg")) + + // Publish from MQTT with QoS 1 + testMQTTPublish(t, mcp, mpr, 1, false, false, mqttTopic, 1, []byte("msg")) + pflags1, pi1 := testMQTTGetPubMsg(t, mc, r, mqttTopic, []byte("msg")) + if pflags1 != 0x2 { + t.Fatalf("Expected flags to be 0x2, got %v", pflags1) + } + pflags2, pi2 := testMQTTGetPubMsg(t, mc, r, mqttTopic, []byte("msg")) + if pflags2 != 0x2 { + t.Fatalf("Expected flags to be 0x2, got %v", pflags2) + } + if pi1 == pi2 { + t.Fatalf("packet identifier for message 1: %v should be different from message 2", pi1) + } + testMQTTSendPubAck(t, mc, pi1) + testMQTTSendPubAck(t, mc, pi2) +} + +func getSubQoS(sub *subscription) int { + if sub.mqtt != nil { + return int(sub.mqtt.qos) + } + return -1 +} + +func TestMQTTSubDups(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Test with single SUBSCRIBE protocol but multiple filters + filters := []*mqttFilter{ + &mqttFilter{filter: "foo", qos: 1}, + &mqttFilter{filter: "foo", qos: 0}, + } + testMQTTSub(t, 1, mc, r, filters, []byte{1, 0}) + testMQTTFlush(t, mc, nil, r) + + // And also with separate SUBSCRIBE protocols + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 0}}, []byte{0}) + // Ask for QoS 2 but server will downgrade to 1 + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 2}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + // Publish and test msg received only once + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + testMQTTPublish(t, mcp, r, 0, false, false, "bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "bar", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + // Check that the QoS for subscriptions have been updated to the latest received filter + var err error + subc := testMQTTGetClient(t, s, "sub") + subc.mu.Lock() + if subc.opts.Username != "sub" { + err = fmt.Errorf("wrong user name") + } + if err == nil { + if sub := subc.subs["foo"]; sub == nil || getSubQoS(sub) != 0 { + err = fmt.Errorf("subscription foo QoS should be 0, got %v", getSubQoS(sub)) + } + } + if err == nil { + if sub := subc.subs["bar"]; sub == nil || getSubQoS(sub) != 1 { + err = fmt.Errorf("subscription bar QoS should be 1, got %v", getSubQoS(sub)) + } + } + subc.mu.Unlock() + if err != nil { + t.Fatal(err) + } + + // Now subscribe on "foo/#" which means that a PUBLISH on "foo" will be received + // by this subscription and also the one on "foo". + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + // Publish and test msg received twice + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + + checkWCSub := func(expectedQoS int) { + t.Helper() + + subc.mu.Lock() + defer subc.mu.Unlock() + + // When invoked with expectedQoS==1, we have the following subs: + // foo (QoS-0), bar (QoS-1), foo.> (QoS-1) + // which means (since QoS-1 have a JS consumer + sub for delivery + // and foo.> causes a "foo fwc") that we should have the following + // number of NATS subs: foo (1), bar (2), foo.> (2) and "foo fwc" (2), + // so total=7. + // When invoked with expectedQoS==0, it means that we have replaced + // foo/# QoS-1 to QoS-0, so we should have 2 less NATS subs, + // so total=5 + expected := 7 + if expectedQoS == 0 { + expected = 5 + } + if lenmap := len(subc.subs); lenmap != expected { + t.Fatalf("Subs map should have %v entries, got %v", expected, lenmap) + } + if sub, ok := subc.subs["foo.>"]; !ok { + t.Fatal("Expected sub foo.> to be present but was not") + } else if getSubQoS(sub) != expectedQoS { + t.Fatalf("Expected sub foo.> QoS to be %v, got %v", expectedQoS, getSubQoS(sub)) + } + if sub, ok := subc.subs["foo fwc"]; !ok { + t.Fatal("Expected sub foo fwc to be present but was not") + } else if getSubQoS(sub) != expectedQoS { + t.Fatalf("Expected sub foo fwc QoS to be %v, got %v", expectedQoS, getSubQoS(sub)) + } + // Make sure existing sub on "foo" qos was not changed. + if sub, ok := subc.subs["foo"]; !ok { + t.Fatal("Expected sub foo to be present but was not") + } else if getSubQoS(sub) != 0 { + t.Fatalf("Expected sub foo QoS to be 0, got %v", getSubQoS(sub)) + } + } + checkWCSub(1) + + // Sub again on same subject with lower QoS + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + // Publish and test msg received twice + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + checkWCSub(0) +} + +func TestMQTTSubWithSpaces(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo bar", qos: 0}}, []byte{mqttSubAckFailure}) +} + +func TestMQTTSubCaseSensitive(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "Foo/Bar", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + testMQTTPublish(t, mcp, r, 0, false, false, "Foo/Bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "Foo/Bar", 0, []byte("msg")) + + testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + natsPub(t, nc, "Foo.Bar", []byte("nats")) + testMQTTCheckPubMsg(t, mc, r, "Foo/Bar", 0, []byte("nats")) + + natsPub(t, nc, "foo.bar", []byte("nats")) + testMQTTExpectNothing(t, r) +} + +func TestMQTTPubSubMatrix(t *testing.T) { + for _, test := range []struct { + name string + natsPub bool + mqttPub bool + mqttPubQoS byte + natsSub bool + mqttSubQoS0 bool + mqttSubQoS1 bool + }{ + {"NATS to MQTT sub QoS-0", true, false, 0, false, true, false}, + {"NATS to MQTT sub QoS-1", true, false, 0, false, false, true}, + {"NATS to MQTT sub QoS-0 and QoS-1", true, false, 0, false, true, true}, + + {"MQTT QoS-0 to NATS sub", false, true, 0, true, false, false}, + {"MQTT QoS-0 to MQTT sub QoS-0", false, true, 0, false, true, false}, + {"MQTT QoS-0 to MQTT sub QoS-1", false, true, 0, false, false, true}, + {"MQTT QoS-0 to NATS sub and MQTT sub QoS-0", false, true, 0, true, true, false}, + {"MQTT QoS-0 to NATS sub and MQTT sub QoS-1", false, true, 0, true, false, true}, + {"MQTT QoS-0 to all subs", false, true, 0, true, true, true}, + + {"MQTT QoS-1 to NATS sub", false, true, 1, true, false, false}, + {"MQTT QoS-1 to MQTT sub QoS-0", false, true, 1, false, true, false}, + {"MQTT QoS-1 to MQTT sub QoS-1", false, true, 1, false, false, true}, + {"MQTT QoS-1 to NATS sub and MQTT sub QoS-0", false, true, 1, true, true, false}, + {"MQTT QoS-1 to NATS sub and MQTT sub QoS-1", false, true, 1, true, false, true}, + {"MQTT QoS-1 to all subs", false, true, 1, true, true, true}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + mc1, r1 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc1.Close() + testMQTTCheckConnAck(t, r1, mqttConnAckRCConnectionAccepted, false) + + mc2, r2 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, false) + + // First setup subscriptions based on test options. + var ns *nats.Subscription + if test.natsSub { + ns = natsSubSync(t, nc, "foo") + } + if test.mqttSubQoS0 { + testMQTTSub(t, 1, mc1, r1, []*mqttFilter{&mqttFilter{filter: "foo", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc1, nil, r1) + } + if test.mqttSubQoS1 { + testMQTTSub(t, 1, mc2, r2, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc2, nil, r2) + } + + // Just as a barrier + natsFlush(t, nc) + + // Now publish + if test.natsPub { + natsPubReq(t, nc, "foo", "", []byte("msg")) + } else { + testMQTTPublish(t, mc, r, test.mqttPubQoS, false, false, "foo", 1, []byte("msg")) + } + + // Check message received + if test.natsSub { + natsNexMsg(t, ns, time.Second) + // Make sure no other is received + if msg, err := ns.NextMsg(50 * time.Millisecond); err == nil { + t.Fatalf("Should not have gotten a second message, got %v", msg) + } + } + if test.mqttSubQoS0 { + testMQTTCheckPubMsg(t, mc1, r1, "foo", 0, []byte("msg")) + testMQTTExpectNothing(t, r1) + } + if test.mqttSubQoS1 { + var expectedFlag byte + if test.mqttPubQoS > 0 { + expectedFlag = test.mqttPubQoS << 1 + } + testMQTTCheckPubMsg(t, mc2, r2, "foo", expectedFlag, []byte("msg")) + testMQTTExpectNothing(t, r2) + } + }) + } +} + +func TestMQTTPreventSubWithMQTTSubPrefix(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, mc, r, + []*mqttFilter{&mqttFilter{filter: strings.ReplaceAll(mqttSubPrefix, ".", "/") + "foo/bar", qos: 1}}, + []byte{mqttSubAckFailure}) +} + +func TestMQTTSubWithNATSStream(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/bar", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + sc := &StreamConfig{ + Name: "test", + Storage: FileStorage, + Retention: InterestPolicy, + Subjects: []string{"foo.>"}, + } + mset, err := s.GlobalAccount().AddStream(sc) + if err != nil { + t.Fatalf("Unable to create stream: %v", err) + } + + sub := natsSubSync(t, nc, "bar") + cc := &ConsumerConfig{ + Durable: "dur", + AckPolicy: AckExplicit, + DeliverSubject: "bar", + } + if _, err := mset.AddConsumer(cc); err != nil { + t.Fatalf("Unable to add consumer: %v", err) + } + + // Now send message from NATS + resp, err := nc.Request("foo.bar", []byte("nats"), time.Second) + if err != nil { + t.Fatalf("Error publishing: %v", err) + } + ar := &ApiResponse{} + if err := json.Unmarshal(resp.Data, ar); err != nil || ar.Error != nil { + t.Fatalf("Unexpected response: err=%v resp=%+v", err, ar.Error) + } + + // Check that message is received by both + checkRecv := func(content string, flags byte) { + t.Helper() + if msg := natsNexMsg(t, sub, time.Second); string(msg.Data) != content { + t.Fatalf("Expected %q, got %q", content, msg.Data) + } + testMQTTCheckPubMsg(t, mc, r, "foo/bar", flags, []byte(content)) + } + checkRecv("nats", 0) + + // Send from MQTT as a QoS0 + testMQTTPublish(t, mc, r, 0, false, false, "foo/bar", 0, []byte("qos0")) + checkRecv("qos0", 0) + + // Send from MQTT as a QoS1 + testMQTTPublish(t, mc, r, 1, false, false, "foo/bar", 1, []byte("qos1")) + checkRecv("qos1", mqttPubQos1) +} + +func TestMQTTTrackPendingOverrun(t *testing.T) { + sess := &mqttSession{pending: make(map[uint16]*mqttPending)} + sub := &subscription{mqtt: &mqttSub{qos: 1}} + + sess.ppi = 0xFFFF + pi, _ := sess.trackPending(1, _EMPTY_, sub) + if pi != 1 { + t.Fatalf("Expected 1, got %v", pi) + } + + p := &mqttPending{} + for i := 1; i <= 0xFFFF; i++ { + sess.pending[uint16(i)] = p + } + pi, _ = sess.trackPending(1, _EMPTY_, sub) + if pi != 0 { + t.Fatalf("Expected 0, got %v", pi) + } + + delete(sess.pending, 1234) + pi, _ = sess.trackPending(1, _EMPTY_, sub) + if pi != 1234 { + t.Fatalf("Expected 1234, got %v", pi) + } +} + +func TestMQTTPreventStreamAndConsumerWithMQTTPrefix(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + sc := &StreamConfig{ + Name: mqttStreamNamePrefix + "test", + Storage: FileStorage, + Retention: InterestPolicy, + Subjects: []string{"foo.>"}, + } + if _, err := s.GlobalAccount().AddStream(sc); err == nil { + t.Fatal("Expected error") + } + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/bar", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + mset, err := s.GlobalAccount().LookupStream(mqttStreamName) + if err != nil { + t.Fatalf("Error looking up MQTT Stream: %v", err) + } + cc := &ConsumerConfig{ + Durable: "dur", + AckPolicy: AckExplicit, + DeliverSubject: "bar", + } + if _, err := mset.AddConsumer(cc); err == nil { + t.Fatal("Expected error") + } +} + +func TestMQTTSubRestart(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: false}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Start an MQTT subscription QoS=1 on "foo" + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) + + // Now start a NATS subscription on ">" (anything that would match the JS consumer delivery subject) + natsSubSync(t, nc, ">") + natsFlush(t, nc) + + // Restart the MQTT client + testMQTTDisconnect(t, mc, nil) + + mc, r = testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: false}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Restart an MQTT subscription QoS=1 on "foo" + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, mc, nil, r) +} + +func TestMQTTSubPropagation(t *testing.T) { + t.Skip("Skipping until JS clustering is supported") + o := testMQTTDefaultOptions() + o.Cluster.Host = "127.0.0.1" + o.Cluster.Port = -1 + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + o2 := DefaultOptions() + o2.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", o.Cluster.Port)) + s2 := RunServer(o2) + defer s2.Shutdown() + + checkClusterFormed(t, s, s2) + + nc := natsConnect(t, s2.ClientURL()) + defer nc.Close() + + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + // Because in MQTT foo/# means foo.> but also foo, check that this is propagated + checkSubInterest(t, s2, globalAccountName, "foo", time.Second) + + // Publish on foo.bar, foo./ and foo and we should receive them + natsPub(t, nc, "foo.bar", []byte("hello")) + testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("hello")) + + natsPub(t, nc, "foo./", []byte("from")) + testMQTTCheckPubMsg(t, mc, r, "foo/", 0, []byte("from")) + + natsPub(t, nc, "foo", []byte("NATS")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("NATS")) +} + +func TestMQTTParseUnsub(t *testing.T) { + eofr := testNewEOFReader() + for _, test := range []struct { + name string + proto []byte + b byte + pl int + reader mqttIOReader + err string + }{ + {"reserved flag", nil, 3, 0, nil, "wrong unsubscribe reserved flags"}, + {"ensure packet loaded", []byte{1, 2}, mqttUnsubscribeFlags, 10, eofr, "error ensuring protocol is loaded"}, + {"error reading packet id", []byte{1}, mqttUnsubscribeFlags, 1, eofr, "reading packet identifier"}, + {"missing filters", []byte{0, 1}, mqttUnsubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"}, + {"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttUnsubscribeFlags, 5, eofr, "topic filter"}, + {"empty topic", []byte{0, 1, 0, 0}, mqttUnsubscribeFlags, 4, nil, "topic filter cannot be empty"}, + {"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttUnsubscribeFlags, 5, nil, "invalid utf8 for topic filter"}, + } { + t.Run(test.name, func(t *testing.T) { + r := &mqttReader{reader: test.reader} + r.reset(test.proto) + mqtt := &mqtt{r: r} + c := &client{mqtt: mqtt} + if _, _, err := c.mqttParseSubsOrUnsubs(r, test.b, test.pl, false); err == nil || !strings.Contains(err.Error(), test.err) { + t.Fatalf("Expected error %q, got %v", test.err, err) + } + }) + } +} + +func testMQTTUnsub(t *testing.T, pi uint16, c net.Conn, r *mqttReader, filters []*mqttFilter) { + t.Helper() + w := &mqttWriter{} + pkLen := 2 // for pi + for i := 0; i < len(filters); i++ { + f := filters[i] + pkLen += 2 + len(f.filter) + } + w.WriteByte(mqttPacketUnsub | mqttUnsubscribeFlags) + w.WriteVarInt(pkLen) + w.WriteUint16(pi) + for i := 0; i < len(filters); i++ { + f := filters[i] + w.WriteBytes([]byte(f.filter)) + } + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing UNSUBSCRIBE protocol: %v", err) + } + // Make sure we have at least 1 byte in buffer (if not will read) + testMQTTReaderHasAtLeastOne(t, r) + // Parse UNSUBACK + b, err := r.readByte("packet type") + if err != nil { + t.Fatal(err) + } + if pt := b & mqttPacketMask; pt != mqttPacketUnsubAck { + t.Fatalf("Expected UNSUBACK packet %x, got %x", mqttPacketUnsubAck, pt) + } + pl, err := r.readPacketLen() + if err != nil { + t.Fatal(err) + } + if err := r.ensurePacketInBuffer(pl); err != nil { + t.Fatal(err) + } + rpi, err := r.readUint16("packet identifier") + if err != nil || rpi != pi { + t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err) + } +} + +func TestMQTTUnsub(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcp.Close() + testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false) + + mc, r := testMQTTConnect(t, &mqttConnInfo{user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 0}}, []byte{0}) + testMQTTFlush(t, mc, nil, r) + + // Publish and test msg received + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg")) + + // Unsubscribe + testMQTTUnsub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo"}}) + + // Publish and test msg not received + testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + // Use of wildcards subs + filters := []*mqttFilter{ + &mqttFilter{filter: "foo/bar", qos: 0}, + &mqttFilter{filter: "foo/#", qos: 0}, + } + testMQTTSub(t, 1, mc, r, filters, []byte{0, 0}) + testMQTTFlush(t, mc, nil, r) + + // Publish and check that message received twice + testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("msg")) + + // Unsub the wildcard one + testMQTTUnsub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/#"}}) + // Publish and check that message received once + testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg")) + testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("msg")) + testMQTTExpectNothing(t, r) + + // Unsub last + testMQTTUnsub(t, 1, mc, r, []*mqttFilter{&mqttFilter{filter: "foo/bar"}}) + // Publish and test msg not received + testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg")) + testMQTTExpectNothing(t, r) +} + +func testMQTTExpectDisconnect(t testing.TB, c net.Conn) { + if buf, err := testMQTTRead(c); err == nil { + t.Fatalf("Expected connection to be disconnected, got %s", buf) + } +} + +func TestMQTTPublishTopicErrors(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + for _, test := range []struct { + name string + topic string + }{ + {"empty", ""}, + {"with single level wildcard", "foo/+"}, + {"with multiple level wildcard", "foo/#"}, + } { + t.Run(test.name, func(t *testing.T) { + mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, mc, r, 0, false, false, test.topic, 0, []byte("msg")) + testMQTTExpectDisconnect(t, mc) + }) + } +} + +func testMQTTDisconnect(t testing.TB, c net.Conn, bw *bufio.Writer) { + t.Helper() + w := &mqttWriter{} + w.WriteByte(mqttPacketDisconnect) + w.WriteByte(0) + if bw != nil { + bw.Write(w.Bytes()) + bw.Flush() + } else { + c.Write(w.Bytes()) + } + testMQTTExpectDisconnect(t, c) +} + +func TestMQTTWill(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + sub := natsSubSync(t, nc, "will.topic") + + willMsg := []byte("bye") + + for _, test := range []struct { + name string + willExpected bool + willQoS byte + }{ + {"will qos 0", true, 0}, + {"will qos 1", true, 1}, + {"proper disconnect no will", false, 0}, + } { + t.Run(test.name, func(t *testing.T) { + mcs, rs := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "will/#", qos: 1}}, []byte{1}) + testMQTTFlush(t, mcs, nil, rs) + + mc, r := testMQTTConnect(t, + &mqttConnInfo{ + cleanSess: true, + will: &mqttWill{ + topic: []byte("will/topic"), + message: willMsg, + qos: test.willQoS, + }, + }, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + if test.willExpected { + mc.Close() + testMQTTCheckPubMsg(t, mcs, rs, "will/topic", test.willQoS<<1, willMsg) + wm := natsNexMsg(t, sub, time.Second) + if !bytes.Equal(wm.Data, willMsg) { + t.Fatalf("Expected will message to be %q, got %q", willMsg, wm.Data) + } + } else { + testMQTTDisconnect(t, mc, nil) + testMQTTExpectNothing(t, rs) + if wm, err := sub.NextMsg(100 * time.Millisecond); err == nil { + t.Fatalf("Should not have receive a message, got %v", wm) + } + } + }) + } +} + +func TestMQTTWillRetain(t *testing.T) { + for _, test := range []struct { + name string + pubQoS byte + subQoS byte + }{ + {"pub QoS0 sub QoS0", 0, 0}, + {"pub QoS0 sub QoS1", 0, 1}, + {"pub QoS1 sub QoS0", 1, 0}, + {"pub QoS1 sub QoS1", 1, 1}, + } { + t.Run(test.name, func(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + willTopic := []byte("will/topic") + willMsg := []byte("bye") + + mc, r := testMQTTConnect(t, + &mqttConnInfo{ + cleanSess: true, + will: &mqttWill{ + topic: willTopic, + message: willMsg, + qos: test.pubQoS, + retain: true, + }, + }, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Disconnect, which will cause will to be produced with retain flag. + mc.Close() + + // Create subscription on will topic and expect will message. + mcs, rs := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "will/#", qos: test.subQoS}}, []byte{test.subQoS}) + pflags, _ := testMQTTGetPubMsg(t, mcs, rs, "will/topic", willMsg) + if pflags&mqttPubFlagRetain == 0 { + t.Fatalf("expected retain flag to be set, it was not: %v", pflags) + } + // Expected QoS will be the lesser of the pub/sub QoS. + expectedQoS := test.pubQoS + if test.subQoS == 0 { + expectedQoS = 0 + } + if qos := mqttGetQoS(pflags); qos != expectedQoS { + t.Fatalf("expected qos to be %v, got %v", expectedQoS, qos) + } + }) + } +} + +func TestMQTTWillRetainPermViolation(t *testing.T) { + template := ` + port: -1 + jetstream: enabled + authorization { + mqtt_perms = { + publish = ["%s"] + subscribe = ["foo", "bar", "$MQTT.sub.>"] + } + users = [ + {user: mqtt, password: pass, permissions: $mqtt_perms} + ] + } + mqtt { + port: -1 + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(template, "foo"))) + defer os.Remove(conf) + + s, o := RunServerWithConfig(conf) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: "mqtt", + pass: "pass", + } + + // We create first a connection with the Will topic that the publisher + // is allowed to publish to. + ci.will = &mqttWill{ + topic: []byte("foo"), + message: []byte("bye"), + qos: 1, + retain: true, + } + mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Disconnect, which will cause the Will to be sent with retain flag. + mc.Close() + + // Create a subscription on the Will subject and we should receive it. + ci.will = nil + mcs, rs := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + pflags, _ := testMQTTGetPubMsg(t, mcs, rs, "foo", []byte("bye")) + if pflags&mqttPubFlagRetain == 0 { + t.Fatalf("expected retain flag to be set, it was not: %v", pflags) + } + if qos := mqttGetQoS(pflags); qos != 1 { + t.Fatalf("expected qos to be 1, got %v", qos) + } + testMQTTDisconnect(t, mcs, nil) + + // Now create another connection with a Will that client is not allowed to publish to. + ci.will = &mqttWill{ + topic: []byte("bar"), + message: []byte("bye"), + qos: 1, + retain: true, + } + mc, r = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Disconnect, to cause Will to be produced, but in that case should not be stored + // since user not allowed to publish on "bar". + mc.Close() + + // Create sub on "bar" which user is allowed to subscribe to. + ci.will = nil + mcs, rs = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + // No Will should be published since it should not have been stored in the first place. + testMQTTExpectNothing(t, rs) + testMQTTDisconnect(t, mcs, nil) + + // Now remove permission to publish on "foo" and check that a new subscription + // on "foo" is now not getting the will message because the original user no + // longer has permission to do so. + reloadUpdateConfig(t, s, conf, fmt.Sprintf(template, "baz")) + + mcs, rs = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mcs.Close() + testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mcs, rs, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTExpectNothing(t, rs) + testMQTTDisconnect(t, mcs, nil) +} + +func TestMQTTPublishRetain(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + for _, test := range []struct { + name string + retained bool + sentValue string + expectedValue string + subGetsIt bool + }{ + {"publish retained", true, "retained", "retained", true}, + {"publish not retained", false, "not retained", "retained", true}, + {"remove retained", true, "", "", false}, + } { + t.Run(test.name, func(t *testing.T) { + mc1, rs1 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc1.Close() + testMQTTCheckConnAck(t, rs1, mqttConnAckRCConnectionAccepted, false) + testMQTTPublish(t, mc1, rs1, 0, false, test.retained, "foo", 0, []byte(test.sentValue)) + + mc2, rs2 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc2.Close() + testMQTTCheckConnAck(t, rs2, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc2, rs2, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 1}}, []byte{1}) + + if test.subGetsIt { + pflags, _ := testMQTTGetPubMsg(t, mc2, rs2, "foo", []byte(test.expectedValue)) + if pflags&mqttPubFlagRetain == 0 { + t.Fatalf("retain flag should have been set, it was not: flags=%v", pflags) + } + } else { + testMQTTExpectNothing(t, rs2) + } + + testMQTTDisconnect(t, mc1, nil) + testMQTTDisconnect(t, mc2, nil) + }) + } +} + +func TestMQTTPublishRetainPermViolation(t *testing.T) { + o := testMQTTDefaultOptions() + o.Users = []*User{ + { + Username: "mqtt", + Password: "pass", + Permissions: &Permissions{ + Publish: &SubjectPermission{Allow: []string{"foo"}}, + Subscribe: &SubjectPermission{Allow: []string{"bar", "$MQTT.sub.>"}}, + }, + }, + } + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + cleanSess: true, + user: "mqtt", + pass: "pass", + } + + mc1, rs1 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc1.Close() + testMQTTCheckConnAck(t, rs1, mqttConnAckRCConnectionAccepted, false) + testMQTTPublish(t, mc1, rs1, 0, false, true, "bar", 0, []byte("retained")) + + mc2, rs2 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer mc2.Close() + testMQTTCheckConnAck(t, rs2, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, mc2, rs2, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + testMQTTExpectNothing(t, rs2) + + testMQTTDisconnect(t, mc1, nil) + testMQTTDisconnect(t, mc2, nil) +} + +func TestMQTTCleanSession(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + clientID: "me", + cleanSess: false, + } + c, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTDisconnect(t, c, nil) + + c, r = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + testMQTTDisconnect(t, c, nil) + + ci.cleanSess = true + c, r = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTDisconnect(t, c, nil) +} + +func TestMQTTDuplicateClientID(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{ + clientID: "me", + cleanSess: false, + } + c1, r1 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c1.Close() + testMQTTCheckConnAck(t, r1, mqttConnAckRCConnectionAccepted, false) + + c2, r2 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) + defer c2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, true) + + // The old client should be disconnected. + testMQTTExpectDisconnect(t, c1) +} + +func TestMQTTPersistedSession(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, c, r, + []*mqttFilter{ + &mqttFilter{filter: "foo/#", qos: 1}, + &mqttFilter{filter: "bar", qos: 1}, + &mqttFilter{filter: "baz", qos: 0}, + }, + []byte{1, 1, 0}) + testMQTTFlush(t, c, nil, r) + + // Shutdown server, close connection and restart server. It should + // have restored the session and consumers. + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + s.Shutdown() + c.Close() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + // Create a publisher that will send qos1 so we verify that messages + // are stored for the persisted sessions. + c, r = testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, c, r, 1, false, false, "foo/bar", 1, []byte("msg0")) + testMQTTFlush(t, c, nil, r) + testMQTTDisconnect(t, c, nil) + c.Close() + + // Recreate consumer session + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Since consumers have been recovered, messages should be received + // (MQTT does not need client to recreate consumers for a recovered + // session) + + // Check that qos1 publish message is received. + testMQTTCheckPubMsg(t, c, r, "foo/bar", mqttPubQos1, []byte("msg0")) + + // Now publish some messages to all subscriptions. + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + natsPub(t, nc, "foo.bar", []byte("msg1")) + testMQTTCheckPubMsg(t, c, r, "foo/bar", 0, []byte("msg1")) + + natsPub(t, nc, "foo", []byte("msg2")) + testMQTTCheckPubMsg(t, c, r, "foo", 0, []byte("msg2")) + + natsPub(t, nc, "bar", []byte("msg3")) + testMQTTCheckPubMsg(t, c, r, "bar", 0, []byte("msg3")) + + natsPub(t, nc, "baz", []byte("msg4")) + testMQTTCheckPubMsg(t, c, r, "baz", 0, []byte("msg4")) + + // Now unsub "bar" and verify that message published on this topic + // is not received. + testMQTTUnsub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "bar"}}) + natsPub(t, nc, "bar", []byte("msg5")) + testMQTTExpectNothing(t, r) + + nc.Close() + s.Shutdown() + c.Close() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + // Recreate a client + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + nc = natsConnect(t, s.ClientURL()) + defer nc.Close() + + natsPub(t, nc, "foo.bar", []byte("msg6")) + testMQTTCheckPubMsg(t, c, r, "foo/bar", 0, []byte("msg6")) + + natsPub(t, nc, "foo", []byte("msg7")) + testMQTTCheckPubMsg(t, c, r, "foo", 0, []byte("msg7")) + + // Make sure that we did not recover bar. + natsPub(t, nc, "bar", []byte("msg8")) + testMQTTExpectNothing(t, r) + + natsPub(t, nc, "baz", []byte("msg9")) + testMQTTCheckPubMsg(t, c, r, "baz", 0, []byte("msg9")) + + // Have the sub client send a subscription downgrading the qos1 subscription. + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo/#", qos: 0}}, []byte{0}) + testMQTTFlush(t, c, nil, r) + + nc.Close() + s.Shutdown() + c.Close() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + // Recreate the sub client + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Publish as a qos1 + c2, r2 := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer c2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, false) + testMQTTPublish(t, c2, r2, 1, false, false, "foo/bar", 1, []byte("msg10")) + + // Verify that it is received as qos0 which is the qos of the subscription. + testMQTTCheckPubMsg(t, c, r, "foo/bar", 0, []byte("msg10")) + + testMQTTDisconnect(t, c, nil) + c.Close() + testMQTTDisconnect(t, c2, nil) + c2.Close() + + // Finally, recreate the sub with clean session and ensure that all is gone + cisub.cleanSess = true + for i := 0; i < 2; i++ { + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + nc = natsConnect(t, s.ClientURL()) + defer nc.Close() + + natsPub(t, nc, "foo.bar", []byte("msg11")) + testMQTTExpectNothing(t, r) + + natsPub(t, nc, "foo", []byte("msg12")) + testMQTTExpectNothing(t, r) + + // Make sure that we did not recover bar. + natsPub(t, nc, "bar", []byte("msg13")) + testMQTTExpectNothing(t, r) + + natsPub(t, nc, "baz", []byte("msg14")) + testMQTTExpectNothing(t, r) + + testMQTTDisconnect(t, c, nil) + c.Close() + nc.Close() + + s.Shutdown() + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + } +} + +func TestMQTTRecoverSessionAndAddNewSub(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + cisub := &mqttConnInfo{clientID: "sub1", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTDisconnect(t, c, nil) + c.Close() + + // Shutdown server, close connection and restart server. It should + // have restored the session and consumers. + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + s.Shutdown() + c.Close() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // No need for defer since it is done top of function + + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + // Now add sub and make sure it does not crash + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, c, nil, r) + + // Now repeat with a new client but without server restart. + cisub2 := &mqttConnInfo{clientID: "sub2", cleanSess: false} + c2, r2 := testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port) + defer c2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, false) + testMQTTDisconnect(t, c2, nil) + c2.Close() + + c2, r2 = testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port) + defer c2.Close() + testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, true) + testMQTTSub(t, 1, c2, r2, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + testMQTTFlush(t, c2, nil, r2) +} + +func TestMQTTRecoverSessionWithSubAndClientResendSub(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + cisub1 := &mqttConnInfo{clientID: "sub1", cleanSess: false} + c, r := testMQTTConnect(t, cisub1, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + // Have a client send a SUBSCRIBE protocol for foo, QoS1 + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTDisconnect(t, c, nil) + c.Close() + + // Restart the server now. + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + s.Shutdown() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // No need for defer since it is done top of function + + // Now restart the client. Since the client was created with cleanSess==false, + // the server will have recorded the subscriptions for this client. + c, r = testMQTTConnect(t, cisub1, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + // At this point, the server has recreated the subscription on foo, QoS1. + + // For applications that restart, it is possible (likely) that they + // will resend their SUBSCRIBE protocols, so do so now: + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, c, nil, r) + + checkNumSub := func(clientID string) { + t.Helper() + + // Find the MQTT client... + mc := testMQTTGetClient(t, s, clientID) + + // Check how many NATS subscriptions are registered. + var fooSub int + var otherSub int + mc.mu.Lock() + for _, sub := range mc.subs { + switch string(sub.subject) { + case "foo": + fooSub++ + default: + otherSub++ + } + } + mc.mu.Unlock() + + // We should have 2 subscriptions, one on "foo", and one for the JS durable + // consumer's delivery subject. + if fooSub != 1 { + t.Fatalf("Expected 1 sub on 'foo', got %v", fooSub) + } + if otherSub != 1 { + t.Fatalf("Expected 1 subscription for JS durable, got %v", otherSub) + } + } + checkNumSub("sub1") + + c.Close() + + // Now same but without the server restart in-between. + cisub2 := &mqttConnInfo{clientID: "sub2", cleanSess: false} + c, r = testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTDisconnect(t, c, nil) + c.Close() + // Restart client + c, r = testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTFlush(t, c, nil, r) + // Check client subs + checkNumSub("sub2") +} + +func TestMQTTPersistRetainedMsg(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + c, r := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, c, r, 1, false, true, "foo", 1, []byte("foo1")) + testMQTTPublish(t, c, r, 1, false, true, "foo", 1, []byte("foo2")) + testMQTTPublish(t, c, r, 1, false, true, "bar", 1, []byte("bar1")) + testMQTTPublish(t, c, r, 0, false, true, "baz", 1, []byte("baz1")) + // Remove bar + testMQTTPublish(t, c, r, 1, false, true, "bar", 1, nil) + testMQTTDisconnect(t, c, nil) + c.Close() + + s.Shutdown() + + o.Port = -1 + o.MQTT.Port = -1 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTCheckPubMsg(t, c, r, "foo", mqttPubFlagRetain|mqttPubQos1, []byte("foo2")) + + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "baz", qos: 1}}, []byte{1}) + testMQTTCheckPubMsg(t, c, r, "baz", mqttPubFlagRetain, []byte("baz1")) + + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + testMQTTExpectNothing(t, r) + + testMQTTDisconnect(t, c, nil) + c.Close() +} + +func TestMQTTConnAckFirstProto(t *testing.T) { + o := testMQTTDefaultOptions() + o.NoLog, o.Debug, o.Trace = true, false, false + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 0}}, []byte{0}) + testMQTTDisconnect(t, c, nil) + c.Close() + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + wg := sync.WaitGroup{} + wg.Add(1) + ch := make(chan struct{}, 1) + ready := make(chan struct{}) + go func() { + defer wg.Done() + + close(ready) + for { + nc.Publish("foo", []byte("msg")) + select { + case <-ch: + return + default: + } + } + }() + + <-ready + for i := 0; i < 100; i++ { + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + c.Close() + } + close(ch) + wg.Wait() +} + +func TestMQTTRedeliveryAckWait(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.AckWait = 250 * time.Millisecond + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("foo1")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 2, []byte("foo2")) + testMQTTDisconnect(t, cp, nil) + cp.Close() + + for i := 0; i < 2; i++ { + flags := mqttPubQos1 + if i > 0 { + flags |= mqttPubFlagDup + } + pi1 := testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("foo1")) + pi2 := testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("foo2")) + + if pi1 != 1 || pi2 != 2 { + t.Fatalf("Unexpected pi values: %v, %v", pi1, pi2) + } + } + // Ack first message + testMQTTSendPubAck(t, c, 1) + // Redelivery should only be for second message now + for i := 0; i < 2; i++ { + flags := mqttPubQos1 | mqttPubFlagDup + pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("foo2")) + if pi != 2 { + t.Fatalf("Unexpected pi to be 2, got %v", pi) + } + } + + // Restart client, should receive second message with pi==2 + c.Close() + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + // Check that message is received with proper pi + pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("foo2")) + if pi != 2 { + t.Fatalf("Unexpected pi to be 2, got %v", pi) + } + // Now ack second message + testMQTTSendPubAck(t, c, 2) + // Flush to make sure it is processed before checking client's maps + testMQTTFlush(t, c, nil, r) + + // Look for the sub client + mc := testMQTTGetClient(t, s, "sub") + mc.mu.Lock() + sess := mc.mqtt.sess + sess.mu.Lock() + lpi := len(sess.pending) + var lsseq int + for _, sseqToPi := range sess.cpending { + lsseq += len(sseqToPi) + } + sess.mu.Unlock() + mc.mu.Unlock() + if lpi != 0 || lsseq != 0 { + t.Fatalf("Maps should be empty, got %v, %v", lpi, lsseq) + } +} + +func TestMQTTAckWaitConfigChange(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.AckWait = 250 * time.Millisecond + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + sendMsg := func(topic, payload string) { + t.Helper() + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, topic, 1, []byte(payload)) + testMQTTDisconnect(t, cp, nil) + cp.Close() + } + sendMsg("foo", "msg1") + + for i := 0; i < 2; i++ { + flags := mqttPubQos1 + if i > 0 { + flags |= mqttPubFlagDup + } + testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("msg1")) + } + + // Restart the server with a different AckWait option value. + // Verify that MQTT sub restart succeeds. It will keep the + // original value. + c.Close() + s.Shutdown() + + o.Port = -1 + o.MQTT.Port = -1 + o.MQTT.AckWait = 10 * time.Millisecond + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1")) + start := time.Now() + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1")) + if dur := time.Since(start); dur < 200*time.Millisecond { + t.Fatalf("AckWait seem to have changed for existing subscription: %v", dur) + } + + // Create new subscription + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + sendMsg("bar", "msg2") + testMQTTCheckPubMsgNoAck(t, c, r, "bar", mqttPubQos1, []byte("msg2")) + start = time.Now() + testMQTTCheckPubMsgNoAck(t, c, r, "bar", mqttPubQos1|mqttPubFlagDup, []byte("msg2")) + if dur := time.Since(start); dur > 50*time.Millisecond { + t.Fatalf("AckWait new value not used by new sub: %v", dur) + } + c.Close() +} + +func TestMQTTUnsubscribeWithPendingAcks(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.AckWait = 250 * time.Millisecond + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg")) + testMQTTDisconnect(t, cp, nil) + cp.Close() + + for i := 0; i < 2; i++ { + flags := mqttPubQos1 + if i > 0 { + flags |= mqttPubFlagDup + } + testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("msg")) + } + + testMQTTUnsub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo"}}) + testMQTTFlush(t, c, nil, r) + + mc := testMQTTGetClient(t, s, "sub") + mc.mu.Lock() + sess := mc.mqtt.sess + sess.mu.Lock() + pal := len(sess.pending) + sess.mu.Unlock() + mc.mu.Unlock() + if pal != 0 { + t.Fatalf("Expected pending ack map to be empty, got %v", pal) + } +} + +func TestMQTTMaxAckPending(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.MaxAckPending = 1 + s := testMQTTRunServer(t, o) + defer func() { + testMQTTShutdownServer(s) + }() + + dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir) + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg2")) + + pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1")) + // Check that we don't receive the second one due to max ack pending + testMQTTExpectNothing(t, r) + + // Now ack first message + testMQTTSendPubAck(t, c, pi) + // Now we should receive message 2 + testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg2")) + testMQTTDisconnect(t, c, nil) + + // Send 2 messages while sub is offline + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg3")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg4")) + + // Restart consumer + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Should receive only message 3 + pi = testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg3")) + testMQTTExpectNothing(t, r) + + // Ack and get the next + testMQTTSendPubAck(t, c, pi) + testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg4")) + + // Check that change to config does not prevent restart of sub. + cp.Close() + c.Close() + s.Shutdown() + + o.Port = -1 + o.MQTT.Port = -1 + o.MQTT.MaxAckPending = 2 + o.StoreDir = dir + s = testMQTTRunServer(t, o) + // There is already the defer for shutdown at top of function + + cp, rp = testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg5")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg6")) + + // Restart consumer + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true) + + // Should receive only message 5 + pi = testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg5")) + testMQTTExpectNothing(t, r) + + // Ack and get the next + testMQTTSendPubAck(t, c, pi) + testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg6")) +} + +func TestMQTTMaxAckPendingForMultipleSubs(t *testing.T) { + o := testMQTTDefaultOptions() + o.MQTT.MaxAckPending = 1 + s := testMQTTRunServer(t, o) + defer s.Shutdown() + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "bar", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1")) + pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1")) + + // Now send a second message but on topic bar + testMQTTPublish(t, cp, rp, 1, false, false, "bar", 1, []byte("msg2")) + + // JS allows us to limit per consumer, but we apply the limit to the + // session, so although JS will attempt to delivery this message, + // the MQTT code will suppress it. + testMQTTExpectNothing(t, r) + + // Ack the first message. + testMQTTSendPubAck(t, c, pi) + + // Now we should get the second message + testMQTTCheckPubMsg(t, c, r, "bar", mqttPubQos1|mqttPubFlagDup, []byte("msg2")) +} + +func TestMQTTConfigReload(t *testing.T) { + template := ` + jetstream: true + mqtt { + port: -1 + ack_wait: %s + max_ack_pending: %s + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(template, `"5s"`, `10000`))) + defer os.Remove(conf) + + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + if val := o.MQTT.AckWait; val != 5*time.Second { + t.Fatalf("Invalid ackwait: %v", val) + } + if val := o.MQTT.MaxAckPending; val != 10000 { + t.Fatalf("Invalid ackwait: %v", val) + } + + changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, `"250ms"`, `1`))) + if err := s.Reload(); err != nil { + t.Fatalf("Error on reload: %v", err) + } + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub := &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg2")) + + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1")) + start := time.Now() + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1")) + if dur := time.Since(start); dur > 500*time.Millisecond { + t.Fatalf("AckWait not applied? dur=%v", dur) + } + c.Close() + cp.Close() + s.Shutdown() + + changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, `"30s"`, `1`))) + s, o = RunServerWithConfig(conf) + defer s.Shutdown() + + c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{&mqttFilter{filter: "foo", qos: 1}}, []byte{1}) + + cipub = &mqttConnInfo{clientID: "pub", cleanSess: true} + cp, rp = testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port) + defer cp.Close() + testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false) + + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1")) + testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg2")) + + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1")) + testMQTTExpectNothing(t, r) + + // Increate the max ack pending + changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, `"30s"`, `10`))) + // Reload now + if err := s.Reload(); err != nil { + t.Fatalf("Error on reload: %v", err) + } + // See that message 2 can now be received (1 will be redelivered too) + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1")) + testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg2")) +} + +// Benchmarks + +const ( + mqttPubSubj = "a" + mqttBenchBufLen = 32768 +) + +func mqttBenchPubQoS0(b *testing.B, subject, payload string, numSubs int) { + b.StopTimer() + o := testMQTTDefaultOptions() + s := RunServer(o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{clientID: "pub", cleanSess: true} + c, br := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(b, br, mqttConnAckRCConnectionAccepted, false) + w := &mqttWriter{} + mqttWritePublish(w, 0, false, false, subject, 0, []byte(payload)) + sendOp := w.Bytes() + + dch := make(chan error, 1) + totalSize := int64(len(sendOp)) + cdch := 0 + + createSub := func(i int) { + ci := &mqttConnInfo{clientID: fmt.Sprintf("sub%d", i), cleanSess: true} + cs, brs := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(b, brs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(b, 1, cs, brs, []*mqttFilter{&mqttFilter{filter: subject, qos: 0}}, []byte{0}) + testMQTTFlush(b, cs, nil, brs) + + w := &mqttWriter{} + varHeaderAndPayload := 2 + len(subject) + len(payload) + w.WriteVarInt(varHeaderAndPayload) + size := 1 + w.Len() + varHeaderAndPayload + totalSize += int64(size) + + go func() { + mqttBenchConsumeMsgQoS0(cs, int64(b.N)*int64(size), dch) + cs.Close() + }() + } + for i := 0; i < numSubs; i++ { + createSub(i + 1) + cdch++ + } + + bw := bufio.NewWriterSize(c, mqttBenchBufLen) + b.SetBytes(totalSize) + b.StartTimer() + for i := 0; i < b.N; i++ { + bw.Write(sendOp) + } + testMQTTFlush(b, c, bw, br) + for i := 0; i < cdch; i++ { + if e := <-dch; e != nil { + b.Fatal(e.Error()) + } + } + b.StopTimer() + c.Close() + s.Shutdown() +} + +func mqttBenchConsumeMsgQoS0(c net.Conn, total int64, dch chan<- error) { + var buf [mqttBenchBufLen]byte + var err error + var n int + for size := int64(0); size < total; { + n, err = c.Read(buf[:]) + if err != nil { + break + } + size += int64(n) + } + dch <- err +} + +func mqttBenchPubQoS1(b *testing.B, subject, payload string, numSubs int) { + b.StopTimer() + o := testMQTTDefaultOptions() + o.MQTT.MaxAckPending = 0xFFFF + s := RunServer(o) + defer testMQTTShutdownServer(s) + + ci := &mqttConnInfo{cleanSess: true} + c, br := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(b, br, mqttConnAckRCConnectionAccepted, false) + + w := &mqttWriter{} + mqttWritePublish(w, 1, false, false, subject, 1, []byte(payload)) + // For reported bytes we will count the PUBLISH + PUBACK (4 bytes) + totalSize := int64(len(w.Bytes()) + 4) + w.Reset() + + pi := uint16(1) + maxpi := uint16(60000) + ppich := make(chan error, 10) + dch := make(chan error, 1+numSubs) + cdch := 1 + // Start go routine to consume PUBACK for published QoS 1 messages. + go mqttBenchConsumePubAck(c, b.N, dch, ppich) + + createSub := func(i int) { + ci := &mqttConnInfo{clientID: fmt.Sprintf("sub%d", i), cleanSess: true} + cs, brs := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(b, brs, mqttConnAckRCConnectionAccepted, false) + + testMQTTSub(b, 1, cs, brs, []*mqttFilter{&mqttFilter{filter: subject, qos: 1}}, []byte{1}) + testMQTTFlush(b, cs, nil, brs) + + w := &mqttWriter{} + varHeaderAndPayload := 2 + len(subject) + 2 + len(payload) + w.WriteVarInt(varHeaderAndPayload) + size := 1 + w.Len() + varHeaderAndPayload + // Add to the bytes reported the size of message sent to subscriber + PUBACK (4 bytes) + totalSize += int64(size + 4) + + go func() { + mqttBenchConsumeMsgQos1(cs, b.N, size, dch) + cs.Close() + }() + } + for i := 0; i < numSubs; i++ { + createSub(i + 1) + cdch++ + } + + flush := func() { + b.Helper() + if _, err := c.Write(w.Bytes()); err != nil { + b.Fatalf("Error on write: %v", err) + } + w.Reset() + } + + b.SetBytes(totalSize) + b.StartTimer() + for i := 0; i < b.N; i++ { + if pi <= maxpi { + mqttWritePublish(w, 1, false, false, subject, pi, []byte(payload)) + pi++ + if w.Len() >= mqttBenchBufLen { + flush() + } + } else { + if w.Len() > 0 { + flush() + } + if pi > 60000 { + pi = 1 + maxpi = 0 + } + if e := <-ppich; e != nil { + b.Fatal(e.Error()) + } + maxpi += 10000 + i-- + } + } + if w.Len() > 0 { + flush() + } + for i := 0; i < cdch; i++ { + if e := <-dch; e != nil { + b.Fatal(e.Error()) + } + } + b.StopTimer() + c.Close() + s.Shutdown() +} + +func mqttBenchConsumeMsgQos1(c net.Conn, total, size int, dch chan<- error) { + var buf [mqttBenchBufLen]byte + pubAck := [4]byte{mqttPacketPubAck, 0x2, 0, 0} + var err error + var n int + var pi uint16 + var prev int + for i := 0; i < total; { + n, err = c.Read(buf[:]) + if err != nil { + break + } + n += prev + for ; n >= size; n -= size { + i++ + pi++ + pubAck[2] = byte(pi >> 8) + pubAck[3] = byte(pi) + if _, err = c.Write(pubAck[:4]); err != nil { + dch <- err + return + } + if pi == 60000 { + pi = 0 + } + } + prev = n + } + dch <- err +} + +func mqttBenchConsumePubAck(c net.Conn, total int, dch, ppich chan<- error) { + var buf [mqttBenchBufLen]byte + var err error + var n int + var pi uint16 + var prev int + for i := 0; i < total; { + n, err = c.Read(buf[:]) + if err != nil { + break + } + n += prev + for ; n >= 4; n -= 4 { + i++ + pi++ + if pi%10000 == 0 { + ppich <- nil + } + if pi == 60001 { + pi = 0 + } + } + prev = n + } + ppich <- err + dch <- err +} + +func BenchmarkMQTT_QoS0_Pub_______0b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, "", 0) +} + +func BenchmarkMQTT_QoS0_Pub_______8b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(8), 0) +} + +func BenchmarkMQTT_QoS0_Pub______32b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(32), 0) +} + +func BenchmarkMQTT_QoS0_Pub_____128b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(128), 0) +} + +func BenchmarkMQTT_QoS0_Pub_____256b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(256), 0) +} + +func BenchmarkMQTT_QoS0_Pub_______1K_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(1024), 0) +} + +func BenchmarkMQTT_QoS0_PubSub1___0b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, "", 1) +} + +func BenchmarkMQTT_QoS0_PubSub1___8b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(8), 1) +} + +func BenchmarkMQTT_QoS0_PubSub1__32b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(32), 1) +} + +func BenchmarkMQTT_QoS0_PubSub1_128b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(128), 1) +} + +func BenchmarkMQTT_QoS0_PubSub1_256b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(256), 1) +} + +func BenchmarkMQTT_QoS0_PubSub1___1K_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(1024), 1) +} + +func BenchmarkMQTT_QoS0_PubSub2___0b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, "", 2) +} + +func BenchmarkMQTT_QoS0_PubSub2___8b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(8), 2) +} + +func BenchmarkMQTT_QoS0_PubSub2__32b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(32), 2) +} + +func BenchmarkMQTT_QoS0_PubSub2_128b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(128), 2) +} + +func BenchmarkMQTT_QoS0_PubSub2_256b_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(256), 2) +} + +func BenchmarkMQTT_QoS0_PubSub2___1K_Payload(b *testing.B) { + mqttBenchPubQoS0(b, mqttPubSubj, sizedString(1024), 2) +} + +func BenchmarkMQTT_QoS1_Pub_______0b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, "", 0) +} + +func BenchmarkMQTT_QoS1_Pub_______8b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(8), 0) +} + +func BenchmarkMQTT_QoS1_Pub______32b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(32), 0) +} + +func BenchmarkMQTT_QoS1_Pub_____128b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(128), 0) +} + +func BenchmarkMQTT_QoS1_Pub_____256b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(256), 0) +} + +func BenchmarkMQTT_QoS1_Pub_______1K_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(1024), 0) +} + +func BenchmarkMQTT_QoS1_PubSub1___0b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, "", 1) +} + +func BenchmarkMQTT_QoS1_PubSub1___8b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(8), 1) +} + +func BenchmarkMQTT_QoS1_PubSub1__32b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(32), 1) +} + +func BenchmarkMQTT_QoS1_PubSub1_128b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(128), 1) +} + +func BenchmarkMQTT_QoS1_PubSub1_256b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(256), 1) +} + +func BenchmarkMQTT_QoS1_PubSub1___1K_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(1024), 1) +} + +func BenchmarkMQTT_QoS1_PubSub2___0b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, "", 2) +} + +func BenchmarkMQTT_QoS1_PubSub2___8b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(8), 2) +} + +func BenchmarkMQTT_QoS1_PubSub2__32b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(32), 2) +} + +func BenchmarkMQTT_QoS1_PubSub2_128b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(128), 2) +} + +func BenchmarkMQTT_QoS1_PubSub2_256b_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(256), 2) +} + +func BenchmarkMQTT_QoS1_PubSub2___1K_Payload(b *testing.B) { + mqttBenchPubQoS1(b, mqttPubSubj, sizedString(1024), 2) +} diff --git a/server/opts.go b/server/opts.go index 32802133..3f44538a 100644 --- a/server/opts.go +++ b/server/opts.go @@ -197,6 +197,7 @@ type Options struct { JetStreamMaxStore int64 `json:"-"` StoreDir string `json:"-"` Websocket WebsocketOpts `json:"-"` + MQTT MQTTOpts `json:"-"` ProfPort int `json:"-"` PidFile string `json:"-"` PortsFileDir string `json:"-"` @@ -256,7 +257,7 @@ type Options struct { routeProto int } -// WebsocketOpts ... +// WebsocketOpts are options for websocket type WebsocketOpts struct { // The server will accept websocket client connections on this hostname/IP. Host string @@ -316,6 +317,49 @@ type WebsocketOpts struct { HandshakeTimeout time.Duration } +// MQTTOpts are options for MQTT +type MQTTOpts struct { + // The server will accept MQTT client connections on this hostname/IP. + Host string + // The server will accept MQTT client connections on this port. + Port int + + // If no user name is provided when a client connects, will default to the + // matching user from the global list of users in `Options.Users`. + NoAuthUser string + + // Authentication section. If anything is configured in this section, + // it will override the authorization configuration of regular clients. + Username string + Password string + Token string + + // Timeout for the authentication process. + AuthTimeout float64 + + // TLS configuration is required. + TLSConfig *tls.Config + // If true, map certificate values for authentication purposes. + TLSMap bool + // Timeout for the TLS handshake + TLSTimeout float64 + + // AckWait is the amount of time after which a QoS 1 message sent to + // a client is redelivered as a DUPLICATE if the server has not + // received the PUBACK on the original Packet Identifier. + // The value has to be positive. + // Zero will cause the server to use the default value (1 hour). + // Note that changes to this option is applied only to new MQTT subscriptions. + AckWait time.Duration + + // MaxAckPending is the amount of QoS 1 messages the server can send to + // a session without receiving any PUBACK for those messages. + // The valid range is [0..65535]. + // Zero will cause the server to use the default value (1024). + // Note that changes to this option is applied only to new MQTT sessions. + MaxAckPending uint16 +} + type netResolver interface { LookupHost(ctx context.Context, host string) ([]string, error) } @@ -1011,6 +1055,11 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error *errors = append(*errors, err) return } + case "mqtt": + if err := parseMQTT(tk, o, errors, warnings); err != nil { + *errors = append(*errors, err) + return + } default: if au := atomic.LoadInt32(&allowUnknownTopLevelField); au == 0 && !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ @@ -3327,7 +3376,7 @@ func parseTLS(v interface{}, isClientCtx bool) (t *TLSConfigOpts, retErr error) return &tc, nil } -func parseAuthForWS(v interface{}, errors *[]error, warnings *[]error) *authorization { +func parseSimpleAuth(v interface{}, errors *[]error, warnings *[]error) *authorization { var ( am map[string]interface{} tk token @@ -3463,7 +3512,7 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro case "compression": o.Websocket.Compression = mv.(bool) case "authorization", "authentication": - auth := parseAuthForWS(tk, errors, warnings) + auth := parseSimpleAuth(tk, errors, warnings) o.Websocket.Username = auth.user o.Websocket.Password = auth.pass o.Websocket.Token = auth.token @@ -3488,6 +3537,79 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro return nil } +func parseMQTT(v interface{}, o *Options, errors *[]error, warnings *[]error) error { + var lt token + defer convertPanicToErrorList(<, errors) + + tk, v := unwrapValue(v, <) + gm, ok := v.(map[string]interface{}) + if !ok { + return &configErr{tk, fmt.Sprintf("Expected mqtt to be a map, got %T", v)} + } + for mk, mv := range gm { + // Again, unwrap token value if line check is required. + tk, mv = unwrapValue(mv, <) + switch strings.ToLower(mk) { + case "listen": + hp, err := parseListen(mv) + if err != nil { + err := &configErr{tk, err.Error()} + *errors = append(*errors, err) + continue + } + o.MQTT.Host = hp.host + o.MQTT.Port = hp.port + case "port": + o.MQTT.Port = int(mv.(int64)) + case "host", "net": + o.MQTT.Host = mv.(string) + case "tls": + tc, err := parseTLS(tk, true) + if err != nil { + *errors = append(*errors, err) + continue + } + if o.MQTT.TLSConfig, err = GenTLSConfig(tc); err != nil { + err := &configErr{tk, err.Error()} + *errors = append(*errors, err) + continue + } + o.MQTT.TLSTimeout = tc.Timeout + o.MQTT.TLSMap = tc.Map + case "authorization", "authentication": + auth := parseSimpleAuth(tk, errors, warnings) + o.MQTT.Username = auth.user + o.MQTT.Password = auth.pass + o.MQTT.Token = auth.token + o.MQTT.AuthTimeout = auth.timeout + case "no_auth_user": + o.MQTT.NoAuthUser = mv.(string) + case "ack_wait", "ackwait": + o.MQTT.AckWait = parseDuration("ack_wait", tk, mv, errors, warnings) + case "max_ack_pending", "max_pending", "max_inflight": + tmp := int(mv.(int64)) + if tmp < 0 || tmp > 0xFFFF { + err := &configErr{tk, fmt.Sprintf("invalid value %v, should in [0..%d] range", tmp, 0xFFFF)} + *errors = append(*errors, err) + } else { + o.MQTT.MaxAckPending = uint16(tmp) + } + default: + if !tk.IsUsedVariable() { + err := &unknownConfigFieldErr{ + field: mk, + configErr: configErr{ + token: tk, + }, + } + *errors = append(*errors, err) + continue + } + } + } + return nil +} + // GenTLSConfig loads TLS related configuration parameters. func GenTLSConfig(tc *TLSConfigOpts) (*tls.Config, error) { // Create the tls.Config from our options before including the certs. @@ -3831,6 +3953,14 @@ func setBaselineOptions(opts *Options) { opts.Websocket.Host = DEFAULT_HOST } } + if opts.MQTT.Port != 0 { + if opts.MQTT.Host == "" { + opts.MQTT.Host = DEFAULT_HOST + } + if opts.MQTT.TLSTimeout == 0 { + opts.MQTT.TLSTimeout = float64(TLS_TIMEOUT) / float64(time.Second) + } + } // JetStream if opts.JetStreamMaxMemory == 0 { opts.JetStreamMaxMemory = -1 diff --git a/server/parser.go b/server/parser.go index e9492b2f..f76337df 100644 --- a/server/parser.go +++ b/server/parser.go @@ -130,6 +130,11 @@ const ( ) func (c *client) parse(buf []byte) error { + // Branch out to mqtt clients. c.mqtt is immutable, but should it become + // an issue (say data race detection), we could branch outside in readLoop + if c.mqtt != nil { + return c.mqttParse(buf) + } var i int var b byte var lmsg bool diff --git a/server/reload.go b/server/reload.go index 14fa386e..790bb80a 100644 --- a/server/reload.go +++ b/server/reload.go @@ -596,6 +596,25 @@ func (m *maxTracedMsgLenOption) Apply(server *Server) { server.Noticef("Reloaded: max_traced_msg_len = %d", m.newValue) } +type mqttAckWaitReload struct { + noopOption + newValue time.Duration +} + +func (o *mqttAckWaitReload) Apply(s *Server) { + s.Noticef("Reloaded: MQTT ack_wait = %v", o.newValue) +} + +type mqttMaxAckPendingReload struct { + noopOption + newValue uint16 +} + +func (o *mqttMaxAckPendingReload) Apply(s *Server) { + s.mqttUpdateMaxAckPending(o.newValue) + s.Noticef("Reloaded: MQTT max_ack_pending = %v", o.newValue) +} + // Reload reads the current configuration file and applies any supported // changes. This returns an error if the server was not started with a config // file or an option which doesn't support hot-swapping was changed. @@ -633,6 +652,7 @@ func (s *Server) Reload() error { gatewayOrgPort := curOpts.Gateway.Port leafnodesOrgPort := curOpts.LeafNode.Port websocketOrgPort := curOpts.Websocket.Port + mqttOrgPort := curOpts.MQTT.Port s.mu.Unlock() @@ -665,6 +685,9 @@ func (s *Server) Reload() error { if newOpts.Websocket.Port == -1 { newOpts.Websocket.Port = websocketOrgPort } + if newOpts.MQTT.Port == -1 { + newOpts.MQTT.Port = mqttOrgPort + } if err := s.reloadOptions(curOpts, newOpts); err != nil { return err @@ -766,7 +789,7 @@ func imposeOrder(value interface{}) error { case WebsocketOpts: sort.Strings(value.AllowedOrigins) case string, bool, int, int32, int64, time.Duration, float64, nil, - LeafNodeOpts, ClusterOpts, *tls.Config, *URLAccResolver, *MemAccResolver, *DirAccResolver, *CacheDirAccResolver, Authentication: + LeafNodeOpts, ClusterOpts, *tls.Config, *URLAccResolver, *MemAccResolver, *DirAccResolver, *CacheDirAccResolver, Authentication, MQTTOpts: // explicitly skipped types default: // this will fail during unit tests @@ -973,6 +996,20 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { return nil, fmt.Errorf("config reload not supported for %s: old=%v, new=%v", field.Name, oldValue, newValue) } + case "mqtt": + diffOpts = append(diffOpts, &mqttAckWaitReload{newValue: newValue.(MQTTOpts).AckWait}) + diffOpts = append(diffOpts, &mqttMaxAckPendingReload{newValue: newValue.(MQTTOpts).MaxAckPending}) + // Nil out/set to 0 the options that we allow to be reloaded so that + // we only fail reload if some that we don't support are changed. + tmpOld := oldValue.(MQTTOpts) + tmpNew := newValue.(MQTTOpts) + tmpOld.TLSConfig, tmpOld.AckWait, tmpOld.MaxAckPending = nil, 0, 0 + tmpNew.TLSConfig, tmpNew.AckWait, tmpNew.MaxAckPending = nil, 0, 0 + if !reflect.DeepEqual(tmpOld, tmpNew) { + // See TODO(ik) note below about printing old/new values. + return nil, fmt.Errorf("config reload not supported for %s: old=%v, new=%v", + field.Name, oldValue, newValue) + } case "connecterrorreports": diffOpts = append(diffOpts, &connectErrorReports{newValue: newValue.(int)}) case "reconnecterrorreports": @@ -1233,6 +1270,8 @@ func (s *Server) reloadAuthorization() { // can't hold the lock as go routine reading it may be waiting for lock as well resetCh = s.sys.resetCh } + // Check that publish retained messages sources are still allowed to publish. + s.mqttCheckPubRetainedPerms() s.mu.Unlock() if resetCh != nil { diff --git a/server/server.go b/server/server.go index 455457a7..59767929 100644 --- a/server/server.go +++ b/server/server.go @@ -222,6 +222,9 @@ type Server struct { // Websocket structure websocket srvWebsocket + // MQTT structure + mqtt srvMQTT + // exporting account name the importer experienced issues with incompleteAccExporterMap sync.Map @@ -533,6 +536,9 @@ func validateOptions(o *Options) error { if err := validateClusterName(o); err != nil { return err } + if err := validateMQTTOptions(o); err != nil { + return err + } // Finally check websocket options. return validateWebsocketOptions(o) } @@ -1516,6 +1522,11 @@ func (s *Server) Start() { s.startWebsocketServer() } + // MQTT + if opts.MQTT.Port != 0 { + s.startMQTT() + } + // Start up routing as well if needed. if opts.Cluster.Port != 0 { s.startGoRoutine(func() { @@ -1611,6 +1622,13 @@ func (s *Server) Shutdown() { s.websocket.listener = nil } + // Kick MQTT accept loop + if s.mqtt.listener != nil { + doneExpected++ + s.mqtt.listener.Close() + s.mqtt.listener = nil + } + // Kick leafnodes AcceptLoop() if s.leafNodeListener != nil { doneExpected++ @@ -1747,7 +1765,7 @@ func (s *Server) AcceptLoop(clr chan struct{}) { s.clientConnectURLs = s.getClientConnectURLs() s.listener = l - go s.acceptConnections(l, "Client", func(conn net.Conn) { s.createClient(conn, nil) }, + go s.acceptConnections(l, "Client", func(conn net.Conn) { s.createClient(conn, nil, nil) }, func(_ error) bool { if s.isLameDuckMode() { // Signal that we are not accepting new clients @@ -2073,7 +2091,7 @@ func (c *tlsMixConn) Read(b []byte) (int, error) { return c.Conn.Read(b) } -func (s *Server) createClient(conn net.Conn, ws *websocket) *client { +func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client { // Snapshot server options. opts := s.getOpts() @@ -2086,32 +2104,49 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { now := time.Now() c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now, ws: ws} + if mqtt != nil { + c.mqtt = mqtt + // Set some of the options here since MQTT clients don't + // send a regular CONNECT (but have their own). + c.opts.Lang = "mqtt" + c.opts.Verbose = false + } c.registerWithAccount(s.globalAccount()) - // Grab JSON info string + var info Info + var authRequired bool + s.mu.Lock() - info := s.copyInfo() - // If this is a websocket client and there is no top-level auth specified, - // then we use the websocket's specific boolean that will be set to true - // if there is any auth{} configured in websocket{}. - if ws != nil && !info.AuthRequired { - info.AuthRequired = s.websocket.authOverride + // We don't need the INFO to mqtt clients. + if mqtt == nil { + // Grab JSON info string + info = s.copyInfo() + // If this is a websocket client and there is no top-level auth specified, + // then we use the websocket's specific boolean that will be set to true + // if there is any auth{} configured in websocket{}. + if ws != nil && !info.AuthRequired { + info.AuthRequired = s.websocket.authOverride + } + if s.nonceRequired() { + // Nonce handling + var raw [nonceLen]byte + nonce := raw[:] + s.generateNonce(nonce) + info.Nonce = string(nonce) + } + c.nonce = []byte(info.Nonce) + authRequired = info.AuthRequired + } else { + authRequired = s.info.AuthRequired || s.mqtt.authOverride } - if s.nonceRequired() { - // Nonce handling - var raw [nonceLen]byte - nonce := raw[:] - s.generateNonce(nonce) - info.Nonce = string(nonce) - } - c.nonce = []byte(info.Nonce) + s.totalClients++ s.mu.Unlock() // Grab lock c.mu.Lock() - if info.AuthRequired { + if authRequired { c.flags.set(expectConnect) } @@ -2120,10 +2155,13 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { c.Debugf("Client connection created") - // Send our information. - // Need to be sent in place since writeLoop cannot be started until - // TLS handshake is done (if applicable). - c.sendProtoNow(c.generateClientInfoJSON(info)) + // We don't send the INFO to mqtt clients. + if mqtt == nil { + // Send our information. + // Need to be sent in place since writeLoop cannot be started until + // TLS handshake is done (if applicable). + c.sendProtoNow(c.generateClientInfoJSON(info)) + } // Unlock to register c.mu.Unlock() @@ -2155,6 +2193,22 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { return nil } s.clients[c.cid] = c + + // May be overridden based on type of client. + TLSConfig := opts.TLSConfig + TLSTimeout := opts.TLSTimeout + + tlsRequired := info.TLSRequired + // Websocket clients do TLS in the websocket http server. + if ws != nil { + tlsRequired = false + } else if mqtt != nil { + tlsRequired = opts.MQTT.TLSConfig != nil + if tlsRequired { + TLSConfig = opts.MQTT.TLSConfig + TLSTimeout = opts.MQTT.TLSTimeout + } + } s.mu.Unlock() // Re-Grab lock @@ -2163,13 +2217,12 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // Connection could have been closed while sending the INFO proto. isClosed := c.isClosed() - tlsRequired := ws == nil && info.TLSRequired var pre []byte // If we have both TLS and non-TLS allowed we need to see which // one the client wants. if !isClosed && opts.TLSConfig != nil && opts.AllowNonTLS { pre = make([]byte, 4) - c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(opts.TLSTimeout))) + c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(TLSTimeout))) n, _ := io.ReadFull(c.nc, pre[:]) c.nc.SetReadDeadline(time.Time{}) pre = pre[:n] @@ -2190,11 +2243,11 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { pre = nil } - c.nc = tls.Server(c.nc, opts.TLSConfig) + c.nc = tls.Server(c.nc, TLSConfig) conn := c.nc.(*tls.Conn) // Setup the timeout - ttl := secondsToDuration(opts.TLSTimeout) + ttl := secondsToDuration(TLSTimeout) time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) conn.SetReadDeadline(time.Now().Add(ttl)) @@ -2231,7 +2284,7 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // Check for Auth. We schedule this timer after the TLS handshake to avoid // the race where the timer fires during the handshake and causes the // server to write bad data to the socket. See issue #432. - if info.AuthRequired { + if authRequired { timeout := opts.AuthTimeout // For websocket, possibly override only if set. We make sure that // opts.AuthTimeout is set to a default value if not configured, @@ -2239,6 +2292,8 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // if user has explicitly set or not. if ws != nil && opts.Websocket.AuthTimeout != 0 { timeout = opts.Websocket.AuthTimeout + } else if mqtt != nil && opts.MQTT.AuthTimeout != 0 { + timeout = opts.MQTT.AuthTimeout } c.setAuthTimer(secondsToDuration(timeout)) } @@ -2421,7 +2476,11 @@ func (s *Server) removeClient(c *client) { if updateProtoInfoCount { s.cproto-- } + mqtt := c.mqtt != nil s.mu.Unlock() + if mqtt { + s.mqttHandleClosedClient(c) + } case ROUTER: s.removeRoute(c) case GATEWAY: diff --git a/server/server_test.go b/server/server_test.go index 1114315b..b546905f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1245,6 +1245,8 @@ func TestServerShutdownDuringStart(t *testing.T) { o.Websocket.Port = -1 o.Websocket.HandshakeTimeout = 1 o.Websocket.NoTLS = true + o.MQTT.Host = "127.0.0.1" + o.MQTT.Port = -1 // We are going to test that if the server is shutdown // while Start() runs (in this case, before), we don't @@ -1286,6 +1288,9 @@ func TestServerShutdownDuringStart(t *testing.T) { if s.websocket.listener != nil { listeners = append(listeners, "websocket") } + if s.mqtt.listener != nil { + listeners = append(listeners, "mqtt") + } s.mu.Unlock() if len(listeners) > 0 { lst := "" diff --git a/server/stream.go b/server/stream.go index 729c1db4..d709767a 100644 --- a/server/stream.go +++ b/server/stream.go @@ -27,6 +27,7 @@ import ( "path/filepath" "reflect" "strconv" + "strings" "sync" "time" @@ -95,6 +96,7 @@ type Stream struct { ddarr []*ddentry ddindex int ddtmr *time.Timer + nosubj bool } // Headers for published messages. @@ -125,13 +127,20 @@ func (a *Account) AddStream(config *StreamConfig) (*Stream, error) { // AddStreamWithStore adds a stream for the given account with custome store config options. func (a *Account) AddStreamWithStore(config *StreamConfig, fsConfig *FileStoreConfig) (*Stream, error) { + if strings.HasPrefix(config.Name, mqttStreamNamePrefix) { + return nil, fmt.Errorf("prefix %q is reserved for MQTT, unable to create stream %q", mqttStreamNamePrefix, config.Name) + } + return a.addStreamWithStore(config, fsConfig, false) +} + +func (a *Account) addStreamWithStore(config *StreamConfig, fsConfig *FileStoreConfig, noSubjectsOK bool) (*Stream, error) { s, jsa, err := a.checkForJetStream() if err != nil { return nil, err } // Sensible defaults. - cfg, err := checkStreamCfg(config) + cfg, err := checkStreamCfg(config, noSubjectsOK) if err != nil { return nil, err } @@ -168,7 +177,7 @@ func (a *Account) AddStreamWithStore(config *StreamConfig, fsConfig *FileStoreCo // Setup the internal client. c := s.createInternalJetStreamClient() - mset := &Stream{jsa: jsa, config: cfg, client: c, consumers: make(map[string]*Consumer)} + mset := &Stream{jsa: jsa, config: cfg, client: c, consumers: make(map[string]*Consumer), nosubj: noSubjectsOK} jsa.streams[cfg.Name] = mset storeDir := path.Join(jsa.storeDir, streamsDir, cfg.Name) @@ -416,7 +425,7 @@ func (jsa *jsAccount) subjectsOverlap(subjects []string) bool { // Default duplicates window. const StreamDefaultDuplicatesWindow = 2 * time.Minute -func checkStreamCfg(config *StreamConfig) (StreamConfig, error) { +func checkStreamCfg(config *StreamConfig, noSubjectOk bool) (StreamConfig, error) { if config == nil { return StreamConfig{}, fmt.Errorf("stream configuration invalid") } @@ -466,7 +475,9 @@ func checkStreamCfg(config *StreamConfig) (StreamConfig, error) { } if len(cfg.Subjects) == 0 { - cfg.Subjects = append(cfg.Subjects, cfg.Name) + if !noSubjectOk { + cfg.Subjects = append(cfg.Subjects, cfg.Name) + } } else { // We can allow overlaps, but don't allow direct duplicates. dset := make(map[string]struct{}, len(cfg.Subjects)) @@ -519,7 +530,10 @@ func (mset *Stream) Delete() error { // Update will allow certain configuration properties of an existing stream to be updated. func (mset *Stream) Update(config *StreamConfig) error { - cfg, err := checkStreamCfg(config) + mset.mu.RLock() + nosubj := mset.nosubj + mset.mu.RUnlock() + cfg, err := checkStreamCfg(config, nosubj) if err != nil { return err } @@ -691,7 +705,7 @@ func (mset *Stream) subscribeInternal(subject string, cb msgHandler) (*subscript mset.sid++ // Now create the subscription - return c.processSub([]byte(subject), nil, []byte(strconv.Itoa(mset.sid)), cb, false) + return c.processSub(c.createSub([]byte(subject), nil, []byte(strconv.Itoa(mset.sid)), cb), false) } // Helper for unlocked stream. @@ -901,7 +915,7 @@ func getExpectedLastSeq(hdr []byte) uint64 { } // processInboundJetStreamMsg handles processing messages bound for a stream. -func (mset *Stream) processInboundJetStreamMsg(_ *subscription, pc *client, subject, reply string, msg []byte) { +func (mset *Stream) processInboundJetStreamMsg(sub *subscription, pc *client, subject, reply string, msg []byte) { mset.mu.Lock() store := mset.store c := mset.client diff --git a/server/sublist.go b/server/sublist.go index 1c860a83..fcb0c28d 100644 --- a/server/sublist.go +++ b/server/sublist.go @@ -1392,3 +1392,65 @@ func (s *Sublist) collectAllSubs(l *level, subs *[]*subscription) { s.collectAllSubs(l.fwc.next, subs) } } + +func (s *Sublist) ReverseMatch(subject string) *SublistResult { + tsa := [32]string{} + tokens := tsa[:0] + start := 0 + for i := 0; i < len(subject); i++ { + if subject[i] == btsep { + tokens = append(tokens, subject[start:i]) + start = i + 1 + } + } + tokens = append(tokens, subject[start:]) + + result := &SublistResult{} + + s.Lock() + reverseMatchLevel(s.root, tokens, nil, result) + // Check for empty result. + if len(result.psubs) == 0 && len(result.qsubs) == 0 { + result = emptyResult + } + s.Unlock() + + return result +} + +func reverseMatchLevel(l *level, toks []string, n *node, results *SublistResult) { + for i, t := range toks { + if l == nil { + return + } + if len(t) == 1 { + if t[0] == fwc { + getAllNodes(l, results) + return + } else if t[0] == pwc { + for _, n := range l.nodes { + reverseMatchLevel(n.next, toks[i+1:], n, results) + } + return + } + } + n = l.nodes[t] + if n == nil { + break + } + l = n.next + } + if n != nil { + addNodeToResults(n, results) + } +} + +func getAllNodes(l *level, results *SublistResult) { + if l == nil { + return + } + for _, n := range l.nodes { + addNodeToResults(n, results) + getAllNodes(n.next, results) + } +} diff --git a/server/sublist_test.go b/server/sublist_test.go index 43a2ae84..eb35e1ea 100644 --- a/server/sublist_test.go +++ b/server/sublist_test.go @@ -1258,6 +1258,74 @@ func TestSublistRegisterInterestNotification(t *testing.T) { expectFalse() } +func TestSublistReverseMatch(t *testing.T) { + s := NewSublistWithCache() + fooSub := newSub("foo") + barSub := newSub("bar") + fooBarSub := newSub("foo.bar") + fooBazSub := newSub("foo.baz") + fooBarBazSub := newSub("foo.bar.baz") + s.Insert(fooSub) + s.Insert(barSub) + s.Insert(fooBarSub) + s.Insert(fooBazSub) + s.Insert(fooBarBazSub) + + r := s.ReverseMatch("foo") + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, fooSub, t) + + r = s.ReverseMatch("bar") + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, barSub, t) + + r = s.ReverseMatch("*") + verifyLen(r.psubs, 2, t) + verifyMember(r.psubs, fooSub, t) + verifyMember(r.psubs, barSub, t) + + r = s.ReverseMatch("baz") + verifyLen(r.psubs, 0, t) + + r = s.ReverseMatch("foo.*") + verifyLen(r.psubs, 2, t) + verifyMember(r.psubs, fooBarSub, t) + verifyMember(r.psubs, fooBazSub, t) + + r = s.ReverseMatch("*.*") + verifyLen(r.psubs, 2, t) + verifyMember(r.psubs, fooBarSub, t) + verifyMember(r.psubs, fooBazSub, t) + + r = s.ReverseMatch("*.bar") + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, fooBarSub, t) + + r = s.ReverseMatch("*.baz") + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, fooBazSub, t) + + r = s.ReverseMatch("bar.*") + verifyLen(r.psubs, 0, t) + + r = s.ReverseMatch("*.bat") + verifyLen(r.psubs, 0, t) + + r = s.ReverseMatch("foo.>") + verifyLen(r.psubs, 3, t) + verifyMember(r.psubs, fooBarSub, t) + verifyMember(r.psubs, fooBazSub, t) + verifyMember(r.psubs, fooBarBazSub, t) + + r = s.ReverseMatch(">") + verifyLen(r.psubs, 5, t) + verifyMember(r.psubs, fooSub, t) + verifyMember(r.psubs, barSub, t) + verifyMember(r.psubs, fooBarSub, t) + verifyMember(r.psubs, fooBazSub, t) + verifyMember(r.psubs, fooBarBazSub, t) +} + // -- Benchmarks Setup -- var benchSublistSubs []*subscription diff --git a/server/websocket.go b/server/websocket.go index 7b10877d..10defdc4 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -856,6 +856,7 @@ func (s *Server) startWebsocketServer() { if port == 0 { s.opts.Websocket.Port = hl.Addr().(*net.TCPAddr).Port } + s.Noticef("Listening for websocket clients on %s://%s:%d", proto, o.Host, port) s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port) if err != nil { s.Fatalf("Unable to get websocket connect URLs: %v", err) @@ -870,7 +871,7 @@ func (s *Server) startWebsocketServer() { s.Errorf(err.Error()) return } - s.createClient(res.conn, res.ws) + s.createClient(res.conn, res.ws, nil) }) hs := &http.Server{ Addr: hp, diff --git a/server/websocket_test.go b/server/websocket_test.go index fa5e235b..e0b075fb 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -2095,6 +2095,10 @@ func TestWSTLSVerifyAndMap(t *testing.T) { {"no filtering, client does not provide cert", false, false}, {"filtering, client provides cert", true, true}, {"filtering, client does not provide cert", true, false}, + {"no users override, client provides cert", false, true}, + {"no users override, client does not provide cert", false, false}, + {"users override, client provides cert", true, true}, + {"users override, client does not provide cert", true, false}, } { t.Run(test.name, func(t *testing.T) { o := testWSOptions() From ac4890acba29adea4956915ae55a9affd8473245 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Mon, 30 Nov 2020 15:42:50 -0700 Subject: [PATCH 02/11] Fixed flapper Tests dealing with MQTT "will" needed to wait for the server to process the MQTT client close of the connection. Only then we have the guarantee that the server produced the "will" message. Signed-off-by: Ivan Kozlovic --- server/mqtt.go | 2 +- server/mqtt_test.go | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/server/mqtt.go b/server/mqtt.go index 4c508b18..3e7fd531 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -1445,7 +1445,7 @@ func (s *Server) mqttHandleWill(c *client) { ////////////////////////////////////////////////////////////////////////////// func (c *client) mqttParsePub(r *mqttReader, pl int, pp *mqttPublish) error { - qos := (pp.flags & mqttPubFlagQoS) >> 1 + qos := mqttGetQoS(pp.flags) if qos > 1 { return fmt.Errorf("publish QoS=%v not supported", qos) } diff --git a/server/mqtt_test.go b/server/mqtt_test.go index d5ad36b3..7427cf9f 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -2676,6 +2676,10 @@ func TestMQTTWillRetain(t *testing.T) { // Disconnect, which will cause will to be produced with retain flag. mc.Close() + // Wait for the server to process the connection close, which will + // cause the "will" message to be published (and retained). + checkClientsCount(t, s, 0) + // Create subscription on will topic and expect will message. mcs, rs := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) defer mcs.Close() @@ -2742,6 +2746,10 @@ func TestMQTTWillRetainPermViolation(t *testing.T) { // Disconnect, which will cause the Will to be sent with retain flag. mc.Close() + // Wait for the server to process the connection close, which will + // cause the "will" message to be published (and retained). + checkClientsCount(t, s, 0) + // Create a subscription on the Will subject and we should receive it. ci.will = nil mcs, rs := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) @@ -2773,6 +2781,10 @@ func TestMQTTWillRetainPermViolation(t *testing.T) { // since user not allowed to publish on "bar". mc.Close() + // Wait for the server to process the connection close, which will + // cause the "will" message to be published (and retained). + checkClientsCount(t, s, 0) + // Create sub on "bar" which user is allowed to subscribe to. ci.will = nil mcs, rs = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) From 718c995914fc7906f7d2e52ec93530b52652006c Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Mon, 30 Nov 2020 17:43:50 -0700 Subject: [PATCH 03/11] Allow "nats" utility to display internal MQTT streams MQTT streams are special in that we do not set subjects in the config since they capture all subjects. Otherwise, we would have been forced to create a stream on say "MQTT.>" but then all publishes would have to be prefixed with "MQTT." in order for them to be captured. However, if one uses the "nats" tool to inspect those streams, the tool would fail with: ``` server response is not a valid "io.nats.jetstream.api.v1.stream_info_response" message: (root): Must validate one and only one schema (oneOf) config: subjects is required config: Must validate all the schemas (allOf) ``` To solve that, if we detect that user asks for the MQTT streams, we artificially set the returned config's subject to ">". Alternatively, we may want to not return those streams at all, although there may be value to see the info for mqtt streams/consumers. Signed-off-by: Ivan Kozlovic --- server/jetstream_api.go | 9 ++++++++- server/mqtt_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/server/jetstream_api.go b/server/jetstream_api.go index 32773443..28170294 100644 --- a/server/jetstream_api.go +++ b/server/jetstream_api.go @@ -952,7 +952,14 @@ func (s *Server) jsStreamInfoRequest(sub *subscription, c *client, subject, repl s.sendAPIResponse(c, subject, reply, string(msg), s.jsonResponse(&resp)) return } - resp.StreamInfo = &StreamInfo{Created: mset.Created(), State: mset.State(), Config: mset.Config()} + config := mset.Config() + // MQTT streams are created without subject, but "nats" tooling would then + // fail to display them since it uses validation and expect the config's + // Subjects to not be empty. + if strings.HasPrefix(name, mqttStreamNamePrefix) && len(config.Subjects) == 0 { + config.Subjects = []string{">"} + } + resp.StreamInfo = &StreamInfo{Created: mset.Created(), State: mset.State(), Config: config} s.sendAPIResponse(c, subject, reply, string(msg), s.jsonResponse(resp)) } diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 7427cf9f..7c76b490 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -3745,7 +3745,48 @@ func TestMQTTConfigReload(t *testing.T) { testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg2")) } +func TestMQTTStreamInfoReturnsNonEmptySubject(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer s.Shutdown() + + cisub := &mqttConnInfo{clientID: "sub", cleanSess: false} + c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + // Check that we can query all MQTT streams. MQTT streams are + // created without subject filter, however, if we return them like this, + // the 'nats' utility will fail to display them due to some xml validation. + for _, sname := range []string{ + mqttStreamName, + mqttSessionsStreamName, + mqttRetainedMsgsStreamName, + } { + t.Run(sname, func(t *testing.T) { + resp, err := nc.Request(fmt.Sprintf(JSApiStreamInfoT, sname), nil, time.Second) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + var bResp JSApiStreamInfoResponse + if err = json.Unmarshal(resp.Data, &bResp); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if len(bResp.Config.Subjects) == 0 { + t.Fatalf("No subject returned, which will cause nats tooling to fail: %+v", bResp.Config) + } + }) + } +} + +////////////////////////////////////////////////////////////////////////// +// // Benchmarks +// +////////////////////////////////////////////////////////////////////////// const ( mqttPubSubj = "a" From 3e91ef75abd2bd69ee5cf94b7309a9026a090518 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Tue, 1 Dec 2020 13:53:20 -0700 Subject: [PATCH 04/11] Some updates based on code review - Added non-public stream and consumer configuration options to achieve the "no subject" and "no interest" capabilities. Had to implement custom FileStreamInfo and FileConsumerInfo marshal/ unmarshal methods so that those non public fields can be persisted/recovered properly. - Restored some of JS original code (since now can use config instead of passing booleans to the functions). - Use RLock for deliveryFormsCycle() check (unrelated to MQTT). - Removed restriction on creating streams with MQTT prefix. - Preventing API deletion of internal streams and their consumers. - Added comment on Sublist's ReverseMatch method. Signed-off-by: Ivan Kozlovic --- server/consumer.go | 56 +++++++++++++------------------- server/filestore.go | 57 +++++++++++++++++++++++++++++++++ server/jetstream.go | 11 ++----- server/jetstream_api.go | 18 ++++++++--- server/mqtt.go | 59 +++++++++++++++++++--------------- server/mqtt_test.go | 71 ++++++++++++++++++++++++++++++----------- server/stream.go | 30 +++++++---------- server/sublist.go | 7 ++++ 8 files changed, 202 insertions(+), 107 deletions(-) diff --git a/server/consumer.go b/server/consumer.go index ef0e71a4..28daf492 100644 --- a/server/consumer.go +++ b/server/consumer.go @@ -60,6 +60,11 @@ type ConsumerConfig struct { SampleFrequency string `json:"sample_freq,omitempty"` MaxWaiting int `json:"max_waiting,omitempty"` MaxAckPending int `json:"max_ack_pending,omitempty"` + + // These are non public configuration options. + // If you add new options, check fileConsumerInfoJSON in order for them to + // be properly persisted/recovered, if needed. + allowNoInterest bool } type CreateConsumerRequest struct { @@ -213,13 +218,6 @@ const ( ) func (mset *Stream) AddConsumer(config *ConsumerConfig) (*Consumer, error) { - if name := mset.Name(); strings.HasPrefix(name, mqttStreamNamePrefix) { - return nil, fmt.Errorf("stream prefix %q is reserved for MQTT, unable to create consumer on %q", mqttStreamNamePrefix, name) - } - return mset.addConsumerCheckInterest(config, true) -} - -func (mset *Stream) addConsumerCheckInterest(config *ConsumerConfig, checkInterest bool) (*Consumer, error) { if config == nil { return nil, fmt.Errorf("consumer config required") } @@ -272,18 +270,23 @@ func (mset *Stream) addConsumerCheckInterest(config *ConsumerConfig, checkIntere // Make sure any partition subject is also a literal. if config.FilterSubject != "" { - // If this is a direct match for the streams only subject clear the filter. + var checkSubject bool + mset.mu.RLock() - if len(mset.config.Subjects) == 1 && mset.config.Subjects[0] == config.FilterSubject { - config.FilterSubject = _EMPTY_ + // If the stream was created with no subject, then skip the checks + if !mset.config.allowNoSubject { + // If this is a direct match for the streams only subject clear the filter. + if len(mset.config.Subjects) == 1 && mset.config.Subjects[0] == config.FilterSubject { + config.FilterSubject = _EMPTY_ + } else { + checkSubject = true + } } mset.mu.RUnlock() - if config.FilterSubject != "" { - // Make sure this is a valid partition of the interest subjects. - if !mset.validSubject(config.FilterSubject) { - return nil, fmt.Errorf("consumer filter subject is not a valid subset of the interest subjects") - } + // Make sure this is a valid partition of the interest subjects. + if checkSubject && !mset.validSubject(config.FilterSubject) { + return nil, fmt.Errorf("consumer filter subject is not a valid subset of the interest subjects") } } @@ -357,7 +360,7 @@ func (mset *Stream) addConsumerCheckInterest(config *ConsumerConfig, checkIntere } else { // If we are a push mode and not active and the only difference // is deliver subject then update and return. - if configsEqualSansDelivery(ocfg, *config) && (!checkInterest || eo.hasNoLocalInterest()) { + if configsEqualSansDelivery(ocfg, *config) && (config.allowNoInterest || eo.hasNoLocalInterest()) { eo.updateDeliverSubject(config.DeliverSubject) return eo, nil } else { @@ -2173,8 +2176,8 @@ func (o *Consumer) stop(dflag, doSignal, advisory bool) error { // Check that we do not form a cycle by delivering to a delivery subject // that is part of the interest group. func (mset *Stream) deliveryFormsCycle(deliverySubject string) bool { - mset.mu.Lock() - defer mset.mu.Unlock() + mset.mu.RLock() + defer mset.mu.RUnlock() for _, subject := range mset.config.Subjects { if subjectIsSubsetMatch(deliverySubject, subject) { @@ -2184,22 +2187,9 @@ func (mset *Stream) deliveryFormsCycle(deliverySubject string) bool { return false } -// Check that the subject is a subset of the stream's configured subjects, -// or returns true if the stream has been created with no subject. +// This is same as check for delivery cycle. func (mset *Stream) validSubject(partitionSubject string) bool { - mset.mu.RLock() - defer mset.mu.RUnlock() - - if mset.nosubj && len(mset.config.Subjects) == 0 { - return true - } - - for _, subject := range mset.config.Subjects { - if subjectIsSubsetMatch(partitionSubject, subject) { - return true - } - } - return false + return mset.deliveryFormsCycle(partitionSubject) } // SetInActiveDeleteThreshold sets the delete threshold for how long to wait diff --git a/server/filestore.go b/server/filestore.go index 3f4a0a86..bd4e4249 100644 --- a/server/filestore.go +++ b/server/filestore.go @@ -56,6 +56,38 @@ type FileStreamInfo struct { StreamConfig } +// Need an alias (which does not have MarshalJSON/UnmarshalJSON) to avoid +// recursive calls which would lead to stack overflow. +type fileStreamInfoAlias FileStreamInfo + +// We will use this struct definition to serialize/deserialize FileStreamInfo +// object. This embeds FileStreamInfo (the alias to prevent recursive calls) +// and makes the non-public options public so they can be persisted/recovered. +type fileStreamInfoJSON struct { + fileStreamInfoAlias + Internal bool `json:"internal,omitempty"` + AllowNoSubject bool `json:"allow_no_subject,omitempty"` +} + +func (fsi FileStreamInfo) MarshalJSON() ([]byte, error) { + return json.Marshal(&fileStreamInfoJSON{ + fileStreamInfoAlias(fsi), + fsi.internal, + fsi.allowNoSubject, + }) +} + +func (fsi *FileStreamInfo) UnmarshalJSON(b []byte) error { + fsiJSON := &fileStreamInfoJSON{} + if err := json.Unmarshal(b, &fsiJSON); err != nil { + return err + } + *fsi = FileStreamInfo(fsiJSON.fileStreamInfoAlias) + fsi.internal = fsiJSON.Internal + fsi.allowNoSubject = fsiJSON.AllowNoSubject + return nil +} + // File ConsumerInfo is used for creating consumer stores. type FileConsumerInfo struct { Created time.Time @@ -63,6 +95,31 @@ type FileConsumerInfo struct { ConsumerConfig } +// See fileStreamInfoAlias, etc.. for details on how this all work. +type fileConsumerInfoAlias FileConsumerInfo + +type fileConsumerInfoJSON struct { + fileConsumerInfoAlias + AllowNoInterest bool `json:"allow_no_interest,omitempty"` +} + +func (fci FileConsumerInfo) MarshalJSON() ([]byte, error) { + return json.Marshal(&fileConsumerInfoJSON{ + fileConsumerInfoAlias(fci), + fci.allowNoInterest, + }) +} + +func (fci *FileConsumerInfo) UnmarshalJSON(b []byte) error { + fciJSON := &fileConsumerInfoJSON{} + if err := json.Unmarshal(b, &fciJSON); err != nil { + return err + } + *fci = FileConsumerInfo(fciJSON.fileConsumerInfoAlias) + fci.allowNoInterest = fciJSON.AllowNoInterest + return nil +} + type fileStore struct { mu sync.RWMutex state StreamState diff --git a/server/jetstream.go b/server/jetstream.go index 0c651a01..66463ab1 100644 --- a/server/jetstream.go +++ b/server/jetstream.go @@ -533,12 +533,7 @@ func (a *Account) EnableJetStream(limits *JetStreamAccountLimits) error { s.Warnf(" Error adding Stream %q to Template %q: %v", cfg.Name, cfg.Template, err) } } - // TODO: We should not rely on the stream name. - // However, having a StreamConfig property, such as AllowNoSubject, - // was not accepted because it does not make sense outside of the - // MQTT use-case. So need to revisit this. - mqtt := cfg.StreamConfig.Name == mqttStreamName - mset, err := a.addStreamWithStore(&cfg.StreamConfig, nil, mqtt) + mset, err := a.AddStream(&cfg.StreamConfig) if err != nil { s.Warnf(" Error recreating Stream %q: %v", cfg.Name, err) continue @@ -583,7 +578,7 @@ func (a *Account) EnableJetStream(limits *JetStreamAccountLimits) error { // the consumer can reconnect. We will create it as a durable and switch it. cfg.ConsumerConfig.Durable = ofi.Name() } - obs, err := mset.addConsumerCheckInterest(&cfg.ConsumerConfig, !mqtt) + obs, err := mset.AddConsumer(&cfg.ConsumerConfig) if err != nil { s.Warnf(" Error adding Consumer: %v", err) continue @@ -1065,7 +1060,7 @@ func (a *Account) AddStreamTemplate(tc *StreamTemplateConfig) (*StreamTemplate, // FIXME(dlc) - Hacky tcopy := tc.deepCopy() tcopy.Config.Name = "_" - cfg, err := checkStreamCfg(tcopy.Config, false) + cfg, err := checkStreamCfg(tcopy.Config) if err != nil { return nil, err } diff --git a/server/jetstream_api.go b/server/jetstream_api.go index 28170294..9cbb6c63 100644 --- a/server/jetstream_api.go +++ b/server/jetstream_api.go @@ -953,10 +953,10 @@ func (s *Server) jsStreamInfoRequest(sub *subscription, c *client, subject, repl return } config := mset.Config() - // MQTT streams are created without subject, but "nats" tooling would then - // fail to display them since it uses validation and expect the config's - // Subjects to not be empty. - if strings.HasPrefix(name, mqttStreamNamePrefix) && len(config.Subjects) == 0 { + // Some streams are created without subject (for instance MQTT streams), + // but "nats" tooling would then fail to display them since it uses + // validation and expect the config's Subjects to not be empty. + if config.allowNoSubject && len(config.Subjects) == 0 { config.Subjects = []string{">"} } resp.StreamInfo = &StreamInfo{Created: mset.Created(), State: mset.State(), Config: config} @@ -1006,6 +1006,11 @@ func (s *Server) jsStreamDeleteRequest(sub *subscription, c *client, subject, re s.sendAPIResponse(c, subject, reply, string(msg), s.jsonResponse(&resp)) return } + if mset.Config().internal { + resp.Error = &ApiError{Code: 403, Description: "not allowed to delete internal stream"} + s.sendAPIResponse(c, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } if err := mset.Delete(); err != nil { resp.Error = jsError(err) s.sendAPIResponse(c, subject, reply, string(msg), s.jsonResponse(&resp)) @@ -1713,6 +1718,11 @@ func (s *Server) jsConsumerDeleteRequest(sub *subscription, c *client, subject, s.sendAPIResponse(c, subject, reply, string(msg), s.jsonResponse(&resp)) return } + if mset.Config().internal { + resp.Error = &ApiError{Code: 403, Description: "not allowed to delete consumer of internal stream"} + s.sendAPIResponse(c, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } consumer := consumerNameFromSubject(subject) obs := mset.LookupConsumer(consumer) if obs == nil { diff --git a/server/mqtt.go b/server/mqtt.go index 3e7fd531..8d64ff8c 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -631,12 +631,14 @@ func (as *mqttAccountSessionManager) init(acc *Account, c *client) error { // Start with sessions stream as.sstream, err = acc.LookupStream(mqttSessionsStreamName) if err != nil { - as.sstream, err = acc.addStreamWithStore(&StreamConfig{ - Subjects: []string{}, - Name: mqttSessionsStreamName, - Storage: FileStorage, - Retention: InterestPolicy, - }, nil, true) + as.sstream, err = acc.AddStream(&StreamConfig{ + Subjects: []string{}, + Name: mqttSessionsStreamName, + Storage: FileStorage, + Retention: InterestPolicy, + internal: true, + allowNoSubject: true, + }) if err != nil { return fmt.Errorf("unable to create sessions stream for MQTT account %q: %v", acc.GetName(), err) } @@ -644,12 +646,14 @@ func (as *mqttAccountSessionManager) init(acc *Account, c *client) error { // Create the stream for the messages. as.mstream, err = acc.LookupStream(mqttStreamName) if err != nil { - as.mstream, err = acc.addStreamWithStore(&StreamConfig{ - Subjects: []string{}, - Name: mqttStreamName, - Storage: FileStorage, - Retention: InterestPolicy, - }, nil, true) + as.mstream, err = acc.AddStream(&StreamConfig{ + Subjects: []string{}, + Name: mqttStreamName, + Storage: FileStorage, + Retention: InterestPolicy, + internal: true, + allowNoSubject: true, + }) if err != nil { return fmt.Errorf("unable to create messages stream for MQTT account %q: %v", acc.GetName(), err) } @@ -657,12 +661,14 @@ func (as *mqttAccountSessionManager) init(acc *Account, c *client) error { // Create the stream for retained messages. as.rstream, err = acc.LookupStream(mqttRetainedMsgsStreamName) if err != nil { - as.rstream, err = acc.addStreamWithStore(&StreamConfig{ - Subjects: []string{}, - Name: mqttRetainedMsgsStreamName, - Storage: FileStorage, - Retention: InterestPolicy, - }, nil, true) + as.rstream, err = acc.AddStream(&StreamConfig{ + Subjects: []string{}, + Name: mqttRetainedMsgsStreamName, + Storage: FileStorage, + Retention: InterestPolicy, + internal: true, + allowNoSubject: true, + }) if err != nil { return fmt.Errorf("unable to create retained messages stream for MQTT account %q: %v", acc.GetName(), err) } @@ -2059,15 +2065,16 @@ func (c *client) mqttProcessJSConsumer(sess *mqttSession, stream *Stream, subjec maxAckPending = mqttDefaultMaxAckPending } cc := &ConsumerConfig{ - DeliverSubject: inbox, - Durable: durName, - AckPolicy: AckExplicit, - DeliverPolicy: DeliverNew, - FilterSubject: subject, - AckWait: ackWait, - MaxAckPending: int(maxAckPending), + DeliverSubject: inbox, + Durable: durName, + AckPolicy: AckExplicit, + DeliverPolicy: DeliverNew, + FilterSubject: subject, + AckWait: ackWait, + MaxAckPending: int(maxAckPending), + allowNoInterest: true, } - cons, err = stream.addConsumerCheckInterest(cc, false) + cons, err = stream.AddConsumer(cc) if err != nil { c.Errorf("Unable to add JetStream consumer for subscription on %q: err=%v", subject, err) return nil, nil, err diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 7c76b490..66375287 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -2307,21 +2307,11 @@ func TestMQTTTrackPendingOverrun(t *testing.T) { } } -func TestMQTTPreventStreamAndConsumerWithMQTTPrefix(t *testing.T) { +func TestMQTTPreventDeleteMQTTStreamsAndConsumers(t *testing.T) { o := testMQTTDefaultOptions() s := testMQTTRunServer(t, o) defer testMQTTShutdownServer(s) - sc := &StreamConfig{ - Name: mqttStreamNamePrefix + "test", - Storage: FileStorage, - Retention: InterestPolicy, - Subjects: []string{"foo.>"}, - } - if _, err := s.GlobalAccount().AddStream(sc); err == nil { - t.Fatal("Expected error") - } - mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) defer mc.Close() testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) @@ -2330,15 +2320,60 @@ func TestMQTTPreventStreamAndConsumerWithMQTTPrefix(t *testing.T) { mset, err := s.GlobalAccount().LookupStream(mqttStreamName) if err != nil { - t.Fatalf("Error looking up MQTT Stream: %v", err) + t.Fatalf("Error looking up stream: %v", err) } - cc := &ConsumerConfig{ - Durable: "dur", - AckPolicy: AckExplicit, - DeliverSubject: "bar", + var cName string + mset.mu.Lock() + for cname := range mset.consumers { + cName = cname + break } - if _, err := mset.AddConsumer(cc); err == nil { - t.Fatal("Expected error") + mset.mu.Unlock() + + // Try first to delete the consumer with API and it should fail + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + respMsg, err := nc.Request(fmt.Sprintf(JSApiConsumerDeleteT, mqttStreamName, cName), nil, time.Second) + if err != nil { + t.Fatalf("Error sending request: %v", err) + } + var resp JSApiConsumerDeleteResponse + if err = json.Unmarshal(respMsg.Data, &resp); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if resp.Success || resp.Error == nil { + t.Fatalf("Operation should have failed") + } + delErr := resp.Error + if delErr.Code != 403 { + t.Fatalf("Expected forbidden, got %v", delErr.Code) + } + if !strings.Contains(delErr.Description, "not allowed to delete consumer of internal stream") { + t.Fatalf("Unexpected error description: %q", delErr.Description) + } + + // Now try with all MQTT streams + streamNames := []string{mqttStreamName, mqttSessionsStreamName, mqttRetainedMsgsStreamName} + for _, sName := range streamNames { + respMsg, err := nc.Request(fmt.Sprintf(JSApiStreamDeleteT, sName), nil, time.Second) + if err != nil { + t.Fatalf("Error sending request: %v", err) + } + var resp JSApiStreamDeleteResponse + if err = json.Unmarshal(respMsg.Data, &resp); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if resp.Success || resp.Error == nil { + t.Fatalf("Operation should have failed") + } + delErr := resp.Error + if delErr.Code != 403 { + t.Fatalf("Expected forbidden, got %v", delErr.Code) + } + if !strings.Contains(delErr.Description, "not allowed to delete internal stream") { + t.Fatalf("Unexpected error description: %q", delErr.Description) + } } } diff --git a/server/stream.go b/server/stream.go index d709767a..029cf391 100644 --- a/server/stream.go +++ b/server/stream.go @@ -27,7 +27,6 @@ import ( "path/filepath" "reflect" "strconv" - "strings" "sync" "time" @@ -51,6 +50,12 @@ type StreamConfig struct { NoAck bool `json:"no_ack,omitempty"` Template string `json:"template_owner,omitempty"` Duplicates time.Duration `json:"duplicate_window,omitempty"` + + // These are non public configuration options. + // If you add new options, check fileStreamInfoJSON in order for them to + // be properly persisted/recovered, if needed. + internal bool + allowNoSubject bool } const JSApiPubAckResponseType = "io.nats.jetstream.api.v1.pub_ack_response" @@ -96,7 +101,6 @@ type Stream struct { ddarr []*ddentry ddindex int ddtmr *time.Timer - nosubj bool } // Headers for published messages. @@ -127,20 +131,13 @@ func (a *Account) AddStream(config *StreamConfig) (*Stream, error) { // AddStreamWithStore adds a stream for the given account with custome store config options. func (a *Account) AddStreamWithStore(config *StreamConfig, fsConfig *FileStoreConfig) (*Stream, error) { - if strings.HasPrefix(config.Name, mqttStreamNamePrefix) { - return nil, fmt.Errorf("prefix %q is reserved for MQTT, unable to create stream %q", mqttStreamNamePrefix, config.Name) - } - return a.addStreamWithStore(config, fsConfig, false) -} - -func (a *Account) addStreamWithStore(config *StreamConfig, fsConfig *FileStoreConfig, noSubjectsOK bool) (*Stream, error) { s, jsa, err := a.checkForJetStream() if err != nil { return nil, err } // Sensible defaults. - cfg, err := checkStreamCfg(config, noSubjectsOK) + cfg, err := checkStreamCfg(config) if err != nil { return nil, err } @@ -177,7 +174,7 @@ func (a *Account) addStreamWithStore(config *StreamConfig, fsConfig *FileStoreCo // Setup the internal client. c := s.createInternalJetStreamClient() - mset := &Stream{jsa: jsa, config: cfg, client: c, consumers: make(map[string]*Consumer), nosubj: noSubjectsOK} + mset := &Stream{jsa: jsa, config: cfg, client: c, consumers: make(map[string]*Consumer)} jsa.streams[cfg.Name] = mset storeDir := path.Join(jsa.storeDir, streamsDir, cfg.Name) @@ -425,7 +422,7 @@ func (jsa *jsAccount) subjectsOverlap(subjects []string) bool { // Default duplicates window. const StreamDefaultDuplicatesWindow = 2 * time.Minute -func checkStreamCfg(config *StreamConfig, noSubjectOk bool) (StreamConfig, error) { +func checkStreamCfg(config *StreamConfig) (StreamConfig, error) { if config == nil { return StreamConfig{}, fmt.Errorf("stream configuration invalid") } @@ -475,7 +472,7 @@ func checkStreamCfg(config *StreamConfig, noSubjectOk bool) (StreamConfig, error } if len(cfg.Subjects) == 0 { - if !noSubjectOk { + if !cfg.allowNoSubject { cfg.Subjects = append(cfg.Subjects, cfg.Name) } } else { @@ -530,10 +527,7 @@ func (mset *Stream) Delete() error { // Update will allow certain configuration properties of an existing stream to be updated. func (mset *Stream) Update(config *StreamConfig) error { - mset.mu.RLock() - nosubj := mset.nosubj - mset.mu.RUnlock() - cfg, err := checkStreamCfg(config, nosubj) + cfg, err := checkStreamCfg(config) if err != nil { return err } @@ -915,7 +909,7 @@ func getExpectedLastSeq(hdr []byte) uint64 { } // processInboundJetStreamMsg handles processing messages bound for a stream. -func (mset *Stream) processInboundJetStreamMsg(sub *subscription, pc *client, subject, reply string, msg []byte) { +func (mset *Stream) processInboundJetStreamMsg(_ *subscription, pc *client, subject, reply string, msg []byte) { mset.mu.Lock() store := mset.store c := mset.client diff --git a/server/sublist.go b/server/sublist.go index fcb0c28d..16e75ac0 100644 --- a/server/sublist.go +++ b/server/sublist.go @@ -1393,6 +1393,13 @@ func (s *Sublist) collectAllSubs(l *level, subs *[]*subscription) { } } +// For a given subject (which may contain wildcards), this call returns all +// subscriptions that would match that subject. For instance, suppose that +// the sublist contains: foo.bar, foo.bar.baz and foo.baz, ReverseMatch("foo.*") +// would return foo.bar and foo.baz. +// This is used in situations where the sublist is likely to contain only +// literals and one wants to get all the subjects that would have been a match +// to a subscription on `subject`. func (s *Sublist) ReverseMatch(subject string) *SublistResult { tsa := [32]string{} tokens := tsa[:0] From 4fc04d3f55fc9f7795caff2a87f3f1ed3614abe4 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Tue, 1 Dec 2020 15:38:47 -0700 Subject: [PATCH 05/11] Revert changes to processSub() Based on how the MQTT callback operates, it is safe to finish setup of the MQTT subscriptions after processSub() returns. So I have reverted the changes to processSub() which will minimize changes to non-MQTT related code. Signed-off-by: Ivan Kozlovic --- server/accounts.go | 6 +++--- server/client.go | 33 ++++++++++++++--------------- server/events.go | 2 +- server/jetstream.go | 2 +- server/mqtt.go | 51 ++++++++++++++++++++------------------------- server/stream.go | 2 +- 6 files changed, 45 insertions(+), 51 deletions(-) diff --git a/server/accounts.go b/server/accounts.go index 02d32028..9a68c032 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -1726,7 +1726,7 @@ func (a *Account) subscribeInternal(subject string, cb msgHandler) (*subscriptio return nil, fmt.Errorf("no internal account client") } - return c.processSub(c.createSub([]byte(subject), nil, []byte(sid), cb), false) + return c.processSub([]byte(subject), nil, []byte(sid), cb, false) } // This will add an account subscription that matches the "from" from a service import entry. @@ -1751,7 +1751,7 @@ func (a *Account) addServiceImportSub(si *serviceImport) error { cb := func(sub *subscription, c *client, subject, reply string, msg []byte) { c.processServiceImport(si, a, msg) } - _, err := c.processSub(c.createSub([]byte(subject), nil, []byte(sid), cb), true) + _, err := c.processSub([]byte(subject), nil, []byte(sid), cb, true) return err } @@ -1951,7 +1951,7 @@ func (a *Account) createRespWildcard() []byte { a.mu.Unlock() // Create subscription and internal callback for all the wildcard response subjects. - c.processSub(c.createSub(wcsub, nil, []byte(sid), a.processServiceImportResponse), false) + c.processSub(wcsub, nil, []byte(sid), a.processServiceImportResponse, false) return pre } diff --git a/server/client.go b/server/client.go index 3a68b9d2..485a72df 100644 --- a/server/client.go +++ b/server/client.go @@ -2226,30 +2226,32 @@ func (c *client) parseSub(argo []byte, noForward bool) error { arg := make([]byte, len(argo)) copy(arg, argo) args := splitArg(arg) - sub := &subscription{client: c} + var ( + subject []byte + queue []byte + sid []byte + ) switch len(args) { case 2: - sub.subject = args[0] - sub.queue = nil - sub.sid = args[1] + subject = args[0] + queue = nil + sid = args[1] case 3: - sub.subject = args[0] - sub.queue = args[1] - sub.sid = args[2] + subject = args[0] + queue = args[1] + sid = args[2] default: return fmt.Errorf("processSub Parse Error: '%s'", arg) } // If there was an error, it has been sent to the client. We don't return an // error here to not close the connection as a parsing error. - c.processSub(sub, noForward) + c.processSub(subject, queue, sid, nil, noForward) return nil } -func (c *client) createSub(subject, queue, sid []byte, cb msgHandler) *subscription { - return &subscription{client: c, subject: subject, queue: queue, sid: sid, icb: cb} -} - -func (c *client) processSub(sub *subscription, noForward bool) (*subscription, error) { +func (c *client) processSub(subject, queue, bsid []byte, cb msgHandler, noForward bool) (*subscription, error) { + // Create the subscription + sub := &subscription{client: c, subject: subject, queue: queue, sid: bsid, icb: cb} c.mu.Lock() @@ -2266,7 +2268,7 @@ func (c *client) processSub(sub *subscription, noForward bool) (*subscription, e // This check does not apply to SYSTEM or JETSTREAM or ACCOUNT clients (because they don't have a `nc`...) if c.isClosed() && (kind != SYSTEM && kind != JETSTREAM && kind != ACCOUNT) { c.mu.Unlock() - return nil, nil + return nil, ErrConnectionClosed } // Check permissions if applicable. @@ -2313,9 +2315,6 @@ func (c *client) processSub(sub *subscription, noForward bool) (*subscription, e updateGWs = c.srv.gateway.enabled } } - } else if es.mqtt != nil && sub.mqtt != nil { - es.mqtt.prm = sub.mqtt.prm - es.mqtt.qos = sub.mqtt.qos } // Unlocked from here onward c.mu.Unlock() diff --git a/server/events.go b/server/events.go index 3e9c84dc..9d3862d3 100644 --- a/server/events.go +++ b/server/events.go @@ -1411,7 +1411,7 @@ func (s *Server) systemSubscribe(subject, queue string, internalOnly bool, cb ms q = []byte(queue) } // Now create the subscription - return c.processSub(c.createSub([]byte(subject), q, []byte(sid), cb), internalOnly) + return c.processSub([]byte(subject), q, []byte(sid), cb, internalOnly) } func (s *Server) sysUnsubscribe(sub *subscription) { diff --git a/server/jetstream.go b/server/jetstream.go index 66463ab1..9a591471 100644 --- a/server/jetstream.go +++ b/server/jetstream.go @@ -1115,7 +1115,7 @@ func (t *StreamTemplate) createTemplateSubscriptions() error { sid := 1 for _, subject := range t.Config.Subjects { // Now create the subscription - if _, err := c.processSub(c.createSub([]byte(subject), nil, []byte(strconv.Itoa(sid)), t.processInboundTemplateMsg), false); err != nil { + if _, err := c.processSub([]byte(subject), nil, []byte(strconv.Itoa(sid)), t.processInboundTemplateMsg, false); err != nil { c.acc.DeleteStreamTemplate(t.Name) return err } diff --git a/server/mqtt.go b/server/mqtt.go index 8d64ff8c..dcc4a957 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -795,6 +795,16 @@ func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, clientID str sess.cons[sid] = cons } + setupSub := func(sub *subscription, qos byte) { + if sub.mqtt == nil { + sub.mqtt = &mqttSub{} + } + sub.mqtt.qos = qos + if fromSubProto { + as.serializeRetainedMsgsForSub(sess, c, sub, trace) + } + } + subs := make([]*subscription, 0, len(filters)) for _, f := range filters { if f.qos > 1 { @@ -810,17 +820,12 @@ func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, clientID str var jscons *Consumer var jssub *subscription - var err error - sub := c.mqttCreateSub(subject, sid, mqttDeliverMsgCb, f.qos) - if fromSubProto { - as.serializeRetainedMsgsForSub(sess, c, sub, trace) - } // Note that if a subscription already exists on this subject, - // the sub is updated with the new qos/prm and the pointer to - // the existing subscription is returned. - sub, err = c.processSub(sub, false) + // the existing sub is returned. Need to update the qos. + sub, err := c.processSub([]byte(subject), nil, []byte(sid), mqttDeliverMsgCb, false) if err == nil { + setupSub(sub, f.qos) // This will create (if not already exist) a JS consumer for subscriptions // of QoS >= 1. But if a JS consumer already exists and the subscription // for same subject is now a QoS==0, then the JS consumer will be deleted. @@ -836,18 +841,16 @@ func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, clientID str if mqttNeedSubForLevelUp(subject) { var fwjscons *Consumer var fwjssub *subscription + var fwcsub *subscription // Say subject is "foo.>", remove the ".>" so that it becomes "foo" fwcsubject := subject[:len(subject)-2] // Change the sid to "foo fwc" fwcsid := fwcsubject + mqttMultiLevelSidSuffix - fwcsub := c.mqttCreateSub(fwcsubject, fwcsid, mqttDeliverMsgCb, f.qos) - if fromSubProto { - as.serializeRetainedMsgsForSub(sess, c, fwcsub, trace) - } // See note above about existing subscription. - fwcsub, err = c.processSub(fwcsub, false) + fwcsub, err = c.processSub([]byte(fwcsubject), nil, []byte(fwcsid), mqttDeliverMsgCb, false) if err == nil { + setupSub(fwcsub, f.qos) fwjscons, fwjssub, err = c.mqttProcessJSConsumer(sess, as.mstream, fwcsubject, fwcsid, f.qos, fromSubProto) } @@ -1841,10 +1844,6 @@ func mqttSubscribeTrace(filters []*mqttFilter) string { } func mqttDeliverMsgCb(sub *subscription, pc *client, subject, reply string, msg []byte) { - if sub.mqtt == nil { - return - } - var ppFlags byte var pQoS byte var pi uint16 @@ -1859,7 +1858,7 @@ func mqttDeliverMsgCb(sub *subscription, pc *client, subject, reply string, msg // We lock to check some of the subscription's fields and if we need to // keep track of pending acks, etc.. sess.mu.Lock() - if sess.c != cc { + if sess.c != cc || sub.mqtt == nil { sess.mu.Unlock() return } @@ -1960,13 +1959,6 @@ func mqttSerializePublishMsg(w *mqttWriter, pi uint16, dup, retained bool, subje return flags } -// Helper to create an MQTT subscription. -func (c *client) mqttCreateSub(subject, sid string, cb msgHandler, qos byte) *subscription { - sub := c.createSub([]byte(subject), nil, []byte(sid), cb) - sub.mqtt = &mqttSub{qos: qos} - return sub -} - // Process the list of subscriptions and update the given filter // with the QoS that has been accepted (or failure). // @@ -2080,12 +2072,10 @@ func (c *client) mqttProcessJSConsumer(sess *mqttSession, stream *Stream, subjec return nil, nil, err } } - sub := c.mqttCreateSub(inbox, inbox, mqttDeliverMsgCb, qos) - sub.mqtt.jsCons = cons // This is an internal subscription on subject like "$MQTT.sub." that is setup // for the JS durable's deliver subject. I don't think that there is any need to // forward this subscription in the cluster/super cluster. - sub, err = c.processSub(sub, true) + sub, err := c.processSub([]byte(inbox), nil, []byte(inbox), mqttDeliverMsgCb, true) if err != nil { if !exists { cons.Delete() @@ -2093,6 +2083,11 @@ func (c *client) mqttProcessJSConsumer(sess *mqttSession, stream *Stream, subjec c.Errorf("Unable to create subscription for JetStream consumer on %q: %v", subject, err) return nil, nil, err } + if sub.mqtt == nil { + sub.mqtt = &mqttSub{} + } + sub.mqtt.qos = qos + sub.mqtt.jsCons = cons return cons, sub, nil } diff --git a/server/stream.go b/server/stream.go index 029cf391..c3cbea4f 100644 --- a/server/stream.go +++ b/server/stream.go @@ -699,7 +699,7 @@ func (mset *Stream) subscribeInternal(subject string, cb msgHandler) (*subscript mset.sid++ // Now create the subscription - return c.processSub(c.createSub([]byte(subject), nil, []byte(strconv.Itoa(mset.sid)), cb), false) + return c.processSub([]byte(subject), nil, []byte(strconv.Itoa(mset.sid)), cb, false) } // Helper for unlocked stream. From 41fac39f8ebf99a318f480d705c6f1138336ac22 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Tue, 1 Dec 2020 18:12:07 -0700 Subject: [PATCH 06/11] Split createClient() into versions for normal, WS and MQTT clients. This duplicate quite a bit of code, but reduces the conditionals in the createClient() function. Signed-off-by: Ivan Kozlovic --- server/client_test.go | 4 +- server/mqtt.go | 134 +++++++++++++++++++++++++++++++++++++++++- server/server.go | 87 +++++++-------------------- server/websocket.go | 98 +++++++++++++++++++++++++++++- 4 files changed, 253 insertions(+), 70 deletions(-) diff --git a/server/client_test.go b/server/client_test.go index 68e48dcc..64899a9e 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -85,7 +85,7 @@ func createClientAsync(ch chan *client, s *Server, cli net.Conn) { s.grWG.Add(1) } go func() { - c := s.createClient(cli, nil, nil) + c := s.createClient(cli) // Must be here to suppress +OK c.opts.Verbose = false if startWriteLoop { @@ -2317,7 +2317,7 @@ func TestCloseConnectionVeryEarly(t *testing.T) { // Call again with this closed connection. Alternatively, we // would have to call with a fake connection that implements // net.Conn but returns an error on Write. - s.createClient(c, nil, nil) + s.createClient(c) // This connection should not have been added to the server. checkClientsCount(t, s, 0) diff --git a/server/mqtt.go b/server/mqtt.go index dcc4a957..f47d7675 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -15,6 +15,7 @@ package server import ( "bytes" + "crypto/tls" "encoding/binary" "encoding/json" "errors" @@ -279,10 +280,141 @@ func (s *Server) startMQTT() { scheme = "tls" } s.Noticef("Listening for MQTT clients on %s://%s:%d", scheme, o.Host, o.Port) - go s.acceptConnections(hl, "MQTT", func(conn net.Conn) { s.createClient(conn, nil, &mqtt{}) }, nil) + go s.acceptConnections(hl, "MQTT", func(conn net.Conn) { s.createMQTTClient(conn) }, nil) s.mu.Unlock() } +// This is similar to createClient() but has some modifications specifi to MQTT clients. +// The comments have been kept to minimum to reduce code size. Check createClient() for +// more details. +func (s *Server) createMQTTClient(conn net.Conn) *client { + opts := s.getOpts() + + maxPay := int32(opts.MaxPayload) + maxSubs := int32(opts.MaxSubs) + if maxSubs == 0 { + maxSubs = -1 + } + now := time.Now() + + c := &client{srv: s, nc: conn, mpay: maxPay, msubs: maxSubs, start: now, last: now, mqtt: &mqtt{}} + // MQTT clients don't send NATS CONNECT protocols. So make it an "echo" + // client, but disable verbose and pedantic (by not setting them). + c.opts.Echo = true + + c.registerWithAccount(s.globalAccount()) + + s.mu.Lock() + // Check auth, override if applicable. + authRequired := s.info.AuthRequired || s.mqtt.authOverride + s.totalClients++ + s.mu.Unlock() + + c.mu.Lock() + if authRequired { + c.flags.set(expectConnect) + } + c.initClient() + c.Debugf("Client connection created") + c.mu.Unlock() + + s.mu.Lock() + if !s.running || s.ldm { + if s.shutdown { + conn.Close() + } + s.mu.Unlock() + return c + } + + if opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn { + s.mu.Unlock() + c.maxConnExceeded() + return nil + } + s.clients[c.cid] = c + + tlsRequired := opts.MQTT.TLSConfig != nil + s.mu.Unlock() + + c.mu.Lock() + + isClosed := c.isClosed() + + var pre []byte + if !isClosed && tlsRequired && opts.AllowNonTLS { + pre = make([]byte, 4) + c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(opts.MQTT.TLSTimeout))) + n, _ := io.ReadFull(c.nc, pre[:]) + c.nc.SetReadDeadline(time.Time{}) + pre = pre[:n] + if n > 0 && pre[0] == 0x16 { + tlsRequired = true + } else { + tlsRequired = false + } + } + + if !isClosed && tlsRequired { + c.Debugf("Starting TLS client connection handshake") + if len(pre) > 0 { + c.nc = &tlsMixConn{c.nc, bytes.NewBuffer(pre)} + pre = nil + } + + c.nc = tls.Server(c.nc, opts.MQTT.TLSConfig) + conn := c.nc.(*tls.Conn) + + ttl := secondsToDuration(opts.MQTT.TLSTimeout) + time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) + conn.SetReadDeadline(time.Now().Add(ttl)) + + c.mu.Unlock() + if err := conn.Handshake(); err != nil { + c.Errorf("TLS handshake error: %v", err) + c.closeConnection(TLSHandshakeError) + return nil + } + conn.SetReadDeadline(time.Time{}) + + c.mu.Lock() + + c.flags.set(handshakeComplete) + + isClosed = c.isClosed() + } + + if isClosed { + c.mu.Unlock() + c.closeConnection(WriteError) + return nil + } + + if authRequired { + timeout := opts.AuthTimeout + // Possibly override with MQTT specific value. + if opts.MQTT.AuthTimeout != 0 { + timeout = opts.MQTT.AuthTimeout + } + c.setAuthTimer(secondsToDuration(timeout)) + } + + // No Ping timer for MQTT clients... + + s.startGoRoutine(func() { c.readLoop(pre) }) + s.startGoRoutine(func() { c.writeLoop() }) + + if tlsRequired { + c.Debugf("TLS handshake complete") + cs := c.nc.(*tls.Conn).ConnectionState() + c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite)) + } + + c.mu.Unlock() + + return c +} + // Given the mqtt options, we check if any auth configuration // has been provided. If so, possibly create users/nkey users and // store them in s.mqtt.users/nkeys. diff --git a/server/server.go b/server/server.go index 59767929..88d03794 100644 --- a/server/server.go +++ b/server/server.go @@ -1765,7 +1765,7 @@ func (s *Server) AcceptLoop(clr chan struct{}) { s.clientConnectURLs = s.getClientConnectURLs() s.listener = l - go s.acceptConnections(l, "Client", func(conn net.Conn) { s.createClient(conn, nil, nil) }, + go s.acceptConnections(l, "Client", func(conn net.Conn) { s.createClient(conn) }, func(_ error) bool { if s.isLameDuckMode() { // Signal that we are not accepting new clients @@ -2091,7 +2091,7 @@ func (c *tlsMixConn) Read(b []byte) (int, error) { return c.Conn.Read(b) } -func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client { +func (s *Server) createClient(conn net.Conn) *client { // Snapshot server options. opts := s.getOpts() @@ -2103,14 +2103,7 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client } now := time.Now() - c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now, ws: ws} - if mqtt != nil { - c.mqtt = mqtt - // Set some of the options here since MQTT clients don't - // send a regular CONNECT (but have their own). - c.opts.Lang = "mqtt" - c.opts.Verbose = false - } + c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now} c.registerWithAccount(s.globalAccount()) @@ -2118,28 +2111,17 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client var authRequired bool s.mu.Lock() - // We don't need the INFO to mqtt clients. - if mqtt == nil { - // Grab JSON info string - info = s.copyInfo() - // If this is a websocket client and there is no top-level auth specified, - // then we use the websocket's specific boolean that will be set to true - // if there is any auth{} configured in websocket{}. - if ws != nil && !info.AuthRequired { - info.AuthRequired = s.websocket.authOverride - } - if s.nonceRequired() { - // Nonce handling - var raw [nonceLen]byte - nonce := raw[:] - s.generateNonce(nonce) - info.Nonce = string(nonce) - } - c.nonce = []byte(info.Nonce) - authRequired = info.AuthRequired - } else { - authRequired = s.info.AuthRequired || s.mqtt.authOverride + // Grab JSON info string + info = s.copyInfo() + if s.nonceRequired() { + // Nonce handling + var raw [nonceLen]byte + nonce := raw[:] + s.generateNonce(nonce) + info.Nonce = string(nonce) } + c.nonce = []byte(info.Nonce) + authRequired = info.AuthRequired s.totalClients++ s.mu.Unlock() @@ -2155,13 +2137,10 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client c.Debugf("Client connection created") - // We don't send the INFO to mqtt clients. - if mqtt == nil { - // Send our information. - // Need to be sent in place since writeLoop cannot be started until - // TLS handshake is done (if applicable). - c.sendProtoNow(c.generateClientInfoJSON(info)) - } + // Send our information. + // Need to be sent in place since writeLoop cannot be started until + // TLS handshake is done (if applicable). + c.sendProtoNow(c.generateClientInfoJSON(info)) // Unlock to register c.mu.Unlock() @@ -2194,21 +2173,7 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client } s.clients[c.cid] = c - // May be overridden based on type of client. - TLSConfig := opts.TLSConfig - TLSTimeout := opts.TLSTimeout - tlsRequired := info.TLSRequired - // Websocket clients do TLS in the websocket http server. - if ws != nil { - tlsRequired = false - } else if mqtt != nil { - tlsRequired = opts.MQTT.TLSConfig != nil - if tlsRequired { - TLSConfig = opts.MQTT.TLSConfig - TLSTimeout = opts.MQTT.TLSTimeout - } - } s.mu.Unlock() // Re-Grab lock @@ -2222,7 +2187,7 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client // one the client wants. if !isClosed && opts.TLSConfig != nil && opts.AllowNonTLS { pre = make([]byte, 4) - c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(TLSTimeout))) + c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(opts.TLSTimeout))) n, _ := io.ReadFull(c.nc, pre[:]) c.nc.SetReadDeadline(time.Time{}) pre = pre[:n] @@ -2243,11 +2208,11 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client pre = nil } - c.nc = tls.Server(c.nc, TLSConfig) + c.nc = tls.Server(c.nc, opts.TLSConfig) conn := c.nc.(*tls.Conn) // Setup the timeout - ttl := secondsToDuration(TLSTimeout) + ttl := secondsToDuration(opts.TLSTimeout) time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) conn.SetReadDeadline(time.Now().Add(ttl)) @@ -2285,17 +2250,7 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client // the race where the timer fires during the handshake and causes the // server to write bad data to the socket. See issue #432. if authRequired { - timeout := opts.AuthTimeout - // For websocket, possibly override only if set. We make sure that - // opts.AuthTimeout is set to a default value if not configured, - // but we don't do the same for websocket's one so that we know - // if user has explicitly set or not. - if ws != nil && opts.Websocket.AuthTimeout != 0 { - timeout = opts.Websocket.AuthTimeout - } else if mqtt != nil && opts.MQTT.AuthTimeout != 0 { - timeout = opts.MQTT.AuthTimeout - } - c.setAuthTimer(secondsToDuration(timeout)) + c.setAuthTimer(secondsToDuration(opts.AuthTimeout)) } // Do final client initialization diff --git a/server/websocket.go b/server/websocket.go index 10defdc4..c29722f5 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -871,7 +871,7 @@ func (s *Server) startWebsocketServer() { s.Errorf(err.Error()) return } - s.createClient(res.conn, res.ws, nil) + s.createWSClient(res.conn, res.ws) }) hs := &http.Server{ Addr: hp, @@ -897,6 +897,102 @@ func (s *Server) startWebsocketServer() { s.mu.Unlock() } +// This is similar to createClient() but has some modifications +// specific to handle websocket clients. +// The comments have been kept to minimum to reduce code size. +// Check createClient() for more details. +func (s *Server) createWSClient(conn net.Conn, ws *websocket) *client { + opts := s.getOpts() + + maxPay := int32(opts.MaxPayload) + maxSubs := int32(opts.MaxSubs) + if maxSubs == 0 { + maxSubs = -1 + } + now := time.Now() + + c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now, ws: ws} + + c.registerWithAccount(s.globalAccount()) + + var info Info + var authRequired bool + + s.mu.Lock() + info = s.copyInfo() + // Check auth, override if applicable. + if !info.AuthRequired { + // Set info.AuthRequired since this is what is sent to the client. + info.AuthRequired = s.websocket.authOverride + } + if s.nonceRequired() { + var raw [nonceLen]byte + nonce := raw[:] + s.generateNonce(nonce) + info.Nonce = string(nonce) + } + c.nonce = []byte(info.Nonce) + authRequired = info.AuthRequired + + s.totalClients++ + s.mu.Unlock() + + c.mu.Lock() + if authRequired { + c.flags.set(expectConnect) + } + c.initClient() + c.Debugf("Client connection created") + c.sendProtoNow(c.generateClientInfoJSON(info)) + c.mu.Unlock() + + s.mu.Lock() + if !s.running || s.ldm { + if s.shutdown { + conn.Close() + } + s.mu.Unlock() + return c + } + + if opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn { + s.mu.Unlock() + c.maxConnExceeded() + return nil + } + s.clients[c.cid] = c + + // Websocket clients do TLS in the websocket http server. + // So no TLS here... + s.mu.Unlock() + + c.mu.Lock() + + if c.isClosed() { + c.mu.Unlock() + c.closeConnection(WriteError) + return nil + } + + if authRequired { + timeout := opts.AuthTimeout + // Possibly override with Websocket specific value. + if opts.Websocket.AuthTimeout != 0 { + timeout = opts.Websocket.AuthTimeout + } + c.setAuthTimer(secondsToDuration(timeout)) + } + + c.setPingTimer() + + s.startGoRoutine(func() { c.readLoop(nil) }) + s.startGoRoutine(func() { c.writeLoop() }) + + c.mu.Unlock() + + return c +} + type wsCaptureHTTPServerLog struct { s *Server } From 67425d23c8b4edfe5749a0e6b7e5be8a28d35fa5 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Wed, 2 Dec 2020 15:52:06 -0700 Subject: [PATCH 07/11] Add c.isMqtt() and c.isWebsocket() This hides the check on "c.mqtt != nil" or "c.ws != nil". Added some tests. Signed-off-by: Ivan Kozlovic --- server/auth.go | 9 ++++--- server/client.go | 51 ++++++++++++++++++++-------------------- server/mqtt.go | 8 ++++++- server/mqtt_test.go | 51 ++++++++++++++++++++++++++++++++++++++-- server/parser.go | 2 +- server/route.go | 2 +- server/server.go | 2 +- server/websocket.go | 6 +++++ server/websocket_test.go | 2 +- 9 files changed, 96 insertions(+), 37 deletions(-) diff --git a/server/auth.go b/server/auth.go index da27dc7e..e925b898 100644 --- a/server/auth.go +++ b/server/auth.go @@ -345,11 +345,10 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo ) s.mu.Lock() authRequired := s.info.AuthRequired - // c.ws/mqtt is immutable, but may need lock if we get race reports. if !authRequired { - if c.mqtt != nil { + if c.isMqtt() { authRequired = s.mqtt.authOverride - } else if c.ws != nil { + } else if c.isWebsocket() { // If no auth required for regular clients, then check if // we have an override for websocket clients. authRequired = s.websocket.authOverride @@ -367,7 +366,7 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo noAuthUser string ) tlsMap := opts.TLSMap - if c.mqtt != nil { + if c.isMqtt() { mo := &opts.MQTT // Always override TLSMap. tlsMap = mo.TLSMap @@ -380,7 +379,7 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo token = mo.Token ao = true } - } else if c.ws != nil { + } else if c.isWebsocket() { wo := &opts.Websocket // Always override TLSMap. tlsMap = wo.TLSMap diff --git a/server/client.go b/server/client.go index 485a72df..cd43d58a 100644 --- a/server/client.go +++ b/server/client.go @@ -541,9 +541,9 @@ func (c *client) initClient() { switch c.kind { case CLIENT: name := "cid" - if c.ws != nil { + if c.isWebsocket() { name = "wid" - } else if c.mqtt != nil { + } else if c.isMqtt() { name = "mid" } c.ncs.Store(fmt.Sprintf("%s - %s:%d", conn, name, c.cid)) @@ -968,8 +968,8 @@ func (c *client) readLoop(pre []byte) { return } nc := c.nc - ws := c.ws != nil - if c.mqtt != nil { + ws := c.isWebsocket() + if c.isMqtt() { c.mqtt.r = &mqttReader{reader: nc} } c.in.rsz = startBufSize @@ -991,7 +991,7 @@ func (c *client) readLoop(pre []byte) { c.mu.Unlock() defer func() { - if c.mqtt != nil { + if c.isMqtt() { s.mqttHandleWill(c) } // These are used only in the readloop, so we can set them to nil @@ -1155,7 +1155,7 @@ func closedStateForErr(err error) ClosedState { // collapsePtoNB will place primary onto nb buffer as needed in prep for WriteTo. // This will return a copy on purpose. func (c *client) collapsePtoNB() (net.Buffers, int64) { - if c.ws != nil { + if c.isWebsocket() { return c.wsCollapsePtoNB() } if c.out.p != nil { @@ -1169,7 +1169,7 @@ func (c *client) collapsePtoNB() (net.Buffers, int64) { // This will handle the fixup needed on a partial write. // Assume pending has been already calculated correctly. func (c *client) handlePartialWrite(pnb net.Buffers) { - if c.ws != nil { + if c.isWebsocket() { c.ws.frames = append(pnb, c.ws.frames...) return } @@ -1263,7 +1263,7 @@ func (c *client) flushOutbound() bool { // Subtract from pending bytes and messages. c.out.pb -= n - if c.ws != nil { + if c.isWebsocket() { c.ws.fs -= n } c.out.pm -= apm // FIXME(dlc) - this will not be totally accurate on partials. @@ -1382,7 +1382,7 @@ func (c *client) markConnAsClosed(reason ClosedState) { c.flags.set(connMarkedClosed) // For a websocket client, unless we are told not to flush, enqueue // a websocket CloseMessage based on the reason. - if !skipFlush && c.ws != nil && !c.ws.closeSent { + if !skipFlush && c.isWebsocket() && !c.ws.closeSent { c.wsEnqueueCloseMessage(reason) } // Be consistent with the creation: for routes and gateways, @@ -1942,7 +1942,7 @@ func (c *client) sendRTTPing() bool { // the c.rtt is 0 and wants to force an update by sending a PING. // Client lock held on entry. func (c *client) sendRTTPingLocked() bool { - if c.mqtt != nil { + if c.isMqtt() { return false } // Most client libs send a CONNECT+PING and wait for a PONG from the @@ -1975,7 +1975,7 @@ func (c *client) generateClientInfoJSON(info Info) []byte { info.CID = c.cid info.ClientIP = c.host info.MaxPayload = c.mpay - if c.ws != nil { + if c.isWebsocket() { info.ClientConnectURLs = info.WSConnectURLs } info.WSConnectURLs = nil @@ -1990,7 +1990,7 @@ func (c *client) sendErr(err string) { if c.trace { c.traceOutOp("-ERR", []byte(err)) } - if c.mqtt == nil { + if !c.isMqtt() { c.enqueueProto([]byte(fmt.Sprintf(errProto, err))) } c.mu.Unlock() @@ -3326,7 +3326,7 @@ func (c *client) processInboundClientMsg(msg []byte) bool { } // If MQTT client, check for retain flag now that we have passed permissions check - if c.mqtt != nil { + if c.isMqtt() { c.mqttHandlePubRetain() } @@ -4002,7 +4002,7 @@ func (c *client) processPingTimer() { c.mu.Lock() c.ping.tmr = nil // Check if connection is still opened - if c.isClosed() || c.mqtt != nil { + if c.isClosed() { c.mu.Unlock() return } @@ -4061,7 +4061,7 @@ func adjustPingIntervalForGateway(d time.Duration) time.Duration { // Lock should be held func (c *client) setPingTimer() { - if c.srv == nil || c.mqtt != nil { + if c.srv == nil { return } d := c.srv.getOpts().PingInterval @@ -4637,18 +4637,19 @@ func (c *client) connectionTypeAllowed(acts map[string]struct{}) bool { if len(acts) == 0 { return true } - // Assume standard client, then update based on presence of websocket - // or other type. - want := jwt.ConnectionTypeStandard - if c.kind == LEAF { + var want string + switch c.kind { + case CLIENT: + if c.isWebsocket() { + want = jwt.ConnectionTypeWebsocket + } else if c.isMqtt() { + want = jwt.ConnectionTypeMqtt + } else { + want = jwt.ConnectionTypeStandard + } + case LEAF: want = jwt.ConnectionTypeLeafnode } - if c.ws != nil { - want = jwt.ConnectionTypeWebsocket - } - if c.mqtt != nil { - want = jwt.ConnectionTypeMqtt - } _, ok := acts[want] return ok } diff --git a/server/mqtt.go b/server/mqtt.go index f47d7675..fba4a644 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -456,6 +456,12 @@ func validateMQTTOptions(o *Options) error { return nil } +// Returns true if this connection is from a MQTT client. +// Lock held on entry. +func (c *client) isMqtt() bool { + return c.mqtt != nil +} + // Parse protocols inside the given buffer. // This is invoked from the readLoop. func (c *client) mqttParse(buf []byte) error { @@ -2019,7 +2025,7 @@ func mqttDeliverMsgCb(sub *subscription, pc *client, subject, reply string, msg } // In JS case, we need to use the pc.ca.deliver value as the subject. subject = string(pc.pa.deliver) - } else if pc.mqtt != nil { + } else if pc.isMqtt() { // This is a MQTT publisher... ppFlags = pc.mqtt.pp.flags pQoS = mqttGetQoS(ppFlags) diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 66375287..23ebefe3 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -570,7 +570,7 @@ func testMQTTGetClient(t testing.TB, s *Server, clientID string) *client { s.mu.Lock() for _, c := range s.clients { c.mu.Lock() - if c.mqtt != nil && c.mqtt.cp != nil && c.mqtt.cp.clientID == clientID { + if c.isMqtt() && c.mqtt.cp != nil && c.mqtt.cp.clientID == clientID { mc = c } c.mu.Unlock() @@ -725,6 +725,30 @@ func testMQTTCheckConnAck(t testing.TB, r *mqttReader, rc byte, sessionPresent b } } +func TestMQTTRequiresJSEnabled(t *testing.T) { + o := testMQTTDefaultOptions() + acc := NewAccount("mqtt") + o.Accounts = []*Account{acc} + o.Users = []*User{&User{Username: "mqtt", Account: acc}} + s := testMQTTRunServer(t, o) + defer s.Shutdown() + + addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port) + c, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating mqtt connection: %v", err) + } + defer c.Close() + + proto := mqttCreateConnectProto(&mqttConnInfo{cleanSess: true, user: "mqtt"}) + if _, err := testMQTTWrite(c, proto); err != nil { + t.Fatalf("Error writing connect: %v", err) + } + if _, err := testMQTTRead(c); err == nil { + t.Fatal("Expected failure, did not get one") + } +} + func testMQTTEnableJSForAccount(t *testing.T, s *Server, accName string) { t.Helper() acc, err := s.LookupAccount(accName) @@ -835,7 +859,7 @@ func TestMQTTTLSVerifyAndMap(t *testing.T) { s.mu.Lock() for _, sc := range s.clients { sc.mu.Lock() - if sc.mqtt != nil { + if sc.isMqtt() { c = sc } sc.mu.Unlock() @@ -1380,6 +1404,29 @@ func TestMQTTConnKeepAlive(t *testing.T) { testMQTTExpectDisconnect(t, mc) } +func TestMQTTDontSetPinger(t *testing.T) { + o := testMQTTDefaultOptions() + o.PingInterval = 15 * time.Millisecond + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: "mqtt", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + c := testMQTTGetClient(t, s, "mqtt") + c.mu.Lock() + timerSet := c.ping.tmr != nil + c.mu.Unlock() + if timerSet { + t.Fatalf("Ping timer should not be set for MQTT clients") + } + + // Wait a bit and expect nothing (and connection should still be valid) + testMQTTExpectNothing(t, r) + testMQTTPublish(t, mc, r, 0, false, false, "foo", 0, []byte("msg")) +} + func TestMQTTTopicAndSubjectConversion(t *testing.T) { for _, test := range []struct { name string diff --git a/server/parser.go b/server/parser.go index f76337df..ad2a83b3 100644 --- a/server/parser.go +++ b/server/parser.go @@ -132,7 +132,7 @@ const ( func (c *client) parse(buf []byte) error { // Branch out to mqtt clients. c.mqtt is immutable, but should it become // an issue (say data race detection), we could branch outside in readLoop - if c.mqtt != nil { + if c.isMqtt() { return c.mqttParse(buf) } var i int diff --git a/server/route.go b/server/route.go index 56b2f238..9ad52421 100644 --- a/server/route.go +++ b/server/route.go @@ -726,7 +726,7 @@ func (s *Server) sendAsyncInfoToClients(regCli, wsCli bool) { // registered (server has received CONNECT and first PING). For // clients that are not at this stage, this will happen in the // processing of the first PING (see client.processPing) - if ((regCli && c.ws == nil) || (wsCli && c.ws != nil)) && + if ((regCli && !c.isWebsocket()) || (wsCli && c.isWebsocket())) && c.opts.Protocol >= ClientProtoInfo && c.flags.isSet(firstPongSent) { // sendInfo takes care of checking if the connection is still diff --git a/server/server.go b/server/server.go index 88d03794..d64cddd2 100644 --- a/server/server.go +++ b/server/server.go @@ -2431,7 +2431,7 @@ func (s *Server) removeClient(c *client) { if updateProtoInfoCount { s.cproto-- } - mqtt := c.mqtt != nil + mqtt := c.isMqtt() s.mu.Unlock() if mqtt { s.mqttHandleClosedClient(c) diff --git a/server/websocket.go b/server/websocket.go index c29722f5..901b4017 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -152,6 +152,12 @@ func wsGet(r io.Reader, buf []byte, pos, needed int) ([]byte, int, error) { return b, pos + avail, nil } +// Returns true if this connection is from a Websocket client. +// Lock held on entry. +func (c *client) isWebsocket() bool { + return c.ws != nil +} + // Returns a slice of byte slices corresponding to payload of websocket frames. // The byte slice `buf` is filled with bytes from the connection's read loop. // This function will decode the frame headers and unmask the payload(s). diff --git a/server/websocket_test.go b/server/websocket_test.go index e0b075fb..101d1ba1 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -2567,7 +2567,7 @@ func TestWSWebrowserClient(t *testing.T) { } c.mu.Lock() - ok := c.ws != nil && c.ws.browser == true + ok := c.isWebsocket() && c.ws.browser == true c.mu.Unlock() if !ok { t.Fatalf("Client is not marked as webrowser client") From cf9ba928ca7c41cf48541d0ea232673b4882a2f9 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Wed, 2 Dec 2020 17:00:47 -0700 Subject: [PATCH 08/11] Fixed some MQTT tests Signed-off-by: Ivan Kozlovic --- server/mqtt_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 23ebefe3..c01955b0 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -2913,6 +2913,7 @@ func TestMQTTPublishRetain(t *testing.T) { defer mc1.Close() testMQTTCheckConnAck(t, rs1, mqttConnAckRCConnectionAccepted, false) testMQTTPublish(t, mc1, rs1, 0, false, test.retained, "foo", 0, []byte(test.sentValue)) + testMQTTFlush(t, mc1, nil, rs1) mc2, rs2 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) defer mc2.Close() @@ -2960,6 +2961,7 @@ func TestMQTTPublishRetainPermViolation(t *testing.T) { defer mc1.Close() testMQTTCheckConnAck(t, rs1, mqttConnAckRCConnectionAccepted, false) testMQTTPublish(t, mc1, rs1, 0, false, true, "bar", 0, []byte("retained")) + testMQTTFlush(t, mc1, nil, rs1) mc2, rs2 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port) defer mc2.Close() @@ -3357,6 +3359,7 @@ func TestMQTTPersistRetainedMsg(t *testing.T) { testMQTTPublish(t, c, r, 0, false, true, "baz", 1, []byte("baz1")) // Remove bar testMQTTPublish(t, c, r, 1, false, true, "bar", 1, nil) + testMQTTFlush(t, c, nil, r) testMQTTDisconnect(t, c, nil) c.Close() @@ -3676,6 +3679,21 @@ func TestMQTTMaxAckPending(t *testing.T) { testMQTTSendPubAck(t, c, pi) testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg4")) + // Make sure this message gets ack'ed + mcli := testMQTTGetClient(t, s, cisub.clientID) + checkFor(t, time.Second, 15*time.Millisecond, func() error { + mcli.mu.Lock() + sess := mcli.mqtt.sess + sess.mu.Lock() + np := len(sess.pending) + sess.mu.Unlock() + mcli.mu.Unlock() + if np != 0 { + return fmt.Errorf("Still %v pending messages", np) + } + return nil + }) + // Check that change to config does not prevent restart of sub. cp.Close() c.Close() From 035cffae375e34eed0020aa6f37a533823721601 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Thu, 3 Dec 2020 14:23:57 -0700 Subject: [PATCH 09/11] Add clientType() which returns NATS/MQTT/WS for CLIENT connections. It returns NON_CLIENT if invoked from a non CLIENT connection. Signed-off-by: Ivan Kozlovic --- server/auth.go | 64 +++++++++++++++++++++++++----------------------- server/client.go | 58 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 81 insertions(+), 41 deletions(-) diff --git a/server/auth.go b/server/auth.go index e925b898..115b675f 100644 --- a/server/auth.go +++ b/server/auth.go @@ -346,11 +346,12 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo s.mu.Lock() authRequired := s.info.AuthRequired if !authRequired { - if c.isMqtt() { + // If no auth required for regular clients, then check if + // we have an override for MQTT or Websocket clients. + switch c.clientType() { + case MQTT: authRequired = s.mqtt.authOverride - } else if c.isWebsocket() { - // If no auth required for regular clients, then check if - // we have an override for websocket clients. + case WS: authRequired = s.websocket.authOverride } } @@ -366,33 +367,36 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo noAuthUser string ) tlsMap := opts.TLSMap - if c.isMqtt() { - mo := &opts.MQTT - // Always override TLSMap. - tlsMap = mo.TLSMap - // The rest depends on if there was any auth override in - // the mqtt's config. - if s.mqtt.authOverride { - noAuthUser = mo.NoAuthUser - username = mo.Username - password = mo.Password - token = mo.Token - ao = true + if c.kind == CLIENT { + switch c.clientType() { + case MQTT: + mo := &opts.MQTT + // Always override TLSMap. + tlsMap = mo.TLSMap + // The rest depends on if there was any auth override in + // the mqtt's config. + if s.mqtt.authOverride { + noAuthUser = mo.NoAuthUser + username = mo.Username + password = mo.Password + token = mo.Token + ao = true + } + case WS: + wo := &opts.Websocket + // Always override TLSMap. + tlsMap = wo.TLSMap + // The rest depends on if there was any auth override in + // the websocket's config. + if s.websocket.authOverride { + noAuthUser = wo.NoAuthUser + username = wo.Username + password = wo.Password + token = wo.Token + ao = true + } } - } else if c.isWebsocket() { - wo := &opts.Websocket - // Always override TLSMap. - tlsMap = wo.TLSMap - // The rest depends on if there was any auth override in - // the websocket's config. - if s.websocket.authOverride { - noAuthUser = wo.NoAuthUser - username = wo.Username - password = wo.Password - token = wo.Token - ao = true - } - } else if c.kind == LEAF { + } else { tlsMap = opts.LeafNode.TLSMap } if !ao { diff --git a/server/client.go b/server/client.go index cd43d58a..a3b3c832 100644 --- a/server/client.go +++ b/server/client.go @@ -51,6 +51,20 @@ const ( ACCOUNT ) +// Extended type of a CLIENT connection. This is returned by c.clientType() +// and indicate what type of client connection we are dealing with. +// If invoked on a non CLIENT connection, NON_CLIENT type is returned. +const ( + // If the connection is not a CLIENT connection. + NON_CLIENT = iota + // Regular NATS client. + NATS + // MQTT client. + MQTT + // Websocket client. + WS +) + const ( // ClientProtoZero is the original Client protocol from 2009. // http://nats.io/documentation/internals/nats-protocol/ @@ -427,6 +441,26 @@ func (c *client) GetTLSConnectionState() *tls.ConnectionState { return &state } +// For CLIENT connections, this function returns the client type, that is, +// NATS (for regular clients), MQTT or WS for websocket. +// If this is invoked for a non CLIENT connection, NON_CLIENT is returned. +// +// This function does not lock the client and accesses fields that are supposed +// to be immutable and therefore it can be invoked outside of the client's lock. +func (c *client) clientType() int { + switch c.kind { + case CLIENT: + if c.isMqtt() { + return MQTT + } else if c.isWebsocket() { + return WS + } + return NATS + default: + return NON_CLIENT + } +} + // This is the main subscription struct that indicates // interest in published messages. // FIXME(dlc) - This is getting bloated for normal subs, need @@ -540,13 +574,14 @@ func (c *client) initClient() { switch c.kind { case CLIENT: - name := "cid" - if c.isWebsocket() { - name = "wid" - } else if c.isMqtt() { - name = "mid" + switch c.clientType() { + case NATS: + c.ncs.Store(fmt.Sprintf("%s - cid:%d", conn, c.cid)) + case WS: + c.ncs.Store(fmt.Sprintf("%s - wid:%d", conn, c.cid)) + case MQTT: + c.ncs.Store(fmt.Sprintf("%s - mid:%d", conn, c.cid)) } - c.ncs.Store(fmt.Sprintf("%s - %s:%d", conn, name, c.cid)) case ROUTER: c.ncs.Store(fmt.Sprintf("%s - rid:%d", conn, c.cid)) case GATEWAY: @@ -4640,12 +4675,13 @@ func (c *client) connectionTypeAllowed(acts map[string]struct{}) bool { var want string switch c.kind { case CLIENT: - if c.isWebsocket() { - want = jwt.ConnectionTypeWebsocket - } else if c.isMqtt() { - want = jwt.ConnectionTypeMqtt - } else { + switch c.clientType() { + case NATS: want = jwt.ConnectionTypeStandard + case WS: + want = jwt.ConnectionTypeWebsocket + case MQTT: + want = jwt.ConnectionTypeMqtt } case LEAF: want = jwt.ConnectionTypeLeafnode From 415a7071a750d73c281fc0cd13ed5838fa527b75 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Thu, 3 Dec 2020 17:57:51 -0700 Subject: [PATCH 10/11] Tweaks to mqttProcessConnect() The test TestMQTTPersistedSession() flapped once on GA. It turns out that when the server was sending CONNACK the test was immediately using a NATS publisher to send a message that was not received by the MQTT subscription for the recovered session. Sending the CONNACK before restoring subscriptions allowed for a window where a different connection could publish and messages would be missed. It is technically ok, I think, and test could have been easily fixed to ensure that we don't NATS publish before the session is fully restored. However, I have changed the order to first restore subscriptions then send the CONNACK. The way locking happens with MQTT subscriptions we are sure that the CONNACK will be sent first because even if there are inflight messages, the MQTT callbacks will have to wait for the session lock under which the subscriptions are restored and the CONNACK sent. Signed-off-by: Ivan Kozlovic --- server/mqtt.go | 46 ++++++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/server/mqtt.go b/server/mqtt.go index fba4a644..0f8760ea 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -1480,9 +1480,12 @@ func (s *Server) mqttProcessConnect(c *client, cp *mqttConnectProto, trace bool) sessp := false // Do we have an existing session for this client ID es, ok := asm.sessions[cp.clientID] + if !ok { + es = mqttSessionCreate(s.getOpts()) + } + es.mu.Lock() + defer es.mu.Unlock() if ok { - es.mu.Lock() - defer es.mu.Unlock() // Clear the session if client wants a clean session. // Also, Spec [MQTT-3.2.2-1]: don't report session present if cleanSess || es.clean { @@ -1496,13 +1499,10 @@ func (s *Server) mqttProcessConnect(c *client, cp *mqttConnectProto, trace bool) // Report to the client that the session was present sessp = true } - ec := es.c - // Is there an actual client associated with this session. - if ec != nil { - // Spec [MQTT-3.1.4-2]. If the ClientId represents a Client already - // connected to the Server then the Server MUST disconnect the existing - // client. - ec := es.c + // Spec [MQTT-3.1.4-2]. If the ClientId represents a Client already + // connected to the Server then the Server MUST disconnect the existing + // client. + if ec := es.c; ec != nil { ec.mu.Lock() // Remove will before closing ec.mqtt.cp.will = nil @@ -1516,12 +1516,10 @@ func (s *Server) mqttProcessConnect(c *client, cp *mqttConnectProto, trace bool) } else { // Spec [MQTT-3.2.2-3]: if the Server does not have stored Session state, // it MUST set Session Present to 0 in the CONNACK packet. - es = mqttSessionCreate(s.getOpts()) es.c, es.clean, es.stream = c, cleanSess, asm.sstream - es.mu.Lock() - defer es.mu.Unlock() - asm.sessions[cp.clientID] = es es.save(cp.clientID) + // Now save this new session into the account sessions + asm.sessions[cp.clientID] = es } c.mu.Lock() c.flags.set(connectReceived) @@ -1529,11 +1527,21 @@ func (s *Server) mqttProcessConnect(c *client, cp *mqttConnectProto, trace bool) c.mqtt.asm = asm c.mqtt.sess = es c.mu.Unlock() - // Spec [MQTT-3.2.0-1]: At this point we need to send the CONNACK before - // restoring subscriptions, because CONNACK must be the first packet sent - // to the client. - sendConnAck(mqttConnAckRCConnectionAccepted, sessp) - // Now process possible saved subscriptions. + // + // Spec [MQTT-3.2.0-1]: CONNACK must be the first protocol sent to the + // session. However, we are going to possibly restore the subscriptions + // first and then send the CONNACK. This will help tests that restore + // a MQTT connection with subs and immediately use NATS to publish. + // In that case, message would not be received because the pub could + // occur before the subscriptions are processed here. It would be + // easy to fix test with doing a PINGREQ/PINGRESP before doing NATS pub, + // but it seems better to ensure that everything is setup before sending + // back the CONNACK. + // Note that since we are under the session lock, the subs callback will + // have to wait to acquire the lock, so we are still guaranteed to enqueue + // the CONNACK before any message. + // + // Process possible saved subscriptions. if l := len(es.subs); l > 0 { filters := make([]*mqttFilter, 0, l) for subject, qos := range es.subs { @@ -1543,6 +1551,8 @@ func (s *Server) mqttProcessConnect(c *client, cp *mqttConnectProto, trace bool) return err } } + // Now send the CONNACK + sendConnAck(mqttConnAckRCConnectionAccepted, sessp) return nil } From 1d7c4712a5cfb4384a1a6498e75382f7be975f6a Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Fri, 4 Dec 2020 14:42:37 -0700 Subject: [PATCH 11/11] Increase Pub performance Essentially make publish a zero alloc. Use c.mqtt.pp as the parser publish packet structure. Messages were initially copied because MQTT messages don't have CR_LF but was adding it so that it worked for NATS pub/subs and MQTT pub/subs. Now an MQTT producer sending to NATS sub will queue CR_LF after payload. Here is result of benchcmp for MQTT pub runs only: ``` benchmark old ns/op new ns/op delta BenchmarkMQTT_QoS0_Pub_______0b_Payload-8 157 55.6 -64.59% BenchmarkMQTT_QoS0_Pub_______8b_Payload-8 167 61.0 -63.47% BenchmarkMQTT_QoS0_Pub______32b_Payload-8 181 65.3 -63.92% BenchmarkMQTT_QoS0_Pub_____128b_Payload-8 243 78.1 -67.86% BenchmarkMQTT_QoS0_Pub_____256b_Payload-8 298 95.0 -68.12% BenchmarkMQTT_QoS0_Pub_______1K_Payload-8 604 224 -62.91% BenchmarkMQTT_QoS1_Pub_______0b_Payload-8 1713 1314 -23.29% BenchmarkMQTT_QoS1_Pub_______8b_Payload-8 1703 1311 -23.02% BenchmarkMQTT_QoS1_Pub______32b_Payload-8 1722 1345 -21.89% BenchmarkMQTT_QoS1_Pub_____128b_Payload-8 2105 1432 -31.97% BenchmarkMQTT_QoS1_Pub_____256b_Payload-8 2148 1503 -30.03% BenchmarkMQTT_QoS1_Pub_______1K_Payload-8 3024 1889 -37.53% benchmark old MB/s new MB/s speedup BenchmarkMQTT_QoS0_Pub_______0b_Payload-8 31.76 89.91 2.83x BenchmarkMQTT_QoS0_Pub_______8b_Payload-8 77.79 213.01 2.74x BenchmarkMQTT_QoS0_Pub______32b_Payload-8 204.52 566.26 2.77x BenchmarkMQTT_QoS0_Pub_____128b_Payload-8 550.65 1715.96 3.12x BenchmarkMQTT_QoS0_Pub_____256b_Payload-8 877.77 2757.16 3.14x BenchmarkMQTT_QoS0_Pub_______1K_Payload-8 1705.02 4607.72 2.70x BenchmarkMQTT_QoS1_Pub_______0b_Payload-8 6.42 8.37 1.30x BenchmarkMQTT_QoS1_Pub_______8b_Payload-8 11.16 14.49 1.30x BenchmarkMQTT_QoS1_Pub______32b_Payload-8 24.97 31.97 1.28x BenchmarkMQTT_QoS1_Pub_____128b_Payload-8 66.52 97.74 1.47x BenchmarkMQTT_QoS1_Pub_____256b_Payload-8 124.78 178.27 1.43x BenchmarkMQTT_QoS1_Pub_______1K_Payload-8 342.64 548.32 1.60x benchmark old allocs new allocs delta BenchmarkMQTT_QoS0_Pub_______0b_Payload-8 3 0 -100.00% BenchmarkMQTT_QoS0_Pub_______8b_Payload-8 3 0 -100.00% BenchmarkMQTT_QoS0_Pub______32b_Payload-8 3 0 -100.00% BenchmarkMQTT_QoS0_Pub_____128b_Payload-8 4 0 -100.00% BenchmarkMQTT_QoS0_Pub_____256b_Payload-8 4 0 -100.00% BenchmarkMQTT_QoS0_Pub_______1K_Payload-8 4 0 -100.00% BenchmarkMQTT_QoS1_Pub_______0b_Payload-8 5 2 -60.00% BenchmarkMQTT_QoS1_Pub_______8b_Payload-8 5 2 -60.00% BenchmarkMQTT_QoS1_Pub______32b_Payload-8 5 2 -60.00% BenchmarkMQTT_QoS1_Pub_____128b_Payload-8 7 3 -57.14% BenchmarkMQTT_QoS1_Pub_____256b_Payload-8 7 3 -57.14% BenchmarkMQTT_QoS1_Pub_______1K_Payload-8 7 3 -57.14% benchmark old bytes new bytes delta BenchmarkMQTT_QoS0_Pub_______0b_Payload-8 80 0 -100.00% BenchmarkMQTT_QoS0_Pub_______8b_Payload-8 88 0 -100.00% BenchmarkMQTT_QoS0_Pub______32b_Payload-8 120 0 -100.00% BenchmarkMQTT_QoS0_Pub_____128b_Payload-8 224 0 -100.00% BenchmarkMQTT_QoS0_Pub_____256b_Payload-8 369 1 -99.73% BenchmarkMQTT_QoS0_Pub_______1K_Payload-8 1250 31 -97.52% BenchmarkMQTT_QoS1_Pub_______0b_Payload-8 106 28 -73.58% BenchmarkMQTT_QoS1_Pub_______8b_Payload-8 122 28 -77.05% BenchmarkMQTT_QoS1_Pub______32b_Payload-8 154 28 -81.82% BenchmarkMQTT_QoS1_Pub_____128b_Payload-8 381 157 -58.79% BenchmarkMQTT_QoS1_Pub_____256b_Payload-8 655 287 -56.18% BenchmarkMQTT_QoS1_Pub_______1K_Payload-8 2312 1078 -53.37% ``` Signed-off-by: Ivan Kozlovic --- server/client.go | 11 ++++++- server/mqtt.go | 79 ++++++++++++++++++++++++++++++++------------- server/mqtt_test.go | 3 +- 3 files changed, 68 insertions(+), 25 deletions(-) diff --git a/server/client.go b/server/client.go index a3b3c832..806978c3 100644 --- a/server/client.go +++ b/server/client.go @@ -2991,7 +2991,12 @@ func (c *client) deliverMsg(sub *subscription, subject, reply, mh, msg []byte, g // Update statistics // The msg includes the CR_LF, so pull back out for accounting. - msgSize := int64(len(msg) - LEN_CR_LF) + msgSize := int64(len(msg)) + prodIsMQTT := c.isMqtt() + // MQTT producers send messages without CR_LF, so don't remove it for them. + if !prodIsMQTT { + msgSize -= int64(LEN_CR_LF) + } // No atomic needed since accessed under client lock. // Monitor is reading those also under client's lock. @@ -3066,6 +3071,10 @@ func (c *client) deliverMsg(sub *subscription, subject, reply, mh, msg []byte, g // Queue to outbound buffer client.queueOutbound(mh) client.queueOutbound(msg) + if prodIsMQTT { + // Need to add CR_LF since MQTT producers don't send CR_LF + client.queueOutbound([]byte(CR_LF)) + } client.out.pm++ diff --git a/server/mqtt.go b/server/mqtt.go index 0f8760ea..b2470ff8 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -83,6 +83,9 @@ const ( mqttConnAckRCBadUserOrPassword = byte(0x4) mqttConnAckRCNotAuthorized = byte(0x5) + // Maximum payload size of a control packet + mqttMaxPayloadSize = 0xFFFFFFF + // Topic/Filter characters mqttTopicLevelSep = '/' mqttSingleLevelWC = '+' @@ -245,6 +248,7 @@ type mqttPublish struct { sz int pi uint16 flags byte + szb [9]byte // MQTT max payload size is 268,435,455 } func (s *Server) startMQTT() { @@ -298,6 +302,7 @@ func (s *Server) createMQTTClient(conn net.Conn) *client { now := time.Now() c := &client{srv: s, nc: conn, mpay: maxPay, msubs: maxSubs, start: now, last: now, mqtt: &mqtt{}} + c.mqtt.pp = &mqttPublish{} // MQTT clients don't send NATS CONNECT protocols. So make it an "echo" // client, but disable verbose and pedantic (by not setting them). c.opts.Echo = true @@ -509,16 +514,17 @@ func (c *client) mqttParse(buf []byte) error { switch pt { case mqttPacketPub: - pp := mqttPublish{flags: b & mqttPacketFlagMask} - err = c.mqttParsePub(r, pl, &pp) + pp := c.mqtt.pp + pp.flags = b & mqttPacketFlagMask + err = c.mqttParsePub(r, pl, pp) if trace { - c.traceInOp("PUBLISH", errOrTrace(err, mqttPubTrace(&pp))) + c.traceInOp("PUBLISH", errOrTrace(err, mqttPubTrace(pp))) if err == nil { - c.traceMsg(pp.msg) + c.mqttTraceMsg(pp.msg) } } if err == nil { - s.mqttProcessPub(c, &pp) + s.mqttProcessPub(c, pp) if pp.pi > 0 { c.mqttEnqueuePubAck(pp.pi) if trace { @@ -630,6 +636,15 @@ func (c *client) mqttParse(buf []byte) error { return err } +func (c *client) mqttTraceMsg(msg []byte) { + maxTrace := c.srv.getOpts().MaxTracedMsgLen + if maxTrace > 0 && len(msg) > maxTrace { + c.Tracef("<<- MSG_PAYLOAD: [\"%s...\"]", msg[:maxTrace]) + } else { + c.Tracef("<<- MSG_PAYLOAD: [%q]", msg) + } +} + // Update the session (possibly remove it) of this disconnected client. func (s *Server) mqttHandleClosedClient(c *client) { c.mu.Lock() @@ -1034,13 +1049,13 @@ func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(sess *mqttSessi pi := sess.getPubAckIdentifier(mqttGetQoS(rm.Flags), sub) // Need to use the subject for the retained message, not the `sub` subject. // We can find the published retained message in rm.sub.subject. - flags := mqttSerializePublishMsg(prm, pi, false, true, string(rm.sub.subject), rm.Msg[:len(rm.Msg)-LEN_CR_LF]) + flags := mqttSerializePublishMsg(prm, pi, false, true, string(rm.sub.subject), rm.Msg) if trace { pp := mqttPublish{ flags: flags, pi: pi, subject: rm.sub.subject, - sz: len(rm.Msg) - LEN_CR_LF, + sz: len(rm.Msg), } c.traceOutOp("PUBLISH", []byte(mqttPubTrace(&pp))) } @@ -1397,9 +1412,8 @@ func (c *client) mqttParseConnect(r *mqttReader, pl int) (byte, *mqttConnectProt if err != nil { return 0, nil, err } - cp.will.message = make([]byte, 0, len(msg)+2) + cp.will.message = make([]byte, 0, len(msg)) cp.will.message = append(cp.will.message, msg...) - cp.will.message = append(cp.will.message, CR_LF...) } if hasUser { @@ -1581,12 +1595,12 @@ func (s *Server) mqttHandleWill(c *client) { c.mu.Unlock() return } - pp := &mqttPublish{ - subject: will.topic, - msg: will.message, - sz: len(will.message) - LEN_CR_LF, - flags: will.qos << 1, - } + pp := c.mqtt.pp + pp.subject = will.topic + pp.msg = will.message + pp.sz = len(will.message) + pp.pi = 0 + pp.flags = will.qos << 1 if will.retain { pp.flags |= mqttPubFlagRetain } @@ -1638,18 +1652,20 @@ func (c *client) mqttParsePub(r *mqttReader, pl int, pp *mqttPublish) error { if pp.pi == 0 { return fmt.Errorf("with QoS=%v, packet identifier cannot be 0", qos) } + } else { + pp.pi = 0 } // The message payload will be the total packet length minus // what we have consumed for the variable header pp.sz = pl - (r.pos - start) - pp.msg = make([]byte, 0, pp.sz+2) if pp.sz > 0 { start = r.pos r.pos += pp.sz - pp.msg = append(pp.msg, r.buf[start:r.pos]...) + pp.msg = r.buf[start:r.pos] + } else { + pp.msg = nil } - pp.msg = append(pp.msg, _CRLF_...) return nil } @@ -1666,8 +1682,21 @@ func mqttPubTrace(pp *mqttPublish) string { } func (s *Server) mqttProcessPub(c *client, pp *mqttPublish) { - c.mqtt.pp = pp - c.pa.subject, c.pa.hdr, c.pa.size, c.pa.szb = pp.subject, -1, pp.sz, []byte(strconv.FormatInt(int64(pp.sz), 10)) + c.pa.subject, c.pa.hdr, c.pa.size = pp.subject, -1, pp.sz + + // Convert size into bytes. + i := len(pp.szb) + if pp.sz > 0 { + for l := pp.sz; l > 0; l /= 10 { + i-- + pp.szb[i] = digits[l%10] + } + } else { + i-- + pp.szb[i] = digits[0] + } + c.pa.szb = pp.szb[i:] + // This will work for QoS 0 but mqtt msg delivery callback will ignore // delivery for QoS > 0 published messages (since it is handled specifically // with call to directProcessInboundJetStreamMsg). @@ -1677,10 +1706,9 @@ func (s *Server) mqttProcessPub(c *client, pp *mqttPublish) { if mqttGetQoS(pp.flags) > 0 { // Since this is the fast path, we access the messages stream directly here // without locking. All the fields mqtt.asm.mstream are immutable. - c.mqtt.asm.mstream.processInboundJetStreamMsg(nil, c, string(c.pa.subject), "", pp.msg[:len(pp.msg)-LEN_CR_LF]) + c.mqtt.asm.mstream.processInboundJetStreamMsg(nil, c, string(c.pa.subject), "", pp.msg) } c.pa.subject, c.pa.hdr, c.pa.size, c.pa.szb = nil, -1, 0, nil - c.mqtt.pp = nil } // Invoked when processing an inbound client message. If the "retain" flag is @@ -2046,8 +2074,13 @@ func mqttDeliverMsgCb(sub *subscription, pc *client, subject, reply string, msg return } retained = mqttIsRetained(ppFlags) + } else { + // This is coming from a non MQTT publisher, so Qos 0, no dup nor retain flag, etc.. + // Should probably reject, for now just truncate. + if len(msg) > mqttMaxPayloadSize { + msg = msg[:mqttMaxPayloadSize] + } } - // else this is coming from a non MQTT publisher, so Qos 0, no dup nor retain flag, etc.. sess.mu.Unlock() sw := mqttWriter{} diff --git a/server/mqtt_test.go b/server/mqtt_test.go index c01955b0..bbf04ee4 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -2716,7 +2716,8 @@ func TestMQTTWill(t *testing.T) { testMQTTDisconnect(t, mc, nil) testMQTTExpectNothing(t, rs) if wm, err := sub.NextMsg(100 * time.Millisecond); err == nil { - t.Fatalf("Should not have receive a message, got %v", wm) + t.Fatalf("Should not have receive a message, got subj=%q data=%q", + wm.Subject, wm.Data) } } })