mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 03:38:42 -07:00
extendable auth methods
This commit is contained in:
19
auth/plain.go
Normal file
19
auth/plain.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/apcera/gnatsd/server"
|
||||
)
|
||||
|
||||
type Plain struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
func (p *Plain) Check(c server.ClientAuth) bool {
|
||||
opts := c.GetOpts()
|
||||
if p.Username != opts.Username || p.Password != opts.Password {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
18
auth/token.go
Normal file
18
auth/token.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/apcera/gnatsd/server"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Token string
|
||||
}
|
||||
|
||||
func (p *Token) Check(c server.ClientAuth) bool {
|
||||
opts := c.GetOpts()
|
||||
if p.Token != opts.Authorization {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
47
gnatsd.go
47
gnatsd.go
@@ -6,6 +6,7 @@ import (
|
||||
"flag"
|
||||
"strings"
|
||||
|
||||
"github.com/apcera/gnatsd/auth"
|
||||
"github.com/apcera/gnatsd/logger"
|
||||
"github.com/apcera/gnatsd/server"
|
||||
)
|
||||
@@ -94,25 +95,45 @@ func main() {
|
||||
// Create the server with appropriate options.
|
||||
s := server.New(&opts)
|
||||
|
||||
// Builds and set the logger based on the flags
|
||||
s.SetLogger(buildLogger(&opts), opts.Debug, opts.Trace)
|
||||
// Configure the authentication mechanism
|
||||
configureAuth(s, &opts)
|
||||
|
||||
// Configure the logger based on the flags
|
||||
configureLogger(s, &opts)
|
||||
|
||||
// Start things up. Block here until done.
|
||||
s.Start()
|
||||
}
|
||||
|
||||
func buildLogger(opts *server.Options) server.Logger {
|
||||
if opts.LogFile != "" {
|
||||
return logger.NewFileLogger(opts.LogFile, opts.Logtime, opts.Debug, opts.Trace)
|
||||
}
|
||||
func configureAuth(s *server.Server, opts *server.Options) {
|
||||
if opts.Username != "" {
|
||||
auth := &auth.Plain{
|
||||
Username: opts.Username,
|
||||
Password: opts.Password,
|
||||
}
|
||||
|
||||
if opts.RemoteSyslog != "" {
|
||||
return logger.NewRemoteSysLogger(opts.RemoteSyslog, opts.Debug, opts.Trace)
|
||||
}
|
||||
s.SetAuthMethod(auth)
|
||||
} else if opts.Authorization != "" {
|
||||
auth := &auth.Token{
|
||||
Token: opts.Authorization,
|
||||
}
|
||||
|
||||
if opts.Syslog {
|
||||
return logger.NewSysLogger(opts.Debug, opts.Trace)
|
||||
s.SetAuthMethod(auth)
|
||||
}
|
||||
|
||||
return logger.NewStdLogger(opts.Logtime, opts.Debug, opts.Trace, true)
|
||||
}
|
||||
|
||||
func configureLogger(s *server.Server, opts *server.Options) {
|
||||
var log server.Logger
|
||||
|
||||
if opts.LogFile != "" {
|
||||
log = logger.NewFileLogger(opts.LogFile, opts.Logtime, opts.Debug, opts.Trace)
|
||||
} else if opts.RemoteSyslog != "" {
|
||||
log = logger.NewRemoteSysLogger(opts.RemoteSyslog, opts.Debug, opts.Trace)
|
||||
} else if opts.Syslog {
|
||||
log = logger.NewSysLogger(opts.Debug, opts.Trace)
|
||||
} else {
|
||||
log = logger.NewStdLogger(opts.Logtime, opts.Debug, opts.Trace, true)
|
||||
}
|
||||
|
||||
s.SetLogger(log, opts.Debug, opts.Trace)
|
||||
}
|
||||
|
||||
11
server/auth.go
Normal file
11
server/auth.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright 2012-2014 Apcera Inc. All rights reserved.
|
||||
|
||||
package server
|
||||
|
||||
type Auth interface {
|
||||
Check(c ClientAuth) bool
|
||||
}
|
||||
|
||||
type ClientAuth interface {
|
||||
GetOpts() *clientOpts
|
||||
}
|
||||
@@ -69,6 +69,10 @@ func (c *client) String() (id string) {
|
||||
return id
|
||||
}
|
||||
|
||||
func (c *client) GetOpts() *clientOpts {
|
||||
return &c.opts
|
||||
}
|
||||
|
||||
type subscription struct {
|
||||
client *client
|
||||
subject []byte
|
||||
@@ -223,7 +227,7 @@ func (c *client) processRouteInfo(info *Info) {
|
||||
}
|
||||
}
|
||||
|
||||
// Process the information messages from clients and other routes.
|
||||
// Process the information messages from Clients and other routes.
|
||||
func (c *client) processInfo(arg []byte) error {
|
||||
info := Info{}
|
||||
if err := json.Unmarshal(arg, &info); err != nil {
|
||||
|
||||
@@ -22,6 +22,12 @@ type serverInfo struct {
|
||||
MaxPayload int64 `json:"max_payload"`
|
||||
}
|
||||
|
||||
type mockAuth struct{}
|
||||
|
||||
func (m *mockAuth) Check(c ClientAuth) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func createClientAsync(ch chan *client, s *Server, cli net.Conn) {
|
||||
go func() {
|
||||
c := s.createClient(cli)
|
||||
@@ -38,12 +44,17 @@ var defaultServerOptions = Options{
|
||||
NoSigs: true,
|
||||
}
|
||||
|
||||
func rawSetup(serverOption Options) (*Server, *client, *bufio.Reader, string) {
|
||||
func rawSetup(serverOptions Options) (*Server, *client, *bufio.Reader, string) {
|
||||
cli, srv := net.Pipe()
|
||||
cr := bufio.NewReaderSize(cli, defaultBufSize)
|
||||
s := New(&serverOption)
|
||||
s := New(&serverOptions)
|
||||
if serverOptions.Authorization != "" {
|
||||
s.SetAuthMethod(&mockAuth{})
|
||||
}
|
||||
|
||||
ch := make(chan *client)
|
||||
createClientAsync(ch, s, srv)
|
||||
|
||||
l, _ := cr.ReadString('\n')
|
||||
|
||||
// Grab client
|
||||
@@ -519,7 +530,8 @@ func TestAuthorizationTimeout(t *testing.T) {
|
||||
serverOptions := defaultServerOptions
|
||||
serverOptions.Authorization = "my_token"
|
||||
serverOptions.AuthTimeout = 1
|
||||
_, _, cr, _ := rawSetup(serverOptions)
|
||||
s, _, cr, _ := rawSetup(serverOptions)
|
||||
s.SetAuthMethod(&mockAuth{})
|
||||
|
||||
time.Sleep(secondsToDuration(serverOptions.AuthTimeout))
|
||||
l, err := cr.ReadString('\n')
|
||||
|
||||
@@ -80,5 +80,6 @@ func executeLogCall(f func(logger Logger, format string, v ...interface{}), form
|
||||
if log.logger == nil {
|
||||
return
|
||||
}
|
||||
|
||||
f(log.logger, format, args...)
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ type Server struct {
|
||||
sl *sublist.Sublist
|
||||
gcid uint64
|
||||
opts *Options
|
||||
auth Auth
|
||||
trace bool
|
||||
debug bool
|
||||
running bool
|
||||
@@ -78,10 +79,6 @@ func New(opts *Options) *Server {
|
||||
SslRequired: false,
|
||||
MaxPayload: MAX_PAYLOAD_SIZE,
|
||||
}
|
||||
// Check for Auth items
|
||||
if opts.Username != "" || opts.Authorization != "" {
|
||||
info.AuthRequired = true
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
info: info,
|
||||
@@ -106,17 +103,27 @@ func New(opts *Options) *Server {
|
||||
// Used to kick out all of the route
|
||||
// connect Go routines.
|
||||
s.rcQuit = make(chan bool)
|
||||
s.generateServerInfoJSON()
|
||||
s.handleSignals()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Sets the authentication method
|
||||
func (s *Server) SetAuthMethod(authMethod Auth) {
|
||||
s.info.AuthRequired = true
|
||||
s.auth = authMethod
|
||||
|
||||
s.generateServerInfoJSON()
|
||||
}
|
||||
|
||||
func (s *Server) generateServerInfoJSON() {
|
||||
// Generate the info json
|
||||
b, err := json.Marshal(s.info)
|
||||
if err != nil {
|
||||
Fatalf("Error marshalling INFO JSON: %+v\n", err)
|
||||
}
|
||||
|
||||
s.infoJSON = []byte(fmt.Sprintf("INFO %s %s", b, CR_LF))
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// PrintAndDie is exported for access in other packages.
|
||||
@@ -407,18 +414,11 @@ func (s *Server) sendInfo(c *client) {
|
||||
}
|
||||
|
||||
func (s *Server) checkClientAuth(c *client) bool {
|
||||
if !s.info.AuthRequired {
|
||||
if s.auth == nil {
|
||||
return true
|
||||
}
|
||||
// We require auth here, check the client
|
||||
// Authorization tokens trump username/password
|
||||
if s.opts.Authorization != "" {
|
||||
return s.opts.Authorization == c.opts.Authorization
|
||||
} else if s.opts.Username != c.opts.Username ||
|
||||
s.opts.Password != c.opts.Password {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
return s.auth.Check(c)
|
||||
}
|
||||
|
||||
func (s *Server) checkRouterAuth(c *client) bool {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/apcera/gnatsd/auth"
|
||||
"github.com/apcera/gnatsd/server"
|
||||
)
|
||||
|
||||
@@ -45,7 +46,7 @@ func runAuthServerWithToken() *server.Server {
|
||||
opts := DefaultTestOptions
|
||||
opts.Port = AUTH_PORT
|
||||
opts.Authorization = AUTH_TOKEN
|
||||
return RunServer(&opts)
|
||||
return RunServerWithAuth(&opts, &auth.Token{Token: AUTH_TOKEN})
|
||||
}
|
||||
|
||||
func TestNoAuthClient(t *testing.T) {
|
||||
@@ -111,7 +112,9 @@ func runAuthServerWithUserPass() *server.Server {
|
||||
opts.Port = AUTH_PORT
|
||||
opts.Username = AUTH_USER
|
||||
opts.Password = AUTH_PASS
|
||||
return RunServer(&opts)
|
||||
|
||||
auth := &auth.Plain{Username: AUTH_USER, Password: AUTH_PASS}
|
||||
return RunServerWithAuth(&opts, auth)
|
||||
}
|
||||
|
||||
func TestNoUserOrPasswordClient(t *testing.T) {
|
||||
|
||||
@@ -44,6 +44,11 @@ func runDefaultServer() *server.Server {
|
||||
|
||||
// New Go Routine based server
|
||||
func RunServer(opts *server.Options) *server.Server {
|
||||
return RunServerWithAuth(opts, nil)
|
||||
}
|
||||
|
||||
// New Go Routine based server with auth
|
||||
func RunServerWithAuth(opts *server.Options, auth server.Auth) *server.Server {
|
||||
if opts == nil {
|
||||
opts = &DefaultTestOptions
|
||||
}
|
||||
@@ -52,6 +57,10 @@ func RunServer(opts *server.Options) *server.Server {
|
||||
panic("No NATS Server object returned.")
|
||||
}
|
||||
|
||||
if auth != nil {
|
||||
s.SetAuthMethod(auth)
|
||||
}
|
||||
|
||||
// Run server in Go routine.
|
||||
go s.Start()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user