1
2
3
4
5 package hpke
6
7 import (
8 "bytes"
9 "encoding/hex"
10 "encoding/json"
11 "os"
12 "strconv"
13 "strings"
14 "testing"
15
16 "crypto/ecdh"
17 _ "crypto/sha256"
18 _ "crypto/sha512"
19 )
20
21 func mustDecodeHex(t *testing.T, in string) []byte {
22 t.Helper()
23 b, err := hex.DecodeString(in)
24 if err != nil {
25 t.Fatal(err)
26 }
27 return b
28 }
29
30 func parseVectorSetup(vector string) map[string]string {
31 vals := map[string]string{}
32 for _, l := range strings.Split(vector, "\n") {
33 fields := strings.Split(l, ": ")
34 vals[fields[0]] = fields[1]
35 }
36 return vals
37 }
38
39 func parseVectorEncryptions(vector string) []map[string]string {
40 vals := []map[string]string{}
41 for _, section := range strings.Split(vector, "\n\n") {
42 e := map[string]string{}
43 for _, l := range strings.Split(section, "\n") {
44 fields := strings.Split(l, ": ")
45 e[fields[0]] = fields[1]
46 }
47 vals = append(vals, e)
48 }
49 return vals
50 }
51
52 func TestRFC9180Vectors(t *testing.T) {
53 vectorsJSON, err := os.ReadFile("testdata/rfc9180-vectors.json")
54 if err != nil {
55 t.Fatal(err)
56 }
57
58 var vectors []struct {
59 Name string
60 Setup string
61 Encryptions string
62 }
63 if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
64 t.Fatal(err)
65 }
66
67 for _, vector := range vectors {
68 t.Run(vector.Name, func(t *testing.T) {
69 setup := parseVectorSetup(vector.Setup)
70
71 kemID, err := strconv.Atoi(setup["kem_id"])
72 if err != nil {
73 t.Fatal(err)
74 }
75 if _, ok := SupportedKEMs[uint16(kemID)]; !ok {
76 t.Skip("unsupported KEM")
77 }
78 kdfID, err := strconv.Atoi(setup["kdf_id"])
79 if err != nil {
80 t.Fatal(err)
81 }
82 if _, ok := SupportedKDFs[uint16(kdfID)]; !ok {
83 t.Skip("unsupported KDF")
84 }
85 aeadID, err := strconv.Atoi(setup["aead_id"])
86 if err != nil {
87 t.Fatal(err)
88 }
89 if _, ok := SupportedAEADs[uint16(aeadID)]; !ok {
90 t.Skip("unsupported AEAD")
91 }
92
93 info := mustDecodeHex(t, setup["info"])
94 pubKeyBytes := mustDecodeHex(t, setup["pkRm"])
95 pub, err := ParseHPKEPublicKey(uint16(kemID), pubKeyBytes)
96 if err != nil {
97 t.Fatal(err)
98 }
99
100 ephemeralPrivKey := mustDecodeHex(t, setup["skEm"])
101
102 testingOnlyGenerateKey = func() (*ecdh.PrivateKey, error) {
103 return SupportedKEMs[uint16(kemID)].curve.NewPrivateKey(ephemeralPrivKey)
104 }
105 t.Cleanup(func() { testingOnlyGenerateKey = nil })
106
107 encap, sender, err := SetupSender(
108 uint16(kemID),
109 uint16(kdfID),
110 uint16(aeadID),
111 pub,
112 info,
113 )
114 if err != nil {
115 t.Fatal(err)
116 }
117
118 expectedEncap := mustDecodeHex(t, setup["enc"])
119 if !bytes.Equal(encap, expectedEncap) {
120 t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
121 }
122
123 privKeyBytes := mustDecodeHex(t, setup["skRm"])
124 priv, err := ParseHPKEPrivateKey(uint16(kemID), privKeyBytes)
125 if err != nil {
126 t.Fatal(err)
127 }
128
129 receipient, err := SetupReceipient(
130 uint16(kemID),
131 uint16(kdfID),
132 uint16(aeadID),
133 priv,
134 info,
135 encap,
136 )
137 if err != nil {
138 t.Fatal(err)
139 }
140
141 for _, ctx := range []*context{sender.context, receipient.context} {
142 expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
143 if !bytes.Equal(ctx.sharedSecret, expectedSharedSecret) {
144 t.Errorf("unexpected shared secret, got: %x, want %x", ctx.sharedSecret, expectedSharedSecret)
145 }
146 expectedKey := mustDecodeHex(t, setup["key"])
147 if !bytes.Equal(ctx.key, expectedKey) {
148 t.Errorf("unexpected key, got: %x, want %x", ctx.key, expectedKey)
149 }
150 expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
151 if !bytes.Equal(ctx.baseNonce, expectedBaseNonce) {
152 t.Errorf("unexpected base nonce, got: %x, want %x", ctx.baseNonce, expectedBaseNonce)
153 }
154 expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
155 if !bytes.Equal(ctx.exporterSecret, expectedExporterSecret) {
156 t.Errorf("unexpected exporter secret, got: %x, want %x", ctx.exporterSecret, expectedExporterSecret)
157 }
158 }
159
160 for _, enc := range parseVectorEncryptions(vector.Encryptions) {
161 t.Run("seq num "+enc["sequence number"], func(t *testing.T) {
162 seqNum, err := strconv.Atoi(enc["sequence number"])
163 if err != nil {
164 t.Fatal(err)
165 }
166 sender.seqNum = uint128{lo: uint64(seqNum)}
167 receipient.seqNum = uint128{lo: uint64(seqNum)}
168 expectedNonce := mustDecodeHex(t, enc["nonce"])
169 computedNonce := sender.nextNonce()
170 if !bytes.Equal(computedNonce, expectedNonce) {
171 t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce)
172 }
173
174 expectedCiphertext := mustDecodeHex(t, enc["ct"])
175 ciphertext, err := sender.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
176 if err != nil {
177 t.Fatal(err)
178 }
179 if !bytes.Equal(ciphertext, expectedCiphertext) {
180 t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext)
181 }
182
183 expectedPlaintext := mustDecodeHex(t, enc["pt"])
184 plaintext, err := receipient.Open(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["ct"]))
185 if err != nil {
186 t.Fatal(err)
187 }
188 if !bytes.Equal(plaintext, expectedPlaintext) {
189 t.Errorf("unexpected plaintext: got %x want %x", plaintext, expectedPlaintext)
190 }
191 })
192 }
193 })
194 }
195 }
196
View as plain text