package mfer import ( "bytes" "compress/gzip" "crypto/sha256" "errors" "io" "git.eeqj.de/sneak/mfer/internal/bork" "git.eeqj.de/sneak/mfer/internal/log" "google.golang.org/protobuf/proto" ) func (m *manifest) validateProtoInner() error { i := m.pbInner if i.Version != MFFile_VERSION_ONE { return errors.New("unknown version") // FIXME move to bork } if len(i.Files) == 0 { return errors.New("manifest without files") // FIXME move to bork } for _, mfp := range m.pbInner.Files { // there is no way we should be doing validateProtoInner() // outside of a load into a blank/empty/new *manifest if m.files != nil { return errors.New("shouldn't happen, internal error") } m.files = make([]*manifestFile, 0) // we can skip error handling here thanks to the magic of protobuf m.addFileLoadTime(mfp) } return nil } func (m *manifest) validateProtoOuter() error { if m.pbOuter.Version != MFFileOuter_VERSION_ONE { return errors.New("unknown version") // FIXME move to bork } if m.pbOuter.CompressionType != MFFileOuter_COMPRESSION_GZIP { return errors.New("unknown compression type") // FIXME move to bork } bb := bytes.NewBuffer(m.pbOuter.InnerMessage) gzr, err := gzip.NewReader(bb) if err != nil { return err } dat, err := io.ReadAll(gzr) defer gzr.Close() if err != nil { return err } isize := len(dat) if int64(isize) != m.pbOuter.Size { log.Tracef("truncated data, got %d expected %d", isize, m.pbOuter.Size) return bork.ErrFileTruncated } log.Tracef("inner data size is %d", isize) log.TraceDump(dat) // FIXME validate Sha256 log.TraceDump(m.pbOuter.Sha256) h := sha256.New() h.Write(dat) shaGot := h.Sum(nil) log.TraceDump("got: ", shaGot) log.TraceDump("expected: ", m.pbOuter.Sha256) if !bytes.Equal(shaGot, m.pbOuter.Sha256) { m.pbOuter.InnerMessage = nil // don't try to mess with it return bork.ErrFileIntegrity } return nil } func validateMagic(dat []byte) bool { ml := len([]byte(MAGIC)) if len(dat) < ml { return false } got := dat[0:ml] expected := []byte(MAGIC) return bytes.Equal(got, expected) } func NewFromReader(input io.Reader) (*manifest, error) { m := New() dat, err := io.ReadAll(input) if err != nil { return nil, err } if !validateMagic(dat) { return nil, errors.New("invalid file format") } // remove magic bytes prefix: ml := len([]byte(MAGIC)) bb := bytes.NewBuffer(dat[ml:]) dat = bb.Bytes() log.TraceDump(dat) // deserialize: m.pbOuter = new(MFFileOuter) err = proto.Unmarshal(dat, m.pbOuter) if err != nil { return nil, err } ve := m.validateProtoOuter() if ve != nil { return nil, ve } m.pbInner = new(MFFile) err = proto.Unmarshal(m.pbOuter.InnerMessage, m.pbInner) if err != nil { return nil, err } log.TraceDump(m.pbInner) ve = m.validateProtoInner() if ve != nil { return nil, ve } m.rescanInternal() log.TraceDump(m) return m, nil }