// Copyright 2025 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package mldsa import ( "crypto/internal/fips140/drbg" "errors" "math/bits" ) // FIPS 204 defines a needless semi-expanded format for private keys. This is // not a good format for key storage and exchange, because it is large and // requires careful parsing to reject malformed keys. Seeds instead are just 32 // bytes, are always valid, and always expand to valid keys in memory. It is // *also* a poor in-memory format, because it defers computing the NTT of s1, // s2, and t0 and the expansion of A until signing time, which is inefficient. // For a hot second, it looked like we could have all agreed to only use seeds, // but unfortunately OpenSSL and BouncyCastle lobbied hard against that during // the WGLC of the LAMPS IETF working group. Also, ACVP tests provide and expect // semi-expanded keys, so we implement them here for testing purposes. func semiExpandedPrivKeySize(p parameters) int { k, l := p.k, p.l ηBitlen := bits.Len(uint(p.η)) + 1 // ρ + K + tr + l × n × η-bit coefficients of s₁ + // k × n × η-bit coefficients of s₂ + k × n × 13-bit coefficients of t₀ return 32 + 32 + 64 + l*n*ηBitlen/8 + k*n*ηBitlen/8 + k*n*13/8 } // TestingOnlyNewPrivateKeyFromSemiExpanded creates a PrivateKey from a // semi-expanded private key encoding, for testing purposes. It rejects // inconsistent keys. // // [PrivateKey.Bytes] must NOT be called on the resulting key, as it will // produce a random value. func TestingOnlyNewPrivateKeyFromSemiExpanded(sk []byte) (*PrivateKey, error) { var p parameters switch len(sk) { case semiExpandedPrivKeySize(params44): p = params44 case semiExpandedPrivKeySize(params65): p = params65 case semiExpandedPrivKeySize(params87): p = params87 default: return nil, errors.New("mldsa: invalid semi-expanded private key size") } k, l := p.k, p.l ρ, K, tr, s1, s2, t0, err := skDecode(sk, p) if err != nil { return nil, err } priv := &PrivateKey{pub: PublicKey{p: p}} priv.k = K priv.pub.tr = tr A := priv.pub.a[:k*l] computeMatrixA(A, ρ[:], p) for r := range l { priv.s1[r] = ntt(s1[r]) } for r := range k { priv.s2[r] = ntt(s2[r]) } for r := range k { priv.t0[r] = ntt(t0[r]) } // We need to put something in priv.seed, and putting random bytes feels // safer than putting anything predictable. drbg.Read(priv.seed[:]) // Making this format *even more* annoying, we need to recompute t1 from ρ, // s1, and s2 if we want to generate the public key. This is essentially as // much work as regenerating everything from seed. // // You might also notice that the semi-expanded format also stores t0 and a // hash of the public key, though. How are we supposed to check they are // consistent without regenerating the public key? Do we even need to check? // Who knows! FIPS 204 says // // > Note that there exist malformed inputs that can cause skDecode to // > return values that are not in the correct range. Hence, skDecode // > should only be run on inputs that come from trusted sources. // // so it sounds like it doesn't even want us to check the coefficients are // within bounds, but especially if using this format for key exchange, that // sounds like a bad idea. So we check everything. t1 := make([][n]uint16, k, maxK) for i := range k { tHat := priv.s2[i] for j := range l { tHat = polyAdd(tHat, nttMul(A[i*l+j], priv.s1[j])) } t := inverseNTT(tHat) for j := range n { r1, r0 := power2Round(t[j]) t1[i][j] = r1 if r0 != t0[i][j] { return nil, errors.New("mldsa: semi-expanded private key inconsistent with t0") } } } pk := pkEncode(priv.pub.raw[:0], ρ[:], t1, p) if computePublicKeyHash(pk) != tr { return nil, errors.New("mldsa: semi-expanded private key inconsistent with public key hash") } computeT1Hat(priv.pub.t1[:k], t1) // NTT(t₁ ⋅ 2ᵈ) return priv, nil } func TestingOnlyPrivateKeySemiExpandedBytes(priv *PrivateKey) []byte { k, l, η := priv.pub.p.k, priv.pub.p.l, priv.pub.p.η sk := make([]byte, 0, semiExpandedPrivKeySize(priv.pub.p)) sk = append(sk, priv.pub.raw[:32]...) // ρ sk = append(sk, priv.k[:]...) // K sk = append(sk, priv.pub.tr[:]...) // tr for i := range l { sk = bitPackSlow(sk, inverseNTT(priv.s1[i]), η, η) } for i := range k { sk = bitPackSlow(sk, inverseNTT(priv.s2[i]), η, η) } const bound = 1 << (13 - 1) // 2^(d-1) for i := range k { sk = bitPackSlow(sk, inverseNTT(priv.t0[i]), bound-1, bound) } return sk } func skDecode(sk []byte, p parameters) (ρ, K [32]byte, tr [64]byte, s1, s2, t0 []ringElement, err error) { k, l, η := p.k, p.l, p.η if len(sk) != semiExpandedPrivKeySize(p) { err = errors.New("mldsa: invalid semi-expanded private key size") return } copy(ρ[:], sk[:32]) sk = sk[32:] copy(K[:], sk[:32]) sk = sk[32:] copy(tr[:], sk[:64]) sk = sk[64:] s1 = make([]ringElement, l) for i := range l { length := n * bits.Len(uint(η)*2) / 8 s1[i], err = bitUnpackSlow(sk[:length], η, η) if err != nil { return } sk = sk[length:] } s2 = make([]ringElement, k) for i := range k { length := n * bits.Len(uint(η)*2) / 8 s2[i], err = bitUnpackSlow(sk[:length], η, η) if err != nil { return } sk = sk[length:] } const bound = 1 << (13 - 1) // 2^(d-1) t0 = make([]ringElement, k) for i := range k { length := n * 13 / 8 t0[i], err = bitUnpackSlow(sk[:length], bound-1, bound) if err != nil { return } sk = sk[length:] } return } func bitPackSlow(buf []byte, r ringElement, a, b int) []byte { bitlen := bits.Len(uint(a + b)) if bitlen <= 0 || bitlen > 16 { panic("mldsa: internal error: invalid bitlen") } out, v := sliceForAppend(buf, n*bitlen/8) var acc uint32 var accBits uint for i := range r { w := int32(b) - fieldCenteredMod(r[i]) acc |= uint32(w) << accBits accBits += uint(bitlen) for accBits >= 8 { v[0] = byte(acc) v = v[1:] acc >>= 8 accBits -= 8 } } if accBits > 0 { v[0] = byte(acc) } return out } func bitUnpackSlow(v []byte, a, b int) (ringElement, error) { bitlen := bits.Len(uint(a + b)) if bitlen <= 0 || bitlen > 16 { panic("mldsa: internal error: invalid bitlen") } if len(v) != n*bitlen/8 { return ringElement{}, errors.New("mldsa: invalid input length for bitUnpackSlow") } mask := uint32((1 << bitlen) - 1) maxValue := uint32(a + b) var r ringElement var acc uint32 var accBits uint vIdx := 0 for i := range r { for accBits < uint(bitlen) { if vIdx < len(v) { acc |= uint32(v[vIdx]) << accBits vIdx++ accBits += 8 } } w := acc & mask if w > maxValue { return ringElement{}, errors.New("mldsa: coefficient out of range") } r[i] = fieldSubToMontgomery(uint32(b), w) acc >>= bitlen accBits -= uint(bitlen) } return r, nil }