From b898b5561a2b54968cd2eae63929bb1feddb142b Mon Sep 17 00:00:00 2001 From: Tyler Treat Date: Thu, 22 Jun 2017 14:48:40 -0500 Subject: [PATCH] Add reload support for pid_file, max_control_line, and max_payload --- server/client.go | 11 ++-- server/configs/reload/max_payload.conf | 6 ++ server/configs/reload/reload.conf | 9 ++- server/reload.go | 57 ++++++++++++++++++ server/reload_test.go | 80 +++++++++++++++++++++++++- server/server.go | 19 +++--- 6 files changed, 163 insertions(+), 19 deletions(-) create mode 100644 server/configs/reload/max_payload.conf diff --git a/server/client.go b/server/client.go index b5a64012..9ccb9d6c 100644 --- a/server/client.go +++ b/server/client.go @@ -524,8 +524,8 @@ func (c *client) maxConnExceeded() { c.closeConnection() } -func (c *client) maxPayloadViolation(sz int) { - c.Errorf("%s: %d vs %d", ErrMaxPayload.Error(), sz, c.mpay) +func (c *client) maxPayloadViolation(sz, max int) { + c.Errorf("%s: %d vs %d", ErrMaxPayload.Error(), sz, max) c.sendErr("Maximum Payload Violation") c.closeConnection() } @@ -712,8 +712,11 @@ func (c *client) processPub(arg []byte) error { if c.pa.size < 0 { return fmt.Errorf("processPub Bad or Missing Size: '%s'", arg) } - if c.mpay > 0 && c.pa.size > c.mpay { - c.maxPayloadViolation(c.pa.size) + c.mu.Lock() + maxPayload := c.mpay + c.mu.Unlock() + if maxPayload > 0 && c.pa.size > maxPayload { + c.maxPayloadViolation(c.pa.size, maxPayload) return ErrMaxPayload } diff --git a/server/configs/reload/max_payload.conf b/server/configs/reload/max_payload.conf new file mode 100644 index 00000000..d1eed424 --- /dev/null +++ b/server/configs/reload/max_payload.conf @@ -0,0 +1,6 @@ +# Copyright 2017 Apcera Inc. All rights reserved. + +listen: localhost:-1 +log_file: "/tmp/gnatsd.log" + +max_payload: 1 diff --git a/server/configs/reload/reload.conf b/server/configs/reload/reload.conf index e2aba027..e80836c2 100644 --- a/server/configs/reload/reload.conf +++ b/server/configs/reload/reload.conf @@ -1,11 +1,14 @@ # Copyright 2017 Apcera Inc. All rights reserved. # logging options -debug: true # enable on reload -trace: true # enable on reload -logtime: true # enable on reload +debug: true # enable on reload +trace: true # enable on reload +logtime: true # enable on reload log_file: "/tmp/gnatsd-2.log" # change on reload +pid_file: "/tmp/gnatsd.pid" # change on reload +max_control_line: 512 # change on reload + # Enable TLS on reload tls { cert_file: "../test/configs/certs/server-cert.pem" diff --git a/server/reload.go b/server/reload.go index d02a8d87..c9dfa936 100644 --- a/server/reload.go +++ b/server/reload.go @@ -295,6 +295,57 @@ func (m *maxConnOption) Apply(server *Server) { server.Noticef("Reloaded: max_connections = %v", m.newValue) } +// pidFileOption implements the option interface for the `pid_file` setting. +type pidFileOption struct { + noopOption + newValue string +} + +// Apply the setting by logging the pid to the new file. +func (p *pidFileOption) Apply(server *Server) { + if p.newValue == "" { + return + } + if err := server.logPid(); err != nil { + server.Errorf("Failed to write pidfile: %v", err) + } + server.Noticef("Reloaded: pid_file = %v", p.newValue) +} + +// maxControlLineOption implements the option interface for the +// `max_control_line` setting. +type maxControlLineOption struct { + noopOption + newValue int +} + +// Apply is a no-op because the max control line will be reloaded after options +// are applied +func (m *maxControlLineOption) Apply(server *Server) { + server.Noticef("Reloaded: max_control_line = %d", m.newValue) +} + +// maxPayloadOption implements the option interface for the `max_payload` +// setting. +type maxPayloadOption struct { + noopOption + newValue int +} + +// Apply the setting by updating the server info and each client. +func (m *maxPayloadOption) Apply(server *Server) { + server.mu.Lock() + server.info.MaxPayload = m.newValue + server.generateServerInfoJSON() + for _, client := range server.clients { + client.mu.Lock() + client.mpay = m.newValue + client.mu.Unlock() + } + server.mu.Unlock() + server.Noticef("Reloaded: max_payload = %d", m.newValue) +} + // Reload reads the current configuration file and applies any supported // changes. This returns an error if the server was not started with a config // file or an option which doesn't support hot-swapping was changed. @@ -384,6 +435,12 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { diffOpts = append(diffOpts, &routesOption{add: add, remove: remove}) case "maxconn": diffOpts = append(diffOpts, &maxConnOption{newValue: newValue.(int)}) + case "pidfile": + diffOpts = append(diffOpts, &pidFileOption{newValue: newValue.(string)}) + case "maxcontrolline": + diffOpts = append(diffOpts, &maxControlLineOption{newValue: newValue.(int)}) + case "maxpayload": + diffOpts = append(diffOpts, &maxPayloadOption{newValue: newValue.(int)}) case "nolog": // Ignore NoLog option since it's not parsed and only used in // testing. diff --git a/server/reload_test.go b/server/reload_test.go index 62a81d67..6b74ad28 100644 --- a/server/reload_test.go +++ b/server/reload_test.go @@ -185,7 +185,7 @@ func TestConfigReload(t *testing.T) { t.Fatal("Expected Logtime to be true") } if updated.LogFile != "/tmp/gnatsd-2.log" { - t.Fatalf("LogFile is incorrect.\nexpected /tmp/gnatsd-2.log\ngot: %s", updated.LogFile) + t.Fatalf("LogFile is incorrect.\nexpected: /tmp/gnatsd-2.log\ngot: %s", updated.LogFile) } if updated.TLSConfig == nil { t.Fatal("Expected TLSConfig to be non-nil") @@ -211,6 +211,12 @@ func TestConfigReload(t *testing.T) { if !updated.Cluster.NoAdvertise { t.Fatal("Expected NoAdvertise to be true") } + if updated.PidFile != "/tmp/gnatsd.pid" { + t.Fatalf("PidFile is incorrect.\nexpected: /tmp/gnatsd.pid\ngot: %s", updated.PidFile) + } + if updated.MaxControlLine != 512 { + t.Fatalf("MaxControlLine is incorrect.\nexpected: 512\ngot: %d", updated.MaxControlLine) + } } // Ensure Reload supports TLS config changes. Test this by starting a server @@ -1395,7 +1401,7 @@ func TestConfigReloadMaxConnections(t *testing.T) { select { case <-closed: case <-time.After(2 * time.Second): - t.Fatal("Expected error") + t.Fatal("Expected to be disconnected") } if numClients := server.NumClients(); numClients != 1 { @@ -1409,6 +1415,76 @@ func TestConfigReloadMaxConnections(t *testing.T) { } } +// Ensure reload supports changing the max payload size. Test this by starting +// a server with the default size limit, ensuring publishes work, reloading +// with a restrictive limit, and ensuring publishing an oversized message fails +// and disconnects the client. +func TestConfigReloadMaxPayload(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/basic.conf") + defer os.Remove(config) + defer server.Shutdown() + + addr := fmt.Sprintf("nats://%s:%d", opts.Host, server.Addr().(*net.TCPAddr).Port) + nc, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc.Close() + closed := make(chan struct{}) + nc.SetDisconnectHandler(func(*nats.Conn) { + closed <- struct{}{} + }) + + conn, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer conn.Close() + sub, err := conn.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + conn.Flush() + + // Ensure we can publish as a sanity check. + if err := nc.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + nc.Flush() + _, err = sub.NextMsg(2 * time.Second) + if err != nil { + t.Fatalf("Error receiving message: %v", err) + } + + // Set max payload to one. + if err := os.Remove(config); err != nil { + t.Fatalf("Error deleting symlink: %v", err) + } + if err := os.Symlink("./configs/reload/max_payload.conf", config); err != nil { + t.Fatalf("Error creating symlink: %v (ensure you have privileges)", err) + } + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure oversized messages don't get delivered and the client is + // disconnected. + if err := nc.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + nc.Flush() + _, err = sub.NextMsg(20 * time.Millisecond) + if err != nats.ErrTimeout { + t.Fatalf("Expected ErrTimeout, got: %v", err) + } + + select { + case <-closed: + case <-time.After(2 * time.Second): + t.Fatal("Expected to be disconnected") + } +} + func runServerWithSymlinkConfig(t *testing.T, symlinkName, configName string) (*Server, *Options, string) { opts, config := newOptionsWithSymlinkConfig(t, symlinkName, configName) opts.NoLog = true diff --git a/server/server.go b/server/server.go index cec881e3..11d4dea7 100644 --- a/server/server.go +++ b/server/server.go @@ -224,12 +224,9 @@ func (s *Server) isRunning() bool { return s.running } -func (s *Server) logPid() { +func (s *Server) logPid() error { pidStr := strconv.Itoa(os.Getpid()) - err := ioutil.WriteFile(s.getOpts().PidFile, []byte(pidStr), 0660) - if err != nil { - PrintAndDie(fmt.Sprintf("Could not write pidfile: %v\n", err)) - } + return ioutil.WriteFile(s.getOpts().PidFile, []byte(pidStr), 0660) } // Start up the server, this will block. @@ -252,7 +249,9 @@ func (s *Server) Start() { // Log the pid to a file if opts.PidFile != _EMPTY_ { - s.logPid() + if err := s.logPid(); err != nil { + PrintAndDie(fmt.Sprintf("Could not write pidfile: %v\n", err)) + } } // Start monitoring if needed @@ -638,7 +637,10 @@ func (s *Server) HTTPHandler() http.Handler { } func (s *Server) createClient(conn net.Conn) *client { - c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: s.info.MaxPayload, start: time.Now()} + // Snapshot server options. + opts := s.getOpts() + + c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: opts.MaxPayload, start: time.Now()} // Grab JSON info string s.mu.Lock() @@ -673,9 +675,6 @@ func (s *Server) createClient(conn net.Conn) *client { return c } - // Snapshot server options. - opts := s.getOpts() - // If there is a max connections specified, check that adding // this new client would not push us over the max if opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn {