Files
nats-server/server/mqtt_test.go
Ivan Kozlovic 552cc737f1 [FIXED] MQTT: asset placement in origin cluster
In a setup with shared system account and a cluster of leaf nodes,
the JS requests did not contain the origin cluster, which caused
assets to possibly be created in the HUB. With this change, the
assets will be created in the origin cluster.

Also, removed use of acc.JetStreamEnabled() but instead fail
start of the server if mqtt is enabled in standalone mode and JS
is not enabled. If JS is enabled, we will get proper error if
account has no JS enabled.

Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
2021-04-28 19:28:00 -06:00

5038 lines
150 KiB
Go

// Copyright 2020 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"bufio"
"bytes"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/nats-io/jwt/v2"
"github.com/nats-io/nats.go"
"github.com/nats-io/nkeys"
"github.com/nats-io/nuid"
)
var testMQTTTimeout = 4 * time.Second
var jsClusterTemplWithMQTT = `
listen: 127.0.0.1:-1
server_name: %s
jetstream: {max_mem_store: 256MB, max_file_store: 2GB, store_dir: "%s"}
cluster {
name: %s
listen: 127.0.0.1:%d
routes = [%s]
}
mqtt {
listen: 127.0.0.1:-1
}
# For access to system account.
accounts { $SYS { users = [ { user: "admin", pass: "s3cr3t!" } ] } }
`
type mqttErrorReader struct {
err error
}
func (r *mqttErrorReader) Read(b []byte) (int, error) { return 0, r.err }
func (r *mqttErrorReader) SetReadDeadline(time.Time) error { return nil }
func testNewEOFReader() *mqttErrorReader {
return &mqttErrorReader{err: io.EOF}
}
func TestMQTTReader(t *testing.T) {
r := &mqttReader{}
r.reset([]byte{0, 2, 'a', 'b'})
bs, err := r.readBytes("", false)
if err != nil {
t.Fatal(err)
}
sbs := string(bs)
if sbs != "ab" {
t.Fatalf(`expected "ab", got %q`, sbs)
}
r.reset([]byte{0, 2, 'a', 'b'})
bs, err = r.readBytes("", true)
if err != nil {
t.Fatal(err)
}
bs[0], bs[1] = 'c', 'd'
if bytes.Equal(bs, r.buf[2:]) {
t.Fatal("readBytes should have returned a copy")
}
r.reset([]byte{'a', 'b'})
if b, err := r.readByte(""); err != nil || b != 'a' {
t.Fatalf("Error reading byte: b=%v err=%v", b, err)
}
if !r.hasMore() {
t.Fatal("expected to have more, did not")
}
if b, err := r.readByte(""); err != nil || b != 'b' {
t.Fatalf("Error reading byte: b=%v err=%v", b, err)
}
if r.hasMore() {
t.Fatal("expected to not have more")
}
if _, err := r.readByte("test"); err == nil || !strings.Contains(err.Error(), "error reading test") {
t.Fatalf("unexpected error: %v", err)
}
r.reset([]byte{0, 2, 'a', 'b'})
if s, err := r.readString(""); err != nil || s != "ab" {
t.Fatalf("Error reading string: s=%q err=%v", s, err)
}
r.reset([]byte{10})
if _, err := r.readUint16("uint16"); err == nil || !strings.Contains(err.Error(), "error reading uint16") {
t.Fatalf("unexpected error: %v", err)
}
r.reset([]byte{1, 2, 3})
r.reader = testNewEOFReader()
if err := r.ensurePacketInBuffer(10); err == nil || !strings.Contains(err.Error(), "error ensuring protocol is loaded") {
t.Fatalf("unexpected error: %v", err)
}
r.reset([]byte{0x82, 0xff, 0x3})
l, err := r.readPacketLen()
if err != nil {
t.Fatal("error getting packet len")
}
if l != 0xff82 {
t.Fatalf("expected length 0xff82 got 0x%x", l)
}
r.reset([]byte{0xff, 0xff, 0xff, 0xff, 0xff})
if _, err := r.readPacketLen(); err == nil || !strings.Contains(err.Error(), "malformed") {
t.Fatalf("unexpected error: %v", err)
}
r.reset([]byte{0x80})
if _, err := r.readPacketLen(); err != io.ErrUnexpectedEOF {
t.Fatalf("unexpected error: %v", err)
}
r.reset([]byte{0x80})
r.reader = &mqttErrorReader{err: errors.New("on purpose")}
if _, err := r.readPacketLen(); err == nil || !strings.Contains(err.Error(), "on purpose") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestMQTTWriter(t *testing.T) {
w := &mqttWriter{}
w.WriteUint16(1234)
r := &mqttReader{}
r.reset(w.Bytes())
if v, err := r.readUint16(""); err != nil || v != 1234 {
t.Fatalf("unexpected value: v=%v err=%v", v, err)
}
w.Reset()
w.WriteString("test")
r.reset(w.Bytes())
if len(r.buf) != 6 {
t.Fatalf("Expected 2 bytes size before string, got %v", r.buf)
}
w.Reset()
w.WriteBytes([]byte("test"))
r.reset(w.Bytes())
if len(r.buf) != 6 {
t.Fatalf("Expected 2 bytes size before bytes, got %v", r.buf)
}
ints := []int{
0, 1, 127, 128, 16383, 16384, 2097151, 2097152, 268435455,
}
lens := []int{
1, 1, 1, 2, 2, 3, 3, 4, 4,
}
tl := 0
w.Reset()
for i, v := range ints {
w.WriteVarInt(v)
tl += lens[i]
if tl != w.Len() {
t.Fatalf("expected len %d, got %d", tl, w.Len())
}
}
r.reset(w.Bytes())
for _, v := range ints {
x, _ := r.readPacketLen()
if v != x {
t.Fatalf("expected %d, got %d", v, x)
}
}
}
func testMQTTDefaultOptions() *Options {
o := DefaultOptions()
o.Cluster.Port = 0
o.Gateway.Name = ""
o.Gateway.Port = 0
o.LeafNode.Port = 0
o.Websocket.Port = 0
o.MQTT.Host = "127.0.0.1"
o.MQTT.Port = -1
o.JetStream = true
return o
}
func testMQTTRunServer(t testing.TB, o *Options) *Server {
t.Helper()
o.NoLog = false
if o.StoreDir == _EMPTY_ {
o.StoreDir = createDir(t, "mqtt_js")
}
s, err := NewServer(o)
if err != nil {
t.Fatalf("Error creating server: %v", err)
}
l := &DummyLogger{}
s.SetLogger(l, true, true)
go s.Start()
if err := s.readyForConnections(3 * time.Second); err != nil {
t.Fatal(err)
}
return s
}
func testMQTTShutdownRestartedServer(s **Server) {
srv := *s
testMQTTShutdownServer(srv)
*s = nil
}
func testMQTTShutdownServer(s *Server) {
if c := s.JetStreamConfig(); c != nil {
dir := strings.TrimSuffix(c.StoreDir, JetStreamStoreDir)
defer os.RemoveAll(dir)
}
s.Shutdown()
}
func testMQTTDefaultTLSOptions(t *testing.T, verify bool) *Options {
t.Helper()
o := testMQTTDefaultOptions()
tc := &TLSConfigOpts{
CertFile: "../test/configs/certs/server-cert.pem",
KeyFile: "../test/configs/certs/server-key.pem",
CaFile: "../test/configs/certs/ca.pem",
Verify: verify,
}
var err error
o.MQTT.TLSConfig, err = GenTLSConfig(tc)
o.MQTT.TLSTimeout = 2.0
if err != nil {
t.Fatalf("Error creating tls config: %v", err)
}
return o
}
func TestMQTTStandaloneRequiresJetStream(t *testing.T) {
conf := createConfFile(t, []byte(`
mqtt {
port: -1
tls {
cert_file: "./configs/certs/server.pem"
key_file: "./configs/certs/key.pem"
}
}
`))
defer removeFile(t, conf)
o, err := ProcessConfigFile(conf)
if err != nil {
t.Fatalf("Error processing config file: %v", err)
}
if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "standalone") {
t.Fatalf("Expected error about requiring JetStream in standalone mode, got %v", err)
}
}
func TestMQTTConfig(t *testing.T) {
conf := createConfFile(t, []byte(`
jetstream: enabled
mqtt {
port: -1
tls {
cert_file: "./configs/certs/server.pem"
key_file: "./configs/certs/key.pem"
}
}
`))
defer removeFile(t, conf)
s, o := RunServerWithConfig(conf)
defer testMQTTShutdownServer(s)
if o.MQTT.TLSConfig == nil {
t.Fatal("expected TLS config to be set")
}
}
func TestMQTTValidateOptions(t *testing.T) {
nmqtto := DefaultOptions()
mqtto := testMQTTDefaultOptions()
for _, test := range []struct {
name string
getOpts func() *Options
err string
}{
{"mqtt disabled", func() *Options { return nmqtto.Clone() }, ""},
{"mqtt username not allowed if users specified", func() *Options {
o := mqtto.Clone()
o.Users = []*User{{Username: "abc", Password: "pwd"}}
o.MQTT.Username = "b"
o.MQTT.Password = "pwd"
return o
}, "mqtt authentication username not compatible with presence of users/nkeys"},
{"mqtt token not allowed if users specified", func() *Options {
o := mqtto.Clone()
o.Nkeys = []*NkeyUser{{Nkey: "abc"}}
o.MQTT.Token = "mytoken"
return o
}, "mqtt authentication token not compatible with presence of users/nkeys"},
{"ack wait should be >=0", func() *Options {
o := mqtto.Clone()
o.MQTT.AckWait = -10 * time.Second
return o
}, "ack wait must be a positive value"},
} {
t.Run(test.name, func(t *testing.T) {
err := validateMQTTOptions(test.getOpts())
if test.err == "" && err != nil {
t.Fatalf("Unexpected error: %v", err)
} else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) {
t.Fatalf("Expected error to contain %q, got %v", test.err, err)
}
})
}
}
func TestMQTTParseOptions(t *testing.T) {
for _, test := range []struct {
name string
content string
checkOpt func(*MQTTOpts) error
err string
}{
// Negative tests
{"bad type", "mqtt: []", nil, "to be a map"},
{"bad listen", "mqtt: { listen: [] }", nil, "port or host:port"},
{"bad port", `mqtt: { port: "abc" }`, nil, "not int64"},
{"bad host", `mqtt: { host: 123 }`, nil, "not string"},
{"bad tls", `mqtt: { tls: 123 }`, nil, "not map[string]interface {}"},
{"unknown field", `mqtt: { this_does_not_exist: 123 }`, nil, "unknown"},
{"ack wait", `mqtt: {ack_wait: abc}`, nil, "invalid duration"},
{"max ack pending", `mqtt: {max_ack_pending: abc}`, nil, "not int64"},
{"max ack pending too high", `mqtt: {max_ack_pending: 12345678}`, nil, "invalid value"},
// Positive tests
{"tls gen fails", `
mqtt {
tls {
cert_file: "./configs/certs/server.pem"
}
}`, nil, "missing 'key_file'"},
{"listen port only", `mqtt { listen: 1234 }`, func(o *MQTTOpts) error {
if o.Port != 1234 {
return fmt.Errorf("expected 1234, got %v", o.Port)
}
return nil
}, ""},
{"listen host and port", `mqtt { listen: "localhost:1234" }`, func(o *MQTTOpts) error {
if o.Host != "localhost" || o.Port != 1234 {
return fmt.Errorf("expected localhost:1234, got %v:%v", o.Host, o.Port)
}
return nil
}, ""},
{"host", `mqtt { host: "localhost" }`, func(o *MQTTOpts) error {
if o.Host != "localhost" {
return fmt.Errorf("expected localhost, got %v", o.Host)
}
return nil
}, ""},
{"port", `mqtt { port: 1234 }`, func(o *MQTTOpts) error {
if o.Port != 1234 {
return fmt.Errorf("expected 1234, got %v", o.Port)
}
return nil
}, ""},
{"tls config",
`
mqtt {
tls {
cert_file: "./configs/certs/server.pem"
key_file: "./configs/certs/key.pem"
}
}
`, func(o *MQTTOpts) error {
if o.TLSConfig == nil {
return fmt.Errorf("TLSConfig should have been set")
}
return nil
}, ""},
{"no auth user",
`
mqtt {
no_auth_user: "noauthuser"
}
`, func(o *MQTTOpts) error {
if o.NoAuthUser != "noauthuser" {
return fmt.Errorf("Invalid NoAuthUser value: %q", o.NoAuthUser)
}
return nil
}, ""},
{"auth block",
`
mqtt {
authorization {
user: "mqttuser"
password: "pwd"
token: "token"
timeout: 2.0
}
}
`, func(o *MQTTOpts) error {
if o.Username != "mqttuser" || o.Password != "pwd" || o.Token != "token" || o.AuthTimeout != 2.0 {
return fmt.Errorf("Invalid auth block: %+v", o)
}
return nil
}, ""},
{"auth timeout as int",
`
mqtt {
authorization {
timeout: 2
}
}
`, func(o *MQTTOpts) error {
if o.AuthTimeout != 2.0 {
return fmt.Errorf("Invalid auth timeout: %v", o.AuthTimeout)
}
return nil
}, ""},
{"ack wait",
`
mqtt {
ack_wait: "10s"
}
`, func(o *MQTTOpts) error {
if o.AckWait != 10*time.Second {
return fmt.Errorf("Invalid ack wait: %v", o.AckWait)
}
return nil
}, ""},
{"max ack pending",
`
mqtt {
max_ack_pending: 123
}
`, func(o *MQTTOpts) error {
if o.MaxAckPending != 123 {
return fmt.Errorf("Invalid max ack pending: %v", o.MaxAckPending)
}
return nil
}, ""},
} {
t.Run(test.name, func(t *testing.T) {
conf := createConfFile(t, []byte(test.content))
defer removeFile(t, conf)
o, err := ProcessConfigFile(conf)
if test.err != _EMPTY_ {
if err == nil || !strings.Contains(err.Error(), test.err) {
t.Fatalf("For content: %q, expected error about %q, got %v", test.content, test.err, err)
}
return
} else if err != nil {
t.Fatalf("Unexpected error for content %q: %v", test.content, err)
}
if err := test.checkOpt(&o.MQTT); err != nil {
t.Fatalf("Incorrect option for content %q: %v", test.content, err.Error())
}
})
}
}
func TestMQTTStart(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
nc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port))
if err != nil {
t.Fatalf("Unable to create tcp connection to mqtt port: %v", err)
}
nc.Close()
// Check failure to start due to port in use
o2 := testMQTTDefaultOptions()
o2.MQTT.Port = o.MQTT.Port
s2, err := NewServer(o2)
if err != nil {
t.Fatalf("Error creating server: %v", err)
}
defer s2.Shutdown()
l := &captureFatalLogger{fatalCh: make(chan string, 1)}
s2.SetLogger(l, false, false)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
s2.Start()
wg.Done()
}()
select {
case e := <-l.fatalCh:
if !strings.Contains(e, "Unable to listen for MQTT connections") {
t.Fatalf("Unexpected error: %q", e)
}
case <-time.After(time.Second):
t.Fatal("Should have gotten a fatal error")
}
}
func TestMQTTTLS(t *testing.T) {
o := testMQTTDefaultTLSOptions(t, false)
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
nc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port))
if err != nil {
t.Fatalf("Unable to create tcp connection to mqtt port: %v", err)
}
defer nc.Close()
// Set MaxVersion to TLSv1.2 so that we fail on handshake if there is
// a disagreement between server and client.
tlsc := &tls.Config{
MaxVersion: tls.VersionTLS12,
InsecureSkipVerify: true,
}
tlsConn := tls.Client(nc, tlsc)
tlsConn.SetDeadline(time.Now().Add(time.Second))
if err := tlsConn.Handshake(); err != nil {
t.Fatalf("Error doing tls handshake: %v", err)
}
nc.Close()
testMQTTShutdownServer(s)
// Force client cert verification
o = testMQTTDefaultTLSOptions(t, true)
s = testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port))
if err != nil {
t.Fatalf("Unable to create tcp connection to mqtt port: %v", err)
}
defer nc.Close()
// Set MaxVersion to TLSv1.2 so that we fail on handshake if there is
// a disagreement between server and client.
tlsc = &tls.Config{
MaxVersion: tls.VersionTLS12,
InsecureSkipVerify: true,
}
tlsConn = tls.Client(nc, tlsc)
tlsConn.SetDeadline(time.Now().Add(time.Second))
if err := tlsConn.Handshake(); err == nil {
t.Fatal("Handshake expected to fail since client did not provide cert")
}
nc.Close()
// Add client cert.
nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port))
if err != nil {
t.Fatalf("Unable to create tcp connection to mqtt port: %v", err)
}
defer nc.Close()
tc := &TLSConfigOpts{
CertFile: "../test/configs/certs/client-cert.pem",
KeyFile: "../test/configs/certs/client-key.pem",
}
tlsc, err = GenTLSConfig(tc)
if err != nil {
t.Fatalf("Error generating tls config: %v", err)
}
tlsc.InsecureSkipVerify = true
tlsConn = tls.Client(nc, tlsc)
tlsConn.SetDeadline(time.Now().Add(time.Second))
if err := tlsConn.Handshake(); err != nil {
t.Fatalf("Handshake error: %v", err)
}
nc.Close()
testMQTTShutdownServer(s)
// Lower TLS timeout so low that we should fail
o.MQTT.TLSTimeout = 0.001
s = testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port))
if err != nil {
t.Fatalf("Unable to create tcp connection to mqtt port: %v", err)
}
defer nc.Close()
time.Sleep(100 * time.Millisecond)
tlsConn = tls.Client(nc, tlsc)
tlsConn.SetDeadline(time.Now().Add(time.Second))
if err := tlsConn.Handshake(); err == nil {
t.Fatal("Expected failure, did not get one")
}
}
type mqttConnInfo struct {
clientID string
cleanSess bool
keepAlive uint16
will *mqttWill
user string
pass string
}
func testMQTTGetClient(t testing.TB, s *Server, clientID string) *client {
t.Helper()
var mc *client
s.mu.Lock()
for _, c := range s.clients {
c.mu.Lock()
if c.isMqtt() && c.mqtt.cp != nil && c.mqtt.cp.clientID == clientID {
mc = c
}
c.mu.Unlock()
if mc != nil {
break
}
}
s.mu.Unlock()
if mc == nil {
t.Fatalf("Did not find client %q", clientID)
}
return mc
}
func testMQTTRead(c net.Conn) ([]byte, error) {
var buf [512]byte
// Make sure that test does not block
c.SetReadDeadline(time.Now().Add(testMQTTTimeout))
n, err := c.Read(buf[:])
if err != nil {
return nil, err
}
c.SetReadDeadline(time.Time{})
return copyBytes(buf[:n]), nil
}
func testMQTTWrite(c net.Conn, buf []byte) (int, error) {
c.SetWriteDeadline(time.Now().Add(testMQTTTimeout))
n, err := c.Write(buf)
c.SetWriteDeadline(time.Time{})
return n, err
}
func testMQTTConnect(t testing.TB, ci *mqttConnInfo, host string, port int) (net.Conn, *mqttReader) {
t.Helper()
addr := fmt.Sprintf("%s:%d", host, port)
c, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating mqtt connection: %v", err)
}
proto := mqttCreateConnectProto(ci)
if _, err := testMQTTWrite(c, proto); err != nil {
t.Fatalf("Error writing connect: %v", err)
}
buf, err := testMQTTRead(c)
if err != nil {
t.Fatalf("Error reading: %v", err)
}
br := &mqttReader{reader: c}
br.reset(buf)
return c, br
}
func mqttCreateConnectProto(ci *mqttConnInfo) []byte {
flags := byte(0)
if ci.cleanSess {
flags |= mqttConnFlagCleanSession
}
if ci.will != nil {
flags |= mqttConnFlagWillFlag | (ci.will.qos << 3)
if ci.will.retain {
flags |= mqttConnFlagWillRetain
}
}
if ci.user != _EMPTY_ {
flags |= mqttConnFlagUsernameFlag
}
if ci.pass != _EMPTY_ {
flags |= mqttConnFlagPasswordFlag
}
pkLen := 2 + len(mqttProtoName) +
1 + // proto level
1 + // flags
2 + // keepAlive
2 + len(ci.clientID)
if ci.will != nil {
pkLen += 2 + len(ci.will.topic)
pkLen += 2 + len(ci.will.message)
}
if ci.user != _EMPTY_ {
pkLen += 2 + len(ci.user)
}
if ci.pass != _EMPTY_ {
pkLen += 2 + len(ci.pass)
}
w := &mqttWriter{}
w.WriteByte(mqttPacketConnect)
w.WriteVarInt(pkLen)
w.WriteString(string(mqttProtoName))
w.WriteByte(0x4)
w.WriteByte(flags)
w.WriteUint16(ci.keepAlive)
w.WriteString(ci.clientID)
if ci.will != nil {
w.WriteBytes(ci.will.topic)
w.WriteBytes(ci.will.message)
}
if ci.user != _EMPTY_ {
w.WriteString(ci.user)
}
if ci.pass != _EMPTY_ {
w.WriteBytes([]byte(ci.pass))
}
return w.Bytes()
}
func testMQTTCheckConnAck(t testing.TB, r *mqttReader, rc byte, sessionPresent bool) {
t.Helper()
r.reader.SetReadDeadline(time.Now().Add(testMQTTTimeout))
if err := r.ensurePacketInBuffer(4); err != nil {
t.Fatalf("Error ensuring packet in buffer: %v", err)
}
r.reader.SetReadDeadline(time.Time{})
b, err := r.readByte("connack packet type")
if err != nil {
t.Fatalf("Error reading packet type: %v", err)
}
pt := b & mqttPacketMask
if pt != mqttPacketConnectAck {
t.Fatalf("Expected ConnAck (%x), got %x", mqttPacketConnectAck, pt)
}
pl, err := r.readByte("connack packet len")
if err != nil {
t.Fatalf("Error reading packet length: %v", err)
}
if pl != 2 {
t.Fatalf("ConnAck packet length should be 2, got %v", pl)
}
caf, err := r.readByte("connack flags")
if err != nil {
t.Fatalf("Error reading packet length: %v", err)
}
if caf&0xfe != 0 {
t.Fatalf("ConnAck flag bits 7-1 should all be 0, got %x", caf>>1)
}
if sp := caf == 1; sp != sessionPresent {
t.Fatalf("Expected session present flag=%v got %v", sessionPresent, sp)
}
carc, err := r.readByte("connack return code")
if err != nil {
t.Fatalf("Error reading returned code: %v", err)
}
if carc != rc {
t.Fatalf("Expected return code to be %v, got %v", rc, carc)
}
}
func TestMQTTRequiresJSEnabled(t *testing.T) {
o := testMQTTDefaultOptions()
acc := NewAccount("mqtt")
o.Accounts = []*Account{acc}
o.Users = []*User{{Username: "mqtt", Account: acc}}
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)
c, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating mqtt connection: %v", err)
}
defer c.Close()
proto := mqttCreateConnectProto(&mqttConnInfo{cleanSess: true, user: "mqtt"})
if _, err := testMQTTWrite(c, proto); err != nil {
t.Fatalf("Error writing connect: %v", err)
}
if _, err := testMQTTRead(c); err == nil {
t.Fatal("Expected failure, did not get one")
}
}
func testMQTTEnableJSForAccount(t *testing.T, s *Server, accName string) {
t.Helper()
acc, err := s.LookupAccount(accName)
if err != nil {
t.Fatalf("Error looking up account: %v", err)
}
limits := &JetStreamAccountLimits{
MaxConsumers: -1,
MaxStreams: -1,
MaxMemory: 1024 * 1024,
}
if err := acc.EnableJetStream(limits); err != nil {
t.Fatalf("Error enabling JS: %v", err)
}
}
func TestMQTTTLSVerifyAndMap(t *testing.T) {
accName := "MyAccount"
acc := NewAccount(accName)
certUserName := "CN=example.com,OU=NATS.io"
users := []*User{{Username: certUserName, Account: acc}}
for _, test := range []struct {
name string
filtering bool
provideCert bool
}{
{"no filtering, client provides cert", false, true},
{"no filtering, client does not provide cert", false, false},
{"filtering, client provides cert", true, true},
{"filtering, client does not provide cert", true, false},
} {
t.Run(test.name, func(t *testing.T) {
o := testMQTTDefaultOptions()
o.Host = "localhost"
o.Accounts = []*Account{acc}
o.Users = users
if test.filtering {
o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeMqtt})
}
tc := &TLSConfigOpts{
CertFile: "../test/configs/certs/tlsauth/server.pem",
KeyFile: "../test/configs/certs/tlsauth/server-key.pem",
CaFile: "../test/configs/certs/tlsauth/ca.pem",
Verify: true,
}
tlsc, err := GenTLSConfig(tc)
if err != nil {
t.Fatalf("Error creating tls config: %v", err)
}
o.MQTT.TLSConfig = tlsc
o.MQTT.TLSTimeout = 2.0
o.MQTT.TLSMap = true
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
testMQTTEnableJSForAccount(t, s, accName)
addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)
mc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating ws connection: %v", err)
}
defer mc.Close()
tlscc := &tls.Config{}
if test.provideCert {
tc := &TLSConfigOpts{
CertFile: "../test/configs/certs/tlsauth/client.pem",
KeyFile: "../test/configs/certs/tlsauth/client-key.pem",
}
var err error
tlscc, err = GenTLSConfig(tc)
if err != nil {
t.Fatalf("Error generating tls config: %v", err)
}
}
tlscc.InsecureSkipVerify = true
if test.provideCert {
tlscc.MinVersion = tls.VersionTLS13
}
mc = tls.Client(mc, tlscc)
if err := mc.(*tls.Conn).Handshake(); err != nil {
t.Fatalf("Error during handshake: %v", err)
}
ci := &mqttConnInfo{cleanSess: true}
proto := mqttCreateConnectProto(ci)
if _, err := testMQTTWrite(mc, proto); err != nil {
t.Fatalf("Error sending proto: %v", err)
}
buf, err := testMQTTRead(mc)
if !test.provideCert {
if err == nil {
t.Fatal("Expected error, did not get one")
} else if !strings.Contains(err.Error(), "bad certificate") {
t.Fatalf("Unexpected error: %v", err)
}
return
}
if err != nil {
t.Fatalf("Error reading: %v", err)
}
r := &mqttReader{reader: mc}
r.reset(buf)
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
var c *client
s.mu.Lock()
for _, sc := range s.clients {
sc.mu.Lock()
if sc.isMqtt() {
c = sc
}
sc.mu.Unlock()
if c != nil {
break
}
}
s.mu.Unlock()
if c == nil {
t.Fatal("Client not found")
}
var uname string
var accname string
c.mu.Lock()
uname = c.opts.Username
if c.acc != nil {
accname = c.acc.GetName()
}
c.mu.Unlock()
if uname != certUserName {
t.Fatalf("Expected username %q, got %q", certUserName, uname)
}
if accname != accName {
t.Fatalf("Expected account %q, got %v", accName, accname)
}
})
}
}
func TestMQTTBasicAuth(t *testing.T) {
for _, test := range []struct {
name string
opts func() *Options
user string
pass string
rc byte
}{
{
"top level auth, no override, wrong u/p",
func() *Options {
o := testMQTTDefaultOptions()
o.Username = "normal"
o.Password = "client"
return o
},
"mqtt", "client", mqttConnAckRCNotAuthorized,
},
{
"top level auth, no override, correct u/p",
func() *Options {
o := testMQTTDefaultOptions()
o.Username = "normal"
o.Password = "client"
return o
},
"normal", "client", mqttConnAckRCConnectionAccepted,
},
{
"no top level auth, mqtt auth, wrong u/p",
func() *Options {
o := testMQTTDefaultOptions()
o.MQTT.Username = "mqtt"
o.MQTT.Password = "client"
return o
},
"normal", "client", mqttConnAckRCNotAuthorized,
},
{
"no top level auth, mqtt auth, correct u/p",
func() *Options {
o := testMQTTDefaultOptions()
o.MQTT.Username = "mqtt"
o.MQTT.Password = "client"
return o
},
"mqtt", "client", mqttConnAckRCConnectionAccepted,
},
{
"top level auth, mqtt override, wrong u/p",
func() *Options {
o := testMQTTDefaultOptions()
o.Username = "normal"
o.Password = "client"
o.MQTT.Username = "mqtt"
o.MQTT.Password = "client"
return o
},
"normal", "client", mqttConnAckRCNotAuthorized,
},
{
"top level auth, mqtt override, correct u/p",
func() *Options {
o := testMQTTDefaultOptions()
o.Username = "normal"
o.Password = "client"
o.MQTT.Username = "mqtt"
o.MQTT.Password = "client"
return o
},
"mqtt", "client", mqttConnAckRCConnectionAccepted,
},
} {
t.Run(test.name, func(t *testing.T) {
o := test.opts()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
ci := &mqttConnInfo{
cleanSess: true,
user: test.user,
pass: test.pass,
}
mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, test.rc, false)
})
}
}
func TestMQTTAuthTimeout(t *testing.T) {
for _, test := range []struct {
name string
at float64
mat float64
ok bool
}{
{"use top-level auth timeout", 0.5, 0.0, true},
{"use mqtt auth timeout", 0.5, 0.05, false},
} {
t.Run(test.name, func(t *testing.T) {
o := testMQTTDefaultOptions()
o.AuthTimeout = test.at
o.MQTT.Username = "mqtt"
o.MQTT.Password = "client"
o.MQTT.AuthTimeout = test.mat
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
mc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port))
if err != nil {
t.Fatalf("Error connecting: %v", err)
}
defer mc.Close()
time.Sleep(100 * time.Millisecond)
ci := &mqttConnInfo{
cleanSess: true,
user: "mqtt",
pass: "client",
}
proto := mqttCreateConnectProto(ci)
if _, err := testMQTTWrite(mc, proto); err != nil {
if test.ok {
t.Fatalf("Error sending connect: %v", err)
}
// else it is ok since we got disconnected due to auth timeout
return
}
buf, err := testMQTTRead(mc)
if err != nil {
if test.ok {
t.Fatalf("Error reading: %v", err)
}
// else it is ok since we got disconnected due to auth timeout
return
}
r := &mqttReader{reader: mc}
r.reset(buf)
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
time.Sleep(500 * time.Millisecond)
testMQTTPublish(t, mc, r, 1, false, false, "foo", 1, []byte("msg"))
})
}
}
func TestMQTTTokenAuth(t *testing.T) {
for _, test := range []struct {
name string
opts func() *Options
token string
rc byte
}{
{
"top level auth, no override, wrong token",
func() *Options {
o := testMQTTDefaultOptions()
o.Authorization = "goodtoken"
return o
},
"badtoken", mqttConnAckRCNotAuthorized,
},
{
"top level auth, no override, correct token",
func() *Options {
o := testMQTTDefaultOptions()
o.Authorization = "goodtoken"
return o
},
"goodtoken", mqttConnAckRCConnectionAccepted,
},
{
"no top level auth, mqtt auth, wrong token",
func() *Options {
o := testMQTTDefaultOptions()
o.MQTT.Token = "goodtoken"
return o
},
"badtoken", mqttConnAckRCNotAuthorized,
},
{
"no top level auth, mqtt auth, correct token",
func() *Options {
o := testMQTTDefaultOptions()
o.MQTT.Token = "goodtoken"
return o
},
"goodtoken", mqttConnAckRCConnectionAccepted,
},
{
"top level auth, mqtt override, wrong token",
func() *Options {
o := testMQTTDefaultOptions()
o.Authorization = "clienttoken"
o.MQTT.Token = "mqtttoken"
return o
},
"clienttoken", mqttConnAckRCNotAuthorized,
},
{
"top level auth, mqtt override, correct token",
func() *Options {
o := testMQTTDefaultOptions()
o.Authorization = "clienttoken"
o.MQTT.Token = "mqtttoken"
return o
},
"mqtttoken", mqttConnAckRCConnectionAccepted,
},
} {
t.Run(test.name, func(t *testing.T) {
o := test.opts()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
ci := &mqttConnInfo{
cleanSess: true,
user: "ignore_use_token",
pass: test.token,
}
mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, test.rc, false)
})
}
}
func TestMQTTJWTWithAllowedConnectionTypes(t *testing.T) {
o := testMQTTDefaultOptions()
// Create System Account
syskp, _ := nkeys.CreateAccount()
syspub, _ := syskp.PublicKey()
sysAc := jwt.NewAccountClaims(syspub)
sysjwt, err := sysAc.Encode(oKp)
if err != nil {
t.Fatalf("Error generating account JWT: %v", err)
}
// Create memory resolver and store system account
mr := &MemAccResolver{}
mr.Store(syspub, sysjwt)
if err != nil {
t.Fatalf("Error saving system account JWT to memory resolver: %v", err)
}
// Add system account and memory resolver to server options
o.SystemAccount = syspub
o.AccountResolver = mr
setupAddTrusted(o)
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
for _, test := range []struct {
name string
connectionTypes []string
rc byte
}{
{"not allowed", []string{jwt.ConnectionTypeStandard}, mqttConnAckRCNotAuthorized},
{"allowed", []string{jwt.ConnectionTypeStandard, strings.ToLower(jwt.ConnectionTypeMqtt)}, mqttConnAckRCConnectionAccepted},
{"allowed with unknown", []string{jwt.ConnectionTypeMqtt, "SomeNewType"}, mqttConnAckRCConnectionAccepted},
{"not allowed with unknown", []string{"SomeNewType"}, mqttConnAckRCNotAuthorized},
} {
t.Run(test.name, func(t *testing.T) {
nuc := newJWTTestUserClaims()
nuc.AllowedConnectionTypes = test.connectionTypes
nuc.BearerToken = true
okp, _ := nkeys.FromSeed(oSeed)
akp, _ := nkeys.CreateAccount()
apub, _ := akp.PublicKey()
nac := jwt.NewAccountClaims(apub)
// Enable Jetstream on account with lax limitations
nac.Limits.JetStreamLimits.Consumer = -1
nac.Limits.JetStreamLimits.Streams = -1
nac.Limits.JetStreamLimits.MemoryStorage = 1024 * 1024
ajwt, err := nac.Encode(okp)
if err != nil {
t.Fatalf("Error generating account JWT: %v", err)
}
nkp, _ := nkeys.CreateUser()
pub, _ := nkp.PublicKey()
nuc.Subject = pub
jwt, err := nuc.Encode(akp)
if err != nil {
t.Fatalf("Error generating user JWT: %v", err)
}
addAccountToMemResolver(s, apub, ajwt)
ci := &mqttConnInfo{
cleanSess: true,
user: "ignore_use_token",
pass: jwt,
}
mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, test.rc, false)
})
}
}
func TestMQTTUsersAuth(t *testing.T) {
users := []*User{{Username: "user", Password: "pwd"}}
for _, test := range []struct {
name string
opts func() *Options
user string
pass string
rc byte
}{
{
"no filtering, wrong user",
func() *Options {
o := testMQTTDefaultOptions()
o.Users = users
return o
},
"wronguser", "pwd", mqttConnAckRCNotAuthorized,
},
{
"no filtering, correct user",
func() *Options {
o := testMQTTDefaultOptions()
o.Users = users
return o
},
"user", "pwd", mqttConnAckRCConnectionAccepted,
},
{
"filtering, user not allowed",
func() *Options {
o := testMQTTDefaultOptions()
o.Users = users
// Only allowed for regular clients
o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard})
return o
},
"user", "pwd", mqttConnAckRCNotAuthorized,
},
{
"filtering, user allowed",
func() *Options {
o := testMQTTDefaultOptions()
o.Users = users
o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeMqtt})
return o
},
"user", "pwd", mqttConnAckRCConnectionAccepted,
},
{
"filtering, wrong password",
func() *Options {
o := testMQTTDefaultOptions()
o.Users = users
o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeMqtt})
return o
},
"user", "badpassword", mqttConnAckRCNotAuthorized,
},
} {
t.Run(test.name, func(t *testing.T) {
o := test.opts()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
ci := &mqttConnInfo{
cleanSess: true,
user: test.user,
pass: test.pass,
}
mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, test.rc, false)
})
}
}
func TestMQTTNoAuthUserValidation(t *testing.T) {
o := testMQTTDefaultOptions()
o.Users = []*User{{Username: "user", Password: "pwd"}}
// Should fail because it is not part of o.Users.
o.MQTT.NoAuthUser = "notfound"
if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") {
t.Fatalf("Expected error saying not present as user, got %v", err)
}
// Set a valid no auth user for global options, but still should fail because
// of o.MQTT.NoAuthUser
o.NoAuthUser = "user"
o.MQTT.NoAuthUser = "notfound"
if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") {
t.Fatalf("Expected error saying not present as user, got %v", err)
}
}
func TestMQTTNoAuthUser(t *testing.T) {
for _, test := range []struct {
name string
override bool
useAuth bool
expectedUser string
expectedAcc string
}{
{"no override, no user provided", false, false, "noauth", "normal"},
{"no override, user povided", false, true, "user", "normal"},
{"override, no user provided", true, false, "mqttnoauth", "mqtt"},
{"override, user provided", true, true, "mqttuser", "mqtt"},
} {
t.Run(test.name, func(t *testing.T) {
o := testMQTTDefaultOptions()
normalAcc := NewAccount("normal")
mqttAcc := NewAccount("mqtt")
o.Accounts = []*Account{normalAcc, mqttAcc}
o.Users = []*User{
{Username: "noauth", Password: "pwd", Account: normalAcc},
{Username: "user", Password: "pwd", Account: normalAcc},
{Username: "mqttnoauth", Password: "pwd", Account: mqttAcc},
{Username: "mqttuser", Password: "pwd", Account: mqttAcc},
}
o.NoAuthUser = "noauth"
if test.override {
o.MQTT.NoAuthUser = "mqttnoauth"
}
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
testMQTTEnableJSForAccount(t, s, "normal")
testMQTTEnableJSForAccount(t, s, "mqtt")
ci := &mqttConnInfo{clientID: "mqtt", cleanSess: true}
if test.useAuth {
ci.user = test.expectedUser
ci.pass = "pwd"
}
mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
c := testMQTTGetClient(t, s, "mqtt")
c.mu.Lock()
uname := c.opts.Username
aname := c.acc.GetName()
c.mu.Unlock()
if uname != test.expectedUser {
t.Fatalf("Expected selected user to be %q, got %q", test.expectedUser, uname)
}
if aname != test.expectedAcc {
t.Fatalf("Expected selected account to be %q, got %q", test.expectedAcc, aname)
}
})
}
}
func TestMQTTConnectNotFirstPacket(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
l := &captureErrorLogger{errCh: make(chan string, 10)}
s.SetLogger(l, false, false)
c, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port))
if err != nil {
t.Fatalf("Error on dial: %v", err)
}
defer c.Close()
w := &mqttWriter{}
mqttWritePublish(w, 0, false, false, "foo", 0, []byte("hello"))
if _, err := testMQTTWrite(c, w.Bytes()); err != nil {
t.Fatalf("Error publishing: %v", err)
}
testMQTTExpectDisconnect(t, c)
select {
case err := <-l.errCh:
if !strings.Contains(err, "should be a CONNECT") {
t.Fatalf("Expected error about first packet being a CONNECT, got %v", err)
}
case <-time.After(time.Second):
t.Fatal("Did not log any error")
}
}
func TestMQTTSecondConnect(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
proto := mqttCreateConnectProto(&mqttConnInfo{cleanSess: true})
if _, err := testMQTTWrite(mc, proto); err != nil {
t.Fatalf("Error writing connect: %v", err)
}
testMQTTExpectDisconnect(t, mc)
}
func TestMQTTParseConnect(t *testing.T) {
eofr := testNewEOFReader()
for _, test := range []struct {
name string
proto []byte
pl int
reader mqttIOReader
err string
}{
{"packet in buffer error", nil, 10, eofr, "error ensuring protocol is loaded"},
{"bad proto name", []byte{0, 4, 'B', 'A', 'D'}, 5, nil, "protocol name"},
{"invalid proto name", []byte{0, 3, 'B', 'A', 'D'}, 5, nil, "expected connect packet with protocol name"},
{"old proto not supported", []byte{0, 6, 'M', 'Q', 'I', 's', 'd', 'p'}, 8, nil, "older protocol"},
{"error on protocol level", []byte{0, 4, 'M', 'Q', 'T', 'T'}, 6, eofr, "protocol level"},
{"unacceptable protocol version", []byte{0, 4, 'M', 'Q', 'T', 'T', 10}, 7, nil, "unacceptable protocol version"},
{"error on flags", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel}, 7, eofr, "flags"},
{"reserved flag", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 1}, 8, nil, "connect flags reserved bit not set to 0"},
{"will qos without will flag", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 1 << 3}, 8, nil, "if Will flag is set to 0, Will QoS must be 0 too"},
{"will retain without will flag", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 1 << 5}, 8, nil, "if Will flag is set to 0, Will Retain flag must be 0 too"},
{"will qos", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 3<<3 | 1<<2}, 8, nil, "if Will flag is set to 1, Will QoS can be 0, 1 or 2"},
{"no user but password", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagPasswordFlag}, 8, nil, "password flag set but username flag is not"},
{"missing keep alive", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0}, 8, nil, "keep alive"},
{"missing client ID", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0, 0, 1}, 10, nil, "client ID"},
{"empty client ID", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0, 0, 1, 0, 0}, 12, nil, "when client ID is empty, clean session flag must be set to 1"},
{"invalid utf8 client ID", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, 0, 0, 1, 0, 1, 241}, 13, nil, "invalid utf8 for client ID"},
{"missing will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0}, 12, nil, "Will topic"},
{"empty will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 0}, 14, nil, "empty Will topic not allowed"},
{"invalid utf8 will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 241}, 15, nil, "invalide utf8 for Will topic"},
{"invalid wildcard will topic", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, '#'}, 15, nil, "wildcards not allowed"},
{"error on will message", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagWillFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 'a', 0, 3}, 17, eofr, "Will message"},
{"error on username", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagCleanSession, 0, 0, 0, 0}, 12, eofr, "user name"},
{"empty username", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 0}, 14, nil, "empty user name not allowed"},
{"invalid utf8 username", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 241}, 15, nil, "invalid utf8 for user name"},
{"error on password", []byte{0, 4, 'M', 'Q', 'T', 'T', mqttProtoLevel, mqttConnFlagUsernameFlag | mqttConnFlagPasswordFlag | mqttConnFlagCleanSession, 0, 0, 0, 0, 0, 1, 'a'}, 15, eofr, "password"},
} {
t.Run(test.name, func(t *testing.T) {
r := &mqttReader{reader: test.reader}
r.reset(test.proto)
mqtt := &mqtt{r: r}
c := &client{mqtt: mqtt}
if _, _, err := c.mqttParseConnect(r, test.pl); err == nil || !strings.Contains(err.Error(), test.err) {
t.Fatalf("Expected error %q, got %v", test.err, err)
}
})
}
}
func TestMQTTConnectFailsOnParse(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)
c, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating mqtt connection: %v", err)
}
pkLen := 2 + len(mqttProtoName) +
1 + // proto level
1 + // flags
2 + // keepAlive
2 + len("mqtt")
w := &mqttWriter{}
w.WriteByte(mqttPacketConnect)
w.WriteVarInt(pkLen)
w.WriteString(string(mqttProtoName))
w.WriteByte(0x7)
w.WriteByte(mqttConnFlagCleanSession)
w.WriteUint16(0)
w.WriteString("mqtt")
c.Write(w.Bytes())
buf, err := testMQTTRead(c)
if err != nil {
t.Fatalf("Error reading: %v", err)
}
r := &mqttReader{reader: c}
r.reset(buf)
testMQTTCheckConnAck(t, r, mqttConnAckRCUnacceptableProtocolVersion, false)
}
func TestMQTTConnKeepAlive(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true, keepAlive: 1}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, mc, r, 0, false, false, "foo", 0, []byte("msg"))
time.Sleep(2 * time.Second)
testMQTTExpectDisconnect(t, mc)
}
func TestMQTTDontSetPinger(t *testing.T) {
o := testMQTTDefaultOptions()
o.PingInterval = 15 * time.Millisecond
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: "mqtt", cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
c := testMQTTGetClient(t, s, "mqtt")
c.mu.Lock()
timerSet := c.ping.tmr != nil
c.mu.Unlock()
if timerSet {
t.Fatalf("Ping timer should not be set for MQTT clients")
}
// Wait a bit and expect nothing (and connection should still be valid)
testMQTTExpectNothing(t, r)
testMQTTPublish(t, mc, r, 0, false, false, "foo", 0, []byte("msg"))
}
func TestMQTTUnsupportedPackets(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
for _, test := range []struct {
name string
packetType byte
}{
{"pubrec", mqttPacketPubRec},
{"pubrel", mqttPacketPubRel},
{"pubcomp", mqttPacketPubComp},
} {
t.Run(test.name, func(t *testing.T) {
mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
w := &mqttWriter{}
pt := test.packetType
if test.packetType == mqttPacketPubRel {
pt |= byte(0x2)
}
w.WriteByte(pt)
w.WriteVarInt(2)
w.WriteUint16(1)
mc.Write(w.Bytes())
testMQTTExpectDisconnect(t, mc)
})
}
}
func TestMQTTTopicAndSubjectConversion(t *testing.T) {
for _, test := range []struct {
name string
mqttTopic string
natsSubject string
err string
}{
{"/", "/", "/./", ""},
{"//", "//", "/././", ""},
{"///", "///", "/./././", ""},
{"////", "////", "/././././", ""},
{"foo", "foo", "foo", ""},
{"/foo", "/foo", "/.foo", ""},
{"//foo", "//foo", "/./.foo", ""},
{"///foo", "///foo", "/././.foo", ""},
{"///foo/", "///foo/", "/././.foo./", ""},
{"///foo//", "///foo//", "/././.foo././", ""},
{"///foo///", "///foo///", "/././.foo./././", ""},
{"foo/bar", "foo/bar", "foo.bar", ""},
{"/foo/bar", "/foo/bar", "/.foo.bar", ""},
{"/foo/bar/", "/foo/bar/", "/.foo.bar./", ""},
{"foo/bar/baz", "foo/bar/baz", "foo.bar.baz", ""},
{"/foo/bar/baz", "/foo/bar/baz", "/.foo.bar.baz", ""},
{"/foo/bar/baz/", "/foo/bar/baz/", "/.foo.bar.baz./", ""},
{"bar", "bar/", "bar./", ""},
{"bar//", "bar//", "bar././", ""},
{"bar///", "bar///", "bar./././", ""},
{"foo//bar", "foo//bar", "foo./.bar", ""},
{"foo///bar", "foo///bar", "foo././.bar", ""},
{"foo////bar", "foo////bar", "foo./././.bar", ""},
// These should produce errors
{"foo/+", "foo/+", "", "wildcards not allowed in publish"},
{"foo/#", "foo/#", "", "wildcards not allowed in publish"},
{"foo bar", "foo bar", "", "not supported"},
{"foo.bar", "foo.bar", "", "not supported"},
} {
t.Run(test.name, func(t *testing.T) {
res, err := mqttTopicToNATSPubSubject([]byte(test.mqttTopic))
if test.err != _EMPTY_ {
if err == nil || !strings.Contains(err.Error(), test.err) {
t.Fatalf("Expected error %q, got %q", test.err, err.Error())
}
return
}
toNATS := string(res)
if toNATS != test.natsSubject {
t.Fatalf("Expected subject %q got %q", test.natsSubject, toNATS)
}
res = natsSubjectToMQTTTopic(string(res))
backToMQTT := string(res)
if backToMQTT != test.mqttTopic {
t.Fatalf("Expected topic %q got %q (NATS conversion was %q)", test.mqttTopic, backToMQTT, toNATS)
}
})
}
}
func TestMQTTFilterConversion(t *testing.T) {
// Similar to TopicConversion test except that wildcards are OK here.
// So testing only those.
for _, test := range []struct {
name string
mqttTopic string
natsSubject string
}{
{"single level wildcard", "+", "*"},
{"single level wildcard", "/+", "/.*"},
{"single level wildcard", "+/", "*./"},
{"single level wildcard", "/+/", "/.*./"},
{"single level wildcard", "foo/+", "foo.*"},
{"single level wildcard", "foo/+/", "foo.*./"},
{"single level wildcard", "foo/+/bar", "foo.*.bar"},
{"single level wildcard", "foo/+/+", "foo.*.*"},
{"single level wildcard", "foo/+/+/", "foo.*.*./"},
{"single level wildcard", "foo/+/+/bar", "foo.*.*.bar"},
{"single level wildcard", "foo//+", "foo./.*"},
{"single level wildcard", "foo//+/", "foo./.*./"},
{"single level wildcard", "foo//+//", "foo./.*././"},
{"single level wildcard", "foo//+//bar", "foo./.*./.bar"},
{"single level wildcard", "foo///+///bar", "foo././.*././.bar"},
{"multi level wildcard", "#", ">"},
{"multi level wildcard", "/#", "/.>"},
{"multi level wildcard", "/foo/#", "/.foo.>"},
{"multi level wildcard", "foo/#", "foo.>"},
{"multi level wildcard", "foo//#", "foo./.>"},
{"multi level wildcard", "foo///#", "foo././.>"},
{"multi level wildcard", "foo/bar/#", "foo.bar.>"},
} {
t.Run(test.name, func(t *testing.T) {
res, err := mqttFilterToNATSSubject([]byte(test.mqttTopic))
if err != nil {
t.Fatalf("Error: %v", err)
}
if string(res) != test.natsSubject {
t.Fatalf("Expected subject %q got %q", test.natsSubject, res)
}
})
}
}
func testMQTTReaderHasAtLeastOne(t testing.TB, r *mqttReader) {
t.Helper()
r.reader.SetReadDeadline(time.Now().Add(testMQTTTimeout))
if err := r.ensurePacketInBuffer(1); err != nil {
t.Fatal(err)
}
r.reader.SetReadDeadline(time.Time{})
}
func TestMQTTParseSub(t *testing.T) {
eofr := testNewEOFReader()
for _, test := range []struct {
name string
proto []byte
b byte
pl int
reader mqttIOReader
err string
}{
{"reserved flag", nil, 3, 0, nil, "wrong subscribe reserved flags"},
{"ensure packet loaded", []byte{1, 2}, mqttSubscribeFlags, 10, eofr, "error ensuring protocol is loaded"},
{"error reading packet id", []byte{1}, mqttSubscribeFlags, 1, eofr, "reading packet identifier"},
{"missing filters", []byte{0, 1}, mqttSubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"},
{"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttSubscribeFlags, 5, eofr, "topic filter"},
{"empty topic", []byte{0, 1, 0, 0}, mqttSubscribeFlags, 4, nil, errMQTTTopicFilterCannotBeEmpty.Error()},
{"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttSubscribeFlags, 5, nil, "invalid utf8 for topic filter"},
{"missing qos", []byte{0, 1, 0, 1, 'a'}, mqttSubscribeFlags, 5, nil, "QoS"},
{"invalid qos", []byte{0, 1, 0, 1, 'a', 3}, mqttSubscribeFlags, 6, nil, "subscribe QoS value must be 0, 1 or 2"},
} {
t.Run(test.name, func(t *testing.T) {
r := &mqttReader{reader: test.reader}
r.reset(test.proto)
mqtt := &mqtt{r: r}
c := &client{mqtt: mqtt}
if _, _, err := c.mqttParseSubsOrUnsubs(r, test.b, test.pl, true); err == nil || !strings.Contains(err.Error(), test.err) {
t.Fatalf("Expected error %q, got %v", test.err, err)
}
})
}
}
func testMQTTSub(t testing.TB, pi uint16, c net.Conn, r *mqttReader, filters []*mqttFilter, expected []byte) {
t.Helper()
w := &mqttWriter{}
pkLen := 2 // for pi
for i := 0; i < len(filters); i++ {
f := filters[i]
pkLen += 2 + len(f.filter) + 1
}
w.WriteByte(mqttPacketSub | mqttSubscribeFlags)
w.WriteVarInt(pkLen)
w.WriteUint16(pi)
for i := 0; i < len(filters); i++ {
f := filters[i]
w.WriteBytes([]byte(f.filter))
w.WriteByte(f.qos)
}
if _, err := testMQTTWrite(c, w.Bytes()); err != nil {
t.Fatalf("Error writing SUBSCRIBE protocol: %v", err)
}
// Make sure we have at least 1 byte in buffer (if not will read)
testMQTTReaderHasAtLeastOne(t, r)
// Parse SUBACK
b, err := r.readByte("packet type")
if err != nil {
t.Fatal(err)
}
if pt := b & mqttPacketMask; pt != mqttPacketSubAck {
t.Fatalf("Expected SUBACK packet %x, got %x", mqttPacketSubAck, pt)
}
pl, err := r.readPacketLen()
if err != nil {
t.Fatal(err)
}
if err := r.ensurePacketInBuffer(pl); err != nil {
t.Fatal(err)
}
rpi, err := r.readUint16("packet identifier")
if err != nil || rpi != pi {
t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err)
}
for i, rem := 0, pl-2; rem > 0; rem-- {
qos, err := r.readByte("filter qos")
if err != nil {
t.Fatal(err)
}
if qos != expected[i] {
t.Fatalf("For topic filter %q expected qos of %v, got %v",
filters[i].filter, expected[i], qos)
}
i++
}
}
func TestMQTTSubAck(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
subs := []*mqttFilter{
{filter: "foo", qos: 0},
{filter: "bar", qos: 1},
{filter: "baz", qos: 2}, // Since we don't support, we should receive a result of 1
{filter: "foo/#/bar", qos: 0}, // Invalid sub, so we should receive a result of mqttSubAckFailure
}
expected := []byte{
0,
1,
1,
mqttSubAckFailure,
}
testMQTTSub(t, 1, mc, r, subs, expected)
}
func testMQTTFlush(t testing.TB, c net.Conn, bw *bufio.Writer, r *mqttReader) {
t.Helper()
w := &mqttWriter{}
w.WriteByte(mqttPacketPing)
w.WriteByte(0)
if bw != nil {
bw.Write(w.Bytes())
bw.Flush()
} else {
c.Write(w.Bytes())
}
r.ensurePacketInBuffer(2)
ab, err := r.readByte("pingresp")
if err != nil {
t.Fatalf("Error reading ping response: %v", err)
}
if pt := ab & mqttPacketMask; pt != mqttPacketPingResp {
t.Fatalf("Expected ping response got %x", pt)
}
l, err := r.readPacketLen()
if err != nil {
t.Fatal(err)
}
if l != 0 {
t.Fatalf("Expected PINGRESP length to be 0, got %v", l)
}
}
func testMQTTExpectNothing(t testing.TB, r *mqttReader) {
t.Helper()
r.reader.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
if err := r.ensurePacketInBuffer(1); err == nil {
t.Fatalf("Expected nothing, got %v", r.buf[r.pos:])
}
r.reader.SetReadDeadline(time.Time{})
}
func testMQTTCheckPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, flags byte, payload []byte) {
t.Helper()
pflags, pi := testMQTTGetPubMsg(t, c, r, topic, payload)
if pflags != flags {
t.Fatalf("Expected flags to be %x, got %x", flags, pflags)
}
if pi > 0 {
testMQTTSendPubAck(t, c, pi)
}
}
func testMQTTCheckPubMsgNoAck(t testing.TB, c net.Conn, r *mqttReader, topic string, flags byte, payload []byte) uint16 {
t.Helper()
pflags, pi := testMQTTGetPubMsg(t, c, r, topic, payload)
if pflags != flags {
t.Fatalf("Expected flags to be %x, got %x", flags, pflags)
}
return pi
}
func testMQTTGetPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, payload []byte) (byte, uint16) {
t.Helper()
testMQTTReaderHasAtLeastOne(t, r)
b, err := r.readByte("packet type")
if err != nil {
t.Fatal(err)
}
if pt := b & mqttPacketMask; pt != mqttPacketPub {
t.Fatalf("Expected PUBLISH packet %x, got %x", mqttPacketPub, pt)
}
pflags := b & mqttPacketFlagMask
qos := (pflags & mqttPubFlagQoS) >> 1
pl, err := r.readPacketLen()
if err != nil {
t.Fatal(err)
}
if err := r.ensurePacketInBuffer(pl); err != nil {
t.Fatal(err)
}
start := r.pos
ptopic, err := r.readString("topic name")
if err != nil {
t.Fatal(err)
}
if ptopic != topic {
t.Fatalf("Expected topic %q, got %q", topic, ptopic)
}
var pi uint16
if qos > 0 {
pi, err = r.readUint16("packet identifier")
if err != nil {
t.Fatal(err)
}
}
msgLen := pl - (r.pos - start)
if r.pos+msgLen > len(r.buf) {
t.Fatalf("computed message length goes beyond buffer: ml=%v pos=%v lenBuf=%v",
msgLen, r.pos, len(r.buf))
}
ppayload := r.buf[r.pos : r.pos+msgLen]
if !bytes.Equal(payload, ppayload) {
t.Fatalf("Expected payload %q, got %q", payload, ppayload)
}
r.pos += msgLen
return pflags, pi
}
func testMQTTSendPubAck(t testing.TB, c net.Conn, pi uint16) {
t.Helper()
w := &mqttWriter{}
w.WriteByte(mqttPacketPubAck)
w.WriteVarInt(2)
w.WriteUint16(pi)
if _, err := testMQTTWrite(c, w.Bytes()); err != nil {
t.Fatalf("Error writing PUBACK: %v", err)
}
}
func testMQTTPublish(t testing.TB, c net.Conn, r *mqttReader, qos byte, dup, retain bool, topic string, pi uint16, payload []byte) {
t.Helper()
w := &mqttWriter{}
mqttWritePublish(w, qos, dup, retain, topic, pi, payload)
if _, err := testMQTTWrite(c, w.Bytes()); err != nil {
t.Fatalf("Error writing PUBLISH proto: %v", err)
}
if qos > 0 {
// Since we don't support QoS 2, we should get disconnected
if qos == 2 {
testMQTTExpectDisconnect(t, c)
return
}
testMQTTReaderHasAtLeastOne(t, r)
// Parse PUBACK
b, err := r.readByte("packet type")
if err != nil {
t.Fatal(err)
}
if pt := b & mqttPacketMask; pt != mqttPacketPubAck {
t.Fatalf("Expected PUBACK packet %x, got %x", mqttPacketPubAck, pt)
}
pl, err := r.readPacketLen()
if err != nil {
t.Fatal(err)
}
if err := r.ensurePacketInBuffer(pl); err != nil {
t.Fatal(err)
}
rpi, err := r.readUint16("packet identifier")
if err != nil || rpi != pi {
t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err)
}
}
}
func TestMQTTParsePub(t *testing.T) {
eofr := testNewEOFReader()
for _, test := range []struct {
name string
flags byte
proto []byte
pl int
reader mqttIOReader
err string
}{
{"qos not supported", 0x4, nil, 0, nil, "not supported"},
{"packet in buffer error", 0, nil, 10, eofr, "error ensuring protocol is loaded"},
{"error on topic", 0, []byte{0, 3, 'f', 'o'}, 4, eofr, "topic"},
{"empty topic", 0, []byte{0, 0}, 2, nil, "topic cannot be empty"},
{"wildcards topic", 0, []byte{0, 1, '#'}, 3, nil, "wildcards not allowed"},
{"error on packet identifier", mqttPubQos1, []byte{0, 3, 'f', 'o', 'o'}, 5, eofr, "packet identifier"},
{"invalid packet identifier", mqttPubQos1, []byte{0, 3, 'f', 'o', 'o', 0, 0}, 7, nil, "packet identifier cannot be 0"},
} {
t.Run(test.name, func(t *testing.T) {
r := &mqttReader{reader: test.reader}
r.reset(test.proto)
mqtt := &mqtt{r: r}
c := &client{mqtt: mqtt}
pp := &mqttPublish{flags: test.flags}
if err := c.mqttParsePub(r, test.pl, pp); err == nil || !strings.Contains(err.Error(), test.err) {
t.Fatalf("Expected error %q, got %v", test.err, err)
}
})
}
}
func TestMQTTParsePubAck(t *testing.T) {
eofr := testNewEOFReader()
for _, test := range []struct {
name string
proto []byte
pl int
reader mqttIOReader
err string
}{
{"packet in buffer error", nil, 10, eofr, "error ensuring protocol is loaded"},
{"error reading packet identifier", []byte{0}, 1, eofr, "packet identifier"},
{"invalid packet identifier", []byte{0, 0}, 2, nil, "packet identifier cannot be 0"},
} {
t.Run(test.name, func(t *testing.T) {
r := &mqttReader{reader: test.reader}
r.reset(test.proto)
if _, err := mqttParsePubAck(r, test.pl); err == nil || !strings.Contains(err.Error(), test.err) {
t.Fatalf("Expected error %q, got %v", test.err, err)
}
})
}
}
func TestMQTTPublish(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mcp.Close()
testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, mcp, mpr, 0, false, false, "foo", 0, []byte("msg"))
testMQTTPublish(t, mcp, mpr, 1, false, false, "foo", 1, []byte("msg"))
testMQTTPublish(t, mcp, mpr, 2, false, false, "foo", 2, []byte("msg"))
}
func TestMQTTSub(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mcp.Close()
testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false)
for _, test := range []struct {
name string
mqttSubTopic string
natsPubSubject string
mqttPubTopic string
ok bool
}{
{"1 level match", "foo", "foo", "foo", true},
{"1 level no match", "foo", "bar", "bar", false},
{"2 levels match", "foo/bar", "foo.bar", "foo/bar", true},
{"2 levels no match", "foo/bar", "foo.baz", "foo/baz", false},
{"3 levels match", "/foo/bar", "/.foo.bar", "/foo/bar", true},
{"3 levels no match", "/foo/bar", "/.foo.baz", "/foo/baz", false},
{"single level wc", "foo/+", "foo.bar.baz", "foo/bar/baz", false},
{"single level wc", "foo/+", "foo.bar./", "foo/bar/", false},
{"single level wc", "foo/+", "foo.bar", "foo/bar", true},
{"single level wc", "foo/+", "foo./", "foo/", true},
{"single level wc", "foo/+", "foo", "foo", false},
{"single level wc", "foo/+", "/.foo", "/foo", false},
{"multiple level wc", "foo/#", "foo.bar.baz./", "foo/bar/baz/", true},
{"multiple level wc", "foo/#", "foo.bar.baz", "foo/bar/baz", true},
{"multiple level wc", "foo/#", "foo.bar./", "foo/bar/", true},
{"multiple level wc", "foo/#", "foo.bar", "foo/bar", true},
{"multiple level wc", "foo/#", "foo./", "foo/", true},
{"multiple level wc", "foo/#", "foo", "foo", true},
{"multiple level wc", "foo/#", "/.foo", "/foo", false},
} {
t.Run(test.name, func(t *testing.T) {
mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: test.mqttSubTopic, qos: 0}}, []byte{0})
testMQTTFlush(t, mc, nil, r)
natsPub(t, nc, test.natsPubSubject, []byte("msg"))
if test.ok {
testMQTTCheckPubMsg(t, mc, r, test.mqttPubTopic, 0, []byte("msg"))
} else {
testMQTTExpectNothing(t, r)
}
testMQTTPublish(t, mcp, mpr, 0, false, false, test.mqttPubTopic, 0, []byte("msg"))
if test.ok {
testMQTTCheckPubMsg(t, mc, r, test.mqttPubTopic, 0, []byte("msg"))
} else {
testMQTTExpectNothing(t, r)
}
})
}
}
func TestMQTTSubQoS(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mcp.Close()
testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false)
mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
mqttTopic := "foo/bar"
// Subscribe with QoS 1
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo/#", qos: 1}}, []byte{1})
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: mqttTopic, qos: 1}}, []byte{1})
testMQTTFlush(t, mc, nil, r)
// Publish from NATS, which means QoS 0
natsPub(t, nc, "foo.bar", []byte("NATS"))
// Will receive as QoS 0
testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("NATS"))
testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("NATS"))
// Publish from MQTT with QoS 0
testMQTTPublish(t, mcp, mpr, 0, false, false, mqttTopic, 0, []byte("msg"))
// Will receive as QoS 0
testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("msg"))
testMQTTCheckPubMsg(t, mc, r, mqttTopic, 0, []byte("msg"))
// Publish from MQTT with QoS 1
testMQTTPublish(t, mcp, mpr, 1, false, false, mqttTopic, 1, []byte("msg"))
pflags1, pi1 := testMQTTGetPubMsg(t, mc, r, mqttTopic, []byte("msg"))
if pflags1 != 0x2 {
t.Fatalf("Expected flags to be 0x2, got %v", pflags1)
}
pflags2, pi2 := testMQTTGetPubMsg(t, mc, r, mqttTopic, []byte("msg"))
if pflags2 != 0x2 {
t.Fatalf("Expected flags to be 0x2, got %v", pflags2)
}
if pi1 == pi2 {
t.Fatalf("packet identifier for message 1: %v should be different from message 2", pi1)
}
testMQTTSendPubAck(t, mc, pi1)
testMQTTSendPubAck(t, mc, pi2)
}
func getSubQoS(sub *subscription) int {
if sub.mqtt != nil {
return int(sub.mqtt.qos)
}
return -1
}
func TestMQTTSubDups(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mcp.Close()
testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false)
mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
// Test with single SUBSCRIBE protocol but multiple filters
filters := []*mqttFilter{
{filter: "foo", qos: 1},
{filter: "foo", qos: 0},
}
testMQTTSub(t, 1, mc, r, filters, []byte{1, 0})
testMQTTFlush(t, mc, nil, r)
// And also with separate SUBSCRIBE protocols
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "bar", qos: 0}}, []byte{0})
// Ask for QoS 2 but server will downgrade to 1
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "bar", qos: 2}}, []byte{1})
testMQTTFlush(t, mc, nil, r)
// Publish and test msg received only once
testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg"))
testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg"))
testMQTTExpectNothing(t, r)
testMQTTPublish(t, mcp, r, 0, false, false, "bar", 0, []byte("msg"))
testMQTTCheckPubMsg(t, mc, r, "bar", 0, []byte("msg"))
testMQTTExpectNothing(t, r)
// Check that the QoS for subscriptions have been updated to the latest received filter
var err error
subc := testMQTTGetClient(t, s, "sub")
subc.mu.Lock()
if subc.opts.Username != "sub" {
err = fmt.Errorf("wrong user name")
}
if err == nil {
if sub := subc.subs["foo"]; sub == nil || getSubQoS(sub) != 0 {
err = fmt.Errorf("subscription foo QoS should be 0, got %v", getSubQoS(sub))
}
}
if err == nil {
if sub := subc.subs["bar"]; sub == nil || getSubQoS(sub) != 1 {
err = fmt.Errorf("subscription bar QoS should be 1, got %v", getSubQoS(sub))
}
}
subc.mu.Unlock()
if err != nil {
t.Fatal(err)
}
// Now subscribe on "foo/#" which means that a PUBLISH on "foo" will be received
// by this subscription and also the one on "foo".
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo/#", qos: 1}}, []byte{1})
testMQTTFlush(t, mc, nil, r)
// Publish and test msg received twice
testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg"))
testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg"))
testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg"))
checkWCSub := func(expectedQoS int) {
t.Helper()
subc.mu.Lock()
defer subc.mu.Unlock()
// When invoked with expectedQoS==1, we have the following subs:
// foo (QoS-0), bar (QoS-1), foo.> (QoS-1)
// which means (since QoS-1 have a JS consumer + sub for delivery
// and foo.> causes a "foo fwc") that we should have the following
// number of NATS subs: foo (1), bar (2), foo.> (2) and "foo fwc" (2),
// so total=7.
// When invoked with expectedQoS==0, it means that we have replaced
// foo/# QoS-1 to QoS-0, so we should have 2 less NATS subs,
// so total=5
expected := 7
if expectedQoS == 0 {
expected = 5
}
if lenmap := len(subc.subs); lenmap != expected {
t.Fatalf("Subs map should have %v entries, got %v", expected, lenmap)
}
if sub, ok := subc.subs["foo.>"]; !ok {
t.Fatal("Expected sub foo.> to be present but was not")
} else if getSubQoS(sub) != expectedQoS {
t.Fatalf("Expected sub foo.> QoS to be %v, got %v", expectedQoS, getSubQoS(sub))
}
if sub, ok := subc.subs["foo fwc"]; !ok {
t.Fatal("Expected sub foo fwc to be present but was not")
} else if getSubQoS(sub) != expectedQoS {
t.Fatalf("Expected sub foo fwc QoS to be %v, got %v", expectedQoS, getSubQoS(sub))
}
// Make sure existing sub on "foo" qos was not changed.
if sub, ok := subc.subs["foo"]; !ok {
t.Fatal("Expected sub foo to be present but was not")
} else if getSubQoS(sub) != 0 {
t.Fatalf("Expected sub foo QoS to be 0, got %v", getSubQoS(sub))
}
}
checkWCSub(1)
// Sub again on same subject with lower QoS
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo/#", qos: 0}}, []byte{0})
testMQTTFlush(t, mc, nil, r)
// Publish and test msg received twice
testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg"))
testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg"))
testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg"))
checkWCSub(0)
}
func TestMQTTSubWithSpaces(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mcp.Close()
testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false)
mc, r := testMQTTConnect(t, &mqttConnInfo{user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo bar", qos: 0}}, []byte{mqttSubAckFailure})
}
func TestMQTTSubCaseSensitive(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mcp.Close()
testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false)
mc, r := testMQTTConnect(t, &mqttConnInfo{user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "Foo/Bar", qos: 0}}, []byte{0})
testMQTTFlush(t, mc, nil, r)
testMQTTPublish(t, mcp, r, 0, false, false, "Foo/Bar", 0, []byte("msg"))
testMQTTCheckPubMsg(t, mc, r, "Foo/Bar", 0, []byte("msg"))
testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg"))
testMQTTExpectNothing(t, r)
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
natsPub(t, nc, "Foo.Bar", []byte("nats"))
testMQTTCheckPubMsg(t, mc, r, "Foo/Bar", 0, []byte("nats"))
natsPub(t, nc, "foo.bar", []byte("nats"))
testMQTTExpectNothing(t, r)
}
func TestMQTTPubSubMatrix(t *testing.T) {
for _, test := range []struct {
name string
natsPub bool
mqttPub bool
mqttPubQoS byte
natsSub bool
mqttSubQoS0 bool
mqttSubQoS1 bool
}{
{"NATS to MQTT sub QoS-0", true, false, 0, false, true, false},
{"NATS to MQTT sub QoS-1", true, false, 0, false, false, true},
{"NATS to MQTT sub QoS-0 and QoS-1", true, false, 0, false, true, true},
{"MQTT QoS-0 to NATS sub", false, true, 0, true, false, false},
{"MQTT QoS-0 to MQTT sub QoS-0", false, true, 0, false, true, false},
{"MQTT QoS-0 to MQTT sub QoS-1", false, true, 0, false, false, true},
{"MQTT QoS-0 to NATS sub and MQTT sub QoS-0", false, true, 0, true, true, false},
{"MQTT QoS-0 to NATS sub and MQTT sub QoS-1", false, true, 0, true, false, true},
{"MQTT QoS-0 to all subs", false, true, 0, true, true, true},
{"MQTT QoS-1 to NATS sub", false, true, 1, true, false, false},
{"MQTT QoS-1 to MQTT sub QoS-0", false, true, 1, false, true, false},
{"MQTT QoS-1 to MQTT sub QoS-1", false, true, 1, false, false, true},
{"MQTT QoS-1 to NATS sub and MQTT sub QoS-0", false, true, 1, true, true, false},
{"MQTT QoS-1 to NATS sub and MQTT sub QoS-1", false, true, 1, true, false, true},
{"MQTT QoS-1 to all subs", false, true, 1, true, true, true},
} {
t.Run(test.name, func(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
mc1, r1 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc1.Close()
testMQTTCheckConnAck(t, r1, mqttConnAckRCConnectionAccepted, false)
mc2, r2 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc2.Close()
testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, false)
// First setup subscriptions based on test options.
var ns *nats.Subscription
if test.natsSub {
ns = natsSubSync(t, nc, "foo")
}
if test.mqttSubQoS0 {
testMQTTSub(t, 1, mc1, r1, []*mqttFilter{{filter: "foo", qos: 0}}, []byte{0})
testMQTTFlush(t, mc1, nil, r1)
}
if test.mqttSubQoS1 {
testMQTTSub(t, 1, mc2, r2, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
testMQTTFlush(t, mc2, nil, r2)
}
// Just as a barrier
natsFlush(t, nc)
// Now publish
if test.natsPub {
natsPubReq(t, nc, "foo", "", []byte("msg"))
} else {
testMQTTPublish(t, mc, r, test.mqttPubQoS, false, false, "foo", 1, []byte("msg"))
}
// Check message received
if test.natsSub {
natsNexMsg(t, ns, time.Second)
// Make sure no other is received
if msg, err := ns.NextMsg(50 * time.Millisecond); err == nil {
t.Fatalf("Should not have gotten a second message, got %v", msg)
}
}
if test.mqttSubQoS0 {
testMQTTCheckPubMsg(t, mc1, r1, "foo", 0, []byte("msg"))
testMQTTExpectNothing(t, r1)
}
if test.mqttSubQoS1 {
var expectedFlag byte
if test.mqttPubQoS > 0 {
expectedFlag = test.mqttPubQoS << 1
}
testMQTTCheckPubMsg(t, mc2, r2, "foo", expectedFlag, []byte("msg"))
testMQTTExpectNothing(t, r2)
}
})
}
}
func TestMQTTPreventSubWithMQTTSubPrefix(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc, r,
[]*mqttFilter{{filter: strings.ReplaceAll(mqttSubPrefix, ".", "/") + "foo/bar", qos: 1}},
[]byte{mqttSubAckFailure})
}
func TestMQTTSubWithNATSStream(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo/bar", qos: 1}}, []byte{1})
testMQTTFlush(t, mc, nil, r)
mcp, rp := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mcp.Close()
testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false)
testMQTTFlush(t, mcp, nil, rp)
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
sc := &StreamConfig{
Name: "test",
Storage: FileStorage,
Retention: InterestPolicy,
Subjects: []string{"foo.>"},
}
mset, err := s.GlobalAccount().addStream(sc)
if err != nil {
t.Fatalf("Unable to create stream: %v", err)
}
sub := natsSubSync(t, nc, "bar")
cc := &ConsumerConfig{
Durable: "dur",
AckPolicy: AckExplicit,
DeliverSubject: "bar",
}
if _, err := mset.addConsumer(cc); err != nil {
t.Fatalf("Unable to add consumer: %v", err)
}
// Now send message from NATS
resp, err := nc.Request("foo.bar", []byte("nats"), time.Second)
if err != nil {
t.Fatalf("Error publishing: %v", err)
}
ar := &ApiResponse{}
if err := json.Unmarshal(resp.Data, ar); err != nil || ar.Error != nil {
t.Fatalf("Unexpected response: err=%v resp=%+v", err, ar.Error)
}
// Check that message is received by both
checkRecv := func(content string, flags byte) {
t.Helper()
if msg := natsNexMsg(t, sub, time.Second); string(msg.Data) != content {
t.Fatalf("Expected %q, got %q", content, msg.Data)
}
testMQTTCheckPubMsg(t, mc, r, "foo/bar", flags, []byte(content))
}
checkRecv("nats", 0)
// Send from MQTT as a QoS0
testMQTTPublish(t, mcp, rp, 0, false, false, "foo/bar", 0, []byte("qos0"))
checkRecv("qos0", 0)
// Send from MQTT as a QoS1
testMQTTPublish(t, mcp, rp, 1, false, false, "foo/bar", 1, []byte("qos1"))
checkRecv("qos1", mqttPubQos1)
}
func TestMQTTTrackPendingOverrun(t *testing.T) {
sess := &mqttSession{pending: make(map[uint16]*mqttPending)}
sub := &subscription{mqtt: &mqttSub{qos: 1}}
sess.ppi = 0xFFFF
pi, _ := sess.trackPending(1, _EMPTY_, sub)
if pi != 1 {
t.Fatalf("Expected 1, got %v", pi)
}
p := &mqttPending{}
for i := 1; i <= 0xFFFF; i++ {
sess.pending[uint16(i)] = p
}
pi, _ = sess.trackPending(1, _EMPTY_, sub)
if pi != 0 {
t.Fatalf("Expected 0, got %v", pi)
}
delete(sess.pending, 1234)
pi, _ = sess.trackPending(1, _EMPTY_, sub)
if pi != 1234 {
t.Fatalf("Expected 1234, got %v", pi)
}
}
func TestMQTTSubRestart(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: false}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
// Start an MQTT subscription QoS=1 on "foo"
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
testMQTTFlush(t, mc, nil, r)
// Now start a NATS subscription on ">" (anything that would match the JS consumer delivery subject)
natsSubSync(t, nc, ">")
natsFlush(t, nc)
// Restart the MQTT client
testMQTTDisconnect(t, mc, nil)
mc, r = testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: false}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true)
// Restart an MQTT subscription QoS=1 on "foo"
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
testMQTTFlush(t, mc, nil, r)
pc, pr := testMQTTConnect(t, &mqttConnInfo{clientID: "pub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer pc.Close()
testMQTTCheckConnAck(t, pr, mqttConnAckRCConnectionAccepted, false)
// Publish a message QoS1
testMQTTPublish(t, pc, pr, 1, false, false, "foo", 1, []byte("msg1"))
// Make sure we receive it
testMQTTCheckPubMsg(t, mc, r, "foo", mqttPubQos1, []byte("msg1"))
// Now "restart" the subscription but as a Qos0
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo", qos: 0}}, []byte{0})
testMQTTFlush(t, mc, nil, r)
// Publish a message QoS
testMQTTPublish(t, pc, pr, 1, false, false, "foo", 1, []byte("msg2"))
// Make sure we receive but as a QoS0
testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg2"))
}
func TestMQTTSubPropagation(t *testing.T) {
cl := createJetStreamClusterWithTemplate(t, jsClusterTemplWithMQTT, "MQTT", 2)
defer cl.shutdown()
o := cl.opts[0]
s2 := cl.servers[1]
nc := natsConnect(t, s2.ClientURL())
defer nc.Close()
mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo/#", qos: 0}}, []byte{0})
testMQTTFlush(t, mc, nil, r)
// Because in MQTT foo/# means foo.> but also foo, check that this is propagated
checkSubInterest(t, s2, globalAccountName, "foo", time.Second)
// Publish on foo.bar, foo./ and foo and we should receive them
natsPub(t, nc, "foo.bar", []byte("hello"))
testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("hello"))
natsPub(t, nc, "foo./", []byte("from"))
testMQTTCheckPubMsg(t, mc, r, "foo/", 0, []byte("from"))
natsPub(t, nc, "foo", []byte("NATS"))
testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("NATS"))
}
func TestMQTTCluster(t *testing.T) {
cl := createJetStreamClusterWithTemplate(t, jsClusterTemplWithMQTT, "MQTT", 2)
defer cl.shutdown()
for _, topTest := range []struct {
name string
restart bool
}{
{"first_start", true},
{"restart", false},
} {
t.Run(topTest.name, func(t *testing.T) {
for _, test := range []struct {
name string
subQos byte
}{
{"qos_0", 0},
{"qos_1", 1},
} {
t.Run(test.name, func(t *testing.T) {
clientID := nuid.Next()
o := cl.opts[0]
mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: clientID, cleanSess: false}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo/#", qos: test.subQos}}, []byte{test.subQos})
testMQTTFlush(t, mc, nil, r)
check := func(mc net.Conn, r *mqttReader, o *Options, s *Server) {
t.Helper()
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
natsPub(t, nc, "foo.bar", []byte("fromNats"))
testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("fromNats"))
mpc, pr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mpc.Close()
testMQTTCheckConnAck(t, pr, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, mpc, pr, 0, false, false, "foo/baz", 0, []byte("mqtt_qos0"))
testMQTTCheckPubMsg(t, mc, r, "foo/baz", 0, []byte("mqtt_qos0"))
testMQTTPublish(t, mpc, pr, 1, false, false, "foo/bat", 1, []byte("mqtt_qos1"))
expectedQoS := byte(0)
if test.subQos == 1 {
expectedQoS = mqttPubQos1
}
testMQTTCheckPubMsg(t, mc, r, "foo/bat", expectedQoS, []byte("mqtt_qos1"))
testMQTTDisconnect(t, mpc, nil)
}
check(mc, r, cl.opts[0], cl.servers[0])
check(mc, r, cl.opts[1], cl.servers[1])
// Start the same subscription from the other server. It should disconnect
// the one connected in the first server.
o = cl.opts[1]
mc2, r2 := testMQTTConnect(t, &mqttConnInfo{clientID: clientID, cleanSess: false}, o.MQTT.Host, o.MQTT.Port)
defer mc2.Close()
testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, true)
// Expect first connection to be closed.
testMQTTExpectDisconnect(t, mc)
// Now re-run the checks
check(mc2, r2, cl.opts[0], cl.servers[0])
check(mc2, r2, cl.opts[1], cl.servers[1])
// Disconnect our sub and restart with clean session then disconnect again to clear the state.
testMQTTDisconnect(t, mc2, nil)
mc2, r2 = testMQTTConnect(t, &mqttConnInfo{clientID: clientID, cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc2.Close()
testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, false)
testMQTTFlush(t, mc2, nil, r2)
testMQTTDisconnect(t, mc2, nil)
// Remove the session from the flappers so we can restart the test
// without failure and have to wait for 1sec before being able to reconnect.
s := cl.servers[0]
sm := &s.mqtt.sessmgr
sm.mu.Lock()
asm := sm.sessions[globalAccountName]
sm.mu.Unlock()
if asm != nil {
asm.mu.Lock()
delete(asm.flappers, clientID)
asm.mu.Unlock()
}
})
}
if topTest.restart {
cl.stopAll()
cl.restartAll()
streams := []string{mqttStreamName, mqttRetainedMsgsStreamName}
for _, sn := range streams {
cl.waitOnStreamLeader(globalAccountName, sn)
}
cl.waitOnConsumerLeader(globalAccountName, mqttRetainedMsgsStreamName, "$MQTT_rmsgs_esFhDys3")
cl.waitOnConsumerLeader(globalAccountName, mqttRetainedMsgsStreamName, "$MQTT_rmsgs_z3WIzPtj")
}
})
}
}
func TestMQTTClusterRetainedMsg(t *testing.T) {
cl := createJetStreamClusterWithTemplate(t, jsClusterTemplWithMQTT, "MQTT", 2)
defer cl.shutdown()
srv1Opts := cl.opts[0]
srv2Opts := cl.opts[1]
// Connect subscription on server 1.
mc, rc := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: false}, srv1Opts.MQTT.Host, srv1Opts.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, rc, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc, rc, []*mqttFilter{{filter: "foo/#", qos: 1}}, []byte{1})
testMQTTFlush(t, mc, nil, rc)
// Create a publisher from server 2.
mp, rp := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, srv2Opts.MQTT.Host, srv2Opts.MQTT.Port)
defer mp.Close()
testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false)
// Send retained message.
testMQTTPublish(t, mp, rp, 1, false, true, "foo/bar", 1, []byte("retained"))
// Check it is received.
testMQTTCheckPubMsg(t, mc, rc, "foo/bar", mqttPubQos1, []byte("retained"))
// Start a new subscription on server 1 and make sure we receive the retained message
mc2, rc2 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, srv1Opts.MQTT.Host, srv1Opts.MQTT.Port)
defer mc2.Close()
testMQTTCheckConnAck(t, rc2, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc2, rc2, []*mqttFilter{{filter: "foo/#", qos: 1}}, []byte{1})
testMQTTCheckPubMsg(t, mc2, rc2, "foo/bar", mqttPubQos1|mqttPubFlagRetain, []byte("retained"))
testMQTTDisconnect(t, mc2, nil)
// Send an empty retained message which should remove it from storage, but still be delivered.
testMQTTPublish(t, mp, rp, 1, false, true, "foo/bar", 1, []byte(""))
testMQTTCheckPubMsg(t, mc, rc, "foo/bar", mqttPubQos1, []byte(""))
// Now shutdown the consumer connection
testMQTTDisconnect(t, mc, nil)
mc.Close()
// Reconnect to server where the retained message was published (server 2)
mc, rc = testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: false}, srv2Opts.MQTT.Host, srv2Opts.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, rc, mqttConnAckRCConnectionAccepted, true)
testMQTTSub(t, 1, mc, rc, []*mqttFilter{{filter: "foo/#", qos: 1}}, []byte{1})
// The retained message should not be delivered.
testMQTTExpectNothing(t, rc)
// Now disconnect and reconnect back to first server
testMQTTDisconnect(t, mc, nil)
mc.Close()
// Now reconnect to the server 1, which is not where the messages were published, and check
// that we don't receive the message.
mc, rc = testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: false}, srv1Opts.MQTT.Host, srv1Opts.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, rc, mqttConnAckRCConnectionAccepted, true)
testMQTTSub(t, 1, mc, rc, []*mqttFilter{{filter: "foo/#", qos: 1}}, []byte{1})
testMQTTExpectNothing(t, rc)
}
func TestMQTTRetainedMsgNetworkUpdates(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
mc, rc := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, rc, mqttConnAckRCConnectionAccepted, false)
c := testMQTTGetClient(t, s, "sub")
asm := c.mqtt.asm
// For this test, we are going to simulate updates arriving in a
// mixed order and verify that we have the expected outcome.
check := func(t *testing.T, subject string, present bool, current, floor uint64) {
t.Helper()
asm.mu.RLock()
defer asm.mu.RUnlock()
erm, ok := asm.retmsgs[subject]
if present && !ok {
t.Fatalf("Subject %q not present", subject)
} else if !present && ok {
t.Fatalf("Subject %q should not be present", subject)
} else if !present {
return
}
if floor != erm.floor {
t.Fatalf("Expected floor to be %v, got %v", floor, erm.floor)
}
if erm.sseq != current {
t.Fatalf("Expected current sequence to be %v, got %v", current, erm.sseq)
}
}
type action struct {
add bool
seq uint64
}
for _, test := range []struct {
subject string
order []action
seq uint64
floor uint64
}{
{"foo.1", []action{{true, 1}, {true, 2}, {true, 3}}, 3, 0},
{"foo.2", []action{{true, 3}, {true, 1}, {true, 2}}, 3, 0},
{"foo.3", []action{{true, 1}, {false, 1}, {true, 2}}, 2, 0},
{"foo.4", []action{{false, 2}, {true, 1}, {true, 3}, {true, 2}}, 3, 0},
{"foo.5", []action{{false, 2}, {true, 1}, {true, 2}}, 0, 2},
{"foo.6", []action{{true, 1}, {true, 2}, {false, 2}}, 0, 2},
} {
t.Run(test.subject, func(t *testing.T) {
for _, a := range test.order {
if a.add {
rm := &mqttRetainedMsg{sseq: a.seq}
asm.handleRetainedMsg(test.subject, rm)
} else {
asm.handleRetainedMsgDel(test.subject, a.seq)
}
}
check(t, test.subject, true, test.seq, test.floor)
})
}
for _, subject := range []string{"foo.5", "foo.6"} {
t.Run("clear_"+subject, func(t *testing.T) {
// Now add a new message, which should clear the floor.
rm := &mqttRetainedMsg{sseq: 3}
asm.handleRetainedMsg(subject, rm)
check(t, subject, true, 3, 0)
// Now do a non network delete and make sure it is gone.
asm.handleRetainedMsgDel(subject, 0)
check(t, subject, false, 0, 0)
})
}
}
func TestMQTTClusterReplicasCount(t *testing.T) {
for _, test := range []struct {
size int
replicas int
}{
{1, 1},
{2, 2},
{3, 3},
{5, 3},
} {
t.Run(fmt.Sprintf("size %v", test.size), func(t *testing.T) {
var s *Server
var o *Options
if test.size == 1 {
o = testMQTTDefaultOptions()
s = testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
} else {
cl := createJetStreamClusterWithTemplate(t, jsClusterTemplWithMQTT, "MQTT", test.size)
defer cl.shutdown()
o = cl.opts[0]
s = cl.randomServer()
}
mc, rc := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: false}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, rc, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc, rc, []*mqttFilter{{filter: "foo/#", qos: 1}}, []byte{1})
testMQTTFlush(t, mc, nil, rc)
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
// Check the replicas of all MQTT streams
js, err := nc.JetStream()
if err != nil {
t.Fatalf("Error getting js: %v", err)
}
for _, sname := range []string{
mqttStreamName,
mqttRetainedMsgsStreamName,
mqttSessionsStreamNamePrefix + string(getHash("sub")),
} {
t.Run(sname, func(t *testing.T) {
si, err := js.StreamInfo(sname)
if err != nil {
t.Fatalf("Error geting stream info: %v", err)
}
if si.Config.Replicas != test.replicas {
t.Fatalf("Expected %v replicas, got %v", test.replicas, si.Config.Replicas)
}
})
}
})
}
}
func TestMQTTClusterPlacement(t *testing.T) {
c := createJetStreamClusterExplicit(t, "HUB", 3)
defer c.shutdown()
lnc := c.createLeafNodesWithStartPortAndMQTT("SPOKE", 3, 22111, `mqtt { listen: 127.0.0.1:-1 }`)
defer lnc.shutdown()
c.waitOnPeerCount(6)
c.waitOnLeader()
for i := 0; i < 10; i++ {
mc, rc := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, lnc.opts[i%3].MQTT.Host, lnc.opts[i%3].MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, rc, mqttConnAckRCConnectionAccepted, false)
}
// Now check that MQTT assets have been created in the LEAF node's side, not the Hub.
nc := natsConnect(t, lnc.servers[0].ClientURL())
defer nc.Close()
js, err := nc.JetStream()
if err != nil {
t.Fatalf("Unable to get JetStream: %v", err)
}
count := 0
for si := range js.StreamsInfo() {
if si.Cluster == nil || si.Cluster.Name != "SPOKE" {
t.Fatalf("Expected asset %q to be placed on spoke cluster, was placed on %+v", si.Config.Name, si.Cluster)
}
for _, repl := range si.Cluster.Replicas {
if !strings.HasPrefix(repl.Name, "SPOKE-") {
t.Fatalf("Replica on the wrong cluster: %+v", repl)
}
}
if si.State.Consumers > 0 {
for ci := range js.ConsumersInfo(si.Config.Name) {
if ci.Cluster == nil || ci.Cluster.Name != "SPOKE" {
t.Fatalf("Expected asset %q to be placed on spoke cluster, was placed on %+v", ci.Name, si.Cluster)
}
for _, repl := range ci.Cluster.Replicas {
if !strings.HasPrefix(repl.Name, "SPOKE-") {
t.Fatalf("Replica on the wrong cluster: %+v", repl)
}
}
}
}
count++
}
if count == 0 {
t.Fatal("No stream found!")
}
}
func TestMQTTParseUnsub(t *testing.T) {
eofr := testNewEOFReader()
for _, test := range []struct {
name string
proto []byte
b byte
pl int
reader mqttIOReader
err string
}{
{"reserved flag", nil, 3, 0, nil, "wrong unsubscribe reserved flags"},
{"ensure packet loaded", []byte{1, 2}, mqttUnsubscribeFlags, 10, eofr, "error ensuring protocol is loaded"},
{"error reading packet id", []byte{1}, mqttUnsubscribeFlags, 1, eofr, "reading packet identifier"},
{"missing filters", []byte{0, 1}, mqttUnsubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"},
{"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttUnsubscribeFlags, 5, eofr, "topic filter"},
{"empty topic", []byte{0, 1, 0, 0}, mqttUnsubscribeFlags, 4, nil, errMQTTTopicFilterCannotBeEmpty.Error()},
{"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttUnsubscribeFlags, 5, nil, "invalid utf8 for topic filter"},
} {
t.Run(test.name, func(t *testing.T) {
r := &mqttReader{reader: test.reader}
r.reset(test.proto)
mqtt := &mqtt{r: r}
c := &client{mqtt: mqtt}
if _, _, err := c.mqttParseSubsOrUnsubs(r, test.b, test.pl, false); err == nil || !strings.Contains(err.Error(), test.err) {
t.Fatalf("Expected error %q, got %v", test.err, err)
}
})
}
}
func testMQTTUnsub(t *testing.T, pi uint16, c net.Conn, r *mqttReader, filters []*mqttFilter) {
t.Helper()
w := &mqttWriter{}
pkLen := 2 // for pi
for i := 0; i < len(filters); i++ {
f := filters[i]
pkLen += 2 + len(f.filter)
}
w.WriteByte(mqttPacketUnsub | mqttUnsubscribeFlags)
w.WriteVarInt(pkLen)
w.WriteUint16(pi)
for i := 0; i < len(filters); i++ {
f := filters[i]
w.WriteBytes([]byte(f.filter))
}
if _, err := testMQTTWrite(c, w.Bytes()); err != nil {
t.Fatalf("Error writing UNSUBSCRIBE protocol: %v", err)
}
// Make sure we have at least 1 byte in buffer (if not will read)
testMQTTReaderHasAtLeastOne(t, r)
// Parse UNSUBACK
b, err := r.readByte("packet type")
if err != nil {
t.Fatal(err)
}
if pt := b & mqttPacketMask; pt != mqttPacketUnsubAck {
t.Fatalf("Expected UNSUBACK packet %x, got %x", mqttPacketUnsubAck, pt)
}
pl, err := r.readPacketLen()
if err != nil {
t.Fatal(err)
}
if err := r.ensurePacketInBuffer(pl); err != nil {
t.Fatal(err)
}
rpi, err := r.readUint16("packet identifier")
if err != nil || rpi != pi {
t.Fatalf("Error with packet identifier expected=%v got: %v err=%v", pi, rpi, err)
}
}
func TestMQTTUnsub(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
mcp, mpr := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mcp.Close()
testMQTTCheckConnAck(t, mpr, mqttConnAckRCConnectionAccepted, false)
mc, r := testMQTTConnect(t, &mqttConnInfo{user: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo", qos: 0}}, []byte{0})
testMQTTFlush(t, mc, nil, r)
// Publish and test msg received
testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg"))
testMQTTCheckPubMsg(t, mc, r, "foo", 0, []byte("msg"))
// Unsubscribe
testMQTTUnsub(t, 1, mc, r, []*mqttFilter{{filter: "foo"}})
// Publish and test msg not received
testMQTTPublish(t, mcp, r, 0, false, false, "foo", 0, []byte("msg"))
testMQTTExpectNothing(t, r)
// Use of wildcards subs
filters := []*mqttFilter{
{filter: "foo/bar", qos: 0},
{filter: "foo/#", qos: 0},
}
testMQTTSub(t, 1, mc, r, filters, []byte{0, 0})
testMQTTFlush(t, mc, nil, r)
// Publish and check that message received twice
testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg"))
testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("msg"))
testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("msg"))
// Unsub the wildcard one
testMQTTUnsub(t, 1, mc, r, []*mqttFilter{{filter: "foo/#"}})
// Publish and check that message received once
testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg"))
testMQTTCheckPubMsg(t, mc, r, "foo/bar", 0, []byte("msg"))
testMQTTExpectNothing(t, r)
// Unsub last
testMQTTUnsub(t, 1, mc, r, []*mqttFilter{{filter: "foo/bar"}})
// Publish and test msg not received
testMQTTPublish(t, mcp, r, 0, false, false, "foo/bar", 0, []byte("msg"))
testMQTTExpectNothing(t, r)
}
func testMQTTExpectDisconnect(t testing.TB, c net.Conn) {
if buf, err := testMQTTRead(c); err == nil {
t.Fatalf("Expected connection to be disconnected, got %s", buf)
}
}
func TestMQTTPublishTopicErrors(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
for _, test := range []struct {
name string
topic string
}{
{"empty", ""},
{"with single level wildcard", "foo/+"},
{"with multiple level wildcard", "foo/#"},
} {
t.Run(test.name, func(t *testing.T) {
mc, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, mc, r, 0, false, false, test.topic, 0, []byte("msg"))
testMQTTExpectDisconnect(t, mc)
})
}
}
func testMQTTDisconnect(t testing.TB, c net.Conn, bw *bufio.Writer) {
t.Helper()
w := &mqttWriter{}
w.WriteByte(mqttPacketDisconnect)
w.WriteByte(0)
if bw != nil {
bw.Write(w.Bytes())
bw.Flush()
} else {
c.Write(w.Bytes())
}
testMQTTExpectDisconnect(t, c)
}
func TestMQTTWill(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
sub := natsSubSync(t, nc, "will.topic")
willMsg := []byte("bye")
for _, test := range []struct {
name string
willExpected bool
willQoS byte
}{
{"will qos 0", true, 0},
{"will qos 1", true, 1},
{"proper disconnect no will", false, 0},
} {
t.Run(test.name, func(t *testing.T) {
mcs, rs := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mcs.Close()
testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mcs, rs, []*mqttFilter{{filter: "will/#", qos: 1}}, []byte{1})
testMQTTFlush(t, mcs, nil, rs)
mc, r := testMQTTConnect(t,
&mqttConnInfo{
cleanSess: true,
will: &mqttWill{
topic: []byte("will/topic"),
message: willMsg,
qos: test.willQoS,
},
}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
if test.willExpected {
mc.Close()
testMQTTCheckPubMsg(t, mcs, rs, "will/topic", test.willQoS<<1, willMsg)
wm := natsNexMsg(t, sub, time.Second)
if !bytes.Equal(wm.Data, willMsg) {
t.Fatalf("Expected will message to be %q, got %q", willMsg, wm.Data)
}
} else {
testMQTTDisconnect(t, mc, nil)
testMQTTExpectNothing(t, rs)
if wm, err := sub.NextMsg(100 * time.Millisecond); err == nil {
t.Fatalf("Should not have receive a message, got subj=%q data=%q",
wm.Subject, wm.Data)
}
}
})
}
}
func TestMQTTWillRetain(t *testing.T) {
for _, test := range []struct {
name string
pubQoS byte
subQoS byte
}{
{"pub QoS0 sub QoS0", 0, 0},
{"pub QoS0 sub QoS1", 0, 1},
{"pub QoS1 sub QoS0", 1, 0},
{"pub QoS1 sub QoS1", 1, 1},
} {
t.Run(test.name, func(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
willTopic := []byte("will/topic")
willMsg := []byte("bye")
mc, r := testMQTTConnect(t,
&mqttConnInfo{
cleanSess: true,
will: &mqttWill{
topic: willTopic,
message: willMsg,
qos: test.pubQoS,
retain: true,
},
}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
// Disconnect the client
mc.Close()
// Wait for the server to process the connection close, which will
// cause the "will" message to be published (and retained).
checkClientsCount(t, s, 0)
// Create subscription on will topic and expect will message.
mcs, rs := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mcs.Close()
testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mcs, rs, []*mqttFilter{{filter: "will/#", qos: test.subQoS}}, []byte{test.subQoS})
pflags, _ := testMQTTGetPubMsg(t, mcs, rs, "will/topic", willMsg)
if pflags&mqttPubFlagRetain == 0 {
t.Fatalf("expected retain flag to be set, it was not: %v", pflags)
}
// Expected QoS will be the lesser of the pub/sub QoS.
expectedQoS := test.pubQoS
if test.subQoS == 0 {
expectedQoS = 0
}
if qos := mqttGetQoS(pflags); qos != expectedQoS {
t.Fatalf("expected qos to be %v, got %v", expectedQoS, qos)
}
})
}
}
func TestMQTTWillRetainPermViolation(t *testing.T) {
template := `
port: -1
jetstream: enabled
authorization {
mqtt_perms = {
publish = ["%s"]
subscribe = ["foo", "bar", "$MQTT.sub.>"]
}
users = [
{user: mqtt, password: pass, permissions: $mqtt_perms}
]
}
mqtt {
port: -1
}
`
conf := createConfFile(t, []byte(fmt.Sprintf(template, "foo")))
defer removeFile(t, conf)
s, o := RunServerWithConfig(conf)
defer testMQTTShutdownServer(s)
ci := &mqttConnInfo{
cleanSess: true,
user: "mqtt",
pass: "pass",
}
// We create first a connection with the Will topic that the publisher
// is allowed to publish to.
ci.will = &mqttWill{
topic: []byte("foo"),
message: []byte("bye"),
qos: 1,
retain: true,
}
mc, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
// Disconnect, which will cause the Will to be sent with retain flag.
mc.Close()
// Wait for the server to process the connection close, which will
// cause the "will" message to be published (and retained).
checkClientsCount(t, s, 0)
// Create a subscription on the Will subject and we should receive it.
ci.will = nil
mcs, rs := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mcs.Close()
testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mcs, rs, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
pflags, _ := testMQTTGetPubMsg(t, mcs, rs, "foo", []byte("bye"))
if pflags&mqttPubFlagRetain == 0 {
t.Fatalf("expected retain flag to be set, it was not: %v", pflags)
}
if qos := mqttGetQoS(pflags); qos != 1 {
t.Fatalf("expected qos to be 1, got %v", qos)
}
testMQTTDisconnect(t, mcs, nil)
// Now create another connection with a Will that client is not allowed to publish to.
ci.will = &mqttWill{
topic: []byte("bar"),
message: []byte("bye"),
qos: 1,
retain: true,
}
mc, r = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
// Disconnect, to cause Will to be produced, but in that case should not be stored
// since user not allowed to publish on "bar".
mc.Close()
// Wait for the server to process the connection close, which will
// cause the "will" message to be published (and retained).
checkClientsCount(t, s, 0)
// Create sub on "bar" which user is allowed to subscribe to.
ci.will = nil
mcs, rs = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mcs.Close()
testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mcs, rs, []*mqttFilter{{filter: "bar", qos: 1}}, []byte{1})
// No Will should be published since it should not have been stored in the first place.
testMQTTExpectNothing(t, rs)
testMQTTDisconnect(t, mcs, nil)
// Now remove permission to publish on "foo" and check that a new subscription
// on "foo" is now not getting the will message because the original user no
// longer has permission to do so.
reloadUpdateConfig(t, s, conf, fmt.Sprintf(template, "baz"))
mcs, rs = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mcs.Close()
testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mcs, rs, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
testMQTTExpectNothing(t, rs)
testMQTTDisconnect(t, mcs, nil)
}
func TestMQTTPublishRetain(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
for _, test := range []struct {
name string
retained bool
sentValue string
expectedValue string
subGetsIt bool
}{
{"publish retained", true, "retained", "retained", true},
{"publish not retained", false, "not retained", "retained", true},
{"remove retained", true, "", "", false},
} {
t.Run(test.name, func(t *testing.T) {
mc1, rs1 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc1.Close()
testMQTTCheckConnAck(t, rs1, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, mc1, rs1, 0, false, test.retained, "foo", 0, []byte(test.sentValue))
testMQTTFlush(t, mc1, nil, rs1)
mc2, rs2 := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc2.Close()
testMQTTCheckConnAck(t, rs2, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc2, rs2, []*mqttFilter{{filter: "foo/#", qos: 1}}, []byte{1})
if test.subGetsIt {
pflags, _ := testMQTTGetPubMsg(t, mc2, rs2, "foo", []byte(test.expectedValue))
if pflags&mqttPubFlagRetain == 0 {
t.Fatalf("retain flag should have been set, it was not: flags=%v", pflags)
}
} else {
testMQTTExpectNothing(t, rs2)
}
testMQTTDisconnect(t, mc1, nil)
testMQTTDisconnect(t, mc2, nil)
})
}
}
func TestMQTTPublishRetainPermViolation(t *testing.T) {
o := testMQTTDefaultOptions()
o.Users = []*User{
{
Username: "mqtt",
Password: "pass",
Permissions: &Permissions{
Publish: &SubjectPermission{Allow: []string{"foo"}},
Subscribe: &SubjectPermission{Allow: []string{"bar", "$MQTT.sub.>"}},
},
},
}
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
ci := &mqttConnInfo{
cleanSess: true,
user: "mqtt",
pass: "pass",
}
mc1, rs1 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mc1.Close()
testMQTTCheckConnAck(t, rs1, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, mc1, rs1, 0, false, true, "bar", 0, []byte("retained"))
testMQTTFlush(t, mc1, nil, rs1)
mc2, rs2 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mc2.Close()
testMQTTCheckConnAck(t, rs2, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc2, rs2, []*mqttFilter{{filter: "bar", qos: 1}}, []byte{1})
testMQTTExpectNothing(t, rs2)
testMQTTDisconnect(t, mc1, nil)
testMQTTDisconnect(t, mc2, nil)
}
func TestMQTTPublishViolation(t *testing.T) {
o := testMQTTDefaultOptions()
o.Users = []*User{
{
Username: "mqtt",
Password: "pass",
Permissions: &Permissions{
Publish: &SubjectPermission{Allow: []string{"foo.bar"}},
Subscribe: &SubjectPermission{Allow: []string{"foo.*", "$MQTT.sub.>"}},
},
},
}
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
ci := &mqttConnInfo{
user: "mqtt",
pass: "pass",
}
ci.clientID = "sub"
mc, rc := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, rc, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, mc, rc, []*mqttFilter{{filter: "foo/+", qos: 1}}, []byte{1})
testMQTTFlush(t, mc, nil, rc)
ci.clientID = "pub"
ci.cleanSess = true
mp, rp := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mp.Close()
testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false)
// These should be received since publisher has the right to publish on foo.bar
testMQTTPublish(t, mp, rp, 0, false, false, "foo/bar", 0, []byte("msg1"))
testMQTTCheckPubMsg(t, mc, rc, "foo/bar", 0, []byte("msg1"))
testMQTTPublish(t, mp, rp, 1, false, false, "foo/bar", 1, []byte("msg2"))
testMQTTCheckPubMsg(t, mc, rc, "foo/bar", mqttPubQos1, []byte("msg2"))
// But these should not be cause pub has no permission to publish on foo.baz
testMQTTPublish(t, mp, rp, 0, false, false, "foo/baz", 0, []byte("msg3"))
testMQTTExpectNothing(t, rc)
testMQTTPublish(t, mp, rp, 1, false, false, "foo/baz", 1, []byte("msg4"))
testMQTTExpectNothing(t, rc)
// Disconnect publisher
testMQTTDisconnect(t, mp, nil)
mp.Close()
// Disconnect subscriber and restart it to make sure that it does not receive msg3/msg4
testMQTTDisconnect(t, mc, nil)
mc.Close()
ci.cleanSess = false
ci.clientID = "sub"
mc, rc = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, rc, mqttConnAckRCConnectionAccepted, true)
testMQTTSub(t, 1, mc, rc, []*mqttFilter{{filter: "foo/+", qos: 1}}, []byte{1})
testMQTTExpectNothing(t, rc)
}
func TestMQTTCleanSession(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
ci := &mqttConnInfo{
clientID: "me",
cleanSess: false,
}
c, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTDisconnect(t, c, nil)
c, r = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true)
testMQTTDisconnect(t, c, nil)
ci.cleanSess = true
c, r = testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTDisconnect(t, c, nil)
}
func TestMQTTDuplicateClientID(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
ci := &mqttConnInfo{
clientID: "me",
cleanSess: false,
}
c1, r1 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer c1.Close()
testMQTTCheckConnAck(t, r1, mqttConnAckRCConnectionAccepted, false)
c2, r2 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer c2.Close()
testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, true)
// The old client should be disconnected.
testMQTTExpectDisconnect(t, c1)
}
func TestMQTTPersistedSession(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownRestartedServer(&s)
cisub := &mqttConnInfo{clientID: "sub", cleanSess: false}
cipub := &mqttConnInfo{clientID: "pub", cleanSess: true}
c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, c, r,
[]*mqttFilter{
{filter: "foo/#", qos: 1},
{filter: "bar", qos: 1},
{filter: "baz", qos: 0},
},
[]byte{1, 1, 0})
testMQTTFlush(t, c, nil, r)
// Shutdown server, close connection and restart server. It should
// have restored the session and consumers.
dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir)
s.Shutdown()
c.Close()
o.Port = -1
o.MQTT.Port = -1
o.StoreDir = dir
s = testMQTTRunServer(t, o)
// There is already the defer for shutdown at top of function
// Create a publisher that will send qos1 so we verify that messages
// are stored for the persisted sessions.
c, r = testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, c, r, 1, false, false, "foo/bar", 1, []byte("msg0"))
testMQTTFlush(t, c, nil, r)
testMQTTDisconnect(t, c, nil)
c.Close()
// Recreate session
c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true)
// Since consumers have been recovered, messages should be received
// (MQTT does not need client to recreate consumers for a recovered
// session)
// Check that qos1 publish message is received.
testMQTTCheckPubMsg(t, c, r, "foo/bar", mqttPubQos1, []byte("msg0"))
// Flush to prevent publishes to be done too soon since we are
// receiving the CONNACK before the subscriptions are restored.
testMQTTFlush(t, c, nil, r)
// Now publish some messages to all subscriptions.
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
natsPub(t, nc, "foo.bar", []byte("msg1"))
testMQTTCheckPubMsg(t, c, r, "foo/bar", 0, []byte("msg1"))
natsPub(t, nc, "foo", []byte("msg2"))
testMQTTCheckPubMsg(t, c, r, "foo", 0, []byte("msg2"))
natsPub(t, nc, "bar", []byte("msg3"))
testMQTTCheckPubMsg(t, c, r, "bar", 0, []byte("msg3"))
natsPub(t, nc, "baz", []byte("msg4"))
testMQTTCheckPubMsg(t, c, r, "baz", 0, []byte("msg4"))
// Now unsub "bar" and verify that message published on this topic
// is not received.
testMQTTUnsub(t, 1, c, r, []*mqttFilter{{filter: "bar"}})
natsPub(t, nc, "bar", []byte("msg5"))
testMQTTExpectNothing(t, r)
nc.Close()
s.Shutdown()
c.Close()
o.Port = -1
o.MQTT.Port = -1
o.StoreDir = dir
s = testMQTTRunServer(t, o)
// There is already the defer for shutdown at top of function
// Recreate a client
c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true)
nc = natsConnect(t, s.ClientURL())
defer nc.Close()
natsPub(t, nc, "foo.bar", []byte("msg6"))
testMQTTCheckPubMsg(t, c, r, "foo/bar", 0, []byte("msg6"))
natsPub(t, nc, "foo", []byte("msg7"))
testMQTTCheckPubMsg(t, c, r, "foo", 0, []byte("msg7"))
// Make sure that we did not recover bar.
natsPub(t, nc, "bar", []byte("msg8"))
testMQTTExpectNothing(t, r)
natsPub(t, nc, "baz", []byte("msg9"))
testMQTTCheckPubMsg(t, c, r, "baz", 0, []byte("msg9"))
// Have the sub client send a subscription downgrading the qos1 subscription.
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo/#", qos: 0}}, []byte{0})
testMQTTFlush(t, c, nil, r)
nc.Close()
s.Shutdown()
c.Close()
o.Port = -1
o.MQTT.Port = -1
o.StoreDir = dir
s = testMQTTRunServer(t, o)
// There is already the defer for shutdown at top of function
// Recreate the sub client
c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true)
// Publish as a qos1
c2, r2 := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port)
defer c2.Close()
testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, c2, r2, 1, false, false, "foo/bar", 1, []byte("msg10"))
// Verify that it is received as qos0 which is the qos of the subscription.
testMQTTCheckPubMsg(t, c, r, "foo/bar", 0, []byte("msg10"))
testMQTTDisconnect(t, c, nil)
c.Close()
testMQTTDisconnect(t, c2, nil)
c2.Close()
// Finally, recreate the sub with clean session and ensure that all is gone
cisub.cleanSess = true
for i := 0; i < 2; i++ {
c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
nc = natsConnect(t, s.ClientURL())
defer nc.Close()
natsPub(t, nc, "foo.bar", []byte("msg11"))
testMQTTExpectNothing(t, r)
natsPub(t, nc, "foo", []byte("msg12"))
testMQTTExpectNothing(t, r)
// Make sure that we did not recover bar.
natsPub(t, nc, "bar", []byte("msg13"))
testMQTTExpectNothing(t, r)
natsPub(t, nc, "baz", []byte("msg14"))
testMQTTExpectNothing(t, r)
testMQTTDisconnect(t, c, nil)
c.Close()
nc.Close()
s.Shutdown()
o.Port = -1
o.MQTT.Port = -1
o.StoreDir = dir
s = testMQTTRunServer(t, o)
// There is already the defer for shutdown at top of function
}
}
func TestMQTTRecoverSessionAndAddNewSub(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownRestartedServer(&s)
cisub := &mqttConnInfo{clientID: "sub1", cleanSess: false}
c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTDisconnect(t, c, nil)
c.Close()
// Shutdown server, close connection and restart server. It should
// have restored the session and consumers.
dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir)
s.Shutdown()
c.Close()
o.Port = -1
o.MQTT.Port = -1
o.StoreDir = dir
s = testMQTTRunServer(t, o)
// No need for defer since it is done top of function
c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true)
// Now add sub and make sure it does not crash
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
testMQTTFlush(t, c, nil, r)
// Now repeat with a new client but without server restart.
cisub2 := &mqttConnInfo{clientID: "sub2", cleanSess: false}
c2, r2 := testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port)
defer c2.Close()
testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, false)
testMQTTDisconnect(t, c2, nil)
c2.Close()
c2, r2 = testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port)
defer c2.Close()
testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, true)
testMQTTSub(t, 1, c2, r2, []*mqttFilter{{filter: "bar", qos: 1}}, []byte{1})
testMQTTFlush(t, c2, nil, r2)
}
func TestMQTTRecoverSessionWithSubAndClientResendSub(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownRestartedServer(&s)
cisub1 := &mqttConnInfo{clientID: "sub1", cleanSess: false}
c, r := testMQTTConnect(t, cisub1, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
// Have a client send a SUBSCRIBE protocol for foo, QoS1
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
testMQTTDisconnect(t, c, nil)
c.Close()
// Restart the server now.
dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir)
s.Shutdown()
o.Port = -1
o.MQTT.Port = -1
o.StoreDir = dir
s = testMQTTRunServer(t, o)
// No need for defer since it is done top of function
// Now restart the client. Since the client was created with cleanSess==false,
// the server will have recorded the subscriptions for this client.
c, r = testMQTTConnect(t, cisub1, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true)
// At this point, the server has recreated the subscription on foo, QoS1.
// For applications that restart, it is possible (likely) that they
// will resend their SUBSCRIBE protocols, so do so now:
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
testMQTTFlush(t, c, nil, r)
checkNumSub := func(clientID string) {
t.Helper()
// Find the MQTT client...
mc := testMQTTGetClient(t, s, clientID)
// Check how many NATS subscriptions are registered.
var fooSub int
var otherSub int
mc.mu.Lock()
for _, sub := range mc.subs {
switch string(sub.subject) {
case "foo":
fooSub++
default:
otherSub++
}
}
mc.mu.Unlock()
// We should have 2 subscriptions, one on "foo", and one for the JS durable
// consumer's delivery subject.
if fooSub != 1 {
t.Fatalf("Expected 1 sub on 'foo', got %v", fooSub)
}
if otherSub != 1 {
t.Fatalf("Expected 1 subscription for JS durable, got %v", otherSub)
}
}
checkNumSub("sub1")
c.Close()
// Now same but without the server restart in-between.
cisub2 := &mqttConnInfo{clientID: "sub2", cleanSess: false}
c, r = testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
testMQTTDisconnect(t, c, nil)
c.Close()
// Restart client
c, r = testMQTTConnect(t, cisub2, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true)
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
testMQTTFlush(t, c, nil, r)
// Check client subs
checkNumSub("sub2")
}
func TestMQTTFlappingSession(t *testing.T) {
mqttSessJailDur = 250 * time.Millisecond
mqttFlapCleanItvl = 350 * time.Millisecond
defer func() {
mqttSessJailDur = mqttSessFlappingJailDur
mqttFlapCleanItvl = mqttSessFlappingCleanupInterval
}()
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
ci := &mqttConnInfo{clientID: "flapper", cleanSess: false}
c, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
// Let's get a handle on the asm to check things later.
cli := testMQTTGetClient(t, s, "flapper")
asm := cli.mqtt.asm
// Start a new connection with the same clientID, which should replace
// the old one and put it in the flappers map.
c2, r2 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer c2.Close()
testMQTTCheckConnAck(t, r2, mqttConnAckRCConnectionAccepted, true)
// Should be disconnected...
testMQTTExpectDisconnect(t, c)
// Now try to reconnect "c" and we should fail. We have to do this manually,
// since we expect it to fail.
addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)
c, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating mqtt connection: %v", err)
}
defer c.Close()
proto := mqttCreateConnectProto(ci)
if _, err := testMQTTWrite(c, proto); err != nil {
t.Fatalf("Error writing connect: %v", err)
}
if _, err := testMQTTRead(c); err == nil {
t.Fatal("Expected connection to fail")
}
// This should be in the flappers map, but after 250ms should be cleared.
for i := 0; i < 2; i++ {
asm.mu.RLock()
_, present := asm.flappers["flapper"]
asm.mu.RUnlock()
if i == 0 {
if !present {
t.Fatal("Did not find the client ID in the flappers map")
}
// Wait for more than the cleanup interval
time.Sleep(mqttFlapCleanItvl + 100*time.Millisecond)
} else if present {
t.Fatal("The client ID should have been cleared from the map")
}
}
}
func TestMQTTLockedSession(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
ci := &mqttConnInfo{clientID: "sub", cleanSess: false}
c, r := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
sm := &s.mqtt.sessmgr
sm.mu.Lock()
asm := sm.sessions[globalAccountName]
sm.mu.Unlock()
if asm == nil {
t.Fatalf("account session manager not found")
}
// Get the session for "sub"
cli := testMQTTGetClient(t, s, "sub")
sess := cli.mqtt.sess
// Pretend that the session above is locked.
if err := asm.lockSession(sess, cli); err != nil {
t.Fatalf("Unable to lock session: %v", err)
}
defer asm.unlockSession(sess)
// Now try to connect another client that wants to use "sub".
// We can't use testMQTTConnect() because it is going to fail.
addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)
c2, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating mqtt connection: %v", err)
}
defer c2.Close()
proto := mqttCreateConnectProto(ci)
if _, err := testMQTTWrite(c2, proto); err != nil {
t.Fatalf("Error writing connect: %v", err)
}
if _, err := testMQTTRead(c2); err == nil {
t.Fatal("Expected connection to fail")
}
// Now try again, but this time release the session while waiting
// to connect and it should succeed.
time.AfterFunc(250*time.Millisecond, func() { asm.unlockSession(sess) })
c3, r3 := testMQTTConnect(t, ci, o.MQTT.Host, o.MQTT.Port)
defer c3.Close()
testMQTTCheckConnAck(t, r3, mqttConnAckRCConnectionAccepted, true)
}
func TestMQTTPersistRetainedMsg(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownRestartedServer(&s)
dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir)
cipub := &mqttConnInfo{clientID: "pub", cleanSess: true}
c, r := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, c, r, 1, false, true, "foo", 1, []byte("foo1"))
testMQTTPublish(t, c, r, 1, false, true, "foo", 1, []byte("foo2"))
testMQTTPublish(t, c, r, 1, false, true, "bar", 1, []byte("bar1"))
testMQTTPublish(t, c, r, 0, false, true, "baz", 1, []byte("baz1"))
// Remove bar
testMQTTPublish(t, c, r, 1, false, true, "bar", 1, nil)
testMQTTFlush(t, c, nil, r)
testMQTTDisconnect(t, c, nil)
c.Close()
s.Shutdown()
o.Port = -1
o.MQTT.Port = -1
o.StoreDir = dir
s = testMQTTRunServer(t, o)
// There is already the defer for shutdown at top of function
cisub := &mqttConnInfo{clientID: "sub", cleanSess: false}
c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
testMQTTCheckPubMsg(t, c, r, "foo", mqttPubFlagRetain|mqttPubQos1, []byte("foo2"))
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "baz", qos: 1}}, []byte{1})
testMQTTCheckPubMsg(t, c, r, "baz", mqttPubFlagRetain, []byte("baz1"))
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "bar", qos: 1}}, []byte{1})
testMQTTExpectNothing(t, r)
testMQTTDisconnect(t, c, nil)
c.Close()
}
func TestMQTTConnAckFirstPacket(t *testing.T) {
o := testMQTTDefaultOptions()
o.NoLog, o.Debug, o.Trace = true, false, false
s := RunServer(o)
defer testMQTTShutdownServer(s)
cisub := &mqttConnInfo{clientID: "sub", cleanSess: false}
c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 0}}, []byte{0})
testMQTTDisconnect(t, c, nil)
c.Close()
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
wg := sync.WaitGroup{}
wg.Add(1)
ch := make(chan struct{}, 1)
ready := make(chan struct{})
go func() {
defer wg.Done()
close(ready)
for {
nc.Publish("foo", []byte("msg"))
select {
case <-ch:
return
default:
}
}
}()
<-ready
for i := 0; i < 100; i++ {
c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true)
w := &mqttWriter{}
w.WriteByte(mqttPacketDisconnect)
w.WriteByte(0)
c.Write(w.Bytes())
// Wait to be disconnected, we can't use testMQTTDisconnect() because
// it would fail because we may still receive some NATS messages.
var b [10]byte
for {
if _, err := c.Read(b[:]); err != nil {
break
}
}
c.Close()
}
close(ch)
wg.Wait()
}
func TestMQTTRedeliveryAckWait(t *testing.T) {
o := testMQTTDefaultOptions()
o.MQTT.AckWait = 250 * time.Millisecond
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
cisub := &mqttConnInfo{clientID: "sub", cleanSess: false}
c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
cipub := &mqttConnInfo{clientID: "pub", cleanSess: true}
cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port)
defer cp.Close()
testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("foo1"))
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 2, []byte("foo2"))
testMQTTDisconnect(t, cp, nil)
cp.Close()
for i := 0; i < 2; i++ {
flags := mqttPubQos1
if i > 0 {
flags |= mqttPubFlagDup
}
pi1 := testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("foo1"))
pi2 := testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("foo2"))
if pi1 != 1 || pi2 != 2 {
t.Fatalf("Unexpected pi values: %v, %v", pi1, pi2)
}
}
// Ack first message
testMQTTSendPubAck(t, c, 1)
// Redelivery should only be for second message now
for i := 0; i < 2; i++ {
flags := mqttPubQos1 | mqttPubFlagDup
pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("foo2"))
if pi != 2 {
t.Fatalf("Unexpected pi to be 2, got %v", pi)
}
}
// Restart client, should receive second message with pi==2
c.Close()
c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true)
// Check that message is received with proper pi
pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("foo2"))
if pi != 2 {
t.Fatalf("Unexpected pi to be 2, got %v", pi)
}
// Now ack second message
testMQTTSendPubAck(t, c, 2)
// Flush to make sure it is processed before checking client's maps
testMQTTFlush(t, c, nil, r)
// Look for the sub client
mc := testMQTTGetClient(t, s, "sub")
mc.mu.Lock()
sess := mc.mqtt.sess
sess.mu.Lock()
lpi := len(sess.pending)
var lsseq int
for _, sseqToPi := range sess.cpending {
lsseq += len(sseqToPi)
}
sess.mu.Unlock()
mc.mu.Unlock()
if lpi != 0 || lsseq != 0 {
t.Fatalf("Maps should be empty, got %v, %v", lpi, lsseq)
}
}
func TestMQTTAckWaitConfigChange(t *testing.T) {
o := testMQTTDefaultOptions()
o.MQTT.AckWait = 250 * time.Millisecond
s := testMQTTRunServer(t, o)
defer testMQTTShutdownRestartedServer(&s)
dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir)
cisub := &mqttConnInfo{clientID: "sub", cleanSess: false}
c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
sendMsg := func(topic, payload string) {
t.Helper()
cipub := &mqttConnInfo{clientID: "pub", cleanSess: true}
cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port)
defer cp.Close()
testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, cp, rp, 1, false, false, topic, 1, []byte(payload))
testMQTTDisconnect(t, cp, nil)
cp.Close()
}
sendMsg("foo", "msg1")
for i := 0; i < 2; i++ {
flags := mqttPubQos1
if i > 0 {
flags |= mqttPubFlagDup
}
testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("msg1"))
}
// Restart the server with a different AckWait option value.
// Verify that MQTT sub restart succeeds. It will keep the
// original value.
c.Close()
s.Shutdown()
o.Port = -1
o.MQTT.Port = -1
o.MQTT.AckWait = 10 * time.Millisecond
o.StoreDir = dir
s = testMQTTRunServer(t, o)
// There is already the defer for shutdown at top of function
c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true)
testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1"))
start := time.Now()
testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1"))
if dur := time.Since(start); dur < 200*time.Millisecond {
t.Fatalf("AckWait seem to have changed for existing subscription: %v", dur)
}
// Create new subscription
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "bar", qos: 1}}, []byte{1})
sendMsg("bar", "msg2")
testMQTTCheckPubMsgNoAck(t, c, r, "bar", mqttPubQos1, []byte("msg2"))
start = time.Now()
testMQTTCheckPubMsgNoAck(t, c, r, "bar", mqttPubQos1|mqttPubFlagDup, []byte("msg2"))
if dur := time.Since(start); dur > 50*time.Millisecond {
t.Fatalf("AckWait new value not used by new sub: %v", dur)
}
c.Close()
}
func TestMQTTUnsubscribeWithPendingAcks(t *testing.T) {
o := testMQTTDefaultOptions()
o.MQTT.AckWait = 250 * time.Millisecond
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
cisub := &mqttConnInfo{clientID: "sub", cleanSess: false}
c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
cipub := &mqttConnInfo{clientID: "pub", cleanSess: true}
cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port)
defer cp.Close()
testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg"))
testMQTTDisconnect(t, cp, nil)
cp.Close()
for i := 0; i < 2; i++ {
flags := mqttPubQos1
if i > 0 {
flags |= mqttPubFlagDup
}
testMQTTCheckPubMsgNoAck(t, c, r, "foo", flags, []byte("msg"))
}
testMQTTUnsub(t, 1, c, r, []*mqttFilter{{filter: "foo"}})
testMQTTFlush(t, c, nil, r)
mc := testMQTTGetClient(t, s, "sub")
mc.mu.Lock()
sess := mc.mqtt.sess
sess.mu.Lock()
pal := len(sess.pending)
sess.mu.Unlock()
mc.mu.Unlock()
if pal != 0 {
t.Fatalf("Expected pending ack map to be empty, got %v", pal)
}
}
func TestMQTTMaxAckPending(t *testing.T) {
o := testMQTTDefaultOptions()
o.MQTT.MaxAckPending = 1
s := testMQTTRunServer(t, o)
defer testMQTTShutdownRestartedServer(&s)
dir := strings.TrimSuffix(s.JetStreamConfig().StoreDir, JetStreamStoreDir)
cisub := &mqttConnInfo{clientID: "sub", cleanSess: false}
c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
cipub := &mqttConnInfo{clientID: "pub", cleanSess: true}
cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port)
defer cp.Close()
testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1"))
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg2"))
pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1"))
// Check that we don't receive the second one due to max ack pending
testMQTTExpectNothing(t, r)
// Now ack first message
testMQTTSendPubAck(t, c, pi)
// Now we should receive message 2
testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg2"))
testMQTTDisconnect(t, c, nil)
// Give a chance to the server to register that this client is gone.
checkClientsCount(t, s, 1)
// Send 2 messages while sub is offline
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg3"))
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg4"))
// Restart consumer
c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true)
// Should receive only message 3
pi = testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg3"))
testMQTTExpectNothing(t, r)
// Ack and get the next
testMQTTSendPubAck(t, c, pi)
testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg4"))
// Make sure this message gets ack'ed
mcli := testMQTTGetClient(t, s, cisub.clientID)
checkFor(t, time.Second, 15*time.Millisecond, func() error {
mcli.mu.Lock()
sess := mcli.mqtt.sess
sess.mu.Lock()
np := len(sess.pending)
sess.mu.Unlock()
mcli.mu.Unlock()
if np != 0 {
return fmt.Errorf("Still %v pending messages", np)
}
return nil
})
// Check that change to config does not prevent restart of sub.
cp.Close()
c.Close()
s.Shutdown()
o.Port = -1
o.MQTT.Port = -1
o.MQTT.MaxAckPending = 2
o.StoreDir = dir
s = testMQTTRunServer(t, o)
// There is already the defer for shutdown at top of function
cp, rp = testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port)
defer cp.Close()
testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg5"))
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg6"))
// Restart consumer
c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, true)
// Should receive only message 5
pi = testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg5"))
testMQTTExpectNothing(t, r)
// Ack and get the next
testMQTTSendPubAck(t, c, pi)
testMQTTCheckPubMsg(t, c, r, "foo", mqttPubQos1, []byte("msg6"))
}
func TestMQTTMaxAckPendingForMultipleSubs(t *testing.T) {
o := testMQTTDefaultOptions()
o.MQTT.AckWait = 500 * time.Millisecond
o.MQTT.MaxAckPending = 1
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
cisub := &mqttConnInfo{clientID: "sub", cleanSess: false}
c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "bar", qos: 1}}, []byte{1})
cipub := &mqttConnInfo{clientID: "pub", cleanSess: true}
cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port)
defer cp.Close()
testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1"))
pi := testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1"))
// Now send a second message but on topic bar
testMQTTPublish(t, cp, rp, 1, false, false, "bar", 1, []byte("msg2"))
// JS allows us to limit per consumer, but we apply the limit to the
// session, so although JS will attempt to delivery this message,
// the MQTT code will suppress it.
testMQTTExpectNothing(t, r)
// Ack the first message.
testMQTTSendPubAck(t, c, pi)
// Now we should get the second message
testMQTTCheckPubMsg(t, c, r, "bar", mqttPubQos1|mqttPubFlagDup, []byte("msg2"))
}
func TestMQTTMaxAckPendingOverLimit(t *testing.T) {
o := testMQTTDefaultOptions()
o.MQTT.MaxAckPending = 20000
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
checkTMax := func(sess *mqttSession, expected int) {
t.Helper()
sess.mu.Lock()
tmax := sess.tmaxack
sess.mu.Unlock()
if tmax != expected {
t.Fatalf("Expected current tmax to be %v, got %v", expected, tmax)
}
}
cisub := &mqttConnInfo{clientID: "sub", cleanSess: false}
c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
mc := testMQTTGetClient(t, s, "sub")
mc.mu.Lock()
sess := mc.mqtt.sess
mc.mu.Unlock()
// After this one, total would be 20000
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
checkTMax(sess, 20000)
// This one will count for 2, so total will be 60000
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "bar/#", qos: 1}}, []byte{1})
checkTMax(sess, 60000)
// This should fail
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "bar", qos: 1}}, []byte{mqttSubAckFailure})
checkTMax(sess, 60000)
// Remove the one with wildcard
testMQTTUnsub(t, 1, c, r, []*mqttFilter{{filter: "bar/#"}})
checkTMax(sess, 20000)
// Now we could add 2 more without wildcards
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "bar", qos: 1}}, []byte{1})
checkTMax(sess, 40000)
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "baz", qos: 1}}, []byte{1})
checkTMax(sess, 60000)
// Again, this one should fail
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "bat", qos: 1}}, []byte{mqttSubAckFailure})
checkTMax(sess, 60000)
// Now remove all and check that we are at 0
testMQTTUnsub(t, 1, c, r, []*mqttFilter{{filter: "foo"}})
checkTMax(sess, 40000)
testMQTTUnsub(t, 1, c, r, []*mqttFilter{{filter: "bar"}})
checkTMax(sess, 20000)
testMQTTUnsub(t, 1, c, r, []*mqttFilter{{filter: "baz"}})
checkTMax(sess, 0)
}
func TestMQTTConfigReload(t *testing.T) {
template := `
jetstream: true
mqtt {
port: -1
ack_wait: %s
max_ack_pending: %s
}
`
conf := createConfFile(t, []byte(fmt.Sprintf(template, `"5s"`, `10000`)))
defer removeFile(t, conf)
s, o := RunServerWithConfig(conf)
defer testMQTTShutdownServer(s)
if val := o.MQTT.AckWait; val != 5*time.Second {
t.Fatalf("Invalid ackwait: %v", val)
}
if val := o.MQTT.MaxAckPending; val != 10000 {
t.Fatalf("Invalid ackwait: %v", val)
}
changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, `"250ms"`, `1`)))
if err := s.Reload(); err != nil {
t.Fatalf("Error on reload: %v", err)
}
cisub := &mqttConnInfo{clientID: "sub", cleanSess: false}
c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
cipub := &mqttConnInfo{clientID: "pub", cleanSess: true}
cp, rp := testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port)
defer cp.Close()
testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1"))
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg2"))
testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1"))
start := time.Now()
testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1|mqttPubFlagDup, []byte("msg1"))
if dur := time.Since(start); dur > 500*time.Millisecond {
t.Fatalf("AckWait not applied? dur=%v", dur)
}
c.Close()
cp.Close()
testMQTTShutdownServer(s)
changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, `"30s"`, `1`)))
s, o = RunServerWithConfig(conf)
defer testMQTTShutdownServer(s)
cisub.cleanSess = true
c, r = testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1})
cipub = &mqttConnInfo{clientID: "pub", cleanSess: true}
cp, rp = testMQTTConnect(t, cipub, o.MQTT.Host, o.MQTT.Port)
defer cp.Close()
testMQTTCheckConnAck(t, rp, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg1"))
testMQTTPublish(t, cp, rp, 1, false, false, "foo", 1, []byte("msg2"))
testMQTTCheckPubMsgNoAck(t, c, r, "foo", mqttPubQos1, []byte("msg1"))
testMQTTExpectNothing(t, r)
// Increase the max ack pending
changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, `"30s"`, `10`)))
// Reload now
if err := s.Reload(); err != nil {
t.Fatalf("Error on reload: %v", err)
}
// Reload will have effect only on new subscriptions.
// Create a new subscription, and we should not be able to get the 2 messages.
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "bar", qos: 1}}, []byte{1})
testMQTTPublish(t, cp, rp, 1, false, false, "bar", 1, []byte("msg3"))
testMQTTPublish(t, cp, rp, 1, false, false, "bar", 1, []byte("msg4"))
testMQTTCheckPubMsg(t, c, r, "bar", mqttPubQos1, []byte("msg3"))
testMQTTCheckPubMsg(t, c, r, "bar", mqttPubQos1, []byte("msg4"))
}
func TestMQTTStreamInfoReturnsNonEmptySubject(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
cisub := &mqttConnInfo{clientID: "sub", cleanSess: false}
c, r := testMQTTConnect(t, cisub, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
// Check that we can query all MQTT streams. MQTT streams are
// created without subject filter, however, if we return them like this,
// the 'nats' utility will fail to display them due to some xml validation.
for _, sname := range []string{
mqttStreamName,
mqttRetainedMsgsStreamName,
} {
t.Run(sname, func(t *testing.T) {
resp, err := nc.Request(fmt.Sprintf(JSApiStreamInfoT, sname), nil, time.Second)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
var bResp JSApiStreamInfoResponse
if err = json.Unmarshal(resp.Data, &bResp); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if len(bResp.Config.Subjects) == 0 {
t.Fatalf("No subject returned, which will cause nats tooling to fail: %+v", bResp.Config)
}
})
}
}
func TestMQTTWebsocketNotSupported(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
l := &captureErrorLogger{errCh: make(chan string, 10)}
s.SetLogger(l, false, false)
addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)
wsc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating connection: %v", err)
}
req := testWSCreateValidReq()
req.URL, _ = url.Parse("ws://" + addr)
if err := req.Write(wsc); err != nil {
t.Fatalf("Error sending request: %v", err)
}
br := bufio.NewReader(wsc)
resp, err := http.ReadResponse(br, req)
if err == nil {
if resp != nil {
defer resp.Body.Close()
}
t.Fatalf("Expected error, got resp=%+v", resp)
}
select {
case err := <-l.errCh:
if !strings.Contains(err, "not supported") {
t.Fatalf("Expected error about websocket not supported, got %v", err)
}
case <-time.After(time.Second):
t.Fatal("Did not log any error")
}
}
//////////////////////////////////////////////////////////////////////////
//
// Benchmarks
//
//////////////////////////////////////////////////////////////////////////
const (
mqttPubSubj = "a"
mqttBenchBufLen = 32768
)
func mqttBenchPubQoS0(b *testing.B, subject, payload string, numSubs int) {
b.StopTimer()
o := testMQTTDefaultOptions()
s := RunServer(o)
defer testMQTTShutdownServer(s)
ci := &mqttConnInfo{clientID: "pub", cleanSess: true}
c, br := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port)
testMQTTCheckConnAck(b, br, mqttConnAckRCConnectionAccepted, false)
w := &mqttWriter{}
mqttWritePublish(w, 0, false, false, subject, 0, []byte(payload))
sendOp := w.Bytes()
dch := make(chan error, 1)
totalSize := int64(len(sendOp))
cdch := 0
createSub := func(i int) {
ci := &mqttConnInfo{clientID: fmt.Sprintf("sub%d", i), cleanSess: true}
cs, brs := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port)
testMQTTCheckConnAck(b, brs, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(b, 1, cs, brs, []*mqttFilter{{filter: subject, qos: 0}}, []byte{0})
testMQTTFlush(b, cs, nil, brs)
w := &mqttWriter{}
varHeaderAndPayload := 2 + len(subject) + len(payload)
w.WriteVarInt(varHeaderAndPayload)
size := 1 + w.Len() + varHeaderAndPayload
totalSize += int64(size)
go func() {
mqttBenchConsumeMsgQoS0(cs, int64(b.N)*int64(size), dch)
cs.Close()
}()
}
for i := 0; i < numSubs; i++ {
createSub(i + 1)
cdch++
}
bw := bufio.NewWriterSize(c, mqttBenchBufLen)
b.SetBytes(totalSize)
b.StartTimer()
for i := 0; i < b.N; i++ {
bw.Write(sendOp)
}
testMQTTFlush(b, c, bw, br)
for i := 0; i < cdch; i++ {
if e := <-dch; e != nil {
b.Fatal(e.Error())
}
}
b.StopTimer()
c.Close()
s.Shutdown()
}
func mqttBenchConsumeMsgQoS0(c net.Conn, total int64, dch chan<- error) {
var buf [mqttBenchBufLen]byte
var err error
var n int
for size := int64(0); size < total; {
n, err = c.Read(buf[:])
if err != nil {
break
}
size += int64(n)
}
dch <- err
}
func mqttBenchPubQoS1(b *testing.B, subject, payload string, numSubs int) {
b.StopTimer()
o := testMQTTDefaultOptions()
o.MQTT.MaxAckPending = 0xFFFF
s := RunServer(o)
defer testMQTTShutdownServer(s)
ci := &mqttConnInfo{cleanSess: true}
c, br := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port)
testMQTTCheckConnAck(b, br, mqttConnAckRCConnectionAccepted, false)
w := &mqttWriter{}
mqttWritePublish(w, 1, false, false, subject, 1, []byte(payload))
// For reported bytes we will count the PUBLISH + PUBACK (4 bytes)
totalSize := int64(len(w.Bytes()) + 4)
w.Reset()
pi := uint16(1)
maxpi := uint16(60000)
ppich := make(chan error, 10)
dch := make(chan error, 1+numSubs)
cdch := 1
// Start go routine to consume PUBACK for published QoS 1 messages.
go mqttBenchConsumePubAck(c, b.N, dch, ppich)
createSub := func(i int) {
ci := &mqttConnInfo{clientID: fmt.Sprintf("sub%d", i), cleanSess: true}
cs, brs := testMQTTConnect(b, ci, o.MQTT.Host, o.MQTT.Port)
testMQTTCheckConnAck(b, brs, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(b, 1, cs, brs, []*mqttFilter{{filter: subject, qos: 1}}, []byte{1})
testMQTTFlush(b, cs, nil, brs)
w := &mqttWriter{}
varHeaderAndPayload := 2 + len(subject) + 2 + len(payload)
w.WriteVarInt(varHeaderAndPayload)
size := 1 + w.Len() + varHeaderAndPayload
// Add to the bytes reported the size of message sent to subscriber + PUBACK (4 bytes)
totalSize += int64(size + 4)
go func() {
mqttBenchConsumeMsgQos1(cs, b.N, size, dch)
cs.Close()
}()
}
for i := 0; i < numSubs; i++ {
createSub(i + 1)
cdch++
}
flush := func() {
b.Helper()
if _, err := c.Write(w.Bytes()); err != nil {
b.Fatalf("Error on write: %v", err)
}
w.Reset()
}
b.SetBytes(totalSize)
b.StartTimer()
for i := 0; i < b.N; i++ {
if pi <= maxpi {
mqttWritePublish(w, 1, false, false, subject, pi, []byte(payload))
pi++
if w.Len() >= mqttBenchBufLen {
flush()
}
} else {
if w.Len() > 0 {
flush()
}
if pi > 60000 {
pi = 1
maxpi = 0
}
if e := <-ppich; e != nil {
b.Fatal(e.Error())
}
maxpi += 10000
i--
}
}
if w.Len() > 0 {
flush()
}
for i := 0; i < cdch; i++ {
if e := <-dch; e != nil {
b.Fatal(e.Error())
}
}
b.StopTimer()
c.Close()
s.Shutdown()
}
func mqttBenchConsumeMsgQos1(c net.Conn, total, size int, dch chan<- error) {
var buf [mqttBenchBufLen]byte
pubAck := [4]byte{mqttPacketPubAck, 0x2, 0, 0}
var err error
var n int
var pi uint16
var prev int
for i := 0; i < total; {
n, err = c.Read(buf[:])
if err != nil {
break
}
n += prev
for ; n >= size; n -= size {
i++
pi++
pubAck[2] = byte(pi >> 8)
pubAck[3] = byte(pi)
if _, err = c.Write(pubAck[:4]); err != nil {
dch <- err
return
}
if pi == 60000 {
pi = 0
}
}
prev = n
}
dch <- err
}
func mqttBenchConsumePubAck(c net.Conn, total int, dch, ppich chan<- error) {
var buf [mqttBenchBufLen]byte
var err error
var n int
var pi uint16
var prev int
for i := 0; i < total; {
n, err = c.Read(buf[:])
if err != nil {
break
}
n += prev
for ; n >= 4; n -= 4 {
i++
pi++
if pi%10000 == 0 {
ppich <- nil
}
if pi == 60001 {
pi = 0
}
}
prev = n
}
ppich <- err
dch <- err
}
func BenchmarkMQTT_QoS0_Pub_______0b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, "", 0)
}
func BenchmarkMQTT_QoS0_Pub_______8b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(8), 0)
}
func BenchmarkMQTT_QoS0_Pub______32b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(32), 0)
}
func BenchmarkMQTT_QoS0_Pub_____128b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(128), 0)
}
func BenchmarkMQTT_QoS0_Pub_____256b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(256), 0)
}
func BenchmarkMQTT_QoS0_Pub_______1K_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(1024), 0)
}
func BenchmarkMQTT_QoS0_PubSub1___0b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, "", 1)
}
func BenchmarkMQTT_QoS0_PubSub1___8b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(8), 1)
}
func BenchmarkMQTT_QoS0_PubSub1__32b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(32), 1)
}
func BenchmarkMQTT_QoS0_PubSub1_128b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(128), 1)
}
func BenchmarkMQTT_QoS0_PubSub1_256b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(256), 1)
}
func BenchmarkMQTT_QoS0_PubSub1___1K_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(1024), 1)
}
func BenchmarkMQTT_QoS0_PubSub2___0b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, "", 2)
}
func BenchmarkMQTT_QoS0_PubSub2___8b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(8), 2)
}
func BenchmarkMQTT_QoS0_PubSub2__32b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(32), 2)
}
func BenchmarkMQTT_QoS0_PubSub2_128b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(128), 2)
}
func BenchmarkMQTT_QoS0_PubSub2_256b_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(256), 2)
}
func BenchmarkMQTT_QoS0_PubSub2___1K_Payload(b *testing.B) {
mqttBenchPubQoS0(b, mqttPubSubj, sizedString(1024), 2)
}
func BenchmarkMQTT_QoS1_Pub_______0b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, "", 0)
}
func BenchmarkMQTT_QoS1_Pub_______8b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(8), 0)
}
func BenchmarkMQTT_QoS1_Pub______32b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(32), 0)
}
func BenchmarkMQTT_QoS1_Pub_____128b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(128), 0)
}
func BenchmarkMQTT_QoS1_Pub_____256b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(256), 0)
}
func BenchmarkMQTT_QoS1_Pub_______1K_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(1024), 0)
}
func BenchmarkMQTT_QoS1_PubSub1___0b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, "", 1)
}
func BenchmarkMQTT_QoS1_PubSub1___8b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(8), 1)
}
func BenchmarkMQTT_QoS1_PubSub1__32b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(32), 1)
}
func BenchmarkMQTT_QoS1_PubSub1_128b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(128), 1)
}
func BenchmarkMQTT_QoS1_PubSub1_256b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(256), 1)
}
func BenchmarkMQTT_QoS1_PubSub1___1K_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(1024), 1)
}
func BenchmarkMQTT_QoS1_PubSub2___0b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, "", 2)
}
func BenchmarkMQTT_QoS1_PubSub2___8b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(8), 2)
}
func BenchmarkMQTT_QoS1_PubSub2__32b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(32), 2)
}
func BenchmarkMQTT_QoS1_PubSub2_128b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(128), 2)
}
func BenchmarkMQTT_QoS1_PubSub2_256b_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(256), 2)
}
func BenchmarkMQTT_QoS1_PubSub2___1K_Payload(b *testing.B) {
mqttBenchPubQoS1(b, mqttPubSubj, sizedString(1024), 2)
}