Source file src/simd/archsimd/clmul_emulated.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.simd && (arm64 || wasm)
     6  
     7  package archsimd
     8  
     9  func new64x2(lo, hi uint64) Uint64x2 {
    10  	return Uint64x2{}.SetElem(0, lo).SetElem(1, hi)
    11  }
    12  
    13  // These masks all have 4 zeroes between 1s.
    14  var m0 = new64x2(0x1084210842108421, 0x2108421084210842)
    15  var m1 = new64x2(0x2108421084210842, 0x4210842108421084)
    16  var m2 = new64x2(0x4210842108421084, 0x8421084210842108)
    17  var m3 = new64x2(0x8421084210842108, 0x0842108421084210)
    18  var m4 = new64x2(0x0842108421084210, 0x1084210842108421)
    19  
    20  // Selects the middle 64 bits of a 128-bit simd value
    21  var middle = new64x2(0xffffffff00000000, 0x00000000ffffffff)
    22  
    23  // mwl is a 64x64 into 128 multiply that is missing
    24  // some carries that we don't need for CLMUL emulation.
    25  // The high 64 bits of each input are ignored.
    26  // Also just for fun, accumulate sums with Xor.
    27  func (x Uint64x2) mwl(y Uint64x2) Uint64x2 {
    28  	// reshape input into Uint32x4
    29  	// input is  {a b _ _}.mwl{c d _ _}
    30  	// need the sum of
    31  	// ac0_ac1
    32  	//   0 ad0_ad1
    33  	//   0 bc0_bc1
    34  	//   0   0 bd0_bd1
    35  	// This "sum" is where the carries (not propagated
    36  	// across lanes) are lost.
    37  	ab__ := x.ReshapeToUint32s()
    38  	cd__ := y.ReshapeToUint32s()
    39  	ac0_ac1_bd0_bd1 := ab__.MulWidenLo(cd__)
    40  
    41  	dc__ := y.RotateAllLeft(32).ReshapeToUint32s()
    42  	ad0_ad1_bc0_bc1 := ab__.MulWidenLo(dc__)
    43  	//
    44  	// have        ad0, ad1, bc0, bc1
    45  	// want        0, ad0+bc0, ad1+bc1, 0
    46  	// to add to    ac0_ac1_bd0_bd1
    47  	//
    48  	// swap 64-bit halves of ad0_ad1_bc0_bc1
    49  	// to get   bc0_bc1_ad0_ad1
    50  	bc0_bc1_ad0_ad1 := Uint64x2{}.SetElem(0, ad0_ad1_bc0_bc1.GetElem(1)).SetElem(1, ad0_ad1_bc0_bc1.GetElem(0))
    51  
    52  	// added to ad0_ad1_bc0_bc1 yields
    53  	//   bc0+ad0, bc1+ad1, bc0+ad0, bc1+ad1
    54  	// rotate 32 (within the two 64-bit elements) yields
    55  	//   bc1+ad1, bc0+ad0, bc1+ad1, bc0+ad0
    56  	// and then intersect with mask:
    57  	//   0      , bc0+ad0, bc1+ad1, 0
    58  	//
    59  	// use xor to make it a worse multiply
    60  	zzz_adPbc0_adPbc1_zzz := bc0_bc1_ad0_ad1.Xor(ad0_ad1_bc0_bc1).RotateAllLeft(32).And(middle)
    61  	return ac0_ac1_bd0_bd1.Xor(zzz_adPbc0_adPbc1_zzz)
    62  }
    63  
    64  // carrylessMultiply is constant time carrless multiply implemented with an
    65  // absurd number of multiplication given that the emulation platforms only have
    66  // 32x32 into 64, it might make sense to rework this into that primitive, but,
    67  // for now this works and is easily tested in scalar Go.
    68  func (x Uint64x2) carrylessMultiply(y Uint64x2) Uint64x2 {
    69  
    70  	// This by masking the two inputs into 5 thinned inputs, with
    71  	// 4 zeroes separating any 2 set bits.  Multiply will potentially
    72  	// set more bits with addition of overlapping terms, however this
    73  	// technique allows as many as 31 additions (filling all 4 separation
    74  	// positions with 1) without perturbing the bits we care about.  Since
    75  	// there's at most 13 set bits in a thinned input, 31 is not a problem.
    76  	// If there were only 3 set bits, there are 16 1s per thinned input and
    77  	// only 15 additions can be tolerated -- so that's not possible.
    78  
    79  	// This is also discussed at
    80  	// https://timtaubert.de/blog/2017/06/verified-binary-multiplication-for-ghash/
    81  
    82  	x0 := x.And(m0)
    83  	x1 := x.And(m1)
    84  	x2 := x.And(m2)
    85  	x3 := x.And(m3)
    86  	x4 := x.And(m4)
    87  
    88  	y0 := y.And(m0)
    89  	y1 := y.And(m1)
    90  	y2 := y.And(m2)
    91  	y3 := y.And(m3)
    92  	y4 := y.And(m4)
    93  
    94  	var z Uint64x2
    95  	// for a given line, combining (xI).mwl(yJ) terms, I+J == K mod 5; mask index = K
    96  	z = (x0.mwl(y0)).Xor(x1.mwl(y4)).Xor(x4.mwl(y1)).Xor(x2.mwl(y3)).Xor(x3.mwl(y2)).And(m0)
    97  	z = (x3.mwl(y3)).Xor(x2.mwl(y4)).Xor(x4.mwl(y2)).Xor(x0.mwl(y1)).Xor(x1.mwl(y0)).And(m1).Or(z)
    98  	z = (x1.mwl(y1)).Xor(x3.mwl(y4)).Xor(x4.mwl(y3)).Xor(x0.mwl(y2)).Xor(x2.mwl(y0)).And(m2).Or(z)
    99  	z = (x4.mwl(y4)).Xor(x0.mwl(y3)).Xor(x3.mwl(y0)).Xor(x1.mwl(y2)).Xor(x2.mwl(y1)).And(m3).Or(z)
   100  	z = (x2.mwl(y2)).Xor(x0.mwl(y4)).Xor(x4.mwl(y0)).Xor(x1.mwl(y3)).Xor(x3.mwl(y1)).And(m4).Or(z)
   101  
   102  	return z
   103  }
   104  

View as plain text