Source file src/cmd/compile/internal/inline/interleaved/interleaved.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package interleaved implements the interleaved devirtualization and
     6  // inlining pass.
     7  package interleaved
     8  
     9  import (
    10  	"cmd/compile/internal/base"
    11  	"cmd/compile/internal/devirtualize"
    12  	"cmd/compile/internal/inline"
    13  	"cmd/compile/internal/inline/inlheur"
    14  	"cmd/compile/internal/ir"
    15  	"cmd/compile/internal/pgoir"
    16  	"cmd/compile/internal/typecheck"
    17  	"fmt"
    18  )
    19  
    20  // DevirtualizeAndInlinePackage interleaves devirtualization and inlining on
    21  // all functions within pkg.
    22  func DevirtualizeAndInlinePackage(pkg *ir.Package, profile *pgoir.Profile) {
    23  	if profile != nil && base.Debug.PGODevirtualize > 0 {
    24  		// TODO(mdempsky): Integrate into DevirtualizeAndInlineFunc below.
    25  		ir.VisitFuncsBottomUp(typecheck.Target.Funcs, func(list []*ir.Func, recursive bool) {
    26  			for _, fn := range list {
    27  				devirtualize.ProfileGuided(fn, profile)
    28  			}
    29  		})
    30  		ir.CurFunc = nil
    31  	}
    32  
    33  	if base.Flag.LowerL != 0 {
    34  		inlheur.SetupScoreAdjustments()
    35  	}
    36  
    37  	var inlProfile *pgoir.Profile // copy of profile for inlining
    38  	if base.Debug.PGOInline != 0 {
    39  		inlProfile = profile
    40  	}
    41  
    42  	// First compute inlinability of all functions in the package.
    43  	inline.CanInlineFuncs(pkg.Funcs, inlProfile)
    44  
    45  	inlState := make(map[*ir.Func]*inlClosureState)
    46  	calleeUseCounts := make(map[*ir.Func]int)
    47  
    48  	// Pre-process all the functions, adding parentheses around call sites and starting their "inl state".
    49  	for _, fn := range typecheck.Target.Funcs {
    50  		bigCaller := base.Flag.LowerL != 0 && inline.IsBigFunc(fn)
    51  		if bigCaller && base.Flag.LowerM > 1 {
    52  			fmt.Printf("%v: function %v considered 'big'; reducing max cost of inlinees\n", ir.Line(fn), fn)
    53  		}
    54  
    55  		s := &inlClosureState{bigCaller: bigCaller, profile: profile, fn: fn, callSites: make(map[*ir.ParenExpr]bool), useCounts: calleeUseCounts}
    56  		s.parenthesize()
    57  		inlState[fn] = s
    58  
    59  		// Do a first pass at counting call sites.
    60  		for i := range s.parens {
    61  			s.resolve(i)
    62  		}
    63  	}
    64  
    65  	ir.VisitFuncsBottomUp(typecheck.Target.Funcs, func(list []*ir.Func, recursive bool) {
    66  
    67  		anyInlineHeuristics := false
    68  
    69  		// inline heuristics, placed here because they have static state and that's what seems to work.
    70  		for _, fn := range list {
    71  			if base.Flag.LowerL != 0 {
    72  				if inlheur.Enabled() && !fn.Wrapper() {
    73  					inlheur.ScoreCalls(fn)
    74  					anyInlineHeuristics = true
    75  				}
    76  				if base.Debug.DumpInlFuncProps != "" && !fn.Wrapper() {
    77  					inlheur.DumpFuncProps(fn, base.Debug.DumpInlFuncProps)
    78  				}
    79  			}
    80  		}
    81  
    82  		if anyInlineHeuristics {
    83  			defer inlheur.ScoreCallsCleanup()
    84  		}
    85  
    86  		// Iterate to a fixed point over all the functions.
    87  		done := false
    88  		for !done {
    89  			done = true
    90  			for _, fn := range list {
    91  				s := inlState[fn]
    92  
    93  				ir.WithFunc(fn, func() {
    94  					l1 := len(s.parens)
    95  					l0 := 0
    96  
    97  					// Batch iterations so that newly discovered call sites are
    98  					// resolved in a batch before inlining attempts.
    99  					// Do this to avoid discovering new closure calls 1 at a time
   100  					// which might cause first call to be seen as a single (high-budget)
   101  					// call before the second is observed.
   102  					for {
   103  						for i := l0; i < l1; i++ { // can't use "range parens" here
   104  							paren := s.parens[i]
   105  							if new := s.edit(i); new != nil {
   106  								// Update AST and recursively mark nodes.
   107  								paren.X = new
   108  								ir.EditChildren(new, s.mark) // mark may append to parens
   109  								done = false
   110  							}
   111  						}
   112  						l0, l1 = l1, len(s.parens)
   113  						if l0 == l1 {
   114  							break
   115  						}
   116  						for i := l0; i < l1; i++ {
   117  							s.resolve(i)
   118  						}
   119  
   120  					}
   121  
   122  				}) // WithFunc
   123  
   124  			}
   125  		}
   126  	})
   127  
   128  	ir.CurFunc = nil
   129  
   130  	if base.Flag.LowerL != 0 {
   131  		if base.Debug.DumpInlFuncProps != "" {
   132  			inlheur.DumpFuncProps(nil, base.Debug.DumpInlFuncProps)
   133  		}
   134  		if inlheur.Enabled() {
   135  			inline.PostProcessCallSites(inlProfile)
   136  			inlheur.TearDown()
   137  		}
   138  	}
   139  
   140  	// remove parentheses
   141  	for _, fn := range typecheck.Target.Funcs {
   142  		inlState[fn].unparenthesize()
   143  	}
   144  
   145  }
   146  
   147  // DevirtualizeAndInlineFunc interleaves devirtualization and inlining
   148  // on a single function.
   149  func DevirtualizeAndInlineFunc(fn *ir.Func, profile *pgoir.Profile) {
   150  	ir.WithFunc(fn, func() {
   151  		if base.Flag.LowerL != 0 {
   152  			if inlheur.Enabled() && !fn.Wrapper() {
   153  				inlheur.ScoreCalls(fn)
   154  				defer inlheur.ScoreCallsCleanup()
   155  			}
   156  			if base.Debug.DumpInlFuncProps != "" && !fn.Wrapper() {
   157  				inlheur.DumpFuncProps(fn, base.Debug.DumpInlFuncProps)
   158  			}
   159  		}
   160  
   161  		bigCaller := base.Flag.LowerL != 0 && inline.IsBigFunc(fn)
   162  		if bigCaller && base.Flag.LowerM > 1 {
   163  			fmt.Printf("%v: function %v considered 'big'; reducing max cost of inlinees\n", ir.Line(fn), fn)
   164  		}
   165  
   166  		s := &inlClosureState{bigCaller: bigCaller, profile: profile, fn: fn, callSites: make(map[*ir.ParenExpr]bool), useCounts: make(map[*ir.Func]int)}
   167  		s.parenthesize()
   168  		s.fixpoint()
   169  		s.unparenthesize()
   170  	})
   171  }
   172  
   173  type callSite struct {
   174  	fn         *ir.Func
   175  	whichParen int
   176  }
   177  
   178  type inlClosureState struct {
   179  	fn        *ir.Func
   180  	profile   *pgoir.Profile
   181  	callSites map[*ir.ParenExpr]bool // callSites[p] == "p appears in parens" (do not append again)
   182  	resolved  []*ir.Func             // for each call in parens, the resolved target of the call
   183  	useCounts map[*ir.Func]int       // shared among all InlClosureStates
   184  	parens    []*ir.ParenExpr
   185  	bigCaller bool
   186  }
   187  
   188  // resolve attempts to resolve a call to a potentially inlineable callee
   189  // and updates use counts on the callees.  Returns the call site count
   190  // for that callee.
   191  func (s *inlClosureState) resolve(i int) (*ir.Func, int) {
   192  	p := s.parens[i]
   193  	if i < len(s.resolved) {
   194  		if callee := s.resolved[i]; callee != nil {
   195  			return callee, s.useCounts[callee]
   196  		}
   197  	}
   198  	n := p.X
   199  	call, ok := n.(*ir.CallExpr)
   200  	if !ok { // previously inlined
   201  		return nil, -1
   202  	}
   203  	devirtualize.StaticCall(call)
   204  	if callee := inline.InlineCallTarget(s.fn, call, s.profile); callee != nil {
   205  		for len(s.resolved) <= i {
   206  			s.resolved = append(s.resolved, nil)
   207  		}
   208  		s.resolved[i] = callee
   209  		c := s.useCounts[callee] + 1
   210  		s.useCounts[callee] = c
   211  		return callee, c
   212  	}
   213  	return nil, 0
   214  }
   215  
   216  func (s *inlClosureState) edit(i int) ir.Node {
   217  	n := s.parens[i].X
   218  	call, ok := n.(*ir.CallExpr)
   219  	if !ok {
   220  		return nil
   221  	}
   222  	// This is redundant with earlier calls to
   223  	// resolve, but because things can change it
   224  	// must be re-checked.
   225  	callee, count := s.resolve(i)
   226  	if count <= 0 {
   227  		return nil
   228  	}
   229  	if inlCall := inline.TryInlineCall(s.fn, call, s.bigCaller, s.profile, count == 1 && callee.ClosureParent != nil); inlCall != nil {
   230  		return inlCall
   231  	}
   232  	return nil
   233  }
   234  
   235  // Mark inserts parentheses, and is called repeatedly.
   236  // These inserted parentheses mark the call sites where
   237  // inlining will be attempted.
   238  func (s *inlClosureState) mark(n ir.Node) ir.Node {
   239  	// Consider the expression "f(g())". We want to be able to replace
   240  	// "g()" in-place with its inlined representation. But if we first
   241  	// replace "f(...)" with its inlined representation, then "g()" will
   242  	// instead appear somewhere within this new AST.
   243  	//
   244  	// To mitigate this, each matched node n is wrapped in a ParenExpr,
   245  	// so we can reliably replace n in-place by assigning ParenExpr.X.
   246  	// It's safe to use ParenExpr here, because typecheck already
   247  	// removed them all.
   248  
   249  	p, _ := n.(*ir.ParenExpr)
   250  	if p != nil && s.callSites[p] {
   251  		return n // already visited n.X before wrapping
   252  	}
   253  
   254  	if isTestingBLoop(n) {
   255  		// No inlining nor devirtualization performed on b.Loop body
   256  		if base.Flag.LowerM > 1 {
   257  			fmt.Printf("%v: skip inlining within testing.B.loop for %v\n", ir.Line(n), n)
   258  		}
   259  		// We still want to explore inlining opportunities in other parts of ForStmt.
   260  		nFor, _ := n.(*ir.ForStmt)
   261  		nForInit := nFor.Init()
   262  		for i, x := range nForInit {
   263  			if x != nil {
   264  				nForInit[i] = s.mark(x)
   265  			}
   266  		}
   267  		if nFor.Cond != nil {
   268  			nFor.Cond = s.mark(nFor.Cond)
   269  		}
   270  		if nFor.Post != nil {
   271  			nFor.Post = s.mark(nFor.Post)
   272  		}
   273  		return n
   274  	}
   275  
   276  	if p != nil {
   277  		n = p.X // in this case p was copied in from a (marked) inlined function, this is a new unvisited node.
   278  	}
   279  
   280  	ok := match(n)
   281  
   282  	// can't wrap TailCall's child into ParenExpr
   283  	if t, ok := n.(*ir.TailCallStmt); ok {
   284  		ir.EditChildren(t.Call, s.mark)
   285  	} else {
   286  		ir.EditChildren(n, s.mark)
   287  	}
   288  
   289  	if ok {
   290  		if p == nil {
   291  			p = ir.NewParenExpr(n.Pos(), n)
   292  			p.SetType(n.Type())
   293  			p.SetTypecheck(n.Typecheck())
   294  			s.callSites[p] = true
   295  		}
   296  
   297  		s.parens = append(s.parens, p)
   298  		n = p
   299  	} else if p != nil {
   300  		n = p // didn't change anything, restore n
   301  	}
   302  	return n
   303  }
   304  
   305  // parenthesize applies s.mark to all the nodes within
   306  // s.fn to mark calls and simplify rewriting them in place.
   307  func (s *inlClosureState) parenthesize() {
   308  	ir.EditChildren(s.fn, s.mark)
   309  }
   310  
   311  func (s *inlClosureState) unparenthesize() {
   312  	if s == nil {
   313  		return
   314  	}
   315  	if len(s.parens) == 0 {
   316  		return // short circuit
   317  	}
   318  
   319  	var unparen func(ir.Node) ir.Node
   320  	unparen = func(n ir.Node) ir.Node {
   321  		if paren, ok := n.(*ir.ParenExpr); ok {
   322  			n = paren.X
   323  		}
   324  		ir.EditChildren(n, unparen)
   325  		return n
   326  	}
   327  	ir.EditChildren(s.fn, unparen)
   328  }
   329  
   330  // fixpoint repeatedly edits a function until it stabilizes, returning
   331  // whether anything changed in any of the fixpoint iterations.
   332  //
   333  // It applies s.edit(n) to each node n within the parentheses in s.parens.
   334  // If s.edit(n) returns nil, no change is made. Otherwise, the result
   335  // replaces n in fn's body, and fixpoint iterates at least once more.
   336  //
   337  // After an iteration where all edit calls return nil, fixpoint
   338  // returns.
   339  func (s *inlClosureState) fixpoint() bool {
   340  	changed := false
   341  	ir.WithFunc(s.fn, func() {
   342  		done := false
   343  		for !done {
   344  			done = true
   345  			for i := 0; i < len(s.parens); i++ { // can't use "range parens" here
   346  				paren := s.parens[i]
   347  				if new := s.edit(i); new != nil {
   348  					// Update AST and recursively mark nodes.
   349  					paren.X = new
   350  					ir.EditChildren(new, s.mark) // mark may append to parens
   351  					done = false
   352  					changed = true
   353  				}
   354  			}
   355  		}
   356  	})
   357  	return changed
   358  }
   359  
   360  func match(n ir.Node) bool {
   361  	switch n := n.(type) {
   362  	case *ir.CallExpr:
   363  		return true
   364  	case *ir.TailCallStmt:
   365  		n.Call.NoInline = true // can't inline yet
   366  	}
   367  	return false
   368  }
   369  
   370  // isTestingBLoop returns true if it matches the node as a
   371  // testing.(*B).Loop. See issue #61515.
   372  func isTestingBLoop(t ir.Node) bool {
   373  	if t.Op() != ir.OFOR {
   374  		return false
   375  	}
   376  	nFor, ok := t.(*ir.ForStmt)
   377  	if !ok || nFor.Cond == nil || nFor.Cond.Op() != ir.OCALLFUNC {
   378  		return false
   379  	}
   380  	n, ok := nFor.Cond.(*ir.CallExpr)
   381  	if !ok || n.Fun == nil || n.Fun.Op() != ir.OMETHEXPR {
   382  		return false
   383  	}
   384  	name := ir.MethodExprName(n.Fun)
   385  	if name == nil {
   386  		return false
   387  	}
   388  	if fSym := name.Sym(); fSym != nil && name.Class == ir.PFUNC && fSym.Pkg != nil &&
   389  		fSym.Name == "(*B).Loop" && fSym.Pkg.Path == "testing" {
   390  		// Attempting to match a function call to testing.(*B).Loop
   391  		return true
   392  	}
   393  	return false
   394  }
   395  

View as plain text