1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26 package rsa
27
28 import (
29 "crypto"
30 "crypto/internal/bigmod"
31 "crypto/internal/boring"
32 "crypto/internal/boring/bbig"
33 "crypto/internal/randutil"
34 "crypto/rand"
35 "crypto/subtle"
36 "errors"
37 "hash"
38 "io"
39 "math"
40 "math/big"
41 )
42
43 var bigOne = big.NewInt(1)
44
45
46
47
48
49
50 type PublicKey struct {
51 N *big.Int
52 E int
53 }
54
55
56
57
58
59
60 func (pub *PublicKey) Size() int {
61 return (pub.N.BitLen() + 7) / 8
62 }
63
64
65 func (pub *PublicKey) Equal(x crypto.PublicKey) bool {
66 xx, ok := x.(*PublicKey)
67 if !ok {
68 return false
69 }
70 return bigIntEqual(pub.N, xx.N) && pub.E == xx.E
71 }
72
73
74
75 type OAEPOptions struct {
76
77 Hash crypto.Hash
78
79
80
81 MGFHash crypto.Hash
82
83
84
85 Label []byte
86 }
87
88 var (
89 errPublicModulus = errors.New("crypto/rsa: missing public modulus")
90 errPublicExponentSmall = errors.New("crypto/rsa: public exponent too small")
91 errPublicExponentLarge = errors.New("crypto/rsa: public exponent too large")
92 )
93
94
95
96
97
98
99 func checkPub(pub *PublicKey) error {
100 if pub.N == nil {
101 return errPublicModulus
102 }
103 if pub.E < 2 {
104 return errPublicExponentSmall
105 }
106 if pub.E > 1<<31-1 {
107 return errPublicExponentLarge
108 }
109 return nil
110 }
111
112
113 type PrivateKey struct {
114 PublicKey
115 D *big.Int
116 Primes []*big.Int
117
118
119
120
121 Precomputed PrecomputedValues
122 }
123
124
125 func (priv *PrivateKey) Public() crypto.PublicKey {
126 return &priv.PublicKey
127 }
128
129
130
131 func (priv *PrivateKey) Equal(x crypto.PrivateKey) bool {
132 xx, ok := x.(*PrivateKey)
133 if !ok {
134 return false
135 }
136 if !priv.PublicKey.Equal(&xx.PublicKey) || !bigIntEqual(priv.D, xx.D) {
137 return false
138 }
139 if len(priv.Primes) != len(xx.Primes) {
140 return false
141 }
142 for i := range priv.Primes {
143 if !bigIntEqual(priv.Primes[i], xx.Primes[i]) {
144 return false
145 }
146 }
147 return true
148 }
149
150
151
152 func bigIntEqual(a, b *big.Int) bool {
153 return subtle.ConstantTimeCompare(a.Bytes(), b.Bytes()) == 1
154 }
155
156
157
158
159
160
161
162
163
164 func (priv *PrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
165 if pssOpts, ok := opts.(*PSSOptions); ok {
166 return SignPSS(rand, priv, pssOpts.Hash, digest, pssOpts)
167 }
168
169 return SignPKCS1v15(rand, priv, opts.HashFunc(), digest)
170 }
171
172
173
174
175 func (priv *PrivateKey) Decrypt(rand io.Reader, ciphertext []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
176 if opts == nil {
177 return DecryptPKCS1v15(rand, priv, ciphertext)
178 }
179
180 switch opts := opts.(type) {
181 case *OAEPOptions:
182 if opts.MGFHash == 0 {
183 return decryptOAEP(opts.Hash.New(), opts.Hash.New(), rand, priv, ciphertext, opts.Label)
184 } else {
185 return decryptOAEP(opts.Hash.New(), opts.MGFHash.New(), rand, priv, ciphertext, opts.Label)
186 }
187
188 case *PKCS1v15DecryptOptions:
189 if l := opts.SessionKeyLen; l > 0 {
190 plaintext = make([]byte, l)
191 if _, err := io.ReadFull(rand, plaintext); err != nil {
192 return nil, err
193 }
194 if err := DecryptPKCS1v15SessionKey(rand, priv, ciphertext, plaintext); err != nil {
195 return nil, err
196 }
197 return plaintext, nil
198 } else {
199 return DecryptPKCS1v15(rand, priv, ciphertext)
200 }
201
202 default:
203 return nil, errors.New("crypto/rsa: invalid options for Decrypt")
204 }
205 }
206
207 type PrecomputedValues struct {
208 Dp, Dq *big.Int
209 Qinv *big.Int
210
211
212
213
214
215
216
217
218
219
220 CRTValues []CRTValue
221
222 n, p, q *bigmod.Modulus
223 }
224
225
226 type CRTValue struct {
227 Exp *big.Int
228 Coeff *big.Int
229 R *big.Int
230 }
231
232
233
234 func (priv *PrivateKey) Validate() error {
235 if err := checkPub(&priv.PublicKey); err != nil {
236 return err
237 }
238
239
240 modulus := new(big.Int).Set(bigOne)
241 for _, prime := range priv.Primes {
242
243 if prime.Cmp(bigOne) <= 0 {
244 return errors.New("crypto/rsa: invalid prime value")
245 }
246 modulus.Mul(modulus, prime)
247 }
248 if modulus.Cmp(priv.N) != 0 {
249 return errors.New("crypto/rsa: invalid modulus")
250 }
251
252
253
254
255
256
257 congruence := new(big.Int)
258 de := new(big.Int).SetInt64(int64(priv.E))
259 de.Mul(de, priv.D)
260 for _, prime := range priv.Primes {
261 pminus1 := new(big.Int).Sub(prime, bigOne)
262 congruence.Mod(de, pminus1)
263 if congruence.Cmp(bigOne) != 0 {
264 return errors.New("crypto/rsa: invalid exponents")
265 }
266 }
267 return nil
268 }
269
270
271
272
273
274
275 func GenerateKey(random io.Reader, bits int) (*PrivateKey, error) {
276 return GenerateMultiPrimeKey(random, 2, bits)
277 }
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298 func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (*PrivateKey, error) {
299 randutil.MaybeReadByte(random)
300
301 if boring.Enabled && random == boring.RandReader && nprimes == 2 &&
302 (bits == 2048 || bits == 3072 || bits == 4096) {
303 bN, bE, bD, bP, bQ, bDp, bDq, bQinv, err := boring.GenerateKeyRSA(bits)
304 if err != nil {
305 return nil, err
306 }
307 N := bbig.Dec(bN)
308 E := bbig.Dec(bE)
309 D := bbig.Dec(bD)
310 P := bbig.Dec(bP)
311 Q := bbig.Dec(bQ)
312 Dp := bbig.Dec(bDp)
313 Dq := bbig.Dec(bDq)
314 Qinv := bbig.Dec(bQinv)
315 e64 := E.Int64()
316 if !E.IsInt64() || int64(int(e64)) != e64 {
317 return nil, errors.New("crypto/rsa: generated key exponent too large")
318 }
319
320 mn, err := bigmod.NewModulusFromBig(N)
321 if err != nil {
322 return nil, err
323 }
324 mp, err := bigmod.NewModulusFromBig(P)
325 if err != nil {
326 return nil, err
327 }
328 mq, err := bigmod.NewModulusFromBig(Q)
329 if err != nil {
330 return nil, err
331 }
332
333 key := &PrivateKey{
334 PublicKey: PublicKey{
335 N: N,
336 E: int(e64),
337 },
338 D: D,
339 Primes: []*big.Int{P, Q},
340 Precomputed: PrecomputedValues{
341 Dp: Dp,
342 Dq: Dq,
343 Qinv: Qinv,
344 CRTValues: make([]CRTValue, 0),
345 n: mn,
346 p: mp,
347 q: mq,
348 },
349 }
350 return key, nil
351 }
352
353 priv := new(PrivateKey)
354 priv.E = 65537
355
356 if nprimes < 2 {
357 return nil, errors.New("crypto/rsa: GenerateMultiPrimeKey: nprimes must be >= 2")
358 }
359
360 if bits < 64 {
361 primeLimit := float64(uint64(1) << uint(bits/nprimes))
362
363 pi := primeLimit / (math.Log(primeLimit) - 1)
364
365
366 pi /= 4
367
368
369 pi /= 2
370 if pi <= float64(nprimes) {
371 return nil, errors.New("crypto/rsa: too few primes of given length to generate an RSA key")
372 }
373 }
374
375 primes := make([]*big.Int, nprimes)
376
377 NextSetOfPrimes:
378 for {
379 todo := bits
380
381
382
383
384
385
386
387
388
389
390
391 if nprimes >= 7 {
392 todo += (nprimes - 2) / 5
393 }
394 for i := 0; i < nprimes; i++ {
395 var err error
396 primes[i], err = rand.Prime(random, todo/(nprimes-i))
397 if err != nil {
398 return nil, err
399 }
400 todo -= primes[i].BitLen()
401 }
402
403
404 for i, prime := range primes {
405 for j := 0; j < i; j++ {
406 if prime.Cmp(primes[j]) == 0 {
407 continue NextSetOfPrimes
408 }
409 }
410 }
411
412 n := new(big.Int).Set(bigOne)
413 totient := new(big.Int).Set(bigOne)
414 pminus1 := new(big.Int)
415 for _, prime := range primes {
416 n.Mul(n, prime)
417 pminus1.Sub(prime, bigOne)
418 totient.Mul(totient, pminus1)
419 }
420 if n.BitLen() != bits {
421
422
423
424 continue NextSetOfPrimes
425 }
426
427 priv.D = new(big.Int)
428 e := big.NewInt(int64(priv.E))
429 ok := priv.D.ModInverse(e, totient)
430
431 if ok != nil {
432 priv.Primes = primes
433 priv.N = n
434 break
435 }
436 }
437
438 priv.Precompute()
439 return priv, nil
440 }
441
442
443 func incCounter(c *[4]byte) {
444 if c[3]++; c[3] != 0 {
445 return
446 }
447 if c[2]++; c[2] != 0 {
448 return
449 }
450 if c[1]++; c[1] != 0 {
451 return
452 }
453 c[0]++
454 }
455
456
457
458 func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
459 var counter [4]byte
460 var digest []byte
461
462 done := 0
463 for done < len(out) {
464 hash.Write(seed)
465 hash.Write(counter[0:4])
466 digest = hash.Sum(digest[:0])
467 hash.Reset()
468
469 for i := 0; i < len(digest) && done < len(out); i++ {
470 out[done] ^= digest[i]
471 done++
472 }
473 incCounter(&counter)
474 }
475 }
476
477
478
479
480 var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA key size")
481
482 func encrypt(pub *PublicKey, plaintext []byte) ([]byte, error) {
483 boring.Unreachable()
484
485 N, err := bigmod.NewModulusFromBig(pub.N)
486 if err != nil {
487 return nil, err
488 }
489 m, err := bigmod.NewNat().SetBytes(plaintext, N)
490 if err != nil {
491 return nil, err
492 }
493 e := uint(pub.E)
494
495 return bigmod.NewNat().ExpShortVarTime(m, e, N).Bytes(N), nil
496 }
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516 func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, label []byte) ([]byte, error) {
517
518
519
520
521
522
523 if err := checkPub(pub); err != nil {
524 return nil, err
525 }
526 hash.Reset()
527 k := pub.Size()
528 if len(msg) > k-2*hash.Size()-2 {
529 return nil, ErrMessageTooLong
530 }
531
532 if boring.Enabled && random == boring.RandReader {
533 bkey, err := boringPublicKey(pub)
534 if err != nil {
535 return nil, err
536 }
537 return boring.EncryptRSAOAEP(hash, hash, bkey, msg, label)
538 }
539 boring.UnreachableExceptTests()
540
541 hash.Write(label)
542 lHash := hash.Sum(nil)
543 hash.Reset()
544
545 em := make([]byte, k)
546 seed := em[1 : 1+hash.Size()]
547 db := em[1+hash.Size():]
548
549 copy(db[0:hash.Size()], lHash)
550 db[len(db)-len(msg)-1] = 1
551 copy(db[len(db)-len(msg):], msg)
552
553 _, err := io.ReadFull(random, seed)
554 if err != nil {
555 return nil, err
556 }
557
558 mgf1XOR(db, hash, seed)
559 mgf1XOR(seed, hash, db)
560
561 if boring.Enabled {
562 var bkey *boring.PublicKeyRSA
563 bkey, err = boringPublicKey(pub)
564 if err != nil {
565 return nil, err
566 }
567 return boring.EncryptRSANoPadding(bkey, em)
568 }
569
570 return encrypt(pub, em)
571 }
572
573
574
575 var ErrDecryption = errors.New("crypto/rsa: decryption error")
576
577
578
579 var ErrVerification = errors.New("crypto/rsa: verification error")
580
581
582
583 func (priv *PrivateKey) Precompute() {
584 if priv.Precomputed.n == nil && len(priv.Primes) == 2 {
585
586
587 var err error
588 priv.Precomputed.n, err = bigmod.NewModulusFromBig(priv.N)
589 if err != nil {
590 return
591 }
592 priv.Precomputed.p, err = bigmod.NewModulusFromBig(priv.Primes[0])
593 if err != nil {
594
595 priv.Precomputed.n = nil
596 return
597 }
598 priv.Precomputed.q, err = bigmod.NewModulusFromBig(priv.Primes[1])
599 if err != nil {
600
601 priv.Precomputed.n, priv.Precomputed.p = nil, nil
602 return
603 }
604 }
605
606
607 if priv.Precomputed.Dp != nil {
608 return
609 }
610
611 priv.Precomputed.Dp = new(big.Int).Sub(priv.Primes[0], bigOne)
612 priv.Precomputed.Dp.Mod(priv.D, priv.Precomputed.Dp)
613
614 priv.Precomputed.Dq = new(big.Int).Sub(priv.Primes[1], bigOne)
615 priv.Precomputed.Dq.Mod(priv.D, priv.Precomputed.Dq)
616
617 priv.Precomputed.Qinv = new(big.Int).ModInverse(priv.Primes[1], priv.Primes[0])
618
619 r := new(big.Int).Mul(priv.Primes[0], priv.Primes[1])
620 priv.Precomputed.CRTValues = make([]CRTValue, len(priv.Primes)-2)
621 for i := 2; i < len(priv.Primes); i++ {
622 prime := priv.Primes[i]
623 values := &priv.Precomputed.CRTValues[i-2]
624
625 values.Exp = new(big.Int).Sub(prime, bigOne)
626 values.Exp.Mod(priv.D, values.Exp)
627
628 values.R = new(big.Int).Set(r)
629 values.Coeff = new(big.Int).ModInverse(r, prime)
630
631 r.Mul(r, prime)
632 }
633 }
634
635 const withCheck = true
636 const noCheck = false
637
638
639
640
641 func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) {
642 if len(priv.Primes) <= 2 {
643 boring.Unreachable()
644 }
645
646 var (
647 err error
648 m, c *bigmod.Nat
649 N *bigmod.Modulus
650 t0 = bigmod.NewNat()
651 )
652 if priv.Precomputed.n == nil {
653 N, err = bigmod.NewModulusFromBig(priv.N)
654 if err != nil {
655 return nil, ErrDecryption
656 }
657 c, err = bigmod.NewNat().SetBytes(ciphertext, N)
658 if err != nil {
659 return nil, ErrDecryption
660 }
661 m = bigmod.NewNat().Exp(c, priv.D.Bytes(), N)
662 } else {
663 N = priv.Precomputed.n
664 P, Q := priv.Precomputed.p, priv.Precomputed.q
665 Qinv, err := bigmod.NewNat().SetBytes(priv.Precomputed.Qinv.Bytes(), P)
666 if err != nil {
667 return nil, ErrDecryption
668 }
669 c, err = bigmod.NewNat().SetBytes(ciphertext, N)
670 if err != nil {
671 return nil, ErrDecryption
672 }
673
674
675 m = bigmod.NewNat().Exp(t0.Mod(c, P), priv.Precomputed.Dp.Bytes(), P)
676
677 m2 := bigmod.NewNat().Exp(t0.Mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
678
679 m.Sub(t0.Mod(m2, P), P)
680
681 m.Mul(Qinv, P)
682
683 m.ExpandFor(N).Mul(t0.Mod(Q.Nat(), N), N)
684
685 m.Add(m2.ExpandFor(N), N)
686 }
687
688 if check {
689 c1 := bigmod.NewNat().ExpShortVarTime(m, uint(priv.E), N)
690 if c1.Equal(c) != 1 {
691 return nil, ErrDecryption
692 }
693 }
694
695 return m.Bytes(N), nil
696 }
697
698
699
700
701
702
703
704
705
706
707
708 func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) ([]byte, error) {
709 return decryptOAEP(hash, hash, random, priv, ciphertext, label)
710 }
711
712 func decryptOAEP(hash, mgfHash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) ([]byte, error) {
713 if err := checkPub(&priv.PublicKey); err != nil {
714 return nil, err
715 }
716 k := priv.Size()
717 if len(ciphertext) > k ||
718 k < hash.Size()*2+2 {
719 return nil, ErrDecryption
720 }
721
722 if boring.Enabled {
723 bkey, err := boringPrivateKey(priv)
724 if err != nil {
725 return nil, err
726 }
727 out, err := boring.DecryptRSAOAEP(hash, mgfHash, bkey, ciphertext, label)
728 if err != nil {
729 return nil, ErrDecryption
730 }
731 return out, nil
732 }
733
734 em, err := decrypt(priv, ciphertext, noCheck)
735 if err != nil {
736 return nil, err
737 }
738
739 hash.Write(label)
740 lHash := hash.Sum(nil)
741 hash.Reset()
742
743 firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
744
745 seed := em[1 : hash.Size()+1]
746 db := em[hash.Size()+1:]
747
748 mgf1XOR(seed, mgfHash, db)
749 mgf1XOR(db, mgfHash, seed)
750
751 lHash2 := db[0:hash.Size()]
752
753
754
755
756
757 lHash2Good := subtle.ConstantTimeCompare(lHash, lHash2)
758
759
760
761
762
763
764 var lookingForIndex, index, invalid int
765 lookingForIndex = 1
766 rest := db[hash.Size():]
767
768 for i := 0; i < len(rest); i++ {
769 equals0 := subtle.ConstantTimeByteEq(rest[i], 0)
770 equals1 := subtle.ConstantTimeByteEq(rest[i], 1)
771 index = subtle.ConstantTimeSelect(lookingForIndex&equals1, i, index)
772 lookingForIndex = subtle.ConstantTimeSelect(equals1, 0, lookingForIndex)
773 invalid = subtle.ConstantTimeSelect(lookingForIndex&^equals0, 1, invalid)
774 }
775
776 if firstByteIsZero&lHash2Good&^invalid&^lookingForIndex != 1 {
777 return nil, ErrDecryption
778 }
779
780 return rest[index+1:], nil
781 }
782
View as plain text