1
2
3
4
5 package rsa
6
7
8
9
10 import (
11 "bytes"
12 "crypto/internal/fips140"
13 "crypto/internal/fips140/drbg"
14 "crypto/internal/fips140/sha256"
15 "crypto/internal/fips140/sha3"
16 "crypto/internal/fips140/sha512"
17 "crypto/internal/fips140/subtle"
18 "errors"
19 "hash"
20 "io"
21 )
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37 func incCounter(c *[4]byte) {
38 if c[3]++; c[3] != 0 {
39 return
40 }
41 if c[2]++; c[2] != 0 {
42 return
43 }
44 if c[1]++; c[1] != 0 {
45 return
46 }
47 c[0]++
48 }
49
50
51
52 func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
53 var counter [4]byte
54 var digest []byte
55
56 done := 0
57 for done < len(out) {
58 hash.Reset()
59 hash.Write(seed)
60 hash.Write(counter[0:4])
61 digest = hash.Sum(digest[:0])
62
63 for i := 0; i < len(digest) && done < len(out); i++ {
64 out[done] ^= digest[i]
65 done++
66 }
67 incCounter(&counter)
68 }
69 }
70
71 func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
72
73
74 hLen := hash.Size()
75 sLen := len(salt)
76 emLen := (emBits + 7) / 8
77
78
79
80
81
82
83
84 if len(mHash) != hLen {
85 return nil, errors.New("crypto/rsa: input must be hashed with given hash")
86 }
87
88
89
90 if emLen < hLen+sLen+2 {
91 return nil, ErrMessageTooLong
92 }
93
94 em := make([]byte, emLen)
95 psLen := emLen - sLen - hLen - 2
96 db := em[:psLen+1+sLen]
97 h := em[psLen+1+sLen : emLen-1]
98
99
100
101
102
103
104
105
106
107
108
109
110 var prefix [8]byte
111
112 hash.Reset()
113 hash.Write(prefix[:])
114 hash.Write(mHash)
115 hash.Write(salt)
116
117 h = hash.Sum(h[:0])
118
119
120
121
122
123
124
125 db[psLen] = 0x01
126 copy(db[psLen+1:], salt)
127
128
129
130
131
132 mgf1XOR(db, hash, h)
133
134
135
136
137 db[0] &= 0xff >> (8*emLen - emBits)
138
139
140 em[emLen-1] = 0xbc
141
142
143 return em, nil
144 }
145
146 const pssSaltLengthAutodetect = -1
147
148 func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
149
150
151 hLen := hash.Size()
152 emLen := (emBits + 7) / 8
153 if emLen != len(em) {
154 return errors.New("rsa: internal error: inconsistent length")
155 }
156
157
158
159
160
161
162 if hLen != len(mHash) {
163 return ErrVerification
164 }
165
166
167 if emLen < hLen+sLen+2 {
168 return ErrVerification
169 }
170
171
172
173 if em[emLen-1] != 0xbc {
174 return ErrVerification
175 }
176
177
178
179 db := em[:emLen-hLen-1]
180 h := em[emLen-hLen-1 : emLen-1]
181
182
183
184
185 var bitMask byte = 0xff >> (8*emLen - emBits)
186 if em[0] & ^bitMask != 0 {
187 return ErrVerification
188 }
189
190
191
192
193 mgf1XOR(db, hash, h)
194
195
196
197 db[0] &= bitMask
198
199
200 if sLen == pssSaltLengthAutodetect {
201 psLen := bytes.IndexByte(db, 0x01)
202 if psLen < 0 {
203 return ErrVerification
204 }
205 sLen = len(db) - psLen - 1
206 }
207
208
209
210 if sLen > hLen {
211 fips140.RecordNonApproved()
212 }
213
214
215
216
217
218 psLen := emLen - hLen - sLen - 2
219 for _, e := range db[:psLen] {
220 if e != 0x00 {
221 return ErrVerification
222 }
223 }
224 if db[psLen] != 0x01 {
225 return ErrVerification
226 }
227
228
229 salt := db[len(db)-sLen:]
230
231
232
233
234
235
236
237 hash.Reset()
238 var prefix [8]byte
239 hash.Write(prefix[:])
240 hash.Write(mHash)
241 hash.Write(salt)
242
243 h0 := hash.Sum(nil)
244
245
246 if !bytes.Equal(h0, h) {
247 return ErrVerification
248 }
249 return nil
250 }
251
252
253
254 func PSSMaxSaltLength(pub *PublicKey, hash hash.Hash) (int, error) {
255 saltLength := (pub.N.BitLen()-1+7)/8 - 2 - hash.Size()
256 if saltLength < 0 {
257 return 0, ErrMessageTooLong
258 }
259
260
261 if fips140.Enabled && saltLength > hash.Size() {
262 return hash.Size(), nil
263 }
264 return saltLength, nil
265 }
266
267
268 func SignPSS(rand io.Reader, priv *PrivateKey, hash hash.Hash, hashed []byte, saltLength int) ([]byte, error) {
269 fipsSelfTest()
270 fips140.RecordApproved()
271 checkApprovedHash(hash)
272
273
274
275
276
277
278
279 if saltLength < 0 {
280 return nil, errors.New("crypto/rsa: salt length cannot be negative")
281 }
282
283
284 if saltLength > hash.Size() {
285 fips140.RecordNonApproved()
286 }
287 salt := make([]byte, saltLength)
288 if err := drbg.ReadWithReaderDeterministic(rand, salt); err != nil {
289 return nil, err
290 }
291
292 emBits := priv.pub.N.BitLen() - 1
293 em, err := emsaPSSEncode(hashed, emBits, salt, hash)
294 if err != nil {
295 return nil, err
296 }
297
298
299
300
301
302
303
304
305 if emLen, k := len(em), priv.pub.Size(); emLen < k {
306 emNew := make([]byte, k)
307 copy(emNew[k-emLen:], em)
308 em = emNew
309 }
310
311 return decrypt(priv, em, withCheck)
312 }
313
314
315 func VerifyPSS(pub *PublicKey, hash hash.Hash, digest []byte, sig []byte) error {
316 return verifyPSS(pub, hash, digest, sig, pssSaltLengthAutodetect)
317 }
318
319
320 func VerifyPSSWithSaltLength(pub *PublicKey, hash hash.Hash, digest []byte, sig []byte, saltLength int) error {
321 if saltLength < 0 {
322 return errors.New("crypto/rsa: salt length cannot be negative")
323 }
324 return verifyPSS(pub, hash, digest, sig, saltLength)
325 }
326
327 func verifyPSS(pub *PublicKey, hash hash.Hash, digest []byte, sig []byte, saltLength int) error {
328 fipsSelfTest()
329 fips140.RecordApproved()
330 checkApprovedHash(hash)
331 if fipsApproved, err := checkPublicKey(pub); err != nil {
332 return err
333 } else if !fipsApproved {
334 fips140.RecordNonApproved()
335 }
336
337 if len(sig) != pub.Size() {
338 return ErrVerification
339 }
340
341 emBits := pub.N.BitLen() - 1
342 emLen := (emBits + 7) / 8
343 em, err := encrypt(pub, sig)
344 if err != nil {
345 return ErrVerification
346 }
347
348
349
350
351
352
353 for len(em) > emLen && len(em) > 0 {
354 if em[0] != 0 {
355 return ErrVerification
356 }
357 em = em[1:]
358 }
359
360 return emsaPSSVerify(digest, em, emBits, saltLength, hash)
361 }
362
363 func checkApprovedHash(hash hash.Hash) {
364 switch hash.(type) {
365 case *sha256.Digest, *sha512.Digest, *sha3.Digest:
366 default:
367 fips140.RecordNonApproved()
368 }
369 }
370
371
372 func EncryptOAEP(hash, mgfHash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, label []byte) ([]byte, error) {
373
374
375
376
377
378
379 fipsSelfTest()
380 fips140.RecordApproved()
381 checkApprovedHash(hash)
382 if fipsApproved, err := checkPublicKey(pub); err != nil {
383 return nil, err
384 } else if !fipsApproved {
385 fips140.RecordNonApproved()
386 }
387 k := pub.Size()
388 if len(msg) > k-2*hash.Size()-2 {
389 return nil, ErrMessageTooLong
390 }
391
392 hash.Reset()
393 hash.Write(label)
394 lHash := hash.Sum(nil)
395
396 em := make([]byte, k)
397 seed := em[1 : 1+hash.Size()]
398 db := em[1+hash.Size():]
399
400 copy(db[0:hash.Size()], lHash)
401 db[len(db)-len(msg)-1] = 1
402 copy(db[len(db)-len(msg):], msg)
403
404 if err := drbg.ReadWithReaderDeterministic(random, seed); err != nil {
405 return nil, err
406 }
407
408 mgf1XOR(db, mgfHash, seed)
409 mgf1XOR(seed, mgfHash, db)
410
411 return encrypt(pub, em)
412 }
413
414
415 func DecryptOAEP(hash, mgfHash hash.Hash, priv *PrivateKey, ciphertext []byte, label []byte) ([]byte, error) {
416 fipsSelfTest()
417 fips140.RecordApproved()
418 checkApprovedHash(hash)
419
420 k := priv.pub.Size()
421 if len(ciphertext) > k ||
422 k < hash.Size()*2+2 {
423 return nil, ErrDecryption
424 }
425
426 em, err := decrypt(priv, ciphertext, noCheck)
427 if err != nil {
428 return nil, err
429 }
430
431 hash.Reset()
432 hash.Write(label)
433 lHash := hash.Sum(nil)
434
435 firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
436
437 seed := em[1 : hash.Size()+1]
438 db := em[hash.Size()+1:]
439
440 mgf1XOR(seed, mgfHash, db)
441 mgf1XOR(db, mgfHash, seed)
442
443 lHash2 := db[0:hash.Size()]
444
445
446
447
448
449 lHash2Good := subtle.ConstantTimeCompare(lHash, lHash2)
450
451
452
453
454
455
456 var lookingForIndex, index, invalid int
457 lookingForIndex = 1
458 rest := db[hash.Size():]
459
460 for i := 0; i < len(rest); i++ {
461 equals0 := subtle.ConstantTimeByteEq(rest[i], 0)
462 equals1 := subtle.ConstantTimeByteEq(rest[i], 1)
463 index = subtle.ConstantTimeSelect(lookingForIndex&equals1, i, index)
464 lookingForIndex = subtle.ConstantTimeSelect(equals1, 0, lookingForIndex)
465 invalid = subtle.ConstantTimeSelect(lookingForIndex&^equals0, 1, invalid)
466 }
467
468 if firstByteIsZero&lHash2Good&^invalid&^lookingForIndex != 1 {
469 return nil, ErrDecryption
470 }
471
472 return rest[index+1:], nil
473 }
474
View as plain text