1
2
3 package mlkem
4
5 import (
6 "bytes"
7 "crypto/internal/fips140"
8 "crypto/internal/fips140/drbg"
9 "crypto/internal/fips140/sha3"
10 "crypto/internal/fips140/subtle"
11 "errors"
12 )
13
14
15
16 type DecapsulationKey1024 struct {
17 d [32]byte
18 z [32]byte
19
20 ρ [32]byte
21 h [32]byte
22
23 encryptionKey1024
24 decryptionKey1024
25 }
26
27
28
29
30 func (dk *DecapsulationKey1024) Bytes() []byte {
31 var b [SeedSize]byte
32 copy(b[:], dk.d[:])
33 copy(b[32:], dk.z[:])
34 return b[:]
35 }
36
37
38
39
40
41
42 func TestingOnlyExpandedBytes1024(dk *DecapsulationKey1024) []byte {
43 b := make([]byte, 0, decapsulationKeySize1024)
44
45
46 for i := range dk.s {
47 b = polyByteEncode(b, dk.s[i])
48 }
49
50
51 for i := range dk.t {
52 b = polyByteEncode(b, dk.t[i])
53 }
54 b = append(b, dk.ρ[:]...)
55
56
57 b = append(b, dk.h[:]...)
58 b = append(b, dk.z[:]...)
59
60 return b
61 }
62
63
64
65 func (dk *DecapsulationKey1024) EncapsulationKey() *EncapsulationKey1024 {
66 return &EncapsulationKey1024{
67 ρ: dk.ρ,
68 h: dk.h,
69 encryptionKey1024: dk.encryptionKey1024,
70 }
71 }
72
73
74
75 type EncapsulationKey1024 struct {
76 ρ [32]byte
77 h [32]byte
78 encryptionKey1024
79 }
80
81
82 func (ek *EncapsulationKey1024) Bytes() []byte {
83
84 b := make([]byte, 0, EncapsulationKeySize1024)
85 return ek.bytes(b)
86 }
87
88 func (ek *EncapsulationKey1024) bytes(b []byte) []byte {
89 for i := range ek.t {
90 b = polyByteEncode(b, ek.t[i])
91 }
92 b = append(b, ek.ρ[:]...)
93 return b
94 }
95
96
97 type encryptionKey1024 struct {
98 t [k1024]nttElement
99 a [k1024 * k1024]nttElement
100 }
101
102
103 type decryptionKey1024 struct {
104 s [k1024]nttElement
105 }
106
107
108
109 func GenerateKey1024() (*DecapsulationKey1024, error) {
110
111 dk := &DecapsulationKey1024{}
112 return generateKey1024(dk)
113 }
114
115 func generateKey1024(dk *DecapsulationKey1024) (*DecapsulationKey1024, error) {
116 fipsSelfTest()
117 var d [32]byte
118 drbg.Read(d[:])
119 var z [32]byte
120 drbg.Read(z[:])
121 kemKeyGen1024(dk, &d, &z)
122 fips140.PCT("ML-KEM PCT", func() error { return kemPCT1024(dk) })
123 fips140.RecordApproved()
124 return dk, nil
125 }
126
127
128
129 func GenerateKeyInternal1024(d, z *[32]byte) *DecapsulationKey1024 {
130 fipsSelfTest()
131 dk := &DecapsulationKey1024{}
132 kemKeyGen1024(dk, d, z)
133 return dk
134 }
135
136
137
138 func NewDecapsulationKey1024(seed []byte) (*DecapsulationKey1024, error) {
139
140 dk := &DecapsulationKey1024{}
141 return newKeyFromSeed1024(dk, seed)
142 }
143
144 func newKeyFromSeed1024(dk *DecapsulationKey1024, seed []byte) (*DecapsulationKey1024, error) {
145 if len(seed) != SeedSize {
146 return nil, errors.New("mlkem: invalid seed length")
147 }
148 d := (*[32]byte)(seed[:32])
149 z := (*[32]byte)(seed[32:])
150 kemKeyGen1024(dk, d, z)
151 fips140.RecordApproved()
152 return dk, nil
153 }
154
155
156
157
158
159
160
161
162 func TestingOnlyNewDecapsulationKey1024(b []byte) (*DecapsulationKey1024, error) {
163 if len(b) != decapsulationKeySize1024 {
164 return nil, errors.New("mlkem: invalid NIST decapsulation key length")
165 }
166
167 dk := &DecapsulationKey1024{}
168 for i := range dk.s {
169 var err error
170 dk.s[i], err = polyByteDecode[nttElement](b[:encodingSize12])
171 if err != nil {
172 return nil, errors.New("mlkem: invalid secret key encoding")
173 }
174 b = b[encodingSize12:]
175 }
176
177 ek, err := NewEncapsulationKey1024(b[:EncapsulationKeySize1024])
178 if err != nil {
179 return nil, err
180 }
181 dk.ρ = ek.ρ
182 dk.h = ek.h
183 dk.encryptionKey1024 = ek.encryptionKey1024
184 b = b[EncapsulationKeySize1024:]
185
186 if !bytes.Equal(dk.h[:], b[:32]) {
187 return nil, errors.New("mlkem: inconsistent H(ek) in encoded bytes")
188 }
189 b = b[32:]
190
191 copy(dk.z[:], b)
192
193
194
195
196
197 drbg.Read(dk.d[:])
198
199 return dk, nil
200 }
201
202
203
204
205
206
207 func kemKeyGen1024(dk *DecapsulationKey1024, d, z *[32]byte) {
208 dk.d = *d
209 dk.z = *z
210
211 g := sha3.New512()
212 g.Write(d[:])
213 g.Write([]byte{k1024})
214 G := g.Sum(make([]byte, 0, 64))
215 ρ, σ := G[:32], G[32:]
216 dk.ρ = [32]byte(ρ)
217
218 A := &dk.a
219 for i := byte(0); i < k1024; i++ {
220 for j := byte(0); j < k1024; j++ {
221 A[i*k1024+j] = sampleNTT(ρ, j, i)
222 }
223 }
224
225 var N byte
226 s := &dk.s
227 for i := range s {
228 s[i] = ntt(samplePolyCBD(σ, N))
229 N++
230 }
231 e := make([]nttElement, k1024)
232 for i := range e {
233 e[i] = ntt(samplePolyCBD(σ, N))
234 N++
235 }
236
237 t := &dk.t
238 for i := range t {
239 t[i] = e[i]
240 for j := range s {
241 t[i] = polyAdd(t[i], nttMul(A[i*k1024+j], s[j]))
242 }
243 }
244
245 H := sha3.New256()
246 ek := dk.EncapsulationKey().Bytes()
247 H.Write(ek)
248 H.Sum(dk.h[:0])
249 }
250
251
252
253
254
255
256
257
258
259 func kemPCT1024(dk *DecapsulationKey1024) error {
260 ek := dk.EncapsulationKey()
261 K, c := ek.Encapsulate()
262 K1, err := dk.Decapsulate(c)
263 if err != nil {
264 return err
265 }
266 if subtle.ConstantTimeCompare(K, K1) != 1 {
267 return errors.New("mlkem: PCT failed")
268 }
269 return nil
270 }
271
272
273
274
275
276 func (ek *EncapsulationKey1024) Encapsulate() (sharedKey, ciphertext []byte) {
277
278 var cc [CiphertextSize1024]byte
279 return ek.encapsulate(&cc)
280 }
281
282 func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (sharedKey, ciphertext []byte) {
283 fipsSelfTest()
284 var m [messageSize]byte
285 drbg.Read(m[:])
286
287
288 fips140.RecordApproved()
289 return kemEncaps1024(cc, ek, &m)
290 }
291
292
293
294 func (ek *EncapsulationKey1024) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
295 fipsSelfTest()
296 cc := &[CiphertextSize1024]byte{}
297 return kemEncaps1024(cc, ek, m)
298 }
299
300
301
302
303 func kemEncaps1024(cc *[CiphertextSize1024]byte, ek *EncapsulationKey1024, m *[messageSize]byte) (K, c []byte) {
304 g := sha3.New512()
305 g.Write(m[:])
306 g.Write(ek.h[:])
307 G := g.Sum(nil)
308 K, r := G[:SharedKeySize], G[SharedKeySize:]
309 c = pkeEncrypt1024(cc, &ek.encryptionKey1024, m, r)
310 return K, c
311 }
312
313
314
315 func NewEncapsulationKey1024(encapsulationKey []byte) (*EncapsulationKey1024, error) {
316
317 ek := &EncapsulationKey1024{}
318 return parseEK1024(ek, encapsulationKey)
319 }
320
321
322
323
324
325 func parseEK1024(ek *EncapsulationKey1024, ekPKE []byte) (*EncapsulationKey1024, error) {
326 if len(ekPKE) != EncapsulationKeySize1024 {
327 return nil, errors.New("mlkem: invalid encapsulation key length")
328 }
329
330 h := sha3.New256()
331 h.Write(ekPKE)
332 h.Sum(ek.h[:0])
333
334 for i := range ek.t {
335 var err error
336 ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
337 if err != nil {
338 return nil, err
339 }
340 ekPKE = ekPKE[encodingSize12:]
341 }
342 copy(ek.ρ[:], ekPKE)
343
344 for i := byte(0); i < k1024; i++ {
345 for j := byte(0); j < k1024; j++ {
346 ek.a[i*k1024+j] = sampleNTT(ek.ρ[:], j, i)
347 }
348 }
349
350 return ek, nil
351 }
352
353
354
355
356
357 func pkeEncrypt1024(cc *[CiphertextSize1024]byte, ex *encryptionKey1024, m *[messageSize]byte, rnd []byte) []byte {
358 var N byte
359 r, e1 := make([]nttElement, k1024), make([]ringElement, k1024)
360 for i := range r {
361 r[i] = ntt(samplePolyCBD(rnd, N))
362 N++
363 }
364 for i := range e1 {
365 e1[i] = samplePolyCBD(rnd, N)
366 N++
367 }
368 e2 := samplePolyCBD(rnd, N)
369
370 u := make([]ringElement, k1024)
371 for i := range u {
372 u[i] = e1[i]
373 for j := range r {
374
375 u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k1024+i], r[j])))
376 }
377 }
378
379 μ := ringDecodeAndDecompress1(m)
380
381 var vNTT nttElement
382 for i := range ex.t {
383 vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
384 }
385 v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
386
387 c := cc[:0]
388 for _, f := range u {
389 c = ringCompressAndEncode11(c, f)
390 }
391 c = ringCompressAndEncode5(c, v)
392
393 return c
394 }
395
396
397
398
399
400 func (dk *DecapsulationKey1024) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
401 fipsSelfTest()
402 if len(ciphertext) != CiphertextSize1024 {
403 return nil, errors.New("mlkem: invalid ciphertext length")
404 }
405 c := (*[CiphertextSize1024]byte)(ciphertext)
406
407
408
409 return kemDecaps1024(dk, c), nil
410 }
411
412
413
414
415 func kemDecaps1024(dk *DecapsulationKey1024, c *[CiphertextSize1024]byte) (K []byte) {
416 fips140.RecordApproved()
417 m := pkeDecrypt1024(&dk.decryptionKey1024, c)
418 g := sha3.New512()
419 g.Write(m[:])
420 g.Write(dk.h[:])
421 G := g.Sum(make([]byte, 0, 64))
422 Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
423 J := sha3.NewShake256()
424 J.Write(dk.z[:])
425 J.Write(c[:])
426 Kout := make([]byte, SharedKeySize)
427 J.Read(Kout)
428 var cc [CiphertextSize1024]byte
429 c1 := pkeEncrypt1024(&cc, &dk.encryptionKey1024, (*[32]byte)(m), r)
430
431 subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
432 return Kout
433 }
434
435
436
437
438
439 func pkeDecrypt1024(dx *decryptionKey1024, c *[CiphertextSize1024]byte) []byte {
440 u := make([]ringElement, k1024)
441 for i := range u {
442 b := (*[encodingSize11]byte)(c[encodingSize11*i : encodingSize11*(i+1)])
443 u[i] = ringDecodeAndDecompress11(b)
444 }
445
446 b := (*[encodingSize5]byte)(c[encodingSize11*k1024:])
447 v := ringDecodeAndDecompress5(b)
448
449 var mask nttElement
450 for i := range dx.s {
451 mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
452 }
453 w := polySub(v, inverseNTT(mask))
454
455 return ringCompressAndEncode1(nil, w)
456 }
457
View as plain text