diff --git a/internal/db/queries.go b/internal/db/queries.go index e4d6250..60e409e 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -779,6 +779,96 @@ func (database *Database) DeleteSession( return nil } +// DeleteClient removes a single client record by ID. +func (database *Database) DeleteClient( + ctx context.Context, + clientID int64, +) error { + _, err := database.conn.ExecContext( + ctx, + "DELETE FROM clients WHERE id = ?", + clientID, + ) + if err != nil { + return fmt.Errorf("delete client: %w", err) + } + + return nil +} + +// GetSessionCount returns the number of active sessions. +func (database *Database) GetSessionCount( + ctx context.Context, +) (int64, error) { + var count int64 + + err := database.conn.QueryRowContext( + ctx, + "SELECT COUNT(*) FROM sessions", + ).Scan(&count) + if err != nil { + return 0, fmt.Errorf( + "get session count: %w", err, + ) + } + + return count, nil +} + +// ClientCountForSession returns the number of clients +// belonging to a session. +func (database *Database) ClientCountForSession( + ctx context.Context, + sessionID int64, +) (int64, error) { + var count int64 + + err := database.conn.QueryRowContext( + ctx, + `SELECT COUNT(*) FROM clients + WHERE session_id = ?`, + sessionID, + ).Scan(&count) + if err != nil { + return 0, fmt.Errorf( + "client count for session: %w", err, + ) + } + + return count, nil +} + +// DeleteStaleSessions removes clients not seen since the +// cutoff and cleans up orphaned sessions. +func (database *Database) DeleteStaleSessions( + ctx context.Context, + cutoff time.Time, +) (int64, error) { + res, err := database.conn.ExecContext(ctx, + "DELETE FROM clients WHERE last_seen < ?", + cutoff, + ) + if err != nil { + return 0, fmt.Errorf( + "delete stale clients: %w", err, + ) + } + + deleted, _ := res.RowsAffected() + + _, err = database.conn.ExecContext(ctx, + `DELETE FROM sessions WHERE id NOT IN + (SELECT DISTINCT session_id FROM clients)`, + ) + if err != nil { + return deleted, fmt.Errorf( + "delete orphan sessions: %w", err, + ) + } + + return deleted, nil +} + // GetSessionChannels returns channels a session // belongs to. func (database *Database) GetSessionChannels( diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 7159884..47923f4 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -1361,20 +1361,71 @@ func (hdlr *Handlers) canAccessChannelHistory( return true } -// HandleServerInfo returns server metadata. -func (hdlr *Handlers) HandleServerInfo() http.HandlerFunc { - type infoResponse struct { - Name string `json:"name"` - MOTD string `json:"motd"` - } - +// HandleLogout deletes the authenticated client's token. +func (hdlr *Handlers) HandleLogout() http.HandlerFunc { return func( writer http.ResponseWriter, request *http.Request, ) { - hdlr.respondJSON(writer, request, &infoResponse{ - Name: hdlr.params.Config.ServerName, - MOTD: hdlr.params.Config.MOTD, + _, clientID, _, ok := + hdlr.requireAuth(writer, request) + if !ok { + return + } + + err := hdlr.params.Database.DeleteClient( + request.Context(), clientID, + ) + if err != nil { + hdlr.log.Error( + "delete client failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return + } + + hdlr.respondJSON(writer, request, + map[string]string{"status": "ok"}, + http.StatusOK) + } +} + +// HandleUsersMe returns the current user's session info. +func (hdlr *Handlers) HandleUsersMe() http.HandlerFunc { + return hdlr.HandleState() +} + +// HandleServerInfo returns server metadata. +func (hdlr *Handlers) HandleServerInfo() http.HandlerFunc { + return func( + writer http.ResponseWriter, + request *http.Request, + ) { + users, err := hdlr.params.Database.GetSessionCount( + request.Context(), + ) + if err != nil { + hdlr.log.Error( + "get session count failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return + } + + hdlr.respondJSON(writer, request, map[string]any{ + "name": hdlr.params.Config.ServerName, + "motd": hdlr.params.Config.MOTD, + "users": users, }, http.StatusOK) } } diff --git a/internal/server/routes.go b/internal/server/routes.go index e7b632e..9e945e7 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -86,6 +86,14 @@ func (srv *Server) setupAPIv1(router chi.Router) { "/state", srv.handlers.HandleState(), ) + router.Post( + "/logout", + srv.handlers.HandleLogout(), + ) + router.Get( + "/users/me", + srv.handlers.HandleUsersMe(), + ) router.Get( "/messages", srv.handlers.HandleGetMessages(),