From d0d261a3468c97e7500b8b3aa956ea99ad18e9d2 Mon Sep 17 00:00:00 2001 From: Stuart Heap Date: Wed, 1 Sep 2021 16:48:51 +0200 Subject: [PATCH] safe handling of RawStrs --- src/api/identity.rs | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/src/api/identity.rs b/src/api/identity.rs index 8f0a2ea7..f7bc54df 100644 --- a/src/api/identity.rs +++ b/src/api/identity.rs @@ -525,25 +525,30 @@ fn invalid_json(error_message: &str, exception: bool) -> JsonResult { #[allow(non_snake_case)] fn prevalidate(domainHint: &RawStr, conn: DbConn) -> JsonResult { let empty_result = json!({}); + match domainHint.percent_decode() { + Ok(domain_hint) => { + let organization = Organization::find_by_identifier(&domain_hint.to_owned(), &conn); + match organization { + Some(organization) => { + if !organization.use_sso { + return invalid_json("SSO Not allowed for organization", false); + } + }, + None => { + return invalid_json("Organization not found by identifier", false); + }, + } - // TODO as_str shouldn't be used here - let organization = Organization::find_by_identifier(domainHint.as_str(), &conn); - match organization { - Some(organization) => { - if !organization.use_sso { - return invalid_json("SSO Not allowed for organization", false); + if domainHint == "" { + return invalid_json("No Organization Identifier Provided", false); } + + Ok(Json(empty_result)) }, - None => { - return invalid_json("Organization not found by identifier", false); + Err(_) => { + return invalid_json("Invalid domainHint received", false); }, } - - if domainHint == "" { - return invalid_json("No Organization Identifier Provided", false); - } - - Ok(Json(empty_result)) } use openidconnect::core::{ @@ -601,7 +606,9 @@ fn authorize( state: &RawStr, conn: DbConn, ) -> Redirect { - let client = get_client_from_identifier(domain_hint.as_str(), &conn); + let domain_hint_decoded = &domain_hint.percent_decode().expect("Invalid domain_hint").into_owned(); + let state_decoded = &state.percent_decode().expect("Invalid state").into_owned(); + let client = get_client_from_identifier(domain_hint_decoded, &conn); let (mut authorize_url, _csrf_state, _nonce) = client .authorize_url( @@ -619,7 +626,7 @@ fn authorize( let new_pairs = old_pairs.map(|pair| { let (key, value) = pair; if key == "state" { - return format!("{}={}", key, state); + return format!("{}={}", key, state_decoded); } return format!("{}={}", key, value); });