1
2
3
4
5 package field
6
7 import (
8 "bytes"
9 "crypto/rand"
10 "encoding/hex"
11 "io"
12 "math/big"
13 "math/bits"
14 mathrand "math/rand"
15 "reflect"
16 "testing"
17 "testing/quick"
18 )
19
20 func (v Element) String() string {
21 return hex.EncodeToString(v.Bytes())
22 }
23
24
25
26 func quickCheckConfig(slowScale int) *quick.Config {
27 cfg := new(quick.Config)
28 if !testing.Short() {
29 cfg.MaxCountScale = float64(slowScale)
30 }
31 return cfg
32 }
33
34 func generateFieldElement(rand *mathrand.Rand) Element {
35 const maskLow52Bits = (1 << 52) - 1
36 return Element{
37 rand.Uint64() & maskLow52Bits,
38 rand.Uint64() & maskLow52Bits,
39 rand.Uint64() & maskLow52Bits,
40 rand.Uint64() & maskLow52Bits,
41 rand.Uint64() & maskLow52Bits,
42 }
43 }
44
45
46
47 var (
48 weirdLimbs51 = []uint64{
49 0, 0, 0, 0,
50 1,
51 19 - 1,
52 19,
53 0x2aaaaaaaaaaaa,
54 0x5555555555555,
55 (1 << 51) - 20,
56 (1 << 51) - 19,
57 (1 << 51) - 1, (1 << 51) - 1,
58 (1 << 51) - 1, (1 << 51) - 1,
59 }
60 weirdLimbs52 = []uint64{
61 0, 0, 0, 0, 0, 0,
62 1,
63 19 - 1,
64 19,
65 0x2aaaaaaaaaaaa,
66 0x5555555555555,
67 (1 << 51) - 20,
68 (1 << 51) - 19,
69 (1 << 51) - 1, (1 << 51) - 1,
70 (1 << 51) - 1, (1 << 51) - 1,
71 (1 << 51) - 1, (1 << 51) - 1,
72 1 << 51,
73 (1 << 51) + 1,
74 (1 << 52) - 19,
75 (1 << 52) - 1,
76 }
77 )
78
79 func generateWeirdFieldElement(rand *mathrand.Rand) Element {
80 return Element{
81 weirdLimbs52[rand.Intn(len(weirdLimbs52))],
82 weirdLimbs51[rand.Intn(len(weirdLimbs51))],
83 weirdLimbs51[rand.Intn(len(weirdLimbs51))],
84 weirdLimbs51[rand.Intn(len(weirdLimbs51))],
85 weirdLimbs51[rand.Intn(len(weirdLimbs51))],
86 }
87 }
88
89 func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value {
90 if rand.Intn(2) == 0 {
91 return reflect.ValueOf(generateWeirdFieldElement(rand))
92 }
93 return reflect.ValueOf(generateFieldElement(rand))
94 }
95
96
97
98 func isInBounds(x *Element) bool {
99 return bits.Len64(x.l0) <= 52 &&
100 bits.Len64(x.l1) <= 52 &&
101 bits.Len64(x.l2) <= 52 &&
102 bits.Len64(x.l3) <= 52 &&
103 bits.Len64(x.l4) <= 52
104 }
105
106 func TestMultiplyDistributesOverAdd(t *testing.T) {
107 multiplyDistributesOverAdd := func(x, y, z Element) bool {
108
109 t1 := new(Element)
110 t1.Add(&x, &y)
111 t1.Multiply(t1, &z)
112
113
114 t2 := new(Element)
115 t3 := new(Element)
116 t2.Multiply(&x, &z)
117 t3.Multiply(&y, &z)
118 t2.Add(t2, t3)
119
120 return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
121 }
122
123 if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig(1024)); err != nil {
124 t.Error(err)
125 }
126 }
127
128 func TestMul64to128(t *testing.T) {
129 a := uint64(5)
130 b := uint64(5)
131 r := mul64(a, b)
132 if r.lo != 0x19 || r.hi != 0 {
133 t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
134 }
135
136 a = uint64(18014398509481983)
137 b = uint64(18014398509481983)
138 r = mul64(a, b)
139 if r.lo != 0xff80000000000001 || r.hi != 0xfffffffffff {
140 t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
141 }
142
143 a = uint64(1125899906842661)
144 b = uint64(2097155)
145 r = mul64(a, b)
146 r = addMul64(r, a, b)
147 r = addMul64(r, a, b)
148 r = addMul64(r, a, b)
149 r = addMul64(r, a, b)
150 if r.lo != 16888498990613035 || r.hi != 640 {
151 t.Errorf("wrong answer: %d + %d*(2**64)", r.lo, r.hi)
152 }
153 }
154
155 func TestSetBytesRoundTrip(t *testing.T) {
156 f1 := func(in [32]byte, fe Element) bool {
157 fe.SetBytes(in[:])
158
159
160
161 in[len(in)-1] &= (1 << 7) - 1
162
163 return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe)
164 }
165 if err := quick.Check(f1, nil); err != nil {
166 t.Errorf("failed bytes->FE->bytes round-trip: %v", err)
167 }
168
169 f2 := func(fe, r Element) bool {
170 r.SetBytes(fe.Bytes())
171
172
173
174
175 fe.reduce()
176 r.reduce()
177 return fe == r
178 }
179 if err := quick.Check(f2, nil); err != nil {
180 t.Errorf("failed FE->bytes->FE round-trip: %v", err)
181 }
182
183
184 type feRTTest struct {
185 fe Element
186 b []byte
187 }
188 var tests = []feRTTest{
189 {
190 fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676},
191 b: []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31},
192 },
193 {
194 fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972},
195 b: []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122},
196 },
197 }
198
199 for _, tt := range tests {
200 b := tt.fe.Bytes()
201 fe, _ := new(Element).SetBytes(tt.b)
202 if !bytes.Equal(b, tt.b) || fe.Equal(&tt.fe) != 1 {
203 t.Errorf("Failed fixed roundtrip: %v", tt)
204 }
205 }
206 }
207
208 func swapEndianness(buf []byte) []byte {
209 for i := 0; i < len(buf)/2; i++ {
210 buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i]
211 }
212 return buf
213 }
214
215 func TestBytesBigEquivalence(t *testing.T) {
216 f1 := func(in [32]byte, fe, fe1 Element) bool {
217 fe.SetBytes(in[:])
218
219 in[len(in)-1] &= (1 << 7) - 1
220 b := new(big.Int).SetBytes(swapEndianness(in[:]))
221 fe1.fromBig(b)
222
223 if fe != fe1 {
224 return false
225 }
226
227 buf := make([]byte, 32)
228 buf = swapEndianness(fe1.toBig().FillBytes(buf))
229
230 return bytes.Equal(fe.Bytes(), buf) && isInBounds(&fe) && isInBounds(&fe1)
231 }
232 if err := quick.Check(f1, nil); err != nil {
233 t.Error(err)
234 }
235 }
236
237
238 func (v *Element) fromBig(n *big.Int) *Element {
239 if n.BitLen() > 32*8 {
240 panic("edwards25519: invalid field element input size")
241 }
242
243 buf := make([]byte, 0, 32)
244 for _, word := range n.Bits() {
245 for i := 0; i < bits.UintSize; i += 8 {
246 if len(buf) >= cap(buf) {
247 break
248 }
249 buf = append(buf, byte(word))
250 word >>= 8
251 }
252 }
253
254 v.SetBytes(buf[:32])
255 return v
256 }
257
258 func (v *Element) fromDecimal(s string) *Element {
259 n, ok := new(big.Int).SetString(s, 10)
260 if !ok {
261 panic("not a valid decimal: " + s)
262 }
263 return v.fromBig(n)
264 }
265
266
267 func (v *Element) toBig() *big.Int {
268 buf := v.Bytes()
269
270 words := make([]big.Word, 32*8/bits.UintSize)
271 for n := range words {
272 for i := 0; i < bits.UintSize; i += 8 {
273 if len(buf) == 0 {
274 break
275 }
276 words[n] |= big.Word(buf[0]) << big.Word(i)
277 buf = buf[1:]
278 }
279 }
280
281 return new(big.Int).SetBits(words)
282 }
283
284 func TestDecimalConstants(t *testing.T) {
285 sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752"
286 if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 {
287 t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp)
288 }
289
290
291
292
293
294 }
295
296 func TestSetBytesRoundTripEdgeCases(t *testing.T) {
297
298
299
300 }
301
302
303 func TestConsistency(t *testing.T) {
304 var x Element
305 var x2, x2sq Element
306
307 x = Element{1, 1, 1, 1, 1}
308 x2.Multiply(&x, &x)
309 x2sq.Square(&x)
310
311 if x2 != x2sq {
312 t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
313 }
314
315 var bytes [32]byte
316
317 _, err := io.ReadFull(rand.Reader, bytes[:])
318 if err != nil {
319 t.Fatal(err)
320 }
321 x.SetBytes(bytes[:])
322
323 x2.Multiply(&x, &x)
324 x2sq.Square(&x)
325
326 if x2 != x2sq {
327 t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
328 }
329 }
330
331 func TestEqual(t *testing.T) {
332 x := Element{1, 1, 1, 1, 1}
333 y := Element{5, 4, 3, 2, 1}
334
335 eq := x.Equal(&x)
336 if eq != 1 {
337 t.Errorf("wrong about equality")
338 }
339
340 eq = x.Equal(&y)
341 if eq != 0 {
342 t.Errorf("wrong about inequality")
343 }
344 }
345
346 func TestInvert(t *testing.T) {
347 x := Element{1, 1, 1, 1, 1}
348 one := Element{1, 0, 0, 0, 0}
349 var xinv, r Element
350
351 xinv.Invert(&x)
352 r.Multiply(&x, &xinv)
353 r.reduce()
354
355 if one != r {
356 t.Errorf("inversion identity failed, got: %x", r)
357 }
358
359 var bytes [32]byte
360
361 _, err := io.ReadFull(rand.Reader, bytes[:])
362 if err != nil {
363 t.Fatal(err)
364 }
365 x.SetBytes(bytes[:])
366
367 xinv.Invert(&x)
368 r.Multiply(&x, &xinv)
369 r.reduce()
370
371 if one != r {
372 t.Errorf("random inversion identity failed, got: %x for field element %x", r, x)
373 }
374
375 zero := Element{}
376 x.Set(&zero)
377 if xx := xinv.Invert(&x); xx != &xinv {
378 t.Errorf("inverting zero did not return the receiver")
379 } else if xinv.Equal(&zero) != 1 {
380 t.Errorf("inverting zero did not return zero")
381 }
382 }
383
384 func TestSelectSwap(t *testing.T) {
385 a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}
386 b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}
387
388 var c, d Element
389
390 c.Select(&a, &b, 1)
391 d.Select(&a, &b, 0)
392
393 if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
394 t.Errorf("Select failed")
395 }
396
397 c.Swap(&d, 0)
398
399 if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
400 t.Errorf("Swap failed")
401 }
402
403 c.Swap(&d, 1)
404
405 if c.Equal(&b) != 1 || d.Equal(&a) != 1 {
406 t.Errorf("Swap failed")
407 }
408 }
409
410 func TestMult32(t *testing.T) {
411 mult32EquivalentToMul := func(x Element, y uint32) bool {
412 t1 := new(Element)
413 for i := 0; i < 100; i++ {
414 t1.Mult32(&x, y)
415 }
416
417 ty := new(Element)
418 ty.l0 = uint64(y)
419
420 t2 := new(Element)
421 for i := 0; i < 100; i++ {
422 t2.Multiply(&x, ty)
423 }
424
425 return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
426 }
427
428 if err := quick.Check(mult32EquivalentToMul, quickCheckConfig(1024)); err != nil {
429 t.Error(err)
430 }
431 }
432
433 func TestSqrtRatio(t *testing.T) {
434
435 type test struct {
436 u, v []byte
437 wasSquare int
438 r []byte
439 }
440 var tests = []test{
441
442
443
444 {
445 decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
446 decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
447 1, decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
448 },
449
450 {
451 decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
452 decodeHex("0100000000000000000000000000000000000000000000000000000000000000"),
453 1, decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
454 },
455
456 {
457 decodeHex("0100000000000000000000000000000000000000000000000000000000000000"),
458 decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
459 0, decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
460 },
461
462 {
463 decodeHex("0200000000000000000000000000000000000000000000000000000000000000"),
464 decodeHex("0100000000000000000000000000000000000000000000000000000000000000"),
465 0, decodeHex("3c5ff1b5d8e4113b871bd052f9e7bcd0582804c266ffb2d4f4203eb07fdb7c54"),
466 },
467
468 {
469 decodeHex("0400000000000000000000000000000000000000000000000000000000000000"),
470 decodeHex("0100000000000000000000000000000000000000000000000000000000000000"),
471 1, decodeHex("0200000000000000000000000000000000000000000000000000000000000000"),
472 },
473
474 {
475 decodeHex("0100000000000000000000000000000000000000000000000000000000000000"),
476 decodeHex("0400000000000000000000000000000000000000000000000000000000000000"),
477 1, decodeHex("f6ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3f"),
478 },
479 }
480
481 for i, tt := range tests {
482 u, _ := new(Element).SetBytes(tt.u)
483 v, _ := new(Element).SetBytes(tt.v)
484 want, _ := new(Element).SetBytes(tt.r)
485 got, wasSquare := new(Element).SqrtRatio(u, v)
486 if got.Equal(want) == 0 || wasSquare != tt.wasSquare {
487 t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare)
488 }
489 }
490 }
491
492 func TestCarryPropagate(t *testing.T) {
493 asmLikeGeneric := func(a [5]uint64) bool {
494 t1 := &Element{a[0], a[1], a[2], a[3], a[4]}
495 t2 := &Element{a[0], a[1], a[2], a[3], a[4]}
496
497 t1.carryPropagate()
498 t2.carryPropagateGeneric()
499
500 if *t1 != *t2 {
501 t.Logf("got: %#v,\nexpected: %#v", t1, t2)
502 }
503
504 return *t1 == *t2 && isInBounds(t2)
505 }
506
507 if err := quick.Check(asmLikeGeneric, quickCheckConfig(1024)); err != nil {
508 t.Error(err)
509 }
510
511 if !asmLikeGeneric([5]uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}) {
512 t.Errorf("failed for {0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}")
513 }
514 }
515
516 func TestFeSquare(t *testing.T) {
517 asmLikeGeneric := func(a Element) bool {
518 t1 := a
519 t2 := a
520
521 feSquareGeneric(&t1, &t1)
522 feSquare(&t2, &t2)
523
524 if t1 != t2 {
525 t.Logf("got: %#v,\nexpected: %#v", t1, t2)
526 }
527
528 return t1 == t2 && isInBounds(&t2)
529 }
530
531 if err := quick.Check(asmLikeGeneric, quickCheckConfig(1024)); err != nil {
532 t.Error(err)
533 }
534 }
535
536 func TestFeMul(t *testing.T) {
537 asmLikeGeneric := func(a, b Element) bool {
538 a1 := a
539 a2 := a
540 b1 := b
541 b2 := b
542
543 feMulGeneric(&a1, &a1, &b1)
544 feMul(&a2, &a2, &b2)
545
546 if a1 != a2 || b1 != b2 {
547 t.Logf("got: %#v,\nexpected: %#v", a1, a2)
548 t.Logf("got: %#v,\nexpected: %#v", b1, b2)
549 }
550
551 return a1 == a2 && isInBounds(&a2) &&
552 b1 == b2 && isInBounds(&b2)
553 }
554
555 if err := quick.Check(asmLikeGeneric, quickCheckConfig(1024)); err != nil {
556 t.Error(err)
557 }
558 }
559
560 func decodeHex(s string) []byte {
561 b, err := hex.DecodeString(s)
562 if err != nil {
563 panic(err)
564 }
565 return b
566 }
567
View as plain text