Files
nats-server/server/mqtt.go
Ivan Kozlovic 7cd412d08d Don't need to store domain in mqttJSA structure
At the time where we need the domain to construct the session hash,
we have access to server options. So use that instead of storing
the domain in the internal mqtt structure.

Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
2021-05-07 11:29:46 -06:00

3807 lines
109 KiB
Go

// Copyright 2020 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"bytes"
"crypto/tls"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
"github.com/nats-io/nuid"
)
// References to "spec" here is from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.pdf
const (
mqttPacketConnect = byte(0x10)
mqttPacketConnectAck = byte(0x20)
mqttPacketPub = byte(0x30)
mqttPacketPubAck = byte(0x40)
mqttPacketPubRec = byte(0x50)
mqttPacketPubRel = byte(0x60)
mqttPacketPubComp = byte(0x70)
mqttPacketSub = byte(0x80)
mqttPacketSubAck = byte(0x90)
mqttPacketUnsub = byte(0xa0)
mqttPacketUnsubAck = byte(0xb0)
mqttPacketPing = byte(0xc0)
mqttPacketPingResp = byte(0xd0)
mqttPacketDisconnect = byte(0xe0)
mqttPacketMask = byte(0xf0)
mqttPacketFlagMask = byte(0x0f)
mqttProtoLevel = byte(0x4)
// Connect flags
mqttConnFlagReserved = byte(0x1)
mqttConnFlagCleanSession = byte(0x2)
mqttConnFlagWillFlag = byte(0x04)
mqttConnFlagWillQoS = byte(0x18)
mqttConnFlagWillRetain = byte(0x20)
mqttConnFlagPasswordFlag = byte(0x40)
mqttConnFlagUsernameFlag = byte(0x80)
// Publish flags
mqttPubFlagRetain = byte(0x01)
mqttPubFlagQoS = byte(0x06)
mqttPubFlagDup = byte(0x08)
mqttPubQos1 = byte(0x2) // 1 << 1
// Subscribe flags
mqttSubscribeFlags = byte(0x2)
mqttSubAckFailure = byte(0x80)
// Unsubscribe flags
mqttUnsubscribeFlags = byte(0x2)
// ConnAck returned codes
mqttConnAckRCConnectionAccepted = byte(0x0)
mqttConnAckRCUnacceptableProtocolVersion = byte(0x1)
mqttConnAckRCIdentifierRejected = byte(0x2)
mqttConnAckRCServerUnavailable = byte(0x3)
mqttConnAckRCBadUserOrPassword = byte(0x4)
mqttConnAckRCNotAuthorized = byte(0x5)
// Maximum payload size of a control packet
mqttMaxPayloadSize = 0xFFFFFFF
// Topic/Filter characters
mqttTopicLevelSep = '/'
mqttSingleLevelWC = '+'
mqttMultiLevelWC = '#'
// This is appended to the sid of a subscription that is
// created on the upper level subject because of the MQTT
// wildcard '#' semantic.
mqttMultiLevelSidSuffix = " fwc"
// This is the prefix for NATS subscriptions subjects associated as delivery
// subject of JS consumer. We want to make them unique so will prevent users
// MQTT subscriptions to start with this.
mqttSubPrefix = "$MQTT.sub."
// MQTT Stream names prefix.
mqttStreamNamePrefix = "$MQTT_"
// Stream name for MQTT messages on a given account
mqttStreamName = mqttStreamNamePrefix + "msgs"
mqttStreamSubjectPrefix = "$MQTT.msgs."
// Stream name for MQTT retained messages on a given account
mqttRetainedMsgsStreamName = mqttStreamNamePrefix + "rmsgs"
mqttRetainedMsgsStreamSubject = "$MQTT.rmsgs"
// Stream name prefix for MQTT sessions on a given account
mqttSessionsStreamNamePrefix = mqttStreamNamePrefix + "sess_"
// Normally, MQTT server should not redeliver QoS 1 messages to clients,
// except after client reconnects. However, NATS Server will redeliver
// unacknowledged messages after this default interval. This can be
// changed with the server.Options.MQTT.AckWait option.
mqttDefaultAckWait = 30 * time.Second
// This is the default for the outstanding number of pending QoS 1
// messages sent to a session with QoS 1 subscriptions.
mqttDefaultMaxAckPending = 1024
// A session's list of subscriptions cannot have a cumulative MaxAckPending
// of more than this limit.
mqttMaxAckTotalLimit = 0xFFFF
// Prefix of the reply subject for JS API requests.
mqttJSARepliesPrefix = "$MQTT.JSA."
// Those are tokens that are used for the reply subject of JS API requests.
// For instance "$MQTT.JSA.<node id>.SC.<number>" is the reply subject
// for a request to create a stream (where <node id> is the server name hash),
// while "$MQTT.JSA.<node id>.SL.<number>" is for a stream lookup, etc...
mqttJSAIdTokenPos = 3
mqttJSATokenPos = 4
mqttJSAStreamCreate = "SC"
mqttJSAStreamLookup = "SL"
mqttJSAStreamDel = "SD"
mqttJSAConsumerCreate = "CC"
mqttJSAConsumerDel = "CD"
mqttJSAMsgStore = "MS"
mqttJSAMsgLoad = "ML"
mqttJSASessPersist = "SP"
mqttJSARetainedMsgDel = "RD"
// Name of the header key added to NATS message to carry mqtt PUBLISH information
mqttNatsHeader = "Nmqtt-Pub"
// This is how long to keep a client in the flappers map before closing the
// connection. This prevent quick reconnect from those clients that keep
// wanting to connect with a client ID already in use.
mqttSessFlappingJailDur = time.Second
// This is how frequently the timer to cleanup the sessions flappers map is firing.
mqttSessFlappingCleanupInterval = 5 * time.Second
)
var (
mqttPingResponse = []byte{mqttPacketPingResp, 0x0}
mqttProtoName = []byte("MQTT")
mqttOldProtoName = []byte("MQIsdp")
mqttNatsHeaderB = []byte(mqttNatsHeader)
mqttSessJailDur = mqttSessFlappingJailDur
mqttFlapCleanItvl = mqttSessFlappingCleanupInterval
mqttJSAPITimeout = 4 * time.Second
)
var (
errMQTTWebsocketNotSupported = errors.New("invalid connection, websocket currently not supported")
errMQTTTopicFilterCannotBeEmpty = errors.New("topic filter cannot be empty")
errMQTTMalformedVarInt = errors.New("malformed variable int")
errMQTTSecondConnectPacket = errors.New("received a second CONNECT packet")
errMQTTServerNameMustBeSet = errors.New("mqtt requires server name to be explicitly set")
errMQTTUserMixWithUsersNKeys = errors.New("mqtt authentication username not compatible with presence of users/nkeys")
errMQTTTokenMixWIthUsersNKeys = errors.New("mqtt authentication token not compatible with presence of users/nkeys")
errMQTTAckWaitMustBePositive = errors.New("ack wait must be a positive value")
errMQTTStandaloneNeedsJetStream = errors.New("mqtt requires JetStream to be enabled if running in standalone mode")
errMQTTConnFlagReserved = errors.New("connect flags reserved bit not set to 0")
errMQTTWillAndRetainFlag = errors.New("if Will flag is set to 0, Will Retain flag must be 0 too")
errMQTTPasswordFlagAndNoUser = errors.New("password flag set but username flag is not")
errMQTTCIDEmptyNeedsCleanFlag = errors.New("when client ID is empty, clean session flag must be set to 1")
errMQTTEmptyWillTopic = errors.New("empty Will topic not allowed")
errMQTTEmptyUsername = errors.New("empty user name not allowed")
errMQTTTopicIsEmpty = errors.New("topic cannot be empty")
errMQTTPacketIdentifierIsZero = errors.New("packet identifier cannot be 0")
errMQTTUnsupportedCharacters = errors.New("characters ' ' and '.' not supported for MQTT topics")
)
type srvMQTT struct {
listener net.Listener
listenerErr error
authOverride bool
sessmgr mqttSessionManager
}
type mqttSessionManager struct {
mu sync.RWMutex
sessions map[string]*mqttAccountSessionManager // key is account name
}
type mqttAccountSessionManager struct {
mu sync.RWMutex
sessions map[string]*mqttSession // key is MQTT client ID
sessByHash map[string]*mqttSession // key is MQTT client ID hash
sessLocked map[string]struct{} // key is MQTT client ID and indicate that a session can not be taken by a new client at this time
flappers map[string]int64 // When connection connects with client ID already in use
flapTimer *time.Timer // Timer to perform some cleanup of the flappers map
sl *Sublist // sublist allowing to find retained messages for given subscription
retmsgs map[string]*mqttRetainedMsg // retained messages
jsa mqttJSA
replicas int
rrmLastSeq uint64 // Restore retained messages expected last sequence
rrmDoneCh chan struct{} // To notify the caller that all retained messages have been loaded
}
type mqttJSA struct {
mu sync.Mutex
id string
c *client
sendq chan *mqttJSPubMsg
rplyr string
replies sync.Map
nuid *nuid.NUID
quitCh chan struct{}
}
type mqttJSPubMsg struct {
subj string
reply string
hdr int
msg []byte
}
type mqttRetMsgDel struct {
Subject string `json:"subject"`
Seq uint64 `json:"seq"`
}
type mqttSession struct {
mu sync.Mutex
id string // client ID
idHash string // client ID hash
c *client
jsa *mqttJSA
subs map[string]byte
cons map[string]*ConsumerConfig
seq uint64
pending map[uint16]*mqttPending // Key is the PUBLISH packet identifier sent to client and maps to a mqttPending record
cpending map[string]map[uint64]uint16 // For each JS consumer, the key is the stream sequence and maps to the PUBLISH packet identifier
ppi uint16 // publish packet identifier
maxp uint16
tmaxack int
clean bool
}
type mqttPersistedSession struct {
Origin string `json:"origin,omitempty"`
ID string `json:"id,omitempty"`
Clean bool `json:"clean,omitempty"`
Subs map[string]byte `json:"subs,omitempty"`
Cons map[string]*ConsumerConfig `json:"cons,omitempty"`
}
type mqttRetainedMsg struct {
Origin string `json:"origin,omitempty"`
Subject string `json:"subject,omitempty"`
Topic string `json:"topic,omitempty"`
Msg []byte `json:"msg,omitempty"`
Flags byte `json:"flags,omitempty"`
Source string `json:"source,omitempty"`
// non exported
sseq uint64
floor uint64
sub *subscription
}
type mqttSub struct {
qos byte
// Pending serialization of retained messages to be sent when subscription is registered
prm *mqttWriter
// This is the JS durable name this subscription is attached to.
jsDur string
}
type mqtt struct {
r *mqttReader
cp *mqttConnectProto
pp *mqttPublish
asm *mqttAccountSessionManager // quick reference to account session manager, immutable after processConnect()
sess *mqttSession // quick reference to session, immutable after processConnect()
}
type mqttPending struct {
sseq uint64 // stream sequence
subject string // the ACK subject to send the ack to
jsDur string // JS durable name
}
type mqttConnectProto struct {
clientID string
rd time.Duration
will *mqttWill
flags byte
}
type mqttIOReader interface {
io.Reader
SetReadDeadline(time.Time) error
}
type mqttReader struct {
reader mqttIOReader
buf []byte
pos int
}
type mqttWriter struct {
bytes.Buffer
}
type mqttWill struct {
topic []byte
subject []byte
message []byte
qos byte
retain bool
}
type mqttFilter struct {
filter string
qos byte
// Used only for tracing and should not be used after parsing of (un)sub protocols.
ttopic []byte
}
type mqttPublish struct {
topic []byte
subject []byte
msg []byte
sz int
pi uint16
flags byte
}
func (s *Server) startMQTT() {
sopts := s.getOpts()
o := &sopts.MQTT
var hl net.Listener
var err error
port := o.Port
if port == -1 {
port = 0
}
hp := net.JoinHostPort(o.Host, strconv.Itoa(port))
s.mu.Lock()
if s.shutdown {
s.mu.Unlock()
return
}
s.mqtt.sessmgr.sessions = make(map[string]*mqttAccountSessionManager)
hl, err = net.Listen("tcp", hp)
s.mqtt.listenerErr = err
if err != nil {
s.mu.Unlock()
s.Fatalf("Unable to listen for MQTT connections: %v", err)
return
}
if port == 0 {
o.Port = hl.Addr().(*net.TCPAddr).Port
}
s.mqtt.listener = hl
scheme := "mqtt"
if o.TLSConfig != nil {
scheme = "tls"
}
s.Noticef("Listening for MQTT clients on %s://%s:%d", scheme, o.Host, o.Port)
go s.acceptConnections(hl, "MQTT", func(conn net.Conn) { s.createMQTTClient(conn) }, nil)
s.mu.Unlock()
}
// This is similar to createClient() but has some modifications specifi to MQTT clients.
// The comments have been kept to minimum to reduce code size. Check createClient() for
// more details.
func (s *Server) createMQTTClient(conn net.Conn) *client {
opts := s.getOpts()
maxPay := int32(opts.MaxPayload)
maxSubs := int32(opts.MaxSubs)
if maxSubs == 0 {
maxSubs = -1
}
now := time.Now().UTC()
c := &client{srv: s, nc: conn, mpay: maxPay, msubs: maxSubs, start: now, last: now, mqtt: &mqtt{}}
c.headers = true
c.mqtt.pp = &mqttPublish{}
// MQTT clients don't send NATS CONNECT protocols. So make it an "echo"
// client, but disable verbose and pedantic (by not setting them).
c.opts.Echo = true
c.registerWithAccount(s.globalAccount())
s.mu.Lock()
// Check auth, override if applicable.
authRequired := s.info.AuthRequired || s.mqtt.authOverride
s.totalClients++
s.mu.Unlock()
c.mu.Lock()
if authRequired {
c.flags.set(expectConnect)
}
c.initClient()
c.Debugf("Client connection created")
c.mu.Unlock()
s.mu.Lock()
if !s.running || s.ldm {
if s.shutdown {
conn.Close()
}
s.mu.Unlock()
return c
}
if opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn {
s.mu.Unlock()
c.maxConnExceeded()
return nil
}
s.clients[c.cid] = c
tlsRequired := opts.MQTT.TLSConfig != nil
s.mu.Unlock()
c.mu.Lock()
// In case connection has already been closed
if c.isClosed() {
c.mu.Unlock()
c.closeConnection(WriteError)
return nil
}
var pre []byte
if tlsRequired && opts.AllowNonTLS {
pre = make([]byte, 4)
c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(opts.MQTT.TLSTimeout)))
n, _ := io.ReadFull(c.nc, pre[:])
c.nc.SetReadDeadline(time.Time{})
pre = pre[:n]
if n > 0 && pre[0] == 0x16 {
tlsRequired = true
} else {
tlsRequired = false
}
}
if tlsRequired {
if len(pre) > 0 {
c.nc = &tlsMixConn{c.nc, bytes.NewBuffer(pre)}
pre = nil
}
// Perform server-side TLS handshake.
if err := c.doTLSServerHandshake("mqtt", opts.MQTT.TLSConfig, opts.MQTT.TLSTimeout); err != nil {
c.mu.Unlock()
return nil
}
}
if authRequired {
timeout := opts.AuthTimeout
// Possibly override with MQTT specific value.
if opts.MQTT.AuthTimeout != 0 {
timeout = opts.MQTT.AuthTimeout
}
c.setAuthTimer(secondsToDuration(timeout))
}
// No Ping timer for MQTT clients...
s.startGoRoutine(func() { c.readLoop(pre) })
s.startGoRoutine(func() { c.writeLoop() })
if tlsRequired {
c.Debugf("TLS handshake complete")
cs := c.nc.(*tls.Conn).ConnectionState()
c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite))
}
c.mu.Unlock()
return c
}
// Given the mqtt options, we check if any auth configuration
// has been provided. If so, possibly create users/nkey users and
// store them in s.mqtt.users/nkeys.
// Also update a boolean that indicates if auth is required for
// mqtt clients.
// Server lock is held on entry.
func (s *Server) mqttConfigAuth(opts *MQTTOpts) {
mqtt := &s.mqtt
// If any of those is specified, we consider that there is an override.
mqtt.authOverride = opts.Username != _EMPTY_ || opts.Token != _EMPTY_ || opts.NoAuthUser != _EMPTY_
}
// Validate the mqtt related options.
func validateMQTTOptions(o *Options) error {
mo := &o.MQTT
// If no port is defined, we don't care about other options
if mo.Port == 0 {
return nil
}
// We have to force the server name to be explicitly set. There are conditions
// where we need a unique, repeatable name.
if o.ServerName == _EMPTY_ {
return errMQTTServerNameMustBeSet
}
// If there is a NoAuthUser, we need to have Users defined and
// the user to be present.
if mo.NoAuthUser != _EMPTY_ {
if err := validateNoAuthUser(o, mo.NoAuthUser); err != nil {
return err
}
}
// Token/Username not possible if there are users/nkeys
if len(o.Users) > 0 || len(o.Nkeys) > 0 {
if mo.Username != _EMPTY_ {
return errMQTTUserMixWithUsersNKeys
}
if mo.Token != _EMPTY_ {
return errMQTTTokenMixWIthUsersNKeys
}
}
if mo.AckWait < 0 {
return errMQTTAckWaitMustBePositive
}
// If strictly standalone and there is no JS enabled, then it won't work...
// For leafnodes, we could either have remote(s) and it would be ok, or no
// remote but accept from a remote side that has "hub" property set, which
// then would ok too. So we fail only if we have no leafnode config at all.
if !o.JetStream && o.Cluster.Port == 0 && o.Gateway.Port == 0 &&
o.LeafNode.Port == 0 && len(o.LeafNode.Remotes) == 0 {
return errMQTTStandaloneNeedsJetStream
}
return nil
}
// Returns true if this connection is from a MQTT client.
// Lock held on entry.
func (c *client) isMqtt() bool {
return c.mqtt != nil
}
// Parse protocols inside the given buffer.
// This is invoked from the readLoop.
func (c *client) mqttParse(buf []byte) error {
c.mu.Lock()
s := c.srv
trace := c.trace
connected := c.flags.isSet(connectReceived)
mqtt := c.mqtt
r := mqtt.r
var rd time.Duration
if mqtt.cp != nil {
rd = mqtt.cp.rd
if rd > 0 {
r.reader.SetReadDeadline(time.Time{})
}
}
c.mu.Unlock()
r.reset(buf)
var err error
var b byte
var pl int
for err == nil && r.hasMore() {
// Read packet type and flags
if b, err = r.readByte("packet type"); err != nil {
break
}
// Packet type
pt := b & mqttPacketMask
// If client was not connected yet, the first packet must be
// a mqttPacketConnect otherwise we fail the connection.
if !connected && pt != mqttPacketConnect {
// Try to guess if the client is trying to connect using Websocket,
// which is currently not supported
if bytes.HasPrefix(buf, []byte("GET ")) {
err = errMQTTWebsocketNotSupported
} else {
err = fmt.Errorf("the first packet should be a CONNECT (%v), got %v", mqttPacketConnect, pt)
}
break
}
if pl, err = r.readPacketLen(); err != nil {
break
}
switch pt {
case mqttPacketPub:
pp := c.mqtt.pp
pp.flags = b & mqttPacketFlagMask
err = c.mqttParsePub(r, pl, pp)
if trace {
c.traceInOp("PUBLISH", errOrTrace(err, mqttPubTrace(pp)))
if err == nil {
c.mqttTraceMsg(pp.msg)
}
}
if err == nil {
err = s.mqttProcessPub(c, pp)
}
if err == nil && pp.pi > 0 {
c.mqttEnqueuePubAck(pp.pi)
if trace {
c.traceOutOp("PUBACK", []byte(fmt.Sprintf("pi=%v", pp.pi)))
}
}
case mqttPacketPubAck:
var pi uint16
pi, err = mqttParsePubAck(r, pl)
if trace {
c.traceInOp("PUBACK", errOrTrace(err, fmt.Sprintf("pi=%v", pi)))
}
if err == nil {
c.mqttProcessPubAck(pi)
}
case mqttPacketSub:
var pi uint16 // packet identifier
var filters []*mqttFilter
var subs []*subscription
pi, filters, err = c.mqttParseSubs(r, b, pl)
if trace {
c.traceInOp("SUBSCRIBE", errOrTrace(err, mqttSubscribeTrace(pi, filters)))
}
if err == nil {
subs, err = c.mqttProcessSubs(filters)
if err == nil && trace {
c.traceOutOp("SUBACK", []byte(fmt.Sprintf("pi=%v", pi)))
}
}
if err == nil {
c.mqttEnqueueSubAck(pi, filters)
c.mqttSendRetainedMsgsToNewSubs(subs)
}
case mqttPacketUnsub:
var pi uint16 // packet identifier
var filters []*mqttFilter
pi, filters, err = c.mqttParseUnsubs(r, b, pl)
if trace {
c.traceInOp("UNSUBSCRIBE", errOrTrace(err, mqttUnsubscribeTrace(pi, filters)))
}
if err == nil {
err = c.mqttProcessUnsubs(filters)
if err == nil && trace {
c.traceOutOp("UNSUBACK", []byte(fmt.Sprintf("pi=%v", pi)))
}
}
if err == nil {
c.mqttEnqueueUnsubAck(pi)
}
case mqttPacketPing:
if trace {
c.traceInOp("PINGREQ", nil)
}
c.mqttEnqueuePingResp()
if trace {
c.traceOutOp("PINGRESP", nil)
}
case mqttPacketConnect:
// It is an error to receive a second connect packet
if connected {
err = errMQTTSecondConnectPacket
break
}
var rc byte
var cp *mqttConnectProto
var sessp bool
rc, cp, err = c.mqttParseConnect(r, pl)
if trace && cp != nil {
c.traceInOp("CONNECT", errOrTrace(err, c.mqttConnectTrace(cp)))
}
if rc != 0 {
c.mqttEnqueueConnAck(rc, sessp)
if trace {
c.traceOutOp("CONNACK", []byte(fmt.Sprintf("sp=%v rc=%v", sessp, rc)))
}
} else if err == nil {
if err = s.mqttProcessConnect(c, cp, trace); err != nil {
err = fmt.Errorf("unable to connect: %v", err)
} else {
connected = true
rd = cp.rd
}
}
case mqttPacketDisconnect:
if trace {
c.traceInOp("DISCONNECT", nil)
}
// Normal disconnect, we need to discard the will.
// Spec [MQTT-3.1.2-8]
c.mu.Lock()
if c.mqtt.cp != nil {
c.mqtt.cp.will = nil
}
c.mu.Unlock()
s.mqttHandleClosedClient(c)
c.closeConnection(ClientClosed)
return nil
case mqttPacketPubRec, mqttPacketPubRel, mqttPacketPubComp:
err = fmt.Errorf("protocol %d not supported", pt>>4)
default:
err = fmt.Errorf("received unknown packet type %d", pt>>4)
}
}
if err == nil && rd > 0 {
r.reader.SetReadDeadline(time.Now().Add(rd))
}
return err
}
func (c *client) mqttTraceMsg(msg []byte) {
maxTrace := c.srv.getOpts().MaxTracedMsgLen
if maxTrace > 0 && len(msg) > maxTrace {
c.Tracef("<<- MSG_PAYLOAD: [\"%s...\"]", msg[:maxTrace])
} else {
c.Tracef("<<- MSG_PAYLOAD: [%q]", msg)
}
}
// The MQTT client connection has been closed, or the DISCONNECT packet was received.
// For a "clean" session, we will delete the session, otherwise, simply removing
// the binding. We will also send the "will" message if applicable.
//
// Runs from the client's readLoop.
// No lock held on entry.
func (s *Server) mqttHandleClosedClient(c *client) {
c.mu.Lock()
asm := c.mqtt.asm
sess := c.mqtt.sess
c.mu.Unlock()
// If asm or sess are nil, it means that we have failed a client
// before it was associated with a session, so nothing more to do.
if asm == nil || sess == nil {
return
}
// Add this session to the locked map for the rest of the execution.
if err := asm.lockSession(sess, c); err != nil {
return
}
defer asm.unlockSession(sess)
asm.mu.Lock()
// Clear the client from the session, but session may stay.
sess.mu.Lock()
sess.c = nil
doClean := sess.clean
sess.mu.Unlock()
// If it was a clean session, then we remove from the account manager,
// and we will call clear() outside of any lock.
if doClean {
asm.removeSession(sess, false)
}
// Remove in case it was in the flappers map.
asm.removeSessFromFlappers(sess.id)
asm.mu.Unlock()
// This needs to be done outside of any lock.
if doClean {
sess.clear(true)
}
// Now handle the "will". This function will be a no-op if there is no "will" to send.
s.mqttHandleWill(c)
}
// Updates the MaxAckPending for all MQTT sessions, updating the
// JetStream consumers and updating their max ack pending and forcing
// a expiration of pending messages.
//
// Runs from a server configuration reload routine.
// No lock held on entry.
func (s *Server) mqttUpdateMaxAckPending(newmaxp uint16) {
msm := &s.mqtt.sessmgr
s.accounts.Range(func(k, _ interface{}) bool {
accName := k.(string)
msm.mu.RLock()
asm := msm.sessions[accName]
msm.mu.RUnlock()
if asm == nil {
// Move to next account
return true
}
asm.mu.RLock()
for _, sess := range asm.sessions {
sess.mu.Lock()
sess.maxp = newmaxp
sess.mu.Unlock()
}
asm.mu.RUnlock()
return true
})
}
// Returns the MQTT sessions manager for a given account.
// If new, creates the required JetStream streams/consumers
// for handling of sessions and messages.
func (s *Server) getOrCreateMQTTAccountSessionManager(clientID string, c *client) (*mqttAccountSessionManager, error) {
sm := &s.mqtt.sessmgr
c.mu.Lock()
acc := c.acc
c.mu.Unlock()
accName := acc.GetName()
sm.mu.RLock()
asm, ok := sm.sessions[accName]
sm.mu.RUnlock()
if ok {
return asm, nil
}
// We will pass the quitCh to the account session manager if we happen to create it.
s.mu.Lock()
quitCh := s.quitCh
s.mu.Unlock()
// Not found, now take the write lock and check again
sm.mu.Lock()
defer sm.mu.Unlock()
asm, ok = sm.sessions[accName]
if ok {
return asm, nil
}
// Need to create one here.
asm, err := s.mqttCreateAccountSessionManager(acc, quitCh)
if err != nil {
return nil, err
}
sm.sessions[accName] = asm
return asm, nil
}
// Creates JS streams/consumers for handling of sessions and messages for this account.
//
// Global session manager lock is held on entry.
func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struct{}) (*mqttAccountSessionManager, error) {
var err error
accName := acc.GetName()
c := s.createInternalAccountClient()
c.acc = acc
id := string(getHash(s.Name()))
replicas := s.mqttDetermineReplicas()
s.Noticef("Creating MQTT streams/consumers with replicas %v for account %q", replicas, accName)
as := &mqttAccountSessionManager{
sessions: make(map[string]*mqttSession),
sessByHash: make(map[string]*mqttSession),
sessLocked: make(map[string]struct{}),
flappers: make(map[string]int64),
replicas: replicas,
jsa: mqttJSA{
id: id,
c: c,
rplyr: mqttJSARepliesPrefix + id + ".",
sendq: make(chan *mqttJSPubMsg, 8192),
nuid: nuid.New(),
quitCh: quitCh,
},
}
var subs []*subscription
var success bool
closeCh := make(chan struct{})
defer func() {
if success {
return
}
for _, sub := range subs {
c.processUnsub(sub.sid)
}
close(closeCh)
}()
// We create all subscriptions before starting the go routine that will do
// sends otherwise we could get races.
// Note that using two different clients (one for the subs, one for the
// sends) would cause other issues such as registration of recent subs in
// the "sub" client would be invisible to the check for GW routed replies
// (shouldMapReplyForGatewaySend) since the client there would be the "sender".
jsa := &as.jsa
sid := int64(1)
// This is a subscription that will process all JS API replies. We could split to
// individual subscriptions if needed, but since there is a bit of common code,
// that seemed like a good idea to be all in one place.
if err := as.createSubscription(jsa.rplyr+"*.*",
as.processJSAPIReplies, &sid, &subs); err != nil {
return nil, err
}
// We will listen for replies to session persist requests so that we can
// detect the use of a session with the same client ID anywhere in the cluster.
if err := as.createSubscription(mqttJSARepliesPrefix+"*."+mqttJSASessPersist+".*",
as.processSessionPersist, &sid, &subs); err != nil {
return nil, err
}
// We create the subscription on "$MQTT.sub.<nuid>" to limit the subjects
// that a user would allow permissions on.
rmsubj := mqttSubPrefix + nuid.Next()
if err := as.createSubscription(rmsubj, as.processRetainedMsg, &sid, &subs); err != nil {
return nil, err
}
// Create a subscription to be notified of retained messages delete requests.
rmdelsubj := mqttJSARepliesPrefix + "*." + mqttJSARetainedMsgDel
if err := as.createSubscription(rmdelsubj, as.processRetainedMsgDel, &sid, &subs); err != nil {
return nil, err
}
// No more creation of subscriptions past this point otherwise RACEs may happen.
// Start the go routine that will send JS API requests.
s.startGoRoutine(func() {
defer s.grWG.Done()
as.sendJSAPIrequests(s, c, accName, closeCh)
})
// Create the stream for the messages.
cfg := &StreamConfig{
Name: mqttStreamName,
Subjects: []string{mqttStreamSubjectPrefix + ">"},
Storage: FileStorage,
Retention: InterestPolicy,
Replicas: as.replicas,
}
if _, err := jsa.createStream(cfg); isErrorOtherThan(err, ErrJetStreamStreamAlreadyUsed) {
return nil, fmt.Errorf("create messages stream for account %q: %v", acc.GetName(), err)
}
// Create the stream for retained messages.
cfg = &StreamConfig{
Name: mqttRetainedMsgsStreamName,
Subjects: []string{mqttRetainedMsgsStreamSubject},
Storage: FileStorage,
Retention: LimitsPolicy,
Replicas: as.replicas,
}
si, err := jsa.createStream(cfg)
if isErrorOtherThan(err, ErrJetStreamStreamAlreadyUsed) {
return nil, fmt.Errorf("create retained messages stream for account %q: %v", acc.GetName(), err)
}
if err != nil {
si, err = jsa.lookupStream(mqttRetainedMsgsStreamName)
if err != nil {
return nil, fmt.Errorf("lookup retained messages stream for account %q: %v", acc.GetName(), err)
}
}
var lastSeq uint64
var rmDoneCh chan struct{}
st := si.State
if st.Msgs > 0 {
lastSeq = st.LastSeq
if lastSeq > 0 {
rmDoneCh = make(chan struct{})
as.rrmLastSeq = lastSeq
as.rrmDoneCh = rmDoneCh
}
}
// Using ephemeral consumer is too risky because if this server were to be
// disconnected from the rest for few seconds, then the leader would remove
// the consumer, so even after a reconnect, we would not longer receive
// retained messages. Delete any existing durable that we have for that
// and recreate here.
// The name for the durable is $MQTT_rmsgs_<server name hash> (which is jsa.id)
rmDurName := mqttRetainedMsgsStreamName + "_" + jsa.id
resp, err := jsa.deleteConsumer(mqttRetainedMsgsStreamName, rmDurName)
// If error other than "not found" then fail, otherwise proceed with creating
// the durable consumer.
if err != nil && (resp == nil || resp.Error.Code != 404) {
return nil, err
}
ccfg := &CreateConsumerRequest{
Stream: mqttRetainedMsgsStreamName,
Config: ConsumerConfig{
Durable: rmDurName,
FilterSubject: mqttRetainedMsgsStreamSubject,
DeliverSubject: rmsubj,
ReplayPolicy: ReplayInstant,
AckPolicy: AckNone,
},
}
if _, err := jsa.createConsumer(ccfg); err != nil {
return nil, fmt.Errorf("create retained messages consumer for account %q: %v", acc.GetName(), err)
}
if lastSeq > 0 {
select {
case <-rmDoneCh:
case <-time.After(mqttJSAPITimeout):
s.Warnf("Timing out waiting to load %v retained messages", st.Msgs)
case <-quitCh:
return nil, ErrServerNotRunning
}
}
// Set this so that on defer we don't cleanup.
success = true
return as, nil
}
func (s *Server) mqttDetermineReplicas() int {
// If not clustered, then replica will be 1.
if !s.JetStreamIsClustered() {
return 1
}
opts := s.getOpts()
replicas := 0
for _, u := range opts.Routes {
host := u.Hostname()
// If this is an IP just add one.
if net.ParseIP(host) != nil {
replicas++
} else {
addrs, _ := net.LookupHost(host)
replicas += len(addrs)
}
}
if replicas < 1 {
replicas = 1
} else if replicas > 3 {
replicas = 3
}
return replicas
}
//////////////////////////////////////////////////////////////////////////////
//
// JS APIs related functions
//
//////////////////////////////////////////////////////////////////////////////
func (jsa *mqttJSA) newRequest(kind, subject string, hdr int, msg []byte) (interface{}, error) {
return jsa.newRequestEx(kind, subject, hdr, msg, mqttJSAPITimeout)
}
func (jsa *mqttJSA) newRequestEx(kind, subject string, hdr int, msg []byte, timeout time.Duration) (interface{}, error) {
jsa.mu.Lock()
// Either we use nuid.Next() which uses a global lock, or our own nuid object, but
// then it needs to be "write" protected. This approach will reduce across account
// contention since we won't use the global nuid's lock.
reply := jsa.rplyr + kind + "." + jsa.nuid.Next()
jsa.mu.Unlock()
ch := make(chan interface{}, 1)
jsa.replies.Store(reply, ch)
jsa.sendq <- &mqttJSPubMsg{
subj: subject,
reply: reply,
hdr: hdr,
msg: msg,
}
var i interface{}
// We don't want to use time.After() which causes memory growth because the timer
// can't be stopped and will need to expire to then be garbage collected.
t := time.NewTimer(timeout)
select {
case i = <-ch:
// Ensure we stop the timer so it can be quickly garbage collected.
t.Stop()
case <-jsa.quitCh:
return nil, ErrServerNotRunning
case <-t.C:
jsa.replies.Delete(reply)
return nil, fmt.Errorf("timeout for request type %q on %q (reply=%q)", kind, subject, reply)
}
return i, nil
}
// If `e` is not nil, returns an error corresponding to e.Description, if not empty,
// or an error of the form: "code %d".
func convertApiErrorToError(e *ApiError) error {
if e == nil {
return nil
}
if e.Description == _EMPTY_ {
return fmt.Errorf("code %d", e.Code)
}
return errors.New(e.Description)
}
func (jsa *mqttJSA) createConsumer(cfg *CreateConsumerRequest) (*JSApiConsumerCreateResponse, error) {
cfgb, err := json.Marshal(cfg)
if err != nil {
return nil, err
}
var subj string
if cfg.Config.Durable != _EMPTY_ {
subj = fmt.Sprintf(JSApiDurableCreateT, cfg.Stream, cfg.Config.Durable)
} else {
subj = fmt.Sprintf(JSApiConsumerCreateT, cfg.Stream)
}
ccri, err := jsa.newRequest(mqttJSAConsumerCreate, subj, 0, cfgb)
if err != nil {
return nil, err
}
ccr := ccri.(*JSApiConsumerCreateResponse)
return ccr, convertApiErrorToError(ccr.Error)
}
func (jsa *mqttJSA) deleteConsumer(streamName, consName string) (*JSApiConsumerDeleteResponse, error) {
subj := fmt.Sprintf(JSApiConsumerDeleteT, streamName, consName)
cdri, err := jsa.newRequest(mqttJSAConsumerDel, subj, 0, nil)
if err != nil {
return nil, err
}
cdr := cdri.(*JSApiConsumerDeleteResponse)
return cdr, convertApiErrorToError(cdr.Error)
}
func (jsa *mqttJSA) createStream(cfg *StreamConfig) (*StreamInfo, error) {
cfgb, err := json.Marshal(cfg)
if err != nil {
return nil, err
}
scri, err := jsa.newRequest(mqttJSAStreamCreate, fmt.Sprintf(JSApiStreamCreateT, cfg.Name), 0, cfgb)
if err != nil {
return nil, err
}
scr := scri.(*JSApiStreamCreateResponse)
return scr.StreamInfo, convertApiErrorToError(scr.Error)
}
func (jsa *mqttJSA) lookupStream(name string) (*StreamInfo, error) {
slri, err := jsa.newRequest(mqttJSAStreamLookup, fmt.Sprintf(JSApiStreamInfoT, name), 0, nil)
if err != nil {
return nil, err
}
slr := slri.(*JSApiStreamInfoResponse)
return slr.StreamInfo, convertApiErrorToError(slr.Error)
}
func (jsa *mqttJSA) deleteStream(name string) (bool, error) {
sdri, err := jsa.newRequest(mqttJSAStreamDel, fmt.Sprintf(JSApiStreamDeleteT, name), 0, nil)
if err != nil {
return false, err
}
sdr := sdri.(*JSApiStreamDeleteResponse)
return sdr.Success, convertApiErrorToError(sdr.Error)
}
func (jsa *mqttJSA) loadMsg(streamName string, seq uint64) (*StoredMsg, error) {
mreq := &JSApiMsgGetRequest{Seq: seq}
req, err := json.Marshal(mreq)
if err != nil {
return nil, err
}
lmri, err := jsa.newRequest(mqttJSAMsgLoad, fmt.Sprintf(JSApiMsgGetT, streamName), 0, req)
if err != nil {
return nil, err
}
lmr := lmri.(*JSApiMsgGetResponse)
return lmr.Message, convertApiErrorToError(lmr.Error)
}
func (jsa *mqttJSA) storeMsg(subject string, headers int, msg []byte) (*JSPubAckResponse, error) {
return jsa.storeMsgWithKind(mqttJSAMsgStore, subject, headers, msg)
}
func (jsa *mqttJSA) storeMsgWithKind(kind, subject string, headers int, msg []byte) (*JSPubAckResponse, error) {
smri, err := jsa.newRequest(kind, subject, headers, msg)
if err != nil {
return nil, err
}
smr := smri.(*JSPubAckResponse)
return smr, convertApiErrorToError(smr.Error)
}
func (jsa *mqttJSA) deleteMsg(stream string, seq uint64) {
dreq := JSApiMsgDeleteRequest{Seq: seq, NoErase: true}
req, _ := json.Marshal(dreq)
jsa.sendq <- &mqttJSPubMsg{
subj: fmt.Sprintf(JSApiMsgDeleteT, stream),
msg: req,
}
}
//////////////////////////////////////////////////////////////////////////////
//
// Account Sessions Manager related functions
//
//////////////////////////////////////////////////////////////////////////////
// Returns true if `err1` is not nil and does not match `err2`, that is
// their error strings are different.
// Assumes that `err2` is never nil.
func isErrorOtherThan(err1, err2 error) bool {
return err1 != nil && err1.Error() != err2.Error()
}
// Process JS API replies.
//
// Can run from various go routines (consumer's loop, system send loop, etc..).
func (as *mqttAccountSessionManager) processJSAPIReplies(_ *subscription, pc *client, subject, _ string, msg []byte) {
token := tokenAt(subject, mqttJSATokenPos)
if token == _EMPTY_ {
return
}
jsa := &as.jsa
chi, ok := jsa.replies.Load(subject)
if !ok {
return
}
jsa.replies.Delete(subject)
ch := chi.(chan interface{})
switch token {
case mqttJSAStreamCreate:
var resp = &JSApiStreamCreateResponse{}
if err := json.Unmarshal(msg, resp); err != nil {
resp.Error = jsInvalidJSONErr
}
ch <- resp
case mqttJSAStreamLookup:
var resp = &JSApiStreamInfoResponse{}
if err := json.Unmarshal(msg, &resp); err != nil {
resp.Error = jsInvalidJSONErr
}
ch <- resp
case mqttJSAStreamDel:
var resp = &JSApiStreamDeleteResponse{}
if err := json.Unmarshal(msg, &resp); err != nil {
resp.Error = jsInvalidJSONErr
}
ch <- resp
case mqttJSAConsumerCreate:
var resp = &JSApiConsumerCreateResponse{}
if err := json.Unmarshal(msg, resp); err != nil {
resp.Error = jsInvalidJSONErr
}
ch <- resp
case mqttJSAConsumerDel:
var resp = &JSApiConsumerDeleteResponse{}
if err := json.Unmarshal(msg, resp); err != nil {
resp.Error = jsInvalidJSONErr
}
ch <- resp
case mqttJSAMsgStore, mqttJSASessPersist:
var resp = &JSPubAckResponse{}
if err := json.Unmarshal(msg, resp); err != nil {
resp.Error = jsInvalidJSONErr
}
ch <- resp
case mqttJSAMsgLoad:
var resp = &JSApiMsgGetResponse{}
if err := json.Unmarshal(msg, resp); err != nil {
resp.Error = jsInvalidJSONErr
}
ch <- resp
default:
pc.Warnf("Unknown reply code %q", token)
}
}
// This will both load all retained messages and process updates from the cluster.
//
// Run from various go routines (JS consumer, etc..).
// No lock held on entry.
func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *client, subject, reply string, rmsg []byte) {
_, msg := c.msgParts(rmsg)
rm := &mqttRetainedMsg{}
if err := json.Unmarshal(msg, rm); err != nil {
return
}
// If lastSeq is 0 (nothing to recover, or done doing it) and this is
// from our own server, ignore.
if as.rrmLastSeq == 0 && rm.Origin == as.jsa.id {
return
}
// At this point we either recover from our own server, or process a remote retained message.
seq, _, _ := ackReplyInfo(reply)
// Handle this retained message
rm.sseq = seq
as.handleRetainedMsg(rm.Subject, rm)
// If we were recovering (lastSeq > 0), then check if we are done.
if as.rrmLastSeq > 0 && seq >= as.rrmLastSeq {
as.rrmLastSeq = 0
close(as.rrmDoneCh)
as.rrmDoneCh = nil
}
}
func (as *mqttAccountSessionManager) processRetainedMsgDel(_ *subscription, c *client, subject, reply string, rmsg []byte) {
idHash := tokenAt(subject, 3)
if idHash == _EMPTY_ || idHash == as.jsa.id {
return
}
_, msg := c.msgParts(rmsg)
if len(msg) < LEN_CR_LF {
return
}
var drm mqttRetMsgDel
if err := json.Unmarshal(msg, &drm); err != nil {
return
}
as.handleRetainedMsgDel(drm.Subject, drm.Seq)
}
// This will receive all JS API replies for a request to store a session record,
// including the reply for our own server, which we will ignore.
// This allows us to detect that some application somewhere else in the cluster
// is connecting with the same client ID, and therefore we need to close the
// connection that is currently using this client ID.
//
// Can run from various go routines (system send loop, etc..).
// No lock held on entry.
func (as *mqttAccountSessionManager) processSessionPersist(_ *subscription, pc *client, subject, _ string, rmsg []byte) {
// Ignore our own responses here (they are handled elsewhere)
if tokenAt(subject, mqttJSAIdTokenPos) == as.jsa.id {
return
}
_, msg := pc.msgParts(rmsg)
if len(msg) < LEN_CR_LF {
return
}
var par = &JSPubAckResponse{}
if err := json.Unmarshal(msg, par); err != nil {
return
}
if err := convertApiErrorToError(par.Error); err != nil {
return
}
cIDHash := strings.TrimPrefix(par.Stream, mqttSessionsStreamNamePrefix)
as.mu.Lock()
defer as.mu.Unlock()
sess, ok := as.sessByHash[cIDHash]
if !ok {
return
}
// If our current session's stream sequence is higher, it means that this
// update is stale, so we don't do anything here.
if par.Sequence < sess.seq {
return
}
as.removeSession(sess, false)
sess.mu.Lock()
if ec := sess.c; ec != nil {
as.addSessToFlappers(sess.id)
ec.Warnf("Closing because a remote connection has started with the same client ID: %q", sess.id)
// Disassociate the client from the session so that on client close,
// nothing will be done with regards to cleaning up the session,
// such as deleting stream, etc..
sess.c = nil
// Remove in separate go routine.
go ec.closeConnection(DuplicateClientID)
}
sess.mu.Unlock()
}
// Adds this client ID to the flappers map, and if needed start the timer
// for map cleanup.
//
// Lock held on entry.
func (as *mqttAccountSessionManager) addSessToFlappers(clientID string) {
as.flappers[clientID] = time.Now().UnixNano()
if as.flapTimer == nil {
as.flapTimer = time.AfterFunc(mqttFlapCleanItvl, func() {
as.mu.Lock()
defer as.mu.Unlock()
// In case of shutdown, this will be nil
if as.flapTimer == nil {
return
}
now := time.Now().UnixNano()
for cID, tm := range as.flappers {
if now-tm > int64(mqttSessJailDur) {
delete(as.flappers, cID)
}
}
as.flapTimer.Reset(mqttFlapCleanItvl)
})
}
}
// Remove this client ID from the flappers map.
//
// Lock held on entry.
func (as *mqttAccountSessionManager) removeSessFromFlappers(clientID string) {
delete(as.flappers, clientID)
// Do not stop/set timer to nil here. Better leave the timer run at its
// regular interval and detect that there is nothing to do. The timer
// will be stopped on shutdown.
}
// Helper to create a subscription. It updates the sid and array of subscriptions.
func (as *mqttAccountSessionManager) createSubscription(subject string, cb msgHandler, sid *int64, subs *[]*subscription) error {
sub, err := as.jsa.c.processSub([]byte(subject), nil, []byte(strconv.FormatInt(*sid, 10)), cb, false)
if err != nil {
return err
}
*sid++
*subs = append(*subs, sub)
return nil
}
// Loop to send JS API requests for a given MQTT account.
// The closeCh is used by the caller to be able to interrupt this routine
// if the rest of the initialization fails, since the quitCh is really
// only used when the server shutdown.
//
// No lock held on entry.
func (as *mqttAccountSessionManager) sendJSAPIrequests(s *Server, c *client, accName string, closeCh chan struct{}) {
var cluster string
if s.JetStreamEnabled() {
cluster = s.cachedClusterName()
}
as.mu.RLock()
sendq := as.jsa.sendq
quitCh := as.jsa.quitCh
ci := ClientInfo{Account: accName, Cluster: cluster}
as.mu.RUnlock()
// The account session manager does not have a suhtdown API per-se, instead,
// we will cleanup things when this go routine exits after detecting that the
// server is shutdown or the initialization of the account manager failed.
defer func() {
as.mu.Lock()
if as.flapTimer != nil {
as.flapTimer.Stop()
as.flapTimer = nil
}
as.mu.Unlock()
}()
b, _ := json.Marshal(ci)
hdrStart := bytes.Buffer{}
hdrStart.WriteString(hdrLine)
http.Header{ClientInfoHdr: []string{string(b)}}.Write(&hdrStart)
hdrStart.WriteString(CR_LF)
hdrStart.WriteString(CR_LF)
hdrb := hdrStart.Bytes()
for {
select {
case r := <-sendq:
var nsize int
msg := r.msg
// If r.hdr is set to -1, it means that there is no need for any header.
if r.hdr != -1 {
bb := bytes.Buffer{}
if r.hdr > 0 {
// This means that the header has been set by the caller and is
// already part of `msg`, so simply set c.pa.hdr to the given value.
c.pa.hdr = r.hdr
nsize = len(msg)
msg = append(msg, _CRLF_...)
} else {
// We need the ClientInfo header, so add it here.
bb.Write(hdrb)
c.pa.hdr = bb.Len()
bb.Write(r.msg)
nsize = bb.Len()
bb.WriteString(_CRLF_)
msg = bb.Bytes()
}
c.pa.hdb = []byte(strconv.Itoa(c.pa.hdr))
} else {
c.pa.hdr = -1
c.pa.hdb = nil
nsize = len(msg)
msg = append(msg, _CRLF_...)
}
c.pa.subject = []byte(r.subj)
c.pa.reply = []byte(r.reply)
c.pa.size = nsize
c.pa.szb = []byte(strconv.Itoa(nsize))
c.processInboundClientMsg(msg)
c.flushClients(0)
case <-closeCh:
return
case <-quitCh:
return
}
}
}
// Add/Replace this message from the retained messages map.
// If a message for this topic already existed, the existing record is updated
// with the provided information.
// This function will return the stream sequence of the record before its update,
// or 0 if the record was added instead of updated.
//
// Lock not held on entry.
func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rm *mqttRetainedMsg) uint64 {
as.mu.Lock()
defer as.mu.Unlock()
if as.retmsgs == nil {
as.retmsgs = make(map[string]*mqttRetainedMsg)
as.sl = NewSublistWithCache()
} else {
// Check if we already had one. If so, update the existing one.
if erm, exists := as.retmsgs[key]; exists {
// If the new sequence is below the floor or the existing one,
// then ignore the new one.
if rm.sseq <= erm.sseq || rm.sseq <= erm.floor {
return 0
}
// Update the existing retained message record with the new rm record.
erm.Origin = rm.Origin
erm.Msg = rm.Msg
erm.Flags = rm.Flags
erm.Source = rm.Source
// Capture existing sequence number so we can return it as the old sequence.
oldSeq := erm.sseq
erm.sseq = rm.sseq
// Clear the floor
erm.floor = 0
// If sub is nil, it means that it was removed from sublist following a
// network delete. So need to add it now.
if erm.sub == nil {
erm.sub = &subscription{subject: []byte(key)}
as.sl.Insert(erm.sub)
}
return oldSeq
}
}
rm.sub = &subscription{subject: []byte(key)}
as.retmsgs[key] = rm
as.sl.Insert(rm.sub)
return 0
}
// Removes the retained message for the given `subject` if present, and returns the
// stream sequence it was stored at. It will be 0 if no retained message was removed.
// If a sequence is passed and not 0, then the retained message will be removed only
// if the given sequence is equal or higher to what is stored.
//
// No lock held on entry.
func (as *mqttAccountSessionManager) handleRetainedMsgDel(subject string, seq uint64) uint64 {
var seqToRemove uint64
as.mu.Lock()
if as.retmsgs == nil {
as.retmsgs = make(map[string]*mqttRetainedMsg)
as.sl = NewSublistWithCache()
}
if erm, ok := as.retmsgs[subject]; ok {
if erm.sub != nil {
as.sl.Remove(erm.sub)
erm.sub = nil
}
// If processing a delete request from the network, then seq will be > 0.
// If that is the case and it is greater or equal to what we have, we need
// to record the floor for this subject.
if seq != 0 && seq >= erm.sseq {
erm.sseq = 0
erm.floor = seq
} else if seq == 0 {
delete(as.retmsgs, subject)
seqToRemove = erm.sseq
}
} else if seq != 0 {
rm := &mqttRetainedMsg{Subject: subject, floor: seq}
as.retmsgs[subject] = rm
}
as.mu.Unlock()
return seqToRemove
}
// First check if this session's client ID is already in the "locked" map,
// which if it is the case means that another client is now bound to this
// session and this should return an error.
// If not in the "locked" map, but the client is not bound with this session,
// then same error is returned.
// Finally, if all checks ok, then the session's ID is added to the "locked" map.
//
// No lock held on entry.
func (as *mqttAccountSessionManager) lockSession(sess *mqttSession, c *client) error {
as.mu.Lock()
defer as.mu.Unlock()
var fail bool
if _, fail = as.sessLocked[sess.id]; !fail {
sess.mu.Lock()
fail = sess.c != c
sess.mu.Unlock()
}
if fail {
return fmt.Errorf("another session is in use with client ID %q", sess.id)
}
as.sessLocked[sess.id] = struct{}{}
return nil
}
// Remove the session from the "locked" map.
//
// No lock held on entry.
func (as *mqttAccountSessionManager) unlockSession(sess *mqttSession) {
as.mu.Lock()
delete(as.sessLocked, sess.id)
as.mu.Unlock()
}
// Simply adds the session to the various sessions maps.
// The boolean `lock` indicates if this function should acquire the lock
// prior to adding to the maps.
//
// No lock held on entry.
func (as *mqttAccountSessionManager) addSession(sess *mqttSession, lock bool) {
if lock {
as.mu.Lock()
}
as.sessions[sess.id] = sess
as.sessByHash[sess.idHash] = sess
if lock {
as.mu.Unlock()
}
}
// Simply removes the session from the various sessions maps.
// The boolean `lock` indicates if this function should acquire the lock
// prior to removing from the maps.
//
// No lock held on entry.
func (as *mqttAccountSessionManager) removeSession(sess *mqttSession, lock bool) {
if lock {
as.mu.Lock()
}
delete(as.sessions, sess.id)
delete(as.sessByHash, sess.idHash)
if lock {
as.mu.Unlock()
}
}
// Process subscriptions for the given session/client.
//
// When `fromSubProto` is false, it means that this is invoked from the CONNECT
// protocol, when restoring subscriptions that were saved for this session.
// In that case, there is no need to update the session record.
//
// When `fromSubProto` is true, it means that this call is invoked from the
// processing of the SUBSCRIBE protocol, which means that the session needs to
// be updated. It also means that if a subscription on same subject with same
// QoS already exist, we should not be recreating the subscription/JS durable,
// since it was already done when processing the CONNECT protocol.
//
// Runs from the client's readLoop.
// Lock not held on entry, but session is in the locked map.
func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, c *client,
filters []*mqttFilter, fromSubProto, trace bool) ([]*subscription, error) {
// Helpers to lock/unlock both account manager and session.
asAndSessLock := func() {
as.mu.Lock()
sess.mu.Lock()
}
asAndSessUnlock := func() {
sess.mu.Unlock()
as.mu.Unlock()
}
// Small helper to add the consumer config to the session.
addJSConsToSess := func(sid string, cc *ConsumerConfig) {
if cc == nil {
return
}
if sess.cons == nil {
sess.cons = make(map[string]*ConsumerConfig)
}
sess.cons[sid] = cc
}
// Helper that sets the sub's mqtt fields and possibly serialize retained messages.
// Assumes account manager and session lock held.
setupSub := func(sub *subscription, qos byte) {
if sub.mqtt == nil {
sub.mqtt = &mqttSub{}
}
sub.mqtt.qos = qos
if fromSubProto {
as.serializeRetainedMsgsForSub(sess, c, sub, trace)
}
}
var err error
subs := make([]*subscription, 0, len(filters))
for _, f := range filters {
if f.qos > 1 {
f.qos = 1
}
subject := f.filter
sid := subject
if strings.HasPrefix(subject, mqttSubPrefix) {
f.qos = mqttSubAckFailure
continue
}
var jscons *ConsumerConfig
var jssub *subscription
// Note that if a subscription already exists on this subject,
// the existing sub is returned. Need to update the qos.
asAndSessLock()
sub, err := c.processSub([]byte(subject), nil, []byte(sid), mqttDeliverMsgCbQos0, false)
if err == nil {
setupSub(sub, f.qos)
}
asAndSessUnlock()
if err == nil {
// This will create (if not already exist) a JS consumer for subscriptions
// of QoS >= 1. But if a JS consumer already exists and the subscription
// for same subject is now a QoS==0, then the JS consumer will be deleted.
jscons, jssub, err = sess.processJSConsumer(c, subject, sid, f.qos, fromSubProto)
}
if err != nil {
// c.processSub already called c.Errorf(), so no need here.
f.qos = mqttSubAckFailure
sess.cleanupFailedSub(c, sub, jscons, jssub)
continue
}
if mqttNeedSubForLevelUp(subject) {
var fwjscons *ConsumerConfig
var fwjssub *subscription
var fwcsub *subscription
// Say subject is "foo.>", remove the ".>" so that it becomes "foo"
fwcsubject := subject[:len(subject)-2]
// Change the sid to "foo fwc"
fwcsid := fwcsubject + mqttMultiLevelSidSuffix
// See note above about existing subscription.
asAndSessLock()
fwcsub, err = c.processSub([]byte(fwcsubject), nil, []byte(fwcsid), mqttDeliverMsgCbQos0, false)
if err == nil {
setupSub(fwcsub, f.qos)
}
asAndSessUnlock()
if err == nil {
fwjscons, fwjssub, err = sess.processJSConsumer(c, fwcsubject, fwcsid, f.qos, fromSubProto)
}
if err != nil {
// c.processSub already called c.Errorf(), so no need here.
f.qos = mqttSubAckFailure
sess.cleanupFailedSub(c, sub, jscons, jssub)
sess.cleanupFailedSub(c, fwcsub, fwjscons, fwjssub)
continue
}
subs = append(subs, fwcsub)
addJSConsToSess(fwcsid, fwjscons)
}
subs = append(subs, sub)
addJSConsToSess(sid, jscons)
}
if fromSubProto {
err = sess.update(filters, true)
}
return subs, err
}
// Retained publish messages matching this subscription are serialized in the
// subscription's `prm` mqtt writer. This buffer will be queued for outbound
// after the subscription is processed and SUBACK is sent or possibly when
// server processes an incoming published message matching the newly
// registered subscription.
//
// Runs from the client's readLoop.
// Account session manager lock held on entry.
// Session lock held on entry.
func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(sess *mqttSession, c *client, sub *subscription, trace bool) {
if len(as.retmsgs) == 0 {
return
}
var rmsa [64]*mqttRetainedMsg
rms := rmsa[:0]
as.getRetainedPublishMsgs(string(sub.subject), &rms)
for _, rm := range rms {
if sub.mqtt.prm == nil {
sub.mqtt.prm = &mqttWriter{}
}
prm := sub.mqtt.prm
pi := sess.getPubAckIdentifier(mqttGetQoS(rm.Flags), sub)
// Need to use the subject for the retained message, not the `sub` subject.
// We can find the published retained message in rm.sub.subject.
flags := mqttSerializePublishMsg(prm, pi, false, true, []byte(rm.Topic), rm.Msg)
if trace {
pp := mqttPublish{
topic: []byte(rm.Topic),
flags: flags,
pi: pi,
sz: len(rm.Msg),
}
c.traceOutOp("PUBLISH", []byte(mqttPubTrace(&pp)))
}
}
}
// Returns in the provided slice all publish retained message records that
// match the given subscription's `subject` (which could have wildcards).
//
// Account session manager lock held on entry.
func (as *mqttAccountSessionManager) getRetainedPublishMsgs(subject string, rms *[]*mqttRetainedMsg) {
result := as.sl.ReverseMatch(subject)
if len(result.psubs) == 0 {
return
}
for _, sub := range result.psubs {
// Since this is a reverse match, the subscription objects here
// contain literals corresponding to the published subjects.
if rm, ok := as.retmsgs[string(sub.subject)]; ok {
*rms = append(*rms, rm)
}
}
}
// Creates the session stream (limit msgs of 1) for this client ID if it does
// not already exist. If it exists, recover the single record to rebuild the
// state of the session. If there is a session record but this session is not
// registered in the runtime of this server, then a request is made to the
// owner to close the client associated with this session since specification
// [MQTT-3.1.4-2] specifies that if the ClientId represents a Client already
// connected to the Server then the Server MUST disconnect the existing client.
//
// Runs from the client's readLoop.
// Lock not held on entry, but session is in the locked map.
func (as *mqttAccountSessionManager) createOrRestoreSession(clientID string, opts *Options) (*mqttSession, bool, error) {
// Add the JS domain (possibly empty) to the client ID, which will make
// session stream/filter subject be unique per domain. So if an application
// with the same client ID moves to the other domain, then there won't be
// conflict of session message in one domain updating the session's stream
// in others.
hash := string(getHash(opts.JetStreamDomain + clientID))
sname := mqttSessionsStreamNamePrefix + hash
cfg := &StreamConfig{
Name: sname,
Subjects: []string{sname},
Storage: FileStorage,
Retention: LimitsPolicy,
MaxMsgs: 1,
Replicas: as.replicas,
}
jsa := &as.jsa
formatError := func(errTxt string, err error) (*mqttSession, bool, error) {
accName := jsa.c.acc.GetName()
return nil, false, fmt.Errorf("%s for account %q, session %q: %v", errTxt, accName, clientID, err)
}
// Send a request to create the stream for this session.
si, err := jsa.createStream(cfg)
// If there is an error and not simply "already used" (which means that the
// stream already exists) then we fail.
if isErrorOtherThan(err, ErrJetStreamStreamAlreadyUsed) {
return formatError("create session stream", err)
}
if err != nil {
// Since we have returned if error is not "stream already exist", then
// it means that the stream already exists and so we now need to recover
// the existing record.
si, err = jsa.lookupStream(sname)
if err != nil {
return formatError("lookup session stream", err)
}
}
// The stream is supposed to have at most 1 record, if it is empty, it means
// that we just created it.
if si.State.Msgs == 0 {
// Create a session and indicate that this session did not exist.
sess := mqttSessionCreate(jsa, clientID, hash, 0, opts)
return sess, false, nil
}
// We need to recover the existing record now.
smsg, err := jsa.loadMsg(sname, si.State.LastSeq)
if err != nil {
return formatError("loading session record", err)
}
ps := &mqttPersistedSession{}
if err := json.Unmarshal(smsg.Data, ps); err != nil {
return formatError(fmt.Sprintf("unmarshal of session record at sequence %v", smsg.Sequence), err)
}
// Restore this session (even if we don't own it), the caller will do the right thing.
sess := mqttSessionCreate(jsa, clientID, hash, smsg.Sequence, opts)
sess.clean = ps.Clean
sess.subs = ps.Subs
sess.cons = ps.Cons
as.addSession(sess, true)
return sess, true, nil
}
// Sends a request to delete a message, but does not wait for the response.
//
// No lock held on entry.
func (as *mqttAccountSessionManager) deleteRetainedMsg(seq uint64) {
as.jsa.deleteMsg(mqttRetainedMsgsStreamName, seq)
}
// Sends a message indicating that a retained message on a given subject and stream sequence
// is being removed.
func (as *mqttAccountSessionManager) notifyRetainedMsgDeleted(subject string, seq uint64) {
req := mqttRetMsgDel{
Subject: subject,
Seq: seq,
}
b, _ := json.Marshal(&req)
jsa := &as.jsa
jsa.sendq <- &mqttJSPubMsg{
subj: jsa.rplyr + mqttJSARetainedMsgDel,
msg: b,
}
}
//////////////////////////////////////////////////////////////////////////////
//
// MQTT session related functions
//
//////////////////////////////////////////////////////////////////////////////
// Returns a new mqttSession object with max ack pending set based on
// option or use mqttDefaultMaxAckPending if no option set.
func mqttSessionCreate(jsa *mqttJSA, id, idHash string, seq uint64, opts *Options) *mqttSession {
maxp := opts.MQTT.MaxAckPending
if maxp == 0 {
maxp = mqttDefaultMaxAckPending
}
return &mqttSession{jsa: jsa, id: id, idHash: idHash, seq: seq, maxp: maxp}
}
// Persists a session. Note that if the session's current client does not match
// the given client, nothing is done.
//
// Lock not held on entry.
func (sess *mqttSession) save() error {
sess.mu.Lock()
ps := mqttPersistedSession{
Origin: sess.jsa.id,
ID: sess.id,
Clean: sess.clean,
Subs: sess.subs,
Cons: sess.cons,
}
b, _ := json.Marshal(&ps)
sname := mqttSessionsStreamNamePrefix + sess.idHash
seq := sess.seq
sess.mu.Unlock()
bb := bytes.Buffer{}
bb.WriteString(hdrLine)
bb.WriteString(JSExpectedLastSeq)
bb.WriteString(":")
bb.WriteString(strconv.FormatInt(int64(seq), 10))
bb.WriteString(CR_LF)
bb.WriteString(CR_LF)
hdr := bb.Len()
bb.Write(b)
resp, err := sess.jsa.storeMsgWithKind(mqttJSASessPersist, sname, hdr, bb.Bytes())
if err != nil {
return err
}
sess.mu.Lock()
sess.seq = resp.Sequence
sess.mu.Unlock()
return nil
}
// Clear the session. If `deleteStream` is true, the stream is deleted,
// otherwise only the consumers (if present) are deleted.
//
// Runs from the client's readLoop.
// Lock not held on entry, but session is in the locked map.
func (sess *mqttSession) clear(deleteStream bool) {
for sid, cc := range sess.cons {
delete(sess.cons, sid)
sess.deleteConsumer(cc)
}
if deleteStream {
sess.jsa.deleteStream(mqttSessionsStreamNamePrefix + sess.idHash)
}
sess.mu.Lock()
sess.subs, sess.pending, sess.cpending, sess.seq, sess.tmaxack = nil, nil, nil, 0, 0
sess.mu.Unlock()
}
// This will update the session record for this client in the account's MQTT
// sessions stream if the session had any change in the subscriptions.
//
// Runs from the client's readLoop.
// Lock not held on entry, but session is in the locked map.
func (sess *mqttSession) update(filters []*mqttFilter, add bool) error {
// Evaluate if we need to persist anything.
var needUpdate bool
for _, f := range filters {
if add {
if f.qos == mqttSubAckFailure {
continue
}
if qos, ok := sess.subs[f.filter]; !ok || qos != f.qos {
if sess.subs == nil {
sess.subs = make(map[string]byte)
}
sess.subs[f.filter] = f.qos
needUpdate = true
}
} else {
if _, ok := sess.subs[f.filter]; ok {
delete(sess.subs, f.filter)
needUpdate = true
}
}
}
var err error
if needUpdate {
err = sess.save()
}
return err
}
// If both pQos and sub.mqtt.qos are > 0, then this will return the next
// packet identifier to use for a published message.
//
// Lock held on entry
func (sess *mqttSession) getPubAckIdentifier(pQos byte, sub *subscription) uint16 {
pi, _ := sess.trackPending(pQos, _EMPTY_, sub)
return pi
}
// If publish message QoS (pQos) and the subscription's QoS are both at least 1,
// this function will assign a packet identifier (pi) and will keep track of
// the pending ack. If the message has already been redelivered (reply != ""),
// the returned boolean will be `true`.
//
// Lock held on entry
func (sess *mqttSession) trackPending(pQos byte, reply string, sub *subscription) (uint16, bool) {
if pQos == 0 || sub.mqtt.qos == 0 {
return 0, false
}
var dup bool
var pi uint16
bumpPI := func() uint16 {
var avail bool
next := sess.ppi
for i := 0; i < 0xFFFF; i++ {
next++
if next == 0 {
next = 1
}
if _, used := sess.pending[next]; !used {
sess.ppi = next
avail = true
break
}
}
if !avail {
return 0
}
return sess.ppi
}
// This can happen when invoked from getPubAckIdentifier...
if reply == _EMPTY_ || sub.mqtt.jsDur == _EMPTY_ {
return bumpPI(), false
}
// Here, we have an ACK subject and a JS consumer...
jsDur := sub.mqtt.jsDur
if sess.pending == nil {
sess.pending = make(map[uint16]*mqttPending)
sess.cpending = make(map[string]map[uint64]uint16)
}
// Get the stream sequence and other from the ack reply subject
sseq, _, dcount := ackReplyInfo(reply)
var pending *mqttPending
// For this JS consumer, check to see if we already have sseq->pi
sseqToPi, ok := sess.cpending[jsDur]
if !ok {
sseqToPi = make(map[uint64]uint16)
sess.cpending[jsDur] = sseqToPi
} else if pi, ok = sseqToPi[sseq]; ok {
// If we already have a pi, get the ack so we update it.
// We will reuse the save packet identifier (pi).
pending = sess.pending[pi]
}
if pi == 0 {
// sess.maxp will always have a value > 0.
if len(sess.pending) >= int(sess.maxp) {
// Indicate that we did not assign a packet identifier.
// The caller will not send the message to the subscription
// and JS will redeliver later, based on consumer's AckWait.
return 0, false
}
pi = bumpPI()
sseqToPi[sseq] = pi
}
if pending == nil {
pending = &mqttPending{jsDur: jsDur, sseq: sseq, subject: reply}
sess.pending[pi] = pending
}
// If redelivery, return DUP flag
if dcount > 1 {
dup = true
}
return pi, dup
}
// Sends a request to create a JS Durable Consumer based on the given consumer's config.
// This will wait in place for the reply from the server handling the requests.
//
// Lock held on entry
func (sess *mqttSession) createConsumer(consConfig *ConsumerConfig) error {
cfg := &CreateConsumerRequest{
Stream: mqttStreamName,
Config: *consConfig,
}
_, err := sess.jsa.createConsumer(cfg)
return err
}
// Sends a consumer delete request, but does not wait for response.
//
// Lock not held on entry.
func (sess *mqttSession) deleteConsumer(cc *ConsumerConfig) {
sess.mu.Lock()
sess.tmaxack -= cc.MaxAckPending
sess.mu.Unlock()
sess.jsa.sendq <- &mqttJSPubMsg{subj: fmt.Sprintf(JSApiConsumerDeleteT, mqttStreamName, cc.Durable)}
}
//////////////////////////////////////////////////////////////////////////////
//
// CONNECT protocol related functions
//
//////////////////////////////////////////////////////////////////////////////
// Parse the MQTT connect protocol
func (c *client) mqttParseConnect(r *mqttReader, pl int) (byte, *mqttConnectProto, error) {
// Make sure that we have the expected length in the buffer,
// and if not, this will read it from the underlying reader.
if err := r.ensurePacketInBuffer(pl); err != nil {
return 0, nil, err
}
// Protocol name
proto, err := r.readBytes("protocol name", false)
if err != nil {
return 0, nil, err
}
// Spec [MQTT-3.1.2-1]
if !bytes.Equal(proto, mqttProtoName) {
// Check proto name against v3.1 to report better error
if bytes.Equal(proto, mqttOldProtoName) {
return 0, nil, fmt.Errorf("older protocol %q not supported", proto)
}
return 0, nil, fmt.Errorf("expected connect packet with protocol name %q, got %q", mqttProtoName, proto)
}
// Protocol level
level, err := r.readByte("protocol level")
if err != nil {
return 0, nil, err
}
// Spec [MQTT-3.1.2-2]
if level != mqttProtoLevel {
return mqttConnAckRCUnacceptableProtocolVersion, nil, fmt.Errorf("unacceptable protocol version of %v", level)
}
cp := &mqttConnectProto{}
// Connect flags
cp.flags, err = r.readByte("flags")
if err != nil {
return 0, nil, err
}
// Spec [MQTT-3.1.2-3]
if cp.flags&mqttConnFlagReserved != 0 {
return 0, nil, errMQTTConnFlagReserved
}
var hasWill bool
wqos := (cp.flags & mqttConnFlagWillQoS) >> 3
wretain := cp.flags&mqttConnFlagWillRetain != 0
// Spec [MQTT-3.1.2-11]
if cp.flags&mqttConnFlagWillFlag == 0 {
// Spec [MQTT-3.1.2-13]
if wqos != 0 {
return 0, nil, fmt.Errorf("if Will flag is set to 0, Will QoS must be 0 too, got %v", wqos)
}
// Spec [MQTT-3.1.2-15]
if wretain {
return 0, nil, errMQTTWillAndRetainFlag
}
} else {
// Spec [MQTT-3.1.2-14]
if wqos == 3 {
return 0, nil, fmt.Errorf("if Will flag is set to 1, Will QoS can be 0, 1 or 2, got %v", wqos)
}
hasWill = true
}
// Spec [MQTT-3.1.2-19]
hasUser := cp.flags&mqttConnFlagUsernameFlag != 0
// Spec [MQTT-3.1.2-21]
hasPassword := cp.flags&mqttConnFlagPasswordFlag != 0
// Spec [MQTT-3.1.2-22]
if !hasUser && hasPassword {
return 0, nil, errMQTTPasswordFlagAndNoUser
}
// Keep alive
var ka uint16
ka, err = r.readUint16("keep alive")
if err != nil {
return 0, nil, err
}
// Spec [MQTT-3.1.2-24]
if ka > 0 {
cp.rd = time.Duration(float64(ka)*1.5) * time.Second
}
// Payload starts here and order is mandated by:
// Spec [MQTT-3.1.3-1]: client ID, will topic, will message, username, password
// Client ID
cp.clientID, err = r.readString("client ID")
if err != nil {
return 0, nil, err
}
// Spec [MQTT-3.1.3-7]
if cp.clientID == _EMPTY_ {
if cp.flags&mqttConnFlagCleanSession == 0 {
return mqttConnAckRCIdentifierRejected, nil, errMQTTCIDEmptyNeedsCleanFlag
}
// Spec [MQTT-3.1.3-6]
cp.clientID = nuid.Next()
}
// Spec [MQTT-3.1.3-4] and [MQTT-3.1.3-9]
if !utf8.ValidString(cp.clientID) {
return mqttConnAckRCIdentifierRejected, nil, fmt.Errorf("invalid utf8 for client ID: %q", cp.clientID)
}
if hasWill {
cp.will = &mqttWill{
qos: wqos,
retain: wretain,
}
var topic []byte
// Need to make a copy since we need to hold to this topic after the
// parsing of this protocol.
topic, err = r.readBytes("Will topic", true)
if err != nil {
return 0, nil, err
}
if len(topic) == 0 {
return 0, nil, errMQTTEmptyWillTopic
}
if !utf8.Valid(topic) {
return 0, nil, fmt.Errorf("invalid utf8 for Will topic %q", topic)
}
cp.will.topic = topic
// Convert MQTT topic to NATS subject
if cp.will.subject, err = mqttTopicToNATSPubSubject(topic); err != nil {
return 0, nil, err
}
// Now "will" message.
// Ask for a copy since we need to hold to this after parsing of this protocol.
cp.will.message, err = r.readBytes("Will message", true)
if err != nil {
return 0, nil, err
}
}
if hasUser {
c.opts.Username, err = r.readString("user name")
if err != nil {
return 0, nil, err
}
if c.opts.Username == _EMPTY_ {
return mqttConnAckRCBadUserOrPassword, nil, errMQTTEmptyUsername
}
// Spec [MQTT-3.1.3-11]
if !utf8.ValidString(c.opts.Username) {
return mqttConnAckRCBadUserOrPassword, nil, fmt.Errorf("invalid utf8 for user name %q", c.opts.Username)
}
}
if hasPassword {
c.opts.Password, err = r.readString("password")
if err != nil {
return 0, nil, err
}
c.opts.Token = c.opts.Password
c.opts.JWT = c.opts.Password
}
return 0, cp, nil
}
func (c *client) mqttConnectTrace(cp *mqttConnectProto) string {
trace := fmt.Sprintf("clientID=%s", cp.clientID)
if cp.rd > 0 {
trace += fmt.Sprintf(" keepAlive=%v", cp.rd)
}
if cp.will != nil {
trace += fmt.Sprintf(" will=(topic=%s QoS=%v retain=%v)",
cp.will.topic, cp.will.qos, cp.will.retain)
}
if c.opts.Username != _EMPTY_ {
trace += fmt.Sprintf(" username=%s", c.opts.Username)
}
if c.opts.Password != _EMPTY_ {
trace += " password=****"
}
return trace
}
// Process the CONNECT packet.
//
// For the first session on the account, an account session manager will be created,
// along with the JetStream streams/consumer necessary for the working of MQTT.
//
// The session, identified by a client ID, will be registered, or if already existing,
// will be resumed. If the session exists but is associated with an existing client,
// the old client is evicted, as per the specifications.
//
// Due to specific locking requirements around JS API requests, we cannot hold some
// locks for the entire duration of processing of some protocols, therefore, we use
// a map that registers the client ID in a "locked" state. If a different client tries
// to connect and the server detects that the client ID is in that map, it will try
// a little bit until it is not, or fail the new client, since we can't protect
// processing of protocols in the original client. This is not expected to happen often.
//
// Runs from the client's readLoop.
// No lock held on entry.
func (s *Server) mqttProcessConnect(c *client, cp *mqttConnectProto, trace bool) error {
sendConnAck := func(rc byte, sessp bool) {
c.mqttEnqueueConnAck(rc, sessp)
if trace {
c.traceOutOp("CONNACK", []byte(fmt.Sprintf("sp=%v rc=%v", sessp, rc)))
}
}
c.mu.Lock()
c.clearAuthTimer()
c.mu.Unlock()
if !s.isClientAuthorized(c) {
c.Errorf(ErrAuthentication.Error())
sendConnAck(mqttConnAckRCNotAuthorized, false)
c.closeConnection(AuthenticationViolation)
return ErrAuthentication
}
// Now that we are are authenticated, we have the client bound to the account.
// Get the account's level MQTT sessions manager. If it does not exists yet,
// this will create it along with the streams where sessions and messages
// are stored.
asm, err := s.getOrCreateMQTTAccountSessionManager(cp.clientID, c)
if err != nil {
return err
}
// Most of the session state is altered only in the readLoop so does not
// need locking. For things that can be access in the readLoop and in
// callbacks, we will use explicit locking.
// To prevent other clients to connect with the same client ID, we will
// add the client ID to a "locked" map so that the connect somewhere else
// is put on hold.
// This keep track of how many times this client is detecting that its
// client ID is in the locked map. After a short amount, the server will
// fail this inbound client.
locked := 0
CHECK:
asm.mu.Lock()
// Check if different applications keep trying to connect with the same
// client ID at the same time.
if tm, ok := asm.flappers[cp.clientID]; ok {
// If the last time it tried to connect was more than 1 sec ago,
// then accept and remove from flappers map.
if time.Now().UnixNano()-tm > int64(mqttSessJailDur) {
asm.removeSessFromFlappers(cp.clientID)
} else {
// Will hold this client for a second and then close it. We
// do this so that if the client has a reconnect feature we
// don't end-up with very rapid flapping between apps.
time.AfterFunc(mqttSessJailDur, func() {
c.closeConnection(DuplicateClientID)
})
asm.mu.Unlock()
return nil
}
}
// If an existing session is in the process of processing some packet, we can't
// evict the old client just yet. So try again to see if the state clears, but
// if it does not, then we have no choice but to fail the new client instead of
// the old one.
if _, ok := asm.sessLocked[cp.clientID]; ok {
asm.mu.Unlock()
if locked++; locked == 10 {
return fmt.Errorf("other session with client ID %q is in the process of connecting", cp.clientID)
}
time.Sleep(100 * time.Millisecond)
goto CHECK
}
// Register this client ID the "locked" map for the duration if this function.
asm.sessLocked[cp.clientID] = struct{}{}
// And remove it on exit, regardless of error or not.
defer func() {
asm.mu.Lock()
delete(asm.sessLocked, cp.clientID)
asm.mu.Unlock()
}()
// Is the client requesting a clean session or not.
cleanSess := cp.flags&mqttConnFlagCleanSession != 0
// Session present? Assume false, will be set to true only when applicable.
sessp := false
// Do we have an existing session for this client ID
es, exists := asm.sessions[cp.clientID]
asm.mu.Unlock()
// The session is not in the map, but may be on disk, so try to recover
// or create the stream if not.
if !exists {
es, exists, err = asm.createOrRestoreSession(cp.clientID, s.getOpts())
if err != nil {
return err
}
}
if exists {
// Clear the session if client wants a clean session.
// Also, Spec [MQTT-3.2.2-1]: don't report session present
if cleanSess || es.clean {
// Spec [MQTT-3.1.2-6]: If CleanSession is set to 1, the Client and
// Server MUST discard any previous Session and start a new one.
// This Session lasts as long as the Network Connection. State data
// associated with this Session MUST NOT be reused in any subsequent
// Session.
es.clear(false)
} else {
// Report to the client that the session was present
sessp = true
}
// Spec [MQTT-3.1.4-2]. If the ClientId represents a Client already
// connected to the Server then the Server MUST disconnect the existing
// client.
// Bind with the new client. This needs to be protected because can be
// accessed outside of the readLoop.
es.mu.Lock()
ec := es.c
es.c = c
es.clean = cleanSess
es.mu.Unlock()
if ec != nil {
// Remove "will" of existing client before closing
ec.mu.Lock()
ec.mqtt.cp.will = nil
ec.mu.Unlock()
// Add to the map of the flappers
asm.mu.Lock()
asm.addSessToFlappers(cp.clientID)
asm.mu.Unlock()
c.Warnf("Replacing old client %q since both have the same client ID %q", ec.String(), cp.clientID)
// Close old client in separate go routine
go ec.closeConnection(DuplicateClientID)
}
} else {
// Spec [MQTT-3.2.2-3]: if the Server does not have stored Session state,
// it MUST set Session Present to 0 in the CONNACK packet.
es.mu.Lock()
es.c, es.clean = c, cleanSess
es.mu.Unlock()
// Now add this new session into the account sessions
asm.addSession(es, true)
}
// We would need to save only if it did not exist previously, but we save
// always in case we are running in cluster mode. This will notify other
// running servers that this session is being used.
if err := es.save(); err != nil {
asm.removeSession(es, true)
return err
}
c.mu.Lock()
c.flags.set(connectReceived)
c.mqtt.cp = cp
c.mqtt.asm = asm
c.mqtt.sess = es
c.mu.Unlock()
// Spec [MQTT-3.2.0-1]: CONNACK must be the first protocol sent to the session.
sendConnAck(mqttConnAckRCConnectionAccepted, sessp)
// Process possible saved subscriptions.
if l := len(es.subs); l > 0 {
filters := make([]*mqttFilter, 0, l)
for subject, qos := range es.subs {
filters = append(filters, &mqttFilter{filter: subject, qos: qos})
}
if _, err := asm.processSubs(es, c, filters, false, trace); err != nil {
return err
}
}
return nil
}
func (c *client) mqttEnqueueConnAck(rc byte, sessionPresent bool) {
proto := [4]byte{mqttPacketConnectAck, 2, 0, rc}
c.mu.Lock()
// Spec [MQTT-3.2.2-4]. If return code is different from 0, then
// session present flag must be set to 0.
if rc == 0 {
if sessionPresent {
proto[2] = 1
}
}
c.enqueueProto(proto[:])
c.mu.Unlock()
}
func (s *Server) mqttHandleWill(c *client) {
c.mu.Lock()
if c.mqtt.cp == nil {
c.mu.Unlock()
return
}
will := c.mqtt.cp.will
if will == nil {
c.mu.Unlock()
return
}
pp := c.mqtt.pp
pp.topic = will.topic
pp.subject = will.subject
pp.msg = will.message
pp.sz = len(will.message)
pp.pi = 0
pp.flags = will.qos << 1
if will.retain {
pp.flags |= mqttPubFlagRetain
}
c.mu.Unlock()
s.mqttProcessPub(c, pp)
c.flushClients(0)
}
//////////////////////////////////////////////////////////////////////////////
//
// PUBLISH protocol related functions
//
//////////////////////////////////////////////////////////////////////////////
func (c *client) mqttParsePub(r *mqttReader, pl int, pp *mqttPublish) error {
qos := mqttGetQoS(pp.flags)
if qos > 1 {
return fmt.Errorf("publish QoS=%v not supported", qos)
}
if err := r.ensurePacketInBuffer(pl); err != nil {
return err
}
// Keep track of where we are when starting to read the variable header
start := r.pos
var err error
pp.topic, err = r.readBytes("topic", false)
if err != nil {
return err
}
if len(pp.topic) == 0 {
return errMQTTTopicIsEmpty
}
// Convert the topic to a NATS subject. This call will also check that
// there is no MQTT wildcards (Spec [MQTT-3.3.2-2] and [MQTT-4.7.1-1])
// Note that this may not result in a copy if there is no conversion.
// It is good because after the message is processed we won't have a
// reference to the buffer and we save a copy.
pp.subject, err = mqttTopicToNATSPubSubject(pp.topic)
if err != nil {
return err
}
if qos > 0 {
pp.pi, err = r.readUint16("packet identifier")
if err != nil {
return err
}
if pp.pi == 0 {
return fmt.Errorf("with QoS=%v, packet identifier cannot be 0", qos)
}
} else {
pp.pi = 0
}
// The message payload will be the total packet length minus
// what we have consumed for the variable header
pp.sz = pl - (r.pos - start)
if pp.sz > 0 {
start = r.pos
r.pos += pp.sz
pp.msg = r.buf[start:r.pos]
} else {
pp.msg = nil
}
return nil
}
func mqttPubTrace(pp *mqttPublish) string {
dup := pp.flags&mqttPubFlagDup != 0
qos := mqttGetQoS(pp.flags)
retain := mqttIsRetained(pp.flags)
var piStr string
if pp.pi > 0 {
piStr = fmt.Sprintf(" pi=%v", pp.pi)
}
return fmt.Sprintf("%s dup=%v QoS=%v retain=%v size=%v%s",
pp.topic, dup, qos, retain, pp.sz, piStr)
}
// Process the PUBLISH packet.
//
// Runs from the client's readLoop.
// No lock held on entry.
func (s *Server) mqttProcessPub(c *client, pp *mqttPublish) error {
c.pa.subject, c.pa.hdr, c.pa.size, c.pa.reply = pp.subject, -1, pp.sz, nil
bb := bytes.Buffer{}
bb.WriteString(hdrLine)
bb.Write(mqttNatsHeaderB)
bb.WriteByte(':')
bb.WriteByte('0' + mqttGetQoS(pp.flags))
bb.WriteString(_CRLF_)
bb.WriteString(_CRLF_)
c.pa.hdr = bb.Len()
c.pa.hdb = []byte(strconv.FormatInt(int64(c.pa.hdr), 10))
bb.Write(pp.msg)
c.pa.size = bb.Len()
c.pa.szb = []byte(strconv.FormatInt(int64(c.pa.size), 10))
msgToSend := bb.Bytes()
var err error
// Unless we have a publish permission error, if the message is QoS1, then we
// need to store the message (and deliver it to JS durable consumers).
if _, permIssue := c.processInboundClientMsg(msgToSend); !permIssue && mqttGetQoS(pp.flags) > 0 {
_, err = c.mqtt.sess.jsa.storeMsg(mqttStreamSubjectPrefix+string(c.pa.subject), c.pa.hdr, msgToSend)
}
c.pa.subject, c.pa.hdr, c.pa.size, c.pa.szb, c.pa.reply = nil, -1, 0, nil, nil
return err
}
// Invoked when processing an inbound client message. If the "retain" flag is
// set, the message is stored so it can be later resent to (re)starting
// subscriptions that match the subject.
//
// Invoked from the MQTT publisher's readLoop. No client lock is held on entry.
func (c *client) mqttHandlePubRetain() {
pp := c.mqtt.pp
if !mqttIsRetained(pp.flags) {
return
}
key := string(pp.subject)
asm := c.mqtt.asm
// Spec [MQTT-3.3.1-11]. Payload of size 0 removes the retained message,
// but should still be delivered as a normal message.
if pp.sz == 0 {
if seqToRemove := asm.handleRetainedMsgDel(key, 0); seqToRemove > 0 {
asm.deleteRetainedMsg(seqToRemove)
asm.notifyRetainedMsgDeleted(key, seqToRemove)
}
} else {
// Spec [MQTT-3.3.1-5]. Store the retained message with its QoS.
// When coming from a publish protocol, `pp` is referencing a stack
// variable that itself possibly references the read buffer.
rm := &mqttRetainedMsg{
Origin: asm.jsa.id,
Subject: key,
Topic: string(pp.topic),
Msg: copyBytes(pp.msg),
Flags: pp.flags,
Source: c.opts.Username,
}
rmBytes, _ := json.Marshal(rm)
smr, err := asm.jsa.storeMsg(mqttRetainedMsgsStreamSubject, -1, rmBytes)
if err == nil {
// Update the new sequence
rm.sseq = smr.Sequence
// Add/update the map
oldSeq := asm.handleRetainedMsg(key, rm)
// If this is a new message on the same subject, delete the old one.
if oldSeq != 0 {
asm.deleteRetainedMsg(oldSeq)
}
} else {
c.mu.Lock()
acc := c.acc
c.mu.Unlock()
c.Errorf("unable to store retained message for account %q, subject %q: %v",
acc.GetName(), key, err)
}
}
// Clear the retain flag for a normal published message.
pp.flags &= ^mqttPubFlagRetain
}
// After a config reload, it is possible that the source of a publish retained
// message is no longer allowed to publish on the given topic. If that is the
// case, the retained message is removed from the map and will no longer be
// sent to (re)starting subscriptions.
//
// Server lock MUST NOT be held on entry.
func (s *Server) mqttCheckPubRetainedPerms() {
sm := &s.mqtt.sessmgr
sm.mu.RLock()
done := len(sm.sessions) == 0
sm.mu.RUnlock()
if done {
return
}
s.mu.Lock()
users := make(map[string]*User, len(s.users))
for un, u := range s.users {
users[un] = u
}
s.mu.Unlock()
sm.mu.RLock()
defer sm.mu.RUnlock()
for _, asm := range sm.sessions {
perms := map[string]*perm{}
deletes := map[string]uint64{}
asm.mu.Lock()
for subject, rm := range asm.retmsgs {
if rm.Source == _EMPTY_ {
continue
}
// Lookup source from global users.
u := users[rm.Source]
if u != nil {
p, ok := perms[rm.Source]
if !ok {
p = generatePubPerms(u.Permissions)
perms[rm.Source] = p
}
// If there is permission and no longer allowed to publish in
// the subject, remove the publish retained message from the map.
if p != nil && !pubAllowed(p, subject) {
u = nil
}
}
// Not present or permissions have changed such that the source can't
// publish on that subject anymore: remove it from the map.
if u == nil {
delete(asm.retmsgs, subject)
asm.sl.Remove(rm.sub)
deletes[subject] = rm.sseq
}
}
asm.mu.Unlock()
for subject, seq := range deletes {
asm.deleteRetainedMsg(seq)
asm.notifyRetainedMsgDeleted(subject, seq)
}
}
}
// Helper to generate only pub permissions from a Permissions object
func generatePubPerms(perms *Permissions) *perm {
var p *perm
if perms.Publish.Allow != nil {
p = &perm{}
p.allow = NewSublistWithCache()
for _, pubSubject := range perms.Publish.Allow {
sub := &subscription{subject: []byte(pubSubject)}
p.allow.Insert(sub)
}
}
if len(perms.Publish.Deny) > 0 {
if p == nil {
p = &perm{}
}
p.deny = NewSublistWithCache()
for _, pubSubject := range perms.Publish.Deny {
sub := &subscription{subject: []byte(pubSubject)}
p.deny.Insert(sub)
}
}
return p
}
// Helper that checks if given `perms` allow to publish on the given `subject`
func pubAllowed(perms *perm, subject string) bool {
allowed := true
if perms.allow != nil {
r := perms.allow.Match(subject)
allowed = len(r.psubs) != 0
}
// If we have a deny list and are currently allowed, check that as well.
if allowed && perms.deny != nil {
r := perms.deny.Match(subject)
allowed = len(r.psubs) == 0
}
return allowed
}
func mqttWritePublish(w *mqttWriter, qos byte, dup, retain bool, subject string, pi uint16, payload []byte) {
flags := qos << 1
if dup {
flags |= mqttPubFlagDup
}
if retain {
flags |= mqttPubFlagRetain
}
w.WriteByte(mqttPacketPub | flags)
pkLen := 2 + len(subject) + len(payload)
if qos > 0 {
pkLen += 2
}
w.WriteVarInt(pkLen)
w.WriteString(subject)
if qos > 0 {
w.WriteUint16(pi)
}
w.Write([]byte(payload))
}
func (c *client) mqttEnqueuePubAck(pi uint16) {
proto := [4]byte{mqttPacketPubAck, 0x2, 0, 0}
proto[2] = byte(pi >> 8)
proto[3] = byte(pi)
c.mu.Lock()
c.enqueueProto(proto[:4])
c.mu.Unlock()
}
func mqttParsePubAck(r *mqttReader, pl int) (uint16, error) {
if err := r.ensurePacketInBuffer(pl); err != nil {
return 0, err
}
pi, err := r.readUint16("packet identifier")
if err != nil {
return 0, err
}
if pi == 0 {
return 0, errMQTTPacketIdentifierIsZero
}
return pi, nil
}
// Process a PUBACK packet.
// Updates the session's pending list and sends an ACK to JS.
//
// Runs from the client's readLoop.
// No lock held on entry.
func (c *client) mqttProcessPubAck(pi uint16) {
sess := c.mqtt.sess
if sess == nil {
return
}
sess.mu.Lock()
if sess.c != c {
sess.mu.Unlock()
return
}
var ackSubject string
if ack, ok := sess.pending[pi]; ok {
delete(sess.pending, pi)
jsDur := ack.jsDur
if sseqToPi, ok := sess.cpending[jsDur]; ok {
delete(sseqToPi, ack.sseq)
}
if len(sess.pending) == 0 {
sess.ppi = 0
}
ackSubject = ack.subject
}
sess.mu.Unlock()
// Send the ack if applicable, this is done outside of the session lock.
if ackSubject != _EMPTY_ {
// We pass -1 for the hdr so that the send loop does not need to
// add the "client info" header. This is not a JS API request per se.
sess.jsa.sendq <- &mqttJSPubMsg{subj: ackSubject, hdr: -1}
}
}
// Return the QoS from the given PUBLISH protocol's flags
func mqttGetQoS(flags byte) byte {
return flags & mqttPubFlagQoS >> 1
}
func mqttIsRetained(flags byte) bool {
return flags&mqttPubFlagRetain != 0
}
//////////////////////////////////////////////////////////////////////////////
//
// SUBSCRIBE related functions
//
//////////////////////////////////////////////////////////////////////////////
func (c *client) mqttParseSubs(r *mqttReader, b byte, pl int) (uint16, []*mqttFilter, error) {
return c.mqttParseSubsOrUnsubs(r, b, pl, true)
}
func (c *client) mqttParseSubsOrUnsubs(r *mqttReader, b byte, pl int, sub bool) (uint16, []*mqttFilter, error) {
var expectedFlag byte
var action string
if sub {
expectedFlag = mqttSubscribeFlags
} else {
expectedFlag = mqttUnsubscribeFlags
action = "un"
}
// Spec [MQTT-3.8.1-1], [MQTT-3.10.1-1]
if rf := b & 0xf; rf != expectedFlag {
return 0, nil, fmt.Errorf("wrong %ssubscribe reserved flags: %x", action, rf)
}
if err := r.ensurePacketInBuffer(pl); err != nil {
return 0, nil, err
}
pi, err := r.readUint16("packet identifier")
if err != nil {
return 0, nil, fmt.Errorf("reading packet identifier: %v", err)
}
end := r.pos + (pl - 2)
var filters []*mqttFilter
for r.pos < end {
// Don't make a copy now because, this will happen during conversion
// or when processing the sub.
topic, err := r.readBytes("topic filter", false)
if err != nil {
return 0, nil, err
}
if len(topic) == 0 {
return 0, nil, errMQTTTopicFilterCannotBeEmpty
}
// Spec [MQTT-3.8.3-1], [MQTT-3.10.3-1]
if !utf8.Valid(topic) {
return 0, nil, fmt.Errorf("invalid utf8 for topic filter %q", topic)
}
var qos byte
// We are going to report if we had an error during the conversion,
// but we don't fail the parsing. When processing the sub, we will
// have an error then, and the processing of subs code will send
// the proper mqttSubAckFailure flag for this given subscription.
filter, err := mqttFilterToNATSSubject(topic)
if err != nil {
c.Errorf("invalid topic %q: %v", topic, err)
err = nil
}
if sub {
qos, err = r.readByte("QoS")
if err != nil {
return 0, nil, err
}
// Spec [MQTT-3-8.3-4].
if qos > 2 {
return 0, nil, fmt.Errorf("subscribe QoS value must be 0, 1 or 2, got %v", qos)
}
}
f := &mqttFilter{ttopic: topic, filter: string(filter), qos: qos}
filters = append(filters, f)
}
// Spec [MQTT-3.8.3-3], [MQTT-3.10.3-2]
if len(filters) == 0 {
return 0, nil, fmt.Errorf("%ssubscribe protocol must contain at least 1 topic filter", action)
}
return pi, filters, nil
}
func mqttSubscribeTrace(pi uint16, filters []*mqttFilter) string {
var sep string
sb := &strings.Builder{}
sb.WriteString("[")
for i, f := range filters {
sb.WriteString(sep)
sb.Write(f.ttopic)
sb.WriteString(" (")
sb.WriteString(f.filter)
sb.WriteString(") QoS=")
sb.WriteString(fmt.Sprintf("%v", f.qos))
if i == 0 {
sep = ", "
}
}
sb.WriteString(fmt.Sprintf("] pi=%v", pi))
return sb.String()
}
// For a MQTT QoS0 subscription, we create a single NATS subscription
// on the actual subject, for instance "foo.bar".
// For a MQTT QoS1 subscription, we create 2 subscriptions, one on
// "foo.bar" (as for QoS0, but sub.mqtt.qos will be 1), and one on
// the subject "$MQTT.sub.<uid>" which is the delivery subject of
// the JS durable consumer with the filter subject "$MQTT.msgs.foo.bar".
//
// This callback delivers messages to the client as QoS0 messages, either
// because they have been produced as QoS0 messages (and therefore only
// this callback can receive them), they are QoS1 published messages but
// this callback is for a subscription that is QoS0, or the published
// messages come from NATS publishers.
//
// This callback must reject a message if it is known to be a QoS1 published
// message and this is the callback for a QoS1 subscription because in
// that case, it will be handled by the other callback. This avoid getting
// duplicate deliveries.
func mqttDeliverMsgCbQos0(sub *subscription, pc *client, subject, _ string, rmsg []byte) {
if pc.kind == JETSTREAM {
return
}
hdr, msg := pc.msgParts(rmsg)
// This is the client associated with the subscription.
cc := sub.client
// This is immutable
sess := cc.mqtt.sess
// Check the subscription's QoS. This needs to be protected because
// the client may change an existing subscription at any time.
sess.mu.Lock()
subQos := sub.mqtt.qos
sess.mu.Unlock()
var retained bool
var topic []byte
// This is an MQTT publisher directly connected to this server.
if pc.isMqtt() {
// If the MQTT subscription is QoS1, then we bail out if the published
// message is QoS1 because it will be handled in the other callack.
if subQos == 1 && mqttGetQoS(pc.mqtt.pp.flags) > 0 {
return
}
topic = pc.mqtt.pp.topic
retained = mqttIsRetained(pc.mqtt.pp.flags)
} else {
// Non MQTT client, could be NATS publisher, or ROUTER, etc..
// For QoS1 subs, we need to make sure that if there is a header, it does
// not say that this is a QoS1 published message, because it will be handled
// in the other callback.
if subQos != 0 && len(hdr) > 0 {
if nhv := getHeader(mqttNatsHeader, hdr); len(nhv) >= 1 {
if qos := nhv[0] - '0'; qos > 0 {
return
}
}
}
// If size is more than what a MQTT client can handle, we should probably reject,
// for now just truncate.
if len(msg) > mqttMaxPayloadSize {
msg = msg[:mqttMaxPayloadSize]
}
topic = natsSubjectToMQTTTopic(subject)
}
// Message never has a packet identifier nor is marked as duplicate.
pc.mqttDeliver(cc, sub, 0, false, retained, topic, msg)
}
// This is the callback attached to a JS durable subscription for a MQTT Qos1 sub.
// Only JETSTREAM should be sending a message to this subject (the delivery subject
// associated with the JS durable consumer), but in cluster mode, this can be coming
// from a route, gw, etc... We make sure that if this is the case, the message contains
// a NATS/MQTT header that indicates that this is a published QoS1 message.
func mqttDeliverMsgCbQos1(sub *subscription, pc *client, subject, reply string, rmsg []byte) {
var retained bool
// Message on foo.bar is stored under $MQTT.msgs.foo.bar, so the subject has to be
// at least as long as the stream subject prefix "$MQTT.msgs.", and after removing
// the prefix, has to be at least 1 character long.
if len(subject) < len(mqttStreamSubjectPrefix)+1 {
return
}
hdr, msg := pc.msgParts(rmsg)
if pc.kind != JETSTREAM {
if len(hdr) == 0 {
return
}
nhv := getHeader(mqttNatsHeader, hdr)
if len(nhv) < 1 {
return
}
if qos := nhv[0] - '0'; qos != 1 {
return
}
}
// This is the client associated with the subscription.
cc := sub.client
// This is immutable
sess := cc.mqtt.sess
// We lock to check some of the subscription's fields and if we need to
// keep track of pending acks, etc..
sess.mu.Lock()
if sess.c != cc || sub.mqtt == nil {
sess.mu.Unlock()
return
}
// This is a QoS1 message for a QoS1 subscription, so get the pi and keep
// track of ack subject.
pQoS := byte(1)
pi, dup := sess.trackPending(pQoS, reply, sub)
sess.mu.Unlock()
if pi == 0 {
// We have reached max pending, don't send the message now.
// JS will cause a redelivery and if by then the number of pending
// messages has fallen below threshold, the message will be resent.
return
}
topic := natsSubjectToMQTTTopic(string(subject[len(mqttStreamSubjectPrefix):]))
pc.mqttDeliver(cc, sub, pi, dup, retained, topic, msg)
}
// Common function to mqtt delivery callbacks to serialize and send the message
// to the `cc` client.
func (c *client) mqttDeliver(cc *client, sub *subscription, pi uint16, dup, retained bool, topic, msg []byte) {
sw := mqttWriter{}
w := &sw
flags := mqttSerializePublishMsg(w, pi, dup, retained, topic, msg)
cc.mu.Lock()
if sub.mqtt.prm != nil {
cc.queueOutbound(sub.mqtt.prm.Bytes())
sub.mqtt.prm = nil
}
cc.queueOutbound(w.Bytes())
c.addToPCD(cc)
if cc.trace {
pp := mqttPublish{
topic: topic,
flags: flags,
pi: pi,
sz: len(msg),
}
cc.traceOutOp("PUBLISH", []byte(mqttPubTrace(&pp)))
}
cc.mu.Unlock()
}
// Serializes to the given writer the message for the given subject.
func mqttSerializePublishMsg(w *mqttWriter, pi uint16, dup, retained bool, topic, msg []byte) byte {
// Compute len (will have to add packet id if message is sent as QoS>=1)
pkLen := 2 + len(topic) + len(msg)
var flags byte
// Set flags for dup/retained/qos1
if dup {
flags |= mqttPubFlagDup
}
if retained {
flags |= mqttPubFlagRetain
}
// For now, we have only QoS 1
if pi > 0 {
pkLen += 2
flags |= mqttPubQos1
}
w.WriteByte(mqttPacketPub | flags)
w.WriteVarInt(pkLen)
w.WriteBytes(topic)
if pi > 0 {
w.WriteUint16(pi)
}
w.Write(msg)
return flags
}
// Process the SUBSCRIBE packet.
//
// Process the list of subscriptions and update the given filter
// with the QoS that has been accepted (or failure).
//
// Spec [MQTT-3.8.4-3] says that if an exact same subscription is
// found, it needs to be replaced with the new one (possibly updating
// the qos) and that the flow of publications must not be interrupted,
// which I read as the replacement cannot be a "remove then add" if there
// is a chance that in between the 2 actions, published messages
// would be "lost" because there would not be any matching subscription.
//
// Run from client's readLoop.
// No lock held on entry.
func (c *client) mqttProcessSubs(filters []*mqttFilter) ([]*subscription, error) {
// Those things are immutable, but since processing subs is not
// really in the fast path, let's get them under the client lock.
c.mu.Lock()
asm := c.mqtt.asm
sess := c.mqtt.sess
trace := c.trace
c.mu.Unlock()
if err := asm.lockSession(sess, c); err != nil {
return nil, err
}
defer asm.unlockSession(sess)
return asm.processSubs(sess, c, filters, true, trace)
}
// Cleanup that is performed in processSubs if there was an error.
//
// Runs from client's readLoop.
// Lock not held on entry, but session is in the locked map.
func (sess *mqttSession) cleanupFailedSub(c *client, sub *subscription, cc *ConsumerConfig, jssub *subscription) {
if sub != nil {
c.processUnsub(sub.sid)
}
if jssub != nil {
c.processUnsub(jssub.sid)
}
if cc != nil {
sess.deleteConsumer(cc)
}
}
// When invoked with a QoS of 0, looks for an existing JS durable consumer for
// the given sid and if one is found, delete the JS durable consumer and unsub
// the NATS subscription on the delivery subject.
// With a QoS > 0, creates or update the existing JS durable consumer along with
// its NATS subscription on a delivery subject.
//
// Lock not held on entry, but session is in the locked map.
func (sess *mqttSession) processJSConsumer(c *client, subject, sid string,
qos byte, fromSubProto bool) (*ConsumerConfig, *subscription, error) {
// Check if we are already a JS consumer for this SID.
cc, exists := sess.cons[sid]
if exists {
// If current QoS is 0, it means that we need to delete the existing
// one (that was QoS > 0)
if qos == 0 {
// The JS durable consumer's delivery subject is on a NUID of
// the form: mqttSubPrefix + <nuid>. It is also used as the sid
// for the NATS subscription, so use that for the lookup.
sub := c.subs[cc.DeliverSubject]
delete(sess.cons, sid)
sess.deleteConsumer(cc)
if sub != nil {
c.processUnsub(sub.sid)
}
return nil, nil, nil
}
// If this is called when processing SUBSCRIBE protocol, then if
// the JS consumer already exists, we are done (it was created
// during the processing of CONNECT).
if fromSubProto {
return nil, nil, nil
}
}
// Here it means we don't have a JS consumer and if we are QoS 0,
// we have nothing to do.
if qos == 0 {
return nil, nil, nil
}
var err error
var inbox string
if exists {
inbox = cc.DeliverSubject
} else {
inbox = mqttSubPrefix + nuid.Next()
opts := c.srv.getOpts()
ackWait := opts.MQTT.AckWait
if ackWait == 0 {
ackWait = mqttDefaultAckWait
}
maxAckPending := int(opts.MQTT.MaxAckPending)
if maxAckPending == 0 {
maxAckPending = mqttDefaultMaxAckPending
}
// Check that the limit of subs' maxAckPending are not going over the limit
if after := sess.tmaxack + maxAckPending; after > mqttMaxAckTotalLimit {
return nil, nil, fmt.Errorf("max_ack_pending for all consumers would be %v which exceeds the limit of %v",
after, mqttMaxAckTotalLimit)
}
durName := sess.idHash + "_" + nuid.Next()
cc = &ConsumerConfig{
DeliverSubject: inbox,
Durable: durName,
AckPolicy: AckExplicit,
DeliverPolicy: DeliverNew,
FilterSubject: mqttStreamSubjectPrefix + subject,
AckWait: ackWait,
MaxAckPending: maxAckPending,
}
if err := sess.createConsumer(cc); err != nil {
c.Errorf("Unable to add JetStream consumer for subscription on %q: err=%v", subject, err)
return nil, nil, err
}
sess.tmaxack += maxAckPending
}
// This is an internal subscription on subject like "$MQTT.sub.<nuid>" that is setup
// for the JS durable's deliver subject.
sess.mu.Lock()
sub, err := c.processSub([]byte(inbox), nil, []byte(inbox), mqttDeliverMsgCbQos1, false)
if err != nil {
sess.mu.Unlock()
sess.deleteConsumer(cc)
c.Errorf("Unable to create subscription for JetStream consumer on %q: %v", subject, err)
return nil, nil, err
}
if sub.mqtt == nil {
sub.mqtt = &mqttSub{}
}
sub.mqtt.qos = qos
sub.mqtt.jsDur = cc.Durable
sess.mu.Unlock()
return cc, sub, nil
}
// Queues the published retained messages for each subscription and signals
// the writeLoop.
func (c *client) mqttSendRetainedMsgsToNewSubs(subs []*subscription) {
c.mu.Lock()
for _, sub := range subs {
if sub.mqtt != nil && sub.mqtt.prm != nil {
c.queueOutbound(sub.mqtt.prm.Bytes())
sub.mqtt.prm = nil
}
}
c.flushSignal()
c.mu.Unlock()
}
func (c *client) mqttEnqueueSubAck(pi uint16, filters []*mqttFilter) {
w := &mqttWriter{}
w.WriteByte(mqttPacketSubAck)
// packet length is 2 (for packet identifier) and 1 byte per filter.
w.WriteVarInt(2 + len(filters))
w.WriteUint16(pi)
for _, f := range filters {
w.WriteByte(f.qos)
}
c.mu.Lock()
c.enqueueProto(w.Bytes())
c.mu.Unlock()
}
//////////////////////////////////////////////////////////////////////////////
//
// UNSUBSCRIBE related functions
//
//////////////////////////////////////////////////////////////////////////////
func (c *client) mqttParseUnsubs(r *mqttReader, b byte, pl int) (uint16, []*mqttFilter, error) {
return c.mqttParseSubsOrUnsubs(r, b, pl, false)
}
// Process the UNSUBSCRIBE packet.
//
// Given the list of topics, this is going to unsubscribe the low level NATS subscriptions
// and delete the JS durable consumers when applicable.
//
// Runs from the client's readLoop.
// No lock held on entry.
func (c *client) mqttProcessUnsubs(filters []*mqttFilter) error {
// Those things are immutable, but since processing unsubs is not
// really in the fast path, let's get them under the client lock.
c.mu.Lock()
asm := c.mqtt.asm
sess := c.mqtt.sess
c.mu.Unlock()
if err := asm.lockSession(sess, c); err != nil {
return err
}
defer asm.unlockSession(sess)
removeJSCons := func(sid string) {
cc, ok := sess.cons[sid]
if ok {
delete(sess.cons, sid)
sess.deleteConsumer(cc)
// Need lock here since these are accessed by callbacks
sess.mu.Lock()
if seqPis, ok := sess.cpending[cc.Durable]; ok {
delete(sess.cpending, cc.Durable)
for _, pi := range seqPis {
delete(sess.pending, pi)
}
if len(sess.pending) == 0 {
sess.ppi = 0
}
}
sess.mu.Unlock()
}
}
for _, f := range filters {
sid := f.filter
// Remove JS Consumer if one exists for this sid
removeJSCons(sid)
if err := c.processUnsub([]byte(sid)); err != nil {
c.Errorf("error unsubscribing from %q: %v", sid, err)
}
if mqttNeedSubForLevelUp(sid) {
subject := sid[:len(sid)-2]
sid = subject + mqttMultiLevelSidSuffix
removeJSCons(sid)
if err := c.processUnsub([]byte(sid)); err != nil {
c.Errorf("error unsubscribing from %q: %v", subject, err)
}
}
}
return sess.update(filters, false)
}
func (c *client) mqttEnqueueUnsubAck(pi uint16) {
w := &mqttWriter{}
w.WriteByte(mqttPacketUnsubAck)
w.WriteVarInt(2)
w.WriteUint16(pi)
c.mu.Lock()
c.enqueueProto(w.Bytes())
c.mu.Unlock()
}
func mqttUnsubscribeTrace(pi uint16, filters []*mqttFilter) string {
var sep string
sb := strings.Builder{}
sb.WriteString("[")
for i, f := range filters {
sb.WriteString(sep)
sb.Write(f.ttopic)
sb.WriteString(" (")
sb.WriteString(f.filter)
sb.WriteString(")")
if i == 0 {
sep = ", "
}
}
sb.WriteString(fmt.Sprintf("] pi=%v", pi))
return sb.String()
}
//////////////////////////////////////////////////////////////////////////////
//
// PINGREQ/PINGRESP related functions
//
//////////////////////////////////////////////////////////////////////////////
func (c *client) mqttEnqueuePingResp() {
c.mu.Lock()
c.enqueueProto(mqttPingResponse)
c.mu.Unlock()
}
//////////////////////////////////////////////////////////////////////////////
//
// Trace functions
//
//////////////////////////////////////////////////////////////////////////////
func errOrTrace(err error, trace string) []byte {
if err != nil {
return []byte(err.Error())
}
return []byte(trace)
}
//////////////////////////////////////////////////////////////////////////////
//
// Subject/Topic conversion functions
//
//////////////////////////////////////////////////////////////////////////////
// Converts an MQTT Topic Name to a NATS Subject (used by PUBLISH)
// See mqttToNATSSubjectConversion() for details.
func mqttTopicToNATSPubSubject(mt []byte) ([]byte, error) {
return mqttToNATSSubjectConversion(mt, false)
}
// Converts an MQTT Topic Filter to a NATS Subject (used by SUBSCRIBE)
// See mqttToNATSSubjectConversion() for details.
func mqttFilterToNATSSubject(filter []byte) ([]byte, error) {
return mqttToNATSSubjectConversion(filter, true)
}
// Converts an MQTT Topic Name or Filter to a NATS Subject.
// In MQTT:
// - a Topic Name does not have wildcard (PUBLISH uses only topic names).
// - a Topic Filter can include wildcards (SUBSCRIBE uses those).
// - '+' and '#' are wildcard characters (single and multiple levels respectively)
// - '/' is the topic level separator.
//
// Conversion that occurs:
// - '/' is replaced with '/.' if it is the first character in mt
// - '/' is replaced with './' if the last or next character in mt is '/'
// For instance, foo//bar would become foo./.bar
// - '/' is replaced with '.' for all other conditions (foo/bar -> foo.bar)
// - '.' and ' ' cause an error to be returned.
//
// If there is no need to convert anything (say "foo" remains "foo"), then
// the no memory is allocated and the returned slice is the original `mt`.
func mqttToNATSSubjectConversion(mt []byte, wcOk bool) ([]byte, error) {
var cp bool
var j int
res := mt
makeCopy := func(i int) {
cp = true
res = make([]byte, 0, len(mt)+10)
if i > 0 {
res = append(res, mt[:i]...)
}
}
end := len(mt) - 1
for i := 0; i < len(mt); i++ {
switch mt[i] {
case mqttTopicLevelSep:
if i == 0 || res[j-1] == btsep {
if !cp {
makeCopy(0)
}
res = append(res, mqttTopicLevelSep, btsep)
j++
} else if i == end || mt[i+1] == mqttTopicLevelSep {
if !cp {
makeCopy(i)
}
res = append(res, btsep, mqttTopicLevelSep)
j++
} else {
if !cp {
makeCopy(i)
}
res = append(res, btsep)
}
case btsep, ' ':
// As of now, we cannot support '.' or ' ' in the MQTT topic/filter.
return nil, errMQTTUnsupportedCharacters
case mqttSingleLevelWC, mqttMultiLevelWC:
if !wcOk {
// Spec [MQTT-3.3.2-2] and [MQTT-4.7.1-1]
// The wildcard characters can be used in Topic Filters, but MUST NOT be used within a Topic Name
return nil, fmt.Errorf("wildcards not allowed in publish's topic: %q", mt)
}
if !cp {
makeCopy(i)
}
if mt[i] == mqttSingleLevelWC {
res = append(res, pwc)
} else {
res = append(res, fwc)
}
default:
if cp {
res = append(res, mt[i])
}
}
j++
}
if cp && res[j-1] == btsep {
res = append(res, mqttTopicLevelSep)
j++
}
return res[:j], nil
}
// Converts a NATS subject to MQTT topic. This is for publish
// messages only, so there is no checking for wildcards.
// Rules are reversed of mqttToNATSSubjectConversion.
func natsSubjectToMQTTTopic(subject string) []byte {
topic := []byte(subject)
end := len(subject) - 1
var j int
for i := 0; i < len(subject); i++ {
switch subject[i] {
case mqttTopicLevelSep:
if !(i == 0 && i < end && subject[i+1] == btsep) {
topic[j] = mqttTopicLevelSep
j++
}
case btsep:
topic[j] = mqttTopicLevelSep
j++
if i < end && subject[i+1] == mqttTopicLevelSep {
i++
}
default:
topic[j] = subject[i]
j++
}
}
return topic[:j]
}
// Returns true if the subject has more than 1 token and ends with ".>"
func mqttNeedSubForLevelUp(subject string) bool {
if len(subject) < 3 {
return false
}
end := len(subject)
if subject[end-2] == '.' && subject[end-1] == fwc {
return true
}
return false
}
//////////////////////////////////////////////////////////////////////////////
//
// MQTT Reader functions
//
//////////////////////////////////////////////////////////////////////////////
func copyBytes(b []byte) []byte {
if b == nil {
return nil
}
cbuf := make([]byte, len(b))
copy(cbuf, b)
return cbuf
}
func (r *mqttReader) reset(buf []byte) {
r.buf = buf
r.pos = 0
}
func (r *mqttReader) hasMore() bool {
return r.pos != len(r.buf)
}
func (r *mqttReader) readByte(field string) (byte, error) {
if r.pos == len(r.buf) {
return 0, fmt.Errorf("error reading %s: %v", field, io.EOF)
}
b := r.buf[r.pos]
r.pos++
return b, nil
}
func (r *mqttReader) readPacketLen() (int, error) {
m := 1
v := 0
for {
var b byte
if r.pos != len(r.buf) {
b = r.buf[r.pos]
r.pos++
} else {
var buf [1]byte
if _, err := r.reader.Read(buf[:1]); err != nil {
if err == io.EOF {
return 0, io.ErrUnexpectedEOF
}
return 0, fmt.Errorf("error reading packet length: %v", err)
}
b = buf[0]
}
v += int(b&0x7f) * m
if (b & 0x80) == 0 {
return v, nil
}
m *= 0x80
if m > 0x200000 {
return 0, errMQTTMalformedVarInt
}
}
}
func (r *mqttReader) ensurePacketInBuffer(pl int) error {
rem := len(r.buf) - r.pos
if rem >= pl {
return nil
}
b := make([]byte, pl)
start := copy(b, r.buf[r.pos:])
for start != pl {
n, err := r.reader.Read(b[start:cap(b)])
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return fmt.Errorf("error ensuring protocol is loaded: %v", err)
}
start += n
}
r.reset(b)
return nil
}
func (r *mqttReader) readString(field string) (string, error) {
var s string
bs, err := r.readBytes(field, false)
if err == nil {
s = string(bs)
}
return s, err
}
func (r *mqttReader) readBytes(field string, cp bool) ([]byte, error) {
luint, err := r.readUint16(field)
if err != nil {
return nil, err
}
l := int(luint)
if l == 0 {
return nil, nil
}
start := r.pos
if start+l > len(r.buf) {
return nil, fmt.Errorf("error reading %s: %v", field, io.ErrUnexpectedEOF)
}
r.pos += l
b := r.buf[start:r.pos]
if cp {
b = copyBytes(b)
}
return b, nil
}
func (r *mqttReader) readUint16(field string) (uint16, error) {
if len(r.buf)-r.pos < 2 {
return 0, fmt.Errorf("error reading %s: %v", field, io.ErrUnexpectedEOF)
}
start := r.pos
r.pos += 2
return binary.BigEndian.Uint16(r.buf[start:r.pos]), nil
}
//////////////////////////////////////////////////////////////////////////////
//
// MQTT Writer functions
//
//////////////////////////////////////////////////////////////////////////////
func (w *mqttWriter) WriteUint16(i uint16) {
w.WriteByte(byte(i >> 8))
w.WriteByte(byte(i))
}
func (w *mqttWriter) WriteString(s string) {
w.WriteBytes([]byte(s))
}
func (w *mqttWriter) WriteBytes(bs []byte) {
w.WriteUint16(uint16(len(bs)))
w.Write(bs)
}
func (w *mqttWriter) WriteVarInt(value int) {
for {
b := byte(value & 0x7f)
value >>= 7
if value > 0 {
b |= 0x80
}
w.WriteByte(b)
if value == 0 {
break
}
}
}