1
2
3
4
5 package mldsa
6
7 import (
8 "crypto/internal/fips140/drbg"
9 "errors"
10 "math/bits"
11 )
12
13
14
15
16
17
18
19
20
21
22
23
24 func semiExpandedPrivKeySize(p parameters) int {
25 k, l := p.k, p.l
26 ηBitlen := bits.Len(uint(p.η)) + 1
27
28
29 return 32 + 32 + 64 + l*n*ηBitlen/8 + k*n*ηBitlen/8 + k*n*13/8
30 }
31
32
33
34
35
36
37
38 func TestingOnlyNewPrivateKeyFromSemiExpanded(sk []byte) (*PrivateKey, error) {
39 var p parameters
40 switch len(sk) {
41 case semiExpandedPrivKeySize(params44):
42 p = params44
43 case semiExpandedPrivKeySize(params65):
44 p = params65
45 case semiExpandedPrivKeySize(params87):
46 p = params87
47 default:
48 return nil, errors.New("mldsa: invalid semi-expanded private key size")
49 }
50 k, l := p.k, p.l
51
52 ρ, K, tr, s1, s2, t0, err := skDecode(sk, p)
53 if err != nil {
54 return nil, err
55 }
56
57 priv := &PrivateKey{pub: PublicKey{p: p}}
58 priv.k = K
59 priv.pub.tr = tr
60 A := priv.pub.a[:k*l]
61 computeMatrixA(A, ρ[:], p)
62 for r := range l {
63 priv.s1[r] = ntt(s1[r])
64 }
65 for r := range k {
66 priv.s2[r] = ntt(s2[r])
67 }
68 for r := range k {
69 priv.t0[r] = ntt(t0[r])
70 }
71
72
73
74 drbg.Read(priv.seed[:])
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93 t1 := make([][n]uint16, k, maxK)
94 for i := range k {
95 tHat := priv.s2[i]
96 for j := range l {
97 tHat = polyAdd(tHat, nttMul(A[i*l+j], priv.s1[j]))
98 }
99 t := inverseNTT(tHat)
100 for j := range n {
101 r1, r0 := power2Round(t[j])
102 t1[i][j] = r1
103 if r0 != t0[i][j] {
104 return nil, errors.New("mldsa: semi-expanded private key inconsistent with t0")
105 }
106 }
107 }
108
109 pk := pkEncode(priv.pub.raw[:0], ρ[:], t1, p)
110 if computePublicKeyHash(pk) != tr {
111 return nil, errors.New("mldsa: semi-expanded private key inconsistent with public key hash")
112 }
113 computeT1Hat(priv.pub.t1[:k], t1)
114
115 return priv, nil
116 }
117
118 func TestingOnlyPrivateKeySemiExpandedBytes(priv *PrivateKey) []byte {
119 k, l, η := priv.pub.p.k, priv.pub.p.l, priv.pub.p.η
120 sk := make([]byte, 0, semiExpandedPrivKeySize(priv.pub.p))
121 sk = append(sk, priv.pub.raw[:32]...)
122 sk = append(sk, priv.k[:]...)
123 sk = append(sk, priv.pub.tr[:]...)
124 for i := range l {
125 sk = bitPackSlow(sk, inverseNTT(priv.s1[i]), η, η)
126 }
127 for i := range k {
128 sk = bitPackSlow(sk, inverseNTT(priv.s2[i]), η, η)
129 }
130 const bound = 1 << (13 - 1)
131 for i := range k {
132 sk = bitPackSlow(sk, inverseNTT(priv.t0[i]), bound-1, bound)
133 }
134 return sk
135 }
136
137 func skDecode(sk []byte, p parameters) (ρ, K [32]byte, tr [64]byte, s1, s2, t0 []ringElement, err error) {
138 k, l, η := p.k, p.l, p.η
139 if len(sk) != semiExpandedPrivKeySize(p) {
140 err = errors.New("mldsa: invalid semi-expanded private key size")
141 return
142 }
143 copy(ρ[:], sk[:32])
144 sk = sk[32:]
145 copy(K[:], sk[:32])
146 sk = sk[32:]
147 copy(tr[:], sk[:64])
148 sk = sk[64:]
149
150 s1 = make([]ringElement, l)
151 for i := range l {
152 length := n * bits.Len(uint(η)*2) / 8
153 s1[i], err = bitUnpackSlow(sk[:length], η, η)
154 if err != nil {
155 return
156 }
157 sk = sk[length:]
158 }
159
160 s2 = make([]ringElement, k)
161 for i := range k {
162 length := n * bits.Len(uint(η)*2) / 8
163 s2[i], err = bitUnpackSlow(sk[:length], η, η)
164 if err != nil {
165 return
166 }
167 sk = sk[length:]
168 }
169
170 const bound = 1 << (13 - 1)
171 t0 = make([]ringElement, k)
172 for i := range k {
173 length := n * 13 / 8
174 t0[i], err = bitUnpackSlow(sk[:length], bound-1, bound)
175 if err != nil {
176 return
177 }
178 sk = sk[length:]
179 }
180
181 return
182 }
183
184 func bitPackSlow(buf []byte, r ringElement, a, b int) []byte {
185 bitlen := bits.Len(uint(a + b))
186 if bitlen <= 0 || bitlen > 16 {
187 panic("mldsa: internal error: invalid bitlen")
188 }
189 out, v := sliceForAppend(buf, n*bitlen/8)
190 var acc uint32
191 var accBits uint
192 for i := range r {
193 w := int32(b) - fieldCenteredMod(r[i])
194 acc |= uint32(w) << accBits
195 accBits += uint(bitlen)
196 for accBits >= 8 {
197 v[0] = byte(acc)
198 v = v[1:]
199 acc >>= 8
200 accBits -= 8
201 }
202 }
203 if accBits > 0 {
204 v[0] = byte(acc)
205 }
206 return out
207 }
208
209 func bitUnpackSlow(v []byte, a, b int) (ringElement, error) {
210 bitlen := bits.Len(uint(a + b))
211 if bitlen <= 0 || bitlen > 16 {
212 panic("mldsa: internal error: invalid bitlen")
213 }
214 if len(v) != n*bitlen/8 {
215 return ringElement{}, errors.New("mldsa: invalid input length for bitUnpackSlow")
216 }
217
218 mask := uint32((1 << bitlen) - 1)
219 maxValue := uint32(a + b)
220
221 var r ringElement
222 var acc uint32
223 var accBits uint
224 vIdx := 0
225
226 for i := range r {
227 for accBits < uint(bitlen) {
228 if vIdx < len(v) {
229 acc |= uint32(v[vIdx]) << accBits
230 vIdx++
231 accBits += 8
232 }
233 }
234 w := acc & mask
235 if w > maxValue {
236 return ringElement{}, errors.New("mldsa: coefficient out of range")
237 }
238 r[i] = fieldSubToMontgomery(uint32(b), w)
239 acc >>= bitlen
240 accBits -= uint(bitlen)
241 }
242
243 return r, nil
244 }
245
View as plain text