|
1 // Copyright 2019 The Go Authors. All rights reserved. |
|
2 // Use of this source code is governed by a BSD-style |
|
3 // license that can be found in the LICENSE file. |
|
4 |
|
5 package impl |
|
6 |
|
7 import ( |
|
8 "math/bits" |
|
9 |
|
10 "google.golang.org/protobuf/encoding/protowire" |
|
11 "google.golang.org/protobuf/internal/errors" |
|
12 "google.golang.org/protobuf/internal/flags" |
|
13 "google.golang.org/protobuf/proto" |
|
14 "google.golang.org/protobuf/reflect/protoreflect" |
|
15 preg "google.golang.org/protobuf/reflect/protoregistry" |
|
16 "google.golang.org/protobuf/runtime/protoiface" |
|
17 piface "google.golang.org/protobuf/runtime/protoiface" |
|
18 ) |
|
19 |
|
20 var errDecode = errors.New("cannot parse invalid wire-format data") |
|
21 |
|
22 type unmarshalOptions struct { |
|
23 flags protoiface.UnmarshalInputFlags |
|
24 resolver interface { |
|
25 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) |
|
26 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) |
|
27 } |
|
28 } |
|
29 |
|
30 func (o unmarshalOptions) Options() proto.UnmarshalOptions { |
|
31 return proto.UnmarshalOptions{ |
|
32 Merge: true, |
|
33 AllowPartial: true, |
|
34 DiscardUnknown: o.DiscardUnknown(), |
|
35 Resolver: o.resolver, |
|
36 } |
|
37 } |
|
38 |
|
39 func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&piface.UnmarshalDiscardUnknown != 0 } |
|
40 |
|
41 func (o unmarshalOptions) IsDefault() bool { |
|
42 return o.flags == 0 && o.resolver == preg.GlobalTypes |
|
43 } |
|
44 |
|
45 var lazyUnmarshalOptions = unmarshalOptions{ |
|
46 resolver: preg.GlobalTypes, |
|
47 } |
|
48 |
|
49 type unmarshalOutput struct { |
|
50 n int // number of bytes consumed |
|
51 initialized bool |
|
52 } |
|
53 |
|
54 // unmarshal is protoreflect.Methods.Unmarshal. |
|
55 func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) { |
|
56 var p pointer |
|
57 if ms, ok := in.Message.(*messageState); ok { |
|
58 p = ms.pointer() |
|
59 } else { |
|
60 p = in.Message.(*messageReflectWrapper).pointer() |
|
61 } |
|
62 out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{ |
|
63 flags: in.Flags, |
|
64 resolver: in.Resolver, |
|
65 }) |
|
66 var flags piface.UnmarshalOutputFlags |
|
67 if out.initialized { |
|
68 flags |= piface.UnmarshalInitialized |
|
69 } |
|
70 return piface.UnmarshalOutput{ |
|
71 Flags: flags, |
|
72 }, err |
|
73 } |
|
74 |
|
75 // errUnknown is returned during unmarshaling to indicate a parse error that |
|
76 // should result in a field being placed in the unknown fields section (for example, |
|
77 // when the wire type doesn't match) as opposed to the entire unmarshal operation |
|
78 // failing (for example, when a field extends past the available input). |
|
79 // |
|
80 // This is a sentinel error which should never be visible to the user. |
|
81 var errUnknown = errors.New("unknown") |
|
82 |
|
83 func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) { |
|
84 mi.init() |
|
85 if flags.ProtoLegacy && mi.isMessageSet { |
|
86 return unmarshalMessageSet(mi, b, p, opts) |
|
87 } |
|
88 initialized := true |
|
89 var requiredMask uint64 |
|
90 var exts *map[int32]ExtensionField |
|
91 start := len(b) |
|
92 for len(b) > 0 { |
|
93 // Parse the tag (field number and wire type). |
|
94 var tag uint64 |
|
95 if b[0] < 0x80 { |
|
96 tag = uint64(b[0]) |
|
97 b = b[1:] |
|
98 } else if len(b) >= 2 && b[1] < 128 { |
|
99 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 |
|
100 b = b[2:] |
|
101 } else { |
|
102 var n int |
|
103 tag, n = protowire.ConsumeVarint(b) |
|
104 if n < 0 { |
|
105 return out, errDecode |
|
106 } |
|
107 b = b[n:] |
|
108 } |
|
109 var num protowire.Number |
|
110 if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { |
|
111 return out, errDecode |
|
112 } else { |
|
113 num = protowire.Number(n) |
|
114 } |
|
115 wtyp := protowire.Type(tag & 7) |
|
116 |
|
117 if wtyp == protowire.EndGroupType { |
|
118 if num != groupTag { |
|
119 return out, errDecode |
|
120 } |
|
121 groupTag = 0 |
|
122 break |
|
123 } |
|
124 |
|
125 var f *coderFieldInfo |
|
126 if int(num) < len(mi.denseCoderFields) { |
|
127 f = mi.denseCoderFields[num] |
|
128 } else { |
|
129 f = mi.coderFields[num] |
|
130 } |
|
131 var n int |
|
132 err := errUnknown |
|
133 switch { |
|
134 case f != nil: |
|
135 if f.funcs.unmarshal == nil { |
|
136 break |
|
137 } |
|
138 var o unmarshalOutput |
|
139 o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts) |
|
140 n = o.n |
|
141 if err != nil { |
|
142 break |
|
143 } |
|
144 requiredMask |= f.validation.requiredBit |
|
145 if f.funcs.isInit != nil && !o.initialized { |
|
146 initialized = false |
|
147 } |
|
148 default: |
|
149 // Possible extension. |
|
150 if exts == nil && mi.extensionOffset.IsValid() { |
|
151 exts = p.Apply(mi.extensionOffset).Extensions() |
|
152 if *exts == nil { |
|
153 *exts = make(map[int32]ExtensionField) |
|
154 } |
|
155 } |
|
156 if exts == nil { |
|
157 break |
|
158 } |
|
159 var o unmarshalOutput |
|
160 o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts) |
|
161 if err != nil { |
|
162 break |
|
163 } |
|
164 n = o.n |
|
165 if !o.initialized { |
|
166 initialized = false |
|
167 } |
|
168 } |
|
169 if err != nil { |
|
170 if err != errUnknown { |
|
171 return out, err |
|
172 } |
|
173 n = protowire.ConsumeFieldValue(num, wtyp, b) |
|
174 if n < 0 { |
|
175 return out, errDecode |
|
176 } |
|
177 if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() { |
|
178 u := mi.mutableUnknownBytes(p) |
|
179 *u = protowire.AppendTag(*u, num, wtyp) |
|
180 *u = append(*u, b[:n]...) |
|
181 } |
|
182 } |
|
183 b = b[n:] |
|
184 } |
|
185 if groupTag != 0 { |
|
186 return out, errDecode |
|
187 } |
|
188 if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) { |
|
189 initialized = false |
|
190 } |
|
191 if initialized { |
|
192 out.initialized = true |
|
193 } |
|
194 out.n = start - len(b) |
|
195 return out, nil |
|
196 } |
|
197 |
|
198 func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) { |
|
199 x := exts[int32(num)] |
|
200 xt := x.Type() |
|
201 if xt == nil { |
|
202 var err error |
|
203 xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num) |
|
204 if err != nil { |
|
205 if err == preg.NotFound { |
|
206 return out, errUnknown |
|
207 } |
|
208 return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err) |
|
209 } |
|
210 } |
|
211 xi := getExtensionFieldInfo(xt) |
|
212 if xi.funcs.unmarshal == nil { |
|
213 return out, errUnknown |
|
214 } |
|
215 if flags.LazyUnmarshalExtensions { |
|
216 if opts.IsDefault() && x.canLazy(xt) { |
|
217 out, valid := skipExtension(b, xi, num, wtyp, opts) |
|
218 switch valid { |
|
219 case ValidationValid: |
|
220 if out.initialized { |
|
221 x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n]) |
|
222 exts[int32(num)] = x |
|
223 return out, nil |
|
224 } |
|
225 case ValidationInvalid: |
|
226 return out, errDecode |
|
227 case ValidationUnknown: |
|
228 } |
|
229 } |
|
230 } |
|
231 ival := x.Value() |
|
232 if !ival.IsValid() && xi.unmarshalNeedsValue { |
|
233 // Create a new message, list, or map value to fill in. |
|
234 // For enums, create a prototype value to let the unmarshal func know the |
|
235 // concrete type. |
|
236 ival = xt.New() |
|
237 } |
|
238 v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts) |
|
239 if err != nil { |
|
240 return out, err |
|
241 } |
|
242 if xi.funcs.isInit == nil { |
|
243 out.initialized = true |
|
244 } |
|
245 x.Set(xt, v) |
|
246 exts[int32(num)] = x |
|
247 return out, nil |
|
248 } |
|
249 |
|
250 func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) { |
|
251 if xi.validation.mi == nil { |
|
252 return out, ValidationUnknown |
|
253 } |
|
254 xi.validation.mi.init() |
|
255 switch xi.validation.typ { |
|
256 case validationTypeMessage: |
|
257 if wtyp != protowire.BytesType { |
|
258 return out, ValidationUnknown |
|
259 } |
|
260 v, n := protowire.ConsumeBytes(b) |
|
261 if n < 0 { |
|
262 return out, ValidationUnknown |
|
263 } |
|
264 out, st := xi.validation.mi.validate(v, 0, opts) |
|
265 out.n = n |
|
266 return out, st |
|
267 case validationTypeGroup: |
|
268 if wtyp != protowire.StartGroupType { |
|
269 return out, ValidationUnknown |
|
270 } |
|
271 out, st := xi.validation.mi.validate(b, num, opts) |
|
272 return out, st |
|
273 default: |
|
274 return out, ValidationUnknown |
|
275 } |
|
276 } |