diff --git a/server/jwt.go b/server/jwt.go index d0fbf4f2..f4189678 100644 --- a/server/jwt.go +++ b/server/jwt.go @@ -171,16 +171,23 @@ func validateTimes(claims *jwt.UserClaims) (bool, time.Duration) { } else if len(claims.Times) == 0 { return true, time.Duration(0) } + loc := time.Local + if claims.Locale != "" { + var err error + if loc, err = time.LoadLocation(claims.Locale); err != nil { + return false, time.Duration(0) // parsing not expected to fail at this point + } + } now := time.Now() for _, timeRange := range claims.Times { y, m, d := now.Date() m = m - 1 d = d - 1 - start, err := time.ParseInLocation("15:04:05", timeRange.Start, now.Location()) + start, err := time.ParseInLocation("15:04:05", timeRange.Start, loc) if err != nil { return false, time.Duration(0) // parsing not expected to fail at this point } - end, err := time.ParseInLocation("15:04:05", timeRange.End, now.Location()) + end, err := time.ParseInLocation("15:04:05", timeRange.End, loc) if err != nil { return false, time.Duration(0) // parsing not expected to fail at this point } @@ -190,8 +197,9 @@ func validateTimes(claims *jwt.UserClaims) (bool, time.Duration) { } else { start = start.AddDate(y, int(m), d) } + start = start.In(now.Location()) if start.Before(now) { - end = end.AddDate(y, int(m), d) + end = end.AddDate(y, int(m), d).In(now.Location()) if end.After(now) { return true, end.Sub(now) } diff --git a/server/jwt_test.go b/server/jwt_test.go index 7a7fe8ed..19f56f28 100644 --- a/server/jwt_test.go +++ b/server/jwt_test.go @@ -3340,39 +3340,50 @@ func TestJWTTimeExpiration(t *testing.T) { defer os.Remove(conf) sA, _ := RunServerWithConfig(conf) defer sA.Shutdown() - t.Run("simple expiration", func(t *testing.T) { - start := time.Now() - creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { j.Times = []jwt.TimeRange{newTimeRange(start, validFor)} }) - defer os.Remove(creds) - disconnectChan := make(chan struct{}) - defer close(disconnectChan) - errChan := make(chan struct{}) - defer close(errChan) - c := natsConnect(t, sA.ClientURL(), - nats.UserCredentials(creds), - nats.DisconnectErrHandler(func(conn *nats.Conn, err error) { - if err != io.EOF { - return + for _, l := range []string{"", "Europe/Berlin", "America/New_York"} { + t.Run("simple expiration "+l, func(t *testing.T) { + start := time.Now() + creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { + if l == "" { + j.Times = []jwt.TimeRange{newTimeRange(start, validFor)} + } else { + loc, err := time.LoadLocation(l) + require_NoError(t, err) + j.Times = []jwt.TimeRange{newTimeRange(start.In(loc), validFor)} + j.Locale = l } - disconnectChan <- struct{}{} - }), - nats.ErrorHandler(func(conn *nats.Conn, s *nats.Subscription, err error) { - if err != nats.ErrAuthExpired { - return - } - now := time.Now() - stop := start.Add(validFor) - // assure event happens within a second of stop - if stop.Add(-validRange).Before(stop) && now.Before(stop.Add(validRange)) { - errChan <- struct{}{} - } - })) - chanRecv(t, errChan, 10*time.Second) - chanRecv(t, disconnectChan, 10*time.Second) - require_True(t, c.IsReconnecting()) - require_False(t, c.IsConnected()) - c.Close() - }) + }) + defer os.Remove(creds) + disconnectChan := make(chan struct{}) + defer close(disconnectChan) + errChan := make(chan struct{}) + defer close(errChan) + c := natsConnect(t, sA.ClientURL(), + nats.UserCredentials(creds), + nats.DisconnectErrHandler(func(conn *nats.Conn, err error) { + if err != io.EOF { + return + } + disconnectChan <- struct{}{} + }), + nats.ErrorHandler(func(conn *nats.Conn, s *nats.Subscription, err error) { + if err != nats.ErrAuthExpired { + return + } + now := time.Now() + stop := start.Add(validFor) + // assure event happens within a second of stop + if stop.Add(-validRange).Before(stop) && now.Before(stop.Add(validRange)) { + errChan <- struct{}{} + } + })) + chanRecv(t, errChan, 10*time.Second) + chanRecv(t, disconnectChan, 10*time.Second) + require_True(t, c.IsReconnecting()) + require_False(t, c.IsConnected()) + c.Close() + }) + } t.Run("double expiration", func(t *testing.T) { start1 := time.Now() start2 := start1.Add(2 * validFor)