mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-16 11:04:42 -07:00
Merge pull request #1309 from nats-io/websocket
[ADDED] Websocket support
This commit is contained in:
158
server/client.go
158
server/client.go
@@ -232,6 +232,7 @@ type client struct {
|
||||
route *route
|
||||
gw *gateway
|
||||
leaf *leaf
|
||||
ws *websocket
|
||||
|
||||
// To keep track of gateway replies mapping
|
||||
gwrm map[string]*gwReplyMap
|
||||
@@ -270,7 +271,6 @@ type outbound struct {
|
||||
mp int64 // Snapshot of max pending for client.
|
||||
lft time.Duration // Last flush time for Write.
|
||||
stc chan struct{} // Stall chan we create to slow down producers on overrun, e.g. fan-in.
|
||||
lwb int32 // Last byte size of Write.
|
||||
}
|
||||
|
||||
type perm struct {
|
||||
@@ -484,16 +484,23 @@ func (c *client) initClient() {
|
||||
|
||||
// snapshot the string version of the connection
|
||||
var conn string
|
||||
if ip, ok := c.nc.(*net.TCPConn); ok {
|
||||
conn = ip.RemoteAddr().String()
|
||||
host, port, _ := net.SplitHostPort(conn)
|
||||
iPort, _ := strconv.Atoi(port)
|
||||
c.host, c.port = host, uint16(iPort)
|
||||
if c.nc != nil {
|
||||
if addr := c.nc.RemoteAddr(); addr != nil {
|
||||
if conn = addr.String(); conn != _EMPTY_ {
|
||||
host, port, _ := net.SplitHostPort(conn)
|
||||
iPort, _ := strconv.Atoi(port)
|
||||
c.host, c.port = host, uint16(iPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch c.kind {
|
||||
case CLIENT:
|
||||
c.ncs = fmt.Sprintf("%s - cid:%d", conn, c.cid)
|
||||
name := "cid"
|
||||
if c.ws != nil {
|
||||
name = "wid"
|
||||
}
|
||||
c.ncs = fmt.Sprintf("%s - %s:%d", conn, name, c.cid)
|
||||
case ROUTER:
|
||||
c.ncs = fmt.Sprintf("%s - rid:%d", conn, c.cid)
|
||||
case GATEWAY:
|
||||
@@ -873,6 +880,7 @@ func (c *client) readLoop() {
|
||||
return
|
||||
}
|
||||
nc := c.nc
|
||||
ws := c.ws != nil
|
||||
c.in.rsz = startBufSize
|
||||
// Snapshot max control line since currently can not be changed on reload and we
|
||||
// were checking it on each call to parse. If this changes and we allow MaxControlLine
|
||||
@@ -898,6 +906,19 @@ func (c *client) readLoop() {
|
||||
// Start read buffer.
|
||||
b := make([]byte, c.in.rsz)
|
||||
|
||||
// Websocket clients will return several slices if there are multiple
|
||||
// websocket frames in the blind read. For non WS clients though, we
|
||||
// will always have 1 slice per loop iteration. So we define this here
|
||||
// so non WS clients will use bufs[0] = b[:n].
|
||||
var _bufs [1][]byte
|
||||
bufs := _bufs[:1]
|
||||
|
||||
var wsr *wsReadInfo
|
||||
if ws {
|
||||
wsr = &wsReadInfo{}
|
||||
wsr.init()
|
||||
}
|
||||
|
||||
for {
|
||||
n, err := nc.Read(b)
|
||||
// If we have any data we will try to parse and exit at the end.
|
||||
@@ -905,6 +926,19 @@ func (c *client) readLoop() {
|
||||
c.closeConnection(closedStateForErr(err))
|
||||
return
|
||||
}
|
||||
if ws {
|
||||
bufs, err = c.wsRead(wsr, nc, b[:n])
|
||||
if bufs == nil && err != nil {
|
||||
if err != io.EOF {
|
||||
c.Errorf("read error: %v", err)
|
||||
}
|
||||
c.closeConnection(closedStateForErr(err))
|
||||
} else if bufs == nil {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
bufs[0] = b[:n]
|
||||
}
|
||||
start := time.Now()
|
||||
|
||||
// Clear inbound stats cache
|
||||
@@ -914,20 +948,22 @@ func (c *client) readLoop() {
|
||||
|
||||
// Main call into parser for inbound data. This will generate callouts
|
||||
// to process messages, etc.
|
||||
if err := c.parse(b[:n]); err != nil {
|
||||
if dur := time.Since(start); dur >= readLoopReportThreshold {
|
||||
c.Warnf("Readloop processing time: %v", dur)
|
||||
for i := 0; i < len(bufs); i++ {
|
||||
if err := c.parse(bufs[i]); err != nil {
|
||||
if dur := time.Since(start); dur >= readLoopReportThreshold {
|
||||
c.Warnf("Readloop processing time: %v", dur)
|
||||
}
|
||||
// Need to call flushClients because some of the clients have been
|
||||
// assigned messages and their "fsp" incremented, and need now to be
|
||||
// decremented and their writeLoop signaled.
|
||||
c.flushClients(0)
|
||||
// handled inline
|
||||
if err != ErrMaxPayload && err != ErrAuthentication {
|
||||
c.Error(err)
|
||||
c.closeConnection(ProtocolViolation)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Need to call flushClients because some of the clients have been
|
||||
// assigned messages and their "fsp" incremented, and need now to be
|
||||
// decremented and their writeLoop signaled.
|
||||
c.flushClients(0)
|
||||
// handled inline
|
||||
if err != ErrMaxPayload && err != ErrAuthentication {
|
||||
c.Error(err)
|
||||
c.closeConnection(ProtocolViolation)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Updates stats for client and server that were collected
|
||||
@@ -1011,19 +1047,26 @@ func closedStateForErr(err error) ClosedState {
|
||||
|
||||
// collapsePtoNB will place primary onto nb buffer as needed in prep for WriteTo.
|
||||
// This will return a copy on purpose.
|
||||
func (c *client) collapsePtoNB() net.Buffers {
|
||||
func (c *client) collapsePtoNB() (net.Buffers, int64) {
|
||||
if c.ws != nil {
|
||||
return c.wsCollapsePtoNB()
|
||||
}
|
||||
if c.out.p != nil {
|
||||
p := c.out.p
|
||||
c.out.p = nil
|
||||
return append(c.out.nb, p)
|
||||
return append(c.out.nb, p), c.out.pb
|
||||
}
|
||||
return c.out.nb
|
||||
return c.out.nb, c.out.pb
|
||||
}
|
||||
|
||||
// This will handle the fixup needed on a partial write.
|
||||
// Assume pending has been already calculated correctly.
|
||||
func (c *client) handlePartialWrite(pnb net.Buffers) {
|
||||
nb := c.collapsePtoNB()
|
||||
if c.ws != nil {
|
||||
c.ws.frames = append(pnb, c.ws.frames...)
|
||||
return
|
||||
}
|
||||
nb, _ := c.collapsePtoNB()
|
||||
// The partial needs to be first, so append nb to pnb
|
||||
c.out.nb = append(pnb, nb...)
|
||||
}
|
||||
@@ -1050,8 +1093,11 @@ func (c *client) flushOutbound() bool {
|
||||
}
|
||||
|
||||
// Place primary on nb, assign primary to secondary, nil out nb and secondary.
|
||||
nb := c.collapsePtoNB()
|
||||
nb, attempted := c.collapsePtoNB()
|
||||
c.out.p, c.out.nb, c.out.s = c.out.s, nil, nil
|
||||
if nb == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// For selecting primary replacement.
|
||||
cnb := nb
|
||||
@@ -1062,7 +1108,6 @@ func (c *client) flushOutbound() bool {
|
||||
|
||||
// In case it goes away after releasing the lock.
|
||||
nc := c.nc
|
||||
attempted := c.out.pb
|
||||
apm := c.out.pm
|
||||
|
||||
// Capture this (we change the value in some tests)
|
||||
@@ -1086,7 +1131,8 @@ func (c *client) flushOutbound() bool {
|
||||
// Re-acquire client lock.
|
||||
c.mu.Lock()
|
||||
|
||||
if err != nil {
|
||||
// Ignore ErrShortWrite errors, they will be handled as partials.
|
||||
if err != nil && err != io.ErrShortWrite {
|
||||
// Handle timeout error (slow consumer) differently
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
if closed := c.handleWriteTimeout(n, attempted, len(cnb)); closed {
|
||||
@@ -1100,29 +1146,31 @@ func (c *client) flushOutbound() bool {
|
||||
report = c.Errorf
|
||||
}
|
||||
report("Error flushing: %v", err)
|
||||
c.markConnAsClosed(WriteError, true)
|
||||
c.markConnAsClosed(WriteError)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Update flush time statistics.
|
||||
c.out.lft = lft
|
||||
c.out.lwb = int32(n)
|
||||
|
||||
// Subtract from pending bytes and messages.
|
||||
c.out.pb -= int64(c.out.lwb)
|
||||
c.out.pb -= n
|
||||
if c.ws != nil {
|
||||
c.ws.fs -= n
|
||||
}
|
||||
c.out.pm -= apm // FIXME(dlc) - this will not be totally accurate on partials.
|
||||
|
||||
// Check for partial writes
|
||||
// TODO(dlc) - zero write with no error will cause lost message and the writeloop to spin.
|
||||
if int64(c.out.lwb) != attempted && n > 0 {
|
||||
if n != attempted && n > 0 {
|
||||
c.handlePartialWrite(nb)
|
||||
} else if c.out.lwb >= c.out.sz {
|
||||
} else if int32(n) >= c.out.sz {
|
||||
c.out.sws = 0
|
||||
}
|
||||
|
||||
// Adjust based on what we wrote plus any pending.
|
||||
pt := int64(c.out.lwb) + c.out.pb
|
||||
pt := n + c.out.pb
|
||||
|
||||
// Adjust sz as needed downward, keeping power of 2.
|
||||
// We do this at a slower rate.
|
||||
@@ -1158,7 +1206,7 @@ func (c *client) flushOutbound() bool {
|
||||
|
||||
// Check if we have a stalled gate and if so and we are recovering release
|
||||
// any stalled producers. Only kind==CLIENT will stall.
|
||||
if c.out.stc != nil && (int64(c.out.lwb) == attempted || c.out.pb < c.out.mp/2) {
|
||||
if c.out.stc != nil && (n == attempted || c.out.pb < c.out.mp/2) {
|
||||
close(c.out.stc)
|
||||
c.out.stc = nil
|
||||
}
|
||||
@@ -1173,7 +1221,7 @@ func (c *client) handleWriteTimeout(written, attempted int64, numChunks int) boo
|
||||
if tlsConn, ok := c.nc.(*tls.Conn); ok {
|
||||
if !tlsConn.ConnectionState().HandshakeComplete {
|
||||
// Likely a TLSTimeout error instead...
|
||||
c.markConnAsClosed(TLSHandshakeError, true)
|
||||
c.markConnAsClosed(TLSHandshakeError)
|
||||
// Would need to coordinate with tlstimeout()
|
||||
// to avoid double logging, so skip logging
|
||||
// here, and don't report a slow consumer error.
|
||||
@@ -1184,7 +1232,7 @@ func (c *client) handleWriteTimeout(written, attempted int64, numChunks int) boo
|
||||
// before the authorization timeout. If that is the case, then we handle
|
||||
// as slow consumer though we do not increase the counter as that can be
|
||||
// misleading.
|
||||
c.markConnAsClosed(SlowConsumerWriteDeadline, true)
|
||||
c.markConnAsClosed(SlowConsumerWriteDeadline)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1195,26 +1243,40 @@ func (c *client) handleWriteTimeout(written, attempted int64, numChunks int) boo
|
||||
|
||||
// We always close CLIENT connections, or when nothing was written at all...
|
||||
if c.kind == CLIENT || written == 0 {
|
||||
c.markConnAsClosed(SlowConsumerWriteDeadline, true)
|
||||
c.markConnAsClosed(SlowConsumerWriteDeadline)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Marks this connection has closed with the given reason.
|
||||
// Sets the closeConnection flag and skipFlushOnClose flag if asked.
|
||||
// Sets the closeConnection flag and skipFlushOnClose depending on the reason.
|
||||
// Depending on the kind of connection, the connection will be saved.
|
||||
// If a writeLoop has been started, the final flush/close/teardown will
|
||||
// be done there, otherwise flush and close of TCP connection is done here in place.
|
||||
// Returns true if closed in place, flase otherwise.
|
||||
// Lock is held on entry.
|
||||
func (c *client) markConnAsClosed(reason ClosedState, skipFlush bool) bool {
|
||||
func (c *client) markConnAsClosed(reason ClosedState) bool {
|
||||
// Possibly set skipFlushOnClose flag even if connection has already been
|
||||
// mark as closed. The rationale is that a connection may be closed with
|
||||
// a reason that justifies a flush (say after sending an -ERR), but then
|
||||
// the flushOutbound() gets a write error. If that happens, connection
|
||||
// being lost, there is no reason to attempt to flush again during the
|
||||
// teardown when the writeLoop exits.
|
||||
var skipFlush bool
|
||||
switch reason {
|
||||
case ReadError, WriteError, SlowConsumerPendingBytes, SlowConsumerWriteDeadline, TLSHandshakeError:
|
||||
c.flags.set(skipFlushOnClose)
|
||||
skipFlush = true
|
||||
}
|
||||
if c.flags.isSet(closeConnection) {
|
||||
return false
|
||||
}
|
||||
c.flags.set(closeConnection)
|
||||
if skipFlush {
|
||||
c.flags.set(skipFlushOnClose)
|
||||
// For a websocket client, unless we are told not to flush, enqueue
|
||||
// a websocket CloseMessage based on the reason.
|
||||
if !skipFlush && c.ws != nil && !c.ws.closeSent {
|
||||
c.wsEnqueueCloseMessage(reason)
|
||||
}
|
||||
// Be consistent with the creation: for routes and gateways,
|
||||
// we use Noticef on create, so use that too for delete.
|
||||
@@ -1610,7 +1672,7 @@ func (c *client) queueOutbound(data []byte) bool {
|
||||
c.out.pb -= int64(len(data))
|
||||
atomic.AddInt64(&c.srv.slowConsumers, 1)
|
||||
c.Noticef("Slow Consumer Detected: MaxPending of %d Exceeded", c.out.mp)
|
||||
c.markConnAsClosed(SlowConsumerPendingBytes, true)
|
||||
c.markConnAsClosed(SlowConsumerPendingBytes)
|
||||
return referenced
|
||||
}
|
||||
|
||||
@@ -1755,6 +1817,10 @@ func (c *client) generateClientInfoJSON(info Info) []byte {
|
||||
info.CID = c.cid
|
||||
info.ClientIP = c.host
|
||||
info.MaxPayload = c.mpay
|
||||
if c.ws != nil {
|
||||
info.ClientConnectURLs = info.WSConnectURLs
|
||||
}
|
||||
info.WSConnectURLs = nil
|
||||
// Generate the info json
|
||||
b, _ := json.Marshal(info)
|
||||
pcs := [][]byte{[]byte("INFO"), b, []byte(CR_LF)}
|
||||
@@ -3803,7 +3869,7 @@ func (c *client) closeConnection(reason ClosedState) {
|
||||
// This will set the closeConnection flag and save the connection, etc..
|
||||
// Will return true if no writeLoop was started and TCP connection was
|
||||
// closed in place, in which case we need to do the teardown.
|
||||
teardownNow := c.markConnAsClosed(reason, false)
|
||||
teardownNow := c.markConnAsClosed(reason)
|
||||
c.mu.Unlock()
|
||||
|
||||
if teardownNow {
|
||||
@@ -3841,6 +3907,7 @@ func (c *client) teardownConn() {
|
||||
var (
|
||||
retryImplicit bool
|
||||
connectURLs []string
|
||||
wsConnectURLs []string
|
||||
gwName string
|
||||
gwIsOutbound bool
|
||||
gwCfg *gatewayCfg
|
||||
@@ -3870,6 +3937,7 @@ func (c *client) teardownConn() {
|
||||
retryImplicit = c.route.retry
|
||||
}
|
||||
connectURLs = c.route.connectURLs
|
||||
wsConnectURLs = c.route.wsConnURLs
|
||||
}
|
||||
if kind == GATEWAY {
|
||||
gwName = c.gw.name
|
||||
@@ -3894,11 +3962,11 @@ func (c *client) teardownConn() {
|
||||
|
||||
if srv != nil {
|
||||
// This is a route that disconnected, but we are not in lame duck mode...
|
||||
if len(connectURLs) > 0 && !srv.isLameDuckMode() {
|
||||
if (len(connectURLs) > 0 || len(wsConnectURLs) > 0) && !srv.isLameDuckMode() {
|
||||
// Unless disabled, possibly update the server's INFO protocol
|
||||
// and send to clients that know how to handle async INFOs.
|
||||
if !srv.getOpts().Cluster.NoAdvertise {
|
||||
srv.removeClientConnectURLsAndSendINFOToClients(connectURLs)
|
||||
srv.removeConnectURLsAndSendINFOToClients(connectURLs, wsConnectURLs)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2012-2019 The NATS Authors
|
||||
// Copyright 2012-2020 The NATS Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
@@ -72,7 +72,7 @@ func (c *testAsyncClient) parseAndClose(proto []byte) {
|
||||
func createClientAsync(ch chan *client, s *Server, cli net.Conn) {
|
||||
s.grWG.Add(1)
|
||||
go func() {
|
||||
c := s.createClient(cli)
|
||||
c := s.createClient(cli, nil)
|
||||
// Must be here to suppress +OK
|
||||
c.opts.Verbose = false
|
||||
go c.writeLoop()
|
||||
@@ -2163,6 +2163,10 @@ func (c *testConnWritePartial) Write(p []byte) (int, error) {
|
||||
return c.buf.Write(p[:n])
|
||||
}
|
||||
|
||||
func (c *testConnWritePartial) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testConnWritePartial) SetWriteDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -2279,7 +2283,7 @@ func TestCloseConnectionVeryEarly(t *testing.T) {
|
||||
// Call again with this closed connection. Alternatively, we
|
||||
// would have to call with a fake connection that implements
|
||||
// net.Conn but returns an error on Write.
|
||||
s.createClient(c)
|
||||
s.createClient(c, nil)
|
||||
|
||||
// This connection should not have been added to the server.
|
||||
checkClientsCount(t, s, 0)
|
||||
|
||||
@@ -154,7 +154,7 @@ func waitCh(t *testing.T, ch chan bool, errTxt string) {
|
||||
}
|
||||
}
|
||||
|
||||
func natsConnect(t *testing.T, url string, options ...nats.Option) *nats.Conn {
|
||||
func natsConnect(t testing.TB, url string, options ...nats.Option) *nats.Conn {
|
||||
t.Helper()
|
||||
nc, err := nats.Connect(url, options...)
|
||||
if err != nil {
|
||||
@@ -215,7 +215,7 @@ func natsFlush(t *testing.T, nc *nats.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func natsPub(t *testing.T, nc *nats.Conn, subj string, payload []byte) {
|
||||
func natsPub(t testing.TB, nc *nats.Conn, subj string, payload []byte) {
|
||||
t.Helper()
|
||||
if err := nc.Publish(subj, payload); err != nil {
|
||||
t.Fatalf("Error on publish: %v", err)
|
||||
|
||||
@@ -969,6 +969,7 @@ type Varz struct {
|
||||
TLSVerify bool `json:"tls_verify,omitempty"`
|
||||
IP string `json:"ip,omitempty"`
|
||||
ClientConnectURLs []string `json:"connect_urls,omitempty"`
|
||||
WSConnectURLs []string `json:"ws_connect_urls,omitempty"`
|
||||
MaxConn int `json:"max_connections"`
|
||||
MaxSubs int `json:"max_subscriptions,omitempty"`
|
||||
PingInterval time.Duration `json:"ping_interval"`
|
||||
@@ -1286,8 +1287,10 @@ func (s *Server) updateVarzRuntimeFields(v *Varz, forceUpdate bool, pcpu float64
|
||||
v.Mem = rss
|
||||
v.CPU = pcpu
|
||||
if l := len(s.info.ClientConnectURLs); l > 0 {
|
||||
v.ClientConnectURLs = make([]string, l)
|
||||
copy(v.ClientConnectURLs, s.info.ClientConnectURLs)
|
||||
v.ClientConnectURLs = append([]string(nil), s.info.ClientConnectURLs...)
|
||||
}
|
||||
if l := len(s.info.WSConnectURLs); l > 0 {
|
||||
v.WSConnectURLs = append([]string(nil), s.info.WSConnectURLs...)
|
||||
}
|
||||
v.Connections = len(s.clients)
|
||||
v.TotalConnections = s.totalClients
|
||||
|
||||
138
server/opts.go
138
server/opts.go
@@ -189,6 +189,7 @@ type Options struct {
|
||||
JetStreamMaxMemory int64 `json:"-"`
|
||||
JetStreamMaxStore int64 `json:"-"`
|
||||
StoreDir string `json:"-"`
|
||||
Websocket WebsocketOpts `json:"-"`
|
||||
ProfPort int `json:"-"`
|
||||
PidFile string `json:"-"`
|
||||
PortsFileDir string `json:"-"`
|
||||
@@ -246,6 +247,37 @@ type Options struct {
|
||||
routeProto int
|
||||
}
|
||||
|
||||
// WebsocketOpts ...
|
||||
type WebsocketOpts struct {
|
||||
// The server will accept websocket client connections on this hostname/IP.
|
||||
Host string
|
||||
// The server will accept websocket client connections on this port.
|
||||
Port int
|
||||
// The host:port to advertise to websocket clients in the cluster.
|
||||
Advertise string
|
||||
|
||||
// TLS configuration is required.
|
||||
TLSConfig *tls.Config
|
||||
|
||||
// If true, the Origin header must match the request's host.
|
||||
SameOrigin bool
|
||||
|
||||
// Only origins in this list will be accepted. If empty and
|
||||
// SameOrigin is false, any origin is accepted.
|
||||
AllowedOrigins []string
|
||||
|
||||
// If set to true, the server will negotiate with clients
|
||||
// if compression can be used. If this is false, no compression
|
||||
// will be used (both in server and clients) since it has to
|
||||
// be negotiated between both endpoints
|
||||
Compression bool
|
||||
|
||||
// Total time allowed for the server to read the client request
|
||||
// and write the response back to the client. This include the
|
||||
// time needed for the TLS Handshake.
|
||||
HandshakeTimeout time.Duration
|
||||
}
|
||||
|
||||
type netResolver interface {
|
||||
LookupHost(ctx context.Context, host string) ([]string, error)
|
||||
}
|
||||
@@ -843,6 +875,11 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error
|
||||
o.ConnectErrorReports = int(v.(int64))
|
||||
case "reconnect_error_reports":
|
||||
o.ReconnectErrorReports = int(v.(int64))
|
||||
case "websocket", "ws":
|
||||
if err := parseWebsocket(tk, o, errors, warnings); err != nil {
|
||||
*errors = append(*errors, err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
if au := atomic.LoadInt32(&allowUnknownTopLevelField); au == 0 && !tk.IsUsedVariable() {
|
||||
err := &unknownConfigFieldErr{
|
||||
@@ -912,6 +949,8 @@ func parseListen(v interface{}) (*hostPort, error) {
|
||||
return nil, fmt.Errorf("could not parse port %q", port)
|
||||
}
|
||||
hp.host = host
|
||||
default:
|
||||
return nil, fmt.Errorf("expected port or host:port, got %T", vv)
|
||||
}
|
||||
return hp, nil
|
||||
}
|
||||
@@ -2940,6 +2979,100 @@ func parseTLS(v interface{}) (t *TLSConfigOpts, retErr error) {
|
||||
return &tc, nil
|
||||
}
|
||||
|
||||
func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]error) error {
|
||||
var lt token
|
||||
defer convertPanicToErrorList(<, errors)
|
||||
|
||||
tk, v := unwrapValue(v, <)
|
||||
gm, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
return &configErr{tk, fmt.Sprintf("Expected websocket to be a map, got %T", v)}
|
||||
}
|
||||
for mk, mv := range gm {
|
||||
// Again, unwrap token value if line check is required.
|
||||
tk, mv = unwrapValue(mv, <)
|
||||
switch strings.ToLower(mk) {
|
||||
case "listen":
|
||||
hp, err := parseListen(mv)
|
||||
if err != nil {
|
||||
err := &configErr{tk, err.Error()}
|
||||
*errors = append(*errors, err)
|
||||
continue
|
||||
}
|
||||
o.Websocket.Host = hp.host
|
||||
o.Websocket.Port = hp.port
|
||||
case "port":
|
||||
o.Websocket.Port = int(mv.(int64))
|
||||
case "host", "net":
|
||||
o.Websocket.Host = mv.(string)
|
||||
case "advertise":
|
||||
o.Websocket.Advertise = mv.(string)
|
||||
case "tls":
|
||||
config, _, err := getTLSConfig(tk)
|
||||
if err != nil {
|
||||
*errors = append(*errors, err)
|
||||
continue
|
||||
}
|
||||
o.Websocket.TLSConfig = config
|
||||
case "same_origin":
|
||||
o.Websocket.SameOrigin = mv.(bool)
|
||||
case "allowed_origins", "allowed_origin", "allow_origins", "allow_origin", "origins", "origin":
|
||||
switch mv := mv.(type) {
|
||||
case string:
|
||||
o.Websocket.AllowedOrigins = []string{mv}
|
||||
case []interface{}:
|
||||
keys := make([]string, 0, len(mv))
|
||||
for _, val := range mv {
|
||||
tk, val = unwrapValue(val, <)
|
||||
if key, ok := val.(string); ok {
|
||||
keys = append(keys, key)
|
||||
} else {
|
||||
err := &configErr{tk, fmt.Sprintf("error parsing allowed origins: unsupported type in array %T", val)}
|
||||
*errors = append(*errors, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
o.Websocket.AllowedOrigins = keys
|
||||
default:
|
||||
err := &configErr{tk, fmt.Sprintf("error parsing allowed origins: unsupported type %T", mv)}
|
||||
*errors = append(*errors, err)
|
||||
}
|
||||
case "handshake_timeout":
|
||||
ht := time.Duration(0)
|
||||
switch mv := mv.(type) {
|
||||
case int64:
|
||||
ht = time.Duration(mv) * time.Second
|
||||
case string:
|
||||
var err error
|
||||
ht, err = time.ParseDuration(mv)
|
||||
if err != nil {
|
||||
err := &configErr{tk, err.Error()}
|
||||
*errors = append(*errors, err)
|
||||
continue
|
||||
}
|
||||
default:
|
||||
err := &configErr{tk, fmt.Sprintf("error parsing handshake timeout: unsupported type %T", mv)}
|
||||
*errors = append(*errors, err)
|
||||
}
|
||||
o.Websocket.HandshakeTimeout = ht
|
||||
case "compression":
|
||||
o.Websocket.Compression = mv.(bool)
|
||||
default:
|
||||
if !tk.IsUsedVariable() {
|
||||
err := &unknownConfigFieldErr{
|
||||
field: mk,
|
||||
configErr: configErr{
|
||||
token: tk,
|
||||
},
|
||||
}
|
||||
*errors = append(*errors, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenTLSConfig loads TLS related configuration parameters.
|
||||
func GenTLSConfig(tc *TLSConfigOpts) (*tls.Config, error) {
|
||||
// Create the tls.Config from our options before including the certs.
|
||||
@@ -3275,6 +3408,11 @@ func setBaselineOptions(opts *Options) {
|
||||
if opts.ReconnectErrorReports == 0 {
|
||||
opts.ReconnectErrorReports = DEFAULT_RECONNECT_ERROR_REPORTS
|
||||
}
|
||||
if opts.Websocket.Port != 0 {
|
||||
if opts.Websocket.Host == "" {
|
||||
opts.Websocket.Host = DEFAULT_HOST
|
||||
}
|
||||
}
|
||||
// JetStream
|
||||
if opts.JetStreamMaxMemory == 0 {
|
||||
opts.JetStreamMaxMemory = -1
|
||||
|
||||
@@ -306,8 +306,10 @@ func (c *clusterOption) Apply(server *Server) {
|
||||
server.routeInfo.AuthRequired = c.newValue.Username != ""
|
||||
if c.newValue.NoAdvertise {
|
||||
server.routeInfo.ClientConnectURLs = nil
|
||||
server.routeInfo.WSConnectURLs = nil
|
||||
} else {
|
||||
server.routeInfo.ClientConnectURLs = server.clientConnectURLs
|
||||
server.routeInfo.WSConnectURLs = server.websocket.connectURLs
|
||||
}
|
||||
server.setRouteInfoHostPortAndIP()
|
||||
server.mu.Unlock()
|
||||
@@ -523,7 +525,7 @@ type clientAdvertiseOption struct {
|
||||
// Apply the setting by updating the server info and regenerate the infoJSON byte array.
|
||||
func (c *clientAdvertiseOption) Apply(server *Server) {
|
||||
server.mu.Lock()
|
||||
server.setInfoHostPortAndGenerateJSON()
|
||||
server.setInfoHostPort()
|
||||
server.mu.Unlock()
|
||||
server.Noticef("Reload: client_advertise = %s", c.newValue)
|
||||
}
|
||||
@@ -627,6 +629,7 @@ func (s *Server) Reload() error {
|
||||
clusterOrgPort := curOpts.Cluster.Port
|
||||
gatewayOrgPort := curOpts.Gateway.Port
|
||||
leafnodesOrgPort := curOpts.LeafNode.Port
|
||||
websocketOrgPort := curOpts.Websocket.Port
|
||||
|
||||
s.mu.Unlock()
|
||||
|
||||
@@ -656,6 +659,9 @@ func (s *Server) Reload() error {
|
||||
if newOpts.LeafNode.Port == -1 {
|
||||
newOpts.LeafNode.Port = leafnodesOrgPort
|
||||
}
|
||||
if newOpts.Websocket.Port == -1 {
|
||||
newOpts.Websocket.Port = websocketOrgPort
|
||||
}
|
||||
|
||||
if err := s.reloadOptions(curOpts, newOpts); err != nil {
|
||||
return err
|
||||
@@ -756,6 +762,10 @@ func imposeOrder(value interface{}) error {
|
||||
sort.Slice(value.Gateways, func(i, j int) bool {
|
||||
return value.Gateways[i].Name < value.Gateways[j].Name
|
||||
})
|
||||
case WebsocketOpts:
|
||||
sort.Slice(value.AllowedOrigins, func(i, j int) bool {
|
||||
return value.AllowedOrigins[i] < value.AllowedOrigins[j]
|
||||
})
|
||||
case string, bool, int, int32, int64, time.Duration, float64, nil,
|
||||
LeafNodeOpts, ClusterOpts, *tls.Config, *URLAccResolver, *MemAccResolver, Authentication:
|
||||
// explicitly skipped types
|
||||
@@ -914,6 +924,18 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
|
||||
return nil, fmt.Errorf("config reload not supported for jetstream max memory")
|
||||
case "jetstreammaxstore":
|
||||
return nil, fmt.Errorf("config reload not supported for jetstream max storage")
|
||||
case "websocket":
|
||||
// Similar to gateways
|
||||
tmpOld := oldValue.(WebsocketOpts)
|
||||
tmpNew := newValue.(WebsocketOpts)
|
||||
tmpOld.TLSConfig = nil
|
||||
tmpNew.TLSConfig = nil
|
||||
// If there is really a change prevents reload.
|
||||
if !reflect.DeepEqual(tmpOld, tmpNew) {
|
||||
// See TODO(ik) note below about printing old/new values.
|
||||
return nil, fmt.Errorf("config reload not supported for %s: old=%v, new=%v",
|
||||
field.Name, oldValue, newValue)
|
||||
}
|
||||
case "connecterrorreports":
|
||||
diffOpts = append(diffOpts, &connectErrorReports{newValue: newValue.(int)})
|
||||
case "reconnecterrorreports":
|
||||
|
||||
@@ -73,6 +73,7 @@ type route struct {
|
||||
authRequired bool
|
||||
tlsRequired bool
|
||||
connectURLs []string
|
||||
wsConnURLs []string
|
||||
replySubs map[*subscription]*time.Timer
|
||||
gatewayURL string
|
||||
leafnodeURL string
|
||||
@@ -524,7 +525,7 @@ func (c *client) processRouteInfo(info *Info) {
|
||||
// Unless disabled, possibly update the server's INFO protocol
|
||||
// and send to clients that know how to handle async INFOs.
|
||||
if !s.getOpts().Cluster.NoAdvertise {
|
||||
s.addClientConnectURLsAndSendINFOToClients(info.ClientConnectURLs)
|
||||
s.addConnectURLsAndSendINFOToClients(info.ClientConnectURLs, info.WSConnectURLs)
|
||||
}
|
||||
} else {
|
||||
c.Debugf("Detected duplicate remote route %q", info.ID)
|
||||
@@ -570,7 +571,7 @@ func (c *client) updateRemoteRoutePerms(sl *Sublist, info *Info) {
|
||||
// sendAsyncInfoToClients sends an INFO protocol to all
|
||||
// connected clients that accept async INFO updates.
|
||||
// The server lock is held on entry.
|
||||
func (s *Server) sendAsyncInfoToClients() {
|
||||
func (s *Server) sendAsyncInfoToClients(regCli, wsCli bool) {
|
||||
// If there are no clients supporting async INFO protocols, we are done.
|
||||
// Also don't send if we are shutting down...
|
||||
if s.cproto == 0 || s.shutdown {
|
||||
@@ -583,7 +584,9 @@ func (s *Server) sendAsyncInfoToClients() {
|
||||
// registered (server has received CONNECT and first PING). For
|
||||
// clients that are not at this stage, this will happen in the
|
||||
// processing of the first PING (see client.processPing)
|
||||
if c.opts.Protocol >= ClientProtoInfo && c.flags.isSet(firstPongSent) {
|
||||
if ((regCli && c.ws == nil) || (wsCli && c.ws != nil)) &&
|
||||
c.opts.Protocol >= ClientProtoInfo &&
|
||||
c.flags.isSet(firstPongSent) {
|
||||
// sendInfo takes care of checking if the connection is still
|
||||
// valid or not, so don't duplicate tests here.
|
||||
c.enqueueProto(c.generateClientInfoJSON(s.copyInfo()))
|
||||
@@ -1236,6 +1239,7 @@ func (s *Server) addRoute(c *client, info *Info) (bool, bool) {
|
||||
s.remotes[id] = c
|
||||
c.mu.Lock()
|
||||
c.route.connectURLs = info.ClientConnectURLs
|
||||
c.route.wsConnURLs = info.WSConnectURLs
|
||||
cid := c.cid
|
||||
hash := string(c.route.hash)
|
||||
c.mu.Unlock()
|
||||
@@ -1495,6 +1499,7 @@ func (s *Server) routeAcceptLoop(ch chan struct{}) {
|
||||
// Set this if only if advertise is not disabled
|
||||
if !opts.Cluster.NoAdvertise {
|
||||
info.ClientConnectURLs = s.clientConnectURLs
|
||||
info.WSConnectURLs = s.websocket.connectURLs
|
||||
}
|
||||
// If we have selected a random port...
|
||||
if port == 0 {
|
||||
|
||||
193
server/server.go
193
server/server.go
@@ -80,7 +80,8 @@ type Info struct {
|
||||
ClientIP string `json:"client_ip,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
Cluster string `json:"cluster,omitempty"`
|
||||
ClientConnectURLs []string `json:"connect_urls,omitempty"` // Contains URLs a client can connect to.
|
||||
ClientConnectURLs []string `json:"connect_urls,omitempty"` // Contains URLs a client can connect to.
|
||||
WSConnectURLs []string `json:"ws_connect_urls,omitempty"` // Contains URLs a ws client can connect to.
|
||||
|
||||
// Route Specific
|
||||
Import *SubjectPermission `json:"import,omitempty"`
|
||||
@@ -214,6 +215,9 @@ type Server struct {
|
||||
|
||||
// For eventIDs
|
||||
eventIds *nuid.NUID
|
||||
|
||||
// Websocket structure
|
||||
websocket srvWebsocket
|
||||
}
|
||||
|
||||
// Make sure all are 64bits for atomic use
|
||||
@@ -301,6 +305,9 @@ func NewServer(opts *Options) (*Server, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Used internally for quick look-ups.
|
||||
s.websocket.connectURLsMap = make(map[string]struct{})
|
||||
|
||||
// Ensure that non-exported options (used in tests) are properly set.
|
||||
s.setLeafNodeNonExportedOptions()
|
||||
|
||||
@@ -322,7 +329,7 @@ func NewServer(opts *Options) (*Server, error) {
|
||||
// listener has been created (possibly with random port),
|
||||
// but since some tests may expect the INFO to be properly
|
||||
// set after New(), let's do it now.
|
||||
s.setInfoHostPortAndGenerateJSON()
|
||||
s.setInfoHostPort()
|
||||
|
||||
// For tracking clients
|
||||
s.clients = make(map[uint64]*client)
|
||||
@@ -434,7 +441,10 @@ func validateOptions(o *Options) error {
|
||||
}
|
||||
// Check that gateway is properly configured. Returns no error
|
||||
// if there is no gateway defined.
|
||||
return validateGatewayOptions(o)
|
||||
if err := validateGatewayOptions(o); err != nil {
|
||||
return err
|
||||
}
|
||||
return validateWebsocketOptions(o)
|
||||
}
|
||||
|
||||
func (s *Server) getOpts() *Options {
|
||||
@@ -1319,6 +1329,13 @@ func (s *Server) Start() {
|
||||
// port to be opened and potential ephemeral port selected.
|
||||
clientListenReady := make(chan struct{})
|
||||
|
||||
// Start websocket server if needed. Do this before starting the routes,
|
||||
// because we want to resolve the gateway host:port so that this information
|
||||
// can be sent to other routes.
|
||||
if opts.Websocket.Port != 0 {
|
||||
s.startWebsocketServer()
|
||||
}
|
||||
|
||||
// Start up routing as well if needed.
|
||||
if opts.Cluster.Port != 0 {
|
||||
s.startGoRoutine(func() {
|
||||
@@ -1402,6 +1419,14 @@ func (s *Server) Shutdown() {
|
||||
s.listener = nil
|
||||
}
|
||||
|
||||
// Kick websocket server
|
||||
if s.websocket.server != nil {
|
||||
doneExpected++
|
||||
s.websocket.server.Close()
|
||||
s.websocket.server = nil
|
||||
s.websocket.listener = nil
|
||||
}
|
||||
|
||||
// Kick leafnodes AcceptLoop()
|
||||
if s.leafNodeListener != nil {
|
||||
doneExpected++
|
||||
@@ -1512,7 +1537,6 @@ func (s *Server) AcceptLoop(clr chan struct{}) {
|
||||
|
||||
// Setup state that can enable shutdown
|
||||
s.mu.Lock()
|
||||
s.listener = l
|
||||
|
||||
// If server was started with RANDOM_PORT (-1), opts.Port would be equal
|
||||
// to 0 at the beginning this function. So we need to get the actual port
|
||||
@@ -1523,14 +1547,16 @@ func (s *Server) AcceptLoop(clr chan struct{}) {
|
||||
|
||||
// Now that port has been set (if it was set to RANDOM), set the
|
||||
// server's info Host/Port with either values from Options or
|
||||
// ClientAdvertise. Also generate the JSON byte array.
|
||||
if err := s.setInfoHostPortAndGenerateJSON(); err != nil {
|
||||
// ClientAdvertise.
|
||||
if err := s.setInfoHostPort(); err != nil {
|
||||
s.Fatalf("Error setting server INFO with ClientAdvertise value of %s, err=%v", s.opts.ClientAdvertise, err)
|
||||
l.Close()
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
// Keep track of client connect URLs. We may need them later.
|
||||
s.clientConnectURLs = s.getClientConnectURLs()
|
||||
s.listener = l
|
||||
s.mu.Unlock()
|
||||
|
||||
// Let the caller know that we are ready
|
||||
@@ -1554,7 +1580,7 @@ func (s *Server) AcceptLoop(clr chan struct{}) {
|
||||
}
|
||||
tmpDelay = ACCEPT_MIN_SLEEP
|
||||
s.startGoRoutine(func() {
|
||||
s.createClient(conn)
|
||||
s.createClient(conn, nil)
|
||||
s.grWG.Done()
|
||||
})
|
||||
}
|
||||
@@ -1565,8 +1591,7 @@ func (s *Server) AcceptLoop(clr chan struct{}) {
|
||||
// Note that this function may be called during config reload, this is why
|
||||
// Host/Port may be reset to original Options if the ClientAdvertise option
|
||||
// is not set (since it may have previously been).
|
||||
// The function then generates the server infoJSON.
|
||||
func (s *Server) setInfoHostPortAndGenerateJSON() error {
|
||||
func (s *Server) setInfoHostPort() error {
|
||||
// When this function is called, opts.Port is set to the actual listen
|
||||
// port (if option was originally set to RANDOM), even during a config
|
||||
// reload. So use of s.opts.Port is safe.
|
||||
@@ -1799,14 +1824,16 @@ func (s *Server) HTTPHandler() http.Handler {
|
||||
return s.httpHandler
|
||||
}
|
||||
|
||||
// Perform a conditional deep copy due to reference nature of ClientConnectURLs.
|
||||
// Perform a conditional deep copy due to reference nature of [Client|WS]ConnectURLs.
|
||||
// If updates are made to Info, this function should be consulted and updated.
|
||||
// Assume lock is held.
|
||||
func (s *Server) copyInfo() Info {
|
||||
info := s.info
|
||||
if info.ClientConnectURLs != nil {
|
||||
info.ClientConnectURLs = make([]string, len(s.info.ClientConnectURLs))
|
||||
copy(info.ClientConnectURLs, s.info.ClientConnectURLs)
|
||||
if len(info.ClientConnectURLs) > 0 {
|
||||
info.ClientConnectURLs = append([]string(nil), s.info.ClientConnectURLs...)
|
||||
}
|
||||
if len(info.WSConnectURLs) > 0 {
|
||||
info.WSConnectURLs = append([]string(nil), s.info.WSConnectURLs...)
|
||||
}
|
||||
if s.nonceRequired() {
|
||||
// Nonce handling
|
||||
@@ -1818,7 +1845,7 @@ func (s *Server) copyInfo() Info {
|
||||
return info
|
||||
}
|
||||
|
||||
func (s *Server) createClient(conn net.Conn) *client {
|
||||
func (s *Server) createClient(conn net.Conn, ws *websocket) *client {
|
||||
// Snapshot server options.
|
||||
opts := s.getOpts()
|
||||
|
||||
@@ -1830,7 +1857,7 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now}
|
||||
c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now, ws: ws}
|
||||
|
||||
c.registerWithAccount(s.globalAccount())
|
||||
|
||||
@@ -1857,6 +1884,7 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
// TLS handshake is done (if applicable).
|
||||
c.sendProtoNow(c.generateClientInfoJSON(info))
|
||||
|
||||
tlsRequired := ws == nil && info.TLSRequired
|
||||
// Unlock to register
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -1885,7 +1913,7 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
c.mu.Lock()
|
||||
|
||||
// Check for TLS
|
||||
if info.TLSRequired {
|
||||
if tlsRequired {
|
||||
c.Debugf("Starting TLS client connection handshake")
|
||||
c.nc = tls.Server(c.nc, opts.TLSConfig)
|
||||
conn := c.nc.(*tls.Conn)
|
||||
@@ -1942,7 +1970,7 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
// Spin up the write loop.
|
||||
s.startGoRoutine(func() { c.writeLoop() })
|
||||
|
||||
if info.TLSRequired {
|
||||
if tlsRequired {
|
||||
c.Debugf("TLS handshake complete")
|
||||
cs := c.nc.(*tls.Conn).ConnectionState()
|
||||
c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite))
|
||||
@@ -1990,57 +2018,66 @@ func (s *Server) saveClosedClient(c *client, nc net.Conn, reason ClosedState) {
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Adds the given array of urls to the server's INFO.ClientConnectURLs
|
||||
// array. The server INFO JSON is regenerated.
|
||||
// Note that a check is made to ensure that given URLs are not
|
||||
// already present. So the INFO JSON is regenerated only if new ULRs
|
||||
// were added.
|
||||
// Adds to the list of client and websocket clients connect URLs.
|
||||
// If there was a change, an INFO protocol is sent to registered clients
|
||||
// that support async INFO protocols.
|
||||
func (s *Server) addClientConnectURLsAndSendINFOToClients(urls []string) {
|
||||
s.updateServerINFOAndSendINFOToClients(urls, true)
|
||||
func (s *Server) addConnectURLsAndSendINFOToClients(curls, wsurls []string) {
|
||||
s.updateServerINFOAndSendINFOToClients(curls, wsurls, true)
|
||||
}
|
||||
|
||||
// Removes the given array of urls from the server's INFO.ClientConnectURLs
|
||||
// array. The server INFO JSON is regenerated if needed.
|
||||
// Removes from the list of client and websocket clients connect URLs.
|
||||
// If there was a change, an INFO protocol is sent to registered clients
|
||||
// that support async INFO protocols.
|
||||
func (s *Server) removeClientConnectURLsAndSendINFOToClients(urls []string) {
|
||||
s.updateServerINFOAndSendINFOToClients(urls, false)
|
||||
func (s *Server) removeConnectURLsAndSendINFOToClients(curls, wsurls []string) {
|
||||
s.updateServerINFOAndSendINFOToClients(curls, wsurls, false)
|
||||
}
|
||||
|
||||
// Updates the server's Info object with the given array of URLs and re-generate
|
||||
// the infoJSON byte array, then send an (async) INFO protocol to clients that
|
||||
// support it.
|
||||
func (s *Server) updateServerINFOAndSendINFOToClients(urls []string, add bool) {
|
||||
// Updates the list of client and websocket clients connect URLs and if any change
|
||||
// sends an async INFO update to clients that support it.
|
||||
func (s *Server) updateServerINFOAndSendINFOToClients(curls, wsurls []string, add bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Will be set to true if we alter the server's Info object.
|
||||
wasUpdated := false
|
||||
remove := !add
|
||||
for _, url := range urls {
|
||||
_, present := s.clientConnectURLsMap[url]
|
||||
if add && !present {
|
||||
s.clientConnectURLsMap[url] = struct{}{}
|
||||
wasUpdated = true
|
||||
} else if remove && present {
|
||||
delete(s.clientConnectURLsMap, url)
|
||||
wasUpdated = true
|
||||
checkMap := func(urls []string, m map[string]struct{}) bool {
|
||||
wasUpdated := false
|
||||
for _, url := range urls {
|
||||
_, present := m[url]
|
||||
if add && !present {
|
||||
m[url] = struct{}{}
|
||||
wasUpdated = true
|
||||
} else if remove && present {
|
||||
delete(m, url)
|
||||
wasUpdated = true
|
||||
}
|
||||
}
|
||||
return wasUpdated
|
||||
}
|
||||
cliUpdated := checkMap(curls, s.clientConnectURLsMap)
|
||||
wsUpdated := checkMap(wsurls, s.websocket.connectURLsMap)
|
||||
|
||||
updateInfo := func(infoURLs *[]string, urls []string, m map[string]struct{}) {
|
||||
// Recreate the info's slice from the map
|
||||
*infoURLs = (*infoURLs)[:0]
|
||||
// Add this server client connect ULRs first...
|
||||
*infoURLs = append(*infoURLs, urls...)
|
||||
// Then the ones from the map
|
||||
for url := range m {
|
||||
*infoURLs = append(*infoURLs, url)
|
||||
}
|
||||
}
|
||||
if wasUpdated {
|
||||
// Recreate the info.ClientConnectURL array from the map
|
||||
s.info.ClientConnectURLs = s.info.ClientConnectURLs[:0]
|
||||
// Add this server client connect ULRs first...
|
||||
s.info.ClientConnectURLs = append(s.info.ClientConnectURLs, s.clientConnectURLs...)
|
||||
for url := range s.clientConnectURLsMap {
|
||||
s.info.ClientConnectURLs = append(s.info.ClientConnectURLs, url)
|
||||
}
|
||||
if cliUpdated {
|
||||
updateInfo(&s.info.ClientConnectURLs, s.clientConnectURLs, s.clientConnectURLsMap)
|
||||
}
|
||||
if wsUpdated {
|
||||
updateInfo(&s.info.WSConnectURLs, s.websocket.connectURLs, s.websocket.connectURLsMap)
|
||||
}
|
||||
if cliUpdated || wsUpdated {
|
||||
// Update the time of this update
|
||||
s.lastCURLsUpdate = time.Now().UnixNano()
|
||||
// Send to all registered clients that support async INFO protocols.
|
||||
s.sendAsyncInfoToClients()
|
||||
s.sendAsyncInfoToClients(cliUpdated, wsUpdated)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2265,7 +2302,11 @@ func (s *Server) ReadyForConnections(dur time.Duration) bool {
|
||||
end := time.Now().Add(dur)
|
||||
for time.Now().Before(end) {
|
||||
s.mu.Lock()
|
||||
ok := s.listener != nil && (opts.Cluster.Port == 0 || s.routeListener != nil) && (opts.Gateway.Name == "" || s.gatewayListener != nil)
|
||||
ok := s.listener != nil &&
|
||||
(opts.Cluster.Port == 0 || s.routeListener != nil) &&
|
||||
(opts.Gateway.Name == "" || s.gatewayListener != nil) &&
|
||||
(opts.LeafNode.Port == 0 || s.leafNodeListener != nil) &&
|
||||
(opts.Websocket.Port == 0 || s.websocket.listener != nil)
|
||||
s.mu.Unlock()
|
||||
if ok {
|
||||
return true
|
||||
@@ -2332,16 +2373,28 @@ func (s *Server) closedClients() []*closedClient {
|
||||
func (s *Server) getClientConnectURLs() []string {
|
||||
// Snapshot server options.
|
||||
opts := s.getOpts()
|
||||
// Ignore error here since we know that if there is client advertise, the
|
||||
// parseHostPort is correct because we did it right before calling this
|
||||
// function in Server.New().
|
||||
urls, _ := s.getConnectURLs(opts.ClientAdvertise, opts.Host, opts.Port)
|
||||
return urls
|
||||
}
|
||||
|
||||
// Generic version that will return an array of URLs based on the given
|
||||
// advertise, host and port values.
|
||||
func (s *Server) getConnectURLs(advertise, host string, port int) ([]string, error) {
|
||||
urls := make([]string, 0, 1)
|
||||
|
||||
// short circuit if client advertise is set
|
||||
if opts.ClientAdvertise != "" {
|
||||
// just use the info host/port. This is updated in s.New()
|
||||
urls = append(urls, net.JoinHostPort(s.info.Host, strconv.Itoa(s.info.Port)))
|
||||
// short circuit if advertise is set
|
||||
if advertise != "" {
|
||||
h, p, err := parseHostPort(advertise, port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
urls = append(urls, net.JoinHostPort(h, strconv.Itoa(p)))
|
||||
} else {
|
||||
sPort := strconv.Itoa(opts.Port)
|
||||
_, ips, err := s.getNonLocalIPsIfHostIsIPAny(opts.Host, true)
|
||||
sPort := strconv.Itoa(port)
|
||||
_, ips, err := s.getNonLocalIPsIfHostIsIPAny(host, true)
|
||||
for _, ip := range ips {
|
||||
urls = append(urls, net.JoinHostPort(ip, sPort))
|
||||
}
|
||||
@@ -2352,14 +2405,14 @@ func (s *Server) getClientConnectURLs() []string {
|
||||
// and not add any address in the array in the loop above, and we
|
||||
// ended-up returning 0.0.0.0, which is problematic for Windows clients.
|
||||
// Check for 0.0.0.0 or :: specifically, and ignore if that's the case.
|
||||
if opts.Host == "0.0.0.0" || opts.Host == "::" {
|
||||
s.Errorf("Address %q can not be resolved properly", opts.Host)
|
||||
if host == "0.0.0.0" || host == "::" {
|
||||
s.Errorf("Address %q can not be resolved properly", host)
|
||||
} else {
|
||||
urls = append(urls, net.JoinHostPort(opts.Host, sPort))
|
||||
urls = append(urls, net.JoinHostPort(host, sPort))
|
||||
}
|
||||
}
|
||||
}
|
||||
return urls
|
||||
return urls, nil
|
||||
}
|
||||
|
||||
// Returns an array of non local IPs if the provided host is
|
||||
@@ -2450,6 +2503,7 @@ type Ports struct {
|
||||
Monitoring []string `json:"monitoring,omitempty"`
|
||||
Cluster []string `json:"cluster,omitempty"`
|
||||
Profile []string `json:"profile,omitempty"`
|
||||
WebSocket []string `json:"websocket,omitempty"`
|
||||
}
|
||||
|
||||
// PortsInfo attempts to resolve all the ports. If after maxWait the ports are not
|
||||
@@ -2460,18 +2514,20 @@ func (s *Server) PortsInfo(maxWait time.Duration) *Ports {
|
||||
opts := s.getOpts()
|
||||
|
||||
s.mu.Lock()
|
||||
info := s.copyInfo()
|
||||
tls := s.info.TLSRequired
|
||||
listener := s.listener
|
||||
httpListener := s.http
|
||||
clusterListener := s.routeListener
|
||||
profileListener := s.profiler
|
||||
wsListener := s.websocket.listener
|
||||
wss := s.websocket.tls
|
||||
s.mu.Unlock()
|
||||
|
||||
ports := Ports{}
|
||||
|
||||
if listener != nil {
|
||||
natsProto := "nats"
|
||||
if info.TLSRequired {
|
||||
if tls {
|
||||
natsProto = "tls"
|
||||
}
|
||||
ports.Nats = formatURL(natsProto, listener)
|
||||
@@ -2497,6 +2553,14 @@ func (s *Server) PortsInfo(maxWait time.Duration) *Ports {
|
||||
ports.Profile = formatURL("http", profileListener)
|
||||
}
|
||||
|
||||
if wsListener != nil {
|
||||
protocol := "ws"
|
||||
if wss {
|
||||
protocol = "wss"
|
||||
}
|
||||
ports.WebSocket = formatURL(protocol, wsListener)
|
||||
}
|
||||
|
||||
return &ports
|
||||
}
|
||||
|
||||
@@ -2600,6 +2664,9 @@ func (s *Server) serviceListeners() []net.Listener {
|
||||
if opts.ProfPort != 0 {
|
||||
listeners = append(listeners, s.profiler)
|
||||
}
|
||||
if opts.Websocket.Port != 0 {
|
||||
listeners = append(listeners, s.websocket.listener)
|
||||
}
|
||||
return listeners
|
||||
}
|
||||
|
||||
|
||||
@@ -86,8 +86,7 @@ func secondsToDuration(seconds float64) time.Duration {
|
||||
func parseHostPort(hostPort string, defaultPort int) (host string, port int, err error) {
|
||||
if hostPort != "" {
|
||||
host, sPort, err := net.SplitHostPort(hostPort)
|
||||
switch err.(type) {
|
||||
case *net.AddrError:
|
||||
if ae, ok := err.(*net.AddrError); ok && strings.Contains(ae.Err, "missing port") {
|
||||
// try appending the current port
|
||||
host, sPort, err = net.SplitHostPort(fmt.Sprintf("%s:%d", hostPort, defaultPort))
|
||||
}
|
||||
|
||||
977
server/websocket.go
Normal file
977
server/websocket.go
Normal file
@@ -0,0 +1,977 @@
|
||||
// Copyright 2020 The NATS Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"crypto/sha1"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type wsOpCode int
|
||||
|
||||
const (
|
||||
// From https://tools.ietf.org/html/rfc6455#section-5.2
|
||||
wsTextMessage = wsOpCode(1)
|
||||
wsBinaryMessage = wsOpCode(2)
|
||||
wsCloseMessage = wsOpCode(8)
|
||||
wsPingMessage = wsOpCode(9)
|
||||
wsPongMessage = wsOpCode(10)
|
||||
|
||||
wsFinalBit = 1 << 7
|
||||
wsRsv1Bit = 1 << 6 // Used for compression, from https://tools.ietf.org/html/rfc7692#section-6
|
||||
wsRsv2Bit = 1 << 5
|
||||
wsRsv3Bit = 1 << 4
|
||||
|
||||
wsMaskBit = 1 << 7
|
||||
|
||||
wsContinuationFrame = 0
|
||||
wsMaxFrameHeaderSize = 10 // For a server-to-client frame
|
||||
wsMaxControlPayloadSize = 125
|
||||
wsFrameSizeForBrowsers = 4096 // From experiment, webrowsers behave better with limited frame size
|
||||
|
||||
// From https://tools.ietf.org/html/rfc6455#section-11.7
|
||||
wsCloseStatusNormalClosure = 1000
|
||||
wsCloseStatusGoingAway = 1001
|
||||
wsCloseStatusProtocolError = 1002
|
||||
wsCloseStatusUnsupportedData = 1003
|
||||
wsCloseStatusNoStatusReceived = 1005
|
||||
wsCloseStatusAbnormalClosure = 1006
|
||||
wsCloseStatusInvalidPayloadData = 1007
|
||||
wsCloseStatusPolicyViolation = 1008
|
||||
wsCloseStatusMessageTooBig = 1009
|
||||
wsCloseStatusInternalSrvError = 1011
|
||||
wsCloseStatusTLSHandshake = 1015
|
||||
|
||||
wsFirstFrame = true
|
||||
wsContFrame = false
|
||||
wsFinalFrame = true
|
||||
wsCompressedFrame = true
|
||||
wsUncompressedFrame = false
|
||||
)
|
||||
|
||||
var decompressorPool sync.Pool
|
||||
|
||||
// From https://tools.ietf.org/html/rfc6455#section-1.3
|
||||
var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
|
||||
type websocket struct {
|
||||
frames net.Buffers
|
||||
fs int64
|
||||
closeMsg []byte
|
||||
compress bool
|
||||
closeSent bool
|
||||
browser bool
|
||||
compressor *flate.Writer
|
||||
}
|
||||
|
||||
type srvWebsocket struct {
|
||||
mu sync.RWMutex
|
||||
server *http.Server
|
||||
listener net.Listener
|
||||
tls bool
|
||||
allowedOrigins map[string]*allowedOrigin // host will be the key
|
||||
sameOrigin bool
|
||||
connectURLs []string
|
||||
connectURLsMap map[string]struct{}
|
||||
}
|
||||
|
||||
type allowedOrigin struct {
|
||||
scheme string
|
||||
port string
|
||||
}
|
||||
|
||||
type wsUpgradeResult struct {
|
||||
conn net.Conn
|
||||
ws *websocket
|
||||
}
|
||||
|
||||
type wsReadInfo struct {
|
||||
rem int
|
||||
fs bool
|
||||
ff bool
|
||||
fc bool
|
||||
mkpos byte
|
||||
mkey [4]byte
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func (r *wsReadInfo) init() {
|
||||
r.fs, r.ff = true, true
|
||||
}
|
||||
|
||||
// Returns a slice containing `needed` bytes from the given buffer `buf`
|
||||
// starting at position `pos`, and possibly read from the given reader `r`.
|
||||
// When bytes are present in `buf`, the `pos` is incremented by the number
|
||||
// of bytes found up to `needed` and the new position is returned. If not
|
||||
// enough bytes are found, the bytes found in `buf` are copied to the returned
|
||||
// slice and the remaning bytes are read from `r`.
|
||||
func wsGet(r io.Reader, buf []byte, pos, needed int) ([]byte, int, error) {
|
||||
avail := len(buf) - pos
|
||||
if avail >= needed {
|
||||
return buf[pos : pos+needed], pos + needed, nil
|
||||
}
|
||||
b := make([]byte, needed)
|
||||
start := copy(b, buf[pos:])
|
||||
for start != needed {
|
||||
n, err := r.Read(b[start:cap(b)])
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
start += n
|
||||
}
|
||||
return b, pos + avail, nil
|
||||
}
|
||||
|
||||
// Returns a slice of byte slices corresponding to payload of websocket frames.
|
||||
// The byte slice `buf` is filled with bytes from the connection's read loop.
|
||||
// This function will decode the frame headers and unmask the payload(s).
|
||||
// It is possible that the returned slices point to the given `buf` slice, so
|
||||
// `buf` should not be overwritten until the returned slices have been parsed.
|
||||
//
|
||||
// Client lock MUST NOT be held on entry.
|
||||
func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, error) {
|
||||
var (
|
||||
bufs [][]byte
|
||||
tmpBuf []byte
|
||||
err error
|
||||
pos int
|
||||
max = len(buf)
|
||||
)
|
||||
for pos != max {
|
||||
if r.fs {
|
||||
b0 := buf[pos]
|
||||
frameType := wsOpCode(b0 & 0xF)
|
||||
final := b0&wsFinalBit != 0
|
||||
compressed := b0&wsRsv1Bit != 0
|
||||
pos++
|
||||
|
||||
tmpBuf, pos, err = wsGet(ior, buf, pos, 1)
|
||||
if err != nil {
|
||||
return bufs, err
|
||||
}
|
||||
b1 := tmpBuf[0]
|
||||
|
||||
// Clients MUST set the mask bit. If not set, reject.
|
||||
if b1&wsMaskBit == 0 {
|
||||
return bufs, c.wsHandleProtocolError("mask bit missing")
|
||||
}
|
||||
|
||||
// Store size in case it is < 125
|
||||
r.rem = int(b1 & 0x7F)
|
||||
|
||||
switch frameType {
|
||||
case wsPingMessage, wsPongMessage, wsCloseMessage:
|
||||
if r.rem > wsMaxControlPayloadSize {
|
||||
return bufs, c.wsHandleProtocolError(
|
||||
fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes",
|
||||
wsMaxControlPayloadSize))
|
||||
}
|
||||
if !final {
|
||||
return bufs, c.wsHandleProtocolError("control frame does not have final bit set")
|
||||
}
|
||||
case wsTextMessage, wsBinaryMessage:
|
||||
if !r.ff {
|
||||
return bufs, c.wsHandleProtocolError("new message started before final frame for previous message was received")
|
||||
}
|
||||
r.ff = final
|
||||
r.fc = compressed
|
||||
case wsContinuationFrame:
|
||||
// Compressed bit must be only set in the first frame
|
||||
if r.ff || compressed {
|
||||
return bufs, c.wsHandleProtocolError("invalid continuation frame")
|
||||
}
|
||||
r.ff = final
|
||||
default:
|
||||
return bufs, c.wsHandleProtocolError(fmt.Sprintf("unknown opcode %v", frameType))
|
||||
}
|
||||
|
||||
switch r.rem {
|
||||
case 126:
|
||||
tmpBuf, pos, err = wsGet(ior, buf, pos, 2)
|
||||
if err != nil {
|
||||
return bufs, err
|
||||
}
|
||||
r.rem = int(binary.BigEndian.Uint16(tmpBuf))
|
||||
case 127:
|
||||
tmpBuf, pos, err = wsGet(ior, buf, pos, 8)
|
||||
if err != nil {
|
||||
return bufs, err
|
||||
}
|
||||
r.rem = int(binary.BigEndian.Uint64(tmpBuf))
|
||||
}
|
||||
|
||||
// Read masking key
|
||||
tmpBuf, pos, err = wsGet(ior, buf, pos, 4)
|
||||
if err != nil {
|
||||
return bufs, err
|
||||
}
|
||||
copy(r.mkey[:], tmpBuf)
|
||||
r.mkpos = 0
|
||||
|
||||
// Handle control messages in place...
|
||||
if wsIsControlFrame(frameType) {
|
||||
pos, err = c.wsHandleControlFrame(r, frameType, ior, buf, pos)
|
||||
if err != nil {
|
||||
return bufs, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Done with the frame header
|
||||
r.fs = false
|
||||
}
|
||||
if pos < max {
|
||||
var b []byte
|
||||
var n int
|
||||
|
||||
n = r.rem
|
||||
if pos+n > max {
|
||||
n = max - pos
|
||||
}
|
||||
b = buf[pos : pos+n]
|
||||
pos += n
|
||||
r.rem -= n
|
||||
if r.fc {
|
||||
r.buf = append(r.buf, b...)
|
||||
b = r.buf
|
||||
}
|
||||
if !r.fc || r.rem == 0 {
|
||||
r.unmask(b)
|
||||
if r.fc {
|
||||
// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
|
||||
// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
|
||||
// does not report unexpected EOF.
|
||||
b = append(b, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff)
|
||||
br := bytes.NewBuffer(b)
|
||||
d, _ := decompressorPool.Get().(io.ReadCloser)
|
||||
if d == nil {
|
||||
d = flate.NewReader(br)
|
||||
} else {
|
||||
d.(flate.Resetter).Reset(br, nil)
|
||||
}
|
||||
b, err = ioutil.ReadAll(d)
|
||||
decompressorPool.Put(d)
|
||||
if err != nil {
|
||||
return bufs, err
|
||||
}
|
||||
}
|
||||
bufs = append(bufs, b)
|
||||
if r.rem == 0 {
|
||||
r.fs, r.fc, r.buf = true, false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return bufs, nil
|
||||
}
|
||||
|
||||
// Handles the PING, PONG and CLOSE websocket control frames.
|
||||
//
|
||||
// Client lock MUST NOT be held on entry.
|
||||
func (c *client) wsHandleControlFrame(r *wsReadInfo, frameType wsOpCode, nc io.Reader, buf []byte, pos int) (int, error) {
|
||||
var payload []byte
|
||||
var err error
|
||||
|
||||
statusPos := pos
|
||||
if r.rem > 0 {
|
||||
payload, pos, err = wsGet(nc, buf, pos, r.rem)
|
||||
if err != nil {
|
||||
return pos, err
|
||||
}
|
||||
r.unmask(payload)
|
||||
r.rem = 0
|
||||
}
|
||||
switch frameType {
|
||||
case wsCloseMessage:
|
||||
status := wsCloseStatusNoStatusReceived
|
||||
body := ""
|
||||
// If there is a payload, it should contain 2 unsigned bytes
|
||||
// that represent the status code and then optional payload.
|
||||
if len(payload) >= 2 {
|
||||
status = int(binary.BigEndian.Uint16(buf[statusPos : statusPos+2]))
|
||||
body = string(buf[statusPos+2 : statusPos+len(payload)])
|
||||
if body != "" && !utf8.ValidString(body) {
|
||||
// https://tools.ietf.org/html/rfc6455#section-5.5.1
|
||||
// If body is present, it must be a valid utf8
|
||||
status = wsCloseStatusInvalidPayloadData
|
||||
body = "invalid utf8 body in close frame"
|
||||
}
|
||||
}
|
||||
c.wsEnqueueControlMessage(wsCloseMessage, wsCreateCloseMessage(status, body))
|
||||
// Return io.EOF so that readLoop will close the connection as ClientClosed
|
||||
// after processing pending buffers.
|
||||
return pos, io.EOF
|
||||
case wsPingMessage:
|
||||
c.wsEnqueueControlMessage(wsPongMessage, payload)
|
||||
case wsPongMessage:
|
||||
// Nothing to do..
|
||||
}
|
||||
return pos, nil
|
||||
}
|
||||
|
||||
// Unmask the given slice.
|
||||
func (r *wsReadInfo) unmask(buf []byte) {
|
||||
p := int(r.mkpos)
|
||||
if len(buf) < 16 {
|
||||
for i := 0; i < len(buf); i++ {
|
||||
buf[i] ^= r.mkey[p&3]
|
||||
p++
|
||||
}
|
||||
r.mkpos = byte(p & 3)
|
||||
return
|
||||
}
|
||||
var k [8]byte
|
||||
for i := 0; i < 8; i++ {
|
||||
k[i] = r.mkey[(p+i)&3]
|
||||
}
|
||||
km := binary.BigEndian.Uint64(k[:])
|
||||
n := (len(buf) / 8) * 8
|
||||
for i := 0; i < n; i += 8 {
|
||||
tmp := binary.BigEndian.Uint64(buf[i : i+8])
|
||||
tmp ^= km
|
||||
binary.BigEndian.PutUint64(buf[i:], tmp)
|
||||
}
|
||||
buf = buf[n:]
|
||||
for i := 0; i < len(buf); i++ {
|
||||
buf[i] ^= r.mkey[p&3]
|
||||
p++
|
||||
}
|
||||
r.mkpos = byte(p & 3)
|
||||
}
|
||||
|
||||
// Returns true if the op code corresponds to a control frame.
|
||||
func wsIsControlFrame(frameType wsOpCode) bool {
|
||||
return frameType >= wsCloseMessage
|
||||
}
|
||||
|
||||
// Create the frame header.
|
||||
// Encodes the frame type and optional compression flag, and the size of the payload.
|
||||
func wsCreateFrameHeader(compressed bool, frameType wsOpCode, l int) []byte {
|
||||
fh := make([]byte, wsMaxFrameHeaderSize)
|
||||
n := wsFillFrameHeader(fh, wsFirstFrame, wsFinalFrame, compressed, frameType, l)
|
||||
return fh[:n]
|
||||
}
|
||||
|
||||
func wsFillFrameHeader(fh []byte, first, final, compressed bool, frameType wsOpCode, l int) int {
|
||||
var n int
|
||||
var b byte
|
||||
if first {
|
||||
b = byte(frameType)
|
||||
}
|
||||
if final {
|
||||
b |= wsFinalBit
|
||||
}
|
||||
if compressed {
|
||||
b |= wsRsv1Bit
|
||||
}
|
||||
switch {
|
||||
case l <= 125:
|
||||
n = 2
|
||||
fh[0] = b
|
||||
fh[1] = byte(l)
|
||||
case l < 65536:
|
||||
n = 4
|
||||
fh[0] = b
|
||||
fh[1] = 126
|
||||
binary.BigEndian.PutUint16(fh[2:], uint16(l))
|
||||
default:
|
||||
n = 10
|
||||
fh[0] = b
|
||||
fh[1] = 127
|
||||
binary.BigEndian.PutUint64(fh[2:], uint64(l))
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// Invokes wsEnqueueControlMessageLocked under client lock.
|
||||
//
|
||||
// Client lock MUST NOT be held on entry
|
||||
func (c *client) wsEnqueueControlMessage(controlMsg wsOpCode, payload []byte) {
|
||||
c.mu.Lock()
|
||||
c.wsEnqueueControlMessageLocked(controlMsg, payload)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Enqueues a websocket control message.
|
||||
// If the control message is a wsCloseMessage, then marks this client
|
||||
// has having sent the close message (since only one should be sent).
|
||||
// This will prevent the generic closeConnection() to enqueue one.
|
||||
//
|
||||
// Client lock held on entry.
|
||||
func (c *client) wsEnqueueControlMessageLocked(controlMsg wsOpCode, payload []byte) {
|
||||
// Control messages are never compressed and their size will be
|
||||
// less than wsMaxControlPayloadSize, which means the frame header
|
||||
// will be only 2 bytes.
|
||||
cm := make([]byte, 2+len(payload))
|
||||
wsFillFrameHeader(cm, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, controlMsg, len(payload))
|
||||
// Note that payload is optional.
|
||||
if len(payload) > 0 {
|
||||
copy(cm[2:], payload)
|
||||
}
|
||||
c.out.pb += int64(len(cm))
|
||||
if controlMsg == wsCloseMessage {
|
||||
// We can't add the close message to the frames buffers
|
||||
// now. It will be done on a flushOutbound() when there
|
||||
// are no more pending buffers to send.
|
||||
c.ws.closeSent = true
|
||||
c.ws.closeMsg = cm
|
||||
} else {
|
||||
c.ws.frames = append(c.ws.frames, cm)
|
||||
c.ws.fs += int64(len(cm))
|
||||
}
|
||||
c.flushSignal()
|
||||
}
|
||||
|
||||
// Enqueues a websocket close message with a status mapped from the given `reason`.
|
||||
//
|
||||
// Client lock held on entry
|
||||
func (c *client) wsEnqueueCloseMessage(reason ClosedState) {
|
||||
var status int
|
||||
switch reason {
|
||||
case ClientClosed:
|
||||
status = wsCloseStatusNormalClosure
|
||||
case AuthenticationTimeout, AuthenticationViolation, SlowConsumerPendingBytes, SlowConsumerWriteDeadline,
|
||||
MaxAccountConnectionsExceeded, MaxConnectionsExceeded, MaxControlLineExceeded, MaxSubscriptionsExceeded,
|
||||
MissingAccount, AuthenticationExpired, Revocation:
|
||||
status = wsCloseStatusPolicyViolation
|
||||
case TLSHandshakeError:
|
||||
status = wsCloseStatusTLSHandshake
|
||||
case ParseError, ProtocolViolation, BadClientProtocolVersion:
|
||||
status = wsCloseStatusProtocolError
|
||||
case MaxPayloadExceeded:
|
||||
status = wsCloseStatusMessageTooBig
|
||||
case ServerShutdown:
|
||||
status = wsCloseStatusGoingAway
|
||||
case WriteError, ReadError, StaleConnection:
|
||||
status = wsCloseStatusAbnormalClosure
|
||||
default:
|
||||
status = wsCloseStatusInternalSrvError
|
||||
}
|
||||
body := wsCreateCloseMessage(status, reason.String())
|
||||
c.wsEnqueueControlMessageLocked(wsCloseMessage, body)
|
||||
}
|
||||
|
||||
// Create and then enqueue a close message with a protocol error and the
|
||||
// given message. This is invoked when parsing websocket frames.
|
||||
//
|
||||
// Lock MUST NOT be held on entry.
|
||||
func (c *client) wsHandleProtocolError(message string) error {
|
||||
buf := wsCreateCloseMessage(wsCloseStatusProtocolError, message)
|
||||
c.wsEnqueueControlMessage(wsCloseMessage, buf)
|
||||
return fmt.Errorf(message)
|
||||
}
|
||||
|
||||
// Create a close message with the given `status` and `body`.
|
||||
// If the `body` is more than the maximum allows control frame payload size,
|
||||
// it is truncated and "..." is added at the end (as a hint that message
|
||||
// is not complete).
|
||||
func wsCreateCloseMessage(status int, body string) []byte {
|
||||
// Since a control message payload is limited in size, we
|
||||
// will limit the text and add trailing "..." if truncated.
|
||||
// The body of a Close Message must be preceded with 2 bytes,
|
||||
// so take that into account for limiting the body length.
|
||||
if len(body) > wsMaxControlPayloadSize-2 {
|
||||
body = body[:wsMaxControlPayloadSize-5]
|
||||
body += "..."
|
||||
}
|
||||
buf := make([]byte, 2+len(body))
|
||||
// We need to have a 2 byte unsigned int that represents the error status code
|
||||
// https://tools.ietf.org/html/rfc6455#section-5.5.1
|
||||
binary.BigEndian.PutUint16(buf[:2], uint16(status))
|
||||
copy(buf[2:], []byte(body))
|
||||
return buf
|
||||
}
|
||||
|
||||
// Process websocket client handshake. On success, returns the raw net.Conn that
|
||||
// will be used to create a *client object.
|
||||
// Invoked from the HTTP server listening on websocket port.
|
||||
func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeResult, error) {
|
||||
opts := s.getOpts()
|
||||
|
||||
// From https://tools.ietf.org/html/rfc6455#section-4.2.1
|
||||
// Point 1.
|
||||
if r.Method != "GET" {
|
||||
return nil, wsReturnHTTPError(w, http.StatusMethodNotAllowed, "request method must be GET")
|
||||
}
|
||||
// Point 2.
|
||||
if r.Host == "" {
|
||||
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "'Host' missing in request")
|
||||
}
|
||||
// Point 3.
|
||||
if !wsHeaderContains(r.Header, "Upgrade", "websocket") {
|
||||
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "invalid value for header 'Uprade'")
|
||||
}
|
||||
// Point 4.
|
||||
if !wsHeaderContains(r.Header, "Connection", "Upgrade") {
|
||||
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "invalid value for header 'Connection'")
|
||||
}
|
||||
// Point 5.
|
||||
key := r.Header.Get("Sec-Websocket-Key")
|
||||
if key == "" {
|
||||
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "key missing")
|
||||
}
|
||||
// Point 6.
|
||||
if !wsHeaderContains(r.Header, "Sec-Websocket-Version", "13") {
|
||||
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "invalid version")
|
||||
}
|
||||
// Others are optional
|
||||
// Point 7.
|
||||
if err := s.websocket.checkOrigin(r); err != nil {
|
||||
return nil, wsReturnHTTPError(w, http.StatusForbidden, fmt.Sprintf("origin not allowed: %v", err))
|
||||
}
|
||||
// Point 8.
|
||||
// We don't have protocols, so ignore.
|
||||
// Point 9.
|
||||
// Extensions, only support for compression at the moment
|
||||
compress := opts.Websocket.Compression
|
||||
if compress {
|
||||
compress = wsClientSupportsCompression(r.Header)
|
||||
}
|
||||
|
||||
h := w.(http.Hijacker)
|
||||
conn, brw, err := h.Hijack()
|
||||
if err != nil {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
return nil, wsReturnHTTPError(w, http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if brw.Reader.Buffered() > 0 {
|
||||
conn.Close()
|
||||
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "client sent data before handshake is complete")
|
||||
}
|
||||
|
||||
var buf [1024]byte
|
||||
p := buf[:0]
|
||||
|
||||
// From https://tools.ietf.org/html/rfc6455#section-4.2.2
|
||||
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
|
||||
p = append(p, wsAcceptKey(key)...)
|
||||
p = append(p, _CRLF_...)
|
||||
if compress {
|
||||
p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
|
||||
}
|
||||
p = append(p, _CRLF_...)
|
||||
|
||||
if _, err = conn.Write(p); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
// If there was a deadline set for the handshake, clear it now.
|
||||
if opts.Websocket.HandshakeTimeout > 0 {
|
||||
conn.SetDeadline(time.Time{})
|
||||
}
|
||||
ws := &websocket{compress: compress}
|
||||
// Indicate if this is likely coming from a browser.
|
||||
if ua := r.Header.Get("User-Agent"); ua != "" && strings.HasPrefix(ua, "Mozilla/") {
|
||||
ws.browser = true
|
||||
}
|
||||
return &wsUpgradeResult{conn: conn, ws: ws}, nil
|
||||
}
|
||||
|
||||
// Returns true if the header named `name` contains a token with value `value`.
|
||||
func wsHeaderContains(header http.Header, name string, value string) bool {
|
||||
for _, s := range header[name] {
|
||||
tokens := strings.Split(s, ",")
|
||||
for _, t := range tokens {
|
||||
t = strings.Trim(t, " \t")
|
||||
if strings.EqualFold(t, value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Return true if the client has "permessage-deflate" in its extensions.
|
||||
func wsClientSupportsCompression(header http.Header) bool {
|
||||
for _, extensionList := range header["Sec-Websocket-Extensions"] {
|
||||
extensions := strings.Split(extensionList, ",")
|
||||
for _, extension := range extensions {
|
||||
extension = strings.Trim(extension, " \t")
|
||||
params := strings.Split(extension, ";")
|
||||
for _, p := range params {
|
||||
p = strings.Trim(p, " \t")
|
||||
if strings.EqualFold(p, "permessage-deflate") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Send an HTTP error with the given `status`` to the given http response writer `w`.
|
||||
// Return an error created based on the `reason` string.
|
||||
func wsReturnHTTPError(w http.ResponseWriter, status int, reason string) error {
|
||||
err := fmt.Errorf("websocket handshake error: %s", reason)
|
||||
w.Header().Set("Sec-Websocket-Version", "13")
|
||||
http.Error(w, http.StatusText(status), status)
|
||||
return err
|
||||
}
|
||||
|
||||
// If the server is configured to accept any origin, then this function returns
|
||||
// `nil` without checking if the Origin is present and valid.
|
||||
// Otherwise, this will check that the Origin matches the same origine or
|
||||
// any origin in the allowed list.
|
||||
func (w *srvWebsocket) checkOrigin(r *http.Request) error {
|
||||
w.mu.RLock()
|
||||
checkSame := w.sameOrigin
|
||||
listEmpty := len(w.allowedOrigins) == 0
|
||||
w.mu.RUnlock()
|
||||
if !checkSame && listEmpty {
|
||||
return nil
|
||||
}
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
origin = r.Header.Get("Sec-Websocket-Origin")
|
||||
}
|
||||
if origin == "" {
|
||||
return errors.New("origin not provided")
|
||||
}
|
||||
u, err := url.ParseRequestURI(origin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
oh, op, err := wsGetHostAndPort(u.Scheme == "https", u.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If checking same origin, compare with the http's request's Host.
|
||||
if checkSame {
|
||||
rh, rp, err := wsGetHostAndPort(r.TLS != nil, r.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if oh != rh || op != rp {
|
||||
return errors.New("not same origin")
|
||||
}
|
||||
// I guess it is possible to have cases where one wants to check
|
||||
// same origin, but also that the origin is in the allowed list.
|
||||
// So continue with the next check.
|
||||
}
|
||||
if !listEmpty {
|
||||
w.mu.RLock()
|
||||
ao := w.allowedOrigins[oh]
|
||||
w.mu.RUnlock()
|
||||
if ao == nil || u.Scheme != ao.scheme || op != ao.port {
|
||||
return errors.New("not in the allowed list")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func wsGetHostAndPort(tls bool, hostport string) (string, string, error) {
|
||||
host, port, err := net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
// If error is missing port, then use defaults based on the scheme
|
||||
if ae, ok := err.(*net.AddrError); ok && strings.Contains(ae.Err, "missing port") {
|
||||
err = nil
|
||||
host = hostport
|
||||
if tls {
|
||||
port = "443"
|
||||
} else {
|
||||
port = "80"
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.ToLower(host), port, err
|
||||
}
|
||||
|
||||
// Concatenate the key sent by the client with the GUID, then computes the SHA1 hash
|
||||
// and returns it as a based64 encoded string.
|
||||
func wsAcceptKey(key string) string {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(key))
|
||||
h.Write(wsGUID)
|
||||
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// Validate the websocket related options.
|
||||
func validateWebsocketOptions(o *Options) error {
|
||||
wo := &o.Websocket
|
||||
// If no port is defined, we don't care about other options
|
||||
if wo.Port == 0 {
|
||||
return nil
|
||||
}
|
||||
// Enforce TLS...
|
||||
if wo.TLSConfig == nil {
|
||||
return errors.New("websocket requires TLS configuration")
|
||||
}
|
||||
// Make sure that allowed origins, if specified, can be parsed.
|
||||
for _, ao := range wo.AllowedOrigins {
|
||||
if _, err := url.Parse(ao); err != nil {
|
||||
return fmt.Errorf("unable to parse allowed origin: %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Creates or updates the existing map
|
||||
func (s *Server) wsSetOriginOptions(o *WebsocketOpts) {
|
||||
ws := &s.websocket
|
||||
ws.mu.Lock()
|
||||
defer ws.mu.Unlock()
|
||||
// Copy over the option's same origin boolean
|
||||
ws.sameOrigin = o.SameOrigin
|
||||
// Reset the map. Will help for config reload if/when we support it.
|
||||
ws.allowedOrigins = nil
|
||||
if o.AllowedOrigins == nil {
|
||||
return
|
||||
}
|
||||
for _, ao := range o.AllowedOrigins {
|
||||
// We have previously checked (during options validation) that the urls
|
||||
// are parseable, but if we get an error, report and skip.
|
||||
u, err := url.ParseRequestURI(ao)
|
||||
if err != nil {
|
||||
s.Errorf("error parsing allowed origin: %v", err)
|
||||
continue
|
||||
}
|
||||
h, p, _ := wsGetHostAndPort(u.Scheme == "https", u.Host)
|
||||
if ws.allowedOrigins == nil {
|
||||
ws.allowedOrigins = make(map[string]*allowedOrigin, len(o.AllowedOrigins))
|
||||
}
|
||||
ws.allowedOrigins[h] = &allowedOrigin{scheme: u.Scheme, port: p}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) startWebsocketServer() {
|
||||
sopts := s.getOpts()
|
||||
o := &sopts.Websocket
|
||||
|
||||
s.wsSetOriginOptions(o)
|
||||
|
||||
var hl net.Listener
|
||||
var proto string
|
||||
var err error
|
||||
|
||||
port := o.Port
|
||||
if port == -1 {
|
||||
port = 0
|
||||
}
|
||||
hp := net.JoinHostPort(o.Host, strconv.Itoa(port))
|
||||
|
||||
// We are enforcing (when validating the options) the use of TLS, but the
|
||||
// code was originally supporting both modes. The reason for TLS only is
|
||||
// that we expect users to send JWTs with bearer tokens and we want to
|
||||
// avoid the possibility of it being "intercepted".
|
||||
|
||||
if o.TLSConfig != nil {
|
||||
proto = "wss"
|
||||
config := o.TLSConfig.Clone()
|
||||
config.ClientAuth = tls.NoClientCert
|
||||
hl, err = tls.Listen("tcp", hp, config)
|
||||
} else {
|
||||
proto = "ws"
|
||||
hl, err = net.Listen("tcp", hp)
|
||||
}
|
||||
if err != nil {
|
||||
s.Fatalf("Unable to listen for websocket connections: %v", err)
|
||||
return
|
||||
}
|
||||
s.Noticef("Listening for websocket clients on %s://%s:%d", proto, o.Host, port)
|
||||
|
||||
s.mu.Lock()
|
||||
s.websocket.tls = proto == "wss"
|
||||
if port == 0 {
|
||||
s.opts.Websocket.Port = hl.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port)
|
||||
if err != nil {
|
||||
s.Fatalf("Unable to get websocket connect URLs: %v", err)
|
||||
hl.Close()
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
res, err := s.wsUpgrade(w, r)
|
||||
if err != nil {
|
||||
s.Errorf(err.Error())
|
||||
return
|
||||
}
|
||||
s.createClient(res.conn, res.ws)
|
||||
})
|
||||
hs := &http.Server{
|
||||
Addr: hp,
|
||||
Handler: mux,
|
||||
ReadTimeout: o.HandshakeTimeout,
|
||||
ErrorLog: log.New(&wsCaptureHTTPServerLog{s}, "", 0),
|
||||
}
|
||||
s.websocket.server = hs
|
||||
s.websocket.listener = hl
|
||||
s.mu.Unlock()
|
||||
|
||||
s.startGoRoutine(func() {
|
||||
defer s.grWG.Done()
|
||||
|
||||
if err := hs.Serve(hl); err != http.ErrServerClosed {
|
||||
s.Fatalf("websocket listener error: %v", err)
|
||||
}
|
||||
s.done <- true
|
||||
})
|
||||
}
|
||||
|
||||
type wsCaptureHTTPServerLog struct {
|
||||
s *Server
|
||||
}
|
||||
|
||||
func (cl *wsCaptureHTTPServerLog) Write(p []byte) (int, error) {
|
||||
var buf [128]byte
|
||||
var b = buf[:0]
|
||||
|
||||
copy(b, []byte("websocket :"))
|
||||
offset := 0
|
||||
if bytes.HasPrefix(p, []byte("http:")) {
|
||||
offset = 6
|
||||
}
|
||||
b = append(b, p[offset:]...)
|
||||
cl.s.Errorf(string(b))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
|
||||
var nb net.Buffers
|
||||
var total = 0
|
||||
var mfs = 0
|
||||
if c.ws.browser {
|
||||
mfs = wsFrameSizeForBrowsers
|
||||
}
|
||||
if len(c.out.p) > 0 {
|
||||
p := c.out.p
|
||||
c.out.p = nil
|
||||
nb = append(c.out.nb, p)
|
||||
} else if len(c.out.nb) > 0 {
|
||||
nb = c.out.nb
|
||||
}
|
||||
// Start with possible already framed buffers (that we could have
|
||||
// got from partials or control messages such as ws pings or pongs).
|
||||
bufs := c.ws.frames
|
||||
if c.ws.compress && len(nb) > 0 {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
cp := c.ws.compressor
|
||||
if cp == nil {
|
||||
c.ws.compressor, _ = flate.NewWriter(buf, flate.BestSpeed)
|
||||
cp = c.ws.compressor
|
||||
} else {
|
||||
cp.Reset(buf)
|
||||
}
|
||||
var usz int
|
||||
var csz int
|
||||
for _, b := range nb {
|
||||
usz += len(b)
|
||||
cp.Write(b)
|
||||
}
|
||||
cp.Close()
|
||||
b := buf.Bytes()
|
||||
p := b[:len(b)-4]
|
||||
if mfs > 0 && len(p) > mfs {
|
||||
for first, final := true, false; len(p) > 0; first = false {
|
||||
lp := len(p)
|
||||
if lp > mfs {
|
||||
lp = mfs
|
||||
} else {
|
||||
final = true
|
||||
}
|
||||
fh := make([]byte, wsMaxFrameHeaderSize)
|
||||
n := wsFillFrameHeader(fh, first, final, wsCompressedFrame, wsBinaryMessage, lp)
|
||||
bufs = append(bufs, fh[:n], p[:lp])
|
||||
csz += n + lp
|
||||
p = p[lp:]
|
||||
}
|
||||
} else {
|
||||
h := wsCreateFrameHeader(true, wsBinaryMessage, len(p))
|
||||
bufs = append(bufs, h, p)
|
||||
csz = len(h) + len(p)
|
||||
}
|
||||
// Add to pb the compressed data size (including headers), but
|
||||
// remove the original uncompressed data size that was added
|
||||
// during the queueing.
|
||||
c.out.pb += int64(csz) - int64(usz)
|
||||
c.ws.fs += int64(csz)
|
||||
} else if len(nb) > 0 {
|
||||
if mfs > 0 {
|
||||
// We are limiting the frame size.
|
||||
startFrame := func() int {
|
||||
bufs = append(bufs, make([]byte, wsMaxFrameHeaderSize))
|
||||
return len(bufs) - 1
|
||||
}
|
||||
endFrame := func(idx, size int) {
|
||||
n := wsFillFrameHeader(bufs[idx], wsFirstFrame, wsFinalFrame, wsUncompressedFrame, wsBinaryMessage, size)
|
||||
c.out.pb += int64(n)
|
||||
c.ws.fs += int64(n + size)
|
||||
bufs[idx] = bufs[idx][:n]
|
||||
}
|
||||
|
||||
fhIdx := startFrame()
|
||||
for i := 0; i < len(nb); i++ {
|
||||
b := nb[i]
|
||||
if total+len(b) <= mfs {
|
||||
bufs = append(bufs, b)
|
||||
total += len(b)
|
||||
continue
|
||||
}
|
||||
for len(b) > 0 {
|
||||
endFrame(fhIdx, total)
|
||||
total = len(b)
|
||||
if total >= mfs {
|
||||
total = mfs
|
||||
}
|
||||
fhIdx = startFrame()
|
||||
bufs = append(bufs, b[:total])
|
||||
b = b[total:]
|
||||
}
|
||||
}
|
||||
if total > 0 {
|
||||
endFrame(fhIdx, total)
|
||||
}
|
||||
} else {
|
||||
// If there is no limit on the frame size, create a single frame for
|
||||
// all pending buffers.
|
||||
for _, b := range nb {
|
||||
total += len(b)
|
||||
}
|
||||
wsfh := wsCreateFrameHeader(false, wsBinaryMessage, total)
|
||||
c.out.pb += int64(len(wsfh))
|
||||
bufs = append(bufs, wsfh)
|
||||
bufs = append(bufs, nb...)
|
||||
c.ws.fs += int64(len(wsfh) + total)
|
||||
}
|
||||
}
|
||||
if len(c.ws.closeMsg) > 0 {
|
||||
bufs = append(bufs, c.ws.closeMsg)
|
||||
c.ws.fs += int64(len(c.ws.closeMsg))
|
||||
c.ws.closeMsg = nil
|
||||
}
|
||||
c.ws.frames = nil
|
||||
return bufs, c.ws.fs
|
||||
}
|
||||
2779
server/websocket_test.go
Normal file
2779
server/websocket_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -54,6 +54,12 @@ func TestPortsFile(t *testing.T) {
|
||||
opts.HTTPPort = -1
|
||||
opts.ProfPort = -1
|
||||
opts.Cluster.Port = -1
|
||||
opts.Websocket.Port = -1
|
||||
tc := &server.TLSConfigOpts{
|
||||
CertFile: "./configs/certs/server-cert.pem",
|
||||
KeyFile: "./configs/certs/server-key.pem",
|
||||
}
|
||||
opts.Websocket.TLSConfig, _ = server.GenTLSConfig(tc)
|
||||
|
||||
s := RunServer(&opts)
|
||||
// this for test cleanup in case we fail - will be ignored if server already shutdown
|
||||
@@ -100,6 +106,10 @@ func TestPortsFile(t *testing.T) {
|
||||
t.Fatal("Expected at least one profile listen url")
|
||||
}
|
||||
|
||||
if len(readPorts.WebSocket) == 0 || !strings.HasPrefix(readPorts.WebSocket[0], "wss://") {
|
||||
t.Fatal("Expected at least one ws listen url")
|
||||
}
|
||||
|
||||
// testing cleanup
|
||||
s.Shutdown()
|
||||
// if we called shutdown, the cleanup code should have kicked
|
||||
|
||||
Reference in New Issue
Block a user