1
2
3
4
5
6
7
8
9 package mlkem
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26 import (
27 "crypto/internal/fips140"
28 "crypto/internal/fips140/drbg"
29 "crypto/internal/fips140/sha3"
30 "crypto/internal/fips140/subtle"
31 "errors"
32 )
33
34 const (
35
36 n = 256
37 q = 3329
38
39
40
41 encodingSize12 = n * 12 / 8
42 encodingSize11 = n * 11 / 8
43 encodingSize10 = n * 10 / 8
44 encodingSize5 = n * 5 / 8
45 encodingSize4 = n * 4 / 8
46 encodingSize1 = n * 1 / 8
47
48 messageSize = encodingSize1
49
50 SharedKeySize = 32
51 SeedSize = 32 + 32
52 )
53
54
55 const (
56 k = 3
57
58 CiphertextSize768 = k*encodingSize10 + encodingSize4
59 EncapsulationKeySize768 = k*encodingSize12 + 32
60 )
61
62
63 const (
64 k1024 = 4
65
66 CiphertextSize1024 = k1024*encodingSize11 + encodingSize5
67 EncapsulationKeySize1024 = k1024*encodingSize12 + 32
68 )
69
70
71
72 type DecapsulationKey768 struct {
73 d [32]byte
74 z [32]byte
75
76 ρ [32]byte
77 h [32]byte
78
79 encryptionKey
80 decryptionKey
81 }
82
83
84
85
86 func (dk *DecapsulationKey768) Bytes() []byte {
87 var b [SeedSize]byte
88 copy(b[:], dk.d[:])
89 copy(b[32:], dk.z[:])
90 return b[:]
91 }
92
93
94
95 func (dk *DecapsulationKey768) EncapsulationKey() *EncapsulationKey768 {
96 return &EncapsulationKey768{
97 ρ: dk.ρ,
98 h: dk.h,
99 encryptionKey: dk.encryptionKey,
100 }
101 }
102
103
104
105 type EncapsulationKey768 struct {
106 ρ [32]byte
107 h [32]byte
108 encryptionKey
109 }
110
111
112 func (ek *EncapsulationKey768) Bytes() []byte {
113
114 b := make([]byte, 0, EncapsulationKeySize768)
115 return ek.bytes(b)
116 }
117
118 func (ek *EncapsulationKey768) bytes(b []byte) []byte {
119 for i := range ek.t {
120 b = polyByteEncode(b, ek.t[i])
121 }
122 b = append(b, ek.ρ[:]...)
123 return b
124 }
125
126
127 type encryptionKey struct {
128 t [k]nttElement
129 a [k * k]nttElement
130 }
131
132
133 type decryptionKey struct {
134 s [k]nttElement
135 }
136
137
138
139 func GenerateKey768() (*DecapsulationKey768, error) {
140
141 dk := &DecapsulationKey768{}
142 return generateKey(dk)
143 }
144
145 func generateKey(dk *DecapsulationKey768) (*DecapsulationKey768, error) {
146 var d [32]byte
147 drbg.Read(d[:])
148 var z [32]byte
149 drbg.Read(z[:])
150 kemKeyGen(dk, &d, &z)
151 if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT(dk) }); err != nil {
152
153 panic(err)
154 }
155 fips140.RecordApproved()
156 return dk, nil
157 }
158
159
160
161 func GenerateKeyInternal768(d, z *[32]byte) *DecapsulationKey768 {
162 dk := &DecapsulationKey768{}
163 kemKeyGen(dk, d, z)
164 return dk
165 }
166
167
168
169 func NewDecapsulationKey768(seed []byte) (*DecapsulationKey768, error) {
170
171 dk := &DecapsulationKey768{}
172 return newKeyFromSeed(dk, seed)
173 }
174
175 func newKeyFromSeed(dk *DecapsulationKey768, seed []byte) (*DecapsulationKey768, error) {
176 if len(seed) != SeedSize {
177 return nil, errors.New("mlkem: invalid seed length")
178 }
179 d := (*[32]byte)(seed[:32])
180 z := (*[32]byte)(seed[32:])
181 kemKeyGen(dk, d, z)
182 if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT(dk) }); err != nil {
183
184 panic(err)
185 }
186 fips140.RecordApproved()
187 return dk, nil
188 }
189
190
191
192
193
194
195 func kemKeyGen(dk *DecapsulationKey768, d, z *[32]byte) {
196 dk.d = *d
197 dk.z = *z
198
199 g := sha3.New512()
200 g.Write(d[:])
201 g.Write([]byte{k})
202 G := g.Sum(make([]byte, 0, 64))
203 ρ, σ := G[:32], G[32:]
204 dk.ρ = [32]byte(ρ)
205
206 A := &dk.a
207 for i := byte(0); i < k; i++ {
208 for j := byte(0); j < k; j++ {
209 A[i*k+j] = sampleNTT(ρ, j, i)
210 }
211 }
212
213 var N byte
214 s := &dk.s
215 for i := range s {
216 s[i] = ntt(samplePolyCBD(σ, N))
217 N++
218 }
219 e := make([]nttElement, k)
220 for i := range e {
221 e[i] = ntt(samplePolyCBD(σ, N))
222 N++
223 }
224
225 t := &dk.t
226 for i := range t {
227 t[i] = e[i]
228 for j := range s {
229 t[i] = polyAdd(t[i], nttMul(A[i*k+j], s[j]))
230 }
231 }
232
233 H := sha3.New256()
234 ek := dk.EncapsulationKey().Bytes()
235 H.Write(ek)
236 H.Sum(dk.h[:0])
237 }
238
239
240
241
242
243
244
245
246
247 func kemPCT(dk *DecapsulationKey768) error {
248 ek := dk.EncapsulationKey()
249 c, K := ek.Encapsulate()
250 K1, err := dk.Decapsulate(c)
251 if err != nil {
252 return err
253 }
254 if subtle.ConstantTimeCompare(K, K1) != 1 {
255 return errors.New("mlkem: PCT failed")
256 }
257 return nil
258 }
259
260
261
262
263
264 func (ek *EncapsulationKey768) Encapsulate() (ciphertext, sharedKey []byte) {
265
266 var cc [CiphertextSize768]byte
267 return ek.encapsulate(&cc)
268 }
269
270 func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (ciphertext, sharedKey []byte) {
271 var m [messageSize]byte
272 drbg.Read(m[:])
273
274
275 fips140.RecordApproved()
276 return kemEncaps(cc, ek, &m)
277 }
278
279
280
281 func (ek *EncapsulationKey768) EncapsulateInternal(m *[32]byte) (ciphertext, sharedKey []byte) {
282 cc := &[CiphertextSize768]byte{}
283 return kemEncaps(cc, ek, m)
284 }
285
286
287
288
289 func kemEncaps(cc *[CiphertextSize768]byte, ek *EncapsulationKey768, m *[messageSize]byte) (c, K []byte) {
290 g := sha3.New512()
291 g.Write(m[:])
292 g.Write(ek.h[:])
293 G := g.Sum(nil)
294 K, r := G[:SharedKeySize], G[SharedKeySize:]
295 c = pkeEncrypt(cc, &ek.encryptionKey, m, r)
296 return c, K
297 }
298
299
300
301 func NewEncapsulationKey768(encapsulationKey []byte) (*EncapsulationKey768, error) {
302
303 ek := &EncapsulationKey768{}
304 return parseEK(ek, encapsulationKey)
305 }
306
307
308
309
310
311 func parseEK(ek *EncapsulationKey768, ekPKE []byte) (*EncapsulationKey768, error) {
312 if len(ekPKE) != EncapsulationKeySize768 {
313 return nil, errors.New("mlkem: invalid encapsulation key length")
314 }
315
316 h := sha3.New256()
317 h.Write(ekPKE)
318 h.Sum(ek.h[:0])
319
320 for i := range ek.t {
321 var err error
322 ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
323 if err != nil {
324 return nil, err
325 }
326 ekPKE = ekPKE[encodingSize12:]
327 }
328 copy(ek.ρ[:], ekPKE)
329
330 for i := byte(0); i < k; i++ {
331 for j := byte(0); j < k; j++ {
332 ek.a[i*k+j] = sampleNTT(ek.ρ[:], j, i)
333 }
334 }
335
336 return ek, nil
337 }
338
339
340
341
342
343 func pkeEncrypt(cc *[CiphertextSize768]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte {
344 var N byte
345 r, e1 := make([]nttElement, k), make([]ringElement, k)
346 for i := range r {
347 r[i] = ntt(samplePolyCBD(rnd, N))
348 N++
349 }
350 for i := range e1 {
351 e1[i] = samplePolyCBD(rnd, N)
352 N++
353 }
354 e2 := samplePolyCBD(rnd, N)
355
356 u := make([]ringElement, k)
357 for i := range u {
358 u[i] = e1[i]
359 for j := range r {
360
361 u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k+i], r[j])))
362 }
363 }
364
365 μ := ringDecodeAndDecompress1(m)
366
367 var vNTT nttElement
368 for i := range ex.t {
369 vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
370 }
371 v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
372
373 c := cc[:0]
374 for _, f := range u {
375 c = ringCompressAndEncode10(c, f)
376 }
377 c = ringCompressAndEncode4(c, v)
378
379 return c
380 }
381
382
383
384
385
386 func (dk *DecapsulationKey768) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
387 if len(ciphertext) != CiphertextSize768 {
388 return nil, errors.New("mlkem: invalid ciphertext length")
389 }
390 c := (*[CiphertextSize768]byte)(ciphertext)
391
392
393
394 return kemDecaps(dk, c), nil
395 }
396
397
398
399
400 func kemDecaps(dk *DecapsulationKey768, c *[CiphertextSize768]byte) (K []byte) {
401 fips140.RecordApproved()
402 m := pkeDecrypt(&dk.decryptionKey, c)
403 g := sha3.New512()
404 g.Write(m[:])
405 g.Write(dk.h[:])
406 G := g.Sum(make([]byte, 0, 64))
407 Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
408 J := sha3.NewShake256()
409 J.Write(dk.z[:])
410 J.Write(c[:])
411 Kout := make([]byte, SharedKeySize)
412 J.Read(Kout)
413 var cc [CiphertextSize768]byte
414 c1 := pkeEncrypt(&cc, &dk.encryptionKey, (*[32]byte)(m), r)
415
416 subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
417 return Kout
418 }
419
420
421
422
423
424 func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize768]byte) []byte {
425 u := make([]ringElement, k)
426 for i := range u {
427 b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)])
428 u[i] = ringDecodeAndDecompress10(b)
429 }
430
431 b := (*[encodingSize4]byte)(c[encodingSize10*k:])
432 v := ringDecodeAndDecompress4(b)
433
434 var mask nttElement
435 for i := range dx.s {
436 mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
437 }
438 w := polySub(v, inverseNTT(mask))
439
440 return ringCompressAndEncode1(nil, w)
441 }
442
View as plain text