From 67425d23c8b4edfe5749a0e6b7e5be8a28d35fa5 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Wed, 2 Dec 2020 15:52:06 -0700 Subject: [PATCH] Add c.isMqtt() and c.isWebsocket() This hides the check on "c.mqtt != nil" or "c.ws != nil". Added some tests. Signed-off-by: Ivan Kozlovic --- server/auth.go | 9 ++++--- server/client.go | 51 ++++++++++++++++++++-------------------- server/mqtt.go | 8 ++++++- server/mqtt_test.go | 51 ++++++++++++++++++++++++++++++++++++++-- server/parser.go | 2 +- server/route.go | 2 +- server/server.go | 2 +- server/websocket.go | 6 +++++ server/websocket_test.go | 2 +- 9 files changed, 96 insertions(+), 37 deletions(-) diff --git a/server/auth.go b/server/auth.go index da27dc7e..e925b898 100644 --- a/server/auth.go +++ b/server/auth.go @@ -345,11 +345,10 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo ) s.mu.Lock() authRequired := s.info.AuthRequired - // c.ws/mqtt is immutable, but may need lock if we get race reports. if !authRequired { - if c.mqtt != nil { + if c.isMqtt() { authRequired = s.mqtt.authOverride - } else if c.ws != nil { + } else if c.isWebsocket() { // If no auth required for regular clients, then check if // we have an override for websocket clients. authRequired = s.websocket.authOverride @@ -367,7 +366,7 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo noAuthUser string ) tlsMap := opts.TLSMap - if c.mqtt != nil { + if c.isMqtt() { mo := &opts.MQTT // Always override TLSMap. tlsMap = mo.TLSMap @@ -380,7 +379,7 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo token = mo.Token ao = true } - } else if c.ws != nil { + } else if c.isWebsocket() { wo := &opts.Websocket // Always override TLSMap. tlsMap = wo.TLSMap diff --git a/server/client.go b/server/client.go index 485a72df..cd43d58a 100644 --- a/server/client.go +++ b/server/client.go @@ -541,9 +541,9 @@ func (c *client) initClient() { switch c.kind { case CLIENT: name := "cid" - if c.ws != nil { + if c.isWebsocket() { name = "wid" - } else if c.mqtt != nil { + } else if c.isMqtt() { name = "mid" } c.ncs.Store(fmt.Sprintf("%s - %s:%d", conn, name, c.cid)) @@ -968,8 +968,8 @@ func (c *client) readLoop(pre []byte) { return } nc := c.nc - ws := c.ws != nil - if c.mqtt != nil { + ws := c.isWebsocket() + if c.isMqtt() { c.mqtt.r = &mqttReader{reader: nc} } c.in.rsz = startBufSize @@ -991,7 +991,7 @@ func (c *client) readLoop(pre []byte) { c.mu.Unlock() defer func() { - if c.mqtt != nil { + if c.isMqtt() { s.mqttHandleWill(c) } // These are used only in the readloop, so we can set them to nil @@ -1155,7 +1155,7 @@ func closedStateForErr(err error) ClosedState { // collapsePtoNB will place primary onto nb buffer as needed in prep for WriteTo. // This will return a copy on purpose. func (c *client) collapsePtoNB() (net.Buffers, int64) { - if c.ws != nil { + if c.isWebsocket() { return c.wsCollapsePtoNB() } if c.out.p != nil { @@ -1169,7 +1169,7 @@ func (c *client) collapsePtoNB() (net.Buffers, int64) { // This will handle the fixup needed on a partial write. // Assume pending has been already calculated correctly. func (c *client) handlePartialWrite(pnb net.Buffers) { - if c.ws != nil { + if c.isWebsocket() { c.ws.frames = append(pnb, c.ws.frames...) return } @@ -1263,7 +1263,7 @@ func (c *client) flushOutbound() bool { // Subtract from pending bytes and messages. c.out.pb -= n - if c.ws != nil { + if c.isWebsocket() { c.ws.fs -= n } c.out.pm -= apm // FIXME(dlc) - this will not be totally accurate on partials. @@ -1382,7 +1382,7 @@ func (c *client) markConnAsClosed(reason ClosedState) { c.flags.set(connMarkedClosed) // For a websocket client, unless we are told not to flush, enqueue // a websocket CloseMessage based on the reason. - if !skipFlush && c.ws != nil && !c.ws.closeSent { + if !skipFlush && c.isWebsocket() && !c.ws.closeSent { c.wsEnqueueCloseMessage(reason) } // Be consistent with the creation: for routes and gateways, @@ -1942,7 +1942,7 @@ func (c *client) sendRTTPing() bool { // the c.rtt is 0 and wants to force an update by sending a PING. // Client lock held on entry. func (c *client) sendRTTPingLocked() bool { - if c.mqtt != nil { + if c.isMqtt() { return false } // Most client libs send a CONNECT+PING and wait for a PONG from the @@ -1975,7 +1975,7 @@ func (c *client) generateClientInfoJSON(info Info) []byte { info.CID = c.cid info.ClientIP = c.host info.MaxPayload = c.mpay - if c.ws != nil { + if c.isWebsocket() { info.ClientConnectURLs = info.WSConnectURLs } info.WSConnectURLs = nil @@ -1990,7 +1990,7 @@ func (c *client) sendErr(err string) { if c.trace { c.traceOutOp("-ERR", []byte(err)) } - if c.mqtt == nil { + if !c.isMqtt() { c.enqueueProto([]byte(fmt.Sprintf(errProto, err))) } c.mu.Unlock() @@ -3326,7 +3326,7 @@ func (c *client) processInboundClientMsg(msg []byte) bool { } // If MQTT client, check for retain flag now that we have passed permissions check - if c.mqtt != nil { + if c.isMqtt() { c.mqttHandlePubRetain() } @@ -4002,7 +4002,7 @@ func (c *client) processPingTimer() { c.mu.Lock() c.ping.tmr = nil // Check if connection is still opened - if c.isClosed() || c.mqtt != nil { + if c.isClosed() { c.mu.Unlock() return } @@ -4061,7 +4061,7 @@ func adjustPingIntervalForGateway(d time.Duration) time.Duration { // Lock should be held func (c *client) setPingTimer() { - if c.srv == nil || c.mqtt != nil { + if c.srv == nil { return } d := c.srv.getOpts().PingInterval @@ -4637,18 +4637,19 @@ func (c *client) connectionTypeAllowed(acts map[string]struct{}) bool { if len(acts) == 0 { return true } - // Assume standard client, then update based on presence of websocket - // or other type. - want := jwt.ConnectionTypeStandard - if c.kind == LEAF { + var want string + switch c.kind { + case CLIENT: + if c.isWebsocket() { + want = jwt.ConnectionTypeWebsocket + } else if c.isMqtt() { + want = jwt.ConnectionTypeMqtt + } else { + want = jwt.ConnectionTypeStandard + } + case LEAF: want = jwt.ConnectionTypeLeafnode } - if c.ws != nil { - want = jwt.ConnectionTypeWebsocket - } - if c.mqtt != nil { - want = jwt.ConnectionTypeMqtt - } _, ok := acts[want] return ok } diff --git a/server/mqtt.go b/server/mqtt.go index f47d7675..fba4a644 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -456,6 +456,12 @@ func validateMQTTOptions(o *Options) error { return nil } +// Returns true if this connection is from a MQTT client. +// Lock held on entry. +func (c *client) isMqtt() bool { + return c.mqtt != nil +} + // Parse protocols inside the given buffer. // This is invoked from the readLoop. func (c *client) mqttParse(buf []byte) error { @@ -2019,7 +2025,7 @@ func mqttDeliverMsgCb(sub *subscription, pc *client, subject, reply string, msg } // In JS case, we need to use the pc.ca.deliver value as the subject. subject = string(pc.pa.deliver) - } else if pc.mqtt != nil { + } else if pc.isMqtt() { // This is a MQTT publisher... ppFlags = pc.mqtt.pp.flags pQoS = mqttGetQoS(ppFlags) diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 66375287..23ebefe3 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -570,7 +570,7 @@ func testMQTTGetClient(t testing.TB, s *Server, clientID string) *client { s.mu.Lock() for _, c := range s.clients { c.mu.Lock() - if c.mqtt != nil && c.mqtt.cp != nil && c.mqtt.cp.clientID == clientID { + if c.isMqtt() && c.mqtt.cp != nil && c.mqtt.cp.clientID == clientID { mc = c } c.mu.Unlock() @@ -725,6 +725,30 @@ func testMQTTCheckConnAck(t testing.TB, r *mqttReader, rc byte, sessionPresent b } } +func TestMQTTRequiresJSEnabled(t *testing.T) { + o := testMQTTDefaultOptions() + acc := NewAccount("mqtt") + o.Accounts = []*Account{acc} + o.Users = []*User{&User{Username: "mqtt", Account: acc}} + s := testMQTTRunServer(t, o) + defer s.Shutdown() + + addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port) + c, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating mqtt connection: %v", err) + } + defer c.Close() + + proto := mqttCreateConnectProto(&mqttConnInfo{cleanSess: true, user: "mqtt"}) + if _, err := testMQTTWrite(c, proto); err != nil { + t.Fatalf("Error writing connect: %v", err) + } + if _, err := testMQTTRead(c); err == nil { + t.Fatal("Expected failure, did not get one") + } +} + func testMQTTEnableJSForAccount(t *testing.T, s *Server, accName string) { t.Helper() acc, err := s.LookupAccount(accName) @@ -835,7 +859,7 @@ func TestMQTTTLSVerifyAndMap(t *testing.T) { s.mu.Lock() for _, sc := range s.clients { sc.mu.Lock() - if sc.mqtt != nil { + if sc.isMqtt() { c = sc } sc.mu.Unlock() @@ -1380,6 +1404,29 @@ func TestMQTTConnKeepAlive(t *testing.T) { testMQTTExpectDisconnect(t, mc) } +func TestMQTTDontSetPinger(t *testing.T) { + o := testMQTTDefaultOptions() + o.PingInterval = 15 * time.Millisecond + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: "mqtt", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer mc.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + + c := testMQTTGetClient(t, s, "mqtt") + c.mu.Lock() + timerSet := c.ping.tmr != nil + c.mu.Unlock() + if timerSet { + t.Fatalf("Ping timer should not be set for MQTT clients") + } + + // Wait a bit and expect nothing (and connection should still be valid) + testMQTTExpectNothing(t, r) + testMQTTPublish(t, mc, r, 0, false, false, "foo", 0, []byte("msg")) +} + func TestMQTTTopicAndSubjectConversion(t *testing.T) { for _, test := range []struct { name string diff --git a/server/parser.go b/server/parser.go index f76337df..ad2a83b3 100644 --- a/server/parser.go +++ b/server/parser.go @@ -132,7 +132,7 @@ const ( func (c *client) parse(buf []byte) error { // Branch out to mqtt clients. c.mqtt is immutable, but should it become // an issue (say data race detection), we could branch outside in readLoop - if c.mqtt != nil { + if c.isMqtt() { return c.mqttParse(buf) } var i int diff --git a/server/route.go b/server/route.go index 56b2f238..9ad52421 100644 --- a/server/route.go +++ b/server/route.go @@ -726,7 +726,7 @@ func (s *Server) sendAsyncInfoToClients(regCli, wsCli bool) { // registered (server has received CONNECT and first PING). For // clients that are not at this stage, this will happen in the // processing of the first PING (see client.processPing) - if ((regCli && c.ws == nil) || (wsCli && c.ws != nil)) && + if ((regCli && !c.isWebsocket()) || (wsCli && c.isWebsocket())) && c.opts.Protocol >= ClientProtoInfo && c.flags.isSet(firstPongSent) { // sendInfo takes care of checking if the connection is still diff --git a/server/server.go b/server/server.go index 88d03794..d64cddd2 100644 --- a/server/server.go +++ b/server/server.go @@ -2431,7 +2431,7 @@ func (s *Server) removeClient(c *client) { if updateProtoInfoCount { s.cproto-- } - mqtt := c.mqtt != nil + mqtt := c.isMqtt() s.mu.Unlock() if mqtt { s.mqttHandleClosedClient(c) diff --git a/server/websocket.go b/server/websocket.go index c29722f5..901b4017 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -152,6 +152,12 @@ func wsGet(r io.Reader, buf []byte, pos, needed int) ([]byte, int, error) { return b, pos + avail, nil } +// Returns true if this connection is from a Websocket client. +// Lock held on entry. +func (c *client) isWebsocket() bool { + return c.ws != nil +} + // Returns a slice of byte slices corresponding to payload of websocket frames. // The byte slice `buf` is filled with bytes from the connection's read loop. // This function will decode the frame headers and unmask the payload(s). diff --git a/server/websocket_test.go b/server/websocket_test.go index e0b075fb..101d1ba1 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -2567,7 +2567,7 @@ func TestWSWebrowserClient(t *testing.T) { } c.mu.Lock() - ok := c.ws != nil && c.ws.browser == true + ok := c.isWebsocket() && c.ws.browser == true c.mu.Unlock() if !ok { t.Fatalf("Client is not marked as webrowser client")