1
2
3
4
5 package nullable
6
7 import (
8 "reflect"
9 "testing"
10
11 "github.com/google/go-cmp/cmp"
12 "google.golang.org/protobuf/proto"
13 "google.golang.org/protobuf/reflect/protoreflect"
14 "google.golang.org/protobuf/runtime/protoimpl"
15 "google.golang.org/protobuf/testing/protocmp"
16 )
17
18 func Test(t *testing.T) {
19 for _, mt := range []protoreflect.MessageType{
20 protoimpl.X.ProtoMessageV2Of((*Proto2)(nil)).ProtoReflect().Type(),
21 protoimpl.X.ProtoMessageV2Of((*Proto3)(nil)).ProtoReflect().Type(),
22 } {
23 t.Run(string(mt.Descriptor().FullName()), func(t *testing.T) {
24 testEmptyMessage(t, mt.Zero(), false)
25 testEmptyMessage(t, mt.New(), true)
26
27 })
28 }
29 }
30
31 var methodTestProtos = []protoreflect.MessageType{
32 protoimpl.X.ProtoMessageV2Of((*Methods)(nil)).ProtoReflect().Type(),
33 }
34
35 func TestMethods(t *testing.T) {
36 for _, mt := range methodTestProtos {
37 t.Run(string(mt.Descriptor().FullName()), func(t *testing.T) {
38 testMethods(t, mt)
39 })
40 }
41 }
42
43 func testMethods(t *testing.T, mt protoreflect.MessageType) {
44 m1 := mt.New()
45 populated := testPopulateMessage(t, m1, 2)
46 b, err := proto.Marshal(m1.Interface())
47 if err != nil {
48 t.Errorf("proto.Marshal error: %v", err)
49 }
50 if populated && len(b) == 0 {
51 t.Errorf("len(proto.Marshal) = 0, want >0")
52 }
53 m2 := mt.New()
54 if err := proto.Unmarshal(b, m2.Interface()); err != nil {
55 t.Errorf("proto.Unmarshal error: %v", err)
56 }
57 if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
58 t.Errorf("message mismatch:\n%v", diff)
59 }
60 proto.Reset(m2.Interface())
61 testEmptyMessage(t, m2, true)
62 proto.Merge(m2.Interface(), m1.Interface())
63 if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
64 t.Errorf("message mismatch:\n%v", diff)
65 }
66 proto.Merge(mt.New().Interface(), mt.Zero().Interface())
67 }
68
69 func testEmptyMessage(t *testing.T, m protoreflect.Message, wantValid bool) {
70 numFields := func(m protoreflect.Message) (n int) {
71 m.Range(func(protoreflect.FieldDescriptor, protoreflect.Value) bool {
72 n++
73 return true
74 })
75 return n
76 }
77
78 md := m.Descriptor()
79 if gotValid := m.IsValid(); gotValid != wantValid {
80 t.Errorf("%v.IsValid = %v, want %v", md.FullName(), gotValid, wantValid)
81 }
82 m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
83 t.Errorf("%v.Range iterated over field %v, want no iteration", md.FullName(), fd.Name())
84 return true
85 })
86 fds := md.Fields()
87 for i := 0; i < fds.Len(); i++ {
88 fd := fds.Get(i)
89 if m.Has(fd) {
90 t.Errorf("%v.Has(%v) = true, want false", md.FullName(), fd.Name())
91 }
92 v := m.Get(fd)
93 switch {
94 case fd.IsList():
95 if n := v.List().Len(); n > 0 {
96 t.Errorf("%v.Get(%v).List().Len() = %v, want 0", md.FullName(), fd.Name(), n)
97 }
98 ls := m.NewField(fd).List()
99 if fd.Message() != nil {
100 if n := numFields(ls.NewElement().Message()); n > 0 {
101 t.Errorf("%v.NewField(%v).List().NewElement().Message().Len() = %v, want 0", md.FullName(), fd.Name(), n)
102 }
103 }
104 case fd.IsMap():
105 if n := v.Map().Len(); n > 0 {
106 t.Errorf("%v.Get(%v).Map().Len() = %v, want 0", md.FullName(), fd.Name(), n)
107 }
108 ms := m.NewField(fd).Map()
109 if fd.MapValue().Message() != nil {
110 if n := numFields(ms.NewValue().Message()); n > 0 {
111 t.Errorf("%v.NewField(%v).Map().NewValue().Message().Len() = %v, want 0", md.FullName(), fd.Name(), n)
112 }
113 }
114 case fd.Message() != nil:
115 if n := numFields(v.Message()); n > 0 {
116 t.Errorf("%v.Get(%v).Message().Len() = %v, want 0", md.FullName(), fd.Name(), n)
117 }
118 if n := numFields(m.NewField(fd).Message()); n > 0 {
119 t.Errorf("%v.NewField(%v).Message().Len() = %v, want 0", md.FullName(), fd.Name(), n)
120 }
121 default:
122 if !reflect.DeepEqual(v.Interface(), fd.Default().Interface()) {
123 t.Errorf("%v.Get(%v) = %v, want %v", md.FullName(), fd.Name(), v, fd.Default())
124 }
125 m.NewField(fd)
126 }
127 }
128 ods := md.Oneofs()
129 for i := 0; i < ods.Len(); i++ {
130 od := ods.Get(i)
131 if fd := m.WhichOneof(od); fd != nil {
132 t.Errorf("%v.WhichOneof(%v) = %v, want nil", md.FullName(), od.Name(), fd.Name())
133 }
134 }
135 if b := m.GetUnknown(); b != nil {
136 t.Errorf("%v.GetUnknown() = %v, want nil", md.FullName(), b)
137 }
138 }
139
140 func testPopulateMessage(t *testing.T, m protoreflect.Message, depth int) bool {
141 if depth == 0 {
142 return false
143 }
144 md := m.Descriptor()
145 fds := md.Fields()
146 var populatedMessage bool
147 for i := 0; i < fds.Len(); i++ {
148 populatedField := true
149 fd := fds.Get(i)
150 m.Clear(fd)
151 switch {
152 case fd.IsList():
153 ls := m.Mutable(fd).List()
154 if fd.Message() == nil {
155 ls.Append(scalarValue(fd.Kind()))
156 } else {
157 populatedField = testPopulateMessage(t, ls.AppendMutable().Message(), depth-1)
158 }
159 case fd.IsMap():
160 ms := m.Mutable(fd).Map()
161 if fd.MapValue().Message() == nil {
162 ms.Set(
163 scalarValue(fd.MapKey().Kind()).MapKey(),
164 scalarValue(fd.MapValue().Kind()),
165 )
166 } else {
167
168 m2 := ms.NewValue().Message()
169 populatedField = testPopulateMessage(t, m2, depth-1)
170 ms.Set(
171 scalarValue(fd.MapKey().Kind()).MapKey(),
172 protoreflect.ValueOfMessage(m2),
173 )
174 }
175 case fd.Message() != nil:
176 populatedField = testPopulateMessage(t, m.Mutable(fd).Message(), depth-1)
177 default:
178 m.Set(fd, scalarValue(fd.Kind()))
179 }
180 if populatedField && !m.Has(fd) {
181 t.Errorf("%v.Has(%v) = false, want true", md.FullName(), fd.Name())
182 }
183 populatedMessage = populatedMessage || populatedField
184 }
185 m.SetUnknown(m.GetUnknown())
186 return populatedMessage
187 }
188
189 func scalarValue(k protoreflect.Kind) protoreflect.Value {
190 switch k {
191 case protoreflect.BoolKind:
192 return protoreflect.ValueOfBool(true)
193 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
194 return protoreflect.ValueOfInt32(-32)
195 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
196 return protoreflect.ValueOfInt64(-64)
197 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
198 return protoreflect.ValueOfUint32(32)
199 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
200 return protoreflect.ValueOfUint64(64)
201 case protoreflect.FloatKind:
202 return protoreflect.ValueOfFloat32(32.32)
203 case protoreflect.DoubleKind:
204 return protoreflect.ValueOfFloat64(64.64)
205 case protoreflect.StringKind:
206 return protoreflect.ValueOfString(string("string"))
207 case protoreflect.BytesKind:
208 return protoreflect.ValueOfBytes([]byte("bytes"))
209 case protoreflect.EnumKind:
210 return protoreflect.ValueOfEnum(1)
211 default:
212 panic("unknown kind: " + k.String())
213 }
214 }
215
View as plain text