diff --git a/README.md b/README.md index bb5315b..f8056f5 100644 --- a/README.md +++ b/README.md @@ -157,6 +157,10 @@ It uses: logging with TTY detection (text for dev, JSON for prod) - **[gorilla/sessions](https://github.com/gorilla/sessions)** for encrypted cookie-based session management +- **[gorilla/csrf](https://github.com/gorilla/csrf)** for CSRF + protection (cookie-based double-submit tokens) +- **[go-chi/httprate](https://github.com/go-chi/httprate)** for + per-IP login rate limiting (sliding window counter) - **[Prometheus](https://prometheus.io)** for metrics, served at `/metrics` behind basic auth - **[Sentry](https://sentry.io)** for optional error reporting @@ -720,7 +724,8 @@ webhooker/ │ │ └── globals.go # Build-time variables (appname, version, arch) │ ├── delivery/ │ │ ├── engine.go # Event-driven delivery engine (channel + timer based) -│ │ └── circuit_breaker.go # Per-target circuit breaker for HTTP targets with retries +│ │ ├── circuit_breaker.go # Per-target circuit breaker for HTTP targets with retries +│ │ └── ssrf.go # SSRF prevention (IP validation, safe HTTP transport) │ ├── handlers/ │ │ ├── handlers.go # Base handler struct, JSON helpers, template rendering │ │ ├── auth.go # Login, logout handlers @@ -734,7 +739,9 @@ webhooker/ │ ├── logger/ │ │ └── logger.go # slog setup with TTY detection │ ├── middleware/ -│ │ └── middleware.go # Logging, CORS, Auth, Metrics, MetricsAuth, SecurityHeaders, MaxBodySize +│ │ ├── middleware.go # Logging, CORS, Auth, Metrics, MetricsAuth, SecurityHeaders, MaxBodySize +│ │ ├── csrf.go # CSRF protection middleware (gorilla/csrf) +│ │ └── ratelimit.go # Per-IP rate limiting middleware (go-chi/httprate) │ ├── server/ │ │ ├── server.go # Server struct, fx lifecycle, signal handling │ │ ├── http.go # HTTP server setup with timeouts @@ -821,6 +828,19 @@ Additionally, form endpoints (`/pages`, `/sources`, `/source/*`) apply a (`nosniff`), X-Frame-Options (`DENY`), Content-Security-Policy, Referrer-Policy, and Permissions-Policy - Request body size limits (1 MB) on all form POST endpoints +- **CSRF protection** via [gorilla/csrf](https://github.com/gorilla/csrf) + on all state-changing forms (cookie-based double-submit tokens with + HMAC authentication). Applied to `/pages`, `/sources`, `/source`, and + `/user` routes. Excluded from `/webhook` (inbound webhook POSTs) and + `/api` (stateless API) +- **SSRF prevention** for HTTP delivery targets: private/reserved IP + ranges (RFC 1918, loopback, link-local, cloud metadata) are blocked + both at target creation time (URL validation) and at delivery time + (custom HTTP transport with SSRF-safe dialer that validates resolved + IPs before connecting, preventing DNS rebinding attacks) +- **Login rate limiting** via [go-chi/httprate](https://github.com/go-chi/httprate): + per-IP sliding-window rate limiter on the login endpoint (5 POST + attempts per minute per IP) to prevent brute-force attacks - Prometheus metrics behind basic auth - Static assets embedded in binary (no filesystem access needed at runtime) @@ -907,7 +927,12 @@ linted, tested, and compiled. ### Remaining: Core Features - [ ] Per-webhook rate limiting in the receiver handler - [ ] Webhook signature verification (GitHub, Stripe formats) -- [ ] CSRF protection for forms +- [x] CSRF protection for forms + ([#35](https://git.eeqj.de/sneak/webhooker/issues/35)) +- [x] SSRF prevention for HTTP delivery targets + ([#36](https://git.eeqj.de/sneak/webhooker/issues/36)) +- [x] Login rate limiting (per-IP brute-force protection) + ([#37](https://git.eeqj.de/sneak/webhooker/issues/37)) - [ ] Session expiration and "remember me" - [ ] Password change/reset flow - [ ] API key authentication for programmatic access diff --git a/go.mod b/go.mod index 84cbebe..6f4ff2a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module sneak.berlin/go/webhooker -go 1.23.0 +go 1.24.0 toolchain go1.24.1 @@ -9,7 +9,9 @@ require ( github.com/getsentry/sentry-go v0.25.0 github.com/go-chi/chi v1.5.5 github.com/go-chi/cors v1.2.1 + github.com/go-chi/httprate v0.15.0 github.com/google/uuid v1.6.0 + github.com/gorilla/csrf v1.7.3 github.com/gorilla/sessions v1.4.0 github.com/joho/godotenv v1.5.1 github.com/prometheus/client_golang v1.18.0 @@ -31,6 +33,7 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect @@ -40,6 +43,7 @@ require ( github.com/prometheus/common v0.45.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/zeebo/xxh3 v1.0.2 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/dig v1.17.0 // indirect go.uber.org/multierr v1.9.0 // indirect diff --git a/go.sum b/go.sum index 1571d29..f7e0ba6 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,8 @@ github.com/go-chi/chi v1.5.5 h1:vOB/HbEMt9QqBqErz07QehcOKHaWFtuj87tTDVz2qXE= github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw= github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= +github.com/go-chi/httprate v0.15.0 h1:j54xcWV9KGmPf/X4H32/aTH+wBlrvxL7P+SdnRqxh5g= +github.com/go-chi/httprate v0.15.0/go.mod h1:rzGHhVrsBn3IMLYDOZQsSU4fJNWcjui4fWKJcCId1R4= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -31,6 +33,8 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbu github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0= +github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= @@ -43,6 +47,8 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= +github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -80,6 +86,10 @@ github.com/stretchr/objx v0.5.1/go.mod h1:/iHQpkQwBD6DLUmQ4pE+s1TXdob1mORJ4/UFdr github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= +github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/dig v1.17.0 h1:5Chju+tUvcC+N7N6EV08BJz41UZuO3BmHcN4A287ZLI= diff --git a/internal/delivery/engine.go b/internal/delivery/engine.go index 77fd8df..78e692d 100644 --- a/internal/delivery/engine.go +++ b/internal/delivery/engine.go @@ -153,7 +153,8 @@ func New(lc fx.Lifecycle, params EngineParams) *Engine { dbManager: params.DBManager, log: params.Logger.Get(), client: &http.Client{ - Timeout: httpClientTimeout, + Timeout: httpClientTimeout, + Transport: NewSSRFSafeTransport(), }, deliveryCh: make(chan DeliveryTask, deliveryChannelSize), retryCh: make(chan DeliveryTask, retryChannelSize), diff --git a/internal/delivery/ssrf.go b/internal/delivery/ssrf.go new file mode 100644 index 0000000..73b5c3e --- /dev/null +++ b/internal/delivery/ssrf.go @@ -0,0 +1,153 @@ +package delivery + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "time" +) + +const ( + // dnsResolutionTimeout is the maximum time to wait for DNS resolution + // during SSRF validation. + dnsResolutionTimeout = 5 * time.Second +) + +// blockedNetworks contains all private/reserved IP ranges that should be +// blocked to prevent SSRF attacks. This includes RFC 1918 private +// addresses, loopback, link-local, and IPv6 equivalents. +// +//nolint:gochecknoglobals // package-level network list is appropriate here +var blockedNetworks []*net.IPNet + +//nolint:gochecknoinits // init is the idiomatic way to parse CIDRs once at startup +func init() { + cidrs := []string{ + // IPv4 private/reserved ranges + "127.0.0.0/8", // Loopback + "10.0.0.0/8", // RFC 1918 Class A private + "172.16.0.0/12", // RFC 1918 Class B private + "192.168.0.0/16", // RFC 1918 Class C private + "169.254.0.0/16", // Link-local (cloud metadata) + "0.0.0.0/8", // "This" network + "100.64.0.0/10", // Shared address space (CGN) + "192.0.0.0/24", // IETF protocol assignments + "192.0.2.0/24", // TEST-NET-1 + "198.18.0.0/15", // Benchmarking + "198.51.100.0/24", // TEST-NET-2 + "203.0.113.0/24", // TEST-NET-3 + "224.0.0.0/4", // Multicast + "240.0.0.0/4", // Reserved for future use + + // IPv6 private/reserved ranges + "::1/128", // Loopback + "fc00::/7", // Unique local addresses + "fe80::/10", // Link-local + } + + for _, cidr := range cidrs { + _, network, err := net.ParseCIDR(cidr) + if err != nil { + panic(fmt.Sprintf("ssrf: failed to parse CIDR %q: %v", cidr, err)) + } + blockedNetworks = append(blockedNetworks, network) + } +} + +// isBlockedIP checks whether an IP address falls within any blocked +// private/reserved network range. +func isBlockedIP(ip net.IP) bool { + for _, network := range blockedNetworks { + if network.Contains(ip) { + return true + } + } + return false +} + +// ValidateTargetURL checks that an HTTP delivery target URL is safe +// from SSRF attacks. It validates the URL format, resolves the hostname +// to IP addresses, and verifies that none of the resolved IPs are in +// blocked private/reserved ranges. +// +// Returns nil if the URL is safe, or an error describing the issue. +func ValidateTargetURL(targetURL string) error { + parsed, err := url.Parse(targetURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + + // Only allow http and https schemes + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return fmt.Errorf("unsupported URL scheme %q: only http and https are allowed", parsed.Scheme) + } + + host := parsed.Hostname() + if host == "" { + return fmt.Errorf("URL has no hostname") + } + + // Check if the host is a raw IP address first + if ip := net.ParseIP(host); ip != nil { + if isBlockedIP(ip) { + return fmt.Errorf("target IP %s is in a blocked private/reserved range", ip) + } + return nil + } + + // Resolve hostname to IPs and check each one + ctx, cancel := context.WithTimeout(context.Background(), dnsResolutionTimeout) + defer cancel() + + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve hostname %q: %w", host, err) + } + + if len(ips) == 0 { + return fmt.Errorf("hostname %q resolved to no IP addresses", host) + } + + for _, ipAddr := range ips { + if isBlockedIP(ipAddr.IP) { + return fmt.Errorf("hostname %q resolves to blocked IP %s (private/reserved range)", host, ipAddr.IP) + } + } + + return nil +} + +// NewSSRFSafeTransport creates an http.Transport with a custom DialContext +// that blocks connections to private/reserved IP addresses. This provides +// defense-in-depth SSRF protection at the network layer, catching cases +// where DNS records change between target creation and delivery time +// (DNS rebinding attacks). +func NewSSRFSafeTransport() *http.Transport { + return &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("ssrf: invalid address %q: %w", addr, err) + } + + // Resolve hostname to IPs + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("ssrf: DNS resolution failed for %q: %w", host, err) + } + + // Check all resolved IPs + for _, ipAddr := range ips { + if isBlockedIP(ipAddr.IP) { + return nil, fmt.Errorf("ssrf: connection to %s (%s) blocked — private/reserved IP range", host, ipAddr.IP) + } + } + + // Connect to the first allowed IP + var dialer net.Dialer + return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port)) + }, + } +} diff --git a/internal/delivery/ssrf_test.go b/internal/delivery/ssrf_test.go new file mode 100644 index 0000000..3a12a03 --- /dev/null +++ b/internal/delivery/ssrf_test.go @@ -0,0 +1,142 @@ +package delivery + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsBlockedIP_PrivateRanges(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ip string + blocked bool + }{ + // Loopback + {"loopback 127.0.0.1", "127.0.0.1", true}, + {"loopback 127.0.0.2", "127.0.0.2", true}, + {"loopback 127.255.255.255", "127.255.255.255", true}, + + // RFC 1918 - Class A + {"10.0.0.0", "10.0.0.0", true}, + {"10.0.0.1", "10.0.0.1", true}, + {"10.255.255.255", "10.255.255.255", true}, + + // RFC 1918 - Class B + {"172.16.0.1", "172.16.0.1", true}, + {"172.31.255.255", "172.31.255.255", true}, + {"172.15.255.255", "172.15.255.255", false}, + {"172.32.0.0", "172.32.0.0", false}, + + // RFC 1918 - Class C + {"192.168.0.1", "192.168.0.1", true}, + {"192.168.255.255", "192.168.255.255", true}, + + // Link-local / cloud metadata + {"169.254.0.1", "169.254.0.1", true}, + {"169.254.169.254", "169.254.169.254", true}, + + // Public IPs (should NOT be blocked) + {"8.8.8.8", "8.8.8.8", false}, + {"1.1.1.1", "1.1.1.1", false}, + {"93.184.216.34", "93.184.216.34", false}, + + // IPv6 loopback + {"::1", "::1", true}, + + // IPv6 unique local + {"fd00::1", "fd00::1", true}, + {"fc00::1", "fc00::1", true}, + + // IPv6 link-local + {"fe80::1", "fe80::1", true}, + + // IPv6 public (should NOT be blocked) + {"2607:f8b0:4004:800::200e", "2607:f8b0:4004:800::200e", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ip := net.ParseIP(tt.ip) + require.NotNil(t, ip, "failed to parse IP %s", tt.ip) + assert.Equal(t, tt.blocked, isBlockedIP(ip), + "isBlockedIP(%s) = %v, want %v", tt.ip, isBlockedIP(ip), tt.blocked) + }) + } +} + +func TestValidateTargetURL_Blocked(t *testing.T) { + t.Parallel() + + blockedURLs := []string{ + "http://127.0.0.1/hook", + "http://127.0.0.1:8080/hook", + "https://10.0.0.1/hook", + "http://192.168.1.1/webhook", + "http://172.16.0.1/api", + "http://169.254.169.254/latest/meta-data/", + "http://[::1]/hook", + "http://[fc00::1]/hook", + "http://[fe80::1]/hook", + "http://0.0.0.0/hook", + } + + for _, u := range blockedURLs { + t.Run(u, func(t *testing.T) { + t.Parallel() + err := ValidateTargetURL(u) + assert.Error(t, err, "URL %s should be blocked", u) + }) + } +} + +func TestValidateTargetURL_Allowed(t *testing.T) { + t.Parallel() + + // These are public IPs and should be allowed + allowedURLs := []string{ + "https://example.com/hook", + "http://93.184.216.34/webhook", + "https://hooks.slack.com/services/T00/B00/xxx", + } + + for _, u := range allowedURLs { + t.Run(u, func(t *testing.T) { + t.Parallel() + err := ValidateTargetURL(u) + assert.NoError(t, err, "URL %s should be allowed", u) + }) + } +} + +func TestValidateTargetURL_InvalidScheme(t *testing.T) { + t.Parallel() + err := ValidateTargetURL("ftp://example.com/hook") + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported URL scheme") +} + +func TestValidateTargetURL_EmptyHost(t *testing.T) { + t.Parallel() + err := ValidateTargetURL("http:///path") + assert.Error(t, err) +} + +func TestValidateTargetURL_InvalidURL(t *testing.T) { + t.Parallel() + err := ValidateTargetURL("://invalid") + assert.Error(t, err) +} + +func TestBlockedNetworks_Initialized(t *testing.T) { + t.Parallel() + assert.NotEmpty(t, blockedNetworks, "blockedNetworks should be initialized") + // Should have at least the main RFC 1918 + loopback + link-local ranges + assert.GreaterOrEqual(t, len(blockedNetworks), 8, + "should have at least 8 blocked network ranges") +} diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 625a12f..47915f3 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -13,6 +13,7 @@ import ( "sneak.berlin/go/webhooker/internal/globals" "sneak.berlin/go/webhooker/internal/healthcheck" "sneak.berlin/go/webhooker/internal/logger" + "sneak.berlin/go/webhooker/internal/middleware" "sneak.berlin/go/webhooker/internal/session" "sneak.berlin/go/webhooker/templates" ) @@ -128,9 +129,13 @@ func (s *Handlers) renderTemplate(w http.ResponseWriter, r *http.Request, pageTe } } - // If data is a map, merge user info into it + // Get CSRF token from request context (set by CSRF middleware) + csrfToken := middleware.CSRFToken(r) + + // If data is a map, merge user info and CSRF token into it if m, ok := data.(map[string]interface{}); ok { m["User"] = userInfo + m["CSRFToken"] = csrfToken if err := tmpl.Execute(w, m); err != nil { s.log.Error("failed to execute template", "error", err) http.Error(w, "Internal server error", http.StatusInternalServerError) @@ -140,13 +145,15 @@ func (s *Handlers) renderTemplate(w http.ResponseWriter, r *http.Request, pageTe // Wrap data with base template data type templateDataWrapper struct { - User *UserInfo - Data interface{} + User *UserInfo + CSRFToken string + Data interface{} } wrapper := templateDataWrapper{ - User: userInfo, - Data: data, + User: userInfo, + CSRFToken: csrfToken, + Data: data, } if err := tmpl.Execute(w, wrapper); err != nil { diff --git a/internal/handlers/source_management.go b/internal/handlers/source_management.go index 0a6fcf5..66e6c27 100644 --- a/internal/handlers/source_management.go +++ b/internal/handlers/source_management.go @@ -8,6 +8,7 @@ import ( "github.com/go-chi/chi" "github.com/google/uuid" "sneak.berlin/go/webhooker/internal/database" + "sneak.berlin/go/webhooker/internal/delivery" ) // WebhookListItem holds data for the webhook list view. @@ -533,6 +534,17 @@ func (h *Handlers) HandleTargetCreate() http.HandlerFunc { http.Error(w, "URL is required for HTTP targets", http.StatusBadRequest) return } + + // Validate URL against SSRF: block private/reserved IP ranges + if err := delivery.ValidateTargetURL(url); err != nil { + h.log.Warn("target URL blocked by SSRF protection", + "url", url, + "error", err, + ) + http.Error(w, "Invalid target URL: "+err.Error(), http.StatusBadRequest) + return + } + cfg := map[string]interface{}{ "url": url, } diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go new file mode 100644 index 0000000..df2ffb2 --- /dev/null +++ b/internal/middleware/csrf.go @@ -0,0 +1,56 @@ +package middleware + +import ( + "net/http" + + "github.com/gorilla/csrf" +) + +// CSRFToken retrieves the CSRF token from the request context. +// Returns an empty string if the gorilla/csrf middleware has not run. +func CSRFToken(r *http.Request) string { + return csrf.Token(r) +} + +// CSRF returns middleware that provides CSRF protection using the +// gorilla/csrf library. The middleware uses the session authentication +// key to sign a CSRF cookie and validates a masked token submitted via +// the "csrf_token" form field (or the "X-CSRF-Token" header) on +// POST/PUT/PATCH/DELETE requests. Requests with an invalid or missing +// token receive a 403 Forbidden response. +// +// In development mode, requests are marked as plaintext HTTP so that +// gorilla/csrf skips the strict Referer-origin check (which is only +// meaningful over TLS). +func (m *Middleware) CSRF() func(http.Handler) http.Handler { + protect := csrf.Protect( + m.session.GetKey(), + csrf.FieldName("csrf_token"), + csrf.Secure(!m.params.Config.IsDev()), + csrf.SameSite(csrf.SameSiteLaxMode), + csrf.Path("/"), + csrf.ErrorHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.log.Warn("csrf: token validation failed", + "method", r.Method, + "path", r.URL.Path, + "remote_addr", r.RemoteAddr, + "reason", csrf.FailureReason(r), + ) + http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden) + })), + ) + + // In development (plaintext HTTP), signal gorilla/csrf to skip + // the strict TLS Referer check by injecting the PlaintextHTTP + // context key before the CSRF handler sees the request. + if m.params.Config.IsDev() { + return func(next http.Handler) http.Handler { + csrfHandler := protect(next) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + csrfHandler.ServeHTTP(w, csrf.PlaintextHTTPRequest(r)) + }) + } + } + + return protect +} diff --git a/internal/middleware/csrf_test.go b/internal/middleware/csrf_test.go new file mode 100644 index 0000000..ae1791d --- /dev/null +++ b/internal/middleware/csrf_test.go @@ -0,0 +1,157 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "sneak.berlin/go/webhooker/internal/config" +) + +func TestCSRF_GETSetsToken(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + var gotToken string + handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + gotToken = CSRFToken(r) + })) + + req := httptest.NewRequest(http.MethodGet, "/form", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET") +} + +func TestCSRF_POSTWithValidToken(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + // Capture the token from a GET request + var token string + csrfMiddleware := m.CSRF() + getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + token = CSRFToken(r) + })) + + getReq := httptest.NewRequest(http.MethodGet, "/form", nil) + getW := httptest.NewRecorder() + getHandler.ServeHTTP(getW, getReq) + + cookies := getW.Result().Cookies() + require.NotEmpty(t, cookies) + require.NotEmpty(t, token) + + // POST with valid token and cookies from the GET response + var called bool + postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + form := url.Values{"csrf_token": {token}} + postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range cookies { + postReq.AddCookie(c) + } + postW := httptest.NewRecorder() + + postHandler.ServeHTTP(postW, postReq) + + assert.True(t, called, "handler should be called with valid CSRF token") +} + +func TestCSRF_POSTWithoutToken(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + csrfMiddleware := m.CSRF() + + // GET to establish the CSRF cookie + getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) + getReq := httptest.NewRequest(http.MethodGet, "/form", nil) + getW := httptest.NewRecorder() + getHandler.ServeHTTP(getW, getReq) + cookies := getW.Result().Cookies() + + // POST without CSRF token + var called bool + postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + postReq := httptest.NewRequest(http.MethodPost, "/form", nil) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range cookies { + postReq.AddCookie(c) + } + postW := httptest.NewRecorder() + + postHandler.ServeHTTP(postW, postReq) + + assert.False(t, called, "handler should NOT be called without CSRF token") + assert.Equal(t, http.StatusForbidden, postW.Code) +} + +func TestCSRF_POSTWithInvalidToken(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + csrfMiddleware := m.CSRF() + + // GET to establish the CSRF cookie + getHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) + getReq := httptest.NewRequest(http.MethodGet, "/form", nil) + getW := httptest.NewRecorder() + getHandler.ServeHTTP(getW, getReq) + cookies := getW.Result().Cookies() + + // POST with wrong CSRF token + var called bool + postHandler := csrfMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + form := url.Values{"csrf_token": {"invalid-token-value"}} + postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode())) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range cookies { + postReq.AddCookie(c) + } + postW := httptest.NewRecorder() + + postHandler.ServeHTTP(postW, postReq) + + assert.False(t, called, "handler should NOT be called with invalid CSRF token") + assert.Equal(t, http.StatusForbidden, postW.Code) +} + +func TestCSRF_GETDoesNotValidate(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + var called bool + handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + // GET requests should pass through without CSRF validation + req := httptest.NewRequest(http.MethodGet, "/form", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.True(t, called, "GET requests should pass through CSRF middleware") +} + +func TestCSRFToken_NoMiddleware(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodGet, "/", nil) + assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when middleware has not run") +} diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index 9267393..6702ad2 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -40,7 +40,7 @@ func testMiddleware(t *testing.T, env string) (*Middleware, *session.Session) { SameSite: http.SameSiteLaxMode, } - sessManager := newTestSession(t, store, cfg, log) + sessManager := newTestSession(t, store, cfg, log, key) m := &Middleware{ log: log, @@ -55,9 +55,9 @@ func testMiddleware(t *testing.T, env string) (*Middleware, *session.Session) { // newTestSession creates a session.Session with a pre-configured cookie store // for testing. This avoids needing the fx lifecycle and database. -func newTestSession(t *testing.T, store *sessions.CookieStore, cfg *config.Config, log *slog.Logger) *session.Session { +func newTestSession(t *testing.T, store *sessions.CookieStore, cfg *config.Config, log *slog.Logger, key []byte) *session.Session { t.Helper() - return session.NewForTest(store, cfg, log) + return session.NewForTest(store, cfg, log, key) } // --- Logging Middleware Tests --- @@ -326,7 +326,7 @@ func TestMetricsAuth_ValidCredentials(t *testing.T) { store := sessions.NewCookieStore(key) store.Options = &sessions.Options{Path: "/", MaxAge: 86400} - sessManager := session.NewForTest(store, cfg, log) + sessManager := session.NewForTest(store, cfg, log, key) m := &Middleware{ log: log, @@ -366,7 +366,7 @@ func TestMetricsAuth_InvalidCredentials(t *testing.T) { store := sessions.NewCookieStore(key) store.Options = &sessions.Options{Path: "/", MaxAge: 86400} - sessManager := session.NewForTest(store, cfg, log) + sessManager := session.NewForTest(store, cfg, log, key) m := &Middleware{ log: log, @@ -406,7 +406,7 @@ func TestMetricsAuth_NoCredentials(t *testing.T) { store := sessions.NewCookieStore(key) store.Options = &sessions.Options{Path: "/", MaxAge: 86400} - sessManager := session.NewForTest(store, cfg, log) + sessManager := session.NewForTest(store, cfg, log, key) m := &Middleware{ log: log, diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go new file mode 100644 index 0000000..e82bec3 --- /dev/null +++ b/internal/middleware/ratelimit.go @@ -0,0 +1,48 @@ +package middleware + +import ( + "net/http" + "time" + + "github.com/go-chi/httprate" +) + +const ( + // loginRateLimit is the maximum number of login attempts per interval. + loginRateLimit = 5 + + // loginRateInterval is the time window for the rate limit. + loginRateInterval = 1 * time.Minute +) + +// LoginRateLimit returns middleware that enforces per-IP rate limiting +// on login attempts using go-chi/httprate. Only POST requests are +// rate-limited; GET requests (rendering the login form) pass through +// unaffected. When the rate limit is exceeded, a 429 Too Many Requests +// response is returned. IP extraction honours X-Forwarded-For, +// X-Real-IP, and True-Client-IP headers for reverse-proxy setups. +func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler { + limiter := httprate.Limit( + loginRateLimit, + loginRateInterval, + httprate.WithKeyFuncs(httprate.KeyByRealIP), + httprate.WithLimitHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.log.Warn("login rate limit exceeded", + "path", r.URL.Path, + ) + http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests) + })), + ) + + return func(next http.Handler) http.Handler { + limited := limiter(next) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Only rate-limit POST requests (actual login attempts) + if r.Method != http.MethodPost { + next.ServeHTTP(w, r) + return + } + limited.ServeHTTP(w, r) + }) + } +} diff --git a/internal/middleware/ratelimit_test.go b/internal/middleware/ratelimit_test.go new file mode 100644 index 0000000..6cea882 --- /dev/null +++ b/internal/middleware/ratelimit_test.go @@ -0,0 +1,90 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "sneak.berlin/go/webhooker/internal/config" +) + +func TestLoginRateLimit_AllowsGET(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + var callCount int + handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + })) + + // GET requests should never be rate-limited + for i := 0; i < 20; i++ { + req := httptest.NewRequest(http.MethodGet, "/pages/login", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code, "GET request %d should pass", i) + } + assert.Equal(t, 20, callCount) +} + +func TestLoginRateLimit_LimitsPOST(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + var callCount int + handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + })) + + // First loginRateLimit POST requests should succeed + for i := 0; i < loginRateLimit; i++ { + req := httptest.NewRequest(http.MethodPost, "/pages/login", nil) + req.RemoteAddr = "10.0.0.1:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code, "POST request %d should pass", i) + } + + // Next POST should be rate-limited + req := httptest.NewRequest(http.MethodPost, "/pages/login", nil) + req.RemoteAddr = "10.0.0.1:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusTooManyRequests, w.Code, "POST after limit should be 429") + assert.Equal(t, loginRateLimit, callCount) +} + +func TestLoginRateLimit_IndependentPerIP(t *testing.T) { + t.Parallel() + m, _ := testMiddleware(t, config.EnvironmentDev) + + handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Exhaust limit for IP1 + for i := 0; i < loginRateLimit; i++ { + req := httptest.NewRequest(http.MethodPost, "/pages/login", nil) + req.RemoteAddr = "1.2.3.4:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + } + + // IP1 should be rate-limited + req := httptest.NewRequest(http.MethodPost, "/pages/login", nil) + req.RemoteAddr = "1.2.3.4:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusTooManyRequests, w.Code) + + // IP2 should still be allowed + req2 := httptest.NewRequest(http.MethodPost, "/pages/login", nil) + req2.RemoteAddr = "5.6.7.8:12345" + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + assert.Equal(t, http.StatusOK, w2.Code, "different IP should not be affected") +} diff --git a/internal/server/routes.go b/internal/server/routes.go index c357614..6fd908e 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -64,13 +64,18 @@ func (s *Server) SetupRoutes() { }) } - // pages that are rendered server-side + // pages that are rendered server-side — CSRF-protected, body-size + // limited, and with per-IP rate limiting on the login endpoint. s.router.Route("/pages", func(r chi.Router) { + r.Use(s.mw.CSRF()) r.Use(s.mw.MaxBodySize(maxFormBodySize)) - // Login page (no auth required) - r.Get("/login", s.h.HandleLoginPage()) - r.Post("/login", s.h.HandleLoginSubmit()) + // Login page — rate-limited to prevent brute-force attacks + r.Group(func(r chi.Router) { + r.Use(s.mw.LoginRateLimit()) + r.Get("/login", s.h.HandleLoginPage()) + r.Post("/login", s.h.HandleLoginSubmit()) + }) // Logout (auth required) r.Post("/logout", s.h.HandleLogout()) @@ -78,11 +83,13 @@ func (s *Server) SetupRoutes() { // User profile routes s.router.Route("/user/{username}", func(r chi.Router) { + r.Use(s.mw.CSRF()) r.Get("/", s.h.HandleProfile()) }) - // Webhook management routes (require authentication) + // Webhook management routes (require authentication, CSRF-protected) s.router.Route("/sources", func(r chi.Router) { + r.Use(s.mw.CSRF()) r.Use(s.mw.RequireAuth()) r.Use(s.mw.MaxBodySize(maxFormBodySize)) r.Get("/", s.h.HandleSourceList()) // List all webhooks @@ -91,6 +98,7 @@ func (s *Server) SetupRoutes() { }) s.router.Route("/source/{sourceID}", func(r chi.Router) { + r.Use(s.mw.CSRF()) r.Use(s.mw.RequireAuth()) r.Use(s.mw.MaxBodySize(maxFormBodySize)) r.Get("/", s.h.HandleSourceDetail()) // View webhook details diff --git a/internal/session/session.go b/internal/session/session.go index 73a89ba..225ce15 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -39,6 +39,7 @@ type SessionParams struct { // Session manages encrypted session storage type Session struct { store *sessions.CookieStore + key []byte // raw 32-byte auth key, also used for CSRF cookie signing log *slog.Logger config *config.Config } @@ -79,6 +80,7 @@ func New(lc fx.Lifecycle, params SessionParams) (*Session, error) { SameSite: http.SameSiteLaxMode, } + s.key = keyBytes s.store = store s.log.Info("session manager initialized") return nil @@ -93,6 +95,12 @@ func (s *Session) Get(r *http.Request) (*sessions.Session, error) { return s.store.Get(r, SessionName) } +// GetKey returns the raw 32-byte authentication key used for session +// encryption. This key is also suitable for CSRF cookie signing. +func (s *Session) GetKey() []byte { + return s.key +} + // Save saves the session func (s *Session) Save(r *http.Request, w http.ResponseWriter, sess *sessions.Session) error { return sess.Save(r, w) diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 7c694f8..0950be5 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -34,7 +34,7 @@ func testSession(t *testing.T) *Session { } log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) - return NewForTest(store, cfg, log) + return NewForTest(store, cfg, log, key) } // --- Get and Save Tests --- diff --git a/internal/session/testing.go b/internal/session/testing.go index 2d20420..d15c4af 100644 --- a/internal/session/testing.go +++ b/internal/session/testing.go @@ -9,10 +9,13 @@ import ( // NewForTest creates a Session with a pre-configured cookie store for use // in tests. This bypasses the fx lifecycle and database dependency, allowing -// middleware and handler tests to use real session functionality. -func NewForTest(store *sessions.CookieStore, cfg *config.Config, log *slog.Logger) *Session { +// middleware and handler tests to use real session functionality. The key +// parameter is the raw 32-byte authentication key used for session encryption +// and CSRF cookie signing. +func NewForTest(store *sessions.CookieStore, cfg *config.Config, log *slog.Logger, key []byte) *Session { return &Session{ store: store, + key: key, config: cfg, log: log, } diff --git a/templates/login.html b/templates/login.html index b4b1d69..6e45e67 100644 --- a/templates/login.html +++ b/templates/login.html @@ -23,6 +23,7 @@ {{end}}
+
+ {{else}} @@ -40,6 +41,7 @@ Sources Profile
+
{{else}} diff --git a/templates/source_detail.html b/templates/source_detail.html index 5a62a4b..5f20957 100644 --- a/templates/source_detail.html +++ b/templates/source_detail.html @@ -17,6 +17,7 @@ Event Log Edit
+
@@ -39,6 +40,7 @@
+
@@ -56,11 +58,13 @@ Inactive {{end}}
+
+
@@ -88,6 +92,7 @@
+
+
diff --git a/templates/source_edit.html b/templates/source_edit.html index 365146a..7bce014 100644 --- a/templates/source_edit.html +++ b/templates/source_edit.html @@ -15,6 +15,7 @@ {{end}}
+
diff --git a/templates/sources_new.html b/templates/sources_new.html index 321ce8f..60cae8d 100644 --- a/templates/sources_new.html +++ b/templates/sources_new.html @@ -15,6 +15,7 @@ {{end}} +