package cli import ( "bytes" "crypto/sha256" "fmt" "io" "net/http" "net/url" "os" "path" "path/filepath" "strings" "time" "github.com/multiformats/go-multihash" "github.com/urfave/cli/v2" "sneak.berlin/go/mfer/internal/log" "sneak.berlin/go/mfer/mfer" ) // DownloadProgress reports the progress of a single file download. type DownloadProgress struct { Path string // File path being downloaded BytesRead int64 // Bytes downloaded so far TotalBytes int64 // Total expected bytes (-1 if unknown) BytesPerSec float64 // Current download rate ETA time.Duration // Estimated time to completion } func (mfa *CLIApp) fetchManifestOperation(ctx *cli.Context) error { log.Debug("fetchManifestOperation()") if ctx.Args().Len() == 0 { return fmt.Errorf("URL argument required") } inputURL := ctx.Args().Get(0) manifestURL, err := resolveManifestURL(inputURL) if err != nil { return fmt.Errorf("invalid URL: %w", err) } log.Infof("fetching manifest from %s", manifestURL) // Fetch manifest resp, err := http.Get(manifestURL) if err != nil { return fmt.Errorf("failed to fetch manifest: %w", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return fmt.Errorf("failed to fetch manifest: HTTP %d", resp.StatusCode) } // Parse manifest manifest, err := mfer.NewManifestFromReader(resp.Body) if err != nil { return fmt.Errorf("failed to parse manifest: %w", err) } files := manifest.Files() log.Infof("manifest contains %d files", len(files)) // Compute base URL (directory containing manifest) baseURL, err := url.Parse(manifestURL) if err != nil { return err } baseURL.Path = path.Dir(baseURL.Path) if !strings.HasSuffix(baseURL.Path, "/") { baseURL.Path += "/" } // Calculate total bytes to download var totalBytes int64 for _, f := range files { totalBytes += f.Size } // Create progress channel progress := make(chan DownloadProgress, 10) // Start progress reporter goroutine done := make(chan struct{}) go func() { defer close(done) for p := range progress { rate := formatBitrate(p.BytesPerSec * 8) if p.ETA > 0 { log.Infof("%s: %d/%d bytes, %s, ETA %s", p.Path, p.BytesRead, p.TotalBytes, rate, p.ETA.Round(time.Second)) } else { log.Infof("%s: %d/%d bytes, %s", p.Path, p.BytesRead, p.TotalBytes, rate) } } }() // Track download start time startTime := time.Now() // Download each file for _, f := range files { // Sanitize the path to prevent path traversal attacks localPath, err := sanitizePath(f.Path) if err != nil { close(progress) <-done return fmt.Errorf("invalid path in manifest: %w", err) } fileURL := baseURL.String() + f.Path log.Infof("fetching %s", f.Path) if err := downloadFile(fileURL, localPath, f, progress); err != nil { close(progress) <-done return fmt.Errorf("failed to download %s: %w", f.Path, err) } } close(progress) <-done // Print summary if not quiet if !ctx.Bool("quiet") { elapsed := time.Since(startTime) avgBytesPerSec := float64(totalBytes) / elapsed.Seconds() avgRate := formatBitrate(avgBytesPerSec * 8) log.Infof("downloaded %d files (%.1f MB) in %.1fs (%s avg)", len(files), float64(totalBytes)/1e6, elapsed.Seconds(), avgRate) } return nil } // sanitizePath validates and sanitizes a file path from the manifest. // It prevents path traversal attacks and rejects unsafe paths. func sanitizePath(p string) (string, error) { // Reject empty paths if p == "" { return "", fmt.Errorf("empty path") } // Reject absolute paths if filepath.IsAbs(p) { return "", fmt.Errorf("absolute path not allowed: %s", p) } // Clean the path to resolve . and .. cleaned := filepath.Clean(p) // Reject paths that escape the current directory if strings.HasPrefix(cleaned, ".."+string(filepath.Separator)) || cleaned == ".." { return "", fmt.Errorf("path traversal not allowed: %s", p) } // Also check for absolute paths after cleaning (handles edge cases) if filepath.IsAbs(cleaned) { return "", fmt.Errorf("absolute path not allowed: %s", p) } return cleaned, nil } // resolveManifestURL takes a URL and returns the manifest URL. // If the URL already ends with .mf, it's returned as-is. // Otherwise, index.mf is appended. func resolveManifestURL(inputURL string) (string, error) { parsed, err := url.Parse(inputURL) if err != nil { return "", err } // Check if URL already ends with .mf if strings.HasSuffix(parsed.Path, ".mf") { return inputURL, nil } // Ensure path ends with / if !strings.HasSuffix(parsed.Path, "/") { parsed.Path += "/" } // Append index.mf parsed.Path += "index.mf" return parsed.String(), nil } // progressWriter wraps an io.Writer and reports progress to a channel. type progressWriter struct { w io.Writer path string total int64 written int64 startTime time.Time progress chan<- DownloadProgress } func (pw *progressWriter) Write(p []byte) (int, error) { n, err := pw.w.Write(p) pw.written += int64(n) if pw.progress != nil { var bytesPerSec float64 var eta time.Duration elapsed := time.Since(pw.startTime) if elapsed > 0 && pw.written > 0 { bytesPerSec = float64(pw.written) / elapsed.Seconds() if bytesPerSec > 0 && pw.total > 0 { remainingBytes := pw.total - pw.written eta = time.Duration(float64(remainingBytes)/bytesPerSec) * time.Second } } sendProgress(pw.progress, DownloadProgress{ Path: pw.path, BytesRead: pw.written, TotalBytes: pw.total, BytesPerSec: bytesPerSec, ETA: eta, }) } return n, err } // formatBitrate formats a bits-per-second value with appropriate unit prefix. func formatBitrate(bps float64) string { switch { case bps >= 1e9: return fmt.Sprintf("%.1f Gbps", bps/1e9) case bps >= 1e6: return fmt.Sprintf("%.1f Mbps", bps/1e6) case bps >= 1e3: return fmt.Sprintf("%.1f Kbps", bps/1e3) default: return fmt.Sprintf("%.0f bps", bps) } } // sendProgress sends a progress update without blocking. func sendProgress(ch chan<- DownloadProgress, p DownloadProgress) { select { case ch <- p: default: } } // downloadFile downloads a URL to a local file path with hash verification. // It downloads to a temporary file, verifies the hash, then renames to the final path. // Progress is reported via the progress channel. func downloadFile(fileURL, localPath string, entry *mfer.MFFilePath, progress chan<- DownloadProgress) error { // Create parent directories if needed dir := filepath.Dir(localPath) if dir != "" && dir != "." { if err := os.MkdirAll(dir, 0755); err != nil { return err } } // Compute temp file path in the same directory // For dotfiles, just append .tmp (they're already hidden) // For regular files, prefix with . and append .tmp base := filepath.Base(localPath) var tmpName string if strings.HasPrefix(base, ".") { tmpName = base + ".tmp" } else { tmpName = "." + base + ".tmp" } tmpPath := filepath.Join(dir, tmpName) if dir == "" || dir == "." { tmpPath = tmpName } // Fetch file resp, err := http.Get(fileURL) if err != nil { return err } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return fmt.Errorf("HTTP %d", resp.StatusCode) } // Determine expected size expectedSize := entry.Size totalBytes := resp.ContentLength if totalBytes < 0 { totalBytes = expectedSize } // Create temp file out, err := os.Create(tmpPath) if err != nil { return err } // Set up hash computation h := sha256.New() // Create progress-reporting writer that also computes hash pw := &progressWriter{ w: io.MultiWriter(out, h), path: localPath, total: totalBytes, startTime: time.Now(), progress: progress, } // Copy content while hashing and reporting progress written, copyErr := io.Copy(pw, resp.Body) // Close file before checking errors (to flush writes) closeErr := out.Close() // If copy failed, clean up temp file and return error if copyErr != nil { _ = os.Remove(tmpPath) return copyErr } if closeErr != nil { _ = os.Remove(tmpPath) return closeErr } // Verify size if written != expectedSize { _ = os.Remove(tmpPath) return fmt.Errorf("size mismatch: expected %d bytes, got %d", expectedSize, written) } // Encode computed hash as multihash computed, err := multihash.Encode(h.Sum(nil), multihash.SHA2_256) if err != nil { _ = os.Remove(tmpPath) return fmt.Errorf("failed to encode hash: %w", err) } // Verify hash against manifest (at least one must match) hashMatch := false for _, hash := range entry.Hashes { if bytes.Equal(computed, hash.MultiHash) { hashMatch = true break } } if !hashMatch { _ = os.Remove(tmpPath) return fmt.Errorf("hash mismatch") } // Rename temp file to final path if err := os.Rename(tmpPath, localPath); err != nil { _ = os.Remove(tmpPath) return fmt.Errorf("failed to rename temp file: %w", err) } return nil }