#include "types.h" #include "log.h" #include "lstring.h" #include #include #include #include #include #include "wireguard.h" const char *program_name = "wg_quicker"; struct IP4 { u8 a; u8 b; u8 c; u8 d; }; typedef struct IP4 IP4; struct VPN_Data { String name; IP4 network; String server_host; String server_port; String pre_shared_key; String server_public_key; IP4 last_ip; }; typedef struct VPN_Data VPN_Data; bool IP4_from_String(String *str, IP4 *ip) { IP4 res; // Warning: str might not be zero terminated int read = sscanf(str->text, "%hhu.%hhu.%hhu.%hhu", &res.a, &res.b, &res.c, &res.d); // TODO: Check for parsing errors *ip = res; return true; } void String_split(String *str, char divider, String **result_arr, u64 *result_count) { u64 capacity = 8; String *res = malloc(sizeof(String) * capacity); u64 count = 0; u64 start = 0; for(u64 i = 0; i < str->length; i++) { if(str->text[i] == divider) { // Resize backing storage if not big enought if(count >= capacity) { capacity *= 1.5; res = realloc(res, sizeof(String) * capacity); } // Add line substring to array res[count++] = string_substring(str, start, i); start = i; } } *result_arr = res; *result_count = count; } bool VPN_Data_from_String(String *str, VPN_Data *vpn) { String *lines; u64 lines_count; String_split(str, '\n', &lines, &lines_count); if(lines_count != 7) { LogError("Error parsing data file,"); return false; } VPN_Data res; res.name = string_trim(&lines[0]); if(! IP4_from_String(&lines[1], &res.network)) { LogError("Error parsing network address."); return false; } res.server_host = string_trim(&lines[2]); res.server_port = string_trim(&lines[3]); res.pre_shared_key = string_trim(&lines[4]); res.server_public_key = string_trim(&lines[5]); if(! IP4_from_String(&lines[6], &res.last_ip)) { LogError("Error parsing last IP address."); return false; } free(lines); *vpn = res; return true; } String VPN_Data_to_String(VPN_Data *vpn) { DString res = dstring_new(2048); res.length += snprintf(res.text, res.capacity, "%.*s\n" "%hhu.%hhu.%hhu.%hhu\n" "%.*s\n" "%.*s\n" "%.*s\n" "%.*s\n" "%hhu.%hhu.%hhu.%hhu\n", vpn->name.length, vpn->name.text, vpn->network.a, vpn->network.b, vpn->network.c, vpn->network.d, vpn->server_host.length, vpn->server_host.text, vpn->server_port.length, vpn->server_port.text, vpn->pre_shared_key.length, vpn->pre_shared_key.text, vpn->server_public_key.length, vpn->server_public_key.text, vpn->last_ip.a, vpn->last_ip.b, vpn->last_ip.c, vpn->last_ip.d ); return TO_STRING(res); } String Stream_ReadAll(FILE *file, bool zero_terminated) { u64 end, start; u64 fsize; u64 read; String res; // Get file size fseek(file, 0, SEEK_END); end = ftell(file); fseek(file, 0, SEEK_SET); start = ftell(file); fsize = end - start; LogDebug("File size is %lu", fsize); // Reserve memory for str res.length = fsize; if (zero_terminated) res.length++; res.text = malloc(res.length); assert(res.text != NULL); // Actually read data from file read = fread(res.text, 1, fsize, file); assert(read == fsize); if (zero_terminated) res.text[res.length] = '\0'; return res; } void Print_ErrorAndUsage(const char *error_msg) { LogError( "%s\n" "Usage: %s ...\n" "Commands:\n" " new_vpn \n" " add_client \n" , error_msg, program_name ); } int main(int argc, char *argv[]) { // Get program name from args if(argc < 1) { LogError("Internal error (missing program name from args. This should never happen. There is a bug.)"); return 1; } program_name = argv[0]; // Get command name from args if (argc < 2) { Print_ErrorAndUsage("Missing command."); return 1; } String command = string_take(argv[1]); if(string_equal(command, string_take("new_vpn"))) { if(argc < 6) { Print_ErrorAndUsage("Missing argument."); return 1; } String arg_vpn_name = string_take(argv[2]); String arg_vpn_net_addr = string_take(argv[3]); String arg_server_host_addr = string_take(argv[4]); String arg_server_port = string_take(argv[5]); VPN_Data vpn; vpn.name = arg_vpn_name; if(! IP4_from_String(&arg_vpn_net_addr, &vpn.network)) { LogError("Error parsing argument: network address."); return 1; } vpn.server_host = arg_server_host_addr; vpn.server_port = arg_server_port; if(vpn.network.d != 0) { LogError("Address %hhu.%hhu.%hhu.%hhu is not a valid /24 network address.", vpn.network.a, vpn.network.b, vpn.network.c, vpn.network.d); return 2; } // Generate private/public key for server and pre shared key wg_key_b64_string priv_b64, publ_b64, pre_shared_b64; { wg_key priv, publ, pre_shared; wg_generate_private_key(priv); wg_generate_public_key(publ, priv); wg_generate_preshared_key(pre_shared); wg_key_to_base64(priv_b64, priv); wg_key_to_base64(publ_b64, publ); wg_key_to_base64(pre_shared_b64, pre_shared); } vpn.pre_shared_key = string_take(pre_shared_b64); vpn.server_public_key = string_take(publ_b64); vpn.last_ip = vpn.network; vpn.last_ip.d = 1; // Save config data String vpn_str = VPN_Data_to_String(&vpn); String data_filename = string_concat( 2, vpn.name.text, vpn.name.length, ".txt", rstring_length(".txt") ); if(access(data_filename.text, F_OK) == 0) { LogError("File \"%s\" already exists.", data_filename.text); return 3; } FILE *data_f = fopen(data_filename.text, "w"); if (! data_f) { LogError("Cannot open \"%s\"", data_filename.text); return 2; } fprintf(data_f, "%.*s", vpn_str.length, vpn_str.text); fclose(data_f); free(data_filename.text); free(vpn_str.text); // Create wg-quick configuration file String wg_config_filename = string_concat( 3, "/etc/wireguard/", rstring_length("/etc/wireguard/"), vpn.name.text, vpn.name.length, ".conf", rstring_length(".conf") ); if(access(wg_config_filename.text, F_OK) == 0) { LogError("File \"%s\" already exists.", wg_config_filename.text); return 3; } FILE *wg_config_f = fopen(wg_config_filename.text, "w"); if (! wg_config_f) { LogError("Cannot open \"%s\"", wg_config_filename.text); return 2; } fprintf(wg_config_f, "[Interface]\n"); fprintf(wg_config_f, "Address = %hhu.%hhu.%hhu.%hhu/24\n", vpn.last_ip.a, vpn.last_ip.b, vpn.last_ip.c, vpn.last_ip.d); fprintf(wg_config_f, "PrivateKey = %s\n", priv_b64); fprintf(wg_config_f, "PostUp = firewall-cmd --zone=public --add-port=%.*s/udp\n", vpn.server_port.length, vpn.server_port.text); fprintf(wg_config_f, "PostDown = firewall-cmd --zone=public --remove-port=%.*s/udp\n", vpn.server_port.length, vpn.server_port.text); fprintf(wg_config_f, "ListenPort = %.*s\n", vpn.server_port.length, vpn.server_port.text); fclose(wg_config_f); free(wg_config_filename.text); printf("You can now activate the Wireguard service with the command \"systemctl start wg-quick@%.*s\"\n", vpn.name.length, vpn.name.text); } else if(string_equal(command, string_take("add_client"))) { if(argc < 4) { Print_ErrorAndUsage("Missing argument."); return 1; } String vpn_name = string_take(argv[2]); String client_name = string_take(argv[3]); // Read data file String data_filename = string_concat( 2, vpn_name.text, vpn_name.length, ".txt", rstring_length(".txt") ); FILE *data_f = fopen(data_filename.text, "r+"); if (! data_f) { LogError("Cannot open \"%s\"", data_filename.text); return 2; } String vpn_str = Stream_ReadAll(data_f, true); VPN_Data vpn; if(! VPN_Data_from_String(&vpn_str, &vpn)) { LogError("Cannot parse data file."); return 2; } if (vpn.last_ip.d >= 254) { LogError("Address space full. (You already generated configs for 253 clients)"); return 2; } vpn.last_ip.d += 1; // Generate client private and public keys wg_key_b64_string priv_b64, publ_b64; { wg_key priv, publ; wg_generate_private_key(priv); wg_key_to_base64(priv_b64, priv); wg_generate_public_key(publ, priv); wg_key_to_base64(publ_b64, publ); } // Update server config file String wg_config_filename = string_concat( 3, "/etc/wireguard/", rstring_length("/etc/wireguard/"), vpn.name.text, vpn.name.length, ".conf", rstring_length(".conf") ); FILE *wg_config_f = fopen(wg_config_filename.text, "a"); if(access(wg_config_filename.text, F_OK) != 0) { LogError("File \"%s\" does not exist.", wg_config_filename.text); return 3; } if (! wg_config_f) { LogError("Cannot open \"%s\"", wg_config_filename.text); return 2; } fprintf(wg_config_f, "\n"); fprintf(wg_config_f, "[Peer]\n"); fprintf(wg_config_f, "# User: %.*s\n", client_name.length, client_name.text); fprintf(wg_config_f, "PublicKey = %s\n", publ_b64); fprintf(wg_config_f, "PresharedKey = %.*s\n", vpn.pre_shared_key.length, vpn.pre_shared_key.text); fprintf(wg_config_f, "AllowedIPs = %hhu.%hhu.%hhu.%hhu/32\n", vpn.last_ip.a, vpn.last_ip.b, vpn.last_ip.c, vpn.last_ip.d); fclose(wg_config_f); free(wg_config_filename.text); // Create client config file String client_conf_path = string_concat( 2, client_name.text, client_name.length, ".conf", rstring_length(".conf") ); FILE *client_conf_f = fopen(client_conf_path.text, "w"); if (!client_conf_f) { LogError("Cannot open \"%s\"", client_conf_path.text); return 2; } fprintf(client_conf_f, "[Interface]\n"); fprintf(client_conf_f, "Address = %hhu.%hhu.%hhu.%hhu/32\n", vpn.last_ip.a, vpn.last_ip.b, vpn.last_ip.c, vpn.last_ip.d); fprintf(client_conf_f, "PrivateKey = %s\n", priv_b64); fprintf(client_conf_f, "\n"); fprintf(client_conf_f, "[Peer]\n"); fprintf(client_conf_f, "PublicKey = %.*s\n", vpn.server_public_key.length, vpn.server_public_key.text); fprintf(client_conf_f, "PresharedKey = %.*s\n", vpn.pre_shared_key.length, vpn.pre_shared_key.text); fprintf(client_conf_f, "AllowedIPs = %hhu.%hhu.%hhu.%hhu/24\n", vpn.network.a, vpn.network.b, vpn.network.c, vpn.network.d); fprintf(client_conf_f, "\n"); fprintf(client_conf_f, "Endpoint = %.*s:%.*s\n", vpn.server_host.length, vpn.server_host.text, vpn.server_port.length, vpn.server_port.text); fprintf(client_conf_f, "PersistentKeepalive = 30\n"); fclose(client_conf_f); // Save config data (last ip changed) String vpn_str_bis = VPN_Data_to_String(&vpn); fseek(data_f, 0, SEEK_SET); fprintf(data_f, "%.*s", vpn_str_bis.length, vpn_str_bis.text); printf("Client config file created: \"%.*s\".\n", client_conf_path.length, client_conf_path.text); printf("Remember to restart the Wireguard service with the command \"systemctl restart wg-quick@%.*s\" to apply the changes.\n", vpn.name.length, vpn.name.text); free(client_conf_path.text); free(vpn_str_bis.text); free(vpn_str.text); fclose(data_f); free(data_filename.text); } else { Print_ErrorAndUsage("Unrecognized command."); return 1; } return 0; }