package imgcache import ( "bytes" "io" "strings" "testing" ) func TestDetectFormat(t *testing.T) { tests := []struct { name string data []byte wantMIME MIMEType wantErr error }{ { name: "JPEG", data: append([]byte{0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01}, make([]byte, 100)...), wantMIME: MIMETypeJPEG, wantErr: nil, }, { name: "PNG", data: append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D}, make([]byte, 100)...), wantMIME: MIMETypePNG, wantErr: nil, }, { name: "GIF87a", data: append([]byte{0x47, 0x49, 0x46, 0x38, 0x37, 0x61, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, make([]byte, 100)...), wantMIME: MIMETypeGIF, wantErr: nil, }, { name: "GIF89a", data: append([]byte{0x47, 0x49, 0x46, 0x38, 0x39, 0x61, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, make([]byte, 100)...), wantMIME: MIMETypeGIF, wantErr: nil, }, { name: "WebP", data: append([]byte{ 0x52, 0x49, 0x46, 0x46, // RIFF 0x00, 0x00, 0x00, 0x00, // file size (placeholder) 0x57, 0x45, 0x42, 0x50, // WEBP }, make([]byte, 100)...), wantMIME: MIMETypeWebP, wantErr: nil, }, { name: "AVIF", data: append([]byte{ 0x00, 0x00, 0x00, 0x1C, // box size 0x66, 0x74, 0x79, 0x70, // ftyp 0x61, 0x76, 0x69, 0x66, // avif brand }, make([]byte, 100)...), wantMIME: MIMETypeAVIF, wantErr: nil, }, { name: "AVIF sequence", data: append([]byte{ 0x00, 0x00, 0x00, 0x1C, // box size 0x66, 0x74, 0x79, 0x70, // ftyp 0x61, 0x76, 0x69, 0x73, // avis brand }, make([]byte, 100)...), wantMIME: MIMETypeAVIF, wantErr: nil, }, { name: "SVG with XML declaration", data: []byte(``), wantMIME: MIMETypeSVG, wantErr: nil, }, { name: "SVG without declaration", data: []byte(``), wantMIME: MIMETypeSVG, wantErr: nil, }, { name: "SVG with whitespace", data: []byte(` `), wantMIME: MIMETypeSVG, wantErr: nil, }, { name: "SVG with BOM", data: append([]byte{0xEF, 0xBB, 0xBF}, []byte(``)...), wantMIME: MIMETypeSVG, wantErr: nil, }, { name: "unknown format", data: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, wantMIME: "", wantErr: ErrUnknownFormat, }, { name: "too short", data: []byte{0xFF, 0xD8}, wantMIME: "", wantErr: ErrNotEnoughData, }, { name: "empty", data: []byte{}, wantMIME: "", wantErr: ErrNotEnoughData, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := DetectFormat(tt.data) if err != tt.wantErr { t.Errorf("DetectFormat() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.wantMIME { t.Errorf("DetectFormat() = %v, want %v", got, tt.wantMIME) } }) } } func TestValidateMagicBytes(t *testing.T) { jpegData := append([]byte{0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01}, make([]byte, 100)...) pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D}, make([]byte, 100)...) tests := []struct { name string data []byte declaredType string wantErr error }{ { name: "matching JPEG", data: jpegData, declaredType: "image/jpeg", wantErr: nil, }, { name: "matching JPEG with params", data: jpegData, declaredType: "image/jpeg; charset=utf-8", wantErr: nil, }, { name: "matching PNG", data: pngData, declaredType: "image/png", wantErr: nil, }, { name: "mismatched type", data: jpegData, declaredType: "image/png", wantErr: ErrMagicByteMismatch, }, { name: "unknown data", data: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, declaredType: "image/jpeg", wantErr: ErrUnknownFormat, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := ValidateMagicBytes(tt.data, tt.declaredType) if err != tt.wantErr { t.Errorf("ValidateMagicBytes() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestIsSupportedMIMEType(t *testing.T) { tests := []struct { mimeType string want bool }{ {"image/jpeg", true}, {"image/png", true}, {"image/webp", true}, {"image/gif", true}, {"image/avif", true}, {"image/svg+xml", true}, {"IMAGE/JPEG", true}, {"image/jpeg; charset=utf-8", true}, {"image/tiff", false}, {"image/bmp", false}, {"application/octet-stream", false}, {"text/plain", false}, {"", false}, } for _, tt := range tests { t.Run(tt.mimeType, func(t *testing.T) { if got := IsSupportedMIMEType(tt.mimeType); got != tt.want { t.Errorf("IsSupportedMIMEType(%q) = %v, want %v", tt.mimeType, got, tt.want) } }) } } func TestPeekAndValidate(t *testing.T) { jpegData := append([]byte{0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01}, []byte("rest of jpeg data")...) pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D}, []byte("rest of png data")...) tests := []struct { name string data []byte declaredType string wantErr bool wantData []byte }{ { name: "valid JPEG", data: jpegData, declaredType: "image/jpeg", wantErr: false, wantData: jpegData, }, { name: "valid PNG", data: pngData, declaredType: "image/png", wantErr: false, wantData: pngData, }, { name: "mismatched type", data: jpegData, declaredType: "image/png", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := bytes.NewReader(tt.data) result, err := PeekAndValidate(r, tt.declaredType) if tt.wantErr { if err == nil { t.Error("PeekAndValidate() expected error, got nil") } return } if err != nil { t.Errorf("PeekAndValidate() unexpected error = %v", err) return } // Read all data from result reader got, err := io.ReadAll(result) if err != nil { t.Errorf("Failed to read result: %v", err) return } if !bytes.Equal(got, tt.wantData) { t.Errorf("PeekAndValidate() data mismatch: got %d bytes, want %d bytes", len(got), len(tt.wantData)) } }) } } func TestMIMEToImageFormat(t *testing.T) { tests := []struct { mimeType string wantFormat ImageFormat wantOk bool }{ {"image/jpeg", FormatJPEG, true}, {"image/png", FormatPNG, true}, {"image/webp", FormatWebP, true}, {"image/gif", FormatGIF, true}, {"image/avif", FormatAVIF, true}, {"image/svg+xml", "", false}, // SVG doesn't convert to ImageFormat {"image/tiff", "", false}, {"text/plain", "", false}, } for _, tt := range tests { t.Run(tt.mimeType, func(t *testing.T) { got, ok := MIMEToImageFormat(tt.mimeType) if ok != tt.wantOk { t.Errorf("MIMEToImageFormat(%q) ok = %v, want %v", tt.mimeType, ok, tt.wantOk) } if got != tt.wantFormat { t.Errorf("MIMEToImageFormat(%q) = %v, want %v", tt.mimeType, got, tt.wantFormat) } }) } } func TestImageFormatToMIME(t *testing.T) { tests := []struct { format ImageFormat wantMIME string }{ {FormatJPEG, "image/jpeg"}, {FormatPNG, "image/png"}, {FormatWebP, "image/webp"}, {FormatGIF, "image/gif"}, {FormatAVIF, "image/avif"}, {FormatOriginal, "application/octet-stream"}, {"unknown", "application/octet-stream"}, } for _, tt := range tests { t.Run(string(tt.format), func(t *testing.T) { got := ImageFormatToMIME(tt.format) if got != tt.wantMIME { t.Errorf("ImageFormatToMIME(%q) = %v, want %v", tt.format, got, tt.wantMIME) } }) } } func TestNormalizeMIMEType(t *testing.T) { tests := []struct { input string want string }{ {"image/jpeg", "image/jpeg"}, {"IMAGE/JPEG", "image/jpeg"}, {"image/jpeg; charset=utf-8", "image/jpeg"}, {" image/jpeg ", "image/jpeg"}, {"image/jpeg; boundary=something", "image/jpeg"}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { got := normalizeMIMEType(tt.input) if got != tt.want { t.Errorf("normalizeMIMEType(%q) = %q, want %q", tt.input, got, tt.want) } }) } } func TestDetectSVG(t *testing.T) { tests := []struct { name string data string want bool }{ {"xml declaration", ``, true}, {"svg element", ``, true}, {"doctype", ``, true}, {"with whitespace", ` `, true}, {"uppercase", ``, true}, {"not svg", ``, false}, {"random text", `hello world`, false}, {"empty", ``, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := detectSVG([]byte(tt.data)) if got != tt.want { t.Errorf("detectSVG(%q) = %v, want %v", tt.data, got, tt.want) } }) } } func TestSkipBOM(t *testing.T) { tests := []struct { name string data []byte want []byte }{ {"with BOM", []byte{0xEF, 0xBB, 0xBF, 'h', 'e', 'l', 'l', 'o'}, []byte("hello")}, {"without BOM", []byte("hello"), []byte("hello")}, {"empty", []byte{}, []byte{}}, {"only BOM", []byte{0xEF, 0xBB, 0xBF}, []byte{}}, {"partial BOM", []byte{0xEF, 0xBB}, []byte{0xEF, 0xBB}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := skipBOM(tt.data) if !bytes.Equal(got, tt.want) { t.Errorf("skipBOM() = %v, want %v", got, tt.want) } }) } } func TestRealWorldSVGPatterns(t *testing.T) { // Test various real-world SVG patterns svgPatterns := []string{ ` `, ` `, ` `, } for i, pattern := range svgPatterns { data := []byte(pattern) if len(data) < MinMagicBytes { // Pad short SVGs for detection data = append(data, make([]byte, MinMagicBytes-len(data))...) } got, err := DetectFormat(data) if err != nil { t.Errorf("Pattern %d: DetectFormat() error = %v", i, err) continue } if got != MIMETypeSVG { t.Errorf("Pattern %d: DetectFormat() = %v, want %v", i, got, MIMETypeSVG) } } } func TestDetectFormatRIFFNotWebP(t *testing.T) { // RIFF container but not WebP (e.g., WAV file) wavData := []byte{ 0x52, 0x49, 0x46, 0x46, // RIFF 0x00, 0x00, 0x00, 0x00, // file size 0x57, 0x41, 0x56, 0x45, // WAVE (not WEBP) } _, err := DetectFormat(wavData) if err != ErrUnknownFormat { t.Errorf("DetectFormat(WAV) error = %v, want %v", err, ErrUnknownFormat) } } func TestDetectFormatFtypNotAVIF(t *testing.T) { // ftyp container but not AVIF (e.g., MP4) mp4Data := []byte{ 0x00, 0x00, 0x00, 0x1C, // box size 0x66, 0x74, 0x79, 0x70, // ftyp 0x69, 0x73, 0x6F, 0x6D, // isom brand (not avif) } _, err := DetectFormat(mp4Data) if err != ErrUnknownFormat { t.Errorf("DetectFormat(MP4) error = %v, want %v", err, ErrUnknownFormat) } } func TestPeekAndValidatePreservesReader(t *testing.T) { // Ensure that after PeekAndValidate, we can read the complete original content originalContent := append( []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D}, []byte(strings.Repeat("PNG IDAT chunk data here ", 100))..., ) r := bytes.NewReader(originalContent) validated, err := PeekAndValidate(r, "image/png") if err != nil { t.Fatalf("PeekAndValidate() error = %v", err) } // Read everything from the validated reader got, err := io.ReadAll(validated) if err != nil { t.Fatalf("io.ReadAll() error = %v", err) } if !bytes.Equal(got, originalContent) { t.Errorf("Content mismatch: got %d bytes, want %d bytes", len(got), len(originalContent)) } }