Source file src/simd/simd_test.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.simd
     6  
     7  package simd_test
     8  
     9  import (
    10  	"simd"
    11  	"testing"
    12  )
    13  
    14  type signed interface {
    15  	~int | ~int8 | ~int16 | ~int32 | ~int64
    16  }
    17  
    18  type unsigned interface {
    19  	~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr
    20  }
    21  
    22  type integer interface {
    23  	~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr
    24  }
    25  
    26  type float interface {
    27  	~float32 | ~float64
    28  }
    29  
    30  type number interface {
    31  	~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | ~float32 | ~float64
    32  }
    33  
    34  func TestInt8s(t *testing.T) {
    35  	// 64 elements = 512 bits
    36  	in1 := []int8{
    37  		1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16,
    38  		17, -18, 19, -20, 21, -22, 23, -24, 25, -26, 27, -28, 29, -30, 31, -32,
    39  		33, -34, 35, -36, 37, -38, 39, -40, 41, -42, 43, -44, 45, -46, 47, -48,
    40  		49, -50, 51, -52, 53, -54, 55, -56, 57, -58, 59, -60, 61, -62, 63, -64,
    41  	}
    42  	in2 := make([]int8, 64)
    43  	for i := range in2 {
    44  		in2[i] = 2
    45  	}
    46  
    47  	x := simd.LoadInt8s(in1)
    48  	y := simd.LoadInt8s(in2)
    49  
    50  	if x.Len() <= 0 {
    51  		t.Errorf("Int8s.Len() returned <= 0")
    52  	}
    53  
    54  	sum := x.Add(y)
    55  	diff := x.Sub(y)
    56  	neg := x.Neg()
    57  	abs := x.Abs()
    58  
    59  	buf := make([]int8, x.Len())
    60  	sum.Store(buf)
    61  	for i := 0; i < x.Len() && i < len(in1); i++ {
    62  		expected := in1[i] + in2[i]
    63  		if buf[i] != expected {
    64  			t.Errorf("Add at %d: got %d, want %d", i, buf[i], expected)
    65  		}
    66  	}
    67  
    68  	diff.Store(buf)
    69  	for i := 0; i < x.Len() && i < len(in1); i++ {
    70  		expected := in1[i] - in2[i]
    71  		if buf[i] != expected {
    72  			t.Errorf("Sub at %d: got %d, want %d", i, buf[i], expected)
    73  		}
    74  	}
    75  
    76  	neg.Store(buf)
    77  	for i := 0; i < x.Len() && i < len(in1); i++ {
    78  		expected := -in1[i]
    79  		if buf[i] != expected {
    80  			t.Errorf("Neg at %d: got %d, want %d", i, buf[i], expected)
    81  		}
    82  	}
    83  
    84  	abs.Store(buf)
    85  	for i := 0; i < x.Len() && i < len(in1); i++ {
    86  		expected := in1[i]
    87  		if expected < 0 {
    88  			expected = -expected
    89  		}
    90  		if buf[i] != expected {
    91  			t.Errorf("Abs at %d: got %d, want %d", i, buf[i], expected)
    92  		}
    93  	}
    94  }
    95  
    96  func TestInt16s(t *testing.T) {
    97  	// 32 elements = 512 bits
    98  	in1 := make([]int16, 32)
    99  	in2 := make([]int16, 32)
   100  	for i := range in1 {
   101  		in1[i] = int16((i + 1) * 100)
   102  		if i%2 != 0 {
   103  			in1[i] = -in1[i]
   104  		}
   105  		in2[i] = 10
   106  	}
   107  
   108  	x := simd.LoadInt16s(in1)
   109  	y := simd.LoadInt16s(in2)
   110  
   111  	sum := x.Add(y)
   112  	buf := make([]int16, x.Len())
   113  	sum.Store(buf)
   114  
   115  	for i := 0; i < x.Len() && i < len(in1); i++ {
   116  		expected := in1[i] + in2[i]
   117  		if buf[i] != expected {
   118  			t.Errorf("Int16s Add at %d: got %d, want %d", i, buf[i], expected)
   119  		}
   120  	}
   121  
   122  	// Test RotateAllLeft
   123  	rotLeft := x.RotateAllLeft(3)
   124  	rotLeft.Store(buf)
   125  	for i := 0; i < x.Len() && i < len(in1); i++ {
   126  		val := uint16(in1[i])
   127  		expected := int16((val << 3) | (val >> 13))
   128  		if buf[i] != expected {
   129  			t.Errorf("Int16s RotateAllLeft at %d: got %d, want %d", i, buf[i], expected)
   130  		}
   131  	}
   132  
   133  	// Test RotateAllRight with large distance
   134  	rotRight := x.RotateAllRight(19)
   135  	rotRight.Store(buf)
   136  	for i := 0; i < x.Len() && i < len(in1); i++ {
   137  		val := uint16(in1[i])
   138  		expected := int16((val >> 3) | (val << 13))
   139  		if buf[i] != expected {
   140  			t.Errorf("Int16s RotateAllRight(19) at %d: got %d, want %d", i, buf[i], expected)
   141  		}
   142  	}
   143  }
   144  
   145  func TestInt32s(t *testing.T) {
   146  	// 16 elements = 512 bits
   147  	in1 := make([]int32, 16)
   148  	in2 := make([]int32, 16)
   149  	for i := range in1 {
   150  		in1[i] = int32((i + 1) * 1000)
   151  		if i%2 != 0 {
   152  			in1[i] = -in1[i]
   153  		}
   154  		in2[i] = 100
   155  	}
   156  
   157  	x := simd.LoadInt32s(in1)
   158  	y := simd.LoadInt32s(in2)
   159  
   160  	sum := x.Add(y)
   161  	buf := make([]int32, x.Len())
   162  	sum.Store(buf)
   163  
   164  	for i := 0; i < x.Len() && i < len(in1); i++ {
   165  		expected := in1[i] + in2[i]
   166  		if buf[i] != expected {
   167  			t.Errorf("Int32s Add at %d: got %d, want %d", i, buf[i], expected)
   168  		}
   169  	}
   170  
   171  	// Test RotateAllLeft
   172  	rotLeft := x.RotateAllLeft(5)
   173  	rotLeft.Store(buf)
   174  	for i := 0; i < x.Len() && i < len(in1); i++ {
   175  		val := uint32(in1[i])
   176  		expected := int32((val << 5) | (val >> 27))
   177  		if buf[i] != expected {
   178  			t.Errorf("Int32s RotateAllLeft at %d: got %d, want %d", i, buf[i], expected)
   179  		}
   180  	}
   181  
   182  	// Test RotateAllRight with large distance
   183  	rotRight := x.RotateAllRight(37)
   184  	rotRight.Store(buf)
   185  	for i := 0; i < x.Len() && i < len(in1); i++ {
   186  		val := uint32(in1[i])
   187  		expected := int32((val >> 5) | (val << 27))
   188  		if buf[i] != expected {
   189  			t.Errorf("Int32s RotateAllRight(37) at %d: got %d, want %d", i, buf[i], expected)
   190  		}
   191  	}
   192  }
   193  
   194  func TestInt64s(t *testing.T) {
   195  	// 8 elements = 512 bits
   196  	in1 := make([]int64, 8)
   197  	in2 := make([]int64, 8)
   198  	for i := range in1 {
   199  		in1[i] = int64((i + 1) * 10000)
   200  		if i%2 != 0 {
   201  			in1[i] = -in1[i]
   202  		}
   203  		in2[i] = 1000
   204  	}
   205  
   206  	x := simd.LoadInt64s(in1)
   207  	y := simd.LoadInt64s(in2)
   208  
   209  	sum := x.Add(y)
   210  	buf := make([]int64, x.Len())
   211  	sum.Store(buf)
   212  
   213  	for i := 0; i < x.Len() && i < len(in1); i++ {
   214  		expected := in1[i] + in2[i]
   215  		if buf[i] != expected {
   216  			t.Errorf("Int64s Add at %d: got %d, want %d", i, buf[i], expected)
   217  		}
   218  	}
   219  
   220  	// Test RotateAllLeft
   221  	rotLeft := x.RotateAllLeft(7)
   222  	rotLeft.Store(buf)
   223  	for i := 0; i < x.Len() && i < len(in1); i++ {
   224  		val := uint64(in1[i])
   225  		expected := int64((val << 7) | (val >> 57))
   226  		if buf[i] != expected {
   227  			t.Errorf("Int64s RotateAllLeft at %d: got %d, want %d", i, buf[i], expected)
   228  		}
   229  	}
   230  
   231  	// Test RotateAllRight with large distance
   232  	rotRight := x.RotateAllRight(71)
   233  	rotRight.Store(buf)
   234  	for i := 0; i < x.Len() && i < len(in1); i++ {
   235  		val := uint64(in1[i])
   236  		expected := int64((val >> 7) | (val << 57))
   237  		if buf[i] != expected {
   238  			t.Errorf("Int64s RotateAllRight(71) at %d: got %d, want %d", i, buf[i], expected)
   239  		}
   240  	}
   241  }
   242  
   243  func TestUint8s(t *testing.T) {
   244  	// 64 elements = 512 bits
   245  	in1 := make([]uint8, 64)
   246  	in2 := make([]uint8, 64)
   247  	for i := range in1 {
   248  		in1[i] = uint8(i + 1)
   249  		in2[i] = 10
   250  	}
   251  
   252  	x := simd.LoadUint8s(in1)
   253  	y := simd.LoadUint8s(in2)
   254  
   255  	avg := x.Average(y)
   256  	buf := make([]uint8, x.Len())
   257  	avg.Store(buf)
   258  
   259  	for i := 0; i < x.Len() && i < len(in1); i++ {
   260  		expected := uint8((int(in1[i]) + int(in2[i]) + 1) >> 1)
   261  		if buf[i] != expected {
   262  			t.Errorf("Uint8s Average at %d: got %d, want %d", i, buf[i], expected)
   263  		}
   264  	}
   265  }
   266  
   267  func TestFloat32s(t *testing.T) {
   268  	// 16 elements = 512 bits
   269  	in1 := make([]float32, 16)
   270  	in2 := make([]float32, 16)
   271  	for i := range in1 {
   272  		val := float32(i) + 1.5
   273  		if i%2 != 0 {
   274  			val = -val
   275  		}
   276  		in1[i] = val
   277  		in2[i] = 0.5
   278  	}
   279  
   280  	x := simd.LoadFloat32s(in1)
   281  	y := simd.LoadFloat32s(in2)
   282  
   283  	sum := x.Add(y)
   284  	buf := make([]float32, x.Len())
   285  	sum.Store(buf)
   286  
   287  	for i := 0; i < x.Len() && i < len(in1); i++ {
   288  		expected := in1[i] + in2[i]
   289  		if buf[i] != expected {
   290  			t.Errorf("Float32s Add at %d: got %f, want %f", i, buf[i], expected)
   291  		}
   292  	}
   293  }
   294  
   295  func TestFloat64s(t *testing.T) {
   296  	// 8 elements = 512 bits
   297  	in1 := make([]float64, 8)
   298  	in2 := make([]float64, 8)
   299  	for i := range in1 {
   300  		val := float64(i)*10.0 + 10.25
   301  		if i%2 != 0 {
   302  			val = -val
   303  		}
   304  		in1[i] = val
   305  		in2[i] = 1.0
   306  	}
   307  
   308  	x := simd.LoadFloat64s(in1)
   309  	y := simd.LoadFloat64s(in2)
   310  
   311  	mul := x.Mul(y)
   312  	buf := make([]float64, x.Len())
   313  	mul.Store(buf)
   314  
   315  	for i := 0; i < x.Len() && i < len(in1); i++ {
   316  		expected := in1[i] * in2[i]
   317  		if buf[i] != expected {
   318  			t.Errorf("Float64s Mul at %d: got %f, want %f", i, buf[i], expected)
   319  		}
   320  	}
   321  }
   322  
   323  func TestUint16s(t *testing.T) {
   324  	in1 := make([]uint16, 32)
   325  	for i := range in1 {
   326  		in1[i] = uint16((i + 1) * 100)
   327  	}
   328  
   329  	x := simd.LoadUint16s(in1)
   330  	buf := make([]uint16, x.Len())
   331  
   332  	// Test RotateAllLeft
   333  	rotLeft := x.RotateAllLeft(3)
   334  	rotLeft.Store(buf)
   335  	for i := 0; i < x.Len() && i < len(in1); i++ {
   336  		val := in1[i]
   337  		expected := (val << 3) | (val >> 13)
   338  		if buf[i] != expected {
   339  			t.Errorf("Uint16s RotateAllLeft at %d: got %d, want %d", i, buf[i], expected)
   340  		}
   341  	}
   342  
   343  	// Test RotateAllRight with large distance
   344  	rotRight := x.RotateAllRight(19)
   345  	rotRight.Store(buf)
   346  	for i := 0; i < x.Len() && i < len(in1); i++ {
   347  		val := in1[i]
   348  		expected := (val >> 3) | (val << 13)
   349  		if buf[i] != expected {
   350  			t.Errorf("Uint16s RotateAllRight(19) at %d: got %d, want %d", i, buf[i], expected)
   351  		}
   352  	}
   353  }
   354  
   355  func TestUint32s(t *testing.T) {
   356  	in1 := make([]uint32, 16)
   357  	for i := range in1 {
   358  		in1[i] = uint32((i + 1) * 1000)
   359  	}
   360  
   361  	x := simd.LoadUint32s(in1)
   362  	buf := make([]uint32, x.Len())
   363  
   364  	// Test RotateAllLeft
   365  	rotLeft := x.RotateAllLeft(5)
   366  	rotLeft.Store(buf)
   367  	for i := 0; i < x.Len() && i < len(in1); i++ {
   368  		val := in1[i]
   369  		expected := (val << 5) | (val >> 27)
   370  		if buf[i] != expected {
   371  			t.Errorf("Uint32s RotateAllLeft at %d: got %d, want %d", i, buf[i], expected)
   372  		}
   373  	}
   374  
   375  	// Test RotateAllRight with large distance
   376  	rotRight := x.RotateAllRight(37)
   377  	rotRight.Store(buf)
   378  	for i := 0; i < x.Len() && i < len(in1); i++ {
   379  		val := in1[i]
   380  		expected := (val >> 5) | (val << 27)
   381  		if buf[i] != expected {
   382  			t.Errorf("Uint32s RotateAllRight(37) at %d: got %d, want %d", i, buf[i], expected)
   383  		}
   384  	}
   385  }
   386  
   387  func TestUint64s(t *testing.T) {
   388  	in1 := make([]uint64, 8)
   389  	for i := range in1 {
   390  		in1[i] = uint64((i + 1) * 10000)
   391  	}
   392  
   393  	x := simd.LoadUint64s(in1)
   394  	buf := make([]uint64, x.Len())
   395  
   396  	// Test RotateAllLeft
   397  	rotLeft := x.RotateAllLeft(7)
   398  	rotLeft.Store(buf)
   399  	for i := 0; i < x.Len() && i < len(in1); i++ {
   400  		val := in1[i]
   401  		expected := (val << 7) | (val >> 57)
   402  		if buf[i] != expected {
   403  			t.Errorf("Uint64s RotateAllLeft at %d: got %d, want %d", i, buf[i], expected)
   404  		}
   405  	}
   406  
   407  	// Test RotateAllRight with large distance
   408  	rotRight := x.RotateAllRight(71)
   409  	rotRight.Store(buf)
   410  	for i := 0; i < x.Len() && i < len(in1); i++ {
   411  		val := in1[i]
   412  		expected := (val >> 7) | (val << 57)
   413  		if buf[i] != expected {
   414  			t.Errorf("Uint64s RotateAllRight(71) at %d: got %d, want %d", i, buf[i], expected)
   415  		}
   416  	}
   417  }
   418  
   419  type HasStoreLen[E number] interface {
   420  	Store(s []E)
   421  	Len() int
   422  }
   423  
   424  func testBroadcast[E number, V HasStoreLen[E]](t *testing.T, x E, f func(e E) V) {
   425  	v := f(x)
   426  	s := make([]E, v.Len())
   427  	v.Store(s)
   428  	for _, e := range s {
   429  		if e != x {
   430  			t.Errorf("Expected %v, saw %v", x, e)
   431  		}
   432  	}
   433  }
   434  
   435  func TestBroadcast(t *testing.T) {
   436  	testBroadcast(t, int8(-2), simd.BroadcastInt8s)
   437  	testBroadcast(t, int16(-2), simd.BroadcastInt16s)
   438  	testBroadcast(t, int32(-2), simd.BroadcastInt32s)
   439  	testBroadcast(t, int64(-2), simd.BroadcastInt64s)
   440  
   441  	testBroadcast(t, uint8(99), simd.BroadcastUint8s)
   442  	testBroadcast(t, uint16(9999), simd.BroadcastUint16s)
   443  	testBroadcast(t, uint32(99991111), simd.BroadcastUint32s)
   444  	testBroadcast(t, uint64(112233445599887766), simd.BroadcastUint64s)
   445  
   446  	testBroadcast(t, float32(99991111), simd.BroadcastFloat32s)
   447  	testBroadcast(t, float64(112233445599887766), simd.BroadcastFloat64s)
   448  }
   449  

View as plain text