256
|
1 |
// Copyright 2019 The Go Authors. All rights reserved. |
|
2 |
// Use of this source code is governed by a BSD-style |
|
3 |
// license that can be found in the LICENSE file. |
|
4 |
|
|
5 |
package impl |
|
6 |
|
|
7 |
import ( |
|
8 |
"math/bits" |
|
9 |
|
|
10 |
"google.golang.org/protobuf/encoding/protowire" |
|
11 |
"google.golang.org/protobuf/internal/errors" |
|
12 |
"google.golang.org/protobuf/internal/flags" |
|
13 |
"google.golang.org/protobuf/proto" |
|
14 |
"google.golang.org/protobuf/reflect/protoreflect" |
260
|
15 |
"google.golang.org/protobuf/reflect/protoregistry" |
256
|
16 |
"google.golang.org/protobuf/runtime/protoiface" |
|
17 |
) |
|
18 |
|
|
19 |
var errDecode = errors.New("cannot parse invalid wire-format data") |
260
|
20 |
var errRecursionDepth = errors.New("exceeded maximum recursion depth") |
256
|
21 |
|
|
22 |
type unmarshalOptions struct { |
|
23 |
flags protoiface.UnmarshalInputFlags |
|
24 |
resolver interface { |
|
25 |
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) |
|
26 |
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) |
|
27 |
} |
260
|
28 |
depth int |
256
|
29 |
} |
|
30 |
|
|
31 |
func (o unmarshalOptions) Options() proto.UnmarshalOptions { |
|
32 |
return proto.UnmarshalOptions{ |
|
33 |
Merge: true, |
|
34 |
AllowPartial: true, |
|
35 |
DiscardUnknown: o.DiscardUnknown(), |
|
36 |
Resolver: o.resolver, |
|
37 |
} |
|
38 |
} |
|
39 |
|
260
|
40 |
func (o unmarshalOptions) DiscardUnknown() bool { |
|
41 |
return o.flags&protoiface.UnmarshalDiscardUnknown != 0 |
|
42 |
} |
256
|
43 |
|
|
44 |
func (o unmarshalOptions) IsDefault() bool { |
260
|
45 |
return o.flags == 0 && o.resolver == protoregistry.GlobalTypes |
256
|
46 |
} |
|
47 |
|
|
48 |
var lazyUnmarshalOptions = unmarshalOptions{ |
260
|
49 |
resolver: protoregistry.GlobalTypes, |
|
50 |
depth: protowire.DefaultRecursionLimit, |
256
|
51 |
} |
|
52 |
|
|
53 |
type unmarshalOutput struct { |
|
54 |
n int // number of bytes consumed |
|
55 |
initialized bool |
|
56 |
} |
|
57 |
|
|
58 |
// unmarshal is protoreflect.Methods.Unmarshal. |
260
|
59 |
func (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) { |
256
|
60 |
var p pointer |
|
61 |
if ms, ok := in.Message.(*messageState); ok { |
|
62 |
p = ms.pointer() |
|
63 |
} else { |
|
64 |
p = in.Message.(*messageReflectWrapper).pointer() |
|
65 |
} |
|
66 |
out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{ |
|
67 |
flags: in.Flags, |
|
68 |
resolver: in.Resolver, |
260
|
69 |
depth: in.Depth, |
256
|
70 |
}) |
260
|
71 |
var flags protoiface.UnmarshalOutputFlags |
256
|
72 |
if out.initialized { |
260
|
73 |
flags |= protoiface.UnmarshalInitialized |
256
|
74 |
} |
260
|
75 |
return protoiface.UnmarshalOutput{ |
256
|
76 |
Flags: flags, |
|
77 |
}, err |
|
78 |
} |
|
79 |
|
|
80 |
// errUnknown is returned during unmarshaling to indicate a parse error that |
|
81 |
// should result in a field being placed in the unknown fields section (for example, |
|
82 |
// when the wire type doesn't match) as opposed to the entire unmarshal operation |
|
83 |
// failing (for example, when a field extends past the available input). |
|
84 |
// |
|
85 |
// This is a sentinel error which should never be visible to the user. |
|
86 |
var errUnknown = errors.New("unknown") |
|
87 |
|
|
88 |
func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) { |
|
89 |
mi.init() |
260
|
90 |
opts.depth-- |
|
91 |
if opts.depth < 0 { |
|
92 |
return out, errRecursionDepth |
|
93 |
} |
256
|
94 |
if flags.ProtoLegacy && mi.isMessageSet { |
|
95 |
return unmarshalMessageSet(mi, b, p, opts) |
|
96 |
} |
|
97 |
initialized := true |
|
98 |
var requiredMask uint64 |
|
99 |
var exts *map[int32]ExtensionField |
|
100 |
start := len(b) |
|
101 |
for len(b) > 0 { |
|
102 |
// Parse the tag (field number and wire type). |
|
103 |
var tag uint64 |
|
104 |
if b[0] < 0x80 { |
|
105 |
tag = uint64(b[0]) |
|
106 |
b = b[1:] |
|
107 |
} else if len(b) >= 2 && b[1] < 128 { |
|
108 |
tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 |
|
109 |
b = b[2:] |
|
110 |
} else { |
|
111 |
var n int |
|
112 |
tag, n = protowire.ConsumeVarint(b) |
|
113 |
if n < 0 { |
|
114 |
return out, errDecode |
|
115 |
} |
|
116 |
b = b[n:] |
|
117 |
} |
|
118 |
var num protowire.Number |
|
119 |
if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { |
|
120 |
return out, errDecode |
|
121 |
} else { |
|
122 |
num = protowire.Number(n) |
|
123 |
} |
|
124 |
wtyp := protowire.Type(tag & 7) |
|
125 |
|
|
126 |
if wtyp == protowire.EndGroupType { |
|
127 |
if num != groupTag { |
|
128 |
return out, errDecode |
|
129 |
} |
|
130 |
groupTag = 0 |
|
131 |
break |
|
132 |
} |
|
133 |
|
|
134 |
var f *coderFieldInfo |
|
135 |
if int(num) < len(mi.denseCoderFields) { |
|
136 |
f = mi.denseCoderFields[num] |
|
137 |
} else { |
|
138 |
f = mi.coderFields[num] |
|
139 |
} |
|
140 |
var n int |
|
141 |
err := errUnknown |
|
142 |
switch { |
|
143 |
case f != nil: |
|
144 |
if f.funcs.unmarshal == nil { |
|
145 |
break |
|
146 |
} |
|
147 |
var o unmarshalOutput |
|
148 |
o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts) |
|
149 |
n = o.n |
|
150 |
if err != nil { |
|
151 |
break |
|
152 |
} |
|
153 |
requiredMask |= f.validation.requiredBit |
|
154 |
if f.funcs.isInit != nil && !o.initialized { |
|
155 |
initialized = false |
|
156 |
} |
|
157 |
default: |
|
158 |
// Possible extension. |
|
159 |
if exts == nil && mi.extensionOffset.IsValid() { |
|
160 |
exts = p.Apply(mi.extensionOffset).Extensions() |
|
161 |
if *exts == nil { |
|
162 |
*exts = make(map[int32]ExtensionField) |
|
163 |
} |
|
164 |
} |
|
165 |
if exts == nil { |
|
166 |
break |
|
167 |
} |
|
168 |
var o unmarshalOutput |
|
169 |
o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts) |
|
170 |
if err != nil { |
|
171 |
break |
|
172 |
} |
|
173 |
n = o.n |
|
174 |
if !o.initialized { |
|
175 |
initialized = false |
|
176 |
} |
|
177 |
} |
|
178 |
if err != nil { |
|
179 |
if err != errUnknown { |
|
180 |
return out, err |
|
181 |
} |
|
182 |
n = protowire.ConsumeFieldValue(num, wtyp, b) |
|
183 |
if n < 0 { |
|
184 |
return out, errDecode |
|
185 |
} |
|
186 |
if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() { |
|
187 |
u := mi.mutableUnknownBytes(p) |
|
188 |
*u = protowire.AppendTag(*u, num, wtyp) |
|
189 |
*u = append(*u, b[:n]...) |
|
190 |
} |
|
191 |
} |
|
192 |
b = b[n:] |
|
193 |
} |
|
194 |
if groupTag != 0 { |
|
195 |
return out, errDecode |
|
196 |
} |
|
197 |
if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) { |
|
198 |
initialized = false |
|
199 |
} |
|
200 |
if initialized { |
|
201 |
out.initialized = true |
|
202 |
} |
|
203 |
out.n = start - len(b) |
|
204 |
return out, nil |
|
205 |
} |
|
206 |
|
|
207 |
func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) { |
|
208 |
x := exts[int32(num)] |
|
209 |
xt := x.Type() |
|
210 |
if xt == nil { |
|
211 |
var err error |
|
212 |
xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num) |
|
213 |
if err != nil { |
260
|
214 |
if err == protoregistry.NotFound { |
256
|
215 |
return out, errUnknown |
|
216 |
} |
|
217 |
return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err) |
|
218 |
} |
|
219 |
} |
|
220 |
xi := getExtensionFieldInfo(xt) |
|
221 |
if xi.funcs.unmarshal == nil { |
|
222 |
return out, errUnknown |
|
223 |
} |
|
224 |
if flags.LazyUnmarshalExtensions { |
|
225 |
if opts.IsDefault() && x.canLazy(xt) { |
|
226 |
out, valid := skipExtension(b, xi, num, wtyp, opts) |
|
227 |
switch valid { |
|
228 |
case ValidationValid: |
|
229 |
if out.initialized { |
|
230 |
x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n]) |
|
231 |
exts[int32(num)] = x |
|
232 |
return out, nil |
|
233 |
} |
|
234 |
case ValidationInvalid: |
|
235 |
return out, errDecode |
|
236 |
case ValidationUnknown: |
|
237 |
} |
|
238 |
} |
|
239 |
} |
|
240 |
ival := x.Value() |
|
241 |
if !ival.IsValid() && xi.unmarshalNeedsValue { |
|
242 |
// Create a new message, list, or map value to fill in. |
|
243 |
// For enums, create a prototype value to let the unmarshal func know the |
|
244 |
// concrete type. |
|
245 |
ival = xt.New() |
|
246 |
} |
|
247 |
v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts) |
|
248 |
if err != nil { |
|
249 |
return out, err |
|
250 |
} |
|
251 |
if xi.funcs.isInit == nil { |
|
252 |
out.initialized = true |
|
253 |
} |
|
254 |
x.Set(xt, v) |
|
255 |
exts[int32(num)] = x |
|
256 |
return out, nil |
|
257 |
} |
|
258 |
|
|
259 |
func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) { |
|
260 |
if xi.validation.mi == nil { |
|
261 |
return out, ValidationUnknown |
|
262 |
} |
|
263 |
xi.validation.mi.init() |
|
264 |
switch xi.validation.typ { |
|
265 |
case validationTypeMessage: |
|
266 |
if wtyp != protowire.BytesType { |
|
267 |
return out, ValidationUnknown |
|
268 |
} |
|
269 |
v, n := protowire.ConsumeBytes(b) |
|
270 |
if n < 0 { |
|
271 |
return out, ValidationUnknown |
|
272 |
} |
|
273 |
out, st := xi.validation.mi.validate(v, 0, opts) |
|
274 |
out.n = n |
|
275 |
return out, st |
|
276 |
case validationTypeGroup: |
|
277 |
if wtyp != protowire.StartGroupType { |
|
278 |
return out, ValidationUnknown |
|
279 |
} |
|
280 |
out, st := xi.validation.mi.validate(b, num, opts) |
|
281 |
return out, st |
|
282 |
default: |
|
283 |
return out, ValidationUnknown |
|
284 |
} |
|
285 |
} |