mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 11:48:43 -07:00
tls flags, proper timeouts
This commit is contained in:
35
gnatsd.go
35
gnatsd.go
@@ -53,6 +53,12 @@ func main() {
|
||||
flag.IntVar(&opts.ProfPort, "profile", 0, "Profiling HTTP port")
|
||||
flag.StringVar(&opts.RoutesStr, "routes", "", "Routes to actively solicit a connection.")
|
||||
|
||||
flag.BoolVar(&opts.TLS, "tls", false, "Enable TLS.")
|
||||
flag.BoolVar(&opts.TLSVerify, "tlsverify", false, "Enable TLS with client verification.")
|
||||
flag.StringVar(&opts.TLSCert, "tlscert", "", "Server certificate file.")
|
||||
flag.StringVar(&opts.TLSKey, "tlskey", "", "Private key for server certificate.")
|
||||
flag.StringVar(&opts.TLSCaCert, "tlscacert", "", "Client certificate CA for verification.")
|
||||
|
||||
// Not public per se, will be replaced with dynamic system, but can be used to lower memory footprint when
|
||||
// lots of connections present.
|
||||
flag.IntVar(&opts.BufSize, "bs", 0, "Read/Write buffer size per client connection.")
|
||||
@@ -98,6 +104,9 @@ func main() {
|
||||
}
|
||||
opts.Routes = newroutes
|
||||
|
||||
// Configure TLS based on any present flags
|
||||
configureTLS(&opts)
|
||||
|
||||
// Create the server with appropriate options.
|
||||
s := server.New(&opts)
|
||||
|
||||
@@ -148,3 +157,29 @@ func configureLogger(s *server.Server, opts *server.Options) {
|
||||
|
||||
s.SetLogger(log, opts.Debug, opts.Trace)
|
||||
}
|
||||
|
||||
func configureTLS(opts *server.Options) {
|
||||
// If no trigger flags, ignore the others
|
||||
if !opts.TLS && !opts.TLSVerify {
|
||||
return
|
||||
}
|
||||
if opts.TLSCert == "" {
|
||||
server.PrintAndDie("TLS Server certificate must be present and valid.")
|
||||
}
|
||||
if opts.TLSKey == "" {
|
||||
server.PrintAndDie("TLS Server private key must be present and valid.")
|
||||
}
|
||||
|
||||
tc := server.TLSConfigOpts{}
|
||||
tc.CertFile = opts.TLSCert
|
||||
tc.KeyFile = opts.TLSKey
|
||||
tc.CaFile = opts.TLSCaCert
|
||||
|
||||
if opts.TLSVerify {
|
||||
tc.Verify = true
|
||||
}
|
||||
var err error
|
||||
if opts.TLSConfig, err = server.GenTLSConfig(&tc); err != nil {
|
||||
server.PrintAndDie(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ net: apcera.me # net interface
|
||||
tls {
|
||||
cert_file: "./configs/certs/server.pem"
|
||||
key_file: "./configs/certs/key.pem"
|
||||
timeout: 0.5
|
||||
}
|
||||
|
||||
authorization {
|
||||
|
||||
@@ -35,11 +35,11 @@ const (
|
||||
// DEFAULT_MAX_CONNECTIONS is the default maximum connections allowed.
|
||||
DEFAULT_MAX_CONNECTIONS = (64 * 1024)
|
||||
|
||||
// SSL_TIMEOUT is the TLS/SSL wait time.
|
||||
SSL_TIMEOUT = 500 * time.Millisecond
|
||||
// TLS_TIMEOUT is the TLS wait time.
|
||||
TLS_TIMEOUT = 500 * time.Millisecond
|
||||
|
||||
// AUTH_TIMEOUT is the authorization wait time.
|
||||
AUTH_TIMEOUT = 2 * SSL_TIMEOUT
|
||||
AUTH_TIMEOUT = 2 * TLS_TIMEOUT
|
||||
|
||||
// DEFAULT_PING_INTERVAL is how often pings are sent to clients and routes.
|
||||
DEFAULT_PING_INTERVAL = 2 * time.Minute
|
||||
|
||||
@@ -41,6 +41,7 @@ type Options struct {
|
||||
ClusterUsername string `json:"-"`
|
||||
ClusterPassword string `json:"-"`
|
||||
ClusterAuthTimeout float64 `json:"auth_timeout"`
|
||||
ClusterTLSTimeout float64 `json:"-"`
|
||||
ClusterTLSConfig *tls.Config `json:"-"`
|
||||
ProfPort int `json:"-"`
|
||||
PidFile string `json:"-"`
|
||||
@@ -51,6 +52,11 @@ type Options struct {
|
||||
RoutesStr string `json:"-"`
|
||||
BufSize int `json:"-"`
|
||||
TLSTimeout float64 `json:"tls_timeout"`
|
||||
TLS bool `json:"-"`
|
||||
TLSVerify bool `json:"-"`
|
||||
TLSCert string `json:"-"`
|
||||
TLSKey string `json:"-"`
|
||||
TLSCaCert string `json:"-"`
|
||||
TLSConfig *tls.Config `json:"-"`
|
||||
}
|
||||
|
||||
@@ -61,11 +67,13 @@ type authorization struct {
|
||||
}
|
||||
|
||||
// This struct holds the parsed tls config information.
|
||||
type tlsConfig struct {
|
||||
certFile string
|
||||
keyFile string
|
||||
caFile string
|
||||
verify bool
|
||||
// It spublic so we can use ot for flag parsing
|
||||
type TLSConfigOpts struct {
|
||||
CertFile string
|
||||
KeyFile string
|
||||
CaFile string
|
||||
Verify bool
|
||||
Timeout float64
|
||||
}
|
||||
|
||||
// ProcessConfigFile processes a configuration file.
|
||||
@@ -132,9 +140,14 @@ func ProcessConfigFile(configFile string) (*Options, error) {
|
||||
opts.MaxConn = int(v.(int64))
|
||||
case "tls":
|
||||
tlsm := v.(map[string]interface{})
|
||||
if opts.TLSConfig, err = parseTLS(tlsm); err != nil {
|
||||
tc, err := parseTLS(tlsm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if opts.TLSConfig, err = GenTLSConfig(tc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts.TLSTimeout = tc.Timeout
|
||||
}
|
||||
}
|
||||
return opts, nil
|
||||
@@ -168,9 +181,14 @@ func parseCluster(cm map[string]interface{}, opts *Options) error {
|
||||
case "tls":
|
||||
var err error
|
||||
tlsm := mv.(map[string]interface{})
|
||||
if opts.ClusterTLSConfig, err = parseTLS(tlsm); err != nil {
|
||||
tc, err := parseTLS(tlsm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if opts.ClusterTLSConfig, err = GenTLSConfig(tc); err != nil {
|
||||
return err
|
||||
}
|
||||
opts.ClusterTLSTimeout = tc.Timeout
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -200,8 +218,9 @@ func parseAuthorization(am map[string]interface{}) authorization {
|
||||
}
|
||||
|
||||
// Helper function to parse TLS configs.
|
||||
func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) {
|
||||
tc := tlsConfig{}
|
||||
//func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) {
|
||||
func parseTLS(tlsm map[string]interface{}) (*TLSConfigOpts, error) {
|
||||
tc := TLSConfigOpts{}
|
||||
for mk, mv := range tlsm {
|
||||
switch strings.ToLower(mk) {
|
||||
case "cert_file":
|
||||
@@ -209,32 +228,44 @@ func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error parsing tls config, expected 'cert_file' to be filename")
|
||||
}
|
||||
tc.certFile = certFile
|
||||
tc.CertFile = certFile
|
||||
case "key_file":
|
||||
keyFile, ok := mv.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error parsing tls config, expected 'key_file' to be filename")
|
||||
}
|
||||
tc.keyFile = keyFile
|
||||
tc.KeyFile = keyFile
|
||||
case "ca_file":
|
||||
caFile, ok := mv.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error parsing tls config, expected 'ca_file' to be filename")
|
||||
}
|
||||
tc.caFile = caFile
|
||||
tc.CaFile = caFile
|
||||
case "verify":
|
||||
verify, ok := mv.(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error parsing tls config, expected 'verify' to be a boolean")
|
||||
}
|
||||
tc.verify = verify
|
||||
|
||||
tc.Verify = verify
|
||||
case "timeout":
|
||||
at := float64(0)
|
||||
switch mv.(type) {
|
||||
case int64:
|
||||
at = float64(mv.(int64))
|
||||
case float64:
|
||||
at = mv.(float64)
|
||||
}
|
||||
tc.Timeout = at
|
||||
default:
|
||||
return nil, fmt.Errorf("error parsing tls config, unknown field [%q]", mk)
|
||||
}
|
||||
}
|
||||
return &tc, nil
|
||||
}
|
||||
|
||||
func GenTLSConfig(tc *TLSConfigOpts) (*tls.Config, error) {
|
||||
// Now load in cert and private key
|
||||
cert, err := tls.LoadX509KeyPair(tc.certFile, tc.keyFile)
|
||||
cert, err := tls.LoadX509KeyPair(tc.CertFile, tc.KeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing X509 certificate/key pair: %v", err)
|
||||
}
|
||||
@@ -244,7 +275,6 @@ func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) {
|
||||
}
|
||||
// Create TLSConfig
|
||||
// We will determine the cipher suites that we prefer.
|
||||
|
||||
config := tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
PreferServerCipherSuites: true,
|
||||
@@ -252,12 +282,12 @@ func parseTLS(tlsm map[string]interface{}) (*tls.Config, error) {
|
||||
CipherSuites: defaultCipherSuites(),
|
||||
}
|
||||
// Require client certificates as needed
|
||||
if tc.verify == true {
|
||||
if tc.Verify == true {
|
||||
config.ClientAuth = tls.RequireAnyClientCert
|
||||
}
|
||||
// Add in CAs if applicable.
|
||||
if tc.caFile != "" {
|
||||
rootPEM, err := ioutil.ReadFile(tc.caFile)
|
||||
if tc.CaFile != "" {
|
||||
rootPEM, err := ioutil.ReadFile(tc.CaFile)
|
||||
if err != nil || rootPEM == nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -446,7 +476,7 @@ func processOptions(opts *Options) {
|
||||
opts.MaxPingsOut = DEFAULT_PING_MAX_OUT
|
||||
}
|
||||
if opts.TLSTimeout == 0 {
|
||||
opts.TLSTimeout = float64(SSL_TIMEOUT) / float64(time.Second)
|
||||
opts.TLSTimeout = float64(TLS_TIMEOUT) / float64(time.Second)
|
||||
}
|
||||
if opts.AuthTimeout == 0 {
|
||||
opts.AuthTimeout = float64(AUTH_TIMEOUT) / float64(time.Second)
|
||||
|
||||
@@ -17,7 +17,7 @@ func TestDefaultOptions(t *testing.T) {
|
||||
MaxConn: DEFAULT_MAX_CONNECTIONS,
|
||||
PingInterval: DEFAULT_PING_INTERVAL,
|
||||
MaxPingsOut: DEFAULT_PING_MAX_OUT,
|
||||
TLSTimeout: float64(SSL_TIMEOUT) / float64(time.Second),
|
||||
TLSTimeout: float64(TLS_TIMEOUT) / float64(time.Second),
|
||||
AuthTimeout: float64(AUTH_TIMEOUT) / float64(time.Second),
|
||||
MaxControlLine: MAX_CONTROL_LINE_SIZE,
|
||||
MaxPayload: MAX_PAYLOAD_SIZE,
|
||||
@@ -85,6 +85,7 @@ func TestTLSConfigFile(t *testing.T) {
|
||||
Username: "derek",
|
||||
Password: "buckley",
|
||||
AuthTimeout: 1.0,
|
||||
TLSTimeout: 0.5,
|
||||
}
|
||||
opts, err := ProcessConfigFile("./configs/tls.conf")
|
||||
if err != nil {
|
||||
|
||||
123
server/server.go
123
server/server.go
@@ -294,6 +294,11 @@ func (s *Server) AcceptLoop() {
|
||||
return
|
||||
}
|
||||
|
||||
// Alert of TLS enabled.
|
||||
if s.opts.TLSConfig != nil {
|
||||
Noticef("TLS required for client connections")
|
||||
}
|
||||
|
||||
Noticef("gnatsd is ready")
|
||||
|
||||
// Setup state that can enable shutdown
|
||||
@@ -333,7 +338,7 @@ func (s *Server) AcceptLoop() {
|
||||
continue
|
||||
}
|
||||
tmpDelay = ACCEPT_MIN_SLEEP
|
||||
s.createClient(conn)
|
||||
go s.createClient(conn)
|
||||
}
|
||||
Noticef("Server Exiting..")
|
||||
s.done <- true
|
||||
@@ -424,19 +429,6 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
// Send our information.
|
||||
s.sendInfo(c, info)
|
||||
|
||||
// Check for TLS
|
||||
if tlsRequired {
|
||||
c.Debugf("Starting TLS client connection handshake")
|
||||
c.nc = tls.Server(c.nc, s.opts.TLSConfig)
|
||||
conn := c.nc.(*tls.Conn)
|
||||
err := conn.Handshake()
|
||||
if err != nil {
|
||||
c.Debugf("TLS handshake error: %v", err)
|
||||
}
|
||||
// Rewrap bw
|
||||
c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize)
|
||||
}
|
||||
|
||||
// Unlock to register
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -445,9 +437,112 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
s.clients[c.cid] = c
|
||||
s.mu.Unlock()
|
||||
|
||||
// Check for TLS
|
||||
if tlsRequired {
|
||||
// Re-Grab lock
|
||||
c.mu.Lock()
|
||||
|
||||
c.Debugf("Starting TLS client connection handshake")
|
||||
c.nc = tls.Server(c.nc, s.opts.TLSConfig)
|
||||
conn := c.nc.(*tls.Conn)
|
||||
|
||||
// Setup the timeout
|
||||
ttl := secondsToDuration(s.opts.TLSTimeout)
|
||||
time.AfterFunc(ttl, func() { tlsTimeout(c, conn) })
|
||||
conn.SetReadDeadline(time.Now().Add(ttl))
|
||||
|
||||
// Force handshake
|
||||
c.mu.Unlock()
|
||||
if err := conn.Handshake(); err != nil {
|
||||
c.Debugf("TLS handshake error: %v", err)
|
||||
c.sendErr("Secure Connection - TLS Required")
|
||||
c.closeConnection()
|
||||
return nil
|
||||
}
|
||||
// Reset the read deadline
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Re-Grab lock
|
||||
c.mu.Lock()
|
||||
|
||||
// Rewrap bw
|
||||
c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize)
|
||||
c.Debugf("TLS handshake complete")
|
||||
cs := conn.ConnectionState()
|
||||
c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite))
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Handle closing down a connection when the handshake has timedout.
|
||||
func tlsTimeout(c *client, conn *tls.Conn) {
|
||||
c.mu.Lock()
|
||||
nc := c.nc
|
||||
c.mu.Unlock()
|
||||
// Check if already closed
|
||||
if nc == nil {
|
||||
return
|
||||
}
|
||||
cs := conn.ConnectionState()
|
||||
if cs.HandshakeComplete == false {
|
||||
c.Debugf("TLS handshake timeout")
|
||||
c.sendErr("Secure Connection - TLS Required")
|
||||
c.closeConnection()
|
||||
}
|
||||
}
|
||||
|
||||
// Seems silly we have to write these
|
||||
func tlsVersion(ver uint16) string {
|
||||
switch ver {
|
||||
case tls.VersionTLS10:
|
||||
return "1.0"
|
||||
case tls.VersionTLS11:
|
||||
return "1.1"
|
||||
case tls.VersionTLS12:
|
||||
return "1.2"
|
||||
}
|
||||
return fmt.Sprintf("Unknown [%x]", ver)
|
||||
}
|
||||
|
||||
// We use hex here so we don't need multiple versions
|
||||
func tlsCipher(cs uint16) string {
|
||||
switch cs {
|
||||
case 0x0005:
|
||||
return "TLS_RSA_WITH_RC4_128_SHA"
|
||||
case 0x000a:
|
||||
return "TLS_RSA_WITH_3DES_EDE_CBC_SHA"
|
||||
case 0x002f:
|
||||
return "TLS_RSA_WITH_AES_128_CBC_SHA"
|
||||
case 0x0035:
|
||||
return "TLS_RSA_WITH_AES_256_CBC_SHA"
|
||||
case 0xc007:
|
||||
return "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA"
|
||||
case 0xc009:
|
||||
return "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA"
|
||||
case 0xc00a:
|
||||
return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"
|
||||
case 0xc011:
|
||||
return "TLS_ECDHE_RSA_WITH_RC4_128_SHA"
|
||||
case 0xc012:
|
||||
return "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA"
|
||||
case 0xc013:
|
||||
return "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA"
|
||||
case 0xc014:
|
||||
return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"
|
||||
case 0xc02f:
|
||||
return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
|
||||
case 0xc02b:
|
||||
return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
|
||||
case 0xc030:
|
||||
return "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"
|
||||
case 0xc02c:
|
||||
return "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
|
||||
}
|
||||
return fmt.Sprintf("Unknown [%x]", cs)
|
||||
}
|
||||
|
||||
// Assume the lock is held upon entry.
|
||||
func (s *Server) sendInfo(c *client, info []byte) {
|
||||
c.nc.Write(info)
|
||||
|
||||
@@ -18,8 +18,8 @@ Server Options:
|
||||
Logging Options:
|
||||
-l, --log FILE File to redirect log output
|
||||
-T, --logtime Timestamp log entries (default: true)
|
||||
-s, --syslog Enable syslog as log method.
|
||||
-r, --remote_syslog Syslog server addr (udp://localhost:514).
|
||||
-s, --syslog Enable syslog as log method
|
||||
-r, --remote_syslog Syslog server addr (udp://localhost:514)
|
||||
-D, --debug Enable debugging output
|
||||
-V, --trace Trace the raw protocol
|
||||
-DV Debug and Trace
|
||||
@@ -28,6 +28,13 @@ Authorization Options:
|
||||
--user user User required for connections
|
||||
--pass password Password required for connections
|
||||
|
||||
TLS Options:
|
||||
--tls Enable TLS, do not verify clients (default: false)
|
||||
--tlscert FILE Server certificate file
|
||||
--tlskey FILE Private key for server certificate
|
||||
--tlsverify Enable TLS, very client certificates
|
||||
--tlscacert client Client certificate CA for verification
|
||||
|
||||
Cluster Options:
|
||||
--routes [rurl-1, rurl-2] Routes to solicit and connect
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ tls {
|
||||
cert_file: "./configs/certs/server-cert.pem"
|
||||
# Server private key
|
||||
key_file: "./configs/certs/server-key.pem"
|
||||
# Specified time for handshake to complete
|
||||
timeout: 0.25
|
||||
}
|
||||
|
||||
authorization {
|
||||
|
||||
@@ -3,11 +3,15 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats"
|
||||
)
|
||||
@@ -129,3 +133,40 @@ func TestTLSClientCertificate(t *testing.T) {
|
||||
nc.Flush()
|
||||
defer nc.Close()
|
||||
}
|
||||
|
||||
func TestTLSConnectionTimeout(t *testing.T) {
|
||||
srv, opts := RunServerWithConfig("./configs/tls.conf")
|
||||
defer srv.Shutdown()
|
||||
|
||||
// Dial with normal TCP
|
||||
endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port)
|
||||
conn, err := net.Dial("tcp", endpoint)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not connect to %q", endpoint)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Read deadlines
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
|
||||
// Read the INFO string.
|
||||
br := bufio.NewReader(conn)
|
||||
info, err := br.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read INFO - %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(info, "INFO ") {
|
||||
t.Fatalf("INFO response incorrect: %s\n", info)
|
||||
}
|
||||
wait := time.Duration(opts.TLSTimeout * float64(time.Second))
|
||||
time.Sleep(wait)
|
||||
// Read deadlines
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
tlsErr, err := br.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("Error reading error response - %v\n", err)
|
||||
}
|
||||
if !strings.Contains(tlsErr, "-ERR 'Secure Connection - TLS Required") {
|
||||
t.Fatalf("TLS Timeout response incorrect: %q\n", tlsErr)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user