routewatch/internal/server/middleware.go
sneak 7d39bd18bc Fix concurrent map write panic in timeout middleware
- Add thread-safe header wrapper in timeoutWriter
- Check context cancellation before writing responses in handlers
- Protect header access after timeout with mutex
- Prevents race condition when requests timeout while handlers are still running
2025-07-28 21:54:58 +02:00

220 lines
4.8 KiB
Go

package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
)
// responseWriter wraps http.ResponseWriter to capture the response
type responseWriter struct {
http.ResponseWriter
body *bytes.Buffer
statusCode int
written bool
mu sync.Mutex
}
func (rw *responseWriter) Write(b []byte) (int, error) {
rw.mu.Lock()
defer rw.mu.Unlock()
if !rw.written {
rw.written = true
}
return rw.body.Write(b)
}
func (rw *responseWriter) WriteHeader(statusCode int) {
rw.mu.Lock()
defer rw.mu.Unlock()
if !rw.written {
rw.statusCode = statusCode
rw.written = true
}
}
func (rw *responseWriter) Header() http.Header {
return rw.ResponseWriter.Header()
}
// JSONResponseMiddleware wraps all JSON responses with metadata
func JSONResponseMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip non-JSON endpoints
if r.URL.Path == "/" || r.URL.Path == "/status" {
next.ServeHTTP(w, r)
return
}
startTime := time.Now()
// Create a custom response writer to capture the response
rw := &responseWriter{
ResponseWriter: w,
body: &bytes.Buffer{},
statusCode: http.StatusOK,
}
// Serve the request
next.ServeHTTP(rw, r)
// Calculate execution time
execTime := time.Since(startTime)
// Only process JSON responses
contentType := rw.Header().Get("Content-Type")
if contentType != "application/json" || rw.body.Len() == 0 {
// Write the original response
w.WriteHeader(rw.statusCode)
_, _ = w.Write(rw.body.Bytes())
return
}
// Parse the original response
var originalResponse map[string]interface{}
if err := json.Unmarshal(rw.body.Bytes(), &originalResponse); err != nil {
// If we can't parse it, just write original
w.WriteHeader(rw.statusCode)
_, _ = w.Write(rw.body.Bytes())
return
}
// Add @meta field
originalResponse["@meta"] = map[string]interface{}{
"time_zone": "UTC",
"api_version": 1,
"execution_time": fmt.Sprintf("%d ms", execTime.Milliseconds()),
}
// Write the enhanced response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(rw.statusCode)
_ = json.NewEncoder(w).Encode(originalResponse)
})
}
// timeoutWriter wraps ResponseWriter to prevent concurrent writes after timeout
type timeoutWriter struct {
http.ResponseWriter
mu sync.Mutex
written bool
header http.Header // cached header to prevent concurrent access
}
func (tw *timeoutWriter) Write(b []byte) (int, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.written {
return 0, nil // Discard writes after timeout
}
return tw.ResponseWriter.Write(b)
}
func (tw *timeoutWriter) WriteHeader(statusCode int) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.written {
return // Discard writes after timeout
}
tw.ResponseWriter.WriteHeader(statusCode)
}
func (tw *timeoutWriter) Header() http.Header {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.written {
// Return a copy to prevent modifications after timeout
if tw.header == nil {
tw.header = make(http.Header)
}
return tw.header
}
return tw.ResponseWriter.Header()
}
func (tw *timeoutWriter) markWritten() {
tw.mu.Lock()
defer tw.mu.Unlock()
tw.written = true
}
// TimeoutMiddleware creates a timeout middleware that returns JSON errors
func TimeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
tw := &timeoutWriter{
ResponseWriter: w,
header: make(http.Header),
}
done := make(chan struct{})
panicChan := make(chan interface{}, 1)
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
next.ServeHTTP(tw, r.WithContext(ctx))
close(done)
}()
select {
case p := <-panicChan:
panic(p)
case <-done:
return
case <-ctx.Done():
tw.markWritten() // Prevent the handler from writing after timeout
execTime := time.Since(startTime)
// Write directly to the underlying writer since we've marked tw as written
// This is safe because markWritten() prevents the handler from writing
tw.mu.Lock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusRequestTimeout)
tw.mu.Unlock()
response := map[string]interface{}{
"status": "error",
"error": map[string]interface{}{
"msg": "Request timeout",
"code": http.StatusRequestTimeout,
},
"@meta": map[string]interface{}{
"time_zone": "UTC",
"api_version": 1,
"execution_time": fmt.Sprintf("%d ms", execTime.Milliseconds()),
},
}
_ = json.NewEncoder(w).Encode(response)
}
})
}
}