04aa358d3029cf0829efe7c6b47fad0617cfc5f5
[tinc] / src / gcrypt / rsa.c
1 /*
2     rsa.c -- RSA key handling
3     Copyright (C) 2007-2022 Guus Sliepen <guus@tinc-vpn.org>
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "../system.h"
21
22 #include <gcrypt.h>
23
24 #include "pem.h"
25
26 #include "asn1.h"
27 #include "rsa.h"
28 #include "../logger.h"
29 #include "../rsa.h"
30 #include "../xalloc.h"
31
32 // BER decoding functions
33
34 static int ber_read_id(unsigned char **p, size_t *buflen) {
35         if(*buflen <= 0) {
36                 return -1;
37         }
38
39         if((**p & 0x1f) == 0x1f) {
40                 int id = 0;
41                 bool more;
42
43                 while(*buflen > 0) {
44                         id <<= 7;
45                         id |= **p & 0x7f;
46                         more = *(*p)++ & 0x80;
47                         (*buflen)--;
48
49                         if(!more) {
50                                 break;
51                         }
52                 }
53
54                 return id;
55         } else {
56                 (*buflen)--;
57                 return *(*p)++ & 0x1f;
58         }
59 }
60
61 static size_t ber_read_len(unsigned char **p, size_t *buflen) {
62         if(*buflen <= 0) {
63                 return -1;
64         }
65
66         if(**p & 0x80) {
67                 size_t result = 0;
68                 size_t len = *(*p)++ & 0x7f;
69                 (*buflen)--;
70
71                 if(len > *buflen) {
72                         return 0;
73                 }
74
75                 for(; len; --len) {
76                         result = (size_t)(result << 8);
77                         result |= *(*p)++;
78                         (*buflen)--;
79                 }
80
81                 return result;
82         } else {
83                 (*buflen)--;
84                 return *(*p)++;
85         }
86 }
87
88 static bool ber_skip_sequence(unsigned char **p, size_t *buflen) {
89         int tag = ber_read_id(p, buflen);
90
91         return tag == TAG_SEQUENCE &&
92                ber_read_len(p, buflen) > 0;
93 }
94
95 static bool ber_read_mpi(unsigned char **p, size_t *buflen, gcry_mpi_t *mpi) {
96         int tag = ber_read_id(p, buflen);
97         size_t len = ber_read_len(p, buflen);
98         gcry_error_t err = 0;
99
100         if(tag != 0x02 || len > *buflen) {
101                 return false;
102         }
103
104         if(mpi) {
105                 err = gcry_mpi_scan(mpi, GCRYMPI_FMT_USG, *p, len, NULL);
106         }
107
108         *p += len;
109         *buflen -= len;
110
111         return mpi ? !err : true;
112 }
113
114 rsa_t *rsa_set_hex_public_key(const char *n, const char *e) {
115         rsa_t *rsa = xzalloc(sizeof(rsa_t));
116
117         gcry_error_t err = gcry_mpi_scan(&rsa->n, GCRYMPI_FMT_HEX, n, 0, NULL);
118
119         if(!err) {
120                 err = gcry_mpi_scan(&rsa->e, GCRYMPI_FMT_HEX, e, 0, NULL);
121         }
122
123         if(err) {
124                 logger(DEBUG_ALWAYS, LOG_ERR, "Error while reading RSA public key: %s", gcry_strerror(errno));
125                 rsa_free(rsa);
126                 return false;
127         }
128
129         return rsa;
130 }
131
132 rsa_t *rsa_set_hex_private_key(const char *n, const char *e, const char *d) {
133         rsa_t *rsa = xzalloc(sizeof(rsa_t));
134
135         gcry_error_t err = gcry_mpi_scan(&rsa->n, GCRYMPI_FMT_HEX, n, 0, NULL);
136
137         if(!err) {
138                 err = gcry_mpi_scan(&rsa->e, GCRYMPI_FMT_HEX, e, 0, NULL);
139         }
140
141         if(!err) {
142                 err = gcry_mpi_scan(&rsa->d, GCRYMPI_FMT_HEX, d, 0, NULL);
143         }
144
145         if(err) {
146                 logger(DEBUG_ALWAYS, LOG_ERR, "Error while reading RSA public key: %s", gcry_strerror(errno));
147                 rsa_free(rsa);
148                 return NULL;
149         }
150
151         return rsa;
152 }
153
154 // Read PEM RSA keys
155
156 rsa_t *rsa_read_pem_public_key(FILE *fp) {
157         uint8_t derbuf[8096], *derp = derbuf;
158         size_t derlen;
159
160         if(!pem_decode(fp, "RSA PUBLIC KEY", derbuf, sizeof(derbuf), &derlen)) {
161                 logger(DEBUG_ALWAYS, LOG_ERR, "Unable to read RSA public key: %s", strerror(errno));
162                 return NULL;
163         }
164
165         rsa_t *rsa = xzalloc(sizeof(rsa_t));
166
167         if(!ber_skip_sequence(&derp, &derlen)
168                         || !ber_read_mpi(&derp, &derlen, &rsa->n)
169                         || !ber_read_mpi(&derp, &derlen, &rsa->e)
170                         || derlen) {
171                 logger(DEBUG_ALWAYS, LOG_ERR, "Error while decoding RSA public key");
172                 rsa_free(rsa);
173                 return NULL;
174         }
175
176         return rsa;
177 }
178
179 rsa_t *rsa_read_pem_private_key(FILE *fp) {
180         uint8_t derbuf[8096], *derp = derbuf;
181         size_t derlen;
182
183         if(!pem_decode(fp, "RSA PRIVATE KEY", derbuf, sizeof(derbuf), &derlen)) {
184                 logger(DEBUG_ALWAYS, LOG_ERR, "Unable to read RSA private key: %s", strerror(errno));
185                 return NULL;
186         }
187
188         rsa_t *rsa = xzalloc(sizeof(rsa_t));
189
190         if(!ber_skip_sequence(&derp, &derlen)
191                         || !ber_read_mpi(&derp, &derlen, NULL)
192                         || !ber_read_mpi(&derp, &derlen, &rsa->n)
193                         || !ber_read_mpi(&derp, &derlen, &rsa->e)
194                         || !ber_read_mpi(&derp, &derlen, &rsa->d)
195                         || !ber_read_mpi(&derp, &derlen, NULL) // p
196                         || !ber_read_mpi(&derp, &derlen, NULL) // q
197                         || !ber_read_mpi(&derp, &derlen, NULL)
198                         || !ber_read_mpi(&derp, &derlen, NULL)
199                         || !ber_read_mpi(&derp, &derlen, NULL) // u
200                         || derlen) {
201                 logger(DEBUG_ALWAYS, LOG_ERR, "Error while decoding RSA private key");
202                 rsa_free(rsa);
203                 rsa = NULL;
204         }
205
206         memzero(derbuf, sizeof(derbuf));
207         return rsa;
208 }
209
210 size_t rsa_size(const rsa_t *rsa) {
211         return (gcry_mpi_get_nbits(rsa->n) + 7) / 8;
212 }
213
214 static bool check(gcry_error_t err) {
215         if(err) {
216                 logger(DEBUG_ALWAYS, LOG_ERR, "gcrypt error %s/%s", gcry_strsource(err), gcry_strerror(err));
217         }
218
219         return !err;
220 }
221
222 /* Well, libgcrypt has functions to handle RSA keys, but they suck.
223  * So we just use libgcrypt's mpi functions, and do the math ourselves.
224  */
225
226 static bool rsa_powm(const gcry_mpi_t ed, const gcry_mpi_t n, const void *in, size_t len, void *out) {
227         gcry_mpi_t inmpi = NULL;
228
229         if(!check(gcry_mpi_scan(&inmpi, GCRYMPI_FMT_USG, in, len, NULL))) {
230                 return false;
231         }
232
233         gcry_mpi_t outmpi = gcry_mpi_snew(len * 8);
234         gcry_mpi_powm(outmpi, inmpi, ed, n);
235
236         size_t out_bytes = (gcry_mpi_get_nbits(outmpi) + 7) / 8;
237         size_t pad = len - MIN(out_bytes, len);
238         unsigned char *pout = out;
239
240         for(; pad; --pad) {
241                 *pout++ = 0;
242         }
243
244         bool ok = check(gcry_mpi_print(GCRYMPI_FMT_USG, pout, len, NULL, outmpi));
245
246         gcry_mpi_release(outmpi);
247         gcry_mpi_release(inmpi);
248
249         return ok;
250 }
251
252 bool rsa_public_encrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
253         return rsa_powm(rsa->e, rsa->n, in, len, out);
254 }
255
256 bool rsa_private_decrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
257         return rsa_powm(rsa->d, rsa->n, in, len, out);
258 }
259
260 void rsa_free(rsa_t *rsa) {
261         if(rsa) {
262                 if(rsa->n) {
263                         gcry_mpi_release(rsa->n);
264                 }
265
266                 if(rsa->e) {
267                         gcry_mpi_release(rsa->e);
268                 }
269
270                 if(rsa->d) {
271                         gcry_mpi_release(rsa->d);
272                 }
273
274                 free(rsa);
275         }
276 }