diff --git a/crypto/rsa.c b/crypto/rsa.c index 77d737f52147..dc692d43b666 100644 --- a/crypto/rsa.c +++ b/crypto/rsa.c @@ -10,16 +10,23 @@ */ #include +#include #include #include #include #include +struct rsa_mpi_key { + MPI n; + MPI e; + MPI d; +}; + /* * RSAEP function [RFC3447 sec 5.1.1] * c = m^e mod n; */ -static int _rsa_enc(const struct rsa_key *key, MPI c, MPI m) +static int _rsa_enc(const struct rsa_mpi_key *key, MPI c, MPI m) { /* (1) Validate 0 <= m < n */ if (mpi_cmp_ui(m, 0) < 0 || mpi_cmp(m, key->n) >= 0) @@ -33,7 +40,7 @@ static int _rsa_enc(const struct rsa_key *key, MPI c, MPI m) * RSADP function [RFC3447 sec 5.1.2] * m = c^d mod n; */ -static int _rsa_dec(const struct rsa_key *key, MPI m, MPI c) +static int _rsa_dec(const struct rsa_mpi_key *key, MPI m, MPI c) { /* (1) Validate 0 <= c < n */ if (mpi_cmp_ui(c, 0) < 0 || mpi_cmp(c, key->n) >= 0) @@ -47,7 +54,7 @@ static int _rsa_dec(const struct rsa_key *key, MPI m, MPI c) * RSASP1 function [RFC3447 sec 5.2.1] * s = m^d mod n */ -static int _rsa_sign(const struct rsa_key *key, MPI s, MPI m) +static int _rsa_sign(const struct rsa_mpi_key *key, MPI s, MPI m) { /* (1) Validate 0 <= m < n */ if (mpi_cmp_ui(m, 0) < 0 || mpi_cmp(m, key->n) >= 0) @@ -61,7 +68,7 @@ static int _rsa_sign(const struct rsa_key *key, MPI s, MPI m) * RSAVP1 function [RFC3447 sec 5.2.2] * m = s^e mod n; */ -static int _rsa_verify(const struct rsa_key *key, MPI m, MPI s) +static int _rsa_verify(const struct rsa_mpi_key *key, MPI m, MPI s) { /* (1) Validate 0 <= s < n */ if (mpi_cmp_ui(s, 0) < 0 || mpi_cmp(s, key->n) >= 0) @@ -71,7 +78,7 @@ static int _rsa_verify(const struct rsa_key *key, MPI m, MPI s) return mpi_powm(m, s, key->e, key->n); } -static inline struct rsa_key *rsa_get_key(struct crypto_akcipher *tfm) +static inline struct rsa_mpi_key *rsa_get_key(struct crypto_akcipher *tfm) { return akcipher_tfm_ctx(tfm); } @@ -79,7 +86,7 @@ static inline struct rsa_key *rsa_get_key(struct crypto_akcipher *tfm) static int rsa_enc(struct akcipher_request *req) { struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req); - const struct rsa_key *pkey = rsa_get_key(tfm); + const struct rsa_mpi_key *pkey = rsa_get_key(tfm); MPI m, c = mpi_alloc(0); int ret = 0; int sign; @@ -118,7 +125,7 @@ static int rsa_enc(struct akcipher_request *req) static int rsa_dec(struct akcipher_request *req) { struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req); - const struct rsa_key *pkey = rsa_get_key(tfm); + const struct rsa_mpi_key *pkey = rsa_get_key(tfm); MPI c, m = mpi_alloc(0); int ret = 0; int sign; @@ -156,7 +163,7 @@ static int rsa_dec(struct akcipher_request *req) static int rsa_sign(struct akcipher_request *req) { struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req); - const struct rsa_key *pkey = rsa_get_key(tfm); + const struct rsa_mpi_key *pkey = rsa_get_key(tfm); MPI m, s = mpi_alloc(0); int ret = 0; int sign; @@ -195,7 +202,7 @@ static int rsa_sign(struct akcipher_request *req) static int rsa_verify(struct akcipher_request *req) { struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req); - const struct rsa_key *pkey = rsa_get_key(tfm); + const struct rsa_mpi_key *pkey = rsa_get_key(tfm); MPI s, m = mpi_alloc(0); int ret = 0; int sign; @@ -233,6 +240,16 @@ static int rsa_verify(struct akcipher_request *req) return ret; } +static void rsa_free_mpi_key(struct rsa_mpi_key *key) +{ + mpi_free(key->d); + mpi_free(key->e); + mpi_free(key->n); + key->d = NULL; + key->e = NULL; + key->n = NULL; +} + static int rsa_check_key_length(unsigned int len) { switch (len) { @@ -251,49 +268,87 @@ static int rsa_check_key_length(unsigned int len) static int rsa_set_pub_key(struct crypto_akcipher *tfm, const void *key, unsigned int keylen) { - struct rsa_key *pkey = akcipher_tfm_ctx(tfm); + struct rsa_mpi_key *mpi_key = akcipher_tfm_ctx(tfm); + struct rsa_key raw_key = {0}; int ret; - ret = rsa_parse_pub_key(pkey, key, keylen); + /* Free the old MPI key if any */ + rsa_free_mpi_key(mpi_key); + + ret = rsa_parse_pub_key(&raw_key, key, keylen); if (ret) return ret; - if (rsa_check_key_length(mpi_get_size(pkey->n) << 3)) { - rsa_free_key(pkey); - ret = -EINVAL; + mpi_key->e = mpi_read_raw_data(raw_key.e, raw_key.e_sz); + if (!mpi_key->e) + goto err; + + mpi_key->n = mpi_read_raw_data(raw_key.n, raw_key.n_sz); + if (!mpi_key->n) + goto err; + + if (rsa_check_key_length(mpi_get_size(mpi_key->n) << 3)) { + rsa_free_mpi_key(mpi_key); + return -EINVAL; } - return ret; + + return 0; + +err: + rsa_free_mpi_key(mpi_key); + return -ENOMEM; } static int rsa_set_priv_key(struct crypto_akcipher *tfm, const void *key, unsigned int keylen) { - struct rsa_key *pkey = akcipher_tfm_ctx(tfm); + struct rsa_mpi_key *mpi_key = akcipher_tfm_ctx(tfm); + struct rsa_key raw_key = {0}; int ret; - ret = rsa_parse_priv_key(pkey, key, keylen); + /* Free the old MPI key if any */ + rsa_free_mpi_key(mpi_key); + + ret = rsa_parse_priv_key(&raw_key, key, keylen); if (ret) return ret; - if (rsa_check_key_length(mpi_get_size(pkey->n) << 3)) { - rsa_free_key(pkey); - ret = -EINVAL; + mpi_key->d = mpi_read_raw_data(raw_key.d, raw_key.d_sz); + if (!mpi_key->d) + goto err; + + mpi_key->e = mpi_read_raw_data(raw_key.e, raw_key.e_sz); + if (!mpi_key->e) + goto err; + + mpi_key->n = mpi_read_raw_data(raw_key.n, raw_key.n_sz); + if (!mpi_key->n) + goto err; + + if (rsa_check_key_length(mpi_get_size(mpi_key->n) << 3)) { + rsa_free_mpi_key(mpi_key); + return -EINVAL; } - return ret; + + return 0; + +err: + rsa_free_mpi_key(mpi_key); + return -ENOMEM; } static int rsa_max_size(struct crypto_akcipher *tfm) { - struct rsa_key *pkey = akcipher_tfm_ctx(tfm); + struct rsa_mpi_key *pkey = akcipher_tfm_ctx(tfm); return pkey->n ? mpi_get_size(pkey->n) : -EINVAL; } static void rsa_exit_tfm(struct crypto_akcipher *tfm) { - struct rsa_key *pkey = akcipher_tfm_ctx(tfm); + struct rsa_mpi_key *pkey = akcipher_tfm_ctx(tfm); - rsa_free_key(pkey); + rsa_free_mpi_key(pkey); } static struct akcipher_alg rsa = { @@ -310,7 +365,7 @@ static struct akcipher_alg rsa = { .cra_driver_name = "rsa-generic", .cra_priority = 100, .cra_module = THIS_MODULE, - .cra_ctxsize = sizeof(struct rsa_key), + .cra_ctxsize = sizeof(struct rsa_mpi_key), }, }; diff --git a/crypto/rsa_helper.c b/crypto/rsa_helper.c index d226f48d0907..583656af4fe2 100644 --- a/crypto/rsa_helper.c +++ b/crypto/rsa_helper.c @@ -22,20 +22,29 @@ int rsa_get_n(void *context, size_t hdrlen, unsigned char tag, const void *value, size_t vlen) { struct rsa_key *key = context; + const u8 *ptr = value; + size_t n_sz = vlen; - key->n = mpi_read_raw_data(value, vlen); - - if (!key->n) - return -ENOMEM; - - /* In FIPS mode only allow key size 2K & 3K */ - if (fips_enabled && (mpi_get_size(key->n) != 256 && - mpi_get_size(key->n) != 384)) { - pr_err("RSA: key size not allowed in FIPS mode\n"); - mpi_free(key->n); - key->n = NULL; + /* invalid key provided */ + if (!value || !vlen) return -EINVAL; + + if (fips_enabled) { + while (!*ptr && n_sz) { + ptr++; + n_sz--; + } + + /* In FIPS mode only allow key size 2K & 3K */ + if (n_sz != 256 && n_sz != 384) { + pr_err("RSA: key size not allowed in FIPS mode\n"); + return -EINVAL; + } } + + key->n = value; + key->n_sz = vlen; + return 0; } @@ -44,10 +53,12 @@ int rsa_get_e(void *context, size_t hdrlen, unsigned char tag, { struct rsa_key *key = context; - key->e = mpi_read_raw_data(value, vlen); + /* invalid key provided */ + if (!value || !key->n_sz || !vlen || vlen > key->n_sz) + return -EINVAL; - if (!key->e) - return -ENOMEM; + key->e = value; + key->e_sz = vlen; return 0; } @@ -57,46 +68,20 @@ int rsa_get_d(void *context, size_t hdrlen, unsigned char tag, { struct rsa_key *key = context; - key->d = mpi_read_raw_data(value, vlen); - - if (!key->d) - return -ENOMEM; - - /* In FIPS mode only allow key size 2K & 3K */ - if (fips_enabled && (mpi_get_size(key->d) != 256 && - mpi_get_size(key->d) != 384)) { - pr_err("RSA: key size not allowed in FIPS mode\n"); - mpi_free(key->d); - key->d = NULL; + /* invalid key provided */ + if (!value || !key->n_sz || !vlen || vlen > key->n_sz) return -EINVAL; - } + + key->d = value; + key->d_sz = vlen; + return 0; } -static void free_mpis(struct rsa_key *key) -{ - mpi_free(key->n); - mpi_free(key->e); - mpi_free(key->d); - key->n = NULL; - key->e = NULL; - key->d = NULL; -} - /** - * rsa_free_key() - frees rsa key allocated by rsa_parse_key() - * - * @rsa_key: struct rsa_key key representation - */ -void rsa_free_key(struct rsa_key *key) -{ - free_mpis(key); -} -EXPORT_SYMBOL_GPL(rsa_free_key); - -/** - * rsa_parse_pub_key() - extracts an rsa public key from BER encoded buffer - * and stores it in the provided struct rsa_key + * rsa_parse_pub_key() - decodes the BER encoded buffer and stores in the + * provided struct rsa_key, pointers to the raw key as is, + * so that the caller can copy it or MPI parse it, etc. * * @rsa_key: struct rsa_key key representation * @key: key in BER format @@ -107,23 +92,15 @@ EXPORT_SYMBOL_GPL(rsa_free_key); int rsa_parse_pub_key(struct rsa_key *rsa_key, const void *key, unsigned int key_len) { - int ret; - - free_mpis(rsa_key); - ret = asn1_ber_decoder(&rsapubkey_decoder, rsa_key, key, key_len); - if (ret < 0) - goto error; - - return 0; -error: - free_mpis(rsa_key); - return ret; + return asn1_ber_decoder(&rsapubkey_decoder, rsa_key, key, key_len); } EXPORT_SYMBOL_GPL(rsa_parse_pub_key); /** - * rsa_parse_pub_key() - extracts an rsa private key from BER encoded buffer - * and stores it in the provided struct rsa_key + * rsa_parse_priv_key() - decodes the BER encoded buffer and stores in the + * provided struct rsa_key, pointers to the raw key + * as is, so that the caller can copy it or MPI parse it, + * etc. * * @rsa_key: struct rsa_key key representation * @key: key in BER format @@ -134,16 +111,6 @@ EXPORT_SYMBOL_GPL(rsa_parse_pub_key); int rsa_parse_priv_key(struct rsa_key *rsa_key, const void *key, unsigned int key_len) { - int ret; - - free_mpis(rsa_key); - ret = asn1_ber_decoder(&rsaprivkey_decoder, rsa_key, key, key_len); - if (ret < 0) - goto error; - - return 0; -error: - free_mpis(rsa_key); - return ret; + return asn1_ber_decoder(&rsaprivkey_decoder, rsa_key, key, key_len); } EXPORT_SYMBOL_GPL(rsa_parse_priv_key); diff --git a/include/crypto/internal/rsa.h b/include/crypto/internal/rsa.h index c7585bdecbc2..d6c042a2ee52 100644 --- a/include/crypto/internal/rsa.h +++ b/include/crypto/internal/rsa.h @@ -12,12 +12,24 @@ */ #ifndef _RSA_HELPER_ #define _RSA_HELPER_ -#include +#include +/** + * rsa_key - RSA key structure + * @n : RSA modulus raw byte stream + * @e : RSA public exponent raw byte stream + * @d : RSA private exponent raw byte stream + * @n_sz : length in bytes of RSA modulus n + * @e_sz : length in bytes of RSA public exponent + * @d_sz : length in bytes of RSA private exponent + */ struct rsa_key { - MPI n; - MPI e; - MPI d; + const u8 *n; + const u8 *e; + const u8 *d; + size_t n_sz; + size_t e_sz; + size_t d_sz; }; int rsa_parse_pub_key(struct rsa_key *rsa_key, const void *key, @@ -26,7 +38,5 @@ int rsa_parse_pub_key(struct rsa_key *rsa_key, const void *key, int rsa_parse_priv_key(struct rsa_key *rsa_key, const void *key, unsigned int key_len); -void rsa_free_key(struct rsa_key *rsa_key); - extern struct crypto_template rsa_pkcs1pad_tmpl; #endif