mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 03:38:42 -07:00
Added dedicated auth block for websocket
Websocket can now override - Username/password - Token - Users - NKeys - no_auth_user - auth_timeout For TLS, support for verify and verify_and_map. We used to set tls config's ClientAuth to NoClientCert. It will now depend if the config requires client certificate verification, which is needed if TLSMap is enabled. Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
This commit is contained in:
@@ -278,6 +278,9 @@ func (s *Server) configureAuthorization() {
|
||||
s.users = nil
|
||||
s.info.AuthRequired = false
|
||||
}
|
||||
|
||||
// Do similar for websocket config
|
||||
s.wsConfigAuth(&opts.Websocket)
|
||||
}
|
||||
|
||||
// checkAuthentication will check based on client type and
|
||||
@@ -309,10 +312,64 @@ func (s *Server) isClientAuthorized(c *client) bool {
|
||||
return opts.CustomClientAuthentication.Check(c)
|
||||
}
|
||||
|
||||
return s.processClientOrLeafAuthentication(c)
|
||||
return s.processClientOrLeafAuthentication(c, opts)
|
||||
}
|
||||
|
||||
func (s *Server) processClientOrLeafAuthentication(c *client) bool {
|
||||
type authOpts struct {
|
||||
username string
|
||||
password string
|
||||
token string
|
||||
noAuthUser string
|
||||
tlsMap bool
|
||||
users map[string]*User
|
||||
nkeys map[string]*NkeyUser
|
||||
}
|
||||
|
||||
func (s *Server) getAuthOpts(c *client, o *Options, auth *authOpts) bool {
|
||||
// c.ws is immutable, but may need lock if we get race reports.
|
||||
wsClient := c.ws != nil
|
||||
|
||||
authRequired := s.info.AuthRequired
|
||||
// For websocket clients, if there is no top-level auth, then we
|
||||
// check for websocket specifically.
|
||||
if !authRequired && wsClient {
|
||||
authRequired = s.websocket.authRequired
|
||||
}
|
||||
if !authRequired {
|
||||
return false
|
||||
}
|
||||
auth.noAuthUser = o.NoAuthUser
|
||||
auth.tlsMap = o.TLSMap
|
||||
if wsClient {
|
||||
wo := &o.Websocket
|
||||
// If those are specified, override, regardless if there is
|
||||
// auth configuration (like user/pwd, etc..) in websocket section.
|
||||
if wo.NoAuthUser != _EMPTY_ {
|
||||
auth.noAuthUser = wo.NoAuthUser
|
||||
}
|
||||
if wo.TLSMap {
|
||||
auth.tlsMap = true
|
||||
}
|
||||
// Now check for websocket auth specific override
|
||||
if s.websocket.authRequired {
|
||||
auth.username = wo.Username
|
||||
auth.password = wo.Password
|
||||
auth.token = wo.Token
|
||||
auth.users = s.websocket.users
|
||||
auth.nkeys = s.websocket.nkeys
|
||||
return true
|
||||
}
|
||||
// else fallback to regular auth config
|
||||
}
|
||||
auth.username = o.Username
|
||||
auth.password = o.Password
|
||||
auth.token = o.Authorization
|
||||
auth.users = s.users
|
||||
auth.nkeys = s.nkeys
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) bool {
|
||||
var (
|
||||
nkey *NkeyUser
|
||||
juc *jwt.UserClaims
|
||||
@@ -320,11 +377,11 @@ func (s *Server) processClientOrLeafAuthentication(c *client) bool {
|
||||
user *User
|
||||
ok bool
|
||||
err error
|
||||
opts = s.getOpts()
|
||||
auth authOpts
|
||||
)
|
||||
|
||||
s.mu.Lock()
|
||||
authRequired := s.info.AuthRequired
|
||||
authRequired := s.getAuthOpts(c, opts, &auth)
|
||||
if !authRequired {
|
||||
// TODO(dlc) - If they send us credentials should we fail?
|
||||
s.mu.Unlock()
|
||||
@@ -355,21 +412,21 @@ func (s *Server) processClientOrLeafAuthentication(c *client) bool {
|
||||
}
|
||||
|
||||
// Check if we have nkeys or users for client.
|
||||
hasNkeys := s.nkeys != nil
|
||||
hasUsers := s.users != nil
|
||||
hasNkeys := auth.nkeys != nil
|
||||
hasUsers := auth.users != nil
|
||||
if hasNkeys && c.opts.Nkey != "" {
|
||||
nkey, ok = s.nkeys[c.opts.Nkey]
|
||||
nkey, ok = auth.nkeys[c.opts.Nkey]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
} else if hasUsers {
|
||||
// Check if we are tls verify and are mapping users from the client_certificate
|
||||
if opts.TLSMap {
|
||||
if auth.tlsMap {
|
||||
var euser string
|
||||
authorized := checkClientTLSCertSubject(c, func(u string) bool {
|
||||
var ok bool
|
||||
user, ok = s.users[u]
|
||||
user, ok = auth.users[u]
|
||||
if !ok {
|
||||
c.Debugf("User in cert [%q], not found", u)
|
||||
return false
|
||||
@@ -388,14 +445,14 @@ func (s *Server) processClientOrLeafAuthentication(c *client) bool {
|
||||
// but we set it here to be able to identify it in the logs.
|
||||
c.opts.Username = euser
|
||||
} else {
|
||||
if c.kind == CLIENT && c.opts.Username == "" && s.opts.NoAuthUser != "" {
|
||||
if u, exists := s.users[s.opts.NoAuthUser]; exists {
|
||||
if c.kind == CLIENT && c.opts.Username == "" && auth.noAuthUser != "" {
|
||||
if u, exists := auth.users[auth.noAuthUser]; exists {
|
||||
c.opts.Username = u.Username
|
||||
c.opts.Password = u.Password
|
||||
}
|
||||
}
|
||||
if c.opts.Username != "" {
|
||||
user, ok = s.users[c.opts.Username]
|
||||
user, ok = auth.users[c.opts.Username]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return false
|
||||
@@ -517,13 +574,13 @@ func (s *Server) processClientOrLeafAuthentication(c *client) bool {
|
||||
}
|
||||
|
||||
if c.kind == CLIENT {
|
||||
if opts.Authorization != "" {
|
||||
return comparePasswords(opts.Authorization, c.opts.Token)
|
||||
} else if opts.Username != "" {
|
||||
if opts.Username != c.opts.Username {
|
||||
if auth.token != "" {
|
||||
return comparePasswords(auth.token, c.opts.Token)
|
||||
} else if auth.username != "" {
|
||||
if auth.username != c.opts.Username {
|
||||
return false
|
||||
}
|
||||
return comparePasswords(opts.Password, c.opts.Password)
|
||||
return comparePasswords(auth.password, c.opts.Password)
|
||||
}
|
||||
} else if c.kind == LEAF {
|
||||
// There is no required username/password to connect and
|
||||
@@ -733,7 +790,7 @@ func (s *Server) isLeafNodeAuthorized(c *client) bool {
|
||||
// Still, if the CONNECT has some user info, we will bind to the
|
||||
// user's account or to the specified default account (if provided)
|
||||
// or to the global account.
|
||||
return s.processClientOrLeafAuthentication(c)
|
||||
return s.processClientOrLeafAuthentication(c, opts)
|
||||
}
|
||||
|
||||
// Support for bcrypt stored passwords and tokens.
|
||||
|
||||
@@ -257,8 +257,26 @@ type WebsocketOpts struct {
|
||||
// The host:port to advertise to websocket clients in the cluster.
|
||||
Advertise string
|
||||
|
||||
// If no user is provided when a client connects, will default to this
|
||||
// user and associated account. This user has to exist either in the
|
||||
// Users defined here or in the global options.
|
||||
NoAuthUser string
|
||||
|
||||
// Authentication section. If anything is configured in this section,
|
||||
// it will override the authorization configuration for regular clients.
|
||||
Username string
|
||||
Password string
|
||||
Token string
|
||||
Users []*User
|
||||
Nkeys []*NkeyUser
|
||||
|
||||
// Timeout for the authentication process.
|
||||
AuthTimeout float64
|
||||
|
||||
// TLS configuration is required.
|
||||
TLSConfig *tls.Config
|
||||
// If true, map certificate values for authentication purposes.
|
||||
TLSMap bool
|
||||
|
||||
// If true, the Origin header must match the request's host.
|
||||
SameOrigin bool
|
||||
@@ -3025,12 +3043,17 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro
|
||||
case "advertise":
|
||||
o.Websocket.Advertise = mv.(string)
|
||||
case "tls":
|
||||
config, _, err := getTLSConfig(tk)
|
||||
tc, err := parseTLS(tk)
|
||||
if err != nil {
|
||||
*errors = append(*errors, err)
|
||||
continue
|
||||
}
|
||||
o.Websocket.TLSConfig = config
|
||||
if o.Websocket.TLSConfig, err = GenTLSConfig(tc); err != nil {
|
||||
err := &configErr{tk, err.Error()}
|
||||
*errors = append(*errors, err)
|
||||
continue
|
||||
}
|
||||
o.Websocket.TLSMap = tc.Map
|
||||
case "same_origin":
|
||||
o.Websocket.SameOrigin = mv.(bool)
|
||||
case "allowed_origins", "allowed_origin", "allow_origins", "allow_origin", "origins", "origin":
|
||||
@@ -3074,6 +3097,42 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro
|
||||
o.Websocket.HandshakeTimeout = ht
|
||||
case "compression":
|
||||
o.Websocket.Compression = mv.(bool)
|
||||
case "authorization", "authentication":
|
||||
auth, err := parseAuthorization(tk, o, errors, warnings)
|
||||
if err != nil {
|
||||
*errors = append(*errors, err)
|
||||
continue
|
||||
}
|
||||
o.Websocket.Username = auth.user
|
||||
o.Websocket.Password = auth.pass
|
||||
o.Websocket.Token = auth.token
|
||||
if (auth.user != "" || auth.pass != "") && auth.token != "" {
|
||||
err := &configErr{tk, "Cannot have a user/pass and token"}
|
||||
*errors = append(*errors, err)
|
||||
continue
|
||||
}
|
||||
o.Websocket.AuthTimeout = auth.timeout
|
||||
// Check for multiple users defined
|
||||
if auth.users != nil {
|
||||
if auth.user != "" {
|
||||
err := &configErr{tk, "Can not have a single user/pass and a users array"}
|
||||
*errors = append(*errors, err)
|
||||
continue
|
||||
}
|
||||
if auth.token != "" {
|
||||
err := &configErr{tk, "Can not have a token and a users array"}
|
||||
*errors = append(*errors, err)
|
||||
continue
|
||||
}
|
||||
// Users may have been added from Accounts parsing, so do an append here
|
||||
o.Websocket.Users = append(o.Websocket.Users, auth.users...)
|
||||
}
|
||||
// Check for nkeys
|
||||
if auth.nkeys != nil {
|
||||
o.Websocket.Nkeys = append(o.Websocket.Nkeys, auth.nkeys...)
|
||||
}
|
||||
case "no_auth_user":
|
||||
o.Websocket.NoAuthUser = mv.(string)
|
||||
default:
|
||||
if !tk.IsUsedVariable() {
|
||||
err := &unknownConfigFieldErr{
|
||||
|
||||
@@ -1875,6 +1875,12 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client {
|
||||
// Grab JSON info string
|
||||
s.mu.Lock()
|
||||
info := s.copyInfo()
|
||||
// If this is a websocket client and there is no top-level auth specified,
|
||||
// then we use the websocket's specific boolean that will be set to true
|
||||
// if there is any auth{} configured in websocket{}.
|
||||
if ws != nil && !info.AuthRequired {
|
||||
info.AuthRequired = s.websocket.authRequired
|
||||
}
|
||||
if s.nonceRequired() {
|
||||
// Nonce handling
|
||||
var raw [nonceLen]byte
|
||||
@@ -1974,7 +1980,15 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client {
|
||||
// the race where the timer fires during the handshake and causes the
|
||||
// server to write bad data to the socket. See issue #432.
|
||||
if info.AuthRequired {
|
||||
c.setAuthTimer(secondsToDuration(opts.AuthTimeout))
|
||||
timeout := opts.AuthTimeout
|
||||
// For websocket, possibly override only if set. We make sure that
|
||||
// opts.AuthTimeout is set to a default value if not configured,
|
||||
// but we don't do the same for websocket's one so that we know
|
||||
// if user has explicitly set or not.
|
||||
if ws != nil && opts.Websocket.AuthTimeout != 0 {
|
||||
timeout = opts.Websocket.AuthTimeout
|
||||
}
|
||||
c.setAuthTimer(secondsToDuration(timeout))
|
||||
}
|
||||
|
||||
// Do final client initialization
|
||||
|
||||
@@ -101,6 +101,9 @@ type srvWebsocket struct {
|
||||
sameOrigin bool
|
||||
connectURLs []string
|
||||
connectURLsMap map[string]struct{}
|
||||
users map[string]*User
|
||||
nkeys map[string]*NkeyUser
|
||||
authRequired bool // indicate if there is auth override in websocket config
|
||||
}
|
||||
|
||||
type allowedOrigin struct {
|
||||
@@ -763,6 +766,55 @@ func (s *Server) wsSetOriginOptions(o *WebsocketOpts) {
|
||||
}
|
||||
}
|
||||
|
||||
// Given the websocket options, we check if any auth configuration
|
||||
// has been provided. If so, possibly create users/nkey users and
|
||||
// store them in s.websocket.users/nkeys.
|
||||
// Also update a boolean that indicates if auth is required for
|
||||
// websocket clients.
|
||||
func (s *Server) wsConfigAuth(opts *WebsocketOpts) {
|
||||
if opts.Nkeys != nil || opts.Users != nil {
|
||||
// Support both at the same time.
|
||||
if opts.Nkeys != nil {
|
||||
s.websocket.nkeys = make(map[string]*NkeyUser)
|
||||
for _, u := range opts.Nkeys {
|
||||
copy := u.clone()
|
||||
if u.Account != nil {
|
||||
if v, ok := s.accounts.Load(u.Account.Name); ok {
|
||||
copy.Account = v.(*Account)
|
||||
}
|
||||
}
|
||||
if copy.Permissions != nil {
|
||||
validateResponsePermissions(copy.Permissions)
|
||||
}
|
||||
s.websocket.nkeys[u.Nkey] = copy
|
||||
}
|
||||
}
|
||||
if opts.Users != nil {
|
||||
s.websocket.users = make(map[string]*User)
|
||||
for _, u := range opts.Users {
|
||||
copy := u.clone()
|
||||
if u.Account != nil {
|
||||
if v, ok := s.accounts.Load(u.Account.Name); ok {
|
||||
copy.Account = v.(*Account)
|
||||
}
|
||||
}
|
||||
if copy.Permissions != nil {
|
||||
validateResponsePermissions(copy.Permissions)
|
||||
}
|
||||
s.websocket.users[u.Username] = copy
|
||||
}
|
||||
}
|
||||
s.assignGlobalAccountToOrphanUsers()
|
||||
s.websocket.authRequired = true
|
||||
} else if opts.Username != "" || opts.Token != "" {
|
||||
s.websocket.authRequired = true
|
||||
} else {
|
||||
s.websocket.users = nil
|
||||
s.websocket.nkeys = nil
|
||||
s.websocket.authRequired = false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) startWebsocketServer() {
|
||||
sopts := s.getOpts()
|
||||
o := &sopts.Websocket
|
||||
@@ -787,7 +839,6 @@ func (s *Server) startWebsocketServer() {
|
||||
if o.TLSConfig != nil {
|
||||
proto = "wss"
|
||||
config := o.TLSConfig.Clone()
|
||||
config.ClientAuth = tls.NoClientCert
|
||||
hl, err = tls.Listen("tcp", hp, config)
|
||||
} else {
|
||||
proto = "ws"
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -34,6 +35,8 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nkeys"
|
||||
)
|
||||
|
||||
type testReader struct {
|
||||
@@ -1581,7 +1584,7 @@ func TestWSAbnormalFailureOfWebServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) {
|
||||
func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader, []byte) {
|
||||
t.Helper()
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
wsc, err := net.Dial("tcp", addr)
|
||||
@@ -1613,9 +1616,15 @@ func testWSCreateClient(t testing.TB, compress, web bool, host string, port int)
|
||||
t.Fatalf("Expected response status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
}
|
||||
// Wait for the INFO
|
||||
if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("INFO {")) {
|
||||
t.Fatalf("Expected INFO, got %s", msg)
|
||||
info := testWSReadFrame(t, br)
|
||||
if !bytes.HasPrefix(info, []byte("INFO {")) {
|
||||
t.Fatalf("Expected INFO, got %s", info)
|
||||
}
|
||||
return wsc, br, info
|
||||
}
|
||||
|
||||
func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) {
|
||||
wsc, br, _ := testWSCreateClientGetInfo(t, compress, web, host, port)
|
||||
// Send CONNECT and PING
|
||||
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"))
|
||||
if _, err := wsc.Write(wsmsg); err != nil {
|
||||
@@ -1792,6 +1801,187 @@ func TestWSTLSConnection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSTLSVerifyClientCert(t *testing.T) {
|
||||
o := testWSOptions()
|
||||
tc := &TLSConfigOpts{
|
||||
CertFile: "../test/configs/certs/server-cert.pem",
|
||||
KeyFile: "../test/configs/certs/server-key.pem",
|
||||
CaFile: "../test/configs/certs/ca.pem",
|
||||
Verify: true,
|
||||
}
|
||||
tlsc, err := GenTLSConfig(tc)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating tls config: %v", err)
|
||||
}
|
||||
o.Websocket.TLSConfig = tlsc
|
||||
s := RunServer(o)
|
||||
defer s.Shutdown()
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
|
||||
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
provideCert bool
|
||||
}{
|
||||
{"client provides cert", true},
|
||||
{"client does not provide cert", false},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
wsc, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating ws connection: %v", err)
|
||||
}
|
||||
defer wsc.Close()
|
||||
tlsc := &tls.Config{}
|
||||
if test.provideCert {
|
||||
tc := &TLSConfigOpts{
|
||||
CertFile: "../test/configs/certs/client-cert.pem",
|
||||
KeyFile: "../test/configs/certs/client-key.pem",
|
||||
}
|
||||
var err error
|
||||
tlsc, err = GenTLSConfig(tc)
|
||||
if err != nil {
|
||||
t.Fatalf("Error generating tls config: %v", err)
|
||||
}
|
||||
}
|
||||
tlsc.InsecureSkipVerify = true
|
||||
wsc = tls.Client(wsc, tlsc)
|
||||
if err := wsc.(*tls.Conn).Handshake(); err != nil {
|
||||
t.Fatalf("Error during handshake: %v", err)
|
||||
}
|
||||
req := testWSCreateValidReq()
|
||||
req.URL, _ = url.Parse("wss://" + addr)
|
||||
if err := req.Write(wsc); err != nil {
|
||||
t.Fatalf("Error sending request: %v", err)
|
||||
}
|
||||
br := bufio.NewReader(wsc)
|
||||
resp, err := http.ReadResponse(br, req)
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
if !test.provideCert {
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, did not get one")
|
||||
} else if !strings.Contains(err.Error(), "bad certificate") {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSTLSVerifyAndMap(t *testing.T) {
|
||||
o := testWSOptions()
|
||||
certUserName := "CN=example.com,OU=NATS.io"
|
||||
o.Users = []*User{&User{Username: certUserName}}
|
||||
tc := &TLSConfigOpts{
|
||||
CertFile: "../test/configs/certs/tlsauth/server.pem",
|
||||
KeyFile: "../test/configs/certs/tlsauth/server-key.pem",
|
||||
CaFile: "../test/configs/certs/tlsauth/ca.pem",
|
||||
Verify: true,
|
||||
}
|
||||
tlsc, err := GenTLSConfig(tc)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating tls config: %v", err)
|
||||
}
|
||||
o.Websocket.TLSConfig = tlsc
|
||||
o.Websocket.TLSMap = true
|
||||
s := RunServer(o)
|
||||
defer s.Shutdown()
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
|
||||
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
provideCert bool
|
||||
}{
|
||||
{"client provides cert", true},
|
||||
{"client does not provide cert", false},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
wsc, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating ws connection: %v", err)
|
||||
}
|
||||
defer wsc.Close()
|
||||
tlsc := &tls.Config{}
|
||||
if test.provideCert {
|
||||
tc := &TLSConfigOpts{
|
||||
CertFile: "../test/configs/certs/tlsauth/client.pem",
|
||||
KeyFile: "../test/configs/certs/tlsauth/client-key.pem",
|
||||
}
|
||||
var err error
|
||||
tlsc, err = GenTLSConfig(tc)
|
||||
if err != nil {
|
||||
t.Fatalf("Error generating tls config: %v", err)
|
||||
}
|
||||
}
|
||||
tlsc.InsecureSkipVerify = true
|
||||
wsc = tls.Client(wsc, tlsc)
|
||||
if err := wsc.(*tls.Conn).Handshake(); err != nil {
|
||||
t.Fatalf("Error during handshake: %v", err)
|
||||
}
|
||||
req := testWSCreateValidReq()
|
||||
req.URL, _ = url.Parse("wss://" + addr)
|
||||
if err := req.Write(wsc); err != nil {
|
||||
t.Fatalf("Error sending request: %v", err)
|
||||
}
|
||||
br := bufio.NewReader(wsc)
|
||||
resp, err := http.ReadResponse(br, req)
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
if !test.provideCert {
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, did not get one")
|
||||
} else if !strings.Contains(err.Error(), "bad certificate") {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
}
|
||||
// Wait for the INFO
|
||||
l := testWSReadFrame(t, br)
|
||||
if !bytes.HasPrefix(l, []byte("INFO {")) {
|
||||
t.Fatalf("Expected INFO, got %s", l)
|
||||
}
|
||||
var info serverInfo
|
||||
if err := json.Unmarshal(l[5:], &info); err != nil {
|
||||
t.Fatalf("Unable to unmarshal info: %v", err)
|
||||
}
|
||||
// Send CONNECT and PING
|
||||
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"))
|
||||
if _, err := wsc.Write(wsmsg); err != nil {
|
||||
t.Fatalf("Error sending message: %v", err)
|
||||
}
|
||||
// Wait for the PONG
|
||||
if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
||||
t.Fatalf("Expected PONG, got %s", msg)
|
||||
}
|
||||
|
||||
c := s.getClient(info.CID)
|
||||
c.mu.Lock()
|
||||
un := c.opts.Username
|
||||
c.mu.Unlock()
|
||||
if un != certUserName {
|
||||
t.Fatalf("Expected client's assigned username to be %q, got %q", certUserName, un)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSHandshakeTimeout(t *testing.T) {
|
||||
o := testWSOptions()
|
||||
o.Websocket.HandshakeTimeout = time.Millisecond
|
||||
@@ -2406,6 +2596,521 @@ func TestWSCompressionFrameSizeLimit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSBasicAuth(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
opts func() *Options
|
||||
user string
|
||||
pass string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
"top level auth, no override, wrong u/p",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Username = "normal"
|
||||
o.Password = "client"
|
||||
return o
|
||||
},
|
||||
"websocket", "client", "-ERR 'Authorization Violation'",
|
||||
},
|
||||
{
|
||||
"top level auth, no override, correct u/p",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Username = "normal"
|
||||
o.Password = "client"
|
||||
return o
|
||||
},
|
||||
"normal", "client", "",
|
||||
},
|
||||
{
|
||||
"no top level auth, ws auth, wrong u/p",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Websocket.Username = "websocket"
|
||||
o.Websocket.Password = "client"
|
||||
return o
|
||||
},
|
||||
"normal", "client", "-ERR 'Authorization Violation'",
|
||||
},
|
||||
{
|
||||
"no top level auth, ws auth, correct u/p",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Websocket.Username = "websocket"
|
||||
o.Websocket.Password = "client"
|
||||
return o
|
||||
},
|
||||
"websocket", "client", "",
|
||||
},
|
||||
{
|
||||
"top level auth, ws override, wrong u/p",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Username = "normal"
|
||||
o.Password = "client"
|
||||
o.Websocket.Username = "websocket"
|
||||
o.Websocket.Password = "client"
|
||||
return o
|
||||
},
|
||||
"normal", "client", "-ERR 'Authorization Violation'",
|
||||
},
|
||||
{
|
||||
"top level auth, ws override, correct u/p",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Username = "normal"
|
||||
o.Password = "client"
|
||||
o.Websocket.Username = "websocket"
|
||||
o.Websocket.Password = "client"
|
||||
return o
|
||||
},
|
||||
"websocket", "client", "",
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
o := test.opts()
|
||||
s := RunServer(o)
|
||||
defer s.Shutdown()
|
||||
|
||||
wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
|
||||
defer wsc.Close()
|
||||
|
||||
connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n",
|
||||
test.user, test.pass)
|
||||
|
||||
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
|
||||
if _, err := wsc.Write(wsmsg); err != nil {
|
||||
t.Fatalf("Error sending message: %v", err)
|
||||
}
|
||||
msg := testWSReadFrame(t, br)
|
||||
if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
||||
t.Fatalf("Expected to receive PONG, got %q", msg)
|
||||
} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
|
||||
t.Fatalf("Expected to receive %q, got %q", test.err, msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSAuthTimeout(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
at float64
|
||||
wat float64
|
||||
err string
|
||||
}{
|
||||
{"use top-level auth timeout", 10.0, 0.0, ""},
|
||||
{"use websocket auth timeout", 10.0, 0.05, "-ERR 'Authentication Timeout'"},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
o := testWSOptions()
|
||||
o.AuthTimeout = test.at
|
||||
o.Websocket.Username = "websocket"
|
||||
o.Websocket.Password = "client"
|
||||
o.Websocket.AuthTimeout = test.wat
|
||||
s := RunServer(o)
|
||||
defer s.Shutdown()
|
||||
|
||||
wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
|
||||
defer wsc.Close()
|
||||
|
||||
var info serverInfo
|
||||
json.Unmarshal([]byte(l[5:]), &info)
|
||||
// Make sure that we are told that auth is required.
|
||||
if !info.AuthRequired {
|
||||
t.Fatalf("Expected auth required, was not: %q", l)
|
||||
}
|
||||
start := time.Now()
|
||||
// Wait before sending connect
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
connectProto := "CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"websocket\",\"pass\":\"client\"}\r\nPING\r\n"
|
||||
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
|
||||
if _, err := wsc.Write(wsmsg); err != nil {
|
||||
t.Fatalf("Error sending message: %v", err)
|
||||
}
|
||||
msg := testWSReadFrame(t, br)
|
||||
if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
|
||||
t.Fatalf("Expected to receive %q error, got %q", test.err, msg)
|
||||
} else if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
||||
t.Fatalf("Unexpected error: %q", msg)
|
||||
}
|
||||
if dur := time.Since(start); dur > time.Second {
|
||||
t.Fatalf("Too long to get timeout error: %v", dur)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSTokenAuth(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
opts func() *Options
|
||||
token string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
"top level auth, no override, wrong token",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Authorization = "goodtoken"
|
||||
return o
|
||||
},
|
||||
"badtoken", "-ERR 'Authorization Violation'",
|
||||
},
|
||||
{
|
||||
"top level auth, no override, correct token",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Authorization = "goodtoken"
|
||||
return o
|
||||
},
|
||||
"goodtoken", "",
|
||||
},
|
||||
{
|
||||
"no top level auth, ws auth, wrong token",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Websocket.Token = "goodtoken"
|
||||
return o
|
||||
},
|
||||
"badtoken", "-ERR 'Authorization Violation'",
|
||||
},
|
||||
{
|
||||
"no top level auth, ws auth, correct token",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Websocket.Token = "goodtoken"
|
||||
return o
|
||||
},
|
||||
"goodtoken", "",
|
||||
},
|
||||
{
|
||||
"top level auth, ws override, wrong token",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Authorization = "clienttoken"
|
||||
o.Websocket.Token = "websockettoken"
|
||||
return o
|
||||
},
|
||||
"clienttoken", "-ERR 'Authorization Violation'",
|
||||
},
|
||||
{
|
||||
"top level auth, ws override, correct token",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Authorization = "clienttoken"
|
||||
o.Websocket.Token = "websockettoken"
|
||||
return o
|
||||
},
|
||||
"websockettoken", "",
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
o := test.opts()
|
||||
s := RunServer(o)
|
||||
defer s.Shutdown()
|
||||
|
||||
wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
|
||||
defer wsc.Close()
|
||||
|
||||
connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"auth_token\":\"%s\"}\r\nPING\r\n",
|
||||
test.token)
|
||||
|
||||
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
|
||||
if _, err := wsc.Write(wsmsg); err != nil {
|
||||
t.Fatalf("Error sending message: %v", err)
|
||||
}
|
||||
msg := testWSReadFrame(t, br)
|
||||
if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
||||
t.Fatalf("Expected to receive PONG, got %q", msg)
|
||||
} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
|
||||
t.Fatalf("Expected to receive %q, got %q", test.err, msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSUsersAuth(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
opts func() *Options
|
||||
user string
|
||||
pass string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
"top level auth, no override, wrong user",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Users = []*User{
|
||||
&User{Username: "normal1", Password: "client1"},
|
||||
&User{Username: "normal2", Password: "client2"},
|
||||
}
|
||||
return o
|
||||
},
|
||||
"websocket", "client", "-ERR 'Authorization Violation'",
|
||||
},
|
||||
{
|
||||
"top level auth, no override, correct user",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Users = []*User{
|
||||
&User{Username: "normal1", Password: "client1"},
|
||||
&User{Username: "normal2", Password: "client2"},
|
||||
}
|
||||
return o
|
||||
},
|
||||
"normal2", "client2", "",
|
||||
},
|
||||
{
|
||||
"no top level auth, ws auth, wrong user",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Websocket.Users = []*User{
|
||||
&User{Username: "websocket1", Password: "client1"},
|
||||
&User{Username: "websocket2", Password: "client2"},
|
||||
}
|
||||
return o
|
||||
},
|
||||
"websocket", "client", "-ERR 'Authorization Violation'",
|
||||
},
|
||||
{
|
||||
"no top level auth, ws auth, correct token",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Websocket.Users = []*User{
|
||||
&User{Username: "websocket1", Password: "client1"},
|
||||
&User{Username: "websocket2", Password: "client2"},
|
||||
}
|
||||
return o
|
||||
},
|
||||
"websocket1", "client1", "",
|
||||
},
|
||||
{
|
||||
"top level auth, ws override, wrong user",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Users = []*User{
|
||||
&User{Username: "normal1", Password: "client1"},
|
||||
&User{Username: "normal2", Password: "client2"},
|
||||
}
|
||||
o.Websocket.Users = []*User{
|
||||
&User{Username: "websocket1", Password: "client1"},
|
||||
&User{Username: "websocket2", Password: "client2"},
|
||||
}
|
||||
return o
|
||||
},
|
||||
"normal2", "client2", "-ERR 'Authorization Violation'",
|
||||
},
|
||||
{
|
||||
"top level auth, ws override, correct token",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Users = []*User{
|
||||
&User{Username: "normal1", Password: "client1"},
|
||||
&User{Username: "normal2", Password: "client2"},
|
||||
}
|
||||
o.Websocket.Users = []*User{
|
||||
&User{Username: "websocket1", Password: "client1"},
|
||||
&User{Username: "websocket2", Password: "client2"},
|
||||
}
|
||||
return o
|
||||
},
|
||||
"websocket2", "client2", "",
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
o := test.opts()
|
||||
s := RunServer(o)
|
||||
defer s.Shutdown()
|
||||
|
||||
wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
|
||||
defer wsc.Close()
|
||||
|
||||
connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n",
|
||||
test.user, test.pass)
|
||||
|
||||
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
|
||||
if _, err := wsc.Write(wsmsg); err != nil {
|
||||
t.Fatalf("Error sending message: %v", err)
|
||||
}
|
||||
msg := testWSReadFrame(t, br)
|
||||
if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
||||
t.Fatalf("Expected to receive PONG, got %q", msg)
|
||||
} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
|
||||
t.Fatalf("Expected to receive %q, got %q", test.err, msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSNoAuthUser(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
noAuthUser string
|
||||
wsNoAuthUser string
|
||||
user string
|
||||
acc string
|
||||
createWSUsers bool
|
||||
}{
|
||||
{"use top-level no_auth_user", "user1", "", "user1", "normal", false},
|
||||
{"use websocket no_auth_user no ws users", "user1", "user2", "user2", "normal", false},
|
||||
{"use websocket no_auth_user with ws users", "user1", "wsuser1", "wsuser1", "websocket", true},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
o := testWSOptions()
|
||||
normalAcc := NewAccount("normal")
|
||||
websocketAcc := NewAccount("websocket")
|
||||
o.Accounts = []*Account{normalAcc, websocketAcc}
|
||||
o.Users = []*User{
|
||||
&User{Username: "user1", Password: "pwd1", Account: normalAcc},
|
||||
&User{Username: "user2", Password: "pwd2", Account: normalAcc},
|
||||
}
|
||||
o.NoAuthUser = test.noAuthUser
|
||||
o.Websocket.NoAuthUser = test.wsNoAuthUser
|
||||
if test.createWSUsers {
|
||||
o.Websocket.Users = []*User{
|
||||
&User{Username: "wsuser1", Password: "pwd1", Account: websocketAcc},
|
||||
&User{Username: "wsuser2", Password: "pwd2", Account: websocketAcc},
|
||||
}
|
||||
}
|
||||
s := RunServer(o)
|
||||
defer s.Shutdown()
|
||||
|
||||
wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
|
||||
defer wsc.Close()
|
||||
|
||||
var info serverInfo
|
||||
json.Unmarshal([]byte(l[5:]), &info)
|
||||
|
||||
connectProto := "CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"
|
||||
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
|
||||
if _, err := wsc.Write(wsmsg); err != nil {
|
||||
t.Fatalf("Error sending message: %v", err)
|
||||
}
|
||||
msg := testWSReadFrame(t, br)
|
||||
if !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
||||
t.Fatalf("Unexpected error: %q", msg)
|
||||
}
|
||||
|
||||
c := s.getClient(info.CID)
|
||||
c.mu.Lock()
|
||||
uname := c.opts.Username
|
||||
aname := c.acc.GetName()
|
||||
c.mu.Unlock()
|
||||
if uname != test.user {
|
||||
t.Fatalf("Expected selected user to be %q, got %q", test.user, uname)
|
||||
}
|
||||
if aname != test.acc {
|
||||
t.Fatalf("Expected selected account to be %q, got %q", test.acc, aname)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSNkeyAuth(t *testing.T) {
|
||||
nkp, _ := nkeys.CreateUser()
|
||||
pub, _ := nkp.PublicKey()
|
||||
|
||||
wsnkp, _ := nkeys.CreateUser()
|
||||
wspub, _ := wsnkp.PublicKey()
|
||||
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
opts func() *Options
|
||||
nkey string
|
||||
kp nkeys.KeyPair
|
||||
err string
|
||||
}{
|
||||
{
|
||||
"top level auth, no override, wrong nkey",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}}
|
||||
return o
|
||||
},
|
||||
wspub, wsnkp, "-ERR 'Authorization Violation'",
|
||||
},
|
||||
{
|
||||
"top level auth, no override, correct nkey",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}}
|
||||
return o
|
||||
},
|
||||
pub, nkp, "",
|
||||
},
|
||||
{
|
||||
"no top level auth, ws auth, wrong nkey",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Websocket.Nkeys = []*NkeyUser{&NkeyUser{Nkey: wspub}}
|
||||
return o
|
||||
},
|
||||
pub, nkp, "-ERR 'Authorization Violation'",
|
||||
},
|
||||
{
|
||||
"no top level auth, ws auth, correct nkey",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Websocket.Nkeys = []*NkeyUser{&NkeyUser{Nkey: wspub}}
|
||||
return o
|
||||
},
|
||||
wspub, wsnkp, "",
|
||||
},
|
||||
{
|
||||
"top level auth, ws override, wrong nkey",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}}
|
||||
o.Websocket.Nkeys = []*NkeyUser{&NkeyUser{Nkey: wspub}}
|
||||
return o
|
||||
},
|
||||
pub, nkp, "-ERR 'Authorization Violation'",
|
||||
},
|
||||
{
|
||||
"top level auth, ws override, correct nkey",
|
||||
func() *Options {
|
||||
o := testWSOptions()
|
||||
o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}}
|
||||
o.Websocket.Nkeys = []*NkeyUser{&NkeyUser{Nkey: wspub}}
|
||||
return o
|
||||
},
|
||||
wspub, wsnkp, "",
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
o := test.opts()
|
||||
s := RunServer(o)
|
||||
defer s.Shutdown()
|
||||
|
||||
wsc, br, infoMsg := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
|
||||
defer wsc.Close()
|
||||
|
||||
// Sign Nonce
|
||||
var info nonceInfo
|
||||
json.Unmarshal([]byte(infoMsg[5:]), &info)
|
||||
sigraw, _ := test.kp.Sign([]byte(info.Nonce))
|
||||
sig := base64.RawURLEncoding.EncodeToString(sigraw)
|
||||
|
||||
connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"nkey\":\"%s\",\"sig\":\"%s\"}\r\nPING\r\n", test.nkey, sig)
|
||||
|
||||
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
|
||||
if _, err := wsc.Write(wsmsg); err != nil {
|
||||
t.Fatalf("Error sending message: %v", err)
|
||||
}
|
||||
msg := testWSReadFrame(t, br)
|
||||
if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
||||
t.Fatalf("Expected to receive PONG, got %q", msg)
|
||||
} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
|
||||
t.Fatalf("Expected to receive %q, got %q", test.err, msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = Benchmark tests
|
||||
// ==================================================================
|
||||
|
||||
Reference in New Issue
Block a user