diff --git a/server/mqtt.go b/server/mqtt.go index 1c903b56..e86558a2 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -2708,11 +2708,17 @@ CHECK: // Will hold this client for a second and then close it. We // do this so that if the client has a reconnect feature we // don't end-up with very rapid flapping between apps. - time.AfterFunc(mqttSessJailDur, func() { - c.closeConnection(DuplicateClientID) - }) + // We need to wait in place and not schedule the connection + // close because if this is a misbehaved client that does + // not wait for the CONNACK and sends other protocols, the + // server would not have a fully setup client and may panic. asm.mu.Unlock() - return nil + select { + case <-s.quitCh: + case <-time.After(mqttSessJailDur): + } + c.closeConnection(DuplicateClientID) + return ErrConnectionClosed } } // If an existing session is in the process of processing some packet, we can't diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 1105d591..b2a1d6bd 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -4361,8 +4361,22 @@ func TestMQTTFlappingSession(t *testing.T) { defer c.Close() proto := mqttCreateConnectProto(ci) if _, err := testMQTTWrite(c, proto); err != nil { - t.Fatalf("Error writing connect: %v", err) + t.Fatalf("Error writing protocols: %v", err) } + // Misbehave and send a SUB protocol without waiting for the CONNACK + w := &mqttWriter{} + pkLen := 2 // for pi + // Topic "foo" + pkLen += 2 + 3 + 1 + w.WriteByte(mqttPacketSub | mqttSubscribeFlags) + w.WriteVarInt(pkLen) + w.WriteUint16(1) + w.WriteBytes([]byte("foo")) + w.WriteByte(1) + if _, err := testMQTTWrite(c, w.Bytes()); err != nil { + t.Fatalf("Error writing protocols: %v", err) + } + // Now read the CONNACK and we should have been disconnected. if _, err := testMQTTRead(c); err == nil { t.Fatal("Expected connection to fail") }