...
1
2
3
4
5 package ssagen
6
7 import (
8 "fmt"
9 "strings"
10
11 "cmd/compile/internal/base"
12 "cmd/compile/internal/ir"
13 "cmd/compile/internal/typecheck"
14 "cmd/compile/internal/types"
15 "cmd/internal/obj"
16 "cmd/internal/src"
17 )
18
19 func EnableNoWriteBarrierRecCheck() {
20 nowritebarrierrecCheck = newNowritebarrierrecChecker()
21 }
22
23 func NoWriteBarrierRecCheck() {
24
25
26 nowritebarrierrecCheck.check()
27 nowritebarrierrecCheck = nil
28 }
29
30 var nowritebarrierrecCheck *nowritebarrierrecChecker
31
32 type nowritebarrierrecChecker struct {
33
34
35
36 extraCalls map[*ir.Func][]nowritebarrierrecCall
37
38
39 curfn *ir.Func
40 }
41
42 type nowritebarrierrecCall struct {
43 target *ir.Func
44 lineno src.XPos
45 }
46
47
48
49 func newNowritebarrierrecChecker() *nowritebarrierrecChecker {
50 c := &nowritebarrierrecChecker{
51 extraCalls: make(map[*ir.Func][]nowritebarrierrecCall),
52 }
53
54
55
56
57
58
59 for _, n := range typecheck.Target.Funcs {
60 c.curfn = n
61 if c.curfn.ABIWrapper() {
62
63
64
65 continue
66 }
67 ir.Visit(n, c.findExtraCalls)
68 }
69 c.curfn = nil
70 return c
71 }
72
73 func (c *nowritebarrierrecChecker) findExtraCalls(nn ir.Node) {
74 if nn.Op() != ir.OCALLFUNC {
75 return
76 }
77 n := nn.(*ir.CallExpr)
78 if n.Fun == nil || n.Fun.Op() != ir.ONAME {
79 return
80 }
81 fn := n.Fun.(*ir.Name)
82 if fn.Class != ir.PFUNC || fn.Defn == nil {
83 return
84 }
85 if types.RuntimeSymName(fn.Sym()) != "systemstack" {
86 return
87 }
88
89 var callee *ir.Func
90 arg := n.Args[0]
91 switch arg.Op() {
92 case ir.ONAME:
93 arg := arg.(*ir.Name)
94 callee = arg.Defn.(*ir.Func)
95 case ir.OCLOSURE:
96 arg := arg.(*ir.ClosureExpr)
97 callee = arg.Func
98 default:
99 base.Fatalf("expected ONAME or OCLOSURE node, got %+v", arg)
100 }
101 c.extraCalls[c.curfn] = append(c.extraCalls[c.curfn], nowritebarrierrecCall{callee, n.Pos()})
102 }
103
104
105
106
107
108
109
110
111
112 func (c *nowritebarrierrecChecker) recordCall(fn *ir.Func, to *obj.LSym, pos src.XPos) {
113
114 if fn.NWBRCalls == nil {
115 fn.NWBRCalls = new([]ir.SymAndPos)
116 }
117 *fn.NWBRCalls = append(*fn.NWBRCalls, ir.SymAndPos{Sym: to, Pos: pos})
118 }
119
120 func (c *nowritebarrierrecChecker) check() {
121
122
123
124
125 symToFunc := make(map[*obj.LSym]*ir.Func)
126
127
128
129
130
131
132 funcs := make(map[*ir.Func]nowritebarrierrecCall)
133
134 var q ir.NameQueue
135
136 for _, fn := range typecheck.Target.Funcs {
137 symToFunc[fn.LSym] = fn
138
139
140 if fn.Pragma&ir.Nowritebarrierrec != 0 {
141 funcs[fn] = nowritebarrierrecCall{}
142 q.PushRight(fn.Nname)
143 }
144
145 if fn.Pragma&ir.Nowritebarrier != 0 && fn.WBPos.IsKnown() {
146 base.ErrorfAt(fn.WBPos, 0, "write barrier prohibited")
147 }
148 }
149
150
151
152 enqueue := func(src, target *ir.Func, pos src.XPos) {
153 if target.Pragma&ir.Yeswritebarrierrec != 0 {
154
155 return
156 }
157 if _, ok := funcs[target]; ok {
158
159 return
160 }
161
162
163 funcs[target] = nowritebarrierrecCall{target: src, lineno: pos}
164 q.PushRight(target.Nname)
165 }
166 for !q.Empty() {
167 fn := q.PopLeft().Func
168
169
170 if fn.WBPos.IsKnown() {
171 var err strings.Builder
172 call := funcs[fn]
173 for call.target != nil {
174 fmt.Fprintf(&err, "\n\t%v: called by %v", base.FmtPos(call.lineno), call.target.Nname)
175 call = funcs[call.target]
176 }
177 base.ErrorfAt(fn.WBPos, 0, "write barrier prohibited by caller; %v%s", fn.Nname, err.String())
178 continue
179 }
180
181
182 for _, callee := range c.extraCalls[fn] {
183 enqueue(fn, callee.target, callee.lineno)
184 }
185 if fn.NWBRCalls == nil {
186 continue
187 }
188 for _, callee := range *fn.NWBRCalls {
189 target := symToFunc[callee.Sym]
190 if target != nil {
191 enqueue(fn, target, callee.Pos)
192 }
193 }
194 }
195 }
196
View as plain text