Change FileProgress callback to channel-based progress
Replace callback-based progress reporting in Builder.AddFile with channel-based FileHashProgress for consistency with EnumerateStatus and ScanStatus patterns. Update scanner.go to use the new channel API.
This commit is contained in:
369
internal/cli/fetch_test.go
Normal file
369
internal/cli/fetch_test.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"sneak.berlin/go/mfer/internal/scanner"
|
||||
"sneak.berlin/go/mfer/mfer"
|
||||
)
|
||||
|
||||
func TestSanitizePath(t *testing.T) {
|
||||
// Valid paths that should be accepted
|
||||
validTests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"file.txt", "file.txt"},
|
||||
{"dir/file.txt", "dir/file.txt"},
|
||||
{"dir/subdir/file.txt", "dir/subdir/file.txt"},
|
||||
{"./file.txt", "file.txt"},
|
||||
{"./dir/file.txt", "dir/file.txt"},
|
||||
{"dir/./file.txt", "dir/file.txt"},
|
||||
}
|
||||
|
||||
for _, tt := range validTests {
|
||||
t.Run("valid:"+tt.input, func(t *testing.T) {
|
||||
result, err := sanitizePath(tt.input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
|
||||
// Invalid paths that should be rejected
|
||||
invalidTests := []struct {
|
||||
input string
|
||||
desc string
|
||||
}{
|
||||
{"", "empty path"},
|
||||
{"..", "parent directory"},
|
||||
{"../file.txt", "parent traversal"},
|
||||
{"../../file.txt", "double parent traversal"},
|
||||
{"dir/../../../file.txt", "traversal escaping base"},
|
||||
{"/etc/passwd", "absolute path"},
|
||||
{"/file.txt", "absolute path with single component"},
|
||||
{"dir/../../etc/passwd", "traversal to system file"},
|
||||
}
|
||||
|
||||
for _, tt := range invalidTests {
|
||||
t.Run("invalid:"+tt.desc, func(t *testing.T) {
|
||||
_, err := sanitizePath(tt.input)
|
||||
assert.Error(t, err, "expected error for path: %s", tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveManifestURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
// Already ends with .mf - use as-is
|
||||
{"https://example.com/path/index.mf", "https://example.com/path/index.mf"},
|
||||
{"https://example.com/path/custom.mf", "https://example.com/path/custom.mf"},
|
||||
{"https://example.com/foo.mf", "https://example.com/foo.mf"},
|
||||
|
||||
// Directory with trailing slash - append index.mf
|
||||
{"https://example.com/path/", "https://example.com/path/index.mf"},
|
||||
{"https://example.com/", "https://example.com/index.mf"},
|
||||
|
||||
// Directory without trailing slash - add slash and index.mf
|
||||
{"https://example.com/path", "https://example.com/path/index.mf"},
|
||||
{"https://example.com", "https://example.com/index.mf"},
|
||||
|
||||
// With query strings
|
||||
{"https://example.com/path?foo=bar", "https://example.com/path/index.mf?foo=bar"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result, err := resolveManifestURL(tt.input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchFromHTTP(t *testing.T) {
|
||||
// Create source filesystem with test files
|
||||
sourceFs := afero.NewMemMapFs()
|
||||
|
||||
testFiles := map[string][]byte{
|
||||
"file1.txt": []byte("Hello, World!"),
|
||||
"file2.txt": []byte("This is file 2 with more content."),
|
||||
"subdir/file3.txt": []byte("Nested file content here."),
|
||||
"subdir/deep/f.txt": []byte("Deeply nested file."),
|
||||
}
|
||||
|
||||
for path, content := range testFiles {
|
||||
fullPath := "/" + path // MemMapFs needs absolute paths
|
||||
dir := filepath.Dir(fullPath)
|
||||
require.NoError(t, sourceFs.MkdirAll(dir, 0755))
|
||||
require.NoError(t, afero.WriteFile(sourceFs, fullPath, content, 0644))
|
||||
}
|
||||
|
||||
// Generate manifest using scanner
|
||||
opts := &scanner.Options{
|
||||
Fs: sourceFs,
|
||||
}
|
||||
s := scanner.NewWithOptions(opts)
|
||||
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
|
||||
|
||||
var manifestBuf bytes.Buffer
|
||||
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
|
||||
manifestData := manifestBuf.Bytes()
|
||||
|
||||
// Create HTTP server that serves the source filesystem
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
path := r.URL.Path
|
||||
if path == "/index.mf" {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Write(manifestData)
|
||||
return
|
||||
}
|
||||
|
||||
// Strip leading slash
|
||||
if len(path) > 0 && path[0] == '/' {
|
||||
path = path[1:]
|
||||
}
|
||||
|
||||
content, exists := testFiles[path]
|
||||
if !exists {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Write(content)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create destination directory
|
||||
destDir, err := os.MkdirTemp("", "mfer-fetch-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(destDir)
|
||||
|
||||
// Change to dest directory for the test
|
||||
origDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.Chdir(destDir))
|
||||
defer os.Chdir(origDir)
|
||||
|
||||
// Parse the manifest to get file entries
|
||||
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestData))
|
||||
require.NoError(t, err)
|
||||
|
||||
files := manifest.Files()
|
||||
require.Len(t, files, len(testFiles))
|
||||
|
||||
// Download each file using downloadFile
|
||||
progress := make(chan DownloadProgress, 10)
|
||||
go func() {
|
||||
for range progress {
|
||||
// Drain progress channel
|
||||
}
|
||||
}()
|
||||
|
||||
baseURL := server.URL + "/"
|
||||
for _, f := range files {
|
||||
localPath, err := sanitizePath(f.Path)
|
||||
require.NoError(t, err)
|
||||
|
||||
fileURL := baseURL + f.Path
|
||||
err = downloadFile(fileURL, localPath, f, progress)
|
||||
require.NoError(t, err, "failed to download %s", f.Path)
|
||||
}
|
||||
close(progress)
|
||||
|
||||
// Verify downloaded files match originals
|
||||
for path, expectedContent := range testFiles {
|
||||
downloadedPath := filepath.Join(destDir, path)
|
||||
downloadedContent, err := os.ReadFile(downloadedPath)
|
||||
require.NoError(t, err, "failed to read downloaded file %s", path)
|
||||
assert.Equal(t, expectedContent, downloadedContent, "content mismatch for %s", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchHashMismatch(t *testing.T) {
|
||||
// Create source filesystem with a test file
|
||||
sourceFs := afero.NewMemMapFs()
|
||||
originalContent := []byte("Original content")
|
||||
require.NoError(t, afero.WriteFile(sourceFs, "/file.txt", originalContent, 0644))
|
||||
|
||||
// Generate manifest
|
||||
opts := &scanner.Options{Fs: sourceFs}
|
||||
s := scanner.NewWithOptions(opts)
|
||||
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
|
||||
|
||||
var manifestBuf bytes.Buffer
|
||||
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
|
||||
|
||||
// Parse manifest
|
||||
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes()))
|
||||
require.NoError(t, err)
|
||||
files := manifest.Files()
|
||||
require.Len(t, files, 1)
|
||||
|
||||
// Create server that serves DIFFERENT content (to trigger hash mismatch)
|
||||
tamperedContent := []byte("Tampered content!")
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Write(tamperedContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create temp directory
|
||||
destDir, err := os.MkdirTemp("", "mfer-fetch-hash-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(destDir)
|
||||
|
||||
origDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.Chdir(destDir))
|
||||
defer os.Chdir(origDir)
|
||||
|
||||
// Try to download - should fail with hash mismatch
|
||||
err = downloadFile(server.URL+"/file.txt", "file.txt", files[0], nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "mismatch")
|
||||
|
||||
// Verify temp file was cleaned up
|
||||
_, err = os.Stat(".file.txt.tmp")
|
||||
assert.True(t, os.IsNotExist(err), "temp file should be cleaned up on hash mismatch")
|
||||
|
||||
// Verify final file was not created
|
||||
_, err = os.Stat("file.txt")
|
||||
assert.True(t, os.IsNotExist(err), "final file should not exist on hash mismatch")
|
||||
}
|
||||
|
||||
func TestFetchSizeMismatch(t *testing.T) {
|
||||
// Create source filesystem with a test file
|
||||
sourceFs := afero.NewMemMapFs()
|
||||
originalContent := []byte("Original content with specific size")
|
||||
require.NoError(t, afero.WriteFile(sourceFs, "/file.txt", originalContent, 0644))
|
||||
|
||||
// Generate manifest
|
||||
opts := &scanner.Options{Fs: sourceFs}
|
||||
s := scanner.NewWithOptions(opts)
|
||||
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
|
||||
|
||||
var manifestBuf bytes.Buffer
|
||||
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
|
||||
|
||||
// Parse manifest
|
||||
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes()))
|
||||
require.NoError(t, err)
|
||||
files := manifest.Files()
|
||||
require.Len(t, files, 1)
|
||||
|
||||
// Create server that serves content with wrong size
|
||||
wrongSizeContent := []byte("Short")
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Write(wrongSizeContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create temp directory
|
||||
destDir, err := os.MkdirTemp("", "mfer-fetch-size-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(destDir)
|
||||
|
||||
origDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.Chdir(destDir))
|
||||
defer os.Chdir(origDir)
|
||||
|
||||
// Try to download - should fail with size mismatch
|
||||
err = downloadFile(server.URL+"/file.txt", "file.txt", files[0], nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "size mismatch")
|
||||
|
||||
// Verify temp file was cleaned up
|
||||
_, err = os.Stat(".file.txt.tmp")
|
||||
assert.True(t, os.IsNotExist(err), "temp file should be cleaned up on size mismatch")
|
||||
}
|
||||
|
||||
func TestFetchProgress(t *testing.T) {
|
||||
// Create source filesystem with a larger test file
|
||||
sourceFs := afero.NewMemMapFs()
|
||||
// Create content large enough to trigger multiple progress updates
|
||||
content := bytes.Repeat([]byte("x"), 100*1024) // 100KB
|
||||
require.NoError(t, afero.WriteFile(sourceFs, "/large.txt", content, 0644))
|
||||
|
||||
// Generate manifest
|
||||
opts := &scanner.Options{Fs: sourceFs}
|
||||
s := scanner.NewWithOptions(opts)
|
||||
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
|
||||
|
||||
var manifestBuf bytes.Buffer
|
||||
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
|
||||
|
||||
// Parse manifest
|
||||
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes()))
|
||||
require.NoError(t, err)
|
||||
files := manifest.Files()
|
||||
require.Len(t, files, 1)
|
||||
|
||||
// Create server that serves the content
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Header().Set("Content-Length", "102400")
|
||||
// Write in chunks to allow progress reporting
|
||||
reader := bytes.NewReader(content)
|
||||
io.Copy(w, reader)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create temp directory
|
||||
destDir, err := os.MkdirTemp("", "mfer-fetch-progress-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(destDir)
|
||||
|
||||
origDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.Chdir(destDir))
|
||||
defer os.Chdir(origDir)
|
||||
|
||||
// Set up progress channel and collect updates
|
||||
progress := make(chan DownloadProgress, 100)
|
||||
var progressUpdates []DownloadProgress
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for p := range progress {
|
||||
progressUpdates = append(progressUpdates, p)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Download
|
||||
err = downloadFile(server.URL+"/large.txt", "large.txt", files[0], progress)
|
||||
close(progress)
|
||||
<-done
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify we got progress updates
|
||||
assert.NotEmpty(t, progressUpdates, "should have received progress updates")
|
||||
|
||||
// Verify final progress shows complete
|
||||
if len(progressUpdates) > 0 {
|
||||
last := progressUpdates[len(progressUpdates)-1]
|
||||
assert.Equal(t, int64(len(content)), last.BytesRead, "final progress should show all bytes read")
|
||||
assert.Equal(t, "large.txt", last.Path)
|
||||
}
|
||||
|
||||
// Verify file was downloaded correctly
|
||||
downloaded, err := os.ReadFile("large.txt")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, downloaded)
|
||||
}
|
||||
Reference in New Issue
Block a user