mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 11:48:43 -07:00
Fix TLS issue where server started to receive TLS data on non TLS connection.
Without the server fix, tls_test.go would likely report an error. The server would show a parser error with protocol snippet containing "random" bytes, likely encrypted data.
This commit is contained in:
@@ -92,7 +92,7 @@ func init() {
|
||||
}
|
||||
|
||||
// Lock should be held
|
||||
func (c *client) initClient() {
|
||||
func (c *client) initClient(tlsConn bool) {
|
||||
s := c.srv
|
||||
c.cid = atomic.AddUint64(&s.gcid, 1)
|
||||
c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize)
|
||||
@@ -130,11 +130,13 @@ func (c *client) initClient() {
|
||||
// ip.SetWriteBuffer(2 * s.opts.BufSize)
|
||||
// }
|
||||
|
||||
// Set the Ping timer
|
||||
c.setPingTimer()
|
||||
if !tlsConn {
|
||||
// Set the Ping timer
|
||||
c.setPingTimer()
|
||||
|
||||
// Spin up the read loop.
|
||||
go c.readLoop()
|
||||
// Spin up the read loop.
|
||||
go c.readLoop()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) readLoop() {
|
||||
|
||||
@@ -138,7 +138,7 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client {
|
||||
c.mu.Lock()
|
||||
|
||||
// Initialize
|
||||
c.initClient()
|
||||
c.initClient(tlsRequired)
|
||||
|
||||
c.Debugf("Route connection created")
|
||||
|
||||
@@ -182,6 +182,14 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client {
|
||||
// Rewrap bw
|
||||
c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize)
|
||||
|
||||
// Do final client initialization
|
||||
|
||||
// Set the Ping timer
|
||||
c.setPingTimer()
|
||||
|
||||
// Spin up the read loop.
|
||||
go c.readLoop()
|
||||
|
||||
c.Debugf("TLS handshake complete")
|
||||
cs := conn.ConnectionState()
|
||||
c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite))
|
||||
|
||||
@@ -439,7 +439,7 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
c.mu.Lock()
|
||||
|
||||
// Initialize
|
||||
c.initClient()
|
||||
c.initClient(tlsRequired)
|
||||
|
||||
c.Debugf("Client connection created")
|
||||
|
||||
@@ -491,6 +491,14 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
// Rewrap bw
|
||||
c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize)
|
||||
|
||||
// Do final client initialization
|
||||
|
||||
// Set the Ping timer
|
||||
c.setPingTimer()
|
||||
|
||||
// Spin up the read loop.
|
||||
go c.readLoop()
|
||||
|
||||
c.Debugf("TLS handshake complete")
|
||||
cs := conn.ConnectionState()
|
||||
c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite))
|
||||
|
||||
@@ -10,9 +10,11 @@ import (
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/gnatsd/server"
|
||||
"github.com/nats-io/nats"
|
||||
)
|
||||
|
||||
@@ -161,3 +163,81 @@ func TestTLSConnectionTimeout(t *testing.T) {
|
||||
t.Fatalf("TLS Timeout response incorrect: %q\n", tlsErr)
|
||||
}
|
||||
}
|
||||
|
||||
func stressConnect(t *testing.T, wg *sync.WaitGroup, errCh chan error, url string, index int) {
|
||||
defer wg.Done()
|
||||
|
||||
subName := fmt.Sprintf("foo.%d", index)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
nc, err := nats.SecureConnect(url)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("Unable to create TLS connection: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer nc.Close()
|
||||
|
||||
sub, err := nc.SubscribeSync(subName)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("Unable to subscribe on '%s': %v\n", subName, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := nc.Publish(subName, []byte("secure data")); err != nil {
|
||||
errCh <- fmt.Errorf("Unable to send on '%s': %v\n", subName, err)
|
||||
}
|
||||
|
||||
if _, err := sub.NextMsg(2 * time.Second); err != nil {
|
||||
errCh <- fmt.Errorf("Unable to get next message: %v\n", err)
|
||||
}
|
||||
|
||||
nc.Close()
|
||||
}
|
||||
|
||||
errCh <- nil
|
||||
}
|
||||
|
||||
func TestTLSStressConnect(t *testing.T) {
|
||||
opts, err := server.ProcessConfigFile("./configs/tls.conf")
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error processing configuration file: %v", err))
|
||||
}
|
||||
opts.NoSigs, opts.NoLog = true, true
|
||||
|
||||
// For this test, remove the authorization
|
||||
opts.Authorization = ""
|
||||
|
||||
// Increase ssl timeout
|
||||
opts.TLSTimeout = 2.0
|
||||
|
||||
srv := RunServer(opts)
|
||||
defer srv.Shutdown()
|
||||
|
||||
nurl := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)
|
||||
|
||||
threadCount := 3
|
||||
|
||||
errCh := make(chan error, threadCount)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(threadCount)
|
||||
|
||||
for i := 0; i < threadCount; i++ {
|
||||
go stressConnect(t, &wg, errCh, nurl, i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
var lastError error
|
||||
lastError = nil
|
||||
for i := 0; i < threadCount; i++ {
|
||||
err := <-errCh
|
||||
if err != nil {
|
||||
lastError = err
|
||||
}
|
||||
}
|
||||
|
||||
if lastError != nil {
|
||||
t.Fatalf("%v\n", lastError)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user