Source file src/cmd/compile/internal/liveness/mergelocals.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package liveness
     6  
     7  import (
     8  	"cmd/compile/internal/base"
     9  	"cmd/compile/internal/bitvec"
    10  	"cmd/compile/internal/ir"
    11  	"cmd/compile/internal/ssa"
    12  	"cmd/internal/src"
    13  	"fmt"
    14  	"os"
    15  	"path/filepath"
    16  	"sort"
    17  	"strings"
    18  )
    19  
    20  // MergeLocalsState encapsulates information about which AUTO
    21  // (stack-allocated) variables within a function can be safely
    22  // merged/overlapped, e.g. share a stack slot with some other auto).
    23  // An instance of MergeLocalsState is produced by MergeLocals() below
    24  // and then consumed in ssagen.AllocFrame. The map 'partition'
    25  // contains entries of the form <N,SL> where N is an *ir.Name and SL
    26  // is a slice holding the indices (within 'vars') of other variables
    27  // that share the same slot, specifically the slot of the first
    28  // element in the partition, which we'll call the "leader". For
    29  // example, if a function contains five variables where v1/v2/v3 are
    30  // safe to overlap and v4/v5 are safe to overlap, the MergeLocalsState
    31  // content might look like
    32  //
    33  //	vars: [v1, v2, v3, v4, v5]
    34  //	partition: v1 -> [1, 0, 2], v2 -> [1, 0, 2], v3 -> [1, 0, 2]
    35  //	           v4 -> [3, 4], v5 -> [3, 4]
    36  //
    37  // A nil MergeLocalsState indicates that no local variables meet the
    38  // necessary criteria for overlap.
    39  type MergeLocalsState struct {
    40  	// contains auto vars that participate in overlapping
    41  	vars []*ir.Name
    42  	// maps auto variable to overlap partition
    43  	partition map[*ir.Name][]int
    44  }
    45  
    46  // candRegion is a sub-range (start, end) corresponding to an interval
    47  // [st,en] within the list of candidate variables.
    48  type candRegion struct {
    49  	st, en int
    50  }
    51  
    52  // cstate holds state information we'll need during the analysis
    53  // phase of stack slot merging but can be discarded when the analysis
    54  // is done.
    55  type cstate struct {
    56  	fn             *ir.Func
    57  	f              *ssa.Func
    58  	lv             *liveness
    59  	cands          []*ir.Name
    60  	nameToSlot     map[*ir.Name]int32
    61  	regions        []candRegion
    62  	indirectUE     map[ssa.ID][]*ir.Name
    63  	ivs            []Intervals
    64  	hashDeselected map[*ir.Name]bool
    65  	trace          int // debug trace level
    66  }
    67  
    68  // MergeLocals analyzes the specified ssa function f to determine which
    69  // of its auto variables can safely share the same stack slot, returning
    70  // a state object that describes how the overlap should be done.
    71  func MergeLocals(fn *ir.Func, f *ssa.Func) *MergeLocalsState {
    72  
    73  	// Create a container object for useful state info and then
    74  	// call collectMergeCandidates to see if there are vars suitable
    75  	// for stack slot merging.
    76  	cs := &cstate{
    77  		fn:    fn,
    78  		f:     f,
    79  		trace: base.Debug.MergeLocalsTrace,
    80  	}
    81  	cs.collectMergeCandidates()
    82  	if len(cs.regions) == 0 {
    83  		return nil
    84  	}
    85  
    86  	// Kick off liveness analysis.
    87  	//
    88  	// If we have a local variable such as "r2" below that's written
    89  	// but then not read, something like:
    90  	//
    91  	//      vardef r1
    92  	//      r1.x = ...
    93  	//      vardef r2
    94  	//      r2.x = 0
    95  	//      r2.y = ...
    96  	//      <call foo>
    97  	//      // no subsequent use of r2
    98  	//      ... = r1.x
    99  	//
   100  	// then for the purpose of calculating stack maps at the call, we
   101  	// can ignore "r2" completely during liveness analysis for stack
   102  	// maps, however for stack slock merging we most definitely want
   103  	// to treat the writes as "uses".
   104  	cs.lv = newliveness(fn, f, cs.cands, cs.nameToSlot, 0)
   105  	cs.lv.conservativeWrites = true
   106  	cs.lv.prologue()
   107  	cs.lv.solve()
   108  
   109  	// Compute intervals for each candidate based on the liveness and
   110  	// on block effects.
   111  	cs.computeIntervals()
   112  
   113  	// Perform merging within each region of the candidates list.
   114  	rv := cs.performMerging()
   115  	if err := rv.check(); err != nil {
   116  		base.FatalfAt(fn.Pos(), "invalid mergelocals state: %v", err)
   117  	}
   118  	return rv
   119  }
   120  
   121  // Subsumed returns whether variable n is subsumed, e.g. appears
   122  // in an overlap position but is not the leader in that partition.
   123  func (mls *MergeLocalsState) Subsumed(n *ir.Name) bool {
   124  	if sl, ok := mls.partition[n]; ok && mls.vars[sl[0]] != n {
   125  		return true
   126  	}
   127  	return false
   128  }
   129  
   130  // IsLeader returns whether a variable n is the leader (first element)
   131  // in a sharing partition.
   132  func (mls *MergeLocalsState) IsLeader(n *ir.Name) bool {
   133  	if sl, ok := mls.partition[n]; ok && mls.vars[sl[0]] == n {
   134  		return true
   135  	}
   136  	return false
   137  }
   138  
   139  // Leader returns the leader variable for subsumed var n.
   140  func (mls *MergeLocalsState) Leader(n *ir.Name) *ir.Name {
   141  	if sl, ok := mls.partition[n]; ok {
   142  		if mls.vars[sl[0]] == n {
   143  			panic("variable is not subsumed")
   144  		}
   145  		return mls.vars[sl[0]]
   146  	}
   147  	panic("not a merge candidate")
   148  }
   149  
   150  // Followers writes a list of the followers for leader n into the slice tmp.
   151  func (mls *MergeLocalsState) Followers(n *ir.Name, tmp []*ir.Name) []*ir.Name {
   152  	tmp = tmp[:0]
   153  	sl, ok := mls.partition[n]
   154  	if !ok {
   155  		panic("no entry for leader")
   156  	}
   157  	if mls.vars[sl[0]] != n {
   158  		panic("followers invoked on subsumed var")
   159  	}
   160  	for _, k := range sl[1:] {
   161  		tmp = append(tmp, mls.vars[k])
   162  	}
   163  	sort.SliceStable(tmp, func(i, j int) bool {
   164  		return tmp[i].Sym().Name < tmp[j].Sym().Name
   165  	})
   166  	return tmp
   167  }
   168  
   169  // EstSavings returns the estimated reduction in stack size (number of bytes) for
   170  // the given merge locals state via a pair of ints, the first for non-pointer types and the second for pointer types.
   171  func (mls *MergeLocalsState) EstSavings() (int, int) {
   172  	totnp := 0
   173  	totp := 0
   174  	for n := range mls.partition {
   175  		if mls.Subsumed(n) {
   176  			sz := int(n.Type().Size())
   177  			if n.Type().HasPointers() {
   178  				totp += sz
   179  			} else {
   180  				totnp += sz
   181  			}
   182  		}
   183  	}
   184  	return totnp, totp
   185  }
   186  
   187  // check tests for various inconsistencies and problems in mls,
   188  // returning an error if any problems are found.
   189  func (mls *MergeLocalsState) check() error {
   190  	if mls == nil {
   191  		return nil
   192  	}
   193  	used := make(map[int]bool)
   194  	seenv := make(map[*ir.Name]int)
   195  	for ii, v := range mls.vars {
   196  		if prev, ok := seenv[v]; ok {
   197  			return fmt.Errorf("duplicate var %q in vslots: %d and %d\n",
   198  				v.Sym().Name, ii, prev)
   199  		}
   200  		seenv[v] = ii
   201  	}
   202  	for k, sl := range mls.partition {
   203  		// length of slice value needs to be more than 1
   204  		if len(sl) < 2 {
   205  			return fmt.Errorf("k=%q v=%+v slice len %d invalid",
   206  				k.Sym().Name, sl, len(sl))
   207  		}
   208  		// values in the slice need to be var indices
   209  		for i, v := range sl {
   210  			if v < 0 || v > len(mls.vars)-1 {
   211  				return fmt.Errorf("k=%q v=+%v slpos %d vslot %d out of range of m.v", k.Sym().Name, sl, i, v)
   212  			}
   213  		}
   214  	}
   215  	for k, sl := range mls.partition {
   216  		foundk := false
   217  		for i, v := range sl {
   218  			vv := mls.vars[v]
   219  			if i == 0 {
   220  				if !mls.IsLeader(vv) {
   221  					return fmt.Errorf("k=%s v=+%v slpos 0 vslot %d IsLeader(%q) is false should be true", k.Sym().Name, sl, v, vv.Sym().Name)
   222  				}
   223  			} else {
   224  				if !mls.Subsumed(vv) {
   225  					return fmt.Errorf("k=%s v=+%v slpos %d vslot %d Subsumed(%q) is false should be true", k.Sym().Name, sl, i, v, vv.Sym().Name)
   226  				}
   227  				if mls.Leader(vv) != mls.vars[sl[0]] {
   228  					return fmt.Errorf("k=%s v=+%v slpos %d vslot %d Leader(%q) got %v want %v", k.Sym().Name, sl, i, v, vv.Sym().Name, mls.Leader(vv), mls.vars[sl[0]])
   229  				}
   230  			}
   231  			if vv == k {
   232  				foundk = true
   233  				if used[v] {
   234  					return fmt.Errorf("k=%s v=+%v val slice used violation at slpos %d vslot %d", k.Sym().Name, sl, i, v)
   235  				}
   236  				used[v] = true
   237  			}
   238  		}
   239  		if !foundk {
   240  			return fmt.Errorf("k=%s v=+%v slice value missing k", k.Sym().Name, sl)
   241  		}
   242  		vl := mls.vars[sl[0]]
   243  		for _, v := range sl[1:] {
   244  			vv := mls.vars[v]
   245  			if vv.Type().Size() > vl.Type().Size() {
   246  				return fmt.Errorf("k=%s v=+%v follower %s size %d larger than leader %s size %d", k.Sym().Name, sl, vv.Sym().Name, vv.Type().Size(), vl.Sym().Name, vl.Type().Size())
   247  			}
   248  			if vv.Type().HasPointers() && !vl.Type().HasPointers() {
   249  				return fmt.Errorf("k=%s v=+%v follower %s hasptr=true but leader %s hasptr=false", k.Sym().Name, sl, vv.Sym().Name, vl.Sym().Name)
   250  			}
   251  			if vv.Type().Alignment() > vl.Type().Alignment() {
   252  				return fmt.Errorf("k=%s v=+%v follower %s align %d greater than leader %s align %d", k.Sym().Name, sl, vv.Sym().Name, vv.Type().Alignment(), vl.Sym().Name, vl.Type().Alignment())
   253  			}
   254  		}
   255  	}
   256  	for i := range used {
   257  		if !used[i] {
   258  			return fmt.Errorf("pos %d var %q unused", i, mls.vars[i])
   259  		}
   260  	}
   261  	return nil
   262  }
   263  
   264  func (mls *MergeLocalsState) String() string {
   265  	var leaders []*ir.Name
   266  	for n, sl := range mls.partition {
   267  		if n == mls.vars[sl[0]] {
   268  			leaders = append(leaders, n)
   269  		}
   270  	}
   271  	sort.Slice(leaders, func(i, j int) bool {
   272  		return leaders[i].Sym().Name < leaders[j].Sym().Name
   273  	})
   274  	var sb strings.Builder
   275  	for _, n := range leaders {
   276  		sb.WriteString(n.Sym().Name + ":")
   277  		sl := mls.partition[n]
   278  		for _, k := range sl[1:] {
   279  			n := mls.vars[k]
   280  			sb.WriteString(" " + n.Sym().Name)
   281  		}
   282  		sb.WriteString("\n")
   283  	}
   284  	return sb.String()
   285  }
   286  
   287  // collectMergeCandidates visits all of the AUTO vars declared in
   288  // function fn and identifies a list of candidate variables for
   289  // merging / overlapping. On return the "cands" field of cs will be
   290  // filled in with our set of potentially overlappable candidate
   291  // variables, the "regions" field will hold regions/sequence of
   292  // compatible vars within the candidates list, "nameToSlot" field will
   293  // be populated, and the "indirectUE" field will be filled in with
   294  // information about indirect upwards-exposed uses in the func.
   295  func (cs *cstate) collectMergeCandidates() {
   296  	var cands []*ir.Name
   297  
   298  	// Collect up the available set of appropriate AUTOs in the
   299  	// function as a first step, and bail if we have fewer than
   300  	// two candidates.
   301  	for _, n := range cs.fn.Dcl {
   302  		if !n.Used() {
   303  			continue
   304  		}
   305  		if !ssa.IsMergeCandidate(n) {
   306  			continue
   307  		}
   308  		cands = append(cands, n)
   309  	}
   310  	if len(cands) < 2 {
   311  		return
   312  	}
   313  
   314  	// Sort by pointerness, size, and then name.
   315  	sort.SliceStable(cands, func(i, j int) bool {
   316  		return nameLess(cands[i], cands[j])
   317  	})
   318  
   319  	if cs.trace > 1 {
   320  		fmt.Fprintf(os.Stderr, "=-= raw cand list for func %v:\n", cs.fn)
   321  		for i := range cands {
   322  			dumpCand(cands[i], i)
   323  		}
   324  	}
   325  
   326  	// Now generate an initial pruned candidate list and regions list.
   327  	// This may be empty if we don't have enough compatible candidates.
   328  	initial, _ := cs.genRegions(cands)
   329  	if len(initial) < 2 {
   330  		return
   331  	}
   332  
   333  	// Set up for hash bisection if enabled.
   334  	cs.setupHashBisection(initial)
   335  
   336  	// Create and populate an indirect use table that we'll use
   337  	// during interval construction. As part of this process we may
   338  	// wind up tossing out additional candidates, so check to make
   339  	// sure we still have something to work with.
   340  	cs.cands, cs.regions = cs.populateIndirectUseTable(initial)
   341  	if len(cs.cands) < 2 {
   342  		return
   343  	}
   344  
   345  	// At this point we have a final pruned set of candidates and a
   346  	// corresponding set of regions for the candidates. Build a
   347  	// name-to-slot map for the candidates.
   348  	cs.nameToSlot = make(map[*ir.Name]int32)
   349  	for i, n := range cs.cands {
   350  		cs.nameToSlot[n] = int32(i)
   351  	}
   352  
   353  	if cs.trace > 1 {
   354  		fmt.Fprintf(os.Stderr, "=-= pruned candidate list for fn %v:\n", cs.fn)
   355  		for i := range cs.cands {
   356  			dumpCand(cs.cands[i], i)
   357  		}
   358  	}
   359  }
   360  
   361  // genRegions generates a set of regions within cands corresponding
   362  // to potentially overlappable/mergeable variables.
   363  func (cs *cstate) genRegions(cands []*ir.Name) ([]*ir.Name, []candRegion) {
   364  	var pruned []*ir.Name
   365  	var regions []candRegion
   366  	st := 0
   367  	for {
   368  		en := nextRegion(cands, st)
   369  		if en == -1 {
   370  			break
   371  		}
   372  		if st == en {
   373  			// region has just one element, we can skip it
   374  			st++
   375  			continue
   376  		}
   377  		pst := len(pruned)
   378  		pen := pst + (en - st)
   379  		if cs.trace > 1 {
   380  			fmt.Fprintf(os.Stderr, "=-= addregion st=%d en=%d: add part %d -> %d\n", st, en, pst, pen)
   381  		}
   382  
   383  		// non-empty region, add to pruned
   384  		pruned = append(pruned, cands[st:en+1]...)
   385  		regions = append(regions, candRegion{st: pst, en: pen})
   386  		st = en + 1
   387  	}
   388  	if len(pruned) < 2 {
   389  		return nil, nil
   390  	}
   391  	return pruned, regions
   392  }
   393  
   394  func (cs *cstate) dumpFunc() {
   395  	fmt.Fprintf(os.Stderr, "=-= mergelocalsdumpfunc %v:\n", cs.fn)
   396  	ii := 0
   397  	for k, b := range cs.f.Blocks {
   398  		fmt.Fprintf(os.Stderr, "b%d:\n", k)
   399  		for _, v := range b.Values {
   400  			pos := base.Ctxt.PosTable.Pos(v.Pos)
   401  			fmt.Fprintf(os.Stderr, "=-= %d L%d|C%d %s\n", ii, pos.RelLine(), pos.RelCol(), v.LongString())
   402  			ii++
   403  		}
   404  	}
   405  }
   406  
   407  func (cs *cstate) dumpFuncIfSelected() {
   408  	if base.Debug.MergeLocalsDumpFunc == "" {
   409  		return
   410  	}
   411  	if !strings.HasSuffix(fmt.Sprintf("%v", cs.fn),
   412  		base.Debug.MergeLocalsDumpFunc) {
   413  		return
   414  	}
   415  	cs.dumpFunc()
   416  }
   417  
   418  // setupHashBisection checks to see if any of the candidate
   419  // variables have been de-selected by our hash debug. Here
   420  // we also implement the -d=mergelocalshtrace flag, which turns
   421  // on debug tracing only if we have at least two candidates
   422  // selected by the hash debug for this function.
   423  func (cs *cstate) setupHashBisection(cands []*ir.Name) {
   424  	if base.Debug.MergeLocalsHash == "" {
   425  		return
   426  	}
   427  	deselected := make(map[*ir.Name]bool)
   428  	selCount := 0
   429  	for _, cand := range cands {
   430  		if !base.MergeLocalsHash.MatchPosWithInfo(cand.Pos(), "mergelocals", nil) {
   431  			deselected[cand] = true
   432  		} else {
   433  			deselected[cand] = false
   434  			selCount++
   435  		}
   436  	}
   437  	if selCount < len(cands) {
   438  		cs.hashDeselected = deselected
   439  	}
   440  	if base.Debug.MergeLocalsHTrace != 0 && selCount >= 2 {
   441  		cs.trace = base.Debug.MergeLocalsHTrace
   442  	}
   443  }
   444  
   445  // populateIndirectUseTable creates and populates the "indirectUE" table
   446  // within cs by doing some additional analysis of how the vars in
   447  // cands are accessed in the function.
   448  //
   449  // It is possible to have situations where a given ir.Name is
   450  // non-address-taken at the source level, but whose address is
   451  // materialized in order to accommodate the needs of
   452  // architecture-dependent operations or one sort or another (examples
   453  // include things like LoweredZero/DuffZero, etc). The issue here is
   454  // that the SymAddr op will show up as touching a variable of
   455  // interest, but the subsequent memory op will not. This is generally
   456  // not an issue for computing whether something is live across a call,
   457  // but it is problematic for collecting the more fine-grained live
   458  // interval info that drives stack slot merging.
   459  //
   460  // To handle this problem, make a forward pass over each basic block
   461  // looking for instructions of the form vK := SymAddr(N) where N is a
   462  // raw candidate. Create an entry in a map at that point from vK to
   463  // its use count. Continue the walk, looking for uses of vK: when we
   464  // see one, record it in a side table as an upwards exposed use of N.
   465  // Each time we see a use, decrement the use count in the map, and if
   466  // we hit zero, remove the map entry. If we hit the end of the basic
   467  // block and we still have map entries, then evict the name in
   468  // question from the candidate set.
   469  func (cs *cstate) populateIndirectUseTable(cands []*ir.Name) ([]*ir.Name, []candRegion) {
   470  
   471  	// main indirect UE table, this is what we're producing in this func
   472  	indirectUE := make(map[ssa.ID][]*ir.Name)
   473  
   474  	// this map holds the current set of candidates; the set may
   475  	// shrink if we have to evict any candidates.
   476  	rawcands := make(map[*ir.Name]struct{})
   477  
   478  	// maps ssa value V to the ir.Name it is taking the addr of,
   479  	// plus a count of the uses we've seen of V during a block walk.
   480  	pendingUses := make(map[ssa.ID]nameCount)
   481  
   482  	// A temporary indirect UE tab just for the current block
   483  	// being processed; used to help with evictions.
   484  	blockIndirectUE := make(map[ssa.ID][]*ir.Name)
   485  
   486  	// temporary map used to record evictions in a given block.
   487  	evicted := make(map[*ir.Name]bool)
   488  	for _, n := range cands {
   489  		rawcands[n] = struct{}{}
   490  	}
   491  	for k := 0; k < len(cs.f.Blocks); k++ {
   492  		genmapclear(pendingUses)
   493  		genmapclear(blockIndirectUE)
   494  		b := cs.f.Blocks[k]
   495  		for _, v := range b.Values {
   496  			if n, e := affectedVar(v); n != nil {
   497  				if _, ok := rawcands[n]; ok {
   498  					if e&ssa.SymAddr != 0 && v.Uses != 0 {
   499  						// we're taking the address of candidate var n
   500  						if _, ok := pendingUses[v.ID]; ok {
   501  							// should never happen
   502  							base.FatalfAt(v.Pos, "internal error: apparent multiple defs for SSA value %d", v.ID)
   503  						}
   504  						// Stash an entry in pendingUses recording
   505  						// that we took the address of "n" via this
   506  						// val.
   507  						pendingUses[v.ID] = nameCount{n: n, count: v.Uses}
   508  						if cs.trace > 2 {
   509  							fmt.Fprintf(os.Stderr, "=-= SymAddr(%s) on %s\n",
   510  								n.Sym().Name, v.LongString())
   511  						}
   512  					}
   513  				}
   514  			}
   515  			for _, arg := range v.Args {
   516  				if nc, ok := pendingUses[arg.ID]; ok {
   517  					// We found a use of some value that took the
   518  					// address of nc.n. Record this inst as a
   519  					// potential indirect use.
   520  					if cs.trace > 2 {
   521  						fmt.Fprintf(os.Stderr, "=-= add indirectUE(%s) count=%d on %s\n", nc.n.Sym().Name, nc.count, v.LongString())
   522  					}
   523  					blockIndirectUE[v.ID] = append(blockIndirectUE[v.ID], nc.n)
   524  					nc.count--
   525  					if nc.count == 0 {
   526  						// That was the last use of the value. Clean
   527  						// up the entry in pendingUses.
   528  						if cs.trace > 2 {
   529  							fmt.Fprintf(os.Stderr, "=-= last use of v%d\n",
   530  								arg.ID)
   531  						}
   532  						delete(pendingUses, arg.ID)
   533  					} else {
   534  						// Not the last use; record the decremented
   535  						// use count and move on.
   536  						pendingUses[arg.ID] = nc
   537  					}
   538  				}
   539  			}
   540  		}
   541  
   542  		// We've reached the end of this basic block: if we have any
   543  		// leftover entries in pendingUses, then evict the
   544  		// corresponding names from the candidate set. The idea here
   545  		// is that if we materialized the address of some local and
   546  		// that value is flowing out of the block off somewhere else,
   547  		// we're going to treat that local as truly address-taken and
   548  		// not have it be a merge candidate.
   549  		genmapclear(evicted)
   550  		if len(pendingUses) != 0 {
   551  			for id, nc := range pendingUses {
   552  				if cs.trace > 2 {
   553  					fmt.Fprintf(os.Stderr, "=-= evicting %q due to pendingUse %d count %d\n", nc.n.Sym().Name, id, nc.count)
   554  				}
   555  				delete(rawcands, nc.n)
   556  				evicted[nc.n] = true
   557  			}
   558  		}
   559  		// Copy entries from blockIndirectUE into final indirectUE. Skip
   560  		// anything that we evicted in the loop above.
   561  		for id, sl := range blockIndirectUE {
   562  			for _, n := range sl {
   563  				if evicted[n] {
   564  					continue
   565  				}
   566  				indirectUE[id] = append(indirectUE[id], n)
   567  				if cs.trace > 2 {
   568  					fmt.Fprintf(os.Stderr, "=-= add final indUE v%d name %s\n", id, n.Sym().Name)
   569  				}
   570  			}
   571  		}
   572  	}
   573  	if len(rawcands) < 2 {
   574  		return nil, nil
   575  	}
   576  	cs.indirectUE = indirectUE
   577  	if cs.trace > 2 {
   578  		fmt.Fprintf(os.Stderr, "=-= iuetab:\n")
   579  		ids := make([]ssa.ID, 0, len(indirectUE))
   580  		for k := range indirectUE {
   581  			ids = append(ids, k)
   582  		}
   583  		sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] })
   584  		for _, id := range ids {
   585  			fmt.Fprintf(os.Stderr, "  v%d:", id)
   586  			for _, n := range indirectUE[id] {
   587  				fmt.Fprintf(os.Stderr, " %s", n.Sym().Name)
   588  			}
   589  			fmt.Fprintf(os.Stderr, "\n")
   590  		}
   591  	}
   592  
   593  	pruned := cands[:0]
   594  	for k := range rawcands {
   595  		pruned = append(pruned, k)
   596  	}
   597  	sort.Slice(pruned, func(i, j int) bool {
   598  		return nameLess(pruned[i], pruned[j])
   599  	})
   600  	var regions []candRegion
   601  	pruned, regions = cs.genRegions(pruned)
   602  	if len(pruned) < 2 {
   603  		return nil, nil
   604  	}
   605  	return pruned, regions
   606  }
   607  
   608  // FIXME: bootstrap tool compiler is build with a "go 1.20" go.mod, so
   609  // we are not allowed to use map clear yet. Use this helper instead.
   610  func genmapclear[KT comparable, VT any](m map[KT]VT) {
   611  	for k := range m {
   612  		delete(m, k)
   613  	}
   614  }
   615  
   616  type nameCount struct {
   617  	n     *ir.Name
   618  	count int32
   619  }
   620  
   621  // nameLess compares ci with cj to see if ci should be less than cj in
   622  // a relative ordering of candidate variables. This is used to sort
   623  // vars by pointerness (variables with pointers first), then in order
   624  // of decreasing alignment, then by decreasing size. We are assuming a
   625  // merging algorithm that merges later entries in the list into
   626  // earlier entries. An example ordered candidate list produced by
   627  // nameLess:
   628  //
   629  //	idx   name    type       align    size
   630  //	0:    abc     [10]*int   8        80
   631  //	1:    xyz     [9]*int    8        72
   632  //	2:    qrs     [2]*int    8        16
   633  //	3:    tuv     [9]int     8        72
   634  //	4:    wxy     [9]int32   4        36
   635  //	5:    jkl     [8]int32   4        32
   636  func nameLess(ci, cj *ir.Name) bool {
   637  	if ci.Type().HasPointers() != cj.Type().HasPointers() {
   638  		return ci.Type().HasPointers()
   639  	}
   640  	if ci.Type().Alignment() != cj.Type().Alignment() {
   641  		return cj.Type().Alignment() < ci.Type().Alignment()
   642  	}
   643  	if ci.Type().Size() != cj.Type().Size() {
   644  		return cj.Type().Size() < ci.Type().Size()
   645  	}
   646  	if ci.Sym().Name != cj.Sym().Name {
   647  		return ci.Sym().Name < cj.Sym().Name
   648  	}
   649  	return fmt.Sprintf("%v", ci.Pos()) < fmt.Sprintf("%v", cj.Pos())
   650  }
   651  
   652  // nextRegion starts at location idx and walks forward in the cands
   653  // slice looking for variables that are "compatible" (potentially
   654  // overlappable, in the sense that they could potentially share the
   655  // stack slot of cands[idx]); it returns the end of the new region
   656  // (range of compatible variables starting at idx).
   657  func nextRegion(cands []*ir.Name, idx int) int {
   658  	n := len(cands)
   659  	if idx >= n {
   660  		return -1
   661  	}
   662  	c0 := cands[idx]
   663  	szprev := c0.Type().Size()
   664  	alnprev := c0.Type().Alignment()
   665  	for j := idx + 1; j < n; j++ {
   666  		cj := cands[j]
   667  		szj := cj.Type().Size()
   668  		if szj > szprev {
   669  			return j - 1
   670  		}
   671  		alnj := cj.Type().Alignment()
   672  		if alnj > alnprev {
   673  			return j - 1
   674  		}
   675  		szprev = szj
   676  		alnprev = alnj
   677  	}
   678  	return n - 1
   679  }
   680  
   681  // mergeVisitRegion tries to perform overlapping of variables with a
   682  // given subrange of cands described by st and en (indices into our
   683  // candidate var list), where the variables within this range have
   684  // already been determined to be compatible with respect to type,
   685  // size, etc. Overlapping is done in a a greedy fashion: we select the
   686  // first element in the st->en range, then walk the rest of the
   687  // elements adding in vars whose lifetimes don't overlap with the
   688  // first element, then repeat the process until we run out of work.
   689  // Ordering of the candidates within the region [st,en] is important;
   690  // within the list the assumption is that if we overlap two variables
   691  // X and Y where X precedes Y in the list, we need to make X the
   692  // "leader" (keep X's slot and set Y's frame offset to X's) as opposed
   693  // to the other way around, since it's possible that Y is smaller in
   694  // size than X.
   695  func (cs *cstate) mergeVisitRegion(mls *MergeLocalsState, st, en int) {
   696  	if cs.trace > 1 {
   697  		fmt.Fprintf(os.Stderr, "=-= mergeVisitRegion(st=%d, en=%d)\n", st, en)
   698  	}
   699  	n := en - st + 1
   700  	used := bitvec.New(int32(n))
   701  
   702  	nxt := func(slot int) int {
   703  		for c := slot - st; c < n; c++ {
   704  			if used.Get(int32(c)) {
   705  				continue
   706  			}
   707  			return c + st
   708  		}
   709  		return -1
   710  	}
   711  
   712  	navail := n
   713  	cands := cs.cands
   714  	ivs := cs.ivs
   715  	if cs.trace > 1 {
   716  		fmt.Fprintf(os.Stderr, "  =-= navail = %d\n", navail)
   717  	}
   718  	for navail >= 2 {
   719  		leader := nxt(st)
   720  		used.Set(int32(leader - st))
   721  		navail--
   722  
   723  		if cs.trace > 1 {
   724  			fmt.Fprintf(os.Stderr, "  =-= begin leader %d used=%s\n", leader,
   725  				used.String())
   726  		}
   727  		elems := []int{leader}
   728  		lints := ivs[leader]
   729  
   730  		for succ := nxt(leader + 1); succ != -1; succ = nxt(succ + 1) {
   731  
   732  			// Skip if de-selected by merge locals hash.
   733  			if cs.hashDeselected != nil && cs.hashDeselected[cands[succ]] {
   734  				continue
   735  			}
   736  			// Skip if already used.
   737  			if used.Get(int32(succ - st)) {
   738  				continue
   739  			}
   740  			if cs.trace > 1 {
   741  				fmt.Fprintf(os.Stderr, "  =-= overlap of %d[%v] {%s} with %d[%v] {%s} is: %v\n", leader, cands[leader], lints.String(), succ, cands[succ], ivs[succ].String(), lints.Overlaps(ivs[succ]))
   742  			}
   743  
   744  			// Can we overlap leader with this var?
   745  			if lints.Overlaps(ivs[succ]) {
   746  				continue
   747  			} else {
   748  				// Add to overlap set.
   749  				elems = append(elems, succ)
   750  				lints = lints.Merge(ivs[succ])
   751  			}
   752  		}
   753  		if len(elems) > 1 {
   754  			// We found some things to overlap with leader. Add the
   755  			// candidate elements to "vars" and update "partition".
   756  			off := len(mls.vars)
   757  			sl := make([]int, len(elems))
   758  			for i, candslot := range elems {
   759  				sl[i] = off + i
   760  				mls.vars = append(mls.vars, cands[candslot])
   761  				mls.partition[cands[candslot]] = sl
   762  			}
   763  			navail -= (len(elems) - 1)
   764  			for i := range elems {
   765  				used.Set(int32(elems[i] - st))
   766  			}
   767  			if cs.trace > 1 {
   768  				fmt.Fprintf(os.Stderr, "=-= overlapping %+v:\n", sl)
   769  				for i := range sl {
   770  					dumpCand(mls.vars[sl[i]], sl[i])
   771  				}
   772  				for i, v := range elems {
   773  					fmt.Fprintf(os.Stderr, "=-= %d: sl=%d %s\n", i, v, ivs[v])
   774  				}
   775  			}
   776  		}
   777  	}
   778  }
   779  
   780  // performMerging carries out variable merging within each of the
   781  // candidate ranges in regions, returning a state object
   782  // that describes the variable overlaps.
   783  func (cs *cstate) performMerging() *MergeLocalsState {
   784  	cands := cs.cands
   785  
   786  	mls := &MergeLocalsState{
   787  		partition: make(map[*ir.Name][]int),
   788  	}
   789  
   790  	// Dump state before attempting overlap.
   791  	if cs.trace > 1 {
   792  		fmt.Fprintf(os.Stderr, "=-= cands live before overlap:\n")
   793  		for i := range cands {
   794  			c := cands[i]
   795  			fmt.Fprintf(os.Stderr, "%d: %v sz=%d ivs=%s\n",
   796  				i, c.Sym().Name, c.Type().Size(), cs.ivs[i].String())
   797  		}
   798  		fmt.Fprintf(os.Stderr, "=-= regions (%d): ", len(cs.regions))
   799  		for _, cr := range cs.regions {
   800  			fmt.Fprintf(os.Stderr, " [%d,%d]", cr.st, cr.en)
   801  		}
   802  		fmt.Fprintf(os.Stderr, "\n")
   803  	}
   804  
   805  	// Apply a greedy merge/overlap strategy within each region
   806  	// of compatible variables.
   807  	for _, cr := range cs.regions {
   808  		cs.mergeVisitRegion(mls, cr.st, cr.en)
   809  	}
   810  	if len(mls.vars) == 0 {
   811  		return nil
   812  	}
   813  	return mls
   814  }
   815  
   816  // computeIntervals performs a backwards sweep over the instructions
   817  // of the function we're compiling, building up an Intervals object
   818  // for each candidate variable by looking for upwards exposed uses
   819  // and kills.
   820  func (cs *cstate) computeIntervals() {
   821  	lv := cs.lv
   822  	ibuilders := make([]IntervalsBuilder, len(cs.cands))
   823  	nvars := int32(len(lv.vars))
   824  	liveout := bitvec.New(nvars)
   825  
   826  	cs.dumpFuncIfSelected()
   827  
   828  	// Count instructions.
   829  	ninstr := 0
   830  	for _, b := range lv.f.Blocks {
   831  		ninstr += len(b.Values)
   832  	}
   833  	// current instruction index during backwards walk
   834  	iidx := ninstr - 1
   835  
   836  	// Make a backwards pass over all blocks
   837  	for k := len(lv.f.Blocks) - 1; k >= 0; k-- {
   838  		b := lv.f.Blocks[k]
   839  		be := lv.blockEffects(b)
   840  
   841  		if cs.trace > 2 {
   842  			fmt.Fprintf(os.Stderr, "=-= liveout from tail of b%d: ", k)
   843  			for j := range lv.vars {
   844  				if be.liveout.Get(int32(j)) {
   845  					fmt.Fprintf(os.Stderr, " %q", lv.vars[j].Sym().Name)
   846  				}
   847  			}
   848  			fmt.Fprintf(os.Stderr, "\n")
   849  		}
   850  
   851  		// Take into account effects taking place at end of this basic
   852  		// block by comparing our current live set with liveout for
   853  		// the block. If a given var was not live before and is now
   854  		// becoming live we need to mark this transition with a
   855  		// builder "Live" call; similarly if a var was live before and
   856  		// is now no longer live, we need a "Kill" call.
   857  		for j := range lv.vars {
   858  			isLive := liveout.Get(int32(j))
   859  			blockLiveOut := be.liveout.Get(int32(j))
   860  			if isLive {
   861  				if !blockLiveOut {
   862  					if cs.trace > 2 {
   863  						fmt.Fprintf(os.Stderr, "=+= at instr %d block boundary kill of %v\n", iidx, lv.vars[j])
   864  					}
   865  					ibuilders[j].Kill(iidx)
   866  				}
   867  			} else if blockLiveOut {
   868  				if cs.trace > 2 {
   869  					fmt.Fprintf(os.Stderr, "=+= at block-end instr %d %v becomes live\n",
   870  						iidx, lv.vars[j])
   871  				}
   872  				ibuilders[j].Live(iidx)
   873  			}
   874  		}
   875  
   876  		// Set our working "currently live" set to the previously
   877  		// computed live out set for the block.
   878  		liveout.Copy(be.liveout)
   879  
   880  		// Now walk backwards through this block.
   881  		for i := len(b.Values) - 1; i >= 0; i-- {
   882  			v := b.Values[i]
   883  
   884  			if cs.trace > 2 {
   885  				fmt.Fprintf(os.Stderr, "=-= b%d instr %d: %s\n", k, iidx, v.LongString())
   886  			}
   887  
   888  			// Update liveness based on what we see happening in this
   889  			// instruction.
   890  			pos, e := lv.valueEffects(v)
   891  			becomeslive := e&uevar != 0
   892  			iskilled := e&varkill != 0
   893  			if becomeslive && iskilled {
   894  				// we do not ever expect to see both a kill and an
   895  				// upwards exposed use given our size constraints.
   896  				panic("should never happen")
   897  			}
   898  			if iskilled && liveout.Get(pos) {
   899  				ibuilders[pos].Kill(iidx)
   900  				liveout.Unset(pos)
   901  				if cs.trace > 2 {
   902  					fmt.Fprintf(os.Stderr, "=+= at instr %d kill of %v\n",
   903  						iidx, lv.vars[pos])
   904  				}
   905  			} else if becomeslive && !liveout.Get(pos) {
   906  				ibuilders[pos].Live(iidx)
   907  				liveout.Set(pos)
   908  				if cs.trace > 2 {
   909  					fmt.Fprintf(os.Stderr, "=+= at instr %d upwards-exposed use of %v\n",
   910  						iidx, lv.vars[pos])
   911  				}
   912  			}
   913  
   914  			if cs.indirectUE != nil {
   915  				// Now handle "indirect" upwards-exposed uses.
   916  				ues := cs.indirectUE[v.ID]
   917  				for _, n := range ues {
   918  					if pos, ok := lv.idx[n]; ok {
   919  						if !liveout.Get(pos) {
   920  							ibuilders[pos].Live(iidx)
   921  							liveout.Set(pos)
   922  							if cs.trace > 2 {
   923  								fmt.Fprintf(os.Stderr, "=+= at instr %d v%d indirect upwards-exposed use of %v\n", iidx, v.ID, lv.vars[pos])
   924  							}
   925  						}
   926  					}
   927  				}
   928  			}
   929  			iidx--
   930  		}
   931  
   932  		// This check disabled for now due to the way scheduling works
   933  		// for ops that materialize values of local variables. For
   934  		// many architecture we have rewrite rules of this form:
   935  		//
   936  		// (LocalAddr <t> {sym} base mem) && t.Elem().HasPointers() => (MOVDaddr {sym} (SPanchored base mem))
   937  		// (LocalAddr <t> {sym} base _)  && !t.Elem().HasPointers() => (MOVDaddr {sym} base)
   938  		//
   939  		// which are designed to ensure that if you have a pointerful
   940  		// variable "abc" sequence
   941  		//
   942  		//    v30 = VarDef <mem> {abc} v21
   943  		//    v31 = LocalAddr <*SB> {abc} v2 v30
   944  		//    v32 = Zero <mem> {SB} [2056] v31 v30
   945  		//
   946  		// this will be lowered into
   947  		//
   948  		//    v30 = VarDef <mem> {sb} v21
   949  		//   v106 = SPanchored <uintptr> v2 v30
   950  		//    v31 = MOVDaddr <*SB> {sb} v106
   951  		//     v3 = DUFFZERO <mem> [2056] v31 v30
   952  		//
   953  		// Note the SPanchored: this ensures that the scheduler won't
   954  		// move the MOVDaddr earlier than the vardef. With a variable
   955  		// "xyz" that has no pointers, howver, if we start with
   956  		//
   957  		//    v66 = VarDef <mem> {t2} v65
   958  		//    v67 = LocalAddr <*T> {t2} v2 v66
   959  		//    v68 = Zero <mem> {T} [2056] v67 v66
   960  		//
   961  		// we might lower to
   962  		//
   963  		//    v66 = VarDef <mem> {t2} v65
   964  		//    v29 = MOVDaddr <*T> {t2} [2032] v2
   965  		//    v43 = LoweredZero <mem> v67 v29 v66
   966  		//    v68 = Zero [2056] v2 v43
   967  		//
   968  		// where that MOVDaddr can float around arbitrarily, meaning
   969  		// that we may see an upwards-exposed use to it before the
   970  		// VarDef.
   971  		//
   972  		// One avenue to restoring the check below would be to change
   973  		// the rewrite rules to something like
   974  		//
   975  		// (LocalAddr <t> {sym} base mem) && (t.Elem().HasPointers() || isMergeCandidate(t) => (MOVDaddr {sym} (SPanchored base mem))
   976  		//
   977  		// however that change will have to be carefully evaluated,
   978  		// since it would constrain the scheduler for _all_ LocalAddr
   979  		// ops for potential merge candidates, even if we don't
   980  		// actually succeed in any overlaps. This will be revisitged in
   981  		// a later CL if possible.
   982  		//
   983  		const checkLiveOnEntry = false
   984  		if checkLiveOnEntry && b == lv.f.Entry {
   985  			for j, v := range lv.vars {
   986  				if liveout.Get(int32(j)) {
   987  					lv.f.Fatalf("%v %L recorded as live on entry",
   988  						lv.fn.Nname, v)
   989  				}
   990  			}
   991  		}
   992  	}
   993  	if iidx != -1 {
   994  		panic("iidx underflow")
   995  	}
   996  
   997  	// Finish intervals construction.
   998  	ivs := make([]Intervals, len(cs.cands))
   999  	for i := range cs.cands {
  1000  		var err error
  1001  		ivs[i], err = ibuilders[i].Finish()
  1002  		if err != nil {
  1003  			cs.dumpFunc()
  1004  			base.FatalfAt(cs.cands[i].Pos(), "interval construct error for var %q in func %q (%d instrs): %v", cs.cands[i].Sym().Name, ir.FuncName(cs.fn), ninstr, err)
  1005  		}
  1006  	}
  1007  	cs.ivs = ivs
  1008  }
  1009  
  1010  func fmtFullPos(p src.XPos) string {
  1011  	var sb strings.Builder
  1012  	sep := ""
  1013  	base.Ctxt.AllPos(p, func(pos src.Pos) {
  1014  		fmt.Fprintf(&sb, sep)
  1015  		sep = "|"
  1016  		file := filepath.Base(pos.Filename())
  1017  		fmt.Fprintf(&sb, "%s:%d:%d", file, pos.Line(), pos.Col())
  1018  	})
  1019  	return sb.String()
  1020  }
  1021  
  1022  func dumpCand(c *ir.Name, i int) {
  1023  	fmt.Fprintf(os.Stderr, " %d: %s %q sz=%d hp=%v align=%d t=%v\n",
  1024  		i, fmtFullPos(c.Pos()), c.Sym().Name, c.Type().Size(),
  1025  		c.Type().HasPointers(), c.Type().Alignment(), c.Type())
  1026  }
  1027  
  1028  // for unit testing only.
  1029  func MakeMergeLocalsState(partition map[*ir.Name][]int, vars []*ir.Name) (*MergeLocalsState, error) {
  1030  	mls := &MergeLocalsState{partition: partition, vars: vars}
  1031  	if err := mls.check(); err != nil {
  1032  		return nil, err
  1033  	}
  1034  	return mls, nil
  1035  }
  1036  

View as plain text