From 953771f2aa82edb11001f748bf62ba324c57f826 Mon Sep 17 00:00:00 2001 From: user Date: Tue, 17 Mar 2026 08:52:50 -0700 Subject: [PATCH] add IP to sessions, IP+hostname to clients - Add ip column to sessions table (real client IP of session creator) - Add ip and hostname columns to clients table (per-connection tracking) - Update CreateSession, RegisterUser, LoginUser to store new fields - Add GetClientHostInfo query method - Update SessionHostInfo to include IP - Extract executeCreateSession to fix funlen lint - Add tests for session IP, client IP/hostname, login client tracking - Update README with new field documentation --- README.md | 6 ++ internal/db/auth.go | 24 +++---- internal/db/auth_test.go | 101 +++++++++++++++++++++++++---- internal/db/queries.go | 57 ++++++++++++---- internal/db/queries_test.go | 93 +++++++++++++++++++++----- internal/db/schema/001_initial.sql | 3 + internal/handlers/api.go | 20 ++++-- internal/handlers/auth.go | 13 +++- 8 files changed, 261 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index 4bf0583..46a0888 100644 --- a/README.md +++ b/README.md @@ -216,6 +216,12 @@ Each session has an IRC-style hostmask composed of three parts: `username` field in the session/register request; defaults to the nick) - **hostname** — automatically resolved via reverse DNS of the connecting client's IP address at session creation time +- **ip** — the real IP address of the session creator, extracted from + `X-Forwarded-For`, `X-Real-IP`, or `RemoteAddr` + +Each **client connection** (created at session creation, registration, or login) +also stores its own **ip** and **hostname**, allowing the server to track the +network origin of each individual client independently from the session. The hostmask appears in: diff --git a/internal/db/auth.go b/internal/db/auth.go index dc02ba4..0367ace 100644 --- a/internal/db/auth.go +++ b/internal/db/auth.go @@ -20,7 +20,7 @@ var errNoPassword = errors.New( // and returns session ID, client ID, and token. func (database *Database) RegisterUser( ctx context.Context, - nick, password, username, hostname string, + nick, password, username, hostname, remoteIP string, ) (int64, int64, string, error) { if username == "" { username = nick @@ -54,11 +54,11 @@ func (database *Database) RegisterUser( res, err := transaction.ExecContext(ctx, `INSERT INTO sessions - (uuid, nick, username, hostname, + (uuid, nick, username, hostname, ip, password_hash, created_at, last_seen) - VALUES (?, ?, ?, ?, ?, ?, ?)`, + VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, sessionUUID, nick, username, hostname, - string(hash), now, now) + remoteIP, string(hash), now, now) if err != nil { _ = transaction.Rollback() @@ -73,10 +73,11 @@ func (database *Database) RegisterUser( clientRes, err := transaction.ExecContext(ctx, `INSERT INTO clients - (uuid, session_id, token, + (uuid, session_id, token, ip, hostname, created_at, last_seen) - VALUES (?, ?, ?, ?, ?)`, - clientUUID, sessionID, tokenHash, now, now) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + clientUUID, sessionID, tokenHash, + remoteIP, hostname, now, now) if err != nil { _ = transaction.Rollback() @@ -101,7 +102,7 @@ func (database *Database) RegisterUser( // client token. func (database *Database) LoginUser( ctx context.Context, - nick, password string, + nick, password, remoteIP, hostname string, ) (int64, int64, string, error) { var ( sessionID int64 @@ -148,10 +149,11 @@ func (database *Database) LoginUser( res, err := database.conn.ExecContext(ctx, `INSERT INTO clients - (uuid, session_id, token, + (uuid, session_id, token, ip, hostname, created_at, last_seen) - VALUES (?, ?, ?, ?, ?)`, - clientUUID, sessionID, tokenHash, now, now) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + clientUUID, sessionID, tokenHash, + remoteIP, hostname, now, now) if err != nil { return 0, 0, "", fmt.Errorf( "create login client: %w", err, diff --git a/internal/db/auth_test.go b/internal/db/auth_test.go index 084ea2d..9b7527c 100644 --- a/internal/db/auth_test.go +++ b/internal/db/auth_test.go @@ -13,7 +13,7 @@ func TestRegisterUser(t *testing.T) { ctx := t.Context() sessionID, clientID, token, err := - database.RegisterUser(ctx, "reguser", "password123", "", "") + database.RegisterUser(ctx, "reguser", "password123", "", "", "") if err != nil { t.Fatal(err) } @@ -46,7 +46,7 @@ func TestRegisterUserWithUserHost(t *testing.T) { sessionID, _, _, err := database.RegisterUser( ctx, "reguhost", "password123", - "myident", "example.org", + "myident", "example.org", "", ) if err != nil { t.Fatal(err) @@ -80,7 +80,7 @@ func TestRegisterUserDefaultUsername(t *testing.T) { ctx := t.Context() sessionID, _, _, err := database.RegisterUser( - ctx, "regdefault", "password123", "", "", + ctx, "regdefault", "password123", "", "", "", ) if err != nil { t.Fatal(err) @@ -108,7 +108,7 @@ func TestRegisterUserDuplicateNick(t *testing.T) { ctx := t.Context() regSID, regCID, regToken, err := - database.RegisterUser(ctx, "dupnick", "password123", "", "") + database.RegisterUser(ctx, "dupnick", "password123", "", "", "") if err != nil { t.Fatal(err) } @@ -118,7 +118,7 @@ func TestRegisterUserDuplicateNick(t *testing.T) { _ = regToken dupSID, dupCID, dupToken, dupErr := - database.RegisterUser(ctx, "dupnick", "other12345", "", "") + database.RegisterUser(ctx, "dupnick", "other12345", "", "", "") if dupErr == nil { t.Fatal("expected error for duplicate nick") } @@ -135,7 +135,7 @@ func TestLoginUser(t *testing.T) { ctx := t.Context() regSID, regCID, regToken, err := - database.RegisterUser(ctx, "loginuser", "mypassword", "", "") + database.RegisterUser(ctx, "loginuser", "mypassword", "", "", "") if err != nil { t.Fatal(err) } @@ -145,7 +145,7 @@ func TestLoginUser(t *testing.T) { _ = regToken sessionID, clientID, token, err := - database.LoginUser(ctx, "loginuser", "mypassword") + database.LoginUser(ctx, "loginuser", "mypassword", "", "") if err != nil { t.Fatal(err) } @@ -166,6 +166,83 @@ func TestLoginUser(t *testing.T) { } } +func TestLoginUserStoresClientIPHostname(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + regSID, regCID, regToken, err := database.RegisterUser( + ctx, "loginipuser", "password123", + "", "", "10.0.0.1", + ) + + _ = regSID + _ = regCID + _ = regToken + if err != nil { + t.Fatal(err) + } + + _, clientID, _, err := database.LoginUser( + ctx, "loginipuser", "password123", + "10.0.0.99", "newhost.example.com", + ) + if err != nil { + t.Fatal(err) + } + + clientInfo, err := database.GetClientHostInfo( + ctx, clientID, + ) + if err != nil { + t.Fatal(err) + } + + if clientInfo.IP != "10.0.0.99" { + t.Fatalf( + "expected client IP 10.0.0.99, got %s", + clientInfo.IP, + ) + } + + if clientInfo.Hostname != "newhost.example.com" { + t.Fatalf( + "expected hostname newhost.example.com, got %s", + clientInfo.Hostname, + ) + } +} + +func TestRegisterUserStoresSessionIP(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + sessionID, _, _, err := database.RegisterUser( + ctx, "regipuser", "password123", + "ident", "host.local", "172.16.0.5", + ) + if err != nil { + t.Fatal(err) + } + + info, err := database.GetSessionHostInfo( + ctx, sessionID, + ) + if err != nil { + t.Fatal(err) + } + + if info.IP != "172.16.0.5" { + t.Fatalf( + "expected session IP 172.16.0.5, got %s", + info.IP, + ) + } +} + func TestLoginUserWrongPassword(t *testing.T) { t.Parallel() @@ -173,7 +250,7 @@ func TestLoginUserWrongPassword(t *testing.T) { ctx := t.Context() regSID, regCID, regToken, err := - database.RegisterUser(ctx, "wrongpw", "correctpass", "", "") + database.RegisterUser(ctx, "wrongpw", "correctpass", "", "", "") if err != nil { t.Fatal(err) } @@ -183,7 +260,7 @@ func TestLoginUserWrongPassword(t *testing.T) { _ = regToken loginSID, loginCID, loginToken, loginErr := - database.LoginUser(ctx, "wrongpw", "wrongpass12") + database.LoginUser(ctx, "wrongpw", "wrongpass12", "", "") if loginErr == nil { t.Fatal("expected error for wrong password") } @@ -201,7 +278,7 @@ func TestLoginUserNoPassword(t *testing.T) { // Create anonymous session (no password). anonSID, anonCID, anonToken, err := - database.CreateSession(ctx, "anon", "", "") + database.CreateSession(ctx, "anon", "", "", "") if err != nil { t.Fatal(err) } @@ -211,7 +288,7 @@ func TestLoginUserNoPassword(t *testing.T) { _ = anonToken loginSID, loginCID, loginToken, loginErr := - database.LoginUser(ctx, "anon", "anything1") + database.LoginUser(ctx, "anon", "anything1", "", "") if loginErr == nil { t.Fatal( "expected error for login on passwordless account", @@ -230,7 +307,7 @@ func TestLoginUserNonexistent(t *testing.T) { ctx := t.Context() loginSID, loginCID, loginToken, err := - database.LoginUser(ctx, "ghost", "password123") + database.LoginUser(ctx, "ghost", "password123", "", "") if err == nil { t.Fatal("expected error for nonexistent user") } diff --git a/internal/db/queries.go b/internal/db/queries.go index a64b124..c5abbe8 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -102,7 +102,7 @@ func FormatHostmask(nick, username, hostname string) string { // CreateSession registers a new session and its first client. func (database *Database) CreateSession( ctx context.Context, - nick, username, hostname string, + nick, username, hostname, remoteIP string, ) (int64, int64, string, error) { if username == "" { username = nick @@ -127,10 +127,11 @@ func (database *Database) CreateSession( res, err := transaction.ExecContext(ctx, `INSERT INTO sessions - (uuid, nick, username, hostname, + (uuid, nick, username, hostname, ip, created_at, last_seen) - VALUES (?, ?, ?, ?, ?, ?)`, - sessionUUID, nick, username, hostname, now, now) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + sessionUUID, nick, username, hostname, + remoteIP, now, now) if err != nil { _ = transaction.Rollback() @@ -145,10 +146,11 @@ func (database *Database) CreateSession( clientRes, err := transaction.ExecContext(ctx, `INSERT INTO clients - (uuid, session_id, token, + (uuid, session_id, token, ip, hostname, created_at, last_seen) - VALUES (?, ?, ?, ?, ?)`, - clientUUID, sessionID, tokenHash, now, now) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + clientUUID, sessionID, tokenHash, + remoteIP, hostname, now, now) if err != nil { _ = transaction.Rollback() @@ -236,14 +238,16 @@ func (database *Database) GetSessionByNick( return sessionID, nil } -// SessionHostInfo holds the username and hostname for a session. +// SessionHostInfo holds the username, hostname, and IP +// for a session. type SessionHostInfo struct { Username string Hostname string + IP string } -// GetSessionHostInfo returns the username and hostname -// for a session. +// GetSessionHostInfo returns the username, hostname, +// and IP for a session. func (database *Database) GetSessionHostInfo( ctx context.Context, sessionID int64, @@ -252,10 +256,10 @@ func (database *Database) GetSessionHostInfo( err := database.conn.QueryRowContext( ctx, - `SELECT username, hostname + `SELECT username, hostname, ip FROM sessions WHERE id = ?`, sessionID, - ).Scan(&info.Username, &info.Hostname) + ).Scan(&info.Username, &info.Hostname, &info.IP) if err != nil { return nil, fmt.Errorf( "get session host info: %w", err, @@ -265,6 +269,35 @@ func (database *Database) GetSessionHostInfo( return &info, nil } +// ClientHostInfo holds the IP and hostname for a client. +type ClientHostInfo struct { + IP string + Hostname string +} + +// GetClientHostInfo returns the IP and hostname for a +// client. +func (database *Database) GetClientHostInfo( + ctx context.Context, + clientID int64, +) (*ClientHostInfo, error) { + var info ClientHostInfo + + err := database.conn.QueryRowContext( + ctx, + `SELECT ip, hostname + FROM clients WHERE id = ?`, + clientID, + ).Scan(&info.IP, &info.Hostname) + if err != nil { + return nil, fmt.Errorf( + "get client host info: %w", err, + ) + } + + return &info, nil +} + // GetChannelByName returns the channel ID for a name. func (database *Database) GetChannelByName( ctx context.Context, diff --git a/internal/db/queries_test.go b/internal/db/queries_test.go index 3f75a0b..938bb0b 100644 --- a/internal/db/queries_test.go +++ b/internal/db/queries_test.go @@ -34,7 +34,7 @@ func TestCreateSession(t *testing.T) { ctx := t.Context() sessionID, _, token, err := database.CreateSession( - ctx, "alice", "", "", + ctx, "alice", "", "", "", ) if err != nil { t.Fatal(err) @@ -45,7 +45,7 @@ func TestCreateSession(t *testing.T) { } _, _, dupToken, dupErr := database.CreateSession( - ctx, "alice", "", "", + ctx, "alice", "", "", "", ) if dupErr == nil { t.Fatal("expected error for duplicate nick") @@ -65,7 +65,7 @@ func assertSessionHostInfo( t.Helper() sessionID, _, _, err := database.CreateSession( - t.Context(), nick, inputUser, inputHost, + t.Context(), nick, inputUser, inputHost, "", ) if err != nil { t.Fatal(err) @@ -118,6 +118,69 @@ func TestCreateSessionDefaultUsername(t *testing.T) { ) } +func TestCreateSessionStoresIP(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + sessionID, clientID, _, err := database.CreateSession( + ctx, "ipuser", "ident", "host.example.com", + "192.168.1.42", + ) + if err != nil { + t.Fatal(err) + } + + info, err := database.GetSessionHostInfo( + ctx, sessionID, + ) + if err != nil { + t.Fatal(err) + } + + if info.IP != "192.168.1.42" { + t.Fatalf( + "expected session IP 192.168.1.42, got %s", + info.IP, + ) + } + + clientInfo, err := database.GetClientHostInfo( + ctx, clientID, + ) + if err != nil { + t.Fatal(err) + } + + if clientInfo.IP != "192.168.1.42" { + t.Fatalf( + "expected client IP 192.168.1.42, got %s", + clientInfo.IP, + ) + } + + if clientInfo.Hostname != "host.example.com" { + t.Fatalf( + "expected client hostname host.example.com, got %s", + clientInfo.Hostname, + ) + } +} + +func TestGetClientHostInfoNotFound(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + + _, err := database.GetClientHostInfo( + t.Context(), 99999, + ) + if err == nil { + t.Fatal("expected error for nonexistent client") + } +} + func TestGetSessionHostInfoNotFound(t *testing.T) { t.Parallel() @@ -183,7 +246,7 @@ func TestChannelMembersIncludeUserHost(t *testing.T) { ctx := t.Context() sid, _, _, err := database.CreateSession( - ctx, "memuser", "myuser", "myhost.net", + ctx, "memuser", "myuser", "myhost.net", "", ) if err != nil { t.Fatal(err) @@ -233,7 +296,7 @@ func TestGetSessionByToken(t *testing.T) { database := setupTestDB(t) ctx := t.Context() - _, _, token, err := database.CreateSession(ctx, "bob", "", "") + _, _, token, err := database.CreateSession(ctx, "bob", "", "", "") if err != nil { t.Fatal(err) } @@ -266,7 +329,7 @@ func TestGetSessionByNick(t *testing.T) { ctx := t.Context() charlieID, charlieClientID, charlieToken, err := - database.CreateSession(ctx, "charlie", "", "") + database.CreateSession(ctx, "charlie", "", "", "") if err != nil { t.Fatal(err) } @@ -323,7 +386,7 @@ func TestJoinAndPart(t *testing.T) { database := setupTestDB(t) ctx := t.Context() - sid, _, _, err := database.CreateSession(ctx, "user1", "", "") + sid, _, _, err := database.CreateSession(ctx, "user1", "", "", "") if err != nil { t.Fatal(err) } @@ -372,7 +435,7 @@ func TestDeleteChannelIfEmpty(t *testing.T) { t.Fatal(err) } - sid, _, _, err := database.CreateSession(ctx, "temp", "", "") + sid, _, _, err := database.CreateSession(ctx, "temp", "", "", "") if err != nil { t.Fatal(err) } @@ -407,7 +470,7 @@ func createSessionWithChannels( ctx := t.Context() - sid, _, _, err := database.CreateSession(ctx, nick, "", "") + sid, _, _, err := database.CreateSession(ctx, nick, "", "", "") if err != nil { t.Fatal(err) } @@ -490,7 +553,7 @@ func TestChangeNick(t *testing.T) { ctx := t.Context() sid, _, token, err := database.CreateSession( - ctx, "old", "", "", + ctx, "old", "", "", "", ) if err != nil { t.Fatal(err) @@ -574,7 +637,7 @@ func TestPollMessages(t *testing.T) { ctx := t.Context() sid, _, token, err := database.CreateSession( - ctx, "poller", "", "", + ctx, "poller", "", "", "", ) if err != nil { t.Fatal(err) @@ -681,7 +744,7 @@ func TestDeleteSession(t *testing.T) { ctx := t.Context() sid, _, _, err := database.CreateSession( - ctx, "deleteme", "", "", + ctx, "deleteme", "", "", "", ) if err != nil { t.Fatal(err) @@ -721,12 +784,12 @@ func TestChannelMembers(t *testing.T) { database := setupTestDB(t) ctx := t.Context() - sid1, _, _, err := database.CreateSession(ctx, "m1", "", "") + sid1, _, _, err := database.CreateSession(ctx, "m1", "", "", "") if err != nil { t.Fatal(err) } - sid2, _, _, err := database.CreateSession(ctx, "m2", "", "") + sid2, _, _, err := database.CreateSession(ctx, "m2", "", "", "") if err != nil { t.Fatal(err) } @@ -784,7 +847,7 @@ func TestEnqueueToClient(t *testing.T) { ctx := t.Context() _, _, token, err := database.CreateSession( - ctx, "enqclient", "", "", + ctx, "enqclient", "", "", "", ) if err != nil { t.Fatal(err) diff --git a/internal/db/schema/001_initial.sql b/internal/db/schema/001_initial.sql index 2366671..e5971ea 100644 --- a/internal/db/schema/001_initial.sql +++ b/internal/db/schema/001_initial.sql @@ -8,6 +8,7 @@ CREATE TABLE IF NOT EXISTS sessions ( nick TEXT NOT NULL UNIQUE, username TEXT NOT NULL DEFAULT '', hostname TEXT NOT NULL DEFAULT '', + ip TEXT NOT NULL DEFAULT '', password_hash TEXT NOT NULL DEFAULT '', signing_key TEXT NOT NULL DEFAULT '', away_message TEXT NOT NULL DEFAULT '', @@ -22,6 +23,8 @@ CREATE TABLE IF NOT EXISTS clients ( uuid TEXT NOT NULL UNIQUE, session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, token TEXT NOT NULL UNIQUE, + ip TEXT NOT NULL DEFAULT '', + hostname TEXT NOT NULL DEFAULT '', created_at DATETIME DEFAULT CURRENT_TIMESTAMP, last_seen DATETIME DEFAULT CURRENT_TIMESTAMP ); diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 973fd51..2cd0dba 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -251,14 +251,26 @@ func (hdlr *Handlers) handleCreateSession( return } + hdlr.executeCreateSession( + writer, request, payload.Nick, username, + ) +} + +func (hdlr *Handlers) executeCreateSession( + writer http.ResponseWriter, + request *http.Request, + nick, username string, +) { + remoteIP := clientIP(request) + hostname := resolveHostname( - request.Context(), clientIP(request), + request.Context(), remoteIP, ) sessionID, clientID, token, err := hdlr.params.Database.CreateSession( request.Context(), - payload.Nick, username, hostname, + nick, username, hostname, remoteIP, ) if err != nil { hdlr.handleCreateSessionError( @@ -271,11 +283,11 @@ func (hdlr *Handlers) handleCreateSession( hdlr.stats.IncrSessions() hdlr.stats.IncrConnections() - hdlr.deliverMOTD(request, clientID, sessionID, payload.Nick) + hdlr.deliverMOTD(request, clientID, sessionID, nick) hdlr.respondJSON(writer, request, map[string]any{ "id": sessionID, - "nick": payload.Nick, + "nick": nick, "token": token, }, http.StatusCreated) } diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index 1f26cdc..44ee5c3 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -94,14 +94,16 @@ func (hdlr *Handlers) executeRegister( request *http.Request, nick, password, username string, ) { + remoteIP := clientIP(request) + hostname := resolveHostname( - request.Context(), clientIP(request), + request.Context(), remoteIP, ) sessionID, clientID, token, err := hdlr.params.Database.RegisterUser( request.Context(), - nick, password, username, hostname, + nick, password, username, hostname, remoteIP, ) if err != nil { hdlr.handleRegisterError( @@ -196,11 +198,18 @@ func (hdlr *Handlers) handleLogin( return } + remoteIP := clientIP(request) + + hostname := resolveHostname( + request.Context(), remoteIP, + ) + sessionID, clientID, token, err := hdlr.params.Database.LoginUser( request.Context(), payload.Nick, payload.Password, + remoteIP, hostname, ) if err != nil { hdlr.respondError(