vendor/google.golang.org/protobuf/proto/decode.go
changeset 256 6d9efbef00a9
child 260 445e01aede7e
equal deleted inserted replaced
255:4f153a23adab 256:6d9efbef00a9
       
     1 // Copyright 2018 The Go Authors. All rights reserved.
       
     2 // Use of this source code is governed by a BSD-style
       
     3 // license that can be found in the LICENSE file.
       
     4 
       
     5 package proto
       
     6 
       
     7 import (
       
     8 	"google.golang.org/protobuf/encoding/protowire"
       
     9 	"google.golang.org/protobuf/internal/encoding/messageset"
       
    10 	"google.golang.org/protobuf/internal/errors"
       
    11 	"google.golang.org/protobuf/internal/flags"
       
    12 	"google.golang.org/protobuf/internal/genid"
       
    13 	"google.golang.org/protobuf/internal/pragma"
       
    14 	"google.golang.org/protobuf/reflect/protoreflect"
       
    15 	"google.golang.org/protobuf/reflect/protoregistry"
       
    16 	"google.golang.org/protobuf/runtime/protoiface"
       
    17 )
       
    18 
       
    19 // UnmarshalOptions configures the unmarshaler.
       
    20 //
       
    21 // Example usage:
       
    22 //   err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
       
    23 type UnmarshalOptions struct {
       
    24 	pragma.NoUnkeyedLiterals
       
    25 
       
    26 	// Merge merges the input into the destination message.
       
    27 	// The default behavior is to always reset the message before unmarshaling,
       
    28 	// unless Merge is specified.
       
    29 	Merge bool
       
    30 
       
    31 	// AllowPartial accepts input for messages that will result in missing
       
    32 	// required fields. If AllowPartial is false (the default), Unmarshal will
       
    33 	// return an error if there are any missing required fields.
       
    34 	AllowPartial bool
       
    35 
       
    36 	// If DiscardUnknown is set, unknown fields are ignored.
       
    37 	DiscardUnknown bool
       
    38 
       
    39 	// Resolver is used for looking up types when unmarshaling extension fields.
       
    40 	// If nil, this defaults to using protoregistry.GlobalTypes.
       
    41 	Resolver interface {
       
    42 		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
       
    43 		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
       
    44 	}
       
    45 }
       
    46 
       
    47 // Unmarshal parses the wire-format message in b and places the result in m.
       
    48 // The provided message must be mutable (e.g., a non-nil pointer to a message).
       
    49 func Unmarshal(b []byte, m Message) error {
       
    50 	_, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect())
       
    51 	return err
       
    52 }
       
    53 
       
    54 // Unmarshal parses the wire-format message in b and places the result in m.
       
    55 // The provided message must be mutable (e.g., a non-nil pointer to a message).
       
    56 func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
       
    57 	_, err := o.unmarshal(b, m.ProtoReflect())
       
    58 	return err
       
    59 }
       
    60 
       
    61 // UnmarshalState parses a wire-format message and places the result in m.
       
    62 //
       
    63 // This method permits fine-grained control over the unmarshaler.
       
    64 // Most users should use Unmarshal instead.
       
    65 func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
       
    66 	return o.unmarshal(in.Buf, in.Message)
       
    67 }
       
    68 
       
    69 // unmarshal is a centralized function that all unmarshal operations go through.
       
    70 // For profiling purposes, avoid changing the name of this function or
       
    71 // introducing other code paths for unmarshal that do not go through this.
       
    72 func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
       
    73 	if o.Resolver == nil {
       
    74 		o.Resolver = protoregistry.GlobalTypes
       
    75 	}
       
    76 	if !o.Merge {
       
    77 		Reset(m.Interface())
       
    78 	}
       
    79 	allowPartial := o.AllowPartial
       
    80 	o.Merge = true
       
    81 	o.AllowPartial = true
       
    82 	methods := protoMethods(m)
       
    83 	if methods != nil && methods.Unmarshal != nil &&
       
    84 		!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
       
    85 		in := protoiface.UnmarshalInput{
       
    86 			Message:  m,
       
    87 			Buf:      b,
       
    88 			Resolver: o.Resolver,
       
    89 		}
       
    90 		if o.DiscardUnknown {
       
    91 			in.Flags |= protoiface.UnmarshalDiscardUnknown
       
    92 		}
       
    93 		out, err = methods.Unmarshal(in)
       
    94 	} else {
       
    95 		err = o.unmarshalMessageSlow(b, m)
       
    96 	}
       
    97 	if err != nil {
       
    98 		return out, err
       
    99 	}
       
   100 	if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
       
   101 		return out, nil
       
   102 	}
       
   103 	return out, checkInitialized(m)
       
   104 }
       
   105 
       
   106 func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
       
   107 	_, err := o.unmarshal(b, m)
       
   108 	return err
       
   109 }
       
   110 
       
   111 func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
       
   112 	md := m.Descriptor()
       
   113 	if messageset.IsMessageSet(md) {
       
   114 		return o.unmarshalMessageSet(b, m)
       
   115 	}
       
   116 	fields := md.Fields()
       
   117 	for len(b) > 0 {
       
   118 		// Parse the tag (field number and wire type).
       
   119 		num, wtyp, tagLen := protowire.ConsumeTag(b)
       
   120 		if tagLen < 0 {
       
   121 			return errDecode
       
   122 		}
       
   123 		if num > protowire.MaxValidNumber {
       
   124 			return errDecode
       
   125 		}
       
   126 
       
   127 		// Find the field descriptor for this field number.
       
   128 		fd := fields.ByNumber(num)
       
   129 		if fd == nil && md.ExtensionRanges().Has(num) {
       
   130 			extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
       
   131 			if err != nil && err != protoregistry.NotFound {
       
   132 				return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
       
   133 			}
       
   134 			if extType != nil {
       
   135 				fd = extType.TypeDescriptor()
       
   136 			}
       
   137 		}
       
   138 		var err error
       
   139 		if fd == nil {
       
   140 			err = errUnknown
       
   141 		} else if flags.ProtoLegacy {
       
   142 			if fd.IsWeak() && fd.Message().IsPlaceholder() {
       
   143 				err = errUnknown // weak referent is not linked in
       
   144 			}
       
   145 		}
       
   146 
       
   147 		// Parse the field value.
       
   148 		var valLen int
       
   149 		switch {
       
   150 		case err != nil:
       
   151 		case fd.IsList():
       
   152 			valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
       
   153 		case fd.IsMap():
       
   154 			valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
       
   155 		default:
       
   156 			valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
       
   157 		}
       
   158 		if err != nil {
       
   159 			if err != errUnknown {
       
   160 				return err
       
   161 			}
       
   162 			valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
       
   163 			if valLen < 0 {
       
   164 				return errDecode
       
   165 			}
       
   166 			if !o.DiscardUnknown {
       
   167 				m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
       
   168 			}
       
   169 		}
       
   170 		b = b[tagLen+valLen:]
       
   171 	}
       
   172 	return nil
       
   173 }
       
   174 
       
   175 func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
       
   176 	v, n, err := o.unmarshalScalar(b, wtyp, fd)
       
   177 	if err != nil {
       
   178 		return 0, err
       
   179 	}
       
   180 	switch fd.Kind() {
       
   181 	case protoreflect.GroupKind, protoreflect.MessageKind:
       
   182 		m2 := m.Mutable(fd).Message()
       
   183 		if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
       
   184 			return n, err
       
   185 		}
       
   186 	default:
       
   187 		// Non-message scalars replace the previous value.
       
   188 		m.Set(fd, v)
       
   189 	}
       
   190 	return n, nil
       
   191 }
       
   192 
       
   193 func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
       
   194 	if wtyp != protowire.BytesType {
       
   195 		return 0, errUnknown
       
   196 	}
       
   197 	b, n = protowire.ConsumeBytes(b)
       
   198 	if n < 0 {
       
   199 		return 0, errDecode
       
   200 	}
       
   201 	var (
       
   202 		keyField = fd.MapKey()
       
   203 		valField = fd.MapValue()
       
   204 		key      protoreflect.Value
       
   205 		val      protoreflect.Value
       
   206 		haveKey  bool
       
   207 		haveVal  bool
       
   208 	)
       
   209 	switch valField.Kind() {
       
   210 	case protoreflect.GroupKind, protoreflect.MessageKind:
       
   211 		val = mapv.NewValue()
       
   212 	}
       
   213 	// Map entries are represented as a two-element message with fields
       
   214 	// containing the key and value.
       
   215 	for len(b) > 0 {
       
   216 		num, wtyp, n := protowire.ConsumeTag(b)
       
   217 		if n < 0 {
       
   218 			return 0, errDecode
       
   219 		}
       
   220 		if num > protowire.MaxValidNumber {
       
   221 			return 0, errDecode
       
   222 		}
       
   223 		b = b[n:]
       
   224 		err = errUnknown
       
   225 		switch num {
       
   226 		case genid.MapEntry_Key_field_number:
       
   227 			key, n, err = o.unmarshalScalar(b, wtyp, keyField)
       
   228 			if err != nil {
       
   229 				break
       
   230 			}
       
   231 			haveKey = true
       
   232 		case genid.MapEntry_Value_field_number:
       
   233 			var v protoreflect.Value
       
   234 			v, n, err = o.unmarshalScalar(b, wtyp, valField)
       
   235 			if err != nil {
       
   236 				break
       
   237 			}
       
   238 			switch valField.Kind() {
       
   239 			case protoreflect.GroupKind, protoreflect.MessageKind:
       
   240 				if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
       
   241 					return 0, err
       
   242 				}
       
   243 			default:
       
   244 				val = v
       
   245 			}
       
   246 			haveVal = true
       
   247 		}
       
   248 		if err == errUnknown {
       
   249 			n = protowire.ConsumeFieldValue(num, wtyp, b)
       
   250 			if n < 0 {
       
   251 				return 0, errDecode
       
   252 			}
       
   253 		} else if err != nil {
       
   254 			return 0, err
       
   255 		}
       
   256 		b = b[n:]
       
   257 	}
       
   258 	// Every map entry should have entries for key and value, but this is not strictly required.
       
   259 	if !haveKey {
       
   260 		key = keyField.Default()
       
   261 	}
       
   262 	if !haveVal {
       
   263 		switch valField.Kind() {
       
   264 		case protoreflect.GroupKind, protoreflect.MessageKind:
       
   265 		default:
       
   266 			val = valField.Default()
       
   267 		}
       
   268 	}
       
   269 	mapv.Set(key.MapKey(), val)
       
   270 	return n, nil
       
   271 }
       
   272 
       
   273 // errUnknown is used internally to indicate fields which should be added
       
   274 // to the unknown field set of a message. It is never returned from an exported
       
   275 // function.
       
   276 var errUnknown = errors.New("BUG: internal error (unknown)")
       
   277 
       
   278 var errDecode = errors.New("cannot parse invalid wire-format data")