mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-14 10:10:42 -07:00
First pass at client TLS support
This commit is contained in:
@@ -244,7 +244,7 @@ func (c *client) processRouteInfo(info *Info) {
|
||||
}
|
||||
}
|
||||
|
||||
// Process the information messages from Clients and other routes.
|
||||
// Process the information messages from Clients and other Routes.
|
||||
func (c *client) processInfo(arg []byte) error {
|
||||
info := Info{}
|
||||
if err := json.Unmarshal(arg, &info); err != nil {
|
||||
@@ -293,6 +293,7 @@ func (c *client) processConnect(arg []byte) error {
|
||||
|
||||
func (c *client) authTimeout() {
|
||||
c.sendErr("Authorization Timeout")
|
||||
c.Debugf("Authorization Timeout")
|
||||
c.closeConnection()
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ type serverInfo struct {
|
||||
Port uint `json:"port"`
|
||||
Version string `json:"version"`
|
||||
AuthRequired bool `json:"auth_required"`
|
||||
SslRequired bool `json:"ssl_required"`
|
||||
TLSRequired bool `json:"ssl_required"`
|
||||
MaxPayload int64 `json:"max_payload"`
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func TestClientCreateAndInfo(t *testing.T) {
|
||||
}
|
||||
// Sanity checks
|
||||
if info.MaxPayload != MAX_PAYLOAD_SIZE ||
|
||||
info.AuthRequired || info.SslRequired ||
|
||||
info.AuthRequired || info.TLSRequired ||
|
||||
info.Port != DEFAULT_PORT {
|
||||
t.Fatalf("INFO inconsistent: %+v\n", info)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
// Copyright 2012-2013 Apcera Inc. All rights reserved.
|
||||
// Copyright 2012-2015 Apcera Inc. All rights reserved.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@@ -30,7 +32,6 @@ type Options struct {
|
||||
PingInterval time.Duration `json:"ping_interval"`
|
||||
MaxPingsOut int `json:"ping_max"`
|
||||
HTTPPort int `json:"http_port"`
|
||||
SslTimeout float64 `json:"ssl_timeout"`
|
||||
AuthTimeout float64 `json:"auth_timeout"`
|
||||
MaxControlLine int `json:"max_control_line"`
|
||||
MaxPayload int `json:"max_payload"`
|
||||
@@ -48,6 +49,8 @@ type Options struct {
|
||||
Routes []*url.URL `json:"-"`
|
||||
RoutesStr string `json:"-"`
|
||||
BufSize int `json:"-"`
|
||||
TLSTimeout float64 `json:"tls_timeout"`
|
||||
TLSConfig *tls.Config `json:"-"`
|
||||
}
|
||||
|
||||
type authorization struct {
|
||||
@@ -56,6 +59,12 @@ type authorization struct {
|
||||
timeout float64
|
||||
}
|
||||
|
||||
// This struct holds the parsed tls config information.
|
||||
type tlsConfig struct {
|
||||
certFile string
|
||||
keyFile string
|
||||
}
|
||||
|
||||
// ProcessConfigFile processes a configuration file.
|
||||
// FIXME(dlc): Hacky
|
||||
func ProcessConfigFile(configFile string) (*Options, error) {
|
||||
@@ -118,6 +127,11 @@ func ProcessConfigFile(configFile string) (*Options, error) {
|
||||
opts.MaxPending = int(v.(int64))
|
||||
case "max_connections", "max_conn":
|
||||
opts.MaxConn = int(v.(int64))
|
||||
case "tls":
|
||||
tlsm := v.(map[string]interface{})
|
||||
if err := parseTLS(tlsm, opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return opts, nil
|
||||
@@ -176,6 +190,53 @@ func parseAuthorization(am map[string]interface{}) authorization {
|
||||
return auth
|
||||
}
|
||||
|
||||
// Helper function to parse TLS configs.
|
||||
func parseTLS(tlsm map[string]interface{}, opts *Options) error {
|
||||
tc := tlsConfig{}
|
||||
for mk, mv := range tlsm {
|
||||
switch strings.ToLower(mk) {
|
||||
case "cert_file":
|
||||
certFile, ok := mv.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("error parsing tls config, expected 'cert_file' to be filename")
|
||||
}
|
||||
tc.certFile = certFile
|
||||
case "key_file":
|
||||
keyFile, ok := mv.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("error parsing tls config, expected 'key_file' to be filename")
|
||||
}
|
||||
tc.keyFile = keyFile
|
||||
default:
|
||||
return fmt.Errorf("error parsing tls config, unknown field [%q]", mk)
|
||||
}
|
||||
}
|
||||
// Now load in cert and private key
|
||||
cert, err := tls.LoadX509KeyPair(tc.certFile, tc.keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing X509 certificate/key pair: %v", err)
|
||||
}
|
||||
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing certificate: %v", err)
|
||||
}
|
||||
// Create TLSConfig
|
||||
// We will determine the cipher suites that we prefer.
|
||||
config := tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
PreferServerCipherSuites: true,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
}
|
||||
opts.TLSConfig = &config
|
||||
return nil
|
||||
}
|
||||
|
||||
// MergeOptions will merge two options giving preference to the flagOpts
|
||||
// if the item is present.
|
||||
func MergeOptions(fileOpts, flagOpts *Options) *Options {
|
||||
@@ -349,8 +410,8 @@ func processOptions(opts *Options) {
|
||||
if opts.MaxPingsOut == 0 {
|
||||
opts.MaxPingsOut = DEFAULT_PING_MAX_OUT
|
||||
}
|
||||
if opts.SslTimeout == 0 {
|
||||
opts.SslTimeout = float64(SSL_TIMEOUT) / float64(time.Second)
|
||||
if opts.TLSTimeout == 0 {
|
||||
opts.TLSTimeout = float64(SSL_TIMEOUT) / float64(time.Second)
|
||||
}
|
||||
if opts.AuthTimeout == 0 {
|
||||
opts.AuthTimeout = float64(AUTH_TIMEOUT) / float64(time.Second)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
// Copyright 2013-2014 Apcera Inc. All rights reserved.
|
||||
// Copyright 2013-2015 Apcera Inc. All rights reserved.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
@@ -16,7 +17,7 @@ func TestDefaultOptions(t *testing.T) {
|
||||
MaxConn: DEFAULT_MAX_CONNECTIONS,
|
||||
PingInterval: DEFAULT_PING_INTERVAL,
|
||||
MaxPingsOut: DEFAULT_PING_MAX_OUT,
|
||||
SslTimeout: float64(SSL_TIMEOUT) / float64(time.Second),
|
||||
TLSTimeout: float64(SSL_TIMEOUT) / float64(time.Second),
|
||||
AuthTimeout: float64(AUTH_TIMEOUT) / float64(time.Second),
|
||||
MaxControlLine: MAX_CONTROL_LINE_SIZE,
|
||||
MaxPayload: MAX_PAYLOAD_SIZE,
|
||||
@@ -77,6 +78,57 @@ func TestConfigFile(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSConfigFile(t *testing.T) {
|
||||
golden := &Options{
|
||||
Host: "apcera.me",
|
||||
Port: 4443,
|
||||
Username: "derek",
|
||||
Password: "buckley",
|
||||
AuthTimeout: 1.0,
|
||||
}
|
||||
opts, err := ProcessConfigFile("./configs/tls/test.conf")
|
||||
if err != nil {
|
||||
t.Fatalf("Received an error reading config file: %v\n", err)
|
||||
}
|
||||
tlsConfig := opts.TLSConfig
|
||||
if tlsConfig == nil {
|
||||
t.Fatal("Expected opts.TLSConfig to be non-nil")
|
||||
}
|
||||
opts.TLSConfig = nil
|
||||
if !reflect.DeepEqual(golden, opts) {
|
||||
t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v",
|
||||
golden, opts)
|
||||
}
|
||||
// Now check TLSConfig a bit more closely
|
||||
// CipherSuites
|
||||
ciphers := []uint16{
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
}
|
||||
if !reflect.DeepEqual(tlsConfig.CipherSuites, ciphers) {
|
||||
t.Fatalf("Got incorrect cipher suite list: [%+v]", tlsConfig.CipherSuites)
|
||||
}
|
||||
if tlsConfig.MinVersion != tls.VersionTLS12 {
|
||||
t.Fatalf("Expected MinVersion of 1.2 [%v], got [%v]", tls.VersionTLS12, tlsConfig.MinVersion)
|
||||
}
|
||||
if tlsConfig.PreferServerCipherSuites != true {
|
||||
t.Fatal("Expected PreferServerCipherSuites to be true")
|
||||
}
|
||||
// Verify hostname is correct in certificate
|
||||
if len(tlsConfig.Certificates) != 1 {
|
||||
t.Fatal("Expected 1 certificate")
|
||||
}
|
||||
if len(tlsConfig.Certificates) < 1 {
|
||||
t.Fatalf("Expected certificates")
|
||||
}
|
||||
cert := tlsConfig.Certificates[0].Leaf
|
||||
if err := cert.VerifyHostname("apcera.me:4443"); err != nil {
|
||||
t.Fatalf("Could not verify hostname in certificate: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeOverrides(t *testing.T) {
|
||||
golden := &Options{
|
||||
Host: "apcera.me",
|
||||
|
||||
@@ -24,7 +24,7 @@ type connectInfo struct {
|
||||
Pedantic bool `json:"pedantic"`
|
||||
User string `json:"user,omitempty"`
|
||||
Pass string `json:"pass,omitempty"`
|
||||
Ssl bool `json:"ssl_required"`
|
||||
TLS bool `json:"ssl_required"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func (c *client) sendConnect() {
|
||||
Pedantic: false,
|
||||
User: user,
|
||||
Pass: pass,
|
||||
Ssl: false,
|
||||
TLS: false,
|
||||
Name: c.srv.info.ID,
|
||||
}
|
||||
b, err := json.Marshal(cinfo)
|
||||
@@ -301,7 +301,7 @@ func (s *Server) StartRouting() {
|
||||
Host: s.opts.ClusterHost,
|
||||
Port: s.opts.ClusterPort,
|
||||
AuthRequired: false,
|
||||
SslRequired: false,
|
||||
TLSRequired: false,
|
||||
MaxPayload: MAX_PAYLOAD_SIZE,
|
||||
}
|
||||
// Check for Auth items
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
@@ -30,7 +32,7 @@ type Info struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
AuthRequired bool `json:"auth_required"`
|
||||
SslRequired bool `json:"ssl_required"`
|
||||
TLSRequired bool `json:"ssl_required"` // ssl json used for older clients
|
||||
MaxPayload int `json:"max_payload"`
|
||||
}
|
||||
|
||||
@@ -80,7 +82,7 @@ func New(opts *Options) *Server {
|
||||
Host: opts.Host,
|
||||
Port: opts.Port,
|
||||
AuthRequired: false,
|
||||
SslRequired: false,
|
||||
TLSRequired: opts.TLSConfig != nil,
|
||||
MaxPayload: opts.MaxPayload,
|
||||
}
|
||||
|
||||
@@ -393,6 +395,7 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
s.mu.Lock()
|
||||
info := s.infoJSON
|
||||
authRequired := s.info.AuthRequired
|
||||
tlsRequired := s.info.TLSRequired
|
||||
s.mu.Unlock()
|
||||
|
||||
// Grab lock
|
||||
@@ -412,6 +415,19 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
// Send our information.
|
||||
s.sendInfo(c, info)
|
||||
|
||||
// Check for TLS
|
||||
if tlsRequired {
|
||||
c.Debugf("Starting TLS client connection handshake")
|
||||
c.nc = tls.Server(c.nc, s.opts.TLSConfig)
|
||||
conn := c.nc.(*tls.Conn)
|
||||
err := conn.Handshake()
|
||||
if err != nil {
|
||||
c.Debugf("TLS handshake error: %v", err)
|
||||
}
|
||||
// Rewrap bw
|
||||
c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize)
|
||||
}
|
||||
|
||||
// Unlock to register
|
||||
c.mu.Unlock()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -19,3 +20,17 @@ func TestServerConfig(t *testing.T) {
|
||||
opts.MaxPayload, sinfo.MaxPayload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSConfig(t *testing.T) {
|
||||
srv, opts := RunServerWithConfig("./configs/tls.conf")
|
||||
defer srv.Shutdown()
|
||||
|
||||
c := createClientConn(t, opts.Host, opts.Port)
|
||||
defer c.Close()
|
||||
|
||||
sinfo := checkInfoMsg(t, c)
|
||||
fmt.Printf("sinfo is %+v\n", sinfo)
|
||||
if sinfo.TLSRequired != true {
|
||||
t.Fatal("Expected TLSRequired to be true when configured")
|
||||
}
|
||||
}
|
||||
|
||||
12
test/test.go
12
test/test.go
@@ -15,6 +15,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/gnatsd/auth"
|
||||
"github.com/nats-io/gnatsd/server"
|
||||
)
|
||||
|
||||
@@ -53,7 +54,16 @@ func RunServerWithConfig(configFile string) (srv *server.Server, opts *server.Op
|
||||
panic(fmt.Sprintf("Error processing configuration file: %v", err))
|
||||
}
|
||||
opts.NoSigs, opts.NoLog = true, true
|
||||
srv = RunServer(opts)
|
||||
|
||||
// Check for auth
|
||||
var a server.Auth
|
||||
if opts.Authorization != "" {
|
||||
a = &auth.Token{Token: opts.Authorization}
|
||||
}
|
||||
if opts.Username != "" {
|
||||
a = &auth.Plain{Username: opts.Username, Password: opts.Password}
|
||||
}
|
||||
srv = RunServerWithAuth(opts, a)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user