From a47e5e045c44d573cf74277e800fdde1b238ee61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Julius=20=C5=BDaromskis?= <89458604+julius-welink@users.noreply.github.com> Date: Mon, 27 Sep 2021 10:07:38 +0300 Subject: [PATCH] [ADDED] TLS connection rate limiter --- server/opts.go | 12 +++++++ server/rate_counter.go | 65 +++++++++++++++++++++++++++++++++++ server/rate_counter_test.go | 52 ++++++++++++++++++++++++++++ server/server.go | 34 +++++++++++++++++++ test/tls_test.go | 68 +++++++++++++++++++++++++++++++++++++ 5 files changed, 231 insertions(+) create mode 100644 server/rate_counter.go create mode 100644 server/rate_counter_test.go diff --git a/server/opts.go b/server/opts.go index 044ce294..608caa00 100644 --- a/server/opts.go +++ b/server/opts.go @@ -249,6 +249,7 @@ type Options struct { TLSCaCert string `json:"-"` TLSConfig *tls.Config `json:"-"` TLSPinnedCerts PinnedCertSet `json:"-"` + TLSRateLimit int64 `json:"-"` AllowNonTLS bool `json:"-"` WriteDeadline time.Duration `json:"-"` MaxClosedClients int `json:"-"` @@ -513,6 +514,7 @@ type TLSConfigOpts struct { Map bool TLSCheckKnownURLs bool Timeout float64 + RateLimit int64 Ciphers []uint16 CurvePreferences []tls.CurveID PinnedCerts PinnedCertSet @@ -902,6 +904,7 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error o.TLSTimeout = tc.Timeout o.TLSMap = tc.Map o.TLSPinnedCerts = tc.PinnedCerts + o.TLSRateLimit = tc.RateLimit // Need to keep track of path of the original TLS config // and certs path for OCSP Stapling monitoring. @@ -3695,6 +3698,15 @@ func parseTLS(v interface{}, isClientCtx bool) (t *TLSConfigOpts, retErr error) return nil, &configErr{tk, "error parsing tls config, 'timeout' wrong type"} } tc.Timeout = at + case "connection_rate_limit": + at := int64(0) + switch mv := mv.(type) { + case int64: + at = mv + default: + return nil, &configErr{tk, "error parsing tls config, 'connection_rate_limit' wrong type"} + } + tc.RateLimit = at case "pinned_certs": ra, ok := mv.([]interface{}) if !ok { diff --git a/server/rate_counter.go b/server/rate_counter.go new file mode 100644 index 00000000..37b47dc7 --- /dev/null +++ b/server/rate_counter.go @@ -0,0 +1,65 @@ +// Copyright 2021-2022 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "sync" + "time" +) + +type rateCounter struct { + limit int64 + count int64 + blocked uint64 + end time.Time + interval time.Duration + mu sync.Mutex +} + +func newRateCounter(limit int64) *rateCounter { + return &rateCounter{ + limit: limit, + interval: time.Second, + } +} + +func (r *rateCounter) allow() bool { + now := time.Now() + + r.mu.Lock() + + if now.After(r.end) { + r.count = 0 + r.end = now.Add(r.interval) + } else { + r.count++ + } + allow := r.count < r.limit + if !allow { + r.blocked++ + } + + r.mu.Unlock() + + return allow +} + +func (r *rateCounter) countBlocked() uint64 { + r.mu.Lock() + blocked := r.blocked + r.blocked = 0 + r.mu.Unlock() + + return blocked +} diff --git a/server/rate_counter_test.go b/server/rate_counter_test.go new file mode 100644 index 00000000..80f4a14e --- /dev/null +++ b/server/rate_counter_test.go @@ -0,0 +1,52 @@ +// Copyright 2021-2022 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "testing" + "time" +) + +func TestRateCounter(t *testing.T) { + counter := newRateCounter(10) + counter.interval = 100 * time.Millisecond + + var i int + for i = 0; i < 10; i++ { + if !counter.allow() { + t.Errorf("counter should allow (iteration %d)", i) + } + } + for i = 0; i < 5; i++ { + if counter.allow() { + t.Errorf("counter should not allow (iteration %d)", i) + } + } + + blocked := counter.countBlocked() + if blocked != 5 { + t.Errorf("Expected blocked = 5, got %d", blocked) + } + + blocked = counter.countBlocked() + if blocked != 0 { + t.Errorf("Expected blocked = 0, got %d", blocked) + } + + time.Sleep(150 * time.Millisecond) + + if !counter.allow() { + t.Errorf("Expected true after current time window expired") + } +} diff --git a/server/server.go b/server/server.go index 6e3ded5e..5fc20022 100644 --- a/server/server.go +++ b/server/server.go @@ -259,6 +259,8 @@ type Server struct { rerrMu sync.Mutex rerrLast time.Time + connRateCounter *rateCounter + // If there is a system account configured, to still support the $G account, // the server will create a fake user and add it to the list of users. // Keep track of what that user name is for config reload purposes. @@ -365,6 +367,10 @@ func NewServer(opts *Options) (*Server, error) { httpReqStats: make(map[string]uint64), // Used to track HTTP requests } + if opts.TLSRateLimit > 0 { + s.connRateCounter = newRateCounter(opts.tlsConfigOpts.RateLimit) + } + // Trusted root operator keys. if !s.processTrustedKeys() { return nil, fmt.Errorf("Error processing trusted operator keys") @@ -513,6 +519,23 @@ func NewServer(opts *Options) (*Server, error) { return s, nil } +func (s *Server) logRejectedTLSConns() { + defer s.grWG.Done() + t := time.NewTicker(time.Second) + defer t.Stop() + for { + select { + case <-s.quitCh: + return + case <-t.C: + blocked := s.connRateCounter.countBlocked() + if blocked > 0 { + s.Warnf("Rejected %d connections due to TLS rate limiting", blocked) + } + } + } +} + // clusterName returns our cluster name which could be dynamic. func (s *Server) ClusterName() string { s.mu.Lock() @@ -1787,6 +1810,10 @@ func (s *Server) Start() { s.logPorts() } + if opts.TLSRateLimit > 0 { + s.startGoRoutine(s.logRejectedTLSConns) + } + // Wait for clients. s.AcceptLoop(clientListenReady) } @@ -2480,6 +2507,13 @@ func (s *Server) createClient(conn net.Conn) *client { // Check for TLS if !isClosed && tlsRequired { + if s.connRateCounter != nil && !s.connRateCounter.allow() { + c.mu.Unlock() + c.sendErr("Connection throttling is active. Please try again later.") + c.closeConnection(MaxConnectionsExceeded) + return nil + } + // If we have a prebuffer create a multi-reader. if len(pre) > 0 { c.nc = &tlsMixConn{c.nc, bytes.NewBuffer(pre)} diff --git a/test/tls_test.go b/test/tls_test.go index 6e92f827..ac17e722 100644 --- a/test/tls_test.go +++ b/test/tls_test.go @@ -1852,6 +1852,74 @@ func TestTLSPinnedCertsClient(t *testing.T) { nc.Close() } +type captureWarnLogger struct { + dummyLogger + receive chan string +} + +func newCaptureWarnLogger() *captureWarnLogger { + return &captureWarnLogger{ + receive: make(chan string, 100), + } +} + +func (l *captureWarnLogger) Warnf(format string, v ...interface{}) { + l.receive <- fmt.Sprintf(format, v...) +} + +func (l *captureWarnLogger) waitFor(expect string, timeout time.Duration) bool { + for { + select { + case msg := <-l.receive: + if strings.Contains(msg, expect) { + return true + } + case <-time.After(timeout): + return false + } + } +} + +func TestTLSConnectionRate(t *testing.T) { + config := ` + listen: "127.0.0.1:-1" + tls { + cert_file: "./configs/certs/server-cert.pem" + key_file: "./configs/certs/server-key.pem" + connection_rate_limit: 3 + } + ` + + confFileName := createConfFile(t, []byte(config)) + defer removeFile(t, confFileName) + + srv, _ := RunServerWithConfig(confFileName) + logger := newCaptureWarnLogger() + srv.SetLogger(logger, false, false) + defer srv.Shutdown() + + var err error + count := 0 + for count < 10 { + var nc *nats.Conn + nc, err = nats.Connect(srv.ClientURL(), nats.RootCAs("./configs/certs/ca.pem")) + + if err != nil { + break + } + nc.Close() + count++ + } + + if count != 3 { + t.Fatalf("Expected 3 connections per second, got %d (%v)", count, err) + } + + if !logger.waitFor("connections due to TLS rate limiting", time.Second) { + t.Fatalf("did not log 'TLS rate limiting' warning") + } +} + func TestTLSPinnedCertsRoute(t *testing.T) { tmplSeed := ` host: localhost