mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 03:38:42 -07:00
[ADDED] TLS connection rate limiter
This commit is contained in:
@@ -249,6 +249,7 @@ type Options struct {
|
||||
TLSCaCert string `json:"-"`
|
||||
TLSConfig *tls.Config `json:"-"`
|
||||
TLSPinnedCerts PinnedCertSet `json:"-"`
|
||||
TLSRateLimit int64 `json:"-"`
|
||||
AllowNonTLS bool `json:"-"`
|
||||
WriteDeadline time.Duration `json:"-"`
|
||||
MaxClosedClients int `json:"-"`
|
||||
@@ -513,6 +514,7 @@ type TLSConfigOpts struct {
|
||||
Map bool
|
||||
TLSCheckKnownURLs bool
|
||||
Timeout float64
|
||||
RateLimit int64
|
||||
Ciphers []uint16
|
||||
CurvePreferences []tls.CurveID
|
||||
PinnedCerts PinnedCertSet
|
||||
@@ -902,6 +904,7 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error
|
||||
o.TLSTimeout = tc.Timeout
|
||||
o.TLSMap = tc.Map
|
||||
o.TLSPinnedCerts = tc.PinnedCerts
|
||||
o.TLSRateLimit = tc.RateLimit
|
||||
|
||||
// Need to keep track of path of the original TLS config
|
||||
// and certs path for OCSP Stapling monitoring.
|
||||
@@ -3695,6 +3698,15 @@ func parseTLS(v interface{}, isClientCtx bool) (t *TLSConfigOpts, retErr error)
|
||||
return nil, &configErr{tk, "error parsing tls config, 'timeout' wrong type"}
|
||||
}
|
||||
tc.Timeout = at
|
||||
case "connection_rate_limit":
|
||||
at := int64(0)
|
||||
switch mv := mv.(type) {
|
||||
case int64:
|
||||
at = mv
|
||||
default:
|
||||
return nil, &configErr{tk, "error parsing tls config, 'connection_rate_limit' wrong type"}
|
||||
}
|
||||
tc.RateLimit = at
|
||||
case "pinned_certs":
|
||||
ra, ok := mv.([]interface{})
|
||||
if !ok {
|
||||
|
||||
65
server/rate_counter.go
Normal file
65
server/rate_counter.go
Normal file
@@ -0,0 +1,65 @@
|
||||
// Copyright 2021-2022 The NATS Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type rateCounter struct {
|
||||
limit int64
|
||||
count int64
|
||||
blocked uint64
|
||||
end time.Time
|
||||
interval time.Duration
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newRateCounter(limit int64) *rateCounter {
|
||||
return &rateCounter{
|
||||
limit: limit,
|
||||
interval: time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rateCounter) allow() bool {
|
||||
now := time.Now()
|
||||
|
||||
r.mu.Lock()
|
||||
|
||||
if now.After(r.end) {
|
||||
r.count = 0
|
||||
r.end = now.Add(r.interval)
|
||||
} else {
|
||||
r.count++
|
||||
}
|
||||
allow := r.count < r.limit
|
||||
if !allow {
|
||||
r.blocked++
|
||||
}
|
||||
|
||||
r.mu.Unlock()
|
||||
|
||||
return allow
|
||||
}
|
||||
|
||||
func (r *rateCounter) countBlocked() uint64 {
|
||||
r.mu.Lock()
|
||||
blocked := r.blocked
|
||||
r.blocked = 0
|
||||
r.mu.Unlock()
|
||||
|
||||
return blocked
|
||||
}
|
||||
52
server/rate_counter_test.go
Normal file
52
server/rate_counter_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright 2021-2022 The NATS Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateCounter(t *testing.T) {
|
||||
counter := newRateCounter(10)
|
||||
counter.interval = 100 * time.Millisecond
|
||||
|
||||
var i int
|
||||
for i = 0; i < 10; i++ {
|
||||
if !counter.allow() {
|
||||
t.Errorf("counter should allow (iteration %d)", i)
|
||||
}
|
||||
}
|
||||
for i = 0; i < 5; i++ {
|
||||
if counter.allow() {
|
||||
t.Errorf("counter should not allow (iteration %d)", i)
|
||||
}
|
||||
}
|
||||
|
||||
blocked := counter.countBlocked()
|
||||
if blocked != 5 {
|
||||
t.Errorf("Expected blocked = 5, got %d", blocked)
|
||||
}
|
||||
|
||||
blocked = counter.countBlocked()
|
||||
if blocked != 0 {
|
||||
t.Errorf("Expected blocked = 0, got %d", blocked)
|
||||
}
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
if !counter.allow() {
|
||||
t.Errorf("Expected true after current time window expired")
|
||||
}
|
||||
}
|
||||
@@ -259,6 +259,8 @@ type Server struct {
|
||||
rerrMu sync.Mutex
|
||||
rerrLast time.Time
|
||||
|
||||
connRateCounter *rateCounter
|
||||
|
||||
// If there is a system account configured, to still support the $G account,
|
||||
// the server will create a fake user and add it to the list of users.
|
||||
// Keep track of what that user name is for config reload purposes.
|
||||
@@ -365,6 +367,10 @@ func NewServer(opts *Options) (*Server, error) {
|
||||
httpReqStats: make(map[string]uint64), // Used to track HTTP requests
|
||||
}
|
||||
|
||||
if opts.TLSRateLimit > 0 {
|
||||
s.connRateCounter = newRateCounter(opts.tlsConfigOpts.RateLimit)
|
||||
}
|
||||
|
||||
// Trusted root operator keys.
|
||||
if !s.processTrustedKeys() {
|
||||
return nil, fmt.Errorf("Error processing trusted operator keys")
|
||||
@@ -513,6 +519,23 @@ func NewServer(opts *Options) (*Server, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Server) logRejectedTLSConns() {
|
||||
defer s.grWG.Done()
|
||||
t := time.NewTicker(time.Second)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.quitCh:
|
||||
return
|
||||
case <-t.C:
|
||||
blocked := s.connRateCounter.countBlocked()
|
||||
if blocked > 0 {
|
||||
s.Warnf("Rejected %d connections due to TLS rate limiting", blocked)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clusterName returns our cluster name which could be dynamic.
|
||||
func (s *Server) ClusterName() string {
|
||||
s.mu.Lock()
|
||||
@@ -1787,6 +1810,10 @@ func (s *Server) Start() {
|
||||
s.logPorts()
|
||||
}
|
||||
|
||||
if opts.TLSRateLimit > 0 {
|
||||
s.startGoRoutine(s.logRejectedTLSConns)
|
||||
}
|
||||
|
||||
// Wait for clients.
|
||||
s.AcceptLoop(clientListenReady)
|
||||
}
|
||||
@@ -2480,6 +2507,13 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
|
||||
// Check for TLS
|
||||
if !isClosed && tlsRequired {
|
||||
if s.connRateCounter != nil && !s.connRateCounter.allow() {
|
||||
c.mu.Unlock()
|
||||
c.sendErr("Connection throttling is active. Please try again later.")
|
||||
c.closeConnection(MaxConnectionsExceeded)
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we have a prebuffer create a multi-reader.
|
||||
if len(pre) > 0 {
|
||||
c.nc = &tlsMixConn{c.nc, bytes.NewBuffer(pre)}
|
||||
|
||||
@@ -1852,6 +1852,74 @@ func TestTLSPinnedCertsClient(t *testing.T) {
|
||||
nc.Close()
|
||||
}
|
||||
|
||||
type captureWarnLogger struct {
|
||||
dummyLogger
|
||||
receive chan string
|
||||
}
|
||||
|
||||
func newCaptureWarnLogger() *captureWarnLogger {
|
||||
return &captureWarnLogger{
|
||||
receive: make(chan string, 100),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *captureWarnLogger) Warnf(format string, v ...interface{}) {
|
||||
l.receive <- fmt.Sprintf(format, v...)
|
||||
}
|
||||
|
||||
func (l *captureWarnLogger) waitFor(expect string, timeout time.Duration) bool {
|
||||
for {
|
||||
select {
|
||||
case msg := <-l.receive:
|
||||
if strings.Contains(msg, expect) {
|
||||
return true
|
||||
}
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSConnectionRate(t *testing.T) {
|
||||
config := `
|
||||
listen: "127.0.0.1:-1"
|
||||
tls {
|
||||
cert_file: "./configs/certs/server-cert.pem"
|
||||
key_file: "./configs/certs/server-key.pem"
|
||||
connection_rate_limit: 3
|
||||
}
|
||||
`
|
||||
|
||||
confFileName := createConfFile(t, []byte(config))
|
||||
defer removeFile(t, confFileName)
|
||||
|
||||
srv, _ := RunServerWithConfig(confFileName)
|
||||
logger := newCaptureWarnLogger()
|
||||
srv.SetLogger(logger, false, false)
|
||||
defer srv.Shutdown()
|
||||
|
||||
var err error
|
||||
count := 0
|
||||
for count < 10 {
|
||||
var nc *nats.Conn
|
||||
nc, err = nats.Connect(srv.ClientURL(), nats.RootCAs("./configs/certs/ca.pem"))
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
nc.Close()
|
||||
count++
|
||||
}
|
||||
|
||||
if count != 3 {
|
||||
t.Fatalf("Expected 3 connections per second, got %d (%v)", count, err)
|
||||
}
|
||||
|
||||
if !logger.waitFor("connections due to TLS rate limiting", time.Second) {
|
||||
t.Fatalf("did not log 'TLS rate limiting' warning")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSPinnedCertsRoute(t *testing.T) {
|
||||
tmplSeed := `
|
||||
host: localhost
|
||||
|
||||
Reference in New Issue
Block a user