diff --git a/server/client_test.go b/server/client_test.go index 68e48dcc..64899a9e 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -85,7 +85,7 @@ func createClientAsync(ch chan *client, s *Server, cli net.Conn) { s.grWG.Add(1) } go func() { - c := s.createClient(cli, nil, nil) + c := s.createClient(cli) // Must be here to suppress +OK c.opts.Verbose = false if startWriteLoop { @@ -2317,7 +2317,7 @@ func TestCloseConnectionVeryEarly(t *testing.T) { // Call again with this closed connection. Alternatively, we // would have to call with a fake connection that implements // net.Conn but returns an error on Write. - s.createClient(c, nil, nil) + s.createClient(c) // This connection should not have been added to the server. checkClientsCount(t, s, 0) diff --git a/server/mqtt.go b/server/mqtt.go index dcc4a957..f47d7675 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -15,6 +15,7 @@ package server import ( "bytes" + "crypto/tls" "encoding/binary" "encoding/json" "errors" @@ -279,10 +280,141 @@ func (s *Server) startMQTT() { scheme = "tls" } s.Noticef("Listening for MQTT clients on %s://%s:%d", scheme, o.Host, o.Port) - go s.acceptConnections(hl, "MQTT", func(conn net.Conn) { s.createClient(conn, nil, &mqtt{}) }, nil) + go s.acceptConnections(hl, "MQTT", func(conn net.Conn) { s.createMQTTClient(conn) }, nil) s.mu.Unlock() } +// This is similar to createClient() but has some modifications specifi to MQTT clients. +// The comments have been kept to minimum to reduce code size. Check createClient() for +// more details. +func (s *Server) createMQTTClient(conn net.Conn) *client { + opts := s.getOpts() + + maxPay := int32(opts.MaxPayload) + maxSubs := int32(opts.MaxSubs) + if maxSubs == 0 { + maxSubs = -1 + } + now := time.Now() + + c := &client{srv: s, nc: conn, mpay: maxPay, msubs: maxSubs, start: now, last: now, mqtt: &mqtt{}} + // MQTT clients don't send NATS CONNECT protocols. So make it an "echo" + // client, but disable verbose and pedantic (by not setting them). + c.opts.Echo = true + + c.registerWithAccount(s.globalAccount()) + + s.mu.Lock() + // Check auth, override if applicable. + authRequired := s.info.AuthRequired || s.mqtt.authOverride + s.totalClients++ + s.mu.Unlock() + + c.mu.Lock() + if authRequired { + c.flags.set(expectConnect) + } + c.initClient() + c.Debugf("Client connection created") + c.mu.Unlock() + + s.mu.Lock() + if !s.running || s.ldm { + if s.shutdown { + conn.Close() + } + s.mu.Unlock() + return c + } + + if opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn { + s.mu.Unlock() + c.maxConnExceeded() + return nil + } + s.clients[c.cid] = c + + tlsRequired := opts.MQTT.TLSConfig != nil + s.mu.Unlock() + + c.mu.Lock() + + isClosed := c.isClosed() + + var pre []byte + if !isClosed && tlsRequired && opts.AllowNonTLS { + pre = make([]byte, 4) + c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(opts.MQTT.TLSTimeout))) + n, _ := io.ReadFull(c.nc, pre[:]) + c.nc.SetReadDeadline(time.Time{}) + pre = pre[:n] + if n > 0 && pre[0] == 0x16 { + tlsRequired = true + } else { + tlsRequired = false + } + } + + if !isClosed && tlsRequired { + c.Debugf("Starting TLS client connection handshake") + if len(pre) > 0 { + c.nc = &tlsMixConn{c.nc, bytes.NewBuffer(pre)} + pre = nil + } + + c.nc = tls.Server(c.nc, opts.MQTT.TLSConfig) + conn := c.nc.(*tls.Conn) + + ttl := secondsToDuration(opts.MQTT.TLSTimeout) + time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) + conn.SetReadDeadline(time.Now().Add(ttl)) + + c.mu.Unlock() + if err := conn.Handshake(); err != nil { + c.Errorf("TLS handshake error: %v", err) + c.closeConnection(TLSHandshakeError) + return nil + } + conn.SetReadDeadline(time.Time{}) + + c.mu.Lock() + + c.flags.set(handshakeComplete) + + isClosed = c.isClosed() + } + + if isClosed { + c.mu.Unlock() + c.closeConnection(WriteError) + return nil + } + + if authRequired { + timeout := opts.AuthTimeout + // Possibly override with MQTT specific value. + if opts.MQTT.AuthTimeout != 0 { + timeout = opts.MQTT.AuthTimeout + } + c.setAuthTimer(secondsToDuration(timeout)) + } + + // No Ping timer for MQTT clients... + + s.startGoRoutine(func() { c.readLoop(pre) }) + s.startGoRoutine(func() { c.writeLoop() }) + + if tlsRequired { + c.Debugf("TLS handshake complete") + cs := c.nc.(*tls.Conn).ConnectionState() + c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite)) + } + + c.mu.Unlock() + + return c +} + // Given the mqtt options, we check if any auth configuration // has been provided. If so, possibly create users/nkey users and // store them in s.mqtt.users/nkeys. diff --git a/server/server.go b/server/server.go index 59767929..88d03794 100644 --- a/server/server.go +++ b/server/server.go @@ -1765,7 +1765,7 @@ func (s *Server) AcceptLoop(clr chan struct{}) { s.clientConnectURLs = s.getClientConnectURLs() s.listener = l - go s.acceptConnections(l, "Client", func(conn net.Conn) { s.createClient(conn, nil, nil) }, + go s.acceptConnections(l, "Client", func(conn net.Conn) { s.createClient(conn) }, func(_ error) bool { if s.isLameDuckMode() { // Signal that we are not accepting new clients @@ -2091,7 +2091,7 @@ func (c *tlsMixConn) Read(b []byte) (int, error) { return c.Conn.Read(b) } -func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client { +func (s *Server) createClient(conn net.Conn) *client { // Snapshot server options. opts := s.getOpts() @@ -2103,14 +2103,7 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client } now := time.Now() - c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now, ws: ws} - if mqtt != nil { - c.mqtt = mqtt - // Set some of the options here since MQTT clients don't - // send a regular CONNECT (but have their own). - c.opts.Lang = "mqtt" - c.opts.Verbose = false - } + c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now} c.registerWithAccount(s.globalAccount()) @@ -2118,28 +2111,17 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client var authRequired bool s.mu.Lock() - // We don't need the INFO to mqtt clients. - if mqtt == nil { - // Grab JSON info string - info = s.copyInfo() - // If this is a websocket client and there is no top-level auth specified, - // then we use the websocket's specific boolean that will be set to true - // if there is any auth{} configured in websocket{}. - if ws != nil && !info.AuthRequired { - info.AuthRequired = s.websocket.authOverride - } - if s.nonceRequired() { - // Nonce handling - var raw [nonceLen]byte - nonce := raw[:] - s.generateNonce(nonce) - info.Nonce = string(nonce) - } - c.nonce = []byte(info.Nonce) - authRequired = info.AuthRequired - } else { - authRequired = s.info.AuthRequired || s.mqtt.authOverride + // Grab JSON info string + info = s.copyInfo() + if s.nonceRequired() { + // Nonce handling + var raw [nonceLen]byte + nonce := raw[:] + s.generateNonce(nonce) + info.Nonce = string(nonce) } + c.nonce = []byte(info.Nonce) + authRequired = info.AuthRequired s.totalClients++ s.mu.Unlock() @@ -2155,13 +2137,10 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client c.Debugf("Client connection created") - // We don't send the INFO to mqtt clients. - if mqtt == nil { - // Send our information. - // Need to be sent in place since writeLoop cannot be started until - // TLS handshake is done (if applicable). - c.sendProtoNow(c.generateClientInfoJSON(info)) - } + // Send our information. + // Need to be sent in place since writeLoop cannot be started until + // TLS handshake is done (if applicable). + c.sendProtoNow(c.generateClientInfoJSON(info)) // Unlock to register c.mu.Unlock() @@ -2194,21 +2173,7 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client } s.clients[c.cid] = c - // May be overridden based on type of client. - TLSConfig := opts.TLSConfig - TLSTimeout := opts.TLSTimeout - tlsRequired := info.TLSRequired - // Websocket clients do TLS in the websocket http server. - if ws != nil { - tlsRequired = false - } else if mqtt != nil { - tlsRequired = opts.MQTT.TLSConfig != nil - if tlsRequired { - TLSConfig = opts.MQTT.TLSConfig - TLSTimeout = opts.MQTT.TLSTimeout - } - } s.mu.Unlock() // Re-Grab lock @@ -2222,7 +2187,7 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client // one the client wants. if !isClosed && opts.TLSConfig != nil && opts.AllowNonTLS { pre = make([]byte, 4) - c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(TLSTimeout))) + c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(opts.TLSTimeout))) n, _ := io.ReadFull(c.nc, pre[:]) c.nc.SetReadDeadline(time.Time{}) pre = pre[:n] @@ -2243,11 +2208,11 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client pre = nil } - c.nc = tls.Server(c.nc, TLSConfig) + c.nc = tls.Server(c.nc, opts.TLSConfig) conn := c.nc.(*tls.Conn) // Setup the timeout - ttl := secondsToDuration(TLSTimeout) + ttl := secondsToDuration(opts.TLSTimeout) time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) conn.SetReadDeadline(time.Now().Add(ttl)) @@ -2285,17 +2250,7 @@ func (s *Server) createClient(conn net.Conn, ws *websocket, mqtt *mqtt) *client // the race where the timer fires during the handshake and causes the // server to write bad data to the socket. See issue #432. if authRequired { - timeout := opts.AuthTimeout - // For websocket, possibly override only if set. We make sure that - // opts.AuthTimeout is set to a default value if not configured, - // but we don't do the same for websocket's one so that we know - // if user has explicitly set or not. - if ws != nil && opts.Websocket.AuthTimeout != 0 { - timeout = opts.Websocket.AuthTimeout - } else if mqtt != nil && opts.MQTT.AuthTimeout != 0 { - timeout = opts.MQTT.AuthTimeout - } - c.setAuthTimer(secondsToDuration(timeout)) + c.setAuthTimer(secondsToDuration(opts.AuthTimeout)) } // Do final client initialization diff --git a/server/websocket.go b/server/websocket.go index 10defdc4..c29722f5 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -871,7 +871,7 @@ func (s *Server) startWebsocketServer() { s.Errorf(err.Error()) return } - s.createClient(res.conn, res.ws, nil) + s.createWSClient(res.conn, res.ws) }) hs := &http.Server{ Addr: hp, @@ -897,6 +897,102 @@ func (s *Server) startWebsocketServer() { s.mu.Unlock() } +// This is similar to createClient() but has some modifications +// specific to handle websocket clients. +// The comments have been kept to minimum to reduce code size. +// Check createClient() for more details. +func (s *Server) createWSClient(conn net.Conn, ws *websocket) *client { + opts := s.getOpts() + + maxPay := int32(opts.MaxPayload) + maxSubs := int32(opts.MaxSubs) + if maxSubs == 0 { + maxSubs = -1 + } + now := time.Now() + + c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now, ws: ws} + + c.registerWithAccount(s.globalAccount()) + + var info Info + var authRequired bool + + s.mu.Lock() + info = s.copyInfo() + // Check auth, override if applicable. + if !info.AuthRequired { + // Set info.AuthRequired since this is what is sent to the client. + info.AuthRequired = s.websocket.authOverride + } + if s.nonceRequired() { + var raw [nonceLen]byte + nonce := raw[:] + s.generateNonce(nonce) + info.Nonce = string(nonce) + } + c.nonce = []byte(info.Nonce) + authRequired = info.AuthRequired + + s.totalClients++ + s.mu.Unlock() + + c.mu.Lock() + if authRequired { + c.flags.set(expectConnect) + } + c.initClient() + c.Debugf("Client connection created") + c.sendProtoNow(c.generateClientInfoJSON(info)) + c.mu.Unlock() + + s.mu.Lock() + if !s.running || s.ldm { + if s.shutdown { + conn.Close() + } + s.mu.Unlock() + return c + } + + if opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn { + s.mu.Unlock() + c.maxConnExceeded() + return nil + } + s.clients[c.cid] = c + + // Websocket clients do TLS in the websocket http server. + // So no TLS here... + s.mu.Unlock() + + c.mu.Lock() + + if c.isClosed() { + c.mu.Unlock() + c.closeConnection(WriteError) + return nil + } + + if authRequired { + timeout := opts.AuthTimeout + // Possibly override with Websocket specific value. + if opts.Websocket.AuthTimeout != 0 { + timeout = opts.Websocket.AuthTimeout + } + c.setAuthTimer(secondsToDuration(timeout)) + } + + c.setPingTimer() + + s.startGoRoutine(func() { c.readLoop(nil) }) + s.startGoRoutine(func() { c.writeLoop() }) + + c.mu.Unlock() + + return c +} + type wsCaptureHTTPServerLog struct { s *Server }