vendor/google.golang.org/protobuf/internal/impl/decode.go
changeset 256 6d9efbef00a9
child 260 445e01aede7e
equal deleted inserted replaced
255:4f153a23adab 256:6d9efbef00a9
       
     1 // Copyright 2019 The Go Authors. All rights reserved.
       
     2 // Use of this source code is governed by a BSD-style
       
     3 // license that can be found in the LICENSE file.
       
     4 
       
     5 package impl
       
     6 
       
     7 import (
       
     8 	"math/bits"
       
     9 
       
    10 	"google.golang.org/protobuf/encoding/protowire"
       
    11 	"google.golang.org/protobuf/internal/errors"
       
    12 	"google.golang.org/protobuf/internal/flags"
       
    13 	"google.golang.org/protobuf/proto"
       
    14 	"google.golang.org/protobuf/reflect/protoreflect"
       
    15 	preg "google.golang.org/protobuf/reflect/protoregistry"
       
    16 	"google.golang.org/protobuf/runtime/protoiface"
       
    17 	piface "google.golang.org/protobuf/runtime/protoiface"
       
    18 )
       
    19 
       
    20 var errDecode = errors.New("cannot parse invalid wire-format data")
       
    21 
       
    22 type unmarshalOptions struct {
       
    23 	flags    protoiface.UnmarshalInputFlags
       
    24 	resolver interface {
       
    25 		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
       
    26 		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
       
    27 	}
       
    28 }
       
    29 
       
    30 func (o unmarshalOptions) Options() proto.UnmarshalOptions {
       
    31 	return proto.UnmarshalOptions{
       
    32 		Merge:          true,
       
    33 		AllowPartial:   true,
       
    34 		DiscardUnknown: o.DiscardUnknown(),
       
    35 		Resolver:       o.resolver,
       
    36 	}
       
    37 }
       
    38 
       
    39 func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&piface.UnmarshalDiscardUnknown != 0 }
       
    40 
       
    41 func (o unmarshalOptions) IsDefault() bool {
       
    42 	return o.flags == 0 && o.resolver == preg.GlobalTypes
       
    43 }
       
    44 
       
    45 var lazyUnmarshalOptions = unmarshalOptions{
       
    46 	resolver: preg.GlobalTypes,
       
    47 }
       
    48 
       
    49 type unmarshalOutput struct {
       
    50 	n           int // number of bytes consumed
       
    51 	initialized bool
       
    52 }
       
    53 
       
    54 // unmarshal is protoreflect.Methods.Unmarshal.
       
    55 func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
       
    56 	var p pointer
       
    57 	if ms, ok := in.Message.(*messageState); ok {
       
    58 		p = ms.pointer()
       
    59 	} else {
       
    60 		p = in.Message.(*messageReflectWrapper).pointer()
       
    61 	}
       
    62 	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
       
    63 		flags:    in.Flags,
       
    64 		resolver: in.Resolver,
       
    65 	})
       
    66 	var flags piface.UnmarshalOutputFlags
       
    67 	if out.initialized {
       
    68 		flags |= piface.UnmarshalInitialized
       
    69 	}
       
    70 	return piface.UnmarshalOutput{
       
    71 		Flags: flags,
       
    72 	}, err
       
    73 }
       
    74 
       
    75 // errUnknown is returned during unmarshaling to indicate a parse error that
       
    76 // should result in a field being placed in the unknown fields section (for example,
       
    77 // when the wire type doesn't match) as opposed to the entire unmarshal operation
       
    78 // failing (for example, when a field extends past the available input).
       
    79 //
       
    80 // This is a sentinel error which should never be visible to the user.
       
    81 var errUnknown = errors.New("unknown")
       
    82 
       
    83 func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
       
    84 	mi.init()
       
    85 	if flags.ProtoLegacy && mi.isMessageSet {
       
    86 		return unmarshalMessageSet(mi, b, p, opts)
       
    87 	}
       
    88 	initialized := true
       
    89 	var requiredMask uint64
       
    90 	var exts *map[int32]ExtensionField
       
    91 	start := len(b)
       
    92 	for len(b) > 0 {
       
    93 		// Parse the tag (field number and wire type).
       
    94 		var tag uint64
       
    95 		if b[0] < 0x80 {
       
    96 			tag = uint64(b[0])
       
    97 			b = b[1:]
       
    98 		} else if len(b) >= 2 && b[1] < 128 {
       
    99 			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
       
   100 			b = b[2:]
       
   101 		} else {
       
   102 			var n int
       
   103 			tag, n = protowire.ConsumeVarint(b)
       
   104 			if n < 0 {
       
   105 				return out, errDecode
       
   106 			}
       
   107 			b = b[n:]
       
   108 		}
       
   109 		var num protowire.Number
       
   110 		if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
       
   111 			return out, errDecode
       
   112 		} else {
       
   113 			num = protowire.Number(n)
       
   114 		}
       
   115 		wtyp := protowire.Type(tag & 7)
       
   116 
       
   117 		if wtyp == protowire.EndGroupType {
       
   118 			if num != groupTag {
       
   119 				return out, errDecode
       
   120 			}
       
   121 			groupTag = 0
       
   122 			break
       
   123 		}
       
   124 
       
   125 		var f *coderFieldInfo
       
   126 		if int(num) < len(mi.denseCoderFields) {
       
   127 			f = mi.denseCoderFields[num]
       
   128 		} else {
       
   129 			f = mi.coderFields[num]
       
   130 		}
       
   131 		var n int
       
   132 		err := errUnknown
       
   133 		switch {
       
   134 		case f != nil:
       
   135 			if f.funcs.unmarshal == nil {
       
   136 				break
       
   137 			}
       
   138 			var o unmarshalOutput
       
   139 			o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
       
   140 			n = o.n
       
   141 			if err != nil {
       
   142 				break
       
   143 			}
       
   144 			requiredMask |= f.validation.requiredBit
       
   145 			if f.funcs.isInit != nil && !o.initialized {
       
   146 				initialized = false
       
   147 			}
       
   148 		default:
       
   149 			// Possible extension.
       
   150 			if exts == nil && mi.extensionOffset.IsValid() {
       
   151 				exts = p.Apply(mi.extensionOffset).Extensions()
       
   152 				if *exts == nil {
       
   153 					*exts = make(map[int32]ExtensionField)
       
   154 				}
       
   155 			}
       
   156 			if exts == nil {
       
   157 				break
       
   158 			}
       
   159 			var o unmarshalOutput
       
   160 			o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
       
   161 			if err != nil {
       
   162 				break
       
   163 			}
       
   164 			n = o.n
       
   165 			if !o.initialized {
       
   166 				initialized = false
       
   167 			}
       
   168 		}
       
   169 		if err != nil {
       
   170 			if err != errUnknown {
       
   171 				return out, err
       
   172 			}
       
   173 			n = protowire.ConsumeFieldValue(num, wtyp, b)
       
   174 			if n < 0 {
       
   175 				return out, errDecode
       
   176 			}
       
   177 			if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
       
   178 				u := mi.mutableUnknownBytes(p)
       
   179 				*u = protowire.AppendTag(*u, num, wtyp)
       
   180 				*u = append(*u, b[:n]...)
       
   181 			}
       
   182 		}
       
   183 		b = b[n:]
       
   184 	}
       
   185 	if groupTag != 0 {
       
   186 		return out, errDecode
       
   187 	}
       
   188 	if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
       
   189 		initialized = false
       
   190 	}
       
   191 	if initialized {
       
   192 		out.initialized = true
       
   193 	}
       
   194 	out.n = start - len(b)
       
   195 	return out, nil
       
   196 }
       
   197 
       
   198 func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
       
   199 	x := exts[int32(num)]
       
   200 	xt := x.Type()
       
   201 	if xt == nil {
       
   202 		var err error
       
   203 		xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
       
   204 		if err != nil {
       
   205 			if err == preg.NotFound {
       
   206 				return out, errUnknown
       
   207 			}
       
   208 			return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
       
   209 		}
       
   210 	}
       
   211 	xi := getExtensionFieldInfo(xt)
       
   212 	if xi.funcs.unmarshal == nil {
       
   213 		return out, errUnknown
       
   214 	}
       
   215 	if flags.LazyUnmarshalExtensions {
       
   216 		if opts.IsDefault() && x.canLazy(xt) {
       
   217 			out, valid := skipExtension(b, xi, num, wtyp, opts)
       
   218 			switch valid {
       
   219 			case ValidationValid:
       
   220 				if out.initialized {
       
   221 					x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
       
   222 					exts[int32(num)] = x
       
   223 					return out, nil
       
   224 				}
       
   225 			case ValidationInvalid:
       
   226 				return out, errDecode
       
   227 			case ValidationUnknown:
       
   228 			}
       
   229 		}
       
   230 	}
       
   231 	ival := x.Value()
       
   232 	if !ival.IsValid() && xi.unmarshalNeedsValue {
       
   233 		// Create a new message, list, or map value to fill in.
       
   234 		// For enums, create a prototype value to let the unmarshal func know the
       
   235 		// concrete type.
       
   236 		ival = xt.New()
       
   237 	}
       
   238 	v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
       
   239 	if err != nil {
       
   240 		return out, err
       
   241 	}
       
   242 	if xi.funcs.isInit == nil {
       
   243 		out.initialized = true
       
   244 	}
       
   245 	x.Set(xt, v)
       
   246 	exts[int32(num)] = x
       
   247 	return out, nil
       
   248 }
       
   249 
       
   250 func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
       
   251 	if xi.validation.mi == nil {
       
   252 		return out, ValidationUnknown
       
   253 	}
       
   254 	xi.validation.mi.init()
       
   255 	switch xi.validation.typ {
       
   256 	case validationTypeMessage:
       
   257 		if wtyp != protowire.BytesType {
       
   258 			return out, ValidationUnknown
       
   259 		}
       
   260 		v, n := protowire.ConsumeBytes(b)
       
   261 		if n < 0 {
       
   262 			return out, ValidationUnknown
       
   263 		}
       
   264 		out, st := xi.validation.mi.validate(v, 0, opts)
       
   265 		out.n = n
       
   266 		return out, st
       
   267 	case validationTypeGroup:
       
   268 		if wtyp != protowire.StartGroupType {
       
   269 			return out, ValidationUnknown
       
   270 		}
       
   271 		out, st := xi.validation.mi.validate(b, num, opts)
       
   272 		return out, st
       
   273 	default:
       
   274 		return out, ValidationUnknown
       
   275 	}
       
   276 }