|
|
@ -96,7 +96,10 @@ struct TokenPayload { |
|
|
|
fn _authorization_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult { |
|
|
|
let org_identifier = data.org_identifier.as_ref().unwrap(); |
|
|
|
let code = data.code.as_ref().unwrap(); |
|
|
|
let (access_token, refresh_token) = get_auth_code_access_token(&code, &org_identifier, &conn); |
|
|
|
let (access_token, refresh_token) = match get_auth_code_access_token(&code, &org_identifier, &conn) { |
|
|
|
Ok((access_token, refresh_token)) => (access_token, refresh_token), |
|
|
|
Err(err) => err!(err), |
|
|
|
}; |
|
|
|
let token = jsonwebtoken::dangerous_insecure_decode::<TokenPayload>(access_token.as_str()).unwrap().claims; |
|
|
|
let expiry = token.exp; |
|
|
|
let user_email = token.email; |
|
|
@ -562,17 +565,7 @@ use openidconnect::{ |
|
|
|
Scope, OAuth2TokenResponse, |
|
|
|
}; |
|
|
|
|
|
|
|
fn handle_error<T: std::error::Error>(fail: &T, msg: &'static str) { |
|
|
|
let mut err_msg = format!("ERROR: {}", msg); |
|
|
|
let mut cur_fail: Option<&dyn std::error::Error> = Some(fail); |
|
|
|
while let Some(cause) = cur_fail { |
|
|
|
err_msg += &format!("\n caused by: {}", cause); |
|
|
|
cur_fail = cause.source(); |
|
|
|
} |
|
|
|
panic!("{}", err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
fn get_client_from_identifier (identifier: &str, conn: &DbConn) -> CoreClient { |
|
|
|
fn get_client_from_identifier (identifier: &str, conn: &DbConn) -> Result<CoreClient, &'static str> { |
|
|
|
let organization = Organization::find_by_identifier(identifier, conn); |
|
|
|
|
|
|
|
match organization { |
|
|
@ -581,21 +574,22 @@ fn get_client_from_identifier (identifier: &str, conn: &DbConn) -> CoreClient { |
|
|
|
let client_id = ClientId::new(organization.client_id); |
|
|
|
let client_secret = ClientSecret::new(organization.client_secret); |
|
|
|
let issuer_url = IssuerUrl::new(organization.authority).expect("invalid issuer URL"); |
|
|
|
let provider_metadata = CoreProviderMetadata::discover(&issuer_url, http_client) |
|
|
|
.unwrap_or_else(|err| { |
|
|
|
handle_error(&err, "Failed to discover OpenID Provider"); |
|
|
|
unreachable!(); |
|
|
|
}); |
|
|
|
let provider_metadata = match CoreProviderMetadata::discover(&issuer_url, http_client) { |
|
|
|
Ok(metadata) => metadata, |
|
|
|
Err(_err) => { |
|
|
|
return Err("Failed to discover OpenID provider"); |
|
|
|
}, |
|
|
|
}; |
|
|
|
let client = CoreClient::from_provider_metadata( |
|
|
|
provider_metadata, |
|
|
|
client_id, |
|
|
|
Some(client_secret), |
|
|
|
) |
|
|
|
.set_redirect_uri(RedirectUrl::new(redirect).expect("Invalid redirect URL")); |
|
|
|
return client; |
|
|
|
return Ok(client); |
|
|
|
}, |
|
|
|
None => { |
|
|
|
panic!("unable to find org"); |
|
|
|
Err("unable to find org") |
|
|
|
}, |
|
|
|
} |
|
|
|
} |
|
|
@ -605,11 +599,11 @@ fn authorize( |
|
|
|
domain_hint: &RawStr, |
|
|
|
state: &RawStr, |
|
|
|
conn: DbConn, |
|
|
|
) -> Redirect { |
|
|
|
) -> ApiResult<Redirect> { |
|
|
|
let domain_hint_decoded = &domain_hint.percent_decode().expect("Invalid domain_hint").into_owned(); |
|
|
|
let state_decoded = &state.percent_decode().expect("Invalid state").into_owned(); |
|
|
|
let client = get_client_from_identifier(domain_hint_decoded, &conn); |
|
|
|
|
|
|
|
match get_client_from_identifier(domain_hint_decoded, &conn) { |
|
|
|
Ok(client) => { |
|
|
|
// TODO store the nonce for validation on authorization token exchange - unclear where to store
|
|
|
|
// this
|
|
|
|
let (mut authorize_url, _csrf_state, _nonce) = client |
|
|
@ -635,29 +629,32 @@ fn authorize( |
|
|
|
let full_query = Vec::from_iter(new_pairs).join("&"); |
|
|
|
authorize_url.set_query(Some(full_query.as_str())); |
|
|
|
|
|
|
|
return Redirect::to(authorize_url.to_string()); |
|
|
|
return Ok(Redirect::to(authorize_url.to_string())); |
|
|
|
}, |
|
|
|
Err(_err) => err!("Unable to find client from identifier"), |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
fn get_auth_code_access_token ( |
|
|
|
code: &str, |
|
|
|
org_identifier: &str, |
|
|
|
conn: &DbConn, |
|
|
|
) -> (String, String) { |
|
|
|
) -> Result<(String, String), &'static str> { |
|
|
|
let oidc_code = AuthorizationCode::new(String::from(code)); |
|
|
|
|
|
|
|
let client = get_client_from_identifier(org_identifier, conn); |
|
|
|
|
|
|
|
let token_response = client |
|
|
|
.exchange_code(oidc_code) |
|
|
|
.request(http_client) |
|
|
|
.unwrap_or_else(|err| { |
|
|
|
handle_error(&err, "Failed to contact token endpoint"); |
|
|
|
unreachable!(); |
|
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
match get_client_from_identifier(org_identifier, conn) { |
|
|
|
Ok(client) => { |
|
|
|
match client.exchange_code(oidc_code).request(http_client) { |
|
|
|
Ok(token_response) => { |
|
|
|
let access_token = token_response.access_token().secret().to_string(); |
|
|
|
let refresh_token = token_response.refresh_token().unwrap().secret().to_string(); |
|
|
|
|
|
|
|
(access_token, refresh_token) |
|
|
|
Ok((access_token, refresh_token)) |
|
|
|
}, |
|
|
|
Err(_err) => Err("Failed to contact token endpoint"), |
|
|
|
} |
|
|
|
|
|
|
|
}, |
|
|
|
Err(_err) => Err("unable to find client"), |
|
|
|
} |
|
|
|
} |
|
|
|