diff --git a/src/api/identity.rs b/src/api/identity.rs index b2955336..f2653390 100644 --- a/src/api/identity.rs +++ b/src/api/identity.rs @@ -1,7 +1,7 @@ use chrono::Local; use num_traits::FromPrimitive; use rocket::{ - http::{RawStr, Status}, + http::Status, request::{Form, FormItems, FromForm}, response::Redirect, Route, @@ -526,32 +526,25 @@ fn invalid_json(error_message: &str, exception: bool) -> JsonResult { #[get("/account/prevalidate?")] #[allow(non_snake_case)] -fn prevalidate(domainHint: &RawStr, conn: DbConn) -> JsonResult { +fn prevalidate(domainHint: String, conn: DbConn) -> JsonResult { let empty_result = json!({}); - match domainHint.percent_decode() { - Ok(domain_hint) => { - let organization = Organization::find_by_identifier(&domain_hint.to_owned(), &conn); - match organization { - Some(organization) => { - if !organization.use_sso { - return invalid_json("SSO Not allowed for organization", false); - } - }, - None => { - return invalid_json("Organization not found by identifier", false); - }, - } - - if domainHint == "" { - return invalid_json("No Organization Identifier Provided", false); + let organization = Organization::find_by_identifier(&domainHint, &conn); + match organization { + Some(organization) => { + if !organization.use_sso { + return invalid_json("SSO Not allowed for organization", false); } - - Ok(Json(empty_result)) }, - Err(_) => { - return invalid_json("Invalid domainHint received", false); + None => { + return invalid_json("Organization not found by identifier", false); }, } + + if domainHint == "" { + return invalid_json("No Organization Identifier Provided", false); + } + + Ok(Json(empty_result)) } use openidconnect::core::{ @@ -596,13 +589,11 @@ fn get_client_from_identifier (identifier: &str, conn: &DbConn) -> Result&")] fn authorize( - domain_hint: &RawStr, - state: &RawStr, + domain_hint: String, + state: String, conn: DbConn, ) -> ApiResult { - let domain_hint_decoded = &domain_hint.percent_decode().expect("Invalid domain_hint").into_owned(); - let state_decoded = &state.percent_decode().expect("Invalid state").into_owned(); - match get_client_from_identifier(domain_hint_decoded, &conn) { + match get_client_from_identifier(&domain_hint, &conn) { Ok(client) => { // TODO store the nonce for validation on authorization token exchange - unclear where to store // this @@ -622,7 +613,7 @@ fn authorize( let new_pairs = old_pairs.map(|pair| { let (key, value) = pair; if key == "state" { - return format!("{}={}", key, state_decoded); + return format!("{}={}", key, state); } return format!("{}={}", key, value); });