5d12985d63d366d7ed2fc5951217bb1bede10db0
[tinc] / src / gcrypt / rsa.c
1 /*
2     rsa.c -- RSA key handling
3     Copyright (C) 2007-2012 Guus Sliepen <guus@tinc-vpn.org>
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "system.h"
21
22 #include <gcrypt.h>
23
24 #include "logger.h"
25 #include "rsa.h"
26 #include "pem.h"
27 #include "xalloc.h"
28
29 // BER decoding functions
30
31 static int ber_read_id(unsigned char **p, size_t *buflen) {
32         if(*buflen <= 0) {
33                 return -1;
34         }
35
36         if((**p & 0x1f) == 0x1f) {
37                 int id = 0;
38                 bool more;
39
40                 while(*buflen > 0) {
41                         id <<= 7;
42                         id |= **p & 0x7f;
43                         more = *(*p)++ & 0x80;
44                         (*buflen)--;
45
46                         if(!more) {
47                                 break;
48                         }
49                 }
50
51                 return id;
52         } else {
53                 (*buflen)--;
54                 return *(*p)++ & 0x1f;
55         }
56 }
57
58 static size_t ber_read_len(unsigned char **p, size_t *buflen) {
59         if(*buflen <= 0) {
60                 return -1;
61         }
62
63         if(**p & 0x80) {
64                 size_t result = 0;
65                 int len = *(*p)++ & 0x7f;
66                 (*buflen)--;
67
68                 if(len > *buflen) {
69                         return 0;
70                 }
71
72                 while(len--) {
73                         result = (size_t)(result << 8);
74                         result |= *(*p)++;
75                         (*buflen)--;
76                 }
77
78                 return result;
79         } else {
80                 (*buflen)--;
81                 return *(*p)++;
82         }
83 }
84
85
86 static bool ber_read_sequence(unsigned char **p, size_t *buflen, size_t *result) {
87         int tag = ber_read_id(p, buflen);
88         size_t len = ber_read_len(p, buflen);
89
90         if(tag == 0x10) {
91                 if(result) {
92                         *result = len;
93                 }
94
95                 return true;
96         } else {
97                 return false;
98         }
99 }
100
101 static bool ber_read_mpi(unsigned char **p, size_t *buflen, gcry_mpi_t *mpi) {
102         int tag = ber_read_id(p, buflen);
103         size_t len = ber_read_len(p, buflen);
104         gcry_error_t err = 0;
105
106         if(tag != 0x02 || len > *buflen) {
107                 return false;
108         }
109
110         if(mpi) {
111                 err = gcry_mpi_scan(mpi, GCRYMPI_FMT_USG, *p, len, NULL);
112         }
113
114         *p += len;
115         *buflen -= len;
116
117         return mpi ? !err : true;
118 }
119
120 rsa_t *rsa_set_hex_public_key(char *n, char *e) {
121         rsa_t *rsa = xzalloc(sizeof(rsa_t));
122
123         gcry_error_t err = gcry_mpi_scan(&rsa->n, GCRYMPI_FMT_HEX, n, 0, NULL)
124                            ? : gcry_mpi_scan(&rsa->e, GCRYMPI_FMT_HEX, e, 0, NULL);
125
126         if(err) {
127                 logger(DEBUG_ALWAYS, LOG_ERR, "Error while reading RSA public key: %s", gcry_strerror(errno));
128                 free(rsa);
129                 return false;
130         }
131
132         return rsa;
133 }
134
135 rsa_t *rsa_set_hex_private_key(char *n, char *e, char *d) {
136         rsa_t *rsa = xzalloc(sizeof(rsa_t));
137
138         gcry_error_t err = gcry_mpi_scan(&rsa->n, GCRYMPI_FMT_HEX, n, 0, NULL)
139                            ? : gcry_mpi_scan(&rsa->e, GCRYMPI_FMT_HEX, e, 0, NULL)
140                            ? : gcry_mpi_scan(&rsa->d, GCRYMPI_FMT_HEX, d, 0, NULL);
141
142         if(err) {
143                 logger(DEBUG_ALWAYS, LOG_ERR, "Error while reading RSA public key: %s", gcry_strerror(errno));
144                 free(rsa);
145                 return false;
146         }
147
148         return rsa;
149 }
150
151 // Read PEM RSA keys
152
153 rsa_t *rsa_read_pem_public_key(FILE *fp) {
154         uint8_t derbuf[8096], *derp = derbuf;
155         size_t derlen;
156
157         if(!pem_decode(fp, "RSA PUBLIC KEY", derbuf, sizeof(derbuf), &derlen)) {
158                 logger(DEBUG_ALWAYS, LOG_ERR, "Unable to read RSA public key: %s", strerror(errno));
159                 return NULL;
160         }
161
162         rsa_t *rsa = xzalloc(sizeof(rsa_t));
163
164         if(!ber_read_sequence(&derp, &derlen, NULL)
165                         || !ber_read_mpi(&derp, &derlen, &rsa->n)
166                         || !ber_read_mpi(&derp, &derlen, &rsa->e)
167                         || derlen) {
168                 logger(DEBUG_ALWAYS, LOG_ERR, "Error while decoding RSA public key");
169                 free(rsa);
170                 return NULL;
171         }
172
173         return rsa;
174 }
175
176 rsa_t *rsa_read_pem_private_key(FILE *fp) {
177         uint8_t derbuf[8096], *derp = derbuf;
178         size_t derlen;
179
180         if(!pem_decode(fp, "RSA PRIVATE KEY", derbuf, sizeof(derbuf), &derlen)) {
181                 logger(DEBUG_ALWAYS, LOG_ERR, "Unable to read RSA private key: %s", strerror(errno));
182                 return NULL;
183         }
184
185         rsa_t *rsa = xzalloc(sizeof(rsa_t));
186
187         if(!ber_read_sequence(&derp, &derlen, NULL)
188                         || !ber_read_mpi(&derp, &derlen, NULL)
189                         || !ber_read_mpi(&derp, &derlen, &rsa->n)
190                         || !ber_read_mpi(&derp, &derlen, &rsa->e)
191                         || !ber_read_mpi(&derp, &derlen, &rsa->d)
192                         || !ber_read_mpi(&derp, &derlen, NULL) // p
193                         || !ber_read_mpi(&derp, &derlen, NULL) // q
194                         || !ber_read_mpi(&derp, &derlen, NULL)
195                         || !ber_read_mpi(&derp, &derlen, NULL)
196                         || !ber_read_mpi(&derp, &derlen, NULL) // u
197                         || derlen) {
198                 logger(DEBUG_ALWAYS, LOG_ERR, "Error while decoding RSA private key");
199                 free(rsa);
200                 return NULL;
201         }
202
203         return rsa;
204 }
205
206 size_t rsa_size(rsa_t *rsa) {
207         return (gcry_mpi_get_nbits(rsa->n) + 7) / 8;
208 }
209
210 /* Well, libgcrypt has functions to handle RSA keys, but they suck.
211  * So we just use libgcrypt's mpi functions, and do the math ourselves.
212  */
213
214 // TODO: get rid of this macro, properly clean up gcry_ structures after use
215 #define check(foo) { gcry_error_t err = (foo); if(err) {logger(DEBUG_ALWAYS, LOG_ERR, "gcrypt error %s/%s at %s:%d", gcry_strsource(err), gcry_strerror(err), __FILE__, __LINE__); return false; }}
216
217 bool rsa_public_encrypt(rsa_t *rsa, void *in, size_t len, void *out) {
218         gcry_mpi_t inmpi;
219         check(gcry_mpi_scan(&inmpi, GCRYMPI_FMT_USG, in, len, NULL));
220
221         gcry_mpi_t outmpi = gcry_mpi_new(len * 8);
222         gcry_mpi_powm(outmpi, inmpi, rsa->e, rsa->n);
223
224         size_t out_bytes = (gcry_mpi_get_nbits(outmpi) + 7) / 8;
225         size_t pad = len - MIN(out_bytes, len);
226
227         for(; pad; --pad) {
228                 *(char *)out++ = 0;
229         }
230
231         check(gcry_mpi_print(GCRYMPI_FMT_USG, out, len, NULL, outmpi));
232
233         return true;
234 }
235
236 bool rsa_private_decrypt(rsa_t *rsa, void *in, size_t len, void *out) {
237         gcry_mpi_t inmpi;
238         check(gcry_mpi_scan(&inmpi, GCRYMPI_FMT_USG, in, len, NULL));
239
240         gcry_mpi_t outmpi = gcry_mpi_new(len * 8);
241         gcry_mpi_powm(outmpi, inmpi, rsa->d, rsa->n);
242
243         size_t pad = len - (gcry_mpi_get_nbits(outmpi) + 7) / 8;
244
245         for(; pad; --pad) {
246                 *(char *)out++ = 0;
247         }
248
249         check(gcry_mpi_print(GCRYMPI_FMT_USG, out, len, NULL, outmpi));
250
251         return true;
252 }
253
254 void rsa_free(rsa_t *rsa) {
255         if(rsa) {
256                 if(rsa->n) {
257                         gcry_mpi_release(rsa->n);
258                 }
259
260                 if(rsa->e) {
261                         gcry_mpi_release(rsa->e);
262                 }
263
264                 if(rsa->d) {
265                         gcry_mpi_release(rsa->d);
266                 }
267
268                 free(rsa);
269         }
270 }