vendor/google.golang.org/protobuf/internal/impl/validate.go
changeset 260 445e01aede7e
parent 256 6d9efbef00a9
equal deleted inserted replaced
259:db4911b0c721 260:445e01aede7e
    14 	"google.golang.org/protobuf/encoding/protowire"
    14 	"google.golang.org/protobuf/encoding/protowire"
    15 	"google.golang.org/protobuf/internal/encoding/messageset"
    15 	"google.golang.org/protobuf/internal/encoding/messageset"
    16 	"google.golang.org/protobuf/internal/flags"
    16 	"google.golang.org/protobuf/internal/flags"
    17 	"google.golang.org/protobuf/internal/genid"
    17 	"google.golang.org/protobuf/internal/genid"
    18 	"google.golang.org/protobuf/internal/strs"
    18 	"google.golang.org/protobuf/internal/strs"
    19 	pref "google.golang.org/protobuf/reflect/protoreflect"
    19 	"google.golang.org/protobuf/reflect/protoreflect"
    20 	preg "google.golang.org/protobuf/reflect/protoregistry"
    20 	"google.golang.org/protobuf/reflect/protoregistry"
    21 	piface "google.golang.org/protobuf/runtime/protoiface"
    21 	"google.golang.org/protobuf/runtime/protoiface"
    22 )
    22 )
    23 
    23 
    24 // ValidationStatus is the result of validating the wire-format encoding of a message.
    24 // ValidationStatus is the result of validating the wire-format encoding of a message.
    25 type ValidationStatus int
    25 type ValidationStatus int
    26 
    26 
    54 
    54 
    55 // Validate determines whether the contents of the buffer are a valid wire encoding
    55 // Validate determines whether the contents of the buffer are a valid wire encoding
    56 // of the message type.
    56 // of the message type.
    57 //
    57 //
    58 // This function is exposed for testing.
    58 // This function is exposed for testing.
    59 func Validate(mt pref.MessageType, in piface.UnmarshalInput) (out piface.UnmarshalOutput, _ ValidationStatus) {
    59 func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) {
    60 	mi, ok := mt.(*MessageInfo)
    60 	mi, ok := mt.(*MessageInfo)
    61 	if !ok {
    61 	if !ok {
    62 		return out, ValidationUnknown
    62 		return out, ValidationUnknown
    63 	}
    63 	}
    64 	if in.Resolver == nil {
    64 	if in.Resolver == nil {
    65 		in.Resolver = preg.GlobalTypes
    65 		in.Resolver = protoregistry.GlobalTypes
    66 	}
    66 	}
    67 	o, st := mi.validate(in.Buf, 0, unmarshalOptions{
    67 	o, st := mi.validate(in.Buf, 0, unmarshalOptions{
    68 		flags:    in.Flags,
    68 		flags:    in.Flags,
    69 		resolver: in.Resolver,
    69 		resolver: in.Resolver,
    70 	})
    70 	})
    71 	if o.initialized {
    71 	if o.initialized {
    72 		out.Flags |= piface.UnmarshalInitialized
    72 		out.Flags |= protoiface.UnmarshalInitialized
    73 	}
    73 	}
    74 	return out, st
    74 	return out, st
    75 }
    75 }
    76 
    76 
    77 type validationInfo struct {
    77 type validationInfo struct {
   104 	validationTypeBytes
   104 	validationTypeBytes
   105 	validationTypeUTF8String
   105 	validationTypeUTF8String
   106 	validationTypeMessageSetItem
   106 	validationTypeMessageSetItem
   107 )
   107 )
   108 
   108 
   109 func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
   109 func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
   110 	var vi validationInfo
   110 	var vi validationInfo
   111 	switch {
   111 	switch {
   112 	case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
   112 	case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
   113 		switch fd.Kind() {
   113 		switch fd.Kind() {
   114 		case pref.MessageKind:
   114 		case protoreflect.MessageKind:
   115 			vi.typ = validationTypeMessage
   115 			vi.typ = validationTypeMessage
   116 			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
   116 			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
   117 				vi.mi = getMessageInfo(ot.Field(0).Type)
   117 				vi.mi = getMessageInfo(ot.Field(0).Type)
   118 			}
   118 			}
   119 		case pref.GroupKind:
   119 		case protoreflect.GroupKind:
   120 			vi.typ = validationTypeGroup
   120 			vi.typ = validationTypeGroup
   121 			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
   121 			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
   122 				vi.mi = getMessageInfo(ot.Field(0).Type)
   122 				vi.mi = getMessageInfo(ot.Field(0).Type)
   123 			}
   123 			}
   124 		case pref.StringKind:
   124 		case protoreflect.StringKind:
   125 			if strs.EnforceUTF8(fd) {
   125 			if strs.EnforceUTF8(fd) {
   126 				vi.typ = validationTypeUTF8String
   126 				vi.typ = validationTypeUTF8String
   127 			}
   127 			}
   128 		}
   128 		}
   129 	default:
   129 	default:
   130 		vi = newValidationInfo(fd, ft)
   130 		vi = newValidationInfo(fd, ft)
   131 	}
   131 	}
   132 	if fd.Cardinality() == pref.Required {
   132 	if fd.Cardinality() == protoreflect.Required {
   133 		// Avoid overflow. The required field check is done with a 64-bit mask, with
   133 		// Avoid overflow. The required field check is done with a 64-bit mask, with
   134 		// any message containing more than 64 required fields always reported as
   134 		// any message containing more than 64 required fields always reported as
   135 		// potentially uninitialized, so it is not important to get a precise count
   135 		// potentially uninitialized, so it is not important to get a precise count
   136 		// of the required fields past 64.
   136 		// of the required fields past 64.
   137 		if mi.numRequiredFields < math.MaxUint8 {
   137 		if mi.numRequiredFields < math.MaxUint8 {
   140 		}
   140 		}
   141 	}
   141 	}
   142 	return vi
   142 	return vi
   143 }
   143 }
   144 
   144 
   145 func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
   145 func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
   146 	var vi validationInfo
   146 	var vi validationInfo
   147 	switch {
   147 	switch {
   148 	case fd.IsList():
   148 	case fd.IsList():
   149 		switch fd.Kind() {
   149 		switch fd.Kind() {
   150 		case pref.MessageKind:
   150 		case protoreflect.MessageKind:
   151 			vi.typ = validationTypeMessage
   151 			vi.typ = validationTypeMessage
   152 			if ft.Kind() == reflect.Slice {
   152 			if ft.Kind() == reflect.Slice {
   153 				vi.mi = getMessageInfo(ft.Elem())
   153 				vi.mi = getMessageInfo(ft.Elem())
   154 			}
   154 			}
   155 		case pref.GroupKind:
   155 		case protoreflect.GroupKind:
   156 			vi.typ = validationTypeGroup
   156 			vi.typ = validationTypeGroup
   157 			if ft.Kind() == reflect.Slice {
   157 			if ft.Kind() == reflect.Slice {
   158 				vi.mi = getMessageInfo(ft.Elem())
   158 				vi.mi = getMessageInfo(ft.Elem())
   159 			}
   159 			}
   160 		case pref.StringKind:
   160 		case protoreflect.StringKind:
   161 			vi.typ = validationTypeBytes
   161 			vi.typ = validationTypeBytes
   162 			if strs.EnforceUTF8(fd) {
   162 			if strs.EnforceUTF8(fd) {
   163 				vi.typ = validationTypeUTF8String
   163 				vi.typ = validationTypeUTF8String
   164 			}
   164 			}
   165 		default:
   165 		default:
   173 			}
   173 			}
   174 		}
   174 		}
   175 	case fd.IsMap():
   175 	case fd.IsMap():
   176 		vi.typ = validationTypeMap
   176 		vi.typ = validationTypeMap
   177 		switch fd.MapKey().Kind() {
   177 		switch fd.MapKey().Kind() {
   178 		case pref.StringKind:
   178 		case protoreflect.StringKind:
   179 			if strs.EnforceUTF8(fd) {
   179 			if strs.EnforceUTF8(fd) {
   180 				vi.keyType = validationTypeUTF8String
   180 				vi.keyType = validationTypeUTF8String
   181 			}
   181 			}
   182 		}
   182 		}
   183 		switch fd.MapValue().Kind() {
   183 		switch fd.MapValue().Kind() {
   184 		case pref.MessageKind:
   184 		case protoreflect.MessageKind:
   185 			vi.valType = validationTypeMessage
   185 			vi.valType = validationTypeMessage
   186 			if ft.Kind() == reflect.Map {
   186 			if ft.Kind() == reflect.Map {
   187 				vi.mi = getMessageInfo(ft.Elem())
   187 				vi.mi = getMessageInfo(ft.Elem())
   188 			}
   188 			}
   189 		case pref.StringKind:
   189 		case protoreflect.StringKind:
   190 			if strs.EnforceUTF8(fd) {
   190 			if strs.EnforceUTF8(fd) {
   191 				vi.valType = validationTypeUTF8String
   191 				vi.valType = validationTypeUTF8String
   192 			}
   192 			}
   193 		}
   193 		}
   194 	default:
   194 	default:
   195 		switch fd.Kind() {
   195 		switch fd.Kind() {
   196 		case pref.MessageKind:
   196 		case protoreflect.MessageKind:
   197 			vi.typ = validationTypeMessage
   197 			vi.typ = validationTypeMessage
   198 			if !fd.IsWeak() {
   198 			if !fd.IsWeak() {
   199 				vi.mi = getMessageInfo(ft)
   199 				vi.mi = getMessageInfo(ft)
   200 			}
   200 			}
   201 		case pref.GroupKind:
   201 		case protoreflect.GroupKind:
   202 			vi.typ = validationTypeGroup
   202 			vi.typ = validationTypeGroup
   203 			vi.mi = getMessageInfo(ft)
   203 			vi.mi = getMessageInfo(ft)
   204 		case pref.StringKind:
   204 		case protoreflect.StringKind:
   205 			vi.typ = validationTypeBytes
   205 			vi.typ = validationTypeBytes
   206 			if strs.EnforceUTF8(fd) {
   206 			if strs.EnforceUTF8(fd) {
   207 				vi.typ = validationTypeUTF8String
   207 				vi.typ = validationTypeUTF8String
   208 			}
   208 			}
   209 		default:
   209 		default:
   312 						fd := st.mi.Desc.Fields().ByNumber(num)
   312 						fd := st.mi.Desc.Fields().ByNumber(num)
   313 						if fd == nil || !fd.IsWeak() {
   313 						if fd == nil || !fd.IsWeak() {
   314 							break
   314 							break
   315 						}
   315 						}
   316 						messageName := fd.Message().FullName()
   316 						messageName := fd.Message().FullName()
   317 						messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
   317 						messageType, err := protoregistry.GlobalTypes.FindMessageByName(messageName)
   318 						switch err {
   318 						switch err {
   319 						case nil:
   319 						case nil:
   320 							vi.mi, _ = messageType.(*MessageInfo)
   320 							vi.mi, _ = messageType.(*MessageInfo)
   321 						case preg.NotFound:
   321 						case protoregistry.NotFound:
   322 							vi.typ = validationTypeBytes
   322 							vi.typ = validationTypeBytes
   323 						default:
   323 						default:
   324 							return out, ValidationUnknown
   324 							return out, ValidationUnknown
   325 						}
   325 						}
   326 					}
   326 					}
   333 				//   2. The resolver returns preg.NotFound.
   333 				//   2. The resolver returns preg.NotFound.
   334 				// In this case, a type added to the resolver in the future could cause
   334 				// In this case, a type added to the resolver in the future could cause
   335 				// unmarshaling to begin failing. Supporting this requires some way to
   335 				// unmarshaling to begin failing. Supporting this requires some way to
   336 				// determine if the resolver is frozen.
   336 				// determine if the resolver is frozen.
   337 				xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
   337 				xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
   338 				if err != nil && err != preg.NotFound {
   338 				if err != nil && err != protoregistry.NotFound {
   339 					return out, ValidationUnknown
   339 					return out, ValidationUnknown
   340 				}
   340 				}
   341 				if err == nil {
   341 				if err == nil {
   342 					vi = getExtensionFieldInfo(xt).validation
   342 					vi = getExtensionFieldInfo(xt).validation
   343 				}
   343 				}
   511 					if err != nil {
   511 					if err != nil {
   512 						return out, ValidationInvalid
   512 						return out, ValidationInvalid
   513 					}
   513 					}
   514 					xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
   514 					xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
   515 					switch {
   515 					switch {
   516 					case err == preg.NotFound:
   516 					case err == protoregistry.NotFound:
   517 						b = b[n:]
   517 						b = b[n:]
   518 					case err != nil:
   518 					case err != nil:
   519 						return out, ValidationUnknown
   519 						return out, ValidationUnknown
   520 					default:
   520 					default:
   521 						xvi := getExtensionFieldInfo(xt).validation
   521 						xvi := getExtensionFieldInfo(xt).validation