// Copyright 2021 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package bigmod import ( "bufio" "bytes" cryptorand "crypto/rand" "encoding/hex" "fmt" "math/big" "math/bits" "math/rand" "os" "reflect" "slices" "strings" "testing" "testing/quick" ) // setBig assigns x = n, optionally resizing n to the appropriate size. // // The announced length of x is set based on the actual bit size of the input, // ignoring leading zeroes. func (x *Nat) setBig(n *big.Int) *Nat { limbs := n.Bits() x.reset(len(limbs)) for i := range limbs { x.limbs[i] = uint(limbs[i]) } return x } func (n *Nat) asBig() *big.Int { bits := make([]big.Word, len(n.limbs)) for i := range n.limbs { bits[i] = big.Word(n.limbs[i]) } return new(big.Int).SetBits(bits) } func (n *Nat) String() string { var limbs []string for i := range n.limbs { limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i])) } return "{" + strings.Join(limbs, " ") + "}" } // Generate generates an even nat. It's used by testing/quick to produce random // *nat values for quick.Check invocations. func (*Nat) Generate(r *rand.Rand, size int) reflect.Value { limbs := make([]uint, size) for i := 0; i < size; i++ { limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2) } return reflect.ValueOf(&Nat{limbs}) } func testModAddCommutative(a *Nat, b *Nat) bool { m := maxModulus(uint(len(a.limbs))) aPlusB := new(Nat).set(a) aPlusB.Add(b, m) bPlusA := new(Nat).set(b) bPlusA.Add(a, m) return aPlusB.Equal(bPlusA) == 1 } func TestModAddCommutative(t *testing.T) { err := quick.Check(testModAddCommutative, &quick.Config{}) if err != nil { t.Error(err) } } func testModSubThenAddIdentity(a *Nat, b *Nat) bool { m := maxModulus(uint(len(a.limbs))) original := new(Nat).set(a) a.Sub(b, m) a.Add(b, m) return a.Equal(original) == 1 } func TestModSubThenAddIdentity(t *testing.T) { err := quick.Check(testModSubThenAddIdentity, &quick.Config{}) if err != nil { t.Error(err) } } func TestMontgomeryRoundtrip(t *testing.T) { err := quick.Check(func(a *Nat) bool { one := &Nat{make([]uint, len(a.limbs))} one.limbs[0] = 1 aPlusOne := new(big.Int).SetBytes(natBytes(a)) aPlusOne.Add(aPlusOne, big.NewInt(1)) m, _ := NewModulus(aPlusOne.Bytes()) monty := new(Nat).set(a) monty.montgomeryRepresentation(m) aAgain := new(Nat).set(monty) aAgain.montgomeryMul(monty, one, m) if a.Equal(aAgain) != 1 { t.Errorf("%v != %v", a, aAgain) return false } return true }, &quick.Config{}) if err != nil { t.Error(err) } } func TestShiftIn(t *testing.T) { if bits.UintSize != 64 { t.Skip("examples are only valid in 64 bit") } examples := []struct { m, x, expected []byte y uint64 }{{ m: []byte{13}, x: []byte{0}, y: 0xFFFF_FFFF_FFFF_FFFF, expected: []byte{2}, }, { m: []byte{13}, x: []byte{7}, y: 0xFFFF_FFFF_FFFF_FFFF, expected: []byte{10}, }, { m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, x: make([]byte, 9), y: 0xFFFF_FFFF_FFFF_FFFF, expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, }, { m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, x: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, y: 0, expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06}, }} for i, tt := range examples { m := modulusFromBytes(tt.m) got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m) if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 { t.Errorf("%d: got %v, expected %v", i, got, exp) } } } func TestModulusAndNatSizes(t *testing.T) { // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two // limbs, if they are not, they fit in three. This can be a problem because // modulus strips leading zeroes and nat does not. m := modulusFromBytes([]byte{ 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe} natFromBytes(xb).ExpandFor(m) // must not panic for shrinking NewNat().SetBytes(xb, m) } func TestSetBytes(t *testing.T) { tests := []struct { m, b []byte fail bool }{{ m: []byte{0xff, 0xff}, b: []byte{0x00, 0x01}, }, { m: []byte{0xff, 0xff}, b: []byte{0xff, 0xff}, fail: true, }, { m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, b: []byte{0x00, 0x01}, }, { m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, }, { m: []byte{0xff, 0xff}, b: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, fail: true, }, { m: []byte{0xff, 0xff}, b: []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, fail: true, }, { m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, }, { m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, fail: true, }, { m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, fail: true, }, { m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, fail: true, }, { m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd}, b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, fail: true, }} for i, tt := range tests { m := modulusFromBytes(tt.m) got, err := NewNat().SetBytes(tt.b, m) if err != nil { if !tt.fail { t.Errorf("%d: unexpected error: %v", i, err) } continue } if tt.fail { t.Errorf("%d: unexpected success", i) continue } if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes { t.Errorf("%d: got %v, expected %v", i, got, expected) } } f := func(xBytes []byte) bool { m := maxModulus(uint(len(xBytes)*8/_W + 1)) got, err := NewNat().SetBytes(xBytes, m) if err != nil { return false } return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes } err := quick.Check(f, &quick.Config{}) if err != nil { t.Error(err) } } func TestExpand(t *testing.T) { sliced := []uint{1, 2, 3, 4} examples := []struct { in []uint n int out []uint }{{ []uint{1, 2}, 4, []uint{1, 2, 0, 0}, }, { sliced[:2], 4, []uint{1, 2, 0, 0}, }, { []uint{1, 2}, 2, []uint{1, 2}, }} for i, tt := range examples { got := (&Nat{tt.in}).expand(tt.n) if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 { t.Errorf("%d: got %v, expected %v", i, got, tt.out) } } } func TestMod(t *testing.T) { m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}) x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}) out := new(Nat) out.Mod(x, m) expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09}) if out.Equal(expected) != 1 { t.Errorf("%+v != %+v", out, expected) } } func TestModSub(t *testing.T) { m := modulusFromBytes([]byte{13}) x := &Nat{[]uint{6}} y := &Nat{[]uint{7}} x.Sub(y, m) expected := &Nat{[]uint{12}} if x.Equal(expected) != 1 { t.Errorf("%+v != %+v", x, expected) } x.Sub(y, m) expected = &Nat{[]uint{5}} if x.Equal(expected) != 1 { t.Errorf("%+v != %+v", x, expected) } } func TestModAdd(t *testing.T) { m := modulusFromBytes([]byte{13}) x := &Nat{[]uint{6}} y := &Nat{[]uint{7}} x.Add(y, m) expected := &Nat{[]uint{0}} if x.Equal(expected) != 1 { t.Errorf("%+v != %+v", x, expected) } x.Add(y, m) expected = &Nat{[]uint{7}} if x.Equal(expected) != 1 { t.Errorf("%+v != %+v", x, expected) } } func TestExp(t *testing.T) { m := modulusFromBytes([]byte{13}) x := &Nat{[]uint{3}} out := &Nat{[]uint{0}} out.Exp(x, []byte{12}, m) expected := &Nat{[]uint{1}} if out.Equal(expected) != 1 { t.Errorf("%+v != %+v", out, expected) } } func TestExpShort(t *testing.T) { m := modulusFromBytes([]byte{13}) x := &Nat{[]uint{3}} out := &Nat{[]uint{0}} out.ExpShortVarTime(x, 12, m) expected := &Nat{[]uint{1}} if out.Equal(expected) != 1 { t.Errorf("%+v != %+v", out, expected) } } // TestMulReductions tests that Mul reduces results equal or slightly greater // than the modulus. Some Montgomery algorithms don't and need extra care to // return correct results. See https://go.dev/issue/13907. func TestMulReductions(t *testing.T) { // Two short but multi-limb primes. a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10) b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10) n := new(big.Int).Mul(a, b) N, _ := NewModulus(n.Bytes()) A := NewNat().setBig(a).ExpandFor(N) B := NewNat().setBig(b).ExpandFor(N) if A.Mul(B, N).IsZero() != 1 { t.Error("a * b mod (a * b) != 0") } i := new(big.Int).ModInverse(a, b) N, _ = NewModulus(b.Bytes()) A = NewNat().setBig(a).ExpandFor(N) I := NewNat().setBig(i).ExpandFor(N) one := NewNat().setBig(big.NewInt(1)).ExpandFor(N) if A.Mul(I, N).Equal(one) != 1 { t.Error("a * inv(a) mod b != 1") } } func TestMul(t *testing.T) { t.Run("small", func(t *testing.T) { testMul(t, 760/8) }) t.Run("1024", func(t *testing.T) { testMul(t, 1024/8) }) t.Run("1536", func(t *testing.T) { testMul(t, 1536/8) }) t.Run("2048", func(t *testing.T) { testMul(t, 2048/8) }) } func testMul(t *testing.T, n int) { a, b, m := make([]byte, n), make([]byte, n), make([]byte, n) cryptorand.Read(a) cryptorand.Read(b) cryptorand.Read(m) // Pick the highest as the modulus. if bytes.Compare(a, m) > 0 { a, m = m, a } if bytes.Compare(b, m) > 0 { b, m = m, b } M, err := NewModulus(m) if err != nil { t.Fatal(err) } A, err := NewNat().SetBytes(a, M) if err != nil { t.Fatal(err) } B, err := NewNat().SetBytes(b, M) if err != nil { t.Fatal(err) } A.Mul(B, M) ABytes := A.Bytes(M) mBig := new(big.Int).SetBytes(m) aBig := new(big.Int).SetBytes(a) bBig := new(big.Int).SetBytes(b) nBig := new(big.Int).Mul(aBig, bBig) nBig.Mod(nBig, mBig) nBigBytes := make([]byte, len(ABytes)) nBig.FillBytes(nBigBytes) if !bytes.Equal(ABytes, nBigBytes) { t.Errorf("got %x, want %x", ABytes, nBigBytes) } } func TestIs(t *testing.T) { checkYes := func(c choice, err string) { t.Helper() if c != yes { t.Error(err) } } checkNot := func(c choice, err string) { t.Helper() if c != no { t.Error(err) } } mFour := modulusFromBytes([]byte{4}) n, err := NewNat().SetBytes([]byte{3}, mFour) if err != nil { t.Fatal(err) } checkYes(n.IsMinusOne(mFour), "3 is not -1 mod 4") checkNot(n.IsZero(), "3 is zero") checkNot(n.IsOne(), "3 is one") checkYes(n.IsOdd(), "3 is not odd") n.SubOne(mFour) checkNot(n.IsMinusOne(mFour), "2 is -1 mod 4") checkNot(n.IsZero(), "2 is zero") checkNot(n.IsOne(), "2 is one") checkNot(n.IsOdd(), "2 is odd") n.SubOne(mFour) checkNot(n.IsMinusOne(mFour), "1 is -1 mod 4") checkNot(n.IsZero(), "1 is zero") checkYes(n.IsOne(), "1 is not one") checkYes(n.IsOdd(), "1 is not odd") n.SubOne(mFour) checkNot(n.IsMinusOne(mFour), "0 is -1 mod 4") checkYes(n.IsZero(), "0 is not zero") checkNot(n.IsOne(), "0 is one") checkNot(n.IsOdd(), "0 is odd") n.SubOne(mFour) checkYes(n.IsMinusOne(mFour), "-1 is not -1 mod 4") checkNot(n.IsZero(), "-1 is zero") checkNot(n.IsOne(), "-1 is one") checkYes(n.IsOdd(), "-1 mod 4 is not odd") mTwoLimbs := maxModulus(2) n, err = NewNat().SetBytes([]byte{0x01}, mTwoLimbs) if err != nil { t.Fatal(err) } if n.IsOne() != 1 { t.Errorf("1 is not one") } } func TestTrailingZeroBits(t *testing.T) { nb := new(big.Int).SetBytes([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7e}) nb.Lsh(nb, 128) expected := 129 for expected >= 0 { n := NewNat().setBig(nb) if n.TrailingZeroBitsVarTime() != uint(expected) { t.Errorf("%d != %d", n.TrailingZeroBitsVarTime(), expected) } nb.Rsh(nb, 1) expected-- } } func TestRightShift(t *testing.T) { nb, err := cryptorand.Int(cryptorand.Reader, new(big.Int).Lsh(big.NewInt(1), 1024)) if err != nil { t.Fatal(err) } for _, shift := range []uint{1, 32, 64, 128, 1024 - 128, 1024 - 64, 1024 - 32, 1024 - 1} { testShift := func(t *testing.T, shift uint) { n := NewNat().setBig(nb) oldLen := len(n.limbs) n.ShiftRightVarTime(shift) if len(n.limbs) != oldLen { t.Errorf("len(n.limbs) = %d, want %d", len(n.limbs), oldLen) } exp := new(big.Int).Rsh(nb, shift) if n.asBig().Cmp(exp) != 0 { t.Errorf("%v != %v", n.asBig(), exp) } } t.Run(fmt.Sprint(shift-1), func(t *testing.T) { testShift(t, shift-1) }) t.Run(fmt.Sprint(shift), func(t *testing.T) { testShift(t, shift) }) t.Run(fmt.Sprint(shift+1), func(t *testing.T) { testShift(t, shift+1) }) } } func natBytes(n *Nat) []byte { return n.Bytes(maxModulus(uint(len(n.limbs)))) } func natFromBytes(b []byte) *Nat { // Must not use Nat.SetBytes as it's used in TestSetBytes. bb := new(big.Int).SetBytes(b) return NewNat().setBig(bb) } func modulusFromBytes(b []byte) *Modulus { bb := new(big.Int).SetBytes(b) m, _ := NewModulus(bb.Bytes()) return m } // maxModulus returns the biggest modulus that can fit in n limbs. func maxModulus(n uint) *Modulus { b := big.NewInt(1) b.Lsh(b, n*_W) b.Sub(b, big.NewInt(1)) m, _ := NewModulus(b.Bytes()) return m } func makeBenchmarkModulus() *Modulus { return maxModulus(32) } func makeBenchmarkValue() *Nat { x := make([]uint, 32) for i := 0; i < 32; i++ { x[i]-- } return &Nat{limbs: x} } func makeBenchmarkExponent() []byte { e := make([]byte, 256) for i := 0; i < 32; i++ { e[i] = 0xFF } return e } func BenchmarkModAdd(b *testing.B) { x := makeBenchmarkValue() y := makeBenchmarkValue() m := makeBenchmarkModulus() b.ResetTimer() for i := 0; i < b.N; i++ { x.Add(y, m) } } func BenchmarkModSub(b *testing.B) { x := makeBenchmarkValue() y := makeBenchmarkValue() m := makeBenchmarkModulus() b.ResetTimer() for i := 0; i < b.N; i++ { x.Sub(y, m) } } func BenchmarkMontgomeryRepr(b *testing.B) { x := makeBenchmarkValue() m := makeBenchmarkModulus() b.ResetTimer() for i := 0; i < b.N; i++ { x.montgomeryRepresentation(m) } } func BenchmarkMontgomeryMul(b *testing.B) { x := makeBenchmarkValue() y := makeBenchmarkValue() out := makeBenchmarkValue() m := makeBenchmarkModulus() b.ResetTimer() for i := 0; i < b.N; i++ { out.montgomeryMul(x, y, m) } } func BenchmarkModMul(b *testing.B) { x := makeBenchmarkValue() y := makeBenchmarkValue() m := makeBenchmarkModulus() b.ResetTimer() for i := 0; i < b.N; i++ { x.Mul(y, m) } } func BenchmarkExpBig(b *testing.B) { out := new(big.Int) exponentBytes := makeBenchmarkExponent() x := new(big.Int).SetBytes(exponentBytes) e := new(big.Int).SetBytes(exponentBytes) n := new(big.Int).SetBytes(exponentBytes) one := new(big.Int).SetUint64(1) n.Add(n, one) b.ResetTimer() for i := 0; i < b.N; i++ { out.Exp(x, e, n) } } func BenchmarkExp(b *testing.B) { x := makeBenchmarkValue() e := makeBenchmarkExponent() out := makeBenchmarkValue() m := makeBenchmarkModulus() b.ResetTimer() for i := 0; i < b.N; i++ { out.Exp(x, e, m) } } func TestNewModulus(t *testing.T) { expected := "modulus must be > 1" _, err := NewModulus([]byte{}) if err == nil || err.Error() != expected { t.Errorf("NewModulus(0) got %q, want %q", err, expected) } _, err = NewModulus([]byte{0}) if err == nil || err.Error() != expected { t.Errorf("NewModulus(0) got %q, want %q", err, expected) } _, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) if err == nil || err.Error() != expected { t.Errorf("NewModulus(0) got %q, want %q", err, expected) } _, err = NewModulus([]byte{1}) if err == nil || err.Error() != expected { t.Errorf("NewModulus(1) got %q, want %q", err, expected) } _, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) if err == nil || err.Error() != expected { t.Errorf("NewModulus(1) got %q, want %q", err, expected) } } func makeTestValue(nbits int) []uint { n := nbits / _W x := make([]uint, n) for i := range n { x[i]-- } return x } func TestAddMulVVWSized(t *testing.T) { // Sized addMulVVW have architecture-specific implementations on // a number of architectures. Test that they match the generic // implementation. tests := []struct { n int f func(z, x *uint, y uint) uint }{ {1024, addMulVVW1024}, {1536, addMulVVW1536}, {2048, addMulVVW2048}, } for _, test := range tests { t.Run(fmt.Sprint(test.n), func(t *testing.T) { x := makeTestValue(test.n) z := makeTestValue(test.n) z2 := slices.Clone(z) var y uint y-- c := addMulVVW(z, x, y) c2 := test.f(&z2[0], &x[0], y) if !slices.Equal(z, z2) || c != c2 { t.Errorf("%016X, %016X != %016X, %016X", z, c, z2, c2) } }) } } func TestInverse(t *testing.T) { f, err := os.Open("testdata/mod_inv_tests.txt") if err != nil { t.Fatal(err) } var ModInv, A, M string var lineNum int scanner := bufio.NewScanner(f) for scanner.Scan() { lineNum++ line := scanner.Text() if len(line) == 0 || line[0] == '#' { continue } k, v, _ := strings.Cut(line, " = ") switch k { case "ModInv": ModInv = v case "A": A = v case "M": M = v t.Run(fmt.Sprintf("line %d", lineNum), func(t *testing.T) { m, err := NewModulus(decodeHex(t, M)) if err != nil { t.Skip("modulus <= 1") } a, err := NewNat().SetBytes(decodeHex(t, A), m) if err != nil { t.Fatal(err) } got, ok := NewNat().InverseVarTime(a, m) if !ok { t.Fatal("not invertible") } exp, err := NewNat().SetBytes(decodeHex(t, ModInv), m) if err != nil { t.Fatal(err) } if got.Equal(exp) != 1 { t.Errorf("%v != %v", got, exp) } }) default: t.Fatalf("unknown key %q on line %d", k, lineNum) } } if err := scanner.Err(); err != nil { t.Fatal(err) } } func decodeHex(t *testing.T, s string) []byte { t.Helper() if len(s)%2 != 0 { s = "0" + s } b, err := hex.DecodeString(s) if err != nil { t.Fatalf("failed to decode hex %q: %v", s, err) } return b }