feat: add CSRF protection, SSRF prevention, and login rate limiting
All checks were successful
check / check (push) Successful in 5s
All checks were successful
check / check (push) Successful in 5s
Security hardening implementing three issues: CSRF Protection (#35): - Session-based CSRF tokens with cryptographically random generation - Constant-time token comparison to prevent timing attacks - CSRF middleware applied to /pages, /sources, /source, and /user routes - Hidden csrf_token field added to all 12+ POST forms in templates - Excluded from /webhook (inbound) and /api (stateless) routes SSRF Prevention (#36): - ValidateTargetURL blocks private/reserved IP ranges at target creation - Blocked ranges: 127.0.0.0/8, 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, 169.254.0.0/16, ::1, fc00::/7, fe80::/10, plus multicast, reserved, test-net, and CGN ranges - SSRF-safe HTTP transport with custom DialContext for defense-in-depth at delivery time (prevents DNS rebinding attacks) - Only http/https schemes allowed Login Rate Limiting (#37): - Per-IP rate limiter using golang.org/x/time/rate - 5 attempts per minute per IP on POST /pages/login - GET requests (form rendering) pass through unaffected - Automatic cleanup of stale per-IP limiter entries - X-Forwarded-For and X-Real-IP header support for reverse proxies Closes #35, closes #36, closes #37
This commit is contained in:
114
internal/middleware/csrf.go
Normal file
114
internal/middleware/csrf.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
// csrfTokenLength is the byte length of generated CSRF tokens.
|
||||
// 32 bytes = 64 hex characters, providing 256 bits of entropy.
|
||||
csrfTokenLength = 32
|
||||
|
||||
// csrfSessionKey is the session key where the CSRF token is stored.
|
||||
csrfSessionKey = "csrf_token"
|
||||
|
||||
// csrfFormField is the HTML form field name for the CSRF token.
|
||||
csrfFormField = "csrf_token"
|
||||
)
|
||||
|
||||
// csrfContextKey is the context key type for CSRF tokens.
|
||||
type csrfContextKey struct{}
|
||||
|
||||
// CSRFToken retrieves the CSRF token from the request context.
|
||||
// Returns an empty string if no token is present.
|
||||
func CSRFToken(r *http.Request) string {
|
||||
if token, ok := r.Context().Value(csrfContextKey{}).(string); ok {
|
||||
return token
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// CSRF returns middleware that provides CSRF protection for state-changing
|
||||
// requests. For every request, it ensures a CSRF token exists in the
|
||||
// session and makes it available via the request context. For POST, PUT,
|
||||
// PATCH, and DELETE requests, it validates the submitted csrf_token form
|
||||
// field against the session token. Requests with an invalid or missing
|
||||
// token receive a 403 Forbidden response.
|
||||
func (m *Middleware) CSRF() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sess, err := m.session.Get(r)
|
||||
if err != nil {
|
||||
m.log.Error("csrf: failed to get session", "error", err)
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure a CSRF token exists in the session
|
||||
token, ok := sess.Values[csrfSessionKey].(string)
|
||||
if !ok {
|
||||
token = ""
|
||||
}
|
||||
if token == "" {
|
||||
token, err = generateCSRFToken()
|
||||
if err != nil {
|
||||
m.log.Error("csrf: failed to generate token", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
sess.Values[csrfSessionKey] = token
|
||||
if saveErr := m.session.Save(r, w, sess); saveErr != nil {
|
||||
m.log.Error("csrf: failed to save session", "error", saveErr)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Store token in context for templates
|
||||
ctx := context.WithValue(r.Context(), csrfContextKey{}, token)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Validate token on state-changing methods
|
||||
switch r.Method {
|
||||
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
|
||||
submitted := r.FormValue(csrfFormField)
|
||||
if !secureCompare(submitted, token) {
|
||||
m.log.Warn("csrf: token mismatch",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"remote_addr", r.RemoteAddr,
|
||||
)
|
||||
http.Error(w, "Forbidden - invalid CSRF token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// generateCSRFToken creates a cryptographically random hex-encoded token.
|
||||
func generateCSRFToken() (string, error) {
|
||||
b := make([]byte, csrfTokenLength)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// secureCompare performs a constant-time string comparison to prevent
|
||||
// timing attacks on CSRF token validation.
|
||||
func secureCompare(a, b string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
var result byte
|
||||
for i := 0; i < len(a); i++ {
|
||||
result |= a[i] ^ b[i]
|
||||
}
|
||||
return result == 0
|
||||
}
|
||||
184
internal/middleware/csrf_test.go
Normal file
184
internal/middleware/csrf_test.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"sneak.berlin/go/webhooker/internal/config"
|
||||
)
|
||||
|
||||
func TestCSRF_GETSetsToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var gotToken string
|
||||
handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
gotToken = CSRFToken(r)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.NotEmpty(t, gotToken, "CSRF token should be set in context on GET")
|
||||
assert.Len(t, gotToken, csrfTokenLength*2, "CSRF token should be hex-encoded 32 bytes")
|
||||
}
|
||||
|
||||
func TestCSRF_POSTWithValidToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
// Use a separate handler for the GET to capture the token
|
||||
var token string
|
||||
getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
token = CSRFToken(r)
|
||||
}))
|
||||
|
||||
// GET to establish the session and capture token
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
|
||||
cookies := getW.Result().Cookies()
|
||||
require.NotEmpty(t, cookies)
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
// POST handler that tracks whether it was called
|
||||
var called bool
|
||||
postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
// POST with valid token
|
||||
form := url.Values{csrfFormField: {token}}
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
for _, c := range cookies {
|
||||
postReq.AddCookie(c)
|
||||
}
|
||||
postW := httptest.NewRecorder()
|
||||
|
||||
postHandler.ServeHTTP(postW, postReq)
|
||||
|
||||
assert.True(t, called, "handler should be called with valid CSRF token")
|
||||
}
|
||||
|
||||
func TestCSRF_POSTWithoutToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
// GET handler to establish session
|
||||
getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
// no-op — just establishes session
|
||||
}))
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
cookies := getW.Result().Cookies()
|
||||
|
||||
// POST handler that tracks whether it was called
|
||||
var called bool
|
||||
postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
// POST without CSRF token
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/form", nil)
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
for _, c := range cookies {
|
||||
postReq.AddCookie(c)
|
||||
}
|
||||
postW := httptest.NewRecorder()
|
||||
|
||||
postHandler.ServeHTTP(postW, postReq)
|
||||
|
||||
assert.False(t, called, "handler should NOT be called without CSRF token")
|
||||
assert.Equal(t, http.StatusForbidden, postW.Code)
|
||||
}
|
||||
|
||||
func TestCSRF_POSTWithInvalidToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
// GET handler to establish session
|
||||
getHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
// no-op — just establishes session
|
||||
}))
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
getW := httptest.NewRecorder()
|
||||
getHandler.ServeHTTP(getW, getReq)
|
||||
cookies := getW.Result().Cookies()
|
||||
|
||||
// POST handler that tracks whether it was called
|
||||
var called bool
|
||||
postHandler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
// POST with wrong CSRF token
|
||||
form := url.Values{csrfFormField: {"invalid-token-value"}}
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/form", strings.NewReader(form.Encode()))
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
for _, c := range cookies {
|
||||
postReq.AddCookie(c)
|
||||
}
|
||||
postW := httptest.NewRecorder()
|
||||
|
||||
postHandler.ServeHTTP(postW, postReq)
|
||||
|
||||
assert.False(t, called, "handler should NOT be called with invalid CSRF token")
|
||||
assert.Equal(t, http.StatusForbidden, postW.Code)
|
||||
}
|
||||
|
||||
func TestCSRF_GETDoesNotValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var called bool
|
||||
handler := m.CSRF()(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
// GET requests should pass through without CSRF validation
|
||||
req := httptest.NewRequest(http.MethodGet, "/form", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.True(t, called, "GET requests should pass through CSRF middleware")
|
||||
}
|
||||
|
||||
func TestCSRFToken_NoContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
assert.Empty(t, CSRFToken(req), "CSRFToken should return empty string when no token in context")
|
||||
}
|
||||
|
||||
func TestGenerateCSRFToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
token, err := generateCSRFToken()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, token, csrfTokenLength*2, "token should be hex-encoded")
|
||||
|
||||
// Verify uniqueness
|
||||
token2, err := generateCSRFToken()
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, token, token2, "each generated token should be unique")
|
||||
}
|
||||
|
||||
func TestSecureCompare(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, secureCompare("abc", "abc"))
|
||||
assert.False(t, secureCompare("abc", "abd"))
|
||||
assert.False(t, secureCompare("abc", "ab"))
|
||||
assert.False(t, secureCompare("", "a"))
|
||||
assert.True(t, secureCompare("", ""))
|
||||
}
|
||||
172
internal/middleware/ratelimit.go
Normal file
172
internal/middleware/ratelimit.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const (
|
||||
// loginRateLimit is the maximum number of login attempts per interval.
|
||||
loginRateLimit = 5
|
||||
|
||||
// loginRateInterval is the time window for the rate limit.
|
||||
loginRateInterval = 1 * time.Minute
|
||||
|
||||
// limiterCleanupInterval is how often stale per-IP limiters are pruned.
|
||||
limiterCleanupInterval = 5 * time.Minute
|
||||
|
||||
// limiterMaxAge is how long an unused limiter is kept before pruning.
|
||||
limiterMaxAge = 10 * time.Minute
|
||||
)
|
||||
|
||||
// ipLimiter holds a rate limiter and the time it was last used.
|
||||
type ipLimiter struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// rateLimiterMap manages per-IP rate limiters with periodic cleanup.
|
||||
type rateLimiterMap struct {
|
||||
mu sync.Mutex
|
||||
limiters map[string]*ipLimiter
|
||||
rate rate.Limit
|
||||
burst int
|
||||
}
|
||||
|
||||
// newRateLimiterMap creates a new per-IP rate limiter map.
|
||||
func newRateLimiterMap(r rate.Limit, burst int) *rateLimiterMap {
|
||||
rlm := &rateLimiterMap{
|
||||
limiters: make(map[string]*ipLimiter),
|
||||
rate: r,
|
||||
burst: burst,
|
||||
}
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go rlm.cleanup()
|
||||
|
||||
return rlm
|
||||
}
|
||||
|
||||
// getLimiter returns the rate limiter for the given IP, creating one if
|
||||
// it doesn't exist.
|
||||
func (rlm *rateLimiterMap) getLimiter(ip string) *rate.Limiter {
|
||||
rlm.mu.Lock()
|
||||
defer rlm.mu.Unlock()
|
||||
|
||||
if entry, ok := rlm.limiters[ip]; ok {
|
||||
entry.lastSeen = time.Now()
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
limiter := rate.NewLimiter(rlm.rate, rlm.burst)
|
||||
rlm.limiters[ip] = &ipLimiter{
|
||||
limiter: limiter,
|
||||
lastSeen: time.Now(),
|
||||
}
|
||||
return limiter
|
||||
}
|
||||
|
||||
// cleanup periodically removes stale rate limiters to prevent unbounded
|
||||
// memory growth from unique IPs.
|
||||
func (rlm *rateLimiterMap) cleanup() {
|
||||
ticker := time.NewTicker(limiterCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rlm.mu.Lock()
|
||||
cutoff := time.Now().Add(-limiterMaxAge)
|
||||
for ip, entry := range rlm.limiters {
|
||||
if entry.lastSeen.Before(cutoff) {
|
||||
delete(rlm.limiters, ip)
|
||||
}
|
||||
}
|
||||
rlm.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// LoginRateLimit returns middleware that enforces per-IP rate limiting
|
||||
// on login attempts. Only POST requests are rate-limited; GET requests
|
||||
// (rendering the login form) pass through unaffected. When the rate
|
||||
// limit is exceeded, a 429 Too Many Requests response is returned.
|
||||
func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
|
||||
// Calculate rate: loginRateLimit events per loginRateInterval
|
||||
r := rate.Limit(float64(loginRateLimit) / loginRateInterval.Seconds())
|
||||
rlm := newRateLimiterMap(r, loginRateLimit)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only rate-limit POST requests (actual login attempts)
|
||||
if r.Method != http.MethodPost {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ip := extractIP(r)
|
||||
limiter := rlm.getLimiter(ip)
|
||||
|
||||
if !limiter.Allow() {
|
||||
m.log.Warn("login rate limit exceeded",
|
||||
"ip", ip,
|
||||
"path", r.URL.Path,
|
||||
)
|
||||
http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// extractIP extracts the client IP address from the request. It checks
|
||||
// X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups),
|
||||
// then falls back to RemoteAddr.
|
||||
func extractIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header (first IP in chain)
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2
|
||||
// The first one is the original client
|
||||
for i := 0; i < len(xff); i++ {
|
||||
if xff[i] == ',' {
|
||||
ip := xff[:i]
|
||||
// Trim whitespace
|
||||
for len(ip) > 0 && ip[0] == ' ' {
|
||||
ip = ip[1:]
|
||||
}
|
||||
for len(ip) > 0 && ip[len(ip)-1] == ' ' {
|
||||
ip = ip[:len(ip)-1]
|
||||
}
|
||||
if ip != "" {
|
||||
return ip
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
trimmed := xff
|
||||
for len(trimmed) > 0 && trimmed[0] == ' ' {
|
||||
trimmed = trimmed[1:]
|
||||
}
|
||||
for len(trimmed) > 0 && trimmed[len(trimmed)-1] == ' ' {
|
||||
trimmed = trimmed[:len(trimmed)-1]
|
||||
}
|
||||
if trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
121
internal/middleware/ratelimit_test.go
Normal file
121
internal/middleware/ratelimit_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"sneak.berlin/go/webhooker/internal/config"
|
||||
)
|
||||
|
||||
func TestLoginRateLimit_AllowsGET(t *testing.T) {
|
||||
t.Parallel()
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var callCount int
|
||||
handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// GET requests should never be rate-limited
|
||||
for i := 0; i < 20; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/pages/login", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "GET request %d should pass", i)
|
||||
}
|
||||
assert.Equal(t, 20, callCount)
|
||||
}
|
||||
|
||||
func TestLoginRateLimit_LimitsPOST(t *testing.T) {
|
||||
t.Parallel()
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
var callCount int
|
||||
handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First loginRateLimit POST requests should succeed
|
||||
for i := 0; i < loginRateLimit; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
||||
req.RemoteAddr = "10.0.0.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "POST request %d should pass", i)
|
||||
}
|
||||
|
||||
// Next POST should be rate-limited
|
||||
req := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
||||
req.RemoteAddr = "10.0.0.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code, "POST after limit should be 429")
|
||||
assert.Equal(t, loginRateLimit, callCount)
|
||||
}
|
||||
|
||||
func TestLoginRateLimit_IndependentPerIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
m, _ := testMiddleware(t, config.EnvironmentDev)
|
||||
|
||||
handler := m.LoginRateLimit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Exhaust limit for IP1
|
||||
for i := 0; i < loginRateLimit; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
||||
req.RemoteAddr = "1.2.3.4:12345"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// IP1 should be rate-limited
|
||||
req := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
||||
req.RemoteAddr = "1.2.3.4:12345"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
|
||||
// IP2 should still be allowed
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/pages/login", nil)
|
||||
req2.RemoteAddr = "5.6.7.8:12345"
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
assert.Equal(t, http.StatusOK, w2.Code, "different IP should not be affected")
|
||||
}
|
||||
|
||||
func TestExtractIP_RemoteAddr(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "192.168.1.100:54321"
|
||||
assert.Equal(t, "192.168.1.100", extractIP(req))
|
||||
}
|
||||
|
||||
func TestExtractIP_XForwardedFor(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
req.Header.Set("X-Forwarded-For", "203.0.113.50, 70.41.3.18, 150.172.238.178")
|
||||
assert.Equal(t, "203.0.113.50", extractIP(req))
|
||||
}
|
||||
|
||||
func TestExtractIP_XRealIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
req.Header.Set("X-Real-IP", "203.0.113.50")
|
||||
assert.Equal(t, "203.0.113.50", extractIP(req))
|
||||
}
|
||||
|
||||
func TestExtractIP_XForwardedForSingle(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
req.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
assert.Equal(t, "203.0.113.50", extractIP(req))
|
||||
}
|
||||
Reference in New Issue
Block a user