Replace callback-based progress reporting in Builder.AddFile with channel-based FileHashProgress for consistency with EnumerateStatus and ScanStatus patterns. Update scanner.go to use the new channel API.
370 lines
11 KiB
Go
370 lines
11 KiB
Go
package cli
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/spf13/afero"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"sneak.berlin/go/mfer/internal/scanner"
|
|
"sneak.berlin/go/mfer/mfer"
|
|
)
|
|
|
|
func TestSanitizePath(t *testing.T) {
|
|
// Valid paths that should be accepted
|
|
validTests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"file.txt", "file.txt"},
|
|
{"dir/file.txt", "dir/file.txt"},
|
|
{"dir/subdir/file.txt", "dir/subdir/file.txt"},
|
|
{"./file.txt", "file.txt"},
|
|
{"./dir/file.txt", "dir/file.txt"},
|
|
{"dir/./file.txt", "dir/file.txt"},
|
|
}
|
|
|
|
for _, tt := range validTests {
|
|
t.Run("valid:"+tt.input, func(t *testing.T) {
|
|
result, err := sanitizePath(tt.input)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
|
|
// Invalid paths that should be rejected
|
|
invalidTests := []struct {
|
|
input string
|
|
desc string
|
|
}{
|
|
{"", "empty path"},
|
|
{"..", "parent directory"},
|
|
{"../file.txt", "parent traversal"},
|
|
{"../../file.txt", "double parent traversal"},
|
|
{"dir/../../../file.txt", "traversal escaping base"},
|
|
{"/etc/passwd", "absolute path"},
|
|
{"/file.txt", "absolute path with single component"},
|
|
{"dir/../../etc/passwd", "traversal to system file"},
|
|
}
|
|
|
|
for _, tt := range invalidTests {
|
|
t.Run("invalid:"+tt.desc, func(t *testing.T) {
|
|
_, err := sanitizePath(tt.input)
|
|
assert.Error(t, err, "expected error for path: %s", tt.input)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestResolveManifestURL(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
// Already ends with .mf - use as-is
|
|
{"https://example.com/path/index.mf", "https://example.com/path/index.mf"},
|
|
{"https://example.com/path/custom.mf", "https://example.com/path/custom.mf"},
|
|
{"https://example.com/foo.mf", "https://example.com/foo.mf"},
|
|
|
|
// Directory with trailing slash - append index.mf
|
|
{"https://example.com/path/", "https://example.com/path/index.mf"},
|
|
{"https://example.com/", "https://example.com/index.mf"},
|
|
|
|
// Directory without trailing slash - add slash and index.mf
|
|
{"https://example.com/path", "https://example.com/path/index.mf"},
|
|
{"https://example.com", "https://example.com/index.mf"},
|
|
|
|
// With query strings
|
|
{"https://example.com/path?foo=bar", "https://example.com/path/index.mf?foo=bar"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.input, func(t *testing.T) {
|
|
result, err := resolveManifestURL(tt.input)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFetchFromHTTP(t *testing.T) {
|
|
// Create source filesystem with test files
|
|
sourceFs := afero.NewMemMapFs()
|
|
|
|
testFiles := map[string][]byte{
|
|
"file1.txt": []byte("Hello, World!"),
|
|
"file2.txt": []byte("This is file 2 with more content."),
|
|
"subdir/file3.txt": []byte("Nested file content here."),
|
|
"subdir/deep/f.txt": []byte("Deeply nested file."),
|
|
}
|
|
|
|
for path, content := range testFiles {
|
|
fullPath := "/" + path // MemMapFs needs absolute paths
|
|
dir := filepath.Dir(fullPath)
|
|
require.NoError(t, sourceFs.MkdirAll(dir, 0755))
|
|
require.NoError(t, afero.WriteFile(sourceFs, fullPath, content, 0644))
|
|
}
|
|
|
|
// Generate manifest using scanner
|
|
opts := &scanner.Options{
|
|
Fs: sourceFs,
|
|
}
|
|
s := scanner.NewWithOptions(opts)
|
|
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
|
|
|
|
var manifestBuf bytes.Buffer
|
|
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
|
|
manifestData := manifestBuf.Bytes()
|
|
|
|
// Create HTTP server that serves the source filesystem
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
path := r.URL.Path
|
|
if path == "/index.mf" {
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
w.Write(manifestData)
|
|
return
|
|
}
|
|
|
|
// Strip leading slash
|
|
if len(path) > 0 && path[0] == '/' {
|
|
path = path[1:]
|
|
}
|
|
|
|
content, exists := testFiles[path]
|
|
if !exists {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
w.Write(content)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create destination directory
|
|
destDir, err := os.MkdirTemp("", "mfer-fetch-test-*")
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(destDir)
|
|
|
|
// Change to dest directory for the test
|
|
origDir, err := os.Getwd()
|
|
require.NoError(t, err)
|
|
require.NoError(t, os.Chdir(destDir))
|
|
defer os.Chdir(origDir)
|
|
|
|
// Parse the manifest to get file entries
|
|
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestData))
|
|
require.NoError(t, err)
|
|
|
|
files := manifest.Files()
|
|
require.Len(t, files, len(testFiles))
|
|
|
|
// Download each file using downloadFile
|
|
progress := make(chan DownloadProgress, 10)
|
|
go func() {
|
|
for range progress {
|
|
// Drain progress channel
|
|
}
|
|
}()
|
|
|
|
baseURL := server.URL + "/"
|
|
for _, f := range files {
|
|
localPath, err := sanitizePath(f.Path)
|
|
require.NoError(t, err)
|
|
|
|
fileURL := baseURL + f.Path
|
|
err = downloadFile(fileURL, localPath, f, progress)
|
|
require.NoError(t, err, "failed to download %s", f.Path)
|
|
}
|
|
close(progress)
|
|
|
|
// Verify downloaded files match originals
|
|
for path, expectedContent := range testFiles {
|
|
downloadedPath := filepath.Join(destDir, path)
|
|
downloadedContent, err := os.ReadFile(downloadedPath)
|
|
require.NoError(t, err, "failed to read downloaded file %s", path)
|
|
assert.Equal(t, expectedContent, downloadedContent, "content mismatch for %s", path)
|
|
}
|
|
}
|
|
|
|
func TestFetchHashMismatch(t *testing.T) {
|
|
// Create source filesystem with a test file
|
|
sourceFs := afero.NewMemMapFs()
|
|
originalContent := []byte("Original content")
|
|
require.NoError(t, afero.WriteFile(sourceFs, "/file.txt", originalContent, 0644))
|
|
|
|
// Generate manifest
|
|
opts := &scanner.Options{Fs: sourceFs}
|
|
s := scanner.NewWithOptions(opts)
|
|
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
|
|
|
|
var manifestBuf bytes.Buffer
|
|
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
|
|
|
|
// Parse manifest
|
|
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes()))
|
|
require.NoError(t, err)
|
|
files := manifest.Files()
|
|
require.Len(t, files, 1)
|
|
|
|
// Create server that serves DIFFERENT content (to trigger hash mismatch)
|
|
tamperedContent := []byte("Tampered content!")
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
w.Write(tamperedContent)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create temp directory
|
|
destDir, err := os.MkdirTemp("", "mfer-fetch-hash-test-*")
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(destDir)
|
|
|
|
origDir, err := os.Getwd()
|
|
require.NoError(t, err)
|
|
require.NoError(t, os.Chdir(destDir))
|
|
defer os.Chdir(origDir)
|
|
|
|
// Try to download - should fail with hash mismatch
|
|
err = downloadFile(server.URL+"/file.txt", "file.txt", files[0], nil)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "mismatch")
|
|
|
|
// Verify temp file was cleaned up
|
|
_, err = os.Stat(".file.txt.tmp")
|
|
assert.True(t, os.IsNotExist(err), "temp file should be cleaned up on hash mismatch")
|
|
|
|
// Verify final file was not created
|
|
_, err = os.Stat("file.txt")
|
|
assert.True(t, os.IsNotExist(err), "final file should not exist on hash mismatch")
|
|
}
|
|
|
|
func TestFetchSizeMismatch(t *testing.T) {
|
|
// Create source filesystem with a test file
|
|
sourceFs := afero.NewMemMapFs()
|
|
originalContent := []byte("Original content with specific size")
|
|
require.NoError(t, afero.WriteFile(sourceFs, "/file.txt", originalContent, 0644))
|
|
|
|
// Generate manifest
|
|
opts := &scanner.Options{Fs: sourceFs}
|
|
s := scanner.NewWithOptions(opts)
|
|
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
|
|
|
|
var manifestBuf bytes.Buffer
|
|
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
|
|
|
|
// Parse manifest
|
|
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes()))
|
|
require.NoError(t, err)
|
|
files := manifest.Files()
|
|
require.Len(t, files, 1)
|
|
|
|
// Create server that serves content with wrong size
|
|
wrongSizeContent := []byte("Short")
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
w.Write(wrongSizeContent)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create temp directory
|
|
destDir, err := os.MkdirTemp("", "mfer-fetch-size-test-*")
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(destDir)
|
|
|
|
origDir, err := os.Getwd()
|
|
require.NoError(t, err)
|
|
require.NoError(t, os.Chdir(destDir))
|
|
defer os.Chdir(origDir)
|
|
|
|
// Try to download - should fail with size mismatch
|
|
err = downloadFile(server.URL+"/file.txt", "file.txt", files[0], nil)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "size mismatch")
|
|
|
|
// Verify temp file was cleaned up
|
|
_, err = os.Stat(".file.txt.tmp")
|
|
assert.True(t, os.IsNotExist(err), "temp file should be cleaned up on size mismatch")
|
|
}
|
|
|
|
func TestFetchProgress(t *testing.T) {
|
|
// Create source filesystem with a larger test file
|
|
sourceFs := afero.NewMemMapFs()
|
|
// Create content large enough to trigger multiple progress updates
|
|
content := bytes.Repeat([]byte("x"), 100*1024) // 100KB
|
|
require.NoError(t, afero.WriteFile(sourceFs, "/large.txt", content, 0644))
|
|
|
|
// Generate manifest
|
|
opts := &scanner.Options{Fs: sourceFs}
|
|
s := scanner.NewWithOptions(opts)
|
|
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
|
|
|
|
var manifestBuf bytes.Buffer
|
|
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
|
|
|
|
// Parse manifest
|
|
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes()))
|
|
require.NoError(t, err)
|
|
files := manifest.Files()
|
|
require.Len(t, files, 1)
|
|
|
|
// Create server that serves the content
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
w.Header().Set("Content-Length", "102400")
|
|
// Write in chunks to allow progress reporting
|
|
reader := bytes.NewReader(content)
|
|
io.Copy(w, reader)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create temp directory
|
|
destDir, err := os.MkdirTemp("", "mfer-fetch-progress-test-*")
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(destDir)
|
|
|
|
origDir, err := os.Getwd()
|
|
require.NoError(t, err)
|
|
require.NoError(t, os.Chdir(destDir))
|
|
defer os.Chdir(origDir)
|
|
|
|
// Set up progress channel and collect updates
|
|
progress := make(chan DownloadProgress, 100)
|
|
var progressUpdates []DownloadProgress
|
|
done := make(chan struct{})
|
|
go func() {
|
|
for p := range progress {
|
|
progressUpdates = append(progressUpdates, p)
|
|
}
|
|
close(done)
|
|
}()
|
|
|
|
// Download
|
|
err = downloadFile(server.URL+"/large.txt", "large.txt", files[0], progress)
|
|
close(progress)
|
|
<-done
|
|
|
|
require.NoError(t, err)
|
|
|
|
// Verify we got progress updates
|
|
assert.NotEmpty(t, progressUpdates, "should have received progress updates")
|
|
|
|
// Verify final progress shows complete
|
|
if len(progressUpdates) > 0 {
|
|
last := progressUpdates[len(progressUpdates)-1]
|
|
assert.Equal(t, int64(len(content)), last.BytesRead, "final progress should show all bytes read")
|
|
assert.Equal(t, "large.txt", last.Path)
|
|
}
|
|
|
|
// Verify file was downloaded correctly
|
|
downloaded, err := os.ReadFile("large.txt")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, content, downloaded)
|
|
}
|