diff --git a/internal/vaultik/snapshot.go b/internal/vaultik/snapshot.go index 2960389..8acaeaf 100644 --- a/internal/vaultik/snapshot.go +++ b/internal/vaultik/snapshot.go @@ -5,7 +5,6 @@ import ( "fmt" "os" "path/filepath" - "regexp" "sort" "strings" "text/tabwriter" @@ -1127,18 +1126,25 @@ func (v *Vaultik) PruneDatabase() (*PruneResult, error) { return result, nil } -// validTableNameRe matches table names containing only lowercase alphanumeric characters and underscores. -var validTableNameRe = regexp.MustCompile(`^[a-z0-9_]+$`) +// allowedTableNames is the exhaustive whitelist of table names that may be +// passed to getTableCount. Any name not in this set is rejected, preventing +// SQL injection even if caller-controlled input is accidentally supplied. +var allowedTableNames = map[string]struct{}{ + "files": {}, + "chunks": {}, + "blobs": {}, +} -// getTableCount returns the count of rows in a table. -// The tableName is sanitized to only allow [a-z0-9_] characters to prevent SQL injection. +// getTableCount returns the number of rows in the given table. +// tableName must appear in the allowedTableNames whitelist; all other values +// are rejected with an error, preventing SQL injection. func (v *Vaultik) getTableCount(tableName string) (int64, error) { - if v.DB == nil { - return 0, nil + if _, ok := allowedTableNames[tableName]; !ok { + return 0, fmt.Errorf("table name not allowed: %q", tableName) } - if !validTableNameRe.MatchString(tableName) { - return 0, fmt.Errorf("invalid table name: %q", tableName) + if v.DB == nil { + return 0, nil } var count int64 diff --git a/internal/vaultik/table_count_test.go b/internal/vaultik/table_count_test.go new file mode 100644 index 0000000..2ecb439 --- /dev/null +++ b/internal/vaultik/table_count_test.go @@ -0,0 +1,51 @@ +package vaultik + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAllowedTableNames(t *testing.T) { + // Verify the whitelist contains exactly the expected tables + expected := []string{"files", "chunks", "blobs"} + assert.Len(t, allowedTableNames, len(expected)) + for _, name := range expected { + _, ok := allowedTableNames[name] + assert.True(t, ok, "expected %q in allowedTableNames", name) + } +} + +func TestGetTableCount_RejectsInvalidNames(t *testing.T) { + v := &Vaultik{} // DB is nil, but rejection happens before DB access + v.DB = nil // explicit + + tests := []struct { + name string + tableName string + wantErr bool + }{ + {"allowed files", "files", false}, + {"allowed chunks", "chunks", false}, + {"allowed blobs", "blobs", false}, + {"sql injection attempt", "files; DROP TABLE files--", true}, + {"unknown table", "users", true}, + {"empty string", "", true}, + {"uppercase", "FILES", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + count, err := v.getTableCount(tt.tableName) + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), "not allowed") + assert.Equal(t, int64(0), count) + } else { + // DB is nil so returns 0, nil for allowed names + assert.NoError(t, err) + assert.Equal(t, int64(0), count) + } + }) + } +}