1
2
3
4
5
6
7 package main
8
9 import (
10 "bytes"
11 "flag"
12 "fmt"
13 "go/format"
14 "io/ioutil"
15 "os"
16 "os/exec"
17 "path"
18 "path/filepath"
19 "regexp"
20 "strconv"
21 "strings"
22 "text/template"
23 )
24
25 var (
26 run bool
27 outfile string
28 repoRoot string
29 )
30
31 func main() {
32 flag.BoolVar(&run, "execute", false, "Write generated files to destination.")
33 flag.StringVar(&outfile, "outfile", "", "Write this specific file to stdout.")
34 flag.Parse()
35
36
37 if outfile == "" {
38 out, err := exec.Command("git", "rev-parse", "--show-toplevel").CombinedOutput()
39 check(err)
40 repoRoot = strings.TrimSpace(string(out))
41 chdirRoot()
42 }
43
44 writeSource("internal/filedesc/desc_list_gen.go", generateDescListTypes())
45 writeSource("internal/impl/codec_gen.go", generateImplCodec())
46 writeSource("internal/impl/message_reflect_gen.go", generateImplMessage())
47 writeSource("internal/impl/merge_gen.go", generateImplMerge())
48 writeSource("proto/decode_gen.go", generateProtoDecode())
49 writeSource("proto/encode_gen.go", generateProtoEncode())
50 writeSource("proto/size_gen.go", generateProtoSize())
51 }
52
53
54 func chdirRoot() {
55 out, err := exec.Command("git", "rev-parse", "--show-toplevel").CombinedOutput()
56 check(err)
57 check(os.Chdir(strings.TrimSpace(string(out))))
58 }
59
60
61 type Expr string
62
63 type DescriptorType string
64
65 const (
66 MessageDesc DescriptorType = "Message"
67 FieldDesc DescriptorType = "Field"
68 OneofDesc DescriptorType = "Oneof"
69 ExtensionDesc DescriptorType = "Extension"
70 EnumDesc DescriptorType = "Enum"
71 EnumValueDesc DescriptorType = "EnumValue"
72 ServiceDesc DescriptorType = "Service"
73 MethodDesc DescriptorType = "Method"
74 )
75
76 func (d DescriptorType) Expr() Expr {
77 return "protoreflect." + Expr(d) + "Descriptor"
78 }
79 func (d DescriptorType) NumberExpr() Expr {
80 switch d {
81 case FieldDesc:
82 return "protoreflect.FieldNumber"
83 case EnumValueDesc:
84 return "protoreflect.EnumNumber"
85 default:
86 return ""
87 }
88 }
89
90 func generateDescListTypes() string {
91 return mustExecute(descListTypesTemplate, []DescriptorType{
92 EnumDesc, EnumValueDesc, MessageDesc, FieldDesc, OneofDesc, ExtensionDesc, ServiceDesc, MethodDesc,
93 })
94 }
95
96 var descListTypesTemplate = template.Must(template.New("").Parse(`
97 {{- range .}}
98 {{$nameList := (printf "%ss" .)}} {{/* e.g., "Messages" */}}
99 {{$nameDesc := (printf "%s" .)}} {{/* e.g., "Message" */}}
100
101 type {{$nameList}} struct {
102 List []{{$nameDesc}}
103 once sync.Once
104 byName map[protoreflect.Name]*{{$nameDesc}} // protected by once
105 {{- if (eq . "Field")}}
106 byJSON map[string]*{{$nameDesc}} // protected by once
107 byText map[string]*{{$nameDesc}} // protected by once
108 {{- end}}
109 {{- if .NumberExpr}}
110 byNum map[{{.NumberExpr}}]*{{$nameDesc}} // protected by once
111 {{- end}}
112 }
113
114 func (p *{{$nameList}}) Len() int {
115 return len(p.List)
116 }
117 func (p *{{$nameList}}) Get(i int) {{.Expr}} {
118 return &p.List[i]
119 }
120 func (p *{{$nameList}}) ByName(s protoreflect.Name) {{.Expr}} {
121 if d := p.lazyInit().byName[s]; d != nil {
122 return d
123 }
124 return nil
125 }
126 {{- if (eq . "Field")}}
127 func (p *{{$nameList}}) ByJSONName(s string) {{.Expr}} {
128 if d := p.lazyInit().byJSON[s]; d != nil {
129 return d
130 }
131 return nil
132 }
133 func (p *{{$nameList}}) ByTextName(s string) {{.Expr}} {
134 if d := p.lazyInit().byText[s]; d != nil {
135 return d
136 }
137 return nil
138 }
139 {{- end}}
140 {{- if .NumberExpr}}
141 func (p *{{$nameList}}) ByNumber(n {{.NumberExpr}}) {{.Expr}} {
142 if d := p.lazyInit().byNum[n]; d != nil {
143 return d
144 }
145 return nil
146 }
147 {{- end}}
148 func (p *{{$nameList}}) Format(s fmt.State, r rune) {
149 descfmt.FormatList(s, r, p)
150 }
151 func (p *{{$nameList}}) ProtoInternal(pragma.DoNotImplement) {}
152 func (p *{{$nameList}}) lazyInit() *{{$nameList}} {
153 p.once.Do(func() {
154 if len(p.List) > 0 {
155 p.byName = make(map[protoreflect.Name]*{{$nameDesc}}, len(p.List))
156 {{- if (eq . "Field")}}
157 p.byJSON = make(map[string]*{{$nameDesc}}, len(p.List))
158 p.byText = make(map[string]*{{$nameDesc}}, len(p.List))
159 {{- end}}
160 {{- if .NumberExpr}}
161 p.byNum = make(map[{{.NumberExpr}}]*{{$nameDesc}}, len(p.List))
162 {{- end}}
163 for i := range p.List {
164 d := &p.List[i]
165 if _, ok := p.byName[d.Name()]; !ok {
166 p.byName[d.Name()] = d
167 }
168 {{- if (eq . "Field")}}
169 if _, ok := p.byJSON[d.JSONName()]; !ok {
170 p.byJSON[d.JSONName()] = d
171 }
172 if _, ok := p.byText[d.TextName()]; !ok {
173 p.byText[d.TextName()] = d
174 }
175 {{- end}}
176 {{- if .NumberExpr}}
177 if _, ok := p.byNum[d.Number()]; !ok {
178 p.byNum[d.Number()] = d
179 }
180 {{- end}}
181 }
182 }
183 })
184 return p
185 }
186 {{- end}}
187 `))
188
189 func mustExecute(t *template.Template, data interface{}) string {
190 var b bytes.Buffer
191 if err := t.Execute(&b, data); err != nil {
192 panic(err)
193 }
194 return b.String()
195 }
196
197 func writeSource(file, src string) {
198
199 var imports []string
200 for _, pkg := range []string{
201 "fmt",
202 "math",
203 "reflect",
204 "sync",
205 "unicode/utf8",
206 "",
207 "google.golang.org/protobuf/internal/descfmt",
208 "google.golang.org/protobuf/encoding/protowire",
209 "google.golang.org/protobuf/internal/errors",
210 "google.golang.org/protobuf/internal/strs",
211 "google.golang.org/protobuf/internal/pragma",
212 "google.golang.org/protobuf/reflect/protoreflect",
213 "google.golang.org/protobuf/runtime/protoiface",
214 } {
215 if pkg == "" {
216 imports = append(imports, "")
217 } else if regexp.MustCompile(`[^\pL_0-9]` + path.Base(pkg) + `\.`).MatchString(src) {
218 imports = append(imports, strconv.Quote(pkg))
219 }
220 }
221
222 s := strings.Join([]string{
223 "// Copyright 2018 The Go Authors. All rights reserved.",
224 "// Use of this source code is governed by a BSD-style",
225 "// license that can be found in the LICENSE file.",
226 "",
227 "// Code generated by generate-types. DO NOT EDIT.",
228 "",
229 "package " + path.Base(path.Dir(path.Join("proto", file))),
230 "",
231 "import (" + strings.Join(imports, "\n") + ")",
232 "",
233 src,
234 }, "\n")
235 b, err := format.Source([]byte(s))
236 if err != nil {
237
238 fmt.Fprintf(os.Stderr, "%v:%v\n", file, err)
239 b = []byte(s)
240 }
241
242 if outfile != "" {
243 if outfile == file {
244 os.Stdout.Write(b)
245 }
246 return
247 }
248
249 absFile := filepath.Join(repoRoot, file)
250 if run {
251 prev, _ := ioutil.ReadFile(absFile)
252 if !bytes.Equal(b, prev) {
253 fmt.Println("#", file)
254 check(ioutil.WriteFile(absFile, b, 0664))
255 }
256 } else {
257 check(ioutil.WriteFile(absFile+".tmp", b, 0664))
258 defer os.Remove(absFile + ".tmp")
259
260 cmd := exec.Command("diff", file, file+".tmp", "-N", "-u")
261 cmd.Dir = repoRoot
262 cmd.Stdout = os.Stdout
263 cmd.Run()
264 }
265 }
266
267 func check(err error) {
268 if err != nil {
269 panic(err)
270 }
271 }
272
View as plain text