Source file src/simd/_gen/tmplgen/main.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  // this generates type-instantiated boilerplate code for
     8  // slice operations and tests
     9  
    10  import (
    11  	"bufio"
    12  	"bytes"
    13  	"flag"
    14  	"fmt"
    15  	"go/format"
    16  	"io"
    17  	"os"
    18  	"strings"
    19  	"text/template"
    20  )
    21  
    22  type resultTypeFunc func(t string, w, c int) (ot string, ow int, oc int)
    23  
    24  // shapes describes a combination of vector widths and various element types
    25  type shapes struct {
    26  	vecs   []int // Vector bit width for this shape.
    27  	ints   []int // Int element bit width(s) for this shape
    28  	uints  []int // Unsigned int element bit width(s) for this shape
    29  	floats []int // Float element bit width(s) for this shape
    30  	output resultTypeFunc
    31  }
    32  
    33  // shapeAndTemplate is a template and the set of shapes on which it will be expanded
    34  type shapeAndTemplate struct {
    35  	s *shapes
    36  	t *template.Template
    37  }
    38  
    39  func (sat shapeAndTemplate) target(outType string, width int) shapeAndTemplate {
    40  	newSat := sat
    41  	newShape := *sat.s
    42  	newShape.output = func(t string, w, c int) (ot string, ow int, oc int) {
    43  		return outType, width, c
    44  	}
    45  	newSat.s = &newShape
    46  	return newSat
    47  }
    48  
    49  func (sat shapeAndTemplate) shrinkTo(outType string, by int) shapeAndTemplate {
    50  	newSat := sat
    51  	newShape := *sat.s
    52  	newShape.output = func(t string, w, c int) (ot string, ow int, oc int) {
    53  		return outType, w / by, c * by
    54  	}
    55  	newSat.s = &newShape
    56  	return newSat
    57  }
    58  
    59  func (s *shapes) forAllShapes(f func(seq int, t, upperT string, w, c int, out io.Writer), out io.Writer) {
    60  	vecs := s.vecs
    61  	ints := s.ints
    62  	uints := s.uints
    63  	floats := s.floats
    64  	seq := 0
    65  	for _, v := range vecs {
    66  		for _, w := range ints {
    67  			c := v / w
    68  			f(seq, "int", "Int", w, c, out)
    69  			seq++
    70  		}
    71  		for _, w := range uints {
    72  			c := v / w
    73  			f(seq, "uint", "Uint", w, c, out)
    74  			seq++
    75  		}
    76  		for _, w := range floats {
    77  			c := v / w
    78  			f(seq, "float", "Float", w, c, out)
    79  			seq++
    80  		}
    81  	}
    82  }
    83  
    84  var allShapes = &shapes{
    85  	vecs:   []int{128, 256, 512},
    86  	ints:   []int{8, 16, 32, 64},
    87  	uints:  []int{8, 16, 32, 64},
    88  	floats: []int{32, 64},
    89  }
    90  
    91  var intShapes = &shapes{
    92  	vecs: []int{128, 256, 512},
    93  	ints: []int{8, 16, 32, 64},
    94  }
    95  
    96  var uintShapes = &shapes{
    97  	vecs:  []int{128, 256, 512},
    98  	uints: []int{8, 16, 32, 64},
    99  }
   100  
   101  var avx512Shapes = &shapes{
   102  	vecs:   []int{512},
   103  	ints:   []int{8, 16, 32, 64},
   104  	uints:  []int{8, 16, 32, 64},
   105  	floats: []int{32, 64},
   106  }
   107  
   108  var avx2Shapes = &shapes{
   109  	vecs:   []int{128, 256},
   110  	ints:   []int{8, 16, 32, 64},
   111  	uints:  []int{8, 16, 32, 64},
   112  	floats: []int{32, 64},
   113  }
   114  
   115  var avx2MaskedLoadShapes = &shapes{
   116  	vecs:   []int{128, 256},
   117  	ints:   []int{32, 64},
   118  	uints:  []int{32, 64},
   119  	floats: []int{32, 64},
   120  }
   121  
   122  var avx2SmallLoadPunShapes = &shapes{
   123  	// ints are done by hand, these are type-punned to int.
   124  	vecs:  []int{128, 256},
   125  	uints: []int{8, 16},
   126  }
   127  
   128  var unaryFlaky = &shapes{ // for tests that support flaky equality
   129  	vecs:   []int{128, 256, 512},
   130  	floats: []int{32, 64},
   131  }
   132  
   133  var ternaryFlaky = &shapes{ // for tests that support flaky equality
   134  	vecs:   []int{128, 256, 512},
   135  	floats: []int{32},
   136  }
   137  
   138  var avx2SignedComparisons = &shapes{
   139  	vecs: []int{128, 256},
   140  	ints: []int{8, 16, 32, 64},
   141  }
   142  
   143  var avx2UnsignedComparisons = &shapes{
   144  	vecs:  []int{128, 256},
   145  	uints: []int{8, 16, 32, 64},
   146  }
   147  
   148  type templateData struct {
   149  	VType  string // the type of the vector, e.g. Float32x4
   150  	AOrAn  string // for documentation, the article "a" or "an"
   151  	EWidth int    // the bit width of the element type, e.g. 32
   152  	Vwidth int    // the width of the vector type, e.g. 128
   153  	Count  int    // the number of elements, e.g. 4
   154  	WxC    string // the width-by-type string, e.g., "32x4"
   155  	BxC    string // as if bytes, in the proper count, e.g., "8x16" (W==8)
   156  	Base   string // the title-case Base Type of the vector, e.g., "Float"
   157  	Etype  string // the element type, e.g. "float32"
   158  	OxFF   string // a mask for the lowest 'count' bits
   159  
   160  	OVType string // type of output vector
   161  	OEtype string // output element type
   162  	OEType string // output element type, title-case
   163  	OCount int    // output element count
   164  }
   165  
   166  func (t templateData) As128BitVec() string {
   167  	return fmt.Sprintf("%s%dx%d", t.Base, t.EWidth, 128/t.EWidth)
   168  }
   169  
   170  func oneTemplate(t *template.Template, baseType string, width, count int, out io.Writer, rtf resultTypeFunc) {
   171  	b := width * count
   172  	if b < 128 || b > 512 {
   173  		return
   174  	}
   175  
   176  	ot, ow, oc := baseType, width, count
   177  	if rtf != nil {
   178  		ot, ow, oc = rtf(ot, ow, oc)
   179  		if ow*oc > 512 || ow*oc < 128 || ow < 8 || ow > 64 {
   180  			return
   181  		}
   182  		// TODO someday we will support conversions to 16-bit floats
   183  		if ot == "float" && ow < 32 {
   184  			return
   185  		}
   186  	}
   187  	ovType := fmt.Sprintf("%s%dx%d", strings.ToUpper(ot[:1])+ot[1:], ow, oc)
   188  	oeType := fmt.Sprintf("%s%d", ot, ow)
   189  	oEType := fmt.Sprintf("%s%d", strings.ToUpper(ot[:1])+ot[1:], ow)
   190  
   191  	wxc := fmt.Sprintf("%dx%d", width, count)
   192  	BaseType := strings.ToUpper(baseType[:1]) + baseType[1:]
   193  	vType := fmt.Sprintf("%s%s", BaseType, wxc)
   194  	eType := fmt.Sprintf("%s%d", baseType, width)
   195  
   196  	bxc := fmt.Sprintf("%dx%d", 8, count*(width/8))
   197  	aOrAn := "a"
   198  	if strings.Contains("aeiou", baseType[:1]) {
   199  		aOrAn = "an"
   200  	}
   201  	oxFF := fmt.Sprintf("0x%x", uint64((1<<count)-1))
   202  	t.Execute(out, templateData{
   203  		VType:  vType,
   204  		AOrAn:  aOrAn,
   205  		EWidth: width,
   206  		Vwidth: b,
   207  		Count:  count,
   208  		WxC:    wxc,
   209  		BxC:    bxc,
   210  		Base:   BaseType,
   211  		Etype:  eType,
   212  		OxFF:   oxFF,
   213  		OVType: ovType,
   214  		OEtype: oeType,
   215  		OCount: oc,
   216  		OEType: oEType,
   217  	})
   218  }
   219  
   220  // forTemplates expands the template sat.t for each shape
   221  // in sat.s, writing to out.
   222  func (sat shapeAndTemplate) forTemplates(out io.Writer) {
   223  	t, s := sat.t, sat.s
   224  	vecs := s.vecs
   225  	ints := s.ints
   226  	uints := s.uints
   227  	floats := s.floats
   228  	for _, v := range vecs {
   229  		for _, w := range ints {
   230  			c := v / w
   231  			oneTemplate(t, "int", w, c, out, sat.s.output)
   232  		}
   233  		for _, w := range uints {
   234  			c := v / w
   235  			oneTemplate(t, "uint", w, c, out, sat.s.output)
   236  		}
   237  		for _, w := range floats {
   238  			c := v / w
   239  			oneTemplate(t, "float", w, c, out, sat.s.output)
   240  		}
   241  	}
   242  }
   243  
   244  func prologue(s string, out io.Writer) {
   245  	fmt.Fprintf(out,
   246  		`// Code generated by '%s'; DO NOT EDIT.
   247  
   248  //go:build goexperiment.simd
   249  
   250  package simd
   251  
   252  `, s)
   253  }
   254  
   255  func ssaPrologue(s string, out io.Writer) {
   256  	fmt.Fprintf(out,
   257  		`// Code generated by '%s'; DO NOT EDIT.
   258  
   259  package ssa
   260  
   261  `, s)
   262  }
   263  
   264  func unsafePrologue(s string, out io.Writer) {
   265  	fmt.Fprintf(out,
   266  		`// Code generated by '%s'; DO NOT EDIT.
   267  
   268  //go:build goexperiment.simd
   269  
   270  package simd
   271  
   272  import "unsafe"
   273  
   274  `, s)
   275  }
   276  
   277  func testPrologue(t, s string, out io.Writer) {
   278  	fmt.Fprintf(out,
   279  		`// Code generated by '%s'; DO NOT EDIT.
   280  
   281  //go:build goexperiment.simd
   282  
   283  // This file contains functions testing %s.
   284  // Each function in this file is specialized for a
   285  // particular simd type <BaseType><Width>x<Count>.
   286  
   287  package simd_test
   288  
   289  import (
   290  	"simd"
   291  	"testing"
   292  )
   293  
   294  `, s, t)
   295  }
   296  
   297  func curryTestPrologue(t string) func(s string, out io.Writer) {
   298  	return func(s string, out io.Writer) {
   299  		testPrologue(t, s, out)
   300  	}
   301  }
   302  
   303  func templateOf(name, temp string) shapeAndTemplate {
   304  	return shapeAndTemplate{s: allShapes,
   305  		t: template.Must(template.New(name).Parse(temp))}
   306  }
   307  
   308  func shapedTemplateOf(s *shapes, name, temp string) shapeAndTemplate {
   309  	return shapeAndTemplate{s: s,
   310  		t: template.Must(template.New(name).Parse(temp))}
   311  }
   312  
   313  var sliceTemplate = templateOf("slice", `
   314  // Load{{.VType}}Slice loads {{.AOrAn}} {{.VType}} from a slice of at least {{.Count}} {{.Etype}}s
   315  func Load{{.VType}}Slice(s []{{.Etype}}) {{.VType}} {
   316  	return Load{{.VType}}((*[{{.Count}}]{{.Etype}})(s))
   317  }
   318  
   319  // StoreSlice stores x into a slice of at least {{.Count}} {{.Etype}}s
   320  func (x {{.VType}}) StoreSlice(s []{{.Etype}}) {
   321  	x.Store((*[{{.Count}}]{{.Etype}})(s))
   322  }
   323  `)
   324  
   325  var unaryTemplate = templateOf("unary_helpers", `
   326  // test{{.VType}}Unary tests the simd unary method f against the expected behavior generated by want
   327  func test{{.VType}}Unary(t *testing.T, f func(_ simd.{{.VType}}) simd.{{.VType}}, want func(_ []{{.Etype}}) []{{.Etype}}) {
   328  	n := {{.Count}}
   329  	t.Helper()
   330  	forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
   331  	 	t.Helper()
   332  		a := simd.Load{{.VType}}Slice(x)
   333  		g := make([]{{.Etype}}, n)
   334  		f(a).StoreSlice(g)
   335  		w := want(x)
   336  		return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
   337  	})
   338  }
   339  `)
   340  
   341  var unaryFlakyTemplate = shapedTemplateOf(unaryFlaky, "unary_flaky_helpers", `
   342  // test{{.VType}}UnaryFlaky tests the simd unary method f against the expected behavior generated by want,
   343  // but using a flakiness parameter because we haven't exactly figured out how simd floating point works
   344  func test{{.VType}}UnaryFlaky(t *testing.T, f func(x simd.{{.VType}}) simd.{{.VType}}, want func(x []{{.Etype}}) []{{.Etype}}, flakiness float64) {
   345  	n := {{.Count}}
   346  	t.Helper()
   347  	forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
   348  	 	t.Helper()
   349  		a := simd.Load{{.VType}}Slice(x)
   350  		g := make([]{{.Etype}}, n)
   351  		f(a).StoreSlice(g)
   352  		w := want(x)
   353  		return checkSlicesLogInput(t, g, w, flakiness, func() {t.Helper(); t.Logf("x=%v", x)})
   354  	})
   355  }
   356  `)
   357  
   358  var convertTemplate = templateOf("convert_helpers", `
   359  // test{{.VType}}ConvertTo{{.OEType}} tests the simd conversion method f against the expected behavior generated by want
   360  // This is for count-preserving conversions, so if there is a change in size, then there is a change in vector width.
   361  func test{{.VType}}ConvertTo{{.OEType}}(t *testing.T, f func(x simd.{{.VType}}) simd.{{.OVType}}, want func(x []{{.Etype}}) []{{.OEtype}}) {
   362  	n := {{.Count}}
   363  	t.Helper()
   364  	forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
   365  	 	t.Helper()
   366  		a := simd.Load{{.VType}}Slice(x)
   367  		g := make([]{{.OEtype}}, n)
   368  		f(a).StoreSlice(g)
   369  		w := want(x)
   370  		return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
   371  	})
   372  }
   373  `)
   374  
   375  var unaryToInt32 = convertTemplate.target("int", 32)
   376  var unaryToUint32 = convertTemplate.target("uint", 32)
   377  var unaryToUint16 = convertTemplate.target("uint", 16)
   378  
   379  var binaryTemplate = templateOf("binary_helpers", `
   380  // test{{.VType}}Binary tests the simd binary method f against the expected behavior generated by want
   381  func test{{.VType}}Binary(t *testing.T, f func(_, _ simd.{{.VType}}) simd.{{.VType}}, want func(_, _ []{{.Etype}}) []{{.Etype}}) {
   382  	n := {{.Count}}
   383  	t.Helper()
   384  	forSlicePair(t, {{.Etype}}s, n, func(x, y []{{.Etype}}) bool {
   385  	 	t.Helper()
   386  		a := simd.Load{{.VType}}Slice(x)
   387  		b := simd.Load{{.VType}}Slice(y)
   388  		g := make([]{{.Etype}}, n)
   389  		f(a, b).StoreSlice(g)
   390  		w := want(x, y)
   391  		return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); })
   392  	})
   393  }
   394  `)
   395  
   396  var ternaryTemplate = templateOf("ternary_helpers", `
   397  // test{{.VType}}Ternary tests the simd ternary method f against the expected behavior generated by want
   398  func test{{.VType}}Ternary(t *testing.T, f func(_, _, _ simd.{{.VType}}) simd.{{.VType}}, want func(_, _, _ []{{.Etype}}) []{{.Etype}}) {
   399  	n := {{.Count}}
   400  	t.Helper()
   401  	forSliceTriple(t, {{.Etype}}s, n, func(x, y, z []{{.Etype}}) bool {
   402  	 	t.Helper()
   403  		a := simd.Load{{.VType}}Slice(x)
   404  		b := simd.Load{{.VType}}Slice(y)
   405  		c := simd.Load{{.VType}}Slice(z)
   406  		g := make([]{{.Etype}}, n)
   407  		f(a, b, c).StoreSlice(g)
   408  		w := want(x, y, z)
   409  		return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); t.Logf("z=%v", z); })
   410  	})
   411  }
   412  `)
   413  
   414  var ternaryFlakyTemplate = shapedTemplateOf(ternaryFlaky, "ternary_helpers", `
   415  // test{{.VType}}TernaryFlaky tests the simd ternary method f against the expected behavior generated by want,
   416  // but using a flakiness parameter because we haven't exactly figured out how simd floating point works
   417  func test{{.VType}}TernaryFlaky(t *testing.T, f func(x, y, z simd.{{.VType}}) simd.{{.VType}}, want func(x, y, z []{{.Etype}}) []{{.Etype}}, flakiness float64) {
   418  	n := {{.Count}}
   419  	t.Helper()
   420  	forSliceTriple(t, {{.Etype}}s, n, func(x, y, z []{{.Etype}}) bool {
   421  	 	t.Helper()
   422  		a := simd.Load{{.VType}}Slice(x)
   423  		b := simd.Load{{.VType}}Slice(y)
   424  		c := simd.Load{{.VType}}Slice(z)
   425  		g := make([]{{.Etype}}, n)
   426  		f(a, b, c).StoreSlice(g)
   427  		w := want(x, y, z)
   428  		return checkSlicesLogInput(t, g, w, flakiness, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); t.Logf("z=%v", z); })
   429  	})
   430  }
   431  `)
   432  
   433  var compareTemplate = templateOf("compare_helpers", `
   434  // test{{.VType}}Compare tests the simd comparison method f against the expected behavior generated by want
   435  func test{{.VType}}Compare(t *testing.T, f func(_, _ simd.{{.VType}}) simd.Mask{{.WxC}}, want func(_, _ []{{.Etype}}) []int64) {
   436  	n := {{.Count}}
   437  	t.Helper()
   438  	forSlicePair(t, {{.Etype}}s, n, func(x, y []{{.Etype}}) bool {
   439  	 	t.Helper()
   440  		a := simd.Load{{.VType}}Slice(x)
   441  		b := simd.Load{{.VType}}Slice(y)
   442  		g := make([]int{{.EWidth}}, n)
   443  		f(a, b).AsInt{{.WxC}}().StoreSlice(g)
   444  		w := want(x, y)
   445  		return checkSlicesLogInput(t, s64(g), w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); })
   446  	})
   447  }
   448  `)
   449  
   450  // TODO this has not been tested yet.
   451  var compareMaskedTemplate = templateOf("comparemasked_helpers", `
   452  // test{{.VType}}CompareMasked tests the simd masked comparison method f against the expected behavior generated by want
   453  // The mask is applied to the output of want; anything not in the mask, is zeroed.
   454  func test{{.VType}}CompareMasked(t *testing.T,
   455  	f func(_, _ simd.{{.VType}}, m simd.Mask{{.WxC}}) simd.Mask{{.WxC}},
   456  	want func(_, _ []{{.Etype}}) []int64) {
   457  	n := {{.Count}}
   458  	t.Helper()
   459  	forSlicePairMasked(t, {{.Etype}}s, n, func(x, y []{{.Etype}}, m []bool) bool {
   460  	 	t.Helper()
   461  		a := simd.Load{{.VType}}Slice(x)
   462  		b := simd.Load{{.VType}}Slice(y)
   463  		k := simd.LoadInt{{.WxC}}Slice(toVect[int{{.EWidth}}](m)).ToMask()
   464  		g := make([]int{{.EWidth}}, n)
   465  		f(a, b, k).AsInt{{.WxC}}().StoreSlice(g)
   466  		w := want(x, y)
   467  		for i := range m {
   468  			if !m[i] {
   469  				w[i] = 0
   470  			}
   471  		}
   472  		return checkSlicesLogInput(t, s64(g), w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); t.Logf("m=%v", m); })
   473  	})
   474  }
   475  `)
   476  
   477  var avx512MaskedLoadSlicePartTemplate = shapedTemplateOf(avx512Shapes, "avx 512 load slice part", `
   478  // Load{{.VType}}SlicePart loads a {{.VType}} from the slice s.
   479  // If s has fewer than {{.Count}} elements, the remaining elements of the vector are filled with zeroes.
   480  // If s has {{.Count}} or more elements, the function is equivalent to Load{{.VType}}Slice.
   481  func Load{{.VType}}SlicePart(s []{{.Etype}}) {{.VType}} {
   482  	l := len(s)
   483  	if l >= {{.Count}} {
   484  		return Load{{.VType}}Slice(s)
   485  	}
   486  	if l == 0 {
   487  		var x {{.VType}}
   488  		return x
   489  	}
   490  	mask := Mask{{.WxC}}FromBits({{.OxFF}} >> ({{.Count}} - l))
   491  	return LoadMasked{{.VType}}(pa{{.VType}}(s), mask)
   492  }
   493  
   494  // StoreSlicePart stores the {{.Count}} elements of x into the slice s.
   495  // It stores as many elements as will fit in s.
   496  // If s has {{.Count}} or more elements, the method is equivalent to x.StoreSlice.
   497  func (x {{.VType}}) StoreSlicePart(s []{{.Etype}}) {
   498  	l := len(s)
   499  	if l >= {{.Count}} {
   500  		x.StoreSlice(s)
   501  		return
   502  	}
   503  	if l == 0 {
   504  		return
   505  	}
   506  	mask := Mask{{.WxC}}FromBits({{.OxFF}} >> ({{.Count}} - l))
   507  	x.StoreMasked(pa{{.VType}}(s), mask)
   508  }
   509  `)
   510  
   511  var avx2MaskedLoadSlicePartTemplate = shapedTemplateOf(avx2MaskedLoadShapes, "avx 2 load slice part", `
   512  // Load{{.VType}}SlicePart loads a {{.VType}} from the slice s.
   513  // If s has fewer than {{.Count}} elements, the remaining elements of the vector are filled with zeroes.
   514  // If s has {{.Count}} or more elements, the function is equivalent to Load{{.VType}}Slice.
   515  func Load{{.VType}}SlicePart(s []{{.Etype}}) {{.VType}} {
   516  	l := len(s)
   517  	if l >= {{.Count}} {
   518  		return Load{{.VType}}Slice(s)
   519  	}
   520  	if l == 0 {
   521  		var x {{.VType}}
   522  		return x
   523  	}
   524  	mask := vecMask{{.EWidth}}[len(vecMask{{.EWidth}})/2-l:]
   525  	return LoadMasked{{.VType}}(pa{{.VType}}(s), LoadInt{{.WxC}}Slice(mask).asMask())
   526  }
   527  
   528  // StoreSlicePart stores the {{.Count}} elements of x into the slice s.
   529  // It stores as many elements as will fit in s.
   530  // If s has {{.Count}} or more elements, the method is equivalent to x.StoreSlice.
   531  func (x {{.VType}}) StoreSlicePart(s []{{.Etype}}) {
   532  	l := len(s)
   533  	if l >= {{.Count}} {
   534  		x.StoreSlice(s)
   535  		return
   536  	}
   537  	if l == 0 {
   538  		return
   539  	}
   540  	mask := vecMask{{.EWidth}}[len(vecMask{{.EWidth}})/2-l:]
   541  	x.StoreMasked(pa{{.VType}}(s), LoadInt{{.WxC}}Slice(mask).asMask())
   542  }
   543  `)
   544  
   545  var avx2SmallLoadSlicePartTemplate = shapedTemplateOf(avx2SmallLoadPunShapes, "avx 2 small load slice part", `
   546  // Load{{.VType}}SlicePart loads a {{.VType}} from the slice s.
   547  // If s has fewer than {{.Count}} elements, the remaining elements of the vector are filled with zeroes.
   548  // If s has {{.Count}} or more elements, the function is equivalent to Load{{.VType}}Slice.
   549  func Load{{.VType}}SlicePart(s []{{.Etype}}) {{.VType}} {
   550  	if len(s) == 0 {
   551  		var zero {{.VType}}
   552  		return zero
   553  	}
   554  	t := unsafe.Slice((*int{{.EWidth}})(unsafe.Pointer(&s[0])), len(s))
   555  	return LoadInt{{.WxC}}SlicePart(t).As{{.VType}}()
   556  }
   557  
   558  // StoreSlicePart stores the {{.Count}} elements of x into the slice s.
   559  // It stores as many elements as will fit in s.
   560  // If s has {{.Count}} or more elements, the method is equivalent to x.StoreSlice.
   561  func (x {{.VType}}) StoreSlicePart(s []{{.Etype}}) {
   562  	if len(s) == 0 {
   563  		return
   564  	}
   565  	t := unsafe.Slice((*int{{.EWidth}})(unsafe.Pointer(&s[0])), len(s))
   566  	x.AsInt{{.WxC}}().StoreSlicePart(t)
   567  }
   568  `)
   569  
   570  func (t templateData) CPUfeature() string {
   571  	switch t.Vwidth {
   572  	case 128:
   573  		return "AVX"
   574  	case 256:
   575  		return "AVX2"
   576  	case 512:
   577  		return "AVX512"
   578  	}
   579  	panic(fmt.Errorf("unexpected vector width %d", t.Vwidth))
   580  }
   581  
   582  var avx2SignedComparisonsTemplate = shapedTemplateOf(avx2SignedComparisons, "avx2 signed comparisons", `
   583  // Less returns a mask whose elements indicate whether x < y
   584  //
   585  // Emulated, CPU Feature {{.CPUfeature}}
   586  func (x {{.VType}}) Less(y {{.VType}}) Mask{{.WxC}} {
   587  	return y.Greater(x)
   588  }
   589  
   590  // GreaterEqual returns a mask whose elements indicate whether x >= y
   591  //
   592  // Emulated, CPU Feature {{.CPUfeature}}
   593  func (x {{.VType}}) GreaterEqual(y {{.VType}}) Mask{{.WxC}} {
   594  	ones := x.Equal(x).AsInt{{.WxC}}()
   595  	return y.Greater(x).AsInt{{.WxC}}().Xor(ones).asMask()
   596  }
   597  
   598  // LessEqual returns a mask whose elements indicate whether x <= y
   599  //
   600  // Emulated, CPU Feature {{.CPUfeature}}
   601  func (x {{.VType}}) LessEqual(y {{.VType}}) Mask{{.WxC}} {
   602  	ones := x.Equal(x).AsInt{{.WxC}}()
   603  	return x.Greater(y).AsInt{{.WxC}}().Xor(ones).asMask()
   604  }
   605  
   606  // NotEqual returns a mask whose elements indicate whether x != y
   607  //
   608  // Emulated, CPU Feature {{.CPUfeature}}
   609  func (x {{.VType}}) NotEqual(y {{.VType}}) Mask{{.WxC}} {
   610  	ones := x.Equal(x).AsInt{{.WxC}}()
   611  	return x.Equal(y).AsInt{{.WxC}}().Xor(ones).asMask()
   612  }
   613  `)
   614  
   615  var bitWiseIntTemplate = shapedTemplateOf(intShapes, "bitwise int complement", `
   616  // Not returns the bitwise complement of x
   617  //
   618  // Emulated, CPU Feature {{.CPUfeature}}
   619  func (x {{.VType}}) Not() {{.VType}} {
   620  	return x.Xor(x.Equal(x).As{{.VType}}())
   621  }
   622  `)
   623  
   624  var bitWiseUintTemplate = shapedTemplateOf(uintShapes, "bitwise uint complement", `
   625  // Not returns the bitwise complement of x
   626  //
   627  // Emulated, CPU Feature {{.CPUfeature}}
   628  func (x {{.VType}}) Not() {{.VType}} {
   629  	return x.Xor(x.Equal(x).AsInt{{.WxC}}().As{{.VType}}())
   630  }
   631  `)
   632  
   633  // CPUfeatureAVX2if8 return AVX2 if the element width is 8,
   634  // otherwise, it returns CPUfeature.  This is for the cpufeature
   635  // of unsigned comparison emulation, which uses shifts for all
   636  // the sizes > 8 (shifts are AVX) but must use broadcast (AVX2)
   637  // for bytes.
   638  func (t templateData) CPUfeatureAVX2if8() string {
   639  	if t.EWidth == 8 {
   640  		return "AVX2"
   641  	}
   642  	return t.CPUfeature()
   643  }
   644  
   645  var avx2UnsignedComparisonsTemplate = shapedTemplateOf(avx2UnsignedComparisons, "avx2 unsigned comparisons", `
   646  // Greater returns a mask whose elements indicate whether x > y
   647  //
   648  // Emulated, CPU Feature {{.CPUfeatureAVX2if8}}
   649  func (x {{.VType}}) Greater(y {{.VType}}) Mask{{.WxC}} {
   650  	a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
   651  {{- if eq .EWidth 8}}
   652  	signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
   653  {{- else}}
   654  	ones := x.Equal(x).AsInt{{.WxC}}()
   655  	signs := ones.ShiftAllLeft({{.EWidth}}-1)
   656  {{- end }}
   657  	return a.Xor(signs).Greater(b.Xor(signs))
   658  }
   659  
   660  // Less returns a mask whose elements indicate whether x < y
   661  //
   662  // Emulated, CPU Feature {{.CPUfeatureAVX2if8}}
   663  func (x {{.VType}}) Less(y {{.VType}}) Mask{{.WxC}} {
   664  	a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
   665  {{- if eq .EWidth 8}}
   666  	signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
   667  {{- else}}
   668  	ones := x.Equal(x).AsInt{{.WxC}}()
   669  	signs := ones.ShiftAllLeft({{.EWidth}}-1)
   670  {{- end }}
   671  	return b.Xor(signs).Greater(a.Xor(signs))
   672  }
   673  
   674  // GreaterEqual returns a mask whose elements indicate whether x >= y
   675  //
   676  // Emulated, CPU Feature {{.CPUfeatureAVX2if8}}
   677  func (x {{.VType}}) GreaterEqual(y {{.VType}}) Mask{{.WxC}} {
   678  	a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
   679  	ones := x.Equal(x).AsInt{{.WxC}}()
   680  {{- if eq .EWidth 8}}
   681  	signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
   682  {{- else}}
   683  	signs := ones.ShiftAllLeft({{.EWidth}}-1)
   684  {{- end }}
   685  	return b.Xor(signs).Greater(a.Xor(signs)).AsInt{{.WxC}}().Xor(ones).asMask()
   686  }
   687  
   688  // LessEqual returns a mask whose elements indicate whether x <= y
   689  //
   690  // Emulated, CPU Feature {{.CPUfeatureAVX2if8}}
   691  func (x {{.VType}}) LessEqual(y {{.VType}}) Mask{{.WxC}} {
   692  	a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
   693  	ones := x.Equal(x).AsInt{{.WxC}}()
   694  {{- if eq .EWidth 8}}
   695  	signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
   696  {{- else}}
   697  	signs := ones.ShiftAllLeft({{.EWidth}}-1)
   698  {{- end }}
   699  	return a.Xor(signs).Greater(b.Xor(signs)).AsInt{{.WxC}}().Xor(ones).asMask()
   700  }
   701  
   702  // NotEqual returns a mask whose elements indicate whether x != y
   703  //
   704  // Emulated, CPU Feature {{.CPUfeature}}
   705  func (x {{.VType}}) NotEqual(y {{.VType}}) Mask{{.WxC}} {
   706  	a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
   707  	ones := x.Equal(x).AsInt{{.WxC}}()
   708  	return a.Equal(b).AsInt{{.WxC}}().Xor(ones).asMask()
   709  }
   710  `)
   711  
   712  var unsafePATemplate = templateOf("unsafe PA helper", `
   713  // pa{{.VType}} returns a type-unsafe pointer to array that can
   714  // only be used with partial load/store operations that only
   715  // access the known-safe portions of the array.
   716  func pa{{.VType}}(s []{{.Etype}}) *[{{.Count}}]{{.Etype}} {
   717  	return (*[{{.Count}}]{{.Etype}})(unsafe.Pointer(&s[0]))
   718  }
   719  `)
   720  
   721  var avx2MaskedTemplate = shapedTemplateOf(avx2Shapes, "avx2 .Masked methods", `
   722  // Masked returns x but with elements zeroed where mask is false.
   723  func (x {{.VType}}) Masked(mask Mask{{.WxC}}) {{.VType}} {
   724  	im := mask.AsInt{{.WxC}}()
   725  {{- if eq .Base "Int" }}
   726  	return im.And(x)
   727  {{- else}}
   728      return x.AsInt{{.WxC}}().And(im).As{{.VType}}()
   729  {{- end -}}
   730  }
   731  
   732  // Merge returns x but with elements set to y where mask is false.
   733  func (x {{.VType}}) Merge(y {{.VType}}, mask Mask{{.WxC}}) {{.VType}} {
   734  {{- if eq .BxC .WxC -}}
   735  	im := mask.AsInt{{.BxC}}()
   736  {{- else}}
   737      im := mask.AsInt{{.WxC}}().AsInt{{.BxC}}()
   738  {{- end -}}
   739  {{- if and (eq .Base "Int") (eq .BxC .WxC) }}
   740  	return y.blend(x, im)
   741  {{- else}}
   742  	ix := x.AsInt{{.BxC}}()
   743  	iy := y.AsInt{{.BxC}}()
   744  	return iy.blend(ix, im).As{{.VType}}()
   745  {{- end -}}
   746  }
   747  `)
   748  
   749  // TODO perhaps write these in ways that work better on AVX512
   750  var avx512MaskedTemplate = shapedTemplateOf(avx512Shapes, "avx512 .Masked methods", `
   751  // Masked returns x but with elements zeroed where mask is false.
   752  func (x {{.VType}}) Masked(mask Mask{{.WxC}}) {{.VType}} {
   753  	im := mask.AsInt{{.WxC}}()
   754  {{- if eq .Base "Int" }}
   755  	return im.And(x)
   756  {{- else}}
   757      return x.AsInt{{.WxC}}().And(im).As{{.VType}}()
   758  {{- end -}}
   759  }
   760  
   761  // Merge returns x but with elements set to y where m is false.
   762  func (x {{.VType}}) Merge(y {{.VType}}, mask Mask{{.WxC}}) {{.VType}} {
   763  {{- if eq .Base "Int" }}
   764  	return y.blendMasked(x, mask)
   765  {{- else}}
   766  	ix := x.AsInt{{.WxC}}()
   767  	iy := y.AsInt{{.WxC}}()
   768  	return iy.blendMasked(ix, mask).As{{.VType}}()
   769  {{- end -}}
   770  }
   771  `)
   772  
   773  func (t templateData) CPUfeatureBC() string {
   774  	switch t.Vwidth {
   775  	case 128:
   776  		return "AVX2"
   777  	case 256:
   778  		return "AVX2"
   779  	case 512:
   780  		if t.EWidth <= 16 {
   781  			return "AVX512BW"
   782  		}
   783  		return "AVX512F"
   784  	}
   785  	panic(fmt.Errorf("unexpected vector width %d", t.Vwidth))
   786  }
   787  
   788  var broadcastTemplate = templateOf("Broadcast functions", `
   789  // Broadcast{{.VType}} returns a vector with the input
   790  // x assigned to all elements of the output.
   791  //
   792  // Emulated, CPU Feature {{.CPUfeatureBC}}
   793  func Broadcast{{.VType}}(x {{.Etype}}) {{.VType}} {
   794  	var z {{.As128BitVec }}
   795  	return z.SetElem(0, x).Broadcast{{.Vwidth}}()
   796  }
   797  `)
   798  
   799  var maskCvtTemplate = templateOf("Mask conversions", `
   800  // ToMask converts from {{.Base}}{{.WxC}} to Mask{{.WxC}}, mask element is set to true when the corresponding vector element is non-zero.
   801  func (from {{.Base}}{{.WxC}}) ToMask() (to Mask{{.WxC}}) {
   802  	return from.NotEqual({{.Base}}{{.WxC}}{})
   803  }
   804  `)
   805  
   806  var stringTemplate = shapedTemplateOf(allShapes, "String methods", `
   807  // String returns a string representation of SIMD vector x
   808  func (x {{.VType}}) String() string {
   809  	var s [{{.Count}}]{{.Etype}}
   810  	x.Store(&s)
   811  	return sliceToString(s[:])
   812  }
   813  `)
   814  
   815  const SIMD = "../../"
   816  const TD = "../../internal/simd_test/"
   817  const SSA = "../../../cmd/compile/internal/ssa/"
   818  
   819  func main() {
   820  	sl := flag.String("sl", SIMD+"slice_gen_amd64.go", "file name for slice operations")
   821  	cm := flag.String("cm", SIMD+"compare_gen_amd64.go", "file name for comparison operations")
   822  	mm := flag.String("mm", SIMD+"maskmerge_gen_amd64.go", "file name for mask/merge operations")
   823  	op := flag.String("op", SIMD+"other_gen_amd64.go", "file name for other operations")
   824  	ush := flag.String("ush", SIMD+"unsafe_helpers.go", "file name for unsafe helpers")
   825  	bh := flag.String("bh", TD+"binary_helpers_test.go", "file name for binary test helpers")
   826  	uh := flag.String("uh", TD+"unary_helpers_test.go", "file name for unary test helpers")
   827  	th := flag.String("th", TD+"ternary_helpers_test.go", "file name for ternary test helpers")
   828  	ch := flag.String("ch", TD+"compare_helpers_test.go", "file name for compare test helpers")
   829  	cmh := flag.String("cmh", TD+"comparemasked_helpers_test.go", "file name for compare-masked test helpers")
   830  	flag.Parse()
   831  
   832  	if *sl != "" {
   833  		one(*sl, unsafePrologue,
   834  			sliceTemplate,
   835  			avx512MaskedLoadSlicePartTemplate,
   836  			avx2MaskedLoadSlicePartTemplate,
   837  			avx2SmallLoadSlicePartTemplate,
   838  		)
   839  	}
   840  	if *cm != "" {
   841  		one(*cm, prologue,
   842  			avx2SignedComparisonsTemplate,
   843  			avx2UnsignedComparisonsTemplate,
   844  		)
   845  	}
   846  	if *mm != "" {
   847  		one(*mm, prologue,
   848  			avx2MaskedTemplate,
   849  			avx512MaskedTemplate,
   850  		)
   851  	}
   852  	if *op != "" {
   853  		one(*op, prologue,
   854  			broadcastTemplate,
   855  			maskCvtTemplate,
   856  			bitWiseIntTemplate,
   857  			bitWiseUintTemplate,
   858  			stringTemplate,
   859  		)
   860  	}
   861  	if *ush != "" {
   862  		one(*ush, unsafePrologue, unsafePATemplate)
   863  	}
   864  	if *uh != "" {
   865  		one(*uh, curryTestPrologue("unary simd methods"), unaryTemplate, unaryToInt32, unaryToUint32, unaryToUint16, unaryFlakyTemplate)
   866  	}
   867  	if *bh != "" {
   868  		one(*bh, curryTestPrologue("binary simd methods"), binaryTemplate)
   869  	}
   870  	if *th != "" {
   871  		one(*th, curryTestPrologue("ternary simd methods"), ternaryTemplate, ternaryFlakyTemplate)
   872  	}
   873  	if *ch != "" {
   874  		one(*ch, curryTestPrologue("simd methods that compare two operands"), compareTemplate)
   875  	}
   876  	if *cmh != "" {
   877  		one(*cmh, curryTestPrologue("simd methods that compare two operands under a mask"), compareMaskedTemplate)
   878  	}
   879  
   880  	nonTemplateRewrites(SSA+"tern_helpers.go", ssaPrologue, classifyBooleanSIMD, ternOpForLogical)
   881  
   882  }
   883  
   884  func ternOpForLogical(out io.Writer) {
   885  	fmt.Fprintf(out, `
   886  func ternOpForLogical(op Op) Op {
   887  	switch op {
   888  `)
   889  
   890  	intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
   891  		wt, ct := w, c
   892  		if wt < 32 {
   893  			wt = 32
   894  			ct = (w * c) / wt
   895  		}
   896  		fmt.Fprintf(out, "case OpAndInt%[1]dx%[2]d, OpOrInt%[1]dx%[2]d, OpXorInt%[1]dx%[2]d,OpAndNotInt%[1]dx%[2]d: return OpternInt%dx%d\n", w, c, wt, ct)
   897  		fmt.Fprintf(out, "case OpAndUint%[1]dx%[2]d, OpOrUint%[1]dx%[2]d, OpXorUint%[1]dx%[2]d,OpAndNotUint%[1]dx%[2]d: return OpternUint%dx%d\n", w, c, wt, ct)
   898  	}, out)
   899  
   900  	fmt.Fprintf(out, `
   901  	}
   902  	return op
   903  }
   904  `)
   905  
   906  }
   907  
   908  func classifyBooleanSIMD(out io.Writer) {
   909  	fmt.Fprintf(out, `
   910  type SIMDLogicalOP uint8
   911  const (
   912  	// boolean simd operations, for reducing expression to VPTERNLOG* instructions
   913  	// sloInterior is set for non-root nodes in logical-op expression trees.
   914  	// the operations are even-numbered.
   915  	sloInterior SIMDLogicalOP = 1
   916  	sloNone SIMDLogicalOP = 2 * iota
   917  	sloAnd
   918  	sloOr
   919  	sloAndNot
   920  	sloXor
   921  	sloNot
   922  )
   923  func classifyBooleanSIMD(v *Value) SIMDLogicalOP {
   924  	switch v.Op {
   925  		case `)
   926  	intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
   927  		op := "And"
   928  		if seq > 0 {
   929  			fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
   930  		} else {
   931  			fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
   932  		}
   933  		seq++
   934  	}, out)
   935  
   936  	fmt.Fprintf(out, `:
   937  		return sloAnd
   938  
   939  		case `)
   940  	intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
   941  		op := "Or"
   942  		if seq > 0 {
   943  			fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
   944  		} else {
   945  			fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
   946  		}
   947  		seq++
   948  	}, out)
   949  
   950  	fmt.Fprintf(out, `:
   951  		return sloOr
   952  
   953  		case `)
   954  	intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
   955  		op := "AndNot"
   956  		if seq > 0 {
   957  			fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
   958  		} else {
   959  			fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
   960  		}
   961  		seq++
   962  	}, out)
   963  
   964  	fmt.Fprintf(out, `:
   965  		return sloAndNot
   966  `)
   967  
   968  	// "Not" is encoded as x.Xor(x.Equal(x).AsInt8x16())
   969  	// i.e. xor.Args[0] == x, xor.Args[1].Op == As...
   970  	// but AsInt8x16 is a pun/passthrough.
   971  
   972  	intShapes.forAllShapes(
   973  		func(seq int, t, upperT string, w, c int, out io.Writer) {
   974  			fmt.Fprintf(out, "case OpXor%s%dx%d: ", upperT, w, c)
   975  			fmt.Fprintf(out, `
   976  				if y := v.Args[1]; y.Op == OpEqual%s%dx%d &&
   977  				   y.Args[0] == y.Args[1] {
   978  				   		return sloNot
   979  				}
   980  				`, upperT, w, c)
   981  			fmt.Fprintf(out, "return sloXor\n")
   982  		}, out)
   983  
   984  	fmt.Fprintf(out, `
   985  	}
   986  	return sloNone
   987  }
   988  `)
   989  }
   990  
   991  // numberLines takes a slice of bytes, and returns a string where each line
   992  // is numbered, starting from 1.
   993  func numberLines(data []byte) string {
   994  	var buf bytes.Buffer
   995  	r := bytes.NewReader(data)
   996  	s := bufio.NewScanner(r)
   997  	for i := 1; s.Scan(); i++ {
   998  		fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
   999  	}
  1000  	return buf.String()
  1001  }
  1002  
  1003  func nonTemplateRewrites(filename string, prologue func(s string, out io.Writer), rewrites ...func(out io.Writer)) {
  1004  	if filename == "" {
  1005  		return
  1006  	}
  1007  
  1008  	ofile := os.Stdout
  1009  
  1010  	if filename != "-" {
  1011  		var err error
  1012  		ofile, err = os.Create(filename)
  1013  		if err != nil {
  1014  			fmt.Fprintf(os.Stderr, "Could not create the output file %s for the generated code, %v", filename, err)
  1015  			os.Exit(1)
  1016  		}
  1017  	}
  1018  
  1019  	out := new(bytes.Buffer)
  1020  
  1021  	prologue("go run genfiles.go", out)
  1022  	for _, rewrite := range rewrites {
  1023  		rewrite(out)
  1024  	}
  1025  
  1026  	b, err := format.Source(out.Bytes())
  1027  	if err != nil {
  1028  		fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
  1029  		fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
  1030  		fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
  1031  		os.Exit(1)
  1032  	} else {
  1033  		ofile.Write(b)
  1034  		ofile.Close()
  1035  	}
  1036  
  1037  }
  1038  
  1039  func one(filename string, prologue func(s string, out io.Writer), sats ...shapeAndTemplate) {
  1040  	if filename == "" {
  1041  		return
  1042  	}
  1043  
  1044  	ofile := os.Stdout
  1045  
  1046  	if filename != "-" {
  1047  		var err error
  1048  		ofile, err = os.Create(filename)
  1049  		if err != nil {
  1050  			fmt.Fprintf(os.Stderr, "Could not create the output file %s for the generated code, %v", filename, err)
  1051  			os.Exit(1)
  1052  		}
  1053  	}
  1054  
  1055  	out := new(bytes.Buffer)
  1056  
  1057  	prologue("go run genfiles.go", out)
  1058  	for _, sat := range sats {
  1059  		sat.forTemplates(out)
  1060  	}
  1061  
  1062  	b, err := format.Source(out.Bytes())
  1063  	if err != nil {
  1064  		fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
  1065  		fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
  1066  		fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
  1067  		os.Exit(1)
  1068  	} else {
  1069  		ofile.Write(b)
  1070  		ofile.Close()
  1071  	}
  1072  
  1073  }
  1074  

View as plain text