vendor/google.golang.org/protobuf/internal/impl/checkinit.go
changeset 256 6d9efbef00a9
child 260 445e01aede7e
equal deleted inserted replaced
255:4f153a23adab 256:6d9efbef00a9
       
     1 // Copyright 2019 The Go Authors. All rights reserved.
       
     2 // Use of this source code is governed by a BSD-style
       
     3 // license that can be found in the LICENSE file.
       
     4 
       
     5 package impl
       
     6 
       
     7 import (
       
     8 	"sync"
       
     9 
       
    10 	"google.golang.org/protobuf/internal/errors"
       
    11 	pref "google.golang.org/protobuf/reflect/protoreflect"
       
    12 	piface "google.golang.org/protobuf/runtime/protoiface"
       
    13 )
       
    14 
       
    15 func (mi *MessageInfo) checkInitialized(in piface.CheckInitializedInput) (piface.CheckInitializedOutput, error) {
       
    16 	var p pointer
       
    17 	if ms, ok := in.Message.(*messageState); ok {
       
    18 		p = ms.pointer()
       
    19 	} else {
       
    20 		p = in.Message.(*messageReflectWrapper).pointer()
       
    21 	}
       
    22 	return piface.CheckInitializedOutput{}, mi.checkInitializedPointer(p)
       
    23 }
       
    24 
       
    25 func (mi *MessageInfo) checkInitializedPointer(p pointer) error {
       
    26 	mi.init()
       
    27 	if !mi.needsInitCheck {
       
    28 		return nil
       
    29 	}
       
    30 	if p.IsNil() {
       
    31 		for _, f := range mi.orderedCoderFields {
       
    32 			if f.isRequired {
       
    33 				return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
       
    34 			}
       
    35 		}
       
    36 		return nil
       
    37 	}
       
    38 	if mi.extensionOffset.IsValid() {
       
    39 		e := p.Apply(mi.extensionOffset).Extensions()
       
    40 		if err := mi.isInitExtensions(e); err != nil {
       
    41 			return err
       
    42 		}
       
    43 	}
       
    44 	for _, f := range mi.orderedCoderFields {
       
    45 		if !f.isRequired && f.funcs.isInit == nil {
       
    46 			continue
       
    47 		}
       
    48 		fptr := p.Apply(f.offset)
       
    49 		if f.isPointer && fptr.Elem().IsNil() {
       
    50 			if f.isRequired {
       
    51 				return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
       
    52 			}
       
    53 			continue
       
    54 		}
       
    55 		if f.funcs.isInit == nil {
       
    56 			continue
       
    57 		}
       
    58 		if err := f.funcs.isInit(fptr, f); err != nil {
       
    59 			return err
       
    60 		}
       
    61 	}
       
    62 	return nil
       
    63 }
       
    64 
       
    65 func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error {
       
    66 	if ext == nil {
       
    67 		return nil
       
    68 	}
       
    69 	for _, x := range *ext {
       
    70 		ei := getExtensionFieldInfo(x.Type())
       
    71 		if ei.funcs.isInit == nil {
       
    72 			continue
       
    73 		}
       
    74 		v := x.Value()
       
    75 		if !v.IsValid() {
       
    76 			continue
       
    77 		}
       
    78 		if err := ei.funcs.isInit(v); err != nil {
       
    79 			return err
       
    80 		}
       
    81 	}
       
    82 	return nil
       
    83 }
       
    84 
       
    85 var (
       
    86 	needsInitCheckMu  sync.Mutex
       
    87 	needsInitCheckMap sync.Map
       
    88 )
       
    89 
       
    90 // needsInitCheck reports whether a message needs to be checked for partial initialization.
       
    91 //
       
    92 // It returns true if the message transitively includes any required or extension fields.
       
    93 func needsInitCheck(md pref.MessageDescriptor) bool {
       
    94 	if v, ok := needsInitCheckMap.Load(md); ok {
       
    95 		if has, ok := v.(bool); ok {
       
    96 			return has
       
    97 		}
       
    98 	}
       
    99 	needsInitCheckMu.Lock()
       
   100 	defer needsInitCheckMu.Unlock()
       
   101 	return needsInitCheckLocked(md)
       
   102 }
       
   103 
       
   104 func needsInitCheckLocked(md pref.MessageDescriptor) (has bool) {
       
   105 	if v, ok := needsInitCheckMap.Load(md); ok {
       
   106 		// If has is true, we've previously determined that this message
       
   107 		// needs init checks.
       
   108 		//
       
   109 		// If has is false, we've previously determined that it can never
       
   110 		// be uninitialized.
       
   111 		//
       
   112 		// If has is not a bool, we've just encountered a cycle in the
       
   113 		// message graph. In this case, it is safe to return false: If
       
   114 		// the message does have required fields, we'll detect them later
       
   115 		// in the graph traversal.
       
   116 		has, ok := v.(bool)
       
   117 		return ok && has
       
   118 	}
       
   119 	needsInitCheckMap.Store(md, struct{}{}) // avoid cycles while descending into this message
       
   120 	defer func() {
       
   121 		needsInitCheckMap.Store(md, has)
       
   122 	}()
       
   123 	if md.RequiredNumbers().Len() > 0 {
       
   124 		return true
       
   125 	}
       
   126 	if md.ExtensionRanges().Len() > 0 {
       
   127 		return true
       
   128 	}
       
   129 	for i := 0; i < md.Fields().Len(); i++ {
       
   130 		fd := md.Fields().Get(i)
       
   131 		// Map keys are never messages, so just consider the map value.
       
   132 		if fd.IsMap() {
       
   133 			fd = fd.MapValue()
       
   134 		}
       
   135 		fmd := fd.Message()
       
   136 		if fmd != nil && needsInitCheckLocked(fmd) {
       
   137 			return true
       
   138 		}
       
   139 	}
       
   140 	return false
       
   141 }