- Atomic writes for mfer gen: writes to temp file, renames on success, cleans up temp on error/interrupt. Prevents empty manifests on Ctrl-C. - Humanized byte sizes using dustin/go-humanize (e.g., "10 MiB" not "10485760") - Progress lines clear when done (using ANSI escape \r\033[K]) - Debug logging when files are added to manifest (mfer gen -vv) - Move -v/-q flags from global to per-command for better UX - Add tests for atomic write behavior with failing filesystem mock
366 lines
8.9 KiB
Go
366 lines
8.9 KiB
Go
package cli
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/dustin/go-humanize"
|
|
"github.com/multiformats/go-multihash"
|
|
"github.com/urfave/cli/v2"
|
|
"sneak.berlin/go/mfer/internal/log"
|
|
"sneak.berlin/go/mfer/mfer"
|
|
)
|
|
|
|
// DownloadProgress reports the progress of a single file download.
|
|
type DownloadProgress struct {
|
|
Path string // File path being downloaded
|
|
BytesRead int64 // Bytes downloaded so far
|
|
TotalBytes int64 // Total expected bytes (-1 if unknown)
|
|
BytesPerSec float64 // Current download rate
|
|
ETA time.Duration // Estimated time to completion
|
|
}
|
|
|
|
func (mfa *CLIApp) fetchManifestOperation(ctx *cli.Context) error {
|
|
log.Debug("fetchManifestOperation()")
|
|
|
|
if ctx.Args().Len() == 0 {
|
|
return fmt.Errorf("URL argument required")
|
|
}
|
|
|
|
inputURL := ctx.Args().Get(0)
|
|
manifestURL, err := resolveManifestURL(inputURL)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid URL: %w", err)
|
|
}
|
|
|
|
log.Infof("fetching manifest from %s", manifestURL)
|
|
|
|
// Fetch manifest
|
|
resp, err := http.Get(manifestURL)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to fetch manifest: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("failed to fetch manifest: HTTP %d", resp.StatusCode)
|
|
}
|
|
|
|
// Parse manifest
|
|
manifest, err := mfer.NewManifestFromReader(resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse manifest: %w", err)
|
|
}
|
|
|
|
files := manifest.Files()
|
|
log.Infof("manifest contains %d files", len(files))
|
|
|
|
// Compute base URL (directory containing manifest)
|
|
baseURL, err := url.Parse(manifestURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
baseURL.Path = path.Dir(baseURL.Path)
|
|
if !strings.HasSuffix(baseURL.Path, "/") {
|
|
baseURL.Path += "/"
|
|
}
|
|
|
|
// Calculate total bytes to download
|
|
var totalBytes int64
|
|
for _, f := range files {
|
|
totalBytes += f.Size
|
|
}
|
|
|
|
// Create progress channel
|
|
progress := make(chan DownloadProgress, 10)
|
|
|
|
// Start progress reporter goroutine
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
for p := range progress {
|
|
rate := formatBitrate(p.BytesPerSec * 8)
|
|
if p.ETA > 0 {
|
|
log.Infof("%s: %s/%s, %s, ETA %s",
|
|
p.Path, humanize.IBytes(uint64(p.BytesRead)), humanize.IBytes(uint64(p.TotalBytes)),
|
|
rate, p.ETA.Round(time.Second))
|
|
} else {
|
|
log.Infof("%s: %s/%s, %s",
|
|
p.Path, humanize.IBytes(uint64(p.BytesRead)), humanize.IBytes(uint64(p.TotalBytes)), rate)
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Track download start time
|
|
startTime := time.Now()
|
|
|
|
// Download each file
|
|
for _, f := range files {
|
|
// Sanitize the path to prevent path traversal attacks
|
|
localPath, err := sanitizePath(f.Path)
|
|
if err != nil {
|
|
close(progress)
|
|
<-done
|
|
return fmt.Errorf("invalid path in manifest: %w", err)
|
|
}
|
|
|
|
fileURL := baseURL.String() + f.Path
|
|
log.Infof("fetching %s", f.Path)
|
|
|
|
if err := downloadFile(fileURL, localPath, f, progress); err != nil {
|
|
close(progress)
|
|
<-done
|
|
return fmt.Errorf("failed to download %s: %w", f.Path, err)
|
|
}
|
|
}
|
|
|
|
close(progress)
|
|
<-done
|
|
|
|
// Print summary
|
|
elapsed := time.Since(startTime)
|
|
avgBytesPerSec := float64(totalBytes) / elapsed.Seconds()
|
|
avgRate := formatBitrate(avgBytesPerSec * 8)
|
|
log.Infof("downloaded %d files (%s) in %.1fs (%s avg)",
|
|
len(files),
|
|
humanize.IBytes(uint64(totalBytes)),
|
|
elapsed.Seconds(),
|
|
avgRate)
|
|
|
|
return nil
|
|
}
|
|
|
|
// sanitizePath validates and sanitizes a file path from the manifest.
|
|
// It prevents path traversal attacks and rejects unsafe paths.
|
|
func sanitizePath(p string) (string, error) {
|
|
// Reject empty paths
|
|
if p == "" {
|
|
return "", fmt.Errorf("empty path")
|
|
}
|
|
|
|
// Reject absolute paths
|
|
if filepath.IsAbs(p) {
|
|
return "", fmt.Errorf("absolute path not allowed: %s", p)
|
|
}
|
|
|
|
// Clean the path to resolve . and ..
|
|
cleaned := filepath.Clean(p)
|
|
|
|
// Reject paths that escape the current directory
|
|
if strings.HasPrefix(cleaned, ".."+string(filepath.Separator)) || cleaned == ".." {
|
|
return "", fmt.Errorf("path traversal not allowed: %s", p)
|
|
}
|
|
|
|
// Also check for absolute paths after cleaning (handles edge cases)
|
|
if filepath.IsAbs(cleaned) {
|
|
return "", fmt.Errorf("absolute path not allowed: %s", p)
|
|
}
|
|
|
|
return cleaned, nil
|
|
}
|
|
|
|
// resolveManifestURL takes a URL and returns the manifest URL.
|
|
// If the URL already ends with .mf, it's returned as-is.
|
|
// Otherwise, index.mf is appended.
|
|
func resolveManifestURL(inputURL string) (string, error) {
|
|
parsed, err := url.Parse(inputURL)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Check if URL already ends with .mf
|
|
if strings.HasSuffix(parsed.Path, ".mf") {
|
|
return inputURL, nil
|
|
}
|
|
|
|
// Ensure path ends with /
|
|
if !strings.HasSuffix(parsed.Path, "/") {
|
|
parsed.Path += "/"
|
|
}
|
|
|
|
// Append index.mf
|
|
parsed.Path += "index.mf"
|
|
|
|
return parsed.String(), nil
|
|
}
|
|
|
|
// progressWriter wraps an io.Writer and reports progress to a channel.
|
|
type progressWriter struct {
|
|
w io.Writer
|
|
path string
|
|
total int64
|
|
written int64
|
|
startTime time.Time
|
|
progress chan<- DownloadProgress
|
|
}
|
|
|
|
func (pw *progressWriter) Write(p []byte) (int, error) {
|
|
n, err := pw.w.Write(p)
|
|
pw.written += int64(n)
|
|
if pw.progress != nil {
|
|
var bytesPerSec float64
|
|
var eta time.Duration
|
|
elapsed := time.Since(pw.startTime)
|
|
if elapsed > 0 && pw.written > 0 {
|
|
bytesPerSec = float64(pw.written) / elapsed.Seconds()
|
|
if bytesPerSec > 0 && pw.total > 0 {
|
|
remainingBytes := pw.total - pw.written
|
|
eta = time.Duration(float64(remainingBytes)/bytesPerSec) * time.Second
|
|
}
|
|
}
|
|
sendProgress(pw.progress, DownloadProgress{
|
|
Path: pw.path,
|
|
BytesRead: pw.written,
|
|
TotalBytes: pw.total,
|
|
BytesPerSec: bytesPerSec,
|
|
ETA: eta,
|
|
})
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
// formatBitrate formats a bits-per-second value with appropriate unit prefix.
|
|
func formatBitrate(bps float64) string {
|
|
switch {
|
|
case bps >= 1e9:
|
|
return fmt.Sprintf("%.1f Gbps", bps/1e9)
|
|
case bps >= 1e6:
|
|
return fmt.Sprintf("%.1f Mbps", bps/1e6)
|
|
case bps >= 1e3:
|
|
return fmt.Sprintf("%.1f Kbps", bps/1e3)
|
|
default:
|
|
return fmt.Sprintf("%.0f bps", bps)
|
|
}
|
|
}
|
|
|
|
// sendProgress sends a progress update without blocking.
|
|
func sendProgress(ch chan<- DownloadProgress, p DownloadProgress) {
|
|
select {
|
|
case ch <- p:
|
|
default:
|
|
}
|
|
}
|
|
|
|
// downloadFile downloads a URL to a local file path with hash verification.
|
|
// It downloads to a temporary file, verifies the hash, then renames to the final path.
|
|
// Progress is reported via the progress channel.
|
|
func downloadFile(fileURL, localPath string, entry *mfer.MFFilePath, progress chan<- DownloadProgress) error {
|
|
// Create parent directories if needed
|
|
dir := filepath.Dir(localPath)
|
|
if dir != "" && dir != "." {
|
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Compute temp file path in the same directory
|
|
// For dotfiles, just append .tmp (they're already hidden)
|
|
// For regular files, prefix with . and append .tmp
|
|
base := filepath.Base(localPath)
|
|
var tmpName string
|
|
if strings.HasPrefix(base, ".") {
|
|
tmpName = base + ".tmp"
|
|
} else {
|
|
tmpName = "." + base + ".tmp"
|
|
}
|
|
tmpPath := filepath.Join(dir, tmpName)
|
|
if dir == "" || dir == "." {
|
|
tmpPath = tmpName
|
|
}
|
|
|
|
// Fetch file
|
|
resp, err := http.Get(fileURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("HTTP %d", resp.StatusCode)
|
|
}
|
|
|
|
// Determine expected size
|
|
expectedSize := entry.Size
|
|
totalBytes := resp.ContentLength
|
|
if totalBytes < 0 {
|
|
totalBytes = expectedSize
|
|
}
|
|
|
|
// Create temp file
|
|
out, err := os.Create(tmpPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set up hash computation
|
|
h := sha256.New()
|
|
|
|
// Create progress-reporting writer that also computes hash
|
|
pw := &progressWriter{
|
|
w: io.MultiWriter(out, h),
|
|
path: localPath,
|
|
total: totalBytes,
|
|
startTime: time.Now(),
|
|
progress: progress,
|
|
}
|
|
|
|
// Copy content while hashing and reporting progress
|
|
written, copyErr := io.Copy(pw, resp.Body)
|
|
|
|
// Close file before checking errors (to flush writes)
|
|
closeErr := out.Close()
|
|
|
|
// If copy failed, clean up temp file and return error
|
|
if copyErr != nil {
|
|
_ = os.Remove(tmpPath)
|
|
return copyErr
|
|
}
|
|
if closeErr != nil {
|
|
_ = os.Remove(tmpPath)
|
|
return closeErr
|
|
}
|
|
|
|
// Verify size
|
|
if written != expectedSize {
|
|
_ = os.Remove(tmpPath)
|
|
return fmt.Errorf("size mismatch: expected %d bytes, got %d", expectedSize, written)
|
|
}
|
|
|
|
// Encode computed hash as multihash
|
|
computed, err := multihash.Encode(h.Sum(nil), multihash.SHA2_256)
|
|
if err != nil {
|
|
_ = os.Remove(tmpPath)
|
|
return fmt.Errorf("failed to encode hash: %w", err)
|
|
}
|
|
|
|
// Verify hash against manifest (at least one must match)
|
|
hashMatch := false
|
|
for _, hash := range entry.Hashes {
|
|
if bytes.Equal(computed, hash.MultiHash) {
|
|
hashMatch = true
|
|
break
|
|
}
|
|
}
|
|
if !hashMatch {
|
|
_ = os.Remove(tmpPath)
|
|
return fmt.Errorf("hash mismatch")
|
|
}
|
|
|
|
// Rename temp file to final path
|
|
if err := os.Rename(tmpPath, localPath); err != nil {
|
|
_ = os.Remove(tmpPath)
|
|
return fmt.Errorf("failed to rename temp file: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|