package mfer import ( "bytes" "crypto/sha256" "errors" "io" "github.com/google/uuid" "github.com/klauspost/compress/zstd" "github.com/spf13/afero" "google.golang.org/protobuf/proto" "sneak.berlin/go/mfer/internal/bork" "sneak.berlin/go/mfer/internal/log" ) // validateUUID checks that the byte slice is a valid UUID (16 bytes, parseable). func validateUUID(data []byte) error { if len(data) != 16 { return errors.New("invalid UUID length") } // Try to parse as UUID to validate format _, err := uuid.FromBytes(data) if err != nil { return errors.New("invalid UUID format") } return nil } func (m *manifest) deserializeInner() error { if m.pbOuter.Version != MFFileOuter_VERSION_ONE { return errors.New("unknown version") } if m.pbOuter.CompressionType != MFFileOuter_COMPRESSION_ZSTD { return errors.New("unknown compression type") } // Validate outer UUID before any decompression if err := validateUUID(m.pbOuter.Uuid); err != nil { return errors.New("outer UUID invalid: " + err.Error()) } // Verify hash of compressed data before decompression h := sha256.New() if _, err := h.Write(m.pbOuter.InnerMessage); err != nil { return err } if !bytes.Equal(h.Sum(nil), m.pbOuter.Sha256) { return errors.New("compressed data hash mismatch") } bb := bytes.NewBuffer(m.pbOuter.InnerMessage) zr, err := zstd.NewReader(bb) if err != nil { return err } defer zr.Close() dat, err := io.ReadAll(zr) if err != nil { return err } isize := len(dat) if int64(isize) != m.pbOuter.Size { log.Debugf("truncated data, got %d expected %d", isize, m.pbOuter.Size) return bork.ErrFileTruncated } // Deserialize inner message m.pbInner = new(MFFile) if err := proto.Unmarshal(dat, m.pbInner); err != nil { return err } // Validate inner UUID if err := validateUUID(m.pbInner.Uuid); err != nil { return errors.New("inner UUID invalid: " + err.Error()) } // Verify UUIDs match if !bytes.Equal(m.pbOuter.Uuid, m.pbInner.Uuid) { return errors.New("outer and inner UUID mismatch") } log.Infof("loaded manifest with %d files", len(m.pbInner.Files)) return nil } func validateMagic(dat []byte) bool { ml := len([]byte(MAGIC)) if len(dat) < ml { return false } got := dat[0:ml] expected := []byte(MAGIC) return bytes.Equal(got, expected) } // NewManifestFromReader reads a manifest from an io.Reader. func NewManifestFromReader(input io.Reader) (*manifest, error) { m := &manifest{} dat, err := io.ReadAll(input) if err != nil { return nil, err } if !validateMagic(dat) { return nil, errors.New("invalid file format") } // remove magic bytes prefix: ml := len([]byte(MAGIC)) bb := bytes.NewBuffer(dat[ml:]) dat = bb.Bytes() // deserialize outer: m.pbOuter = new(MFFileOuter) if err := proto.Unmarshal(dat, m.pbOuter); err != nil { return nil, err } // deserialize inner: if err := m.deserializeInner(); err != nil { return nil, err } return m, nil } // NewManifestFromFile reads a manifest from a file path using the given filesystem. // If fs is nil, the real filesystem (OsFs) is used. func NewManifestFromFile(fs afero.Fs, path string) (*manifest, error) { if fs == nil { fs = afero.NewOsFs() } f, err := fs.Open(path) if err != nil { return nil, err } defer func() { _ = f.Close() }() return NewManifestFromReader(f) }