routewatch/internal/server/middleware.go
sneak 3a9ec98d5c Add structured HTTP request logging and increase timeouts
- Replace chi's Logger middleware with structured slog-based logging
- Log request start (debug) and completion (info/warn/error by status)
- Include method, path, status, duration_ms, remote_addr in logs
- Increase request timeout from 8s to 30s for slow queries
- Add read/write/idle timeouts to HTTP server config
- Better server startup logging to confirm listening state
2025-12-30 13:37:54 +07:00

369 lines
9.4 KiB
Go

package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"sync"
"time"
)
// responseWriter wraps http.ResponseWriter to capture the response
type responseWriter struct {
http.ResponseWriter
body *bytes.Buffer
statusCode int
written bool
mu sync.Mutex
}
func (rw *responseWriter) Write(b []byte) (int, error) {
rw.mu.Lock()
defer rw.mu.Unlock()
if !rw.written {
rw.written = true
}
return rw.body.Write(b)
}
func (rw *responseWriter) WriteHeader(statusCode int) {
rw.mu.Lock()
defer rw.mu.Unlock()
if !rw.written {
rw.statusCode = statusCode
rw.written = true
}
}
func (rw *responseWriter) Header() http.Header {
return rw.ResponseWriter.Header()
}
// JSONResponseMiddleware is an HTTP middleware that wraps all JSON responses
// with a @meta field containing execution metadata. The metadata includes the
// time zone (always UTC), API version, and request execution time in milliseconds.
//
// Endpoints "/" and "/status" are excluded from this processing and passed through
// unchanged. Non-JSON responses and empty responses are also passed through unchanged.
func JSONResponseMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip non-JSON endpoints
if r.URL.Path == "/" || r.URL.Path == "/status" {
next.ServeHTTP(w, r)
return
}
startTime := time.Now()
// Create a custom response writer to capture the response
rw := &responseWriter{
ResponseWriter: w,
body: &bytes.Buffer{},
statusCode: http.StatusOK,
}
// Serve the request
next.ServeHTTP(rw, r)
// Calculate execution time
execTime := time.Since(startTime)
// Only process JSON responses
contentType := rw.Header().Get("Content-Type")
if contentType != "application/json" || rw.body.Len() == 0 {
// Write the original response
w.WriteHeader(rw.statusCode)
_, _ = w.Write(rw.body.Bytes())
return
}
// Parse the original response
var originalResponse map[string]interface{}
if err := json.Unmarshal(rw.body.Bytes(), &originalResponse); err != nil {
// If we can't parse it, just write original
w.WriteHeader(rw.statusCode)
_, _ = w.Write(rw.body.Bytes())
return
}
// Add @meta field
originalResponse["@meta"] = map[string]interface{}{
"time_zone": "UTC",
"api_version": 1,
"execution_time": fmt.Sprintf("%d ms", execTime.Milliseconds()),
}
// Write the enhanced response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(rw.statusCode)
_ = json.NewEncoder(w).Encode(originalResponse)
})
}
// timeoutWriter wraps ResponseWriter to prevent concurrent writes after timeout
type timeoutWriter struct {
http.ResponseWriter
mu sync.Mutex
written bool
header http.Header // cached header to prevent concurrent access
}
func (tw *timeoutWriter) Write(b []byte) (int, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.written {
return 0, nil // Discard writes after timeout
}
return tw.ResponseWriter.Write(b)
}
func (tw *timeoutWriter) WriteHeader(statusCode int) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.written {
return // Discard writes after timeout
}
tw.ResponseWriter.WriteHeader(statusCode)
}
func (tw *timeoutWriter) Header() http.Header {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.written {
// Return a copy to prevent modifications after timeout
if tw.header == nil {
tw.header = make(http.Header)
}
return tw.header
}
return tw.ResponseWriter.Header()
}
func (tw *timeoutWriter) markWritten() {
tw.mu.Lock()
defer tw.mu.Unlock()
tw.written = true
}
// TimeoutMiddleware creates an HTTP middleware that enforces a request timeout.
// If the handler does not complete within the specified duration, the middleware
// returns a JSON error response with HTTP status 408 (Request Timeout).
//
// The timeout parameter specifies the maximum duration allowed for request processing.
// The returned middleware handles panics from the wrapped handler by re-panicking
// after cleanup, and prevents concurrent writes to the response after timeout occurs.
func TimeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
tw := &timeoutWriter{
ResponseWriter: w,
header: make(http.Header),
}
done := make(chan struct{})
panicChan := make(chan interface{}, 1)
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
next.ServeHTTP(tw, r.WithContext(ctx))
close(done)
}()
select {
case p := <-panicChan:
panic(p)
case <-done:
return
case <-ctx.Done():
tw.markWritten() // Prevent the handler from writing after timeout
execTime := time.Since(startTime)
// Write directly to the underlying writer since we've marked tw as written
// This is safe because markWritten() prevents the handler from writing
tw.mu.Lock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusRequestTimeout)
tw.mu.Unlock()
response := map[string]interface{}{
"status": "error",
"error": map[string]interface{}{
"msg": "Request timeout",
"code": http.StatusRequestTimeout,
},
"@meta": map[string]interface{}{
"time_zone": "UTC",
"api_version": 1,
"execution_time": fmt.Sprintf("%d ms", execTime.Milliseconds()),
},
}
_ = json.NewEncoder(w).Encode(response)
}
})
}
}
// JSONValidationMiddleware is an HTTP middleware that validates JSON API responses.
// It ensures that responses with Content-Type "application/json" contain valid JSON.
//
// If a response is not valid JSON or is empty when JSON is expected, the middleware
// returns a properly formatted JSON error response. For timeout errors (status 408),
// the error message will be "Request timeout". For other errors, it returns
// "Internal server error" with status 500 if the original status was 200.
//
// Non-JSON responses are passed through unchanged.
func JSONValidationMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create a custom response writer to capture the response
rw := &responseWriter{
ResponseWriter: w,
body: &bytes.Buffer{},
statusCode: http.StatusOK,
}
// Serve the request
next.ServeHTTP(rw, r)
// Check if it's meant to be a JSON response
contentType := rw.Header().Get("Content-Type")
isJSON := contentType == "application/json" || contentType == ""
// If it's not JSON or has content, pass through
if !isJSON && rw.body.Len() > 0 {
w.WriteHeader(rw.statusCode)
_, _ = w.Write(rw.body.Bytes())
return
}
// For JSON responses, validate the JSON
if rw.body.Len() > 0 {
var testParse interface{}
if err := json.Unmarshal(rw.body.Bytes(), &testParse); err == nil {
// Valid JSON, write it out
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(rw.statusCode)
_, _ = w.Write(rw.body.Bytes())
return
}
}
// If we get here, either there's no body or invalid JSON
// Write a proper error response
w.Header().Set("Content-Type", "application/json")
// Determine appropriate status code
statusCode := rw.statusCode
if statusCode == http.StatusOK {
statusCode = http.StatusInternalServerError
}
w.WriteHeader(statusCode)
errorMsg := "Internal server error"
if statusCode == http.StatusRequestTimeout {
errorMsg = "Request timeout"
}
response := map[string]interface{}{
"status": "error",
"error": map[string]interface{}{
"msg": errorMsg,
"code": statusCode,
},
}
_ = json.NewEncoder(w).Encode(response)
})
}
// statusWriter wraps http.ResponseWriter to capture the status code
type statusWriter struct {
http.ResponseWriter
statusCode int
written bool
}
func (sw *statusWriter) WriteHeader(statusCode int) {
if !sw.written {
sw.statusCode = statusCode
sw.written = true
}
sw.ResponseWriter.WriteHeader(statusCode)
}
func (sw *statusWriter) Write(b []byte) (int, error) {
if !sw.written {
sw.statusCode = http.StatusOK
sw.written = true
}
return sw.ResponseWriter.Write(b)
}
// RequestLoggerMiddleware creates a structured logging middleware using slog.
func RequestLoggerMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Wrap response writer to capture status
sw := &statusWriter{ResponseWriter: w, statusCode: http.StatusOK}
// Log request start
logger.Debug("HTTP request started",
"method", r.Method,
"path", r.URL.Path,
"remote_addr", r.RemoteAddr,
"user_agent", r.UserAgent(),
)
// Serve the request
next.ServeHTTP(sw, r)
// Log request completion
duration := time.Since(start)
logLevel := slog.LevelInfo
if sw.statusCode >= http.StatusInternalServerError {
logLevel = slog.LevelError
} else if sw.statusCode >= http.StatusBadRequest {
logLevel = slog.LevelWarn
}
logger.Log(r.Context(), logLevel, "HTTP request completed",
"method", r.Method,
"path", r.URL.Path,
"status", sw.statusCode,
"duration_ms", duration.Milliseconds(),
"remote_addr", r.RemoteAddr,
)
})
}
}