Source file src/simd/_gen/simdgen/gen_simdrules.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"slices"
    11  	"strings"
    12  	"text/template"
    13  )
    14  
    15  type tplRuleData struct {
    16  	tplName        string // e.g. "sftimm"
    17  	GoOp           string // e.g. "ShiftAllLeft"
    18  	GoType         string // e.g. "Uint32x8"
    19  	Args           string // e.g. "x y"
    20  	Asm            string // e.g. "VPSLLD256"
    21  	ArgsOut        string // e.g. "x y"
    22  	MaskInConvert  string // e.g. "VPMOVVec32x8ToM"
    23  	MaskOutConvert string // e.g. "VPMOVMToVec32x8"
    24  	ElementSize    int    // e.g. 32
    25  	Size           int    // e.g. 128
    26  	ArgsLoadAddr   string // [Args] with its last vreg arg being a concrete "(VMOVDQUload* ptr mem)", and might contain mask.
    27  	ArgsAddr       string // [Args] with its last vreg arg being replaced by "ptr", and might contain mask, and with a "mem" at the end.
    28  	FeatCheck      string // e.g. "v.Block.CPUfeatures.hasFeature(CPUavx512)" -- for a ssa/_gen rules file.
    29  }
    30  
    31  var (
    32  	ruleTemplates = template.Must(template.New("simdRules").Parse(`
    33  {{define "pureVreg"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ArgsOut}})
    34  {{end}}
    35  {{define "maskIn"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask))
    36  {{end}}
    37  {{define "maskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}}))
    38  {{end}}
    39  {{define "maskInMaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask)))
    40  {{end}}
    41  {{define "sftimm"}}({{.Asm}} x (MOVQconst [c])) => ({{.Asm}}const [uint8(c)] x)
    42  {{end}}
    43  {{define "masksftimm"}}({{.Asm}} x (MOVQconst [c]) mask) => ({{.Asm}}const [uint8(c)] x mask)
    44  {{end}}
    45  {{define "vregMem"}}({{.Asm}} {{.ArgsLoadAddr}}) && canMergeLoad(v, l) && clobber(l) => ({{.Asm}}load {{.ArgsAddr}})
    46  {{end}}
    47  {{define "vregMemFeatCheck"}}({{.Asm}} {{.ArgsLoadAddr}}) && {{.FeatCheck}} && canMergeLoad(v, l) && clobber(l)=> ({{.Asm}}load {{.ArgsAddr}})
    48  {{end}}
    49  `))
    50  )
    51  
    52  func (d tplRuleData) MaskOptimization(asmCheck map[string]bool) string {
    53  	asmNoMask := d.Asm
    54  	if i := strings.Index(asmNoMask, "Masked"); i == -1 {
    55  		return ""
    56  	}
    57  	asmNoMask = strings.ReplaceAll(asmNoMask, "Masked", "")
    58  	if asmCheck[asmNoMask] == false {
    59  		return ""
    60  	}
    61  
    62  	for _, nope := range []string{"VMOVDQU", "VPCOMPRESS", "VCOMPRESS", "VPEXPAND", "VEXPAND", "VPBLENDM", "VMOVUP"} {
    63  		if strings.HasPrefix(asmNoMask, nope) {
    64  			return ""
    65  		}
    66  	}
    67  
    68  	size := asmNoMask[len(asmNoMask)-3:]
    69  	if strings.HasSuffix(asmNoMask, "const") {
    70  		sufLen := len("128const")
    71  		size = asmNoMask[len(asmNoMask)-sufLen:][:3]
    72  	}
    73  	switch size {
    74  	case "128", "256", "512":
    75  	default:
    76  		panic("Unexpected operation size on " + d.Asm)
    77  	}
    78  
    79  	switch d.ElementSize {
    80  	case 8, 16, 32, 64:
    81  	default:
    82  		panic(fmt.Errorf("Unexpected operation width %d on %v", d.ElementSize, d.Asm))
    83  	}
    84  
    85  	return fmt.Sprintf("(VMOVDQU%dMasked%s (%s %s) mask) => (%s %s mask)\n", d.ElementSize, size, asmNoMask, d.Args, d.Asm, d.Args)
    86  }
    87  
    88  // SSA rewrite rules need to appear in a most-to-least-specific order.  This works for that.
    89  var tmplOrder = map[string]int{
    90  	"masksftimm":    0,
    91  	"sftimm":        1,
    92  	"maskInMaskOut": 2,
    93  	"maskOut":       3,
    94  	"maskIn":        4,
    95  	"pureVreg":      5,
    96  	"vregMem":       6,
    97  }
    98  
    99  func compareTplRuleData(x, y tplRuleData) int {
   100  	if c := compareNatural(x.GoOp, y.GoOp); c != 0 {
   101  		return c
   102  	}
   103  	if c := compareNatural(x.GoType, y.GoType); c != 0 {
   104  		return c
   105  	}
   106  	if c := compareNatural(x.Args, y.Args); c != 0 {
   107  		return c
   108  	}
   109  	if x.tplName == y.tplName {
   110  		return 0
   111  	}
   112  	xo, xok := tmplOrder[x.tplName]
   113  	yo, yok := tmplOrder[y.tplName]
   114  	if !xok {
   115  		panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", x.tplName))
   116  	}
   117  	if !yok {
   118  		panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", y.tplName))
   119  	}
   120  	return xo - yo
   121  }
   122  
   123  // writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules
   124  // within the specified directory.
   125  func writeSIMDRules(ops []Operation) *bytes.Buffer {
   126  	buffer := new(bytes.Buffer)
   127  	buffer.WriteString(generatedHeader + "\n")
   128  
   129  	// asm -> masked merging rules
   130  	maskedMergeOpts := make(map[string]string)
   131  	s2n := map[int]string{8: "B", 16: "W", 32: "D", 64: "Q"}
   132  	asmCheck := map[string]bool{}
   133  	var allData []tplRuleData
   134  	var optData []tplRuleData    // for mask peephole optimizations, and other misc
   135  	var memOptData []tplRuleData // for memory peephole optimizations
   136  	memOpSeen := make(map[string]bool)
   137  
   138  	for _, opr := range ops {
   139  		opInShape, opOutShape, maskType, immType, gOp := opr.shape()
   140  		asm := machineOpName(maskType, gOp)
   141  		vregInCnt := len(gOp.In)
   142  		if maskType == OneMask {
   143  			vregInCnt--
   144  		}
   145  
   146  		data := tplRuleData{
   147  			GoOp: gOp.Go,
   148  			Asm:  asm,
   149  		}
   150  
   151  		if vregInCnt == 1 {
   152  			data.Args = "x"
   153  			data.ArgsOut = data.Args
   154  		} else if vregInCnt == 2 {
   155  			data.Args = "x y"
   156  			data.ArgsOut = data.Args
   157  		} else if vregInCnt == 3 {
   158  			data.Args = "x y z"
   159  			data.ArgsOut = data.Args
   160  		} else {
   161  			panic(fmt.Errorf("simdgen does not support more than 3 vreg in inputs"))
   162  		}
   163  		if immType == ConstImm {
   164  			data.ArgsOut = fmt.Sprintf("[%s] %s", *opr.In[0].Const, data.ArgsOut)
   165  		} else if immType == VarImm {
   166  			data.Args = fmt.Sprintf("[a] %s", data.Args)
   167  			data.ArgsOut = fmt.Sprintf("[a] %s", data.ArgsOut)
   168  		} else if immType == ConstVarImm {
   169  			data.Args = fmt.Sprintf("[a] %s", data.Args)
   170  			data.ArgsOut = fmt.Sprintf("[a+%s] %s", *opr.In[0].Const, data.ArgsOut)
   171  		}
   172  
   173  		goType := func(op Operation) string {
   174  			if op.OperandOrder != nil {
   175  				switch *op.OperandOrder {
   176  				case "21Type1", "231Type1":
   177  					// Permute uses operand[1] for method receiver.
   178  					return *op.In[1].Go
   179  				}
   180  			}
   181  			return *op.In[0].Go
   182  		}
   183  		var tplName string
   184  		// If class overwrite is happening, that's not really a mask but a vreg.
   185  		if opOutShape == OneVregOut || opOutShape == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
   186  			switch opInShape {
   187  			case OneImmIn:
   188  				tplName = "pureVreg"
   189  				data.GoType = goType(gOp)
   190  			case PureVregIn:
   191  				tplName = "pureVreg"
   192  				data.GoType = goType(gOp)
   193  			case OneKmaskImmIn:
   194  				fallthrough
   195  			case OneKmaskIn:
   196  				tplName = "maskIn"
   197  				data.GoType = goType(gOp)
   198  				rearIdx := len(gOp.In) - 1
   199  				// Mask is at the end.
   200  				width := *gOp.In[rearIdx].ElemBits
   201  				data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", width, *gOp.In[rearIdx].Lanes)
   202  				data.ElementSize = width
   203  			case PureKmaskIn:
   204  				panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
   205  			}
   206  		} else if opOutShape == OneGregOut {
   207  			tplName = "pureVreg" // TODO this will be wrong
   208  			data.GoType = goType(gOp)
   209  		} else {
   210  			// OneKmaskOut case
   211  			data.MaskOutConvert = fmt.Sprintf("VPMOVMToVec%dx%d", *gOp.Out[0].ElemBits, *gOp.In[0].Lanes)
   212  			switch opInShape {
   213  			case OneImmIn:
   214  				fallthrough
   215  			case PureVregIn:
   216  				tplName = "maskOut"
   217  				data.GoType = goType(gOp)
   218  			case OneKmaskImmIn:
   219  				fallthrough
   220  			case OneKmaskIn:
   221  				tplName = "maskInMaskOut"
   222  				data.GoType = goType(gOp)
   223  				rearIdx := len(gOp.In) - 1
   224  				data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes)
   225  			case PureKmaskIn:
   226  				panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
   227  			}
   228  		}
   229  
   230  		if gOp.SpecialLower != nil {
   231  			if *gOp.SpecialLower == "sftimm" {
   232  				if data.GoType[0] == 'I' {
   233  					// only do these for signed types, it is a duplicate rewrite for unsigned
   234  					sftImmData := data
   235  					if tplName == "maskIn" {
   236  						sftImmData.tplName = "masksftimm"
   237  					} else {
   238  						sftImmData.tplName = "sftimm"
   239  					}
   240  					allData = append(allData, sftImmData)
   241  					asmCheck[sftImmData.Asm+"const"] = true
   242  				}
   243  			} else {
   244  				panic("simdgen sees unknwon special lower " + *gOp.SpecialLower + ", maybe implement it?")
   245  			}
   246  		}
   247  		if gOp.MemFeatures != nil && *gOp.MemFeatures == "vbcst" {
   248  			// sanity check
   249  			selected := true
   250  			for _, a := range gOp.In {
   251  				if a.TreatLikeAScalarOfSize != nil {
   252  					selected = false
   253  					break
   254  				}
   255  			}
   256  			if _, ok := memOpSeen[data.Asm]; ok {
   257  				selected = false
   258  			}
   259  			if selected {
   260  				memOpSeen[data.Asm] = true
   261  				lastVreg := gOp.In[vregInCnt-1]
   262  				// sanity check
   263  				if lastVreg.Class != "vreg" {
   264  					panic(fmt.Errorf("simdgen expects vbcst replaced operand to be a vreg, but %v found", lastVreg))
   265  				}
   266  				memOpData := data
   267  				// Remove the last vreg from the arg and change it to a load.
   268  				origArgs := data.Args[:len(data.Args)-1]
   269  				// Prepare imm args.
   270  				immArg := ""
   271  				immArgCombineOff := " [off] "
   272  				if immType != NoImm && immType != InvalidImm {
   273  					_, after, found := strings.Cut(origArgs, "]")
   274  					if found {
   275  						origArgs = after
   276  					}
   277  					immArg = "[c] "
   278  					immArgCombineOff = " [makeValAndOff(int32(int8(c)),off)] "
   279  				}
   280  				memOpData.ArgsLoadAddr = immArg + origArgs + fmt.Sprintf("l:(VMOVDQUload%d {sym} [off] ptr mem)", *lastVreg.Bits)
   281  				// Remove the last vreg from the arg and change it to "ptr".
   282  				memOpData.ArgsAddr = "{sym}" + immArgCombineOff + origArgs + "ptr"
   283  				if maskType == OneMask {
   284  					memOpData.ArgsAddr += " mask"
   285  					memOpData.ArgsLoadAddr += " mask"
   286  				}
   287  				memOpData.ArgsAddr += " mem"
   288  				if gOp.MemFeaturesData != nil {
   289  					_, feat2 := getVbcstData(*gOp.MemFeaturesData)
   290  					knownFeatChecks := map[string]string{
   291  						"AVX":    "v.Block.CPUfeatures.hasFeature(CPUavx)",
   292  						"AVX2":   "v.Block.CPUfeatures.hasFeature(CPUavx2)",
   293  						"AVX512": "v.Block.CPUfeatures.hasFeature(CPUavx512)",
   294  					}
   295  					memOpData.FeatCheck = knownFeatChecks[feat2]
   296  					memOpData.tplName = "vregMemFeatCheck"
   297  				} else {
   298  					memOpData.tplName = "vregMem"
   299  				}
   300  				memOptData = append(memOptData, memOpData)
   301  				asmCheck[memOpData.Asm+"load"] = true
   302  			}
   303  		}
   304  		// Generate the masked merging optimization rules
   305  		if gOp.hasMaskedMerging(maskType, opOutShape) {
   306  			// TODO: handle customized operand order and special lower.
   307  			maskElem := gOp.In[len(gOp.In)-1]
   308  			if maskElem.Bits == nil {
   309  				panic("mask has no bits")
   310  			}
   311  			if maskElem.ElemBits == nil {
   312  				panic("mask has no elemBits")
   313  			}
   314  			if maskElem.Lanes == nil {
   315  				panic("mask has no lanes")
   316  			}
   317  			switch *maskElem.Bits {
   318  			case 128, 256:
   319  				// VPBLENDVB cases.
   320  				noMaskName := machineOpName(NoMask, gOp)
   321  				ruleExisting, ok := maskedMergeOpts[noMaskName]
   322  				rule := fmt.Sprintf("(VPBLENDVB%d dst (%s %s) mask) && v.Block.CPUfeatures.hasFeature(CPUavx512) => (%sMerging dst %s (VPMOVVec%dx%dToM <types.TypeMask> mask))\n",
   323  					*maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args, *maskElem.ElemBits, *maskElem.Lanes)
   324  				if ok && ruleExisting != rule {
   325  					panic("multiple masked merge rules for one op")
   326  				} else {
   327  					maskedMergeOpts[noMaskName] = rule
   328  				}
   329  			case 512:
   330  				// VPBLENDM[BWDQ] cases.
   331  				noMaskName := machineOpName(NoMask, gOp)
   332  				ruleExisting, ok := maskedMergeOpts[noMaskName]
   333  				rule := fmt.Sprintf("(VPBLENDM%sMasked%d dst (%s %s) mask) => (%sMerging dst %s mask)\n",
   334  					s2n[*maskElem.ElemBits], *maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args)
   335  				if ok && ruleExisting != rule {
   336  					panic("multiple masked merge rules for one op")
   337  				} else {
   338  					maskedMergeOpts[noMaskName] = rule
   339  				}
   340  			}
   341  		}
   342  
   343  		if tplName == "pureVreg" && data.Args == data.ArgsOut {
   344  			data.Args = "..."
   345  			data.ArgsOut = "..."
   346  		}
   347  		data.tplName = tplName
   348  		if opr.NoGenericOps != nil && *opr.NoGenericOps == "true" ||
   349  			opr.SkipMaskedMethod() {
   350  			optData = append(optData, data)
   351  			continue
   352  		}
   353  		allData = append(allData, data)
   354  		asmCheck[data.Asm] = true
   355  	}
   356  
   357  	slices.SortFunc(allData, compareTplRuleData)
   358  
   359  	for _, data := range allData {
   360  		if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
   361  			panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err))
   362  		}
   363  	}
   364  
   365  	seen := make(map[string]bool)
   366  
   367  	for _, data := range optData {
   368  		if data.tplName == "maskIn" {
   369  			rule := data.MaskOptimization(asmCheck)
   370  			if seen[rule] {
   371  				continue
   372  			}
   373  			seen[rule] = true
   374  			buffer.WriteString(rule)
   375  		}
   376  	}
   377  
   378  	maskedMergeOptsRules := []string{}
   379  	for asm, rule := range maskedMergeOpts {
   380  		if !asmCheck[asm] {
   381  			continue
   382  		}
   383  		maskedMergeOptsRules = append(maskedMergeOptsRules, rule)
   384  	}
   385  	slices.Sort(maskedMergeOptsRules)
   386  	for _, rule := range maskedMergeOptsRules {
   387  		buffer.WriteString(rule)
   388  	}
   389  
   390  	for _, data := range memOptData {
   391  		if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
   392  			panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.Asm, err))
   393  		}
   394  	}
   395  
   396  	return buffer
   397  }
   398  

View as plain text