// Copyright 2025 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build ignore package main import ( "bytes" "fmt" "io" "log" "os" "slices" "strconv" "internal/runtime/gc" "internal/runtime/gc/internal/gen" ) const header = "// Code generated by mkasm.go. DO NOT EDIT.\n\n" func main() { generate("expand_amd64.s", genExpanders) } func generate(fileName string, genFunc func(*gen.File)) { var buf bytes.Buffer tee := io.MultiWriter(&buf, os.Stdout) file := gen.NewFile(tee) genFunc(file) fmt.Fprintf(tee, header) file.Compile() f, err := os.Create(fileName) if err != nil { log.Fatal(err) } defer f.Close() _, err = f.Write(buf.Bytes()) if err != nil { log.Fatal(err) } } func genExpanders(file *gen.File) { gcExpandersAVX512 := make([]*gen.Func, len(gc.SizeClassToSize)) for sc, ob := range gc.SizeClassToSize { if gc.SizeClassToNPages[sc] != 1 { // These functions all produce a bitmap that covers exactly one // page. continue } if ob > gc.MinSizeForMallocHeader { // This size class is too big to have a packed pointer/scalar bitmap. break } xf := int(ob) / 8 log.Printf("size class %d bytes, expansion %dx", ob, xf) fn := gen.NewFunc(fmt.Sprintf("expandAVX512_%d<>", xf)) ptrObjBits := gen.Arg[gen.Ptr[gen.Uint8x64]](fn) if xf == 1 { expandIdentity(ptrObjBits) } else { ok := gfExpander(xf, ptrObjBits) if !ok { log.Printf("failed to generate expander for size class %d", sc) } } file.AddFunc(fn) gcExpandersAVX512[sc] = fn } // Generate table mapping size class to expander PC file.AddConst("·gcExpandersAVX512", gcExpandersAVX512) } // mat8x8 is an 8x8 bit matrix. type mat8x8 struct { mat [8]uint8 } func matGroupToVec(mats *[8]mat8x8) [8]uint64 { var out [8]uint64 for i, mat := range mats { for j, row := range mat.mat { // For some reason, Intel flips the rows. out[i] |= uint64(row) << ((7 - j) * 8) } } return out } // expandIdentity implements 1x expansion (that is, no expansion). func expandIdentity(ptrObjBits gen.Ptr[gen.Uint8x64]) { objBitsLo := gen.Deref(ptrObjBits) objBitsHi := gen.Deref(ptrObjBits.AddConst(64)) gen.Return(objBitsLo, objBitsHi) } // gfExpander produces a function that expands each bit in an input bitmap into // f consecutive bits in an output bitmap. // // The input is // // AX *[8]uint64 = A pointer to floor(1024/f) bits (f >= 2, so at most 512 bits) // // The output is // // Z1 [64]uint8 = The bottom 512 bits of the expanded bitmap // Z2 [64]uint8 = The top 512 bits of the expanded bitmap // // TODO(austin): This should Z0/Z1. func gfExpander(f int, ptrObjBits gen.Ptr[gen.Uint8x64]) bool { // TODO(austin): For powers of 2 >= 8, we can use mask expansion ops to make this much simpler. // TODO(austin): For f >= 8, I suspect there are better ways to do this. // // For example, we could use a mask expansion to get a full byte for each // input bit, and separately create the bytes that blend adjacent bits, then // shuffle those bytes together. Certainly for f >= 16 this makes sense // because each of those bytes will be used, possibly more than once. objBits := gen.Deref(ptrObjBits) type term struct { iByte, oByte int mat mat8x8 } var terms []term // Iterate over all output bytes and construct the 8x8 GF2 matrix to compute // the output byte from the appropriate input byte. Gather all of these into // "terms". for oByte := 0; oByte < 1024/8; oByte++ { var byteMat mat8x8 iByte := -1 for oBit := oByte * 8; oBit < oByte*8+8; oBit++ { iBit := oBit / f if iByte == -1 { iByte = iBit / 8 } else if iByte != iBit/8 { log.Printf("output byte %d straddles input bytes %d and %d", oByte, iByte, iBit/8) return false } // One way to view this is that the i'th row of the matrix will be // ANDed with the input byte, and the parity of the result will set // the i'th bit in the output. We use a simple 1 bit mask, so the // parity is irrelevant beyond selecting out that one bit. byteMat.mat[oBit%8] = 1 << (iBit % 8) } terms = append(terms, term{iByte, oByte, byteMat}) } if false { // Print input byte -> output byte as a matrix maxIByte, maxOByte := 0, 0 for _, term := range terms { maxIByte = max(maxIByte, term.iByte) maxOByte = max(maxOByte, term.oByte) } iToO := make([][]rune, maxIByte+1) for i := range iToO { iToO[i] = make([]rune, maxOByte+1) } matMap := make(map[mat8x8]int) for _, term := range terms { i, ok := matMap[term.mat] if !ok { i = len(matMap) matMap[term.mat] = i } iToO[term.iByte][term.oByte] = 'A' + rune(i) } for o := range maxOByte + 1 { fmt.Printf("%d", o) for i := range maxIByte + 1 { fmt.Printf(",") if mat := iToO[i][o]; mat != 0 { fmt.Printf("%c", mat) } } fmt.Println() } } // In hardware, each (8 byte) matrix applies to 8 bytes of data in parallel, // and we get to operate on up to 8 matrixes in parallel (or 64 values). That is: // // abcdefgh ijklmnop qrstuvwx yzABCDEF GHIJKLMN OPQRSTUV WXYZ0123 456789_+ // mat0 mat1 mat2 mat3 mat4 mat5 mat6 mat7 // Group the terms by matrix, but limit each group to 8 terms. const termsPerGroup = 8 // Number of terms we can multiply by the same matrix. const groupsPerSuperGroup = 8 // Number of matrixes we can fit in a vector. matMap := make(map[mat8x8]int) allMats := make(map[mat8x8]bool) var termGroups [][]term for _, term := range terms { allMats[term.mat] = true i, ok := matMap[term.mat] if ok && f > groupsPerSuperGroup { // The output is ultimately produced in two [64]uint8 registers. // Getting every byte in the right place of each of these requires a // final permutation that often requires more than one source. // // Up to 8x expansion, we can get a really nice grouping so we can use // the same 8 matrix vector several times, without producing // permutations that require more than two sources. // // Above 8x, however, we can't get nice matrixes anyway, so we // instead prefer reducing the complexity of the permutations we // need to produce the final outputs. To do this, avoid grouping // together terms that are split across the two registers. outRegister := termGroups[i][0].oByte / 64 if term.oByte/64 != outRegister { ok = false } } if !ok { // Start a new term group. i = len(termGroups) matMap[term.mat] = i termGroups = append(termGroups, nil) } termGroups[i] = append(termGroups[i], term) if len(termGroups[i]) == termsPerGroup { // This term group is full. delete(matMap, term.mat) } } for i, termGroup := range termGroups { log.Printf("term group %d:", i) for _, term := range termGroup { log.Printf(" %+v", term) } } // We can do 8 matrix multiplies in parallel, which is 8 term groups. Pack // as many term groups as we can into each super-group to minimize the // number of matrix multiplies. // // Ideally, we use the same matrix in each super-group, which might mean // doing fewer than 8 multiplies at a time. That's fine because it never // increases the total number of matrix multiplies. // // TODO: Packing the matrixes less densely may let us use more broadcast // loads instead of general permutations, though. That replaces a load of // the permutation with a load of the matrix, but is probably still slightly // better. var sgSize, nSuperGroups int oneMatVec := f <= groupsPerSuperGroup if oneMatVec { // We can use the same matrix in each multiply by doing sgSize // multiplies at a time. sgSize = groupsPerSuperGroup / len(allMats) * len(allMats) nSuperGroups = (len(termGroups) + sgSize - 1) / sgSize } else { // We can't use the same matrix for each multiply. Just do as many at a // time as we can. // // TODO: This is going to produce several distinct matrixes, when we // probably only need two. Be smarter about how we create super-groups // in this case. Maybe we build up an array of super-groups and then the // loop below just turns them into ops? sgSize = 8 nSuperGroups = (len(termGroups) + groupsPerSuperGroup - 1) / groupsPerSuperGroup } // Construct each super-group. var matGroup [8]mat8x8 var matMuls []gen.Uint8x64 var perm [128]int for sgi := range nSuperGroups { var iperm [64]uint8 for i := range iperm { iperm[i] = 0xff // "Don't care" } // Pick off sgSize term groups. superGroup := termGroups[:min(len(termGroups), sgSize)] termGroups = termGroups[len(superGroup):] // Build the matrix and permutations for this super-group. var thisMatGroup [8]mat8x8 for i, termGroup := range superGroup { // All terms in this group have the same matrix. Pick one. thisMatGroup[i] = termGroup[0].mat for j, term := range termGroup { // Build the input permutation. iperm[i*termsPerGroup+j] = uint8(term.iByte) // Build the output permutation. perm[term.oByte] = sgi*groupsPerSuperGroup*termsPerGroup + i*termsPerGroup + j } } log.Printf("input permutation %d: %v", sgi, iperm) // Check that we're not making more distinct matrixes than expected. if oneMatVec { if sgi == 0 { matGroup = thisMatGroup } else if matGroup != thisMatGroup { log.Printf("super-groups have different matrixes:\n%+v\n%+v", matGroup, thisMatGroup) return false } } // Emit matrix op. matConst := gen.ConstUint64x8(matGroupToVec(&thisMatGroup), fmt.Sprintf("*_mat%d<>", sgi)) inOp := objBits.Shuffle(gen.ConstUint8x64(iperm, fmt.Sprintf("*_inShuf%d<>", sgi))) matMul := matConst.GF2P8Affine(inOp) matMuls = append(matMuls, matMul) } log.Printf("output permutation: %v", perm) outLo, ok := genShuffle("*_outShufLo", (*[64]int)(perm[:64]), matMuls...) if !ok { log.Printf("bad number of inputs to final shuffle: %d != 1, 2, or 4", len(matMuls)) return false } outHi, ok := genShuffle("*_outShufHi", (*[64]int)(perm[64:]), matMuls...) if !ok { log.Printf("bad number of inputs to final shuffle: %d != 1, 2, or 4", len(matMuls)) return false } gen.Return(outLo, outHi) return true } func genShuffle(name string, perm *[64]int, args ...gen.Uint8x64) (gen.Uint8x64, bool) { // Construct flattened permutation. var vperm [64]byte // Get the inputs used by this permutation. var inputs []int for i, src := range perm { inputIdx := slices.Index(inputs, src/64) if inputIdx == -1 { inputIdx = len(inputs) inputs = append(inputs, src/64) } vperm[i] = byte(src%64 | (inputIdx << 6)) } // Emit instructions for easy cases. switch len(inputs) { case 1: constOp := gen.ConstUint8x64(vperm, name) return args[inputs[0]].Shuffle(constOp), true case 2: constOp := gen.ConstUint8x64(vperm, name) return args[inputs[0]].Shuffle2(args[inputs[1]], constOp), true } // Harder case, we need to shuffle in from up to 2 more tables. // // Perform two shuffles. One shuffle will get its data from the first // two inputs, the other shuffle will get its data from the other one // or two inputs. All values they don't care each don't care about will // be zeroed. var vperms [2][64]byte var masks [2]uint64 for j, idx := range vperm { for i := range vperms { vperms[i][j] = 0xff // "Don't care" } if idx == 0xff { continue } vperms[idx/128][j] = idx % 128 masks[idx/128] |= uint64(1) << j } // Validate that the masks are fully disjoint. if masks[0]^masks[1] != ^uint64(0) { panic("bad shuffle!") } // Generate constants. constOps := make([]gen.Uint8x64, len(vperms)) for i, v := range vperms { constOps[i] = gen.ConstUint8x64(v, name+strconv.Itoa(i)) } // Generate shuffles. switch len(inputs) { case 3: r0 := args[inputs[0]].Shuffle2Zeroed(args[inputs[1]], constOps[0], gen.ConstMask64(masks[0])) r1 := args[inputs[2]].ShuffleZeroed(constOps[1], gen.ConstMask64(masks[1])) return r0.ToUint64x8().Or(r1.ToUint64x8()).ToUint8x64(), true case 4: r0 := args[inputs[0]].Shuffle2Zeroed(args[inputs[1]], constOps[0], gen.ConstMask64(masks[0])) r1 := args[inputs[2]].Shuffle2Zeroed(args[inputs[3]], constOps[1], gen.ConstMask64(masks[1])) return r0.ToUint64x8().Or(r1.ToUint64x8()).ToUint8x64(), true } // Too many inputs. To support more, we'd need to separate tables much earlier. // Right now all the indices fit in a byte, but with >4 inputs they might not (>256 bytes). return args[0], false }