add IP to sessions, IP+hostname to clients
All checks were successful
check / check (push) Successful in 1m5s

- Add ip column to sessions table (real client IP of session creator)
- Add ip and hostname columns to clients table (per-connection tracking)
- Update CreateSession, RegisterUser, LoginUser to store new fields
- Add GetClientHostInfo query method
- Update SessionHostInfo to include IP
- Extract executeCreateSession to fix funlen lint
- Add tests for session IP, client IP/hostname, login client tracking
- Update README with new field documentation
This commit is contained in:
user
2026-03-17 08:52:50 -07:00
parent e42c6c1868
commit 953771f2aa
8 changed files with 261 additions and 56 deletions

View File

@@ -216,6 +216,12 @@ Each session has an IRC-style hostmask composed of three parts:
`username` field in the session/register request; defaults to the nick)
- **hostname** — automatically resolved via reverse DNS of the connecting
client's IP address at session creation time
- **ip** — the real IP address of the session creator, extracted from
`X-Forwarded-For`, `X-Real-IP`, or `RemoteAddr`
Each **client connection** (created at session creation, registration, or login)
also stores its own **ip** and **hostname**, allowing the server to track the
network origin of each individual client independently from the session.
The hostmask appears in:

View File

@@ -20,7 +20,7 @@ var errNoPassword = errors.New(
// and returns session ID, client ID, and token.
func (database *Database) RegisterUser(
ctx context.Context,
nick, password, username, hostname string,
nick, password, username, hostname, remoteIP string,
) (int64, int64, string, error) {
if username == "" {
username = nick
@@ -54,11 +54,11 @@ func (database *Database) RegisterUser(
res, err := transaction.ExecContext(ctx,
`INSERT INTO sessions
(uuid, nick, username, hostname,
(uuid, nick, username, hostname, ip,
password_hash, created_at, last_seen)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
sessionUUID, nick, username, hostname,
string(hash), now, now)
remoteIP, string(hash), now, now)
if err != nil {
_ = transaction.Rollback()
@@ -73,10 +73,11 @@ func (database *Database) RegisterUser(
clientRes, err := transaction.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token,
(uuid, session_id, token, ip, hostname,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash, now, now)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash,
remoteIP, hostname, now, now)
if err != nil {
_ = transaction.Rollback()
@@ -101,7 +102,7 @@ func (database *Database) RegisterUser(
// client token.
func (database *Database) LoginUser(
ctx context.Context,
nick, password string,
nick, password, remoteIP, hostname string,
) (int64, int64, string, error) {
var (
sessionID int64
@@ -148,10 +149,11 @@ func (database *Database) LoginUser(
res, err := database.conn.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token,
(uuid, session_id, token, ip, hostname,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash, now, now)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash,
remoteIP, hostname, now, now)
if err != nil {
return 0, 0, "", fmt.Errorf(
"create login client: %w", err,

View File

@@ -13,7 +13,7 @@ func TestRegisterUser(t *testing.T) {
ctx := t.Context()
sessionID, clientID, token, err :=
database.RegisterUser(ctx, "reguser", "password123", "", "")
database.RegisterUser(ctx, "reguser", "password123", "", "", "")
if err != nil {
t.Fatal(err)
}
@@ -46,7 +46,7 @@ func TestRegisterUserWithUserHost(t *testing.T) {
sessionID, _, _, err := database.RegisterUser(
ctx, "reguhost", "password123",
"myident", "example.org",
"myident", "example.org", "",
)
if err != nil {
t.Fatal(err)
@@ -80,7 +80,7 @@ func TestRegisterUserDefaultUsername(t *testing.T) {
ctx := t.Context()
sessionID, _, _, err := database.RegisterUser(
ctx, "regdefault", "password123", "", "",
ctx, "regdefault", "password123", "", "", "",
)
if err != nil {
t.Fatal(err)
@@ -108,7 +108,7 @@ func TestRegisterUserDuplicateNick(t *testing.T) {
ctx := t.Context()
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "dupnick", "password123", "", "")
database.RegisterUser(ctx, "dupnick", "password123", "", "", "")
if err != nil {
t.Fatal(err)
}
@@ -118,7 +118,7 @@ func TestRegisterUserDuplicateNick(t *testing.T) {
_ = regToken
dupSID, dupCID, dupToken, dupErr :=
database.RegisterUser(ctx, "dupnick", "other12345", "", "")
database.RegisterUser(ctx, "dupnick", "other12345", "", "", "")
if dupErr == nil {
t.Fatal("expected error for duplicate nick")
}
@@ -135,7 +135,7 @@ func TestLoginUser(t *testing.T) {
ctx := t.Context()
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "loginuser", "mypassword", "", "")
database.RegisterUser(ctx, "loginuser", "mypassword", "", "", "")
if err != nil {
t.Fatal(err)
}
@@ -145,7 +145,7 @@ func TestLoginUser(t *testing.T) {
_ = regToken
sessionID, clientID, token, err :=
database.LoginUser(ctx, "loginuser", "mypassword")
database.LoginUser(ctx, "loginuser", "mypassword", "", "")
if err != nil {
t.Fatal(err)
}
@@ -166,6 +166,83 @@ func TestLoginUser(t *testing.T) {
}
}
func TestLoginUserStoresClientIPHostname(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
regSID, regCID, regToken, err := database.RegisterUser(
ctx, "loginipuser", "password123",
"", "", "10.0.0.1",
)
_ = regSID
_ = regCID
_ = regToken
if err != nil {
t.Fatal(err)
}
_, clientID, _, err := database.LoginUser(
ctx, "loginipuser", "password123",
"10.0.0.99", "newhost.example.com",
)
if err != nil {
t.Fatal(err)
}
clientInfo, err := database.GetClientHostInfo(
ctx, clientID,
)
if err != nil {
t.Fatal(err)
}
if clientInfo.IP != "10.0.0.99" {
t.Fatalf(
"expected client IP 10.0.0.99, got %s",
clientInfo.IP,
)
}
if clientInfo.Hostname != "newhost.example.com" {
t.Fatalf(
"expected hostname newhost.example.com, got %s",
clientInfo.Hostname,
)
}
}
func TestRegisterUserStoresSessionIP(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err := database.RegisterUser(
ctx, "regipuser", "password123",
"ident", "host.local", "172.16.0.5",
)
if err != nil {
t.Fatal(err)
}
info, err := database.GetSessionHostInfo(
ctx, sessionID,
)
if err != nil {
t.Fatal(err)
}
if info.IP != "172.16.0.5" {
t.Fatalf(
"expected session IP 172.16.0.5, got %s",
info.IP,
)
}
}
func TestLoginUserWrongPassword(t *testing.T) {
t.Parallel()
@@ -173,7 +250,7 @@ func TestLoginUserWrongPassword(t *testing.T) {
ctx := t.Context()
regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "wrongpw", "correctpass", "", "")
database.RegisterUser(ctx, "wrongpw", "correctpass", "", "", "")
if err != nil {
t.Fatal(err)
}
@@ -183,7 +260,7 @@ func TestLoginUserWrongPassword(t *testing.T) {
_ = regToken
loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "wrongpw", "wrongpass12")
database.LoginUser(ctx, "wrongpw", "wrongpass12", "", "")
if loginErr == nil {
t.Fatal("expected error for wrong password")
}
@@ -201,7 +278,7 @@ func TestLoginUserNoPassword(t *testing.T) {
// Create anonymous session (no password).
anonSID, anonCID, anonToken, err :=
database.CreateSession(ctx, "anon", "", "")
database.CreateSession(ctx, "anon", "", "", "")
if err != nil {
t.Fatal(err)
}
@@ -211,7 +288,7 @@ func TestLoginUserNoPassword(t *testing.T) {
_ = anonToken
loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "anon", "anything1")
database.LoginUser(ctx, "anon", "anything1", "", "")
if loginErr == nil {
t.Fatal(
"expected error for login on passwordless account",
@@ -230,7 +307,7 @@ func TestLoginUserNonexistent(t *testing.T) {
ctx := t.Context()
loginSID, loginCID, loginToken, err :=
database.LoginUser(ctx, "ghost", "password123")
database.LoginUser(ctx, "ghost", "password123", "", "")
if err == nil {
t.Fatal("expected error for nonexistent user")
}

View File

@@ -102,7 +102,7 @@ func FormatHostmask(nick, username, hostname string) string {
// CreateSession registers a new session and its first client.
func (database *Database) CreateSession(
ctx context.Context,
nick, username, hostname string,
nick, username, hostname, remoteIP string,
) (int64, int64, string, error) {
if username == "" {
username = nick
@@ -127,10 +127,11 @@ func (database *Database) CreateSession(
res, err := transaction.ExecContext(ctx,
`INSERT INTO sessions
(uuid, nick, username, hostname,
(uuid, nick, username, hostname, ip,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?, ?)`,
sessionUUID, nick, username, hostname, now, now)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
sessionUUID, nick, username, hostname,
remoteIP, now, now)
if err != nil {
_ = transaction.Rollback()
@@ -145,10 +146,11 @@ func (database *Database) CreateSession(
clientRes, err := transaction.ExecContext(ctx,
`INSERT INTO clients
(uuid, session_id, token,
(uuid, session_id, token, ip, hostname,
created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash, now, now)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash,
remoteIP, hostname, now, now)
if err != nil {
_ = transaction.Rollback()
@@ -236,14 +238,16 @@ func (database *Database) GetSessionByNick(
return sessionID, nil
}
// SessionHostInfo holds the username and hostname for a session.
// SessionHostInfo holds the username, hostname, and IP
// for a session.
type SessionHostInfo struct {
Username string
Hostname string
IP string
}
// GetSessionHostInfo returns the username and hostname
// for a session.
// GetSessionHostInfo returns the username, hostname,
// and IP for a session.
func (database *Database) GetSessionHostInfo(
ctx context.Context,
sessionID int64,
@@ -252,10 +256,10 @@ func (database *Database) GetSessionHostInfo(
err := database.conn.QueryRowContext(
ctx,
`SELECT username, hostname
`SELECT username, hostname, ip
FROM sessions WHERE id = ?`,
sessionID,
).Scan(&info.Username, &info.Hostname)
).Scan(&info.Username, &info.Hostname, &info.IP)
if err != nil {
return nil, fmt.Errorf(
"get session host info: %w", err,
@@ -265,6 +269,35 @@ func (database *Database) GetSessionHostInfo(
return &info, nil
}
// ClientHostInfo holds the IP and hostname for a client.
type ClientHostInfo struct {
IP string
Hostname string
}
// GetClientHostInfo returns the IP and hostname for a
// client.
func (database *Database) GetClientHostInfo(
ctx context.Context,
clientID int64,
) (*ClientHostInfo, error) {
var info ClientHostInfo
err := database.conn.QueryRowContext(
ctx,
`SELECT ip, hostname
FROM clients WHERE id = ?`,
clientID,
).Scan(&info.IP, &info.Hostname)
if err != nil {
return nil, fmt.Errorf(
"get client host info: %w", err,
)
}
return &info, nil
}
// GetChannelByName returns the channel ID for a name.
func (database *Database) GetChannelByName(
ctx context.Context,

View File

@@ -34,7 +34,7 @@ func TestCreateSession(t *testing.T) {
ctx := t.Context()
sessionID, _, token, err := database.CreateSession(
ctx, "alice", "", "",
ctx, "alice", "", "", "",
)
if err != nil {
t.Fatal(err)
@@ -45,7 +45,7 @@ func TestCreateSession(t *testing.T) {
}
_, _, dupToken, dupErr := database.CreateSession(
ctx, "alice", "", "",
ctx, "alice", "", "", "",
)
if dupErr == nil {
t.Fatal("expected error for duplicate nick")
@@ -65,7 +65,7 @@ func assertSessionHostInfo(
t.Helper()
sessionID, _, _, err := database.CreateSession(
t.Context(), nick, inputUser, inputHost,
t.Context(), nick, inputUser, inputHost, "",
)
if err != nil {
t.Fatal(err)
@@ -118,6 +118,69 @@ func TestCreateSessionDefaultUsername(t *testing.T) {
)
}
func TestCreateSessionStoresIP(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, clientID, _, err := database.CreateSession(
ctx, "ipuser", "ident", "host.example.com",
"192.168.1.42",
)
if err != nil {
t.Fatal(err)
}
info, err := database.GetSessionHostInfo(
ctx, sessionID,
)
if err != nil {
t.Fatal(err)
}
if info.IP != "192.168.1.42" {
t.Fatalf(
"expected session IP 192.168.1.42, got %s",
info.IP,
)
}
clientInfo, err := database.GetClientHostInfo(
ctx, clientID,
)
if err != nil {
t.Fatal(err)
}
if clientInfo.IP != "192.168.1.42" {
t.Fatalf(
"expected client IP 192.168.1.42, got %s",
clientInfo.IP,
)
}
if clientInfo.Hostname != "host.example.com" {
t.Fatalf(
"expected client hostname host.example.com, got %s",
clientInfo.Hostname,
)
}
}
func TestGetClientHostInfoNotFound(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
_, err := database.GetClientHostInfo(
t.Context(), 99999,
)
if err == nil {
t.Fatal("expected error for nonexistent client")
}
}
func TestGetSessionHostInfoNotFound(t *testing.T) {
t.Parallel()
@@ -183,7 +246,7 @@ func TestChannelMembersIncludeUserHost(t *testing.T) {
ctx := t.Context()
sid, _, _, err := database.CreateSession(
ctx, "memuser", "myuser", "myhost.net",
ctx, "memuser", "myuser", "myhost.net", "",
)
if err != nil {
t.Fatal(err)
@@ -233,7 +296,7 @@ func TestGetSessionByToken(t *testing.T) {
database := setupTestDB(t)
ctx := t.Context()
_, _, token, err := database.CreateSession(ctx, "bob", "", "")
_, _, token, err := database.CreateSession(ctx, "bob", "", "", "")
if err != nil {
t.Fatal(err)
}
@@ -266,7 +329,7 @@ func TestGetSessionByNick(t *testing.T) {
ctx := t.Context()
charlieID, charlieClientID, charlieToken, err :=
database.CreateSession(ctx, "charlie", "", "")
database.CreateSession(ctx, "charlie", "", "", "")
if err != nil {
t.Fatal(err)
}
@@ -323,7 +386,7 @@ func TestJoinAndPart(t *testing.T) {
database := setupTestDB(t)
ctx := t.Context()
sid, _, _, err := database.CreateSession(ctx, "user1", "", "")
sid, _, _, err := database.CreateSession(ctx, "user1", "", "", "")
if err != nil {
t.Fatal(err)
}
@@ -372,7 +435,7 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
t.Fatal(err)
}
sid, _, _, err := database.CreateSession(ctx, "temp", "", "")
sid, _, _, err := database.CreateSession(ctx, "temp", "", "", "")
if err != nil {
t.Fatal(err)
}
@@ -407,7 +470,7 @@ func createSessionWithChannels(
ctx := t.Context()
sid, _, _, err := database.CreateSession(ctx, nick, "", "")
sid, _, _, err := database.CreateSession(ctx, nick, "", "", "")
if err != nil {
t.Fatal(err)
}
@@ -490,7 +553,7 @@ func TestChangeNick(t *testing.T) {
ctx := t.Context()
sid, _, token, err := database.CreateSession(
ctx, "old", "", "",
ctx, "old", "", "", "",
)
if err != nil {
t.Fatal(err)
@@ -574,7 +637,7 @@ func TestPollMessages(t *testing.T) {
ctx := t.Context()
sid, _, token, err := database.CreateSession(
ctx, "poller", "", "",
ctx, "poller", "", "", "",
)
if err != nil {
t.Fatal(err)
@@ -681,7 +744,7 @@ func TestDeleteSession(t *testing.T) {
ctx := t.Context()
sid, _, _, err := database.CreateSession(
ctx, "deleteme", "", "",
ctx, "deleteme", "", "", "",
)
if err != nil {
t.Fatal(err)
@@ -721,12 +784,12 @@ func TestChannelMembers(t *testing.T) {
database := setupTestDB(t)
ctx := t.Context()
sid1, _, _, err := database.CreateSession(ctx, "m1", "", "")
sid1, _, _, err := database.CreateSession(ctx, "m1", "", "", "")
if err != nil {
t.Fatal(err)
}
sid2, _, _, err := database.CreateSession(ctx, "m2", "", "")
sid2, _, _, err := database.CreateSession(ctx, "m2", "", "", "")
if err != nil {
t.Fatal(err)
}
@@ -784,7 +847,7 @@ func TestEnqueueToClient(t *testing.T) {
ctx := t.Context()
_, _, token, err := database.CreateSession(
ctx, "enqclient", "", "",
ctx, "enqclient", "", "", "",
)
if err != nil {
t.Fatal(err)

View File

@@ -8,6 +8,7 @@ CREATE TABLE IF NOT EXISTS sessions (
nick TEXT NOT NULL UNIQUE,
username TEXT NOT NULL DEFAULT '',
hostname TEXT NOT NULL DEFAULT '',
ip TEXT NOT NULL DEFAULT '',
password_hash TEXT NOT NULL DEFAULT '',
signing_key TEXT NOT NULL DEFAULT '',
away_message TEXT NOT NULL DEFAULT '',
@@ -22,6 +23,8 @@ CREATE TABLE IF NOT EXISTS clients (
uuid TEXT NOT NULL UNIQUE,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
token TEXT NOT NULL UNIQUE,
ip TEXT NOT NULL DEFAULT '',
hostname TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
);

View File

@@ -251,14 +251,26 @@ func (hdlr *Handlers) handleCreateSession(
return
}
hdlr.executeCreateSession(
writer, request, payload.Nick, username,
)
}
func (hdlr *Handlers) executeCreateSession(
writer http.ResponseWriter,
request *http.Request,
nick, username string,
) {
remoteIP := clientIP(request)
hostname := resolveHostname(
request.Context(), clientIP(request),
request.Context(), remoteIP,
)
sessionID, clientID, token, err :=
hdlr.params.Database.CreateSession(
request.Context(),
payload.Nick, username, hostname,
nick, username, hostname, remoteIP,
)
if err != nil {
hdlr.handleCreateSessionError(
@@ -271,11 +283,11 @@ func (hdlr *Handlers) handleCreateSession(
hdlr.stats.IncrSessions()
hdlr.stats.IncrConnections()
hdlr.deliverMOTD(request, clientID, sessionID, payload.Nick)
hdlr.deliverMOTD(request, clientID, sessionID, nick)
hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID,
"nick": payload.Nick,
"nick": nick,
"token": token,
}, http.StatusCreated)
}

View File

@@ -94,14 +94,16 @@ func (hdlr *Handlers) executeRegister(
request *http.Request,
nick, password, username string,
) {
remoteIP := clientIP(request)
hostname := resolveHostname(
request.Context(), clientIP(request),
request.Context(), remoteIP,
)
sessionID, clientID, token, err :=
hdlr.params.Database.RegisterUser(
request.Context(),
nick, password, username, hostname,
nick, password, username, hostname, remoteIP,
)
if err != nil {
hdlr.handleRegisterError(
@@ -196,11 +198,18 @@ func (hdlr *Handlers) handleLogin(
return
}
remoteIP := clientIP(request)
hostname := resolveHostname(
request.Context(), remoteIP,
)
sessionID, clientID, token, err :=
hdlr.params.Database.LoginUser(
request.Context(),
payload.Nick,
payload.Password,
remoteIP, hostname,
)
if err != nil {
hdlr.respondError(