mfer/internal/cli/fetch_test.go
sneak a5b0343b28 Use Go 1.13+ octal literal syntax throughout codebase
Update file permission literals from legacy octal format (0755, 0644)
to explicit Go 1.13+ format (0o755, 0o644) for improved readability.
2025-12-18 01:29:40 -08:00

369 lines
11 KiB
Go

package cli
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"sneak.berlin/go/mfer/mfer"
)
func TestSanitizePath(t *testing.T) {
// Valid paths that should be accepted
validTests := []struct {
input string
expected string
}{
{"file.txt", "file.txt"},
{"dir/file.txt", "dir/file.txt"},
{"dir/subdir/file.txt", "dir/subdir/file.txt"},
{"./file.txt", "file.txt"},
{"./dir/file.txt", "dir/file.txt"},
{"dir/./file.txt", "dir/file.txt"},
}
for _, tt := range validTests {
t.Run("valid:"+tt.input, func(t *testing.T) {
result, err := sanitizePath(tt.input)
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
// Invalid paths that should be rejected
invalidTests := []struct {
input string
desc string
}{
{"", "empty path"},
{"..", "parent directory"},
{"../file.txt", "parent traversal"},
{"../../file.txt", "double parent traversal"},
{"dir/../../../file.txt", "traversal escaping base"},
{"/etc/passwd", "absolute path"},
{"/file.txt", "absolute path with single component"},
{"dir/../../etc/passwd", "traversal to system file"},
}
for _, tt := range invalidTests {
t.Run("invalid:"+tt.desc, func(t *testing.T) {
_, err := sanitizePath(tt.input)
assert.Error(t, err, "expected error for path: %s", tt.input)
})
}
}
func TestResolveManifestURL(t *testing.T) {
tests := []struct {
input string
expected string
}{
// Already ends with .mf - use as-is
{"https://example.com/path/index.mf", "https://example.com/path/index.mf"},
{"https://example.com/path/custom.mf", "https://example.com/path/custom.mf"},
{"https://example.com/foo.mf", "https://example.com/foo.mf"},
// Directory with trailing slash - append index.mf
{"https://example.com/path/", "https://example.com/path/index.mf"},
{"https://example.com/", "https://example.com/index.mf"},
// Directory without trailing slash - add slash and index.mf
{"https://example.com/path", "https://example.com/path/index.mf"},
{"https://example.com", "https://example.com/index.mf"},
// With query strings
{"https://example.com/path?foo=bar", "https://example.com/path/index.mf?foo=bar"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result, err := resolveManifestURL(tt.input)
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
func TestFetchFromHTTP(t *testing.T) {
// Create source filesystem with test files
sourceFs := afero.NewMemMapFs()
testFiles := map[string][]byte{
"file1.txt": []byte("Hello, World!"),
"file2.txt": []byte("This is file 2 with more content."),
"subdir/file3.txt": []byte("Nested file content here."),
"subdir/deep/f.txt": []byte("Deeply nested file."),
}
for path, content := range testFiles {
fullPath := "/" + path // MemMapFs needs absolute paths
dir := filepath.Dir(fullPath)
require.NoError(t, sourceFs.MkdirAll(dir, 0o755))
require.NoError(t, afero.WriteFile(sourceFs, fullPath, content, 0o644))
}
// Generate manifest using scanner
opts := &mfer.ScannerOptions{
Fs: sourceFs,
}
s := mfer.NewScannerWithOptions(opts)
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
var manifestBuf bytes.Buffer
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
manifestData := manifestBuf.Bytes()
// Create HTTP server that serves the source filesystem
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
if path == "/index.mf" {
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(manifestData)
return
}
// Strip leading slash
if len(path) > 0 && path[0] == '/' {
path = path[1:]
}
content, exists := testFiles[path]
if !exists {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(content)
}))
defer server.Close()
// Create destination directory
destDir, err := os.MkdirTemp("", "mfer-fetch-test-*")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(destDir) }()
// Change to dest directory for the test
origDir, err := os.Getwd()
require.NoError(t, err)
require.NoError(t, os.Chdir(destDir))
defer func() { _ = os.Chdir(origDir) }()
// Parse the manifest to get file entries
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestData))
require.NoError(t, err)
files := manifest.Files()
require.Len(t, files, len(testFiles))
// Download each file using downloadFile
progress := make(chan DownloadProgress, 10)
go func() {
for range progress {
// Drain progress channel
}
}()
baseURL := server.URL + "/"
for _, f := range files {
localPath, err := sanitizePath(f.Path)
require.NoError(t, err)
fileURL := baseURL + f.Path
err = downloadFile(fileURL, localPath, f, progress)
require.NoError(t, err, "failed to download %s", f.Path)
}
close(progress)
// Verify downloaded files match originals
for path, expectedContent := range testFiles {
downloadedPath := filepath.Join(destDir, path)
downloadedContent, err := os.ReadFile(downloadedPath)
require.NoError(t, err, "failed to read downloaded file %s", path)
assert.Equal(t, expectedContent, downloadedContent, "content mismatch for %s", path)
}
}
func TestFetchHashMismatch(t *testing.T) {
// Create source filesystem with a test file
sourceFs := afero.NewMemMapFs()
originalContent := []byte("Original content")
require.NoError(t, afero.WriteFile(sourceFs, "/file.txt", originalContent, 0o644))
// Generate manifest
opts := &mfer.ScannerOptions{Fs: sourceFs}
s := mfer.NewScannerWithOptions(opts)
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
var manifestBuf bytes.Buffer
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
// Parse manifest
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes()))
require.NoError(t, err)
files := manifest.Files()
require.Len(t, files, 1)
// Create server that serves DIFFERENT content (to trigger hash mismatch)
tamperedContent := []byte("Tampered content!")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(tamperedContent)
}))
defer server.Close()
// Create temp directory
destDir, err := os.MkdirTemp("", "mfer-fetch-hash-test-*")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(destDir) }()
origDir, err := os.Getwd()
require.NoError(t, err)
require.NoError(t, os.Chdir(destDir))
defer func() { _ = os.Chdir(origDir) }()
// Try to download - should fail with hash mismatch
err = downloadFile(server.URL+"/file.txt", "file.txt", files[0], nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "mismatch")
// Verify temp file was cleaned up
_, err = os.Stat(".file.txt.tmp")
assert.True(t, os.IsNotExist(err), "temp file should be cleaned up on hash mismatch")
// Verify final file was not created
_, err = os.Stat("file.txt")
assert.True(t, os.IsNotExist(err), "final file should not exist on hash mismatch")
}
func TestFetchSizeMismatch(t *testing.T) {
// Create source filesystem with a test file
sourceFs := afero.NewMemMapFs()
originalContent := []byte("Original content with specific size")
require.NoError(t, afero.WriteFile(sourceFs, "/file.txt", originalContent, 0o644))
// Generate manifest
opts := &mfer.ScannerOptions{Fs: sourceFs}
s := mfer.NewScannerWithOptions(opts)
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
var manifestBuf bytes.Buffer
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
// Parse manifest
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes()))
require.NoError(t, err)
files := manifest.Files()
require.Len(t, files, 1)
// Create server that serves content with wrong size
wrongSizeContent := []byte("Short")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write(wrongSizeContent)
}))
defer server.Close()
// Create temp directory
destDir, err := os.MkdirTemp("", "mfer-fetch-size-test-*")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(destDir) }()
origDir, err := os.Getwd()
require.NoError(t, err)
require.NoError(t, os.Chdir(destDir))
defer func() { _ = os.Chdir(origDir) }()
// Try to download - should fail with size mismatch
err = downloadFile(server.URL+"/file.txt", "file.txt", files[0], nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "size mismatch")
// Verify temp file was cleaned up
_, err = os.Stat(".file.txt.tmp")
assert.True(t, os.IsNotExist(err), "temp file should be cleaned up on size mismatch")
}
func TestFetchProgress(t *testing.T) {
// Create source filesystem with a larger test file
sourceFs := afero.NewMemMapFs()
// Create content large enough to trigger multiple progress updates
content := bytes.Repeat([]byte("x"), 100*1024) // 100KB
require.NoError(t, afero.WriteFile(sourceFs, "/large.txt", content, 0o644))
// Generate manifest
opts := &mfer.ScannerOptions{Fs: sourceFs}
s := mfer.NewScannerWithOptions(opts)
require.NoError(t, s.EnumerateFS(sourceFs, "/", nil))
var manifestBuf bytes.Buffer
require.NoError(t, s.ToManifest(context.Background(), &manifestBuf, nil))
// Parse manifest
manifest, err := mfer.NewManifestFromReader(bytes.NewReader(manifestBuf.Bytes()))
require.NoError(t, err)
files := manifest.Files()
require.Len(t, files, 1)
// Create server that serves the content
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Length", "102400")
// Write in chunks to allow progress reporting
reader := bytes.NewReader(content)
_, _ = io.Copy(w, reader)
}))
defer server.Close()
// Create temp directory
destDir, err := os.MkdirTemp("", "mfer-fetch-progress-test-*")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(destDir) }()
origDir, err := os.Getwd()
require.NoError(t, err)
require.NoError(t, os.Chdir(destDir))
defer func() { _ = os.Chdir(origDir) }()
// Set up progress channel and collect updates
progress := make(chan DownloadProgress, 100)
var progressUpdates []DownloadProgress
done := make(chan struct{})
go func() {
for p := range progress {
progressUpdates = append(progressUpdates, p)
}
close(done)
}()
// Download
err = downloadFile(server.URL+"/large.txt", "large.txt", files[0], progress)
close(progress)
<-done
require.NoError(t, err)
// Verify we got progress updates
assert.NotEmpty(t, progressUpdates, "should have received progress updates")
// Verify final progress shows complete
if len(progressUpdates) > 0 {
last := progressUpdates[len(progressUpdates)-1]
assert.Equal(t, int64(len(content)), last.BytesRead, "final progress should show all bytes read")
assert.Equal(t, "large.txt", last.Path)
}
// Verify file was downloaded correctly
downloaded, err := os.ReadFile("large.txt")
require.NoError(t, err)
assert.Equal(t, content, downloaded)
}