Source file src/simd/archsimd/internal/simd_test/shift_test.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.simd && amd64
     6  
     7  package simd_test
     8  
     9  import (
    10  	"simd/archsimd"
    11  	"testing"
    12  )
    13  
    14  func TestRotateAllLeft(t *testing.T) {
    15  	x := uint8(0x81)
    16  	if y := rotl(x, 1); y != 3 {
    17  		t.Errorf("Expected 3, got 0x%x", y)
    18  	}
    19  	if y := rotl(x, 7); y != 0xc0 {
    20  		t.Errorf("Expected 0xc0, got 0x%x", y)
    21  	}
    22  	if y := rotr(x, 4); y != 0x18 {
    23  		t.Errorf("Expected 0x18, got 0x%x", y)
    24  	}
    25  
    26  	for i := uint64(0); i < 65; i++ {
    27  		testUint64x2Unary(t, curry2(archsimd.Uint64x2.RotateAllLeft, i), rotlOfSlice[uint64](i))
    28  		testUint32x4Unary(t, curry2(archsimd.Uint32x4.RotateAllLeft, i), rotlOfSlice[uint32](i))
    29  		//		testUint16x8Unary(t, curry2(archsimd.Uint16x8.RotateAllLeft, i), rotlOfSlice[uint16](i))
    30  		//		testUint8x16Unary(t, curry2(archsimd.Uint8x16.RotateAllLeft, i), rotlOfSlice[uint8](i))
    31  	}
    32  	for i := uint64(0); i < 65; i++ {
    33  		testUint64x4Unary(t, curry2(archsimd.Uint64x4.RotateAllLeft, i), rotlOfSlice[uint64](i))
    34  		testUint32x8Unary(t, curry2(archsimd.Uint32x8.RotateAllLeft, i), rotlOfSlice[uint32](i))
    35  		//		testUint16x16Unary(t, curry2(archsimd.Uint16x16.RotateAllLeft, i), rotlOfSlice[uint16](i))
    36  		//		testUint8x32Unary(t, curry2(archsimd.Uint8x32.RotateAllLeft, i), rotlOfSlice[uint8](i))
    37  	}
    38  
    39  }
    40  
    41  func TestRotateAllRight(t *testing.T) {
    42  	x := uint8(0x81)
    43  	if y := rotr(x, 1); y != 0xc0 {
    44  		t.Errorf("Expected 0xc0, got 0x%x", y)
    45  	}
    46  	if y := rotr(x, 7); y != 3 {
    47  		t.Errorf("Expected 3, got 0x%x", y)
    48  	}
    49  	if y := rotr(x, 4); y != 0x18 {
    50  		t.Errorf("Expected 0x18, got 0x%x", y)
    51  	}
    52  
    53  	for i := uint64(0); i < 65; i++ {
    54  		testUint64x2Unary(t, curry2(archsimd.Uint64x2.RotateAllRight, i), rotrOfSlice[uint64](i))
    55  		testUint32x4Unary(t, curry2(archsimd.Uint32x4.RotateAllRight, i), rotrOfSlice[uint32](i))
    56  		//		testUint16x8Unary(t, curry2(archsimd.Uint16x8.RotateAllLeft, i), rotlOfSlice[uint16](i))
    57  		//		testUint8x16Unary(t, curry2(archsimd.Uint8x16.RotateAllLeft, i), rotlOfSlice[uint8](i))
    58  	}
    59  	for i := uint64(0); i < 65; i++ {
    60  		testUint64x4Unary(t, curry2(archsimd.Uint64x4.RotateAllRight, i), rotrOfSlice[uint64](i))
    61  		testUint32x8Unary(t, curry2(archsimd.Uint32x8.RotateAllRight, i), rotrOfSlice[uint32](i))
    62  		//		testUint16x16Unary(t, curry2(archsimd.Uint16x16.RotateAllLeft, i), rotlOfSlice[uint16](i))
    63  		//		testUint8x32Unary(t, curry2(archsimd.Uint8x32.RotateAllLeft, i), rotlOfSlice[uint8](i))
    64  	}
    65  }
    66  
    67  func TestShift(t *testing.T) {
    68  	if !archsimd.X86.AVX2() {
    69  		t.Skip("requires AVX2")
    70  	}
    71  
    72  	testInt32x4Binary(t,
    73  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftLeft(y.AsUint32x4()) },
    74  		map2(func(x, y int32) int32 { return x << uint32(y) }))
    75  	testInt32x4Binary(t,
    76  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftRight(y.AsUint32x4()) },
    77  		map2(func(x, y int32) int32 { return x >> uint32(y) }))
    78  	testUint32x4Binary(t,
    79  		func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftRight(y) },
    80  		map2(func(x, y uint32) uint32 { return x >> y }))
    81  }
    82  
    83  func concatInt32s(x, y int32) int64 {
    84  	return (int64(x) << 32) | int64(uint32(y))
    85  }
    86  
    87  func concatUint32s(x, y uint32) uint64 {
    88  	return (uint64(x) << 32) | uint64(y)
    89  }
    90  
    91  func TestShiftAllConcat(t *testing.T) {
    92  	if !archsimd.X86.AVX512VBMI2() {
    93  		t.Skip("requires AVX512-VBMI2")
    94  	}
    95  
    96  	// Note that unlike their non-Concat counterparts, these wrap the shift count.
    97  
    98  	hide := hideConst[uint64]
    99  
   100  	// ShiftAllLeftConcat
   101  	salc := func(shift uint64) func(x, y int32) int32 {
   102  		return func(x, y int32) int32 {
   103  			return int32(concatInt32s(x, y) >> (32 - shift%32))
   104  		}
   105  	}
   106  
   107  	testInt32x4Binary(t,
   108  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeftConcatMod32(y, 2) },
   109  		map2(salc(2)))
   110  	testInt32x4Binary(t,
   111  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeftConcatMod32(y, hide(2)) },
   112  		map2(salc(hide(2))))
   113  
   114  	testInt32x4Binary(t,
   115  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeftConcatMod32(y, 128) },
   116  		map2(salc(128)))
   117  	testInt32x4Binary(t,
   118  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeftConcatMod32(y, hide(128)) },
   119  		map2(salc(hide(128))))
   120  
   121  	// Signed ShiftAllRightConcat
   122  	sarc := func(shift uint64) func(x, y int32) int32 {
   123  		return func(x, y int32) int32 {
   124  			return int32(concatInt32s(y, x) >> (shift % 32))
   125  		}
   126  	}
   127  
   128  	testInt32x4Binary(t,
   129  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRightConcatMod32(y, 2) },
   130  		map2(sarc(2)))
   131  	testInt32x4Binary(t,
   132  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRightConcatMod32(y, hide(2)) },
   133  		map2(sarc(hide(2))))
   134  
   135  	testInt32x4Binary(t,
   136  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRightConcatMod32(y, 128) },
   137  		map2(sarc(128)))
   138  	testInt32x4Binary(t,
   139  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRightConcatMod32(y, hide(128)) },
   140  		map2(sarc(hide(128))))
   141  
   142  	// Unsigned ShiftAllRightConcat
   143  	usarc := func(shift uint64) func(x, y uint32) uint32 {
   144  		return func(x, y uint32) uint32 {
   145  			return uint32(concatUint32s(y, x) >> (shift % 32))
   146  		}
   147  	}
   148  
   149  	testUint32x4Binary(t,
   150  		func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRightConcatMod32(y, 2) },
   151  		map2(usarc(2)))
   152  	testUint32x4Binary(t,
   153  		func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRightConcatMod32(y, hide(2)) },
   154  		map2(usarc(hide(2))))
   155  
   156  	testUint32x4Binary(t,
   157  		func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRightConcatMod32(y, 128) },
   158  		map2(usarc(128)))
   159  	testUint32x4Binary(t,
   160  		func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRightConcatMod32(y, hide(128)) },
   161  		map2(usarc(hide(128))))
   162  }
   163  
   164  func TestShiftConcat(t *testing.T) {
   165  	if !archsimd.X86.AVX512VBMI2() {
   166  		t.Skip("requires AVX512-VBMI2")
   167  	}
   168  
   169  	// Note that unlike their non-Concat counterparts, these wrap the shift count.
   170  
   171  	testInt32x4Ternary(t,
   172  		func(x, y, z archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftLeftConcatMod32(y, z.AsUint32x4()) },
   173  		map3(func(x, y, z int32) int32 {
   174  			return int32(concatInt32s(x, y) >> (32 - uint32(z)%32))
   175  		}))
   176  
   177  	testInt32x4Ternary(t,
   178  		func(x, y, z archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftRightConcatMod32(y, z.AsUint32x4()) },
   179  		map3(func(x, y, z int32) int32 {
   180  			return int32(concatInt32s(y, x) >> (uint32(z) % 32))
   181  		}))
   182  
   183  	testUint32x4Ternary(t,
   184  		func(x, y, z archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftRightConcatMod32(y, z) },
   185  		map3(func(x, y, z uint32) uint32 {
   186  			return uint32(concatUint32s(y, x) >> (z % 32))
   187  		}))
   188  }
   189  
   190  func TestConcatShiftBytesRight(t *testing.T) {
   191  	hide := hideConst[uint64]
   192  
   193  	csbr := func(shift uint64) func(x, y []uint8) []uint8 {
   194  		return func(x, y []uint8) []uint8 {
   195  			z := make([]uint8, len(x))
   196  			for i := range z {
   197  				target := i + int(shift)
   198  				if target < 16 {
   199  					z[i] = y[target]
   200  				} else if target < 32 {
   201  					z[i] = x[(target - 16)]
   202  				}
   203  			}
   204  			return z
   205  		}
   206  	}
   207  
   208  	t.Run("Uint8x16", func(t *testing.T) {
   209  		if !archsimd.X86.AVX() {
   210  			t.Skip("requires AVX")
   211  		}
   212  		for _, shift := range []uint64{0, 2, 16, 20, 32, 128} {
   213  			t.Log("shift", shift)
   214  			testUint8x16Binary(t,
   215  				func(x, y archsimd.Uint8x16) archsimd.Uint8x16 { return x.ConcatShiftBytesRight(y, shift) },
   216  				csbr(shift))
   217  			testUint8x16Binary(t,
   218  				func(x, y archsimd.Uint8x16) archsimd.Uint8x16 { return x.ConcatShiftBytesRight(y, hide(shift)) },
   219  				csbr(hide(shift)))
   220  		}
   221  	})
   222  
   223  	t.Run("Uint8x32", func(t *testing.T) {
   224  		if !archsimd.X86.AVX2() {
   225  			t.Skip("requires AVX2")
   226  		}
   227  		for _, shift := range []uint64{0, 2, 16, 20, 32, 128} {
   228  			t.Log("shift", shift)
   229  			testUint8x32Binary(t,
   230  				func(x, y archsimd.Uint8x32) archsimd.Uint8x32 { return x.ConcatShiftBytesRightGrouped(y, shift) },
   231  				grouped2(csbr(shift)))
   232  			testUint8x32Binary(t,
   233  				func(x, y archsimd.Uint8x32) archsimd.Uint8x32 { return x.ConcatShiftBytesRightGrouped(y, hide(shift)) },
   234  				grouped2(csbr(hide(shift))))
   235  		}
   236  	})
   237  
   238  	t.Run("Uint8x64", func(t *testing.T) {
   239  		if !archsimd.X86.AVX512() {
   240  			t.Skip("requires AVX512")
   241  		}
   242  		for _, shift := range []uint64{0, 2, 16, 20, 32, 128} {
   243  			t.Log("shift", shift)
   244  			testUint8x64Binary(t,
   245  				func(x, y archsimd.Uint8x64) archsimd.Uint8x64 { return x.ConcatShiftBytesRightGrouped(y, shift) },
   246  				grouped2(csbr(shift)))
   247  			testUint8x64Binary(t,
   248  				func(x, y archsimd.Uint8x64) archsimd.Uint8x64 { return x.ConcatShiftBytesRightGrouped(y, hide(shift)) },
   249  				grouped2(csbr(hide(shift))))
   250  		}
   251  	})
   252  }
   253  

View as plain text