diff --git a/README.md b/README.md index 1927e93..9128c6d 100644 --- a/README.md +++ b/README.md @@ -710,7 +710,8 @@ webhooker/ │ │ └── globals.go # Build-time variables (appname, version, arch) │ ├── delivery/ │ │ ├── engine.go # Event-driven delivery engine (channel + timer based) -│ │ └── circuit_breaker.go # Per-target circuit breaker for HTTP targets with retries +│ │ ├── circuit_breaker.go # Per-target circuit breaker for HTTP targets with retries +│ │ └── ssrf.go # SSRF prevention (IP validation, safe HTTP transport) │ ├── handlers/ │ │ ├── handlers.go # Base handler struct, JSON helpers, template rendering │ │ ├── auth.go # Login, logout handlers @@ -724,7 +725,9 @@ webhooker/ │ ├── logger/ │ │ └── logger.go # slog setup with TTY detection │ ├── middleware/ -│ │ └── middleware.go # Logging, CORS, Auth, Metrics, MetricsAuth +│ │ ├── middleware.go # Logging, CORS, Auth, Metrics, MetricsAuth +│ │ ├── csrf.go # CSRF protection middleware (session-based tokens) +│ │ └── ratelimit.go # Per-IP rate limiting middleware (login endpoint) │ ├── server/ │ │ ├── server.go # Server struct, fx lifecycle, signal handling │ │ ├── http.go # HTTP server setup with timeouts @@ -799,6 +802,17 @@ Applied to all routes in this order: - Session cookies are HttpOnly, SameSite Lax, Secure (prod only) - Session key is a 32-byte value auto-generated on first startup and stored in the database +- **CSRF protection** on all state-changing forms (session-based tokens + with constant-time comparison). Applied to `/pages`, `/sources`, + `/source`, and `/user` routes. Excluded from `/webhook` (inbound + webhook POSTs) and `/api` (stateless API) +- **SSRF prevention** for HTTP delivery targets: private/reserved IP + ranges (RFC 1918, loopback, link-local, cloud metadata) are blocked + both at target creation time (URL validation) and at delivery time + (custom HTTP transport with SSRF-safe dialer that validates resolved + IPs before connecting, preventing DNS rebinding attacks) +- **Login rate limiting**: per-IP rate limiter on the login endpoint + (5 attempts per minute per IP) to prevent brute-force attacks - Prometheus metrics behind basic auth - Static assets embedded in binary (no filesystem access needed at runtime) @@ -875,7 +889,12 @@ linted, tested, and compiled. - [ ] Per-webhook rate limiting in the receiver handler - [ ] Webhook signature verification (GitHub, Stripe formats) - [ ] Security headers (HSTS, CSP, X-Frame-Options) -- [ ] CSRF protection for forms +- [x] CSRF protection for forms + ([#35](https://git.eeqj.de/sneak/webhooker/issues/35)) +- [x] SSRF prevention for HTTP delivery targets + ([#36](https://git.eeqj.de/sneak/webhooker/issues/36)) +- [x] Login rate limiting (per-IP brute-force protection) + ([#37](https://git.eeqj.de/sneak/webhooker/issues/37)) - [ ] Session expiration and "remember me" - [ ] Password change/reset flow - [ ] API key authentication for programmatic access diff --git a/go.mod b/go.mod index 84cbebe..16f4ce1 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module sneak.berlin/go/webhooker -go 1.23.0 +go 1.24.0 toolchain go1.24.1 @@ -17,6 +17,7 @@ require ( github.com/stretchr/testify v1.8.4 go.uber.org/fx v1.20.1 golang.org/x/crypto v0.38.0 + golang.org/x/time v0.14.0 gorm.io/driver/sqlite v1.5.4 gorm.io/gorm v1.25.5 modernc.org/sqlite v1.28.0 diff --git a/go.sum b/go.sum index 1571d29..5b4395e 100644 --- a/go.sum +++ b/go.sum @@ -103,6 +103,8 @@ golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/delivery/engine.go b/internal/delivery/engine.go index 4a4a1a6..2c7044d 100644 --- a/internal/delivery/engine.go +++ b/internal/delivery/engine.go @@ -146,7 +146,8 @@ func New(lc fx.Lifecycle, params EngineParams) *Engine { dbManager: params.DBManager, log: params.Logger.Get(), client: &http.Client{ - Timeout: httpClientTimeout, + Timeout: httpClientTimeout, + Transport: NewSSRFSafeTransport(), }, deliveryCh: make(chan DeliveryTask, deliveryChannelSize), retryCh: make(chan DeliveryTask, retryChannelSize), diff --git a/internal/delivery/ssrf.go b/internal/delivery/ssrf.go new file mode 100644 index 0000000..73b5c3e --- /dev/null +++ b/internal/delivery/ssrf.go @@ -0,0 +1,153 @@ +package delivery + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "time" +) + +const ( + // dnsResolutionTimeout is the maximum time to wait for DNS resolution + // during SSRF validation. + dnsResolutionTimeout = 5 * time.Second +) + +// blockedNetworks contains all private/reserved IP ranges that should be +// blocked to prevent SSRF attacks. This includes RFC 1918 private +// addresses, loopback, link-local, and IPv6 equivalents. +// +//nolint:gochecknoglobals // package-level network list is appropriate here +var blockedNetworks []*net.IPNet + +//nolint:gochecknoinits // init is the idiomatic way to parse CIDRs once at startup +func init() { + cidrs := []string{ + // IPv4 private/reserved ranges + "127.0.0.0/8", // Loopback + "10.0.0.0/8", // RFC 1918 Class A private + "172.16.0.0/12", // RFC 1918 Class B private + "192.168.0.0/16", // RFC 1918 Class C private + "169.254.0.0/16", // Link-local (cloud metadata) + "0.0.0.0/8", // "This" network + "100.64.0.0/10", // Shared address space (CGN) + "192.0.0.0/24", // IETF protocol assignments + "192.0.2.0/24", // TEST-NET-1 + "198.18.0.0/15", // Benchmarking + "198.51.100.0/24", // TEST-NET-2 + "203.0.113.0/24", // TEST-NET-3 + "224.0.0.0/4", // Multicast + "240.0.0.0/4", // Reserved for future use + + // IPv6 private/reserved ranges + "::1/128", // Loopback + "fc00::/7", // Unique local addresses + "fe80::/10", // Link-local + } + + for _, cidr := range cidrs { + _, network, err := net.ParseCIDR(cidr) + if err != nil { + panic(fmt.Sprintf("ssrf: failed to parse CIDR %q: %v", cidr, err)) + } + blockedNetworks = append(blockedNetworks, network) + } +} + +// isBlockedIP checks whether an IP address falls within any blocked +// private/reserved network range. +func isBlockedIP(ip net.IP) bool { + for _, network := range blockedNetworks { + if network.Contains(ip) { + return true + } + } + return false +} + +// ValidateTargetURL checks that an HTTP delivery target URL is safe +// from SSRF attacks. It validates the URL format, resolves the hostname +// to IP addresses, and verifies that none of the resolved IPs are in +// blocked private/reserved ranges. +// +// Returns nil if the URL is safe, or an error describing the issue. +func ValidateTargetURL(targetURL string) error { + parsed, err := url.Parse(targetURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + + // Only allow http and https schemes + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return fmt.Errorf("unsupported URL scheme %q: only http and https are allowed", parsed.Scheme) + } + + host := parsed.Hostname() + if host == "" { + return fmt.Errorf("URL has no hostname") + } + + // Check if the host is a raw IP address first + if ip := net.ParseIP(host); ip != nil { + if isBlockedIP(ip) { + return fmt.Errorf("target IP %s is in a blocked private/reserved range", ip) + } + return nil + } + + // Resolve hostname to IPs and check each one + ctx, cancel := context.WithTimeout(context.Background(), dnsResolutionTimeout) + defer cancel() + + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve hostname %q: %w", host, err) + } + + if len(ips) == 0 { + return fmt.Errorf("hostname %q resolved to no IP addresses", host) + } + + for _, ipAddr := range ips { + if isBlockedIP(ipAddr.IP) { + return fmt.Errorf("hostname %q resolves to blocked IP %s (private/reserved range)", host, ipAddr.IP) + } + } + + return nil +} + +// NewSSRFSafeTransport creates an http.Transport with a custom DialContext +// that blocks connections to private/reserved IP addresses. This provides +// defense-in-depth SSRF protection at the network layer, catching cases +// where DNS records change between target creation and delivery time +// (DNS rebinding attacks). +func NewSSRFSafeTransport() *http.Transport { + return &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("ssrf: invalid address %q: %w", addr, err) + } + + // Resolve hostname to IPs + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("ssrf: DNS resolution failed for %q: %w", host, err) + } + + // Check all resolved IPs + for _, ipAddr := range ips { + if isBlockedIP(ipAddr.IP) { + return nil, fmt.Errorf("ssrf: connection to %s (%s) blocked — private/reserved IP range", host, ipAddr.IP) + } + } + + // Connect to the first allowed IP + var dialer net.Dialer + return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port)) + }, + } +} diff --git a/internal/delivery/ssrf_test.go b/internal/delivery/ssrf_test.go new file mode 100644 index 0000000..3a12a03 --- /dev/null +++ b/internal/delivery/ssrf_test.go @@ -0,0 +1,142 @@ +package delivery + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsBlockedIP_PrivateRanges(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ip string + blocked bool + }{ + // Loopback + {"loopback 127.0.0.1", "127.0.0.1", true}, + {"loopback 127.0.0.2", "127.0.0.2", true}, + {"loopback 127.255.255.255", "127.255.255.255", true}, + + // RFC 1918 - Class A + {"10.0.0.0", "10.0.0.0", true}, + {"10.0.0.1", "10.0.0.1", true}, + {"10.255.255.255", "10.255.255.255", true}, + + // RFC 1918 - Class B + {"172.16.0.1", "172.16.0.1", true}, + {"172.31.255.255", "172.31.255.255", true}, + {"172.15.255.255", "172.15.255.255", false}, + {"172.32.0.0", "172.32.0.0", false}, + + // RFC 1918 - Class C + {"192.168.0.1", "192.168.0.1", true}, + {"192.168.255.255", "192.168.255.255", true}, + + // Link-local / cloud metadata + {"169.254.0.1", "169.254.0.1", true}, + {"169.254.169.254", "169.254.169.254", true}, + + // Public IPs (should NOT be blocked) + {"8.8.8.8", "8.8.8.8", false}, + {"1.1.1.1", "1.1.1.1", false}, + {"93.184.216.34", "93.184.216.34", false}, + + // IPv6 loopback + {"::1", "::1", true}, + + // IPv6 unique local + {"fd00::1", "fd00::1", true}, + {"fc00::1", "fc00::1", true}, + + // IPv6 link-local + {"fe80::1", "fe80::1", true}, + + // IPv6 public (should NOT be blocked) + {"2607:f8b0:4004:800::200e", "2607:f8b0:4004:800::200e", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ip := net.ParseIP(tt.ip) + require.NotNil(t, ip, "failed to parse IP %s", tt.ip) + assert.Equal(t, tt.blocked, isBlockedIP(ip), + "isBlockedIP(%s) = %v, want %v", tt.ip, isBlockedIP(ip), tt.blocked) + }) + } +} + +func TestValidateTargetURL_Blocked(t *testing.T) { + t.Parallel() + + blockedURLs := []string{ + "http://127.0.0.1/hook", + "http://127.0.0.1:8080/hook", + "https://10.0.0.1/hook", + "http://192.168.1.1/webhook", + "http://172.16.0.1/api", + "http://169.254.169.254/latest/meta-data/", + "http://[::1]/hook", + "http://[fc00::1]/hook", + "http://[fe80::1]/hook", + "http://0.0.0.0/hook", + } + + for _, u := range blockedURLs { + t.Run(u, func(t *testing.T) { + t.Parallel() + err := ValidateTargetURL(u) + assert.Error(t, err, "URL %s should be blocked", u) + }) + } +} + +func TestValidateTargetURL_Allowed(t *testing.T) { + t.Parallel() + + // These are public IPs and should be allowed + allowedURLs := []string{ + "https://example.com/hook", + "http://93.184.216.34/webhook", + "https://hooks.slack.com/services/T00/B00/xxx", + } + + for _, u := range allowedURLs { + t.Run(u, func(t *testing.T) { + t.Parallel() + err := ValidateTargetURL(u) + assert.NoError(t, err, "URL %s should be allowed", u) + }) + } +} + +func TestValidateTargetURL_InvalidScheme(t *testing.T) { + t.Parallel() + err := ValidateTargetURL("ftp://example.com/hook") + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported URL scheme") +} + +func TestValidateTargetURL_EmptyHost(t *testing.T) { + t.Parallel() + err := ValidateTargetURL("http:///path") + assert.Error(t, err) +} + +func TestValidateTargetURL_InvalidURL(t *testing.T) { + t.Parallel() + err := ValidateTargetURL("://invalid") + assert.Error(t, err) +} + +func TestBlockedNetworks_Initialized(t *testing.T) { + t.Parallel() + assert.NotEmpty(t, blockedNetworks, "blockedNetworks should be initialized") + // Should have at least the main RFC 1918 + loopback + link-local ranges + assert.GreaterOrEqual(t, len(blockedNetworks), 8, + "should have at least 8 blocked network ranges") +} diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 625a12f..47915f3 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -13,6 +13,7 @@ import ( "sneak.berlin/go/webhooker/internal/globals" "sneak.berlin/go/webhooker/internal/healthcheck" "sneak.berlin/go/webhooker/internal/logger" + "sneak.berlin/go/webhooker/internal/middleware" "sneak.berlin/go/webhooker/internal/session" "sneak.berlin/go/webhooker/templates" ) @@ -128,9 +129,13 @@ func (s *Handlers) renderTemplate(w http.ResponseWriter, r *http.Request, pageTe } } - // If data is a map, merge user info into it + // Get CSRF token from request context (set by CSRF middleware) + csrfToken := middleware.CSRFToken(r) + + // If data is a map, merge user info and CSRF token into it if m, ok := data.(map[string]interface{}); ok { m["User"] = userInfo + m["CSRFToken"] = csrfToken if err := tmpl.Execute(w, m); err != nil { s.log.Error("failed to execute template", "error", err) http.Error(w, "Internal server error", http.StatusInternalServerError) @@ -140,13 +145,15 @@ func (s *Handlers) renderTemplate(w http.ResponseWriter, r *http.Request, pageTe // Wrap data with base template data type templateDataWrapper struct { - User *UserInfo - Data interface{} + User *UserInfo + CSRFToken string + Data interface{} } wrapper := templateDataWrapper{ - User: userInfo, - Data: data, + User: userInfo, + CSRFToken: csrfToken, + Data: data, } if err := tmpl.Execute(w, wrapper); err != nil { diff --git a/internal/handlers/source_management.go b/internal/handlers/source_management.go index 66a4873..c211f07 100644 --- a/internal/handlers/source_management.go +++ b/internal/handlers/source_management.go @@ -8,6 +8,7 @@ import ( "github.com/go-chi/chi" "github.com/google/uuid" "sneak.berlin/go/webhooker/internal/database" + "sneak.berlin/go/webhooker/internal/delivery" ) // WebhookListItem holds data for the webhook list view. @@ -533,6 +534,17 @@ func (h *Handlers) HandleTargetCreate() http.HandlerFunc { http.Error(w, "URL is required for HTTP targets", http.StatusBadRequest) return } + + // Validate URL against SSRF: block private/reserved IP ranges + if err := delivery.ValidateTargetURL(url); err != nil { + h.log.Warn("target URL blocked by SSRF protection", + "url", url, + "error", err, + ) + http.Error(w, "Invalid target URL: "+err.Error(), http.StatusBadRequest) + return + } + cfg := map[string]interface{}{ "url": url, } diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go new file mode 100644 index 0000000..ac86471 --- /dev/null +++ b/internal/middleware/csrf.go @@ -0,0 +1,114 @@ +package middleware + +import ( + "context" + "crypto/rand" + "encoding/hex" + "net/http" +) + +const ( + // csrfTokenLength is the byte length of generated CSRF tokens. + // 32 bytes = 64 hex characters, providing 256 bits of entropy. + csrfTokenLength = 32 + + // csrfSessionKey is the session key where the CSRF token is stored. + csrfSessionKey = "csrf_token" + + // csrfFormField is the HTML form field name for the CSRF token. + csrfFormField = "csrf_token" +) + +// csrfContextKey is the context key type for CSRF tokens. +type csrfContextKey struct{} + +// CSRFToken retrieves the CSRF token from the request context. +// Returns an empty string if no token is present. +func CSRFToken(r *http.Request) string { + if token, ok := r.Context().Value(csrfContextKey{}).(string); ok { + return token + } + return "" +} + +// CSRF returns middleware that provides CSRF protection for state-changing +// requests. For every request, it ensures a CSRF token exists in the +// session and makes it available via the request context. For POST, PUT, +// PATCH, and DELETE requests, it validates the submitted csrf_token form +// field against the session token. Requests with an invalid or missing +// token receive a 403 Forbidden response. +func (m *Middleware) CSRF() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sess, err := m.session.Get(r) + if err != nil { + m.log.Error("csrf: failed to get session", "error", err) + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + // Ensure a CSRF token exists in the session + token, ok := sess.Values[csrfSessionKey].(string) + if !ok { + token = "" + } + if token == "" { + token, err = generateCSRFToken() + if err != nil { + m.log.Error("csrf: failed to generate token", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + sess.Values[csrfSessionKey] = token + if saveErr := m.session.Save(r, w, sess); saveErr != nil { + m.log.Error("csrf: failed to save session", "error", saveErr) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + } + + // Store token in context for templates + ctx := context.WithValue(r.Context(), csrfContextKey{}, token) + r = r.WithContext(ctx) + + // Validate token on state-changing methods + switch r.Method { + case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: + submitted := r.FormValue(csrfFormField) + if !secureCompare(submitted, token) { + m.log.Warn("csrf: token mismatch", + "method", r.Method, + "path", r.URL.Path, + "remote_addr", r.RemoteAddr, + ) + http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden) + return + } + } + + next.ServeHTTP(w, r) + }) + } +} + +// generateCSRFToken creates a cryptographically random hex-encoded token. +func generateCSRFToken() (string, error) { + b := make([]byte, csrfTokenLength) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +// secureCompare performs a constant-time string comparison to prevent +// timing attacks on CSRF token validation. +func secureCompare(a, b string) bool { + if len(a) != len(b) { + return false + } + var result byte + for i := 0; i < len(a); i++ { + result |= a[i] ^ b[i] + } + return result == 0 +} diff --git a/internal/middleware/csrf_test.go b/internal/middleware/csrf_test.go new file mode 100644 index 0000000..f0022ac --- /dev/null +++ b/internal/middleware/csrf_test.go @@ -0,0 +1,184 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "sneak.berlin/go/webhooker/internal/config" +) + +func TestCSRF_GETSetsToken(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + var gotToken string + handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + gotToken = CSRFToken(r) + })) + + req := httptest.NewRequest(http.MethodGet, "/form", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET") + assert.Len(t, gotToken, csrfTokenLength*2, "CSRF token should be hex-encoded 32 bytes") +} + +func TestCSRF_POSTWithValidToken(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + // Use a separate handler for the GET to capture the token + var token string + getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + token = CSRFToken(r) + })) + + // GET to establish the session and capture token + getReq := httptest.NewRequest(http.MethodGet, "/form", nil) + getW := httptest.NewRecorder() + getHandler.ServeHTTP(getW, getReq) + + cookies := getW.Result().Cookies() + require.NotEmpty(t, cookies) + require.NotEmpty(t, token) + + // POST handler that tracks whether it was called + var called bool + postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + // POST with valid token + form := url.Values{csrfFormField: {token}} + postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range cookies { + postReq.AddCookie(c) + } + postW := httptest.NewRecorder() + + postHandler.ServeHTTP(postW, postReq) + + assert.True(t, called, "handler should be called with valid CSRF token") +} + +func TestCSRF_POSTWithoutToken(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + // GET handler to establish session + getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + // no-op — just establishes session + })) + + getReq := httptest.NewRequest(http.MethodGet, "/form", nil) + getW := httptest.NewRecorder() + getHandler.ServeHTTP(getW, getReq) + cookies := getW.Result().Cookies() + + // POST handler that tracks whether it was called + var called bool + postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + // POST without CSRF token + postReq := httptest.NewRequest(http.MethodPost, "/form", nil) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range cookies { + postReq.AddCookie(c) + } + postW := httptest.NewRecorder() + + postHandler.ServeHTTP(postW, postReq) + + assert.False(t, called, "handler should NOT be called without CSRF token") + assert.Equal(t, http.StatusForbidden, postW.Code) +} + +func TestCSRF_POSTWithInvalidToken(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + // GET handler to establish session + getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + // no-op — just establishes session + })) + + getReq := httptest.NewRequest(http.MethodGet, "/form", nil) + getW := httptest.NewRecorder() + getHandler.ServeHTTP(getW, getReq) + cookies := getW.Result().Cookies() + + // POST handler that tracks whether it was called + var called bool + postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + // POST with wrong CSRF token + form := url.Values{csrfFormField: {"invalid-token-value"}} + postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range cookies { + postReq.AddCookie(c) + } + postW := httptest.NewRecorder() + + postHandler.ServeHTTP(postW, postReq) + + assert.False(t, called, "handler should NOT be called with invalid CSRF token") + assert.Equal(t, http.StatusForbidden, postW.Code) +} + +func TestCSRF_GETDoesNotValidate(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + var called bool + handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + // GET requests should pass through without CSRF validation + req := httptest.NewRequest(http.MethodGet, "/form", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.True(t, called, "GET requests should pass through CSRF middleware") +} + +func TestCSRFToken_NoContext(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodGet, "/", nil) + assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when no token in context") +} + +func TestGenerateCSRFToken(t *testing.T) { + t.Parallel() + token, err := generateCSRFToken() + require.NoError(t, err) + assert.Len(t, token, csrfTokenLength*2, "token should be hex-encoded") + + // Verify uniqueness + token2, err := generateCSRFToken() + require.NoError(t, err) + assert.NotEqual(t, token, token2, "each generated token should be unique") +} + +func TestSecureCompare(t *testing.T) { + t.Parallel() + assert.True(t, secureCompare("abc", "abc")) + assert.False(t, secureCompare("abc", "abd")) + assert.False(t, secureCompare("abc", "ab")) + assert.False(t, secureCompare("", "a")) + assert.True(t, secureCompare("", "")) +} diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go new file mode 100644 index 0000000..db23efd --- /dev/null +++ b/internal/middleware/ratelimit.go @@ -0,0 +1,172 @@ +package middleware + +import ( + "net" + "net/http" + "sync" + "time" + + "golang.org/x/time/rate" +) + +const ( + // loginRateLimit is the maximum number of login attempts per interval. + loginRateLimit = 5 + + // loginRateInterval is the time window for the rate limit. + loginRateInterval = 1 * time.Minute + + // limiterCleanupInterval is how often stale per-IP limiters are pruned. + limiterCleanupInterval = 5 * time.Minute + + // limiterMaxAge is how long an unused limiter is kept before pruning. + limiterMaxAge = 10 * time.Minute +) + +// ipLimiter holds a rate limiter and the time it was last used. +type ipLimiter struct { + limiter *rate.Limiter + lastSeen time.Time +} + +// rateLimiterMap manages per-IP rate limiters with periodic cleanup. +type rateLimiterMap struct { + mu sync.Mutex + limiters map[string]*ipLimiter + rate rate.Limit + burst int +} + +// newRateLimiterMap creates a new per-IP rate limiter map. +func newRateLimiterMap(r rate.Limit, burst int) *rateLimiterMap { + rlm := &rateLimiterMap{ + limiters: make(map[string]*ipLimiter), + rate: r, + burst: burst, + } + + // Start background cleanup goroutine + go rlm.cleanup() + + return rlm +} + +// getLimiter returns the rate limiter for the given IP, creating one if +// it doesn't exist. +func (rlm *rateLimiterMap) getLimiter(ip string) *rate.Limiter { + rlm.mu.Lock() + defer rlm.mu.Unlock() + + if entry, ok := rlm.limiters[ip]; ok { + entry.lastSeen = time.Now() + return entry.limiter + } + + limiter := rate.NewLimiter(rlm.rate, rlm.burst) + rlm.limiters[ip] = &ipLimiter{ + limiter: limiter, + lastSeen: time.Now(), + } + return limiter +} + +// cleanup periodically removes stale rate limiters to prevent unbounded +// memory growth from unique IPs. +func (rlm *rateLimiterMap) cleanup() { + ticker := time.NewTicker(limiterCleanupInterval) + defer ticker.Stop() + + for range ticker.C { + rlm.mu.Lock() + cutoff := time.Now().Add(-limiterMaxAge) + for ip, entry := range rlm.limiters { + if entry.lastSeen.Before(cutoff) { + delete(rlm.limiters, ip) + } + } + rlm.mu.Unlock() + } +} + +// LoginRateLimit returns middleware that enforces per-IP rate limiting +// on login attempts. Only POST requests are rate-limited; GET requests +// (rendering the login form) pass through unaffected. When the rate +// limit is exceeded, a 429 Too Many Requests response is returned. +func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler { + // Calculate rate: loginRateLimit events per loginRateInterval + r := rate.Limit(float64(loginRateLimit) / loginRateInterval.Seconds()) + rlm := newRateLimiterMap(r, loginRateLimit) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Only rate-limit POST requests (actual login attempts) + if r.Method != http.MethodPost { + next.ServeHTTP(w, r) + return + } + + ip := extractIP(r) + limiter := rlm.getLimiter(ip) + + if !limiter.Allow() { + m.log.Warn("login rate limit exceeded", + "ip", ip, + "path", r.URL.Path, + ) + http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// extractIP extracts the client IP address from the request. It checks +// X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups), +// then falls back to RemoteAddr. +func extractIP(r *http.Request) string { + // Check X-Forwarded-For header (first IP in chain) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2 + // The first one is the original client + for i := 0; i < len(xff); i++ { + if xff[i] == ',' { + ip := xff[:i] + // Trim whitespace + for len(ip) > 0 && ip[0] == ' ' { + ip = ip[1:] + } + for len(ip) > 0 && ip[len(ip)-1] == ' ' { + ip = ip[:len(ip)-1] + } + if ip != "" { + return ip + } + break + } + } + trimmed := xff + for len(trimmed) > 0 && trimmed[0] == ' ' { + trimmed = trimmed[1:] + } + for len(trimmed) > 0 && trimmed[len(trimmed)-1] == ' ' { + trimmed = trimmed[:len(trimmed)-1] + } + if trimmed != "" { + return trimmed + } + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + + // Fall back to RemoteAddr + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return ip +} diff --git a/internal/middleware/ratelimit_test.go b/internal/middleware/ratelimit_test.go new file mode 100644 index 0000000..bd4e85f --- /dev/null +++ b/internal/middleware/ratelimit_test.go @@ -0,0 +1,121 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "sneak.berlin/go/webhooker/internal/config" +) + +func TestLoginRateLimit_AllowsGET(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + var callCount int + handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + })) + + // GET requests should never be rate-limited + for i := 0; i < 20; i++ { + req := httptest.NewRequest(http.MethodGet, "/pages/login", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code, "GET request %d should pass", i) + } + assert.Equal(t, 20, callCount) +} + +func TestLoginRateLimit_LimitsPOST(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + var callCount int + handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + })) + + // First loginRateLimit POST requests should succeed + for i := 0; i < loginRateLimit; i++ { + req := httptest.NewRequest(http.MethodPost, "/pages/login", nil) + req.RemoteAddr = "10.0.0.1:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code, "POST request %d should pass", i) + } + + // Next POST should be rate-limited + req := httptest.NewRequest(http.MethodPost, "/pages/login", nil) + req.RemoteAddr = "10.0.0.1:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusTooManyRequests, w.Code, "POST after limit should be 429") + assert.Equal(t, loginRateLimit, callCount) +} + +func TestLoginRateLimit_IndependentPerIP(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Exhaust limit for IP1 + for i := 0; i < loginRateLimit; i++ { + req := httptest.NewRequest(http.MethodPost, "/pages/login", nil) + req.RemoteAddr = "1.2.3.4:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + } + + // IP1 should be rate-limited + req := httptest.NewRequest(http.MethodPost, "/pages/login", nil) + req.RemoteAddr = "1.2.3.4:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusTooManyRequests, w.Code) + + // IP2 should still be allowed + req2 := httptest.NewRequest(http.MethodPost, "/pages/login", nil) + req2.RemoteAddr = "5.6.7.8:12345" + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + assert.Equal(t, http.StatusOK, w2.Code, "different IP should not be affected") +} + +func TestExtractIP_RemoteAddr(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "192.168.1.100:54321" + assert.Equal(t, "192.168.1.100", extractIP(req)) +} + +func TestExtractIP_XForwardedFor(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "10.0.0.1:1234" + req.Header.Set("X-Forwarded-For", "203.0.113.50, 70.41.3.18, 150.172.238.178") + assert.Equal(t, "203.0.113.50", extractIP(req)) +} + +func TestExtractIP_XRealIP(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "10.0.0.1:1234" + req.Header.Set("X-Real-IP", "203.0.113.50") + assert.Equal(t, "203.0.113.50", extractIP(req)) +} + +func TestExtractIP_XForwardedForSingle(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "10.0.0.1:1234" + req.Header.Set("X-Forwarded-For", "203.0.113.50") + assert.Equal(t, "203.0.113.50", extractIP(req)) +} diff --git a/internal/server/routes.go b/internal/server/routes.go index 347b976..0169e43 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -58,11 +58,17 @@ func (s *Server) SetupRoutes() { }) } - // pages that are rendered server-side + // pages that are rendered server-side — CSRF-protected and with + // per-IP rate limiting on the login endpoint. s.router.Route("/pages", func(r chi.Router) { - // Login page (no auth required) - r.Get("/login", s.h.HandleLoginPage()) - r.Post("/login", s.h.HandleLoginSubmit()) + r.Use(s.mw.CSRF()) + + // Login page — rate-limited to prevent brute-force attacks + r.Group(func(r chi.Router) { + r.Use(s.mw.LoginRateLimit()) + r.Get("/login", s.h.HandleLoginPage()) + r.Post("/login", s.h.HandleLoginSubmit()) + }) // Logout (auth required) r.Post("/logout", s.h.HandleLogout()) @@ -70,11 +76,13 @@ func (s *Server) SetupRoutes() { // User profile routes s.router.Route("/user/{username}", func(r chi.Router) { + r.Use(s.mw.CSRF()) r.Get("/", s.h.HandleProfile()) }) - // Webhook management routes (require authentication) + // Webhook management routes (require authentication, CSRF-protected) s.router.Route("/sources", func(r chi.Router) { + r.Use(s.mw.CSRF()) r.Use(s.mw.RequireAuth()) r.Get("/", s.h.HandleSourceList()) // List all webhooks r.Get("/new", s.h.HandleSourceCreate()) // Show create form @@ -82,6 +90,7 @@ func (s *Server) SetupRoutes() { }) s.router.Route("/source/{sourceID}", func(r chi.Router) { + r.Use(s.mw.CSRF()) r.Use(s.mw.RequireAuth()) r.Get("/", s.h.HandleSourceDetail()) // View webhook details r.Get("/edit", s.h.HandleSourceEdit()) // Show edit form diff --git a/templates/login.html b/templates/login.html index b4b1d69..6e45e67 100644 --- a/templates/login.html +++ b/templates/login.html @@ -23,6 +23,7 @@ {{end}}