vendor/google.golang.org/protobuf/internal/impl/codec_map.go
changeset 256 6d9efbef00a9
child 260 445e01aede7e
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/vendor/google.golang.org/protobuf/internal/impl/codec_map.go	Sun Jul 11 10:35:56 2021 +0200
@@ -0,0 +1,388 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+	"reflect"
+	"sort"
+
+	"google.golang.org/protobuf/encoding/protowire"
+	"google.golang.org/protobuf/internal/genid"
+	pref "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+type mapInfo struct {
+	goType     reflect.Type
+	keyWiretag uint64
+	valWiretag uint64
+	keyFuncs   valueCoderFuncs
+	valFuncs   valueCoderFuncs
+	keyZero    pref.Value
+	keyKind    pref.Kind
+	conv       *mapConverter
+}
+
+func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
+	// TODO: Consider generating specialized map coders.
+	keyField := fd.MapKey()
+	valField := fd.MapValue()
+	keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
+	valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
+	keyFuncs := encoderFuncsForValue(keyField)
+	valFuncs := encoderFuncsForValue(valField)
+	conv := newMapConverter(ft, fd)
+
+	mapi := &mapInfo{
+		goType:     ft,
+		keyWiretag: keyWiretag,
+		valWiretag: valWiretag,
+		keyFuncs:   keyFuncs,
+		valFuncs:   valFuncs,
+		keyZero:    keyField.Default(),
+		keyKind:    keyField.Kind(),
+		conv:       conv,
+	}
+	if valField.Kind() == pref.MessageKind {
+		valueMessage = getMessageInfo(ft.Elem())
+	}
+
+	funcs = pointerCoderFuncs{
+		size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
+			return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
+		},
+		marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
+			return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
+		},
+		unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
+			mp := p.AsValueOf(ft)
+			if mp.Elem().IsNil() {
+				mp.Elem().Set(reflect.MakeMap(mapi.goType))
+			}
+			if f.mi == nil {
+				return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
+			} else {
+				return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
+			}
+		},
+	}
+	switch valField.Kind() {
+	case pref.MessageKind:
+		funcs.merge = mergeMapOfMessage
+	case pref.BytesKind:
+		funcs.merge = mergeMapOfBytes
+	default:
+		funcs.merge = mergeMap
+	}
+	if valFuncs.isInit != nil {
+		funcs.isInit = func(p pointer, f *coderFieldInfo) error {
+			return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
+		}
+	}
+	return valueMessage, funcs
+}
+
+const (
+	mapKeyTagSize = 1 // field 1, tag size 1.
+	mapValTagSize = 1 // field 2, tag size 2.
+)
+
+func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
+	if mapv.Len() == 0 {
+		return 0
+	}
+	n := 0
+	iter := mapRange(mapv)
+	for iter.Next() {
+		key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
+		keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
+		var valSize int
+		value := mapi.conv.valConv.PBValueOf(iter.Value())
+		if f.mi == nil {
+			valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
+		} else {
+			p := pointerOfValue(iter.Value())
+			valSize += mapValTagSize
+			valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
+		}
+		n += f.tagsize + protowire.SizeBytes(keySize+valSize)
+	}
+	return n
+}
+
+func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
+	if wtyp != protowire.BytesType {
+		return out, errUnknown
+	}
+	b, n := protowire.ConsumeBytes(b)
+	if n < 0 {
+		return out, errDecode
+	}
+	var (
+		key = mapi.keyZero
+		val = mapi.conv.valConv.New()
+	)
+	for len(b) > 0 {
+		num, wtyp, n := protowire.ConsumeTag(b)
+		if n < 0 {
+			return out, errDecode
+		}
+		if num > protowire.MaxValidNumber {
+			return out, errDecode
+		}
+		b = b[n:]
+		err := errUnknown
+		switch num {
+		case genid.MapEntry_Key_field_number:
+			var v pref.Value
+			var o unmarshalOutput
+			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
+			if err != nil {
+				break
+			}
+			key = v
+			n = o.n
+		case genid.MapEntry_Value_field_number:
+			var v pref.Value
+			var o unmarshalOutput
+			v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
+			if err != nil {
+				break
+			}
+			val = v
+			n = o.n
+		}
+		if err == errUnknown {
+			n = protowire.ConsumeFieldValue(num, wtyp, b)
+			if n < 0 {
+				return out, errDecode
+			}
+		} else if err != nil {
+			return out, err
+		}
+		b = b[n:]
+	}
+	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
+	out.n = n
+	return out, nil
+}
+
+func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
+	if wtyp != protowire.BytesType {
+		return out, errUnknown
+	}
+	b, n := protowire.ConsumeBytes(b)
+	if n < 0 {
+		return out, errDecode
+	}
+	var (
+		key = mapi.keyZero
+		val = reflect.New(f.mi.GoReflectType.Elem())
+	)
+	for len(b) > 0 {
+		num, wtyp, n := protowire.ConsumeTag(b)
+		if n < 0 {
+			return out, errDecode
+		}
+		if num > protowire.MaxValidNumber {
+			return out, errDecode
+		}
+		b = b[n:]
+		err := errUnknown
+		switch num {
+		case 1:
+			var v pref.Value
+			var o unmarshalOutput
+			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
+			if err != nil {
+				break
+			}
+			key = v
+			n = o.n
+		case 2:
+			if wtyp != protowire.BytesType {
+				break
+			}
+			var v []byte
+			v, n = protowire.ConsumeBytes(b)
+			if n < 0 {
+				return out, errDecode
+			}
+			var o unmarshalOutput
+			o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
+			if o.initialized {
+				// Consider this map item initialized so long as we see
+				// an initialized value.
+				out.initialized = true
+			}
+		}
+		if err == errUnknown {
+			n = protowire.ConsumeFieldValue(num, wtyp, b)
+			if n < 0 {
+				return out, errDecode
+			}
+		} else if err != nil {
+			return out, err
+		}
+		b = b[n:]
+	}
+	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
+	out.n = n
+	return out, nil
+}
+
+func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
+	if f.mi == nil {
+		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
+		val := mapi.conv.valConv.PBValueOf(valrv)
+		size := 0
+		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
+		size += mapi.valFuncs.size(val, mapValTagSize, opts)
+		b = protowire.AppendVarint(b, uint64(size))
+		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
+		if err != nil {
+			return nil, err
+		}
+		return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
+	} else {
+		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
+		val := pointerOfValue(valrv)
+		valSize := f.mi.sizePointer(val, opts)
+		size := 0
+		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
+		size += mapValTagSize + protowire.SizeBytes(valSize)
+		b = protowire.AppendVarint(b, uint64(size))
+		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
+		if err != nil {
+			return nil, err
+		}
+		b = protowire.AppendVarint(b, mapi.valWiretag)
+		b = protowire.AppendVarint(b, uint64(valSize))
+		return f.mi.marshalAppendPointer(b, val, opts)
+	}
+}
+
+func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
+	if mapv.Len() == 0 {
+		return b, nil
+	}
+	if opts.Deterministic() {
+		return appendMapDeterministic(b, mapv, mapi, f, opts)
+	}
+	iter := mapRange(mapv)
+	for iter.Next() {
+		var err error
+		b = protowire.AppendVarint(b, f.wiretag)
+		b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
+		if err != nil {
+			return b, err
+		}
+	}
+	return b, nil
+}
+
+func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
+	keys := mapv.MapKeys()
+	sort.Slice(keys, func(i, j int) bool {
+		switch keys[i].Kind() {
+		case reflect.Bool:
+			return !keys[i].Bool() && keys[j].Bool()
+		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+			return keys[i].Int() < keys[j].Int()
+		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+			return keys[i].Uint() < keys[j].Uint()
+		case reflect.Float32, reflect.Float64:
+			return keys[i].Float() < keys[j].Float()
+		case reflect.String:
+			return keys[i].String() < keys[j].String()
+		default:
+			panic("invalid kind: " + keys[i].Kind().String())
+		}
+	})
+	for _, key := range keys {
+		var err error
+		b = protowire.AppendVarint(b, f.wiretag)
+		b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
+		if err != nil {
+			return b, err
+		}
+	}
+	return b, nil
+}
+
+func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
+	if mi := f.mi; mi != nil {
+		mi.init()
+		if !mi.needsInitCheck {
+			return nil
+		}
+		iter := mapRange(mapv)
+		for iter.Next() {
+			val := pointerOfValue(iter.Value())
+			if err := mi.checkInitializedPointer(val); err != nil {
+				return err
+			}
+		}
+	} else {
+		iter := mapRange(mapv)
+		for iter.Next() {
+			val := mapi.conv.valConv.PBValueOf(iter.Value())
+			if err := mapi.valFuncs.isInit(val); err != nil {
+				return err
+			}
+		}
+	}
+	return nil
+}
+
+func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
+	dstm := dst.AsValueOf(f.ft).Elem()
+	srcm := src.AsValueOf(f.ft).Elem()
+	if srcm.Len() == 0 {
+		return
+	}
+	if dstm.IsNil() {
+		dstm.Set(reflect.MakeMap(f.ft))
+	}
+	iter := mapRange(srcm)
+	for iter.Next() {
+		dstm.SetMapIndex(iter.Key(), iter.Value())
+	}
+}
+
+func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
+	dstm := dst.AsValueOf(f.ft).Elem()
+	srcm := src.AsValueOf(f.ft).Elem()
+	if srcm.Len() == 0 {
+		return
+	}
+	if dstm.IsNil() {
+		dstm.Set(reflect.MakeMap(f.ft))
+	}
+	iter := mapRange(srcm)
+	for iter.Next() {
+		dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
+	}
+}
+
+func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
+	dstm := dst.AsValueOf(f.ft).Elem()
+	srcm := src.AsValueOf(f.ft).Elem()
+	if srcm.Len() == 0 {
+		return
+	}
+	if dstm.IsNil() {
+		dstm.Set(reflect.MakeMap(f.ft))
+	}
+	iter := mapRange(srcm)
+	for iter.Next() {
+		val := reflect.New(f.ft.Elem().Elem())
+		if f.mi != nil {
+			f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
+		} else {
+			opts.Merge(asMessage(val), asMessage(iter.Value()))
+		}
+		dstm.SetMapIndex(iter.Key(), val)
+	}
+}