fix: restrict CORS to configured origins (closes #40)
- Add CORSOrigins config field (UPAAS_CORS_ORIGINS env var) - Default to same-origin only (no CORS headers when unconfigured) - When configured, allow specified origins with AllowCredentials: true - Add tests for CORS middleware behavior
This commit is contained in:
parent
506c795f16
commit
02847eea92
@ -52,6 +52,7 @@ type Config struct {
|
|||||||
MetricsUsername string
|
MetricsUsername string
|
||||||
MetricsPassword string
|
MetricsPassword string
|
||||||
SessionSecret string
|
SessionSecret string
|
||||||
|
CORSOrigins string
|
||||||
params *Params
|
params *Params
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
}
|
}
|
||||||
@ -102,6 +103,7 @@ func setupViper(name string) {
|
|||||||
viper.SetDefault("METRICS_USERNAME", "")
|
viper.SetDefault("METRICS_USERNAME", "")
|
||||||
viper.SetDefault("METRICS_PASSWORD", "")
|
viper.SetDefault("METRICS_PASSWORD", "")
|
||||||
viper.SetDefault("SESSION_SECRET", "")
|
viper.SetDefault("SESSION_SECRET", "")
|
||||||
|
viper.SetDefault("CORS_ORIGINS", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildConfig(log *slog.Logger, params *Params) (*Config, error) {
|
func buildConfig(log *slog.Logger, params *Params) (*Config, error) {
|
||||||
@ -136,6 +138,7 @@ func buildConfig(log *slog.Logger, params *Params) (*Config, error) {
|
|||||||
MetricsUsername: viper.GetString("METRICS_USERNAME"),
|
MetricsUsername: viper.GetString("METRICS_USERNAME"),
|
||||||
MetricsPassword: viper.GetString("METRICS_PASSWORD"),
|
MetricsPassword: viper.GetString("METRICS_PASSWORD"),
|
||||||
SessionSecret: viper.GetString("SESSION_SECRET"),
|
SessionSecret: viper.GetString("SESSION_SECRET"),
|
||||||
|
CORSOrigins: viper.GetString("CORS_ORIGINS"),
|
||||||
params: params,
|
params: params,
|
||||||
log: log,
|
log: log,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import (
|
|||||||
"git.eeqj.de/sneak/upaas/internal/config"
|
"git.eeqj.de/sneak/upaas/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//nolint:gosec // test credentials
|
||||||
func newCORSTestMiddleware(corsOrigins string) *Middleware {
|
func newCORSTestMiddleware(corsOrigins string) *Middleware {
|
||||||
return &Middleware{
|
return &Middleware{
|
||||||
log: slog.Default(),
|
log: slog.Default(),
|
||||||
@ -24,6 +25,8 @@ func newCORSTestMiddleware(corsOrigins string) *Middleware {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCORS_NoOriginsConfigured_NoCORSHeaders(t *testing.T) {
|
func TestCORS_NoOriginsConfigured_NoCORSHeaders(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
m := newCORSTestMiddleware("")
|
m := newCORSTestMiddleware("")
|
||||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@ -31,6 +34,7 @@ func TestCORS_NoOriginsConfigured_NoCORSHeaders(t *testing.T) {
|
|||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
req.Header.Set("Origin", "https://evil.com")
|
req.Header.Set("Origin", "https://evil.com")
|
||||||
|
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(rec, req)
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
@ -39,6 +43,8 @@ func TestCORS_NoOriginsConfigured_NoCORSHeaders(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCORS_OriginsConfigured_AllowsMatchingOrigin(t *testing.T) {
|
func TestCORS_OriginsConfigured_AllowsMatchingOrigin(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
m := newCORSTestMiddleware("https://app.example.com,https://other.example.com")
|
m := newCORSTestMiddleware("https://app.example.com,https://other.example.com")
|
||||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@ -46,6 +52,7 @@ func TestCORS_OriginsConfigured_AllowsMatchingOrigin(t *testing.T) {
|
|||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
req.Header.Set("Origin", "https://app.example.com")
|
req.Header.Set("Origin", "https://app.example.com")
|
||||||
|
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(rec, req)
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
@ -56,6 +63,8 @@ func TestCORS_OriginsConfigured_AllowsMatchingOrigin(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCORS_OriginsConfigured_RejectsNonMatchingOrigin(t *testing.T) {
|
func TestCORS_OriginsConfigured_RejectsNonMatchingOrigin(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
m := newCORSTestMiddleware("https://app.example.com")
|
m := newCORSTestMiddleware("https://app.example.com")
|
||||||
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@ -63,6 +72,7 @@ func TestCORS_OriginsConfigured_RejectsNonMatchingOrigin(t *testing.T) {
|
|||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
req.Header.Set("Origin", "https://evil.com")
|
req.Header.Set("Origin", "https://evil.com")
|
||||||
|
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(rec, req)
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
|||||||
@ -177,17 +177,48 @@ func realIP(r *http.Request) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CORS returns CORS middleware.
|
// CORS returns CORS middleware.
|
||||||
|
// When UPAAS_CORS_ORIGINS is empty (default), no CORS headers are sent
|
||||||
|
// (same-origin only). When configured, only the specified origins are
|
||||||
|
// allowed and credentials (cookies) are permitted.
|
||||||
func (m *Middleware) CORS() func(http.Handler) http.Handler {
|
func (m *Middleware) CORS() func(http.Handler) http.Handler {
|
||||||
|
origins := parseCORSOrigins(m.params.Config.CORSOrigins)
|
||||||
|
|
||||||
|
// No origins configured — no CORS headers (same-origin policy).
|
||||||
|
if len(origins) == 0 {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return cors.Handler(cors.Options{
|
return cors.Handler(cors.Options{
|
||||||
AllowedOrigins: []string{"*"},
|
AllowedOrigins: origins,
|
||||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||||
ExposedHeaders: []string{"Link"},
|
ExposedHeaders: []string{"Link"},
|
||||||
AllowCredentials: false,
|
AllowCredentials: true,
|
||||||
MaxAge: corsMaxAge,
|
MaxAge: corsMaxAge,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseCORSOrigins splits a comma-separated origin string into a slice,
|
||||||
|
// trimming whitespace. Returns nil if the input is empty.
|
||||||
|
func parseCORSOrigins(raw string) []string {
|
||||||
|
if raw == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(raw, ",")
|
||||||
|
origins := make([]string, 0, len(parts))
|
||||||
|
|
||||||
|
for _, p := range parts {
|
||||||
|
if o := strings.TrimSpace(p); o != "" {
|
||||||
|
origins = append(origins, o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return origins
|
||||||
|
}
|
||||||
|
|
||||||
// MetricsAuth returns basic auth middleware for metrics endpoint.
|
// MetricsAuth returns basic auth middleware for metrics endpoint.
|
||||||
func (m *Middleware) MetricsAuth() func(http.Handler) http.Handler {
|
func (m *Middleware) MetricsAuth() func(http.Handler) http.Handler {
|
||||||
if m.params.Config.MetricsUsername == "" {
|
if m.params.Config.MetricsUsername == "" {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user