Files
chat/internal/db/queries.go
clawbot 67446b36a1
All checks were successful
check / check (push) Successful in 5s
feat: store auth tokens as SHA-256 hashes instead of plaintext (#69)
## Summary

Hash client auth tokens with SHA-256 before storing in the database. When validating tokens, hash the incoming token and compare against the stored hash. This prevents token exposure if the database is compromised.

Existing plaintext tokens are implicitly invalidated since they will not match the new hashed lookups — users will need to create new sessions.

## Changes

- **`internal/db/queries.go`**: Added `hashToken()` helper using `crypto/sha256`. Updated `CreateSession` to store hashed token. Updated `GetSessionByToken` to hash the incoming token before querying.
- **`internal/db/auth.go`**: Updated `RegisterUser` and `LoginUser` to store hashed tokens.

## Migration

No schema changes needed. The `token` column remains `TEXT` but now stores 64-char hex SHA-256 digests instead of 64-char hex random tokens. Existing plaintext tokens are effectively invalidated.

closes #34

Co-authored-by: user <user@Mac.lan guest wan>
Reviewed-on: #69
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-10 12:44:29 +01:00

1112 lines
23 KiB
Go

package db
import (
"context"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/hex"
"encoding/json"
"fmt"
"strconv"
"time"
"git.eeqj.de/sneak/neoirc/internal/irc"
"github.com/google/uuid"
)
const (
tokenBytes = 32
defaultPollLimit = 100
defaultHistLimit = 50
)
func generateToken() (string, error) {
buf := make([]byte, tokenBytes)
_, err := rand.Read(buf)
if err != nil {
return "", fmt.Errorf("generate token: %w", err)
}
return hex.EncodeToString(buf), nil
}
// hashToken returns the lowercase hex-encoded SHA-256
// digest of a plaintext token string.
func hashToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
// IRCMessage is the IRC envelope for all messages.
type IRCMessage struct {
ID string `json:"id"`
Command string `json:"command"`
Code int `json:"code,omitempty"`
From string `json:"from,omitempty"`
To string `json:"to,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Body json.RawMessage `json:"body,omitempty"`
TS string `json:"ts"`
Meta json.RawMessage `json:"meta,omitempty"`
DBID int64 `json:"-"`
}
// isNumericCode returns true if s is exactly a 3-digit
// IRC numeric reply code.
func isNumericCode(s string) bool {
return len(s) == 3 &&
s[0] >= '0' && s[0] <= '9' &&
s[1] >= '0' && s[1] <= '9' &&
s[2] >= '0' && s[2] <= '9'
}
// ChannelInfo is a lightweight channel representation.
type ChannelInfo struct {
ID int64 `json:"id"`
Name string `json:"name"`
Topic string `json:"topic"`
}
// MemberInfo represents a channel member.
type MemberInfo struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
LastSeen time.Time `json:"lastSeen"`
}
// CreateSession registers a new session and its first client.
func (database *Database) CreateSession(
ctx context.Context,
nick string,
) (int64, int64, string, error) {
sessionUUID := uuid.New().String()
clientUUID := uuid.New().String()
token, err := generateToken()
if err != nil {
return 0, 0, "", err
}
now := time.Now()
transaction, err := database.conn.BeginTx(ctx, nil)
if err != nil {
return 0, 0, "", fmt.Errorf(
"begin tx: %w", err,
)
}
res, err := transaction.ExecContext(ctx,
`INSERT INTO sessions
(uuid, nick, created_at, last_seen)
VALUES (?, ?, ?, ?)`,
sessionUUID, nick, now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create session: %w", err,
)
}
sessionID, _ := res.LastInsertId()
tokenHash := hashToken(token)
clientRes, err := transaction.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash, now, now)
if err != nil {
_ = transaction.Rollback()
return 0, 0, "", fmt.Errorf(
"create client: %w", err,
)
}
clientID, _ := clientRes.LastInsertId()
err = transaction.Commit()
if err != nil {
return 0, 0, "", fmt.Errorf(
"commit session: %w", err,
)
}
return sessionID, clientID, token, nil
}
// GetSessionByToken returns session id, client id, and
// nick for a client token.
func (database *Database) GetSessionByToken(
ctx context.Context,
token string,
) (int64, int64, string, error) {
var (
sessionID int64
clientID int64
nick string
)
tokenHash := hashToken(token)
err := database.conn.QueryRowContext(
ctx,
`SELECT s.id, c.id, s.nick
FROM clients c
INNER JOIN sessions s
ON s.id = c.session_id
WHERE c.token = ?`,
tokenHash,
).Scan(&sessionID, &clientID, &nick)
if err != nil {
return 0, 0, "", fmt.Errorf(
"get session by token: %w", err,
)
}
now := time.Now()
_, _ = database.conn.ExecContext(
ctx,
"UPDATE sessions SET last_seen = ? WHERE id = ?",
now, sessionID,
)
_, _ = database.conn.ExecContext(
ctx,
"UPDATE clients SET last_seen = ? WHERE id = ?",
now, clientID,
)
return sessionID, clientID, nick, nil
}
// GetSessionByNick returns session id for a given nick.
func (database *Database) GetSessionByNick(
ctx context.Context,
nick string,
) (int64, error) {
var sessionID int64
err := database.conn.QueryRowContext(
ctx,
"SELECT id FROM sessions WHERE nick = ?",
nick,
).Scan(&sessionID)
if err != nil {
return 0, fmt.Errorf(
"get session by nick: %w", err,
)
}
return sessionID, nil
}
// GetChannelByName returns the channel ID for a name.
func (database *Database) GetChannelByName(
ctx context.Context,
name string,
) (int64, error) {
var channelID int64
err := database.conn.QueryRowContext(
ctx,
"SELECT id FROM channels WHERE name = ?",
name,
).Scan(&channelID)
if err != nil {
return 0, fmt.Errorf(
"get channel by name: %w", err,
)
}
return channelID, nil
}
// GetOrCreateChannel returns channel id, creating if needed.
// Uses INSERT OR IGNORE to avoid TOCTOU races.
func (database *Database) GetOrCreateChannel(
ctx context.Context,
name string,
) (int64, error) {
now := time.Now()
_, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO channels
(name, created_at, updated_at)
VALUES (?, ?, ?)`,
name, now, now)
if err != nil {
return 0, fmt.Errorf("create channel: %w", err)
}
var channelID int64
err = database.conn.QueryRowContext(
ctx,
"SELECT id FROM channels WHERE name = ?",
name,
).Scan(&channelID)
if err != nil {
return 0, fmt.Errorf("get channel: %w", err)
}
return channelID, nil
}
// JoinChannel adds a session to a channel.
func (database *Database) JoinChannel(
ctx context.Context,
channelID, sessionID int64,
) error {
_, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO channel_members
(channel_id, session_id, joined_at)
VALUES (?, ?, ?)`,
channelID, sessionID, time.Now())
if err != nil {
return fmt.Errorf("join channel: %w", err)
}
return nil
}
// PartChannel removes a session from a channel.
func (database *Database) PartChannel(
ctx context.Context,
channelID, sessionID int64,
) error {
_, err := database.conn.ExecContext(ctx,
`DELETE FROM channel_members
WHERE channel_id = ? AND session_id = ?`,
channelID, sessionID)
if err != nil {
return fmt.Errorf("part channel: %w", err)
}
return nil
}
// DeleteChannelIfEmpty removes a channel with no members.
func (database *Database) DeleteChannelIfEmpty(
ctx context.Context,
channelID int64,
) error {
_, err := database.conn.ExecContext(ctx,
`DELETE FROM channels WHERE id = ?
AND NOT EXISTS
(SELECT 1 FROM channel_members
WHERE channel_id = ?)`,
channelID, channelID)
if err != nil {
return fmt.Errorf(
"delete channel if empty: %w", err,
)
}
return nil
}
// scanChannels scans rows into a ChannelInfo slice.
func scanChannels(
rows *sql.Rows,
) ([]ChannelInfo, error) {
defer func() { _ = rows.Close() }()
var out []ChannelInfo
for rows.Next() {
var chanInfo ChannelInfo
err := rows.Scan(
&chanInfo.ID, &chanInfo.Name, &chanInfo.Topic,
)
if err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
out = append(out, chanInfo)
}
err := rows.Err()
if err != nil {
return nil, fmt.Errorf("rows error: %w", err)
}
if out == nil {
out = []ChannelInfo{}
}
return out, nil
}
// ListChannels returns channels the session has joined.
func (database *Database) ListChannels(
ctx context.Context,
sessionID int64,
) ([]ChannelInfo, error) {
rows, err := database.conn.QueryContext(ctx,
`SELECT c.id, c.name, c.topic
FROM channels c
INNER JOIN channel_members cm
ON cm.channel_id = c.id
WHERE cm.session_id = ?
ORDER BY c.name`, sessionID)
if err != nil {
return nil, fmt.Errorf("list channels: %w", err)
}
return scanChannels(rows)
}
// ListAllChannels returns every channel.
func (database *Database) ListAllChannels(
ctx context.Context,
) ([]ChannelInfo, error) {
rows, err := database.conn.QueryContext(ctx,
`SELECT id, name, topic
FROM channels ORDER BY name`)
if err != nil {
return nil, fmt.Errorf(
"list all channels: %w", err,
)
}
return scanChannels(rows)
}
// ChannelMembers returns all members of a channel.
func (database *Database) ChannelMembers(
ctx context.Context,
channelID int64,
) ([]MemberInfo, error) {
rows, err := database.conn.QueryContext(ctx,
`SELECT s.id, s.nick, s.last_seen
FROM sessions s
INNER JOIN channel_members cm
ON cm.session_id = s.id
WHERE cm.channel_id = ?
ORDER BY s.nick`, channelID)
if err != nil {
return nil, fmt.Errorf(
"query channel members: %w", err,
)
}
defer func() { _ = rows.Close() }()
var members []MemberInfo
for rows.Next() {
var member MemberInfo
err = rows.Scan(
&member.ID, &member.Nick, &member.LastSeen,
)
if err != nil {
return nil, fmt.Errorf(
"scan member: %w", err,
)
}
members = append(members, member)
}
err = rows.Err()
if err != nil {
return nil, fmt.Errorf("rows error: %w", err)
}
if members == nil {
members = []MemberInfo{}
}
return members, nil
}
// IsChannelMember checks if a session belongs to a channel.
func (database *Database) IsChannelMember(
ctx context.Context,
channelID, sessionID int64,
) (bool, error) {
var count int
err := database.conn.QueryRowContext(ctx,
`SELECT COUNT(*) FROM channel_members
WHERE channel_id = ? AND session_id = ?`,
channelID, sessionID,
).Scan(&count)
if err != nil {
return false, fmt.Errorf(
"check membership: %w", err,
)
}
return count > 0, nil
}
// scanInt64s scans rows into an int64 slice.
func scanInt64s(rows *sql.Rows) ([]int64, error) {
defer func() { _ = rows.Close() }()
var ids []int64
for rows.Next() {
var val int64
err := rows.Scan(&val)
if err != nil {
return nil, fmt.Errorf(
"scan int64: %w", err,
)
}
ids = append(ids, val)
}
err := rows.Err()
if err != nil {
return nil, fmt.Errorf("rows error: %w", err)
}
return ids, nil
}
// GetChannelMemberIDs returns session IDs in a channel.
func (database *Database) GetChannelMemberIDs(
ctx context.Context,
channelID int64,
) ([]int64, error) {
rows, err := database.conn.QueryContext(ctx,
`SELECT session_id FROM channel_members
WHERE channel_id = ?`, channelID)
if err != nil {
return nil, fmt.Errorf(
"get channel member ids: %w", err,
)
}
return scanInt64s(rows)
}
// GetSessionChannelIDs returns channel IDs for a session.
func (database *Database) GetSessionChannelIDs(
ctx context.Context,
sessionID int64,
) ([]int64, error) {
rows, err := database.conn.QueryContext(ctx,
`SELECT channel_id FROM channel_members
WHERE session_id = ?`, sessionID)
if err != nil {
return nil, fmt.Errorf(
"get session channel ids: %w", err,
)
}
return scanInt64s(rows)
}
// InsertMessage stores a message and returns its DB ID.
func (database *Database) InsertMessage(
ctx context.Context,
command, from, target string,
params json.RawMessage,
body json.RawMessage,
meta json.RawMessage,
) (int64, string, error) {
msgUUID := uuid.New().String()
now := time.Now().UTC()
if params == nil {
params = json.RawMessage("[]")
}
if body == nil {
body = json.RawMessage("[]")
}
if meta == nil {
meta = json.RawMessage("{}")
}
res, err := database.conn.ExecContext(ctx,
`INSERT INTO messages
(uuid, command, msg_from, msg_to,
params, body, meta, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
msgUUID, command, from, target,
string(params), string(body), string(meta), now)
if err != nil {
return 0, "", fmt.Errorf(
"insert message: %w", err,
)
}
dbID, _ := res.LastInsertId()
return dbID, msgUUID, nil
}
// EnqueueToSession adds a message to all clients of a
// session's queues.
func (database *Database) EnqueueToSession(
ctx context.Context,
sessionID, messageID int64,
) error {
_, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO client_queues
(client_id, message_id, created_at)
SELECT c.id, ?, ?
FROM clients c
WHERE c.session_id = ?`,
messageID, time.Now(), sessionID)
if err != nil {
return fmt.Errorf(
"enqueue to session: %w", err,
)
}
return nil
}
// EnqueueToClient adds a message to a specific client's
// queue.
func (database *Database) EnqueueToClient(
ctx context.Context,
clientID, messageID int64,
) error {
_, err := database.conn.ExecContext(ctx,
`INSERT OR IGNORE INTO client_queues
(client_id, message_id, created_at)
VALUES (?, ?, ?)`,
clientID, messageID, time.Now())
if err != nil {
return fmt.Errorf(
"enqueue to client: %w", err,
)
}
return nil
}
// PollMessages returns queued messages for a client.
func (database *Database) PollMessages(
ctx context.Context,
clientID, afterQueueID int64,
limit int,
) ([]IRCMessage, int64, error) {
if limit <= 0 {
limit = defaultPollLimit
}
rows, err := database.conn.QueryContext(ctx,
`SELECT cq.id, m.uuid, m.command,
m.msg_from, m.msg_to,
m.params, m.body, m.meta, m.created_at
FROM client_queues cq
INNER JOIN messages m
ON m.id = cq.message_id
WHERE cq.client_id = ? AND cq.id > ?
ORDER BY cq.id ASC LIMIT ?`,
clientID, afterQueueID, limit)
if err != nil {
return nil, afterQueueID, fmt.Errorf(
"poll messages: %w", err,
)
}
msgs, lastQID, scanErr := scanMessages(
rows, afterQueueID,
)
if scanErr != nil {
return nil, afterQueueID, scanErr
}
return msgs, lastQID, nil
}
// GetHistory returns message history for a target.
func (database *Database) GetHistory(
ctx context.Context,
target string,
beforeID int64,
limit int,
) ([]IRCMessage, error) {
if limit <= 0 {
limit = defaultHistLimit
}
rows, err := database.queryHistory(
ctx, target, beforeID, limit,
)
if err != nil {
return nil, err
}
msgs, _, scanErr := scanMessages(rows, 0)
if scanErr != nil {
return nil, scanErr
}
if msgs == nil {
msgs = []IRCMessage{}
}
reverseMessages(msgs)
return msgs, nil
}
func (database *Database) queryHistory(
ctx context.Context,
target string,
beforeID int64,
limit int,
) (*sql.Rows, error) {
if beforeID > 0 {
rows, err := database.conn.QueryContext(ctx,
`SELECT id, uuid, command, msg_from,
msg_to, params, body, meta, created_at
FROM messages
WHERE msg_to = ? AND id < ?
AND command = 'PRIVMSG'
ORDER BY id DESC LIMIT ?`,
target, beforeID, limit)
if err != nil {
return nil, fmt.Errorf(
"query history: %w", err,
)
}
return rows, nil
}
rows, err := database.conn.QueryContext(ctx,
`SELECT id, uuid, command, msg_from,
msg_to, params, body, meta, created_at
FROM messages
WHERE msg_to = ?
AND command = 'PRIVMSG'
ORDER BY id DESC LIMIT ?`,
target, limit)
if err != nil {
return nil, fmt.Errorf("query history: %w", err)
}
return rows, nil
}
func scanMessages(
rows *sql.Rows,
fallbackQID int64,
) ([]IRCMessage, int64, error) {
defer func() { _ = rows.Close() }()
var msgs []IRCMessage
lastQID := fallbackQID
for rows.Next() {
var (
msg IRCMessage
qID int64
params, body, meta string
createdAt time.Time
)
err := rows.Scan(
&qID, &msg.ID, &msg.Command,
&msg.From, &msg.To,
&params, &body, &meta, &createdAt,
)
if err != nil {
return nil, fallbackQID, fmt.Errorf(
"scan message: %w", err,
)
}
if params != "" && params != "[]" {
msg.Params = json.RawMessage(params)
}
msg.Body = json.RawMessage(body)
msg.Meta = json.RawMessage(meta)
msg.TS = createdAt.Format(time.RFC3339Nano)
msg.DBID = qID
lastQID = qID
if isNumericCode(msg.Command) {
code, _ := strconv.Atoi(msg.Command)
msg.Code = code
if name := irc.Name(code); name != "" {
msg.Command = name
}
}
msgs = append(msgs, msg)
}
err := rows.Err()
if err != nil {
return nil, fallbackQID, fmt.Errorf(
"rows error: %w", err,
)
}
if msgs == nil {
msgs = []IRCMessage{}
}
return msgs, lastQID, nil
}
func reverseMessages(msgs []IRCMessage) {
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i]
}
}
// ChangeNick updates a session's nickname.
func (database *Database) ChangeNick(
ctx context.Context,
sessionID int64,
newNick string,
) error {
_, err := database.conn.ExecContext(ctx,
"UPDATE sessions SET nick = ? WHERE id = ?",
newNick, sessionID)
if err != nil {
return fmt.Errorf("change nick: %w", err)
}
return nil
}
// SetTopic sets the topic for a channel.
func (database *Database) SetTopic(
ctx context.Context,
channelName, topic string,
) error {
_, err := database.conn.ExecContext(ctx,
`UPDATE channels SET topic = ?,
updated_at = ? WHERE name = ?`,
topic, time.Now(), channelName)
if err != nil {
return fmt.Errorf("set topic: %w", err)
}
return nil
}
// DeleteSession removes a session and all its data.
func (database *Database) DeleteSession(
ctx context.Context,
sessionID int64,
) error {
_, err := database.conn.ExecContext(
ctx,
"DELETE FROM sessions WHERE id = ?",
sessionID,
)
if err != nil {
return fmt.Errorf("delete session: %w", err)
}
return nil
}
// DeleteClient removes a single client record by ID.
func (database *Database) DeleteClient(
ctx context.Context,
clientID int64,
) error {
_, err := database.conn.ExecContext(
ctx,
"DELETE FROM clients WHERE id = ?",
clientID,
)
if err != nil {
return fmt.Errorf("delete client: %w", err)
}
return nil
}
// GetUserCount returns the number of active users.
func (database *Database) GetUserCount(
ctx context.Context,
) (int64, error) {
var count int64
err := database.conn.QueryRowContext(
ctx,
"SELECT COUNT(*) FROM sessions",
).Scan(&count)
if err != nil {
return 0, fmt.Errorf(
"get user count: %w", err,
)
}
return count, nil
}
// ClientCountForSession returns the number of clients
// belonging to a session.
func (database *Database) ClientCountForSession(
ctx context.Context,
sessionID int64,
) (int64, error) {
var count int64
err := database.conn.QueryRowContext(
ctx,
`SELECT COUNT(*) FROM clients
WHERE session_id = ?`,
sessionID,
).Scan(&count)
if err != nil {
return 0, fmt.Errorf(
"client count for session: %w", err,
)
}
return count, nil
}
// DeleteStaleUsers removes clients not seen since the
// cutoff and cleans up orphaned users (sessions).
func (database *Database) DeleteStaleUsers(
ctx context.Context,
cutoff time.Time,
) (int64, error) {
res, err := database.conn.ExecContext(ctx,
"DELETE FROM clients WHERE last_seen < ?",
cutoff,
)
if err != nil {
return 0, fmt.Errorf(
"delete stale clients: %w", err,
)
}
deleted, _ := res.RowsAffected()
_, err = database.conn.ExecContext(ctx,
`DELETE FROM sessions WHERE id NOT IN
(SELECT DISTINCT session_id FROM clients)`,
)
if err != nil {
return deleted, fmt.Errorf(
"delete orphan sessions: %w", err,
)
}
return deleted, nil
}
// StaleSession holds the id and nick of a session
// whose clients are all stale.
type StaleSession struct {
ID int64
Nick string
}
// GetStaleOrphanSessions returns sessions where every
// client has a last_seen before cutoff.
func (database *Database) GetStaleOrphanSessions(
ctx context.Context,
cutoff time.Time,
) ([]StaleSession, error) {
rows, err := database.conn.QueryContext(ctx,
`SELECT s.id, s.nick
FROM sessions s
WHERE s.id IN (
SELECT DISTINCT session_id FROM clients
WHERE last_seen < ?
)
AND s.id NOT IN (
SELECT DISTINCT session_id FROM clients
WHERE last_seen >= ?
)`, cutoff, cutoff)
if err != nil {
return nil, fmt.Errorf(
"get stale orphan sessions: %w", err,
)
}
defer func() { _ = rows.Close() }()
var result []StaleSession
for rows.Next() {
var stale StaleSession
if err := rows.Scan(&stale.ID, &stale.Nick); err != nil {
return nil, fmt.Errorf(
"scan stale session: %w", err,
)
}
result = append(result, stale)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf(
"iterate stale sessions: %w", err,
)
}
return result, nil
}
// GetSessionChannels returns channels a session
// belongs to.
func (database *Database) GetSessionChannels(
ctx context.Context,
sessionID int64,
) ([]ChannelInfo, error) {
rows, err := database.conn.QueryContext(ctx,
`SELECT c.id, c.name, c.topic
FROM channels c
INNER JOIN channel_members cm
ON cm.channel_id = c.id
WHERE cm.session_id = ?`, sessionID)
if err != nil {
return nil, fmt.Errorf(
"get session channels: %w", err,
)
}
return scanChannels(rows)
}
// GetChannelCount returns the total number of channels.
func (database *Database) GetChannelCount(
ctx context.Context,
) (int64, error) {
var count int64
err := database.conn.QueryRowContext(
ctx,
"SELECT COUNT(*) FROM channels",
).Scan(&count)
if err != nil {
return 0, fmt.Errorf(
"get channel count: %w", err,
)
}
return count, nil
}
// ChannelInfoFull contains extended channel information.
type ChannelInfoFull struct {
ID int64 `json:"id"`
Name string `json:"name"`
Topic string `json:"topic"`
MemberCount int64 `json:"memberCount"`
}
// ListAllChannelsWithCounts returns every channel
// with its member count.
func (database *Database) ListAllChannelsWithCounts(
ctx context.Context,
) ([]ChannelInfoFull, error) {
rows, err := database.conn.QueryContext(ctx,
`SELECT c.id, c.name, c.topic,
COUNT(cm.session_id) AS member_count
FROM channels c
LEFT JOIN channel_members cm
ON cm.channel_id = c.id
GROUP BY c.id
ORDER BY c.name`)
if err != nil {
return nil, fmt.Errorf(
"list channels with counts: %w", err,
)
}
defer func() { _ = rows.Close() }()
var out []ChannelInfoFull
for rows.Next() {
var chanInfo ChannelInfoFull
err = rows.Scan(
&chanInfo.ID, &chanInfo.Name,
&chanInfo.Topic, &chanInfo.MemberCount,
)
if err != nil {
return nil, fmt.Errorf(
"scan channel full: %w", err,
)
}
out = append(out, chanInfo)
}
err = rows.Err()
if err != nil {
return nil, fmt.Errorf("rows error: %w", err)
}
if out == nil {
out = []ChannelInfoFull{}
}
return out, nil
}
// GetChannelCreatedAt returns the creation time of a
// channel.
func (database *Database) GetChannelCreatedAt(
ctx context.Context,
channelID int64,
) (time.Time, error) {
var createdAt time.Time
err := database.conn.QueryRowContext(
ctx,
"SELECT created_at FROM channels WHERE id = ?",
channelID,
).Scan(&createdAt)
if err != nil {
return time.Time{}, fmt.Errorf(
"get channel created_at: %w", err,
)
}
return createdAt, nil
}
// GetSessionCreatedAt returns the creation time of a
// session.
func (database *Database) GetSessionCreatedAt(
ctx context.Context,
sessionID int64,
) (time.Time, error) {
var createdAt time.Time
err := database.conn.QueryRowContext(
ctx,
"SELECT created_at FROM sessions WHERE id = ?",
sessionID,
).Scan(&createdAt)
if err != nil {
return time.Time{}, fmt.Errorf(
"get session created_at: %w", err,
)
}
return createdAt, nil
}