- Add thread-safe header wrapper in timeoutWriter - Check context cancellation before writing responses in handlers - Protect header access after timeout with mutex - Prevents race condition when requests timeout while handlers are still running
220 lines
4.8 KiB
Go
220 lines
4.8 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// responseWriter wraps http.ResponseWriter to capture the response
|
|
type responseWriter struct {
|
|
http.ResponseWriter
|
|
body *bytes.Buffer
|
|
statusCode int
|
|
written bool
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (rw *responseWriter) Write(b []byte) (int, error) {
|
|
rw.mu.Lock()
|
|
defer rw.mu.Unlock()
|
|
|
|
if !rw.written {
|
|
rw.written = true
|
|
}
|
|
|
|
return rw.body.Write(b)
|
|
}
|
|
|
|
func (rw *responseWriter) WriteHeader(statusCode int) {
|
|
rw.mu.Lock()
|
|
defer rw.mu.Unlock()
|
|
|
|
if !rw.written {
|
|
rw.statusCode = statusCode
|
|
rw.written = true
|
|
}
|
|
}
|
|
|
|
func (rw *responseWriter) Header() http.Header {
|
|
return rw.ResponseWriter.Header()
|
|
}
|
|
|
|
// JSONResponseMiddleware wraps all JSON responses with metadata
|
|
func JSONResponseMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Skip non-JSON endpoints
|
|
if r.URL.Path == "/" || r.URL.Path == "/status" {
|
|
next.ServeHTTP(w, r)
|
|
|
|
return
|
|
}
|
|
|
|
startTime := time.Now()
|
|
|
|
// Create a custom response writer to capture the response
|
|
rw := &responseWriter{
|
|
ResponseWriter: w,
|
|
body: &bytes.Buffer{},
|
|
statusCode: http.StatusOK,
|
|
}
|
|
|
|
// Serve the request
|
|
next.ServeHTTP(rw, r)
|
|
|
|
// Calculate execution time
|
|
execTime := time.Since(startTime)
|
|
|
|
// Only process JSON responses
|
|
contentType := rw.Header().Get("Content-Type")
|
|
if contentType != "application/json" || rw.body.Len() == 0 {
|
|
// Write the original response
|
|
w.WriteHeader(rw.statusCode)
|
|
_, _ = w.Write(rw.body.Bytes())
|
|
|
|
return
|
|
}
|
|
|
|
// Parse the original response
|
|
var originalResponse map[string]interface{}
|
|
if err := json.Unmarshal(rw.body.Bytes(), &originalResponse); err != nil {
|
|
// If we can't parse it, just write original
|
|
w.WriteHeader(rw.statusCode)
|
|
_, _ = w.Write(rw.body.Bytes())
|
|
|
|
return
|
|
}
|
|
|
|
// Add @meta field
|
|
originalResponse["@meta"] = map[string]interface{}{
|
|
"time_zone": "UTC",
|
|
"api_version": 1,
|
|
"execution_time": fmt.Sprintf("%d ms", execTime.Milliseconds()),
|
|
}
|
|
|
|
// Write the enhanced response
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(rw.statusCode)
|
|
_ = json.NewEncoder(w).Encode(originalResponse)
|
|
})
|
|
}
|
|
|
|
// timeoutWriter wraps ResponseWriter to prevent concurrent writes after timeout
|
|
type timeoutWriter struct {
|
|
http.ResponseWriter
|
|
mu sync.Mutex
|
|
written bool
|
|
header http.Header // cached header to prevent concurrent access
|
|
}
|
|
|
|
func (tw *timeoutWriter) Write(b []byte) (int, error) {
|
|
tw.mu.Lock()
|
|
defer tw.mu.Unlock()
|
|
|
|
if tw.written {
|
|
return 0, nil // Discard writes after timeout
|
|
}
|
|
|
|
return tw.ResponseWriter.Write(b)
|
|
}
|
|
|
|
func (tw *timeoutWriter) WriteHeader(statusCode int) {
|
|
tw.mu.Lock()
|
|
defer tw.mu.Unlock()
|
|
|
|
if tw.written {
|
|
return // Discard writes after timeout
|
|
}
|
|
|
|
tw.ResponseWriter.WriteHeader(statusCode)
|
|
}
|
|
|
|
func (tw *timeoutWriter) Header() http.Header {
|
|
tw.mu.Lock()
|
|
defer tw.mu.Unlock()
|
|
|
|
if tw.written {
|
|
// Return a copy to prevent modifications after timeout
|
|
if tw.header == nil {
|
|
tw.header = make(http.Header)
|
|
}
|
|
|
|
return tw.header
|
|
}
|
|
|
|
return tw.ResponseWriter.Header()
|
|
}
|
|
|
|
func (tw *timeoutWriter) markWritten() {
|
|
tw.mu.Lock()
|
|
defer tw.mu.Unlock()
|
|
tw.written = true
|
|
}
|
|
|
|
// TimeoutMiddleware creates a timeout middleware that returns JSON errors
|
|
func TimeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
startTime := time.Now()
|
|
|
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
|
defer cancel()
|
|
|
|
tw := &timeoutWriter{
|
|
ResponseWriter: w,
|
|
header: make(http.Header),
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
panicChan := make(chan interface{}, 1)
|
|
|
|
go func() {
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
panicChan <- p
|
|
}
|
|
}()
|
|
|
|
next.ServeHTTP(tw, r.WithContext(ctx))
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case p := <-panicChan:
|
|
panic(p)
|
|
case <-done:
|
|
return
|
|
case <-ctx.Done():
|
|
tw.markWritten() // Prevent the handler from writing after timeout
|
|
execTime := time.Since(startTime)
|
|
|
|
// Write directly to the underlying writer since we've marked tw as written
|
|
// This is safe because markWritten() prevents the handler from writing
|
|
tw.mu.Lock()
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusRequestTimeout)
|
|
tw.mu.Unlock()
|
|
|
|
response := map[string]interface{}{
|
|
"status": "error",
|
|
"error": map[string]interface{}{
|
|
"msg": "Request timeout",
|
|
"code": http.StatusRequestTimeout,
|
|
},
|
|
"@meta": map[string]interface{}{
|
|
"time_zone": "UTC",
|
|
"api_version": 1,
|
|
"execution_time": fmt.Sprintf("%d ms", execTime.Milliseconds()),
|
|
},
|
|
}
|
|
|
|
_ = json.NewEncoder(w).Encode(response)
|
|
}
|
|
})
|
|
}
|
|
}
|