Add context cancellation support to database operations

- Add context-aware versions of all read operations in the database
- Update handlers to use context from HTTP requests
- Allows database queries to be cancelled when HTTP requests timeout
- Prevents database operations from continuing after client disconnects
This commit is contained in:
2025-07-28 19:27:55 +02:00
parent 0196251906
commit e0a4c8642e
7 changed files with 1347 additions and 49 deletions

View File

@@ -92,7 +92,7 @@ func (s *Server) handleStatusJSON() http.HandlerFunc {
errChan := make(chan error)
go func() {
dbStats, err := s.db.GetStats()
dbStats, err := s.db.GetStatsContext(ctx)
if err != nil {
s.logger.Debug("Database stats query failed", "error", err)
errChan <- err
@@ -126,7 +126,7 @@ func (s *Server) handleStatusJSON() http.HandlerFunc {
const bitsPerMegabit = 1000000.0
// Get route counts from database
ipv4Routes, ipv6Routes, err := s.db.GetLiveRouteCounts()
ipv4Routes, ipv6Routes, err := s.db.GetLiveRouteCountsContext(ctx)
if err != nil {
s.logger.Warn("Failed to get live route counts", "error", err)
// Continue with zero counts
@@ -234,7 +234,7 @@ func (s *Server) handleStats() http.HandlerFunc {
errChan := make(chan error)
go func() {
dbStats, err := s.db.GetStats()
dbStats, err := s.db.GetStatsContext(ctx)
if err != nil {
s.logger.Debug("Database stats query failed", "error", err)
errChan <- err
@@ -267,7 +267,7 @@ func (s *Server) handleStats() http.HandlerFunc {
const bitsPerMegabit = 1000000.0
// Get route counts from database
ipv4Routes, ipv6Routes, err := s.db.GetLiveRouteCounts()
ipv4Routes, ipv6Routes, err := s.db.GetLiveRouteCountsContext(ctx)
if err != nil {
s.logger.Warn("Failed to get live route counts", "error", err)
// Continue with zero counts
@@ -354,7 +354,7 @@ func (s *Server) handleIPLookup() http.HandlerFunc {
}
// Look up AS information for the IP
asInfo, err := s.db.GetASInfoForIP(ip)
asInfo, err := s.db.GetASInfoForIPContext(r.Context(), ip)
if err != nil {
// Check if it's an invalid IP error
if errors.Is(err, database.ErrInvalidIP) {
@@ -385,7 +385,7 @@ func (s *Server) handleASDetailJSON() http.HandlerFunc {
return
}
asInfo, prefixes, err := s.db.GetASDetails(asn)
asInfo, prefixes, err := s.db.GetASDetailsContext(r.Context(), asn)
if err != nil {
if errors.Is(err, database.ErrNoRoute) {
writeJSONError(w, http.StatusNotFound, err.Error())
@@ -438,7 +438,7 @@ func (s *Server) handlePrefixDetailJSON() http.HandlerFunc {
return
}
routes, err := s.db.GetPrefixDetails(prefix)
routes, err := s.db.GetPrefixDetailsContext(r.Context(), prefix)
if err != nil {
if errors.Is(err, database.ErrNoRoute) {
writeJSONError(w, http.StatusNotFound, err.Error())
@@ -480,7 +480,7 @@ func (s *Server) handleASDetail() http.HandlerFunc {
return
}
asInfo, prefixes, err := s.db.GetASDetails(asn)
asInfo, prefixes, err := s.db.GetASDetailsContext(r.Context(), asn)
if err != nil {
if errors.Is(err, database.ErrNoRoute) {
http.Error(w, "AS not found", http.StatusNotFound)
@@ -584,7 +584,7 @@ func (s *Server) handlePrefixDetail() http.HandlerFunc {
return
}
routes, err := s.db.GetPrefixDetails(prefix)
routes, err := s.db.GetPrefixDetailsContext(r.Context(), prefix)
if err != nil {
if errors.Is(err, database.ErrNoRoute) {
http.Error(w, "Prefix not found", http.StatusNotFound)
@@ -607,7 +607,7 @@ func (s *Server) handlePrefixDetail() http.HandlerFunc {
for _, route := range routes {
if _, exists := originMap[route.OriginASN]; !exists {
// Get AS info from database
asInfo, _, _ := s.db.GetASDetails(route.OriginASN)
asInfo, _, _ := s.db.GetASDetailsContext(r.Context(), route.OriginASN)
handle := ""
description := ""
if asInfo != nil {
@@ -768,7 +768,7 @@ func (s *Server) handlePrefixLength() http.HandlerFunc {
// Get random sample of prefixes
const maxPrefixes = 500
prefixes, err := s.db.GetRandomPrefixesByLength(maskLength, ipVersion, maxPrefixes)
prefixes, err := s.db.GetRandomPrefixesByLengthContext(r.Context(), maskLength, ipVersion, maxPrefixes)
if err != nil {
s.logger.Error("Failed to get prefixes by length", "error", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)