|
1 // Copyright 2018 The Go Authors. All rights reserved. |
|
2 // Use of this source code is governed by a BSD-style |
|
3 // license that can be found in the LICENSE file. |
|
4 |
|
5 package proto |
|
6 |
|
7 import ( |
|
8 "google.golang.org/protobuf/encoding/protowire" |
|
9 "google.golang.org/protobuf/internal/encoding/messageset" |
|
10 "google.golang.org/protobuf/internal/errors" |
|
11 "google.golang.org/protobuf/internal/flags" |
|
12 "google.golang.org/protobuf/internal/genid" |
|
13 "google.golang.org/protobuf/internal/pragma" |
|
14 "google.golang.org/protobuf/reflect/protoreflect" |
|
15 "google.golang.org/protobuf/reflect/protoregistry" |
|
16 "google.golang.org/protobuf/runtime/protoiface" |
|
17 ) |
|
18 |
|
19 // UnmarshalOptions configures the unmarshaler. |
|
20 // |
|
21 // Example usage: |
|
22 // err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m) |
|
23 type UnmarshalOptions struct { |
|
24 pragma.NoUnkeyedLiterals |
|
25 |
|
26 // Merge merges the input into the destination message. |
|
27 // The default behavior is to always reset the message before unmarshaling, |
|
28 // unless Merge is specified. |
|
29 Merge bool |
|
30 |
|
31 // AllowPartial accepts input for messages that will result in missing |
|
32 // required fields. If AllowPartial is false (the default), Unmarshal will |
|
33 // return an error if there are any missing required fields. |
|
34 AllowPartial bool |
|
35 |
|
36 // If DiscardUnknown is set, unknown fields are ignored. |
|
37 DiscardUnknown bool |
|
38 |
|
39 // Resolver is used for looking up types when unmarshaling extension fields. |
|
40 // If nil, this defaults to using protoregistry.GlobalTypes. |
|
41 Resolver interface { |
|
42 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) |
|
43 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) |
|
44 } |
|
45 } |
|
46 |
|
47 // Unmarshal parses the wire-format message in b and places the result in m. |
|
48 // The provided message must be mutable (e.g., a non-nil pointer to a message). |
|
49 func Unmarshal(b []byte, m Message) error { |
|
50 _, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect()) |
|
51 return err |
|
52 } |
|
53 |
|
54 // Unmarshal parses the wire-format message in b and places the result in m. |
|
55 // The provided message must be mutable (e.g., a non-nil pointer to a message). |
|
56 func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error { |
|
57 _, err := o.unmarshal(b, m.ProtoReflect()) |
|
58 return err |
|
59 } |
|
60 |
|
61 // UnmarshalState parses a wire-format message and places the result in m. |
|
62 // |
|
63 // This method permits fine-grained control over the unmarshaler. |
|
64 // Most users should use Unmarshal instead. |
|
65 func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) { |
|
66 return o.unmarshal(in.Buf, in.Message) |
|
67 } |
|
68 |
|
69 // unmarshal is a centralized function that all unmarshal operations go through. |
|
70 // For profiling purposes, avoid changing the name of this function or |
|
71 // introducing other code paths for unmarshal that do not go through this. |
|
72 func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) { |
|
73 if o.Resolver == nil { |
|
74 o.Resolver = protoregistry.GlobalTypes |
|
75 } |
|
76 if !o.Merge { |
|
77 Reset(m.Interface()) |
|
78 } |
|
79 allowPartial := o.AllowPartial |
|
80 o.Merge = true |
|
81 o.AllowPartial = true |
|
82 methods := protoMethods(m) |
|
83 if methods != nil && methods.Unmarshal != nil && |
|
84 !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) { |
|
85 in := protoiface.UnmarshalInput{ |
|
86 Message: m, |
|
87 Buf: b, |
|
88 Resolver: o.Resolver, |
|
89 } |
|
90 if o.DiscardUnknown { |
|
91 in.Flags |= protoiface.UnmarshalDiscardUnknown |
|
92 } |
|
93 out, err = methods.Unmarshal(in) |
|
94 } else { |
|
95 err = o.unmarshalMessageSlow(b, m) |
|
96 } |
|
97 if err != nil { |
|
98 return out, err |
|
99 } |
|
100 if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) { |
|
101 return out, nil |
|
102 } |
|
103 return out, checkInitialized(m) |
|
104 } |
|
105 |
|
106 func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error { |
|
107 _, err := o.unmarshal(b, m) |
|
108 return err |
|
109 } |
|
110 |
|
111 func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error { |
|
112 md := m.Descriptor() |
|
113 if messageset.IsMessageSet(md) { |
|
114 return o.unmarshalMessageSet(b, m) |
|
115 } |
|
116 fields := md.Fields() |
|
117 for len(b) > 0 { |
|
118 // Parse the tag (field number and wire type). |
|
119 num, wtyp, tagLen := protowire.ConsumeTag(b) |
|
120 if tagLen < 0 { |
|
121 return errDecode |
|
122 } |
|
123 if num > protowire.MaxValidNumber { |
|
124 return errDecode |
|
125 } |
|
126 |
|
127 // Find the field descriptor for this field number. |
|
128 fd := fields.ByNumber(num) |
|
129 if fd == nil && md.ExtensionRanges().Has(num) { |
|
130 extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num) |
|
131 if err != nil && err != protoregistry.NotFound { |
|
132 return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err) |
|
133 } |
|
134 if extType != nil { |
|
135 fd = extType.TypeDescriptor() |
|
136 } |
|
137 } |
|
138 var err error |
|
139 if fd == nil { |
|
140 err = errUnknown |
|
141 } else if flags.ProtoLegacy { |
|
142 if fd.IsWeak() && fd.Message().IsPlaceholder() { |
|
143 err = errUnknown // weak referent is not linked in |
|
144 } |
|
145 } |
|
146 |
|
147 // Parse the field value. |
|
148 var valLen int |
|
149 switch { |
|
150 case err != nil: |
|
151 case fd.IsList(): |
|
152 valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd) |
|
153 case fd.IsMap(): |
|
154 valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd) |
|
155 default: |
|
156 valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd) |
|
157 } |
|
158 if err != nil { |
|
159 if err != errUnknown { |
|
160 return err |
|
161 } |
|
162 valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:]) |
|
163 if valLen < 0 { |
|
164 return errDecode |
|
165 } |
|
166 if !o.DiscardUnknown { |
|
167 m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...)) |
|
168 } |
|
169 } |
|
170 b = b[tagLen+valLen:] |
|
171 } |
|
172 return nil |
|
173 } |
|
174 |
|
175 func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) { |
|
176 v, n, err := o.unmarshalScalar(b, wtyp, fd) |
|
177 if err != nil { |
|
178 return 0, err |
|
179 } |
|
180 switch fd.Kind() { |
|
181 case protoreflect.GroupKind, protoreflect.MessageKind: |
|
182 m2 := m.Mutable(fd).Message() |
|
183 if err := o.unmarshalMessage(v.Bytes(), m2); err != nil { |
|
184 return n, err |
|
185 } |
|
186 default: |
|
187 // Non-message scalars replace the previous value. |
|
188 m.Set(fd, v) |
|
189 } |
|
190 return n, nil |
|
191 } |
|
192 |
|
193 func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) { |
|
194 if wtyp != protowire.BytesType { |
|
195 return 0, errUnknown |
|
196 } |
|
197 b, n = protowire.ConsumeBytes(b) |
|
198 if n < 0 { |
|
199 return 0, errDecode |
|
200 } |
|
201 var ( |
|
202 keyField = fd.MapKey() |
|
203 valField = fd.MapValue() |
|
204 key protoreflect.Value |
|
205 val protoreflect.Value |
|
206 haveKey bool |
|
207 haveVal bool |
|
208 ) |
|
209 switch valField.Kind() { |
|
210 case protoreflect.GroupKind, protoreflect.MessageKind: |
|
211 val = mapv.NewValue() |
|
212 } |
|
213 // Map entries are represented as a two-element message with fields |
|
214 // containing the key and value. |
|
215 for len(b) > 0 { |
|
216 num, wtyp, n := protowire.ConsumeTag(b) |
|
217 if n < 0 { |
|
218 return 0, errDecode |
|
219 } |
|
220 if num > protowire.MaxValidNumber { |
|
221 return 0, errDecode |
|
222 } |
|
223 b = b[n:] |
|
224 err = errUnknown |
|
225 switch num { |
|
226 case genid.MapEntry_Key_field_number: |
|
227 key, n, err = o.unmarshalScalar(b, wtyp, keyField) |
|
228 if err != nil { |
|
229 break |
|
230 } |
|
231 haveKey = true |
|
232 case genid.MapEntry_Value_field_number: |
|
233 var v protoreflect.Value |
|
234 v, n, err = o.unmarshalScalar(b, wtyp, valField) |
|
235 if err != nil { |
|
236 break |
|
237 } |
|
238 switch valField.Kind() { |
|
239 case protoreflect.GroupKind, protoreflect.MessageKind: |
|
240 if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil { |
|
241 return 0, err |
|
242 } |
|
243 default: |
|
244 val = v |
|
245 } |
|
246 haveVal = true |
|
247 } |
|
248 if err == errUnknown { |
|
249 n = protowire.ConsumeFieldValue(num, wtyp, b) |
|
250 if n < 0 { |
|
251 return 0, errDecode |
|
252 } |
|
253 } else if err != nil { |
|
254 return 0, err |
|
255 } |
|
256 b = b[n:] |
|
257 } |
|
258 // Every map entry should have entries for key and value, but this is not strictly required. |
|
259 if !haveKey { |
|
260 key = keyField.Default() |
|
261 } |
|
262 if !haveVal { |
|
263 switch valField.Kind() { |
|
264 case protoreflect.GroupKind, protoreflect.MessageKind: |
|
265 default: |
|
266 val = valField.Default() |
|
267 } |
|
268 } |
|
269 mapv.Set(key.MapKey(), val) |
|
270 return n, nil |
|
271 } |
|
272 |
|
273 // errUnknown is used internally to indicate fields which should be added |
|
274 // to the unknown field set of a message. It is never returned from an exported |
|
275 // function. |
|
276 var errUnknown = errors.New("BUG: internal error (unknown)") |
|
277 |
|
278 var errDecode = errors.New("cannot parse invalid wire-format data") |