#include "pem.h"
+#include "asn1.h"
#include "rsa.h"
#include "../logger.h"
#include "../rsa.h"
return 0;
}
- while(len--) {
+ for(; len; --len) {
result = (size_t)(result << 8);
result |= *(*p)++;
(*buflen)--;
}
}
-
-static bool ber_read_sequence(unsigned char **p, size_t *buflen, size_t *result) {
+static bool ber_skip_sequence(unsigned char **p, size_t *buflen) {
int tag = ber_read_id(p, buflen);
- size_t len = ber_read_len(p, buflen);
- if(tag == 0x10) {
- if(result) {
- *result = len;
- }
-
- return true;
- } else {
- return false;
- }
+ return tag == TAG_SEQUENCE &&
+ ber_read_len(p, buflen) > 0;
}
static bool ber_read_mpi(unsigned char **p, size_t *buflen, gcry_mpi_t *mpi) {
return mpi ? !err : true;
}
+rsa_t *rsa_new(void) {
+ return xzalloc(sizeof(rsa_t));
+}
+
rsa_t *rsa_set_hex_public_key(const char *n, const char *e) {
- rsa_t *rsa = xzalloc(sizeof(rsa_t));
+ rsa_t *rsa = rsa_new();
gcry_error_t err = gcry_mpi_scan(&rsa->n, GCRYMPI_FMT_HEX, n, 0, NULL);
if(err) {
logger(DEBUG_ALWAYS, LOG_ERR, "Error while reading RSA public key: %s", gcry_strerror(errno));
- free(rsa);
+ rsa_free(rsa);
return false;
}
}
rsa_t *rsa_set_hex_private_key(const char *n, const char *e, const char *d) {
- rsa_t *rsa = xzalloc(sizeof(rsa_t));
+ rsa_t *rsa = rsa_new();
gcry_error_t err = gcry_mpi_scan(&rsa->n, GCRYMPI_FMT_HEX, n, 0, NULL);
if(err) {
logger(DEBUG_ALWAYS, LOG_ERR, "Error while reading RSA public key: %s", gcry_strerror(errno));
- free(rsa);
- return false;
+ rsa_free(rsa);
+ return NULL;
}
return rsa;
return NULL;
}
- rsa_t *rsa = xzalloc(sizeof(rsa_t));
+ rsa_t *rsa = rsa_new();
- if(!ber_read_sequence(&derp, &derlen, NULL)
+ if(!ber_skip_sequence(&derp, &derlen)
|| !ber_read_mpi(&derp, &derlen, &rsa->n)
|| !ber_read_mpi(&derp, &derlen, &rsa->e)
|| derlen) {
logger(DEBUG_ALWAYS, LOG_ERR, "Error while decoding RSA public key");
- free(rsa);
+ rsa_free(rsa);
return NULL;
}
return NULL;
}
- rsa_t *rsa = xzalloc(sizeof(rsa_t));
+ rsa_t *rsa = rsa_new();
- if(!ber_read_sequence(&derp, &derlen, NULL)
+ if(!ber_skip_sequence(&derp, &derlen)
|| !ber_read_mpi(&derp, &derlen, NULL)
|| !ber_read_mpi(&derp, &derlen, &rsa->n)
|| !ber_read_mpi(&derp, &derlen, &rsa->e)
|| !ber_read_mpi(&derp, &derlen, NULL) // u
|| derlen) {
logger(DEBUG_ALWAYS, LOG_ERR, "Error while decoding RSA private key");
- free(rsa);
- return NULL;
+ rsa_free(rsa);
+ rsa = NULL;
}
+ memzero(derbuf, sizeof(derbuf));
return rsa;
}
return (gcry_mpi_get_nbits(rsa->n) + 7) / 8;
}
+static bool check(gcry_error_t err) {
+ if(err) {
+ logger(DEBUG_ALWAYS, LOG_ERR, "gcrypt error %s/%s", gcry_strsource(err), gcry_strerror(err));
+ }
+
+ return !err;
+}
+
/* Well, libgcrypt has functions to handle RSA keys, but they suck.
* So we just use libgcrypt's mpi functions, and do the math ourselves.
*/
-// TODO: get rid of this macro, properly clean up gcry_ structures after use
-#define check(foo) { gcry_error_t err = (foo); if(err) {logger(DEBUG_ALWAYS, LOG_ERR, "gcrypt error %s/%s at %s:%d", gcry_strsource(err), gcry_strerror(err), __FILE__, __LINE__); return false; }}
+static bool rsa_powm(const gcry_mpi_t ed, const gcry_mpi_t n, const void *in, size_t len, void *out) {
+ gcry_mpi_t inmpi = NULL;
-bool rsa_public_encrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
- gcry_mpi_t inmpi;
- check(gcry_mpi_scan(&inmpi, GCRYMPI_FMT_USG, in, len, NULL));
+ if(!check(gcry_mpi_scan(&inmpi, GCRYMPI_FMT_USG, in, len, NULL))) {
+ return false;
+ }
- gcry_mpi_t outmpi = gcry_mpi_new(len * 8);
- gcry_mpi_powm(outmpi, inmpi, rsa->e, rsa->n);
+ gcry_mpi_t outmpi = gcry_mpi_snew(len * 8);
+ gcry_mpi_powm(outmpi, inmpi, ed, n);
size_t out_bytes = (gcry_mpi_get_nbits(outmpi) + 7) / 8;
size_t pad = len - MIN(out_bytes, len);
*pout++ = 0;
}
- check(gcry_mpi_print(GCRYMPI_FMT_USG, pout, len, NULL, outmpi));
-
- return true;
-}
-
-bool rsa_private_decrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
- gcry_mpi_t inmpi;
- check(gcry_mpi_scan(&inmpi, GCRYMPI_FMT_USG, in, len, NULL));
-
- gcry_mpi_t outmpi = gcry_mpi_new(len * 8);
- gcry_mpi_powm(outmpi, inmpi, rsa->d, rsa->n);
+ bool ok = check(gcry_mpi_print(GCRYMPI_FMT_USG, pout, len, NULL, outmpi));
- size_t pad = len - (gcry_mpi_get_nbits(outmpi) + 7) / 8;
- unsigned char *pout = out;
+ gcry_mpi_release(outmpi);
+ gcry_mpi_release(inmpi);
- for(; pad; --pad) {
- *pout++ = 0;
- }
+ return ok;
+}
- check(gcry_mpi_print(GCRYMPI_FMT_USG, pout, len, NULL, outmpi));
+bool rsa_public_encrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
+ return rsa_powm(rsa->e, rsa->n, in, len, out);
+}
- return true;
+bool rsa_private_decrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
+ return rsa_powm(rsa->d, rsa->n, in, len, out);
}
void rsa_free(rsa_t *rsa) {