add IP to sessions, IP+hostname to clients
All checks were successful
check / check (push) Successful in 1m5s
All checks were successful
check / check (push) Successful in 1m5s
- Add ip column to sessions table (real client IP of session creator) - Add ip and hostname columns to clients table (per-connection tracking) - Update CreateSession, RegisterUser, LoginUser to store new fields - Add GetClientHostInfo query method - Update SessionHostInfo to include IP - Extract executeCreateSession to fix funlen lint - Add tests for session IP, client IP/hostname, login client tracking - Update README with new field documentation
This commit is contained in:
@@ -216,6 +216,12 @@ Each session has an IRC-style hostmask composed of three parts:
|
|||||||
`username` field in the session/register request; defaults to the nick)
|
`username` field in the session/register request; defaults to the nick)
|
||||||
- **hostname** — automatically resolved via reverse DNS of the connecting
|
- **hostname** — automatically resolved via reverse DNS of the connecting
|
||||||
client's IP address at session creation time
|
client's IP address at session creation time
|
||||||
|
- **ip** — the real IP address of the session creator, extracted from
|
||||||
|
`X-Forwarded-For`, `X-Real-IP`, or `RemoteAddr`
|
||||||
|
|
||||||
|
Each **client connection** (created at session creation, registration, or login)
|
||||||
|
also stores its own **ip** and **hostname**, allowing the server to track the
|
||||||
|
network origin of each individual client independently from the session.
|
||||||
|
|
||||||
The hostmask appears in:
|
The hostmask appears in:
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ var errNoPassword = errors.New(
|
|||||||
// and returns session ID, client ID, and token.
|
// and returns session ID, client ID, and token.
|
||||||
func (database *Database) RegisterUser(
|
func (database *Database) RegisterUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
nick, password, username, hostname string,
|
nick, password, username, hostname, remoteIP string,
|
||||||
) (int64, int64, string, error) {
|
) (int64, int64, string, error) {
|
||||||
if username == "" {
|
if username == "" {
|
||||||
username = nick
|
username = nick
|
||||||
@@ -54,11 +54,11 @@ func (database *Database) RegisterUser(
|
|||||||
|
|
||||||
res, err := transaction.ExecContext(ctx,
|
res, err := transaction.ExecContext(ctx,
|
||||||
`INSERT INTO sessions
|
`INSERT INTO sessions
|
||||||
(uuid, nick, username, hostname,
|
(uuid, nick, username, hostname, ip,
|
||||||
password_hash, created_at, last_seen)
|
password_hash, created_at, last_seen)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
sessionUUID, nick, username, hostname,
|
sessionUUID, nick, username, hostname,
|
||||||
string(hash), now, now)
|
remoteIP, string(hash), now, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = transaction.Rollback()
|
_ = transaction.Rollback()
|
||||||
|
|
||||||
@@ -73,10 +73,11 @@ func (database *Database) RegisterUser(
|
|||||||
|
|
||||||
clientRes, err := transaction.ExecContext(ctx,
|
clientRes, err := transaction.ExecContext(ctx,
|
||||||
`INSERT INTO clients
|
`INSERT INTO clients
|
||||||
(uuid, session_id, token,
|
(uuid, session_id, token, ip, hostname,
|
||||||
created_at, last_seen)
|
created_at, last_seen)
|
||||||
VALUES (?, ?, ?, ?, ?)`,
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||||
clientUUID, sessionID, tokenHash, now, now)
|
clientUUID, sessionID, tokenHash,
|
||||||
|
remoteIP, hostname, now, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = transaction.Rollback()
|
_ = transaction.Rollback()
|
||||||
|
|
||||||
@@ -101,7 +102,7 @@ func (database *Database) RegisterUser(
|
|||||||
// client token.
|
// client token.
|
||||||
func (database *Database) LoginUser(
|
func (database *Database) LoginUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
nick, password string,
|
nick, password, remoteIP, hostname string,
|
||||||
) (int64, int64, string, error) {
|
) (int64, int64, string, error) {
|
||||||
var (
|
var (
|
||||||
sessionID int64
|
sessionID int64
|
||||||
@@ -148,10 +149,11 @@ func (database *Database) LoginUser(
|
|||||||
|
|
||||||
res, err := database.conn.ExecContext(ctx,
|
res, err := database.conn.ExecContext(ctx,
|
||||||
`INSERT INTO clients
|
`INSERT INTO clients
|
||||||
(uuid, session_id, token,
|
(uuid, session_id, token, ip, hostname,
|
||||||
created_at, last_seen)
|
created_at, last_seen)
|
||||||
VALUES (?, ?, ?, ?, ?)`,
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||||
clientUUID, sessionID, tokenHash, now, now)
|
clientUUID, sessionID, tokenHash,
|
||||||
|
remoteIP, hostname, now, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, "", fmt.Errorf(
|
return 0, 0, "", fmt.Errorf(
|
||||||
"create login client: %w", err,
|
"create login client: %w", err,
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ func TestRegisterUser(t *testing.T) {
|
|||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
sessionID, clientID, token, err :=
|
sessionID, clientID, token, err :=
|
||||||
database.RegisterUser(ctx, "reguser", "password123", "", "")
|
database.RegisterUser(ctx, "reguser", "password123", "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -46,7 +46,7 @@ func TestRegisterUserWithUserHost(t *testing.T) {
|
|||||||
|
|
||||||
sessionID, _, _, err := database.RegisterUser(
|
sessionID, _, _, err := database.RegisterUser(
|
||||||
ctx, "reguhost", "password123",
|
ctx, "reguhost", "password123",
|
||||||
"myident", "example.org",
|
"myident", "example.org", "",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -80,7 +80,7 @@ func TestRegisterUserDefaultUsername(t *testing.T) {
|
|||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
sessionID, _, _, err := database.RegisterUser(
|
sessionID, _, _, err := database.RegisterUser(
|
||||||
ctx, "regdefault", "password123", "", "",
|
ctx, "regdefault", "password123", "", "", "",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -108,7 +108,7 @@ func TestRegisterUserDuplicateNick(t *testing.T) {
|
|||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
regSID, regCID, regToken, err :=
|
regSID, regCID, regToken, err :=
|
||||||
database.RegisterUser(ctx, "dupnick", "password123", "", "")
|
database.RegisterUser(ctx, "dupnick", "password123", "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -118,7 +118,7 @@ func TestRegisterUserDuplicateNick(t *testing.T) {
|
|||||||
_ = regToken
|
_ = regToken
|
||||||
|
|
||||||
dupSID, dupCID, dupToken, dupErr :=
|
dupSID, dupCID, dupToken, dupErr :=
|
||||||
database.RegisterUser(ctx, "dupnick", "other12345", "", "")
|
database.RegisterUser(ctx, "dupnick", "other12345", "", "", "")
|
||||||
if dupErr == nil {
|
if dupErr == nil {
|
||||||
t.Fatal("expected error for duplicate nick")
|
t.Fatal("expected error for duplicate nick")
|
||||||
}
|
}
|
||||||
@@ -135,7 +135,7 @@ func TestLoginUser(t *testing.T) {
|
|||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
regSID, regCID, regToken, err :=
|
regSID, regCID, regToken, err :=
|
||||||
database.RegisterUser(ctx, "loginuser", "mypassword", "", "")
|
database.RegisterUser(ctx, "loginuser", "mypassword", "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -145,7 +145,7 @@ func TestLoginUser(t *testing.T) {
|
|||||||
_ = regToken
|
_ = regToken
|
||||||
|
|
||||||
sessionID, clientID, token, err :=
|
sessionID, clientID, token, err :=
|
||||||
database.LoginUser(ctx, "loginuser", "mypassword")
|
database.LoginUser(ctx, "loginuser", "mypassword", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -166,6 +166,83 @@ func TestLoginUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoginUserStoresClientIPHostname(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
database := setupTestDB(t)
|
||||||
|
ctx := t.Context()
|
||||||
|
|
||||||
|
regSID, regCID, regToken, err := database.RegisterUser(
|
||||||
|
ctx, "loginipuser", "password123",
|
||||||
|
"", "", "10.0.0.1",
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = regSID
|
||||||
|
_ = regCID
|
||||||
|
_ = regToken
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, clientID, _, err := database.LoginUser(
|
||||||
|
ctx, "loginipuser", "password123",
|
||||||
|
"10.0.0.99", "newhost.example.com",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientInfo, err := database.GetClientHostInfo(
|
||||||
|
ctx, clientID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientInfo.IP != "10.0.0.99" {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected client IP 10.0.0.99, got %s",
|
||||||
|
clientInfo.IP,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientInfo.Hostname != "newhost.example.com" {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected hostname newhost.example.com, got %s",
|
||||||
|
clientInfo.Hostname,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterUserStoresSessionIP(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
database := setupTestDB(t)
|
||||||
|
ctx := t.Context()
|
||||||
|
|
||||||
|
sessionID, _, _, err := database.RegisterUser(
|
||||||
|
ctx, "regipuser", "password123",
|
||||||
|
"ident", "host.local", "172.16.0.5",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := database.GetSessionHostInfo(
|
||||||
|
ctx, sessionID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IP != "172.16.0.5" {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected session IP 172.16.0.5, got %s",
|
||||||
|
info.IP,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLoginUserWrongPassword(t *testing.T) {
|
func TestLoginUserWrongPassword(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -173,7 +250,7 @@ func TestLoginUserWrongPassword(t *testing.T) {
|
|||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
regSID, regCID, regToken, err :=
|
regSID, regCID, regToken, err :=
|
||||||
database.RegisterUser(ctx, "wrongpw", "correctpass", "", "")
|
database.RegisterUser(ctx, "wrongpw", "correctpass", "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -183,7 +260,7 @@ func TestLoginUserWrongPassword(t *testing.T) {
|
|||||||
_ = regToken
|
_ = regToken
|
||||||
|
|
||||||
loginSID, loginCID, loginToken, loginErr :=
|
loginSID, loginCID, loginToken, loginErr :=
|
||||||
database.LoginUser(ctx, "wrongpw", "wrongpass12")
|
database.LoginUser(ctx, "wrongpw", "wrongpass12", "", "")
|
||||||
if loginErr == nil {
|
if loginErr == nil {
|
||||||
t.Fatal("expected error for wrong password")
|
t.Fatal("expected error for wrong password")
|
||||||
}
|
}
|
||||||
@@ -201,7 +278,7 @@ func TestLoginUserNoPassword(t *testing.T) {
|
|||||||
|
|
||||||
// Create anonymous session (no password).
|
// Create anonymous session (no password).
|
||||||
anonSID, anonCID, anonToken, err :=
|
anonSID, anonCID, anonToken, err :=
|
||||||
database.CreateSession(ctx, "anon", "", "")
|
database.CreateSession(ctx, "anon", "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -211,7 +288,7 @@ func TestLoginUserNoPassword(t *testing.T) {
|
|||||||
_ = anonToken
|
_ = anonToken
|
||||||
|
|
||||||
loginSID, loginCID, loginToken, loginErr :=
|
loginSID, loginCID, loginToken, loginErr :=
|
||||||
database.LoginUser(ctx, "anon", "anything1")
|
database.LoginUser(ctx, "anon", "anything1", "", "")
|
||||||
if loginErr == nil {
|
if loginErr == nil {
|
||||||
t.Fatal(
|
t.Fatal(
|
||||||
"expected error for login on passwordless account",
|
"expected error for login on passwordless account",
|
||||||
@@ -230,7 +307,7 @@ func TestLoginUserNonexistent(t *testing.T) {
|
|||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
loginSID, loginCID, loginToken, err :=
|
loginSID, loginCID, loginToken, err :=
|
||||||
database.LoginUser(ctx, "ghost", "password123")
|
database.LoginUser(ctx, "ghost", "password123", "", "")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for nonexistent user")
|
t.Fatal("expected error for nonexistent user")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ func FormatHostmask(nick, username, hostname string) string {
|
|||||||
// CreateSession registers a new session and its first client.
|
// CreateSession registers a new session and its first client.
|
||||||
func (database *Database) CreateSession(
|
func (database *Database) CreateSession(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
nick, username, hostname string,
|
nick, username, hostname, remoteIP string,
|
||||||
) (int64, int64, string, error) {
|
) (int64, int64, string, error) {
|
||||||
if username == "" {
|
if username == "" {
|
||||||
username = nick
|
username = nick
|
||||||
@@ -127,10 +127,11 @@ func (database *Database) CreateSession(
|
|||||||
|
|
||||||
res, err := transaction.ExecContext(ctx,
|
res, err := transaction.ExecContext(ctx,
|
||||||
`INSERT INTO sessions
|
`INSERT INTO sessions
|
||||||
(uuid, nick, username, hostname,
|
(uuid, nick, username, hostname, ip,
|
||||||
created_at, last_seen)
|
created_at, last_seen)
|
||||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||||
sessionUUID, nick, username, hostname, now, now)
|
sessionUUID, nick, username, hostname,
|
||||||
|
remoteIP, now, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = transaction.Rollback()
|
_ = transaction.Rollback()
|
||||||
|
|
||||||
@@ -145,10 +146,11 @@ func (database *Database) CreateSession(
|
|||||||
|
|
||||||
clientRes, err := transaction.ExecContext(ctx,
|
clientRes, err := transaction.ExecContext(ctx,
|
||||||
`INSERT INTO clients
|
`INSERT INTO clients
|
||||||
(uuid, session_id, token,
|
(uuid, session_id, token, ip, hostname,
|
||||||
created_at, last_seen)
|
created_at, last_seen)
|
||||||
VALUES (?, ?, ?, ?, ?)`,
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||||
clientUUID, sessionID, tokenHash, now, now)
|
clientUUID, sessionID, tokenHash,
|
||||||
|
remoteIP, hostname, now, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = transaction.Rollback()
|
_ = transaction.Rollback()
|
||||||
|
|
||||||
@@ -236,14 +238,16 @@ func (database *Database) GetSessionByNick(
|
|||||||
return sessionID, nil
|
return sessionID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionHostInfo holds the username and hostname for a session.
|
// SessionHostInfo holds the username, hostname, and IP
|
||||||
|
// for a session.
|
||||||
type SessionHostInfo struct {
|
type SessionHostInfo struct {
|
||||||
Username string
|
Username string
|
||||||
Hostname string
|
Hostname string
|
||||||
|
IP string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSessionHostInfo returns the username and hostname
|
// GetSessionHostInfo returns the username, hostname,
|
||||||
// for a session.
|
// and IP for a session.
|
||||||
func (database *Database) GetSessionHostInfo(
|
func (database *Database) GetSessionHostInfo(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID int64,
|
sessionID int64,
|
||||||
@@ -252,10 +256,10 @@ func (database *Database) GetSessionHostInfo(
|
|||||||
|
|
||||||
err := database.conn.QueryRowContext(
|
err := database.conn.QueryRowContext(
|
||||||
ctx,
|
ctx,
|
||||||
`SELECT username, hostname
|
`SELECT username, hostname, ip
|
||||||
FROM sessions WHERE id = ?`,
|
FROM sessions WHERE id = ?`,
|
||||||
sessionID,
|
sessionID,
|
||||||
).Scan(&info.Username, &info.Hostname)
|
).Scan(&info.Username, &info.Hostname, &info.IP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"get session host info: %w", err,
|
"get session host info: %w", err,
|
||||||
@@ -265,6 +269,35 @@ func (database *Database) GetSessionHostInfo(
|
|||||||
return &info, nil
|
return &info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClientHostInfo holds the IP and hostname for a client.
|
||||||
|
type ClientHostInfo struct {
|
||||||
|
IP string
|
||||||
|
Hostname string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientHostInfo returns the IP and hostname for a
|
||||||
|
// client.
|
||||||
|
func (database *Database) GetClientHostInfo(
|
||||||
|
ctx context.Context,
|
||||||
|
clientID int64,
|
||||||
|
) (*ClientHostInfo, error) {
|
||||||
|
var info ClientHostInfo
|
||||||
|
|
||||||
|
err := database.conn.QueryRowContext(
|
||||||
|
ctx,
|
||||||
|
`SELECT ip, hostname
|
||||||
|
FROM clients WHERE id = ?`,
|
||||||
|
clientID,
|
||||||
|
).Scan(&info.IP, &info.Hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"get client host info: %w", err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &info, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetChannelByName returns the channel ID for a name.
|
// GetChannelByName returns the channel ID for a name.
|
||||||
func (database *Database) GetChannelByName(
|
func (database *Database) GetChannelByName(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func TestCreateSession(t *testing.T) {
|
|||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
sessionID, _, token, err := database.CreateSession(
|
sessionID, _, token, err := database.CreateSession(
|
||||||
ctx, "alice", "", "",
|
ctx, "alice", "", "", "",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -45,7 +45,7 @@ func TestCreateSession(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, _, dupToken, dupErr := database.CreateSession(
|
_, _, dupToken, dupErr := database.CreateSession(
|
||||||
ctx, "alice", "", "",
|
ctx, "alice", "", "", "",
|
||||||
)
|
)
|
||||||
if dupErr == nil {
|
if dupErr == nil {
|
||||||
t.Fatal("expected error for duplicate nick")
|
t.Fatal("expected error for duplicate nick")
|
||||||
@@ -65,7 +65,7 @@ func assertSessionHostInfo(
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
sessionID, _, _, err := database.CreateSession(
|
sessionID, _, _, err := database.CreateSession(
|
||||||
t.Context(), nick, inputUser, inputHost,
|
t.Context(), nick, inputUser, inputHost, "",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -118,6 +118,69 @@ func TestCreateSessionDefaultUsername(t *testing.T) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateSessionStoresIP(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
database := setupTestDB(t)
|
||||||
|
ctx := t.Context()
|
||||||
|
|
||||||
|
sessionID, clientID, _, err := database.CreateSession(
|
||||||
|
ctx, "ipuser", "ident", "host.example.com",
|
||||||
|
"192.168.1.42",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := database.GetSessionHostInfo(
|
||||||
|
ctx, sessionID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IP != "192.168.1.42" {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected session IP 192.168.1.42, got %s",
|
||||||
|
info.IP,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientInfo, err := database.GetClientHostInfo(
|
||||||
|
ctx, clientID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientInfo.IP != "192.168.1.42" {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected client IP 192.168.1.42, got %s",
|
||||||
|
clientInfo.IP,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientInfo.Hostname != "host.example.com" {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected client hostname host.example.com, got %s",
|
||||||
|
clientInfo.Hostname,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetClientHostInfoNotFound(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
database := setupTestDB(t)
|
||||||
|
|
||||||
|
_, err := database.GetClientHostInfo(
|
||||||
|
t.Context(), 99999,
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for nonexistent client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetSessionHostInfoNotFound(t *testing.T) {
|
func TestGetSessionHostInfoNotFound(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -183,7 +246,7 @@ func TestChannelMembersIncludeUserHost(t *testing.T) {
|
|||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
sid, _, _, err := database.CreateSession(
|
sid, _, _, err := database.CreateSession(
|
||||||
ctx, "memuser", "myuser", "myhost.net",
|
ctx, "memuser", "myuser", "myhost.net", "",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -233,7 +296,7 @@ func TestGetSessionByToken(t *testing.T) {
|
|||||||
database := setupTestDB(t)
|
database := setupTestDB(t)
|
||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
_, _, token, err := database.CreateSession(ctx, "bob", "", "")
|
_, _, token, err := database.CreateSession(ctx, "bob", "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -266,7 +329,7 @@ func TestGetSessionByNick(t *testing.T) {
|
|||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
charlieID, charlieClientID, charlieToken, err :=
|
charlieID, charlieClientID, charlieToken, err :=
|
||||||
database.CreateSession(ctx, "charlie", "", "")
|
database.CreateSession(ctx, "charlie", "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -323,7 +386,7 @@ func TestJoinAndPart(t *testing.T) {
|
|||||||
database := setupTestDB(t)
|
database := setupTestDB(t)
|
||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
sid, _, _, err := database.CreateSession(ctx, "user1", "", "")
|
sid, _, _, err := database.CreateSession(ctx, "user1", "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -372,7 +435,7 @@ func TestDeleteChannelIfEmpty(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sid, _, _, err := database.CreateSession(ctx, "temp", "", "")
|
sid, _, _, err := database.CreateSession(ctx, "temp", "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -407,7 +470,7 @@ func createSessionWithChannels(
|
|||||||
|
|
||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
sid, _, _, err := database.CreateSession(ctx, nick, "", "")
|
sid, _, _, err := database.CreateSession(ctx, nick, "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -490,7 +553,7 @@ func TestChangeNick(t *testing.T) {
|
|||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
sid, _, token, err := database.CreateSession(
|
sid, _, token, err := database.CreateSession(
|
||||||
ctx, "old", "", "",
|
ctx, "old", "", "", "",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -574,7 +637,7 @@ func TestPollMessages(t *testing.T) {
|
|||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
sid, _, token, err := database.CreateSession(
|
sid, _, token, err := database.CreateSession(
|
||||||
ctx, "poller", "", "",
|
ctx, "poller", "", "", "",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -681,7 +744,7 @@ func TestDeleteSession(t *testing.T) {
|
|||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
sid, _, _, err := database.CreateSession(
|
sid, _, _, err := database.CreateSession(
|
||||||
ctx, "deleteme", "", "",
|
ctx, "deleteme", "", "", "",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -721,12 +784,12 @@ func TestChannelMembers(t *testing.T) {
|
|||||||
database := setupTestDB(t)
|
database := setupTestDB(t)
|
||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
sid1, _, _, err := database.CreateSession(ctx, "m1", "", "")
|
sid1, _, _, err := database.CreateSession(ctx, "m1", "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sid2, _, _, err := database.CreateSession(ctx, "m2", "", "")
|
sid2, _, _, err := database.CreateSession(ctx, "m2", "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -784,7 +847,7 @@ func TestEnqueueToClient(t *testing.T) {
|
|||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
_, _, token, err := database.CreateSession(
|
_, _, token, err := database.CreateSession(
|
||||||
ctx, "enqclient", "", "",
|
ctx, "enqclient", "", "", "",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ CREATE TABLE IF NOT EXISTS sessions (
|
|||||||
nick TEXT NOT NULL UNIQUE,
|
nick TEXT NOT NULL UNIQUE,
|
||||||
username TEXT NOT NULL DEFAULT '',
|
username TEXT NOT NULL DEFAULT '',
|
||||||
hostname TEXT NOT NULL DEFAULT '',
|
hostname TEXT NOT NULL DEFAULT '',
|
||||||
|
ip TEXT NOT NULL DEFAULT '',
|
||||||
password_hash TEXT NOT NULL DEFAULT '',
|
password_hash TEXT NOT NULL DEFAULT '',
|
||||||
signing_key TEXT NOT NULL DEFAULT '',
|
signing_key TEXT NOT NULL DEFAULT '',
|
||||||
away_message TEXT NOT NULL DEFAULT '',
|
away_message TEXT NOT NULL DEFAULT '',
|
||||||
@@ -22,6 +23,8 @@ CREATE TABLE IF NOT EXISTS clients (
|
|||||||
uuid TEXT NOT NULL UNIQUE,
|
uuid TEXT NOT NULL UNIQUE,
|
||||||
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
|
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
|
||||||
token TEXT NOT NULL UNIQUE,
|
token TEXT NOT NULL UNIQUE,
|
||||||
|
ip TEXT NOT NULL DEFAULT '',
|
||||||
|
hostname TEXT NOT NULL DEFAULT '',
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
|
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -251,14 +251,26 @@ func (hdlr *Handlers) handleCreateSession(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hdlr.executeCreateSession(
|
||||||
|
writer, request, payload.Nick, username,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hdlr *Handlers) executeCreateSession(
|
||||||
|
writer http.ResponseWriter,
|
||||||
|
request *http.Request,
|
||||||
|
nick, username string,
|
||||||
|
) {
|
||||||
|
remoteIP := clientIP(request)
|
||||||
|
|
||||||
hostname := resolveHostname(
|
hostname := resolveHostname(
|
||||||
request.Context(), clientIP(request),
|
request.Context(), remoteIP,
|
||||||
)
|
)
|
||||||
|
|
||||||
sessionID, clientID, token, err :=
|
sessionID, clientID, token, err :=
|
||||||
hdlr.params.Database.CreateSession(
|
hdlr.params.Database.CreateSession(
|
||||||
request.Context(),
|
request.Context(),
|
||||||
payload.Nick, username, hostname,
|
nick, username, hostname, remoteIP,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hdlr.handleCreateSessionError(
|
hdlr.handleCreateSessionError(
|
||||||
@@ -271,11 +283,11 @@ func (hdlr *Handlers) handleCreateSession(
|
|||||||
hdlr.stats.IncrSessions()
|
hdlr.stats.IncrSessions()
|
||||||
hdlr.stats.IncrConnections()
|
hdlr.stats.IncrConnections()
|
||||||
|
|
||||||
hdlr.deliverMOTD(request, clientID, sessionID, payload.Nick)
|
hdlr.deliverMOTD(request, clientID, sessionID, nick)
|
||||||
|
|
||||||
hdlr.respondJSON(writer, request, map[string]any{
|
hdlr.respondJSON(writer, request, map[string]any{
|
||||||
"id": sessionID,
|
"id": sessionID,
|
||||||
"nick": payload.Nick,
|
"nick": nick,
|
||||||
"token": token,
|
"token": token,
|
||||||
}, http.StatusCreated)
|
}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -94,14 +94,16 @@ func (hdlr *Handlers) executeRegister(
|
|||||||
request *http.Request,
|
request *http.Request,
|
||||||
nick, password, username string,
|
nick, password, username string,
|
||||||
) {
|
) {
|
||||||
|
remoteIP := clientIP(request)
|
||||||
|
|
||||||
hostname := resolveHostname(
|
hostname := resolveHostname(
|
||||||
request.Context(), clientIP(request),
|
request.Context(), remoteIP,
|
||||||
)
|
)
|
||||||
|
|
||||||
sessionID, clientID, token, err :=
|
sessionID, clientID, token, err :=
|
||||||
hdlr.params.Database.RegisterUser(
|
hdlr.params.Database.RegisterUser(
|
||||||
request.Context(),
|
request.Context(),
|
||||||
nick, password, username, hostname,
|
nick, password, username, hostname, remoteIP,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hdlr.handleRegisterError(
|
hdlr.handleRegisterError(
|
||||||
@@ -196,11 +198,18 @@ func (hdlr *Handlers) handleLogin(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
remoteIP := clientIP(request)
|
||||||
|
|
||||||
|
hostname := resolveHostname(
|
||||||
|
request.Context(), remoteIP,
|
||||||
|
)
|
||||||
|
|
||||||
sessionID, clientID, token, err :=
|
sessionID, clientID, token, err :=
|
||||||
hdlr.params.Database.LoginUser(
|
hdlr.params.Database.LoginUser(
|
||||||
request.Context(),
|
request.Context(),
|
||||||
payload.Nick,
|
payload.Nick,
|
||||||
payload.Password,
|
payload.Password,
|
||||||
|
remoteIP, hostname,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hdlr.respondError(
|
hdlr.respondError(
|
||||||
|
|||||||
Reference in New Issue
Block a user