extendable auth methods

This commit is contained in:
Máximo Cuadros Ortiz
2014-11-26 23:38:46 +01:00
parent 54ac589d82
commit 65ae9c16f2
10 changed files with 134 additions and 36 deletions

19
auth/plain.go Normal file
View File

@@ -0,0 +1,19 @@
package auth
import (
"github.com/apcera/gnatsd/server"
)
type Plain struct {
Username string
Password string
}
func (p *Plain) Check(c server.ClientAuth) bool {
opts := c.GetOpts()
if p.Username != opts.Username || p.Password != opts.Password {
return false
}
return true
}

18
auth/token.go Normal file
View File

@@ -0,0 +1,18 @@
package auth
import (
"github.com/apcera/gnatsd/server"
)
type Token struct {
Token string
}
func (p *Token) Check(c server.ClientAuth) bool {
opts := c.GetOpts()
if p.Token != opts.Authorization {
return false
}
return true
}

View File

@@ -6,6 +6,7 @@ import (
"flag"
"strings"
"github.com/apcera/gnatsd/auth"
"github.com/apcera/gnatsd/logger"
"github.com/apcera/gnatsd/server"
)
@@ -94,25 +95,45 @@ func main() {
// Create the server with appropriate options.
s := server.New(&opts)
// Builds and set the logger based on the flags
s.SetLogger(buildLogger(&opts), opts.Debug, opts.Trace)
// Configure the authentication mechanism
configureAuth(s, &opts)
// Configure the logger based on the flags
configureLogger(s, &opts)
// Start things up. Block here until done.
s.Start()
}
func buildLogger(opts *server.Options) server.Logger {
if opts.LogFile != "" {
return logger.NewFileLogger(opts.LogFile, opts.Logtime, opts.Debug, opts.Trace)
}
func configureAuth(s *server.Server, opts *server.Options) {
if opts.Username != "" {
auth := &auth.Plain{
Username: opts.Username,
Password: opts.Password,
}
if opts.RemoteSyslog != "" {
return logger.NewRemoteSysLogger(opts.RemoteSyslog, opts.Debug, opts.Trace)
}
s.SetAuthMethod(auth)
} else if opts.Authorization != "" {
auth := &auth.Token{
Token: opts.Authorization,
}
if opts.Syslog {
return logger.NewSysLogger(opts.Debug, opts.Trace)
s.SetAuthMethod(auth)
}
return logger.NewStdLogger(opts.Logtime, opts.Debug, opts.Trace, true)
}
func configureLogger(s *server.Server, opts *server.Options) {
var log server.Logger
if opts.LogFile != "" {
log = logger.NewFileLogger(opts.LogFile, opts.Logtime, opts.Debug, opts.Trace)
} else if opts.RemoteSyslog != "" {
log = logger.NewRemoteSysLogger(opts.RemoteSyslog, opts.Debug, opts.Trace)
} else if opts.Syslog {
log = logger.NewSysLogger(opts.Debug, opts.Trace)
} else {
log = logger.NewStdLogger(opts.Logtime, opts.Debug, opts.Trace, true)
}
s.SetLogger(log, opts.Debug, opts.Trace)
}

11
server/auth.go Normal file
View File

@@ -0,0 +1,11 @@
// Copyright 2012-2014 Apcera Inc. All rights reserved.
package server
type Auth interface {
Check(c ClientAuth) bool
}
type ClientAuth interface {
GetOpts() *clientOpts
}

View File

@@ -69,6 +69,10 @@ func (c *client) String() (id string) {
return id
}
func (c *client) GetOpts() *clientOpts {
return &c.opts
}
type subscription struct {
client *client
subject []byte
@@ -223,7 +227,7 @@ func (c *client) processRouteInfo(info *Info) {
}
}
// Process the information messages from clients and other routes.
// Process the information messages from Clients and other routes.
func (c *client) processInfo(arg []byte) error {
info := Info{}
if err := json.Unmarshal(arg, &info); err != nil {

View File

@@ -22,6 +22,12 @@ type serverInfo struct {
MaxPayload int64 `json:"max_payload"`
}
type mockAuth struct{}
func (m *mockAuth) Check(c ClientAuth) bool {
return true
}
func createClientAsync(ch chan *client, s *Server, cli net.Conn) {
go func() {
c := s.createClient(cli)
@@ -38,12 +44,17 @@ var defaultServerOptions = Options{
NoSigs: true,
}
func rawSetup(serverOption Options) (*Server, *client, *bufio.Reader, string) {
func rawSetup(serverOptions Options) (*Server, *client, *bufio.Reader, string) {
cli, srv := net.Pipe()
cr := bufio.NewReaderSize(cli, defaultBufSize)
s := New(&serverOption)
s := New(&serverOptions)
if serverOptions.Authorization != "" {
s.SetAuthMethod(&mockAuth{})
}
ch := make(chan *client)
createClientAsync(ch, s, srv)
l, _ := cr.ReadString('\n')
// Grab client
@@ -519,7 +530,8 @@ func TestAuthorizationTimeout(t *testing.T) {
serverOptions := defaultServerOptions
serverOptions.Authorization = "my_token"
serverOptions.AuthTimeout = 1
_, _, cr, _ := rawSetup(serverOptions)
s, _, cr, _ := rawSetup(serverOptions)
s.SetAuthMethod(&mockAuth{})
time.Sleep(secondsToDuration(serverOptions.AuthTimeout))
l, err := cr.ReadString('\n')

View File

@@ -80,5 +80,6 @@ func executeLogCall(f func(logger Logger, format string, v ...interface{}), form
if log.logger == nil {
return
}
f(log.logger, format, args...)
}

View File

@@ -40,6 +40,7 @@ type Server struct {
sl *sublist.Sublist
gcid uint64
opts *Options
auth Auth
trace bool
debug bool
running bool
@@ -78,10 +79,6 @@ func New(opts *Options) *Server {
SslRequired: false,
MaxPayload: MAX_PAYLOAD_SIZE,
}
// Check for Auth items
if opts.Username != "" || opts.Authorization != "" {
info.AuthRequired = true
}
s := &Server{
info: info,
@@ -106,17 +103,27 @@ func New(opts *Options) *Server {
// Used to kick out all of the route
// connect Go routines.
s.rcQuit = make(chan bool)
s.generateServerInfoJSON()
s.handleSignals()
return s
}
// Sets the authentication method
func (s *Server) SetAuthMethod(authMethod Auth) {
s.info.AuthRequired = true
s.auth = authMethod
s.generateServerInfoJSON()
}
func (s *Server) generateServerInfoJSON() {
// Generate the info json
b, err := json.Marshal(s.info)
if err != nil {
Fatalf("Error marshalling INFO JSON: %+v\n", err)
}
s.infoJSON = []byte(fmt.Sprintf("INFO %s %s", b, CR_LF))
return s
}
// PrintAndDie is exported for access in other packages.
@@ -407,18 +414,11 @@ func (s *Server) sendInfo(c *client) {
}
func (s *Server) checkClientAuth(c *client) bool {
if !s.info.AuthRequired {
if s.auth == nil {
return true
}
// We require auth here, check the client
// Authorization tokens trump username/password
if s.opts.Authorization != "" {
return s.opts.Authorization == c.opts.Authorization
} else if s.opts.Username != c.opts.Username ||
s.opts.Password != c.opts.Password {
return false
}
return true
return s.auth.Check(c)
}
func (s *Server) checkRouterAuth(c *client) bool {

View File

@@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/apcera/gnatsd/auth"
"github.com/apcera/gnatsd/server"
)
@@ -45,7 +46,7 @@ func runAuthServerWithToken() *server.Server {
opts := DefaultTestOptions
opts.Port = AUTH_PORT
opts.Authorization = AUTH_TOKEN
return RunServer(&opts)
return RunServerWithAuth(&opts, &auth.Token{Token: AUTH_TOKEN})
}
func TestNoAuthClient(t *testing.T) {
@@ -111,7 +112,9 @@ func runAuthServerWithUserPass() *server.Server {
opts.Port = AUTH_PORT
opts.Username = AUTH_USER
opts.Password = AUTH_PASS
return RunServer(&opts)
auth := &auth.Plain{Username: AUTH_USER, Password: AUTH_PASS}
return RunServerWithAuth(&opts, auth)
}
func TestNoUserOrPasswordClient(t *testing.T) {

View File

@@ -44,6 +44,11 @@ func runDefaultServer() *server.Server {
// New Go Routine based server
func RunServer(opts *server.Options) *server.Server {
return RunServerWithAuth(opts, nil)
}
// New Go Routine based server with auth
func RunServerWithAuth(opts *server.Options, auth server.Auth) *server.Server {
if opts == nil {
opts = &DefaultTestOptions
}
@@ -52,6 +57,10 @@ func RunServer(opts *server.Options) *server.Server {
panic("No NATS Server object returned.")
}
if auth != nil {
s.SetAuthMethod(auth)
}
// Run server in Go routine.
go s.Start()