Merge pull request #1477 from pas2k/ws_cookie_auth

[ADDED] Cookie JWT auth for WebSocket
This commit is contained in:
Ivan Kozlovic
2020-06-18 14:01:14 -06:00
committed by GitHub
4 changed files with 261 additions and 4 deletions

View File

@@ -1497,7 +1497,13 @@ func (c *client) processConnect(arg []byte) error {
lang := c.opts.Lang
account := c.opts.Account
accountNew := c.opts.AccountNew
// If websocket client and JWT not in the CONNECT, use the cookie JWT (possibly empty).
if ws := c.ws; ws != nil && c.opts.JWT == "" {
c.opts.JWT = ws.cookieJwt
}
ujwt := c.opts.JWT
// For headers both client and server need to support.
c.headers = supportsHeaders && c.opts.Headers
c.mu.Unlock()

View File

@@ -265,6 +265,11 @@ type WebsocketOpts struct {
// Users defined here or in the global options.
NoAuthUser string
// Name of the cookie, which if present in WebSocket upgrade headers,
// will be treated as JWT during CONNECT phase as long as
// "jwt" specified in the CONNECT options is missing or empty.
JWTCookie string
// Authentication section. If anything is configured in this section,
// it will override the authorization configuration for regular clients.
Username string
@@ -3152,6 +3157,8 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro
if auth.nkeys != nil {
o.Websocket.Nkeys = append(o.Websocket.Nkeys, auth.nkeys...)
}
case "jwt_cookie":
o.Websocket.JWTCookie = mv.(string)
case "no_auth_user":
o.Websocket.NoAuthUser = mv.(string)
default:

View File

@@ -90,6 +90,7 @@ type websocket struct {
closeSent bool
browser bool
compressor *flate.Writer
cookieJwt string
}
type srvWebsocket struct {
@@ -597,6 +598,11 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe
if ua := r.Header.Get("User-Agent"); ua != "" && strings.HasPrefix(ua, "Mozilla/") {
ws.browser = true
}
if opts.Websocket.JWTCookie != "" {
if c, err := r.Cookie(opts.Websocket.JWTCookie); err == nil && c != nil {
ws.cookieJwt = c.Value
}
}
return &wsUpgradeResult{conn: conn, ws: ws}, nil
}
@@ -748,6 +754,12 @@ func validateWebsocketOptions(o *Options) error {
}
return fmt.Errorf("websocket no_auth_user %q not found in users configuration", wo.NoAuthUser)
}
// Using JWT requires Trusted Keys
if wo.JWTCookie != "" {
if len(o.TrustedOperators) == 0 && len(o.TrustedKeys) == 0 {
return fmt.Errorf("trusted operators or trusted keys configuration is required for JWT authentication via cookie %q", wo.JWTCookie)
}
}
return nil
}

View File

