Source file src/crypto/internal/mlkem768/mlkem768_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mlkem768
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	_ "embed"
    11  	"encoding/hex"
    12  	"errors"
    13  	"flag"
    14  	"math/big"
    15  	"strconv"
    16  	"testing"
    17  
    18  	"golang.org/x/crypto/sha3"
    19  )
    20  
    21  func TestFieldReduce(t *testing.T) {
    22  	for a := uint32(0); a < 2*q*q; a++ {
    23  		got := fieldReduce(a)
    24  		exp := fieldElement(a % q)
    25  		if got != exp {
    26  			t.Fatalf("reduce(%d) = %d, expected %d", a, got, exp)
    27  		}
    28  	}
    29  }
    30  
    31  func TestFieldAdd(t *testing.T) {
    32  	for a := fieldElement(0); a < q; a++ {
    33  		for b := fieldElement(0); b < q; b++ {
    34  			got := fieldAdd(a, b)
    35  			exp := (a + b) % q
    36  			if got != exp {
    37  				t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp)
    38  			}
    39  		}
    40  	}
    41  }
    42  
    43  func TestFieldSub(t *testing.T) {
    44  	for a := fieldElement(0); a < q; a++ {
    45  		for b := fieldElement(0); b < q; b++ {
    46  			got := fieldSub(a, b)
    47  			exp := (a - b + q) % q
    48  			if got != exp {
    49  				t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp)
    50  			}
    51  		}
    52  	}
    53  }
    54  
    55  func TestFieldMul(t *testing.T) {
    56  	for a := fieldElement(0); a < q; a++ {
    57  		for b := fieldElement(0); b < q; b++ {
    58  			got := fieldMul(a, b)
    59  			exp := fieldElement((uint32(a) * uint32(b)) % q)
    60  			if got != exp {
    61  				t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp)
    62  			}
    63  		}
    64  	}
    65  }
    66  
    67  func TestDecompressCompress(t *testing.T) {
    68  	for _, bits := range []uint8{1, 4, 10} {
    69  		for a := uint16(0); a < 1<<bits; a++ {
    70  			f := decompress(a, bits)
    71  			if f >= q {
    72  				t.Fatalf("decompress(%d, %d) = %d >= q", a, bits, f)
    73  			}
    74  			got := compress(f, bits)
    75  			if got != a {
    76  				t.Fatalf("compress(decompress(%d, %d), %d) = %d", a, bits, bits, got)
    77  			}
    78  		}
    79  
    80  		for a := fieldElement(0); a < q; a++ {
    81  			c := compress(a, bits)
    82  			if c >= 1<<bits {
    83  				t.Fatalf("compress(%d, %d) = %d >= 2^bits", a, bits, c)
    84  			}
    85  			got := decompress(c, bits)
    86  			diff := min(a-got, got-a, a-got+q, got-a+q)
    87  			ceil := q / (1 << bits)
    88  			if diff > fieldElement(ceil) {
    89  				t.Fatalf("decompress(compress(%d, %d), %d) = %d (diff %d, max diff %d)",
    90  					a, bits, bits, got, diff, ceil)
    91  			}
    92  		}
    93  	}
    94  }
    95  
    96  func CompressRat(x fieldElement, d uint8) uint16 {
    97  	if x >= q {
    98  		panic("x out of range")
    99  	}
   100  	if d <= 0 || d >= 12 {
   101  		panic("d out of range")
   102  	}
   103  
   104  	precise := big.NewRat((1<<d)*int64(x), q) // (2ᵈ / q) * x == (2ᵈ * x) / q
   105  
   106  	// FloatString rounds halves away from 0, and our result should always be positive,
   107  	// so it should work as we expect. (There's no direct way to round a Rat.)
   108  	rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
   109  	if err != nil {
   110  		panic(err)
   111  	}
   112  
   113  	// If we rounded up, `rounded` may be equal to 2ᵈ, so we perform a final reduction.
   114  	return uint16(rounded % (1 << d))
   115  }
   116  
   117  func TestCompress(t *testing.T) {
   118  	for d := 1; d < 12; d++ {
   119  		for n := 0; n < q; n++ {
   120  			expected := CompressRat(fieldElement(n), uint8(d))
   121  			result := compress(fieldElement(n), uint8(d))
   122  			if result != expected {
   123  				t.Errorf("compress(%d, %d): got %d, expected %d", n, d, result, expected)
   124  			}
   125  		}
   126  	}
   127  }
   128  
   129  func DecompressRat(y uint16, d uint8) fieldElement {
   130  	if y >= 1<<d {
   131  		panic("y out of range")
   132  	}
   133  	if d <= 0 || d >= 12 {
   134  		panic("d out of range")
   135  	}
   136  
   137  	precise := big.NewRat(q*int64(y), 1<<d) // (q / 2ᵈ) * y  ==  (q * y) / 2ᵈ
   138  
   139  	// FloatString rounds halves away from 0, and our result should always be positive,
   140  	// so it should work as we expect. (There's no direct way to round a Rat.)
   141  	rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
   142  	if err != nil {
   143  		panic(err)
   144  	}
   145  
   146  	// If we rounded up, `rounded` may be equal to q, so we perform a final reduction.
   147  	return fieldElement(rounded % q)
   148  }
   149  
   150  func TestDecompress(t *testing.T) {
   151  	for d := 1; d < 12; d++ {
   152  		for n := 0; n < (1 << d); n++ {
   153  			expected := DecompressRat(uint16(n), uint8(d))
   154  			result := decompress(uint16(n), uint8(d))
   155  			if result != expected {
   156  				t.Errorf("decompress(%d, %d): got %d, expected %d", n, d, result, expected)
   157  			}
   158  		}
   159  	}
   160  }
   161  
   162  func BitRev7(n uint8) uint8 {
   163  	if n>>7 != 0 {
   164  		panic("not 7 bits")
   165  	}
   166  	var r uint8
   167  	r |= n >> 6 & 0b0000_0001
   168  	r |= n >> 4 & 0b0000_0010
   169  	r |= n >> 2 & 0b0000_0100
   170  	r |= n /**/ & 0b0000_1000
   171  	r |= n << 2 & 0b0001_0000
   172  	r |= n << 4 & 0b0010_0000
   173  	r |= n << 6 & 0b0100_0000
   174  	return r
   175  }
   176  
   177  func TestZetas(t *testing.T) {
   178  	ζ := big.NewInt(17)
   179  	q := big.NewInt(q)
   180  	for k, zeta := range zetas {
   181  		// ζ^BitRev7(k) mod q
   182  		exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))), q)
   183  		if big.NewInt(int64(zeta)).Cmp(exp) != 0 {
   184  			t.Errorf("zetas[%d] = %v, expected %v", k, zeta, exp)
   185  		}
   186  	}
   187  }
   188  
   189  func TestGammas(t *testing.T) {
   190  	ζ := big.NewInt(17)
   191  	q := big.NewInt(q)
   192  	for k, gamma := range gammas {
   193  		// ζ^2BitRev7(i)+1
   194  		exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))*2+1), q)
   195  		if big.NewInt(int64(gamma)).Cmp(exp) != 0 {
   196  			t.Errorf("gammas[%d] = %v, expected %v", k, gamma, exp)
   197  		}
   198  	}
   199  }
   200  
   201  func TestRoundTrip(t *testing.T) {
   202  	dk, err := GenerateKey()
   203  	if err != nil {
   204  		t.Fatal(err)
   205  	}
   206  	c, Ke, err := Encapsulate(dk.EncapsulationKey())
   207  	if err != nil {
   208  		t.Fatal(err)
   209  	}
   210  	Kd, err := Decapsulate(dk, c)
   211  	if err != nil {
   212  		t.Fatal(err)
   213  	}
   214  	if !bytes.Equal(Ke, Kd) {
   215  		t.Fail()
   216  	}
   217  
   218  	dk1, err := GenerateKey()
   219  	if err != nil {
   220  		t.Fatal(err)
   221  	}
   222  	if bytes.Equal(dk.EncapsulationKey(), dk1.EncapsulationKey()) {
   223  		t.Fail()
   224  	}
   225  	if bytes.Equal(dk.Bytes(), dk1.Bytes()) {
   226  		t.Fail()
   227  	}
   228  	if bytes.Equal(dk.Bytes()[EncapsulationKeySize-32:], dk1.Bytes()[EncapsulationKeySize-32:]) {
   229  		t.Fail()
   230  	}
   231  
   232  	c1, Ke1, err := Encapsulate(dk.EncapsulationKey())
   233  	if err != nil {
   234  		t.Fatal(err)
   235  	}
   236  	if bytes.Equal(c, c1) {
   237  		t.Fail()
   238  	}
   239  	if bytes.Equal(Ke, Ke1) {
   240  		t.Fail()
   241  	}
   242  }
   243  
   244  func TestBadLengths(t *testing.T) {
   245  	dk, err := GenerateKey()
   246  	if err != nil {
   247  		t.Fatal(err)
   248  	}
   249  	ek := dk.EncapsulationKey()
   250  
   251  	for i := 0; i < len(ek)-1; i++ {
   252  		if _, _, err := Encapsulate(ek[:i]); err == nil {
   253  			t.Errorf("expected error for ek length %d", i)
   254  		}
   255  	}
   256  	ekLong := ek
   257  	for i := 0; i < 100; i++ {
   258  		ekLong = append(ekLong, 0)
   259  		if _, _, err := Encapsulate(ekLong); err == nil {
   260  			t.Errorf("expected error for ek length %d", len(ekLong))
   261  		}
   262  	}
   263  
   264  	c, _, err := Encapsulate(ek)
   265  	if err != nil {
   266  		t.Fatal(err)
   267  	}
   268  
   269  	for i := 0; i < len(dk.Bytes())-1; i++ {
   270  		if _, err := NewKeyFromExtendedEncoding(dk.Bytes()[:i]); err == nil {
   271  			t.Errorf("expected error for dk length %d", i)
   272  		}
   273  	}
   274  	dkLong := dk.Bytes()
   275  	for i := 0; i < 100; i++ {
   276  		dkLong = append(dkLong, 0)
   277  		if _, err := NewKeyFromExtendedEncoding(dkLong); err == nil {
   278  			t.Errorf("expected error for dk length %d", len(dkLong))
   279  		}
   280  	}
   281  
   282  	for i := 0; i < len(c)-1; i++ {
   283  		if _, err := Decapsulate(dk, c[:i]); err == nil {
   284  			t.Errorf("expected error for c length %d", i)
   285  		}
   286  	}
   287  	cLong := c
   288  	for i := 0; i < 100; i++ {
   289  		cLong = append(cLong, 0)
   290  		if _, err := Decapsulate(dk, cLong); err == nil {
   291  			t.Errorf("expected error for c length %d", len(cLong))
   292  		}
   293  	}
   294  }
   295  
   296  func EncapsulateDerand(ek, m []byte) (c, K []byte, err error) {
   297  	if len(m) != messageSize {
   298  		return nil, nil, errors.New("bad message length")
   299  	}
   300  	return kemEncaps(nil, ek, (*[messageSize]byte)(m))
   301  }
   302  
   303  func DecapsulateFromBytes(dkBytes []byte, c []byte) ([]byte, error) {
   304  	dk, err := NewKeyFromExtendedEncoding(dkBytes)
   305  	if err != nil {
   306  		return nil, err
   307  	}
   308  	return Decapsulate(dk, c)
   309  }
   310  
   311  func GenerateKeyDerand(t testing.TB, d, z []byte) ([]byte, *DecapsulationKey) {
   312  	if len(d) != 32 || len(z) != 32 {
   313  		t.Fatal("bad length")
   314  	}
   315  	dk := kemKeyGen(nil, (*[32]byte)(d), (*[32]byte)(z))
   316  	return dk.EncapsulationKey(), dk
   317  }
   318  
   319  var millionFlag = flag.Bool("million", false, "run the million vector test")
   320  
   321  // TestPQCrystalsAccumulated accumulates the 10k vectors generated by the
   322  // reference implementation and checks the hash of the result, to avoid checking
   323  // in 150MB of test vectors.
   324  func TestPQCrystalsAccumulated(t *testing.T) {
   325  	n := 10000
   326  	expected := "f7db260e1137a742e05fe0db9525012812b004d29040a5b606aad3d134b548d3"
   327  	if testing.Short() {
   328  		n = 100
   329  		expected = "8d0c478ead6037897a0da6be21e5399545babf5fc6dd10c061c99b7dee2bf0dc"
   330  	}
   331  	if *millionFlag {
   332  		n = 1000000
   333  		expected = "70090cc5842aad0ec43d5042c783fae9bc320c047b5dafcb6e134821db02384d"
   334  	}
   335  
   336  	s := sha3.NewShake128()
   337  	o := sha3.NewShake128()
   338  	d := make([]byte, 32)
   339  	z := make([]byte, 32)
   340  	msg := make([]byte, 32)
   341  	ct1 := make([]byte, CiphertextSize)
   342  
   343  	for i := 0; i < n; i++ {
   344  		s.Read(d)
   345  		s.Read(z)
   346  		ek, dk := GenerateKeyDerand(t, d, z)
   347  		o.Write(ek)
   348  		o.Write(dk.Bytes())
   349  
   350  		s.Read(msg)
   351  		ct, k, err := EncapsulateDerand(ek, msg)
   352  		if err != nil {
   353  			t.Fatal(err)
   354  		}
   355  		o.Write(ct)
   356  		o.Write(k)
   357  
   358  		kk, err := Decapsulate(dk, ct)
   359  		if err != nil {
   360  			t.Fatal(err)
   361  		}
   362  		if !bytes.Equal(kk, k) {
   363  			t.Errorf("k: got %x, expected %x", kk, k)
   364  		}
   365  
   366  		s.Read(ct1)
   367  		k1, err := Decapsulate(dk, ct1)
   368  		if err != nil {
   369  			t.Fatal(err)
   370  		}
   371  		o.Write(k1)
   372  	}
   373  
   374  	got := hex.EncodeToString(o.Sum(nil))
   375  	if got != expected {
   376  		t.Errorf("got %s, expected %s", got, expected)
   377  	}
   378  }
   379  
   380  var sink byte
   381  
   382  func BenchmarkKeyGen(b *testing.B) {
   383  	var dk DecapsulationKey
   384  	var d, z [32]byte
   385  	rand.Read(d[:])
   386  	rand.Read(z[:])
   387  	b.ResetTimer()
   388  	for i := 0; i < b.N; i++ {
   389  		dk := kemKeyGen(&dk, &d, &z)
   390  		sink ^= dk.EncapsulationKey()[0]
   391  	}
   392  }
   393  
   394  func BenchmarkEncaps(b *testing.B) {
   395  	d := make([]byte, 32)
   396  	rand.Read(d)
   397  	z := make([]byte, 32)
   398  	rand.Read(z)
   399  	var m [messageSize]byte
   400  	rand.Read(m[:])
   401  	ek, _ := GenerateKeyDerand(b, d, z)
   402  	var c [CiphertextSize]byte
   403  	b.ResetTimer()
   404  	for i := 0; i < b.N; i++ {
   405  		c, K, err := kemEncaps(&c, ek, &m)
   406  		if err != nil {
   407  			b.Fatal(err)
   408  		}
   409  		sink ^= c[0] ^ K[0]
   410  	}
   411  }
   412  
   413  func BenchmarkDecaps(b *testing.B) {
   414  	d := make([]byte, 32)
   415  	rand.Read(d)
   416  	z := make([]byte, 32)
   417  	rand.Read(z)
   418  	m := make([]byte, 32)
   419  	rand.Read(m)
   420  	ek, dk := GenerateKeyDerand(b, d, z)
   421  	c, _, err := EncapsulateDerand(ek, m)
   422  	if err != nil {
   423  		b.Fatal(err)
   424  	}
   425  	b.ResetTimer()
   426  	for i := 0; i < b.N; i++ {
   427  		K := kemDecaps(dk, (*[CiphertextSize]byte)(c))
   428  		sink ^= K[0]
   429  	}
   430  }
   431  
   432  func BenchmarkRoundTrip(b *testing.B) {
   433  	dk, err := GenerateKey()
   434  	if err != nil {
   435  		b.Fatal(err)
   436  	}
   437  	ek := dk.EncapsulationKey()
   438  	c, _, err := Encapsulate(ek)
   439  	if err != nil {
   440  		b.Fatal(err)
   441  	}
   442  	b.Run("Alice", func(b *testing.B) {
   443  		for i := 0; i < b.N; i++ {
   444  			dkS, err := GenerateKey()
   445  			if err != nil {
   446  				b.Fatal(err)
   447  			}
   448  			ekS := dkS.EncapsulationKey()
   449  			sink ^= ekS[0]
   450  
   451  			Ks, err := Decapsulate(dk, c)
   452  			if err != nil {
   453  				b.Fatal(err)
   454  			}
   455  			sink ^= Ks[0]
   456  		}
   457  	})
   458  	b.Run("Bob", func(b *testing.B) {
   459  		for i := 0; i < b.N; i++ {
   460  			cS, Ks, err := Encapsulate(ek)
   461  			if err != nil {
   462  				b.Fatal(err)
   463  			}
   464  			sink ^= cS[0] ^ Ks[0]
   465  		}
   466  	})
   467  }
   468  

View as plain text