diff --git a/server/client.go b/server/client.go index 10a20d45..d183c72b 100644 --- a/server/client.go +++ b/server/client.go @@ -244,7 +244,7 @@ func (c *client) processRouteInfo(info *Info) { } } -// Process the information messages from Clients and other routes. +// Process the information messages from Clients and other Routes. func (c *client) processInfo(arg []byte) error { info := Info{} if err := json.Unmarshal(arg, &info); err != nil { @@ -293,6 +293,7 @@ func (c *client) processConnect(arg []byte) error { func (c *client) authTimeout() { c.sendErr("Authorization Timeout") + c.Debugf("Authorization Timeout") c.closeConnection() } diff --git a/server/client_test.go b/server/client_test.go index 4a787f19..19bcb845 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -18,7 +18,7 @@ type serverInfo struct { Port uint `json:"port"` Version string `json:"version"` AuthRequired bool `json:"auth_required"` - SslRequired bool `json:"ssl_required"` + TLSRequired bool `json:"ssl_required"` MaxPayload int64 `json:"max_payload"` } @@ -93,7 +93,7 @@ func TestClientCreateAndInfo(t *testing.T) { } // Sanity checks if info.MaxPayload != MAX_PAYLOAD_SIZE || - info.AuthRequired || info.SslRequired || + info.AuthRequired || info.TLSRequired || info.Port != DEFAULT_PORT { t.Fatalf("INFO inconsistent: %+v\n", info) } diff --git a/server/opts.go b/server/opts.go index d47b066d..680ced48 100644 --- a/server/opts.go +++ b/server/opts.go @@ -1,8 +1,10 @@ -// Copyright 2012-2013 Apcera Inc. All rights reserved. +// Copyright 2012-2015 Apcera Inc. All rights reserved. package server import ( + "crypto/tls" + "crypto/x509" "fmt" "io/ioutil" "net" @@ -30,7 +32,6 @@ type Options struct { PingInterval time.Duration `json:"ping_interval"` MaxPingsOut int `json:"ping_max"` HTTPPort int `json:"http_port"` - SslTimeout float64 `json:"ssl_timeout"` AuthTimeout float64 `json:"auth_timeout"` MaxControlLine int `json:"max_control_line"` MaxPayload int `json:"max_payload"` @@ -48,6 +49,8 @@ type Options struct { Routes []*url.URL `json:"-"` RoutesStr string `json:"-"` BufSize int `json:"-"` + TLSTimeout float64 `json:"tls_timeout"` + TLSConfig *tls.Config `json:"-"` } type authorization struct { @@ -56,6 +59,12 @@ type authorization struct { timeout float64 } +// This struct holds the parsed tls config information. +type tlsConfig struct { + certFile string + keyFile string +} + // ProcessConfigFile processes a configuration file. // FIXME(dlc): Hacky func ProcessConfigFile(configFile string) (*Options, error) { @@ -118,6 +127,11 @@ func ProcessConfigFile(configFile string) (*Options, error) { opts.MaxPending = int(v.(int64)) case "max_connections", "max_conn": opts.MaxConn = int(v.(int64)) + case "tls": + tlsm := v.(map[string]interface{}) + if err := parseTLS(tlsm, opts); err != nil { + return nil, err + } } } return opts, nil @@ -176,6 +190,53 @@ func parseAuthorization(am map[string]interface{}) authorization { return auth } +// Helper function to parse TLS configs. +func parseTLS(tlsm map[string]interface{}, opts *Options) error { + tc := tlsConfig{} + for mk, mv := range tlsm { + switch strings.ToLower(mk) { + case "cert_file": + certFile, ok := mv.(string) + if !ok { + return fmt.Errorf("error parsing tls config, expected 'cert_file' to be filename") + } + tc.certFile = certFile + case "key_file": + keyFile, ok := mv.(string) + if !ok { + return fmt.Errorf("error parsing tls config, expected 'key_file' to be filename") + } + tc.keyFile = keyFile + default: + return fmt.Errorf("error parsing tls config, unknown field [%q]", mk) + } + } + // Now load in cert and private key + cert, err := tls.LoadX509KeyPair(tc.certFile, tc.keyFile) + if err != nil { + return fmt.Errorf("error parsing X509 certificate/key pair: %v", err) + } + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return fmt.Errorf("error parsing certificate: %v", err) + } + // Create TLSConfig + // We will determine the cipher suites that we prefer. + config := tls.Config{ + Certificates: []tls.Certificate{cert}, + PreferServerCipherSuites: true, + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, + } + opts.TLSConfig = &config + return nil +} + // MergeOptions will merge two options giving preference to the flagOpts // if the item is present. func MergeOptions(fileOpts, flagOpts *Options) *Options { @@ -349,8 +410,8 @@ func processOptions(opts *Options) { if opts.MaxPingsOut == 0 { opts.MaxPingsOut = DEFAULT_PING_MAX_OUT } - if opts.SslTimeout == 0 { - opts.SslTimeout = float64(SSL_TIMEOUT) / float64(time.Second) + if opts.TLSTimeout == 0 { + opts.TLSTimeout = float64(SSL_TIMEOUT) / float64(time.Second) } if opts.AuthTimeout == 0 { opts.AuthTimeout = float64(AUTH_TIMEOUT) / float64(time.Second) diff --git a/server/opts_test.go b/server/opts_test.go index 821dc51a..00eef8dc 100644 --- a/server/opts_test.go +++ b/server/opts_test.go @@ -1,8 +1,9 @@ -// Copyright 2013-2014 Apcera Inc. All rights reserved. +// Copyright 2013-2015 Apcera Inc. All rights reserved. package server import ( + "crypto/tls" "net/url" "reflect" "testing" @@ -16,7 +17,7 @@ func TestDefaultOptions(t *testing.T) { MaxConn: DEFAULT_MAX_CONNECTIONS, PingInterval: DEFAULT_PING_INTERVAL, MaxPingsOut: DEFAULT_PING_MAX_OUT, - SslTimeout: float64(SSL_TIMEOUT) / float64(time.Second), + TLSTimeout: float64(SSL_TIMEOUT) / float64(time.Second), AuthTimeout: float64(AUTH_TIMEOUT) / float64(time.Second), MaxControlLine: MAX_CONTROL_LINE_SIZE, MaxPayload: MAX_PAYLOAD_SIZE, @@ -77,6 +78,57 @@ func TestConfigFile(t *testing.T) { } } +func TestTLSConfigFile(t *testing.T) { + golden := &Options{ + Host: "apcera.me", + Port: 4443, + Username: "derek", + Password: "buckley", + AuthTimeout: 1.0, + } + opts, err := ProcessConfigFile("./configs/tls/test.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + tlsConfig := opts.TLSConfig + if tlsConfig == nil { + t.Fatal("Expected opts.TLSConfig to be non-nil") + } + opts.TLSConfig = nil + if !reflect.DeepEqual(golden, opts) { + t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, opts) + } + // Now check TLSConfig a bit more closely + // CipherSuites + ciphers := []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + } + if !reflect.DeepEqual(tlsConfig.CipherSuites, ciphers) { + t.Fatalf("Got incorrect cipher suite list: [%+v]", tlsConfig.CipherSuites) + } + if tlsConfig.MinVersion != tls.VersionTLS12 { + t.Fatalf("Expected MinVersion of 1.2 [%v], got [%v]", tls.VersionTLS12, tlsConfig.MinVersion) + } + if tlsConfig.PreferServerCipherSuites != true { + t.Fatal("Expected PreferServerCipherSuites to be true") + } + // Verify hostname is correct in certificate + if len(tlsConfig.Certificates) != 1 { + t.Fatal("Expected 1 certificate") + } + if len(tlsConfig.Certificates) < 1 { + t.Fatalf("Expected certificates") + } + cert := tlsConfig.Certificates[0].Leaf + if err := cert.VerifyHostname("apcera.me:4443"); err != nil { + t.Fatalf("Could not verify hostname in certificate: %v\n", err) + } +} + func TestMergeOverrides(t *testing.T) { golden := &Options{ Host: "apcera.me", diff --git a/server/route.go b/server/route.go index 9fc5ffea..60ec2c74 100644 --- a/server/route.go +++ b/server/route.go @@ -24,7 +24,7 @@ type connectInfo struct { Pedantic bool `json:"pedantic"` User string `json:"user,omitempty"` Pass string `json:"pass,omitempty"` - Ssl bool `json:"ssl_required"` + TLS bool `json:"ssl_required"` Name string `json:"name"` } @@ -42,7 +42,7 @@ func (c *client) sendConnect() { Pedantic: false, User: user, Pass: pass, - Ssl: false, + TLS: false, Name: c.srv.info.ID, } b, err := json.Marshal(cinfo) @@ -301,7 +301,7 @@ func (s *Server) StartRouting() { Host: s.opts.ClusterHost, Port: s.opts.ClusterPort, AuthRequired: false, - SslRequired: false, + TLSRequired: false, MaxPayload: MAX_PAYLOAD_SIZE, } // Check for Auth items diff --git a/server/server.go b/server/server.go index 82e50e8d..d46ab8d7 100644 --- a/server/server.go +++ b/server/server.go @@ -3,6 +3,8 @@ package server import ( + "bufio" + "crypto/tls" "encoding/json" "fmt" "io/ioutil" @@ -30,7 +32,7 @@ type Info struct { Host string `json:"host"` Port int `json:"port"` AuthRequired bool `json:"auth_required"` - SslRequired bool `json:"ssl_required"` + TLSRequired bool `json:"ssl_required"` // ssl json used for older clients MaxPayload int `json:"max_payload"` } @@ -80,7 +82,7 @@ func New(opts *Options) *Server { Host: opts.Host, Port: opts.Port, AuthRequired: false, - SslRequired: false, + TLSRequired: opts.TLSConfig != nil, MaxPayload: opts.MaxPayload, } @@ -393,6 +395,7 @@ func (s *Server) createClient(conn net.Conn) *client { s.mu.Lock() info := s.infoJSON authRequired := s.info.AuthRequired + tlsRequired := s.info.TLSRequired s.mu.Unlock() // Grab lock @@ -412,6 +415,19 @@ func (s *Server) createClient(conn net.Conn) *client { // Send our information. s.sendInfo(c, info) + // Check for TLS + if tlsRequired { + c.Debugf("Starting TLS client connection handshake") + c.nc = tls.Server(c.nc, s.opts.TLSConfig) + conn := c.nc.(*tls.Conn) + err := conn.Handshake() + if err != nil { + c.Debugf("TLS handshake error: %v", err) + } + // Rewrap bw + c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize) + } + // Unlock to register c.mu.Unlock() diff --git a/test/opts_test.go b/test/opts_test.go index b94ac7c3..d94ad0fb 100644 --- a/test/opts_test.go +++ b/test/opts_test.go @@ -3,6 +3,7 @@ package test import ( + "fmt" "testing" ) @@ -19,3 +20,17 @@ func TestServerConfig(t *testing.T) { opts.MaxPayload, sinfo.MaxPayload) } } + +func TestTLSConfig(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/tls.conf") + defer srv.Shutdown() + + c := createClientConn(t, opts.Host, opts.Port) + defer c.Close() + + sinfo := checkInfoMsg(t, c) + fmt.Printf("sinfo is %+v\n", sinfo) + if sinfo.TLSRequired != true { + t.Fatal("Expected TLSRequired to be true when configured") + } +} diff --git a/test/test.go b/test/test.go index 59a6857c..619e6b12 100644 --- a/test/test.go +++ b/test/test.go @@ -15,6 +15,7 @@ import ( "strings" "time" + "github.com/nats-io/gnatsd/auth" "github.com/nats-io/gnatsd/server" ) @@ -53,7 +54,16 @@ func RunServerWithConfig(configFile string) (srv *server.Server, opts *server.Op panic(fmt.Sprintf("Error processing configuration file: %v", err)) } opts.NoSigs, opts.NoLog = true, true - srv = RunServer(opts) + + // Check for auth + var a server.Auth + if opts.Authorization != "" { + a = &auth.Token{Token: opts.Authorization} + } + if opts.Username != "" { + a = &auth.Plain{Username: opts.Username, Password: opts.Password} + } + srv = RunServerWithAuth(opts, a) return }