Use hardening option to add only hardening flags
[tinc] / src / gcrypt / rsa.c
1 /*
2     rsa.c -- RSA key handling
3     Copyright (C) 2007-2022 Guus Sliepen <guus@tinc-vpn.org>
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "../system.h"
21
22 #include <gcrypt.h>
23
24 #include "pem.h"
25
26 #include "asn1.h"
27 #include "rsa.h"
28 #include "../logger.h"
29 #include "../rsa.h"
30 #include "../xalloc.h"
31
32 // BER decoding functions
33
34 static int ber_read_id(unsigned char **p, size_t *buflen) {
35         if(*buflen <= 0) {
36                 return -1;
37         }
38
39         if((**p & 0x1f) == 0x1f) {
40                 int id = 0;
41                 bool more;
42
43                 while(*buflen > 0) {
44                         id <<= 7;
45                         id |= **p & 0x7f;
46                         more = *(*p)++ & 0x80;
47                         (*buflen)--;
48
49                         if(!more) {
50                                 break;
51                         }
52                 }
53
54                 return id;
55         } else {
56                 (*buflen)--;
57                 return *(*p)++ & 0x1f;
58         }
59 }
60
61 static size_t ber_read_len(unsigned char **p, size_t *buflen) {
62         if(*buflen <= 0) {
63                 return -1;
64         }
65
66         if(**p & 0x80) {
67                 size_t result = 0;
68                 size_t len = *(*p)++ & 0x7f;
69                 (*buflen)--;
70
71                 if(len > *buflen) {
72                         return 0;
73                 }
74
75                 for(; len; --len) {
76                         result = (size_t)(result << 8);
77                         result |= *(*p)++;
78                         (*buflen)--;
79                 }
80
81                 return result;
82         } else {
83                 (*buflen)--;
84                 return *(*p)++;
85         }
86 }
87
88 static bool ber_skip_sequence(unsigned char **p, size_t *buflen) {
89         int tag = ber_read_id(p, buflen);
90
91         return tag == TAG_SEQUENCE &&
92                ber_read_len(p, buflen) > 0;
93 }
94
95 static bool ber_read_mpi(unsigned char **p, size_t *buflen, gcry_mpi_t *mpi) {
96         int tag = ber_read_id(p, buflen);
97         size_t len = ber_read_len(p, buflen);
98         gcry_error_t err = 0;
99
100         if(tag != 0x02 || len > *buflen) {
101                 return false;
102         }
103
104         if(mpi) {
105                 err = gcry_mpi_scan(mpi, GCRYMPI_FMT_USG, *p, len, NULL);
106         }
107
108         *p += len;
109         *buflen -= len;
110
111         return mpi ? !err : true;
112 }
113
114 rsa_t *rsa_new(void) {
115         return xzalloc(sizeof(rsa_t));
116 }
117
118 rsa_t *rsa_set_hex_public_key(const char *n, const char *e) {
119         rsa_t *rsa = rsa_new();
120
121         gcry_error_t err = gcry_mpi_scan(&rsa->n, GCRYMPI_FMT_HEX, n, 0, NULL);
122
123         if(!err) {
124                 err = gcry_mpi_scan(&rsa->e, GCRYMPI_FMT_HEX, e, 0, NULL);
125         }
126
127         if(err) {
128                 logger(DEBUG_ALWAYS, LOG_ERR, "Error while reading RSA public key: %s", gcry_strerror(errno));
129                 rsa_free(rsa);
130                 return false;
131         }
132
133         return rsa;
134 }
135
136 rsa_t *rsa_set_hex_private_key(const char *n, const char *e, const char *d) {
137         rsa_t *rsa = rsa_new();
138
139         gcry_error_t err = gcry_mpi_scan(&rsa->n, GCRYMPI_FMT_HEX, n, 0, NULL);
140
141         if(!err) {
142                 err = gcry_mpi_scan(&rsa->e, GCRYMPI_FMT_HEX, e, 0, NULL);
143         }
144
145         if(!err) {
146                 err = gcry_mpi_scan(&rsa->d, GCRYMPI_FMT_HEX, d, 0, NULL);
147         }
148
149         if(err) {
150                 logger(DEBUG_ALWAYS, LOG_ERR, "Error while reading RSA public key: %s", gcry_strerror(errno));
151                 rsa_free(rsa);
152                 return NULL;
153         }
154
155         return rsa;
156 }
157
158 // Read PEM RSA keys
159
160 rsa_t *rsa_read_pem_public_key(FILE *fp) {
161         uint8_t derbuf[8096], *derp = derbuf;
162         size_t derlen;
163
164         if(!pem_decode(fp, "RSA PUBLIC KEY", derbuf, sizeof(derbuf), &derlen)) {
165                 logger(DEBUG_ALWAYS, LOG_ERR, "Unable to read RSA public key: %s", strerror(errno));
166                 return NULL;
167         }
168
169         rsa_t *rsa = rsa_new();
170
171         if(!ber_skip_sequence(&derp, &derlen)
172                         || !ber_read_mpi(&derp, &derlen, &rsa->n)
173                         || !ber_read_mpi(&derp, &derlen, &rsa->e)
174                         || derlen) {
175                 logger(DEBUG_ALWAYS, LOG_ERR, "Error while decoding RSA public key");
176                 rsa_free(rsa);
177                 return NULL;
178         }
179
180         return rsa;
181 }
182
183 rsa_t *rsa_read_pem_private_key(FILE *fp) {
184         uint8_t derbuf[8096], *derp = derbuf;
185         size_t derlen;
186
187         if(!pem_decode(fp, "RSA PRIVATE KEY", derbuf, sizeof(derbuf), &derlen)) {
188                 logger(DEBUG_ALWAYS, LOG_ERR, "Unable to read RSA private key: %s", strerror(errno));
189                 return NULL;
190         }
191
192         rsa_t *rsa = rsa_new();
193
194         if(!ber_skip_sequence(&derp, &derlen)
195                         || !ber_read_mpi(&derp, &derlen, NULL)
196                         || !ber_read_mpi(&derp, &derlen, &rsa->n)
197                         || !ber_read_mpi(&derp, &derlen, &rsa->e)
198                         || !ber_read_mpi(&derp, &derlen, &rsa->d)
199                         || !ber_read_mpi(&derp, &derlen, NULL) // p
200                         || !ber_read_mpi(&derp, &derlen, NULL) // q
201                         || !ber_read_mpi(&derp, &derlen, NULL)
202                         || !ber_read_mpi(&derp, &derlen, NULL)
203                         || !ber_read_mpi(&derp, &derlen, NULL) // u
204                         || derlen) {
205                 logger(DEBUG_ALWAYS, LOG_ERR, "Error while decoding RSA private key");
206                 rsa_free(rsa);
207                 rsa = NULL;
208         }
209
210         memzero(derbuf, sizeof(derbuf));
211         return rsa;
212 }
213
214 size_t rsa_size(const rsa_t *rsa) {
215         return (gcry_mpi_get_nbits(rsa->n) + 7) / 8;
216 }
217
218 static bool check(gcry_error_t err) {
219         if(err) {
220                 logger(DEBUG_ALWAYS, LOG_ERR, "gcrypt error %s/%s", gcry_strsource(err), gcry_strerror(err));
221         }
222
223         return !err;
224 }
225
226 /* Well, libgcrypt has functions to handle RSA keys, but they suck.
227  * So we just use libgcrypt's mpi functions, and do the math ourselves.
228  */
229
230 static bool rsa_powm(const gcry_mpi_t ed, const gcry_mpi_t n, const void *in, size_t len, void *out) {
231         gcry_mpi_t inmpi = NULL;
232
233         if(!check(gcry_mpi_scan(&inmpi, GCRYMPI_FMT_USG, in, len, NULL))) {
234                 return false;
235         }
236
237         gcry_mpi_t outmpi = gcry_mpi_snew(len * 8);
238         gcry_mpi_powm(outmpi, inmpi, ed, n);
239
240         size_t out_bytes = (gcry_mpi_get_nbits(outmpi) + 7) / 8;
241         size_t pad = len - MIN(out_bytes, len);
242         unsigned char *pout = out;
243
244         for(; pad; --pad) {
245                 *pout++ = 0;
246         }
247
248         bool ok = check(gcry_mpi_print(GCRYMPI_FMT_USG, pout, len, NULL, outmpi));
249
250         gcry_mpi_release(outmpi);
251         gcry_mpi_release(inmpi);
252
253         return ok;
254 }
255
256 bool rsa_public_encrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
257         return rsa_powm(rsa->e, rsa->n, in, len, out);
258 }
259
260 bool rsa_private_decrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
261         return rsa_powm(rsa->d, rsa->n, in, len, out);
262 }
263
264 void rsa_free(rsa_t *rsa) {
265         if(rsa) {
266                 if(rsa->n) {
267                         gcry_mpi_release(rsa->n);
268                 }
269
270                 if(rsa->e) {
271                         gcry_mpi_release(rsa->e);
272                 }
273
274                 if(rsa->d) {
275                         gcry_mpi_release(rsa->d);
276                 }
277
278                 free(rsa);
279         }
280 }