You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

240 lines
7.6 KiB

import ipaddress
import os
import shutil
import typing
import const
import script.wireguard
from sqlalchemy import exists
from sqlalchemy.orm import Session
import util
import models
import schemas
import logging
_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.DEBUG)
def start_client(sess: Session, peer: schemas.WGPeer):
db_peer: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).one()
client_file = os.path.join(const.CLIENT_DIR(db_peer.server_ref.interface), str(db_peer.id) + ".conf")
import subprocess
output = subprocess.check_output(const.CMD_WG_QUICK + ["up", client_file], stderr=subprocess.STDOUT)
print(output)
def server_generate_config(sess: Session, server: schemas.WGServer):
db_server: models.WGServer = server_query_get_by_interface(sess, server.interface).one()
result = util.jinja_env.get_template("server.j2").render(
data=db_server
)
interface = db_server.interface
server_file = const.SERVER_FILE(interface)
with open(server_file, "w+") as f:
f.write(result)
os.chmod(server_file, 0o600)
return result
def peer_generate_config(sess: Session, peer: schemas.WGPeer):
db_peer: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).one()
result = util.jinja_env.get_template("peer.j2").render(
data=db_peer
)
peer_file = const.PEER_FILE(db_peer)
with open(peer_file, "w+") as f:
f.write(result)
os.chmod(peer_file, 0o600)
return result
def peer_query_get_by_address(sess: Session, address: str, server: str):
return sess.query(models.WGPeer) \
.filter(models.WGPeer.address == address) \
.filter(models.WGPeer.server == server)
def peer_insert(sess: Session, peer: schemas.WGPeer) -> schemas.WGPeer:
db_server: models.WGServer = server_query_get_by_interface(sess, peer.server).one()
db_peer = models.WGPeer(**peer.dict())
address_space = set(ipaddress.ip_network(db_server.address, strict=False).hosts())
occupied_space = set()
for p in db_server.peers:
try:
occupied_space.add(ipaddress.ip_address(p.address.split("/")[0]))
except ValueError as e:
print(e)
pass # Ignore invalid addresses. These are out of address_space
address_space -= occupied_space
# Select first available address
db_peer.address = str(list(address_space).pop(0)) + "/32"
# Private public key generation
private_key, public_key = script.wireguard.generate_keys()
db_peer.private_key = private_key
db_peer.public_key = public_key
# Set 0.0.0.0/0, ::/0 as default allowed ips
db_peer.allowed_ips = ', '.join(const.PEER_DEFAULT_ALLOWED_IPS)
# Set unnamed
db_peer.name = "Unnamed"
db_peer.dns = db_server.endpoint
sess.add(db_peer)
sess.commit()
return peer.from_orm(db_peer)
def peer_dns_set(sess: Session, peer: schemas.WGPeer) -> schemas.WGPeer:
db_peer: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).one()
db_peer.dns = peer.dns
sess.add(db_peer)
sess.commit()
return peer.from_orm(db_peer)
def peer_remove(sess: Session, peer: schemas.WGPeer) -> bool:
db_peers: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).all()
for db_peer in db_peers:
sess.delete(db_peer)
sess.commit()
try:
os.remove(const.PEER_FILE(db_peer))
except:
pass
return True
def peer_key_pair_generate(sess: Session, peer: schemas.WGPeer) -> schemas.WGPeer:
db_peer: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).one()
private_key, public_key = script.wireguard.generate_keys()
db_peer.private_key = private_key
db_peer.public_key = public_key
sess.add(db_peer)
sess.commit()
return peer.from_orm(db_peer)
def peer_ip_address_set(sess: Session, peer: schemas.WGPeer) -> schemas.WGPeer:
db_peer: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).one()
db_peer.address = peer.address
sess.add(db_peer)
sess.commit()
return peer.from_orm(db_peer)
def peer_update(sess: Session, peer: schemas.WGPeer) -> schemas.WGPeer:
db_peer: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).one()
db_peer.address = peer.address
db_peer.public_key = peer.public_key
db_peer.private_key = peer.private_key
db_peer.name = peer.name
db_peer.dns = peer.dns
db_peer.allowed_ips = peer.allowed_ips
sess.add(db_peer)
sess.commit()
return peer.from_orm(db_peer)
def peer_get(sess: Session, server: schemas.WGServer) -> typing.List[schemas.WGPeer]:
db_server = server_query_get_by_interface(sess, server.interface).one()
return db_server.peers
def server_query_get_by_interface(sess: Session, interface: str):
return sess.query(models.WGServer) \
.filter(models.WGServer.interface == interface)
def server_update_field(sess: Session, interface: str, server: schemas.WGServer, fields: typing.Set):
if server_query_get_by_interface(sess, interface) \
.update(
server.dict(include=fields), synchronize_session=False
) == 1:
sess.commit()
return True
return False
def server_get_all(sess: Session) -> typing.List[schemas.WGServer]:
db_interfaces = sess.query(models.WGServer).all()
return [schemas.WGServer.from_orm(db_interface) for db_interface in db_interfaces]
def server_add(sess: Session, server: schemas.WGServer) -> schemas.WGServer:
if sess.query(exists().where(models.WGServer.interface == server.interface)).scalar():
raise ValueError("The server interface %s already exists in the database" % server.interface)
db_server = server.convert()
sess.add(db_server)
sess.commit()
return server.from_orm(db_server)
def server_remove(sess: Session, server: schemas.WGServer) -> bool:
db_server = server_query_get_by_interface(sess, server.interface).one()
if db_server is None:
raise ValueError("The server with interface %s is already deleted." % server.interface)
sess.delete(db_server)
sess.commit()
shutil.rmtree(const.SERVER_DIR(db_server.interface))
return True
def server_preshared_key(sess: Session, server: schemas.WGServer) -> bool:
return server_update_field(sess, server.interface, server, {"shared_key"})
def server_key_pair_set(sess: Session, server: schemas.WGServer) -> bool:
return server_update_field(sess, server.interface, server, {"private_key", "public_key"})
def server_listen_port_set(sess: Session, server: schemas.WGServer) -> bool:
if server.listen_port < 1024 or server.listen_port > 65535:
raise ValueError("The listen_port is not in port range 1024 < x < 65535")
return server_update_field(sess, server.interface, server, {"listen_port"})
def server_ip_address_set(sess: Session, server: schemas.WGServer) -> bool:
network = ipaddress.ip_network(server.address, False)
if not network.is_private:
raise ValueError("The network is not in private range")
return server_update_field(sess, server.interface, server, {"address"})
def server_post_up_set(sess: Session, server: schemas.WGServer) -> bool:
return server_update_field(sess, server.interface, server, {"post_up"})
def server_post_down_set(sess: Session, server: schemas.WGServer) -> bool:
return server_update_field(sess, server.interface, server, {"post_down"})
def server_endpoint_set(sess: Session, server: schemas.WGServer) -> bool:
return server_update_field(sess, server.interface, server, {"endpoint"})