mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 03:38:42 -07:00
This hides the check on "c.mqtt != nil" or "c.ws != nil". Added some tests. Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
3929 lines
109 KiB
Go
3929 lines
109 KiB
Go
// Copyright 2020 The NATS Authors
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package server
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"compress/flate"
|
|
"crypto/tls"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"math/rand"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/nats-io/jwt/v2"
|
|
"github.com/nats-io/nkeys"
|
|
)
|
|
|
|
type testReader struct {
|
|
buf []byte
|
|
pos int
|
|
max int
|
|
err error
|
|
}
|
|
|
|
func (tr *testReader) Read(p []byte) (int, error) {
|
|
if tr.err != nil {
|
|
return 0, tr.err
|
|
}
|
|
n := len(tr.buf) - tr.pos
|
|
if n == 0 {
|
|
return 0, nil
|
|
}
|
|
if n > cap(p) {
|
|
n = cap(p)
|
|
}
|
|
if tr.max > 0 && n > tr.max {
|
|
n = tr.max
|
|
}
|
|
copy(p, tr.buf[tr.pos:tr.pos+n])
|
|
tr.pos += n
|
|
return n, nil
|
|
}
|
|
|
|
func TestWSGet(t *testing.T) {
|
|
rb := []byte("012345")
|
|
|
|
tr := &testReader{buf: []byte("6789")}
|
|
|
|
for _, test := range []struct {
|
|
name string
|
|
pos int
|
|
needed int
|
|
newpos int
|
|
trmax int
|
|
result string
|
|
reterr bool
|
|
}{
|
|
{"fromrb1", 0, 3, 3, 4, "012", false}, // Partial from read buffer
|
|
{"fromrb2", 3, 2, 5, 4, "34", false}, // Partial from read buffer
|
|
{"fromrb3", 5, 1, 6, 4, "5", false}, // Partial from read buffer
|
|
{"fromtr1", 4, 4, 6, 4, "4567", false}, // Partial from read buffer + some of ioReader
|
|
{"fromtr2", 4, 6, 6, 4, "456789", false}, // Partial from read buffer + all of ioReader
|
|
{"fromtr3", 4, 6, 6, 2, "456789", false}, // Partial from read buffer + all of ioReader with several reads
|
|
{"fromtr4", 4, 6, 6, 2, "", true}, // ioReader returns error
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
tr.pos = 0
|
|
tr.max = test.trmax
|
|
if test.reterr {
|
|
tr.err = fmt.Errorf("on purpose")
|
|
}
|
|
res, np, err := wsGet(tr, rb, test.pos, test.needed)
|
|
if test.reterr {
|
|
if err == nil {
|
|
t.Fatalf("Expected error, got none")
|
|
}
|
|
if err.Error() != "on purpose" {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if np != 0 || res != nil {
|
|
t.Fatalf("Unexpected returned values: res=%v n=%v", res, np)
|
|
}
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Error on get: %v", err)
|
|
}
|
|
if np != test.newpos {
|
|
t.Fatalf("Expected pos=%v, got %v", test.newpos, np)
|
|
}
|
|
if string(res) != test.result {
|
|
t.Fatalf("Invalid returned content: %s", res)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSIsControlFrame(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
code wsOpCode
|
|
isControl bool
|
|
}{
|
|
{"binary", wsBinaryMessage, false},
|
|
{"text", wsTextMessage, false},
|
|
{"ping", wsPingMessage, true},
|
|
{"pong", wsPongMessage, true},
|
|
{"close", wsCloseMessage, true},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
if res := wsIsControlFrame(test.code); res != test.isControl {
|
|
t.Fatalf("Expected %q isControl to be %v, got %v", test.name, test.isControl, res)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func testWSSimpleMask(key, buf []byte) {
|
|
for i := 0; i < len(buf); i++ {
|
|
buf[i] ^= key[i&3]
|
|
}
|
|
}
|
|
|
|
func TestWSUnmask(t *testing.T) {
|
|
key := []byte{1, 2, 3, 4}
|
|
orgBuf := []byte("this is a clear text")
|
|
|
|
mask := func() []byte {
|
|
t.Helper()
|
|
buf := append([]byte(nil), orgBuf...)
|
|
testWSSimpleMask(key, buf)
|
|
// First ensure that the content is masked.
|
|
if bytes.Equal(buf, orgBuf) {
|
|
t.Fatalf("Masking did not do anything: %q", buf)
|
|
}
|
|
return buf
|
|
}
|
|
|
|
ri := &wsReadInfo{}
|
|
ri.init()
|
|
copy(ri.mkey[:], key)
|
|
|
|
buf := mask()
|
|
// Unmask in one call
|
|
ri.unmask(buf)
|
|
if !bytes.Equal(buf, orgBuf) {
|
|
t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf)
|
|
}
|
|
|
|
// Unmask in multiple calls
|
|
buf = mask()
|
|
ri.mkpos = 0
|
|
ri.unmask(buf[:3])
|
|
ri.unmask(buf[3:11])
|
|
ri.unmask(buf[11:])
|
|
if !bytes.Equal(buf, orgBuf) {
|
|
t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf)
|
|
}
|
|
}
|
|
|
|
func TestWSCreateCloseMessage(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
status int
|
|
psize int
|
|
truncated bool
|
|
}{
|
|
{"fits", wsCloseStatusInternalSrvError, 10, false},
|
|
{"truncated", wsCloseStatusProtocolError, wsMaxControlPayloadSize + 10, true},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
payload := make([]byte, test.psize)
|
|
for i := 0; i < len(payload); i++ {
|
|
payload[i] = byte('A' + (i % 26))
|
|
}
|
|
res := wsCreateCloseMessage(test.status, string(payload))
|
|
if status := binary.BigEndian.Uint16(res[:2]); int(status) != test.status {
|
|
t.Fatalf("Expected status to be %v, got %v", test.status, status)
|
|
}
|
|
psize := len(res) - 2
|
|
if !test.truncated {
|
|
if int(psize) != test.psize {
|
|
t.Fatalf("Expected size to be %v, got %v", test.psize, psize)
|
|
}
|
|
if !bytes.Equal(res[2:], payload) {
|
|
t.Fatalf("Unexpected result: %q", res[2:])
|
|
}
|
|
return
|
|
}
|
|
// Since the payload of a close message contains a 2 byte status, the
|
|
// actual max text size will be wsMaxControlPayloadSize-2
|
|
if int(psize) != wsMaxControlPayloadSize-2 {
|
|
t.Fatalf("Expected size to be capped to %v, got %v", wsMaxControlPayloadSize-2, psize)
|
|
}
|
|
if string(res[len(res)-3:]) != "..." {
|
|
t.Fatalf("Expected res to have `...` at the end, got %q", res[4:])
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSCreateFrameHeader(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
frameType wsOpCode
|
|
compressed bool
|
|
len int
|
|
}{
|
|
{"uncompressed 10", wsBinaryMessage, false, 10},
|
|
{"uncompressed 600", wsTextMessage, false, 600},
|
|
{"uncompressed 100000", wsTextMessage, false, 100000},
|
|
{"compressed 10", wsBinaryMessage, true, 10},
|
|
{"compressed 600", wsBinaryMessage, true, 600},
|
|
{"compressed 100000", wsTextMessage, true, 100000},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
res := wsCreateFrameHeader(test.compressed, test.frameType, test.len)
|
|
// The server is always sending the message has a single frame,
|
|
// so the "final" bit should be set.
|
|
expected := byte(test.frameType) | wsFinalBit
|
|
if test.compressed {
|
|
expected |= wsRsv1Bit
|
|
}
|
|
if b := res[0]; b != expected {
|
|
t.Fatalf("Expected first byte to be %v, got %v", expected, b)
|
|
}
|
|
switch {
|
|
case test.len <= 125:
|
|
if len(res) != 2 {
|
|
t.Fatalf("Frame len should be 2, got %v", len(res))
|
|
}
|
|
if res[1] != byte(test.len) {
|
|
t.Fatalf("Expected len to be in second byte and be %v, got %v", test.len, res[1])
|
|
}
|
|
case test.len < 65536:
|
|
// 1+1+2
|
|
if len(res) != 4 {
|
|
t.Fatalf("Frame len should be 4, got %v", len(res))
|
|
}
|
|
if res[1] != 126 {
|
|
t.Fatalf("Second byte value should be 126, got %v", res[1])
|
|
}
|
|
if rl := binary.BigEndian.Uint16(res[2:]); int(rl) != test.len {
|
|
t.Fatalf("Expected len to be %v, got %v", test.len, rl)
|
|
}
|
|
default:
|
|
// 1+1+8
|
|
if len(res) != 10 {
|
|
t.Fatalf("Frame len should be 10, got %v", len(res))
|
|
}
|
|
if res[1] != 127 {
|
|
t.Fatalf("Second byte value should be 127, got %v", res[1])
|
|
}
|
|
if rl := binary.BigEndian.Uint64(res[2:]); int(rl) != test.len {
|
|
t.Fatalf("Expected len to be %v, got %v", test.len, rl)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func testWSCreateClientMsg(frameType wsOpCode, frameNum int, final, compressed bool, payload []byte) []byte {
|
|
if compressed {
|
|
buf := &bytes.Buffer{}
|
|
compressor, _ := flate.NewWriter(buf, 1)
|
|
compressor.Write(payload)
|
|
compressor.Flush()
|
|
payload = buf.Bytes()
|
|
// The last 4 bytes are dropped
|
|
payload = payload[:len(payload)-4]
|
|
}
|
|
frame := make([]byte, 14+len(payload))
|
|
if frameNum == 1 {
|
|
frame[0] = byte(frameType)
|
|
}
|
|
if final {
|
|
frame[0] |= wsFinalBit
|
|
}
|
|
if compressed {
|
|
frame[0] |= wsRsv1Bit
|
|
}
|
|
pos := 1
|
|
lenPayload := len(payload)
|
|
switch {
|
|
case lenPayload <= 125:
|
|
frame[pos] = byte(lenPayload) | wsMaskBit
|
|
pos++
|
|
case lenPayload < 65536:
|
|
frame[pos] = 126 | wsMaskBit
|
|
binary.BigEndian.PutUint16(frame[2:], uint16(lenPayload))
|
|
pos += 3
|
|
default:
|
|
frame[1] = 127 | wsMaskBit
|
|
binary.BigEndian.PutUint64(frame[2:], uint64(lenPayload))
|
|
pos += 9
|
|
}
|
|
key := []byte{1, 2, 3, 4}
|
|
copy(frame[pos:], key)
|
|
pos += 4
|
|
copy(frame[pos:], payload)
|
|
testWSSimpleMask(key, frame[pos:])
|
|
pos += lenPayload
|
|
return frame[:pos]
|
|
}
|
|
|
|
func testWSSetupForRead() (*client, *wsReadInfo, *testReader) {
|
|
ri := &wsReadInfo{}
|
|
ri.init()
|
|
tr := &testReader{}
|
|
opts := DefaultOptions()
|
|
opts.MaxPending = MAX_PENDING_SIZE
|
|
s := &Server{opts: opts}
|
|
c := &client{srv: s, ws: &websocket{}}
|
|
c.initClient()
|
|
return c, ri, tr
|
|
}
|
|
|
|
func TestWSReadUncompressedFrames(t *testing.T) {
|
|
c, ri, tr := testWSSetupForRead()
|
|
// Create 2 WS messages
|
|
pl1 := []byte("first message")
|
|
wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl1)
|
|
pl2 := []byte("second message")
|
|
wsmsg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl2)
|
|
// Add both in single buffer
|
|
orgrb := append([]byte(nil), wsmsg1...)
|
|
orgrb = append(orgrb, wsmsg2...)
|
|
|
|
rb := append([]byte(nil), orgrb...)
|
|
bufs, err := c.wsRead(ri, tr, rb)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 2 {
|
|
t.Fatalf("Expected 2 buffers, got %v", n)
|
|
}
|
|
if !bytes.Equal(bufs[0], pl1) {
|
|
t.Fatalf("Unexpected content for buffer 1: %s", bufs[0])
|
|
}
|
|
if !bytes.Equal(bufs[1], pl2) {
|
|
t.Fatalf("Unexpected content for buffer 2: %s", bufs[1])
|
|
}
|
|
|
|
// Now reset and try with the read buffer not containing full ws frame
|
|
c, ri, tr = testWSSetupForRead()
|
|
rb = append([]byte(nil), orgrb...)
|
|
// Frame is 1+1+4+'first message'. So say we pass with rb of 11 bytes,
|
|
// then we should get "first"
|
|
bufs, err = c.wsRead(ri, tr, rb[:11])
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 1 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
if string(bufs[0]) != "first" {
|
|
t.Fatalf("Unexpected content: %q", bufs[0])
|
|
}
|
|
// Call again with more data..
|
|
bufs, err = c.wsRead(ri, tr, rb[11:32])
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 2 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
if string(bufs[0]) != " message" {
|
|
t.Fatalf("Unexpected content: %q", bufs[0])
|
|
}
|
|
if string(bufs[1]) != "second " {
|
|
t.Fatalf("Unexpected content: %q", bufs[1])
|
|
}
|
|
// Call with the rest
|
|
bufs, err = c.wsRead(ri, tr, rb[32:])
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 1 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
if string(bufs[0]) != "message" {
|
|
t.Fatalf("Unexpected content: %q", bufs[0])
|
|
}
|
|
}
|
|
|
|
func TestWSReadCompressedFrames(t *testing.T) {
|
|
c, ri, tr := testWSSetupForRead()
|
|
uncompressed := []byte("this is the uncompress data")
|
|
wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed)
|
|
rb := append([]byte(nil), wsmsg1...)
|
|
// Call with some but not all of the payload
|
|
bufs, err := c.wsRead(ri, tr, rb[:10])
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 0 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
// Call with the rest, only then should we get the uncompressed data.
|
|
bufs, err = c.wsRead(ri, tr, rb[10:])
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 1 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
if !bytes.Equal(bufs[0], uncompressed) {
|
|
t.Fatalf("Unexpected content: %s", bufs[0])
|
|
}
|
|
// Stress the fact that we use a pool and want to make sure
|
|
// that if we get a decompressor from the pool, it is properly reset
|
|
// with the buffer to decompress.
|
|
for i := 0; i < 9; i++ {
|
|
rb = append(rb, wsmsg1...)
|
|
}
|
|
bufs, err = c.wsRead(ri, tr, rb)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 10 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
}
|
|
|
|
func TestWSReadCompressedFrameCorrupted(t *testing.T) {
|
|
c, ri, tr := testWSSetupForRead()
|
|
uncompressed := []byte("this is the uncompress data")
|
|
wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed)
|
|
copy(wsmsg1[10:], []byte{1, 2, 3, 4})
|
|
rb := append([]byte(nil), wsmsg1...)
|
|
bufs, err := c.wsRead(ri, tr, rb)
|
|
if err == nil || !strings.Contains(err.Error(), "corrupt") {
|
|
t.Fatalf("Expected error about corrupted data, got %v", err)
|
|
}
|
|
if n := len(bufs); n != 0 {
|
|
t.Fatalf("Expected no buffer, got %v", n)
|
|
}
|
|
}
|
|
|
|
func TestWSReadVariousFrameSizes(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
size int
|
|
}{
|
|
{"tiny", 100},
|
|
{"medium", 1000},
|
|
{"large", 70000},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
c, ri, tr := testWSSetupForRead()
|
|
uncompressed := make([]byte, test.size)
|
|
for i := 0; i < len(uncompressed); i++ {
|
|
uncompressed[i] = 'A' + byte(i%26)
|
|
}
|
|
wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, uncompressed)
|
|
rb := append([]byte(nil), wsmsg1...)
|
|
bufs, err := c.wsRead(ri, tr, rb)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 1 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
if !bytes.Equal(bufs[0], uncompressed) {
|
|
t.Fatalf("Unexpected content: %s", bufs[0])
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSReadFragmentedFrames(t *testing.T) {
|
|
c, ri, tr := testWSSetupForRead()
|
|
payloads := []string{"first", "second", "third"}
|
|
var rb []byte
|
|
for i := 0; i < len(payloads); i++ {
|
|
final := i == len(payloads)-1
|
|
frag := testWSCreateClientMsg(wsBinaryMessage, i+1, final, false, []byte(payloads[i]))
|
|
rb = append(rb, frag...)
|
|
}
|
|
bufs, err := c.wsRead(ri, tr, rb)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 3 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
for i, expected := range payloads {
|
|
if string(bufs[i]) != expected {
|
|
t.Fatalf("Unexpected content for buf=%v: %s", i, bufs[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWSReadPartialFrameHeaderAtEndOfReadBuffer(t *testing.T) {
|
|
c, ri, tr := testWSSetupForRead()
|
|
msg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg1"))
|
|
msg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg2"))
|
|
rb := append([]byte(nil), msg1...)
|
|
rb = append(rb, msg2...)
|
|
// We will pass the first frame + the first byte of the next frame.
|
|
rbl := rb[:len(msg1)+1]
|
|
// Make the io reader return the rest of the frame
|
|
tr.buf = rb[len(msg1)+1:]
|
|
bufs, err := c.wsRead(ri, tr, rbl)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 1 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
// We should not have asked to the io reader more than what is needed for reading
|
|
// the frame header. Since we had already the first byte in the read buffer,
|
|
// tr.pos should be 1(size)+4(key)=5
|
|
if tr.pos != 5 {
|
|
t.Fatalf("Expected reader pos to be 5, got %v", tr.pos)
|
|
}
|
|
}
|
|
|
|
func TestWSReadPingFrame(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
payload []byte
|
|
}{
|
|
{"without payload", nil},
|
|
{"with payload", []byte("optional payload")},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
c, ri, tr := testWSSetupForRead()
|
|
ping := testWSCreateClientMsg(wsPingMessage, 1, true, false, test.payload)
|
|
rb := append([]byte(nil), ping...)
|
|
bufs, err := c.wsRead(ri, tr, rb)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 0 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
// A PONG should have been queued with the payload of the ping
|
|
c.mu.Lock()
|
|
nb, _ := c.collapsePtoNB()
|
|
c.mu.Unlock()
|
|
if n := len(nb); n == 0 {
|
|
t.Fatalf("Expected buffers, got %v", n)
|
|
}
|
|
if expected := 2 + len(test.payload); expected != len(nb[0]) {
|
|
t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0]))
|
|
}
|
|
b := nb[0][0]
|
|
if b&wsFinalBit == 0 {
|
|
t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
|
|
}
|
|
if b&byte(wsPongMessage) == 0 {
|
|
t.Fatalf("Should have been a PONG, it wasn't: %v", b)
|
|
}
|
|
if len(test.payload) > 0 {
|
|
if !bytes.Equal(nb[0][2:], test.payload) {
|
|
t.Fatalf("Unexpected content: %s", nb[0][2:])
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSReadPongFrame(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
payload []byte
|
|
}{
|
|
{"without payload", nil},
|
|
{"with payload", []byte("optional payload")},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
c, ri, tr := testWSSetupForRead()
|
|
pong := testWSCreateClientMsg(wsPongMessage, 1, true, false, test.payload)
|
|
rb := append([]byte(nil), pong...)
|
|
bufs, err := c.wsRead(ri, tr, rb)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 0 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
// Nothing should be sent...
|
|
c.mu.Lock()
|
|
nb, _ := c.collapsePtoNB()
|
|
c.mu.Unlock()
|
|
if n := len(nb); n != 0 {
|
|
t.Fatalf("Expected no buffer, got %v", n)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSReadCloseFrame(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
payload []byte
|
|
}{
|
|
{"without payload", nil},
|
|
{"with payload", []byte("optional payload")},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
c, ri, tr := testWSSetupForRead()
|
|
// a close message has a status in 2 bytes + optional payload
|
|
payload := make([]byte, 2+len(test.payload))
|
|
binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure)
|
|
if len(test.payload) > 0 {
|
|
copy(payload[2:], test.payload)
|
|
}
|
|
close := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload)
|
|
// Have a normal frame prior to close to make sure that wsRead returns
|
|
// the normal frame along with io.EOF to indicate that wsCloseMessage was received.
|
|
msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg"))
|
|
rb := append([]byte(nil), msg...)
|
|
rb = append(rb, close...)
|
|
bufs, err := c.wsRead(ri, tr, rb)
|
|
// It is expected that wsRead returns io.EOF on processing a close.
|
|
if err != io.EOF {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 1 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
if string(bufs[0]) != "msg" {
|
|
t.Fatalf("Unexpected content: %s", bufs[0])
|
|
}
|
|
// A CLOSE should have been queued with the payload of the original close message.
|
|
c.mu.Lock()
|
|
nb, _ := c.collapsePtoNB()
|
|
c.mu.Unlock()
|
|
if n := len(nb); n == 0 {
|
|
t.Fatalf("Expected buffers, got %v", n)
|
|
}
|
|
if expected := 2 + 2 + len(test.payload); expected != len(nb[0]) {
|
|
t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0]))
|
|
}
|
|
b := nb[0][0]
|
|
if b&wsFinalBit == 0 {
|
|
t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
|
|
}
|
|
if b&byte(wsCloseMessage) == 0 {
|
|
t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
|
|
}
|
|
if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNormalClosure {
|
|
t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNormalClosure, status)
|
|
}
|
|
if len(test.payload) > 0 {
|
|
if !bytes.Equal(nb[0][4:], test.payload) {
|
|
t.Fatalf("Unexpected content: %s", nb[0][4:])
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSReadControlFrameBetweebFragmentedFrames(t *testing.T) {
|
|
c, ri, tr := testWSSetupForRead()
|
|
frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("first"))
|
|
frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, []byte("second"))
|
|
ctrl := testWSCreateClientMsg(wsPongMessage, 1, true, false, nil)
|
|
rb := append([]byte(nil), frag1...)
|
|
rb = append(rb, ctrl...)
|
|
rb = append(rb, frag2...)
|
|
bufs, err := c.wsRead(ri, tr, rb)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 2 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
if string(bufs[0]) != "first" {
|
|
t.Fatalf("Unexpected content: %s", bufs[0])
|
|
}
|
|
if string(bufs[1]) != "second" {
|
|
t.Fatalf("Unexpected content: %s", bufs[1])
|
|
}
|
|
}
|
|
|
|
func TestWSReadGetErrors(t *testing.T) {
|
|
tr := &testReader{err: fmt.Errorf("on purpose")}
|
|
for _, test := range []struct {
|
|
lenPayload int
|
|
rbextra int
|
|
}{
|
|
{10, 1},
|
|
{10, 3},
|
|
{200, 1},
|
|
{200, 2},
|
|
{200, 5},
|
|
{70000, 1},
|
|
{70000, 5},
|
|
{70000, 13},
|
|
} {
|
|
t.Run("", func(t *testing.T) {
|
|
c, ri, _ := testWSSetupForRead()
|
|
msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg"))
|
|
frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, make([]byte, test.lenPayload))
|
|
rb := append([]byte(nil), msg...)
|
|
rb = append(rb, frame...)
|
|
bufs, err := c.wsRead(ri, tr, rb[:len(msg)+test.rbextra])
|
|
if err == nil || err.Error() != "on purpose" {
|
|
t.Fatalf("Expected 'on purpose' error, got %v", err)
|
|
}
|
|
if n := len(bufs); n != 1 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
if string(bufs[0]) != "msg" {
|
|
t.Fatalf("Unexpected content: %s", bufs[0])
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSHandleControlFrameErrors(t *testing.T) {
|
|
c, ri, tr := testWSSetupForRead()
|
|
tr.err = fmt.Errorf("on purpose")
|
|
|
|
// a close message has a status in 2 bytes + optional payload
|
|
text := []byte("this is a close message")
|
|
payload := make([]byte, 2+len(text))
|
|
binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure)
|
|
copy(payload[2:], text)
|
|
ctrl := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload)
|
|
|
|
bufs, err := c.wsRead(ri, tr, ctrl[:len(ctrl)-4])
|
|
if err == nil || err.Error() != "on purpose" {
|
|
t.Fatalf("Expected 'on purpose' error, got %v", err)
|
|
}
|
|
if n := len(bufs); n != 0 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
|
|
// Alter the content of close message. It is supposed to be valid utf-8.
|
|
c, ri, tr = testWSSetupForRead()
|
|
cp := append([]byte(nil), payload...)
|
|
cp[10] = 0xF1
|
|
ctrl = testWSCreateClientMsg(wsCloseMessage, 1, true, false, cp)
|
|
bufs, err = c.wsRead(ri, tr, ctrl)
|
|
// We should still receive an EOF but the message enqueued to the client
|
|
// should contain wsCloseStatusInvalidPayloadData and the error about invalid utf8
|
|
if err != io.EOF {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if n := len(bufs); n != 0 {
|
|
t.Fatalf("Unexpected buffer returned: %v", n)
|
|
}
|
|
c.mu.Lock()
|
|
nb, _ := c.collapsePtoNB()
|
|
c.mu.Unlock()
|
|
if n := len(nb); n == 0 {
|
|
t.Fatalf("Expected buffers, got %v", n)
|
|
}
|
|
b := nb[0][0]
|
|
if b&wsFinalBit == 0 {
|
|
t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
|
|
}
|
|
if b&byte(wsCloseMessage) == 0 {
|
|
t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
|
|
}
|
|
if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusInvalidPayloadData {
|
|
t.Fatalf("Expected status to be %v, got %v", wsCloseStatusInvalidPayloadData, status)
|
|
}
|
|
if !bytes.Contains(nb[0][4:], []byte("utf8")) {
|
|
t.Fatalf("Unexpected content: %s", nb[0][4:])
|
|
}
|
|
}
|
|
|
|
func TestWSReadErrors(t *testing.T) {
|
|
for _, test := range []struct {
|
|
cframe func() []byte
|
|
err string
|
|
nbufs int
|
|
}{
|
|
{
|
|
func() []byte {
|
|
msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello"))
|
|
msg[1] &= ^byte(wsMaskBit)
|
|
return msg
|
|
},
|
|
"mask bit missing", 1,
|
|
},
|
|
{
|
|
func() []byte {
|
|
return testWSCreateClientMsg(wsPingMessage, 1, true, false, make([]byte, 200))
|
|
},
|
|
"control frame length bigger than maximum allowed", 1,
|
|
},
|
|
{
|
|
func() []byte {
|
|
return testWSCreateClientMsg(wsPingMessage, 1, false, false, []byte("hello"))
|
|
},
|
|
"control frame does not have final bit set", 1,
|
|
},
|
|
{
|
|
func() []byte {
|
|
frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("frag1"))
|
|
newMsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("new message"))
|
|
all := append([]byte(nil), frag1...)
|
|
all = append(all, newMsg...)
|
|
return all
|
|
},
|
|
"new message started before final frame for previous message was received", 2,
|
|
},
|
|
{
|
|
func() []byte {
|
|
frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, true, []byte("frag1"))
|
|
frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frag2"))
|
|
frag2[0] |= wsRsv1Bit
|
|
all := append([]byte(nil), frag1...)
|
|
all = append(all, frag2...)
|
|
return all
|
|
},
|
|
"invalid continuation frame", 2,
|
|
},
|
|
{
|
|
func() []byte {
|
|
return testWSCreateClientMsg(99, 1, false, false, []byte("hello"))
|
|
},
|
|
"unknown opcode", 1,
|
|
},
|
|
} {
|
|
t.Run(test.err, func(t *testing.T) {
|
|
c, ri, tr := testWSSetupForRead()
|
|
// Add a valid message first
|
|
msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello"))
|
|
// Then add the bad frame
|
|
bad := test.cframe()
|
|
// Add them both to a read buffer
|
|
rb := append([]byte(nil), msg...)
|
|
rb = append(rb, bad...)
|
|
bufs, err := c.wsRead(ri, tr, rb)
|
|
if err == nil || !strings.Contains(err.Error(), test.err) {
|
|
t.Fatalf("Expected error to contain %q, got %q", test.err, err.Error())
|
|
}
|
|
if n := len(bufs); n != test.nbufs {
|
|
t.Fatalf("Unexpected number of buffers: %v", n)
|
|
}
|
|
if string(bufs[0]) != "hello" {
|
|
t.Fatalf("Unexpected content: %s", bufs[0])
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSEnqueueCloseMsg(t *testing.T) {
|
|
for _, test := range []struct {
|
|
reason ClosedState
|
|
status int
|
|
}{
|
|
{ClientClosed, wsCloseStatusNormalClosure},
|
|
{AuthenticationTimeout, wsCloseStatusPolicyViolation},
|
|
{AuthenticationViolation, wsCloseStatusPolicyViolation},
|
|
{SlowConsumerPendingBytes, wsCloseStatusPolicyViolation},
|
|
{SlowConsumerWriteDeadline, wsCloseStatusPolicyViolation},
|
|
{MaxAccountConnectionsExceeded, wsCloseStatusPolicyViolation},
|
|
{MaxConnectionsExceeded, wsCloseStatusPolicyViolation},
|
|
{MaxControlLineExceeded, wsCloseStatusPolicyViolation},
|
|
{MaxSubscriptionsExceeded, wsCloseStatusPolicyViolation},
|
|
{MissingAccount, wsCloseStatusPolicyViolation},
|
|
{AuthenticationExpired, wsCloseStatusPolicyViolation},
|
|
{Revocation, wsCloseStatusPolicyViolation},
|
|
{TLSHandshakeError, wsCloseStatusTLSHandshake},
|
|
{ParseError, wsCloseStatusProtocolError},
|
|
{ProtocolViolation, wsCloseStatusProtocolError},
|
|
{BadClientProtocolVersion, wsCloseStatusProtocolError},
|
|
{MaxPayloadExceeded, wsCloseStatusMessageTooBig},
|
|
{ServerShutdown, wsCloseStatusGoingAway},
|
|
{WriteError, wsCloseStatusAbnormalClosure},
|
|
{ReadError, wsCloseStatusAbnormalClosure},
|
|
{StaleConnection, wsCloseStatusAbnormalClosure},
|
|
{ClosedState(254), wsCloseStatusInternalSrvError},
|
|
} {
|
|
t.Run(test.reason.String(), func(t *testing.T) {
|
|
c, _, _ := testWSSetupForRead()
|
|
c.wsEnqueueCloseMessage(test.reason)
|
|
c.mu.Lock()
|
|
nb, _ := c.collapsePtoNB()
|
|
c.mu.Unlock()
|
|
if n := len(nb); n != 1 {
|
|
t.Fatalf("Expected 1 buffer, got %v", n)
|
|
}
|
|
b := nb[0][0]
|
|
if b&wsFinalBit == 0 {
|
|
t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
|
|
}
|
|
if b&byte(wsCloseMessage) == 0 {
|
|
t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
|
|
}
|
|
if status := binary.BigEndian.Uint16(nb[0][2:4]); int(status) != test.status {
|
|
t.Fatalf("Expected status to be %v, got %v", test.status, status)
|
|
}
|
|
if string(nb[0][4:]) != test.reason.String() {
|
|
t.Fatalf("Unexpected content: %s", nb[0][4:])
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type testResponseWriter struct {
|
|
http.ResponseWriter
|
|
buf bytes.Buffer
|
|
headers http.Header
|
|
err error
|
|
brw *bufio.ReadWriter
|
|
conn *testWSFakeNetConn
|
|
}
|
|
|
|
func (trw *testResponseWriter) Write(p []byte) (int, error) {
|
|
return trw.buf.Write(p)
|
|
}
|
|
|
|
func (trw *testResponseWriter) WriteHeader(status int) {
|
|
trw.buf.WriteString(fmt.Sprintf("%v", status))
|
|
}
|
|
|
|
func (trw *testResponseWriter) Header() http.Header {
|
|
if trw.headers == nil {
|
|
trw.headers = make(http.Header)
|
|
}
|
|
return trw.headers
|
|
}
|
|
|
|
type testWSFakeNetConn struct {
|
|
net.Conn
|
|
wbuf bytes.Buffer
|
|
err error
|
|
wsOpened bool
|
|
isClosed bool
|
|
deadlineCleared bool
|
|
}
|
|
|
|
func (c *testWSFakeNetConn) Write(p []byte) (int, error) {
|
|
if c.err != nil {
|
|
return 0, c.err
|
|
}
|
|
return c.wbuf.Write(p)
|
|
}
|
|
|
|
func (c *testWSFakeNetConn) SetDeadline(t time.Time) error {
|
|
if t.IsZero() {
|
|
c.deadlineCleared = true
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *testWSFakeNetConn) Close() error {
|
|
c.isClosed = true
|
|
return nil
|
|
}
|
|
|
|
func (trw *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
if trw.conn == nil {
|
|
trw.conn = &testWSFakeNetConn{}
|
|
}
|
|
trw.conn.wsOpened = true
|
|
if trw.brw == nil {
|
|
trw.brw = bufio.NewReadWriter(bufio.NewReader(trw.conn), bufio.NewWriter(trw.conn))
|
|
}
|
|
return trw.conn, trw.brw, trw.err
|
|
}
|
|
|
|
func testWSOptions() *Options {
|
|
opts := DefaultOptions()
|
|
opts.DisableShortFirstPing = true
|
|
opts.Websocket.Host = "127.0.0.1"
|
|
opts.Websocket.Port = -1
|
|
opts.NoSystemAccount = true
|
|
var err error
|
|
tc := &TLSConfigOpts{
|
|
CertFile: "./configs/certs/server.pem",
|
|
KeyFile: "./configs/certs/key.pem",
|
|
}
|
|
opts.Websocket.TLSConfig, err = GenTLSConfig(tc)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return opts
|
|
}
|
|
|
|
func testWSCreateValidReq() *http.Request {
|
|
req := &http.Request{
|
|
Method: "GET",
|
|
Host: "localhost",
|
|
Proto: "HTTP/1.1",
|
|
}
|
|
req.Header = make(http.Header)
|
|
req.Header.Set("Upgrade", "websocket")
|
|
req.Header.Set("Connection", "Upgrade")
|
|
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
|
req.Header.Set("Sec-Websocket-Version", "13")
|
|
return req
|
|
}
|
|
|
|
func TestWSCheckOrigin(t *testing.T) {
|
|
notSameOrigin := false
|
|
sameOrigin := true
|
|
allowedListEmpty := []string{}
|
|
someList := []string{"http://host1.com", "http://host2.com:1234"}
|
|
|
|
for _, test := range []struct {
|
|
name string
|
|
sameOrigin bool
|
|
origins []string
|
|
reqHost string
|
|
reqTLS bool
|
|
origin string
|
|
err string
|
|
}{
|
|
{"any", notSameOrigin, allowedListEmpty, "", false, "http://any.host.com", ""},
|
|
{"same origin ok", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:80", ""},
|
|
{"same origin bad host", sameOrigin, allowedListEmpty, "host.com", false, "http://other.host.com", "not same origin"},
|
|
{"same origin bad port", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:81", "not same origin"},
|
|
{"same origin bad scheme", sameOrigin, allowedListEmpty, "host.com", true, "http://host.com", "not same origin"},
|
|
{"same origin bad uri", sameOrigin, allowedListEmpty, "host.com", false, "@@@://invalid:url:1234", "invalid URI"},
|
|
{"same origin bad url", sameOrigin, allowedListEmpty, "host.com", false, "http://invalid:url:1234", "too many colons"},
|
|
{"same origin bad req host", sameOrigin, allowedListEmpty, "invalid:url:1234", false, "http://host.com", "too many colons"},
|
|
{"no origin", sameOrigin, allowedListEmpty, "", false, "", "origin not provided"},
|
|
{"allowed from list", notSameOrigin, someList, "", false, "http://host2.com:1234", ""},
|
|
{"allowed with different path", notSameOrigin, someList, "", false, "http://host1.com/some/path", ""},
|
|
{"list bad port", notSameOrigin, someList, "", false, "http://host1.com:1234", "not in the allowed list"},
|
|
{"list bad scheme", notSameOrigin, someList, "", false, "https://host2.com:1234", "not in the allowed list"},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
opts := DefaultOptions()
|
|
opts.Websocket.SameOrigin = test.sameOrigin
|
|
opts.Websocket.AllowedOrigins = test.origins
|
|
s := &Server{opts: opts}
|
|
s.wsSetOriginOptions(&opts.Websocket)
|
|
|
|
req := testWSCreateValidReq()
|
|
req.Host = test.reqHost
|
|
if test.reqTLS {
|
|
req.TLS = &tls.ConnectionState{}
|
|
}
|
|
if test.origin != "" {
|
|
req.Header.Set("Origin", test.origin)
|
|
}
|
|
err := s.websocket.checkOrigin(req)
|
|
if test.err == "" && err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
} else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) {
|
|
t.Fatalf("Expected error %q, got %v", test.err, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSUpgradeValidationErrors(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
setup func() (*Options, *testResponseWriter, *http.Request)
|
|
err string
|
|
status int
|
|
}{
|
|
{
|
|
"bad method",
|
|
func() (*Options, *testResponseWriter, *http.Request) {
|
|
opts := testWSOptions()
|
|
req := testWSCreateValidReq()
|
|
req.Method = "POST"
|
|
return opts, nil, req
|
|
},
|
|
"must be GET",
|
|
http.StatusMethodNotAllowed,
|
|
},
|
|
{
|
|
"no host",
|
|
func() (*Options, *testResponseWriter, *http.Request) {
|
|
opts := testWSOptions()
|
|
req := testWSCreateValidReq()
|
|
req.Host = ""
|
|
return opts, nil, req
|
|
},
|
|
"'Host' missing in request",
|
|
http.StatusBadRequest,
|
|
},
|
|
{
|
|
"invalid upgrade header",
|
|
func() (*Options, *testResponseWriter, *http.Request) {
|
|
opts := testWSOptions()
|
|
req := testWSCreateValidReq()
|
|
req.Header.Del("Upgrade")
|
|
return opts, nil, req
|
|
},
|
|
"invalid value for header 'Upgrade'",
|
|
http.StatusBadRequest,
|
|
},
|
|
{
|
|
"invalid connection header",
|
|
func() (*Options, *testResponseWriter, *http.Request) {
|
|
opts := testWSOptions()
|
|
req := testWSCreateValidReq()
|
|
req.Header.Del("Connection")
|
|
return opts, nil, req
|
|
},
|
|
"invalid value for header 'Connection'",
|
|
http.StatusBadRequest,
|
|
},
|
|
{
|
|
"no key",
|
|
func() (*Options, *testResponseWriter, *http.Request) {
|
|
opts := testWSOptions()
|
|
req := testWSCreateValidReq()
|
|
req.Header.Del("Sec-Websocket-Key")
|
|
return opts, nil, req
|
|
},
|
|
"key missing",
|
|
http.StatusBadRequest,
|
|
},
|
|
{
|
|
"empty key",
|
|
func() (*Options, *testResponseWriter, *http.Request) {
|
|
opts := testWSOptions()
|
|
req := testWSCreateValidReq()
|
|
req.Header.Set("Sec-Websocket-Key", "")
|
|
return opts, nil, req
|
|
},
|
|
"key missing",
|
|
http.StatusBadRequest,
|
|
},
|
|
{
|
|
"missing version",
|
|
func() (*Options, *testResponseWriter, *http.Request) {
|
|
opts := testWSOptions()
|
|
req := testWSCreateValidReq()
|
|
req.Header.Del("Sec-Websocket-Version")
|
|
return opts, nil, req
|
|
},
|
|
"invalid version",
|
|
http.StatusBadRequest,
|
|
},
|
|
{
|
|
"wrong version",
|
|
func() (*Options, *testResponseWriter, *http.Request) {
|
|
opts := testWSOptions()
|
|
req := testWSCreateValidReq()
|
|
req.Header.Set("Sec-Websocket-Version", "99")
|
|
return opts, nil, req
|
|
},
|
|
"invalid version",
|
|
http.StatusBadRequest,
|
|
},
|
|
{
|
|
"origin",
|
|
func() (*Options, *testResponseWriter, *http.Request) {
|
|
opts := testWSOptions()
|
|
opts.Websocket.SameOrigin = true
|
|
req := testWSCreateValidReq()
|
|
req.Header.Set("Origin", "http://bad.host.com")
|
|
return opts, nil, req
|
|
},
|
|
"origin not allowed",
|
|
http.StatusForbidden,
|
|
},
|
|
{
|
|
"hijack error",
|
|
func() (*Options, *testResponseWriter, *http.Request) {
|
|
opts := testWSOptions()
|
|
rw := &testResponseWriter{err: fmt.Errorf("on purpose")}
|
|
req := testWSCreateValidReq()
|
|
return opts, rw, req
|
|
},
|
|
"on purpose",
|
|
http.StatusInternalServerError,
|
|
},
|
|
{
|
|
"hijack buffered data",
|
|
func() (*Options, *testResponseWriter, *http.Request) {
|
|
opts := testWSOptions()
|
|
buf := &bytes.Buffer{}
|
|
buf.WriteString("some data")
|
|
rw := &testResponseWriter{
|
|
conn: &testWSFakeNetConn{},
|
|
brw: bufio.NewReadWriter(bufio.NewReader(buf), bufio.NewWriter(nil)),
|
|
}
|
|
tmp := [1]byte{}
|
|
io.ReadAtLeast(rw.brw, tmp[:1], 1)
|
|
req := testWSCreateValidReq()
|
|
return opts, rw, req
|
|
},
|
|
"client sent data before handshake is complete",
|
|
http.StatusBadRequest,
|
|
},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
opts, rw, req := test.setup()
|
|
if rw == nil {
|
|
rw = &testResponseWriter{}
|
|
}
|
|
s := &Server{opts: opts}
|
|
s.wsSetOriginOptions(&opts.Websocket)
|
|
res, err := s.wsUpgrade(rw, req)
|
|
if err == nil || !strings.Contains(err.Error(), test.err) {
|
|
t.Fatalf("Should get error %q, got %v", test.err, err)
|
|
}
|
|
if res != nil {
|
|
t.Fatalf("Should not have returned a result, got %v", res)
|
|
}
|
|
expected := fmt.Sprintf("%v%s\n", test.status, http.StatusText(test.status))
|
|
if got := rw.buf.String(); got != expected {
|
|
t.Fatalf("Expected %q got %q", expected, got)
|
|
}
|
|
// Check that if the connection was opened, it is now closed.
|
|
if rw.conn != nil && rw.conn.wsOpened && !rw.conn.isClosed {
|
|
t.Fatal("Connection was opened, but has not been closed")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSUpgradeResponseWriteError(t *testing.T) {
|
|
opts := testWSOptions()
|
|
s := &Server{opts: opts}
|
|
expectedErr := errors.New("on purpose")
|
|
rw := &testResponseWriter{
|
|
conn: &testWSFakeNetConn{err: expectedErr},
|
|
}
|
|
req := testWSCreateValidReq()
|
|
res, err := s.wsUpgrade(rw, req)
|
|
if err != expectedErr {
|
|
t.Fatalf("Should get error %q, got %v", expectedErr.Error(), err)
|
|
}
|
|
if res != nil {
|
|
t.Fatalf("Should not have returned a result, got %v", res)
|
|
}
|
|
if !rw.conn.isClosed {
|
|
t.Fatal("Connection should have been closed")
|
|
}
|
|
}
|
|
|
|
func TestWSUpgradeConnDeadline(t *testing.T) {
|
|
opts := testWSOptions()
|
|
opts.Websocket.HandshakeTimeout = time.Second
|
|
s := &Server{opts: opts}
|
|
rw := &testResponseWriter{}
|
|
req := testWSCreateValidReq()
|
|
res, err := s.wsUpgrade(rw, req)
|
|
if res == nil || err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if rw.conn.isClosed {
|
|
t.Fatal("Connection should NOT have been closed")
|
|
}
|
|
if !rw.conn.deadlineCleared {
|
|
t.Fatal("Connection deadline should have been cleared after handshake")
|
|
}
|
|
}
|
|
|
|
func TestWSCompressNegotiation(t *testing.T) {
|
|
// No compression on the server, but client asks
|
|
opts := testWSOptions()
|
|
s := &Server{opts: opts}
|
|
rw := &testResponseWriter{}
|
|
req := testWSCreateValidReq()
|
|
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate")
|
|
res, err := s.wsUpgrade(rw, req)
|
|
if res == nil || err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
// The http response should not contain "permessage-deflate"
|
|
output := rw.conn.wbuf.String()
|
|
if strings.Contains(output, "permessage-deflate") {
|
|
t.Fatalf("Compression disabled in server so response to client should not contain extension, got %s", output)
|
|
}
|
|
|
|
// Option in the server and client, so compression should be negotiated.
|
|
s.opts.Websocket.Compression = true
|
|
rw = &testResponseWriter{}
|
|
res, err = s.wsUpgrade(rw, req)
|
|
if res == nil || err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
// The http response should not contain "permessage-deflate"
|
|
output = rw.conn.wbuf.String()
|
|
if !strings.Contains(output, "permessage-deflate") {
|
|
t.Fatalf("Compression in server and client request, so response should contain extension, got %s", output)
|
|
}
|
|
|
|
// Option in server but not asked by the client, so response should not contain "permessage-deflate"
|
|
rw = &testResponseWriter{}
|
|
req.Header.Del("Sec-Websocket-Extensions")
|
|
res, err = s.wsUpgrade(rw, req)
|
|
if res == nil || err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
// The http response should not contain "permessage-deflate"
|
|
output = rw.conn.wbuf.String()
|
|
if strings.Contains(output, "permessage-deflate") {
|
|
t.Fatalf("Compression in server but not in client, so response to client should not contain extension, got %s", output)
|
|
}
|
|
}
|
|
|
|
func TestWSParseOptions(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
content string
|
|
checkOpt func(*WebsocketOpts) error
|
|
err string
|
|
}{
|
|
// Negative tests
|
|
{"bad type", "websocket: []", nil, "to be a map"},
|
|
{"bad listen", "websocket: { listen: [] }", nil, "port or host:port"},
|
|
{"bad port", `websocket: { port: "abc" }`, nil, "not int64"},
|
|
{"bad host", `websocket: { host: 123 }`, nil, "not string"},
|
|
{"bad advertise type", `websocket: { advertise: 123 }`, nil, "not string"},
|
|
{"bad tls", `websocket: { tls: 123 }`, nil, "not map[string]interface {}"},
|
|
{"bad same origin", `websocket: { same_origin: "abc" }`, nil, "not bool"},
|
|
{"bad allowed origins type", `websocket: { allowed_origins: {} }`, nil, "unsupported type"},
|
|
{"bad allowed origins values", `websocket: { allowed_origins: [ {} ] }`, nil, "unsupported type in array"},
|
|
{"bad handshake timeout type", `websocket: { handshake_timeout: [] }`, nil, "unsupported type"},
|
|
{"bad handshake timeout duration", `websocket: { handshake_timeout: "abc" }`, nil, "invalid duration"},
|
|
{"unknown field", `websocket: { this_does_not_exist: 123 }`, nil, "unknown"},
|
|
// Positive tests
|
|
{"listen port only", `websocket { listen: 1234 }`, func(wo *WebsocketOpts) error {
|
|
if wo.Port != 1234 {
|
|
return fmt.Errorf("expected 1234, got %v", wo.Port)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"listen host and port", `websocket { listen: "localhost:1234" }`, func(wo *WebsocketOpts) error {
|
|
if wo.Host != "localhost" || wo.Port != 1234 {
|
|
return fmt.Errorf("expected localhost:1234, got %v:%v", wo.Host, wo.Port)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"host", `websocket { host: "localhost" }`, func(wo *WebsocketOpts) error {
|
|
if wo.Host != "localhost" {
|
|
return fmt.Errorf("expected localhost, got %v", wo.Host)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"port", `websocket { port: 1234 }`, func(wo *WebsocketOpts) error {
|
|
if wo.Port != 1234 {
|
|
return fmt.Errorf("expected 1234, got %v", wo.Port)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"advertise", `websocket { advertise: "host:1234" }`, func(wo *WebsocketOpts) error {
|
|
if wo.Advertise != "host:1234" {
|
|
return fmt.Errorf("expected %q, got %q", "host:1234", wo.Advertise)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"same origin", `websocket { same_origin: true }`, func(wo *WebsocketOpts) error {
|
|
if !wo.SameOrigin {
|
|
return fmt.Errorf("expected same_origin==true, got %v", wo.SameOrigin)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"allowed origins one only", `websocket { allowed_origins: "https://host.com/" }`, func(wo *WebsocketOpts) error {
|
|
expected := []string{"https://host.com/"}
|
|
if !reflect.DeepEqual(wo.AllowedOrigins, expected) {
|
|
return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"allowed origins array",
|
|
`
|
|
websocket {
|
|
allowed_origins: [
|
|
"https://host1.com/"
|
|
"https://host2.com/"
|
|
]
|
|
}
|
|
`, func(wo *WebsocketOpts) error {
|
|
expected := []string{"https://host1.com/", "https://host2.com/"}
|
|
if !reflect.DeepEqual(wo.AllowedOrigins, expected) {
|
|
return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"handshake timeout in whole seconds", `websocket { handshake_timeout: 3 }`, func(wo *WebsocketOpts) error {
|
|
if wo.HandshakeTimeout != 3*time.Second {
|
|
return fmt.Errorf("expected handshake to be 3s, got %v", wo.HandshakeTimeout)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"handshake timeout n duration", `websocket { handshake_timeout: "4s" }`, func(wo *WebsocketOpts) error {
|
|
if wo.HandshakeTimeout != 4*time.Second {
|
|
return fmt.Errorf("expected handshake to be 4s, got %v", wo.HandshakeTimeout)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"tls config",
|
|
`
|
|
websocket {
|
|
tls {
|
|
cert_file: "./configs/certs/server.pem"
|
|
key_file: "./configs/certs/key.pem"
|
|
}
|
|
}
|
|
`, func(wo *WebsocketOpts) error {
|
|
if wo.TLSConfig == nil {
|
|
return fmt.Errorf("TLSConfig should have been set")
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"compression",
|
|
`
|
|
websocket {
|
|
compression: true
|
|
}
|
|
`, func(wo *WebsocketOpts) error {
|
|
if !wo.Compression {
|
|
return fmt.Errorf("Compression should have been set")
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"jwt cookie",
|
|
`
|
|
websocket {
|
|
jwt_cookie: "jwtcookie"
|
|
}
|
|
`, func(wo *WebsocketOpts) error {
|
|
if wo.JWTCookie != "jwtcookie" {
|
|
return fmt.Errorf("Invalid JWTCookie value: %q", wo.JWTCookie)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"no auth user",
|
|
`
|
|
websocket {
|
|
no_auth_user: "noauthuser"
|
|
}
|
|
`, func(wo *WebsocketOpts) error {
|
|
if wo.NoAuthUser != "noauthuser" {
|
|
return fmt.Errorf("Invalid NoAuthUser value: %q", wo.NoAuthUser)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"auth block",
|
|
`
|
|
websocket {
|
|
authorization {
|
|
user: "webuser"
|
|
password: "pwd"
|
|
token: "token"
|
|
timeout: 2.0
|
|
}
|
|
}
|
|
`, func(wo *WebsocketOpts) error {
|
|
if wo.Username != "webuser" || wo.Password != "pwd" || wo.Token != "token" || wo.AuthTimeout != 2.0 {
|
|
return fmt.Errorf("Invalid auth block: %+v", wo)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
{"auth timeout as int",
|
|
`
|
|
websocket {
|
|
authorization {
|
|
timeout: 2
|
|
}
|
|
}
|
|
`, func(wo *WebsocketOpts) error {
|
|
if wo.AuthTimeout != 2.0 {
|
|
return fmt.Errorf("Invalid auth timeout: %v", wo.AuthTimeout)
|
|
}
|
|
return nil
|
|
}, ""},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
conf := createConfFile(t, []byte(test.content))
|
|
defer os.Remove(conf)
|
|
o, err := ProcessConfigFile(conf)
|
|
if test.err != _EMPTY_ {
|
|
if err == nil || !strings.Contains(err.Error(), test.err) {
|
|
t.Fatalf("For content: %q, expected error about %q, got %v", test.content, test.err, err)
|
|
}
|
|
return
|
|
} else if err != nil {
|
|
t.Fatalf("Unexpected error for content %q: %v", test.content, err)
|
|
}
|
|
if err := test.checkOpt(&o.Websocket); err != nil {
|
|
t.Fatalf("Incorrect option for content %q: %v", test.content, err.Error())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSValidateOptions(t *testing.T) {
|
|
nwso := DefaultOptions()
|
|
wso := testWSOptions()
|
|
for _, test := range []struct {
|
|
name string
|
|
getOpts func() *Options
|
|
err string
|
|
}{
|
|
{"websocket disabled", func() *Options { return nwso.Clone() }, ""},
|
|
{"no tls", func() *Options { o := wso.Clone(); o.Websocket.TLSConfig = nil; return o }, "requires TLS configuration"},
|
|
{"bad url in allowed list", func() *Options {
|
|
o := wso.Clone()
|
|
o.Websocket.AllowedOrigins = []string{"http://this:is:bad:url"}
|
|
return o
|
|
}, "unable to parse"},
|
|
{"missing trusted configuration", func() *Options {
|
|
o := wso.Clone()
|
|
o.Websocket.JWTCookie = "jwt"
|
|
return o
|
|
}, "keys configuration is required"},
|
|
{"websocket username not allowed if users specified", func() *Options {
|
|
o := wso.Clone()
|
|
o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: "abc"}}
|
|
o.Websocket.Username = "b"
|
|
o.Websocket.Password = "pwd"
|
|
return o
|
|
}, "websocket authentication username not compatible with presence of users/nkeys"},
|
|
{"websocket token not allowed if users specified", func() *Options {
|
|
o := wso.Clone()
|
|
o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: "abc"}}
|
|
o.Websocket.Token = "mytoken"
|
|
return o
|
|
}, "websocket authentication token not compatible with presence of users/nkeys"},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
err := validateWebsocketOptions(test.getOpts())
|
|
if test.err == "" && err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
} else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) {
|
|
t.Fatalf("Expected error to contain %q, got %v", test.err, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSSetOriginOptions(t *testing.T) {
|
|
o := testWSOptions()
|
|
for _, test := range []struct {
|
|
content string
|
|
err string
|
|
}{
|
|
{"@@@://host.com/", "invalid URI"},
|
|
{"http://this:is:bad:url/", "invalid port"},
|
|
} {
|
|
t.Run(test.err, func(t *testing.T) {
|
|
o.Websocket.AllowedOrigins = []string{test.content}
|
|
s := &Server{}
|
|
l := &captureErrorLogger{errCh: make(chan string, 1)}
|
|
s.SetLogger(l, false, false)
|
|
s.wsSetOriginOptions(&o.Websocket)
|
|
select {
|
|
case e := <-l.errCh:
|
|
if !strings.Contains(e, test.err) {
|
|
t.Fatalf("Unexpected error: %v", e)
|
|
}
|
|
case <-time.After(50 * time.Millisecond):
|
|
t.Fatalf("Did not get the error")
|
|
}
|
|
|
|
})
|
|
}
|
|
}
|
|
|
|
type captureFatalLogger struct {
|
|
DummyLogger
|
|
fatalCh chan string
|
|
}
|
|
|
|
func (l *captureFatalLogger) Fatalf(format string, v ...interface{}) {
|
|
select {
|
|
case l.fatalCh <- fmt.Sprintf(format, v...):
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestWSFailureToStartServer(t *testing.T) {
|
|
// Create a listener to use a port
|
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("Error listening: %v", err)
|
|
}
|
|
defer l.Close()
|
|
|
|
o := testWSOptions()
|
|
// Make sure we don't have unnecessary listen ports opened.
|
|
o.HTTPPort = 0
|
|
o.Cluster.Port = 0
|
|
o.Gateway.Name = ""
|
|
o.Gateway.Port = 0
|
|
o.LeafNode.Port = 0
|
|
o.Websocket.Port = l.Addr().(*net.TCPAddr).Port
|
|
s, err := NewServer(o)
|
|
if err != nil {
|
|
t.Fatalf("Error creating server: %v", err)
|
|
}
|
|
defer s.Shutdown()
|
|
logger := &captureFatalLogger{fatalCh: make(chan string, 1)}
|
|
s.SetLogger(logger, false, false)
|
|
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
s.Start()
|
|
wg.Done()
|
|
}()
|
|
|
|
select {
|
|
case e := <-logger.fatalCh:
|
|
if !strings.Contains(e, "Unable to listen") {
|
|
t.Fatalf("Unexpected error: %v", e)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatalf("Should have reported a fatal error")
|
|
}
|
|
// Since this is a test and the process does not actually
|
|
// exit on Fatal error, wait for the client port to be
|
|
// ready so when we shutdown we don't leave the accept
|
|
// loop hanging.
|
|
checkFor(t, time.Second, 15*time.Millisecond, func() error {
|
|
s.mu.Lock()
|
|
ready := s.listener != nil
|
|
s.mu.Unlock()
|
|
if !ready {
|
|
return fmt.Errorf("client accept loop not started yet")
|
|
}
|
|
return nil
|
|
})
|
|
s.Shutdown()
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestWSAbnormalFailureOfWebServer(t *testing.T) {
|
|
o := testWSOptions()
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
logger := &captureFatalLogger{fatalCh: make(chan string, 1)}
|
|
s.SetLogger(logger, false, false)
|
|
|
|
// Now close the WS listener to cause a WebServer error
|
|
s.mu.Lock()
|
|
s.websocket.listener.Close()
|
|
s.mu.Unlock()
|
|
|
|
select {
|
|
case e := <-logger.fatalCh:
|
|
if !strings.Contains(e, "websocket listener error") {
|
|
t.Fatalf("Unexpected error: %v", e)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatalf("Should have reported a fatal error")
|
|
}
|
|
}
|
|
|
|
type testWSClientOptions struct {
|
|
compress, web bool
|
|
host string
|
|
port int
|
|
extraHeaders map[string]string
|
|
noTLS bool
|
|
}
|
|
|
|
func testNewWSClient(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte) {
|
|
t.Helper()
|
|
addr := fmt.Sprintf("%s:%d", o.host, o.port)
|
|
wsc, err := net.Dial("tcp", addr)
|
|
if err != nil {
|
|
t.Fatalf("Error creating ws connection: %v", err)
|
|
}
|
|
if !o.noTLS {
|
|
wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
|
|
if err := wsc.(*tls.Conn).Handshake(); err != nil {
|
|
t.Fatalf("Error during handshake: %v", err)
|
|
}
|
|
}
|
|
req := testWSCreateValidReq()
|
|
if o.compress {
|
|
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate")
|
|
}
|
|
if o.web {
|
|
req.Header.Set("User-Agent", "Mozilla/5.0")
|
|
}
|
|
if o.extraHeaders != nil {
|
|
for hdr, val := range o.extraHeaders {
|
|
req.Header.Add(hdr, val)
|
|
}
|
|
}
|
|
req.URL, _ = url.Parse("wss://" + addr)
|
|
if err := req.Write(wsc); err != nil {
|
|
t.Fatalf("Error sending request: %v", err)
|
|
}
|
|
br := bufio.NewReader(wsc)
|
|
resp, err := http.ReadResponse(br, req)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
|
t.Fatalf("Expected response status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
|
|
}
|
|
// Wait for the INFO
|
|
info := testWSReadFrame(t, br)
|
|
if !bytes.HasPrefix(info, []byte("INFO {")) {
|
|
t.Fatalf("Expected INFO, got %s", info)
|
|
}
|
|
return wsc, br, info
|
|
}
|
|
|
|
type testClaimsOptions struct {
|
|
nac *jwt.AccountClaims
|
|
nuc *jwt.UserClaims
|
|
connectRequest interface{}
|
|
dontSign bool
|
|
expectAnswer string
|
|
}
|
|
|
|
func testWSWithClaims(t *testing.T, s *Server, o testWSClientOptions, tclm testClaimsOptions) (kp nkeys.KeyPair, conn net.Conn, rdr *bufio.Reader, auth_was_required bool) {
|
|
t.Helper()
|
|
|
|
okp, _ := nkeys.FromSeed(oSeed)
|
|
|
|
akp, _ := nkeys.CreateAccount()
|
|
apub, _ := akp.PublicKey()
|
|
if tclm.nac == nil {
|
|
tclm.nac = jwt.NewAccountClaims(apub)
|
|
} else {
|
|
tclm.nac.Subject = apub
|
|
}
|
|
ajwt, err := tclm.nac.Encode(okp)
|
|
if err != nil {
|
|
t.Fatalf("Error generating account JWT: %v", err)
|
|
}
|
|
|
|
nkp, _ := nkeys.CreateUser()
|
|
pub, _ := nkp.PublicKey()
|
|
if tclm.nuc == nil {
|
|
tclm.nuc = jwt.NewUserClaims(pub)
|
|
} else {
|
|
tclm.nuc.Subject = pub
|
|
}
|
|
jwt, err := tclm.nuc.Encode(akp)
|
|
if err != nil {
|
|
t.Fatalf("Error generating user JWT: %v", err)
|
|
}
|
|
|
|
addAccountToMemResolver(s, apub, ajwt)
|
|
|
|
c, cr, l := testNewWSClient(t, o)
|
|
|
|
var info struct {
|
|
Nonce string `json:"nonce,omitempty"`
|
|
AuthRequired bool `json:"auth_required,omitempty"`
|
|
}
|
|
|
|
if err := json.Unmarshal([]byte(l[5:]), &info); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if info.AuthRequired {
|
|
cs := ""
|
|
if tclm.connectRequest != nil {
|
|
customReq, err := json.Marshal(tclm.connectRequest)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
// PING needed to flush the +OK/-ERR to us.
|
|
cs = fmt.Sprintf("CONNECT %v\r\nPING\r\n", string(customReq))
|
|
} else if !tclm.dontSign {
|
|
// Sign Nonce
|
|
sigraw, _ := nkp.Sign([]byte(info.Nonce))
|
|
sig := base64.RawURLEncoding.EncodeToString(sigraw)
|
|
cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"sig\":\"%s\",\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt, sig)
|
|
} else {
|
|
cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt)
|
|
}
|
|
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(cs))
|
|
c.Write(wsmsg)
|
|
l = testWSReadFrame(t, cr)
|
|
if !strings.HasPrefix(string(l), tclm.expectAnswer) {
|
|
t.Fatalf("Expected %q, got %q", tclm.expectAnswer, l)
|
|
}
|
|
}
|
|
return akp, c, cr, info.AuthRequired
|
|
}
|
|
|
|
func setupAddTrusted(o *Options) {
|
|
kp, _ := nkeys.FromSeed(oSeed)
|
|
pub, _ := kp.PublicKey()
|
|
o.TrustedKeys = []string{pub}
|
|
}
|
|
|
|
func setupAddCookie(o *Options) {
|
|
o.Websocket.JWTCookie = "jwt"
|
|
}
|
|
|
|
func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader, []byte) {
|
|
t.Helper()
|
|
return testNewWSClient(t, testWSClientOptions{
|
|
compress: compress,
|
|
web: web,
|
|
host: host,
|
|
port: port,
|
|
})
|
|
}
|
|
|
|
func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) {
|
|
wsc, br, _ := testWSCreateClientGetInfo(t, compress, web, host, port)
|
|
// Send CONNECT and PING
|
|
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"))
|
|
if _, err := wsc.Write(wsmsg); err != nil {
|
|
t.Fatalf("Error sending message: %v", err)
|
|
}
|
|
// Wait for the PONG
|
|
if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
|
t.Fatalf("Expected PONG, got %s", msg)
|
|
}
|
|
return wsc, br
|
|
}
|
|
|
|
func testWSReadFrame(t testing.TB, br *bufio.Reader) []byte {
|
|
t.Helper()
|
|
fh := [2]byte{}
|
|
if _, err := io.ReadAtLeast(br, fh[:2], 2); err != nil {
|
|
t.Fatalf("Error reading frame: %v", err)
|
|
}
|
|
fc := fh[0]&wsRsv1Bit != 0
|
|
sb := fh[1]
|
|
size := 0
|
|
switch {
|
|
case sb <= 125:
|
|
size = int(sb)
|
|
case sb == 126:
|
|
tmp := [2]byte{}
|
|
if _, err := io.ReadAtLeast(br, tmp[:2], 2); err != nil {
|
|
t.Fatalf("Error reading frame: %v", err)
|
|
}
|
|
size = int(binary.BigEndian.Uint16(tmp[:2]))
|
|
case sb == 127:
|
|
tmp := [8]byte{}
|
|
if _, err := io.ReadAtLeast(br, tmp[:8], 8); err != nil {
|
|
t.Fatalf("Error reading frame: %v", err)
|
|
}
|
|
size = int(binary.BigEndian.Uint64(tmp[:8]))
|
|
}
|
|
buf := make([]byte, size)
|
|
if _, err := io.ReadAtLeast(br, buf, size); err != nil {
|
|
t.Fatalf("Error reading frame: %v", err)
|
|
}
|
|
if !fc {
|
|
return buf
|
|
}
|
|
buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff)
|
|
dbr := bytes.NewBuffer(buf)
|
|
d := flate.NewReader(dbr)
|
|
uncompressed, err := ioutil.ReadAll(d)
|
|
if err != nil {
|
|
t.Fatalf("Error reading frame: %v", err)
|
|
}
|
|
return uncompressed
|
|
}
|
|
|
|
func TestWSPubSub(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
compression bool
|
|
}{
|
|
{"no compression", false},
|
|
{"compression", true},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
o := testWSOptions()
|
|
if test.compression {
|
|
o.Websocket.Compression = true
|
|
}
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
// Create a regular client to subscribe
|
|
nc := natsConnect(t, s.ClientURL())
|
|
defer nc.Close()
|
|
nsub := natsSubSync(t, nc, "foo")
|
|
checkExpectedSubs(t, 1, s)
|
|
|
|
// Now create a WS client and send a message on "foo"
|
|
wsc, br := testWSCreateClient(t, test.compression, false, o.Websocket.Host, o.Websocket.Port)
|
|
defer wsc.Close()
|
|
|
|
// Send a WS message for "PUB foo 2\r\nok\r\n"
|
|
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("PUB foo 7\r\nfrom ws\r\n"))
|
|
if _, err := wsc.Write(wsmsg); err != nil {
|
|
t.Fatalf("Error sending message: %v", err)
|
|
}
|
|
|
|
// Now check that message is received
|
|
msg := natsNexMsg(t, nsub, time.Second)
|
|
if string(msg.Data) != "from ws" {
|
|
t.Fatalf("Expected message to be %q, got %q", "ok", string(msg.Data))
|
|
}
|
|
|
|
// Now do reverse, create a subscription on WS client on bar
|
|
wsmsg = testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB bar 1\r\n"))
|
|
if _, err := wsc.Write(wsmsg); err != nil {
|
|
t.Fatalf("Error sending subscription: %v", err)
|
|
}
|
|
// Wait for it to be registered on server
|
|
checkExpectedSubs(t, 2, s)
|
|
// Now publish from NATS connection and verify received on WS client
|
|
natsPub(t, nc, "bar", []byte("from nats"))
|
|
natsFlush(t, nc)
|
|
|
|
// Check for the "from nats" message...
|
|
// Set some deadline so we are not stuck forever on failure
|
|
wsc.SetReadDeadline(time.Now().Add(10 * time.Second))
|
|
ok := 0
|
|
for {
|
|
line, _, err := br.ReadLine()
|
|
if err != nil {
|
|
t.Fatalf("Error reading: %v", err)
|
|
}
|
|
// Note that this works even in compression test because those
|
|
// texts are likely not to be compressed, but compression code is
|
|
// still executed.
|
|
if ok == 0 && bytes.Contains(line, []byte("MSG bar 1 9")) {
|
|
ok = 1
|
|
continue
|
|
} else if ok == 1 && bytes.Contains(line, []byte("from nats")) {
|
|
ok = 2
|
|
break
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSTLSConnection(t *testing.T) {
|
|
o := testWSOptions()
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
|
|
|
|
for _, test := range []struct {
|
|
name string
|
|
useTLS bool
|
|
status int
|
|
}{
|
|
{"client uses TLS", true, http.StatusSwitchingProtocols},
|
|
{"client does not use TLS", false, http.StatusBadRequest},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
wsc, err := net.Dial("tcp", addr)
|
|
if err != nil {
|
|
t.Fatalf("Error creating ws connection: %v", err)
|
|
}
|
|
defer wsc.Close()
|
|
if test.useTLS {
|
|
wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
|
|
if err := wsc.(*tls.Conn).Handshake(); err != nil {
|
|
t.Fatalf("Error during handshake: %v", err)
|
|
}
|
|
}
|
|
req := testWSCreateValidReq()
|
|
var scheme string
|
|
if test.useTLS {
|
|
scheme = "s"
|
|
}
|
|
req.URL, _ = url.Parse("ws" + scheme + "://" + addr)
|
|
if err := req.Write(wsc); err != nil {
|
|
t.Fatalf("Error sending request: %v", err)
|
|
}
|
|
br := bufio.NewReader(wsc)
|
|
resp, err := http.ReadResponse(br, req)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != test.status {
|
|
t.Fatalf("Expected status %v, got %v", test.status, resp.StatusCode)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSTLSVerifyClientCert(t *testing.T) {
|
|
o := testWSOptions()
|
|
tc := &TLSConfigOpts{
|
|
CertFile: "../test/configs/certs/server-cert.pem",
|
|
KeyFile: "../test/configs/certs/server-key.pem",
|
|
CaFile: "../test/configs/certs/ca.pem",
|
|
Verify: true,
|
|
}
|
|
tlsc, err := GenTLSConfig(tc)
|
|
if err != nil {
|
|
t.Fatalf("Error creating tls config: %v", err)
|
|
}
|
|
o.Websocket.TLSConfig = tlsc
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
|
|
|
|
for _, test := range []struct {
|
|
name string
|
|
provideCert bool
|
|
}{
|
|
{"client provides cert", true},
|
|
{"client does not provide cert", false},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
wsc, err := net.Dial("tcp", addr)
|
|
if err != nil {
|
|
t.Fatalf("Error creating ws connection: %v", err)
|
|
}
|
|
defer wsc.Close()
|
|
tlsc := &tls.Config{}
|
|
if test.provideCert {
|
|
tc := &TLSConfigOpts{
|
|
CertFile: "../test/configs/certs/client-cert.pem",
|
|
KeyFile: "../test/configs/certs/client-key.pem",
|
|
}
|
|
var err error
|
|
tlsc, err = GenTLSConfig(tc)
|
|
if err != nil {
|
|
t.Fatalf("Error generating tls config: %v", err)
|
|
}
|
|
}
|
|
tlsc.InsecureSkipVerify = true
|
|
wsc = tls.Client(wsc, tlsc)
|
|
if err := wsc.(*tls.Conn).Handshake(); err != nil {
|
|
t.Fatalf("Error during handshake: %v", err)
|
|
}
|
|
req := testWSCreateValidReq()
|
|
req.URL, _ = url.Parse("wss://" + addr)
|
|
if err := req.Write(wsc); err != nil {
|
|
t.Fatalf("Error sending request: %v", err)
|
|
}
|
|
br := bufio.NewReader(wsc)
|
|
resp, err := http.ReadResponse(br, req)
|
|
if resp != nil {
|
|
resp.Body.Close()
|
|
}
|
|
if !test.provideCert {
|
|
if err == nil {
|
|
t.Fatal("Expected error, did not get one")
|
|
} else if !strings.Contains(err.Error(), "bad certificate") {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
|
t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func testCreateAllowedConnectionTypes(list []string) map[string]struct{} {
|
|
if len(list) == 0 {
|
|
return nil
|
|
}
|
|
m := make(map[string]struct{}, len(list))
|
|
for _, l := range list {
|
|
m[l] = struct{}{}
|
|
}
|
|
return m
|
|
}
|
|
|
|
func TestWSTLSVerifyAndMap(t *testing.T) {
|
|
accName := "MyAccount"
|
|
acc := NewAccount(accName)
|
|
certUserName := "CN=example.com,OU=NATS.io"
|
|
users := []*User{&User{Username: certUserName, Account: acc}}
|
|
|
|
for _, test := range []struct {
|
|
name string
|
|
filtering bool
|
|
provideCert bool
|
|
}{
|
|
{"no filtering, client provides cert", false, true},
|
|
{"no filtering, client does not provide cert", false, false},
|
|
{"filtering, client provides cert", true, true},
|
|
{"filtering, client does not provide cert", true, false},
|
|
{"no users override, client provides cert", false, true},
|
|
{"no users override, client does not provide cert", false, false},
|
|
{"users override, client provides cert", true, true},
|
|
{"users override, client does not provide cert", true, false},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
o := testWSOptions()
|
|
o.Accounts = []*Account{acc}
|
|
o.Users = users
|
|
if test.filtering {
|
|
o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket})
|
|
}
|
|
tc := &TLSConfigOpts{
|
|
CertFile: "../test/configs/certs/tlsauth/server.pem",
|
|
KeyFile: "../test/configs/certs/tlsauth/server-key.pem",
|
|
CaFile: "../test/configs/certs/tlsauth/ca.pem",
|
|
Verify: true,
|
|
}
|
|
tlsc, err := GenTLSConfig(tc)
|
|
if err != nil {
|
|
t.Fatalf("Error creating tls config: %v", err)
|
|
}
|
|
o.Websocket.TLSConfig = tlsc
|
|
o.Websocket.TLSMap = true
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
|
|
wsc, err := net.Dial("tcp", addr)
|
|
if err != nil {
|
|
t.Fatalf("Error creating ws connection: %v", err)
|
|
}
|
|
defer wsc.Close()
|
|
tlscc := &tls.Config{}
|
|
if test.provideCert {
|
|
tc := &TLSConfigOpts{
|
|
CertFile: "../test/configs/certs/tlsauth/client.pem",
|
|
KeyFile: "../test/configs/certs/tlsauth/client-key.pem",
|
|
}
|
|
var err error
|
|
tlscc, err = GenTLSConfig(tc)
|
|
if err != nil {
|
|
t.Fatalf("Error generating tls config: %v", err)
|
|
}
|
|
}
|
|
tlscc.InsecureSkipVerify = true
|
|
wsc = tls.Client(wsc, tlscc)
|
|
if err := wsc.(*tls.Conn).Handshake(); err != nil {
|
|
t.Fatalf("Error during handshake: %v", err)
|
|
}
|
|
req := testWSCreateValidReq()
|
|
req.URL, _ = url.Parse("wss://" + addr)
|
|
if err := req.Write(wsc); err != nil {
|
|
t.Fatalf("Error sending request: %v", err)
|
|
}
|
|
br := bufio.NewReader(wsc)
|
|
resp, err := http.ReadResponse(br, req)
|
|
if resp != nil {
|
|
resp.Body.Close()
|
|
}
|
|
if !test.provideCert {
|
|
if err == nil {
|
|
t.Fatal("Expected error, did not get one")
|
|
} else if !strings.Contains(err.Error(), "bad certificate") {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
|
t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
|
|
}
|
|
// Wait for the INFO
|
|
l := testWSReadFrame(t, br)
|
|
if !bytes.HasPrefix(l, []byte("INFO {")) {
|
|
t.Fatalf("Expected INFO, got %s", l)
|
|
}
|
|
var info serverInfo
|
|
if err := json.Unmarshal(l[5:], &info); err != nil {
|
|
t.Fatalf("Unable to unmarshal info: %v", err)
|
|
}
|
|
// Send CONNECT and PING
|
|
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"))
|
|
if _, err := wsc.Write(wsmsg); err != nil {
|
|
t.Fatalf("Error sending message: %v", err)
|
|
}
|
|
// Wait for the PONG
|
|
if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
|
t.Fatalf("Expected PONG, got %s", msg)
|
|
}
|
|
|
|
var uname string
|
|
var accname string
|
|
c := s.getClient(info.CID)
|
|
if c != nil {
|
|
c.mu.Lock()
|
|
uname = c.opts.Username
|
|
if c.acc != nil {
|
|
accname = c.acc.GetName()
|
|
}
|
|
c.mu.Unlock()
|
|
}
|
|
if uname != certUserName {
|
|
t.Fatalf("Expected username %q, got %q", certUserName, uname)
|
|
}
|
|
if accname != accName {
|
|
t.Fatalf("Expected account %q, got %v", accName, accname)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSHandshakeTimeout(t *testing.T) {
|
|
o := testWSOptions()
|
|
o.Websocket.HandshakeTimeout = time.Millisecond
|
|
tc := &TLSConfigOpts{
|
|
CertFile: "./configs/certs/server.pem",
|
|
KeyFile: "./configs/certs/key.pem",
|
|
}
|
|
o.Websocket.TLSConfig, _ = GenTLSConfig(tc)
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
logger := &captureErrorLogger{errCh: make(chan string, 1)}
|
|
s.SetLogger(logger, false, false)
|
|
|
|
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
|
|
wsc, err := net.Dial("tcp", addr)
|
|
if err != nil {
|
|
t.Fatalf("Error creating ws connection: %v", err)
|
|
}
|
|
defer wsc.Close()
|
|
|
|
// Delay the handshake
|
|
wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
|
|
time.Sleep(20 * time.Millisecond)
|
|
// We expect error since the server should have cut us off
|
|
if err := wsc.(*tls.Conn).Handshake(); err == nil {
|
|
t.Fatal("Expected error during handshake")
|
|
}
|
|
|
|
// Check that server logs error
|
|
select {
|
|
case e := <-logger.errCh:
|
|
if !strings.Contains(e, "timeout") {
|
|
t.Fatalf("Unexpected error: %v", e)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("Should have timed-out")
|
|
}
|
|
}
|
|
|
|
func TestWSServerReportUpgradeFailure(t *testing.T) {
|
|
o := testWSOptions()
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
logger := &captureErrorLogger{errCh: make(chan string, 1)}
|
|
s.SetLogger(logger, false, false)
|
|
|
|
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
|
|
req := testWSCreateValidReq()
|
|
req.URL, _ = url.Parse("wss://" + addr)
|
|
|
|
wsc, err := net.Dial("tcp", addr)
|
|
if err != nil {
|
|
t.Fatalf("Error creating ws connection: %v", err)
|
|
}
|
|
defer wsc.Close()
|
|
wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
|
|
if err := wsc.(*tls.Conn).Handshake(); err != nil {
|
|
t.Fatalf("Error during handshake: %v", err)
|
|
}
|
|
// Remove a required field from the request to have it fail
|
|
req.Header.Del("Connection")
|
|
// Send the request
|
|
if err := req.Write(wsc); err != nil {
|
|
t.Fatalf("Error sending request: %v", err)
|
|
}
|
|
br := bufio.NewReader(wsc)
|
|
resp, err := http.ReadResponse(br, req)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Fatalf("Expected status %v, got %v", http.StatusBadRequest, resp.StatusCode)
|
|
}
|
|
|
|
// Check that server logs error
|
|
select {
|
|
case e := <-logger.errCh:
|
|
if !strings.Contains(e, "invalid value for header 'Connection'") {
|
|
t.Fatalf("Unexpected error: %v", e)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("Should have timed-out")
|
|
}
|
|
}
|
|
|
|
func TestWSCloseMsgSendOnConnectionClose(t *testing.T) {
|
|
o := testWSOptions()
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
wsc, br := testWSCreateClient(t, false, false, o.Websocket.Host, o.Websocket.Port)
|
|
defer wsc.Close()
|
|
|
|
checkClientsCount(t, s, 1)
|
|
var c *client
|
|
s.mu.Lock()
|
|
for _, cli := range s.clients {
|
|
c = cli
|
|
break
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
c.closeConnection(ProtocolViolation)
|
|
msg := testWSReadFrame(t, br)
|
|
if len(msg) < 2 {
|
|
t.Fatalf("Should have 2 bytes to represent the status, got %v", msg)
|
|
}
|
|
if sc := int(binary.BigEndian.Uint16(msg[:2])); sc != wsCloseStatusProtocolError {
|
|
t.Fatalf("Expected status to be %v, got %v", wsCloseStatusProtocolError, sc)
|
|
}
|
|
expectedPayload := ProtocolViolation.String()
|
|
if p := string(msg[2:]); p != expectedPayload {
|
|
t.Fatalf("Expected payload to be %q, got %q", expectedPayload, p)
|
|
}
|
|
}
|
|
|
|
func TestWSAdvertise(t *testing.T) {
|
|
o := testWSOptions()
|
|
o.Cluster.Port = 0
|
|
o.HTTPPort = 0
|
|
o.Websocket.Advertise = "xxx:host:yyy"
|
|
s, err := NewServer(o)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
defer s.Shutdown()
|
|
l := &captureFatalLogger{fatalCh: make(chan string, 1)}
|
|
s.SetLogger(l, false, false)
|
|
go s.Start()
|
|
select {
|
|
case e := <-l.fatalCh:
|
|
if !strings.Contains(e, "Unable to get websocket connect URLs") {
|
|
t.Fatalf("Unexpected error: %q", e)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatal("Should have failed to start")
|
|
}
|
|
s.Shutdown()
|
|
|
|
o1 := testWSOptions()
|
|
o1.Websocket.Advertise = "host1:1234"
|
|
s1 := RunServer(o1)
|
|
defer s1.Shutdown()
|
|
|
|
wsc, br := testWSCreateClient(t, false, false, o1.Websocket.Host, o1.Websocket.Port)
|
|
defer wsc.Close()
|
|
|
|
o2 := testWSOptions()
|
|
o2.Websocket.Advertise = "host2:5678"
|
|
o2.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", o1.Cluster.Host, o1.Cluster.Port))
|
|
s2 := RunServer(o2)
|
|
defer s2.Shutdown()
|
|
|
|
checkInfo := func(expected []string) {
|
|
t.Helper()
|
|
infob := testWSReadFrame(t, br)
|
|
info := &Info{}
|
|
json.Unmarshal(infob[5:], info)
|
|
if n := len(info.ClientConnectURLs); n != len(expected) {
|
|
t.Fatalf("Unexpected info: %+v", info)
|
|
}
|
|
good := 0
|
|
for _, u := range info.ClientConnectURLs {
|
|
for _, eu := range expected {
|
|
if u == eu {
|
|
good++
|
|
}
|
|
}
|
|
}
|
|
if good != len(expected) {
|
|
t.Fatalf("Unexpected connect urls: %q", info.ClientConnectURLs)
|
|
}
|
|
}
|
|
checkInfo([]string{"host1:1234", "host2:5678"})
|
|
|
|
// Now shutdown s2 and expect another INFO
|
|
s2.Shutdown()
|
|
checkInfo([]string{"host1:1234"})
|
|
|
|
// Restart with another advertise and check that it gets updated
|
|
o2.Websocket.Advertise = "host3:9012"
|
|
s2 = RunServer(o2)
|
|
defer s2.Shutdown()
|
|
checkInfo([]string{"host1:1234", "host3:9012"})
|
|
}
|
|
|
|
func TestWSFrameOutbound(t *testing.T) {
|
|
c, _, _ := testWSSetupForRead()
|
|
|
|
var bufs net.Buffers
|
|
bufs = append(bufs, []byte("this "))
|
|
bufs = append(bufs, []byte("is "))
|
|
bufs = append(bufs, []byte("a "))
|
|
bufs = append(bufs, []byte("set "))
|
|
bufs = append(bufs, []byte("of "))
|
|
bufs = append(bufs, []byte("buffers"))
|
|
en := 2
|
|
for _, b := range bufs {
|
|
en += len(b)
|
|
}
|
|
c.mu.Lock()
|
|
c.out.nb = bufs
|
|
res, n := c.collapsePtoNB()
|
|
c.mu.Unlock()
|
|
if n != int64(en) {
|
|
t.Fatalf("Expected size to be %v, got %v", en, n)
|
|
}
|
|
if eb := 1 + len(bufs); eb != len(res) {
|
|
t.Fatalf("Expected %v buffers, got %v", eb, len(res))
|
|
}
|
|
var ob []byte
|
|
for i := 1; i < len(res); i++ {
|
|
ob = append(ob, res[i]...)
|
|
}
|
|
if !bytes.Equal(ob, []byte("this is a set of buffers")) {
|
|
t.Fatalf("Unexpected outbound: %q", ob)
|
|
}
|
|
|
|
bufs = nil
|
|
c.out.pb = 0
|
|
c.ws.fs = 0
|
|
c.ws.frames = nil
|
|
c.ws.browser = true
|
|
bufs = append(bufs, []byte("some smaller "))
|
|
bufs = append(bufs, []byte("buffers"))
|
|
bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+10))
|
|
bufs = append(bufs, []byte("then some more"))
|
|
en = 2 + len(bufs[0]) + len(bufs[1])
|
|
en += 4 + len(bufs[2]) - 10
|
|
en += 2 + len(bufs[3]) + 10
|
|
c.mu.Lock()
|
|
c.out.nb = bufs
|
|
res, n = c.collapsePtoNB()
|
|
c.mu.Unlock()
|
|
if n != int64(en) {
|
|
t.Fatalf("Expected size to be %v, got %v", en, n)
|
|
}
|
|
if len(res) != 8 {
|
|
t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
|
|
}
|
|
if len(res[4]) != wsFrameSizeForBrowsers {
|
|
t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
|
|
}
|
|
if len(res[6]) != 10 {
|
|
t.Fatalf("Frame 6 should have the partial of 10 bytes, got %v", len(res[6]))
|
|
}
|
|
|
|
bufs = nil
|
|
c.out.pb = 0
|
|
c.ws.fs = 0
|
|
c.ws.frames = nil
|
|
c.ws.browser = true
|
|
bufs = append(bufs, []byte("some smaller "))
|
|
bufs = append(bufs, []byte("buffers"))
|
|
// Have one of the exact max size
|
|
bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers))
|
|
bufs = append(bufs, []byte("then some more"))
|
|
en = 2 + len(bufs[0]) + len(bufs[1])
|
|
en += 4 + len(bufs[2])
|
|
en += 2 + len(bufs[3])
|
|
c.mu.Lock()
|
|
c.out.nb = bufs
|
|
res, n = c.collapsePtoNB()
|
|
c.mu.Unlock()
|
|
if n != int64(en) {
|
|
t.Fatalf("Expected size to be %v, got %v", en, n)
|
|
}
|
|
if len(res) != 7 {
|
|
t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
|
|
}
|
|
if len(res[4]) != wsFrameSizeForBrowsers {
|
|
t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
|
|
}
|
|
if string(res[6]) != string(bufs[3]) {
|
|
t.Fatalf("Frame 6 should be %q, got %q", bufs[3], res[6])
|
|
}
|
|
|
|
bufs = nil
|
|
c.out.pb = 0
|
|
c.ws.fs = 0
|
|
c.ws.frames = nil
|
|
c.ws.browser = true
|
|
bufs = append(bufs, []byte("some smaller "))
|
|
bufs = append(bufs, []byte("buffers"))
|
|
// Have one of the exact max size, and last in the list
|
|
bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers))
|
|
en = 2 + len(bufs[0]) + len(bufs[1])
|
|
en += 4 + len(bufs[2])
|
|
c.mu.Lock()
|
|
c.out.nb = bufs
|
|
res, n = c.collapsePtoNB()
|
|
c.mu.Unlock()
|
|
if n != int64(en) {
|
|
t.Fatalf("Expected size to be %v, got %v", en, n)
|
|
}
|
|
if len(res) != 5 {
|
|
t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
|
|
}
|
|
if len(res[4]) != wsFrameSizeForBrowsers {
|
|
t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
|
|
}
|
|
|
|
bufs = nil
|
|
c.out.pb = 0
|
|
c.ws.fs = 0
|
|
c.ws.frames = nil
|
|
c.ws.browser = true
|
|
bufs = append(bufs, []byte("some smaller buffer"))
|
|
bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers-5))
|
|
bufs = append(bufs, []byte("then some more"))
|
|
en = 2 + len(bufs[0])
|
|
en += 4 + len(bufs[1])
|
|
en += 2 + len(bufs[2])
|
|
c.mu.Lock()
|
|
c.out.nb = bufs
|
|
res, n = c.collapsePtoNB()
|
|
c.mu.Unlock()
|
|
if n != int64(en) {
|
|
t.Fatalf("Expected size to be %v, got %v", en, n)
|
|
}
|
|
if len(res) != 6 {
|
|
t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
|
|
}
|
|
if len(res[3]) != wsFrameSizeForBrowsers-5 {
|
|
t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
|
|
}
|
|
if string(res[5]) != string(bufs[2]) {
|
|
t.Fatalf("Frame 6 should be %q, got %q", bufs[2], res[5])
|
|
}
|
|
}
|
|
|
|
func TestWSWebrowserClient(t *testing.T) {
|
|
o := testWSOptions()
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
wsc, br := testWSCreateClient(t, false, true, o.Websocket.Host, o.Websocket.Port)
|
|
defer wsc.Close()
|
|
|
|
checkClientsCount(t, s, 1)
|
|
var c *client
|
|
s.mu.Lock()
|
|
for _, cli := range s.clients {
|
|
c = cli
|
|
break
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB foo 1\r\nPING\r\n"))
|
|
wsc.Write(proto)
|
|
if res := testWSReadFrame(t, br); !bytes.Equal(res, []byte(pongProto)) {
|
|
t.Fatalf("Expected PONG back")
|
|
}
|
|
|
|
c.mu.Lock()
|
|
ok := c.isWebsocket() && c.ws.browser == true
|
|
c.mu.Unlock()
|
|
if !ok {
|
|
t.Fatalf("Client is not marked as webrowser client")
|
|
}
|
|
|
|
nc := natsConnect(t, s.ClientURL())
|
|
defer nc.Close()
|
|
|
|
// Send a big message and check that it is received in smaller frames
|
|
psize := 204813
|
|
nc.Publish("foo", make([]byte, psize))
|
|
nc.Flush()
|
|
|
|
rsize := psize + len(fmt.Sprintf("MSG foo %d\r\n\r\n", psize))
|
|
nframes := 0
|
|
for total := 0; total < rsize; nframes++ {
|
|
res := testWSReadFrame(t, br)
|
|
total += len(res)
|
|
}
|
|
if expected := psize / wsFrameSizeForBrowsers; expected > nframes {
|
|
t.Fatalf("Expected %v frames, got %v", expected, nframes)
|
|
}
|
|
}
|
|
|
|
type testWSWrappedConn struct {
|
|
net.Conn
|
|
mu sync.RWMutex
|
|
buf *bytes.Buffer
|
|
partial bool
|
|
}
|
|
|
|
func (wc *testWSWrappedConn) Write(p []byte) (int, error) {
|
|
wc.mu.Lock()
|
|
defer wc.mu.Unlock()
|
|
var err error
|
|
n := len(p)
|
|
if wc.partial && n > 10 {
|
|
n = 10
|
|
err = io.ErrShortWrite
|
|
}
|
|
p = p[:n]
|
|
wc.buf.Write(p)
|
|
wc.Conn.Write(p)
|
|
return n, err
|
|
}
|
|
|
|
func TestWSCompressionBasic(t *testing.T) {
|
|
payload := "This is the content of a message that will be compresseddddddddddddddddddddd."
|
|
msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload)
|
|
|
|
cbuf := &bytes.Buffer{}
|
|
compressor, _ := flate.NewWriter(cbuf, flate.BestSpeed)
|
|
compressor.Write([]byte(msgProto))
|
|
compressor.Close()
|
|
compressed := cbuf.Bytes()
|
|
// The last 4 bytes are dropped
|
|
compressed = compressed[:len(compressed)-4]
|
|
|
|
o := testWSOptions()
|
|
o.Websocket.Compression = true
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port)
|
|
defer c.Close()
|
|
|
|
proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n"))
|
|
c.Write(proto)
|
|
l := testWSReadFrame(t, br)
|
|
if !bytes.Equal(l, []byte(pongProto)) {
|
|
t.Fatalf("Expected PONG, got %q", l)
|
|
}
|
|
|
|
var wc *testWSWrappedConn
|
|
s.mu.Lock()
|
|
for _, c := range s.clients {
|
|
c.mu.Lock()
|
|
wc = &testWSWrappedConn{Conn: c.nc, buf: &bytes.Buffer{}}
|
|
c.nc = wc
|
|
c.mu.Unlock()
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
nc := natsConnect(t, s.ClientURL())
|
|
defer nc.Close()
|
|
natsPub(t, nc, "foo", []byte(payload))
|
|
|
|
res := &bytes.Buffer{}
|
|
for total := 0; total < len(msgProto); {
|
|
l := testWSReadFrame(t, br)
|
|
n, _ := res.Write(l)
|
|
total += n
|
|
}
|
|
if !bytes.Equal([]byte(msgProto), res.Bytes()) {
|
|
t.Fatalf("Unexpected result: %q", res)
|
|
}
|
|
|
|
// Now check the wrapped connection buffer to check that data was actually compressed.
|
|
wc.mu.RLock()
|
|
res = wc.buf
|
|
wc.mu.RUnlock()
|
|
if bytes.Contains(res.Bytes(), []byte(payload)) {
|
|
t.Fatalf("Looks like frame was not compressed: %q", res.Bytes())
|
|
}
|
|
header := res.Bytes()[:2]
|
|
body := res.Bytes()[2:]
|
|
expectedB0 := byte(wsBinaryMessage) | wsFinalBit | wsRsv1Bit
|
|
expectedPS := len(compressed)
|
|
expectedB1 := byte(expectedPS)
|
|
|
|
if b := header[0]; b != expectedB0 {
|
|
t.Fatalf("Expected first byte to be %v, got %v", expectedB0, b)
|
|
}
|
|
if b := header[1]; b != expectedB1 {
|
|
t.Fatalf("Expected second byte to be %v, got %v", expectedB1, b)
|
|
}
|
|
if len(body) != expectedPS {
|
|
t.Fatalf("Expected payload length to be %v, got %v", expectedPS, len(body))
|
|
}
|
|
if !bytes.Equal(body, compressed) {
|
|
t.Fatalf("Unexpected compress body: %q", body)
|
|
}
|
|
}
|
|
|
|
func TestWSCompressionWithPartialWrite(t *testing.T) {
|
|
payload := "This is the content of a message that will be compresseddddddddddddddddddddd."
|
|
msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload)
|
|
|
|
o := testWSOptions()
|
|
o.Websocket.Compression = true
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port)
|
|
defer c.Close()
|
|
|
|
proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n"))
|
|
c.Write(proto)
|
|
l := testWSReadFrame(t, br)
|
|
if !bytes.Equal(l, []byte(pongProto)) {
|
|
t.Fatalf("Expected PONG, got %q", l)
|
|
}
|
|
|
|
pingPayload := []byte("my ping")
|
|
pingFromWSClient := testWSCreateClientMsg(wsPingMessage, 1, true, false, pingPayload)
|
|
|
|
var wc *testWSWrappedConn
|
|
var ws *client
|
|
s.mu.Lock()
|
|
for _, c := range s.clients {
|
|
ws = c
|
|
c.mu.Lock()
|
|
wc = &testWSWrappedConn{
|
|
Conn: c.nc,
|
|
buf: &bytes.Buffer{},
|
|
}
|
|
c.nc = wc
|
|
c.mu.Unlock()
|
|
break
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
wc.mu.Lock()
|
|
wc.partial = true
|
|
wc.mu.Unlock()
|
|
|
|
nc := natsConnect(t, s.ClientURL())
|
|
defer nc.Close()
|
|
|
|
expected := &bytes.Buffer{}
|
|
for i := 0; i < 10; i++ {
|
|
if i > 0 {
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
expected.Write([]byte(msgProto))
|
|
natsPub(t, nc, "foo", []byte(payload))
|
|
if i == 1 {
|
|
c.Write(pingFromWSClient)
|
|
}
|
|
}
|
|
|
|
var gotPingResponse bool
|
|
res := &bytes.Buffer{}
|
|
for total := 0; total < 10*len(msgProto); {
|
|
l := testWSReadFrame(t, br)
|
|
if bytes.Equal(l, pingPayload) {
|
|
gotPingResponse = true
|
|
} else {
|
|
n, _ := res.Write(l)
|
|
total += n
|
|
}
|
|
}
|
|
if !bytes.Equal(expected.Bytes(), res.Bytes()) {
|
|
t.Fatalf("Unexpected result: %q", res)
|
|
}
|
|
if !gotPingResponse {
|
|
t.Fatal("Did not get the ping response")
|
|
}
|
|
|
|
checkFor(t, time.Second, 15*time.Millisecond, func() error {
|
|
ws.mu.Lock()
|
|
pb := ws.out.pb
|
|
wf := ws.ws.frames
|
|
fs := ws.ws.fs
|
|
ws.mu.Unlock()
|
|
if pb != 0 || len(wf) != 0 || fs != 0 {
|
|
return fmt.Errorf("Expected pb, wf and fs to be 0, got %v, %v, %v", pb, wf, fs)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func TestWSCompressionFrameSizeLimit(t *testing.T) {
|
|
opts := testWSOptions()
|
|
opts.MaxPending = MAX_PENDING_SIZE
|
|
s := &Server{opts: opts}
|
|
c := &client{srv: s, ws: &websocket{compress: true, browser: true}}
|
|
c.initClient()
|
|
|
|
// uncompressedPayload := []byte("abcdefghijklmnopqrstuvwxyz")
|
|
uncompressedPayload := make([]byte, 2*wsFrameSizeForBrowsers)
|
|
for i := 0; i < len(uncompressedPayload); i++ {
|
|
uncompressedPayload[i] = byte(rand.Intn(256))
|
|
}
|
|
|
|
c.mu.Lock()
|
|
c.out.nb = append(net.Buffers(nil), uncompressedPayload)
|
|
nb, _ := c.collapsePtoNB()
|
|
c.mu.Unlock()
|
|
|
|
bb := &bytes.Buffer{}
|
|
for i, b := range nb {
|
|
// frame header buffer are always very small. The payload should not be more
|
|
// than 10 bytes since that is what we passed as the limit.
|
|
if len(b) > wsFrameSizeForBrowsers {
|
|
t.Fatalf("Frame size too big: %v (%q)", len(b), b)
|
|
}
|
|
// Check frame headers for the proper formatting.
|
|
if i%2 == 1 {
|
|
bb.Write(b)
|
|
}
|
|
}
|
|
buf := bb.Bytes()
|
|
buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff)
|
|
dbr := bytes.NewBuffer(buf)
|
|
d := flate.NewReader(dbr)
|
|
uncompressed, err := ioutil.ReadAll(d)
|
|
if err != nil {
|
|
t.Fatalf("Error reading frame: %v", err)
|
|
}
|
|
if !bytes.Equal(uncompressed, uncompressedPayload) {
|
|
t.Fatalf("Unexpected uncomressed data: %q", uncompressed)
|
|
}
|
|
}
|
|
|
|
func TestWSBasicAuth(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
opts func() *Options
|
|
user string
|
|
pass string
|
|
err string
|
|
}{
|
|
{
|
|
"top level auth, no override, wrong u/p",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Username = "normal"
|
|
o.Password = "client"
|
|
return o
|
|
},
|
|
"websocket", "client", "-ERR 'Authorization Violation'",
|
|
},
|
|
{
|
|
"top level auth, no override, correct u/p",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Username = "normal"
|
|
o.Password = "client"
|
|
return o
|
|
},
|
|
"normal", "client", "",
|
|
},
|
|
{
|
|
"no top level auth, ws auth, wrong u/p",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Websocket.Username = "websocket"
|
|
o.Websocket.Password = "client"
|
|
return o
|
|
},
|
|
"normal", "client", "-ERR 'Authorization Violation'",
|
|
},
|
|
{
|
|
"no top level auth, ws auth, correct u/p",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Websocket.Username = "websocket"
|
|
o.Websocket.Password = "client"
|
|
return o
|
|
},
|
|
"websocket", "client", "",
|
|
},
|
|
{
|
|
"top level auth, ws override, wrong u/p",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Username = "normal"
|
|
o.Password = "client"
|
|
o.Websocket.Username = "websocket"
|
|
o.Websocket.Password = "client"
|
|
return o
|
|
},
|
|
"normal", "client", "-ERR 'Authorization Violation'",
|
|
},
|
|
{
|
|
"top level auth, ws override, correct u/p",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Username = "normal"
|
|
o.Password = "client"
|
|
o.Websocket.Username = "websocket"
|
|
o.Websocket.Password = "client"
|
|
return o
|
|
},
|
|
"websocket", "client", "",
|
|
},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
o := test.opts()
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
|
|
defer wsc.Close()
|
|
|
|
connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n",
|
|
test.user, test.pass)
|
|
|
|
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
|
|
if _, err := wsc.Write(wsmsg); err != nil {
|
|
t.Fatalf("Error sending message: %v", err)
|
|
}
|
|
msg := testWSReadFrame(t, br)
|
|
if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
|
t.Fatalf("Expected to receive PONG, got %q", msg)
|
|
} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
|
|
t.Fatalf("Expected to receive %q, got %q", test.err, msg)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSAuthTimeout(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
at float64
|
|
wat float64
|
|
err string
|
|
}{
|
|
{"use top-level auth timeout", 10.0, 0.0, ""},
|
|
{"use websocket auth timeout", 10.0, 0.05, "-ERR 'Authentication Timeout'"},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
o := testWSOptions()
|
|
o.AuthTimeout = test.at
|
|
o.Websocket.Username = "websocket"
|
|
o.Websocket.Password = "client"
|
|
o.Websocket.AuthTimeout = test.wat
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
|
|
defer wsc.Close()
|
|
|
|
var info serverInfo
|
|
json.Unmarshal([]byte(l[5:]), &info)
|
|
// Make sure that we are told that auth is required.
|
|
if !info.AuthRequired {
|
|
t.Fatalf("Expected auth required, was not: %q", l)
|
|
}
|
|
start := time.Now()
|
|
// Wait before sending connect
|
|
time.Sleep(100 * time.Millisecond)
|
|
connectProto := "CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"websocket\",\"pass\":\"client\"}\r\nPING\r\n"
|
|
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
|
|
if _, err := wsc.Write(wsmsg); err != nil {
|
|
t.Fatalf("Error sending message: %v", err)
|
|
}
|
|
msg := testWSReadFrame(t, br)
|
|
if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
|
|
t.Fatalf("Expected to receive %q error, got %q", test.err, msg)
|
|
} else if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
|
t.Fatalf("Unexpected error: %q", msg)
|
|
}
|
|
if dur := time.Since(start); dur > time.Second {
|
|
t.Fatalf("Too long to get timeout error: %v", dur)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSTokenAuth(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
opts func() *Options
|
|
token string
|
|
err string
|
|
}{
|
|
{
|
|
"top level auth, no override, wrong token",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Authorization = "goodtoken"
|
|
return o
|
|
},
|
|
"badtoken", "-ERR 'Authorization Violation'",
|
|
},
|
|
{
|
|
"top level auth, no override, correct token",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Authorization = "goodtoken"
|
|
return o
|
|
},
|
|
"goodtoken", "",
|
|
},
|
|
{
|
|
"no top level auth, ws auth, wrong token",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Websocket.Token = "goodtoken"
|
|
return o
|
|
},
|
|
"badtoken", "-ERR 'Authorization Violation'",
|
|
},
|
|
{
|
|
"no top level auth, ws auth, correct token",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Websocket.Token = "goodtoken"
|
|
return o
|
|
},
|
|
"goodtoken", "",
|
|
},
|
|
{
|
|
"top level auth, ws override, wrong token",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Authorization = "clienttoken"
|
|
o.Websocket.Token = "websockettoken"
|
|
return o
|
|
},
|
|
"clienttoken", "-ERR 'Authorization Violation'",
|
|
},
|
|
{
|
|
"top level auth, ws override, correct token",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Authorization = "clienttoken"
|
|
o.Websocket.Token = "websockettoken"
|
|
return o
|
|
},
|
|
"websockettoken", "",
|
|
},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
o := test.opts()
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
|
|
defer wsc.Close()
|
|
|
|
connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"auth_token\":\"%s\"}\r\nPING\r\n",
|
|
test.token)
|
|
|
|
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
|
|
if _, err := wsc.Write(wsmsg); err != nil {
|
|
t.Fatalf("Error sending message: %v", err)
|
|
}
|
|
msg := testWSReadFrame(t, br)
|
|
if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
|
t.Fatalf("Expected to receive PONG, got %q", msg)
|
|
} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
|
|
t.Fatalf("Expected to receive %q, got %q", test.err, msg)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSBindToProperAccount(t *testing.T) {
|
|
conf := createConfFile(t, []byte(fmt.Sprintf(`
|
|
listen: "127.0.0.1:-1"
|
|
accounts {
|
|
a {
|
|
users [
|
|
{user: a, password: pwd, allowed_connection_types: ["%s", "%s"]}
|
|
]
|
|
}
|
|
b {
|
|
users [
|
|
{user: b, password: pwd}
|
|
]
|
|
}
|
|
}
|
|
websocket {
|
|
listen: "127.0.0.1:-1"
|
|
no_tls: true
|
|
}
|
|
`, jwt.ConnectionTypeStandard, strings.ToLower(jwt.ConnectionTypeWebsocket)))) // on purpose use lower case to ensure that it is converted.
|
|
defer os.Remove(conf)
|
|
s, o := RunServerWithConfig(conf)
|
|
defer s.Shutdown()
|
|
|
|
nc := natsConnect(t, fmt.Sprintf("nats://a:pwd@127.0.0.1:%d", o.Port))
|
|
defer nc.Close()
|
|
|
|
sub := natsSubSync(t, nc, "foo")
|
|
|
|
wsc, br, _ := testNewWSClient(t, testWSClientOptions{host: o.Websocket.Host, port: o.Websocket.Port, noTLS: true})
|
|
// Send CONNECT and PING
|
|
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false,
|
|
[]byte(fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", "a", "pwd")))
|
|
if _, err := wsc.Write(wsmsg); err != nil {
|
|
t.Fatalf("Error sending message: %v", err)
|
|
}
|
|
// Wait for the PONG
|
|
if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
|
t.Fatalf("Expected PONG, got %s", msg)
|
|
}
|
|
|
|
wsmsg = testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("PUB foo 7\r\nfrom ws\r\n"))
|
|
if _, err := wsc.Write(wsmsg); err != nil {
|
|
t.Fatalf("Error sending message: %v", err)
|
|
}
|
|
|
|
natsNexMsg(t, sub, time.Second)
|
|
}
|
|
|
|
func TestWSUsersAuth(t *testing.T) {
|
|
users := []*User{&User{Username: "user", Password: "pwd"}}
|
|
for _, test := range []struct {
|
|
name string
|
|
opts func() *Options
|
|
user string
|
|
pass string
|
|
err string
|
|
}{
|
|
{
|
|
"no filtering, wrong user",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Users = users
|
|
return o
|
|
},
|
|
"wronguser", "pwd", "-ERR 'Authorization Violation'",
|
|
},
|
|
{
|
|
"no filtering, correct user",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Users = users
|
|
return o
|
|
},
|
|
"user", "pwd", "",
|
|
},
|
|
{
|
|
"filering, user not allowed",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Users = users
|
|
// Only allowed for regular clients
|
|
o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard})
|
|
return o
|
|
},
|
|
"user", "pwd", "-ERR 'Authorization Violation'",
|
|
},
|
|
{
|
|
"filtering, user allowed",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Users = users
|
|
o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket})
|
|
return o
|
|
},
|
|
"user", "pwd", "",
|
|
},
|
|
{
|
|
"filtering, wrong password",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Users = users
|
|
o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket})
|
|
return o
|
|
},
|
|
"user", "badpassword", "-ERR 'Authorization Violation'",
|
|
},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
o := test.opts()
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
|
|
defer wsc.Close()
|
|
|
|
connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n",
|
|
test.user, test.pass)
|
|
|
|
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
|
|
if _, err := wsc.Write(wsmsg); err != nil {
|
|
t.Fatalf("Error sending message: %v", err)
|
|
}
|
|
msg := testWSReadFrame(t, br)
|
|
if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
|
t.Fatalf("Expected to receive PONG, got %q", msg)
|
|
} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
|
|
t.Fatalf("Expected to receive %q, got %q", test.err, msg)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSNoAuthUserValidation(t *testing.T) {
|
|
o := testWSOptions()
|
|
o.Users = []*User{&User{Username: "user", Password: "pwd"}}
|
|
// Should fail because it is not part of o.Users.
|
|
o.Websocket.NoAuthUser = "notfound"
|
|
if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") {
|
|
t.Fatalf("Expected error saying not present as user, got %v", err)
|
|
}
|
|
// Set a valid no auth user for global options, but still should fail because
|
|
// of o.Websocket.NoAuthUser
|
|
o.NoAuthUser = "user"
|
|
o.Websocket.NoAuthUser = "notfound"
|
|
if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") {
|
|
t.Fatalf("Expected error saying not present as user, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestWSNoAuthUser(t *testing.T) {
|
|
for _, test := range []struct {
|
|
name string
|
|
override bool
|
|
useAuth bool
|
|
expectedUser string
|
|
expectedAcc string
|
|
}{
|
|
{"no override, no user provided", false, false, "noauth", "normal"},
|
|
{"no override, user povided", false, true, "user", "normal"},
|
|
{"override, no user provided", true, false, "wsnoauth", "websocket"},
|
|
{"override, user provided", true, true, "wsuser", "websocket"},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
o := testWSOptions()
|
|
normalAcc := NewAccount("normal")
|
|
websocketAcc := NewAccount("websocket")
|
|
o.Accounts = []*Account{normalAcc, websocketAcc}
|
|
o.Users = []*User{
|
|
&User{Username: "noauth", Password: "pwd", Account: normalAcc},
|
|
&User{Username: "user", Password: "pwd", Account: normalAcc},
|
|
&User{Username: "wsnoauth", Password: "pwd", Account: websocketAcc},
|
|
&User{Username: "wsuser", Password: "pwd", Account: websocketAcc},
|
|
}
|
|
o.NoAuthUser = "noauth"
|
|
if test.override {
|
|
o.Websocket.NoAuthUser = "wsnoauth"
|
|
}
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
|
|
defer wsc.Close()
|
|
|
|
var info serverInfo
|
|
json.Unmarshal([]byte(l[5:]), &info)
|
|
|
|
var connectProto string
|
|
if test.useAuth {
|
|
connectProto = fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"pwd\"}\r\nPING\r\n",
|
|
test.expectedUser)
|
|
} else {
|
|
connectProto = "CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"
|
|
}
|
|
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
|
|
if _, err := wsc.Write(wsmsg); err != nil {
|
|
t.Fatalf("Error sending message: %v", err)
|
|
}
|
|
msg := testWSReadFrame(t, br)
|
|
if !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
|
t.Fatalf("Unexpected error: %q", msg)
|
|
}
|
|
|
|
c := s.getClient(info.CID)
|
|
c.mu.Lock()
|
|
uname := c.opts.Username
|
|
aname := c.acc.GetName()
|
|
c.mu.Unlock()
|
|
if uname != test.expectedUser {
|
|
t.Fatalf("Expected selected user to be %q, got %q", test.expectedUser, uname)
|
|
}
|
|
if aname != test.expectedAcc {
|
|
t.Fatalf("Expected selected account to be %q, got %q", test.expectedAcc, aname)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSNkeyAuth(t *testing.T) {
|
|
nkp, _ := nkeys.CreateUser()
|
|
pub, _ := nkp.PublicKey()
|
|
|
|
wsnkp, _ := nkeys.CreateUser()
|
|
wspub, _ := wsnkp.PublicKey()
|
|
|
|
badkp, _ := nkeys.CreateUser()
|
|
badpub, _ := badkp.PublicKey()
|
|
|
|
for _, test := range []struct {
|
|
name string
|
|
opts func() *Options
|
|
nkey string
|
|
kp nkeys.KeyPair
|
|
err string
|
|
}{
|
|
{
|
|
"no filtering, wrong nkey",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}}
|
|
return o
|
|
},
|
|
badpub, badkp, "-ERR 'Authorization Violation'",
|
|
},
|
|
{
|
|
"no filtering, correct nkey",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Nkeys = []*NkeyUser{&NkeyUser{Nkey: pub}}
|
|
return o
|
|
},
|
|
pub, nkp, "",
|
|
},
|
|
{
|
|
"filtering, nkey not allowed",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Nkeys = []*NkeyUser{
|
|
&NkeyUser{
|
|
Nkey: pub,
|
|
AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard}),
|
|
},
|
|
&NkeyUser{
|
|
Nkey: wspub,
|
|
AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeWebsocket}),
|
|
},
|
|
}
|
|
return o
|
|
},
|
|
pub, nkp, "-ERR 'Authorization Violation'",
|
|
},
|
|
{
|
|
"filtering, correct nkey",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Nkeys = []*NkeyUser{
|
|
&NkeyUser{Nkey: pub},
|
|
&NkeyUser{
|
|
Nkey: wspub,
|
|
AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}),
|
|
},
|
|
}
|
|
return o
|
|
},
|
|
wspub, wsnkp, "",
|
|
},
|
|
{
|
|
"filtering, wrong nkey",
|
|
func() *Options {
|
|
o := testWSOptions()
|
|
o.Nkeys = []*NkeyUser{
|
|
&NkeyUser{
|
|
Nkey: wspub,
|
|
AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}),
|
|
},
|
|
}
|
|
return o
|
|
},
|
|
badpub, badkp, "-ERR 'Authorization Violation'",
|
|
},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
o := test.opts()
|
|
s := RunServer(o)
|
|
defer s.Shutdown()
|
|
|
|
wsc, br, infoMsg := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
|
|
defer wsc.Close()
|
|
|
|
// Sign Nonce
|
|
var info nonceInfo
|
|
json.Unmarshal([]byte(infoMsg[5:]), &info)
|
|
sigraw, _ := test.kp.Sign([]byte(info.Nonce))
|
|
sig := base64.RawURLEncoding.EncodeToString(sigraw)
|
|
|
|
connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"nkey\":\"%s\",\"sig\":\"%s\"}\r\nPING\r\n", test.nkey, sig)
|
|
|
|
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
|
|
if _, err := wsc.Write(wsmsg); err != nil {
|
|
t.Fatalf("Error sending message: %v", err)
|
|
}
|
|
msg := testWSReadFrame(t, br)
|
|
if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
|
|
t.Fatalf("Expected to receive PONG, got %q", msg)
|
|
} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
|
|
t.Fatalf("Expected to receive %q, got %q", test.err, msg)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSJWTWithAllowedConnectionTypes(t *testing.T) {
|
|
o := testWSOptions()
|
|
setupAddTrusted(o)
|
|
s := RunServer(o)
|
|
buildMemAccResolver(s)
|
|
defer s.Shutdown()
|
|
|
|
for _, test := range []struct {
|
|
name string
|
|
connectionTypes []string
|
|
expectedAnswer string
|
|
}{
|
|
{"not allowed", []string{jwt.ConnectionTypeStandard}, "-ERR"},
|
|
{"allowed", []string{jwt.ConnectionTypeStandard, strings.ToLower(jwt.ConnectionTypeWebsocket)}, "+OK"},
|
|
{"allowed with unknown", []string{jwt.ConnectionTypeWebsocket, "SomeNewType"}, "+OK"},
|
|
{"not allowed with unknown", []string{"SomeNewType"}, "-ERR"},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
nuc := newJWTTestUserClaims()
|
|
nuc.AllowedConnectionTypes = test.connectionTypes
|
|
claimOpt := testClaimsOptions{
|
|
nuc: nuc,
|
|
expectAnswer: test.expectedAnswer,
|
|
}
|
|
_, c, _, _ := testWSWithClaims(t, s, testWSClientOptions{host: o.Websocket.Host, port: o.Websocket.Port}, claimOpt)
|
|
c.Close()
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWSJWTCookieUser(t *testing.T) {
|
|
|
|
nucSigFunc := func() *jwt.UserClaims { return newJWTTestUserClaims() }
|
|
nucBearerFunc := func() *jwt.UserClaims {
|
|
ret := newJWTTestUserClaims()
|
|
ret.BearerToken = true
|
|
return ret
|
|
}
|
|
|
|
o := testWSOptions()
|
|
setupAddTrusted(o)
|
|
setupAddCookie(o)
|
|
s := RunServer(o)
|
|
buildMemAccResolver(s)
|
|
defer s.Shutdown()
|
|
|
|
genJwt := func(t *testing.T, nuc *jwt.UserClaims) string {
|
|
okp, _ := nkeys.FromSeed(oSeed)
|
|
|
|
akp, _ := nkeys.CreateAccount()
|
|
apub, _ := akp.PublicKey()
|
|
|
|
nac := jwt.NewAccountClaims(apub)
|
|
ajwt, err := nac.Encode(okp)
|
|
if err != nil {
|
|
t.Fatalf("Error generating account JWT: %v", err)
|
|
}
|
|
|
|
nkp, _ := nkeys.CreateUser()
|
|
pub, _ := nkp.PublicKey()
|
|
nuc.Subject = pub
|
|
jwt, err := nuc.Encode(akp)
|
|
if err != nil {
|
|
t.Fatalf("Error generating user JWT: %v", err)
|
|
}
|
|
addAccountToMemResolver(s, apub, ajwt)
|
|
return jwt
|
|
}
|
|
|
|
cliOpts := testWSClientOptions{
|
|
host: o.Websocket.Host,
|
|
port: o.Websocket.Port,
|
|
}
|
|
for _, test := range []struct {
|
|
name string
|
|
nuc *jwt.UserClaims
|
|
opts func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions)
|
|
expectAnswer string
|
|
}{
|
|
{
|
|
name: "protocol auth, non-bearer key, with signature",
|
|
nuc: nucSigFunc(),
|
|
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
|
|
return cliOpts, testClaimsOptions{nuc: claims}
|
|
},
|
|
expectAnswer: "+OK",
|
|
},
|
|
{
|
|
name: "protocol auth, non-bearer key, w/o required signature",
|
|
nuc: nucSigFunc(),
|
|
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
|
|
return cliOpts, testClaimsOptions{nuc: claims, dontSign: true}
|
|
},
|
|
expectAnswer: "-ERR",
|
|
},
|
|
{
|
|
name: "protocol auth, bearer key, w/o signature",
|
|
nuc: nucBearerFunc(),
|
|
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
|
|
return cliOpts, testClaimsOptions{nuc: claims, dontSign: true}
|
|
},
|
|
expectAnswer: "+OK",
|
|
},
|
|
{
|
|
name: "cookie auth, non-bearer key, protocol auth fail",
|
|
nuc: nucSigFunc(),
|
|
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
|
|
co := cliOpts
|
|
co.extraHeaders = map[string]string{}
|
|
co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims)
|
|
return co, testClaimsOptions{connectRequest: struct{}{}}
|
|
},
|
|
expectAnswer: "-ERR",
|
|
},
|
|
{
|
|
name: "cookie auth, bearer key, protocol auth success with implied cookie jwt",
|
|
nuc: nucBearerFunc(),
|
|
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
|
|
co := cliOpts
|
|
co.extraHeaders = map[string]string{}
|
|
co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims)
|
|
return co, testClaimsOptions{connectRequest: struct{}{}}
|
|
},
|
|
expectAnswer: "+OK",
|
|
},
|
|
{
|
|
name: "cookie auth, non-bearer key, protocol auth success via override jwt in CONNECT opts",
|
|
nuc: nucSigFunc(),
|
|
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
|
|
co := cliOpts
|
|
co.extraHeaders = map[string]string{}
|
|
co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims)
|
|
return co, testClaimsOptions{nuc: nucBearerFunc()}
|
|
},
|
|
expectAnswer: "+OK",
|
|
},
|
|
} {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
cliOpt, claimOpt := test.opts(t, test.nuc)
|
|
claimOpt.expectAnswer = test.expectAnswer
|
|
_, c, _, _ := testWSWithClaims(t, s, cliOpt, claimOpt)
|
|
c.Close()
|
|
})
|
|
}
|
|
s.Shutdown()
|
|
|
|
}
|
|
|
|
// ==================================================================
|
|
// = Benchmark tests
|
|
// ==================================================================
|
|
|
|
const testWSBenchSubject = "a"
|
|
|
|
var ch = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()")
|
|
|
|
func sizedString(sz int) string {
|
|
b := make([]byte, sz)
|
|
for i := range b {
|
|
b[i] = ch[rand.Intn(len(ch))]
|
|
}
|
|
return string(b)
|
|
}
|
|
|
|
func sizedStringForCompression(sz int) string {
|
|
b := make([]byte, sz)
|
|
c := byte(0)
|
|
s := 0
|
|
for i := range b {
|
|
if s%20 == 0 {
|
|
c = ch[rand.Intn(len(ch))]
|
|
}
|
|
b[i] = c
|
|
}
|
|
return string(b)
|
|
}
|
|
|
|
func testWSFlushConn(b *testing.B, compress bool, c net.Conn, br *bufio.Reader) {
|
|
buf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte(pingProto))
|
|
c.Write(buf)
|
|
c.SetReadDeadline(time.Now().Add(5 * time.Second))
|
|
res := testWSReadFrame(b, br)
|
|
c.SetReadDeadline(time.Time{})
|
|
if !bytes.HasPrefix(res, []byte(pongProto)) {
|
|
b.Fatalf("Failed read of PONG: %s\n", res)
|
|
}
|
|
}
|
|
|
|
func wsBenchPub(b *testing.B, numPubs int, compress bool, payload string) {
|
|
b.StopTimer()
|
|
opts := testWSOptions()
|
|
opts.Websocket.Compression = compress
|
|
s := RunServer(opts)
|
|
defer s.Shutdown()
|
|
|
|
n := b.N
|
|
extra := 0
|
|
pubProto := []byte(fmt.Sprintf("PUB %s %d\r\n%s\r\n", testWSBenchSubject, len(payload), payload))
|
|
singleOpBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, pubProto)
|
|
|
|
// Simulate client that would buffer messages before framing/sending.
|
|
// Figure out how many we can fit in one frame based on b.N and length of pubProto
|
|
const bufSize = 32768
|
|
tmpa := [bufSize]byte{}
|
|
tmp := tmpa[:0]
|
|
pb := 0
|
|
for i := 0; i < b.N; i++ {
|
|
tmp = append(tmp, pubProto...)
|
|
pb++
|
|
if len(tmp) >= bufSize {
|
|
break
|
|
}
|
|
}
|
|
sendBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, tmp)
|
|
n = b.N / pb
|
|
extra = b.N - (n * pb)
|
|
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(numPubs)
|
|
|
|
type pub struct {
|
|
c net.Conn
|
|
br *bufio.Reader
|
|
bw *bufio.Writer
|
|
}
|
|
var pubs []pub
|
|
for i := 0; i < numPubs; i++ {
|
|
wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port)
|
|
defer wsc.Close()
|
|
bw := bufio.NewWriterSize(wsc, bufSize)
|
|
pubs = append(pubs, pub{wsc, br, bw})
|
|
}
|
|
|
|
// Average the amount of bytes sent by iteration
|
|
avg := len(sendBuf) / pb
|
|
if extra > 0 {
|
|
avg += len(singleOpBuf)
|
|
avg /= 2
|
|
}
|
|
b.SetBytes(int64(numPubs * avg))
|
|
b.StartTimer()
|
|
|
|
for i := 0; i < numPubs; i++ {
|
|
p := pubs[i]
|
|
go func(p pub) {
|
|
defer wg.Done()
|
|
for i := 0; i < n; i++ {
|
|
p.bw.Write(sendBuf)
|
|
}
|
|
for i := 0; i < extra; i++ {
|
|
p.bw.Write(singleOpBuf)
|
|
}
|
|
p.bw.Flush()
|
|
testWSFlushConn(b, compress, p.c, p.br)
|
|
}(p)
|
|
}
|
|
wg.Wait()
|
|
b.StopTimer()
|
|
}
|
|
|
|
func Benchmark_WS_Pubx1_CN_____0b(b *testing.B) {
|
|
wsBenchPub(b, 1, false, "")
|
|
}
|
|
|
|
func Benchmark_WS_Pubx1_CY_____0b(b *testing.B) {
|
|
wsBenchPub(b, 1, true, "")
|
|
}
|
|
|
|
func Benchmark_WS_Pubx1_CN___128b(b *testing.B) {
|
|
s := sizedString(128)
|
|
wsBenchPub(b, 1, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx1_CY___128b(b *testing.B) {
|
|
s := sizedStringForCompression(128)
|
|
wsBenchPub(b, 1, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx1_CN__1024b(b *testing.B) {
|
|
s := sizedString(1024)
|
|
wsBenchPub(b, 1, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx1_CY__1024b(b *testing.B) {
|
|
s := sizedStringForCompression(1024)
|
|
wsBenchPub(b, 1, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx1_CN__4096b(b *testing.B) {
|
|
s := sizedString(4 * 1024)
|
|
wsBenchPub(b, 1, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx1_CY__4096b(b *testing.B) {
|
|
s := sizedStringForCompression(4 * 1024)
|
|
wsBenchPub(b, 1, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx1_CN__8192b(b *testing.B) {
|
|
s := sizedString(8 * 1024)
|
|
wsBenchPub(b, 1, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx1_CY__8192b(b *testing.B) {
|
|
s := sizedStringForCompression(8 * 1024)
|
|
wsBenchPub(b, 1, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx1_CN_32768b(b *testing.B) {
|
|
s := sizedString(32 * 1024)
|
|
wsBenchPub(b, 1, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx1_CY_32768b(b *testing.B) {
|
|
s := sizedStringForCompression(32 * 1024)
|
|
wsBenchPub(b, 1, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx5_CN_____0b(b *testing.B) {
|
|
wsBenchPub(b, 5, false, "")
|
|
}
|
|
|
|
func Benchmark_WS_Pubx5_CY_____0b(b *testing.B) {
|
|
wsBenchPub(b, 5, true, "")
|
|
}
|
|
|
|
func Benchmark_WS_Pubx5_CN___128b(b *testing.B) {
|
|
s := sizedString(128)
|
|
wsBenchPub(b, 5, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx5_CY___128b(b *testing.B) {
|
|
s := sizedStringForCompression(128)
|
|
wsBenchPub(b, 5, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx5_CN__1024b(b *testing.B) {
|
|
s := sizedString(1024)
|
|
wsBenchPub(b, 5, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx5_CY__1024b(b *testing.B) {
|
|
s := sizedStringForCompression(1024)
|
|
wsBenchPub(b, 5, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx5_CN__4096b(b *testing.B) {
|
|
s := sizedString(4 * 1024)
|
|
wsBenchPub(b, 5, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx5_CY__4096b(b *testing.B) {
|
|
s := sizedStringForCompression(4 * 1024)
|
|
wsBenchPub(b, 5, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx5_CN__8192b(b *testing.B) {
|
|
s := sizedString(8 * 1024)
|
|
wsBenchPub(b, 5, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx5_CY__8192b(b *testing.B) {
|
|
s := sizedStringForCompression(8 * 1024)
|
|
wsBenchPub(b, 5, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx5_CN_32768b(b *testing.B) {
|
|
s := sizedString(32 * 1024)
|
|
wsBenchPub(b, 5, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Pubx5_CY_32768b(b *testing.B) {
|
|
s := sizedStringForCompression(32 * 1024)
|
|
wsBenchPub(b, 5, true, s)
|
|
}
|
|
|
|
func wsBenchSub(b *testing.B, numSubs int, compress bool, payload string) {
|
|
b.StopTimer()
|
|
opts := testWSOptions()
|
|
opts.Websocket.Compression = compress
|
|
s := RunServer(opts)
|
|
defer s.Shutdown()
|
|
|
|
var subs []*bufio.Reader
|
|
for i := 0; i < numSubs; i++ {
|
|
wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port)
|
|
defer wsc.Close()
|
|
subProto := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress,
|
|
[]byte(fmt.Sprintf("SUB %s 1\r\nPING\r\n", testWSBenchSubject)))
|
|
wsc.Write(subProto)
|
|
// Waiting for PONG
|
|
testWSReadFrame(b, br)
|
|
subs = append(subs, br)
|
|
}
|
|
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(numSubs)
|
|
|
|
// Use regular NATS client to publish messages
|
|
nc := natsConnect(b, s.ClientURL())
|
|
defer nc.Close()
|
|
|
|
b.StartTimer()
|
|
|
|
for i := 0; i < numSubs; i++ {
|
|
br := subs[i]
|
|
go func(br *bufio.Reader) {
|
|
defer wg.Done()
|
|
for count := 0; count < b.N; {
|
|
msgs := testWSReadFrame(b, br)
|
|
count += bytes.Count(msgs, []byte("MSG "))
|
|
}
|
|
}(br)
|
|
}
|
|
for i := 0; i < b.N; i++ {
|
|
natsPub(b, nc, testWSBenchSubject, []byte(payload))
|
|
}
|
|
wg.Wait()
|
|
b.StopTimer()
|
|
}
|
|
|
|
func Benchmark_WS_Subx1_CN_____0b(b *testing.B) {
|
|
wsBenchSub(b, 1, false, "")
|
|
}
|
|
|
|
func Benchmark_WS_Subx1_CY_____0b(b *testing.B) {
|
|
wsBenchSub(b, 1, true, "")
|
|
}
|
|
|
|
func Benchmark_WS_Subx1_CN___128b(b *testing.B) {
|
|
s := sizedString(128)
|
|
wsBenchSub(b, 1, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx1_CY___128b(b *testing.B) {
|
|
s := sizedStringForCompression(128)
|
|
wsBenchSub(b, 1, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx1_CN__1024b(b *testing.B) {
|
|
s := sizedString(1024)
|
|
wsBenchSub(b, 1, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx1_CY__1024b(b *testing.B) {
|
|
s := sizedStringForCompression(1024)
|
|
wsBenchSub(b, 1, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx1_CN__4096b(b *testing.B) {
|
|
s := sizedString(4096)
|
|
wsBenchSub(b, 1, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx1_CY__4096b(b *testing.B) {
|
|
s := sizedStringForCompression(4096)
|
|
wsBenchSub(b, 1, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx1_CN__8192b(b *testing.B) {
|
|
s := sizedString(8192)
|
|
wsBenchSub(b, 1, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx1_CY__8192b(b *testing.B) {
|
|
s := sizedStringForCompression(8192)
|
|
wsBenchSub(b, 1, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx1_CN_32768b(b *testing.B) {
|
|
s := sizedString(32768)
|
|
wsBenchSub(b, 1, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx1_CY_32768b(b *testing.B) {
|
|
s := sizedStringForCompression(32768)
|
|
wsBenchSub(b, 1, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx5_CN_____0b(b *testing.B) {
|
|
wsBenchSub(b, 5, false, "")
|
|
}
|
|
|
|
func Benchmark_WS_Subx5_CY_____0b(b *testing.B) {
|
|
wsBenchSub(b, 5, true, "")
|
|
}
|
|
|
|
func Benchmark_WS_Subx5_CN___128b(b *testing.B) {
|
|
s := sizedString(128)
|
|
wsBenchSub(b, 5, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx5_CY___128b(b *testing.B) {
|
|
s := sizedStringForCompression(128)
|
|
wsBenchSub(b, 5, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx5_CN__1024b(b *testing.B) {
|
|
s := sizedString(1024)
|
|
wsBenchSub(b, 5, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx5_CY__1024b(b *testing.B) {
|
|
s := sizedStringForCompression(1024)
|
|
wsBenchSub(b, 5, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx5_CN__4096b(b *testing.B) {
|
|
s := sizedString(4096)
|
|
wsBenchSub(b, 5, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx5_CY__4096b(b *testing.B) {
|
|
s := sizedStringForCompression(4096)
|
|
wsBenchSub(b, 5, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx5_CN__8192b(b *testing.B) {
|
|
s := sizedString(8192)
|
|
wsBenchSub(b, 5, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx5_CY__8192b(b *testing.B) {
|
|
s := sizedStringForCompression(8192)
|
|
wsBenchSub(b, 5, true, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx5_CN_32768b(b *testing.B) {
|
|
s := sizedString(32768)
|
|
wsBenchSub(b, 5, false, s)
|
|
}
|
|
|
|
func Benchmark_WS_Subx5_CY_32768b(b *testing.B) {
|
|
s := sizedStringForCompression(32768)
|
|
wsBenchSub(b, 5, true, s)
|
|
}
|