diff --git a/README.md b/README.md index 37525cd..4bf0583 100644 --- a/README.md +++ b/README.md @@ -207,6 +207,26 @@ removal. Identity verification at the message layer via cryptographic signatures (see [Security Model](#security-model)) remains independent of account registration. +### Hostmask (nick!user@host) + +Each session has an IRC-style hostmask composed of three parts: + +- **nick** — the user's current nick (changes with `NICK` command) +- **username** — an ident-like identifier set at session creation (optional + `username` field in the session/register request; defaults to the nick) +- **hostname** — automatically resolved via reverse DNS of the connecting + client's IP address at session creation time + +The hostmask appears in: + +- **WHOIS** (`311 RPL_WHOISUSER`) — `params` contains + `[nick, username, hostname, "*"]` +- **WHO** (`352 RPL_WHOREPLY`) — `params` contains + `[channel, username, hostname, server, nick, flags]` + +The hostmask format (`nick!user@host`) is stored for future use in ban matching +(`+b` mode) and other access control features. + ### Nick Semantics - Nicks are **unique per server at any point in time** — two sessions cannot @@ -976,7 +996,7 @@ the server to the client (never C2S) and use 3-digit string codes in the | `252` | RPL_LUSEROP | On connect or LUSERS command | `{"command":"252","to":"alice","params":["0"],"body":["operator(s) online"]}` | | `254` | RPL_LUSERCHANNELS | On connect or LUSERS command | `{"command":"254","to":"alice","params":["3"],"body":["channels formed"]}` | | `255` | RPL_LUSERME | On connect or LUSERS command | `{"command":"255","to":"alice","body":["I have 5 clients and 1 servers"]}` | -| `311` | RPL_WHOISUSER | In response to WHOIS | `{"command":"311","to":"alice","params":["bob","bob","neoirc","*"],"body":["bob"]}` | +| `311` | RPL_WHOISUSER | In response to WHOIS | `{"command":"311","to":"alice","params":["bob","bobident","host.example.com","*"],"body":["bob"]}` | | `312` | RPL_WHOISSERVER | In response to WHOIS | `{"command":"312","to":"alice","params":["bob","neoirc"],"body":["neoirc server"]}` | | `315` | RPL_ENDOFWHO | End of WHO response | `{"command":"315","to":"alice","params":["#general"],"body":["End of /WHO list"]}` | | `318` | RPL_ENDOFWHOIS | End of WHOIS response | `{"command":"318","to":"alice","params":["bob"],"body":["End of /WHOIS list"]}` | @@ -987,7 +1007,7 @@ the server to the client (never C2S) and use 3-digit string codes in the | `329` | RPL_CREATIONTIME | After channel MODE query | `{"command":"329","to":"alice","params":["#general","1709251200"]}` | | `331` | RPL_NOTOPIC | Channel has no topic (on JOIN) | `{"command":"331","to":"alice","params":["#general"],"body":["No topic is set"]}` | | `332` | RPL_TOPIC | On JOIN or TOPIC query | `{"command":"332","to":"alice","params":["#general"],"body":["Welcome!"]}` | -| `352` | RPL_WHOREPLY | In response to WHO | `{"command":"352","to":"alice","params":["#general","bob","neoirc","neoirc","bob","H"],"body":["0 bob"]}` | +| `352` | RPL_WHOREPLY | In response to WHO | `{"command":"352","to":"alice","params":["#general","bobident","host.example.com","neoirc","bob","H"],"body":["0 bob"]}` | | `353` | RPL_NAMREPLY | On JOIN or NAMES query | `{"command":"353","to":"alice","params":["=","#general"],"body":["@op1 alice bob +voiced1"]}` | | `366` | RPL_ENDOFNAMES | End of NAMES response | `{"command":"366","to":"alice","params":["#general"],"body":["End of /NAMES list"]}` | | `372` | RPL_MOTD | MOTD line | `{"command":"372","to":"alice","body":["Welcome to the server"]}` | @@ -1056,14 +1076,20 @@ difficulty is advertised via `GET /api/v1/server` in the `hashcash_bits` field. **Request Body:** ```json -{"nick": "alice", "pow_token": "1:20:260310:neoirc::3a2f1"} +{"nick": "alice", "username": "alice", "pow_token": "1:20:260310:neoirc::3a2f1"} ``` | Field | Type | Required | Constraints | |------------|--------|-------------|-------------| | `nick` | string | Yes | 1–32 characters, must be unique on the server | +| `username` | string | No | 1–32 characters, IRC ident-style. Defaults to nick if omitted. | | `pow_token` | string | Conditional | Hashcash stamp (required when server has `hashcash_bits` > 0) | +The `username` field sets the user portion of the IRC hostmask +(`nick!user@host`). The hostname is automatically resolved via reverse DNS of +the connecting client's IP address at session creation time. Together these form +the hostmask used in WHOIS, WHO, and future ban matching (`+b`). + **Response:** `201 Created` ```json { @@ -1084,6 +1110,7 @@ difficulty is advertised via `GET /api/v1/server` in the `hashcash_bits` field. | Status | Error | When | |--------|-------|------| | 400 | `nick must be 1-32 characters` | Empty or too-long nick | +| 400 | `invalid username format` | Username doesn't match allowed format | | 402 | `hashcash proof-of-work required` | Missing `pow_token` field in request body when hashcash is enabled | | 402 | `invalid hashcash stamp: ...` | Stamp fails validation (wrong bits, expired, reused, etc.) | | 409 | `nick already taken` | Another active session holds this nick | @@ -1105,14 +1132,18 @@ remains active. **Request Body:** ```json -{"nick": "alice", "password": "mypassword"} +{"nick": "alice", "username": "alice", "password": "mypassword"} ``` | Field | Type | Required | Constraints | |------------|--------|----------|-------------| | `nick` | string | Yes | 1–32 characters, must be unique on the server | +| `username` | string | No | 1–32 characters, IRC ident-style. Defaults to nick if omitted. | | `password` | string | Yes | Minimum 8 characters | +The `username` and hostname (auto-resolved via reverse DNS) form the IRC +hostmask (`nick!user@host`) shown in WHOIS and WHO responses. + **Response:** `201 Created` ```json { @@ -1133,6 +1164,7 @@ remains active. | Status | Error | When | |--------|-------|------| | 400 | `invalid nick format` | Nick doesn't match allowed format | +| 400 | `invalid username format` | Username doesn't match allowed format | | 400 | `password must be at least 8 characters` | Password too short | | 409 | `nick already taken` | Another active session holds this nick | @@ -1941,6 +1973,8 @@ The database schema is managed via embedded SQL migration files in | `id` | INTEGER | Primary key (auto-increment) | | `uuid` | TEXT | Unique session UUID | | `nick` | TEXT | Unique nick | +| `username` | TEXT | IRC ident/username portion of the hostmask (defaults to nick) | +| `hostname` | TEXT | Reverse DNS hostname of the connecting client IP | | `password_hash`| TEXT | bcrypt hash (empty string for anonymous sessions) | | `signing_key` | TEXT | Public signing key (empty string if unset) | | `away_message` | TEXT | Away message (empty string if not away) | diff --git a/internal/db/auth.go b/internal/db/auth.go index 7bf18bd..dc02ba4 100644 --- a/internal/db/auth.go +++ b/internal/db/auth.go @@ -20,8 +20,12 @@ var errNoPassword = errors.New( // and returns session ID, client ID, and token. func (database *Database) RegisterUser( ctx context.Context, - nick, password string, + nick, password, username, hostname string, ) (int64, int64, string, error) { + if username == "" { + username = nick + } + hash, err := bcrypt.GenerateFromPassword( []byte(password), bcryptCost, ) @@ -50,10 +54,11 @@ func (database *Database) RegisterUser( res, err := transaction.ExecContext(ctx, `INSERT INTO sessions - (uuid, nick, password_hash, - created_at, last_seen) - VALUES (?, ?, ?, ?, ?)`, - sessionUUID, nick, string(hash), now, now) + (uuid, nick, username, hostname, + password_hash, created_at, last_seen) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + sessionUUID, nick, username, hostname, + string(hash), now, now) if err != nil { _ = transaction.Rollback() diff --git a/internal/db/auth_test.go b/internal/db/auth_test.go index 5188925..084ea2d 100644 --- a/internal/db/auth_test.go +++ b/internal/db/auth_test.go @@ -13,7 +13,7 @@ func TestRegisterUser(t *testing.T) { ctx := t.Context() sessionID, clientID, token, err := - database.RegisterUser(ctx, "reguser", "password123") + database.RegisterUser(ctx, "reguser", "password123", "", "") if err != nil { t.Fatal(err) } @@ -38,6 +38,69 @@ func TestRegisterUser(t *testing.T) { } } +func TestRegisterUserWithUserHost(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + sessionID, _, _, err := database.RegisterUser( + ctx, "reguhost", "password123", + "myident", "example.org", + ) + if err != nil { + t.Fatal(err) + } + + info, err := database.GetSessionHostInfo( + ctx, sessionID, + ) + if err != nil { + t.Fatal(err) + } + + if info.Username != "myident" { + t.Fatalf( + "expected myident, got %s", info.Username, + ) + } + + if info.Hostname != "example.org" { + t.Fatalf( + "expected example.org, got %s", + info.Hostname, + ) + } +} + +func TestRegisterUserDefaultUsername(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + sessionID, _, _, err := database.RegisterUser( + ctx, "regdefault", "password123", "", "", + ) + if err != nil { + t.Fatal(err) + } + + info, err := database.GetSessionHostInfo( + ctx, sessionID, + ) + if err != nil { + t.Fatal(err) + } + + if info.Username != "regdefault" { + t.Fatalf( + "expected regdefault, got %s", + info.Username, + ) + } +} + func TestRegisterUserDuplicateNick(t *testing.T) { t.Parallel() @@ -45,7 +108,7 @@ func TestRegisterUserDuplicateNick(t *testing.T) { ctx := t.Context() regSID, regCID, regToken, err := - database.RegisterUser(ctx, "dupnick", "password123") + database.RegisterUser(ctx, "dupnick", "password123", "", "") if err != nil { t.Fatal(err) } @@ -55,7 +118,7 @@ func TestRegisterUserDuplicateNick(t *testing.T) { _ = regToken dupSID, dupCID, dupToken, dupErr := - database.RegisterUser(ctx, "dupnick", "other12345") + database.RegisterUser(ctx, "dupnick", "other12345", "", "") if dupErr == nil { t.Fatal("expected error for duplicate nick") } @@ -72,7 +135,7 @@ func TestLoginUser(t *testing.T) { ctx := t.Context() regSID, regCID, regToken, err := - database.RegisterUser(ctx, "loginuser", "mypassword") + database.RegisterUser(ctx, "loginuser", "mypassword", "", "") if err != nil { t.Fatal(err) } @@ -110,7 +173,7 @@ func TestLoginUserWrongPassword(t *testing.T) { ctx := t.Context() regSID, regCID, regToken, err := - database.RegisterUser(ctx, "wrongpw", "correctpass") + database.RegisterUser(ctx, "wrongpw", "correctpass", "", "") if err != nil { t.Fatal(err) } @@ -138,7 +201,7 @@ func TestLoginUserNoPassword(t *testing.T) { // Create anonymous session (no password). anonSID, anonCID, anonToken, err := - database.CreateSession(ctx, "anon") + database.CreateSession(ctx, "anon", "", "") if err != nil { t.Fatal(err) } diff --git a/internal/db/queries.go b/internal/db/queries.go index ff13174..a64b124 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -74,14 +74,40 @@ type ChannelInfo struct { type MemberInfo struct { ID int64 `json:"id"` Nick string `json:"nick"` + Username string `json:"username"` + Hostname string `json:"hostname"` LastSeen time.Time `json:"lastSeen"` } +// Hostmask returns the IRC hostmask in +// nick!user@host format. +func (m *MemberInfo) Hostmask() string { + return FormatHostmask(m.Nick, m.Username, m.Hostname) +} + +// FormatHostmask formats a nick, username, and hostname +// into a standard IRC hostmask string (nick!user@host). +func FormatHostmask(nick, username, hostname string) string { + if username == "" { + username = nick + } + + if hostname == "" { + hostname = "*" + } + + return nick + "!" + username + "@" + hostname +} + // CreateSession registers a new session and its first client. func (database *Database) CreateSession( ctx context.Context, - nick string, + nick, username, hostname string, ) (int64, int64, string, error) { + if username == "" { + username = nick + } + sessionUUID := uuid.New().String() clientUUID := uuid.New().String() @@ -101,9 +127,10 @@ func (database *Database) CreateSession( res, err := transaction.ExecContext(ctx, `INSERT INTO sessions - (uuid, nick, created_at, last_seen) - VALUES (?, ?, ?, ?)`, - sessionUUID, nick, now, now) + (uuid, nick, username, hostname, + created_at, last_seen) + VALUES (?, ?, ?, ?, ?, ?)`, + sessionUUID, nick, username, hostname, now, now) if err != nil { _ = transaction.Rollback() @@ -209,6 +236,35 @@ func (database *Database) GetSessionByNick( return sessionID, nil } +// SessionHostInfo holds the username and hostname for a session. +type SessionHostInfo struct { + Username string + Hostname string +} + +// GetSessionHostInfo returns the username and hostname +// for a session. +func (database *Database) GetSessionHostInfo( + ctx context.Context, + sessionID int64, +) (*SessionHostInfo, error) { + var info SessionHostInfo + + err := database.conn.QueryRowContext( + ctx, + `SELECT username, hostname + FROM sessions WHERE id = ?`, + sessionID, + ).Scan(&info.Username, &info.Hostname) + if err != nil { + return nil, fmt.Errorf( + "get session host info: %w", err, + ) + } + + return &info, nil +} + // GetChannelByName returns the channel ID for a name. func (database *Database) GetChannelByName( ctx context.Context, @@ -388,7 +444,8 @@ func (database *Database) ChannelMembers( channelID int64, ) ([]MemberInfo, error) { rows, err := database.conn.QueryContext(ctx, - `SELECT s.id, s.nick, s.last_seen + `SELECT s.id, s.nick, s.username, + s.hostname, s.last_seen FROM sessions s INNER JOIN channel_members cm ON cm.session_id = s.id @@ -408,7 +465,9 @@ func (database *Database) ChannelMembers( var member MemberInfo err = rows.Scan( - &member.ID, &member.Nick, &member.LastSeen, + &member.ID, &member.Nick, + &member.Username, &member.Hostname, + &member.LastSeen, ) if err != nil { return nil, fmt.Errorf( diff --git a/internal/db/queries_test.go b/internal/db/queries_test.go index 15814a2..3f75a0b 100644 --- a/internal/db/queries_test.go +++ b/internal/db/queries_test.go @@ -34,7 +34,7 @@ func TestCreateSession(t *testing.T) { ctx := t.Context() sessionID, _, token, err := database.CreateSession( - ctx, "alice", + ctx, "alice", "", "", ) if err != nil { t.Fatal(err) @@ -45,7 +45,7 @@ func TestCreateSession(t *testing.T) { } _, _, dupToken, dupErr := database.CreateSession( - ctx, "alice", + ctx, "alice", "", "", ) if dupErr == nil { t.Fatal("expected error for duplicate nick") @@ -54,13 +54,186 @@ func TestCreateSession(t *testing.T) { _ = dupToken } +// assertSessionHostInfo creates a session and verifies +// the stored username and hostname match expectations. +func assertSessionHostInfo( + t *testing.T, + database *db.Database, + nick, inputUser, inputHost, + expectUser, expectHost string, +) { + t.Helper() + + sessionID, _, _, err := database.CreateSession( + t.Context(), nick, inputUser, inputHost, + ) + if err != nil { + t.Fatal(err) + } + + info, err := database.GetSessionHostInfo( + t.Context(), sessionID, + ) + if err != nil { + t.Fatal(err) + } + + if info.Username != expectUser { + t.Fatalf( + "expected username %s, got %s", + expectUser, info.Username, + ) + } + + if info.Hostname != expectHost { + t.Fatalf( + "expected hostname %s, got %s", + expectHost, info.Hostname, + ) + } +} + +func TestCreateSessionWithUserHost(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + + assertSessionHostInfo( + t, database, + "hostuser", "myident", "example.com", + "myident", "example.com", + ) +} + +func TestCreateSessionDefaultUsername(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + + // Empty username defaults to nick. + assertSessionHostInfo( + t, database, + "defaultu", "", "host.local", + "defaultu", "host.local", + ) +} + +func TestGetSessionHostInfoNotFound(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + + _, err := database.GetSessionHostInfo( + t.Context(), 99999, + ) + if err == nil { + t.Fatal("expected error for nonexistent session") + } +} + +func TestFormatHostmask(t *testing.T) { + t.Parallel() + + result := db.FormatHostmask( + "nick", "user", "host.com", + ) + if result != "nick!user@host.com" { + t.Fatalf( + "expected nick!user@host.com, got %s", + result, + ) + } +} + +func TestFormatHostmaskDefaults(t *testing.T) { + t.Parallel() + + result := db.FormatHostmask("nick", "", "") + if result != "nick!nick@*" { + t.Fatalf( + "expected nick!nick@*, got %s", + result, + ) + } +} + +func TestMemberInfoHostmask(t *testing.T) { + t.Parallel() + + member := &db.MemberInfo{ //nolint:exhaustruct // test only uses hostmask fields + Nick: "alice", + Username: "aliceident", + Hostname: "alice.example.com", + } + + hostmask := member.Hostmask() + expected := "alice!aliceident@alice.example.com" + + if hostmask != expected { + t.Fatalf( + "expected %s, got %s", expected, hostmask, + ) + } +} + +func TestChannelMembersIncludeUserHost(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + sid, _, _, err := database.CreateSession( + ctx, "memuser", "myuser", "myhost.net", + ) + if err != nil { + t.Fatal(err) + } + + chID, err := database.GetOrCreateChannel( + ctx, "#hostchan", + ) + if err != nil { + t.Fatal(err) + } + + err = database.JoinChannel(ctx, chID, sid) + if err != nil { + t.Fatal(err) + } + + members, err := database.ChannelMembers(ctx, chID) + if err != nil { + t.Fatal(err) + } + + if len(members) != 1 { + t.Fatalf( + "expected 1 member, got %d", len(members), + ) + } + + if members[0].Username != "myuser" { + t.Fatalf( + "expected username myuser, got %s", + members[0].Username, + ) + } + + if members[0].Hostname != "myhost.net" { + t.Fatalf( + "expected hostname myhost.net, got %s", + members[0].Hostname, + ) + } +} + func TestGetSessionByToken(t *testing.T) { t.Parallel() database := setupTestDB(t) ctx := t.Context() - _, _, token, err := database.CreateSession(ctx, "bob") + _, _, token, err := database.CreateSession(ctx, "bob", "", "") if err != nil { t.Fatal(err) } @@ -93,7 +266,7 @@ func TestGetSessionByNick(t *testing.T) { ctx := t.Context() charlieID, charlieClientID, charlieToken, err := - database.CreateSession(ctx, "charlie") + database.CreateSession(ctx, "charlie", "", "") if err != nil { t.Fatal(err) } @@ -150,7 +323,7 @@ func TestJoinAndPart(t *testing.T) { database := setupTestDB(t) ctx := t.Context() - sid, _, _, err := database.CreateSession(ctx, "user1") + sid, _, _, err := database.CreateSession(ctx, "user1", "", "") if err != nil { t.Fatal(err) } @@ -199,7 +372,7 @@ func TestDeleteChannelIfEmpty(t *testing.T) { t.Fatal(err) } - sid, _, _, err := database.CreateSession(ctx, "temp") + sid, _, _, err := database.CreateSession(ctx, "temp", "", "") if err != nil { t.Fatal(err) } @@ -234,7 +407,7 @@ func createSessionWithChannels( ctx := t.Context() - sid, _, _, err := database.CreateSession(ctx, nick) + sid, _, _, err := database.CreateSession(ctx, nick, "", "") if err != nil { t.Fatal(err) } @@ -317,7 +490,7 @@ func TestChangeNick(t *testing.T) { ctx := t.Context() sid, _, token, err := database.CreateSession( - ctx, "old", + ctx, "old", "", "", ) if err != nil { t.Fatal(err) @@ -401,7 +574,7 @@ func TestPollMessages(t *testing.T) { ctx := t.Context() sid, _, token, err := database.CreateSession( - ctx, "poller", + ctx, "poller", "", "", ) if err != nil { t.Fatal(err) @@ -508,7 +681,7 @@ func TestDeleteSession(t *testing.T) { ctx := t.Context() sid, _, _, err := database.CreateSession( - ctx, "deleteme", + ctx, "deleteme", "", "", ) if err != nil { t.Fatal(err) @@ -548,12 +721,12 @@ func TestChannelMembers(t *testing.T) { database := setupTestDB(t) ctx := t.Context() - sid1, _, _, err := database.CreateSession(ctx, "m1") + sid1, _, _, err := database.CreateSession(ctx, "m1", "", "") if err != nil { t.Fatal(err) } - sid2, _, _, err := database.CreateSession(ctx, "m2") + sid2, _, _, err := database.CreateSession(ctx, "m2", "", "") if err != nil { t.Fatal(err) } @@ -611,7 +784,7 @@ func TestEnqueueToClient(t *testing.T) { ctx := t.Context() _, _, token, err := database.CreateSession( - ctx, "enqclient", + ctx, "enqclient", "", "", ) if err != nil { t.Fatal(err) diff --git a/internal/db/schema/001_initial.sql b/internal/db/schema/001_initial.sql index 4ea5e28..2366671 100644 --- a/internal/db/schema/001_initial.sql +++ b/internal/db/schema/001_initial.sql @@ -6,6 +6,8 @@ CREATE TABLE IF NOT EXISTS sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, uuid TEXT NOT NULL UNIQUE, nick TEXT NOT NULL UNIQUE, + username TEXT NOT NULL DEFAULT '', + hostname TEXT NOT NULL DEFAULT '', password_hash TEXT NOT NULL DEFAULT '', signing_key TEXT NOT NULL DEFAULT '', away_message TEXT NOT NULL DEFAULT '', diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 7a06e49..973fd51 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net" "net/http" "regexp" "strconv" @@ -23,6 +24,12 @@ var validChannelRe = regexp.MustCompile( `^#[a-zA-Z0-9_\-]{1,63}$`, ) +var validUsernameRe = regexp.MustCompile( + `^[a-zA-Z0-9_\-\[\]\\^{}|` + "`" + `]{1,32}$`, +) + +const dnsLookupTimeout = 3 * time.Second + const ( maxLongPollTimeout = 30 pollMessageLimit = 100 @@ -39,6 +46,55 @@ func (hdlr *Handlers) maxBodySize() int64 { return defaultMaxBodySize } +// clientIP extracts the connecting client's IP address +// from the request, checking X-Forwarded-For and +// X-Real-IP headers before falling back to RemoteAddr. +func clientIP(request *http.Request) string { + if forwarded := request.Header.Get("X-Forwarded-For"); forwarded != "" { + // X-Forwarded-For can contain a comma-separated list; + // the first entry is the original client. + parts := strings.SplitN(forwarded, ",", 2) //nolint:mnd + ip := strings.TrimSpace(parts[0]) + + if ip != "" { + return ip + } + } + + if realIP := request.Header.Get("X-Real-IP"); realIP != "" { + return strings.TrimSpace(realIP) + } + + host, _, err := net.SplitHostPort(request.RemoteAddr) + if err != nil { + return request.RemoteAddr + } + + return host +} + +// resolveHostname performs a reverse DNS lookup on the +// given IP address. Returns the first PTR record with the +// trailing dot stripped, or the raw IP if lookup fails. +func resolveHostname( + reqCtx context.Context, + addr string, +) string { + resolver := &net.Resolver{} //nolint:exhaustruct // using default resolver + + ctx, cancel := context.WithTimeout( + reqCtx, dnsLookupTimeout, + ) + defer cancel() + + names, err := resolver.LookupAddr(ctx, addr) + if err != nil || len(names) == 0 { + return addr + } + + return strings.TrimSuffix(names[0], ".") +} + // authSession extracts the session from the client token. func (hdlr *Handlers) authSession( request *http.Request, @@ -146,6 +202,7 @@ func (hdlr *Handlers) handleCreateSession( ) { type createRequest struct { Nick string `json:"nick"` + Username string `json:"username,omitempty"` Hashcash string `json:"pow_token,omitempty"` //nolint:tagliatelle } @@ -162,30 +219,10 @@ func (hdlr *Handlers) handleCreateSession( return } - // Validate hashcash proof-of-work if configured. - if hdlr.params.Config.HashcashBits > 0 { - if payload.Hashcash == "" { - hdlr.respondError( - writer, request, - "hashcash proof-of-work required", - http.StatusPaymentRequired, - ) - - return - } - - err = hdlr.hashcashVal.Validate( - payload.Hashcash, hdlr.params.Config.HashcashBits, - ) - if err != nil { - hdlr.respondError( - writer, request, - "invalid hashcash stamp: "+err.Error(), - http.StatusPaymentRequired, - ) - - return - } + if !hdlr.validateHashcash( + writer, request, payload.Hashcash, + ) { + return } payload.Nick = strings.TrimSpace(payload.Nick) @@ -200,9 +237,28 @@ func (hdlr *Handlers) handleCreateSession( return } + username := resolveUsername( + payload.Username, payload.Nick, + ) + + if !validUsernameRe.MatchString(username) { + hdlr.respondError( + writer, request, + "invalid username format", + http.StatusBadRequest, + ) + + return + } + + hostname := resolveHostname( + request.Context(), clientIP(request), + ) + sessionID, clientID, token, err := hdlr.params.Database.CreateSession( - request.Context(), payload.Nick, + request.Context(), + payload.Nick, username, hostname, ) if err != nil { hdlr.handleCreateSessionError( @@ -224,6 +280,55 @@ func (hdlr *Handlers) handleCreateSession( }, http.StatusCreated) } +// validateHashcash validates a hashcash stamp if required. +// Returns false if validation failed and a response was +// already sent. +func (hdlr *Handlers) validateHashcash( + writer http.ResponseWriter, + request *http.Request, + stamp string, +) bool { + if hdlr.params.Config.HashcashBits == 0 { + return true + } + + if stamp == "" { + hdlr.respondError( + writer, request, + "hashcash proof-of-work required", + http.StatusPaymentRequired, + ) + + return false + } + + err := hdlr.hashcashVal.Validate( + stamp, hdlr.params.Config.HashcashBits, + ) + if err != nil { + hdlr.respondError( + writer, request, + "invalid hashcash stamp: "+err.Error(), + http.StatusPaymentRequired, + ) + + return false + } + + return true +} + +// resolveUsername returns the trimmed username, defaulting +// to the nick if empty. +func resolveUsername(username, nick string) string { + username = strings.TrimSpace(username) + if username == "" { + return nick + } + + return username +} + func (hdlr *Handlers) handleCreateSessionError( writer http.ResponseWriter, request *http.Request, @@ -2105,10 +2210,26 @@ func (hdlr *Handlers) executeWhois( return } + // Look up username and hostname for the target. + username := queryNick + hostname := srvName + + hostInfo, hostErr := hdlr.params.Database. + GetSessionHostInfo(ctx, targetSID) + if hostErr == nil && hostInfo != nil { + if hostInfo.Username != "" { + username = hostInfo.Username + } + + if hostInfo.Hostname != "" { + hostname = hostInfo.Hostname + } + } + // 311 RPL_WHOISUSER hdlr.enqueueNumeric( ctx, clientID, irc.RplWhoisUser, nick, - []string{queryNick, queryNick, srvName, "*"}, + []string{queryNick, username, hostname, "*"}, queryNick, ) @@ -2215,11 +2336,21 @@ func (hdlr *Handlers) handleWho( ) if memErr == nil { for _, mem := range members { + username := mem.Username + if username == "" { + username = mem.Nick + } + + hostname := mem.Hostname + if hostname == "" { + hostname = srvName + } + // 352 RPL_WHOREPLY hdlr.enqueueNumeric( ctx, clientID, irc.RplWhoReply, nick, []string{ - channel, mem.Nick, srvName, + channel, username, hostname, srvName, mem.Nick, "H", }, "0 "+mem.Nick, diff --git a/internal/handlers/api_test.go b/internal/handlers/api_test.go index 68a1d2e..38dccff 100644 --- a/internal/handlers/api_test.go +++ b/internal/handlers/api_test.go @@ -2130,6 +2130,249 @@ func TestSessionStillWorks(t *testing.T) { } } +// findNumericWithParams returns the first message matching +// the given numeric code. Returns nil if not found. +func findNumericWithParams( + msgs []map[string]any, + numeric string, +) map[string]any { + want, _ := strconv.Atoi(numeric) + + for _, msg := range msgs { + code, ok := msg["code"].(float64) + if ok && int(code) == want { + return msg + } + } + + return nil +} + +// getNumericParams extracts the params array from a +// numeric message as a string slice. +func getNumericParams( + msg map[string]any, +) []string { + raw, exists := msg["params"] + if !exists || raw == nil { + return nil + } + + arr, isArr := raw.([]any) + if !isArr { + return nil + } + + result := make([]string, 0, len(arr)) + + for _, val := range arr { + str, isString := val.(string) + if isString { + result = append(result, str) + } + } + + return result +} + +func TestWhoisShowsHostInfo(t *testing.T) { + tserver := newTestServer(t) + + token := tserver.createSessionWithUsername( + "whoisuser", "myident", + ) + + queryToken := tserver.createSession("querier") + + _, lastID := tserver.pollMessages(queryToken, 0) + + tserver.sendCommand(queryToken, map[string]any{ + commandKey: "WHOIS", + toKey: "whoisuser", + }) + + msgs, _ := tserver.pollMessages(queryToken, lastID) + + whoisMsg := findNumericWithParams(msgs, "311") + if whoisMsg == nil { + t.Fatalf( + "expected RPL_WHOISUSER (311), got %v", + msgs, + ) + } + + params := getNumericParams(whoisMsg) + + if len(params) < 2 { + t.Fatalf( + "expected at least 2 params, got %v", + params, + ) + } + + if params[1] != "myident" { + t.Fatalf( + "expected username myident, got %s", + params[1], + ) + } + + _ = token +} + +// createSessionWithUsername creates a session with a +// specific username and returns the token. +func (tserver *testServer) createSessionWithUsername( + nick, username string, +) string { + tserver.t.Helper() + + body, err := json.Marshal(map[string]string{ + "nick": nick, + "username": username, + }) + if err != nil { + tserver.t.Fatalf("marshal session: %v", err) + } + + resp, err := doRequest( + tserver.t, + http.MethodPost, + tserver.url(apiSession), + bytes.NewReader(body), + ) + if err != nil { + tserver.t.Fatalf("create session: %v", err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + respBody, _ := io.ReadAll(resp.Body) + tserver.t.Fatalf( + "create session: status %d: %s", + resp.StatusCode, respBody, + ) + } + + var result struct { + Token string `json:"token"` + } + + _ = json.NewDecoder(resp.Body).Decode(&result) + + return result.Token +} + +func TestWhoShowsHostInfo(t *testing.T) { + tserver := newTestServer(t) + + whoToken := tserver.createSessionWithUsername( + "whouser", "whoident", + ) + + tserver.sendCommand(whoToken, map[string]any{ + commandKey: joinCmd, toKey: "#whotest", + }) + + queryToken := tserver.createSession("whoquerier") + + tserver.sendCommand(queryToken, map[string]any{ + commandKey: joinCmd, toKey: "#whotest", + }) + + _, lastID := tserver.pollMessages(queryToken, 0) + + tserver.sendCommand(queryToken, map[string]any{ + commandKey: "WHO", + toKey: "#whotest", + }) + + msgs, _ := tserver.pollMessages(queryToken, lastID) + + assertWhoReplyUsername(t, msgs, "whouser", "whoident") +} + +func assertWhoReplyUsername( + t *testing.T, + msgs []map[string]any, + targetNick, expectedUsername string, +) { + t.Helper() + + for _, msg := range msgs { + code, isCode := msg["code"].(float64) + if !isCode || int(code) != 352 { + continue + } + + params := getNumericParams(msg) + if len(params) < 5 || params[4] != targetNick { + continue + } + + if params[1] != expectedUsername { + t.Fatalf( + "expected username %s in WHO, got %s", + expectedUsername, params[1], + ) + } + + return + } + + t.Fatalf( + "expected RPL_WHOREPLY (352) for %s, msgs: %v", + targetNick, msgs, + ) +} + +func TestSessionUsernameDefault(t *testing.T) { + tserver := newTestServer(t) + + // Create session without specifying username. + token := tserver.createSession("defaultusr") + + queryToken := tserver.createSession("querier2") + + _, lastID := tserver.pollMessages(queryToken, 0) + + // WHOIS should show the nick as the username. + tserver.sendCommand(queryToken, map[string]any{ + commandKey: "WHOIS", + toKey: "defaultusr", + }) + + msgs, _ := tserver.pollMessages(queryToken, lastID) + + whoisMsg := findNumericWithParams(msgs, "311") + if whoisMsg == nil { + t.Fatalf( + "expected RPL_WHOISUSER (311), got %v", + msgs, + ) + } + + params := getNumericParams(whoisMsg) + + if len(params) < 2 { + t.Fatalf( + "expected at least 2 params, got %v", + params, + ) + } + + // Username defaults to nick. + if params[1] != "defaultusr" { + t.Fatalf( + "expected default username defaultusr, got %s", + params[1], + ) + } + + _ = token +} + func TestNickBroadcastToChannels(t *testing.T) { tserver := newTestServer(t) aliceToken := tserver.createSession("nick_a") diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index 293636f..1f26cdc 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -30,6 +30,7 @@ func (hdlr *Handlers) handleRegister( ) { type registerRequest struct { Nick string `json:"nick"` + Username string `json:"username,omitempty"` Password string `json:"password"` } @@ -58,6 +59,20 @@ func (hdlr *Handlers) handleRegister( return } + username := resolveUsername( + payload.Username, payload.Nick, + ) + + if !validUsernameRe.MatchString(username) { + hdlr.respondError( + writer, request, + "invalid username format", + http.StatusBadRequest, + ) + + return + } + if len(payload.Password) < minPasswordLength { hdlr.respondError( writer, request, @@ -68,11 +83,25 @@ func (hdlr *Handlers) handleRegister( return } + hdlr.executeRegister( + writer, request, + payload.Nick, payload.Password, username, + ) +} + +func (hdlr *Handlers) executeRegister( + writer http.ResponseWriter, + request *http.Request, + nick, password, username string, +) { + hostname := resolveHostname( + request.Context(), clientIP(request), + ) + sessionID, clientID, token, err := hdlr.params.Database.RegisterUser( request.Context(), - payload.Nick, - payload.Password, + nick, password, username, hostname, ) if err != nil { hdlr.handleRegisterError( @@ -85,11 +114,11 @@ func (hdlr *Handlers) handleRegister( hdlr.stats.IncrSessions() hdlr.stats.IncrConnections() - hdlr.deliverMOTD(request, clientID, sessionID, payload.Nick) + hdlr.deliverMOTD(request, clientID, sessionID, nick) hdlr.respondJSON(writer, request, map[string]any{ "id": sessionID, - "nick": payload.Nick, + "nick": nick, "token": token, }, http.StatusCreated) }