Source file
src/simd/ip_test.go
1
2
3
4
5
6
7 package simd_test
8
9 import (
10 "fmt"
11 "simd"
12 "testing"
13 )
14
15 func TestIP(t *testing.T) {
16
17 var a, b [50]float32
18 for i := 0; i < 50; i++ {
19 a[i] = float32(i)
20 b[i] = float32(i)
21 }
22 x, sumWidth, emulated := ip(a[:50], b[:50])
23
24 if x != 40425 {
25 t.Errorf("Expected 40425, got %f", x)
26 }
27
28 fmt.Printf("ip: sum was computed in width %d, emulated = %v\n", sumWidth, emulated)
29 }
30
31 func TestIPGoTo(t *testing.T) {
32
33 var a, b [50]float32
34 for i := 0; i < 50; i++ {
35 a[i] = float32(i)
36 b[i] = float32(i)
37 }
38 x, sumWidth, emulated := ipGoTo(a[:50], b[:50])
39
40 if x != 40425 {
41 t.Errorf("Expected 40425, got %f", x)
42 }
43
44 fmt.Printf("ipgoto: sum was computed in width %d, emulated = %v\n", sumWidth, emulated)
45 }
46
47 func first[T, U any](t T, u U) T {
48 return t
49 }
50
51 func ip(x, y []float32) (float32, int, bool) {
52 var a simd.Float32s
53 sumWidth := a.Len() * 32
54 emulated := simd.Emulated()
55 var i int
56 for i = 0; i < len(x)-a.Len()+1; i += a.Len() {
57 u := simd.LoadFloat32s(x[i : i+a.Len()])
58 v := simd.LoadFloat32s(y[i : i+a.Len()])
59 a = a.Add(u.Mul(v))
60 }
61 if i < len(x) {
62 a = a.Add(first(simd.LoadFloat32sPart(x[i:])).
63 Mul(first(simd.LoadFloat32sPart(y[i:]))))
64 }
65
66 return sum(a), sumWidth, emulated
67 }
68
69 func ipGoTo(x, y []float32) (float32, int, bool) {
70 var a simd.Float32s
71 sumWidth := a.Len() * 32
72 emulated := simd.Emulated()
73 var i int
74 var u, v simd.Float32s
75 loop:
76 if !(i < len(x)-a.Len()+1) {
77 goto done
78 }
79 u = simd.LoadFloat32s(x[i : i+a.Len()])
80 v = simd.LoadFloat32s(y[i : i+a.Len()])
81 a = a.Add(u.Mul(v))
82 i += a.Len()
83 goto loop
84 done:
85 if i < len(x) {
86 a = a.Add(first(simd.LoadFloat32sPart(x[i:])).
87 Mul(first(simd.LoadFloat32sPart(y[i:]))))
88 }
89
90 return sum(a), sumWidth, emulated
91 }
92
93 func boringSum(x simd.Float32s) float32 {
94 s := make([]float32, x.Len())
95 x.Store(s)
96 var r float32
97 for _, e := range s {
98 r += e
99 }
100 return r
101 }
102
View as plain text