Source file src/crypto/internal/fips140/mldsa/semiexpanded.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mldsa
     6  
     7  import (
     8  	"crypto/internal/fips140/drbg"
     9  	"errors"
    10  	"math/bits"
    11  )
    12  
    13  // FIPS 204 defines a needless semi-expanded format for private keys. This is
    14  // not a good format for key storage and exchange, because it is large and
    15  // requires careful parsing to reject malformed keys. Seeds instead are just 32
    16  // bytes, are always valid, and always expand to valid keys in memory. It is
    17  // *also* a poor in-memory format, because it defers computing the NTT of s1,
    18  // s2, and t0 and the expansion of A until signing time, which is inefficient.
    19  // For a hot second, it looked like we could have all agreed to only use seeds,
    20  // but unfortunately OpenSSL and BouncyCastle lobbied hard against that during
    21  // the WGLC of the LAMPS IETF working group. Also, ACVP tests provide and expect
    22  // semi-expanded keys, so we implement them here for testing purposes.
    23  
    24  func semiExpandedPrivKeySize(p parameters) int {
    25  	k, l := p.k, p.l
    26  	ηBitlen := bits.Len(uint(p.η)) + 1
    27  	// ρ + K + tr + l × n × η-bit coefficients of s₁ +
    28  	// k × n × η-bit coefficients of s₂ + k × n × 13-bit coefficients of t₀
    29  	return 32 + 32 + 64 + l*n*ηBitlen/8 + k*n*ηBitlen/8 + k*n*13/8
    30  }
    31  
    32  // TestingOnlyNewPrivateKeyFromSemiExpanded creates a PrivateKey from a
    33  // semi-expanded private key encoding, for testing purposes. It rejects
    34  // inconsistent keys.
    35  //
    36  // [PrivateKey.Bytes] must NOT be called on the resulting key, as it will
    37  // produce a random value.
    38  func TestingOnlyNewPrivateKeyFromSemiExpanded(sk []byte) (*PrivateKey, error) {
    39  	var p parameters
    40  	switch len(sk) {
    41  	case semiExpandedPrivKeySize(params44):
    42  		p = params44
    43  	case semiExpandedPrivKeySize(params65):
    44  		p = params65
    45  	case semiExpandedPrivKeySize(params87):
    46  		p = params87
    47  	default:
    48  		return nil, errors.New("mldsa: invalid semi-expanded private key size")
    49  	}
    50  	k, l := p.k, p.l
    51  
    52  	ρ, K, tr, s1, s2, t0, err := skDecode(sk, p)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	priv := &PrivateKey{pub: PublicKey{p: p}}
    58  	priv.k = K
    59  	priv.pub.tr = tr
    60  	A := priv.pub.a[:k*l]
    61  	computeMatrixA(A, ρ[:], p)
    62  	for r := range l {
    63  		priv.s1[r] = ntt(s1[r])
    64  	}
    65  	for r := range k {
    66  		priv.s2[r] = ntt(s2[r])
    67  	}
    68  	for r := range k {
    69  		priv.t0[r] = ntt(t0[r])
    70  	}
    71  
    72  	// We need to put something in priv.seed, and putting random bytes feels
    73  	// safer than putting anything predictable.
    74  	drbg.Read(priv.seed[:])
    75  
    76  	// Making this format *even more* annoying, we need to recompute t1 from ρ,
    77  	// s1, and s2 if we want to generate the public key. This is essentially as
    78  	// much work as regenerating everything from seed.
    79  	//
    80  	// You might also notice that the semi-expanded format also stores t0 and a
    81  	// hash of the public key, though. How are we supposed to check they are
    82  	// consistent without regenerating the public key? Do we even need to check?
    83  	// Who knows! FIPS 204 says
    84  	//
    85  	//  > Note that there exist malformed inputs that can cause skDecode to
    86  	//  > return values that are not in the correct range. Hence, skDecode
    87  	//  > should only be run on inputs that come from trusted sources.
    88  	//
    89  	// so it sounds like it doesn't even want us to check the coefficients are
    90  	// within bounds, but especially if using this format for key exchange, that
    91  	// sounds like a bad idea. So we check everything.
    92  
    93  	t1 := make([][n]uint16, k, maxK)
    94  	for i := range k {
    95  		tHat := priv.s2[i]
    96  		for j := range l {
    97  			tHat = polyAdd(tHat, nttMul(A[i*l+j], priv.s1[j]))
    98  		}
    99  		t := inverseNTT(tHat)
   100  		for j := range n {
   101  			r1, r0 := power2Round(t[j])
   102  			t1[i][j] = r1
   103  			if r0 != t0[i][j] {
   104  				return nil, errors.New("mldsa: semi-expanded private key inconsistent with t0")
   105  			}
   106  		}
   107  	}
   108  
   109  	pk := pkEncode(priv.pub.raw[:0], ρ[:], t1, p)
   110  	if computePublicKeyHash(pk) != tr {
   111  		return nil, errors.New("mldsa: semi-expanded private key inconsistent with public key hash")
   112  	}
   113  	computeT1Hat(priv.pub.t1[:k], t1) // NTT(t₁ ⋅ 2ᵈ)
   114  
   115  	return priv, nil
   116  }
   117  
   118  func TestingOnlyPrivateKeySemiExpandedBytes(priv *PrivateKey) []byte {
   119  	k, l, η := priv.pub.p.k, priv.pub.p.l, priv.pub.p.η
   120  	sk := make([]byte, 0, semiExpandedPrivKeySize(priv.pub.p))
   121  	sk = append(sk, priv.pub.raw[:32]...) // ρ
   122  	sk = append(sk, priv.k[:]...)         // K
   123  	sk = append(sk, priv.pub.tr[:]...)    // tr
   124  	for i := range l {
   125  		sk = bitPackSlow(sk, inverseNTT(priv.s1[i]), η, η)
   126  	}
   127  	for i := range k {
   128  		sk = bitPackSlow(sk, inverseNTT(priv.s2[i]), η, η)
   129  	}
   130  	const bound = 1 << (13 - 1) // 2^(d-1)
   131  	for i := range k {
   132  		sk = bitPackSlow(sk, inverseNTT(priv.t0[i]), bound-1, bound)
   133  	}
   134  	return sk
   135  }
   136  
   137  func skDecode(sk []byte, p parameters) (ρ, K [32]byte, tr [64]byte, s1, s2, t0 []ringElement, err error) {
   138  	k, l, η := p.k, p.l, p.η
   139  	if len(sk) != semiExpandedPrivKeySize(p) {
   140  		err = errors.New("mldsa: invalid semi-expanded private key size")
   141  		return
   142  	}
   143  	copy(ρ[:], sk[:32])
   144  	sk = sk[32:]
   145  	copy(K[:], sk[:32])
   146  	sk = sk[32:]
   147  	copy(tr[:], sk[:64])
   148  	sk = sk[64:]
   149  
   150  	s1 = make([]ringElement, l)
   151  	for i := range l {
   152  		length := n * bits.Len(uint(η)*2) / 8
   153  		s1[i], err = bitUnpackSlow(sk[:length], η, η)
   154  		if err != nil {
   155  			return
   156  		}
   157  		sk = sk[length:]
   158  	}
   159  
   160  	s2 = make([]ringElement, k)
   161  	for i := range k {
   162  		length := n * bits.Len(uint(η)*2) / 8
   163  		s2[i], err = bitUnpackSlow(sk[:length], η, η)
   164  		if err != nil {
   165  			return
   166  		}
   167  		sk = sk[length:]
   168  	}
   169  
   170  	const bound = 1 << (13 - 1) // 2^(d-1)
   171  	t0 = make([]ringElement, k)
   172  	for i := range k {
   173  		length := n * 13 / 8
   174  		t0[i], err = bitUnpackSlow(sk[:length], bound-1, bound)
   175  		if err != nil {
   176  			return
   177  		}
   178  		sk = sk[length:]
   179  	}
   180  
   181  	return
   182  }
   183  
   184  func bitPackSlow(buf []byte, r ringElement, a, b int) []byte {
   185  	bitlen := bits.Len(uint(a + b))
   186  	if bitlen <= 0 || bitlen > 16 {
   187  		panic("mldsa: internal error: invalid bitlen")
   188  	}
   189  	out, v := sliceForAppend(buf, n*bitlen/8)
   190  	var acc uint32
   191  	var accBits uint
   192  	for i := range r {
   193  		w := int32(b) - fieldCenteredMod(r[i])
   194  		acc |= uint32(w) << accBits
   195  		accBits += uint(bitlen)
   196  		for accBits >= 8 {
   197  			v[0] = byte(acc)
   198  			v = v[1:]
   199  			acc >>= 8
   200  			accBits -= 8
   201  		}
   202  	}
   203  	if accBits > 0 {
   204  		v[0] = byte(acc)
   205  	}
   206  	return out
   207  }
   208  
   209  func bitUnpackSlow(v []byte, a, b int) (ringElement, error) {
   210  	bitlen := bits.Len(uint(a + b))
   211  	if bitlen <= 0 || bitlen > 16 {
   212  		panic("mldsa: internal error: invalid bitlen")
   213  	}
   214  	if len(v) != n*bitlen/8 {
   215  		return ringElement{}, errors.New("mldsa: invalid input length for bitUnpackSlow")
   216  	}
   217  
   218  	mask := uint32((1 << bitlen) - 1)
   219  	maxValue := uint32(a + b)
   220  
   221  	var r ringElement
   222  	var acc uint32
   223  	var accBits uint
   224  	vIdx := 0
   225  
   226  	for i := range r {
   227  		for accBits < uint(bitlen) {
   228  			if vIdx < len(v) {
   229  				acc |= uint32(v[vIdx]) << accBits
   230  				vIdx++
   231  				accBits += 8
   232  			}
   233  		}
   234  		w := acc & mask
   235  		if w > maxValue {
   236  			return ringElement{}, errors.New("mldsa: coefficient out of range")
   237  		}
   238  		r[i] = fieldSubToMontgomery(uint32(b), w)
   239  		acc >>= bitlen
   240  		accBits -= uint(bitlen)
   241  	}
   242  
   243  	return r, nil
   244  }
   245  

View as plain text