routewatch/internal/server/middleware.go
sneak e1d0ab5ea6 Add detailed godoc documentation to CLIEntry function
Expand the documentation comment for CLIEntry to provide more context
about what the function does, including its use of the fx dependency
injection framework, signal handling, and blocking behavior.
2025-12-27 12:24:22 +07:00

304 lines
7.8 KiB
Go

package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
)
// responseWriter wraps http.ResponseWriter to capture the response
type responseWriter struct {
http.ResponseWriter
body *bytes.Buffer
statusCode int
written bool
mu sync.Mutex
}
func (rw *responseWriter) Write(b []byte) (int, error) {
rw.mu.Lock()
defer rw.mu.Unlock()
if !rw.written {
rw.written = true
}
return rw.body.Write(b)
}
func (rw *responseWriter) WriteHeader(statusCode int) {
rw.mu.Lock()
defer rw.mu.Unlock()
if !rw.written {
rw.statusCode = statusCode
rw.written = true
}
}
func (rw *responseWriter) Header() http.Header {
return rw.ResponseWriter.Header()
}
// JSONResponseMiddleware is an HTTP middleware that wraps all JSON responses
// with a @meta field containing execution metadata. The metadata includes the
// time zone (always UTC), API version, and request execution time in milliseconds.
//
// Endpoints "/" and "/status" are excluded from this processing and passed through
// unchanged. Non-JSON responses and empty responses are also passed through unchanged.
func JSONResponseMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip non-JSON endpoints
if r.URL.Path == "/" || r.URL.Path == "/status" {
next.ServeHTTP(w, r)
return
}
startTime := time.Now()
// Create a custom response writer to capture the response
rw := &responseWriter{
ResponseWriter: w,
body: &bytes.Buffer{},
statusCode: http.StatusOK,
}
// Serve the request
next.ServeHTTP(rw, r)
// Calculate execution time
execTime := time.Since(startTime)
// Only process JSON responses
contentType := rw.Header().Get("Content-Type")
if contentType != "application/json" || rw.body.Len() == 0 {
// Write the original response
w.WriteHeader(rw.statusCode)
_, _ = w.Write(rw.body.Bytes())
return
}
// Parse the original response
var originalResponse map[string]interface{}
if err := json.Unmarshal(rw.body.Bytes(), &originalResponse); err != nil {
// If we can't parse it, just write original
w.WriteHeader(rw.statusCode)
_, _ = w.Write(rw.body.Bytes())
return
}
// Add @meta field
originalResponse["@meta"] = map[string]interface{}{
"time_zone": "UTC",
"api_version": 1,
"execution_time": fmt.Sprintf("%d ms", execTime.Milliseconds()),
}
// Write the enhanced response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(rw.statusCode)
_ = json.NewEncoder(w).Encode(originalResponse)
})
}
// timeoutWriter wraps ResponseWriter to prevent concurrent writes after timeout
type timeoutWriter struct {
http.ResponseWriter
mu sync.Mutex
written bool
header http.Header // cached header to prevent concurrent access
}
func (tw *timeoutWriter) Write(b []byte) (int, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.written {
return 0, nil // Discard writes after timeout
}
return tw.ResponseWriter.Write(b)
}
func (tw *timeoutWriter) WriteHeader(statusCode int) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.written {
return // Discard writes after timeout
}
tw.ResponseWriter.WriteHeader(statusCode)
}
func (tw *timeoutWriter) Header() http.Header {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.written {
// Return a copy to prevent modifications after timeout
if tw.header == nil {
tw.header = make(http.Header)
}
return tw.header
}
return tw.ResponseWriter.Header()
}
func (tw *timeoutWriter) markWritten() {
tw.mu.Lock()
defer tw.mu.Unlock()
tw.written = true
}
// TimeoutMiddleware creates an HTTP middleware that enforces a request timeout.
// If the handler does not complete within the specified duration, the middleware
// returns a JSON error response with HTTP status 408 (Request Timeout).
//
// The timeout parameter specifies the maximum duration allowed for request processing.
// The returned middleware handles panics from the wrapped handler by re-panicking
// after cleanup, and prevents concurrent writes to the response after timeout occurs.
func TimeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
tw := &timeoutWriter{
ResponseWriter: w,
header: make(http.Header),
}
done := make(chan struct{})
panicChan := make(chan interface{}, 1)
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
next.ServeHTTP(tw, r.WithContext(ctx))
close(done)
}()
select {
case p := <-panicChan:
panic(p)
case <-done:
return
case <-ctx.Done():
tw.markWritten() // Prevent the handler from writing after timeout
execTime := time.Since(startTime)
// Write directly to the underlying writer since we've marked tw as written
// This is safe because markWritten() prevents the handler from writing
tw.mu.Lock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusRequestTimeout)
tw.mu.Unlock()
response := map[string]interface{}{
"status": "error",
"error": map[string]interface{}{
"msg": "Request timeout",
"code": http.StatusRequestTimeout,
},
"@meta": map[string]interface{}{
"time_zone": "UTC",
"api_version": 1,
"execution_time": fmt.Sprintf("%d ms", execTime.Milliseconds()),
},
}
_ = json.NewEncoder(w).Encode(response)
}
})
}
}
// JSONValidationMiddleware is an HTTP middleware that validates JSON API responses.
// It ensures that responses with Content-Type "application/json" contain valid JSON.
//
// If a response is not valid JSON or is empty when JSON is expected, the middleware
// returns a properly formatted JSON error response. For timeout errors (status 408),
// the error message will be "Request timeout". For other errors, it returns
// "Internal server error" with status 500 if the original status was 200.
//
// Non-JSON responses are passed through unchanged.
func JSONValidationMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create a custom response writer to capture the response
rw := &responseWriter{
ResponseWriter: w,
body: &bytes.Buffer{},
statusCode: http.StatusOK,
}
// Serve the request
next.ServeHTTP(rw, r)
// Check if it's meant to be a JSON response
contentType := rw.Header().Get("Content-Type")
isJSON := contentType == "application/json" || contentType == ""
// If it's not JSON or has content, pass through
if !isJSON && rw.body.Len() > 0 {
w.WriteHeader(rw.statusCode)
_, _ = w.Write(rw.body.Bytes())
return
}
// For JSON responses, validate the JSON
if rw.body.Len() > 0 {
var testParse interface{}
if err := json.Unmarshal(rw.body.Bytes(), &testParse); err == nil {
// Valid JSON, write it out
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(rw.statusCode)
_, _ = w.Write(rw.body.Bytes())
return
}
}
// If we get here, either there's no body or invalid JSON
// Write a proper error response
w.Header().Set("Content-Type", "application/json")
// Determine appropriate status code
statusCode := rw.statusCode
if statusCode == http.StatusOK {
statusCode = http.StatusInternalServerError
}
w.WriteHeader(statusCode)
errorMsg := "Internal server error"
if statusCode == http.StatusRequestTimeout {
errorMsg = "Request timeout"
}
response := map[string]interface{}{
"status": "error",
"error": map[string]interface{}{
"msg": errorMsg,
"code": statusCode,
},
}
_ = json.NewEncoder(w).Encode(response)
})
}