@@ -36,6 +36,7 @@ import (
"testing"
"time"
"github.com/nats-io/jwt/v2"
"github.com/nats-io/nkeys"
)
@@ -1454,6 +1455,11 @@ func TestWSValidateOptions(t *testing.T) {
o.Websocket.AllowedOrigins = []string{"http://this:is:bad:url"}
return o
}, "unable to parse"},
{"missing trusted configuration", func() *Options {
o := wso.Clone()
o.Websocket.JWTCookie = "jwt"
return o
}, "keys configuration is required"},
} {
t.Run(test.name, func(t *testing.T) {
err := validateWebsocketOptions(test.getOpts())
@@ -1584,9 +1590,16 @@ func TestWSAbnormalFailureOfWebServer(t *testing.T) {
}
}
func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader, []byte) {
type testWSClientOptions struct {
compress, web bool
host string
port int
extraHeaders map[string]string
}
func testNewWSClient(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte) {
t.Helper()
addr := fmt.Sprintf("%s:%d", host, port)
addr := fmt.Sprintf("%s:%d", o.host, o.port)
wsc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating ws connection: %v", err)
@@ -1596,12 +1609,17 @@ func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, po
t.Fatalf("Error during handshake: %v", err)
}
req := testWSCreateValidReq()
if compress {
if o.compress {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate")
}
if web {
if o.web {
req.Header.Set("User-Agent", "Mozilla/5.0")
}
if o.extraHeaders != nil {
for hdr, val := range o.extraHeaders {
req.Header.Add(hdr, val)
}
}
req.URL, _ = url.Parse("wss://" + addr)
if err := req.Write(wsc); err != nil {
t.Fatalf("Error sending request: %v", err)
@@ -1623,6 +1641,102 @@ func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, po
return wsc, br, info
}
type testClaimsOptions struct {
nac *jwt.AccountClaims
nuc *jwt.UserClaims
connectRequest interface{}
dontSign bool
expectAnswer string
}
func testWSWithClaims(t *testing.T, s *Server, o testWSClientOptions, tclm testClaimsOptions) (kp nkeys.KeyPair, conn net.Conn, rdr *bufio.Reader, auth_was_required bool) {
t.Helper()
okp, _ := nkeys.FromSeed(oSeed)
akp, _ := nkeys.CreateAccount()
apub, _ := akp.PublicKey()
if tclm.nac == nil {
tclm.nac = jwt.NewAccountClaims(apub)
} else {
tclm.nac.Subject = apub
}
ajwt, err := tclm.nac.Encode(okp)
if err != nil {
t.Fatalf("Error generating account JWT: %v", err)
}
nkp, _ := nkeys.CreateUser()
pub, _ := nkp.PublicKey()
if tclm.nuc == nil {
tclm.nuc = jwt.NewUserClaims(pub)
} else {
tclm.nuc.Subject = pub
}
jwt, err := tclm.nuc.Encode(akp)
if err != nil {
t.Fatalf("Error generating user JWT: %v", err)
}
addAccountToMemResolver(s, apub, ajwt)
c, cr, l := testNewWSClient(t, o)
var info struct {
Nonce string `json:"nonce,omitempty"`
AuthRequired bool `json:"auth_required,omitempty"`
}
if err := json.Unmarshal([]byte(l[5:]), &info); err != nil {
t.Fatal(err)
}
if info.AuthRequired {
cs := ""
if tclm.connectRequest != nil {
customReq, err := json.Marshal(tclm.connectRequest)
if err != nil {
t.Fatal(err)
}
// PING needed to flush the +OK/-ERR to us.
cs = fmt.Sprintf("CONNECT %v\r\nPING\r\n", string(customReq))
} else if !tclm.dontSign {
// Sign Nonce
sigraw, _ := nkp.Sign([]byte(info.Nonce))
sig := base64.RawURLEncoding.EncodeToString(sigraw)
cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"sig\":\"%s\",\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt, sig)
} else {
cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt)
}
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(cs))
c.Write(wsmsg)
l = testWSReadFrame(t, cr)
if !strings.HasPrefix(string(l), tclm.expectAnswer) {
t.Fatalf("Expected %q, got %q", tclm.expectAnswer, l)
}
}
return akp, c, cr, info.AuthRequired
}
func setupAddTrusted(o *Options) {
kp, _ := nkeys.FromSeed(oSeed)
pub, _ := kp.PublicKey()
o.TrustedKeys = []string{pub}
}
func setupAddCookie(o *Options) {
o.Websocket.JWTCookie = "jwt"
}
func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader, []byte) {
t.Helper()
return testNewWSClient(t, testWSClientOptions{
compress: compress,
web: web,
host: host,
port: port,
})
}
func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) {
wsc, br, _ := testWSCreateClientGetInfo(t, compress, web, host, port)
// Send CONNECT and PING
@@ -3157,6 +3271,124 @@ func TestWSNkeyAuth(t *testing.T) {
}
}
func TestJWTWSCookieUser(t *testing.T) {
nucSigFunc := func() *jwt.UserClaims { return newJWTTestUserClaims() }
nucBearerFunc := func() *jwt.UserClaims {
ret := newJWTTestUserClaims()
ret.BearerToken = true
return ret
}
o := testWSOptions()
setupAddTrusted(o)
setupAddCookie(o)
s := RunServer(o)
buildMemAccResolver(s)
defer s.Shutdown()
genJwt := func(t *testing.T, nuc *jwt.UserClaims) string {
okp, _ := nkeys.FromSeed(oSeed)
akp, _ := nkeys.CreateAccount()
apub, _ := akp.PublicKey()
nac := jwt.NewAccountClaims(apub)
ajwt, err := nac.Encode(okp)
if err != nil {
t.Fatalf("Error generating account JWT: %v", err)
}
nkp, _ := nkeys.CreateUser()
pub, _ := nkp.PublicKey()
nuc.Subject = pub
jwt, err := nuc.Encode(akp)
if err != nil {
t.Fatalf("Error generating user JWT: %v", err)
}
addAccountToMemResolver(s, apub, ajwt)
return jwt
}
cliOpts := testWSClientOptions{
host: o.Websocket.Host,
port: o.Websocket.Port,
}
for _, test := range []struct {
name string
nuc *jwt.UserClaims
opts func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions)
expectAnswer string
}{
{
name: "protocol auth, non-bearer key, with signature",
nuc: nucSigFunc(),
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
return cliOpts, testClaimsOptions{nuc: claims}
},
expectAnswer: "+OK",
},
{
name: "protocol auth, non-bearer key, w/o required signature",
nuc: nucSigFunc(),
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
return cliOpts, testClaimsOptions{nuc: claims, dontSign: true}
},
expectAnswer: "-ERR",
},
{
name: "protocol auth, bearer key, w/o signature",
nuc: nucBearerFunc(),
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
return cliOpts, testClaimsOptions{nuc: claims, dontSign: true}
},
expectAnswer: "+OK",
},
{
name: "cookie auth, non-bearer key, protocol auth fail",
nuc: nucSigFunc(),
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
co := cliOpts
co.extraHeaders = map[string]string{}
co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims)
return co, testClaimsOptions{connectRequest: struct{}{}}
},
expectAnswer: "-ERR",
},
{
name: "cookie auth, bearer key, protocol auth success with implied cookie jwt",
nuc: nucBearerFunc(),
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
co := cliOpts
co.extraHeaders = map[string]string{}
co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims)
return co, testClaimsOptions{connectRequest: struct{}{}}
},
expectAnswer: "+OK",
},
{
name: "cookie auth, non-bearer key, protocol auth success via override jwt in CONNECT opts",
nuc: nucSigFunc(),
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
co := cliOpts
co.extraHeaders = map[string]string{}
co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims)
return co, testClaimsOptions{nuc: nucBearerFunc()}
},
expectAnswer: "+OK",
},
} {
t.Run(test.name, func(t *testing.T) {
cliOpt, claimOpt := test.opts(t, test.nuc)
claimOpt.expectAnswer = test.expectAnswer
_, c, _, _ := testWSWithClaims(t, s, cliOpt, claimOpt)
c.Close()
})
}
s.Shutdown()
}
// ==================================================================
// = Benchmark tests
// ==================================================================