1
2
3
4
5 package bigmod
6
7 import (
8 "bufio"
9 "bytes"
10 cryptorand "crypto/rand"
11 "encoding/hex"
12 "fmt"
13 "math/big"
14 "math/bits"
15 "math/rand"
16 "os"
17 "reflect"
18 "slices"
19 "strings"
20 "testing"
21 "testing/quick"
22 )
23
24
25
26
27
28 func (x *Nat) setBig(n *big.Int) *Nat {
29 limbs := n.Bits()
30 x.reset(len(limbs))
31 for i := range limbs {
32 x.limbs[i] = uint(limbs[i])
33 }
34 return x
35 }
36
37 func (n *Nat) asBig() *big.Int {
38 bits := make([]big.Word, len(n.limbs))
39 for i := range n.limbs {
40 bits[i] = big.Word(n.limbs[i])
41 }
42 return new(big.Int).SetBits(bits)
43 }
44
45 func (n *Nat) String() string {
46 var limbs []string
47 for i := range n.limbs {
48 limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i]))
49 }
50 return "{" + strings.Join(limbs, " ") + "}"
51 }
52
53
54
55 func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
56 limbs := make([]uint, size)
57 for i := 0; i < size; i++ {
58 limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
59 }
60 return reflect.ValueOf(&Nat{limbs})
61 }
62
63 func testModAddCommutative(a *Nat, b *Nat) bool {
64 m := maxModulus(uint(len(a.limbs)))
65 aPlusB := new(Nat).set(a)
66 aPlusB.Add(b, m)
67 bPlusA := new(Nat).set(b)
68 bPlusA.Add(a, m)
69 return aPlusB.Equal(bPlusA) == 1
70 }
71
72 func TestModAddCommutative(t *testing.T) {
73 err := quick.Check(testModAddCommutative, &quick.Config{})
74 if err != nil {
75 t.Error(err)
76 }
77 }
78
79 func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
80 m := maxModulus(uint(len(a.limbs)))
81 original := new(Nat).set(a)
82 a.Sub(b, m)
83 a.Add(b, m)
84 return a.Equal(original) == 1
85 }
86
87 func TestModSubThenAddIdentity(t *testing.T) {
88 err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
89 if err != nil {
90 t.Error(err)
91 }
92 }
93
94 func TestMontgomeryRoundtrip(t *testing.T) {
95 err := quick.Check(func(a *Nat) bool {
96 one := &Nat{make([]uint, len(a.limbs))}
97 one.limbs[0] = 1
98 aPlusOne := new(big.Int).SetBytes(natBytes(a))
99 aPlusOne.Add(aPlusOne, big.NewInt(1))
100 m, _ := NewModulus(aPlusOne.Bytes())
101 monty := new(Nat).set(a)
102 monty.montgomeryRepresentation(m)
103 aAgain := new(Nat).set(monty)
104 aAgain.montgomeryMul(monty, one, m)
105 if a.Equal(aAgain) != 1 {
106 t.Errorf("%v != %v", a, aAgain)
107 return false
108 }
109 return true
110 }, &quick.Config{})
111 if err != nil {
112 t.Error(err)
113 }
114 }
115
116 func TestShiftIn(t *testing.T) {
117 if bits.UintSize != 64 {
118 t.Skip("examples are only valid in 64 bit")
119 }
120 examples := []struct {
121 m, x, expected []byte
122 y uint64
123 }{{
124 m: []byte{13},
125 x: []byte{0},
126 y: 0xFFFF_FFFF_FFFF_FFFF,
127 expected: []byte{2},
128 }, {
129 m: []byte{13},
130 x: []byte{7},
131 y: 0xFFFF_FFFF_FFFF_FFFF,
132 expected: []byte{10},
133 }, {
134 m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
135 x: make([]byte, 9),
136 y: 0xFFFF_FFFF_FFFF_FFFF,
137 expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
138 }, {
139 m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
140 x: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
141 y: 0,
142 expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06},
143 }}
144
145 for i, tt := range examples {
146 m := modulusFromBytes(tt.m)
147 got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
148 if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 {
149 t.Errorf("%d: got %v, expected %v", i, got, exp)
150 }
151 }
152 }
153
154 func TestModulusAndNatSizes(t *testing.T) {
155
156
157
158
159 m := modulusFromBytes([]byte{
160 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
161 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
162 xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
163 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
164 natFromBytes(xb).ExpandFor(m)
165 NewNat().SetBytes(xb, m)
166 }
167
168 func TestSetBytes(t *testing.T) {
169 tests := []struct {
170 m, b []byte
171 fail bool
172 }{{
173 m: []byte{0xff, 0xff},
174 b: []byte{0x00, 0x01},
175 }, {
176 m: []byte{0xff, 0xff},
177 b: []byte{0xff, 0xff},
178 fail: true,
179 }, {
180 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
181 b: []byte{0x00, 0x01},
182 }, {
183 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
184 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
185 }, {
186 m: []byte{0xff, 0xff},
187 b: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
188 fail: true,
189 }, {
190 m: []byte{0xff, 0xff},
191 b: []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
192 fail: true,
193 }, {
194 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
195 b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
196 }, {
197 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
198 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
199 fail: true,
200 }, {
201 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
202 b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
203 fail: true,
204 }, {
205 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
206 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
207 fail: true,
208 }, {
209 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd},
210 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
211 fail: true,
212 }}
213
214 for i, tt := range tests {
215 m := modulusFromBytes(tt.m)
216 got, err := NewNat().SetBytes(tt.b, m)
217 if err != nil {
218 if !tt.fail {
219 t.Errorf("%d: unexpected error: %v", i, err)
220 }
221 continue
222 }
223 if tt.fail {
224 t.Errorf("%d: unexpected success", i)
225 continue
226 }
227 if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
228 t.Errorf("%d: got %v, expected %v", i, got, expected)
229 }
230 }
231
232 f := func(xBytes []byte) bool {
233 m := maxModulus(uint(len(xBytes)*8/_W + 1))
234 got, err := NewNat().SetBytes(xBytes, m)
235 if err != nil {
236 return false
237 }
238 return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes
239 }
240
241 err := quick.Check(f, &quick.Config{})
242 if err != nil {
243 t.Error(err)
244 }
245 }
246
247 func TestExpand(t *testing.T) {
248 sliced := []uint{1, 2, 3, 4}
249 examples := []struct {
250 in []uint
251 n int
252 out []uint
253 }{{
254 []uint{1, 2},
255 4,
256 []uint{1, 2, 0, 0},
257 }, {
258 sliced[:2],
259 4,
260 []uint{1, 2, 0, 0},
261 }, {
262 []uint{1, 2},
263 2,
264 []uint{1, 2},
265 }}
266
267 for i, tt := range examples {
268 got := (&Nat{tt.in}).expand(tt.n)
269 if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
270 t.Errorf("%d: got %v, expected %v", i, got, tt.out)
271 }
272 }
273 }
274
275 func TestMod(t *testing.T) {
276 m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})
277 x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
278 out := new(Nat)
279 out.Mod(x, m)
280 expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
281 if out.Equal(expected) != 1 {
282 t.Errorf("%+v != %+v", out, expected)
283 }
284 }
285
286 func TestModSub(t *testing.T) {
287 m := modulusFromBytes([]byte{13})
288 x := &Nat{[]uint{6}}
289 y := &Nat{[]uint{7}}
290 x.Sub(y, m)
291 expected := &Nat{[]uint{12}}
292 if x.Equal(expected) != 1 {
293 t.Errorf("%+v != %+v", x, expected)
294 }
295 x.Sub(y, m)
296 expected = &Nat{[]uint{5}}
297 if x.Equal(expected) != 1 {
298 t.Errorf("%+v != %+v", x, expected)
299 }
300 }
301
302 func TestModAdd(t *testing.T) {
303 m := modulusFromBytes([]byte{13})
304 x := &Nat{[]uint{6}}
305 y := &Nat{[]uint{7}}
306 x.Add(y, m)
307 expected := &Nat{[]uint{0}}
308 if x.Equal(expected) != 1 {
309 t.Errorf("%+v != %+v", x, expected)
310 }
311 x.Add(y, m)
312 expected = &Nat{[]uint{7}}
313 if x.Equal(expected) != 1 {
314 t.Errorf("%+v != %+v", x, expected)
315 }
316 }
317
318 func TestExp(t *testing.T) {
319 m := modulusFromBytes([]byte{13})
320 x := &Nat{[]uint{3}}
321 out := &Nat{[]uint{0}}
322 out.Exp(x, []byte{12}, m)
323 expected := &Nat{[]uint{1}}
324 if out.Equal(expected) != 1 {
325 t.Errorf("%+v != %+v", out, expected)
326 }
327 }
328
329 func TestExpShort(t *testing.T) {
330 m := modulusFromBytes([]byte{13})
331 x := &Nat{[]uint{3}}
332 out := &Nat{[]uint{0}}
333 out.ExpShortVarTime(x, 12, m)
334 expected := &Nat{[]uint{1}}
335 if out.Equal(expected) != 1 {
336 t.Errorf("%+v != %+v", out, expected)
337 }
338 }
339
340
341
342
343 func TestMulReductions(t *testing.T) {
344
345 a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10)
346 b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
347 n := new(big.Int).Mul(a, b)
348
349 N, _ := NewModulus(n.Bytes())
350 A := NewNat().setBig(a).ExpandFor(N)
351 B := NewNat().setBig(b).ExpandFor(N)
352
353 if A.Mul(B, N).IsZero() != 1 {
354 t.Error("a * b mod (a * b) != 0")
355 }
356
357 i := new(big.Int).ModInverse(a, b)
358 N, _ = NewModulus(b.Bytes())
359 A = NewNat().setBig(a).ExpandFor(N)
360 I := NewNat().setBig(i).ExpandFor(N)
361 one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
362
363 if A.Mul(I, N).Equal(one) != 1 {
364 t.Error("a * inv(a) mod b != 1")
365 }
366 }
367
368 func TestMul(t *testing.T) {
369 t.Run("small", func(t *testing.T) { testMul(t, 760/8) })
370 t.Run("1024", func(t *testing.T) { testMul(t, 1024/8) })
371 t.Run("1536", func(t *testing.T) { testMul(t, 1536/8) })
372 t.Run("2048", func(t *testing.T) { testMul(t, 2048/8) })
373 }
374
375 func testMul(t *testing.T, n int) {
376 a, b, m := make([]byte, n), make([]byte, n), make([]byte, n)
377 cryptorand.Read(a)
378 cryptorand.Read(b)
379 cryptorand.Read(m)
380
381
382 if bytes.Compare(a, m) > 0 {
383 a, m = m, a
384 }
385 if bytes.Compare(b, m) > 0 {
386 b, m = m, b
387 }
388
389 M, err := NewModulus(m)
390 if err != nil {
391 t.Fatal(err)
392 }
393 A, err := NewNat().SetBytes(a, M)
394 if err != nil {
395 t.Fatal(err)
396 }
397 B, err := NewNat().SetBytes(b, M)
398 if err != nil {
399 t.Fatal(err)
400 }
401
402 A.Mul(B, M)
403 ABytes := A.Bytes(M)
404
405 mBig := new(big.Int).SetBytes(m)
406 aBig := new(big.Int).SetBytes(a)
407 bBig := new(big.Int).SetBytes(b)
408 nBig := new(big.Int).Mul(aBig, bBig)
409 nBig.Mod(nBig, mBig)
410 nBigBytes := make([]byte, len(ABytes))
411 nBig.FillBytes(nBigBytes)
412
413 if !bytes.Equal(ABytes, nBigBytes) {
414 t.Errorf("got %x, want %x", ABytes, nBigBytes)
415 }
416 }
417
418 func TestIs(t *testing.T) {
419 checkYes := func(c choice, err string) {
420 t.Helper()
421 if c != yes {
422 t.Error(err)
423 }
424 }
425 checkNot := func(c choice, err string) {
426 t.Helper()
427 if c != no {
428 t.Error(err)
429 }
430 }
431
432 mFour := modulusFromBytes([]byte{4})
433 n, err := NewNat().SetBytes([]byte{3}, mFour)
434 if err != nil {
435 t.Fatal(err)
436 }
437 checkYes(n.IsMinusOne(mFour), "3 is not -1 mod 4")
438 checkNot(n.IsZero(), "3 is zero")
439 checkNot(n.IsOne(), "3 is one")
440 checkYes(n.IsOdd(), "3 is not odd")
441 n.SubOne(mFour)
442 checkNot(n.IsMinusOne(mFour), "2 is -1 mod 4")
443 checkNot(n.IsZero(), "2 is zero")
444 checkNot(n.IsOne(), "2 is one")
445 checkNot(n.IsOdd(), "2 is odd")
446 n.SubOne(mFour)
447 checkNot(n.IsMinusOne(mFour), "1 is -1 mod 4")
448 checkNot(n.IsZero(), "1 is zero")
449 checkYes(n.IsOne(), "1 is not one")
450 checkYes(n.IsOdd(), "1 is not odd")
451 n.SubOne(mFour)
452 checkNot(n.IsMinusOne(mFour), "0 is -1 mod 4")
453 checkYes(n.IsZero(), "0 is not zero")
454 checkNot(n.IsOne(), "0 is one")
455 checkNot(n.IsOdd(), "0 is odd")
456 n.SubOne(mFour)
457 checkYes(n.IsMinusOne(mFour), "-1 is not -1 mod 4")
458 checkNot(n.IsZero(), "-1 is zero")
459 checkNot(n.IsOne(), "-1 is one")
460 checkYes(n.IsOdd(), "-1 mod 4 is not odd")
461
462 mTwoLimbs := maxModulus(2)
463 n, err = NewNat().SetBytes([]byte{0x01}, mTwoLimbs)
464 if err != nil {
465 t.Fatal(err)
466 }
467 if n.IsOne() != 1 {
468 t.Errorf("1 is not one")
469 }
470 }
471
472 func TestTrailingZeroBits(t *testing.T) {
473 nb := new(big.Int).SetBytes([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7e})
474 nb.Lsh(nb, 128)
475 expected := 129
476 for expected >= 0 {
477 n := NewNat().setBig(nb)
478 if n.TrailingZeroBitsVarTime() != uint(expected) {
479 t.Errorf("%d != %d", n.TrailingZeroBitsVarTime(), expected)
480 }
481 nb.Rsh(nb, 1)
482 expected--
483 }
484 }
485
486 func TestRightShift(t *testing.T) {
487 nb, err := cryptorand.Int(cryptorand.Reader, new(big.Int).Lsh(big.NewInt(1), 1024))
488 if err != nil {
489 t.Fatal(err)
490 }
491 for _, shift := range []uint{1, 32, 64, 128, 1024 - 128, 1024 - 64, 1024 - 32, 1024 - 1} {
492 testShift := func(t *testing.T, shift uint) {
493 n := NewNat().setBig(nb)
494 oldLen := len(n.limbs)
495 n.ShiftRightVarTime(shift)
496 if len(n.limbs) != oldLen {
497 t.Errorf("len(n.limbs) = %d, want %d", len(n.limbs), oldLen)
498 }
499 exp := new(big.Int).Rsh(nb, shift)
500 if n.asBig().Cmp(exp) != 0 {
501 t.Errorf("%v != %v", n.asBig(), exp)
502 }
503 }
504 t.Run(fmt.Sprint(shift-1), func(t *testing.T) { testShift(t, shift-1) })
505 t.Run(fmt.Sprint(shift), func(t *testing.T) { testShift(t, shift) })
506 t.Run(fmt.Sprint(shift+1), func(t *testing.T) { testShift(t, shift+1) })
507 }
508 }
509
510 func natBytes(n *Nat) []byte {
511 return n.Bytes(maxModulus(uint(len(n.limbs))))
512 }
513
514 func natFromBytes(b []byte) *Nat {
515
516 bb := new(big.Int).SetBytes(b)
517 return NewNat().setBig(bb)
518 }
519
520 func modulusFromBytes(b []byte) *Modulus {
521 bb := new(big.Int).SetBytes(b)
522 m, _ := NewModulus(bb.Bytes())
523 return m
524 }
525
526
527 func maxModulus(n uint) *Modulus {
528 b := big.NewInt(1)
529 b.Lsh(b, n*_W)
530 b.Sub(b, big.NewInt(1))
531 m, _ := NewModulus(b.Bytes())
532 return m
533 }
534
535 func makeBenchmarkModulus() *Modulus {
536 return maxModulus(32)
537 }
538
539 func makeBenchmarkValue() *Nat {
540 x := make([]uint, 32)
541 for i := 0; i < 32; i++ {
542 x[i]--
543 }
544 return &Nat{limbs: x}
545 }
546
547 func makeBenchmarkExponent() []byte {
548 e := make([]byte, 256)
549 for i := 0; i < 32; i++ {
550 e[i] = 0xFF
551 }
552 return e
553 }
554
555 func BenchmarkModAdd(b *testing.B) {
556 x := makeBenchmarkValue()
557 y := makeBenchmarkValue()
558 m := makeBenchmarkModulus()
559
560 b.ResetTimer()
561 for i := 0; i < b.N; i++ {
562 x.Add(y, m)
563 }
564 }
565
566 func BenchmarkModSub(b *testing.B) {
567 x := makeBenchmarkValue()
568 y := makeBenchmarkValue()
569 m := makeBenchmarkModulus()
570
571 b.ResetTimer()
572 for i := 0; i < b.N; i++ {
573 x.Sub(y, m)
574 }
575 }
576
577 func BenchmarkMontgomeryRepr(b *testing.B) {
578 x := makeBenchmarkValue()
579 m := makeBenchmarkModulus()
580
581 b.ResetTimer()
582 for i := 0; i < b.N; i++ {
583 x.montgomeryRepresentation(m)
584 }
585 }
586
587 func BenchmarkMontgomeryMul(b *testing.B) {
588 x := makeBenchmarkValue()
589 y := makeBenchmarkValue()
590 out := makeBenchmarkValue()
591 m := makeBenchmarkModulus()
592
593 b.ResetTimer()
594 for i := 0; i < b.N; i++ {
595 out.montgomeryMul(x, y, m)
596 }
597 }
598
599 func BenchmarkModMul(b *testing.B) {
600 x := makeBenchmarkValue()
601 y := makeBenchmarkValue()
602 m := makeBenchmarkModulus()
603
604 b.ResetTimer()
605 for i := 0; i < b.N; i++ {
606 x.Mul(y, m)
607 }
608 }
609
610 func BenchmarkExpBig(b *testing.B) {
611 out := new(big.Int)
612 exponentBytes := makeBenchmarkExponent()
613 x := new(big.Int).SetBytes(exponentBytes)
614 e := new(big.Int).SetBytes(exponentBytes)
615 n := new(big.Int).SetBytes(exponentBytes)
616 one := new(big.Int).SetUint64(1)
617 n.Add(n, one)
618
619 b.ResetTimer()
620 for i := 0; i < b.N; i++ {
621 out.Exp(x, e, n)
622 }
623 }
624
625 func BenchmarkExp(b *testing.B) {
626 x := makeBenchmarkValue()
627 e := makeBenchmarkExponent()
628 out := makeBenchmarkValue()
629 m := makeBenchmarkModulus()
630
631 b.ResetTimer()
632 for i := 0; i < b.N; i++ {
633 out.Exp(x, e, m)
634 }
635 }
636
637 func TestNewModulus(t *testing.T) {
638 expected := "modulus must be > 1"
639 _, err := NewModulus([]byte{})
640 if err == nil || err.Error() != expected {
641 t.Errorf("NewModulus(0) got %q, want %q", err, expected)
642 }
643 _, err = NewModulus([]byte{0})
644 if err == nil || err.Error() != expected {
645 t.Errorf("NewModulus(0) got %q, want %q", err, expected)
646 }
647 _, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
648 if err == nil || err.Error() != expected {
649 t.Errorf("NewModulus(0) got %q, want %q", err, expected)
650 }
651 _, err = NewModulus([]byte{1})
652 if err == nil || err.Error() != expected {
653 t.Errorf("NewModulus(1) got %q, want %q", err, expected)
654 }
655 _, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
656 if err == nil || err.Error() != expected {
657 t.Errorf("NewModulus(1) got %q, want %q", err, expected)
658 }
659 }
660
661 func makeTestValue(nbits int) []uint {
662 n := nbits / _W
663 x := make([]uint, n)
664 for i := range n {
665 x[i]--
666 }
667 return x
668 }
669
670 func TestAddMulVVWSized(t *testing.T) {
671
672
673
674 tests := []struct {
675 n int
676 f func(z, x *uint, y uint) uint
677 }{
678 {1024, addMulVVW1024},
679 {1536, addMulVVW1536},
680 {2048, addMulVVW2048},
681 }
682 for _, test := range tests {
683 t.Run(fmt.Sprint(test.n), func(t *testing.T) {
684 x := makeTestValue(test.n)
685 z := makeTestValue(test.n)
686 z2 := slices.Clone(z)
687 var y uint
688 y--
689 c := addMulVVW(z, x, y)
690 c2 := test.f(&z2[0], &x[0], y)
691 if !slices.Equal(z, z2) || c != c2 {
692 t.Errorf("%016X, %016X != %016X, %016X", z, c, z2, c2)
693 }
694 })
695 }
696 }
697
698 func TestInverse(t *testing.T) {
699 f, err := os.Open("testdata/mod_inv_tests.txt")
700 if err != nil {
701 t.Fatal(err)
702 }
703
704 var ModInv, A, M string
705 var lineNum int
706 scanner := bufio.NewScanner(f)
707 for scanner.Scan() {
708 lineNum++
709 line := scanner.Text()
710 if len(line) == 0 || line[0] == '#' {
711 continue
712 }
713
714 k, v, _ := strings.Cut(line, " = ")
715 switch k {
716 case "ModInv":
717 ModInv = v
718 case "A":
719 A = v
720 case "M":
721 M = v
722
723 t.Run(fmt.Sprintf("line %d", lineNum), func(t *testing.T) {
724 m, err := NewModulus(decodeHex(t, M))
725 if err != nil {
726 t.Skip("modulus <= 1")
727 }
728 a, err := NewNat().SetBytes(decodeHex(t, A), m)
729 if err != nil {
730 t.Fatal(err)
731 }
732
733 got, ok := NewNat().InverseVarTime(a, m)
734 if !ok {
735 t.Fatal("not invertible")
736 }
737 exp, err := NewNat().SetBytes(decodeHex(t, ModInv), m)
738 if err != nil {
739 t.Fatal(err)
740 }
741 if got.Equal(exp) != 1 {
742 t.Errorf("%v != %v", got, exp)
743 }
744 })
745 default:
746 t.Fatalf("unknown key %q on line %d", k, lineNum)
747 }
748 }
749 if err := scanner.Err(); err != nil {
750 t.Fatal(err)
751 }
752 }
753
754 func decodeHex(t *testing.T, s string) []byte {
755 t.Helper()
756 if len(s)%2 != 0 {
757 s = "0" + s
758 }
759 b, err := hex.DecodeString(s)
760 if err != nil {
761 t.Fatalf("failed to decode hex %q: %v", s, err)
762 }
763 return b
764 }
765
View as plain text