1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package rsa 6 7 // This file implements the RSASSA-PSS signature scheme according to RFC 8017. 8 9 import ( 10 "bytes" 11 "crypto" 12 "crypto/internal/boring" 13 "errors" 14 "hash" 15 "io" 16 ) 17 18 // Per RFC 8017, Section 9.1 19 // 20 // EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc 21 // 22 // where 23 // 24 // DB = PS || 0x01 || salt 25 // 26 // and PS can be empty so 27 // 28 // emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2 29 // 30 31 func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) { 32 // See RFC 8017, Section 9.1.1. 33 34 hLen := hash.Size() 35 sLen := len(salt) 36 emLen := (emBits + 7) / 8 37 38 // 1. If the length of M is greater than the input limitation for the 39 // hash function (2^61 - 1 octets for SHA-1), output "message too 40 // long" and stop. 41 // 42 // 2. Let mHash = Hash(M), an octet string of length hLen. 43 44 if len(mHash) != hLen { 45 return nil, errors.New("crypto/rsa: input must be hashed with given hash") 46 } 47 48 // 3. If emLen < hLen + sLen + 2, output "encoding error" and stop. 49 50 if emLen < hLen+sLen+2 { 51 return nil, ErrMessageTooLong 52 } 53 54 em := make([]byte, emLen) 55 psLen := emLen - sLen - hLen - 2 56 db := em[:psLen+1+sLen] 57 h := em[psLen+1+sLen : emLen-1] 58 59 // 4. Generate a random octet string salt of length sLen; if sLen = 0, 60 // then salt is the empty string. 61 // 62 // 5. Let 63 // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; 64 // 65 // M' is an octet string of length 8 + hLen + sLen with eight 66 // initial zero octets. 67 // 68 // 6. Let H = Hash(M'), an octet string of length hLen. 69 70 var prefix [8]byte 71 72 hash.Write(prefix[:]) 73 hash.Write(mHash) 74 hash.Write(salt) 75 76 h = hash.Sum(h[:0]) 77 hash.Reset() 78 79 // 7. Generate an octet string PS consisting of emLen - sLen - hLen - 2 80 // zero octets. The length of PS may be 0. 81 // 82 // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length 83 // emLen - hLen - 1. 84 85 db[psLen] = 0x01 86 copy(db[psLen+1:], salt) 87 88 // 9. Let dbMask = MGF(H, emLen - hLen - 1). 89 // 90 // 10. Let maskedDB = DB \xor dbMask. 91 92 mgf1XOR(db, hash, h) 93 94 // 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in 95 // maskedDB to zero. 96 97 db[0] &= 0xff >> (8*emLen - emBits) 98 99 // 12. Let EM = maskedDB || H || 0xbc. 100 em[emLen-1] = 0xbc 101 102 // 13. Output EM. 103 return em, nil 104 } 105 106 func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error { 107 // See RFC 8017, Section 9.1.2. 108 109 hLen := hash.Size() 110 if sLen == PSSSaltLengthEqualsHash { 111 sLen = hLen 112 } 113 emLen := (emBits + 7) / 8 114 if emLen != len(em) { 115 return errors.New("rsa: internal error: inconsistent length") 116 } 117 118 // 1. If the length of M is greater than the input limitation for the 119 // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" 120 // and stop. 121 // 122 // 2. Let mHash = Hash(M), an octet string of length hLen. 123 if hLen != len(mHash) { 124 return ErrVerification 125 } 126 127 // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. 128 if emLen < hLen+sLen+2 { 129 return ErrVerification 130 } 131 132 // 4. If the rightmost octet of EM does not have hexadecimal value 133 // 0xbc, output "inconsistent" and stop. 134 if em[emLen-1] != 0xbc { 135 return ErrVerification 136 } 137 138 // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and 139 // let H be the next hLen octets. 140 db := em[:emLen-hLen-1] 141 h := em[emLen-hLen-1 : emLen-1] 142 143 // 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in 144 // maskedDB are not all equal to zero, output "inconsistent" and 145 // stop. 146 var bitMask byte = 0xff >> (8*emLen - emBits) 147 if em[0] & ^bitMask != 0 { 148 return ErrVerification 149 } 150 151 // 7. Let dbMask = MGF(H, emLen - hLen - 1). 152 // 153 // 8. Let DB = maskedDB \xor dbMask. 154 mgf1XOR(db, hash, h) 155 156 // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB 157 // to zero. 158 db[0] &= bitMask 159 160 // If we don't know the salt length, look for the 0x01 delimiter. 161 if sLen == PSSSaltLengthAuto { 162 psLen := bytes.IndexByte(db, 0x01) 163 if psLen < 0 { 164 return ErrVerification 165 } 166 sLen = len(db) - psLen - 1 167 } 168 169 // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero 170 // or if the octet at position emLen - hLen - sLen - 1 (the leftmost 171 // position is "position 1") does not have hexadecimal value 0x01, 172 // output "inconsistent" and stop. 173 psLen := emLen - hLen - sLen - 2 174 for _, e := range db[:psLen] { 175 if e != 0x00 { 176 return ErrVerification 177 } 178 } 179 if db[psLen] != 0x01 { 180 return ErrVerification 181 } 182 183 // 11. Let salt be the last sLen octets of DB. 184 salt := db[len(db)-sLen:] 185 186 // 12. Let 187 // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; 188 // M' is an octet string of length 8 + hLen + sLen with eight 189 // initial zero octets. 190 // 191 // 13. Let H' = Hash(M'), an octet string of length hLen. 192 var prefix [8]byte 193 hash.Write(prefix[:]) 194 hash.Write(mHash) 195 hash.Write(salt) 196 197 h0 := hash.Sum(nil) 198 199 // 14. If H = H', output "consistent." Otherwise, output "inconsistent." 200 if !bytes.Equal(h0, h) { // TODO: constant time? 201 return ErrVerification 202 } 203 return nil 204 } 205 206 // signPSSWithSalt calculates the signature of hashed using PSS with specified salt. 207 // Note that hashed must be the result of hashing the input message using the 208 // given hash function. salt is a random sequence of bytes whose length will be 209 // later used to verify the signature. 210 func signPSSWithSalt(priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) { 211 emBits := priv.N.BitLen() - 1 212 em, err := emsaPSSEncode(hashed, emBits, salt, hash.New()) 213 if err != nil { 214 return nil, err 215 } 216 217 if boring.Enabled { 218 bkey, err := boringPrivateKey(priv) 219 if err != nil { 220 return nil, err 221 } 222 // Note: BoringCrypto always does decrypt "withCheck". 223 // (It's not just decrypt.) 224 s, err := boring.DecryptRSANoPadding(bkey, em) 225 if err != nil { 226 return nil, err 227 } 228 return s, nil 229 } 230 231 // RFC 8017: "Note that the octet length of EM will be one less than k if 232 // modBits - 1 is divisible by 8 and equal to k otherwise, where k is the 233 // length in octets of the RSA modulus n." 🙄 234 // 235 // This is extremely annoying, as all other encrypt and decrypt inputs are 236 // always the exact same size as the modulus. Since it only happens for 237 // weird modulus sizes, fix it by padding inefficiently. 238 if emLen, k := len(em), priv.Size(); emLen < k { 239 emNew := make([]byte, k) 240 copy(emNew[k-emLen:], em) 241 em = emNew 242 } 243 244 return decrypt(priv, em, withCheck) 245 } 246 247 const ( 248 // PSSSaltLengthAuto causes the salt in a PSS signature to be as large 249 // as possible when signing, and to be auto-detected when verifying. 250 PSSSaltLengthAuto = 0 251 // PSSSaltLengthEqualsHash causes the salt length to equal the length 252 // of the hash used in the signature. 253 PSSSaltLengthEqualsHash = -1 254 ) 255 256 // PSSOptions contains options for creating and verifying PSS signatures. 257 type PSSOptions struct { 258 // SaltLength controls the length of the salt used in the PSS signature. It 259 // can either be a positive number of bytes, or one of the special 260 // PSSSaltLength constants. 261 SaltLength int 262 263 // Hash is the hash function used to generate the message digest. If not 264 // zero, it overrides the hash function passed to SignPSS. It's required 265 // when using PrivateKey.Sign. 266 Hash crypto.Hash 267 } 268 269 // HashFunc returns opts.Hash so that [PSSOptions] implements [crypto.SignerOpts]. 270 func (opts *PSSOptions) HashFunc() crypto.Hash { 271 return opts.Hash 272 } 273 274 func (opts *PSSOptions) saltLength() int { 275 if opts == nil { 276 return PSSSaltLengthAuto 277 } 278 return opts.SaltLength 279 } 280 281 var invalidSaltLenErr = errors.New("crypto/rsa: PSSOptions.SaltLength cannot be negative") 282 283 // SignPSS calculates the signature of digest using PSS. 284 // 285 // digest must be the result of hashing the input message using the given hash 286 // function. The opts argument may be nil, in which case sensible defaults are 287 // used. If opts.Hash is set, it overrides hash. 288 // 289 // The signature is randomized depending on the message, key, and salt size, 290 // using bytes from rand. Most applications should use [crypto/rand.Reader] as 291 // rand. 292 func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) { 293 // Note that while we don't commit to deterministic execution with respect 294 // to the rand stream, we also don't apply MaybeReadByte, so per Hyrum's Law 295 // it's probably relied upon by some. It's a tolerable promise because a 296 // well-specified number of random bytes is included in the signature, in a 297 // well-specified way. 298 299 if boring.Enabled && rand == boring.RandReader { 300 bkey, err := boringPrivateKey(priv) 301 if err != nil { 302 return nil, err 303 } 304 return boring.SignRSAPSS(bkey, hash, digest, opts.saltLength()) 305 } 306 boring.UnreachableExceptTests() 307 308 if opts != nil && opts.Hash != 0 { 309 hash = opts.Hash 310 } 311 312 saltLength := opts.saltLength() 313 switch saltLength { 314 case PSSSaltLengthAuto: 315 saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size() 316 if saltLength < 0 { 317 return nil, ErrMessageTooLong 318 } 319 case PSSSaltLengthEqualsHash: 320 saltLength = hash.Size() 321 default: 322 // If we get here saltLength is either > 0 or < -1, in the 323 // latter case we fail out. 324 if saltLength <= 0 { 325 return nil, invalidSaltLenErr 326 } 327 } 328 salt := make([]byte, saltLength) 329 if _, err := io.ReadFull(rand, salt); err != nil { 330 return nil, err 331 } 332 return signPSSWithSalt(priv, hash, digest, salt) 333 } 334 335 // VerifyPSS verifies a PSS signature. 336 // 337 // A valid signature is indicated by returning a nil error. digest must be the 338 // result of hashing the input message using the given hash function. The opts 339 // argument may be nil, in which case sensible defaults are used. opts.Hash is 340 // ignored. 341 func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error { 342 if boring.Enabled { 343 bkey, err := boringPublicKey(pub) 344 if err != nil { 345 return err 346 } 347 if err := boring.VerifyRSAPSS(bkey, hash, digest, sig, opts.saltLength()); err != nil { 348 return ErrVerification 349 } 350 return nil 351 } 352 if len(sig) != pub.Size() { 353 return ErrVerification 354 } 355 // Salt length must be either one of the special constants (-1 or 0) 356 // or otherwise positive. If it is < PSSSaltLengthEqualsHash (-1) 357 // we return an error. 358 if opts.saltLength() < PSSSaltLengthEqualsHash { 359 return invalidSaltLenErr 360 } 361 362 emBits := pub.N.BitLen() - 1 363 emLen := (emBits + 7) / 8 364 em, err := encrypt(pub, sig) 365 if err != nil { 366 return ErrVerification 367 } 368 369 // Like in signPSSWithSalt, deal with mismatches between emLen and the size 370 // of the modulus. The spec would have us wire emLen into the encoding 371 // function, but we'd rather always encode to the size of the modulus and 372 // then strip leading zeroes if necessary. This only happens for weird 373 // modulus sizes anyway. 374 for len(em) > emLen && len(em) > 0 { 375 if em[0] != 0 { 376 return ErrVerification 377 } 378 em = em[1:] 379 } 380 381 return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New()) 382 } 383