Source file src/simd/_gen/simdgen/gen_utility.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"fmt"
    11  	"go/format"
    12  	"log"
    13  	"os"
    14  	"path/filepath"
    15  	"reflect"
    16  	"slices"
    17  	"sort"
    18  	"strings"
    19  	"text/template"
    20  	"unicode"
    21  )
    22  
    23  func templateOf(temp, name string) *template.Template {
    24  	t, err := template.New(name).Parse(temp)
    25  	if err != nil {
    26  		panic(fmt.Errorf("failed to parse template %s: %w", name, err))
    27  	}
    28  	return t
    29  }
    30  
    31  func createPath(goroot string, file string) (*os.File, error) {
    32  	fp := filepath.Join(goroot, file)
    33  	dir := filepath.Dir(fp)
    34  	err := os.MkdirAll(dir, 0755)
    35  	if err != nil {
    36  		return nil, fmt.Errorf("failed to create directory %s: %w", dir, err)
    37  	}
    38  	f, err := os.Create(fp)
    39  	if err != nil {
    40  		return nil, fmt.Errorf("failed to create file %s: %w", fp, err)
    41  	}
    42  	return f, nil
    43  }
    44  
    45  func formatWriteAndClose(out *bytes.Buffer, goroot string, file string) {
    46  	b, err := format.Source(out.Bytes())
    47  	if err != nil {
    48  		fmt.Fprintf(os.Stderr, "%v\n", err)
    49  		fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
    50  		fmt.Fprintf(os.Stderr, "%v\n", err)
    51  		panic(err)
    52  	} else {
    53  		writeAndClose(b, goroot, file)
    54  	}
    55  }
    56  
    57  func writeAndClose(b []byte, goroot string, file string) {
    58  	ofile, err := createPath(goroot, file)
    59  	if err != nil {
    60  		panic(err)
    61  	}
    62  	ofile.Write(b)
    63  	ofile.Close()
    64  }
    65  
    66  // numberLines takes a slice of bytes, and returns a string where each line
    67  // is numbered, starting from 1.
    68  func numberLines(data []byte) string {
    69  	var buf bytes.Buffer
    70  	r := bytes.NewReader(data)
    71  	s := bufio.NewScanner(r)
    72  	for i := 1; s.Scan(); i++ {
    73  		fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
    74  	}
    75  	return buf.String()
    76  }
    77  
    78  type inShape uint8
    79  type outShape uint8
    80  type maskShape uint8
    81  type immShape uint8
    82  type memShape uint8
    83  
    84  const (
    85  	InvalidIn     inShape = iota
    86  	PureVregIn            // vector register input only
    87  	OneKmaskIn            // vector and kmask input
    88  	OneImmIn              // vector and immediate input
    89  	OneKmaskImmIn         // vector, kmask, and immediate inputs
    90  	PureKmaskIn           // only mask inputs.
    91  )
    92  
    93  const (
    94  	InvalidOut     outShape = iota
    95  	NoOut                   // no output
    96  	OneVregOut              // (one) vector register output
    97  	OneGregOut              // (one) general register output
    98  	OneKmaskOut             // mask output
    99  	OneVregOutAtIn          // the first input is also the output
   100  )
   101  
   102  const (
   103  	InvalidMask maskShape = iota
   104  	NoMask                // no mask
   105  	OneMask               // with mask (K1 to K7)
   106  	AllMasks              // a K mask instruction (K0-K7)
   107  )
   108  
   109  const (
   110  	InvalidImm  immShape = iota
   111  	NoImm                // no immediate
   112  	ConstImm             // const only immediate
   113  	VarImm               // pure imm argument provided by the users
   114  	ConstVarImm          // a combination of user arg and const
   115  )
   116  
   117  const (
   118  	InvalidMem memShape = iota
   119  	NoMem
   120  	VregMemIn // The instruction contains a mem input which is loading a vreg.
   121  )
   122  
   123  // opShape returns the several integers describing the shape of the operation,
   124  // and modified versions of the op:
   125  //
   126  // opNoImm is op with its inputs excluding the const imm.
   127  //
   128  // This function does not modify op.
   129  func (op *Operation) shape() (shapeIn inShape, shapeOut outShape, maskType maskShape, immType immShape,
   130  	opNoImm Operation) {
   131  	if len(op.Out) > 1 {
   132  		panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
   133  	}
   134  	var outputReg int
   135  	if len(op.Out) == 1 {
   136  		outputReg = op.Out[0].AsmPos
   137  		if op.Out[0].Class == "vreg" {
   138  			shapeOut = OneVregOut
   139  		} else if op.Out[0].Class == "greg" {
   140  			shapeOut = OneGregOut
   141  		} else if op.Out[0].Class == "mask" {
   142  			shapeOut = OneKmaskOut
   143  		} else {
   144  			panic(fmt.Errorf("simdgen only supports output of class vreg or mask: %s", op))
   145  		}
   146  	} else {
   147  		shapeOut = NoOut
   148  		// TODO: are these only Load/Stores?
   149  		// We manually supported two Load and Store, are those enough?
   150  		panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
   151  	}
   152  	hasImm := false
   153  	maskCount := 0
   154  	hasVreg := false
   155  	for _, in := range op.In {
   156  		if in.AsmPos == outputReg {
   157  			if shapeOut != OneVregOutAtIn && in.AsmPos == 0 && in.Class == "vreg" {
   158  				shapeOut = OneVregOutAtIn
   159  			} else {
   160  				panic(fmt.Errorf("simdgen only support output and input sharing the same position case of \"the first input is vreg and the only output\": %s", op))
   161  			}
   162  		}
   163  		if in.Class == "immediate" {
   164  			// A manual check on XED data found that AMD64 SIMD instructions at most
   165  			// have 1 immediates. So we don't need to check this here.
   166  			if *in.Bits != 8 {
   167  				panic(fmt.Errorf("simdgen only supports immediates of 8 bits: %s", op))
   168  			}
   169  			hasImm = true
   170  		} else if in.Class == "mask" {
   171  			maskCount++
   172  		} else {
   173  			hasVreg = true
   174  		}
   175  	}
   176  	opNoImm = *op
   177  
   178  	removeImm := func(o *Operation) {
   179  		o.In = o.In[1:]
   180  	}
   181  	if hasImm {
   182  		removeImm(&opNoImm)
   183  		if op.In[0].Const != nil {
   184  			if op.In[0].ImmOffset != nil {
   185  				immType = ConstVarImm
   186  			} else {
   187  				immType = ConstImm
   188  			}
   189  		} else if op.In[0].ImmOffset != nil {
   190  			immType = VarImm
   191  		} else {
   192  			panic(fmt.Errorf("simdgen requires imm to have at least one of ImmOffset or Const set: %s", op))
   193  		}
   194  	} else {
   195  		immType = NoImm
   196  	}
   197  	if maskCount == 0 {
   198  		maskType = NoMask
   199  	} else {
   200  		maskType = OneMask
   201  	}
   202  	checkPureMask := func() bool {
   203  		if hasImm {
   204  			panic(fmt.Errorf("simdgen does not support immediates in pure mask operations: %s", op))
   205  		}
   206  		if hasVreg {
   207  			panic(fmt.Errorf("simdgen does not support more than 1 masks in non-pure mask operations: %s", op))
   208  		}
   209  		return false
   210  	}
   211  	if !hasImm && maskCount == 0 {
   212  		shapeIn = PureVregIn
   213  	} else if !hasImm && maskCount > 0 {
   214  		if maskCount == 1 {
   215  			shapeIn = OneKmaskIn
   216  		} else {
   217  			if checkPureMask() {
   218  				return
   219  			}
   220  			shapeIn = PureKmaskIn
   221  			maskType = AllMasks
   222  		}
   223  	} else if hasImm && maskCount == 0 {
   224  		shapeIn = OneImmIn
   225  	} else {
   226  		if maskCount == 1 {
   227  			shapeIn = OneKmaskImmIn
   228  		} else {
   229  			checkPureMask()
   230  			return
   231  		}
   232  	}
   233  	return
   234  }
   235  
   236  // regShape returns a string representation of the register shape.
   237  func (op *Operation) regShape(mem memShape) (string, error) {
   238  	_, _, _, _, gOp := op.shape()
   239  	var regInfo, fixedName string
   240  	var vRegInCnt, gRegInCnt, kMaskInCnt, vRegOutCnt, gRegOutCnt, kMaskOutCnt, memInCnt, memOutCnt int
   241  	for i, in := range gOp.In {
   242  		switch in.Class {
   243  		case "vreg":
   244  			vRegInCnt++
   245  		case "greg":
   246  			gRegInCnt++
   247  		case "mask":
   248  			kMaskInCnt++
   249  		case "memory":
   250  			if mem != VregMemIn {
   251  				panic("simdgen only knows VregMemIn in regShape")
   252  			}
   253  			memInCnt++
   254  			vRegInCnt++
   255  		}
   256  		if in.FixedReg != nil {
   257  			fixedName = fmt.Sprintf("%sAtIn%d", *in.FixedReg, i)
   258  		}
   259  	}
   260  	for i, out := range gOp.Out {
   261  		// If class overwrite is happening, that's not really a mask but a vreg.
   262  		if out.Class == "vreg" || out.OverwriteClass != nil {
   263  			vRegOutCnt++
   264  		} else if out.Class == "greg" {
   265  			gRegOutCnt++
   266  		} else if out.Class == "mask" {
   267  			kMaskOutCnt++
   268  		} else if out.Class == "memory" {
   269  			if mem != VregMemIn {
   270  				panic("simdgen only knows VregMemIn in regShape")
   271  			}
   272  			vRegOutCnt++
   273  			memOutCnt++
   274  		}
   275  		if out.FixedReg != nil {
   276  			fixedName = fmt.Sprintf("%sAtIn%d", *out.FixedReg, i)
   277  		}
   278  	}
   279  	var inRegs, inMasks, outRegs, outMasks string
   280  
   281  	rmAbbrev := func(s string, i int) string {
   282  		if i == 0 {
   283  			return ""
   284  		}
   285  		if i == 1 {
   286  			return s
   287  		}
   288  		return fmt.Sprintf("%s%d", s, i)
   289  
   290  	}
   291  
   292  	inRegs = rmAbbrev("v", vRegInCnt)
   293  	inRegs += rmAbbrev("gp", gRegInCnt)
   294  	inMasks = rmAbbrev("k", kMaskInCnt)
   295  
   296  	outRegs = rmAbbrev("v", vRegOutCnt)
   297  	outRegs += rmAbbrev("gp", gRegOutCnt)
   298  	outMasks = rmAbbrev("k", kMaskOutCnt)
   299  
   300  	if kMaskInCnt == 0 && kMaskOutCnt == 0 && gRegInCnt == 0 && gRegOutCnt == 0 {
   301  		// For pure v we can abbreviate it as v%d%d.
   302  		regInfo = fmt.Sprintf("v%d%d", vRegInCnt, vRegOutCnt)
   303  	} else if kMaskInCnt == 0 && kMaskOutCnt == 0 {
   304  		regInfo = fmt.Sprintf("%s%s", inRegs, outRegs)
   305  	} else {
   306  		regInfo = fmt.Sprintf("%s%s%s%s", inRegs, inMasks, outRegs, outMasks)
   307  	}
   308  	if memInCnt > 0 {
   309  		if memInCnt == 1 {
   310  			regInfo += "load"
   311  		} else {
   312  			panic("simdgen does not understand more than 1 mem op as of now")
   313  		}
   314  	}
   315  	if memOutCnt > 0 {
   316  		panic("simdgen does not understand memory as output as of now")
   317  	}
   318  	regInfo += fixedName
   319  	return regInfo, nil
   320  }
   321  
   322  // sortOperand sorts op.In by putting immediates first, then vreg, and mask the last.
   323  // TODO: verify that this is a safe assumption of the prog structure.
   324  // from my observation looks like in asm, imms are always the first,
   325  // masks are always the last, with vreg in between.
   326  func (op *Operation) sortOperand() {
   327  	priority := map[string]int{"immediate": 0, "vreg": 1, "greg": 1, "mask": 2}
   328  	sort.SliceStable(op.In, func(i, j int) bool {
   329  		pi := priority[op.In[i].Class]
   330  		pj := priority[op.In[j].Class]
   331  		if pi != pj {
   332  			return pi < pj
   333  		}
   334  		return op.In[i].AsmPos < op.In[j].AsmPos
   335  	})
   336  }
   337  
   338  // goNormalType returns the Go type name for the result of an Op that
   339  // does not return a vector, i.e., that returns a result in a general
   340  // register.  Currently there's only one family of Ops in Go's simd library
   341  // that does this (GetElem), and so this is specialized to work for that,
   342  // but the problem (mismatch betwen hardware register width and Go type
   343  // width) seems likely to recur if there are any other cases.
   344  func (op Operation) goNormalType() string {
   345  	if op.Go == "GetElem" {
   346  		// GetElem returns an element of the vector into a general register
   347  		// but as far as the hardware is concerned, that result is either 32
   348  		// or 64 bits wide, no matter what the vector element width is.
   349  		// This is not "wrong" but it is not the right answer for Go source code.
   350  		// To get the Go type right, combine the base type ("int", "uint", "float"),
   351  		// with the input vector element width in bits (8,16,32,64).
   352  
   353  		at := 0 // proper value of at depends on whether immediate was stripped or not
   354  		if op.In[at].Class == "immediate" {
   355  			at++
   356  		}
   357  		return fmt.Sprintf("%s%d", *op.Out[0].Base, *op.In[at].ElemBits)
   358  	}
   359  	panic(fmt.Errorf("Implement goNormalType for %v", op))
   360  }
   361  
   362  // SSAType returns the string for the type reference in SSA generation,
   363  // for example in the intrinsics generating template.
   364  func (op Operation) SSAType() string {
   365  	if op.Out[0].Class == "greg" {
   366  		return fmt.Sprintf("types.Types[types.T%s]", strings.ToUpper(op.goNormalType()))
   367  	}
   368  	return fmt.Sprintf("types.TypeVec%d", *op.Out[0].Bits)
   369  }
   370  
   371  // GoType returns the Go type returned by this operation (relative to the simd package),
   372  // for example "int32" or "Int8x16".  This is used in a template.
   373  func (op Operation) GoType() string {
   374  	if op.Out[0].Class == "greg" {
   375  		return op.goNormalType()
   376  	}
   377  	return *op.Out[0].Go
   378  }
   379  
   380  // ImmName returns the name to use for an operation's immediate operand.
   381  // This can be overriden in the yaml with "name" on an operand,
   382  // otherwise, for now, "constant"
   383  func (op Operation) ImmName() string {
   384  	return op.Op0Name("constant")
   385  }
   386  
   387  func (o Operand) OpName(s string) string {
   388  	if n := o.Name; n != nil {
   389  		return *n
   390  	}
   391  	if o.Class == "mask" {
   392  		return "mask"
   393  	}
   394  	return s
   395  }
   396  
   397  func (o Operand) OpNameAndType(s string) string {
   398  	return o.OpName(s) + " " + *o.Go
   399  }
   400  
   401  // GoExported returns [Go] with first character capitalized.
   402  func (op Operation) GoExported() string {
   403  	return capitalizeFirst(op.Go)
   404  }
   405  
   406  // DocumentationExported returns [Documentation] with method name capitalized.
   407  func (op Operation) DocumentationExported() string {
   408  	return strings.ReplaceAll(op.Documentation, op.Go, op.GoExported())
   409  }
   410  
   411  // Op0Name returns the name to use for the 0 operand,
   412  // if any is present, otherwise the parameter is used.
   413  func (op Operation) Op0Name(s string) string {
   414  	return op.In[0].OpName(s)
   415  }
   416  
   417  // Op1Name returns the name to use for the 1 operand,
   418  // if any is present, otherwise the parameter is used.
   419  func (op Operation) Op1Name(s string) string {
   420  	return op.In[1].OpName(s)
   421  }
   422  
   423  // Op2Name returns the name to use for the 2 operand,
   424  // if any is present, otherwise the parameter is used.
   425  func (op Operation) Op2Name(s string) string {
   426  	return op.In[2].OpName(s)
   427  }
   428  
   429  // Op3Name returns the name to use for the 3 operand,
   430  // if any is present, otherwise the parameter is used.
   431  func (op Operation) Op3Name(s string) string {
   432  	return op.In[3].OpName(s)
   433  }
   434  
   435  // Op0NameAndType returns the name and type to use for
   436  // the 0 operand, if a name is provided, otherwise
   437  // the parameter value is used as the default.
   438  func (op Operation) Op0NameAndType(s string) string {
   439  	return op.In[0].OpNameAndType(s)
   440  }
   441  
   442  // Op1NameAndType returns the name and type to use for
   443  // the 1 operand, if a name is provided, otherwise
   444  // the parameter value is used as the default.
   445  func (op Operation) Op1NameAndType(s string) string {
   446  	return op.In[1].OpNameAndType(s)
   447  }
   448  
   449  // Op2NameAndType returns the name and type to use for
   450  // the 2 operand, if a name is provided, otherwise
   451  // the parameter value is used as the default.
   452  func (op Operation) Op2NameAndType(s string) string {
   453  	return op.In[2].OpNameAndType(s)
   454  }
   455  
   456  // Op3NameAndType returns the name and type to use for
   457  // the 3 operand, if a name is provided, otherwise
   458  // the parameter value is used as the default.
   459  func (op Operation) Op3NameAndType(s string) string {
   460  	return op.In[3].OpNameAndType(s)
   461  }
   462  
   463  // Op4NameAndType returns the name and type to use for
   464  // the 4 operand, if a name is provided, otherwise
   465  // the parameter value is used as the default.
   466  func (op Operation) Op4NameAndType(s string) string {
   467  	return op.In[4].OpNameAndType(s)
   468  }
   469  
   470  var immClasses []string = []string{"BAD0Imm", "BAD1Imm", "op1Imm8", "op2Imm8", "op3Imm8", "op4Imm8"}
   471  var classes []string = []string{"BAD0", "op1", "op2", "op3", "op4"}
   472  
   473  // classifyOp returns a classification string, modified operation, and perhaps error based
   474  // on the stub and intrinsic shape for the operation.
   475  // The classification string is in the regular expression set "op[1234](Imm8)?(_<order>)?"
   476  // where the "<order>" suffix is optionally attached to the Operation in its input yaml.
   477  // The classification string is used to select a template or a clause of a template
   478  // for intrinsics declaration and the ssagen intrinisics glue code in the compiler.
   479  func classifyOp(op Operation) (string, Operation, error) {
   480  	_, _, _, immType, gOp := op.shape()
   481  
   482  	var class string
   483  
   484  	if immType == VarImm || immType == ConstVarImm {
   485  		switch l := len(op.In); l {
   486  		case 1:
   487  			return "", op, fmt.Errorf("simdgen does not recognize this operation of only immediate input: %s", op)
   488  		case 2, 3, 4, 5:
   489  			class = immClasses[l]
   490  		default:
   491  			return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
   492  		}
   493  		if order := op.OperandOrder; order != nil {
   494  			class += "_" + *order
   495  		}
   496  		return class, op, nil
   497  	} else {
   498  		switch l := len(gOp.In); l {
   499  		case 1, 2, 3, 4:
   500  			class = classes[l]
   501  		default:
   502  			return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
   503  		}
   504  		if order := op.OperandOrder; order != nil {
   505  			class += "_" + *order
   506  		}
   507  		return class, gOp, nil
   508  	}
   509  }
   510  
   511  func checkVecAsScalar(op Operation) (idx int, err error) {
   512  	idx = -1
   513  	sSize := 0
   514  	for i, o := range op.In {
   515  		if o.TreatLikeAScalarOfSize != nil {
   516  			if idx == -1 {
   517  				idx = i
   518  				sSize = *o.TreatLikeAScalarOfSize
   519  			} else {
   520  				err = fmt.Errorf("simdgen only supports one TreatLikeAScalarOfSize in the arg list: %s", op)
   521  				return
   522  			}
   523  		}
   524  	}
   525  	if idx >= 0 {
   526  		if sSize != 8 && sSize != 16 && sSize != 32 && sSize != 64 {
   527  			err = fmt.Errorf("simdgen does not recognize this uint size: %d, %s", sSize, op)
   528  			return
   529  		}
   530  	}
   531  	return
   532  }
   533  
   534  func rewriteVecAsScalarRegInfo(op Operation, regInfo string) (string, error) {
   535  	idx, err := checkVecAsScalar(op)
   536  	if err != nil {
   537  		return "", err
   538  	}
   539  	if idx != -1 {
   540  		if regInfo == "v21" {
   541  			regInfo = "vfpv"
   542  		} else if regInfo == "v2kv" {
   543  			regInfo = "vfpkv"
   544  		} else if regInfo == "v31" {
   545  			regInfo = "v2fpv"
   546  		} else if regInfo == "v3kv" {
   547  			regInfo = "v2fpkv"
   548  		} else {
   549  			return "", fmt.Errorf("simdgen does not recognize uses of treatLikeAScalarOfSize with op regShape %s in op: %s", regInfo, op)
   550  		}
   551  	}
   552  	return regInfo, nil
   553  }
   554  
   555  func rewriteLastVregToMem(op Operation) Operation {
   556  	newIn := make([]Operand, len(op.In))
   557  	lastVregIdx := -1
   558  	for i := range len(op.In) {
   559  		newIn[i] = op.In[i]
   560  		if op.In[i].Class == "vreg" {
   561  			lastVregIdx = i
   562  		}
   563  	}
   564  	// vbcst operations put their mem op always as the last vreg.
   565  	if lastVregIdx == -1 {
   566  		panic("simdgen cannot find one vreg in the mem op vreg original")
   567  	}
   568  	newIn[lastVregIdx].Class = "memory"
   569  	op.In = newIn
   570  
   571  	return op
   572  }
   573  
   574  // dedup is deduping operations in the full structure level.
   575  func dedup(ops []Operation) (deduped []Operation) {
   576  	for _, op := range ops {
   577  		seen := false
   578  		for _, dop := range deduped {
   579  			if reflect.DeepEqual(op, dop) {
   580  				seen = true
   581  				break
   582  			}
   583  		}
   584  		if !seen {
   585  			deduped = append(deduped, op)
   586  		}
   587  	}
   588  	return
   589  }
   590  
   591  func (op Operation) GenericName() string {
   592  	if op.OperandOrder != nil {
   593  		switch *op.OperandOrder {
   594  		case "21Type1", "231Type1":
   595  			// Permute uses operand[1] for method receiver.
   596  			return op.Go + *op.In[1].Go
   597  		}
   598  	}
   599  	if op.In[0].Class == "immediate" {
   600  		return op.Go + *op.In[1].Go
   601  	}
   602  	return op.Go + *op.In[0].Go
   603  }
   604  
   605  // dedupGodef is deduping operations in [Op.Go]+[*Op.In[0].Go] level.
   606  // By deduping, it means picking the least advanced architecture that satisfy the requirement:
   607  // AVX512 will be least preferred.
   608  // If FlagNoDedup is set, it will report the duplicates to the console.
   609  func dedupGodef(ops []Operation) ([]Operation, error) {
   610  	seen := map[string][]Operation{}
   611  	for _, op := range ops {
   612  		_, _, _, _, gOp := op.shape()
   613  
   614  		gN := gOp.GenericName()
   615  		seen[gN] = append(seen[gN], op)
   616  	}
   617  	if *FlagReportDup {
   618  		for gName, dup := range seen {
   619  			if len(dup) > 1 {
   620  				log.Printf("Duplicate for %s:\n", gName)
   621  				for _, op := range dup {
   622  					log.Printf("%s\n", op)
   623  				}
   624  			}
   625  		}
   626  		return ops, nil
   627  	}
   628  	isAVX512 := func(op Operation) bool {
   629  		return strings.Contains(op.CPUFeature, "AVX512")
   630  	}
   631  	deduped := []Operation{}
   632  	for _, dup := range seen {
   633  		if len(dup) > 1 {
   634  			slices.SortFunc(dup, func(i, j Operation) int {
   635  				// Put non-AVX512 candidates at the beginning
   636  				if !isAVX512(i) && isAVX512(j) {
   637  					return -1
   638  				}
   639  				if isAVX512(i) && !isAVX512(j) {
   640  					return 1
   641  				}
   642  				if i.CPUFeature != j.CPUFeature {
   643  					return strings.Compare(i.CPUFeature, j.CPUFeature)
   644  				}
   645  				// Weirdly Intel sometimes has duplicated definitions for the same instruction,
   646  				// this confuses the XED mem-op merge logic: [MemFeature] will only be attached to an instruction
   647  				// for only once, which means that for essentially duplicated instructions only one will have the
   648  				// proper [MemFeature] set. We have to make this sort deterministic for [MemFeature].
   649  				if i.MemFeatures != nil && j.MemFeatures == nil {
   650  					return -1
   651  				}
   652  				if i.MemFeatures == nil && j.MemFeatures != nil {
   653  					return 1
   654  				}
   655  				// Their order does not matter anymore, at least for now.
   656  				return 0
   657  			})
   658  		}
   659  		deduped = append(deduped, dup[0])
   660  	}
   661  	slices.SortFunc(deduped, compareOperations)
   662  	return deduped, nil
   663  }
   664  
   665  // Copy op.ConstImm to op.In[0].Const
   666  // This is a hack to reduce the size of defs we need for const imm operations.
   667  func copyConstImm(ops []Operation) error {
   668  	for _, op := range ops {
   669  		if op.ConstImm == nil {
   670  			continue
   671  		}
   672  		_, _, _, immType, _ := op.shape()
   673  
   674  		if immType == ConstImm || immType == ConstVarImm {
   675  			op.In[0].Const = op.ConstImm
   676  		}
   677  		// Otherwise, just not port it - e.g. {VPCMP[BWDQ] imm=0} and {VPCMPEQ[BWDQ]} are
   678  		// the same operations "Equal", [dedupgodef] should be able to distinguish them.
   679  	}
   680  	return nil
   681  }
   682  
   683  func capitalizeFirst(s string) string {
   684  	if s == "" {
   685  		return ""
   686  	}
   687  	// Convert the string to a slice of runes to handle multi-byte characters correctly.
   688  	r := []rune(s)
   689  	r[0] = unicode.ToUpper(r[0])
   690  	return string(r)
   691  }
   692  
   693  // overwrite corrects some errors due to:
   694  //   - The XED data is wrong
   695  //   - Go's SIMD API requirement, for example AVX2 compares should also produce masks.
   696  //     This rewrite has strict constraints, please see the error message.
   697  //     These constraints are also explointed in [writeSIMDRules], [writeSIMDMachineOps]
   698  //     and [writeSIMDSSA], please be careful when updating these constraints.
   699  func overwrite(ops []Operation) error {
   700  	hasClassOverwrite := false
   701  	overwrite := func(op []Operand, idx int, o Operation) error {
   702  		if op[idx].OverwriteElementBits != nil {
   703  			if op[idx].ElemBits == nil {
   704  				panic(fmt.Errorf("ElemBits is nil at operand %d of %v", idx, o))
   705  			}
   706  			*op[idx].ElemBits = *op[idx].OverwriteElementBits
   707  			*op[idx].Lanes = *op[idx].Bits / *op[idx].ElemBits
   708  			*op[idx].Go = fmt.Sprintf("%s%dx%d", capitalizeFirst(*op[idx].Base), *op[idx].ElemBits, *op[idx].Lanes)
   709  		}
   710  		if op[idx].OverwriteClass != nil {
   711  			if op[idx].OverwriteBase == nil {
   712  				panic(fmt.Errorf("simdgen: [OverwriteClass] must be set together with [OverwriteBase]: %s", op[idx]))
   713  			}
   714  			oBase := *op[idx].OverwriteBase
   715  			oClass := *op[idx].OverwriteClass
   716  			if oClass != "mask" {
   717  				panic(fmt.Errorf("simdgen: [Class] overwrite only supports overwritting to mask: %s", op[idx]))
   718  			}
   719  			if oBase != "int" {
   720  				panic(fmt.Errorf("simdgen: [Class] overwrite must set [OverwriteBase] to int: %s", op[idx]))
   721  			}
   722  			if op[idx].Class != "vreg" {
   723  				panic(fmt.Errorf("simdgen: [Class] overwrite must be overwriting [Class] from vreg: %s", op[idx]))
   724  			}
   725  			hasClassOverwrite = true
   726  			*op[idx].Base = oBase
   727  			op[idx].Class = oClass
   728  			*op[idx].Go = fmt.Sprintf("Mask%dx%d", *op[idx].ElemBits, *op[idx].Lanes)
   729  		} else if op[idx].OverwriteBase != nil {
   730  			oBase := *op[idx].OverwriteBase
   731  			*op[idx].Go = strings.ReplaceAll(*op[idx].Go, capitalizeFirst(*op[idx].Base), capitalizeFirst(oBase))
   732  			if op[idx].Class == "greg" {
   733  				*op[idx].Go = strings.ReplaceAll(*op[idx].Go, *op[idx].Base, oBase)
   734  			}
   735  			*op[idx].Base = oBase
   736  		}
   737  		return nil
   738  	}
   739  	for i, o := range ops {
   740  		hasClassOverwrite = false
   741  		for j := range ops[i].In {
   742  			if err := overwrite(ops[i].In, j, o); err != nil {
   743  				return err
   744  			}
   745  			if hasClassOverwrite {
   746  				return fmt.Errorf("simdgen does not support [OverwriteClass] in inputs: %s", ops[i])
   747  			}
   748  		}
   749  		for j := range ops[i].Out {
   750  			if err := overwrite(ops[i].Out, j, o); err != nil {
   751  				return err
   752  			}
   753  		}
   754  		if hasClassOverwrite {
   755  			for _, in := range ops[i].In {
   756  				if in.Class == "mask" {
   757  					return fmt.Errorf("simdgen only supports [OverwriteClass] for operations without mask inputs")
   758  				}
   759  			}
   760  		}
   761  	}
   762  	return nil
   763  }
   764  
   765  // reportXEDInconsistency reports potential XED inconsistencies.
   766  // We can add more fields to [Operation] to enable more checks and implement it here.
   767  // Supported checks:
   768  // [NameAndSizeCheck]: NAME[BWDQ] should set the elemBits accordingly.
   769  // This check is useful to find inconsistencies, then we can add overwrite fields to
   770  // those defs to correct them manually.
   771  func reportXEDInconsistency(ops []Operation) error {
   772  	for _, o := range ops {
   773  		if o.NameAndSizeCheck != nil {
   774  			suffixSizeMap := map[byte]int{'B': 8, 'W': 16, 'D': 32, 'Q': 64}
   775  			checkOperand := func(opr Operand) error {
   776  				if opr.ElemBits == nil {
   777  					return fmt.Errorf("simdgen expects elemBits to be set when performing NameAndSizeCheck")
   778  				}
   779  				if v, ok := suffixSizeMap[o.Asm[len(o.Asm)-1]]; !ok {
   780  					return fmt.Errorf("simdgen expects asm to end with [BWDQ] when performing NameAndSizeCheck")
   781  				} else {
   782  					if v != *opr.ElemBits {
   783  						return fmt.Errorf("simdgen finds NameAndSizeCheck inconsistency in def: %s", o)
   784  					}
   785  				}
   786  				return nil
   787  			}
   788  			for _, in := range o.In {
   789  				if in.Class != "vreg" && in.Class != "mask" {
   790  					continue
   791  				}
   792  				if in.TreatLikeAScalarOfSize != nil {
   793  					// This is an irregular operand, don't check it.
   794  					continue
   795  				}
   796  				if err := checkOperand(in); err != nil {
   797  					return err
   798  				}
   799  			}
   800  			for _, out := range o.Out {
   801  				if err := checkOperand(out); err != nil {
   802  					return err
   803  				}
   804  			}
   805  		}
   806  	}
   807  	return nil
   808  }
   809  
   810  func (o *Operation) hasMaskedMerging(maskType maskShape, outType outShape) bool {
   811  	// BLEND and VMOVDQU are not user-facing ops so we should filter them out.
   812  	return o.OperandOrder == nil && o.SpecialLower == nil && maskType == OneMask && outType == OneVregOut &&
   813  		len(o.InVariant) == 1 && !strings.Contains(o.Asm, "BLEND") && !strings.Contains(o.Asm, "VMOVDQU")
   814  }
   815  
   816  func getVbcstData(s string) (feat1Match, feat2Match string) {
   817  	_, err := fmt.Sscanf(s, "feat1=%[^;];feat2=%s", &feat1Match, &feat2Match)
   818  	if err != nil {
   819  		panic(err)
   820  	}
   821  	return
   822  }
   823  
   824  func (o Operation) String() string {
   825  	return pprints(o)
   826  }
   827  
   828  func (op Operand) String() string {
   829  	return pprints(op)
   830  }
   831  

View as plain text