[IMPROVED] MQTT error message when client connects with websocket

Websocket is currently not supported for MQTT clients. When a
client tries to connect with websocket protocol to the MQTT port,
the error message: `mid:9 - not connected` would be logged, which
is not really telling.

The server will now guess if the connection was websocket and report
a more appropriate error message, such as:
```
invalid connection, websocket currently not supported
```

Resolves #2126

Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
This commit is contained in:
Ivan Kozlovic
2021-04-22 08:59:33 -06:00
parent eb3af67031
commit d0fd62c83b
2 changed files with 70 additions and 6 deletions

View File

@@ -170,6 +170,13 @@ var (
mqttFlapCleanItvl = mqttSessFlappingCleanupInterval
)
var (
errMQTTWebsocketNotSupported = errors.New("invalid connection, websocket currently not supported")
errMQTTTopicFilterCannotBeEmpty = errors.New("topic filter cannot be empty")
errMQTTMalformedVarInt = errors.New("malformed variable int")
errMQTTSecondConnectPacket = errors.New("received a second CONNECT packet")
)
type srvMQTT struct {
listener net.Listener
listenerErr error
@@ -565,7 +572,13 @@ func (c *client) mqttParse(buf []byte) error {
// If client was not connected yet, the first packet must be
// a mqttPacketConnect otherwise we fail the connection.
if !connected && pt != mqttPacketConnect {
err = errors.New("not connected")
// Try to guess if the client is trying to connect using Websocket,
// which is currently not supported
if bytes.HasPrefix(buf, []byte("GET ")) {
err = errMQTTWebsocketNotSupported
} else {
err = fmt.Errorf("the first packet should be a CONNECT (%v), got %v", mqttPacketConnect, pt)
}
break
}
@@ -647,7 +660,7 @@ func (c *client) mqttParse(buf []byte) error {
case mqttPacketConnect:
// It is an error to receive a second connect packet
if connected {
err = errors.New("second connect packet")
err = errMQTTSecondConnectPacket
break
}
var rc byte
@@ -2913,7 +2926,7 @@ func (c *client) mqttParseSubsOrUnsubs(r *mqttReader, b byte, pl int, sub bool)
return 0, nil, err
}
if len(topic) == 0 {
return 0, nil, errors.New("topic filter cannot be empty")
return 0, nil, errMQTTTopicFilterCannotBeEmpty
}
// Spec [MQTT-3.8.3-1], [MQTT-3.10.3-1]
if !utf8.Valid(topic) {
@@ -3648,7 +3661,7 @@ func (r *mqttReader) readPacketLen() (int, error) {
}
m *= 0x80
if m > 0x200000 {
return 0, errors.New("malformed variable int")
return 0, errMQTTMalformedVarInt
}
}
}

View File

@@ -22,6 +22,8 @@ import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
@@ -1381,6 +1383,9 @@ func TestMQTTConnectNotFirstPacket(t *testing.T) {
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
l := &captureErrorLogger{errCh: make(chan string, 10)}
s.SetLogger(l, false, false)
c, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port))
if err != nil {
t.Fatalf("Error on dial: %v", err)
@@ -1393,6 +1398,15 @@ func TestMQTTConnectNotFirstPacket(t *testing.T) {
t.Fatalf("Error publishing: %v", err)
}
testMQTTExpectDisconnect(t, c)
select {
case err := <-l.errCh:
if !strings.Contains(err, "should be a CONNECT") {
t.Fatalf("Expected error about first packet being a CONNECT, got %v", err)
}
case <-time.After(time.Second):
t.Fatal("Did not log any error")
}
}
func TestMQTTSecondConnect(t *testing.T) {
@@ -1691,7 +1705,7 @@ func TestMQTTParseSub(t *testing.T) {
{"error reading packet id", []byte{1}, mqttSubscribeFlags, 1, eofr, "reading packet identifier"},
{"missing filters", []byte{0, 1}, mqttSubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"},
{"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttSubscribeFlags, 5, eofr, "topic filter"},
{"empty topic", []byte{0, 1, 0, 0}, mqttSubscribeFlags, 4, nil, "topic filter cannot be empty"},
{"empty topic", []byte{0, 1, 0, 0}, mqttSubscribeFlags, 4, nil, errMQTTTopicFilterCannotBeEmpty.Error()},
{"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttSubscribeFlags, 5, nil, "invalid utf8 for topic filter"},
{"missing qos", []byte{0, 1, 0, 1, 'a'}, mqttSubscribeFlags, 5, nil, "QoS"},
{"invalid qos", []byte{0, 1, 0, 1, 'a', 3}, mqttSubscribeFlags, 6, nil, "subscribe QoS value must be 0, 1 or 2"},
@@ -2903,7 +2917,7 @@ func TestMQTTParseUnsub(t *testing.T) {
{"error reading packet id", []byte{1}, mqttUnsubscribeFlags, 1, eofr, "reading packet identifier"},
{"missing filters", []byte{0, 1}, mqttUnsubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"},
{"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttUnsubscribeFlags, 5, eofr, "topic filter"},
{"empty topic", []byte{0, 1, 0, 0}, mqttUnsubscribeFlags, 4, nil, "topic filter cannot be empty"},
{"empty topic", []byte{0, 1, 0, 0}, mqttUnsubscribeFlags, 4, nil, errMQTTTopicFilterCannotBeEmpty.Error()},
{"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttUnsubscribeFlags, 5, nil, "invalid utf8 for topic filter"},
} {
t.Run(test.name, func(t *testing.T) {
@@ -4526,6 +4540,43 @@ func TestMQTTStreamInfoReturnsNonEmptySubject(t *testing.T) {
}
}
func TestMQTTWebsocketNotSupported(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)
l := &captureErrorLogger{errCh: make(chan string, 10)}
s.SetLogger(l, false, false)
addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)
wsc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating connection: %v", err)
}
req := testWSCreateValidReq()
req.URL, _ = url.Parse("ws://" + addr)
if err := req.Write(wsc); err != nil {
t.Fatalf("Error sending request: %v", err)
}
br := bufio.NewReader(wsc)
resp, err := http.ReadResponse(br, req)
if err == nil {
if resp != nil {
defer resp.Body.Close()
}
t.Fatalf("Expected error, got resp=%+v", resp)
}
select {
case err := <-l.errCh:
if !strings.Contains(err, "not supported") {
t.Fatalf("Expected error about websocket not supported, got %v", err)
}
case <-time.After(time.Second):
t.Fatal("Did not log any error")
}
}
//////////////////////////////////////////////////////////////////////////
//
// Benchmarks