@ -1,4 +1,5 @@
from datetime import timedelta , datetime
import ssl
import jwt
from fastapi import Depends , HTTPException
@ -11,11 +12,24 @@ from starlette import status
from starlette . requests import Request
from starlette . responses import Response
import ldap3
import db
import const
import schemas
from database import models
from database . database import SessionLocal
if const . AUTH_LDAP_ENABLED :
if const . AUTH_LDAP_SECURITY :
ldap_tls_config = ldap3 . Tls ( validate = ssl . CERT_REQUIRED if const . AUTH_LDAP_SECURITY_VALID_CERTIFICATE else ssl . CERT_NONE )
else :
ldap_tls_config = False
LDAP_SERVER = ldap3 . Server ( const . AUTH_LDAP_SERVER , const . AUTH_LDAP_PORT , get_info = ldap3 . ALL , use_ssl = const . AUTH_LDAP_SECURITY == " SSL " , tls = ldap_tls_config )
else :
LDAP_SERVER = None
oauth2_scheme = OAuth2PasswordBearer ( tokenUrl = " /api/v1/login " , auto_error = False )
pwd_context = CryptContext ( schemes = [ " bcrypt " ] , deprecated = " auto " )
@ -100,3 +114,120 @@ def auth(token: str = Depends(oauth2_scheme), api_key: str = Depends(retrieve_ap
raise credentials_exception
return user
AUTH_ENGINES : dict = { }
def authengine ( name : str , sequence : int , enabled : bool ) :
def decorator ( f ) :
AUTH_ENGINES [ name ] = {
" function " : f ,
" sequence " : sequence ,
" enabled " : enabled
}
return decorator
class Authentication ( object ) :
def __init__ ( self , username : str , password : str , sess : Session ) :
self . username = username
self . password = password
self . sess = sess
def login ( self ) :
user : schemas . UserInDB = False
for engine in sorted ( AUTH_ENGINES . keys ( ) , key = lambda x : AUTH_ENGINES [ x ] [ " sequence " ] ) :
if not AUTH_ENGINES [ engine ] [ " enabled " ] :
continue
try :
user = AUTH_ENGINES [ engine ] [ " function " ] ( self )
logger . info ( " User %s logged in via the %s authentication engine " % ( self . username , engine ) )
break
except Exception as err :
logger . warning ( " Login failed for %s using the %s authentication engine: %s " % ( self . username , engine , err ) )
if not user :
raise HTTPException (
status_code = status . HTTP_401_UNAUTHORIZED ,
detail = " Incorrect username or password " ,
headers = { " WWW-Authenticate " : " Bearer " } ,
)
return user
@authengine ( name = " builtin " , sequence = 10 , enabled = const . AUTH_LOCAL_ENABLED )
def _builtin ( self ) :
assert const . AUTH_LOCAL_ENABLED , " LOCAL authentication not enabled "
user : schemas . UserInDB = schemas . UserInDB ( username = self . username , password = " " ) . from_db ( self . sess )
# Verify password
assert user and verify_password ( self . password , user . password ) , " Invalid username or password "
return user
@authengine ( name = " LDAP " , sequence = 20 , enabled = const . AUTH_LDAP_ENABLED )
def _ldap ( self ) :
assert const . AUTH_LDAP_ENABLED , " LDAP authentication not enabled "
def _get_ldap_attr ( ldapobj , attribute ) :
attr = ldapobj [ " attributes " ] . get ( attribute , None )
if isinstance ( attr , list ) :
try :
return attr [ 0 ]
except IndexError :
return None
return attr
ldap_auth = ldap3 . ANONYMOUS
ldap_user = None
valid : bool = False
if const . AUTH_LDAP_USER :
if const . AUTH_LDAP_ACTIVEDIRECTORY :
ldap_auth = ldap3 . NTLM
else :
ldap_auth = ldap3 . SIMPLE
# Connect with binddn, if set, to search the user
with ldap3 . Connection ( LDAP_SERVER , user = const . AUTH_LDAP_USER , password = const . AUTH_LDAP_PASSWORD , authentication = ldap_auth , read_only = True , auto_bind = ldap3 . AUTO_BIND_NONE ) as cn :
if const . AUTH_LDAP_SECURITY == " TLS " :
cn . start_tls ( )
try :
assert cn . bind ( )
logger . debug ( " LDAP system bind complete " )
except :
logger . exception ( " Unable to connect/bind to LDAP server " )
raise
# TODO find a parsing tool like python-ldap.filter.filter_format
ldap_filter : str = const . AUTH_LDAP_FILTER % self . username
ldap_attributes : list = [ " cn " , " mail " ]
if const . AUTH_LDAP_ACTIVEDIRECTORY :
ldap_attributes . extend ( [ " samAccountName " , " givenName " ] )
cn . search ( search_base = const . AUTH_LDAP_BASE , search_filter = ldap_filter , attributes = ldap_attributes )
assert len ( cn . response ) == 1 , " Found %d LDAP users for the filter %s " % ( len ( cn . response ) , ldap_filter )
ldap_user = cn . response [ 0 ] . copy ( )
logininfo : str = " %s \ %s " % ( const . AUTH_LDAP_DOMAIN , _get_ldap_attr ( ldap_user , " samAccountName " ) ) if const . AUTH_LDAP_ACTIVEDIRECTORY else ldap_user [ " dn " ]
with ldap3 . Connection ( LDAP_SERVER , user = logininfo , password = self . password , authentication = ldap3 . NTLM if const . AUTH_LDAP_ACTIVEDIRECTORY else ldap3 . SIMPLE , read_only = True , auto_bind = ldap3 . AUTO_BIND_NONE ) as cn :
if const . AUTH_LDAP_SECURITY == " TLS " :
cn . start_tls ( )
assert cn . bind ( ) , " LDAP authentication failed for %s " % self . username
cn . unbind ( )
user : schema . UserInDB = schemas . UserInDB ( username = self . username , password = " " ) . from_db ( self . sess )
if user :
user . full_name = _get_ldap_attr ( ldap_user , " givenName " if const . AUTH_LDAP_ACTIVEDIRECTORY else " cn " )
user . email = _get_ldap_attr ( ldap_user , " mail " )
user . password = None
db . user . update_user ( self . sess , user )
else :
if not db . user . create_user ( self . sess , models . User (
username = username ,
password = None ,
full_name = _get_ldap_attr ( ldap_user , " givenName " if const . AUTH_LDAP_ACTIVEDIRECTORY else " cn " ) ,
email = _get_ldap_attr ( ldap_user , " mail " ) ,
role = " user " , # TODO: Map LDAP groups to roles
) ) :
raise HTTPException ( status_code = 400 , detail = " Could not create LDAP user " )
user : schema . UserInDB = schemas . UserInDB ( username = self . username , password = " " ) . from_db ( self . sess )
return user