From 6c61464915fe59e3b23639701860ec91d034443d Mon Sep 17 00:00:00 2001 From: Matthias Hanel Date: Thu, 20 Aug 2020 15:28:16 -0400 Subject: [PATCH] [ADDED] Checks for CIDR blocks and connect time ranges specified in jwt (#1567) because times stored are hh:mm:ss it is possible to end up with start > end where end is actually the next day. jwt.go line 189 Also, ranges are based on the servers location, not the clients. Signed-off-by: Matthias Hanel --- server/auth.go | 11 +- server/client.go | 17 ++- server/jwt.go | 59 ++++++++++ server/jwt_test.go | 268 +++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 329 insertions(+), 26 deletions(-) diff --git a/server/auth.go b/server/auth.go index ad9de95b..0ff19eec 100644 --- a/server/auth.go +++ b/server/auth.go @@ -514,6 +514,15 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo c.Debugf("User authentication revoked") return false } + if !validateSrc(juc, c.host) { + c.Errorf("Bad src Ip %s", c.host) + return false + } + allowNow, validFor := validateTimes(juc) + if !allowNow { + c.Errorf("Outside connect times") + return false + } nkey = buildInternalNkeyUser(juc, acc) if err := c.RegisterNkeyUser(nkey); err != nil { @@ -526,7 +535,7 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) boo s.accountConnectEvent(c) // Check if we need to set an auth timer if the user jwt expires. - c.checkExpiration(juc.Claims()) + c.setExpiration(juc.Claims(), validFor) return true } diff --git a/server/client.go b/server/client.go index fcdf08f1..420f4f3d 100644 --- a/server/client.go +++ b/server/client.go @@ -767,16 +767,23 @@ func (c *client) setPermissions(perms *Permissions) { // Check to see if we have an expiration for the user JWT via base claims. // FIXME(dlc) - Clear on connect with new JWT. -func (c *client) checkExpiration(claims *jwt.ClaimsData) { +func (c *client) setExpiration(claims *jwt.ClaimsData, validFor time.Duration) { if claims.Expires == 0 { + if validFor != 0 { + c.setExpirationTimer(validFor) + } return } + expiresAt := time.Duration(0) tn := time.Now().Unix() - if claims.Expires < tn { - return + if claims.Expires > tn { + expiresAt = time.Duration(claims.Expires-tn) * time.Second + } + if validFor != 0 && validFor < expiresAt { + c.setExpirationTimer(validFor) + } else { + c.setExpirationTimer(expiresAt) } - expiresAt := time.Duration(claims.Expires - tn) - c.setExpirationTimer(expiresAt * time.Second) } // This will load up the deny structure used for filtering delivered diff --git a/server/jwt.go b/server/jwt.go index 029a5e46..f66237fa 100644 --- a/server/jwt.go +++ b/server/jwt.go @@ -16,8 +16,10 @@ package server import ( "fmt" "io/ioutil" + "net" "regexp" "strings" + "time" "github.com/nats-io/jwt/v2" "github.com/nats-io/nkeys" @@ -140,3 +142,60 @@ func validateTrustedOperators(o *Options) error { } return nil } + +func validateSrc(claims *jwt.UserClaims, host string) bool { + if claims == nil { + return false + } else if claims.Src == "" { + return true + } else if host == "" { + return false + } + ip := net.ParseIP(host) + if ip == nil { + return false + } + for _, cidr := range strings.Split(claims.Src, ",") { + if _, net, err := net.ParseCIDR(cidr); err != nil { + return false // should not happen as this jwt is invalid + } else if net.Contains(ip) { + return true + } + } + return false +} + +func validateTimes(claims *jwt.UserClaims) (bool, time.Duration) { + if claims == nil { + return false, time.Duration(0) + } else if len(claims.Times) == 0 { + return true, time.Duration(0) + } + now := time.Now() + for _, timeRange := range claims.Times { + y, m, d := now.Date() + m = m - 1 + d = d - 1 + start, err := time.ParseInLocation("15:04:05", timeRange.Start, now.Location()) + if err != nil { + return false, time.Duration(0) // parsing not expected to fail at this point + } + end, err := time.ParseInLocation("15:04:05", timeRange.End, now.Location()) + if err != nil { + return false, time.Duration(0) // parsing not expected to fail at this point + } + if start.After(end) { + start = start.AddDate(y, int(m), d) + d++ // the intent is to be the next day + } else { + start = start.AddDate(y, int(m), d) + } + if start.Before(now) { + end = end.AddDate(y, int(m), d) + if end.After(now) { + return true, end.Sub(now) + } + } + } + return false, time.Duration(0) +} diff --git a/server/jwt_test.go b/server/jwt_test.go index f13bb23d..033a91c7 100644 --- a/server/jwt_test.go +++ b/server/jwt_test.go @@ -19,6 +19,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/ioutil" "net/http" "net/http/httptest" @@ -40,8 +41,17 @@ var ( oSeed = []byte("SOAFYNORQLQFJYBYNUGC5D7SH2MXMUX5BFEWWGHN3EK4VGG5TPT5DZP7QU") // This matches ./configs/nkeys/op.jwt ojwt = "eyJ0eXAiOiJqd3QiLCJhbGciOiJlZDI1NTE5In0.eyJhdWQiOiJURVNUUyIsImV4cCI6MTg1OTEyMTI3NSwianRpIjoiWE5MWjZYWVBIVE1ESlFSTlFPSFVPSlFHV0NVN01JNVc1SlhDWk5YQllVS0VRVzY3STI1USIsImlhdCI6MTU0Mzc2MTI3NSwiaXNzIjoiT0NBVDMzTVRWVTJWVU9JTUdOR1VOWEo2NkFIMlJMU0RBRjNNVUJDWUFZNVFNSUw2NU5RTTZYUUciLCJuYW1lIjoiU3luYWRpYSBDb21tdW5pY2F0aW9ucyBJbmMuIiwibmJmIjoxNTQzNzYxMjc1LCJzdWIiOiJPQ0FUMzNNVFZVMlZVT0lNR05HVU5YSjY2QUgyUkxTREFGM01VQkNZQVk1UU1JTDY1TlFNNlhRRyIsInR5cGUiOiJvcGVyYXRvciIsIm5hdHMiOnsic2lnbmluZ19rZXlzIjpbIk9EU0tSN01ZRlFaNU1NQUo2RlBNRUVUQ1RFM1JJSE9GTFRZUEpSTUFWVk40T0xWMllZQU1IQ0FDIiwiT0RTS0FDU1JCV1A1MzdEWkRSVko2NTdKT0lHT1BPUTZLRzdUNEhONk9LNEY2SUVDR1hEQUhOUDIiLCJPRFNLSTM2TFpCNDRPWTVJVkNSNlA1MkZaSlpZTVlXWlZXTlVEVExFWjVUSzJQTjNPRU1SVEFCUiJdfX0.hyfz6E39BMUh0GLzovFfk3wT4OfualftjdJ_eYkLfPvu5tZubYQ_Pn9oFYGCV_6yKy3KMGhWGUCyCdHaPhalBw" + oKp nkeys.KeyPair ) +func init() { + var err error + oKp, err = nkeys.FromSeed(oSeed) + if err != nil { + panic(fmt.Sprintf("Parsing oSeed failed with: %v", err)) + } +} + func opTrustBasicSetup() *Server { kp, _ := nkeys.FromSeed(oSeed) pub, _ := kp.PublicKey() @@ -92,8 +102,6 @@ func createClientWithIssuer(t *testing.T, s *Server, akp nkeys.KeyPair, optIssue func setupJWTTestWithClaims(t *testing.T, nac *jwt.AccountClaims, nuc *jwt.UserClaims, expected string) (*Server, nkeys.KeyPair, *testAsyncClient, *bufio.Reader) { t.Helper() - okp, _ := nkeys.FromSeed(oSeed) - akp, _ := nkeys.CreateAccount() apub, _ := akp.PublicKey() if nac == nil { @@ -101,7 +109,7 @@ func setupJWTTestWithClaims(t *testing.T, nac *jwt.AccountClaims, nuc *jwt.UserC } else { nac.Subject = apub } - ajwt, err := nac.Encode(okp) + ajwt, err := nac.Encode(oKp) if err != nil { t.Fatalf("Error generating account JWT: %v", err) } @@ -288,7 +296,10 @@ func TestJWTUserExpiresAfterConnect(t *testing.T) { s, c, cr := setupJWTTestWithUserClaims(t, nuc, "+OK") defer s.Shutdown() defer c.close() - l, _ := cr.ReadString('\n') + l, err := cr.ReadString('\n') + if err != nil { + t.Fatalf("Received %v", err) + } if !strings.HasPrefix(l, "PONG") { t.Fatalf("Expected a PONG") } @@ -296,7 +307,10 @@ func TestJWTUserExpiresAfterConnect(t *testing.T) { // Now we should expire after 1 second or so. time.Sleep(1250 * time.Millisecond) - l, _ = cr.ReadString('\n') + l, err = cr.ReadString('\n') + if err != nil { + t.Fatalf("Received %v", err) + } if !strings.HasPrefix(l, "-ERR ") { t.Fatalf("Expected an error") } @@ -2935,7 +2949,7 @@ func TestAccountNATSResolverFetch(t *testing.T) { nc := natsConnect(t, url, nats.UserCredentials(credsfile)) nc.Close() } - createAccountAndUser := func(pair nkeys.KeyPair, limit bool) (string, string, string, string) { + createAccountAndUser := func(limit bool) (string, string, string, string) { t.Helper() kp, _ := nkeys.CreateAccount() pub, _ := kp.PublicKey() @@ -2943,14 +2957,14 @@ func TestAccountNATSResolverFetch(t *testing.T) { if limit { claim.Limits.Conn = 1 } - jwt1, err := claim.Encode(pair) + jwt1, err := claim.Encode(oKp) require_NoError(t, err) time.Sleep(2 * time.Second) // create updated claim allowing more connections if limit { claim.Limits.Conn = 2 } - jwt2, err := claim.Encode(pair) + jwt2, err := claim.Encode(oKp) require_NoError(t, err) ukp, _ := nkeys.CreateUser() seed, _ := ukp.Seed() @@ -2991,21 +3005,14 @@ func TestAccountNATSResolverFetch(t *testing.T) { } return passCnt } - // Create Operator - op, _ := nkeys.CreateOperator() - opub, _ := op.PublicKey() - oc := jwt.NewOperatorClaims(opub) - oc.Subject = opub - ojwt, err := oc.Encode(op) - require_NoError(t, err) // Create Accounts and corresponding user creds - syspub, sysjwt, _, sysCreds := createAccountAndUser(op, false) + syspub, sysjwt, _, sysCreds := createAccountAndUser(false) defer os.Remove(sysCreds) - apub, ajwt1, ajwt2, aCreds := createAccountAndUser(op, true) + apub, ajwt1, ajwt2, aCreds := createAccountAndUser(true) defer os.Remove(aCreds) - bpub, bjwt1, bjwt2, bCreds := createAccountAndUser(op, true) + bpub, bjwt1, bjwt2, bCreds := createAccountAndUser(true) defer os.Remove(bCreds) - cpub, cjwt1, cjwt2, cCreds := createAccountAndUser(op, true) + cpub, cjwt1, cjwt2, cCreds := createAccountAndUser(true) defer os.Remove(cCreds) // Create one directory for each server dirA := createDir("srv-a") @@ -3180,8 +3187,229 @@ func TestAccountNATSResolverFetch(t *testing.T) { // Test exceeding limit. For the exclusive directory resolver, limit is a stop gap measure. // It is not expected to be hit. When hit the administrator is supposed to take action. - dpub, djwt1, _, dCreds := createAccountAndUser(op, true) + dpub, djwt1, _, dCreds := createAccountAndUser(true) defer os.Remove(dCreds) passCnt = updateJwt(sA.ClientURL(), sysCreds, dpub, djwt1) require_True(t, passCnt == 1) // Only Server C updated } + +func newTimeRange(start time.Time, dur time.Duration) jwt.TimeRange { + return jwt.TimeRange{Start: start.Format("15:04:05"), End: start.Add(dur).Format("15:04:05")} +} + +func createUserWithLimit(t *testing.T, accKp nkeys.KeyPair, expiration time.Time, limits func(*jwt.Limits)) string { + t.Helper() + ukp, _ := nkeys.CreateUser() + seed, _ := ukp.Seed() + upub, _ := ukp.PublicKey() + uclaim := newJWTTestUserClaims() + uclaim.Subject = upub + if limits != nil { + limits(&uclaim.Limits) + } + if !expiration.IsZero() { + uclaim.Expires = expiration.Unix() + } + vr := jwt.ValidationResults{} + uclaim.Validate(&vr) + require_Len(t, len(vr.Errors()), 0) + ujwt, err := uclaim.Encode(accKp) + require_NoError(t, err) + return genCredsFile(t, ujwt, seed) +} + +func TestJWTUserLimits(t *testing.T) { + // helper for time + inAnHour := time.Now().Add(time.Hour) + inTwoHours := time.Now().Add(2 * time.Hour) + doNotExpire := time.Now().AddDate(1, 0, 0) + // create account + kp, _ := nkeys.CreateAccount() + aPub, _ := kp.PublicKey() + claim := jwt.NewAccountClaims(aPub) + aJwt, err := claim.Encode(oKp) + require_NoError(t, err) + conf := createConfFile(t, []byte(fmt.Sprintf(` + listen: -1 + operator: %s + resolver: MEM + resolver_preload: { + %s: %s + } + `, ojwt, aPub, aJwt))) + defer os.Remove(conf) + sA, _ := RunServerWithConfig(conf) + defer sA.Shutdown() + for _, v := range []struct { + pass bool + f func(*jwt.Limits) + }{ + {true, nil}, + {false, func(j *jwt.Limits) { j.Src = "8.8.8.8/8" }}, + {true, func(j *jwt.Limits) { j.Src = "8.8.8.8/0" }}, + {true, func(j *jwt.Limits) { j.Src = "127.0.0.1/8" }}, + {true, func(j *jwt.Limits) { j.Src = "8.8.8.8/8,127.0.0.1/8" }}, + {false, func(j *jwt.Limits) { j.Src = "8.8.8.8/8,9.9.9.9/8" }}, + {true, func(j *jwt.Limits) { j.Times = append(j.Times, newTimeRange(time.Now(), time.Hour)) }}, + {false, func(j *jwt.Limits) { j.Times = append(j.Times, newTimeRange(time.Now().Add(time.Hour), time.Hour)) }}, + {true, func(j *jwt.Limits) { + j.Times = append(j.Times, newTimeRange(inAnHour, time.Hour), newTimeRange(time.Now(), time.Hour)) + }}, // last one is within range + {false, func(j *jwt.Limits) { + j.Times = append(j.Times, newTimeRange(inAnHour, time.Hour), newTimeRange(inTwoHours, time.Hour)) + }}, // out of range + {false, func(j *jwt.Limits) { + j.Times = append(j.Times, newTimeRange(inAnHour, 3*time.Hour), newTimeRange(inTwoHours, 2*time.Hour)) + }}, // overlapping [a[]b] out of range*/ + {false, func(j *jwt.Limits) { + j.Times = append(j.Times, newTimeRange(inAnHour, 3*time.Hour), newTimeRange(inTwoHours, time.Hour)) + }}, // overlapping [a[b]] out of range + // next day tests where end < begin + {true, func(j *jwt.Limits) { j.Times = append(j.Times, newTimeRange(time.Now(), 25*time.Hour)) }}, + {true, func(j *jwt.Limits) { j.Times = append(j.Times, newTimeRange(time.Now(), -time.Hour)) }}, + } { + t.Run("", func(t *testing.T) { + creds := createUserWithLimit(t, kp, doNotExpire, v.f) + defer os.Remove(creds) + if c, err := nats.Connect(sA.ClientURL(), nats.UserCredentials(creds)); err == nil { + c.Close() + if !v.pass { + t.Fatalf("Expected failure got none") + } + } else if v.pass { + t.Fatalf("Expected success got %v", err) + } else if !strings.Contains(err.Error(), "Authorization Violation") { + t.Fatalf("Expected error other than %v", err) + } + }) + } +} + +func TestJWTTimeExpiration(t *testing.T) { + validFor := 1500 * time.Millisecond + validRange := 500 * time.Millisecond + doNotExpire := time.Now().AddDate(1, 0, 0) + // create account + kp, _ := nkeys.CreateAccount() + aPub, _ := kp.PublicKey() + claim := jwt.NewAccountClaims(aPub) + aJwt, err := claim.Encode(oKp) + require_NoError(t, err) + conf := createConfFile(t, []byte(fmt.Sprintf(` + listen: -1 + operator: %s + resolver: MEM + resolver_preload: { + %s: %s + } + `, ojwt, aPub, aJwt))) + defer os.Remove(conf) + sA, _ := RunServerWithConfig(conf) + defer sA.Shutdown() + t.Run("simple expiration", func(t *testing.T) { + start := time.Now() + creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { j.Times = []jwt.TimeRange{newTimeRange(start, validFor)} }) + defer os.Remove(creds) + disconnectChan := make(chan struct{}) + defer close(disconnectChan) + errChan := make(chan struct{}) + defer close(errChan) + c := natsConnect(t, sA.ClientURL(), + nats.UserCredentials(creds), + nats.DisconnectErrHandler(func(conn *nats.Conn, err error) { + if err != io.EOF { + return + } + disconnectChan <- struct{}{} + }), + nats.ErrorHandler(func(conn *nats.Conn, s *nats.Subscription, err error) { + if err != nats.ErrAuthExpired { + return + } + now := time.Now() + stop := start.Add(validFor) + // assure event happens within a second of stop + if stop.Add(-validRange).Before(stop) && now.Before(stop.Add(validRange)) { + errChan <- struct{}{} + } + })) + <-errChan + <-disconnectChan + require_True(t, c.IsReconnecting()) + require_False(t, c.IsConnected()) + c.Close() + }) + t.Run("double expiration", func(t *testing.T) { + start1 := time.Now() + start2 := start1.Add(2 * validFor) + creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { + j.Times = []jwt.TimeRange{newTimeRange(start1, validFor), newTimeRange(start2, validFor)} + }) + defer os.Remove(creds) + errChan := make(chan struct{}) + defer close(errChan) + reConnectChan := make(chan struct{}) + defer close(reConnectChan) + c := natsConnect(t, sA.ClientURL(), + nats.UserCredentials(creds), + nats.ReconnectHandler(func(conn *nats.Conn) { + reConnectChan <- struct{}{} + }), + nats.ErrorHandler(func(conn *nats.Conn, s *nats.Subscription, err error) { + if err != nats.ErrAuthExpired { + return + } + now := time.Now() + stop := start1.Add(validFor) + // assure event happens within a second of stop + if stop.Add(-validRange).Before(stop) && now.Before(stop.Add(validRange)) { + errChan <- struct{}{} + return + } + stop = start2.Add(validFor) + // assure event happens within a second of stop + if stop.Add(-validRange).Before(stop) && now.Before(stop.Add(validRange)) { + errChan <- struct{}{} + } + })) + <-errChan + <-reConnectChan + require_False(t, c.IsReconnecting()) + require_True(t, c.IsConnected()) + <-errChan + c.Close() + }) + t.Run("lower jwt expiration overwrites time", func(t *testing.T) { + start := time.Now() + creds := createUserWithLimit(t, kp, start.Add(validFor), func(j *jwt.Limits) { j.Times = []jwt.TimeRange{newTimeRange(start, 2*validFor)} }) + defer os.Remove(creds) + disconnectChan := make(chan struct{}) + defer close(disconnectChan) + errChan := make(chan struct{}) + defer close(errChan) + c := natsConnect(t, sA.ClientURL(), + nats.UserCredentials(creds), + nats.DisconnectErrHandler(func(conn *nats.Conn, err error) { + if err != io.EOF { + return + } + disconnectChan <- struct{}{} + }), + nats.ErrorHandler(func(conn *nats.Conn, s *nats.Subscription, err error) { + if err != nats.ErrAuthExpired { + return + } + now := time.Now() + stop := start.Add(validFor) + // assure event happens within a second of stop + if stop.Add(-validRange).Before(stop) && now.Before(stop.Add(validRange)) { + errChan <- struct{}{} + } + })) + <-errChan + <-disconnectChan + require_True(t, c.IsReconnecting()) + require_False(t, c.IsConnected()) + c.Close() + }) +}