Source file src/crypto/internal/fips140/bigmod/nat.go
1 // Copyright 2021 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package bigmod 6 7 import ( 8 _ "crypto/internal/fips140/check" 9 "crypto/internal/fips140deps/byteorder" 10 "errors" 11 "math/bits" 12 ) 13 14 const ( 15 // _W is the size in bits of our limbs. 16 _W = bits.UintSize 17 // _S is the size in bytes of our limbs. 18 _S = _W / 8 19 ) 20 21 // Note: These functions make many loops over all the words in a Nat. 22 // These loops used to be in assembly, invisible to -race, -asan, and -msan, 23 // but now they are in Go and incur significant overhead in those modes. 24 // To bring the old performance back, we mark all functions that loop 25 // over Nat words with //go:norace. Because //go:norace does not 26 // propagate across inlining, we must also mark functions that inline 27 // //go:norace functions - specifically, those that inline add, addMulVVW, 28 // assign, cmpGeq, rshift1, and sub. 29 30 // choice represents a constant-time boolean. The value of choice is always 31 // either 1 or 0. We use an int instead of bool in order to make decisions in 32 // constant time by turning it into a mask. 33 type choice uint 34 35 func not(c choice) choice { return 1 ^ c } 36 37 const yes = choice(1) 38 const no = choice(0) 39 40 // ctMask is all 1s if on is yes, and all 0s otherwise. 41 func ctMask(on choice) uint { return -uint(on) } 42 43 // ctEq returns 1 if x == y, and 0 otherwise. The execution time of this 44 // function does not depend on its inputs. 45 func ctEq(x, y uint) choice { 46 // If x != y, then either x - y or y - x will generate a carry. 47 _, c1 := bits.Sub(x, y, 0) 48 _, c2 := bits.Sub(y, x, 0) 49 return not(choice(c1 | c2)) 50 } 51 52 // Nat represents an arbitrary natural number 53 // 54 // Each Nat has an announced length, which is the number of limbs it has stored. 55 // Operations on this number are allowed to leak this length, but will not leak 56 // any information about the values contained in those limbs. 57 type Nat struct { 58 // limbs is little-endian in base 2^W with W = bits.UintSize. 59 limbs []uint 60 } 61 62 // preallocTarget is the size in bits of the numbers used to implement the most 63 // common and most performant RSA key size. It's also enough to cover some of 64 // the operations of key sizes up to 4096. 65 const preallocTarget = 2048 66 const preallocLimbs = (preallocTarget + _W - 1) / _W 67 68 // NewNat returns a new nat with a size of zero, just like new(Nat), but with 69 // the preallocated capacity to hold a number of up to preallocTarget bits. 70 // NewNat inlines, so the allocation can live on the stack. 71 func NewNat() *Nat { 72 limbs := make([]uint, 0, preallocLimbs) 73 return &Nat{limbs} 74 } 75 76 // expand expands x to n limbs, leaving its value unchanged. 77 func (x *Nat) expand(n int) *Nat { 78 if len(x.limbs) > n { 79 panic("bigmod: internal error: shrinking nat") 80 } 81 if cap(x.limbs) < n { 82 newLimbs := make([]uint, n) 83 copy(newLimbs, x.limbs) 84 x.limbs = newLimbs 85 return x 86 } 87 extraLimbs := x.limbs[len(x.limbs):n] 88 clear(extraLimbs) 89 x.limbs = x.limbs[:n] 90 return x 91 } 92 93 // reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs). 94 func (x *Nat) reset(n int) *Nat { 95 if cap(x.limbs) < n { 96 x.limbs = make([]uint, n) 97 return x 98 } 99 clear(x.limbs) 100 x.limbs = x.limbs[:n] 101 return x 102 } 103 104 // resetToBytes assigns x = b, where b is a slice of big-endian bytes, resizing 105 // n to the appropriate size. 106 // 107 // The announced length of x is set based on the actual bit size of the input, 108 // ignoring leading zeroes. 109 func (x *Nat) resetToBytes(b []byte) *Nat { 110 x.reset((len(b) + _S - 1) / _S) 111 if err := x.setBytes(b); err != nil { 112 panic("bigmod: internal error: bad arithmetic") 113 } 114 return x.trim() 115 } 116 117 // trim reduces the size of x to match its value. 118 func (x *Nat) trim() *Nat { 119 // Trim most significant (trailing in little-endian) zero limbs. 120 // We assume comparison with zero (but not the branch) is constant time. 121 for i := len(x.limbs) - 1; i >= 0; i-- { 122 if x.limbs[i] != 0 { 123 break 124 } 125 x.limbs = x.limbs[:i] 126 } 127 return x 128 } 129 130 // set assigns x = y, optionally resizing x to the appropriate size. 131 func (x *Nat) set(y *Nat) *Nat { 132 x.reset(len(y.limbs)) 133 copy(x.limbs, y.limbs) 134 return x 135 } 136 137 // Bytes returns x as a zero-extended big-endian byte slice. The size of the 138 // slice will match the size of m. 139 // 140 // x must have the same size as m and it must be less than or equal to m. 141 func (x *Nat) Bytes(m *Modulus) []byte { 142 i := m.Size() 143 bytes := make([]byte, i) 144 for _, limb := range x.limbs { 145 for j := 0; j < _S; j++ { 146 i-- 147 if i < 0 { 148 if limb == 0 { 149 break 150 } 151 panic("bigmod: modulus is smaller than nat") 152 } 153 bytes[i] = byte(limb) 154 limb >>= 8 155 } 156 } 157 return bytes 158 } 159 160 // SetBytes assigns x = b, where b is a slice of big-endian bytes. 161 // SetBytes returns an error if b >= m. 162 // 163 // The output will be resized to the size of m and overwritten. 164 // 165 //go:norace 166 func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) { 167 x.resetFor(m) 168 if err := x.setBytes(b); err != nil { 169 return nil, err 170 } 171 if x.cmpGeq(m.nat) == yes { 172 return nil, errors.New("input overflows the modulus") 173 } 174 return x, nil 175 } 176 177 // SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. 178 // SetOverflowingBytes returns an error if b has a longer bit length than m, but 179 // reduces overflowing values up to 2^⌈log2(m)⌉ - 1. 180 // 181 // The output will be resized to the size of m and overwritten. 182 func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) { 183 x.resetFor(m) 184 if err := x.setBytes(b); err != nil { 185 return nil, err 186 } 187 // setBytes would have returned an error if the input overflowed the limb 188 // size of the modulus, so now we only need to check if the most significant 189 // limb of x has more bits than the most significant limb of the modulus. 190 if bitLen(x.limbs[len(x.limbs)-1]) > bitLen(m.nat.limbs[len(m.nat.limbs)-1]) { 191 return nil, errors.New("input overflows the modulus size") 192 } 193 x.maybeSubtractModulus(no, m) 194 return x, nil 195 } 196 197 // bigEndianUint returns the contents of buf interpreted as a 198 // big-endian encoded uint value. 199 func bigEndianUint(buf []byte) uint { 200 if _W == 64 { 201 return uint(byteorder.BEUint64(buf)) 202 } 203 return uint(byteorder.BEUint32(buf)) 204 } 205 206 func (x *Nat) setBytes(b []byte) error { 207 i, k := len(b), 0 208 for k < len(x.limbs) && i >= _S { 209 x.limbs[k] = bigEndianUint(b[i-_S : i]) 210 i -= _S 211 k++ 212 } 213 for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 { 214 x.limbs[k] |= uint(b[i-1]) << s 215 i-- 216 } 217 if i > 0 { 218 return errors.New("input overflows the modulus size") 219 } 220 return nil 221 } 222 223 // SetUint assigns x = y. 224 // 225 // The output will be resized to a single limb and overwritten. 226 func (x *Nat) SetUint(y uint) *Nat { 227 x.reset(1) 228 x.limbs[0] = y 229 return x 230 } 231 232 // Equal returns 1 if x == y, and 0 otherwise. 233 // 234 // Both operands must have the same announced length. 235 // 236 //go:norace 237 func (x *Nat) Equal(y *Nat) choice { 238 // Eliminate bounds checks in the loop. 239 size := len(x.limbs) 240 xLimbs := x.limbs[:size] 241 yLimbs := y.limbs[:size] 242 243 equal := yes 244 for i := 0; i < size; i++ { 245 equal &= ctEq(xLimbs[i], yLimbs[i]) 246 } 247 return equal 248 } 249 250 // IsZero returns 1 if x == 0, and 0 otherwise. 251 // 252 //go:norace 253 func (x *Nat) IsZero() choice { 254 // Eliminate bounds checks in the loop. 255 size := len(x.limbs) 256 xLimbs := x.limbs[:size] 257 258 zero := yes 259 for i := 0; i < size; i++ { 260 zero &= ctEq(xLimbs[i], 0) 261 } 262 return zero 263 } 264 265 // IsOne returns 1 if x == 1, and 0 otherwise. 266 // 267 //go:norace 268 func (x *Nat) IsOne() choice { 269 // Eliminate bounds checks in the loop. 270 size := len(x.limbs) 271 xLimbs := x.limbs[:size] 272 273 if len(xLimbs) == 0 { 274 return no 275 } 276 277 one := ctEq(xLimbs[0], 1) 278 for i := 1; i < size; i++ { 279 one &= ctEq(xLimbs[i], 0) 280 } 281 return one 282 } 283 284 // IsMinusOne returns 1 if x == -1 mod m, and 0 otherwise. 285 // 286 // The length of x must be the same as the modulus. x must already be reduced 287 // modulo m. 288 // 289 //go:norace 290 func (x *Nat) IsMinusOne(m *Modulus) choice { 291 minusOne := m.Nat() 292 minusOne.SubOne(m) 293 return x.Equal(minusOne) 294 } 295 296 // IsOdd returns 1 if x is odd, and 0 otherwise. 297 func (x *Nat) IsOdd() choice { 298 if len(x.limbs) == 0 { 299 return no 300 } 301 return choice(x.limbs[0] & 1) 302 } 303 304 // TrailingZeroBitsVarTime returns the number of trailing zero bits in x. 305 func (x *Nat) TrailingZeroBitsVarTime() uint { 306 var t uint 307 limbs := x.limbs 308 for _, l := range limbs { 309 if l == 0 { 310 t += _W 311 continue 312 } 313 t += uint(bits.TrailingZeros(l)) 314 break 315 } 316 return t 317 } 318 319 // cmpGeq returns 1 if x >= y, and 0 otherwise. 320 // 321 // Both operands must have the same announced length. 322 // 323 //go:norace 324 func (x *Nat) cmpGeq(y *Nat) choice { 325 // Eliminate bounds checks in the loop. 326 size := len(x.limbs) 327 xLimbs := x.limbs[:size] 328 yLimbs := y.limbs[:size] 329 330 var c uint 331 for i := 0; i < size; i++ { 332 _, c = bits.Sub(xLimbs[i], yLimbs[i], c) 333 } 334 // If there was a carry, then subtracting y underflowed, so 335 // x is not greater than or equal to y. 336 return not(choice(c)) 337 } 338 339 // assign sets x <- y if on == 1, and does nothing otherwise. 340 // 341 // Both operands must have the same announced length. 342 // 343 //go:norace 344 func (x *Nat) assign(on choice, y *Nat) *Nat { 345 // Eliminate bounds checks in the loop. 346 size := len(x.limbs) 347 xLimbs := x.limbs[:size] 348 yLimbs := y.limbs[:size] 349 350 mask := ctMask(on) 351 for i := 0; i < size; i++ { 352 xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i]) 353 } 354 return x 355 } 356 357 // add computes x += y and returns the carry. 358 // 359 // Both operands must have the same announced length. 360 // 361 //go:norace 362 func (x *Nat) add(y *Nat) (c uint) { 363 // Eliminate bounds checks in the loop. 364 size := len(x.limbs) 365 xLimbs := x.limbs[:size] 366 yLimbs := y.limbs[:size] 367 368 for i := 0; i < size; i++ { 369 xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c) 370 } 371 return 372 } 373 374 // sub computes x -= y. It returns the borrow of the subtraction. 375 // 376 // Both operands must have the same announced length. 377 // 378 //go:norace 379 func (x *Nat) sub(y *Nat) (c uint) { 380 // Eliminate bounds checks in the loop. 381 size := len(x.limbs) 382 xLimbs := x.limbs[:size] 383 yLimbs := y.limbs[:size] 384 385 for i := 0; i < size; i++ { 386 xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c) 387 } 388 return 389 } 390 391 // ShiftRightVarTime sets x = x >> n. 392 // 393 // The announced length of x is unchanged. 394 // 395 //go:norace 396 func (x *Nat) ShiftRightVarTime(n uint) *Nat { 397 // Eliminate bounds checks in the loop. 398 size := len(x.limbs) 399 xLimbs := x.limbs[:size] 400 401 shift := int(n % _W) 402 shiftLimbs := int(n / _W) 403 404 var shiftedLimbs []uint 405 if shiftLimbs < size { 406 shiftedLimbs = xLimbs[shiftLimbs:] 407 } 408 409 for i := range xLimbs { 410 if i >= len(shiftedLimbs) { 411 xLimbs[i] = 0 412 continue 413 } 414 415 xLimbs[i] = shiftedLimbs[i] >> shift 416 if i+1 < len(shiftedLimbs) { 417 xLimbs[i] |= shiftedLimbs[i+1] << (_W - shift) 418 } 419 } 420 421 return x 422 } 423 424 // BitLenVarTime returns the actual size of x in bits. 425 // 426 // The actual size of x (but nothing more) leaks through timing side-channels. 427 // Note that this is ordinarily secret, as opposed to the announced size of x. 428 func (x *Nat) BitLenVarTime() int { 429 // Eliminate bounds checks in the loop. 430 size := len(x.limbs) 431 xLimbs := x.limbs[:size] 432 433 for i := size - 1; i >= 0; i-- { 434 if xLimbs[i] != 0 { 435 return i*_W + bitLen(xLimbs[i]) 436 } 437 } 438 return 0 439 } 440 441 // bitLen is a version of bits.Len that only leaks the bit length of n, but not 442 // its value. bits.Len and bits.LeadingZeros use a lookup table for the 443 // low-order bits on some architectures. 444 func bitLen(n uint) int { 445 len := 0 446 // We assume, here and elsewhere, that comparison to zero is constant time 447 // with respect to different non-zero values. 448 for n != 0 { 449 len++ 450 n >>= 1 451 } 452 return len 453 } 454 455 // Modulus is used for modular arithmetic, precomputing relevant constants. 456 // 457 // A Modulus can leak the exact number of bits needed to store its value 458 // and is stored without padding. Its actual value is still kept secret. 459 type Modulus struct { 460 // The underlying natural number for this modulus. 461 // 462 // This will be stored without any padding, and shouldn't alias with any 463 // other natural number being used. 464 nat *Nat 465 466 // If m is even, the following fields are not set. 467 odd bool 468 m0inv uint // -nat.limbs[0]⁻¹ mod _W 469 rr *Nat // R*R for montgomeryRepresentation 470 } 471 472 // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs). 473 func rr(m *Modulus) *Nat { 474 rr := NewNat().ExpandFor(m) 475 n := uint(len(rr.limbs)) 476 mLen := uint(m.BitLen()) 477 logR := _W * n 478 479 // We start by computing R = 2^(_W * n) mod m. We can get pretty close, to 480 // 2^⌊log₂m⌋, by setting the highest bit we can without having to reduce. 481 rr.limbs[n-1] = 1 << ((mLen - 1) % _W) 482 // Then we double until we reach 2^(_W * n). 483 for i := mLen - 1; i < logR; i++ { 484 rr.Add(rr, m) 485 } 486 487 // Next we need to get from R to 2^(_W * n) R mod m (aka from one to R in 488 // the Montgomery domain, meaning we can use Montgomery multiplication now). 489 // We could do that by doubling _W * n times, or with a square-and-double 490 // chain log2(_W * n) long. Turns out the fastest thing is to start out with 491 // doublings, and switch to square-and-double once the exponent is large 492 // enough to justify the cost of the multiplications. 493 494 // The threshold is selected experimentally as a linear function of n. 495 threshold := n / 4 496 497 // We calculate how many of the most-significant bits of the exponent we can 498 // compute before crossing the threshold, and we do it with doublings. 499 i := bits.UintSize 500 for logR>>i <= threshold { 501 i-- 502 } 503 for k := uint(0); k < logR>>i; k++ { 504 rr.Add(rr, m) 505 } 506 507 // Then we process the remaining bits of the exponent with a 508 // square-and-double chain. 509 for i > 0 { 510 rr.montgomeryMul(rr, rr, m) 511 i-- 512 if logR>>i&1 != 0 { 513 rr.Add(rr, m) 514 } 515 } 516 517 return rr 518 } 519 520 // minusInverseModW computes -x⁻¹ mod _W with x odd. 521 // 522 // This operation is used to precompute a constant involved in Montgomery 523 // multiplication. 524 func minusInverseModW(x uint) uint { 525 // Every iteration of this loop doubles the least-significant bits of 526 // correct inverse in y. The first three bits are already correct (1⁻¹ = 1, 527 // 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough 528 // for 64 bits (and wastes only one iteration for 32 bits). 529 // 530 // See https://crypto.stackexchange.com/a/47496. 531 y := x 532 for i := 0; i < 5; i++ { 533 y = y * (2 - x*y) 534 } 535 return -y 536 } 537 538 // NewModulus creates a new Modulus from a slice of big-endian bytes. The 539 // modulus must be greater than one. 540 // 541 // The number of significant bits and whether the modulus is even is leaked 542 // through timing side-channels. 543 func NewModulus(b []byte) (*Modulus, error) { 544 n := NewNat().resetToBytes(b) 545 return newModulus(n) 546 } 547 548 // NewModulusProduct creates a new Modulus from the product of two numbers 549 // represented as big-endian byte slices. The result must be greater than one. 550 // 551 //go:norace 552 func NewModulusProduct(a, b []byte) (*Modulus, error) { 553 x := NewNat().resetToBytes(a) 554 y := NewNat().resetToBytes(b) 555 n := NewNat().reset(len(x.limbs) + len(y.limbs)) 556 for i := range y.limbs { 557 n.limbs[i+len(x.limbs)] = addMulVVW(n.limbs[i:i+len(x.limbs)], x.limbs, y.limbs[i]) 558 } 559 return newModulus(n.trim()) 560 } 561 562 func newModulus(n *Nat) (*Modulus, error) { 563 m := &Modulus{nat: n} 564 if m.nat.IsZero() == yes || m.nat.IsOne() == yes { 565 return nil, errors.New("modulus must be > 1") 566 } 567 if m.nat.IsOdd() == 1 { 568 m.odd = true 569 m.m0inv = minusInverseModW(m.nat.limbs[0]) 570 m.rr = rr(m) 571 } 572 return m, nil 573 } 574 575 // Size returns the size of m in bytes. 576 func (m *Modulus) Size() int { 577 return (m.BitLen() + 7) / 8 578 } 579 580 // BitLen returns the size of m in bits. 581 func (m *Modulus) BitLen() int { 582 return m.nat.BitLenVarTime() 583 } 584 585 // Nat returns m as a Nat. 586 func (m *Modulus) Nat() *Nat { 587 // Make a copy so that the caller can't modify m.nat or alias it with 588 // another Nat in a modulus operation. 589 n := NewNat() 590 n.set(m.nat) 591 return n 592 } 593 594 // shiftIn calculates x = x << _W + y mod m. 595 // 596 // This assumes that x is already reduced mod m. 597 // 598 //go:norace 599 func (x *Nat) shiftIn(y uint, m *Modulus) *Nat { 600 d := NewNat().resetFor(m) 601 602 // Eliminate bounds checks in the loop. 603 size := len(m.nat.limbs) 604 xLimbs := x.limbs[:size] 605 dLimbs := d.limbs[:size] 606 mLimbs := m.nat.limbs[:size] 607 608 // Each iteration of this loop computes x = 2x + b mod m, where b is a bit 609 // from y. Effectively, it left-shifts x and adds y one bit at a time, 610 // reducing it every time. 611 // 612 // To do the reduction, each iteration computes both 2x + b and 2x + b - m. 613 // The next iteration (and finally the return line) will use either result 614 // based on whether 2x + b overflows m. 615 needSubtraction := no 616 for i := _W - 1; i >= 0; i-- { 617 carry := (y >> i) & 1 618 var borrow uint 619 mask := ctMask(needSubtraction) 620 for i := 0; i < size; i++ { 621 l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i])) 622 xLimbs[i], carry = bits.Add(l, l, carry) 623 dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow) 624 } 625 // Like in maybeSubtractModulus, we need the subtraction if either it 626 // didn't underflow (meaning 2x + b > m) or if computing 2x + b 627 // overflowed (meaning 2x + b > 2^_W*n > m). 628 needSubtraction = not(choice(borrow)) | choice(carry) 629 } 630 return x.assign(needSubtraction, d) 631 } 632 633 // Mod calculates out = x mod m. 634 // 635 // This works regardless how large the value of x is. 636 // 637 // The output will be resized to the size of m and overwritten. 638 // 639 //go:norace 640 func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { 641 out.resetFor(m) 642 // Working our way from the most significant to the least significant limb, 643 // we can insert each limb at the least significant position, shifting all 644 // previous limbs left by _W. This way each limb will get shifted by the 645 // correct number of bits. We can insert at least N - 1 limbs without 646 // overflowing m. After that, we need to reduce every time we shift. 647 i := len(x.limbs) - 1 648 // For the first N - 1 limbs we can skip the actual shifting and position 649 // them at the shifted position, which starts at min(N - 2, i). 650 start := len(m.nat.limbs) - 2 651 if i < start { 652 start = i 653 } 654 for j := start; j >= 0; j-- { 655 out.limbs[j] = x.limbs[i] 656 i-- 657 } 658 // We shift in the remaining limbs, reducing modulo m each time. 659 for i >= 0 { 660 out.shiftIn(x.limbs[i], m) 661 i-- 662 } 663 return out 664 } 665 666 // ExpandFor ensures x has the right size to work with operations modulo m. 667 // 668 // The announced size of x must be smaller than or equal to that of m. 669 func (x *Nat) ExpandFor(m *Modulus) *Nat { 670 return x.expand(len(m.nat.limbs)) 671 } 672 673 // resetFor ensures out has the right size to work with operations modulo m. 674 // 675 // out is zeroed and may start at any size. 676 func (out *Nat) resetFor(m *Modulus) *Nat { 677 return out.reset(len(m.nat.limbs)) 678 } 679 680 // maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes. 681 // 682 // It can be used to reduce modulo m a value up to 2m - 1, which is a common 683 // range for results computed by higher level operations. 684 // 685 // always is usually a carry that indicates that the operation that produced x 686 // overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m. 687 // 688 // x and m operands must have the same announced length. 689 // 690 //go:norace 691 func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { 692 t := NewNat().set(x) 693 underflow := t.sub(m.nat) 694 // We keep the result if x - m didn't underflow (meaning x >= m) 695 // or if always was set. 696 keep := not(choice(underflow)) | choice(always) 697 x.assign(keep, t) 698 } 699 700 // Sub computes x = x - y mod m. 701 // 702 // The length of both operands must be the same as the modulus. Both operands 703 // must already be reduced modulo m. 704 // 705 //go:norace 706 func (x *Nat) Sub(y *Nat, m *Modulus) *Nat { 707 underflow := x.sub(y) 708 // If the subtraction underflowed, add m. 709 t := NewNat().set(x) 710 t.add(m.nat) 711 x.assign(choice(underflow), t) 712 return x 713 } 714 715 // SubOne computes x = x - 1 mod m. 716 // 717 // The length of x must be the same as the modulus. 718 func (x *Nat) SubOne(m *Modulus) *Nat { 719 one := NewNat().ExpandFor(m) 720 one.limbs[0] = 1 721 // Sub asks for x to be reduced modulo m, while SubOne doesn't, but when 722 // y = 1, it works, and this is an internal use. 723 return x.Sub(one, m) 724 } 725 726 // Add computes x = x + y mod m. 727 // 728 // The length of both operands must be the same as the modulus. Both operands 729 // must already be reduced modulo m. 730 // 731 //go:norace 732 func (x *Nat) Add(y *Nat, m *Modulus) *Nat { 733 overflow := x.add(y) 734 x.maybeSubtractModulus(choice(overflow), m) 735 return x 736 } 737 738 // montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and 739 // n = len(m.nat.limbs). 740 // 741 // Faster Montgomery multiplication replaces standard modular multiplication for 742 // numbers in this representation. 743 // 744 // This assumes that x is already reduced mod m. 745 func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat { 746 // A Montgomery multiplication (which computes a * b / R) by R * R works out 747 // to a multiplication by R, which takes the value out of the Montgomery domain. 748 return x.montgomeryMul(x, m.rr, m) 749 } 750 751 // montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and 752 // n = len(m.nat.limbs). 753 // 754 // This assumes that x is already reduced mod m. 755 func (x *Nat) montgomeryReduction(m *Modulus) *Nat { 756 // By Montgomery multiplying with 1 not in Montgomery representation, we 757 // convert out back from Montgomery representation, because it works out to 758 // dividing by R. 759 one := NewNat().ExpandFor(m) 760 one.limbs[0] = 1 761 return x.montgomeryMul(x, one, m) 762 } 763 764 // montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and 765 // n = len(m.nat.limbs), also known as a Montgomery multiplication. 766 // 767 // All inputs should be the same length and already reduced modulo m. 768 // x will be resized to the size of m and overwritten. 769 // 770 //go:norace 771 func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { 772 n := len(m.nat.limbs) 773 mLimbs := m.nat.limbs[:n] 774 aLimbs := a.limbs[:n] 775 bLimbs := b.limbs[:n] 776 777 switch n { 778 default: 779 // Attempt to use a stack-allocated backing array. 780 T := make([]uint, 0, preallocLimbs*2) 781 if cap(T) < n*2 { 782 T = make([]uint, 0, n*2) 783 } 784 T = T[:n*2] 785 786 // This loop implements Word-by-Word Montgomery Multiplication, as 787 // described in Algorithm 4 (Fig. 3) of "Efficient Software 788 // Implementations of Modular Exponentiation" by Shay Gueron 789 // [https://eprint.iacr.org/2011/239.pdf]. 790 var c uint 791 for i := 0; i < n; i++ { 792 _ = T[n+i] // bounds check elimination hint 793 794 // Step 1 (T = a × b) is computed as a large pen-and-paper column 795 // multiplication of two numbers with n base-2^_W digits. If we just 796 // wanted to produce 2n-wide T, we would do 797 // 798 // for i := 0; i < n; i++ { 799 // d := bLimbs[i] 800 // T[n+i] = addMulVVW(T[i:n+i], aLimbs, d) 801 // } 802 // 803 // where d is a digit of the multiplier, T[i:n+i] is the shifted 804 // position of the product of that digit, and T[n+i] is the final carry. 805 // Note that T[i] isn't modified after processing the i-th digit. 806 // 807 // Instead of running two loops, one for Step 1 and one for Steps 2–6, 808 // the result of Step 1 is computed during the next loop. This is 809 // possible because each iteration only uses T[i] in Step 2 and then 810 // discards it in Step 6. 811 d := bLimbs[i] 812 c1 := addMulVVW(T[i:n+i], aLimbs, d) 813 814 // Step 6 is replaced by shifting the virtual window we operate 815 // over: T of the algorithm is T[i:] for us. That means that T1 in 816 // Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv. 817 Y := T[i] * m.m0inv 818 819 // Step 4 and 5 add Y × m to T, which as mentioned above is stored 820 // at T[i:]. The two carries (from a × d and Y × m) are added up in 821 // the next word T[n+i], and the carry bit from that addition is 822 // brought forward to the next iteration. 823 c2 := addMulVVW(T[i:n+i], mLimbs, Y) 824 T[n+i], c = bits.Add(c1, c2, c) 825 } 826 827 // Finally for Step 7 we copy the final T window into x, and subtract m 828 // if necessary (which as explained in maybeSubtractModulus can be the 829 // case both if x >= m, or if x overflowed). 830 // 831 // The paper suggests in Section 4 that we can do an "Almost Montgomery 832 // Multiplication" by subtracting only in the overflow case, but the 833 // cost is very similar since the constant time subtraction tells us if 834 // x >= m as a side effect, and taking care of the broken invariant is 835 // highly undesirable (see https://go.dev/issue/13907). 836 copy(x.reset(n).limbs, T[n:]) 837 x.maybeSubtractModulus(choice(c), m) 838 839 // The following specialized cases follow the exact same algorithm, but 840 // optimized for the sizes most used in RSA. addMulVVW is implemented in 841 // assembly with loop unrolling depending on the architecture and bounds 842 // checks are removed by the compiler thanks to the constant size. 843 case 1024 / _W: 844 const n = 1024 / _W // compiler hint 845 T := make([]uint, n*2) 846 var c uint 847 for i := 0; i < n; i++ { 848 d := bLimbs[i] 849 c1 := addMulVVW1024(&T[i], &aLimbs[0], d) 850 Y := T[i] * m.m0inv 851 c2 := addMulVVW1024(&T[i], &mLimbs[0], Y) 852 T[n+i], c = bits.Add(c1, c2, c) 853 } 854 copy(x.reset(n).limbs, T[n:]) 855 x.maybeSubtractModulus(choice(c), m) 856 857 case 1536 / _W: 858 const n = 1536 / _W // compiler hint 859 T := make([]uint, n*2) 860 var c uint 861 for i := 0; i < n; i++ { 862 d := bLimbs[i] 863 c1 := addMulVVW1536(&T[i], &aLimbs[0], d) 864 Y := T[i] * m.m0inv 865 c2 := addMulVVW1536(&T[i], &mLimbs[0], Y) 866 T[n+i], c = bits.Add(c1, c2, c) 867 } 868 copy(x.reset(n).limbs, T[n:]) 869 x.maybeSubtractModulus(choice(c), m) 870 871 case 2048 / _W: 872 const n = 2048 / _W // compiler hint 873 T := make([]uint, n*2) 874 var c uint 875 for i := 0; i < n; i++ { 876 d := bLimbs[i] 877 c1 := addMulVVW2048(&T[i], &aLimbs[0], d) 878 Y := T[i] * m.m0inv 879 c2 := addMulVVW2048(&T[i], &mLimbs[0], Y) 880 T[n+i], c = bits.Add(c1, c2, c) 881 } 882 copy(x.reset(n).limbs, T[n:]) 883 x.maybeSubtractModulus(choice(c), m) 884 } 885 886 return x 887 } 888 889 // addMulVVW multiplies the multi-word value x by the single-word value y, 890 // adding the result to the multi-word value z and returning the final carry. 891 // It can be thought of as one row of a pen-and-paper column multiplication. 892 // 893 //go:norace 894 func addMulVVW(z, x []uint, y uint) (carry uint) { 895 _ = x[len(z)-1] // bounds check elimination hint 896 for i := range z { 897 hi, lo := bits.Mul(x[i], y) 898 lo, c := bits.Add(lo, z[i], 0) 899 // We use bits.Add with zero to get an add-with-carry instruction that 900 // absorbs the carry from the previous bits.Add. 901 hi, _ = bits.Add(hi, 0, c) 902 lo, c = bits.Add(lo, carry, 0) 903 hi, _ = bits.Add(hi, 0, c) 904 carry = hi 905 z[i] = lo 906 } 907 return carry 908 } 909 910 // Mul calculates x = x * y mod m. 911 // 912 // The length of both operands must be the same as the modulus. Both operands 913 // must already be reduced modulo m. 914 // 915 //go:norace 916 func (x *Nat) Mul(y *Nat, m *Modulus) *Nat { 917 if m.odd { 918 // A Montgomery multiplication by a value out of the Montgomery domain 919 // takes the result out of Montgomery representation. 920 xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m 921 return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m 922 } 923 924 n := len(m.nat.limbs) 925 xLimbs := x.limbs[:n] 926 yLimbs := y.limbs[:n] 927 928 switch n { 929 default: 930 // Attempt to use a stack-allocated backing array. 931 T := make([]uint, 0, preallocLimbs*2) 932 if cap(T) < n*2 { 933 T = make([]uint, 0, n*2) 934 } 935 T = T[:n*2] 936 937 // T = x * y 938 for i := 0; i < n; i++ { 939 T[n+i] = addMulVVW(T[i:n+i], xLimbs, yLimbs[i]) 940 } 941 942 // x = T mod m 943 return x.Mod(&Nat{limbs: T}, m) 944 945 // The following specialized cases follow the exact same algorithm, but 946 // optimized for the sizes most used in RSA. See montgomeryMul for details. 947 case 1024 / _W: 948 const n = 1024 / _W // compiler hint 949 T := make([]uint, n*2) 950 for i := 0; i < n; i++ { 951 T[n+i] = addMulVVW1024(&T[i], &xLimbs[0], yLimbs[i]) 952 } 953 return x.Mod(&Nat{limbs: T}, m) 954 case 1536 / _W: 955 const n = 1536 / _W // compiler hint 956 T := make([]uint, n*2) 957 for i := 0; i < n; i++ { 958 T[n+i] = addMulVVW1536(&T[i], &xLimbs[0], yLimbs[i]) 959 } 960 return x.Mod(&Nat{limbs: T}, m) 961 case 2048 / _W: 962 const n = 2048 / _W // compiler hint 963 T := make([]uint, n*2) 964 for i := 0; i < n; i++ { 965 T[n+i] = addMulVVW2048(&T[i], &xLimbs[0], yLimbs[i]) 966 } 967 return x.Mod(&Nat{limbs: T}, m) 968 } 969 } 970 971 // Exp calculates out = x^e mod m. 972 // 973 // The exponent e is represented in big-endian order. The output will be resized 974 // to the size of m and overwritten. x must already be reduced modulo m. 975 // 976 // m must be odd, or Exp will panic. 977 // 978 //go:norace 979 func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { 980 if !m.odd { 981 panic("bigmod: modulus for Exp must be odd") 982 } 983 984 // We use a 4 bit window. For our RSA workload, 4 bit windows are faster 985 // than 2 bit windows, but use an extra 12 nats worth of scratch space. 986 // Using bit sizes that don't divide 8 are more complex to implement, but 987 // are likely to be more efficient if necessary. 988 989 table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1) 990 // newNat calls are unrolled so they are allocated on the stack. 991 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 992 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 993 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 994 } 995 table[0].set(x).montgomeryRepresentation(m) 996 for i := 1; i < len(table); i++ { 997 table[i].montgomeryMul(table[i-1], table[0], m) 998 } 999 1000 out.resetFor(m) 1001 out.limbs[0] = 1 1002 out.montgomeryRepresentation(m) 1003 tmp := NewNat().ExpandFor(m) 1004 for _, b := range e { 1005 for _, j := range []int{4, 0} { 1006 // Square four times. Optimization note: this can be implemented 1007 // more efficiently than with generic Montgomery multiplication. 1008 out.montgomeryMul(out, out, m) 1009 out.montgomeryMul(out, out, m) 1010 out.montgomeryMul(out, out, m) 1011 out.montgomeryMul(out, out, m) 1012 1013 // Select x^k in constant time from the table. 1014 k := uint((b >> j) & 0b1111) 1015 for i := range table { 1016 tmp.assign(ctEq(k, uint(i+1)), table[i]) 1017 } 1018 1019 // Multiply by x^k, discarding the result if k = 0. 1020 tmp.montgomeryMul(out, tmp, m) 1021 out.assign(not(ctEq(k, 0)), tmp) 1022 } 1023 } 1024 1025 return out.montgomeryReduction(m) 1026 } 1027 1028 // ExpShortVarTime calculates out = x^e mod m. 1029 // 1030 // The output will be resized to the size of m and overwritten. x must already 1031 // be reduced modulo m. This leaks the exponent through timing side-channels. 1032 // 1033 // m must be odd, or ExpShortVarTime will panic. 1034 func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat { 1035 if !m.odd { 1036 panic("bigmod: modulus for ExpShortVarTime must be odd") 1037 } 1038 // For short exponents, precomputing a table and using a window like in Exp 1039 // doesn't pay off. Instead, we do a simple conditional square-and-multiply 1040 // chain, skipping the initial run of zeroes. 1041 xR := NewNat().set(x).montgomeryRepresentation(m) 1042 out.set(xR) 1043 for i := bits.UintSize - bits.Len(e) + 1; i < bits.UintSize; i++ { 1044 out.montgomeryMul(out, out, m) 1045 if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 { 1046 out.montgomeryMul(out, xR, m) 1047 } 1048 } 1049 return out.montgomeryReduction(m) 1050 } 1051 1052 // InverseVarTime calculates x = a⁻¹ mod m and returns (x, true) if a is 1053 // invertible. Otherwise, InverseVarTime returns (x, false) and x is not 1054 // modified. 1055 // 1056 // a must be reduced modulo m, but doesn't need to have the same size. The 1057 // output will be resized to the size of m and overwritten. 1058 // 1059 //go:norace 1060 func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) { 1061 // This is the extended binary GCD algorithm described in the Handbook of 1062 // Applied Cryptography, Algorithm 14.61, adapted by BoringSSL to bound 1063 // coefficients and avoid negative numbers. For more details and proof of 1064 // correctness, see https://github.com/mit-plv/fiat-crypto/pull/333/files. 1065 // 1066 // Following the proof linked in the PR above, the changes are: 1067 // 1068 // 1. Negate [B] and [C] so they are positive. The invariant now involves a 1069 // subtraction. 1070 // 2. If step 2 (both [x] and [y] are even) runs, abort immediately. This 1071 // algorithm only cares about [x] and [y] relatively prime. 1072 // 3. Subtract copies of [x] and [y] as needed in step 6 (both [u] and [v] 1073 // are odd) so coefficients stay in bounds. 1074 // 4. Replace the [u >= v] check with [u > v]. This changes the end 1075 // condition to [v = 0] rather than [u = 0]. This saves an extra 1076 // subtraction due to which coefficients were negated. 1077 // 5. Rename x and y to a and n, to capture that one is a modulus. 1078 // 6. Rearrange steps 4 through 6 slightly. Merge the loops in steps 4 and 1079 // 5 into the main loop (step 7's goto), and move step 6 to the start of 1080 // the loop iteration, ensuring each loop iteration halves at least one 1081 // value. 1082 // 1083 // Note this algorithm does not handle either input being zero. 1084 1085 if a.IsZero() == yes { 1086 return x, false 1087 } 1088 if a.IsOdd() == no && !m.odd { 1089 // a and m are not coprime, as they are both even. 1090 return x, false 1091 } 1092 1093 u := NewNat().set(a).ExpandFor(m) 1094 v := m.Nat() 1095 1096 A := NewNat().reset(len(m.nat.limbs)) 1097 A.limbs[0] = 1 1098 B := NewNat().reset(len(a.limbs)) 1099 C := NewNat().reset(len(m.nat.limbs)) 1100 D := NewNat().reset(len(a.limbs)) 1101 D.limbs[0] = 1 1102 1103 // Before and after each loop iteration, the following hold: 1104 // 1105 // u = A*a - B*m 1106 // v = D*m - C*a 1107 // 0 < u <= a 1108 // 0 <= v <= m 1109 // 0 <= A < m 1110 // 0 <= B <= a 1111 // 0 <= C < m 1112 // 0 <= D <= a 1113 // 1114 // After each loop iteration, u and v only get smaller, and at least one of 1115 // them shrinks by at least a factor of two. 1116 for { 1117 // If both u and v are odd, subtract the smaller from the larger. 1118 // If u = v, we need to subtract from v to hit the modified exit condition. 1119 if u.IsOdd() == yes && v.IsOdd() == yes { 1120 if v.cmpGeq(u) == no { 1121 u.sub(v) 1122 A.Add(C, m) 1123 B.Add(D, &Modulus{nat: a}) 1124 } else { 1125 v.sub(u) 1126 C.Add(A, m) 1127 D.Add(B, &Modulus{nat: a}) 1128 } 1129 } 1130 1131 // Exactly one of u and v is now even. 1132 if u.IsOdd() == v.IsOdd() { 1133 panic("bigmod: internal error: u and v are not in the expected state") 1134 } 1135 1136 // Halve the even one and adjust the corresponding coefficient. 1137 if u.IsOdd() == no { 1138 rshift1(u, 0) 1139 if A.IsOdd() == yes || B.IsOdd() == yes { 1140 rshift1(A, A.add(m.nat)) 1141 rshift1(B, B.add(a)) 1142 } else { 1143 rshift1(A, 0) 1144 rshift1(B, 0) 1145 } 1146 } else { // v.IsOdd() == no 1147 rshift1(v, 0) 1148 if C.IsOdd() == yes || D.IsOdd() == yes { 1149 rshift1(C, C.add(m.nat)) 1150 rshift1(D, D.add(a)) 1151 } else { 1152 rshift1(C, 0) 1153 rshift1(D, 0) 1154 } 1155 } 1156 1157 if v.IsZero() == yes { 1158 if u.IsOne() == no { 1159 return x, false 1160 } 1161 return x.set(A), true 1162 } 1163 } 1164 } 1165 1166 //go:norace 1167 func rshift1(a *Nat, carry uint) { 1168 size := len(a.limbs) 1169 aLimbs := a.limbs[:size] 1170 1171 for i := range size { 1172 aLimbs[i] >>= 1 1173 if i+1 < size { 1174 aLimbs[i] |= aLimbs[i+1] << (_W - 1) 1175 } else { 1176 aLimbs[i] |= carry << (_W - 1) 1177 } 1178 } 1179 } 1180