Source file
src/cmd/fix/main.go
Documentation: cmd/fix
1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "flag"
10 "fmt"
11 "go/ast"
12 "go/format"
13 "go/parser"
14 "go/scanner"
15 "go/token"
16 "internal/diff"
17 "io"
18 "io/fs"
19 "os"
20 "path/filepath"
21 "sort"
22 "strconv"
23 "strings"
24 )
25
26 var (
27 fset = token.NewFileSet()
28 exitCode = 0
29 )
30
31 var allowedRewrites = flag.String("r", "",
32 "restrict the rewrites to this comma-separated list")
33
34 var forceRewrites = flag.String("force", "",
35 "force these fixes to run even if the code looks updated")
36
37 var allowed, force map[string]bool
38
39 var (
40 doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
41 goVersionStr = flag.String("go", "", "go language version for files")
42
43 goVersion int
44 )
45
46
47 const debug = false
48
49 func usage() {
50 fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
51 flag.PrintDefaults()
52 fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
53 sort.Sort(byName(fixes))
54 for _, f := range fixes {
55 if f.disabled {
56 fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
57 } else {
58 fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
59 }
60 desc := strings.TrimSpace(f.desc)
61 desc = strings.ReplaceAll(desc, "\n", "\n\t")
62 fmt.Fprintf(os.Stderr, "\t%s\n", desc)
63 }
64 os.Exit(2)
65 }
66
67 func main() {
68 flag.Usage = usage
69 flag.Parse()
70
71 if *goVersionStr != "" {
72 if !strings.HasPrefix(*goVersionStr, "go") {
73 report(fmt.Errorf("invalid -go=%s", *goVersionStr))
74 os.Exit(exitCode)
75 }
76 majorStr := (*goVersionStr)[len("go"):]
77 minorStr := "0"
78 if before, after, found := strings.Cut(majorStr, "."); found {
79 majorStr, minorStr = before, after
80 }
81 major, err1 := strconv.Atoi(majorStr)
82 minor, err2 := strconv.Atoi(minorStr)
83 if err1 != nil || err2 != nil || major < 0 || major >= 100 || minor < 0 || minor >= 100 {
84 report(fmt.Errorf("invalid -go=%s", *goVersionStr))
85 os.Exit(exitCode)
86 }
87
88 goVersion = major*100 + minor
89 }
90
91 sort.Sort(byDate(fixes))
92
93 if *allowedRewrites != "" {
94 allowed = make(map[string]bool)
95 for _, f := range strings.Split(*allowedRewrites, ",") {
96 allowed[f] = true
97 }
98 }
99
100 if *forceRewrites != "" {
101 force = make(map[string]bool)
102 for _, f := range strings.Split(*forceRewrites, ",") {
103 force[f] = true
104 }
105 }
106
107 if flag.NArg() == 0 {
108 if err := processFile("standard input", true); err != nil {
109 report(err)
110 }
111 os.Exit(exitCode)
112 }
113
114 for i := 0; i < flag.NArg(); i++ {
115 path := flag.Arg(i)
116 switch dir, err := os.Stat(path); {
117 case err != nil:
118 report(err)
119 case dir.IsDir():
120 walkDir(path)
121 default:
122 if err := processFile(path, false); err != nil {
123 report(err)
124 }
125 }
126 }
127
128 os.Exit(exitCode)
129 }
130
131 const parserMode = parser.ParseComments
132
133 func gofmtFile(f *ast.File) ([]byte, error) {
134 var buf bytes.Buffer
135 if err := format.Node(&buf, fset, f); err != nil {
136 return nil, err
137 }
138 return buf.Bytes(), nil
139 }
140
141 func processFile(filename string, useStdin bool) error {
142 var f *os.File
143 var err error
144 var fixlog strings.Builder
145
146 if useStdin {
147 f = os.Stdin
148 } else {
149 f, err = os.Open(filename)
150 if err != nil {
151 return err
152 }
153 defer f.Close()
154 }
155
156 src, err := io.ReadAll(f)
157 if err != nil {
158 return err
159 }
160
161 file, err := parser.ParseFile(fset, filename, src, parserMode)
162 if err != nil {
163 return err
164 }
165
166
167
168 newSrc, err := gofmtFile(file)
169 if err != nil {
170 return err
171 }
172 if !bytes.Equal(newSrc, src) {
173 newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
174 if err != nil {
175 return err
176 }
177 file = newFile
178 fmt.Fprintf(&fixlog, " fmt")
179 }
180
181
182 newFile := file
183 fixed := false
184 for _, fix := range fixes {
185 if allowed != nil && !allowed[fix.name] {
186 continue
187 }
188 if fix.disabled && !force[fix.name] {
189 continue
190 }
191 if fix.f(newFile) {
192 fixed = true
193 fmt.Fprintf(&fixlog, " %s", fix.name)
194
195
196
197
198 newSrc, err := gofmtFile(newFile)
199 if err != nil {
200 return err
201 }
202 newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
203 if err != nil {
204 if debug {
205 fmt.Printf("%s", newSrc)
206 report(err)
207 os.Exit(exitCode)
208 }
209 return err
210 }
211 }
212 }
213 if !fixed {
214 return nil
215 }
216 fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
217
218
219
220
221
222
223
224 newSrc, err = gofmtFile(newFile)
225 if err != nil {
226 return err
227 }
228
229 if *doDiff {
230 os.Stdout.Write(diff.Diff(filename, src, "fixed/"+filename, newSrc))
231 return nil
232 }
233
234 if useStdin {
235 os.Stdout.Write(newSrc)
236 return nil
237 }
238
239 return os.WriteFile(f.Name(), newSrc, 0)
240 }
241
242 func gofmt(n any) string {
243 var gofmtBuf strings.Builder
244 if err := format.Node(&gofmtBuf, fset, n); err != nil {
245 return "<" + err.Error() + ">"
246 }
247 return gofmtBuf.String()
248 }
249
250 func report(err error) {
251 scanner.PrintError(os.Stderr, err)
252 exitCode = 2
253 }
254
255 func walkDir(path string) {
256 filepath.WalkDir(path, visitFile)
257 }
258
259 func visitFile(path string, f fs.DirEntry, err error) error {
260 if err == nil && isGoFile(f) {
261 err = processFile(path, false)
262 }
263 if err != nil {
264 report(err)
265 }
266 return nil
267 }
268
269 func isGoFile(f fs.DirEntry) bool {
270
271 name := f.Name()
272 return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
273 }
274
View as plain text