1
2
3
4
5
6
7
8
9
10
11 package protogen
12
13 import (
14 "bufio"
15 "bytes"
16 "fmt"
17 "go/ast"
18 "go/parser"
19 "go/printer"
20 "go/token"
21 "go/types"
22 "io/ioutil"
23 "os"
24 "path"
25 "path/filepath"
26 "sort"
27 "strconv"
28 "strings"
29
30 "google.golang.org/protobuf/encoding/prototext"
31 "google.golang.org/protobuf/internal/genid"
32 "google.golang.org/protobuf/internal/strs"
33 "google.golang.org/protobuf/proto"
34 "google.golang.org/protobuf/reflect/protodesc"
35 "google.golang.org/protobuf/reflect/protoreflect"
36 "google.golang.org/protobuf/reflect/protoregistry"
37
38 "google.golang.org/protobuf/types/descriptorpb"
39 "google.golang.org/protobuf/types/dynamicpb"
40 "google.golang.org/protobuf/types/pluginpb"
41 )
42
43 const goPackageDocURL = "https://protobuf.dev/reference/go/go-generated#package"
44
45
46
47
48
49
50
51
52 func (opts Options) Run(f func(*Plugin) error) {
53 if err := run(opts, f); err != nil {
54 fmt.Fprintf(os.Stderr, "%s: %v\n", filepath.Base(os.Args[0]), err)
55 os.Exit(1)
56 }
57 }
58
59 func run(opts Options, f func(*Plugin) error) error {
60 if len(os.Args) > 1 {
61 return fmt.Errorf("unknown argument %q (this program should be run by protoc, not directly)", os.Args[1])
62 }
63 in, err := ioutil.ReadAll(os.Stdin)
64 if err != nil {
65 return err
66 }
67 req := &pluginpb.CodeGeneratorRequest{}
68 if err := proto.Unmarshal(in, req); err != nil {
69 return err
70 }
71 gen, err := opts.New(req)
72 if err != nil {
73 return err
74 }
75 if err := f(gen); err != nil {
76
77
78
79
80
81
82 gen.Error(err)
83 }
84 resp := gen.Response()
85 out, err := proto.Marshal(resp)
86 if err != nil {
87 return err
88 }
89 if _, err := os.Stdout.Write(out); err != nil {
90 return err
91 }
92 return nil
93 }
94
95
96 type Plugin struct {
97
98 Request *pluginpb.CodeGeneratorRequest
99
100
101
102
103 Files []*File
104 FilesByPath map[string]*File
105
106
107
108
109 SupportedFeatures uint64
110
111 fileReg *protoregistry.Files
112 enumsByName map[protoreflect.FullName]*Enum
113 messagesByName map[protoreflect.FullName]*Message
114 annotateCode bool
115 pathType pathType
116 module string
117 genFiles []*GeneratedFile
118 opts Options
119 err error
120 }
121
122 type Options struct {
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146 ParamFunc func(name, value string) error
147
148
149
150
151 ImportRewriteFunc func(GoImportPath) GoImportPath
152 }
153
154
155 func (opts Options) New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) {
156 gen := &Plugin{
157 Request: req,
158 FilesByPath: make(map[string]*File),
159 fileReg: new(protoregistry.Files),
160 enumsByName: make(map[protoreflect.FullName]*Enum),
161 messagesByName: make(map[protoreflect.FullName]*Message),
162 opts: opts,
163 }
164
165 packageNames := make(map[string]GoPackageName)
166 importPaths := make(map[string]GoImportPath)
167 for _, param := range strings.Split(req.GetParameter(), ",") {
168 var value string
169 if i := strings.Index(param, "="); i >= 0 {
170 value = param[i+1:]
171 param = param[0:i]
172 }
173 switch param {
174 case "":
175
176 case "module":
177 gen.module = value
178 case "paths":
179 switch value {
180 case "import":
181 gen.pathType = pathTypeImport
182 case "source_relative":
183 gen.pathType = pathTypeSourceRelative
184 default:
185 return nil, fmt.Errorf(`unknown path type %q: want "import" or "source_relative"`, value)
186 }
187 case "annotate_code":
188 switch value {
189 case "true", "":
190 gen.annotateCode = true
191 case "false":
192 default:
193 return nil, fmt.Errorf(`bad value for parameter %q: want "true" or "false"`, param)
194 }
195 default:
196 if param[0] == 'M' {
197 impPath, pkgName := splitImportPathAndPackageName(value)
198 if pkgName != "" {
199 packageNames[param[1:]] = pkgName
200 }
201 if impPath != "" {
202 importPaths[param[1:]] = impPath
203 }
204 continue
205 }
206 if opts.ParamFunc != nil {
207 if err := opts.ParamFunc(param, value); err != nil {
208 return nil, err
209 }
210 }
211 }
212 }
213
214
215
216
217 if gen.module != "" && gen.pathType == pathTypeSourceRelative {
218 return nil, fmt.Errorf("cannot use module= with paths=source_relative")
219 }
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235 for _, fdesc := range gen.Request.ProtoFile {
236
237
238 filename := fdesc.GetName()
239 impPath, pkgName := splitImportPathAndPackageName(fdesc.GetOptions().GetGoPackage())
240 if importPaths[filename] == "" && impPath != "" {
241 importPaths[filename] = impPath
242 }
243 if packageNames[filename] == "" && pkgName != "" {
244 packageNames[filename] = pkgName
245 }
246 switch {
247 case importPaths[filename] == "":
248
249 return nil, fmt.Errorf(
250 "unable to determine Go import path for %q\n\n"+
251 "Please specify either:\n"+
252 "\t• a \"go_package\" option in the .proto source file, or\n"+
253 "\t• a \"M\" argument on the command line.\n\n"+
254 "See %v for more information.\n",
255 fdesc.GetName(), goPackageDocURL)
256 case !strings.Contains(string(importPaths[filename]), ".") &&
257 !strings.Contains(string(importPaths[filename]), "/"):
258
259
260 return nil, fmt.Errorf(
261 "invalid Go import path %q for %q\n\n"+
262 "The import path must contain at least one period ('.') or forward slash ('/') character.\n\n"+
263 "See %v for more information.\n",
264 string(importPaths[filename]), fdesc.GetName(), goPackageDocURL)
265 case packageNames[filename] == "":
266
267
268
269
270
271
272
273
274
275
276 if impPath == "" {
277 impPath = importPaths[filename]
278 }
279 packageNames[filename] = cleanPackageName(path.Base(string(impPath)))
280 }
281 }
282
283
284
285 packageFiles := make(map[GoImportPath][]string)
286 for filename, importPath := range importPaths {
287 if _, ok := packageNames[filename]; !ok {
288
289
290 continue
291 }
292 packageFiles[importPath] = append(packageFiles[importPath], filename)
293 }
294 for importPath, filenames := range packageFiles {
295 for i := 1; i < len(filenames); i++ {
296 if a, b := packageNames[filenames[0]], packageNames[filenames[i]]; a != b {
297 return nil, fmt.Errorf("Go package %v has inconsistent names %v (%v) and %v (%v)",
298 importPath, a, filenames[0], b, filenames[i])
299 }
300 }
301 }
302
303
304 typeRegistry := newExtensionRegistry()
305 for _, fdesc := range gen.Request.ProtoFile {
306 filename := fdesc.GetName()
307 if gen.FilesByPath[filename] != nil {
308 return nil, fmt.Errorf("duplicate file name: %q", filename)
309 }
310 f, err := newFile(gen, fdesc, packageNames[filename], importPaths[filename])
311 if err != nil {
312 return nil, err
313 }
314 gen.Files = append(gen.Files, f)
315 gen.FilesByPath[filename] = f
316 if err = typeRegistry.registerAllExtensionsFromFile(f.Desc); err != nil {
317 return nil, err
318 }
319 }
320 for _, filename := range gen.Request.FileToGenerate {
321 f, ok := gen.FilesByPath[filename]
322 if !ok {
323 return nil, fmt.Errorf("no descriptor for generated file: %v", filename)
324 }
325 f.Generate = true
326 }
327
328
329 if typeRegistry.hasNovelExtensions() {
330 for _, f := range gen.Files {
331 b, err := proto.Marshal(f.Proto.ProtoReflect().Interface())
332 if err != nil {
333 return nil, err
334 }
335 err = proto.UnmarshalOptions{Resolver: typeRegistry}.Unmarshal(b, f.Proto)
336 if err != nil {
337 return nil, err
338 }
339 }
340 }
341 return gen, nil
342 }
343
344
345
346 func (gen *Plugin) Error(err error) {
347 if gen.err == nil {
348 gen.err = err
349 }
350 }
351
352
353 func (gen *Plugin) Response() *pluginpb.CodeGeneratorResponse {
354 resp := &pluginpb.CodeGeneratorResponse{}
355 if gen.err != nil {
356 resp.Error = proto.String(gen.err.Error())
357 return resp
358 }
359 for _, g := range gen.genFiles {
360 if g.skip {
361 continue
362 }
363 content, err := g.Content()
364 if err != nil {
365 return &pluginpb.CodeGeneratorResponse{
366 Error: proto.String(err.Error()),
367 }
368 }
369 filename := g.filename
370 if gen.module != "" {
371 trim := gen.module + "/"
372 if !strings.HasPrefix(filename, trim) {
373 return &pluginpb.CodeGeneratorResponse{
374 Error: proto.String(fmt.Sprintf("%v: generated file does not match prefix %q", filename, gen.module)),
375 }
376 }
377 filename = strings.TrimPrefix(filename, trim)
378 }
379 resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{
380 Name: proto.String(filename),
381 Content: proto.String(string(content)),
382 })
383 if gen.annotateCode && strings.HasSuffix(g.filename, ".go") {
384 meta, err := g.metaFile(content)
385 if err != nil {
386 return &pluginpb.CodeGeneratorResponse{
387 Error: proto.String(err.Error()),
388 }
389 }
390 resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{
391 Name: proto.String(filename + ".meta"),
392 Content: proto.String(meta),
393 })
394 }
395 }
396 if gen.SupportedFeatures > 0 {
397 resp.SupportedFeatures = proto.Uint64(gen.SupportedFeatures)
398 }
399 return resp
400 }
401
402
403 type File struct {
404 Desc protoreflect.FileDescriptor
405 Proto *descriptorpb.FileDescriptorProto
406
407 GoDescriptorIdent GoIdent
408 GoPackageName GoPackageName
409 GoImportPath GoImportPath
410
411 Enums []*Enum
412 Messages []*Message
413 Extensions []*Extension
414 Services []*Service
415
416 Generate bool
417
418
419
420
421
422
423 GeneratedFilenamePrefix string
424
425 location Location
426 }
427
428 func newFile(gen *Plugin, p *descriptorpb.FileDescriptorProto, packageName GoPackageName, importPath GoImportPath) (*File, error) {
429 desc, err := protodesc.NewFile(p, gen.fileReg)
430 if err != nil {
431 return nil, fmt.Errorf("invalid FileDescriptorProto %q: %v", p.GetName(), err)
432 }
433 if err := gen.fileReg.RegisterFile(desc); err != nil {
434 return nil, fmt.Errorf("cannot register descriptor %q: %v", p.GetName(), err)
435 }
436 f := &File{
437 Desc: desc,
438 Proto: p,
439 GoPackageName: packageName,
440 GoImportPath: importPath,
441 location: Location{SourceFile: desc.Path()},
442 }
443
444
445 prefix := p.GetName()
446 if ext := path.Ext(prefix); ext == ".proto" || ext == ".protodevel" {
447 prefix = prefix[:len(prefix)-len(ext)]
448 }
449 switch gen.pathType {
450 case pathTypeImport:
451
452 prefix = path.Join(string(f.GoImportPath), path.Base(prefix))
453 case pathTypeSourceRelative:
454
455
456 }
457 f.GoDescriptorIdent = GoIdent{
458 GoName: "File_" + strs.GoSanitized(p.GetName()),
459 GoImportPath: f.GoImportPath,
460 }
461 f.GeneratedFilenamePrefix = prefix
462
463 for i, eds := 0, desc.Enums(); i < eds.Len(); i++ {
464 f.Enums = append(f.Enums, newEnum(gen, f, nil, eds.Get(i)))
465 }
466 for i, mds := 0, desc.Messages(); i < mds.Len(); i++ {
467 f.Messages = append(f.Messages, newMessage(gen, f, nil, mds.Get(i)))
468 }
469 for i, xds := 0, desc.Extensions(); i < xds.Len(); i++ {
470 f.Extensions = append(f.Extensions, newField(gen, f, nil, xds.Get(i)))
471 }
472 for i, sds := 0, desc.Services(); i < sds.Len(); i++ {
473 f.Services = append(f.Services, newService(gen, f, sds.Get(i)))
474 }
475 for _, message := range f.Messages {
476 if err := message.resolveDependencies(gen); err != nil {
477 return nil, err
478 }
479 }
480 for _, extension := range f.Extensions {
481 if err := extension.resolveDependencies(gen); err != nil {
482 return nil, err
483 }
484 }
485 for _, service := range f.Services {
486 for _, method := range service.Methods {
487 if err := method.resolveDependencies(gen); err != nil {
488 return nil, err
489 }
490 }
491 }
492 return f, nil
493 }
494
495
496
497 func splitImportPathAndPackageName(s string) (GoImportPath, GoPackageName) {
498 if i := strings.Index(s, ";"); i >= 0 {
499 return GoImportPath(s[:i]), GoPackageName(s[i+1:])
500 }
501 return GoImportPath(s), ""
502 }
503
504
505 type Enum struct {
506 Desc protoreflect.EnumDescriptor
507
508 GoIdent GoIdent
509
510 Values []*EnumValue
511
512 Location Location
513 Comments CommentSet
514 }
515
516 func newEnum(gen *Plugin, f *File, parent *Message, desc protoreflect.EnumDescriptor) *Enum {
517 var loc Location
518 if parent != nil {
519 loc = parent.Location.appendPath(genid.DescriptorProto_EnumType_field_number, desc.Index())
520 } else {
521 loc = f.location.appendPath(genid.FileDescriptorProto_EnumType_field_number, desc.Index())
522 }
523 enum := &Enum{
524 Desc: desc,
525 GoIdent: newGoIdent(f, desc),
526 Location: loc,
527 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
528 }
529 gen.enumsByName[desc.FullName()] = enum
530 for i, vds := 0, enum.Desc.Values(); i < vds.Len(); i++ {
531 enum.Values = append(enum.Values, newEnumValue(gen, f, parent, enum, vds.Get(i)))
532 }
533 return enum
534 }
535
536
537 type EnumValue struct {
538 Desc protoreflect.EnumValueDescriptor
539
540 GoIdent GoIdent
541
542 Parent *Enum
543
544 Location Location
545 Comments CommentSet
546 }
547
548 func newEnumValue(gen *Plugin, f *File, message *Message, enum *Enum, desc protoreflect.EnumValueDescriptor) *EnumValue {
549
550
551
552
553 parentIdent := enum.GoIdent
554 if message != nil {
555 parentIdent = message.GoIdent
556 }
557 name := parentIdent.GoName + "_" + string(desc.Name())
558 loc := enum.Location.appendPath(genid.EnumDescriptorProto_Value_field_number, desc.Index())
559 return &EnumValue{
560 Desc: desc,
561 GoIdent: f.GoImportPath.Ident(name),
562 Parent: enum,
563 Location: loc,
564 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
565 }
566 }
567
568
569 type Message struct {
570 Desc protoreflect.MessageDescriptor
571
572 GoIdent GoIdent
573
574 Fields []*Field
575 Oneofs []*Oneof
576
577 Enums []*Enum
578 Messages []*Message
579 Extensions []*Extension
580
581 Location Location
582 Comments CommentSet
583 }
584
585 func newMessage(gen *Plugin, f *File, parent *Message, desc protoreflect.MessageDescriptor) *Message {
586 var loc Location
587 if parent != nil {
588 loc = parent.Location.appendPath(genid.DescriptorProto_NestedType_field_number, desc.Index())
589 } else {
590 loc = f.location.appendPath(genid.FileDescriptorProto_MessageType_field_number, desc.Index())
591 }
592 message := &Message{
593 Desc: desc,
594 GoIdent: newGoIdent(f, desc),
595 Location: loc,
596 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
597 }
598 gen.messagesByName[desc.FullName()] = message
599 for i, eds := 0, desc.Enums(); i < eds.Len(); i++ {
600 message.Enums = append(message.Enums, newEnum(gen, f, message, eds.Get(i)))
601 }
602 for i, mds := 0, desc.Messages(); i < mds.Len(); i++ {
603 message.Messages = append(message.Messages, newMessage(gen, f, message, mds.Get(i)))
604 }
605 for i, fds := 0, desc.Fields(); i < fds.Len(); i++ {
606 message.Fields = append(message.Fields, newField(gen, f, message, fds.Get(i)))
607 }
608 for i, ods := 0, desc.Oneofs(); i < ods.Len(); i++ {
609 message.Oneofs = append(message.Oneofs, newOneof(gen, f, message, ods.Get(i)))
610 }
611 for i, xds := 0, desc.Extensions(); i < xds.Len(); i++ {
612 message.Extensions = append(message.Extensions, newField(gen, f, message, xds.Get(i)))
613 }
614
615
616 for _, field := range message.Fields {
617 if od := field.Desc.ContainingOneof(); od != nil {
618 oneof := message.Oneofs[od.Index()]
619 field.Oneof = oneof
620 oneof.Fields = append(oneof.Fields, field)
621 }
622 }
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639 usedNames := map[string]bool{
640 "Reset": true,
641 "String": true,
642 "ProtoMessage": true,
643 "Marshal": true,
644 "Unmarshal": true,
645 "ExtensionRangeArray": true,
646 "ExtensionMap": true,
647 "Descriptor": true,
648 }
649 makeNameUnique := func(name string, hasGetter bool) string {
650 for usedNames[name] || (hasGetter && usedNames["Get"+name]) {
651 name += "_"
652 }
653 usedNames[name] = true
654 usedNames["Get"+name] = hasGetter
655 return name
656 }
657 for _, field := range message.Fields {
658 field.GoName = makeNameUnique(field.GoName, true)
659 field.GoIdent.GoName = message.GoIdent.GoName + "_" + field.GoName
660 if field.Oneof != nil && field.Oneof.Fields[0] == field {
661
662
663
664 field.Oneof.GoName = makeNameUnique(field.Oneof.GoName, false)
665 field.Oneof.GoIdent.GoName = message.GoIdent.GoName + "_" + field.Oneof.GoName
666 }
667 }
668
669
670
671
672
673 for _, field := range message.Fields {
674 if field.Oneof != nil {
675 Loop:
676 for {
677 for _, nestedMessage := range message.Messages {
678 if nestedMessage.GoIdent == field.GoIdent {
679 field.GoIdent.GoName += "_"
680 continue Loop
681 }
682 }
683 for _, nestedEnum := range message.Enums {
684 if nestedEnum.GoIdent == field.GoIdent {
685 field.GoIdent.GoName += "_"
686 continue Loop
687 }
688 }
689 break Loop
690 }
691 }
692 }
693
694 return message
695 }
696
697 func (message *Message) resolveDependencies(gen *Plugin) error {
698 for _, field := range message.Fields {
699 if err := field.resolveDependencies(gen); err != nil {
700 return err
701 }
702 }
703 for _, message := range message.Messages {
704 if err := message.resolveDependencies(gen); err != nil {
705 return err
706 }
707 }
708 for _, extension := range message.Extensions {
709 if err := extension.resolveDependencies(gen); err != nil {
710 return err
711 }
712 }
713 return nil
714 }
715
716
717 type Field struct {
718 Desc protoreflect.FieldDescriptor
719
720
721
722
723 GoName string
724
725
726
727
728
729 GoIdent GoIdent
730
731 Parent *Message
732 Oneof *Oneof
733 Extendee *Message
734
735 Enum *Enum
736 Message *Message
737
738 Location Location
739 Comments CommentSet
740 }
741
742 func newField(gen *Plugin, f *File, message *Message, desc protoreflect.FieldDescriptor) *Field {
743 var loc Location
744 switch {
745 case desc.IsExtension() && message == nil:
746 loc = f.location.appendPath(genid.FileDescriptorProto_Extension_field_number, desc.Index())
747 case desc.IsExtension() && message != nil:
748 loc = message.Location.appendPath(genid.DescriptorProto_Extension_field_number, desc.Index())
749 default:
750 loc = message.Location.appendPath(genid.DescriptorProto_Field_field_number, desc.Index())
751 }
752 camelCased := strs.GoCamelCase(string(desc.Name()))
753 var parentPrefix string
754 if message != nil {
755 parentPrefix = message.GoIdent.GoName + "_"
756 }
757 field := &Field{
758 Desc: desc,
759 GoName: camelCased,
760 GoIdent: GoIdent{
761 GoImportPath: f.GoImportPath,
762 GoName: parentPrefix + camelCased,
763 },
764 Parent: message,
765 Location: loc,
766 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
767 }
768 return field
769 }
770
771 func (field *Field) resolveDependencies(gen *Plugin) error {
772 desc := field.Desc
773 switch desc.Kind() {
774 case protoreflect.EnumKind:
775 name := field.Desc.Enum().FullName()
776 enum, ok := gen.enumsByName[name]
777 if !ok {
778 return fmt.Errorf("field %v: no descriptor for enum %v", desc.FullName(), name)
779 }
780 field.Enum = enum
781 case protoreflect.MessageKind, protoreflect.GroupKind:
782 name := desc.Message().FullName()
783 message, ok := gen.messagesByName[name]
784 if !ok {
785 return fmt.Errorf("field %v: no descriptor for type %v", desc.FullName(), name)
786 }
787 field.Message = message
788 }
789 if desc.IsExtension() {
790 name := desc.ContainingMessage().FullName()
791 message, ok := gen.messagesByName[name]
792 if !ok {
793 return fmt.Errorf("field %v: no descriptor for type %v", desc.FullName(), name)
794 }
795 field.Extendee = message
796 }
797 return nil
798 }
799
800
801 type Oneof struct {
802 Desc protoreflect.OneofDescriptor
803
804
805
806
807 GoName string
808
809
810 GoIdent GoIdent
811
812 Parent *Message
813
814 Fields []*Field
815
816 Location Location
817 Comments CommentSet
818 }
819
820 func newOneof(gen *Plugin, f *File, message *Message, desc protoreflect.OneofDescriptor) *Oneof {
821 loc := message.Location.appendPath(genid.DescriptorProto_OneofDecl_field_number, desc.Index())
822 camelCased := strs.GoCamelCase(string(desc.Name()))
823 parentPrefix := message.GoIdent.GoName + "_"
824 return &Oneof{
825 Desc: desc,
826 Parent: message,
827 GoName: camelCased,
828 GoIdent: GoIdent{
829 GoImportPath: f.GoImportPath,
830 GoName: parentPrefix + camelCased,
831 },
832 Location: loc,
833 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
834 }
835 }
836
837
838 type Extension = Field
839
840
841 type Service struct {
842 Desc protoreflect.ServiceDescriptor
843
844 GoName string
845
846 Methods []*Method
847
848 Location Location
849 Comments CommentSet
850 }
851
852 func newService(gen *Plugin, f *File, desc protoreflect.ServiceDescriptor) *Service {
853 loc := f.location.appendPath(genid.FileDescriptorProto_Service_field_number, desc.Index())
854 service := &Service{
855 Desc: desc,
856 GoName: strs.GoCamelCase(string(desc.Name())),
857 Location: loc,
858 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
859 }
860 for i, mds := 0, desc.Methods(); i < mds.Len(); i++ {
861 service.Methods = append(service.Methods, newMethod(gen, f, service, mds.Get(i)))
862 }
863 return service
864 }
865
866
867 type Method struct {
868 Desc protoreflect.MethodDescriptor
869
870 GoName string
871
872 Parent *Service
873
874 Input *Message
875 Output *Message
876
877 Location Location
878 Comments CommentSet
879 }
880
881 func newMethod(gen *Plugin, f *File, service *Service, desc protoreflect.MethodDescriptor) *Method {
882 loc := service.Location.appendPath(genid.ServiceDescriptorProto_Method_field_number, desc.Index())
883 method := &Method{
884 Desc: desc,
885 GoName: strs.GoCamelCase(string(desc.Name())),
886 Parent: service,
887 Location: loc,
888 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
889 }
890 return method
891 }
892
893 func (method *Method) resolveDependencies(gen *Plugin) error {
894 desc := method.Desc
895
896 inName := desc.Input().FullName()
897 in, ok := gen.messagesByName[inName]
898 if !ok {
899 return fmt.Errorf("method %v: no descriptor for type %v", desc.FullName(), inName)
900 }
901 method.Input = in
902
903 outName := desc.Output().FullName()
904 out, ok := gen.messagesByName[outName]
905 if !ok {
906 return fmt.Errorf("method %v: no descriptor for type %v", desc.FullName(), outName)
907 }
908 method.Output = out
909
910 return nil
911 }
912
913
914 type GeneratedFile struct {
915 gen *Plugin
916 skip bool
917 filename string
918 goImportPath GoImportPath
919 buf bytes.Buffer
920 packageNames map[GoImportPath]GoPackageName
921 usedPackageNames map[GoPackageName]bool
922 manualImports map[GoImportPath]bool
923 annotations map[string][]Annotation
924 }
925
926
927
928 func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
929 g := &GeneratedFile{
930 gen: gen,
931 filename: filename,
932 goImportPath: goImportPath,
933 packageNames: make(map[GoImportPath]GoPackageName),
934 usedPackageNames: make(map[GoPackageName]bool),
935 manualImports: make(map[GoImportPath]bool),
936 annotations: make(map[string][]Annotation),
937 }
938
939
940 for _, s := range types.Universe.Names() {
941 g.usedPackageNames[GoPackageName(s)] = true
942 }
943
944 gen.genFiles = append(gen.genFiles, g)
945 return g
946 }
947
948
949
950
951 func (g *GeneratedFile) P(v ...interface{}) {
952 for _, x := range v {
953 switch x := x.(type) {
954 case GoIdent:
955 fmt.Fprint(&g.buf, g.QualifiedGoIdent(x))
956 default:
957 fmt.Fprint(&g.buf, x)
958 }
959 }
960 fmt.Fprintln(&g.buf)
961 }
962
963
964
965
966
967
968 func (g *GeneratedFile) QualifiedGoIdent(ident GoIdent) string {
969 if ident.GoImportPath == g.goImportPath {
970 return ident.GoName
971 }
972 if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
973 return string(packageName) + "." + ident.GoName
974 }
975 packageName := cleanPackageName(path.Base(string(ident.GoImportPath)))
976 for i, orig := 1, packageName; g.usedPackageNames[packageName]; i++ {
977 packageName = orig + GoPackageName(strconv.Itoa(i))
978 }
979 g.packageNames[ident.GoImportPath] = packageName
980 g.usedPackageNames[packageName] = true
981 return string(packageName) + "." + ident.GoName
982 }
983
984
985
986
987
988
989 func (g *GeneratedFile) Import(importPath GoImportPath) {
990 g.manualImports[importPath] = true
991 }
992
993
994 func (g *GeneratedFile) Write(p []byte) (n int, err error) {
995 return g.buf.Write(p)
996 }
997
998
999 func (g *GeneratedFile) Skip() {
1000 g.skip = true
1001 }
1002
1003
1004
1005 func (g *GeneratedFile) Unskip() {
1006 g.skip = false
1007 }
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017 func (g *GeneratedFile) Annotate(symbol string, loc Location) {
1018 g.AnnotateSymbol(symbol, Annotation{Location: loc})
1019 }
1020
1021
1022
1023
1024
1025 type Annotation struct {
1026
1027 Location Location
1028
1029
1030 Semantic *descriptorpb.GeneratedCodeInfo_Annotation_Semantic
1031 }
1032
1033
1034
1035
1036
1037
1038
1039 func (g *GeneratedFile) AnnotateSymbol(symbol string, info Annotation) {
1040 g.annotations[symbol] = append(g.annotations[symbol], info)
1041 }
1042
1043
1044 func (g *GeneratedFile) Content() ([]byte, error) {
1045 if !strings.HasSuffix(g.filename, ".go") {
1046 return g.buf.Bytes(), nil
1047 }
1048
1049
1050 original := g.buf.Bytes()
1051 fset := token.NewFileSet()
1052 file, err := parser.ParseFile(fset, "", original, parser.ParseComments)
1053 if err != nil {
1054
1055
1056
1057 var src bytes.Buffer
1058 s := bufio.NewScanner(bytes.NewReader(original))
1059 for line := 1; s.Scan(); line++ {
1060 fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
1061 }
1062 return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
1063 }
1064
1065
1066 var importPaths [][2]string
1067 rewriteImport := func(importPath string) string {
1068 if f := g.gen.opts.ImportRewriteFunc; f != nil {
1069 return string(f(GoImportPath(importPath)))
1070 }
1071 return importPath
1072 }
1073 for importPath := range g.packageNames {
1074 pkgName := string(g.packageNames[GoImportPath(importPath)])
1075 pkgPath := rewriteImport(string(importPath))
1076 importPaths = append(importPaths, [2]string{pkgName, pkgPath})
1077 }
1078 for importPath := range g.manualImports {
1079 if _, ok := g.packageNames[importPath]; !ok {
1080 pkgPath := rewriteImport(string(importPath))
1081 importPaths = append(importPaths, [2]string{"_", pkgPath})
1082 }
1083 }
1084 sort.Slice(importPaths, func(i, j int) bool {
1085 return importPaths[i][1] < importPaths[j][1]
1086 })
1087
1088
1089 if len(importPaths) > 0 {
1090
1091
1092 pos := file.Package
1093 tokFile := fset.File(file.Package)
1094 pkgLine := tokFile.Line(file.Package)
1095 for _, c := range file.Comments {
1096 if tokFile.Line(c.Pos()) > pkgLine {
1097 break
1098 }
1099 pos = c.End()
1100 }
1101
1102
1103 impDecl := &ast.GenDecl{
1104 Tok: token.IMPORT,
1105 TokPos: pos,
1106 Lparen: pos,
1107 Rparen: pos,
1108 }
1109 for _, importPath := range importPaths {
1110 impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
1111 Name: &ast.Ident{
1112 Name: importPath[0],
1113 NamePos: pos,
1114 },
1115 Path: &ast.BasicLit{
1116 Kind: token.STRING,
1117 Value: strconv.Quote(importPath[1]),
1118 ValuePos: pos,
1119 },
1120 EndPos: pos,
1121 })
1122 }
1123 file.Decls = append([]ast.Decl{impDecl}, file.Decls...)
1124 }
1125
1126 var out bytes.Buffer
1127 if err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(&out, fset, file); err != nil {
1128 return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
1129 }
1130 return out.Bytes(), nil
1131 }
1132
1133 func (g *GeneratedFile) generatedCodeInfo(content []byte) (*descriptorpb.GeneratedCodeInfo, error) {
1134 fset := token.NewFileSet()
1135 astFile, err := parser.ParseFile(fset, "", content, 0)
1136 if err != nil {
1137 return nil, err
1138 }
1139 info := &descriptorpb.GeneratedCodeInfo{}
1140
1141 seenAnnotations := make(map[string]bool)
1142 annotate := func(s string, ident *ast.Ident) {
1143 seenAnnotations[s] = true
1144 for _, a := range g.annotations[s] {
1145 info.Annotation = append(info.Annotation, &descriptorpb.GeneratedCodeInfo_Annotation{
1146 SourceFile: proto.String(a.Location.SourceFile),
1147 Path: a.Location.Path,
1148 Begin: proto.Int32(int32(fset.Position(ident.Pos()).Offset)),
1149 End: proto.Int32(int32(fset.Position(ident.End()).Offset)),
1150 Semantic: a.Semantic,
1151 })
1152 }
1153 }
1154 for _, decl := range astFile.Decls {
1155 switch decl := decl.(type) {
1156 case *ast.GenDecl:
1157 for _, spec := range decl.Specs {
1158 switch spec := spec.(type) {
1159 case *ast.TypeSpec:
1160 annotate(spec.Name.Name, spec.Name)
1161 switch st := spec.Type.(type) {
1162 case *ast.StructType:
1163 for _, field := range st.Fields.List {
1164 for _, name := range field.Names {
1165 annotate(spec.Name.Name+"."+name.Name, name)
1166 }
1167 }
1168 case *ast.InterfaceType:
1169 for _, field := range st.Methods.List {
1170 for _, name := range field.Names {
1171 annotate(spec.Name.Name+"."+name.Name, name)
1172 }
1173 }
1174 }
1175 case *ast.ValueSpec:
1176 for _, name := range spec.Names {
1177 annotate(name.Name, name)
1178 }
1179 }
1180 }
1181 case *ast.FuncDecl:
1182 if decl.Recv == nil {
1183 annotate(decl.Name.Name, decl.Name)
1184 } else {
1185 recv := decl.Recv.List[0].Type
1186 if s, ok := recv.(*ast.StarExpr); ok {
1187 recv = s.X
1188 }
1189 if id, ok := recv.(*ast.Ident); ok {
1190 annotate(id.Name+"."+decl.Name.Name, decl.Name)
1191 }
1192 }
1193 }
1194 }
1195 for a := range g.annotations {
1196 if !seenAnnotations[a] {
1197 return nil, fmt.Errorf("%v: no symbol matching annotation %q", g.filename, a)
1198 }
1199 }
1200
1201 return info, nil
1202 }
1203
1204
1205
1206 func (g *GeneratedFile) metaFile(content []byte) (string, error) {
1207 info, err := g.generatedCodeInfo(content)
1208 if err != nil {
1209 return "", err
1210 }
1211
1212 b, err := prototext.Marshal(info)
1213 if err != nil {
1214 return "", err
1215 }
1216 return string(b), nil
1217 }
1218
1219
1220
1221 type GoIdent struct {
1222 GoName string
1223 GoImportPath GoImportPath
1224 }
1225
1226 func (id GoIdent) String() string { return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName) }
1227
1228
1229 func newGoIdent(f *File, d protoreflect.Descriptor) GoIdent {
1230 name := strings.TrimPrefix(string(d.FullName()), string(f.Desc.Package())+".")
1231 return GoIdent{
1232 GoName: strs.GoCamelCase(name),
1233 GoImportPath: f.GoImportPath,
1234 }
1235 }
1236
1237
1238
1239 type GoImportPath string
1240
1241 func (p GoImportPath) String() string { return strconv.Quote(string(p)) }
1242
1243
1244 func (p GoImportPath) Ident(s string) GoIdent {
1245 return GoIdent{GoName: s, GoImportPath: p}
1246 }
1247
1248
1249 type GoPackageName string
1250
1251
1252 func cleanPackageName(name string) GoPackageName {
1253 return GoPackageName(strs.GoSanitized(name))
1254 }
1255
1256 type pathType int
1257
1258 const (
1259 pathTypeImport pathType = iota
1260 pathTypeSourceRelative
1261 )
1262
1263
1264
1265
1266
1267 type Location struct {
1268 SourceFile string
1269 Path protoreflect.SourcePath
1270 }
1271
1272
1273 func (loc Location) appendPath(num protoreflect.FieldNumber, idx int) Location {
1274 loc.Path = append(protoreflect.SourcePath(nil), loc.Path...)
1275 loc.Path = append(loc.Path, int32(num), int32(idx))
1276 return loc
1277 }
1278
1279
1280
1281 type CommentSet struct {
1282 LeadingDetached []Comments
1283 Leading Comments
1284 Trailing Comments
1285 }
1286
1287 func makeCommentSet(loc protoreflect.SourceLocation) CommentSet {
1288 var leadingDetached []Comments
1289 for _, s := range loc.LeadingDetachedComments {
1290 leadingDetached = append(leadingDetached, Comments(s))
1291 }
1292 return CommentSet{
1293 LeadingDetached: leadingDetached,
1294 Leading: Comments(loc.LeadingComments),
1295 Trailing: Comments(loc.TrailingComments),
1296 }
1297 }
1298
1299
1300 type Comments string
1301
1302
1303
1304
1305 func (c Comments) String() string {
1306 if c == "" {
1307 return ""
1308 }
1309 var b []byte
1310 for _, line := range strings.Split(strings.TrimSuffix(string(c), "\n"), "\n") {
1311 b = append(b, "//"...)
1312 b = append(b, line...)
1313 b = append(b, "\n"...)
1314 }
1315 return string(b)
1316 }
1317
1318
1319
1320
1321
1322
1323 type extensionRegistry struct {
1324 base *protoregistry.Types
1325 local *protoregistry.Types
1326 }
1327
1328 func newExtensionRegistry() *extensionRegistry {
1329 return &extensionRegistry{
1330 base: protoregistry.GlobalTypes,
1331 local: &protoregistry.Types{},
1332 }
1333 }
1334
1335
1336 func (e *extensionRegistry) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
1337 if xt, err := e.local.FindExtensionByName(field); err == nil {
1338 return xt, nil
1339 }
1340
1341 return e.base.FindExtensionByName(field)
1342 }
1343
1344
1345 func (e *extensionRegistry) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
1346 if xt, err := e.local.FindExtensionByNumber(message, field); err == nil {
1347 return xt, nil
1348 }
1349
1350 return e.base.FindExtensionByNumber(message, field)
1351 }
1352
1353 func (e *extensionRegistry) hasNovelExtensions() bool {
1354 return e.local.NumExtensions() > 0
1355 }
1356
1357 func (e *extensionRegistry) registerAllExtensionsFromFile(f protoreflect.FileDescriptor) error {
1358 if err := e.registerAllExtensions(f.Extensions()); err != nil {
1359 return err
1360 }
1361 return nil
1362 }
1363
1364 func (e *extensionRegistry) registerAllExtensionsFromMessage(ms protoreflect.MessageDescriptors) error {
1365 for i := 0; i < ms.Len(); i++ {
1366 m := ms.Get(i)
1367 if err := e.registerAllExtensions(m.Extensions()); err != nil {
1368 return err
1369 }
1370 }
1371 return nil
1372 }
1373
1374 func (e *extensionRegistry) registerAllExtensions(exts protoreflect.ExtensionDescriptors) error {
1375 for i := 0; i < exts.Len(); i++ {
1376 if err := e.registerExtension(exts.Get(i)); err != nil {
1377 return err
1378 }
1379 }
1380 return nil
1381 }
1382
1383
1384
1385 func (e *extensionRegistry) registerExtension(xd protoreflect.ExtensionDescriptor) error {
1386 if _, err := e.FindExtensionByName(xd.FullName()); err != protoregistry.NotFound {
1387
1388 return err
1389 }
1390 return e.local.RegisterExtension(dynamicpb.NewExtensionType(xd))
1391 }
1392
View as plain text