|
|
@ -1,4 +1,5 @@ |
|
|
|
from datetime import timedelta, datetime |
|
|
|
import ssl |
|
|
|
|
|
|
|
import jwt |
|
|
|
from fastapi import Depends, HTTPException |
|
|
@ -11,11 +12,24 @@ from starlette import status |
|
|
|
from starlette.requests import Request |
|
|
|
from starlette.responses import Response |
|
|
|
|
|
|
|
import ldap3 |
|
|
|
|
|
|
|
import db |
|
|
|
import const |
|
|
|
import schemas |
|
|
|
from database import models |
|
|
|
from database.database import SessionLocal |
|
|
|
|
|
|
|
if const.AUTH_LDAP_ENABLED: |
|
|
|
if const.AUTH_LDAP_SECURITY: |
|
|
|
ldap_tls_config=ldap3.Tls(validate=ssl.CERT_REQUIRED if const.AUTH_LDAP_SECURITY_VALID_CERTIFICATE else ssl.CERT_NONE) |
|
|
|
else: |
|
|
|
ldap_tls_config = False |
|
|
|
LDAP_SERVER = ldap3.Server(const.AUTH_LDAP_SERVER, const.AUTH_LDAP_PORT, get_info=ldap3.ALL, use_ssl=const.AUTH_LDAP_SECURITY=="SSL", tls=ldap_tls_config) |
|
|
|
else: |
|
|
|
LDAP_SERVER = None |
|
|
|
|
|
|
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/login", auto_error=False) |
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
|
|
|
|
|
|
@ -100,3 +114,120 @@ def auth(token: str = Depends(oauth2_scheme), api_key: str = Depends(retrieve_ap |
|
|
|
raise credentials_exception |
|
|
|
return user |
|
|
|
|
|
|
|
AUTH_ENGINES: dict = {} |
|
|
|
|
|
|
|
def authengine(name: str, sequence: int, enabled: bool): |
|
|
|
def decorator(f): |
|
|
|
AUTH_ENGINES[name] = { |
|
|
|
"function": f, |
|
|
|
"sequence": sequence, |
|
|
|
"enabled": enabled |
|
|
|
} |
|
|
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
class Authentication(object): |
|
|
|
|
|
|
|
def __init__(self, username: str, password: str, sess: Session): |
|
|
|
self.username = username |
|
|
|
self.password = password |
|
|
|
self.sess = sess |
|
|
|
|
|
|
|
def login(self): |
|
|
|
user: schemas.UserInDB = False |
|
|
|
|
|
|
|
for engine in sorted(AUTH_ENGINES.keys(), key=lambda x: AUTH_ENGINES[x]["sequence"]): |
|
|
|
if not AUTH_ENGINES[engine]["enabled"]: |
|
|
|
continue |
|
|
|
try: |
|
|
|
user = AUTH_ENGINES[engine]["function"](self) |
|
|
|
logger.info("User %s logged in via the %s authentication engine" % (self.username, engine)) |
|
|
|
break |
|
|
|
except Exception as err: |
|
|
|
logger.warning("Login failed for %s using the %s authentication engine: %s" % (self.username, engine, err)) |
|
|
|
|
|
|
|
if not user: |
|
|
|
raise HTTPException( |
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
|
detail="Incorrect username or password", |
|
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
|
) |
|
|
|
|
|
|
|
return user |
|
|
|
|
|
|
|
@authengine(name="builtin", sequence=10, enabled=const.AUTH_LOCAL_ENABLED) |
|
|
|
def _builtin(self): |
|
|
|
assert const.AUTH_LOCAL_ENABLED, "LOCAL authentication not enabled" |
|
|
|
user: schemas.UserInDB = schemas.UserInDB(username=self.username, password="").from_db(self.sess) |
|
|
|
|
|
|
|
# Verify password |
|
|
|
assert user and verify_password(self.password, user.password), "Invalid username or password" |
|
|
|
|
|
|
|
return user |
|
|
|
|
|
|
|
@authengine(name="LDAP", sequence=20, enabled=const.AUTH_LDAP_ENABLED) |
|
|
|
def _ldap(self): |
|
|
|
assert const.AUTH_LDAP_ENABLED, "LDAP authentication not enabled" |
|
|
|
|
|
|
|
def _get_ldap_attr(ldapobj, attribute): |
|
|
|
attr = ldapobj["attributes"].get(attribute, None) |
|
|
|
if isinstance(attr, list): |
|
|
|
try: |
|
|
|
return attr[0] |
|
|
|
except IndexError: |
|
|
|
return None |
|
|
|
return attr |
|
|
|
|
|
|
|
ldap_auth = ldap3.ANONYMOUS |
|
|
|
ldap_user = None |
|
|
|
valid: bool = False |
|
|
|
if const.AUTH_LDAP_USER: |
|
|
|
if const.AUTH_LDAP_ACTIVEDIRECTORY: |
|
|
|
ldap_auth = ldap3.NTLM |
|
|
|
else: |
|
|
|
ldap_auth = ldap3.SIMPLE |
|
|
|
|
|
|
|
# Connect with binddn, if set, to search the user |
|
|
|
with ldap3.Connection(LDAP_SERVER, user=const.AUTH_LDAP_USER, password=const.AUTH_LDAP_PASSWORD, authentication=ldap_auth, read_only=True, auto_bind=ldap3.AUTO_BIND_NONE) as cn: |
|
|
|
if const.AUTH_LDAP_SECURITY == "TLS": |
|
|
|
cn.start_tls() |
|
|
|
try: |
|
|
|
assert cn.bind() |
|
|
|
logger.debug("LDAP system bind complete") |
|
|
|
except: |
|
|
|
logger.exception("Unable to connect/bind to LDAP server") |
|
|
|
raise |
|
|
|
# TODO find a parsing tool like python-ldap.filter.filter_format |
|
|
|
ldap_filter: str = const.AUTH_LDAP_FILTER % self.username |
|
|
|
ldap_attributes: list = ["cn", "mail"] |
|
|
|
|
|
|
|
if const.AUTH_LDAP_ACTIVEDIRECTORY: |
|
|
|
ldap_attributes.extend(["samAccountName", "givenName"]) |
|
|
|
cn.search(search_base=const.AUTH_LDAP_BASE, search_filter=ldap_filter, attributes=ldap_attributes) |
|
|
|
assert len(cn.response) == 1, "Found %d LDAP users for the filter %s" % (len(cn.response), ldap_filter) |
|
|
|
ldap_user = cn.response[0].copy() |
|
|
|
|
|
|
|
logininfo: str = "%s\%s" % (const.AUTH_LDAP_DOMAIN, _get_ldap_attr(ldap_user, "samAccountName")) if const.AUTH_LDAP_ACTIVEDIRECTORY else ldap_user["dn"] |
|
|
|
with ldap3.Connection(LDAP_SERVER, user=logininfo, password=self.password, authentication=ldap3.NTLM if const.AUTH_LDAP_ACTIVEDIRECTORY else ldap3.SIMPLE, read_only=True, auto_bind=ldap3.AUTO_BIND_NONE) as cn: |
|
|
|
if const.AUTH_LDAP_SECURITY == "TLS": |
|
|
|
cn.start_tls() |
|
|
|
assert cn.bind(), "LDAP authentication failed for %s" % self.username |
|
|
|
cn.unbind() |
|
|
|
|
|
|
|
user: schema.UserInDB = schemas.UserInDB(username=self.username, password="").from_db(self.sess) |
|
|
|
if user: |
|
|
|
user.full_name = _get_ldap_attr(ldap_user, "givenName" if const.AUTH_LDAP_ACTIVEDIRECTORY else "cn") |
|
|
|
user.email = _get_ldap_attr(ldap_user, "mail") |
|
|
|
user.password = None |
|
|
|
db.user.update_user(self.sess, user) |
|
|
|
else: |
|
|
|
if not db.user.create_user(self.sess, models.User( |
|
|
|
username=username, |
|
|
|
password=None, |
|
|
|
full_name=_get_ldap_attr(ldap_user, "givenName" if const.AUTH_LDAP_ACTIVEDIRECTORY else "cn"), |
|
|
|
email=_get_ldap_attr(ldap_user, "mail"), |
|
|
|
role="user", # TODO: Map LDAP groups to roles |
|
|
|
)): |
|
|
|
raise HTTPException(status_code=400, detail="Could not create LDAP user") |
|
|
|
user: schema.UserInDB = schemas.UserInDB(username=self.username, password="").from_db(self.sess) |
|
|
|
return user |
|
|
|