1
2
3
4
5
6
7 package main
8
9 import (
10 "bytes"
11 "fmt"
12 "io"
13 "log"
14 "os"
15 "slices"
16 "strconv"
17
18 "internal/runtime/gc"
19 "internal/runtime/gc/internal/gen"
20 )
21
22 const header = "// Code generated by mkasm.go. DO NOT EDIT.\n\n"
23
24 func main() {
25 generate("expand_amd64.s", genExpanders)
26 }
27
28 func generate(fileName string, genFunc func(*gen.File)) {
29 var buf bytes.Buffer
30 tee := io.MultiWriter(&buf, os.Stdout)
31
32 file := gen.NewFile(tee)
33
34 genFunc(file)
35
36 fmt.Fprintf(tee, header)
37 file.Compile()
38
39 f, err := os.Create(fileName)
40 if err != nil {
41 log.Fatal(err)
42 }
43 defer f.Close()
44 _, err = f.Write(buf.Bytes())
45 if err != nil {
46 log.Fatal(err)
47 }
48 }
49
50 func genExpanders(file *gen.File) {
51 gcExpandersAVX512 := make([]*gen.Func, len(gc.SizeClassToSize))
52 for sc, ob := range gc.SizeClassToSize {
53 if gc.SizeClassToNPages[sc] != 1 {
54
55
56 continue
57 }
58 if ob > gc.MinSizeForMallocHeader {
59
60 break
61 }
62
63 xf := int(ob) / 8
64 log.Printf("size class %d bytes, expansion %dx", ob, xf)
65
66 fn := gen.NewFunc(fmt.Sprintf("expandAVX512_%d<>", xf))
67 ptrObjBits := gen.Arg[gen.Ptr[gen.Uint8x64]](fn)
68
69 if xf == 1 {
70 expandIdentity(ptrObjBits)
71 } else {
72 ok := gfExpander(xf, ptrObjBits)
73 if !ok {
74 log.Printf("failed to generate expander for size class %d", sc)
75 }
76 }
77 file.AddFunc(fn)
78 gcExpandersAVX512[sc] = fn
79 }
80
81
82 file.AddConst("·gcExpandersAVX512", gcExpandersAVX512)
83 }
84
85
86 type mat8x8 struct {
87 mat [8]uint8
88 }
89
90 func matGroupToVec(mats *[8]mat8x8) [8]uint64 {
91 var out [8]uint64
92 for i, mat := range mats {
93 for j, row := range mat.mat {
94
95 out[i] |= uint64(row) << ((7 - j) * 8)
96 }
97 }
98 return out
99 }
100
101
102 func expandIdentity(ptrObjBits gen.Ptr[gen.Uint8x64]) {
103 objBitsLo := gen.Deref(ptrObjBits)
104 objBitsHi := gen.Deref(ptrObjBits.AddConst(64))
105 gen.Return(objBitsLo, objBitsHi)
106 }
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121 func gfExpander(f int, ptrObjBits gen.Ptr[gen.Uint8x64]) bool {
122
123
124
125
126
127
128
129
130
131 objBits := gen.Deref(ptrObjBits)
132
133 type term struct {
134 iByte, oByte int
135 mat mat8x8
136 }
137 var terms []term
138
139
140
141
142 for oByte := 0; oByte < 1024/8; oByte++ {
143 var byteMat mat8x8
144 iByte := -1
145 for oBit := oByte * 8; oBit < oByte*8+8; oBit++ {
146 iBit := oBit / f
147 if iByte == -1 {
148 iByte = iBit / 8
149 } else if iByte != iBit/8 {
150 log.Printf("output byte %d straddles input bytes %d and %d", oByte, iByte, iBit/8)
151 return false
152 }
153
154
155
156
157 byteMat.mat[oBit%8] = 1 << (iBit % 8)
158 }
159 terms = append(terms, term{iByte, oByte, byteMat})
160 }
161
162 if false {
163
164 maxIByte, maxOByte := 0, 0
165 for _, term := range terms {
166 maxIByte = max(maxIByte, term.iByte)
167 maxOByte = max(maxOByte, term.oByte)
168 }
169 iToO := make([][]rune, maxIByte+1)
170 for i := range iToO {
171 iToO[i] = make([]rune, maxOByte+1)
172 }
173 matMap := make(map[mat8x8]int)
174 for _, term := range terms {
175 i, ok := matMap[term.mat]
176 if !ok {
177 i = len(matMap)
178 matMap[term.mat] = i
179 }
180 iToO[term.iByte][term.oByte] = 'A' + rune(i)
181 }
182 for o := range maxOByte + 1 {
183 fmt.Printf("%d", o)
184 for i := range maxIByte + 1 {
185 fmt.Printf(",")
186 if mat := iToO[i][o]; mat != 0 {
187 fmt.Printf("%c", mat)
188 }
189 }
190 fmt.Println()
191 }
192 }
193
194
195
196
197
198
199
200
201 const termsPerGroup = 8
202 const groupsPerSuperGroup = 8
203
204 matMap := make(map[mat8x8]int)
205 allMats := make(map[mat8x8]bool)
206 var termGroups [][]term
207 for _, term := range terms {
208 allMats[term.mat] = true
209
210 i, ok := matMap[term.mat]
211 if ok && f > groupsPerSuperGroup {
212
213
214
215
216
217
218
219
220
221
222
223
224 outRegister := termGroups[i][0].oByte / 64
225 if term.oByte/64 != outRegister {
226 ok = false
227 }
228 }
229 if !ok {
230
231 i = len(termGroups)
232 matMap[term.mat] = i
233 termGroups = append(termGroups, nil)
234 }
235
236 termGroups[i] = append(termGroups[i], term)
237
238 if len(termGroups[i]) == termsPerGroup {
239
240 delete(matMap, term.mat)
241 }
242 }
243
244 for i, termGroup := range termGroups {
245 log.Printf("term group %d:", i)
246 for _, term := range termGroup {
247 log.Printf(" %+v", term)
248 }
249 }
250
251
252
253
254
255
256
257
258
259
260
261
262
263 var sgSize, nSuperGroups int
264 oneMatVec := f <= groupsPerSuperGroup
265 if oneMatVec {
266
267
268 sgSize = groupsPerSuperGroup / len(allMats) * len(allMats)
269 nSuperGroups = (len(termGroups) + sgSize - 1) / sgSize
270 } else {
271
272
273
274
275
276
277
278 sgSize = 8
279 nSuperGroups = (len(termGroups) + groupsPerSuperGroup - 1) / groupsPerSuperGroup
280 }
281
282
283 var matGroup [8]mat8x8
284 var matMuls []gen.Uint8x64
285 var perm [128]int
286 for sgi := range nSuperGroups {
287 var iperm [64]uint8
288 for i := range iperm {
289 iperm[i] = 0xff
290 }
291
292 superGroup := termGroups[:min(len(termGroups), sgSize)]
293 termGroups = termGroups[len(superGroup):]
294
295 var thisMatGroup [8]mat8x8
296 for i, termGroup := range superGroup {
297
298 thisMatGroup[i] = termGroup[0].mat
299 for j, term := range termGroup {
300
301 iperm[i*termsPerGroup+j] = uint8(term.iByte)
302
303 perm[term.oByte] = sgi*groupsPerSuperGroup*termsPerGroup + i*termsPerGroup + j
304 }
305 }
306 log.Printf("input permutation %d: %v", sgi, iperm)
307
308
309 if oneMatVec {
310 if sgi == 0 {
311 matGroup = thisMatGroup
312 } else if matGroup != thisMatGroup {
313 log.Printf("super-groups have different matrixes:\n%+v\n%+v", matGroup, thisMatGroup)
314 return false
315 }
316 }
317
318
319 matConst := gen.ConstUint64x8(matGroupToVec(&thisMatGroup), fmt.Sprintf("*_mat%d<>", sgi))
320 inOp := objBits.Shuffle(gen.ConstUint8x64(iperm, fmt.Sprintf("*_inShuf%d<>", sgi)))
321 matMul := matConst.GF2P8Affine(inOp)
322 matMuls = append(matMuls, matMul)
323 }
324
325 log.Printf("output permutation: %v", perm)
326
327 outLo, ok := genShuffle("*_outShufLo", (*[64]int)(perm[:64]), matMuls...)
328 if !ok {
329 log.Printf("bad number of inputs to final shuffle: %d != 1, 2, or 4", len(matMuls))
330 return false
331 }
332 outHi, ok := genShuffle("*_outShufHi", (*[64]int)(perm[64:]), matMuls...)
333 if !ok {
334 log.Printf("bad number of inputs to final shuffle: %d != 1, 2, or 4", len(matMuls))
335 return false
336 }
337 gen.Return(outLo, outHi)
338
339 return true
340 }
341
342 func genShuffle(name string, perm *[64]int, args ...gen.Uint8x64) (gen.Uint8x64, bool) {
343
344 var vperm [64]byte
345
346
347 var inputs []int
348 for i, src := range perm {
349 inputIdx := slices.Index(inputs, src/64)
350 if inputIdx == -1 {
351 inputIdx = len(inputs)
352 inputs = append(inputs, src/64)
353 }
354 vperm[i] = byte(src%64 | (inputIdx << 6))
355 }
356
357
358 switch len(inputs) {
359 case 1:
360 constOp := gen.ConstUint8x64(vperm, name)
361 return args[inputs[0]].Shuffle(constOp), true
362 case 2:
363 constOp := gen.ConstUint8x64(vperm, name)
364 return args[inputs[0]].Shuffle2(args[inputs[1]], constOp), true
365 }
366
367
368
369
370
371
372
373 var vperms [2][64]byte
374 var masks [2]uint64
375 for j, idx := range vperm {
376 for i := range vperms {
377 vperms[i][j] = 0xff
378 }
379 if idx == 0xff {
380 continue
381 }
382 vperms[idx/128][j] = idx % 128
383 masks[idx/128] |= uint64(1) << j
384 }
385
386
387 if masks[0]^masks[1] != ^uint64(0) {
388 panic("bad shuffle!")
389 }
390
391
392 constOps := make([]gen.Uint8x64, len(vperms))
393 for i, v := range vperms {
394 constOps[i] = gen.ConstUint8x64(v, name+strconv.Itoa(i))
395 }
396
397
398 switch len(inputs) {
399 case 3:
400 r0 := args[inputs[0]].Shuffle2Zeroed(args[inputs[1]], constOps[0], gen.ConstMask64(masks[0]))
401 r1 := args[inputs[2]].ShuffleZeroed(constOps[1], gen.ConstMask64(masks[1]))
402 return r0.ToUint64x8().Or(r1.ToUint64x8()).ToUint8x64(), true
403 case 4:
404 r0 := args[inputs[0]].Shuffle2Zeroed(args[inputs[1]], constOps[0], gen.ConstMask64(masks[0]))
405 r1 := args[inputs[2]].Shuffle2Zeroed(args[inputs[3]], constOps[1], gen.ConstMask64(masks[1]))
406 return r0.ToUint64x8().Or(r1.ToUint64x8()).ToUint8x64(), true
407 }
408
409
410
411 return args[0], false
412 }
413
View as plain text