Add support for OpenSSL 3.0+
[tinc] / src / openssl / rsa.c
index d50ba1d..bf6938b 100644 (file)
 #include "../system.h"
 
 #include <openssl/pem.h>
-#include <openssl/err.h>
 #include <openssl/rsa.h>
 
 #define TINC_RSA_INTERNAL
+
+#if OPENSSL_VERSION_MAJOR < 3
 typedef RSA rsa_t;
+#else
+#include <openssl/encoder.h>
+#include <openssl/decoder.h>
+#include <openssl/core_names.h>
+#include <openssl/param_build.h>
+#include <assert.h>
+
+typedef EVP_PKEY rsa_t;
+#endif
 
+#include "log.h"
 #include "../logger.h"
 #include "../rsa.h"
 
 // Set RSA keys
 
-rsa_t *rsa_set_hex_public_key(const char *n, const char *e) {
-       BIGNUM *bn_n = NULL;
-       BIGNUM *bn_e = NULL;
+#if OPENSSL_VERSION_MAJOR >= 3
+static EVP_PKEY *build_rsa_key(int selection, const BIGNUM *bn_n, const BIGNUM *bn_e, const BIGNUM *bn_d) {
+       assert(bn_n);
+       assert(bn_e);
 
-       if((size_t)BN_hex2bn(&bn_n, n) != strlen(n) || (size_t)BN_hex2bn(&bn_e, e) != strlen(e)) {
-               BN_free(bn_e);
-               BN_free(bn_n);
-               return false;
+       EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new_from_name(NULL, "RSA", NULL);
+
+       if(!ctx) {
+               openssl_err("initialize key context");
+               return NULL;
        }
 
-       rsa_t *rsa = RSA_new();
+       OSSL_PARAM_BLD *bld = OSSL_PARAM_BLD_new();
+       OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_N, bn_n);
+       OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_E, bn_e);
 
-       if(!rsa) {
-               return NULL;
+       if(bn_d) {
+               OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_D, bn_d);
        }
 
-       RSA_set0_key(rsa, bn_n, bn_e, NULL);
+       OSSL_PARAM *params = OSSL_PARAM_BLD_to_param(bld);
+       EVP_PKEY *key = NULL;
 
-       return rsa;
+       bool ok = EVP_PKEY_fromdata_init(ctx) > 0
+                 && EVP_PKEY_fromdata(ctx, &key, selection, params) > 0;
+
+       OSSL_PARAM_free(params);
+       OSSL_PARAM_BLD_free(bld);
+       EVP_PKEY_CTX_free(ctx);
+
+       if(ok) {
+               return key;
+       }
+
+       openssl_err("build key");
+       return NULL;
 }
+#endif
 
