vendor/google.golang.org/protobuf/internal/impl/codec_map.go
changeset 256 6d9efbef00a9
child 260 445e01aede7e
equal deleted inserted replaced
255:4f153a23adab 256:6d9efbef00a9
       
     1 // Copyright 2019 The Go Authors. All rights reserved.
       
     2 // Use of this source code is governed by a BSD-style
       
     3 // license that can be found in the LICENSE file.
       
     4 
       
     5 package impl
       
     6 
       
     7 import (
       
     8 	"reflect"
       
     9 	"sort"
       
    10 
       
    11 	"google.golang.org/protobuf/encoding/protowire"
       
    12 	"google.golang.org/protobuf/internal/genid"
       
    13 	pref "google.golang.org/protobuf/reflect/protoreflect"
       
    14 )
       
    15 
       
    16 type mapInfo struct {
       
    17 	goType     reflect.Type
       
    18 	keyWiretag uint64
       
    19 	valWiretag uint64
       
    20 	keyFuncs   valueCoderFuncs
       
    21 	valFuncs   valueCoderFuncs
       
    22 	keyZero    pref.Value
       
    23 	keyKind    pref.Kind
       
    24 	conv       *mapConverter
       
    25 }
       
    26 
       
    27 func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
       
    28 	// TODO: Consider generating specialized map coders.
       
    29 	keyField := fd.MapKey()
       
    30 	valField := fd.MapValue()
       
    31 	keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
       
    32 	valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
       
    33 	keyFuncs := encoderFuncsForValue(keyField)
       
    34 	valFuncs := encoderFuncsForValue(valField)
       
    35 	conv := newMapConverter(ft, fd)
       
    36 
       
    37 	mapi := &mapInfo{
       
    38 		goType:     ft,
       
    39 		keyWiretag: keyWiretag,
       
    40 		valWiretag: valWiretag,
       
    41 		keyFuncs:   keyFuncs,
       
    42 		valFuncs:   valFuncs,
       
    43 		keyZero:    keyField.Default(),
       
    44 		keyKind:    keyField.Kind(),
       
    45 		conv:       conv,
       
    46 	}
       
    47 	if valField.Kind() == pref.MessageKind {
       
    48 		valueMessage = getMessageInfo(ft.Elem())
       
    49 	}
       
    50 
       
    51 	funcs = pointerCoderFuncs{
       
    52 		size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
       
    53 			return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
       
    54 		},
       
    55 		marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
       
    56 			return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
       
    57 		},
       
    58 		unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
       
    59 			mp := p.AsValueOf(ft)
       
    60 			if mp.Elem().IsNil() {
       
    61 				mp.Elem().Set(reflect.MakeMap(mapi.goType))
       
    62 			}
       
    63 			if f.mi == nil {
       
    64 				return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
       
    65 			} else {
       
    66 				return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
       
    67 			}
       
    68 		},
       
    69 	}
       
    70 	switch valField.Kind() {
       
    71 	case pref.MessageKind:
       
    72 		funcs.merge = mergeMapOfMessage
       
    73 	case pref.BytesKind:
       
    74 		funcs.merge = mergeMapOfBytes
       
    75 	default:
       
    76 		funcs.merge = mergeMap
       
    77 	}
       
    78 	if valFuncs.isInit != nil {
       
    79 		funcs.isInit = func(p pointer, f *coderFieldInfo) error {
       
    80 			return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
       
    81 		}
       
    82 	}
       
    83 	return valueMessage, funcs
       
    84 }
       
    85 
       
    86 const (
       
    87 	mapKeyTagSize = 1 // field 1, tag size 1.
       
    88 	mapValTagSize = 1 // field 2, tag size 2.
       
    89 )
       
    90 
       
    91 func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
       
    92 	if mapv.Len() == 0 {
       
    93 		return 0
       
    94 	}
       
    95 	n := 0
       
    96 	iter := mapRange(mapv)
       
    97 	for iter.Next() {
       
    98 		key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
       
    99 		keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
       
   100 		var valSize int
       
   101 		value := mapi.conv.valConv.PBValueOf(iter.Value())
       
   102 		if f.mi == nil {
       
   103 			valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
       
   104 		} else {
       
   105 			p := pointerOfValue(iter.Value())
       
   106 			valSize += mapValTagSize
       
   107 			valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
       
   108 		}
       
   109 		n += f.tagsize + protowire.SizeBytes(keySize+valSize)
       
   110 	}
       
   111 	return n
       
   112 }
       
   113 
       
   114 func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
       
   115 	if wtyp != protowire.BytesType {
       
   116 		return out, errUnknown
       
   117 	}
       
   118 	b, n := protowire.ConsumeBytes(b)
       
   119 	if n < 0 {
       
   120 		return out, errDecode
       
   121 	}
       
   122 	var (
       
   123 		key = mapi.keyZero
       
   124 		val = mapi.conv.valConv.New()
       
   125 	)
       
   126 	for len(b) > 0 {
       
   127 		num, wtyp, n := protowire.ConsumeTag(b)
       
   128 		if n < 0 {
       
   129 			return out, errDecode
       
   130 		}
       
   131 		if num > protowire.MaxValidNumber {
       
   132 			return out, errDecode
       
   133 		}
       
   134 		b = b[n:]
       
   135 		err := errUnknown
       
   136 		switch num {
       
   137 		case genid.MapEntry_Key_field_number:
       
   138 			var v pref.Value
       
   139 			var o unmarshalOutput
       
   140 			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
       
   141 			if err != nil {
       
   142 				break
       
   143 			}
       
   144 			key = v
       
   145 			n = o.n
       
   146 		case genid.MapEntry_Value_field_number:
       
   147 			var v pref.Value
       
   148 			var o unmarshalOutput
       
   149 			v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
       
   150 			if err != nil {
       
   151 				break
       
   152 			}
       
   153 			val = v
       
   154 			n = o.n
       
   155 		}
       
   156 		if err == errUnknown {
       
   157 			n = protowire.ConsumeFieldValue(num, wtyp, b)
       
   158 			if n < 0 {
       
   159 				return out, errDecode
       
   160 			}
       
   161 		} else if err != nil {
       
   162 			return out, err
       
   163 		}
       
   164 		b = b[n:]
       
   165 	}
       
   166 	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
       
   167 	out.n = n
       
   168 	return out, nil
       
   169 }
       
   170 
       
   171 func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
       
   172 	if wtyp != protowire.BytesType {
       
   173 		return out, errUnknown
       
   174 	}
       
   175 	b, n := protowire.ConsumeBytes(b)
       
   176 	if n < 0 {
       
   177 		return out, errDecode
       
   178 	}
       
   179 	var (
       
   180 		key = mapi.keyZero
       
   181 		val = reflect.New(f.mi.GoReflectType.Elem())
       
   182 	)
       
   183 	for len(b) > 0 {
       
   184 		num, wtyp, n := protowire.ConsumeTag(b)
       
   185 		if n < 0 {
       
   186 			return out, errDecode
       
   187 		}
       
   188 		if num > protowire.MaxValidNumber {
       
   189 			return out, errDecode
       
   190 		}
       
   191 		b = b[n:]
       
   192 		err := errUnknown
       
   193 		switch num {
       
   194 		case 1:
       
   195 			var v pref.Value
       
   196 			var o unmarshalOutput
       
   197 			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
       
   198 			if err != nil {
       
   199 				break
       
   200 			}
       
   201 			key = v
       
   202 			n = o.n
       
   203 		case 2:
       
   204 			if wtyp != protowire.BytesType {
       
   205 				break
       
   206 			}
       
   207 			var v []byte
       
   208 			v, n = protowire.ConsumeBytes(b)
       
   209 			if n < 0 {
       
   210 				return out, errDecode
       
   211 			}
       
   212 			var o unmarshalOutput
       
   213 			o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
       
   214 			if o.initialized {
       
   215 				// Consider this map item initialized so long as we see
       
   216 				// an initialized value.
       
   217 				out.initialized = true
       
   218 			}
       
   219 		}
       
   220 		if err == errUnknown {
       
   221 			n = protowire.ConsumeFieldValue(num, wtyp, b)
       
   222 			if n < 0 {
       
   223 				return out, errDecode
       
   224 			}
       
   225 		} else if err != nil {
       
   226 			return out, err
       
   227 		}
       
   228 		b = b[n:]
       
   229 	}
       
   230 	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
       
   231 	out.n = n
       
   232 	return out, nil
       
   233 }
       
   234 
       
   235 func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
       
   236 	if f.mi == nil {
       
   237 		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
       
   238 		val := mapi.conv.valConv.PBValueOf(valrv)
       
   239 		size := 0
       
   240 		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
       
   241 		size += mapi.valFuncs.size(val, mapValTagSize, opts)
       
   242 		b = protowire.AppendVarint(b, uint64(size))
       
   243 		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
       
   244 		if err != nil {
       
   245 			return nil, err
       
   246 		}
       
   247 		return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
       
   248 	} else {
       
   249 		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
       
   250 		val := pointerOfValue(valrv)
       
   251 		valSize := f.mi.sizePointer(val, opts)
       
   252 		size := 0
       
   253 		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
       
   254 		size += mapValTagSize + protowire.SizeBytes(valSize)
       
   255 		b = protowire.AppendVarint(b, uint64(size))
       
   256 		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
       
   257 		if err != nil {
       
   258 			return nil, err
       
   259 		}
       
   260 		b = protowire.AppendVarint(b, mapi.valWiretag)
       
   261 		b = protowire.AppendVarint(b, uint64(valSize))
       
   262 		return f.mi.marshalAppendPointer(b, val, opts)
       
   263 	}
       
   264 }
       
   265 
       
   266 func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
       
   267 	if mapv.Len() == 0 {
       
   268 		return b, nil
       
   269 	}
       
   270 	if opts.Deterministic() {
       
   271 		return appendMapDeterministic(b, mapv, mapi, f, opts)
       
   272 	}
       
   273 	iter := mapRange(mapv)
       
   274 	for iter.Next() {
       
   275 		var err error
       
   276 		b = protowire.AppendVarint(b, f.wiretag)
       
   277 		b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
       
   278 		if err != nil {
       
   279 			return b, err
       
   280 		}
       
   281 	}
       
   282 	return b, nil
       
   283 }
       
   284 
       
   285 func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
       
   286 	keys := mapv.MapKeys()
       
   287 	sort.Slice(keys, func(i, j int) bool {
       
   288 		switch keys[i].Kind() {
       
   289 		case reflect.Bool:
       
   290 			return !keys[i].Bool() && keys[j].Bool()
       
   291 		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
       
   292 			return keys[i].Int() < keys[j].Int()
       
   293 		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
       
   294 			return keys[i].Uint() < keys[j].Uint()
       
   295 		case reflect.Float32, reflect.Float64:
       
   296 			return keys[i].Float() < keys[j].Float()
       
   297 		case reflect.String:
       
   298 			return keys[i].String() < keys[j].String()
       
   299 		default:
       
   300 			panic("invalid kind: " + keys[i].Kind().String())
       
   301 		}
       
   302 	})
       
   303 	for _, key := range keys {
       
   304 		var err error
       
   305 		b = protowire.AppendVarint(b, f.wiretag)
       
   306 		b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
       
   307 		if err != nil {
       
   308 			return b, err
       
   309 		}
       
   310 	}
       
   311 	return b, nil
       
   312 }
       
   313 
       
   314 func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
       
   315 	if mi := f.mi; mi != nil {
       
   316 		mi.init()
       
   317 		if !mi.needsInitCheck {
       
   318 			return nil
       
   319 		}
       
   320 		iter := mapRange(mapv)
       
   321 		for iter.Next() {
       
   322 			val := pointerOfValue(iter.Value())
       
   323 			if err := mi.checkInitializedPointer(val); err != nil {
       
   324 				return err
       
   325 			}
       
   326 		}
       
   327 	} else {
       
   328 		iter := mapRange(mapv)
       
   329 		for iter.Next() {
       
   330 			val := mapi.conv.valConv.PBValueOf(iter.Value())
       
   331 			if err := mapi.valFuncs.isInit(val); err != nil {
       
   332 				return err
       
   333 			}
       
   334 		}
       
   335 	}
       
   336 	return nil
       
   337 }
       
   338 
       
   339 func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
       
   340 	dstm := dst.AsValueOf(f.ft).Elem()
       
   341 	srcm := src.AsValueOf(f.ft).Elem()
       
   342 	if srcm.Len() == 0 {
       
   343 		return
       
   344 	}
       
   345 	if dstm.IsNil() {
       
   346 		dstm.Set(reflect.MakeMap(f.ft))
       
   347 	}
       
   348 	iter := mapRange(srcm)
       
   349 	for iter.Next() {
       
   350 		dstm.SetMapIndex(iter.Key(), iter.Value())
       
   351 	}
       
   352 }
       
   353 
       
   354 func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
       
   355 	dstm := dst.AsValueOf(f.ft).Elem()
       
   356 	srcm := src.AsValueOf(f.ft).Elem()
       
   357 	if srcm.Len() == 0 {
       
   358 		return
       
   359 	}
       
   360 	if dstm.IsNil() {
       
   361 		dstm.Set(reflect.MakeMap(f.ft))
       
   362 	}
       
   363 	iter := mapRange(srcm)
       
   364 	for iter.Next() {
       
   365 		dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
       
   366 	}
       
   367 }
       
   368 
       
   369 func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
       
   370 	dstm := dst.AsValueOf(f.ft).Elem()
       
   371 	srcm := src.AsValueOf(f.ft).Elem()
       
   372 	if srcm.Len() == 0 {
       
   373 		return
       
   374 	}
       
   375 	if dstm.IsNil() {
       
   376 		dstm.Set(reflect.MakeMap(f.ft))
       
   377 	}
       
   378 	iter := mapRange(srcm)
       
   379 	for iter.Next() {
       
   380 		val := reflect.New(f.ft.Elem().Elem())
       
   381 		if f.mi != nil {
       
   382 			f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
       
   383 		} else {
       
   384 			opts.Merge(asMessage(val), asMessage(iter.Value()))
       
   385 		}
       
   386 		dstm.SetMapIndex(iter.Key(), val)
       
   387 	}
       
   388 }