mfer/internal/cli/fetch.go
sneak a5b0343b28 Use Go 1.13+ octal literal syntax throughout codebase
Update file permission literals from legacy octal format (0755, 0644)
to explicit Go 1.13+ format (0o755, 0o644) for improved readability.
2025-12-18 01:29:40 -08:00

366 lines
8.9 KiB
Go

package cli
import (
"bytes"
"crypto/sha256"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/dustin/go-humanize"
"github.com/multiformats/go-multihash"
"github.com/urfave/cli/v2"
"sneak.berlin/go/mfer/internal/log"
"sneak.berlin/go/mfer/mfer"
)
// DownloadProgress reports the progress of a single file download.
type DownloadProgress struct {
Path string // File path being downloaded
BytesRead int64 // Bytes downloaded so far
TotalBytes int64 // Total expected bytes (-1 if unknown)
BytesPerSec float64 // Current download rate
ETA time.Duration // Estimated time to completion
}
func (mfa *CLIApp) fetchManifestOperation(ctx *cli.Context) error {
log.Debug("fetchManifestOperation()")
if ctx.Args().Len() == 0 {
return fmt.Errorf("URL argument required")
}
inputURL := ctx.Args().Get(0)
manifestURL, err := resolveManifestURL(inputURL)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
log.Infof("fetching manifest from %s", manifestURL)
// Fetch manifest
resp, err := http.Get(manifestURL)
if err != nil {
return fmt.Errorf("failed to fetch manifest: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to fetch manifest: HTTP %d", resp.StatusCode)
}
// Parse manifest
manifest, err := mfer.NewManifestFromReader(resp.Body)
if err != nil {
return fmt.Errorf("failed to parse manifest: %w", err)
}
files := manifest.Files()
log.Infof("manifest contains %d files", len(files))
// Compute base URL (directory containing manifest)
baseURL, err := url.Parse(manifestURL)
if err != nil {
return err
}
baseURL.Path = path.Dir(baseURL.Path)
if !strings.HasSuffix(baseURL.Path, "/") {
baseURL.Path += "/"
}
// Calculate total bytes to download
var totalBytes int64
for _, f := range files {
totalBytes += f.Size
}
// Create progress channel
progress := make(chan DownloadProgress, 10)
// Start progress reporter goroutine
done := make(chan struct{})
go func() {
defer close(done)
for p := range progress {
rate := formatBitrate(p.BytesPerSec * 8)
if p.ETA > 0 {
log.Infof("%s: %s/%s, %s, ETA %s",
p.Path, humanize.IBytes(uint64(p.BytesRead)), humanize.IBytes(uint64(p.TotalBytes)),
rate, p.ETA.Round(time.Second))
} else {
log.Infof("%s: %s/%s, %s",
p.Path, humanize.IBytes(uint64(p.BytesRead)), humanize.IBytes(uint64(p.TotalBytes)), rate)
}
}
}()
// Track download start time
startTime := time.Now()
// Download each file
for _, f := range files {
// Sanitize the path to prevent path traversal attacks
localPath, err := sanitizePath(f.Path)
if err != nil {
close(progress)
<-done
return fmt.Errorf("invalid path in manifest: %w", err)
}
fileURL := baseURL.String() + f.Path
log.Infof("fetching %s", f.Path)
if err := downloadFile(fileURL, localPath, f, progress); err != nil {
close(progress)
<-done
return fmt.Errorf("failed to download %s: %w", f.Path, err)
}
}
close(progress)
<-done
// Print summary
elapsed := time.Since(startTime)
avgBytesPerSec := float64(totalBytes) / elapsed.Seconds()
avgRate := formatBitrate(avgBytesPerSec * 8)
log.Infof("downloaded %d files (%s) in %.1fs (%s avg)",
len(files),
humanize.IBytes(uint64(totalBytes)),
elapsed.Seconds(),
avgRate)
return nil
}
// sanitizePath validates and sanitizes a file path from the manifest.
// It prevents path traversal attacks and rejects unsafe paths.
func sanitizePath(p string) (string, error) {
// Reject empty paths
if p == "" {
return "", fmt.Errorf("empty path")
}
// Reject absolute paths
if filepath.IsAbs(p) {
return "", fmt.Errorf("absolute path not allowed: %s", p)
}
// Clean the path to resolve . and ..
cleaned := filepath.Clean(p)
// Reject paths that escape the current directory
if strings.HasPrefix(cleaned, ".."+string(filepath.Separator)) || cleaned == ".." {
return "", fmt.Errorf("path traversal not allowed: %s", p)
}
// Also check for absolute paths after cleaning (handles edge cases)
if filepath.IsAbs(cleaned) {
return "", fmt.Errorf("absolute path not allowed: %s", p)
}
return cleaned, nil
}
// resolveManifestURL takes a URL and returns the manifest URL.
// If the URL already ends with .mf, it's returned as-is.
// Otherwise, index.mf is appended.
func resolveManifestURL(inputURL string) (string, error) {
parsed, err := url.Parse(inputURL)
if err != nil {
return "", err
}
// Check if URL already ends with .mf
if strings.HasSuffix(parsed.Path, ".mf") {
return inputURL, nil
}
// Ensure path ends with /
if !strings.HasSuffix(parsed.Path, "/") {
parsed.Path += "/"
}
// Append index.mf
parsed.Path += "index.mf"
return parsed.String(), nil
}
// progressWriter wraps an io.Writer and reports progress to a channel.
type progressWriter struct {
w io.Writer
path string
total int64
written int64
startTime time.Time
progress chan<- DownloadProgress
}
func (pw *progressWriter) Write(p []byte) (int, error) {
n, err := pw.w.Write(p)
pw.written += int64(n)
if pw.progress != nil {
var bytesPerSec float64
var eta time.Duration
elapsed := time.Since(pw.startTime)
if elapsed > 0 && pw.written > 0 {
bytesPerSec = float64(pw.written) / elapsed.Seconds()
if bytesPerSec > 0 && pw.total > 0 {
remainingBytes := pw.total - pw.written
eta = time.Duration(float64(remainingBytes)/bytesPerSec) * time.Second
}
}
sendProgress(pw.progress, DownloadProgress{
Path: pw.path,
BytesRead: pw.written,
TotalBytes: pw.total,
BytesPerSec: bytesPerSec,
ETA: eta,
})
}
return n, err
}
// formatBitrate formats a bits-per-second value with appropriate unit prefix.
func formatBitrate(bps float64) string {
switch {
case bps >= 1e9:
return fmt.Sprintf("%.1f Gbps", bps/1e9)
case bps >= 1e6:
return fmt.Sprintf("%.1f Mbps", bps/1e6)
case bps >= 1e3:
return fmt.Sprintf("%.1f Kbps", bps/1e3)
default:
return fmt.Sprintf("%.0f bps", bps)
}
}
// sendProgress sends a progress update without blocking.
func sendProgress(ch chan<- DownloadProgress, p DownloadProgress) {
select {
case ch <- p:
default:
}
}
// downloadFile downloads a URL to a local file path with hash verification.
// It downloads to a temporary file, verifies the hash, then renames to the final path.
// Progress is reported via the progress channel.
func downloadFile(fileURL, localPath string, entry *mfer.MFFilePath, progress chan<- DownloadProgress) error {
// Create parent directories if needed
dir := filepath.Dir(localPath)
if dir != "" && dir != "." {
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
}
// Compute temp file path in the same directory
// For dotfiles, just append .tmp (they're already hidden)
// For regular files, prefix with . and append .tmp
base := filepath.Base(localPath)
var tmpName string
if strings.HasPrefix(base, ".") {
tmpName = base + ".tmp"
} else {
tmpName = "." + base + ".tmp"
}
tmpPath := filepath.Join(dir, tmpName)
if dir == "" || dir == "." {
tmpPath = tmpName
}
// Fetch file
resp, err := http.Get(fileURL)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("HTTP %d", resp.StatusCode)
}
// Determine expected size
expectedSize := entry.Size
totalBytes := resp.ContentLength
if totalBytes < 0 {
totalBytes = expectedSize
}
// Create temp file
out, err := os.Create(tmpPath)
if err != nil {
return err
}
// Set up hash computation
h := sha256.New()
// Create progress-reporting writer that also computes hash
pw := &progressWriter{
w: io.MultiWriter(out, h),
path: localPath,
total: totalBytes,
startTime: time.Now(),
progress: progress,
}
// Copy content while hashing and reporting progress
written, copyErr := io.Copy(pw, resp.Body)
// Close file before checking errors (to flush writes)
closeErr := out.Close()
// If copy failed, clean up temp file and return error
if copyErr != nil {
_ = os.Remove(tmpPath)
return copyErr
}
if closeErr != nil {
_ = os.Remove(tmpPath)
return closeErr
}
// Verify size
if written != expectedSize {
_ = os.Remove(tmpPath)
return fmt.Errorf("size mismatch: expected %d bytes, got %d", expectedSize, written)
}
// Encode computed hash as multihash
computed, err := multihash.Encode(h.Sum(nil), multihash.SHA2_256)
if err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("failed to encode hash: %w", err)
}
// Verify hash against manifest (at least one must match)
hashMatch := false
for _, hash := range entry.Hashes {
if bytes.Equal(computed, hash.MultiHash) {
hashMatch = true
break
}
}
if !hashMatch {
_ = os.Remove(tmpPath)
return fmt.Errorf("hash mismatch")
}
// Rename temp file to final path
if err := os.Rename(tmpPath, localPath); err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("failed to rename temp file: %w", err)
}
return nil
}