Source file src/crypto/internal/fips140/bigmod/nat_test.go

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package bigmod
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	cryptorand "crypto/rand"
    11  	"encoding/hex"
    12  	"fmt"
    13  	"math/big"
    14  	"math/bits"
    15  	"math/rand"
    16  	"os"
    17  	"reflect"
    18  	"slices"
    19  	"strings"
    20  	"testing"
    21  	"testing/quick"
    22  )
    23  
    24  // setBig assigns x = n, optionally resizing n to the appropriate size.
    25  //
    26  // The announced length of x is set based on the actual bit size of the input,
    27  // ignoring leading zeroes.
    28  func (x *Nat) setBig(n *big.Int) *Nat {
    29  	limbs := n.Bits()
    30  	x.reset(len(limbs))
    31  	for i := range limbs {
    32  		x.limbs[i] = uint(limbs[i])
    33  	}
    34  	return x
    35  }
    36  
    37  func (n *Nat) asBig() *big.Int {
    38  	bits := make([]big.Word, len(n.limbs))
    39  	for i := range n.limbs {
    40  		bits[i] = big.Word(n.limbs[i])
    41  	}
    42  	return new(big.Int).SetBits(bits)
    43  }
    44  
    45  func (n *Nat) String() string {
    46  	var limbs []string
    47  	for i := range n.limbs {
    48  		limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i]))
    49  	}
    50  	return "{" + strings.Join(limbs, " ") + "}"
    51  }
    52  
    53  // Generate generates an even nat. It's used by testing/quick to produce random
    54  // *nat values for quick.Check invocations.
    55  func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
    56  	limbs := make([]uint, size)
    57  	for i := 0; i < size; i++ {
    58  		limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
    59  	}
    60  	return reflect.ValueOf(&Nat{limbs})
    61  }
    62  
    63  func testModAddCommutative(a *Nat, b *Nat) bool {
    64  	m := maxModulus(uint(len(a.limbs)))
    65  	aPlusB := new(Nat).set(a)
    66  	aPlusB.Add(b, m)
    67  	bPlusA := new(Nat).set(b)
    68  	bPlusA.Add(a, m)
    69  	return aPlusB.Equal(bPlusA) == 1
    70  }
    71  
    72  func TestModAddCommutative(t *testing.T) {
    73  	err := quick.Check(testModAddCommutative, &quick.Config{})
    74  	if err != nil {
    75  		t.Error(err)
    76  	}
    77  }
    78  
    79  func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
    80  	m := maxModulus(uint(len(a.limbs)))
    81  	original := new(Nat).set(a)
    82  	a.Sub(b, m)
    83  	a.Add(b, m)
    84  	return a.Equal(original) == 1
    85  }
    86  
    87  func TestModSubThenAddIdentity(t *testing.T) {
    88  	err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
    89  	if err != nil {
    90  		t.Error(err)
    91  	}
    92  }
    93  
    94  func TestMontgomeryRoundtrip(t *testing.T) {
    95  	err := quick.Check(func(a *Nat) bool {
    96  		one := &Nat{make([]uint, len(a.limbs))}
    97  		one.limbs[0] = 1
    98  		aPlusOne := new(big.Int).SetBytes(natBytes(a))
    99  		aPlusOne.Add(aPlusOne, big.NewInt(1))
   100  		m, _ := NewModulus(aPlusOne.Bytes())
   101  		monty := new(Nat).set(a)
   102  		monty.montgomeryRepresentation(m)
   103  		aAgain := new(Nat).set(monty)
   104  		aAgain.montgomeryMul(monty, one, m)
   105  		if a.Equal(aAgain) != 1 {
   106  			t.Errorf("%v != %v", a, aAgain)
   107  			return false
   108  		}
   109  		return true
   110  	}, &quick.Config{})
   111  	if err != nil {
   112  		t.Error(err)
   113  	}
   114  }
   115  
   116  func TestShiftIn(t *testing.T) {
   117  	if bits.UintSize != 64 {
   118  		t.Skip("examples are only valid in 64 bit")
   119  	}
   120  	examples := []struct {
   121  		m, x, expected []byte
   122  		y              uint64
   123  	}{{
   124  		m:        []byte{13},
   125  		x:        []byte{0},
   126  		y:        0xFFFF_FFFF_FFFF_FFFF,
   127  		expected: []byte{2},
   128  	}, {
   129  		m:        []byte{13},
   130  		x:        []byte{7},
   131  		y:        0xFFFF_FFFF_FFFF_FFFF,
   132  		expected: []byte{10},
   133  	}, {
   134  		m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
   135  		x:        make([]byte, 9),
   136  		y:        0xFFFF_FFFF_FFFF_FFFF,
   137  		expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   138  	}, {
   139  		m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
   140  		x:        []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   141  		y:        0,
   142  		expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06},
   143  	}}
   144  
   145  	for i, tt := range examples {
   146  		m := modulusFromBytes(tt.m)
   147  		got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
   148  		if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 {
   149  			t.Errorf("%d: got %v, expected %v", i, got, exp)
   150  		}
   151  	}
   152  }
   153  
   154  func TestModulusAndNatSizes(t *testing.T) {
   155  	// These are 126 bit (2 * _W on 64-bit architectures) values, serialized as
   156  	// 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
   157  	// limbs, if they are not, they fit in three. This can be a problem because
   158  	// modulus strips leading zeroes and nat does not.
   159  	m := modulusFromBytes([]byte{
   160  		0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
   161  		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
   162  	xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
   163  		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
   164  	natFromBytes(xb).ExpandFor(m) // must not panic for shrinking
   165  	NewNat().SetBytes(xb, m)
   166  }
   167  
   168  func TestSetBytes(t *testing.T) {
   169  	tests := []struct {
   170  		m, b []byte
   171  		fail bool
   172  	}{{
   173  		m: []byte{0xff, 0xff},
   174  		b: []byte{0x00, 0x01},
   175  	}, {
   176  		m:    []byte{0xff, 0xff},
   177  		b:    []byte{0xff, 0xff},
   178  		fail: true,
   179  	}, {
   180  		m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   181  		b: []byte{0x00, 0x01},
   182  	}, {
   183  		m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   184  		b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
   185  	}, {
   186  		m:    []byte{0xff, 0xff},
   187  		b:    []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
   188  		fail: true,
   189  	}, {
   190  		m:    []byte{0xff, 0xff},
   191  		b:    []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
   192  		fail: true,
   193  	}, {
   194  		m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   195  		b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
   196  	}, {
   197  		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   198  		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
   199  		fail: true,
   200  	}, {
   201  		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   202  		b:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   203  		fail: true,
   204  	}, {
   205  		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   206  		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
   207  		fail: true,
   208  	}, {
   209  		m:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd},
   210  		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   211  		fail: true,
   212  	}}
   213  
   214  	for i, tt := range tests {
   215  		m := modulusFromBytes(tt.m)
   216  		got, err := NewNat().SetBytes(tt.b, m)
   217  		if err != nil {
   218  			if !tt.fail {
   219  				t.Errorf("%d: unexpected error: %v", i, err)
   220  			}
   221  			continue
   222  		}
   223  		if tt.fail {
   224  			t.Errorf("%d: unexpected success", i)
   225  			continue
   226  		}
   227  		if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
   228  			t.Errorf("%d: got %v, expected %v", i, got, expected)
   229  		}
   230  	}
   231  
   232  	f := func(xBytes []byte) bool {
   233  		m := maxModulus(uint(len(xBytes)*8/_W + 1))
   234  		got, err := NewNat().SetBytes(xBytes, m)
   235  		if err != nil {
   236  			return false
   237  		}
   238  		return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes
   239  	}
   240  
   241  	err := quick.Check(f, &quick.Config{})
   242  	if err != nil {
   243  		t.Error(err)
   244  	}
   245  }
   246  
   247  func TestExpand(t *testing.T) {
   248  	sliced := []uint{1, 2, 3, 4}
   249  	examples := []struct {
   250  		in  []uint
   251  		n   int
   252  		out []uint
   253  	}{{
   254  		[]uint{1, 2},
   255  		4,
   256  		[]uint{1, 2, 0, 0},
   257  	}, {
   258  		sliced[:2],
   259  		4,
   260  		[]uint{1, 2, 0, 0},
   261  	}, {
   262  		[]uint{1, 2},
   263  		2,
   264  		[]uint{1, 2},
   265  	}}
   266  
   267  	for i, tt := range examples {
   268  		got := (&Nat{tt.in}).expand(tt.n)
   269  		if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
   270  			t.Errorf("%d: got %v, expected %v", i, got, tt.out)
   271  		}
   272  	}
   273  }
   274  
   275  func TestMod(t *testing.T) {
   276  	m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})
   277  	x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
   278  	out := new(Nat)
   279  	out.Mod(x, m)
   280  	expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
   281  	if out.Equal(expected) != 1 {
   282  		t.Errorf("%+v != %+v", out, expected)
   283  	}
   284  }
   285  
   286  func TestModSub(t *testing.T) {
   287  	m := modulusFromBytes([]byte{13})
   288  	x := &Nat{[]uint{6}}
   289  	y := &Nat{[]uint{7}}
   290  	x.Sub(y, m)
   291  	expected := &Nat{[]uint{12}}
   292  	if x.Equal(expected) != 1 {
   293  		t.Errorf("%+v != %+v", x, expected)
   294  	}
   295  	x.Sub(y, m)
   296  	expected = &Nat{[]uint{5}}
   297  	if x.Equal(expected) != 1 {
   298  		t.Errorf("%+v != %+v", x, expected)
   299  	}
   300  }
   301  
   302  func TestModAdd(t *testing.T) {
   303  	m := modulusFromBytes([]byte{13})
   304  	x := &Nat{[]uint{6}}
   305  	y := &Nat{[]uint{7}}
   306  	x.Add(y, m)
   307  	expected := &Nat{[]uint{0}}
   308  	if x.Equal(expected) != 1 {
   309  		t.Errorf("%+v != %+v", x, expected)
   310  	}
   311  	x.Add(y, m)
   312  	expected = &Nat{[]uint{7}}
   313  	if x.Equal(expected) != 1 {
   314  		t.Errorf("%+v != %+v", x, expected)
   315  	}
   316  }
   317  
   318  func TestExp(t *testing.T) {
   319  	m := modulusFromBytes([]byte{13})
   320  	x := &Nat{[]uint{3}}
   321  	out := &Nat{[]uint{0}}
   322  	out.Exp(x, []byte{12}, m)
   323  	expected := &Nat{[]uint{1}}
   324  	if out.Equal(expected) != 1 {
   325  		t.Errorf("%+v != %+v", out, expected)
   326  	}
   327  }
   328  
   329  func TestExpShort(t *testing.T) {
   330  	m := modulusFromBytes([]byte{13})
   331  	x := &Nat{[]uint{3}}
   332  	out := &Nat{[]uint{0}}
   333  	out.ExpShortVarTime(x, 12, m)
   334  	expected := &Nat{[]uint{1}}
   335  	if out.Equal(expected) != 1 {
   336  		t.Errorf("%+v != %+v", out, expected)
   337  	}
   338  }
   339  
   340  // TestMulReductions tests that Mul reduces results equal or slightly greater
   341  // than the modulus. Some Montgomery algorithms don't and need extra care to
   342  // return correct results. See https://go.dev/issue/13907.
   343  func TestMulReductions(t *testing.T) {
   344  	// Two short but multi-limb primes.
   345  	a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10)
   346  	b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
   347  	n := new(big.Int).Mul(a, b)
   348  
   349  	N, _ := NewModulus(n.Bytes())
   350  	A := NewNat().setBig(a).ExpandFor(N)
   351  	B := NewNat().setBig(b).ExpandFor(N)
   352  
   353  	if A.Mul(B, N).IsZero() != 1 {
   354  		t.Error("a * b mod (a * b) != 0")
   355  	}
   356  
   357  	i := new(big.Int).ModInverse(a, b)
   358  	N, _ = NewModulus(b.Bytes())
   359  	A = NewNat().setBig(a).ExpandFor(N)
   360  	I := NewNat().setBig(i).ExpandFor(N)
   361  	one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
   362  
   363  	if A.Mul(I, N).Equal(one) != 1 {
   364  		t.Error("a * inv(a) mod b != 1")
   365  	}
   366  }
   367  
   368  func TestMul(t *testing.T) {
   369  	t.Run("small", func(t *testing.T) { testMul(t, 760/8) })
   370  	t.Run("1024", func(t *testing.T) { testMul(t, 1024/8) })
   371  	t.Run("1536", func(t *testing.T) { testMul(t, 1536/8) })
   372  	t.Run("2048", func(t *testing.T) { testMul(t, 2048/8) })
   373  }
   374  
   375  func testMul(t *testing.T, n int) {
   376  	a, b, m := make([]byte, n), make([]byte, n), make([]byte, n)
   377  	cryptorand.Read(a)
   378  	cryptorand.Read(b)
   379  	cryptorand.Read(m)
   380  
   381  	// Pick the highest as the modulus.
   382  	if bytes.Compare(a, m) > 0 {
   383  		a, m = m, a
   384  	}
   385  	if bytes.Compare(b, m) > 0 {
   386  		b, m = m, b
   387  	}
   388  
   389  	M, err := NewModulus(m)
   390  	if err != nil {
   391  		t.Fatal(err)
   392  	}
   393  	A, err := NewNat().SetBytes(a, M)
   394  	if err != nil {
   395  		t.Fatal(err)
   396  	}
   397  	B, err := NewNat().SetBytes(b, M)
   398  	if err != nil {
   399  		t.Fatal(err)
   400  	}
   401  
   402  	A.Mul(B, M)
   403  	ABytes := A.Bytes(M)
   404  
   405  	mBig := new(big.Int).SetBytes(m)
   406  	aBig := new(big.Int).SetBytes(a)
   407  	bBig := new(big.Int).SetBytes(b)
   408  	nBig := new(big.Int).Mul(aBig, bBig)
   409  	nBig.Mod(nBig, mBig)
   410  	nBigBytes := make([]byte, len(ABytes))
   411  	nBig.FillBytes(nBigBytes)
   412  
   413  	if !bytes.Equal(ABytes, nBigBytes) {
   414  		t.Errorf("got %x, want %x", ABytes, nBigBytes)
   415  	}
   416  }
   417  
   418  func TestIs(t *testing.T) {
   419  	checkYes := func(c choice, err string) {
   420  		t.Helper()
   421  		if c != yes {
   422  			t.Error(err)
   423  		}
   424  	}
   425  	checkNot := func(c choice, err string) {
   426  		t.Helper()
   427  		if c != no {
   428  			t.Error(err)
   429  		}
   430  	}
   431  
   432  	mFour := modulusFromBytes([]byte{4})
   433  	n, err := NewNat().SetBytes([]byte{3}, mFour)
   434  	if err != nil {
   435  		t.Fatal(err)
   436  	}
   437  	checkYes(n.IsMinusOne(mFour), "3 is not -1 mod 4")
   438  	checkNot(n.IsZero(), "3 is zero")
   439  	checkNot(n.IsOne(), "3 is one")
   440  	checkYes(n.IsOdd(), "3 is not odd")
   441  	n.SubOne(mFour)
   442  	checkNot(n.IsMinusOne(mFour), "2 is -1 mod 4")
   443  	checkNot(n.IsZero(), "2 is zero")
   444  	checkNot(n.IsOne(), "2 is one")
   445  	checkNot(n.IsOdd(), "2 is odd")
   446  	n.SubOne(mFour)
   447  	checkNot(n.IsMinusOne(mFour), "1 is -1 mod 4")
   448  	checkNot(n.IsZero(), "1 is zero")
   449  	checkYes(n.IsOne(), "1 is not one")
   450  	checkYes(n.IsOdd(), "1 is not odd")
   451  	n.SubOne(mFour)
   452  	checkNot(n.IsMinusOne(mFour), "0 is -1 mod 4")
   453  	checkYes(n.IsZero(), "0 is not zero")
   454  	checkNot(n.IsOne(), "0 is one")
   455  	checkNot(n.IsOdd(), "0 is odd")
   456  	n.SubOne(mFour)
   457  	checkYes(n.IsMinusOne(mFour), "-1 is not -1 mod 4")
   458  	checkNot(n.IsZero(), "-1 is zero")
   459  	checkNot(n.IsOne(), "-1 is one")
   460  	checkYes(n.IsOdd(), "-1 mod 4 is not odd")
   461  
   462  	mTwoLimbs := maxModulus(2)
   463  	n, err = NewNat().SetBytes([]byte{0x01}, mTwoLimbs)
   464  	if err != nil {
   465  		t.Fatal(err)
   466  	}
   467  	if n.IsOne() != 1 {
   468  		t.Errorf("1 is not one")
   469  	}
   470  }
   471  
   472  func TestTrailingZeroBits(t *testing.T) {
   473  	nb := new(big.Int).SetBytes([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7e})
   474  	nb.Lsh(nb, 128)
   475  	expected := 129
   476  	for expected >= 0 {
   477  		n := NewNat().setBig(nb)
   478  		if n.TrailingZeroBitsVarTime() != uint(expected) {
   479  			t.Errorf("%d != %d", n.TrailingZeroBitsVarTime(), expected)
   480  		}
   481  		nb.Rsh(nb, 1)
   482  		expected--
   483  	}
   484  }
   485  
   486  func TestRightShift(t *testing.T) {
   487  	nb, err := cryptorand.Int(cryptorand.Reader, new(big.Int).Lsh(big.NewInt(1), 1024))
   488  	if err != nil {
   489  		t.Fatal(err)
   490  	}
   491  	for _, shift := range []uint{1, 32, 64, 128, 1024 - 128, 1024 - 64, 1024 - 32, 1024 - 1} {
   492  		testShift := func(t *testing.T, shift uint) {
   493  			n := NewNat().setBig(nb)
   494  			oldLen := len(n.limbs)
   495  			n.ShiftRightVarTime(shift)
   496  			if len(n.limbs) != oldLen {
   497  				t.Errorf("len(n.limbs) = %d, want %d", len(n.limbs), oldLen)
   498  			}
   499  			exp := new(big.Int).Rsh(nb, shift)
   500  			if n.asBig().Cmp(exp) != 0 {
   501  				t.Errorf("%v != %v", n.asBig(), exp)
   502  			}
   503  		}
   504  		t.Run(fmt.Sprint(shift-1), func(t *testing.T) { testShift(t, shift-1) })
   505  		t.Run(fmt.Sprint(shift), func(t *testing.T) { testShift(t, shift) })
   506  		t.Run(fmt.Sprint(shift+1), func(t *testing.T) { testShift(t, shift+1) })
   507  	}
   508  }
   509  
   510  func natBytes(n *Nat) []byte {
   511  	return n.Bytes(maxModulus(uint(len(n.limbs))))
   512  }
   513  
   514  func natFromBytes(b []byte) *Nat {
   515  	// Must not use Nat.SetBytes as it's used in TestSetBytes.
   516  	bb := new(big.Int).SetBytes(b)
   517  	return NewNat().setBig(bb)
   518  }
   519  
   520  func modulusFromBytes(b []byte) *Modulus {
   521  	bb := new(big.Int).SetBytes(b)
   522  	m, _ := NewModulus(bb.Bytes())
   523  	return m
   524  }
   525  
   526  // maxModulus returns the biggest modulus that can fit in n limbs.
   527  func maxModulus(n uint) *Modulus {
   528  	b := big.NewInt(1)
   529  	b.Lsh(b, n*_W)
   530  	b.Sub(b, big.NewInt(1))
   531  	m, _ := NewModulus(b.Bytes())
   532  	return m
   533  }
   534  
   535  func makeBenchmarkModulus() *Modulus {
   536  	return maxModulus(32)
   537  }
   538  
   539  func makeBenchmarkValue() *Nat {
   540  	x := make([]uint, 32)
   541  	for i := 0; i < 32; i++ {
   542  		x[i]--
   543  	}
   544  	return &Nat{limbs: x}
   545  }
   546  
   547  func makeBenchmarkExponent() []byte {
   548  	e := make([]byte, 256)
   549  	for i := 0; i < 32; i++ {
   550  		e[i] = 0xFF
   551  	}
   552  	return e
   553  }
   554  
   555  func BenchmarkModAdd(b *testing.B) {
   556  	x := makeBenchmarkValue()
   557  	y := makeBenchmarkValue()
   558  	m := makeBenchmarkModulus()
   559  
   560  	b.ResetTimer()
   561  	for i := 0; i < b.N; i++ {
   562  		x.Add(y, m)
   563  	}
   564  }
   565  
   566  func BenchmarkModSub(b *testing.B) {
   567  	x := makeBenchmarkValue()
   568  	y := makeBenchmarkValue()
   569  	m := makeBenchmarkModulus()
   570  
   571  	b.ResetTimer()
   572  	for i := 0; i < b.N; i++ {
   573  		x.Sub(y, m)
   574  	}
   575  }
   576  
   577  func BenchmarkMontgomeryRepr(b *testing.B) {
   578  	x := makeBenchmarkValue()
   579  	m := makeBenchmarkModulus()
   580  
   581  	b.ResetTimer()
   582  	for i := 0; i < b.N; i++ {
   583  		x.montgomeryRepresentation(m)
   584  	}
   585  }
   586  
   587  func BenchmarkMontgomeryMul(b *testing.B) {
   588  	x := makeBenchmarkValue()
   589  	y := makeBenchmarkValue()
   590  	out := makeBenchmarkValue()
   591  	m := makeBenchmarkModulus()
   592  
   593  	b.ResetTimer()
   594  	for i := 0; i < b.N; i++ {
   595  		out.montgomeryMul(x, y, m)
   596  	}
   597  }
   598  
   599  func BenchmarkModMul(b *testing.B) {
   600  	x := makeBenchmarkValue()
   601  	y := makeBenchmarkValue()
   602  	m := makeBenchmarkModulus()
   603  
   604  	b.ResetTimer()
   605  	for i := 0; i < b.N; i++ {
   606  		x.Mul(y, m)
   607  	}
   608  }
   609  
   610  func BenchmarkExpBig(b *testing.B) {
   611  	out := new(big.Int)
   612  	exponentBytes := makeBenchmarkExponent()
   613  	x := new(big.Int).SetBytes(exponentBytes)
   614  	e := new(big.Int).SetBytes(exponentBytes)
   615  	n := new(big.Int).SetBytes(exponentBytes)
   616  	one := new(big.Int).SetUint64(1)
   617  	n.Add(n, one)
   618  
   619  	b.ResetTimer()
   620  	for i := 0; i < b.N; i++ {
   621  		out.Exp(x, e, n)
   622  	}
   623  }
   624  
   625  func BenchmarkExp(b *testing.B) {
   626  	x := makeBenchmarkValue()
   627  	e := makeBenchmarkExponent()
   628  	out := makeBenchmarkValue()
   629  	m := makeBenchmarkModulus()
   630  
   631  	b.ResetTimer()
   632  	for i := 0; i < b.N; i++ {
   633  		out.Exp(x, e, m)
   634  	}
   635  }
   636  
   637  func TestNewModulus(t *testing.T) {
   638  	expected := "modulus must be > 1"
   639  	_, err := NewModulus([]byte{})
   640  	if err == nil || err.Error() != expected {
   641  		t.Errorf("NewModulus(0) got %q, want %q", err, expected)
   642  	}
   643  	_, err = NewModulus([]byte{0})
   644  	if err == nil || err.Error() != expected {
   645  		t.Errorf("NewModulus(0) got %q, want %q", err, expected)
   646  	}
   647  	_, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
   648  	if err == nil || err.Error() != expected {
   649  		t.Errorf("NewModulus(0) got %q, want %q", err, expected)
   650  	}
   651  	_, err = NewModulus([]byte{1})
   652  	if err == nil || err.Error() != expected {
   653  		t.Errorf("NewModulus(1) got %q, want %q", err, expected)
   654  	}
   655  	_, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
   656  	if err == nil || err.Error() != expected {
   657  		t.Errorf("NewModulus(1) got %q, want %q", err, expected)
   658  	}
   659  }
   660  
   661  func makeTestValue(nbits int) []uint {
   662  	n := nbits / _W
   663  	x := make([]uint, n)
   664  	for i := range n {
   665  		x[i]--
   666  	}
   667  	return x
   668  }
   669  
   670  func TestAddMulVVWSized(t *testing.T) {
   671  	// Sized addMulVVW have architecture-specific implementations on
   672  	// a number of architectures. Test that they match the generic
   673  	// implementation.
   674  	tests := []struct {
   675  		n int
   676  		f func(z, x *uint, y uint) uint
   677  	}{
   678  		{1024, addMulVVW1024},
   679  		{1536, addMulVVW1536},
   680  		{2048, addMulVVW2048},
   681  	}
   682  	for _, test := range tests {
   683  		t.Run(fmt.Sprint(test.n), func(t *testing.T) {
   684  			x := makeTestValue(test.n)
   685  			z := makeTestValue(test.n)
   686  			z2 := slices.Clone(z)
   687  			var y uint
   688  			y--
   689  			c := addMulVVW(z, x, y)
   690  			c2 := test.f(&z2[0], &x[0], y)
   691  			if !slices.Equal(z, z2) || c != c2 {
   692  				t.Errorf("%016X, %016X != %016X, %016X", z, c, z2, c2)
   693  			}
   694  		})
   695  	}
   696  }
   697  
   698  func TestInverse(t *testing.T) {
   699  	f, err := os.Open("testdata/mod_inv_tests.txt")
   700  	if err != nil {
   701  		t.Fatal(err)
   702  	}
   703  
   704  	var ModInv, A, M string
   705  	var lineNum int
   706  	scanner := bufio.NewScanner(f)
   707  	for scanner.Scan() {
   708  		lineNum++
   709  		line := scanner.Text()
   710  		if len(line) == 0 || line[0] == '#' {
   711  			continue
   712  		}
   713  
   714  		k, v, _ := strings.Cut(line, " = ")
   715  		switch k {
   716  		case "ModInv":
   717  			ModInv = v
   718  		case "A":
   719  			A = v
   720  		case "M":
   721  			M = v
   722  
   723  			t.Run(fmt.Sprintf("line %d", lineNum), func(t *testing.T) {
   724  				m, err := NewModulus(decodeHex(t, M))
   725  				if err != nil {
   726  					t.Skip("modulus <= 1")
   727  				}
   728  				a, err := NewNat().SetBytes(decodeHex(t, A), m)
   729  				if err != nil {
   730  					t.Fatal(err)
   731  				}
   732  
   733  				got, ok := NewNat().InverseVarTime(a, m)
   734  				if !ok {
   735  					t.Fatal("not invertible")
   736  				}
   737  				exp, err := NewNat().SetBytes(decodeHex(t, ModInv), m)
   738  				if err != nil {
   739  					t.Fatal(err)
   740  				}
   741  				if got.Equal(exp) != 1 {
   742  					t.Errorf("%v != %v", got, exp)
   743  				}
   744  			})
   745  		default:
   746  			t.Fatalf("unknown key %q on line %d", k, lineNum)
   747  		}
   748  	}
   749  	if err := scanner.Err(); err != nil {
   750  		t.Fatal(err)
   751  	}
   752  }
   753  
   754  func decodeHex(t *testing.T, s string) []byte {
   755  	t.Helper()
   756  	if len(s)%2 != 0 {
   757  		s = "0" + s
   758  	}
   759  	b, err := hex.DecodeString(s)
   760  	if err != nil {
   761  		t.Fatalf("failed to decode hex %q: %v", s, err)
   762  	}
   763  	return b
   764  }
   765  

View as plain text