1
2
3
4
5 package main
6
7
8
9
10 import (
11 "bufio"
12 "bytes"
13 "flag"
14 "fmt"
15 "go/format"
16 "io"
17 "os"
18 "strings"
19 "text/template"
20 )
21
22 type resultTypeFunc func(t string, w, c int) (ot string, ow int, oc int)
23
24
25 type shapes struct {
26 vecs []int
27 ints []int
28 uints []int
29 floats []int
30 output resultTypeFunc
31 }
32
33
34 type shapeAndTemplate struct {
35 s *shapes
36 t *template.Template
37 }
38
39 func (sat shapeAndTemplate) target(outType string, width int) shapeAndTemplate {
40 newSat := sat
41 newShape := *sat.s
42 newShape.output = func(t string, w, c int) (ot string, ow int, oc int) {
43 oc = c
44 if width*c > 512 {
45 oc = 512 / width
46 } else if width*c < 128 {
47 oc = 128 / width
48 }
49 return outType, width, oc
50 }
51 newSat.s = &newShape
52 return newSat
53 }
54
55 func (sat shapeAndTemplate) targetFixed(outType string, width, count int) shapeAndTemplate {
56 newSat := sat
57 newShape := *sat.s
58 newShape.output = func(t string, w, c int) (ot string, ow int, oc int) {
59 return outType, width, count
60 }
61 newSat.s = &newShape
62 return newSat
63 }
64
65 func (s *shapes) forAllShapes(f func(seq int, t, upperT string, w, c int, out io.Writer), out io.Writer) {
66 vecs := s.vecs
67 ints := s.ints
68 uints := s.uints
69 floats := s.floats
70 seq := 0
71 for _, v := range vecs {
72 for _, w := range ints {
73 c := v / w
74 f(seq, "int", "Int", w, c, out)
75 seq++
76 }
77 for _, w := range uints {
78 c := v / w
79 f(seq, "uint", "Uint", w, c, out)
80 seq++
81 }
82 for _, w := range floats {
83 c := v / w
84 f(seq, "float", "Float", w, c, out)
85 seq++
86 }
87 }
88 }
89
90 var allShapes = &shapes{
91 vecs: []int{128, 256, 512},
92 ints: []int{8, 16, 32, 64},
93 uints: []int{8, 16, 32, 64},
94 floats: []int{32, 64},
95 }
96
97 var intShapes = &shapes{
98 vecs: []int{128, 256, 512},
99 ints: []int{8, 16, 32, 64},
100 }
101
102 var uintShapes = &shapes{
103 vecs: []int{128, 256, 512},
104 uints: []int{8, 16, 32, 64},
105 }
106
107 var floatShapes = &shapes{
108 vecs: []int{128, 256, 512},
109 floats: []int{32, 64},
110 }
111
112 var integerShapes = &shapes{
113 vecs: []int{128, 256, 512},
114 ints: []int{8, 16, 32, 64},
115 uints: []int{8, 16, 32, 64},
116 }
117
118 var avx512Shapes = &shapes{
119 vecs: []int{512},
120 ints: []int{8, 16, 32, 64},
121 uints: []int{8, 16, 32, 64},
122 floats: []int{32, 64},
123 }
124
125 var avx2Shapes = &shapes{
126 vecs: []int{128, 256},
127 ints: []int{8, 16, 32, 64},
128 uints: []int{8, 16, 32, 64},
129 floats: []int{32, 64},
130 }
131
132 var avx2MaskedLoadShapes = &shapes{
133 vecs: []int{128, 256},
134 ints: []int{32, 64},
135 uints: []int{32, 64},
136 floats: []int{32, 64},
137 }
138
139 var avx2SmallLoadPunShapes = &shapes{
140
141 vecs: []int{128, 256},
142 uints: []int{8, 16},
143 }
144
145 var unaryFlaky = &shapes{
146 vecs: []int{128, 256, 512},
147 floats: []int{32, 64},
148 }
149
150 var ternaryFlaky = &shapes{
151 vecs: []int{128, 256, 512},
152 floats: []int{32},
153 }
154
155 var avx2SignedComparisons = &shapes{
156 vecs: []int{128, 256},
157 ints: []int{8, 16, 32, 64},
158 }
159
160 var avx2UnsignedComparisons = &shapes{
161 vecs: []int{128, 256},
162 uints: []int{8, 16, 32, 64},
163 }
164
165 type templateData struct {
166 VType string
167 AOrAn string
168 EWidth int
169 Vwidth int
170 Count int
171 WxC string
172 BxC string
173 Base string
174 Etype string
175 OxFF string
176
177 OVType string
178 OEtype string
179 OEType string
180 OCount int
181 }
182
183 func (t templateData) As128BitVec() string {
184 return fmt.Sprintf("%s%dx%d", t.Base, t.EWidth, 128/t.EWidth)
185 }
186
187 func oneTemplate(t *template.Template, baseType string, width, count int, out io.Writer, rtf resultTypeFunc) {
188 b := width * count
189 if b < 128 || b > 512 {
190 return
191 }
192
193 ot, ow, oc := baseType, width, count
194 if rtf != nil {
195 ot, ow, oc = rtf(ot, ow, oc)
196 if ow*oc > 512 || ow*oc < 128 || ow < 8 || ow > 64 {
197 return
198 }
199
200 if ot == "float" && ow < 32 {
201 return
202 }
203 }
204 ovType := fmt.Sprintf("%s%dx%d", strings.ToUpper(ot[:1])+ot[1:], ow, oc)
205 oeType := fmt.Sprintf("%s%d", ot, ow)
206 oEType := fmt.Sprintf("%s%d", strings.ToUpper(ot[:1])+ot[1:], ow)
207
208 wxc := fmt.Sprintf("%dx%d", width, count)
209 BaseType := strings.ToUpper(baseType[:1]) + baseType[1:]
210 vType := fmt.Sprintf("%s%s", BaseType, wxc)
211 eType := fmt.Sprintf("%s%d", baseType, width)
212
213 bxc := fmt.Sprintf("%dx%d", 8, count*(width/8))
214 aOrAn := "a"
215 if strings.Contains("aeiou", baseType[:1]) {
216 aOrAn = "an"
217 }
218 oxFF := fmt.Sprintf("0x%x", uint64((1<<count)-1))
219 t.Execute(out, templateData{
220 VType: vType,
221 AOrAn: aOrAn,
222 EWidth: width,
223 Vwidth: b,
224 Count: count,
225 WxC: wxc,
226 BxC: bxc,
227 Base: BaseType,
228 Etype: eType,
229 OxFF: oxFF,
230 OVType: ovType,
231 OEtype: oeType,
232 OCount: oc,
233 OEType: oEType,
234 })
235 }
236
237
238
239 func (sat shapeAndTemplate) forTemplates(out io.Writer) {
240 t, s := sat.t, sat.s
241 vecs := s.vecs
242 ints := s.ints
243 uints := s.uints
244 floats := s.floats
245 for _, v := range vecs {
246 for _, w := range ints {
247 c := v / w
248 oneTemplate(t, "int", w, c, out, sat.s.output)
249 }
250 for _, w := range uints {
251 c := v / w
252 oneTemplate(t, "uint", w, c, out, sat.s.output)
253 }
254 for _, w := range floats {
255 c := v / w
256 oneTemplate(t, "float", w, c, out, sat.s.output)
257 }
258 }
259 }
260
261 func prologue(s string, out io.Writer) {
262 fmt.Fprintf(out,
263 `// Code generated by '%s'; DO NOT EDIT.
264
265 //go:build goexperiment.simd
266
267 package archsimd
268
269 `, s)
270 }
271
272 func ssaPrologue(s string, out io.Writer) {
273 fmt.Fprintf(out,
274 `// Code generated by '%s'; DO NOT EDIT.
275
276 package ssa
277
278 `, s)
279 }
280
281 func unsafePrologue(s string, out io.Writer) {
282 fmt.Fprintf(out,
283 `// Code generated by '%s'; DO NOT EDIT.
284
285 //go:build goexperiment.simd
286
287 package archsimd
288
289 import "unsafe"
290
291 `, s)
292 }
293
294 func testPrologue(t, s string, out io.Writer) {
295 fmt.Fprintf(out,
296 `// Code generated by '%s'; DO NOT EDIT.
297
298 //go:build goexperiment.simd && amd64
299
300 // This file contains functions testing %s.
301 // Each function in this file is specialized for a
302 // particular simd type <BaseType><Width>x<Count>.
303
304 package simd_test
305
306 import (
307 "simd/archsimd"
308 "testing"
309 )
310
311 `, s, t)
312 }
313
314 func curryTestPrologue(t string) func(s string, out io.Writer) {
315 return func(s string, out io.Writer) {
316 testPrologue(t, s, out)
317 }
318 }
319
320 func templateOf(name, temp string) shapeAndTemplate {
321 return shapeAndTemplate{s: allShapes,
322 t: template.Must(template.New(name).Parse(temp))}
323 }
324
325 func shapedTemplateOf(s *shapes, name, temp string) shapeAndTemplate {
326 return shapeAndTemplate{s: s,
327 t: template.Must(template.New(name).Parse(temp))}
328 }
329
330 var sliceTemplate = templateOf("slice", `
331 // Load{{.VType}}Slice loads {{.AOrAn}} {{.VType}} from a slice of at least {{.Count}} {{.Etype}}s.
332 func Load{{.VType}}Slice(s []{{.Etype}}) {{.VType}} {
333 return Load{{.VType}}((*[{{.Count}}]{{.Etype}})(s))
334 }
335
336 // StoreSlice stores x into a slice of at least {{.Count}} {{.Etype}}s.
337 func (x {{.VType}}) StoreSlice(s []{{.Etype}}) {
338 x.Store((*[{{.Count}}]{{.Etype}})(s))
339 }
340 `)
341
342 var unaryTemplate = templateOf("unary_helpers", `
343 // test{{.VType}}Unary tests the simd unary method f against the expected behavior generated by want
344 func test{{.VType}}Unary(t *testing.T, f func(_ archsimd.{{.VType}}) archsimd.{{.VType}}, want func(_ []{{.Etype}}) []{{.Etype}}) {
345 n := {{.Count}}
346 t.Helper()
347 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
348 t.Helper()
349 a := archsimd.Load{{.VType}}Slice(x)
350 g := make([]{{.Etype}}, n)
351 f(a).StoreSlice(g)
352 w := want(x)
353 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
354 })
355 }
356 `)
357
358 var unaryFlakyTemplate = shapedTemplateOf(unaryFlaky, "unary_flaky_helpers", `
359 // test{{.VType}}UnaryFlaky tests the simd unary method f against the expected behavior generated by want,
360 // but using a flakiness parameter because we haven't exactly figured out how simd floating point works
361 func test{{.VType}}UnaryFlaky(t *testing.T, f func(x archsimd.{{.VType}}) archsimd.{{.VType}}, want func(x []{{.Etype}}) []{{.Etype}}, flakiness float64) {
362 n := {{.Count}}
363 t.Helper()
364 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
365 t.Helper()
366 a := archsimd.Load{{.VType}}Slice(x)
367 g := make([]{{.Etype}}, n)
368 f(a).StoreSlice(g)
369 w := want(x)
370 return checkSlicesLogInput(t, g, w, flakiness, func() {t.Helper(); t.Logf("x=%v", x)})
371 })
372 }
373 `)
374
375 var convertTemplate = templateOf("convert_helpers", `
376 // test{{.VType}}ConvertTo{{.OEType}} tests the simd conversion method f against the expected behavior generated by want.
377 // This is for count-preserving conversions, so if there is a change in size, then there is a change in vector width,
378 // (extended to at least 128 bits, or truncated to at most 512 bits).
379 func test{{.VType}}ConvertTo{{.OEType}}(t *testing.T, f func(x archsimd.{{.VType}}) archsimd.{{.OVType}}, want func(x []{{.Etype}}) []{{.OEtype}}) {
380 n := {{.Count}}
381 t.Helper()
382 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
383 t.Helper()
384 a := archsimd.Load{{.VType}}Slice(x)
385 g := make([]{{.OEtype}}, {{.OCount}})
386 f(a).StoreSlice(g)
387 w := want(x)
388 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
389 })
390 }
391 `)
392
393 var (
394
395
396
397 unaryToInt8 = convertTemplate.target("int", 8)
398 unaryToUint8 = convertTemplate.target("uint", 8)
399 unaryToInt16 = convertTemplate.target("int", 16)
400 unaryToUint16 = convertTemplate.target("uint", 16)
401 unaryToInt32 = convertTemplate.target("int", 32)
402 unaryToUint32 = convertTemplate.target("uint", 32)
403 unaryToInt64 = convertTemplate.target("int", 64)
404 unaryToUint64 = convertTemplate.target("uint", 64)
405 unaryToFloat32 = convertTemplate.target("float", 32)
406 unaryToFloat64 = convertTemplate.target("float", 64)
407 )
408
409 var convertLoTemplate = shapedTemplateOf(integerShapes, "convert_lo_helpers", `
410 // test{{.VType}}ConvertLoTo{{.OVType}} tests the simd conversion method f against the expected behavior generated by want.
411 // This converts only the low {{.OCount}} elements.
412 func test{{.VType}}ConvertLoTo{{.OVType}}(t *testing.T, f func(x archsimd.{{.VType}}) archsimd.{{.OVType}}, want func(x []{{.Etype}}) []{{.OEtype}}) {
413 n := {{.Count}}
414 t.Helper()
415 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
416 t.Helper()
417 a := archsimd.Load{{.VType}}Slice(x)
418 g := make([]{{.OEtype}}, {{.OCount}})
419 f(a).StoreSlice(g)
420 w := want(x)
421 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
422 })
423 }
424 `)
425
426 var (
427
428
429
430
431
432 unaryToInt64x2 = convertLoTemplate.targetFixed("int", 64, 2)
433 unaryToInt64x4 = convertLoTemplate.targetFixed("int", 64, 4)
434 unaryToUint64x2 = convertLoTemplate.targetFixed("uint", 64, 2)
435 unaryToUint64x4 = convertLoTemplate.targetFixed("uint", 64, 4)
436 unaryToInt32x4 = convertLoTemplate.targetFixed("int", 32, 4)
437 unaryToInt32x8 = convertLoTemplate.targetFixed("int", 32, 8)
438 unaryToUint32x4 = convertLoTemplate.targetFixed("uint", 32, 4)
439 unaryToUint32x8 = convertLoTemplate.targetFixed("uint", 32, 8)
440 unaryToInt16x8 = convertLoTemplate.targetFixed("int", 16, 8)
441 unaryToUint16x8 = convertLoTemplate.targetFixed("uint", 16, 8)
442 )
443
444 var binaryTemplate = templateOf("binary_helpers", `
445 // test{{.VType}}Binary tests the simd binary method f against the expected behavior generated by want
446 func test{{.VType}}Binary(t *testing.T, f func(_, _ archsimd.{{.VType}}) archsimd.{{.VType}}, want func(_, _ []{{.Etype}}) []{{.Etype}}) {
447 n := {{.Count}}
448 t.Helper()
449 forSlicePair(t, {{.Etype}}s, n, func(x, y []{{.Etype}}) bool {
450 t.Helper()
451 a := archsimd.Load{{.VType}}Slice(x)
452 b := archsimd.Load{{.VType}}Slice(y)
453 g := make([]{{.Etype}}, n)
454 f(a, b).StoreSlice(g)
455 w := want(x, y)
456 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); })
457 })
458 }
459 `)
460
461 var ternaryTemplate = templateOf("ternary_helpers", `
462 // test{{.VType}}Ternary tests the simd ternary method f against the expected behavior generated by want
463 func test{{.VType}}Ternary(t *testing.T, f func(_, _, _ archsimd.{{.VType}}) archsimd.{{.VType}}, want func(_, _, _ []{{.Etype}}) []{{.Etype}}) {
464 n := {{.Count}}
465 t.Helper()
466 forSliceTriple(t, {{.Etype}}s, n, func(x, y, z []{{.Etype}}) bool {
467 t.Helper()
468 a := archsimd.Load{{.VType}}Slice(x)
469 b := archsimd.Load{{.VType}}Slice(y)
470 c := archsimd.Load{{.VType}}Slice(z)
471 g := make([]{{.Etype}}, n)
472 f(a, b, c).StoreSlice(g)
473 w := want(x, y, z)
474 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); t.Logf("z=%v", z); })
475 })
476 }
477 `)
478
479 var ternaryFlakyTemplate = shapedTemplateOf(ternaryFlaky, "ternary_helpers", `
480 // test{{.VType}}TernaryFlaky tests the simd ternary method f against the expected behavior generated by want,
481 // but using a flakiness parameter because we haven't exactly figured out how simd floating point works
482 func test{{.VType}}TernaryFlaky(t *testing.T, f func(x, y, z archsimd.{{.VType}}) archsimd.{{.VType}}, want func(x, y, z []{{.Etype}}) []{{.Etype}}, flakiness float64) {
483 n := {{.Count}}
484 t.Helper()
485 forSliceTriple(t, {{.Etype}}s, n, func(x, y, z []{{.Etype}}) bool {
486 t.Helper()
487 a := archsimd.Load{{.VType}}Slice(x)
488 b := archsimd.Load{{.VType}}Slice(y)
489 c := archsimd.Load{{.VType}}Slice(z)
490 g := make([]{{.Etype}}, n)
491 f(a, b, c).StoreSlice(g)
492 w := want(x, y, z)
493 return checkSlicesLogInput(t, g, w, flakiness, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); t.Logf("z=%v", z); })
494 })
495 }
496 `)
497
498 var compareTemplate = templateOf("compare_helpers", `
499 // test{{.VType}}Compare tests the simd comparison method f against the expected behavior generated by want
500 func test{{.VType}}Compare(t *testing.T, f func(_, _ archsimd.{{.VType}}) archsimd.Mask{{.WxC}}, want func(_, _ []{{.Etype}}) []int64) {
501 n := {{.Count}}
502 t.Helper()
503 forSlicePair(t, {{.Etype}}s, n, func(x, y []{{.Etype}}) bool {
504 t.Helper()
505 a := archsimd.Load{{.VType}}Slice(x)
506 b := archsimd.Load{{.VType}}Slice(y)
507 g := make([]int{{.EWidth}}, n)
508 f(a, b).ToInt{{.WxC}}().StoreSlice(g)
509 w := want(x, y)
510 return checkSlicesLogInput(t, s64(g), w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); })
511 })
512 }
513 `)
514
515 var compareUnaryTemplate = shapedTemplateOf(floatShapes, "compare_unary_helpers", `
516 // test{{.VType}}UnaryCompare tests the simd unary comparison method f against the expected behavior generated by want
517 func test{{.VType}}UnaryCompare(t *testing.T, f func(x archsimd.{{.VType}}) archsimd.Mask{{.WxC}}, want func(x []{{.Etype}}) []int64) {
518 n := {{.Count}}
519 t.Helper()
520 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
521 t.Helper()
522 a := archsimd.Load{{.VType}}Slice(x)
523 g := make([]int{{.EWidth}}, n)
524 f(a).ToInt{{.WxC}}().StoreSlice(g)
525 w := want(x)
526 return checkSlicesLogInput(t, s64(g), w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
527 })
528 }
529 `)
530
531
532 var compareMaskedTemplate = templateOf("comparemasked_helpers", `
533 // test{{.VType}}CompareMasked tests the simd masked comparison method f against the expected behavior generated by want
534 // The mask is applied to the output of want; anything not in the mask, is zeroed.
535 func test{{.VType}}CompareMasked(t *testing.T,
536 f func(_, _ archsimd.{{.VType}}, m archsimd.Mask{{.WxC}}) archsimd.Mask{{.WxC}},
537 want func(_, _ []{{.Etype}}) []int64) {
538 n := {{.Count}}
539 t.Helper()
540 forSlicePairMasked(t, {{.Etype}}s, n, func(x, y []{{.Etype}}, m []bool) bool {
541 t.Helper()
542 a := archsimd.Load{{.VType}}Slice(x)
543 b := archsimd.Load{{.VType}}Slice(y)
544 k := archsimd.LoadInt{{.WxC}}Slice(toVect[int{{.EWidth}}](m)).ToMask()
545 g := make([]int{{.EWidth}}, n)
546 f(a, b, k).ToInt{{.WxC}}().StoreSlice(g)
547 w := want(x, y)
548 for i := range m {
549 if !m[i] {
550 w[i] = 0
551 }
552 }
553 return checkSlicesLogInput(t, s64(g), w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); t.Logf("m=%v", m); })
554 })
555 }
556 `)
557
558 var avx512MaskedLoadSlicePartTemplate = shapedTemplateOf(avx512Shapes, "avx 512 load slice part", `
559 // Load{{.VType}}SlicePart loads a {{.VType}} from the slice s.
560 // If s has fewer than {{.Count}} elements, the remaining elements of the vector are filled with zeroes.
561 // If s has {{.Count}} or more elements, the function is equivalent to Load{{.VType}}Slice.
562 func Load{{.VType}}SlicePart(s []{{.Etype}}) {{.VType}} {
563 l := len(s)
564 if l >= {{.Count}} {
565 return Load{{.VType}}Slice(s)
566 }
567 if l == 0 {
568 var x {{.VType}}
569 return x
570 }
571 mask := Mask{{.WxC}}FromBits({{.OxFF}} >> ({{.Count}} - l))
572 return LoadMasked{{.VType}}(pa{{.VType}}(s), mask)
573 }
574
575 // StoreSlicePart stores the {{.Count}} elements of x into the slice s.
576 // It stores as many elements as will fit in s.
577 // If s has {{.Count}} or more elements, the method is equivalent to x.StoreSlice.
578 func (x {{.VType}}) StoreSlicePart(s []{{.Etype}}) {
579 l := len(s)
580 if l >= {{.Count}} {
581 x.StoreSlice(s)
582 return
583 }
584 if l == 0 {
585 return
586 }
587 mask := Mask{{.WxC}}FromBits({{.OxFF}} >> ({{.Count}} - l))
588 x.StoreMasked(pa{{.VType}}(s), mask)
589 }
590 `)
591
592 var avx2MaskedLoadSlicePartTemplate = shapedTemplateOf(avx2MaskedLoadShapes, "avx 2 load slice part", `
593 // Load{{.VType}}SlicePart loads a {{.VType}} from the slice s.
594 // If s has fewer than {{.Count}} elements, the remaining elements of the vector are filled with zeroes.
595 // If s has {{.Count}} or more elements, the function is equivalent to Load{{.VType}}Slice.
596 func Load{{.VType}}SlicePart(s []{{.Etype}}) {{.VType}} {
597 l := len(s)
598 if l >= {{.Count}} {
599 return Load{{.VType}}Slice(s)
600 }
601 if l == 0 {
602 var x {{.VType}}
603 return x
604 }
605 mask := vecMask{{.EWidth}}[len(vecMask{{.EWidth}})/2-l:]
606 return LoadMasked{{.VType}}(pa{{.VType}}(s), LoadInt{{.WxC}}Slice(mask).asMask())
607 }
608
609 // StoreSlicePart stores the {{.Count}} elements of x into the slice s.
610 // It stores as many elements as will fit in s.
611 // If s has {{.Count}} or more elements, the method is equivalent to x.StoreSlice.
612 func (x {{.VType}}) StoreSlicePart(s []{{.Etype}}) {
613 l := len(s)
614 if l >= {{.Count}} {
615 x.StoreSlice(s)
616 return
617 }
618 if l == 0 {
619 return
620 }
621 mask := vecMask{{.EWidth}}[len(vecMask{{.EWidth}})/2-l:]
622 x.StoreMasked(pa{{.VType}}(s), LoadInt{{.WxC}}Slice(mask).asMask())
623 }
624 `)
625
626 var avx2SmallLoadSlicePartTemplate = shapedTemplateOf(avx2SmallLoadPunShapes, "avx 2 small load slice part", `
627 // Load{{.VType}}SlicePart loads a {{.VType}} from the slice s.
628 // If s has fewer than {{.Count}} elements, the remaining elements of the vector are filled with zeroes.
629 // If s has {{.Count}} or more elements, the function is equivalent to Load{{.VType}}Slice.
630 func Load{{.VType}}SlicePart(s []{{.Etype}}) {{.VType}} {
631 if len(s) == 0 {
632 var zero {{.VType}}
633 return zero
634 }
635 t := unsafe.Slice((*int{{.EWidth}})(unsafe.Pointer(&s[0])), len(s))
636 return LoadInt{{.WxC}}SlicePart(t).As{{.VType}}()
637 }
638
639 // StoreSlicePart stores the {{.Count}} elements of x into the slice s.
640 // It stores as many elements as will fit in s.
641 // If s has {{.Count}} or more elements, the method is equivalent to x.StoreSlice.
642 func (x {{.VType}}) StoreSlicePart(s []{{.Etype}}) {
643 if len(s) == 0 {
644 return
645 }
646 t := unsafe.Slice((*int{{.EWidth}})(unsafe.Pointer(&s[0])), len(s))
647 x.AsInt{{.WxC}}().StoreSlicePart(t)
648 }
649 `)
650
651 func (t templateData) CPUfeature() string {
652 switch t.Vwidth {
653 case 128:
654 return "AVX"
655 case 256:
656 return "AVX2"
657 case 512:
658 return "AVX512"
659 }
660 panic(fmt.Errorf("unexpected vector width %d", t.Vwidth))
661 }
662
663 var avx2SignedComparisonsTemplate = shapedTemplateOf(avx2SignedComparisons, "avx2 signed comparisons", `
664 // Less returns a mask whose elements indicate whether x < y.
665 //
666 // Emulated, CPU Feature: {{.CPUfeature}}
667 func (x {{.VType}}) Less(y {{.VType}}) Mask{{.WxC}} {
668 return y.Greater(x)
669 }
670
671 // GreaterEqual returns a mask whose elements indicate whether x >= y.
672 //
673 // Emulated, CPU Feature: {{.CPUfeature}}
674 func (x {{.VType}}) GreaterEqual(y {{.VType}}) Mask{{.WxC}} {
675 ones := x.Equal(x).ToInt{{.WxC}}()
676 return y.Greater(x).ToInt{{.WxC}}().Xor(ones).asMask()
677 }
678
679 // LessEqual returns a mask whose elements indicate whether x <= y.
680 //
681 // Emulated, CPU Feature: {{.CPUfeature}}
682 func (x {{.VType}}) LessEqual(y {{.VType}}) Mask{{.WxC}} {
683 ones := x.Equal(x).ToInt{{.WxC}}()
684 return x.Greater(y).ToInt{{.WxC}}().Xor(ones).asMask()
685 }
686
687 // NotEqual returns a mask whose elements indicate whether x != y.
688 //
689 // Emulated, CPU Feature: {{.CPUfeature}}
690 func (x {{.VType}}) NotEqual(y {{.VType}}) Mask{{.WxC}} {
691 ones := x.Equal(x).ToInt{{.WxC}}()
692 return x.Equal(y).ToInt{{.WxC}}().Xor(ones).asMask()
693 }
694 `)
695
696 var bitWiseIntTemplate = shapedTemplateOf(intShapes, "bitwise int complement", `
697 // Not returns the bitwise complement of x.
698 //
699 // Emulated, CPU Feature: {{.CPUfeature}}
700 func (x {{.VType}}) Not() {{.VType}} {
701 return x.Xor(x.Equal(x).ToInt{{.WxC}}())
702 }
703 `)
704
705 var bitWiseUintTemplate = shapedTemplateOf(uintShapes, "bitwise uint complement", `
706 // Not returns the bitwise complement of x.
707 //
708 // Emulated, CPU Feature: {{.CPUfeature}}
709 func (x {{.VType}}) Not() {{.VType}} {
710 return x.Xor(x.Equal(x).ToInt{{.WxC}}().As{{.VType}}())
711 }
712 `)
713
714
715
716
717
718
719 func (t templateData) CPUfeatureAVX2if8() string {
720 if t.EWidth == 8 {
721 return "AVX2"
722 }
723 return t.CPUfeature()
724 }
725
726 var avx2UnsignedComparisonsTemplate = shapedTemplateOf(avx2UnsignedComparisons, "avx2 unsigned comparisons", `
727 // Greater returns a mask whose elements indicate whether x > y.
728 //
729 // Emulated, CPU Feature: {{.CPUfeatureAVX2if8}}
730 func (x {{.VType}}) Greater(y {{.VType}}) Mask{{.WxC}} {
731 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
732 {{- if eq .EWidth 8}}
733 signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
734 {{- else}}
735 ones := x.Equal(x).ToInt{{.WxC}}()
736 signs := ones.ShiftAllLeft({{.EWidth}}-1)
737 {{- end }}
738 return a.Xor(signs).Greater(b.Xor(signs))
739 }
740
741 // Less returns a mask whose elements indicate whether x < y.
742 //
743 // Emulated, CPU Feature: {{.CPUfeatureAVX2if8}}
744 func (x {{.VType}}) Less(y {{.VType}}) Mask{{.WxC}} {
745 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
746 {{- if eq .EWidth 8}}
747 signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
748 {{- else}}
749 ones := x.Equal(x).ToInt{{.WxC}}()
750 signs := ones.ShiftAllLeft({{.EWidth}}-1)
751 {{- end }}
752 return b.Xor(signs).Greater(a.Xor(signs))
753 }
754
755 // GreaterEqual returns a mask whose elements indicate whether x >= y.
756 //
757 // Emulated, CPU Feature: {{.CPUfeatureAVX2if8}}
758 func (x {{.VType}}) GreaterEqual(y {{.VType}}) Mask{{.WxC}} {
759 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
760 ones := x.Equal(x).ToInt{{.WxC}}()
761 {{- if eq .EWidth 8}}
762 signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
763 {{- else}}
764 signs := ones.ShiftAllLeft({{.EWidth}}-1)
765 {{- end }}
766 return b.Xor(signs).Greater(a.Xor(signs)).ToInt{{.WxC}}().Xor(ones).asMask()
767 }
768
769 // LessEqual returns a mask whose elements indicate whether x <= y.
770 //
771 // Emulated, CPU Feature: {{.CPUfeatureAVX2if8}}
772 func (x {{.VType}}) LessEqual(y {{.VType}}) Mask{{.WxC}} {
773 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
774 ones := x.Equal(x).ToInt{{.WxC}}()
775 {{- if eq .EWidth 8}}
776 signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
777 {{- else}}
778 signs := ones.ShiftAllLeft({{.EWidth}}-1)
779 {{- end }}
780 return a.Xor(signs).Greater(b.Xor(signs)).ToInt{{.WxC}}().Xor(ones).asMask()
781 }
782
783 // NotEqual returns a mask whose elements indicate whether x != y.
784 //
785 // Emulated, CPU Feature: {{.CPUfeature}}
786 func (x {{.VType}}) NotEqual(y {{.VType}}) Mask{{.WxC}} {
787 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
788 ones := x.Equal(x).ToInt{{.WxC}}()
789 return a.Equal(b).ToInt{{.WxC}}().Xor(ones).asMask()
790 }
791 `)
792
793 var unsafePATemplate = templateOf("unsafe PA helper", `
794 // pa{{.VType}} returns a type-unsafe pointer to array that can
795 // only be used with partial load/store operations that only
796 // access the known-safe portions of the array.
797 func pa{{.VType}}(s []{{.Etype}}) *[{{.Count}}]{{.Etype}} {
798 return (*[{{.Count}}]{{.Etype}})(unsafe.Pointer(&s[0]))
799 }
800 `)
801
802 var avx2MaskedTemplate = shapedTemplateOf(avx2Shapes, "avx2 .Masked methods", `
803 // Masked returns x but with elements zeroed where mask is false.
804 func (x {{.VType}}) Masked(mask Mask{{.WxC}}) {{.VType}} {
805 im := mask.ToInt{{.WxC}}()
806 {{- if eq .Base "Int" }}
807 return im.And(x)
808 {{- else}}
809 return x.AsInt{{.WxC}}().And(im).As{{.VType}}()
810 {{- end -}}
811 }
812
813 // Merge returns x but with elements set to y where mask is false.
814 func (x {{.VType}}) Merge(y {{.VType}}, mask Mask{{.WxC}}) {{.VType}} {
815 {{- if eq .BxC .WxC -}}
816 im := mask.ToInt{{.BxC}}()
817 {{- else}}
818 im := mask.ToInt{{.WxC}}().AsInt{{.BxC}}()
819 {{- end -}}
820 {{- if and (eq .Base "Int") (eq .BxC .WxC) }}
821 return y.blend(x, im)
822 {{- else}}
823 ix := x.AsInt{{.BxC}}()
824 iy := y.AsInt{{.BxC}}()
825 return iy.blend(ix, im).As{{.VType}}()
826 {{- end -}}
827 }
828 `)
829
830
831 var avx512MaskedTemplate = shapedTemplateOf(avx512Shapes, "avx512 .Masked methods", `
832 // Masked returns x but with elements zeroed where mask is false.
833 func (x {{.VType}}) Masked(mask Mask{{.WxC}}) {{.VType}} {
834 im := mask.ToInt{{.WxC}}()
835 {{- if eq .Base "Int" }}
836 return im.And(x)
837 {{- else}}
838 return x.AsInt{{.WxC}}().And(im).As{{.VType}}()
839 {{- end -}}
840 }
841
842 // Merge returns x but with elements set to y where mask is false.
843 func (x {{.VType}}) Merge(y {{.VType}}, mask Mask{{.WxC}}) {{.VType}} {
844 {{- if eq .Base "Int" }}
845 return y.blendMasked(x, mask)
846 {{- else}}
847 ix := x.AsInt{{.WxC}}()
848 iy := y.AsInt{{.WxC}}()
849 return iy.blendMasked(ix, mask).As{{.VType}}()
850 {{- end -}}
851 }
852 `)
853
854 func (t templateData) CPUfeatureBC() string {
855 switch t.Vwidth {
856 case 128:
857 return "AVX2"
858 case 256:
859 return "AVX2"
860 case 512:
861 if t.EWidth <= 16 {
862 return "AVX512BW"
863 }
864 return "AVX512F"
865 }
866 panic(fmt.Errorf("unexpected vector width %d", t.Vwidth))
867 }
868
869 var broadcastTemplate = templateOf("Broadcast functions", `
870 // Broadcast{{.VType}} returns a vector with the input
871 // x assigned to all elements of the output.
872 //
873 // Emulated, CPU Feature: {{.CPUfeatureBC}}
874 func Broadcast{{.VType}}(x {{.Etype}}) {{.VType}} {
875 var z {{.As128BitVec }}
876 return z.SetElem(0, x).Broadcast1To{{.Count}}()
877 }
878 `)
879
880 var maskCvtTemplate = shapedTemplateOf(intShapes, "Mask conversions", `
881 // ToMask converts from {{.Base}}{{.WxC}} to Mask{{.WxC}}, mask element is set to true when the corresponding vector element is non-zero.
882 func (from {{.Base}}{{.WxC}}) ToMask() (to Mask{{.WxC}}) {
883 return from.NotEqual({{.Base}}{{.WxC}}{})
884 }
885 `)
886
887 var stringTemplate = shapedTemplateOf(allShapes, "String methods", `
888 // String returns a string representation of SIMD vector x.
889 func (x {{.VType}}) String() string {
890 var s [{{.Count}}]{{.Etype}}
891 x.Store(&s)
892 return sliceToString(s[:])
893 }
894 `)
895
896 const SIMD = "../../"
897 const TD = "../../internal/simd_test/"
898 const SSA = "../../../../cmd/compile/internal/ssa/"
899
900 func main() {
901 sl := flag.String("sl", SIMD+"slice_gen_amd64.go", "file name for slice operations")
902 cm := flag.String("cm", SIMD+"compare_gen_amd64.go", "file name for comparison operations")
903 mm := flag.String("mm", SIMD+"maskmerge_gen_amd64.go", "file name for mask/merge operations")
904 op := flag.String("op", SIMD+"other_gen_amd64.go", "file name for other operations")
905 ush := flag.String("ush", SIMD+"unsafe_helpers.go", "file name for unsafe helpers")
906 bh := flag.String("bh", TD+"binary_helpers_test.go", "file name for binary test helpers")
907 uh := flag.String("uh", TD+"unary_helpers_test.go", "file name for unary test helpers")
908 th := flag.String("th", TD+"ternary_helpers_test.go", "file name for ternary test helpers")
909 ch := flag.String("ch", TD+"compare_helpers_test.go", "file name for compare test helpers")
910 cmh := flag.String("cmh", TD+"comparemasked_helpers_test.go", "file name for compare-masked test helpers")
911 flag.Parse()
912
913 if *sl != "" {
914 one(*sl, unsafePrologue,
915 sliceTemplate,
916 avx512MaskedLoadSlicePartTemplate,
917 avx2MaskedLoadSlicePartTemplate,
918 avx2SmallLoadSlicePartTemplate,
919 )
920 }
921 if *cm != "" {
922 one(*cm, prologue,
923 avx2SignedComparisonsTemplate,
924 avx2UnsignedComparisonsTemplate,
925 )
926 }
927 if *mm != "" {
928 one(*mm, prologue,
929 avx2MaskedTemplate,
930 avx512MaskedTemplate,
931 )
932 }
933 if *op != "" {
934 one(*op, prologue,
935 broadcastTemplate,
936 maskCvtTemplate,
937 bitWiseIntTemplate,
938 bitWiseUintTemplate,
939 stringTemplate,
940 )
941 }
942 if *ush != "" {
943 one(*ush, unsafePrologue, unsafePATemplate)
944 }
945 if *uh != "" {
946 one(*uh, curryTestPrologue("unary simd methods"), unaryTemplate,
947 unaryToInt8, unaryToUint8, unaryToInt16, unaryToUint16,
948 unaryToInt32, unaryToUint32, unaryToInt64, unaryToUint64,
949 unaryToFloat32, unaryToFloat64,
950 unaryToInt64x2, unaryToInt64x4,
951 unaryToUint64x2, unaryToUint64x4,
952 unaryToInt32x4, unaryToInt32x8,
953 unaryToUint32x4, unaryToUint32x8,
954 unaryToInt16x8, unaryToUint16x8,
955 unaryFlakyTemplate,
956 )
957 }
958 if *bh != "" {
959 one(*bh, curryTestPrologue("binary simd methods"), binaryTemplate)
960 }
961 if *th != "" {
962 one(*th, curryTestPrologue("ternary simd methods"), ternaryTemplate, ternaryFlakyTemplate)
963 }
964 if *ch != "" {
965 one(*ch, curryTestPrologue("simd methods that compare two operands"), compareTemplate, compareUnaryTemplate)
966 }
967 if *cmh != "" {
968 one(*cmh, curryTestPrologue("simd methods that compare two operands under a mask"), compareMaskedTemplate)
969 }
970
971 nonTemplateRewrites(SSA+"tern_helpers.go", ssaPrologue, classifyBooleanSIMD, ternOpForLogical)
972
973 }
974
975 func ternOpForLogical(out io.Writer) {
976 fmt.Fprintf(out, `
977 func ternOpForLogical(op Op) Op {
978 switch op {
979 `)
980
981 intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
982 wt, ct := w, c
983 if wt < 32 {
984 wt = 32
985 ct = (w * c) / wt
986 }
987 fmt.Fprintf(out, "case OpAndInt%[1]dx%[2]d, OpOrInt%[1]dx%[2]d, OpXorInt%[1]dx%[2]d,OpAndNotInt%[1]dx%[2]d: return OpternInt%dx%d\n", w, c, wt, ct)
988 fmt.Fprintf(out, "case OpAndUint%[1]dx%[2]d, OpOrUint%[1]dx%[2]d, OpXorUint%[1]dx%[2]d,OpAndNotUint%[1]dx%[2]d: return OpternUint%dx%d\n", w, c, wt, ct)
989 }, out)
990
991 fmt.Fprintf(out, `
992 }
993 return op
994 }
995 `)
996
997 }
998
999 func classifyBooleanSIMD(out io.Writer) {
1000 fmt.Fprintf(out, `
1001 type SIMDLogicalOP uint8
1002 const (
1003 // boolean simd operations, for reducing expression to VPTERNLOG* instructions
1004 // sloInterior is set for non-root nodes in logical-op expression trees.
1005 // the operations are even-numbered.
1006 sloInterior SIMDLogicalOP = 1
1007 sloNone SIMDLogicalOP = 2 * iota
1008 sloAnd
1009 sloOr
1010 sloAndNot
1011 sloXor
1012 sloNot
1013 )
1014 func classifyBooleanSIMD(v *Value) SIMDLogicalOP {
1015 switch v.Op {
1016 case `)
1017 intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
1018 op := "And"
1019 if seq > 0 {
1020 fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
1021 } else {
1022 fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
1023 }
1024 seq++
1025 }, out)
1026
1027 fmt.Fprintf(out, `:
1028 return sloAnd
1029
1030 case `)
1031 intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
1032 op := "Or"
1033 if seq > 0 {
1034 fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
1035 } else {
1036 fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
1037 }
1038 seq++
1039 }, out)
1040
1041 fmt.Fprintf(out, `:
1042 return sloOr
1043
1044 case `)
1045 intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
1046 op := "AndNot"
1047 if seq > 0 {
1048 fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
1049 } else {
1050 fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
1051 }
1052 seq++
1053 }, out)
1054
1055 fmt.Fprintf(out, `:
1056 return sloAndNot
1057 `)
1058
1059
1060
1061
1062
1063 intShapes.forAllShapes(
1064 func(seq int, t, upperT string, w, c int, out io.Writer) {
1065 fmt.Fprintf(out, "case OpXor%s%dx%d: ", upperT, w, c)
1066 fmt.Fprintf(out, `
1067 if y := v.Args[1]; y.Op == OpEqual%s%dx%d &&
1068 y.Args[0] == y.Args[1] {
1069 return sloNot
1070 }
1071 `, upperT, w, c)
1072 fmt.Fprintf(out, "return sloXor\n")
1073 }, out)
1074
1075 fmt.Fprintf(out, `
1076 }
1077 return sloNone
1078 }
1079 `)
1080 }
1081
1082
1083
1084 func numberLines(data []byte) string {
1085 var buf bytes.Buffer
1086 r := bytes.NewReader(data)
1087 s := bufio.NewScanner(r)
1088 for i := 1; s.Scan(); i++ {
1089 fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
1090 }
1091 return buf.String()
1092 }
1093
1094 func nonTemplateRewrites(filename string, prologue func(s string, out io.Writer), rewrites ...func(out io.Writer)) {
1095 if filename == "" {
1096 return
1097 }
1098
1099 ofile := os.Stdout
1100
1101 if filename != "-" {
1102 var err error
1103 ofile, err = os.Create(filename)
1104 if err != nil {
1105 fmt.Fprintf(os.Stderr, "Could not create the output file %s for the generated code, %v", filename, err)
1106 os.Exit(1)
1107 }
1108 }
1109
1110 out := new(bytes.Buffer)
1111
1112 prologue("tmplgen", out)
1113 for _, rewrite := range rewrites {
1114 rewrite(out)
1115 }
1116
1117 b, err := format.Source(out.Bytes())
1118 if err != nil {
1119 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
1120 fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
1121 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
1122 os.Exit(1)
1123 } else {
1124 ofile.Write(b)
1125 ofile.Close()
1126 }
1127
1128 }
1129
1130 func one(filename string, prologue func(s string, out io.Writer), sats ...shapeAndTemplate) {
1131 if filename == "" {
1132 return
1133 }
1134
1135 ofile := os.Stdout
1136
1137 if filename != "-" {
1138 var err error
1139 ofile, err = os.Create(filename)
1140 if err != nil {
1141 fmt.Fprintf(os.Stderr, "Could not create the output file %s for the generated code, %v", filename, err)
1142 os.Exit(1)
1143 }
1144 }
1145
1146 out := new(bytes.Buffer)
1147
1148 prologue("tmplgen", out)
1149 for _, sat := range sats {
1150 sat.forTemplates(out)
1151 }
1152
1153 b, err := format.Source(out.Bytes())
1154 if err != nil {
1155 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
1156 fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
1157 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
1158 os.Exit(1)
1159 } else {
1160 ofile.Write(b)
1161 ofile.Close()
1162 }
1163
1164 }
1165
View as plain text