diff --git a/.gitignore b/.gitignore index a6fb1aa..95d8357 100644 --- a/.gitignore +++ b/.gitignore @@ -34,5 +34,5 @@ vendor/ # Project data.db debug.log -chat-cli +/chat-cli web/node_modules/ diff --git a/cmd/chat-cli/api/client.go b/cmd/chat-cli/api/client.go index a20298c..fee4051 100644 --- a/cmd/chat-cli/api/client.go +++ b/cmd/chat-cli/api/client.go @@ -1,15 +1,31 @@ +// Package api provides a client for the chat server HTTP API. package api import ( "bytes" "encoding/json" + "errors" "fmt" "io" "net/http" "net/url" + "strconv" "time" ) +const ( + httpTimeout = 30 * time.Second + pollExtraDelay = 5 + httpErrThreshold = 400 +) + +// ErrHTTP is returned for non-2xx responses. +var ErrHTTP = errors.New("http error") + +// ErrUnexpectedFormat is returned when the response format is +// not recognised. +var ErrUnexpectedFormat = errors.New("unexpected format") + // Client wraps HTTP calls to the chat server API. type Client struct { BaseURL string @@ -22,59 +38,32 @@ func NewClient(baseURL string) *Client { return &Client{ BaseURL: baseURL, HTTPClient: &http.Client{ - Timeout: 30 * time.Second, + Timeout: httpTimeout, }, } } -func (c *Client) do(method, path string, body interface{}) ([]byte, error) { - var bodyReader io.Reader - if body != nil { - data, err := json.Marshal(body) - if err != nil { - return nil, fmt.Errorf("marshal: %w", err) - } - bodyReader = bytes.NewReader(data) - } - - req, err := http.NewRequest(method, c.BaseURL+path, bodyReader) - if err != nil { - return nil, fmt.Errorf("request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - if c.Token != "" { - req.Header.Set("Authorization", "Bearer "+c.Token) - } - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("http: %w", err) - } - defer resp.Body.Close() - - data, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read body: %w", err) - } - - if resp.StatusCode >= 400 { - return data, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(data)) - } - - return data, nil -} - // CreateSession creates a new session on the server. -func (c *Client) CreateSession(nick string) (*SessionResponse, error) { - data, err := c.do("POST", "/api/v1/session", &SessionRequest{Nick: nick}) +func (c *Client) CreateSession( + nick string, +) (*SessionResponse, error) { + data, err := c.do( + "POST", "/api/v1/session", + &SessionRequest{Nick: nick}, + ) if err != nil { return nil, err } + var resp SessionResponse - if err := json.Unmarshal(data, &resp); err != nil { + + err = json.Unmarshal(data, &resp) + if err != nil { return nil, fmt.Errorf("decode session: %w", err) } + c.Token = resp.Token + return &resp, nil } @@ -84,78 +73,113 @@ func (c *Client) GetState() (*StateResponse, error) { if err != nil { return nil, err } + var resp StateResponse - if err := json.Unmarshal(data, &resp); err != nil { + + err = json.Unmarshal(data, &resp) + if err != nil { return nil, fmt.Errorf("decode state: %w", err) } + return &resp, nil } // SendMessage sends a message (any IRC command). func (c *Client) SendMessage(msg *Message) error { _, err := c.do("POST", "/api/v1/messages", msg) + return err } // PollMessages long-polls for new messages. -func (c *Client) PollMessages(afterID string, timeout int) ([]Message, error) { - // Use a longer HTTP timeout than the server long-poll timeout. - client := &http.Client{Timeout: time.Duration(timeout+5) * time.Second} +func (c *Client) PollMessages( + afterID string, + timeout int, +) ([]Message, error) { + pollTimeout := time.Duration( + timeout+pollExtraDelay, + ) * time.Second + + client := &http.Client{Timeout: pollTimeout} params := url.Values{} if afterID != "" { params.Set("after", afterID) } - params.Set("timeout", fmt.Sprintf("%d", timeout)) + + params.Set("timeout", strconv.Itoa(timeout)) path := "/api/v1/messages" if len(params) > 0 { path += "?" + params.Encode() } - req, err := http.NewRequest("GET", c.BaseURL+path, nil) + req, err := http.NewRequest( //nolint:noctx // CLI tool + http.MethodGet, c.BaseURL+path, nil, + ) if err != nil { return nil, err } + req.Header.Set("Authorization", "Bearer "+c.Token) - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // URL from user config if err != nil { return nil, err } - defer resp.Body.Close() + + defer func() { _ = resp.Body.Close() }() data, err := io.ReadAll(resp.Body) if err != nil { return nil, err } - if resp.StatusCode >= 400 { - return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(data)) + if resp.StatusCode >= httpErrThreshold { + return nil, fmt.Errorf( + "%w: %d: %s", + ErrHTTP, resp.StatusCode, string(data), + ) } - // The server may return an array directly or wrapped. + return decodeMessages(data) +} + +func decodeMessages(data []byte) ([]Message, error) { var msgs []Message - if err := json.Unmarshal(data, &msgs); err != nil { - // Try wrapped format. - var wrapped MessagesResponse - if err2 := json.Unmarshal(data, &wrapped); err2 != nil { - return nil, fmt.Errorf("decode messages: %w (raw: %s)", err, string(data)) - } - msgs = wrapped.Messages + + err := json.Unmarshal(data, &msgs) + if err == nil { + return msgs, nil } - return msgs, nil + var wrapped MessagesResponse + + err2 := json.Unmarshal(data, &wrapped) + if err2 != nil { + return nil, fmt.Errorf( + "decode messages: %w (raw: %s)", + err, string(data), + ) + } + + return wrapped.Messages, nil } -// JoinChannel joins a channel via the unified command endpoint. +// JoinChannel joins a channel via the unified command +// endpoint. func (c *Client) JoinChannel(channel string) error { - return c.SendMessage(&Message{Command: "JOIN", To: channel}) + return c.SendMessage( + &Message{Command: "JOIN", To: channel}, + ) } -// PartChannel leaves a channel via the unified command endpoint. +// PartChannel leaves a channel via the unified command +// endpoint. func (c *Client) PartChannel(channel string) error { - return c.SendMessage(&Message{Command: "PART", To: channel}) + return c.SendMessage( + &Message{Command: "PART", To: channel}, + ) } // ListChannels returns all channels on the server. @@ -164,29 +188,39 @@ func (c *Client) ListChannels() ([]Channel, error) { if err != nil { return nil, err } + var channels []Channel - if err := json.Unmarshal(data, &channels); err != nil { + + err = json.Unmarshal(data, &channels) + if err != nil { return nil, err } + return channels, nil } // GetMembers returns members of a channel. -func (c *Client) GetMembers(channel string) ([]string, error) { - data, err := c.do("GET", "/api/v1/channels/"+url.PathEscape(channel)+"/members", nil) +func (c *Client) GetMembers( + channel string, +) ([]string, error) { + path := "/api/v1/channels/" + + url.PathEscape(channel) + "/members" + + data, err := c.do("GET", path, nil) if err != nil { return nil, err } + var members []string - if err := json.Unmarshal(data, &members); err != nil { - // Try object format. - var obj map[string]interface{} - if err2 := json.Unmarshal(data, &obj); err2 != nil { - return nil, err - } - // Extract member names from whatever format. - return nil, fmt.Errorf("unexpected members format: %s", string(data)) + + err = json.Unmarshal(data, &members) + if err != nil { + return nil, fmt.Errorf( + "%w: members: %s", + ErrUnexpectedFormat, string(data), + ) } + return members, nil } @@ -196,9 +230,63 @@ func (c *Client) GetServerInfo() (*ServerInfo, error) { if err != nil { return nil, err } + var info ServerInfo - if err := json.Unmarshal(data, &info); err != nil { + + err = json.Unmarshal(data, &info) + if err != nil { return nil, err } + return &info, nil } + +func (c *Client) do( + method, path string, + body any, +) ([]byte, error) { + var bodyReader io.Reader + + if body != nil { + data, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("marshal: %w", err) + } + + bodyReader = bytes.NewReader(data) + } + + req, err := http.NewRequest( //nolint:noctx // CLI tool + method, c.BaseURL+path, bodyReader, + ) + if err != nil { + return nil, fmt.Errorf("request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + if c.Token != "" { + req.Header.Set("Authorization", "Bearer "+c.Token) + } + + resp, err := c.HTTPClient.Do(req) //nolint:gosec // URL from user config + if err != nil { + return nil, fmt.Errorf("http: %w", err) + } + + defer func() { _ = resp.Body.Close() }() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read body: %w", err) + } + + if resp.StatusCode >= httpErrThreshold { + return data, fmt.Errorf( + "%w: %d: %s", + ErrHTTP, resp.StatusCode, string(data), + ) + } + + return data, nil +} diff --git a/cmd/chat-cli/api/types.go b/cmd/chat-cli/api/types.go index 1655d79..12d223b 100644 --- a/cmd/chat-cli/api/types.go +++ b/cmd/chat-cli/api/types.go @@ -1,4 +1,4 @@ -package api +package api //nolint:revive // package name is intentional import "time" @@ -9,42 +9,45 @@ type SessionRequest struct { // SessionResponse is the response from POST /api/v1/session. type SessionResponse struct { - SessionID string `json:"session_id"` - ClientID string `json:"client_id"` + SessionID string `json:"sessionId"` + ClientID string `json:"clientId"` Nick string `json:"nick"` Token string `json:"token"` } // StateResponse is the response from GET /api/v1/state. type StateResponse struct { - SessionID string `json:"session_id"` - ClientID string `json:"client_id"` + SessionID string `json:"sessionId"` + ClientID string `json:"clientId"` Nick string `json:"nick"` Channels []string `json:"channels"` } // Message represents a chat message envelope. type Message struct { - Command string `json:"command"` - From string `json:"from,omitempty"` - To string `json:"to,omitempty"` - Params []string `json:"params,omitempty"` - Body interface{} `json:"body,omitempty"` - ID string `json:"id,omitempty"` - TS string `json:"ts,omitempty"` - Meta interface{} `json:"meta,omitempty"` + Command string `json:"command"` + From string `json:"from,omitempty"` + To string `json:"to,omitempty"` + Params []string `json:"params,omitempty"` + Body any `json:"body,omitempty"` + ID string `json:"id,omitempty"` + TS string `json:"ts,omitempty"` + Meta any `json:"meta,omitempty"` } -// BodyLines returns the body as a slice of strings (for text messages). +// BodyLines returns the body as a slice of strings (for text +// messages). func (m *Message) BodyLines() []string { switch v := m.Body.(type) { - case []interface{}: + case []any: lines := make([]string, 0, len(v)) + for _, item := range v { if s, ok := item.(string); ok { lines = append(lines, s) } } + return lines case []string: return v @@ -58,7 +61,7 @@ type Channel struct { Name string `json:"name"` Topic string `json:"topic"` Members int `json:"members"` - CreatedAt string `json:"created_at"` + CreatedAt string `json:"createdAt"` } // ServerInfo is the response from GET /api/v1/server. @@ -79,5 +82,6 @@ func (m *Message) ParseTS() time.Time { if err != nil { return time.Now() } + return t } diff --git a/cmd/chat-cli/main.go b/cmd/chat-cli/main.go index 6dfa1f3..37af8f3 100644 --- a/cmd/chat-cli/main.go +++ b/cmd/chat-cli/main.go @@ -1,8 +1,10 @@ +// Package main implements chat-cli, an IRC-style terminal client. package main import ( "fmt" "os" + "strconv" "strings" "sync" "time" @@ -10,6 +12,12 @@ import ( "git.eeqj.de/sneak/chat/cmd/chat-cli/api" ) +const ( + pollTimeoutSec = 15 + retryDelay = 2 * time.Second + maxNickLength = 32 +) + // App holds the application state. type App struct { ui *UI @@ -32,11 +40,18 @@ func main() { app.ui.OnInput(app.handleInput) app.ui.SetStatus(app.nick, "", "disconnected") - app.ui.AddStatus("Welcome to chat-cli — an IRC-style client") - app.ui.AddStatus("Type [yellow]/connect [white] to begin, or [yellow]/help[white] for commands") + app.ui.AddStatus( + "Welcome to chat-cli \u2014 an IRC-style client", + ) + app.ui.AddStatus( + "Type [yellow]/connect [white] " + + "to begin, or [yellow]/help[white] for commands", + ) + + err := app.ui.Run() + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "Error: %v\n", err) - if err := app.ui.Run(); err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } } @@ -44,21 +59,34 @@ func main() { func (a *App) handleInput(text string) { if strings.HasPrefix(text, "/") { a.handleCommand(text) + return } - // Plain text → PRIVMSG to current target. + a.sendPlainText(text) +} + +func (a *App) sendPlainText(text string) { a.mu.Lock() target := a.target connected := a.connected + nick := a.nick a.mu.Unlock() if !connected { - a.ui.AddStatus("[red]Not connected. Use /connect ") + a.ui.AddStatus( + "[red]Not connected. Use /connect ", + ) + return } + if target == "" { - a.ui.AddStatus("[red]No target. Use /join #channel or /query nick") + a.ui.AddStatus( + "[red]No target. " + + "Use /join #channel or /query nick", + ) + return } @@ -68,21 +96,28 @@ func (a *App) handleInput(text string) { Body: []string{text}, }) if err != nil { - a.ui.AddStatus(fmt.Sprintf("[red]Send error: %v", err)) + a.ui.AddStatus( + fmt.Sprintf("[red]Send error: %v", err), + ) + return } - // Echo locally. ts := time.Now().Format("15:04") - a.mu.Lock() - nick := a.nick - a.mu.Unlock() - a.ui.AddLine(target, fmt.Sprintf("[gray]%s [green]<%s>[white] %s", ts, nick, text)) + + a.ui.AddLine( + target, + fmt.Sprintf( + "[gray]%s [green]<%s>[white] %s", + ts, nick, text, + ), + ) } -func (a *App) handleCommand(text string) { - parts := strings.SplitN(text, " ", 2) +func (a *App) handleCommand(text string) { //nolint:cyclop // command dispatch + parts := strings.SplitN(text, " ", 2) //nolint:mnd // split into cmd+args cmd := strings.ToLower(parts[0]) + args := "" if len(parts) > 1 { args = parts[1] @@ -114,27 +149,41 @@ func (a *App) handleCommand(text string) { case "/help": a.cmdHelp() default: - a.ui.AddStatus(fmt.Sprintf("[red]Unknown command: %s", cmd)) + a.ui.AddStatus( + "[red]Unknown command: " + cmd, + ) } } func (a *App) cmdConnect(serverURL string) { if serverURL == "" { - a.ui.AddStatus("[red]Usage: /connect ") + a.ui.AddStatus( + "[red]Usage: /connect ", + ) + return } + serverURL = strings.TrimRight(serverURL, "/") - a.ui.AddStatus(fmt.Sprintf("Connecting to %s...", serverURL)) + a.ui.AddStatus( + fmt.Sprintf("Connecting to %s...", serverURL), + ) a.mu.Lock() nick := a.nick a.mu.Unlock() client := api.NewClient(serverURL) + resp, err := client.CreateSession(nick) if err != nil { - a.ui.AddStatus(fmt.Sprintf("[red]Connection failed: %v", err)) + a.ui.AddStatus( + fmt.Sprintf( + "[red]Connection failed: %v", err, + ), + ) + return } @@ -145,19 +194,26 @@ func (a *App) cmdConnect(serverURL string) { a.lastMsgID = "" a.mu.Unlock() - a.ui.AddStatus(fmt.Sprintf("[green]Connected! Nick: %s, Session: %s", resp.Nick, resp.SessionID)) + a.ui.AddStatus( + fmt.Sprintf( + "[green]Connected! Nick: %s, Session: %s", + resp.Nick, resp.SessionID, + ), + ) a.ui.SetStatus(resp.Nick, "", "connected") - // Start polling. a.stopPoll = make(chan struct{}) + go a.pollLoop() } func (a *App) cmdNick(nick string) { if nick == "" { a.ui.AddStatus("[red]Usage: /nick ") + return } + a.mu.Lock() connected := a.connected a.mu.Unlock() @@ -166,7 +222,14 @@ func (a *App) cmdNick(nick string) { a.mu.Lock() a.nick = nick a.mu.Unlock() - a.ui.AddStatus(fmt.Sprintf("Nick set to %s (will be used on connect)", nick)) + + a.ui.AddStatus( + fmt.Sprintf( + "Nick set to %s (will be used on connect)", + nick, + ), + ) + return } @@ -175,7 +238,12 @@ func (a *App) cmdNick(nick string) { Body: []string{nick}, }) if err != nil { - a.ui.AddStatus(fmt.Sprintf("[red]Nick change failed: %v", err)) + a.ui.AddStatus( + fmt.Sprintf( + "[red]Nick change failed: %v", err, + ), + ) + return } @@ -183,15 +251,20 @@ func (a *App) cmdNick(nick string) { a.nick = nick target := a.target a.mu.Unlock() + a.ui.SetStatus(nick, target, "connected") - a.ui.AddStatus(fmt.Sprintf("Nick changed to %s", nick)) + a.ui.AddStatus( + "Nick changed to " + nick, + ) } func (a *App) cmdJoin(channel string) { if channel == "" { a.ui.AddStatus("[red]Usage: /join #channel") + return } + if !strings.HasPrefix(channel, "#") { channel = "#" + channel } @@ -199,14 +272,19 @@ func (a *App) cmdJoin(channel string) { a.mu.Lock() connected := a.connected a.mu.Unlock() + if !connected { a.ui.AddStatus("[red]Not connected") + return } err := a.client.JoinChannel(channel) if err != nil { - a.ui.AddStatus(fmt.Sprintf("[red]Join failed: %v", err)) + a.ui.AddStatus( + fmt.Sprintf("[red]Join failed: %v", err), + ) + return } @@ -216,39 +294,55 @@ func (a *App) cmdJoin(channel string) { a.mu.Unlock() a.ui.SwitchToBuffer(channel) - a.ui.AddLine(channel, fmt.Sprintf("[yellow]*** Joined %s", channel)) + a.ui.AddLine( + channel, + "[yellow]*** Joined "+channel, + ) a.ui.SetStatus(nick, channel, "connected") } func (a *App) cmdPart(channel string) { a.mu.Lock() + if channel == "" { channel = a.target } + connected := a.connected a.mu.Unlock() if channel == "" || !strings.HasPrefix(channel, "#") { a.ui.AddStatus("[red]No channel to part") + return } + if !connected { a.ui.AddStatus("[red]Not connected") + return } err := a.client.PartChannel(channel) if err != nil { - a.ui.AddStatus(fmt.Sprintf("[red]Part failed: %v", err)) + a.ui.AddStatus( + fmt.Sprintf("[red]Part failed: %v", err), + ) + return } - a.ui.AddLine(channel, fmt.Sprintf("[yellow]*** Left %s", channel)) + a.ui.AddLine( + channel, + "[yellow]*** Left "+channel, + ) a.mu.Lock() + if a.target == channel { a.target = "" } + nick := a.nick a.mu.Unlock() @@ -257,19 +351,23 @@ func (a *App) cmdPart(channel string) { } func (a *App) cmdMsg(args string) { - parts := strings.SplitN(args, " ", 2) - if len(parts) < 2 { + parts := strings.SplitN(args, " ", 2) //nolint:mnd // split into target+text + if len(parts) < 2 { //nolint:mnd // min args a.ui.AddStatus("[red]Usage: /msg ") + return } + target, text := parts[0], parts[1] a.mu.Lock() connected := a.connected nick := a.nick a.mu.Unlock() + if !connected { a.ui.AddStatus("[red]Not connected") + return } @@ -279,17 +377,28 @@ func (a *App) cmdMsg(args string) { Body: []string{text}, }) if err != nil { - a.ui.AddStatus(fmt.Sprintf("[red]Send failed: %v", err)) + a.ui.AddStatus( + fmt.Sprintf("[red]Send failed: %v", err), + ) + return } ts := time.Now().Format("15:04") - a.ui.AddLine(target, fmt.Sprintf("[gray]%s [green]<%s>[white] %s", ts, nick, text)) + + a.ui.AddLine( + target, + fmt.Sprintf( + "[gray]%s [green]<%s>[white] %s", + ts, nick, text, + ), + ) } func (a *App) cmdQuery(nick string) { if nick == "" { a.ui.AddStatus("[red]Usage: /query ") + return } @@ -310,22 +419,29 @@ func (a *App) cmdTopic(args string) { if !connected { a.ui.AddStatus("[red]Not connected") + return } + if !strings.HasPrefix(target, "#") { a.ui.AddStatus("[red]Not in a channel") + return } if args == "" { - // Query topic. err := a.client.SendMessage(&api.Message{ Command: "TOPIC", To: target, }) if err != nil { - a.ui.AddStatus(fmt.Sprintf("[red]Topic query failed: %v", err)) + a.ui.AddStatus( + fmt.Sprintf( + "[red]Topic query failed: %v", err, + ), + ) } + return } @@ -335,7 +451,11 @@ func (a *App) cmdTopic(args string) { Body: []string{args}, }) if err != nil { - a.ui.AddStatus(fmt.Sprintf("[red]Topic set failed: %v", err)) + a.ui.AddStatus( + fmt.Sprintf( + "[red]Topic set failed: %v", err, + ), + ) } } @@ -347,20 +467,32 @@ func (a *App) cmdNames() { if !connected { a.ui.AddStatus("[red]Not connected") + return } + if !strings.HasPrefix(target, "#") { a.ui.AddStatus("[red]Not in a channel") + return } members, err := a.client.GetMembers(target) if err != nil { - a.ui.AddStatus(fmt.Sprintf("[red]Names failed: %v", err)) + a.ui.AddStatus( + fmt.Sprintf("[red]Names failed: %v", err), + ) + return } - a.ui.AddLine(target, fmt.Sprintf("[cyan]*** Members of %s: %s", target, strings.Join(members, " "))) + a.ui.AddLine( + target, + fmt.Sprintf( + "[cyan]*** Members of %s: %s", + target, strings.Join(members, " "), + ), + ) } func (a *App) cmdList() { @@ -370,46 +502,55 @@ func (a *App) cmdList() { if !connected { a.ui.AddStatus("[red]Not connected") + return } channels, err := a.client.ListChannels() if err != nil { - a.ui.AddStatus(fmt.Sprintf("[red]List failed: %v", err)) + a.ui.AddStatus( + fmt.Sprintf("[red]List failed: %v", err), + ) + return } a.ui.AddStatus("[cyan]*** Channel list:") + for _, ch := range channels { - a.ui.AddStatus(fmt.Sprintf(" %s (%d members) %s", ch.Name, ch.Members, ch.Topic)) + a.ui.AddStatus( + fmt.Sprintf( + " %s (%d members) %s", + ch.Name, ch.Members, ch.Topic, + ), + ) } + a.ui.AddStatus("[cyan]*** End of channel list") } func (a *App) cmdWindow(args string) { if args == "" { a.ui.AddStatus("[red]Usage: /window ") + return } - n := 0 - fmt.Sscanf(args, "%d", &n) + + n, _ := strconv.Atoi(args) a.ui.SwitchBuffer(n) a.mu.Lock() - if n < a.ui.BufferCount() && n >= 0 { - // Update target to the buffer name. - // Needs to be done carefully. - } nick := a.nick a.mu.Unlock() - // Update target based on buffer. - if n < a.ui.BufferCount() { + if n >= 0 && n < a.ui.BufferCount() { buf := a.ui.buffers[n] + if buf.Name != "(status)" { a.mu.Lock() a.target = buf.Name a.mu.Unlock() + a.ui.SetStatus(nick, buf.Name, "connected") } else { a.ui.SetStatus(nick, "", "connected") @@ -419,12 +560,17 @@ func (a *App) cmdWindow(args string) { func (a *App) cmdQuit() { a.mu.Lock() + if a.connected && a.client != nil { - _ = a.client.SendMessage(&api.Message{Command: "QUIT"}) + _ = a.client.SendMessage( + &api.Message{Command: "QUIT"}, + ) } + if a.stopPoll != nil { close(a.stopPoll) } + a.mu.Unlock() a.ui.Stop() } @@ -432,20 +578,21 @@ func (a *App) cmdQuit() { func (a *App) cmdHelp() { help := []string{ "[cyan]*** chat-cli commands:", - " /connect — Connect to server", - " /nick — Change nickname", - " /join #channel — Join channel", - " /part [#chan] — Leave channel", - " /msg — Send DM", - " /query — Open DM window", - " /topic [text] — View/set topic", - " /names — List channel members", - " /list — List channels", - " /window — Switch buffer (Alt+0-9)", - " /quit — Disconnect and exit", - " /help — This help", + " /connect \u2014 Connect to server", + " /nick \u2014 Change nickname", + " /join #channel \u2014 Join channel", + " /part [#chan] \u2014 Leave channel", + " /msg \u2014 Send DM", + " /query \u2014 Open DM window", + " /topic [text] \u2014 View/set topic", + " /names \u2014 List channel members", + " /list \u2014 List channels", + " /window \u2014 Switch buffer (Alt+0-9)", + " /quit \u2014 Disconnect and exit", + " /help \u2014 This help", " Plain text sends to current target.", } + for _, line := range help { a.ui.AddStatus(line) } @@ -469,32 +616,31 @@ func (a *App) pollLoop() { return } - msgs, err := client.PollMessages(lastID, 15) + msgs, err := client.PollMessages( + lastID, pollTimeoutSec, + ) if err != nil { - // Transient error — retry after delay. - time.Sleep(2 * time.Second) + time.Sleep(retryDelay) + continue } - for _, msg := range msgs { - a.handleServerMessage(&msg) - if msg.ID != "" { + for i := range msgs { + a.handleServerMessage(&msgs[i]) + + if msgs[i].ID != "" { a.mu.Lock() - a.lastMsgID = msg.ID + a.lastMsgID = msgs[i].ID a.mu.Unlock() } } } } -func (a *App) handleServerMessage(msg *api.Message) { - ts := "" - if msg.TS != "" { - t := msg.ParseTS() - ts = t.Local().Format("15:04") - } else { - ts = time.Now().Format("15:04") - } +func (a *App) handleServerMessage( + msg *api.Message, +) { + ts := a.parseMessageTS(msg) a.mu.Lock() myNick := a.nick @@ -502,79 +648,203 @@ func (a *App) handleServerMessage(msg *api.Message) { switch msg.Command { case "PRIVMSG": - lines := msg.BodyLines() - text := strings.Join(lines, " ") - if msg.From == myNick { - // Skip our own echoed messages (already displayed locally). - return - } - target := msg.To - if !strings.HasPrefix(target, "#") { - // DM — use sender's nick as buffer name. - target = msg.From - } - a.ui.AddLine(target, fmt.Sprintf("[gray]%s [green]<%s>[white] %s", ts, msg.From, text)) - + a.handlePrivmsgMsg(msg, ts, myNick) case "JOIN": - target := msg.To - if target != "" { - a.ui.AddLine(target, fmt.Sprintf("[gray]%s [yellow]*** %s has joined %s", ts, msg.From, target)) - } - + a.handleJoinMsg(msg, ts) case "PART": - target := msg.To - lines := msg.BodyLines() - reason := strings.Join(lines, " ") - if target != "" { - if reason != "" { - a.ui.AddLine(target, fmt.Sprintf("[gray]%s [yellow]*** %s has left %s (%s)", ts, msg.From, target, reason)) - } else { - a.ui.AddLine(target, fmt.Sprintf("[gray]%s [yellow]*** %s has left %s", ts, msg.From, target)) - } - } - + a.handlePartMsg(msg, ts) case "QUIT": - lines := msg.BodyLines() - reason := strings.Join(lines, " ") - if reason != "" { - a.ui.AddStatus(fmt.Sprintf("[gray]%s [yellow]*** %s has quit (%s)", ts, msg.From, reason)) - } else { - a.ui.AddStatus(fmt.Sprintf("[gray]%s [yellow]*** %s has quit", ts, msg.From)) - } - + a.handleQuitMsg(msg, ts) case "NICK": - lines := msg.BodyLines() - newNick := "" - if len(lines) > 0 { - newNick = lines[0] - } - if msg.From == myNick && newNick != "" { - a.mu.Lock() - a.nick = newNick - target := a.target - a.mu.Unlock() - a.ui.SetStatus(newNick, target, "connected") - } - a.ui.AddStatus(fmt.Sprintf("[gray]%s [yellow]*** %s is now known as %s", ts, msg.From, newNick)) - + a.handleNickMsg(msg, ts, myNick) case "NOTICE": - lines := msg.BodyLines() - text := strings.Join(lines, " ") - a.ui.AddStatus(fmt.Sprintf("[gray]%s [magenta]--%s-- %s", ts, msg.From, text)) - + a.handleNoticeMsg(msg, ts) case "TOPIC": - lines := msg.BodyLines() - text := strings.Join(lines, " ") - if msg.To != "" { - a.ui.AddLine(msg.To, fmt.Sprintf("[gray]%s [cyan]*** %s set topic: %s", ts, msg.From, text)) - } - + a.handleTopicMsg(msg, ts) default: - // Numeric replies and other messages → status window. - lines := msg.BodyLines() - text := strings.Join(lines, " ") - if text != "" { - a.ui.AddStatus(fmt.Sprintf("[gray]%s [white][%s] %s", ts, msg.Command, text)) - } + a.handleDefaultMsg(msg, ts) } } + +func (a *App) parseMessageTS(msg *api.Message) string { + if msg.TS != "" { + t := msg.ParseTS() + + return t.In(time.Local).Format("15:04") //nolint:gosmopolitan // CLI uses local time + } + + return time.Now().Format("15:04") +} + +func (a *App) handlePrivmsgMsg( + msg *api.Message, + ts, myNick string, +) { + lines := msg.BodyLines() + text := strings.Join(lines, " ") + + if msg.From == myNick { + return + } + + target := msg.To + if !strings.HasPrefix(target, "#") { + target = msg.From + } + + a.ui.AddLine( + target, + fmt.Sprintf( + "[gray]%s [green]<%s>[white] %s", + ts, msg.From, text, + ), + ) +} + +func (a *App) handleJoinMsg( + msg *api.Message, ts string, +) { + target := msg.To + if target == "" { + return + } + + a.ui.AddLine( + target, + fmt.Sprintf( + "[gray]%s [yellow]*** %s has joined %s", + ts, msg.From, target, + ), + ) +} + +func (a *App) handlePartMsg( + msg *api.Message, ts string, +) { + target := msg.To + if target == "" { + return + } + + lines := msg.BodyLines() + reason := strings.Join(lines, " ") + + if reason != "" { + a.ui.AddLine( + target, + fmt.Sprintf( + "[gray]%s [yellow]*** %s has left %s (%s)", + ts, msg.From, target, reason, + ), + ) + } else { + a.ui.AddLine( + target, + fmt.Sprintf( + "[gray]%s [yellow]*** %s has left %s", + ts, msg.From, target, + ), + ) + } +} + +func (a *App) handleQuitMsg( + msg *api.Message, ts string, +) { + lines := msg.BodyLines() + reason := strings.Join(lines, " ") + + if reason != "" { + a.ui.AddStatus( + fmt.Sprintf( + "[gray]%s [yellow]*** %s has quit (%s)", + ts, msg.From, reason, + ), + ) + } else { + a.ui.AddStatus( + fmt.Sprintf( + "[gray]%s [yellow]*** %s has quit", + ts, msg.From, + ), + ) + } +} + +func (a *App) handleNickMsg( + msg *api.Message, ts, myNick string, +) { + lines := msg.BodyLines() + + newNick := "" + if len(lines) > 0 { + newNick = lines[0] + } + + if msg.From == myNick && newNick != "" { + a.mu.Lock() + a.nick = newNick + target := a.target + a.mu.Unlock() + + a.ui.SetStatus(newNick, target, "connected") + } + + a.ui.AddStatus( + fmt.Sprintf( + "[gray]%s [yellow]*** %s is now known as %s", + ts, msg.From, newNick, + ), + ) +} + +func (a *App) handleNoticeMsg( + msg *api.Message, ts string, +) { + lines := msg.BodyLines() + text := strings.Join(lines, " ") + + a.ui.AddStatus( + fmt.Sprintf( + "[gray]%s [magenta]--%s-- %s", + ts, msg.From, text, + ), + ) +} + +func (a *App) handleTopicMsg( + msg *api.Message, ts string, +) { + if msg.To == "" { + return + } + + lines := msg.BodyLines() + text := strings.Join(lines, " ") + + a.ui.AddLine( + msg.To, + fmt.Sprintf( + "[gray]%s [cyan]*** %s set topic: %s", + ts, msg.From, text, + ), + ) +} + +func (a *App) handleDefaultMsg( + msg *api.Message, ts string, +) { + lines := msg.BodyLines() + text := strings.Join(lines, " ") + + if text == "" { + return + } + + a.ui.AddStatus( + fmt.Sprintf( + "[gray]%s [white][%s] %s", + ts, msg.Command, text, + ), + ) +} diff --git a/cmd/chat-cli/ui.go b/cmd/chat-cli/ui.go index 16449f2..40f55b3 100644 --- a/cmd/chat-cli/ui.go +++ b/cmd/chat-cli/ui.go @@ -31,6 +31,7 @@ type UI struct { } // NewUI creates the tview-based IRC-like UI. + func NewUI() *UI { ui := &UI{ app: tview.NewApplication(), @@ -39,80 +40,35 @@ func NewUI() *UI { }, } - // Message area. - ui.messages = tview.NewTextView(). - SetDynamicColors(true). - SetScrollable(true). - SetWordWrap(true). - SetChangedFunc(func() { - ui.app.Draw() - }) - ui.messages.SetBorder(false) - - // Status bar. - ui.statusBar = tview.NewTextView(). - SetDynamicColors(true) - ui.statusBar.SetBackgroundColor(tcell.ColorNavy) - ui.statusBar.SetTextColor(tcell.ColorWhite) - - // Input field. - ui.input = tview.NewInputField(). - SetFieldBackgroundColor(tcell.ColorBlack). - SetFieldTextColor(tcell.ColorWhite) - ui.input.SetDoneFunc(func(key tcell.Key) { - if key == tcell.KeyEnter { - text := ui.input.GetText() - if text == "" { - return - } - ui.input.SetText("") - if ui.onInput != nil { - ui.onInput(text) - } - } - }) - - // Capture Alt+N for window switching. - ui.app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { - if event.Modifiers()&tcell.ModAlt != 0 { - r := event.Rune() - if r >= '0' && r <= '9' { - idx := int(r - '0') - ui.SwitchBuffer(idx) - return nil - } - } - return event - }) - - // Layout: messages on top, status bar, input at bottom. - ui.layout = tview.NewFlex().SetDirection(tview.FlexRow). - AddItem(ui.messages, 0, 1, false). - AddItem(ui.statusBar, 1, 0, false). - AddItem(ui.input, 1, 0, true) - - ui.app.SetRoot(ui.layout, true) - ui.app.SetFocus(ui.input) + ui.setupMessages() + ui.setupStatusBar() + ui.setupInput() + ui.setupKeybindings() + ui.setupLayout() return ui } // Run starts the UI event loop (blocks). + func (ui *UI) Run() error { return ui.app.Run() } // Stop stops the UI. + func (ui *UI) Stop() { ui.app.Stop() } // OnInput sets the callback for user input. + func (ui *UI) OnInput(fn func(string)) { ui.onInput = fn } // AddLine adds a line to the specified buffer. + func (ui *UI) AddLine(bufferName string, line string) { ui.app.QueueUpdateDraw(func() { buf := ui.getOrCreateBuffer(bufferName) @@ -121,89 +77,219 @@ func (ui *UI) AddLine(bufferName string, line string) { // Mark unread if not currently viewing this buffer. if ui.buffers[ui.currentBuffer] != buf { buf.Unread++ + ui.refreshStatus() } // If viewing this buffer, append to display. if ui.buffers[ui.currentBuffer] == buf { - fmt.Fprintln(ui.messages, line) + _, _ = fmt.Fprintln(ui.messages, line) } }) } // AddStatus adds a line to the status buffer (buffer 0). + func (ui *UI) AddStatus(line string) { ts := time.Now().Format("15:04") - ui.AddLine("(status)", fmt.Sprintf("[gray]%s[white] %s", ts, line)) + + ui.AddLine( + "(status)", + fmt.Sprintf("[gray]%s[white] %s", ts, line), + ) } // SwitchBuffer switches to the buffer at index n. + func (ui *UI) SwitchBuffer(n int) { ui.app.QueueUpdateDraw(func() { if n < 0 || n >= len(ui.buffers) { return } + ui.currentBuffer = n + buf := ui.buffers[n] buf.Unread = 0 + ui.messages.Clear() + for _, line := range buf.Lines { - fmt.Fprintln(ui.messages, line) + _, _ = fmt.Fprintln(ui.messages, line) } + ui.messages.ScrollToEnd() ui.refreshStatus() }) } -// SwitchToBuffer switches to the named buffer, creating it if needed. +// SwitchToBuffer switches to the named buffer, creating it + func (ui *UI) SwitchToBuffer(name string) { ui.app.QueueUpdateDraw(func() { buf := ui.getOrCreateBuffer(name) + for i, b := range ui.buffers { if b == buf { ui.currentBuffer = i + break } } + buf.Unread = 0 + ui.messages.Clear() + for _, line := range buf.Lines { - fmt.Fprintln(ui.messages, line) + _, _ = fmt.Fprintln(ui.messages, line) } + ui.messages.ScrollToEnd() ui.refreshStatus() }) } // SetStatus updates the status bar text. -func (ui *UI) SetStatus(nick, target, connStatus string) { + +func (ui *UI) SetStatus( + nick, target, connStatus string, +) { ui.app.QueueUpdateDraw(func() { ui.refreshStatusWith(nick, target, connStatus) }) } -func (ui *UI) refreshStatus() { - // Will be called from the main goroutine via QueueUpdateDraw parent. - // Rebuild status from app state — caller must provide context. +// BufferCount returns the number of buffers. + +func (ui *UI) BufferCount() int { + return len(ui.buffers) } -func (ui *UI) refreshStatusWith(nick, target, connStatus string) { - var unreadParts []string +// BufferIndex returns the index of a named buffer, or -1. + +func (ui *UI) BufferIndex(name string) int { for i, buf := range ui.buffers { - if buf.Unread > 0 { - unreadParts = append(unreadParts, fmt.Sprintf("%d:%s(%d)", i, buf.Name, buf.Unread)) + if buf.Name == name { + return i } } - unread := "" - if len(unreadParts) > 0 { - unread = " [Act: " + strings.Join(unreadParts, ",") + "]" + + return -1 +} + +func (ui *UI) setupMessages() { + ui.messages = tview.NewTextView(). + SetDynamicColors(true). + SetScrollable(true). + SetWordWrap(true). + SetChangedFunc(func() { + ui.app.Draw() + }) + ui.messages.SetBorder(false) +} + +func (ui *UI) setupStatusBar() { + ui.statusBar = tview.NewTextView(). + SetDynamicColors(true) + ui.statusBar.SetBackgroundColor(tcell.ColorNavy) + ui.statusBar.SetTextColor(tcell.ColorWhite) +} + +func (ui *UI) setupInput() { + ui.input = tview.NewInputField(). + SetFieldBackgroundColor(tcell.ColorBlack). + SetFieldTextColor(tcell.ColorWhite) + + ui.input.SetDoneFunc(func(key tcell.Key) { + if key != tcell.KeyEnter { + return + } + + text := ui.input.GetText() + if text == "" { + return + } + + ui.input.SetText("") + + if ui.onInput != nil { + ui.onInput(text) + } + }) +} + +func (ui *UI) setupKeybindings() { + ui.app.SetInputCapture( + func(event *tcell.EventKey) *tcell.EventKey { + if event.Modifiers()&tcell.ModAlt == 0 { + return event + } + + r := event.Rune() + if r >= '0' && r <= '9' { + idx := int(r - '0') + ui.SwitchBuffer(idx) + + return nil + } + + return event + }, + ) +} + +func (ui *UI) setupLayout() { + ui.layout = tview.NewFlex(). + SetDirection(tview.FlexRow). + AddItem(ui.messages, 0, 1, false). + AddItem(ui.statusBar, 1, 0, false). + AddItem(ui.input, 1, 0, true) + + ui.app.SetRoot(ui.layout, true) + ui.app.SetFocus(ui.input) +} + +// if needed. + +func (ui *UI) refreshStatus() { + // Rebuilt from app state by parent QueueUpdateDraw. +} + +func (ui *UI) refreshStatusWith( + nick, target, connStatus string, +) { + var unreadParts []string + + for i, buf := range ui.buffers { + if buf.Unread > 0 { + unreadParts = append( + unreadParts, + fmt.Sprintf( + "%d:%s(%d)", i, buf.Name, buf.Unread, + ), + ) + } } - bufInfo := fmt.Sprintf("[%d:%s]", ui.currentBuffer, ui.buffers[ui.currentBuffer].Name) + unread := "" + if len(unreadParts) > 0 { + unread = " [Act: " + + strings.Join(unreadParts, ",") + "]" + } + + bufInfo := fmt.Sprintf( + "[%d:%s]", + ui.currentBuffer, + ui.buffers[ui.currentBuffer].Name, + ) ui.statusBar.Clear() - fmt.Fprintf(ui.statusBar, " [%s] %s %s %s%s", - connStatus, nick, bufInfo, target, unread) + + _, _ = fmt.Fprintf( + ui.statusBar, " [%s] %s %s %s%s", + connStatus, nick, bufInfo, target, unread, + ) } func (ui *UI) getOrCreateBuffer(name string) *Buffer { @@ -212,22 +298,9 @@ func (ui *UI) getOrCreateBuffer(name string) *Buffer { return buf } } + buf := &Buffer{Name: name} ui.buffers = append(ui.buffers, buf) + return buf } - -// BufferCount returns the number of buffers. -func (ui *UI) BufferCount() int { - return len(ui.buffers) -} - -// BufferIndex returns the index of a named buffer, or -1. -func (ui *UI) BufferIndex(name string) int { - for i, buf := range ui.buffers { - if buf.Name == name { - return i - } - } - return -1 -} diff --git a/internal/db/db.go b/internal/db/db.go index b9a2b9f..5062827 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -88,8 +88,10 @@ func NewTest(dsn string) (*Database, error) { } // Item 9: Enable foreign keys - if _, err := d.Exec("PRAGMA foreign_keys = ON"); err != nil { + _, err = d.Exec("PRAGMA foreign_keys = ON") //nolint:noctx // no context in sql.Open path + if err != nil { _ = d.Close() + return nil, fmt.Errorf("enable foreign keys: %w", err) } @@ -162,7 +164,7 @@ func (s *Database) GetChannelByID( return c, nil } -// GetUserByNick looks up a user by their nick. +// GetUserByNickModel looks up a user by their nick. func (s *Database) GetUserByNickModel( ctx context.Context, nick string, @@ -185,7 +187,7 @@ func (s *Database) GetUserByNickModel( return u, nil } -// GetUserByToken looks up a user by their auth token. +// GetUserByTokenModel looks up a user by their auth token. func (s *Database) GetUserByTokenModel( ctx context.Context, token string, @@ -219,6 +221,7 @@ func (s *Database) DeleteAuthToken( _, err := s.db.ExecContext(ctx, `DELETE FROM auth_tokens WHERE token = ?`, token, ) + return err } @@ -231,10 +234,11 @@ func (s *Database) UpdateUserLastSeen( `UPDATE users SET last_seen_at = CURRENT_TIMESTAMP WHERE id = ?`, userID, ) + return err } -// CreateUser inserts a new user into the database. +// CreateUserModel inserts a new user into the database. func (s *Database) CreateUserModel( ctx context.Context, id, nick, passwordHash string, @@ -394,6 +398,7 @@ func (s *Database) DequeueMessages( if err != nil { return nil, err } + defer func() { _ = rows.Close() }() entries := []*models.MessageQueueEntry{} @@ -423,14 +428,14 @@ func (s *Database) AckMessages( } placeholders := make([]string, len(entryIDs)) - args := make([]interface{}, len(entryIDs)) + args := make([]any, len(entryIDs)) for i, id := range entryIDs { placeholders[i] = "?" args[i] = id } - query := fmt.Sprintf( + query := fmt.Sprintf( //nolint:gosec // placeholders are ?, not user input "DELETE FROM message_queue WHERE id IN (%s)", strings.Join(placeholders, ","), ) @@ -549,7 +554,8 @@ func (s *Database) connect(ctx context.Context) error { s.log.Info("database connected") // Item 9: Enable foreign keys on every connection - if _, err := s.db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { + _, err = s.db.ExecContext(ctx, "PRAGMA foreign_keys = ON") + if err != nil { return fmt.Errorf("enable foreign keys: %w", err) } @@ -676,41 +682,54 @@ func (s *Database) applyMigrations( "version", m.version, "name", m.name, ) - tx, err := s.db.BeginTx(ctx, nil) + err = s.executeMigration(ctx, m) if err != nil { - return fmt.Errorf( - "begin tx for migration %d: %w", m.version, err, - ) - } - - _, err = tx.ExecContext(ctx, m.sql) - if err != nil { - _ = tx.Rollback() - - return fmt.Errorf( - "apply migration %d (%s): %w", - m.version, m.name, err, - ) - } - - _, err = tx.ExecContext(ctx, - "INSERT INTO schema_migrations (version) VALUES (?)", - m.version, - ) - if err != nil { - _ = tx.Rollback() - - return fmt.Errorf( - "record migration %d: %w", m.version, err, - ) - } - - if err := tx.Commit(); err != nil { - return fmt.Errorf( - "commit migration %d: %w", m.version, err, - ) + return err } } return nil } + +func (s *Database) executeMigration( + ctx context.Context, + m migration, +) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf( + "begin tx for migration %d: %w", m.version, err, + ) + } + + _, err = tx.ExecContext(ctx, m.sql) + if err != nil { + _ = tx.Rollback() + + return fmt.Errorf( + "apply migration %d (%s): %w", + m.version, m.name, err, + ) + } + + _, err = tx.ExecContext(ctx, + "INSERT INTO schema_migrations (version) VALUES (?)", + m.version, + ) + if err != nil { + _ = tx.Rollback() + + return fmt.Errorf( + "record migration %d: %w", m.version, err, + ) + } + + err = tx.Commit() + if err != nil { + return fmt.Errorf( + "commit migration %d: %w", m.version, err, + ) + } + + return nil +} diff --git a/internal/db/queries.go b/internal/db/queries.go index 974f4db..f3567b4 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -3,107 +3,144 @@ package db import ( "context" "crypto/rand" + "database/sql" "encoding/hex" "fmt" "time" ) +const ( + defaultMessageLimit = 50 + defaultPollLimit = 100 + tokenBytes = 32 +) + func generateToken() string { - b := make([]byte, 32) + b := make([]byte, tokenBytes) _, _ = rand.Read(b) + return hex.EncodeToString(b) } -// CreateUser registers a new user with the given nick and returns the user with token. -func (s *Database) CreateUser(ctx context.Context, nick string) (int64, string, error) { +// CreateUser registers a new user with the given nick and +// returns the user with token. +func (s *Database) CreateUser( + ctx context.Context, + nick string, +) (int64, string, error) { token := generateToken() now := time.Now() + res, err := s.db.ExecContext(ctx, "INSERT INTO users (nick, token, created_at, last_seen) VALUES (?, ?, ?, ?)", nick, token, now, now) if err != nil { return 0, "", fmt.Errorf("create user: %w", err) } + id, _ := res.LastInsertId() + return id, token, nil } -// GetUserByToken returns user id and nick for a given auth token. -func (s *Database) GetUserByToken(ctx context.Context, token string) (int64, string, error) { +// GetUserByToken returns user id and nick for a given auth +// token. +func (s *Database) GetUserByToken( + ctx context.Context, + token string, +) (int64, string, error) { var id int64 + var nick string - err := s.db.QueryRowContext(ctx, "SELECT id, nick FROM users WHERE token = ?", token).Scan(&id, &nick) + + err := s.db.QueryRowContext( + ctx, + "SELECT id, nick FROM users WHERE token = ?", + token, + ).Scan(&id, &nick) if err != nil { return 0, "", err } + // Update last_seen - _, _ = s.db.ExecContext(ctx, "UPDATE users SET last_seen = ? WHERE id = ?", time.Now(), id) + _, _ = s.db.ExecContext( + ctx, + "UPDATE users SET last_seen = ? WHERE id = ?", + time.Now(), id, + ) + return id, nick, nil } // GetUserByNick returns user id for a given nick. -func (s *Database) GetUserByNick(ctx context.Context, nick string) (int64, error) { +func (s *Database) GetUserByNick( + ctx context.Context, + nick string, +) (int64, error) { var id int64 - err := s.db.QueryRowContext(ctx, "SELECT id FROM users WHERE nick = ?", nick).Scan(&id) + + err := s.db.QueryRowContext( + ctx, + "SELECT id FROM users WHERE nick = ?", + nick, + ).Scan(&id) + return id, err } -// GetOrCreateChannel returns the channel id, creating it if needed. -func (s *Database) GetOrCreateChannel(ctx context.Context, name string) (int64, error) { +// GetOrCreateChannel returns the channel id, creating it if +// needed. +func (s *Database) GetOrCreateChannel( + ctx context.Context, + name string, +) (int64, error) { var id int64 - err := s.db.QueryRowContext(ctx, "SELECT id FROM channels WHERE name = ?", name).Scan(&id) + + err := s.db.QueryRowContext( + ctx, + "SELECT id FROM channels WHERE name = ?", + name, + ).Scan(&id) if err == nil { return id, nil } + now := time.Now() + res, err := s.db.ExecContext(ctx, "INSERT INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)", name, now, now) if err != nil { return 0, fmt.Errorf("create channel: %w", err) } + id, _ = res.LastInsertId() + return id, nil } // JoinChannel adds a user to a channel. -func (s *Database) JoinChannel(ctx context.Context, channelID, userID int64) error { +func (s *Database) JoinChannel( + ctx context.Context, + channelID, userID int64, +) error { _, err := s.db.ExecContext(ctx, "INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)", channelID, userID, time.Now()) + return err } // PartChannel removes a user from a channel. -func (s *Database) PartChannel(ctx context.Context, channelID, userID int64) error { +func (s *Database) PartChannel( + ctx context.Context, + channelID, userID int64, +) error { _, err := s.db.ExecContext(ctx, "DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?", channelID, userID) - return err -} -// ListChannels returns all channels the user has joined. -func (s *Database) ListChannels(ctx context.Context, userID int64) ([]ChannelInfo, error) { - rows, err := s.db.QueryContext(ctx, - `SELECT c.id, c.name, c.topic FROM channels c - INNER JOIN channel_members cm ON cm.channel_id = c.id - WHERE cm.user_id = ? ORDER BY c.name`, userID) - if err != nil { - return nil, err - } - defer rows.Close() - var channels []ChannelInfo - for rows.Next() { - var ch ChannelInfo - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil { - return nil, err - } - channels = append(channels, ch) - } - if channels == nil { - channels = []ChannelInfo{} - } - return channels, nil + return err } // ChannelInfo is a lightweight channel representation. @@ -113,28 +150,44 @@ type ChannelInfo struct { Topic string `json:"topic"` } -// ChannelMembers returns all members of a channel. -func (s *Database) ChannelMembers(ctx context.Context, channelID int64) ([]MemberInfo, error) { +// ListChannels returns all channels the user has joined. +func (s *Database) ListChannels( + ctx context.Context, + userID int64, +) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, - `SELECT u.id, u.nick, u.last_seen FROM users u - INNER JOIN channel_members cm ON cm.user_id = u.id - WHERE cm.channel_id = ? ORDER BY u.nick`, channelID) + `SELECT c.id, c.name, c.topic FROM channels c + INNER JOIN channel_members cm ON cm.channel_id = c.id + WHERE cm.user_id = ? ORDER BY c.name`, userID) if err != nil { return nil, err } - defer rows.Close() - var members []MemberInfo + + defer func() { _ = rows.Close() }() + + var channels []ChannelInfo + for rows.Next() { - var m MemberInfo - if err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen); err != nil { + var ch ChannelInfo + + err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic) + if err != nil { return nil, err } - members = append(members, m) + + channels = append(channels, ch) } - if members == nil { - members = []MemberInfo{} + + err = rows.Err() + if err != nil { + return nil, err } - return members, nil + + if channels == nil { + channels = []ChannelInfo{} + } + + return channels, nil } // MemberInfo represents a channel member. @@ -144,6 +197,46 @@ type MemberInfo struct { LastSeen time.Time `json:"lastSeen"` } +// ChannelMembers returns all members of a channel. +func (s *Database) ChannelMembers( + ctx context.Context, + channelID int64, +) ([]MemberInfo, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT u.id, u.nick, u.last_seen FROM users u + INNER JOIN channel_members cm ON cm.user_id = u.id + WHERE cm.channel_id = ? ORDER BY u.nick`, channelID) + if err != nil { + return nil, err + } + + defer func() { _ = rows.Close() }() + + var members []MemberInfo + + for rows.Next() { + var m MemberInfo + + err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen) + if err != nil { + return nil, err + } + + members = append(members, m) + } + + err = rows.Err() + if err != nil { + return nil, err + } + + if members == nil { + members = []MemberInfo{} + } + + return members, nil +} + // MessageInfo represents a chat message. type MessageInfo struct { ID int64 `json:"id"` @@ -155,11 +248,18 @@ type MessageInfo struct { CreatedAt time.Time `json:"createdAt"` } -// GetMessages returns messages for a channel, optionally after a given ID. -func (s *Database) GetMessages(ctx context.Context, channelID int64, afterID int64, limit int) ([]MessageInfo, error) { +// GetMessages returns messages for a channel, optionally +// after a given ID. +func (s *Database) GetMessages( + ctx context.Context, + channelID int64, + afterID int64, + limit int, +) ([]MessageInfo, error) { if limit <= 0 { - limit = 50 + limit = defaultMessageLimit } + rows, err := s.db.QueryContext(ctx, `SELECT m.id, c.name, u.nick, m.content, m.created_at FROM messages m @@ -170,128 +270,288 @@ func (s *Database) GetMessages(ctx context.Context, channelID int64, afterID int if err != nil { return nil, err } - defer rows.Close() + + defer func() { _ = rows.Close() }() + var msgs []MessageInfo + for rows.Next() { var m MessageInfo - if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil { + + err := rows.Scan( + &m.ID, &m.Channel, &m.Nick, + &m.Content, &m.CreatedAt, + ) + if err != nil { return nil, err } + msgs = append(msgs, m) } + + err = rows.Err() + if err != nil { + return nil, err + } + if msgs == nil { msgs = []MessageInfo{} } + return msgs, nil } // SendMessage inserts a channel message. -func (s *Database) SendMessage(ctx context.Context, channelID, userID int64, content string) (int64, error) { +func (s *Database) SendMessage( + ctx context.Context, + channelID, userID int64, + content string, +) (int64, error) { res, err := s.db.ExecContext(ctx, "INSERT INTO messages (channel_id, user_id, content, is_dm, created_at) VALUES (?, ?, ?, 0, ?)", channelID, userID, content, time.Now()) if err != nil { return 0, err } + return res.LastInsertId() } // SendDM inserts a direct message. -func (s *Database) SendDM(ctx context.Context, fromID, toID int64, content string) (int64, error) { +func (s *Database) SendDM( + ctx context.Context, + fromID, toID int64, + content string, +) (int64, error) { res, err := s.db.ExecContext(ctx, "INSERT INTO messages (user_id, content, is_dm, dm_target_id, created_at) VALUES (?, ?, 1, ?, ?)", fromID, content, toID, time.Now()) if err != nil { return 0, err } + return res.LastInsertId() } -// GetDMs returns direct messages between two users after a given ID. -func (s *Database) GetDMs(ctx context.Context, userA, userB int64, afterID int64, limit int) ([]MessageInfo, error) { +// GetDMs returns direct messages between two users after a +// given ID. +func (s *Database) GetDMs( + ctx context.Context, + userA, userB int64, + afterID int64, + limit int, +) ([]MessageInfo, error) { if limit <= 0 { - limit = 50 + limit = defaultMessageLimit } + rows, err := s.db.QueryContext(ctx, `SELECT m.id, u.nick, m.content, t.nick, m.created_at FROM messages m INNER JOIN users u ON u.id = m.user_id INNER JOIN users t ON t.id = m.dm_target_id WHERE m.is_dm = 1 AND m.id > ? - AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?)) - ORDER BY m.id ASC LIMIT ?`, afterID, userA, userB, userB, userA, limit) + AND ((m.user_id = ? AND m.dm_target_id = ?) + OR (m.user_id = ? AND m.dm_target_id = ?)) + ORDER BY m.id ASC LIMIT ?`, + afterID, userA, userB, userB, userA, limit) if err != nil { return nil, err } - defer rows.Close() + + defer func() { _ = rows.Close() }() + var msgs []MessageInfo + for rows.Next() { var m MessageInfo - if err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt); err != nil { + + err := rows.Scan( + &m.ID, &m.Nick, &m.Content, + &m.DMTarget, &m.CreatedAt, + ) + if err != nil { return nil, err } + m.IsDM = true + msgs = append(msgs, m) } + + err = rows.Err() + if err != nil { + return nil, err + } + if msgs == nil { msgs = []MessageInfo{} } + return msgs, nil } -// PollMessages returns all new messages (channel + DM) for a user after a given ID. -func (s *Database) PollMessages(ctx context.Context, userID int64, afterID int64, limit int) ([]MessageInfo, error) { +// PollMessages returns all new messages (channel + DM) for +// a user after a given ID. +func (s *Database) PollMessages( + ctx context.Context, + userID int64, + afterID int64, + limit int, +) ([]MessageInfo, error) { if limit <= 0 { - limit = 100 + limit = defaultPollLimit } + rows, err := s.db.QueryContext(ctx, - `SELECT m.id, COALESCE(c.name, ''), u.nick, m.content, m.is_dm, COALESCE(t.nick, ''), m.created_at + `SELECT m.id, COALESCE(c.name, ''), u.nick, m.content, + m.is_dm, COALESCE(t.nick, ''), m.created_at FROM messages m INNER JOIN users u ON u.id = m.user_id LEFT JOIN channels c ON c.id = m.channel_id LEFT JOIN users t ON t.id = m.dm_target_id WHERE m.id > ? AND ( - (m.is_dm = 0 AND m.channel_id IN (SELECT channel_id FROM channel_members WHERE user_id = ?)) - OR (m.is_dm = 1 AND (m.user_id = ? OR m.dm_target_id = ?)) + (m.is_dm = 0 AND m.channel_id IN + (SELECT channel_id FROM channel_members + WHERE user_id = ?)) + OR (m.is_dm = 1 + AND (m.user_id = ? OR m.dm_target_id = ?)) ) - ORDER BY m.id ASC LIMIT ?`, afterID, userID, userID, userID, limit) + ORDER BY m.id ASC LIMIT ?`, + afterID, userID, userID, userID, limit) if err != nil { return nil, err } - defer rows.Close() - var msgs []MessageInfo + + defer func() { _ = rows.Close() }() + + msgs := make([]MessageInfo, 0) + for rows.Next() { - var m MessageInfo - var isDM int - if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &isDM, &m.DMTarget, &m.CreatedAt); err != nil { + var ( + m MessageInfo + isDM int + ) + + err := rows.Scan( + &m.ID, &m.Channel, &m.Nick, &m.Content, + &isDM, &m.DMTarget, &m.CreatedAt, + ) + if err != nil { return nil, err } + m.IsDM = isDM == 1 msgs = append(msgs, m) } - if msgs == nil { - msgs = []MessageInfo{} + + err = rows.Err() + if err != nil { + return nil, err } + return msgs, nil } -// GetMessagesBefore returns channel messages before a given ID (for history scrollback). -func (s *Database) GetMessagesBefore(ctx context.Context, channelID int64, beforeID int64, limit int) ([]MessageInfo, error) { - if limit <= 0 { - limit = 50 +func scanChannelMessages( + rows *sql.Rows, +) ([]MessageInfo, error) { + var msgs []MessageInfo + + for rows.Next() { + var m MessageInfo + + err := rows.Scan( + &m.ID, &m.Channel, &m.Nick, + &m.Content, &m.CreatedAt, + ) + if err != nil { + return nil, err + } + + msgs = append(msgs, m) } + + err := rows.Err() + if err != nil { + return nil, err + } + + if msgs == nil { + msgs = []MessageInfo{} + } + + return msgs, nil +} + +func scanDMMessages( + rows *sql.Rows, +) ([]MessageInfo, error) { + var msgs []MessageInfo + + for rows.Next() { + var m MessageInfo + + err := rows.Scan( + &m.ID, &m.Nick, &m.Content, + &m.DMTarget, &m.CreatedAt, + ) + if err != nil { + return nil, err + } + + m.IsDM = true + + msgs = append(msgs, m) + } + + err := rows.Err() + if err != nil { + return nil, err + } + + if msgs == nil { + msgs = []MessageInfo{} + } + + return msgs, nil +} + +func reverseMessages(msgs []MessageInfo) { + for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { + msgs[i], msgs[j] = msgs[j], msgs[i] + } +} + +// GetMessagesBefore returns channel messages before a given +// ID (for history scrollback). +func (s *Database) GetMessagesBefore( + ctx context.Context, + channelID int64, + beforeID int64, + limit int, +) ([]MessageInfo, error) { + if limit <= 0 { + limit = defaultMessageLimit + } + var query string + var args []any + if beforeID > 0 { - query = `SELECT m.id, c.name, u.nick, m.content, m.created_at + query = `SELECT m.id, c.name, u.nick, m.content, + m.created_at FROM messages m INNER JOIN users u ON u.id = m.user_id INNER JOIN channels c ON c.id = m.channel_id - WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id < ? + WHERE m.channel_id = ? AND m.is_dm = 0 + AND m.id < ? ORDER BY m.id DESC LIMIT ?` args = []any{channelID, beforeID, limit} } else { - query = `SELECT m.id, c.name, u.nick, m.content, m.created_at + query = `SELECT m.id, c.name, u.nick, m.content, + m.created_at FROM messages m INNER JOIN users u ON u.id = m.user_id INNER JOIN channels c ON c.id = m.channel_id @@ -299,116 +559,153 @@ func (s *Database) GetMessagesBefore(ctx context.Context, channelID int64, befor ORDER BY m.id DESC LIMIT ?` args = []any{channelID, limit} } + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer rows.Close() - var msgs []MessageInfo - for rows.Next() { - var m MessageInfo - if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil { - return nil, err - } - msgs = append(msgs, m) - } - if msgs == nil { - msgs = []MessageInfo{} - } - // Reverse to ascending order - for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { - msgs[i], msgs[j] = msgs[j], msgs[i] + + defer func() { _ = rows.Close() }() + + msgs, scanErr := scanChannelMessages(rows) + if scanErr != nil { + return nil, scanErr } + + // Reverse to ascending order. + reverseMessages(msgs) + return msgs, nil } -// GetDMsBefore returns DMs between two users before a given ID (for history scrollback). -func (s *Database) GetDMsBefore(ctx context.Context, userA, userB int64, beforeID int64, limit int) ([]MessageInfo, error) { +// GetDMsBefore returns DMs between two users before a given +// ID (for history scrollback). +func (s *Database) GetDMsBefore( + ctx context.Context, + userA, userB int64, + beforeID int64, + limit int, +) ([]MessageInfo, error) { if limit <= 0 { - limit = 50 + limit = defaultMessageLimit } + var query string + var args []any + if beforeID > 0 { - query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at + query = `SELECT m.id, u.nick, m.content, t.nick, + m.created_at FROM messages m INNER JOIN users u ON u.id = m.user_id INNER JOIN users t ON t.id = m.dm_target_id WHERE m.is_dm = 1 AND m.id < ? - AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?)) + AND ((m.user_id = ? AND m.dm_target_id = ?) + OR (m.user_id = ? AND m.dm_target_id = ?)) ORDER BY m.id DESC LIMIT ?` - args = []any{beforeID, userA, userB, userB, userA, limit} + args = []any{ + beforeID, userA, userB, userB, userA, limit, + } } else { - query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at + query = `SELECT m.id, u.nick, m.content, t.nick, + m.created_at FROM messages m INNER JOIN users u ON u.id = m.user_id INNER JOIN users t ON t.id = m.dm_target_id WHERE m.is_dm = 1 - AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?)) + AND ((m.user_id = ? AND m.dm_target_id = ?) + OR (m.user_id = ? AND m.dm_target_id = ?)) ORDER BY m.id DESC LIMIT ?` args = []any{userA, userB, userB, userA, limit} } + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer rows.Close() - var msgs []MessageInfo - for rows.Next() { - var m MessageInfo - if err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt); err != nil { - return nil, err - } - m.IsDM = true - msgs = append(msgs, m) - } - if msgs == nil { - msgs = []MessageInfo{} - } - // Reverse to ascending order - for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { - msgs[i], msgs[j] = msgs[j], msgs[i] + + defer func() { _ = rows.Close() }() + + msgs, scanErr := scanDMMessages(rows) + if scanErr != nil { + return nil, scanErr } + + // Reverse to ascending order. + reverseMessages(msgs) + return msgs, nil } // ChangeNick updates a user's nickname. -func (s *Database) ChangeNick(ctx context.Context, userID int64, newNick string) error { +func (s *Database) ChangeNick( + ctx context.Context, + userID int64, + newNick string, +) error { _, err := s.db.ExecContext(ctx, - "UPDATE users SET nick = ? WHERE id = ?", newNick, userID) + "UPDATE users SET nick = ? WHERE id = ?", + newNick, userID) + return err } // SetTopic sets the topic for a channel. -func (s *Database) SetTopic(ctx context.Context, channelName string, _ int64, topic string) error { +func (s *Database) SetTopic( + ctx context.Context, + channelName string, + _ int64, + topic string, +) error { _, err := s.db.ExecContext(ctx, - "UPDATE channels SET topic = ? WHERE name = ?", topic, channelName) + "UPDATE channels SET topic = ? WHERE name = ?", + topic, channelName) + return err } -// GetServerName returns the server name (unused, config provides this). +// GetServerName returns the server name (unused, config +// provides this). func (s *Database) GetServerName() string { return "" } // ListAllChannels returns all channels. -func (s *Database) ListAllChannels(ctx context.Context) ([]ChannelInfo, error) { +func (s *Database) ListAllChannels( + ctx context.Context, +) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, "SELECT id, name, topic FROM channels ORDER BY name") if err != nil { return nil, err } - defer rows.Close() + + defer func() { _ = rows.Close() }() + var channels []ChannelInfo + for rows.Next() { var ch ChannelInfo - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil { + + err := rows.Scan( + &ch.ID, &ch.Name, &ch.Topic, + ) + if err != nil { return nil, err } + channels = append(channels, ch) } + + err = rows.Err() + if err != nil { + return nil, err + } + if channels == nil { channels = []ChannelInfo{} } + return channels, nil } diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 975b7a1..764a0fe 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -11,79 +11,181 @@ import ( "github.com/go-chi/chi" ) -// authUser extracts the user from the Authorization header (Bearer token). -func (s *Handlers) authUser(r *http.Request) (int64, string, error) { +const ( + maxNickLen = 32 + defaultHistory = 50 +) + +// authUser extracts the user from the Authorization header +// (Bearer token). +func (s *Handlers) authUser( + r *http.Request, +) (int64, string, error) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { return 0, "", sql.ErrNoRows } + token := strings.TrimPrefix(auth, "Bearer ") + return s.params.Database.GetUserByToken(r.Context(), token) } -func (s *Handlers) requireAuth(w http.ResponseWriter, r *http.Request) (int64, string, bool) { +func (s *Handlers) requireAuth( + w http.ResponseWriter, + r *http.Request, +) (int64, string, bool) { uid, nick, err := s.authUser(r) if err != nil { - s.respondJSON(w, r, map[string]string{"error": "unauthorized"}, http.StatusUnauthorized) + s.respondJSON( + w, r, + map[string]string{"error": "unauthorized"}, + http.StatusUnauthorized, + ) + return 0, "", false } + return uid, nick, true } -// HandleCreateSession creates a new user session and returns the auth token. +func (s *Handlers) respondError( + w http.ResponseWriter, + r *http.Request, + msg string, + code int, +) { + s.respondJSON(w, r, map[string]string{"error": msg}, code) +} + +func (s *Handlers) internalError( + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + s.log.Error(msg, "error", err) + s.respondError(w, r, "internal error", http.StatusInternalServerError) +} + +// bodyLines extracts body as string lines from a request body +// field. +func bodyLines(body any) []string { + switch v := body.(type) { + case []any: + lines := make([]string, 0, len(v)) + + for _, item := range v { + if s, ok := item.(string); ok { + lines = append(lines, s) + } + } + + return lines + case []string: + return v + default: + return nil + } +} + +// HandleCreateSession creates a new user session and returns +// the auth token. func (s *Handlers) HandleCreateSession() http.HandlerFunc { type request struct { Nick string `json:"nick"` } + type response struct { ID int64 `json:"id"` Nick string `json:"nick"` Token string `json:"token"` } + return func(w http.ResponseWriter, r *http.Request) { var req request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest) + + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + s.respondError( + w, r, "invalid request", + http.StatusBadRequest, + ) + return } + req.Nick = strings.TrimSpace(req.Nick) - if req.Nick == "" || len(req.Nick) > 32 { - s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest) + + if req.Nick == "" || len(req.Nick) > maxNickLen { + s.respondError( + w, r, "nick must be 1-32 characters", + http.StatusBadRequest, + ) + return } - id, token, err := s.params.Database.CreateUser(r.Context(), req.Nick) + + id, token, err := s.params.Database.CreateUser( + r.Context(), req.Nick, + ) if err != nil { if strings.Contains(err.Error(), "UNIQUE") { - s.respondJSON(w, r, map[string]string{"error": "nick already taken"}, http.StatusConflict) + s.respondError( + w, r, "nick already taken", + http.StatusConflict, + ) + return } - s.log.Error("create user failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + + s.internalError(w, r, "create user failed", err) + return } - s.respondJSON(w, r, &response{ID: id, Nick: req.Nick, Token: token}, http.StatusCreated) + + s.respondJSON( + w, r, + &response{ID: id, Nick: req.Nick, Token: token}, + http.StatusCreated, + ) } } -// HandleState returns the current user's info and joined channels. +// HandleState returns the current user's info and joined +// channels. func (s *Handlers) HandleState() http.HandlerFunc { type response struct { ID int64 `json:"id"` Nick string `json:"nick"` Channels []db.ChannelInfo `json:"channels"` } + return func(w http.ResponseWriter, r *http.Request) { uid, nick, ok := s.requireAuth(w, r) if !ok { return } - channels, err := s.params.Database.ListChannels(r.Context(), uid) + + channels, err := s.params.Database.ListChannels( + r.Context(), uid, + ) if err != nil { - s.log.Error("list channels failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + s.internalError( + w, r, "list channels failed", err, + ) + return } - s.respondJSON(w, r, &response{ID: uid, Nick: nick, Channels: channels}, http.StatusOK) + + s.respondJSON( + w, r, + &response{ + ID: uid, Nick: nick, + Channels: channels, + }, + http.StatusOK, + ) } } @@ -94,12 +196,18 @@ func (s *Handlers) HandleListAllChannels() http.HandlerFunc { if !ok { return } - channels, err := s.params.Database.ListAllChannels(r.Context()) + + channels, err := s.params.Database.ListAllChannels( + r.Context(), + ) if err != nil { - s.log.Error("list all channels failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + s.internalError( + w, r, "list all channels failed", err, + ) + return } + s.respondJSON(w, r, channels, http.StatusOK) } } @@ -111,286 +219,570 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc { if !ok { return } + name := "#" + chi.URLParam(r, "channel") + var chID int64 - err := s.params.Database.GetDB().QueryRowContext(r.Context(), - "SELECT id FROM channels WHERE name = ?", name).Scan(&chID) + + err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query + r.Context(), + "SELECT id FROM channels WHERE name = ?", + name, + ).Scan(&chID) if err != nil { - s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) + s.respondError( + w, r, "channel not found", + http.StatusNotFound, + ) + return } - members, err := s.params.Database.ChannelMembers(r.Context(), chID) + + members, err := s.params.Database.ChannelMembers( + r.Context(), chID, + ) if err != nil { - s.log.Error("channel members failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + s.internalError( + w, r, "channel members failed", err, + ) + return } + s.respondJSON(w, r, members, http.StatusOK) } } -// HandleGetMessages returns all new messages (channel + DM) for the user via long-polling. -// This is the single unified message stream — replaces separate channel/DM/poll endpoints. +// HandleGetMessages returns all new messages (channel + DM) +// for the user via long-polling. func (s *Handlers) HandleGetMessages() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { uid, _, ok := s.requireAuth(w, r) if !ok { return } - afterID, _ := strconv.ParseInt(r.URL.Query().Get("after"), 10, 64) - limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) - msgs, err := s.params.Database.PollMessages(r.Context(), uid, afterID, limit) + + afterID, _ := strconv.ParseInt( + r.URL.Query().Get("after"), 10, 64, + ) + + limit, _ := strconv.Atoi( + r.URL.Query().Get("limit"), + ) + + msgs, err := s.params.Database.PollMessages( + r.Context(), uid, afterID, limit, + ) if err != nil { - s.log.Error("get messages failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + s.internalError( + w, r, "get messages failed", err, + ) + return } + s.respondJSON(w, r, msgs, http.StatusOK) } } -// HandleSendCommand handles all C2S commands via POST /messages. -// The "command" field dispatches to the appropriate logic. +type sendRequest struct { + Command string `json:"command"` + To string `json:"to"` + Params []string `json:"params,omitempty"` + Body any `json:"body,omitempty"` +} + +// HandleSendCommand handles all C2S commands via POST +// /messages. func (s *Handlers) HandleSendCommand() http.HandlerFunc { - type request struct { - Command string `json:"command"` - To string `json:"to"` - Params []string `json:"params,omitempty"` - Body interface{} `json:"body,omitempty"` - } return func(w http.ResponseWriter, r *http.Request) { uid, nick, ok := s.requireAuth(w, r) if !ok { return } - var req request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest) + + var req sendRequest + + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + s.respondError( + w, r, "invalid request", + http.StatusBadRequest, + ) + return } - req.Command = strings.ToUpper(strings.TrimSpace(req.Command)) + + req.Command = strings.ToUpper( + strings.TrimSpace(req.Command), + ) req.To = strings.TrimSpace(req.To) - // Helper to extract body as string lines. - bodyLines := func() []string { - switch v := req.Body.(type) { - case []interface{}: - lines := make([]string, 0, len(v)) - for _, item := range v { - if s, ok := item.(string); ok { - lines = append(lines, s) - } - } - return lines - case []string: - return v - default: - return nil - } - } - - switch req.Command { - case "PRIVMSG", "NOTICE": - if req.To == "" { - s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) - return - } - lines := bodyLines() - if len(lines) == 0 { - s.respondJSON(w, r, map[string]string{"error": "body required"}, http.StatusBadRequest) - return - } - content := strings.Join(lines, "\n") - - if strings.HasPrefix(req.To, "#") { - // Channel message - var chID int64 - err := s.params.Database.GetDB().QueryRowContext(r.Context(), - "SELECT id FROM channels WHERE name = ?", req.To).Scan(&chID) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) - return - } - msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, content) - if err != nil { - s.log.Error("send message failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) - } else { - // DM - targetID, err := s.params.Database.GetUserByNick(r.Context(), req.To) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) - return - } - msgID, err := s.params.Database.SendDM(r.Context(), uid, targetID, content) - if err != nil { - s.log.Error("send dm failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) - } - - case "JOIN": - if req.To == "" { - s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) - return - } - channel := req.To - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } - chID, err := s.params.Database.GetOrCreateChannel(r.Context(), channel) - if err != nil { - s.log.Error("get/create channel failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - if err := s.params.Database.JoinChannel(r.Context(), chID, uid); err != nil { - s.log.Error("join channel failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]string{"status": "joined", "channel": channel}, http.StatusOK) - - case "PART": - if req.To == "" { - s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) - return - } - channel := req.To - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } - var chID int64 - err := s.params.Database.GetDB().QueryRowContext(r.Context(), - "SELECT id FROM channels WHERE name = ?", channel).Scan(&chID) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) - return - } - if err := s.params.Database.PartChannel(r.Context(), chID, uid); err != nil { - s.log.Error("part channel failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]string{"status": "parted", "channel": channel}, http.StatusOK) - - case "NICK": - lines := bodyLines() - if len(lines) == 0 { - s.respondJSON(w, r, map[string]string{"error": "body required (new nick)"}, http.StatusBadRequest) - return - } - newNick := strings.TrimSpace(lines[0]) - if newNick == "" || len(newNick) > 32 { - s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest) - return - } - if err := s.params.Database.ChangeNick(r.Context(), uid, newNick); err != nil { - if strings.Contains(err.Error(), "UNIQUE") { - s.respondJSON(w, r, map[string]string{"error": "nick already in use"}, http.StatusConflict) - return - } - s.log.Error("change nick failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK) - - case "TOPIC": - if req.To == "" { - s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) - return - } - lines := bodyLines() - if len(lines) == 0 { - s.respondJSON(w, r, map[string]string{"error": "body required (topic text)"}, http.StatusBadRequest) - return - } - topic := strings.Join(lines, " ") - channel := req.To - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } - if err := s.params.Database.SetTopic(r.Context(), channel, uid, topic); err != nil { - s.log.Error("set topic failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]string{"status": "ok", "topic": topic}, http.StatusOK) - - case "PING": - s.respondJSON(w, r, map[string]string{"command": "PONG", "from": s.params.Config.ServerName}, http.StatusOK) - - default: - _ = nick // suppress unused warning - s.respondJSON(w, r, map[string]string{"error": "unknown command: " + req.Command}, http.StatusBadRequest) - } + s.dispatchCommand(w, r, uid, nick, &req) } } -// HandleGetHistory returns message history for a specific target (channel or DM). +func (s *Handlers) dispatchCommand( + w http.ResponseWriter, + r *http.Request, + uid int64, + nick string, + req *sendRequest, +) { + switch req.Command { + case "PRIVMSG", "NOTICE": + s.handlePrivmsg(w, r, uid, req) + case "JOIN": + s.handleJoin(w, r, uid, req) + case "PART": + s.handlePart(w, r, uid, req) + case "NICK": + s.handleNick(w, r, uid, req) + case "TOPIC": + s.handleTopic(w, r, uid, req) + case "PING": + s.respondJSON( + w, r, + map[string]string{ + "command": "PONG", + "from": s.params.Config.ServerName, + }, + http.StatusOK, + ) + default: + _ = nick + + s.respondError( + w, r, + "unknown command: "+req.Command, + http.StatusBadRequest, + ) + } +} + +func (s *Handlers) handlePrivmsg( + w http.ResponseWriter, + r *http.Request, + uid int64, + req *sendRequest, +) { + if req.To == "" { + s.respondError( + w, r, "to field required", + http.StatusBadRequest, + ) + + return + } + + lines := bodyLines(req.Body) + if len(lines) == 0 { + s.respondError( + w, r, "body required", http.StatusBadRequest, + ) + + return + } + + content := strings.Join(lines, "\n") + + if strings.HasPrefix(req.To, "#") { + s.sendChannelMsg(w, r, uid, req.To, content) + } else { + s.sendDM(w, r, uid, req.To, content) + } +} + +func (s *Handlers) sendChannelMsg( + w http.ResponseWriter, + r *http.Request, + uid int64, + channel, content string, +) { + var chID int64 + + err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query + r.Context(), + "SELECT id FROM channels WHERE name = ?", + channel, + ).Scan(&chID) + if err != nil { + s.respondError( + w, r, "channel not found", + http.StatusNotFound, + ) + + return + } + + msgID, err := s.params.Database.SendMessage( + r.Context(), chID, uid, content, + ) + if err != nil { + s.internalError(w, r, "send message failed", err) + + return + } + + s.respondJSON( + w, r, + map[string]any{"id": msgID, "status": "sent"}, + http.StatusCreated, + ) +} + +func (s *Handlers) sendDM( + w http.ResponseWriter, + r *http.Request, + uid int64, + toNick, content string, +) { + targetID, err := s.params.Database.GetUserByNick( + r.Context(), toNick, + ) + if err != nil { + s.respondError( + w, r, "user not found", http.StatusNotFound, + ) + + return + } + + msgID, err := s.params.Database.SendDM( + r.Context(), uid, targetID, content, + ) + if err != nil { + s.internalError(w, r, "send dm failed", err) + + return + } + + s.respondJSON( + w, r, + map[string]any{"id": msgID, "status": "sent"}, + http.StatusCreated, + ) +} + +func (s *Handlers) handleJoin( + w http.ResponseWriter, + r *http.Request, + uid int64, + req *sendRequest, +) { + if req.To == "" { + s.respondError( + w, r, "to field required", + http.StatusBadRequest, + ) + + return + } + + channel := req.To + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + + chID, err := s.params.Database.GetOrCreateChannel( + r.Context(), channel, + ) + if err != nil { + s.internalError( + w, r, "get/create channel failed", err, + ) + + return + } + + err = s.params.Database.JoinChannel( + r.Context(), chID, uid, + ) + if err != nil { + s.internalError(w, r, "join channel failed", err) + + return + } + + s.respondJSON( + w, r, + map[string]string{ + "status": "joined", "channel": channel, + }, + http.StatusOK, + ) +} + +func (s *Handlers) handlePart( + w http.ResponseWriter, + r *http.Request, + uid int64, + req *sendRequest, +) { + if req.To == "" { + s.respondError( + w, r, "to field required", + http.StatusBadRequest, + ) + + return + } + + channel := req.To + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + + var chID int64 + + err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query + r.Context(), + "SELECT id FROM channels WHERE name = ?", + channel, + ).Scan(&chID) + if err != nil { + s.respondError( + w, r, "channel not found", + http.StatusNotFound, + ) + + return + } + + err = s.params.Database.PartChannel( + r.Context(), chID, uid, + ) + if err != nil { + s.internalError(w, r, "part channel failed", err) + + return + } + + s.respondJSON( + w, r, + map[string]string{ + "status": "parted", "channel": channel, + }, + http.StatusOK, + ) +} + +func (s *Handlers) handleNick( + w http.ResponseWriter, + r *http.Request, + uid int64, + req *sendRequest, +) { + lines := bodyLines(req.Body) + if len(lines) == 0 { + s.respondError( + w, r, "body required (new nick)", + http.StatusBadRequest, + ) + + return + } + + newNick := strings.TrimSpace(lines[0]) + if newNick == "" || len(newNick) > maxNickLen { + s.respondError( + w, r, "nick must be 1-32 characters", + http.StatusBadRequest, + ) + + return + } + + err := s.params.Database.ChangeNick( + r.Context(), uid, newNick, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE") { + s.respondError( + w, r, "nick already in use", + http.StatusConflict, + ) + + return + } + + s.internalError(w, r, "change nick failed", err) + + return + } + + s.respondJSON( + w, r, + map[string]string{"status": "ok", "nick": newNick}, + http.StatusOK, + ) +} + +func (s *Handlers) handleTopic( + w http.ResponseWriter, + r *http.Request, + uid int64, + req *sendRequest, +) { + if req.To == "" { + s.respondError( + w, r, "to field required", + http.StatusBadRequest, + ) + + return + } + + lines := bodyLines(req.Body) + if len(lines) == 0 { + s.respondError( + w, r, "body required (topic text)", + http.StatusBadRequest, + ) + + return + } + + topic := strings.Join(lines, " ") + + channel := req.To + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + + err := s.params.Database.SetTopic( + r.Context(), channel, uid, topic, + ) + if err != nil { + s.internalError(w, r, "set topic failed", err) + + return + } + + s.respondJSON( + w, r, + map[string]string{"status": "ok", "topic": topic}, + http.StatusOK, + ) +} + +// HandleGetHistory returns message history for a specific +// target (channel or DM). func (s *Handlers) HandleGetHistory() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { uid, _, ok := s.requireAuth(w, r) if !ok { return } + target := r.URL.Query().Get("target") if target == "" { - s.respondJSON(w, r, map[string]string{"error": "target required"}, http.StatusBadRequest) + s.respondError( + w, r, "target required", + http.StatusBadRequest, + ) + return } - beforeID, _ := strconv.ParseInt(r.URL.Query().Get("before"), 10, 64) - limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) + + beforeID, _ := strconv.ParseInt( + r.URL.Query().Get("before"), 10, 64, + ) + + limit, _ := strconv.Atoi( + r.URL.Query().Get("limit"), + ) if limit <= 0 { - limit = 50 + limit = defaultHistory } if strings.HasPrefix(target, "#") { - // Channel history - var chID int64 - err := s.params.Database.GetDB().QueryRowContext(r.Context(), - "SELECT id FROM channels WHERE name = ?", target).Scan(&chID) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) - return - } - msgs, err := s.params.Database.GetMessagesBefore(r.Context(), chID, beforeID, limit) - if err != nil { - s.log.Error("get history failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, msgs, http.StatusOK) + s.getChannelHistory( + w, r, target, beforeID, limit, + ) } else { - // DM history - targetID, err := s.params.Database.GetUserByNick(r.Context(), target) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) - return - } - msgs, err := s.params.Database.GetDMsBefore(r.Context(), uid, targetID, beforeID, limit) - if err != nil { - s.log.Error("get dm history failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, msgs, http.StatusOK) + s.getDMHistory( + w, r, uid, target, beforeID, limit, + ) } } } +func (s *Handlers) getChannelHistory( + w http.ResponseWriter, + r *http.Request, + target string, + beforeID int64, + limit int, +) { + var chID int64 + + err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query + r.Context(), + "SELECT id FROM channels WHERE name = ?", + target, + ).Scan(&chID) + if err != nil { + s.respondError( + w, r, "channel not found", + http.StatusNotFound, + ) + + return + } + + msgs, err := s.params.Database.GetMessagesBefore( + r.Context(), chID, beforeID, limit, + ) + if err != nil { + s.internalError(w, r, "get history failed", err) + + return + } + + s.respondJSON(w, r, msgs, http.StatusOK) +} + +func (s *Handlers) getDMHistory( + w http.ResponseWriter, + r *http.Request, + uid int64, + target string, + beforeID int64, + limit int, +) { + targetID, err := s.params.Database.GetUserByNick( + r.Context(), target, + ) + if err != nil { + s.respondError( + w, r, "user not found", http.StatusNotFound, + ) + + return + } + + msgs, err := s.params.Database.GetDMsBefore( + r.Context(), uid, targetID, beforeID, limit, + ) + if err != nil { + s.internalError( + w, r, "get dm history failed", err, + ) + + return + } + + s.respondJSON(w, r, msgs, http.StatusOK) +} + // HandleServerInfo returns server metadata (MOTD, name). func (s *Handlers) HandleServerInfo() http.HandlerFunc { type response struct { Name string `json:"name"` MOTD string `json:"motd"` } + return func(w http.ResponseWriter, r *http.Request) { s.respondJSON(w, r, &response{ Name: s.params.Config.ServerName, diff --git a/internal/models/auth_token.go b/internal/models/auth_token.go index f1646e2..c2c3fd1 100644 --- a/internal/models/auth_token.go +++ b/internal/models/auth_token.go @@ -2,7 +2,6 @@ package models import ( "context" - "fmt" "time" ) @@ -23,5 +22,5 @@ func (t *AuthToken) User(ctx context.Context) (*User, error) { return ul.GetUserByID(ctx, t.UserID) } - return nil, fmt.Errorf("user lookup not available") + return nil, ErrUserLookupNotAvailable } diff --git a/internal/models/channel_member.go b/internal/models/channel_member.go index f93ed9d..59586c7 100644 --- a/internal/models/channel_member.go +++ b/internal/models/channel_member.go @@ -2,7 +2,6 @@ package models import ( "context" - "fmt" "time" ) @@ -23,7 +22,7 @@ func (cm *ChannelMember) User(ctx context.Context) (*User, error) { return ul.GetUserByID(ctx, cm.UserID) } - return nil, fmt.Errorf("user lookup not available") + return nil, ErrUserLookupNotAvailable } // Channel returns the full Channel for this membership. @@ -32,5 +31,5 @@ func (cm *ChannelMember) Channel(ctx context.Context) (*Channel, error) { return cl.GetChannelByID(ctx, cm.ChannelID) } - return nil, fmt.Errorf("channel lookup not available") + return nil, ErrChannelLookupNotAvailable } diff --git a/internal/models/model.go b/internal/models/model.go index b65c6e8..302f6b3 100644 --- a/internal/models/model.go +++ b/internal/models/model.go @@ -6,6 +6,7 @@ package models import ( "context" "database/sql" + "errors" ) // DB is the interface that models use to query the database. @@ -24,6 +25,12 @@ type ChannelLookup interface { GetChannelByID(ctx context.Context, id string) (*Channel, error) } +// Sentinel errors for model lookup methods. +var ( + ErrUserLookupNotAvailable = errors.New("user lookup not available") + ErrChannelLookupNotAvailable = errors.New("channel lookup not available") +) + // Base is embedded in all model structs to provide database access. type Base struct { db DB @@ -40,7 +47,7 @@ func (b *Base) GetDB() *sql.DB { } // GetUserLookup returns the DB as a UserLookup if it implements the interface. -func (b *Base) GetUserLookup() UserLookup { +func (b *Base) GetUserLookup() UserLookup { //nolint:ireturn // interface return is intentional if ul, ok := b.db.(UserLookup); ok { return ul } @@ -49,7 +56,7 @@ func (b *Base) GetUserLookup() UserLookup { } // GetChannelLookup returns the DB as a ChannelLookup if it implements the interface. -func (b *Base) GetChannelLookup() ChannelLookup { +func (b *Base) GetChannelLookup() ChannelLookup { //nolint:ireturn // interface return is intentional if cl, ok := b.db.(ChannelLookup); ok { return cl } diff --git a/internal/models/session.go b/internal/models/session.go index 42231d9..295def2 100644 --- a/internal/models/session.go +++ b/internal/models/session.go @@ -2,7 +2,6 @@ package models import ( "context" - "fmt" "time" ) @@ -23,5 +22,5 @@ func (s *Session) User(ctx context.Context) (*User, error) { return ul.GetUserByID(ctx, s.UserID) } - return nil, fmt.Errorf("user lookup not available") + return nil, ErrUserLookupNotAvailable } diff --git a/internal/server/routes.go b/internal/server/routes.go index e211492..9a70e3c 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -71,17 +71,37 @@ func (s *Server) SetupRoutes() { s.log.Error("failed to get web dist filesystem", "error", err) } else { fileServer := http.FileServer(http.FS(distFS)) + s.router.Get("/*", func(w http.ResponseWriter, r *http.Request) { - // Try to serve the file; if not found, serve index.html for SPA routing - f, err := distFS.(fs.ReadFileFS).ReadFile(r.URL.Path[1:]) - if err != nil || len(f) == 0 { - indexHTML, _ := distFS.(fs.ReadFileFS).ReadFile("index.html") - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(indexHTML) - return - } - fileServer.ServeHTTP(w, r) + s.serveSPA(distFS, fileServer, w, r) }) } } + +func (s *Server) serveSPA( + distFS fs.FS, + fileServer http.Handler, + w http.ResponseWriter, + r *http.Request, +) { + readFS, ok := distFS.(fs.ReadFileFS) + if !ok { + http.Error(w, "filesystem error", http.StatusInternalServerError) + + return + } + + // Try to serve the file; fall back to index.html for SPA routing. + f, err := readFS.ReadFile(r.URL.Path[1:]) + if err != nil || len(f) == 0 { + indexHTML, _ := readFS.ReadFile("index.html") + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(indexHTML) + + return + } + + fileServer.ServeHTTP(w, r) +}