import logging import random import subprocess import tempfile import requests import typing import configparser import warnings import base64 from sqlalchemy.orm import Session import const import schemas import os import re import ipaddress import util import pywireguard from database import models from database.database import SessionLocal _LOGGER = logging.getLogger(__name__) class WGAlreadyStartedError(Exception): pass class WGAlreadyStoppedError(Exception): pass class WGPermissionsError(Exception): pass class WGPortAlreadyInUse(Exception): pass class TempServerFile(): def __init__(self, server: schemas.WGServer): self.server = server self.td = tempfile.TemporaryDirectory(prefix="wg_man_") self.server_file = os.path.join(self.td.name, f"{server.interface}.conf") def __enter__(self): with open(self.server_file, "w+") as f: f.write(self.server.configuration) return self.server_file def __exit__(self, type, value, traceback): self.td.cleanup() def _run_wg(server: schemas.WGServer, command): warnings.DeprecationWarning("_run_wg will be depretated in favor of pywireguard") try: output = subprocess.check_output(const.CMD_WG_COMMAND + command, stderr=subprocess.STDOUT) return output except Exception as e: if b'Operation not permitted' in e.output: raise WGPermissionsError("The user has insufficient permissions for interface %s" % server.interface) def is_installed(): output = subprocess.check_output(const.CMD_WG_COMMAND) return output == b'' or b'interface' in output def generate_keys() -> typing.Dict[str, str, str]: return pywireguard.generate_keys() def generate_psk(): return generate_keys()["preshared_key"] def start_interface(server: typing.Union[schemas.WGServer, schemas.WGPeer]): device = pywireguard.Device(server.interface) if not device.has_private_key: device.private_key = base64.b64encode(server.private_key) if not device.has_listen_port: device.listen_port = server.listen_port # Remove existing peers device.clear_peers() for peer in server.peers: add_peer(server, peer) #TODO Post up and post down device.update() def stop_interface(server: schemas.WGServer): device = pywireguard.Device(server.interface) device.delete() def restart_interface(server: schemas.WGServer): try: stop_interface(server) except WGAlreadyStoppedError: pass start_interface(server) def is_running(server: schemas.WGServer): try: output = _run_wg(server, ["show", server.interface]) if output is None or b'Unable to access interface: No such device' in output: return False except Exception as e: print(e.output) if b'No such device' in e.output: return False return True def add_peer(server: schemas.WGServer, peer: schemas.WGPeer): device = pywireguard.Device(server.interface) wgpeer = pywireguard.Peer(public_key=peer.public_key) if peer.shared_key: wgpeer.preshared_key = peer.shared_key # TODO: Set peer endpoint. Need to wait for pywireguard implementation for allowedip in peer.allowed_ips: splited = allowedip.split("/") ipaddr = splited[0] if len(splited) == 2: cidr = splited[1] else: cidr = 32 wgallowed = pywireguard.AllowedIP(ip=ipaddr, cidr=cidr) peer.add_allowed_ip(wgallowed) device.add_peer(wgpeer) def remove_peer(server: schemas.WGServer, peer: schemas.WGPeer): removed = False device = pywireguard.Device(server.interface) for wgpeer in device.get_peers(): if wgpeer.public_key == peer.public_key.encode(): wgpeer.remove_me() removed = True break device.update() return removed def get_stats(server: schemas.WGServer): try: output = _run_wg(server, ["show", server.interface]) if not output: return [] regex = r"peer:.*?^\n" test_str = output.decode("utf-8") + "\n" peers = [] peers_raw = re.findall(regex, test_str, re.MULTILINE | re.DOTALL) for peer in peers_raw: peer = peer.strip() lines = [x.split(": ")[1] for x in peer.split("\n")] if len(lines) == 2: public_key, allowed_ips = lines peers.append(dict( public_key=public_key, client_endpoint=None, allowed_ips=allowed_ips, handshake=None, rx=None, tx=None )) elif len(lines) == 5 or len(lines) == 6: public_key = lines[0] client_endpoint, allowed_ips, handshake, rx_tx = lines[-4:] # [1] is sometimes psk rx = re.match(r"^(.*) received", rx_tx).group(1) tx = re.match(r"^.*, (.*)sent", rx_tx).group(1) peers.append(dict( public_key=public_key, client_endpoint=client_endpoint, allowed_ips=allowed_ips, handshake=handshake, rx=rx, tx=tx )) else: ValueError("We have not handled peers with line number of %s" % str(len(lines))) return peers except Exception as e: _LOGGER.exception(e) return [] def move_server_dir(interface, interface1): old_server_dir = const.SERVER_DIR(interface) old_server_file = const.SERVER_FILE(interface) new_server_dir = const.SERVER_DIR(interface1) new_server_file = old_server_file.replace(f"{interface}.conf", f"{interface1}.conf") os.rename(old_server_file, new_server_file) os.rename(old_server_dir, new_server_dir) def generate_config(obj: typing.Union[typing.Dict[schemas.WGPeer, schemas.WGServer], schemas.WGServer]): warnings.DeprecationWarning("generate_config will be depretated in favor of pywireguard") if isinstance(obj, dict) and "server" in obj and "peer" in obj: template = "peer.j2" is_ipv6 = obj["server"].v6_address is not None elif isinstance(obj, schemas.WGServer) or isinstance(obj, models.WGServer): template = "server.j2" is_ipv6 = obj.v6_address is not None else: raise ValueError("Incorrect input type. Should be WGPeer or WGServer") result = util.jinja_env.get_template(template).render( data=obj, is_ipv6=is_ipv6 ) return result def retrieve_client_conf_from_server( client_name, server_interface, server_host, server_api_key ): const.CLIENT_NAME = "client-1" const.CLIENT_SERVER_INTERFACE = "wg0" const.CLIENT_SERVER_HOST = "http://localhost:4200" const.CLIENT_API_KEY = "8bae20143fb962930614952d80634822361fd5ab9488053866a56de5881f9d7b" assert server_interface is not None and \ server_host is not None and \ server_api_key is not None, "Client configuration is invalid: %s, %s, api-key-is-null?: %s" % ( server_interface, server_host, server_api_key is None ) api_get_or_add = f"{server_host}/api/v1/peer/configuration/get_or_add" response = requests.post(api_get_or_add, json={ "server_interface": server_interface, "name": client_name }, headers={ "X-API-Key": server_api_key }) if response.status_code != 200: print(response.text) raise RuntimeError("Could not retrieve config from server: %s" % (api_get_or_add,)) return response.text def create_client_config(sess: Session, configuration, client_name, client_routes): parser = configparser.ConfigParser() parser.read_string(configuration) public_key = parser["Peer"]["PublicKey"] assert len(set(parser.sections()) - {"Interface", "Peer"}) == 0, "Missing Interface or Peer section" # Parse Server # Check if server already exists. is_new_server = False is_new_peer = False try: db_server = sess.query(models.WGServer).filter_by( public_key=public_key, read_only=1 ).one() except: db_server = None if db_server is None: db_server = models.WGServer() is_new_server = True db_server.read_only = 1 db_server.public_key = parser["Peer"]["PublicKey"] db_server.address = parser["Peer"]["Endpoint"] db_server.listen_port = random.randint(69000, 19292009) db_server.v6_address = "N/A" db_server.v6_subnet = 0 db_server.address = "N/A" db_server.subnet = 0 db_server.private_key = "N/A" db_server.dns = "N/A" db_server.post_up = "N/A" db_server.post_down = "N/A" db_server.is_running = False db_server.configuration = "N/A" # Parse client try: db_peer = sess.query(models.WGPeer).filter_by( private_key=parser["Interface"]["PrivateKey"], read_only=1 ).one() except: db_peer = None if db_peer is None: db_peer = models.WGPeer() is_new_peer = True db_peer.read_only = 1 db_peer.name = client_name addresses_split = parser["Interface"]["Address"].split(",") assert len(addresses_split) > 0, "Must be at least one address" for address_with_subnet in addresses_split: addr, subnet = address_with_subnet.split("/") if isinstance(ipaddress.ip_address(addr), ipaddress.IPv4Address): db_peer.address = address_with_subnet elif isinstance(ipaddress.ip_address(addr), ipaddress.IPv6Address): db_peer.v6_address = address_with_subnet else: raise RuntimeError("Incorrect IP Address: %s, %s" % (addr, subnet)) db_peer.private_key = parser["Interface"]["PrivateKey"] db_peer.public_key = "N/A" db_peer.allowed_ips = client_routes if client_routes else parser["Peer"]["AllowedIPs"] db_peer.configuration = configuration db_server.interface = f"client_{db_peer.name}" db_server.configuration = configuration try: db_peer.shared_key = parser["Interface"]["PrivateKey"] except KeyError: db_peer.shared_key = "N/A" db_peer.dns = parser["Interface"]["DNS"] db_peer.server = db_server if is_new_server: sess.add(db_server) else: sess.merge(db_server) sess.commit() if is_new_peer: sess.add(db_peer) else: sess.merge(db_peer) sess.commit() if const.CLIENT_START_AUTOMATICALLY and not is_running(db_server): start_interface(db_server) def load_environment_clients(sess: Session): i = 1 while True: client_name = os.getenv(f"CLIENT_{i}_NAME", None) client_server_interface = os.getenv(f"CLIENT_{i}_SERVER_INTERFACE", None) client_server_host = os.getenv(f"CLIENT_{i}_SERVER_HOST", None) client_api_key = os.getenv(f"CLIENT_{i}_API_KEY", None) client_routes = os.getenv(f"CLIENT_{i}_ROUTES", None) if client_api_key is None or \ client_server_interface is None or \ client_server_host is None or \ client_api_key is None: break _LOGGER.warning( f"Found client configuration: name={client_name},siface={client_server_interface},shost={client_server_host}") config = retrieve_client_conf_from_server( client_name=client_name, server_interface=client_server_interface, server_host=client_server_host, server_api_key=client_api_key ) create_client_config(sess, configuration=config, client_name=client_name, client_routes=client_routes) i += 1 if __name__ == "__main__": os.environ["CLIENT_1_NAME"] = "client-1" os.environ["CLIENT_1_SERVER_INTERFACE"] = "wg0" os.environ["CLIENT_1_SERVER_HOST"] = "http://localhost:4200" os.environ["CLIENT_1_API_KEY"] = "8bae20143fb962930614952d80634822361fd5ab9488053866a56de5881f9d7b" os.environ["CLIENT_2_NAME"] = "client-2" os.environ["CLIENT_2_SERVER_INTERFACE"] = "wg0" os.environ["CLIENT_2_SERVER_HOST"] = "http://localhost:4200" os.environ["CLIENT_2_API_KEY"] = "8bae20143fb962930614952d80634822361fd5ab9488053866a56de5881f9d7b" sess: Session = SessionLocal() load_environment_clients(sess) sess.close()