tls flags, proper timeouts

This commit is contained in:
Derek Collison
2015-11-22 14:43:16 -08:00
parent d703fd551a
commit 3b64567f00
9 changed files with 252 additions and 40 deletions

View File

@@ -53,6 +53,12 @@ func main() {
flag.IntVar(&opts.ProfPort, "profile", 0, "Profiling HTTP port")
flag.StringVar(&opts.RoutesStr, "routes", "", "Routes to actively solicit a connection.")
flag.BoolVar(&opts.TLS, "tls", false, "Enable TLS.")
flag.BoolVar(&opts.TLSVerify, "tlsverify", false, "Enable TLS with client verification.")
flag.StringVar(&opts.TLSCert, "tlscert", "", "Server certificate file.")
flag.StringVar(&opts.TLSKey, "tlskey", "", "Private key for server certificate.")
flag.StringVar(&opts.TLSCaCert, "tlscacert", "", "Client certificate CA for verification.")
// Not public per se, will be replaced with dynamic system, but can be used to lower memory footprint when
// lots of connections present.
flag.IntVar(&opts.BufSize, "bs", 0, "Read/Write buffer size per client connection.")
@@ -98,6 +104,9 @@ func main() {
}
opts.Routes = newroutes
// Configure TLS based on any present flags
configureTLS(&opts)
// Create the server with appropriate options.
s := server.New(&opts)
@@ -148,3 +157,29 @@ func configureLogger(s *server.Server, opts *server.Options) {
s.SetLogger(log, opts.Debug, opts.Trace)
}
func configureTLS(opts *server.Options) {
// If no trigger flags, ignore the others
if !opts.TLS && !opts.TLSVerify {
return
}
if opts.TLSCert == "" {
server.PrintAndDie("TLS Server certificate must be present and valid.")
}
if opts.TLSKey == "" {
server.PrintAndDie("TLS Server private key must be present and valid.")
}
tc := server.TLSConfigOpts{}
tc.CertFile = opts.TLSCert
tc.KeyFile = opts.TLSKey
tc.CaFile = opts.TLSCaCert
if opts.TLSVerify {
tc.Verify = true
}
var err error
if opts.TLSConfig, err = server.GenTLSConfig(&tc); err != nil {
server.PrintAndDie(err.Error())
}
}

View File

@@ -7,6 +7,7 @@ net: apcera.me # net interface
tls {
cert_file: "./configs/certs/server.pem"
key_file: "./configs/certs/key.pem"
timeout: 0.5
}
authorization {

View File

@@ -35,11 +35,11 @@ const (
// DEFAULT_MAX_CONNECTIONS is the default maximum connections allowed.
DEFAULT_MAX_CONNECTIONS = (64 * 1024)
// SSL_TIMEOUT is the TLS/SSL wait time.
SSL_TIMEOUT = 500 * time.Millisecond
// TLS_TIMEOUT is the TLS wait time.
TLS_TIMEOUT = 500 * time.Millisecond
// AUTH_TIMEOUT is the authorization wait time.
AUTH_TIMEOUT = 2 * SSL_TIMEOUT
AUTH_TIMEOUT = 2 * TLS_TIMEOUT
// DEFAULT_PING_INTERVAL is how often pings are sent to clients and routes.
DEFAULT_PING_INTERVAL = 2 * time.Minute

View File

@@ -41,6 +41,7 @@ type Options struct {
ClusterUsername string `json:"-"`
ClusterPassword string `json:"-"`
ClusterAuthTimeout float64 `json:"auth_timeout"`
ClusterTLSTimeout float64 `json:"-"`
ClusterTLSConfig *tls.Config `json:"-"`
ProfPort int `json:"-"`
PidFile string `json:"-"`
@@ -51,6 +52,11 @@ type Options struct {
RoutesStr string `json:"-"`
BufSize int `json:"-"`
TLSTimeout float64 `json:"tls_timeout"`
TLS bool `json:"-"`
TLSVerify bool `json:"-"`
TLSCert string `json:"-"`
TLSKey string `json:"-"`
TLSCaCert string `json:"-"`
TLSConfig *tls.Config `json:"-"`
}
@@ -61,11 +67,13 @@ type authorization struct {
}
// This struct holds the parsed tls config information.
type tlsConfig struct {
certFile string
keyFile string
caFile string
verify bool
// It spublic so we can use ot for flag parsing
type TLSConfigOpts struct {
CertFile string
KeyFile string
CaFile string
Verify bool
Timeout float64
}
// ProcessConfigFile processes a configuration file.
@@ -132,9 +140,14 @@ func ProcessConfigFile(configFile string) (*Options, error) {
opts.MaxConn = int(v.(int64))
case "tls":
tlsm := v.(map[string]interface{})
if opts.TLSConfig, err = parseTLS(tlsm); err != nil {
tc, err := parseTLS(tlsm)
if err != nil {
return nil, err
}
if opts.TLSConfig, err = GenTLSConfig(tc); err != nil {
return nil, err
}
opts.TLSTimeout = tc.Timeout
}
}
return opts, nil
@@ -168,9 +181,14 @@ func parseCluster(cm map[string]interface{}, opts *Options) error {
case "tls":
var err error
tlsm := mv.(map[string]interface{})
if opts.ClusterTLSConfig, err = parseTLS(tlsm); err != nil {
tc, err := parseTLS(tlsm)
if err != nil {
return err
}
if opts.ClusterTLSConfig, err = GenTLSConfig(tc); err != nil {
return err
}
opts.ClusterTLSTimeout = tc.Timeout
}
}
return nil
@@ -200,8 +218,9 @@ func parseAuthorization(am map[string]interface{}) authorization {
}
// Helper function to parse TLS configs.
func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) {
tc := tlsConfig{}
//func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) {
func parseTLS(tlsm map[string]interface{}) (*TLSConfigOpts, error) {
tc := TLSConfigOpts{}
for mk, mv := range tlsm {
switch strings.ToLower(mk) {
case "cert_file":
@@ -209,32 +228,44 @@ func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) {
if !ok {
return nil, fmt.Errorf("error parsing tls config, expected 'cert_file' to be filename")
}
tc.certFile = certFile
tc.CertFile = certFile
case "key_file":
keyFile, ok := mv.(string)
if !ok {
return nil, fmt.Errorf("error parsing tls config, expected 'key_file' to be filename")
}
tc.keyFile = keyFile
tc.KeyFile = keyFile
case "ca_file":
caFile, ok := mv.(string)
if !ok {
return nil, fmt.Errorf("error parsing tls config, expected 'ca_file' to be filename")
}
tc.caFile = caFile
tc.CaFile = caFile
case "verify":
verify, ok := mv.(bool)
if !ok {
return nil, fmt.Errorf("error parsing tls config, expected 'verify' to be a boolean")
}
tc.verify = verify
tc.Verify = verify
case "timeout":
at := float64(0)
switch mv.(type) {
case int64:
at = float64(mv.(int64))
case float64:
at = mv.(float64)
}
tc.Timeout = at
default:
return nil, fmt.Errorf("error parsing tls config, unknown field [%q]", mk)
}
}
return &tc, nil
}
func GenTLSConfig(tc *TLSConfigOpts) (*tls.Config, error) {
// Now load in cert and private key
cert, err := tls.LoadX509KeyPair(tc.certFile, tc.keyFile)
cert, err := tls.LoadX509KeyPair(tc.CertFile, tc.KeyFile)
if err != nil {
return nil, fmt.Errorf("error parsing X509 certificate/key pair: %v", err)
}
@@ -244,7 +275,6 @@ func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) {
}
// Create TLSConfig
// We will determine the cipher suites that we prefer.
config := tls.Config{
Certificates: []tls.Certificate{cert},
PreferServerCipherSuites: true,
@@ -252,12 +282,12 @@ func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) {
CipherSuites: defaultCipherSuites(),
}
// Require client certificates as needed
if tc.verify == true {
if tc.Verify == true {
config.ClientAuth = tls.RequireAnyClientCert
}
// Add in CAs if applicable.
if tc.caFile != "" {
rootPEM, err := ioutil.ReadFile(tc.caFile)
if tc.CaFile != "" {
rootPEM, err := ioutil.ReadFile(tc.CaFile)
if err != nil || rootPEM == nil {
return nil, err
}
@@ -446,7 +476,7 @@ func processOptions(opts *Options) {
opts.MaxPingsOut = DEFAULT_PING_MAX_OUT
}
if opts.TLSTimeout == 0 {
opts.TLSTimeout = float64(SSL_TIMEOUT) / float64(time.Second)
opts.TLSTimeout = float64(TLS_TIMEOUT) / float64(time.Second)
}
if opts.AuthTimeout == 0 {
opts.AuthTimeout = float64(AUTH_TIMEOUT) / float64(time.Second)

View File

@@ -17,7 +17,7 @@ func TestDefaultOptions(t *testing.T) {
MaxConn: DEFAULT_MAX_CONNECTIONS,
PingInterval: DEFAULT_PING_INTERVAL,
MaxPingsOut: DEFAULT_PING_MAX_OUT,
TLSTimeout: float64(SSL_TIMEOUT) / float64(time.Second),
TLSTimeout: float64(TLS_TIMEOUT) / float64(time.Second),
AuthTimeout: float64(AUTH_TIMEOUT) / float64(time.Second),
MaxControlLine: MAX_CONTROL_LINE_SIZE,
MaxPayload: MAX_PAYLOAD_SIZE,
@@ -85,6 +85,7 @@ func TestTLSConfigFile(t *testing.T) {
Username: "derek",
Password: "buckley",
AuthTimeout: 1.0,
TLSTimeout: 0.5,
}
opts, err := ProcessConfigFile("./configs/tls.conf")
if err != nil {

View File

@@ -294,6 +294,11 @@ func (s *Server) AcceptLoop() {
return
}
// Alert of TLS enabled.
if s.opts.TLSConfig != nil {
Noticef("TLS required for client connections")
}
Noticef("gnatsd is ready")
// Setup state that can enable shutdown
@@ -333,7 +338,7 @@ func (s *Server) AcceptLoop() {
continue
}
tmpDelay = ACCEPT_MIN_SLEEP
s.createClient(conn)
go s.createClient(conn)
}
Noticef("Server Exiting..")
s.done <- true
@@ -424,19 +429,6 @@ func (s *Server) createClient(conn net.Conn) *client {
// Send our information.
s.sendInfo(c, info)
// Check for TLS
if tlsRequired {
c.Debugf("Starting TLS client connection handshake")
c.nc = tls.Server(c.nc, s.opts.TLSConfig)
conn := c.nc.(*tls.Conn)
err := conn.Handshake()
if err != nil {
c.Debugf("TLS handshake error: %v", err)
}
// Rewrap bw
c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize)
}
// Unlock to register
c.mu.Unlock()
@@ -445,9 +437,112 @@ func (s *Server) createClient(conn net.Conn) *client {
s.clients[c.cid] = c
s.mu.Unlock()
// Check for TLS
if tlsRequired {
// Re-Grab lock
c.mu.Lock()
c.Debugf("Starting TLS client connection handshake")
c.nc = tls.Server(c.nc, s.opts.TLSConfig)
conn := c.nc.(*tls.Conn)
// Setup the timeout
ttl := secondsToDuration(s.opts.TLSTimeout)
time.AfterFunc(ttl, func() { tlsTimeout(c, conn) })
conn.SetReadDeadline(time.Now().Add(ttl))
// Force handshake
c.mu.Unlock()
if err := conn.Handshake(); err != nil {
c.Debugf("TLS handshake error: %v", err)
c.sendErr("Secure Connection - TLS Required")
c.closeConnection()
return nil
}
// Reset the read deadline
conn.SetReadDeadline(time.Time{})
// Re-Grab lock
c.mu.Lock()
// Rewrap bw
c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize)
c.Debugf("TLS handshake complete")
cs := conn.ConnectionState()
c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite))
c.mu.Unlock()
}
return c
}
// Handle closing down a connection when the handshake has timedout.
func tlsTimeout(c *client, conn *tls.Conn) {
c.mu.Lock()
nc := c.nc
c.mu.Unlock()
// Check if already closed
if nc == nil {
return
}
cs := conn.ConnectionState()
if cs.HandshakeComplete == false {
c.Debugf("TLS handshake timeout")
c.sendErr("Secure Connection - TLS Required")
c.closeConnection()
}
}
// Seems silly we have to write these
func tlsVersion(ver uint16) string {
switch ver {
case tls.VersionTLS10:
return "1.0"
case tls.VersionTLS11:
return "1.1"
case tls.VersionTLS12:
return "1.2"
}
return fmt.Sprintf("Unknown [%x]", ver)
}
// We use hex here so we don't need multiple versions
func tlsCipher(cs uint16) string {
switch cs {
case 0x0005:
return "TLS_RSA_WITH_RC4_128_SHA"
case 0x000a:
return "TLS_RSA_WITH_3DES_EDE_CBC_SHA"
case 0x002f:
return "TLS_RSA_WITH_AES_128_CBC_SHA"
case 0x0035:
return "TLS_RSA_WITH_AES_256_CBC_SHA"
case 0xc007:
return "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA"
case 0xc009:
return "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA"
case 0xc00a:
return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"
case 0xc011:
return "TLS_ECDHE_RSA_WITH_RC4_128_SHA"
case 0xc012:
return "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA"
case 0xc013:
return "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA"
case 0xc014:
return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"
case 0xc02f:
return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
case 0xc02b:
return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
case 0xc030:
return "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"
case 0xc02c:
return "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
}
return fmt.Sprintf("Unknown [%x]", cs)
}
// Assume the lock is held upon entry.
func (s *Server) sendInfo(c *client, info []byte) {
c.nc.Write(info)

View File

@@ -18,8 +18,8 @@ Server Options:
Logging Options:
-l, --log FILE File to redirect log output
-T, --logtime Timestamp log entries (default: true)
-s, --syslog Enable syslog as log method.
-r, --remote_syslog Syslog server addr (udp://localhost:514).
-s, --syslog Enable syslog as log method
-r, --remote_syslog Syslog server addr (udp://localhost:514)
-D, --debug Enable debugging output
-V, --trace Trace the raw protocol
-DV Debug and Trace
@@ -28,6 +28,13 @@ Authorization Options:
--user user User required for connections
--pass password Password required for connections
TLS Options:
--tls Enable TLS, do not verify clients (default: false)
--tlscert FILE Server certificate file
--tlskey FILE Private key for server certificate
--tlsverify Enable TLS, very client certificates
--tlscacert client Client certificate CA for verification
Cluster Options:
--routes [rurl-1, rurl-2] Routes to solicit and connect

View File

@@ -9,6 +9,8 @@ tls {
cert_file: "./configs/certs/server-cert.pem"
# Server private key
key_file: "./configs/certs/server-key.pem"
# Specified time for handshake to complete
timeout: 0.25
}
authorization {

View File

@@ -3,11 +3,15 @@
package test
import (
"bufio"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"strings"
"testing"
"time"
"github.com/nats-io/nats"
)
@@ -129,3 +133,40 @@ func TestTLSClientCertificate(t *testing.T) {
nc.Flush()
defer nc.Close()
}
func TestTLSConnectionTimeout(t *testing.T) {
srv, opts := RunServerWithConfig("./configs/tls.conf")
defer srv.Shutdown()
// Dial with normal TCP
endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port)
conn, err := net.Dial("tcp", endpoint)
if err != nil {
t.Fatalf("Could not connect to %q", endpoint)
}
defer conn.Close()
// Read deadlines
conn.SetReadDeadline(time.Now().Add(time.Second))
// Read the INFO string.
br := bufio.NewReader(conn)
info, err := br.ReadString('\n')
if err != nil {
t.Fatalf("Failed to read INFO - %v", err)
}
if !strings.HasPrefix(info, "INFO ") {
t.Fatalf("INFO response incorrect: %s\n", info)
}
wait := time.Duration(opts.TLSTimeout * float64(time.Second))
time.Sleep(wait)
// Read deadlines
conn.SetReadDeadline(time.Now().Add(time.Second))
tlsErr, err := br.ReadString('\n')
if err != nil {
t.Fatalf("Error reading error response - %v\n", err)
}
if !strings.Contains(tlsErr, "-ERR 'Secure Connection - TLS Required") {
t.Fatalf("TLS Timeout response incorrect: %q\n", tlsErr)
}
}