diff --git a/server/websocket.go b/server/websocket.go index a89d7b1c..23de9f22 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -955,6 +955,7 @@ func (s *Server) startWebsocketServer() { if o.TLSConfig != nil { proto = wsSchemePrefixTLS config := o.TLSConfig.Clone() + config.GetConfigForClient = s.wsGetTLSConfig hl, err = tls.Listen("tcp", hp, config) } else { proto = wsSchemePrefix @@ -1028,6 +1029,17 @@ func (s *Server) startWebsocketServer() { s.mu.Unlock() } +// The TLS configuration is passed to the listener when the websocket +// "server" is setup. That prevents TLS configuration updates on reload +// from being used. By setting this function in tls.Config.GetConfigForClient +// we instruct the TLS handshake to ask for the tls configuration to be +// used for a specific client. We don't care which client, we always use +// the same TLS configuration. +func (s *Server) wsGetTLSConfig(_ *tls.ClientHelloInfo) (*tls.Config, error) { + opts := s.getOpts() + return opts.Websocket.TLSConfig, nil +} + // This is similar to createClient() but has some modifications // specific to handle websocket clients. // The comments have been kept to minimum to reduce code size. diff --git a/server/websocket_test.go b/server/websocket_test.go index c2498f5b..c3434f72 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -3643,7 +3643,63 @@ func TestWSJWTCookieUser(t *testing.T) { }) } s.Shutdown() +} +func TestWSReloadTLSConfig(t *testing.T) { + template := ` + listen: "127.0.0.1:-1" + websocket { + listen: "127.0.0.1:-1" + tls { + cert_file: '%s' + key_file: '%s' + ca_file: '../test/configs/certs/ca.pem' + } + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(template, + "../test/configs/certs/server-noip.pem", + "../test/configs/certs/server-key-noip.pem"))) + defer removeFile(t, conf) + + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port) + wsc, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating ws connection: %v", err) + } + defer wsc.Close() + + tc := &TLSConfigOpts{CaFile: "../test/configs/certs/ca.pem"} + tlsConfig, err := GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error generating TLS config: %v", err) + } + tlsConfig.ServerName = "127.0.0.1" + tlsConfig.RootCAs = tlsConfig.ClientCAs + tlsConfig.ClientCAs = nil + wsc = tls.Client(wsc, tlsConfig.Clone()) + if err := wsc.(*tls.Conn).Handshake(); err == nil || !strings.Contains(err.Error(), "SAN") { + t.Fatalf("Unexpected error: %v", err) + } + wsc.Close() + + reloadUpdateConfig(t, s, conf, fmt.Sprintf(template, + "../test/configs/certs/server-cert.pem", + "../test/configs/certs/server-key.pem")) + + wsc, err = net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error creating ws connection: %v", err) + } + defer wsc.Close() + + wsc = tls.Client(wsc, tlsConfig.Clone()) + if err := wsc.(*tls.Conn).Handshake(); err != nil { + t.Fatalf("Error on TLS handshake: %v", err) + } } // ==================================================================