From 3bb473c67d5f95b7c4acf5cb08ad185770867620 Mon Sep 17 00:00:00 2001 From: "R.I.Pienaar" Date: Thu, 27 Oct 2022 12:57:30 +0200 Subject: [PATCH] adds the notion of a connection deadline to User This will be used mainly by CustomClientAuthentication implementations to indicate that the user connection should be disconnected at some point in future - like when a certificate or token expires Signed-off-by: R.I.Pienaar --- server/auth.go | 9 +++++++++ server/auth_test.go | 37 +++++++++++++++++++++++++++++++++++++ server/client.go | 5 +++++ server/server_test.go | 20 ++++++++++++++++++-- 4 files changed, 69 insertions(+), 2 deletions(-) diff --git a/server/auth.go b/server/auth.go index fbadce75..db9791a1 100644 --- a/server/auth.go +++ b/server/auth.go @@ -71,6 +71,7 @@ type User struct { Password string `json:"password"` Permissions *Permissions `json:"permissions,omitempty"` Account *Account `json:"account,omitempty"` + ConnectionDeadline time.Time `json:"connection_deadline,omitempty"` AllowedConnectionTypes map[string]struct{} `json:"connection_types,omitempty"` } @@ -83,6 +84,14 @@ func (u *User) clone() *User { clone := &User{} *clone = *u clone.Permissions = u.Permissions.clone() + + if len(u.AllowedConnectionTypes) > 0 { + clone.AllowedConnectionTypes = make(map[string]struct{}) + for k, v := range u.AllowedConnectionTypes { + clone.AllowedConnectionTypes[k] = v + } + } + return clone } diff --git a/server/auth_test.go b/server/auth_test.go index dcd87649..28d6d753 100644 --- a/server/auth_test.go +++ b/server/auth_test.go @@ -14,6 +14,7 @@ package server import ( + "context" "fmt" "net" "net/url" @@ -276,6 +277,42 @@ func TestNoAuthUser(t *testing.T) { } } +func TestUserConnectionDeadline(t *testing.T) { + clientAuth := &DummyAuth{ + t: t, + register: true, + deadline: time.Now().Add(50 * time.Millisecond), + } + + opts := DefaultOptions() + opts.CustomClientAuthentication = clientAuth + + s := RunServer(opts) + defer s.Shutdown() + + var dcerr error + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + + nc, err := nats.Connect(s.ClientURL(), nats.UserInfo("valid", ""), nats.NoReconnect(), nats.ErrorHandler(func(nc *nats.Conn, _ *nats.Subscription, err error) { + dcerr = err + cancel() + })) + if err != nil { + t.Fatalf("Expected client to connect, got: %s", err) + } + + <-ctx.Done() + + if nc.IsConnected() { + t.Fatalf("Expected to be disconnected") + } + + if dcerr == nil || dcerr.Error() != "nats: authentication expired" { + t.Fatalf("Expected a auth expired error: got: %v", dcerr) + } +} + func TestNoAuthUserNoConnectProto(t *testing.T) { conf := createConfFile(t, []byte(` listen: "127.0.0.1:-1" diff --git a/server/client.go b/server/client.go index e6beae89..156d6c36 100644 --- a/server/client.go +++ b/server/client.go @@ -840,6 +840,11 @@ func (c *client) RegisterUser(user *User) { c.opts.Username = user.Username } + // if a deadline time stamp is set we start a timer to disconnect the user at that time + if !user.ConnectionDeadline.IsZero() { + c.atmr = time.AfterFunc(time.Until(user.ConnectionDeadline), c.authExpired) + } + c.mu.Unlock() } diff --git a/server/server_test.go b/server/server_test.go index 5116c7b4..3a3f374f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -611,6 +611,8 @@ func TestNilMonitoringPort(t *testing.T) { type DummyAuth struct { t *testing.T needNonce bool + deadline time.Time + register bool } func (d *DummyAuth) Check(c ClientAuthentication) bool { @@ -620,12 +622,26 @@ func (d *DummyAuth) Check(c ClientAuthentication) bool { d.t.Fatalf("Received a nonce when none was expected") } - return c.GetOpts().Username == "valid" + if c.GetOpts().Username != "valid" { + return false + } + + if !d.register { + return true + } + + u := &User{ + Username: c.GetOpts().Username, + ConnectionDeadline: d.deadline, + } + c.RegisterUser(u) + + return true } func TestCustomClientAuthentication(t *testing.T) { testAuth := func(t *testing.T, nonce bool) { - clientAuth := &DummyAuth{t, nonce} + clientAuth := &DummyAuth{t: t, needNonce: nonce} opts := DefaultOptions() opts.CustomClientAuthentication = clientAuth