package imgcache
import (
"bytes"
"io"
"strings"
"testing"
)
func TestDetectFormat(t *testing.T) {
tests := []struct {
name string
data []byte
wantMIME MIMEType
wantErr error
}{
{
name: "JPEG",
data: append([]byte{0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01}, make([]byte, 100)...),
wantMIME: MIMETypeJPEG,
wantErr: nil,
},
{
name: "PNG",
data: append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D}, make([]byte, 100)...),
wantMIME: MIMETypePNG,
wantErr: nil,
},
{
name: "GIF87a",
data: append([]byte{0x47, 0x49, 0x46, 0x38, 0x37, 0x61, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, make([]byte, 100)...),
wantMIME: MIMETypeGIF,
wantErr: nil,
},
{
name: "GIF89a",
data: append([]byte{0x47, 0x49, 0x46, 0x38, 0x39, 0x61, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, make([]byte, 100)...),
wantMIME: MIMETypeGIF,
wantErr: nil,
},
{
name: "WebP",
data: append([]byte{
0x52, 0x49, 0x46, 0x46, // RIFF
0x00, 0x00, 0x00, 0x00, // file size (placeholder)
0x57, 0x45, 0x42, 0x50, // WEBP
}, make([]byte, 100)...),
wantMIME: MIMETypeWebP,
wantErr: nil,
},
{
name: "AVIF",
data: append([]byte{
0x00, 0x00, 0x00, 0x1C, // box size
0x66, 0x74, 0x79, 0x70, // ftyp
0x61, 0x76, 0x69, 0x66, // avif brand
}, make([]byte, 100)...),
wantMIME: MIMETypeAVIF,
wantErr: nil,
},
{
name: "AVIF sequence",
data: append([]byte{
0x00, 0x00, 0x00, 0x1C, // box size
0x66, 0x74, 0x79, 0x70, // ftyp
0x61, 0x76, 0x69, 0x73, // avis brand
}, make([]byte, 100)...),
wantMIME: MIMETypeAVIF,
wantErr: nil,
},
{
name: "SVG with XML declaration",
data: []byte(``),
wantMIME: MIMETypeSVG,
wantErr: nil,
},
{
name: "SVG without declaration",
data: []byte(``),
wantMIME: MIMETypeSVG,
wantErr: nil,
},
{
name: "SVG with whitespace",
data: []byte(` `),
wantMIME: MIMETypeSVG,
wantErr: nil,
},
{
name: "SVG with BOM",
data: append([]byte{0xEF, 0xBB, 0xBF}, []byte(``)...),
wantMIME: MIMETypeSVG,
wantErr: nil,
},
{
name: "unknown format",
data: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
wantMIME: "",
wantErr: ErrUnknownFormat,
},
{
name: "too short",
data: []byte{0xFF, 0xD8},
wantMIME: "",
wantErr: ErrNotEnoughData,
},
{
name: "empty",
data: []byte{},
wantMIME: "",
wantErr: ErrNotEnoughData,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := DetectFormat(tt.data)
if err != tt.wantErr {
t.Errorf("DetectFormat() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.wantMIME {
t.Errorf("DetectFormat() = %v, want %v", got, tt.wantMIME)
}
})
}
}
func TestValidateMagicBytes(t *testing.T) {
jpegData := append([]byte{0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01}, make([]byte, 100)...)
pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D}, make([]byte, 100)...)
tests := []struct {
name string
data []byte
declaredType string
wantErr error
}{
{
name: "matching JPEG",
data: jpegData,
declaredType: "image/jpeg",
wantErr: nil,
},
{
name: "matching JPEG with params",
data: jpegData,
declaredType: "image/jpeg; charset=utf-8",
wantErr: nil,
},
{
name: "matching PNG",
data: pngData,
declaredType: "image/png",
wantErr: nil,
},
{
name: "mismatched type",
data: jpegData,
declaredType: "image/png",
wantErr: ErrMagicByteMismatch,
},
{
name: "unknown data",
data: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
declaredType: "image/jpeg",
wantErr: ErrUnknownFormat,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateMagicBytes(tt.data, tt.declaredType)
if err != tt.wantErr {
t.Errorf("ValidateMagicBytes() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestIsSupportedMIMEType(t *testing.T) {
tests := []struct {
mimeType string
want bool
}{
{"image/jpeg", true},
{"image/png", true},
{"image/webp", true},
{"image/gif", true},
{"image/avif", true},
{"image/svg+xml", true},
{"IMAGE/JPEG", true},
{"image/jpeg; charset=utf-8", true},
{"image/tiff", false},
{"image/bmp", false},
{"application/octet-stream", false},
{"text/plain", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.mimeType, func(t *testing.T) {
if got := IsSupportedMIMEType(tt.mimeType); got != tt.want {
t.Errorf("IsSupportedMIMEType(%q) = %v, want %v", tt.mimeType, got, tt.want)
}
})
}
}
func TestPeekAndValidate(t *testing.T) {
jpegData := append([]byte{0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01}, []byte("rest of jpeg data")...)
pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D}, []byte("rest of png data")...)
tests := []struct {
name string
data []byte
declaredType string
wantErr bool
wantData []byte
}{
{
name: "valid JPEG",
data: jpegData,
declaredType: "image/jpeg",
wantErr: false,
wantData: jpegData,
},
{
name: "valid PNG",
data: pngData,
declaredType: "image/png",
wantErr: false,
wantData: pngData,
},
{
name: "mismatched type",
data: jpegData,
declaredType: "image/png",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := bytes.NewReader(tt.data)
result, err := PeekAndValidate(r, tt.declaredType)
if tt.wantErr {
if err == nil {
t.Error("PeekAndValidate() expected error, got nil")
}
return
}
if err != nil {
t.Errorf("PeekAndValidate() unexpected error = %v", err)
return
}
// Read all data from result reader
got, err := io.ReadAll(result)
if err != nil {
t.Errorf("Failed to read result: %v", err)
return
}
if !bytes.Equal(got, tt.wantData) {
t.Errorf("PeekAndValidate() data mismatch: got %d bytes, want %d bytes", len(got), len(tt.wantData))
}
})
}
}
func TestMIMEToImageFormat(t *testing.T) {
tests := []struct {
mimeType string
wantFormat ImageFormat
wantOk bool
}{
{"image/jpeg", FormatJPEG, true},
{"image/png", FormatPNG, true},
{"image/webp", FormatWebP, true},
{"image/gif", FormatGIF, true},
{"image/avif", FormatAVIF, true},
{"image/svg+xml", "", false}, // SVG doesn't convert to ImageFormat
{"image/tiff", "", false},
{"text/plain", "", false},
}
for _, tt := range tests {
t.Run(tt.mimeType, func(t *testing.T) {
got, ok := MIMEToImageFormat(tt.mimeType)
if ok != tt.wantOk {
t.Errorf("MIMEToImageFormat(%q) ok = %v, want %v", tt.mimeType, ok, tt.wantOk)
}
if got != tt.wantFormat {
t.Errorf("MIMEToImageFormat(%q) = %v, want %v", tt.mimeType, got, tt.wantFormat)
}
})
}
}
func TestImageFormatToMIME(t *testing.T) {
tests := []struct {
format ImageFormat
wantMIME string
}{
{FormatJPEG, "image/jpeg"},
{FormatPNG, "image/png"},
{FormatWebP, "image/webp"},
{FormatGIF, "image/gif"},
{FormatAVIF, "image/avif"},
{FormatOriginal, "application/octet-stream"},
{"unknown", "application/octet-stream"},
}
for _, tt := range tests {
t.Run(string(tt.format), func(t *testing.T) {
got := ImageFormatToMIME(tt.format)
if got != tt.wantMIME {
t.Errorf("ImageFormatToMIME(%q) = %v, want %v", tt.format, got, tt.wantMIME)
}
})
}
}
func TestNormalizeMIMEType(t *testing.T) {
tests := []struct {
input string
want string
}{
{"image/jpeg", "image/jpeg"},
{"IMAGE/JPEG", "image/jpeg"},
{"image/jpeg; charset=utf-8", "image/jpeg"},
{" image/jpeg ", "image/jpeg"},
{"image/jpeg; boundary=something", "image/jpeg"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := normalizeMIMEType(tt.input)
if got != tt.want {
t.Errorf("normalizeMIMEType(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestDetectSVG(t *testing.T) {
tests := []struct {
name string
data string
want bool
}{
{"xml declaration", ``, true},
{"svg element", ``, true},
{"doctype", ``, true},
{"with whitespace", `
`, true},
{"uppercase", ``, true},
{"not svg", ``, false},
{"random text", `hello world`, false},
{"empty", ``, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := detectSVG([]byte(tt.data))
if got != tt.want {
t.Errorf("detectSVG(%q) = %v, want %v", tt.data, got, tt.want)
}
})
}
}
func TestSkipBOM(t *testing.T) {
tests := []struct {
name string
data []byte
want []byte
}{
{"with BOM", []byte{0xEF, 0xBB, 0xBF, 'h', 'e', 'l', 'l', 'o'}, []byte("hello")},
{"without BOM", []byte("hello"), []byte("hello")},
{"empty", []byte{}, []byte{}},
{"only BOM", []byte{0xEF, 0xBB, 0xBF}, []byte{}},
{"partial BOM", []byte{0xEF, 0xBB}, []byte{0xEF, 0xBB}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := skipBOM(tt.data)
if !bytes.Equal(got, tt.want) {
t.Errorf("skipBOM() = %v, want %v", got, tt.want)
}
})
}
}
func TestRealWorldSVGPatterns(t *testing.T) {
// Test various real-world SVG patterns
svgPatterns := []string{
`
`,
``,
`
`,
}
for i, pattern := range svgPatterns {
data := []byte(pattern)
if len(data) < MinMagicBytes {
// Pad short SVGs for detection
data = append(data, make([]byte, MinMagicBytes-len(data))...)
}
got, err := DetectFormat(data)
if err != nil {
t.Errorf("Pattern %d: DetectFormat() error = %v", i, err)
continue
}
if got != MIMETypeSVG {
t.Errorf("Pattern %d: DetectFormat() = %v, want %v", i, got, MIMETypeSVG)
}
}
}
func TestDetectFormatRIFFNotWebP(t *testing.T) {
// RIFF container but not WebP (e.g., WAV file)
wavData := []byte{
0x52, 0x49, 0x46, 0x46, // RIFF
0x00, 0x00, 0x00, 0x00, // file size
0x57, 0x41, 0x56, 0x45, // WAVE (not WEBP)
}
_, err := DetectFormat(wavData)
if err != ErrUnknownFormat {
t.Errorf("DetectFormat(WAV) error = %v, want %v", err, ErrUnknownFormat)
}
}
func TestDetectFormatFtypNotAVIF(t *testing.T) {
// ftyp container but not AVIF (e.g., MP4)
mp4Data := []byte{
0x00, 0x00, 0x00, 0x1C, // box size
0x66, 0x74, 0x79, 0x70, // ftyp
0x69, 0x73, 0x6F, 0x6D, // isom brand (not avif)
}
_, err := DetectFormat(mp4Data)
if err != ErrUnknownFormat {
t.Errorf("DetectFormat(MP4) error = %v, want %v", err, ErrUnknownFormat)
}
}
func TestPeekAndValidatePreservesReader(t *testing.T) {
// Ensure that after PeekAndValidate, we can read the complete original content
originalContent := append(
[]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D},
[]byte(strings.Repeat("PNG IDAT chunk data here ", 100))...,
)
r := bytes.NewReader(originalContent)
validated, err := PeekAndValidate(r, "image/png")
if err != nil {
t.Fatalf("PeekAndValidate() error = %v", err)
}
// Read everything from the validated reader
got, err := io.ReadAll(validated)
if err != nil {
t.Fatalf("io.ReadAll() error = %v", err)
}
if !bytes.Equal(got, originalContent) {
t.Errorf("Content mismatch: got %d bytes, want %d bytes", len(got), len(originalContent))
}
}