From 9384170eb025f4771ba8ac96eb2ae51e5dbbf82f Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Tue, 7 Apr 2026 10:08:50 +0000 Subject: [PATCH] feat(ws): add read pump for disconnect detection, GetContext, Level.String - Add WebSocket read pump in LogSocketHandler so client disconnects are detected promptly instead of only on the next WriteMessage call - Add Client.GetContext(ctx) for context-aware blocking reads - Implement fmt.Stringer on Level type (Level.String()) - Add tests for GetContext, GetContext cancellation, and Level.String() --- log/log.go | 17 +++++++++++++ log/log_test.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++-- log/types.go | 25 ++++++++++++++++++ main.go | 10 ++++---- ws/server.go | 44 +++++++++++++++++++++++++------- 5 files changed, 147 insertions(+), 16 deletions(-) diff --git a/log/log.go b/log/log.go index 896b368..f3243f6 100644 --- a/log/log.go +++ b/log/log.go @@ -1,6 +1,7 @@ package log import ( + "context" "errors" "fmt" "os" @@ -153,6 +154,7 @@ func (c *Client) SetLogLevel(level Level) { c.LogLevel = level } +// Get blocks until a log entry is available and returns it. func (c *Client) Get() Entry { if !c.initialized { panic(errors.New("cannot get logs for uninitialized client, did you use CreateClient?")) @@ -160,6 +162,21 @@ func (c *Client) Get() Entry { return <-c.writer } +// GetContext blocks until a log entry is available or ctx is cancelled. +// The second return value is false when the context was cancelled before +// an entry arrived. +func (c *Client) GetContext(ctx context.Context) (Entry, bool) { + if !c.initialized { + panic(errors.New("cannot get logs for uninitialized client, did you use CreateClient?")) + } + select { + case e := <-c.writer: + return e, true + case <-ctx.Done(): + return Entry{}, false + } +} + // Trace prints out logs on trace level func Trace(args ...any) { output := fmt.Sprint(args...) diff --git a/log/log_test.go b/log/log_test.go index e23eed7..f9477a6 100644 --- a/log/log_test.go +++ b/log/log_test.go @@ -1,6 +1,7 @@ package log import ( + "context" "strconv" "sync" "testing" @@ -378,8 +379,8 @@ func TestMultiNamespaceClient(t *testing.T) { authLogger := NewLogger("auth") dbLogger := NewLogger("database") - dbLogger.Info("db message") // filtered out - apiLogger.Info("api message") // should arrive + dbLogger.Info("db message") // filtered out + apiLogger.Info("api message") // should arrive authLogger.Info("auth message") // should arrive e1, ok := getEntry(c, time.Second) @@ -550,6 +551,68 @@ func TestMatchesNamespace(t *testing.T) { c2.Destroy() } +// TestGetContext verifies context cancellation stops blocking Get. +func TestGetContext(t *testing.T) { + c := CreateClient(DefaultNamespace) + c.SetLogLevel(LTrace) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + _, ok := c.GetContext(ctx) + if ok { + t.Error("expected GetContext to return false on cancelled context") + } + c.Destroy() +} + +// TestGetContextReceivesEntry verifies GetContext delivers entries normally. +func TestGetContextReceivesEntry(t *testing.T) { + c := CreateClient(DefaultNamespace) + c.SetLogLevel(LTrace) + + go func() { + time.Sleep(10 * time.Millisecond) + Info("context entry") + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + e, ok := c.GetContext(ctx) + if !ok { + t.Fatal("expected GetContext to return entry") + } + if e.Output != "context entry" { + t.Errorf("output = %q, want %q", e.Output, "context entry") + } + c.Destroy() +} + +// TestLevelString verifies the Level.String() method. +func TestLevelString(t *testing.T) { + tests := []struct { + level Level + want string + }{ + {LTrace, "TRACE"}, + {LDebug, "DEBUG"}, + {LInfo, "INFO"}, + {LNotice, "NOTICE"}, + {LWarn, "WARN"}, + {LError, "ERROR"}, + {LPanic, "PANIC"}, + {LFatal, "FATAL"}, + {Level(99), "UNKNOWN"}, + } + for _, tt := range tests { + got := tt.level.String() + if got != tt.want { + t.Errorf("Level(%d).String() = %q, want %q", tt.level, got, tt.want) + } + } +} + func TestFlush(t *testing.T) { defer Flush() } diff --git a/log/types.go b/log/types.go index dcaf7bd..baa94ba 100644 --- a/log/types.go +++ b/log/types.go @@ -15,6 +15,31 @@ const ( const DefaultNamespace = "default" +// String returns the human-readable name of the log level (e.g. "INFO"). +// It implements [fmt.Stringer]. +func (l Level) String() string { + switch l { + case LTrace: + return "TRACE" + case LDebug: + return "DEBUG" + case LInfo: + return "INFO" + case LNotice: + return "NOTICE" + case LWarn: + return "WARN" + case LError: + return "ERROR" + case LPanic: + return "PANIC" + case LFatal: + return "FATAL" + default: + return "UNKNOWN" + } +} + type ( LogWriter chan Entry Level int diff --git a/main.go b/main.go index 954285d..f1c2d77 100644 --- a/main.go +++ b/main.go @@ -17,21 +17,21 @@ func generateLogs() { apiLogger := logger.NewLogger("api") dbLogger := logger.NewLogger("database") authLogger := logger.NewLogger("auth") - + for { logger.Info("This is a default namespace log!") apiLogger.Info("API request received") apiLogger.Debug("Processing API call") - + dbLogger.Info("Database query executed") dbLogger.Warn("Slow query detected") - + authLogger.Info("User authentication successful") authLogger.Error("Failed login attempt detected") - + logger.Trace("This is a trace log in default namespace!") logger.Warn("This is a warning in default namespace!") - + time.Sleep(2 * time.Second) } } diff --git a/ws/server.go b/ws/server.go index 4c7027f..9c914b2 100644 --- a/ws/server.go +++ b/ws/server.go @@ -1,6 +1,7 @@ package ws import ( + "context" "encoding/json" "net/http" "strings" @@ -11,36 +12,61 @@ import ( var upgrader = websocket.Upgrader{} // use default options +// SetUpgrader replaces the default [websocket.Upgrader] used by +// [LogSocketHandler]. func SetUpgrader(u websocket.Upgrader) { upgrader = u } +// LogSocketHandler upgrades the HTTP connection to a WebSocket and streams +// log entries to the client. An optional "namespaces" query parameter +// (comma-separated) filters which namespaces the client receives. func LogSocketHandler(w http.ResponseWriter, r *http.Request) { - // Get namespaces from query parameter, comma-separated - // Empty or missing means all namespaces + // Get namespaces from query parameter, comma-separated. + // Empty or missing means all namespaces. namespacesParam := r.URL.Query().Get("namespaces") var namespaces []string if namespacesParam != "" { namespaces = strings.Split(namespacesParam, ",") } - c, err := upgrader.Upgrade(w, r, nil) + conn, err := upgrader.Upgrade(w, r, nil) if err != nil { logger.Error("upgrade:", err) return } - defer c.Close() + defer conn.Close() + lc := logger.CreateClient(namespaces...) defer lc.Destroy() lc.SetLogLevel(logger.LTrace) logger.Info("Websocket client attached.") + + // Start a read pump so the server detects client disconnects promptly. + // Without this, a disconnected client is only noticed when WriteMessage + // fails, which can be delayed indefinitely when no logs are produced. + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + go func() { + defer cancel() + for { + if _, _, err := conn.ReadMessage(); err != nil { + return + } + } + }() + for { - logEvent := lc.Get() - logJSON, _ := json.Marshal(logEvent) - err = c.WriteMessage(websocket.TextMessage, logJSON) - if err != nil { + entry, ok := lc.GetContext(ctx) + if !ok { + // Context cancelled — client disconnected. + return + } + logJSON, _ := json.Marshal(entry) + if err := conn.WriteMessage(websocket.TextMessage, logJSON); err != nil { logger.Warn("write:", err) - break + return } } }