// Copyright 2012-2015 Apcera Inc. All rights reserved. package server import ( "bufio" "crypto/tls" "encoding/json" "fmt" "io/ioutil" "net" "net/http" "os" "os/signal" "runtime" "strconv" "sync" "time" // Allow dynamic profiling. _ "net/http/pprof" "github.com/nats-io/gnatsd/sublist" ) // Info is the information sent to clients to help them understand information // about this server. type Info struct { ID string `json:"server_id"` Version string `json:"version"` GoVersion string `json:"go"` Host string `json:"host"` Port int `json:"port"` AuthRequired bool `json:"auth_required"` SSLRequired bool `json:"ssl_required"` // DEPRECATED: ssl json used for older clients TLSRequired bool `json:"tls_required"` TLSVerify bool `json:"tls_verify"` MaxPayload int `json:"max_payload"` Routes []RemoteInfo `json:"routes,omitempty"` } // Server is our main struct. type Server struct { gcid uint64 grid uint64 stats mu sync.Mutex info Info infoJSON []byte sl *sublist.Sublist opts *Options auth Auth trace bool debug bool running bool listener net.Listener clients map[uint64]*client routes map[uint64]*client remotes map[string]*client totalClients uint64 done chan bool start time.Time http net.Listener httpReqStats map[string]uint64 routeListener net.Listener routeInfo Info rcQuit chan bool } // Make sure all are 64bits for atomic use type stats struct { inMsgs int64 outMsgs int64 inBytes int64 outBytes int64 slowConsumers int64 } // New will setup a new server struct after parsing the options. func New(opts *Options) *Server { processOptions(opts) // Process TLS options, including whether we require client certificates. tlsReq := opts.TLSConfig != nil verify := (tlsReq == true && opts.TLSConfig.ClientAuth == tls.RequireAnyClientCert) info := Info{ ID: genID(), Version: VERSION, GoVersion: runtime.Version(), Host: opts.Host, Port: opts.Port, AuthRequired: false, TLSRequired: tlsReq, SSLRequired: tlsReq, TLSVerify: verify, MaxPayload: opts.MaxPayload, } s := &Server{ info: info, sl: sublist.New(), opts: opts, debug: opts.Debug, trace: opts.Trace, done: make(chan bool, 1), start: time.Now(), } s.mu.Lock() defer s.mu.Unlock() // For tracking clients s.clients = make(map[uint64]*client) // For tracking routes and their remote ids s.routes = make(map[uint64]*client) s.remotes = make(map[string]*client) // Used to kick out all of the route // connect Go routines. s.rcQuit = make(chan bool) s.generateServerInfoJSON() s.handleSignals() return s } // Sets the authentication method func (s *Server) SetAuthMethod(authMethod Auth) { s.mu.Lock() defer s.mu.Unlock() s.info.AuthRequired = true s.auth = authMethod s.generateServerInfoJSON() } func (s *Server) generateServerInfoJSON() { // Generate the info json b, err := json.Marshal(s.info) if err != nil { Fatalf("Error marshalling INFO JSON: %+v\n", err) } s.infoJSON = []byte(fmt.Sprintf("INFO %s %s", b, CR_LF)) } // PrintAndDie is exported for access in other packages. func PrintAndDie(msg string) { fmt.Fprintf(os.Stderr, "%s\n", msg) os.Exit(1) } // PrintServerAndExit will print our version and exit. func PrintServerAndExit() { fmt.Printf("gnatsd version %s\n", VERSION) os.Exit(0) } // Signal Handling func (s *Server) handleSignals() { if s.opts.NoSigs { return } c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) go func() { for sig := range c { Debugf("Trapped Signal; %v", sig) // FIXME, trip running? Noticef("Server Exiting..") os.Exit(0) } }() } // Protected check on running state func (s *Server) isRunning() bool { s.mu.Lock() defer s.mu.Unlock() return s.running } func (s *Server) logPid() { pidStr := strconv.Itoa(os.Getpid()) err := ioutil.WriteFile(s.opts.PidFile, []byte(pidStr), 0660) if err != nil { PrintAndDie(fmt.Sprintf("Could not write pidfile: %v\n", err)) } } // Start up the server, this will block. // Start via a Go routine if needed. func (s *Server) Start() { Noticef("Starting gnatsd version %s", VERSION) s.running = true // Log the pid to a file if s.opts.PidFile != _EMPTY_ { s.logPid() } // Start up the http server if needed. if s.opts.HTTPPort != 0 { s.StartHTTPMonitoring() } // Start up the https server if needed. if s.opts.HTTPSPort != 0 { if s.opts.TLSConfig == nil { Fatalf("TLS cert and key required for HTTPS") } s.StartHTTPSMonitoring() } // Start up routing as well if needed. if s.opts.ClusterPort != 0 { s.StartRouting() } // Pprof http endpoint for the profiler. if s.opts.ProfPort != 0 { s.StartProfiler() } // Wait for clients. s.AcceptLoop() } // Shutdown will shutdown the server instance by kicking out the AcceptLoop // and closing all associated clients. func (s *Server) Shutdown() { s.mu.Lock() // Prevent issues with multiple calls. if !s.running { s.mu.Unlock() return } s.running = false conns := make(map[uint64]*client) // Copy off the clients for i, c := range s.clients { conns[i] = c } // Copy off the routes for i, r := range s.routes { conns[i] = r } // Number of done channel responses we expect. doneExpected := 0 // Kick client AcceptLoop() if s.listener != nil { doneExpected++ s.listener.Close() s.listener = nil } // Kick route AcceptLoop() if s.routeListener != nil { doneExpected++ s.routeListener.Close() s.routeListener = nil } // Kick HTTP monitoring if its running if s.http != nil { doneExpected++ s.http.Close() s.http = nil } // Release the solicited routes connect go routines. close(s.rcQuit) s.mu.Unlock() // Close client and route connections for _, c := range conns { c.closeConnection() } // Block until the accept loops exit for doneExpected > 0 { <-s.done doneExpected-- } } // AcceptLoop is exported for easier testing. func (s *Server) AcceptLoop() { hp := net.JoinHostPort(s.opts.Host, strconv.Itoa(s.opts.Port)) Noticef("Listening for client connections on %s", hp) l, e := net.Listen("tcp", hp) if e != nil { Fatalf("Error listening on port: %s, %q", hp, e) return } // Alert of TLS enabled. if s.opts.TLSConfig != nil { Noticef("TLS required for client connections") } Noticef("gnatsd is ready") // Setup state that can enable shutdown s.mu.Lock() s.listener = l // If server was started with RANDOM_PORT (-1), opts.Port would be equal // to 0 at the beginning this function. So we need to get the actual port if s.opts.Port == 0 { // Write resolved port back to options. _, port, err := net.SplitHostPort(l.Addr().String()) if err != nil { Fatalf("Error parsing server address (%s): %s", l.Addr().String(), e) s.mu.Unlock() return } portNum, err := strconv.Atoi(port) if err != nil { Fatalf("Error parsing server address (%s): %s", l.Addr().String(), e) s.mu.Unlock() return } s.opts.Port = portNum } s.mu.Unlock() tmpDelay := ACCEPT_MIN_SLEEP for s.isRunning() { conn, err := l.Accept() if err != nil { if ne, ok := err.(net.Error); ok && ne.Temporary() { Debugf("Temporary Client Accept Error(%v), sleeping %dms", ne, tmpDelay/time.Millisecond) time.Sleep(tmpDelay) tmpDelay *= 2 if tmpDelay > ACCEPT_MAX_SLEEP { tmpDelay = ACCEPT_MAX_SLEEP } } else if s.isRunning() { Noticef("Accept error: %v", err) } continue } tmpDelay = ACCEPT_MIN_SLEEP go s.createClient(conn) } Noticef("Server Exiting..") s.done <- true } // StartProfiler is called to enable dynamic profiling. func (s *Server) StartProfiler() { Noticef("Starting profiling on http port %d", s.opts.ProfPort) hp := net.JoinHostPort(s.opts.Host, strconv.Itoa(s.opts.ProfPort)) go func() { err := http.ListenAndServe(hp, nil) if err != nil { Fatalf("error starting monitor server: %s", err) } }() } // StartHTTPMonitoring will enable the HTTP monitoring port. func (s *Server) StartHTTPMonitoring() { s.startMonitoring(false) } // StartHTTPMonitoring will enable the HTTPS monitoring port. func (s *Server) StartHTTPSMonitoring() { s.startMonitoring(true) } // HTTP endpoints const ( RootPath = "/" VarzPath = "/varz" ConnzPath = "/connz" RoutezPath = "/routez" SubszPath = "/subsz" ) // Start the monitoring server func (s *Server) startMonitoring(secure bool) { // Used to track HTTP requests s.httpReqStats = map[string]uint64{ RootPath: 0, VarzPath: 0, ConnzPath: 0, RoutezPath: 0, SubszPath: 0, } var hp string var err error if secure { hp = net.JoinHostPort(s.opts.Host, strconv.Itoa(s.opts.HTTPSPort)) Noticef("Starting https monitor on %s", hp) config := *s.opts.TLSConfig config.ClientAuth = tls.NoClientCert s.http, err = tls.Listen("tcp", hp, &config) } else { hp = net.JoinHostPort(s.opts.Host, strconv.Itoa(s.opts.HTTPPort)) Noticef("Starting http monitor on %s", hp) s.http, err = net.Listen("tcp", hp) } if err != nil { Fatalf("Can't listen to the monitor port: %v", err) } mux := http.NewServeMux() // Root mux.HandleFunc(RootPath, s.HandleRoot) // Varz mux.HandleFunc(VarzPath, s.HandleVarz) // Connz mux.HandleFunc(ConnzPath, s.HandleConnz) // Routez mux.HandleFunc(RoutezPath, s.HandleRoutez) // Subz mux.HandleFunc(SubszPath, s.HandleSubsz) // Subz alias for backwards compatibility mux.HandleFunc("/subscriptionsz", s.HandleSubsz) srv := &http.Server{ Addr: hp, Handler: mux, ReadTimeout: 2 * time.Second, WriteTimeout: 2 * time.Second, MaxHeaderBytes: 1 << 20, } go func() { srv.Serve(s.http) srv.Handler = nil s.done <- true }() } func (s *Server) createClient(conn net.Conn) *client { c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: s.info.MaxPayload, start: time.Now()} // Grab JSON info string s.mu.Lock() info := s.infoJSON authRequired := s.info.AuthRequired tlsRequired := s.info.TLSRequired s.totalClients++ s.mu.Unlock() // Grab lock c.mu.Lock() // Initialize c.initClient(tlsRequired) c.Debugf("Client connection created") // Check for Auth if authRequired { ttl := secondsToDuration(s.opts.AuthTimeout) c.setAuthTimer(ttl) } // Send our information. s.sendInfo(c, info) // Unlock to register c.mu.Unlock() // Register with the server. s.mu.Lock() s.clients[c.cid] = c s.mu.Unlock() // Check for TLS if tlsRequired { // Re-Grab lock c.mu.Lock() c.Debugf("Starting TLS client connection handshake") c.nc = tls.Server(c.nc, s.opts.TLSConfig) conn := c.nc.(*tls.Conn) // Setup the timeout ttl := secondsToDuration(s.opts.TLSTimeout) time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) conn.SetReadDeadline(time.Now().Add(ttl)) // Force handshake c.mu.Unlock() if err := conn.Handshake(); err != nil { c.Debugf("TLS handshake error: %v", err) c.sendErr("Secure Connection - TLS Required") c.closeConnection() return nil } // Reset the read deadline conn.SetReadDeadline(time.Time{}) // Re-Grab lock c.mu.Lock() // Rewrap bw c.bw = bufio.NewWriterSize(c.nc, s.opts.BufSize) // Do final client initialization // Set the Ping timer c.setPingTimer() // Spin up the read loop. go c.readLoop() c.Debugf("TLS handshake complete") cs := conn.ConnectionState() c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite)) c.mu.Unlock() } return c } // Handle closing down a connection when the handshake has timedout. func tlsTimeout(c *client, conn *tls.Conn) { c.mu.Lock() nc := c.nc c.mu.Unlock() // Check if already closed if nc == nil { return } cs := conn.ConnectionState() if cs.HandshakeComplete == false { c.Debugf("TLS handshake timeout") c.sendErr("Secure Connection - TLS Required") c.closeConnection() } } // Seems silly we have to write these func tlsVersion(ver uint16) string { switch ver { case tls.VersionTLS10: return "1.0" case tls.VersionTLS11: return "1.1" case tls.VersionTLS12: return "1.2" } return fmt.Sprintf("Unknown [%x]", ver) } // We use hex here so we don't need multiple versions func tlsCipher(cs uint16) string { switch cs { case 0x0005: return "TLS_RSA_WITH_RC4_128_SHA" case 0x000a: return "TLS_RSA_WITH_3DES_EDE_CBC_SHA" case 0x002f: return "TLS_RSA_WITH_AES_128_CBC_SHA" case 0x0035: return "TLS_RSA_WITH_AES_256_CBC_SHA" case 0xc007: return "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA" case 0xc009: return "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA" case 0xc00a: return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA" case 0xc011: return "TLS_ECDHE_RSA_WITH_RC4_128_SHA" case 0xc012: return "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA" case 0xc013: return "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA" case 0xc014: return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA" case 0xc02f: return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" case 0xc02b: return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" case 0xc030: return "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" case 0xc02c: return "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" } return fmt.Sprintf("Unknown [%x]", cs) } // Assume the lock is held upon entry. func (s *Server) sendInfo(c *client, info []byte) { c.nc.Write(info) } func (s *Server) checkClientAuth(c *client) bool { if s.auth == nil { return true } return s.auth.Check(c) } func (s *Server) checkRouterAuth(c *client) bool { if !s.routeInfo.AuthRequired { return true } if s.opts.ClusterUsername != c.opts.Username || s.opts.ClusterPassword != c.opts.Password { return false } return true } // Check auth and return boolean indicating if client is ok func (s *Server) checkAuth(c *client) bool { switch c.typ { case CLIENT: return s.checkClientAuth(c) case ROUTER: return s.checkRouterAuth(c) default: return false } } // Remove a client or route from our internal accounting. func (s *Server) removeClient(c *client) { c.mu.Lock() cid := c.cid typ := c.typ c.mu.Unlock() s.mu.Lock() switch typ { case CLIENT: delete(s.clients, cid) case ROUTER: delete(s.routes, cid) if c.route != nil { rc, ok := s.remotes[c.route.remoteID] // Only delete it if it is us.. if ok && c == rc { delete(s.remotes, c.route.remoteID) } } } s.mu.Unlock() } ///////////////////////////////////////////////////////////////// // These are some helpers for accounting in functional tests. ///////////////////////////////////////////////////////////////// // NumRoutes will report the number of registered routes. func (s *Server) NumRoutes() int { s.mu.Lock() defer s.mu.Unlock() return len(s.routes) } // NumRemotes will report number of registered remotes. func (s *Server) NumRemotes() int { s.mu.Lock() defer s.mu.Unlock() return len(s.remotes) } // NumClients will report the number of registered clients. func (s *Server) NumClients() int { s.mu.Lock() defer s.mu.Unlock() return len(s.clients) } // NumSubscriptions will report how many subscriptions are active. func (s *Server) NumSubscriptions() uint32 { s.mu.Lock() defer s.mu.Unlock() stats := s.sl.Stats() return stats.NumSubs } // Addr will return the net.Addr object for the current listener. func (s *Server) Addr() net.Addr { s.mu.Lock() defer s.mu.Unlock() if s.listener == nil { return nil } return s.listener.Addr() } // GetListenEndpoint will return a string of the form host:port suitable for // a connect. Will return nil if the server is not ready to accept client // connections. func (s *Server) GetListenEndpoint() string { s.mu.Lock() defer s.mu.Unlock() // Wait for the listener to be set, see note about RANDOM_PORT below if s.listener == nil { return "" } host := s.opts.Host // On windows, a connect with host "0.0.0.0" (or "::") will fail. // We replace it with "localhost" when that's the case. if host == "0.0.0.0" || host == "::" || host == "[::]" { host = "localhost" } // Return the opts's Host and Port. Note that the Port may be set // when the listener is started, due to the use of RANDOM_PORT return net.JoinHostPort(host, strconv.Itoa(s.opts.Port)) } // Server's ID func (s *Server) Id() string { s.mu.Lock() defer s.mu.Unlock() return s.info.ID }