1
2
3
4
5 package lostcancel
6
7 import (
8 _ "embed"
9 "fmt"
10 "go/ast"
11 "go/types"
12
13 "golang.org/x/tools/go/analysis"
14 "golang.org/x/tools/go/analysis/passes/ctrlflow"
15 "golang.org/x/tools/go/analysis/passes/inspect"
16 "golang.org/x/tools/go/analysis/passes/internal/analysisutil"
17 "golang.org/x/tools/go/ast/inspector"
18 "golang.org/x/tools/go/cfg"
19 )
20
21
22 var doc string
23
24 var Analyzer = &analysis.Analyzer{
25 Name: "lostcancel",
26 Doc: analysisutil.MustExtractDoc(doc, "lostcancel"),
27 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/lostcancel",
28 Run: run,
29 Requires: []*analysis.Analyzer{
30 inspect.Analyzer,
31 ctrlflow.Analyzer,
32 },
33 }
34
35 const debug = false
36
37 var contextPackage = "context"
38
39
40
41
42
43
44
45
46
47
48
49 func run(pass *analysis.Pass) (interface{}, error) {
50
51 if !analysisutil.Imports(pass.Pkg, contextPackage) {
52 return nil, nil
53 }
54
55
56 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
57 nodeTypes := []ast.Node{
58 (*ast.FuncLit)(nil),
59 (*ast.FuncDecl)(nil),
60 }
61 inspect.Preorder(nodeTypes, func(n ast.Node) {
62 runFunc(pass, n)
63 })
64 return nil, nil
65 }
66
67 func runFunc(pass *analysis.Pass, node ast.Node) {
68
69 var funcScope *types.Scope
70 switch v := node.(type) {
71 case *ast.FuncLit:
72 funcScope = pass.TypesInfo.Scopes[v.Type]
73 case *ast.FuncDecl:
74 funcScope = pass.TypesInfo.Scopes[v.Type]
75 }
76
77
78 cancelvars := make(map[*types.Var]ast.Node)
79
80
81
82
83
84
85 stack := make([]ast.Node, 0, 32)
86 ast.Inspect(node, func(n ast.Node) bool {
87 switch n.(type) {
88 case *ast.FuncLit:
89 if len(stack) > 0 {
90 return false
91 }
92 case nil:
93 stack = stack[:len(stack)-1]
94 return true
95 }
96 stack = append(stack, n)
97
98
99
100
101
102
103
104 if !isContextWithCancel(pass.TypesInfo, n) || !isCall(stack[len(stack)-2]) {
105 return true
106 }
107 var id *ast.Ident
108 stmt := stack[len(stack)-3]
109 switch stmt := stmt.(type) {
110 case *ast.ValueSpec:
111 if len(stmt.Names) > 1 {
112 id = stmt.Names[1]
113 }
114 case *ast.AssignStmt:
115 if len(stmt.Lhs) > 1 {
116 id, _ = stmt.Lhs[1].(*ast.Ident)
117 }
118 }
119 if id != nil {
120 if id.Name == "_" {
121 pass.ReportRangef(id,
122 "the cancel function returned by context.%s should be called, not discarded, to avoid a context leak",
123 n.(*ast.SelectorExpr).Sel.Name)
124 } else if v, ok := pass.TypesInfo.Uses[id].(*types.Var); ok {
125
126
127 if funcScope.Contains(v.Pos()) {
128 cancelvars[v] = stmt
129 }
130 } else if v, ok := pass.TypesInfo.Defs[id].(*types.Var); ok {
131 cancelvars[v] = stmt
132 }
133 }
134 return true
135 })
136
137 if len(cancelvars) == 0 {
138 return
139 }
140
141
142 cfgs := pass.ResultOf[ctrlflow.Analyzer].(*ctrlflow.CFGs)
143 var g *cfg.CFG
144 var sig *types.Signature
145 switch node := node.(type) {
146 case *ast.FuncDecl:
147 sig, _ = pass.TypesInfo.Defs[node.Name].Type().(*types.Signature)
148 if node.Name.Name == "main" && sig.Recv() == nil && pass.Pkg.Name() == "main" {
149
150
151 return
152 }
153 g = cfgs.FuncDecl(node)
154
155 case *ast.FuncLit:
156 sig, _ = pass.TypesInfo.Types[node.Type].Type.(*types.Signature)
157 g = cfgs.FuncLit(node)
158 }
159 if sig == nil {
160 return
161 }
162
163
164 if debug {
165 fmt.Println(g.Format(pass.Fset))
166 }
167
168
169
170
171 for v, stmt := range cancelvars {
172 if ret := lostCancelPath(pass, g, v, stmt, sig); ret != nil {
173 lineno := pass.Fset.Position(stmt.Pos()).Line
174 pass.ReportRangef(stmt, "the %s function is not used on all paths (possible context leak)", v.Name())
175 pass.ReportRangef(ret, "this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno)
176 }
177 }
178 }
179
180 func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok }
181
182
183
184 func isContextWithCancel(info *types.Info, n ast.Node) bool {
185 sel, ok := n.(*ast.SelectorExpr)
186 if !ok {
187 return false
188 }
189 switch sel.Sel.Name {
190 case "WithCancel", "WithTimeout", "WithDeadline":
191 default:
192 return false
193 }
194 if x, ok := sel.X.(*ast.Ident); ok {
195 if pkgname, ok := info.Uses[x].(*types.PkgName); ok {
196 return pkgname.Imported().Path() == contextPackage
197 }
198
199
200 return x.Name == "context"
201 }
202 return false
203 }
204
205
206
207
208
209 func lostCancelPath(pass *analysis.Pass, g *cfg.CFG, v *types.Var, stmt ast.Node, sig *types.Signature) *ast.ReturnStmt {
210 vIsNamedResult := sig != nil && tupleContains(sig.Results(), v)
211
212
213 uses := func(pass *analysis.Pass, v *types.Var, stmts []ast.Node) bool {
214 found := false
215 for _, stmt := range stmts {
216 ast.Inspect(stmt, func(n ast.Node) bool {
217 switch n := n.(type) {
218 case *ast.Ident:
219 if pass.TypesInfo.Uses[n] == v {
220 found = true
221 }
222 case *ast.ReturnStmt:
223
224
225 if n.Results == nil && vIsNamedResult {
226 found = true
227 }
228 }
229 return !found
230 })
231 }
232 return found
233 }
234
235
236 memo := make(map[*cfg.Block]bool)
237 blockUses := func(pass *analysis.Pass, v *types.Var, b *cfg.Block) bool {
238 res, ok := memo[b]
239 if !ok {
240 res = uses(pass, v, b.Nodes)
241 memo[b] = res
242 }
243 return res
244 }
245
246
247
248 var defblock *cfg.Block
249 var rest []ast.Node
250 outer:
251 for _, b := range g.Blocks {
252 for i, n := range b.Nodes {
253 if n == stmt {
254 defblock = b
255 rest = b.Nodes[i+1:]
256 break outer
257 }
258 }
259 }
260 if defblock == nil {
261 panic("internal error: can't find defining block for cancel var")
262 }
263
264
265 if uses(pass, v, rest) {
266 return nil
267 }
268
269
270 if ret := defblock.Return(); ret != nil {
271 return ret
272 }
273
274
275
276 seen := make(map[*cfg.Block]bool)
277 var search func(blocks []*cfg.Block) *ast.ReturnStmt
278 search = func(blocks []*cfg.Block) *ast.ReturnStmt {
279 for _, b := range blocks {
280 if seen[b] {
281 continue
282 }
283 seen[b] = true
284
285
286 if blockUses(pass, v, b) {
287 continue
288 }
289
290
291 if ret := b.Return(); ret != nil {
292 if debug {
293 fmt.Printf("found path to return in block %s\n", b)
294 }
295 return ret
296 }
297
298
299 if ret := search(b.Succs); ret != nil {
300 if debug {
301 fmt.Printf(" from block %s\n", b)
302 }
303 return ret
304 }
305 }
306 return nil
307 }
308 return search(defblock.Succs)
309 }
310
311 func tupleContains(tuple *types.Tuple, v *types.Var) bool {
312 for i := 0; i < tuple.Len(); i++ {
313 if tuple.At(i) == v {
314 return true
315 }
316 }
317 return false
318 }
319
View as plain text