From fb0a0c2ca7463837f2f9d5e17151dd9b9c2c799e Mon Sep 17 00:00:00 2001 From: Kalyan Parajuli Date: Sun, 24 Aug 2025 22:27:15 -0700 Subject: [PATCH] Added SSO numeric normalizations Some SSO platforms (most prominent being MS Entra ID platform) uses strings for numerical value fields like "expires_on", "expires_in", etc. This commit make considerations for these kind of ineptitude to prevent frustrations --- src/sso_client.rs | 114 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 93 insertions(+), 21 deletions(-) diff --git a/src/sso_client.rs b/src/sso_client.rs index 3d2a3c48..f98886ea 100644 --- a/src/sso_client.rs +++ b/src/sso_client.rs @@ -1,4 +1,5 @@ use regex::Regex; +use serde_json::Value; use std::borrow::Cow; use std::time::Duration; use url::Url; @@ -15,7 +16,6 @@ use crate::{ sso::{OIDCCode, OIDCState}, CONFIG, }; - static CLIENT_CACHE_KEY: Lazy = Lazy::new(|| "sso-client".to_string()); static CLIENT_CACHE: Lazy> = Lazy::new(|| { Cache::builder().max_capacity(1).time_to_live(Duration::from_secs(CONFIG.sso_client_cache_expiration())).build() @@ -155,19 +155,8 @@ impl Client { >, IdTokenClaims, )> { - let oidc_code = AuthorizationCode::new(code.to_string()); - - let mut exchange = self.core_client.exchange_code(oidc_code); - - if CONFIG.sso_pkce() { - match nonce.verifier { - None => err!(format!("Missing verifier in the DB nonce table")), - Some(secret) => exchange = exchange.set_pkce_verifier(PkceCodeVerifier::new(secret.clone())), - } - } - - match exchange.request_async(&self.http_client).await { - Err(err) => err!(format!("Failed to contact token endpoint: {:?}", err)), + match self.perform_code_request(code, &nonce).await { + Err(err) => err!(format!("Endpoint response error: {:?}", err)), Ok(token_response) => { let oidc_nonce = Nonce::new(nonce.nonce); @@ -196,6 +185,86 @@ impl Client { } } + async fn perform_code_request( + &self, + code: OIDCCode, + nonce: &SsoNonce, + ) -> ApiResult< + StandardTokenResponse< + IdTokenFields< + EmptyAdditionalClaims, + EmptyExtraTokenFields, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJwsSigningAlgorithm, + >, + CoreTokenType, + >, + > { + let oidc_code = AuthorizationCode::new(code.to_string()); + + let mut exchange = self.core_client.exchange_code(oidc_code); + + if CONFIG.sso_pkce() { + match &nonce.verifier { + None => err!(format!("Missing verifier in the DB nonce table")), + Some(secret) => exchange = exchange.set_pkce_verifier(PkceCodeVerifier::new(secret.clone())), + } + } + + match exchange.request_async(&self.http_client).await { + Err(err) => Self::attempt_parsing_recovery(err), + Ok(token_response) => Ok(token_response), + } + } + + fn attempt_parsing_recovery( + error: RequestTokenError, StandardErrorResponse>, + ) -> ApiResult< + StandardTokenResponse< + IdTokenFields< + EmptyAdditionalClaims, + EmptyExtraTokenFields, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJwsSigningAlgorithm, + >, + CoreTokenType, + >, + > { + match &error { + RequestTokenError::Parse(_, response_bytes) => { + let response_body = String::from_utf8_lossy(response_bytes); + let mut parsed: Value = serde_json::from_str(&response_body)?; + // Normalize numeric fields which might be present as strings + Self::normalize_numeric_fields(&mut parsed); + // Parse back to token response + let token_response = serde_json::from_value(parsed); + match token_response { + Err(err) => { + err!(format!("Failed to parse token endpoint response: {:?}", err)) + } + Ok(token_response) => Ok(token_response), + } + } + _ => err!(format!("Failed to contact access token endpoint: {:?}", error)), + } + } + + fn normalize_numeric_fields(response_parsed: &mut Value) { + let numeric_fields = ["expires_in", "ext_expires_in", "expires_on"]; // MS Entra fields present as strings + + for field in &numeric_fields { + if let Some(field_val) = response_parsed.get_mut(field) { + if let Some(string_value) = field_val.as_str() { + if let Ok(num_val) = string_value.parse::() { + *field_val = Value::Number(num_val.into()); + } + } + } + } + } + pub async fn user_info(&self, access_token: AccessToken) -> ApiResult { match self.core_client.user_info(access_token, None).request_async(&self.http_client).await { Err(err) => err!(format!("Request to user_info endpoint failed: {err}")), @@ -236,15 +305,18 @@ impl Client { let client = Client::cached().await?; let token_response = match client.core_client.exchange_refresh_token(&rt).request_async(&client.http_client).await { - Err(err) => err!(format!("Request to exchange_refresh_token endpoint failed: {:?}", err)), - Ok(token_response) => token_response, + Err(err) => Self::attempt_parsing_recovery(err), + Ok(token_response) => Ok(token_response), }; - Ok(( - token_response.refresh_token().map(|token| token.secret().clone()), - token_response.access_token().secret().clone(), - token_response.expires_in(), - )) + match token_response { + Err(err) => err!(format!("Request to exchange_refresh_token endpoint failed: {:?}", err)), + Ok(token_response) => Ok(( + token_response.refresh_token().map(|token| token.secret().clone()), + token_response.access_token().secret().clone(), + token_response.expires_in(), + )), + } } }