1
2
3
4
5 package loopclosure
6
7 import (
8 _ "embed"
9 "go/ast"
10 "go/types"
11
12 "golang.org/x/tools/go/analysis"
13 "golang.org/x/tools/go/analysis/passes/inspect"
14 "golang.org/x/tools/go/analysis/passes/internal/analysisutil"
15 "golang.org/x/tools/go/ast/inspector"
16 "golang.org/x/tools/go/types/typeutil"
17 "golang.org/x/tools/internal/versions"
18 )
19
20
21 var doc string
22
23 var Analyzer = &analysis.Analyzer{
24 Name: "loopclosure",
25 Doc: analysisutil.MustExtractDoc(doc, "loopclosure"),
26 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/loopclosure",
27 Requires: []*analysis.Analyzer{inspect.Analyzer},
28 Run: run,
29 }
30
31 func run(pass *analysis.Pass) (interface{}, error) {
32 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
33
34 nodeFilter := []ast.Node{
35 (*ast.File)(nil),
36 (*ast.RangeStmt)(nil),
37 (*ast.ForStmt)(nil),
38 }
39 inspect.Nodes(nodeFilter, func(n ast.Node, push bool) bool {
40 if !push {
41
42 return true
43 }
44
45 var vars []types.Object
46 addVar := func(expr ast.Expr) {
47 if id, _ := expr.(*ast.Ident); id != nil {
48 if obj := pass.TypesInfo.ObjectOf(id); obj != nil {
49 vars = append(vars, obj)
50 }
51 }
52 }
53 var body *ast.BlockStmt
54 switch n := n.(type) {
55 case *ast.File:
56
57 goversion := versions.Lang(versions.FileVersions(pass.TypesInfo, n))
58
59 return goversion == "" || versions.Compare(goversion, "go1.22") < 0
60 case *ast.RangeStmt:
61 body = n.Body
62 addVar(n.Key)
63 addVar(n.Value)
64 case *ast.ForStmt:
65 body = n.Body
66 switch post := n.Post.(type) {
67 case *ast.AssignStmt:
68
69 for _, lhs := range post.Lhs {
70 addVar(lhs)
71 }
72 case *ast.IncDecStmt:
73
74 addVar(post.X)
75 }
76 }
77 if vars == nil {
78 return true
79 }
80
81
82
83
84
85
86
87
88
89
90
91
92
93 forEachLastStmt(body.List, func(last ast.Stmt) {
94 var stmts []ast.Stmt
95 switch s := last.(type) {
96 case *ast.GoStmt:
97 stmts = litStmts(s.Call.Fun)
98 case *ast.DeferStmt:
99 stmts = litStmts(s.Call.Fun)
100 case *ast.ExprStmt:
101 if call, ok := s.X.(*ast.CallExpr); ok {
102 stmts = litStmts(goInvoke(pass.TypesInfo, call))
103 }
104 }
105 for _, stmt := range stmts {
106 reportCaptured(pass, vars, stmt)
107 }
108 })
109
110
111
112
113
114
115
116 for _, s := range body.List {
117 switch s := s.(type) {
118 case *ast.ExprStmt:
119 if call, ok := s.X.(*ast.CallExpr); ok {
120 for _, stmt := range parallelSubtest(pass.TypesInfo, call) {
121 reportCaptured(pass, vars, stmt)
122 }
123
124 }
125 }
126 }
127 return true
128 })
129 return nil, nil
130 }
131
132
133
134
135
136 func reportCaptured(pass *analysis.Pass, vars []types.Object, checkStmt ast.Stmt) {
137 ast.Inspect(checkStmt, func(n ast.Node) bool {
138 id, ok := n.(*ast.Ident)
139 if !ok {
140 return true
141 }
142 obj := pass.TypesInfo.Uses[id]
143 if obj == nil {
144 return true
145 }
146 for _, v := range vars {
147 if v == obj {
148 pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name)
149 }
150 }
151 return true
152 })
153 }
154
155
156
157
158
159 func forEachLastStmt(stmts []ast.Stmt, onLast func(last ast.Stmt)) {
160 if len(stmts) == 0 {
161 return
162 }
163
164 s := stmts[len(stmts)-1]
165 switch s := s.(type) {
166 case *ast.IfStmt:
167 loop:
168 for {
169 forEachLastStmt(s.Body.List, onLast)
170 switch e := s.Else.(type) {
171 case *ast.BlockStmt:
172 forEachLastStmt(e.List, onLast)
173 break loop
174 case *ast.IfStmt:
175 s = e
176 case nil:
177 break loop
178 }
179 }
180 case *ast.ForStmt:
181 forEachLastStmt(s.Body.List, onLast)
182 case *ast.RangeStmt:
183 forEachLastStmt(s.Body.List, onLast)
184 case *ast.SwitchStmt:
185 for _, c := range s.Body.List {
186 cc := c.(*ast.CaseClause)
187 forEachLastStmt(cc.Body, onLast)
188 }
189 case *ast.TypeSwitchStmt:
190 for _, c := range s.Body.List {
191 cc := c.(*ast.CaseClause)
192 forEachLastStmt(cc.Body, onLast)
193 }
194 case *ast.SelectStmt:
195 for _, c := range s.Body.List {
196 cc := c.(*ast.CommClause)
197 forEachLastStmt(cc.Body, onLast)
198 }
199 default:
200 onLast(s)
201 }
202 }
203
204
205
206
207
208 func litStmts(fun ast.Expr) []ast.Stmt {
209 lit, _ := fun.(*ast.FuncLit)
210 if lit == nil {
211 return nil
212 }
213 return lit.Body.List
214 }
215
216
217
218
219
220
221
222
223
224
225 func goInvoke(info *types.Info, call *ast.CallExpr) ast.Expr {
226 if !isMethodCall(info, call, "golang.org/x/sync/errgroup", "Group", "Go") {
227 return nil
228 }
229 return call.Args[0]
230 }
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258 func parallelSubtest(info *types.Info, call *ast.CallExpr) []ast.Stmt {
259 if !isMethodCall(info, call, "testing", "T", "Run") {
260 return nil
261 }
262
263 if len(call.Args) != 2 {
264
265 return nil
266 }
267
268 lit, _ := call.Args[1].(*ast.FuncLit)
269 if lit == nil {
270 return nil
271 }
272
273
274
275 if len(lit.Type.Params.List[0].Names) == 0 {
276 return nil
277 }
278
279 tObj := info.Defs[lit.Type.Params.List[0].Names[0]]
280 if tObj == nil {
281 return nil
282 }
283
284
285
286
287
288
289
290 var stmts []ast.Stmt
291 afterParallel := false
292 for _, stmt := range lit.Body.List {
293 stmt, labeled := unlabel(stmt)
294 if labeled {
295
296
297 stmts = nil
298 afterParallel = false
299 }
300
301 if afterParallel {
302 stmts = append(stmts, stmt)
303 continue
304 }
305
306
307 exprStmt, ok := stmt.(*ast.ExprStmt)
308 if !ok {
309 continue
310 }
311 expr := exprStmt.X
312 if isMethodCall(info, expr, "testing", "T", "Parallel") {
313 call, _ := expr.(*ast.CallExpr)
314 if call == nil {
315 continue
316 }
317 x, _ := call.Fun.(*ast.SelectorExpr)
318 if x == nil {
319 continue
320 }
321 id, _ := x.X.(*ast.Ident)
322 if id == nil {
323 continue
324 }
325 if info.Uses[id] == tObj {
326 afterParallel = true
327 }
328 }
329 }
330
331 return stmts
332 }
333
334
335
336
337
338 func unlabel(stmt ast.Stmt) (ast.Stmt, bool) {
339 labeled := false
340 for {
341 labelStmt, ok := stmt.(*ast.LabeledStmt)
342 if !ok {
343 return stmt, labeled
344 }
345 labeled = true
346 stmt = labelStmt.Stmt
347 }
348 }
349
350
351
352 func isMethodCall(info *types.Info, expr ast.Expr, pkgPath, typeName, method string) bool {
353 call, ok := expr.(*ast.CallExpr)
354 if !ok {
355 return false
356 }
357
358
359 f := typeutil.StaticCallee(info, call)
360 if f == nil || f.Name() != method {
361 return false
362 }
363 recv := f.Type().(*types.Signature).Recv()
364 if recv == nil {
365 return false
366 }
367
368
369
370 rtype := recv.Type()
371 if ptr, ok := recv.Type().(*types.Pointer); ok {
372 rtype = ptr.Elem()
373 }
374 return analysisutil.IsNamedType(rtype, pkgPath, typeName)
375 }
376
View as plain text