package cli import ( "bytes" "context" "io" "net/http" "net/http/httptest" "os" "path/filepath" "testing" "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "sneak.berlin/go/mfer/internal/scanner" "sneak.berlin/go/mfer/mfer" ) func TestSanitizePath(t *testing.T) { // Valid paths that should be accepted validTests := []struct { input string expected string }{ {"file.txt", "file.txt"}, {"dir/file.txt", "dir/file.txt"}, {"dir/subdir/file.txt", "dir/subdir/file.txt"}, {"./file.txt", "file.txt"}, {"./dir/file.txt", "dir/file.txt"}, {"dir/./file.txt", "dir/file.txt"}, } for _, tt := range validTests { t.Run("valid:"+tt.input, func(t *testing.T) { result, err := sanitizePath(tt.input) assert.NoError(t, err) assert.Equal(t, tt.expected, result) }) } // Invalid paths that should be rejected invalidTests := []struct { input string desc string }{ {"", "empty path"}, {"..", "parent directory"}, {"../file.txt", "parent traversal"}, {"../../file.txt", "double parent traversal"}, {"dir/../../../file.txt", "traversal escaping base"}, {"/etc/passwd", "absolute path"}, {"/file.txt", "absolute path with single component"}, {"dir/../../etc/passwd", "traversal to system file"}, } for _, tt := range invalidTests { t.Run("invalid:"+tt.desc, func(t *testing.T) { _, err := sanitizePath(tt.input) assert.Error(t, err, "expected error for path: %s", tt.input) }) } } func TestResolveManifestURL(t *testing.T) { tests := []struct { input string expected string }{ // Already ends with .mf - use as-is {"https://example.com/path/index.mf", "https://example.com/path/index.mf"}, {"https://example.com/path/custom.mf", "https://example.com/path/custom.mf"}, {"https://example.com/foo.mf", "https://example.com/foo.mf"}, // Directory with trailing slash - append index.mf {"https://example.com/path/", "https://example.com/path/index.mf"}, {"https://example.com/", "https://example.com/index.mf"}, // Directory without trailing slash - add slash and index.mf {"https://example.com/path", "https://example.com/path/index.mf"}, {"https://example.com", "https://example.com/index.mf"}, // With query strings {"https://example.com/path?foo=bar", "https://example.com/path/index.mf?foo=bar"}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { result, err := resolveManifestURL(tt.input) assert.NoError(t, err) assert.Equal(t, tt.expected, result) }) } } func TestFetchFromHTTP(t *testing.T) { // Create source filesystem with test files sourceFs := afero.NewMemMapFs() testFiles := map[string][]byte{ "file1.txt": []byte("Hello, World!"), "file2.txt": []byte("This is file 2 with more content."), "subdir/file3.txt": []byte("Nested file content here."), "subdir/deep/f.txt": []byte("Deeply nested file."), } for path, content := range testFiles { fullPath := "/" + path // MemMapFs needs absolute paths dir := filepath.Dir(fullPath) require.NoError(t, sourceFs.MkdirAll(dir, 0755)) require.NoError(t, afero.WriteFile(sourceFs, fullPath, content, 0644)) } // Generate manifest using scanner opts := &scanner.Options{ Fs: sourceFs, } s := scanner.NewWithOptions(opts) require.NoError(t, s.EnumerateFS(sourceFs, "/", nil)) var manifestBuf bytes.Buffer require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil)) manifestData := manifestBuf.Bytes() // Create HTTP server that serves the source filesystem server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { path := r.URL.Path if path == "/index.mf" { w.Header().Set("Content-Type", "application/octet-stream") w.Write(manifestData) return } // Strip leading slash if len(path) > 0 && path[0] == '/' { path = path[1:] } content, exists := testFiles[path] if !exists { http.NotFound(w, r) return } w.Header().Set("Content-Type", "application/octet-stream") w.Write(content) })) defer server.Close() // Create destination directory destDir, err := os.MkdirTemp("", "mfer-fetch-test-*") require.NoError(t, err) defer os.RemoveAll(destDir) // Change to dest directory for the test origDir, err := os.Getwd() require.NoError(t, err) require.NoError(t, os.Chdir(destDir)) defer os.Chdir(origDir) // Parse the manifest to get file entries manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestData)) require.NoError(t, err) files := manifest.Files() require.Len(t, files, len(testFiles)) // Download each file using downloadFile progress := make(chan DownloadProgress, 10) go func() { for range progress { // Drain progress channel } }() baseURL := server.URL + "/" for _, f := range files { localPath, err := sanitizePath(f.Path) require.NoError(t, err) fileURL := baseURL + f.Path err = downloadFile(fileURL, localPath, f, progress) require.NoError(t, err, "failed to download %s", f.Path) } close(progress) // Verify downloaded files match originals for path, expectedContent := range testFiles { downloadedPath := filepath.Join(destDir, path) downloadedContent, err := os.ReadFile(downloadedPath) require.NoError(t, err, "failed to read downloaded file %s", path) assert.Equal(t, expectedContent, downloadedContent, "content mismatch for %s", path) } } func TestFetchHashMismatch(t *testing.T) { // Create source filesystem with a test file sourceFs := afero.NewMemMapFs() originalContent := []byte("Original content") require.NoError(t, afero.WriteFile(sourceFs, "/file.txt", originalContent, 0644)) // Generate manifest opts := &scanner.Options{Fs: sourceFs} s := scanner.NewWithOptions(opts) require.NoError(t, s.EnumerateFS(sourceFs, "/", nil)) var manifestBuf bytes.Buffer require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil)) // Parse manifest manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes())) require.NoError(t, err) files := manifest.Files() require.Len(t, files, 1) // Create server that serves DIFFERENT content (to trigger hash mismatch) tamperedContent := []byte("Tampered content!") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/octet-stream") w.Write(tamperedContent) })) defer server.Close() // Create temp directory destDir, err := os.MkdirTemp("", "mfer-fetch-hash-test-*") require.NoError(t, err) defer os.RemoveAll(destDir) origDir, err := os.Getwd() require.NoError(t, err) require.NoError(t, os.Chdir(destDir)) defer os.Chdir(origDir) // Try to download - should fail with hash mismatch err = downloadFile(server.URL+"/file.txt", "file.txt", files[0], nil) assert.Error(t, err) assert.Contains(t, err.Error(), "mismatch") // Verify temp file was cleaned up _, err = os.Stat(".file.txt.tmp") assert.True(t, os.IsNotExist(err), "temp file should be cleaned up on hash mismatch") // Verify final file was not created _, err = os.Stat("file.txt") assert.True(t, os.IsNotExist(err), "final file should not exist on hash mismatch") } func TestFetchSizeMismatch(t *testing.T) { // Create source filesystem with a test file sourceFs := afero.NewMemMapFs() originalContent := []byte("Original content with specific size") require.NoError(t, afero.WriteFile(sourceFs, "/file.txt", originalContent, 0644)) // Generate manifest opts := &scanner.Options{Fs: sourceFs} s := scanner.NewWithOptions(opts) require.NoError(t, s.EnumerateFS(sourceFs, "/", nil)) var manifestBuf bytes.Buffer require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil)) // Parse manifest manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes())) require.NoError(t, err) files := manifest.Files() require.Len(t, files, 1) // Create server that serves content with wrong size wrongSizeContent := []byte("Short") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/octet-stream") w.Write(wrongSizeContent) })) defer server.Close() // Create temp directory destDir, err := os.MkdirTemp("", "mfer-fetch-size-test-*") require.NoError(t, err) defer os.RemoveAll(destDir) origDir, err := os.Getwd() require.NoError(t, err) require.NoError(t, os.Chdir(destDir)) defer os.Chdir(origDir) // Try to download - should fail with size mismatch err = downloadFile(server.URL+"/file.txt", "file.txt", files[0], nil) assert.Error(t, err) assert.Contains(t, err.Error(), "size mismatch") // Verify temp file was cleaned up _, err = os.Stat(".file.txt.tmp") assert.True(t, os.IsNotExist(err), "temp file should be cleaned up on size mismatch") } func TestFetchProgress(t *testing.T) { // Create source filesystem with a larger test file sourceFs := afero.NewMemMapFs() // Create content large enough to trigger multiple progress updates content := bytes.Repeat([]byte("x"), 100*1024) // 100KB require.NoError(t, afero.WriteFile(sourceFs, "/large.txt", content, 0644)) // Generate manifest opts := &scanner.Options{Fs: sourceFs} s := scanner.NewWithOptions(opts) require.NoError(t, s.EnumerateFS(sourceFs, "/", nil)) var manifestBuf bytes.Buffer require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil)) // Parse manifest manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes())) require.NoError(t, err) files := manifest.Files() require.Len(t, files, 1) // Create server that serves the content server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/octet-stream") w.Header().Set("Content-Length", "102400") // Write in chunks to allow progress reporting reader := bytes.NewReader(content) io.Copy(w, reader) })) defer server.Close() // Create temp directory destDir, err := os.MkdirTemp("", "mfer-fetch-progress-test-*") require.NoError(t, err) defer os.RemoveAll(destDir) origDir, err := os.Getwd() require.NoError(t, err) require.NoError(t, os.Chdir(destDir)) defer os.Chdir(origDir) // Set up progress channel and collect updates progress := make(chan DownloadProgress, 100) var progressUpdates []DownloadProgress done := make(chan struct{}) go func() { for p := range progress { progressUpdates = append(progressUpdates, p) } close(done) }() // Download err = downloadFile(server.URL+"/large.txt", "large.txt", files[0], progress) close(progress) <-done require.NoError(t, err) // Verify we got progress updates assert.NotEmpty(t, progressUpdates, "should have received progress updates") // Verify final progress shows complete if len(progressUpdates) > 0 { last := progressUpdates[len(progressUpdates)-1] assert.Equal(t, int64(len(content)), last.BytesRead, "final progress should show all bytes read") assert.Equal(t, "large.txt", last.Path) } // Verify file was downloaded correctly downloaded, err := os.ReadFile("large.txt") require.NoError(t, err) assert.Equal(t, content, downloaded) }