Change FileProgress callback to channel-based progress
Replace callback-based progress reporting in Builder.AddFile with channel-based FileHashProgress for consistency with EnumerateStatus and ScanStatus patterns. Update scanner.go to use the new channel API.
This commit is contained in:
@@ -1,12 +1,366 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/apex/log"
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/multiformats/go-multihash"
|
||||
"github.com/urfave/cli/v2"
|
||||
"sneak.berlin/go/mfer/internal/log"
|
||||
"sneak.berlin/go/mfer/mfer"
|
||||
)
|
||||
|
||||
func (mfa *CLIApp) fetchManifestOperation(c *cli.Context) error {
|
||||
log.Debugf("fetchManifestOperation()")
|
||||
panic("not implemented")
|
||||
return nil //nolint
|
||||
// DownloadProgress reports the progress of a single file download.
|
||||
type DownloadProgress struct {
|
||||
Path string // File path being downloaded
|
||||
BytesRead int64 // Bytes downloaded so far
|
||||
TotalBytes int64 // Total expected bytes (-1 if unknown)
|
||||
BytesPerSec float64 // Current download rate
|
||||
ETA time.Duration // Estimated time to completion
|
||||
}
|
||||
|
||||
func (mfa *CLIApp) fetchManifestOperation(ctx *cli.Context) error {
|
||||
log.Debug("fetchManifestOperation()")
|
||||
|
||||
if ctx.Args().Len() == 0 {
|
||||
return fmt.Errorf("URL argument required")
|
||||
}
|
||||
|
||||
inputURL := ctx.Args().Get(0)
|
||||
manifestURL, err := resolveManifestURL(inputURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("fetching manifest from %s", manifestURL)
|
||||
|
||||
// Fetch manifest
|
||||
resp, err := http.Get(manifestURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch manifest: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to fetch manifest: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Parse manifest
|
||||
manifest, err := mfer.NewManifestFromReader(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse manifest: %w", err)
|
||||
}
|
||||
|
||||
files := manifest.Files()
|
||||
log.Infof("manifest contains %d files", len(files))
|
||||
|
||||
// Compute base URL (directory containing manifest)
|
||||
baseURL, err := url.Parse(manifestURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
baseURL.Path = path.Dir(baseURL.Path)
|
||||
if !strings.HasSuffix(baseURL.Path, "/") {
|
||||
baseURL.Path += "/"
|
||||
}
|
||||
|
||||
// Calculate total bytes to download
|
||||
var totalBytes int64
|
||||
for _, f := range files {
|
||||
totalBytes += f.Size
|
||||
}
|
||||
|
||||
// Create progress channel
|
||||
progress := make(chan DownloadProgress, 10)
|
||||
|
||||
// Start progress reporter goroutine
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for p := range progress {
|
||||
rate := formatBitrate(p.BytesPerSec * 8)
|
||||
if p.ETA > 0 {
|
||||
log.Infof("%s: %d/%d bytes, %s, ETA %s",
|
||||
p.Path, p.BytesRead, p.TotalBytes,
|
||||
rate, p.ETA.Round(time.Second))
|
||||
} else {
|
||||
log.Infof("%s: %d/%d bytes, %s",
|
||||
p.Path, p.BytesRead, p.TotalBytes, rate)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Track download start time
|
||||
startTime := time.Now()
|
||||
|
||||
// Download each file
|
||||
for _, f := range files {
|
||||
// Sanitize the path to prevent path traversal attacks
|
||||
localPath, err := sanitizePath(f.Path)
|
||||
if err != nil {
|
||||
close(progress)
|
||||
<-done
|
||||
return fmt.Errorf("invalid path in manifest: %w", err)
|
||||
}
|
||||
|
||||
fileURL := baseURL.String() + f.Path
|
||||
log.Infof("fetching %s", f.Path)
|
||||
|
||||
if err := downloadFile(fileURL, localPath, f, progress); err != nil {
|
||||
close(progress)
|
||||
<-done
|
||||
return fmt.Errorf("failed to download %s: %w", f.Path, err)
|
||||
}
|
||||
}
|
||||
|
||||
close(progress)
|
||||
<-done
|
||||
|
||||
// Print summary if not quiet
|
||||
if !ctx.Bool("quiet") {
|
||||
elapsed := time.Since(startTime)
|
||||
avgBytesPerSec := float64(totalBytes) / elapsed.Seconds()
|
||||
avgRate := formatBitrate(avgBytesPerSec * 8)
|
||||
log.Infof("downloaded %d files (%.1f MB) in %.1fs (%s avg)",
|
||||
len(files),
|
||||
float64(totalBytes)/1e6,
|
||||
elapsed.Seconds(),
|
||||
avgRate)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizePath validates and sanitizes a file path from the manifest.
|
||||
// It prevents path traversal attacks and rejects unsafe paths.
|
||||
func sanitizePath(p string) (string, error) {
|
||||
// Reject empty paths
|
||||
if p == "" {
|
||||
return "", fmt.Errorf("empty path")
|
||||
}
|
||||
|
||||
// Reject absolute paths
|
||||
if filepath.IsAbs(p) {
|
||||
return "", fmt.Errorf("absolute path not allowed: %s", p)
|
||||
}
|
||||
|
||||
// Clean the path to resolve . and ..
|
||||
cleaned := filepath.Clean(p)
|
||||
|
||||
// Reject paths that escape the current directory
|
||||
if strings.HasPrefix(cleaned, ".."+string(filepath.Separator)) || cleaned == ".." {
|
||||
return "", fmt.Errorf("path traversal not allowed: %s", p)
|
||||
}
|
||||
|
||||
// Also check for absolute paths after cleaning (handles edge cases)
|
||||
if filepath.IsAbs(cleaned) {
|
||||
return "", fmt.Errorf("absolute path not allowed: %s", p)
|
||||
}
|
||||
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
// resolveManifestURL takes a URL and returns the manifest URL.
|
||||
// If the URL already ends with .mf, it's returned as-is.
|
||||
// Otherwise, index.mf is appended.
|
||||
func resolveManifestURL(inputURL string) (string, error) {
|
||||
parsed, err := url.Parse(inputURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Check if URL already ends with .mf
|
||||
if strings.HasSuffix(parsed.Path, ".mf") {
|
||||
return inputURL, nil
|
||||
}
|
||||
|
||||
// Ensure path ends with /
|
||||
if !strings.HasSuffix(parsed.Path, "/") {
|
||||
parsed.Path += "/"
|
||||
}
|
||||
|
||||
// Append index.mf
|
||||
parsed.Path += "index.mf"
|
||||
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
// progressWriter wraps an io.Writer and reports progress to a channel.
|
||||
type progressWriter struct {
|
||||
w io.Writer
|
||||
path string
|
||||
total int64
|
||||
written int64
|
||||
startTime time.Time
|
||||
progress chan<- DownloadProgress
|
||||
}
|
||||
|
||||
func (pw *progressWriter) Write(p []byte) (int, error) {
|
||||
n, err := pw.w.Write(p)
|
||||
pw.written += int64(n)
|
||||
if pw.progress != nil {
|
||||
var bytesPerSec float64
|
||||
var eta time.Duration
|
||||
elapsed := time.Since(pw.startTime)
|
||||
if elapsed > 0 && pw.written > 0 {
|
||||
bytesPerSec = float64(pw.written) / elapsed.Seconds()
|
||||
if bytesPerSec > 0 && pw.total > 0 {
|
||||
remainingBytes := pw.total - pw.written
|
||||
eta = time.Duration(float64(remainingBytes)/bytesPerSec) * time.Second
|
||||
}
|
||||
}
|
||||
sendProgress(pw.progress, DownloadProgress{
|
||||
Path: pw.path,
|
||||
BytesRead: pw.written,
|
||||
TotalBytes: pw.total,
|
||||
BytesPerSec: bytesPerSec,
|
||||
ETA: eta,
|
||||
})
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// formatBitrate formats a bits-per-second value with appropriate unit prefix.
|
||||
func formatBitrate(bps float64) string {
|
||||
switch {
|
||||
case bps >= 1e9:
|
||||
return fmt.Sprintf("%.1f Gbps", bps/1e9)
|
||||
case bps >= 1e6:
|
||||
return fmt.Sprintf("%.1f Mbps", bps/1e6)
|
||||
case bps >= 1e3:
|
||||
return fmt.Sprintf("%.1f Kbps", bps/1e3)
|
||||
default:
|
||||
return fmt.Sprintf("%.0f bps", bps)
|
||||
}
|
||||
}
|
||||
|
||||
// sendProgress sends a progress update without blocking.
|
||||
func sendProgress(ch chan<- DownloadProgress, p DownloadProgress) {
|
||||
select {
|
||||
case ch <- p:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// downloadFile downloads a URL to a local file path with hash verification.
|
||||
// It downloads to a temporary file, verifies the hash, then renames to the final path.
|
||||
// Progress is reported via the progress channel.
|
||||
func downloadFile(fileURL, localPath string, entry *mfer.MFFilePath, progress chan<- DownloadProgress) error {
|
||||
// Create parent directories if needed
|
||||
dir := filepath.Dir(localPath)
|
||||
if dir != "" && dir != "." {
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Compute temp file path in the same directory
|
||||
// For dotfiles, just append .tmp (they're already hidden)
|
||||
// For regular files, prefix with . and append .tmp
|
||||
base := filepath.Base(localPath)
|
||||
var tmpName string
|
||||
if strings.HasPrefix(base, ".") {
|
||||
tmpName = base + ".tmp"
|
||||
} else {
|
||||
tmpName = "." + base + ".tmp"
|
||||
}
|
||||
tmpPath := filepath.Join(dir, tmpName)
|
||||
if dir == "" || dir == "." {
|
||||
tmpPath = tmpName
|
||||
}
|
||||
|
||||
// Fetch file
|
||||
resp, err := http.Get(fileURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Determine expected size
|
||||
expectedSize := entry.Size
|
||||
totalBytes := resp.ContentLength
|
||||
if totalBytes < 0 {
|
||||
totalBytes = expectedSize
|
||||
}
|
||||
|
||||
// Create temp file
|
||||
out, err := os.Create(tmpPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set up hash computation
|
||||
h := sha256.New()
|
||||
|
||||
// Create progress-reporting writer that also computes hash
|
||||
pw := &progressWriter{
|
||||
w: io.MultiWriter(out, h),
|
||||
path: localPath,
|
||||
total: totalBytes,
|
||||
startTime: time.Now(),
|
||||
progress: progress,
|
||||
}
|
||||
|
||||
// Copy content while hashing and reporting progress
|
||||
written, copyErr := io.Copy(pw, resp.Body)
|
||||
|
||||
// Close file before checking errors (to flush writes)
|
||||
closeErr := out.Close()
|
||||
|
||||
// If copy failed, clean up temp file and return error
|
||||
if copyErr != nil {
|
||||
os.Remove(tmpPath)
|
||||
return copyErr
|
||||
}
|
||||
if closeErr != nil {
|
||||
os.Remove(tmpPath)
|
||||
return closeErr
|
||||
}
|
||||
|
||||
// Verify size
|
||||
if written != expectedSize {
|
||||
os.Remove(tmpPath)
|
||||
return fmt.Errorf("size mismatch: expected %d bytes, got %d", expectedSize, written)
|
||||
}
|
||||
|
||||
// Encode computed hash as multihash
|
||||
computed, err := multihash.Encode(h.Sum(nil), multihash.SHA2_256)
|
||||
if err != nil {
|
||||
os.Remove(tmpPath)
|
||||
return fmt.Errorf("failed to encode hash: %w", err)
|
||||
}
|
||||
|
||||
// Verify hash against manifest (at least one must match)
|
||||
hashMatch := false
|
||||
for _, hash := range entry.Hashes {
|
||||
if bytes.Equal(computed, hash.MultiHash) {
|
||||
hashMatch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hashMatch {
|
||||
os.Remove(tmpPath)
|
||||
return fmt.Errorf("hash mismatch")
|
||||
}
|
||||
|
||||
// Rename temp file to final path
|
||||
if err := os.Rename(tmpPath, localPath); err != nil {
|
||||
os.Remove(tmpPath)
|
||||
return fmt.Errorf("failed to rename temp file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user