1
2
3
4
5
6
7
8 package main
9
10 import (
11 "fmt"
12 "os"
13 "simd/archsimd"
14 "unsafe"
15 )
16
17 func load(s []float32) archsimd.Float32x4 {
18 return archsimd.LoadFloat32x4Array((*[4]float32)(s[:4]))
19 }
20
21 type S1 = archsimd.Float32x4
22
23 type S2 archsimd.Float32x4
24
25 func (s S2) Len() int {
26 return archsimd.Float32x4(s).Len()
27 }
28
29 func (s S2) Load(a []float32) S2 {
30 return S2(load(a))
31 }
32
33 func (s S2) Store(a *[4]float32) {
34 archsimd.Float32x4(s).StoreArray(a)
35 }
36
37 func (s S2) Add(a S2) S2 {
38 return S2(archsimd.Float32x4(s).Add(archsimd.Float32x4(a)))
39 }
40
41 func (s S2) Mul(a S2) S2 {
42 return S2(archsimd.Float32x4(s).Mul(archsimd.Float32x4(a)))
43 }
44
45 type S3 struct {
46 archsimd.Float32x4
47 }
48
49 func ip32_0(a, b []float32) float32 {
50 var s float32
51 for i := range a {
52 s += a[i] * b[i]
53 }
54 return s
55 }
56
57 func ip32_1(a, b []float32) float32 {
58 var z S1
59 sum := z
60 var i int
61 stride := z.Len()
62 for ; i <= len(a)-stride; i += stride {
63 va := load(a[i:])
64 vb := load(b[i:])
65 sum = sum.Add(va.Mul(vb))
66 }
67 var tmp [4]float32
68 sum.StoreArray(&tmp)
69 return tmp[0] + tmp[1] + tmp[2] + tmp[3]
70 }
71
72 func ip32_1a(a, b []float32) float32 {
73 var z S1
74 sum := z
75 var i int
76 stride := z.Len()
77 for ; i <= len(a)-stride; i += stride {
78 va := load(a[i:])
79 vb := load(b[i:])
80 sum = FMA(sum, va, vb)
81 }
82 var tmp [4]float32
83 sum.StoreArray(&tmp)
84 return tmp[0] + tmp[1] + tmp[2] + tmp[3]
85 }
86
87
88 func FMA(a, b, c archsimd.Float32x4) archsimd.Float32x4 {
89 return a.Add(b.Mul(c))
90 }
91
92 func ip32_2(a, b []float32) float32 {
93 var z S2
94 sum := z
95 var i int
96 stride := z.Len()
97 for ; i <= len(a)-stride; i += stride {
98 va := z.Load(a[i:])
99 vb := z.Load(b[i:])
100 sum = sum.Add(va.Mul(vb))
101 }
102 var tmp [4]float32
103 sum.Store(&tmp)
104 return tmp[0] + tmp[1] + tmp[2] + tmp[3]
105 }
106
107 func ip32_3(a, b []float32) float32 {
108 var z S3
109 sum := z
110 var i int
111 stride := z.Len()
112 for ; i <= len(a)-stride; i += stride {
113 va := load(a[i:])
114 vb := load(b[i:])
115 sum = S3{sum.Add(va.Mul(vb))}
116 }
117 var tmp [4]float32
118 sum.StoreArray(&tmp)
119 return tmp[0] + tmp[1] + tmp[2] + tmp[3]
120 }
121
122 func main() {
123 a := []float32{1, 2, 3, 4, 5, 6, 7, 8}
124 ip0 := ip32_0(a, a)
125 ip1 := ip32_1(a, a)
126 ip1a := ip32_1a(a, a)
127 ip2 := ip32_2(a, a)
128 ip3 := ip32_3(a, a)
129 fmt.Printf("Test IP = %f\n", ip0)
130 fmt.Printf("SIMD IP 1 = %f\n", ip1)
131 fmt.Printf("SIMD IP 1a = %f\n", ip1a)
132 fmt.Printf("SIMD IP 2 = %f\n", ip2)
133 fmt.Printf("SIMD IP 3 = %f\n", ip3)
134 var z1 S1
135 var z2 S2
136 var z3 S2
137
138 s1, s2, s3 := unsafe.Sizeof(z1), unsafe.Sizeof(z2), unsafe.Sizeof(z3)
139
140 fmt.Printf("unsafe.Sizeof(z1, z2, z3)=%d, %d, %d\n", s1, s2, s3)
141
142 fail := false
143
144 if s1 != 16 || s2 != 16 || s3 != 16 {
145 fmt.Println("Failed a sizeof check, should all be 16")
146 fail = true
147 }
148
149 if ip1 != ip0 || ip1a != ip0 || ip2 != ip0 || ip3 != ip0 {
150 fmt.Println("Failed an inner product check, should all be", ip0)
151 fail = true
152 }
153
154 if fail {
155 os.Exit(1)
156 }
157 }
158
View as plain text