From fb79a982ef0aea0b4b30a039218d46ae43dc1557 Mon Sep 17 00:00:00 2001 From: Derek Collison Date: Tue, 18 Dec 2012 16:56:49 -0800 Subject: [PATCH] Added ability to run server as a Go routine --- gnatsd.go | 2 +- server/client.go | 4 ++++ server/client_test.go | 2 +- server/server.go | 54 +++++++++++++++++++++++++++++++++++++------ test/gosrv_test.go | 38 ++++++++++++++++++++++++++++++ test/test.go | 43 ++++++++++++++++++++++++++++++++++ 6 files changed, 134 insertions(+), 9 deletions(-) create mode 100644 test/gosrv_test.go diff --git a/gnatsd.go b/gnatsd.go index 110d26b5..c2b0eb57 100644 --- a/gnatsd.go +++ b/gnatsd.go @@ -46,7 +46,7 @@ func main() { // TBD: Parse config if given - s := server.New(opts) + s := server.New(&opts) s.AcceptLoop() } diff --git a/server/client.go b/server/client.go index 1d7bf28f..9a444ada 100644 --- a/server/client.go +++ b/server/client.go @@ -504,6 +504,10 @@ func (c *client) closeConnection() { c.mu.Unlock() if c.srv != nil { + // Unregister + c.srv.removeClient(c) + + // Remove subscriptions. for _, s := range subs { if sub, ok := s.(*subscription); ok { c.srv.sl.Remove(sub.subject, sub) diff --git a/server/client_test.go b/server/client_test.go index 1c1662ec..3d16f553 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -35,7 +35,7 @@ func createClientAsync(ch chan *client, s *Server, cli net.Conn) { func rawSetup() (*Server, *client, *bufio.Reader, string) { cli, srv := net.Pipe() cr := bufio.NewReaderSize(cli, defaultBufSize) - s := New(Options{}) + s := New(&Options{}) ch := make(chan *client) createClientAsync(ch, s, srv) l, _ := cr.ReadString('\n') diff --git a/server/server.go b/server/server.go index 464093e9..54c517de 100644 --- a/server/server.go +++ b/server/server.go @@ -22,6 +22,7 @@ type Options struct { Trace bool Debug bool NoLog bool + NoSigs bool Logtime bool MaxConn int Username string @@ -44,9 +45,12 @@ type Server struct { infoJson []byte sl *sublist.Sublist gcid uint64 - opts Options + opts *Options trace bool debug bool + running bool + listener net.Listener + clients map[uint64]*client } func processOptions(opt *Options) { @@ -62,8 +66,8 @@ func processOptions(opt *Options) { } } -func New(opts Options) *Server { - processOptions(&opts) +func New(opts *Options) *Server { + processOptions(opts) info := Info{ Id: genId(), Version: VERSION, @@ -87,6 +91,9 @@ func New(opts Options) *Server { // Setup logging with flags s.LogInit() + // For tracing clients + s.clients = make(map[uint64]*client) + // Generate the info json b, err := json.Marshal(s.info) if err != nil { @@ -101,17 +108,38 @@ func New(opts Options) *Server { // Signal Handling func (s *Server) handleSignals() { + if s.opts.NoSigs { + return + } c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) - go func(){ + go func() { for sig := range c { Debugf("Trapped Signal; %v", sig) + // FIXME, trip running? Log("Server Exiting..") os.Exit(0) } }() } +// Shutdown will shutdown the server instance by kicking out the AcceptLoop +// and closing all associated clients. +func (s *Server) Shutdown() { + s.running = false + + // Close client connections + for _, c := range s.clients { + c.closeConnection() + } + + // Kick AcceptLoop() + if s.listener != nil { + s.listener.Close() + s.listener = nil + } +} + func (s *Server) AcceptLoop() { Logf("Starting nats-server version %s on port %d", VERSION, s.opts.Port) @@ -121,8 +149,13 @@ func (s *Server) AcceptLoop() { Fatalf("Error listening on port: %d - %v", s.opts.Port, e) return } - for { - conn, err := l.Accept() + + // Setup state that can enable shutdown + s.listener = l + s.running = true + + for s.running { + conn, err := s.listener.Accept() if err != nil { if ne, ok := err.(net.Error); ok && ne.Temporary() { Logf("Accept error: %v", err) @@ -131,6 +164,8 @@ func (s *Server) AcceptLoop() { } s.createClient(conn) } + Log("Server Exiting..") + os.Exit(0) } func clientConnStr(conn net.Conn) interface{} { @@ -145,7 +180,6 @@ func (s *Server) createClient(conn net.Conn) *client { c := &client{srv: s, conn: conn, opts: defaultOpts} c.cid = atomic.AddUint64(&s.gcid, 1) - // FIXME, should write be double? c.bw = bufio.NewWriterSize(c.conn, defaultBufSize) c.subs = hashmap.New() @@ -164,6 +198,8 @@ func (s *Server) createClient(conn net.Conn) *client { if s.info.AuthRequired { c.atmr = time.AfterFunc(AUTH_TIMEOUT, func() { c.authViolation() }) } + // Register with the server. + s.clients[c.cid] = c return c } @@ -187,3 +223,7 @@ func (s *Server) checkAuth(c *client) bool { } return true } + +func (s *Server) removeClient(c *client) { + delete(s.clients, c.cid) +} \ No newline at end of file diff --git a/test/gosrv_test.go b/test/gosrv_test.go new file mode 100644 index 00000000..b1e05c2c --- /dev/null +++ b/test/gosrv_test.go @@ -0,0 +1,38 @@ +// Copyright 2012 Apcera Inc. All rights reserved. + +package test + +import ( + "fmt" + "testing" + "runtime" + "time" +) + +func TestSimpleGoServerShutdown(t *testing.T) { + s := runDefaultServer() + base := runtime.NumGoroutine() + s.Shutdown() + time.Sleep(10 * time.Millisecond) + delta := (runtime.NumGoroutine() - base) + if delta > 0 { + t.Fatalf("%d Go routines still exist post Shutdown()", delta) + } +} + +func TestGoServerShutdownWithClients(t *testing.T) { + fmt.Printf("before: %d go routines\n", runtime.NumGoroutine()) + s := runDefaultServer() + for i := 0 ; i < 10 ; i++ { + createClientConn(t, "localhost", 4222) + } + base := runtime.NumGoroutine() + s.Shutdown() + time.Sleep(10 * time.Millisecond) + delta := (runtime.NumGoroutine() - base) + if delta > 0 { + t.Fatalf("%d Go routines still exist post Shutdown()", delta) + } +} + + diff --git a/test/test.go b/test/test.go index 9c4ef215..8056cfa4 100644 --- a/test/test.go +++ b/test/test.go @@ -28,6 +28,49 @@ type tLogger interface { Errorf(format string, args ...interface{}) } +var defaultServerOptions = server.Options{ + Host: "localhost", + Port: 4222, + Trace: false, + Debug: false, + NoLog: true, + NoSigs: true, +} + +func runDefaultServer() *server.Server { + return runServer(&defaultServerOptions) +} + +// New Go Routine based server +func runServer(opts *server.Options) *server.Server { + if opts == nil { + opts = &defaultServerOptions + } + s := server.New(opts) + if s == nil { + panic("No nats server object returned.") + } + + go s.AcceptLoop() + + // Make sure we are running and can bind before returning. + addr := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + end := time.Now().Add(time.Second * 10) + for time.Now().Before(end) { + conn, err := net.Dial("tcp", addr) + if err != nil { + time.Sleep(50 * time.Millisecond) + // Retry + continue + } + conn.Close() + return s + } + + panic("Unable to start NATs") + return nil +} + func startServer(t tLogger, port int, other string) *natsServer { var s natsServer args := fmt.Sprintf("-p %d %s", port, other)