256
|
1 |
// Copyright 2018 The Go Authors. All rights reserved. |
|
2 |
// Use of this source code is governed by a BSD-style |
|
3 |
// license that can be found in the LICENSE file. |
|
4 |
|
|
5 |
package proto |
|
6 |
|
|
7 |
import ( |
|
8 |
"google.golang.org/protobuf/encoding/protowire" |
|
9 |
"google.golang.org/protobuf/internal/encoding/messageset" |
|
10 |
"google.golang.org/protobuf/internal/errors" |
|
11 |
"google.golang.org/protobuf/internal/flags" |
|
12 |
"google.golang.org/protobuf/internal/genid" |
|
13 |
"google.golang.org/protobuf/internal/pragma" |
|
14 |
"google.golang.org/protobuf/reflect/protoreflect" |
|
15 |
"google.golang.org/protobuf/reflect/protoregistry" |
|
16 |
"google.golang.org/protobuf/runtime/protoiface" |
|
17 |
) |
|
18 |
|
|
19 |
// UnmarshalOptions configures the unmarshaler. |
|
20 |
// |
|
21 |
// Example usage: |
|
22 |
// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m) |
|
23 |
type UnmarshalOptions struct { |
|
24 |
pragma.NoUnkeyedLiterals |
|
25 |
|
|
26 |
// Merge merges the input into the destination message. |
|
27 |
// The default behavior is to always reset the message before unmarshaling, |
|
28 |
// unless Merge is specified. |
|
29 |
Merge bool |
|
30 |
|
|
31 |
// AllowPartial accepts input for messages that will result in missing |
|
32 |
// required fields. If AllowPartial is false (the default), Unmarshal will |
|
33 |
// return an error if there are any missing required fields. |
|
34 |
AllowPartial bool |
|
35 |
|
|
36 |
// If DiscardUnknown is set, unknown fields are ignored. |
|
37 |
DiscardUnknown bool |
|
38 |
|
|
39 |
// Resolver is used for looking up types when unmarshaling extension fields. |
|
40 |
// If nil, this defaults to using protoregistry.GlobalTypes. |
|
41 |
Resolver interface { |
|
42 |
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) |
|
43 |
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) |
|
44 |
} |
|
45 |
} |
|
46 |
|
|
47 |
// Unmarshal parses the wire-format message in b and places the result in m. |
|
48 |
// The provided message must be mutable (e.g., a non-nil pointer to a message). |
|
49 |
func Unmarshal(b []byte, m Message) error { |
|
50 |
_, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect()) |
|
51 |
return err |
|
52 |
} |
|
53 |
|
|
54 |
// Unmarshal parses the wire-format message in b and places the result in m. |
|
55 |
// The provided message must be mutable (e.g., a non-nil pointer to a message). |
|
56 |
func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error { |
|
57 |
_, err := o.unmarshal(b, m.ProtoReflect()) |
|
58 |
return err |
|
59 |
} |
|
60 |
|
|
61 |
// UnmarshalState parses a wire-format message and places the result in m. |
|
62 |
// |
|
63 |
// This method permits fine-grained control over the unmarshaler. |
|
64 |
// Most users should use Unmarshal instead. |
|
65 |
func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) { |
|
66 |
return o.unmarshal(in.Buf, in.Message) |
|
67 |
} |
|
68 |
|
|
69 |
// unmarshal is a centralized function that all unmarshal operations go through. |
|
70 |
// For profiling purposes, avoid changing the name of this function or |
|
71 |
// introducing other code paths for unmarshal that do not go through this. |
|
72 |
func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) { |
|
73 |
if o.Resolver == nil { |
|
74 |
o.Resolver = protoregistry.GlobalTypes |
|
75 |
} |
|
76 |
if !o.Merge { |
|
77 |
Reset(m.Interface()) |
|
78 |
} |
|
79 |
allowPartial := o.AllowPartial |
|
80 |
o.Merge = true |
|
81 |
o.AllowPartial = true |
|
82 |
methods := protoMethods(m) |
|
83 |
if methods != nil && methods.Unmarshal != nil && |
|
84 |
!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) { |
|
85 |
in := protoiface.UnmarshalInput{ |
|
86 |
Message: m, |
|
87 |
Buf: b, |
|
88 |
Resolver: o.Resolver, |
|
89 |
} |
|
90 |
if o.DiscardUnknown { |
|
91 |
in.Flags |= protoiface.UnmarshalDiscardUnknown |
|
92 |
} |
|
93 |
out, err = methods.Unmarshal(in) |
|
94 |
} else { |
|
95 |
err = o.unmarshalMessageSlow(b, m) |
|
96 |
} |
|
97 |
if err != nil { |
|
98 |
return out, err |
|
99 |
} |
|
100 |
if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) { |
|
101 |
return out, nil |
|
102 |
} |
|
103 |
return out, checkInitialized(m) |
|
104 |
} |
|
105 |
|
|
106 |
func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error { |
|
107 |
_, err := o.unmarshal(b, m) |
|
108 |
return err |
|
109 |
} |
|
110 |
|
|
111 |
func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error { |
|
112 |
md := m.Descriptor() |
|
113 |
if messageset.IsMessageSet(md) { |
|
114 |
return o.unmarshalMessageSet(b, m) |
|
115 |
} |
|
116 |
fields := md.Fields() |
|
117 |
for len(b) > 0 { |
|
118 |
// Parse the tag (field number and wire type). |
|
119 |
num, wtyp, tagLen := protowire.ConsumeTag(b) |
|
120 |
if tagLen < 0 { |
|
121 |
return errDecode |
|
122 |
} |
|
123 |
if num > protowire.MaxValidNumber { |
|
124 |
return errDecode |
|
125 |
} |
|
126 |
|
|
127 |
// Find the field descriptor for this field number. |
|
128 |
fd := fields.ByNumber(num) |
|
129 |
if fd == nil && md.ExtensionRanges().Has(num) { |
|
130 |
extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num) |
|
131 |
if err != nil && err != protoregistry.NotFound { |
|
132 |
return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err) |
|
133 |
} |
|
134 |
if extType != nil { |
|
135 |
fd = extType.TypeDescriptor() |
|
136 |
} |
|
137 |
} |
|
138 |
var err error |
|
139 |
if fd == nil { |
|
140 |
err = errUnknown |
|
141 |
} else if flags.ProtoLegacy { |
|
142 |
if fd.IsWeak() && fd.Message().IsPlaceholder() { |
|
143 |
err = errUnknown // weak referent is not linked in |
|
144 |
} |
|
145 |
} |
|
146 |
|
|
147 |
// Parse the field value. |
|
148 |
var valLen int |
|
149 |
switch { |
|
150 |
case err != nil: |
|
151 |
case fd.IsList(): |
|
152 |
valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd) |
|
153 |
case fd.IsMap(): |
|
154 |
valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd) |
|
155 |
default: |
|
156 |
valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd) |
|
157 |
} |
|
158 |
if err != nil { |
|
159 |
if err != errUnknown { |
|
160 |
return err |
|
161 |
} |
|
162 |
valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:]) |
|
163 |
if valLen < 0 { |
|
164 |
return errDecode |
|
165 |
} |
|
166 |
if !o.DiscardUnknown { |
|
167 |
m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...)) |
|
168 |
} |
|
169 |
} |
|
170 |
b = b[tagLen+valLen:] |
|
171 |
} |
|
172 |
return nil |
|
173 |
} |
|
174 |
|
|
175 |
func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) { |
|
176 |
v, n, err := o.unmarshalScalar(b, wtyp, fd) |
|
177 |
if err != nil { |
|
178 |
return 0, err |
|
179 |
} |
|
180 |
switch fd.Kind() { |
|
181 |
case protoreflect.GroupKind, protoreflect.MessageKind: |
|
182 |
m2 := m.Mutable(fd).Message() |
|
183 |
if err := o.unmarshalMessage(v.Bytes(), m2); err != nil { |
|
184 |
return n, err |
|
185 |
} |
|
186 |
default: |
|
187 |
// Non-message scalars replace the previous value. |
|
188 |
m.Set(fd, v) |
|
189 |
} |
|
190 |
return n, nil |
|
191 |
} |
|
192 |
|
|
193 |
func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) { |
|
194 |
if wtyp != protowire.BytesType { |
|
195 |
return 0, errUnknown |
|
196 |
} |
|
197 |
b, n = protowire.ConsumeBytes(b) |
|
198 |
if n < 0 { |
|
199 |
return 0, errDecode |
|
200 |
} |
|
201 |
var ( |
|
202 |
keyField = fd.MapKey() |
|
203 |
valField = fd.MapValue() |
|
204 |
key protoreflect.Value |
|
205 |
val protoreflect.Value |
|
206 |
haveKey bool |
|
207 |
haveVal bool |
|
208 |
) |
|
209 |
switch valField.Kind() { |
|
210 |
case protoreflect.GroupKind, protoreflect.MessageKind: |
|
211 |
val = mapv.NewValue() |
|
212 |
} |
|
213 |
// Map entries are represented as a two-element message with fields |
|
214 |
// containing the key and value. |
|
215 |
for len(b) > 0 { |
|
216 |
num, wtyp, n := protowire.ConsumeTag(b) |
|
217 |
if n < 0 { |
|
218 |
return 0, errDecode |
|
219 |
} |
|
220 |
if num > protowire.MaxValidNumber { |
|
221 |
return 0, errDecode |
|
222 |
} |
|
223 |
b = b[n:] |
|
224 |
err = errUnknown |
|
225 |
switch num { |
|
226 |
case genid.MapEntry_Key_field_number: |
|
227 |
key, n, err = o.unmarshalScalar(b, wtyp, keyField) |
|
228 |
if err != nil { |
|
229 |
break |
|
230 |
} |
|
231 |
haveKey = true |
|
232 |
case genid.MapEntry_Value_field_number: |
|
233 |
var v protoreflect.Value |
|
234 |
v, n, err = o.unmarshalScalar(b, wtyp, valField) |
|
235 |
if err != nil { |
|
236 |
break |
|
237 |
} |
|
238 |
switch valField.Kind() { |
|
239 |
case protoreflect.GroupKind, protoreflect.MessageKind: |
|
240 |
if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil { |
|
241 |
return 0, err |
|
242 |
} |
|
243 |
default: |
|
244 |
val = v |
|
245 |
} |
|
246 |
haveVal = true |
|
247 |
} |
|
248 |
if err == errUnknown { |
|
249 |
n = protowire.ConsumeFieldValue(num, wtyp, b) |
|
250 |
if n < 0 { |
|
251 |
return 0, errDecode |
|
252 |
} |
|
253 |
} else if err != nil { |
|
254 |
return 0, err |
|
255 |
} |
|
256 |
b = b[n:] |
|
257 |
} |
|
258 |
// Every map entry should have entries for key and value, but this is not strictly required. |
|
259 |
if !haveKey { |
|
260 |
key = keyField.Default() |
|
261 |
} |
|
262 |
if !haveVal { |
|
263 |
switch valField.Kind() { |
|
264 |
case protoreflect.GroupKind, protoreflect.MessageKind: |
|
265 |
default: |
|
266 |
val = valField.Default() |
|
267 |
} |
|
268 |
} |
|
269 |
mapv.Set(key.MapKey(), val) |
|
270 |
return n, nil |
|
271 |
} |
|
272 |
|
|
273 |
// errUnknown is used internally to indicate fields which should be added |
|
274 |
// to the unknown field set of a message. It is never returned from an exported |
|
275 |
// function. |
|
276 |
var errUnknown = errors.New("BUG: internal error (unknown)") |
|
277 |
|
|
278 |
var errDecode = errors.New("cannot parse invalid wire-format data") |