add IP to sessions, IP+hostname to clients
All checks were successful
check / check (push) Successful in 1m5s

- Add ip column to sessions table (real client IP of session creator)
- Add ip and hostname columns to clients table (per-connection tracking)
- Update CreateSession, RegisterUser, LoginUser to store new fields
- Add GetClientHostInfo query method
- Update SessionHostInfo to include IP
- Extract executeCreateSession to fix funlen lint
- Add tests for session IP, client IP/hostname, login client tracking
- Update README with new field documentation
This commit is contained in:
user
2026-03-17 08:52:50 -07:00
parent e42c6c1868
commit 953771f2aa
8 changed files with 261 additions and 56 deletions

View File

@@ -216,6 +216,12 @@ Each session has an IRC-style hostmask composed of three parts:
`username` field in the session/register request; defaults to the nick) `username` field in the session/register request; defaults to the nick)
- **hostname** — automatically resolved via reverse DNS of the connecting - **hostname** — automatically resolved via reverse DNS of the connecting
client's IP address at session creation time client's IP address at session creation time
- **ip** — the real IP address of the session creator, extracted from
`X-Forwarded-For`, `X-Real-IP`, or `RemoteAddr`
Each **client connection** (created at session creation, registration, or login)
also stores its own **ip** and **hostname**, allowing the server to track the
network origin of each individual client independently from the session.
The hostmask appears in: The hostmask appears in:

View File

