package server import ( "bytes" "context" "encoding/json" "fmt" "net/http" "sync" "time" ) // responseWriter wraps http.ResponseWriter to capture the response type responseWriter struct { http.ResponseWriter body *bytes.Buffer statusCode int written bool mu sync.Mutex } func (rw *responseWriter) Write(b []byte) (int, error) { rw.mu.Lock() defer rw.mu.Unlock() if !rw.written { rw.written = true } return rw.body.Write(b) } func (rw *responseWriter) WriteHeader(statusCode int) { rw.mu.Lock() defer rw.mu.Unlock() if !rw.written { rw.statusCode = statusCode rw.written = true } } func (rw *responseWriter) Header() http.Header { return rw.ResponseWriter.Header() } // JSONResponseMiddleware wraps all JSON responses with metadata func JSONResponseMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip non-JSON endpoints if r.URL.Path == "/" || r.URL.Path == "/status" { next.ServeHTTP(w, r) return } startTime := time.Now() // Create a custom response writer to capture the response rw := &responseWriter{ ResponseWriter: w, body: &bytes.Buffer{}, statusCode: http.StatusOK, } // Serve the request next.ServeHTTP(rw, r) // Calculate execution time execTime := time.Since(startTime) // Only process JSON responses contentType := rw.Header().Get("Content-Type") if contentType != "application/json" || rw.body.Len() == 0 { // Write the original response w.WriteHeader(rw.statusCode) _, _ = w.Write(rw.body.Bytes()) return } // Parse the original response var originalResponse map[string]interface{} if err := json.Unmarshal(rw.body.Bytes(), &originalResponse); err != nil { // If we can't parse it, just write original w.WriteHeader(rw.statusCode) _, _ = w.Write(rw.body.Bytes()) return } // Add @meta field originalResponse["@meta"] = map[string]interface{}{ "time_zone": "UTC", "api_version": 1, "execution_time": fmt.Sprintf("%d ms", execTime.Milliseconds()), } // Write the enhanced response w.Header().Set("Content-Type", "application/json") w.WriteHeader(rw.statusCode) _ = json.NewEncoder(w).Encode(originalResponse) }) } // timeoutWriter wraps ResponseWriter to prevent concurrent writes after timeout type timeoutWriter struct { http.ResponseWriter mu sync.Mutex written bool header http.Header // cached header to prevent concurrent access } func (tw *timeoutWriter) Write(b []byte) (int, error) { tw.mu.Lock() defer tw.mu.Unlock() if tw.written { return 0, nil // Discard writes after timeout } return tw.ResponseWriter.Write(b) } func (tw *timeoutWriter) WriteHeader(statusCode int) { tw.mu.Lock() defer tw.mu.Unlock() if tw.written { return // Discard writes after timeout } tw.ResponseWriter.WriteHeader(statusCode) } func (tw *timeoutWriter) Header() http.Header { tw.mu.Lock() defer tw.mu.Unlock() if tw.written { // Return a copy to prevent modifications after timeout if tw.header == nil { tw.header = make(http.Header) } return tw.header } return tw.ResponseWriter.Header() } func (tw *timeoutWriter) markWritten() { tw.mu.Lock() defer tw.mu.Unlock() tw.written = true } // TimeoutMiddleware creates a timeout middleware that returns JSON errors func TimeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { startTime := time.Now() ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() tw := &timeoutWriter{ ResponseWriter: w, header: make(http.Header), } done := make(chan struct{}) panicChan := make(chan interface{}, 1) go func() { defer func() { if p := recover(); p != nil { panicChan <- p } }() next.ServeHTTP(tw, r.WithContext(ctx)) close(done) }() select { case p := <-panicChan: panic(p) case <-done: return case <-ctx.Done(): tw.markWritten() // Prevent the handler from writing after timeout execTime := time.Since(startTime) // Write directly to the underlying writer since we've marked tw as written // This is safe because markWritten() prevents the handler from writing tw.mu.Lock() w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusRequestTimeout) tw.mu.Unlock() response := map[string]interface{}{ "status": "error", "error": map[string]interface{}{ "msg": "Request timeout", "code": http.StatusRequestTimeout, }, "@meta": map[string]interface{}{ "time_zone": "UTC", "api_version": 1, "execution_time": fmt.Sprintf("%d ms", execTime.Milliseconds()), }, } _ = json.NewEncoder(w).Encode(response) } }) } }