From ffee747a6673e213ac1a40910bfd962b0cc2c6db Mon Sep 17 00:00:00 2001 From: "R.I.Pienaar" Date: Thu, 18 Nov 2021 13:48:44 +0100 Subject: [PATCH] expose the nonce to custom authentication Signed-off-by: R.I.Pienaar --- server/auth.go | 8 +++++--- server/client.go | 8 ++++++++ server/server_test.go | 44 ++++++++++++++++++++++++++++--------------- 3 files changed, 42 insertions(+), 18 deletions(-) diff --git a/server/auth.go b/server/auth.go index 3ee51576..5f5c9298 100644 --- a/server/auth.go +++ b/server/auth.go @@ -42,14 +42,16 @@ type Authentication interface { // ClientAuthentication is an interface for client authentication type ClientAuthentication interface { - // Get options associated with a client + // GetOpts gets options associated with a client GetOpts() *ClientOpts - // If TLS is enabled, TLS ConnectionState, nil otherwise + // GetTLSConnectionState if TLS is enabled, TLS ConnectionState, nil otherwise GetTLSConnectionState() *tls.ConnectionState - // Optionally map a user after auth. + // RegisterUser optionally map a user after auth. RegisterUser(*User) // RemoteAddress expose the connection information of the client RemoteAddress() net.Addr + // GetNonce is the nonce presented to the user in the INFO line + GetNonce() []byte // Kind indicates what type of connection this is matching defined constants like CLIENT, ROUTER, GATEWAY, LEAF etc Kind() int } diff --git a/server/client.go b/server/client.go index 68a26a9a..11fe8759 100644 --- a/server/client.go +++ b/server/client.go @@ -427,6 +427,14 @@ func (c *client) String() (id string) { return _EMPTY_ } +// GetNonce returns the nonce that was presented to the user on connection +func (c *client) GetNonce() []byte { + c.mu.Lock() + defer c.mu.Unlock() + + return c.nonce +} + // GetName returns the application supplied name for the connection. func (c *client) GetName() string { c.mu.Lock() diff --git a/server/server_test.go b/server/server_test.go index 3ed4b2ad..70756e82 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -603,32 +603,46 @@ func TestNilMonitoringPort(t *testing.T) { } } -type DummyAuth struct{} +type DummyAuth struct { + t *testing.T + needNonce bool +} func (d *DummyAuth) Check(c ClientAuthentication) bool { + if d.needNonce && len(c.GetNonce()) == 0 { + d.t.Fatalf("Expected a nonce but received none") + } else if !d.needNonce && len(c.GetNonce()) > 0 { + d.t.Fatalf("Received a nonce when none was expected") + } + return c.GetOpts().Username == "valid" } func TestCustomClientAuthentication(t *testing.T) { - var clientAuth DummyAuth + testAuth := func(t *testing.T, nonce bool) { + clientAuth := &DummyAuth{t, nonce} - opts := DefaultOptions() - opts.CustomClientAuthentication = &clientAuth + opts := DefaultOptions() + opts.CustomClientAuthentication = clientAuth + opts.AlwaysEnableNonce = nonce - s := RunServer(opts) + s := RunServer(opts) + defer s.Shutdown() - defer s.Shutdown() + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) - addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) - - nc, err := nats.Connect(addr, nats.UserInfo("valid", "")) - if err != nil { - t.Fatalf("Expected client to connect, got: %s", err) - } - nc.Close() - if _, err := nats.Connect(addr, nats.UserInfo("invalid", "")); err == nil { - t.Fatal("Expected client to fail to connect") + nc, err := nats.Connect(addr, nats.UserInfo("valid", "")) + if err != nil { + t.Fatalf("Expected client to connect, got: %s", err) + } + nc.Close() + if _, err := nats.Connect(addr, nats.UserInfo("invalid", "")); err == nil { + t.Fatal("Expected client to fail to connect") + } } + + t.Run("with nonce", func(t *testing.T) { testAuth(t, true) }) + t.Run("without nonce", func(t *testing.T) { testAuth(t, false) }) } func TestCustomRouterAuthentication(t *testing.T) {