1
2
3
4
5 package proto_test
6
7 import (
8 "bytes"
9 "fmt"
10 "reflect"
11 "testing"
12
13 "google.golang.org/protobuf/encoding/prototext"
14 "google.golang.org/protobuf/proto"
15 "google.golang.org/protobuf/reflect/protoreflect"
16 "google.golang.org/protobuf/testing/protopack"
17
18 "google.golang.org/protobuf/internal/errors"
19 testpb "google.golang.org/protobuf/internal/testprotos/test"
20 test3pb "google.golang.org/protobuf/internal/testprotos/test3"
21 )
22
23 func TestDecode(t *testing.T) {
24 for _, test := range testValidMessages {
25 if len(test.decodeTo) == 0 {
26 t.Errorf("%v: no test message types", test.desc)
27 }
28 for _, want := range test.decodeTo {
29 t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
30 opts := test.unmarshalOptions
31 opts.AllowPartial = test.partial
32 wire := append(([]byte)(nil), test.wire...)
33 got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
34 if err := opts.Unmarshal(wire, got); err != nil {
35 t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, prototext.Format(want))
36 return
37 }
38
39
40
41
42 if !bytes.Equal(test.wire, wire) {
43 t.Errorf("Unmarshal unexpectedly modified its input")
44 }
45 for i := range wire {
46 wire[i] = 0
47 }
48 if !proto.Equal(got, want) && got.ProtoReflect().IsValid() && want.ProtoReflect().IsValid() {
49 t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", prototext.Format(got), prototext.Format(want))
50 }
51 })
52 }
53 }
54 }
55
56 func TestDecodeRequiredFieldChecks(t *testing.T) {
57 for _, test := range testValidMessages {
58 if !test.partial {
59 continue
60 }
61 for _, m := range test.decodeTo {
62 t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
63 opts := test.unmarshalOptions
64 opts.AllowPartial = false
65 got := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
66 if err := proto.Unmarshal(test.wire, got); err == nil {
67 t.Fatalf("Unmarshal succeeded (want error)\nMessage:\n%v", prototext.Format(got))
68 }
69 })
70 }
71 }
72 }
73
74 func TestDecodeInvalidMessages(t *testing.T) {
75 for _, test := range testInvalidMessages {
76 if len(test.decodeTo) == 0 {
77 t.Errorf("%v: no test message types", test.desc)
78 }
79 for _, want := range test.decodeTo {
80 t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
81 opts := test.unmarshalOptions
82 opts.AllowPartial = test.partial
83 got := want.ProtoReflect().New().Interface()
84 if err := opts.Unmarshal(test.wire, got); err == nil {
85 t.Errorf("Unmarshal unexpectedly succeeded\ninput bytes: [%x]\nMessage:\n%v", test.wire, prototext.Format(got))
86 } else if !errors.Is(err, proto.Error) {
87 t.Errorf("Unmarshal error is not a proto.Error: %v", err)
88 }
89 })
90 }
91 }
92 }
93
94 func TestDecodeZeroLengthBytes(t *testing.T) {
95
96
97 wire := protopack.Message{
98 protopack.Tag{94, protopack.BytesType}, protopack.Bytes(nil),
99 }.Marshal()
100 m := &test3pb.TestAllTypes{}
101 if err := proto.Unmarshal(wire, m); err != nil {
102 t.Fatal(err)
103 }
104 if m.OptionalBytes != nil {
105 t.Errorf("unmarshal zero-length proto3 bytes field: got %v, want nil", m.OptionalBytes)
106 }
107 }
108
109 func TestDecodeOneofNilWrapper(t *testing.T) {
110 wire := protopack.Message{
111 protopack.Tag{111, protopack.VarintType}, protopack.Varint(1111),
112 }.Marshal()
113 m := &testpb.TestAllTypes{OneofField: (*testpb.TestAllTypes_OneofUint32)(nil)}
114 if err := proto.Unmarshal(wire, m); err != nil {
115 t.Fatal(err)
116 }
117 if got := m.GetOneofUint32(); got != 1111 {
118 t.Errorf("GetOneofUint32() = %v, want %v", got, 1111)
119 }
120 }
121
122 func TestDecodeEmptyBytes(t *testing.T) {
123
124
125
126 m := &testpb.TestAllTypes{}
127 b := protopack.Message{
128 protopack.Tag{45, protopack.BytesType}, protopack.Bytes(nil),
129 }.Marshal()
130 if err := proto.Unmarshal(b, m); err != nil {
131 t.Fatal(err)
132 }
133 if m.RepeatedBytes[0] == nil {
134 t.Errorf("unmarshaling repeated bytes field containing zero-length value: Got nil bytes, want non-nil")
135 }
136 }
137
138 func build(m proto.Message, opts ...buildOpt) proto.Message {
139 for _, opt := range opts {
140 opt(m)
141 }
142 return m
143 }
144
145 type buildOpt func(proto.Message)
146
147 func unknown(raw protoreflect.RawFields) buildOpt {
148 return func(m proto.Message) {
149 m.ProtoReflect().SetUnknown(raw)
150 }
151 }
152
153 func extend(desc protoreflect.ExtensionType, value interface{}) buildOpt {
154 return func(m proto.Message) {
155 proto.SetExtension(m, desc, value)
156 }
157 }
158
View as plain text