-rsa_t *rsa_set_hex_private_key(const char *n, const char *e, const char *d) {
+static bool hex_to_bn(BIGNUM **bn, const char *hex) {
+       return (size_t)BN_hex2bn(bn, hex) == strlen(hex);
+}
+
+static rsa_t *rsa_set_hex_key(const char *n, const char *e, const char *d) {
+       rsa_t *rsa = NULL;
        BIGNUM *bn_n = NULL;
        BIGNUM *bn_e = NULL;
        BIGNUM *bn_d = NULL;
 
-       if((size_t)BN_hex2bn(&bn_n, n) != strlen(n) || (size_t)BN_hex2bn(&bn_e, e) != strlen(e) || (size_t)BN_hex2bn(&bn_d, d) != strlen(d)) {
+       if(!hex_to_bn(&bn_n, n) || !hex_to_bn(&bn_e, e) || (d && !hex_to_bn(&bn_d, d))) {
+               goto exit;
+       }
+
+#if OPENSSL_VERSION_MAJOR < 3
+       rsa = RSA_new();
+
+       if(rsa) {
+               RSA_set0_key(rsa, bn_n, bn_e, bn_d);
+       }
+
+#else
+       int selection = bn_d ? EVP_PKEY_KEYPAIR : EVP_PKEY_PUBLIC_KEY;
+       rsa = build_rsa_key(selection, bn_n, bn_e, bn_d);
+#endif
+
+exit:
+#if OPENSSL_VERSION_MAJOR < 3
+
+       if(!rsa)
+#endif
+       {
                BN_free(bn_d);
                BN_free(bn_e);
                BN_free(bn_n);
-               return false;
        }
 
-       rsa_t *rsa = RSA_new();
+       return rsa;
+}
 
-       if(!rsa) {
+rsa_t *rsa_set_hex_public_key(const char *n, const char *e) {
+       return rsa_set_hex_key(n, e, NULL);
+}
+
+rsa_t *rsa_set_hex_private_key(const char *n, const char *e, const char *d) {
+       return rsa_set_hex_key(n, e, d);
+}
+
+// Read PEM RSA keys
+
+#if OPENSSL_VERSION_MAJOR >= 3
+static rsa_t *read_key_from_pem(FILE *fp, int selection) {
+       rsa_t *rsa = NULL;
+       OSSL_DECODER_CTX *ctx = OSSL_DECODER_CTX_new_for_pkey(&rsa, "PEM", NULL, "RSA", selection, NULL, NULL);
+
+       if(!ctx) {
+               openssl_err("initialize decoder");
                return NULL;
        }
 
-       RSA_set0_key(rsa, bn_n, bn_e, bn_d);
+       bool ok = OSSL_DECODER_from_fp(ctx, fp);
+       OSSL_DECODER_CTX_free(ctx);
+
+       if(!ok) {
+               rsa = NULL;
+               openssl_err("read RSA key from file");
+       }
 
        return rsa;
 }
-
-// Read PEM RSA keys
+#endif
 
 rsa_t *rsa_read_pem_public_key(FILE *fp) {
-       rsa_t *rsa = PEM_read_RSAPublicKey(fp, NULL, NULL, NULL);
+       rsa_t *rsa;
+
+#if OPENSSL_VERSION_MAJOR < 3
+       rsa = PEM_read_RSAPublicKey(fp, NULL, NULL, NULL);
 
        if(!rsa) {
                rewind(fp);
                rsa = PEM_read_RSA_PUBKEY(fp, NULL, NULL, NULL);
        }
 
+#else
+       rsa = read_key_from_pem(fp, OSSL_KEYMGMT_SELECT_PUBLIC_KEY);
+#endif
+
        if(!rsa) {
-               logger(DEBUG_ALWAYS, LOG_ERR, "Unable to read RSA public key: %s", ERR_error_string(ERR_get_error(), NULL));
+               openssl_err("read RSA public key");
        }
 
        return rsa;
 }
 
 rsa_t *rsa_read_pem_private_key(FILE *fp) {
-       rsa_t *rsa = PEM_read_RSAPrivateKey(fp, NULL, NULL, NULL);
+       rsa_t *rsa;
+
+#if OPENSSL_VERSION_MAJOR < 3
+       rsa = PEM_read_RSAPrivateKey(fp, NULL, NULL, NULL);
+#else
+       rsa = read_key_from_pem(fp, OSSL_KEYMGMT_SELECT_PRIVATE_KEY);
+#endif
 
        if(!rsa) {
-               logger(DEBUG_ALWAYS, LOG_ERR, "Unable to read RSA private key: %s", ERR_error_string(ERR_get_error(), NULL));
+               openssl_err("read RSA private key");
        }
 
        return rsa;
 }
 
 size_t rsa_size(const rsa_t *rsa) {
+#if OPENSSL_VERSION_MAJOR < 3
        return RSA_size(rsa);
+#else
+       return EVP_PKEY_get_size(rsa);
+#endif
 }
 
 bool rsa_public_encrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
+#if OPENSSL_VERSION_MAJOR < 3
+
        if((size_t)RSA_public_encrypt((int) len, in, out, rsa, RSA_NO_PADDING) == len) {
                return true;
        }
 
-       logger(DEBUG_ALWAYS, LOG_ERR, "Unable to perform RSA encryption: %s", ERR_error_string(ERR_get_error(), NULL));
+#else
+       EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(rsa, NULL);
+
+       if(ctx) {
+               size_t outlen = len;
+
+               bool ok = EVP_PKEY_encrypt_init(ctx) > 0
+                         && EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_NO_PADDING) > 0
+                         && EVP_PKEY_encrypt(ctx, out, &outlen, in, len) > 0
+                         && outlen == len;
+
+               EVP_PKEY_CTX_free(ctx);
+
+               if(ok) {
+                       return true;
+               }
+       }
+
+#endif
+
+       openssl_err("perform RSA encryption");
        return false;
 }
 
 bool rsa_private_decrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
+#if OPENSSL_VERSION_MAJOR < 3
+
        if((size_t)RSA_private_decrypt((int) len, in, out, rsa, RSA_NO_PADDING) == len) {
                return true;
        }
 
-       logger(DEBUG_ALWAYS, LOG_ERR, "Unable to perform RSA decryption: %s", ERR_error_string(ERR_get_error(), NULL));
+#else
+       EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(rsa, NULL);
+
+       if(ctx) {
+               size_t outlen = len;
+
+               bool ok = EVP_PKEY_decrypt_init(ctx) > 0
+                         && EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_NO_PADDING) > 0
+                         && EVP_PKEY_decrypt(ctx, out, &outlen, in, len) > 0
+                         && outlen == len;
+
+               EVP_PKEY_CTX_free(ctx);
+
+               if(ok) {
+                       return true;
+               }
+       }
+
+#endif
+
+       openssl_err("perform RSA decryption");
        return false;
 }
 
 void rsa_free(rsa_t *rsa) {
        if(rsa) {
+#if OPENSSL_VERSION_MAJOR < 3
                RSA_free(rsa);
+#else
+               EVP_PKEY_free(rsa);
+#endif
        }
 }