fix: restrict CORS to configured origins (closes #40) #92
71
internal/middleware/cors_test.go
Normal file
71
internal/middleware/cors_test.go
Normal file
@ -0,0 +1,71 @@
|
||||
package middleware //nolint:testpackage // tests internal CORS behavior
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"git.eeqj.de/sneak/upaas/internal/config"
|
||||
)
|
||||
|
||||
func newCORSTestMiddleware(corsOrigins string) *Middleware {
|
||||
return &Middleware{
|
||||
log: slog.Default(),
|
||||
params: &Params{
|
||||
Config: &config.Config{
|
||||
CORSOrigins: corsOrigins,
|
||||
SessionSecret: "test-secret-32-bytes-long-enough",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_NoOriginsConfigured_NoCORSHeaders(t *testing.T) {
|
||||
m := newCORSTestMiddleware("")
|
||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "https://evil.com")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"),
|
||||
"expected no CORS headers when no origins configured")
|
||||
}
|
||||
|
||||
func TestCORS_OriginsConfigured_AllowsMatchingOrigin(t *testing.T) {
|
||||
m := newCORSTestMiddleware("https://app.example.com,https://other.example.com")
|
||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "https://app.example.com")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, "https://app.example.com",
|
||||
rec.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.Equal(t, "true",
|
||||
rec.Header().Get("Access-Control-Allow-Credentials"))
|
||||
}
|
||||
|
||||
func TestCORS_OriginsConfigured_RejectsNonMatchingOrigin(t *testing.T) {
|
||||
m := newCORSTestMiddleware("https://app.example.com")
|
||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "https://evil.com")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"),
|
||||
"expected no CORS headers for non-matching origin")
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user