1
2
3 package mlkem
4
5 import (
6 "crypto/internal/fips140"
7 "crypto/internal/fips140/drbg"
8 "crypto/internal/fips140/sha3"
9 "crypto/internal/fips140/subtle"
10 "errors"
11 )
12
13
14
15 type DecapsulationKey1024 struct {
16 d [32]byte
17 z [32]byte
18
19 ρ [32]byte
20 h [32]byte
21
22 encryptionKey1024
23 decryptionKey1024
24 }
25
26
27
28
29 func (dk *DecapsulationKey1024) Bytes() []byte {
30 var b [SeedSize]byte
31 copy(b[:], dk.d[:])
32 copy(b[32:], dk.z[:])
33 return b[:]
34 }
35
36
37
38 func (dk *DecapsulationKey1024) EncapsulationKey() *EncapsulationKey1024 {
39 return &EncapsulationKey1024{
40 ρ: dk.ρ,
41 h: dk.h,
42 encryptionKey1024: dk.encryptionKey1024,
43 }
44 }
45
46
47
48 type EncapsulationKey1024 struct {
49 ρ [32]byte
50 h [32]byte
51 encryptionKey1024
52 }
53
54
55 func (ek *EncapsulationKey1024) Bytes() []byte {
56
57 b := make([]byte, 0, EncapsulationKeySize1024)
58 return ek.bytes(b)
59 }
60
61 func (ek *EncapsulationKey1024) bytes(b []byte) []byte {
62 for i := range ek.t {
63 b = polyByteEncode(b, ek.t[i])
64 }
65 b = append(b, ek.ρ[:]...)
66 return b
67 }
68
69
70 type encryptionKey1024 struct {
71 t [k1024]nttElement
72 a [k1024 * k1024]nttElement
73 }
74
75
76 type decryptionKey1024 struct {
77 s [k1024]nttElement
78 }
79
80
81
82 func GenerateKey1024() (*DecapsulationKey1024, error) {
83
84 dk := &DecapsulationKey1024{}
85 return generateKey1024(dk)
86 }
87
88 func generateKey1024(dk *DecapsulationKey1024) (*DecapsulationKey1024, error) {
89 var d [32]byte
90 drbg.Read(d[:])
91 var z [32]byte
92 drbg.Read(z[:])
93 kemKeyGen1024(dk, &d, &z)
94 if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT1024(dk) }); err != nil {
95
96 panic(err)
97 }
98 fips140.RecordApproved()
99 return dk, nil
100 }
101
102
103
104 func GenerateKeyInternal1024(d, z *[32]byte) *DecapsulationKey1024 {
105 dk := &DecapsulationKey1024{}
106 kemKeyGen1024(dk, d, z)
107 return dk
108 }
109
110
111
112 func NewDecapsulationKey1024(seed []byte) (*DecapsulationKey1024, error) {
113
114 dk := &DecapsulationKey1024{}
115 return newKeyFromSeed1024(dk, seed)
116 }
117
118 func newKeyFromSeed1024(dk *DecapsulationKey1024, seed []byte) (*DecapsulationKey1024, error) {
119 if len(seed) != SeedSize {
120 return nil, errors.New("mlkem: invalid seed length")
121 }
122 d := (*[32]byte)(seed[:32])
123 z := (*[32]byte)(seed[32:])
124 kemKeyGen1024(dk, d, z)
125 if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT1024(dk) }); err != nil {
126
127 panic(err)
128 }
129 fips140.RecordApproved()
130 return dk, nil
131 }
132
133
134
135
136
137
138 func kemKeyGen1024(dk *DecapsulationKey1024, d, z *[32]byte) {
139 dk.d = *d
140 dk.z = *z
141
142 g := sha3.New512()
143 g.Write(d[:])
144 g.Write([]byte{k1024})
145 G := g.Sum(make([]byte, 0, 64))
146 ρ, σ := G[:32], G[32:]
147 dk.ρ = [32]byte(ρ)
148
149 A := &dk.a
150 for i := byte(0); i < k1024; i++ {
151 for j := byte(0); j < k1024; j++ {
152 A[i*k1024+j] = sampleNTT(ρ, j, i)
153 }
154 }
155
156 var N byte
157 s := &dk.s
158 for i := range s {
159 s[i] = ntt(samplePolyCBD(σ, N))
160 N++
161 }
162 e := make([]nttElement, k1024)
163 for i := range e {
164 e[i] = ntt(samplePolyCBD(σ, N))
165 N++
166 }
167
168 t := &dk.t
169 for i := range t {
170 t[i] = e[i]
171 for j := range s {
172 t[i] = polyAdd(t[i], nttMul(A[i*k1024+j], s[j]))
173 }
174 }
175
176 H := sha3.New256()
177 ek := dk.EncapsulationKey().Bytes()
178 H.Write(ek)
179 H.Sum(dk.h[:0])
180 }
181
182
183
184
185
186
187
188
189
190 func kemPCT1024(dk *DecapsulationKey1024) error {
191 ek := dk.EncapsulationKey()
192 c, K := ek.Encapsulate()
193 K1, err := dk.Decapsulate(c)
194 if err != nil {
195 return err
196 }
197 if subtle.ConstantTimeCompare(K, K1) != 1 {
198 return errors.New("mlkem: PCT failed")
199 }
200 return nil
201 }
202
203
204
205
206
207 func (ek *EncapsulationKey1024) Encapsulate() (ciphertext, sharedKey []byte) {
208
209 var cc [CiphertextSize1024]byte
210 return ek.encapsulate(&cc)
211 }
212
213 func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (ciphertext, sharedKey []byte) {
214 var m [messageSize]byte
215 drbg.Read(m[:])
216
217
218 fips140.RecordApproved()
219 return kemEncaps1024(cc, ek, &m)
220 }
221
222
223
224 func (ek *EncapsulationKey1024) EncapsulateInternal(m *[32]byte) (ciphertext, sharedKey []byte) {
225 cc := &[CiphertextSize1024]byte{}
226 return kemEncaps1024(cc, ek, m)
227 }
228
229
230
231
232 func kemEncaps1024(cc *[CiphertextSize1024]byte, ek *EncapsulationKey1024, m *[messageSize]byte) (c, K []byte) {
233 g := sha3.New512()
234 g.Write(m[:])
235 g.Write(ek.h[:])
236 G := g.Sum(nil)
237 K, r := G[:SharedKeySize], G[SharedKeySize:]
238 c = pkeEncrypt1024(cc, &ek.encryptionKey1024, m, r)
239 return c, K
240 }
241
242
243
244 func NewEncapsulationKey1024(encapsulationKey []byte) (*EncapsulationKey1024, error) {
245
246 ek := &EncapsulationKey1024{}
247 return parseEK1024(ek, encapsulationKey)
248 }
249
250
251
252
253
254 func parseEK1024(ek *EncapsulationKey1024, ekPKE []byte) (*EncapsulationKey1024, error) {
255 if len(ekPKE) != EncapsulationKeySize1024 {
256 return nil, errors.New("mlkem: invalid encapsulation key length")
257 }
258
259 h := sha3.New256()
260 h.Write(ekPKE)
261 h.Sum(ek.h[:0])
262
263 for i := range ek.t {
264 var err error
265 ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
266 if err != nil {
267 return nil, err
268 }
269 ekPKE = ekPKE[encodingSize12:]
270 }
271 copy(ek.ρ[:], ekPKE)
272
273 for i := byte(0); i < k1024; i++ {
274 for j := byte(0); j < k1024; j++ {
275 ek.a[i*k1024+j] = sampleNTT(ek.ρ[:], j, i)
276 }
277 }
278
279 return ek, nil
280 }
281
282
283
284
285
286 func pkeEncrypt1024(cc *[CiphertextSize1024]byte, ex *encryptionKey1024, m *[messageSize]byte, rnd []byte) []byte {
287 var N byte
288 r, e1 := make([]nttElement, k1024), make([]ringElement, k1024)
289 for i := range r {
290 r[i] = ntt(samplePolyCBD(rnd, N))
291 N++
292 }
293 for i := range e1 {
294 e1[i] = samplePolyCBD(rnd, N)
295 N++
296 }
297 e2 := samplePolyCBD(rnd, N)
298
299 u := make([]ringElement, k1024)
300 for i := range u {
301 u[i] = e1[i]
302 for j := range r {
303
304 u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k1024+i], r[j])))
305 }
306 }
307
308 μ := ringDecodeAndDecompress1(m)
309
310 var vNTT nttElement
311 for i := range ex.t {
312 vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
313 }
314 v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
315
316 c := cc[:0]
317 for _, f := range u {
318 c = ringCompressAndEncode11(c, f)
319 }
320 c = ringCompressAndEncode5(c, v)
321
322 return c
323 }
324
325
326
327
328
329 func (dk *DecapsulationKey1024) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
330 if len(ciphertext) != CiphertextSize1024 {
331 return nil, errors.New("mlkem: invalid ciphertext length")
332 }
333 c := (*[CiphertextSize1024]byte)(ciphertext)
334
335
336
337 return kemDecaps1024(dk, c), nil
338 }
339
340
341
342
343 func kemDecaps1024(dk *DecapsulationKey1024, c *[CiphertextSize1024]byte) (K []byte) {
344 fips140.RecordApproved()
345 m := pkeDecrypt1024(&dk.decryptionKey1024, c)
346 g := sha3.New512()
347 g.Write(m[:])
348 g.Write(dk.h[:])
349 G := g.Sum(make([]byte, 0, 64))
350 Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
351 J := sha3.NewShake256()
352 J.Write(dk.z[:])
353 J.Write(c[:])
354 Kout := make([]byte, SharedKeySize)
355 J.Read(Kout)
356 var cc [CiphertextSize1024]byte
357 c1 := pkeEncrypt1024(&cc, &dk.encryptionKey1024, (*[32]byte)(m), r)
358
359 subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
360 return Kout
361 }
362
363
364
365
366
367 func pkeDecrypt1024(dx *decryptionKey1024, c *[CiphertextSize1024]byte) []byte {
368 u := make([]ringElement, k1024)
369 for i := range u {
370 b := (*[encodingSize11]byte)(c[encodingSize11*i : encodingSize11*(i+1)])
371 u[i] = ringDecodeAndDecompress11(b)
372 }
373
374 b := (*[encodingSize5]byte)(c[encodingSize11*k1024:])
375 v := ringDecodeAndDecompress5(b)
376
377 var mask nttElement
378 for i := range dx.s {
379 mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
380 }
381 w := polySub(v, inverseNTT(mask))
382
383 return ringCompressAndEncode1(nil, w)
384 }
385
View as plain text