[ADDED] TLS connection rate limiter

This commit is contained in:
Julius Žaromskis
2021-09-27 10:07:38 +03:00
parent 08ff14a24e
commit a47e5e045c
5 changed files with 231 additions and 0 deletions

View File

@@ -249,6 +249,7 @@ type Options struct {
TLSCaCert string `json:"-"`
TLSConfig *tls.Config `json:"-"`
TLSPinnedCerts PinnedCertSet `json:"-"`
TLSRateLimit int64 `json:"-"`
AllowNonTLS bool `json:"-"`
WriteDeadline time.Duration `json:"-"`
MaxClosedClients int `json:"-"`
@@ -513,6 +514,7 @@ type TLSConfigOpts struct {
Map bool
TLSCheckKnownURLs bool
Timeout float64
RateLimit int64
Ciphers []uint16
CurvePreferences []tls.CurveID
PinnedCerts PinnedCertSet
@@ -902,6 +904,7 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error
o.TLSTimeout = tc.Timeout
o.TLSMap = tc.Map
o.TLSPinnedCerts = tc.PinnedCerts
o.TLSRateLimit = tc.RateLimit
// Need to keep track of path of the original TLS config
// and certs path for OCSP Stapling monitoring.
@@ -3695,6 +3698,15 @@ func parseTLS(v interface{}, isClientCtx bool) (t *TLSConfigOpts, retErr error)
return nil, &configErr{tk, "error parsing tls config, 'timeout' wrong type"}
}
tc.Timeout = at
case "connection_rate_limit":
at := int64(0)
switch mv := mv.(type) {
case int64:
at = mv
default:
return nil, &configErr{tk, "error parsing tls config, 'connection_rate_limit' wrong type"}
}
tc.RateLimit = at
case "pinned_certs":
ra, ok := mv.([]interface{})
if !ok {

65
server/rate_counter.go Normal file
View File

@@ -0,0 +1,65 @@
// Copyright 2021-2022 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"sync"
"time"
)
type rateCounter struct {
limit int64
count int64
blocked uint64
end time.Time
interval time.Duration
mu sync.Mutex
}
func newRateCounter(limit int64) *rateCounter {
return &rateCounter{
limit: limit,
interval: time.Second,
}
}
func (r *rateCounter) allow() bool {
now := time.Now()
r.mu.Lock()
if now.After(r.end) {
r.count = 0
r.end = now.Add(r.interval)
} else {
r.count++
}
allow := r.count < r.limit
if !allow {
r.blocked++
}
r.mu.Unlock()
return allow
}
func (r *rateCounter) countBlocked() uint64 {
r.mu.Lock()
blocked := r.blocked
r.blocked = 0
r.mu.Unlock()
return blocked
}

View File

@@ -0,0 +1,52 @@
// Copyright 2021-2022 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"testing"
"time"
)
func TestRateCounter(t *testing.T) {
counter := newRateCounter(10)
counter.interval = 100 * time.Millisecond
var i int
for i = 0; i < 10; i++ {
if !counter.allow() {
t.Errorf("counter should allow (iteration %d)", i)
}
}
for i = 0; i < 5; i++ {
if counter.allow() {
t.Errorf("counter should not allow (iteration %d)", i)
}
}
blocked := counter.countBlocked()
if blocked != 5 {
t.Errorf("Expected blocked = 5, got %d", blocked)
}
blocked = counter.countBlocked()
if blocked != 0 {
t.Errorf("Expected blocked = 0, got %d", blocked)
}
time.Sleep(150 * time.Millisecond)
if !counter.allow() {
t.Errorf("Expected true after current time window expired")
}
}

View File

@@ -259,6 +259,8 @@ type Server struct {
rerrMu sync.Mutex
rerrLast time.Time
connRateCounter *rateCounter
// If there is a system account configured, to still support the $G account,
// the server will create a fake user and add it to the list of users.
// Keep track of what that user name is for config reload purposes.
@@ -365,6 +367,10 @@ func NewServer(opts *Options) (*Server, error) {
httpReqStats: make(map[string]uint64), // Used to track HTTP requests
}
if opts.TLSRateLimit > 0 {
s.connRateCounter = newRateCounter(opts.tlsConfigOpts.RateLimit)
}
// Trusted root operator keys.
if !s.processTrustedKeys() {
return nil, fmt.Errorf("Error processing trusted operator keys")
@@ -513,6 +519,23 @@ func NewServer(opts *Options) (*Server, error) {
return s, nil
}
func (s *Server) logRejectedTLSConns() {
defer s.grWG.Done()
t := time.NewTicker(time.Second)
defer t.Stop()
for {
select {
case <-s.quitCh:
return
case <-t.C:
blocked := s.connRateCounter.countBlocked()
if blocked > 0 {
s.Warnf("Rejected %d connections due to TLS rate limiting", blocked)
}
}
}
}
// clusterName returns our cluster name which could be dynamic.
func (s *Server) ClusterName() string {
s.mu.Lock()
@@ -1787,6 +1810,10 @@ func (s *Server) Start() {
s.logPorts()
}
if opts.TLSRateLimit > 0 {
s.startGoRoutine(s.logRejectedTLSConns)
}
// Wait for clients.
s.AcceptLoop(clientListenReady)
}
@@ -2480,6 +2507,13 @@ func (s *Server) createClient(conn net.Conn) *client {
// Check for TLS
if !isClosed && tlsRequired {
if s.connRateCounter != nil && !s.connRateCounter.allow() {
c.mu.Unlock()
c.sendErr("Connection throttling is active. Please try again later.")
c.closeConnection(MaxConnectionsExceeded)
return nil
}
// If we have a prebuffer create a multi-reader.
if len(pre) > 0 {
c.nc = &tlsMixConn{c.nc, bytes.NewBuffer(pre)}

View File

@@ -1852,6 +1852,74 @@ func TestTLSPinnedCertsClient(t *testing.T) {
nc.Close()
}
type captureWarnLogger struct {
dummyLogger
receive chan string
}
func newCaptureWarnLogger() *captureWarnLogger {
return &captureWarnLogger{
receive: make(chan string, 100),
}
}
func (l *captureWarnLogger) Warnf(format string, v ...interface{}) {
l.receive <- fmt.Sprintf(format, v...)
}
func (l *captureWarnLogger) waitFor(expect string, timeout time.Duration) bool {
for {
select {
case msg := <-l.receive:
if strings.Contains(msg, expect) {
return true
}
case <-time.After(timeout):
return false
}
}
}
func TestTLSConnectionRate(t *testing.T) {
config := `
listen: "127.0.0.1:-1"
tls {
cert_file: "./configs/certs/server-cert.pem"
key_file: "./configs/certs/server-key.pem"
connection_rate_limit: 3
}
`
confFileName := createConfFile(t, []byte(config))
defer removeFile(t, confFileName)
srv, _ := RunServerWithConfig(confFileName)
logger := newCaptureWarnLogger()
srv.SetLogger(logger, false, false)
defer srv.Shutdown()
var err error
count := 0
for count < 10 {
var nc *nats.Conn
nc, err = nats.Connect(srv.ClientURL(), nats.RootCAs("./configs/certs/ca.pem"))
if err != nil {
break
}
nc.Close()
count++
}
if count != 3 {
t.Fatalf("Expected 3 connections per second, got %d (%v)", count, err)
}
if !logger.waitFor("connections due to TLS rate limiting", time.Second) {
t.Fatalf("did not log 'TLS rate limiting' warning")
}
}
func TestTLSPinnedCertsRoute(t *testing.T) {
tmplSeed := `
host: localhost