Source file src/simd/archsimd/_gen/simdgen/gen_simdTypes.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"cmp"
    10  	"fmt"
    11  	"maps"
    12  	"slices"
    13  	"sort"
    14  	"strings"
    15  	"unicode"
    16  )
    17  
    18  type simdType struct {
    19  	Name                    string // The go type name of this simd type, for example Int32x4.
    20  	Lanes                   int    // The number of elements in this vector/mask.
    21  	Base                    string // The element's type, like for Int32x4 it will be int32.
    22  	Fields                  string // The struct fields, it should be right formatted.
    23  	Type                    string // Either "mask" or "vreg"
    24  	VectorCounterpart       string // For mask use only: just replacing the "Mask" in [simdType.Name] with "Int"
    25  	ReshapedVectorWithAndOr string // For mask use only: vector AND and OR are only available in some shape with element width 32.
    26  	Size                    int    // The size of the vector type
    27  }
    28  
    29  func (x simdType) ElemBits() int {
    30  	return x.Size / x.Lanes
    31  }
    32  
    33  func (x simdType) Article() string {
    34  	if strings.HasPrefix(x.Name, "Int") {
    35  		return "an"
    36  	}
    37  	return "a" // Float, Uint
    38  }
    39  
    40  // LanesContainer returns the smallest int/uint bit size that is
    41  // large enough to hold one bit for each lane.  E.g., Mask32x4
    42  // is 4 lanes, and a uint8 is the smallest uint that has 4 bits.
    43  func (x simdType) LanesContainer() int {
    44  	if x.Lanes > 64 {
    45  		panic("too many lanes")
    46  	}
    47  	if x.Lanes > 32 {
    48  		return 64
    49  	}
    50  	if x.Lanes > 16 {
    51  		return 32
    52  	}
    53  	if x.Lanes > 8 {
    54  		return 16
    55  	}
    56  	return 8
    57  }
    58  
    59  // MaskedLoadStoreFilter encodes which simd type type currently
    60  // get masked loads/stores generated, it is used in two places,
    61  // this forces coordination.
    62  func (x simdType) MaskedLoadStoreFilter() bool {
    63  	return x.Size == 512 || x.ElemBits() >= 32 && x.Type != "mask"
    64  }
    65  
    66  func (x simdType) IntelSizeSuffix() string {
    67  	switch x.ElemBits() {
    68  	case 8:
    69  		return "B"
    70  	case 16:
    71  		return "W"
    72  	case 32:
    73  		return "D"
    74  	case 64:
    75  		return "Q"
    76  	}
    77  	panic("oops")
    78  }
    79  
    80  func (x simdType) MaskedLoadDoc() string {
    81  	if x.Size == 512 || x.ElemBits() < 32 {
    82  		return fmt.Sprintf("// Asm: VMOVDQU%d.Z, CPU Feature: AVX512", x.ElemBits())
    83  	} else {
    84  		return fmt.Sprintf("// Asm: VMASKMOV%s, CPU Feature: AVX2", x.IntelSizeSuffix())
    85  	}
    86  }
    87  
    88  func (x simdType) MaskedStoreDoc() string {
    89  	if x.Size == 512 || x.ElemBits() < 32 {
    90  		return fmt.Sprintf("// Asm: VMOVDQU%d, CPU Feature: AVX512", x.ElemBits())
    91  	} else {
    92  		return fmt.Sprintf("// Asm: VMASKMOV%s, CPU Feature: AVX2", x.IntelSizeSuffix())
    93  	}
    94  }
    95  
    96  func (x simdType) ToBitsDoc() string {
    97  	if x.Size == 512 || x.ElemBits() == 16 {
    98  		return fmt.Sprintf("// Asm: KMOV%s, CPU Features: AVX512", x.IntelSizeSuffix())
    99  	}
   100  	// 128/256 bit vectors with 8, 32, 64 bit elements
   101  	var asm string
   102  	var feat string
   103  	switch x.ElemBits() {
   104  	case 8:
   105  		asm = "VPMOVMSKB"
   106  		if x.Size == 256 {
   107  			feat = "AVX2"
   108  		} else {
   109  			feat = "AVX"
   110  		}
   111  	case 32:
   112  		asm = "VMOVMSKPS"
   113  		feat = "AVX"
   114  	case 64:
   115  		asm = "VMOVMSKPD"
   116  		feat = "AVX"
   117  	default:
   118  		panic("unexpected ElemBits")
   119  	}
   120  	return fmt.Sprintf("// Asm: %s, CPU Features: %s", asm, feat)
   121  }
   122  
   123  func compareSimdTypes(x, y simdType) int {
   124  	// "vreg" then "mask"
   125  	if c := -compareNatural(x.Type, y.Type); c != 0 {
   126  		return c
   127  	}
   128  	// want "flo" < "int" < "uin" (and then 8 < 16 < 32 < 64),
   129  	// not "int16" < "int32" < "int64" < "int8")
   130  	// so limit comparison to first 3 bytes in string.
   131  	if c := compareNatural(x.Base[:3], y.Base[:3]); c != 0 {
   132  		return c
   133  	}
   134  	// base type size, 8 < 16 < 32 < 64
   135  	if c := x.ElemBits() - y.ElemBits(); c != 0 {
   136  		return c
   137  	}
   138  	// vector size last
   139  	return x.Size - y.Size
   140  }
   141  
   142  type simdTypeMap map[int][]simdType
   143  
   144  type simdTypePair struct {
   145  	Tsrc simdType
   146  	Tdst simdType
   147  }
   148  
   149  func compareSimdTypePairs(x, y simdTypePair) int {
   150  	c := compareSimdTypes(x.Tsrc, y.Tsrc)
   151  	if c != 0 {
   152  		return c
   153  	}
   154  	return compareSimdTypes(x.Tdst, y.Tdst)
   155  }
   156  
   157  const simdPackageHeader = generatedHeader + `
   158  //go:build goexperiment.simd
   159  
   160  package archsimd
   161  `
   162  
   163  const simdTypesTemplates = `
   164  {{define "sizeTmpl"}}
   165  // v{{.}} is a tag type that tells the compiler that this is really {{.}}-bit SIMD
   166  type v{{.}} struct {
   167  	_{{.}} [0]func() // uncomparable
   168  }
   169  {{end}}
   170  
   171  {{define "typeTmpl"}}
   172  {{- if eq .Type "mask"}}
   173  // {{.Name}} is a mask for a SIMD vector of {{.Lanes}} {{.ElemBits}}-bit elements.
   174  {{- else}}
   175  // {{.Name}} is a {{.Size}}-bit SIMD vector of {{.Lanes}} {{.Base}}s.
   176  {{- end}}
   177  type {{.Name}} struct {
   178  {{.Fields}}
   179  }
   180  
   181  {{end}}
   182  `
   183  
   184  const simdFeaturesTemplate = `
   185  import "internal/cpu"
   186  
   187  type X86Features struct {}
   188  
   189  var X86 X86Features
   190  
   191  {{range .}}
   192  {{- if eq .Feature "AVX512"}}
   193  // {{.Feature}} returns whether the CPU supports the AVX512F+CD+BW+DQ+VL features.
   194  //
   195  // These five CPU features are bundled together, and no use of AVX-512
   196  // is allowed unless all of these features are supported together.
   197  // Nearly every CPU that has shipped with any support for AVX-512 has
   198  // supported all five of these features.
   199  {{- else -}}
   200  // {{.Feature}} returns whether the CPU supports the {{.Feature}} feature.
   201  {{- end}}
   202  //
   203  // {{.Feature}} is defined on all GOARCHes, but will only return true on
   204  // GOARCH {{.GoArch}}.
   205  func (X86Features) {{.Feature}}() bool {
   206  	return cpu.X86.Has{{.Feature}}
   207  }
   208  {{end}}
   209  `
   210  
   211  const simdLoadStoreTemplate = `
   212  // Len returns the number of elements in {{.Article}} {{.Name}}.
   213  func (x {{.Name}}) Len() int { return {{.Lanes}} }
   214  
   215  // Load{{.Name}} loads {{.Article}} {{.Name}} from an array.
   216  //
   217  //go:noescape
   218  func Load{{.Name}}(y *[{{.Lanes}}]{{.Base}}) {{.Name}}
   219  
   220  // Store stores {{.Article}} {{.Name}} to an array.
   221  //
   222  //go:noescape
   223  func (x {{.Name}}) Store(y *[{{.Lanes}}]{{.Base}})
   224  `
   225  
   226  const simdMaskFromValTemplate = `
   227  // {{.Name}}FromBits constructs a {{.Name}} from a bitmap value, where 1 means set for the indexed element, 0 means unset.
   228  {{- if ne .Lanes .LanesContainer}}
   229  // Only the lower {{.Lanes}} bits of y are used.
   230  {{- end}}
   231  //
   232  // Asm: KMOV{{.IntelSizeSuffix}}, CPU Feature: AVX512
   233  func {{.Name}}FromBits(y uint{{.LanesContainer}}) {{.Name}}
   234  
   235  // ToBits constructs a bitmap from a {{.Name}}, where 1 means set for the indexed element, 0 means unset.
   236  {{- if ne .Lanes .LanesContainer}}
   237  // Only the lower {{.Lanes}} bits of y are used.
   238  {{- end}}
   239  //
   240  {{.ToBitsDoc}}
   241  func (x {{.Name}}) ToBits() uint{{.LanesContainer}}
   242  `
   243  
   244  const simdMaskedLoadStoreTemplate = `
   245  // LoadMasked{{.Name}} loads {{.Article}} {{.Name}} from an array,
   246  // at those elements enabled by mask.
   247  //
   248  {{.MaskedLoadDoc}}
   249  //
   250  //go:noescape
   251  func LoadMasked{{.Name}}(y *[{{.Lanes}}]{{.Base}}, mask Mask{{.ElemBits}}x{{.Lanes}}) {{.Name}}
   252  
   253  // StoreMasked stores {{.Article}} {{.Name}} to an array,
   254  // at those elements enabled by mask.
   255  //
   256  {{.MaskedStoreDoc}}
   257  //
   258  //go:noescape
   259  func (x {{.Name}}) StoreMasked(y *[{{.Lanes}}]{{.Base}}, mask Mask{{.ElemBits}}x{{.Lanes}})
   260  `
   261  
   262  const simdStubsTmpl = `
   263  {{define "op1"}}
   264  {{if .Documentation}}{{.Documentation}}
   265  //{{end}}
   266  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   267  func ({{.Op0NameAndType "x"}}) {{.Go}}() {{.GoType}}
   268  {{end}}
   269  
   270  {{define "op2"}}
   271  {{if .Documentation}}{{.Documentation}}
   272  //{{end}}
   273  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   274  func ({{.Op0NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}) {{.GoType}}
   275  {{end}}
   276  
   277  {{define "op2_21"}}
   278  {{if .Documentation}}{{.Documentation}}
   279  //{{end}}
   280  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   281  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}) {{.GoType}}
   282  {{end}}
   283  
   284  {{define "op2_21Type1"}}
   285  {{if .Documentation}}{{.Documentation}}
   286  //{{end}}
   287  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   288  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}) {{.GoType}}
   289  {{end}}
   290  
   291  {{define "op3"}}
   292  {{if .Documentation}}{{.Documentation}}
   293  //{{end}}
   294  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   295  func ({{.Op0NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op2NameAndType "z"}}) {{.GoType}}
   296  {{end}}
   297  
   298  {{define "op3_31Zero3"}}
   299  {{if .Documentation}}{{.Documentation}}
   300  //{{end}}
   301  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   302  func ({{.Op2NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}) {{.GoType}}
   303  {{end}}
   304  
   305  {{define "op3_21"}}
   306  {{if .Documentation}}{{.Documentation}}
   307  //{{end}}
   308  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   309  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}, {{.Op2NameAndType "z"}}) {{.GoType}}
   310  {{end}}
   311  
   312  {{define "op3_21Type1"}}
   313  {{if .Documentation}}{{.Documentation}}
   314  //{{end}}
   315  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   316  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}, {{.Op2NameAndType "z"}}) {{.GoType}}
   317  {{end}}
   318  
   319  {{define "op3_231Type1"}}
   320  {{if .Documentation}}{{.Documentation}}
   321  //{{end}}
   322  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   323  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.Op0NameAndType "z"}}) {{.GoType}}
   324  {{end}}
   325  
   326  {{define "op2VecAsScalar"}}
   327  {{if .Documentation}}{{.Documentation}}
   328  //{{end}}
   329  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   330  func ({{.Op0NameAndType "x"}}) {{.Go}}(y uint{{(index .In 1).TreatLikeAScalarOfSize}}) {{(index .Out 0).Go}}
   331  {{end}}
   332  
   333  {{define "op3VecAsScalar"}}
   334  {{if .Documentation}}{{.Documentation}}
   335  //{{end}}
   336  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   337  func ({{.Op0NameAndType "x"}}) {{.Go}}(y uint{{(index .In 1).TreatLikeAScalarOfSize}}, {{.Op2NameAndType "z"}}) {{(index .Out 0).Go}}
   338  {{end}}
   339  
   340  {{define "op4"}}
   341  {{if .Documentation}}{{.Documentation}}
   342  //{{end}}
   343  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   344  func ({{.Op0NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op2NameAndType "z"}}, {{.Op3NameAndType "u"}}) {{.GoType}}
   345  {{end}}
   346  
   347  {{define "op4_231Type1"}}
   348  {{if .Documentation}}{{.Documentation}}
   349  //{{end}}
   350  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   351  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.Op0NameAndType "z"}}, {{.Op3NameAndType "u"}}) {{.GoType}}
   352  {{end}}
   353  
   354  {{define "op4_31"}}
   355  {{if .Documentation}}{{.Documentation}}
   356  //{{end}}
   357  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   358  func ({{.Op2NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op0NameAndType "z"}}, {{.Op3NameAndType "u"}}) {{.GoType}}
   359  {{end}}
   360  
   361  {{define "op1Imm8"}}
   362  {{if .Documentation}}{{.Documentation}}
   363  //{{end}}
   364  // {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
   365  //
   366  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   367  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8) {{.GoType}}
   368  {{end}}
   369  
   370  {{define "op2Imm8"}}
   371  {{if .Documentation}}{{.Documentation}}
   372  //{{end}}
   373  // {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
   374  //
   375  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   376  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}) {{.GoType}}
   377  {{end}}
   378  
   379  {{define "op2Imm8_2I"}}
   380  {{if .Documentation}}{{.Documentation}}
   381  //{{end}}
   382  // {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
   383  //
   384  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   385  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.ImmName}} uint8) {{.GoType}}
   386  {{end}}
   387  
   388  {{define "op2Imm8_II"}}
   389  {{if .Documentation}}{{.Documentation}}
   390  //{{end}}
   391  // {{.ImmName}} result in better performance when they are constants, non-constant values will be translated into a jump table.
   392  // {{.ImmName}} should be between 0 and 3, inclusive; other values may result in a runtime panic.
   393  //
   394  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   395  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}) {{.GoType}}
   396  {{end}}
   397  
   398  {{define "op2Imm8_SHA1RNDS4"}}
   399  {{if .Documentation}}{{.Documentation}}
   400  //{{end}}
   401  // {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
   402  //
   403  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   404  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}) {{.GoType}}
   405  {{end}}
   406  
   407  {{define "op3Imm8"}}
   408  {{if .Documentation}}{{.Documentation}}
   409  //{{end}}
   410  // {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
   411  //
   412  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   413  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}, {{.Op3NameAndType "z"}}) {{.GoType}}
   414  {{end}}
   415  
   416  {{define "op3Imm8_2I"}}
   417  {{if .Documentation}}{{.Documentation}}
   418  //{{end}}
   419  // {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
   420  //
   421  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   422  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.ImmName}} uint8, {{.Op3NameAndType "z"}}) {{.GoType}}
   423  {{end}}
   424  
   425  
   426  {{define "op4Imm8"}}
   427  {{if .Documentation}}{{.Documentation}}
   428  //{{end}}
   429  // {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
   430  //
   431  // Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
   432  func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}, {{.Op3NameAndType "z"}}, {{.Op4NameAndType "u"}}) {{.GoType}}
   433  {{end}}
   434  
   435  {{define "vectorConversion"}}
   436  // As{{.Tdst.Name}} returns {{.Tdst.Article}} {{.Tdst.Name}} with the same bit representation as x.
   437  func (x {{.Tsrc.Name}}) As{{.Tdst.Name}}() {{.Tdst.Name}}
   438  {{end}}
   439  
   440  {{define "mask"}}
   441  // To{{.VectorCounterpart}} converts from {{.Name}} to {{.VectorCounterpart}}.
   442  func (from {{.Name}}) To{{.VectorCounterpart}}() (to {{.VectorCounterpart}})
   443  
   444  // asMask converts from {{.VectorCounterpart}} to {{.Name}}.
   445  func (from {{.VectorCounterpart}}) asMask() (to {{.Name}})
   446  
   447  func (x {{.Name}}) And(y {{.Name}}) {{.Name}}
   448  
   449  func (x {{.Name}}) Or(y {{.Name}}) {{.Name}}
   450  {{end}}
   451  `
   452  
   453  // parseSIMDTypes groups go simd types by their vector sizes, and
   454  // returns a map whose key is the vector size, value is the simd type.
   455  func parseSIMDTypes(ops []Operation) simdTypeMap {
   456  	// TODO: maybe instead of going over ops, let's try go over types.yaml.
   457  	ret := map[int][]simdType{}
   458  	seen := map[string]struct{}{}
   459  	processArg := func(arg Operand) {
   460  		if arg.Class == "immediate" || arg.Class == "greg" {
   461  			// Immediates are not encoded as vector types.
   462  			return
   463  		}
   464  		if _, ok := seen[*arg.Go]; ok {
   465  			return
   466  		}
   467  		seen[*arg.Go] = struct{}{}
   468  
   469  		lanes := *arg.Lanes
   470  		base := fmt.Sprintf("%s%d", *arg.Base, *arg.ElemBits)
   471  		tagFieldNameS := fmt.Sprintf("%sx%d", base, lanes)
   472  		tagFieldS := fmt.Sprintf("%s v%d", tagFieldNameS, *arg.Bits)
   473  		valFieldS := fmt.Sprintf("vals%s[%d]%s", strings.Repeat(" ", len(tagFieldNameS)-3), lanes, base)
   474  		fields := fmt.Sprintf("\t%s\n\t%s", tagFieldS, valFieldS)
   475  		if arg.Class == "mask" {
   476  			vectorCounterpart := strings.ReplaceAll(*arg.Go, "Mask", "Int")
   477  			reshapedVectorWithAndOr := fmt.Sprintf("Int32x%d", *arg.Bits/32)
   478  			ret[*arg.Bits] = append(ret[*arg.Bits], simdType{*arg.Go, lanes, base, fields, arg.Class, vectorCounterpart, reshapedVectorWithAndOr, *arg.Bits})
   479  			// In case the vector counterpart of a mask is not present, put its vector counterpart typedef into the map as well.
   480  			if _, ok := seen[vectorCounterpart]; !ok {
   481  				seen[vectorCounterpart] = struct{}{}
   482  				ret[*arg.Bits] = append(ret[*arg.Bits], simdType{vectorCounterpart, lanes, base, fields, "vreg", "", "", *arg.Bits})
   483  			}
   484  		} else {
   485  			ret[*arg.Bits] = append(ret[*arg.Bits], simdType{*arg.Go, lanes, base, fields, arg.Class, "", "", *arg.Bits})
   486  		}
   487  	}
   488  	for _, op := range ops {
   489  		for _, arg := range op.In {
   490  			processArg(arg)
   491  		}
   492  		for _, arg := range op.Out {
   493  			processArg(arg)
   494  		}
   495  	}
   496  	return ret
   497  }
   498  
   499  func vConvertFromTypeMap(typeMap simdTypeMap) []simdTypePair {
   500  	v := []simdTypePair{}
   501  	for _, ts := range typeMap {
   502  		for i, tsrc := range ts {
   503  			for j, tdst := range ts {
   504  				if i != j && tsrc.Type == tdst.Type && tsrc.Type == "vreg" &&
   505  					tsrc.Lanes > 1 && tdst.Lanes > 1 {
   506  					v = append(v, simdTypePair{tsrc, tdst})
   507  				}
   508  			}
   509  		}
   510  	}
   511  	slices.SortFunc(v, compareSimdTypePairs)
   512  	return v
   513  }
   514  
   515  func masksFromTypeMap(typeMap simdTypeMap) []simdType {
   516  	m := []simdType{}
   517  	for _, ts := range typeMap {
   518  		for _, tsrc := range ts {
   519  			if tsrc.Type == "mask" {
   520  				m = append(m, tsrc)
   521  			}
   522  		}
   523  	}
   524  	slices.SortFunc(m, compareSimdTypes)
   525  	return m
   526  }
   527  
   528  func typesFromTypeMap(typeMap simdTypeMap) []simdType {
   529  	m := []simdType{}
   530  	for _, ts := range typeMap {
   531  		for _, tsrc := range ts {
   532  			if tsrc.Lanes > 1 {
   533  				m = append(m, tsrc)
   534  			}
   535  		}
   536  	}
   537  	slices.SortFunc(m, compareSimdTypes)
   538  	return m
   539  }
   540  
   541  // writeSIMDTypes generates the simd vector types into a bytes.Buffer
   542  func writeSIMDTypes(typeMap simdTypeMap) *bytes.Buffer {
   543  	t := templateOf(simdTypesTemplates, "types_amd64")
   544  	loadStore := templateOf(simdLoadStoreTemplate, "loadstore_amd64")
   545  	maskedLoadStore := templateOf(simdMaskedLoadStoreTemplate, "maskedloadstore_amd64")
   546  	maskFromVal := templateOf(simdMaskFromValTemplate, "maskFromVal_amd64")
   547  
   548  	buffer := new(bytes.Buffer)
   549  	buffer.WriteString(simdPackageHeader)
   550  
   551  	sizes := make([]int, 0, len(typeMap))
   552  	for size, types := range typeMap {
   553  		slices.SortFunc(types, compareSimdTypes)
   554  		sizes = append(sizes, size)
   555  	}
   556  	sort.Ints(sizes)
   557  
   558  	for _, size := range sizes {
   559  		if size <= 64 {
   560  			// these are scalar
   561  			continue
   562  		}
   563  		if err := t.ExecuteTemplate(buffer, "sizeTmpl", size); err != nil {
   564  			panic(fmt.Errorf("failed to execute size template for size %d: %w", size, err))
   565  		}
   566  		for _, typeDef := range typeMap[size] {
   567  			if typeDef.Lanes == 1 {
   568  				continue
   569  			}
   570  			if err := t.ExecuteTemplate(buffer, "typeTmpl", typeDef); err != nil {
   571  				panic(fmt.Errorf("failed to execute type template for type %s: %w", typeDef.Name, err))
   572  			}
   573  			if typeDef.Type != "mask" {
   574  				if err := loadStore.ExecuteTemplate(buffer, "loadstore_amd64", typeDef); err != nil {
   575  					panic(fmt.Errorf("failed to execute loadstore template for type %s: %w", typeDef.Name, err))
   576  				}
   577  				// restrict to AVX2 masked loads/stores first.
   578  				if typeDef.MaskedLoadStoreFilter() {
   579  					if err := maskedLoadStore.ExecuteTemplate(buffer, "maskedloadstore_amd64", typeDef); err != nil {
   580  						panic(fmt.Errorf("failed to execute maskedloadstore template for type %s: %w", typeDef.Name, err))
   581  					}
   582  				}
   583  			} else {
   584  				if err := maskFromVal.ExecuteTemplate(buffer, "maskFromVal_amd64", typeDef); err != nil {
   585  					panic(fmt.Errorf("failed to execute maskFromVal template for type %s: %w", typeDef.Name, err))
   586  				}
   587  			}
   588  		}
   589  	}
   590  
   591  	return buffer
   592  }
   593  
   594  func writeSIMDFeatures(ops []Operation) *bytes.Buffer {
   595  	// Gather all features
   596  	type featureKey struct {
   597  		GoArch  string
   598  		Feature string
   599  	}
   600  	featureSet := make(map[featureKey]struct{})
   601  	for _, op := range ops {
   602  		// Generate a feature check for each independant feature in a
   603  		// composite feature.
   604  		for feature := range strings.SplitSeq(op.CPUFeature, ",") {
   605  			feature = strings.TrimSpace(feature)
   606  			featureSet[featureKey{op.GoArch, feature}] = struct{}{}
   607  		}
   608  	}
   609  	features := slices.SortedFunc(maps.Keys(featureSet), func(a, b featureKey) int {
   610  		if c := cmp.Compare(a.GoArch, b.GoArch); c != 0 {
   611  			return c
   612  		}
   613  		return compareNatural(a.Feature, b.Feature)
   614  	})
   615  
   616  	// If we ever have the same feature name on more than one GOARCH, we'll have
   617  	// to be more careful about this.
   618  	t := templateOf(simdFeaturesTemplate, "features")
   619  
   620  	buffer := new(bytes.Buffer)
   621  	buffer.WriteString(simdPackageHeader)
   622  
   623  	if err := t.Execute(buffer, features); err != nil {
   624  		panic(fmt.Errorf("failed to execute features template: %w", err))
   625  	}
   626  
   627  	return buffer
   628  }
   629  
   630  // writeSIMDStubs returns two bytes.Buffers containing the declarations for the public
   631  // and internal-use vector intrinsics.
   632  func writeSIMDStubs(ops []Operation, typeMap simdTypeMap) (f, fI *bytes.Buffer) {
   633  	t := templateOf(simdStubsTmpl, "simdStubs")
   634  	f = new(bytes.Buffer)
   635  	fI = new(bytes.Buffer)
   636  	f.WriteString(simdPackageHeader)
   637  	fI.WriteString(simdPackageHeader)
   638  
   639  	slices.SortFunc(ops, compareOperations)
   640  
   641  	for i, op := range ops {
   642  		if op.NoTypes != nil && *op.NoTypes == "true" {
   643  			continue
   644  		}
   645  		if op.SkipMaskedMethod() {
   646  			continue
   647  		}
   648  		idxVecAsScalar, err := checkVecAsScalar(op)
   649  		if err != nil {
   650  			panic(err)
   651  		}
   652  		if s, op, err := classifyOp(op); err == nil {
   653  			if idxVecAsScalar != -1 {
   654  				if s == "op2" || s == "op3" {
   655  					s += "VecAsScalar"
   656  				} else {
   657  					panic(fmt.Errorf("simdgen only supports op2 or op3 with TreatLikeAScalarOfSize"))
   658  				}
   659  			}
   660  			if i == 0 || op.Go != ops[i-1].Go {
   661  				if unicode.IsUpper([]rune(op.Go)[0]) {
   662  					fmt.Fprintf(f, "\n/* %s */\n", op.Go)
   663  				} else {
   664  					fmt.Fprintf(fI, "\n/* %s */\n", op.Go)
   665  				}
   666  			}
   667  			if unicode.IsUpper([]rune(op.Go)[0]) {
   668  				if err := t.ExecuteTemplate(f, s, op); err != nil {
   669  					panic(fmt.Errorf("failed to execute template %s for op %v: %w", s, op, err))
   670  				}
   671  			} else {
   672  				if err := t.ExecuteTemplate(fI, s, op); err != nil {
   673  					panic(fmt.Errorf("failed to execute template %s for op %v: %w", s, op, err))
   674  				}
   675  			}
   676  		} else {
   677  			panic(fmt.Errorf("failed to classify op %v: %w", op.Go, err))
   678  		}
   679  	}
   680  
   681  	vectorConversions := vConvertFromTypeMap(typeMap)
   682  	for _, conv := range vectorConversions {
   683  		if err := t.ExecuteTemplate(f, "vectorConversion", conv); err != nil {
   684  			panic(fmt.Errorf("failed to execute vectorConversion template: %w", err))
   685  		}
   686  	}
   687  
   688  	masks := masksFromTypeMap(typeMap)
   689  	for _, mask := range masks {
   690  		if err := t.ExecuteTemplate(f, "mask", mask); err != nil {
   691  			panic(fmt.Errorf("failed to execute mask template for mask %s: %w", mask.Name, err))
   692  		}
   693  	}
   694  
   695  	return
   696  }
   697  

View as plain text