1
2
3
4
5 package proto_test
6
7 import (
8 "fmt"
9 "reflect"
10 "sync"
11 "testing"
12
13 "github.com/google/go-cmp/cmp"
14
15 "google.golang.org/protobuf/proto"
16 "google.golang.org/protobuf/reflect/protoreflect"
17 "google.golang.org/protobuf/runtime/protoimpl"
18 "google.golang.org/protobuf/testing/protocmp"
19
20 legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2_20160225_2fc053c5"
21 testpb "google.golang.org/protobuf/internal/testprotos/test"
22 test3pb "google.golang.org/protobuf/internal/testprotos/test3"
23 descpb "google.golang.org/protobuf/types/descriptorpb"
24 )
25
26 func TestExtensionFuncs(t *testing.T) {
27 for _, test := range []struct {
28 message proto.Message
29 ext protoreflect.ExtensionType
30 wantDefault interface{}
31 value interface{}
32 }{
33 {
34 message: &testpb.TestAllExtensions{},
35 ext: testpb.E_OptionalInt32,
36 wantDefault: int32(0),
37 value: int32(1),
38 },
39 {
40 message: &testpb.TestAllExtensions{},
41 ext: testpb.E_RepeatedString,
42 wantDefault: ([]string)(nil),
43 value: []string{"a", "b", "c"},
44 },
45 {
46 message: protoimpl.X.MessageOf(&legacy1pb.Message{}).Interface(),
47 ext: legacy1pb.E_Message_ExtensionOptionalBool,
48 wantDefault: false,
49 value: true,
50 },
51 } {
52 desc := fmt.Sprintf("Extension %v, value %v", test.ext.TypeDescriptor().FullName(), test.value)
53 if proto.HasExtension(test.message, test.ext) {
54 t.Errorf("%v:\nbefore setting extension HasExtension(...) = true, want false", desc)
55 }
56 got := proto.GetExtension(test.message, test.ext)
57 if d := cmp.Diff(test.wantDefault, got); d != "" {
58 t.Errorf("%v:\nbefore setting extension GetExtension(...) returns unexpected value (-want,+got):\n%v", desc, d)
59 }
60 proto.SetExtension(test.message, test.ext, test.value)
61 if !proto.HasExtension(test.message, test.ext) {
62 t.Errorf("%v:\nafter setting extension HasExtension(...) = false, want true", desc)
63 }
64 got = proto.GetExtension(test.message, test.ext)
65 if d := cmp.Diff(test.value, got); d != "" {
66 t.Errorf("%v:\nafter setting extension GetExtension(...) returns unexpected value (-want,+got):\n%v", desc, d)
67 }
68 proto.ClearExtension(test.message, test.ext)
69 if proto.HasExtension(test.message, test.ext) {
70 t.Errorf("%v:\nafter clearing extension HasExtension(...) = true, want false", desc)
71 }
72 }
73 }
74
75 func TestIsValid(t *testing.T) {
76 tests := []struct {
77 xt protoreflect.ExtensionType
78 vi interface{}
79 want bool
80 }{
81 {testpb.E_OptionalBool, nil, false},
82 {testpb.E_OptionalBool, bool(true), true},
83 {testpb.E_OptionalBool, new(bool), false},
84 {testpb.E_OptionalInt32, nil, false},
85 {testpb.E_OptionalInt32, int32(0), true},
86 {testpb.E_OptionalInt32, new(int32), false},
87 {testpb.E_OptionalInt64, nil, false},
88 {testpb.E_OptionalInt64, int64(0), true},
89 {testpb.E_OptionalInt64, new(int64), false},
90 {testpb.E_OptionalUint32, nil, false},
91 {testpb.E_OptionalUint32, uint32(0), true},
92 {testpb.E_OptionalUint32, new(uint32), false},
93 {testpb.E_OptionalUint64, nil, false},
94 {testpb.E_OptionalUint64, uint64(0), true},
95 {testpb.E_OptionalUint64, new(uint64), false},
96 {testpb.E_OptionalFloat, nil, false},
97 {testpb.E_OptionalFloat, float32(0), true},
98 {testpb.E_OptionalFloat, new(float32), false},
99 {testpb.E_OptionalDouble, nil, false},
100 {testpb.E_OptionalDouble, float64(0), true},
101 {testpb.E_OptionalDouble, new(float32), false},
102 {testpb.E_OptionalString, nil, false},
103 {testpb.E_OptionalString, string(""), true},
104 {testpb.E_OptionalString, new(string), false},
105 {testpb.E_OptionalNestedEnum, nil, false},
106 {testpb.E_OptionalNestedEnum, testpb.TestAllTypes_BAZ, true},
107 {testpb.E_OptionalNestedEnum, testpb.TestAllTypes_BAZ.Enum(), false},
108 {testpb.E_OptionalNestedMessage, nil, false},
109 {testpb.E_OptionalNestedMessage, (*testpb.TestAllExtensions_NestedMessage)(nil), true},
110 {testpb.E_OptionalNestedMessage, new(testpb.TestAllExtensions_NestedMessage), true},
111 {testpb.E_OptionalNestedMessage, new(testpb.TestAllExtensions), false},
112 {testpb.E_RepeatedBool, nil, false},
113 {testpb.E_RepeatedBool, []bool(nil), true},
114 {testpb.E_RepeatedBool, []bool{}, true},
115 {testpb.E_RepeatedBool, []bool{false}, true},
116 {testpb.E_RepeatedBool, []*bool{}, false},
117 {testpb.E_RepeatedInt32, nil, false},
118 {testpb.E_RepeatedInt32, []int32(nil), true},
119 {testpb.E_RepeatedInt32, []int32{}, true},
120 {testpb.E_RepeatedInt32, []int32{0}, true},
121 {testpb.E_RepeatedInt32, []*int32{}, false},
122 {testpb.E_RepeatedInt64, nil, false},
123 {testpb.E_RepeatedInt64, []int64(nil), true},
124 {testpb.E_RepeatedInt64, []int64{}, true},
125 {testpb.E_RepeatedInt64, []int64{0}, true},
126 {testpb.E_RepeatedInt64, []*int64{}, false},
127 {testpb.E_RepeatedUint32, nil, false},
128 {testpb.E_RepeatedUint32, []uint32(nil), true},
129 {testpb.E_RepeatedUint32, []uint32{}, true},
130 {testpb.E_RepeatedUint32, []uint32{0}, true},
131 {testpb.E_RepeatedUint32, []*uint32{}, false},
132 {testpb.E_RepeatedUint64, nil, false},
133 {testpb.E_RepeatedUint64, []uint64(nil), true},
134 {testpb.E_RepeatedUint64, []uint64{}, true},
135 {testpb.E_RepeatedUint64, []uint64{0}, true},
136 {testpb.E_RepeatedUint64, []*uint64{}, false},
137 {testpb.E_RepeatedFloat, nil, false},
138 {testpb.E_RepeatedFloat, []float32(nil), true},
139 {testpb.E_RepeatedFloat, []float32{}, true},
140 {testpb.E_RepeatedFloat, []float32{0}, true},
141 {testpb.E_RepeatedFloat, []*float32{}, false},
142 {testpb.E_RepeatedDouble, nil, false},
143 {testpb.E_RepeatedDouble, []float64(nil), true},
144 {testpb.E_RepeatedDouble, []float64{}, true},
145 {testpb.E_RepeatedDouble, []float64{0}, true},
146 {testpb.E_RepeatedDouble, []*float64{}, false},
147 {testpb.E_RepeatedString, nil, false},
148 {testpb.E_RepeatedString, []string(nil), true},
149 {testpb.E_RepeatedString, []string{}, true},
150 {testpb.E_RepeatedString, []string{""}, true},
151 {testpb.E_RepeatedString, []*string{}, false},
152 {testpb.E_RepeatedNestedEnum, nil, false},
153 {testpb.E_RepeatedNestedEnum, []testpb.TestAllTypes_NestedEnum(nil), true},
154 {testpb.E_RepeatedNestedEnum, []testpb.TestAllTypes_NestedEnum{}, true},
155 {testpb.E_RepeatedNestedEnum, []testpb.TestAllTypes_NestedEnum{0}, true},
156 {testpb.E_RepeatedNestedEnum, []*testpb.TestAllTypes_NestedEnum{}, false},
157 {testpb.E_RepeatedNestedMessage, nil, false},
158 {testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions_NestedMessage(nil), true},
159 {testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions_NestedMessage{}, true},
160 {testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions_NestedMessage{{}}, true},
161 {testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions{}, false},
162 }
163
164 for _, tt := range tests {
165
166 got := tt.xt.IsValidInterface(tt.vi)
167 if got != tt.want {
168 t.Errorf("%v.IsValidInterface() = %v, want %v", tt.xt.TypeDescriptor().FullName(), got, tt.want)
169 }
170 if !got {
171 continue
172 }
173
174
175 wantHas := true
176 pv := tt.xt.ValueOf(tt.vi)
177 switch v := pv.Interface().(type) {
178 case protoreflect.List:
179 wantHas = v.Len() > 0
180 case protoreflect.Message:
181 wantHas = v.IsValid()
182 }
183 m := &testpb.TestAllExtensions{}
184 proto.SetExtension(m, tt.xt, tt.vi)
185 gotHas := proto.HasExtension(m, tt.xt)
186 if gotHas != wantHas {
187 t.Errorf("HasExtension(%q) = %v, want %v", tt.xt.TypeDescriptor().FullName(), gotHas, wantHas)
188 }
189
190
191 got = tt.xt.IsValidValue(pv)
192 if got != tt.want {
193 t.Errorf("%v.IsValidValue() = %v, want %v", tt.xt.TypeDescriptor().FullName(), got, tt.want)
194 }
195 if !got {
196 continue
197 }
198
199
200
201 vi := tt.xt.InterfaceOf(pv)
202 if !reflect.DeepEqual(vi, tt.vi) {
203 t.Errorf("InterfaceOf(ValueOf(...)) round-trip mismatch: got %v, want %v", vi, tt.vi)
204 }
205 }
206 }
207
208 func TestExtensionRanger(t *testing.T) {
209 tests := []struct {
210 msg proto.Message
211 want map[protoreflect.ExtensionType]interface{}
212 }{{
213 msg: &testpb.TestAllExtensions{},
214 want: map[protoreflect.ExtensionType]interface{}{
215 testpb.E_OptionalInt32: int32(5),
216 testpb.E_OptionalString: string("hello"),
217 testpb.E_OptionalNestedMessage: &testpb.TestAllExtensions_NestedMessage{},
218 testpb.E_OptionalNestedEnum: testpb.TestAllTypes_BAZ,
219 testpb.E_RepeatedFloat: []float32{+32.32, -32.32},
220 testpb.E_RepeatedNestedMessage: []*testpb.TestAllExtensions_NestedMessage{{}},
221 testpb.E_RepeatedNestedEnum: []testpb.TestAllTypes_NestedEnum{testpb.TestAllTypes_BAZ},
222 },
223 }, {
224 msg: &descpb.MessageOptions{},
225 want: map[protoreflect.ExtensionType]interface{}{
226 test3pb.E_OptionalInt32: int32(5),
227 test3pb.E_OptionalString: string("hello"),
228 test3pb.E_OptionalForeignMessage: &test3pb.ForeignMessage{},
229 test3pb.E_OptionalForeignEnum: test3pb.ForeignEnum_FOREIGN_BAR,
230
231 test3pb.E_OptionalOptionalInt32: int32(5),
232 test3pb.E_OptionalOptionalString: string("hello"),
233 test3pb.E_OptionalOptionalForeignMessage: &test3pb.ForeignMessage{},
234 test3pb.E_OptionalOptionalForeignEnum: test3pb.ForeignEnum_FOREIGN_BAR,
235 },
236 }}
237
238 for _, tt := range tests {
239 for xt, v := range tt.want {
240 proto.SetExtension(tt.msg, xt, v)
241 }
242
243 got := make(map[protoreflect.ExtensionType]interface{})
244 proto.RangeExtensions(tt.msg, func(xt protoreflect.ExtensionType, v interface{}) bool {
245 got[xt] = v
246 return true
247 })
248
249 if diff := cmp.Diff(tt.want, got, protocmp.Transform()); diff != "" {
250 t.Errorf("proto.RangeExtensions mismatch (-want +got):\n%s", diff)
251 }
252 }
253 }
254
255 func TestExtensionGetRace(t *testing.T) {
256
257
258
259 want := int32(42)
260 m1 := &testpb.TestAllExtensions{}
261 proto.SetExtension(m1, testpb.E_OptionalNestedMessage, &testpb.TestAllExtensions_NestedMessage{A: proto.Int32(want)})
262 b, err := proto.Marshal(m1)
263 if err != nil {
264 t.Fatal(err)
265 }
266 m := &testpb.TestAllExtensions{}
267 if err := proto.Unmarshal(b, m); err != nil {
268 t.Fatal(err)
269 }
270 var wg sync.WaitGroup
271 for i := 0; i < 3; i++ {
272 wg.Add(1)
273 go func() {
274 defer wg.Done()
275 if _, err := proto.Marshal(m); err != nil {
276 t.Error(err)
277 }
278 }()
279 wg.Add(1)
280 go func() {
281 defer wg.Done()
282 got := proto.GetExtension(m, testpb.E_OptionalNestedMessage).(*testpb.TestAllExtensions_NestedMessage).GetA()
283 if got != want {
284 t.Errorf("GetExtension(optional_nested_message).a = %v, want %v", got, want)
285 }
286 }()
287 }
288 wg.Wait()
289 }
290
View as plain text