Source file
src/crypto/hpke/hpke_test.go
1
2
3
4
5 package hpke
6
7 import (
8 "bytes"
9 "crypto/ecdh"
10 "crypto/mlkem"
11 "crypto/mlkem/mlkemtest"
12 "crypto/sha3"
13 "encoding/hex"
14 "encoding/json"
15 "fmt"
16 "io"
17 "os"
18 "testing"
19 )
20
21 func Example() {
22
23
24
25
26 kem, kdf, aead := MLKEM768X25519(), HKDFSHA256(), AES256GCM()
27
28
29 var (
30 recipientPrivateKey PrivateKey
31 publicKeyBytes []byte
32 )
33 {
34 k, err := kem.GenerateKey()
35 if err != nil {
36 panic(err)
37 }
38 recipientPrivateKey = k
39 publicKeyBytes = k.PublicKey().Bytes()
40 }
41
42
43 var ciphertext []byte
44 {
45 publicKey, err := kem.NewPublicKey(publicKeyBytes)
46 if err != nil {
47 panic(err)
48 }
49
50 message := []byte("|-()-|")
51 ct, err := Seal(publicKey, kdf, aead, []byte("example"), message)
52 if err != nil {
53 panic(err)
54 }
55
56 ciphertext = ct
57 }
58
59
60 {
61 plaintext, err := Open(recipientPrivateKey, kdf, aead, []byte("example"), ciphertext)
62 if err != nil {
63 panic(err)
64 }
65 fmt.Printf("Decrypted message: %s\n", plaintext)
66 }
67
68
69
70 }
71
72 func mustDecodeHex(t *testing.T, in string) []byte {
73 t.Helper()
74 b, err := hex.DecodeString(in)
75 if err != nil {
76 t.Fatal(err)
77 }
78 return b
79 }
80
81 func TestVectors(t *testing.T) {
82 t.Run("rfc9180", func(t *testing.T) {
83 testVectors(t, "rfc9180")
84 })
85 t.Run("hpke-pq", func(t *testing.T) {
86 testVectors(t, "hpke-pq")
87 })
88 }
89
90 func testVectors(t *testing.T, name string) {
91 vectorsJSON, err := os.ReadFile("testdata/" + name + ".json")
92 if err != nil {
93 t.Fatal(err)
94 }
95 var vectors []struct {
96 Mode uint16 `json:"mode"`
97 KEM uint16 `json:"kem_id"`
98 KDF uint16 `json:"kdf_id"`
99 AEAD uint16 `json:"aead_id"`
100 Info string `json:"info"`
101 IkmE string `json:"ikmE"`
102 IkmR string `json:"ikmR"`
103 SkRm string `json:"skRm"`
104 PkRm string `json:"pkRm"`
105 Enc string `json:"enc"`
106 Encryptions []struct {
107 Aad string `json:"aad"`
108 Ct string `json:"ct"`
109 Nonce string `json:"nonce"`
110 Pt string `json:"pt"`
111 } `json:"encryptions"`
112 Exports []struct {
113 Context string `json:"exporter_context"`
114 L int `json:"L"`
115 Value string `json:"exported_value"`
116 } `json:"exports"`
117
118
119
120 AccEncryptions string `json:"encryptions_accumulated"`
121 AccExports string `json:"exports_accumulated"`
122 }
123 if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
124 t.Fatal(err)
125 }
126
127 for _, vector := range vectors {
128 name := fmt.Sprintf("mode %04x kem %04x kdf %04x aead %04x",
129 vector.Mode, vector.KEM, vector.KDF, vector.AEAD)
130 t.Run(name, func(t *testing.T) {
131 if vector.Mode != 0 {
132 t.Skip("only mode 0 (base) is supported")
133 }
134 if vector.KEM == 0x0021 {
135 t.Skip("KEM 0x0021 (DHKEM(X448)) not supported")
136 }
137 if vector.KEM == 0x0040 {
138 t.Skip("KEM 0x0040 (ML-KEM-512) not supported")
139 }
140 if vector.KDF == 0x0012 || vector.KDF == 0x0013 {
141 t.Skipf("TurboSHAKE KDF not supported")
142 }
143
144 kdf, err := NewKDF(vector.KDF)
145 if err != nil {
146 t.Fatal(err)
147 }
148 if kdf.ID() != vector.KDF {
149 t.Errorf("unexpected KDF ID: got %04x, want %04x", kdf.ID(), vector.KDF)
150 }
151
152 aead, err := NewAEAD(vector.AEAD)
153 if err != nil {
154 t.Fatal(err)
155 }
156 if aead.ID() != vector.AEAD {
157 t.Errorf("unexpected AEAD ID: got %04x, want %04x", aead.ID(), vector.AEAD)
158 }
159
160 kem, err := NewKEM(vector.KEM)
161 if err != nil {
162 t.Fatal(err)
163 }
164 if kem.ID() != vector.KEM {
165 t.Errorf("unexpected KEM ID: got %04x, want %04x", kem.ID(), vector.KEM)
166 }
167
168 pubKeyBytes := mustDecodeHex(t, vector.PkRm)
169 kemSender, err := kem.NewPublicKey(pubKeyBytes)
170 if err != nil {
171 t.Fatal(err)
172 }
173 if kemSender.KEM() != kem {
174 t.Errorf("unexpected KEM from sender: got %04x, want %04x", kemSender.KEM().ID(), kem.ID())
175 }
176 if !bytes.Equal(kemSender.Bytes(), pubKeyBytes) {
177 t.Errorf("unexpected KEM bytes: got %x, want %x", kemSender.Bytes(), pubKeyBytes)
178 }
179
180 ikmE := mustDecodeHex(t, vector.IkmE)
181 setupDerandomizedEncap(t, ikmE, kemSender)
182
183 info := mustDecodeHex(t, vector.Info)
184 encap, sender, err := NewSender(kemSender, kdf, aead, info)
185 if err != nil {
186 t.Fatal(err)
187 }
188 if len(encap) != kem.encSize() {
189 t.Errorf("unexpected encapsulated key size: got %d, want %d", len(encap), kem.encSize())
190 }
191
192 expectedEncap := mustDecodeHex(t, vector.Enc)
193 if !bytes.Equal(encap, expectedEncap) {
194 t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
195 }
196
197 privKeyBytes := mustDecodeHex(t, vector.SkRm)
198 kemRecipient, err := kem.NewPrivateKey(privKeyBytes)
199 if err != nil {
200 t.Fatal(err)
201 }
202 if kemRecipient.KEM() != kem {
203 t.Errorf("unexpected KEM from recipient: got %04x, want %04x", kemRecipient.KEM().ID(), kem.ID())
204 }
205 kemRecipientBytes, err := kemRecipient.Bytes()
206 if err != nil {
207 t.Fatal(err)
208 }
209
210 if !bytes.Equal(kemRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
211 t.Errorf("unexpected KEM bytes: got %x, want %x", kemRecipientBytes, privKeyBytes)
212 }
213 if vector.KEM == DHKEM(ecdh.X25519()).ID() {
214 kem2, err := kem.NewPrivateKey(kemRecipientBytes)
215 if err != nil {
216 t.Fatal(err)
217 }
218 kemRecipientBytes2, err := kem2.Bytes()
219 if err != nil {
220 t.Fatal(err)
221 }
222 if !bytes.Equal(kemRecipientBytes2, kemRecipientBytes) {
223 t.Errorf("X25519 re-serialized key differs: got %x, want %x", kemRecipientBytes2, kemRecipientBytes)
224 }
225 if !bytes.Equal(kem2.PublicKey().Bytes(), pubKeyBytes) {
226 t.Errorf("X25519 re-derived public key differs: got %x, want %x", kem2.PublicKey().Bytes(), pubKeyBytes)
227 }
228 }
229 if !bytes.Equal(kemRecipient.PublicKey().Bytes(), pubKeyBytes) {
230 t.Errorf("unexpected KEM sender bytes: got %x, want %x", kemRecipient.PublicKey().Bytes(), pubKeyBytes)
231 }
232
233 ikm := mustDecodeHex(t, vector.IkmR)
234 derivRecipient, err := kem.DeriveKeyPair(ikm)
235 if err != nil {
236 t.Fatal(err)
237 }
238 derivRecipientBytes, err := derivRecipient.Bytes()
239 if err != nil {
240 t.Fatal(err)
241 }
242 if !bytes.Equal(derivRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
243 t.Errorf("unexpected KEM bytes from seed: got %x, want %x", derivRecipientBytes, privKeyBytes)
244 }
245 if !bytes.Equal(derivRecipient.PublicKey().Bytes(), pubKeyBytes) {
246 t.Errorf("unexpected KEM sender bytes from seed: got %x, want %x", derivRecipient.PublicKey().Bytes(), pubKeyBytes)
247 }
248
249 recipient, err := NewRecipient(encap, kemRecipient, kdf, aead, info)
250 if err != nil {
251 t.Fatal(err)
252 }
253
254 if aead != ExportOnly() && len(vector.AccEncryptions) != 0 {
255 source, sink := sha3.NewSHAKE128(), sha3.NewSHAKE128()
256 for range 1000 {
257 aad, plaintext := drawRandomInput(t, source), drawRandomInput(t, source)
258 ciphertext, err := sender.Seal(aad, plaintext)
259 if err != nil {
260 t.Fatal(err)
261 }
262 sink.Write(ciphertext)
263 got, err := recipient.Open(aad, ciphertext)
264 if err != nil {
265 t.Fatal(err)
266 }
267 if !bytes.Equal(got, plaintext) {
268 t.Errorf("unexpected plaintext: got %x want %x", got, plaintext)
269 }
270 }
271 encryptions := make([]byte, 16)
272 sink.Read(encryptions)
273 expectedEncryptions := mustDecodeHex(t, vector.AccEncryptions)
274 if !bytes.Equal(encryptions, expectedEncryptions) {
275 t.Errorf("unexpected accumulated encryptions, got: %x, want %x", encryptions, expectedEncryptions)
276 }
277 } else if aead != ExportOnly() {
278 for _, enc := range vector.Encryptions {
279 aad := mustDecodeHex(t, enc.Aad)
280 plaintext := mustDecodeHex(t, enc.Pt)
281 expectedCiphertext := mustDecodeHex(t, enc.Ct)
282
283 ciphertext, err := sender.Seal(aad, plaintext)
284 if err != nil {
285 t.Fatal(err)
286 }
287 if !bytes.Equal(ciphertext, expectedCiphertext) {
288 t.Errorf("unexpected ciphertext, got: %x, want %x", ciphertext, expectedCiphertext)
289 }
290
291 got, err := recipient.Open(aad, ciphertext)
292 if err != nil {
293 t.Fatal(err)
294 }
295 if !bytes.Equal(got, plaintext) {
296 t.Errorf("unexpected plaintext: got %x want %x", got, plaintext)
297 }
298 }
299 } else {
300 if _, err := sender.Seal(nil, nil); err == nil {
301 t.Error("expected error from Seal with export-only AEAD")
302 }
303 if _, err := recipient.Open(nil, nil); err == nil {
304 t.Error("expected error from Open with export-only AEAD")
305 }
306 }
307
308 if len(vector.AccExports) != 0 {
309 source, sink := sha3.NewSHAKE128(), sha3.NewSHAKE128()
310 for l := range 1000 {
311 context := string(drawRandomInput(t, source))
312 value, err := sender.Export(context, l)
313 if err != nil {
314 t.Fatal(err)
315 }
316 sink.Write(value)
317 got, err := recipient.Export(context, l)
318 if err != nil {
319 t.Fatal(err)
320 }
321 if !bytes.Equal(got, value) {
322 t.Errorf("recipient: unexpected exported secret: got %x want %x", got, value)
323 }
324 }
325 exports := make([]byte, 16)
326 sink.Read(exports)
327 expectedExports := mustDecodeHex(t, vector.AccExports)
328 if !bytes.Equal(exports, expectedExports) {
329 t.Errorf("unexpected accumulated exports, got: %x, want %x", exports, expectedExports)
330 }
331 } else {
332 for _, exp := range vector.Exports {
333 context := string(mustDecodeHex(t, exp.Context))
334 expectedValue := mustDecodeHex(t, exp.Value)
335
336 value, err := sender.Export(context, exp.L)
337 if err != nil {
338 t.Fatal(err)
339 }
340 if !bytes.Equal(value, expectedValue) {
341 t.Errorf("unexpected exported value, got: %x, want %x", value, expectedValue)
342 }
343
344 got, err := recipient.Export(context, exp.L)
345 if err != nil {
346 t.Fatal(err)
347 }
348 if !bytes.Equal(got, value) {
349 t.Errorf("recipient: unexpected exported secret: got %x want %x", got, value)
350 }
351 }
352 }
353 })
354 }
355 }
356
357 func drawRandomInput(t *testing.T, r io.Reader) []byte {
358 t.Helper()
359 l := make([]byte, 1)
360 if _, err := r.Read(l); err != nil {
361 t.Fatal(err)
362 }
363 n := int(l[0])
364 b := make([]byte, n)
365 if _, err := r.Read(b); err != nil {
366 t.Fatal(err)
367 }
368 return b
369 }
370
371 func setupDerandomizedEncap(t *testing.T, randBytes []byte, pk PublicKey) {
372 t.Cleanup(func() {
373 testingOnlyGenerateKey = nil
374 testingOnlyEncapsulate = nil
375 })
376 switch pk.KEM() {
377 case DHKEM(ecdh.P256()), DHKEM(ecdh.P384()), DHKEM(ecdh.P521()), DHKEM(ecdh.X25519()):
378 r, err := pk.KEM().DeriveKeyPair(randBytes)
379 if err != nil {
380 t.Fatal(err)
381 }
382 testingOnlyGenerateKey = func() *ecdh.PrivateKey {
383 return r.(*dhKEMPrivateKey).priv.(*ecdh.PrivateKey)
384 }
385 case mlkem768:
386 pq := pk.(*mlkemPublicKey).pq.(*mlkem.EncapsulationKey768)
387 testingOnlyEncapsulate = func() ([]byte, []byte) {
388 ss, ct, err := mlkemtest.Encapsulate768(pq, randBytes)
389 if err != nil {
390 t.Fatal(err)
391 }
392 return ss, ct
393 }
394 case mlkem1024:
395 pq := pk.(*mlkemPublicKey).pq.(*mlkem.EncapsulationKey1024)
396 testingOnlyEncapsulate = func() ([]byte, []byte) {
397 ss, ct, err := mlkemtest.Encapsulate1024(pq, randBytes)
398 if err != nil {
399 t.Fatal(err)
400 }
401 return ss, ct
402 }
403 case mlkem768X25519:
404 pqRand, tRand := randBytes[:32], randBytes[32:]
405 pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey768)
406 k, err := ecdh.X25519().NewPrivateKey(tRand)
407 if err != nil {
408 t.Fatal(err)
409 }
410 testingOnlyGenerateKey = func() *ecdh.PrivateKey {
411 return k
412 }
413 testingOnlyEncapsulate = func() ([]byte, []byte) {
414 ss, ct, err := mlkemtest.Encapsulate768(pq, pqRand)
415 if err != nil {
416 t.Fatal(err)
417 }
418 return ss, ct
419 }
420 case mlkem768P256:
421
422
423 pqRand, tRand := randBytes[:32], randBytes[32:64]
424 pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey768)
425 k, err := ecdh.P256().NewPrivateKey(tRand)
426 if err != nil {
427 t.Fatal(err)
428 }
429 testingOnlyGenerateKey = func() *ecdh.PrivateKey {
430 return k
431 }
432 testingOnlyEncapsulate = func() ([]byte, []byte) {
433 ss, ct, err := mlkemtest.Encapsulate768(pq, pqRand)
434 if err != nil {
435 t.Fatal(err)
436 }
437 return ss, ct
438 }
439 case mlkem1024P384:
440 pqRand, tRand := randBytes[:32], randBytes[32:]
441 pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey1024)
442 k, err := ecdh.P384().NewPrivateKey(tRand)
443 if err != nil {
444 t.Fatal(err)
445 }
446 testingOnlyGenerateKey = func() *ecdh.PrivateKey {
447 return k
448 }
449 testingOnlyEncapsulate = func() ([]byte, []byte) {
450 ss, ct, err := mlkemtest.Encapsulate1024(pq, pqRand)
451 if err != nil {
452 t.Fatal(err)
453 }
454 return ss, ct
455 }
456 default:
457 t.Fatalf("unsupported KEM %04x", pk.KEM().ID())
458 }
459 }
460
461 func TestSingletons(t *testing.T) {
462 if HKDFSHA256() != HKDFSHA256() {
463 t.Error("HKDFSHA256() != HKDFSHA256()")
464 }
465 if HKDFSHA384() != HKDFSHA384() {
466 t.Error("HKDFSHA384() != HKDFSHA384()")
467 }
468 if HKDFSHA512() != HKDFSHA512() {
469 t.Error("HKDFSHA512() != HKDFSHA512()")
470 }
471 if AES128GCM() != AES128GCM() {
472 t.Error("AES128GCM() != AES128GCM()")
473 }
474 if AES256GCM() != AES256GCM() {
475 t.Error("AES256GCM() != AES256GCM()")
476 }
477 if ChaCha20Poly1305() != ChaCha20Poly1305() {
478 t.Error("ChaCha20Poly1305() != ChaCha20Poly1305()")
479 }
480 if ExportOnly() != ExportOnly() {
481 t.Error("ExportOnly() != ExportOnly()")
482 }
483 if DHKEM(ecdh.P256()) != DHKEM(ecdh.P256()) {
484 t.Error("DHKEM(P-256) != DHKEM(P-256)")
485 }
486 if DHKEM(ecdh.P384()) != DHKEM(ecdh.P384()) {
487 t.Error("DHKEM(P-384) != DHKEM(P-384)")
488 }
489 if DHKEM(ecdh.P521()) != DHKEM(ecdh.P521()) {
490 t.Error("DHKEM(P-521) != DHKEM(P-521)")
491 }
492 if DHKEM(ecdh.X25519()) != DHKEM(ecdh.X25519()) {
493 t.Error("DHKEM(X25519) != DHKEM(X25519)")
494 }
495 if MLKEM768() != MLKEM768() {
496 t.Error("MLKEM768() != MLKEM768()")
497 }
498 if MLKEM1024() != MLKEM1024() {
499 t.Error("MLKEM1024() != MLKEM1024()")
500 }
501 if MLKEM768X25519() != MLKEM768X25519() {
502 t.Error("MLKEM768X25519() != MLKEM768X25519()")
503 }
504 if MLKEM768P256() != MLKEM768P256() {
505 t.Error("MLKEM768P256() != MLKEM768P256()")
506 }
507 if MLKEM1024P384() != MLKEM1024P384() {
508 t.Error("MLKEM1024P384() != MLKEM1024P384()")
509 }
510 }
511
View as plain text