@@ -20,7 +20,7 @@ var errNoPassword = errors.New(
// and returns session ID, client ID, and token. // and returns session ID, client ID, and token.
func (database *Database) RegisterUser( func (database *Database) RegisterUser(
ctx context.Context, ctx context.Context,
nick, password, username, hostname string, nick, password, username, hostname, remoteIP string,
) (int64, int64, string, error) { ) (int64, int64, string, error) {
if username == "" { if username == "" {
username = nick username = nick
@@ -54,11 +54,11 @@ func (database *Database) RegisterUser(
res, err := transaction.ExecContext(ctx, res, err := transaction.ExecContext(ctx,
`INSERT INTO sessions `INSERT INTO sessions
(uuid, nick, username, hostname, (uuid, nick, username, hostname, ip,
password_hash, created_at, last_seen) password_hash, created_at, last_seen)
VALUES (?, ?, ?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
sessionUUID, nick, username, hostname, sessionUUID, nick, username, hostname,
string(hash), now, now) remoteIP, string(hash), now, now)
if err != nil { if err != nil {
_ = transaction.Rollback() _ = transaction.Rollback()
@@ -73,10 +73,11 @@ func (database *Database) RegisterUser(
clientRes, err := transaction.ExecContext(ctx, clientRes, err := transaction.ExecContext(ctx,
`INSERT INTO clients `INSERT INTO clients
(uuid, session_id, token, (uuid, session_id, token, ip, hostname,
created_at, last_seen) created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash, now, now) clientUUID, sessionID, tokenHash,
remoteIP, hostname, now, now)
if err != nil { if err != nil {
_ = transaction.Rollback() _ = transaction.Rollback()
@@ -101,7 +102,7 @@ func (database *Database) RegisterUser(
// client token. // client token.
func (database *Database) LoginUser( func (database *Database) LoginUser(
ctx context.Context, ctx context.Context,
nick, password string, nick, password, remoteIP, hostname string,
) (int64, int64, string, error) { ) (int64, int64, string, error) {
var ( var (
sessionID int64 sessionID int64
@@ -148,10 +149,11 @@ func (database *Database) LoginUser(
res, err := database.conn.ExecContext(ctx, res, err := database.conn.ExecContext(ctx,
`INSERT INTO clients `INSERT INTO clients
(uuid, session_id, token, (uuid, session_id, token, ip, hostname,
created_at, last_seen) created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash, now, now) clientUUID, sessionID, tokenHash,
remoteIP, hostname, now, now)
if err != nil { if err != nil {
return 0, 0, "", fmt.Errorf( return 0, 0, "", fmt.Errorf(
"create login client: %w", err, "create login client: %w", err,

View File

@@ -13,7 +13,7 @@ func TestRegisterUser(t *testing.T) {
ctx := t.Context() ctx := t.Context()
sessionID, clientID, token, err := sessionID, clientID, token, err :=
database.RegisterUser(ctx, "reguser", "password123", "", "") database.RegisterUser(ctx, "reguser", "password123", "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -46,7 +46,7 @@ func TestRegisterUserWithUserHost(t *testing.T) {
sessionID, _, _, err := database.RegisterUser( sessionID, _, _, err := database.RegisterUser(
ctx, "reguhost", "password123", ctx, "reguhost", "password123",
"myident", "example.org", "myident", "example.org", "",
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -80,7 +80,7 @@ func TestRegisterUserDefaultUsername(t *testing.T) {
ctx := t.Context() ctx := t.Context()
sessionID, _, _, err := database.RegisterUser( sessionID, _, _, err := database.RegisterUser(
ctx, "regdefault", "password123", "", "", ctx, "regdefault", "password123", "", "", "",
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -108,7 +108,7 @@ func TestRegisterUserDuplicateNick(t *testing.T) {
ctx := t.Context() ctx := t.Context()
regSID, regCID, regToken, err := regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "dupnick", "password123", "", "") database.RegisterUser(ctx, "dupnick", "password123", "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -118,7 +118,7 @@ func TestRegisterUserDuplicateNick(t *testing.T) {
_ = regToken _ = regToken
dupSID, dupCID, dupToken, dupErr := dupSID, dupCID, dupToken, dupErr :=
database.RegisterUser(ctx, "dupnick", "other12345", "", "") database.RegisterUser(ctx, "dupnick", "other12345", "", "", "")
if dupErr == nil { if dupErr == nil {
t.Fatal("expected error for duplicate nick") t.Fatal("expected error for duplicate nick")
} }
@@ -135,7 +135,7 @@ func TestLoginUser(t *testing.T) {
ctx := t.Context() ctx := t.Context()
regSID, regCID, regToken, err := regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "loginuser", "mypassword", "", "") database.RegisterUser(ctx, "loginuser", "mypassword", "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -145,7 +145,7 @@ func TestLoginUser(t *testing.T) {
_ = regToken _ = regToken
sessionID, clientID, token, err := sessionID, clientID, token, err :=
database.LoginUser(ctx, "loginuser", "mypassword") database.LoginUser(ctx, "loginuser", "mypassword", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -166,6 +166,83 @@ func TestLoginUser(t *testing.T) {
} }
} }
func TestLoginUserStoresClientIPHostname(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
regSID, regCID, regToken, err := database.RegisterUser(
ctx, "loginipuser", "password123",
"", "", "10.0.0.1",
)
_ = regSID
_ = regCID
_ = regToken
if err != nil {
t.Fatal(err)
}
_, clientID, _, err := database.LoginUser(
ctx, "loginipuser", "password123",
"10.0.0.99", "newhost.example.com",
)
if err != nil {
t.Fatal(err)
}
clientInfo, err := database.GetClientHostInfo(
ctx, clientID,
)
if err != nil {
t.Fatal(err)
}
if clientInfo.IP != "10.0.0.99" {
t.Fatalf(
"expected client IP 10.0.0.99, got %s",
clientInfo.IP,
)
}
if clientInfo.Hostname != "newhost.example.com" {
t.Fatalf(
"expected hostname newhost.example.com, got %s",
clientInfo.Hostname,
)
}
}
func TestRegisterUserStoresSessionIP(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, _, err := database.RegisterUser(
ctx, "regipuser", "password123",
"ident", "host.local", "172.16.0.5",
)
if err != nil {
t.Fatal(err)
}
info, err := database.GetSessionHostInfo(
ctx, sessionID,
)
if err != nil {
t.Fatal(err)
}
if info.IP != "172.16.0.5" {
t.Fatalf(
"expected session IP 172.16.0.5, got %s",
info.IP,
)
}
}
func TestLoginUserWrongPassword(t *testing.T) { func TestLoginUserWrongPassword(t *testing.T) {
t.Parallel() t.Parallel()
@@ -173,7 +250,7 @@ func TestLoginUserWrongPassword(t *testing.T) {
ctx := t.Context() ctx := t.Context()
regSID, regCID, regToken, err := regSID, regCID, regToken, err :=
database.RegisterUser(ctx, "wrongpw", "correctpass", "", "") database.RegisterUser(ctx, "wrongpw", "correctpass", "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -183,7 +260,7 @@ func TestLoginUserWrongPassword(t *testing.T) {
_ = regToken _ = regToken
loginSID, loginCID, loginToken, loginErr := loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "wrongpw", "wrongpass12") database.LoginUser(ctx, "wrongpw", "wrongpass12", "", "")
if loginErr == nil { if loginErr == nil {
t.Fatal("expected error for wrong password") t.Fatal("expected error for wrong password")
} }
@@ -201,7 +278,7 @@ func TestLoginUserNoPassword(t *testing.T) {
// Create anonymous session (no password). // Create anonymous session (no password).
anonSID, anonCID, anonToken, err := anonSID, anonCID, anonToken, err :=
database.CreateSession(ctx, "anon", "", "") database.CreateSession(ctx, "anon", "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -211,7 +288,7 @@ func TestLoginUserNoPassword(t *testing.T) {
_ = anonToken _ = anonToken
loginSID, loginCID, loginToken, loginErr := loginSID, loginCID, loginToken, loginErr :=
database.LoginUser(ctx, "anon", "anything1") database.LoginUser(ctx, "anon", "anything1", "", "")
if loginErr == nil { if loginErr == nil {
t.Fatal( t.Fatal(
"expected error for login on passwordless account", "expected error for login on passwordless account",
@@ -230,7 +307,7 @@ func TestLoginUserNonexistent(t *testing.T) {
ctx := t.Context() ctx := t.Context()
loginSID, loginCID, loginToken, err := loginSID, loginCID, loginToken, err :=
database.LoginUser(ctx, "ghost", "password123") database.LoginUser(ctx, "ghost", "password123", "", "")
if err == nil { if err == nil {
t.Fatal("expected error for nonexistent user") t.Fatal("expected error for nonexistent user")
} }

View File

@@ -102,7 +102,7 @@ func FormatHostmask(nick, username, hostname string) string {
// CreateSession registers a new session and its first client. // CreateSession registers a new session and its first client.
func (database *Database) CreateSession( func (database *Database) CreateSession(
ctx context.Context, ctx context.Context,
nick, username, hostname string, nick, username, hostname, remoteIP string,
) (int64, int64, string, error) { ) (int64, int64, string, error) {
if username == "" { if username == "" {
username = nick username = nick
@@ -127,10 +127,11 @@ func (database *Database) CreateSession(
res, err := transaction.ExecContext(ctx, res, err := transaction.ExecContext(ctx,
`INSERT INTO sessions `INSERT INTO sessions
(uuid, nick, username, hostname, (uuid, nick, username, hostname, ip,
created_at, last_seen) created_at, last_seen)
VALUES (?, ?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?, ?)`,
sessionUUID, nick, username, hostname, now, now) sessionUUID, nick, username, hostname,
remoteIP, now, now)
if err != nil { if err != nil {
_ = transaction.Rollback() _ = transaction.Rollback()
@@ -145,10 +146,11 @@ func (database *Database) CreateSession(
clientRes, err := transaction.ExecContext(ctx, clientRes, err := transaction.ExecContext(ctx,
`INSERT INTO clients `INSERT INTO clients
(uuid, session_id, token, (uuid, session_id, token, ip, hostname,
created_at, last_seen) created_at, last_seen)
VALUES (?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?, ?)`,
clientUUID, sessionID, tokenHash, now, now) clientUUID, sessionID, tokenHash,
remoteIP, hostname, now, now)
if err != nil { if err != nil {
_ = transaction.Rollback() _ = transaction.Rollback()
@@ -236,14 +238,16 @@ func (database *Database) GetSessionByNick(
return sessionID, nil return sessionID, nil
} }
// SessionHostInfo holds the username and hostname for a session. // SessionHostInfo holds the username, hostname, and IP
// for a session.
type SessionHostInfo struct { type SessionHostInfo struct {
Username string Username string
Hostname string Hostname string
IP string
} }
// GetSessionHostInfo returns the username and hostname // GetSessionHostInfo returns the username, hostname,
// for a session. // and IP for a session.
func (database *Database) GetSessionHostInfo( func (database *Database) GetSessionHostInfo(
ctx context.Context, ctx context.Context,
sessionID int64, sessionID int64,
@@ -252,10 +256,10 @@ func (database *Database) GetSessionHostInfo(
err := database.conn.QueryRowContext( err := database.conn.QueryRowContext(
ctx, ctx,
`SELECT username, hostname `SELECT username, hostname, ip
FROM sessions WHERE id = ?`, FROM sessions WHERE id = ?`,
sessionID, sessionID,
).Scan(&info.Username, &info.Hostname) ).Scan(&info.Username, &info.Hostname, &info.IP)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"get session host info: %w", err, "get session host info: %w", err,
@@ -265,6 +269,35 @@ func (database *Database) GetSessionHostInfo(
return &info, nil return &info, nil
} }
// ClientHostInfo holds the IP and hostname for a client.
type ClientHostInfo struct {
IP string
Hostname string
}
// GetClientHostInfo returns the IP and hostname for a
// client.
func (database *Database) GetClientHostInfo(
ctx context.Context,
clientID int64,
) (*ClientHostInfo, error) {
var info ClientHostInfo
err := database.conn.QueryRowContext(
ctx,
`SELECT ip, hostname
FROM clients WHERE id = ?`,
clientID,
).Scan(&info.IP, &info.Hostname)
if err != nil {
return nil, fmt.Errorf(
"get client host info: %w", err,
)
}
return &info, nil
}
// GetChannelByName returns the channel ID for a name. // GetChannelByName returns the channel ID for a name.
func (database *Database) GetChannelByName( func (database *Database) GetChannelByName(
ctx context.Context, ctx context.Context,

View File

@@ -34,7 +34,7 @@ func TestCreateSession(t *testing.T) {
ctx := t.Context() ctx := t.Context()
sessionID, _, token, err := database.CreateSession( sessionID, _, token, err := database.CreateSession(
ctx, "alice", "", "", ctx, "alice", "", "", "",
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -45,7 +45,7 @@ func TestCreateSession(t *testing.T) {
} }
_, _, dupToken, dupErr := database.CreateSession( _, _, dupToken, dupErr := database.CreateSession(
ctx, "alice", "", "", ctx, "alice", "", "", "",
) )
if dupErr == nil { if dupErr == nil {
t.Fatal("expected error for duplicate nick") t.Fatal("expected error for duplicate nick")
@@ -65,7 +65,7 @@ func assertSessionHostInfo(
t.Helper() t.Helper()
sessionID, _, _, err := database.CreateSession( sessionID, _, _, err := database.CreateSession(
t.Context(), nick, inputUser, inputHost, t.Context(), nick, inputUser, inputHost, "",
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -118,6 +118,69 @@ func TestCreateSessionDefaultUsername(t *testing.T) {
) )
} }
func TestCreateSessionStoresIP(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, clientID, _, err := database.CreateSession(
ctx, "ipuser", "ident", "host.example.com",
"192.168.1.42",
)
if err != nil {
t.Fatal(err)
}
info, err := database.GetSessionHostInfo(
ctx, sessionID,
)
if err != nil {
t.Fatal(err)
}
if info.IP != "192.168.1.42" {
t.Fatalf(
"expected session IP 192.168.1.42, got %s",
info.IP,
)
}
clientInfo, err := database.GetClientHostInfo(
ctx, clientID,
)
if err != nil {
t.Fatal(err)
}
if clientInfo.IP != "192.168.1.42" {
t.Fatalf(
"expected client IP 192.168.1.42, got %s",
clientInfo.IP,
)
}
if clientInfo.Hostname != "host.example.com" {
t.Fatalf(
"expected client hostname host.example.com, got %s",
clientInfo.Hostname,
)
}
}
func TestGetClientHostInfoNotFound(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
_, err := database.GetClientHostInfo(
t.Context(), 99999,
)
if err == nil {
t.Fatal("expected error for nonexistent client")
}
}
func TestGetSessionHostInfoNotFound(t *testing.T) { func TestGetSessionHostInfoNotFound(t *testing.T) {
t.Parallel() t.Parallel()
@@ -183,7 +246,7 @@ func TestChannelMembersIncludeUserHost(t *testing.T) {
ctx := t.Context() ctx := t.Context()
sid, _, _, err := database.CreateSession( sid, _, _, err := database.CreateSession(
ctx, "memuser", "myuser", "myhost.net", ctx, "memuser", "myuser", "myhost.net", "",
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -233,7 +296,7 @@ func TestGetSessionByToken(t *testing.T) {
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
_, _, token, err := database.CreateSession(ctx, "bob", "", "") _, _, token, err := database.CreateSession(ctx, "bob", "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -266,7 +329,7 @@ func TestGetSessionByNick(t *testing.T) {
ctx := t.Context() ctx := t.Context()
charlieID, charlieClientID, charlieToken, err := charlieID, charlieClientID, charlieToken, err :=
database.CreateSession(ctx, "charlie", "", "") database.CreateSession(ctx, "charlie", "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -323,7 +386,7 @@ func TestJoinAndPart(t *testing.T) {
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
sid, _, _, err := database.CreateSession(ctx, "user1", "", "") sid, _, _, err := database.CreateSession(ctx, "user1", "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -372,7 +435,7 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
sid, _, _, err := database.CreateSession(ctx, "temp", "", "") sid, _, _, err := database.CreateSession(ctx, "temp", "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -407,7 +470,7 @@ func createSessionWithChannels(
ctx := t.Context() ctx := t.Context()
sid, _, _, err := database.CreateSession(ctx, nick, "", "") sid, _, _, err := database.CreateSession(ctx, nick, "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -490,7 +553,7 @@ func TestChangeNick(t *testing.T) {
ctx := t.Context() ctx := t.Context()
sid, _, token, err := database.CreateSession( sid, _, token, err := database.CreateSession(
ctx, "old", "", "", ctx, "old", "", "", "",
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -574,7 +637,7 @@ func TestPollMessages(t *testing.T) {
ctx := t.Context() ctx := t.Context()
sid, _, token, err := database.CreateSession( sid, _, token, err := database.CreateSession(
ctx, "poller", "", "", ctx, "poller", "", "", "",
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -681,7 +744,7 @@ func TestDeleteSession(t *testing.T) {
ctx := t.Context() ctx := t.Context()
sid, _, _, err := database.CreateSession( sid, _, _, err := database.CreateSession(
ctx, "deleteme", "", "", ctx, "deleteme", "", "", "",
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -721,12 +784,12 @@ func TestChannelMembers(t *testing.T) {
database := setupTestDB(t) database := setupTestDB(t)
ctx := t.Context() ctx := t.Context()
sid1, _, _, err := database.CreateSession(ctx, "m1", "", "") sid1, _, _, err := database.CreateSession(ctx, "m1", "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
sid2, _, _, err := database.CreateSession(ctx, "m2", "", "") sid2, _, _, err := database.CreateSession(ctx, "m2", "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -784,7 +847,7 @@ func TestEnqueueToClient(t *testing.T) {
ctx := t.Context() ctx := t.Context()
_, _, token, err := database.CreateSession( _, _, token, err := database.CreateSession(
ctx, "enqclient", "", "", ctx, "enqclient", "", "", "",
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -8,6 +8,7 @@ CREATE TABLE IF NOT EXISTS sessions (
nick TEXT NOT NULL UNIQUE, nick TEXT NOT NULL UNIQUE,
username TEXT NOT NULL DEFAULT '', username TEXT NOT NULL DEFAULT '',
hostname TEXT NOT NULL DEFAULT '', hostname TEXT NOT NULL DEFAULT '',
ip TEXT NOT NULL DEFAULT '',
password_hash TEXT NOT NULL DEFAULT '', password_hash TEXT NOT NULL DEFAULT '',
signing_key TEXT NOT NULL DEFAULT '', signing_key TEXT NOT NULL DEFAULT '',
away_message TEXT NOT NULL DEFAULT '', away_message TEXT NOT NULL DEFAULT '',
@@ -22,6 +23,8 @@ CREATE TABLE IF NOT EXISTS clients (
uuid TEXT NOT NULL UNIQUE, uuid TEXT NOT NULL UNIQUE,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
token TEXT NOT NULL UNIQUE, token TEXT NOT NULL UNIQUE,
ip TEXT NOT NULL DEFAULT '',
hostname TEXT NOT NULL DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP, created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
); );

View File

@@ -251,14 +251,26 @@ func (hdlr *Handlers) handleCreateSession(
return return
} }
hdlr.executeCreateSession(
writer, request, payload.Nick, username,
)
}
func (hdlr *Handlers) executeCreateSession(
writer http.ResponseWriter,
request *http.Request,
nick, username string,
) {
remoteIP := clientIP(request)
hostname := resolveHostname( hostname := resolveHostname(
request.Context(), clientIP(request), request.Context(), remoteIP,
) )
sessionID, clientID, token, err := sessionID, clientID, token, err :=
hdlr.params.Database.CreateSession( hdlr.params.Database.CreateSession(
request.Context(), request.Context(),
payload.Nick, username, hostname, nick, username, hostname, remoteIP,
) )
if err != nil { if err != nil {
hdlr.handleCreateSessionError( hdlr.handleCreateSessionError(
@@ -271,11 +283,11 @@ func (hdlr *Handlers) handleCreateSession(
hdlr.stats.IncrSessions() hdlr.stats.IncrSessions()
hdlr.stats.IncrConnections() hdlr.stats.IncrConnections()
hdlr.deliverMOTD(request, clientID, sessionID, payload.Nick) hdlr.deliverMOTD(request, clientID, sessionID, nick)
hdlr.respondJSON(writer, request, map[string]any{ hdlr.respondJSON(writer, request, map[string]any{
"id": sessionID, "id": sessionID,
"nick": payload.Nick, "nick": nick,
"token": token, "token": token,
}, http.StatusCreated) }, http.StatusCreated)
} }

View File

@@ -94,14 +94,16 @@ func (hdlr *Handlers) executeRegister(
request *http.Request, request *http.Request,
nick, password, username string, nick, password, username string,
) { ) {
remoteIP := clientIP(request)
hostname := resolveHostname( hostname := resolveHostname(
request.Context(), clientIP(request), request.Context(), remoteIP,
) )
sessionID, clientID, token, err := sessionID, clientID, token, err :=
hdlr.params.Database.RegisterUser( hdlr.params.Database.RegisterUser(
request.Context(), request.Context(),
nick, password, username, hostname, nick, password, username, hostname, remoteIP,
) )
if err != nil { if err != nil {
hdlr.handleRegisterError( hdlr.handleRegisterError(
@@ -196,11 +198,18 @@ func (hdlr *Handlers) handleLogin(
return return
} }
remoteIP := clientIP(request)
hostname := resolveHostname(
request.Context(), remoteIP,
)
sessionID, clientID, token, err := sessionID, clientID, token, err :=
hdlr.params.Database.LoginUser( hdlr.params.Database.LoginUser(
request.Context(), request.Context(),
payload.Nick, payload.Nick,
payload.Password, payload.Password,
remoteIP, hostname,
) )
if err != nil { if err != nil {
hdlr.respondError( hdlr.respondError(