Source file src/crypto/internal/fips140/edwards25519/field/fe_test.go

     1  // Copyright (c) 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package field
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"encoding/hex"
    11  	"io"
    12  	"math/big"
    13  	"math/bits"
    14  	mathrand "math/rand"
    15  	"reflect"
    16  	"testing"
    17  	"testing/quick"
    18  )
    19  
    20  func (v Element) String() string {
    21  	return hex.EncodeToString(v.Bytes())
    22  }
    23  
    24  // quickCheckConfig returns a quick.Config that scales the max count by the
    25  // given factor if the -short flag is not set.
    26  func quickCheckConfig(slowScale int) *quick.Config {
    27  	cfg := new(quick.Config)
    28  	if !testing.Short() {
    29  		cfg.MaxCountScale = float64(slowScale)
    30  	}
    31  	return cfg
    32  }
    33  
    34  func generateFieldElement(rand *mathrand.Rand) Element {
    35  	const maskLow52Bits = (1 << 52) - 1
    36  	return Element{
    37  		rand.Uint64() & maskLow52Bits,
    38  		rand.Uint64() & maskLow52Bits,
    39  		rand.Uint64() & maskLow52Bits,
    40  		rand.Uint64() & maskLow52Bits,
    41  		rand.Uint64() & maskLow52Bits,
    42  	}
    43  }
    44  
    45  // weirdLimbs can be combined to generate a range of edge-case field elements.
    46  // 0 and -1 are intentionally more weighted, as they combine well.
    47  var (
    48  	weirdLimbs51 = []uint64{
    49  		0, 0, 0, 0,
    50  		1,
    51  		19 - 1,
    52  		19,
    53  		0x2aaaaaaaaaaaa,
    54  		0x5555555555555,
    55  		(1 << 51) - 20,
    56  		(1 << 51) - 19,
    57  		(1 << 51) - 1, (1 << 51) - 1,
    58  		(1 << 51) - 1, (1 << 51) - 1,
    59  	}
    60  	weirdLimbs52 = []uint64{
    61  		0, 0, 0, 0, 0, 0,
    62  		1,
    63  		19 - 1,
    64  		19,
    65  		0x2aaaaaaaaaaaa,
    66  		0x5555555555555,
    67  		(1 << 51) - 20,
    68  		(1 << 51) - 19,
    69  		(1 << 51) - 1, (1 << 51) - 1,
    70  		(1 << 51) - 1, (1 << 51) - 1,
    71  		(1 << 51) - 1, (1 << 51) - 1,
    72  		1 << 51,
    73  		(1 << 51) + 1,
    74  		(1 << 52) - 19,
    75  		(1 << 52) - 1,
    76  	}
    77  )
    78  
    79  func generateWeirdFieldElement(rand *mathrand.Rand) Element {
    80  	return Element{
    81  		weirdLimbs52[rand.Intn(len(weirdLimbs52))],
    82  		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
    83  		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
    84  		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
    85  		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
    86  	}
    87  }
    88  
    89  func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value {
    90  	if rand.Intn(2) == 0 {
    91  		return reflect.ValueOf(generateWeirdFieldElement(rand))
    92  	}
    93  	return reflect.ValueOf(generateFieldElement(rand))
    94  }
    95  
    96  // isInBounds returns whether the element is within the expected bit size bounds
    97  // after a light reduction.
    98  func isInBounds(x *Element) bool {
    99  	return bits.Len64(x.l0) <= 52 &&
   100  		bits.Len64(x.l1) <= 52 &&
   101  		bits.Len64(x.l2) <= 52 &&
   102  		bits.Len64(x.l3) <= 52 &&
   103  		bits.Len64(x.l4) <= 52
   104  }
   105  
   106  func TestMultiplyDistributesOverAdd(t *testing.T) {
   107  	multiplyDistributesOverAdd := func(x, y, z Element) bool {
   108  		// Compute t1 = (x+y)*z
   109  		t1 := new(Element)
   110  		t1.Add(&x, &y)
   111  		t1.Multiply(t1, &z)
   112  
   113  		// Compute t2 = x*z + y*z
   114  		t2 := new(Element)
   115  		t3 := new(Element)
   116  		t2.Multiply(&x, &z)
   117  		t3.Multiply(&y, &z)
   118  		t2.Add(t2, t3)
   119  
   120  		return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
   121  	}
   122  
   123  	if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig(1024)); err != nil {
   124  		t.Error(err)
   125  	}
   126  }
   127  
   128  func TestMul64to128(t *testing.T) {
   129  	a := uint64(5)
   130  	b := uint64(5)
   131  	r := mul64(a, b)
   132  	if r.lo != 0x19 || r.hi != 0 {
   133  		t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
   134  	}
   135  
   136  	a = uint64(18014398509481983) // 2^54 - 1
   137  	b = uint64(18014398509481983) // 2^54 - 1
   138  	r = mul64(a, b)
   139  	if r.lo != 0xff80000000000001 || r.hi != 0xfffffffffff {
   140  		t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
   141  	}
   142  
   143  	a = uint64(1125899906842661)
   144  	b = uint64(2097155)
   145  	r = mul64(a, b)
   146  	r = addMul64(r, a, b)
   147  	r = addMul64(r, a, b)
   148  	r = addMul64(r, a, b)
   149  	r = addMul64(r, a, b)
   150  	if r.lo != 16888498990613035 || r.hi != 640 {
   151  		t.Errorf("wrong answer: %d + %d*(2**64)", r.lo, r.hi)
   152  	}
   153  }
   154  
   155  func TestSetBytesRoundTrip(t *testing.T) {
   156  	f1 := func(in [32]byte, fe Element) bool {
   157  		fe.SetBytes(in[:])
   158  
   159  		// Mask the most significant bit as it's ignored by SetBytes. (Now
   160  		// instead of earlier so we check the masking in SetBytes is working.)
   161  		in[len(in)-1] &= (1 << 7) - 1
   162  
   163  		return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe)
   164  	}
   165  	if err := quick.Check(f1, nil); err != nil {
   166  		t.Errorf("failed bytes->FE->bytes round-trip: %v", err)
   167  	}
   168  
   169  	f2 := func(fe, r Element) bool {
   170  		r.SetBytes(fe.Bytes())
   171  
   172  		// Intentionally not using Equal not to go through Bytes again.
   173  		// Calling reduce because both Generate and SetBytes can produce
   174  		// non-canonical representations.
   175  		fe.reduce()
   176  		r.reduce()
   177  		return fe == r
   178  	}
   179  	if err := quick.Check(f2, nil); err != nil {
   180  		t.Errorf("failed FE->bytes->FE round-trip: %v", err)
   181  	}
   182  
   183  	// Check some fixed vectors from dalek
   184  	type feRTTest struct {
   185  		fe Element
   186  		b  []byte
   187  	}
   188  	var tests = []feRTTest{
   189  		{
   190  			fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676},
   191  			b:  []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31},
   192  		},
   193  		{
   194  			fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972},
   195  			b:  []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122},
   196  		},
   197  	}
   198  
   199  	for _, tt := range tests {
   200  		b := tt.fe.Bytes()
   201  		fe, _ := new(Element).SetBytes(tt.b)
   202  		if !bytes.Equal(b, tt.b) || fe.Equal(&tt.fe) != 1 {
   203  			t.Errorf("Failed fixed roundtrip: %v", tt)
   204  		}
   205  	}
   206  }
   207  
   208  func swapEndianness(buf []byte) []byte {
   209  	for i := 0; i < len(buf)/2; i++ {
   210  		buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i]
   211  	}
   212  	return buf
   213  }
   214  
   215  func TestBytesBigEquivalence(t *testing.T) {
   216  	f1 := func(in [32]byte, fe, fe1 Element) bool {
   217  		fe.SetBytes(in[:])
   218  
   219  		in[len(in)-1] &= (1 << 7) - 1 // mask the most significant bit
   220  		b := new(big.Int).SetBytes(swapEndianness(in[:]))
   221  		fe1.fromBig(b)
   222  
   223  		if fe != fe1 {
   224  			return false
   225  		}
   226  
   227  		buf := make([]byte, 32)
   228  		buf = swapEndianness(fe1.toBig().FillBytes(buf))
   229  
   230  		return bytes.Equal(fe.Bytes(), buf) && isInBounds(&fe) && isInBounds(&fe1)
   231  	}
   232  	if err := quick.Check(f1, nil); err != nil {
   233  		t.Error(err)
   234  	}
   235  }
   236  
   237  // fromBig sets v = n, and returns v. The bit length of n must not exceed 256.
   238  func (v *Element) fromBig(n *big.Int) *Element {
   239  	if n.BitLen() > 32*8 {
   240  		panic("edwards25519: invalid field element input size")
   241  	}
   242  
   243  	buf := make([]byte, 0, 32)
   244  	for _, word := range n.Bits() {
   245  		for i := 0; i < bits.UintSize; i += 8 {
   246  			if len(buf) >= cap(buf) {
   247  				break
   248  			}
   249  			buf = append(buf, byte(word))
   250  			word >>= 8
   251  		}
   252  	}
   253  
   254  	v.SetBytes(buf[:32])
   255  	return v
   256  }
   257  
   258  func (v *Element) fromDecimal(s string) *Element {
   259  	n, ok := new(big.Int).SetString(s, 10)
   260  	if !ok {
   261  		panic("not a valid decimal: " + s)
   262  	}
   263  	return v.fromBig(n)
   264  }
   265  
   266  // toBig returns v as a big.Int.
   267  func (v *Element) toBig() *big.Int {
   268  	buf := v.Bytes()
   269  
   270  	words := make([]big.Word, 32*8/bits.UintSize)
   271  	for n := range words {
   272  		for i := 0; i < bits.UintSize; i += 8 {
   273  			if len(buf) == 0 {
   274  				break
   275  			}
   276  			words[n] |= big.Word(buf[0]) << big.Word(i)
   277  			buf = buf[1:]
   278  		}
   279  	}
   280  
   281  	return new(big.Int).SetBits(words)
   282  }
   283  
   284  func TestDecimalConstants(t *testing.T) {
   285  	sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752"
   286  	if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 {
   287  		t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp)
   288  	}
   289  	// d is in the parent package, and we don't want to expose d or fromDecimal.
   290  	// dString := "37095705934669439343138083508754565189542113879843219016388785533085940283555"
   291  	// if exp := new(Element).fromDecimal(dString); d.Equal(exp) != 1 {
   292  	// 	t.Errorf("d is %v, expected %v", d, exp)
   293  	// }
   294  }
   295  
   296  func TestSetBytesRoundTripEdgeCases(t *testing.T) {
   297  	// TODO: values close to 0, close to 2^255-19, between 2^255-19 and 2^255-1,
   298  	// and between 2^255 and 2^256-1. Test both the documented SetBytes
   299  	// behavior, and that Bytes reduces them.
   300  }
   301  
   302  // Tests self-consistency between Multiply and Square.
   303  func TestConsistency(t *testing.T) {
   304  	var x Element
   305  	var x2, x2sq Element
   306  
   307  	x = Element{1, 1, 1, 1, 1}
   308  	x2.Multiply(&x, &x)
   309  	x2sq.Square(&x)
   310  
   311  	if x2 != x2sq {
   312  		t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
   313  	}
   314  
   315  	var bytes [32]byte
   316  
   317  	_, err := io.ReadFull(rand.Reader, bytes[:])
   318  	if err != nil {
   319  		t.Fatal(err)
   320  	}
   321  	x.SetBytes(bytes[:])
   322  
   323  	x2.Multiply(&x, &x)
   324  	x2sq.Square(&x)
   325  
   326  	if x2 != x2sq {
   327  		t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
   328  	}
   329  }
   330  
   331  func TestEqual(t *testing.T) {
   332  	x := Element{1, 1, 1, 1, 1}
   333  	y := Element{5, 4, 3, 2, 1}
   334  
   335  	eq := x.Equal(&x)
   336  	if eq != 1 {
   337  		t.Errorf("wrong about equality")
   338  	}
   339  
   340  	eq = x.Equal(&y)
   341  	if eq != 0 {
   342  		t.Errorf("wrong about inequality")
   343  	}
   344  }
   345  
   346  func TestInvert(t *testing.T) {
   347  	x := Element{1, 1, 1, 1, 1}
   348  	one := Element{1, 0, 0, 0, 0}
   349  	var xinv, r Element
   350  
   351  	xinv.Invert(&x)
   352  	r.Multiply(&x, &xinv)
   353  	r.reduce()
   354  
   355  	if one != r {
   356  		t.Errorf("inversion identity failed, got: %x", r)
   357  	}
   358  
   359  	var bytes [32]byte
   360  
   361  	_, err := io.ReadFull(rand.Reader, bytes[:])
   362  	if err != nil {
   363  		t.Fatal(err)
   364  	}
   365  	x.SetBytes(bytes[:])
   366  
   367  	xinv.Invert(&x)
   368  	r.Multiply(&x, &xinv)
   369  	r.reduce()
   370  
   371  	if one != r {
   372  		t.Errorf("random inversion identity failed, got: %x for field element %x", r, x)
   373  	}
   374  
   375  	zero := Element{}
   376  	x.Set(&zero)
   377  	if xx := xinv.Invert(&x); xx != &xinv {
   378  		t.Errorf("inverting zero did not return the receiver")
   379  	} else if xinv.Equal(&zero) != 1 {
   380  		t.Errorf("inverting zero did not return zero")
   381  	}
   382  }
   383  
   384  func TestSelectSwap(t *testing.T) {
   385  	a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}
   386  	b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}
   387  
   388  	var c, d Element
   389  
   390  	c.Select(&a, &b, 1)
   391  	d.Select(&a, &b, 0)
   392  
   393  	if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
   394  		t.Errorf("Select failed")
   395  	}
   396  
   397  	c.Swap(&d, 0)
   398  
   399  	if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
   400  		t.Errorf("Swap failed")
   401  	}
   402  
   403  	c.Swap(&d, 1)
   404  
   405  	if c.Equal(&b) != 1 || d.Equal(&a) != 1 {
   406  		t.Errorf("Swap failed")
   407  	}
   408  }
   409  
   410  func TestMult32(t *testing.T) {
   411  	mult32EquivalentToMul := func(x Element, y uint32) bool {
   412  		t1 := new(Element)
   413  		for i := 0; i < 100; i++ {
   414  			t1.Mult32(&x, y)
   415  		}
   416  
   417  		ty := new(Element)
   418  		ty.l0 = uint64(y)
   419  
   420  		t2 := new(Element)
   421  		for i := 0; i < 100; i++ {
   422  			t2.Multiply(&x, ty)
   423  		}
   424  
   425  		return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
   426  	}
   427  
   428  	if err := quick.Check(mult32EquivalentToMul, quickCheckConfig(1024)); err != nil {
   429  		t.Error(err)
   430  	}
   431  }
   432  
   433  func TestSqrtRatio(t *testing.T) {
   434  	// From draft-irtf-cfrg-ristretto255-decaf448-00, Appendix A.4.
   435  	type test struct {
   436  		u, v      []byte
   437  		wasSquare int
   438  		r         []byte
   439  	}
   440  	var tests = []test{
   441  		// If u is 0, the function is defined to return (0, TRUE), even if v
   442  		// is zero. Note that where used in this package, the denominator v
   443  		// is never zero.
   444  		{
   445  			decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
   446  			decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
   447  			1, decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
   448  		},
   449  		// 0/1 == 0²
   450  		{
   451  			decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
   452  			decodeHex("0100000000000000000000000000000000000000000000000000000000000000"),
   453  			1, decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
   454  		},
   455  		// If u is non-zero and v is zero, defined to return (0, FALSE).
   456  		{
   457  			decodeHex("0100000000000000000000000000000000000000000000000000000000000000"),
   458  			decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
   459  			0, decodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
   460  		},
   461  		// 2/1 is not square in this field.
   462  		{
   463  			decodeHex("0200000000000000000000000000000000000000000000000000000000000000"),
   464  			decodeHex("0100000000000000000000000000000000000000000000000000000000000000"),
   465  			0, decodeHex("3c5ff1b5d8e4113b871bd052f9e7bcd0582804c266ffb2d4f4203eb07fdb7c54"),
   466  		},
   467  		// 4/1 == 2²
   468  		{
   469  			decodeHex("0400000000000000000000000000000000000000000000000000000000000000"),
   470  			decodeHex("0100000000000000000000000000000000000000000000000000000000000000"),
   471  			1, decodeHex("0200000000000000000000000000000000000000000000000000000000000000"),
   472  		},
   473  		// 1/4 == (2⁻¹)² == (2^(p-2))² per Euler's theorem
   474  		{
   475  			decodeHex("0100000000000000000000000000000000000000000000000000000000000000"),
   476  			decodeHex("0400000000000000000000000000000000000000000000000000000000000000"),
   477  			1, decodeHex("f6ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3f"),
   478  		},
   479  	}
   480  
   481  	for i, tt := range tests {
   482  		u, _ := new(Element).SetBytes(tt.u)
   483  		v, _ := new(Element).SetBytes(tt.v)
   484  		want, _ := new(Element).SetBytes(tt.r)
   485  		got, wasSquare := new(Element).SqrtRatio(u, v)
   486  		if got.Equal(want) == 0 || wasSquare != tt.wasSquare {
   487  			t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare)
   488  		}
   489  	}
   490  }
   491  
   492  func TestCarryPropagate(t *testing.T) {
   493  	asmLikeGeneric := func(a [5]uint64) bool {
   494  		t1 := &Element{a[0], a[1], a[2], a[3], a[4]}
   495  		t2 := &Element{a[0], a[1], a[2], a[3], a[4]}
   496  
   497  		t1.carryPropagate()
   498  		t2.carryPropagateGeneric()
   499  
   500  		if *t1 != *t2 {
   501  			t.Logf("got: %#v,\nexpected: %#v", t1, t2)
   502  		}
   503  
   504  		return *t1 == *t2 && isInBounds(t2)
   505  	}
   506  
   507  	if err := quick.Check(asmLikeGeneric, quickCheckConfig(1024)); err != nil {
   508  		t.Error(err)
   509  	}
   510  
   511  	if !asmLikeGeneric([5]uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}) {
   512  		t.Errorf("failed for {0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}")
   513  	}
   514  }
   515  
   516  func TestFeSquare(t *testing.T) {
   517  	asmLikeGeneric := func(a Element) bool {
   518  		t1 := a
   519  		t2 := a
   520  
   521  		feSquareGeneric(&t1, &t1)
   522  		feSquare(&t2, &t2)
   523  
   524  		if t1 != t2 {
   525  			t.Logf("got: %#v,\nexpected: %#v", t1, t2)
   526  		}
   527  
   528  		return t1 == t2 && isInBounds(&t2)
   529  	}
   530  
   531  	if err := quick.Check(asmLikeGeneric, quickCheckConfig(1024)); err != nil {
   532  		t.Error(err)
   533  	}
   534  }
   535  
   536  func TestFeMul(t *testing.T) {
   537  	asmLikeGeneric := func(a, b Element) bool {
   538  		a1 := a
   539  		a2 := a
   540  		b1 := b
   541  		b2 := b
   542  
   543  		feMulGeneric(&a1, &a1, &b1)
   544  		feMul(&a2, &a2, &b2)
   545  
   546  		if a1 != a2 || b1 != b2 {
   547  			t.Logf("got: %#v,\nexpected: %#v", a1, a2)
   548  			t.Logf("got: %#v,\nexpected: %#v", b1, b2)
   549  		}
   550  
   551  		return a1 == a2 && isInBounds(&a2) &&
   552  			b1 == b2 && isInBounds(&b2)
   553  	}
   554  
   555  	if err := quick.Check(asmLikeGeneric, quickCheckConfig(1024)); err != nil {
   556  		t.Error(err)
   557  	}
   558  }
   559  
   560  func decodeHex(s string) []byte {
   561  	b, err := hex.DecodeString(s)
   562  	if err != nil {
   563  		panic(err)
   564  	}
   565  	return b
   566  }
   567  

View as plain text