1
2
3
4
5 package hpke
6
7 import (
8 "crypto"
9 "crypto/aes"
10 "crypto/cipher"
11 "crypto/ecdh"
12 "crypto/internal/fips140/hkdf"
13 "crypto/rand"
14 "errors"
15 "internal/byteorder"
16 "math/bits"
17
18 "golang.org/x/crypto/chacha20poly1305"
19 )
20
21
22
23 var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error)
24
25 type hkdfKDF struct {
26 hash crypto.Hash
27 }
28
29 func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) []byte {
30 labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey))
31 labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
32 labeledIKM = append(labeledIKM, sid...)
33 labeledIKM = append(labeledIKM, label...)
34 labeledIKM = append(labeledIKM, inputKey...)
35 return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
36 }
37
38 func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) []byte {
39 labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info))
40 labeledInfo = byteorder.BEAppendUint16(labeledInfo, length)
41 labeledInfo = append(labeledInfo, []byte("HPKE-v1")...)
42 labeledInfo = append(labeledInfo, suiteID...)
43 labeledInfo = append(labeledInfo, label...)
44 labeledInfo = append(labeledInfo, info...)
45 return hkdf.Expand(kdf.hash.New, randomKey, string(labeledInfo), int(length))
46 }
47
48
49 type dhKEM struct {
50 dh ecdh.Curve
51 kdf hkdfKDF
52
53 suiteID []byte
54 nSecret uint16
55 }
56
57 type KemID uint16
58
59 const DHKEM_X25519_HKDF_SHA256 = 0x0020
60
61 var SupportedKEMs = map[uint16]struct {
62 curve ecdh.Curve
63 hash crypto.Hash
64 nSecret uint16
65 }{
66
67 DHKEM_X25519_HKDF_SHA256: {ecdh.X25519(), crypto.SHA256, 32},
68 }
69
70 func newDHKem(kemID uint16) (*dhKEM, error) {
71 suite, ok := SupportedKEMs[kemID]
72 if !ok {
73 return nil, errors.New("unsupported suite ID")
74 }
75 return &dhKEM{
76 dh: suite.curve,
77 kdf: hkdfKDF{suite.hash},
78 suiteID: byteorder.BEAppendUint16([]byte("KEM"), kemID),
79 nSecret: suite.nSecret,
80 }, nil
81 }
82
83 func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) []byte {
84 eaePRK := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey)
85 return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret)
86 }
87
88 func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encapPub []byte, err error) {
89 var privEph *ecdh.PrivateKey
90 if testingOnlyGenerateKey != nil {
91 privEph, err = testingOnlyGenerateKey()
92 } else {
93 privEph, err = dh.dh.GenerateKey(rand.Reader)
94 }
95 if err != nil {
96 return nil, nil, err
97 }
98 dhVal, err := privEph.ECDH(pubRecipient)
99 if err != nil {
100 return nil, nil, err
101 }
102 encPubEph := privEph.PublicKey().Bytes()
103
104 encPubRecip := pubRecipient.Bytes()
105 kemContext := append(encPubEph, encPubRecip...)
106
107 return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil
108 }
109
110 func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) {
111 pubEph, err := dh.dh.NewPublicKey(encPubEph)
112 if err != nil {
113 return nil, err
114 }
115 dhVal, err := secRecipient.ECDH(pubEph)
116 if err != nil {
117 return nil, err
118 }
119 kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...)
120
121 return dh.ExtractAndExpand(dhVal, kemContext), nil
122 }
123
124 type context struct {
125 aead cipher.AEAD
126
127 sharedSecret []byte
128
129 suiteID []byte
130
131 key []byte
132 baseNonce []byte
133 exporterSecret []byte
134
135 seqNum uint128
136 }
137
138 type Sender struct {
139 *context
140 }
141
142 type Receipient struct {
143 *context
144 }
145
146 var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
147 block, err := aes.NewCipher(key)
148 if err != nil {
149 return nil, err
150 }
151 return cipher.NewGCM(block)
152 }
153
154 type AEADID uint16
155
156 const (
157 AEAD_AES_128_GCM = 0x0001
158 AEAD_AES_256_GCM = 0x0002
159 AEAD_ChaCha20Poly1305 = 0x0003
160 )
161
162 var SupportedAEADs = map[uint16]struct {
163 keySize int
164 nonceSize int
165 aead func([]byte) (cipher.AEAD, error)
166 }{
167
168 AEAD_AES_128_GCM: {keySize: 16, nonceSize: 12, aead: aesGCMNew},
169 AEAD_AES_256_GCM: {keySize: 32, nonceSize: 12, aead: aesGCMNew},
170 AEAD_ChaCha20Poly1305: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
171 }
172
173 type KDFID uint16
174
175 const KDF_HKDF_SHA256 = 0x0001
176
177 var SupportedKDFs = map[uint16]func() *hkdfKDF{
178
179 KDF_HKDF_SHA256: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
180 }
181
182 func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) (*context, error) {
183 sid := suiteID(kemID, kdfID, aeadID)
184
185 kdfInit, ok := SupportedKDFs[kdfID]
186 if !ok {
187 return nil, errors.New("unsupported KDF id")
188 }
189 kdf := kdfInit()
190
191 aeadInfo, ok := SupportedAEADs[aeadID]
192 if !ok {
193 return nil, errors.New("unsupported AEAD id")
194 }
195
196 pskIDHash := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil)
197 infoHash := kdf.LabeledExtract(sid, nil, "info_hash", info)
198 ksContext := append([]byte{0}, pskIDHash...)
199 ksContext = append(ksContext, infoHash...)
200
201 secret := kdf.LabeledExtract(sid, sharedSecret, "secret", nil)
202
203 key := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) )
204 baseNonce := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) )
205 exporterSecret := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) )
206
207 aead, err := aeadInfo.aead(key)
208 if err != nil {
209 return nil, err
210 }
211
212 return &context{
213 aead: aead,
214 sharedSecret: sharedSecret,
215 suiteID: sid,
216 key: key,
217 baseNonce: baseNonce,
218 exporterSecret: exporterSecret,
219 }, nil
220 }
221
222 func SetupSender(kemID, kdfID, aeadID uint16, pub *ecdh.PublicKey, info []byte) ([]byte, *Sender, error) {
223 kem, err := newDHKem(kemID)
224 if err != nil {
225 return nil, nil, err
226 }
227 sharedSecret, encapsulatedKey, err := kem.Encap(pub)
228 if err != nil {
229 return nil, nil, err
230 }
231
232 context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
233 if err != nil {
234 return nil, nil, err
235 }
236
237 return encapsulatedKey, &Sender{context}, nil
238 }
239
240 func SetupReceipient(kemID, kdfID, aeadID uint16, priv *ecdh.PrivateKey, info, encPubEph []byte) (*Receipient, error) {
241 kem, err := newDHKem(kemID)
242 if err != nil {
243 return nil, err
244 }
245 sharedSecret, err := kem.Decap(encPubEph, priv)
246 if err != nil {
247 return nil, err
248 }
249
250 context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
251 if err != nil {
252 return nil, err
253 }
254
255 return &Receipient{context}, nil
256 }
257
258 func (ctx *context) nextNonce() []byte {
259 nonce := ctx.seqNum.bytes()[16-ctx.aead.NonceSize():]
260 for i := range ctx.baseNonce {
261 nonce[i] ^= ctx.baseNonce[i]
262 }
263 return nonce
264 }
265
266 func (ctx *context) incrementNonce() {
267
268
269 if ctx.seqNum.bitLen() >= (ctx.aead.NonceSize()*8)-1 {
270 panic("message limit reached")
271 }
272 ctx.seqNum = ctx.seqNum.addOne()
273 }
274
275 func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
276 ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
277 s.incrementNonce()
278 return ciphertext, nil
279 }
280
281 func (r *Receipient) Open(aad, ciphertext []byte) ([]byte, error) {
282 plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad)
283 if err != nil {
284 return nil, err
285 }
286 r.incrementNonce()
287 return plaintext, nil
288 }
289
290 func suiteID(kemID, kdfID, aeadID uint16) []byte {
291 suiteID := make([]byte, 0, 4+2+2+2)
292 suiteID = append(suiteID, []byte("HPKE")...)
293 suiteID = byteorder.BEAppendUint16(suiteID, kemID)
294 suiteID = byteorder.BEAppendUint16(suiteID, kdfID)
295 suiteID = byteorder.BEAppendUint16(suiteID, aeadID)
296 return suiteID
297 }
298
299 func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) {
300 kemInfo, ok := SupportedKEMs[kemID]
301 if !ok {
302 return nil, errors.New("unsupported KEM id")
303 }
304 return kemInfo.curve.NewPublicKey(bytes)
305 }
306
307 func ParseHPKEPrivateKey(kemID uint16, bytes []byte) (*ecdh.PrivateKey, error) {
308 kemInfo, ok := SupportedKEMs[kemID]
309 if !ok {
310 return nil, errors.New("unsupported KEM id")
311 }
312 return kemInfo.curve.NewPrivateKey(bytes)
313 }
314
315 type uint128 struct {
316 hi, lo uint64
317 }
318
319 func (u uint128) addOne() uint128 {
320 lo, carry := bits.Add64(u.lo, 1, 0)
321 return uint128{u.hi + carry, lo}
322 }
323
324 func (u uint128) bitLen() int {
325 return bits.Len64(u.hi) + bits.Len64(u.lo)
326 }
327
328 func (u uint128) bytes() []byte {
329 b := make([]byte, 16)
330 byteorder.BEPutUint64(b[0:], u.hi)
331 byteorder.BEPutUint64(b[8:], u.lo)
332 return b
333 }
334
View as plain text