fastmirror/fastmirror.go

364 lines
9.3 KiB
Go

package fastmirror
import (
"bufio"
_ "embed"
"fmt"
"github.com/schollz/progressbar/v3"
"html/template"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
)
//go:embed sources_list_legacy.tmpl
var legacyTemplate string
//go:embed sources_list_modern.tmpl
var modernTemplate string
func CLIEntry() {
log.Println("Starting Fastmirror CLI")
if runtime.GOOS != "linux" {
log.Fatal("This program is only for Linux")
}
if err := isUbuntu(); err != nil {
log.Fatal(err)
}
sourcesFilePath, err := findSourcesFilePath()
if err != nil {
log.Fatal(err)
}
log.Printf("Found sources file: %s", sourcesFilePath)
// Extract suites from the existing sources file
suites, err := extractSuites(sourcesFilePath)
if err != nil {
log.Fatal(err)
}
log.Printf("Extracted suites: %v", suites)
// Create backup of the sources file
if err := createBackup(sourcesFilePath); err != nil {
log.Fatal(err)
}
// Create sources.list.d directory if it doesn't exist
if err := createSourcesListD(); err != nil {
log.Fatal(err)
}
// Fetch and find the fastest mirror
fastestMirrorURL, fastestTime, latencies, err := findFastestMirror()
if err != nil {
log.Fatal(err)
}
log.Printf("Fastest mirror: %s", fastestMirrorURL)
// Display latency statistics
displayLatencyStatistics(latencies, fastestMirrorURL, fastestTime)
// Update sources file with the fastest mirror
if err := updateSourcesFile(sourcesFilePath, fastestMirrorURL, suites); err != nil {
log.Fatal(err)
}
log.Println("Fastmirror CLI finished successfully")
}
func getUbuntuCodename() (string, error) {
cmd := exec.Command("lsb_release", "-cs")
output, err := cmd.Output()
if err != nil {
return "", fmt.Errorf("failed to run lsb_release: %v", err)
}
return strings.TrimSpace(string(output)), nil
}
func getUbuntuVersion() (int, int, error) {
cmd := exec.Command("lsb_release", "-rs")
output, err := cmd.Output()
if err != nil {
return 0, 0, fmt.Errorf("failed to run lsb_release: %v", err)
}
version := strings.TrimSpace(string(output))
parts := strings.Split(version, ".")
if len(parts) != 2 {
return 0, 0, fmt.Errorf("unexpected version format: %s", version)
}
majorVersion, err := strconv.Atoi(parts[0])
if err != nil {
return 0, 0, fmt.Errorf("failed to parse major version: %v", err)
}
minorVersion, err := strconv.Atoi(parts[1])
if err != nil {
return 0, 0, fmt.Errorf("failed to parse minor version: %v", err)
}
return majorVersion, minorVersion, nil
}
func isUbuntu() error {
cmd := exec.Command("lsb_release", "-is")
output, err := cmd.Output()
if err != nil {
return fmt.Errorf("failed to run lsb_release: %v", err)
}
if strings.TrimSpace(string(output)) != "Ubuntu" {
return fmt.Errorf("this program is only for Ubuntu")
}
return nil
}
func findSourcesFilePath() (string, error) {
const sourcesListPath = "/etc/apt/sources.list"
const sourcesListDPath = "/etc/apt/sources.list.d/"
if isUbuntuSourcesFile(sourcesListPath) {
log.Printf("Found Ubuntu sources file: %s", sourcesListPath)
return sourcesListPath, nil
}
files, err := ioutil.ReadDir(sourcesListDPath)
if err != nil {
return "", fmt.Errorf("failed to read directory %s: %v", sourcesListDPath, err)
}
for _, file := range files {
if file.IsDir() {
continue
}
filePath := filepath.Join(sourcesListDPath, file.Name())
if isUbuntuSourcesFile(filePath) {
log.Printf("Found Ubuntu sources file: %s", filePath)
return filePath, nil
}
}
return "", fmt.Errorf("no Ubuntu sources file found")
}
func isUbuntuSourcesFile(filePath string) bool {
content, err := ioutil.ReadFile(filePath)
if err != nil {
log.Printf("Failed to read file %s: %v", filePath, err)
return false
}
return strings.Contains(strings.ToLower(string(content)), "ubuntu.com")
}
func createBackup(filePath string) error {
backupPath := filePath + ".bak"
if _, err := os.Stat(backupPath); err == nil {
return fmt.Errorf("backup file %s already exists", backupPath)
}
if err := os.Rename(filePath, backupPath); err != nil {
return fmt.Errorf("failed to create backup of %s: %v", filePath, err)
}
log.Printf("Created backup: %s", backupPath)
return nil
}
func createSourcesListD() error {
const sourcesListDPath = "/etc/apt/sources.list.d/"
if _, err := os.Stat(sourcesListDPath); os.IsNotExist(err) {
if err := os.Mkdir(sourcesListDPath, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %v", sourcesListDPath, err)
}
log.Printf("Created directory: %s", sourcesListDPath)
}
return nil
}
func findFastestMirror() (string, time.Duration, []time.Duration, error) {
log.Println("Fetching mirrors list")
resp, err := http.Get("http://mirrors.ubuntu.com/mirrors.txt")
if err != nil {
return "", 0, nil, fmt.Errorf("failed to fetch mirrors list: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", 0, nil, fmt.Errorf("failed to read mirrors list: %v", err)
}
mirrors := strings.Split(string(body), "\n")
log.Printf("Found %d mirrors", len(mirrors))
var fastestMirrorURL string
var fastestTime time.Duration
var latencies []time.Duration
downCount := 0
noIndexCount := 0
codename, err := getUbuntuCodename()
if err != nil {
return "", 0, nil, err
}
httpClient := http.Client{
Timeout: 1 * time.Second,
}
log.Println("Testing mirrors for latency")
bar := progressbar.Default(int64(len(mirrors)))
for _, mirror := range mirrors {
bar.Add(1)
if strings.HasPrefix(mirror, "https://") {
mirror = strings.TrimSuffix(mirror, "/")
startTime := time.Now()
isValid := isValidMirror(httpClient, mirror, codename)
elapsedTime := time.Since(startTime)
if isValid {
latencies = append(latencies, elapsedTime)
if fastestMirrorURL == "" || elapsedTime < fastestTime {
fastestMirrorURL = mirror
fastestTime = elapsedTime
}
} else {
noIndexCount++
}
} else {
downCount++
}
}
if fastestMirrorURL == "" {
return "", 0, nil, fmt.Errorf("no suitable HTTPS mirror found")
}
displayLatencyStatistics(latencies, fastestMirrorURL, fastestTime)
log.Printf("Number of mirrors that were down: %d", downCount)
log.Printf("Number of mirrors without required package index: %d", noIndexCount)
return fastestMirrorURL, fastestTime, latencies, nil
}
func isValidMirror(httpClient http.Client, mirrorURL, codename string) bool {
uri := fmt.Sprintf("%s/dists/%s/Release", mirrorURL, codename)
resp, err := httpClient.Get(uri)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
func extractSuites(filePath string) ([]string, error) {
content, err := ioutil.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read file %s: %v", filePath, err)
}
// Extract suites from the sources file
var suites []string
scanner := bufio.NewScanner(strings.NewReader(string(content)))
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "deb ") || strings.HasPrefix(line, "deb-src ") {
parts := strings.Fields(line)
if len(parts) >= 4 {
suites = append(suites, parts[3:]...)
}
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("failed to scan file %s: %v", filePath, err)
}
return suites, nil
}
func displayLatencyStatistics(latencies []time.Duration, fastestMirrorURL string, fastestTime time.Duration) {
if len(latencies) == 0 {
log.Println("No latencies recorded")
return
}
var totalLatency time.Duration
var minLatency time.Duration = latencies[0]
var maxLatency time.Duration = latencies[0]
for _, latency := range latencies {
totalLatency += latency
if latency < minLatency {
minLatency = latency
}
if latency > maxLatency {
maxLatency = latency
}
}
averageLatency := totalLatency / time.Duration(len(latencies))
log.Printf("Minimum latency: %s", minLatency)
log.Printf("Maximum latency: %s", maxLatency)
log.Printf("Average latency: %s", averageLatency)
log.Printf("Chosen fastest mirror: %s with latency: %s", fastestMirrorURL, fastestTime)
}
func updateSourcesFile(filePath, mirrorURL string, suites []string) error {
codename, err := getUbuntuCodename()
if err != nil {
return err
}
majorVersion, minorVersion, err := getUbuntuVersion()
if err != nil {
return err
}
suitesString := strings.Join(suites, " ")
var content string
data := struct {
MirrorURL string
Codename string
Suites string
}{
MirrorURL: mirrorURL,
Codename: codename,
Suites: suitesString,
}
if majorVersion > 24 || (majorVersion == 24 && minorVersion >= 4) {
tmpl, err := template.New("modern").Parse(modernTemplate)
if err != nil {
return fmt.Errorf("failed to parse modern template: %v", err)
}
var buf strings.Builder
err = tmpl.Execute(&buf, data)
if err != nil {
return fmt.Errorf("failed to execute modern template: %v", err)
}
content = buf.String()
} else {
tmpl, err := template.New("legacy").Parse(legacyTemplate)
if err != nil {
return fmt.Errorf("failed to parse legacy template: %v", err)
}
var buf strings.Builder
err = tmpl.Execute(&buf, data)
if err != nil {
return fmt.Errorf("failed to execute legacy template: %v", err)
}
content = buf.String()
}
err = ioutil.WriteFile(filePath, []byte(content), 0644)
if err != nil {
return fmt.Errorf("failed to update sources file %s: %v", filePath, err)
}
log.Printf("Updated sources file: %s", filePath)
return nil
}