Browse Source

Added SSO numeric normalizations

Some SSO platforms (most prominent being MS Entra ID platform) uses strings for numerical value fields like "expires_on", "expires_in", etc.
This commit make considerations for these kind of ineptitude to prevent frustrations
pull/6222/head
Kalyan Parajuli 2 weeks ago
parent
commit
fb0a0c2ca7
  1. 114
      src/sso_client.rs

114
src/sso_client.rs

@ -1,4 +1,5 @@
use regex::Regex;
use serde_json::Value;
use std::borrow::Cow;
use std::time::Duration;
use url::Url;
@ -15,7 +16,6 @@ use crate::{
sso::{OIDCCode, OIDCState},
CONFIG,
};
static CLIENT_CACHE_KEY: Lazy<String> = Lazy::new(|| "sso-client".to_string());
static CLIENT_CACHE: Lazy<Cache<String, Client>> = Lazy::new(|| {
Cache::builder().max_capacity(1).time_to_live(Duration::from_secs(CONFIG.sso_client_cache_expiration())).build()
@ -155,19 +155,8 @@ impl Client {
>,
IdTokenClaims<EmptyAdditionalClaims, CoreGenderClaim>,
)> {
let oidc_code = AuthorizationCode::new(code.to_string());
let mut exchange = self.core_client.exchange_code(oidc_code);
if CONFIG.sso_pkce() {
match nonce.verifier {
None => err!(format!("Missing verifier in the DB nonce table")),
Some(secret) => exchange = exchange.set_pkce_verifier(PkceCodeVerifier::new(secret.clone())),
}
}
match exchange.request_async(&self.http_client).await {
Err(err) => err!(format!("Failed to contact token endpoint: {:?}", err)),
match self.perform_code_request(code, &nonce).await {
Err(err) => err!(format!("Endpoint response error: {:?}", err)),
Ok(token_response) => {
let oidc_nonce = Nonce::new(nonce.nonce);
@ -196,6 +185,86 @@ impl Client {
}
}
async fn perform_code_request(
&self,
code: OIDCCode,
nonce: &SsoNonce,
) -> ApiResult<
StandardTokenResponse<
IdTokenFields<
EmptyAdditionalClaims,
EmptyExtraTokenFields,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJwsSigningAlgorithm,
>,
CoreTokenType,
>,
> {
let oidc_code = AuthorizationCode::new(code.to_string());
let mut exchange = self.core_client.exchange_code(oidc_code);
if CONFIG.sso_pkce() {
match &nonce.verifier {
None => err!(format!("Missing verifier in the DB nonce table")),
Some(secret) => exchange = exchange.set_pkce_verifier(PkceCodeVerifier::new(secret.clone())),
}
}
match exchange.request_async(&self.http_client).await {
Err(err) => Self::attempt_parsing_recovery(err),
Ok(token_response) => Ok(token_response),
}
}
fn attempt_parsing_recovery(
error: RequestTokenError<HttpClientError<reqwest::Error>, StandardErrorResponse<CoreErrorResponseType>>,
) -> ApiResult<
StandardTokenResponse<
IdTokenFields<
EmptyAdditionalClaims,
EmptyExtraTokenFields,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJwsSigningAlgorithm,
>,
CoreTokenType,
>,
> {
match &error {
RequestTokenError::Parse(_, response_bytes) => {
let response_body = String::from_utf8_lossy(response_bytes);
let mut parsed: Value = serde_json::from_str(&response_body)?;
// Normalize numeric fields which might be present as strings
Self::normalize_numeric_fields(&mut parsed);
// Parse back to token response
let token_response = serde_json::from_value(parsed);
match token_response {
Err(err) => {
err!(format!("Failed to parse token endpoint response: {:?}", err))
}
Ok(token_response) => Ok(token_response),
}
}
_ => err!(format!("Failed to contact access token endpoint: {:?}", error)),
}
}
fn normalize_numeric_fields(response_parsed: &mut Value) {
let numeric_fields = ["expires_in", "ext_expires_in", "expires_on"]; // MS Entra fields present as strings
for field in &numeric_fields {
if let Some(field_val) = response_parsed.get_mut(field) {
if let Some(string_value) = field_val.as_str() {
if let Ok(num_val) = string_value.parse::<u64>() {
*field_val = Value::Number(num_val.into());
}
}
}
}
}
pub async fn user_info(&self, access_token: AccessToken) -> ApiResult<CoreUserInfoClaims> {
match self.core_client.user_info(access_token, None).request_async(&self.http_client).await {
Err(err) => err!(format!("Request to user_info endpoint failed: {err}")),
@ -236,15 +305,18 @@ impl Client {
let client = Client::cached().await?;
let token_response =
match client.core_client.exchange_refresh_token(&rt).request_async(&client.http_client).await {
Err(err) => err!(format!("Request to exchange_refresh_token endpoint failed: {:?}", err)),
Ok(token_response) => token_response,
Err(err) => Self::attempt_parsing_recovery(err),
Ok(token_response) => Ok(token_response),
};
Ok((
token_response.refresh_token().map(|token| token.secret().clone()),
token_response.access_token().secret().clone(),
token_response.expires_in(),
))
match token_response {
Err(err) => err!(format!("Request to exchange_refresh_token endpoint failed: {:?}", err)),
Ok(token_response) => Ok((
token_response.refresh_token().map(|token| token.secret().clone()),
token_response.access_token().secret().clone(),
token_response.expires_in(),
)),
}
}
}

Loading…
Cancel
Save