Add context cancellation support to database operations
- Add context-aware versions of all read operations in the database - Update handlers to use context from HTTP requests - Allows database queries to be cancelled when HTTP requests timeout - Prevents database operations from continuing after client disconnects
This commit is contained in:
@@ -92,7 +92,7 @@ func (s *Server) handleStatusJSON() http.HandlerFunc {
|
||||
errChan := make(chan error)
|
||||
|
||||
go func() {
|
||||
dbStats, err := s.db.GetStats()
|
||||
dbStats, err := s.db.GetStatsContext(ctx)
|
||||
if err != nil {
|
||||
s.logger.Debug("Database stats query failed", "error", err)
|
||||
errChan <- err
|
||||
@@ -126,7 +126,7 @@ func (s *Server) handleStatusJSON() http.HandlerFunc {
|
||||
const bitsPerMegabit = 1000000.0
|
||||
|
||||
// Get route counts from database
|
||||
ipv4Routes, ipv6Routes, err := s.db.GetLiveRouteCounts()
|
||||
ipv4Routes, ipv6Routes, err := s.db.GetLiveRouteCountsContext(ctx)
|
||||
if err != nil {
|
||||
s.logger.Warn("Failed to get live route counts", "error", err)
|
||||
// Continue with zero counts
|
||||
@@ -234,7 +234,7 @@ func (s *Server) handleStats() http.HandlerFunc {
|
||||
errChan := make(chan error)
|
||||
|
||||
go func() {
|
||||
dbStats, err := s.db.GetStats()
|
||||
dbStats, err := s.db.GetStatsContext(ctx)
|
||||
if err != nil {
|
||||
s.logger.Debug("Database stats query failed", "error", err)
|
||||
errChan <- err
|
||||
@@ -267,7 +267,7 @@ func (s *Server) handleStats() http.HandlerFunc {
|
||||
const bitsPerMegabit = 1000000.0
|
||||
|
||||
// Get route counts from database
|
||||
ipv4Routes, ipv6Routes, err := s.db.GetLiveRouteCounts()
|
||||
ipv4Routes, ipv6Routes, err := s.db.GetLiveRouteCountsContext(ctx)
|
||||
if err != nil {
|
||||
s.logger.Warn("Failed to get live route counts", "error", err)
|
||||
// Continue with zero counts
|
||||
@@ -354,7 +354,7 @@ func (s *Server) handleIPLookup() http.HandlerFunc {
|
||||
}
|
||||
|
||||
// Look up AS information for the IP
|
||||
asInfo, err := s.db.GetASInfoForIP(ip)
|
||||
asInfo, err := s.db.GetASInfoForIPContext(r.Context(), ip)
|
||||
if err != nil {
|
||||
// Check if it's an invalid IP error
|
||||
if errors.Is(err, database.ErrInvalidIP) {
|
||||
@@ -385,7 +385,7 @@ func (s *Server) handleASDetailJSON() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
asInfo, prefixes, err := s.db.GetASDetails(asn)
|
||||
asInfo, prefixes, err := s.db.GetASDetailsContext(r.Context(), asn)
|
||||
if err != nil {
|
||||
if errors.Is(err, database.ErrNoRoute) {
|
||||
writeJSONError(w, http.StatusNotFound, err.Error())
|
||||
@@ -438,7 +438,7 @@ func (s *Server) handlePrefixDetailJSON() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := s.db.GetPrefixDetails(prefix)
|
||||
routes, err := s.db.GetPrefixDetailsContext(r.Context(), prefix)
|
||||
if err != nil {
|
||||
if errors.Is(err, database.ErrNoRoute) {
|
||||
writeJSONError(w, http.StatusNotFound, err.Error())
|
||||
@@ -480,7 +480,7 @@ func (s *Server) handleASDetail() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
asInfo, prefixes, err := s.db.GetASDetails(asn)
|
||||
asInfo, prefixes, err := s.db.GetASDetailsContext(r.Context(), asn)
|
||||
if err != nil {
|
||||
if errors.Is(err, database.ErrNoRoute) {
|
||||
http.Error(w, "AS not found", http.StatusNotFound)
|
||||
@@ -584,7 +584,7 @@ func (s *Server) handlePrefixDetail() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := s.db.GetPrefixDetails(prefix)
|
||||
routes, err := s.db.GetPrefixDetailsContext(r.Context(), prefix)
|
||||
if err != nil {
|
||||
if errors.Is(err, database.ErrNoRoute) {
|
||||
http.Error(w, "Prefix not found", http.StatusNotFound)
|
||||
@@ -607,7 +607,7 @@ func (s *Server) handlePrefixDetail() http.HandlerFunc {
|
||||
for _, route := range routes {
|
||||
if _, exists := originMap[route.OriginASN]; !exists {
|
||||
// Get AS info from database
|
||||
asInfo, _, _ := s.db.GetASDetails(route.OriginASN)
|
||||
asInfo, _, _ := s.db.GetASDetailsContext(r.Context(), route.OriginASN)
|
||||
handle := ""
|
||||
description := ""
|
||||
if asInfo != nil {
|
||||
@@ -768,7 +768,7 @@ func (s *Server) handlePrefixLength() http.HandlerFunc {
|
||||
|
||||
// Get random sample of prefixes
|
||||
const maxPrefixes = 500
|
||||
prefixes, err := s.db.GetRandomPrefixesByLength(maskLength, ipVersion, maxPrefixes)
|
||||
prefixes, err := s.db.GetRandomPrefixesByLengthContext(r.Context(), maskLength, ipVersion, maxPrefixes)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get prefixes by length", "error", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
|
||||
Reference in New Issue
Block a user