From 3916c3c1b02e07f95c48aeec0ce66e116f772748 Mon Sep 17 00:00:00 2001 From: Timshel Date: Wed, 20 May 2026 19:42:55 +0200 Subject: [PATCH] Fix enforce blocked --- src/api/icons.rs | 2 +- src/http_client.rs | 44 +++++++++++++++++++++++++++++--------------- src/sso_client.rs | 11 ++++++++--- 3 files changed, 38 insertions(+), 19 deletions(-) diff --git a/src/api/icons.rs b/src/api/icons.rs index 02a14844..81191e38 100644 --- a/src/api/icons.rs +++ b/src/api/icons.rs @@ -65,7 +65,7 @@ static CLIENT: LazyLock = LazyLock::new(|| { let icon_download_timeout = Duration::from_secs(CONFIG.icon_download_timeout()); let pool_idle_timeout = Duration::from_secs(10); // Reuse the client between requests - get_reqwest_client_builder() + get_reqwest_client_builder(true) .cookie_provider(Arc::clone(&cookie_store)) .timeout(icon_download_timeout) .pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections diff --git a/src/http_client.rs b/src/http_client.rs index 232ba7da..205b1cc3 100644 --- a/src/http_client.rs +++ b/src/http_client.rs @@ -18,7 +18,7 @@ use crate::{CONFIG, util::is_global}; pub fn make_http_request(method: reqwest::Method, url: &str) -> Result { static INSTANCE: LazyLock = - LazyLock::new(|| get_reqwest_client_builder().build().expect("Failed to build client")); + LazyLock::new(|| get_reqwest_client_builder(true).build().expect("Failed to build client")); let Ok(url) = url::Url::parse(url) else { err!("Invalid URL"); @@ -32,7 +32,7 @@ pub fn make_http_request(method: reqwest::Method, url: &str) -> Result ClientBuilder { +pub fn get_reqwest_client_builder(enforce_block: bool) -> ClientBuilder { let mut headers = header::HeaderMap::new(); headers.insert(header::USER_AGENT, header::HeaderValue::from_static("Vaultwarden")); @@ -55,7 +55,7 @@ pub fn get_reqwest_client_builder() -> ClientBuilder { Client::builder() .default_headers(headers) .redirect(redirect_policy) - .dns_resolver(CustomDnsResolver::instance()) + .dns_resolver(CustomDns::instance(enforce_block)) .timeout(Duration::from_secs(10)) } @@ -210,6 +210,11 @@ impl fmt::Display for CustomHttpClientError { impl std::error::Error for CustomHttpClientError {} +pub struct CustomDns { + enforce_block: bool, + resolver: Arc, +} + #[derive(Debug, Clone)] enum CustomDnsResolver { Default(), @@ -217,12 +222,18 @@ enum CustomDnsResolver { } type BoxError = Box; -impl CustomDnsResolver { - fn instance() -> Arc { +impl CustomDns { + fn instance(enforce_block: bool) -> Self { static INSTANCE: LazyLock> = LazyLock::new(CustomDnsResolver::new); - Arc::clone(&*INSTANCE) + + CustomDns { + enforce_block, + resolver: Arc::clone(&*INSTANCE), + } } +} +impl CustomDnsResolver { fn new() -> Arc { TokioResolver::builder(TokioRuntimeProvider::default()) .and_then(|mut builder| { @@ -239,30 +250,32 @@ impl CustomDnsResolver { } // Note that we get an iterator of addresses, but we only grab the first one for convenience - async fn resolve_domain(&self, name: &str) -> Result, BoxError> { - pre_resolve(name)?; + async fn resolve_domain(&self, name: &str, enforce_block: bool) -> Result, BoxError> { + pre_resolve(name, enforce_block)?; let results: Vec = match self { Self::Default() => tokio::net::lookup_host((name, 0)).await?.collect(), Self::Hickory(r) => r.lookup_ip(name).await?.iter().map(|i| SocketAddr::new(i, 0)).collect(), }; - for addr in &results { - post_resolve(name, addr.ip())?; + if enforce_block { + for addr in &results { + post_resolve(name, addr.ip())?; + } } Ok(results) } } -fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> { +fn pre_resolve(name: &str, enforce_block: bool) -> Result<(), CustomHttpClientError> { let Ok(host) = get_valid_host(name) else { return Err(CustomHttpClientError::Invalid { domain: name.to_owned(), }); }; - if should_block_host(&host).is_err() { + if enforce_block && should_block_host(&host).is_err() { return Err(CustomHttpClientError::Blocked { domain: name.to_owned(), }); @@ -282,12 +295,13 @@ fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomHttpClientError> { } } -impl Resolve for CustomDnsResolver { +impl Resolve for CustomDns { fn resolve(&self, name: Name) -> Resolving { - let this = self.clone(); + let enforce_block = self.enforce_block; + let this = Arc::clone(&self.resolver); Box::pin(async move { let name = name.as_str(); - let results = this.resolve_domain(name).await?; + let results = this.resolve_domain(name, enforce_block).await?; if results.is_empty() { warn!("Unable to resolve {name} to any valid IP address"); } diff --git a/src/sso_client.rs b/src/sso_client.rs index 5aa77750..55b9677c 100644 --- a/src/sso_client.rs +++ b/src/sso_client.rs @@ -71,7 +71,7 @@ pub struct OidcHttpClient { impl OidcHttpClient { fn new() -> Result { - get_reqwest_client_builder().redirect(reqwest::redirect::Policy::none()).build().map(|client| Self { + get_reqwest_client_builder(false).redirect(reqwest::redirect::Policy::none()).build().map(|client| Self { client, }) } @@ -83,7 +83,10 @@ impl<'c> AsyncHttpClient<'c> for OidcHttpClient { fn call(&'c self, request: HttpRequest) -> Self::Future { Box::pin(async move { - let response = self.client.execute(request.try_into().map_err(Box::new)?).await.map_err(Box::new)?; + let response = self.client.execute(request.try_into().map_err(Box::new)?).await.map_err(|e| { + debug!("Request failed {e:?}"); + Box::new(e) + })?; let mut builder = http::Response::builder().status(response.status()).version(response.version()); @@ -91,7 +94,9 @@ impl<'c> AsyncHttpClient<'c> for OidcHttpClient { builder = builder.header(name, value); } - builder.body(response.bytes().await.map_err(Box::new)?.to_vec()).map_err(HttpClientError::Http) + let body = response.bytes().await.map_err(Box::new)?; + debug!("Response body {}", String::from_utf8_lossy(&body)); + builder.body(body.to_vec()).map_err(HttpClientError::Http) }) } }