From 75d2ddb26bcf1de9d5fc8b2940645a03df3e0244 Mon Sep 17 00:00:00 2001 From: Todd Beets Date: Fri, 15 Sep 2023 12:36:34 -0700 Subject: [PATCH] AuthCallout request should include TLS data when client is NATS WS client --- server/auth_callout_test.go | 96 +++++++++++++++++++++++++++++++++++++ server/websocket.go | 8 ++-- 2 files changed, 101 insertions(+), 3 deletions(-) diff --git a/server/auth_callout_test.go b/server/auth_callout_test.go index 6486721a..3d2089b0 100644 --- a/server/auth_callout_test.go +++ b/server/auth_callout_test.go @@ -152,6 +152,23 @@ func (at *authTest) Connect(clientOptions ...nats.Option) *nats.Conn { return conn } +func (at *authTest) WSNewClient(clientOptions ...nats.Option) (*nats.Conn, error) { + pi := at.srv.PortsInfo(10 * time.Millisecond) + require_False(at.t, pi == nil) + conn, err := nats.Connect(strings.Replace(pi.WebSocket[0], "127.0.0.1", "localhost", 1), clientOptions...) + if err != nil { + return nil, err + } + at.clients = append(at.clients, conn) + return conn, nil +} + +func (at *authTest) WSConnect(clientOptions ...nats.Option) *nats.Conn { + conn, err := at.WSNewClient(clientOptions...) + require_NoError(at.t, err) + return conn +} + func (at *authTest) RequireConnectError(clientOptions ...nats.Option) { _, err := at.NewClient(clientOptions...) require_Error(at.t, err) @@ -1423,3 +1440,82 @@ func TestAuthCalloutOperator_AnyAccount(t *testing.T) { userInfo = response.Data.(*UserInfo) require_Equal(t, userInfo.Account, bpk) } + +func TestAuthCalloutWSClientTLSCerts(t *testing.T) { + conf := ` + server_name: T + listen: "localhost:-1" + + tls { + cert_file = "../test/configs/certs/tlsauth/server.pem" + key_file = "../test/configs/certs/tlsauth/server-key.pem" + ca_file = "../test/configs/certs/tlsauth/ca.pem" + verify = true + } + + websocket: { + listen: "localhost:-1" + tls { + cert_file = "../test/configs/certs/tlsauth/server.pem" + key_file = "../test/configs/certs/tlsauth/server-key.pem" + ca_file = "../test/configs/certs/tlsauth/ca.pem" + verify = true + } + } + + accounts { + AUTH { users [ {user: "auth", password: "pwd"} ] } + FOO {} + } + authorization { + timeout: 1s + auth_callout { + # Needs to be a public account nkey, will work for both server config and operator mode. + issuer: "ABJHLOVMPA4CI6R5KLNGOB4GSLNIY7IOUPAJC4YFNDLQVIOBYQGUWVLA" + account: AUTH + auth_users: [ auth ] + } + } + ` + handler := func(m *nats.Msg) { + user, si, ci, _, ctls := decodeAuthRequest(t, m.Data) + require_True(t, si.Name == "T") + require_True(t, ci.Host == "127.0.0.1") + require_True(t, ctls != nil) + // Zero since we are verified and will be under verified chains. + require_True(t, len(ctls.Certs) == 0) + require_True(t, len(ctls.VerifiedChains) == 1) + // Since we have a CA. + require_True(t, len(ctls.VerifiedChains[0]) == 2) + blk, _ := pem.Decode([]byte(ctls.VerifiedChains[0][0])) + cert, err := x509.ParseCertificate(blk.Bytes) + require_NoError(t, err) + if strings.HasPrefix(cert.Subject.String(), "CN=example.com") { + // Override blank name here, server will substitute. + ujwt := createAuthUser(t, user, "dlc", "FOO", "", nil, 0, nil) + m.Respond(serviceResponse(t, user, si.ID, ujwt, "", 0)) + } + } + + ac := NewAuthTest(t, conf, handler, + nats.UserInfo("auth", "pwd"), + nats.ClientCert("../test/configs/certs/tlsauth/client2.pem", "../test/configs/certs/tlsauth/client2-key.pem"), + nats.RootCAs("../test/configs/certs/tlsauth/ca.pem")) + defer ac.Cleanup() + + // Will use client cert to determine user. + nc := ac.WSConnect( + nats.ClientCert("../test/configs/certs/tlsauth/client2.pem", "../test/configs/certs/tlsauth/client2-key.pem"), + nats.RootCAs("../test/configs/certs/tlsauth/ca.pem"), + ) + + resp, err := nc.Request(userDirectInfoSubj, nil, time.Second) + require_NoError(t, err) + response := ServerAPIResponse{Data: &UserInfo{}} + err = json.Unmarshal(resp.Data, &response) + require_NoError(t, err) + userInfo := response.Data.(*UserInfo) + + require_True(t, userInfo.UserID == "dlc") + require_True(t, userInfo.Account == "FOO") +} diff --git a/server/websocket.go b/server/websocket.go index 462a923c..6bf82305 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -1234,12 +1234,14 @@ func (s *Server) createWSClient(conn net.Conn, ws *websocket) *client { return nil } s.clients[c.cid] = c - - // Websocket clients do TLS in the websocket http server. - // So no TLS here... s.mu.Unlock() c.mu.Lock() + // Websocket clients do TLS in the websocket http server. + // So no TLS initiation here... + if _, ok := conn.(*tls.Conn); ok { + c.flags.set(handshakeComplete) + } if c.isClosed() { c.mu.Unlock()