Source file src/simd/archsimd/internal/simd_test/arm64_compare_test.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.simd && arm64
     6  
     7  package simd_test
     8  
     9  import (
    10  	"simd/archsimd"
    11  	"testing"
    12  )
    13  
    14  // --- Equal (hardware: CMEQ / FCMEQ) ---
    15  
    16  func TestEqual(t *testing.T) {
    17  	testInt8x16Compare(t, archsimd.Int8x16.Equal, equalSlice[int8])
    18  	testInt16x8Compare(t, archsimd.Int16x8.Equal, equalSlice[int16])
    19  	testInt32x4Compare(t, archsimd.Int32x4.Equal, equalSlice[int32])
    20  	testInt64x2Compare(t, archsimd.Int64x2.Equal, equalSlice[int64])
    21  
    22  	testUint8x16Compare(t, archsimd.Uint8x16.Equal, equalSlice[uint8])
    23  	testUint16x8Compare(t, archsimd.Uint16x8.Equal, equalSlice[uint16])
    24  	testUint32x4Compare(t, archsimd.Uint32x4.Equal, equalSlice[uint32])
    25  	testUint64x2Compare(t, archsimd.Uint64x2.Equal, equalSlice[uint64])
    26  
    27  	testFloat32x4Compare(t, archsimd.Float32x4.Equal, equalSlice[float32])
    28  	testFloat64x2Compare(t, archsimd.Float64x2.Equal, equalSlice[float64])
    29  }
    30  
    31  // --- Greater (hardware: CMGT/CMHI for signed/unsigned, FCMGT for float) ---
    32  
    33  func TestGreater(t *testing.T) {
    34  	testInt8x16Compare(t, archsimd.Int8x16.Greater, greaterSlice[int8])
    35  	testInt16x8Compare(t, archsimd.Int16x8.Greater, greaterSlice[int16])
    36  	testInt32x4Compare(t, archsimd.Int32x4.Greater, greaterSlice[int32])
    37  	testInt64x2Compare(t, archsimd.Int64x2.Greater, greaterSlice[int64])
    38  
    39  	testUint8x16Compare(t, archsimd.Uint8x16.Greater, greaterSlice[uint8])
    40  	testUint16x8Compare(t, archsimd.Uint16x8.Greater, greaterSlice[uint16])
    41  	testUint32x4Compare(t, archsimd.Uint32x4.Greater, greaterSlice[uint32])
    42  	testUint64x2Compare(t, archsimd.Uint64x2.Greater, greaterSlice[uint64])
    43  
    44  	testFloat32x4Compare(t, archsimd.Float32x4.Greater, greaterSlice[float32])
    45  	testFloat64x2Compare(t, archsimd.Float64x2.Greater, greaterSlice[float64])
    46  }
    47  
    48  // --- GreaterEqual (hardware: CMGE/CMHS for signed/unsigned, FCMGE for float) ---
    49  
    50  func TestGreaterEqual(t *testing.T) {
    51  	testInt8x16Compare(t, archsimd.Int8x16.GreaterEqual, greaterEqualSlice[int8])
    52  	testInt16x8Compare(t, archsimd.Int16x8.GreaterEqual, greaterEqualSlice[int16])
    53  	testInt32x4Compare(t, archsimd.Int32x4.GreaterEqual, greaterEqualSlice[int32])
    54  	testInt64x2Compare(t, archsimd.Int64x2.GreaterEqual, greaterEqualSlice[int64])
    55  
    56  	testUint8x16Compare(t, archsimd.Uint8x16.GreaterEqual, greaterEqualSlice[uint8])
    57  	testUint16x8Compare(t, archsimd.Uint16x8.GreaterEqual, greaterEqualSlice[uint16])
    58  	testUint32x4Compare(t, archsimd.Uint32x4.GreaterEqual, greaterEqualSlice[uint32])
    59  	testUint64x2Compare(t, archsimd.Uint64x2.GreaterEqual, greaterEqualSlice[uint64])
    60  
    61  	testFloat32x4Compare(t, archsimd.Float32x4.GreaterEqual, greaterEqualSlice[float32])
    62  	testFloat64x2Compare(t, archsimd.Float64x2.GreaterEqual, greaterEqualSlice[float64])
    63  }
    64  
    65  // --- Less (derived: y.Greater(x)) ---
    66  
    67  func TestLess(t *testing.T) {
    68  	testInt8x16Compare(t, archsimd.Int8x16.Less, lessSlice[int8])
    69  	testInt16x8Compare(t, archsimd.Int16x8.Less, lessSlice[int16])
    70  	testInt32x4Compare(t, archsimd.Int32x4.Less, lessSlice[int32])
    71  	testInt64x2Compare(t, archsimd.Int64x2.Less, lessSlice[int64])
    72  
    73  	testUint8x16Compare(t, archsimd.Uint8x16.Less, lessSlice[uint8])
    74  	testUint16x8Compare(t, archsimd.Uint16x8.Less, lessSlice[uint16])
    75  	testUint32x4Compare(t, archsimd.Uint32x4.Less, lessSlice[uint32])
    76  	testUint64x2Compare(t, archsimd.Uint64x2.Less, lessSlice[uint64])
    77  
    78  	testFloat32x4Compare(t, archsimd.Float32x4.Less, lessSlice[float32])
    79  	testFloat64x2Compare(t, archsimd.Float64x2.Less, lessSlice[float64])
    80  }
    81  
    82  // --- LessEqual (derived: y.GreaterEqual(x)) ---
    83  
    84  func TestLessEqual(t *testing.T) {
    85  	testInt8x16Compare(t, archsimd.Int8x16.LessEqual, lessEqualSlice[int8])
    86  	testInt16x8Compare(t, archsimd.Int16x8.LessEqual, lessEqualSlice[int16])
    87  	testInt32x4Compare(t, archsimd.Int32x4.LessEqual, lessEqualSlice[int32])
    88  	testInt64x2Compare(t, archsimd.Int64x2.LessEqual, lessEqualSlice[int64])
    89  
    90  	testUint8x16Compare(t, archsimd.Uint8x16.LessEqual, lessEqualSlice[uint8])
    91  	testUint16x8Compare(t, archsimd.Uint16x8.LessEqual, lessEqualSlice[uint16])
    92  	testUint32x4Compare(t, archsimd.Uint32x4.LessEqual, lessEqualSlice[uint32])
    93  	testUint64x2Compare(t, archsimd.Uint64x2.LessEqual, lessEqualSlice[uint64])
    94  
    95  	testFloat32x4Compare(t, archsimd.Float32x4.LessEqual, lessEqualSlice[float32])
    96  	testFloat64x2Compare(t, archsimd.Float64x2.LessEqual, lessEqualSlice[float64])
    97  }
    98  
    99  // --- NotEqual (derived: Equal().Not()) ---
   100  
   101  func TestNotEqual(t *testing.T) {
   102  	testInt8x16Compare(t, archsimd.Int8x16.NotEqual, notEqualSlice[int8])
   103  	testInt16x8Compare(t, archsimd.Int16x8.NotEqual, notEqualSlice[int16])
   104  	testInt32x4Compare(t, archsimd.Int32x4.NotEqual, notEqualSlice[int32])
   105  	testInt64x2Compare(t, archsimd.Int64x2.NotEqual, notEqualSlice[int64])
   106  
   107  	testUint8x16Compare(t, archsimd.Uint8x16.NotEqual, notEqualSlice[uint8])
   108  	testUint16x8Compare(t, archsimd.Uint16x8.NotEqual, notEqualSlice[uint16])
   109  	testUint32x4Compare(t, archsimd.Uint32x4.NotEqual, notEqualSlice[uint32])
   110  	testUint64x2Compare(t, archsimd.Uint64x2.NotEqual, notEqualSlice[uint64])
   111  
   112  	testFloat32x4Compare(t, archsimd.Float32x4.NotEqual, notEqualSlice[float32])
   113  	testFloat64x2Compare(t, archsimd.Float64x2.NotEqual, notEqualSlice[float64])
   114  }
   115  
   116  // --- Masked: zero elements where mask is false ---
   117  
   118  func TestMasked(t *testing.T) {
   119  	// Test Masked for Int8x16
   120  	forSlicePair(t, int8s, 16, func(x, y []int8) bool {
   121  		t.Helper()
   122  		a := archsimd.LoadInt8x16(x)
   123  		mask := archsimd.LoadInt8x16(y).Greater(archsimd.Int8x16{}) // mask: y > 0
   124  		g := make([]int8, 16)
   125  		a.Masked(mask).Store(g)
   126  		w := make([]int8, 16)
   127  		for i := range w {
   128  			if y[i] > 0 {
   129  				w[i] = x[i]
   130  			}
   131  		}
   132  		return checkSlicesLogInput(t, g, w, 0.0, func() { t.Helper(); t.Logf("x=%v, y=%v", x, y) })
   133  	})
   134  
   135  	// Test Masked for Float64x2
   136  	forSlicePair(t, float64s, 2, func(x, y []float64) bool {
   137  		t.Helper()
   138  		a := archsimd.LoadFloat64x2(x)
   139  		mask := archsimd.LoadFloat64x2(y).Greater(archsimd.Float64x2{}) // mask: y > 0
   140  		g := make([]float64, 2)
   141  		a.Masked(mask).Store(g)
   142  		w := make([]float64, 2)
   143  		for i := range w {
   144  			if y[i] > 0 {
   145  				w[i] = x[i]
   146  			}
   147  		}
   148  		return checkSlicesLogInput(t, g, w, 0.0, func() { t.Helper(); t.Logf("x=%v, y=%v", x, y) })
   149  	})
   150  }
   151  
   152  // --- IfElse: set elements to y where mask is true, keep x where true ---
   153  
   154  func TestIfElse(t *testing.T) {
   155  	// Test Merge for Int8x16
   156  	forSliceTriple(t, int8s, 16, func(x, y, m []int8) bool {
   157  		t.Helper()
   158  		a := archsimd.LoadInt8x16(x)
   159  		b := archsimd.LoadInt8x16(y)
   160  		mask := archsimd.LoadInt8x16(m).Greater(archsimd.Int8x16{}) // mask: m > 0
   161  		g := make([]int8, 16)
   162  		a.IfElse(mask, b).Store(g)
   163  		w := make([]int8, 16)
   164  		for i := range w {
   165  			if m[i] > 0 {
   166  				w[i] = y[i]
   167  			} else {
   168  				w[i] = x[i]
   169  			}
   170  		}
   171  		return checkSlicesLogInput(t, g, w, 0.0, func() { t.Helper(); t.Logf("x=%v, y=%v, m=%v", x, y, m) })
   172  	})
   173  
   174  	// Test Merge for Float32x4
   175  	forSliceTriple(t, float32s, 4, func(x, y, m []float32) bool {
   176  		t.Helper()
   177  		a := archsimd.LoadFloat32x4(x)
   178  		b := archsimd.LoadFloat32x4(y)
   179  		mask := archsimd.LoadFloat32x4(m).Greater(archsimd.Float32x4{}) // mask: m > 0
   180  		g := make([]float32, 4)
   181  		a.IfElse(mask, b).Store(g)
   182  		w := make([]float32, 4)
   183  		for i := range w {
   184  			if m[i] > 0 {
   185  				w[i] = y[i]
   186  			} else {
   187  				w[i] = x[i]
   188  			}
   189  		}
   190  		return checkSlicesLogInput(t, g, w, 0.0, func() { t.Helper(); t.Logf("x=%v, y=%v, m=%v", x, y, m) })
   191  	})
   192  }
   193  

View as plain text