From 30d63e80dcc9adea68f77046821ec35a83c2f9d6 Mon Sep 17 00:00:00 2001 From: sneak Date: Thu, 8 Jan 2026 03:35:29 -0800 Subject: [PATCH] Add magic byte detection for image format validation Implements format detection by checking file magic bytes for JPEG, PNG, GIF, WebP, AVIF, and SVG. Includes validation against declared Content-Type. --- internal/imgcache/magic.go | 227 +++++++++++++++ internal/imgcache/magic_test.go | 497 ++++++++++++++++++++++++++++++++ 2 files changed, 724 insertions(+) create mode 100644 internal/imgcache/magic.go create mode 100644 internal/imgcache/magic_test.go diff --git a/internal/imgcache/magic.go b/internal/imgcache/magic.go new file mode 100644 index 0000000..3643cf8 --- /dev/null +++ b/internal/imgcache/magic.go @@ -0,0 +1,227 @@ +package imgcache + +import ( + "bytes" + "errors" + "io" + "strings" +) + +// Magic byte errors. +var ( + ErrUnknownFormat = errors.New("unknown image format") + ErrMagicByteMismatch = errors.New("content does not match declared Content-Type") + ErrNotEnoughData = errors.New("not enough data to detect format") +) + +// MIMEType represents a supported MIME type for input images. +type MIMEType string + +// Supported input MIME types. +const ( + MIMETypeJPEG = MIMEType("image/jpeg") + MIMETypePNG = MIMEType("image/png") + MIMETypeWebP = MIMEType("image/webp") + MIMETypeGIF = MIMEType("image/gif") + MIMETypeAVIF = MIMEType("image/avif") + MIMETypeSVG = MIMEType("image/svg+xml") +) + +// MinMagicBytes is the minimum number of bytes needed to detect format. +const MinMagicBytes = 12 + +// Magic byte signatures for supported formats. +// These are effectively constants but Go doesn't support const slices. +// +//nolint:gochecknoglobals // immutable lookup data +var ( + magicJPEG = []byte{0xFF, 0xD8, 0xFF} + magicPNG = []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A} + magicGIF = []byte{0x47, 0x49, 0x46, 0x38} // GIF8 (GIF87a or GIF89a) + magicWebP = []byte{0x52, 0x49, 0x46, 0x46} // RIFF (WebP starts with RIFF....WEBP) + // AVIF uses the ftyp box with brand "avif" or "avis" + // Format: size(4 bytes) + "ftyp" + brand(4 bytes) + magicFtyp = []byte{0x66, 0x74, 0x79, 0x70} // "ftyp" +) + +// WebP identifier appears at offset 8 after RIFF header. +// +//nolint:gochecknoglobals // immutable lookup data +var webpIdent = []byte{0x57, 0x45, 0x42, 0x50} // "WEBP" + +// AVIF brand identifiers. +// +//nolint:gochecknoglobals // immutable lookup data +var ( + avifBrand = []byte{0x61, 0x76, 0x69, 0x66} // "avif" + avisBrand = []byte{0x61, 0x76, 0x69, 0x73} // "avis" (AVIF sequence) +) + +// DetectFormat detects the image format from magic bytes. +// Returns the MIME type and nil error on success. +func DetectFormat(data []byte) (MIMEType, error) { + if len(data) < MinMagicBytes { + return "", ErrNotEnoughData + } + + // Check JPEG (FFD8FF) + if bytes.HasPrefix(data, magicJPEG) { + return MIMETypeJPEG, nil + } + + // Check PNG (89504E47 0D0A1A0A) + if bytes.HasPrefix(data, magicPNG) { + return MIMETypePNG, nil + } + + // Check GIF (GIF87a or GIF89a) + if bytes.HasPrefix(data, magicGIF) { + return MIMETypeGIF, nil + } + + // Check WebP (RIFF....WEBP) + if bytes.HasPrefix(data, magicWebP) && len(data) >= 12 { + if bytes.Equal(data[8:12], webpIdent) { + return MIMETypeWebP, nil + } + } + + // Check AVIF (....ftypavif or ....ftypavis) + // The ftyp box can start at offset 4 (after size bytes) + if len(data) >= 12 && bytes.Equal(data[4:8], magicFtyp) { + brand := data[8:12] + if bytes.Equal(brand, avifBrand) || bytes.Equal(brand, avisBrand) { + return MIMETypeAVIF, nil + } + } + + // Check SVG - look for XML declaration or SVG tag + if detectSVG(data) { + return MIMETypeSVG, nil + } + + return "", ErrUnknownFormat +} + +// detectSVG checks if data appears to be SVG content. +func detectSVG(data []byte) bool { + // Skip BOM if present + content := skipBOM(data) + + // Convert to string for easier pattern matching + s := strings.ToLower(string(content)) + + // Skip leading whitespace + s = strings.TrimSpace(s) + + // Check for XML declaration or SVG element + return strings.HasPrefix(s, "= 3 && data[0] == 0xEF && data[1] == 0xBB && data[2] == 0xBF { + return data[3:] + } + + return data +} + +// ValidateMagicBytes validates that the content matches the declared MIME type. +func ValidateMagicBytes(data []byte, declaredType string) error { + detected, err := DetectFormat(data) + if err != nil { + return err + } + + // Normalize the declared type (remove parameters like charset) + normalizedDeclared := normalizeMIMEType(declaredType) + + // Check if they match + if string(detected) != normalizedDeclared { + return ErrMagicByteMismatch + } + + return nil +} + +// normalizeMIMEType extracts just the media type, removing parameters. +func normalizeMIMEType(mimeType string) string { + // Handle "image/jpeg; charset=utf-8" -> "image/jpeg" + if idx := strings.Index(mimeType, ";"); idx != -1 { + mimeType = mimeType[:idx] + } + + return strings.TrimSpace(strings.ToLower(mimeType)) +} + +// IsSupportedMIMEType checks if a MIME type is supported for input. +func IsSupportedMIMEType(mimeType string) bool { + normalized := normalizeMIMEType(mimeType) + switch MIMEType(normalized) { + case MIMETypeJPEG, MIMETypePNG, MIMETypeWebP, MIMETypeGIF, MIMETypeAVIF, MIMETypeSVG: + return true + default: + return false + } +} + +// PeekAndValidate reads the minimum bytes needed for format detection, +// validates against the declared type, and returns a reader that includes +// those bytes for subsequent reading. +func PeekAndValidate(r io.Reader, declaredType string) (io.Reader, error) { + // Read minimum bytes for detection + buf := make([]byte, MinMagicBytes) + n, err := io.ReadFull(r, buf) + if err != nil && err != io.ErrUnexpectedEOF { + return nil, err + } + buf = buf[:n] + + // Validate magic bytes + if err := ValidateMagicBytes(buf, declaredType); err != nil { + return nil, err + } + + // Return a reader that includes the peeked bytes + return io.MultiReader(bytes.NewReader(buf), r), nil +} + +// MIMEToImageFormat converts a MIME type to our ImageFormat type. +func MIMEToImageFormat(mimeType string) (ImageFormat, bool) { + normalized := normalizeMIMEType(mimeType) + switch MIMEType(normalized) { + case MIMETypeJPEG: + return FormatJPEG, true + case MIMETypePNG: + return FormatPNG, true + case MIMETypeWebP: + return FormatWebP, true + case MIMETypeGIF: + return FormatGIF, true + case MIMETypeAVIF: + return FormatAVIF, true + default: + return "", false + } +} + +// ImageFormatToMIME converts our ImageFormat to a MIME type string. +func ImageFormatToMIME(format ImageFormat) string { + switch format { + case FormatJPEG: + return string(MIMETypeJPEG) + case FormatPNG: + return string(MIMETypePNG) + case FormatWebP: + return string(MIMETypeWebP) + case FormatGIF: + return string(MIMETypeGIF) + case FormatAVIF: + return string(MIMETypeAVIF) + default: + return "application/octet-stream" + } +} diff --git a/internal/imgcache/magic_test.go b/internal/imgcache/magic_test.go new file mode 100644 index 0000000..36d8bd5 --- /dev/null +++ b/internal/imgcache/magic_test.go @@ -0,0 +1,497 @@ +package imgcache + +import ( + "bytes" + "io" + "strings" + "testing" +) + +func TestDetectFormat(t *testing.T) { + tests := []struct { + name string + data []byte + wantMIME MIMEType + wantErr error + }{ + { + name: "JPEG", + data: append([]byte{0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01}, make([]byte, 100)...), + wantMIME: MIMETypeJPEG, + wantErr: nil, + }, + { + name: "PNG", + data: append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D}, make([]byte, 100)...), + wantMIME: MIMETypePNG, + wantErr: nil, + }, + { + name: "GIF87a", + data: append([]byte{0x47, 0x49, 0x46, 0x38, 0x37, 0x61, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, make([]byte, 100)...), + wantMIME: MIMETypeGIF, + wantErr: nil, + }, + { + name: "GIF89a", + data: append([]byte{0x47, 0x49, 0x46, 0x38, 0x39, 0x61, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, make([]byte, 100)...), + wantMIME: MIMETypeGIF, + wantErr: nil, + }, + { + name: "WebP", + data: append([]byte{ + 0x52, 0x49, 0x46, 0x46, // RIFF + 0x00, 0x00, 0x00, 0x00, // file size (placeholder) + 0x57, 0x45, 0x42, 0x50, // WEBP + }, make([]byte, 100)...), + wantMIME: MIMETypeWebP, + wantErr: nil, + }, + { + name: "AVIF", + data: append([]byte{ + 0x00, 0x00, 0x00, 0x1C, // box size + 0x66, 0x74, 0x79, 0x70, // ftyp + 0x61, 0x76, 0x69, 0x66, // avif brand + }, make([]byte, 100)...), + wantMIME: MIMETypeAVIF, + wantErr: nil, + }, + { + name: "AVIF sequence", + data: append([]byte{ + 0x00, 0x00, 0x00, 0x1C, // box size + 0x66, 0x74, 0x79, 0x70, // ftyp + 0x61, 0x76, 0x69, 0x73, // avis brand + }, make([]byte, 100)...), + wantMIME: MIMETypeAVIF, + wantErr: nil, + }, + { + name: "SVG with XML declaration", + data: []byte(``), + wantMIME: MIMETypeSVG, + wantErr: nil, + }, + { + name: "SVG without declaration", + data: []byte(``), + wantMIME: MIMETypeSVG, + wantErr: nil, + }, + { + name: "SVG with whitespace", + data: []byte(` `), + wantMIME: MIMETypeSVG, + wantErr: nil, + }, + { + name: "SVG with BOM", + data: append([]byte{0xEF, 0xBB, 0xBF}, []byte(``)...), + wantMIME: MIMETypeSVG, + wantErr: nil, + }, + { + name: "unknown format", + data: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + wantMIME: "", + wantErr: ErrUnknownFormat, + }, + { + name: "too short", + data: []byte{0xFF, 0xD8}, + wantMIME: "", + wantErr: ErrNotEnoughData, + }, + { + name: "empty", + data: []byte{}, + wantMIME: "", + wantErr: ErrNotEnoughData, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := DetectFormat(tt.data) + + if err != tt.wantErr { + t.Errorf("DetectFormat() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if got != tt.wantMIME { + t.Errorf("DetectFormat() = %v, want %v", got, tt.wantMIME) + } + }) + } +} + +func TestValidateMagicBytes(t *testing.T) { + jpegData := append([]byte{0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01}, make([]byte, 100)...) + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D}, make([]byte, 100)...) + + tests := []struct { + name string + data []byte + declaredType string + wantErr error + }{ + { + name: "matching JPEG", + data: jpegData, + declaredType: "image/jpeg", + wantErr: nil, + }, + { + name: "matching JPEG with params", + data: jpegData, + declaredType: "image/jpeg; charset=utf-8", + wantErr: nil, + }, + { + name: "matching PNG", + data: pngData, + declaredType: "image/png", + wantErr: nil, + }, + { + name: "mismatched type", + data: jpegData, + declaredType: "image/png", + wantErr: ErrMagicByteMismatch, + }, + { + name: "unknown data", + data: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + declaredType: "image/jpeg", + wantErr: ErrUnknownFormat, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateMagicBytes(tt.data, tt.declaredType) + + if err != tt.wantErr { + t.Errorf("ValidateMagicBytes() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestIsSupportedMIMEType(t *testing.T) { + tests := []struct { + mimeType string + want bool + }{ + {"image/jpeg", true}, + {"image/png", true}, + {"image/webp", true}, + {"image/gif", true}, + {"image/avif", true}, + {"image/svg+xml", true}, + {"IMAGE/JPEG", true}, + {"image/jpeg; charset=utf-8", true}, + {"image/tiff", false}, + {"image/bmp", false}, + {"application/octet-stream", false}, + {"text/plain", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.mimeType, func(t *testing.T) { + if got := IsSupportedMIMEType(tt.mimeType); got != tt.want { + t.Errorf("IsSupportedMIMEType(%q) = %v, want %v", tt.mimeType, got, tt.want) + } + }) + } +} + +func TestPeekAndValidate(t *testing.T) { + jpegData := append([]byte{0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01}, []byte("rest of jpeg data")...) + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D}, []byte("rest of png data")...) + + tests := []struct { + name string + data []byte + declaredType string + wantErr bool + wantData []byte + }{ + { + name: "valid JPEG", + data: jpegData, + declaredType: "image/jpeg", + wantErr: false, + wantData: jpegData, + }, + { + name: "valid PNG", + data: pngData, + declaredType: "image/png", + wantErr: false, + wantData: pngData, + }, + { + name: "mismatched type", + data: jpegData, + declaredType: "image/png", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bytes.NewReader(tt.data) + result, err := PeekAndValidate(r, tt.declaredType) + + if tt.wantErr { + if err == nil { + t.Error("PeekAndValidate() expected error, got nil") + } + + return + } + + if err != nil { + t.Errorf("PeekAndValidate() unexpected error = %v", err) + + return + } + + // Read all data from result reader + got, err := io.ReadAll(result) + if err != nil { + t.Errorf("Failed to read result: %v", err) + + return + } + + if !bytes.Equal(got, tt.wantData) { + t.Errorf("PeekAndValidate() data mismatch: got %d bytes, want %d bytes", len(got), len(tt.wantData)) + } + }) + } +} + +func TestMIMEToImageFormat(t *testing.T) { + tests := []struct { + mimeType string + wantFormat ImageFormat + wantOk bool + }{ + {"image/jpeg", FormatJPEG, true}, + {"image/png", FormatPNG, true}, + {"image/webp", FormatWebP, true}, + {"image/gif", FormatGIF, true}, + {"image/avif", FormatAVIF, true}, + {"image/svg+xml", "", false}, // SVG doesn't convert to ImageFormat + {"image/tiff", "", false}, + {"text/plain", "", false}, + } + + for _, tt := range tests { + t.Run(tt.mimeType, func(t *testing.T) { + got, ok := MIMEToImageFormat(tt.mimeType) + + if ok != tt.wantOk { + t.Errorf("MIMEToImageFormat(%q) ok = %v, want %v", tt.mimeType, ok, tt.wantOk) + } + + if got != tt.wantFormat { + t.Errorf("MIMEToImageFormat(%q) = %v, want %v", tt.mimeType, got, tt.wantFormat) + } + }) + } +} + +func TestImageFormatToMIME(t *testing.T) { + tests := []struct { + format ImageFormat + wantMIME string + }{ + {FormatJPEG, "image/jpeg"}, + {FormatPNG, "image/png"}, + {FormatWebP, "image/webp"}, + {FormatGIF, "image/gif"}, + {FormatAVIF, "image/avif"}, + {FormatOriginal, "application/octet-stream"}, + {"unknown", "application/octet-stream"}, + } + + for _, tt := range tests { + t.Run(string(tt.format), func(t *testing.T) { + got := ImageFormatToMIME(tt.format) + + if got != tt.wantMIME { + t.Errorf("ImageFormatToMIME(%q) = %v, want %v", tt.format, got, tt.wantMIME) + } + }) + } +} + +func TestNormalizeMIMEType(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"image/jpeg", "image/jpeg"}, + {"IMAGE/JPEG", "image/jpeg"}, + {"image/jpeg; charset=utf-8", "image/jpeg"}, + {" image/jpeg ", "image/jpeg"}, + {"image/jpeg; boundary=something", "image/jpeg"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := normalizeMIMEType(tt.input) + + if got != tt.want { + t.Errorf("normalizeMIMEType(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestDetectSVG(t *testing.T) { + tests := []struct { + name string + data string + want bool + }{ + {"xml declaration", ``, true}, + {"svg element", ``, true}, + {"doctype", ``, true}, + {"with whitespace", ` + `, true}, + {"uppercase", ``, true}, + {"not svg", ``, false}, + {"random text", `hello world`, false}, + {"empty", ``, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectSVG([]byte(tt.data)) + + if got != tt.want { + t.Errorf("detectSVG(%q) = %v, want %v", tt.data, got, tt.want) + } + }) + } +} + +func TestSkipBOM(t *testing.T) { + tests := []struct { + name string + data []byte + want []byte + }{ + {"with BOM", []byte{0xEF, 0xBB, 0xBF, 'h', 'e', 'l', 'l', 'o'}, []byte("hello")}, + {"without BOM", []byte("hello"), []byte("hello")}, + {"empty", []byte{}, []byte{}}, + {"only BOM", []byte{0xEF, 0xBB, 0xBF}, []byte{}}, + {"partial BOM", []byte{0xEF, 0xBB}, []byte{0xEF, 0xBB}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := skipBOM(tt.data) + + if !bytes.Equal(got, tt.want) { + t.Errorf("skipBOM() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRealWorldSVGPatterns(t *testing.T) { + // Test various real-world SVG patterns + svgPatterns := []string{ + ` + + +`, + ` + +`, + ` + +`, + } + + for i, pattern := range svgPatterns { + data := []byte(pattern) + if len(data) < MinMagicBytes { + // Pad short SVGs for detection + data = append(data, make([]byte, MinMagicBytes-len(data))...) + } + + got, err := DetectFormat(data) + if err != nil { + t.Errorf("Pattern %d: DetectFormat() error = %v", i, err) + + continue + } + + if got != MIMETypeSVG { + t.Errorf("Pattern %d: DetectFormat() = %v, want %v", i, got, MIMETypeSVG) + } + } +} + +func TestDetectFormatRIFFNotWebP(t *testing.T) { + // RIFF container but not WebP (e.g., WAV file) + wavData := []byte{ + 0x52, 0x49, 0x46, 0x46, // RIFF + 0x00, 0x00, 0x00, 0x00, // file size + 0x57, 0x41, 0x56, 0x45, // WAVE (not WEBP) + } + + _, err := DetectFormat(wavData) + if err != ErrUnknownFormat { + t.Errorf("DetectFormat(WAV) error = %v, want %v", err, ErrUnknownFormat) + } +} + +func TestDetectFormatFtypNotAVIF(t *testing.T) { + // ftyp container but not AVIF (e.g., MP4) + mp4Data := []byte{ + 0x00, 0x00, 0x00, 0x1C, // box size + 0x66, 0x74, 0x79, 0x70, // ftyp + 0x69, 0x73, 0x6F, 0x6D, // isom brand (not avif) + } + + _, err := DetectFormat(mp4Data) + if err != ErrUnknownFormat { + t.Errorf("DetectFormat(MP4) error = %v, want %v", err, ErrUnknownFormat) + } +} + +func TestPeekAndValidatePreservesReader(t *testing.T) { + // Ensure that after PeekAndValidate, we can read the complete original content + originalContent := append( + []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D}, + []byte(strings.Repeat("PNG IDAT chunk data here ", 100))..., + ) + + r := bytes.NewReader(originalContent) + validated, err := PeekAndValidate(r, "image/png") + if err != nil { + t.Fatalf("PeekAndValidate() error = %v", err) + } + + // Read everything from the validated reader + got, err := io.ReadAll(validated) + if err != nil { + t.Fatalf("io.ReadAll() error = %v", err) + } + + if !bytes.Equal(got, originalContent) { + t.Errorf("Content mismatch: got %d bytes, want %d bytes", len(got), len(originalContent)) + } +}