fix: extract real client IP from proxy headers (X-Real-IP / X-Forwarded-For)

Behind a reverse proxy like Traefik, RemoteAddr always contains the
proxy's IP. Add realIP() helper that checks X-Real-IP first, then the
first entry of X-Forwarded-For, falling back to RemoteAddr.

Update both LoginRateLimit and Logging middleware to use realIP().
Add comprehensive tests for the new function.

Fixes #12
This commit is contained in:
2026-02-15 21:14:12 -08:00
parent a1b06219e7
commit ef0786c4b4
2 changed files with 108 additions and 2 deletions

View File

@@ -7,6 +7,7 @@ import (
"math"
"net"
"net/http"
"strings"
"sync"
"time"
@@ -90,7 +91,7 @@ func (m *Middleware) Logging() func(http.Handler) http.Handler {
"request_id", reqID,
"referer", request.Referer(),
"proto", request.Proto,
"remoteIP", ipFromHostPort(request.RemoteAddr),
"remoteIP", realIP(request),
"status", lrw.statusCode,
"latency_ms", latency.Milliseconds(),
)
@@ -110,6 +111,28 @@ func ipFromHostPort(hostPort string) string {
return host
}
// realIP extracts the client's real IP address from the request,
// checking proxy headers first (trusted reverse proxy like Traefik),
// then falling back to RemoteAddr.
func realIP(r *http.Request) string {
// 1. X-Real-IP (set by Traefik/nginx)
if ip := strings.TrimSpace(r.Header.Get("X-Real-IP")); ip != "" {
return ip
}
// 2. X-Forwarded-For: take the first (leftmost/client) IP
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if parts := strings.SplitN(xff, ",", 2); len(parts) > 0 { //nolint:mnd
if ip := strings.TrimSpace(parts[0]); ip != "" {
return ip
}
}
}
// 3. Fall back to RemoteAddr
return ipFromHostPort(r.RemoteAddr)
}
// CORS returns CORS middleware.
func (m *Middleware) CORS() func(http.Handler) http.Handler {
return cors.Handler(cors.Options{
@@ -241,7 +264,7 @@ func (m *Middleware) LoginRateLimit() func(http.Handler) http.Handler {
writer http.ResponseWriter,
request *http.Request,
) {
ip := ipFromHostPort(request.RemoteAddr)
ip := realIP(request)
limiter := loginLimiter.getLimiter(ip)
if !limiter.Allow() {