diff --git a/wg-manager-backend/requirements.txt b/wg-manager-backend/requirements.txt index 95fd9a5..c0ec4c0 100644 --- a/wg-manager-backend/requirements.txt +++ b/wg-manager-backend/requirements.txt @@ -20,3 +20,4 @@ qrcode[pil] alembic loguru ldap3==2.9 +pywireguard==0.1.3 diff --git a/wg-manager-backend/script/wireguard.py b/wg-manager-backend/script/wireguard.py index 6e23c22..2038438 100644 --- a/wg-manager-backend/script/wireguard.py +++ b/wg-manager-backend/script/wireguard.py @@ -5,6 +5,8 @@ import tempfile import requests import typing import configparser +import warnings +import base64 from sqlalchemy.orm import Session @@ -14,6 +16,7 @@ import os import re import ipaddress import util +import pywireguard from database import models from database.database import SessionLocal @@ -52,6 +55,7 @@ class TempServerFile(): def _run_wg(server: schemas.WGServer, command): + warnings.DeprecationWarning("_run_wg will be depretated in favor of pywireguard") try: output = subprocess.check_output(const.CMD_WG_COMMAND + command, stderr=subprocess.STDOUT) return output @@ -65,49 +69,34 @@ def is_installed(): return output == b'' or b'interface' in output -def generate_keys() -> typing.Dict[str, str]: - private_key = subprocess.check_output(const.CMD_WG_COMMAND + ["genkey"]) - public_key = subprocess.check_output( - const.CMD_WG_COMMAND + ["pubkey"], - input=private_key - ) - - private_key = private_key.decode("utf-8").strip() - public_key = public_key.decode("utf-8").strip() - return dict( - private_key=private_key, - public_key=public_key - ) - +def generate_keys() -> typing.Dict[str, str, str]: + return pywireguard.generate_keys() def generate_psk(): - return subprocess.check_output(const.CMD_WG_COMMAND + ["genpsk"]).decode("utf-8").strip() + return generate_keys()["preshared_key"] def start_interface(server: typing.Union[schemas.WGServer, schemas.WGPeer]): - with TempServerFile(server) as server_file: - try: - # print(*const.CMD_WG_QUICK, "up", server_file) - output = subprocess.check_output(const.CMD_WG_QUICK + ["up", server_file], stderr=subprocess.STDOUT) - return output - except Exception as e: - print(e.output) - if b'already exists' in e.output: - raise WGAlreadyStartedError("The wireguard device %s is already started." % server.interface) - elif b'Address already in use' in e.output: - raise WGPortAlreadyInUse("The port %s is already used by another application." % server.listen_port) + device = pywireguard.Device(server.interface) + if not device.has_private_key: + device.private_key = base64.b64encode(server.private_key) + if not device.has_listen_port: + device.listen_port = server.listen_port + # Remove existing peers + device.clear_peers() + for peer in server.peers: + add_peer(server, peer) -def stop_interface(server: schemas.WGServer): - with TempServerFile(server) as server_file: - try: - output = subprocess.check_output(const.CMD_WG_QUICK + ["down", server_file], stderr=subprocess.STDOUT) - return output - except Exception as e: - if b'is not a WireGuard interface' in e.output: - raise WGAlreadyStoppedError("The wireguard device %s is already stopped." % server.interface) + #TODO Post up and post down + + device.update() +def stop_interface(server: schemas.WGServer): + device = pywireguard.Device(server.interface) + device.delete() + def restart_interface(server: schemas.WGServer): try: stop_interface(server) @@ -129,21 +118,33 @@ def is_running(server: schemas.WGServer): def add_peer(server: schemas.WGServer, peer: schemas.WGPeer): - try: - output = _run_wg(server, ["set", server.interface, "peer", peer.public_key, "allowed-ips", peer.address]) - return output == b'' - except Exception as e: - _LOGGER.exception(e) - return False + device = pywireguard.Device(server.interface) + wgpeer = pywireguard.Peer(public_key=peer.public_key) + if peer.shared_key: + wgpeer.preshared_key = peer.shared_key + # TODO: Set peer endpoint. Need to wait for pywireguard implementation + for allowedip in peer.allowed_ips: + splited = allowedip.split("/") + ipaddr = splited[0] + if len(splited) == 2: + cidr = splited[1] + else: + cidr = 32 + wgallowed = pywireguard.AllowedIP(ip=ipaddr, cidr=cidr) + peer.add_allowed_ip(wgallowed) + device.add_peer(wgpeer) def remove_peer(server: schemas.WGServer, peer: schemas.WGPeer): - try: - output = _run_wg(server, ["set", server.interface, "peer", peer.public_key, "remove"]) - return output == b'' - except Exception as e: - _LOGGER.exception(e) - return False + removed = False + device = pywireguard.Device(server.interface) + for wgpeer in device.get_peers(): + if wgpeer.public_key == peer.public_key.encode(): + wgpeer.remove_me() + removed = True + break + device.update() + return removed def get_stats(server: schemas.WGServer): @@ -207,6 +208,7 @@ def move_server_dir(interface, interface1): def generate_config(obj: typing.Union[typing.Dict[schemas.WGPeer, schemas.WGServer], schemas.WGServer]): + warnings.DeprecationWarning("generate_config will be depretated in favor of pywireguard") if isinstance(obj, dict) and "server" in obj and "peer" in obj: template = "peer.j2" is_ipv6 = obj["server"].v6_address is not None