256
|
1 |
// Copyright 2019 The Go Authors. All rights reserved. |
|
2 |
// Use of this source code is governed by a BSD-style |
|
3 |
// license that can be found in the LICENSE file. |
|
4 |
|
|
5 |
package proto |
|
6 |
|
|
7 |
import ( |
|
8 |
"bytes" |
|
9 |
"math" |
|
10 |
"reflect" |
|
11 |
|
|
12 |
"google.golang.org/protobuf/encoding/protowire" |
260
|
13 |
"google.golang.org/protobuf/reflect/protoreflect" |
256
|
14 |
) |
|
15 |
|
|
16 |
// Equal reports whether two messages are equal. |
|
17 |
// If two messages marshal to the same bytes under deterministic serialization, |
|
18 |
// then Equal is guaranteed to report true. |
|
19 |
// |
|
20 |
// Two messages are equal if they belong to the same message descriptor, |
|
21 |
// have the same set of populated known and extension field values, |
|
22 |
// and the same set of unknown fields values. If either of the top-level |
|
23 |
// messages are invalid, then Equal reports true only if both are invalid. |
|
24 |
// |
|
25 |
// Scalar values are compared with the equivalent of the == operator in Go, |
|
26 |
// except bytes values which are compared using bytes.Equal and |
|
27 |
// floating point values which specially treat NaNs as equal. |
|
28 |
// Message values are compared by recursively calling Equal. |
|
29 |
// Lists are equal if each element value is also equal. |
|
30 |
// Maps are equal if they have the same set of keys, where the pair of values |
|
31 |
// for each key is also equal. |
|
32 |
func Equal(x, y Message) bool { |
|
33 |
if x == nil || y == nil { |
|
34 |
return x == nil && y == nil |
|
35 |
} |
260
|
36 |
if reflect.TypeOf(x).Kind() == reflect.Ptr && x == y { |
|
37 |
// Avoid an expensive comparison if both inputs are identical pointers. |
|
38 |
return true |
|
39 |
} |
256
|
40 |
mx := x.ProtoReflect() |
|
41 |
my := y.ProtoReflect() |
|
42 |
if mx.IsValid() != my.IsValid() { |
|
43 |
return false |
|
44 |
} |
|
45 |
return equalMessage(mx, my) |
|
46 |
} |
|
47 |
|
|
48 |
// equalMessage compares two messages. |
260
|
49 |
func equalMessage(mx, my protoreflect.Message) bool { |
256
|
50 |
if mx.Descriptor() != my.Descriptor() { |
|
51 |
return false |
|
52 |
} |
|
53 |
|
|
54 |
nx := 0 |
|
55 |
equal := true |
260
|
56 |
mx.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool { |
256
|
57 |
nx++ |
|
58 |
vy := my.Get(fd) |
|
59 |
equal = my.Has(fd) && equalField(fd, vx, vy) |
|
60 |
return equal |
|
61 |
}) |
|
62 |
if !equal { |
|
63 |
return false |
|
64 |
} |
|
65 |
ny := 0 |
260
|
66 |
my.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool { |
256
|
67 |
ny++ |
|
68 |
return true |
|
69 |
}) |
|
70 |
if nx != ny { |
|
71 |
return false |
|
72 |
} |
|
73 |
|
|
74 |
return equalUnknown(mx.GetUnknown(), my.GetUnknown()) |
|
75 |
} |
|
76 |
|
|
77 |
// equalField compares two fields. |
260
|
78 |
func equalField(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool { |
256
|
79 |
switch { |
|
80 |
case fd.IsList(): |
|
81 |
return equalList(fd, x.List(), y.List()) |
|
82 |
case fd.IsMap(): |
|
83 |
return equalMap(fd, x.Map(), y.Map()) |
|
84 |
default: |
|
85 |
return equalValue(fd, x, y) |
|
86 |
} |
|
87 |
} |
|
88 |
|
|
89 |
// equalMap compares two maps. |
260
|
90 |
func equalMap(fd protoreflect.FieldDescriptor, x, y protoreflect.Map) bool { |
256
|
91 |
if x.Len() != y.Len() { |
|
92 |
return false |
|
93 |
} |
|
94 |
equal := true |
260
|
95 |
x.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool { |
256
|
96 |
vy := y.Get(k) |
|
97 |
equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy) |
|
98 |
return equal |
|
99 |
}) |
|
100 |
return equal |
|
101 |
} |
|
102 |
|
|
103 |
// equalList compares two lists. |
260
|
104 |
func equalList(fd protoreflect.FieldDescriptor, x, y protoreflect.List) bool { |
256
|
105 |
if x.Len() != y.Len() { |
|
106 |
return false |
|
107 |
} |
|
108 |
for i := x.Len() - 1; i >= 0; i-- { |
|
109 |
if !equalValue(fd, x.Get(i), y.Get(i)) { |
|
110 |
return false |
|
111 |
} |
|
112 |
} |
|
113 |
return true |
|
114 |
} |
|
115 |
|
|
116 |
// equalValue compares two singular values. |
260
|
117 |
func equalValue(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool { |
256
|
118 |
switch fd.Kind() { |
260
|
119 |
case protoreflect.BoolKind: |
256
|
120 |
return x.Bool() == y.Bool() |
260
|
121 |
case protoreflect.EnumKind: |
256
|
122 |
return x.Enum() == y.Enum() |
260
|
123 |
case protoreflect.Int32Kind, protoreflect.Sint32Kind, |
|
124 |
protoreflect.Int64Kind, protoreflect.Sint64Kind, |
|
125 |
protoreflect.Sfixed32Kind, protoreflect.Sfixed64Kind: |
256
|
126 |
return x.Int() == y.Int() |
260
|
127 |
case protoreflect.Uint32Kind, protoreflect.Uint64Kind, |
|
128 |
protoreflect.Fixed32Kind, protoreflect.Fixed64Kind: |
256
|
129 |
return x.Uint() == y.Uint() |
260
|
130 |
case protoreflect.FloatKind, protoreflect.DoubleKind: |
256
|
131 |
fx := x.Float() |
|
132 |
fy := y.Float() |
|
133 |
if math.IsNaN(fx) || math.IsNaN(fy) { |
|
134 |
return math.IsNaN(fx) && math.IsNaN(fy) |
|
135 |
} |
|
136 |
return fx == fy |
260
|
137 |
case protoreflect.StringKind: |
256
|
138 |
return x.String() == y.String() |
260
|
139 |
case protoreflect.BytesKind: |
256
|
140 |
return bytes.Equal(x.Bytes(), y.Bytes()) |
260
|
141 |
case protoreflect.MessageKind, protoreflect.GroupKind: |
256
|
142 |
return equalMessage(x.Message(), y.Message()) |
|
143 |
default: |
|
144 |
return x.Interface() == y.Interface() |
|
145 |
} |
|
146 |
} |
|
147 |
|
|
148 |
// equalUnknown compares unknown fields by direct comparison on the raw bytes |
|
149 |
// of each individual field number. |
260
|
150 |
func equalUnknown(x, y protoreflect.RawFields) bool { |
256
|
151 |
if len(x) != len(y) { |
|
152 |
return false |
|
153 |
} |
|
154 |
if bytes.Equal([]byte(x), []byte(y)) { |
|
155 |
return true |
|
156 |
} |
|
157 |
|
260
|
158 |
mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields) |
|
159 |
my := make(map[protoreflect.FieldNumber]protoreflect.RawFields) |
256
|
160 |
for len(x) > 0 { |
|
161 |
fnum, _, n := protowire.ConsumeField(x) |
|
162 |
mx[fnum] = append(mx[fnum], x[:n]...) |
|
163 |
x = x[n:] |
|
164 |
} |
|
165 |
for len(y) > 0 { |
|
166 |
fnum, _, n := protowire.ConsumeField(y) |
|
167 |
my[fnum] = append(my[fnum], y[:n]...) |
|
168 |
y = y[n:] |
|
169 |
} |
|
170 |
return reflect.DeepEqual(mx, my) |
|
171 |
} |