dashboardwireguard-vpn-setupwireguard-vpnwireguard-tunnelwireguard-dashboardwireguardwg-managervpnsite-to-siteobfuscation
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
233 lines
8.4 KiB
233 lines
8.4 KiB
from datetime import timedelta, datetime
|
|
import ssl
|
|
|
|
import jwt
|
|
from fastapi import Depends, HTTPException
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from jwt import PyJWTError
|
|
from loguru import logger
|
|
from passlib.context import CryptContext
|
|
from sqlalchemy.orm import Session
|
|
from starlette import status
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
|
|
import ldap3
|
|
|
|
import db
|
|
import const
|
|
import schemas
|
|
from database import models
|
|
from database.database import SessionLocal
|
|
|
|
if const.AUTH_LDAP_ENABLED:
|
|
if const.AUTH_LDAP_SECURITY:
|
|
ldap_tls_config=ldap3.Tls(validate=ssl.CERT_REQUIRED if const.AUTH_LDAP_SECURITY_VALID_CERTIFICATE else ssl.CERT_NONE)
|
|
else:
|
|
ldap_tls_config = False
|
|
LDAP_SERVER = ldap3.Server(const.AUTH_LDAP_SERVER, const.AUTH_LDAP_PORT, get_info=ldap3.ALL, use_ssl=const.AUTH_LDAP_SECURITY=="SSL", tls=ldap_tls_config)
|
|
else:
|
|
LDAP_SERVER = None
|
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/login", auto_error=False)
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
def get_password_hash(password):
|
|
return pwd_context.hash(password)
|
|
|
|
|
|
def verify_password(plain_password, hashed_password):
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
async def logging_middleware(request: Request, call_next):
|
|
response = await call_next(request)
|
|
logger.opt(depth=2).info(f"{request.method} {request.url} - Code: {response.status_code}")
|
|
return response
|
|
|
|
|
|
async def db_session_middleware(request: Request, call_next):
|
|
response = Response("Internal server error (Database error)", status_code=500)
|
|
try:
|
|
request.state.db = SessionLocal()
|
|
response = await call_next(request)
|
|
finally:
|
|
request.state.db.close()
|
|
return response
|
|
|
|
|
|
# NON MIDDLEWARE MIDDLEWARISH THING
|
|
|
|
|
|
# Dependency
|
|
def get_db(request: Request):
|
|
return request.state.db
|
|
|
|
|
|
def create_access_token(*, data: dict, expires_delta: timedelta = None):
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=15)
|
|
to_encode.update({"exp": expire})
|
|
encoded_jwt = jwt.encode(to_encode, const.SECRET_KEY, algorithm=const.ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
|
|
def retrieve_api_key(request: Request):
|
|
return request.headers.get("X-API-Key", None)
|
|
|
|
|
|
def auth(token: str = Depends(oauth2_scheme), api_key: str = Depends(retrieve_api_key), sess: Session = Depends(get_db)):
|
|
|
|
username = None
|
|
|
|
credentials_exception = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
# Attempt to authenticate using JWT
|
|
try:
|
|
payload = jwt.decode(token, const.SECRET_KEY, algorithms=[const.ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
except PyJWTError:
|
|
pass
|
|
|
|
try:
|
|
db_user_api_key = sess.query(models.UserAPIKey).filter_by(key=api_key).one()
|
|
username = db_user_api_key.user.username
|
|
except Exception:
|
|
pass
|
|
|
|
if username is None:
|
|
raise credentials_exception
|
|
|
|
user = schemas.User.from_orm(
|
|
schemas.UserInDB(username=username, password="").from_db(sess)
|
|
)
|
|
if user is None:
|
|
raise credentials_exception
|
|
return user
|
|
|
|
AUTH_ENGINES: dict = {}
|
|
|
|
def authengine(name: str, sequence: int, enabled: bool):
|
|
def decorator(f):
|
|
AUTH_ENGINES[name] = {
|
|
"function": f,
|
|
"sequence": sequence,
|
|
"enabled": enabled
|
|
}
|
|
|
|
return decorator
|
|
|
|
class Authentication(object):
|
|
|
|
def __init__(self, username: str, password: str, sess: Session):
|
|
self.username = username
|
|
self.password = password
|
|
self.sess = sess
|
|
|
|
def login(self):
|
|
user: schemas.UserInDB = False
|
|
|
|
for engine in sorted(AUTH_ENGINES.keys(), key=lambda x: AUTH_ENGINES[x]["sequence"]):
|
|
if not AUTH_ENGINES[engine]["enabled"]:
|
|
continue
|
|
try:
|
|
user = AUTH_ENGINES[engine]["function"](self)
|
|
logger.info("User %s logged in via the %s authentication engine" % (self.username, engine))
|
|
break
|
|
except Exception as err:
|
|
logger.warning("Login failed for %s using the %s authentication engine: %s" % (self.username, engine, err))
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect username or password",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
return user
|
|
|
|
@authengine(name="builtin", sequence=10, enabled=const.AUTH_LOCAL_ENABLED)
|
|
def _builtin(self):
|
|
assert const.AUTH_LOCAL_ENABLED, "LOCAL authentication not enabled"
|
|
user: schemas.UserInDB = schemas.UserInDB(username=self.username, password="").from_db(self.sess)
|
|
|
|
# Verify password
|
|
assert user and verify_password(self.password, user.password), "Invalid username or password"
|
|
|
|
return user
|
|
|
|
@authengine(name="LDAP", sequence=20, enabled=const.AUTH_LDAP_ENABLED)
|
|
def _ldap(self):
|
|
assert const.AUTH_LDAP_ENABLED, "LDAP authentication not enabled"
|
|
|
|
def _get_ldap_attr(ldapobj, attribute):
|
|
attr = ldapobj["attributes"].get(attribute, None)
|
|
if isinstance(attr, list):
|
|
try:
|
|
return attr[0]
|
|
except IndexError:
|
|
return None
|
|
return attr
|
|
|
|
ldap_auth = ldap3.ANONYMOUS
|
|
ldap_user = None
|
|
valid: bool = False
|
|
if const.AUTH_LDAP_USER:
|
|
if const.AUTH_LDAP_ACTIVEDIRECTORY:
|
|
ldap_auth = ldap3.NTLM
|
|
else:
|
|
ldap_auth = ldap3.SIMPLE
|
|
|
|
# Connect with binddn, if set, to search the user
|
|
with ldap3.Connection(LDAP_SERVER, user=const.AUTH_LDAP_USER, password=const.AUTH_LDAP_PASSWORD, authentication=ldap_auth, read_only=True, auto_bind=ldap3.AUTO_BIND_NONE) as cn:
|
|
if const.AUTH_LDAP_SECURITY == "TLS":
|
|
cn.start_tls()
|
|
try:
|
|
assert cn.bind()
|
|
logger.debug("LDAP system bind complete")
|
|
except:
|
|
logger.exception("Unable to connect/bind to LDAP server")
|
|
raise
|
|
# TODO find a parsing tool like python-ldap.filter.filter_format
|
|
ldap_filter: str = const.AUTH_LDAP_FILTER % self.username
|
|
ldap_attributes: list = ["cn", "mail"]
|
|
|
|
if const.AUTH_LDAP_ACTIVEDIRECTORY:
|
|
ldap_attributes.extend(["samAccountName", "givenName"])
|
|
cn.search(search_base=const.AUTH_LDAP_BASE, search_filter=ldap_filter, attributes=ldap_attributes)
|
|
assert len(cn.response) == 1, "Found %d LDAP users for the filter %s" % (len(cn.response), ldap_filter)
|
|
ldap_user = cn.response[0].copy()
|
|
|
|
logininfo: str = "%s\%s" % (const.AUTH_LDAP_DOMAIN, _get_ldap_attr(ldap_user, "samAccountName")) if const.AUTH_LDAP_ACTIVEDIRECTORY else ldap_user["dn"]
|
|
with ldap3.Connection(LDAP_SERVER, user=logininfo, password=self.password, authentication=ldap3.NTLM if const.AUTH_LDAP_ACTIVEDIRECTORY else ldap3.SIMPLE, read_only=True, auto_bind=ldap3.AUTO_BIND_NONE) as cn:
|
|
if const.AUTH_LDAP_SECURITY == "TLS":
|
|
cn.start_tls()
|
|
assert cn.bind(), "LDAP authentication failed for %s" % self.username
|
|
cn.unbind()
|
|
|
|
user: schema.UserInDB = schemas.UserInDB(username=self.username, password="").from_db(self.sess)
|
|
if user:
|
|
user.full_name = _get_ldap_attr(ldap_user, "givenName" if const.AUTH_LDAP_ACTIVEDIRECTORY else "cn")
|
|
user.email = _get_ldap_attr(ldap_user, "mail")
|
|
user.password = None
|
|
db.user.update_user(self.sess, user)
|
|
else:
|
|
if not db.user.create_user(self.sess, models.User(
|
|
username=username,
|
|
password=None,
|
|
full_name=_get_ldap_attr(ldap_user, "givenName" if const.AUTH_LDAP_ACTIVEDIRECTORY else "cn"),
|
|
email=_get_ldap_attr(ldap_user, "mail"),
|
|
role="user", # TODO: Map LDAP groups to roles
|
|
)):
|
|
raise HTTPException(status_code=400, detail="Could not create LDAP user")
|
|
user: schema.UserInDB = schemas.UserInDB(username=self.username, password="").from_db(self.sess)
|
|
return user
|
|
|