Source file src/simd/archsimd/internal/simd_test/arm64_shift_test.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.simd && arm64
     6  
     7  package simd_test
     8  
     9  import (
    10  	"simd/archsimd"
    11  	"testing"
    12  )
    13  
    14  func TestShift(t *testing.T) {
    15  	// Signed — reuse binary helpers, same-type operand pairs
    16  	testInt8x16Binary(t, archsimd.Int8x16.Shift, shiftSlice[int8])
    17  	testInt16x8Binary(t, archsimd.Int16x8.Shift, shiftSlice[int16])
    18  	testInt32x4Binary(t, archsimd.Int32x4.Shift, shiftSlice[int32])
    19  	testInt64x2Binary(t, archsimd.Int64x2.Shift, shiftSlice[int64])
    20  
    21  	// Unsigned — mixed-type operand pairs
    22  	testUint8x16Shift(t, archsimd.Uint8x16.Shift, shiftMixedSlice[uint8, int8])
    23  	testUint16x8Shift(t, archsimd.Uint16x8.Shift, shiftMixedSlice[uint16, int16])
    24  	testUint32x4Shift(t, archsimd.Uint32x4.Shift, shiftMixedSlice[uint32, int32])
    25  	testUint64x2Shift(t, archsimd.Uint64x2.Shift, shiftMixedSlice[uint64, int64])
    26  }
    27  
    28  func TestShiftSaturated(t *testing.T) {
    29  	// Signed — reuse binary helpers
    30  	testInt8x16Binary(t, archsimd.Int8x16.ShiftSaturated, shiftSaturatingSignedSlice[int8])
    31  	testInt16x8Binary(t, archsimd.Int16x8.ShiftSaturated, shiftSaturatingSignedSlice[int16])
    32  	testInt32x4Binary(t, archsimd.Int32x4.ShiftSaturated, shiftSaturatingSignedSlice[int32])
    33  	testInt64x2Binary(t, archsimd.Int64x2.ShiftSaturated, shiftSaturatingSignedSlice[int64])
    34  
    35  	// Unsigned — mixed-type
    36  	testUint8x16Shift(t, archsimd.Uint8x16.ShiftSaturated, shiftSaturatingUnsignedSlice[uint8, int8])
    37  	testUint16x8Shift(t, archsimd.Uint16x8.ShiftSaturated, shiftSaturatingUnsignedSlice[uint16, int16])
    38  	testUint32x4Shift(t, archsimd.Uint32x4.ShiftSaturated, shiftSaturatingUnsignedSlice[uint32, int32])
    39  	testUint64x2Shift(t, archsimd.Uint64x2.ShiftSaturated, shiftSaturatingUnsignedSlice[uint64, int64])
    40  }
    41  
    42  var testShiftConstAmt uint64 = 3
    43  
    44  func TestShiftLeftConst(t *testing.T) {
    45  	// Signed
    46  	testInt8x16ShiftConst(t, archsimd.Int8x16.ShiftAllLeft, shiftLeftByConstSlice[int8])
    47  	testInt16x8ShiftConst(t, archsimd.Int16x8.ShiftAllLeft, shiftLeftByConstSlice[int16])
    48  	testInt32x4ShiftConst(t, archsimd.Int32x4.ShiftAllLeft, shiftLeftByConstSlice[int32])
    49  	testInt64x2ShiftConst(t, archsimd.Int64x2.ShiftAllLeft, shiftLeftByConstSlice[int64])
    50  	// Unsigned
    51  	testUint8x16ShiftConst(t, archsimd.Uint8x16.ShiftAllLeft, shiftLeftByConstSlice[uint8])
    52  	testUint16x8ShiftConst(t, archsimd.Uint16x8.ShiftAllLeft, shiftLeftByConstSlice[uint16])
    53  	testUint32x4ShiftConst(t, archsimd.Uint32x4.ShiftAllLeft, shiftLeftByConstSlice[uint32])
    54  	testUint64x2ShiftConst(t, archsimd.Uint64x2.ShiftAllLeft, shiftLeftByConstSlice[uint64])
    55  
    56  	// Variable shift amount to prevent constant folding
    57  	forSlice(t, int32s, 4, func(x []int32) bool {
    58  		a := archsimd.LoadInt32x4(x)
    59  		g := make([]int32, 4)
    60  		a.ShiftAllLeft(testShiftConstAmt).Store(g)
    61  		w := shiftLeftByConstSlice(x, testShiftConstAmt)
    62  		return checkSlicesLogInput(t, g, w, 0.0, func() { t.Helper(); t.Logf("x=%v, amt=%d", x, testShiftConstAmt) })
    63  	})
    64  }
    65  
    66  func TestShiftRightConst(t *testing.T) {
    67  	// Signed (arithmetic right shift)
    68  	testInt8x16ShiftConst(t, archsimd.Int8x16.ShiftAllRight, shiftRightByConstSlice[int8])
    69  	testInt16x8ShiftConst(t, archsimd.Int16x8.ShiftAllRight, shiftRightByConstSlice[int16])
    70  	testInt32x4ShiftConst(t, archsimd.Int32x4.ShiftAllRight, shiftRightByConstSlice[int32])
    71  	testInt64x2ShiftConst(t, archsimd.Int64x2.ShiftAllRight, shiftRightByConstSlice[int64])
    72  	// Unsigned (logical right shift)
    73  	testUint8x16ShiftConst(t, archsimd.Uint8x16.ShiftAllRight, shiftRightByConstSlice[uint8])
    74  	testUint16x8ShiftConst(t, archsimd.Uint16x8.ShiftAllRight, shiftRightByConstSlice[uint16])
    75  	testUint32x4ShiftConst(t, archsimd.Uint32x4.ShiftAllRight, shiftRightByConstSlice[uint32])
    76  	testUint64x2ShiftConst(t, archsimd.Uint64x2.ShiftAllRight, shiftRightByConstSlice[uint64])
    77  
    78  	// Variable shift amount to prevent constant folding
    79  	forSlice(t, int32s, 4, func(x []int32) bool {
    80  		a := archsimd.LoadInt32x4(x)
    81  		g := make([]int32, 4)
    82  		a.ShiftAllRight(testShiftConstAmt).Store(g)
    83  		w := shiftRightByConstSlice(x, testShiftConstAmt)
    84  		return checkSlicesLogInput(t, g, w, 0.0, func() { t.Helper(); t.Logf("x=%v, amt=%d", x, testShiftConstAmt) })
    85  	})
    86  }
    87  
    88  // testShiftAllAmts contains shift amounts for ShiftAll tests, including
    89  // in-range amounts for all element sizes and out-of-range amounts to
    90  // verify CSEL/CMPconst clamping logic in the lowering rules.
    91  var testShiftAllAmts = []uint64{0, 1, 3, 7, 15, 31, 63, 128, 1024}
    92  
    93  // testShiftAllVarAmt is a non-constant shift amount to prevent constant folding.
    94  var testShiftAllVarAmt uint64 = 3
    95  
    96  func TestShiftAllLeft(t *testing.T) {
    97  	// Signed
    98  	testInt8x16ShiftAll(t, archsimd.Int8x16.ShiftAllLeft, shiftAllLeftSlice[int8])
    99  	testInt16x8ShiftAll(t, archsimd.Int16x8.ShiftAllLeft, shiftAllLeftSlice[int16])
   100  	testInt32x4ShiftAll(t, archsimd.Int32x4.ShiftAllLeft, shiftAllLeftSlice[int32])
   101  	testInt64x2ShiftAll(t, archsimd.Int64x2.ShiftAllLeft, shiftAllLeftSlice[int64])
   102  	// Unsigned
   103  	testUint8x16ShiftAll(t, archsimd.Uint8x16.ShiftAllLeft, shiftAllLeftSlice[uint8])
   104  	testUint16x8ShiftAll(t, archsimd.Uint16x8.ShiftAllLeft, shiftAllLeftSlice[uint16])
   105  	testUint32x4ShiftAll(t, archsimd.Uint32x4.ShiftAllLeft, shiftAllLeftSlice[uint32])
   106  	testUint64x2ShiftAll(t, archsimd.Uint64x2.ShiftAllLeft, shiftAllLeftSlice[uint64])
   107  
   108  	// Variable shift amount to prevent constant folding
   109  	forSlice(t, int32s, 4, func(x []int32) bool {
   110  		a := archsimd.LoadInt32x4(x)
   111  		g := make([]int32, 4)
   112  		a.ShiftAllLeft(testShiftAllVarAmt).Store(g)
   113  		w := shiftAllLeftSlice(x, testShiftAllVarAmt)
   114  		return checkSlicesLogInput(t, g, w, 0.0, func() { t.Helper(); t.Logf("x=%v, amt=%d", x, testShiftAllVarAmt) })
   115  	})
   116  }
   117  
   118  func TestShiftAllRight(t *testing.T) {
   119  	// Signed (arithmetic right shift)
   120  	testInt8x16ShiftAll(t, archsimd.Int8x16.ShiftAllRight, shiftAllRightSlice[int8])
   121  	testInt16x8ShiftAll(t, archsimd.Int16x8.ShiftAllRight, shiftAllRightSlice[int16])
   122  	testInt32x4ShiftAll(t, archsimd.Int32x4.ShiftAllRight, shiftAllRightSlice[int32])
   123  	testInt64x2ShiftAll(t, archsimd.Int64x2.ShiftAllRight, shiftAllRightSlice[int64])
   124  	// Unsigned (logical right shift)
   125  	testUint8x16ShiftAll(t, archsimd.Uint8x16.ShiftAllRight, shiftAllRightSlice[uint8])
   126  	testUint16x8ShiftAll(t, archsimd.Uint16x8.ShiftAllRight, shiftAllRightSlice[uint16])
   127  	testUint32x4ShiftAll(t, archsimd.Uint32x4.ShiftAllRight, shiftAllRightSlice[uint32])
   128  	testUint64x2ShiftAll(t, archsimd.Uint64x2.ShiftAllRight, shiftAllRightSlice[uint64])
   129  
   130  	// Variable shift amount to prevent constant folding
   131  	forSlice(t, int32s, 4, func(x []int32) bool {
   132  		a := archsimd.LoadInt32x4(x)
   133  		g := make([]int32, 4)
   134  		a.ShiftAllRight(testShiftAllVarAmt).Store(g)
   135  		w := shiftAllRightSlice(x, testShiftAllVarAmt)
   136  		return checkSlicesLogInput(t, g, w, 0.0, func() { t.Helper(); t.Logf("x=%v, amt=%d", x, testShiftAllVarAmt) })
   137  	})
   138  }
   139  
   140  func TestConcatShiftBytesRight(t *testing.T) {
   141  	hide := hideConst[uint64]
   142  
   143  	csbr := func(shift uint64) func(x, y []uint8) []uint8 {
   144  		return func(x, y []uint8) []uint8 {
   145  			z := make([]uint8, len(x))
   146  			for i := range z {
   147  				target := i + int(shift)
   148  				if target < 16 {
   149  					z[i] = y[target]
   150  				} else if target < 32 {
   151  					z[i] = x[(target - 16)]
   152  				}
   153  			}
   154  			return z
   155  		}
   156  	}
   157  
   158  	t.Run("Uint8x16", func(t *testing.T) {
   159  		for _, shift := range []uint64{0, 2, 8, 15} {
   160  			t.Log("shift", shift)
   161  			testUint8x16Binary(t,
   162  				func(x, y archsimd.Uint8x16) archsimd.Uint8x16 { return x.ConcatShiftBytesRight(y, shift) },
   163  				csbr(shift))
   164  			testUint8x16Binary(t,
   165  				func(x, y archsimd.Uint8x16) archsimd.Uint8x16 { return x.ConcatShiftBytesRight(y, hide(shift)) },
   166  				csbr(hide(shift)))
   167  		}
   168  	})
   169  }
   170  

View as plain text