diff --git a/auth/plain.go b/auth/plain.go new file mode 100644 index 00000000..49e26db2 --- /dev/null +++ b/auth/plain.go @@ -0,0 +1,19 @@ +package auth + +import ( + "github.com/apcera/gnatsd/server" +) + +type Plain struct { + Username string + Password string +} + +func (p *Plain) Check(c server.ClientAuth) bool { + opts := c.GetOpts() + if p.Username != opts.Username || p.Password != opts.Password { + return false + } + + return true +} diff --git a/auth/token.go b/auth/token.go new file mode 100644 index 00000000..62a4d227 --- /dev/null +++ b/auth/token.go @@ -0,0 +1,18 @@ +package auth + +import ( + "github.com/apcera/gnatsd/server" +) + +type Token struct { + Token string +} + +func (p *Token) Check(c server.ClientAuth) bool { + opts := c.GetOpts() + if p.Token != opts.Authorization { + return false + } + + return true +} diff --git a/gnatsd.go b/gnatsd.go index 31f96fa4..03d1769c 100644 --- a/gnatsd.go +++ b/gnatsd.go @@ -6,6 +6,7 @@ import ( "flag" "strings" + "github.com/apcera/gnatsd/auth" "github.com/apcera/gnatsd/logger" "github.com/apcera/gnatsd/server" ) @@ -94,25 +95,45 @@ func main() { // Create the server with appropriate options. s := server.New(&opts) - // Builds and set the logger based on the flags - s.SetLogger(buildLogger(&opts), opts.Debug, opts.Trace) + // Configure the authentication mechanism + configureAuth(s, &opts) + + // Configure the logger based on the flags + configureLogger(s, &opts) // Start things up. Block here until done. s.Start() } -func buildLogger(opts *server.Options) server.Logger { - if opts.LogFile != "" { - return logger.NewFileLogger(opts.LogFile, opts.Logtime, opts.Debug, opts.Trace) - } +func configureAuth(s *server.Server, opts *server.Options) { + if opts.Username != "" { + auth := &auth.Plain{ + Username: opts.Username, + Password: opts.Password, + } - if opts.RemoteSyslog != "" { - return logger.NewRemoteSysLogger(opts.RemoteSyslog, opts.Debug, opts.Trace) - } + s.SetAuthMethod(auth) + } else if opts.Authorization != "" { + auth := &auth.Token{ + Token: opts.Authorization, + } - if opts.Syslog { - return logger.NewSysLogger(opts.Debug, opts.Trace) + s.SetAuthMethod(auth) } - - return logger.NewStdLogger(opts.Logtime, opts.Debug, opts.Trace, true) +} + +func configureLogger(s *server.Server, opts *server.Options) { + var log server.Logger + + if opts.LogFile != "" { + log = logger.NewFileLogger(opts.LogFile, opts.Logtime, opts.Debug, opts.Trace) + } else if opts.RemoteSyslog != "" { + log = logger.NewRemoteSysLogger(opts.RemoteSyslog, opts.Debug, opts.Trace) + } else if opts.Syslog { + log = logger.NewSysLogger(opts.Debug, opts.Trace) + } else { + log = logger.NewStdLogger(opts.Logtime, opts.Debug, opts.Trace, true) + } + + s.SetLogger(log, opts.Debug, opts.Trace) } diff --git a/server/auth.go b/server/auth.go new file mode 100644 index 00000000..f80f5319 --- /dev/null +++ b/server/auth.go @@ -0,0 +1,11 @@ +// Copyright 2012-2014 Apcera Inc. All rights reserved. + +package server + +type Auth interface { + Check(c ClientAuth) bool +} + +type ClientAuth interface { + GetOpts() *clientOpts +} diff --git a/server/client.go b/server/client.go index 30f418ea..c2dfe384 100644 --- a/server/client.go +++ b/server/client.go @@ -69,6 +69,10 @@ func (c *client) String() (id string) { return id } +func (c *client) GetOpts() *clientOpts { + return &c.opts +} + type subscription struct { client *client subject []byte @@ -223,7 +227,7 @@ func (c *client) processRouteInfo(info *Info) { } } -// Process the information messages from clients and other routes. +// Process the information messages from Clients and other routes. func (c *client) processInfo(arg []byte) error { info := Info{} if err := json.Unmarshal(arg, &info); err != nil { diff --git a/server/client_test.go b/server/client_test.go index 4e3f4f6a..46b95413 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -22,6 +22,12 @@ type serverInfo struct { MaxPayload int64 `json:"max_payload"` } +type mockAuth struct{} + +func (m *mockAuth) Check(c ClientAuth) bool { + return true +} + func createClientAsync(ch chan *client, s *Server, cli net.Conn) { go func() { c := s.createClient(cli) @@ -38,12 +44,17 @@ var defaultServerOptions = Options{ NoSigs: true, } -func rawSetup(serverOption Options) (*Server, *client, *bufio.Reader, string) { +func rawSetup(serverOptions Options) (*Server, *client, *bufio.Reader, string) { cli, srv := net.Pipe() cr := bufio.NewReaderSize(cli, defaultBufSize) - s := New(&serverOption) + s := New(&serverOptions) + if serverOptions.Authorization != "" { + s.SetAuthMethod(&mockAuth{}) + } + ch := make(chan *client) createClientAsync(ch, s, srv) + l, _ := cr.ReadString('\n') // Grab client @@ -519,7 +530,8 @@ func TestAuthorizationTimeout(t *testing.T) { serverOptions := defaultServerOptions serverOptions.Authorization = "my_token" serverOptions.AuthTimeout = 1 - _, _, cr, _ := rawSetup(serverOptions) + s, _, cr, _ := rawSetup(serverOptions) + s.SetAuthMethod(&mockAuth{}) time.Sleep(secondsToDuration(serverOptions.AuthTimeout)) l, err := cr.ReadString('\n') diff --git a/server/log.go b/server/log.go index 3260c55c..e1dc1266 100644 --- a/server/log.go +++ b/server/log.go @@ -80,5 +80,6 @@ func executeLogCall(f func(logger Logger, format string, v ...interface{}), form if log.logger == nil { return } + f(log.logger, format, args...) } diff --git a/server/server.go b/server/server.go index 4675b11d..ee371546 100644 --- a/server/server.go +++ b/server/server.go @@ -40,6 +40,7 @@ type Server struct { sl *sublist.Sublist gcid uint64 opts *Options + auth Auth trace bool debug bool running bool @@ -78,10 +79,6 @@ func New(opts *Options) *Server { SslRequired: false, MaxPayload: MAX_PAYLOAD_SIZE, } - // Check for Auth items - if opts.Username != "" || opts.Authorization != "" { - info.AuthRequired = true - } s := &Server{ info: info, @@ -106,17 +103,27 @@ func New(opts *Options) *Server { // Used to kick out all of the route // connect Go routines. s.rcQuit = make(chan bool) + s.generateServerInfoJSON() s.handleSignals() + return s +} + +// Sets the authentication method +func (s *Server) SetAuthMethod(authMethod Auth) { + s.info.AuthRequired = true + s.auth = authMethod + + s.generateServerInfoJSON() +} + +func (s *Server) generateServerInfoJSON() { // Generate the info json b, err := json.Marshal(s.info) if err != nil { Fatalf("Error marshalling INFO JSON: %+v\n", err) } - s.infoJSON = []byte(fmt.Sprintf("INFO %s %s", b, CR_LF)) - - return s } // PrintAndDie is exported for access in other packages. @@ -407,18 +414,11 @@ func (s *Server) sendInfo(c *client) { } func (s *Server) checkClientAuth(c *client) bool { - if !s.info.AuthRequired { + if s.auth == nil { return true } - // We require auth here, check the client - // Authorization tokens trump username/password - if s.opts.Authorization != "" { - return s.opts.Authorization == c.opts.Authorization - } else if s.opts.Username != c.opts.Username || - s.opts.Password != c.opts.Password { - return false - } - return true + + return s.auth.Check(c) } func (s *Server) checkRouterAuth(c *client) bool { diff --git a/test/auth_test.go b/test/auth_test.go index bc623d46..540959db 100644 --- a/test/auth_test.go +++ b/test/auth_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/apcera/gnatsd/auth" "github.com/apcera/gnatsd/server" ) @@ -45,7 +46,7 @@ func runAuthServerWithToken() *server.Server { opts := DefaultTestOptions opts.Port = AUTH_PORT opts.Authorization = AUTH_TOKEN - return RunServer(&opts) + return RunServerWithAuth(&opts, &auth.Token{Token: AUTH_TOKEN}) } func TestNoAuthClient(t *testing.T) { @@ -111,7 +112,9 @@ func runAuthServerWithUserPass() *server.Server { opts.Port = AUTH_PORT opts.Username = AUTH_USER opts.Password = AUTH_PASS - return RunServer(&opts) + + auth := &auth.Plain{Username: AUTH_USER, Password: AUTH_PASS} + return RunServerWithAuth(&opts, auth) } func TestNoUserOrPasswordClient(t *testing.T) { diff --git a/test/test.go b/test/test.go index f92affa2..26626bba 100644 --- a/test/test.go +++ b/test/test.go @@ -44,6 +44,11 @@ func runDefaultServer() *server.Server { // New Go Routine based server func RunServer(opts *server.Options) *server.Server { + return RunServerWithAuth(opts, nil) +} + +// New Go Routine based server with auth +func RunServerWithAuth(opts *server.Options, auth server.Auth) *server.Server { if opts == nil { opts = &DefaultTestOptions } @@ -52,6 +57,10 @@ func RunServer(opts *server.Options) *server.Server { panic("No NATS Server object returned.") } + if auth != nil { + s.SetAuthMethod(auth) + } + // Run server in Go routine. go s.Start()