mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 03:38:42 -07:00
[IMPROVED] MQTT error message when client connects with websocket
Websocket is currently not supported for MQTT clients. When a client tries to connect with websocket protocol to the MQTT port, the error message: `mid:9 - not connected` would be logged, which is not really telling. The server will now guess if the connection was websocket and report a more appropriate error message, such as: ``` invalid connection, websocket currently not supported ``` Resolves #2126 Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
This commit is contained in:
@@ -170,6 +170,13 @@ var (
|
||||
mqttFlapCleanItvl = mqttSessFlappingCleanupInterval
|
||||
)
|
||||
|
||||
var (
|
||||
errMQTTWebsocketNotSupported = errors.New("invalid connection, websocket currently not supported")
|
||||
errMQTTTopicFilterCannotBeEmpty = errors.New("topic filter cannot be empty")
|
||||
errMQTTMalformedVarInt = errors.New("malformed variable int")
|
||||
errMQTTSecondConnectPacket = errors.New("received a second CONNECT packet")
|
||||
)
|
||||
|
||||
type srvMQTT struct {
|
||||
listener net.Listener
|
||||
listenerErr error
|
||||
@@ -565,7 +572,13 @@ func (c *client) mqttParse(buf []byte) error {
|
||||
// If client was not connected yet, the first packet must be
|
||||
// a mqttPacketConnect otherwise we fail the connection.
|
||||
if !connected && pt != mqttPacketConnect {
|
||||
err = errors.New("not connected")
|
||||
// Try to guess if the client is trying to connect using Websocket,
|
||||
// which is currently not supported
|
||||
if bytes.HasPrefix(buf, []byte("GET ")) {
|
||||
err = errMQTTWebsocketNotSupported
|
||||
} else {
|
||||
err = fmt.Errorf("the first packet should be a CONNECT (%v), got %v", mqttPacketConnect, pt)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
@@ -647,7 +660,7 @@ func (c *client) mqttParse(buf []byte) error {
|
||||
case mqttPacketConnect:
|
||||
// It is an error to receive a second connect packet
|
||||
if connected {
|
||||
err = errors.New("second connect packet")
|
||||
err = errMQTTSecondConnectPacket
|
||||
break
|
||||
}
|
||||
var rc byte
|
||||
@@ -2913,7 +2926,7 @@ func (c *client) mqttParseSubsOrUnsubs(r *mqttReader, b byte, pl int, sub bool)
|
||||
return 0, nil, err
|
||||
}
|
||||
if len(topic) == 0 {
|
||||
return 0, nil, errors.New("topic filter cannot be empty")
|
||||
return 0, nil, errMQTTTopicFilterCannotBeEmpty
|
||||
}
|
||||
// Spec [MQTT-3.8.3-1], [MQTT-3.10.3-1]
|
||||
if !utf8.Valid(topic) {
|
||||
@@ -3648,7 +3661,7 @@ func (r *mqttReader) readPacketLen() (int, error) {
|
||||
}
|
||||
m *= 0x80
|
||||
if m > 0x200000 {
|
||||
return 0, errors.New("malformed variable int")
|
||||
return 0, errMQTTMalformedVarInt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -1381,6 +1383,9 @@ func TestMQTTConnectNotFirstPacket(t *testing.T) {
|
||||
s := testMQTTRunServer(t, o)
|
||||
defer testMQTTShutdownServer(s)
|
||||
|
||||
l := &captureErrorLogger{errCh: make(chan string, 10)}
|
||||
s.SetLogger(l, false, false)
|
||||
|
||||
c, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port))
|
||||
if err != nil {
|
||||
t.Fatalf("Error on dial: %v", err)
|
||||
@@ -1393,6 +1398,15 @@ func TestMQTTConnectNotFirstPacket(t *testing.T) {
|
||||
t.Fatalf("Error publishing: %v", err)
|
||||
}
|
||||
testMQTTExpectDisconnect(t, c)
|
||||
|
||||
select {
|
||||
case err := <-l.errCh:
|
||||
if !strings.Contains(err, "should be a CONNECT") {
|
||||
t.Fatalf("Expected error about first packet being a CONNECT, got %v", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Did not log any error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMQTTSecondConnect(t *testing.T) {
|
||||
@@ -1691,7 +1705,7 @@ func TestMQTTParseSub(t *testing.T) {
|
||||
{"error reading packet id", []byte{1}, mqttSubscribeFlags, 1, eofr, "reading packet identifier"},
|
||||
{"missing filters", []byte{0, 1}, mqttSubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"},
|
||||
{"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttSubscribeFlags, 5, eofr, "topic filter"},
|
||||
{"empty topic", []byte{0, 1, 0, 0}, mqttSubscribeFlags, 4, nil, "topic filter cannot be empty"},
|
||||
{"empty topic", []byte{0, 1, 0, 0}, mqttSubscribeFlags, 4, nil, errMQTTTopicFilterCannotBeEmpty.Error()},
|
||||
{"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttSubscribeFlags, 5, nil, "invalid utf8 for topic filter"},
|
||||
{"missing qos", []byte{0, 1, 0, 1, 'a'}, mqttSubscribeFlags, 5, nil, "QoS"},
|
||||
{"invalid qos", []byte{0, 1, 0, 1, 'a', 3}, mqttSubscribeFlags, 6, nil, "subscribe QoS value must be 0, 1 or 2"},
|
||||
@@ -2903,7 +2917,7 @@ func TestMQTTParseUnsub(t *testing.T) {
|
||||
{"error reading packet id", []byte{1}, mqttUnsubscribeFlags, 1, eofr, "reading packet identifier"},
|
||||
{"missing filters", []byte{0, 1}, mqttUnsubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"},
|
||||
{"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttUnsubscribeFlags, 5, eofr, "topic filter"},
|
||||
{"empty topic", []byte{0, 1, 0, 0}, mqttUnsubscribeFlags, 4, nil, "topic filter cannot be empty"},
|
||||
{"empty topic", []byte{0, 1, 0, 0}, mqttUnsubscribeFlags, 4, nil, errMQTTTopicFilterCannotBeEmpty.Error()},
|
||||
{"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttUnsubscribeFlags, 5, nil, "invalid utf8 for topic filter"},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
@@ -4526,6 +4540,43 @@ func TestMQTTStreamInfoReturnsNonEmptySubject(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMQTTWebsocketNotSupported(t *testing.T) {
|
||||
o := testMQTTDefaultOptions()
|
||||
s := testMQTTRunServer(t, o)
|
||||
defer testMQTTShutdownServer(s)
|
||||
|
||||
l := &captureErrorLogger{errCh: make(chan string, 10)}
|
||||
s.SetLogger(l, false, false)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)
|
||||
wsc, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating connection: %v", err)
|
||||
}
|
||||
req := testWSCreateValidReq()
|
||||
req.URL, _ = url.Parse("ws://" + addr)
|
||||
if err := req.Write(wsc); err != nil {
|
||||
t.Fatalf("Error sending request: %v", err)
|
||||
}
|
||||
br := bufio.NewReader(wsc)
|
||||
resp, err := http.ReadResponse(br, req)
|
||||
if err == nil {
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
t.Fatalf("Expected error, got resp=%+v", resp)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-l.errCh:
|
||||
if !strings.Contains(err, "not supported") {
|
||||
t.Fatalf("Expected error about websocket not supported, got %v", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Did not log any error")
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Benchmarks
|
||||
|
||||
Reference in New Issue
Block a user