diff --git a/go.mod b/go.mod index 62917a58..f8fab6b3 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/klauspost/compress v1.17.0 github.com/minio/highwayhash v1.0.2 github.com/nats-io/jwt/v2 v2.5.2 - github.com/nats-io/nats.go v1.30.2 + github.com/nats-io/nats.go v1.30.3-0.20231009181226-1941a1a4f14f github.com/nats-io/nkeys v0.4.5 github.com/nats-io/nuid v1.0.1 go.uber.org/automaxprocs v1.5.3 diff --git a/go.sum b/go.sum index 72369f26..cc87607c 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,8 @@ github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= github.com/nats-io/jwt/v2 v2.5.2 h1:DhGH+nKt+wIkDxM6qnVSKjokq5t59AZV5HRcFW0zJwU= github.com/nats-io/jwt/v2 v2.5.2/go.mod h1:24BeQtRwxRV8ruvC4CojXlx/WQ/VjuwlYiH+vu/+ibI= -github.com/nats-io/nats.go v1.30.2 h1:aloM0TGpPorZKQhbAkdCzYDj+ZmsJDyeo3Gkbr72NuY= -github.com/nats-io/nats.go v1.30.2/go.mod h1:dcfhUgmQNN4GJEfIb2f9R7Fow+gzBF4emzDHrVBd5qM= +github.com/nats-io/nats.go v1.30.3-0.20231009181226-1941a1a4f14f h1:1OBmQ3HJsJAX4vemhoCQjonLBaQ7yx/7PUe6oF1kzvE= +github.com/nats-io/nats.go v1.30.3-0.20231009181226-1941a1a4f14f/go.mod h1:dcfhUgmQNN4GJEfIb2f9R7Fow+gzBF4emzDHrVBd5qM= github.com/nats-io/nkeys v0.4.5 h1:Zdz2BUlFm4fJlierwvGK+yl20IAKUm7eV6AAZXEhkPk= github.com/nats-io/nkeys v0.4.5/go.mod h1:XUkxdLPTufzlihbamfzQ7mw/VGx6ObUs+0bN5sNvt64= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= diff --git a/server/client.go b/server/client.go index e3364c8a..afa849b8 100644 --- a/server/client.go +++ b/server/client.go @@ -141,6 +141,7 @@ const ( expectConnect // Marks if this connection is expected to send a CONNECT connectProcessFinished // Marks if this connection has finished the connect process. compressionNegotiated // Marks if this connection has negotiated compression level with remote. + didTLSFirst // Marks if this connection requested and was accepted doing the TLS handshake first (prior to INFO). ) // set the flag (would be equivalent to set the boolean to true) diff --git a/server/client_test.go b/server/client_test.go index 95051a4f..55fef6ff 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -2602,3 +2602,345 @@ func TestClientUserInfoReq(t *testing.T) { t.Fatalf("User info for %q did not match", "admin") } } + +func TestTLSClientHandshakeFirst(t *testing.T) { + tmpl := ` + listen: "127.0.0.1:-1" + tls { + cert_file: "../test/configs/certs/server-cert.pem" + key_file: "../test/configs/certs/server-key.pem" + timeout: 1 + first: %s + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, "true"))) + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + connect := func(tlsfirst, expectedOk bool) { + opts := []nats.Option{nats.RootCAs("../test/configs/certs/ca.pem")} + if tlsfirst { + opts = append(opts, nats.TLSHandshakeFirst()) + } + nc, err := nats.Connect(fmt.Sprintf("tls://localhost:%d", o.Port), opts...) + if expectedOk { + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if tlsfirst { + cz, err := s.Connz(nil) + if err != nil { + t.Fatalf("Error getting connz: %v", err) + } + if !cz.Conns[0].TLSFirst { + t.Fatal("Expected TLSFirst boolean to be set, it was not") + } + } + } else if !expectedOk && err == nil { + nc.Close() + t.Fatal("Expected error, got none") + } + } + + // Server is TLS first, but client is not, so should fail. + connect(false, false) + + // Now client is TLS first too, so should work. + connect(true, true) + + // Config reload the server and disable tls first + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, "false")) + + // Now if client wants TLS first, connection should fail. + connect(true, false) + + // But if it does not, should be ok. + connect(false, true) + + // Config reload the server again and enable tls first + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, "true")) + + // If both client and server are TLS first, this should work. + connect(true, true) +} + +func TestTLSClientHandshakeFirstFallbackDelayConfigValues(t *testing.T) { + tmpl := ` + listen: "127.0.0.1:-1" + tls { + cert_file: "../test/configs/certs/server-cert.pem" + key_file: "../test/configs/certs/server-key.pem" + timeout: 1 + first: %s + } + ` + for _, test := range []struct { + name string + val string + first bool + delay time.Duration + }{ + {"first as boolean true", "true", true, 0}, + {"first as boolean false", "false", false, 0}, + {"first as string true", "\"true\"", true, 0}, + {"first as string false", "\"false\"", false, 0}, + {"first as string on", "on", true, 0}, + {"first as string off", "off", false, 0}, + {"first as string auto", "auto", true, DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY}, + {"first as string auto_fallback", "auto_fallback", true, DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY}, + {"first as fallback duration", "300ms", true, 300 * time.Millisecond}, + } { + t.Run(test.name, func(t *testing.T) { + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, test.val))) + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + if test.first { + if !o.TLSHandshakeFirst { + t.Fatal("Expected tls first to be true, was not") + } + if test.delay != o.TLSHandshakeFirstFallback { + t.Fatalf("Expected fallback delay to be %v, got %v", test.delay, o.TLSHandshakeFirstFallback) + } + } else { + if o.TLSHandshakeFirst { + t.Fatal("Expected tls first to be false, was not") + } + if o.TLSHandshakeFirstFallback != 0 { + t.Fatalf("Expected fallback delay to be 0, got %v", o.TLSHandshakeFirstFallback) + } + } + }) + } +} + +type pauseAfterDial struct { + delay time.Duration +} + +func (d *pauseAfterDial) Dial(network, address string) (net.Conn, error) { + c, err := net.Dial(network, address) + if err != nil { + return nil, err + } + time.Sleep(d.delay) + return c, nil +} + +func TestTLSClientHandshakeFirstFallbackDelay(t *testing.T) { + // Using certificates with RSA 4K to make sure that the fallback does + // not prevent a client with TLS first to successfully connect. + tmpl := ` + listen: "127.0.0.1:-1" + tls { + cert_file: "./configs/certs/tls/benchmark-server-cert-rsa-4096.pem" + key_file: "./configs/certs/tls/benchmark-server-key-rsa-4096.pem" + timeout: 1 + first: %s + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, "auto"))) + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + url := fmt.Sprintf("tls://localhost:%d", o.Port) + d := &pauseAfterDial{delay: DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY + 100*time.Millisecond} + + // Connect a client without "TLS first" and it should be accepted. + nc, err := nats.Connect(url, + nats.SetCustomDialer(d), + nats.Secure(&tls.Config{ + ServerName: "reuben.nats.io", + MinVersion: tls.VersionTLS12, + }), + nats.RootCAs("./configs/certs/tls/benchmark-ca-cert.pem")) + require_NoError(t, err) + defer nc.Close() + // Check that the TLS first in monitoring is set to false + cs, err := s.Connz(nil) + require_NoError(t, err) + if cs.Conns[0].TLSFirst { + t.Fatal("Expected monitoring ConnInfo.TLSFirst to be false, it was not") + } + nc.Close() + + // Wait for the client to be removed + checkClientsCount(t, s, 0) + + // Increase the fallback delay with config reload. + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, "\"1s\"")) + + // This time, start the client with "TLS first". + // We will also make sure that we did not wait for the fallback delay + // in order to connect. + start := time.Now() + nc, err = nats.Connect(url, + nats.SetCustomDialer(d), + nats.Secure(&tls.Config{ + ServerName: "reuben.nats.io", + MinVersion: tls.VersionTLS12, + }), + nats.RootCAs("./configs/certs/tls/benchmark-ca-cert.pem"), + nats.TLSHandshakeFirst()) + require_NoError(t, err) + require_True(t, time.Since(start) < 500*time.Millisecond) + defer nc.Close() + + // Check that the TLS first in monitoring is set to true. + cs, err = s.Connz(nil) + require_NoError(t, err) + if !cs.Conns[0].TLSFirst { + t.Fatal("Expected monitoring ConnInfo.TLSFirst to be true, it was not") + } + nc.Close() +} + +func TestTLSClientHandshakeFirstFallbackDelayAndAllowNonTLS(t *testing.T) { + tmpl := ` + listen: "127.0.0.1:-1" + tls { + cert_file: "../test/configs/certs/server-cert.pem" + key_file: "../test/configs/certs/server-key.pem" + timeout: 1 + first: %s + } + allow_non_tls: true + ` + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, "true"))) + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + // We first start with a server that has handshake first set to true + // and allow_non_tls. In that case, only "TLS first" clients should be + // accepted. + url := fmt.Sprintf("tls://localhost:%d", o.Port) + nc, err := nats.Connect(url, + nats.RootCAs("../test/configs/certs/ca.pem"), + nats.TLSHandshakeFirst()) + require_NoError(t, err) + defer nc.Close() + // Check that the TLS first in monitoring is set to true + cs, err := s.Connz(nil) + require_NoError(t, err) + if !cs.Conns[0].TLSFirst { + t.Fatal("Expected monitoring ConnInfo.TLSFirst to be true, it was not") + } + nc.Close() + + // Client not using "TLS First" should fail. + nc, err = nats.Connect(url, nats.RootCAs("../test/configs/certs/ca.pem")) + if err == nil { + nc.Close() + t.Fatal("Expected connection to fail, it did not") + } + + // And non TLS clients should also fail to connect. + nc, err = nats.Connect(fmt.Sprintf("nats://127.0.0.1:%d", o.Port)) + if err == nil { + nc.Close() + t.Fatal("Expected connection to fail, it did not") + } + + // Now we will replace TLS first in server with a fallback delay. + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, "\"25ms\"")) + + // Clients with "TLS first" should still be able to connect + nc, err = nats.Connect(url, + nats.RootCAs("../test/configs/certs/ca.pem"), + nats.TLSHandshakeFirst()) + require_NoError(t, err) + defer nc.Close() + + checkConnInfo := func(isTLS, isTLSFirst bool) { + t.Helper() + cs, err = s.Connz(nil) + require_NoError(t, err) + conn := cs.Conns[0] + if !isTLS { + if conn.TLSVersion != _EMPTY_ { + t.Fatalf("Being a non TLS client, there should not be TLSVersion set, got %v", conn.TLSVersion) + } + if conn.TLSFirst { + t.Fatal("Being a non TLS client, TLSFirst should not be set, but it was") + } + return + } + if isTLSFirst && !conn.TLSFirst { + t.Fatal("Expected monitoring ConnInfo.TLSFirst to be true, it was not") + } else if !isTLSFirst && conn.TLSFirst { + t.Fatal("Expected monitoring ConnInfo.TLSFirst to be false, it was not") + } + nc.Close() + + checkClientsCount(t, s, 0) + } + checkConnInfo(true, true) + + // Clients with TLS but not "TLS first" should also be able to connect. + nc, err = nats.Connect(url, nats.RootCAs("../test/configs/certs/ca.pem")) + require_NoError(t, err) + defer nc.Close() + checkConnInfo(true, false) + + // And non TLS clients should also be able to connect. + nc, err = nats.Connect(fmt.Sprintf("nats://127.0.0.1:%d", o.Port)) + require_NoError(t, err) + defer nc.Close() + checkConnInfo(false, false) +} + +func TestTLSClientHandshakeFirstAndInProcessConnection(t *testing.T) { + conf := createConfFile(t, []byte(` + listen: "127.0.0.1:-1" + tls { + cert_file: "../test/configs/certs/server-cert.pem" + key_file: "../test/configs/certs/server-key.pem" + timeout: 1 + first: true + } + `)) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + // Check that we can create an in process connection that does not use TLS + nc, err := nats.Connect(_EMPTY_, nats.InProcessServer(s)) + require_NoError(t, err) + defer nc.Close() + if nc.TLSRequired() { + t.Fatalf("Shouldn't have required TLS for in-process connection") + } + if _, err = nc.TLSConnectionState(); err == nil { + t.Fatal("Should have got an error retrieving TLS connection state") + } + nc.Close() + + // If the client wants TLS, it should get a TLS connection. + nc, err = nats.Connect(_EMPTY_, + nats.InProcessServer(s), + nats.RootCAs("../test/configs/certs/ca.pem")) + require_NoError(t, err) + defer nc.Close() + if _, err = nc.TLSConnectionState(); err != nil { + t.Fatal("Should have not got an error retrieving TLS connection state") + } + // However, the server would not have sent that TLS was required, + // but instead it is available. + if nc.TLSRequired() { + t.Fatalf("Shouldn't have required TLS for in-process connection") + } + nc.Close() + + // The in-process connection with TLS and "TLS first" should also be working. + nc, err = nats.Connect(_EMPTY_, + nats.InProcessServer(s), + nats.RootCAs("../test/configs/certs/ca.pem"), + nats.TLSHandshakeFirst()) + require_NoError(t, err) + defer nc.Close() + if !nc.TLSRequired() { + t.Fatalf("The server should have sent that TLS is required") + } + if _, err = nc.TLSConnectionState(); err != nil { + t.Fatal("Should have not got an error retrieving TLS connection state") + } +} diff --git a/server/config_check_test.go b/server/config_check_test.go index 34c95b17..a07962b7 100644 --- a/server/config_check_test.go +++ b/server/config_check_test.go @@ -1808,6 +1808,30 @@ func TestConfigCheck(t *testing.T) { errorPos: 0, reason: "", }, + { + name: "TLS handshake first, wrong type", + config: ` + port: -1 + tls { + first: 123 + } + `, + err: fmt.Errorf("field %q should be a boolean or a string, got int64", "first"), + errorLine: 4, + errorPos: 6, + }, + { + name: "TLS handshake first, wrong value", + config: ` + port: -1 + tls { + first: "123" + } + `, + err: fmt.Errorf("field %q's value %q is invalid", "first", "123"), + errorLine: 4, + errorPos: 6, + }, } checkConfig := func(config string) error { diff --git a/server/const.go b/server/const.go index 64ec6b62..91b5c76e 100644 --- a/server/const.go +++ b/server/const.go @@ -82,6 +82,12 @@ const ( // TLS_TIMEOUT is the TLS wait time. TLS_TIMEOUT = 2 * time.Second + // DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY is the default amount of + // time for the server to wait for the TLS handshake with a client to + // be initiated before falling back to sending the INFO protocol first. + // See TLSHandshakeFirst and TLSHandshakeFirstFallback options. + DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY = 50 * time.Millisecond + // AUTH_TIMEOUT is the authorization wait time. AUTH_TIMEOUT = 2 * time.Second diff --git a/server/monitor.go b/server/monitor.go index 66f5e81a..073c468e 100644 --- a/server/monitor.go +++ b/server/monitor.go @@ -130,6 +130,7 @@ type ConnInfo struct { TLSVersion string `json:"tls_version,omitempty"` TLSCipher string `json:"tls_cipher_suite,omitempty"` TLSPeerCerts []*TLSPeerCert `json:"tls_peer_certs,omitempty"` + TLSFirst bool `json:"tls_first,omitempty"` AuthorizedUser string `json:"authorized_user,omitempty"` Account string `json:"account,omitempty"` Subs []string `json:"subscriptions_list,omitempty"` @@ -568,6 +569,7 @@ func (ci *ConnInfo) fill(client *client, nc net.Conn, now time.Time, auth bool) if auth && len(cs.PeerCertificates) > 0 { ci.TLSPeerCerts = makePeerCerts(cs.PeerCertificates) } + ci.TLSFirst = client.flags.isSet(didTLSFirst) } } diff --git a/server/opts.go b/server/opts.go index 039f982c..38b5799e 100644 --- a/server/opts.go +++ b/server/opts.go @@ -327,11 +327,23 @@ type Options struct { TLSConfig *tls.Config `json:"-"` TLSPinnedCerts PinnedCertSet `json:"-"` TLSRateLimit int64 `json:"-"` - AllowNonTLS bool `json:"-"` - WriteDeadline time.Duration `json:"-"` - MaxClosedClients int `json:"-"` - LameDuckDuration time.Duration `json:"-"` - LameDuckGracePeriod time.Duration `json:"-"` + // When set to true, the server will perform the TLS handshake before + // sending the INFO protocol. For clients that are not configured + // with a similar option, their connection will fail with some sort + // of timeout or EOF error since they are expecting to receive an + // INFO protocol first. + TLSHandshakeFirst bool `json:"-"` + // If TLSHandshakeFirst is true and this value is strictly positive, + // the server will wait for that amount of time for the TLS handshake + // to start before falling back to previous behavior of sending the + // INFO protocol first. It allows for a mix of newer clients that can + // require a TLS handshake first, and older clients that can't. + TLSHandshakeFirstFallback time.Duration `json:"-"` + AllowNonTLS bool `json:"-"` + WriteDeadline time.Duration `json:"-"` + MaxClosedClients int `json:"-"` + LameDuckDuration time.Duration `json:"-"` + LameDuckGracePeriod time.Duration `json:"-"` // MaxTracedMsgLen is the maximum printable length for traced messages. MaxTracedMsgLen int `json:"-"` @@ -638,7 +650,8 @@ type TLSConfigOpts struct { Insecure bool Map bool TLSCheckKnownURLs bool - HandshakeFirst bool // Indicate that the TLS handshake should occur first, before sending the INFO protocol + HandshakeFirst bool // Indicate that the TLS handshake should occur first, before sending the INFO protocol. + FallbackDelay time.Duration // Where supported, indicates how long to wait for the handshake before falling back to sending the INFO protocol first. Timeout float64 RateLimit int64 Ciphers []uint16 @@ -1072,6 +1085,8 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error o.TLSMap = tc.Map o.TLSPinnedCerts = tc.PinnedCerts o.TLSRateLimit = tc.RateLimit + o.TLSHandshakeFirst = tc.HandshakeFirst + o.TLSHandshakeFirstFallback = tc.FallbackDelay // Need to keep track of path of the original TLS config // and certs path for OCSP Stapling monitoring. @@ -4312,7 +4327,30 @@ func parseTLS(v interface{}, isClientCtx bool) (t *TLSConfigOpts, retErr error) } tc.CertMatch = certMatch case "handshake_first", "first", "immediate": - tc.HandshakeFirst = mv.(bool) + switch mv := mv.(type) { + case bool: + tc.HandshakeFirst = mv + case string: + switch strings.ToLower(mv) { + case "true", "on": + tc.HandshakeFirst = true + case "false", "off": + tc.HandshakeFirst = false + case "auto", "auto_fallback": + tc.HandshakeFirst = true + tc.FallbackDelay = DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY + default: + // Check to see if this is a duration. + if dur, err := time.ParseDuration(mv); err == nil { + tc.HandshakeFirst = true + tc.FallbackDelay = dur + break + } + return nil, &configErr{tk, fmt.Sprintf("field %q's value %q is invalid", mk, mv)} + } + default: + return nil, &configErr{tk, fmt.Sprintf("field %q should be a boolean or a string, got %T", mk, mv)} + } case "ocsp_peer": switch vv := mv.(type) { case bool: diff --git a/server/reload.go b/server/reload.go index 23988171..4e1c2f71 100644 --- a/server/reload.go +++ b/server/reload.go @@ -266,6 +266,28 @@ func (t *tlsPinnedCertOption) Apply(server *Server) { server.Noticef("Reloaded: %d pinned_certs", len(t.newValue)) } +// tlsHandshakeFirst implements the option interface for the tls `handshake first` setting. +type tlsHandshakeFirst struct { + noopOption + newValue bool +} + +// Apply is a no-op because the timeout will be reloaded after options are applied. +func (t *tlsHandshakeFirst) Apply(server *Server) { + server.Noticef("Reloaded: Client TLS handshake first: %v", t.newValue) +} + +// tlsHandshakeFirstFallback implements the option interface for the tls `handshake first fallback delay` setting. +type tlsHandshakeFirstFallback struct { + noopOption + newValue time.Duration +} + +// Apply is a no-op because the timeout will be reloaded after options are applied. +func (t *tlsHandshakeFirstFallback) Apply(server *Server) { + server.Noticef("Reloaded: Client TLS handshake first fallback delay: %v", t.newValue) +} + // authOption is a base struct that provides default option behaviors. type authOption struct { noopOption @@ -1222,6 +1244,10 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { diffOpts = append(diffOpts, &tlsTimeoutOption{newValue: newValue.(float64)}) case "tlspinnedcerts": diffOpts = append(diffOpts, &tlsPinnedCertOption{newValue: newValue.(PinnedCertSet)}) + case "tlshandshakefirst": + diffOpts = append(diffOpts, &tlsHandshakeFirst{newValue: newValue.(bool)}) + case "tlshandshakefirstfallback": + diffOpts = append(diffOpts, &tlsHandshakeFirstFallback{newValue: newValue.(time.Duration)}) case "username": diffOpts = append(diffOpts, &usernameOption{}) case "password": diff --git a/server/server.go b/server/server.go index d1d0d109..b1138903 100644 --- a/server/server.go +++ b/server/server.go @@ -2573,6 +2573,9 @@ func (s *Server) AcceptLoop(clr chan struct{}) { // Alert of TLS enabled. if opts.TLSConfig != nil { s.Noticef("TLS required for client connections") + if opts.TLSHandshakeFirst && opts.TLSHandshakeFirstFallback == 0 { + s.Warnf("Clients that are not using \"TLS Handshake First\" option will fail to connect") + } } // If server was started with RANDOM_PORT (-1), opts.Port would be equal @@ -3041,10 +3044,37 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client { c.Debugf("Client connection created") - // Send our information. - // Need to be sent in place since writeLoop cannot be started until - // TLS handshake is done (if applicable). - c.sendProtoNow(c.generateClientInfoJSON(info)) + // Save info.TLSRequired value since we may neeed to change it back and forth. + orgInfoTLSReq := info.TLSRequired + + var tlsFirstFallback time.Duration + // Check if we should do TLS first. + tlsFirst := opts.TLSConfig != nil && opts.TLSHandshakeFirst + if tlsFirst { + // Make sure info.TLSRequired is set to true (it could be false + // if AllowNonTLS is enabled). + info.TLSRequired = true + // Get the fallback delay value if applicable. + if f := opts.TLSHandshakeFirstFallback; f > 0 { + tlsFirstFallback = f + } else if inProcess { + // For in-process connection, we will always have a fallback + // delay. It allows support for non-TLS, TLS and "TLS First" + // in-process clients to successfully connect. + tlsFirstFallback = DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY + } + } + + // Decide if we are going to require TLS or not and generate INFO json. + tlsRequired := info.TLSRequired + infoBytes := c.generateClientInfoJSON(info) + + // Send our information, except if TLS and TLSHandshakeFirst is requested. + if !tlsFirst { + // Need to be sent in place since writeLoop cannot be started until + // TLS handshake is done (if applicable). + c.sendProtoNow(infoBytes) + } // Unlock to register c.mu.Unlock() @@ -3077,20 +3107,50 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client { } s.clients[c.cid] = c - tlsRequired := info.TLSRequired s.mu.Unlock() // Re-Grab lock c.mu.Lock() - // Connection could have been closed while sending the INFO proto. isClosed := c.isClosed() - var pre []byte + // We need first to check for "TLS First" fallback delay. + if !isClosed && tlsFirstFallback > 0 { + // We wait and see if we are getting any data. Since we did not send + // the INFO protocol yet, only clients that use TLS first should be + // sending data (the TLS handshake). We don't really check the content: + // if it is a rogue agent and not an actual client performing the + // TLS handshake, the error will be detected when performing the + // handshake on our side. + pre = make([]byte, 4) + c.nc.SetReadDeadline(time.Now().Add(tlsFirstFallback)) + n, _ := io.ReadFull(c.nc, pre[:]) + c.nc.SetReadDeadline(time.Time{}) + // If we get any data (regardless of possible timeout), we will proceed + // with the TLS handshake. + if n > 0 { + pre = pre[:n] + } else { + // We did not get anything so we will send the INFO protocol. + pre = nil + + // Restore the original info.TLSRequired value if it is + // different that the current value and regenerate infoBytes. + if orgInfoTLSReq != info.TLSRequired { + info.TLSRequired = orgInfoTLSReq + infoBytes = c.generateClientInfoJSON(info) + } + c.sendProtoNow(infoBytes) + // Set the boolean to false for the rest of the function. + tlsFirst = false + // Check closed status again + isClosed = c.isClosed() + } + } // If we have both TLS and non-TLS allowed we need to see which // one the client wants. We'll always allow this for in-process // connections. - if !isClosed && opts.TLSConfig != nil && (inProcess || opts.AllowNonTLS) { + if !isClosed && !tlsFirst && opts.TLSConfig != nil && (inProcess || opts.AllowNonTLS) { pre = make([]byte, 4) c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(opts.TLSTimeout))) n, _ := io.ReadFull(c.nc, pre[:]) @@ -3125,12 +3185,18 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client { } } - // If connection is marked as closed, bail out. + // Now, send the INFO if it was delayed + if !isClosed && tlsFirst { + c.flags.set(didTLSFirst) + c.sendProtoNow(infoBytes) + // Check closed status + isClosed = c.isClosed() + } + + // Connection could have been closed while sending the INFO proto. if isClosed { c.mu.Unlock() - // Connection could have been closed due to TLS timeout or while trying - // to send the INFO protocol. We need to call closeConnection() to make - // sure that proper cleanup is done. + // We need to call closeConnection() to make sure that proper cleanup is done. c.closeConnection(WriteError) return nil } diff --git a/test/tls_test.go b/test/tls_test.go index ad0c91af..0fc391cb 100644 --- a/test/tls_test.go +++ b/test/tls_test.go @@ -82,6 +82,7 @@ func TestTLSInProcessConnection(t *testing.T) { if err != nil { t.Fatal(err) } + defer nc.Close() if nc.TLSRequired() { t.Fatalf("Shouldn't have required TLS for in-process connection")