From: Kirill Isakov Date: Mon, 11 Apr 2022 18:41:29 +0000 (+0600) Subject: Reduce duplication in request handler tables X-Git-Url: https://tinc-vpn.org/git/browse?a=commitdiff_plain;h=56621be326497d56db0c4c372ae3cc497018cfcf;p=tinc Reduce duplication in request handler tables Request handlers and request names are now grouped together so there's less chance of messing up the order (however unlikely it may have been). --- diff --git a/src/protocol.c b/src/protocol.c index 02dbf5bf..3c006d5c 100644 --- a/src/protocol.c +++ b/src/protocol.c @@ -33,30 +33,47 @@ bool tunnelserver = false; bool strictsubnets = false; bool experimental = true; -/* Jumptable for the request handlers */ - -static bool (*request_handlers[])(connection_t *, const char *) = { - id_h, metakey_h, challenge_h, chal_reply_h, ack_h, - NULL, NULL, termreq_h, - ping_h, pong_h, - add_subnet_h, del_subnet_h, - add_edge_h, del_edge_h, - key_changed_h, req_key_h, ans_key_h, tcppacket_h, control_h, - NULL, NULL, /* Not "real" requests (yet) */ - sptps_tcppacket_h, - udp_info_h, mtu_info_h, -}; +static inline bool is_valid_request(request_t req) { + return req > ALL && req < LAST; +} -/* Request names */ +/* Request handlers */ +const request_entry_t *get_request_entry(request_t req) { + if(!is_valid_request(req)) { + logger(DEBUG_ALWAYS, LOG_ERR, "Invalid request %d", req); + return NULL; + } -static const char (*request_name[]) = { - "ID", "METAKEY", "CHALLENGE", "CHAL_REPLY", "ACK", - "STATUS", "ERROR", "TERMREQ", - "PING", "PONG", - "ADD_SUBNET", "DEL_SUBNET", - "ADD_EDGE", "DEL_EDGE", "KEY_CHANGED", "REQ_KEY", "ANS_KEY", "PACKET", "CONTROL", - "REQ_PUBKEY", "ANS_PUBKEY", "SPTPS_PACKET", "UDP_INFO", "MTU_INFO", -}; + // Prevent user from accessing the table directly to always have bound checks + static const request_entry_t request_entries[] = { + [ID] = {id_h, "ID"}, + [METAKEY] = {metakey_h, "METAKEY"}, + [CHALLENGE] = {challenge_h, "CHALLENGE"}, + [CHAL_REPLY] = {chal_reply_h, "CHAL_REPLY"}, + [ACK] = {ack_h, "ACK"}, + [STATUS] = {NULL, "STATUS"}, + [ERROR] = {NULL, "ERROR"}, + [TERMREQ] = {termreq_h, "TERMREQ"}, + [PING] = {ping_h, "PING"}, + [PONG] = {pong_h, "PONG"}, + [ADD_SUBNET] = {add_subnet_h, "ADD_SUBNET"}, + [DEL_SUBNET] = {del_subnet_h, "DEL_SUBNET"}, + [ADD_EDGE] = {add_edge_h, "ADD_EDGE"}, + [DEL_EDGE] = {del_edge_h, "DEL_EDGE"}, + [KEY_CHANGED] = {key_changed_h, "KEY_CHANGED"}, + [REQ_KEY] = {req_key_h, "REQ_KEY"}, + [ANS_KEY] = {ans_key_h, "ANS_KEY"}, + [PACKET] = {tcppacket_h, "PACKET"}, + [CONTROL] = {control_h, "CONTROL"}, + /* Not "real" requests yet */ + [REQ_PUBKEY] = {NULL, "REQ_PUBKEY"}, + [ANS_PUBKEY] = {NULL, "ANS_PUBKEY"}, + [SPTPS_PACKET] = {sptps_tcppacket_h, "SPTPS_PACKET"}, + [UDP_INFO] = {udp_info_h, "UDP_INFO"}, + [MTU_INFO] = {mtu_info_h, "MTU_INFO"}, + }; + return &request_entries[req]; +} static int past_request_compare(const past_request_t *a, const past_request_t *b) { return strcmp(a->request, b->request); @@ -96,7 +113,7 @@ bool send_request(connection_t *c, const char *format, ...) { } int id = atoi(request); - logger(DEBUG_META, LOG_DEBUG, "Sending %s to %s (%s): %s", request_name[id], c->name, c->hostname, request); + logger(DEBUG_META, LOG_DEBUG, "Sending %s to %s (%s): %s", get_request_entry(id)->name, c->name, c->hostname, request); request[len++] = '\n'; @@ -114,7 +131,7 @@ bool send_request(connection_t *c, const char *format, ...) { } void forward_request(connection_t *from, const char *request) { - logger(DEBUG_META, LOG_DEBUG, "Forwarding %s from %s (%s): %s", request_name[atoi(request)], from->name, from->hostname, request); + logger(DEBUG_META, LOG_DEBUG, "Forwarding %s from %s (%s): %s", get_request_entry(atoi(request))->name, from->name, from->hostname, request); // Create a temporary newline-terminated copy of the request size_t len = strlen(request); @@ -145,23 +162,24 @@ bool receive_request(connection_t *c, const char *request) { int reqno = atoi(request); if(reqno || *request == '0') { - if((reqno < 0) || (reqno >= LAST) || !request_handlers[reqno]) { + if(!is_valid_request(reqno) || !get_request_entry(reqno)->handler) { logger(DEBUG_META, LOG_DEBUG, "Unknown request from %s (%s): %s", c->name, c->hostname, request); return false; - } else { - logger(DEBUG_META, LOG_DEBUG, "Got %s from %s (%s): %s", request_name[reqno], c->name, c->hostname, request); } + const request_entry_t *entry = get_request_entry(reqno); + logger(DEBUG_META, LOG_DEBUG, "Got %s from %s (%s): %s", entry->name, c->name, c->hostname, request); + if((c->allow_request != ALL) && (c->allow_request != reqno)) { logger(DEBUG_ALWAYS, LOG_ERR, "Unauthorized request from %s (%s)", c->name, c->hostname); return false; } - if(!request_handlers[reqno](c, request)) { + if(!entry->handler(c, request)) { /* Something went wrong. Probably scriptkiddies. Terminate. */ if(reqno != TERMREQ) { - logger(DEBUG_ALWAYS, LOG_ERR, "Error while processing %s from %s (%s)", request_name[reqno], c->name, c->hostname); + logger(DEBUG_ALWAYS, LOG_ERR, "Error while processing %s from %s (%s)", entry->name, c->name, c->hostname); } return false; diff --git a/src/protocol.h b/src/protocol.h index 392a1fe8..ced6554f 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -22,11 +22,14 @@ */ #include "ecdsa.h" +#include "connection.h" /* Protocol version. Different major versions are incompatible. */ #define PROT_MAJOR 17 -#define PROT_MINOR 7 /* Should not exceed 255! */ +#define PROT_MINOR 7 + +STATIC_ASSERT(PROT_MINOR <= 255, "PROT_MINOR must not exceed 255"); /* Silly Windows */ @@ -53,11 +56,18 @@ typedef enum request_t { LAST /* Guardian for the highest request number */ } request_t; +typedef bool (request_handler_t)(connection_t *c, const char *request); + typedef struct past_request_t { const char *request; time_t firstseen; } past_request_t; +typedef struct { + request_handler_t *const handler; + const char *name; +} request_entry_t; + extern bool tunnelserver; extern bool strictsubnets; extern bool experimental; @@ -87,11 +97,12 @@ extern bool receive_request(struct connection_t *c, const char *request); extern void exit_requests(void); extern bool seen_request(const char *request); +extern const request_entry_t *get_request_entry(request_t req); + /* Requests */ extern bool send_id(struct connection_t *c); extern bool send_metakey(struct connection_t *c); -extern bool send_metakey_ec(struct connection_t *c); extern bool send_challenge(struct connection_t *c); extern bool send_chal_reply(struct connection_t *c); extern bool send_ack(struct connection_t *c); @@ -112,27 +123,25 @@ extern bool send_mtu_info(struct node_t *from, struct node_t *to, int mtu); /* Request handlers */ -extern bool id_h(struct connection_t *c, const char *request); -extern bool metakey_h(struct connection_t *c, const char *request); -extern bool challenge_h(struct connection_t *c, const char *request); -extern bool chal_reply_h(struct connection_t *c, const char *request); -extern bool ack_h(struct connection_t *c, const char *request); -extern bool status_h(struct connection_t *c, const char *request); -extern bool error_h(struct connection_t *c, const char *request); -extern bool termreq_h(struct connection_t *c, const char *request); -extern bool ping_h(struct connection_t *c, const char *request); -extern bool pong_h(struct connection_t *c, const char *request); -extern bool add_subnet_h(struct connection_t *c, const char *request); -extern bool del_subnet_h(struct connection_t *c, const char *request); -extern bool add_edge_h(struct connection_t *c, const char *request); -extern bool del_edge_h(struct connection_t *c, const char *request); -extern bool key_changed_h(struct connection_t *c, const char *request); -extern bool req_key_h(struct connection_t *c, const char *request); -extern bool ans_key_h(struct connection_t *c, const char *request); -extern bool tcppacket_h(struct connection_t *c, const char *request); -extern bool sptps_tcppacket_h(struct connection_t *c, const char *request); -extern bool control_h(struct connection_t *c, const char *request); -extern bool udp_info_h(struct connection_t *c, const char *request); -extern bool mtu_info_h(struct connection_t *c, const char *request); +extern request_handler_t id_h; +extern request_handler_t metakey_h; +extern request_handler_t challenge_h; +extern request_handler_t chal_reply_h; +extern request_handler_t ack_h; +extern request_handler_t termreq_h; +extern request_handler_t ping_h; +extern request_handler_t pong_h; +extern request_handler_t add_subnet_h; +extern request_handler_t del_subnet_h; +extern request_handler_t add_edge_h; +extern request_handler_t del_edge_h; +extern request_handler_t key_changed_h; +extern request_handler_t req_key_h; +extern request_handler_t ans_key_h; +extern request_handler_t tcppacket_h; +extern request_handler_t sptps_tcppacket_h; +extern request_handler_t control_h; +extern request_handler_t udp_info_h; +extern request_handler_t mtu_info_h; #endif diff --git a/test/unit/meson.build b/test/unit/meson.build index 120d527a..0240dc96 100644 --- a/test/unit/meson.build +++ b/test/unit/meson.build @@ -27,6 +27,9 @@ tests = { 'subnet': { 'code': 'test_subnet.c', }, + 'protocol': { + 'code': 'test_protocol.c', + }, 'splay_tree': { 'code': 'test_splay_tree.c', 'link': link_tinc, diff --git a/test/unit/test_protocol.c b/test/unit/test_protocol.c new file mode 100644 index 00000000..af575141 --- /dev/null +++ b/test/unit/test_protocol.c @@ -0,0 +1,27 @@ +#include "unittest.h" +#include "../../src/protocol.h" + +static void test_get_invalid_request(void **state) { + (void)state; + + assert_null(get_request_entry(ALL)); + assert_null(get_request_entry(LAST)); +} + +static void test_get_valid_request_returns_nonnull(void **state) { + (void)state; + + for(request_t req = ID; req < LAST; ++req) { + const request_entry_t *ent = get_request_entry(req); + assert_non_null(ent); + assert_non_null(ent->name); + } +} + +int main(void) { + const struct CMUnitTest tests[] = { + cmocka_unit_test(test_get_invalid_request), + cmocka_unit_test(test_get_valid_request_returns_nonnull), + }; + return cmocka_run_group_tests(tests, NULL, NULL); +}