From d0fd62c83b9fda44b7fb26f15d23ef54ffc8e8f0 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Thu, 22 Apr 2021 08:59:33 -0600 Subject: [PATCH] [IMPROVED] MQTT error message when client connects with websocket Websocket is currently not supported for MQTT clients. When a client tries to connect with websocket protocol to the MQTT port, the error message: `mid:9 - not connected` would be logged, which is not really telling. The server will now guess if the connection was websocket and report a more appropriate error message, such as: ``` invalid connection, websocket currently not supported ``` Resolves #2126 Signed-off-by: Ivan Kozlovic --- server/mqtt.go | 21 +++++++++++++---- server/mqtt_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/server/mqtt.go b/server/mqtt.go index 2a846aef..c9a1ee26 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -170,6 +170,13 @@ var ( mqttFlapCleanItvl = mqttSessFlappingCleanupInterval ) +var ( + errMQTTWebsocketNotSupported = errors.New("invalid connection, websocket currently not supported") + errMQTTTopicFilterCannotBeEmpty = errors.New("topic filter cannot be empty") + errMQTTMalformedVarInt = errors.New("malformed variable int") + errMQTTSecondConnectPacket = errors.New("received a second CONNECT packet") +) + type srvMQTT struct { listener net.Listener listenerErr error @@ -565,7 +572,13 @@ func (c *client) mqttParse(buf []byte) error { // If client was not connected yet, the first packet must be // a mqttPacketConnect otherwise we fail the connection. if !connected && pt != mqttPacketConnect { - err = errors.New("not connected") + // Try to guess if the client is trying to connect using Websocket, + // which is currently not supported + if bytes.HasPrefix(buf, []byte("GET ")) { + err = errMQTTWebsocketNotSupported + } else { + err = fmt.Errorf("the first packet should be a CONNECT (%v), got %v", mqttPacketConnect, pt) + } break } @@ -647,7 +660,7 @@ func (c *client) mqttParse(buf []byte) error { case mqttPacketConnect: // It is an error to receive a second connect packet if connected { - err = errors.New("second connect packet") + err = errMQTTSecondConnectPacket break } var rc byte @@ -2913,7 +2926,7 @@ func (c *client) mqttParseSubsOrUnsubs(r *mqttReader, b byte, pl int, sub bool) return 0, nil, err } if len(topic) == 0 { - return 0, nil, errors.New("topic filter cannot be empty") + return 0, nil, errMQTTTopicFilterCannotBeEmpty } // Spec [MQTT-3.8.3-1], [MQTT-3.10.3-1] if !utf8.Valid(topic) { @@ -3648,7 +3661,7 @@ func (r *mqttReader) readPacketLen() (int, error) { } m *= 0x80 if m > 0x200000 { - return 0, errors.New("malformed variable int") + return 0, errMQTTMalformedVarInt } } } diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 2ede9b1d..609c87d2 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -22,6 +22,8 @@ import ( "fmt" "io" "net" + "net/http" + "net/url" "os" "strings" "sync" @@ -1381,6 +1383,9 @@ func TestMQTTConnectNotFirstPacket(t *testing.T) { s := testMQTTRunServer(t, o) defer testMQTTShutdownServer(s) + l := &captureErrorLogger{errCh: make(chan string, 10)} + s.SetLogger(l, false, false) + c, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)) if err != nil { t.Fatalf("Error on dial: %v", err) @@ -1393,6 +1398,15 @@ func TestMQTTConnectNotFirstPacket(t *testing.T) { t.Fatalf("Error publishing: %v", err) } testMQTTExpectDisconnect(t, c) + + select { + case err := <-l.errCh: + if !strings.Contains(err, "should be a CONNECT") { + t.Fatalf("Expected error about first packet being a CONNECT, got %v", err) + } + case <-time.After(time.Second): + t.Fatal("Did not log any error") + } } func TestMQTTSecondConnect(t *testing.T) { @@ -1691,7 +1705,7 @@ func TestMQTTParseSub(t *testing.T) { {"error reading packet id", []byte{1}, mqttSubscribeFlags, 1, eofr, "reading packet identifier"}, {"missing filters", []byte{0, 1}, mqttSubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"}, {"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttSubscribeFlags, 5, eofr, "topic filter"}, - {"empty topic", []byte{0, 1, 0, 0}, mqttSubscribeFlags, 4, nil, "topic filter cannot be empty"}, + {"empty topic", []byte{0, 1, 0, 0}, mqttSubscribeFlags, 4, nil, errMQTTTopicFilterCannotBeEmpty.Error()}, {"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttSubscribeFlags, 5, nil, "invalid utf8 for topic filter"}, {"missing qos", []byte{0, 1, 0, 1, 'a'}, mqttSubscribeFlags, 5, nil, "QoS"}, {"invalid qos", []byte{0, 1, 0, 1, 'a', 3}, mqttSubscribeFlags, 6, nil, "subscribe QoS value must be 0, 1 or 2"}, @@ -2903,7 +2917,7 @@ func TestMQTTParseUnsub(t *testing.T) { {"error reading packet id", []byte{1}, mqttUnsubscribeFlags, 1, eofr, "reading packet identifier"}, {"missing filters", []byte{0, 1}, mqttUnsubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"}, {"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttUnsubscribeFlags, 5, eofr, "topic filter"}, - {"empty topic", []byte{0, 1, 0, 0}, mqttUnsubscribeFlags, 4, nil, "topic filter cannot be empty"}, + {"empty topic", []byte{0, 1, 0, 0}, mqttUnsubscribeFlags, 4, nil, errMQTTTopicFilterCannotBeEmpty.Error()}, {"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttUnsubscribeFlags, 5, nil, "invalid utf8 for topic filter"}, } { t.Run(test.name, func(t *testing.T) { @@ -4526,6 +4540,43 @@ func TestMQTTStreamInfoReturnsNonEmptySubject(t *testing.T) { } } +func TestMQTTWebsocketNotSupported(t *testing.T) { + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + l := &captureErrorLogger{errCh: make(chan string, 10)} + s.SetLogger(l, false, false) + + addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port) + wsc, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating connection: %v", err) + } + req := testWSCreateValidReq() + req.URL, _ = url.Parse("ws://" + addr) + if err := req.Write(wsc); err != nil { + t.Fatalf("Error sending request: %v", err) + } + br := bufio.NewReader(wsc) + resp, err := http.ReadResponse(br, req) + if err == nil { + if resp != nil { + defer resp.Body.Close() + } + t.Fatalf("Expected error, got resp=%+v", resp) + } + + select { + case err := <-l.errCh: + if !strings.Contains(err, "not supported") { + t.Fatalf("Expected error about websocket not supported, got %v", err) + } + case <-time.After(time.Second): + t.Fatal("Did not log any error") + } +} + ////////////////////////////////////////////////////////////////////////// // // Benchmarks