256
|
1 |
// Copyright 2019 The Go Authors. All rights reserved. |
|
2 |
// Use of this source code is governed by a BSD-style |
|
3 |
// license that can be found in the LICENSE file. |
|
4 |
|
|
5 |
package impl |
|
6 |
|
|
7 |
import ( |
|
8 |
"reflect" |
|
9 |
"sort" |
|
10 |
|
|
11 |
"google.golang.org/protobuf/encoding/protowire" |
|
12 |
"google.golang.org/protobuf/internal/genid" |
260
|
13 |
"google.golang.org/protobuf/reflect/protoreflect" |
256
|
14 |
) |
|
15 |
|
|
16 |
type mapInfo struct { |
|
17 |
goType reflect.Type |
|
18 |
keyWiretag uint64 |
|
19 |
valWiretag uint64 |
|
20 |
keyFuncs valueCoderFuncs |
|
21 |
valFuncs valueCoderFuncs |
260
|
22 |
keyZero protoreflect.Value |
|
23 |
keyKind protoreflect.Kind |
256
|
24 |
conv *mapConverter |
|
25 |
} |
|
26 |
|
260
|
27 |
func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) { |
256
|
28 |
// TODO: Consider generating specialized map coders. |
|
29 |
keyField := fd.MapKey() |
|
30 |
valField := fd.MapValue() |
|
31 |
keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()]) |
|
32 |
valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()]) |
|
33 |
keyFuncs := encoderFuncsForValue(keyField) |
|
34 |
valFuncs := encoderFuncsForValue(valField) |
|
35 |
conv := newMapConverter(ft, fd) |
|
36 |
|
|
37 |
mapi := &mapInfo{ |
|
38 |
goType: ft, |
|
39 |
keyWiretag: keyWiretag, |
|
40 |
valWiretag: valWiretag, |
|
41 |
keyFuncs: keyFuncs, |
|
42 |
valFuncs: valFuncs, |
|
43 |
keyZero: keyField.Default(), |
|
44 |
keyKind: keyField.Kind(), |
|
45 |
conv: conv, |
|
46 |
} |
260
|
47 |
if valField.Kind() == protoreflect.MessageKind { |
256
|
48 |
valueMessage = getMessageInfo(ft.Elem()) |
|
49 |
} |
|
50 |
|
|
51 |
funcs = pointerCoderFuncs{ |
|
52 |
size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int { |
|
53 |
return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts) |
|
54 |
}, |
|
55 |
marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { |
|
56 |
return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts) |
|
57 |
}, |
|
58 |
unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) { |
|
59 |
mp := p.AsValueOf(ft) |
|
60 |
if mp.Elem().IsNil() { |
|
61 |
mp.Elem().Set(reflect.MakeMap(mapi.goType)) |
|
62 |
} |
|
63 |
if f.mi == nil { |
|
64 |
return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts) |
|
65 |
} else { |
|
66 |
return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts) |
|
67 |
} |
|
68 |
}, |
|
69 |
} |
|
70 |
switch valField.Kind() { |
260
|
71 |
case protoreflect.MessageKind: |
256
|
72 |
funcs.merge = mergeMapOfMessage |
260
|
73 |
case protoreflect.BytesKind: |
256
|
74 |
funcs.merge = mergeMapOfBytes |
|
75 |
default: |
|
76 |
funcs.merge = mergeMap |
|
77 |
} |
|
78 |
if valFuncs.isInit != nil { |
|
79 |
funcs.isInit = func(p pointer, f *coderFieldInfo) error { |
|
80 |
return isInitMap(p.AsValueOf(ft).Elem(), mapi, f) |
|
81 |
} |
|
82 |
} |
|
83 |
return valueMessage, funcs |
|
84 |
} |
|
85 |
|
|
86 |
const ( |
|
87 |
mapKeyTagSize = 1 // field 1, tag size 1. |
|
88 |
mapValTagSize = 1 // field 2, tag size 2. |
|
89 |
) |
|
90 |
|
|
91 |
func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int { |
|
92 |
if mapv.Len() == 0 { |
|
93 |
return 0 |
|
94 |
} |
|
95 |
n := 0 |
|
96 |
iter := mapRange(mapv) |
|
97 |
for iter.Next() { |
|
98 |
key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey() |
|
99 |
keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) |
|
100 |
var valSize int |
|
101 |
value := mapi.conv.valConv.PBValueOf(iter.Value()) |
|
102 |
if f.mi == nil { |
|
103 |
valSize = mapi.valFuncs.size(value, mapValTagSize, opts) |
|
104 |
} else { |
|
105 |
p := pointerOfValue(iter.Value()) |
|
106 |
valSize += mapValTagSize |
|
107 |
valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts)) |
|
108 |
} |
|
109 |
n += f.tagsize + protowire.SizeBytes(keySize+valSize) |
|
110 |
} |
|
111 |
return n |
|
112 |
} |
|
113 |
|
|
114 |
func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { |
|
115 |
if wtyp != protowire.BytesType { |
|
116 |
return out, errUnknown |
|
117 |
} |
|
118 |
b, n := protowire.ConsumeBytes(b) |
|
119 |
if n < 0 { |
|
120 |
return out, errDecode |
|
121 |
} |
|
122 |
var ( |
|
123 |
key = mapi.keyZero |
|
124 |
val = mapi.conv.valConv.New() |
|
125 |
) |
|
126 |
for len(b) > 0 { |
|
127 |
num, wtyp, n := protowire.ConsumeTag(b) |
|
128 |
if n < 0 { |
|
129 |
return out, errDecode |
|
130 |
} |
|
131 |
if num > protowire.MaxValidNumber { |
|
132 |
return out, errDecode |
|
133 |
} |
|
134 |
b = b[n:] |
|
135 |
err := errUnknown |
|
136 |
switch num { |
|
137 |
case genid.MapEntry_Key_field_number: |
260
|
138 |
var v protoreflect.Value |
256
|
139 |
var o unmarshalOutput |
|
140 |
v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts) |
|
141 |
if err != nil { |
|
142 |
break |
|
143 |
} |
|
144 |
key = v |
|
145 |
n = o.n |
|
146 |
case genid.MapEntry_Value_field_number: |
260
|
147 |
var v protoreflect.Value |
256
|
148 |
var o unmarshalOutput |
|
149 |
v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts) |
|
150 |
if err != nil { |
|
151 |
break |
|
152 |
} |
|
153 |
val = v |
|
154 |
n = o.n |
|
155 |
} |
|
156 |
if err == errUnknown { |
|
157 |
n = protowire.ConsumeFieldValue(num, wtyp, b) |
|
158 |
if n < 0 { |
|
159 |
return out, errDecode |
|
160 |
} |
|
161 |
} else if err != nil { |
|
162 |
return out, err |
|
163 |
} |
|
164 |
b = b[n:] |
|
165 |
} |
|
166 |
mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val)) |
|
167 |
out.n = n |
|
168 |
return out, nil |
|
169 |
} |
|
170 |
|
|
171 |
func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { |
|
172 |
if wtyp != protowire.BytesType { |
|
173 |
return out, errUnknown |
|
174 |
} |
|
175 |
b, n := protowire.ConsumeBytes(b) |
|
176 |
if n < 0 { |
|
177 |
return out, errDecode |
|
178 |
} |
|
179 |
var ( |
|
180 |
key = mapi.keyZero |
|
181 |
val = reflect.New(f.mi.GoReflectType.Elem()) |
|
182 |
) |
|
183 |
for len(b) > 0 { |
|
184 |
num, wtyp, n := protowire.ConsumeTag(b) |
|
185 |
if n < 0 { |
|
186 |
return out, errDecode |
|
187 |
} |
|
188 |
if num > protowire.MaxValidNumber { |
|
189 |
return out, errDecode |
|
190 |
} |
|
191 |
b = b[n:] |
|
192 |
err := errUnknown |
|
193 |
switch num { |
|
194 |
case 1: |
260
|
195 |
var v protoreflect.Value |
256
|
196 |
var o unmarshalOutput |
|
197 |
v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts) |
|
198 |
if err != nil { |
|
199 |
break |
|
200 |
} |
|
201 |
key = v |
|
202 |
n = o.n |
|
203 |
case 2: |
|
204 |
if wtyp != protowire.BytesType { |
|
205 |
break |
|
206 |
} |
|
207 |
var v []byte |
|
208 |
v, n = protowire.ConsumeBytes(b) |
|
209 |
if n < 0 { |
|
210 |
return out, errDecode |
|
211 |
} |
|
212 |
var o unmarshalOutput |
|
213 |
o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts) |
|
214 |
if o.initialized { |
|
215 |
// Consider this map item initialized so long as we see |
|
216 |
// an initialized value. |
|
217 |
out.initialized = true |
|
218 |
} |
|
219 |
} |
|
220 |
if err == errUnknown { |
|
221 |
n = protowire.ConsumeFieldValue(num, wtyp, b) |
|
222 |
if n < 0 { |
|
223 |
return out, errDecode |
|
224 |
} |
|
225 |
} else if err != nil { |
|
226 |
return out, err |
|
227 |
} |
|
228 |
b = b[n:] |
|
229 |
} |
|
230 |
mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val) |
|
231 |
out.n = n |
|
232 |
return out, nil |
|
233 |
} |
|
234 |
|
|
235 |
func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { |
|
236 |
if f.mi == nil { |
|
237 |
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey() |
|
238 |
val := mapi.conv.valConv.PBValueOf(valrv) |
|
239 |
size := 0 |
|
240 |
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) |
|
241 |
size += mapi.valFuncs.size(val, mapValTagSize, opts) |
|
242 |
b = protowire.AppendVarint(b, uint64(size)) |
|
243 |
b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts) |
|
244 |
if err != nil { |
|
245 |
return nil, err |
|
246 |
} |
|
247 |
return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts) |
|
248 |
} else { |
|
249 |
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey() |
|
250 |
val := pointerOfValue(valrv) |
|
251 |
valSize := f.mi.sizePointer(val, opts) |
|
252 |
size := 0 |
|
253 |
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) |
|
254 |
size += mapValTagSize + protowire.SizeBytes(valSize) |
|
255 |
b = protowire.AppendVarint(b, uint64(size)) |
|
256 |
b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts) |
|
257 |
if err != nil { |
|
258 |
return nil, err |
|
259 |
} |
|
260 |
b = protowire.AppendVarint(b, mapi.valWiretag) |
|
261 |
b = protowire.AppendVarint(b, uint64(valSize)) |
|
262 |
return f.mi.marshalAppendPointer(b, val, opts) |
|
263 |
} |
|
264 |
} |
|
265 |
|
|
266 |
func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { |
|
267 |
if mapv.Len() == 0 { |
|
268 |
return b, nil |
|
269 |
} |
|
270 |
if opts.Deterministic() { |
|
271 |
return appendMapDeterministic(b, mapv, mapi, f, opts) |
|
272 |
} |
|
273 |
iter := mapRange(mapv) |
|
274 |
for iter.Next() { |
|
275 |
var err error |
|
276 |
b = protowire.AppendVarint(b, f.wiretag) |
|
277 |
b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts) |
|
278 |
if err != nil { |
|
279 |
return b, err |
|
280 |
} |
|
281 |
} |
|
282 |
return b, nil |
|
283 |
} |
|
284 |
|
|
285 |
func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { |
|
286 |
keys := mapv.MapKeys() |
|
287 |
sort.Slice(keys, func(i, j int) bool { |
|
288 |
switch keys[i].Kind() { |
|
289 |
case reflect.Bool: |
|
290 |
return !keys[i].Bool() && keys[j].Bool() |
|
291 |
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
|
292 |
return keys[i].Int() < keys[j].Int() |
|
293 |
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: |
|
294 |
return keys[i].Uint() < keys[j].Uint() |
|
295 |
case reflect.Float32, reflect.Float64: |
|
296 |
return keys[i].Float() < keys[j].Float() |
|
297 |
case reflect.String: |
|
298 |
return keys[i].String() < keys[j].String() |
|
299 |
default: |
|
300 |
panic("invalid kind: " + keys[i].Kind().String()) |
|
301 |
} |
|
302 |
}) |
|
303 |
for _, key := range keys { |
|
304 |
var err error |
|
305 |
b = protowire.AppendVarint(b, f.wiretag) |
|
306 |
b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts) |
|
307 |
if err != nil { |
|
308 |
return b, err |
|
309 |
} |
|
310 |
} |
|
311 |
return b, nil |
|
312 |
} |
|
313 |
|
|
314 |
func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error { |
|
315 |
if mi := f.mi; mi != nil { |
|
316 |
mi.init() |
|
317 |
if !mi.needsInitCheck { |
|
318 |
return nil |
|
319 |
} |
|
320 |
iter := mapRange(mapv) |
|
321 |
for iter.Next() { |
|
322 |
val := pointerOfValue(iter.Value()) |
|
323 |
if err := mi.checkInitializedPointer(val); err != nil { |
|
324 |
return err |
|
325 |
} |
|
326 |
} |
|
327 |
} else { |
|
328 |
iter := mapRange(mapv) |
|
329 |
for iter.Next() { |
|
330 |
val := mapi.conv.valConv.PBValueOf(iter.Value()) |
|
331 |
if err := mapi.valFuncs.isInit(val); err != nil { |
|
332 |
return err |
|
333 |
} |
|
334 |
} |
|
335 |
} |
|
336 |
return nil |
|
337 |
} |
|
338 |
|
|
339 |
func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { |
|
340 |
dstm := dst.AsValueOf(f.ft).Elem() |
|
341 |
srcm := src.AsValueOf(f.ft).Elem() |
|
342 |
if srcm.Len() == 0 { |
|
343 |
return |
|
344 |
} |
|
345 |
if dstm.IsNil() { |
|
346 |
dstm.Set(reflect.MakeMap(f.ft)) |
|
347 |
} |
|
348 |
iter := mapRange(srcm) |
|
349 |
for iter.Next() { |
|
350 |
dstm.SetMapIndex(iter.Key(), iter.Value()) |
|
351 |
} |
|
352 |
} |
|
353 |
|
|
354 |
func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { |
|
355 |
dstm := dst.AsValueOf(f.ft).Elem() |
|
356 |
srcm := src.AsValueOf(f.ft).Elem() |
|
357 |
if srcm.Len() == 0 { |
|
358 |
return |
|
359 |
} |
|
360 |
if dstm.IsNil() { |
|
361 |
dstm.Set(reflect.MakeMap(f.ft)) |
|
362 |
} |
|
363 |
iter := mapRange(srcm) |
|
364 |
for iter.Next() { |
|
365 |
dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...))) |
|
366 |
} |
|
367 |
} |
|
368 |
|
|
369 |
func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { |
|
370 |
dstm := dst.AsValueOf(f.ft).Elem() |
|
371 |
srcm := src.AsValueOf(f.ft).Elem() |
|
372 |
if srcm.Len() == 0 { |
|
373 |
return |
|
374 |
} |
|
375 |
if dstm.IsNil() { |
|
376 |
dstm.Set(reflect.MakeMap(f.ft)) |
|
377 |
} |
|
378 |
iter := mapRange(srcm) |
|
379 |
for iter.Next() { |
|
380 |
val := reflect.New(f.ft.Elem().Elem()) |
|
381 |
if f.mi != nil { |
|
382 |
f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts) |
|
383 |
} else { |
|
384 |
opts.Merge(asMessage(val), asMessage(iter.Value())) |
|
385 |
} |
|
386 |
dstm.SetMapIndex(iter.Key(), val) |
|
387 |
} |
|
388 |
} |