package models import ( "context" "crypto/rand" "database/sql" "encoding/hex" "errors" "fmt" "time" "github.com/oklog/ulid/v2" "git.eeqj.de/sneak/upaas/internal/database" ) // tokenRandomBytes is the number of random bytes for token generation. const tokenRandomBytes = 16 // tokenPrefix is prepended to generated API tokens. const tokenPrefix = "upaas_" // APIToken represents an API authentication token. type APIToken struct { db *database.Database ID string UserID int64 Name string TokenHash string CreatedAt time.Time ExpiresAt sql.NullTime LastUsedAt sql.NullTime } // NewAPIToken creates a new APIToken with a database reference. func NewAPIToken(db *database.Database) *APIToken { return &APIToken{db: db} } // GenerateToken generates a random API token string. func GenerateToken() (string, error) { b := make([]byte, tokenRandomBytes) _, err := rand.Read(b) if err != nil { return "", fmt.Errorf("generating token: %w", err) } return tokenPrefix + hex.EncodeToString(b), nil } // Save inserts the API token into the database. func (t *APIToken) Save(ctx context.Context) error { if t.ID == "" { t.ID = ulid.Make().String() } query := `INSERT INTO api_tokens (id, user_id, name, token_hash, expires_at) VALUES (?, ?, ?, ?, ?)` _, err := t.db.Exec( ctx, query, t.ID, t.UserID, t.Name, t.TokenHash, t.ExpiresAt, ) if err != nil { return fmt.Errorf("inserting api token: %w", err) } return t.Reload(ctx) } // Reload refreshes the token from the database. func (t *APIToken) Reload(ctx context.Context) error { row := t.db.QueryRow(ctx, `SELECT id, user_id, name, token_hash, created_at, expires_at, last_used_at FROM api_tokens WHERE id = ?`, t.ID) return t.scan(row) } // Delete removes the token from the database. func (t *APIToken) Delete(ctx context.Context) error { _, err := t.db.Exec(ctx, "DELETE FROM api_tokens WHERE id = ?", t.ID) return err } // TouchLastUsed updates the last_used_at timestamp. func (t *APIToken) TouchLastUsed(ctx context.Context) error { _, err := t.db.Exec(ctx, "UPDATE api_tokens SET last_used_at = ? WHERE id = ?", time.Now().UTC(), t.ID) return err } // IsExpired reports whether the token has expired. func (t *APIToken) IsExpired() bool { return t.ExpiresAt.Valid && t.ExpiresAt.Time.Before(time.Now()) } func (t *APIToken) scan(row *sql.Row) error { return row.Scan( &t.ID, &t.UserID, &t.Name, &t.TokenHash, &t.CreatedAt, &t.ExpiresAt, &t.LastUsedAt, ) } // FindAPITokenByHash finds a token by its hash. // //nolint:nilnil // nil,nil is idiomatic for "not found" func FindAPITokenByHash( ctx context.Context, db *database.Database, hash string, ) (*APIToken, error) { token := NewAPIToken(db) row := db.QueryRow(ctx, `SELECT id, user_id, name, token_hash, created_at, expires_at, last_used_at FROM api_tokens WHERE token_hash = ?`, hash) err := token.scan(row) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, fmt.Errorf("finding api token by hash: %w", err) } return token, nil } // FindAPIToken finds a token by ID. // //nolint:nilnil // nil,nil is idiomatic for "not found" func FindAPIToken( ctx context.Context, db *database.Database, id string, ) (*APIToken, error) { token := NewAPIToken(db) row := db.QueryRow(ctx, `SELECT id, user_id, name, token_hash, created_at, expires_at, last_used_at FROM api_tokens WHERE id = ?`, id) err := token.scan(row) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, fmt.Errorf("finding api token: %w", err) } return token, nil } // ListAPITokensByUser returns all tokens for a user. func ListAPITokensByUser( ctx context.Context, db *database.Database, userID int64, ) ([]*APIToken, error) { rows, err := db.Query(ctx, `SELECT id, user_id, name, token_hash, created_at, expires_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY created_at DESC`, userID) if err != nil { return nil, fmt.Errorf("listing api tokens: %w", err) } defer func() { _ = rows.Close() }() var tokens []*APIToken for rows.Next() { t := NewAPIToken(db) scanErr := rows.Scan( &t.ID, &t.UserID, &t.Name, &t.TokenHash, &t.CreatedAt, &t.ExpiresAt, &t.LastUsedAt, ) if scanErr != nil { return nil, fmt.Errorf("scanning api token: %w", scanErr) } tokens = append(tokens, t) } rowsErr := rows.Err() if rowsErr != nil { return nil, fmt.Errorf("iterating api tokens: %w", rowsErr) } return tokens, nil }