Source file src/simd/clmul_test.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.simd
     6  
     7  package simd_test
     8  
     9  import (
    10  	"fmt"
    11  	"math/bits"
    12  	"math/rand"
    13  	"simd"
    14  	"testing"
    15  )
    16  
    17  func vpsumd(xlo, xhi, ylo, yhi uint64) (lo, hi uint64) {
    18  	lo, hi = clmul64(xhi, yhi)
    19  	l, h := clmul64(xlo, ylo)
    20  	hi ^= h
    21  	lo ^= l
    22  	return
    23  }
    24  
    25  func clmul64(a, b uint64) (lo, hi uint64) {
    26  	for i := range uint(64) {
    27  		if (a>>i)&1 == 1 {
    28  			if i == 0 {
    29  				lo ^= b
    30  			} else {
    31  				lo ^= b << i
    32  				hi ^= b >> (64 - i)
    33  			}
    34  		}
    35  	}
    36  	return
    37  }
    38  
    39  func em2(xlo, xhi, ylo, yhi uint64) string {
    40  	lx := newT(xlo, 0)
    41  	ly := newT(ylo, 0)
    42  	hx := newT(xhi, 0)
    43  	hy := newT(yhi, 0)
    44  
    45  	z := (lx.ClMul(ly)).Xor(hx.ClMul(hy))
    46  
    47  	return fmt.Sprintf("0x%08x%08x", z.b, z.a)
    48  }
    49  
    50  func em1(a, b, c, d uint64) string {
    51  	lo, hi := vpsumd(a, b, c, d)
    52  	return fmt.Sprintf("0x%08x%08x", hi, lo)
    53  }
    54  
    55  func set0(v uint64) simd.Uint64s {
    56  	a := [2]uint64{v, 0}
    57  	r, _ := simd.LoadUint64sPart(a[:])
    58  	return r
    59  }
    60  
    61  func get(v simd.Uint64s) (lo, hi uint64) {
    62  	var a [2]uint64
    63  	v.StorePart(a[:])
    64  	return a[0], a[1]
    65  }
    66  
    67  func em3(xlo, xhi, ylo, yhi uint64) string {
    68  	lx := set0(xlo)
    69  	ly := set0(ylo)
    70  	hx := set0(xhi)
    71  	hy := set0(yhi)
    72  
    73  	z := (lx.CarrylessMultiplyEven(ly)).Xor(hx.CarrylessMultiplyEven(hy))
    74  
    75  	lo, hi := get(z)
    76  	return fmt.Sprintf("0x%08x%08x", hi, lo)
    77  }
    78  
    79  func TestClMul(t *testing.T) {
    80  	fmt.Println("Vector length:", simd.VectorBitSize())
    81  	fmt.Println("Emulated:", simd.Emulated())
    82  	fmt.Println("HasHWCLMUL:", simd.HasHardwareCarrylessMultiply())
    83  
    84  	x := uint64(0x0807060504030201)
    85  	y := uint64(0x0101010101010101)
    86  
    87  	var a, b, c, d uint64
    88  	a, b, c, d = 0x66b32838754f59a3, 0xaeba319ab2418c50, 0x45678b3c7f11fc73, 0xd62ef8ae5f7b693
    89  
    90  	fmt.Println("EMULATION 1")
    91  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", x, x, 1, 16, em1(x, x, 1, 16))
    92  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", x, y, 1, 16, em1(x, y, 1, 16))
    93  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", x, y, x, y, em1(x, y, x, y))
    94  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", 1, 2, y*4, y, em1(1, 2, y*4, y))
    95  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", a, b, c, d, em1(a, b, c, d))
    96  
    97  	fmt.Println("EMULATION 2")
    98  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", x, x, 1, 16, em2(x, x, 1, 16))
    99  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", x, y, 1, 16, em2(x, y, 1, 16))
   100  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", x, y, x, y, em2(x, y, x, y))
   101  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", 1, 2, y*4, y, em2(1, 2, y*4, y))
   102  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", a, b, c, d, em2(a, b, c, d))
   103  
   104  	fmt.Println("EMULATION 3")
   105  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", x, x, 1, 16, em3(x, x, 1, 16))
   106  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", x, y, 1, 16, em3(x, y, 1, 16))
   107  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", x, y, x, y, em3(x, y, x, y))
   108  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", 1, 2, y*4, y, em3(1, 2, y*4, y))
   109  	fmt.Printf("clmul(0x%08x, 0x%08x, 0x%08x, 0x%08x) = %s\n", a, b, c, d, em3(a, b, c, d))
   110  
   111  	for i := range 10000 {
   112  		a, b, c, d := rand.Uint64(), rand.Uint64(), rand.Uint64(), rand.Uint64()
   113  
   114  		e1 := em1(a, b, c, d)
   115  		e2 := em2(a, b, c, d)
   116  		e3 := em3(a, b, c, d)
   117  
   118  		if e1 != e2 || e1 != e3 {
   119  			t.Errorf("Mismatch at %d, a,b,c,d = 0x%08x, 0x%08x, 0x%08x, 0x%08x; e1=%s, e2=%s, e3=%s", i, a, b, c, d, e1, e2, e3)
   120  			if i > 5 {
   121  				return
   122  			}
   123  		}
   124  
   125  	}
   126  
   127  }
   128  
   129  type T struct {
   130  	a, b uint64
   131  }
   132  
   133  func newT(lo, hi uint64) T {
   134  	return T{a: lo, b: hi}
   135  }
   136  
   137  func (x T) And(y T) T {
   138  	return T{a: x.a & y.a, b: x.b & y.b}
   139  }
   140  
   141  func (x T) Xor(y T) T {
   142  	return T{a: x.a ^ y.a, b: x.b ^ y.b}
   143  }
   144  
   145  func (x T) Or(y T) T {
   146  	return T{a: x.a | y.a, b: x.b | y.b}
   147  }
   148  
   149  func (x T) MWL(y T) T { // MulWidenLo
   150  	hi, lo := bits.Mul64(x.a, y.a)
   151  	return T{a: lo, b: hi}
   152  }
   153  
   154  func (x T) ClMul(y T) T {
   155  	m1 := newT(0x1084210842108421, 0x2108421084210842)
   156  	m2 := newT(0x2108421084210842, 0x4210842108421084)
   157  	m3 := newT(0x4210842108421084, 0x8421084210842108)
   158  	m4 := newT(0x8421084210842108, 0x0842108421084210)
   159  	m5 := newT(0x0842108421084210, 0x1084210842108421)
   160  
   161  	x1 := x.And(m1)
   162  	x2 := x.And(m2)
   163  	x3 := x.And(m3)
   164  	x4 := x.And(m4)
   165  	x5 := x.And(m5)
   166  
   167  	y1 := y.And(m1)
   168  	y2 := y.And(m2)
   169  	y3 := y.And(m3)
   170  	y4 := y.And(m4)
   171  	y5 := y.And(m5)
   172  
   173  	// sum of x, y indices == K mod 5; mask index = K-1
   174  	z := (x1.MWL(y1)).Xor(x2.MWL(y5)).Xor(x5.MWL(y2)).Xor(x3.MWL(y4)).Xor(x4.MWL(y3)).And(m1)
   175  	z = (x4.MWL(y4)).Xor(x3.MWL(y5)).Xor(x5.MWL(y3)).Xor(x1.MWL(y2)).Xor(x2.MWL(y1)).And(m2).Or(z)
   176  	z = (x2.MWL(y2)).Xor(x4.MWL(y5)).Xor(x5.MWL(y4)).Xor(x1.MWL(y3)).Xor(x3.MWL(y1)).And(m3).Or(z)
   177  	z = (x5.MWL(y5)).Xor(x1.MWL(y4)).Xor(x4.MWL(y1)).Xor(x2.MWL(y3)).Xor(x3.MWL(y2)).And(m4).Or(z)
   178  	z = (x3.MWL(y3)).Xor(x1.MWL(y5)).Xor(x5.MWL(y1)).Xor(x2.MWL(y4)).Xor(x4.MWL(y2)).And(m5).Or(z)
   179  
   180  	return z
   181  }
   182  

View as plain text