File paths with spaces, #, ?, %, etc. were concatenated directly into URLs without encoding, producing malformed download URLs. Add encodeFilePath() that encodes each path segment individually (preserving directory separators) and use it in fetch.
392 lines
12 KiB
Go
392 lines
12 KiB
Go
package cli
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/spf13/afero"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"sneak.berlin/go/mfer/mfer"
|
|
)
|
|
|
|
func TestEncodeFilePath(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"file.txt", "file.txt"},
|
|
{"dir/file.txt", "dir/file.txt"},
|
|
{"my file.txt", "my%20file.txt"},
|
|
{"dir/my file.txt", "dir/my%20file.txt"},
|
|
{"file#1.txt", "file%231.txt"},
|
|
{"file?v=1.txt", "file%3Fv=1.txt"},
|
|
{"path/to/file with spaces.txt", "path/to/file%20with%20spaces.txt"},
|
|
{"100%done.txt", "100%25done.txt"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.input, func(t *testing.T) {
|
|
result := encodeFilePath(tt.input)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSanitizePath(t *testing.T) {
|
|
// Valid paths that should be accepted
|
|
validTests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"file.txt", "file.txt"},
|
|
{"dir/file.txt", "dir/file.txt"},
|
|
{"dir/subdir/file.txt", "dir/subdir/file.txt"},
|
|
{"./file.txt", "file.txt"},
|
|
{"./dir/file.txt", "dir/file.txt"},
|
|
{"dir/./file.txt", "dir/file.txt"},
|
|
}
|
|
|
|
for _, tt := range validTests {
|
|
t.Run("valid:"+tt.input, func(t *testing.T) {
|
|
result, err := sanitizePath(tt.input)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
|
|
// Invalid paths that should be rejected
|
|
invalidTests := []struct {
|
|
input string
|
|
desc string
|
|
}{
|
|
{"", "empty path"},
|
|
{"..", "parent directory"},
|
|
{"../file.txt", "parent traversal"},
|
|
{"../../file.txt", "double parent traversal"},
|
|
{"dir/../../../file.txt", "traversal escaping base"},
|
|
{"/etc/passwd", "absolute path"},
|
|
{"/file.txt", "absolute path with single component"},
|
|
{"dir/../../etc/passwd", "traversal to system file"},
|
|
}
|
|
|
|
for _, tt := range invalidTests {
|
|
t.Run("invalid:"+tt.desc, func(t *testing.T) {
|
|
_, err := sanitizePath(tt.input)
|
|
assert.Error(t, err, "expected error for path: %s", tt.input)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestResolveManifestURL(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
// Already ends with .mf - use as-is
|
|
{"https://example.com/path/index.mf", "https://example.com/path/index.mf"},
|
|
{"https://example.com/path/custom.mf", "https://example.com/path/custom.mf"},
|
|
{"https://example.com/foo.mf", "https://example.com/foo.mf"},
|
|
|
|
// Directory with trailing slash - append index.mf
|
|
{"https://example.com/path/", "https://example.com/path/index.mf"},
|
|
{"https://example.com/", "https://example.com/index.mf"},
|
|
|
|
// Directory without trailing slash - add slash and index.mf
|
|
{"https://example.com/path", "https://example.com/path/index.mf"},
|
|
{"https://example.com", "https://example.com/index.mf"},
|
|
|
|
// With query strings
|
|
{"https://example.com/path?foo=bar", "https://example.com/path/index.mf?foo=bar"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.input, func(t *testing.T) {
|
|
result, err := resolveManifestURL(tt.input)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFetchFromHTTP(t *testing.T) {
|
|
// Create source filesystem with test files
|
|
sourceFs := afero.NewMemMapFs()
|
|
|
|
testFiles := map[string][]byte{
|
|
"file1.txt": []byte("Hello, World!"),
|
|
"file2.txt": []byte("This is file 2 with more content."),
|
|
"subdir/file3.txt": []byte("Nested file content here."),
|
|
"subdir/deep/f.txt": []byte("Deeply nested file."),
|
|
}
|
|
|
|
for path, content := range testFiles {
|
|
fullPath := "/" + path // MemMapFs needs absolute paths
|
|
dir := filepath.Dir(fullPath)
|
|
require.NoError(t, sourceFs.MkdirAll(dir, 0o755))
|
|
require.NoError(t, afero.WriteFile(sourceFs, fullPath, content, 0o644))
|
|
}
|
|
|
|
// Generate manifest using scanner
|
|
opts := &mfer.ScannerOptions{
|
|
Fs: sourceFs,
|
|
}
|
|
s := mfer.NewScannerWithOptions(opts)
|
|
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
|
|
|
|
var manifestBuf bytes.Buffer
|
|
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
|
|
manifestData := manifestBuf.Bytes()
|
|
|
|
// Create HTTP server that serves the source filesystem
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
path := r.URL.Path
|
|
if path == "/index.mf" {
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
_, _ = w.Write(manifestData)
|
|
return
|
|
}
|
|
|
|
// Strip leading slash
|
|
if len(path) > 0 && path[0] == '/' {
|
|
path = path[1:]
|
|
}
|
|
|
|
content, exists := testFiles[path]
|
|
if !exists {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
_, _ = w.Write(content)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create destination directory
|
|
destDir, err := os.MkdirTemp("", "mfer-fetch-test-*")
|
|
require.NoError(t, err)
|
|
defer func() { _ = os.RemoveAll(destDir) }()
|
|
|
|
// Change to dest directory for the test
|
|
origDir, err := os.Getwd()
|
|
require.NoError(t, err)
|
|
require.NoError(t, os.Chdir(destDir))
|
|
defer func() { _ = os.Chdir(origDir) }()
|
|
|
|
// Parse the manifest to get file entries
|
|
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestData))
|
|
require.NoError(t, err)
|
|
|
|
files := manifest.Files()
|
|
require.Len(t, files, len(testFiles))
|
|
|
|
// Download each file using downloadFile
|
|
progress := make(chan DownloadProgress, 10)
|
|
go func() {
|
|
for range progress {
|
|
// Drain progress channel
|
|
}
|
|
}()
|
|
|
|
baseURL := server.URL + "/"
|
|
for _, f := range files {
|
|
localPath, err := sanitizePath(f.Path)
|
|
require.NoError(t, err)
|
|
|
|
fileURL := baseURL + f.Path
|
|
err = downloadFile(fileURL, localPath, f, progress)
|
|
require.NoError(t, err, "failed to download %s", f.Path)
|
|
}
|
|
close(progress)
|
|
|
|
// Verify downloaded files match originals
|
|
for path, expectedContent := range testFiles {
|
|
downloadedPath := filepath.Join(destDir, path)
|
|
downloadedContent, err := os.ReadFile(downloadedPath)
|
|
require.NoError(t, err, "failed to read downloaded file %s", path)
|
|
assert.Equal(t, expectedContent, downloadedContent, "content mismatch for %s", path)
|
|
}
|
|
}
|
|
|
|
func TestFetchHashMismatch(t *testing.T) {
|
|
// Create source filesystem with a test file
|
|
sourceFs := afero.NewMemMapFs()
|
|
originalContent := []byte("Original content")
|
|
require.NoError(t, afero.WriteFile(sourceFs, "/file.txt", originalContent, 0o644))
|
|
|
|
// Generate manifest
|
|
opts := &mfer.ScannerOptions{Fs: sourceFs}
|
|
s := mfer.NewScannerWithOptions(opts)
|
|
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
|
|
|
|
var manifestBuf bytes.Buffer
|
|
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
|
|
|
|
// Parse manifest
|
|
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes()))
|
|
require.NoError(t, err)
|
|
files := manifest.Files()
|
|
require.Len(t, files, 1)
|
|
|
|
// Create server that serves DIFFERENT content (to trigger hash mismatch)
|
|
tamperedContent := []byte("Tampered content!")
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
_, _ = w.Write(tamperedContent)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create temp directory
|
|
destDir, err := os.MkdirTemp("", "mfer-fetch-hash-test-*")
|
|
require.NoError(t, err)
|
|
defer func() { _ = os.RemoveAll(destDir) }()
|
|
|
|
origDir, err := os.Getwd()
|
|
require.NoError(t, err)
|
|
require.NoError(t, os.Chdir(destDir))
|
|
defer func() { _ = os.Chdir(origDir) }()
|
|
|
|
// Try to download - should fail with hash mismatch
|
|
err = downloadFile(server.URL+"/file.txt", "file.txt", files[0], nil)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "mismatch")
|
|
|
|
// Verify temp file was cleaned up
|
|
_, err = os.Stat(".file.txt.tmp")
|
|
assert.True(t, os.IsNotExist(err), "temp file should be cleaned up on hash mismatch")
|
|
|
|
// Verify final file was not created
|
|
_, err = os.Stat("file.txt")
|
|
assert.True(t, os.IsNotExist(err), "final file should not exist on hash mismatch")
|
|
}
|
|
|
|
func TestFetchSizeMismatch(t *testing.T) {
|
|
// Create source filesystem with a test file
|
|
sourceFs := afero.NewMemMapFs()
|
|
originalContent := []byte("Original content with specific size")
|
|
require.NoError(t, afero.WriteFile(sourceFs, "/file.txt", originalContent, 0o644))
|
|
|
|
// Generate manifest
|
|
opts := &mfer.ScannerOptions{Fs: sourceFs}
|
|
s := mfer.NewScannerWithOptions(opts)
|
|
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
|
|
|
|
var manifestBuf bytes.Buffer
|
|
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
|
|
|
|
// Parse manifest
|
|
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes()))
|
|
require.NoError(t, err)
|
|
files := manifest.Files()
|
|
require.Len(t, files, 1)
|
|
|
|
// Create server that serves content with wrong size
|
|
wrongSizeContent := []byte("Short")
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
_, _ = w.Write(wrongSizeContent)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create temp directory
|
|
destDir, err := os.MkdirTemp("", "mfer-fetch-size-test-*")
|
|
require.NoError(t, err)
|
|
defer func() { _ = os.RemoveAll(destDir) }()
|
|
|
|
origDir, err := os.Getwd()
|
|
require.NoError(t, err)
|
|
require.NoError(t, os.Chdir(destDir))
|
|
defer func() { _ = os.Chdir(origDir) }()
|
|
|
|
// Try to download - should fail with size mismatch
|
|
err = downloadFile(server.URL+"/file.txt", "file.txt", files[0], nil)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "size mismatch")
|
|
|
|
// Verify temp file was cleaned up
|
|
_, err = os.Stat(".file.txt.tmp")
|
|
assert.True(t, os.IsNotExist(err), "temp file should be cleaned up on size mismatch")
|
|
}
|
|
|
|
func TestFetchProgress(t *testing.T) {
|
|
// Create source filesystem with a larger test file
|
|
sourceFs := afero.NewMemMapFs()
|
|
// Create content large enough to trigger multiple progress updates
|
|
content := bytes.Repeat([]byte("x"), 100*1024) // 100KB
|
|
require.NoError(t, afero.WriteFile(sourceFs, "/large.txt", content, 0o644))
|
|
|
|
// Generate manifest
|
|
opts := &mfer.ScannerOptions{Fs: sourceFs}
|
|
s := mfer.NewScannerWithOptions(opts)
|
|
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
|
|
|
|
var manifestBuf bytes.Buffer
|
|
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
|
|
|
|
// Parse manifest
|
|
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes()))
|
|
require.NoError(t, err)
|
|
files := manifest.Files()
|
|
require.Len(t, files, 1)
|
|
|
|
// Create server that serves the content
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
w.Header().Set("Content-Length", "102400")
|
|
// Write in chunks to allow progress reporting
|
|
reader := bytes.NewReader(content)
|
|
_, _ = io.Copy(w, reader)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create temp directory
|
|
destDir, err := os.MkdirTemp("", "mfer-fetch-progress-test-*")
|
|
require.NoError(t, err)
|
|
defer func() { _ = os.RemoveAll(destDir) }()
|
|
|
|
origDir, err := os.Getwd()
|
|
require.NoError(t, err)
|
|
require.NoError(t, os.Chdir(destDir))
|
|
defer func() { _ = os.Chdir(origDir) }()
|
|
|
|
// Set up progress channel and collect updates
|
|
progress := make(chan DownloadProgress, 100)
|
|
var progressUpdates []DownloadProgress
|
|
done := make(chan struct{})
|
|
go func() {
|
|
for p := range progress {
|
|
progressUpdates = append(progressUpdates, p)
|
|
}
|
|
close(done)
|
|
}()
|
|
|
|
// Download
|
|
err = downloadFile(server.URL+"/large.txt", "large.txt", files[0], progress)
|
|
close(progress)
|
|
<-done
|
|
|
|
require.NoError(t, err)
|
|
|
|
// Verify we got progress updates
|
|
assert.NotEmpty(t, progressUpdates, "should have received progress updates")
|
|
|
|
// Verify final progress shows complete
|
|
if len(progressUpdates) > 0 {
|
|
last := progressUpdates[len(progressUpdates)-1]
|
|
assert.Equal(t, int64(len(content)), last.BytesRead, "final progress should show all bytes read")
|
|
assert.Equal(t, "large.txt", last.Path)
|
|
}
|
|
|
|
// Verify file was downloaded correctly
|
|
downloaded, err := os.ReadFile("large.txt")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, content, downloaded)
|
|
}
|