1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package main
24
25 import (
26 "bufio"
27 "bytes"
28 "flag"
29 "fmt"
30 "go/ast"
31 "go/format"
32 "go/parser"
33 "go/token"
34 "io"
35 "log"
36 "os"
37 "path"
38 "path/filepath"
39 "regexp"
40 "strconv"
41 "strings"
42 )
43
44 const validGOOS = "aix|darwin|dragonfly|freebsd|linux|netbsd|openbsd|solaris"
45
46
47 func getValidGOOS(filename string) (string, bool) {
48 matches := regexp.MustCompile(`_(` + validGOOS + `)\.go$`).FindStringSubmatch(filename)
49 if len(matches) != 2 {
50 return "", false
51 }
52 return matches[1], true
53 }
54
55
56 type codeElem struct {
57 tok token.Token
58 src string
59 }
60
61
62 func newCodeElem(tok token.Token, node ast.Node) (codeElem, error) {
63 var b strings.Builder
64 err := format.Node(&b, token.NewFileSet(), node)
65 if err != nil {
66 return codeElem{}, err
67 }
68 return codeElem{tok, b.String()}, nil
69 }
70
71
72 type codeSet struct {
73 set map[codeElem]bool
74 }
75
76
77 func newCodeSet() *codeSet { return &codeSet{make(map[codeElem]bool)} }
78
79
80 func (c *codeSet) add(elem codeElem) { c.set[elem] = true }
81
82
83 func (c *codeSet) has(elem codeElem) bool { return c.set[elem] }
84
85
86 func (c *codeSet) isEmpty() bool { return len(c.set) == 0 }
87
88
89 func (c *codeSet) intersection(a *codeSet) *codeSet {
90 res := newCodeSet()
91
92 for elem := range c.set {
93 if a.has(elem) {
94 res.add(elem)
95 }
96 }
97 return res
98 }
99
100
101 func (c *codeSet) keepCommon(elem codeElem) bool {
102 switch elem.tok {
103 case token.VAR:
104
105 return false
106 case token.CONST, token.TYPE, token.FUNC, token.COMMENT:
107
108 return c.has(elem)
109 case token.IMPORT:
110
111 return true
112 }
113
114 log.Fatalf("keepCommon: invalid elem %v", elem)
115 return true
116 }
117
118
119 func (c *codeSet) keepArchSpecific(elem codeElem) bool {
120 switch elem.tok {
121 case token.CONST, token.TYPE, token.FUNC:
122
123 return !c.has(elem)
124 }
125 return true
126 }
127
128
129 type srcFile struct {
130 name string
131 src []byte
132 }
133
134
135 type filterFn func(codeElem) bool
136
137
138
139
140 func filter(src interface{}, keep filterFn) ([]byte, error) {
141
142 fset := token.NewFileSet()
143 f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
144 if err != nil {
145 return nil, err
146 }
147 cmap := ast.NewCommentMap(fset, f, f.Comments)
148
149
150 var groups specGroups = make(map[string]int)
151 var groupID int
152
153 decls := f.Decls
154 f.Decls = f.Decls[:0]
155 for _, decl := range decls {
156 switch decl := decl.(type) {
157 case *ast.GenDecl:
158
159 specs := decl.Specs
160 decl.Specs = decl.Specs[:0]
161 for i, spec := range specs {
162 elem, err := newCodeElem(decl.Tok, spec)
163 if err != nil {
164 return nil, err
165 }
166
167
168 if i > 0 && fset.Position(specs[i-1].End()).Line < fset.Position(spec.Pos()).Line-1 {
169 groupID++
170 }
171
172
173 if keep(elem) {
174 decl.Specs = append(decl.Specs, spec)
175 groups.add(elem.src, groupID)
176 }
177 }
178
179 if len(decl.Specs) > 0 {
180 f.Decls = append(f.Decls, decl)
181 }
182 case *ast.FuncDecl:
183
184 elem, err := newCodeElem(token.FUNC, decl)
185 if err != nil {
186 return nil, err
187 }
188 if keep(elem) {
189 f.Decls = append(f.Decls, decl)
190 }
191 }
192 }
193
194
195 if cmap[f] != nil {
196 commentGroups := cmap[f]
197 cmap[f] = cmap[f][:0]
198 for _, cGrp := range commentGroups {
199 if keep(codeElem{token.COMMENT, cGrp.Text()}) {
200 cmap[f] = append(cmap[f], cGrp)
201 }
202 }
203 }
204 f.Comments = cmap.Filter(f).Comments()
205
206
207 var buf bytes.Buffer
208 if err = format.Node(&buf, fset, f); err != nil {
209 return nil, err
210 }
211
212 groupedSrc, err := groups.filterEmptyLines(&buf)
213 if err != nil {
214 return nil, err
215 }
216
217 return filterImports(groupedSrc)
218 }
219
220
221 func getCommonSet(files []srcFile) (*codeSet, error) {
222 if len(files) == 0 {
223 return nil, fmt.Errorf("no files provided")
224 }
225
226 baseSet, err := getCodeSet(files[0].src)
227 if err != nil {
228 return nil, err
229 }
230
231
232
233 for _, f := range files[1:] {
234 set, err := getCodeSet(f.src)
235 if err != nil {
236 return nil, err
237 }
238
239 baseSet = baseSet.intersection(set)
240 }
241 return baseSet, nil
242 }
243
244
245
246 func getCodeSet(src interface{}) (*codeSet, error) {
247 set := newCodeSet()
248
249 fset := token.NewFileSet()
250 f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
251 if err != nil {
252 return nil, err
253 }
254
255 for _, decl := range f.Decls {
256 switch decl := decl.(type) {
257 case *ast.GenDecl:
258
259 if !(decl.Tok == token.CONST || decl.Tok == token.TYPE) {
260 break
261 }
262
263 for _, spec := range decl.Specs {
264 elem, err := newCodeElem(decl.Tok, spec)
265 if err != nil {
266 return nil, err
267 }
268
269 set.add(elem)
270 }
271 case *ast.FuncDecl:
272
273 elem, err := newCodeElem(token.FUNC, decl)
274 if err != nil {
275 return nil, err
276 }
277
278 set.add(elem)
279 }
280 }
281
282
283 cmap := ast.NewCommentMap(fset, f, f.Comments)
284 for _, cGrp := range cmap[f] {
285 text := cGrp.Text()
286 if text == "" && len(cGrp.List) == 1 && strings.HasPrefix(cGrp.List[0].Text, "//go:build ") {
287
288
289
290 set.add(codeElem{token.COMMENT, cGrp.List[0].Text[len("//"):] + "\n"})
291 continue
292 }
293 set.add(codeElem{token.COMMENT, text})
294 }
295
296 return set, nil
297 }
298
299
300 func importName(iSpec *ast.ImportSpec) (string, error) {
301 if iSpec.Name == nil {
302 name, err := strconv.Unquote(iSpec.Path.Value)
303 if err != nil {
304 return "", err
305 }
306 return path.Base(name), nil
307 }
308 return iSpec.Name.Name, nil
309 }
310
311
312 type specGroups map[string]int
313
314
315 func (s specGroups) add(src string, groupID int) error {
316 srcBytes, err := format.Source(bytes.TrimSpace([]byte(src)))
317 if err != nil {
318 return err
319 }
320 s[string(srcBytes)] = groupID
321 return nil
322 }
323
324
325
326 func (s specGroups) filterEmptyLines(src io.Reader) ([]byte, error) {
327 scanner := bufio.NewScanner(src)
328 var out bytes.Buffer
329
330 var emptyLines bytes.Buffer
331 prevGroupID := -1
332 for scanner.Scan() {
333 line := bytes.TrimSpace(scanner.Bytes())
334
335 if len(line) == 0 {
336 fmt.Fprintf(&emptyLines, "%s\n", scanner.Bytes())
337 continue
338 }
339
340
341
342 if src, err := format.Source(line); err == nil {
343 groupID, ok := s[string(src)]
344 if ok && groupID == prevGroupID {
345 emptyLines.Reset()
346 }
347 prevGroupID = groupID
348 }
349
350 emptyLines.WriteTo(&out)
351 fmt.Fprintf(&out, "%s\n", scanner.Bytes())
352 }
353 if err := scanner.Err(); err != nil {
354 return nil, err
355 }
356 return out.Bytes(), nil
357 }
358
359
360 func filterImports(fileSrc []byte) ([]byte, error) {
361 fset := token.NewFileSet()
362 file, err := parser.ParseFile(fset, "", fileSrc, parser.ParseComments)
363 if err != nil {
364 return nil, err
365 }
366 cmap := ast.NewCommentMap(fset, file, file.Comments)
367
368
369 keepImport := make(map[string]bool)
370 for _, u := range file.Unresolved {
371 keepImport[u.Name] = true
372 }
373
374
375 decls := file.Decls
376 file.Decls = file.Decls[:0]
377 for _, decl := range decls {
378 importDecl, ok := decl.(*ast.GenDecl)
379
380
381 if !ok || importDecl.Tok != token.IMPORT {
382 file.Decls = append(file.Decls, decl)
383 continue
384 }
385
386
387 specs := importDecl.Specs
388 importDecl.Specs = importDecl.Specs[:0]
389 for _, spec := range specs {
390 iSpec := spec.(*ast.ImportSpec)
391 name, err := importName(iSpec)
392 if err != nil {
393 return nil, err
394 }
395
396 if keepImport[name] {
397 importDecl.Specs = append(importDecl.Specs, iSpec)
398 }
399 }
400 if len(importDecl.Specs) > 0 {
401 file.Decls = append(file.Decls, importDecl)
402 }
403 }
404
405
406 imports := file.Imports
407 file.Imports = file.Imports[:0]
408 for _, spec := range imports {
409 name, err := importName(spec)
410 if err != nil {
411 return nil, err
412 }
413
414 if keepImport[name] {
415 file.Imports = append(file.Imports, spec)
416 }
417 }
418 file.Comments = cmap.Filter(file).Comments()
419
420 var buf bytes.Buffer
421 err = format.Node(&buf, fset, file)
422 if err != nil {
423 return nil, err
424 }
425
426 return buf.Bytes(), nil
427 }
428
429
430
431
432
433 func merge(mergedFile string, archFiles ...string) error {
434
435 goos, ok := getValidGOOS(mergedFile)
436 if !ok {
437 return fmt.Errorf("invalid GOOS in merged file name %s", mergedFile)
438 }
439
440
441 var inSrc []srcFile
442 for _, file := range archFiles {
443 src, err := os.ReadFile(file)
444 if err != nil {
445 return fmt.Errorf("cannot read archfile %s: %w", file, err)
446 }
447
448 inSrc = append(inSrc, srcFile{file, src})
449 }
450
451
452 commonSet, err := getCommonSet(inSrc)
453 if err != nil {
454 return err
455 }
456 if commonSet.isEmpty() {
457
458 return nil
459 }
460
461
462 mergedSrc, err := filter(inSrc[0].src, commonSet.keepCommon)
463 if err != nil {
464 return err
465 }
466
467 f, err := os.Create(mergedFile)
468 if err != nil {
469 return err
470 }
471
472 buf := bufio.NewWriter(f)
473 fmt.Fprintln(buf, "// Code generated by mkmerge; DO NOT EDIT.")
474 fmt.Fprintln(buf)
475 fmt.Fprintf(buf, "//go:build %s\n", goos)
476 fmt.Fprintln(buf)
477 buf.Write(mergedSrc)
478
479 err = buf.Flush()
480 if err != nil {
481 return err
482 }
483 err = f.Close()
484 if err != nil {
485 return err
486 }
487
488
489 for _, inFile := range inSrc {
490 src, err := filter(inFile.src, commonSet.keepArchSpecific)
491 if err != nil {
492 return err
493 }
494 err = os.WriteFile(inFile.name, src, 0644)
495 if err != nil {
496 return err
497 }
498 }
499 return nil
500 }
501
502 func main() {
503 var mergedFile string
504 flag.StringVar(&mergedFile, "out", "", "Write merged code to `FILE`")
505 flag.Parse()
506
507
508 var filenames []string
509 for _, arg := range flag.Args() {
510 matches, err := filepath.Glob(arg)
511 if err != nil {
512 fmt.Fprintf(os.Stderr, "Invalid command line argument %q: %v\n", arg, err)
513 os.Exit(1)
514 }
515 filenames = append(filenames, matches...)
516 }
517
518 if len(filenames) < 2 {
519
520 return
521 }
522
523 err := merge(mergedFile, filenames...)
524 if err != nil {
525 fmt.Fprintf(os.Stderr, "Merge failed with error: %v\n", err)
526 os.Exit(1)
527 }
528 }
529
View as plain text