Browse Source

Fix enforce blocked

pull/7246/head
Timshel 3 days ago
parent
commit
3916c3c1b0
  1. 2
      src/api/icons.rs
  2. 40
      src/http_client.rs
  3. 11
      src/sso_client.rs

2
src/api/icons.rs

@ -65,7 +65,7 @@ static CLIENT: LazyLock<Client> = LazyLock::new(|| {
let icon_download_timeout = Duration::from_secs(CONFIG.icon_download_timeout()); let icon_download_timeout = Duration::from_secs(CONFIG.icon_download_timeout());
let pool_idle_timeout = Duration::from_secs(10); let pool_idle_timeout = Duration::from_secs(10);
// Reuse the client between requests // Reuse the client between requests
get_reqwest_client_builder() get_reqwest_client_builder(true)
.cookie_provider(Arc::clone(&cookie_store)) .cookie_provider(Arc::clone(&cookie_store))
.timeout(icon_download_timeout) .timeout(icon_download_timeout)
.pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections .pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections

40
src/http_client.rs

@ -18,7 +18,7 @@ use crate::{CONFIG, util::is_global};
pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::RequestBuilder, crate::Error> { pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::RequestBuilder, crate::Error> {
static INSTANCE: LazyLock<Client> = static INSTANCE: LazyLock<Client> =
LazyLock::new(|| get_reqwest_client_builder().build().expect("Failed to build client")); LazyLock::new(|| get_reqwest_client_builder(true).build().expect("Failed to build client"));
let Ok(url) = url::Url::parse(url) else { let Ok(url) = url::Url::parse(url) else {
err!("Invalid URL"); err!("Invalid URL");
@ -32,7 +32,7 @@ pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::
Ok(INSTANCE.request(method, url)) Ok(INSTANCE.request(method, url))
} }
pub fn get_reqwest_client_builder() -> ClientBuilder { pub fn get_reqwest_client_builder(enforce_block: bool) -> ClientBuilder {
let mut headers = header::HeaderMap::new(); let mut headers = header::HeaderMap::new();
headers.insert(header::USER_AGENT, header::HeaderValue::from_static("Vaultwarden")); headers.insert(header::USER_AGENT, header::HeaderValue::from_static("Vaultwarden"));
@ -55,7 +55,7 @@ pub fn get_reqwest_client_builder() -> ClientBuilder {
Client::builder() Client::builder()
.default_headers(headers) .default_headers(headers)
.redirect(redirect_policy) .redirect(redirect_policy)
.dns_resolver(CustomDnsResolver::instance()) .dns_resolver(CustomDns::instance(enforce_block))
.timeout(Duration::from_secs(10)) .timeout(Duration::from_secs(10))
} }
@ -210,6 +210,11 @@ impl fmt::Display for CustomHttpClientError {
impl std::error::Error for CustomHttpClientError {} impl std::error::Error for CustomHttpClientError {}
pub struct CustomDns {
enforce_block: bool,
resolver: Arc<CustomDnsResolver>,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
enum CustomDnsResolver { enum CustomDnsResolver {
Default(), Default(),
@ -217,12 +222,18 @@ enum CustomDnsResolver {
} }
type BoxError = Box<dyn std::error::Error + Send + Sync>; type BoxError = Box<dyn std::error::Error + Send + Sync>;
impl CustomDnsResolver { impl CustomDns {
fn instance() -> Arc<Self> { fn instance(enforce_block: bool) -> Self {
static INSTANCE: LazyLock<Arc<CustomDnsResolver>> = LazyLock::new(CustomDnsResolver::new); static INSTANCE: LazyLock<Arc<CustomDnsResolver>> = LazyLock::new(CustomDnsResolver::new);
Arc::clone(&*INSTANCE)
CustomDns {
enforce_block,
resolver: Arc::clone(&*INSTANCE),
} }
}
}
impl CustomDnsResolver {
fn new() -> Arc<Self> { fn new() -> Arc<Self> {
TokioResolver::builder(TokioRuntimeProvider::default()) TokioResolver::builder(TokioRuntimeProvider::default())
.and_then(|mut builder| { .and_then(|mut builder| {
@ -239,30 +250,32 @@ impl CustomDnsResolver {
} }
// Note that we get an iterator of addresses, but we only grab the first one for convenience // Note that we get an iterator of addresses, but we only grab the first one for convenience
async fn resolve_domain(&self, name: &str) -> Result<Vec<SocketAddr>, BoxError> { async fn resolve_domain(&self, name: &str, enforce_block: bool) -> Result<Vec<SocketAddr>, BoxError> {
pre_resolve(name)?; pre_resolve(name, enforce_block)?;
let results: Vec<SocketAddr> = match self { let results: Vec<SocketAddr> = match self {
Self::Default() => tokio::net::lookup_host((name, 0)).await?.collect(), Self::Default() => tokio::net::lookup_host((name, 0)).await?.collect(),
Self::Hickory(r) => r.lookup_ip(name).await?.iter().map(|i| SocketAddr::new(i, 0)).collect(), Self::Hickory(r) => r.lookup_ip(name).await?.iter().map(|i| SocketAddr::new(i, 0)).collect(),
}; };
if enforce_block {
for addr in &results { for addr in &results {
post_resolve(name, addr.ip())?; post_resolve(name, addr.ip())?;
} }
}
Ok(results) Ok(results)
} }
} }
fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> { fn pre_resolve(name: &str, enforce_block: bool) -> Result<(), CustomHttpClientError> {
let Ok(host) = get_valid_host(name) else { let Ok(host) = get_valid_host(name) else {
return Err(CustomHttpClientError::Invalid { return Err(CustomHttpClientError::Invalid {
domain: name.to_owned(), domain: name.to_owned(),
}); });
}; };
if should_block_host(&host).is_err() { if enforce_block && should_block_host(&host).is_err() {
return Err(CustomHttpClientError::Blocked { return Err(CustomHttpClientError::Blocked {
domain: name.to_owned(), domain: name.to_owned(),
}); });
@ -282,12 +295,13 @@ fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomHttpClientError> {
} }
} }
impl Resolve for CustomDnsResolver { impl Resolve for CustomDns {
fn resolve(&self, name: Name) -> Resolving { fn resolve(&self, name: Name) -> Resolving {
let this = self.clone(); let enforce_block = self.enforce_block;
let this = Arc::clone(&self.resolver);
Box::pin(async move { Box::pin(async move {
let name = name.as_str(); let name = name.as_str();
let results = this.resolve_domain(name).await?; let results = this.resolve_domain(name, enforce_block).await?;
if results.is_empty() { if results.is_empty() {
warn!("Unable to resolve {name} to any valid IP address"); warn!("Unable to resolve {name} to any valid IP address");
} }

11
src/sso_client.rs

@ -71,7 +71,7 @@ pub struct OidcHttpClient {
impl OidcHttpClient { impl OidcHttpClient {
fn new() -> Result<Self, reqwest::Error> { fn new() -> Result<Self, reqwest::Error> {
get_reqwest_client_builder().redirect(reqwest::redirect::Policy::none()).build().map(|client| Self { get_reqwest_client_builder(false).redirect(reqwest::redirect::Policy::none()).build().map(|client| Self {
client, client,
}) })
} }
@ -83,7 +83,10 @@ impl<'c> AsyncHttpClient<'c> for OidcHttpClient {
fn call(&'c self, request: HttpRequest) -> Self::Future { fn call(&'c self, request: HttpRequest) -> Self::Future {
Box::pin(async move { Box::pin(async move {
let response = self.client.execute(request.try_into().map_err(Box::new)?).await.map_err(Box::new)?; let response = self.client.execute(request.try_into().map_err(Box::new)?).await.map_err(|e| {
debug!("Request failed {e:?}");
Box::new(e)
})?;
let mut builder = http::Response::builder().status(response.status()).version(response.version()); let mut builder = http::Response::builder().status(response.status()).version(response.version());
@ -91,7 +94,9 @@ impl<'c> AsyncHttpClient<'c> for OidcHttpClient {
builder = builder.header(name, value); builder = builder.header(name, value);
} }
builder.body(response.bytes().await.map_err(Box::new)?.to_vec()).map_err(HttpClientError::Http) let body = response.bytes().await.map_err(Box::new)?;
debug!("Response body {}", String::from_utf8_lossy(&body));
builder.body(body.to_vec()).map_err(HttpClientError::Http)
}) })
} }
} }

Loading…
Cancel
Save