diff --git a/internal/database/database.go b/internal/database/database.go index 393f676..7450210 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -88,6 +88,7 @@ CREATE INDEX IF NOT EXISTS idx_announcements_prefix_id ON announcements(prefix_i CREATE INDEX IF NOT EXISTS idx_announcements_asn_id ON announcements(asn_id); CREATE INDEX IF NOT EXISTS idx_asn_peerings_from_asn ON asn_peerings(from_asn_id); CREATE INDEX IF NOT EXISTS idx_asn_peerings_to_asn ON asn_peerings(to_asn_id); +CREATE INDEX IF NOT EXISTS idx_asn_peerings_lookup ON asn_peerings(from_asn_id, to_asn_id); -- Indexes for live routes table CREATE INDEX IF NOT EXISTS idx_live_routes_active @@ -102,6 +103,19 @@ CREATE INDEX IF NOT EXISTS idx_live_routes_prefix ON live_routes(prefix_id) WHERE withdrawn_at IS NULL; +-- Critical index for the most common query pattern +CREATE INDEX IF NOT EXISTS idx_live_routes_lookup + ON live_routes(prefix_id, origin_asn_id, peer_asn) + WHERE withdrawn_at IS NULL; + +-- Index for withdrawal updates by prefix and peer +CREATE INDEX IF NOT EXISTS idx_live_routes_withdraw + ON live_routes(prefix_id, peer_asn) + WHERE withdrawn_at IS NULL; + +-- Additional indexes for prefixes table +CREATE INDEX IF NOT EXISTS idx_prefixes_prefix ON prefixes(prefix); + -- Indexes for bgp_peers table CREATE INDEX IF NOT EXISTS idx_bgp_peers_asn ON bgp_peers(peer_asn); CREATE INDEX IF NOT EXISTS idx_bgp_peers_last_seen ON bgp_peers(last_seen); @@ -209,7 +223,23 @@ func NewWithConfig(config Config, logger *slog.Logger) (*Database, error) { // Initialize creates the database schema if it doesn't exist. func (d *Database) Initialize() error { - _, err := d.db.Exec(dbSchema) + // Set SQLite pragmas for better performance + pragmas := []string{ + "PRAGMA journal_mode=WAL", // Already set in connection string + "PRAGMA synchronous=NORMAL", // Faster than FULL, still safe + "PRAGMA cache_size=-524288", // 512MB cache + "PRAGMA temp_store=MEMORY", // Use memory for temp tables + "PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O + "PRAGMA optimize", // Run optimizer + } + + for _, pragma := range pragmas { + if err := d.exec(pragma); err != nil { + d.logger.Warn("Failed to set pragma", "pragma", pragma, "error", err) + } + } + + err := d.exec(dbSchema) return err } @@ -219,9 +249,19 @@ func (d *Database) Close() error { return d.db.Close() } +// beginTx starts a new transaction with logging +func (d *Database) beginTx() (*loggingTx, error) { + tx, err := d.db.Begin() + if err != nil { + return nil, err + } + + return &loggingTx{Tx: tx, logger: d.logger}, nil +} + // GetOrCreateASN retrieves an existing ASN or creates a new one if it doesn't exist. func (d *Database) GetOrCreateASN(number int, timestamp time.Time) (*ASN, error) { - tx, err := d.db.Begin() + tx, err := d.beginTx() if err != nil { return nil, err } @@ -282,7 +322,7 @@ func (d *Database) GetOrCreateASN(number int, timestamp time.Time) (*ASN, error) // GetOrCreatePrefix retrieves an existing prefix or creates a new one if it doesn't exist. func (d *Database) GetOrCreatePrefix(prefix string, timestamp time.Time) (*Prefix, error) { - tx, err := d.db.Begin() + tx, err := d.beginTx() if err != nil { return nil, err } @@ -344,7 +384,7 @@ func (d *Database) GetOrCreatePrefix(prefix string, timestamp time.Time) (*Prefi // RecordAnnouncement inserts a new BGP announcement or withdrawal into the database. func (d *Database) RecordAnnouncement(announcement *Announcement) error { - _, err := d.db.Exec(` + err := d.exec(` INSERT INTO announcements (id, prefix_id, asn_id, origin_asn_id, path, next_hop, timestamp, is_withdrawal) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, announcement.ID.String(), announcement.PrefixID.String(), @@ -356,7 +396,7 @@ func (d *Database) RecordAnnouncement(announcement *Announcement) error { // RecordPeering records a peering relationship between two ASNs. func (d *Database) RecordPeering(fromASNID, toASNID string, timestamp time.Time) error { - tx, err := d.db.Begin() + tx, err := d.beginTx() if err != nil { return err } @@ -408,7 +448,7 @@ func (d *Database) UpdateLiveRoute( ) error { // Check if route already exists var routeID sql.NullString - err := d.db.QueryRow(` + err := d.queryRow(` SELECT id FROM live_routes WHERE prefix_id = ? AND origin_asn_id = ? AND peer_asn = ? AND withdrawn_at IS NULL`, prefixID.String(), originASNID.String(), peerASN).Scan(&routeID) @@ -419,14 +459,14 @@ func (d *Database) UpdateLiveRoute( if routeID.Valid { // Route exists and is active, update it - _, err = d.db.Exec(` + err = d.exec(` UPDATE live_routes SET next_hop = ?, announced_at = ? WHERE id = ?`, nextHop, timestamp, routeID.String) } else { // Either new route or re-announcement of withdrawn route - _, err = d.db.Exec(` + err = d.exec(` INSERT OR REPLACE INTO live_routes (id, prefix_id, origin_asn_id, peer_asn, next_hop, announced_at, withdrawn_at) VALUES (?, ?, ?, ?, ?, ?, NULL)`, @@ -439,7 +479,7 @@ func (d *Database) UpdateLiveRoute( // WithdrawLiveRoute marks a route as withdrawn in the live routing table func (d *Database) WithdrawLiveRoute(prefixID uuid.UUID, peerASN int, timestamp time.Time) error { - _, err := d.db.Exec(` + err := d.exec(` UPDATE live_routes SET withdrawn_at = ? WHERE prefix_id = ? AND peer_asn = ? AND withdrawn_at IS NULL`, @@ -450,7 +490,7 @@ func (d *Database) WithdrawLiveRoute(prefixID uuid.UUID, peerASN int, timestamp // GetActiveLiveRoutes returns all currently active routes (not withdrawn) func (d *Database) GetActiveLiveRoutes() ([]LiveRoute, error) { - rows, err := d.db.Query(` + rows, err := d.query(` SELECT id, prefix_id, origin_asn_id, peer_asn, next_hop, announced_at FROM live_routes WHERE withdrawn_at IS NULL @@ -484,7 +524,7 @@ func (d *Database) GetActiveLiveRoutes() ([]LiveRoute, error) { // UpdatePeer updates or creates a BGP peer record func (d *Database) UpdatePeer(peerIP string, peerASN int, messageType string, timestamp time.Time) error { - tx, err := d.db.Begin() + tx, err := d.beginTx() if err != nil { return err } @@ -533,41 +573,49 @@ func (d *Database) GetStats() (Stats, error) { var stats Stats // Count ASNs - err := d.db.QueryRow("SELECT COUNT(*) FROM asns").Scan(&stats.ASNs) + d.logger.Info("Counting ASNs") + err := d.queryRow("SELECT COUNT(*) FROM asns").Scan(&stats.ASNs) if err != nil { return stats, err } // Count prefixes - err = d.db.QueryRow("SELECT COUNT(*) FROM prefixes").Scan(&stats.Prefixes) + d.logger.Info("Counting prefixes") + err = d.queryRow("SELECT COUNT(*) FROM prefixes").Scan(&stats.Prefixes) if err != nil { return stats, err } // Count IPv4 and IPv6 prefixes + d.logger.Info("Counting IPv4 prefixes") const ipVersionV4 = 4 - err = d.db.QueryRow("SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV4).Scan(&stats.IPv4Prefixes) + err = d.queryRow("SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV4).Scan(&stats.IPv4Prefixes) if err != nil { return stats, err } + d.logger.Info("Counting IPv6 prefixes") const ipVersionV6 = 6 - err = d.db.QueryRow("SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV6).Scan(&stats.IPv6Prefixes) + err = d.queryRow("SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV6).Scan(&stats.IPv6Prefixes) if err != nil { return stats, err } // Count peerings - err = d.db.QueryRow("SELECT COUNT(*) FROM asn_peerings").Scan(&stats.Peerings) + d.logger.Info("Counting peerings") + err = d.queryRow("SELECT COUNT(*) FROM asn_peerings").Scan(&stats.Peerings) if err != nil { return stats, err } // Count live routes - err = d.db.QueryRow("SELECT COUNT(*) FROM live_routes WHERE withdrawn_at IS NULL").Scan(&stats.LiveRoutes) + d.logger.Info("Counting live routes") + err = d.queryRow("SELECT COUNT(*) FROM live_routes WHERE withdrawn_at IS NULL").Scan(&stats.LiveRoutes) if err != nil { return stats, err } + d.logger.Info("Stats collection complete") + return stats, nil } diff --git a/internal/database/slowquery.go b/internal/database/slowquery.go new file mode 100644 index 0000000..471f307 --- /dev/null +++ b/internal/database/slowquery.go @@ -0,0 +1,98 @@ +package database + +import ( + "context" + "database/sql" + "log/slog" + "time" +) + +const slowQueryThreshold = 10 * time.Millisecond + +// logSlowQuery logs queries that take longer than slowQueryThreshold +func logSlowQuery(logger *slog.Logger, query string, start time.Time) { + elapsed := time.Since(start) + if elapsed > slowQueryThreshold { + logger.Debug("Slow query", "query", query, "duration", elapsed) + } +} + +// queryRow wraps QueryRow with slow query logging +func (d *Database) queryRow(query string, args ...interface{}) *sql.Row { + start := time.Now() + defer logSlowQuery(d.logger, query, start) + + return d.db.QueryRow(query, args...) +} + +// query wraps Query with slow query logging +func (d *Database) query(query string, args ...interface{}) (*sql.Rows, error) { + start := time.Now() + defer logSlowQuery(d.logger, query, start) + + return d.db.Query(query, args...) +} + +// exec wraps Exec with slow query logging +func (d *Database) exec(query string, args ...interface{}) error { + start := time.Now() + defer logSlowQuery(d.logger, query, start) + + _, err := d.db.Exec(query, args...) + + return err +} + +// loggingTx wraps sql.Tx to log slow queries +type loggingTx struct { + *sql.Tx + logger *slog.Logger +} + +// QueryRow wraps sql.Tx.QueryRow to log slow queries +func (tx *loggingTx) QueryRow(query string, args ...interface{}) *sql.Row { + start := time.Now() + defer logSlowQuery(tx.logger, query, start) + + return tx.Tx.QueryRow(query, args...) +} + +// Query wraps sql.Tx.Query to log slow queries +func (tx *loggingTx) Query(query string, args ...interface{}) (*sql.Rows, error) { + start := time.Now() + defer logSlowQuery(tx.logger, query, start) + + return tx.Tx.Query(query, args...) +} + +// QueryContext wraps sql.Tx.QueryContext to log slow queries +func (tx *loggingTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + start := time.Now() + defer logSlowQuery(tx.logger, query, start) + + return tx.Tx.QueryContext(ctx, query, args...) +} + +// QueryRowContext wraps sql.Tx.QueryRowContext to log slow queries +func (tx *loggingTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + start := time.Now() + defer logSlowQuery(tx.logger, query, start) + + return tx.Tx.QueryRowContext(ctx, query, args...) +} + +// Exec wraps sql.Tx.Exec to log slow queries +func (tx *loggingTx) Exec(query string, args ...interface{}) (sql.Result, error) { + start := time.Now() + defer logSlowQuery(tx.logger, query, start) + + return tx.Tx.Exec(query, args...) +} + +// ExecContext wraps sql.Tx.ExecContext to log slow queries +func (tx *loggingTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + start := time.Now() + defer logSlowQuery(tx.logger, query, start) + + return tx.Tx.ExecContext(ctx, query, args...) +} diff --git a/internal/server/middleware.go b/internal/server/middleware.go new file mode 100644 index 0000000..d9d9689 --- /dev/null +++ b/internal/server/middleware.go @@ -0,0 +1,201 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" +) + +// responseWriter wraps http.ResponseWriter to capture the response +type responseWriter struct { + http.ResponseWriter + body *bytes.Buffer + statusCode int + written bool + mu sync.Mutex +} + +func (rw *responseWriter) Write(b []byte) (int, error) { + rw.mu.Lock() + defer rw.mu.Unlock() + + if !rw.written { + rw.written = true + } + + return rw.body.Write(b) +} + +func (rw *responseWriter) WriteHeader(statusCode int) { + rw.mu.Lock() + defer rw.mu.Unlock() + + if !rw.written { + rw.statusCode = statusCode + rw.written = true + } +} + +func (rw *responseWriter) Header() http.Header { + return rw.ResponseWriter.Header() +} + +// JSONResponseMiddleware wraps all JSON responses with metadata +func JSONResponseMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip non-JSON endpoints + if r.URL.Path == "/" || r.URL.Path == "/status" { + next.ServeHTTP(w, r) + + return + } + + startTime := time.Now() + + // Create a custom response writer to capture the response + rw := &responseWriter{ + ResponseWriter: w, + body: &bytes.Buffer{}, + statusCode: http.StatusOK, + } + + // Serve the request + next.ServeHTTP(rw, r) + + // Calculate execution time + execTime := time.Since(startTime) + + // Only process JSON responses + contentType := rw.Header().Get("Content-Type") + if contentType != "application/json" || rw.body.Len() == 0 { + // Write the original response + w.WriteHeader(rw.statusCode) + _, _ = w.Write(rw.body.Bytes()) + + return + } + + // Parse the original response + var originalResponse map[string]interface{} + if err := json.Unmarshal(rw.body.Bytes(), &originalResponse); err != nil { + // If we can't parse it, just write original + w.WriteHeader(rw.statusCode) + _, _ = w.Write(rw.body.Bytes()) + + return + } + + // Add @meta field + originalResponse["@meta"] = map[string]interface{}{ + "time_zone": "UTC", + "api_version": 1, + "execution_time": fmt.Sprintf("%d ms", execTime.Milliseconds()), + } + + // Write the enhanced response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(rw.statusCode) + _ = json.NewEncoder(w).Encode(originalResponse) + }) +} + +// timeoutWriter wraps ResponseWriter to prevent concurrent writes after timeout +type timeoutWriter struct { + http.ResponseWriter + mu sync.Mutex + written bool +} + +func (tw *timeoutWriter) Write(b []byte) (int, error) { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.written { + return 0, nil // Discard writes after timeout + } + + return tw.ResponseWriter.Write(b) +} + +func (tw *timeoutWriter) WriteHeader(statusCode int) { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.written { + return // Discard writes after timeout + } + + tw.ResponseWriter.WriteHeader(statusCode) +} + +func (tw *timeoutWriter) Header() http.Header { + return tw.ResponseWriter.Header() +} + +func (tw *timeoutWriter) markWritten() { + tw.mu.Lock() + defer tw.mu.Unlock() + tw.written = true +} + +// TimeoutMiddleware creates a timeout middleware that returns JSON errors +func TimeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer cancel() + + tw := &timeoutWriter{ + ResponseWriter: w, + } + + done := make(chan struct{}) + panicChan := make(chan interface{}, 1) + + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + + next.ServeHTTP(tw, r.WithContext(ctx)) + close(done) + }() + + select { + case p := <-panicChan: + panic(p) + case <-done: + return + case <-ctx.Done(): + tw.markWritten() // Prevent the handler from writing after timeout + execTime := time.Since(startTime) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusRequestTimeout) + + response := map[string]interface{}{ + "status": "error", + "error": map[string]interface{}{ + "msg": "Request timeout", + "code": http.StatusRequestTimeout, + }, + "@meta": map[string]interface{}{ + "time_zone": "UTC", + "api_version": 1, + "execution_time": fmt.Sprintf("%d ms", execTime.Milliseconds()), + }, + } + + _ = json.NewEncoder(w).Encode(response) + } + }) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 303565f..a6623d2 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -47,8 +47,9 @@ func (s *Server) setupRoutes() { r.Use(middleware.RealIP) r.Use(middleware.Logger) r.Use(middleware.Recoverer) - const requestTimeout = 60 * time.Second - r.Use(middleware.Timeout(requestTimeout)) + const requestTimeout = 2 * time.Second + r.Use(TimeoutMiddleware(requestTimeout)) + r.Use(JSONResponseMiddleware) // Routes r.Get("/", s.handleRoot()) @@ -124,15 +125,60 @@ func (s *Server) handleStatusJSON() http.HandlerFunc { LiveRoutes int `json:"live_routes"` } - return func(w http.ResponseWriter, _ *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + // Create a 1 second timeout context for this request + ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second) + defer cancel() + metrics := s.streamer.GetMetrics() - // Get database stats - dbStats, err := s.db.GetStats() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + // Get database stats with timeout + statsChan := make(chan database.Stats) + errChan := make(chan error) + + go func() { + s.logger.Debug("Starting database stats query") + dbStats, err := s.db.GetStats() + if err != nil { + s.logger.Debug("Database stats query failed", "error", err) + errChan <- err + + return + } + s.logger.Debug("Database stats query completed") + statsChan <- dbStats + }() + + var dbStats database.Stats + select { + case <-ctx.Done(): + s.logger.Error("Database stats timeout in status.json") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusRequestTimeout) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "error", + "error": map[string]interface{}{ + "msg": "Database timeout", + "code": http.StatusRequestTimeout, + }, + }) return + case err := <-errChan: + s.logger.Error("Failed to get database stats", "error", err) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "error", + "error": map[string]interface{}{ + "msg": err.Error(), + "code": http.StatusInternalServerError, + }, + }) + + return + case dbStats = <-statsChan: + // Success } uptime := time.Since(metrics.ConnectedSince).Truncate(time.Second).String() @@ -158,7 +204,11 @@ func (s *Server) handleStatusJSON() http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(stats); err != nil { + response := map[string]interface{}{ + "status": "ok", + "data": stats, + } + if err := json.NewEncoder(w).Encode(response); err != nil { s.logger.Error("Failed to encode stats", "error", err) } } @@ -182,15 +232,53 @@ func (s *Server) handleStats() http.HandlerFunc { LiveRoutes int `json:"live_routes"` } - return func(w http.ResponseWriter, _ *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + // Create a 1 second timeout context for this request + ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second) + defer cancel() + + // Check if context is already cancelled + select { + case <-ctx.Done(): + http.Error(w, "Request timeout", http.StatusRequestTimeout) + + return + default: + } + metrics := s.streamer.GetMetrics() - // Get database stats - dbStats, err := s.db.GetStats() - if err != nil { + // Get database stats with timeout + statsChan := make(chan database.Stats) + errChan := make(chan error) + + go func() { + s.logger.Debug("Starting database stats query") + dbStats, err := s.db.GetStats() + if err != nil { + s.logger.Debug("Database stats query failed", "error", err) + errChan <- err + + return + } + s.logger.Debug("Database stats query completed") + statsChan <- dbStats + }() + + var dbStats database.Stats + select { + case <-ctx.Done(): + s.logger.Error("Database stats timeout") + http.Error(w, "Database timeout", http.StatusRequestTimeout) + + return + case err := <-errChan: + s.logger.Error("Failed to get database stats", "error", err) http.Error(w, err.Error(), http.StatusInternalServerError) return + case dbStats = <-statsChan: + // Success } uptime := time.Since(metrics.ConnectedSince).Truncate(time.Second).String() @@ -216,7 +304,11 @@ func (s *Server) handleStats() http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(stats); err != nil { + response := map[string]interface{}{ + "status": "ok", + "data": stats, + } + if err := json.NewEncoder(w).Encode(response); err != nil { s.logger.Error("Failed to encode stats", "error", err) } } diff --git a/internal/templates/status.html b/internal/templates/status.html index f7b7b93..279b009 100644 --- a/internal/templates/status.html +++ b/internal/templates/status.html @@ -145,7 +145,17 @@ function updateStatus() { fetch('/api/v1/stats') .then(response => response.json()) - .then(data => { + .then(response => { + // Check if response is an error + if (response.status === 'error') { + document.getElementById('error').textContent = 'Error: ' + response.error.msg; + document.getElementById('error').style.display = 'block'; + return; + } + + // Extract data from successful response + const data = response.data; + // Connection status const connectedEl = document.getElementById('connected'); connectedEl.textContent = data.connected ? 'Connected' : 'Disconnected';