Add JSON validation middleware to ensure valid API responses
- Created JSONValidationMiddleware that validates all JSON responses - Ensures that even on timeout or internal errors, a valid JSON error response is returned - Applied to all API endpoints including /status.json - Prevents client-side JSON parse errors when server encounters issues
This commit is contained in:
parent
7aec01c499
commit
3338e92785
@ -1482,6 +1482,7 @@ func (d *Database) GetASPeersContext(ctx context.Context, asn int) ([]ASPeer, er
|
|||||||
err := rows.Scan(&peer.ASN, &peer.Handle, &peer.Description, &peer.FirstSeen, &peer.LastSeen)
|
err := rows.Scan(&peer.ASN, &peer.Handle, &peer.Description, &peer.FirstSeen, &peer.LastSeen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
d.logger.Error("Failed to scan peer row", "error", err, "asn", asn)
|
d.logger.Error("Failed to scan peer row", "error", err, "asn", asn)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
peers = append(peers, peer)
|
peers = append(peers, peer)
|
||||||
|
|||||||
@ -217,3 +217,68 @@ func TimeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JSONValidationMiddleware ensures all JSON API responses are valid JSON
|
||||||
|
func JSONValidationMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Create a custom response writer to capture the response
|
||||||
|
rw := &responseWriter{
|
||||||
|
ResponseWriter: w,
|
||||||
|
body: &bytes.Buffer{},
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve the request
|
||||||
|
next.ServeHTTP(rw, r)
|
||||||
|
|
||||||
|
// Check if it's meant to be a JSON response
|
||||||
|
contentType := rw.Header().Get("Content-Type")
|
||||||
|
isJSON := contentType == "application/json" || contentType == ""
|
||||||
|
|
||||||
|
// If it's not JSON or has content, pass through
|
||||||
|
if !isJSON && rw.body.Len() > 0 {
|
||||||
|
w.WriteHeader(rw.statusCode)
|
||||||
|
_, _ = w.Write(rw.body.Bytes())
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For JSON responses, validate the JSON
|
||||||
|
if rw.body.Len() > 0 {
|
||||||
|
var testParse interface{}
|
||||||
|
if err := json.Unmarshal(rw.body.Bytes(), &testParse); err == nil {
|
||||||
|
// Valid JSON, write it out
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(rw.statusCode)
|
||||||
|
_, _ = w.Write(rw.body.Bytes())
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we get here, either there's no body or invalid JSON
|
||||||
|
// Write a proper error response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Determine appropriate status code
|
||||||
|
statusCode := rw.statusCode
|
||||||
|
if statusCode == http.StatusOK {
|
||||||
|
statusCode = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
w.WriteHeader(statusCode)
|
||||||
|
|
||||||
|
errorMsg := "Internal server error"
|
||||||
|
if statusCode == http.StatusRequestTimeout {
|
||||||
|
errorMsg = "Request timeout"
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"status": "error",
|
||||||
|
"error": map[string]interface{}{
|
||||||
|
"msg": errorMsg,
|
||||||
|
"code": statusCode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(response)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@ -23,7 +23,7 @@ func (s *Server) setupRoutes() {
|
|||||||
// Routes
|
// Routes
|
||||||
r.Get("/", s.handleRoot())
|
r.Get("/", s.handleRoot())
|
||||||
r.Get("/status", s.handleStatusHTML())
|
r.Get("/status", s.handleStatusHTML())
|
||||||
r.Get("/status.json", s.handleStatusJSON())
|
r.Get("/status.json", JSONValidationMiddleware(s.handleStatusJSON()).ServeHTTP)
|
||||||
|
|
||||||
// AS and prefix detail pages
|
// AS and prefix detail pages
|
||||||
r.Get("/as/{asn}", s.handleASDetail())
|
r.Get("/as/{asn}", s.handleASDetail())
|
||||||
@ -33,6 +33,7 @@ func (s *Server) setupRoutes() {
|
|||||||
|
|
||||||
// API routes
|
// API routes
|
||||||
r.Route("/api/v1", func(r chi.Router) {
|
r.Route("/api/v1", func(r chi.Router) {
|
||||||
|
r.Use(JSONValidationMiddleware)
|
||||||
r.Get("/stats", s.handleStats())
|
r.Get("/stats", s.handleStats())
|
||||||
r.Get("/ip/{ip}", s.handleIPLookup())
|
r.Get("/ip/{ip}", s.handleIPLookup())
|
||||||
r.Get("/as/{asn}", s.handleASDetailJSON())
|
r.Get("/as/{asn}", s.handleASDetailJSON())
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user