From a57a73e94e0146fc415cb2f564dcea608dab5e87 Mon Sep 17 00:00:00 2001 From: clawbot Date: Thu, 26 Feb 2026 21:21:49 -0800 Subject: [PATCH] fix: address all PR #10 review findings Security: - Add channel membership check before PRIVMSG (prevents non-members from sending) - Add membership check on history endpoint (channels require membership, DMs scoped to own nick) - Enforce MaxBytesReader on all POST request bodies - Fix rand.Read error being silently ignored in token generation Data integrity: - Fix TOCTOU race in GetOrCreateChannel using INSERT OR IGNORE + SELECT Build: - Add CGO_ENABLED=0 to golangci-lint install in Dockerfile (fixes alpine build) Linting: - Strict .golangci.yml: only wsl disabled (deprecated in v2) - Re-enable exhaustruct, depguard, godot, wrapcheck, varnamelen - Fix linters-settings -> linters.settings for v2 config format - Fix ALL lint findings in actual code (no linter config weakening) - Wrap all external package errors (wrapcheck) - Fill struct fields or add targeted nolint:exhaustruct where appropriate - Rename short variables (ts->timestamp, n->bufIndex, etc.) - Add depguard deny policy for io/ioutil and math/rand - Exclude G704 (SSRF) in gosec config (CLI client takes user-configured URLs) Tests: - Add security tests (TestNonMemberCannotSend, TestHistoryNonMember) - Split TestInsertAndPollMessages for reduced complexity - Fix parallel test safety (viper global state prevents parallelism) - Use t.Context() instead of context.Background() in tests Docker build verified passing locally. --- .golangci.yml | 40 +- Dockerfile | 2 +- cmd/chat-cli/api/client.go | 273 ++--- cmd/chat-cli/api/types.go | 13 +- cmd/chat-cli/main.go | 148 ++- cmd/chat-cli/ui.go | 21 +- internal/broker/broker.go | 32 +- internal/broker/broker_test.go | 67 +- internal/config/config.go | 12 +- internal/db/db.go | 136 ++- internal/db/export_test.go | 33 +- internal/db/queries.go | 345 ++++-- internal/db/queries_test.go | 53 +- internal/handlers/api.go | 1120 +++++++++++------- internal/handlers/api_test.go | 1699 +++++++++++++++------------ internal/handlers/handlers.go | 47 +- internal/handlers/healthcheck.go | 11 +- internal/healthcheck/healthcheck.go | 36 +- internal/logger/logger.go | 61 +- internal/middleware/middleware.go | 133 ++- internal/server/routes.go | 127 +- internal/server/server.go | 144 ++- 22 files changed, 2650 insertions(+), 1903 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 34a8e31..2698d2d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -7,24 +7,28 @@ run: linters: default: all disable: - # Genuinely incompatible with project patterns - - exhaustruct # Requires all struct fields - - depguard # Dependency allow/block lists - - godot # Requires comments to end with periods - - wsl # Deprecated, replaced by wsl_v5 - - wrapcheck # Too verbose for internal packages - - varnamelen # Short names like db, id are idiomatic Go - -linters-settings: - lll: - line-length: 88 - funlen: - lines: 80 - statements: 50 - cyclop: - max-complexity: 15 - dupl: - threshold: 100 + - wsl # Deprecated in v2, replaced by wsl_v5 + settings: + lll: + line-length: 88 + funlen: + lines: 80 + statements: 50 + cyclop: + max-complexity: 15 + dupl: + threshold: 100 + gosec: + excludes: + - G704 + depguard: + rules: + all: + deny: + - pkg: "io/ioutil" + desc: "Deprecated; use io and os packages." + - pkg: "math/rand$" + desc: "Use crypto/rand for security-sensitive code." issues: exclude-use-default: false diff --git a/Dockerfile b/Dockerfile index 8c7526f..2dd9992 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ WORKDIR /src RUN apk add --no-cache git build-base make # golangci-lint v2.1.6 (eabc2638a66d), 2026-02-26 -RUN go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@eabc2638a66daf5bb6c6fb052a32fa3ef7b6600d +RUN CGO_ENABLED=0 go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@eabc2638a66daf5bb6c6fb052a32fa3ef7b6600d COPY go.mod go.sum ./ RUN go mod download diff --git a/cmd/chat-cli/api/client.go b/cmd/chat-cli/api/client.go index aea4ed6..ca22506 100644 --- a/cmd/chat-cli/api/client.go +++ b/cmd/chat-cli/api/client.go @@ -1,3 +1,4 @@ +// Package chatapi provides a client for the chat server API. package chatapi import ( @@ -31,17 +32,19 @@ type Client struct { // NewClient creates a new API client. func NewClient(baseURL string) *Client { - return &Client{ - BaseURL: baseURL, - HTTPClient: &http.Client{Timeout: httpTimeout}, + return &Client{ //nolint:exhaustruct // Token set after CreateSession + BaseURL: baseURL, + HTTPClient: &http.Client{ //nolint:exhaustruct // defaults fine + Timeout: httpTimeout, + }, } } // CreateSession creates a new session on the server. -func (c *Client) CreateSession( +func (client *Client) CreateSession( nick string, ) (*SessionResponse, error) { - data, err := c.do( + data, err := client.do( http.MethodPost, "/api/v1/session", &SessionRequest{Nick: nick}, @@ -57,14 +60,14 @@ func (c *Client) CreateSession( return nil, fmt.Errorf("decode session: %w", err) } - c.Token = resp.Token + client.Token = resp.Token return &resp, nil } // GetState returns the current user state. -func (c *Client) GetState() (*StateResponse, error) { - data, err := c.do( +func (client *Client) GetState() (*StateResponse, error) { + data, err := client.do( http.MethodGet, "/api/v1/state", nil, ) if err != nil { @@ -82,8 +85,8 @@ func (c *Client) GetState() (*StateResponse, error) { } // SendMessage sends a message (any IRC command). -func (c *Client) SendMessage(msg *Message) error { - _, err := c.do( +func (client *Client) SendMessage(msg *Message) error { + _, err := client.do( http.MethodPost, "/api/v1/messages", msg, ) @@ -91,123 +94,16 @@ func (c *Client) SendMessage(msg *Message) error { } // PollMessages long-polls for new messages. -func (c *Client) PollMessages( +func (client *Client) PollMessages( afterID int64, timeout int, ) (*PollResult, error) { - client := &http.Client{ + pollClient := &http.Client{ //nolint:exhaustruct // defaults fine Timeout: time.Duration( timeout+pollExtraTime, ) * time.Second, } - path := c.buildPollPath(afterID, timeout) - - req, err := http.NewRequestWithContext( - context.Background(), - http.MethodGet, - c.BaseURL+path, - nil, - ) - if err != nil { - return nil, err - } - - req.Header.Set("Authorization", "Bearer "+c.Token) - - resp, err := client.Do(req) //nolint:gosec // URL is from configured BaseURL, not user input - if err != nil { - return nil, err - } - - defer func() { _ = resp.Body.Close() }() - - return c.decodePollResponse(resp) -} - -// JoinChannel joins a channel. -func (c *Client) JoinChannel(channel string) error { - return c.SendMessage( - &Message{Command: "JOIN", To: channel}, - ) -} - -// PartChannel leaves a channel. -func (c *Client) PartChannel(channel string) error { - return c.SendMessage( - &Message{Command: "PART", To: channel}, - ) -} - -// ListChannels returns all channels on the server. -func (c *Client) ListChannels() ([]Channel, error) { - data, err := c.do( - http.MethodGet, "/api/v1/channels", nil, - ) - if err != nil { - return nil, err - } - - var channels []Channel - - err = json.Unmarshal(data, &channels) - if err != nil { - return nil, err - } - - return channels, nil -} - -// GetMembers returns members of a channel. -func (c *Client) GetMembers( - channel string, -) ([]string, error) { - name := strings.TrimPrefix(channel, "#") - - data, err := c.do( - http.MethodGet, - "/api/v1/channels/"+url.PathEscape(name)+ - "/members", - nil, - ) - if err != nil { - return nil, err - } - - var members []string - - err = json.Unmarshal(data, &members) - if err != nil { - return nil, fmt.Errorf( - "unexpected members format: %w", err, - ) - } - - return members, nil -} - -// GetServerInfo returns server info. -func (c *Client) GetServerInfo() (*ServerInfo, error) { - data, err := c.do( - http.MethodGet, "/api/v1/server", nil, - ) - if err != nil { - return nil, err - } - - var info ServerInfo - - err = json.Unmarshal(data, &info) - if err != nil { - return nil, err - } - - return &info, nil -} - -func (c *Client) buildPollPath( - afterID int64, timeout int, -) string { params := url.Values{} if afterID > 0 { params.Set( @@ -218,15 +114,32 @@ func (c *Client) buildPollPath( params.Set("timeout", strconv.Itoa(timeout)) - return "/api/v1/messages?" + params.Encode() -} + path := "/api/v1/messages?" + params.Encode() + + request, err := http.NewRequestWithContext( + context.Background(), + http.MethodGet, + client.BaseURL+path, + nil, + ) + if err != nil { + return nil, fmt.Errorf("new request: %w", err) + } + + request.Header.Set( + "Authorization", "Bearer "+client.Token, + ) + + resp, err := pollClient.Do(request) + if err != nil { + return nil, fmt.Errorf("poll request: %w", err) + } + + defer func() { _ = resp.Body.Close() }() -func (c *Client) decodePollResponse( - resp *http.Response, -) (*PollResult, error) { data, err := io.ReadAll(resp.Body) if err != nil { - return nil, err + return nil, fmt.Errorf("read poll body: %w", err) } if resp.StatusCode >= httpErrThreshold { @@ -251,7 +164,99 @@ func (c *Client) decodePollResponse( }, nil } -func (c *Client) do( +// JoinChannel joins a channel. +func (client *Client) JoinChannel(channel string) error { + return client.SendMessage( + &Message{ //nolint:exhaustruct // only command+to needed + Command: "JOIN", To: channel, + }, + ) +} + +// PartChannel leaves a channel. +func (client *Client) PartChannel(channel string) error { + return client.SendMessage( + &Message{ //nolint:exhaustruct // only command+to needed + Command: "PART", To: channel, + }, + ) +} + +// ListChannels returns all channels on the server. +func (client *Client) ListChannels() ( + []Channel, error, +) { + data, err := client.do( + http.MethodGet, "/api/v1/channels", nil, + ) + if err != nil { + return nil, err + } + + var channels []Channel + + err = json.Unmarshal(data, &channels) + if err != nil { + return nil, fmt.Errorf( + "decode channels: %w", err, + ) + } + + return channels, nil +} + +// GetMembers returns members of a channel. +func (client *Client) GetMembers( + channel string, +) ([]string, error) { + name := strings.TrimPrefix(channel, "#") + + data, err := client.do( + http.MethodGet, + "/api/v1/channels/"+url.PathEscape(name)+ + "/members", + nil, + ) + if err != nil { + return nil, err + } + + var members []string + + err = json.Unmarshal(data, &members) + if err != nil { + return nil, fmt.Errorf( + "unexpected members format: %w", err, + ) + } + + return members, nil +} + +// GetServerInfo returns server info. +func (client *Client) GetServerInfo() ( + *ServerInfo, error, +) { + data, err := client.do( + http.MethodGet, "/api/v1/server", nil, + ) + if err != nil { + return nil, err + } + + var info ServerInfo + + err = json.Unmarshal(data, &info) + if err != nil { + return nil, fmt.Errorf( + "decode server info: %w", err, + ) + } + + return &info, nil +} + +func (client *Client) do( method, path string, body any, ) ([]byte, error) { @@ -266,25 +271,27 @@ func (c *Client) do( bodyReader = bytes.NewReader(data) } - req, err := http.NewRequestWithContext( + request, err := http.NewRequestWithContext( context.Background(), method, - c.BaseURL+path, + client.BaseURL+path, bodyReader, ) if err != nil { return nil, fmt.Errorf("request: %w", err) } - req.Header.Set("Content-Type", "application/json") + request.Header.Set( + "Content-Type", "application/json", + ) - if c.Token != "" { - req.Header.Set( - "Authorization", "Bearer "+c.Token, + if client.Token != "" { + request.Header.Set( + "Authorization", "Bearer "+client.Token, ) } - resp, err := c.HTTPClient.Do(req) //nolint:gosec // URL is from configured BaseURL, not user input + resp, err := client.HTTPClient.Do(request) if err != nil { return nil, fmt.Errorf("http: %w", err) } diff --git a/cmd/chat-cli/api/types.go b/cmd/chat-cli/api/types.go index 1d72cd0..718bf76 100644 --- a/cmd/chat-cli/api/types.go +++ b/cmd/chat-cli/api/types.go @@ -1,4 +1,3 @@ -// Package chatapi provides API types and client for chat-cli. package chatapi import "time" @@ -36,19 +35,19 @@ type Message struct { // BodyLines returns the body as a string slice. func (m *Message) BodyLines() []string { - switch v := m.Body.(type) { + switch bodyVal := m.Body.(type) { case []any: - lines := make([]string, 0, len(v)) + lines := make([]string, 0, len(bodyVal)) - for _, item := range v { - if s, ok := item.(string); ok { - lines = append(lines, s) + for _, item := range bodyVal { + if str, ok := item.(string); ok { + lines = append(lines, str) } } return lines case []string: - return v + return bodyVal default: return nil } diff --git a/cmd/chat-cli/main.go b/cmd/chat-cli/main.go index 6da0038..f5b22f3 100644 --- a/cmd/chat-cli/main.go +++ b/cmd/chat-cli/main.go @@ -32,7 +32,7 @@ type App struct { } func main() { - app := &App{ + app := &App{ //nolint:exhaustruct ui: NewUI(), nick: "guest", } @@ -85,7 +85,7 @@ func (a *App) handleInput(text string) { return } - err := a.client.SendMessage(&api.Message{ + err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct Command: "PRIVMSG", To: target, Body: []string{text}, @@ -98,7 +98,7 @@ func (a *App) handleInput(text string) { return } - ts := time.Now().Format(timeFormat) + timestamp := time.Now().Format(timeFormat) a.mu.Lock() nick := a.nick @@ -106,7 +106,7 @@ func (a *App) handleInput(text string) { a.ui.AddLine(target, fmt.Sprintf( "[gray]%s [green]<%s>[white] %s", - ts, nick, text, + timestamp, nick, text, )) } @@ -123,40 +123,36 @@ func (a *App) handleCommand(text string) { } func (a *App) dispatchCommand(cmd, args string) { - argCmds := map[string]func(string){ - "/connect": a.cmdConnect, - "/nick": a.cmdNick, - "/join": a.cmdJoin, - "/part": a.cmdPart, - "/msg": a.cmdMsg, - "/query": a.cmdQuery, - "/topic": a.cmdTopic, - "/window": a.cmdWindow, - "/w": a.cmdWindow, + switch cmd { + case "/connect": + a.cmdConnect(args) + case "/nick": + a.cmdNick(args) + case "/join": + a.cmdJoin(args) + case "/part": + a.cmdPart(args) + case "/msg": + a.cmdMsg(args) + case "/query": + a.cmdQuery(args) + case "/topic": + a.cmdTopic(args) + case "/names": + a.cmdNames() + case "/list": + a.cmdList() + case "/window", "/w": + a.cmdWindow(args) + case "/quit": + a.cmdQuit() + case "/help": + a.cmdHelp() + default: + a.ui.AddStatus( + "[red]Unknown command: " + cmd, + ) } - - if fn, ok := argCmds[cmd]; ok { - fn(args) - - return - } - - noArgCmds := map[string]func(){ - "/names": a.cmdNames, - "/list": a.cmdList, - "/quit": a.cmdQuit, - "/help": a.cmdHelp, - } - - if fn, ok := noArgCmds[cmd]; ok { - fn() - - return - } - - a.ui.AddStatus( - "[red]Unknown command: " + cmd, - ) } func (a *App) cmdConnect(serverURL string) { @@ -231,7 +227,7 @@ func (a *App) cmdNick(nick string) { return } - err := a.client.SendMessage(&api.Message{ + err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct Command: "NICK", Body: []string{nick}, }) @@ -366,7 +362,7 @@ func (a *App) cmdMsg(args string) { return } - err := a.client.SendMessage(&api.Message{ + err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct Command: "PRIVMSG", To: target, Body: []string{text}, @@ -379,11 +375,11 @@ func (a *App) cmdMsg(args string) { return } - ts := time.Now().Format(timeFormat) + timestamp := time.Now().Format(timeFormat) a.ui.AddLine(target, fmt.Sprintf( "[gray]%s [green]<%s>[white] %s", - ts, nick, text, + timestamp, nick, text, )) } @@ -424,7 +420,7 @@ func (a *App) cmdTopic(args string) { } if args == "" { - err := a.client.SendMessage(&api.Message{ + err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct Command: "TOPIC", To: target, }) @@ -437,7 +433,7 @@ func (a *App) cmdTopic(args string) { return } - err := a.client.SendMessage(&api.Message{ + err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct Command: "TOPIC", To: target, Body: []string{args}, @@ -523,18 +519,18 @@ func (a *App) cmdWindow(args string) { return } - var n int + var bufIndex int - _, _ = fmt.Sscanf(args, "%d", &n) + _, _ = fmt.Sscanf(args, "%d", &bufIndex) - a.ui.SwitchBuffer(n) + a.ui.SwitchBuffer(bufIndex) a.mu.Lock() nick := a.nick a.mu.Unlock() - if n >= 0 && n < a.ui.BufferCount() { - buf := a.ui.buffers[n] + if bufIndex >= 0 && bufIndex < a.ui.BufferCount() { + buf := a.ui.buffers[bufIndex] if buf.Name != "(status)" { a.mu.Lock() a.target = buf.Name @@ -554,7 +550,7 @@ func (a *App) cmdQuit() { if a.connected && a.client != nil { _ = a.client.SendMessage( - &api.Message{Command: "QUIT"}, + &api.Message{Command: "QUIT"}, //nolint:exhaustruct ) } @@ -629,7 +625,7 @@ func (a *App) pollLoop() { } func (a *App) handleServerMessage(msg *api.Message) { - ts := a.formatTS(msg) + timestamp := a.formatTS(msg) a.mu.Lock() myNick := a.nick @@ -637,21 +633,21 @@ func (a *App) handleServerMessage(msg *api.Message) { switch msg.Command { case "PRIVMSG": - a.handlePrivmsgEvent(msg, ts, myNick) + a.handlePrivmsgEvent(msg, timestamp, myNick) case "JOIN": - a.handleJoinEvent(msg, ts) + a.handleJoinEvent(msg, timestamp) case "PART": - a.handlePartEvent(msg, ts) + a.handlePartEvent(msg, timestamp) case "QUIT": - a.handleQuitEvent(msg, ts) + a.handleQuitEvent(msg, timestamp) case "NICK": - a.handleNickEvent(msg, ts, myNick) + a.handleNickEvent(msg, timestamp, myNick) case "NOTICE": - a.handleNoticeEvent(msg, ts) + a.handleNoticeEvent(msg, timestamp) case "TOPIC": - a.handleTopicEvent(msg, ts) + a.handleTopicEvent(msg, timestamp) default: - a.handleDefaultEvent(msg, ts) + a.handleDefaultEvent(msg, timestamp) } } @@ -664,7 +660,7 @@ func (a *App) formatTS(msg *api.Message) string { } func (a *App) handlePrivmsgEvent( - msg *api.Message, ts, myNick string, + msg *api.Message, timestamp, myNick string, ) { lines := msg.BodyLines() text := strings.Join(lines, " ") @@ -680,12 +676,12 @@ func (a *App) handlePrivmsgEvent( a.ui.AddLine(target, fmt.Sprintf( "[gray]%s [green]<%s>[white] %s", - ts, msg.From, text, + timestamp, msg.From, text, )) } func (a *App) handleJoinEvent( - msg *api.Message, ts string, + msg *api.Message, timestamp string, ) { if msg.To == "" { return @@ -693,12 +689,12 @@ func (a *App) handleJoinEvent( a.ui.AddLine(msg.To, fmt.Sprintf( "[gray]%s [yellow]*** %s has joined %s", - ts, msg.From, msg.To, + timestamp, msg.From, msg.To, )) } func (a *App) handlePartEvent( - msg *api.Message, ts string, + msg *api.Message, timestamp string, ) { if msg.To == "" { return @@ -710,18 +706,18 @@ func (a *App) handlePartEvent( if reason != "" { a.ui.AddLine(msg.To, fmt.Sprintf( "[gray]%s [yellow]*** %s has left %s (%s)", - ts, msg.From, msg.To, reason, + timestamp, msg.From, msg.To, reason, )) } else { a.ui.AddLine(msg.To, fmt.Sprintf( "[gray]%s [yellow]*** %s has left %s", - ts, msg.From, msg.To, + timestamp, msg.From, msg.To, )) } } func (a *App) handleQuitEvent( - msg *api.Message, ts string, + msg *api.Message, timestamp string, ) { lines := msg.BodyLines() reason := strings.Join(lines, " ") @@ -729,18 +725,18 @@ func (a *App) handleQuitEvent( if reason != "" { a.ui.AddStatus(fmt.Sprintf( "[gray]%s [yellow]*** %s has quit (%s)", - ts, msg.From, reason, + timestamp, msg.From, reason, )) } else { a.ui.AddStatus(fmt.Sprintf( "[gray]%s [yellow]*** %s has quit", - ts, msg.From, + timestamp, msg.From, )) } } func (a *App) handleNickEvent( - msg *api.Message, ts, myNick string, + msg *api.Message, timestamp, myNick string, ) { lines := msg.BodyLines() @@ -761,24 +757,24 @@ func (a *App) handleNickEvent( a.ui.AddStatus(fmt.Sprintf( "[gray]%s [yellow]*** %s is now known as %s", - ts, msg.From, newNick, + timestamp, msg.From, newNick, )) } func (a *App) handleNoticeEvent( - msg *api.Message, ts string, + msg *api.Message, timestamp string, ) { lines := msg.BodyLines() text := strings.Join(lines, " ") a.ui.AddStatus(fmt.Sprintf( "[gray]%s [magenta]--%s-- %s", - ts, msg.From, text, + timestamp, msg.From, text, )) } func (a *App) handleTopicEvent( - msg *api.Message, ts string, + msg *api.Message, timestamp string, ) { if msg.To == "" { return @@ -789,12 +785,12 @@ func (a *App) handleTopicEvent( a.ui.AddLine(msg.To, fmt.Sprintf( "[gray]%s [cyan]*** %s set topic: %s", - ts, msg.From, text, + timestamp, msg.From, text, )) } func (a *App) handleDefaultEvent( - msg *api.Message, ts string, + msg *api.Message, timestamp string, ) { lines := msg.BodyLines() text := strings.Join(lines, " ") @@ -802,7 +798,7 @@ func (a *App) handleDefaultEvent( if text != "" { a.ui.AddStatus(fmt.Sprintf( "[gray]%s [white][%s] %s", - ts, msg.Command, text, + timestamp, msg.Command, text, )) } } diff --git a/cmd/chat-cli/ui.go b/cmd/chat-cli/ui.go index 847114b..a0f1bbb 100644 --- a/cmd/chat-cli/ui.go +++ b/cmd/chat-cli/ui.go @@ -32,10 +32,10 @@ type UI struct { // NewUI creates the tview-based IRC-like UI. func NewUI() *UI { - ui := &UI{ + ui := &UI{ //nolint:exhaustruct,varnamelen // fields set below; ui is idiomatic app: tview.NewApplication(), buffers: []*Buffer{ - {Name: "(status)", Lines: nil}, + {Name: "(status)", Lines: nil, Unread: 0}, }, } @@ -58,7 +58,12 @@ func NewUI() *UI { // Run starts the UI event loop (blocks). func (ui *UI) Run() error { - return ui.app.Run() + err := ui.app.Run() + if err != nil { + return fmt.Errorf("run ui: %w", err) + } + + return nil } // Stop stops the UI. @@ -100,15 +105,15 @@ func (ui *UI) AddStatus(line string) { } // SwitchBuffer switches to the buffer at index n. -func (ui *UI) SwitchBuffer(n int) { +func (ui *UI) SwitchBuffer(bufIndex int) { ui.app.QueueUpdateDraw(func() { - if n < 0 || n >= len(ui.buffers) { + if bufIndex < 0 || bufIndex >= len(ui.buffers) { return } - ui.currentBuffer = n + ui.currentBuffer = bufIndex - buf := ui.buffers[n] + buf := ui.buffers[bufIndex] buf.Unread = 0 ui.messages.Clear() @@ -282,7 +287,7 @@ func (ui *UI) getOrCreateBuffer(name string) *Buffer { } } - buf := &Buffer{Name: name} + buf := &Buffer{Name: name, Lines: nil, Unread: 0} ui.buffers = append(ui.buffers, buf) return buf diff --git a/internal/broker/broker.go b/internal/broker/broker.go index b1f8535..6974110 100644 --- a/internal/broker/broker.go +++ b/internal/broker/broker.go @@ -8,25 +8,28 @@ import ( // Broker notifies waiting clients when new messages are available. type Broker struct { mu sync.Mutex - listeners map[int64][]chan struct{} // userID -> list of waiting channels + listeners map[int64][]chan struct{} } // New creates a new Broker. func New() *Broker { - return &Broker{ + return &Broker{ //nolint:exhaustruct // mu has zero-value default listeners: make(map[int64][]chan struct{}), } } -// Wait returns a channel that will be closed when a message is available for the user. +// Wait returns a channel that will be closed when a message +// is available for the user. func (b *Broker) Wait(userID int64) chan struct{} { - ch := make(chan struct{}, 1) + waitCh := make(chan struct{}, 1) b.mu.Lock() - b.listeners[userID] = append(b.listeners[userID], ch) + b.listeners[userID] = append( + b.listeners[userID], waitCh, + ) b.mu.Unlock() - return ch + return waitCh } // Notify wakes up all waiting clients for a user. @@ -36,24 +39,29 @@ func (b *Broker) Notify(userID int64) { delete(b.listeners, userID) b.mu.Unlock() - for _, ch := range waiters { + for _, waiter := range waiters { select { - case ch <- struct{}{}: + case waiter <- struct{}{}: default: } } } // Remove removes a specific wait channel (for cleanup on timeout). -func (b *Broker) Remove(userID int64, ch chan struct{}) { +func (b *Broker) Remove( + userID int64, + waitCh chan struct{}, +) { b.mu.Lock() defer b.mu.Unlock() waiters := b.listeners[userID] - for i, w := range waiters { - if w == ch { - b.listeners[userID] = append(waiters[:i], waiters[i+1:]...) + for i, waiter := range waiters { + if waiter == waitCh { + b.listeners[userID] = append( + waiters[:i], waiters[i+1:]..., + ) break } diff --git a/internal/broker/broker_test.go b/internal/broker/broker_test.go index fc653ef..2d35013 100644 --- a/internal/broker/broker_test.go +++ b/internal/broker/broker_test.go @@ -11,8 +11,8 @@ import ( func TestNewBroker(t *testing.T) { t.Parallel() - b := broker.New() - if b == nil { + brk := broker.New() + if brk == nil { t.Fatal("expected non-nil broker") } } @@ -20,16 +20,16 @@ func TestNewBroker(t *testing.T) { func TestWaitAndNotify(t *testing.T) { t.Parallel() - b := broker.New() - ch := b.Wait(1) + brk := broker.New() + waitCh := brk.Wait(1) go func() { time.Sleep(10 * time.Millisecond) - b.Notify(1) + brk.Notify(1) }() select { - case <-ch: + case <-waitCh: case <-time.After(2 * time.Second): t.Fatal("timeout") } @@ -38,21 +38,22 @@ func TestWaitAndNotify(t *testing.T) { func TestNotifyWithoutWaiters(t *testing.T) { t.Parallel() - b := broker.New() - b.Notify(42) // should not panic + brk := broker.New() + brk.Notify(42) // should not panic. } func TestRemove(t *testing.T) { t.Parallel() - b := broker.New() - ch := b.Wait(1) - b.Remove(1, ch) + brk := broker.New() + waitCh := brk.Wait(1) - b.Notify(1) + brk.Remove(1, waitCh) + + brk.Notify(1) select { - case <-ch: + case <-waitCh: t.Fatal("should not receive after remove") case <-time.After(50 * time.Millisecond): } @@ -61,20 +62,20 @@ func TestRemove(t *testing.T) { func TestMultipleWaiters(t *testing.T) { t.Parallel() - b := broker.New() - ch1 := b.Wait(1) - ch2 := b.Wait(1) + brk := broker.New() + waitCh1 := brk.Wait(1) + waitCh2 := brk.Wait(1) - b.Notify(1) + brk.Notify(1) select { - case <-ch1: + case <-waitCh1: case <-time.After(time.Second): t.Fatal("ch1 timeout") } select { - case <-ch2: + case <-waitCh2: case <-time.After(time.Second): t.Fatal("ch2 timeout") } @@ -83,36 +84,38 @@ func TestMultipleWaiters(t *testing.T) { func TestConcurrentWaitNotify(t *testing.T) { t.Parallel() - b := broker.New() + brk := broker.New() - var wg sync.WaitGroup + var waitGroup sync.WaitGroup const concurrency = 100 - for i := range concurrency { - wg.Add(1) + for idx := range concurrency { + waitGroup.Add(1) go func(uid int64) { - defer wg.Done() + defer waitGroup.Done() - ch := b.Wait(uid) - b.Notify(uid) + waitCh := brk.Wait(uid) + + brk.Notify(uid) select { - case <-ch: + case <-waitCh: case <-time.After(time.Second): t.Error("timeout") } - }(int64(i % 10)) + }(int64(idx % 10)) } - wg.Wait() + waitGroup.Wait() } func TestRemoveNonexistent(t *testing.T) { t.Parallel() - b := broker.New() - ch := make(chan struct{}, 1) - b.Remove(999, ch) // should not panic + brk := broker.New() + waitCh := make(chan struct{}, 1) + + brk.Remove(999, waitCh) // should not panic. } diff --git a/internal/config/config.go b/internal/config/config.go index 7820a6d..5999266 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -41,7 +41,9 @@ type Config struct { } // New creates a new Config by reading from files and environment variables. -func New(_ fx.Lifecycle, params Params) (*Config, error) { +func New( + _ fx.Lifecycle, params Params, +) (*Config, error) { log := params.Logger.Get() name := params.Globals.Appname @@ -74,7 +76,7 @@ func New(_ fx.Lifecycle, params Params) (*Config, error) { } } - s := &Config{ + cfg := &Config{ DBURL: viper.GetString("DBURL"), Debug: viper.GetBool("DEBUG"), Port: viper.GetInt("PORT"), @@ -92,10 +94,10 @@ func New(_ fx.Lifecycle, params Params) (*Config, error) { params: ¶ms, } - if s.Debug { + if cfg.Debug { params.Logger.EnableDebugLogging() - s.log = params.Logger.Get() + cfg.log = params.Logger.Get() } - return s, nil + return cfg, nil } diff --git a/internal/db/db.go b/internal/db/db.go index 5cf340b..20ed77a 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -37,93 +37,93 @@ type Params struct { // Database manages the SQLite connection and migrations. type Database struct { - db *sql.DB + conn *sql.DB log *slog.Logger params *Params } // New creates a new Database and registers lifecycle hooks. func New( - lc fx.Lifecycle, + lifecycle fx.Lifecycle, params Params, ) (*Database, error) { - s := new(Database) - s.params = ¶ms - s.log = params.Logger.Get() + database := &Database{ //nolint:exhaustruct // conn set in OnStart + params: ¶ms, + log: params.Logger.Get(), + } - s.log.Info("Database instantiated") + database.log.Info("Database instantiated") - lc.Append(fx.Hook{ + lifecycle.Append(fx.Hook{ OnStart: func(ctx context.Context) error { - s.log.Info("Database OnStart Hook") + database.log.Info("Database OnStart Hook") - return s.connect(ctx) + return database.connect(ctx) }, OnStop: func(_ context.Context) error { - s.log.Info("Database OnStop Hook") + database.log.Info("Database OnStop Hook") - if s.db != nil { - return s.db.Close() + if database.conn != nil { + closeErr := database.conn.Close() + if closeErr != nil { + return fmt.Errorf( + "close db: %w", closeErr, + ) + } } return nil }, }) - return s, nil + return database, nil } // GetDB returns the underlying sql.DB connection. -func (s *Database) GetDB() *sql.DB { - return s.db +func (database *Database) GetDB() *sql.DB { + return database.conn } -func (s *Database) connect(ctx context.Context) error { - dbURL := s.params.Config.DBURL +func (database *Database) connect(ctx context.Context) error { + dbURL := database.params.Config.DBURL if dbURL == "" { dbURL = "file:./data.db?_journal_mode=WAL&_busy_timeout=5000" } - s.log.Info("connecting to database", "url", dbURL) + database.log.Info( + "connecting to database", "url", dbURL, + ) - d, err := sql.Open("sqlite", dbURL) + conn, err := sql.Open("sqlite", dbURL) if err != nil { - s.log.Error( - "failed to open database", "error", err, - ) - - return err + return fmt.Errorf("open database: %w", err) } - err = d.PingContext(ctx) + err = conn.PingContext(ctx) if err != nil { - s.log.Error( - "failed to ping database", "error", err, - ) - - return err + return fmt.Errorf("ping database: %w", err) } - d.SetMaxOpenConns(1) + conn.SetMaxOpenConns(1) - s.db = d - s.log.Info("database connected") + database.conn = conn + database.log.Info("database connected") - _, err = s.db.ExecContext( + _, err = database.conn.ExecContext( ctx, "PRAGMA foreign_keys = ON", ) if err != nil { return fmt.Errorf("enable foreign keys: %w", err) } - _, err = s.db.ExecContext( + _, err = database.conn.ExecContext( ctx, "PRAGMA busy_timeout = 5000", ) if err != nil { return fmt.Errorf("set busy timeout: %w", err) } - return s.runMigrations(ctx) + return database.runMigrations(ctx) } type migration struct { @@ -132,10 +132,10 @@ type migration struct { sql string } -func (s *Database) runMigrations( +func (database *Database) runMigrations( ctx context.Context, ) error { - _, err := s.db.ExecContext(ctx, + _, err := database.conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP)`) @@ -145,37 +145,37 @@ func (s *Database) runMigrations( ) } - migrations, err := s.loadMigrations() + migrations, err := database.loadMigrations() if err != nil { return err } - for _, m := range migrations { - err = s.applyMigration(ctx, m) + for _, mig := range migrations { + err = database.applyMigration(ctx, mig) if err != nil { return err } } - s.log.Info("database migrations complete") + database.log.Info("database migrations complete") return nil } -func (s *Database) applyMigration( +func (database *Database) applyMigration( ctx context.Context, - m migration, + mig migration, ) error { var exists int - err := s.db.QueryRowContext(ctx, + err := database.conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM schema_migrations WHERE version = ?`, - m.version, + mig.version, ).Scan(&exists) if err != nil { return fmt.Errorf( - "check migration %d: %w", m.version, err, + "check migration %d: %w", mig.version, err, ) } @@ -183,55 +183,63 @@ func (s *Database) applyMigration( return nil } - s.log.Info( + database.log.Info( "applying migration", - "version", m.version, - "name", m.name, + "version", mig.version, + "name", mig.name, ) - return s.execMigration(ctx, m) + return database.execMigration(ctx, mig) } -func (s *Database) execMigration( +func (database *Database) execMigration( ctx context.Context, - m migration, + mig migration, ) error { - tx, err := s.db.BeginTx(ctx, nil) + transaction, err := database.conn.BeginTx(ctx, nil) if err != nil { return fmt.Errorf( "begin tx for migration %d: %w", - m.version, err, + mig.version, err, ) } - _, err = tx.ExecContext(ctx, m.sql) + _, err = transaction.ExecContext(ctx, mig.sql) if err != nil { - _ = tx.Rollback() + _ = transaction.Rollback() return fmt.Errorf( "apply migration %d (%s): %w", - m.version, m.name, err, + mig.version, mig.name, err, ) } - _, err = tx.ExecContext(ctx, + _, err = transaction.ExecContext(ctx, `INSERT INTO schema_migrations (version) VALUES (?)`, - m.version, + mig.version, ) if err != nil { - _ = tx.Rollback() + _ = transaction.Rollback() return fmt.Errorf( "record migration %d: %w", - m.version, err, + mig.version, err, ) } - return tx.Commit() + err = transaction.Commit() + if err != nil { + return fmt.Errorf( + "commit migration %d: %w", + mig.version, err, + ) + } + + return nil } -func (s *Database) loadMigrations() ( +func (database *Database) loadMigrations() ( []migration, error, ) { diff --git a/internal/db/export_test.go b/internal/db/export_test.go index 2270385..45c0435 100644 --- a/internal/db/export_test.go +++ b/internal/db/export_test.go @@ -13,35 +13,48 @@ var testDBCounter atomic.Int64 // NewTestDatabase creates an in-memory database for testing. func NewTestDatabase() (*Database, error) { - n := testDBCounter.Add(1) + counter := testDBCounter.Add(1) dsn := fmt.Sprintf( "file:testdb%d?mode=memory"+ "&cache=shared&_pragma=foreign_keys(1)", - n, + counter, ) - d, err := sql.Open("sqlite", dsn) + conn, err := sql.Open("sqlite", dsn) if err != nil { - return nil, err + return nil, fmt.Errorf("open test db: %w", err) } - database := &Database{db: d, log: slog.Default()} + database := &Database{ //nolint:exhaustruct // test helper, params not needed + conn: conn, + log: slog.Default(), + } err = database.runMigrations(context.Background()) if err != nil { - closeErr := d.Close() + closeErr := conn.Close() if closeErr != nil { - return nil, closeErr + return nil, fmt.Errorf( + "close after migration failure: %w", + closeErr, + ) } - return nil, err + return nil, fmt.Errorf( + "run test migrations: %w", err, + ) } return database, nil } // Close closes the underlying database connection. -func (s *Database) Close() error { - return s.db.Close() +func (database *Database) Close() error { + err := database.conn.Close() + if err != nil { + return fmt.Errorf("close database: %w", err) + } + + return nil } diff --git a/internal/db/queries.go b/internal/db/queries.go index 0f3b7d0..e57356f 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -18,11 +18,15 @@ const ( defaultHistLimit = 50 ) -func generateToken() string { - b := make([]byte, tokenBytes) - _, _ = rand.Read(b) +func generateToken() (string, error) { + buf := make([]byte, tokenBytes) - return hex.EncodeToString(b) + _, err := rand.Read(buf) + if err != nil { + return "", fmt.Errorf("generate token: %w", err) + } + + return hex.EncodeToString(buf), nil } // IRCMessage is the IRC envelope for all messages. @@ -52,14 +56,18 @@ type MemberInfo struct { } // CreateUser registers a new user with the given nick. -func (s *Database) CreateUser( +func (database *Database) CreateUser( ctx context.Context, nick string, ) (int64, string, error) { - token := generateToken() + token, err := generateToken() + if err != nil { + return 0, "", err + } + now := time.Now() - res, err := s.db.ExecContext(ctx, + res, err := database.conn.ExecContext(ctx, `INSERT INTO users (nick, token, created_at, last_seen) VALUES (?, ?, ?, ?)`, @@ -68,90 +76,88 @@ func (s *Database) CreateUser( return 0, "", fmt.Errorf("create user: %w", err) } - id, _ := res.LastInsertId() + userID, _ := res.LastInsertId() - return id, token, nil + return userID, token, nil } // GetUserByToken returns user id and nick for a token. -func (s *Database) GetUserByToken( +func (database *Database) GetUserByToken( ctx context.Context, token string, ) (int64, string, error) { - var id int64 + var userID int64 var nick string - err := s.db.QueryRowContext( + err := database.conn.QueryRowContext( ctx, "SELECT id, nick FROM users WHERE token = ?", token, - ).Scan(&id, &nick) + ).Scan(&userID, &nick) if err != nil { - return 0, "", err + return 0, "", fmt.Errorf("get user by token: %w", err) } - _, _ = s.db.ExecContext( + _, _ = database.conn.ExecContext( ctx, "UPDATE users SET last_seen = ? WHERE id = ?", - time.Now(), id, + time.Now(), userID, ) - return id, nick, nil + return userID, nick, nil } // GetUserByNick returns user id for a given nick. -func (s *Database) GetUserByNick( +func (database *Database) GetUserByNick( ctx context.Context, nick string, ) (int64, error) { - var id int64 + var userID int64 - err := s.db.QueryRowContext( + err := database.conn.QueryRowContext( ctx, "SELECT id FROM users WHERE nick = ?", nick, - ).Scan(&id) + ).Scan(&userID) + if err != nil { + return 0, fmt.Errorf("get user by nick: %w", err) + } - return id, err + return userID, nil } // GetChannelByName returns the channel ID for a name. -func (s *Database) GetChannelByName( +func (database *Database) GetChannelByName( ctx context.Context, name string, ) (int64, error) { - var id int64 + var channelID int64 - err := s.db.QueryRowContext( + err := database.conn.QueryRowContext( ctx, "SELECT id FROM channels WHERE name = ?", name, - ).Scan(&id) + ).Scan(&channelID) + if err != nil { + return 0, fmt.Errorf( + "get channel by name: %w", err, + ) + } - return id, err + return channelID, nil } // GetOrCreateChannel returns channel id, creating if needed. -func (s *Database) GetOrCreateChannel( +// Uses INSERT OR IGNORE to avoid TOCTOU races. +func (database *Database) GetOrCreateChannel( ctx context.Context, name string, ) (int64, error) { - var id int64 - - err := s.db.QueryRowContext( - ctx, - "SELECT id FROM channels WHERE name = ?", - name, - ).Scan(&id) - if err == nil { - return id, nil - } - now := time.Now() - res, err := s.db.ExecContext(ctx, - `INSERT INTO channels + _, err := database.conn.ExecContext(ctx, + `INSERT OR IGNORE INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)`, name, now, now) @@ -159,51 +165,71 @@ func (s *Database) GetOrCreateChannel( return 0, fmt.Errorf("create channel: %w", err) } - id, _ = res.LastInsertId() + var channelID int64 - return id, nil + err = database.conn.QueryRowContext( + ctx, + "SELECT id FROM channels WHERE name = ?", + name, + ).Scan(&channelID) + if err != nil { + return 0, fmt.Errorf("get channel: %w", err) + } + + return channelID, nil } // JoinChannel adds a user to a channel. -func (s *Database) JoinChannel( +func (database *Database) JoinChannel( ctx context.Context, channelID, userID int64, ) error { - _, err := s.db.ExecContext(ctx, + _, err := database.conn.ExecContext(ctx, `INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)`, channelID, userID, time.Now()) + if err != nil { + return fmt.Errorf("join channel: %w", err) + } - return err + return nil } // PartChannel removes a user from a channel. -func (s *Database) PartChannel( +func (database *Database) PartChannel( ctx context.Context, channelID, userID int64, ) error { - _, err := s.db.ExecContext(ctx, + _, err := database.conn.ExecContext(ctx, `DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?`, channelID, userID) + if err != nil { + return fmt.Errorf("part channel: %w", err) + } - return err + return nil } // DeleteChannelIfEmpty removes a channel with no members. -func (s *Database) DeleteChannelIfEmpty( +func (database *Database) DeleteChannelIfEmpty( ctx context.Context, channelID int64, ) error { - _, err := s.db.ExecContext(ctx, + _, err := database.conn.ExecContext(ctx, `DELETE FROM channels WHERE id = ? AND NOT EXISTS (SELECT 1 FROM channel_members WHERE channel_id = ?)`, channelID, channelID) + if err != nil { + return fmt.Errorf( + "delete channel if empty: %w", err, + ) + } - return err + return nil } // scanChannels scans rows into a ChannelInfo slice. @@ -215,19 +241,21 @@ func scanChannels( var out []ChannelInfo for rows.Next() { - var ch ChannelInfo + var chanInfo ChannelInfo - err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic) + err := rows.Scan( + &chanInfo.ID, &chanInfo.Name, &chanInfo.Topic, + ) if err != nil { - return nil, err + return nil, fmt.Errorf("scan channel: %w", err) } - out = append(out, ch) + out = append(out, chanInfo) } err := rows.Err() if err != nil { - return nil, err + return nil, fmt.Errorf("rows error: %w", err) } if out == nil { @@ -238,11 +266,11 @@ func scanChannels( } // ListChannels returns channels the user has joined. -func (s *Database) ListChannels( +func (database *Database) ListChannels( ctx context.Context, userID int64, ) ([]ChannelInfo, error) { - rows, err := s.db.QueryContext(ctx, + rows, err := database.conn.QueryContext(ctx, `SELECT c.id, c.name, c.topic FROM channels c INNER JOIN channel_members cm @@ -250,32 +278,34 @@ func (s *Database) ListChannels( WHERE cm.user_id = ? ORDER BY c.name`, userID) if err != nil { - return nil, err + return nil, fmt.Errorf("list channels: %w", err) } return scanChannels(rows) } // ListAllChannels returns every channel. -func (s *Database) ListAllChannels( +func (database *Database) ListAllChannels( ctx context.Context, ) ([]ChannelInfo, error) { - rows, err := s.db.QueryContext(ctx, + rows, err := database.conn.QueryContext(ctx, `SELECT id, name, topic FROM channels ORDER BY name`) if err != nil { - return nil, err + return nil, fmt.Errorf( + "list all channels: %w", err, + ) } return scanChannels(rows) } // ChannelMembers returns all members of a channel. -func (s *Database) ChannelMembers( +func (database *Database) ChannelMembers( ctx context.Context, channelID int64, ) ([]MemberInfo, error) { - rows, err := s.db.QueryContext(ctx, + rows, err := database.conn.QueryContext(ctx, `SELECT u.id, u.nick, u.last_seen FROM users u INNER JOIN channel_members cm @@ -283,7 +313,9 @@ func (s *Database) ChannelMembers( WHERE cm.channel_id = ? ORDER BY u.nick`, channelID) if err != nil { - return nil, err + return nil, fmt.Errorf( + "query channel members: %w", err, + ) } defer func() { _ = rows.Close() }() @@ -291,19 +323,23 @@ func (s *Database) ChannelMembers( var members []MemberInfo for rows.Next() { - var m MemberInfo + var member MemberInfo - err = rows.Scan(&m.ID, &m.Nick, &m.LastSeen) + err = rows.Scan( + &member.ID, &member.Nick, &member.LastSeen, + ) if err != nil { - return nil, err + return nil, fmt.Errorf( + "scan member: %w", err, + ) } - members = append(members, m) + members = append(members, member) } err = rows.Err() if err != nil { - return nil, err + return nil, fmt.Errorf("rows error: %w", err) } if members == nil { @@ -313,6 +349,27 @@ func (s *Database) ChannelMembers( return members, nil } +// IsChannelMember checks if a user belongs to a channel. +func (database *Database) IsChannelMember( + ctx context.Context, + channelID, userID int64, +) (bool, error) { + var count int + + err := database.conn.QueryRowContext(ctx, + `SELECT COUNT(*) FROM channel_members + WHERE channel_id = ? AND user_id = ?`, + channelID, userID, + ).Scan(&count) + if err != nil { + return false, fmt.Errorf( + "check membership: %w", err, + ) + } + + return count > 0, nil +} + // scanInt64s scans rows into an int64 slice. func scanInt64s(rows *sql.Rows) ([]int64, error) { defer func() { _ = rows.Close() }() @@ -320,58 +377,64 @@ func scanInt64s(rows *sql.Rows) ([]int64, error) { var ids []int64 for rows.Next() { - var id int64 + var val int64 - err := rows.Scan(&id) + err := rows.Scan(&val) if err != nil { - return nil, err + return nil, fmt.Errorf( + "scan int64: %w", err, + ) } - ids = append(ids, id) + ids = append(ids, val) } err := rows.Err() if err != nil { - return nil, err + return nil, fmt.Errorf("rows error: %w", err) } return ids, nil } // GetChannelMemberIDs returns user IDs in a channel. -func (s *Database) GetChannelMemberIDs( +func (database *Database) GetChannelMemberIDs( ctx context.Context, channelID int64, ) ([]int64, error) { - rows, err := s.db.QueryContext(ctx, + rows, err := database.conn.QueryContext(ctx, `SELECT user_id FROM channel_members WHERE channel_id = ?`, channelID) if err != nil { - return nil, err + return nil, fmt.Errorf( + "get channel member ids: %w", err, + ) } return scanInt64s(rows) } // GetUserChannelIDs returns channel IDs the user is in. -func (s *Database) GetUserChannelIDs( +func (database *Database) GetUserChannelIDs( ctx context.Context, userID int64, ) ([]int64, error) { - rows, err := s.db.QueryContext(ctx, + rows, err := database.conn.QueryContext(ctx, `SELECT channel_id FROM channel_members WHERE user_id = ?`, userID) if err != nil { - return nil, err + return nil, fmt.Errorf( + "get user channel ids: %w", err, + ) } return scanInt64s(rows) } // InsertMessage stores a message and returns its DB ID. -func (s *Database) InsertMessage( +func (database *Database) InsertMessage( ctx context.Context, - command, from, to string, + command, from, target string, body json.RawMessage, meta json.RawMessage, ) (int64, string, error) { @@ -386,38 +449,43 @@ func (s *Database) InsertMessage( meta = json.RawMessage("{}") } - res, err := s.db.ExecContext(ctx, + res, err := database.conn.ExecContext(ctx, `INSERT INTO messages (uuid, command, msg_from, msg_to, body, meta, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)`, - msgUUID, command, from, to, + msgUUID, command, from, target, string(body), string(meta), now) if err != nil { - return 0, "", err + return 0, "", fmt.Errorf( + "insert message: %w", err, + ) } - id, _ := res.LastInsertId() + dbID, _ := res.LastInsertId() - return id, msgUUID, nil + return dbID, msgUUID, nil } // EnqueueMessage adds a message to a user's queue. -func (s *Database) EnqueueMessage( +func (database *Database) EnqueueMessage( ctx context.Context, userID, messageID int64, ) error { - _, err := s.db.ExecContext(ctx, + _, err := database.conn.ExecContext(ctx, `INSERT OR IGNORE INTO client_queues (user_id, message_id, created_at) VALUES (?, ?, ?)`, userID, messageID, time.Now()) + if err != nil { + return fmt.Errorf("enqueue message: %w", err) + } - return err + return nil } // PollMessages returns queued messages for a user. -func (s *Database) PollMessages( +func (database *Database) PollMessages( ctx context.Context, userID, afterQueueID int64, limit int, @@ -426,7 +494,7 @@ func (s *Database) PollMessages( limit = defaultPollLimit } - rows, err := s.db.QueryContext(ctx, + rows, err := database.conn.QueryContext(ctx, `SELECT cq.id, m.uuid, m.command, m.msg_from, m.msg_to, m.body, m.meta, m.created_at @@ -437,7 +505,9 @@ func (s *Database) PollMessages( ORDER BY cq.id ASC LIMIT ?`, userID, afterQueueID, limit) if err != nil { - return nil, afterQueueID, err + return nil, afterQueueID, fmt.Errorf( + "poll messages: %w", err, + ) } msgs, lastQID, scanErr := scanMessages( @@ -451,7 +521,7 @@ func (s *Database) PollMessages( } // GetHistory returns message history for a target. -func (s *Database) GetHistory( +func (database *Database) GetHistory( ctx context.Context, target string, beforeID int64, @@ -461,7 +531,7 @@ func (s *Database) GetHistory( limit = defaultHistLimit } - rows, err := s.queryHistory( + rows, err := database.queryHistory( ctx, target, beforeID, limit, ) if err != nil { @@ -482,14 +552,14 @@ func (s *Database) GetHistory( return msgs, nil } -func (s *Database) queryHistory( +func (database *Database) queryHistory( ctx context.Context, target string, beforeID int64, limit int, ) (*sql.Rows, error) { if beforeID > 0 { - return s.db.QueryContext(ctx, + rows, err := database.conn.QueryContext(ctx, `SELECT id, uuid, command, msg_from, msg_to, body, meta, created_at FROM messages @@ -497,9 +567,16 @@ func (s *Database) queryHistory( AND command = 'PRIVMSG' ORDER BY id DESC LIMIT ?`, target, beforeID, limit) + if err != nil { + return nil, fmt.Errorf( + "query history: %w", err, + ) + } + + return rows, nil } - return s.db.QueryContext(ctx, + rows, err := database.conn.QueryContext(ctx, `SELECT id, uuid, command, msg_from, msg_to, body, meta, created_at FROM messages @@ -507,6 +584,11 @@ func (s *Database) queryHistory( AND command = 'PRIVMSG' ORDER BY id DESC LIMIT ?`, target, limit) + if err != nil { + return nil, fmt.Errorf("query history: %w", err) + } + + return rows, nil } func scanMessages( @@ -521,33 +603,37 @@ func scanMessages( for rows.Next() { var ( - m IRCMessage + msg IRCMessage qID int64 body, meta string - ts time.Time + createdAt time.Time ) err := rows.Scan( - &qID, &m.ID, &m.Command, - &m.From, &m.To, - &body, &meta, &ts, + &qID, &msg.ID, &msg.Command, + &msg.From, &msg.To, + &body, &meta, &createdAt, ) if err != nil { - return nil, fallbackQID, err + return nil, fallbackQID, fmt.Errorf( + "scan message: %w", err, + ) } - m.Body = json.RawMessage(body) - m.Meta = json.RawMessage(meta) - m.TS = ts.Format(time.RFC3339Nano) - m.DBID = qID + msg.Body = json.RawMessage(body) + msg.Meta = json.RawMessage(meta) + msg.TS = createdAt.Format(time.RFC3339Nano) + msg.DBID = qID lastQID = qID - msgs = append(msgs, m) + msgs = append(msgs, msg) } err := rows.Err() if err != nil { - return nil, fallbackQID, err + return nil, fallbackQID, fmt.Errorf( + "rows error: %w", err, + ) } if msgs == nil { @@ -564,59 +650,70 @@ func reverseMessages(msgs []IRCMessage) { } // ChangeNick updates a user's nickname. -func (s *Database) ChangeNick( +func (database *Database) ChangeNick( ctx context.Context, userID int64, newNick string, ) error { - _, err := s.db.ExecContext(ctx, + _, err := database.conn.ExecContext(ctx, "UPDATE users SET nick = ? WHERE id = ?", newNick, userID) + if err != nil { + return fmt.Errorf("change nick: %w", err) + } - return err + return nil } // SetTopic sets the topic for a channel. -func (s *Database) SetTopic( +func (database *Database) SetTopic( ctx context.Context, channelName, topic string, ) error { - _, err := s.db.ExecContext(ctx, + _, err := database.conn.ExecContext(ctx, `UPDATE channels SET topic = ?, updated_at = ? WHERE name = ?`, topic, time.Now(), channelName) + if err != nil { + return fmt.Errorf("set topic: %w", err) + } - return err + return nil } // DeleteUser removes a user and all their data. -func (s *Database) DeleteUser( +func (database *Database) DeleteUser( ctx context.Context, userID int64, ) error { - _, err := s.db.ExecContext( + _, err := database.conn.ExecContext( ctx, "DELETE FROM users WHERE id = ?", userID, ) + if err != nil { + return fmt.Errorf("delete user: %w", err) + } - return err + return nil } // GetAllChannelMembershipsForUser returns channels // a user belongs to. -func (s *Database) GetAllChannelMembershipsForUser( +func (database *Database) GetAllChannelMembershipsForUser( ctx context.Context, userID int64, ) ([]ChannelInfo, error) { - rows, err := s.db.QueryContext(ctx, + rows, err := database.conn.QueryContext(ctx, `SELECT c.id, c.name, c.topic FROM channels c INNER JOIN channel_members cm ON cm.channel_id = c.id WHERE cm.user_id = ?`, userID) if err != nil { - return nil, err + return nil, fmt.Errorf( + "get memberships: %w", err, + ) } return scanChannels(rows) diff --git a/internal/db/queries_test.go b/internal/db/queries_test.go index 0cae346..a83a951 100644 --- a/internal/db/queries_test.go +++ b/internal/db/queries_test.go @@ -12,19 +12,19 @@ import ( func setupTestDB(t *testing.T) *db.Database { t.Helper() - d, err := db.NewTestDatabase() + database, err := db.NewTestDatabase() if err != nil { t.Fatal(err) } t.Cleanup(func() { - closeErr := d.Close() + closeErr := database.Close() if closeErr != nil { t.Logf("close db: %v", closeErr) } }) - return d + return database } func TestCreateUser(t *testing.T) { @@ -349,12 +349,30 @@ func TestSetTopic(t *testing.T) { } } -func insertTestMessage( - t *testing.T, - database *db.Database, -) (int64, int64) { - t.Helper() +func TestInsertMessage(t *testing.T) { + t.Parallel() + database := setupTestDB(t) + ctx := t.Context() + + body := json.RawMessage(`["hello"]`) + + dbID, msgUUID, err := database.InsertMessage( + ctx, "PRIVMSG", "poller", "#test", body, nil, + ) + if err != nil { + t.Fatal(err) + } + + if dbID == 0 || msgUUID == "" { + t.Fatal("expected valid id and uuid") + } +} + +func TestPollMessages(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) ctx := t.Context() uid, _, err := database.CreateUser(ctx, "poller") @@ -364,11 +382,11 @@ func insertTestMessage( body := json.RawMessage(`["hello"]`) - dbID, msgUUID, err := database.InsertMessage( + dbID, _, err := database.InsertMessage( ctx, "PRIVMSG", "poller", "#test", body, nil, ) - if err != nil || dbID == 0 || msgUUID == "" { - t.Fatal("insert failed") + if err != nil { + t.Fatal(err) } err = database.EnqueueMessage(ctx, uid, dbID) @@ -376,19 +394,10 @@ func insertTestMessage( t.Fatal(err) } - return uid, dbID -} - -func TestInsertAndPollMessages(t *testing.T) { - t.Parallel() - - database := setupTestDB(t) - uid, _ := insertTestMessage(t, database) - const batchSize = 10 msgs, lastQID, err := database.PollMessages( - t.Context(), uid, 0, batchSize, + ctx, uid, 0, batchSize, ) if err != nil { t.Fatal(err) @@ -411,7 +420,7 @@ func TestInsertAndPollMessages(t *testing.T) { } msgs, _, _ = database.PollMessages( - t.Context(), uid, lastQID, batchSize, + ctx, uid, lastQID, batchSize, ) if len(msgs) != 0 { diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 065467a..3c9a428 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -2,6 +2,7 @@ package handlers import ( "encoding/json" + "fmt" "net/http" "regexp" "strconv" @@ -22,13 +23,25 @@ var validChannelRe = regexp.MustCompile( const ( maxLongPollTimeout = 30 pollMessageLimit = 100 + defaultMaxBodySize = 4096 + defaultHistLimit = 50 + maxHistLimit = 500 + cmdPrivmsg = "PRIVMSG" ) +func (hdlr *Handlers) maxBodySize() int64 { + if hdlr.params.Config.MaxMessageSize > 0 { + return int64(hdlr.params.Config.MaxMessageSize) + } + + return defaultMaxBodySize +} + // authUser extracts the user from the Authorization header. -func (s *Handlers) authUser( - r *http.Request, +func (hdlr *Handlers) authUser( + request *http.Request, ) (int64, string, error) { - auth := r.Header.Get("Authorization") + auth := request.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { return 0, "", errUnauthorized } @@ -38,20 +51,27 @@ func (s *Handlers) authUser( return 0, "", errUnauthorized } - return s.params.Database.GetUserByToken( - r.Context(), token, + uid, nick, err := hdlr.params.Database.GetUserByToken( + request.Context(), token, ) + if err != nil { + return 0, "", fmt.Errorf("auth: %w", err) + } + + return uid, nick, nil } -func (s *Handlers) requireAuth( - w http.ResponseWriter, - r *http.Request, +func (hdlr *Handlers) requireAuth( + writer http.ResponseWriter, + request *http.Request, ) (int64, string, bool) { - uid, nick, err := s.authUser(r) + uid, nick, err := hdlr.authUser(request) if err != nil { - s.respondError(w, r, + hdlr.respondError( + writer, request, "unauthorized", - http.StatusUnauthorized) + http.StatusUnauthorized, + ) return 0, "", false } @@ -61,149 +81,159 @@ func (s *Handlers) requireAuth( // fanOut stores a message and enqueues it to all specified // user IDs, then notifies them. -func (s *Handlers) fanOut( - r *http.Request, - command, from, to string, +func (hdlr *Handlers) fanOut( + request *http.Request, + command, from, target string, body json.RawMessage, userIDs []int64, ) (string, error) { - dbID, msgUUID, err := s.params.Database.InsertMessage( - r.Context(), command, from, to, body, nil, + dbID, msgUUID, err := hdlr.params.Database.InsertMessage( + request.Context(), command, from, target, body, nil, ) if err != nil { - return "", err + return "", fmt.Errorf("insert message: %w", err) } for _, uid := range userIDs { - err = s.params.Database.EnqueueMessage( - r.Context(), uid, dbID, + enqErr := hdlr.params.Database.EnqueueMessage( + request.Context(), uid, dbID, ) - if err != nil { - s.log.Error("enqueue failed", - "error", err, "user_id", uid) + if enqErr != nil { + hdlr.log.Error("enqueue failed", + "error", enqErr, "user_id", uid) } - s.broker.Notify(uid) + hdlr.broker.Notify(uid) } return msgUUID, nil } -// fanOutSilent calls fanOut and discards the return values. -func (s *Handlers) fanOutSilent( - r *http.Request, - command, from, to string, +// fanOutSilent calls fanOut and discards the UUID. +func (hdlr *Handlers) fanOutSilent( + request *http.Request, + command, from, target string, body json.RawMessage, userIDs []int64, ) error { - _, err := s.fanOut( - r, command, from, to, body, userIDs, + _, err := hdlr.fanOut( + request, command, from, target, body, userIDs, ) return err } // HandleCreateSession creates a new user session. -func (s *Handlers) HandleCreateSession() http.HandlerFunc { - type request struct { +func (hdlr *Handlers) HandleCreateSession() http.HandlerFunc { + type createRequest struct { Nick string `json:"nick"` } - type response struct { + type createResponse struct { ID int64 `json:"id"` Nick string `json:"nick"` Token string `json:"token"` } - return func(w http.ResponseWriter, r *http.Request) { - var req request + return func( + writer http.ResponseWriter, + request *http.Request, + ) { + request.Body = http.MaxBytesReader( + writer, request.Body, hdlr.maxBodySize(), + ) - err := json.NewDecoder(r.Body).Decode(&req) + var payload createRequest + + err := json.NewDecoder(request.Body).Decode(&payload) if err != nil { - s.respondError(w, r, + hdlr.respondError( + writer, request, "invalid request body", - http.StatusBadRequest) + http.StatusBadRequest, + ) return } - req.Nick = strings.TrimSpace(req.Nick) + payload.Nick = strings.TrimSpace(payload.Nick) - if !validNickRe.MatchString(req.Nick) { - s.respondError(w, r, + if !validNickRe.MatchString(payload.Nick) { + hdlr.respondError( + writer, request, "invalid nick format", - http.StatusBadRequest) + http.StatusBadRequest, + ) return } - id, token, err := s.params.Database.CreateUser( - r.Context(), req.Nick, + userID, token, err := hdlr.params.Database.CreateUser( + request.Context(), payload.Nick, ) if err != nil { - s.handleCreateUserError(w, r, err) + if strings.Contains(err.Error(), "UNIQUE") { + hdlr.respondError( + writer, request, + "nick already taken", + http.StatusConflict, + ) + + return + } + + hdlr.log.Error( + "create user failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) return } - s.respondJSON(w, r, - &response{ID: id, Nick: req.Nick, Token: token}, - http.StatusCreated) + hdlr.respondJSON( + writer, request, + &createResponse{ + ID: userID, + Nick: payload.Nick, + Token: token, + }, + http.StatusCreated, + ) } } -func (s *Handlers) respondError( - w http.ResponseWriter, - r *http.Request, - msg string, - code int, -) { - s.respondJSON(w, r, - map[string]string{"error": msg}, code) -} - -func (s *Handlers) handleCreateUserError( - w http.ResponseWriter, - r *http.Request, - err error, -) { - if strings.Contains(err.Error(), "UNIQUE") { - s.respondError(w, r, - "nick already taken", - http.StatusConflict) - - return - } - - s.log.Error("create user failed", "error", err) - - s.respondError(w, r, - "internal error", - http.StatusInternalServerError) -} - // HandleState returns the current user's info and channels. -func (s *Handlers) HandleState() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - uid, nick, ok := s.requireAuth(w, r) +func (hdlr *Handlers) HandleState() http.HandlerFunc { + return func( + writer http.ResponseWriter, + request *http.Request, + ) { + uid, nick, ok := hdlr.requireAuth(writer, request) if !ok { return } - channels, err := s.params.Database.ListChannels( - r.Context(), uid, + channels, err := hdlr.params.Database.ListChannels( + request.Context(), uid, ) if err != nil { - s.log.Error("list channels failed", "error", err) - - s.respondError(w, r, + hdlr.log.Error( + "list channels failed", "error", err, + ) + hdlr.respondError( + writer, request, "internal error", - http.StatusInternalServerError) + http.StatusInternalServerError, + ) return } - s.respondJSON(w, r, map[string]any{ + hdlr.respondJSON(writer, request, map[string]any{ "id": uid, "nick": nick, "channels": channels, @@ -212,86 +242,103 @@ func (s *Handlers) HandleState() http.HandlerFunc { } // HandleListAllChannels returns all channels on the server. -func (s *Handlers) HandleListAllChannels() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - _, _, ok := s.requireAuth(w, r) +func (hdlr *Handlers) HandleListAllChannels() http.HandlerFunc { + return func( + writer http.ResponseWriter, + request *http.Request, + ) { + _, _, ok := hdlr.requireAuth(writer, request) if !ok { return } - channels, err := s.params.Database.ListAllChannels( - r.Context(), + channels, err := hdlr.params.Database.ListAllChannels( + request.Context(), ) if err != nil { - s.log.Error( + hdlr.log.Error( "list all channels failed", "error", err, ) - - s.respondError(w, r, + hdlr.respondError( + writer, request, "internal error", - http.StatusInternalServerError) + http.StatusInternalServerError, + ) return } - s.respondJSON(w, r, channels, http.StatusOK) + hdlr.respondJSON( + writer, request, channels, http.StatusOK, + ) } } // HandleChannelMembers returns members of a channel. -func (s *Handlers) HandleChannelMembers() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - _, _, ok := s.requireAuth(w, r) +func (hdlr *Handlers) HandleChannelMembers() http.HandlerFunc { + return func( + writer http.ResponseWriter, + request *http.Request, + ) { + _, _, ok := hdlr.requireAuth(writer, request) if !ok { return } - name := "#" + chi.URLParam(r, "channel") + name := "#" + chi.URLParam(request, "channel") - chID, err := s.params.Database.GetChannelByName( - r.Context(), name, + chID, err := hdlr.params.Database.GetChannelByName( + request.Context(), name, ) if err != nil { - s.respondError(w, r, + hdlr.respondError( + writer, request, "channel not found", - http.StatusNotFound) - - return - } - - members, err := s.params.Database.ChannelMembers( - r.Context(), chID, - ) - if err != nil { - s.log.Error( - "channel members failed", "error", err, + http.StatusNotFound, ) - s.respondError(w, r, + return + } + + members, err := hdlr.params.Database.ChannelMembers( + request.Context(), chID, + ) + if err != nil { + hdlr.log.Error( + "channel members failed", "error", err, + ) + hdlr.respondError( + writer, request, "internal error", - http.StatusInternalServerError) + http.StatusInternalServerError, + ) return } - s.respondJSON(w, r, members, http.StatusOK) + hdlr.respondJSON( + writer, request, members, http.StatusOK, + ) } } // HandleGetMessages returns messages via long-polling. -func (s *Handlers) HandleGetMessages() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - uid, _, ok := s.requireAuth(w, r) +func (hdlr *Handlers) HandleGetMessages() http.HandlerFunc { + return func( + writer http.ResponseWriter, + request *http.Request, + ) { + uid, _, ok := hdlr.requireAuth(writer, request) if !ok { return } afterID, _ := strconv.ParseInt( - r.URL.Query().Get("after"), 10, 64, + request.URL.Query().Get("after"), 10, 64, ) timeout, _ := strconv.Atoi( - r.URL.Query().Get("timeout"), + request.URL.Query().Get("timeout"), ) if timeout < 0 { timeout = 0 @@ -301,23 +348,25 @@ func (s *Handlers) HandleGetMessages() http.HandlerFunc { timeout = maxLongPollTimeout } - msgs, lastQID, err := s.params.Database.PollMessages( - r.Context(), uid, afterID, pollMessageLimit, + msgs, lastQID, err := hdlr.params.Database.PollMessages( + request.Context(), uid, + afterID, pollMessageLimit, ) if err != nil { - s.log.Error( + hdlr.log.Error( "poll messages failed", "error", err, ) - - s.respondError(w, r, + hdlr.respondError( + writer, request, "internal error", - http.StatusInternalServerError) + http.StatusInternalServerError, + ) return } if len(msgs) > 0 || timeout == 0 { - s.respondJSON(w, r, map[string]any{ + hdlr.respondJSON(writer, request, map[string]any{ "messages": msgs, "last_id": lastQID, }, http.StatusOK) @@ -325,17 +374,17 @@ func (s *Handlers) HandleGetMessages() http.HandlerFunc { return } - s.longPoll(w, r, uid, afterID, timeout) + hdlr.longPoll(writer, request, uid, afterID, timeout) } } -func (s *Handlers) longPoll( - w http.ResponseWriter, - r *http.Request, +func (hdlr *Handlers) longPoll( + writer http.ResponseWriter, + request *http.Request, uid, afterID int64, timeout int, ) { - waitCh := s.broker.Wait(uid) + waitCh := hdlr.broker.Wait(uid) timer := time.NewTimer( time.Duration(timeout) * time.Second, @@ -346,234 +395,301 @@ func (s *Handlers) longPoll( select { case <-waitCh: case <-timer.C: - case <-r.Context().Done(): - s.broker.Remove(uid, waitCh) + case <-request.Context().Done(): + hdlr.broker.Remove(uid, waitCh) return } - s.broker.Remove(uid, waitCh) + hdlr.broker.Remove(uid, waitCh) - msgs, lastQID, err := s.params.Database.PollMessages( - r.Context(), uid, afterID, pollMessageLimit, + msgs, lastQID, err := hdlr.params.Database.PollMessages( + request.Context(), uid, + afterID, pollMessageLimit, ) if err != nil { - s.log.Error("poll messages failed", "error", err) - - s.respondError(w, r, + hdlr.log.Error( + "poll messages failed", "error", err, + ) + hdlr.respondError( + writer, request, "internal error", - http.StatusInternalServerError) + http.StatusInternalServerError, + ) return } - s.respondJSON(w, r, map[string]any{ + hdlr.respondJSON(writer, request, map[string]any{ "messages": msgs, "last_id": lastQID, }, http.StatusOK) } // HandleSendCommand handles all C2S commands. -func (s *Handlers) HandleSendCommand() http.HandlerFunc { - type request struct { +func (hdlr *Handlers) HandleSendCommand() http.HandlerFunc { + type commandRequest struct { Command string `json:"command"` To string `json:"to"` Body json.RawMessage `json:"body,omitempty"` Meta json.RawMessage `json:"meta,omitempty"` } - return func(w http.ResponseWriter, r *http.Request) { - uid, nick, ok := s.requireAuth(w, r) + return func( + writer http.ResponseWriter, + request *http.Request, + ) { + request.Body = http.MaxBytesReader( + writer, request.Body, hdlr.maxBodySize(), + ) + + uid, nick, ok := hdlr.requireAuth(writer, request) if !ok { return } - var req request + var payload commandRequest - err := json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(request.Body).Decode(&payload) if err != nil { - s.respondError(w, r, + hdlr.respondError( + writer, request, "invalid request body", - http.StatusBadRequest) + http.StatusBadRequest, + ) return } - req.Command = strings.ToUpper( - strings.TrimSpace(req.Command), + payload.Command = strings.ToUpper( + strings.TrimSpace(payload.Command), ) - req.To = strings.TrimSpace(req.To) + payload.To = strings.TrimSpace(payload.To) - if req.Command == "" { - s.respondError(w, r, + if payload.Command == "" { + hdlr.respondError( + writer, request, "command required", - http.StatusBadRequest) + http.StatusBadRequest, + ) return } bodyLines := func() []string { - if req.Body == nil { + if payload.Body == nil { return nil } var lines []string - err := json.Unmarshal(req.Body, &lines) - if err != nil { + decErr := json.Unmarshal(payload.Body, &lines) + if decErr != nil { return nil } return lines } - s.dispatchCommand( - w, r, uid, nick, req.Command, - req.To, req.Body, bodyLines, + hdlr.dispatchCommand( + writer, request, uid, nick, + payload.Command, payload.To, + payload.Body, bodyLines, ) } } -func (s *Handlers) dispatchCommand( - w http.ResponseWriter, - r *http.Request, +func (hdlr *Handlers) dispatchCommand( + writer http.ResponseWriter, + request *http.Request, uid int64, - nick, command, to string, + nick, command, target string, body json.RawMessage, bodyLines func() []string, ) { switch command { - case "PRIVMSG", "NOTICE": - s.handlePrivmsg( - w, r, uid, nick, command, to, body, bodyLines, + case cmdPrivmsg, "NOTICE": + hdlr.handlePrivmsg( + writer, request, uid, nick, + command, target, body, bodyLines, ) case "JOIN": - s.handleJoin(w, r, uid, nick, to) + hdlr.handleJoin( + writer, request, uid, nick, target, + ) case "PART": - s.handlePart(w, r, uid, nick, to, body) + hdlr.handlePart( + writer, request, uid, nick, target, body, + ) case "NICK": - s.handleNick(w, r, uid, nick, bodyLines) + hdlr.handleNick( + writer, request, uid, nick, bodyLines, + ) case "TOPIC": - s.handleTopic(w, r, nick, to, body, bodyLines) + hdlr.handleTopic( + writer, request, nick, target, body, bodyLines, + ) case "QUIT": - s.handleQuit(w, r, uid, nick, body) + hdlr.handleQuit( + writer, request, uid, nick, body, + ) case "PING": - s.respondJSON(w, r, + hdlr.respondJSON(writer, request, map[string]string{ "command": "PONG", - "from": s.params.Config.ServerName, + "from": hdlr.params.Config.ServerName, }, http.StatusOK) default: - s.respondJSON(w, r, - map[string]string{ - "error": "unknown command: " + command, - }, - http.StatusBadRequest) + hdlr.respondError( + writer, request, + "unknown command: "+command, + http.StatusBadRequest, + ) } } -func (s *Handlers) handlePrivmsg( - w http.ResponseWriter, - r *http.Request, +func (hdlr *Handlers) handlePrivmsg( + writer http.ResponseWriter, + request *http.Request, uid int64, - nick, command, to string, + nick, command, target string, body json.RawMessage, bodyLines func() []string, ) { - if to == "" { - s.respondError(w, r, + if target == "" { + hdlr.respondError( + writer, request, "to field required", - http.StatusBadRequest) + http.StatusBadRequest, + ) return } lines := bodyLines() if len(lines) == 0 { - s.respondError(w, r, + hdlr.respondError( + writer, request, "body required", - http.StatusBadRequest) - - return - } - - if strings.HasPrefix(to, "#") { - s.handleChannelMsg( - w, r, uid, nick, command, to, body, + http.StatusBadRequest, ) return } - s.handleDirectMsg(w, r, uid, nick, command, to, body) + if strings.HasPrefix(target, "#") { + hdlr.handleChannelMsg( + writer, request, uid, nick, + command, target, body, + ) + + return + } + + hdlr.handleDirectMsg( + writer, request, uid, nick, + command, target, body, + ) } -func (s *Handlers) handleChannelMsg( - w http.ResponseWriter, - r *http.Request, - _ int64, - nick, command, to string, +func (hdlr *Handlers) handleChannelMsg( + writer http.ResponseWriter, + request *http.Request, + uid int64, + nick, command, target string, body json.RawMessage, ) { - chID, err := s.params.Database.GetChannelByName( - r.Context(), to, + chID, err := hdlr.params.Database.GetChannelByName( + request.Context(), target, ) if err != nil { - s.respondError(w, r, + hdlr.respondError( + writer, request, "channel not found", - http.StatusNotFound) - - return - } - - memberIDs, err := s.params.Database.GetChannelMemberIDs( - r.Context(), chID, - ) - if err != nil { - s.log.Error( - "get channel members failed", "error", err, + http.StatusNotFound, ) - s.respondError(w, r, - "internal error", - http.StatusInternalServerError) - return } - msgUUID, err := s.fanOut( - r, command, nick, to, body, memberIDs, + isMember, err := hdlr.params.Database.IsChannelMember( + request.Context(), chID, uid, ) if err != nil { - s.log.Error("send message failed", "error", err) - - s.respondError(w, r, + hdlr.log.Error( + "check membership failed", "error", err, + ) + hdlr.respondError( + writer, request, "internal error", - http.StatusInternalServerError) + http.StatusInternalServerError, + ) return } - s.respondJSON(w, r, + if !isMember { + hdlr.respondError( + writer, request, + "not a member of this channel", + http.StatusForbidden, + ) + + return + } + + memberIDs, err := hdlr.params.Database.GetChannelMemberIDs( + request.Context(), chID, + ) + if err != nil { + hdlr.log.Error( + "get channel members failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return + } + + msgUUID, err := hdlr.fanOut( + request, command, nick, target, body, memberIDs, + ) + if err != nil { + hdlr.log.Error("send message failed", "error", err) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return + } + + hdlr.respondJSON(writer, request, map[string]string{"id": msgUUID, "status": "sent"}, http.StatusCreated) } -func (s *Handlers) handleDirectMsg( - w http.ResponseWriter, - r *http.Request, +func (hdlr *Handlers) handleDirectMsg( + writer http.ResponseWriter, + request *http.Request, uid int64, - nick, command, to string, + nick, command, target string, body json.RawMessage, ) { - targetUID, err := s.params.Database.GetUserByNick( - r.Context(), to, + targetUID, err := hdlr.params.Database.GetUserByNick( + request.Context(), target, ) if err != nil { - s.respondError(w, r, + hdlr.respondError( + writer, request, "user not found", - http.StatusNotFound) + http.StatusNotFound, + ) return } @@ -583,98 +699,97 @@ func (s *Handlers) handleDirectMsg( recipients = append(recipients, uid) } - msgUUID, err := s.fanOut( - r, command, nick, to, body, recipients, + msgUUID, err := hdlr.fanOut( + request, command, nick, target, body, recipients, ) if err != nil { - s.log.Error("send dm failed", "error", err) - - s.respondError(w, r, + hdlr.log.Error("send dm failed", "error", err) + hdlr.respondError( + writer, request, "internal error", - http.StatusInternalServerError) + http.StatusInternalServerError, + ) return } - s.respondJSON(w, r, + hdlr.respondJSON(writer, request, map[string]string{"id": msgUUID, "status": "sent"}, http.StatusCreated) } -func normalizeChannel(name string) string { - if !strings.HasPrefix(name, "#") { - return "#" + name - } - - return name -} - -func (s *Handlers) internalError( - w http.ResponseWriter, - r *http.Request, - msg string, - err error, -) { - s.log.Error(msg, "error", err) - - s.respondError(w, r, - "internal error", - http.StatusInternalServerError) -} - -func (s *Handlers) handleJoin( - w http.ResponseWriter, - r *http.Request, +func (hdlr *Handlers) handleJoin( + writer http.ResponseWriter, + request *http.Request, uid int64, - nick, to string, + nick, target string, ) { - if to == "" { - s.respondError(w, r, + if target == "" { + hdlr.respondError( + writer, request, "to field required", - http.StatusBadRequest) + http.StatusBadRequest, + ) return } - channel := normalizeChannel(to) + channel := target + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } if !validChannelRe.MatchString(channel) { - s.respondError(w, r, + hdlr.respondError( + writer, request, "invalid channel name", - http.StatusBadRequest) + http.StatusBadRequest, + ) return } - chID, err := s.params.Database.GetOrCreateChannel( - r.Context(), channel, + chID, err := hdlr.params.Database.GetOrCreateChannel( + request.Context(), channel, ) if err != nil { - s.internalError(w, r, - "get/create channel failed", err) + hdlr.log.Error( + "get/create channel failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) return } - err = s.params.Database.JoinChannel( - r.Context(), chID, uid, + err = hdlr.params.Database.JoinChannel( + request.Context(), chID, uid, ) if err != nil { - s.internalError(w, r, - "join channel failed", err) + hdlr.log.Error( + "join channel failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) return } - memberIDs, _ := s.params.Database.GetChannelMemberIDs( - r.Context(), chID, + memberIDs, _ := hdlr.params.Database.GetChannelMemberIDs( + request.Context(), chID, ) - _ = s.fanOutSilent( - r, "JOIN", nick, channel, nil, memberIDs, + _ = hdlr.fanOutSilent( + request, "JOIN", nick, channel, nil, memberIDs, ) - s.respondJSON(w, r, + hdlr.respondJSON(writer, request, map[string]string{ "status": "joined", "channel": channel, @@ -682,67 +797,70 @@ func (s *Handlers) handleJoin( http.StatusOK) } -func (s *Handlers) handlePart( - w http.ResponseWriter, - r *http.Request, +func (hdlr *Handlers) handlePart( + writer http.ResponseWriter, + request *http.Request, uid int64, - nick, to string, + nick, target string, body json.RawMessage, ) { - if to == "" { - s.respondError(w, r, + if target == "" { + hdlr.respondError( + writer, request, "to field required", - http.StatusBadRequest) + http.StatusBadRequest, + ) return } - channel := normalizeChannel(to) + channel := target + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } - chID, err := s.params.Database.GetChannelByName( - r.Context(), channel, + chID, err := hdlr.params.Database.GetChannelByName( + request.Context(), channel, ) if err != nil { - s.respondError(w, r, + hdlr.respondError( + writer, request, "channel not found", - http.StatusNotFound) + http.StatusNotFound, + ) return } - s.partAndCleanup(w, r, chID, uid, nick, channel, body) -} - -func (s *Handlers) partAndCleanup( - w http.ResponseWriter, - r *http.Request, - chID, uid int64, - nick, channel string, - body json.RawMessage, -) { - memberIDs, _ := s.params.Database.GetChannelMemberIDs( - r.Context(), chID, + memberIDs, _ := hdlr.params.Database.GetChannelMemberIDs( + request.Context(), chID, ) - _ = s.fanOutSilent( - r, "PART", nick, channel, body, memberIDs, + _ = hdlr.fanOutSilent( + request, "PART", nick, channel, body, memberIDs, ) - err := s.params.Database.PartChannel( - r.Context(), chID, uid, + err = hdlr.params.Database.PartChannel( + request.Context(), chID, uid, ) if err != nil { - s.internalError(w, r, - "part channel failed", err) + hdlr.log.Error( + "part channel failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) return } - _ = s.params.Database.DeleteChannelIfEmpty( - r.Context(), chID, + _ = hdlr.params.Database.DeleteChannelIfEmpty( + request.Context(), chID, ) - s.respondJSON(w, r, + hdlr.respondJSON(writer, request, map[string]string{ "status": "parted", "channel": channel, @@ -750,18 +868,20 @@ func (s *Handlers) partAndCleanup( http.StatusOK) } -func (s *Handlers) handleNick( - w http.ResponseWriter, - r *http.Request, +func (hdlr *Handlers) handleNick( + writer http.ResponseWriter, + request *http.Request, uid int64, nick string, bodyLines func() []string, ) { lines := bodyLines() if len(lines) == 0 { - s.respondError(w, r, + hdlr.respondError( + writer, request, "body required (new nick)", - http.StatusBadRequest) + http.StatusBadRequest, + ) return } @@ -769,15 +889,17 @@ func (s *Handlers) handleNick( newNick := strings.TrimSpace(lines[0]) if !validNickRe.MatchString(newNick) { - s.respondError(w, r, + hdlr.respondError( + writer, request, "invalid nick", - http.StatusBadRequest) + http.StatusBadRequest, + ) return } if newNick == nick { - s.respondJSON(w, r, + hdlr.respondJSON(writer, request, map[string]string{ "status": "ok", "nick": newNick, }, @@ -786,272 +908,382 @@ func (s *Handlers) handleNick( return } - err := s.params.Database.ChangeNick( - r.Context(), uid, newNick, + err := hdlr.params.Database.ChangeNick( + request.Context(), uid, newNick, ) if err != nil { - s.handleChangeNickError(w, r, err) + if strings.Contains(err.Error(), "UNIQUE") { + hdlr.respondError( + writer, request, + "nick already in use", + http.StatusConflict, + ) + + return + } + + hdlr.log.Error( + "change nick failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) return } - s.broadcastNick(r, uid, nick, newNick) + hdlr.broadcastNick(request, uid, nick, newNick) - s.respondJSON(w, r, + hdlr.respondJSON(writer, request, map[string]string{ "status": "ok", "nick": newNick, }, http.StatusOK) } -func (s *Handlers) handleChangeNickError( - w http.ResponseWriter, - r *http.Request, - err error, -) { - if strings.Contains(err.Error(), "UNIQUE") { - s.respondError(w, r, - "nick already in use", - http.StatusConflict) - - return - } - - s.internalError(w, r, "change nick failed", err) -} - -func (s *Handlers) broadcastNick( - r *http.Request, +func (hdlr *Handlers) broadcastNick( + request *http.Request, uid int64, oldNick, newNick string, ) { - channels, _ := s.params.Database. - GetAllChannelMembershipsForUser(r.Context(), uid) + channels, _ := hdlr.params.Database. + GetAllChannelMembershipsForUser( + request.Context(), uid, + ) notified := map[int64]bool{uid: true} nickBody, err := json.Marshal([]string{newNick}) if err != nil { - s.log.Error("marshal nick body", "error", err) + hdlr.log.Error( + "marshal nick body", "error", err, + ) return } - dbID, _, _ := s.params.Database.InsertMessage( - r.Context(), "NICK", oldNick, "", + dbID, _, _ := hdlr.params.Database.InsertMessage( + request.Context(), "NICK", oldNick, "", json.RawMessage(nickBody), nil, ) - _ = s.params.Database.EnqueueMessage( - r.Context(), uid, dbID, + _ = hdlr.params.Database.EnqueueMessage( + request.Context(), uid, dbID, ) - s.broker.Notify(uid) + hdlr.broker.Notify(uid) - for _, ch := range channels { - memberIDs, _ := s.params.Database. - GetChannelMemberIDs(r.Context(), ch.ID) + for _, chanInfo := range channels { + memberIDs, _ := hdlr.params.Database. + GetChannelMemberIDs( + request.Context(), chanInfo.ID, + ) for _, mid := range memberIDs { if !notified[mid] { notified[mid] = true - _ = s.params.Database.EnqueueMessage( - r.Context(), mid, dbID, + _ = hdlr.params.Database.EnqueueMessage( + request.Context(), mid, dbID, ) - s.broker.Notify(mid) + hdlr.broker.Notify(mid) } } } } -func (s *Handlers) handleTopic( - w http.ResponseWriter, - r *http.Request, - nick, to string, +func (hdlr *Handlers) handleTopic( + writer http.ResponseWriter, + request *http.Request, + nick, target string, body json.RawMessage, bodyLines func() []string, ) { - if to == "" { - s.respondError(w, r, + if target == "" { + hdlr.respondError( + writer, request, "to field required", - http.StatusBadRequest) + http.StatusBadRequest, + ) return } lines := bodyLines() if len(lines) == 0 { - s.respondError(w, r, + hdlr.respondError( + writer, request, "body required (topic text)", - http.StatusBadRequest) + http.StatusBadRequest, + ) return } topic := strings.Join(lines, " ") - channel := normalizeChannel(to) - err := s.params.Database.SetTopic( - r.Context(), channel, topic, + channel := target + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + + err := hdlr.params.Database.SetTopic( + request.Context(), channel, topic, ) if err != nil { - s.internalError(w, r, "set topic failed", err) + hdlr.log.Error( + "set topic failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) return } - s.broadcastTopic(w, r, nick, channel, topic, body) -} - -func (s *Handlers) broadcastTopic( - w http.ResponseWriter, - r *http.Request, - nick, channel, topic string, - body json.RawMessage, -) { - chID, err := s.params.Database.GetChannelByName( - r.Context(), channel, + chID, err := hdlr.params.Database.GetChannelByName( + request.Context(), channel, ) if err != nil { - s.respondError(w, r, + hdlr.respondError( + writer, request, "channel not found", - http.StatusNotFound) + http.StatusNotFound, + ) return } - memberIDs, _ := s.params.Database.GetChannelMemberIDs( - r.Context(), chID, + memberIDs, _ := hdlr.params.Database.GetChannelMemberIDs( + request.Context(), chID, ) - _ = s.fanOutSilent( - r, "TOPIC", nick, channel, body, memberIDs, + _ = hdlr.fanOutSilent( + request, "TOPIC", nick, channel, body, memberIDs, ) - s.respondJSON(w, r, + hdlr.respondJSON(writer, request, map[string]string{ "status": "ok", "topic": topic, }, http.StatusOK) } -func (s *Handlers) handleQuit( - w http.ResponseWriter, - r *http.Request, +func (hdlr *Handlers) handleQuit( + writer http.ResponseWriter, + request *http.Request, uid int64, nick string, body json.RawMessage, ) { - channels, _ := s.params.Database. - GetAllChannelMembershipsForUser(r.Context(), uid) + channels, _ := hdlr.params.Database. + GetAllChannelMembershipsForUser( + request.Context(), uid, + ) notified := map[int64]bool{} var dbID int64 if len(channels) > 0 { - dbID, _, _ = s.params.Database.InsertMessage( - r.Context(), "QUIT", nick, "", body, nil, + dbID, _, _ = hdlr.params.Database.InsertMessage( + request.Context(), "QUIT", nick, "", body, nil, ) } - for _, ch := range channels { - memberIDs, _ := s.params.Database. - GetChannelMemberIDs(r.Context(), ch.ID) + for _, chanInfo := range channels { + memberIDs, _ := hdlr.params.Database. + GetChannelMemberIDs( + request.Context(), chanInfo.ID, + ) for _, mid := range memberIDs { if mid != uid && !notified[mid] { notified[mid] = true - _ = s.params.Database.EnqueueMessage( - r.Context(), mid, dbID, + _ = hdlr.params.Database.EnqueueMessage( + request.Context(), mid, dbID, ) - s.broker.Notify(mid) + hdlr.broker.Notify(mid) } } - _ = s.params.Database.PartChannel( - r.Context(), ch.ID, uid, + _ = hdlr.params.Database.PartChannel( + request.Context(), chanInfo.ID, uid, ) - _ = s.params.Database.DeleteChannelIfEmpty( - r.Context(), ch.ID, + _ = hdlr.params.Database.DeleteChannelIfEmpty( + request.Context(), chanInfo.ID, ) } - _ = s.params.Database.DeleteUser(r.Context(), uid) + _ = hdlr.params.Database.DeleteUser( + request.Context(), uid, + ) - s.respondJSON(w, r, + hdlr.respondJSON(writer, request, map[string]string{"status": "quit"}, http.StatusOK) } -const ( - defaultHistLimit = 50 - maxHistLimit = 500 -) - // HandleGetHistory returns message history for a target. -func (s *Handlers) HandleGetHistory() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - _, _, ok := s.requireAuth(w, r) +func (hdlr *Handlers) HandleGetHistory() http.HandlerFunc { + return func( + writer http.ResponseWriter, + request *http.Request, + ) { + uid, nick, ok := hdlr.requireAuth(writer, request) if !ok { return } - target := r.URL.Query().Get("target") + target := request.URL.Query().Get("target") if target == "" { - s.respondError(w, r, + hdlr.respondError( + writer, request, "target required", - http.StatusBadRequest) + http.StatusBadRequest, + ) return } + if !hdlr.canAccessHistory( + writer, request, uid, nick, target, + ) { + return + } + beforeID, _ := strconv.ParseInt( - r.URL.Query().Get("before"), 10, 64, + request.URL.Query().Get("before"), 10, 64, ) limit, _ := strconv.Atoi( - r.URL.Query().Get("limit"), + request.URL.Query().Get("limit"), ) if limit <= 0 || limit > maxHistLimit { limit = defaultHistLimit } - msgs, err := s.params.Database.GetHistory( - r.Context(), target, beforeID, limit, + msgs, err := hdlr.params.Database.GetHistory( + request.Context(), target, beforeID, limit, ) if err != nil { - s.log.Error( + hdlr.log.Error( "get history failed", "error", err, ) - - s.respondError(w, r, + hdlr.respondError( + writer, request, "internal error", - http.StatusInternalServerError) + http.StatusInternalServerError, + ) return } - s.respondJSON(w, r, msgs, http.StatusOK) + hdlr.respondJSON( + writer, request, msgs, http.StatusOK, + ) } } +// canAccessHistory verifies the user can read history +// for the given target (channel or DM participant). +func (hdlr *Handlers) canAccessHistory( + writer http.ResponseWriter, + request *http.Request, + uid int64, + nick, target string, +) bool { + if strings.HasPrefix(target, "#") { + return hdlr.canAccessChannelHistory( + writer, request, uid, target, + ) + } + + // DM history: only allow if the target is the + // requester's own nick (messages sent to them). + if target != nick { + hdlr.respondError( + writer, request, + "forbidden", + http.StatusForbidden, + ) + + return false + } + + return true +} + +func (hdlr *Handlers) canAccessChannelHistory( + writer http.ResponseWriter, + request *http.Request, + uid int64, + target string, +) bool { + chID, err := hdlr.params.Database.GetChannelByName( + request.Context(), target, + ) + if err != nil { + hdlr.respondError( + writer, request, + "channel not found", + http.StatusNotFound, + ) + + return false + } + + isMember, err := hdlr.params.Database.IsChannelMember( + request.Context(), chID, uid, + ) + if err != nil { + hdlr.log.Error( + "check membership failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) + + return false + } + + if !isMember { + hdlr.respondError( + writer, request, + "not a member of this channel", + http.StatusForbidden, + ) + + return false + } + + return true +} + // HandleServerInfo returns server metadata. -func (s *Handlers) HandleServerInfo() http.HandlerFunc { - type response struct { +func (hdlr *Handlers) HandleServerInfo() http.HandlerFunc { + type infoResponse struct { Name string `json:"name"` MOTD string `json:"motd"` } - return func(w http.ResponseWriter, r *http.Request) { - s.respondJSON(w, r, &response{ - Name: s.params.Config.ServerName, - MOTD: s.params.Config.MOTD, + return func( + writer http.ResponseWriter, + request *http.Request, + ) { + hdlr.respondJSON(writer, request, &infoResponse{ + Name: hdlr.params.Config.ServerName, + MOTD: hdlr.params.Config.MOTD, }, http.StatusOK) } } diff --git a/internal/handlers/api_test.go b/internal/handlers/api_test.go index da40025..fdd4e75 100644 --- a/internal/handlers/api_test.go +++ b/internal/handlers/api_test.go @@ -1,8 +1,11 @@ +// Tests use a global viper instance for configuration, +// making parallel execution unsafe. +// +//nolint:paralleltest package handlers_test import ( "bytes" - "context" "encoding/json" "fmt" "io" @@ -26,223 +29,290 @@ import ( "go.uber.org/fx/fxtest" ) -const cmdPrivmsg = "PRIVMSG" +const ( + commandKey = "command" + bodyKey = "body" + toKey = "to" + statusKey = "status" + privmsgCmd = "PRIVMSG" + joinCmd = "JOIN" + apiMessages = "/api/v1/messages" + apiSession = "/api/v1/session" + apiState = "/api/v1/state" +) -var viperMu sync.Mutex //nolint:gochecknoglobals // serializes viper access in parallel tests - -// testServer wraps a test HTTP server with helper methods. +// testServer wraps a test HTTP server with helpers. type testServer struct { - srv *httptest.Server - t *testing.T - fxApp *fxtest.App + httpServer *httptest.Server + t *testing.T + fxApp *fxtest.App } -func testGlobals() *globals.Globals { +func newTestServer( + t *testing.T, +) *testServer { + t.Helper() + + dbPath := filepath.Join( + t.TempDir(), "test.db", + ) + + dbURL := "file:" + dbPath + + "?_journal_mode=WAL&_busy_timeout=5000" + + var srv *server.Server + + app := fxtest.New(t, + fx.Provide( + newTestGlobals, + logger.New, + func( + lifecycle fx.Lifecycle, + globs *globals.Globals, + log *logger.Logger, + ) (*config.Config, error) { + cfg, err := config.New( + lifecycle, config.Params{ //nolint:exhaustruct + Globals: globs, Logger: log, + }, + ) + if err != nil { + return nil, fmt.Errorf( + "test config: %w", err, + ) + } + + cfg.DBURL = dbURL + cfg.Port = 0 + + return cfg, nil + }, + newTestDB, + newTestHealthcheck, + newTestMiddleware, + newTestHandlers, + newTestServerFx, + ), + fx.Populate(&srv), + ) + + app.RequireStart() + time.Sleep(100 * time.Millisecond) + + httpSrv := httptest.NewServer(srv) + + t.Cleanup(func() { + httpSrv.Close() + app.RequireStop() + }) + + return &testServer{ + httpServer: httpSrv, + t: t, + fxApp: app, + } +} + +func newTestGlobals() *globals.Globals { return &globals.Globals{ Appname: "chat-test", Version: "test", } } -func testConfigFactory( - dbURL string, -) func(fx.Lifecycle, *globals.Globals, *logger.Logger) (*config.Config, error) { - return func( - lc fx.Lifecycle, - g *globals.Globals, - l *logger.Logger, - ) (*config.Config, error) { - viperMu.Lock() - - c, err := config.New(lc, config.Params{ - Globals: g, Logger: l, - }) - - viperMu.Unlock() - - if err != nil { - return nil, err - } - - c.DBURL = dbURL - - return c, nil - } -} - -func testDB( - lc fx.Lifecycle, - l *logger.Logger, - c *config.Config, +func newTestDB( + lifecycle fx.Lifecycle, + log *logger.Logger, + cfg *config.Config, ) (*db.Database, error) { - return db.New(lc, db.Params{ - Logger: l, Config: c, + database, err := db.New(lifecycle, db.Params{ //nolint:exhaustruct + Logger: log, Config: cfg, }) + if err != nil { + return nil, fmt.Errorf("test db: %w", err) + } + + return database, nil } -func testHealthcheck( - lc fx.Lifecycle, - g *globals.Globals, - c *config.Config, - l *logger.Logger, - d *db.Database, +func newTestHealthcheck( + lifecycle fx.Lifecycle, + globs *globals.Globals, + cfg *config.Config, + log *logger.Logger, + database *db.Database, ) (*healthcheck.Healthcheck, error) { - return healthcheck.New(lc, healthcheck.Params{ - Globals: g, - Config: c, - Logger: l, - Database: d, + hcheck, err := healthcheck.New(lifecycle, healthcheck.Params{ //nolint:exhaustruct + Globals: globs, + Config: cfg, + Logger: log, + Database: database, }) + if err != nil { + return nil, fmt.Errorf("test healthcheck: %w", err) + } + + return hcheck, nil } -func testMiddleware( - lc fx.Lifecycle, - l *logger.Logger, - g *globals.Globals, - c *config.Config, +func newTestMiddleware( + lifecycle fx.Lifecycle, + log *logger.Logger, + globs *globals.Globals, + cfg *config.Config, ) (*middleware.Middleware, error) { - return middleware.New(lc, middleware.Params{ - Logger: l, - Globals: g, - Config: c, + mware, err := middleware.New(lifecycle, middleware.Params{ //nolint:exhaustruct + Logger: log, + Globals: globs, + Config: cfg, }) + if err != nil { + return nil, fmt.Errorf("test middleware: %w", err) + } + + return mware, nil } -func testHandlers( - lc fx.Lifecycle, - l *logger.Logger, - g *globals.Globals, - c *config.Config, - d *db.Database, - hc *healthcheck.Healthcheck, +func newTestHandlers( + lifecycle fx.Lifecycle, + log *logger.Logger, + globs *globals.Globals, + cfg *config.Config, + database *db.Database, + hcheck *healthcheck.Healthcheck, ) (*handlers.Handlers, error) { - return handlers.New(lc, handlers.Params{ - Logger: l, - Globals: g, - Config: c, - Database: d, - Healthcheck: hc, + hdlr, err := handlers.New(lifecycle, handlers.Params{ //nolint:exhaustruct + Logger: log, + Globals: globs, + Config: cfg, + Database: database, + Healthcheck: hcheck, }) + if err != nil { + return nil, fmt.Errorf("test handlers: %w", err) + } + + return hdlr, nil } -func testServer2( - lc fx.Lifecycle, - l *logger.Logger, - g *globals.Globals, - c *config.Config, - mw *middleware.Middleware, - h *handlers.Handlers, +func newTestServerFx( + lifecycle fx.Lifecycle, + log *logger.Logger, + globs *globals.Globals, + cfg *config.Config, + mware *middleware.Middleware, + hdlr *handlers.Handlers, ) (*server.Server, error) { - return server.New(lc, server.Params{ - Logger: l, - Globals: g, - Config: c, - Middleware: mw, - Handlers: h, + srv, err := server.New(lifecycle, server.Params{ //nolint:exhaustruct + Logger: log, + Globals: globs, + Config: cfg, + Middleware: mware, + Handlers: hdlr, }) + if err != nil { + return nil, fmt.Errorf("test server: %w", err) + } + + return srv, nil } -func newTestServer(t *testing.T) *testServer { +func (tserver *testServer) url(path string) string { + return tserver.httpServer.URL + path +} + +func doRequest( + t *testing.T, + method, url string, + body io.Reader, +) (*http.Response, error) { t.Helper() - dbPath := filepath.Join(t.TempDir(), "test.db") - dbURL := "file:" + dbPath + "?_journal_mode=WAL&_busy_timeout=5000" - - var s *server.Server - - app := fxtest.New(t, - fx.Provide( - testGlobals, logger.New, - testConfigFactory(dbURL), testDB, - testHealthcheck, testMiddleware, - testHandlers, testServer2, - ), - fx.Populate(&s), + request, err := http.NewRequestWithContext( + t.Context(), method, url, body, ) - - app.RequireStart() - time.Sleep(100 * time.Millisecond) - - ts := httptest.NewServer(s) - - t.Cleanup(func() { - ts.Close() - app.RequireStop() - }) - - return &testServer{srv: ts, t: t, fxApp: app} -} - -func (ts *testServer) url(path string) string { - return ts.srv.URL + path -} - -func newReqWithCtx( - method, url string, body io.Reader, -) (*http.Request, error) { - return http.NewRequestWithContext( - context.Background(), method, url, body, - ) -} - -func (ts *testServer) doReq( - method, url string, body io.Reader, -) (*http.Response, error) { - ts.t.Helper() - - req, err := newReqWithCtx(method, url, body) if err != nil { return nil, fmt.Errorf("new request: %w", err) } if body != nil { - req.Header.Set("Content-Type", "application/json") + request.Header.Set( + "Content-Type", "application/json", + ) } - return http.DefaultClient.Do(req) //nolint:gosec // test helper with test server URL + resp, err := http.DefaultClient.Do(request) + if err != nil { + return nil, fmt.Errorf("do request: %w", err) + } + + return resp, nil } -func (ts *testServer) doReqAuth( - method, url, token string, body io.Reader, +func doRequestAuth( + t *testing.T, + method, url, token string, + body io.Reader, ) (*http.Response, error) { - ts.t.Helper() + t.Helper() - req, err := newReqWithCtx(method, url, body) + request, err := http.NewRequestWithContext( + t.Context(), method, url, body, + ) if err != nil { return nil, fmt.Errorf("new request: %w", err) } if body != nil { - req.Header.Set("Content-Type", "application/json") + request.Header.Set( + "Content-Type", "application/json", + ) } if token != "" { - req.Header.Set("Authorization", "Bearer "+token) + request.Header.Set( + "Authorization", "Bearer "+token, + ) } - return http.DefaultClient.Do(req) //nolint:gosec // test helper with test server URL + resp, err := http.DefaultClient.Do(request) + if err != nil { + return nil, fmt.Errorf("do request: %w", err) + } + + return resp, nil } -func (ts *testServer) createSession(nick string) string { - ts.t.Helper() +func (tserver *testServer) createSession( + nick string, +) string { + tserver.t.Helper() - body, err := json.Marshal(map[string]string{"nick": nick}) - if err != nil { - ts.t.Fatalf("marshal session: %v", err) - } - - resp, err := ts.doReq( - http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), + body, err := json.Marshal( + map[string]string{"nick": nick}, ) if err != nil { - ts.t.Fatalf("create session: %v", err) + tserver.t.Fatalf("marshal session: %v", err) + } + + resp, err := doRequest( + tserver.t, + http.MethodPost, + tserver.url(apiSession), + bytes.NewReader(body), + ) + if err != nil { + tserver.t.Fatalf("create session: %v", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusCreated { - b, _ := io.ReadAll(resp.Body) - ts.t.Fatalf("create session: status %d: %s", resp.StatusCode, b) + respBody, _ := io.ReadAll(resp.Body) + tserver.t.Fatalf( + "create session: status %d: %s", + resp.StatusCode, respBody, + ) } var result struct { @@ -250,29 +320,33 @@ func (ts *testServer) createSession(nick string) string { Token string `json:"token"` } - err = json.NewDecoder(resp.Body).Decode(&result) - if err != nil { - ts.t.Fatalf("decode session: %v", err) + decErr := json.NewDecoder(resp.Body).Decode(&result) + if decErr != nil { + tserver.t.Fatalf("decode session: %v", decErr) } return result.Token } -func (ts *testServer) sendCommand( +func (tserver *testServer) sendCommand( token string, cmd map[string]any, ) (int, map[string]any) { - ts.t.Helper() + tserver.t.Helper() body, err := json.Marshal(cmd) if err != nil { - ts.t.Fatalf("marshal command: %v", err) + tserver.t.Fatalf("marshal command: %v", err) } - resp, err := ts.doReqAuth( - http.MethodPost, ts.url("/api/v1/messages"), token, bytes.NewReader(body), + resp, err := doRequestAuth( + tserver.t, + http.MethodPost, + tserver.url(apiMessages), + token, + bytes.NewReader(body), ) if err != nil { - ts.t.Fatalf("send command: %v", err) + tserver.t.Fatalf("send command: %v", err) } defer func() { _ = resp.Body.Close() }() @@ -284,14 +358,20 @@ func (ts *testServer) sendCommand( return resp.StatusCode, result } -func (ts *testServer) getJSON( - token, path string, //nolint:unparam +func (tserver *testServer) getState( + token string, ) (int, map[string]any) { - ts.t.Helper() + tserver.t.Helper() - resp, err := ts.doReqAuth(http.MethodGet, ts.url(path), token, nil) + resp, err := doRequestAuth( + tserver.t, + http.MethodGet, + tserver.url(apiState), + token, + nil, + ) if err != nil { - ts.t.Fatalf("get: %v", err) + tserver.t.Fatalf("get: %v", err) } defer func() { _ = resp.Body.Close() }() @@ -303,18 +383,25 @@ func (ts *testServer) getJSON( return resp.StatusCode, result } -func (ts *testServer) pollMessages( +func (tserver *testServer) pollMessages( token string, afterID int64, ) ([]map[string]any, int64) { - ts.t.Helper() + tserver.t.Helper() - url := fmt.Sprintf( - "%s/api/v1/messages?timeout=0&after=%d", ts.srv.URL, afterID, + pollURL := fmt.Sprintf( + "%s"+apiMessages+"?timeout=0&after=%d", + tserver.httpServer.URL, afterID, ) - resp, err := ts.doReqAuth(http.MethodGet, url, token, nil) + resp, err := doRequestAuth( + tserver.t, + http.MethodGet, + pollURL, + token, + nil, + ) if err != nil { - ts.t.Fatalf("poll: %v", err) + tserver.t.Fatalf("poll: %v", err) } defer func() { _ = resp.Body.Close() }() @@ -324,9 +411,9 @@ func (ts *testServer) pollMessages( LastID json.Number `json:"last_id"` //nolint:tagliatelle } - err = json.NewDecoder(resp.Body).Decode(&result) - if err != nil { - ts.t.Fatalf("decode poll: %v", err) + decErr := json.NewDecoder(resp.Body).Decode(&result) + if decErr != nil { + tserver.t.Fatalf("decode poll: %v", decErr) } lastID, _ := result.LastID.Int64() @@ -334,21 +421,121 @@ func (ts *testServer) pollMessages( return result.Messages, lastID } -func postSessionExpect( +func postSession( t *testing.T, - ts *testServer, + tserver *testServer, nick string, - wantStatus int, -) { +) *http.Response { t.Helper() - body, err := json.Marshal(map[string]string{"nick": nick}) + body, err := json.Marshal( + map[string]string{"nick": nick}, + ) if err != nil { t.Fatalf("marshal: %v", err) } - resp, err := ts.doReq( - http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), + resp, err := doRequest( + t, + http.MethodPost, + tserver.url(apiSession), + bytes.NewReader(body), + ) + if err != nil { + t.Fatal(err) + } + + return resp +} + +func findMessage( + msgs []map[string]any, + command, from string, +) bool { + for _, msg := range msgs { + if msg[commandKey] == command && + msg["from"] == from { + return true + } + } + + return false +} + +// --- Tests --- + +func TestCreateSessionValid(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("alice") + + if token == "" { + t.Fatal("expected token") + } +} + +func TestCreateSessionDuplicate(t *testing.T) { + tserver := newTestServer(t) + tserver.createSession("alice") + + resp := postSession(t, tserver, "alice") + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusConflict { + t.Fatalf("expected 409, got %d", resp.StatusCode) + } +} + +func TestCreateSessionEmpty(t *testing.T) { + tserver := newTestServer(t) + + resp := postSession(t, tserver, "") + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf( + "expected 400, got %d", resp.StatusCode, + ) + } +} + +func TestCreateSessionInvalidChars(t *testing.T) { + tserver := newTestServer(t) + + resp := postSession(t, tserver, "hello world") + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf( + "expected 400, got %d", resp.StatusCode, + ) + } +} + +func TestCreateSessionNumericStart(t *testing.T) { + tserver := newTestServer(t) + + resp := postSession(t, tserver, "123abc") + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf( + "expected 400, got %d", resp.StatusCode, + ) + } +} + +func TestCreateSessionMalformed(t *testing.T) { + tserver := newTestServer(t) + + resp, err := doRequest( + t, + http.MethodPost, + tserver.url(apiSession), + strings.NewReader("{bad"), ) if err != nil { t.Fatal(err) @@ -356,296 +543,268 @@ func postSessionExpect( defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != wantStatus { - t.Fatalf("expected %d, got %d", wantStatus, resp.StatusCode) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf( + "expected 400, got %d", resp.StatusCode, + ) } } -// --- Tests --- +func TestAuthNoHeader(t *testing.T) { + tserver := newTestServer(t) -func TestCreateSession(t *testing.T) { - t.Parallel() + status, _ := tserver.getState("") + if status != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", status) + } +} - ts := newTestServer(t) +func TestAuthBadToken(t *testing.T) { + tserver := newTestServer(t) - t.Run("valid nick", func(t *testing.T) { - t.Parallel() + status, _ := tserver.getState( + "invalid-token-12345", + ) + if status != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", status) + } +} - token := ts.createSession("alice") - if token == "" { - t.Fatal("expected token") - } - }) +func TestAuthValidToken(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("authtest") - t.Run("duplicate nick", func(t *testing.T) { - t.Parallel() + status, result := tserver.getState(token) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d", status) + } - ts2 := newTestServer(t) - ts2.createSession("dupnick") - - postSessionExpect(t, ts2, "dupnick", http.StatusConflict) - }) - - t.Run("empty nick", func(t *testing.T) { - t.Parallel() - - postSessionExpect(t, ts, "", http.StatusBadRequest) - }) - - t.Run("invalid nick chars", func(t *testing.T) { - t.Parallel() - - postSessionExpect(t, ts, "hello world", http.StatusBadRequest) - }) - - t.Run("nick starting with number", func(t *testing.T) { - t.Parallel() - - postSessionExpect(t, ts, "123abc", http.StatusBadRequest) - }) - - t.Run("malformed json", func(t *testing.T) { - t.Parallel() - - resp, err := ts.doReq( - http.MethodPost, ts.url("/api/v1/session"), strings.NewReader("{bad"), + if result["nick"] != "authtest" { + t.Fatalf( + "expected nick authtest, got %v", + result["nick"], ) - if err != nil { - t.Fatal(err) - } - - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.StatusCode) - } - }) + } } -func TestAuth(t *testing.T) { - t.Parallel() +func TestJoinChannel(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("joiner") - ts := newTestServer(t) + status, result := tserver.sendCommand( + token, + map[string]any{ + commandKey: joinCmd, toKey: "#test", + }, + ) + if status != http.StatusOK { + t.Fatalf( + "expected 200, got %d: %v", status, result, + ) + } - t.Run("no auth header", func(t *testing.T) { - t.Parallel() - - status, _ := ts.getJSON("", "/api/v1/state") - if status != http.StatusUnauthorized { - t.Fatalf("expected 401, got %d", status) - } - }) - - t.Run("bad token", func(t *testing.T) { - t.Parallel() - - status, _ := ts.getJSON("invalid-token-12345", "/api/v1/state") - if status != http.StatusUnauthorized { - t.Fatalf("expected 401, got %d", status) - } - }) - - t.Run("valid token", func(t *testing.T) { - t.Parallel() - - token := ts.createSession("authtest") - - status, result := ts.getJSON(token, "/api/v1/state") - if status != http.StatusOK { - t.Fatalf("expected 200, got %d", status) - } - - if result["nick"] != "authtest" { - t.Fatalf("expected nick authtest, got %v", result["nick"]) - } - }) + if result["channel"] != "#test" { + t.Fatalf( + "expected #test, got %v", result["channel"], + ) + } } -func TestJoinAndPart(t *testing.T) { - t.Parallel() +func TestJoinWithoutHash(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("joiner2") - ts := newTestServer(t) - token := ts.createSession("bob") + status, result := tserver.sendCommand( + token, + map[string]any{ + commandKey: joinCmd, toKey: "other", + }, + ) + if status != http.StatusOK { + t.Fatalf( + "expected 200, got %d: %v", status, result, + ) + } - t.Run("join channel", func(t *testing.T) { - t.Parallel() - - status, result := ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#test"}) - if status != http.StatusOK { - t.Fatalf("expected 200, got %d: %v", status, result) - } - - if result["channel"] != "#test" { - t.Fatalf("expected #test, got %v", result["channel"]) - } - }) - - t.Run("join without hash", func(t *testing.T) { - t.Parallel() - - status, result := ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "other"}) - if status != http.StatusOK { - t.Fatalf("expected 200, got %d: %v", status, result) - } - - if result["channel"] != "#other" { - t.Fatalf("expected #other, got %v", result["channel"]) - } - }) - - t.Run("part channel", func(t *testing.T) { - t.Parallel() - - ts2 := newTestServer(t) - tok := ts2.createSession("partuser") - ts2.sendCommand(tok, map[string]any{"command": "JOIN", "to": "#partchan"}) - - status, result := ts2.sendCommand(tok, map[string]any{"command": "PART", "to": "#partchan"}) - if status != http.StatusOK { - t.Fatalf("expected 200, got %d: %v", status, result) - } - - if result["channel"] != "#partchan" { - t.Fatalf("expected #partchan, got %v", result["channel"]) - } - }) - - t.Run("join missing to", func(t *testing.T) { - t.Parallel() - - status, _ := ts.sendCommand(token, map[string]any{"command": "JOIN"}) - if status != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", status) - } - }) + if result["channel"] != "#other" { + t.Fatalf( + "expected #other, got %v", + result["channel"], + ) + } } -func TestPrivmsgChannel(t *testing.T) { - t.Parallel() +func TestPartChannel(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("parter") - ts := newTestServer(t) - aliceToken := ts.createSession("alice_msg") - bobToken := ts.createSession("bob_msg") + tserver.sendCommand( + token, + map[string]any{ + commandKey: joinCmd, toKey: "#test", + }, + ) - ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#chat"}) - ts.sendCommand(bobToken, map[string]any{"command": "JOIN", "to": "#chat"}) + status, result := tserver.sendCommand( + token, + map[string]any{ + commandKey: "PART", toKey: "#test", + }, + ) + if status != http.StatusOK { + t.Fatalf( + "expected 200, got %d: %v", status, result, + ) + } - _, _ = ts.pollMessages(aliceToken, 0) - _, bobLastID := ts.pollMessages(bobToken, 0) + if result["channel"] != "#test" { + t.Fatalf( + "expected #test, got %v", result["channel"], + ) + } +} - status, result := ts.sendCommand(aliceToken, map[string]any{ - "command": cmdPrivmsg, - "to": "#chat", - "body": []string{"hello world"}, +func TestJoinMissingTo(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("joiner3") + + status, _ := tserver.sendCommand( + token, map[string]any{commandKey: joinCmd}, + ) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) + } +} + +func TestChannelMessage(t *testing.T) { + tserver := newTestServer(t) + aliceToken := tserver.createSession("alice_msg") + bobToken := tserver.createSession("bob_msg") + + tserver.sendCommand(aliceToken, map[string]any{ + commandKey: joinCmd, toKey: "#chat", }) + tserver.sendCommand(bobToken, map[string]any{ + commandKey: joinCmd, toKey: "#chat", + }) + + _, _ = tserver.pollMessages(aliceToken, 0) + _, bobLastID := tserver.pollMessages(bobToken, 0) + + status, result := tserver.sendCommand( + aliceToken, + map[string]any{ + commandKey: privmsgCmd, + toKey: "#chat", + bodyKey: []string{"hello world"}, + }, + ) if status != http.StatusCreated { - t.Fatalf("expected 201, got %d: %v", status, result) + t.Fatalf( + "expected 201, got %d: %v", status, result, + ) } if result["id"] == nil || result["id"] == "" { t.Fatal("expected message id") } - msgs, _ := ts.pollMessages(bobToken, bobLastID) - - found := false - - for _, m := range msgs { - if m["command"] == cmdPrivmsg && m["from"] == "alice_msg" { - found = true - - break - } - } - - if !found { - t.Fatalf("bob didn't receive alice's message: %v", msgs) + msgs, _ := tserver.pollMessages( + bobToken, bobLastID, + ) + if !findMessage(msgs, privmsgCmd, "alice_msg") { + t.Fatalf( + "bob didn't receive alice's message: %v", + msgs, + ) } } -func TestPrivmsgErrors(t *testing.T) { - t.Parallel() +func TestMessageMissingBody(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("nobody") - ts := newTestServer(t) - aliceToken := ts.createSession("alice_msg2") - - ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#chat2"}) - - t.Run("missing body", func(t *testing.T) { - t.Parallel() - - status, _ := ts.sendCommand(aliceToken, map[string]any{ - "command": cmdPrivmsg, - "to": "#chat2", - }) - if status != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", status) - } + tserver.sendCommand(token, map[string]any{ + commandKey: joinCmd, toKey: "#chat", }) - t.Run("missing to", func(t *testing.T) { - t.Parallel() - - status, _ := ts.sendCommand(aliceToken, map[string]any{ - "command": cmdPrivmsg, - "body": []string{"hello"}, - }) - if status != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", status) - } + status, _ := tserver.sendCommand(token, map[string]any{ + commandKey: privmsgCmd, toKey: "#chat", }) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) + } } -func TestDMSend(t *testing.T) { - t.Parallel() +func TestMessageMissingTo(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("noto") - ts := newTestServer(t) - aliceToken := ts.createSession("alice_dm") - bobToken := ts.createSession("bob_dm") - - status, result := ts.sendCommand(aliceToken, map[string]any{ - "command": cmdPrivmsg, - "to": "bob_dm", - "body": []string{"hey bob"}, + status, _ := tserver.sendCommand(token, map[string]any{ + commandKey: privmsgCmd, + bodyKey: []string{"hello"}, }) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) + } +} + +func TestNonMemberCannotSend(t *testing.T) { + tserver := newTestServer(t) + aliceToken := tserver.createSession("alice_nosend") + bobToken := tserver.createSession("bob_nosend") + + // Only bob joins the channel. + tserver.sendCommand(bobToken, map[string]any{ + commandKey: joinCmd, toKey: "#private", + }) + + // Alice tries to send without joining. + status, _ := tserver.sendCommand( + aliceToken, + map[string]any{ + commandKey: privmsgCmd, + toKey: "#private", + bodyKey: []string{"sneaky"}, + }, + ) + if status != http.StatusForbidden { + t.Fatalf("expected 403, got %d", status) + } +} + +func TestDirectMessage(t *testing.T) { + tserver := newTestServer(t) + aliceToken := tserver.createSession("alice_dm") + bobToken := tserver.createSession("bob_dm") + + status, result := tserver.sendCommand( + aliceToken, + map[string]any{ + commandKey: privmsgCmd, + toKey: "bob_dm", + bodyKey: []string{"hey bob"}, + }, + ) if status != http.StatusCreated { - t.Fatalf("expected 201, got %d: %v", status, result) + t.Fatalf( + "expected 201, got %d: %v", status, result, + ) } - msgs, _ := ts.pollMessages(bobToken, 0) - - found := false - - for _, m := range msgs { - if m["command"] == cmdPrivmsg && m["from"] == "alice_dm" { - found = true - } - } - - if !found { + msgs, _ := tserver.pollMessages(bobToken, 0) + if !findMessage(msgs, privmsgCmd, "alice_dm") { t.Fatal("bob didn't receive DM") } -} -func TestDMEcho(t *testing.T) { - t.Parallel() - - ts := newTestServer(t) - aliceToken := ts.createSession("alice_echo") - ts.createSession("bob_echo") - - ts.sendCommand(aliceToken, map[string]any{ - "command": cmdPrivmsg, - "to": "bob_echo", - "body": []string{"hey bob"}, - }) - - msgs, _ := ts.pollMessages(aliceToken, 0) + aliceMsgs, _ := tserver.pollMessages(aliceToken, 0) found := false - for _, m := range msgs { - if m["command"] == cmdPrivmsg && m["from"] == "alice_echo" && m["to"] == "bob_echo" { + for _, msg := range aliceMsgs { + if msg[commandKey] == privmsgCmd && + msg["from"] == "alice_dm" && + msg[toKey] == "bob_dm" { found = true } } @@ -655,16 +814,14 @@ func TestDMEcho(t *testing.T) { } } -func TestDMNonexistent(t *testing.T) { - t.Parallel() +func TestDMToNonexistentUser(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("dmsender") - ts := newTestServer(t) - aliceToken := ts.createSession("alice_noone") - - status, _ := ts.sendCommand(aliceToken, map[string]any{ - "command": cmdPrivmsg, - "to": "nobody", - "body": []string{"hello?"}, + status, _ := tserver.sendCommand(token, map[string]any{ + commandKey: privmsgCmd, + toKey: "nobody", + bodyKey: []string{"hello?"}, }) if status != http.StatusNotFound { t.Fatalf("expected 404, got %d", status) @@ -672,228 +829,246 @@ func TestDMNonexistent(t *testing.T) { } func TestNickChange(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) + token := tserver.createSession("nick_test") - ts := newTestServer(t) - tok := ts.createSession("nick_change") - - status, result := ts.sendCommand(tok, map[string]any{ - "command": "NICK", - "body": []string{"newnick"}, - }) + status, result := tserver.sendCommand( + token, + map[string]any{ + commandKey: "NICK", + bodyKey: []string{"newnick"}, + }, + ) if status != http.StatusOK { - t.Fatalf("expected 200, got %d: %v", status, result) + t.Fatalf( + "expected 200, got %d: %v", status, result, + ) } if result["nick"] != "newnick" { - t.Fatalf("expected newnick, got %v", result["nick"]) + t.Fatalf( + "expected newnick, got %v", result["nick"], + ) } } func TestNickSameAsCurrent(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) + token := tserver.createSession("same_nick") - ts := newTestServer(t) - tok := ts.createSession("samenick") - - status, _ := ts.sendCommand(tok, map[string]any{ - "command": "NICK", - "body": []string{"samenick"}, + status, _ := tserver.sendCommand(token, map[string]any{ + commandKey: "NICK", + bodyKey: []string{"same_nick"}, }) if status != http.StatusOK { t.Fatalf("expected 200, got %d", status) } } -func TestNickErrors(t *testing.T) { - t.Parallel() +func TestNickCollision(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("nickuser") - ts := newTestServer(t) - token := ts.createSession("nick_test") + tserver.createSession("taken_nick") - t.Run("collision", func(t *testing.T) { - t.Parallel() - - ts.createSession("taken_nick") - - status, _ := ts.sendCommand(token, map[string]any{ - "command": "NICK", - "body": []string{"taken_nick"}, - }) - if status != http.StatusConflict { - t.Fatalf("expected 409, got %d", status) - } + status, _ := tserver.sendCommand(token, map[string]any{ + commandKey: "NICK", + bodyKey: []string{"taken_nick"}, }) + if status != http.StatusConflict { + t.Fatalf("expected 409, got %d", status) + } +} - t.Run("invalid", func(t *testing.T) { - t.Parallel() +func TestNickInvalid(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("nickval") - status, _ := ts.sendCommand(token, map[string]any{ - "command": "NICK", - "body": []string{"bad nick!"}, - }) - if status != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", status) - } + status, _ := tserver.sendCommand(token, map[string]any{ + commandKey: "NICK", + bodyKey: []string{"bad nick!"}, }) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) + } +} - t.Run("empty body", func(t *testing.T) { - t.Parallel() +func TestNickEmptyBody(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("nicknobody") - status, _ := ts.sendCommand(token, map[string]any{ - "command": "NICK", - }) - if status != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", status) - } - }) + status, _ := tserver.sendCommand( + token, map[string]any{commandKey: "NICK"}, + ) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) + } } func TestTopic(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) + token := tserver.createSession("topic_user") - ts := newTestServer(t) - token := ts.createSession("topic_user") - - ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#topictest"}) - - t.Run("set topic", func(t *testing.T) { - t.Parallel() - - status, result := ts.sendCommand(token, map[string]any{ - "command": "TOPIC", - "to": "#topictest", - "body": []string{"Hello World Topic"}, - }) - if status != http.StatusOK { - t.Fatalf("expected 200, got %d: %v", status, result) - } - - if result["topic"] != "Hello World Topic" { - t.Fatalf("expected topic, got %v", result["topic"]) - } + tserver.sendCommand(token, map[string]any{ + commandKey: joinCmd, toKey: "#topictest", }) - t.Run("missing to", func(t *testing.T) { - t.Parallel() + status, result := tserver.sendCommand( + token, + map[string]any{ + commandKey: "TOPIC", + toKey: "#topictest", + bodyKey: []string{"Hello World Topic"}, + }, + ) + if status != http.StatusOK { + t.Fatalf( + "expected 200, got %d: %v", status, result, + ) + } - status, _ := ts.sendCommand(token, map[string]any{ - "command": "TOPIC", - "body": []string{"topic"}, - }) - if status != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", status) - } + if result["topic"] != "Hello World Topic" { + t.Fatalf( + "expected topic, got %v", result["topic"], + ) + } +} + +func TestTopicMissingTo(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("topicnoto") + + status, _ := tserver.sendCommand(token, map[string]any{ + commandKey: "TOPIC", + bodyKey: []string{"topic"}, + }) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) + } +} + +func TestTopicMissingBody(t *testing.T) { + tserver := newTestServer(t) + token := tserver.createSession("topicnobody") + + tserver.sendCommand(token, map[string]any{ + commandKey: joinCmd, toKey: "#topictest", }) - t.Run("missing body", func(t *testing.T) { - t.Parallel() - - status, _ := ts.sendCommand(token, map[string]any{ - "command": "TOPIC", - "to": "#topictest", - }) - if status != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", status) - } + status, _ := tserver.sendCommand(token, map[string]any{ + commandKey: "TOPIC", toKey: "#topictest", }) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) + } } func TestPing(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) + token := tserver.createSession("ping_user") - ts := newTestServer(t) - token := ts.createSession("ping_user") - - status, result := ts.sendCommand(token, map[string]any{"command": "PING"}) + status, result := tserver.sendCommand( + token, map[string]any{commandKey: "PING"}, + ) if status != http.StatusOK { t.Fatalf("expected 200, got %d", status) } - if result["command"] != "PONG" { - t.Fatalf("expected PONG, got %v", result["command"]) + if result[commandKey] != "PONG" { + t.Fatalf( + "expected PONG, got %v", + result[commandKey], + ) } } func TestQuit(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) + token := tserver.createSession("quitter") + observerToken := tserver.createSession("observer") - ts := newTestServer(t) - token := ts.createSession("quitter") - observerToken := ts.createSession("observer") + tserver.sendCommand(token, map[string]any{ + commandKey: joinCmd, toKey: "#quitchan", + }) + tserver.sendCommand(observerToken, map[string]any{ + commandKey: joinCmd, toKey: "#quitchan", + }) - ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#quitchan"}) - ts.sendCommand(observerToken, map[string]any{"command": "JOIN", "to": "#quitchan"}) + _, lastID := tserver.pollMessages(observerToken, 0) - _, lastID := ts.pollMessages(observerToken, 0) - - status, result := ts.sendCommand(token, map[string]any{"command": "QUIT"}) + status, result := tserver.sendCommand( + token, map[string]any{commandKey: "QUIT"}, + ) if status != http.StatusOK { - t.Fatalf("expected 200, got %d: %v", status, result) + t.Fatalf( + "expected 200, got %d: %v", status, result, + ) } - msgs, _ := ts.pollMessages(observerToken, lastID) - - found := false - - for _, m := range msgs { - if m["command"] == "QUIT" && m["from"] == "quitter" { - found = true - } + msgs, _ := tserver.pollMessages( + observerToken, lastID, + ) + if !findMessage(msgs, "QUIT", "quitter") { + t.Fatalf( + "observer didn't get QUIT: %v", msgs, + ) } - if !found { - t.Fatalf("observer didn't get QUIT: %v", msgs) - } - - status2, _ := ts.getJSON(token, "/api/v1/state") - if status2 != http.StatusUnauthorized { - t.Fatalf("expected 401 after quit, got %d", status2) + afterStatus, _ := tserver.getState(token) + if afterStatus != http.StatusUnauthorized { + t.Fatalf( + "expected 401 after quit, got %d", + afterStatus, + ) } } func TestUnknownCommand(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) + token := tserver.createSession("cmdtest") - ts := newTestServer(t) - token := ts.createSession("cmdtest") - - status, result := ts.sendCommand(token, map[string]any{"command": "BOGUS"}) + status, _ := tserver.sendCommand( + token, map[string]any{commandKey: "BOGUS"}, + ) if status != http.StatusBadRequest { - t.Fatalf("expected 400, got %d: %v", status, result) + t.Fatalf("expected 400, got %d", status) } } func TestEmptyCommand(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) + token := tserver.createSession("emptycmd") - ts := newTestServer(t) - token := ts.createSession("emptycmd") - - status, _ := ts.sendCommand(token, map[string]any{"command": ""}) + status, _ := tserver.sendCommand( + token, map[string]any{commandKey: ""}, + ) if status != http.StatusBadRequest { t.Fatalf("expected 400, got %d", status) } } func TestHistory(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) + token := tserver.createSession("historian") - ts := newTestServer(t) - token := ts.createSession("historian") - - ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#history"}) + tserver.sendCommand(token, map[string]any{ + commandKey: joinCmd, toKey: "#history", + }) for range 5 { - ts.sendCommand(token, map[string]any{ - "command": cmdPrivmsg, - "to": "#history", - "body": []string{"test message"}, + tserver.sendCommand(token, map[string]any{ + commandKey: privmsgCmd, + toKey: "#history", + bodyKey: []string{"test message"}, }) } - resp, err := ts.doReqAuth( - http.MethodGet, ts.url("/api/v1/history?target=%23history&limit=3"), token, nil, + histURL := tserver.url( + "/api/v1/history?target=%23history&limit=3", + ) + + resp, err := doRequestAuth( + t, http.MethodGet, histURL, token, nil, ) if err != nil { t.Fatal(err) @@ -902,14 +1077,16 @@ func TestHistory(t *testing.T) { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d", resp.StatusCode) + t.Fatalf( + "expected 200, got %d", resp.StatusCode, + ) } var msgs []map[string]any - err = json.NewDecoder(resp.Body).Decode(&msgs) - if err != nil { - t.Fatalf("decode history: %v", err) + decErr := json.NewDecoder(resp.Body).Decode(&msgs) + if decErr != nil { + t.Fatalf("decode history: %v", decErr) } if len(msgs) != 3 { @@ -917,16 +1094,56 @@ func TestHistory(t *testing.T) { } } +func TestHistoryNonMember(t *testing.T) { + tserver := newTestServer(t) + aliceToken := tserver.createSession("alice_hist") + bobToken := tserver.createSession("bob_hist") + + // Alice creates and joins a channel. + tserver.sendCommand(aliceToken, map[string]any{ + commandKey: joinCmd, toKey: "#secret", + }) + tserver.sendCommand(aliceToken, map[string]any{ + commandKey: privmsgCmd, + toKey: "#secret", + bodyKey: []string{"secret stuff"}, + }) + + // Bob tries to read history without joining. + histURL := tserver.url( + "/api/v1/history?target=%23secret", + ) + + resp, err := doRequestAuth( + t, http.MethodGet, histURL, bobToken, nil, + ) + if err != nil { + t.Fatal(err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusForbidden { + t.Fatalf( + "expected 403, got %d", resp.StatusCode, + ) + } +} + func TestChannelList(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) + token := tserver.createSession("lister") - ts := newTestServer(t) - token := ts.createSession("lister") + tserver.sendCommand(token, map[string]any{ + commandKey: joinCmd, toKey: "#listchan", + }) - ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#listchan"}) - - resp, err := ts.doReqAuth( - http.MethodGet, ts.url("/api/v1/channels"), token, nil, + resp, err := doRequestAuth( + t, + http.MethodGet, + tserver.url("/api/v1/channels"), + token, + nil, ) if err != nil { t.Fatal(err) @@ -935,20 +1152,24 @@ func TestChannelList(t *testing.T) { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d", resp.StatusCode) + t.Fatalf( + "expected 200, got %d", resp.StatusCode, + ) } var channels []map[string]any - err = json.NewDecoder(resp.Body).Decode(&channels) - if err != nil { - t.Fatalf("decode channels: %v", err) + decErr := json.NewDecoder(resp.Body).Decode( + &channels, + ) + if decErr != nil { + t.Fatalf("decode channels: %v", decErr) } found := false - for _, ch := range channels { - if ch["name"] == "#listchan" { + for _, channel := range channels { + if channel["name"] == "#listchan" { found = true } } @@ -959,15 +1180,21 @@ func TestChannelList(t *testing.T) { } func TestChannelMembers(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) + token := tserver.createSession("membertest") - ts := newTestServer(t) - token := ts.createSession("membertest") + tserver.sendCommand(token, map[string]any{ + commandKey: joinCmd, toKey: "#members", + }) - ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#members"}) - - resp, err := ts.doReqAuth( - http.MethodGet, ts.url("/api/v1/channels/members/members"), token, nil, + resp, err := doRequestAuth( + t, + http.MethodGet, + tserver.url( + "/api/v1/channels/members/members", + ), + token, + nil, ) if err != nil { t.Fatal(err) @@ -976,29 +1203,45 @@ func TestChannelMembers(t *testing.T) { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d", resp.StatusCode) + t.Fatalf( + "expected 200, got %d", resp.StatusCode, + ) } } -func asyncPoll( - ts *testServer, - token string, - afterID int64, -) <-chan []map[string]any { - ch := make(chan []map[string]any, 1) +func TestLongPoll(t *testing.T) { + tserver := newTestServer(t) + aliceToken := tserver.createSession("lp_alice") + bobToken := tserver.createSession("lp_bob") + + tserver.sendCommand(aliceToken, map[string]any{ + commandKey: joinCmd, toKey: "#longpoll", + }) + tserver.sendCommand(bobToken, map[string]any{ + commandKey: joinCmd, toKey: "#longpoll", + }) + + _, lastID := tserver.pollMessages(bobToken, 0) + + var waitGroup sync.WaitGroup + + var pollMsgs []map[string]any + + waitGroup.Add(1) go func() { - url := fmt.Sprintf( - "%s/api/v1/messages?timeout=5&after=%d", - ts.srv.URL, afterID, + defer waitGroup.Done() + + pollURL := fmt.Sprintf( + "%s"+apiMessages+"?timeout=5&after=%d", + tserver.httpServer.URL, lastID, ) - resp, err := ts.doReqAuth( - http.MethodGet, url, token, nil, + resp, err := doRequestAuth( + t, http.MethodGet, + pollURL, bobToken, nil, ) if err != nil { - ch <- nil - return } @@ -1010,59 +1253,39 @@ func asyncPoll( _ = json.NewDecoder(resp.Body).Decode(&result) - ch <- result.Messages + pollMsgs = result.Messages }() - return ch -} - -func TestLongPoll(t *testing.T) { - t.Parallel() - - ts := newTestServer(t) - aliceToken := ts.createSession("lp_alice") - bobToken := ts.createSession("lp_bob") - - ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#longpoll"}) - ts.sendCommand(bobToken, map[string]any{"command": "JOIN", "to": "#longpoll"}) - - _, lastID := ts.pollMessages(bobToken, 0) - - pollMsgs := asyncPoll(ts, bobToken, lastID) - time.Sleep(200 * time.Millisecond) - ts.sendCommand(aliceToken, map[string]any{ - "command": cmdPrivmsg, - "to": "#longpoll", - "body": []string{"wake up!"}, + tserver.sendCommand(aliceToken, map[string]any{ + commandKey: privmsgCmd, + toKey: "#longpoll", + bodyKey: []string{"wake up!"}, }) - msgs := <-pollMsgs + waitGroup.Wait() - found := false - - for _, m := range msgs { - if m["command"] == cmdPrivmsg && m["from"] == "lp_alice" { - found = true - } - } - - if !found { - t.Fatalf("long-poll didn't receive message: %v", msgs) + if !findMessage(pollMsgs, privmsgCmd, "lp_alice") { + t.Fatalf( + "long-poll didn't receive message: %v", + pollMsgs, + ) } } func TestLongPollTimeout(t *testing.T) { - t.Parallel() - - ts := newTestServer(t) - token := ts.createSession("lp_timeout") + tserver := newTestServer(t) + token := tserver.createSession("lp_timeout") start := time.Now() - resp, err := ts.doReqAuth( - http.MethodGet, ts.url("/api/v1/messages?timeout=1"), token, nil, + resp, err := doRequestAuth( + t, + http.MethodGet, + tserver.url(apiMessages+"?timeout=1"), + token, + nil, ) if err != nil { t.Fatal(err) @@ -1073,25 +1296,35 @@ func TestLongPollTimeout(t *testing.T) { elapsed := time.Since(start) if elapsed < 900*time.Millisecond { - t.Fatalf("long-poll returned too fast: %v", elapsed) + t.Fatalf( + "long-poll returned too fast: %v", elapsed, + ) } if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d", resp.StatusCode) + t.Fatalf( + "expected 200, got %d", resp.StatusCode, + ) } } func TestEphemeralChannelCleanup(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) + token := tserver.createSession("ephemeral") - ts := newTestServer(t) - token := ts.createSession("ephemeral") + tserver.sendCommand(token, map[string]any{ + commandKey: joinCmd, toKey: "#ephemeral", + }) + tserver.sendCommand(token, map[string]any{ + commandKey: "PART", toKey: "#ephemeral", + }) - ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#ephemeral"}) - ts.sendCommand(token, map[string]any{"command": "PART", "to": "#ephemeral"}) - - resp, err := ts.doReqAuth( - http.MethodGet, ts.url("/api/v1/channels"), token, nil, + resp, err := doRequestAuth( + t, + http.MethodGet, + tserver.url("/api/v1/channels"), + token, + nil, ) if err != nil { t.Fatal(err) @@ -1101,44 +1334,55 @@ func TestEphemeralChannelCleanup(t *testing.T) { var channels []map[string]any - err = json.NewDecoder(resp.Body).Decode(&channels) - if err != nil { - t.Fatalf("decode channels: %v", err) + decErr := json.NewDecoder(resp.Body).Decode( + &channels, + ) + if decErr != nil { + t.Fatalf("decode channels: %v", decErr) } - for _, ch := range channels { - if ch["name"] == "#ephemeral" { - t.Fatal("ephemeral channel should have been cleaned up") + for _, channel := range channels { + if channel["name"] == "#ephemeral" { + t.Fatal( + "ephemeral channel should be cleaned up", + ) } } } func TestConcurrentSessions(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) - ts := newTestServer(t) + var waitGroup sync.WaitGroup - var wg sync.WaitGroup + const concurrency = 20 - errs := make(chan error, 20) + errs := make(chan error, concurrency) - for i := range 20 { - wg.Add(1) + for idx := range concurrency { + waitGroup.Add(1) - go func(idx int) { - defer wg.Done() + go func(index int) { + defer waitGroup.Done() - nick := fmt.Sprintf("concurrent_%d", idx) + nick := fmt.Sprintf("conc_%d", index) - body, err := json.Marshal(map[string]string{"nick": nick}) + body, err := json.Marshal( + map[string]string{"nick": nick}, + ) if err != nil { - errs <- fmt.Errorf("marshal: %w", err) + errs <- fmt.Errorf( + "marshal: %w", err, + ) return } - resp, err := ts.doReq( - http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), + resp, err := doRequest( + t, + http.MethodPost, + tserver.url(apiSession), + bytes.NewReader(body), ) if err != nil { errs <- err @@ -1150,28 +1394,32 @@ func TestConcurrentSessions(t *testing.T) { if resp.StatusCode != http.StatusCreated { errs <- fmt.Errorf( //nolint:err113 - "status %d for %s", resp.StatusCode, nick, + "status %d for %s", + resp.StatusCode, nick, ) } - }(i) + }(idx) } - wg.Wait() + waitGroup.Wait() close(errs) for err := range errs { if err != nil { - t.Fatalf("concurrent session creation error: %v", err) + t.Fatalf("concurrent error: %v", err) } } } func TestServerInfo(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) - ts := newTestServer(t) - - resp, err := ts.doReq(http.MethodGet, ts.url("/api/v1/server"), nil) + resp, err := doRequest( + t, + http.MethodGet, + tserver.url("/api/v1/server"), + nil, + ) if err != nil { t.Fatal(err) } @@ -1179,16 +1427,21 @@ func TestServerInfo(t *testing.T) { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d", resp.StatusCode) + t.Fatalf( + "expected 200, got %d", resp.StatusCode, + ) } } func TestHealthcheck(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) - ts := newTestServer(t) - - resp, err := ts.doReq(http.MethodGet, ts.url("/.well-known/healthcheck.json"), nil) + resp, err := doRequest( + t, + http.MethodGet, + tserver.url("/.well-known/healthcheck.json"), + nil, + ) if err != nil { t.Fatal(err) } @@ -1196,48 +1449,50 @@ func TestHealthcheck(t *testing.T) { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d", resp.StatusCode) + t.Fatalf( + "expected 200, got %d", resp.StatusCode, + ) } var result map[string]any - err = json.NewDecoder(resp.Body).Decode(&result) - if err != nil { - t.Fatalf("decode healthcheck: %v", err) + decErr := json.NewDecoder(resp.Body).Decode(&result) + if decErr != nil { + t.Fatalf("decode healthcheck: %v", decErr) } - if result["status"] != "ok" { - t.Fatalf("expected ok status, got %v", result["status"]) + if result[statusKey] != "ok" { + t.Fatalf( + "expected ok status, got %v", + result[statusKey], + ) } } func TestNickBroadcastToChannels(t *testing.T) { - t.Parallel() + tserver := newTestServer(t) + aliceToken := tserver.createSession("nick_a") + bobToken := tserver.createSession("nick_b") - ts := newTestServer(t) - aliceToken := ts.createSession("nick_a") - bobToken := ts.createSession("nick_b") - - ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#nicktest"}) - ts.sendCommand(bobToken, map[string]any{"command": "JOIN", "to": "#nicktest"}) - - _, lastID := ts.pollMessages(bobToken, 0) - - ts.sendCommand(aliceToken, map[string]any{ - "command": "NICK", "body": []string{"nick_a_new"}, + tserver.sendCommand(aliceToken, map[string]any{ + commandKey: joinCmd, toKey: "#nicktest", + }) + tserver.sendCommand(bobToken, map[string]any{ + commandKey: joinCmd, toKey: "#nicktest", }) - msgs, _ := ts.pollMessages(bobToken, lastID) + _, lastID := tserver.pollMessages(bobToken, 0) - found := false + tserver.sendCommand(aliceToken, map[string]any{ + commandKey: "NICK", + bodyKey: []string{"nick_a_new"}, + }) - for _, m := range msgs { - if m["command"] == "NICK" && m["from"] == "nick_a" { - found = true - } - } + msgs, _ := tserver.pollMessages(bobToken, lastID) - if !found { - t.Fatalf("bob didn't get nick change: %v", msgs) + if !findMessage(msgs, "NICK", "nick_a") { + t.Fatalf( + "bob didn't get nick change: %v", msgs, + ) } } diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index a4f3bd5..d6e2014 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -40,40 +40,59 @@ type Handlers struct { // New creates a new Handlers instance. func New( - lc fx.Lifecycle, + lifecycle fx.Lifecycle, params Params, ) (*Handlers, error) { - s := new(Handlers) - s.params = ¶ms - s.log = params.Logger.Get() - s.hc = params.Healthcheck - s.broker = broker.New() + hdlr := &Handlers{ + params: ¶ms, + log: params.Logger.Get(), + hc: params.Healthcheck, + broker: broker.New(), + } - lc.Append(fx.Hook{ + lifecycle.Append(fx.Hook{ OnStart: func(_ context.Context) error { return nil }, + OnStop: func(_ context.Context) error { + return nil + }, }) - return s, nil + return hdlr, nil } -func (s *Handlers) respondJSON( - w http.ResponseWriter, +func (hdlr *Handlers) respondJSON( + writer http.ResponseWriter, _ *http.Request, data any, status int, ) { - w.Header().Set( + writer.Header().Set( "Content-Type", "application/json; charset=utf-8", ) - w.WriteHeader(status) + writer.WriteHeader(status) if data != nil { - err := json.NewEncoder(w).Encode(data) + err := json.NewEncoder(writer).Encode(data) if err != nil { - s.log.Error("json encode error", "error", err) + hdlr.log.Error( + "json encode error", "error", err, + ) } } } + +func (hdlr *Handlers) respondError( + writer http.ResponseWriter, + request *http.Request, + msg string, + status int, +) { + hdlr.respondJSON( + writer, request, + map[string]string{"error": msg}, + status, + ) +} diff --git a/internal/handlers/healthcheck.go b/internal/handlers/healthcheck.go index 1666ebb..99f0af2 100644 --- a/internal/handlers/healthcheck.go +++ b/internal/handlers/healthcheck.go @@ -7,9 +7,12 @@ import ( const httpStatusOK = 200 // HandleHealthCheck returns an HTTP handler for the health check endpoint. -func (s *Handlers) HandleHealthCheck() http.HandlerFunc { - return func(w http.ResponseWriter, req *http.Request) { - resp := s.hc.Healthcheck() - s.respondJSON(w, req, resp, httpStatusOK) +func (hdlr *Handlers) HandleHealthCheck() http.HandlerFunc { + return func( + writer http.ResponseWriter, + request *http.Request, + ) { + resp := hdlr.hc.Healthcheck() + hdlr.respondJSON(writer, request, resp, httpStatusOK) } } diff --git a/internal/healthcheck/healthcheck.go b/internal/healthcheck/healthcheck.go index 37b2983..2aacc84 100644 --- a/internal/healthcheck/healthcheck.go +++ b/internal/healthcheck/healthcheck.go @@ -33,14 +33,17 @@ type Healthcheck struct { } // New creates a new Healthcheck instance. -func New(lc fx.Lifecycle, params Params) (*Healthcheck, error) { - s := new(Healthcheck) - s.params = ¶ms - s.log = params.Logger.Get() +func New( + lifecycle fx.Lifecycle, params Params, +) (*Healthcheck, error) { + hcheck := &Healthcheck{ //nolint:exhaustruct // StartupTime set in OnStart + params: ¶ms, + log: params.Logger.Get(), + } - lc.Append(fx.Hook{ + lifecycle.Append(fx.Hook{ OnStart: func(_ context.Context) error { - s.StartupTime = time.Now() + hcheck.StartupTime = time.Now() return nil }, @@ -49,7 +52,7 @@ func New(lc fx.Lifecycle, params Params) (*Healthcheck, error) { }, }) - return s, nil + return hcheck, nil } // Response is the JSON response returned by the health endpoint. @@ -64,19 +67,18 @@ type Response struct { } // Healthcheck returns the current health status of the server. -func (s *Healthcheck) Healthcheck() *Response { - resp := &Response{ +func (hcheck *Healthcheck) Healthcheck() *Response { + return &Response{ Status: "ok", Now: time.Now().UTC().Format(time.RFC3339Nano), - UptimeSeconds: int64(s.uptime().Seconds()), - UptimeHuman: s.uptime().String(), - Appname: s.params.Globals.Appname, - Version: s.params.Globals.Version, + UptimeSeconds: int64(hcheck.uptime().Seconds()), + UptimeHuman: hcheck.uptime().String(), + Appname: hcheck.params.Globals.Appname, + Version: hcheck.params.Globals.Version, + Maintenance: hcheck.params.Config.MaintenanceMode, } - - return resp } -func (s *Healthcheck) uptime() time.Duration { - return time.Since(s.StartupTime) +func (hcheck *Healthcheck) uptime() time.Duration { + return time.Since(hcheck.StartupTime) } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 42b4fa2..518c86a 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -23,51 +23,56 @@ type Logger struct { params Params } -// New creates a new Logger with appropriate handler based on terminal detection. -func New(_ fx.Lifecycle, params Params) (*Logger, error) { - l := new(Logger) - l.level = new(slog.LevelVar) - l.level.Set(slog.LevelInfo) +// New creates a new Logger with appropriate handler +// based on terminal detection. +func New( + _ fx.Lifecycle, params Params, +) (*Logger, error) { + logger := new(Logger) + logger.level = new(slog.LevelVar) + logger.level.Set(slog.LevelInfo) tty := false + if fileInfo, _ := os.Stdout.Stat(); (fileInfo.Mode() & os.ModeCharDevice) != 0 { tty = true } - var handler slog.Handler - if tty { - handler = slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - Level: l.level, - AddSource: true, - }) - } else { - handler = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ - Level: l.level, - AddSource: true, - }) + opts := &slog.HandlerOptions{ //nolint:exhaustruct // ReplaceAttr optional + Level: logger.level, + AddSource: true, } - l.log = slog.New(handler) - l.params = params + var handler slog.Handler + if tty { + handler = slog.NewTextHandler(os.Stdout, opts) + } else { + handler = slog.NewJSONHandler(os.Stdout, opts) + } - return l, nil + logger.log = slog.New(handler) + logger.params = params + + return logger, nil } // EnableDebugLogging switches the log level to debug. -func (l *Logger) EnableDebugLogging() { - l.level.Set(slog.LevelDebug) - l.log.Debug("debug logging enabled", "debug", true) +func (logger *Logger) EnableDebugLogging() { + logger.level.Set(slog.LevelDebug) + logger.log.Debug( + "debug logging enabled", "debug", true, + ) } // Get returns the underlying slog.Logger. -func (l *Logger) Get() *slog.Logger { - return l.log +func (logger *Logger) Get() *slog.Logger { + return logger.log } // Identify logs the application name and version at startup. -func (l *Logger) Identify() { - l.log.Info("starting", - "appname", l.params.Globals.Appname, - "version", l.params.Globals.Version, +func (logger *Logger) Identify() { + logger.log.Info("starting", + "appname", logger.params.Globals.Appname, + "version", logger.params.Globals.Version, ) } diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index f048f58..a69c58c 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -11,7 +11,7 @@ import ( "git.eeqj.de/sneak/chat/internal/globals" "git.eeqj.de/sneak/chat/internal/logger" basicauth "github.com/99designs/basicauth-go" - "github.com/go-chi/chi/middleware" + chimw "github.com/go-chi/chi/middleware" "github.com/go-chi/cors" metrics "github.com/slok/go-http-metrics/metrics/prometheus" ghmm "github.com/slok/go-http-metrics/middleware" @@ -38,25 +38,28 @@ type Middleware struct { } // New creates a new Middleware instance. -func New(_ fx.Lifecycle, params Params) (*Middleware, error) { - s := new(Middleware) - s.params = ¶ms - s.log = params.Logger.Get() +func New( + _ fx.Lifecycle, params Params, +) (*Middleware, error) { + mware := &Middleware{ + params: ¶ms, + log: params.Logger.Get(), + } - return s, nil + return mware, nil } -func ipFromHostPort(hp string) string { - h, _, err := net.SplitHostPort(hp) +func ipFromHostPort(hostPort string) string { + host, _, err := net.SplitHostPort(hostPort) if err != nil { return "" } - if len(h) > 0 && h[0] == '[' { - return h[1 : len(h)-1] + if len(host) > 0 && host[0] == '[' { + return host[1 : len(host)-1] } - return h + return host } type loggingResponseWriter struct { @@ -65,9 +68,15 @@ type loggingResponseWriter struct { statusCode int } -// newLoggingResponseWriter wraps a ResponseWriter to capture the status code. -func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { - return &loggingResponseWriter{w, http.StatusOK} +// newLoggingResponseWriter wraps a ResponseWriter +// to capture the status code. +func newLoggingResponseWriter( + writer http.ResponseWriter, +) *loggingResponseWriter { + return &loggingResponseWriter{ + ResponseWriter: writer, + statusCode: http.StatusOK, + } } func (lrw *loggingResponseWriter) WriteHeader(code int) { @@ -76,43 +85,57 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) { } // Logging returns middleware that logs each HTTP request. -func (s *Middleware) Logging() func(http.Handler) http.Handler { +func (mware *Middleware) Logging() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - lrw := newLoggingResponseWriter(w) - ctx := r.Context() + return http.HandlerFunc( + func( + writer http.ResponseWriter, + request *http.Request, + ) { + start := time.Now() + lrw := newLoggingResponseWriter(writer) + ctx := request.Context() - defer func() { - latency := time.Since(start) + defer func() { + latency := time.Since(start) - reqID, _ := ctx.Value(middleware.RequestIDKey).(string) + reqID, _ := ctx.Value( + chimw.RequestIDKey, + ).(string) - s.log.InfoContext(ctx, "request", - "request_start", start, - "method", r.Method, - "url", r.URL.String(), - "useragent", r.UserAgent(), - "request_id", reqID, - "referer", r.Referer(), - "proto", r.Proto, - "remoteIP", ipFromHostPort(r.RemoteAddr), - "status", lrw.statusCode, - "latency_ms", latency.Milliseconds(), - ) - }() + mware.log.InfoContext( + ctx, "request", + "request_start", start, + "method", request.Method, + "url", request.URL.String(), + "useragent", request.UserAgent(), + "request_id", reqID, + "referer", request.Referer(), + "proto", request.Proto, + "remoteIP", + ipFromHostPort(request.RemoteAddr), + "status", lrw.statusCode, + "latency_ms", + latency.Milliseconds(), + ) + }() - next.ServeHTTP(lrw, r) - }) + next.ServeHTTP(lrw, request) + }) } } // CORS returns middleware that handles Cross-Origin Resource Sharing. -func (s *Middleware) CORS() func(http.Handler) http.Handler { - return cors.Handler(cors.Options{ - AllowedOrigins: []string{"*"}, - AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, +func (mware *Middleware) CORS() func(http.Handler) http.Handler { + return cors.Handler(cors.Options{ //nolint:exhaustruct // optional fields + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{ + "GET", "POST", "PUT", "DELETE", "OPTIONS", + }, + AllowedHeaders: []string{ + "Accept", "Authorization", + "Content-Type", "X-CSRF-Token", + }, ExposedHeaders: []string{"Link"}, AllowCredentials: false, MaxAge: corsMaxAge, @@ -120,28 +143,34 @@ func (s *Middleware) CORS() func(http.Handler) http.Handler { } // Auth returns middleware that performs authentication. -func (s *Middleware) Auth() func(http.Handler) http.Handler { +func (mware *Middleware) Auth() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - s.log.Info("AUTH: before request") - next.ServeHTTP(w, r) - }) + return http.HandlerFunc( + func( + writer http.ResponseWriter, + request *http.Request, + ) { + mware.log.Info("AUTH: before request") + next.ServeHTTP(writer, request) + }) } } // Metrics returns middleware that records HTTP metrics. -func (s *Middleware) Metrics() func(http.Handler) http.Handler { - mdlw := ghmm.New(ghmm.Config{ - Recorder: metrics.NewRecorder(metrics.Config{}), +func (mware *Middleware) Metrics() func(http.Handler) http.Handler { + metricsMiddleware := ghmm.New(ghmm.Config{ //nolint:exhaustruct // optional fields + Recorder: metrics.NewRecorder( + metrics.Config{}, //nolint:exhaustruct // defaults + ), }) return func(next http.Handler) http.Handler { - return std.Handler("", mdlw, next) + return std.Handler("", metricsMiddleware, next) } } // MetricsAuth returns middleware that protects metrics with basic auth. -func (s *Middleware) MetricsAuth() func(http.Handler) http.Handler { +func (mware *Middleware) MetricsAuth() func(http.Handler) http.Handler { return basicauth.New( "metrics", map[string][]string{ diff --git a/internal/server/routes.go b/internal/server/routes.go index c9ad7c7..ba49ad9 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -17,67 +17,94 @@ import ( const routeTimeout = 60 * time.Second // SetupRoutes configures the HTTP routes and middleware. -func (s *Server) SetupRoutes() { - s.router = chi.NewRouter() +func (srv *Server) SetupRoutes() { + srv.router = chi.NewRouter() - s.router.Use(middleware.Recoverer) - s.router.Use(middleware.RequestID) - s.router.Use(s.mw.Logging()) + srv.router.Use(middleware.Recoverer) + srv.router.Use(middleware.RequestID) + srv.router.Use(srv.mw.Logging()) if viper.GetString("METRICS_USERNAME") != "" { - s.router.Use(s.mw.Metrics()) + srv.router.Use(srv.mw.Metrics()) } - s.router.Use(s.mw.CORS()) - s.router.Use(middleware.Timeout(routeTimeout)) + srv.router.Use(srv.mw.CORS()) + srv.router.Use(middleware.Timeout(routeTimeout)) - if s.sentryEnabled { - sentryHandler := sentryhttp.New(sentryhttp.Options{ - Repanic: true, - }) - s.router.Use(sentryHandler.Handle) + if srv.sentryEnabled { + sentryHandler := sentryhttp.New( + sentryhttp.Options{ //nolint:exhaustruct // optional fields + Repanic: true, + }, + ) + + srv.router.Use(sentryHandler.Handle) } - // Health check - s.router.Get( + // Health check. + srv.router.Get( "/.well-known/healthcheck.json", - s.h.HandleHealthCheck(), + srv.handlers.HandleHealthCheck(), ) - // Protected metrics endpoint + // Protected metrics endpoint. if viper.GetString("METRICS_USERNAME") != "" { - s.router.Group(func(r chi.Router) { - r.Use(s.mw.MetricsAuth()) - r.Get("/metrics", + srv.router.Group(func(router chi.Router) { + router.Use(srv.mw.MetricsAuth()) + router.Get("/metrics", http.HandlerFunc( promhttp.Handler().ServeHTTP, )) }) } - // API v1 - s.router.Route("/api/v1", func(r chi.Router) { - r.Get("/server", s.h.HandleServerInfo()) - r.Post("/session", s.h.HandleCreateSession()) - r.Get("/state", s.h.HandleState()) - r.Get("/messages", s.h.HandleGetMessages()) - r.Post("/messages", s.h.HandleSendCommand()) - r.Get("/history", s.h.HandleGetHistory()) - r.Get("/channels", s.h.HandleListAllChannels()) - r.Get( - "/channels/{channel}/members", - s.h.HandleChannelMembers(), - ) - }) + // API v1. + srv.router.Route( + "/api/v1", + func(router chi.Router) { + router.Get( + "/server", + srv.handlers.HandleServerInfo(), + ) + router.Post( + "/session", + srv.handlers.HandleCreateSession(), + ) + router.Get( + "/state", + srv.handlers.HandleState(), + ) + router.Get( + "/messages", + srv.handlers.HandleGetMessages(), + ) + router.Post( + "/messages", + srv.handlers.HandleSendCommand(), + ) + router.Get( + "/history", + srv.handlers.HandleGetHistory(), + ) + router.Get( + "/channels", + srv.handlers.HandleListAllChannels(), + ) + router.Get( + "/channels/{channel}/members", + srv.handlers.HandleChannelMembers(), + ) + }, + ) - // Serve embedded SPA - s.setupSPA() + // Serve embedded SPA. + srv.setupSPA() } -func (s *Server) setupSPA() { +func (srv *Server) setupSPA() { distFS, err := fs.Sub(web.Dist, "dist") if err != nil { - s.log.Error( + srv.log.Error( "failed to get web dist filesystem", "error", err, ) @@ -87,38 +114,40 @@ func (s *Server) setupSPA() { fileServer := http.FileServer(http.FS(distFS)) - s.router.Get("/*", func( - w http.ResponseWriter, - r *http.Request, + srv.router.Get("/*", func( + writer http.ResponseWriter, + request *http.Request, ) { readFS, ok := distFS.(fs.ReadFileFS) if !ok { - fileServer.ServeHTTP(w, r) + fileServer.ServeHTTP(writer, request) return } - f, readErr := readFS.ReadFile(r.URL.Path[1:]) - if readErr != nil || len(f) == 0 { + fileData, readErr := readFS.ReadFile( + request.URL.Path[1:], + ) + if readErr != nil || len(fileData) == 0 { indexHTML, indexErr := readFS.ReadFile( "index.html", ) if indexErr != nil { - http.NotFound(w, r) + http.NotFound(writer, request) return } - w.Header().Set( + writer.Header().Set( "Content-Type", "text/html; charset=utf-8", ) - w.WriteHeader(http.StatusOK) - _, _ = w.Write(indexHTML) + writer.WriteHeader(http.StatusOK) + _, _ = writer.Write(indexHTML) return } - fileServer.ServeHTTP(w, r) + fileServer.ServeHTTP(writer, request) }) } diff --git a/internal/server/server.go b/internal/server/server.go index f19af2c..b6d04c5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -41,7 +41,8 @@ type Params struct { Handlers *handlers.Handlers } -// Server is the main HTTP server. It manages routing, middleware, and lifecycle. +// Server is the main HTTP server. +// It manages routing, middleware, and lifecycle. type Server struct { startupTime time.Time exitCode int @@ -53,21 +54,24 @@ type Server struct { router *chi.Mux params Params mw *middleware.Middleware - h *handlers.Handlers + handlers *handlers.Handlers } // New creates a new Server and registers its lifecycle hooks. -func New(lc fx.Lifecycle, params Params) (*Server, error) { - s := new(Server) - s.params = params - s.mw = params.Middleware - s.h = params.Handlers - s.log = params.Logger.Get() +func New( + lifecycle fx.Lifecycle, params Params, +) (*Server, error) { + srv := &Server{ //nolint:exhaustruct // fields set during lifecycle + params: params, + mw: params.Middleware, + handlers: params.Handlers, + log: params.Logger.Get(), + } - lc.Append(fx.Hook{ + lifecycle.Append(fx.Hook{ OnStart: func(_ context.Context) error { - s.startupTime = time.Now() - go s.Run() //nolint:contextcheck + srv.startupTime = time.Now() + go srv.Run() //nolint:contextcheck return nil }, @@ -76,122 +80,140 @@ func New(lc fx.Lifecycle, params Params) (*Server, error) { }, }) - return s, nil + return srv, nil } // Run starts the server configuration, Sentry, and begins serving. -func (s *Server) Run() { - s.configure() - s.enableSentry() - s.serve() +func (srv *Server) Run() { + srv.configure() + srv.enableSentry() + srv.serve() } // ServeHTTP delegates to the chi router. -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - s.router.ServeHTTP(w, r) +func (srv *Server) ServeHTTP( + writer http.ResponseWriter, + request *http.Request, +) { + srv.router.ServeHTTP(writer, request) } // MaintenanceMode reports whether the server is in maintenance mode. -func (s *Server) MaintenanceMode() bool { - return s.params.Config.MaintenanceMode +func (srv *Server) MaintenanceMode() bool { + return srv.params.Config.MaintenanceMode } -func (s *Server) enableSentry() { - s.sentryEnabled = false +func (srv *Server) enableSentry() { + srv.sentryEnabled = false - if s.params.Config.SentryDSN == "" { + if srv.params.Config.SentryDSN == "" { return } - err := sentry.Init(sentry.ClientOptions{ - Dsn: s.params.Config.SentryDSN, - Release: fmt.Sprintf("%s-%s", s.params.Globals.Appname, s.params.Globals.Version), + err := sentry.Init(sentry.ClientOptions{ //nolint:exhaustruct // only essential fields + Dsn: srv.params.Config.SentryDSN, + Release: fmt.Sprintf( + "%s-%s", + srv.params.Globals.Appname, + srv.params.Globals.Version, + ), }) if err != nil { - s.log.Error("sentry init failure", "error", err) + srv.log.Error("sentry init failure", "error", err) os.Exit(1) } - s.log.Info("sentry error reporting activated") - s.sentryEnabled = true + srv.log.Info("sentry error reporting activated") + srv.sentryEnabled = true } -func (s *Server) serve() int { - s.ctx, s.cancelFunc = context.WithCancel(context.Background()) +func (srv *Server) serve() int { + srv.ctx, srv.cancelFunc = context.WithCancel( + context.Background(), + ) go func() { - c := make(chan os.Signal, 1) + sigCh := make(chan os.Signal, 1) signal.Ignore(syscall.SIGPIPE) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - sig := <-c - s.log.Info("signal received", "signal", sig) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) - if s.cancelFunc != nil { - s.cancelFunc() + sig := <-sigCh + + srv.log.Info("signal received", "signal", sig) + + if srv.cancelFunc != nil { + srv.cancelFunc() } }() - go s.serveUntilShutdown() + go srv.serveUntilShutdown() - <-s.ctx.Done() + <-srv.ctx.Done() - s.cleanShutdown() + srv.cleanShutdown() - return s.exitCode + return srv.exitCode } -func (s *Server) cleanupForExit() { - s.log.Info("cleaning up") +func (srv *Server) cleanupForExit() { + srv.log.Info("cleaning up") } -func (s *Server) cleanShutdown() { - s.exitCode = 0 +func (srv *Server) cleanShutdown() { + srv.exitCode = 0 ctxShutdown, shutdownCancel := context.WithTimeout( context.Background(), shutdownTimeout, ) - err := s.httpServer.Shutdown(ctxShutdown) + err := srv.httpServer.Shutdown(ctxShutdown) if err != nil { - s.log.Error("server clean shutdown failed", "error", err) + srv.log.Error( + "server clean shutdown failed", "error", err, + ) } if shutdownCancel != nil { shutdownCancel() } - s.cleanupForExit() + srv.cleanupForExit() - if s.sentryEnabled { + if srv.sentryEnabled { sentry.Flush(sentryFlushTime) } } -func (s *Server) configure() { - // server configuration placeholder +func (srv *Server) configure() { + // Server configuration placeholder. } -func (s *Server) serveUntilShutdown() { - listenAddr := fmt.Sprintf(":%d", s.params.Config.Port) - s.httpServer = &http.Server{ +func (srv *Server) serveUntilShutdown() { + listenAddr := fmt.Sprintf( + ":%d", srv.params.Config.Port, + ) + + srv.httpServer = &http.Server{ //nolint:exhaustruct // optional fields Addr: listenAddr, ReadTimeout: httpReadTimeout, WriteTimeout: httpWriteTimeout, MaxHeaderBytes: maxHeaderBytes, - Handler: s, + Handler: srv, } - s.SetupRoutes() + srv.SetupRoutes() - s.log.Info("http begin listen", "listenaddr", listenAddr) + srv.log.Info( + "http begin listen", "listenaddr", listenAddr, + ) - err := s.httpServer.ListenAndServe() + err := srv.httpServer.ListenAndServe() if err != nil && !errors.Is(err, http.ErrServerClosed) { - s.log.Error("listen error", "error", err) + srv.log.Error("listen error", "error", err) - if s.cancelFunc != nil { - s.cancelFunc() + if srv.cancelFunc != nil { + srv.cancelFunc() } } }