diff --git a/server/client.go b/server/client.go index de1a0c4c..e67ef5b6 100644 --- a/server/client.go +++ b/server/client.go @@ -92,7 +92,7 @@ func init() { } // Lock should be held -func (c *client) initClient() { +func (c *client) initClient(tlsConn bool) { s := c.srv c.cid = atomic.AddUint64(&s.gcid, 1) c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize) @@ -130,11 +130,13 @@ func (c *client) initClient() { // ip.SetWriteBuffer(2 * s.opts.BufSize) // } - // Set the Ping timer - c.setPingTimer() + if !tlsConn { + // Set the Ping timer + c.setPingTimer() - // Spin up the read loop. - go c.readLoop() + // Spin up the read loop. + go c.readLoop() + } } func (c *client) readLoop() { diff --git a/server/route.go b/server/route.go index d79390c7..fe279057 100644 --- a/server/route.go +++ b/server/route.go @@ -138,7 +138,7 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client { c.mu.Lock() // Initialize - c.initClient() + c.initClient(tlsRequired) c.Debugf("Route connection created") @@ -182,6 +182,14 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client { // Rewrap bw c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize) + // Do final client initialization + + // Set the Ping timer + c.setPingTimer() + + // Spin up the read loop. + go c.readLoop() + c.Debugf("TLS handshake complete") cs := conn.ConnectionState() c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite)) diff --git a/server/server.go b/server/server.go index 397a2bae..106e35e7 100644 --- a/server/server.go +++ b/server/server.go @@ -439,7 +439,7 @@ func (s *Server) createClient(conn net.Conn) *client { c.mu.Lock() // Initialize - c.initClient() + c.initClient(tlsRequired) c.Debugf("Client connection created") @@ -491,6 +491,14 @@ func (s *Server) createClient(conn net.Conn) *client { // Rewrap bw c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize) + // Do final client initialization + + // Set the Ping timer + c.setPingTimer() + + // Spin up the read loop. + go c.readLoop() + c.Debugf("TLS handshake complete") cs := conn.ConnectionState() c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite)) diff --git a/test/tls_test.go b/test/tls_test.go index cf840692..a484d2be 100644 --- a/test/tls_test.go +++ b/test/tls_test.go @@ -10,9 +10,11 @@ import ( "io/ioutil" "net" "strings" + "sync" "testing" "time" + "github.com/nats-io/gnatsd/server" "github.com/nats-io/nats" ) @@ -161,3 +163,81 @@ func TestTLSConnectionTimeout(t *testing.T) { t.Fatalf("TLS Timeout response incorrect: %q\n", tlsErr) } } + +func stressConnect(t *testing.T, wg *sync.WaitGroup, errCh chan error, url string, index int) { + defer wg.Done() + + subName := fmt.Sprintf("foo.%d", index) + + for i := 0; i < 100; i++ { + nc, err := nats.SecureConnect(url) + if err != nil { + errCh <- fmt.Errorf("Unable to create TLS connection: %v\n", err) + return + } + defer nc.Close() + + sub, err := nc.SubscribeSync(subName) + if err != nil { + errCh <- fmt.Errorf("Unable to subscribe on '%s': %v\n", subName, err) + return + } + + if err := nc.Publish(subName, []byte("secure data")); err != nil { + errCh <- fmt.Errorf("Unable to send on '%s': %v\n", subName, err) + } + + if _, err := sub.NextMsg(2 * time.Second); err != nil { + errCh <- fmt.Errorf("Unable to get next message: %v\n", err) + } + + nc.Close() + } + + errCh <- nil +} + +func TestTLSStressConnect(t *testing.T) { + opts, err := server.ProcessConfigFile("./configs/tls.conf") + if err != nil { + panic(fmt.Sprintf("Error processing configuration file: %v", err)) + } + opts.NoSigs, opts.NoLog = true, true + + // For this test, remove the authorization + opts.Authorization = "" + + // Increase ssl timeout + opts.TLSTimeout = 2.0 + + srv := RunServer(opts) + defer srv.Shutdown() + + nurl := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + + threadCount := 3 + + errCh := make(chan error, threadCount) + + var wg sync.WaitGroup + wg.Add(threadCount) + + for i := 0; i < threadCount; i++ { + go stressConnect(t, &wg, errCh, nurl, i) + } + + wg.Wait() + + var lastError error + lastError = nil + for i := 0; i < threadCount; i++ { + err := <-errCh + if err != nil { + lastError = err + } + } + + if lastError != nil { + t.Fatalf("%v\n", lastError) + } +}