...

Source file src/google.golang.org/protobuf/internal/cmd/generate-types/main.go

Documentation: google.golang.org/protobuf/internal/cmd/generate-types

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:generate go run -tags protolegacy . -execute
     6  
     7  package main
     8  
     9  import (
    10  	"bytes"
    11  	"flag"
    12  	"fmt"
    13  	"go/format"
    14  	"io/ioutil"
    15  	"os"
    16  	"os/exec"
    17  	"path"
    18  	"path/filepath"
    19  	"regexp"
    20  	"strconv"
    21  	"strings"
    22  	"text/template"
    23  )
    24  
    25  var (
    26  	run      bool
    27  	outfile  string
    28  	repoRoot string
    29  )
    30  
    31  func main() {
    32  	flag.BoolVar(&run, "execute", false, "Write generated files to destination.")
    33  	flag.StringVar(&outfile, "outfile", "", "Write this specific file to stdout.")
    34  	flag.Parse()
    35  
    36  	// Determine repository root path.
    37  	if outfile == "" {
    38  		out, err := exec.Command("git", "rev-parse", "--show-toplevel").CombinedOutput()
    39  		check(err)
    40  		repoRoot = strings.TrimSpace(string(out))
    41  		chdirRoot()
    42  	}
    43  
    44  	writeSource("internal/filedesc/desc_list_gen.go", generateDescListTypes())
    45  	writeSource("internal/impl/codec_gen.go", generateImplCodec())
    46  	writeSource("internal/impl/message_reflect_gen.go", generateImplMessage())
    47  	writeSource("internal/impl/merge_gen.go", generateImplMerge())
    48  	writeSource("proto/decode_gen.go", generateProtoDecode())
    49  	writeSource("proto/encode_gen.go", generateProtoEncode())
    50  	writeSource("proto/size_gen.go", generateProtoSize())
    51  }
    52  
    53  // chdirRoot changes the working directory to the repository root.
    54  func chdirRoot() {
    55  	out, err := exec.Command("git", "rev-parse", "--show-toplevel").CombinedOutput()
    56  	check(err)
    57  	check(os.Chdir(strings.TrimSpace(string(out))))
    58  }
    59  
    60  // Expr is a single line Go expression.
    61  type Expr string
    62  
    63  type DescriptorType string
    64  
    65  const (
    66  	MessageDesc   DescriptorType = "Message"
    67  	FieldDesc     DescriptorType = "Field"
    68  	OneofDesc     DescriptorType = "Oneof"
    69  	ExtensionDesc DescriptorType = "Extension"
    70  	EnumDesc      DescriptorType = "Enum"
    71  	EnumValueDesc DescriptorType = "EnumValue"
    72  	ServiceDesc   DescriptorType = "Service"
    73  	MethodDesc    DescriptorType = "Method"
    74  )
    75  
    76  func (d DescriptorType) Expr() Expr {
    77  	return "protoreflect." + Expr(d) + "Descriptor"
    78  }
    79  func (d DescriptorType) NumberExpr() Expr {
    80  	switch d {
    81  	case FieldDesc:
    82  		return "protoreflect.FieldNumber"
    83  	case EnumValueDesc:
    84  		return "protoreflect.EnumNumber"
    85  	default:
    86  		return ""
    87  	}
    88  }
    89  
    90  func generateDescListTypes() string {
    91  	return mustExecute(descListTypesTemplate, []DescriptorType{
    92  		EnumDesc, EnumValueDesc, MessageDesc, FieldDesc, OneofDesc, ExtensionDesc, ServiceDesc, MethodDesc,
    93  	})
    94  }
    95  
    96  var descListTypesTemplate = template.Must(template.New("").Parse(`
    97  	{{- range .}}
    98  	{{$nameList := (printf "%ss" .)}} {{/* e.g., "Messages" */}}
    99  	{{$nameDesc := (printf "%s"  .)}} {{/* e.g., "Message" */}}
   100  
   101  	type {{$nameList}} struct {
   102  		List   []{{$nameDesc}}
   103  		once   sync.Once
   104  		byName map[protoreflect.Name]*{{$nameDesc}} // protected by once
   105  		{{- if (eq . "Field")}}
   106  		byJSON map[string]*{{$nameDesc}}            // protected by once
   107  		byText map[string]*{{$nameDesc}}            // protected by once
   108  		{{- end}}
   109  		{{- if .NumberExpr}}
   110  		byNum  map[{{.NumberExpr}}]*{{$nameDesc}}   // protected by once
   111  		{{- end}}
   112  	}
   113  
   114  	func (p *{{$nameList}}) Len() int {
   115  		return len(p.List)
   116  	}
   117  	func (p *{{$nameList}}) Get(i int) {{.Expr}} {
   118  		return &p.List[i]
   119  	}
   120  	func (p *{{$nameList}}) ByName(s protoreflect.Name) {{.Expr}} {
   121  		if d := p.lazyInit().byName[s]; d != nil {
   122  			return d
   123  		}
   124  		return nil
   125  	}
   126  	{{- if (eq . "Field")}}
   127  	func (p *{{$nameList}}) ByJSONName(s string) {{.Expr}} {
   128  		if d := p.lazyInit().byJSON[s]; d != nil {
   129  			return d
   130  		}
   131  		return nil
   132  	}
   133  	func (p *{{$nameList}}) ByTextName(s string) {{.Expr}} {
   134  		if d := p.lazyInit().byText[s]; d != nil {
   135  			return d
   136  		}
   137  		return nil
   138  	}
   139  	{{- end}}
   140  	{{- if .NumberExpr}}
   141  	func (p *{{$nameList}}) ByNumber(n {{.NumberExpr}}) {{.Expr}} {
   142  		if d := p.lazyInit().byNum[n]; d != nil {
   143  			return d
   144  		}
   145  		return nil
   146  	}
   147  	{{- end}}
   148  	func (p *{{$nameList}}) Format(s fmt.State, r rune) {
   149  		descfmt.FormatList(s, r, p)
   150  	}
   151  	func (p *{{$nameList}}) ProtoInternal(pragma.DoNotImplement) {}
   152  	func (p *{{$nameList}}) lazyInit() *{{$nameList}} {
   153  		p.once.Do(func() {
   154  			if len(p.List) > 0 {
   155  				p.byName = make(map[protoreflect.Name]*{{$nameDesc}}, len(p.List))
   156  				{{- if (eq . "Field")}}
   157  				p.byJSON = make(map[string]*{{$nameDesc}}, len(p.List))
   158  				p.byText = make(map[string]*{{$nameDesc}}, len(p.List))
   159  				{{- end}}
   160  				{{- if .NumberExpr}}
   161  				p.byNum = make(map[{{.NumberExpr}}]*{{$nameDesc}}, len(p.List))
   162  				{{- end}}
   163  				for i := range p.List {
   164  					d := &p.List[i]
   165  					if _, ok := p.byName[d.Name()]; !ok {
   166  						p.byName[d.Name()] = d
   167  					}
   168  					{{- if (eq . "Field")}}
   169  					if _, ok := p.byJSON[d.JSONName()]; !ok {
   170  						p.byJSON[d.JSONName()] = d
   171  					}
   172  					if _, ok := p.byText[d.TextName()]; !ok {
   173  						p.byText[d.TextName()] = d
   174  					}
   175  					{{- end}}
   176  					{{- if .NumberExpr}}
   177  					if _, ok := p.byNum[d.Number()]; !ok {
   178  						p.byNum[d.Number()] = d
   179  					}
   180  					{{- end}}
   181  				}
   182  			}
   183  		})
   184  		return p
   185  	}
   186  	{{- end}}
   187  `))
   188  
   189  func mustExecute(t *template.Template, data interface{}) string {
   190  	var b bytes.Buffer
   191  	if err := t.Execute(&b, data); err != nil {
   192  		panic(err)
   193  	}
   194  	return b.String()
   195  }
   196  
   197  func writeSource(file, src string) {
   198  	// Crude but effective way to detect used imports.
   199  	var imports []string
   200  	for _, pkg := range []string{
   201  		"fmt",
   202  		"math",
   203  		"reflect",
   204  		"sync",
   205  		"unicode/utf8",
   206  		"",
   207  		"google.golang.org/protobuf/internal/descfmt",
   208  		"google.golang.org/protobuf/encoding/protowire",
   209  		"google.golang.org/protobuf/internal/errors",
   210  		"google.golang.org/protobuf/internal/strs",
   211  		"google.golang.org/protobuf/internal/pragma",
   212  		"google.golang.org/protobuf/reflect/protoreflect",
   213  		"google.golang.org/protobuf/runtime/protoiface",
   214  	} {
   215  		if pkg == "" {
   216  			imports = append(imports, "") // blank line between stdlib and proto packages
   217  		} else if regexp.MustCompile(`[^\pL_0-9]` + path.Base(pkg) + `\.`).MatchString(src) {
   218  			imports = append(imports, strconv.Quote(pkg))
   219  		}
   220  	}
   221  
   222  	s := strings.Join([]string{
   223  		"// Copyright 2018 The Go Authors. All rights reserved.",
   224  		"// Use of this source code is governed by a BSD-style",
   225  		"// license that can be found in the LICENSE file.",
   226  		"",
   227  		"// Code generated by generate-types. DO NOT EDIT.",
   228  		"",
   229  		"package " + path.Base(path.Dir(path.Join("proto", file))),
   230  		"",
   231  		"import (" + strings.Join(imports, "\n") + ")",
   232  		"",
   233  		src,
   234  	}, "\n")
   235  	b, err := format.Source([]byte(s))
   236  	if err != nil {
   237  		// Just print the error and output the unformatted file for examination.
   238  		fmt.Fprintf(os.Stderr, "%v:%v\n", file, err)
   239  		b = []byte(s)
   240  	}
   241  
   242  	if outfile != "" {
   243  		if outfile == file {
   244  			os.Stdout.Write(b)
   245  		}
   246  		return
   247  	}
   248  
   249  	absFile := filepath.Join(repoRoot, file)
   250  	if run {
   251  		prev, _ := ioutil.ReadFile(absFile)
   252  		if !bytes.Equal(b, prev) {
   253  			fmt.Println("#", file)
   254  			check(ioutil.WriteFile(absFile, b, 0664))
   255  		}
   256  	} else {
   257  		check(ioutil.WriteFile(absFile+".tmp", b, 0664))
   258  		defer os.Remove(absFile + ".tmp")
   259  
   260  		cmd := exec.Command("diff", file, file+".tmp", "-N", "-u")
   261  		cmd.Dir = repoRoot
   262  		cmd.Stdout = os.Stdout
   263  		cmd.Run()
   264  	}
   265  }
   266  
   267  func check(err error) {
   268  	if err != nil {
   269  		panic(err)
   270  	}
   271  }
   272  

View as plain text