1
2
3
4
5
6
7
8
9 package hpke
10
11 import (
12 "crypto/cipher"
13 "errors"
14 "internal/byteorder"
15 )
16
17 type context struct {
18 suiteID []byte
19
20 export func(string, uint16) ([]byte, error)
21
22 aead cipher.AEAD
23 baseNonce []byte
24
25
26 seqNum uint64
27 }
28
29
30
31
32 type Sender struct {
33 *context
34 }
35
36
37
38
39 type Recipient struct {
40 *context
41 }
42
43 func newContext(sharedSecret []byte, kemID uint16, kdf KDF, aead AEAD, info []byte) (*context, error) {
44 sid := suiteID(kemID, kdf.ID(), aead.ID())
45
46 if kdf.oneStage() {
47 secrets := make([]byte, 0, 2+2+len(sharedSecret))
48 secrets = byteorder.BEAppendUint16(secrets, 0)
49 secrets = byteorder.BEAppendUint16(secrets, uint16(len(sharedSecret)))
50 secrets = append(secrets, sharedSecret...)
51
52 ksContext := make([]byte, 0, 1+2+2+len(info))
53 ksContext = append(ksContext, 0)
54 ksContext = byteorder.BEAppendUint16(ksContext, 0)
55 ksContext = byteorder.BEAppendUint16(ksContext, uint16(len(info)))
56 ksContext = append(ksContext, info...)
57
58 secret, err := kdf.labeledDerive(sid, secrets, "secret", ksContext,
59 uint16(aead.keySize()+aead.nonceSize()+kdf.size()))
60 if err != nil {
61 return nil, err
62 }
63 key := secret[:aead.keySize()]
64 baseNonce := secret[aead.keySize() : aead.keySize()+aead.nonceSize()]
65 expSecret := secret[aead.keySize()+aead.nonceSize():]
66
67 a, err := aead.aead(key)
68 if err != nil {
69 return nil, err
70 }
71 export := func(exporterContext string, length uint16) ([]byte, error) {
72 return kdf.labeledDerive(sid, expSecret, "sec", []byte(exporterContext), length)
73 }
74
75 return &context{
76 aead: a,
77 suiteID: sid,
78 export: export,
79 baseNonce: baseNonce,
80 }, nil
81 }
82
83 pskIDHash, err := kdf.labeledExtract(sid, nil, "psk_id_hash", nil)
84 if err != nil {
85 return nil, err
86 }
87 infoHash, err := kdf.labeledExtract(sid, nil, "info_hash", info)
88 if err != nil {
89 return nil, err
90 }
91 ksContext := append([]byte{0}, pskIDHash...)
92 ksContext = append(ksContext, infoHash...)
93
94 secret, err := kdf.labeledExtract(sid, sharedSecret, "secret", nil)
95 if err != nil {
96 return nil, err
97 }
98 key, err := kdf.labeledExpand(sid, secret, "key", ksContext, uint16(aead.keySize()))
99 if err != nil {
100 return nil, err
101 }
102 a, err := aead.aead(key)
103 if err != nil {
104 return nil, err
105 }
106 baseNonce, err := kdf.labeledExpand(sid, secret, "base_nonce", ksContext, uint16(aead.nonceSize()))
107 if err != nil {
108 return nil, err
109 }
110 expSecret, err := kdf.labeledExpand(sid, secret, "exp", ksContext, uint16(kdf.size()))
111 if err != nil {
112 return nil, err
113 }
114 export := func(exporterContext string, length uint16) ([]byte, error) {
115 return kdf.labeledExpand(sid, expSecret, "sec", []byte(exporterContext), length)
116 }
117
118 return &context{
119 aead: a,
120 suiteID: sid,
121 export: export,
122 baseNonce: baseNonce,
123 }, nil
124 }
125
126
127
128
129
130
131
132
133
134
135 func NewSender(pk PublicKey, kdf KDF, aead AEAD, info []byte) (enc []byte, s *Sender, err error) {
136 sharedSecret, encapsulatedKey, err := pk.encap()
137 if err != nil {
138 return nil, nil, err
139 }
140 context, err := newContext(sharedSecret, pk.KEM().ID(), kdf, aead, info)
141 if err != nil {
142 return nil, nil, err
143 }
144 return encapsulatedKey, &Sender{context}, nil
145 }
146
147
148
149
150
151
152
153
154 func NewRecipient(enc []byte, k PrivateKey, kdf KDF, aead AEAD, info []byte) (*Recipient, error) {
155 sharedSecret, err := k.decap(enc)
156 if err != nil {
157 return nil, err
158 }
159 context, err := newContext(sharedSecret, k.KEM().ID(), kdf, aead, info)
160 if err != nil {
161 return nil, err
162 }
163 return &Recipient{context}, nil
164 }
165
166
167
168
169
170
171 func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
172 if s.aead == nil {
173 return nil, errors.New("export-only instantiation")
174 }
175 ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
176 s.seqNum++
177 return ciphertext, nil
178 }
179
180
181
182
183 func Seal(pk PublicKey, kdf KDF, aead AEAD, info, plaintext []byte) ([]byte, error) {
184 enc, s, err := NewSender(pk, kdf, aead, info)
185 if err != nil {
186 return nil, err
187 }
188 ct, err := s.Seal(nil, plaintext)
189 if err != nil {
190 return nil, err
191 }
192 return append(enc, ct...), nil
193 }
194
195
196
197 func (s *Sender) Export(exporterContext string, length int) ([]byte, error) {
198 if length < 0 || length > 0xFFFF {
199 return nil, errors.New("invalid length")
200 }
201 return s.export(exporterContext, uint16(length))
202 }
203
204
205
206
207
208
209 func (r *Recipient) Open(aad, ciphertext []byte) ([]byte, error) {
210 if r.aead == nil {
211 return nil, errors.New("export-only instantiation")
212 }
213 plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad)
214 if err != nil {
215 return nil, err
216 }
217 r.seqNum++
218 return plaintext, nil
219 }
220
221
222
223
224 func Open(k PrivateKey, kdf KDF, aead AEAD, info, ciphertext []byte) ([]byte, error) {
225 encSize := k.KEM().encSize()
226 if len(ciphertext) < encSize {
227 return nil, errors.New("ciphertext too short")
228 }
229 enc, ciphertext := ciphertext[:encSize], ciphertext[encSize:]
230 r, err := NewRecipient(enc, k, kdf, aead, info)
231 if err != nil {
232 return nil, err
233 }
234 return r.Open(nil, ciphertext)
235 }
236
237
238
239 func (r *Recipient) Export(exporterContext string, length int) ([]byte, error) {
240 if length < 0 || length > 0xFFFF {
241 return nil, errors.New("invalid length")
242 }
243 return r.export(exporterContext, uint16(length))
244 }
245
246 func (ctx *context) nextNonce() []byte {
247 nonce := make([]byte, ctx.aead.NonceSize())
248 byteorder.BEPutUint64(nonce[len(nonce)-8:], ctx.seqNum)
249 for i := range ctx.baseNonce {
250 nonce[i] ^= ctx.baseNonce[i]
251 }
252 return nonce
253 }
254
255 func suiteID(kemID, kdfID, aeadID uint16) []byte {
256 suiteID := make([]byte, 0, 4+2+2+2)
257 suiteID = append(suiteID, []byte("HPKE")...)
258 suiteID = byteorder.BEAppendUint16(suiteID, kemID)
259 suiteID = byteorder.BEAppendUint16(suiteID, kdfID)
260 suiteID = byteorder.BEAppendUint16(suiteID, aeadID)
261 return suiteID
262 }
263
View as plain text