Source file src/crypto/internal/fips140/mldsa/field.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mldsa
     6  
     7  import (
     8  	"crypto/internal/constanttime"
     9  	"crypto/internal/fips140/sha3"
    10  	"errors"
    11  	"math/bits"
    12  )
    13  
    14  const (
    15  	q        = 8380417    // 2²³ - 2¹³ + 1
    16  	R        = 4294967296 // 2³²
    17  	RR       = 2365951    // R² mod q, aka R in the Montgomery domain
    18  	qNegInv  = 4236238847 // -q⁻¹ mod R (q * qNegInv ≡ -1 mod R)
    19  	one      = 4193792    // R mod q, aka 1 in the Montgomery domain
    20  	minusOne = 4186625    // (q - 1) * R mod q, aka -1 in the Montgomery domain
    21  )
    22  
    23  // fieldElement is an element n of ℤ_q in the Montgomery domain, represented as
    24  // an integer x in [0, q) such that x ≡ n * R (mod q) where R = 2³².
    25  type fieldElement uint32
    26  
    27  var errUnreducedFieldElement = errors.New("mldsa: unreduced field element")
    28  
    29  // fieldToMontgomery checks that a value a is < q, and converts it to
    30  // Montgomery form.
    31  func fieldToMontgomery(a uint32) (fieldElement, error) {
    32  	if a >= q {
    33  		return 0, errUnreducedFieldElement
    34  	}
    35  	// a * R² * R⁻¹ ≡ a * R (mod q)
    36  	return fieldMontgomeryMul(fieldElement(a), RR), nil
    37  }
    38  
    39  // fieldSubToMontgomery converts a difference a - b to Montgomery form.
    40  // a and b must be < q. (This bound can probably be relaxed.)
    41  func fieldSubToMontgomery(a, b uint32) fieldElement {
    42  	x := a - b + q
    43  	return fieldMontgomeryMul(fieldElement(x), RR)
    44  }
    45  
    46  // fieldFromMontgomery converts a value a in Montgomery form back to
    47  // standard representation.
    48  func fieldFromMontgomery(a fieldElement) uint32 {
    49  	// (a * R) * 1 * R⁻¹ ≡ a (mod q)
    50  	return uint32(fieldMontgomeryReduce(uint64(a)))
    51  }
    52  
    53  // fieldCenteredMod returns r mod± q, the value r reduced to the range
    54  // [−(q−1)/2, (q−1)/2].
    55  func fieldCenteredMod(r fieldElement) int32 {
    56  	x := int32(fieldFromMontgomery(r))
    57  	// x <= q / 2 ? x : x - q
    58  	return constantTimeSelectLessOrEqual(x, q/2, x, x-q)
    59  }
    60  
    61  // fieldInfinityNorm returns the infinity norm ||r||∞ of r, or the absolute
    62  // value of r centered around 0.
    63  func fieldInfinityNorm(r fieldElement) uint32 {
    64  	x := int32(fieldFromMontgomery(r))
    65  	// x <= q / 2 ? x : |x - q|
    66  	// |x - q| = -(x - q) = q - x because x < q => x - q < 0
    67  	return uint32(constantTimeSelectLessOrEqual(x, q/2, x, q-x))
    68  }
    69  
    70  // fieldReduceOnce reduces a value a < 2q.
    71  func fieldReduceOnce(a uint32) fieldElement {
    72  	x, b := bits.Sub64(uint64(a), uint64(q), 0)
    73  	return fieldElement(x + b*q)
    74  }
    75  
    76  // fieldAdd returns a + b mod q.
    77  func fieldAdd(a, b fieldElement) fieldElement {
    78  	x := uint32(a + b)
    79  	return fieldReduceOnce(x)
    80  }
    81  
    82  // fieldSub returns a - b mod q.
    83  func fieldSub(a, b fieldElement) fieldElement {
    84  	x := uint32(a - b + q)
    85  	return fieldReduceOnce(x)
    86  }
    87  
    88  // fieldMontgomeryMul returns a * b * R⁻¹ mod q.
    89  func fieldMontgomeryMul(a, b fieldElement) fieldElement {
    90  	x := uint64(a) * uint64(b)
    91  	return fieldMontgomeryReduce(x)
    92  }
    93  
    94  // fieldMontgomeryReduce returns x * R⁻¹ mod q for x < q * R.
    95  func fieldMontgomeryReduce(x uint64) fieldElement {
    96  	t := uint32(x) * qNegInv
    97  	u := (x + uint64(t)*q) >> 32
    98  	return fieldReduceOnce(uint32(u))
    99  }
   100  
   101  // fieldMontgomeryMulSub returns a * (b - c). This operation is fused to save a
   102  // fieldReduceOnce after the subtraction.
   103  func fieldMontgomeryMulSub(a, b, c fieldElement) fieldElement {
   104  	x := uint64(a) * uint64(b-c+q)
   105  	return fieldMontgomeryReduce(x)
   106  }
   107  
   108  // fieldMontgomeryAddMul returns a * b + c * d. This operation is fused to save
   109  // a fieldReduceOnce and a fieldReduce.
   110  func fieldMontgomeryAddMul(a, b, c, d fieldElement) fieldElement {
   111  	x := uint64(a) * uint64(b)
   112  	x += uint64(c) * uint64(d)
   113  	return fieldMontgomeryReduce(x)
   114  }
   115  
   116  const n = 256
   117  
   118  // ringElement is a polynomial, an element of R_q.
   119  type ringElement [n]fieldElement
   120  
   121  // polyAdd adds two ringElements or nttElements.
   122  func polyAdd[T ~[n]fieldElement](a, b T) (s T) {
   123  	for i := range s {
   124  		s[i] = fieldAdd(a[i], b[i])
   125  	}
   126  	return s
   127  }
   128  
   129  // polySub subtracts two ringElements or nttElements.
   130  func polySub[T ~[n]fieldElement](a, b T) (s T) {
   131  	for i := range s {
   132  		s[i] = fieldSub(a[i], b[i])
   133  	}
   134  	return s
   135  }
   136  
   137  // nttElement is an NTT representation, an element of T_q.
   138  type nttElement [n]fieldElement
   139  
   140  // zetas are the values ζ^BitRev₈(k) mod q for each index k, converted to the
   141  // Montgomery domain.
   142  var zetas = [256]fieldElement{4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, 466468, 1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103, 2725464, 1024112, 7300517, 3585928, 7830929, 7260833, 2619752, 6271868, 6262231, 4520680, 6980856, 5102745, 1757237, 8360995, 4010497, 280005, 2706023, 95776, 3077325, 3530437, 6718724, 4788269, 5842901, 3915439, 4519302, 5336701, 3574422, 5512770, 3539968, 8079950, 2348700, 7841118, 6681150, 6736599, 3505694, 4558682, 3507263, 6239768, 6779997, 3699596, 811944, 531354, 954230, 3881043, 3900724, 5823537, 2071892, 5582638, 4450022, 6851714, 4702672, 5339162, 6927966, 3475950, 2176455, 6795196, 7122806, 1939314, 4296819, 7380215, 5190273, 5223087, 4747489, 126922, 3412210, 7396998, 2147896, 2715295, 5412772, 4686924, 7969390, 5903370, 7709315, 7151892, 8357436, 7072248, 7998430, 1349076, 1852771, 6949987, 5037034, 264944, 508951, 3097992, 44288, 7280319, 904516, 3958618, 4656075, 8371839, 1653064, 5130689, 2389356, 8169440, 759969, 7063561, 189548, 4827145, 3159746, 6529015, 5971092, 8202977, 1315589, 1341330, 1285669, 6795489, 7567685, 6940675, 5361315, 4499357, 4751448, 3839961, 2091667, 3407706, 2316500, 3817976, 5037939, 2244091, 5933984, 4817955, 266997, 2434439, 7144689, 3513181, 4860065, 4621053, 7183191, 5187039, 900702, 1859098, 909542, 819034, 495491, 6767243, 8337157, 7857917, 7725090, 5257975, 2031748, 3207046, 4823422, 7855319, 7611795, 4784579, 342297, 286988, 5942594, 4108315, 3437287, 5038140, 1735879, 203044, 2842341, 2691481, 5790267, 1265009, 4055324, 1247620, 2486353, 1595974, 4613401, 1250494, 2635921, 4832145, 5386378, 1869119, 1903435, 7329447, 7047359, 1237275, 5062207, 6950192, 7929317, 1312455, 3306115, 6417775, 7100756, 1917081, 5834105, 7005614, 1500165, 777191, 2235880, 3406031, 7838005, 5548557, 6709241, 6533464, 5796124, 4656147, 594136, 4603424, 6366809, 2432395, 2454455, 8215696, 1957272, 3369112, 185531, 7173032, 5196991, 162844, 1616392, 3014001, 810149, 1652634, 4686184, 6581310, 5341501, 3523897, 3866901, 269760, 2213111, 7404533, 1717735, 472078, 7953734, 1723600, 6577327, 1910376, 6712985, 7276084, 8119771, 4546524, 5441381, 6144432, 7959518, 6094090, 183443, 7403526, 1612842, 4834730, 7826001, 3919660, 8332111, 7018208, 3937738, 1400424, 7534263, 1976782}
   143  
   144  // ntt maps a ringElement to its nttElement representation.
   145  //
   146  // It implements NTT, according to FIPS 203, Algorithm 9.
   147  func ntt(f ringElement) nttElement {
   148  	var m uint8
   149  
   150  	for len := 128; len >= 8; len /= 2 {
   151  		for start := 0; start < 256; start += 2 * len {
   152  			m++
   153  			zeta := zetas[m]
   154  
   155  			// Bounds check elimination hint.
   156  			f, flen := f[start:start+len], f[start+len:start+len+len]
   157  			for j := 0; j < len; j += 2 {
   158  				t := fieldMontgomeryMul(zeta, flen[j])
   159  				flen[j] = fieldSub(f[j], t)
   160  				f[j] = fieldAdd(f[j], t)
   161  
   162  				// Unroll by 2 for performance.
   163  				t = fieldMontgomeryMul(zeta, flen[j+1])
   164  				flen[j+1] = fieldSub(f[j+1], t)
   165  				f[j+1] = fieldAdd(f[j+1], t)
   166  			}
   167  		}
   168  	}
   169  
   170  	// Unroll len = 4, 2, and 1.
   171  	for start := 0; start < 256; start += 8 {
   172  		m++
   173  		zeta := zetas[m]
   174  
   175  		t := fieldMontgomeryMul(zeta, f[start+4])
   176  		f[start+4] = fieldSub(f[start], t)
   177  		f[start] = fieldAdd(f[start], t)
   178  
   179  		t = fieldMontgomeryMul(zeta, f[start+5])
   180  		f[start+5] = fieldSub(f[start+1], t)
   181  		f[start+1] = fieldAdd(f[start+1], t)
   182  
   183  		t = fieldMontgomeryMul(zeta, f[start+6])
   184  		f[start+6] = fieldSub(f[start+2], t)
   185  		f[start+2] = fieldAdd(f[start+2], t)
   186  
   187  		t = fieldMontgomeryMul(zeta, f[start+7])
   188  		f[start+7] = fieldSub(f[start+3], t)
   189  		f[start+3] = fieldAdd(f[start+3], t)
   190  	}
   191  	for start := 0; start < 256; start += 4 {
   192  		m++
   193  		zeta := zetas[m]
   194  
   195  		t := fieldMontgomeryMul(zeta, f[start+2])
   196  		f[start+2] = fieldSub(f[start], t)
   197  		f[start] = fieldAdd(f[start], t)
   198  
   199  		t = fieldMontgomeryMul(zeta, f[start+3])
   200  		f[start+3] = fieldSub(f[start+1], t)
   201  		f[start+1] = fieldAdd(f[start+1], t)
   202  	}
   203  	for start := 0; start < 256; start += 2 {
   204  		m++
   205  		zeta := zetas[m]
   206  
   207  		t := fieldMontgomeryMul(zeta, f[start+1])
   208  		f[start+1] = fieldSub(f[start], t)
   209  		f[start] = fieldAdd(f[start], t)
   210  	}
   211  
   212  	return nttElement(f)
   213  }
   214  
   215  // inverseNTT maps a nttElement back to the ringElement it represents.
   216  //
   217  // It implements NTT⁻¹, according to FIPS 203, Algorithm 10.
   218  func inverseNTT(f nttElement) ringElement {
   219  	var m uint8 = 255
   220  
   221  	// Unroll len = 1, 2, and 4.
   222  	for start := 0; start < 256; start += 2 {
   223  		zeta := zetas[m]
   224  		m--
   225  
   226  		t := f[start]
   227  		f[start] = fieldAdd(t, f[start+1])
   228  		f[start+1] = fieldMontgomeryMulSub(zeta, f[start+1], t)
   229  	}
   230  	for start := 0; start < 256; start += 4 {
   231  		zeta := zetas[m]
   232  		m--
   233  
   234  		t := f[start]
   235  		f[start] = fieldAdd(t, f[start+2])
   236  		f[start+2] = fieldMontgomeryMulSub(zeta, f[start+2], t)
   237  
   238  		t = f[start+1]
   239  		f[start+1] = fieldAdd(t, f[start+3])
   240  		f[start+3] = fieldMontgomeryMulSub(zeta, f[start+3], t)
   241  	}
   242  	for start := 0; start < 256; start += 8 {
   243  		zeta := zetas[m]
   244  		m--
   245  
   246  		t := f[start]
   247  		f[start] = fieldAdd(t, f[start+4])
   248  		f[start+4] = fieldMontgomeryMulSub(zeta, f[start+4], t)
   249  
   250  		t = f[start+1]
   251  		f[start+1] = fieldAdd(t, f[start+5])
   252  		f[start+5] = fieldMontgomeryMulSub(zeta, f[start+5], t)
   253  
   254  		t = f[start+2]
   255  		f[start+2] = fieldAdd(t, f[start+6])
   256  		f[start+6] = fieldMontgomeryMulSub(zeta, f[start+6], t)
   257  
   258  		t = f[start+3]
   259  		f[start+3] = fieldAdd(t, f[start+7])
   260  		f[start+7] = fieldMontgomeryMulSub(zeta, f[start+7], t)
   261  	}
   262  
   263  	for len := 8; len < 256; len *= 2 {
   264  		for start := 0; start < 256; start += 2 * len {
   265  			zeta := zetas[m]
   266  			m--
   267  
   268  			// Bounds check elimination hint.
   269  			f, flen := f[start:start+len], f[start+len:start+len+len]
   270  			for j := 0; j < len; j += 2 {
   271  				t := f[j]
   272  				f[j] = fieldAdd(t, flen[j])
   273  				// -z * (t - flen[j]) = z * (flen[j] - t)
   274  				flen[j] = fieldMontgomeryMulSub(zeta, flen[j], t)
   275  
   276  				// Unroll by 2 for performance.
   277  				t = f[j+1]
   278  				f[j+1] = fieldAdd(t, flen[j+1])
   279  				flen[j+1] = fieldMontgomeryMulSub(zeta, flen[j+1], t)
   280  			}
   281  		}
   282  	}
   283  
   284  	for i := range f {
   285  		f[i] = fieldMontgomeryMul(f[i], 16382) // 16382 = 256⁻¹ * R mod q
   286  	}
   287  	return ringElement(f)
   288  }
   289  
   290  // nttMul multiplies two nttElements.
   291  func nttMul(a, b nttElement) (p nttElement) {
   292  	for i := range p {
   293  		p[i] = fieldMontgomeryMul(a[i], b[i])
   294  	}
   295  	return p
   296  }
   297  
   298  // sampleNTT samples an nttElement uniformly at random from the seed rho and the
   299  // indices s and r. It implements Step 3 of ExpandA, RejNTTPoly, and
   300  // CoeffFromThreeBytes from FIPS 204, passing in ρ, s, and r instead of ρ'.
   301  func sampleNTT(rho []byte, s, r byte) nttElement {
   302  	G := sha3.NewShake128()
   303  	G.Write(rho)
   304  	G.Write([]byte{s, r})
   305  
   306  	var a nttElement
   307  	var j int         // index into a
   308  	var buf [168]byte // buffered reads from B, matching the rate of SHAKE-128
   309  	off := len(buf)   // index into buf, starts in a "buffer fully consumed" state
   310  	for j < n {
   311  		if off >= len(buf) {
   312  			G.Read(buf[:])
   313  			off = 0
   314  		}
   315  		v := uint32(buf[off]) | uint32(buf[off+1])<<8 | uint32(buf[off+2])<<16
   316  		off += 3
   317  		f, err := fieldToMontgomery(v & 0b01111111_11111111_11111111) // 23 bits
   318  		if err != nil {
   319  			continue
   320  		}
   321  		a[j] = f
   322  		j++
   323  	}
   324  	return a
   325  }
   326  
   327  // sampleBoundedPoly samples a ringElement with coefficients in [−η, η] from the
   328  // seed rho and the index r. It implements RejBoundedPoly and CoeffFromHalfByte
   329  // from FIPS 204, passing in ρ and r separately from ExpandS.
   330  func sampleBoundedPoly(rho []byte, r byte, p parameters) ringElement {
   331  	H := sha3.NewShake256()
   332  	H.Write(rho)
   333  	H.Write([]byte{r, 0}) // IntegerToBytes(r, 2)
   334  
   335  	var a ringElement
   336  	var j int
   337  	var buf [136]byte // buffered reads from H, matching the rate of SHAKE-256
   338  	off := len(buf)   // index into buf, starts in a "buffer fully consumed" state
   339  	for {
   340  		if off >= len(buf) {
   341  			H.Read(buf[:])
   342  			off = 0
   343  		}
   344  		z0 := buf[off] & 0x0F
   345  		z1 := buf[off] >> 4
   346  		off++
   347  		coeff, ok := coeffFromHalfByte(z0, p)
   348  		if ok {
   349  			a[j] = coeff
   350  			j++
   351  		}
   352  		if j >= len(a) {
   353  			break
   354  		}
   355  		coeff, ok = coeffFromHalfByte(z1, p)
   356  		if ok {
   357  			a[j] = coeff
   358  			j++
   359  		}
   360  		if j >= len(a) {
   361  			break
   362  		}
   363  	}
   364  	return a
   365  }
   366  
   367  // sampleInBall samples a ringElement with coefficients in {−1, 0, 1}, and τ
   368  // non-zero coefficients. It is not constant-time.
   369  func sampleInBall(rho []byte, p parameters) ringElement {
   370  	H := sha3.NewShake256()
   371  	H.Write(rho)
   372  	s := make([]byte, 8)
   373  	H.Read(s)
   374  
   375  	var c ringElement
   376  	for i := 256 - p.τ; i < 256; i++ {
   377  		j := make([]byte, 1)
   378  		H.Read(j)
   379  		for j[0] > byte(i) {
   380  			H.Read(j)
   381  		}
   382  		c[i] = c[j[0]]
   383  		// c[j] = (−1) ^ h[i+τ−256], where h are the bits in s in little-endian.
   384  		// That is, -1⁰ = 1 if the bit is 0, -1¹ = -1 if it is 1.
   385  		bitIdx := i + p.τ - 256
   386  		bit := (s[bitIdx/8] >> (bitIdx % 8)) & 1
   387  		if bit == 0 {
   388  			c[j[0]] = one
   389  		} else {
   390  			c[j[0]] = minusOne
   391  		}
   392  	}
   393  
   394  	return c
   395  }
   396  
   397  // coeffFromHalfByte implements CoeffFromHalfByte from FIPS 204.
   398  //
   399  // It maps a value in [0, 15] to a coefficient in [−η, η]
   400  func coeffFromHalfByte(b byte, p parameters) (fieldElement, bool) {
   401  	if b > 15 {
   402  		panic("internal error: half-byte out of range")
   403  	}
   404  	switch p.η {
   405  	case 2:
   406  		// Return z = 2 − (b mod 5), which maps from
   407  		//
   408  		//     b = ( 14, 13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0 )
   409  		//
   410  		// to
   411  		//
   412  		//   b%5 = (  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0 )
   413  		//
   414  		// to
   415  		//
   416  		//     z = ( -2, -1,  0,  1,  2, -2, -1,  0,  1,  2, -2, -1,  0,  1,  2 )
   417  		//
   418  		if b > 14 {
   419  			return 0, false
   420  		}
   421  		// Calculate b % 5 with Barrett reduction, to avoid a potentially
   422  		// variable-time division.
   423  		const barrettMultiplier = 0x3334 // ⌈2¹⁶ / 5⌉
   424  		const barrettShift = 16          // log₂(2¹⁶)
   425  		quotient := (uint32(b) * barrettMultiplier) >> barrettShift
   426  		remainder := uint32(b) - quotient*5
   427  		return fieldSubToMontgomery(2, remainder), true
   428  	case 4:
   429  		// Return z = 4 − b, which maps from
   430  		//
   431  		//   b = (  8,  7,  6,  5,  4,  3,  2,  1,  0 )
   432  		//
   433  		// to
   434  		//
   435  		//   z = ( −4, -3, -2, -1,  0,  1,  2,  3,  4 )
   436  		//
   437  		if b > 8 {
   438  			return 0, false
   439  		}
   440  		return fieldSubToMontgomery(4, uint32(b)), true
   441  	default:
   442  		panic("internal error: unsupported η")
   443  	}
   444  }
   445  
   446  // power2Round implements Power2Round from FIPS 204.
   447  //
   448  // It separates the bottom d = 13 bits of each 23-bit coefficient, rounding the
   449  // high part based on the low part, and correcting the low part accordingly.
   450  func power2Round(r fieldElement) (hi uint16, lo fieldElement) {
   451  	rr := fieldFromMontgomery(r)
   452  	// Add 2¹² - 1 to round up r1 by one if r0 > 2¹².
   453  	// r is at most 2²³ - 2¹³ + 1, so rr + (2¹² - 1) won't overflow 23 bits.
   454  	r1 := rr + 1<<12 - 1
   455  	r1 >>= 13
   456  	// r1 <= 2¹⁰ - 1
   457  	// r1 * 2¹³ <= (2¹⁰ - 1) * 2¹³ = 2²³ - 2¹³ < q
   458  	r0 := fieldSubToMontgomery(rr, r1<<13)
   459  	return uint16(r1), r0
   460  }
   461  
   462  // highBits implements HighBits from FIPS 204.
   463  func highBits(r ringElement, p parameters) [n]byte {
   464  	var w [n]byte
   465  	switch p.γ2 {
   466  	case 32:
   467  		for i := range n {
   468  			w[i] = highBits32(fieldFromMontgomery(r[i]))
   469  		}
   470  	case 88:
   471  		for i := range n {
   472  			w[i] = highBits88(fieldFromMontgomery(r[i]))
   473  		}
   474  	default:
   475  		panic("mldsa: internal error: unsupported γ2")
   476  	}
   477  	return w
   478  }
   479  
   480  // useHint implements UseHint from FIPS 204.
   481  //
   482  // It is not constant-time.
   483  func useHint(r ringElement, h [n]byte, p parameters) [n]byte {
   484  	var w [n]byte
   485  	switch p.γ2 {
   486  	case 32:
   487  		for i := range n {
   488  			w[i] = useHint32(r[i], h[i])
   489  		}
   490  	case 88:
   491  		for i := range n {
   492  			w[i] = useHint88(r[i], h[i])
   493  		}
   494  	default:
   495  		panic("mldsa: internal error: unsupported γ2")
   496  	}
   497  	return w
   498  }
   499  
   500  // makeHint implements MakeHint from FIPS 204.
   501  func makeHint(ct0, w, cs2 ringElement, p parameters) (h [n]byte, count1s int) {
   502  	switch p.γ2 {
   503  	case 32:
   504  		for i := range n {
   505  			h[i] = makeHint32(ct0[i], w[i], cs2[i])
   506  			count1s += int(h[i])
   507  		}
   508  	case 88:
   509  		for i := range n {
   510  			h[i] = makeHint88(ct0[i], w[i], cs2[i])
   511  			count1s += int(h[i])
   512  		}
   513  	default:
   514  		panic("mldsa: internal error: unsupported γ2")
   515  	}
   516  	return h, count1s
   517  }
   518  
   519  // highBits32 implements HighBits from FIPS 204 for γ2 = (q - 1) / 32.
   520  func highBits32(x uint32) byte {
   521  	// The implementation is based on the reference implementation and on
   522  	// BoringSSL. There are exhaustive tests in TestDecompose that compare it to
   523  	// a straightforward implementation of Decompose from the spec, so for our
   524  	// purposes it only has to work and be constant-time.
   525  	r1 := (x + 127) >> 7
   526  	r1 = (r1*1025 + (1 << 21)) >> 22
   527  	r1 &= 0b1111
   528  	return byte(r1)
   529  }
   530  
   531  // decompose32 implements Decompose from FIPS 204 for γ2 = (q - 1) / 32.
   532  //
   533  // r1 is in [0, 15].
   534  func decompose32(r fieldElement) (r1 byte, r0 int32) {
   535  	x := fieldFromMontgomery(r)
   536  	r1 = highBits32(x)
   537  
   538  	// r - r1 * (2 * γ2) mod± q
   539  	r0 = int32(x) - int32(r1)*2*(q-1)/32
   540  	r0 = constantTimeSelectLessOrEqual(q/2+1, r0, r0-q, r0)
   541  
   542  	return r1, r0
   543  }
   544  
   545  // useHint32 implements UseHint from FIPS 204 for γ2 = (q - 1) / 32.
   546  func useHint32(r fieldElement, hint byte) byte {
   547  	const m = 16 // (q − 1) / (2 * γ2)
   548  	r1, r0 := decompose32(r)
   549  	if hint == 1 {
   550  		if r0 > 0 {
   551  			r1 = (r1 + 1) % m
   552  		} else {
   553  			// Underflow is safe, because it operates modulo 256 (since the type
   554  			// is byte), which is a multiple of m.
   555  			r1 = (r1 - 1) % m
   556  		}
   557  	}
   558  	return r1
   559  }
   560  
   561  // makeHint32 implements MakeHint from FIPS 204 for γ2 = (q - 1) / 32.
   562  func makeHint32(ct0, w, cs2 fieldElement) byte {
   563  	// v1 = HighBits(r + z) = HighBits(w - cs2 + ct0 - ct0) = HighBits(w - cs2)
   564  	rPlusZ := fieldSub(w, cs2)
   565  	v1 := highBits32(fieldFromMontgomery(rPlusZ))
   566  	// r1 = HighBits(r) = HighBits(w - cs2 + ct0)
   567  	r1 := highBits32(fieldFromMontgomery(fieldAdd(rPlusZ, ct0)))
   568  
   569  	return byte(constanttime.ByteEq(v1, r1) ^ 1)
   570  }
   571  
   572  // highBits88 implements HighBits from FIPS 204 for γ2 = (q - 1) / 88.
   573  func highBits88(x uint32) byte {
   574  	// Like highBits32, this is exhaustively tested in TestDecompose.
   575  	r1 := (x + 127) >> 7
   576  	r1 = (r1*11275 + (1 << 23)) >> 24
   577  	r1 = constantTimeSelectEqual(r1, 44, 0, r1)
   578  	return byte(r1)
   579  }
   580  
   581  // decompose88 implements Decompose from FIPS 204 for γ2 = (q - 1) / 88.
   582  //
   583  // r1 is in [0, 43].
   584  func decompose88(r fieldElement) (r1 byte, r0 int32) {
   585  	x := fieldFromMontgomery(r)
   586  	r1 = highBits88(x)
   587  
   588  	// r - r1 * (2 * γ2) mod± q
   589  	r0 = int32(x) - int32(r1)*2*(q-1)/88
   590  	r0 = constantTimeSelectLessOrEqual(q/2+1, r0, r0-q, r0)
   591  
   592  	return r1, r0
   593  }
   594  
   595  // useHint88 implements UseHint from FIPS 204 for γ2 = (q - 1) / 88.
   596  func useHint88(r fieldElement, hint byte) byte {
   597  	const m = 44 // (q − 1) / (2 * γ2)
   598  	r1, r0 := decompose88(r)
   599  	if hint == 1 {
   600  		if r0 > 0 {
   601  			// (r1 + 1) mod m, for r1 in [0, m-1]
   602  			if r1 == m-1 {
   603  				r1 = 0
   604  			} else {
   605  				r1++
   606  			}
   607  		} else {
   608  			// (r1 - 1) % m, for r1 in [0, m-1]
   609  			if r1 == 0 {
   610  				r1 = m - 1
   611  			} else {
   612  				r1--
   613  			}
   614  		}
   615  	}
   616  	return r1
   617  }
   618  
   619  // makeHint88 implements MakeHint from FIPS 204 for γ2 = (q - 1) / 88.
   620  func makeHint88(ct0, w, cs2 fieldElement) byte {
   621  	// Same as makeHint32 above.
   622  	rPlusZ := fieldSub(w, cs2)
   623  	v1 := highBits88(fieldFromMontgomery(rPlusZ))
   624  	r1 := highBits88(fieldFromMontgomery(fieldAdd(rPlusZ, ct0)))
   625  	return byte(constanttime.ByteEq(v1, r1) ^ 1)
   626  }
   627  
   628  // bitPack implements BitPack(r mod± q, γ₁-1, γ₁), which packs the centered
   629  // coefficients of r into little-endian γ1+1-bit chunks. It appends to buf.
   630  //
   631  // It must only be applied to r with coefficients in [−γ₁+1, γ₁], as
   632  // guaranteed by the rejection conditions in Sign.
   633  func bitPack(b []byte, r ringElement, p parameters) []byte {
   634  	switch p.γ1 {
   635  	case 17:
   636  		return bitPack18(b, r)
   637  	case 19:
   638  		return bitPack20(b, r)
   639  	default:
   640  		panic("mldsa: internal error: unsupported γ1")
   641  	}
   642  }
   643  
   644  // bitPack18 implements BitPack(r mod± q, 2¹⁷-1, 2¹⁷), which packs the centered
   645  // coefficients of r into little-endian 18-bit chunks. It appends to buf.
   646  //
   647  // It must only be applied to r with coefficients in [−2¹⁷+1, 2¹⁷], as
   648  // guaranteed by the rejection conditions in Sign.
   649  func bitPack18(buf []byte, r ringElement) []byte {
   650  	out, v := sliceForAppend(buf, 18*n/8)
   651  	const b = 1 << 17
   652  	for i := 0; i < n; i += 4 {
   653  		// b - [−2¹⁷+1, 2¹⁷] = [0, 2²⁸-1]
   654  		w0 := b - fieldCenteredMod(r[i])
   655  		v[0] = byte(w0 << 0)
   656  		v[1] = byte(w0 >> 8)
   657  		v[2] = byte(w0 >> 16)
   658  		w1 := b - fieldCenteredMod(r[i+1])
   659  		v[2] |= byte(w1 << 2)
   660  		v[3] = byte(w1 >> 6)
   661  		v[4] = byte(w1 >> 14)
   662  		w2 := b - fieldCenteredMod(r[i+2])
   663  		v[4] |= byte(w2 << 4)
   664  		v[5] = byte(w2 >> 4)
   665  		v[6] = byte(w2 >> 12)
   666  		w3 := b - fieldCenteredMod(r[i+3])
   667  		v[6] |= byte(w3 << 6)
   668  		v[7] = byte(w3 >> 2)
   669  		v[8] = byte(w3 >> 10)
   670  		v = v[4*18/8:]
   671  	}
   672  	return out
   673  }
   674  
   675  // bitPack20 implements BitPack(r mod± q, 2¹⁹-1, 2¹⁹), which packs the centered
   676  // coefficients of r into little-endian 20-bit chunks. It appends to buf.
   677  //
   678  // It must only be applied to r with coefficients in [−2¹⁹+1, 2¹⁹], as
   679  // guaranteed by the rejection conditions in Sign.
   680  func bitPack20(buf []byte, r ringElement) []byte {
   681  	out, v := sliceForAppend(buf, 20*n/8)
   682  	const b = 1 << 19
   683  	for i := 0; i < n; i += 2 {
   684  		// b - [−2¹⁹+1, 2¹⁹] = [0, 2²⁰-1]
   685  		w0 := b - fieldCenteredMod(r[i])
   686  		v[0] = byte(w0 << 0)
   687  		v[1] = byte(w0 >> 8)
   688  		v[2] = byte(w0 >> 16)
   689  		w1 := b - fieldCenteredMod(r[i+1])
   690  		v[2] |= byte(w1 << 4)
   691  		v[3] = byte(w1 >> 4)
   692  		v[4] = byte(w1 >> 12)
   693  		v = v[2*20/8:]
   694  	}
   695  	return out
   696  }
   697  
   698  // bitUnpack implements BitUnpack(v, 2^γ1-1, 2^γ1), which unpacks each γ1+1 bits
   699  // in little-endian into a coefficient in [-2^γ1+1, 2^γ1].
   700  func bitUnpack(v []byte, p parameters) ringElement {
   701  	switch p.γ1 {
   702  	case 17:
   703  		return bitUnpack18(v)
   704  	case 19:
   705  		return bitUnpack20(v)
   706  	default:
   707  		panic("mldsa: internal error: unsupported γ1")
   708  	}
   709  }
   710  
   711  // bitUnpack18 implements BitUnpack(v, 2¹⁷-1, 2¹⁷), which unpacks each 18 bits
   712  // in little-endian into a coefficient in [-2¹⁷+1, 2¹⁷].
   713  func bitUnpack18(v []byte) ringElement {
   714  	if len(v) != 18*n/8 {
   715  		panic("mldsa: internal error: invalid bitUnpack18 input length")
   716  	}
   717  	const b = 1 << 17
   718  	const mask18 = 1<<18 - 1
   719  	var r ringElement
   720  	for i := 0; i < n; i += 4 {
   721  		w0 := uint32(v[0]) | uint32(v[1])<<8 | uint32(v[2])<<16
   722  		r[i+0] = fieldSubToMontgomery(b, w0&mask18)
   723  		w1 := uint32(v[2])>>2 | uint32(v[3])<<6 | uint32(v[4])<<14
   724  		r[i+1] = fieldSubToMontgomery(b, w1&mask18)
   725  		w2 := uint32(v[4])>>4 | uint32(v[5])<<4 | uint32(v[6])<<12
   726  		r[i+2] = fieldSubToMontgomery(b, w2&mask18)
   727  		w3 := uint32(v[6])>>6 | uint32(v[7])<<2 | uint32(v[8])<<10
   728  		r[i+3] = fieldSubToMontgomery(b, w3&mask18)
   729  		v = v[4*18/8:]
   730  	}
   731  	return r
   732  }
   733  
   734  // bitUnpack20 implements BitUnpack(v, 2¹⁹-1, 2¹⁹), which unpacks each 20 bits
   735  // in little-endian into a coefficient in [-2¹⁹+1, 2¹⁹].
   736  func bitUnpack20(v []byte) ringElement {
   737  	if len(v) != 20*n/8 {
   738  		panic("mldsa: internal error: invalid bitUnpack20 input length")
   739  	}
   740  	const b = 1 << 19
   741  	const mask20 = 1<<20 - 1
   742  	var r ringElement
   743  	for i := 0; i < n; i += 2 {
   744  		w0 := uint32(v[0]) | uint32(v[1])<<8 | uint32(v[2])<<16
   745  		r[i+0] = fieldSubToMontgomery(b, w0&mask20)
   746  		w1 := uint32(v[2])>>4 | uint32(v[3])<<4 | uint32(v[4])<<12
   747  		r[i+1] = fieldSubToMontgomery(b, w1&mask20)
   748  		v = v[2*20/8:]
   749  	}
   750  	return r
   751  }
   752  
   753  // sliceForAppend takes a slice and a requested number of bytes. It returns a
   754  // slice with the contents of the given slice followed by that many bytes and a
   755  // second slice that aliases into it and contains only the extra bytes. If the
   756  // original slice has sufficient capacity then no allocation is performed.
   757  func sliceForAppend(in []byte, n int) (head, tail []byte) {
   758  	if total := len(in) + n; cap(in) >= total {
   759  		head = in[:total]
   760  	} else {
   761  		head = make([]byte, total)
   762  		copy(head, in)
   763  	}
   764  	tail = head[len(in):]
   765  	return
   766  }
   767  
   768  // constantTimeSelectLessOrEqual returns yes if a <= b, no otherwise, in constant time.
   769  func constantTimeSelectLessOrEqual(a, b, yes, no int32) int32 {
   770  	return int32(constanttime.Select(constanttime.LessOrEq(int(a), int(b)), int(yes), int(no)))
   771  }
   772  
   773  // constantTimeSelectEqual returns yes if a == b, no otherwise, in constant time.
   774  func constantTimeSelectEqual(a, b, yes, no uint32) uint32 {
   775  	return uint32(constanttime.Select(constanttime.Eq(int32(a), int32(b)), int(yes), int(no)))
   776  }
   777  
   778  // constantTimeAbs returns the absolute value of x in constant time.
   779  func constantTimeAbs(x int32) uint32 {
   780  	return uint32(constantTimeSelectLessOrEqual(0, x, x, -x))
   781  }
   782  

View as plain text