256
|
1 |
// Copyright 2019 The Go Authors. All rights reserved. |
|
2 |
// Use of this source code is governed by a BSD-style |
|
3 |
// license that can be found in the LICENSE file. |
|
4 |
|
|
5 |
package proto |
|
6 |
|
|
7 |
import ( |
|
8 |
"bytes" |
|
9 |
"math" |
|
10 |
"reflect" |
|
11 |
|
|
12 |
"google.golang.org/protobuf/encoding/protowire" |
|
13 |
pref "google.golang.org/protobuf/reflect/protoreflect" |
|
14 |
) |
|
15 |
|
|
16 |
// Equal reports whether two messages are equal. |
|
17 |
// If two messages marshal to the same bytes under deterministic serialization, |
|
18 |
// then Equal is guaranteed to report true. |
|
19 |
// |
|
20 |
// Two messages are equal if they belong to the same message descriptor, |
|
21 |
// have the same set of populated known and extension field values, |
|
22 |
// and the same set of unknown fields values. If either of the top-level |
|
23 |
// messages are invalid, then Equal reports true only if both are invalid. |
|
24 |
// |
|
25 |
// Scalar values are compared with the equivalent of the == operator in Go, |
|
26 |
// except bytes values which are compared using bytes.Equal and |
|
27 |
// floating point values which specially treat NaNs as equal. |
|
28 |
// Message values are compared by recursively calling Equal. |
|
29 |
// Lists are equal if each element value is also equal. |
|
30 |
// Maps are equal if they have the same set of keys, where the pair of values |
|
31 |
// for each key is also equal. |
|
32 |
func Equal(x, y Message) bool { |
|
33 |
if x == nil || y == nil { |
|
34 |
return x == nil && y == nil |
|
35 |
} |
|
36 |
mx := x.ProtoReflect() |
|
37 |
my := y.ProtoReflect() |
|
38 |
if mx.IsValid() != my.IsValid() { |
|
39 |
return false |
|
40 |
} |
|
41 |
return equalMessage(mx, my) |
|
42 |
} |
|
43 |
|
|
44 |
// equalMessage compares two messages. |
|
45 |
func equalMessage(mx, my pref.Message) bool { |
|
46 |
if mx.Descriptor() != my.Descriptor() { |
|
47 |
return false |
|
48 |
} |
|
49 |
|
|
50 |
nx := 0 |
|
51 |
equal := true |
|
52 |
mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool { |
|
53 |
nx++ |
|
54 |
vy := my.Get(fd) |
|
55 |
equal = my.Has(fd) && equalField(fd, vx, vy) |
|
56 |
return equal |
|
57 |
}) |
|
58 |
if !equal { |
|
59 |
return false |
|
60 |
} |
|
61 |
ny := 0 |
|
62 |
my.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool { |
|
63 |
ny++ |
|
64 |
return true |
|
65 |
}) |
|
66 |
if nx != ny { |
|
67 |
return false |
|
68 |
} |
|
69 |
|
|
70 |
return equalUnknown(mx.GetUnknown(), my.GetUnknown()) |
|
71 |
} |
|
72 |
|
|
73 |
// equalField compares two fields. |
|
74 |
func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool { |
|
75 |
switch { |
|
76 |
case fd.IsList(): |
|
77 |
return equalList(fd, x.List(), y.List()) |
|
78 |
case fd.IsMap(): |
|
79 |
return equalMap(fd, x.Map(), y.Map()) |
|
80 |
default: |
|
81 |
return equalValue(fd, x, y) |
|
82 |
} |
|
83 |
} |
|
84 |
|
|
85 |
// equalMap compares two maps. |
|
86 |
func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool { |
|
87 |
if x.Len() != y.Len() { |
|
88 |
return false |
|
89 |
} |
|
90 |
equal := true |
|
91 |
x.Range(func(k pref.MapKey, vx pref.Value) bool { |
|
92 |
vy := y.Get(k) |
|
93 |
equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy) |
|
94 |
return equal |
|
95 |
}) |
|
96 |
return equal |
|
97 |
} |
|
98 |
|
|
99 |
// equalList compares two lists. |
|
100 |
func equalList(fd pref.FieldDescriptor, x, y pref.List) bool { |
|
101 |
if x.Len() != y.Len() { |
|
102 |
return false |
|
103 |
} |
|
104 |
for i := x.Len() - 1; i >= 0; i-- { |
|
105 |
if !equalValue(fd, x.Get(i), y.Get(i)) { |
|
106 |
return false |
|
107 |
} |
|
108 |
} |
|
109 |
return true |
|
110 |
} |
|
111 |
|
|
112 |
// equalValue compares two singular values. |
|
113 |
func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool { |
|
114 |
switch fd.Kind() { |
|
115 |
case pref.BoolKind: |
|
116 |
return x.Bool() == y.Bool() |
|
117 |
case pref.EnumKind: |
|
118 |
return x.Enum() == y.Enum() |
|
119 |
case pref.Int32Kind, pref.Sint32Kind, |
|
120 |
pref.Int64Kind, pref.Sint64Kind, |
|
121 |
pref.Sfixed32Kind, pref.Sfixed64Kind: |
|
122 |
return x.Int() == y.Int() |
|
123 |
case pref.Uint32Kind, pref.Uint64Kind, |
|
124 |
pref.Fixed32Kind, pref.Fixed64Kind: |
|
125 |
return x.Uint() == y.Uint() |
|
126 |
case pref.FloatKind, pref.DoubleKind: |
|
127 |
fx := x.Float() |
|
128 |
fy := y.Float() |
|
129 |
if math.IsNaN(fx) || math.IsNaN(fy) { |
|
130 |
return math.IsNaN(fx) && math.IsNaN(fy) |
|
131 |
} |
|
132 |
return fx == fy |
|
133 |
case pref.StringKind: |
|
134 |
return x.String() == y.String() |
|
135 |
case pref.BytesKind: |
|
136 |
return bytes.Equal(x.Bytes(), y.Bytes()) |
|
137 |
case pref.MessageKind, pref.GroupKind: |
|
138 |
return equalMessage(x.Message(), y.Message()) |
|
139 |
default: |
|
140 |
return x.Interface() == y.Interface() |
|
141 |
} |
|
142 |
} |
|
143 |
|
|
144 |
// equalUnknown compares unknown fields by direct comparison on the raw bytes |
|
145 |
// of each individual field number. |
|
146 |
func equalUnknown(x, y pref.RawFields) bool { |
|
147 |
if len(x) != len(y) { |
|
148 |
return false |
|
149 |
} |
|
150 |
if bytes.Equal([]byte(x), []byte(y)) { |
|
151 |
return true |
|
152 |
} |
|
153 |
|
|
154 |
mx := make(map[pref.FieldNumber]pref.RawFields) |
|
155 |
my := make(map[pref.FieldNumber]pref.RawFields) |
|
156 |
for len(x) > 0 { |
|
157 |
fnum, _, n := protowire.ConsumeField(x) |
|
158 |
mx[fnum] = append(mx[fnum], x[:n]...) |
|
159 |
x = x[n:] |
|
160 |
} |
|
161 |
for len(y) > 0 { |
|
162 |
fnum, _, n := protowire.ConsumeField(y) |
|
163 |
my[fnum] = append(my[fnum], y[:n]...) |
|
164 |
y = y[n:] |
|
165 |
} |
|
166 |
return reflect.DeepEqual(mx, my) |
|
167 |
} |