1
2
3
4
5 package devirtualize
6
7 import (
8 "cmd/compile/internal/base"
9 "cmd/compile/internal/inline"
10 "cmd/compile/internal/ir"
11 "cmd/compile/internal/logopt"
12 "cmd/compile/internal/pgo"
13 "cmd/compile/internal/typecheck"
14 "cmd/compile/internal/types"
15 "cmd/internal/obj"
16 "cmd/internal/src"
17 "encoding/json"
18 "fmt"
19 "os"
20 "strings"
21 )
22
23
24
25
26 type CallStat struct {
27 Pkg string
28 Pos string
29
30 Caller string
31
32
33 Direct bool
34
35
36 Interface bool
37
38
39 Weight int64
40
41
42
43 Hottest string
44 HottestWeight int64
45
46
47
48
49
50
51 Devirtualized string
52 DevirtualizedWeight int64
53 }
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105 func ProfileGuided(fn *ir.Func, p *pgo.Profile) {
106 ir.CurFunc = fn
107
108 name := ir.LinkFuncName(fn)
109
110 var jsonW *json.Encoder
111 if base.Debug.PGODebug >= 3 {
112 jsonW = json.NewEncoder(os.Stdout)
113 }
114
115 var edit func(n ir.Node) ir.Node
116 edit = func(n ir.Node) ir.Node {
117 if n == nil {
118 return n
119 }
120
121 ir.EditChildren(n, edit)
122
123 call, ok := n.(*ir.CallExpr)
124 if !ok {
125 return n
126 }
127
128 var stat *CallStat
129 if base.Debug.PGODebug >= 3 {
130
131
132
133 stat = constructCallStat(p, fn, name, call)
134 if stat != nil {
135 defer func() {
136 jsonW.Encode(&stat)
137 }()
138 }
139 }
140
141 op := call.Op()
142 if op != ir.OCALLFUNC && op != ir.OCALLINTER {
143 return n
144 }
145
146 if base.Debug.PGODebug >= 2 {
147 fmt.Printf("%v: PGO devirtualize considering call %v\n", ir.Line(call), call)
148 }
149
150 if call.GoDefer {
151 if base.Debug.PGODebug >= 2 {
152 fmt.Printf("%v: can't PGO devirtualize go/defer call %v\n", ir.Line(call), call)
153 }
154 return n
155 }
156
157 var newNode ir.Node
158 var callee *ir.Func
159 var weight int64
160 switch op {
161 case ir.OCALLFUNC:
162 newNode, callee, weight = maybeDevirtualizeFunctionCall(p, fn, call)
163 case ir.OCALLINTER:
164 newNode, callee, weight = maybeDevirtualizeInterfaceCall(p, fn, call)
165 default:
166 panic("unreachable")
167 }
168
169 if newNode == nil {
170 return n
171 }
172
173 if stat != nil {
174 stat.Devirtualized = ir.LinkFuncName(callee)
175 stat.DevirtualizedWeight = weight
176 }
177
178 return newNode
179 }
180
181 ir.EditChildren(fn, edit)
182 }
183
184
185
186
187 func maybeDevirtualizeInterfaceCall(p *pgo.Profile, fn *ir.Func, call *ir.CallExpr) (ir.Node, *ir.Func, int64) {
188 if base.Debug.PGODevirtualize < 1 {
189 return nil, nil, 0
190 }
191
192
193 callee, weight := findHotConcreteInterfaceCallee(p, fn, call)
194 if callee == nil {
195 return nil, nil, 0
196 }
197
198 ctyp := methodRecvType(callee)
199 if ctyp == nil {
200 return nil, nil, 0
201 }
202
203 if !shouldPGODevirt(callee) {
204 return nil, nil, 0
205 }
206
207 if !base.PGOHash.MatchPosWithInfo(call.Pos(), "devirt", nil) {
208 return nil, nil, 0
209 }
210
211 return rewriteInterfaceCall(call, fn, callee, ctyp), callee, weight
212 }
213
214
215
216
217 func maybeDevirtualizeFunctionCall(p *pgo.Profile, fn *ir.Func, call *ir.CallExpr) (ir.Node, *ir.Func, int64) {
218 if base.Debug.PGODevirtualize < 2 {
219 return nil, nil, 0
220 }
221
222
223 callee := pgo.DirectCallee(call.Fun)
224 if callee != nil {
225 return nil, nil, 0
226 }
227
228
229 callee, weight := findHotConcreteFunctionCallee(p, fn, call)
230 if callee == nil {
231 return nil, nil, 0
232 }
233
234
235
236
237 if callee.OClosure != nil {
238 if base.Debug.PGODebug >= 3 {
239 fmt.Printf("callee %s is a closure, skipping\n", ir.FuncName(callee))
240 }
241 return nil, nil, 0
242 }
243
244
245
246 if callee.Sym().Pkg.Path == "runtime" && callee.Sym().Name == "memhash_varlen" {
247 if base.Debug.PGODebug >= 3 {
248 fmt.Printf("callee %s is a closure (runtime.memhash_varlen), skipping\n", ir.FuncName(callee))
249 }
250 return nil, nil, 0
251 }
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279 if callee.Type().Recv() != nil {
280 if base.Debug.PGODebug >= 3 {
281 fmt.Printf("callee %s is a method, skipping\n", ir.FuncName(callee))
282 }
283 return nil, nil, 0
284 }
285
286
287 if !shouldPGODevirt(callee) {
288 return nil, nil, 0
289 }
290
291 if !base.PGOHash.MatchPosWithInfo(call.Pos(), "devirt", nil) {
292 return nil, nil, 0
293 }
294
295 return rewriteFunctionCall(call, fn, callee), callee, weight
296 }
297
298
299
300
301
302
303 func shouldPGODevirt(fn *ir.Func) bool {
304 var reason string
305 if base.Flag.LowerM > 1 || logopt.Enabled() {
306 defer func() {
307 if reason != "" {
308 if base.Flag.LowerM > 1 {
309 fmt.Printf("%v: should not PGO devirtualize %v: %s\n", ir.Line(fn), ir.FuncName(fn), reason)
310 }
311 if logopt.Enabled() {
312 logopt.LogOpt(fn.Pos(), ": should not PGO devirtualize function", "pgo-devirtualize", ir.FuncName(fn), reason)
313 }
314 }
315 }()
316 }
317
318 reason = inline.InlineImpossible(fn)
319 if reason != "" {
320 return false
321 }
322
323
324
325
326
327
328
329
330
331
332
333 return true
334 }
335
336
337
338
339 func constructCallStat(p *pgo.Profile, fn *ir.Func, name string, call *ir.CallExpr) *CallStat {
340 switch call.Op() {
341 case ir.OCALLFUNC, ir.OCALLINTER, ir.OCALLMETH:
342 default:
343
344 return nil
345 }
346
347 stat := CallStat{
348 Pkg: base.Ctxt.Pkgpath,
349 Pos: ir.Line(call),
350 Caller: name,
351 }
352
353 offset := pgo.NodeLineOffset(call, fn)
354
355 hotter := func(e *pgo.IREdge) bool {
356 if stat.Hottest == "" {
357 return true
358 }
359 if e.Weight != stat.HottestWeight {
360 return e.Weight > stat.HottestWeight
361 }
362
363
364 return e.Dst.Name() < stat.Hottest
365 }
366
367
368
369
370
371 callerNode := p.WeightedCG.IRNodes[name]
372 for _, edge := range callerNode.OutEdges {
373 if edge.CallSiteOffset != offset {
374 continue
375 }
376 stat.Weight += edge.Weight
377 if hotter(edge) {
378 stat.HottestWeight = edge.Weight
379 stat.Hottest = edge.Dst.Name()
380 }
381 }
382
383 switch call.Op() {
384 case ir.OCALLFUNC:
385 stat.Interface = false
386
387 callee := pgo.DirectCallee(call.Fun)
388 if callee != nil {
389 stat.Direct = true
390 if stat.Hottest == "" {
391 stat.Hottest = ir.LinkFuncName(callee)
392 }
393 } else {
394 stat.Direct = false
395 }
396 case ir.OCALLINTER:
397 stat.Direct = false
398 stat.Interface = true
399 case ir.OCALLMETH:
400 base.FatalfAt(call.Pos(), "OCALLMETH missed by typecheck")
401 }
402
403 return &stat
404 }
405
406
407
408
409
410
411
412 func copyInputs(curfn *ir.Func, pos src.XPos, recvOrFn ir.Node, args []ir.Node, init *ir.Nodes) (ir.Node, []ir.Node) {
413
414
415
416
417
418
419
420 var lhs, rhs []ir.Node
421 newRecvOrFn := typecheck.TempAt(pos, curfn, recvOrFn.Type())
422 lhs = append(lhs, newRecvOrFn)
423 rhs = append(rhs, recvOrFn)
424
425 for _, arg := range args {
426 argvar := typecheck.TempAt(pos, curfn, arg.Type())
427
428 lhs = append(lhs, argvar)
429 rhs = append(rhs, arg)
430 }
431
432 asList := ir.NewAssignListStmt(pos, ir.OAS2, lhs, rhs)
433 init.Append(typecheck.Stmt(asList))
434
435 return newRecvOrFn, lhs[1:]
436 }
437
438
439 func retTemps(curfn *ir.Func, pos src.XPos, call *ir.CallExpr) []ir.Node {
440 sig := call.Fun.Type()
441 var retvars []ir.Node
442 for _, ret := range sig.Results() {
443 retvars = append(retvars, typecheck.TempAt(pos, curfn, ret.Type))
444 }
445 return retvars
446 }
447
448
449
450
451 func condCall(curfn *ir.Func, pos src.XPos, cond ir.Node, thenCall, elseCall *ir.CallExpr, init ir.Nodes) *ir.InlinedCallExpr {
452
453
454 retvars := retTemps(curfn, pos, thenCall)
455
456 var thenBlock, elseBlock ir.Nodes
457 if len(retvars) == 0 {
458 thenBlock.Append(thenCall)
459 elseBlock.Append(elseCall)
460 } else {
461
462 thenRet := append([]ir.Node(nil), retvars...)
463 thenAsList := ir.NewAssignListStmt(pos, ir.OAS2, thenRet, []ir.Node{thenCall})
464 thenBlock.Append(typecheck.Stmt(thenAsList))
465
466 elseRet := append([]ir.Node(nil), retvars...)
467 elseAsList := ir.NewAssignListStmt(pos, ir.OAS2, elseRet, []ir.Node{elseCall})
468 elseBlock.Append(typecheck.Stmt(elseAsList))
469 }
470
471 nif := ir.NewIfStmt(pos, cond, thenBlock, elseBlock)
472 nif.SetInit(init)
473 nif.Likely = true
474
475 body := []ir.Node{typecheck.Stmt(nif)}
476
477
478
479 res := ir.NewInlinedCallExpr(pos, body, retvars)
480 res.SetType(thenCall.Type())
481 res.SetTypecheck(1)
482 return res
483 }
484
485
486
487 func rewriteInterfaceCall(call *ir.CallExpr, curfn, callee *ir.Func, concretetyp *types.Type) ir.Node {
488 if base.Flag.LowerM != 0 {
489 fmt.Printf("%v: PGO devirtualizing interface call %v to %v\n", ir.Line(call), call.Fun, callee)
490 }
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520 sel := call.Fun.(*ir.SelectorExpr)
521 method := sel.Sel
522 pos := call.Pos()
523 init := ir.TakeInit(call)
524
525 recv, args := copyInputs(curfn, pos, sel.X, call.Args.Take(), &init)
526
527
528 argvars := append([]ir.Node(nil), args...)
529 call.Args = argvars
530
531 tmpnode := typecheck.TempAt(base.Pos, curfn, concretetyp)
532 tmpok := typecheck.TempAt(base.Pos, curfn, types.Types[types.TBOOL])
533
534 assert := ir.NewTypeAssertExpr(pos, recv, concretetyp)
535
536 assertAsList := ir.NewAssignListStmt(pos, ir.OAS2, []ir.Node{tmpnode, tmpok}, []ir.Node{typecheck.Expr(assert)})
537 init.Append(typecheck.Stmt(assertAsList))
538
539 concreteCallee := typecheck.XDotMethod(pos, tmpnode, method, true)
540
541 argvars = append([]ir.Node(nil), argvars...)
542 concreteCall := typecheck.Call(pos, concreteCallee, argvars, call.IsDDD).(*ir.CallExpr)
543
544 res := condCall(curfn, pos, tmpok, concreteCall, call, init)
545
546 if base.Debug.PGODebug >= 3 {
547 fmt.Printf("PGO devirtualizing interface call to %+v. After: %+v\n", concretetyp, res)
548 }
549
550 return res
551 }
552
553
554
555 func rewriteFunctionCall(call *ir.CallExpr, curfn, callee *ir.Func) ir.Node {
556 if base.Flag.LowerM != 0 {
557 fmt.Printf("%v: PGO devirtualizing function call %v to %v\n", ir.Line(call), call.Fun, callee)
558 }
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586 pos := call.Pos()
587 init := ir.TakeInit(call)
588
589 fn, args := copyInputs(curfn, pos, call.Fun, call.Args.Take(), &init)
590
591
592 argvars := append([]ir.Node(nil), args...)
593 call.Args = argvars
594
595
596
597 fnIface := typecheck.Expr(ir.NewConvExpr(pos, ir.OCONV, types.Types[types.TINTER], fn))
598 calleeIface := typecheck.Expr(ir.NewConvExpr(pos, ir.OCONV, types.Types[types.TINTER], callee.Nname))
599
600 fnPC := ir.FuncPC(pos, fnIface, obj.ABIInternal)
601 concretePC := ir.FuncPC(pos, calleeIface, obj.ABIInternal)
602
603 pcEq := typecheck.Expr(ir.NewBinaryExpr(base.Pos, ir.OEQ, fnPC, concretePC))
604
605
606
607
608 if callee.OClosure != nil {
609 base.Fatalf("Callee is a closure: %+v", callee)
610 }
611
612
613 argvars = append([]ir.Node(nil), argvars...)
614 concreteCall := typecheck.Call(pos, callee.Nname, argvars, call.IsDDD).(*ir.CallExpr)
615
616 res := condCall(curfn, pos, pcEq, concreteCall, call, init)
617
618 if base.Debug.PGODebug >= 3 {
619 fmt.Printf("PGO devirtualizing function call to %+v. After: %+v\n", ir.FuncName(callee), res)
620 }
621
622 return res
623 }
624
625
626
627 func methodRecvType(fn *ir.Func) *types.Type {
628 recv := fn.Nname.Type().Recv()
629 if recv == nil {
630 return nil
631 }
632 return recv.Type
633 }
634
635
636
637 func interfaceCallRecvTypeAndMethod(call *ir.CallExpr) (*types.Type, *types.Sym) {
638 if call.Op() != ir.OCALLINTER {
639 base.Fatalf("Call isn't OCALLINTER: %+v", call)
640 }
641
642 sel, ok := call.Fun.(*ir.SelectorExpr)
643 if !ok {
644 base.Fatalf("OCALLINTER doesn't contain SelectorExpr: %+v", call)
645 }
646
647 return sel.X.Type(), sel.Sel
648 }
649
650
651
652
653
654 func findHotConcreteCallee(p *pgo.Profile, caller *ir.Func, call *ir.CallExpr, extraFn func(callerName string, callOffset int, candidate *pgo.IREdge) bool) (*ir.Func, int64) {
655 callerName := ir.LinkFuncName(caller)
656 callerNode := p.WeightedCG.IRNodes[callerName]
657 callOffset := pgo.NodeLineOffset(call, caller)
658
659 var hottest *pgo.IREdge
660
661
662
663
664
665
666
667 hotter := func(e *pgo.IREdge) bool {
668 if hottest == nil {
669 return true
670 }
671 if e.Weight != hottest.Weight {
672 return e.Weight > hottest.Weight
673 }
674
675
676
677
678
679 if (hottest.Dst.AST == nil) != (e.Dst.AST == nil) {
680 if e.Dst.AST != nil {
681 return true
682 }
683 return false
684 }
685
686
687
688 return e.Dst.Name() < hottest.Dst.Name()
689 }
690
691 for _, e := range callerNode.OutEdges {
692 if e.CallSiteOffset != callOffset {
693 continue
694 }
695
696 if !hotter(e) {
697
698
699
700
701
702 if base.Debug.PGODebug >= 2 {
703 fmt.Printf("%v: edge %s:%d -> %s (weight %d): too cold (hottest %d)\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight, hottest.Weight)
704 }
705 continue
706 }
707
708 if e.Dst.AST == nil {
709
710
711
712
713
714
715
716
717
718 if base.Debug.PGODebug >= 2 {
719 fmt.Printf("%v: edge %s:%d -> %s (weight %d) (missing IR): hottest so far\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight)
720 }
721 hottest = e
722 continue
723 }
724
725 if extraFn != nil && !extraFn(callerName, callOffset, e) {
726 continue
727 }
728
729 if base.Debug.PGODebug >= 2 {
730 fmt.Printf("%v: edge %s:%d -> %s (weight %d): hottest so far\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight)
731 }
732 hottest = e
733 }
734
735 if hottest == nil {
736 if base.Debug.PGODebug >= 2 {
737 fmt.Printf("%v: call %s:%d: no hot callee\n", ir.Line(call), callerName, callOffset)
738 }
739 return nil, 0
740 }
741
742 if base.Debug.PGODebug >= 2 {
743 fmt.Printf("%v call %s:%d: hottest callee %s (weight %d)\n", ir.Line(call), callerName, callOffset, hottest.Dst.Name(), hottest.Weight)
744 }
745 return hottest.Dst.AST, hottest.Weight
746 }
747
748
749
750 func findHotConcreteInterfaceCallee(p *pgo.Profile, caller *ir.Func, call *ir.CallExpr) (*ir.Func, int64) {
751 inter, method := interfaceCallRecvTypeAndMethod(call)
752
753 return findHotConcreteCallee(p, caller, call, func(callerName string, callOffset int, e *pgo.IREdge) bool {
754 ctyp := methodRecvType(e.Dst.AST)
755 if ctyp == nil {
756
757
758 if base.Debug.PGODebug >= 2 {
759 fmt.Printf("%v: edge %s:%d -> %s (weight %d): callee not a method\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight)
760 }
761 return false
762 }
763
764
765
766 if !typecheck.Implements(ctyp, inter) {
767
768
769
770
771
772
773
774
775
776 if base.Debug.PGODebug >= 2 {
777 why := typecheck.ImplementsExplain(ctyp, inter)
778 fmt.Printf("%v: edge %s:%d -> %s (weight %d): %v doesn't implement %v (%s)\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight, ctyp, inter, why)
779 }
780 return false
781 }
782
783
784
785 if !strings.HasSuffix(e.Dst.Name(), "."+method.Name) {
786 if base.Debug.PGODebug >= 2 {
787 fmt.Printf("%v: edge %s:%d -> %s (weight %d): callee is a different method\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight)
788 }
789 return false
790 }
791
792 return true
793 })
794 }
795
796
797
798 func findHotConcreteFunctionCallee(p *pgo.Profile, caller *ir.Func, call *ir.CallExpr) (*ir.Func, int64) {
799 typ := call.Fun.Type().Underlying()
800
801 return findHotConcreteCallee(p, caller, call, func(callerName string, callOffset int, e *pgo.IREdge) bool {
802 ctyp := e.Dst.AST.Type().Underlying()
803
804
805
806
807
808
809
810
811 if !types.Identical(typ, ctyp) {
812 if base.Debug.PGODebug >= 2 {
813 fmt.Printf("%v: edge %s:%d -> %s (weight %d): %v doesn't match %v\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight, ctyp, typ)
814 }
815 return false
816 }
817
818 return true
819 })
820 }
821
View as plain text