// Copyright 2019 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package impl import ( "reflect" "sort" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/internal/genid" pref "google.golang.org/protobuf/reflect/protoreflect" ) type mapInfo struct { goType reflect.Type keyWiretag uint64 valWiretag uint64 keyFuncs valueCoderFuncs valFuncs valueCoderFuncs keyZero pref.Value keyKind pref.Kind conv *mapConverter } func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) { // TODO: Consider generating specialized map coders. keyField := fd.MapKey() valField := fd.MapValue() keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()]) valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()]) keyFuncs := encoderFuncsForValue(keyField) valFuncs := encoderFuncsForValue(valField) conv := newMapConverter(ft, fd) mapi := &mapInfo{ goType: ft, keyWiretag: keyWiretag, valWiretag: valWiretag, keyFuncs: keyFuncs, valFuncs: valFuncs, keyZero: keyField.Default(), keyKind: keyField.Kind(), conv: conv, } if valField.Kind() == pref.MessageKind { valueMessage = getMessageInfo(ft.Elem()) } funcs = pointerCoderFuncs{ size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int { return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts) }, marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts) }, unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) { mp := p.AsValueOf(ft) if mp.Elem().IsNil() { mp.Elem().Set(reflect.MakeMap(mapi.goType)) } if f.mi == nil { return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts) } else { return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts) } }, } switch valField.Kind() { case pref.MessageKind: funcs.merge = mergeMapOfMessage case pref.BytesKind: funcs.merge = mergeMapOfBytes default: funcs.merge = mergeMap } if valFuncs.isInit != nil { funcs.isInit = func(p pointer, f *coderFieldInfo) error { return isInitMap(p.AsValueOf(ft).Elem(), mapi, f) } } return valueMessage, funcs } const ( mapKeyTagSize = 1 // field 1, tag size 1. mapValTagSize = 1 // field 2, tag size 2. ) func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int { if mapv.Len() == 0 { return 0 } n := 0 iter := mapRange(mapv) for iter.Next() { key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey() keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) var valSize int value := mapi.conv.valConv.PBValueOf(iter.Value()) if f.mi == nil { valSize = mapi.valFuncs.size(value, mapValTagSize, opts) } else { p := pointerOfValue(iter.Value()) valSize += mapValTagSize valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts)) } n += f.tagsize + protowire.SizeBytes(keySize+valSize) } return n } func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { if wtyp != protowire.BytesType { return out, errUnknown } b, n := protowire.ConsumeBytes(b) if n < 0 { return out, errDecode } var ( key = mapi.keyZero val = mapi.conv.valConv.New() ) for len(b) > 0 { num, wtyp, n := protowire.ConsumeTag(b) if n < 0 { return out, errDecode } if num > protowire.MaxValidNumber { return out, errDecode } b = b[n:] err := errUnknown switch num { case genid.MapEntry_Key_field_number: var v pref.Value var o unmarshalOutput v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts) if err != nil { break } key = v n = o.n case genid.MapEntry_Value_field_number: var v pref.Value var o unmarshalOutput v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts) if err != nil { break } val = v n = o.n } if err == errUnknown { n = protowire.ConsumeFieldValue(num, wtyp, b) if n < 0 { return out, errDecode } } else if err != nil { return out, err } b = b[n:] } mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val)) out.n = n return out, nil } func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { if wtyp != protowire.BytesType { return out, errUnknown } b, n := protowire.ConsumeBytes(b) if n < 0 { return out, errDecode } var ( key = mapi.keyZero val = reflect.New(f.mi.GoReflectType.Elem()) ) for len(b) > 0 { num, wtyp, n := protowire.ConsumeTag(b) if n < 0 { return out, errDecode } if num > protowire.MaxValidNumber { return out, errDecode } b = b[n:] err := errUnknown switch num { case 1: var v pref.Value var o unmarshalOutput v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts) if err != nil { break } key = v n = o.n case 2: if wtyp != protowire.BytesType { break } var v []byte v, n = protowire.ConsumeBytes(b) if n < 0 { return out, errDecode } var o unmarshalOutput o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts) if o.initialized { // Consider this map item initialized so long as we see // an initialized value. out.initialized = true } } if err == errUnknown { n = protowire.ConsumeFieldValue(num, wtyp, b) if n < 0 { return out, errDecode } } else if err != nil { return out, err } b = b[n:] } mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val) out.n = n return out, nil } func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { if f.mi == nil { key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey() val := mapi.conv.valConv.PBValueOf(valrv) size := 0 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) size += mapi.valFuncs.size(val, mapValTagSize, opts) b = protowire.AppendVarint(b, uint64(size)) b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts) if err != nil { return nil, err } return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts) } else { key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey() val := pointerOfValue(valrv) valSize := f.mi.sizePointer(val, opts) size := 0 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) size += mapValTagSize + protowire.SizeBytes(valSize) b = protowire.AppendVarint(b, uint64(size)) b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts) if err != nil { return nil, err } b = protowire.AppendVarint(b, mapi.valWiretag) b = protowire.AppendVarint(b, uint64(valSize)) return f.mi.marshalAppendPointer(b, val, opts) } } func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { if mapv.Len() == 0 { return b, nil } if opts.Deterministic() { return appendMapDeterministic(b, mapv, mapi, f, opts) } iter := mapRange(mapv) for iter.Next() { var err error b = protowire.AppendVarint(b, f.wiretag) b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts) if err != nil { return b, err } } return b, nil } func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { keys := mapv.MapKeys() sort.Slice(keys, func(i, j int) bool { switch keys[i].Kind() { case reflect.Bool: return !keys[i].Bool() && keys[j].Bool() case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return keys[i].Int() < keys[j].Int() case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: return keys[i].Uint() < keys[j].Uint() case reflect.Float32, reflect.Float64: return keys[i].Float() < keys[j].Float() case reflect.String: return keys[i].String() < keys[j].String() default: panic("invalid kind: " + keys[i].Kind().String()) } }) for _, key := range keys { var err error b = protowire.AppendVarint(b, f.wiretag) b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts) if err != nil { return b, err } } return b, nil } func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error { if mi := f.mi; mi != nil { mi.init() if !mi.needsInitCheck { return nil } iter := mapRange(mapv) for iter.Next() { val := pointerOfValue(iter.Value()) if err := mi.checkInitializedPointer(val); err != nil { return err } } } else { iter := mapRange(mapv) for iter.Next() { val := mapi.conv.valConv.PBValueOf(iter.Value()) if err := mapi.valFuncs.isInit(val); err != nil { return err } } } return nil } func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { dstm := dst.AsValueOf(f.ft).Elem() srcm := src.AsValueOf(f.ft).Elem() if srcm.Len() == 0 { return } if dstm.IsNil() { dstm.Set(reflect.MakeMap(f.ft)) } iter := mapRange(srcm) for iter.Next() { dstm.SetMapIndex(iter.Key(), iter.Value()) } } func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { dstm := dst.AsValueOf(f.ft).Elem() srcm := src.AsValueOf(f.ft).Elem() if srcm.Len() == 0 { return } if dstm.IsNil() { dstm.Set(reflect.MakeMap(f.ft)) } iter := mapRange(srcm) for iter.Next() { dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...))) } } func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { dstm := dst.AsValueOf(f.ft).Elem() srcm := src.AsValueOf(f.ft).Elem() if srcm.Len() == 0 { return } if dstm.IsNil() { dstm.Set(reflect.MakeMap(f.ft)) } iter := mapRange(srcm) for iter.Next() { val := reflect.New(f.ft.Elem().Elem()) if f.mi != nil { f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts) } else { opts.Merge(asMessage(val), asMessage(iter.Value())) } dstm.SetMapIndex(iter.Key(), val) } }