Source file src/crypto/internal/fips140/mlkem/mlkem768.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package mlkem implements the quantum-resistant key encapsulation method
     6  // ML-KEM (formerly known as Kyber), as specified in [NIST FIPS 203].
     7  //
     8  // [NIST FIPS 203]: https://doi.org/10.6028/NIST.FIPS.203
     9  package mlkem
    10  
    11  // This package targets security, correctness, simplicity, readability, and
    12  // reviewability as its primary goals. All critical operations are performed in
    13  // constant time.
    14  //
    15  // Variable and function names, as well as code layout, are selected to
    16  // facilitate reviewing the implementation against the NIST FIPS 203 document.
    17  //
    18  // Reviewers unfamiliar with polynomials or linear algebra might find the
    19  // background at https://words.filippo.io/kyber-math/ useful.
    20  //
    21  // This file implements the recommended parameter set ML-KEM-768. The ML-KEM-1024
    22  // parameter set implementation is auto-generated from this file.
    23  //
    24  //go:generate go run generate1024.go -input mlkem768.go -output mlkem1024.go
    25  
    26  import (
    27  	"crypto/internal/fips140"
    28  	"crypto/internal/fips140/drbg"
    29  	"crypto/internal/fips140/sha3"
    30  	"crypto/internal/fips140/subtle"
    31  	"errors"
    32  )
    33  
    34  const (
    35  	// ML-KEM global constants.
    36  	n = 256
    37  	q = 3329
    38  
    39  	// encodingSizeX is the byte size of a ringElement or nttElement encoded
    40  	// by ByteEncode_X (FIPS 203, Algorithm 5).
    41  	encodingSize12 = n * 12 / 8
    42  	encodingSize11 = n * 11 / 8
    43  	encodingSize10 = n * 10 / 8
    44  	encodingSize5  = n * 5 / 8
    45  	encodingSize4  = n * 4 / 8
    46  	encodingSize1  = n * 1 / 8
    47  
    48  	messageSize = encodingSize1
    49  
    50  	SharedKeySize = 32
    51  	SeedSize      = 32 + 32
    52  )
    53  
    54  // ML-KEM-768 parameters.
    55  const (
    56  	k = 3
    57  
    58  	CiphertextSize768       = k*encodingSize10 + encodingSize4
    59  	EncapsulationKeySize768 = k*encodingSize12 + 32
    60  )
    61  
    62  // ML-KEM-1024 parameters.
    63  const (
    64  	k1024 = 4
    65  
    66  	CiphertextSize1024       = k1024*encodingSize11 + encodingSize5
    67  	EncapsulationKeySize1024 = k1024*encodingSize12 + 32
    68  )
    69  
    70  // A DecapsulationKey768 is the secret key used to decapsulate a shared key from a
    71  // ciphertext. It includes various precomputed values.
    72  type DecapsulationKey768 struct {
    73  	d [32]byte // decapsulation key seed
    74  	z [32]byte // implicit rejection sampling seed
    75  
    76  	ρ [32]byte // sampleNTT seed for A, stored for the encapsulation key
    77  	h [32]byte // H(ek), stored for ML-KEM.Decaps_internal
    78  
    79  	encryptionKey
    80  	decryptionKey
    81  }
    82  
    83  // Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form.
    84  //
    85  // The decapsulation key must be kept secret.
    86  func (dk *DecapsulationKey768) Bytes() []byte {
    87  	var b [SeedSize]byte
    88  	copy(b[:], dk.d[:])
    89  	copy(b[32:], dk.z[:])
    90  	return b[:]
    91  }
    92  
    93  // EncapsulationKey returns the public encapsulation key necessary to produce
    94  // ciphertexts.
    95  func (dk *DecapsulationKey768) EncapsulationKey() *EncapsulationKey768 {
    96  	return &EncapsulationKey768{
    97  		ρ:             dk.ρ,
    98  		h:             dk.h,
    99  		encryptionKey: dk.encryptionKey,
   100  	}
   101  }
   102  
   103  // An EncapsulationKey768 is the public key used to produce ciphertexts to be
   104  // decapsulated by the corresponding [DecapsulationKey768].
   105  type EncapsulationKey768 struct {
   106  	ρ [32]byte // sampleNTT seed for A
   107  	h [32]byte // H(ek)
   108  	encryptionKey
   109  }
   110  
   111  // Bytes returns the encapsulation key as a byte slice.
   112  func (ek *EncapsulationKey768) Bytes() []byte {
   113  	// The actual logic is in a separate function to outline this allocation.
   114  	b := make([]byte, 0, EncapsulationKeySize768)
   115  	return ek.bytes(b)
   116  }
   117  
   118  func (ek *EncapsulationKey768) bytes(b []byte) []byte {
   119  	for i := range ek.t {
   120  		b = polyByteEncode(b, ek.t[i])
   121  	}
   122  	b = append(b, ek.ρ[:]...)
   123  	return b
   124  }
   125  
   126  // encryptionKey is the parsed and expanded form of a PKE encryption key.
   127  type encryptionKey struct {
   128  	t [k]nttElement     // ByteDecode₁₂(ek[:384k])
   129  	a [k * k]nttElement // A[i*k+j] = sampleNTT(ρ, j, i)
   130  }
   131  
   132  // decryptionKey is the parsed and expanded form of a PKE decryption key.
   133  type decryptionKey struct {
   134  	s [k]nttElement // ByteDecode₁₂(dk[:decryptionKeySize])
   135  }
   136  
   137  // GenerateKey768 generates a new decapsulation key, drawing random bytes from
   138  // a DRBG. The decapsulation key must be kept secret.
   139  func GenerateKey768() (*DecapsulationKey768, error) {
   140  	// The actual logic is in a separate function to outline this allocation.
   141  	dk := &DecapsulationKey768{}
   142  	return generateKey(dk)
   143  }
   144  
   145  func generateKey(dk *DecapsulationKey768) (*DecapsulationKey768, error) {
   146  	var d [32]byte
   147  	drbg.Read(d[:])
   148  	var z [32]byte
   149  	drbg.Read(z[:])
   150  	kemKeyGen(dk, &d, &z)
   151  	if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT(dk) }); err != nil {
   152  		// This clearly can't happen, but FIPS 140-3 requires us to check.
   153  		panic(err)
   154  	}
   155  	fips140.RecordApproved()
   156  	return dk, nil
   157  }
   158  
   159  // GenerateKeyInternal768 is a derandomized version of GenerateKey768,
   160  // exclusively for use in tests.
   161  func GenerateKeyInternal768(d, z *[32]byte) *DecapsulationKey768 {
   162  	dk := &DecapsulationKey768{}
   163  	kemKeyGen(dk, d, z)
   164  	return dk
   165  }
   166  
   167  // NewDecapsulationKey768 parses a decapsulation key from a 64-byte
   168  // seed in the "d || z" form. The seed must be uniformly random.
   169  func NewDecapsulationKey768(seed []byte) (*DecapsulationKey768, error) {
   170  	// The actual logic is in a separate function to outline this allocation.
   171  	dk := &DecapsulationKey768{}
   172  	return newKeyFromSeed(dk, seed)
   173  }
   174  
   175  func newKeyFromSeed(dk *DecapsulationKey768, seed []byte) (*DecapsulationKey768, error) {
   176  	if len(seed) != SeedSize {
   177  		return nil, errors.New("mlkem: invalid seed length")
   178  	}
   179  	d := (*[32]byte)(seed[:32])
   180  	z := (*[32]byte)(seed[32:])
   181  	kemKeyGen(dk, d, z)
   182  	if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT(dk) }); err != nil {
   183  		// This clearly can't happen, but FIPS 140-3 requires us to check.
   184  		panic(err)
   185  	}
   186  	fips140.RecordApproved()
   187  	return dk, nil
   188  }
   189  
   190  // kemKeyGen generates a decapsulation key.
   191  //
   192  // It implements ML-KEM.KeyGen_internal according to FIPS 203, Algorithm 16, and
   193  // K-PKE.KeyGen according to FIPS 203, Algorithm 13. The two are merged to save
   194  // copies and allocations.
   195  func kemKeyGen(dk *DecapsulationKey768, d, z *[32]byte) {
   196  	dk.d = *d
   197  	dk.z = *z
   198  
   199  	g := sha3.New512()
   200  	g.Write(d[:])
   201  	g.Write([]byte{k}) // Module dimension as a domain separator.
   202  	G := g.Sum(make([]byte, 0, 64))
   203  	ρ, σ := G[:32], G[32:]
   204  	dk.ρ = [32]byte(ρ)
   205  
   206  	A := &dk.a
   207  	for i := byte(0); i < k; i++ {
   208  		for j := byte(0); j < k; j++ {
   209  			A[i*k+j] = sampleNTT(ρ, j, i)
   210  		}
   211  	}
   212  
   213  	var N byte
   214  	s := &dk.s
   215  	for i := range s {
   216  		s[i] = ntt(samplePolyCBD(σ, N))
   217  		N++
   218  	}
   219  	e := make([]nttElement, k)
   220  	for i := range e {
   221  		e[i] = ntt(samplePolyCBD(σ, N))
   222  		N++
   223  	}
   224  
   225  	t := &dk.t
   226  	for i := range t { // t = A ◦ s + e
   227  		t[i] = e[i]
   228  		for j := range s {
   229  			t[i] = polyAdd(t[i], nttMul(A[i*k+j], s[j]))
   230  		}
   231  	}
   232  
   233  	H := sha3.New256()
   234  	ek := dk.EncapsulationKey().Bytes()
   235  	H.Write(ek)
   236  	H.Sum(dk.h[:0])
   237  }
   238  
   239  // kemPCT performs a Pairwise Consistency Test per FIPS 140-3 IG 10.3.A
   240  // Additional Comment 1: "For key pairs generated for use with approved KEMs in
   241  // FIPS 203, the PCT shall consist of applying the encapsulation key ek to
   242  // encapsulate a shared secret K leading to ciphertext c, and then applying
   243  // decapsulation key dk to retrieve the same shared secret K. The PCT passes if
   244  // the two shared secret K values are equal. The PCT shall be performed either
   245  // when keys are generated/imported, prior to the first exportation, or prior to
   246  // the first operational use (if not exported before the first use)."
   247  func kemPCT(dk *DecapsulationKey768) error {
   248  	ek := dk.EncapsulationKey()
   249  	c, K := ek.Encapsulate()
   250  	K1, err := dk.Decapsulate(c)
   251  	if err != nil {
   252  		return err
   253  	}
   254  	if subtle.ConstantTimeCompare(K, K1) != 1 {
   255  		return errors.New("mlkem: PCT failed")
   256  	}
   257  	return nil
   258  }
   259  
   260  // Encapsulate generates a shared key and an associated ciphertext from an
   261  // encapsulation key, drawing random bytes from a DRBG.
   262  //
   263  // The shared key must be kept secret.
   264  func (ek *EncapsulationKey768) Encapsulate() (ciphertext, sharedKey []byte) {
   265  	// The actual logic is in a separate function to outline this allocation.
   266  	var cc [CiphertextSize768]byte
   267  	return ek.encapsulate(&cc)
   268  }
   269  
   270  func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (ciphertext, sharedKey []byte) {
   271  	var m [messageSize]byte
   272  	drbg.Read(m[:])
   273  	// Note that the modulus check (step 2 of the encapsulation key check from
   274  	// FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK.
   275  	fips140.RecordApproved()
   276  	return kemEncaps(cc, ek, &m)
   277  }
   278  
   279  // EncapsulateInternal is a derandomized version of Encapsulate, exclusively for
   280  // use in tests.
   281  func (ek *EncapsulationKey768) EncapsulateInternal(m *[32]byte) (ciphertext, sharedKey []byte) {
   282  	cc := &[CiphertextSize768]byte{}
   283  	return kemEncaps(cc, ek, m)
   284  }
   285  
   286  // kemEncaps generates a shared key and an associated ciphertext.
   287  //
   288  // It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
   289  func kemEncaps(cc *[CiphertextSize768]byte, ek *EncapsulationKey768, m *[messageSize]byte) (c, K []byte) {
   290  	g := sha3.New512()
   291  	g.Write(m[:])
   292  	g.Write(ek.h[:])
   293  	G := g.Sum(nil)
   294  	K, r := G[:SharedKeySize], G[SharedKeySize:]
   295  	c = pkeEncrypt(cc, &ek.encryptionKey, m, r)
   296  	return c, K
   297  }
   298  
   299  // NewEncapsulationKey768 parses an encapsulation key from its encoded form.
   300  // If the encapsulation key is not valid, NewEncapsulationKey768 returns an error.
   301  func NewEncapsulationKey768(encapsulationKey []byte) (*EncapsulationKey768, error) {
   302  	// The actual logic is in a separate function to outline this allocation.
   303  	ek := &EncapsulationKey768{}
   304  	return parseEK(ek, encapsulationKey)
   305  }
   306  
   307  // parseEK parses an encryption key from its encoded form.
   308  //
   309  // It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
   310  // Algorithm 14.
   311  func parseEK(ek *EncapsulationKey768, ekPKE []byte) (*EncapsulationKey768, error) {
   312  	if len(ekPKE) != EncapsulationKeySize768 {
   313  		return nil, errors.New("mlkem: invalid encapsulation key length")
   314  	}
   315  
   316  	h := sha3.New256()
   317  	h.Write(ekPKE)
   318  	h.Sum(ek.h[:0])
   319  
   320  	for i := range ek.t {
   321  		var err error
   322  		ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
   323  		if err != nil {
   324  			return nil, err
   325  		}
   326  		ekPKE = ekPKE[encodingSize12:]
   327  	}
   328  	copy(ek.ρ[:], ekPKE)
   329  
   330  	for i := byte(0); i < k; i++ {
   331  		for j := byte(0); j < k; j++ {
   332  			ek.a[i*k+j] = sampleNTT(ek.ρ[:], j, i)
   333  		}
   334  	}
   335  
   336  	return ek, nil
   337  }
   338  
   339  // pkeEncrypt encrypt a plaintext message.
   340  //
   341  // It implements K-PKE.Encrypt according to FIPS 203, Algorithm 14, although the
   342  // computation of t and AT is done in parseEK.
   343  func pkeEncrypt(cc *[CiphertextSize768]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte {
   344  	var N byte
   345  	r, e1 := make([]nttElement, k), make([]ringElement, k)
   346  	for i := range r {
   347  		r[i] = ntt(samplePolyCBD(rnd, N))
   348  		N++
   349  	}
   350  	for i := range e1 {
   351  		e1[i] = samplePolyCBD(rnd, N)
   352  		N++
   353  	}
   354  	e2 := samplePolyCBD(rnd, N)
   355  
   356  	u := make([]ringElement, k) // NTT⁻¹(AT ◦ r) + e1
   357  	for i := range u {
   358  		u[i] = e1[i]
   359  		for j := range r {
   360  			// Note that i and j are inverted, as we need the transposed of A.
   361  			u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k+i], r[j])))
   362  		}
   363  	}
   364  
   365  	μ := ringDecodeAndDecompress1(m)
   366  
   367  	var vNTT nttElement // t⊺ ◦ r
   368  	for i := range ex.t {
   369  		vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
   370  	}
   371  	v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
   372  
   373  	c := cc[:0]
   374  	for _, f := range u {
   375  		c = ringCompressAndEncode10(c, f)
   376  	}
   377  	c = ringCompressAndEncode4(c, v)
   378  
   379  	return c
   380  }
   381  
   382  // Decapsulate generates a shared key from a ciphertext and a decapsulation key.
   383  // If the ciphertext is not valid, Decapsulate returns an error.
   384  //
   385  // The shared key must be kept secret.
   386  func (dk *DecapsulationKey768) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
   387  	if len(ciphertext) != CiphertextSize768 {
   388  		return nil, errors.New("mlkem: invalid ciphertext length")
   389  	}
   390  	c := (*[CiphertextSize768]byte)(ciphertext)
   391  	// Note that the hash check (step 3 of the decapsulation input check from
   392  	// FIPS 203, Section 7.3) is foregone as a DecapsulationKey is always
   393  	// validly generated by ML-KEM.KeyGen_internal.
   394  	return kemDecaps(dk, c), nil
   395  }
   396  
   397  // kemDecaps produces a shared key from a ciphertext.
   398  //
   399  // It implements ML-KEM.Decaps_internal according to FIPS 203, Algorithm 18.
   400  func kemDecaps(dk *DecapsulationKey768, c *[CiphertextSize768]byte) (K []byte) {
   401  	fips140.RecordApproved()
   402  	m := pkeDecrypt(&dk.decryptionKey, c)
   403  	g := sha3.New512()
   404  	g.Write(m[:])
   405  	g.Write(dk.h[:])
   406  	G := g.Sum(make([]byte, 0, 64))
   407  	Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
   408  	J := sha3.NewShake256()
   409  	J.Write(dk.z[:])
   410  	J.Write(c[:])
   411  	Kout := make([]byte, SharedKeySize)
   412  	J.Read(Kout)
   413  	var cc [CiphertextSize768]byte
   414  	c1 := pkeEncrypt(&cc, &dk.encryptionKey, (*[32]byte)(m), r)
   415  
   416  	subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
   417  	return Kout
   418  }
   419  
   420  // pkeDecrypt decrypts a ciphertext.
   421  //
   422  // It implements K-PKE.Decrypt according to FIPS 203, Algorithm 15,
   423  // although s is retained from kemKeyGen.
   424  func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize768]byte) []byte {
   425  	u := make([]ringElement, k)
   426  	for i := range u {
   427  		b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)])
   428  		u[i] = ringDecodeAndDecompress10(b)
   429  	}
   430  
   431  	b := (*[encodingSize4]byte)(c[encodingSize10*k:])
   432  	v := ringDecodeAndDecompress4(b)
   433  
   434  	var mask nttElement // s⊺ ◦ NTT(u)
   435  	for i := range dx.s {
   436  		mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
   437  	}
   438  	w := polySub(v, inverseNTT(mask))
   439  
   440  	return ringCompressAndEncode1(nil, w)
   441  }
   442  

View as plain text