1
2
3
4
5 package rsa
6
7
8
9
10 import (
11 "bytes"
12 "crypto/internal/constanttime"
13 "crypto/internal/fips140"
14 "crypto/internal/fips140/drbg"
15 "crypto/internal/fips140/sha256"
16 "crypto/internal/fips140/sha3"
17 "crypto/internal/fips140/sha512"
18 "crypto/internal/fips140/subtle"
19 "errors"
20 "hash"
21 "io"
22 )
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 func incCounter(c *[4]byte) {
39 if c[3]++; c[3] != 0 {
40 return
41 }
42 if c[2]++; c[2] != 0 {
43 return
44 }
45 if c[1]++; c[1] != 0 {
46 return
47 }
48 c[0]++
49 }
50
51
52
53 func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
54 var counter [4]byte
55 var digest []byte
56
57 done := 0
58 for done < len(out) {
59 hash.Reset()
60 hash.Write(seed)
61 hash.Write(counter[0:4])
62 digest = hash.Sum(digest[:0])
63
64 for i := 0; i < len(digest) && done < len(out); i++ {
65 out[done] ^= digest[i]
66 done++
67 }
68 incCounter(&counter)
69 }
70 }
71
72 func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
73
74
75 hLen := hash.Size()
76 sLen := len(salt)
77 emLen := (emBits + 7) / 8
78
79
80
81
82
83
84
85 if len(mHash) != hLen {
86 return nil, errors.New("crypto/rsa: input must be hashed with given hash")
87 }
88
89
90
91 if emLen < hLen+sLen+2 {
92 return nil, ErrMessageTooLong
93 }
94
95 em := make([]byte, emLen)
96 psLen := emLen - sLen - hLen - 2
97 db := em[:psLen+1+sLen]
98 h := em[psLen+1+sLen : emLen-1]
99
100
101
102
103
104
105
106
107
108
109
110
111 var prefix [8]byte
112
113 hash.Reset()
114 hash.Write(prefix[:])
115 hash.Write(mHash)
116 hash.Write(salt)
117
118 h = hash.Sum(h[:0])
119
120
121
122
123
124
125
126 db[psLen] = 0x01
127 copy(db[psLen+1:], salt)
128
129
130
131
132
133 mgf1XOR(db, hash, h)
134
135
136
137
138 db[0] &= 0xff >> (8*emLen - emBits)
139
140
141 em[emLen-1] = 0xbc
142
143
144 return em, nil
145 }
146
147 const pssSaltLengthAutodetect = -1
148
149 func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
150
151
152 hLen := hash.Size()
153 emLen := (emBits + 7) / 8
154 if emLen != len(em) {
155 return errors.New("rsa: internal error: inconsistent length")
156 }
157
158
159
160
161
162
163 if hLen != len(mHash) {
164 return ErrVerification
165 }
166
167
168 if emLen < hLen+sLen+2 {
169 return ErrVerification
170 }
171
172
173
174 if em[emLen-1] != 0xbc {
175 return ErrVerification
176 }
177
178
179
180 db := em[:emLen-hLen-1]
181 h := em[emLen-hLen-1 : emLen-1]
182
183
184
185
186 var bitMask byte = 0xff >> (8*emLen - emBits)
187 if em[0] & ^bitMask != 0 {
188 return ErrVerification
189 }
190
191
192
193
194 mgf1XOR(db, hash, h)
195
196
197
198 db[0] &= bitMask
199
200
201 if sLen == pssSaltLengthAutodetect {
202 psLen := bytes.IndexByte(db, 0x01)
203 if psLen < 0 {
204 return ErrVerification
205 }
206 sLen = len(db) - psLen - 1
207 }
208
209
210
211 if sLen > hLen {
212 fips140.RecordNonApproved()
213 }
214
215
216
217
218
219 psLen := emLen - hLen - sLen - 2
220 for _, e := range db[:psLen] {
221 if e != 0x00 {
222 return ErrVerification
223 }
224 }
225 if db[psLen] != 0x01 {
226 return ErrVerification
227 }
228
229
230 salt := db[len(db)-sLen:]
231
232
233
234
235
236
237
238 hash.Reset()
239 var prefix [8]byte
240 hash.Write(prefix[:])
241 hash.Write(mHash)
242 hash.Write(salt)
243
244 h0 := hash.Sum(nil)
245
246
247 if !bytes.Equal(h0, h) {
248 return ErrVerification
249 }
250 return nil
251 }
252
253
254
255 func PSSMaxSaltLength(pub *PublicKey, hash hash.Hash) (int, error) {
256 saltLength := (pub.N.BitLen()-1+7)/8 - 2 - hash.Size()
257 if saltLength < 0 {
258 return 0, ErrMessageTooLong
259 }
260
261
262 if fips140.Enabled && saltLength > hash.Size() {
263 return hash.Size(), nil
264 }
265 return saltLength, nil
266 }
267
268
269 func SignPSS(rand io.Reader, priv *PrivateKey, hash hash.Hash, hashed []byte, saltLength int) ([]byte, error) {
270 fipsSelfTest()
271 fips140.RecordApproved()
272 checkApprovedHash(hash)
273
274
275
276
277
278
279
280 if saltLength < 0 {
281 return nil, errors.New("crypto/rsa: salt length cannot be negative")
282 }
283
284
285 if saltLength > hash.Size() {
286 fips140.RecordNonApproved()
287 }
288 salt := make([]byte, saltLength)
289 if err := drbg.ReadWithReader(rand, salt); err != nil {
290 return nil, err
291 }
292
293 emBits := priv.pub.N.BitLen() - 1
294 em, err := emsaPSSEncode(hashed, emBits, salt, hash)
295 if err != nil {
296 return nil, err
297 }
298
299
300
301
302
303
304
305
306 if emLen, k := len(em), priv.pub.Size(); emLen < k {
307 emNew := make([]byte, k)
308 copy(emNew[k-emLen:], em)
309 em = emNew
310 }
311
312 return decrypt(priv, em, withCheck)
313 }
314
315
316 func VerifyPSS(pub *PublicKey, hash hash.Hash, digest []byte, sig []byte) error {
317 return verifyPSS(pub, hash, digest, sig, pssSaltLengthAutodetect)
318 }
319
320
321 func VerifyPSSWithSaltLength(pub *PublicKey, hash hash.Hash, digest []byte, sig []byte, saltLength int) error {
322 if saltLength < 0 {
323 return errors.New("crypto/rsa: salt length cannot be negative")
324 }
325 return verifyPSS(pub, hash, digest, sig, saltLength)
326 }
327
328 func verifyPSS(pub *PublicKey, hash hash.Hash, digest []byte, sig []byte, saltLength int) error {
329 fipsSelfTest()
330 fips140.RecordApproved()
331 checkApprovedHash(hash)
332 if fipsApproved, err := checkPublicKey(pub); err != nil {
333 return err
334 } else if !fipsApproved {
335 fips140.RecordNonApproved()
336 }
337
338 if len(sig) != pub.Size() {
339 return ErrVerification
340 }
341
342 emBits := pub.N.BitLen() - 1
343 emLen := (emBits + 7) / 8
344 em, err := encrypt(pub, sig)
345 if err != nil {
346 return ErrVerification
347 }
348
349
350
351
352
353
354 for len(em) > emLen && len(em) > 0 {
355 if em[0] != 0 {
356 return ErrVerification
357 }
358 em = em[1:]
359 }
360
361 return emsaPSSVerify(digest, em, emBits, saltLength, hash)
362 }
363
364 func checkApprovedHash(hash hash.Hash) {
365 switch hash.(type) {
366 case *sha256.Digest, *sha512.Digest, *sha3.Digest:
367 default:
368 fips140.RecordNonApproved()
369 }
370 }
371
372
373 func EncryptOAEP(hash, mgfHash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, label []byte) ([]byte, error) {
374
375
376
377
378
379
380 fipsSelfTest()
381 fips140.RecordApproved()
382 checkApprovedHash(hash)
383 if fipsApproved, err := checkPublicKey(pub); err != nil {
384 return nil, err
385 } else if !fipsApproved {
386 fips140.RecordNonApproved()
387 }
388 k := pub.Size()
389 if len(msg) > k-2*hash.Size()-2 {
390 return nil, ErrMessageTooLong
391 }
392
393 hash.Reset()
394 hash.Write(label)
395 lHash := hash.Sum(nil)
396
397 em := make([]byte, k)
398 seed := em[1 : 1+hash.Size()]
399 db := em[1+hash.Size():]
400
401 copy(db[0:hash.Size()], lHash)
402 db[len(db)-len(msg)-1] = 1
403 copy(db[len(db)-len(msg):], msg)
404
405 if err := drbg.ReadWithReader(random, seed); err != nil {
406 return nil, err
407 }
408
409 mgf1XOR(db, mgfHash, seed)
410 mgf1XOR(seed, mgfHash, db)
411
412 return encrypt(pub, em)
413 }
414
415
416 func DecryptOAEP(hash, mgfHash hash.Hash, priv *PrivateKey, ciphertext []byte, label []byte) ([]byte, error) {
417 fipsSelfTest()
418 fips140.RecordApproved()
419 checkApprovedHash(hash)
420
421 k := priv.pub.Size()
422 if len(ciphertext) > k ||
423 k < hash.Size()*2+2 {
424 return nil, ErrDecryption
425 }
426
427 em, err := decrypt(priv, ciphertext, noCheck)
428 if err != nil {
429 return nil, err
430 }
431
432 hash.Reset()
433 hash.Write(label)
434 lHash := hash.Sum(nil)
435
436 firstByteIsZero := constanttime.ByteEq(em[0], 0)
437
438 seed := em[1 : hash.Size()+1]
439 db := em[hash.Size()+1:]
440
441 mgf1XOR(seed, mgfHash, db)
442 mgf1XOR(db, mgfHash, seed)
443
444 lHash2 := db[0:hash.Size()]
445
446
447
448
449
450 lHash2Good := subtle.ConstantTimeCompare(lHash, lHash2)
451
452
453
454
455
456
457 var lookingForIndex, index, invalid int
458 lookingForIndex = 1
459 rest := db[hash.Size():]
460
461 for i := 0; i < len(rest); i++ {
462 equals0 := constanttime.ByteEq(rest[i], 0)
463 equals1 := constanttime.ByteEq(rest[i], 1)
464 index = constanttime.Select(lookingForIndex&equals1, i, index)
465 lookingForIndex = constanttime.Select(equals1, 0, lookingForIndex)
466 invalid = constanttime.Select(lookingForIndex&^equals0, 1, invalid)
467 }
468
469 if firstByteIsZero&lHash2Good&^invalid&^lookingForIndex != 1 {
470 return nil, ErrDecryption
471 }
472
473 return rest[index+1:], nil
474 }
475
View as plain text