Source file src/simd/archsimd/testdata/arm64/sample.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.simd
     6  
     7  // This is a copy of amd64 sample.go adapted for Float32x4 (NEON 128-bit registers).
     8  package main
     9  
    10  import (
    11  	"fmt"
    12  	"os"
    13  	"simd/archsimd"
    14  	"unsafe"
    15  )
    16  
    17  func load(s []float32) archsimd.Float32x4 {
    18  	return archsimd.LoadFloat32x4Array((*[4]float32)(s[:4]))
    19  }
    20  
    21  type S1 = archsimd.Float32x4
    22  
    23  type S2 archsimd.Float32x4
    24  
    25  func (s S2) Len() int {
    26  	return archsimd.Float32x4(s).Len()
    27  }
    28  
    29  func (s S2) Load(a []float32) S2 {
    30  	return S2(load(a))
    31  }
    32  
    33  func (s S2) Store(a *[4]float32) {
    34  	archsimd.Float32x4(s).StoreArray(a)
    35  }
    36  
    37  func (s S2) Add(a S2) S2 {
    38  	return S2(archsimd.Float32x4(s).Add(archsimd.Float32x4(a)))
    39  }
    40  
    41  func (s S2) Mul(a S2) S2 {
    42  	return S2(archsimd.Float32x4(s).Mul(archsimd.Float32x4(a)))
    43  }
    44  
    45  type S3 struct {
    46  	archsimd.Float32x4
    47  }
    48  
    49  func ip32_0(a, b []float32) float32 {
    50  	var s float32
    51  	for i := range a {
    52  		s += a[i] * b[i]
    53  	}
    54  	return s
    55  }
    56  
    57  func ip32_1(a, b []float32) float32 {
    58  	var z S1
    59  	sum := z
    60  	var i int
    61  	stride := z.Len()
    62  	for ; i <= len(a)-stride; i += stride {
    63  		va := load(a[i:])
    64  		vb := load(b[i:])
    65  		sum = sum.Add(va.Mul(vb))
    66  	}
    67  	var tmp [4]float32
    68  	sum.StoreArray(&tmp)
    69  	return tmp[0] + tmp[1] + tmp[2] + tmp[3]
    70  }
    71  
    72  func ip32_1a(a, b []float32) float32 {
    73  	var z S1
    74  	sum := z
    75  	var i int
    76  	stride := z.Len()
    77  	for ; i <= len(a)-stride; i += stride {
    78  		va := load(a[i:])
    79  		vb := load(b[i:])
    80  		sum = FMA(sum, va, vb)
    81  	}
    82  	var tmp [4]float32
    83  	sum.StoreArray(&tmp)
    84  	return tmp[0] + tmp[1] + tmp[2] + tmp[3]
    85  }
    86  
    87  //go:noinline
    88  func FMA(a, b, c archsimd.Float32x4) archsimd.Float32x4 {
    89  	return a.Add(b.Mul(c))
    90  }
    91  
    92  func ip32_2(a, b []float32) float32 {
    93  	var z S2
    94  	sum := z
    95  	var i int
    96  	stride := z.Len()
    97  	for ; i <= len(a)-stride; i += stride {
    98  		va := z.Load(a[i:])
    99  		vb := z.Load(b[i:])
   100  		sum = sum.Add(va.Mul(vb))
   101  	}
   102  	var tmp [4]float32
   103  	sum.Store(&tmp)
   104  	return tmp[0] + tmp[1] + tmp[2] + tmp[3]
   105  }
   106  
   107  func ip32_3(a, b []float32) float32 {
   108  	var z S3
   109  	sum := z
   110  	var i int
   111  	stride := z.Len()
   112  	for ; i <= len(a)-stride; i += stride {
   113  		va := load(a[i:])
   114  		vb := load(b[i:])
   115  		sum = S3{sum.Add(va.Mul(vb))}
   116  	}
   117  	var tmp [4]float32
   118  	sum.StoreArray(&tmp)
   119  	return tmp[0] + tmp[1] + tmp[2] + tmp[3]
   120  }
   121  
   122  func main() {
   123  	a := []float32{1, 2, 3, 4, 5, 6, 7, 8}
   124  	ip0 := ip32_0(a, a)
   125  	ip1 := ip32_1(a, a)
   126  	ip1a := ip32_1a(a, a)
   127  	ip2 := ip32_2(a, a)
   128  	ip3 := ip32_3(a, a)
   129  	fmt.Printf("Test IP    = %f\n", ip0)
   130  	fmt.Printf("SIMD IP 1  = %f\n", ip1)
   131  	fmt.Printf("SIMD IP 1a = %f\n", ip1a)
   132  	fmt.Printf("SIMD IP 2  = %f\n", ip2)
   133  	fmt.Printf("SIMD IP 3 = %f\n", ip3)
   134  	var z1 S1
   135  	var z2 S2
   136  	var z3 S2
   137  
   138  	s1, s2, s3 := unsafe.Sizeof(z1), unsafe.Sizeof(z2), unsafe.Sizeof(z3)
   139  
   140  	fmt.Printf("unsafe.Sizeof(z1, z2, z3)=%d, %d, %d\n", s1, s2, s3)
   141  
   142  	fail := false
   143  
   144  	if s1 != 16 || s2 != 16 || s3 != 16 {
   145  		fmt.Println("Failed a sizeof check, should all be 16")
   146  		fail = true
   147  	}
   148  
   149  	if ip1 != ip0 || ip1a != ip0 || ip2 != ip0 || ip3 != ip0 {
   150  		fmt.Println("Failed an inner product check, should all be", ip0)
   151  		fail = true
   152  	}
   153  
   154  	if fail {
   155  		os.Exit(1)
   156  	}
   157  }
   158  

View as plain text