1
2
3
4
5 package main
6
7 import (
8 "cmp"
9 "fmt"
10 "log"
11 "maps"
12 "reflect"
13 "regexp"
14 "slices"
15 "strconv"
16 "strings"
17
18 "simd/_gen/unify"
19
20 "golang.org/x/arch/x86/xeddata"
21 "gopkg.in/yaml.v3"
22 )
23
24 const (
25 NOT_REG_CLASS = iota
26 VREG_CLASS
27 GREG_CLASS
28 )
29
30
31
32 type instVariant uint8
33
34 const (
35 instVariantNone instVariant = 0
36
37
38
39 instVariantMasked instVariant = 1 << iota
40 )
41
42 var operandRemarks int
43
44
45 func loadXED(xedPath string) []*unify.Value {
46
47
48 db, err := xeddata.NewDatabase(xedPath)
49 if err != nil {
50 log.Fatalf("open database: %v", err)
51 }
52
53 var defs []*unify.Value
54 type opData struct {
55 inst *xeddata.Inst
56 ops []operand
57 mem string
58 }
59
60 memOps := make(map[string][]opData, 0)
61 otherOps := make(map[string][]opData, 0)
62 appendDefs := func(inst *xeddata.Inst, ops []operand, addFields map[string]string) {
63 applyQuirks(inst, ops)
64
65 defsPos := len(defs)
66 defs = append(defs, instToUVal(inst, ops, addFields)...)
67
68 if *flagDebugXED {
69 for i := defsPos; i < len(defs); i++ {
70 y, _ := yaml.Marshal(defs[i])
71 fmt.Printf("==>\n%s\n", y)
72 }
73 }
74 }
75 err = xeddata.WalkInsts(xedPath, func(inst *xeddata.Inst) {
76 inst.Pattern = xeddata.ExpandStates(db, inst.Pattern)
77
78 switch {
79 case inst.RealOpcode == "N":
80 return
81 case !(strings.HasPrefix(inst.Extension, "AVX") || strings.HasPrefix(inst.Extension, "SHA")):
82
83 return
84 }
85
86 if *flagDebugXED {
87 fmt.Printf("%s:\n%+v\n", inst.Pos, inst)
88 }
89
90 ops, err := decodeOperands(db, strings.Fields(inst.Operands))
91 if err != nil {
92 operandRemarks++
93 if *Verbose {
94 log.Printf("%s: [%s] %s", inst.Pos, inst.Opcode(), err)
95 }
96 return
97 }
98 var data map[string][]opData
99 mem := checkMem(ops)
100 if mem == "vbcst" {
101
102
103 data = memOps
104 } else {
105 data = otherOps
106 }
107 opcode := inst.Opcode()
108 if _, ok := data[opcode]; !ok {
109 s := make([]opData, 1)
110 s[0] = opData{inst, ops, mem}
111 data[opcode] = s
112 } else {
113 data[opcode] = append(data[opcode], opData{inst, ops, mem})
114 }
115 })
116 for _, s := range otherOps {
117 for _, o := range s {
118 addFields := map[string]string{}
119 if o.mem == "noMem" {
120 opcode := o.inst.Opcode()
121
122
123
124 if ms, ok := memOps[opcode]; ok {
125 feat1, ok1 := decodeCPUFeature(o.inst)
126
127
128 var feat1Match, feat2Match string
129 matchIdx := -1
130 var featMismatchCnt int
131 outer:
132 for i, m := range ms {
133
134 var featMismatch bool
135 feat2, ok2 := decodeCPUFeature(m.inst)
136 if !ok1 || !ok2 {
137 continue
138 }
139 if feat1 != feat2 {
140 featMismatch = true
141 featMismatchCnt++
142 }
143 if len(o.ops) == len(m.ops) {
144 for j := range o.ops {
145 if reflect.TypeOf(o.ops[j]) == reflect.TypeOf(m.ops[j]) {
146 v1, ok3 := o.ops[j].(operandVReg)
147 v2, _ := m.ops[j].(operandVReg)
148 if !ok3 {
149 continue
150 }
151 if v1.vecShape != v2.vecShape {
152
153 continue outer
154 }
155 } else {
156 _, ok3 := o.ops[j].(operandVReg)
157 _, ok4 := m.ops[j].(operandMem)
158
159 if !ok3 || !ok4 {
160
161 continue outer
162 }
163 }
164 }
165
166 matchIdx = i
167 feat1Match = feat1
168 feat2Match = feat2
169 if featMismatchCnt > 1 {
170 panic("multiple feature mismatch vbcst memops detected, simdgen failed to distinguish")
171 }
172 if !featMismatch {
173
174 break
175 }
176 }
177 }
178
179 if matchIdx != -1 {
180 memOps[opcode] = append(memOps[opcode][:matchIdx], memOps[opcode][matchIdx+1:]...)
181
182
183 addFields["memFeatures"] = "vbcst"
184 if feat1Match != feat2Match {
185 addFields["memFeaturesData"] = fmt.Sprintf("feat1=%s;feat2=%s", feat1Match, feat2Match)
186 }
187 }
188 }
189 }
190 appendDefs(o.inst, o.ops, addFields)
191 }
192 }
193 for _, ms := range memOps {
194 for _, m := range ms {
195 if *Verbose {
196 log.Printf("mem op not merged: %s, %v\n", m.inst.Opcode(), m)
197 }
198 appendDefs(m.inst, m.ops, nil)
199 }
200 }
201 if err != nil {
202 log.Fatalf("walk insts: %v", err)
203 }
204
205 if len(unknownFeatures) > 0 {
206 if !*Verbose {
207 nInst := 0
208 for _, insts := range unknownFeatures {
209 nInst += len(insts)
210 }
211 log.Printf("%d unhandled CPU features for %d instructions (use -v for details)", len(unknownFeatures), nInst)
212 } else {
213 keys := slices.SortedFunc(maps.Keys(unknownFeatures), func(a, b cpuFeatureKey) int {
214 return cmp.Or(cmp.Compare(a.Extension, b.Extension),
215 cmp.Compare(a.ISASet, b.ISASet))
216 })
217 for _, key := range keys {
218 if key.ISASet == "" || key.ISASet == key.Extension {
219 log.Printf("unhandled Extension %s", key.Extension)
220 } else {
221 log.Printf("unhandled Extension %s and ISASet %s", key.Extension, key.ISASet)
222 }
223 log.Printf(" opcodes: %s", slices.Sorted(maps.Keys(unknownFeatures[key])))
224 }
225 }
226 }
227
228 return defs
229 }
230
231 var (
232 maskRequiredRe = regexp.MustCompile(`VPCOMPRESS[BWDQ]|VCOMPRESSP[SD]|VPEXPAND[BWDQ]|VEXPANDP[SD]`)
233 maskOptionalRe = regexp.MustCompile(`VPCMP(EQ|GT|U)?[BWDQ]|VCMPP[SD]`)
234 )
235
236 func applyQuirks(inst *xeddata.Inst, ops []operand) {
237 opc := inst.Opcode()
238 switch {
239 case maskRequiredRe.MatchString(opc):
240
241
242 for i, op := range ops {
243 if op, ok := op.(operandMask); ok {
244 op.optional = false
245 ops[i] = op
246 }
247 }
248
249 case maskOptionalRe.MatchString(opc):
250
251 for i, op := range ops {
252 if op, ok := op.(operandMask); ok && op.action.r {
253 op.optional = true
254 ops[i] = op
255 }
256 }
257 }
258 }
259
260 type operandCommon struct {
261 action operandAction
262 }
263
264
265
266
267 type operandAction struct {
268 r bool
269 w bool
270 cr bool
271 cw bool
272 }
273
274 type operandMem struct {
275 operandCommon
276 vecShape
277 elemBaseType scalarBaseType
278
279
280
281 vbcst bool
282 unknown bool
283 }
284
285 type vecShape struct {
286 elemBits int
287 bits int
288 fixedName string
289 }
290
291 type operandVReg struct {
292 operandCommon
293 vecShape
294 elemBaseType scalarBaseType
295 }
296
297 type operandGReg struct {
298 operandCommon
299 vecShape
300 elemBaseType scalarBaseType
301 }
302
303
304
305
306
307
308
309
310 type operandMask struct {
311 operandCommon
312 vecShape
313
314
315 allMasks bool
316
317
318 optional bool
319 }
320
321 type operandImm struct {
322 operandCommon
323 bits int
324 }
325
326 type operand interface {
327 common() operandCommon
328 addToDef(b *unify.DefBuilder)
329 }
330
331 func strVal(s any) *unify.Value {
332 return unify.NewValue(unify.NewStringExact(fmt.Sprint(s)))
333 }
334
335 func (o operandCommon) common() operandCommon {
336 return o
337 }
338
339 func (o operandMem) addToDef(b *unify.DefBuilder) {
340 b.Add("class", strVal("memory"))
341 if o.unknown {
342 return
343 }
344 baseDomain, err := unify.NewStringRegex(o.elemBaseType.regex())
345 if err != nil {
346 panic("parsing baseRe: " + err.Error())
347 }
348 b.Add("base", unify.NewValue(baseDomain))
349 b.Add("bits", strVal(o.bits))
350 if o.elemBits != o.bits {
351 b.Add("elemBits", strVal(o.elemBits))
352 }
353 }
354
355 func (o operandVReg) addToDef(b *unify.DefBuilder) {
356 baseDomain, err := unify.NewStringRegex(o.elemBaseType.regex())
357 if err != nil {
358 panic("parsing baseRe: " + err.Error())
359 }
360 b.Add("class", strVal("vreg"))
361 b.Add("bits", strVal(o.bits))
362 b.Add("base", unify.NewValue(baseDomain))
363
364
365 if o.elemBits != o.bits {
366 b.Add("elemBits", strVal(o.elemBits))
367 }
368 if o.fixedName != "" {
369 b.Add("fixedReg", strVal(o.fixedName))
370 }
371 }
372
373 func (o operandGReg) addToDef(b *unify.DefBuilder) {
374 baseDomain, err := unify.NewStringRegex(o.elemBaseType.regex())
375 if err != nil {
376 panic("parsing baseRe: " + err.Error())
377 }
378 b.Add("class", strVal("greg"))
379 b.Add("bits", strVal(o.bits))
380 b.Add("base", unify.NewValue(baseDomain))
381 if o.elemBits != o.bits {
382 b.Add("elemBits", strVal(o.elemBits))
383 }
384 if o.fixedName != "" {
385 b.Add("fixedReg", strVal(o.fixedName))
386 }
387 }
388
389 func (o operandMask) addToDef(b *unify.DefBuilder) {
390 b.Add("class", strVal("mask"))
391 if o.allMasks {
392
393 return
394 }
395 b.Add("elemBits", strVal(o.elemBits))
396 b.Add("bits", strVal(o.bits))
397 if o.fixedName != "" {
398 b.Add("fixedReg", strVal(o.fixedName))
399 }
400 }
401
402 func (o operandImm) addToDef(b *unify.DefBuilder) {
403 b.Add("class", strVal("immediate"))
404 b.Add("bits", strVal(o.bits))
405 }
406
407 var actionEncoding = map[string]operandAction{
408 "r": {r: true},
409 "cr": {r: true, cr: true},
410 "w": {w: true},
411 "cw": {w: true, cw: true},
412 "rw": {r: true, w: true},
413 "crw": {r: true, w: true, cr: true},
414 "rcw": {r: true, w: true, cw: true},
415 }
416
417 func decodeOperand(db *xeddata.Database, operand string) (operand, error) {
418 op, err := xeddata.NewOperand(db, operand)
419 if err != nil {
420 log.Fatalf("parsing operand %q: %v", operand, err)
421 }
422 if *flagDebugXED {
423 fmt.Printf(" %+v\n", op)
424 }
425
426 if strings.HasPrefix(op.Name, "EMX_BROADCAST") {
427
428
429
430
431 return nil, nil
432 }
433
434
435
436 action, ok := actionEncoding[op.Action]
437 if !ok {
438 return nil, fmt.Errorf("unknown action %q", op.Action)
439 }
440 common := operandCommon{action: action}
441
442 lhs := op.NameLHS()
443 if strings.HasPrefix(lhs, "MEM") {
444
445
446 if op.Width == "vv" && (op.Attributes["TXT=BCASTSTR"] ||
447 op.Attributes["TXT=VPBROADCASTD"]) {
448 baseType, elemBits, ok := decodeType(op)
449 if !ok {
450 return nil, fmt.Errorf("failed to decode memory width %q", operand)
451 }
452
453
454
455
456
457 shape := vecShape{elemBits: elemBits, bits: elemBits}
458 return operandMem{
459 operandCommon: common,
460 vecShape: shape,
461 elemBaseType: baseType,
462 vbcst: true,
463 unknown: false,
464 }, nil
465 }
466
467
468 return operandMem{
469 operandCommon: common,
470 unknown: true,
471 }, nil
472 } else if strings.HasPrefix(lhs, "REG") {
473 if op.Width == "mskw" {
474
475
476
477
478 return operandMask{
479 operandCommon: common,
480 optional: op.Attributes["TXT=ZEROSTR"],
481 }, nil
482 } else {
483 class, regBits, fixedReg := decodeReg(op)
484 if class == NOT_REG_CLASS {
485 return nil, fmt.Errorf("failed to decode register %q", operand)
486 }
487 baseType, elemBits, ok := decodeType(op)
488 if !ok {
489 return nil, fmt.Errorf("failed to decode register width %q", operand)
490 }
491 shape := vecShape{elemBits: elemBits, bits: regBits, fixedName: fixedReg}
492 if class == VREG_CLASS {
493 return operandVReg{
494 operandCommon: common,
495 vecShape: shape,
496 elemBaseType: baseType,
497 }, nil
498 }
499
500 m := min(shape.bits, shape.elemBits)
501 shape.bits, shape.elemBits = m, m
502 return operandGReg{
503 operandCommon: common,
504 vecShape: shape,
505 elemBaseType: baseType,
506 }, nil
507
508 }
509 } else if strings.HasPrefix(lhs, "IMM") {
510 _, bits, ok := decodeType(op)
511 if !ok {
512 return nil, fmt.Errorf("failed to decode register width %q", operand)
513 }
514 return operandImm{
515 operandCommon: common,
516 bits: bits,
517 }, nil
518 }
519
520
521 return nil, fmt.Errorf("unknown operand LHS %q in %q", lhs, operand)
522 }
523
524 func decodeOperands(db *xeddata.Database, operands []string) (ops []operand, err error) {
525
526 for _, o := range operands {
527 op, err := decodeOperand(db, o)
528 if err != nil {
529 return nil, err
530 }
531 if op != nil {
532 ops = append(ops, op)
533 }
534 }
535
536
537
538 if err := inferMaskSizes(ops); err != nil {
539 return nil, fmt.Errorf("%w in operands %+v", err, operands)
540 }
541
542 return ops, nil
543 }
544
545 func inferMaskSizes(ops []operand) error {
546
547
548
549
550
551
552
553
554
555
556
557
558
559 inferMask := func(r, w bool) error {
560 var masks []int
561 var rSizes, wSizes, sizes []vecShape
562 allMasks := true
563 hasWMask := false
564 for i, op := range ops {
565 action := op.common().action
566 if _, ok := op.(operandMask); ok {
567 if action.r && action.w {
568 return fmt.Errorf("unexpected rw mask")
569 }
570 if action.r == r || action.w == w {
571 masks = append(masks, i)
572 }
573 if action.w {
574 hasWMask = true
575 }
576 } else {
577 allMasks = false
578 if reg, ok := op.(operandVReg); ok {
579 if action.r {
580 rSizes = append(rSizes, reg.vecShape)
581 }
582 if action.w {
583 wSizes = append(wSizes, reg.vecShape)
584 }
585 }
586 }
587 }
588 if len(masks) == 0 {
589 return nil
590 }
591
592 if r {
593 sizes = rSizes
594 if len(sizes) == 0 {
595 sizes = wSizes
596 }
597 }
598 if w {
599 sizes = wSizes
600 if len(sizes) == 0 {
601 sizes = rSizes
602 }
603 }
604
605 if len(sizes) == 0 {
606
607 if allMasks {
608 for _, i := range masks {
609 m := ops[i].(operandMask)
610 m.allMasks = true
611 ops[i] = m
612 }
613 return nil
614 }
615 return fmt.Errorf("cannot infer mask size: no register operands")
616 }
617 shape, ok := singular(sizes)
618 if !ok {
619 if !hasWMask && len(wSizes) == 1 && len(masks) == 1 {
620
621
622 shape = wSizes[0]
623 } else {
624 return fmt.Errorf("cannot infer mask size: multiple register sizes %v", sizes)
625 }
626 }
627 for _, i := range masks {
628 m := ops[i].(operandMask)
629 m.vecShape = shape
630 ops[i] = m
631 }
632 return nil
633 }
634 if err := inferMask(true, false); err != nil {
635 return err
636 }
637 if err := inferMask(false, true); err != nil {
638 return err
639 }
640 return nil
641 }
642
643
644
645
646
647 func addOperandsToDef(ops []operand, instDB *unify.DefBuilder, variant instVariant) {
648 var inVals, inVar, outVals []*unify.Value
649 asmPos := 0
650 for _, op := range ops {
651 var db unify.DefBuilder
652 op.addToDef(&db)
653 db.Add("asmPos", unify.NewValue(unify.NewStringExact(fmt.Sprint(asmPos))))
654
655 action := op.common().action
656 asmCount := 1
657 if action.r {
658 inVal := unify.NewValue(db.Build())
659
660 if mask, ok := op.(operandMask); ok && mask.optional {
661 if variant&instVariantMasked != 0 {
662 inVar = append(inVar, inVal)
663 } else {
664
665 asmCount = 0
666 }
667 } else {
668
669 inVals = append(inVals, inVal)
670 }
671 }
672 if action.w {
673 outVal := unify.NewValue(db.Build())
674 outVals = append(outVals, outVal)
675 }
676
677 asmPos += asmCount
678 }
679
680 instDB.Add("in", unify.NewValue(unify.NewTuple(inVals...)))
681 instDB.Add("inVariant", unify.NewValue(unify.NewTuple(inVar...)))
682 instDB.Add("out", unify.NewValue(unify.NewTuple(outVals...)))
683 memFeatures := checkMem(ops)
684 if memFeatures != "noMem" {
685 instDB.Add("memFeatures", unify.NewValue(unify.NewStringExact(memFeatures)))
686 }
687 }
688
689
690
691 func checkMem(ops []operand) string {
692 memState := "noMem"
693 var mem *operandMem
694 memCnt := 0
695 for _, op := range ops {
696 if m, ok := op.(operandMem); ok {
697 mem = &m
698 memCnt++
699 }
700 }
701 if mem != nil {
702 if mem.unknown {
703 memState = "unknown"
704 } else if memCnt > 1 {
705 memState = "tooManyMem"
706 } else {
707
708
709
710
711 memState = "vbcst"
712 }
713 }
714 return memState
715 }
716
717 func instToUVal(inst *xeddata.Inst, ops []operand, addFields map[string]string) []*unify.Value {
718 feature, ok := decodeCPUFeature(inst)
719 if !ok {
720 return nil
721 }
722
723 var vals []*unify.Value
724 vals = append(vals, instToUVal1(inst, ops, feature, instVariantNone, addFields))
725 if hasOptionalMask(ops) {
726 vals = append(vals, instToUVal1(inst, ops, feature, instVariantMasked, addFields))
727 }
728 return vals
729 }
730
731 func instToUVal1(inst *xeddata.Inst, ops []operand, feature string, variant instVariant, addFields map[string]string) *unify.Value {
732 var db unify.DefBuilder
733 db.Add("goarch", unify.NewValue(unify.NewStringExact("amd64")))
734 db.Add("asm", unify.NewValue(unify.NewStringExact(inst.Opcode())))
735 addOperandsToDef(ops, &db, variant)
736 db.Add("cpuFeature", unify.NewValue(unify.NewStringExact(feature)))
737 for k, v := range addFields {
738 db.Add(k, unify.NewValue(unify.NewStringExact(v)))
739 }
740
741 if strings.Contains(inst.Pattern, "ZEROING=0") {
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757 db.Add("zeroing", unify.NewValue(unify.NewStringExact("false")))
758 }
759 pos := unify.Pos{Path: inst.Pos.Path, Line: inst.Pos.Line}
760 return unify.NewValuePos(db.Build(), pos)
761 }
762
763
764
765 func decodeCPUFeature(inst *xeddata.Inst) (string, bool) {
766 key := cpuFeatureKey{
767 Extension: inst.Extension,
768 ISASet: isaSetStrip.ReplaceAllLiteralString(inst.ISASet, ""),
769 }
770 feat, ok := cpuFeatureMap[key]
771 if !ok {
772 imap := unknownFeatures[key]
773 if imap == nil {
774 imap = make(map[string]struct{})
775 unknownFeatures[key] = imap
776 }
777 imap[inst.Opcode()] = struct{}{}
778 return "", false
779 }
780 if feat == "ignore" {
781 return "", false
782 }
783 return feat, true
784 }
785
786 var isaSetStrip = regexp.MustCompile("_(128N?|256N?|512)$")
787
788 type cpuFeatureKey struct {
789 Extension, ISASet string
790 }
791
792
793
794 var cpuFeatureMap = map[cpuFeatureKey]string{
795 {"SHA", "SHA"}: "SHA",
796
797 {"AVX", ""}: "AVX",
798 {"AVX_VNNI", "AVX_VNNI"}: "AVXVNNI",
799 {"AVX2", ""}: "AVX2",
800 {"AVXAES", ""}: "AVX, AES",
801
802
803 {"AVX512EVEX", "AVX512F"}: "AVX512",
804 {"AVX512EVEX", "AVX512CD"}: "AVX512",
805 {"AVX512EVEX", "AVX512BW"}: "AVX512",
806 {"AVX512EVEX", "AVX512DQ"}: "AVX512",
807
808
809
810
811 {"AVX512EVEX", "AVX512_BITALG"}: "AVX512BITALG",
812 {"AVX512EVEX", "AVX512_GFNI"}: "AVX512GFNI",
813 {"AVX512EVEX", "AVX512_VBMI2"}: "AVX512VBMI2",
814 {"AVX512EVEX", "AVX512_VBMI"}: "AVX512VBMI",
815 {"AVX512EVEX", "AVX512_VNNI"}: "AVX512VNNI",
816 {"AVX512EVEX", "AVX512_VPOPCNTDQ"}: "AVX512VPOPCNTDQ",
817 {"AVX512EVEX", "AVX512_VAES"}: "AVX512VAES",
818
819
820 {"AVX512EVEX", "AVX10_2_RC"}: "ignore",
821 }
822
823 var unknownFeatures = map[cpuFeatureKey]map[string]struct{}{}
824
825
826 func hasOptionalMask(ops []operand) bool {
827 for _, op := range ops {
828 if op, ok := op.(operandMask); ok && op.optional {
829 return true
830 }
831 }
832 return false
833 }
834
835 func singular[T comparable](xs []T) (T, bool) {
836 if len(xs) == 0 {
837 return *new(T), false
838 }
839 for _, x := range xs[1:] {
840 if x != xs[0] {
841 return *new(T), false
842 }
843 }
844 return xs[0], true
845 }
846
847 type fixedReg struct {
848 class int
849 name string
850 width int
851 }
852
853 var fixedRegMap = map[string]fixedReg{
854 "XED_REG_XMM0": {VREG_CLASS, "x0", 128},
855 }
856
857
858
859
860 func decodeReg(op *xeddata.Operand) (class, width int, name string) {
861
862
863
864
865
866
867
868
869
870
871
872 if !strings.HasPrefix(op.NameLHS(), "REG") {
873 return NOT_REG_CLASS, 0, ""
874 }
875
876
877 rhs := op.NameRHS()
878 if !strings.HasSuffix(rhs, "()") {
879 if fixedReg, ok := fixedRegMap[rhs]; ok {
880 return fixedReg.class, fixedReg.width, fixedReg.name
881 }
882 return NOT_REG_CLASS, 0, ""
883 }
884 switch {
885 case strings.HasPrefix(rhs, "XMM_"):
886 return VREG_CLASS, 128, ""
887 case strings.HasPrefix(rhs, "YMM_"):
888 return VREG_CLASS, 256, ""
889 case strings.HasPrefix(rhs, "ZMM_"):
890 return VREG_CLASS, 512, ""
891 case strings.HasPrefix(rhs, "GPR64_"), strings.HasPrefix(rhs, "VGPR64_"):
892 return GREG_CLASS, 64, ""
893 case strings.HasPrefix(rhs, "GPR32_"), strings.HasPrefix(rhs, "VGPR32_"):
894 return GREG_CLASS, 32, ""
895 }
896 return NOT_REG_CLASS, 0, ""
897 }
898
899 var xtypeRe = regexp.MustCompile(`^([iuf])([0-9]+)$`)
900
901
902
903
904 type scalarBaseType int
905
906 const (
907 scalarBaseInt scalarBaseType = iota
908 scalarBaseUint
909 scalarBaseIntOrUint
910 scalarBaseFloat
911 scalarBaseComplex
912 scalarBaseBFloat
913 scalarBaseHFloat
914 )
915
916 func (s scalarBaseType) regex() string {
917 switch s {
918 case scalarBaseInt:
919 return "int"
920 case scalarBaseUint:
921 return "uint"
922 case scalarBaseIntOrUint:
923 return "int|uint"
924 case scalarBaseFloat:
925 return "float"
926 case scalarBaseComplex:
927 return "complex"
928 case scalarBaseBFloat:
929 return "BFloat"
930 case scalarBaseHFloat:
931 return "HFloat"
932 }
933 panic(fmt.Sprintf("unknown scalar base type %d", s))
934 }
935
936 func decodeType(op *xeddata.Operand) (base scalarBaseType, bits int, ok bool) {
937
938
939
940
941
942
943
944
945
946 switch op.Xtype {
947
948
949 case "bf8":
950 return scalarBaseBFloat, 8, true
951 case "hf8":
952 return scalarBaseHFloat, 8, true
953 case "bf16":
954 return scalarBaseBFloat, 16, true
955 case "2f16":
956
957
958 return scalarBaseComplex, 32, true
959 case "2i8", "2I8":
960
961
962 return scalarBaseInt, 8, true
963 case "2u16", "2U16":
964
965
966 return scalarBaseUint, 16, true
967 case "2i16", "2I16":
968
969 return scalarBaseInt, 16, true
970 case "4u8", "4U8":
971
972 return scalarBaseUint, 8, true
973 case "4i8", "4I8":
974
975 return scalarBaseInt, 8, true
976 }
977
978
979 m := xtypeRe.FindStringSubmatch(op.Xtype)
980 if m == nil {
981
982 return 0, 0, false
983 }
984 bits, _ = strconv.Atoi(m[2])
985 switch m[1] {
986 case "i", "u":
987
988
989
990 return scalarBaseIntOrUint, bits, true
991 case "f":
992 return scalarBaseFloat, bits, true
993 default:
994 panic("unreachable")
995 }
996 }
997
View as plain text