mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-16 11:04:42 -07:00
Merge pull request #1477 from pas2k/ws_cookie_auth
[ADDED] Cookie JWT auth for WebSocket
This commit is contained in:
@@ -1497,7 +1497,13 @@ func (c *client) processConnect(arg []byte) error {
|
||||
lang := c.opts.Lang
|
||||
account := c.opts.Account
|
||||
accountNew := c.opts.AccountNew
|
||||
|
||||
// If websocket client and JWT not in the CONNECT, use the cookie JWT (possibly empty).
|
||||
if ws := c.ws; ws != nil && c.opts.JWT == "" {
|
||||
c.opts.JWT = ws.cookieJwt
|
||||
}
|
||||
ujwt := c.opts.JWT
|
||||
|
||||
// For headers both client and server need to support.
|
||||
c.headers = supportsHeaders && c.opts.Headers
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -265,6 +265,11 @@ type WebsocketOpts struct {
|
||||
// Users defined here or in the global options.
|
||||
NoAuthUser string
|
||||
|
||||
// Name of the cookie, which if present in WebSocket upgrade headers,
|
||||
// will be treated as JWT during CONNECT phase as long as
|
||||
// "jwt" specified in the CONNECT options is missing or empty.
|
||||
JWTCookie string
|
||||
|
||||
// Authentication section. If anything is configured in this section,
|
||||
// it will override the authorization configuration for regular clients.
|
||||
Username string
|
||||
@@ -3152,6 +3157,8 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro
|
||||
if auth.nkeys != nil {
|
||||
o.Websocket.Nkeys = append(o.Websocket.Nkeys, auth.nkeys...)
|
||||
}
|
||||
case "jwt_cookie":
|
||||
o.Websocket.JWTCookie = mv.(string)
|
||||
case "no_auth_user":
|
||||
o.Websocket.NoAuthUser = mv.(string)
|
||||
default:
|
||||
|
||||
@@ -90,6 +90,7 @@ type websocket struct {
|
||||
closeSent bool
|
||||
browser bool
|
||||
compressor *flate.Writer
|
||||
cookieJwt string
|
||||
}
|
||||
|
||||
type srvWebsocket struct {
|
||||
@@ -597,6 +598,11 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe
|
||||
if ua := r.Header.Get("User-Agent"); ua != "" && strings.HasPrefix(ua, "Mozilla/") {
|
||||
ws.browser = true
|
||||
}
|
||||
if opts.Websocket.JWTCookie != "" {
|
||||
if c, err := r.Cookie(opts.Websocket.JWTCookie); err == nil && c != nil {
|
||||
ws.cookieJwt = c.Value
|
||||
}
|
||||
}
|
||||
return &wsUpgradeResult{conn: conn, ws: ws}, nil
|
||||
}
|
||||
|
||||
@@ -748,6 +754,12 @@ func validateWebsocketOptions(o *Options) error {
|
||||
}
|
||||
return fmt.Errorf("websocket no_auth_user %q not found in users configuration", wo.NoAuthUser)
|
||||
}
|
||||
// Using JWT requires Trusted Keys
|
||||
if wo.JWTCookie != "" {
|
||||
if len(o.TrustedOperators) == 0 && len(o.TrustedKeys) == 0 {
|
||||
return fmt.Errorf("trusted operators or trusted keys configuration is required for JWT authentication via cookie %q", wo.JWTCookie)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/jwt/v2"
|
||||
"github.com/nats-io/nkeys"
|
||||
)
|
||||
|
||||
@@ -1454,6 +1455,11 @@ func TestWSValidateOptions(t *testing.T) {
|
||||
o.Websocket.AllowedOrigins = []string{"http://this:is:bad:url"}
|
||||
return o
|
||||
}, "unable to parse"},
|
||||
{"missing trusted configuration", func() *Options {
|
||||
o := wso.Clone()
|
||||
o.Websocket.JWTCookie = "jwt"
|
||||
return o
|
||||
}, "keys configuration is required"},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
err := validateWebsocketOptions(test.getOpts())
|
||||
@@ -1584,9 +1590,16 @@ func TestWSAbnormalFailureOfWebServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader, []byte) {
|
||||
type testWSClientOptions struct {
|
||||
compress, web bool
|
||||
host string
|
||||
port int
|
||||
extraHeaders map[string]string
|
||||
}
|
||||
|
||||
func testNewWSClient(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte) {
|
||||
t.Helper()
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
addr := fmt.Sprintf("%s:%d", o.host, o.port)
|
||||
wsc, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating ws connection: %v", err)
|
||||
@@ -1596,12 +1609,17 @@ func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, po
|
||||
t.Fatalf("Error during handshake: %v", err)
|
||||
}
|
||||
req := testWSCreateValidReq()
|
||||
if compress {
|
||||
if o.compress {
|
||||
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate")
|
||||
}
|
||||
if web {
|
||||
if o.web {
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0")
|
||||
}
|
||||
if o.extraHeaders != nil {
|
||||
for hdr, val := range o.extraHeaders {
|
||||
req.Header.Add(hdr, val)
|
||||
}
|
||||
}
|
||||
req.URL, _ = url.Parse("wss://" + addr)
|
||||
if err := req.Write(wsc); err != nil {
|
||||
t.Fatalf("Error sending request: %v", err)
|
||||
@@ -1623,6 +1641,102 @@ func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, po
|
||||
return wsc, br, info
|
||||
}
|
||||
|
||||
type testClaimsOptions struct {
|
||||
nac *jwt.AccountClaims
|
||||
nuc *jwt.UserClaims
|
||||
connectRequest interface{}
|
||||
dontSign bool
|
||||
expectAnswer string
|
||||
}
|
||||
|
||||
func testWSWithClaims(t *testing.T, s *Server, o testWSClientOptions, tclm testClaimsOptions) (kp nkeys.KeyPair, conn net.Conn, rdr *bufio.Reader, auth_was_required bool) {
|
||||
t.Helper()
|
||||
|
||||
okp, _ := nkeys.FromSeed(oSeed)
|
||||
|
||||
akp, _ := nkeys.CreateAccount()
|
||||
apub, _ := akp.PublicKey()
|
||||
if tclm.nac == nil {
|
||||
tclm.nac = jwt.NewAccountClaims(apub)
|
||||
} else {
|
||||
tclm.nac.Subject = apub
|
||||
}
|
||||
ajwt, err := tclm.nac.Encode(okp)
|
||||
if err != nil {
|
||||
t.Fatalf("Error generating account JWT: %v", err)
|
||||
}
|
||||
|
||||
nkp, _ := nkeys.CreateUser()
|
||||
pub, _ := nkp.PublicKey()
|
||||
if tclm.nuc == nil {
|
||||
tclm.nuc = jwt.NewUserClaims(pub)
|
||||
} else {
|
||||
tclm.nuc.Subject = pub
|
||||
}
|
||||
jwt, err := tclm.nuc.Encode(akp)
|
||||
if err != nil {
|
||||
t.Fatalf("Error generating user JWT: %v", err)
|
||||
}
|
||||
|
||||
addAccountToMemResolver(s, apub, ajwt)
|
||||
|
||||
c, cr, l := testNewWSClient(t, o)
|
||||
|
||||
var info struct {
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
AuthRequired bool `json:"auth_required,omitempty"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(l[5:]), &info); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if info.AuthRequired {
|
||||
cs := ""
|
||||
if tclm.connectRequest != nil {
|
||||
customReq, err := json.Marshal(tclm.connectRequest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// PING needed to flush the +OK/-ERR to us.
|
||||
cs = fmt.Sprintf("CONNECT %v\r\nPING\r\n", string(customReq))
|
||||
} else if !tclm.dontSign {
|
||||
// Sign Nonce
|
||||
sigraw, _ := nkp.Sign([]byte(info.Nonce))
|
||||
sig := base64.RawURLEncoding.EncodeToString(sigraw)
|
||||
cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"sig\":\"%s\",\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt, sig)
|
||||
} else {
|
||||
cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt)
|
||||
}
|
||||
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(cs))
|
||||
c.Write(wsmsg)
|
||||
l = testWSReadFrame(t, cr)
|
||||
if !strings.HasPrefix(string(l), tclm.expectAnswer) {
|
||||
t.Fatalf("Expected %q, got %q", tclm.expectAnswer, l)
|
||||
}
|
||||
}
|
||||
return akp, c, cr, info.AuthRequired
|
||||
}
|
||||
|
||||
func setupAddTrusted(o *Options) {
|
||||
kp, _ := nkeys.FromSeed(oSeed)
|
||||
pub, _ := kp.PublicKey()
|
||||
o.TrustedKeys = []string{pub}
|
||||
}
|
||||
|
||||
func setupAddCookie(o *Options) {
|
||||
o.Websocket.JWTCookie = "jwt"
|
||||
}
|
||||
|
||||
func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader, []byte) {
|
||||
t.Helper()
|
||||
return testNewWSClient(t, testWSClientOptions{
|
||||
compress: compress,
|
||||
web: web,
|
||||
host: host,
|
||||
port: port,
|
||||
})
|
||||
}
|
||||
|
||||
func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) {
|
||||
wsc, br, _ := testWSCreateClientGetInfo(t, compress, web, host, port)
|
||||
// Send CONNECT and PING
|
||||
@@ -3157,6 +3271,124 @@ func TestWSNkeyAuth(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTWSCookieUser(t *testing.T) {
|
||||
|
||||
nucSigFunc := func() *jwt.UserClaims { return newJWTTestUserClaims() }
|
||||
nucBearerFunc := func() *jwt.UserClaims {
|
||||
ret := newJWTTestUserClaims()
|
||||
ret.BearerToken = true
|
||||
return ret
|
||||
}
|
||||
|
||||
o := testWSOptions()
|
||||
setupAddTrusted(o)
|
||||
setupAddCookie(o)
|
||||
s := RunServer(o)
|
||||
buildMemAccResolver(s)
|
||||
defer s.Shutdown()
|
||||
|
||||
genJwt := func(t *testing.T, nuc *jwt.UserClaims) string {
|
||||
okp, _ := nkeys.FromSeed(oSeed)
|
||||
|
||||
akp, _ := nkeys.CreateAccount()
|
||||
apub, _ := akp.PublicKey()
|
||||
|
||||
nac := jwt.NewAccountClaims(apub)
|
||||
ajwt, err := nac.Encode(okp)
|
||||
if err != nil {
|
||||
t.Fatalf("Error generating account JWT: %v", err)
|
||||
}
|
||||
|
||||
nkp, _ := nkeys.CreateUser()
|
||||
pub, _ := nkp.PublicKey()
|
||||
nuc.Subject = pub
|
||||
jwt, err := nuc.Encode(akp)
|
||||
if err != nil {
|
||||
t.Fatalf("Error generating user JWT: %v", err)
|
||||
}
|
||||
addAccountToMemResolver(s, apub, ajwt)
|
||||
return jwt
|
||||
}
|
||||
|
||||
cliOpts := testWSClientOptions{
|
||||
host: o.Websocket.Host,
|
||||
port: o.Websocket.Port,
|
||||
}
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
nuc *jwt.UserClaims
|
||||
opts func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions)
|
||||
expectAnswer string
|
||||
}{
|
||||
{
|
||||
name: "protocol auth, non-bearer key, with signature",
|
||||
nuc: nucSigFunc(),
|
||||
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
|
||||
return cliOpts, testClaimsOptions{nuc: claims}
|
||||
},
|
||||
expectAnswer: "+OK",
|
||||
},
|
||||
{
|
||||
name: "protocol auth, non-bearer key, w/o required signature",
|
||||
nuc: nucSigFunc(),
|
||||
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
|
||||
return cliOpts, testClaimsOptions{nuc: claims, dontSign: true}
|
||||
},
|
||||
expectAnswer: "-ERR",
|
||||
},
|
||||
{
|
||||
name: "protocol auth, bearer key, w/o signature",
|
||||
nuc: nucBearerFunc(),
|
||||
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
|
||||
return cliOpts, testClaimsOptions{nuc: claims, dontSign: true}
|
||||
},
|
||||
expectAnswer: "+OK",
|
||||
},
|
||||
{
|
||||
name: "cookie auth, non-bearer key, protocol auth fail",
|
||||
nuc: nucSigFunc(),
|
||||
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
|
||||
co := cliOpts
|
||||
co.extraHeaders = map[string]string{}
|
||||
co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims)
|
||||
return co, testClaimsOptions{connectRequest: struct{}{}}
|
||||
},
|
||||
expectAnswer: "-ERR",
|
||||
},
|
||||
{
|
||||
name: "cookie auth, bearer key, protocol auth success with implied cookie jwt",
|
||||
nuc: nucBearerFunc(),
|
||||
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
|
||||
co := cliOpts
|
||||
co.extraHeaders = map[string]string{}
|
||||
co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims)
|
||||
return co, testClaimsOptions{connectRequest: struct{}{}}
|
||||
},
|
||||
expectAnswer: "+OK",
|
||||
},
|
||||
{
|
||||
name: "cookie auth, non-bearer key, protocol auth success via override jwt in CONNECT opts",
|
||||
nuc: nucSigFunc(),
|
||||
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
|
||||
co := cliOpts
|
||||
co.extraHeaders = map[string]string{}
|
||||
co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims)
|
||||
return co, testClaimsOptions{nuc: nucBearerFunc()}
|
||||
},
|
||||
expectAnswer: "+OK",
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
cliOpt, claimOpt := test.opts(t, test.nuc)
|
||||
claimOpt.expectAnswer = test.expectAnswer
|
||||
_, c, _, _ := testWSWithClaims(t, s, cliOpt, claimOpt)
|
||||
c.Close()
|
||||
})
|
||||
}
|
||||
s.Shutdown()
|
||||
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = Benchmark tests
|
||||
// ==================================================================
|
||||
|
||||
Reference in New Issue
Block a user