Extract common logic in OpenSSL-specific code
[tinc] / src / openssl / rsa.c
index bf6938b..a4e4ac6 100644 (file)
@@ -199,22 +199,23 @@ size_t rsa_size(const rsa_t *rsa) {
 #endif
 }
 
-bool rsa_public_encrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
-#if OPENSSL_VERSION_MAJOR < 3
+#if OPENSSL_VERSION_MAJOR >= 3
+// Initialize encryption or decryption context. Must return >0 on success, ≤0 on failure.
+typedef int (enc_init_t)(EVP_PKEY_CTX *ctx);
 
-       if((size_t)RSA_public_encrypt((int) len, in, out, rsa, RSA_NO_PADDING) == len) {
-               return true;
-       }
+// Encrypt or decrypt data. Must return >0 on success, ≤0 on failure.
+typedef int (enc_process_t)(EVP_PKEY_CTX *ctx, unsigned char *out, size_t *outlen, const unsigned char *in, size_t inlen);
 
-#else
+static bool rsa_encrypt_decrypt(rsa_t *rsa, const void *in, size_t len, void *out,
+                                enc_init_t init, enc_process_t process) {
        EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(rsa, NULL);
 
        if(ctx) {
                size_t outlen = len;
 
-               bool ok = EVP_PKEY_encrypt_init(ctx) > 0
+               bool ok = init(ctx) > 0
                          && EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_NO_PADDING) > 0
-                         && EVP_PKEY_encrypt(ctx, out, &outlen, in, len) > 0
+                         && process(ctx, out, &outlen, in, len) > 0
                          && outlen == len;
 
                EVP_PKEY_CTX_free(ctx);
@@ -224,8 +225,21 @@ bool rsa_public_encrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
                }
        }
 
+       return false;
+}
 #endif
 
+bool rsa_public_encrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
+#if OPENSSL_VERSION_MAJOR < 3
+
+       if((size_t)RSA_public_encrypt((int) len, in, out, rsa, RSA_NO_PADDING) == len) {
+#else
+
+       if(rsa_encrypt_decrypt(rsa, in, len, out, EVP_PKEY_encrypt_init, EVP_PKEY_encrypt)) {
+#endif
+               return true;
+       }
+
        openssl_err("perform RSA encryption");
        return false;
 }
@@ -234,28 +248,12 @@ bool rsa_private_decrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
 #if OPENSSL_VERSION_MAJOR < 3
 
        if((size_t)RSA_private_decrypt((int) len, in, out, rsa, RSA_NO_PADDING) == len) {
-               return true;
-       }
-
 #else
-       EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(rsa, NULL);
-
-       if(ctx) {
-               size_t outlen = len;
-
-               bool ok = EVP_PKEY_decrypt_init(ctx) > 0
-                         && EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_NO_PADDING) > 0
-                         && EVP_PKEY_decrypt(ctx, out, &outlen, in, len) > 0
-                         && outlen == len;
-
-               EVP_PKEY_CTX_free(ctx);
-
-               if(ok) {
-                       return true;
-               }
-       }
 
+       if(rsa_encrypt_decrypt(rsa, in, len, out, EVP_PKEY_decrypt_init, EVP_PKEY_decrypt)) {
 #endif
+               return true;
+       }
 
        openssl_err("perform RSA decryption");
        return false;