diff --git a/server/auth.go b/server/auth.go index 8f60bf8f..eb4c3b58 100644 --- a/server/auth.go +++ b/server/auth.go @@ -278,6 +278,9 @@ func (s *Server) configureAuthorization() { s.users = nil s.info.AuthRequired = false } + + // Do similar for websocket config + s.wsConfigAuth(&opts.Websocket) } // checkAuthentication will check based on client type and @@ -309,10 +312,64 @@ func (s *Server) isClientAuthorized(c *client) bool { return opts.CustomClientAuthentication.Check(c) } - return s.processClientOrLeafAuthentication(c) + return s.processClientOrLeafAuthentication(c, opts) } -func (s *Server) processClientOrLeafAuthentication(c *client) bool { +type authOpts struct { + username string + password string + token string + noAuthUser string + tlsMap bool + users map[string]*User + nkeys map[string]*NkeyUser +} + +func (s *Server) getAuthOpts(c *client, o *Options, auth *authOpts) bool { + // c.ws is immutable, but may need lock if we get race reports. + wsClient := c.ws != nil + + authRequired := s.info.AuthRequired + // For websocket clients, if there is no top-level auth, then we + // check for websocket specifically. + if !authRequired && wsClient { + authRequired = s.websocket.authRequired + } + if !authRequired { + return false + } + auth.noAuthUser = o.NoAuthUser + auth.tlsMap = o.TLSMap + if wsClient { + wo := &o.Websocket + // If those are specified, override, regardless if there is + // auth configuration (like user/pwd, etc..) in websocket section. + if wo.NoAuthUser != _EMPTY_ { + auth.noAuthUser = wo.NoAuthUser + } + if wo.TLSMap { + auth.tlsMap = true + } + // Now check for websocket auth specific override + if s.websocket.authRequired { + auth.username = wo.Username + auth.password = wo.Password + auth.token = wo.Token + auth.users = s.websocket.users + auth.nkeys = s.websocket.nkeys + return true + } + // else fallback to regular auth config + } + auth.username = o.Username + auth.password = o.Password + auth.token = o.Authorization + auth.users = s.users + auth.nkeys = s.nkeys + return true +} + +func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) bool { var ( nkey *NkeyUser juc *jwt.UserClaims @@ -320,11 +377,11 @@ func (s *Server) processClientOrLeafAuthentication(c *client) bool { user *User ok bool err error - opts = s.getOpts() + auth authOpts ) s.mu.Lock() - authRequired := s.info.AuthRequired + authRequired := s.getAuthOpts(c, opts, &auth) if !authRequired { // TODO(dlc) - If they send us credentials should we fail? s.mu.Unlock() @@ -355,21 +412,21 @@ func (s *Server) processClientOrLeafAuthentication(c *client) bool { } // Check if we have nkeys or users for client. - hasNkeys := s.nkeys != nil - hasUsers := s.users != nil + hasNkeys := auth.nkeys != nil + hasUsers := auth.users != nil if hasNkeys && c.opts.Nkey != "" { - nkey, ok = s.nkeys[c.opts.Nkey] + nkey, ok = auth.nkeys[c.opts.Nkey] if !ok { s.mu.Unlock() return false } } else if hasUsers { // Check if we are tls verify and are mapping users from the client_certificate - if opts.TLSMap { + if auth.tlsMap { var euser string authorized := checkClientTLSCertSubject(c, func(u string) bool { var ok bool - user, ok = s.users[u] + user, ok = auth.users[u] if !ok { c.Debugf("User in cert [%q], not found", u) return false @@ -388,14 +445,14 @@ func (s *Server) processClientOrLeafAuthentication(c *client) bool { // but we set it here to be able to identify it in the logs. c.opts.Username = euser } else { - if c.kind == CLIENT && c.opts.Username == "" && s.opts.NoAuthUser != "" { - if u, exists := s.users[s.opts.NoAuthUser]; exists { + if c.kind == CLIENT && c.opts.Username == "" && auth.noAuthUser != "" { + if u, exists := auth.users[auth.noAuthUser]; exists { c.opts.Username = u.Username c.opts.Password = u.Password } } if c.opts.Username != "" { - user, ok = s.users[c.opts.Username] + user, ok = auth.users[c.opts.Username] if !ok { s.mu.Unlock() return false @@ -517,13 +574,13 @@ func (s *Server) processClientOrLeafAuthentication(c *client) bool { } if c.kind == CLIENT { - if opts.Authorization != "" { - return comparePasswords(opts.Authorization, c.opts.Token) - } else if opts.Username != "" { - if opts.Username != c.opts.Username { + if auth.token != "" { + return comparePasswords(auth.token, c.opts.Token) + } else if auth.username != "" { + if auth.username != c.opts.Username { return false } - return comparePasswords(opts.Password, c.opts.Password) + return comparePasswords(auth.password, c.opts.Password) } } else if c.kind == LEAF { // There is no required username/password to connect and @@ -733,7 +790,7 @@ func (s *Server) isLeafNodeAuthorized(c *client) bool { // Still, if the CONNECT has some user info, we will bind to the // user's account or to the specified default account (if provided) // or to the global account. - return s.processClientOrLeafAuthentication(c) + return s.processClientOrLeafAuthentication(c, opts) } // Support for bcrypt stored passwords and tokens. diff --git a/server/opts.go b/server/opts.go index 0ee18590..2702b484 100644 --- a/server/opts.go +++ b/server/opts.go @@ -257,8 +257,26 @@ type WebsocketOpts struct { // The host:port to advertise to websocket clients in the cluster. Advertise string + // If no user is provided when a client connects, will default to this + // user and associated account. This user has to exist either in the + // Users defined here or in the global options. + NoAuthUser string + + // Authentication section. If anything is configured in this section, + // it will override the authorization configuration for regular clients. + Username string + Password string + Token string + Users []*User + Nkeys []*NkeyUser + + // Timeout for the authentication process. + AuthTimeout float64 + // TLS configuration is required. TLSConfig *tls.Config + // If true, map certificate values for authentication purposes. + TLSMap bool // If true, the Origin header must match the request's host. SameOrigin bool @@ -3025,12 +3043,17 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro case "advertise": o.Websocket.Advertise = mv.(string) case "tls": - config, _, err := getTLSConfig(tk) + tc, err := parseTLS(tk) if err != nil { *errors = append(*errors, err) continue } - o.Websocket.TLSConfig = config + if o.Websocket.TLSConfig, err = GenTLSConfig(tc); err != nil { + err := &configErr{tk, err.Error()} + *errors = append(*errors, err) + continue + } + o.Websocket.TLSMap = tc.Map case "same_origin": o.Websocket.SameOrigin = mv.(bool) case "allowed_origins", "allowed_origin", "allow_origins", "allow_origin", "origins", "origin": @@ -3074,6 +3097,42 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro o.Websocket.HandshakeTimeout = ht case "compression": o.Websocket.Compression = mv.(bool) + case "authorization", "authentication": + auth, err := parseAuthorization(tk, o, errors, warnings) + if err != nil { + *errors = append(*errors, err) + continue + } + o.Websocket.Username = auth.user + o.Websocket.Password = auth.pass + o.Websocket.Token = auth.token + if (auth.user != "" || auth.pass != "") && auth.token != "" { + err := &configErr{tk, "Cannot have a user/pass and token"} + *errors = append(*errors, err) + continue + } + o.Websocket.AuthTimeout = auth.timeout + // Check for multiple users defined + if auth.users != nil { + if auth.user != "" { + err := &configErr{tk, "Can not have a single user/pass and a users array"} + *errors = append(*errors, err) + continue + } + if auth.token != "" { + err := &configErr{tk, "Can not have a token and a users array"} + *errors = append(*errors, err) + continue + } + // Users may have been added from Accounts parsing, so do an append here + o.Websocket.Users = append(o.Websocket.Users, auth.users...) + } + // Check for nkeys + if auth.nkeys != nil { + o.Websocket.Nkeys = append(o.Websocket.Nkeys, auth.nkeys...) + } + case "no_auth_user": + o.Websocket.NoAuthUser = mv.(string) default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ diff --git a/server/server.go b/server/server.go index 50fe1046..53f84839 100644 --- a/server/server.go +++ b/server/server.go @@ -1875,6 +1875,12 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // Grab JSON info string s.mu.Lock() info := s.copyInfo() + // If this is a websocket client and there is no top-level auth specified, + // then we use the websocket's specific boolean that will be set to true + // if there is any auth{} configured in websocket{}. + if ws != nil && !info.AuthRequired { + info.AuthRequired = s.websocket.authRequired + } if s.nonceRequired() { // Nonce handling var raw [nonceLen]byte @@ -1974,7 +1980,15 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // the race where the timer fires during the handshake and causes the // server to write bad data to the socket. See issue #432. if info.AuthRequired { - c.setAuthTimer(secondsToDuration(opts.AuthTimeout)) + timeout := opts.AuthTimeout + // For websocket, possibly override only if set. We make sure that + // opts.AuthTimeout is set to a default value if not configured, + // but we don't do the same for websocket's one so that we know + // if user has explicitly set or not. + if ws != nil && opts.Websocket.AuthTimeout != 0 { + timeout = opts.Websocket.AuthTimeout + } + c.setAuthTimer(secondsToDuration(timeout)) } // Do final client initialization diff --git a/server/websocket.go b/server/websocket.go index 18a270c6..3eacbd43 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -101,6 +101,9 @@ type srvWebsocket struct { sameOrigin bool connectURLs []string connectURLsMap map[string]struct{} + users map[string]*User + nkeys map[string]*NkeyUser + authRequired bool // indicate if there is auth override in websocket config } type allowedOrigin struct { @@ -763,6 +766,55 @@ func (s *Server) wsSetOriginOptions(o *WebsocketOpts) { } } +// Given the websocket options, we check if any auth configuration +// has been provided. If so, possibly create users/nkey users and +// store them in s.websocket.users/nkeys. +// Also update a boolean that indicates if auth is required for +// websocket clients. +func (s *Server) wsConfigAuth(opts *WebsocketOpts) { + if opts.Nkeys != nil || opts.Users != nil { + // Support both at the same time. + if opts.Nkeys != nil { + s.websocket.nkeys = make(map[string]*NkeyUser) + for _, u := range opts.Nkeys { + copy := u.clone() + if u.Account != nil { + if v, ok := s.accounts.Load(u.Account.Name); ok { + copy.Account = v.(*Account) + } + } + if copy.Permissions != nil { + validateResponsePermissions(copy.Permissions) + } + s.websocket.nkeys[u.Nkey] = copy + } + } + if opts.Users != nil { + s.websocket.users = make(map[string]*User) + for _, u := range opts.Users { + copy := u.clone() + if u.Account != nil { + if v, ok := s.accounts.Load(u.Account.Name); ok { + copy.Account = v.(*Account) + } + } + if copy.Permissions != nil { + validateResponsePermissions(copy.Permissions) + } + s.websocket.users[u.Username] = copy + } + } + s.assignGlobalAccountToOrphanUsers() + s.websocket.authRequired = true + } else if opts.Username != "" || opts.Token != "" { + s.websocket.authRequired = true + } else { + s.websocket.users = nil + s.websocket.nkeys = nil + s.websocket.authRequired = false + } +} + func (s *Server) startWebsocketServer() { sopts := s.getOpts() o := &sopts.Websocket @@ -787,7 +839,6 @@ func (s *Server) startWebsocketServer() { if o.TLSConfig != nil { proto = "wss" config := o.TLSConfig.Clone() - config.ClientAuth = tls.NoClientCert hl, err = tls.Listen("tcp", hp, config) } else { proto = "ws" diff --git a/server/websocket_test.go b/server/websocket_test.go index 1f81fe67..7e646bed 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -18,6 +18,7 @@ import ( "bytes" "compress/flate" "crypto/tls" + "encoding/base64" "encoding/binary" "encoding/json" "errors" @@ -34,6 +35,8 @@ import ( "sync" "testing" "time" + + "github.com/nats-io/nkeys" ) type testReader struct { @@ -1581,7 +1584,7 @@ func TestWSAbnormalFailureOfWebServer(t *testing.T) { } } -func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) { +func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader, []byte) { t.Helper() addr := fmt.Sprintf("%s:%d", host, port) wsc, err := net.Dial("tcp", addr) @@ -1613,9 +1616,15 @@ func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) t.Fatalf("Expected response status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) } // Wait for the INFO - if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("INFO {")) { - t.Fatalf("Expected INFO, got %s", msg) + info := testWSReadFrame(t, br) + if !bytes.HasPrefix(info, []byte("INFO {")) { + t.Fatalf("Expected INFO, got %s", info) } + return wsc, br, info +} + +func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) { + wsc, br, _ := testWSCreateClientGetInfo(t, compress, web, host, port) // Send CONNECT and PING wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n")) if _, err := wsc.Write(wsmsg); err != nil { @@ -1792,6 +1801,187 @@ func TestWSTLSConnection(t *testing.T) { } } +func TestWSTLSVerifyClientCert(t *testing.T) { + o := testWSOptions() + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/server-cert.pem", + KeyFile: "../test/configs/certs/server-key.pem", + CaFile: "../test/configs/certs/ca.pem", + Verify: true, + } + tlsc, err := GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error creating tls config: %v", err) + } + o.Websocket.TLSConfig = tlsc + s := RunServer(o) + defer s.Shutdown() + + addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) + + for _, test := range []struct { + name string + provideCert bool + }{ + {"client provides cert", true}, + {"client does not provide cert", false}, + } { + t.Run(test.name, func(t *testing.T) { + wsc, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating ws connection: %v", err) + } + defer wsc.Close() + tlsc := &tls.Config{} + if test.provideCert { + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/client-cert.pem", + KeyFile: "../test/configs/certs/client-key.pem", + } + var err error + tlsc, err = GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error generating tls config: %v", err) + } + } + tlsc.InsecureSkipVerify = true + wsc = tls.Client(wsc, tlsc) + if err := wsc.(*tls.Conn).Handshake(); err != nil { + t.Fatalf("Error during handshake: %v", err) + } + req := testWSCreateValidReq() + req.URL, _ = url.Parse("wss://" + addr) + if err := req.Write(wsc); err != nil { + t.Fatalf("Error sending request: %v", err) + } + br := bufio.NewReader(wsc) + resp, err := http.ReadResponse(br, req) + if resp != nil { + resp.Body.Close() + } + if !test.provideCert { + if err == nil { + t.Fatal("Expected error, did not get one") + } else if !strings.Contains(err.Error(), "bad certificate") { + t.Fatalf("Unexpected error: %v", err) + } + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if resp.StatusCode != http.StatusSwitchingProtocols { + t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } + }) + } +} + +func TestWSTLSVerifyAndMap(t *testing.T) { + o := testWSOptions() + certUserName := "CN=example.com,OU=NATS.io" + o.Users = []*User{&User{Username: certUserName}} + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/tlsauth/server.pem", + KeyFile: "../test/configs/certs/tlsauth/server-key.pem", + CaFile: "../test/configs/certs/tlsauth/ca.pem", + Verify: true, + } + tlsc, err := GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error creating tls config: %v", err) + } + o.Websocket.TLSConfig = tlsc + o.Websocket.TLSMap = true + s := RunServer(o) + defer s.Shutdown() + + addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) + + for _, test := range []struct { + name string + provideCert bool + }{ + {"client provides cert", true}, + {"client does not provide cert", false}, + } { + t.Run(test.name, func(t *testing.T) { + wsc, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating ws connection: %v", err) + } + defer wsc.Close() + tlsc := &tls.Config{} + if test.provideCert { + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/tlsauth/client.pem", + KeyFile: "../test/configs/certs/tlsauth/client-key.pem", + } + var err error + tlsc, err = GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error generating tls config: %v", err) + } + } + tlsc.InsecureSkipVerify = true + wsc = tls.Client(wsc, tlsc) + if err := wsc.(*tls.Conn).Handshake(); err != nil { + t.Fatalf("Error during handshake: %v", err) + } + req := testWSCreateValidReq() + req.URL, _ = url.Parse("wss://" + addr) + if err := req.Write(wsc); err != nil { + t.Fatalf("Error sending request: %v", err) + } + br := bufio.NewReader(wsc) + resp, err := http.ReadResponse(br, req) + if resp != nil { + resp.Body.Close() + } + if !test.provideCert { + if err == nil { + t.Fatal("Expected error, did not get one") + } else if !strings.Contains(err.Error(), "bad certificate") { + t.Fatalf("Unexpected error: %v", err) + } + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if resp.StatusCode != http.StatusSwitchingProtocols { + t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } + // Wait for the INFO + l := testWSReadFrame(t, br) + if !bytes.HasPrefix(l, []byte("INFO {")) { + t.Fatalf("Expected INFO, got %s", l) + } + var info serverInfo + if err := json.Unmarshal(l[5:], &info); err != nil { + t.Fatalf("Unable to unmarshal info: %v", err) + } + // Send CONNECT and PING + wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n")) + if _, err := wsc.Write(wsmsg); err != nil { + t.Fatalf("Error sending message: %v", err) + } + // Wait for the PONG + if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) { + t.Fatalf("Expected PONG, got %s", msg) + } + + c := s.getClient(info.CID) + c.mu.Lock() + un := c.opts.Username + c.mu.Unlock() + if un != certUserName { + t.Fatalf("Expected client's assigned username to be %q, got %q", certUserName, un) + } + }) + } +} + func TestWSHandshakeTimeout(t *testing.T) { o := testWSOptions() o.Websocket.HandshakeTimeout = time.Millisecond @@ -2406,6 +2596,521 @@ func TestWSCompressionFrameSizeLimit(t *testing.T) { } } +func TestWSBasicAuth(t *testing.T) { + for _, test := range []struct { + name string + opts func() *Options + user string + pass string + err string + }{ + { + "top level auth, no override, wrong u/p", + func() *Options { + o := testWSOptions() + o.Username = "normal" + o.Password = "client" + return o + }, + "websocket", "client", "-ERR 'Authorization Violation'", + }, + { + "top level auth, no override, correct u/p", + func() *Options { + o := testWSOptions() + o.Username = "normal" + o.Password = "client" + return o + }, + "normal", "client", "", + }, + { + "no top level auth, ws auth, wrong u/p", + func() *Options { + o := testWSOptions() + o.Websocket.Username = "websocket" + o.Websocket.Password = "client" + return o + }, + "normal", "client", "-ERR 'Authorization Violation'", + }, + { + "no top level auth, ws auth, correct u/p", + func() *Options { + o := testWSOptions() + o.Websocket.Username = "websocket" + o.Websocket.Password = "client" + return o + }, + "websocket", "client", "", + }, + { + "top level auth, ws override, wrong u/p", + func() *Options { + o := testWSOptions() + o.Username = "normal" + o.Password = "client" + o.Websocket.Username = "websocket" + o.Websocket.Password = "client" + return o + }, + "normal", "client", "-ERR 'Authorization Violation'", + }, + { + "top level auth, ws override, correct u/p", + func() *Options { + o := testWSOptions() + o.Username = "normal" + o.Password = "client" + o.Websocket.Username = "websocket" + o.Websocket.Password = "client" + return o + }, + "websocket", "client", "", + }, + } { + t.Run(test.name, func(t *testing.T) { + o := test.opts() + s := RunServer(o) + defer s.Shutdown() + + wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) + defer wsc.Close() + + connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", + test.user, test.pass) + + wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) + if _, err := wsc.Write(wsmsg); err != nil { + t.Fatalf("Error sending message: %v", err) + } + msg := testWSReadFrame(t, br) + if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { + t.Fatalf("Expected to receive PONG, got %q", msg) + } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { + t.Fatalf("Expected to receive %q, got %q", test.err, msg) + } + }) + } +} + +func TestWSAuthTimeout(t *testing.T) { + for _, test := range []struct { + name string + at float64 + wat float64 + err string + }{ + {"use top-level auth timeout", 10.0, 0.0, ""}, + {"use websocket auth timeout", 10.0, 0.05, "-ERR 'Authentication Timeout'"}, + } { + t.Run(test.name, func(t *testing.T) { + o := testWSOptions() + o.AuthTimeout = test.at + o.Websocket.Username = "websocket" + o.Websocket.Password = "client" + o.Websocket.AuthTimeout = test.wat + s := RunServer(o) + defer s.Shutdown() + + wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) + defer wsc.Close() + + var info serverInfo + json.Unmarshal([]byte(l[5:]), &info) + // Make sure that we are told that auth is required. + if !info.AuthRequired { + t.Fatalf("Expected auth required, was not: %q", l) + } + start := time.Now() + // Wait before sending connect + time.Sleep(100 * time.Millisecond) + connectProto := "CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"websocket\",\"pass\":\"client\"}\r\nPING\r\n" + wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) + if _, err := wsc.Write(wsmsg); err != nil { + t.Fatalf("Error sending message: %v", err) + } + msg := testWSReadFrame(t, br) + if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { + t.Fatalf("Expected to receive %q error, got %q", test.err, msg) + } else if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { + t.Fatalf("Unexpected error: %q", msg) + } + if dur := time.Since(start); dur > time.Second { + t.Fatalf("Too long to get timeout error: %v", dur) + } + }) + } +} + +func TestWSTokenAuth(t *testing.T) { + for _, test := range []struct { + name string + opts func() *Options + token string + err string + }{ + { + "top level auth, no override, wrong token", + func() *Options { + o := testWSOptions() + o.Authorization = "goodtoken" + return o + }, + "badtoken", "-ERR 'Authorization Violation'", + }, + { + "top level auth, no override, correct token", + func() *Options { + o := testWSOptions() + o.Authorization = "goodtoken" + return o + }, + "goodtoken", "", + }, + { + "no top level auth, ws auth, wrong token", + func() *Options { + o := testWSOptions() + o.Websocket.Token = "goodtoken" + return o + }, + "badtoken", "-ERR 'Authorization Violation'", + }, + { + "no top level auth, ws auth, correct token", + func() *Options { + o := testWSOptions() + o.Websocket.Token = "goodtoken" + return o + }, + "goodtoken", "", + }, + { + "top level auth, ws override, wrong token", + func() *Options { + o := testWSOptions() + o.Authorization = "clienttoken" + o.Websocket.Token = "websockettoken" + return o + }, + "clienttoken", "-ERR 'Authorization Violation'", + }, + { + "top level auth, ws override, correct token", + func() *Options { + o := testWSOptions() + o.Authorization = "clienttoken" + o.Websocket.Token = "websockettoken" + return o + }, + "websockettoken", "", + }, + } { + t.Run(test.name, func(t *testing.T) { + o := test.opts() + s := RunServer(o) + defer s.Shutdown() + + wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) + defer wsc.Close() + + connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"auth_token\":\"%s\"}\r\nPING\r\n", + test.token) + + wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) + if _, err := wsc.Write(wsmsg); err != nil { + t.Fatalf("Error sending message: %v", err) + } + msg := testWSReadFrame(t, br) + if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { + t.Fatalf("Expected to receive PONG, got %q", msg) + } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { + t.Fatalf("Expected to receive %q, got %q", test.err, msg) + } + }) + } +} + +func TestWSUsersAuth(t *testing.T) { + for _, test := range []struct { + name string + opts func() *Options + user string + pass string + err string + }{ + { + "top level auth, no override, wrong user", + func() *Options { + o := testWSOptions() + o.Users = []*User{ + &User{Username: "normal1", Password: "client1"}, + &User{Username: "normal2", Password: "client2"}, + } + return o + }, + "websocket", "client", "-ERR 'Authorization Violation'", + }, + { + "top level auth, no override, correct user", + func() *Options { + o := testWSOptions() + o.Users = []*User{ + &User{Username: "normal1", Password: "client1"}, + &User{Username: "normal2", Password: "client2"}, + } + return o + }, + "normal2", "client2", "", + }, + { + "no top level auth, ws auth, wrong user", + func() *Options { + o := testWSOptions() + o.Websocket.Users = []*User{ + &User{Username: "websocket1", Password: "client1"}, + &User{Username: "websocket2", Password: "client2"}, + } + return o + }, + "websocket", "client", "-ERR 'Authorization Violation'", + }, + { + "no top level auth, ws auth, correct token", + func() *Options { + o := testWSOptions() + o.Websocket.Users = []*User{ + &User{Username: "websocket1", Password: "client1"}, + &User{Username: "websocket2", Password: "client2"}, + } + return o + }, + "websocket1", "client1", "", + }, + { + "top level auth, ws override, wrong user", + func() *Options { + o := testWSOptions() + o.Users = []*User{ + &User{Username: "normal1", Password: "client1"}, + &User{Username: "normal2", Password: "client2"}, + } + o.Websocket.Users = []*User{ + &User{Username: "websocket1", Password: "client1"}, + &User{Username: "websocket2", Password: "client2"}, + } + return o + }, + "normal2", "client2", "-ERR 'Authorization Violation'", + }, + { + "top level auth, ws override, correct token", + func() *Options { + o := testWSOptions() + o.Users = []*User{ + &User{Username: "normal1", Password: "client1"}, + &User{Username: "normal2", Password: "client2"}, + } + o.Websocket.Users = []*User{ + &User{Username: "websocket1", Password: "client1"}, + &User{Username: "websocket2", Password: "client2"}, + } + return o + }, + "websocket2", "client2", "", + }, + } { + t.Run(test.name, func(t *testing.T) { + o := test.opts() + s := RunServer(o) + defer s.Shutdown() + + wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) + defer wsc.Close() + + connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", + test.user, test.pass) + + wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) + if _, err := wsc.Write(wsmsg); err != nil { + t.Fatalf("Error sending message: %v", err) + } + msg := testWSReadFrame(t, br) + if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { + t.Fatalf("Expected to receive PONG, got %q", msg) + } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { + t.Fatalf("Expected to receive %q, got %q", test.err, msg) + } + }) + } +} + +func TestWSNoAuthUser(t *testing.T) { + for _, test := range []struct { + name string + noAuthUser string + wsNoAuthUser string + user string + acc string + createWSUsers bool + }{ + {"use top-level no_auth_user", "user1", "", "user1", "normal", false}, + {"use websocket no_auth_user no ws users", "user1", "user2", "user2", "normal", false}, + {"use websocket no_auth_user with ws users", "user1", "wsuser1", "wsuser1", "websocket", true}, + } { + t.Run(test.name, func(t *testing.T) { + o := testWSOptions() + normalAcc := NewAccount("normal") + websocketAcc := NewAccount("websocket") + o.Accounts = []*Account{normalAcc, websocketAcc} + o.Users = []*User{ + &User{Username: "user1", Password: "pwd1", Account: normalAcc}, + &User{Username: "user2", Password: "pwd2", Account: normalAcc}, + } + o.NoAuthUser = test.noAuthUser + o.Websocket.NoAuthUser = test.wsNoAuthUser + if test.createWSUsers { + o.Websocket.Users = []*User{ + &User{Username: "wsuser1", Password: "pwd1", Account: websocketAcc}, + &User{Username: "wsuser2", Password: "pwd2", Account: websocketAcc}, + } + } + s := RunServer(o) + defer s.Shutdown() + + wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) + defer wsc.Close() + + var info serverInfo + json.Unmarshal([]byte(l[5:]), &info) + + connectProto := "CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n" + wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) + if _, err := wsc.Write(wsmsg); err != nil { + t.Fatalf("Error sending message: %v", err) + } + msg := testWSReadFrame(t, br) + if !bytes.HasPrefix(msg, []byte("PONG\r\n")) { + t.Fatalf("Unexpected error: %q", msg) + } + + c := s.getClient(info.CID) + c.mu.Lock() + uname := c.opts.Username + aname := c.acc.GetName() + c.mu.Unlock() + if uname != test.user { + t.Fatalf("Expected selected user to be %q, got %q", test.user, uname) + } + if aname != test.acc { + t.Fatalf("Expected selected account to be %q, got %q", test.acc, aname) + } + }) + } +} + +func TestWSNkeyAuth(t *testing.T) { + nkp, _ := nkeys.CreateUser() + pub, _ := nkp.PublicKey() + + wsnkp, _ := nkeys.CreateUser() + wspub, _ := wsnkp.PublicKey() + + for _, test := range []struct { + name string + opts func() *Options + nkey string + kp nkeys.KeyPair + err string + }{ + { + "top level auth, no override, wrong nkey", + func() *Options { + o := testWSOptions() + o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}} + return o + }, + wspub, wsnkp, "-ERR 'Authorization Violation'", + }, + { + "top level auth, no override, correct nkey", + func() *Options { + o := testWSOptions() + o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}} + return o + }, + pub, nkp, "", + }, + { + "no top level auth, ws auth, wrong nkey", + func() *Options { + o := testWSOptions() + o.Websocket.Nkeys = []*NkeyUser{&NkeyUser{Nkey: wspub}} + return o + }, + pub, nkp, "-ERR 'Authorization Violation'", + }, + { + "no top level auth, ws auth, correct nkey", + func() *Options { + o := testWSOptions() + o.Websocket.Nkeys = []*NkeyUser{&NkeyUser{Nkey: wspub}} + return o + }, + wspub, wsnkp, "", + }, + { + "top level auth, ws override, wrong nkey", + func() *Options { + o := testWSOptions() + o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}} + o.Websocket.Nkeys = []*NkeyUser{&NkeyUser{Nkey: wspub}} + return o + }, + pub, nkp, "-ERR 'Authorization Violation'", + }, + { + "top level auth, ws override, correct nkey", + func() *Options { + o := testWSOptions() + o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}} + o.Websocket.Nkeys = []*NkeyUser{&NkeyUser{Nkey: wspub}} + return o + }, + wspub, wsnkp, "", + }, + } { + t.Run(test.name, func(t *testing.T) { + o := test.opts() + s := RunServer(o) + defer s.Shutdown() + + wsc, br, infoMsg := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) + defer wsc.Close() + + // Sign Nonce + var info nonceInfo + json.Unmarshal([]byte(infoMsg[5:]), &info) + sigraw, _ := test.kp.Sign([]byte(info.Nonce)) + sig := base64.RawURLEncoding.EncodeToString(sigraw) + + connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"nkey\":\"%s\",\"sig\":\"%s\"}\r\nPING\r\n", test.nkey, sig) + + wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) + if _, err := wsc.Write(wsmsg); err != nil { + t.Fatalf("Error sending message: %v", err) + } + msg := testWSReadFrame(t, br) + if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { + t.Fatalf("Expected to receive PONG, got %q", msg) + } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { + t.Fatalf("Expected to receive %q, got %q", test.err, msg) + } + }) + } +} + // ================================================================== // = Benchmark tests // ==================================================================