From 3b64567f00cc3514e66d4e494fc4e02e40fed4ee Mon Sep 17 00:00:00 2001 From: Derek Collison Date: Sun, 22 Nov 2015 14:43:16 -0800 Subject: [PATCH] tls flags, proper timeouts --- gnatsd.go | 35 ++++++++++++ server/configs/tls.conf | 1 + server/const.go | 6 +- server/opts.go | 70 ++++++++++++++++------- server/opts_test.go | 3 +- server/server.go | 123 +++++++++++++++++++++++++++++++++++----- server/usage.go | 11 +++- test/configs/tls.conf | 2 + test/tls_test.go | 41 ++++++++++++++ 9 files changed, 252 insertions(+), 40 deletions(-) diff --git a/gnatsd.go b/gnatsd.go index 5cbee58f..6fb280ed 100644 --- a/gnatsd.go +++ b/gnatsd.go @@ -53,6 +53,12 @@ func main() { flag.IntVar(&opts.ProfPort, "profile", 0, "Profiling HTTP port") flag.StringVar(&opts.RoutesStr, "routes", "", "Routes to actively solicit a connection.") + flag.BoolVar(&opts.TLS, "tls", false, "Enable TLS.") + flag.BoolVar(&opts.TLSVerify, "tlsverify", false, "Enable TLS with client verification.") + flag.StringVar(&opts.TLSCert, "tlscert", "", "Server certificate file.") + flag.StringVar(&opts.TLSKey, "tlskey", "", "Private key for server certificate.") + flag.StringVar(&opts.TLSCaCert, "tlscacert", "", "Client certificate CA for verification.") + // Not public per se, will be replaced with dynamic system, but can be used to lower memory footprint when // lots of connections present. flag.IntVar(&opts.BufSize, "bs", 0, "Read/Write buffer size per client connection.") @@ -98,6 +104,9 @@ func main() { } opts.Routes = newroutes + // Configure TLS based on any present flags + configureTLS(&opts) + // Create the server with appropriate options. s := server.New(&opts) @@ -148,3 +157,29 @@ func configureLogger(s *server.Server, opts *server.Options) { s.SetLogger(log, opts.Debug, opts.Trace) } + +func configureTLS(opts *server.Options) { + // If no trigger flags, ignore the others + if !opts.TLS && !opts.TLSVerify { + return + } + if opts.TLSCert == "" { + server.PrintAndDie("TLS Server certificate must be present and valid.") + } + if opts.TLSKey == "" { + server.PrintAndDie("TLS Server private key must be present and valid.") + } + + tc := server.TLSConfigOpts{} + tc.CertFile = opts.TLSCert + tc.KeyFile = opts.TLSKey + tc.CaFile = opts.TLSCaCert + + if opts.TLSVerify { + tc.Verify = true + } + var err error + if opts.TLSConfig, err = server.GenTLSConfig(&tc); err != nil { + server.PrintAndDie(err.Error()) + } +} diff --git a/server/configs/tls.conf b/server/configs/tls.conf index 54fe8788..e647f4cf 100644 --- a/server/configs/tls.conf +++ b/server/configs/tls.conf @@ -7,6 +7,7 @@ net: apcera.me # net interface tls { cert_file: "./configs/certs/server.pem" key_file: "./configs/certs/key.pem" + timeout: 0.5 } authorization { diff --git a/server/const.go b/server/const.go index b0b6548f..b15e37fd 100644 --- a/server/const.go +++ b/server/const.go @@ -35,11 +35,11 @@ const ( // DEFAULT_MAX_CONNECTIONS is the default maximum connections allowed. DEFAULT_MAX_CONNECTIONS = (64 * 1024) - // SSL_TIMEOUT is the TLS/SSL wait time. - SSL_TIMEOUT = 500 * time.Millisecond + // TLS_TIMEOUT is the TLS wait time. + TLS_TIMEOUT = 500 * time.Millisecond // AUTH_TIMEOUT is the authorization wait time. - AUTH_TIMEOUT = 2 * SSL_TIMEOUT + AUTH_TIMEOUT = 2 * TLS_TIMEOUT // DEFAULT_PING_INTERVAL is how often pings are sent to clients and routes. DEFAULT_PING_INTERVAL = 2 * time.Minute diff --git a/server/opts.go b/server/opts.go index 0f4c0478..7d96f2b7 100644 --- a/server/opts.go +++ b/server/opts.go @@ -41,6 +41,7 @@ type Options struct { ClusterUsername string `json:"-"` ClusterPassword string `json:"-"` ClusterAuthTimeout float64 `json:"auth_timeout"` + ClusterTLSTimeout float64 `json:"-"` ClusterTLSConfig *tls.Config `json:"-"` ProfPort int `json:"-"` PidFile string `json:"-"` @@ -51,6 +52,11 @@ type Options struct { RoutesStr string `json:"-"` BufSize int `json:"-"` TLSTimeout float64 `json:"tls_timeout"` + TLS bool `json:"-"` + TLSVerify bool `json:"-"` + TLSCert string `json:"-"` + TLSKey string `json:"-"` + TLSCaCert string `json:"-"` TLSConfig *tls.Config `json:"-"` } @@ -61,11 +67,13 @@ type authorization struct { } // This struct holds the parsed tls config information. -type tlsConfig struct { - certFile string - keyFile string - caFile string - verify bool +// It spublic so we can use ot for flag parsing +type TLSConfigOpts struct { + CertFile string + KeyFile string + CaFile string + Verify bool + Timeout float64 } // ProcessConfigFile processes a configuration file. @@ -132,9 +140,14 @@ func ProcessConfigFile(configFile string) (*Options, error) { opts.MaxConn = int(v.(int64)) case "tls": tlsm := v.(map[string]interface{}) - if opts.TLSConfig, err = parseTLS(tlsm); err != nil { + tc, err := parseTLS(tlsm) + if err != nil { return nil, err } + if opts.TLSConfig, err = GenTLSConfig(tc); err != nil { + return nil, err + } + opts.TLSTimeout = tc.Timeout } } return opts, nil @@ -168,9 +181,14 @@ func parseCluster(cm map[string]interface{}, opts *Options) error { case "tls": var err error tlsm := mv.(map[string]interface{}) - if opts.ClusterTLSConfig, err = parseTLS(tlsm); err != nil { + tc, err := parseTLS(tlsm) + if err != nil { return err } + if opts.ClusterTLSConfig, err = GenTLSConfig(tc); err != nil { + return err + } + opts.ClusterTLSTimeout = tc.Timeout } } return nil @@ -200,8 +218,9 @@ func parseAuthorization(am map[string]interface{}) authorization { } // Helper function to parse TLS configs. -func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) { - tc := tlsConfig{} +//func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) { +func parseTLS(tlsm map[string]interface{}) (*TLSConfigOpts, error) { + tc := TLSConfigOpts{} for mk, mv := range tlsm { switch strings.ToLower(mk) { case "cert_file": @@ -209,32 +228,44 @@ func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) { if !ok { return nil, fmt.Errorf("error parsing tls config, expected 'cert_file' to be filename") } - tc.certFile = certFile + tc.CertFile = certFile case "key_file": keyFile, ok := mv.(string) if !ok { return nil, fmt.Errorf("error parsing tls config, expected 'key_file' to be filename") } - tc.keyFile = keyFile + tc.KeyFile = keyFile case "ca_file": caFile, ok := mv.(string) if !ok { return nil, fmt.Errorf("error parsing tls config, expected 'ca_file' to be filename") } - tc.caFile = caFile + tc.CaFile = caFile case "verify": verify, ok := mv.(bool) if !ok { return nil, fmt.Errorf("error parsing tls config, expected 'verify' to be a boolean") } - tc.verify = verify - + tc.Verify = verify + case "timeout": + at := float64(0) + switch mv.(type) { + case int64: + at = float64(mv.(int64)) + case float64: + at = mv.(float64) + } + tc.Timeout = at default: return nil, fmt.Errorf("error parsing tls config, unknown field [%q]", mk) } } + return &tc, nil +} + +func GenTLSConfig(tc *TLSConfigOpts) (*tls.Config, error) { // Now load in cert and private key - cert, err := tls.LoadX509KeyPair(tc.certFile, tc.keyFile) + cert, err := tls.LoadX509KeyPair(tc.CertFile, tc.KeyFile) if err != nil { return nil, fmt.Errorf("error parsing X509 certificate/key pair: %v", err) } @@ -244,7 +275,6 @@ func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) { } // Create TLSConfig // We will determine the cipher suites that we prefer. - config := tls.Config{ Certificates: []tls.Certificate{cert}, PreferServerCipherSuites: true, @@ -252,12 +282,12 @@ func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) { CipherSuites: defaultCipherSuites(), } // Require client certificates as needed - if tc.verify == true { + if tc.Verify == true { config.ClientAuth = tls.RequireAnyClientCert } // Add in CAs if applicable. - if tc.caFile != "" { - rootPEM, err := ioutil.ReadFile(tc.caFile) + if tc.CaFile != "" { + rootPEM, err := ioutil.ReadFile(tc.CaFile) if err != nil || rootPEM == nil { return nil, err } @@ -446,7 +476,7 @@ func processOptions(opts *Options) { opts.MaxPingsOut = DEFAULT_PING_MAX_OUT } if opts.TLSTimeout == 0 { - opts.TLSTimeout = float64(SSL_TIMEOUT) / float64(time.Second) + opts.TLSTimeout = float64(TLS_TIMEOUT) / float64(time.Second) } if opts.AuthTimeout == 0 { opts.AuthTimeout = float64(AUTH_TIMEOUT) / float64(time.Second) diff --git a/server/opts_test.go b/server/opts_test.go index c59e02a0..87ee5d26 100644 --- a/server/opts_test.go +++ b/server/opts_test.go @@ -17,7 +17,7 @@ func TestDefaultOptions(t *testing.T) { MaxConn: DEFAULT_MAX_CONNECTIONS, PingInterval: DEFAULT_PING_INTERVAL, MaxPingsOut: DEFAULT_PING_MAX_OUT, - TLSTimeout: float64(SSL_TIMEOUT) / float64(time.Second), + TLSTimeout: float64(TLS_TIMEOUT) / float64(time.Second), AuthTimeout: float64(AUTH_TIMEOUT) / float64(time.Second), MaxControlLine: MAX_CONTROL_LINE_SIZE, MaxPayload: MAX_PAYLOAD_SIZE, @@ -85,6 +85,7 @@ func TestTLSConfigFile(t *testing.T) { Username: "derek", Password: "buckley", AuthTimeout: 1.0, + TLSTimeout: 0.5, } opts, err := ProcessConfigFile("./configs/tls.conf") if err != nil { diff --git a/server/server.go b/server/server.go index d5aa685a..bf0e8f54 100644 --- a/server/server.go +++ b/server/server.go @@ -294,6 +294,11 @@ func (s *Server) AcceptLoop() { return } + // Alert of TLS enabled. + if s.opts.TLSConfig != nil { + Noticef("TLS required for client connections") + } + Noticef("gnatsd is ready") // Setup state that can enable shutdown @@ -333,7 +338,7 @@ func (s *Server) AcceptLoop() { continue } tmpDelay = ACCEPT_MIN_SLEEP - s.createClient(conn) + go s.createClient(conn) } Noticef("Server Exiting..") s.done <- true @@ -424,19 +429,6 @@ func (s *Server) createClient(conn net.Conn) *client { // Send our information. s.sendInfo(c, info) - // Check for TLS - if tlsRequired { - c.Debugf("Starting TLS client connection handshake") - c.nc = tls.Server(c.nc, s.opts.TLSConfig) - conn := c.nc.(*tls.Conn) - err := conn.Handshake() - if err != nil { - c.Debugf("TLS handshake error: %v", err) - } - // Rewrap bw - c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize) - } - // Unlock to register c.mu.Unlock() @@ -445,9 +437,112 @@ func (s *Server) createClient(conn net.Conn) *client { s.clients[c.cid] = c s.mu.Unlock() + // Check for TLS + if tlsRequired { + // Re-Grab lock + c.mu.Lock() + + c.Debugf("Starting TLS client connection handshake") + c.nc = tls.Server(c.nc, s.opts.TLSConfig) + conn := c.nc.(*tls.Conn) + + // Setup the timeout + ttl := secondsToDuration(s.opts.TLSTimeout) + time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) + conn.SetReadDeadline(time.Now().Add(ttl)) + + // Force handshake + c.mu.Unlock() + if err := conn.Handshake(); err != nil { + c.Debugf("TLS handshake error: %v", err) + c.sendErr("Secure Connection - TLS Required") + c.closeConnection() + return nil + } + // Reset the read deadline + conn.SetReadDeadline(time.Time{}) + + // Re-Grab lock + c.mu.Lock() + + // Rewrap bw + c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize) + c.Debugf("TLS handshake complete") + cs := conn.ConnectionState() + c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite)) + c.mu.Unlock() + } + return c } +// Handle closing down a connection when the handshake has timedout. +func tlsTimeout(c *client, conn *tls.Conn) { + c.mu.Lock() + nc := c.nc + c.mu.Unlock() + // Check if already closed + if nc == nil { + return + } + cs := conn.ConnectionState() + if cs.HandshakeComplete == false { + c.Debugf("TLS handshake timeout") + c.sendErr("Secure Connection - TLS Required") + c.closeConnection() + } +} + +// Seems silly we have to write these +func tlsVersion(ver uint16) string { + switch ver { + case tls.VersionTLS10: + return "1.0" + case tls.VersionTLS11: + return "1.1" + case tls.VersionTLS12: + return "1.2" + } + return fmt.Sprintf("Unknown [%x]", ver) +} + +// We use hex here so we don't need multiple versions +func tlsCipher(cs uint16) string { + switch cs { + case 0x0005: + return "TLS_RSA_WITH_RC4_128_SHA" + case 0x000a: + return "TLS_RSA_WITH_3DES_EDE_CBC_SHA" + case 0x002f: + return "TLS_RSA_WITH_AES_128_CBC_SHA" + case 0x0035: + return "TLS_RSA_WITH_AES_256_CBC_SHA" + case 0xc007: + return "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA" + case 0xc009: + return "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA" + case 0xc00a: + return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA" + case 0xc011: + return "TLS_ECDHE_RSA_WITH_RC4_128_SHA" + case 0xc012: + return "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA" + case 0xc013: + return "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA" + case 0xc014: + return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA" + case 0xc02f: + return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" + case 0xc02b: + return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" + case 0xc030: + return "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" + case 0xc02c: + return "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" + } + return fmt.Sprintf("Unknown [%x]", cs) +} + // Assume the lock is held upon entry. func (s *Server) sendInfo(c *client, info []byte) { c.nc.Write(info) diff --git a/server/usage.go b/server/usage.go index 97f9be56..b4cebef4 100644 --- a/server/usage.go +++ b/server/usage.go @@ -18,8 +18,8 @@ Server Options: Logging Options: -l, --log FILE File to redirect log output -T, --logtime Timestamp log entries (default: true) - -s, --syslog Enable syslog as log method. - -r, --remote_syslog Syslog server addr (udp://localhost:514). + -s, --syslog Enable syslog as log method + -r, --remote_syslog Syslog server addr (udp://localhost:514) -D, --debug Enable debugging output -V, --trace Trace the raw protocol -DV Debug and Trace @@ -28,6 +28,13 @@ Authorization Options: --user user User required for connections --pass password Password required for connections +TLS Options: + --tls Enable TLS, do not verify clients (default: false) + --tlscert FILE Server certificate file + --tlskey FILE Private key for server certificate + --tlsverify Enable TLS, very client certificates + --tlscacert client Client certificate CA for verification + Cluster Options: --routes [rurl-1, rurl-2] Routes to solicit and connect diff --git a/test/configs/tls.conf b/test/configs/tls.conf index ac6a9aa2..66294347 100644 --- a/test/configs/tls.conf +++ b/test/configs/tls.conf @@ -9,6 +9,8 @@ tls { cert_file: "./configs/certs/server-cert.pem" # Server private key key_file: "./configs/certs/server-key.pem" + # Specified time for handshake to complete + timeout: 0.25 } authorization { diff --git a/test/tls_test.go b/test/tls_test.go index 415e193a..1aef3b60 100644 --- a/test/tls_test.go +++ b/test/tls_test.go @@ -3,11 +3,15 @@ package test import ( + "bufio" "crypto/tls" "crypto/x509" "fmt" "io/ioutil" + "net" + "strings" "testing" + "time" "github.com/nats-io/nats" ) @@ -129,3 +133,40 @@ func TestTLSClientCertificate(t *testing.T) { nc.Flush() defer nc.Close() } + +func TestTLSConnectionTimeout(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/tls.conf") + defer srv.Shutdown() + + // Dial with normal TCP + endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + conn, err := net.Dial("tcp", endpoint) + if err != nil { + t.Fatalf("Could not connect to %q", endpoint) + } + defer conn.Close() + + // Read deadlines + conn.SetReadDeadline(time.Now().Add(time.Second)) + + // Read the INFO string. + br := bufio.NewReader(conn) + info, err := br.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read INFO - %v", err) + } + if !strings.HasPrefix(info, "INFO ") { + t.Fatalf("INFO response incorrect: %s\n", info) + } + wait := time.Duration(opts.TLSTimeout * float64(time.Second)) + time.Sleep(wait) + // Read deadlines + conn.SetReadDeadline(time.Now().Add(time.Second)) + tlsErr, err := br.ReadString('\n') + if err != nil { + t.Fatalf("Error reading error response - %v\n", err) + } + if !strings.Contains(tlsErr, "-ERR 'Secure Connection - TLS Required") { + t.Fatalf("TLS Timeout response incorrect: %q\n", tlsErr) + } +}