diff --git a/server/client_test.go b/server/client_test.go index f1fa8966..39468c1a 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -2331,34 +2331,56 @@ func TestCloseConnectionLogsReason(t *testing.T) { } func TestCloseConnectionVeryEarly(t *testing.T) { - o := DefaultOptions() - s := RunServer(o) - defer s.Shutdown() + for _, test := range []struct { + name string + useTLS bool + }{ + {"no_tls", false}, + {"tls", true}, + } { + t.Run(test.name, func(t *testing.T) { + o := DefaultOptions() + if test.useTLS { + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/server-cert.pem", + KeyFile: "../test/configs/certs/server-key.pem", + CaFile: "../test/configs/certs/ca.pem", + } + tlsConfig, err := GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error generating tls config: %v", err) + } + o.TLSConfig = tlsConfig + } + s := RunServer(o) + defer s.Shutdown() - // The issue was with a connection that would break right when - // server was sending the INFO. Creating a bare TCP connection - // and closing it right away won't help reproduce the problem. - // So testing in 2 steps. + // The issue was with a connection that would break right when + // server was sending the INFO. Creating a bare TCP connection + // and closing it right away won't help reproduce the problem. + // So testing in 2 steps. - // Get a normal TCP connection to the server. - c, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", o.Port)) - if err != nil { - s.mu.Unlock() - t.Fatalf("Unable to create tcp connection") + // Get a normal TCP connection to the server. + c, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", o.Port)) + if err != nil { + s.mu.Unlock() + t.Fatalf("Unable to create tcp connection") + } + // Now close it. + c.Close() + + // Wait that num clients falls to 0. + checkClientsCount(t, s, 0) + + // Call again with this closed connection. Alternatively, we + // would have to call with a fake connection that implements + // net.Conn but returns an error on Write. + s.createClient(c, nil) + + // This connection should not have been added to the server. + checkClientsCount(t, s, 0) + }) } - // Now close it. - c.Close() - - // Wait that num clients falls to 0. - checkClientsCount(t, s, 0) - - // Call again with this closed connection. Alternatively, we - // would have to call with a fake connection that implements - // net.Conn but returns an error on Write. - s.createClient(c, nil) - - // This connection should not have been added to the server. - checkClientsCount(t, s, 0) } type connAddrString struct { diff --git a/server/const.go b/server/const.go index b00f5091..2f37ba93 100644 --- a/server/const.go +++ b/server/const.go @@ -40,7 +40,7 @@ var ( const ( // VERSION is the current version for the server. - VERSION = "2.2.0-beta.22" + VERSION = "2.2.0-beta.23" // PROTO is the currently supported protocol. // 0 was the original diff --git a/server/server.go b/server/server.go index 4ae18d40..2d864f1a 100644 --- a/server/server.go +++ b/server/server.go @@ -2074,11 +2074,14 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // Re-Grab lock c.mu.Lock() + // Connection could have been closed while sending the INFO proto. + isClosed := c.isClosed() + tlsRequired := ws == nil && info.TLSRequired var pre []byte // If we have both TLS and non-TLS allowed we need to see which // one the client wants. - if opts.TLSConfig != nil && opts.AllowNonTLS { + if !isClosed && opts.TLSConfig != nil && opts.AllowNonTLS { pre = make([]byte, 4) c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(opts.TLSTimeout))) n, _ := io.ReadFull(c.nc, pre[:]) @@ -2092,7 +2095,7 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { } // Check for TLS - if tlsRequired { + if !isClosed && tlsRequired { c.Debugf("Starting TLS client connection handshake") // If we have a prebuffer create a multi-reader. if len(pre) > 0 { @@ -2124,17 +2127,18 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // Indicate that handshake is complete (used in monitoring) c.flags.set(handshakeComplete) + + // The connection may have been closed + isClosed = c.isClosed() } - // The connection may have been closed - if c.isClosed() { + // If connection is marked as closed, bail out. + if isClosed { c.mu.Unlock() - // If it was due to TLS timeout, closeConnection() has already been called. - // Otherwise, if connection was marked as closed while sending the INFO, - // we need to call closeConnection() directly here. - if !info.TLSRequired { - c.closeConnection(WriteError) - } + // Connection could have been closed due to TLS timeout or while trying + // to send the INFO protocol. We need to call closeConnection() to make + // sure that proper cleanup is done. + c.closeConnection(WriteError) return nil }