package httpfetcher import ( "context" "errors" "io" "net" "testing" "testing/fstest" ) func TestDefaultConfig(t *testing.T) { cfg := DefaultConfig() if cfg.Timeout != DefaultFetchTimeout { t.Errorf("Timeout = %v, want %v", cfg.Timeout, DefaultFetchTimeout) } if cfg.MaxResponseSize != DefaultMaxResponseSize { t.Errorf("MaxResponseSize = %d, want %d", cfg.MaxResponseSize, DefaultMaxResponseSize) } if cfg.MaxConnectionsPerHost != DefaultMaxConnectionsPerHost { t.Errorf("MaxConnectionsPerHost = %d, want %d", cfg.MaxConnectionsPerHost, DefaultMaxConnectionsPerHost) } if cfg.AllowHTTP { t.Error("AllowHTTP should default to false") } if len(cfg.AllowedContentTypes) == 0 { t.Error("AllowedContentTypes should not be empty") } } func TestNewWithNilConfigUsesDefaults(t *testing.T) { f := New(nil) if f == nil { t.Fatal("New(nil) returned nil") } if f.config == nil { t.Fatal("config should be populated from DefaultConfig") } if f.config.Timeout != DefaultFetchTimeout { t.Errorf("Timeout = %v, want %v", f.config.Timeout, DefaultFetchTimeout) } } func TestIsAllowedContentType(t *testing.T) { f := New(DefaultConfig()) tests := []struct { contentType string want bool }{ {"image/jpeg", true}, {"image/png", true}, {"image/webp", true}, {"image/jpeg; charset=utf-8", true}, {"IMAGE/JPEG", true}, {"text/html", false}, {"application/octet-stream", false}, {"", false}, } for _, tc := range tests { t.Run(tc.contentType, func(t *testing.T) { got := f.isAllowedContentType(tc.contentType) if got != tc.want { t.Errorf("isAllowedContentType(%q) = %v, want %v", tc.contentType, got, tc.want) } }) } } func TestExtractHost(t *testing.T) { tests := []struct { url string want string }{ {"https://example.com/path", "example.com"}, {"http://example.com:8080/path", "example.com:8080"}, {"https://example.com", "example.com"}, {"https://example.com?q=1", "example.com"}, {"example.com/path", "example.com"}, {"", ""}, } for _, tc := range tests { t.Run(tc.url, func(t *testing.T) { got := extractHost(tc.url) if got != tc.want { t.Errorf("extractHost(%q) = %q, want %q", tc.url, got, tc.want) } }) } } func TestIsLocalhost(t *testing.T) { tests := []struct { host string want bool }{ {"localhost", true}, {"LOCALHOST", true}, {"127.0.0.1", true}, {"::1", true}, {"[::1]", true}, {"foo.localhost", true}, {"foo.local", true}, {"example.com", false}, {"127.0.0.2", false}, // Handled by isPrivateIP, not isLocalhost string match } for _, tc := range tests { t.Run(tc.host, func(t *testing.T) { got := isLocalhost(tc.host) if got != tc.want { t.Errorf("isLocalhost(%q) = %v, want %v", tc.host, got, tc.want) } }) } } func TestIsPrivateIP(t *testing.T) { tests := []struct { ip string want bool }{ {"127.0.0.1", true}, // loopback {"10.0.0.1", true}, // private {"192.168.1.1", true}, // private {"172.16.0.1", true}, // private {"169.254.1.1", true}, // link-local {"0.0.0.0", true}, // unspecified {"224.0.0.1", true}, // multicast {"::1", true}, // IPv6 loopback {"fe80::1", true}, // IPv6 link-local {"8.8.8.8", false}, // public {"2001:4860:4860::8888", false}, // public IPv6 } for _, tc := range tests { t.Run(tc.ip, func(t *testing.T) { ip := net.ParseIP(tc.ip) if ip == nil { t.Fatalf("failed to parse IP %q", tc.ip) } got := isPrivateIP(ip) if got != tc.want { t.Errorf("isPrivateIP(%q) = %v, want %v", tc.ip, got, tc.want) } }) } if !isPrivateIP(nil) { t.Error("isPrivateIP(nil) should return true") } } func TestValidateURL_RejectsNonHTTPS(t *testing.T) { err := validateURL("http://example.com/path", false) if !errors.Is(err, ErrUnsupportedScheme) { t.Errorf("validateURL http = %v, want ErrUnsupportedScheme", err) } } func TestValidateURL_AllowsHTTPWhenConfigured(t *testing.T) { // Use a host that won't resolve (explicit .invalid TLD) so we don't hit DNS. err := validateURL("http://nonexistent.invalid/path", true) // We expect a host resolution error, not ErrUnsupportedScheme. if errors.Is(err, ErrUnsupportedScheme) { t.Error("validateURL with AllowHTTP should not return ErrUnsupportedScheme") } } func TestValidateURL_RejectsLocalhost(t *testing.T) { err := validateURL("https://localhost/path", false) if !errors.Is(err, ErrSSRFBlocked) { t.Errorf("validateURL localhost = %v, want ErrSSRFBlocked", err) } } func TestValidateURL_EmptyHost(t *testing.T) { err := validateURL("https:///path", false) if !errors.Is(err, ErrInvalidHost) { t.Errorf("validateURL empty host = %v, want ErrInvalidHost", err) } } func TestMockFetcher_FetchesFile(t *testing.T) { mockFS := fstest.MapFS{ "example.com/images/photo.jpg": &fstest.MapFile{Data: []byte("fake-jpeg-data")}, } m := NewMock(mockFS) result, err := m.Fetch(context.Background(), "https://example.com/images/photo.jpg") if err != nil { t.Fatalf("Fetch() error = %v", err) } defer func() { _ = result.Content.Close() }() if result.ContentType != "image/jpeg" { t.Errorf("ContentType = %q, want image/jpeg", result.ContentType) } data, err := io.ReadAll(result.Content) if err != nil { t.Fatalf("read content: %v", err) } if string(data) != "fake-jpeg-data" { t.Errorf("Content = %q, want %q", string(data), "fake-jpeg-data") } if result.ContentLength != int64(len("fake-jpeg-data")) { t.Errorf("ContentLength = %d, want %d", result.ContentLength, len("fake-jpeg-data")) } } func TestMockFetcher_MissingFileReturnsUpstreamError(t *testing.T) { mockFS := fstest.MapFS{} m := NewMock(mockFS) _, err := m.Fetch(context.Background(), "https://example.com/missing.jpg") if !errors.Is(err, ErrUpstreamError) { t.Errorf("Fetch() error = %v, want ErrUpstreamError", err) } } func TestMockFetcher_RespectsContextCancellation(t *testing.T) { mockFS := fstest.MapFS{ "example.com/photo.jpg": &fstest.MapFile{Data: []byte("data")}, } m := NewMock(mockFS) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := m.Fetch(ctx, "https://example.com/photo.jpg") if !errors.Is(err, context.Canceled) { t.Errorf("Fetch() error = %v, want context.Canceled", err) } } func TestDetectContentTypeFromPath(t *testing.T) { tests := []struct { path string want string }{ {"foo/bar.jpg", "image/jpeg"}, {"foo/bar.JPG", "image/jpeg"}, {"foo/bar.jpeg", "image/jpeg"}, {"foo/bar.png", "image/png"}, {"foo/bar.gif", "image/gif"}, {"foo/bar.webp", "image/webp"}, {"foo/bar.avif", "image/avif"}, {"foo/bar.svg", "image/svg+xml"}, {"foo/bar.bin", "application/octet-stream"}, {"foo/bar", "application/octet-stream"}, } for _, tc := range tests { t.Run(tc.path, func(t *testing.T) { got := detectContentTypeFromPath(tc.path) if got != tc.want { t.Errorf("detectContentTypeFromPath(%q) = %q, want %q", tc.path, got, tc.want) } }) } } func TestLimitedReader_EnforcesLimit(t *testing.T) { src := make([]byte, 100) r := &limitedReader{ reader: &byteReader{data: src}, remaining: 50, } buf := make([]byte, 100) n, err := r.Read(buf) if err != nil { t.Fatalf("first Read error = %v", err) } if n > 50 { t.Errorf("read %d bytes, should be capped at 50", n) } // Drain until limit is exhausted. total := n for total < 50 { nn, err := r.Read(buf) total += nn if err != nil { t.Fatalf("during drain: %v", err) } } // Now the limit is exhausted — next read should error. _, err = r.Read(buf) if !errors.Is(err, ErrResponseTooLarge) { t.Errorf("exhausted Read error = %v, want ErrResponseTooLarge", err) } } // byteReader is a minimal io.Reader over a byte slice for testing. type byteReader struct { data []byte pos int } func (r *byteReader) Read(p []byte) (int, error) { if r.pos >= len(r.data) { return 0, io.EOF } n := copy(p, r.data[r.pos:]) r.pos += n return n, nil }