diff --git a/server.go b/server.go index aee5819..174ad4b 100644 --- a/server.go +++ b/server.go @@ -43,6 +43,9 @@ const ( // maxPasteSize limits the maximum size of a paste body (1 MB). maxPasteSize = 1 << 20 + // maxCollisionRetries is the number of times to retry generating a paste ID on collision. + maxCollisionRetries = 3 + // shutdownTimeout is the maximum time to wait for in-flight requests during shutdown. shutdownTimeout = 10 * time.Second @@ -117,7 +120,13 @@ func (s *Server) initRoutes() { s.mux.HandleFunc("GET /healthz", s.handleHealthz) } -func (s *Server) renderTemplate(name string, w http.ResponseWriter, data any) { +// templateData holds data passed to HTML templates. +type templateData struct { + Blob string + UUID string +} + +func (s *Server) renderTemplate(name string, w http.ResponseWriter, data *templateData) { var buf strings.Builder if err := s.templates.ExecuteTemplate(&buf, name, data); err != nil { log.Printf("error executing template %s: %v", name, err) @@ -140,7 +149,7 @@ func (s *Server) handleIndex(w http.ResponseWriter, r *http.Request) { contentType := negotiateContentType(r) switch contentType { case contentTypeHTML: - s.renderTemplate("base", w, nil) + s.renderTemplate("base", w, &templateData{}) default: w.Header().Set(headerContentType, contentTypePlain) _, _ = fmt.Fprintln(w, "pastebin service - POST a 'blob' form field to create a paste") @@ -159,12 +168,16 @@ func (s *Server) handlePaste(w http.ResponseWriter, r *http.Request) { pasteID := RandomString(pasteIDLength) // Retry on the extremely unlikely collision. - for retries := 0; retries < 3; retries++ { + for retries := 0; retries < maxCollisionRetries; retries++ { if _, found := s.store.Get(pasteID); !found { break } pasteID = RandomString(pasteIDLength) } + if _, found := s.store.Get(pasteID); found { + http.Error(w, "Internal Server Error: ID collision", http.StatusInternalServerError) + return + } s.store.Set(pasteID, blob, cache.DefaultExpiration) pastePath, err := url.Parse(fmt.Sprintf("./p/%s", pasteID)) @@ -208,10 +221,7 @@ func (s *Server) handleView(w http.ResponseWriter, r *http.Request) { contentType := negotiateContentType(r) switch contentType { case contentTypeHTML: - s.renderTemplate("base", w, struct { - Blob string - UUID string - }{ + s.renderTemplate("base", w, &templateData{ Blob: blob, UUID: pasteID, }) diff --git a/utils.go b/utils.go index 2b29e38..b73e8d7 100644 --- a/utils.go +++ b/utils.go @@ -3,14 +3,22 @@ package main import ( "crypto/rand" "encoding/base64" + "fmt" ) // RandomString generates a URL-safe random string of the specified length. // It reads random bytes and encodes them using base64 URL encoding. // The resulting string is truncated to the requested length. +// It panics if the system's cryptographic random number generator fails, +// which indicates a fundamental system issue. func RandomString(length int) string { + if length <= 0 { + return "" + } rawBytes := make([]byte, length*2) - _, _ = rand.Read(rawBytes) + if _, err := rand.Read(rawBytes); err != nil { + panic(fmt.Sprintf("crypto/rand.Read failed: %v", err)) + } encoded := base64.URLEncoding.EncodeToString(rawBytes) return encoded[:length] } diff --git a/utils_test.go b/utils_test.go index fb40ef8..ae02e13 100644 --- a/utils_test.go +++ b/utils_test.go @@ -23,6 +23,11 @@ func TestRandomStringUniqueness(t *testing.T) { } } +func TestRandomStringZeroLength(t *testing.T) { + assert.Equal(t, "", RandomString(0)) + assert.Equal(t, "", RandomString(-1)) +} + func TestRandomStringURLSafe(t *testing.T) { for range 50 { result := RandomString(32)