vendor/google.golang.org/protobuf/internal/impl/decode.go
changeset 260 445e01aede7e
parent 256 6d9efbef00a9
--- a/vendor/google.golang.org/protobuf/internal/impl/decode.go	Tue Aug 23 22:33:28 2022 +0200
+++ b/vendor/google.golang.org/protobuf/internal/impl/decode.go	Tue Aug 23 22:39:43 2022 +0200
@@ -12,12 +12,12 @@
 	"google.golang.org/protobuf/internal/flags"
 	"google.golang.org/protobuf/proto"
 	"google.golang.org/protobuf/reflect/protoreflect"
-	preg "google.golang.org/protobuf/reflect/protoregistry"
+	"google.golang.org/protobuf/reflect/protoregistry"
 	"google.golang.org/protobuf/runtime/protoiface"
-	piface "google.golang.org/protobuf/runtime/protoiface"
 )
 
 var errDecode = errors.New("cannot parse invalid wire-format data")
+var errRecursionDepth = errors.New("exceeded maximum recursion depth")
 
 type unmarshalOptions struct {
 	flags    protoiface.UnmarshalInputFlags
@@ -25,6 +25,7 @@
 		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
 		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
 	}
+	depth int
 }
 
 func (o unmarshalOptions) Options() proto.UnmarshalOptions {
@@ -36,14 +37,17 @@
 	}
 }
 
-func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&piface.UnmarshalDiscardUnknown != 0 }
+func (o unmarshalOptions) DiscardUnknown() bool {
+	return o.flags&protoiface.UnmarshalDiscardUnknown != 0
+}
 
 func (o unmarshalOptions) IsDefault() bool {
-	return o.flags == 0 && o.resolver == preg.GlobalTypes
+	return o.flags == 0 && o.resolver == protoregistry.GlobalTypes
 }
 
 var lazyUnmarshalOptions = unmarshalOptions{
-	resolver: preg.GlobalTypes,
+	resolver: protoregistry.GlobalTypes,
+	depth:    protowire.DefaultRecursionLimit,
 }
 
 type unmarshalOutput struct {
@@ -52,7 +56,7 @@
 }
 
 // unmarshal is protoreflect.Methods.Unmarshal.
-func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
+func (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
 	var p pointer
 	if ms, ok := in.Message.(*messageState); ok {
 		p = ms.pointer()
@@ -62,12 +66,13 @@
 	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
 		flags:    in.Flags,
 		resolver: in.Resolver,
+		depth:    in.Depth,
 	})
-	var flags piface.UnmarshalOutputFlags
+	var flags protoiface.UnmarshalOutputFlags
 	if out.initialized {
-		flags |= piface.UnmarshalInitialized
+		flags |= protoiface.UnmarshalInitialized
 	}
-	return piface.UnmarshalOutput{
+	return protoiface.UnmarshalOutput{
 		Flags: flags,
 	}, err
 }
@@ -82,6 +87,10 @@
 
 func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
 	mi.init()
+	opts.depth--
+	if opts.depth < 0 {
+		return out, errRecursionDepth
+	}
 	if flags.ProtoLegacy && mi.isMessageSet {
 		return unmarshalMessageSet(mi, b, p, opts)
 	}
@@ -202,7 +211,7 @@
 		var err error
 		xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
 		if err != nil {
-			if err == preg.NotFound {
+			if err == protoregistry.NotFound {
 				return out, errUnknown
 			}
 			return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)