fix: restrict CORS to configured origins (closes #40)
- Add CORSOrigins config field (UPAAS_CORS_ORIGINS env var) - Default to same-origin only (no CORS headers when unconfigured) - When configured, allow specified origins with AllowCredentials: true - Add tests for CORS middleware behavior
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"git.eeqj.de/sneak/upaas/internal/config"
|
||||
)
|
||||
|
||||
//nolint:gosec // test credentials
|
||||
func newCORSTestMiddleware(corsOrigins string) *Middleware {
|
||||
return &Middleware{
|
||||
log: slog.Default(),
|
||||
@@ -24,6 +25,8 @@ func newCORSTestMiddleware(corsOrigins string) *Middleware {
|
||||
}
|
||||
|
||||
func TestCORS_NoOriginsConfigured_NoCORSHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := newCORSTestMiddleware("")
|
||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -31,6 +34,7 @@ func TestCORS_NoOriginsConfigured_NoCORSHeaders(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "https://evil.com")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
@@ -39,6 +43,8 @@ func TestCORS_NoOriginsConfigured_NoCORSHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCORS_OriginsConfigured_AllowsMatchingOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := newCORSTestMiddleware("https://app.example.com,https://other.example.com")
|
||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -46,6 +52,7 @@ func TestCORS_OriginsConfigured_AllowsMatchingOrigin(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "https://app.example.com")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
@@ -56,6 +63,8 @@ func TestCORS_OriginsConfigured_AllowsMatchingOrigin(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCORS_OriginsConfigured_RejectsNonMatchingOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := newCORSTestMiddleware("https://app.example.com")
|
||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -63,6 +72,7 @@ func TestCORS_OriginsConfigured_RejectsNonMatchingOrigin(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "https://evil.com")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
|
||||
@@ -177,17 +177,48 @@ func realIP(r *http.Request) string {
|
||||
}
|
||||
|
||||
// CORS returns CORS middleware.
|
||||
// When UPAAS_CORS_ORIGINS is empty (default), no CORS headers are sent
|
||||
// (same-origin only). When configured, only the specified origins are
|
||||
// allowed and credentials (cookies) are permitted.
|
||||
func (m *Middleware) CORS() func(http.Handler) http.Handler {
|
||||
origins := parseCORSOrigins(m.params.Config.CORSOrigins)
|
||||
|
||||
// No origins configured — no CORS headers (same-origin policy).
|
||||
if len(origins) == 0 {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
}
|
||||
|
||||
return cors.Handler(cors.Options{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedOrigins: origins,
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
AllowCredentials: false,
|
||||
AllowCredentials: true,
|
||||
MaxAge: corsMaxAge,
|
||||
})
|
||||
}
|
||||
|
||||
// parseCORSOrigins splits a comma-separated origin string into a slice,
|
||||
// trimming whitespace. Returns nil if the input is empty.
|
||||
func parseCORSOrigins(raw string) []string {
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Split(raw, ",")
|
||||
origins := make([]string, 0, len(parts))
|
||||
|
||||
for _, p := range parts {
|
||||
if o := strings.TrimSpace(p); o != "" {
|
||||
origins = append(origins, o)
|
||||
}
|
||||
}
|
||||
|
||||
return origins
|
||||
}
|
||||
|
||||
// MetricsAuth returns basic auth middleware for metrics endpoint.
|
||||
func (m *Middleware) MetricsAuth() func(http.Handler) http.Handler {
|
||||
if m.params.Config.MetricsUsername == "" {
|
||||
@@ -235,9 +266,9 @@ func (m *Middleware) CSRF() func(http.Handler) http.Handler {
|
||||
// loginRateLimit configures the login rate limiter.
|
||||
const (
|
||||
loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds
|
||||
loginBurst = 5 // allow burst of 5
|
||||
limiterExpiry = 10 * time.Minute // evict entries not seen in 10 minutes
|
||||
limiterCleanupEvery = 1 * time.Minute // sweep interval
|
||||
loginBurst = 5 // allow burst of 5
|
||||
limiterExpiry = 10 * time.Minute // evict entries not seen in 10 minutes
|
||||
limiterCleanupEvery = 1 * time.Minute // sweep interval
|
||||
)
|
||||
|
||||
// ipLimiterEntry stores a rate limiter with its last-seen timestamp.
|
||||
@@ -249,8 +280,8 @@ type ipLimiterEntry struct {
|
||||
// ipLimiter tracks per-IP rate limiters for login attempts with automatic
|
||||
// eviction of stale entries to prevent unbounded memory growth.
|
||||
type ipLimiter struct {
|
||||
mu sync.Mutex
|
||||
limiters map[string]*ipLimiterEntry
|
||||
mu sync.Mutex
|
||||
limiters map[string]*ipLimiterEntry
|
||||
lastSweep time.Time
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user