Source file src/simd/testdata/sample.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"os"
    10  	"simd"
    11  	"unsafe"
    12  )
    13  
    14  func load(s []float64) simd.Float64x4 {
    15  	return simd.LoadFloat64x4((*[4]float64)(s[:4]))
    16  }
    17  
    18  type S1 = simd.Float64x4
    19  
    20  type S2 simd.Float64x4
    21  
    22  func (s S2) Len() int {
    23  	return simd.Float64x4(s).Len()
    24  }
    25  
    26  func (s S2) Load(a []float64) S2 {
    27  	return S2(load(a))
    28  }
    29  
    30  func (s S2) Store(a *[4]float64) {
    31  	simd.Float64x4(s).Store(a)
    32  }
    33  
    34  func (s S2) Add(a S2) S2 {
    35  	return S2(simd.Float64x4(s).Add(simd.Float64x4(a)))
    36  }
    37  
    38  func (s S2) Mul(a S2) S2 {
    39  	return S2(simd.Float64x4(s).Mul(simd.Float64x4(a)))
    40  }
    41  
    42  type S3 struct {
    43  	simd.Float64x4
    44  }
    45  
    46  func ip64_0(a, b []float64) float64 {
    47  	s := 0.0
    48  	for i := range a {
    49  		s += a[i] * b[i]
    50  	}
    51  	return s
    52  }
    53  
    54  func ip64_1(a, b []float64) float64 {
    55  	var z S1
    56  	sum := z
    57  	var i int
    58  	stride := z.Len()
    59  	for ; i <= len(a)-stride; i += stride {
    60  		va := load(a[i:])
    61  		vb := load(b[i:])
    62  		sum = sum.Add(va.Mul(vb))
    63  	}
    64  	var tmp [4]float64
    65  	sum.Store(&tmp)
    66  	return tmp[0] + tmp[1] + tmp[2] + tmp[3]
    67  }
    68  
    69  func ip64_1a(a, b []float64) float64 {
    70  	var z S1
    71  	sum := z
    72  	var i int
    73  	stride := z.Len()
    74  	for ; i <= len(a)-stride; i += stride {
    75  		va := load(a[i:])
    76  		vb := load(b[i:])
    77  		sum = FMA(sum, va, vb)
    78  	}
    79  	var tmp [4]float64
    80  	sum.Store(&tmp)
    81  	return tmp[0] + tmp[1] + tmp[2] + tmp[3]
    82  }
    83  
    84  //go:noinline
    85  func FMA(a, b, c simd.Float64x4) simd.Float64x4 {
    86  	return a.Add(b.Mul(c))
    87  }
    88  
    89  func ip64_2(a, b []float64) float64 {
    90  	var z S2
    91  	sum := z
    92  	var i int
    93  	stride := z.Len()
    94  	for ; i <= len(a)-stride; i += stride {
    95  		va := z.Load(a[i:])
    96  		vb := z.Load(b[i:])
    97  		sum = sum.Add(va.Mul(vb))
    98  	}
    99  	var tmp [4]float64
   100  	sum.Store(&tmp)
   101  	return tmp[0] + tmp[1] + tmp[2] + tmp[3]
   102  }
   103  
   104  func ip64_3(a, b []float64) float64 {
   105  	var z S3
   106  	sum := z
   107  	var i int
   108  	stride := z.Len()
   109  	for ; i <= len(a)-stride; i += stride {
   110  		va := load(a[i:])
   111  		vb := load(b[i:])
   112  		sum = S3{sum.Add(va.Mul(vb))}
   113  	}
   114  	var tmp [4]float64
   115  	sum.Store(&tmp)
   116  	return tmp[0] + tmp[1] + tmp[2] + tmp[3]
   117  }
   118  
   119  func main() {
   120  	a := []float64{1, 2, 3, 4, 5, 6, 7, 8}
   121  	ip0 := ip64_0(a, a)
   122  	ip1 := ip64_1(a, a)
   123  	ip1a := ip64_1a(a, a)
   124  	ip2 := ip64_2(a, a)
   125  	ip3 := ip64_3(a, a)
   126  	fmt.Printf("Test IP    = %f\n", ip0)
   127  	fmt.Printf("SIMD IP 1  = %f\n", ip1)
   128  	fmt.Printf("SIMD IP 1a = %f\n", ip1a)
   129  	fmt.Printf("SIMD IP 2  = %f\n", ip2)
   130  	fmt.Printf("SIMD IP 3 = %f\n", ip3)
   131  	var z1 S1
   132  	var z2 S2
   133  	var z3 S2
   134  
   135  	s1, s2, s3 := unsafe.Sizeof(z1), unsafe.Sizeof(z2), unsafe.Sizeof(z3)
   136  
   137  	fmt.Printf("unsafe.Sizeof(z1, z2, z3)=%d, %d, %d\n", s1, s2, s3)
   138  
   139  	fail := false
   140  
   141  	if s1 != 32 || s2 != 32 || s3 != 32 {
   142  		fmt.Println("Failed a sizeof check, should all be 32")
   143  		fail = true
   144  	}
   145  
   146  	if ip1 != ip0 || ip1a != ip0 || ip2 != ip0 || ip3 != ip0 {
   147  		fmt.Println("Failed an inner product check, should all be", ip0)
   148  		fail = true
   149  	}
   150  
   151  	if fail {
   152  		os.Exit(1)
   153  	}
   154  }
   155  

View as plain text