File paths with spaces, #, ?, %, etc. were concatenated directly into URLs without encoding, producing malformed download URLs. Add encodeFilePath() that encodes each path segment individually (preserving directory separators) and use it in fetch.
375 lines
9.2 KiB
Go
375 lines
9.2 KiB
Go
package cli
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/dustin/go-humanize"
|
|
"github.com/multiformats/go-multihash"
|
|
"github.com/urfave/cli/v2"
|
|
"sneak.berlin/go/mfer/internal/log"
|
|
"sneak.berlin/go/mfer/mfer"
|
|
)
|
|
|
|
// DownloadProgress reports the progress of a single file download.
|
|
type DownloadProgress struct {
|
|
Path string // File path being downloaded
|
|
BytesRead int64 // Bytes downloaded so far
|
|
TotalBytes int64 // Total expected bytes (-1 if unknown)
|
|
BytesPerSec float64 // Current download rate
|
|
ETA time.Duration // Estimated time to completion
|
|
}
|
|
|
|
func (mfa *CLIApp) fetchManifestOperation(ctx *cli.Context) error {
|
|
log.Debug("fetchManifestOperation()")
|
|
|
|
if ctx.Args().Len() == 0 {
|
|
return fmt.Errorf("URL argument required")
|
|
}
|
|
|
|
inputURL := ctx.Args().Get(0)
|
|
manifestURL, err := resolveManifestURL(inputURL)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid URL: %w", err)
|
|
}
|
|
|
|
log.Infof("fetching manifest from %s", manifestURL)
|
|
|
|
// Fetch manifest
|
|
resp, err := http.Get(manifestURL)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to fetch manifest: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("failed to fetch manifest: HTTP %d", resp.StatusCode)
|
|
}
|
|
|
|
// Parse manifest
|
|
manifest, err := mfer.NewManifestFromReader(resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse manifest: %w", err)
|
|
}
|
|
|
|
files := manifest.Files()
|
|
log.Infof("manifest contains %d files", len(files))
|
|
|
|
// Compute base URL (directory containing manifest)
|
|
baseURL, err := url.Parse(manifestURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
baseURL.Path = path.Dir(baseURL.Path)
|
|
if !strings.HasSuffix(baseURL.Path, "/") {
|
|
baseURL.Path += "/"
|
|
}
|
|
|
|
// Calculate total bytes to download
|
|
var totalBytes int64
|
|
for _, f := range files {
|
|
totalBytes += f.Size
|
|
}
|
|
|
|
// Create progress channel
|
|
progress := make(chan DownloadProgress, 10)
|
|
|
|
// Start progress reporter goroutine
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
for p := range progress {
|
|
rate := formatBitrate(p.BytesPerSec * 8)
|
|
if p.ETA > 0 {
|
|
log.Infof("%s: %s/%s, %s, ETA %s",
|
|
p.Path, humanize.IBytes(uint64(p.BytesRead)), humanize.IBytes(uint64(p.TotalBytes)),
|
|
rate, p.ETA.Round(time.Second))
|
|
} else {
|
|
log.Infof("%s: %s/%s, %s",
|
|
p.Path, humanize.IBytes(uint64(p.BytesRead)), humanize.IBytes(uint64(p.TotalBytes)), rate)
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Track download start time
|
|
startTime := time.Now()
|
|
|
|
// Download each file
|
|
for _, f := range files {
|
|
// Sanitize the path to prevent path traversal attacks
|
|
localPath, err := sanitizePath(f.Path)
|
|
if err != nil {
|
|
close(progress)
|
|
<-done
|
|
return fmt.Errorf("invalid path in manifest: %w", err)
|
|
}
|
|
|
|
fileURL := baseURL.String() + encodeFilePath(f.Path)
|
|
log.Infof("fetching %s", f.Path)
|
|
|
|
if err := downloadFile(fileURL, localPath, f, progress); err != nil {
|
|
close(progress)
|
|
<-done
|
|
return fmt.Errorf("failed to download %s: %w", f.Path, err)
|
|
}
|
|
}
|
|
|
|
close(progress)
|
|
<-done
|
|
|
|
// Print summary
|
|
elapsed := time.Since(startTime)
|
|
avgBytesPerSec := float64(totalBytes) / elapsed.Seconds()
|
|
avgRate := formatBitrate(avgBytesPerSec * 8)
|
|
log.Infof("downloaded %d files (%s) in %.1fs (%s avg)",
|
|
len(files),
|
|
humanize.IBytes(uint64(totalBytes)),
|
|
elapsed.Seconds(),
|
|
avgRate)
|
|
|
|
return nil
|
|
}
|
|
|
|
// encodeFilePath URL-encodes each segment of a file path while preserving slashes.
|
|
func encodeFilePath(p string) string {
|
|
segments := strings.Split(p, "/")
|
|
for i, seg := range segments {
|
|
segments[i] = url.PathEscape(seg)
|
|
}
|
|
return strings.Join(segments, "/")
|
|
}
|
|
|
|
// sanitizePath validates and sanitizes a file path from the manifest.
|
|
// It prevents path traversal attacks and rejects unsafe paths.
|
|
func sanitizePath(p string) (string, error) {
|
|
// Reject empty paths
|
|
if p == "" {
|
|
return "", fmt.Errorf("empty path")
|
|
}
|
|
|
|
// Reject absolute paths
|
|
if filepath.IsAbs(p) {
|
|
return "", fmt.Errorf("absolute path not allowed: %s", p)
|
|
}
|
|
|
|
// Clean the path to resolve . and ..
|
|
cleaned := filepath.Clean(p)
|
|
|
|
// Reject paths that escape the current directory
|
|
if strings.HasPrefix(cleaned, ".."+string(filepath.Separator)) || cleaned == ".." {
|
|
return "", fmt.Errorf("path traversal not allowed: %s", p)
|
|
}
|
|
|
|
// Also check for absolute paths after cleaning (handles edge cases)
|
|
if filepath.IsAbs(cleaned) {
|
|
return "", fmt.Errorf("absolute path not allowed: %s", p)
|
|
}
|
|
|
|
return cleaned, nil
|
|
}
|
|
|
|
// resolveManifestURL takes a URL and returns the manifest URL.
|
|
// If the URL already ends with .mf, it's returned as-is.
|
|
// Otherwise, index.mf is appended.
|
|
func resolveManifestURL(inputURL string) (string, error) {
|
|
parsed, err := url.Parse(inputURL)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Check if URL already ends with .mf
|
|
if strings.HasSuffix(parsed.Path, ".mf") {
|
|
return inputURL, nil
|
|
}
|
|
|
|
// Ensure path ends with /
|
|
if !strings.HasSuffix(parsed.Path, "/") {
|
|
parsed.Path += "/"
|
|
}
|
|
|
|
// Append index.mf
|
|
parsed.Path += "index.mf"
|
|
|
|
return parsed.String(), nil
|
|
}
|
|
|
|
// progressWriter wraps an io.Writer and reports progress to a channel.
|
|
type progressWriter struct {
|
|
w io.Writer
|
|
path string
|
|
total int64
|
|
written int64
|
|
startTime time.Time
|
|
progress chan<- DownloadProgress
|
|
}
|
|
|
|
func (pw *progressWriter) Write(p []byte) (int, error) {
|
|
n, err := pw.w.Write(p)
|
|
pw.written += int64(n)
|
|
if pw.progress != nil {
|
|
var bytesPerSec float64
|
|
var eta time.Duration
|
|
elapsed := time.Since(pw.startTime)
|
|
if elapsed > 0 && pw.written > 0 {
|
|
bytesPerSec = float64(pw.written) / elapsed.Seconds()
|
|
if bytesPerSec > 0 && pw.total > 0 {
|
|
remainingBytes := pw.total - pw.written
|
|
eta = time.Duration(float64(remainingBytes)/bytesPerSec) * time.Second
|
|
}
|
|
}
|
|
sendProgress(pw.progress, DownloadProgress{
|
|
Path: pw.path,
|
|
BytesRead: pw.written,
|
|
TotalBytes: pw.total,
|
|
BytesPerSec: bytesPerSec,
|
|
ETA: eta,
|
|
})
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
// formatBitrate formats a bits-per-second value with appropriate unit prefix.
|
|
func formatBitrate(bps float64) string {
|
|
switch {
|
|
case bps >= 1e9:
|
|
return fmt.Sprintf("%.1f Gbps", bps/1e9)
|
|
case bps >= 1e6:
|
|
return fmt.Sprintf("%.1f Mbps", bps/1e6)
|
|
case bps >= 1e3:
|
|
return fmt.Sprintf("%.1f Kbps", bps/1e3)
|
|
default:
|
|
return fmt.Sprintf("%.0f bps", bps)
|
|
}
|
|
}
|
|
|
|
// sendProgress sends a progress update without blocking.
|
|
func sendProgress(ch chan<- DownloadProgress, p DownloadProgress) {
|
|
select {
|
|
case ch <- p:
|
|
default:
|
|
}
|
|
}
|
|
|
|
// downloadFile downloads a URL to a local file path with hash verification.
|
|
// It downloads to a temporary file, verifies the hash, then renames to the final path.
|
|
// Progress is reported via the progress channel.
|
|
func downloadFile(fileURL, localPath string, entry *mfer.MFFilePath, progress chan<- DownloadProgress) error {
|
|
// Create parent directories if needed
|
|
dir := filepath.Dir(localPath)
|
|
if dir != "" && dir != "." {
|
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Compute temp file path in the same directory
|
|
// For dotfiles, just append .tmp (they're already hidden)
|
|
// For regular files, prefix with . and append .tmp
|
|
base := filepath.Base(localPath)
|
|
var tmpName string
|
|
if strings.HasPrefix(base, ".") {
|
|
tmpName = base + ".tmp"
|
|
} else {
|
|
tmpName = "." + base + ".tmp"
|
|
}
|
|
tmpPath := filepath.Join(dir, tmpName)
|
|
if dir == "" || dir == "." {
|
|
tmpPath = tmpName
|
|
}
|
|
|
|
// Fetch file
|
|
resp, err := http.Get(fileURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("HTTP %d", resp.StatusCode)
|
|
}
|
|
|
|
// Determine expected size
|
|
expectedSize := entry.Size
|
|
totalBytes := resp.ContentLength
|
|
if totalBytes < 0 {
|
|
totalBytes = expectedSize
|
|
}
|
|
|
|
// Create temp file
|
|
out, err := os.Create(tmpPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set up hash computation
|
|
h := sha256.New()
|
|
|
|
// Create progress-reporting writer that also computes hash
|
|
pw := &progressWriter{
|
|
w: io.MultiWriter(out, h),
|
|
path: localPath,
|
|
total: totalBytes,
|
|
startTime: time.Now(),
|
|
progress: progress,
|
|
}
|
|
|
|
// Copy content while hashing and reporting progress
|
|
written, copyErr := io.Copy(pw, resp.Body)
|
|
|
|
// Close file before checking errors (to flush writes)
|
|
closeErr := out.Close()
|
|
|
|
// If copy failed, clean up temp file and return error
|
|
if copyErr != nil {
|
|
_ = os.Remove(tmpPath)
|
|
return copyErr
|
|
}
|
|
if closeErr != nil {
|
|
_ = os.Remove(tmpPath)
|
|
return closeErr
|
|
}
|
|
|
|
// Verify size
|
|
if written != expectedSize {
|
|
_ = os.Remove(tmpPath)
|
|
return fmt.Errorf("size mismatch: expected %d bytes, got %d", expectedSize, written)
|
|
}
|
|
|
|
// Encode computed hash as multihash
|
|
computed, err := multihash.Encode(h.Sum(nil), multihash.SHA2_256)
|
|
if err != nil {
|
|
_ = os.Remove(tmpPath)
|
|
return fmt.Errorf("failed to encode hash: %w", err)
|
|
}
|
|
|
|
// Verify hash against manifest (at least one must match)
|
|
hashMatch := false
|
|
for _, hash := range entry.Hashes {
|
|
if bytes.Equal(computed, hash.MultiHash) {
|
|
hashMatch = true
|
|
break
|
|
}
|
|
}
|
|
if !hashMatch {
|
|
_ = os.Remove(tmpPath)
|
|
return fmt.Errorf("hash mismatch")
|
|
}
|
|
|
|
// Rename temp file to final path
|
|
if err := os.Rename(tmpPath, localPath); err != nil {
|
|
_ = os.Remove(tmpPath)
|
|
return fmt.Errorf("failed to rename temp file: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|