|
@ -1,7 +1,7 @@ |
|
|
use chrono::Local; |
|
|
use chrono::Local; |
|
|
use num_traits::FromPrimitive; |
|
|
use num_traits::FromPrimitive; |
|
|
use rocket::{ |
|
|
use rocket::{ |
|
|
http::{RawStr, Status}, |
|
|
http::Status, |
|
|
request::{Form, FormItems, FromForm}, |
|
|
request::{Form, FormItems, FromForm}, |
|
|
response::Redirect, |
|
|
response::Redirect, |
|
|
Route, |
|
|
Route, |
|
@ -526,32 +526,25 @@ fn invalid_json(error_message: &str, exception: bool) -> JsonResult { |
|
|
|
|
|
|
|
|
#[get("/account/prevalidate?<domainHint>")] |
|
|
#[get("/account/prevalidate?<domainHint>")] |
|
|
#[allow(non_snake_case)] |
|
|
#[allow(non_snake_case)] |
|
|
fn prevalidate(domainHint: &RawStr, conn: DbConn) -> JsonResult { |
|
|
fn prevalidate(domainHint: String, conn: DbConn) -> JsonResult { |
|
|
let empty_result = json!({}); |
|
|
let empty_result = json!({}); |
|
|
match domainHint.percent_decode() { |
|
|
let organization = Organization::find_by_identifier(&domainHint, &conn); |
|
|
Ok(domain_hint) => { |
|
|
match organization { |
|
|
let organization = Organization::find_by_identifier(&domain_hint.to_owned(), &conn); |
|
|
Some(organization) => { |
|
|
match organization { |
|
|
if !organization.use_sso { |
|
|
Some(organization) => { |
|
|
return invalid_json("SSO Not allowed for organization", false); |
|
|
if !organization.use_sso { |
|
|
|
|
|
return invalid_json("SSO Not allowed for organization", false); |
|
|
|
|
|
} |
|
|
|
|
|
}, |
|
|
|
|
|
None => { |
|
|
|
|
|
return invalid_json("Organization not found by identifier", false); |
|
|
|
|
|
}, |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if domainHint == "" { |
|
|
|
|
|
return invalid_json("No Organization Identifier Provided", false); |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Ok(Json(empty_result)) |
|
|
|
|
|
}, |
|
|
}, |
|
|
Err(_) => { |
|
|
None => { |
|
|
return invalid_json("Invalid domainHint received", false); |
|
|
return invalid_json("Organization not found by identifier", false); |
|
|
}, |
|
|
}, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if domainHint == "" { |
|
|
|
|
|
return invalid_json("No Organization Identifier Provided", false); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Ok(Json(empty_result)) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
use openidconnect::core::{ |
|
|
use openidconnect::core::{ |
|
@ -596,13 +589,11 @@ fn get_client_from_identifier (identifier: &str, conn: &DbConn) -> Result<CoreCl |
|
|
|
|
|
|
|
|
#[get("/connect/authorize?<domain_hint>&<state>")] |
|
|
#[get("/connect/authorize?<domain_hint>&<state>")] |
|
|
fn authorize( |
|
|
fn authorize( |
|
|
domain_hint: &RawStr, |
|
|
domain_hint: String, |
|
|
state: &RawStr, |
|
|
state: String, |
|
|
conn: DbConn, |
|
|
conn: DbConn, |
|
|
) -> ApiResult<Redirect> { |
|
|
) -> ApiResult<Redirect> { |
|
|
let domain_hint_decoded = &domain_hint.percent_decode().expect("Invalid domain_hint").into_owned(); |
|
|
match get_client_from_identifier(&domain_hint, &conn) { |
|
|
let state_decoded = &state.percent_decode().expect("Invalid state").into_owned(); |
|
|
|
|
|
match get_client_from_identifier(domain_hint_decoded, &conn) { |
|
|
|
|
|
Ok(client) => { |
|
|
Ok(client) => { |
|
|
// TODO store the nonce for validation on authorization token exchange - unclear where to store
|
|
|
// TODO store the nonce for validation on authorization token exchange - unclear where to store
|
|
|
// this
|
|
|
// this
|
|
@ -622,7 +613,7 @@ fn authorize( |
|
|
let new_pairs = old_pairs.map(|pair| { |
|
|
let new_pairs = old_pairs.map(|pair| { |
|
|
let (key, value) = pair; |
|
|
let (key, value) = pair; |
|
|
if key == "state" { |
|
|
if key == "state" { |
|
|
return format!("{}={}", key, state_decoded); |
|
|
return format!("{}={}", key, state); |
|
|
} |
|
|
} |
|
|
return format!("{}={}", key, value); |
|
|
return format!("{}={}", key, value); |
|
|
}); |
|
|
}); |
|
|