mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 03:38:42 -07:00
Properly handle and enforce max payload
This commit is contained in:
@@ -39,6 +39,7 @@ type client struct {
|
||||
lang string
|
||||
opts clientOpts
|
||||
nc net.Conn
|
||||
mpay int
|
||||
ncs string
|
||||
bw *bufio.Writer
|
||||
srv *Server
|
||||
@@ -153,9 +154,9 @@ func (c *client) readLoop() {
|
||||
return
|
||||
}
|
||||
if err := c.parse(b[:n]); err != nil {
|
||||
c.Errorf("Error reading from client: %s", err.Error())
|
||||
// Auth was handled inline
|
||||
if err != ErrAuthorization {
|
||||
// handled inline
|
||||
if err != ErrMaxPayload && err != ErrAuthorization {
|
||||
c.Errorf("Error reading from client: %s", err.Error())
|
||||
c.sendErr("Parser Error")
|
||||
c.closeConnection()
|
||||
}
|
||||
@@ -297,10 +298,17 @@ func (c *client) authTimeout() {
|
||||
}
|
||||
|
||||
func (c *client) authViolation() {
|
||||
c.Errorf(ErrAuthorization.Error())
|
||||
c.sendErr("Authorization Violation")
|
||||
c.closeConnection()
|
||||
}
|
||||
|
||||
func (c *client) maxPayloadViolation(sz int) {
|
||||
c.Errorf("%s: %d vs %d", ErrMaxPayload.Error(), sz, c.mpay)
|
||||
c.sendErr("Maximum Payload Violation")
|
||||
c.closeConnection()
|
||||
}
|
||||
|
||||
func (c *client) sendErr(err string) {
|
||||
c.mu.Lock()
|
||||
if c.bw != nil {
|
||||
@@ -430,6 +438,11 @@ func (c *client) processPub(arg []byte) error {
|
||||
if c.pa.size < 0 {
|
||||
return fmt.Errorf("processPub Bad or Missing Size: '%s'", arg)
|
||||
}
|
||||
if c.mpay > 0 && c.pa.size > c.mpay {
|
||||
c.maxPayloadViolation(c.pa.size)
|
||||
return ErrMaxPayload
|
||||
}
|
||||
|
||||
if c.opts.Pedantic && !sublist.IsValidLiteralSubject(c.pa.subject) {
|
||||
c.sendErr("Invalid Subject")
|
||||
}
|
||||
|
||||
@@ -20,8 +20,21 @@ log_file: "/tmp/gnatsd.log"
|
||||
syslog: true
|
||||
remote_syslog: "udp://foo.com:33"
|
||||
|
||||
#pid file
|
||||
# pid file
|
||||
pid_file: "/tmp/gnatsd.pid"
|
||||
|
||||
# prof_port
|
||||
prof_port: 6543
|
||||
|
||||
# max_connections
|
||||
max_connections: 100
|
||||
|
||||
# maximum control line
|
||||
max_control_line: 2048
|
||||
|
||||
# maximum payload
|
||||
max_payload: 65536
|
||||
|
||||
# slow consumer threshold
|
||||
max_pending_size: 10000000
|
||||
|
||||
|
||||
@@ -6,8 +6,11 @@ import "errors"
|
||||
|
||||
var (
|
||||
// ErrConnectionClosed represents error condition on a closed connection.
|
||||
ErrConnectionClosed = errors.New("Connection closed")
|
||||
ErrConnectionClosed = errors.New("Connection Closed")
|
||||
|
||||
// ErrAuthorization represents error condition on failed authorization.
|
||||
ErrAuthorization = errors.New("Authorization Error")
|
||||
|
||||
// ErrMaxPayload represents error condition when the payload is too big.
|
||||
ErrMaxPayload = errors.New("Maximum Payload Exceeded")
|
||||
)
|
||||
|
||||
@@ -34,6 +34,7 @@ type Options struct {
|
||||
AuthTimeout float64 `json:"auth_timeout"`
|
||||
MaxControlLine int `json:"max_control_line"`
|
||||
MaxPayload int `json:"max_payload"`
|
||||
MaxPending int `json:"max_pending_size"`
|
||||
ClusterHost string `json:"addr"`
|
||||
ClusterPort int `json:"port"`
|
||||
ClusterUsername string `json:"-"`
|
||||
@@ -107,6 +108,14 @@ func ProcessConfigFile(configFile string) (*Options, error) {
|
||||
opts.PidFile = v.(string)
|
||||
case "prof_port":
|
||||
opts.ProfPort = int(v.(int64))
|
||||
case "max_control_line":
|
||||
opts.MaxControlLine = int(v.(int64))
|
||||
case "max_payload":
|
||||
opts.MaxPayload = int(v.(int64))
|
||||
case "max_pending_size", "max_pending":
|
||||
opts.MaxPending = int(v.(int64))
|
||||
case "max_connections", "max_conn":
|
||||
opts.MaxConn = int(v.(int64))
|
||||
}
|
||||
}
|
||||
return opts, nil
|
||||
|
||||
@@ -44,20 +44,24 @@ func TestOptions_RandomPort(t *testing.T) {
|
||||
|
||||
func TestConfigFile(t *testing.T) {
|
||||
golden := &Options{
|
||||
Host: "apcera.me",
|
||||
Port: 4242,
|
||||
Username: "derek",
|
||||
Password: "bella",
|
||||
AuthTimeout: 1.0,
|
||||
Debug: false,
|
||||
Trace: true,
|
||||
Logtime: false,
|
||||
HTTPPort: 8222,
|
||||
LogFile: "/tmp/gnatsd.log",
|
||||
PidFile: "/tmp/gnatsd.pid",
|
||||
ProfPort: 6543,
|
||||
Syslog: true,
|
||||
RemoteSyslog: "udp://foo.com:33",
|
||||
Host: "apcera.me",
|
||||
Port: 4242,
|
||||
Username: "derek",
|
||||
Password: "bella",
|
||||
AuthTimeout: 1.0,
|
||||
Debug: false,
|
||||
Trace: true,
|
||||
Logtime: false,
|
||||
HTTPPort: 8222,
|
||||
LogFile: "/tmp/gnatsd.log",
|
||||
PidFile: "/tmp/gnatsd.pid",
|
||||
ProfPort: 6543,
|
||||
Syslog: true,
|
||||
RemoteSyslog: "udp://foo.com:33",
|
||||
MaxControlLine: 2048,
|
||||
MaxPayload: 65536,
|
||||
MaxConn: 100,
|
||||
MaxPending: 10000000,
|
||||
}
|
||||
|
||||
opts, err := ProcessConfigFile("./configs/test.conf")
|
||||
@@ -73,20 +77,24 @@ func TestConfigFile(t *testing.T) {
|
||||
|
||||
func TestMergeOverrides(t *testing.T) {
|
||||
golden := &Options{
|
||||
Host: "apcera.me",
|
||||
Port: 2222,
|
||||
Username: "derek",
|
||||
Password: "spooky",
|
||||
AuthTimeout: 1.0,
|
||||
Debug: true,
|
||||
Trace: true,
|
||||
Logtime: false,
|
||||
HTTPPort: DEFAULT_HTTP_PORT,
|
||||
LogFile: "/tmp/gnatsd.log",
|
||||
PidFile: "/tmp/gnatsd.pid",
|
||||
ProfPort: 6789,
|
||||
Syslog: true,
|
||||
RemoteSyslog: "udp://foo.com:33",
|
||||
Host: "apcera.me",
|
||||
Port: 2222,
|
||||
Username: "derek",
|
||||
Password: "spooky",
|
||||
AuthTimeout: 1.0,
|
||||
Debug: true,
|
||||
Trace: true,
|
||||
Logtime: false,
|
||||
HTTPPort: DEFAULT_HTTP_PORT,
|
||||
LogFile: "/tmp/gnatsd.log",
|
||||
PidFile: "/tmp/gnatsd.pid",
|
||||
ProfPort: 6789,
|
||||
Syslog: true,
|
||||
RemoteSyslog: "udp://foo.com:33",
|
||||
MaxControlLine: 2048,
|
||||
MaxPayload: 65536,
|
||||
MaxConn: 100,
|
||||
MaxPending: 10000000,
|
||||
}
|
||||
fopts, err := ProcessConfigFile("./configs/test.conf")
|
||||
if err != nil {
|
||||
|
||||
@@ -223,6 +223,15 @@ func TestParsePubArg(t *testing.T) {
|
||||
testPubArg(c, t)
|
||||
}
|
||||
|
||||
func TestParsePubBadSize(t *testing.T) {
|
||||
c := dummyClient()
|
||||
// Setup localized max payload
|
||||
c.mpay = 32768
|
||||
if err := c.processPub([]byte("foo 2222222222222222\r")); err == nil {
|
||||
t.Fatalf("Expected parse error for size too large")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMsg(t *testing.T) {
|
||||
c := dummyClient()
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ func New(opts *Options) *Server {
|
||||
Port: opts.Port,
|
||||
AuthRequired: false,
|
||||
SslRequired: false,
|
||||
MaxPayload: MAX_PAYLOAD_SIZE,
|
||||
MaxPayload: opts.MaxPayload,
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
@@ -380,7 +380,7 @@ func (s *Server) StartHTTPMonitoring() {
|
||||
}
|
||||
|
||||
func (s *Server) createClient(conn net.Conn) *client {
|
||||
c := &client{srv: s, nc: conn, opts: defaultOpts}
|
||||
c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: s.info.MaxPayload}
|
||||
|
||||
// Grab lock
|
||||
c.mu.Lock()
|
||||
|
||||
@@ -60,12 +60,16 @@ func benchPub(b *testing.B, subject, payload string) {
|
||||
|
||||
var ch = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()")
|
||||
|
||||
func sizedString(sz int) string {
|
||||
func sizedBytes(sz int) []byte {
|
||||
b := make([]byte, sz)
|
||||
for i := range b {
|
||||
b[i] = ch[rand.Intn(len(ch))]
|
||||
}
|
||||
return string(b)
|
||||
return b
|
||||
}
|
||||
|
||||
func sizedString(sz int) string {
|
||||
return string(sizedBytes(sz))
|
||||
}
|
||||
|
||||
func Benchmark___PubNo_Payload(b *testing.B) {
|
||||
|
||||
@@ -12,14 +12,8 @@ import (
|
||||
)
|
||||
|
||||
func runServers(t *testing.T) (srvA, srvB *server.Server, optsA, optsB *server.Options) {
|
||||
optsA, _ = server.ProcessConfigFile("./configs/srv_a.conf")
|
||||
optsB, _ = server.ProcessConfigFile("./configs/srv_b.conf")
|
||||
|
||||
optsA.NoSigs, optsA.NoLog = true, true
|
||||
optsB.NoSigs, optsB.NoLog = true, true
|
||||
|
||||
srvA = RunServer(optsA)
|
||||
srvB = RunServer(optsB)
|
||||
srvA, optsA = RunServerWithConfig("./configs/srv_a.conf")
|
||||
srvB, optsB = RunServerWithConfig("./configs/srv_b.conf")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
9
test/configs/override.conf
Normal file
9
test/configs/override.conf
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright 2015 Apcera Inc. All rights reserved.
|
||||
|
||||
# Config file to test overrides to client
|
||||
|
||||
port: 4224
|
||||
|
||||
# maximum payload
|
||||
max_payload: 2222
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
func TestSimpleGoServerShutdown(t *testing.T) {
|
||||
base := runtime.NumGoroutine()
|
||||
s := runDefaultServer()
|
||||
s := RunDefaultServer()
|
||||
s.Shutdown()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
delta := (runtime.NumGoroutine() - base)
|
||||
@@ -21,7 +21,7 @@ func TestSimpleGoServerShutdown(t *testing.T) {
|
||||
|
||||
func TestGoServerShutdownWithClients(t *testing.T) {
|
||||
base := runtime.NumGoroutine()
|
||||
s := runDefaultServer()
|
||||
s := RunDefaultServer()
|
||||
for i := 0; i < 50; i++ {
|
||||
createClientConn(t, "localhost", 4222)
|
||||
}
|
||||
@@ -37,7 +37,7 @@ func TestGoServerShutdownWithClients(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGoServerMultiShutdown(t *testing.T) {
|
||||
s := runDefaultServer()
|
||||
s := RunDefaultServer()
|
||||
s.Shutdown()
|
||||
s.Shutdown()
|
||||
}
|
||||
|
||||
36
test/maxpayload_test.go
Normal file
36
test/maxpayload_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright 2015 Apcera Inc. All rights reserved.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats"
|
||||
)
|
||||
|
||||
func TestMaxPayload(t *testing.T) {
|
||||
srv, opts := RunServerWithConfig("./configs/override.conf")
|
||||
defer srv.Shutdown()
|
||||
|
||||
nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d/", opts.Host, opts.Port))
|
||||
if err != nil {
|
||||
t.Fatalf("Could not connect to server: %v", err)
|
||||
}
|
||||
defer nc.Close()
|
||||
|
||||
big := sizedBytes(4 * 1024 * 1024)
|
||||
nc.Publish("foo", big)
|
||||
err = nc.FlushTimeout(1 * time.Second)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected an error from flush")
|
||||
}
|
||||
if strings.Contains(err.Error(), "Maximum Payload Violation") != true {
|
||||
t.Fatalf("Received wrong error message (%v)\n", err)
|
||||
}
|
||||
if !nc.IsClosed() {
|
||||
t.Fatalf("Expected connection to be closed")
|
||||
}
|
||||
}
|
||||
21
test/opts_test.go
Normal file
21
test/opts_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Copyright 2015 Apcera Inc. All rights reserved.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestServerConfig(t *testing.T) {
|
||||
srv, opts := RunServerWithConfig("./configs/override.conf")
|
||||
defer srv.Shutdown()
|
||||
|
||||
c := createClientConn(t, opts.Host, opts.Port)
|
||||
defer c.Close()
|
||||
|
||||
sinfo := checkInfoMsg(t, c)
|
||||
if sinfo.MaxPayload != opts.MaxPayload {
|
||||
t.Fatalf("Expected max_payload from server, got %d vs %d",
|
||||
opts.MaxPayload, sinfo.MaxPayload)
|
||||
}
|
||||
}
|
||||
15
test/test.go
15
test/test.go
@@ -38,7 +38,7 @@ var DefaultTestOptions = server.Options{
|
||||
NoSigs: true,
|
||||
}
|
||||
|
||||
func runDefaultServer() *server.Server {
|
||||
func RunDefaultServer() *server.Server {
|
||||
return RunServer(&DefaultTestOptions)
|
||||
}
|
||||
|
||||
@@ -47,6 +47,16 @@ func RunServer(opts *server.Options) *server.Server {
|
||||
return RunServerWithAuth(opts, nil)
|
||||
}
|
||||
|
||||
func RunServerWithConfig(configFile string) (srv *server.Server, opts *server.Options) {
|
||||
opts, err := server.ProcessConfigFile(configFile)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error processing configuration file: %v", err))
|
||||
}
|
||||
opts.NoSigs, opts.NoLog = true, true
|
||||
srv = RunServer(opts)
|
||||
return
|
||||
}
|
||||
|
||||
// New Go Routine based server with auth
|
||||
func RunServerWithAuth(opts *server.Options, auth server.Auth) *server.Server {
|
||||
if opts == nil {
|
||||
@@ -193,7 +203,7 @@ func checkSocket(t tLogger, addr string, wait time.Duration) {
|
||||
t.Fatalf("Failed to connect to the socket: %q", addr)
|
||||
}
|
||||
|
||||
func checkInfoMsg(t tLogger, c net.Conn) {
|
||||
func checkInfoMsg(t tLogger, c net.Conn) server.Info {
|
||||
buf := expectResult(t, c, infoRe)
|
||||
js := infoRe.FindAllSubmatch(buf, 1)[0][1]
|
||||
var sinfo server.Info
|
||||
@@ -201,6 +211,7 @@ func checkInfoMsg(t tLogger, c net.Conn) {
|
||||
if err != nil {
|
||||
stackFatalf(t, "Could not unmarshal INFO json: %v\n", err)
|
||||
}
|
||||
return sinfo
|
||||
}
|
||||
|
||||
func doConnect(t tLogger, c net.Conn, verbose, pedantic, ssl bool) {
|
||||
|
||||
Reference in New Issue
Block a user