1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "fmt"
10 "slices"
11 "strings"
12 "text/template"
13 )
14
15 type tplRuleData struct {
16 tplName string
17 GoOp string
18 GoType string
19 Args string
20 Asm string
21 ArgsOut string
22 MaskInConvert string
23 MaskOutConvert string
24 ElementSize int
25 Size int
26 ArgsLoadAddr string
27 ArgsAddr string
28 FeatCheck string
29 }
30
31 var (
32 ruleTemplates = template.Must(template.New("simdRules").Parse(`
33 {{define "pureVreg"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ArgsOut}})
34 {{end}}
35 {{define "maskIn"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask))
36 {{end}}
37 {{define "maskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}}))
38 {{end}}
39 {{define "maskInMaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask)))
40 {{end}}
41 {{define "sftimm"}}({{.Asm}} x (MOVQconst [c])) => ({{.Asm}}const [uint8(c)] x)
42 {{end}}
43 {{define "masksftimm"}}({{.Asm}} x (MOVQconst [c]) mask) => ({{.Asm}}const [uint8(c)] x mask)
44 {{end}}
45 {{define "vregMem"}}({{.Asm}} {{.ArgsLoadAddr}}) && canMergeLoad(v, l) && clobber(l) => ({{.Asm}}load {{.ArgsAddr}})
46 {{end}}
47 {{define "vregMemFeatCheck"}}({{.Asm}} {{.ArgsLoadAddr}}) && {{.FeatCheck}} && canMergeLoad(v, l) && clobber(l)=> ({{.Asm}}load {{.ArgsAddr}})
48 {{end}}
49 `))
50 )
51
52 func (d tplRuleData) MaskOptimization(asmCheck map[string]bool) string {
53 asmNoMask := d.Asm
54 if i := strings.Index(asmNoMask, "Masked"); i == -1 {
55 return ""
56 }
57 asmNoMask = strings.ReplaceAll(asmNoMask, "Masked", "")
58 if asmCheck[asmNoMask] == false {
59 return ""
60 }
61
62 for _, nope := range []string{"VMOVDQU", "VPCOMPRESS", "VCOMPRESS", "VPEXPAND", "VEXPAND", "VPBLENDM", "VMOVUP"} {
63 if strings.HasPrefix(asmNoMask, nope) {
64 return ""
65 }
66 }
67
68 size := asmNoMask[len(asmNoMask)-3:]
69 if strings.HasSuffix(asmNoMask, "const") {
70 sufLen := len("128const")
71 size = asmNoMask[len(asmNoMask)-sufLen:][:3]
72 }
73 switch size {
74 case "128", "256", "512":
75 default:
76 panic("Unexpected operation size on " + d.Asm)
77 }
78
79 switch d.ElementSize {
80 case 8, 16, 32, 64:
81 default:
82 panic(fmt.Errorf("Unexpected operation width %d on %v", d.ElementSize, d.Asm))
83 }
84
85 return fmt.Sprintf("(VMOVDQU%dMasked%s (%s %s) mask) => (%s %s mask)\n", d.ElementSize, size, asmNoMask, d.Args, d.Asm, d.Args)
86 }
87
88
89 var tmplOrder = map[string]int{
90 "masksftimm": 0,
91 "sftimm": 1,
92 "maskInMaskOut": 2,
93 "maskOut": 3,
94 "maskIn": 4,
95 "pureVreg": 5,
96 "vregMem": 6,
97 }
98
99 func compareTplRuleData(x, y tplRuleData) int {
100 if c := compareNatural(x.GoOp, y.GoOp); c != 0 {
101 return c
102 }
103 if c := compareNatural(x.GoType, y.GoType); c != 0 {
104 return c
105 }
106 if c := compareNatural(x.Args, y.Args); c != 0 {
107 return c
108 }
109 if x.tplName == y.tplName {
110 return 0
111 }
112 xo, xok := tmplOrder[x.tplName]
113 yo, yok := tmplOrder[y.tplName]
114 if !xok {
115 panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", x.tplName))
116 }
117 if !yok {
118 panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", y.tplName))
119 }
120 return xo - yo
121 }
122
123
124
125 func writeSIMDRules(ops []Operation) *bytes.Buffer {
126 buffer := new(bytes.Buffer)
127 buffer.WriteString(generatedHeader + "\n")
128
129
130 maskedMergeOpts := make(map[string]string)
131 s2n := map[int]string{8: "B", 16: "W", 32: "D", 64: "Q"}
132 asmCheck := map[string]bool{}
133 var allData []tplRuleData
134 var optData []tplRuleData
135 var memOptData []tplRuleData
136 memOpSeen := make(map[string]bool)
137
138 for _, opr := range ops {
139 opInShape, opOutShape, maskType, immType, gOp := opr.shape()
140 asm := machineOpName(maskType, gOp)
141 vregInCnt := len(gOp.In)
142 if maskType == OneMask {
143 vregInCnt--
144 }
145
146 data := tplRuleData{
147 GoOp: gOp.Go,
148 Asm: asm,
149 }
150
151 if vregInCnt == 1 {
152 data.Args = "x"
153 data.ArgsOut = data.Args
154 } else if vregInCnt == 2 {
155 data.Args = "x y"
156 data.ArgsOut = data.Args
157 } else if vregInCnt == 3 {
158 data.Args = "x y z"
159 data.ArgsOut = data.Args
160 } else {
161 panic(fmt.Errorf("simdgen does not support more than 3 vreg in inputs"))
162 }
163 if immType == ConstImm {
164 data.ArgsOut = fmt.Sprintf("[%s] %s", *opr.In[0].Const, data.ArgsOut)
165 } else if immType == VarImm {
166 data.Args = fmt.Sprintf("[a] %s", data.Args)
167 data.ArgsOut = fmt.Sprintf("[a] %s", data.ArgsOut)
168 } else if immType == ConstVarImm {
169 data.Args = fmt.Sprintf("[a] %s", data.Args)
170 data.ArgsOut = fmt.Sprintf("[a+%s] %s", *opr.In[0].Const, data.ArgsOut)
171 }
172
173 goType := func(op Operation) string {
174 if op.OperandOrder != nil {
175 switch *op.OperandOrder {
176 case "21Type1", "231Type1":
177
178 return *op.In[1].Go
179 }
180 }
181 return *op.In[0].Go
182 }
183 var tplName string
184
185 if opOutShape == OneVregOut || opOutShape == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
186 switch opInShape {
187 case OneImmIn:
188 tplName = "pureVreg"
189 data.GoType = goType(gOp)
190 case PureVregIn:
191 tplName = "pureVreg"
192 data.GoType = goType(gOp)
193 case OneKmaskImmIn:
194 fallthrough
195 case OneKmaskIn:
196 tplName = "maskIn"
197 data.GoType = goType(gOp)
198 rearIdx := len(gOp.In) - 1
199
200 width := *gOp.In[rearIdx].ElemBits
201 data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", width, *gOp.In[rearIdx].Lanes)
202 data.ElementSize = width
203 case PureKmaskIn:
204 panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
205 }
206 } else if opOutShape == OneGregOut {
207 tplName = "pureVreg"
208 data.GoType = goType(gOp)
209 } else {
210
211 data.MaskOutConvert = fmt.Sprintf("VPMOVMToVec%dx%d", *gOp.Out[0].ElemBits, *gOp.In[0].Lanes)
212 switch opInShape {
213 case OneImmIn:
214 fallthrough
215 case PureVregIn:
216 tplName = "maskOut"
217 data.GoType = goType(gOp)
218 case OneKmaskImmIn:
219 fallthrough
220 case OneKmaskIn:
221 tplName = "maskInMaskOut"
222 data.GoType = goType(gOp)
223 rearIdx := len(gOp.In) - 1
224 data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes)
225 case PureKmaskIn:
226 panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
227 }
228 }
229
230 if gOp.SpecialLower != nil {
231 if *gOp.SpecialLower == "sftimm" {
232 if data.GoType[0] == 'I' {
233
234 sftImmData := data
235 if tplName == "maskIn" {
236 sftImmData.tplName = "masksftimm"
237 } else {
238 sftImmData.tplName = "sftimm"
239 }
240 allData = append(allData, sftImmData)
241 asmCheck[sftImmData.Asm+"const"] = true
242 }
243 } else {
244 panic("simdgen sees unknwon special lower " + *gOp.SpecialLower + ", maybe implement it?")
245 }
246 }
247 if gOp.MemFeatures != nil && *gOp.MemFeatures == "vbcst" {
248
249 selected := true
250 for _, a := range gOp.In {
251 if a.TreatLikeAScalarOfSize != nil {
252 selected = false
253 break
254 }
255 }
256 if _, ok := memOpSeen[data.Asm]; ok {
257 selected = false
258 }
259 if selected {
260 memOpSeen[data.Asm] = true
261 lastVreg := gOp.In[vregInCnt-1]
262
263 if lastVreg.Class != "vreg" {
264 panic(fmt.Errorf("simdgen expects vbcst replaced operand to be a vreg, but %v found", lastVreg))
265 }
266 memOpData := data
267
268 origArgs := data.Args[:len(data.Args)-1]
269
270 immArg := ""
271 immArgCombineOff := " [off] "
272 if immType != NoImm && immType != InvalidImm {
273 _, after, found := strings.Cut(origArgs, "]")
274 if found {
275 origArgs = after
276 }
277 immArg = "[c] "
278 immArgCombineOff = " [makeValAndOff(int32(int8(c)),off)] "
279 }
280 memOpData.ArgsLoadAddr = immArg + origArgs + fmt.Sprintf("l:(VMOVDQUload%d {sym} [off] ptr mem)", *lastVreg.Bits)
281
282 memOpData.ArgsAddr = "{sym}" + immArgCombineOff + origArgs + "ptr"
283 if maskType == OneMask {
284 memOpData.ArgsAddr += " mask"
285 memOpData.ArgsLoadAddr += " mask"
286 }
287 memOpData.ArgsAddr += " mem"
288 if gOp.MemFeaturesData != nil {
289 _, feat2 := getVbcstData(*gOp.MemFeaturesData)
290 knownFeatChecks := map[string]string{
291 "AVX": "v.Block.CPUfeatures.hasFeature(CPUavx)",
292 "AVX2": "v.Block.CPUfeatures.hasFeature(CPUavx2)",
293 "AVX512": "v.Block.CPUfeatures.hasFeature(CPUavx512)",
294 }
295 memOpData.FeatCheck = knownFeatChecks[feat2]
296 memOpData.tplName = "vregMemFeatCheck"
297 } else {
298 memOpData.tplName = "vregMem"
299 }
300 memOptData = append(memOptData, memOpData)
301 asmCheck[memOpData.Asm+"load"] = true
302 }
303 }
304
305 if gOp.hasMaskedMerging(maskType, opOutShape) {
306
307 maskElem := gOp.In[len(gOp.In)-1]
308 if maskElem.Bits == nil {
309 panic("mask has no bits")
310 }
311 if maskElem.ElemBits == nil {
312 panic("mask has no elemBits")
313 }
314 if maskElem.Lanes == nil {
315 panic("mask has no lanes")
316 }
317 switch *maskElem.Bits {
318 case 128, 256:
319
320 noMaskName := machineOpName(NoMask, gOp)
321 ruleExisting, ok := maskedMergeOpts[noMaskName]
322 rule := fmt.Sprintf("(VPBLENDVB%d dst (%s %s) mask) && v.Block.CPUfeatures.hasFeature(CPUavx512) => (%sMerging dst %s (VPMOVVec%dx%dToM <types.TypeMask> mask))\n",
323 *maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args, *maskElem.ElemBits, *maskElem.Lanes)
324 if ok && ruleExisting != rule {
325 panic("multiple masked merge rules for one op")
326 } else {
327 maskedMergeOpts[noMaskName] = rule
328 }
329 case 512:
330
331 noMaskName := machineOpName(NoMask, gOp)
332 ruleExisting, ok := maskedMergeOpts[noMaskName]
333 rule := fmt.Sprintf("(VPBLENDM%sMasked%d dst (%s %s) mask) => (%sMerging dst %s mask)\n",
334 s2n[*maskElem.ElemBits], *maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args)
335 if ok && ruleExisting != rule {
336 panic("multiple masked merge rules for one op")
337 } else {
338 maskedMergeOpts[noMaskName] = rule
339 }
340 }
341 }
342
343 if tplName == "pureVreg" && data.Args == data.ArgsOut {
344 data.Args = "..."
345 data.ArgsOut = "..."
346 }
347 data.tplName = tplName
348 if opr.NoGenericOps != nil && *opr.NoGenericOps == "true" ||
349 opr.SkipMaskedMethod() {
350 optData = append(optData, data)
351 continue
352 }
353 allData = append(allData, data)
354 asmCheck[data.Asm] = true
355 }
356
357 slices.SortFunc(allData, compareTplRuleData)
358
359 for _, data := range allData {
360 if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
361 panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err))
362 }
363 }
364
365 seen := make(map[string]bool)
366
367 for _, data := range optData {
368 if data.tplName == "maskIn" {
369 rule := data.MaskOptimization(asmCheck)
370 if seen[rule] {
371 continue
372 }
373 seen[rule] = true
374 buffer.WriteString(rule)
375 }
376 }
377
378 maskedMergeOptsRules := []string{}
379 for asm, rule := range maskedMergeOpts {
380 if !asmCheck[asm] {
381 continue
382 }
383 maskedMergeOptsRules = append(maskedMergeOptsRules, rule)
384 }
385 slices.Sort(maskedMergeOptsRules)
386 for _, rule := range maskedMergeOptsRules {
387 buffer.WriteString(rule)
388 }
389
390 for _, data := range memOptData {
391 if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
392 panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.Asm, err))
393 }
394 }
395
396 return buffer
397 }
398
View as plain text