Source file src/crypto/internal/fips140test/nistec_test.go

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fipstest
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/elliptic"
    10  	"crypto/internal/cryptotest"
    11  	"crypto/internal/fips140/nistec"
    12  	"fmt"
    13  	"math/big"
    14  	"math/rand"
    15  	"testing"
    16  )
    17  
    18  func TestNISTECAllocations(t *testing.T) {
    19  	cryptotest.SkipTestAllocations(t)
    20  	t.Run("P224", func(t *testing.T) {
    21  		if allocs := testing.AllocsPerRun(10, func() {
    22  			p := nistec.NewP224Point().SetGenerator()
    23  			scalar := make([]byte, 28)
    24  			rand.Read(scalar)
    25  			p.ScalarBaseMult(scalar)
    26  			p.ScalarMult(p, scalar)
    27  			out := p.Bytes()
    28  			if _, err := nistec.NewP224Point().SetBytes(out); err != nil {
    29  				t.Fatal(err)
    30  			}
    31  			out = p.BytesCompressed()
    32  			if _, err := p.SetBytes(out); err != nil {
    33  				t.Fatal(err)
    34  			}
    35  		}); allocs > 0 {
    36  			t.Errorf("expected zero allocations, got %0.1f", allocs)
    37  		}
    38  	})
    39  	t.Run("P256", func(t *testing.T) {
    40  		if allocs := testing.AllocsPerRun(10, func() {
    41  			p := nistec.NewP256Point().SetGenerator()
    42  			scalar := make([]byte, 32)
    43  			rand.Read(scalar)
    44  			p.ScalarBaseMult(scalar)
    45  			p.ScalarMult(p, scalar)
    46  			out := p.Bytes()
    47  			if _, err := nistec.NewP256Point().SetBytes(out); err != nil {
    48  				t.Fatal(err)
    49  			}
    50  			out = p.BytesCompressed()
    51  			if _, err := p.SetBytes(out); err != nil {
    52  				t.Fatal(err)
    53  			}
    54  		}); allocs > 0 {
    55  			t.Errorf("expected zero allocations, got %0.1f", allocs)
    56  		}
    57  	})
    58  	t.Run("P384", func(t *testing.T) {
    59  		if allocs := testing.AllocsPerRun(10, func() {
    60  			p := nistec.NewP384Point().SetGenerator()
    61  			scalar := make([]byte, 48)
    62  			rand.Read(scalar)
    63  			p.ScalarBaseMult(scalar)
    64  			p.ScalarMult(p, scalar)
    65  			out := p.Bytes()
    66  			if _, err := nistec.NewP384Point().SetBytes(out); err != nil {
    67  				t.Fatal(err)
    68  			}
    69  			out = p.BytesCompressed()
    70  			if _, err := p.SetBytes(out); err != nil {
    71  				t.Fatal(err)
    72  			}
    73  		}); allocs > 0 {
    74  			t.Errorf("expected zero allocations, got %0.1f", allocs)
    75  		}
    76  	})
    77  	t.Run("P521", func(t *testing.T) {
    78  		if allocs := testing.AllocsPerRun(10, func() {
    79  			p := nistec.NewP521Point().SetGenerator()
    80  			scalar := make([]byte, 66)
    81  			rand.Read(scalar)
    82  			p.ScalarBaseMult(scalar)
    83  			p.ScalarMult(p, scalar)
    84  			out := p.Bytes()
    85  			if _, err := nistec.NewP521Point().SetBytes(out); err != nil {
    86  				t.Fatal(err)
    87  			}
    88  			out = p.BytesCompressed()
    89  			if _, err := p.SetBytes(out); err != nil {
    90  				t.Fatal(err)
    91  			}
    92  		}); allocs > 0 {
    93  			t.Errorf("expected zero allocations, got %0.1f", allocs)
    94  		}
    95  	})
    96  }
    97  
    98  type nistPoint[T any] interface {
    99  	Bytes() []byte
   100  	SetGenerator() T
   101  	SetBytes([]byte) (T, error)
   102  	Add(T, T) T
   103  	Double(T) T
   104  	ScalarMult(T, []byte) (T, error)
   105  	ScalarBaseMult([]byte) (T, error)
   106  }
   107  
   108  func TestEquivalents(t *testing.T) {
   109  	t.Run("P224", func(t *testing.T) {
   110  		testEquivalents(t, nistec.NewP224Point, elliptic.P224())
   111  	})
   112  	t.Run("P256", func(t *testing.T) {
   113  		testEquivalents(t, nistec.NewP256Point, elliptic.P256())
   114  	})
   115  	t.Run("P384", func(t *testing.T) {
   116  		testEquivalents(t, nistec.NewP384Point, elliptic.P384())
   117  	})
   118  	t.Run("P521", func(t *testing.T) {
   119  		testEquivalents(t, nistec.NewP521Point, elliptic.P521())
   120  	})
   121  }
   122  
   123  func testEquivalents[P nistPoint[P]](t *testing.T, newPoint func() P, c elliptic.Curve) {
   124  	p := newPoint().SetGenerator()
   125  
   126  	elementSize := (c.Params().BitSize + 7) / 8
   127  	two := make([]byte, elementSize)
   128  	two[len(two)-1] = 2
   129  	nPlusTwo := make([]byte, elementSize)
   130  	new(big.Int).Add(c.Params().N, big.NewInt(2)).FillBytes(nPlusTwo)
   131  
   132  	p1 := newPoint().Double(p)
   133  	p2 := newPoint().Add(p, p)
   134  	p3, err := newPoint().ScalarMult(p, two)
   135  	fatalIfErr(t, err)
   136  	p4, err := newPoint().ScalarBaseMult(two)
   137  	fatalIfErr(t, err)
   138  	p5, err := newPoint().ScalarMult(p, nPlusTwo)
   139  	fatalIfErr(t, err)
   140  	p6, err := newPoint().ScalarBaseMult(nPlusTwo)
   141  	fatalIfErr(t, err)
   142  
   143  	if !bytes.Equal(p1.Bytes(), p2.Bytes()) {
   144  		t.Error("P+P != 2*P")
   145  	}
   146  	if !bytes.Equal(p1.Bytes(), p3.Bytes()) {
   147  		t.Error("P+P != [2]P")
   148  	}
   149  	if !bytes.Equal(p1.Bytes(), p4.Bytes()) {
   150  		t.Error("G+G != [2]G")
   151  	}
   152  	if !bytes.Equal(p1.Bytes(), p5.Bytes()) {
   153  		t.Error("P+P != [N+2]P")
   154  	}
   155  	if !bytes.Equal(p1.Bytes(), p6.Bytes()) {
   156  		t.Error("G+G != [N+2]G")
   157  	}
   158  }
   159  
   160  func TestScalarMult(t *testing.T) {
   161  	t.Run("P224", func(t *testing.T) {
   162  		testScalarMult(t, nistec.NewP224Point, elliptic.P224())
   163  	})
   164  	t.Run("P256", func(t *testing.T) {
   165  		testScalarMult(t, nistec.NewP256Point, elliptic.P256())
   166  	})
   167  	t.Run("P384", func(t *testing.T) {
   168  		testScalarMult(t, nistec.NewP384Point, elliptic.P384())
   169  	})
   170  	t.Run("P521", func(t *testing.T) {
   171  		testScalarMult(t, nistec.NewP521Point, elliptic.P521())
   172  	})
   173  }
   174  
   175  func testScalarMult[P nistPoint[P]](t *testing.T, newPoint func() P, c elliptic.Curve) {
   176  	G := newPoint().SetGenerator()
   177  	checkScalar := func(t *testing.T, scalar []byte) {
   178  		p1, err := newPoint().ScalarBaseMult(scalar)
   179  		fatalIfErr(t, err)
   180  		p2, err := newPoint().ScalarMult(G, scalar)
   181  		fatalIfErr(t, err)
   182  		if !bytes.Equal(p1.Bytes(), p2.Bytes()) {
   183  			t.Error("[k]G != ScalarBaseMult(k)")
   184  		}
   185  
   186  		expectInfinity := new(big.Int).Mod(new(big.Int).SetBytes(scalar), c.Params().N).Sign() == 0
   187  		if expectInfinity {
   188  			if !bytes.Equal(p1.Bytes(), newPoint().Bytes()) {
   189  				t.Error("ScalarBaseMult(k) != ∞")
   190  			}
   191  			if !bytes.Equal(p2.Bytes(), newPoint().Bytes()) {
   192  				t.Error("[k]G != ∞")
   193  			}
   194  		} else {
   195  			if bytes.Equal(p1.Bytes(), newPoint().Bytes()) {
   196  				t.Error("ScalarBaseMult(k) == ∞")
   197  			}
   198  			if bytes.Equal(p2.Bytes(), newPoint().Bytes()) {
   199  				t.Error("[k]G == ∞")
   200  			}
   201  		}
   202  
   203  		d := new(big.Int).SetBytes(scalar)
   204  		d.Sub(c.Params().N, d)
   205  		d.Mod(d, c.Params().N)
   206  		g1, err := newPoint().ScalarBaseMult(d.FillBytes(make([]byte, len(scalar))))
   207  		fatalIfErr(t, err)
   208  		g1.Add(g1, p1)
   209  		if !bytes.Equal(g1.Bytes(), newPoint().Bytes()) {
   210  			t.Error("[N - k]G + [k]G != ∞")
   211  		}
   212  	}
   213  
   214  	byteLen := len(c.Params().N.Bytes())
   215  	bitLen := c.Params().N.BitLen()
   216  	t.Run("0", func(t *testing.T) { checkScalar(t, make([]byte, byteLen)) })
   217  	t.Run("1", func(t *testing.T) {
   218  		checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen)))
   219  	})
   220  	t.Run("N-1", func(t *testing.T) {
   221  		checkScalar(t, new(big.Int).Sub(c.Params().N, big.NewInt(1)).Bytes())
   222  	})
   223  	t.Run("N", func(t *testing.T) { checkScalar(t, c.Params().N.Bytes()) })
   224  	t.Run("N+1", func(t *testing.T) {
   225  		checkScalar(t, new(big.Int).Add(c.Params().N, big.NewInt(1)).Bytes())
   226  	})
   227  	t.Run("all1s", func(t *testing.T) {
   228  		s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen))
   229  		s.Sub(s, big.NewInt(1))
   230  		checkScalar(t, s.Bytes())
   231  	})
   232  	if testing.Short() {
   233  		return
   234  	}
   235  	for i := 0; i < bitLen; i++ {
   236  		t.Run(fmt.Sprintf("1<<%d", i), func(t *testing.T) {
   237  			s := new(big.Int).Lsh(big.NewInt(1), uint(i))
   238  			checkScalar(t, s.FillBytes(make([]byte, byteLen)))
   239  		})
   240  	}
   241  	for i := 0; i <= 64; i++ {
   242  		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   243  			checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen)))
   244  		})
   245  	}
   246  	// Test N-64...N+64 since they risk overlapping with precomputed table values
   247  	// in the final additions.
   248  	for i := int64(-64); i <= 64; i++ {
   249  		t.Run(fmt.Sprintf("N%+d", i), func(t *testing.T) {
   250  			checkScalar(t, new(big.Int).Add(c.Params().N, big.NewInt(i)).Bytes())
   251  		})
   252  	}
   253  }
   254  
   255  func fatalIfErr(t *testing.T, err error) {
   256  	t.Helper()
   257  	if err != nil {
   258  		t.Fatal(err)
   259  	}
   260  }
   261  

View as plain text