1
2
3
4
5 package main
6
7 import (
8 "fmt"
9 "os"
10 "simd"
11 "unsafe"
12 )
13
14 func load(s []float64) simd.Float64x4 {
15 return simd.LoadFloat64x4((*[4]float64)(s[:4]))
16 }
17
18 type S1 = simd.Float64x4
19
20 type S2 simd.Float64x4
21
22 func (s S2) Len() int {
23 return simd.Float64x4(s).Len()
24 }
25
26 func (s S2) Load(a []float64) S2 {
27 return S2(load(a))
28 }
29
30 func (s S2) Store(a *[4]float64) {
31 simd.Float64x4(s).Store(a)
32 }
33
34 func (s S2) Add(a S2) S2 {
35 return S2(simd.Float64x4(s).Add(simd.Float64x4(a)))
36 }
37
38 func (s S2) Mul(a S2) S2 {
39 return S2(simd.Float64x4(s).Mul(simd.Float64x4(a)))
40 }
41
42 type S3 struct {
43 simd.Float64x4
44 }
45
46 func ip64_0(a, b []float64) float64 {
47 s := 0.0
48 for i := range a {
49 s += a[i] * b[i]
50 }
51 return s
52 }
53
54 func ip64_1(a, b []float64) float64 {
55 var z S1
56 sum := z
57 var i int
58 stride := z.Len()
59 for ; i <= len(a)-stride; i += stride {
60 va := load(a[i:])
61 vb := load(b[i:])
62 sum = sum.Add(va.Mul(vb))
63 }
64 var tmp [4]float64
65 sum.Store(&tmp)
66 return tmp[0] + tmp[1] + tmp[2] + tmp[3]
67 }
68
69 func ip64_1a(a, b []float64) float64 {
70 var z S1
71 sum := z
72 var i int
73 stride := z.Len()
74 for ; i <= len(a)-stride; i += stride {
75 va := load(a[i:])
76 vb := load(b[i:])
77 sum = FMA(sum, va, vb)
78 }
79 var tmp [4]float64
80 sum.Store(&tmp)
81 return tmp[0] + tmp[1] + tmp[2] + tmp[3]
82 }
83
84
85 func FMA(a, b, c simd.Float64x4) simd.Float64x4 {
86 return a.Add(b.Mul(c))
87 }
88
89 func ip64_2(a, b []float64) float64 {
90 var z S2
91 sum := z
92 var i int
93 stride := z.Len()
94 for ; i <= len(a)-stride; i += stride {
95 va := z.Load(a[i:])
96 vb := z.Load(b[i:])
97 sum = sum.Add(va.Mul(vb))
98 }
99 var tmp [4]float64
100 sum.Store(&tmp)
101 return tmp[0] + tmp[1] + tmp[2] + tmp[3]
102 }
103
104 func ip64_3(a, b []float64) float64 {
105 var z S3
106 sum := z
107 var i int
108 stride := z.Len()
109 for ; i <= len(a)-stride; i += stride {
110 va := load(a[i:])
111 vb := load(b[i:])
112 sum = S3{sum.Add(va.Mul(vb))}
113 }
114 var tmp [4]float64
115 sum.Store(&tmp)
116 return tmp[0] + tmp[1] + tmp[2] + tmp[3]
117 }
118
119 func main() {
120 a := []float64{1, 2, 3, 4, 5, 6, 7, 8}
121 ip0 := ip64_0(a, a)
122 ip1 := ip64_1(a, a)
123 ip1a := ip64_1a(a, a)
124 ip2 := ip64_2(a, a)
125 ip3 := ip64_3(a, a)
126 fmt.Printf("Test IP = %f\n", ip0)
127 fmt.Printf("SIMD IP 1 = %f\n", ip1)
128 fmt.Printf("SIMD IP 1a = %f\n", ip1a)
129 fmt.Printf("SIMD IP 2 = %f\n", ip2)
130 fmt.Printf("SIMD IP 3 = %f\n", ip3)
131 var z1 S1
132 var z2 S2
133 var z3 S2
134
135 s1, s2, s3 := unsafe.Sizeof(z1), unsafe.Sizeof(z2), unsafe.Sizeof(z3)
136
137 fmt.Printf("unsafe.Sizeof(z1, z2, z3)=%d, %d, %d\n", s1, s2, s3)
138
139 fail := false
140
141 if s1 != 32 || s2 != 32 || s3 != 32 {
142 fmt.Println("Failed a sizeof check, should all be 32")
143 fail = true
144 }
145
146 if ip1 != ip0 || ip1a != ip0 || ip2 != ip0 || ip3 != ip0 {
147 fmt.Println("Failed an inner product check, should all be", ip0)
148 fail = true
149 }
150
151 if fail {
152 os.Exit(1)
153 }
154 }
155
View as plain text