1
2
3
4
5 package main
6
7 import (
8 "bufio"
9 "bytes"
10 "fmt"
11 "go/format"
12 "log"
13 "os"
14 "path/filepath"
15 "reflect"
16 "slices"
17 "sort"
18 "strings"
19 "text/template"
20 "unicode"
21 )
22
23 func templateOf(temp, name string) *template.Template {
24 t, err := template.New(name).Parse(temp)
25 if err != nil {
26 panic(fmt.Errorf("failed to parse template %s: %w", name, err))
27 }
28 return t
29 }
30
31 func createPath(goroot string, file string) (*os.File, error) {
32 fp := filepath.Join(goroot, file)
33 dir := filepath.Dir(fp)
34 err := os.MkdirAll(dir, 0755)
35 if err != nil {
36 return nil, fmt.Errorf("failed to create directory %s: %w", dir, err)
37 }
38 f, err := os.Create(fp)
39 if err != nil {
40 return nil, fmt.Errorf("failed to create file %s: %w", fp, err)
41 }
42 return f, nil
43 }
44
45 func formatWriteAndClose(out *bytes.Buffer, goroot string, file string) {
46 b, err := format.Source(out.Bytes())
47 if err != nil {
48 fmt.Fprintf(os.Stderr, "%v\n", err)
49 fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
50 fmt.Fprintf(os.Stderr, "%v\n", err)
51 panic(err)
52 } else {
53 writeAndClose(b, goroot, file)
54 }
55 }
56
57 func writeAndClose(b []byte, goroot string, file string) {
58 ofile, err := createPath(goroot, file)
59 if err != nil {
60 panic(err)
61 }
62 ofile.Write(b)
63 ofile.Close()
64 }
65
66
67
68 func numberLines(data []byte) string {
69 var buf bytes.Buffer
70 r := bytes.NewReader(data)
71 s := bufio.NewScanner(r)
72 for i := 1; s.Scan(); i++ {
73 fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
74 }
75 return buf.String()
76 }
77
78 type inShape uint8
79 type outShape uint8
80 type maskShape uint8
81 type immShape uint8
82 type memShape uint8
83
84 const (
85 InvalidIn inShape = iota
86 PureVregIn
87 OneKmaskIn
88 OneImmIn
89 OneKmaskImmIn
90 PureKmaskIn
91 )
92
93 const (
94 InvalidOut outShape = iota
95 NoOut
96 OneVregOut
97 OneGregOut
98 OneKmaskOut
99 OneVregOutAtIn
100 )
101
102 const (
103 InvalidMask maskShape = iota
104 NoMask
105 OneMask
106 AllMasks
107 )
108
109 const (
110 InvalidImm immShape = iota
111 NoImm
112 ConstImm
113 VarImm
114 ConstVarImm
115 )
116
117 const (
118 InvalidMem memShape = iota
119 NoMem
120 VregMemIn
121 )
122
123
124
125
126
127
128
129 func (op *Operation) shape() (shapeIn inShape, shapeOut outShape, maskType maskShape, immType immShape,
130 opNoImm Operation) {
131 if len(op.Out) > 1 {
132 panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
133 }
134 var outputReg int
135 if len(op.Out) == 1 {
136 outputReg = op.Out[0].AsmPos
137 if op.Out[0].Class == "vreg" {
138 shapeOut = OneVregOut
139 } else if op.Out[0].Class == "greg" {
140 shapeOut = OneGregOut
141 } else if op.Out[0].Class == "mask" {
142 shapeOut = OneKmaskOut
143 } else {
144 panic(fmt.Errorf("simdgen only supports output of class vreg or mask: %s", op))
145 }
146 } else {
147 shapeOut = NoOut
148
149
150 panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
151 }
152 hasImm := false
153 maskCount := 0
154 hasVreg := false
155 for _, in := range op.In {
156 if in.AsmPos == outputReg {
157 if shapeOut != OneVregOutAtIn && in.AsmPos == 0 && in.Class == "vreg" {
158 shapeOut = OneVregOutAtIn
159 } else {
160 panic(fmt.Errorf("simdgen only support output and input sharing the same position case of \"the first input is vreg and the only output\": %s", op))
161 }
162 }
163 if in.Class == "immediate" {
164
165
166 if *in.Bits != 8 {
167 panic(fmt.Errorf("simdgen only supports immediates of 8 bits: %s", op))
168 }
169 hasImm = true
170 } else if in.Class == "mask" {
171 maskCount++
172 } else {
173 hasVreg = true
174 }
175 }
176 opNoImm = *op
177
178 removeImm := func(o *Operation) {
179 o.In = o.In[1:]
180 }
181 if hasImm {
182 removeImm(&opNoImm)
183 if op.In[0].Const != nil {
184 if op.In[0].ImmOffset != nil {
185 immType = ConstVarImm
186 } else {
187 immType = ConstImm
188 }
189 } else if op.In[0].ImmOffset != nil {
190 immType = VarImm
191 } else {
192 panic(fmt.Errorf("simdgen requires imm to have at least one of ImmOffset or Const set: %s", op))
193 }
194 } else {
195 immType = NoImm
196 }
197 if maskCount == 0 {
198 maskType = NoMask
199 } else {
200 maskType = OneMask
201 }
202 checkPureMask := func() bool {
203 if hasImm {
204 panic(fmt.Errorf("simdgen does not support immediates in pure mask operations: %s", op))
205 }
206 if hasVreg {
207 panic(fmt.Errorf("simdgen does not support more than 1 masks in non-pure mask operations: %s", op))
208 }
209 return false
210 }
211 if !hasImm && maskCount == 0 {
212 shapeIn = PureVregIn
213 } else if !hasImm && maskCount > 0 {
214 if maskCount == 1 {
215 shapeIn = OneKmaskIn
216 } else {
217 if checkPureMask() {
218 return
219 }
220 shapeIn = PureKmaskIn
221 maskType = AllMasks
222 }
223 } else if hasImm && maskCount == 0 {
224 shapeIn = OneImmIn
225 } else {
226 if maskCount == 1 {
227 shapeIn = OneKmaskImmIn
228 } else {
229 checkPureMask()
230 return
231 }
232 }
233 return
234 }
235
236
237 func (op *Operation) regShape(mem memShape) (string, error) {
238 _, _, _, _, gOp := op.shape()
239 var regInfo, fixedName string
240 var vRegInCnt, gRegInCnt, kMaskInCnt, vRegOutCnt, gRegOutCnt, kMaskOutCnt, memInCnt, memOutCnt int
241 for i, in := range gOp.In {
242 switch in.Class {
243 case "vreg":
244 vRegInCnt++
245 case "greg":
246 gRegInCnt++
247 case "mask":
248 kMaskInCnt++
249 case "memory":
250 if mem != VregMemIn {
251 panic("simdgen only knows VregMemIn in regShape")
252 }
253 memInCnt++
254 vRegInCnt++
255 }
256 if in.FixedReg != nil {
257 fixedName = fmt.Sprintf("%sAtIn%d", *in.FixedReg, i)
258 }
259 }
260 for i, out := range gOp.Out {
261
262 if out.Class == "vreg" || out.OverwriteClass != nil {
263 vRegOutCnt++
264 } else if out.Class == "greg" {
265 gRegOutCnt++
266 } else if out.Class == "mask" {
267 kMaskOutCnt++
268 } else if out.Class == "memory" {
269 if mem != VregMemIn {
270 panic("simdgen only knows VregMemIn in regShape")
271 }
272 vRegOutCnt++
273 memOutCnt++
274 }
275 if out.FixedReg != nil {
276 fixedName = fmt.Sprintf("%sAtIn%d", *out.FixedReg, i)
277 }
278 }
279 var inRegs, inMasks, outRegs, outMasks string
280
281 rmAbbrev := func(s string, i int) string {
282 if i == 0 {
283 return ""
284 }
285 if i == 1 {
286 return s
287 }
288 return fmt.Sprintf("%s%d", s, i)
289
290 }
291
292 inRegs = rmAbbrev("v", vRegInCnt)
293 inRegs += rmAbbrev("gp", gRegInCnt)
294 inMasks = rmAbbrev("k", kMaskInCnt)
295
296 outRegs = rmAbbrev("v", vRegOutCnt)
297 outRegs += rmAbbrev("gp", gRegOutCnt)
298 outMasks = rmAbbrev("k", kMaskOutCnt)
299
300 if kMaskInCnt == 0 && kMaskOutCnt == 0 && gRegInCnt == 0 && gRegOutCnt == 0 {
301
302 regInfo = fmt.Sprintf("v%d%d", vRegInCnt, vRegOutCnt)
303 } else if kMaskInCnt == 0 && kMaskOutCnt == 0 {
304 regInfo = fmt.Sprintf("%s%s", inRegs, outRegs)
305 } else {
306 regInfo = fmt.Sprintf("%s%s%s%s", inRegs, inMasks, outRegs, outMasks)
307 }
308 if memInCnt > 0 {
309 if memInCnt == 1 {
310 regInfo += "load"
311 } else {
312 panic("simdgen does not understand more than 1 mem op as of now")
313 }
314 }
315 if memOutCnt > 0 {
316 panic("simdgen does not understand memory as output as of now")
317 }
318 regInfo += fixedName
319 return regInfo, nil
320 }
321
322
323
324
325
326 func (op *Operation) sortOperand() {
327 priority := map[string]int{"immediate": 0, "vreg": 1, "greg": 1, "mask": 2}
328 sort.SliceStable(op.In, func(i, j int) bool {
329 pi := priority[op.In[i].Class]
330 pj := priority[op.In[j].Class]
331 if pi != pj {
332 return pi < pj
333 }
334 return op.In[i].AsmPos < op.In[j].AsmPos
335 })
336 }
337
338
339
340
341
342
343
344 func (op Operation) goNormalType() string {
345 if op.Go == "GetElem" {
346
347
348
349
350
351
352
353 at := 0
354 if op.In[at].Class == "immediate" {
355 at++
356 }
357 return fmt.Sprintf("%s%d", *op.Out[0].Base, *op.In[at].ElemBits)
358 }
359 panic(fmt.Errorf("Implement goNormalType for %v", op))
360 }
361
362
363
364 func (op Operation) SSAType() string {
365 if op.Out[0].Class == "greg" {
366 return fmt.Sprintf("types.Types[types.T%s]", strings.ToUpper(op.goNormalType()))
367 }
368 return fmt.Sprintf("types.TypeVec%d", *op.Out[0].Bits)
369 }
370
371
372
373 func (op Operation) GoType() string {
374 if op.Out[0].Class == "greg" {
375 return op.goNormalType()
376 }
377 return *op.Out[0].Go
378 }
379
380
381
382
383 func (op Operation) ImmName() string {
384 return op.Op0Name("constant")
385 }
386
387 func (o Operand) OpName(s string) string {
388 if n := o.Name; n != nil {
389 return *n
390 }
391 if o.Class == "mask" {
392 return "mask"
393 }
394 return s
395 }
396
397 func (o Operand) OpNameAndType(s string) string {
398 return o.OpName(s) + " " + *o.Go
399 }
400
401
402 func (op Operation) GoExported() string {
403 return capitalizeFirst(op.Go)
404 }
405
406
407 func (op Operation) DocumentationExported() string {
408 return strings.ReplaceAll(op.Documentation, op.Go, op.GoExported())
409 }
410
411
412
413 func (op Operation) Op0Name(s string) string {
414 return op.In[0].OpName(s)
415 }
416
417
418
419 func (op Operation) Op1Name(s string) string {
420 return op.In[1].OpName(s)
421 }
422
423
424
425 func (op Operation) Op2Name(s string) string {
426 return op.In[2].OpName(s)
427 }
428
429
430
431 func (op Operation) Op3Name(s string) string {
432 return op.In[3].OpName(s)
433 }
434
435
436
437
438 func (op Operation) Op0NameAndType(s string) string {
439 return op.In[0].OpNameAndType(s)
440 }
441
442
443
444
445 func (op Operation) Op1NameAndType(s string) string {
446 return op.In[1].OpNameAndType(s)
447 }
448
449
450
451
452 func (op Operation) Op2NameAndType(s string) string {
453 return op.In[2].OpNameAndType(s)
454 }
455
456
457
458
459 func (op Operation) Op3NameAndType(s string) string {
460 return op.In[3].OpNameAndType(s)
461 }
462
463
464
465
466 func (op Operation) Op4NameAndType(s string) string {
467 return op.In[4].OpNameAndType(s)
468 }
469
470 var immClasses []string = []string{"BAD0Imm", "BAD1Imm", "op1Imm8", "op2Imm8", "op3Imm8", "op4Imm8"}
471 var classes []string = []string{"BAD0", "op1", "op2", "op3", "op4"}
472
473
474
475
476
477
478
479 func classifyOp(op Operation) (string, Operation, error) {
480 _, _, _, immType, gOp := op.shape()
481
482 var class string
483
484 if immType == VarImm || immType == ConstVarImm {
485 switch l := len(op.In); l {
486 case 1:
487 return "", op, fmt.Errorf("simdgen does not recognize this operation of only immediate input: %s", op)
488 case 2, 3, 4, 5:
489 class = immClasses[l]
490 default:
491 return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
492 }
493 if order := op.OperandOrder; order != nil {
494 class += "_" + *order
495 }
496 return class, op, nil
497 } else {
498 switch l := len(gOp.In); l {
499 case 1, 2, 3, 4:
500 class = classes[l]
501 default:
502 return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
503 }
504 if order := op.OperandOrder; order != nil {
505 class += "_" + *order
506 }
507 return class, gOp, nil
508 }
509 }
510
511 func checkVecAsScalar(op Operation) (idx int, err error) {
512 idx = -1
513 sSize := 0
514 for i, o := range op.In {
515 if o.TreatLikeAScalarOfSize != nil {
516 if idx == -1 {
517 idx = i
518 sSize = *o.TreatLikeAScalarOfSize
519 } else {
520 err = fmt.Errorf("simdgen only supports one TreatLikeAScalarOfSize in the arg list: %s", op)
521 return
522 }
523 }
524 }
525 if idx >= 0 {
526 if sSize != 8 && sSize != 16 && sSize != 32 && sSize != 64 {
527 err = fmt.Errorf("simdgen does not recognize this uint size: %d, %s", sSize, op)
528 return
529 }
530 }
531 return
532 }
533
534 func rewriteVecAsScalarRegInfo(op Operation, regInfo string) (string, error) {
535 idx, err := checkVecAsScalar(op)
536 if err != nil {
537 return "", err
538 }
539 if idx != -1 {
540 if regInfo == "v21" {
541 regInfo = "vfpv"
542 } else if regInfo == "v2kv" {
543 regInfo = "vfpkv"
544 } else if regInfo == "v31" {
545 regInfo = "v2fpv"
546 } else if regInfo == "v3kv" {
547 regInfo = "v2fpkv"
548 } else {
549 return "", fmt.Errorf("simdgen does not recognize uses of treatLikeAScalarOfSize with op regShape %s in op: %s", regInfo, op)
550 }
551 }
552 return regInfo, nil
553 }
554
555 func rewriteLastVregToMem(op Operation) Operation {
556 newIn := make([]Operand, len(op.In))
557 lastVregIdx := -1
558 for i := range len(op.In) {
559 newIn[i] = op.In[i]
560 if op.In[i].Class == "vreg" {
561 lastVregIdx = i
562 }
563 }
564
565 if lastVregIdx == -1 {
566 panic("simdgen cannot find one vreg in the mem op vreg original")
567 }
568 newIn[lastVregIdx].Class = "memory"
569 op.In = newIn
570
571 return op
572 }
573
574
575 func dedup(ops []Operation) (deduped []Operation) {
576 for _, op := range ops {
577 seen := false
578 for _, dop := range deduped {
579 if reflect.DeepEqual(op, dop) {
580 seen = true
581 break
582 }
583 }
584 if !seen {
585 deduped = append(deduped, op)
586 }
587 }
588 return
589 }
590
591 func (op Operation) GenericName() string {
592 if op.OperandOrder != nil {
593 switch *op.OperandOrder {
594 case "21Type1", "231Type1":
595
596 return op.Go + *op.In[1].Go
597 }
598 }
599 if op.In[0].Class == "immediate" {
600 return op.Go + *op.In[1].Go
601 }
602 return op.Go + *op.In[0].Go
603 }
604
605
606
607
608
609 func dedupGodef(ops []Operation) ([]Operation, error) {
610 seen := map[string][]Operation{}
611 for _, op := range ops {
612 _, _, _, _, gOp := op.shape()
613
614 gN := gOp.GenericName()
615 seen[gN] = append(seen[gN], op)
616 }
617 if *FlagReportDup {
618 for gName, dup := range seen {
619 if len(dup) > 1 {
620 log.Printf("Duplicate for %s:\n", gName)
621 for _, op := range dup {
622 log.Printf("%s\n", op)
623 }
624 }
625 }
626 return ops, nil
627 }
628 isAVX512 := func(op Operation) bool {
629 return strings.Contains(op.CPUFeature, "AVX512")
630 }
631 deduped := []Operation{}
632 for _, dup := range seen {
633 if len(dup) > 1 {
634 slices.SortFunc(dup, func(i, j Operation) int {
635
636 if !isAVX512(i) && isAVX512(j) {
637 return -1
638 }
639 if isAVX512(i) && !isAVX512(j) {
640 return 1
641 }
642 if i.CPUFeature != j.CPUFeature {
643 return strings.Compare(i.CPUFeature, j.CPUFeature)
644 }
645
646
647
648
649 if i.MemFeatures != nil && j.MemFeatures == nil {
650 return -1
651 }
652 if i.MemFeatures == nil && j.MemFeatures != nil {
653 return 1
654 }
655
656 return 0
657 })
658 }
659 deduped = append(deduped, dup[0])
660 }
661 slices.SortFunc(deduped, compareOperations)
662 return deduped, nil
663 }
664
665
666
667 func copyConstImm(ops []Operation) error {
668 for _, op := range ops {
669 if op.ConstImm == nil {
670 continue
671 }
672 _, _, _, immType, _ := op.shape()
673
674 if immType == ConstImm || immType == ConstVarImm {
675 op.In[0].Const = op.ConstImm
676 }
677
678
679 }
680 return nil
681 }
682
683 func capitalizeFirst(s string) string {
684 if s == "" {
685 return ""
686 }
687
688 r := []rune(s)
689 r[0] = unicode.ToUpper(r[0])
690 return string(r)
691 }
692
693
694
695
696
697
698
699 func overwrite(ops []Operation) error {
700 hasClassOverwrite := false
701 overwrite := func(op []Operand, idx int, o Operation) error {
702 if op[idx].OverwriteElementBits != nil {
703 if op[idx].ElemBits == nil {
704 panic(fmt.Errorf("ElemBits is nil at operand %d of %v", idx, o))
705 }
706 *op[idx].ElemBits = *op[idx].OverwriteElementBits
707 *op[idx].Lanes = *op[idx].Bits / *op[idx].ElemBits
708 *op[idx].Go = fmt.Sprintf("%s%dx%d", capitalizeFirst(*op[idx].Base), *op[idx].ElemBits, *op[idx].Lanes)
709 }
710 if op[idx].OverwriteClass != nil {
711 if op[idx].OverwriteBase == nil {
712 panic(fmt.Errorf("simdgen: [OverwriteClass] must be set together with [OverwriteBase]: %s", op[idx]))
713 }
714 oBase := *op[idx].OverwriteBase
715 oClass := *op[idx].OverwriteClass
716 if oClass != "mask" {
717 panic(fmt.Errorf("simdgen: [Class] overwrite only supports overwritting to mask: %s", op[idx]))
718 }
719 if oBase != "int" {
720 panic(fmt.Errorf("simdgen: [Class] overwrite must set [OverwriteBase] to int: %s", op[idx]))
721 }
722 if op[idx].Class != "vreg" {
723 panic(fmt.Errorf("simdgen: [Class] overwrite must be overwriting [Class] from vreg: %s", op[idx]))
724 }
725 hasClassOverwrite = true
726 *op[idx].Base = oBase
727 op[idx].Class = oClass
728 *op[idx].Go = fmt.Sprintf("Mask%dx%d", *op[idx].ElemBits, *op[idx].Lanes)
729 } else if op[idx].OverwriteBase != nil {
730 oBase := *op[idx].OverwriteBase
731 *op[idx].Go = strings.ReplaceAll(*op[idx].Go, capitalizeFirst(*op[idx].Base), capitalizeFirst(oBase))
732 if op[idx].Class == "greg" {
733 *op[idx].Go = strings.ReplaceAll(*op[idx].Go, *op[idx].Base, oBase)
734 }
735 *op[idx].Base = oBase
736 }
737 return nil
738 }
739 for i, o := range ops {
740 hasClassOverwrite = false
741 for j := range ops[i].In {
742 if err := overwrite(ops[i].In, j, o); err != nil {
743 return err
744 }
745 if hasClassOverwrite {
746 return fmt.Errorf("simdgen does not support [OverwriteClass] in inputs: %s", ops[i])
747 }
748 }
749 for j := range ops[i].Out {
750 if err := overwrite(ops[i].Out, j, o); err != nil {
751 return err
752 }
753 }
754 if hasClassOverwrite {
755 for _, in := range ops[i].In {
756 if in.Class == "mask" {
757 return fmt.Errorf("simdgen only supports [OverwriteClass] for operations without mask inputs")
758 }
759 }
760 }
761 }
762 return nil
763 }
764
765
766
767
768
769
770
771 func reportXEDInconsistency(ops []Operation) error {
772 for _, o := range ops {
773 if o.NameAndSizeCheck != nil {
774 suffixSizeMap := map[byte]int{'B': 8, 'W': 16, 'D': 32, 'Q': 64}
775 checkOperand := func(opr Operand) error {
776 if opr.ElemBits == nil {
777 return fmt.Errorf("simdgen expects elemBits to be set when performing NameAndSizeCheck")
778 }
779 if v, ok := suffixSizeMap[o.Asm[len(o.Asm)-1]]; !ok {
780 return fmt.Errorf("simdgen expects asm to end with [BWDQ] when performing NameAndSizeCheck")
781 } else {
782 if v != *opr.ElemBits {
783 return fmt.Errorf("simdgen finds NameAndSizeCheck inconsistency in def: %s", o)
784 }
785 }
786 return nil
787 }
788 for _, in := range o.In {
789 if in.Class != "vreg" && in.Class != "mask" {
790 continue
791 }
792 if in.TreatLikeAScalarOfSize != nil {
793
794 continue
795 }
796 if err := checkOperand(in); err != nil {
797 return err
798 }
799 }
800 for _, out := range o.Out {
801 if err := checkOperand(out); err != nil {
802 return err
803 }
804 }
805 }
806 }
807 return nil
808 }
809
810 func (o *Operation) hasMaskedMerging(maskType maskShape, outType outShape) bool {
811
812 return o.OperandOrder == nil && o.SpecialLower == nil && maskType == OneMask && outType == OneVregOut &&
813 len(o.InVariant) == 1 && !strings.Contains(o.Asm, "BLEND") && !strings.Contains(o.Asm, "VMOVDQU")
814 }
815
816 func getVbcstData(s string) (feat1Match, feat2Match string) {
817 _, err := fmt.Sscanf(s, "feat1=%[^;];feat2=%s", &feat1Match, &feat2Match)
818 if err != nil {
819 panic(err)
820 }
821 return
822 }
823
824 func (o Operation) String() string {
825 return pprints(o)
826 }
827
828 func (op Operand) String() string {
829 return pprints(op)
830 }
831
View as plain text