|
1 // Copyright 2019 The Go Authors. All rights reserved. |
|
2 // Use of this source code is governed by a BSD-style |
|
3 // license that can be found in the LICENSE file. |
|
4 |
|
5 package impl |
|
6 |
|
7 import ( |
|
8 "reflect" |
|
9 "sort" |
|
10 |
|
11 "google.golang.org/protobuf/encoding/protowire" |
|
12 "google.golang.org/protobuf/internal/genid" |
|
13 pref "google.golang.org/protobuf/reflect/protoreflect" |
|
14 ) |
|
15 |
|
16 type mapInfo struct { |
|
17 goType reflect.Type |
|
18 keyWiretag uint64 |
|
19 valWiretag uint64 |
|
20 keyFuncs valueCoderFuncs |
|
21 valFuncs valueCoderFuncs |
|
22 keyZero pref.Value |
|
23 keyKind pref.Kind |
|
24 conv *mapConverter |
|
25 } |
|
26 |
|
27 func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) { |
|
28 // TODO: Consider generating specialized map coders. |
|
29 keyField := fd.MapKey() |
|
30 valField := fd.MapValue() |
|
31 keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()]) |
|
32 valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()]) |
|
33 keyFuncs := encoderFuncsForValue(keyField) |
|
34 valFuncs := encoderFuncsForValue(valField) |
|
35 conv := newMapConverter(ft, fd) |
|
36 |
|
37 mapi := &mapInfo{ |
|
38 goType: ft, |
|
39 keyWiretag: keyWiretag, |
|
40 valWiretag: valWiretag, |
|
41 keyFuncs: keyFuncs, |
|
42 valFuncs: valFuncs, |
|
43 keyZero: keyField.Default(), |
|
44 keyKind: keyField.Kind(), |
|
45 conv: conv, |
|
46 } |
|
47 if valField.Kind() == pref.MessageKind { |
|
48 valueMessage = getMessageInfo(ft.Elem()) |
|
49 } |
|
50 |
|
51 funcs = pointerCoderFuncs{ |
|
52 size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int { |
|
53 return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts) |
|
54 }, |
|
55 marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { |
|
56 return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts) |
|
57 }, |
|
58 unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) { |
|
59 mp := p.AsValueOf(ft) |
|
60 if mp.Elem().IsNil() { |
|
61 mp.Elem().Set(reflect.MakeMap(mapi.goType)) |
|
62 } |
|
63 if f.mi == nil { |
|
64 return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts) |
|
65 } else { |
|
66 return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts) |
|
67 } |
|
68 }, |
|
69 } |
|
70 switch valField.Kind() { |
|
71 case pref.MessageKind: |
|
72 funcs.merge = mergeMapOfMessage |
|
73 case pref.BytesKind: |
|
74 funcs.merge = mergeMapOfBytes |
|
75 default: |
|
76 funcs.merge = mergeMap |
|
77 } |
|
78 if valFuncs.isInit != nil { |
|
79 funcs.isInit = func(p pointer, f *coderFieldInfo) error { |
|
80 return isInitMap(p.AsValueOf(ft).Elem(), mapi, f) |
|
81 } |
|
82 } |
|
83 return valueMessage, funcs |
|
84 } |
|
85 |
|
86 const ( |
|
87 mapKeyTagSize = 1 // field 1, tag size 1. |
|
88 mapValTagSize = 1 // field 2, tag size 2. |
|
89 ) |
|
90 |
|
91 func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int { |
|
92 if mapv.Len() == 0 { |
|
93 return 0 |
|
94 } |
|
95 n := 0 |
|
96 iter := mapRange(mapv) |
|
97 for iter.Next() { |
|
98 key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey() |
|
99 keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) |
|
100 var valSize int |
|
101 value := mapi.conv.valConv.PBValueOf(iter.Value()) |
|
102 if f.mi == nil { |
|
103 valSize = mapi.valFuncs.size(value, mapValTagSize, opts) |
|
104 } else { |
|
105 p := pointerOfValue(iter.Value()) |
|
106 valSize += mapValTagSize |
|
107 valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts)) |
|
108 } |
|
109 n += f.tagsize + protowire.SizeBytes(keySize+valSize) |
|
110 } |
|
111 return n |
|
112 } |
|
113 |
|
114 func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { |
|
115 if wtyp != protowire.BytesType { |
|
116 return out, errUnknown |
|
117 } |
|
118 b, n := protowire.ConsumeBytes(b) |
|
119 if n < 0 { |
|
120 return out, errDecode |
|
121 } |
|
122 var ( |
|
123 key = mapi.keyZero |
|
124 val = mapi.conv.valConv.New() |
|
125 ) |
|
126 for len(b) > 0 { |
|
127 num, wtyp, n := protowire.ConsumeTag(b) |
|
128 if n < 0 { |
|
129 return out, errDecode |
|
130 } |
|
131 if num > protowire.MaxValidNumber { |
|
132 return out, errDecode |
|
133 } |
|
134 b = b[n:] |
|
135 err := errUnknown |
|
136 switch num { |
|
137 case genid.MapEntry_Key_field_number: |
|
138 var v pref.Value |
|
139 var o unmarshalOutput |
|
140 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts) |
|
141 if err != nil { |
|
142 break |
|
143 } |
|
144 key = v |
|
145 n = o.n |
|
146 case genid.MapEntry_Value_field_number: |
|
147 var v pref.Value |
|
148 var o unmarshalOutput |
|
149 v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts) |
|
150 if err != nil { |
|
151 break |
|
152 } |
|
153 val = v |
|
154 n = o.n |
|
155 } |
|
156 if err == errUnknown { |
|
157 n = protowire.ConsumeFieldValue(num, wtyp, b) |
|
158 if n < 0 { |
|
159 return out, errDecode |
|
160 } |
|
161 } else if err != nil { |
|
162 return out, err |
|
163 } |
|
164 b = b[n:] |
|
165 } |
|
166 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val)) |
|
167 out.n = n |
|
168 return out, nil |
|
169 } |
|
170 |
|
171 func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { |
|
172 if wtyp != protowire.BytesType { |
|
173 return out, errUnknown |
|
174 } |
|
175 b, n := protowire.ConsumeBytes(b) |
|
176 if n < 0 { |
|
177 return out, errDecode |
|
178 } |
|
179 var ( |
|
180 key = mapi.keyZero |
|
181 val = reflect.New(f.mi.GoReflectType.Elem()) |
|
182 ) |
|
183 for len(b) > 0 { |
|
184 num, wtyp, n := protowire.ConsumeTag(b) |
|
185 if n < 0 { |
|
186 return out, errDecode |
|
187 } |
|
188 if num > protowire.MaxValidNumber { |
|
189 return out, errDecode |
|
190 } |
|
191 b = b[n:] |
|
192 err := errUnknown |
|
193 switch num { |
|
194 case 1: |
|
195 var v pref.Value |
|
196 var o unmarshalOutput |
|
197 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts) |
|
198 if err != nil { |
|
199 break |
|
200 } |
|
201 key = v |
|
202 n = o.n |
|
203 case 2: |
|
204 if wtyp != protowire.BytesType { |
|
205 break |
|
206 } |
|
207 var v []byte |
|
208 v, n = protowire.ConsumeBytes(b) |
|
209 if n < 0 { |
|
210 return out, errDecode |
|
211 } |
|
212 var o unmarshalOutput |
|
213 o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts) |
|
214 if o.initialized { |
|
215 // Consider this map item initialized so long as we see |
|
216 // an initialized value. |
|
217 out.initialized = true |
|
218 } |
|
219 } |
|
220 if err == errUnknown { |
|
221 n = protowire.ConsumeFieldValue(num, wtyp, b) |
|
222 if n < 0 { |
|
223 return out, errDecode |
|
224 } |
|
225 } else if err != nil { |
|
226 return out, err |
|
227 } |
|
228 b = b[n:] |
|
229 } |
|
230 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val) |
|
231 out.n = n |
|
232 return out, nil |
|
233 } |
|
234 |
|
235 func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { |
|
236 if f.mi == nil { |
|
237 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey() |
|
238 val := mapi.conv.valConv.PBValueOf(valrv) |
|
239 size := 0 |
|
240 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) |
|
241 size += mapi.valFuncs.size(val, mapValTagSize, opts) |
|
242 b = protowire.AppendVarint(b, uint64(size)) |
|
243 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts) |
|
244 if err != nil { |
|
245 return nil, err |
|
246 } |
|
247 return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts) |
|
248 } else { |
|
249 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey() |
|
250 val := pointerOfValue(valrv) |
|
251 valSize := f.mi.sizePointer(val, opts) |
|
252 size := 0 |
|
253 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) |
|
254 size += mapValTagSize + protowire.SizeBytes(valSize) |
|
255 b = protowire.AppendVarint(b, uint64(size)) |
|
256 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts) |
|
257 if err != nil { |
|
258 return nil, err |
|
259 } |
|
260 b = protowire.AppendVarint(b, mapi.valWiretag) |
|
261 b = protowire.AppendVarint(b, uint64(valSize)) |
|
262 return f.mi.marshalAppendPointer(b, val, opts) |
|
263 } |
|
264 } |
|
265 |
|
266 func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { |
|
267 if mapv.Len() == 0 { |
|
268 return b, nil |
|
269 } |
|
270 if opts.Deterministic() { |
|
271 return appendMapDeterministic(b, mapv, mapi, f, opts) |
|
272 } |
|
273 iter := mapRange(mapv) |
|
274 for iter.Next() { |
|
275 var err error |
|
276 b = protowire.AppendVarint(b, f.wiretag) |
|
277 b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts) |
|
278 if err != nil { |
|
279 return b, err |
|
280 } |
|
281 } |
|
282 return b, nil |
|
283 } |
|
284 |
|
285 func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { |
|
286 keys := mapv.MapKeys() |
|
287 sort.Slice(keys, func(i, j int) bool { |
|
288 switch keys[i].Kind() { |
|
289 case reflect.Bool: |
|
290 return !keys[i].Bool() && keys[j].Bool() |
|
291 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
|
292 return keys[i].Int() < keys[j].Int() |
|
293 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: |
|
294 return keys[i].Uint() < keys[j].Uint() |
|
295 case reflect.Float32, reflect.Float64: |
|
296 return keys[i].Float() < keys[j].Float() |
|
297 case reflect.String: |
|
298 return keys[i].String() < keys[j].String() |
|
299 default: |
|
300 panic("invalid kind: " + keys[i].Kind().String()) |
|
301 } |
|
302 }) |
|
303 for _, key := range keys { |
|
304 var err error |
|
305 b = protowire.AppendVarint(b, f.wiretag) |
|
306 b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts) |
|
307 if err != nil { |
|
308 return b, err |
|
309 } |
|
310 } |
|
311 return b, nil |
|
312 } |
|
313 |
|
314 func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error { |
|
315 if mi := f.mi; mi != nil { |
|
316 mi.init() |
|
317 if !mi.needsInitCheck { |
|
318 return nil |
|
319 } |
|
320 iter := mapRange(mapv) |
|
321 for iter.Next() { |
|
322 val := pointerOfValue(iter.Value()) |
|
323 if err := mi.checkInitializedPointer(val); err != nil { |
|
324 return err |
|
325 } |
|
326 } |
|
327 } else { |
|
328 iter := mapRange(mapv) |
|
329 for iter.Next() { |
|
330 val := mapi.conv.valConv.PBValueOf(iter.Value()) |
|
331 if err := mapi.valFuncs.isInit(val); err != nil { |
|
332 return err |
|
333 } |
|
334 } |
|
335 } |
|
336 return nil |
|
337 } |
|
338 |
|
339 func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { |
|
340 dstm := dst.AsValueOf(f.ft).Elem() |
|
341 srcm := src.AsValueOf(f.ft).Elem() |
|
342 if srcm.Len() == 0 { |
|
343 return |
|
344 } |
|
345 if dstm.IsNil() { |
|
346 dstm.Set(reflect.MakeMap(f.ft)) |
|
347 } |
|
348 iter := mapRange(srcm) |
|
349 for iter.Next() { |
|
350 dstm.SetMapIndex(iter.Key(), iter.Value()) |
|
351 } |
|
352 } |
|
353 |
|
354 func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { |
|
355 dstm := dst.AsValueOf(f.ft).Elem() |
|
356 srcm := src.AsValueOf(f.ft).Elem() |
|
357 if srcm.Len() == 0 { |
|
358 return |
|
359 } |
|
360 if dstm.IsNil() { |
|
361 dstm.Set(reflect.MakeMap(f.ft)) |
|
362 } |
|
363 iter := mapRange(srcm) |
|
364 for iter.Next() { |
|
365 dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...))) |
|
366 } |
|
367 } |
|
368 |
|
369 func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { |
|
370 dstm := dst.AsValueOf(f.ft).Elem() |
|
371 srcm := src.AsValueOf(f.ft).Elem() |
|
372 if srcm.Len() == 0 { |
|
373 return |
|
374 } |
|
375 if dstm.IsNil() { |
|
376 dstm.Set(reflect.MakeMap(f.ft)) |
|
377 } |
|
378 iter := mapRange(srcm) |
|
379 for iter.Next() { |
|
380 val := reflect.New(f.ft.Elem().Elem()) |
|
381 if f.mi != nil { |
|
382 f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts) |
|
383 } else { |
|
384 opts.Merge(asMessage(val), asMessage(iter.Value())) |
|
385 } |
|
386 dstm.SetMapIndex(iter.Key(), val) |
|
387 } |
|
388 } |