import ipaddress import json import os import shutil import typing from starlette.exceptions import HTTPException import const import script.wireguard from sqlalchemy.orm import Session from database import models import schemas import logging _LOGGER = logging.getLogger(__name__) _LOGGER.setLevel(logging.DEBUG) def start_client(sess: Session, peer: schemas.WGPeer): db_peer: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).one() client_file = os.path.join(const.CLIENT_DIR(db_peer.server.interface), str(db_peer.id) + ".conf") import subprocess output = subprocess.check_output(const.CMD_WG_QUICK + ["up", client_file], stderr=subprocess.STDOUT) def get_server_by_id(sess: Session, server_id): return sess.query(models.WGServer).filter_by(id=server_id).one() def peer_query_get_by_address(sess: Session, address: str, server: str): return sess.query(models.WGPeer) \ .filter(models.WGPeer.address == address) \ .filter(models.WGPeer.server == server) def peer_dns_set(sess: Session, peer: schemas.WGPeer) -> schemas.WGPeer: db_peer: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).one() db_peer.dns = peer.dns sess.add(db_peer) sess.commit() return peer.from_orm(db_peer) def peer_remove(sess: Session, peer: schemas.WGPeer) -> bool: db_peers = sess.query(models.WGPeer).filter_by(id=peer.id).all() for db_peer in db_peers: sess.delete(db_peer) sess.commit() server_update_configuration(sess, peer.server_id) return True def peer_edit(sess: Session, peer: schemas.WGPeer): # Retrieve server from db server: models.WGServer = get_server_by_id(sess, peer.server_id) # Generate peer configuration peer.configuration = script.wireguard.generate_config(dict( peer=peer, server=server )) # Update database record for Peer sess.query(models.WGPeer) \ .filter_by(id=peer.id) \ .update(peer.dict(exclude={"id"})) sess.commit() server_update_configuration(sess, server.id) return peer def peer_key_pair_generate(sess: Session, peer: schemas.WGPeer) -> schemas.WGPeer: db_peer: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).one() private_key, public_key = script.wireguard.generate_keys() db_peer.private_key = private_key db_peer.public_key = public_key sess.add(db_peer) sess.commit() return peer.from_orm(db_peer) def peer_ip_address_set(sess: Session, peer: schemas.WGPeer) -> schemas.WGPeer: db_peer: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).one() db_peer.address = peer.address sess.add(db_peer) sess.commit() return peer.from_orm(db_peer) def peer_update(sess: Session, peer: schemas.WGPeer) -> schemas.WGPeer: db_peer: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).one() db_peer.address = peer.address db_peer.public_key = peer.public_key db_peer.private_key = peer.private_key db_peer.name = peer.name db_peer.dns = peer.dns db_peer.allowed_ips = peer.allowed_ips sess.add(db_peer) sess.commit() return peer.from_orm(db_peer) def peer_get(sess: Session, server: schemas.WGServer) -> typing.List[schemas.WGPeer]: db_server = server_query_get_by_interface(sess, server.interface).one() return db_server.peers def server_query_get_by_interface(sess: Session, interface: str): return sess.query(models.WGServer) \ .filter(models.WGServer.interface == interface) def server_update_field(sess: Session, interface: str, server: schemas.WGServer, fields: typing.Set): if server_query_get_by_interface(sess, interface) \ .update( server.dict(include=fields), synchronize_session=False ) == 1: sess.commit() return True return False def server_get_all(sess: Session) -> typing.List[schemas.WGServer]: db_interfaces = sess.query(models.WGServer) \ .all() return [schemas.WGServer.from_orm(db_interface) for db_interface in db_interfaces] def server_add_on_init(sess: Session): """ Routine for adding server from env variable. :param server: :param sess: :return: """ try: init_data = json.loads(const.SERVER_INIT_INTERFACE) if init_data["endpoint"] == "||external||": import requests init_data["endpoint"] = requests.get("https://api.ipify.org").text elif init_data["endpoint"] == "||internal||": import socket init_data["endpoint"] = socket.gethostbyname(socket.gethostname()) if sess.query(models.WGServer) \ .filter_by(endpoint=init_data["endpoint"], listen_port=init_data["listen_port"]) \ .count() == 0: # Only add if it does not already exists. server_add(schemas.WGServerAdd(**init_data), sess, start=const.SERVER_INIT_INTERFACE_START) except Exception as e: _LOGGER.warning("Failed to setup initial server interface with exception:") _LOGGER.exception(e) def server_add(server: schemas.WGServerAdd, sess: Session, start=False): # Configure POST UP with defaults if not manually set. if server.post_up == "": server.post_up = const.DEFAULT_POST_UP if server.v6_address is not None: server.post_up += const.DEFAULT_POST_UP_v6 # Configure POST DOWN with defaults if not manually set. if server.post_down == "": server.post_down = const.DEFAULT_POST_DOWN if server.v6_address is not None: server.post_down += const.DEFAULT_POST_DOWN_v6 peers = server.peers if server.peers else [] # Public/Private key try: if sess.query(models.WGServer) \ .filter( (models.WGServer.interface == server.interface) | (models.WGServer.address == server.address) | (models.WGServer.v6_address == server.v6_address)).count() != 0: raise HTTPException(status_code=400, detail="The server interface or ip %s already exists in the database" % server.interface) if not server.private_key: keys = script.wireguard.generate_keys() server.private_key = keys["private_key"] server.public_key = keys["public_key"] server.configuration = script.wireguard.generate_config(server) server.peers = [] server.sync(sess) if len(peers) > 0: server.from_db(sess) for schemaPeer in peers: schemaPeer.server_id = server.id schemaPeer.configuration = script.wireguard.generate_config(dict( peer=schemaPeer, server=server )) dbPeer = models.WGPeer(**schemaPeer.dict()) sess.add(dbPeer) sess.commit() server.from_db(sess) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) if start and not script.wireguard.is_running(server): script.wireguard.start_interface(server) return server def server_remove(sess: Session, server: schemas.WGServer) -> bool: db_server = server_query_get_by_interface(sess, server.interface).one() if db_server is None: raise ValueError("The server with interface %s is already deleted." % server.interface) sess.delete(db_server) sess.commit() shutil.rmtree(const.SERVER_DIR(db_server.interface)) return True def server_preshared_key(sess: Session, server: schemas.WGServer) -> bool: return server_update_field(sess, server.interface, server, {"shared_key"}) def server_key_pair_set(sess: Session, server: schemas.WGServer) -> bool: return server_update_field(sess, server.interface, server, {"private_key", "public_key"}) def server_listen_port_set(sess: Session, server: schemas.WGServer) -> bool: if server.listen_port < 1024 or server.listen_port > 65535: raise ValueError("The listen_port is not in port range 1024 < x < 65535") return server_update_field(sess, server.interface, server, {"listen_port"}) def server_ip_address_set(sess: Session, server: schemas.WGServer) -> bool: network = ipaddress.ip_network(server.address, False) if not network.is_private: raise ValueError("The network is not in private range") return server_update_field(sess, server.interface, server, {"address"}) def server_post_up_set(sess: Session, server: schemas.WGServer) -> bool: return server_update_field(sess, server.interface, server, {"post_up"}) def server_post_down_set(sess: Session, server: schemas.WGServer) -> bool: return server_update_field(sess, server.interface, server, {"post_down"}) def server_endpoint_set(sess: Session, server: schemas.WGServer) -> bool: return server_update_field(sess, server.interface, server, {"endpoint"}) def server_update_configuration(sess: Session, server_id: int) -> bool: # Generate server configuration server: models.WGServer = sess.query(models.WGServer).filter_by(id=server_id).one() server.configuration = script.wireguard.generate_config(server) sess.add(server) sess.commit()