Source file src/internal/runtime/gc/scan/mkasm.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build ignore
     6  
     7  package main
     8  
     9  import (
    10  	"bytes"
    11  	"fmt"
    12  	"io"
    13  	"log"
    14  	"os"
    15  	"slices"
    16  	"strconv"
    17  
    18  	"internal/runtime/gc"
    19  	"internal/runtime/gc/internal/gen"
    20  )
    21  
    22  const header = "// Code generated by mkasm.go. DO NOT EDIT.\n\n"
    23  
    24  func main() {
    25  	generate("expand_amd64.s", genExpanders)
    26  }
    27  
    28  func generate(fileName string, genFunc func(*gen.File)) {
    29  	var buf bytes.Buffer
    30  	tee := io.MultiWriter(&buf, os.Stdout)
    31  
    32  	file := gen.NewFile(tee)
    33  
    34  	genFunc(file)
    35  
    36  	fmt.Fprintf(tee, header)
    37  	file.Compile()
    38  
    39  	f, err := os.Create(fileName)
    40  	if err != nil {
    41  		log.Fatal(err)
    42  	}
    43  	defer f.Close()
    44  	_, err = f.Write(buf.Bytes())
    45  	if err != nil {
    46  		log.Fatal(err)
    47  	}
    48  }
    49  
    50  func genExpanders(file *gen.File) {
    51  	gcExpandersAVX512 := make([]*gen.Func, len(gc.SizeClassToSize))
    52  	for sc, ob := range gc.SizeClassToSize {
    53  		if gc.SizeClassToNPages[sc] != 1 {
    54  			// These functions all produce a bitmap that covers exactly one
    55  			// page.
    56  			continue
    57  		}
    58  		if ob > gc.MinSizeForMallocHeader {
    59  			// This size class is too big to have a packed pointer/scalar bitmap.
    60  			break
    61  		}
    62  
    63  		xf := int(ob) / 8
    64  		log.Printf("size class %d bytes, expansion %dx", ob, xf)
    65  
    66  		fn := gen.NewFunc(fmt.Sprintf("expandAVX512_%d<>", xf))
    67  		ptrObjBits := gen.Arg[gen.Ptr[gen.Uint8x64]](fn)
    68  
    69  		if xf == 1 {
    70  			expandIdentity(ptrObjBits)
    71  		} else {
    72  			ok := gfExpander(xf, ptrObjBits)
    73  			if !ok {
    74  				log.Printf("failed to generate expander for size class %d", sc)
    75  			}
    76  		}
    77  		file.AddFunc(fn)
    78  		gcExpandersAVX512[sc] = fn
    79  	}
    80  
    81  	// Generate table mapping size class to expander PC
    82  	file.AddConst("·gcExpandersAVX512", gcExpandersAVX512)
    83  }
    84  
    85  // mat8x8 is an 8x8 bit matrix.
    86  type mat8x8 struct {
    87  	mat [8]uint8
    88  }
    89  
    90  func matGroupToVec(mats *[8]mat8x8) [8]uint64 {
    91  	var out [8]uint64
    92  	for i, mat := range mats {
    93  		for j, row := range mat.mat {
    94  			// For some reason, Intel flips the rows.
    95  			out[i] |= uint64(row) << ((7 - j) * 8)
    96  		}
    97  	}
    98  	return out
    99  }
   100  
   101  // expandIdentity implements 1x expansion (that is, no expansion).
   102  func expandIdentity(ptrObjBits gen.Ptr[gen.Uint8x64]) {
   103  	objBitsLo := gen.Deref(ptrObjBits)
   104  	objBitsHi := gen.Deref(ptrObjBits.AddConst(64))
   105  	gen.Return(objBitsLo, objBitsHi)
   106  }
   107  
   108  // gfExpander produces a function that expands each bit in an input bitmap into
   109  // f consecutive bits in an output bitmap.
   110  //
   111  // The input is
   112  //
   113  //	AX *[8]uint64 = A pointer to floor(1024/f) bits (f >= 2, so at most 512 bits)
   114  //
   115  // The output is
   116  //
   117  //	Z1 [64]uint8  = The bottom 512 bits of the expanded bitmap
   118  //	Z2 [64]uint8  = The top 512 bits of the expanded bitmap
   119  //
   120  // TODO(austin): This should Z0/Z1.
   121  func gfExpander(f int, ptrObjBits gen.Ptr[gen.Uint8x64]) bool {
   122  	// TODO(austin): For powers of 2 >= 8, we can use mask expansion ops to make this much simpler.
   123  
   124  	// TODO(austin): For f >= 8, I suspect there are better ways to do this.
   125  	//
   126  	// For example, we could use a mask expansion to get a full byte for each
   127  	// input bit, and separately create the bytes that blend adjacent bits, then
   128  	// shuffle those bytes together. Certainly for f >= 16 this makes sense
   129  	// because each of those bytes will be used, possibly more than once.
   130  
   131  	objBits := gen.Deref(ptrObjBits)
   132  
   133  	type term struct {
   134  		iByte, oByte int
   135  		mat          mat8x8
   136  	}
   137  	var terms []term
   138  
   139  	// Iterate over all output bytes and construct the 8x8 GF2 matrix to compute
   140  	// the output byte from the appropriate input byte. Gather all of these into
   141  	// "terms".
   142  	for oByte := 0; oByte < 1024/8; oByte++ {
   143  		var byteMat mat8x8
   144  		iByte := -1
   145  		for oBit := oByte * 8; oBit < oByte*8+8; oBit++ {
   146  			iBit := oBit / f
   147  			if iByte == -1 {
   148  				iByte = iBit / 8
   149  			} else if iByte != iBit/8 {
   150  				log.Printf("output byte %d straddles input bytes %d and %d", oByte, iByte, iBit/8)
   151  				return false
   152  			}
   153  			// One way to view this is that the i'th row of the matrix will be
   154  			// ANDed with the input byte, and the parity of the result will set
   155  			// the i'th bit in the output. We use a simple 1 bit mask, so the
   156  			// parity is irrelevant beyond selecting out that one bit.
   157  			byteMat.mat[oBit%8] = 1 << (iBit % 8)
   158  		}
   159  		terms = append(terms, term{iByte, oByte, byteMat})
   160  	}
   161  
   162  	if false {
   163  		// Print input byte -> output byte as a matrix
   164  		maxIByte, maxOByte := 0, 0
   165  		for _, term := range terms {
   166  			maxIByte = max(maxIByte, term.iByte)
   167  			maxOByte = max(maxOByte, term.oByte)
   168  		}
   169  		iToO := make([][]rune, maxIByte+1)
   170  		for i := range iToO {
   171  			iToO[i] = make([]rune, maxOByte+1)
   172  		}
   173  		matMap := make(map[mat8x8]int)
   174  		for _, term := range terms {
   175  			i, ok := matMap[term.mat]
   176  			if !ok {
   177  				i = len(matMap)
   178  				matMap[term.mat] = i
   179  			}
   180  			iToO[term.iByte][term.oByte] = 'A' + rune(i)
   181  		}
   182  		for o := range maxOByte + 1 {
   183  			fmt.Printf("%d", o)
   184  			for i := range maxIByte + 1 {
   185  				fmt.Printf(",")
   186  				if mat := iToO[i][o]; mat != 0 {
   187  					fmt.Printf("%c", mat)
   188  				}
   189  			}
   190  			fmt.Println()
   191  		}
   192  	}
   193  
   194  	// In hardware, each (8 byte) matrix applies to 8 bytes of data in parallel,
   195  	// and we get to operate on up to 8 matrixes in parallel (or 64 values). That is:
   196  	//
   197  	//  abcdefgh ijklmnop qrstuvwx yzABCDEF GHIJKLMN OPQRSTUV WXYZ0123 456789_+
   198  	//    mat0     mat1     mat2     mat3     mat4     mat5     mat6     mat7
   199  
   200  	// Group the terms by matrix, but limit each group to 8 terms.
   201  	const termsPerGroup = 8       // Number of terms we can multiply by the same matrix.
   202  	const groupsPerSuperGroup = 8 // Number of matrixes we can fit in a vector.
   203  
   204  	matMap := make(map[mat8x8]int)
   205  	allMats := make(map[mat8x8]bool)
   206  	var termGroups [][]term
   207  	for _, term := range terms {
   208  		allMats[term.mat] = true
   209  
   210  		i, ok := matMap[term.mat]
   211  		if ok && f > groupsPerSuperGroup {
   212  			// The output is ultimately produced in two [64]uint8 registers.
   213  			// Getting every byte in the right place of each of these requires a
   214  			// final permutation that often requires more than one source.
   215  			//
   216  			// Up to 8x expansion, we can get a really nice grouping so we can use
   217  			// the same 8 matrix vector several times, without producing
   218  			// permutations that require more than two sources.
   219  			//
   220  			// Above 8x, however, we can't get nice matrixes anyway, so we
   221  			// instead prefer reducing the complexity of the permutations we
   222  			// need to produce the final outputs. To do this, avoid grouping
   223  			// together terms that are split across the two registers.
   224  			outRegister := termGroups[i][0].oByte / 64
   225  			if term.oByte/64 != outRegister {
   226  				ok = false
   227  			}
   228  		}
   229  		if !ok {
   230  			// Start a new term group.
   231  			i = len(termGroups)
   232  			matMap[term.mat] = i
   233  			termGroups = append(termGroups, nil)
   234  		}
   235  
   236  		termGroups[i] = append(termGroups[i], term)
   237  
   238  		if len(termGroups[i]) == termsPerGroup {
   239  			// This term group is full.
   240  			delete(matMap, term.mat)
   241  		}
   242  	}
   243  
   244  	for i, termGroup := range termGroups {
   245  		log.Printf("term group %d:", i)
   246  		for _, term := range termGroup {
   247  			log.Printf("  %+v", term)
   248  		}
   249  	}
   250  
   251  	// We can do 8 matrix multiplies in parallel, which is 8 term groups. Pack
   252  	// as many term groups as we can into each super-group to minimize the
   253  	// number of matrix multiplies.
   254  	//
   255  	// Ideally, we use the same matrix in each super-group, which might mean
   256  	// doing fewer than 8 multiplies at a time. That's fine because it never
   257  	// increases the total number of matrix multiplies.
   258  	//
   259  	// TODO: Packing the matrixes less densely may let us use more broadcast
   260  	// loads instead of general permutations, though. That replaces a load of
   261  	// the permutation with a load of the matrix, but is probably still slightly
   262  	// better.
   263  	var sgSize, nSuperGroups int
   264  	oneMatVec := f <= groupsPerSuperGroup
   265  	if oneMatVec {
   266  		// We can use the same matrix in each multiply by doing sgSize
   267  		// multiplies at a time.
   268  		sgSize = groupsPerSuperGroup / len(allMats) * len(allMats)
   269  		nSuperGroups = (len(termGroups) + sgSize - 1) / sgSize
   270  	} else {
   271  		// We can't use the same matrix for each multiply. Just do as many at a
   272  		// time as we can.
   273  		//
   274  		// TODO: This is going to produce several distinct matrixes, when we
   275  		// probably only need two. Be smarter about how we create super-groups
   276  		// in this case. Maybe we build up an array of super-groups and then the
   277  		// loop below just turns them into ops?
   278  		sgSize = 8
   279  		nSuperGroups = (len(termGroups) + groupsPerSuperGroup - 1) / groupsPerSuperGroup
   280  	}
   281  
   282  	// Construct each super-group.
   283  	var matGroup [8]mat8x8
   284  	var matMuls []gen.Uint8x64
   285  	var perm [128]int
   286  	for sgi := range nSuperGroups {
   287  		var iperm [64]uint8
   288  		for i := range iperm {
   289  			iperm[i] = 0xff // "Don't care"
   290  		}
   291  		// Pick off sgSize term groups.
   292  		superGroup := termGroups[:min(len(termGroups), sgSize)]
   293  		termGroups = termGroups[len(superGroup):]
   294  		// Build the matrix and permutations for this super-group.
   295  		var thisMatGroup [8]mat8x8
   296  		for i, termGroup := range superGroup {
   297  			// All terms in this group have the same matrix. Pick one.
   298  			thisMatGroup[i] = termGroup[0].mat
   299  			for j, term := range termGroup {
   300  				// Build the input permutation.
   301  				iperm[i*termsPerGroup+j] = uint8(term.iByte)
   302  				// Build the output permutation.
   303  				perm[term.oByte] = sgi*groupsPerSuperGroup*termsPerGroup + i*termsPerGroup + j
   304  			}
   305  		}
   306  		log.Printf("input permutation %d: %v", sgi, iperm)
   307  
   308  		// Check that we're not making more distinct matrixes than expected.
   309  		if oneMatVec {
   310  			if sgi == 0 {
   311  				matGroup = thisMatGroup
   312  			} else if matGroup != thisMatGroup {
   313  				log.Printf("super-groups have different matrixes:\n%+v\n%+v", matGroup, thisMatGroup)
   314  				return false
   315  			}
   316  		}
   317  
   318  		// Emit matrix op.
   319  		matConst := gen.ConstUint64x8(matGroupToVec(&thisMatGroup), fmt.Sprintf("*_mat%d<>", sgi))
   320  		inOp := objBits.Shuffle(gen.ConstUint8x64(iperm, fmt.Sprintf("*_inShuf%d<>", sgi)))
   321  		matMul := matConst.GF2P8Affine(inOp)
   322  		matMuls = append(matMuls, matMul)
   323  	}
   324  
   325  	log.Printf("output permutation: %v", perm)
   326  
   327  	outLo, ok := genShuffle("*_outShufLo", (*[64]int)(perm[:64]), matMuls...)
   328  	if !ok {
   329  		log.Printf("bad number of inputs to final shuffle: %d != 1, 2, or 4", len(matMuls))
   330  		return false
   331  	}
   332  	outHi, ok := genShuffle("*_outShufHi", (*[64]int)(perm[64:]), matMuls...)
   333  	if !ok {
   334  		log.Printf("bad number of inputs to final shuffle: %d != 1, 2, or 4", len(matMuls))
   335  		return false
   336  	}
   337  	gen.Return(outLo, outHi)
   338  
   339  	return true
   340  }
   341  
   342  func genShuffle(name string, perm *[64]int, args ...gen.Uint8x64) (gen.Uint8x64, bool) {
   343  	// Construct flattened permutation.
   344  	var vperm [64]byte
   345  
   346  	// Get the inputs used by this permutation.
   347  	var inputs []int
   348  	for i, src := range perm {
   349  		inputIdx := slices.Index(inputs, src/64)
   350  		if inputIdx == -1 {
   351  			inputIdx = len(inputs)
   352  			inputs = append(inputs, src/64)
   353  		}
   354  		vperm[i] = byte(src%64 | (inputIdx << 6))
   355  	}
   356  
   357  	// Emit instructions for easy cases.
   358  	switch len(inputs) {
   359  	case 1:
   360  		constOp := gen.ConstUint8x64(vperm, name)
   361  		return args[inputs[0]].Shuffle(constOp), true
   362  	case 2:
   363  		constOp := gen.ConstUint8x64(vperm, name)
   364  		return args[inputs[0]].Shuffle2(args[inputs[1]], constOp), true
   365  	}
   366  
   367  	// Harder case, we need to shuffle in from up to 2 more tables.
   368  	//
   369  	// Perform two shuffles. One shuffle will get its data from the first
   370  	// two inputs, the other shuffle will get its data from the other one
   371  	// or two inputs. All values they don't care each don't care about will
   372  	// be zeroed.
   373  	var vperms [2][64]byte
   374  	var masks [2]uint64
   375  	for j, idx := range vperm {
   376  		for i := range vperms {
   377  			vperms[i][j] = 0xff // "Don't care"
   378  		}
   379  		if idx == 0xff {
   380  			continue
   381  		}
   382  		vperms[idx/128][j] = idx % 128
   383  		masks[idx/128] |= uint64(1) << j
   384  	}
   385  
   386  	// Validate that the masks are fully disjoint.
   387  	if masks[0]^masks[1] != ^uint64(0) {
   388  		panic("bad shuffle!")
   389  	}
   390  
   391  	// Generate constants.
   392  	constOps := make([]gen.Uint8x64, len(vperms))
   393  	for i, v := range vperms {
   394  		constOps[i] = gen.ConstUint8x64(v, name+strconv.Itoa(i))
   395  	}
   396  
   397  	// Generate shuffles.
   398  	switch len(inputs) {
   399  	case 3:
   400  		r0 := args[inputs[0]].Shuffle2Zeroed(args[inputs[1]], constOps[0], gen.ConstMask64(masks[0]))
   401  		r1 := args[inputs[2]].ShuffleZeroed(constOps[1], gen.ConstMask64(masks[1]))
   402  		return r0.ToUint64x8().Or(r1.ToUint64x8()).ToUint8x64(), true
   403  	case 4:
   404  		r0 := args[inputs[0]].Shuffle2Zeroed(args[inputs[1]], constOps[0], gen.ConstMask64(masks[0]))
   405  		r1 := args[inputs[2]].Shuffle2Zeroed(args[inputs[3]], constOps[1], gen.ConstMask64(masks[1]))
   406  		return r0.ToUint64x8().Or(r1.ToUint64x8()).ToUint8x64(), true
   407  	}
   408  
   409  	// Too many inputs. To support more, we'd need to separate tables much earlier.
   410  	// Right now all the indices fit in a byte, but with >4 inputs they might not (>256 bytes).
   411  	return args[0], false
   412  }
   413  

View as plain text