Validate table name against allowlist in getTableCount (closes #27) #32
@ -1126,12 +1126,26 @@ func (v *Vaultik) PruneDatabase() (*PruneResult, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// getTableCount returns the count of rows in a table
|
||||
// validTableNames is the allowlist of table names that can be counted.
|
||||
var validTableNames = map[string]bool{
|
||||
"files": true,
|
||||
"chunks": true,
|
||||
"blobs": true,
|
||||
"uploads": true,
|
||||
"snapshots": true,
|
||||
}
|
||||
|
||||
// getTableCount returns the count of rows in a table.
|
||||
// The tableName must be in the validTableNames allowlist to prevent SQL injection.
|
||||
func (v *Vaultik) getTableCount(tableName string) (int64, error) {
|
||||
if v.DB == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if !validTableNames[tableName] {
|
||||
return 0, fmt.Errorf("invalid table name: %q", tableName)
|
||||
}
|
||||
|
||||
var count int64
|
||||
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName)
|
||||
err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&count)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user