72 lines
2.1 KiB
Go
72 lines
2.1 KiB
Go
package middleware //nolint:testpackage // tests internal CORS behavior
|
|
|
|
import (
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"git.eeqj.de/sneak/upaas/internal/config"
|
|
)
|
|
|
|
func newCORSTestMiddleware(corsOrigins string) *Middleware {
|
|
return &Middleware{
|
|
log: slog.Default(),
|
|
params: &Params{
|
|
Config: &config.Config{
|
|
CORSOrigins: corsOrigins,
|
|
SessionSecret: "test-secret-32-bytes-long-enough",
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func TestCORS_NoOriginsConfigured_NoCORSHeaders(t *testing.T) {
|
|
m := newCORSTestMiddleware("")
|
|
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "https://evil.com")
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"),
|
|
"expected no CORS headers when no origins configured")
|
|
}
|
|
|
|
func TestCORS_OriginsConfigured_AllowsMatchingOrigin(t *testing.T) {
|
|
m := newCORSTestMiddleware("https://app.example.com,https://other.example.com")
|
|
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "https://app.example.com")
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, "https://app.example.com",
|
|
rec.Header().Get("Access-Control-Allow-Origin"))
|
|
assert.Equal(t, "true",
|
|
rec.Header().Get("Access-Control-Allow-Credentials"))
|
|
}
|
|
|
|
func TestCORS_OriginsConfigured_RejectsNonMatchingOrigin(t *testing.T) {
|
|
m := newCORSTestMiddleware("https://app.example.com")
|
|
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "https://evil.com")
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"),
|
|
"expected no CORS headers for non-matching origin")
|
|
}
|