14 "google.golang.org/protobuf/encoding/protowire" |
14 "google.golang.org/protobuf/encoding/protowire" |
15 "google.golang.org/protobuf/internal/encoding/messageset" |
15 "google.golang.org/protobuf/internal/encoding/messageset" |
16 "google.golang.org/protobuf/internal/flags" |
16 "google.golang.org/protobuf/internal/flags" |
17 "google.golang.org/protobuf/internal/genid" |
17 "google.golang.org/protobuf/internal/genid" |
18 "google.golang.org/protobuf/internal/strs" |
18 "google.golang.org/protobuf/internal/strs" |
19 pref "google.golang.org/protobuf/reflect/protoreflect" |
19 "google.golang.org/protobuf/reflect/protoreflect" |
20 preg "google.golang.org/protobuf/reflect/protoregistry" |
20 "google.golang.org/protobuf/reflect/protoregistry" |
21 piface "google.golang.org/protobuf/runtime/protoiface" |
21 "google.golang.org/protobuf/runtime/protoiface" |
22 ) |
22 ) |
23 |
23 |
24 // ValidationStatus is the result of validating the wire-format encoding of a message. |
24 // ValidationStatus is the result of validating the wire-format encoding of a message. |
25 type ValidationStatus int |
25 type ValidationStatus int |
26 |
26 |
54 |
54 |
55 // Validate determines whether the contents of the buffer are a valid wire encoding |
55 // Validate determines whether the contents of the buffer are a valid wire encoding |
56 // of the message type. |
56 // of the message type. |
57 // |
57 // |
58 // This function is exposed for testing. |
58 // This function is exposed for testing. |
59 func Validate(mt pref.MessageType, in piface.UnmarshalInput) (out piface.UnmarshalOutput, _ ValidationStatus) { |
59 func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) { |
60 mi, ok := mt.(*MessageInfo) |
60 mi, ok := mt.(*MessageInfo) |
61 if !ok { |
61 if !ok { |
62 return out, ValidationUnknown |
62 return out, ValidationUnknown |
63 } |
63 } |
64 if in.Resolver == nil { |
64 if in.Resolver == nil { |
65 in.Resolver = preg.GlobalTypes |
65 in.Resolver = protoregistry.GlobalTypes |
66 } |
66 } |
67 o, st := mi.validate(in.Buf, 0, unmarshalOptions{ |
67 o, st := mi.validate(in.Buf, 0, unmarshalOptions{ |
68 flags: in.Flags, |
68 flags: in.Flags, |
69 resolver: in.Resolver, |
69 resolver: in.Resolver, |
70 }) |
70 }) |
71 if o.initialized { |
71 if o.initialized { |
72 out.Flags |= piface.UnmarshalInitialized |
72 out.Flags |= protoiface.UnmarshalInitialized |
73 } |
73 } |
74 return out, st |
74 return out, st |
75 } |
75 } |
76 |
76 |
77 type validationInfo struct { |
77 type validationInfo struct { |
104 validationTypeBytes |
104 validationTypeBytes |
105 validationTypeUTF8String |
105 validationTypeUTF8String |
106 validationTypeMessageSetItem |
106 validationTypeMessageSetItem |
107 ) |
107 ) |
108 |
108 |
109 func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo { |
109 func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo { |
110 var vi validationInfo |
110 var vi validationInfo |
111 switch { |
111 switch { |
112 case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic(): |
112 case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic(): |
113 switch fd.Kind() { |
113 switch fd.Kind() { |
114 case pref.MessageKind: |
114 case protoreflect.MessageKind: |
115 vi.typ = validationTypeMessage |
115 vi.typ = validationTypeMessage |
116 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok { |
116 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok { |
117 vi.mi = getMessageInfo(ot.Field(0).Type) |
117 vi.mi = getMessageInfo(ot.Field(0).Type) |
118 } |
118 } |
119 case pref.GroupKind: |
119 case protoreflect.GroupKind: |
120 vi.typ = validationTypeGroup |
120 vi.typ = validationTypeGroup |
121 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok { |
121 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok { |
122 vi.mi = getMessageInfo(ot.Field(0).Type) |
122 vi.mi = getMessageInfo(ot.Field(0).Type) |
123 } |
123 } |
124 case pref.StringKind: |
124 case protoreflect.StringKind: |
125 if strs.EnforceUTF8(fd) { |
125 if strs.EnforceUTF8(fd) { |
126 vi.typ = validationTypeUTF8String |
126 vi.typ = validationTypeUTF8String |
127 } |
127 } |
128 } |
128 } |
129 default: |
129 default: |
130 vi = newValidationInfo(fd, ft) |
130 vi = newValidationInfo(fd, ft) |
131 } |
131 } |
132 if fd.Cardinality() == pref.Required { |
132 if fd.Cardinality() == protoreflect.Required { |
133 // Avoid overflow. The required field check is done with a 64-bit mask, with |
133 // Avoid overflow. The required field check is done with a 64-bit mask, with |
134 // any message containing more than 64 required fields always reported as |
134 // any message containing more than 64 required fields always reported as |
135 // potentially uninitialized, so it is not important to get a precise count |
135 // potentially uninitialized, so it is not important to get a precise count |
136 // of the required fields past 64. |
136 // of the required fields past 64. |
137 if mi.numRequiredFields < math.MaxUint8 { |
137 if mi.numRequiredFields < math.MaxUint8 { |
140 } |
140 } |
141 } |
141 } |
142 return vi |
142 return vi |
143 } |
143 } |
144 |
144 |
145 func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo { |
145 func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo { |
146 var vi validationInfo |
146 var vi validationInfo |
147 switch { |
147 switch { |
148 case fd.IsList(): |
148 case fd.IsList(): |
149 switch fd.Kind() { |
149 switch fd.Kind() { |
150 case pref.MessageKind: |
150 case protoreflect.MessageKind: |
151 vi.typ = validationTypeMessage |
151 vi.typ = validationTypeMessage |
152 if ft.Kind() == reflect.Slice { |
152 if ft.Kind() == reflect.Slice { |
153 vi.mi = getMessageInfo(ft.Elem()) |
153 vi.mi = getMessageInfo(ft.Elem()) |
154 } |
154 } |
155 case pref.GroupKind: |
155 case protoreflect.GroupKind: |
156 vi.typ = validationTypeGroup |
156 vi.typ = validationTypeGroup |
157 if ft.Kind() == reflect.Slice { |
157 if ft.Kind() == reflect.Slice { |
158 vi.mi = getMessageInfo(ft.Elem()) |
158 vi.mi = getMessageInfo(ft.Elem()) |
159 } |
159 } |
160 case pref.StringKind: |
160 case protoreflect.StringKind: |
161 vi.typ = validationTypeBytes |
161 vi.typ = validationTypeBytes |
162 if strs.EnforceUTF8(fd) { |
162 if strs.EnforceUTF8(fd) { |
163 vi.typ = validationTypeUTF8String |
163 vi.typ = validationTypeUTF8String |
164 } |
164 } |
165 default: |
165 default: |
173 } |
173 } |
174 } |
174 } |
175 case fd.IsMap(): |
175 case fd.IsMap(): |
176 vi.typ = validationTypeMap |
176 vi.typ = validationTypeMap |
177 switch fd.MapKey().Kind() { |
177 switch fd.MapKey().Kind() { |
178 case pref.StringKind: |
178 case protoreflect.StringKind: |
179 if strs.EnforceUTF8(fd) { |
179 if strs.EnforceUTF8(fd) { |
180 vi.keyType = validationTypeUTF8String |
180 vi.keyType = validationTypeUTF8String |
181 } |
181 } |
182 } |
182 } |
183 switch fd.MapValue().Kind() { |
183 switch fd.MapValue().Kind() { |
184 case pref.MessageKind: |
184 case protoreflect.MessageKind: |
185 vi.valType = validationTypeMessage |
185 vi.valType = validationTypeMessage |
186 if ft.Kind() == reflect.Map { |
186 if ft.Kind() == reflect.Map { |
187 vi.mi = getMessageInfo(ft.Elem()) |
187 vi.mi = getMessageInfo(ft.Elem()) |
188 } |
188 } |
189 case pref.StringKind: |
189 case protoreflect.StringKind: |
190 if strs.EnforceUTF8(fd) { |
190 if strs.EnforceUTF8(fd) { |
191 vi.valType = validationTypeUTF8String |
191 vi.valType = validationTypeUTF8String |
192 } |
192 } |
193 } |
193 } |
194 default: |
194 default: |
195 switch fd.Kind() { |
195 switch fd.Kind() { |
196 case pref.MessageKind: |
196 case protoreflect.MessageKind: |
197 vi.typ = validationTypeMessage |
197 vi.typ = validationTypeMessage |
198 if !fd.IsWeak() { |
198 if !fd.IsWeak() { |
199 vi.mi = getMessageInfo(ft) |
199 vi.mi = getMessageInfo(ft) |
200 } |
200 } |
201 case pref.GroupKind: |
201 case protoreflect.GroupKind: |
202 vi.typ = validationTypeGroup |
202 vi.typ = validationTypeGroup |
203 vi.mi = getMessageInfo(ft) |
203 vi.mi = getMessageInfo(ft) |
204 case pref.StringKind: |
204 case protoreflect.StringKind: |
205 vi.typ = validationTypeBytes |
205 vi.typ = validationTypeBytes |
206 if strs.EnforceUTF8(fd) { |
206 if strs.EnforceUTF8(fd) { |
207 vi.typ = validationTypeUTF8String |
207 vi.typ = validationTypeUTF8String |
208 } |
208 } |
209 default: |
209 default: |
312 fd := st.mi.Desc.Fields().ByNumber(num) |
312 fd := st.mi.Desc.Fields().ByNumber(num) |
313 if fd == nil || !fd.IsWeak() { |
313 if fd == nil || !fd.IsWeak() { |
314 break |
314 break |
315 } |
315 } |
316 messageName := fd.Message().FullName() |
316 messageName := fd.Message().FullName() |
317 messageType, err := preg.GlobalTypes.FindMessageByName(messageName) |
317 messageType, err := protoregistry.GlobalTypes.FindMessageByName(messageName) |
318 switch err { |
318 switch err { |
319 case nil: |
319 case nil: |
320 vi.mi, _ = messageType.(*MessageInfo) |
320 vi.mi, _ = messageType.(*MessageInfo) |
321 case preg.NotFound: |
321 case protoregistry.NotFound: |
322 vi.typ = validationTypeBytes |
322 vi.typ = validationTypeBytes |
323 default: |
323 default: |
324 return out, ValidationUnknown |
324 return out, ValidationUnknown |
325 } |
325 } |
326 } |
326 } |
333 // 2. The resolver returns preg.NotFound. |
333 // 2. The resolver returns preg.NotFound. |
334 // In this case, a type added to the resolver in the future could cause |
334 // In this case, a type added to the resolver in the future could cause |
335 // unmarshaling to begin failing. Supporting this requires some way to |
335 // unmarshaling to begin failing. Supporting this requires some way to |
336 // determine if the resolver is frozen. |
336 // determine if the resolver is frozen. |
337 xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num) |
337 xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num) |
338 if err != nil && err != preg.NotFound { |
338 if err != nil && err != protoregistry.NotFound { |
339 return out, ValidationUnknown |
339 return out, ValidationUnknown |
340 } |
340 } |
341 if err == nil { |
341 if err == nil { |
342 vi = getExtensionFieldInfo(xt).validation |
342 vi = getExtensionFieldInfo(xt).validation |
343 } |
343 } |
511 if err != nil { |
511 if err != nil { |
512 return out, ValidationInvalid |
512 return out, ValidationInvalid |
513 } |
513 } |
514 xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid) |
514 xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid) |
515 switch { |
515 switch { |
516 case err == preg.NotFound: |
516 case err == protoregistry.NotFound: |
517 b = b[n:] |
517 b = b[n:] |
518 case err != nil: |
518 case err != nil: |
519 return out, ValidationUnknown |
519 return out, ValidationUnknown |
520 default: |
520 default: |
521 xvi := getExtensionFieldInfo(xt).validation |
521 xvi := getExtensionFieldInfo(xt).validation |