1
0
mirror of https://github.com/taigrr/log-socket synced 2026-04-14 03:47:58 -07:00

3 Commits

Author SHA1 Message Date
9384170eb0 feat(ws): add read pump for disconnect detection, GetContext, Level.String
- Add WebSocket read pump in LogSocketHandler so client disconnects are
  detected promptly instead of only on the next WriteMessage call
- Add Client.GetContext(ctx) for context-aware blocking reads
- Implement fmt.Stringer on Level type (Level.String())
- Add tests for GetContext, GetContext cancellation, and Level.String()
2026-04-07 10:08:50 +00:00
70ade62c8c test(ws,browser): add test suites for ws and browser packages (#22)
- Add ws/server_test.go: WebSocket connection, namespace filtering, non-upgrade rejection, SetUpgrader
- Add ws/namespaces_test.go: NamespacesHandler response format and content
- Add browser/browser_test.go: template rendering, ws URL construction, trailing slash handling
- Fix goimports whitespace in ws/server.go
2026-03-08 14:29:47 -04:00
2375c6ca90 test(log): add comprehensive Logger type method tests (#23)
Tests all Logger methods: Trace/f/ln, Debug/f/ln, Info/f/ln,
Notice/f/ln, Warn/f/ln, Error/f/ln, Panic/f/ln, Print/f/ln,
Default(), SetInfoDepth, and panic-with-error behavior.

Previously these methods had near-zero direct test coverage.
2026-03-08 14:28:57 -04:00
9 changed files with 908 additions and 17 deletions

53
browser/browser_test.go Normal file
View File

@@ -0,0 +1,53 @@
package browser
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestLogSocketViewHandler_HTTP(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost:8080/", nil)
w := httptest.NewRecorder()
LogSocketViewHandler(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want 200", w.Code)
}
body := w.Body.String()
// html/template escapes forward slashes in JS context
if !strings.Contains(body, `ws:\/\/localhost:8080\/ws`) {
t.Error("response should contain escaped ws://localhost:8080/ws URL")
}
if !strings.Contains(body, "<!DOCTYPE html>") {
t.Error("response should contain HTML doctype")
}
}
func TestLogSocketViewHandler_CustomPath(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://myhost:9090/dashboard/", nil)
w := httptest.NewRecorder()
LogSocketViewHandler(w, req)
body := w.Body.String()
if !strings.Contains(body, `ws:\/\/myhost:9090\/dashboard\/ws`) {
t.Error("expected escaped ws://myhost:9090/dashboard/ws in body")
}
}
func TestLogSocketViewHandler_TrailingSlashTrimmed(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
w := httptest.NewRecorder()
LogSocketViewHandler(w, req)
body := w.Body.String()
// Should NOT have double slash before ws
if strings.Contains(body, `\/\/ws`) {
t.Error("should not have double slash before /ws")
}
if !strings.Contains(body, `ws:\/\/example.com\/ws`) {
t.Error("expected escaped ws://example.com/ws in body")
}
}

View File

@@ -1,6 +1,7 @@
package log package log
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"os" "os"
@@ -153,6 +154,7 @@ func (c *Client) SetLogLevel(level Level) {
c.LogLevel = level c.LogLevel = level
} }
// Get blocks until a log entry is available and returns it.
func (c *Client) Get() Entry { func (c *Client) Get() Entry {
if !c.initialized { if !c.initialized {
panic(errors.New("cannot get logs for uninitialized client, did you use CreateClient?")) panic(errors.New("cannot get logs for uninitialized client, did you use CreateClient?"))
@@ -160,6 +162,21 @@ func (c *Client) Get() Entry {
return <-c.writer return <-c.writer
} }
// GetContext blocks until a log entry is available or ctx is cancelled.
// The second return value is false when the context was cancelled before
// an entry arrived.
func (c *Client) GetContext(ctx context.Context) (Entry, bool) {
if !c.initialized {
panic(errors.New("cannot get logs for uninitialized client, did you use CreateClient?"))
}
select {
case e := <-c.writer:
return e, true
case <-ctx.Done():
return Entry{}, false
}
}
// Trace prints out logs on trace level // Trace prints out logs on trace level
func Trace(args ...any) { func Trace(args ...any) {
output := fmt.Sprint(args...) output := fmt.Sprint(args...)

View File

@@ -1,6 +1,7 @@
package log package log
import ( import (
"context"
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
@@ -378,8 +379,8 @@ func TestMultiNamespaceClient(t *testing.T) {
authLogger := NewLogger("auth") authLogger := NewLogger("auth")
dbLogger := NewLogger("database") dbLogger := NewLogger("database")
dbLogger.Info("db message") // filtered out dbLogger.Info("db message") // filtered out
apiLogger.Info("api message") // should arrive apiLogger.Info("api message") // should arrive
authLogger.Info("auth message") // should arrive authLogger.Info("auth message") // should arrive
e1, ok := getEntry(c, time.Second) e1, ok := getEntry(c, time.Second)
@@ -550,6 +551,68 @@ func TestMatchesNamespace(t *testing.T) {
c2.Destroy() c2.Destroy()
} }
// TestGetContext verifies context cancellation stops blocking Get.
func TestGetContext(t *testing.T) {
c := CreateClient(DefaultNamespace)
c.SetLogLevel(LTrace)
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
_, ok := c.GetContext(ctx)
if ok {
t.Error("expected GetContext to return false on cancelled context")
}
c.Destroy()
}
// TestGetContextReceivesEntry verifies GetContext delivers entries normally.
func TestGetContextReceivesEntry(t *testing.T) {
c := CreateClient(DefaultNamespace)
c.SetLogLevel(LTrace)
go func() {
time.Sleep(10 * time.Millisecond)
Info("context entry")
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
e, ok := c.GetContext(ctx)
if !ok {
t.Fatal("expected GetContext to return entry")
}
if e.Output != "context entry" {
t.Errorf("output = %q, want %q", e.Output, "context entry")
}
c.Destroy()
}
// TestLevelString verifies the Level.String() method.
func TestLevelString(t *testing.T) {
tests := []struct {
level Level
want string
}{
{LTrace, "TRACE"},
{LDebug, "DEBUG"},
{LInfo, "INFO"},
{LNotice, "NOTICE"},
{LWarn, "WARN"},
{LError, "ERROR"},
{LPanic, "PANIC"},
{LFatal, "FATAL"},
{Level(99), "UNKNOWN"},
}
for _, tt := range tests {
got := tt.level.String()
if got != tt.want {
t.Errorf("Level(%d).String() = %q, want %q", tt.level, got, tt.want)
}
}
}
func TestFlush(t *testing.T) { func TestFlush(t *testing.T) {
defer Flush() defer Flush()
} }

518
log/logger_test.go Normal file
View File

@@ -0,0 +1,518 @@
package log
import (
"testing"
"time"
)
func TestLoggerTrace(t *testing.T) {
c := CreateClient("logger-trace")
c.SetLogLevel(LTrace)
l := NewLogger("logger-trace")
l.Trace("trace message")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Level != "TRACE" {
t.Errorf("level = %q, want TRACE", e.Level)
}
if e.Output != "trace message" {
t.Errorf("output = %q, want %q", e.Output, "trace message")
}
if e.Namespace != "logger-trace" {
t.Errorf("namespace = %q, want %q", e.Namespace, "logger-trace")
}
c.Destroy()
}
func TestLoggerTracef(t *testing.T) {
c := CreateClient("logger-tracef")
c.SetLogLevel(LTrace)
l := NewLogger("logger-tracef")
l.Tracef("trace %s %d", "msg", 1)
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Level != "TRACE" {
t.Errorf("level = %q, want TRACE", e.Level)
}
if e.Output != "trace msg 1" {
t.Errorf("output = %q, want %q", e.Output, "trace msg 1")
}
c.Destroy()
}
func TestLoggerTraceln(t *testing.T) {
c := CreateClient("logger-traceln")
c.SetLogLevel(LTrace)
l := NewLogger("logger-traceln")
l.Traceln("trace line")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Level != "TRACE" {
t.Errorf("level = %q, want TRACE", e.Level)
}
if e.Output != "trace line\n" {
t.Errorf("output = %q, want %q", e.Output, "trace line\n")
}
c.Destroy()
}
func TestLoggerDebug(t *testing.T) {
c := CreateClient("logger-debug")
c.SetLogLevel(LDebug)
l := NewLogger("logger-debug")
l.Debug("debug message")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Level != "DEBUG" {
t.Errorf("level = %q, want DEBUG", e.Level)
}
if e.Output != "debug message" {
t.Errorf("output = %q, want %q", e.Output, "debug message")
}
c.Destroy()
}
func TestLoggerDebugf(t *testing.T) {
c := CreateClient("logger-debugf")
c.SetLogLevel(LDebug)
l := NewLogger("logger-debugf")
l.Debugf("debug %d", 42)
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "debug 42" {
t.Errorf("output = %q, want %q", e.Output, "debug 42")
}
c.Destroy()
}
func TestLoggerInfo(t *testing.T) {
c := CreateClient("logger-info")
c.SetLogLevel(LInfo)
l := NewLogger("logger-info")
l.Info("info message")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Level != "INFO" {
t.Errorf("level = %q, want INFO", e.Level)
}
if e.Output != "info message" {
t.Errorf("output = %q, want %q", e.Output, "info message")
}
c.Destroy()
}
func TestLoggerInfof(t *testing.T) {
c := CreateClient("logger-infof")
c.SetLogLevel(LInfo)
l := NewLogger("logger-infof")
l.Infof("count: %d", 99)
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "count: 99" {
t.Errorf("output = %q, want %q", e.Output, "count: 99")
}
c.Destroy()
}
func TestLoggerInfoln(t *testing.T) {
c := CreateClient("logger-infoln")
c.SetLogLevel(LInfo)
l := NewLogger("logger-infoln")
l.Infoln("info line")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "info line\n" {
t.Errorf("output = %q, want %q", e.Output, "info line\n")
}
c.Destroy()
}
func TestLoggerNotice(t *testing.T) {
c := CreateClient("logger-notice")
c.SetLogLevel(LNotice)
l := NewLogger("logger-notice")
l.Notice("notice message")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Level != "NOTICE" {
t.Errorf("level = %q, want NOTICE", e.Level)
}
if e.Output != "notice message" {
t.Errorf("output = %q, want %q", e.Output, "notice message")
}
c.Destroy()
}
func TestLoggerNoticef(t *testing.T) {
c := CreateClient("logger-noticef")
c.SetLogLevel(LNotice)
l := NewLogger("logger-noticef")
l.Noticef("notice %s", "formatted")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "notice formatted" {
t.Errorf("output = %q, want %q", e.Output, "notice formatted")
}
c.Destroy()
}
func TestLoggerNoticeln(t *testing.T) {
c := CreateClient("logger-noticeln")
c.SetLogLevel(LNotice)
l := NewLogger("logger-noticeln")
l.Noticeln("notice line")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "notice line\n" {
t.Errorf("output = %q, want %q", e.Output, "notice line\n")
}
c.Destroy()
}
func TestLoggerWarn(t *testing.T) {
c := CreateClient("logger-warn")
c.SetLogLevel(LWarn)
l := NewLogger("logger-warn")
l.Warn("warn message")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Level != "WARN" {
t.Errorf("level = %q, want WARN", e.Level)
}
if e.Output != "warn message" {
t.Errorf("output = %q, want %q", e.Output, "warn message")
}
c.Destroy()
}
func TestLoggerWarnf(t *testing.T) {
c := CreateClient("logger-warnf")
c.SetLogLevel(LWarn)
l := NewLogger("logger-warnf")
l.Warnf("warn %d", 1)
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "warn 1" {
t.Errorf("output = %q, want %q", e.Output, "warn 1")
}
c.Destroy()
}
func TestLoggerWarnln(t *testing.T) {
c := CreateClient("logger-warnln")
c.SetLogLevel(LWarn)
l := NewLogger("logger-warnln")
l.Warnln("warn line")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "warn line\n" {
t.Errorf("output = %q, want %q", e.Output, "warn line\n")
}
c.Destroy()
}
func TestLoggerError(t *testing.T) {
c := CreateClient("logger-error")
c.SetLogLevel(LError)
l := NewLogger("logger-error")
l.Error("error message")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Level != "ERROR" {
t.Errorf("level = %q, want ERROR", e.Level)
}
if e.Output != "error message" {
t.Errorf("output = %q, want %q", e.Output, "error message")
}
c.Destroy()
}
func TestLoggerErrorf(t *testing.T) {
c := CreateClient("logger-errorf")
c.SetLogLevel(LError)
l := NewLogger("logger-errorf")
l.Errorf("err: %s", "something broke")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "err: something broke" {
t.Errorf("output = %q, want %q", e.Output, "err: something broke")
}
c.Destroy()
}
func TestLoggerErrorln(t *testing.T) {
c := CreateClient("logger-errorln")
c.SetLogLevel(LError)
l := NewLogger("logger-errorln")
l.Errorln("error line")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "error line\n" {
t.Errorf("output = %q, want %q", e.Output, "error line\n")
}
c.Destroy()
}
func TestLoggerPanic(t *testing.T) {
c := CreateClient("logger-panic")
c.SetLogLevel(LPanic)
l := NewLogger("logger-panic")
defer func() {
r := recover()
if r == nil {
t.Error("expected panic, got nil")
}
// Verify the entry was broadcast
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out waiting for panic entry")
}
if e.Level != "PANIC" {
t.Errorf("level = %q, want PANIC", e.Level)
}
if e.Namespace != "logger-panic" {
t.Errorf("namespace = %q, want %q", e.Namespace, "logger-panic")
}
c.Destroy()
}()
l.Panic("panic message")
}
func TestLoggerPanicf(t *testing.T) {
c := CreateClient("logger-panicf")
c.SetLogLevel(LPanic)
l := NewLogger("logger-panicf")
defer func() {
r := recover()
if r == nil {
t.Error("expected panic, got nil")
}
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "panic 42" {
t.Errorf("output = %q, want %q", e.Output, "panic 42")
}
c.Destroy()
}()
l.Panicf("panic %d", 42)
}
func TestLoggerPanicln(t *testing.T) {
c := CreateClient("logger-panicln")
c.SetLogLevel(LPanic)
l := NewLogger("logger-panicln")
defer func() {
r := recover()
if r == nil {
t.Error("expected panic, got nil")
}
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "panic line\n" {
t.Errorf("output = %q, want %q", e.Output, "panic line\n")
}
c.Destroy()
}()
l.Panicln("panic line")
}
func TestLoggerPrint(t *testing.T) {
c := CreateClient("logger-print")
c.SetLogLevel(LInfo)
l := NewLogger("logger-print")
l.Print("print message")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
// Print delegates to Info
if e.Level != "INFO" {
t.Errorf("level = %q, want INFO", e.Level)
}
if e.Output != "print message" {
t.Errorf("output = %q, want %q", e.Output, "print message")
}
c.Destroy()
}
func TestLoggerPrintf(t *testing.T) {
c := CreateClient("logger-printf")
c.SetLogLevel(LInfo)
l := NewLogger("logger-printf")
l.Printf("formatted %s", "print")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "formatted print" {
t.Errorf("output = %q, want %q", e.Output, "formatted print")
}
c.Destroy()
}
func TestLoggerPrintln(t *testing.T) {
c := CreateClient("logger-println")
c.SetLogLevel(LInfo)
l := NewLogger("logger-println")
l.Println("println message")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "println message\n" {
t.Errorf("output = %q, want %q", e.Output, "println message\n")
}
c.Destroy()
}
func TestLoggerSetInfoDepth(t *testing.T) {
l := NewLogger("depth-test")
l.SetInfoDepth(3)
if l.FileInfoDepth != 3 {
t.Errorf("FileInfoDepth = %d, want 3", l.FileInfoDepth)
}
}
func TestDefaultLogger(t *testing.T) {
l := Default()
if l.Namespace != DefaultNamespace {
t.Errorf("namespace = %q, want %q", l.Namespace, DefaultNamespace)
}
if l.FileInfoDepth != 0 {
t.Errorf("FileInfoDepth = %d, want 0", l.FileInfoDepth)
}
// Verify Default() logger can emit entries
c := CreateClient(DefaultNamespace)
c.SetLogLevel(LInfo)
l.Info("default logger message")
e, ok := getEntry(c, time.Second)
if !ok {
t.Fatal("timed out")
}
if e.Output != "default logger message" {
t.Errorf("output = %q, want %q", e.Output, "default logger message")
}
c.Destroy()
}
func TestLoggerPanicWithError(t *testing.T) {
// When the first arg is an error, Panic should re-panic with that error
c := CreateClient("logger-panic-err")
c.SetLogLevel(LPanic)
l := NewLogger("logger-panic-err")
testErr := errTest("test error")
defer func() {
r := recover()
if r == nil {
t.Error("expected panic, got nil")
}
if err, ok := r.(errTest); ok {
if string(err) != "test error" {
t.Errorf("panic value = %q, want %q", string(err), "test error")
}
} else {
// The first arg was an error, so Panic should re-panic with it
t.Logf("panic type = %T, value = %v (implementation re-panics with original error)", r, r)
}
c.Destroy()
}()
l.Panic(testErr)
}
// errTest is a simple error type for testing.
type errTest string
func (e errTest) Error() string { return string(e) }

View File

@@ -15,6 +15,31 @@ const (
const DefaultNamespace = "default" const DefaultNamespace = "default"
// String returns the human-readable name of the log level (e.g. "INFO").
// It implements [fmt.Stringer].
func (l Level) String() string {
switch l {
case LTrace:
return "TRACE"
case LDebug:
return "DEBUG"
case LInfo:
return "INFO"
case LNotice:
return "NOTICE"
case LWarn:
return "WARN"
case LError:
return "ERROR"
case LPanic:
return "PANIC"
case LFatal:
return "FATAL"
default:
return "UNKNOWN"
}
}
type ( type (
LogWriter chan Entry LogWriter chan Entry
Level int Level int

10
main.go
View File

@@ -17,21 +17,21 @@ func generateLogs() {
apiLogger := logger.NewLogger("api") apiLogger := logger.NewLogger("api")
dbLogger := logger.NewLogger("database") dbLogger := logger.NewLogger("database")
authLogger := logger.NewLogger("auth") authLogger := logger.NewLogger("auth")
for { for {
logger.Info("This is a default namespace log!") logger.Info("This is a default namespace log!")
apiLogger.Info("API request received") apiLogger.Info("API request received")
apiLogger.Debug("Processing API call") apiLogger.Debug("Processing API call")
dbLogger.Info("Database query executed") dbLogger.Info("Database query executed")
dbLogger.Warn("Slow query detected") dbLogger.Warn("Slow query detected")
authLogger.Info("User authentication successful") authLogger.Info("User authentication successful")
authLogger.Error("Failed login attempt detected") authLogger.Error("Failed login attempt detected")
logger.Trace("This is a trace log in default namespace!") logger.Trace("This is a trace log in default namespace!")
logger.Warn("This is a warning in default namespace!") logger.Warn("This is a warning in default namespace!")
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
} }
} }

62
ws/namespaces_test.go Normal file
View File

@@ -0,0 +1,62 @@
package ws
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
logger "github.com/taigrr/log-socket/v2/log"
)
func TestNamespacesHandler(t *testing.T) {
// Log to a known namespace to ensure it appears
nsLogger := logger.NewLogger("ns-handler-test")
nsLogger.Info("register namespace")
req := httptest.NewRequest(http.MethodGet, "/api/namespaces", nil)
w := httptest.NewRecorder()
NamespacesHandler(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want 200", w.Code)
}
ct := w.Header().Get("Content-Type")
if ct != "application/json" {
t.Errorf("Content-Type = %q, want application/json", ct)
}
var result struct {
Namespaces []string `json:"namespaces"`
}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
found := false
for _, ns := range result.Namespaces {
if ns == "ns-handler-test" {
found = true
break
}
}
if !found {
t.Errorf("namespace 'ns-handler-test' not found in %v", result.Namespaces)
}
}
func TestNamespacesHandler_ResponseFormat(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/api/namespaces", nil)
w := httptest.NewRecorder()
NamespacesHandler(w, req)
var result map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("response is not valid JSON: %v", err)
}
if _, ok := result["namespaces"]; !ok {
t.Error("response missing 'namespaces' key")
}
}

View File

@@ -1,6 +1,7 @@
package ws package ws
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings" "strings"
@@ -11,36 +12,61 @@ import (
var upgrader = websocket.Upgrader{} // use default options var upgrader = websocket.Upgrader{} // use default options
// SetUpgrader replaces the default [websocket.Upgrader] used by
// [LogSocketHandler].
func SetUpgrader(u websocket.Upgrader) { func SetUpgrader(u websocket.Upgrader) {
upgrader = u upgrader = u
} }
// LogSocketHandler upgrades the HTTP connection to a WebSocket and streams
// log entries to the client. An optional "namespaces" query parameter
// (comma-separated) filters which namespaces the client receives.
func LogSocketHandler(w http.ResponseWriter, r *http.Request) { func LogSocketHandler(w http.ResponseWriter, r *http.Request) {
// Get namespaces from query parameter, comma-separated // Get namespaces from query parameter, comma-separated.
// Empty or missing means all namespaces // Empty or missing means all namespaces.
namespacesParam := r.URL.Query().Get("namespaces") namespacesParam := r.URL.Query().Get("namespaces")
var namespaces []string var namespaces []string
if namespacesParam != "" { if namespacesParam != "" {
namespaces = strings.Split(namespacesParam, ",") namespaces = strings.Split(namespacesParam, ",")
} }
c, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
logger.Error("upgrade:", err) logger.Error("upgrade:", err)
return return
} }
defer c.Close() defer conn.Close()
lc := logger.CreateClient(namespaces...) lc := logger.CreateClient(namespaces...)
defer lc.Destroy() defer lc.Destroy()
lc.SetLogLevel(logger.LTrace) lc.SetLogLevel(logger.LTrace)
logger.Info("Websocket client attached.") logger.Info("Websocket client attached.")
// Start a read pump so the server detects client disconnects promptly.
// Without this, a disconnected client is only noticed when WriteMessage
// fails, which can be delayed indefinitely when no logs are produced.
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
go func() {
defer cancel()
for {
if _, _, err := conn.ReadMessage(); err != nil {
return
}
}
}()
for { for {
logEvent := lc.Get() entry, ok := lc.GetContext(ctx)
logJSON, _ := json.Marshal(logEvent) if !ok {
err = c.WriteMessage(websocket.TextMessage, logJSON) // Context cancelled — client disconnected.
if err != nil { return
}
logJSON, _ := json.Marshal(entry)
if err := conn.WriteMessage(websocket.TextMessage, logJSON); err != nil {
logger.Warn("write:", err) logger.Warn("write:", err)
break return
} }
} }
} }

127
ws/server_test.go Normal file
View File

@@ -0,0 +1,127 @@
package ws
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
logger "github.com/taigrr/log-socket/v2/log"
)
func TestSetUpgrader(t *testing.T) {
custom := websocket.Upgrader{
ReadBufferSize: 2048,
WriteBufferSize: 2048,
}
SetUpgrader(custom)
if upgrader.ReadBufferSize != 2048 {
t.Errorf("ReadBufferSize = %d, want 2048", upgrader.ReadBufferSize)
}
if upgrader.WriteBufferSize != 2048 {
t.Errorf("WriteBufferSize = %d, want 2048", upgrader.WriteBufferSize)
}
// Reset to default
SetUpgrader(websocket.Upgrader{})
}
func TestLogSocketHandler_NonWebSocket(t *testing.T) {
// A non-upgrade request should fail gracefully (upgrader returns error)
req := httptest.NewRequest(http.MethodGet, "/ws", nil)
w := httptest.NewRecorder()
LogSocketHandler(w, req)
// The upgrader should return a 400-level error for non-websocket requests
if w.Code == http.StatusOK || w.Code == http.StatusSwitchingProtocols {
t.Errorf("expected error status for non-websocket request, got %d", w.Code)
}
}
func TestLogSocketHandler_WebSocket(t *testing.T) {
// Set upgrader with permissive origin check for testing
SetUpgrader(websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
})
defer SetUpgrader(websocket.Upgrader{})
server := httptest.NewServer(http.HandlerFunc(LogSocketHandler))
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer conn.Close()
// Send a log entry and verify it arrives over the websocket
testLogger := logger.NewLogger("ws-test")
testLogger.Info("test message for websocket")
// Read messages until we find our test entry (the handler itself
// logs "Websocket client attached." which may arrive first)
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
var found bool
for i := 0; i < 10; i++ {
_, message, err := conn.ReadMessage()
if err != nil {
t.Fatalf("failed to read message: %v", err)
}
var entry logger.Entry
if err := json.Unmarshal(message, &entry); err != nil {
t.Fatalf("failed to unmarshal entry: %v", err)
}
if entry.Namespace == "ws-test" && entry.Level == "INFO" {
found = true
break
}
}
if !found {
t.Error("did not receive expected log entry with namespace ws-test")
}
}
func TestLogSocketHandler_NamespaceFilter(t *testing.T) {
SetUpgrader(websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
})
defer SetUpgrader(websocket.Upgrader{})
server := httptest.NewServer(http.HandlerFunc(LogSocketHandler))
defer server.Close()
// Connect with namespace filter
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?namespaces=filtered-ns"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer conn.Close()
// Send a log to a different namespace — it should NOT be received
otherLogger := logger.NewLogger("other-ns")
otherLogger.Info("should not arrive")
// Send a log to the filtered namespace — it SHOULD be received
filteredLogger := logger.NewLogger("filtered-ns")
filteredLogger.Info("should arrive")
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, message, err := conn.ReadMessage()
if err != nil {
t.Fatalf("failed to read message: %v", err)
}
var entry logger.Entry
if err := json.Unmarshal(message, &entry); err != nil {
t.Fatalf("failed to unmarshal entry: %v", err)
}
if entry.Namespace != "filtered-ns" {
t.Errorf("namespace = %q, want filtered-ns", entry.Namespace)
}
if !strings.Contains(entry.Output, "should arrive") {
t.Errorf("output = %q, want to contain 'should arrive'", entry.Output)
}
}