diff --git a/internal/db/queries.go b/internal/db/queries.go index d6fbb91..5b209dd 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -869,6 +869,60 @@ func (database *Database) DeleteStaleUsers( return deleted, nil } +// StaleSession holds the id and nick of a session +// whose clients are all stale. +type StaleSession struct { + ID int64 + Nick string +} + +// GetStaleOrphanSessions returns sessions where every +// client has a last_seen before cutoff. +func (database *Database) GetStaleOrphanSessions( + ctx context.Context, + cutoff time.Time, +) ([]StaleSession, error) { + rows, err := database.conn.QueryContext(ctx, + `SELECT s.id, s.nick + FROM sessions s + WHERE s.id IN ( + SELECT DISTINCT session_id FROM clients + WHERE last_seen < ? + ) + AND s.id NOT IN ( + SELECT DISTINCT session_id FROM clients + WHERE last_seen >= ? + )`, cutoff, cutoff) + if err != nil { + return nil, fmt.Errorf( + "get stale orphan sessions: %w", err, + ) + } + + defer func() { _ = rows.Close() }() + + var result []StaleSession + + for rows.Next() { + var stale StaleSession + if err := rows.Scan(&stale.ID, &stale.Nick); err != nil { + return nil, fmt.Errorf( + "scan stale session: %w", err, + ) + } + + result = append(result, stale) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf( + "iterate stale sessions: %w", err, + ) + } + + return result, nil +} + // GetSessionChannels returns channels a session // belongs to. func (database *Database) GetSessionChannels( diff --git a/internal/handlers/api.go b/internal/handlers/api.go index b963f9b..ad31dd3 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -1405,7 +1405,7 @@ func (hdlr *Handlers) HandleLogout() http.HandlerFunc { if remaining == 0 { hdlr.cleanupUser( - request, sessionID, nick, + ctx, sessionID, nick, ) } @@ -1418,12 +1418,10 @@ func (hdlr *Handlers) HandleLogout() http.HandlerFunc { // cleanupUser parts the user from all channels (notifying // members) and deletes the session. func (hdlr *Handlers) cleanupUser( - request *http.Request, + ctx context.Context, sessionID int64, nick string, ) { - ctx := request.Context() - channels, _ := hdlr.params.Database. GetSessionChannels(ctx, sessionID) diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 7ce952e..ed6ef04 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -169,6 +169,20 @@ func (hdlr *Handlers) runCleanup( ) { cutoff := time.Now().Add(-timeout) + // Find sessions that will be orphaned so we can send + // QUIT notifications before deleting anything. + stale, err := hdlr.params.Database. + GetStaleOrphanSessions(ctx, cutoff) + if err != nil { + hdlr.log.Error( + "stale session lookup failed", "error", err, + ) + } + + for _, ss := range stale { + hdlr.cleanupUser(ctx, ss.ID, ss.Nick) + } + deleted, err := hdlr.params.Database.DeleteStaleUsers( ctx, cutoff, )