From 506c795f168e79e006bc958cb9a67aea6c00b521 Mon Sep 17 00:00:00 2001 From: clawbot Date: Thu, 19 Feb 2026 13:43:33 -0800 Subject: [PATCH] test: add CORS middleware tests (failing - TDD) --- internal/middleware/cors_test.go | 71 ++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 internal/middleware/cors_test.go diff --git a/internal/middleware/cors_test.go b/internal/middleware/cors_test.go new file mode 100644 index 0000000..dd6ac12 --- /dev/null +++ b/internal/middleware/cors_test.go @@ -0,0 +1,71 @@ +package middleware //nolint:testpackage // tests internal CORS behavior + +import ( + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "git.eeqj.de/sneak/upaas/internal/config" +) + +func newCORSTestMiddleware(corsOrigins string) *Middleware { + return &Middleware{ + log: slog.Default(), + params: &Params{ + Config: &config.Config{ + CORSOrigins: corsOrigins, + SessionSecret: "test-secret-32-bytes-long-enough", + }, + }, + } +} + +func TestCORS_NoOriginsConfigured_NoCORSHeaders(t *testing.T) { + m := newCORSTestMiddleware("") + handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "https://evil.com") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"), + "expected no CORS headers when no origins configured") +} + +func TestCORS_OriginsConfigured_AllowsMatchingOrigin(t *testing.T) { + m := newCORSTestMiddleware("https://app.example.com,https://other.example.com") + handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "https://app.example.com") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, "https://app.example.com", + rec.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "true", + rec.Header().Get("Access-Control-Allow-Credentials")) +} + +func TestCORS_OriginsConfigured_RejectsNonMatchingOrigin(t *testing.T) { + m := newCORSTestMiddleware("https://app.example.com") + handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "https://evil.com") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"), + "expected no CORS headers for non-matching origin") +}