From 21bc47697bc2575d9017be1d0d540f9121aa3e99 Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Mon, 9 Mar 2026 06:00:16 +0000 Subject: [PATCH] feat: add graceful shutdown, health endpoint, and improved client API --- client/client.go | 58 ++++++++++++++++++++++++++++++++----------- client/client_test.go | 31 ++++++++++++++++++++--- server.go | 44 ++++++++++++++++++++++++++++++-- server_test.go | 13 ++++++++++ 4 files changed, 127 insertions(+), 19 deletions(-) diff --git a/client/client.go b/client/client.go index 4220afd..206d17b 100644 --- a/client/client.go +++ b/client/client.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/url" + "os" "strings" ) @@ -17,23 +18,38 @@ const ( // Client is a pastebin API client. type Client struct { - url string - insecure bool + serviceURL string + insecure bool + output io.Writer } // NewClient creates a new pastebin Client. // When insecure is true, TLS certificate verification is skipped. +// Output is written to stdout by default; use WithOutput to change. func NewClient(serviceURL string, insecure bool) *Client { - return &Client{url: serviceURL, insecure: insecure} + return &Client{serviceURL: serviceURL, insecure: insecure, output: os.Stdout} +} + +// WithOutput sets the writer where paste URLs are printed. +// Returns the Client for chaining. +func (client *Client) WithOutput(writer io.Writer) *Client { + client.output = writer + return client } // Paste reads from body and submits it as a new paste. -// It prints the resulting paste URL to stdout. -func (c *Client) Paste(body io.Reader) error { +// It prints the resulting paste URL to the configured output writer. +func (client *Client) Paste(body io.Reader) error { transport := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: c.insecure}, //nolint:gosec // user-requested skip + TLSClientConfig: &tls.Config{InsecureSkipVerify: client.insecure}, //nolint:gosec // user-requested skip + } + httpClient := &http.Client{ + Transport: transport, + // Don't follow redirects; capture the URL from the response. + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, } - httpClient := &http.Client{Transport: transport} var builder strings.Builder if _, err := io.Copy(&builder, body); err != nil { @@ -43,16 +59,30 @@ func (c *Client) Paste(body io.Reader) error { formValues := url.Values{} formValues.Set(formFieldBlob, builder.String()) - resp, err := httpClient.PostForm(c.url, formValues) + resp, err := httpClient.PostForm(client.serviceURL, formValues) if err != nil { - return fmt.Errorf("posting paste to %s: %w", c.url, err) + return fmt.Errorf("posting paste to %s: %w", client.serviceURL, err) } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusMovedPermanently { - return fmt.Errorf("unexpected response from %s: %d", c.url, resp.StatusCode) + switch resp.StatusCode { + case http.StatusOK: + // Plain text response contains the paste URL in the body. + responseBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return fmt.Errorf("reading response body: %w", readErr) + } + fmt.Fprint(client.output, string(responseBody)) + return nil + case http.StatusFound, http.StatusMovedPermanently: + // HTML response redirects to the paste URL. + location := resp.Header.Get("Location") + if location == "" { + return fmt.Errorf("redirect response missing Location header") + } + fmt.Fprint(client.output, location) + return nil + default: + return fmt.Errorf("unexpected response from %s: %d", client.serviceURL, resp.StatusCode) } - - fmt.Print(resp.Request.URL.String()) - return nil } diff --git a/client/client_test.go b/client/client_test.go index 3c1d1e9..3d58faa 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,6 +1,7 @@ package client import ( + "bytes" "net/http" "net/http/httptest" "strings" @@ -18,12 +19,29 @@ func TestPasteSuccess(t *testing.T) { blob := r.FormValue("blob") assert.Equal(t, "test content", blob) w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(r.Host + "/p/abc123")) })) defer server.Close() - cli := NewClient(server.URL, false) + var output bytes.Buffer + cli := NewClient(server.URL, false).WithOutput(&output) err := cli.Paste(strings.NewReader("test content")) assert.NoError(t, err) + assert.Contains(t, output.String(), "/p/abc123") +} + +func TestPasteRedirect(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", "/p/redirect123") + w.WriteHeader(http.StatusFound) + })) + defer server.Close() + + var output bytes.Buffer + cli := NewClient(server.URL, false).WithOutput(&output) + err := cli.Paste(strings.NewReader("redirect content")) + assert.NoError(t, err) + assert.Equal(t, "/p/redirect123", output.String()) } func TestPasteServerError(t *testing.T) { @@ -32,7 +50,8 @@ func TestPasteServerError(t *testing.T) { })) defer server.Close() - cli := NewClient(server.URL, false) + var output bytes.Buffer + cli := NewClient(server.URL, false).WithOutput(&output) err := cli.Paste(strings.NewReader("test content")) assert.Error(t, err) assert.Contains(t, err.Error(), "unexpected response") @@ -46,6 +65,12 @@ func TestPasteInvalidURL(t *testing.T) { func TestNewClient(t *testing.T) { cli := NewClient("http://example.com", true) - assert.Equal(t, "http://example.com", cli.url) + assert.Equal(t, "http://example.com", cli.serviceURL) assert.True(t, cli.insecure) } + +func TestWithOutput(t *testing.T) { + var buf bytes.Buffer + cli := NewClient("http://example.com", false).WithOutput(&buf) + assert.Equal(t, &buf, cli.output) +} diff --git a/server.go b/server.go index c908eae..aee5819 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,7 @@ package main import ( + "context" "embed" "encoding/json" "fmt" @@ -10,8 +11,10 @@ import ( "log" "net/http" "net/url" + "os/signal" "strconv" "strings" + "syscall" "time" "github.com/patrickmn/go-cache" @@ -39,6 +42,11 @@ const ( // maxPasteSize limits the maximum size of a paste body (1 MB). maxPasteSize = 1 << 20 + + // shutdownTimeout is the maximum time to wait for in-flight requests during shutdown. + shutdownTimeout = 10 * time.Second + + statusHealthy = "healthy" ) // Server holds the pastebin HTTP server state. @@ -65,9 +73,31 @@ func NewServer(config Config) *Server { } // ListenAndServe starts the HTTP server on the configured bind address. +// It handles graceful shutdown on SIGINT and SIGTERM. func (s *Server) ListenAndServe() error { - log.Printf("pastebin listening on %s", s.config.Bind) - return http.ListenAndServe(s.config.Bind, s.mux) + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + server := &http.Server{ + Addr: s.config.Bind, + Handler: s.mux, + } + + errChan := make(chan error, 1) + go func() { + log.Printf("pastebin listening on %s", s.config.Bind) + errChan <- server.ListenAndServe() + }() + + select { + case err := <-errChan: + return err + case <-ctx.Done(): + log.Println("shutting down gracefully...") + shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + return server.Shutdown(shutdownCtx) + } } func (s *Server) initRoutes() { @@ -84,6 +114,7 @@ func (s *Server) initRoutes() { s.mux.HandleFunc("POST /delete/{uuid}", s.handleDelete) s.mux.HandleFunc("GET /download/{uuid}", s.handleDownload) s.mux.HandleFunc("GET /debug/stats", s.handleStats) + s.mux.HandleFunc("GET /healthz", s.handleHealthz) } func (s *Server) renderTemplate(name string, w http.ResponseWriter, data any) { @@ -248,3 +279,12 @@ func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request) { log.Printf("error encoding stats: %v", err) } } + +func (s *Server) handleHealthz(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(headerContentType, contentTypeJSON) + if err := json.NewEncoder(w).Encode(struct { + Status string `json:"status"` + }{Status: statusHealthy}); err != nil { + log.Printf("error encoding health response: %v", err) + } +} diff --git a/server_test.go b/server_test.go index e1223d2..c57a592 100644 --- a/server_test.go +++ b/server_test.go @@ -353,6 +353,19 @@ func TestPasteRoundTripSpecialChars(t *testing.T) { assert.Equal(t, specialContent, viewRec.Body.String()) } +func TestHealthzHandler(t *testing.T) { + server := newTestServer() + + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + rec := httptest.NewRecorder() + + server.mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, rec.Header().Get("Content-Type"), "application/json") + assert.Contains(t, rec.Body.String(), `"status":"healthy"`) +} + func TestViewWithTabs(t *testing.T) { server := newTestServer()