diff --git a/include/jwt-cpp/jwt.h b/include/jwt-cpp/jwt.h index 2a25d54b3..f7c785ec0 100644 --- a/include/jwt-cpp/jwt.h +++ b/include/jwt-cpp/jwt.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -3418,18 +3419,16 @@ namespace jwt { bool empty() const noexcept { return jwk_claims.empty(); } - helper::evp_pkey_handle get_pkey() const { return k.get_asymmetric_key(); } - - std::string get_oct_key() const { return k.get_symmetric_key(); } + key get_key() const { return k; } private: template static helper::evp_pkey_handle build_rsa_key(const details::map_of_claims& claims, Decode&& decode) { - EVP_PKEY* evp_key = nullptr; auto n = jwt::helper::raw2bn(decode(claims.get_claim("n").as_string())); auto e = jwt::helper::raw2bn(decode(claims.get_claim("e").as_string())); #ifdef JWT_OPENSSL_3_0 + EVP_PKEY* evp_key = nullptr; // https://www.openssl.org/docs/manmaster/man7/EVP_PKEY-RSA.html // see https://www.openssl.org/docs/man3.0/man3/EVP_PKEY_fromdata.html // and https://stackoverflow.com/questions/68465716/how-to-properly-create-an-rsa-key-from-raw-data-in-openssl-3-0-in-c-language @@ -3439,25 +3438,39 @@ namespace jwt { std::unique_ptr params_build(OSSL_PARAM_BLD_new(), OSSL_PARAM_BLD_free); - OSSL_PARAM_BLD_push_BN(params_build.get(), "n", n.get()); - OSSL_PARAM_BLD_push_BN(params_build.get(), "e", e.get()); + if (!params_build) { throw std::runtime_error("OSSL_PARAM_BLD_new failed"); } + if (OSSL_PARAM_BLD_push_BN(params_build.get(), "n", n.get()) != 1) { + throw std::runtime_error("OSSL_PARAM_BLD_push_BN failed"); + } + if (OSSL_PARAM_BLD_push_BN(params_build.get(), "e", e.get()) != 1) { + throw std::runtime_error("OSSL_PARAM_BLD_push_BN failed"); + } std::unique_ptr params(OSSL_PARAM_BLD_to_param(params_build.get()), OSSL_PARAM_free); - EVP_PKEY_fromdata_init(ctx.get()); - EVP_PKEY_fromdata(ctx.get(), &evp_key, EVP_PKEY_PUBLIC_KEY, params.get()); + if (!params) { throw std::runtime_error("OSSL_PARAM_BLD_to_param failed"); } + if (EVP_PKEY_fromdata_init(ctx.get()) != 1) { throw std::runtime_error("EVP_PKEY_fromdata_init failed"); } + if (EVP_PKEY_fromdata(ctx.get(), &evp_key, EVP_PKEY_PUBLIC_KEY, params.get()) != 1) { + throw std::runtime_error("EVP_PKEY_fromdata failed"); + } return helper::evp_pkey_handle(evp_key); #else - RSA* rsa = RSA_new(); - evp_key = EVP_PKEY_new(); + std::unique_ptr rsa(RSA_new(), RSA_free); + if (!rsa) { throw std::runtime_error("RSA_new failed"); } #if defined(JWT_OPENSSL_1_0_0) && !defined(LIBWOLFSSL_VERSION_HEX) rsa->e = e.release(); rsa->n = n.release(); #else - RSA_set0_key(rsa, n.release(), e.release(), nullptr); + if (RSA_set0_key(rsa.get(), n.release(), e.release(), nullptr) != 1) { + throw std::runtimeruntime_error("RSA_set0_key failed"); + } #endif - EVP_PKEY_assign_RSA(evp_key, rsa); - return helper::evp_pkey_handle(evp_key); + std::unique_ptr evp_key(EVP_PKEY_new(), EVP_PKEY_free); + if (EVP_PKEY_assign_RSA(evp_key.get(), rsa.get()) != 1) { + throw std::runtime_error("EVP_PKEY_assign_RSA failed"); + } + rsa.release(); + return helper::evp_pkey_handle(evp_key.release()); #endif } @@ -3489,6 +3502,97 @@ namespace jwt { key k; }; + struct algo_base { + virtual ~algo_base() = default; + virtual void verify(const std::string& data, const std::string& sig, std::error_code& ec) = 0; + }; + template + struct algo : public algo_base { + T alg; + explicit algo(T a) : alg(a) {} + void verify(const std::string& data, const std::string& sig, std::error_code& ec) override { + alg.verify(data, sig, ec); + } + }; + + struct algorithm_db { + using builder_fn = std::function(const key&)>; + using algname_to_builder_fn = std::map; + enum type { empty, basic }; + + algorithm_db() : algorithm_db(empty) {} + algorithm_db(type t) { + if (t == empty) { supported_algorithms.clear(); } + } + + builder_fn create_algorithm(const std::string& name) const { + const auto algorithm = supported_algorithms.find(name); + if (algorithm != supported_algorithms.end()) { return algorithm->second; } + return nullptr; + } + + void register_algorithm(const std::string& alg_name, builder_fn build_fn) { + supported_algorithms.insert_or_assign(alg_name, build_fn); + } + + private: + algname_to_builder_fn supported_algorithms = { + {"RS256", + [](const key& key) { + return std::make_unique>(jwt::algorithm::rs256(key.get_asymmetric_key())); + }}, + {"RS384", + [](const key& key) { + return std::make_unique>(jwt::algorithm::rs384(key.get_asymmetric_key())); + }}, + {"RS512", + [](const key& key) { + return std::make_unique>(jwt::algorithm::rs512(key.get_asymmetric_key())); + }}, + {"PS256", + [](const key& key) { + return std::make_unique>(jwt::algorithm::ps256(key.get_asymmetric_key())); + }}, + {"PS384", + [](const key& key) { + return std::make_unique>(jwt::algorithm::ps384(key.get_asymmetric_key())); + }}, + {"PS512", + [](const key& key) { + return std::make_unique>(jwt::algorithm::ps512(key.get_asymmetric_key())); + }}, + {"ES256", + [](const key& key) { + return std::make_unique>(jwt::algorithm::es256(key.get_asymmetric_key())); + }}, + {"ES384", + [](const key& key) { + return std::make_unique>(jwt::algorithm::es384(key.get_asymmetric_key())); + }}, + {"ES512", + [](const key& key) { + return std::make_unique>(jwt::algorithm::es512(key.get_asymmetric_key())); + }}, + {"ES256K", + [](const key& key) { + return std::make_unique>( + jwt::algorithm::es256k(key.get_asymmetric_key())); + }}, + {"HS256", + [](const key& key) { + return std::make_unique>(jwt::algorithm::hs256(key.get_symmetric_key())); + }}, + {"HS384", + [](const key& key) { + return std::make_unique>(jwt::algorithm::hs384(key.get_symmetric_key())); + }}, + {"HS512", + [](const key& key) { + return std::make_unique>(jwt::algorithm::hs512(key.get_symmetric_key())); + }}, + }; + }; + /** * Verifier class used to check if a decoded token contains all claims required by your application and has a valid * signature. @@ -3510,32 +3614,15 @@ namespace jwt { std::function&, std::error_code& ec)>; private: - struct algo_base { - virtual ~algo_base() = default; - virtual void verify(const std::string& data, const std::string& sig, std::error_code& ec) = 0; - }; - template - struct algo : public algo_base { - T alg; - explicit algo(T a) : alg(a) {} - void verify(const std::string& data, const std::string& sig, std::error_code& ec) override { - alg.verify(data, sig, ec); - } - }; /// Required claims std::unordered_map claims; /// Leeway time for exp, nbf and iat size_t default_leeway = 0; /// Instance of clock type Clock clock; + algorithm_db supported_algorithms; /// Supported algorithms std::unordered_map> algs; - using alg_name = std::string; - using alg_list = std::vector; - using algorithms = std::unordered_map; - algorithms supported_alg = {{"RSA", {"RS256", "RS384", "RS512", "PS256", "PS384", "PS512"}}, - {"EC", {"ES256", "ES384", "ES512", "ES256K"}}, - {"oct", {"HS256", "HS384", "HS512"}}}; typedef std::vector> key_list; /// https://datatracker.ietf.org/doc/html/rfc7517#section-4.5 - kid to keys @@ -3551,70 +3638,21 @@ namespace jwt { } } - bool is_valid_combination(const jwt::jwk& key, const std::string& alg_name) const { - const alg_list& x = supported_alg.find(key.get_key_type())->second; - return std::find(x.cbegin(), x.cend(), alg_name) != x.cend(); + bool is_valid_combination(const std::string& key_type, const std::string& alg_name) const { + // TODO:mk check whether key type can be used with the algorithm + return true; } - inline std::unique_ptr from_key_and_alg(const jwt::jwk& key, - const std::string& alg_name, std::error_code& ec) const { + std::unique_ptr from_key_and_alg(const jwt::jwk& key, const std::string& alg_name, + std::error_code& ec) const { ec.clear(); - algorithms::const_iterator it = supported_alg.find(key.get_key_type()); - if (it == supported_alg.end()) { - ec = error::token_verification_error::wrong_algorithm; - return nullptr; - } - - const alg_list& supported_jwt_algorithms = it->second; - if (std::find(supported_jwt_algorithms.begin(), supported_jwt_algorithms.end(), alg_name) == - supported_jwt_algorithms.end()) { + auto create = supported_algorithms.create_algorithm(alg_name); + if (create == nullptr) { ec = error::token_verification_error::wrong_algorithm; return nullptr; } - if (alg_name == "RS256") { - return std::unique_ptr>( - new algo(jwt::algorithm::rs256(key.get_pkey()))); - } else if (alg_name == "RS384") { - return std::unique_ptr>( - new algo(jwt::algorithm::rs384(key.get_pkey()))); - } else if (alg_name == "RS512") { - return std::unique_ptr>( - new algo(jwt::algorithm::rs512(key.get_pkey()))); - } else if (alg_name == "PS256") { - return std::unique_ptr>( - new algo(jwt::algorithm::ps256(key.get_pkey()))); - } else if (alg_name == "PS384") { - return std::unique_ptr>( - new algo(jwt::algorithm::ps384(key.get_pkey()))); - } else if (alg_name == "PS512") { - return std::unique_ptr>( - new algo(jwt::algorithm::ps512(key.get_pkey()))); - } else if (alg_name == "ES256") { - return std::unique_ptr>( - new algo(jwt::algorithm::es256(key.get_pkey()))); - } else if (alg_name == "ES384") { - return std::unique_ptr>( - new algo(jwt::algorithm::es384(key.get_pkey()))); - } else if (alg_name == "ES512") { - return std::unique_ptr>( - new algo(jwt::algorithm::es512(key.get_pkey()))); - } else if (alg_name == "ES256K") { - return std::unique_ptr>( - new algo(jwt::algorithm::es256k(key.get_pkey()))); - } else if (alg_name == "HS256") { - return std::unique_ptr>( - new algo(jwt::algorithm::hs256(key.get_oct_key()))); - } else if (alg_name == "HS384") { - return std::unique_ptr>( - new algo(jwt::algorithm::hs384(key.get_oct_key()))); - } else if (alg_name == "HS512") { - return std::unique_ptr>( - new algo(jwt::algorithm::hs512(key.get_oct_key()))); - } - - ec = error::token_verification_error::wrong_algorithm; - return nullptr; + return create(key.get_key()); } public: @@ -3622,7 +3660,9 @@ namespace jwt { * Constructor for building a new verifier instance * \param c Clock instance */ - explicit verifier(Clock c) : clock(c) { + explicit verifier(Clock c) : verifier(c, algorithm_db(algorithm_db::basic)) {} + + verifier(Clock c, algorithm_db algorithms) : clock(c), supported_algorithms(algorithms) { claims["exp"] = [](const verify_ops::verify_context& ctx, std::error_code& ec) { if (!ctx.jwt.has_expires_at()) return; auto exp = ctx.jwt.get_expires_at(); @@ -3821,7 +3861,7 @@ namespace jwt { if (key_set_it != keys.end()) { const key_list& keys = key_set_it->second; for (const auto& key : keys) { - if (is_valid_combination(key, algo)) { + if (is_valid_combination(key.get_key_type(), algo)) { key_found = true; auto alg = from_key_and_alg(key, algo, ec); alg->verify(data, sig, ec); diff --git a/tests/JwkTest.cpp b/tests/JwkTest.cpp index 37e7fd624..41bbb80c3 100644 --- a/tests/JwkTest.cpp +++ b/tests/JwkTest.cpp @@ -76,3 +76,28 @@ TEST(JwkTest, HmacKey) { auto decoded_token = jwt::decode(token); ASSERT_NO_THROW(verifier.verify(decoded_token)); } + +TEST(JwkTest, CustomAlgorithm) { + // {"alg":"my-custom-alg","typ":"JWS"}.{"iss":"auth0"}.valid_signature + std::string token = "eyJhbGciOiJteS1jdXN0b20tYWxnIiwidHlwIjoiSldTIn0.eyJpc3MiOiJhdXRoMCJ9.dmFsaWRfc2lnbmF0dXJl"; + std::string secret_key = R"({ + "kty": "oct", + "k": "c2VjcmV0" + })"; + + struct custom_verification_algorithm { + void verify(const std::string& data, const std::string& sig, std::error_code& ec) {} + }; + + jwt::algorithm_db my_verification_algorithms; + my_verification_algorithms.register_algorithm("my-custom-alg", [](const jwt::key&) { + return std::make_unique>(custom_verification_algorithm()); + }); + auto verifier = jwt::verifier(jwt::default_clock(), + my_verification_algorithms); + + auto jwk = jwt::parse_jwk(secret_key); + verifier.allow_key(jwk); + auto decoded_token = jwt::decode(token); + ASSERT_NO_THROW(verifier.verify(decoded_token)); +}