256
|
1 |
// Copyright 2019 The Go Authors. All rights reserved. |
|
2 |
// Use of this source code is governed by a BSD-style |
|
3 |
// license that can be found in the LICENSE file. |
|
4 |
|
|
5 |
package impl |
|
6 |
|
|
7 |
import ( |
|
8 |
"sync" |
|
9 |
|
|
10 |
"google.golang.org/protobuf/internal/errors" |
260
|
11 |
"google.golang.org/protobuf/reflect/protoreflect" |
|
12 |
"google.golang.org/protobuf/runtime/protoiface" |
256
|
13 |
) |
|
14 |
|
260
|
15 |
func (mi *MessageInfo) checkInitialized(in protoiface.CheckInitializedInput) (protoiface.CheckInitializedOutput, error) { |
256
|
16 |
var p pointer |
|
17 |
if ms, ok := in.Message.(*messageState); ok { |
|
18 |
p = ms.pointer() |
|
19 |
} else { |
|
20 |
p = in.Message.(*messageReflectWrapper).pointer() |
|
21 |
} |
260
|
22 |
return protoiface.CheckInitializedOutput{}, mi.checkInitializedPointer(p) |
256
|
23 |
} |
|
24 |
|
|
25 |
func (mi *MessageInfo) checkInitializedPointer(p pointer) error { |
|
26 |
mi.init() |
|
27 |
if !mi.needsInitCheck { |
|
28 |
return nil |
|
29 |
} |
|
30 |
if p.IsNil() { |
|
31 |
for _, f := range mi.orderedCoderFields { |
|
32 |
if f.isRequired { |
|
33 |
return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName())) |
|
34 |
} |
|
35 |
} |
|
36 |
return nil |
|
37 |
} |
|
38 |
if mi.extensionOffset.IsValid() { |
|
39 |
e := p.Apply(mi.extensionOffset).Extensions() |
|
40 |
if err := mi.isInitExtensions(e); err != nil { |
|
41 |
return err |
|
42 |
} |
|
43 |
} |
|
44 |
for _, f := range mi.orderedCoderFields { |
|
45 |
if !f.isRequired && f.funcs.isInit == nil { |
|
46 |
continue |
|
47 |
} |
|
48 |
fptr := p.Apply(f.offset) |
|
49 |
if f.isPointer && fptr.Elem().IsNil() { |
|
50 |
if f.isRequired { |
|
51 |
return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName())) |
|
52 |
} |
|
53 |
continue |
|
54 |
} |
|
55 |
if f.funcs.isInit == nil { |
|
56 |
continue |
|
57 |
} |
|
58 |
if err := f.funcs.isInit(fptr, f); err != nil { |
|
59 |
return err |
|
60 |
} |
|
61 |
} |
|
62 |
return nil |
|
63 |
} |
|
64 |
|
|
65 |
func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error { |
|
66 |
if ext == nil { |
|
67 |
return nil |
|
68 |
} |
|
69 |
for _, x := range *ext { |
|
70 |
ei := getExtensionFieldInfo(x.Type()) |
|
71 |
if ei.funcs.isInit == nil { |
|
72 |
continue |
|
73 |
} |
|
74 |
v := x.Value() |
|
75 |
if !v.IsValid() { |
|
76 |
continue |
|
77 |
} |
|
78 |
if err := ei.funcs.isInit(v); err != nil { |
|
79 |
return err |
|
80 |
} |
|
81 |
} |
|
82 |
return nil |
|
83 |
} |
|
84 |
|
|
85 |
var ( |
|
86 |
needsInitCheckMu sync.Mutex |
|
87 |
needsInitCheckMap sync.Map |
|
88 |
) |
|
89 |
|
|
90 |
// needsInitCheck reports whether a message needs to be checked for partial initialization. |
|
91 |
// |
|
92 |
// It returns true if the message transitively includes any required or extension fields. |
260
|
93 |
func needsInitCheck(md protoreflect.MessageDescriptor) bool { |
256
|
94 |
if v, ok := needsInitCheckMap.Load(md); ok { |
|
95 |
if has, ok := v.(bool); ok { |
|
96 |
return has |
|
97 |
} |
|
98 |
} |
|
99 |
needsInitCheckMu.Lock() |
|
100 |
defer needsInitCheckMu.Unlock() |
|
101 |
return needsInitCheckLocked(md) |
|
102 |
} |
|
103 |
|
260
|
104 |
func needsInitCheckLocked(md protoreflect.MessageDescriptor) (has bool) { |
256
|
105 |
if v, ok := needsInitCheckMap.Load(md); ok { |
|
106 |
// If has is true, we've previously determined that this message |
|
107 |
// needs init checks. |
|
108 |
// |
|
109 |
// If has is false, we've previously determined that it can never |
|
110 |
// be uninitialized. |
|
111 |
// |
|
112 |
// If has is not a bool, we've just encountered a cycle in the |
|
113 |
// message graph. In this case, it is safe to return false: If |
|
114 |
// the message does have required fields, we'll detect them later |
|
115 |
// in the graph traversal. |
|
116 |
has, ok := v.(bool) |
|
117 |
return ok && has |
|
118 |
} |
|
119 |
needsInitCheckMap.Store(md, struct{}{}) // avoid cycles while descending into this message |
|
120 |
defer func() { |
|
121 |
needsInitCheckMap.Store(md, has) |
|
122 |
}() |
|
123 |
if md.RequiredNumbers().Len() > 0 { |
|
124 |
return true |
|
125 |
} |
|
126 |
if md.ExtensionRanges().Len() > 0 { |
|
127 |
return true |
|
128 |
} |
|
129 |
for i := 0; i < md.Fields().Len(); i++ { |
|
130 |
fd := md.Fields().Get(i) |
|
131 |
// Map keys are never messages, so just consider the map value. |
|
132 |
if fd.IsMap() { |
|
133 |
fd = fd.MapValue() |
|
134 |
} |
|
135 |
fmd := fd.Message() |
|
136 |
if fmd != nil && needsInitCheckLocked(fmd) { |
|
137 |
return true |
|
138 |
} |
|
139 |
} |
|
140 |
return false |
|
141 |
} |