Files
nats-server/server/websocket_test.go
Ivan Kozlovic d2a8282a0d [FIXED] LeafNode TLSMap and websocket auth override
We added authentication override block for websocket configuration
in PR #1463 and #1465 which somehow introduced a drop in perf as
reported by the bench tests.
This PR refactors a bit to restore the performance numbers.

This change also fixes the override behavior for websocket auth:
- If websocket's NoAuthUser is configured, the websocket's auth
  block MUST define Users, and the user be present.
- If there is any override (username/pwd,token,etc..) then the
  whole block config will be used when authenticating a websocket
  client, which means that if websocket NoAuthUser is empty we
  are not falling back to the regular client's NoAuthUser config.
- TLSMap always override the regular client's config. That is,
  whatever TLSMap value specified in the websocket's tls{} block
  will be used.

The TLSMap configuration was not used for LeafNodes. The behavior
now will be:
- If LeafNode's auth block contains users and TLSMap is true,
  the user is looked up based on the cert's info. If not found,
  authentication will fail. If found, it will be authenticated
  and bound to associated account.
- If no user is specified in LeafNode's auth block and TLSMap
  is true, then the cert's info will be used against the global
  users map.

Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
2020-06-11 17:06:54 -06:00

3548 lines
98 KiB
Go

// Copyright 2020 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"bufio"
"bytes"
"compress/flate"
"crypto/tls"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"reflect"
"strings"
"sync"
"testing"
"time"
"github.com/nats-io/nkeys"
)
type testReader struct {
buf []byte
pos int
max int
err error
}
func (tr *testReader) Read(p []byte) (int, error) {
if tr.err != nil {
return 0, tr.err
}
n := len(tr.buf) - tr.pos
if n == 0 {
return 0, nil
}
if n > cap(p) {
n = cap(p)
}
if tr.max > 0 && n > tr.max {
n = tr.max
}
copy(p, tr.buf[tr.pos:tr.pos+n])
tr.pos += n
return n, nil
}
func TestWSGet(t *testing.T) {
rb := []byte("012345")
tr := &testReader{buf: []byte("6789")}
for _, test := range []struct {
name string
pos int
needed int
newpos int
trmax int
result string
reterr bool
}{
{"fromrb1", 0, 3, 3, 4, "012", false}, // Partial from read buffer
{"fromrb2", 3, 2, 5, 4, "34", false}, // Partial from read buffer
{"fromrb3", 5, 1, 6, 4, "5", false}, // Partial from read buffer
{"fromtr1", 4, 4, 6, 4, "4567", false}, // Partial from read buffer + some of ioReader
{"fromtr2", 4, 6, 6, 4, "456789", false}, // Partial from read buffer + all of ioReader
{"fromtr3", 4, 6, 6, 2, "456789", false}, // Partial from read buffer + all of ioReader with several reads
{"fromtr4", 4, 6, 6, 2, "", true}, // ioReader returns error
} {
t.Run(test.name, func(t *testing.T) {
tr.pos = 0
tr.max = test.trmax
if test.reterr {
tr.err = fmt.Errorf("on purpose")
}
res, np, err := wsGet(tr, rb, test.pos, test.needed)
if test.reterr {
if err == nil {
t.Fatalf("Expected error, got none")
}
if err.Error() != "on purpose" {
t.Fatalf("Unexpected error: %v", err)
}
if np != 0 || res != nil {
t.Fatalf("Unexpected returned values: res=%v n=%v", res, np)
}
return
}
if err != nil {
t.Fatalf("Error on get: %v", err)
}
if np != test.newpos {
t.Fatalf("Expected pos=%v, got %v", test.newpos, np)
}
if string(res) != test.result {
t.Fatalf("Invalid returned content: %s", res)
}
})
}
}
func TestWSIsControlFrame(t *testing.T) {
for _, test := range []struct {
name string
code wsOpCode
isControl bool
}{
{"binary", wsBinaryMessage, false},
{"text", wsTextMessage, false},
{"ping", wsPingMessage, true},
{"pong", wsPongMessage, true},
{"close", wsCloseMessage, true},
} {
t.Run(test.name, func(t *testing.T) {
if res := wsIsControlFrame(test.code); res != test.isControl {
t.Fatalf("Expected %q isControl to be %v, got %v", test.name, test.isControl, res)
}
})
}
}
func testWSSimpleMask(key, buf []byte) {
for i := 0; i < len(buf); i++ {
buf[i] ^= key[i&3]
}
}
func TestWSUnmask(t *testing.T) {
key := []byte{1, 2, 3, 4}
orgBuf := []byte("this is a clear text")
mask := func() []byte {
t.Helper()
buf := append([]byte(nil), orgBuf...)
testWSSimpleMask(key, buf)
// First ensure that the content is masked.
if bytes.Equal(buf, orgBuf) {
t.Fatalf("Masking did not do anything: %q", buf)
}
return buf
}
ri := &wsReadInfo{}
ri.init()
copy(ri.mkey[:], key)
buf := mask()
// Unmask in one call
ri.unmask(buf)
if !bytes.Equal(buf, orgBuf) {
t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf)
}
// Unmask in multiple calls
buf = mask()
ri.mkpos = 0
ri.unmask(buf[:3])
ri.unmask(buf[3:11])
ri.unmask(buf[11:])
if !bytes.Equal(buf, orgBuf) {
t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf)
}
}
func TestWSCreateCloseMessage(t *testing.T) {
for _, test := range []struct {
name string
status int
psize int
truncated bool
}{
{"fits", wsCloseStatusInternalSrvError, 10, false},
{"truncated", wsCloseStatusProtocolError, wsMaxControlPayloadSize + 10, true},
} {
t.Run(test.name, func(t *testing.T) {
payload := make([]byte, test.psize)
for i := 0; i < len(payload); i++ {
payload[i] = byte('A' + (i % 26))
}
res := wsCreateCloseMessage(test.status, string(payload))
if status := binary.BigEndian.Uint16(res[:2]); int(status) != test.status {
t.Fatalf("Expected status to be %v, got %v", test.status, status)
}
psize := len(res) - 2
if !test.truncated {
if int(psize) != test.psize {
t.Fatalf("Expected size to be %v, got %v", test.psize, psize)
}
if !bytes.Equal(res[2:], payload) {
t.Fatalf("Unexpected result: %q", res[2:])
}
return
}
// Since the payload of a close message contains a 2 byte status, the
// actual max text size will be wsMaxControlPayloadSize-2
if int(psize) != wsMaxControlPayloadSize-2 {
t.Fatalf("Expected size to be capped to %v, got %v", wsMaxControlPayloadSize-2, psize)
}
if string(res[len(res)-3:]) != "..." {
t.Fatalf("Expected res to have `...` at the end, got %q", res[4:])
}
})
}
}
func TestWSCreateFrameHeader(t *testing.T) {
for _, test := range []struct {
name string
frameType wsOpCode
compressed bool
len int
}{
{"uncompressed 10", wsBinaryMessage, false, 10},
{"uncompressed 600", wsTextMessage, false, 600},
{"uncompressed 100000", wsTextMessage, false, 100000},
{"compressed 10", wsBinaryMessage, true, 10},
{"compressed 600", wsBinaryMessage, true, 600},
{"compressed 100000", wsTextMessage, true, 100000},
} {
t.Run(test.name, func(t *testing.T) {
res := wsCreateFrameHeader(test.compressed, test.frameType, test.len)
// The server is always sending the message has a single frame,
// so the "final" bit should be set.
expected := byte(test.frameType) | wsFinalBit
if test.compressed {
expected |= wsRsv1Bit
}
if b := res[0]; b != expected {
t.Fatalf("Expected first byte to be %v, got %v", expected, b)
}
switch {
case test.len <= 125:
if len(res) != 2 {
t.Fatalf("Frame len should be 2, got %v", len(res))
}
if res[1] != byte(test.len) {
t.Fatalf("Expected len to be in second byte and be %v, got %v", test.len, res[1])
}
case test.len < 65536:
// 1+1+2
if len(res) != 4 {
t.Fatalf("Frame len should be 4, got %v", len(res))
}
if res[1] != 126 {
t.Fatalf("Second byte value should be 126, got %v", res[1])
}
if rl := binary.BigEndian.Uint16(res[2:]); int(rl) != test.len {
t.Fatalf("Expected len to be %v, got %v", test.len, rl)
}
default:
// 1+1+8
if len(res) != 10 {
t.Fatalf("Frame len should be 10, got %v", len(res))
}
if res[1] != 127 {
t.Fatalf("Second byte value should be 127, got %v", res[1])
}
if rl := binary.BigEndian.Uint64(res[2:]); int(rl) != test.len {
t.Fatalf("Expected len to be %v, got %v", test.len, rl)
}
}
})
}
}
func testWSCreateClientMsg(frameType wsOpCode, frameNum int, final, compressed bool, payload []byte) []byte {
if compressed {
buf := &bytes.Buffer{}
compressor, _ := flate.NewWriter(buf, 1)
compressor.Write(payload)
compressor.Flush()
payload = buf.Bytes()
// The last 4 bytes are dropped
payload = payload[:len(payload)-4]
}
frame := make([]byte, 14+len(payload))
if frameNum == 1 {
frame[0] = byte(frameType)
}
if final {
frame[0] |= wsFinalBit
}
if compressed {
frame[0] |= wsRsv1Bit
}
pos := 1
lenPayload := len(payload)
switch {
case lenPayload <= 125:
frame[pos] = byte(lenPayload) | wsMaskBit
pos++
case lenPayload < 65536:
frame[pos] = 126 | wsMaskBit
binary.BigEndian.PutUint16(frame[2:], uint16(lenPayload))
pos += 3
default:
frame[1] = 127 | wsMaskBit
binary.BigEndian.PutUint64(frame[2:], uint64(lenPayload))
pos += 9
}
key := []byte{1, 2, 3, 4}
copy(frame[pos:], key)
pos += 4
copy(frame[pos:], payload)
testWSSimpleMask(key, frame[pos:])
pos += lenPayload
return frame[:pos]
}
func testWSSetupForRead() (*client, *wsReadInfo, *testReader) {
ri := &wsReadInfo{}
ri.init()
tr := &testReader{}
opts := DefaultOptions()
opts.MaxPending = MAX_PENDING_SIZE
s := &Server{opts: opts}
c := &client{srv: s, ws: &websocket{}}
c.initClient()
return c, ri, tr
}
func TestWSReadUncompressedFrames(t *testing.T) {
c, ri, tr := testWSSetupForRead()
// Create 2 WS messages
pl1 := []byte("first message")
wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl1)
pl2 := []byte("second message")
wsmsg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl2)
// Add both in single buffer
orgrb := append([]byte(nil), wsmsg1...)
orgrb = append(orgrb, wsmsg2...)
rb := append([]byte(nil), orgrb...)
bufs, err := c.wsRead(ri, tr, rb)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 2 {
t.Fatalf("Expected 2 buffers, got %v", n)
}
if !bytes.Equal(bufs[0], pl1) {
t.Fatalf("Unexpected content for buffer 1: %s", bufs[0])
}
if !bytes.Equal(bufs[1], pl2) {
t.Fatalf("Unexpected content for buffer 2: %s", bufs[1])
}
// Now reset and try with the read buffer not containing full ws frame
c, ri, tr = testWSSetupForRead()
rb = append([]byte(nil), orgrb...)
// Frame is 1+1+4+'first message'. So say we pass with rb of 11 bytes,
// then we should get "first"
bufs, err = c.wsRead(ri, tr, rb[:11])
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 1 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
if string(bufs[0]) != "first" {
t.Fatalf("Unexpected content: %q", bufs[0])
}
// Call again with more data..
bufs, err = c.wsRead(ri, tr, rb[11:32])
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 2 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
if string(bufs[0]) != " message" {
t.Fatalf("Unexpected content: %q", bufs[0])
}
if string(bufs[1]) != "second " {
t.Fatalf("Unexpected content: %q", bufs[1])
}
// Call with the rest
bufs, err = c.wsRead(ri, tr, rb[32:])
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 1 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
if string(bufs[0]) != "message" {
t.Fatalf("Unexpected content: %q", bufs[0])
}
}
func TestWSReadCompressedFrames(t *testing.T) {
c, ri, tr := testWSSetupForRead()
uncompressed := []byte("this is the uncompress data")
wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed)
rb := append([]byte(nil), wsmsg1...)
// Call with some but not all of the payload
bufs, err := c.wsRead(ri, tr, rb[:10])
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 0 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
// Call with the rest, only then should we get the uncompressed data.
bufs, err = c.wsRead(ri, tr, rb[10:])
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 1 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
if !bytes.Equal(bufs[0], uncompressed) {
t.Fatalf("Unexpected content: %s", bufs[0])
}
// Stress the fact that we use a pool and want to make sure
// that if we get a decompressor from the pool, it is properly reset
// with the buffer to decompress.
for i := 0; i < 9; i++ {
rb = append(rb, wsmsg1...)
}
bufs, err = c.wsRead(ri, tr, rb)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 10 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
}
func TestWSReadCompressedFrameCorrupted(t *testing.T) {
c, ri, tr := testWSSetupForRead()
uncompressed := []byte("this is the uncompress data")
wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed)
copy(wsmsg1[10:], []byte{1, 2, 3, 4})
rb := append([]byte(nil), wsmsg1...)
bufs, err := c.wsRead(ri, tr, rb)
if err == nil || !strings.Contains(err.Error(), "corrupt") {
t.Fatalf("Expected error about corrupted data, got %v", err)
}
if n := len(bufs); n != 0 {
t.Fatalf("Expected no buffer, got %v", n)
}
}
func TestWSReadVariousFrameSizes(t *testing.T) {
for _, test := range []struct {
name string
size int
}{
{"tiny", 100},
{"medium", 1000},
{"large", 70000},
} {
t.Run(test.name, func(t *testing.T) {
c, ri, tr := testWSSetupForRead()
uncompressed := make([]byte, test.size)
for i := 0; i < len(uncompressed); i++ {
uncompressed[i] = 'A' + byte(i%26)
}
wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, uncompressed)
rb := append([]byte(nil), wsmsg1...)
bufs, err := c.wsRead(ri, tr, rb)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 1 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
if !bytes.Equal(bufs[0], uncompressed) {
t.Fatalf("Unexpected content: %s", bufs[0])
}
})
}
}
func TestWSReadFragmentedFrames(t *testing.T) {
c, ri, tr := testWSSetupForRead()
payloads := []string{"first", "second", "third"}
var rb []byte
for i := 0; i < len(payloads); i++ {
final := i == len(payloads)-1
frag := testWSCreateClientMsg(wsBinaryMessage, i+1, final, false, []byte(payloads[i]))
rb = append(rb, frag...)
}
bufs, err := c.wsRead(ri, tr, rb)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 3 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
for i, expected := range payloads {
if string(bufs[i]) != expected {
t.Fatalf("Unexpected content for buf=%v: %s", i, bufs[i])
}
}
}
func TestWSReadPartialFrameHeaderAtEndOfReadBuffer(t *testing.T) {
c, ri, tr := testWSSetupForRead()
msg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg1"))
msg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg2"))
rb := append([]byte(nil), msg1...)
rb = append(rb, msg2...)
// We will pass the first frame + the first byte of the next frame.
rbl := rb[:len(msg1)+1]
// Make the io reader return the rest of the frame
tr.buf = rb[len(msg1)+1:]
bufs, err := c.wsRead(ri, tr, rbl)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 1 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
// We should not have asked to the io reader more than what is needed for reading
// the frame header. Since we had already the first byte in the read buffer,
// tr.pos should be 1(size)+4(key)=5
if tr.pos != 5 {
t.Fatalf("Expected reader pos to be 5, got %v", tr.pos)
}
}
func TestWSReadPingFrame(t *testing.T) {
for _, test := range []struct {
name string
payload []byte
}{
{"without payload", nil},
{"with payload", []byte("optional payload")},
} {
t.Run(test.name, func(t *testing.T) {
c, ri, tr := testWSSetupForRead()
ping := testWSCreateClientMsg(wsPingMessage, 1, true, false, test.payload)
rb := append([]byte(nil), ping...)
bufs, err := c.wsRead(ri, tr, rb)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 0 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
// A PONG should have been queued with the payload of the ping
c.mu.Lock()
nb, _ := c.collapsePtoNB()
c.mu.Unlock()
if n := len(nb); n == 0 {
t.Fatalf("Expected buffers, got %v", n)
}
if expected := 2 + len(test.payload); expected != len(nb[0]) {
t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0]))
}
b := nb[0][0]
if b&wsFinalBit == 0 {
t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
}
if b&byte(wsPongMessage) == 0 {
t.Fatalf("Should have been a PONG, it wasn't: %v", b)
}
if len(test.payload) > 0 {
if !bytes.Equal(nb[0][2:], test.payload) {
t.Fatalf("Unexpected content: %s", nb[0][2:])
}
}
})
}
}
func TestWSReadPongFrame(t *testing.T) {
for _, test := range []struct {
name string
payload []byte
}{
{"without payload", nil},
{"with payload", []byte("optional payload")},
} {
t.Run(test.name, func(t *testing.T) {
c, ri, tr := testWSSetupForRead()
pong := testWSCreateClientMsg(wsPongMessage, 1, true, false, test.payload)
rb := append([]byte(nil), pong...)
bufs, err := c.wsRead(ri, tr, rb)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 0 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
// Nothing should be sent...
c.mu.Lock()
nb, _ := c.collapsePtoNB()
c.mu.Unlock()
if n := len(nb); n != 0 {
t.Fatalf("Expected no buffer, got %v", n)
}
})
}
}
func TestWSReadCloseFrame(t *testing.T) {
for _, test := range []struct {
name string
payload []byte
}{
{"without payload", nil},
{"with payload", []byte("optional payload")},
} {
t.Run(test.name, func(t *testing.T) {
c, ri, tr := testWSSetupForRead()
// a close message has a status in 2 bytes + optional payload
payload := make([]byte, 2+len(test.payload))
binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure)
if len(test.payload) > 0 {
copy(payload[2:], test.payload)
}
close := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload)
// Have a normal frame prior to close to make sure that wsRead returns
// the normal frame along with io.EOF to indicate that wsCloseMessage was received.
msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg"))
rb := append([]byte(nil), msg...)
rb = append(rb, close...)
bufs, err := c.wsRead(ri, tr, rb)
// It is expected that wsRead returns io.EOF on processing a close.
if err != io.EOF {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 1 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
if string(bufs[0]) != "msg" {
t.Fatalf("Unexpected content: %s", bufs[0])
}
// A CLOSE should have been queued with the payload of the original close message.
c.mu.Lock()
nb, _ := c.collapsePtoNB()
c.mu.Unlock()
if n := len(nb); n == 0 {
t.Fatalf("Expected buffers, got %v", n)
}
if expected := 2 + 2 + len(test.payload); expected != len(nb[0]) {
t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0]))
}
b := nb[0][0]
if b&wsFinalBit == 0 {
t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
}
if b&byte(wsCloseMessage) == 0 {
t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
}
if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNormalClosure {
t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNormalClosure, status)
}
if len(test.payload) > 0 {
if !bytes.Equal(nb[0][4:], test.payload) {
t.Fatalf("Unexpected content: %s", nb[0][4:])
}
}
})
}
}
func TestWSReadControlFrameBetweebFragmentedFrames(t *testing.T) {
c, ri, tr := testWSSetupForRead()
frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("first"))
frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, []byte("second"))
ctrl := testWSCreateClientMsg(wsPongMessage, 1, true, false, nil)
rb := append([]byte(nil), frag1...)
rb = append(rb, ctrl...)
rb = append(rb, frag2...)
bufs, err := c.wsRead(ri, tr, rb)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 2 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
if string(bufs[0]) != "first" {
t.Fatalf("Unexpected content: %s", bufs[0])
}
if string(bufs[1]) != "second" {
t.Fatalf("Unexpected content: %s", bufs[1])
}
}
func TestWSReadGetErrors(t *testing.T) {
tr := &testReader{err: fmt.Errorf("on purpose")}
for _, test := range []struct {
lenPayload int
rbextra int
}{
{10, 1},
{10, 3},
{200, 1},
{200, 2},
{200, 5},
{70000, 1},
{70000, 5},
{70000, 13},
} {
t.Run("", func(t *testing.T) {
c, ri, _ := testWSSetupForRead()
msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg"))
frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, make([]byte, test.lenPayload))
rb := append([]byte(nil), msg...)
rb = append(rb, frame...)
bufs, err := c.wsRead(ri, tr, rb[:len(msg)+test.rbextra])
if err == nil || err.Error() != "on purpose" {
t.Fatalf("Expected 'on purpose' error, got %v", err)
}
if n := len(bufs); n != 1 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
if string(bufs[0]) != "msg" {
t.Fatalf("Unexpected content: %s", bufs[0])
}
})
}
}
func TestWSHandleControlFrameErrors(t *testing.T) {
c, ri, tr := testWSSetupForRead()
tr.err = fmt.Errorf("on purpose")
// a close message has a status in 2 bytes + optional payload
text := []byte("this is a close message")
payload := make([]byte, 2+len(text))
binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure)
copy(payload[2:], text)
ctrl := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload)
bufs, err := c.wsRead(ri, tr, ctrl[:len(ctrl)-4])
if err == nil || err.Error() != "on purpose" {
t.Fatalf("Expected 'on purpose' error, got %v", err)
}
if n := len(bufs); n != 0 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
// Alter the content of close message. It is supposed to be valid utf-8.
c, ri, tr = testWSSetupForRead()
cp := append([]byte(nil), payload...)
cp[10] = 0xF1
ctrl = testWSCreateClientMsg(wsCloseMessage, 1, true, false, cp)
bufs, err = c.wsRead(ri, tr, ctrl)
// We should still receive an EOF but the message enqueued to the client
// should contain wsCloseStatusInvalidPayloadData and the error about invalid utf8
if err != io.EOF {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 0 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
c.mu.Lock()
nb, _ := c.collapsePtoNB()
c.mu.Unlock()
if n := len(nb); n == 0 {
t.Fatalf("Expected buffers, got %v", n)
}
b := nb[0][0]
if b&wsFinalBit == 0 {
t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
}
if b&byte(wsCloseMessage) == 0 {
t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
}
if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusInvalidPayloadData {
t.Fatalf("Expected status to be %v, got %v", wsCloseStatusInvalidPayloadData, status)
}
if !bytes.Contains(nb[0][4:], []byte("utf8")) {
t.Fatalf("Unexpected content: %s", nb[0][4:])
}
}
func TestWSReadErrors(t *testing.T) {
for _, test := range []struct {
cframe func() []byte
err string
nbufs int
}{
{
func() []byte {
msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello"))
msg[1] &= ^byte(wsMaskBit)
return msg
},
"mask bit missing", 1,
},
{
func() []byte {
return testWSCreateClientMsg(wsPingMessage, 1, true, false, make([]byte, 200))
},
"control frame length bigger than maximum allowed", 1,
},
{
func() []byte {
return testWSCreateClientMsg(wsPingMessage, 1, false, false, []byte("hello"))
},
"control frame does not have final bit set", 1,
},
{
func() []byte {
frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("frag1"))
newMsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("new message"))
all := append([]byte(nil), frag1...)
all = append(all, newMsg...)
return all
},
"new message started before final frame for previous message was received", 2,
},
{
func() []byte {
frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, true, []byte("frag1"))
frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frag2"))
frag2[0] |= wsRsv1Bit
all := append([]byte(nil), frag1...)
all = append(all, frag2...)
return all
},
"invalid continuation frame", 2,
},
{
func() []byte {
return testWSCreateClientMsg(99, 1, false, false, []byte("hello"))
},
"unknown opcode", 1,
},
} {
t.Run(test.err, func(t *testing.T) {
c, ri, tr := testWSSetupForRead()
// Add a valid message first
msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello"))
// Then add the bad frame
bad := test.cframe()
// Add them both to a read buffer
rb := append([]byte(nil), msg...)
rb = append(rb, bad...)
bufs, err := c.wsRead(ri, tr, rb)
if err == nil || !strings.Contains(err.Error(), test.err) {
t.Fatalf("Expected error to contain %q, got %q", test.err, err.Error())
}
if n := len(bufs); n != test.nbufs {
t.Fatalf("Unexpected number of buffers: %v", n)
}
if string(bufs[0]) != "hello" {
t.Fatalf("Unexpected content: %s", bufs[0])
}
})
}
}
func TestWSEnqueueCloseMsg(t *testing.T) {
for _, test := range []struct {
reason ClosedState
status int
}{
{ClientClosed, wsCloseStatusNormalClosure},
{AuthenticationTimeout, wsCloseStatusPolicyViolation},
{AuthenticationViolation, wsCloseStatusPolicyViolation},
{SlowConsumerPendingBytes, wsCloseStatusPolicyViolation},
{SlowConsumerWriteDeadline, wsCloseStatusPolicyViolation},
{MaxAccountConnectionsExceeded, wsCloseStatusPolicyViolation},
{MaxConnectionsExceeded, wsCloseStatusPolicyViolation},
{MaxControlLineExceeded, wsCloseStatusPolicyViolation},
{MaxSubscriptionsExceeded, wsCloseStatusPolicyViolation},
{MissingAccount, wsCloseStatusPolicyViolation},
{AuthenticationExpired, wsCloseStatusPolicyViolation},
{Revocation, wsCloseStatusPolicyViolation},
{TLSHandshakeError, wsCloseStatusTLSHandshake},
{ParseError, wsCloseStatusProtocolError},
{ProtocolViolation, wsCloseStatusProtocolError},
{BadClientProtocolVersion, wsCloseStatusProtocolError},
{MaxPayloadExceeded, wsCloseStatusMessageTooBig},
{ServerShutdown, wsCloseStatusGoingAway},
{WriteError, wsCloseStatusAbnormalClosure},
{ReadError, wsCloseStatusAbnormalClosure},
{StaleConnection, wsCloseStatusAbnormalClosure},
{ClosedState(254), wsCloseStatusInternalSrvError},
} {
t.Run(test.reason.String(), func(t *testing.T) {
c, _, _ := testWSSetupForRead()
c.wsEnqueueCloseMessage(test.reason)
c.mu.Lock()
nb, _ := c.collapsePtoNB()
c.mu.Unlock()
if n := len(nb); n != 1 {
t.Fatalf("Expected 1 buffer, got %v", n)
}
b := nb[0][0]
if b&wsFinalBit == 0 {
t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
}
if b&byte(wsCloseMessage) == 0 {
t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
}
if status := binary.BigEndian.Uint16(nb[0][2:4]); int(status) != test.status {
t.Fatalf("Expected status to be %v, got %v", test.status, status)
}
if string(nb[0][4:]) != test.reason.String() {
t.Fatalf("Unexpected content: %s", nb[0][4:])
}
})
}
}
type testResponseWriter struct {
http.ResponseWriter
buf bytes.Buffer
headers http.Header
err error
brw *bufio.ReadWriter
conn *testWSFakeNetConn
}
func (trw *testResponseWriter) Write(p []byte) (int, error) {
return trw.buf.Write(p)
}
func (trw *testResponseWriter) WriteHeader(status int) {
trw.buf.WriteString(fmt.Sprintf("%v", status))
}
func (trw *testResponseWriter) Header() http.Header {
if trw.headers == nil {
trw.headers = make(http.Header)
}
return trw.headers
}
type testWSFakeNetConn struct {
net.Conn
wbuf bytes.Buffer
err error
wsOpened bool
isClosed bool
deadlineCleared bool
}
func (c *testWSFakeNetConn) Write(p []byte) (int, error) {
if c.err != nil {
return 0, c.err
}
return c.wbuf.Write(p)
}
func (c *testWSFakeNetConn) SetDeadline(t time.Time) error {
if t.IsZero() {
c.deadlineCleared = true
}
return nil
}
func (c *testWSFakeNetConn) Close() error {
c.isClosed = true
return nil
}
func (trw *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if trw.conn == nil {
trw.conn = &testWSFakeNetConn{}
}
trw.conn.wsOpened = true
if trw.brw == nil {
trw.brw = bufio.NewReadWriter(bufio.NewReader(trw.conn), bufio.NewWriter(trw.conn))
}
return trw.conn, trw.brw, trw.err
}
func testWSOptions() *Options {
opts := DefaultOptions()
opts.DisableShortFirstPing = true
opts.Websocket.Host = "127.0.0.1"
opts.Websocket.Port = -1
opts.NoSystemAccount = true
var err error
tc := &TLSConfigOpts{
CertFile: "./configs/certs/server.pem",
KeyFile: "./configs/certs/key.pem",
}
opts.Websocket.TLSConfig, err = GenTLSConfig(tc)
if err != nil {
panic(err)
}
return opts
}
func testWSCreateValidReq() *http.Request {
req := &http.Request{
Method: "GET",
Host: "localhost",
Proto: "HTTP/1.1",
}
req.Header = make(http.Header)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Sec-Websocket-Version", "13")
return req
}
func TestWSCheckOrigin(t *testing.T) {
notSameOrigin := false
sameOrigin := true
allowedListEmpty := []string{}
someList := []string{"http://host1.com", "http://host2.com:1234"}
for _, test := range []struct {
name string
sameOrigin bool
origins []string
reqHost string
reqTLS bool
origin string
err string
}{
{"any", notSameOrigin, allowedListEmpty, "", false, "http://any.host.com", ""},
{"same origin ok", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:80", ""},
{"same origin bad host", sameOrigin, allowedListEmpty, "host.com", false, "http://other.host.com", "not same origin"},
{"same origin bad port", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:81", "not same origin"},
{"same origin bad scheme", sameOrigin, allowedListEmpty, "host.com", true, "http://host.com", "not same origin"},
{"same origin bad uri", sameOrigin, allowedListEmpty, "host.com", false, "@@@://invalid:url:1234", "invalid URI"},
{"same origin bad url", sameOrigin, allowedListEmpty, "host.com", false, "http://invalid:url:1234", "too many colons"},
{"same origin bad req host", sameOrigin, allowedListEmpty, "invalid:url:1234", false, "http://host.com", "too many colons"},
{"no origin", sameOrigin, allowedListEmpty, "", false, "", "origin not provided"},
{"allowed from list", notSameOrigin, someList, "", false, "http://host2.com:1234", ""},
{"allowed with different path", notSameOrigin, someList, "", false, "http://host1.com/some/path", ""},
{"list bad port", notSameOrigin, someList, "", false, "http://host1.com:1234", "not in the allowed list"},
{"list bad scheme", notSameOrigin, someList, "", false, "https://host2.com:1234", "not in the allowed list"},
} {
t.Run(test.name, func(t *testing.T) {
opts := DefaultOptions()
opts.Websocket.SameOrigin = test.sameOrigin
opts.Websocket.AllowedOrigins = test.origins
s := &Server{opts: opts}
s.wsSetOriginOptions(&opts.Websocket)
req := testWSCreateValidReq()
req.Host = test.reqHost
if test.reqTLS {
req.TLS = &tls.ConnectionState{}
}
if test.origin != "" {
req.Header.Set("Origin", test.origin)
}
err := s.websocket.checkOrigin(req)
if test.err == "" && err != nil {
t.Fatalf("Unexpected error: %v", err)
} else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) {
t.Fatalf("Expected error %q, got %v", test.err, err)
}
})
}
}
func TestWSUpgradeValidationErrors(t *testing.T) {
for _, test := range []struct {
name string
setup func() (*Options, *testResponseWriter, *http.Request)
err string
status int
}{
{
"bad method",
func() (*Options, *testResponseWriter, *http.Request) {
opts := testWSOptions()
req := testWSCreateValidReq()
req.Method = "POST"
return opts, nil, req
},
"must be GET",
http.StatusMethodNotAllowed,
},
{
"no host",
func() (*Options, *testResponseWriter, *http.Request) {
opts := testWSOptions()
req := testWSCreateValidReq()
req.Host = ""
return opts, nil, req
},
"'Host' missing in request",
http.StatusBadRequest,
},
{
"invalid upgrade header",
func() (*Options, *testResponseWriter, *http.Request) {
opts := testWSOptions()
req := testWSCreateValidReq()
req.Header.Del("Upgrade")
return opts, nil, req
},
"invalid value for header 'Uprade'",
http.StatusBadRequest,
},
{
"invalid connection header",
func() (*Options, *testResponseWriter, *http.Request) {
opts := testWSOptions()
req := testWSCreateValidReq()
req.Header.Del("Connection")
return opts, nil, req
},
"invalid value for header 'Connection'",
http.StatusBadRequest,
},
{
"no key",
func() (*Options, *testResponseWriter, *http.Request) {
opts := testWSOptions()
req := testWSCreateValidReq()
req.Header.Del("Sec-Websocket-Key")
return opts, nil, req
},
"key missing",
http.StatusBadRequest,
},
{
"empty key",
func() (*Options, *testResponseWriter, *http.Request) {
opts := testWSOptions()
req := testWSCreateValidReq()
req.Header.Set("Sec-Websocket-Key", "")
return opts, nil, req
},
"key missing",
http.StatusBadRequest,
},
{
"missing version",
func() (*Options, *testResponseWriter, *http.Request) {
opts := testWSOptions()
req := testWSCreateValidReq()
req.Header.Del("Sec-Websocket-Version")
return opts, nil, req
},
"invalid version",
http.StatusBadRequest,
},
{
"wrong version",
func() (*Options, *testResponseWriter, *http.Request) {
opts := testWSOptions()
req := testWSCreateValidReq()
req.Header.Set("Sec-Websocket-Version", "99")
return opts, nil, req
},
"invalid version",
http.StatusBadRequest,
},
{
"origin",
func() (*Options, *testResponseWriter, *http.Request) {
opts := testWSOptions()
opts.Websocket.SameOrigin = true
req := testWSCreateValidReq()
req.Header.Set("Origin", "http://bad.host.com")
return opts, nil, req
},
"origin not allowed",
http.StatusForbidden,
},
{
"hijack error",
func() (*Options, *testResponseWriter, *http.Request) {
opts := testWSOptions()
rw := &testResponseWriter{err: fmt.Errorf("on purpose")}
req := testWSCreateValidReq()
return opts, rw, req
},
"on purpose",
http.StatusInternalServerError,
},
{
"hijack buffered data",
func() (*Options, *testResponseWriter, *http.Request) {
opts := testWSOptions()
buf := &bytes.Buffer{}
buf.WriteString("some data")
rw := &testResponseWriter{
conn: &testWSFakeNetConn{},
brw: bufio.NewReadWriter(bufio.NewReader(buf), bufio.NewWriter(nil)),
}
tmp := [1]byte{}
io.ReadAtLeast(rw.brw, tmp[:1], 1)
req := testWSCreateValidReq()
return opts, rw, req
},
"client sent data before handshake is complete",
http.StatusBadRequest,
},
} {
t.Run(test.name, func(t *testing.T) {
opts, rw, req := test.setup()
if rw == nil {
rw = &testResponseWriter{}
}
s := &Server{opts: opts}
s.wsSetOriginOptions(&opts.Websocket)
res, err := s.wsUpgrade(rw, req)
if err == nil || !strings.Contains(err.Error(), test.err) {
t.Fatalf("Should get error %q, got %v", test.err, err)
}
if res != nil {
t.Fatalf("Should not have returned a result, got %v", res)
}
expected := fmt.Sprintf("%v%s\n", test.status, http.StatusText(test.status))
if got := rw.buf.String(); got != expected {
t.Fatalf("Expected %q got %q", expected, got)
}
// Check that if the connection was opened, it is now closed.
if rw.conn != nil && rw.conn.wsOpened && !rw.conn.isClosed {
t.Fatal("Connection was opened, but has not been closed")
}
})
}
}
func TestWSUpgradeResponseWriteError(t *testing.T) {
opts := testWSOptions()
s := &Server{opts: opts}
expectedErr := errors.New("on purpose")
rw := &testResponseWriter{
conn: &testWSFakeNetConn{err: expectedErr},
}
req := testWSCreateValidReq()
res, err := s.wsUpgrade(rw, req)
if err != expectedErr {
t.Fatalf("Should get error %q, got %v", expectedErr.Error(), err)
}
if res != nil {
t.Fatalf("Should not have returned a result, got %v", res)
}
if !rw.conn.isClosed {
t.Fatal("Connection should have been closed")
}
}
func TestWSUpgradeConnDeadline(t *testing.T) {
opts := testWSOptions()
opts.Websocket.HandshakeTimeout = time.Second
s := &Server{opts: opts}
rw := &testResponseWriter{}
req := testWSCreateValidReq()
res, err := s.wsUpgrade(rw, req)
if res == nil || err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if rw.conn.isClosed {
t.Fatal("Connection should NOT have been closed")
}
if !rw.conn.deadlineCleared {
t.Fatal("Connection deadline should have been cleared after handshake")
}
}
func TestWSCompressNegotiation(t *testing.T) {
// No compression on the server, but client asks
opts := testWSOptions()
s := &Server{opts: opts}
rw := &testResponseWriter{}
req := testWSCreateValidReq()
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate")
res, err := s.wsUpgrade(rw, req)
if res == nil || err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// The http response should not contain "permessage-deflate"
output := rw.conn.wbuf.String()
if strings.Contains(output, "permessage-deflate") {
t.Fatalf("Compression disabled in server so response to client should not contain extension, got %s", output)
}
// Option in the server and client, so compression should be negotiated.
s.opts.Websocket.Compression = true
rw = &testResponseWriter{}
res, err = s.wsUpgrade(rw, req)
if res == nil || err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// The http response should not contain "permessage-deflate"
output = rw.conn.wbuf.String()
if !strings.Contains(output, "permessage-deflate") {
t.Fatalf("Compression in server and client request, so response should contain extension, got %s", output)
}
// Option in server but not asked by the client, so response should not contain "permessage-deflate"
rw = &testResponseWriter{}
req.Header.Del("Sec-Websocket-Extensions")
res, err = s.wsUpgrade(rw, req)
if res == nil || err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// The http response should not contain "permessage-deflate"
output = rw.conn.wbuf.String()
if strings.Contains(output, "permessage-deflate") {
t.Fatalf("Compression in server but not in client, so response to client should not contain extension, got %s", output)
}
}
func TestWSParseOptions(t *testing.T) {
for _, test := range []struct {
name string
content string
checkOpt func(*WebsocketOpts) error
err string
}{
// Negative tests
{"bad type", "websocket: []", nil, "to be a map"},
{"bad listen", "websocket: { listen: [] }", nil, "port or host:port"},
{"bad port", `websocket: { port: "abc" }`, nil, "not int64"},
{"bad host", `websocket: { host: 123 }`, nil, "not string"},
{"bad advertise type", `websocket: { advertise: 123 }`, nil, "not string"},
{"bad tls", `websocket: { tls: 123 }`, nil, "not map[string]interface {}"},
{"bad same origin", `websocket: { same_origin: "abc" }`, nil, "not bool"},
{"bad allowed origins type", `websocket: { allowed_origins: {} }`, nil, "unsupported type"},
{"bad allowed origins values", `websocket: { allowed_origins: [ {} ] }`, nil, "unsupported type in array"},
{"bad handshake timeout type", `websocket: { handshake_timeout: [] }`, nil, "unsupported type"},
{"bad handshake timeout duration", `websocket: { handshake_timeout: "abc" }`, nil, "invalid duration"},
{"unknown field", `websocket: { this_does_not_exist: 123 }`, nil, "unknown"},
// Positive tests
{"listen port only", `websocket { listen: 1234 }`, func(wo *WebsocketOpts) error {
if wo.Port != 1234 {
return fmt.Errorf("expected 1234, got %v", wo.Port)
}
return nil
}, ""},
{"listen host and port", `websocket { listen: "localhost:1234" }`, func(wo *WebsocketOpts) error {
if wo.Host != "localhost" || wo.Port != 1234 {
return fmt.Errorf("expected localhost:1234, got %v:%v", wo.Host, wo.Port)
}
return nil
}, ""},
{"host", `websocket { host: "localhost" }`, func(wo *WebsocketOpts) error {
if wo.Host != "localhost" {
return fmt.Errorf("expected localhost, got %v", wo.Host)
}
return nil
}, ""},
{"port", `websocket { port: 1234 }`, func(wo *WebsocketOpts) error {
if wo.Port != 1234 {
return fmt.Errorf("expected 1234, got %v", wo.Port)
}
return nil
}, ""},
{"advertise", `websocket { advertise: "host:1234" }`, func(wo *WebsocketOpts) error {
if wo.Advertise != "host:1234" {
return fmt.Errorf("expected %q, got %q", "host:1234", wo.Advertise)
}
return nil
}, ""},
{"same origin", `websocket { same_origin: true }`, func(wo *WebsocketOpts) error {
if !wo.SameOrigin {
return fmt.Errorf("expected same_origin==true, got %v", wo.SameOrigin)
}
return nil
}, ""},
{"allowed origins one only", `websocket { allowed_origins: "https://host.com/" }`, func(wo *WebsocketOpts) error {
expected := []string{"https://host.com/"}
if !reflect.DeepEqual(wo.AllowedOrigins, expected) {
return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins)
}
return nil
}, ""},
{"allowed origins array",
`
websocket {
allowed_origins: [
"https://host1.com/"
"https://host2.com/"
]
}
`, func(wo *WebsocketOpts) error {
expected := []string{"https://host1.com/", "https://host2.com/"}
if !reflect.DeepEqual(wo.AllowedOrigins, expected) {
return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins)
}
return nil
}, ""},
{"handshake timeout in whole seconds", `websocket { handshake_timeout: 3 }`, func(wo *WebsocketOpts) error {
if wo.HandshakeTimeout != 3*time.Second {
return fmt.Errorf("expected handshake to be 3s, got %v", wo.HandshakeTimeout)
}
return nil
}, ""},
{"handshake timeout n duration", `websocket { handshake_timeout: "4s" }`, func(wo *WebsocketOpts) error {
if wo.HandshakeTimeout != 4*time.Second {
return fmt.Errorf("expected handshake to be 4s, got %v", wo.HandshakeTimeout)
}
return nil
}, ""},
{"tls config",
`
websocket {
tls {
cert_file: "./configs/certs/server.pem"
key_file: "./configs/certs/key.pem"
}
}
`, func(wo *WebsocketOpts) error {
if wo.TLSConfig == nil {
return fmt.Errorf("TLSConfig should have been set")
}
return nil
}, ""},
} {
t.Run(test.name, func(t *testing.T) {
conf := createConfFile(t, []byte(test.content))
defer os.Remove(conf)
o, err := ProcessConfigFile(conf)
if test.err != _EMPTY_ {
if err == nil || !strings.Contains(err.Error(), test.err) {
t.Fatalf("For content: %q, expected error about %q, got %v", test.content, test.err, err)
}
return
} else if err != nil {
t.Fatalf("Unexpected error for content %q: %v", test.content, err)
}
if err := test.checkOpt(&o.Websocket); err != nil {
t.Fatalf("Incorrect option for content %q: %v", test.content, err.Error())
}
})
}
}
func TestWSValidateOptions(t *testing.T) {
nwso := DefaultOptions()
wso := testWSOptions()
for _, test := range []struct {
name string
getOpts func() *Options
err string
}{
{"websocket disabled", func() *Options { return nwso.Clone() }, ""},
{"no tls", func() *Options { o := wso.Clone(); o.Websocket.TLSConfig = nil; return o }, "requires TLS configuration"},
{"bad url in allowed list", func() *Options {
o := wso.Clone()
o.Websocket.AllowedOrigins = []string{"http://this:is:bad:url"}
return o
}, "unable to parse"},
} {
t.Run(test.name, func(t *testing.T) {
err := validateWebsocketOptions(test.getOpts())
if test.err == "" && err != nil {
t.Fatalf("Unexpected error: %v", err)
} else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) {
t.Fatalf("Expected error to contain %q, got %v", test.err, err)
}
})
}
}
func TestWSSetOriginOptions(t *testing.T) {
o := testWSOptions()
for _, test := range []struct {
content string
err string
}{
{"@@@://host.com/", "invalid URI"},
{"http://this:is:bad:url/", "invalid port"},
} {
t.Run(test.err, func(t *testing.T) {
o.Websocket.AllowedOrigins = []string{test.content}
s := &Server{}
l := &captureErrorLogger{errCh: make(chan string, 1)}
s.SetLogger(l, false, false)
s.wsSetOriginOptions(&o.Websocket)
select {
case e := <-l.errCh:
if !strings.Contains(e, test.err) {
t.Fatalf("Unexpected error: %v", e)
}
case <-time.After(50 * time.Millisecond):
t.Fatalf("Did not get the error")
}
})
}
}
type captureFatalLogger struct {
DummyLogger
fatalCh chan string
}
func (l *captureFatalLogger) Fatalf(format string, v ...interface{}) {
select {
case l.fatalCh <- fmt.Sprintf(format, v...):
default:
}
}
func TestWSFailureToStartServer(t *testing.T) {
// Create a listener to use a port
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Error listening: %v", err)
}
defer l.Close()
o := testWSOptions()
// Make sure we don't have unnecessary listen ports opened.
o.HTTPPort = 0
o.Cluster.Port = 0
o.Gateway.Name = ""
o.Gateway.Port = 0
o.LeafNode.Port = 0
o.Websocket.Port = l.Addr().(*net.TCPAddr).Port
s, err := NewServer(o)
if err != nil {
t.Fatalf("Error creating server: %v", err)
}
defer s.Shutdown()
logger := &captureFatalLogger{fatalCh: make(chan string, 1)}
s.SetLogger(logger, false, false)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
s.Start()
wg.Done()
}()
select {
case e := <-logger.fatalCh:
if !strings.Contains(e, "Unable to listen") {
t.Fatalf("Unexpected error: %v", e)
}
case <-time.After(2 * time.Second):
t.Fatalf("Should have reported a fatal error")
}
// Since this is a test and the process does not actually
// exit on Fatal error, wait for the client port to be
// ready so when we shutdown we don't leave the accept
// loop hanging.
checkFor(t, time.Second, 15*time.Millisecond, func() error {
s.mu.Lock()
ready := s.listener != nil
s.mu.Unlock()
if !ready {
return fmt.Errorf("client accept loop not started yet")
}
return nil
})
s.Shutdown()
wg.Wait()
}
func TestWSAbnormalFailureOfWebServer(t *testing.T) {
o := testWSOptions()
s := RunServer(o)
defer s.Shutdown()
logger := &captureFatalLogger{fatalCh: make(chan string, 1)}
s.SetLogger(logger, false, false)
// Now close the WS listener to cause a WebServer error
s.mu.Lock()
s.websocket.listener.Close()
s.mu.Unlock()
select {
case e := <-logger.fatalCh:
if !strings.Contains(e, "websocket listener error") {
t.Fatalf("Unexpected error: %v", e)
}
case <-time.After(2 * time.Second):
t.Fatalf("Should have reported a fatal error")
}
}
func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader, []byte) {
t.Helper()
addr := fmt.Sprintf("%s:%d", host, port)
wsc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating ws connection: %v", err)
}
wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
if err := wsc.(*tls.Conn).Handshake(); err != nil {
t.Fatalf("Error during handshake: %v", err)
}
req := testWSCreateValidReq()
if compress {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate")
}
if web {
req.Header.Set("User-Agent", "Mozilla/5.0")
}
req.URL, _ = url.Parse("wss://" + addr)
if err := req.Write(wsc); err != nil {
t.Fatalf("Error sending request: %v", err)
}
br := bufio.NewReader(wsc)
resp, err := http.ReadResponse(br, req)
if err != nil {
t.Fatalf("Error reading response: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusSwitchingProtocols {
t.Fatalf("Expected response status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
// Wait for the INFO
info := testWSReadFrame(t, br)
if !bytes.HasPrefix(info, []byte("INFO {")) {
t.Fatalf("Expected INFO, got %s", info)
}
return wsc, br, info
}
func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) {
wsc, br, _ := testWSCreateClientGetInfo(t, compress, web, host, port)
// Send CONNECT and PING
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"))
if _, err := wsc.Write(wsmsg); err != nil {
t.Fatalf("Error sending message: %v", err)
}
// Wait for the PONG
if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
t.Fatalf("Expected PONG, got %s", msg)
}
return wsc, br
}
func testWSReadFrame(t testing.TB, br *bufio.Reader) []byte {
t.Helper()
fh := [2]byte{}
if _, err := io.ReadAtLeast(br, fh[:2], 2); err != nil {
t.Fatalf("Error reading frame: %v", err)
}
fc := fh[0]&wsRsv1Bit != 0
sb := fh[1]
size := 0
switch {
case sb <= 125:
size = int(sb)
case sb == 126:
tmp := [2]byte{}
if _, err := io.ReadAtLeast(br, tmp[:2], 2); err != nil {
t.Fatalf("Error reading frame: %v", err)
}
size = int(binary.BigEndian.Uint16(tmp[:2]))
case sb == 127:
tmp := [8]byte{}
if _, err := io.ReadAtLeast(br, tmp[:8], 8); err != nil {
t.Fatalf("Error reading frame: %v", err)
}
size = int(binary.BigEndian.Uint64(tmp[:8]))
}
buf := make([]byte, size)
if _, err := io.ReadAtLeast(br, buf, size); err != nil {
t.Fatalf("Error reading frame: %v", err)
}
if !fc {
return buf
}
buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff)
dbr := bytes.NewBuffer(buf)
d := flate.NewReader(dbr)
uncompressed, err := ioutil.ReadAll(d)
if err != nil {
t.Fatalf("Error reading frame: %v", err)
}
return uncompressed
}
func TestWSPubSub(t *testing.T) {
for _, test := range []struct {
name string
compression bool
}{
{"no compression", false},
{"compression", true},
} {
t.Run(test.name, func(t *testing.T) {
o := testWSOptions()
if test.compression {
o.Websocket.Compression = true
}
s := RunServer(o)
defer s.Shutdown()
// Create a regular client to subscribe
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
nsub := natsSubSync(t, nc, "foo")
checkExpectedSubs(t, 1, s)
// Now create a WS client and send a message on "foo"
wsc, br := testWSCreateClient(t, test.compression, false, o.Websocket.Host, o.Websocket.Port)
defer wsc.Close()
// Send a WS message for "PUB foo 2\r\nok\r\n"
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("PUB foo 7\r\nfrom ws\r\n"))
if _, err := wsc.Write(wsmsg); err != nil {
t.Fatalf("Error sending message: %v", err)
}
// Now check that message is received
msg := natsNexMsg(t, nsub, time.Second)
if string(msg.Data) != "from ws" {
t.Fatalf("Expected message to be %q, got %q", "ok", string(msg.Data))
}
// Now do reverse, create a subscription on WS client on bar
wsmsg = testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB bar 1\r\n"))
if _, err := wsc.Write(wsmsg); err != nil {
t.Fatalf("Error sending subscription: %v", err)
}
// Wait for it to be registered on server
checkExpectedSubs(t, 2, s)
// Now publish from NATS connection and verify received on WS client
natsPub(t, nc, "bar", []byte("from nats"))
natsFlush(t, nc)
// Check for the "from nats" message...
// Set some deadline so we are not stuck forever on failure
wsc.SetReadDeadline(time.Now().Add(10 * time.Second))
ok := 0
for {
line, _, err := br.ReadLine()
if err != nil {
t.Fatalf("Error reading: %v", err)
}
// Note that this works even in compression test because those
// texts are likely not to be compressed, but compression code is
// still executed.
if ok == 0 && bytes.Contains(line, []byte("MSG bar 1 9")) {
ok = 1
continue
} else if ok == 1 && bytes.Contains(line, []byte("from nats")) {
ok = 2
break
}
}
})
}
}
func TestWSTLSConnection(t *testing.T) {
o := testWSOptions()
s := RunServer(o)
defer s.Shutdown()
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
for _, test := range []struct {
name string
useTLS bool
status int
}{
{"client uses TLS", true, http.StatusSwitchingProtocols},
{"client does not use TLS", false, http.StatusBadRequest},
} {
t.Run(test.name, func(t *testing.T) {
wsc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating ws connection: %v", err)
}
defer wsc.Close()
if test.useTLS {
wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
if err := wsc.(*tls.Conn).Handshake(); err != nil {
t.Fatalf("Error during handshake: %v", err)
}
}
req := testWSCreateValidReq()
var scheme string
if test.useTLS {
scheme = "s"
}
req.URL, _ = url.Parse("ws" + scheme + "://" + addr)
if err := req.Write(wsc); err != nil {
t.Fatalf("Error sending request: %v", err)
}
br := bufio.NewReader(wsc)
resp, err := http.ReadResponse(br, req)
if err != nil {
t.Fatalf("Error reading response: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != test.status {
t.Fatalf("Expected status %v, got %v", test.status, resp.StatusCode)
}
})
}
}
func TestWSTLSVerifyClientCert(t *testing.T) {
o := testWSOptions()
tc := &TLSConfigOpts{
CertFile: "../test/configs/certs/server-cert.pem",
KeyFile: "../test/configs/certs/server-key.pem",
CaFile: "../test/configs/certs/ca.pem",
Verify: true,
}
tlsc, err := GenTLSConfig(tc)
if err != nil {
t.Fatalf("Error creating tls config: %v", err)
}
o.Websocket.TLSConfig = tlsc
s := RunServer(o)
defer s.Shutdown()
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
for _, test := range []struct {
name string
provideCert bool
}{
{"client provides cert", true},
{"client does not provide cert", false},
} {
t.Run(test.name, func(t *testing.T) {
wsc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating ws connection: %v", err)
}
defer wsc.Close()
tlsc := &tls.Config{}
if test.provideCert {
tc := &TLSConfigOpts{
CertFile: "../test/configs/certs/client-cert.pem",
KeyFile: "../test/configs/certs/client-key.pem",
}
var err error
tlsc, err = GenTLSConfig(tc)
if err != nil {
t.Fatalf("Error generating tls config: %v", err)
}
}
tlsc.InsecureSkipVerify = true
wsc = tls.Client(wsc, tlsc)
if err := wsc.(*tls.Conn).Handshake(); err != nil {
t.Fatalf("Error during handshake: %v", err)
}
req := testWSCreateValidReq()
req.URL, _ = url.Parse("wss://" + addr)
if err := req.Write(wsc); err != nil {
t.Fatalf("Error sending request: %v", err)
}
br := bufio.NewReader(wsc)
resp, err := http.ReadResponse(br, req)
if resp != nil {
resp.Body.Close()
}
if !test.provideCert {
if err == nil {
t.Fatal("Expected error, did not get one")
} else if !strings.Contains(err.Error(), "bad certificate") {
t.Fatalf("Unexpected error: %v", err)
}
return
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if resp.StatusCode != http.StatusSwitchingProtocols {
t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
})
}
}
func TestWSTLSVerifyAndMap(t *testing.T) {
accName := "MyAccount"
acc := NewAccount(accName)
certUserName := "CN=example.com,OU=NATS.io"
users := []*User{&User{Username: certUserName, Account: acc}}
for _, test := range []struct {
name string
wsUsers bool
provideCert bool
}{
{"no users override, client provides cert", false, true},
{"no users override, client does not provide cert", false, false},
{"sers override, client provides cert", true, true},
{"users override, client does not provide cert", true, false},
} {
t.Run(test.name, func(t *testing.T) {
o := testWSOptions()
o.Accounts = []*Account{acc}
if test.wsUsers {
o.Websocket.Users = users
} else {
o.Users = users
}
tc := &TLSConfigOpts{
CertFile: "../test/configs/certs/tlsauth/server.pem",
KeyFile: "../test/configs/certs/tlsauth/server-key.pem",
CaFile: "../test/configs/certs/tlsauth/ca.pem",
Verify: true,
}
tlsc, err := GenTLSConfig(tc)
if err != nil {
t.Fatalf("Error creating tls config: %v", err)
}
o.Websocket.TLSConfig = tlsc
o.Websocket.TLSMap = true
s := RunServer(o)
defer s.Shutdown()
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
wsc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating ws connection: %v", err)
}
defer wsc.Close()
tlscc := &tls.Config{}
if test.provideCert {
tc := &TLSConfigOpts{
CertFile: "../test/configs/certs/tlsauth/client.pem",
KeyFile: "../test/configs/certs/tlsauth/client-key.pem",
}
var err error
tlscc, err = GenTLSConfig(tc)
if err != nil {
t.Fatalf("Error generating tls config: %v", err)
}
}
tlscc.InsecureSkipVerify = true
wsc = tls.Client(wsc, tlscc)
if err := wsc.(*tls.Conn).Handshake(); err != nil {
t.Fatalf("Error during handshake: %v", err)
}
req := testWSCreateValidReq()
req.URL, _ = url.Parse("wss://" + addr)
if err := req.Write(wsc); err != nil {
t.Fatalf("Error sending request: %v", err)
}
br := bufio.NewReader(wsc)
resp, err := http.ReadResponse(br, req)
if resp != nil {
resp.Body.Close()
}
if !test.provideCert {
if err == nil {
t.Fatal("Expected error, did not get one")
} else if !strings.Contains(err.Error(), "bad certificate") {
t.Fatalf("Unexpected error: %v", err)
}
return
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if resp.StatusCode != http.StatusSwitchingProtocols {
t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
// Wait for the INFO
l := testWSReadFrame(t, br)
if !bytes.HasPrefix(l, []byte("INFO {")) {
t.Fatalf("Expected INFO, got %s", l)
}
var info serverInfo
if err := json.Unmarshal(l[5:], &info); err != nil {
t.Fatalf("Unable to unmarshal info: %v", err)
}
// Send CONNECT and PING
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"))
if _, err := wsc.Write(wsmsg); err != nil {
t.Fatalf("Error sending message: %v", err)
}
// Wait for the PONG
if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
t.Fatalf("Expected PONG, got %s", msg)
}
var uname string
var accname string
c := s.getClient(info.CID)
if c != nil {
c.mu.Lock()
uname = c.opts.Username
if c.acc != nil {
accname = c.acc.GetName()
}
c.mu.Unlock()
}
if uname != certUserName {
t.Fatalf("Expected username %q, got %q", certUserName, uname)
}
if accname != accName {
t.Fatalf("Expected account %q, got %v", accName, accname)
}
})
}
}
func TestWSHandshakeTimeout(t *testing.T) {
o := testWSOptions()
o.Websocket.HandshakeTimeout = time.Millisecond
tc := &TLSConfigOpts{
CertFile: "./configs/certs/server.pem",
KeyFile: "./configs/certs/key.pem",
}
o.Websocket.TLSConfig, _ = GenTLSConfig(tc)
s := RunServer(o)
defer s.Shutdown()
logger := &captureErrorLogger{errCh: make(chan string, 1)}
s.SetLogger(logger, false, false)
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
wsc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating ws connection: %v", err)
}
defer wsc.Close()
// Delay the handshake
wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
time.Sleep(20 * time.Millisecond)
// We expect error since the server should have cut us off
if err := wsc.(*tls.Conn).Handshake(); err == nil {
t.Fatal("Expected error during handshake")
}
// Check that server logs error
select {
case e := <-logger.errCh:
if !strings.Contains(e, "timeout") {
t.Fatalf("Unexpected error: %v", e)
}
case <-time.After(time.Second):
t.Fatalf("Should have timed-out")
}
}
func TestWSServerReportUpgradeFailure(t *testing.T) {
o := testWSOptions()
s := RunServer(o)
defer s.Shutdown()
logger := &captureErrorLogger{errCh: make(chan string, 1)}
s.SetLogger(logger, false, false)
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
req := testWSCreateValidReq()
req.URL, _ = url.Parse("wss://" + addr)
wsc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating ws connection: %v", err)
}
defer wsc.Close()
wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
if err := wsc.(*tls.Conn).Handshake(); err != nil {
t.Fatalf("Error during handshake: %v", err)
}
// Remove a required field from the request to have it fail
req.Header.Del("Connection")
// Send the request
if err := req.Write(wsc); err != nil {
t.Fatalf("Error sending request: %v", err)
}
br := bufio.NewReader(wsc)
resp, err := http.ReadResponse(br, req)
if err != nil {
t.Fatalf("Error reading response: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("Expected status %v, got %v", http.StatusBadRequest, resp.StatusCode)
}
// Check that server logs error
select {
case e := <-logger.errCh:
if !strings.Contains(e, "invalid value for header 'Connection'") {
t.Fatalf("Unexpected error: %v", e)
}
case <-time.After(time.Second):
t.Fatalf("Should have timed-out")
}
}
func TestWSCloseMsgSendOnConnectionClose(t *testing.T) {
o := testWSOptions()
s := RunServer(o)
defer s.Shutdown()
wsc, br := testWSCreateClient(t, false, false, o.Websocket.Host, o.Websocket.Port)
defer wsc.Close()
checkClientsCount(t, s, 1)
var c *client
s.mu.Lock()
for _, cli := range s.clients {
c = cli
break
}
s.mu.Unlock()
c.closeConnection(ProtocolViolation)
msg := testWSReadFrame(t, br)
if len(msg) < 2 {
t.Fatalf("Should have 2 bytes to represent the status, got %v", msg)
}
if sc := int(binary.BigEndian.Uint16(msg[:2])); sc != wsCloseStatusProtocolError {
t.Fatalf("Expected status to be %v, got %v", wsCloseStatusProtocolError, sc)
}
expectedPayload := ProtocolViolation.String()
if p := string(msg[2:]); p != expectedPayload {
t.Fatalf("Expected payload to be %q, got %q", expectedPayload, p)
}
}
func TestWSAdvertise(t *testing.T) {
o := testWSOptions()
o.Cluster.Port = 0
o.HTTPPort = 0
o.Websocket.Advertise = "xxx:host:yyy"
s, err := NewServer(o)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer s.Shutdown()
l := &captureFatalLogger{fatalCh: make(chan string, 1)}
s.SetLogger(l, false, false)
go s.Start()
select {
case e := <-l.fatalCh:
if !strings.Contains(e, "Unable to get websocket connect URLs") {
t.Fatalf("Unexpected error: %q", e)
}
case <-time.After(time.Second):
t.Fatal("Should have failed to start")
}
s.Shutdown()
o1 := testWSOptions()
o1.Websocket.Advertise = "host1:1234"
s1 := RunServer(o1)
defer s1.Shutdown()
wsc, br := testWSCreateClient(t, false, false, o1.Websocket.Host, o1.Websocket.Port)
defer wsc.Close()
o2 := testWSOptions()
o2.Websocket.Advertise = "host2:5678"
o2.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", o1.Cluster.Host, o1.Cluster.Port))
s2 := RunServer(o2)
defer s2.Shutdown()
checkInfo := func(expected []string) {
t.Helper()
infob := testWSReadFrame(t, br)
info := &Info{}
json.Unmarshal(infob[5:], info)
if n := len(info.ClientConnectURLs); n != len(expected) {
t.Fatalf("Unexpected info: %+v", info)
}
good := 0
for _, u := range info.ClientConnectURLs {
for _, eu := range expected {
if u == eu {
good++
}
}
}
if good != len(expected) {
t.Fatalf("Unexpected connect urls: %q", info.ClientConnectURLs)
}
}
checkInfo([]string{"host1:1234", "host2:5678"})
// Now shutdown s2 and expect another INFO
s2.Shutdown()
checkInfo([]string{"host1:1234"})
// Restart with another advertise and check that it gets updated
o2.Websocket.Advertise = "host3:9012"
s2 = RunServer(o2)
defer s2.Shutdown()
checkInfo([]string{"host1:1234", "host3:9012"})
}
func TestWSFrameOutbound(t *testing.T) {
c, _, _ := testWSSetupForRead()
var bufs net.Buffers
bufs = append(bufs, []byte("this "))
bufs = append(bufs, []byte("is "))
bufs = append(bufs, []byte("a "))
bufs = append(bufs, []byte("set "))
bufs = append(bufs, []byte("of "))
bufs = append(bufs, []byte("buffers"))
en := 2
for _, b := range bufs {
en += len(b)
}
c.mu.Lock()
c.out.nb = bufs
res, n := c.collapsePtoNB()
c.mu.Unlock()
if n != int64(en) {
t.Fatalf("Expected size to be %v, got %v", en, n)
}
if eb := 1 + len(bufs); eb != len(res) {
t.Fatalf("Expected %v buffers, got %v", eb, len(res))
}
var ob []byte
for i := 1; i < len(res); i++ {
ob = append(ob, res[i]...)
}
if !bytes.Equal(ob, []byte("this is a set of buffers")) {
t.Fatalf("Unexpected outbound: %q", ob)
}
bufs = nil
c.out.pb = 0
c.ws.fs = 0
c.ws.frames = nil
c.ws.browser = true
bufs = append(bufs, []byte("some smaller "))
bufs = append(bufs, []byte("buffers"))
bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+10))
bufs = append(bufs, []byte("then some more"))
en = 2 + len(bufs[0]) + len(bufs[1])
en += 4 + len(bufs[2]) - 10
en += 2 + len(bufs[3]) + 10
c.mu.Lock()
c.out.nb = bufs
res, n = c.collapsePtoNB()
c.mu.Unlock()
if n != int64(en) {
t.Fatalf("Expected size to be %v, got %v", en, n)
}
if len(res) != 8 {
t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
}
if len(res[4]) != wsFrameSizeForBrowsers {
t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
}
if len(res[6]) != 10 {
t.Fatalf("Frame 6 should have the partial of 10 bytes, got %v", len(res[6]))
}
bufs = nil
c.out.pb = 0
c.ws.fs = 0
c.ws.frames = nil
c.ws.browser = true
bufs = append(bufs, []byte("some smaller "))
bufs = append(bufs, []byte("buffers"))
// Have one of the exact max size
bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers))
bufs = append(bufs, []byte("then some more"))
en = 2 + len(bufs[0]) + len(bufs[1])
en += 4 + len(bufs[2])
en += 2 + len(bufs[3])
c.mu.Lock()
c.out.nb = bufs
res, n = c.collapsePtoNB()
c.mu.Unlock()
if n != int64(en) {
t.Fatalf("Expected size to be %v, got %v", en, n)
}
if len(res) != 7 {
t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
}
if len(res[4]) != wsFrameSizeForBrowsers {
t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
}
if string(res[6]) != string(bufs[3]) {
t.Fatalf("Frame 6 should be %q, got %q", bufs[3], res[6])
}
bufs = nil
c.out.pb = 0
c.ws.fs = 0
c.ws.frames = nil
c.ws.browser = true
bufs = append(bufs, []byte("some smaller "))
bufs = append(bufs, []byte("buffers"))
// Have one of the exact max size, and last in the list
bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers))
en = 2 + len(bufs[0]) + len(bufs[1])
en += 4 + len(bufs[2])
c.mu.Lock()
c.out.nb = bufs
res, n = c.collapsePtoNB()
c.mu.Unlock()
if n != int64(en) {
t.Fatalf("Expected size to be %v, got %v", en, n)
}
if len(res) != 5 {
t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
}
if len(res[4]) != wsFrameSizeForBrowsers {
t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
}
bufs = nil
c.out.pb = 0
c.ws.fs = 0
c.ws.frames = nil
c.ws.browser = true
bufs = append(bufs, []byte("some smaller buffer"))
bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers-5))
bufs = append(bufs, []byte("then some more"))
en = 2 + len(bufs[0])
en += 4 + len(bufs[1])
en += 2 + len(bufs[2])
c.mu.Lock()
c.out.nb = bufs
res, n = c.collapsePtoNB()
c.mu.Unlock()
if n != int64(en) {
t.Fatalf("Expected size to be %v, got %v", en, n)
}
if len(res) != 6 {
t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
}
if len(res[3]) != wsFrameSizeForBrowsers-5 {
t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
}
if string(res[5]) != string(bufs[2]) {
t.Fatalf("Frame 6 should be %q, got %q", bufs[2], res[5])
}
}
func TestWSWebrowserClient(t *testing.T) {
o := testWSOptions()
s := RunServer(o)
defer s.Shutdown()
wsc, br := testWSCreateClient(t, false, true, o.Websocket.Host, o.Websocket.Port)
defer wsc.Close()
checkClientsCount(t, s, 1)
var c *client
s.mu.Lock()
for _, cli := range s.clients {
c = cli
break
}
s.mu.Unlock()
proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB foo 1\r\nPING\r\n"))
wsc.Write(proto)
if res := testWSReadFrame(t, br); !bytes.Equal(res, []byte(pongProto)) {
t.Fatalf("Expected PONG back")
}
c.mu.Lock()
ok := c.ws != nil && c.ws.browser == true
c.mu.Unlock()
if !ok {
t.Fatalf("Client is not marked as webrowser client")
}
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
// Send a big message and check that it is received in smaller frames
psize := 204813
nc.Publish("foo", make([]byte, psize))
nc.Flush()
rsize := psize + len(fmt.Sprintf("MSG foo %d\r\n\r\n", psize))
nframes := 0
for total := 0; total < rsize; nframes++ {
res := testWSReadFrame(t, br)
total += len(res)
}
if expected := psize / wsFrameSizeForBrowsers; expected > nframes {
t.Fatalf("Expected %v frames, got %v", expected, nframes)
}
}
type testWSWrappedConn struct {
net.Conn
mu sync.RWMutex
buf *bytes.Buffer
partial bool
}
func (wc *testWSWrappedConn) Write(p []byte) (int, error) {
wc.mu.Lock()
defer wc.mu.Unlock()
var err error
n := len(p)
if wc.partial && n > 10 {
n = 10
err = io.ErrShortWrite
}
p = p[:n]
wc.buf.Write(p)
wc.Conn.Write(p)
return n, err
}
func TestWSCompressionBasic(t *testing.T) {
payload := "This is the content of a message that will be compresseddddddddddddddddddddd."
msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload)
cbuf := &bytes.Buffer{}
compressor, _ := flate.NewWriter(cbuf, flate.BestSpeed)
compressor.Write([]byte(msgProto))
compressor.Close()
compressed := cbuf.Bytes()
// The last 4 bytes are dropped
compressed = compressed[:len(compressed)-4]
o := testWSOptions()
o.Websocket.Compression = true
s := RunServer(o)
defer s.Shutdown()
c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port)
defer c.Close()
proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n"))
c.Write(proto)
l := testWSReadFrame(t, br)
if !bytes.Equal(l, []byte(pongProto)) {
t.Fatalf("Expected PONG, got %q", l)
}
var wc *testWSWrappedConn
s.mu.Lock()
for _, c := range s.clients {
c.mu.Lock()
wc = &testWSWrappedConn{Conn: c.nc, buf: &bytes.Buffer{}}
c.nc = wc
c.mu.Unlock()
}
s.mu.Unlock()
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
natsPub(t, nc, "foo", []byte(payload))
res := &bytes.Buffer{}
for total := 0; total < len(msgProto); {
l := testWSReadFrame(t, br)
n, _ := res.Write(l)
total += n
}
if !bytes.Equal([]byte(msgProto), res.Bytes()) {
t.Fatalf("Unexpected result: %q", res)
}
// Now check the wrapped connection buffer to check that data was actually compressed.
wc.mu.RLock()
res = wc.buf
wc.mu.RUnlock()
if bytes.Contains(res.Bytes(), []byte(payload)) {
t.Fatalf("Looks like frame was not compressed: %q", res.Bytes())
}
header := res.Bytes()[:2]
body := res.Bytes()[2:]
expectedB0 := byte(wsBinaryMessage) | wsFinalBit | wsRsv1Bit
expectedPS := len(compressed)
expectedB1 := byte(expectedPS)
if b := header[0]; b != expectedB0 {
t.Fatalf("Expected first byte to be %v, got %v", expectedB0, b)
}
if b := header[1]; b != expectedB1 {
t.Fatalf("Expected second byte to be %v, got %v", expectedB1, b)
}
if len(body) != expectedPS {
t.Fatalf("Expected payload length to be %v, got %v", expectedPS, len(body))
}
if !bytes.Equal(body, compressed) {
t.Fatalf("Unexpected compress body: %q", body)
}
}
func TestWSCompressionWithPartialWrite(t *testing.T) {
payload := "This is the content of a message that will be compresseddddddddddddddddddddd."
msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload)
o := testWSOptions()
o.Websocket.Compression = true
s := RunServer(o)
defer s.Shutdown()
c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port)
defer c.Close()
proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n"))
c.Write(proto)
l := testWSReadFrame(t, br)
if !bytes.Equal(l, []byte(pongProto)) {
t.Fatalf("Expected PONG, got %q", l)
}
pingPayload := []byte("my ping")
pingFromWSClient := testWSCreateClientMsg(wsPingMessage, 1, true, false, pingPayload)
var wc *testWSWrappedConn
var ws *client
s.mu.Lock()
for _, c := range s.clients {
ws = c
c.mu.Lock()
wc = &testWSWrappedConn{
Conn: c.nc,
buf: &bytes.Buffer{},
}
c.nc = wc
c.mu.Unlock()
break
}
s.mu.Unlock()
wc.mu.Lock()
wc.partial = true
wc.mu.Unlock()
nc := natsConnect(t, s.ClientURL())
defer nc.Close()
expected := &bytes.Buffer{}
for i := 0; i < 10; i++ {
if i > 0 {
time.Sleep(10 * time.Millisecond)
}
expected.Write([]byte(msgProto))
natsPub(t, nc, "foo", []byte(payload))
if i == 1 {
c.Write(pingFromWSClient)
}
}
var gotPingResponse bool
res := &bytes.Buffer{}
for total := 0; total < 10*len(msgProto); {
l := testWSReadFrame(t, br)
if bytes.Equal(l, pingPayload) {
gotPingResponse = true
} else {
n, _ := res.Write(l)
total += n
}
}
if !bytes.Equal(expected.Bytes(), res.Bytes()) {
t.Fatalf("Unexpected result: %q", res)
}
if !gotPingResponse {
t.Fatal("Did not get the ping response")
}
checkFor(t, time.Second, 15*time.Millisecond, func() error {
ws.mu.Lock()
pb := ws.out.pb
wf := ws.ws.frames
fs := ws.ws.fs
ws.mu.Unlock()
if pb != 0 || len(wf) != 0 || fs != 0 {
return fmt.Errorf("Expected pb, wf and fs to be 0, got %v, %v, %v", pb, wf, fs)
}
return nil
})
}
func TestWSCompressionFrameSizeLimit(t *testing.T) {
opts := testWSOptions()
opts.MaxPending = MAX_PENDING_SIZE
s := &Server{opts: opts}
c := &client{srv: s, ws: &websocket{compress: true, browser: true}}
c.initClient()
// uncompressedPayload := []byte("abcdefghijklmnopqrstuvwxyz")
uncompressedPayload := make([]byte, 2*wsFrameSizeForBrowsers)
for i := 0; i < len(uncompressedPayload); i++ {
uncompressedPayload[i] = byte(rand.Intn(256))
}
c.mu.Lock()
c.out.nb = append(net.Buffers(nil), uncompressedPayload)
nb, _ := c.collapsePtoNB()
c.mu.Unlock()
bb := &bytes.Buffer{}
for i, b := range nb {
// frame header buffer are always very small. The payload should not be more
// than 10 bytes since that is what we passed as the limit.
if len(b) > wsFrameSizeForBrowsers {
t.Fatalf("Frame size too big: %v (%q)", len(b), b)
}
// Check frame headers for the proper formatting.
if i%2 == 1 {
bb.Write(b)
}
}
buf := bb.Bytes()
buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff)
dbr := bytes.NewBuffer(buf)
d := flate.NewReader(dbr)
uncompressed, err := ioutil.ReadAll(d)
if err != nil {
t.Fatalf("Error reading frame: %v", err)
}
if !bytes.Equal(uncompressed, uncompressedPayload) {
t.Fatalf("Unexpected uncomressed data: %q", uncompressed)
}
}
func TestWSBasicAuth(t *testing.T) {
for _, test := range []struct {
name string
opts func() *Options
user string
pass string
err string
}{
{
"top level auth, no override, wrong u/p",
func() *Options {
o := testWSOptions()
o.Username = "normal"
o.Password = "client"
return o
},
"websocket", "client", "-ERR 'Authorization Violation'",
},
{
"top level auth, no override, correct u/p",
func() *Options {
o := testWSOptions()
o.Username = "normal"
o.Password = "client"
return o
},
"normal", "client", "",
},
{
"no top level auth, ws auth, wrong u/p",
func() *Options {
o := testWSOptions()
o.Websocket.Username = "websocket"
o.Websocket.Password = "client"
return o
},
"normal", "client", "-ERR 'Authorization Violation'",
},
{
"no top level auth, ws auth, correct u/p",
func() *Options {
o := testWSOptions()
o.Websocket.Username = "websocket"
o.Websocket.Password = "client"
return o
},
"websocket", "client", "",
},
{
"top level auth, ws override, wrong u/p",
func() *Options {
o := testWSOptions()
o.Username = "normal"
o.Password = "client"
o.Websocket.Username = "websocket"
o.Websocket.Password = "client"
return o
},
"normal", "client", "-ERR 'Authorization Violation'",
},
{
"top level auth, ws override, correct u/p",
func() *Options {
o := testWSOptions()
o.Username = "normal"
o.Password = "client"
o.Websocket.Username = "websocket"
o.Websocket.Password = "client"
return o
},
"websocket", "client", "",
},
} {
t.Run(test.name, func(t *testing.T) {
o := test.opts()
s := RunServer(o)
defer s.Shutdown()
wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
defer wsc.Close()
connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n",
test.user, test.pass)
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
if _, err := wsc.Write(wsmsg); err != nil {
t.Fatalf("Error sending message: %v", err)
}
msg := testWSReadFrame(t, br)
if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
t.Fatalf("Expected to receive PONG, got %q", msg)
} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
t.Fatalf("Expected to receive %q, got %q", test.err, msg)
}
})
}
}
func TestWSAuthTimeout(t *testing.T) {
for _, test := range []struct {
name string
at float64
wat float64
err string
}{
{"use top-level auth timeout", 10.0, 0.0, ""},
{"use websocket auth timeout", 10.0, 0.05, "-ERR 'Authentication Timeout'"},
} {
t.Run(test.name, func(t *testing.T) {
o := testWSOptions()
o.AuthTimeout = test.at
o.Websocket.Username = "websocket"
o.Websocket.Password = "client"
o.Websocket.AuthTimeout = test.wat
s := RunServer(o)
defer s.Shutdown()
wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
defer wsc.Close()
var info serverInfo
json.Unmarshal([]byte(l[5:]), &info)
// Make sure that we are told that auth is required.
if !info.AuthRequired {
t.Fatalf("Expected auth required, was not: %q", l)
}
start := time.Now()
// Wait before sending connect
time.Sleep(100 * time.Millisecond)
connectProto := "CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"websocket\",\"pass\":\"client\"}\r\nPING\r\n"
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
if _, err := wsc.Write(wsmsg); err != nil {
t.Fatalf("Error sending message: %v", err)
}
msg := testWSReadFrame(t, br)
if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
t.Fatalf("Expected to receive %q error, got %q", test.err, msg)
} else if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
t.Fatalf("Unexpected error: %q", msg)
}
if dur := time.Since(start); dur > time.Second {
t.Fatalf("Too long to get timeout error: %v", dur)
}
})
}
}
func TestWSTokenAuth(t *testing.T) {
for _, test := range []struct {
name string
opts func() *Options
token string
err string
}{
{
"top level auth, no override, wrong token",
func() *Options {
o := testWSOptions()
o.Authorization = "goodtoken"
return o
},
"badtoken", "-ERR 'Authorization Violation'",
},
{
"top level auth, no override, correct token",
func() *Options {
o := testWSOptions()
o.Authorization = "goodtoken"
return o
},
"goodtoken", "",
},
{
"no top level auth, ws auth, wrong token",
func() *Options {
o := testWSOptions()
o.Websocket.Token = "goodtoken"
return o
},
"badtoken", "-ERR 'Authorization Violation'",
},
{
"no top level auth, ws auth, correct token",
func() *Options {
o := testWSOptions()
o.Websocket.Token = "goodtoken"
return o
},
"goodtoken", "",
},
{
"top level auth, ws override, wrong token",
func() *Options {
o := testWSOptions()
o.Authorization = "clienttoken"
o.Websocket.Token = "websockettoken"
return o
},
"clienttoken", "-ERR 'Authorization Violation'",
},
{
"top level auth, ws override, correct token",
func() *Options {
o := testWSOptions()
o.Authorization = "clienttoken"
o.Websocket.Token = "websockettoken"
return o
},
"websockettoken", "",
},
} {
t.Run(test.name, func(t *testing.T) {
o := test.opts()
s := RunServer(o)
defer s.Shutdown()
wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
defer wsc.Close()
connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"auth_token\":\"%s\"}\r\nPING\r\n",
test.token)
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
if _, err := wsc.Write(wsmsg); err != nil {
t.Fatalf("Error sending message: %v", err)
}
msg := testWSReadFrame(t, br)
if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
t.Fatalf("Expected to receive PONG, got %q", msg)
} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
t.Fatalf("Expected to receive %q, got %q", test.err, msg)
}
})
}
}
func TestWSUsersAuth(t *testing.T) {
for _, test := range []struct {
name string
opts func() *Options
user string
pass string
err string
}{
{
"top level auth, no override, wrong user",
func() *Options {
o := testWSOptions()
o.Users = []*User{
&User{Username: "normal1", Password: "client1"},
&User{Username: "normal2", Password: "client2"},
}
return o
},
"websocket", "client", "-ERR 'Authorization Violation'",
},
{
"top level auth, no override, correct user",
func() *Options {
o := testWSOptions()
o.Users = []*User{
&User{Username: "normal1", Password: "client1"},
&User{Username: "normal2", Password: "client2"},
}
return o
},
"normal2", "client2", "",
},
{
"no top level auth, ws auth, wrong user",
func() *Options {
o := testWSOptions()
o.Websocket.Users = []*User{
&User{Username: "websocket1", Password: "client1"},
&User{Username: "websocket2", Password: "client2"},
}
return o
},
"websocket", "client", "-ERR 'Authorization Violation'",
},
{
"no top level auth, ws auth, correct token",
func() *Options {
o := testWSOptions()
o.Websocket.Users = []*User{
&User{Username: "websocket1", Password: "client1"},
&User{Username: "websocket2", Password: "client2"},
}
return o
},
"websocket1", "client1", "",
},
{
"top level auth, ws override, wrong user",
func() *Options {
o := testWSOptions()
o.Users = []*User{
&User{Username: "normal1", Password: "client1"},
&User{Username: "normal2", Password: "client2"},
}
o.Websocket.Users = []*User{
&User{Username: "websocket1", Password: "client1"},
&User{Username: "websocket2", Password: "client2"},
}
return o
},
"normal2", "client2", "-ERR 'Authorization Violation'",
},
{
"top level auth, ws override, correct token",
func() *Options {
o := testWSOptions()
o.Users = []*User{
&User{Username: "normal1", Password: "client1"},
&User{Username: "normal2", Password: "client2"},
}
o.Websocket.Users = []*User{
&User{Username: "websocket1", Password: "client1"},
&User{Username: "websocket2", Password: "client2"},
}
return o
},
"websocket2", "client2", "",
},
} {
t.Run(test.name, func(t *testing.T) {
o := test.opts()
s := RunServer(o)
defer s.Shutdown()
wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
defer wsc.Close()
connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n",
test.user, test.pass)
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
if _, err := wsc.Write(wsmsg); err != nil {
t.Fatalf("Error sending message: %v", err)
}
msg := testWSReadFrame(t, br)
if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
t.Fatalf("Expected to receive PONG, got %q", msg)
} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
t.Fatalf("Expected to receive %q, got %q", test.err, msg)
}
})
}
}
func TestWSNoAuthUserValidation(t *testing.T) {
// It is illegal to configure a websocket's NoAuthUser if websocket's
// auth config does not have a matching User.
// Create regular clients's User to make sure that we fail even if
// the websocket's NoAuthUser is found in opts.Users.
o := testWSOptions()
o.Users = []*User{&User{Username: "user", Password: "pwd"}}
o.Websocket.NoAuthUser = "user"
if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "users are not") {
t.Fatalf("Expected error saying that users are not configured, got %v", err)
}
o.Websocket.Users = []*User{&User{Username: "wsuser", Password: "pwd"}}
o.Websocket.NoAuthUser = "notfound"
if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not found") {
t.Fatalf("Expected error saying no auth user not found, got %v", err)
}
}
func TestWSNoAuthUser(t *testing.T) {
for _, test := range []struct {
name string
createWSUsers bool
useAuth bool
expectedUser string
expectedAcc string
}{
{"no override, no user provided", false, false, "noauth", "normal"},
{"no override, user povided", false, true, "user", "normal"},
{"override, no user provided", true, false, "wsnoauth", "websocket"},
{"override, user provided", true, true, "wsuser", "websocket"},
} {
t.Run(test.name, func(t *testing.T) {
o := testWSOptions()
normalAcc := NewAccount("normal")
websocketAcc := NewAccount("websocket")
o.Accounts = []*Account{normalAcc, websocketAcc}
o.Users = []*User{
&User{Username: "noauth", Password: "pwd", Account: normalAcc},
&User{Username: "user", Password: "pwd", Account: normalAcc},
}
o.NoAuthUser = "noauth"
if test.createWSUsers {
o.Websocket.Users = []*User{
&User{Username: "wsnoauth", Password: "pwd", Account: websocketAcc},
&User{Username: "wsuser", Password: "pwd", Account: websocketAcc},
}
o.Websocket.NoAuthUser = "wsnoauth"
}
s := RunServer(o)
defer s.Shutdown()
wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
defer wsc.Close()
var info serverInfo
json.Unmarshal([]byte(l[5:]), &info)
var connectProto string
if test.useAuth {
connectProto = fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"pwd\"}\r\nPING\r\n",
test.expectedUser)
} else {
connectProto = "CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"
}
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
if _, err := wsc.Write(wsmsg); err != nil {
t.Fatalf("Error sending message: %v", err)
}
msg := testWSReadFrame(t, br)
if !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
t.Fatalf("Unexpected error: %q", msg)
}
c := s.getClient(info.CID)
c.mu.Lock()
uname := c.opts.Username
aname := c.acc.GetName()
c.mu.Unlock()
if uname != test.expectedUser {
t.Fatalf("Expected selected user to be %q, got %q", test.expectedUser, uname)
}
if aname != test.expectedAcc {
t.Fatalf("Expected selected account to be %q, got %q", test.expectedAcc, aname)
}
})
}
}
func TestWSNkeyAuth(t *testing.T) {
nkp, _ := nkeys.CreateUser()
pub, _ := nkp.PublicKey()
wsnkp, _ := nkeys.CreateUser()
wspub, _ := wsnkp.PublicKey()
for _, test := range []struct {
name string
opts func() *Options
nkey string
kp nkeys.KeyPair
err string
}{
{
"top level auth, no override, wrong nkey",
func() *Options {
o := testWSOptions()
o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}}
return o
},
wspub, wsnkp, "-ERR 'Authorization Violation'",
},
{
"top level auth, no override, correct nkey",
func() *Options {
o := testWSOptions()
o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}}
return o
},
pub, nkp, "",
},
{
"no top level auth, ws auth, wrong nkey",
func() *Options {
o := testWSOptions()
o.Websocket.Nkeys = []*NkeyUser{&NkeyUser{Nkey: wspub}}
return o
},
pub, nkp, "-ERR 'Authorization Violation'",
},
{
"no top level auth, ws auth, correct nkey",
func() *Options {
o := testWSOptions()
o.Websocket.Nkeys = []*NkeyUser{&NkeyUser{Nkey: wspub}}
return o
},
wspub, wsnkp, "",
},
{
"top level auth, ws override, wrong nkey",
func() *Options {
o := testWSOptions()
o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}}
o.Websocket.Nkeys = []*NkeyUser{&NkeyUser{Nkey: wspub}}
return o
},
pub, nkp, "-ERR 'Authorization Violation'",
},
{
"top level auth, ws override, correct nkey",
func() *Options {
o := testWSOptions()
o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}}
o.Websocket.Nkeys = []*NkeyUser{&NkeyUser{Nkey: wspub}}
return o
},
wspub, wsnkp, "",
},
} {
t.Run(test.name, func(t *testing.T) {
o := test.opts()
s := RunServer(o)
defer s.Shutdown()
wsc, br, infoMsg := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
defer wsc.Close()
// Sign Nonce
var info nonceInfo
json.Unmarshal([]byte(infoMsg[5:]), &info)
sigraw, _ := test.kp.Sign([]byte(info.Nonce))
sig := base64.RawURLEncoding.EncodeToString(sigraw)
connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"nkey\":\"%s\",\"sig\":\"%s\"}\r\nPING\r\n", test.nkey, sig)
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
if _, err := wsc.Write(wsmsg); err != nil {
t.Fatalf("Error sending message: %v", err)
}
msg := testWSReadFrame(t, br)
if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
t.Fatalf("Expected to receive PONG, got %q", msg)
} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
t.Fatalf("Expected to receive %q, got %q", test.err, msg)
}
})
}
}
// ==================================================================
// = Benchmark tests
// ==================================================================
const testWSBenchSubject = "a"
var ch = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()")
func sizedString(sz int) string {
b := make([]byte, sz)
for i := range b {
b[i] = ch[rand.Intn(len(ch))]
}
return string(b)
}
func sizedStringForCompression(sz int) string {
b := make([]byte, sz)
c := byte(0)
s := 0
for i := range b {
if s%20 == 0 {
c = ch[rand.Intn(len(ch))]
}
b[i] = c
}
return string(b)
}
func testWSFlushConn(b *testing.B, compress bool, c net.Conn, br *bufio.Reader) {
buf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte(pingProto))
c.Write(buf)
c.SetReadDeadline(time.Now().Add(5 * time.Second))
res := testWSReadFrame(b, br)
c.SetReadDeadline(time.Time{})
if !bytes.HasPrefix(res, []byte(pongProto)) {
b.Fatalf("Failed read of PONG: %s\n", res)
}
}
func wsBenchPub(b *testing.B, numPubs int, compress bool, payload string) {
b.StopTimer()
opts := testWSOptions()
opts.Websocket.Compression = compress
s := RunServer(opts)
defer s.Shutdown()
n := b.N
extra := 0
pubProto := []byte(fmt.Sprintf("PUB %s %d\r\n%s\r\n", testWSBenchSubject, len(payload), payload))
singleOpBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, pubProto)
// Simulate client that would buffer messages before framing/sending.
// Figure out how many we can fit in one frame based on b.N and length of pubProto
const bufSize = 32768
tmpa := [bufSize]byte{}
tmp := tmpa[:0]
pb := 0
for i := 0; i < b.N; i++ {
tmp = append(tmp, pubProto...)
pb++
if len(tmp) >= bufSize {
break
}
}
sendBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, tmp)
n = b.N / pb
extra = b.N - (n * pb)
wg := sync.WaitGroup{}
wg.Add(numPubs)
type pub struct {
c net.Conn
br *bufio.Reader
bw *bufio.Writer
}
var pubs []pub
for i := 0; i < numPubs; i++ {
wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port)
defer wsc.Close()
bw := bufio.NewWriterSize(wsc, bufSize)
pubs = append(pubs, pub{wsc, br, bw})
}
// Average the amount of bytes sent by iteration
avg := len(sendBuf) / pb
if extra > 0 {
avg += len(singleOpBuf)
avg /= 2
}
b.SetBytes(int64(numPubs * avg))
b.StartTimer()
for i := 0; i < numPubs; i++ {
p := pubs[i]
go func(p pub) {
defer wg.Done()
for i := 0; i < n; i++ {
p.bw.Write(sendBuf)
}
for i := 0; i < extra; i++ {
p.bw.Write(singleOpBuf)
}
p.bw.Flush()
testWSFlushConn(b, compress, p.c, p.br)
}(p)
}
wg.Wait()
b.StopTimer()
}
func Benchmark_WS_Pubx1_CN_____0b(b *testing.B) {
wsBenchPub(b, 1, false, "")
}
func Benchmark_WS_Pubx1_CY_____0b(b *testing.B) {
wsBenchPub(b, 1, true, "")
}
func Benchmark_WS_Pubx1_CN___128b(b *testing.B) {
s := sizedString(128)
wsBenchPub(b, 1, false, s)
}
func Benchmark_WS_Pubx1_CY___128b(b *testing.B) {
s := sizedStringForCompression(128)
wsBenchPub(b, 1, true, s)
}
func Benchmark_WS_Pubx1_CN__1024b(b *testing.B) {
s := sizedString(1024)
wsBenchPub(b, 1, false, s)
}
func Benchmark_WS_Pubx1_CY__1024b(b *testing.B) {
s := sizedStringForCompression(1024)
wsBenchPub(b, 1, true, s)
}
func Benchmark_WS_Pubx1_CN__4096b(b *testing.B) {
s := sizedString(4 * 1024)
wsBenchPub(b, 1, false, s)
}
func Benchmark_WS_Pubx1_CY__4096b(b *testing.B) {
s := sizedStringForCompression(4 * 1024)
wsBenchPub(b, 1, true, s)
}
func Benchmark_WS_Pubx1_CN__8192b(b *testing.B) {
s := sizedString(8 * 1024)
wsBenchPub(b, 1, false, s)
}
func Benchmark_WS_Pubx1_CY__8192b(b *testing.B) {
s := sizedStringForCompression(8 * 1024)
wsBenchPub(b, 1, true, s)
}
func Benchmark_WS_Pubx1_CN_32768b(b *testing.B) {
s := sizedString(32 * 1024)
wsBenchPub(b, 1, false, s)
}
func Benchmark_WS_Pubx1_CY_32768b(b *testing.B) {
s := sizedStringForCompression(32 * 1024)
wsBenchPub(b, 1, true, s)
}
func Benchmark_WS_Pubx5_CN_____0b(b *testing.B) {
wsBenchPub(b, 5, false, "")
}
func Benchmark_WS_Pubx5_CY_____0b(b *testing.B) {
wsBenchPub(b, 5, true, "")
}
func Benchmark_WS_Pubx5_CN___128b(b *testing.B) {
s := sizedString(128)
wsBenchPub(b, 5, false, s)
}
func Benchmark_WS_Pubx5_CY___128b(b *testing.B) {
s := sizedStringForCompression(128)
wsBenchPub(b, 5, true, s)
}
func Benchmark_WS_Pubx5_CN__1024b(b *testing.B) {
s := sizedString(1024)
wsBenchPub(b, 5, false, s)
}
func Benchmark_WS_Pubx5_CY__1024b(b *testing.B) {
s := sizedStringForCompression(1024)
wsBenchPub(b, 5, true, s)
}
func Benchmark_WS_Pubx5_CN__4096b(b *testing.B) {
s := sizedString(4 * 1024)
wsBenchPub(b, 5, false, s)
}
func Benchmark_WS_Pubx5_CY__4096b(b *testing.B) {
s := sizedStringForCompression(4 * 1024)
wsBenchPub(b, 5, true, s)
}
func Benchmark_WS_Pubx5_CN__8192b(b *testing.B) {
s := sizedString(8 * 1024)
wsBenchPub(b, 5, false, s)
}
func Benchmark_WS_Pubx5_CY__8192b(b *testing.B) {
s := sizedStringForCompression(8 * 1024)
wsBenchPub(b, 5, true, s)
}
func Benchmark_WS_Pubx5_CN_32768b(b *testing.B) {
s := sizedString(32 * 1024)
wsBenchPub(b, 5, false, s)
}
func Benchmark_WS_Pubx5_CY_32768b(b *testing.B) {
s := sizedStringForCompression(32 * 1024)
wsBenchPub(b, 5, true, s)
}
func wsBenchSub(b *testing.B, numSubs int, compress bool, payload string) {
b.StopTimer()
opts := testWSOptions()
opts.Websocket.Compression = compress
s := RunServer(opts)
defer s.Shutdown()
var subs []*bufio.Reader
for i := 0; i < numSubs; i++ {
wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port)
defer wsc.Close()
subProto := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress,
[]byte(fmt.Sprintf("SUB %s 1\r\nPING\r\n", testWSBenchSubject)))
wsc.Write(subProto)
// Waiting for PONG
testWSReadFrame(b, br)
subs = append(subs, br)
}
wg := sync.WaitGroup{}
wg.Add(numSubs)
// Use regular NATS client to publish messages
nc := natsConnect(b, s.ClientURL())
defer nc.Close()
b.StartTimer()
for i := 0; i < numSubs; i++ {
br := subs[i]
go func(br *bufio.Reader) {
defer wg.Done()
for count := 0; count < b.N; {
msgs := testWSReadFrame(b, br)
count += bytes.Count(msgs, []byte("MSG "))
}
}(br)
}
for i := 0; i < b.N; i++ {
natsPub(b, nc, testWSBenchSubject, []byte(payload))
}
wg.Wait()
b.StopTimer()
}
func Benchmark_WS_Subx1_CN_____0b(b *testing.B) {
wsBenchSub(b, 1, false, "")
}
func Benchmark_WS_Subx1_CY_____0b(b *testing.B) {
wsBenchSub(b, 1, true, "")
}
func Benchmark_WS_Subx1_CN___128b(b *testing.B) {
s := sizedString(128)
wsBenchSub(b, 1, false, s)
}
func Benchmark_WS_Subx1_CY___128b(b *testing.B) {
s := sizedStringForCompression(128)
wsBenchSub(b, 1, true, s)
}
func Benchmark_WS_Subx1_CN__1024b(b *testing.B) {
s := sizedString(1024)
wsBenchSub(b, 1, false, s)
}
func Benchmark_WS_Subx1_CY__1024b(b *testing.B) {
s := sizedStringForCompression(1024)
wsBenchSub(b, 1, true, s)
}
func Benchmark_WS_Subx1_CN__4096b(b *testing.B) {
s := sizedString(4096)
wsBenchSub(b, 1, false, s)
}
func Benchmark_WS_Subx1_CY__4096b(b *testing.B) {
s := sizedStringForCompression(4096)
wsBenchSub(b, 1, true, s)
}
func Benchmark_WS_Subx1_CN__8192b(b *testing.B) {
s := sizedString(8192)
wsBenchSub(b, 1, false, s)
}
func Benchmark_WS_Subx1_CY__8192b(b *testing.B) {
s := sizedStringForCompression(8192)
wsBenchSub(b, 1, true, s)
}
func Benchmark_WS_Subx1_CN_32768b(b *testing.B) {
s := sizedString(32768)
wsBenchSub(b, 1, false, s)
}
func Benchmark_WS_Subx1_CY_32768b(b *testing.B) {
s := sizedStringForCompression(32768)
wsBenchSub(b, 1, true, s)
}
func Benchmark_WS_Subx5_CN_____0b(b *testing.B) {
wsBenchSub(b, 5, false, "")
}
func Benchmark_WS_Subx5_CY_____0b(b *testing.B) {
wsBenchSub(b, 5, true, "")
}
func Benchmark_WS_Subx5_CN___128b(b *testing.B) {
s := sizedString(128)
wsBenchSub(b, 5, false, s)
}
func Benchmark_WS_Subx5_CY___128b(b *testing.B) {
s := sizedStringForCompression(128)
wsBenchSub(b, 5, true, s)
}
func Benchmark_WS_Subx5_CN__1024b(b *testing.B) {
s := sizedString(1024)
wsBenchSub(b, 5, false, s)
}
func Benchmark_WS_Subx5_CY__1024b(b *testing.B) {
s := sizedStringForCompression(1024)
wsBenchSub(b, 5, true, s)
}
func Benchmark_WS_Subx5_CN__4096b(b *testing.B) {
s := sizedString(4096)
wsBenchSub(b, 5, false, s)
}
func Benchmark_WS_Subx5_CY__4096b(b *testing.B) {
s := sizedStringForCompression(4096)
wsBenchSub(b, 5, true, s)
}
func Benchmark_WS_Subx5_CN__8192b(b *testing.B) {
s := sizedString(8192)
wsBenchSub(b, 5, false, s)
}
func Benchmark_WS_Subx5_CY__8192b(b *testing.B) {
s := sizedStringForCompression(8192)
wsBenchSub(b, 5, true, s)
}
func Benchmark_WS_Subx5_CN_32768b(b *testing.B) {
s := sizedString(32768)
wsBenchSub(b, 5, false, s)
}
func Benchmark_WS_Subx5_CY_32768b(b *testing.B) {
s := sizedStringForCompression(32768)
wsBenchSub(b, 5, true, s)
}