1
2
3
4
5 package hpke
6
7 import (
8 "bytes"
9 "crypto"
10 "crypto/ecdh"
11 "crypto/mlkem"
12 "crypto/rand"
13 "crypto/sha3"
14 "errors"
15 "internal/byteorder"
16 )
17
18 var mlkem768X25519 = &hybridKEM{
19 id: 0x647a,
20 label: `\./` +
21 `/^\`,
22 curve: ecdh.X25519(),
23
24 curveSeedSize: 32,
25 curvePointSize: 32,
26 pqEncapsKeySize: mlkem.EncapsulationKeySize768,
27 pqCiphertextSize: mlkem.CiphertextSize768,
28
29 pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
30 return mlkem.NewEncapsulationKey768(data)
31 },
32 pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
33 return mlkem.NewDecapsulationKey768(data)
34 },
35 pqGenerateKey: func() (crypto.Decapsulator, error) {
36 return mlkem.GenerateKey768()
37 },
38 }
39
40
41
42 func MLKEM768X25519() KEM {
43 return mlkem768X25519
44 }
45
46 var mlkem768P256 = &hybridKEM{
47 id: 0x0050,
48 label: "MLKEM768-P256",
49 curve: ecdh.P256(),
50
51 curveSeedSize: 32,
52 curvePointSize: 65,
53 pqEncapsKeySize: mlkem.EncapsulationKeySize768,
54 pqCiphertextSize: mlkem.CiphertextSize768,
55
56 pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
57 return mlkem.NewEncapsulationKey768(data)
58 },
59 pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
60 return mlkem.NewDecapsulationKey768(data)
61 },
62 pqGenerateKey: func() (crypto.Decapsulator, error) {
63 return mlkem.GenerateKey768()
64 },
65 }
66
67
68 func MLKEM768P256() KEM {
69 return mlkem768P256
70 }
71
72 var mlkem1024P384 = &hybridKEM{
73 id: 0x0051,
74 label: "MLKEM1024-P384",
75 curve: ecdh.P384(),
76
77 curveSeedSize: 48,
78 curvePointSize: 97,
79 pqEncapsKeySize: mlkem.EncapsulationKeySize1024,
80 pqCiphertextSize: mlkem.CiphertextSize1024,
81
82 pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
83 return mlkem.NewEncapsulationKey1024(data)
84 },
85 pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
86 return mlkem.NewDecapsulationKey1024(data)
87 },
88 pqGenerateKey: func() (crypto.Decapsulator, error) {
89 return mlkem.GenerateKey1024()
90 },
91 }
92
93
94 func MLKEM1024P384() KEM {
95 return mlkem1024P384
96 }
97
98 type hybridKEM struct {
99 id uint16
100 label string
101 curve ecdh.Curve
102
103 curveSeedSize int
104 curvePointSize int
105 pqEncapsKeySize int
106 pqCiphertextSize int
107
108 pqNewPublicKey func(data []byte) (crypto.Encapsulator, error)
109 pqNewPrivateKey func(data []byte) (crypto.Decapsulator, error)
110 pqGenerateKey func() (crypto.Decapsulator, error)
111 }
112
113 func (kem *hybridKEM) ID() uint16 {
114 return kem.id
115 }
116
117 func (kem *hybridKEM) encSize() int {
118 return kem.pqCiphertextSize + kem.curvePointSize
119 }
120
121 func (kem *hybridKEM) sharedSecret(ssPQ, ssT, ctT, ekT []byte) []byte {
122 h := sha3.New256()
123 h.Write(ssPQ)
124 h.Write(ssT)
125 h.Write(ctT)
126 h.Write(ekT)
127 h.Write([]byte(kem.label))
128 return h.Sum(nil)
129 }
130
131 type hybridPublicKey struct {
132 kem *hybridKEM
133 t *ecdh.PublicKey
134 pq crypto.Encapsulator
135 }
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150 func NewHybridPublicKey(pq crypto.Encapsulator, t *ecdh.PublicKey) (PublicKey, error) {
151 switch t.Curve() {
152 case ecdh.X25519():
153 if _, ok := pq.(*mlkem.EncapsulationKey768); !ok {
154 return nil, errors.New("invalid PQ KEM for X25519 hybrid")
155 }
156 return &hybridPublicKey{mlkem768X25519, t, pq}, nil
157 case ecdh.P256():
158 if _, ok := pq.(*mlkem.EncapsulationKey768); !ok {
159 return nil, errors.New("invalid PQ KEM for P-256 hybrid")
160 }
161 return &hybridPublicKey{mlkem768P256, t, pq}, nil
162 case ecdh.P384():
163 if _, ok := pq.(*mlkem.EncapsulationKey1024); !ok {
164 return nil, errors.New("invalid PQ KEM for P-384 hybrid")
165 }
166 return &hybridPublicKey{mlkem1024P384, t, pq}, nil
167 default:
168 return nil, errors.New("unsupported curve")
169 }
170 }
171
172 func (kem *hybridKEM) NewPublicKey(data []byte) (PublicKey, error) {
173 if len(data) != kem.pqEncapsKeySize+kem.curvePointSize {
174 return nil, errors.New("invalid public key size")
175 }
176 pq, err := kem.pqNewPublicKey(data[:kem.pqEncapsKeySize])
177 if err != nil {
178 return nil, err
179 }
180 k, err := kem.curve.NewPublicKey(data[kem.pqEncapsKeySize:])
181 if err != nil {
182 return nil, err
183 }
184 return NewHybridPublicKey(pq, k)
185 }
186
187 func (pk *hybridPublicKey) KEM() KEM {
188 return pk.kem
189 }
190
191 func (pk *hybridPublicKey) Bytes() []byte {
192 return append(pk.pq.Bytes(), pk.t.Bytes()...)
193 }
194
195 var testingOnlyEncapsulate func() (ss, ct []byte)
196
197 func (pk *hybridPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
198 skE, err := pk.t.Curve().GenerateKey(rand.Reader)
199 if err != nil {
200 return nil, nil, err
201 }
202 if testingOnlyGenerateKey != nil {
203 skE = testingOnlyGenerateKey()
204 }
205 ssT, err := skE.ECDH(pk.t)
206 if err != nil {
207 return nil, nil, err
208 }
209 ctT := skE.PublicKey().Bytes()
210
211 ssPQ, ctPQ := pk.pq.Encapsulate()
212 if testingOnlyEncapsulate != nil {
213 ssPQ, ctPQ = testingOnlyEncapsulate()
214 }
215
216 ss := pk.kem.sharedSecret(ssPQ, ssT, ctT, pk.t.Bytes())
217 ct := append(ctPQ, ctT...)
218 return ss, ct, nil
219 }
220
221 type hybridPrivateKey struct {
222 kem *hybridKEM
223 seed []byte
224 t ecdh.KeyExchanger
225 pq crypto.Decapsulator
226 }
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243 func NewHybridPrivateKey(pq crypto.Decapsulator, t ecdh.KeyExchanger) (PrivateKey, error) {
244 return newHybridPrivateKey(pq, t, nil)
245 }
246
247 func (kem *hybridKEM) GenerateKey() (PrivateKey, error) {
248 seed := make([]byte, 32)
249 rand.Read(seed)
250 return kem.NewPrivateKey(seed)
251 }
252
253 func (kem *hybridKEM) NewPrivateKey(priv []byte) (PrivateKey, error) {
254 if len(priv) != 32 {
255 return nil, errors.New("hpke: invalid hybrid KEM secret length")
256 }
257
258 s := sha3.NewSHAKE256()
259 s.Write(priv)
260
261 seedPQ := make([]byte, mlkem.SeedSize)
262 s.Read(seedPQ)
263 pq, err := kem.pqNewPrivateKey(seedPQ)
264 if err != nil {
265 return nil, err
266 }
267
268 seedT := make([]byte, kem.curveSeedSize)
269 for {
270 s.Read(seedT)
271 k, err := kem.curve.NewPrivateKey(seedT)
272 if err != nil {
273 continue
274 }
275 return newHybridPrivateKey(pq, k, priv)
276 }
277 }
278
279 func newHybridPrivateKey(pq crypto.Decapsulator, t ecdh.KeyExchanger, seed []byte) (PrivateKey, error) {
280 switch t.Curve() {
281 case ecdh.X25519():
282 if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok {
283 return nil, errors.New("invalid PQ KEM for X25519 hybrid")
284 }
285 return &hybridPrivateKey{mlkem768X25519, bytes.Clone(seed), t, pq}, nil
286 case ecdh.P256():
287 if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok {
288 return nil, errors.New("invalid PQ KEM for P-256 hybrid")
289 }
290 return &hybridPrivateKey{mlkem768P256, bytes.Clone(seed), t, pq}, nil
291 case ecdh.P384():
292 if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey1024); !ok {
293 return nil, errors.New("invalid PQ KEM for P-384 hybrid")
294 }
295 return &hybridPrivateKey{mlkem1024P384, bytes.Clone(seed), t, pq}, nil
296 default:
297 return nil, errors.New("unsupported curve")
298 }
299 }
300
301 func (kem *hybridKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
302 suiteID := byteorder.BEAppendUint16([]byte("KEM"), kem.id)
303 dk, err := SHAKE256().labeledDerive(suiteID, ikm, "DeriveKeyPair", nil, 32)
304 if err != nil {
305 return nil, err
306 }
307 return kem.NewPrivateKey(dk)
308 }
309
310 func (k *hybridPrivateKey) KEM() KEM {
311 return k.kem
312 }
313
314 func (k *hybridPrivateKey) Bytes() ([]byte, error) {
315 if k.seed == nil {
316 return nil, errors.New("private key seed not available")
317 }
318 return k.seed, nil
319 }
320
321 func (k *hybridPrivateKey) PublicKey() PublicKey {
322 return &hybridPublicKey{
323 kem: k.kem,
324 t: k.t.PublicKey(),
325 pq: k.pq.Encapsulator(),
326 }
327 }
328
329 func (k *hybridPrivateKey) decap(enc []byte) ([]byte, error) {
330 if len(enc) != k.kem.pqCiphertextSize+k.kem.curvePointSize {
331 return nil, errors.New("invalid encapsulated key size")
332 }
333 ctPQ, ctT := enc[:k.kem.pqCiphertextSize], enc[k.kem.pqCiphertextSize:]
334 ssPQ, err := k.pq.Decapsulate(ctPQ)
335 if err != nil {
336 return nil, err
337 }
338 pub, err := k.t.Curve().NewPublicKey(ctT)
339 if err != nil {
340 return nil, err
341 }
342 ssT, err := k.t.ECDH(pub)
343 if err != nil {
344 return nil, err
345 }
346 ss := k.kem.sharedSecret(ssPQ, ssT, ctT, k.t.PublicKey().Bytes())
347 return ss, nil
348 }
349
350 var mlkem768 = &mlkemKEM{
351 id: 0x0041,
352 ciphertextSize: mlkem.CiphertextSize768,
353 newPublicKey: func(data []byte) (crypto.Encapsulator, error) {
354 return mlkem.NewEncapsulationKey768(data)
355 },
356 newPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
357 return mlkem.NewDecapsulationKey768(data)
358 },
359 generateKey: func() (crypto.Decapsulator, error) {
360 return mlkem.GenerateKey768()
361 },
362 }
363
364
365 func MLKEM768() KEM {
366 return mlkem768
367 }
368
369 var mlkem1024 = &mlkemKEM{
370 id: 0x0042,
371 ciphertextSize: mlkem.CiphertextSize1024,
372 newPublicKey: func(data []byte) (crypto.Encapsulator, error) {
373 return mlkem.NewEncapsulationKey1024(data)
374 },
375 newPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
376 return mlkem.NewDecapsulationKey1024(data)
377 },
378 generateKey: func() (crypto.Decapsulator, error) {
379 return mlkem.GenerateKey1024()
380 },
381 }
382
383
384 func MLKEM1024() KEM {
385 return mlkem1024
386 }
387
388 type mlkemKEM struct {
389 id uint16
390 ciphertextSize int
391 newPublicKey func(data []byte) (crypto.Encapsulator, error)
392 newPrivateKey func(data []byte) (crypto.Decapsulator, error)
393 generateKey func() (crypto.Decapsulator, error)
394 }
395
396 func (kem *mlkemKEM) ID() uint16 {
397 return kem.id
398 }
399
400 func (kem *mlkemKEM) encSize() int {
401 return kem.ciphertextSize
402 }
403
404 type mlkemPublicKey struct {
405 kem *mlkemKEM
406 pq crypto.Encapsulator
407 }
408
409
410
411
412
413
414
415
416
417
418
419
420 func NewMLKEMPublicKey(pub crypto.Encapsulator) (PublicKey, error) {
421 switch pub.(type) {
422 case *mlkem.EncapsulationKey768:
423 return &mlkemPublicKey{mlkem768, pub}, nil
424 case *mlkem.EncapsulationKey1024:
425 return &mlkemPublicKey{mlkem1024, pub}, nil
426 default:
427 return nil, errors.New("unsupported public key type")
428 }
429 }
430
431 func (kem *mlkemKEM) NewPublicKey(data []byte) (PublicKey, error) {
432 pq, err := kem.newPublicKey(data)
433 if err != nil {
434 return nil, err
435 }
436 return NewMLKEMPublicKey(pq)
437 }
438
439 func (pk *mlkemPublicKey) KEM() KEM {
440 return pk.kem
441 }
442
443 func (pk *mlkemPublicKey) Bytes() []byte {
444 return pk.pq.Bytes()
445 }
446
447 func (pk *mlkemPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
448 ss, ct := pk.pq.Encapsulate()
449 if testingOnlyEncapsulate != nil {
450 ss, ct = testingOnlyEncapsulate()
451 }
452 return ss, ct, nil
453 }
454
455 type mlkemPrivateKey struct {
456 kem *mlkemKEM
457 pq crypto.Decapsulator
458 }
459
460
461
462
463
464
465
466
467
468
469
470
471 func NewMLKEMPrivateKey(priv crypto.Decapsulator) (PrivateKey, error) {
472 switch priv.Encapsulator().(type) {
473 case *mlkem.EncapsulationKey768:
474 return &mlkemPrivateKey{mlkem768, priv}, nil
475 case *mlkem.EncapsulationKey1024:
476 return &mlkemPrivateKey{mlkem1024, priv}, nil
477 default:
478 return nil, errors.New("unsupported public key type")
479 }
480 }
481
482 func (kem *mlkemKEM) GenerateKey() (PrivateKey, error) {
483 pq, err := kem.generateKey()
484 if err != nil {
485 return nil, err
486 }
487 return NewMLKEMPrivateKey(pq)
488 }
489
490 func (kem *mlkemKEM) NewPrivateKey(priv []byte) (PrivateKey, error) {
491 pq, err := kem.newPrivateKey(priv)
492 if err != nil {
493 return nil, err
494 }
495 return NewMLKEMPrivateKey(pq)
496 }
497
498 func (kem *mlkemKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
499 suiteID := byteorder.BEAppendUint16([]byte("KEM"), kem.id)
500 dk, err := SHAKE256().labeledDerive(suiteID, ikm, "DeriveKeyPair", nil, 64)
501 if err != nil {
502 return nil, err
503 }
504 return kem.NewPrivateKey(dk)
505 }
506
507 func (k *mlkemPrivateKey) KEM() KEM {
508 return k.kem
509 }
510
511 func (k *mlkemPrivateKey) Bytes() ([]byte, error) {
512 pq, ok := k.pq.(interface {
513 Bytes() []byte
514 })
515 if !ok {
516 return nil, errors.New("private key seed not available")
517 }
518 return pq.Bytes(), nil
519 }
520
521 func (k *mlkemPrivateKey) PublicKey() PublicKey {
522 return &mlkemPublicKey{
523 kem: k.kem,
524 pq: k.pq.Encapsulator(),
525 }
526 }
527
528 func (k *mlkemPrivateKey) decap(enc []byte) ([]byte, error) {
529 return k.pq.Decapsulate(enc)
530 }
531
View as plain text