Properly handle and enforce max payload

This commit is contained in:
Derek Collison
2015-08-05 22:05:58 -07:00
parent 9a60bc1364
commit 075529e2fe
14 changed files with 180 additions and 50 deletions

View File

@@ -39,6 +39,7 @@ type client struct {
lang string
opts clientOpts
nc net.Conn
mpay int
ncs string
bw *bufio.Writer
srv *Server
@@ -153,9 +154,9 @@ func (c *client) readLoop() {
return
}
if err := c.parse(b[:n]); err != nil {
c.Errorf("Error reading from client: %s", err.Error())
// Auth was handled inline
if err != ErrAuthorization {
// handled inline
if err != ErrMaxPayload && err != ErrAuthorization {
c.Errorf("Error reading from client: %s", err.Error())
c.sendErr("Parser Error")
c.closeConnection()
}
@@ -297,10 +298,17 @@ func (c *client) authTimeout() {
}
func (c *client) authViolation() {
c.Errorf(ErrAuthorization.Error())
c.sendErr("Authorization Violation")
c.closeConnection()
}
func (c *client) maxPayloadViolation(sz int) {
c.Errorf("%s: %d vs %d", ErrMaxPayload.Error(), sz, c.mpay)
c.sendErr("Maximum Payload Violation")
c.closeConnection()
}
func (c *client) sendErr(err string) {
c.mu.Lock()
if c.bw != nil {
@@ -430,6 +438,11 @@ func (c *client) processPub(arg []byte) error {
if c.pa.size < 0 {
return fmt.Errorf("processPub Bad or Missing Size: '%s'", arg)
}
if c.mpay > 0 && c.pa.size > c.mpay {
c.maxPayloadViolation(c.pa.size)
return ErrMaxPayload
}
if c.opts.Pedantic && !sublist.IsValidLiteralSubject(c.pa.subject) {
c.sendErr("Invalid Subject")
}

View File

@@ -20,8 +20,21 @@ log_file: "/tmp/gnatsd.log"
syslog: true
remote_syslog: "udp://foo.com:33"
#pid file
# pid file
pid_file: "/tmp/gnatsd.pid"
# prof_port
prof_port: 6543
# max_connections
max_connections: 100
# maximum control line
max_control_line: 2048
# maximum payload
max_payload: 65536
# slow consumer threshold
max_pending_size: 10000000

View File

@@ -6,8 +6,11 @@ import "errors"
var (
// ErrConnectionClosed represents error condition on a closed connection.
ErrConnectionClosed = errors.New("Connection closed")
ErrConnectionClosed = errors.New("Connection Closed")
// ErrAuthorization represents error condition on failed authorization.
ErrAuthorization = errors.New("Authorization Error")
// ErrMaxPayload represents error condition when the payload is too big.
ErrMaxPayload = errors.New("Maximum Payload Exceeded")
)

View File

@@ -34,6 +34,7 @@ type Options struct {
AuthTimeout float64 `json:"auth_timeout"`
MaxControlLine int `json:"max_control_line"`
MaxPayload int `json:"max_payload"`
MaxPending int `json:"max_pending_size"`
ClusterHost string `json:"addr"`
ClusterPort int `json:"port"`
ClusterUsername string `json:"-"`
@@ -107,6 +108,14 @@ func ProcessConfigFile(configFile string) (*Options, error) {
opts.PidFile = v.(string)
case "prof_port":
opts.ProfPort = int(v.(int64))
case "max_control_line":
opts.MaxControlLine = int(v.(int64))
case "max_payload":
opts.MaxPayload = int(v.(int64))
case "max_pending_size", "max_pending":
opts.MaxPending = int(v.(int64))
case "max_connections", "max_conn":
opts.MaxConn = int(v.(int64))
}
}
return opts, nil

View File

@@ -44,20 +44,24 @@ func TestOptions_RandomPort(t *testing.T) {
func TestConfigFile(t *testing.T) {
golden := &Options{
Host: "apcera.me",
Port: 4242,
Username: "derek",
Password: "bella",
AuthTimeout: 1.0,
Debug: false,
Trace: true,
Logtime: false,
HTTPPort: 8222,
LogFile: "/tmp/gnatsd.log",
PidFile: "/tmp/gnatsd.pid",
ProfPort: 6543,
Syslog: true,
RemoteSyslog: "udp://foo.com:33",
Host: "apcera.me",
Port: 4242,
Username: "derek",
Password: "bella",
AuthTimeout: 1.0,
Debug: false,
Trace: true,
Logtime: false,
HTTPPort: 8222,
LogFile: "/tmp/gnatsd.log",
PidFile: "/tmp/gnatsd.pid",
ProfPort: 6543,
Syslog: true,
RemoteSyslog: "udp://foo.com:33",
MaxControlLine: 2048,
MaxPayload: 65536,
MaxConn: 100,
MaxPending: 10000000,
}
opts, err := ProcessConfigFile("./configs/test.conf")
@@ -73,20 +77,24 @@ func TestConfigFile(t *testing.T) {
func TestMergeOverrides(t *testing.T) {
golden := &Options{
Host: "apcera.me",
Port: 2222,
Username: "derek",
Password: "spooky",
AuthTimeout: 1.0,
Debug: true,
Trace: true,
Logtime: false,
HTTPPort: DEFAULT_HTTP_PORT,
LogFile: "/tmp/gnatsd.log",
PidFile: "/tmp/gnatsd.pid",
ProfPort: 6789,
Syslog: true,
RemoteSyslog: "udp://foo.com:33",
Host: "apcera.me",
Port: 2222,
Username: "derek",
Password: "spooky",
AuthTimeout: 1.0,
Debug: true,
Trace: true,
Logtime: false,
HTTPPort: DEFAULT_HTTP_PORT,
LogFile: "/tmp/gnatsd.log",
PidFile: "/tmp/gnatsd.pid",
ProfPort: 6789,
Syslog: true,
RemoteSyslog: "udp://foo.com:33",
MaxControlLine: 2048,
MaxPayload: 65536,
MaxConn: 100,
MaxPending: 10000000,
}
fopts, err := ProcessConfigFile("./configs/test.conf")
if err != nil {

View File

@@ -223,6 +223,15 @@ func TestParsePubArg(t *testing.T) {
testPubArg(c, t)
}
func TestParsePubBadSize(t *testing.T) {
c := dummyClient()
// Setup localized max payload
c.mpay = 32768
if err := c.processPub([]byte("foo 2222222222222222\r")); err == nil {
t.Fatalf("Expected parse error for size too large")
}
}
func TestParseMsg(t *testing.T) {
c := dummyClient()

View File

@@ -81,7 +81,7 @@ func New(opts *Options) *Server {
Port: opts.Port,
AuthRequired: false,
SslRequired: false,
MaxPayload: MAX_PAYLOAD_SIZE,
MaxPayload: opts.MaxPayload,
}
s := &Server{
@@ -380,7 +380,7 @@ func (s *Server) StartHTTPMonitoring() {
}
func (s *Server) createClient(conn net.Conn) *client {
c := &client{srv: s, nc: conn, opts: defaultOpts}
c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: s.info.MaxPayload}
// Grab lock
c.mu.Lock()

View File

@@ -60,12 +60,16 @@ func benchPub(b *testing.B, subject, payload string) {
var ch = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()")
func sizedString(sz int) string {
func sizedBytes(sz int) []byte {
b := make([]byte, sz)
for i := range b {
b[i] = ch[rand.Intn(len(ch))]
}
return string(b)
return b
}
func sizedString(sz int) string {
return string(sizedBytes(sz))
}
func Benchmark___PubNo_Payload(b *testing.B) {

View File

@@ -12,14 +12,8 @@ import (
)
func runServers(t *testing.T) (srvA, srvB *server.Server, optsA, optsB *server.Options) {
optsA, _ = server.ProcessConfigFile("./configs/srv_a.conf")
optsB, _ = server.ProcessConfigFile("./configs/srv_b.conf")
optsA.NoSigs, optsA.NoLog = true, true
optsB.NoSigs, optsB.NoLog = true, true
srvA = RunServer(optsA)
srvB = RunServer(optsB)
srvA, optsA = RunServerWithConfig("./configs/srv_a.conf")
srvB, optsB = RunServerWithConfig("./configs/srv_b.conf")
return
}

View File

@@ -0,0 +1,9 @@
# Copyright 2015 Apcera Inc. All rights reserved.
# Config file to test overrides to client
port: 4224
# maximum payload
max_payload: 2222

View File

@@ -10,7 +10,7 @@ import (
func TestSimpleGoServerShutdown(t *testing.T) {
base := runtime.NumGoroutine()
s := runDefaultServer()
s := RunDefaultServer()
s.Shutdown()
time.Sleep(10 * time.Millisecond)
delta := (runtime.NumGoroutine() - base)
@@ -21,7 +21,7 @@ func TestSimpleGoServerShutdown(t *testing.T) {
func TestGoServerShutdownWithClients(t *testing.T) {
base := runtime.NumGoroutine()
s := runDefaultServer()
s := RunDefaultServer()
for i := 0; i < 50; i++ {
createClientConn(t, "localhost", 4222)
}
@@ -37,7 +37,7 @@ func TestGoServerShutdownWithClients(t *testing.T) {
}
func TestGoServerMultiShutdown(t *testing.T) {
s := runDefaultServer()
s := RunDefaultServer()
s.Shutdown()
s.Shutdown()
}

36
test/maxpayload_test.go Normal file
View File

@@ -0,0 +1,36 @@
// Copyright 2015 Apcera Inc. All rights reserved.
package test
import (
"fmt"
"strings"
"testing"
"time"
"github.com/nats-io/nats"
)
func TestMaxPayload(t *testing.T) {
srv, opts := RunServerWithConfig("./configs/override.conf")
defer srv.Shutdown()
nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d/", opts.Host, opts.Port))
if err != nil {
t.Fatalf("Could not connect to server: %v", err)
}
defer nc.Close()
big := sizedBytes(4 * 1024 * 1024)
nc.Publish("foo", big)
err = nc.FlushTimeout(1 * time.Second)
if err == nil {
t.Fatalf("Expected an error from flush")
}
if strings.Contains(err.Error(), "Maximum Payload Violation") != true {
t.Fatalf("Received wrong error message (%v)\n", err)
}
if !nc.IsClosed() {
t.Fatalf("Expected connection to be closed")
}
}

21
test/opts_test.go Normal file
View File

@@ -0,0 +1,21 @@
// Copyright 2015 Apcera Inc. All rights reserved.
package test
import (
"testing"
)
func TestServerConfig(t *testing.T) {
srv, opts := RunServerWithConfig("./configs/override.conf")
defer srv.Shutdown()
c := createClientConn(t, opts.Host, opts.Port)
defer c.Close()
sinfo := checkInfoMsg(t, c)
if sinfo.MaxPayload != opts.MaxPayload {
t.Fatalf("Expected max_payload from server, got %d vs %d",
opts.MaxPayload, sinfo.MaxPayload)
}
}

View File

@@ -38,7 +38,7 @@ var DefaultTestOptions = server.Options{
NoSigs: true,
}
func runDefaultServer() *server.Server {
func RunDefaultServer() *server.Server {
return RunServer(&DefaultTestOptions)
}
@@ -47,6 +47,16 @@ func RunServer(opts *server.Options) *server.Server {
return RunServerWithAuth(opts, nil)
}
func RunServerWithConfig(configFile string) (srv *server.Server, opts *server.Options) {
opts, err := server.ProcessConfigFile(configFile)
if err != nil {
panic(fmt.Sprintf("Error processing configuration file: %v", err))
}
opts.NoSigs, opts.NoLog = true, true
srv = RunServer(opts)
return
}
// New Go Routine based server with auth
func RunServerWithAuth(opts *server.Options, auth server.Auth) *server.Server {
if opts == nil {
@@ -193,7 +203,7 @@ func checkSocket(t tLogger, addr string, wait time.Duration) {
t.Fatalf("Failed to connect to the socket: %q", addr)
}
func checkInfoMsg(t tLogger, c net.Conn) {
func checkInfoMsg(t tLogger, c net.Conn) server.Info {
buf := expectResult(t, c, infoRe)
js := infoRe.FindAllSubmatch(buf, 1)[0][1]
var sinfo server.Info
@@ -201,6 +211,7 @@ func checkInfoMsg(t tLogger, c net.Conn) {
if err != nil {
stackFatalf(t, "Could not unmarshal INFO json: %v\n", err)
}
return sinfo
}
func doConnect(t tLogger, c net.Conn, verbose, pedantic, ssl bool) {