Browse Source

Merge 3916c3c1b0 into d626ea81ab

pull/7246/merge
Timshel 3 days ago
committed by GitHub
parent
commit
5fa8947e36
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 2
      src/api/icons.rs
  2. 40
      src/http_client.rs
  3. 11
      src/sso_client.rs

2
src/api/icons.rs

@ -65,7 +65,7 @@ static CLIENT: LazyLock<Client> = LazyLock::new(|| {
let icon_download_timeout = Duration::from_secs(CONFIG.icon_download_timeout());
let pool_idle_timeout = Duration::from_secs(10);
// Reuse the client between requests
get_reqwest_client_builder()
get_reqwest_client_builder(true)
.cookie_provider(Arc::clone(&cookie_store))
.timeout(icon_download_timeout)
.pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections

40
src/http_client.rs

@ -18,7 +18,7 @@ use crate::{CONFIG, util::is_global};
pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::RequestBuilder, crate::Error> {
static INSTANCE: LazyLock<Client> =
LazyLock::new(|| get_reqwest_client_builder().build().expect("Failed to build client"));
LazyLock::new(|| get_reqwest_client_builder(true).build().expect("Failed to build client"));
let Ok(url) = url::Url::parse(url) else {
err!("Invalid URL");
@ -32,7 +32,7 @@ pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::
Ok(INSTANCE.request(method, url))
}
pub fn get_reqwest_client_builder() -> ClientBuilder {
pub fn get_reqwest_client_builder(enforce_block: bool) -> ClientBuilder {
let mut headers = header::HeaderMap::new();
headers.insert(header::USER_AGENT, header::HeaderValue::from_static("Vaultwarden"));
@ -55,7 +55,7 @@ pub fn get_reqwest_client_builder() -> ClientBuilder {
Client::builder()
.default_headers(headers)
.redirect(redirect_policy)
.dns_resolver(CustomDnsResolver::instance())
.dns_resolver(CustomDns::instance(enforce_block))
.timeout(Duration::from_secs(10))
}
@ -210,6 +210,11 @@ impl fmt::Display for CustomHttpClientError {
impl std::error::Error for CustomHttpClientError {}
pub struct CustomDns {
enforce_block: bool,
resolver: Arc<CustomDnsResolver>,
}
#[derive(Debug, Clone)]
enum CustomDnsResolver {
Default(),
@ -217,12 +222,18 @@ enum CustomDnsResolver {
}
type BoxError = Box<dyn std::error::Error + Send + Sync>;
impl CustomDnsResolver {
fn instance() -> Arc<Self> {
impl CustomDns {
fn instance(enforce_block: bool) -> Self {
static INSTANCE: LazyLock<Arc<CustomDnsResolver>> = LazyLock::new(CustomDnsResolver::new);
Arc::clone(&*INSTANCE)
CustomDns {
enforce_block,
resolver: Arc::clone(&*INSTANCE),
}
}
}
impl CustomDnsResolver {
fn new() -> Arc<Self> {
TokioResolver::builder(TokioRuntimeProvider::default())
.and_then(|mut builder| {
@ -239,30 +250,32 @@ impl CustomDnsResolver {
}
// Note that we get an iterator of addresses, but we only grab the first one for convenience
async fn resolve_domain(&self, name: &str) -> Result<Vec<SocketAddr>, BoxError> {
pre_resolve(name)?;
async fn resolve_domain(&self, name: &str, enforce_block: bool) -> Result<Vec<SocketAddr>, BoxError> {
pre_resolve(name, enforce_block)?;
let results: Vec<SocketAddr> = match self {
Self::Default() => tokio::net::lookup_host((name, 0)).await?.collect(),
Self::Hickory(r) => r.lookup_ip(name).await?.iter().map(|i| SocketAddr::new(i, 0)).collect(),
};
if enforce_block {
for addr in &results {
post_resolve(name, addr.ip())?;
}
}
Ok(results)
}
}
fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> {
fn pre_resolve(name: &str, enforce_block: bool) -> Result<(), CustomHttpClientError> {
let Ok(host) = get_valid_host(name) else {
return Err(CustomHttpClientError::Invalid {
domain: name.to_owned(),
});
};
if should_block_host(&host).is_err() {
if enforce_block && should_block_host(&host).is_err() {
return Err(CustomHttpClientError::Blocked {
domain: name.to_owned(),
});
@ -282,12 +295,13 @@ fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomHttpClientError> {
}
}
impl Resolve for CustomDnsResolver {
impl Resolve for CustomDns {
fn resolve(&self, name: Name) -> Resolving {
let this = self.clone();
let enforce_block = self.enforce_block;
let this = Arc::clone(&self.resolver);
Box::pin(async move {
let name = name.as_str();
let results = this.resolve_domain(name).await?;
let results = this.resolve_domain(name, enforce_block).await?;
if results.is_empty() {
warn!("Unable to resolve {name} to any valid IP address");
}

11
src/sso_client.rs

@ -71,7 +71,7 @@ pub struct OidcHttpClient {
impl OidcHttpClient {
fn new() -> Result<Self, reqwest::Error> {
get_reqwest_client_builder().redirect(reqwest::redirect::Policy::none()).build().map(|client| Self {
get_reqwest_client_builder(false).redirect(reqwest::redirect::Policy::none()).build().map(|client| Self {
client,
})
}
@ -83,7 +83,10 @@ impl<'c> AsyncHttpClient<'c> for OidcHttpClient {
fn call(&'c self, request: HttpRequest) -> Self::Future {
Box::pin(async move {
let response = self.client.execute(request.try_into().map_err(Box::new)?).await.map_err(Box::new)?;
let response = self.client.execute(request.try_into().map_err(Box::new)?).await.map_err(|e| {
debug!("Request failed {e:?}");
Box::new(e)
})?;
let mut builder = http::Response::builder().status(response.status()).version(response.version());
@ -91,7 +94,9 @@ impl<'c> AsyncHttpClient<'c> for OidcHttpClient {
builder = builder.header(name, value);
}
builder.body(response.bytes().await.map_err(Box::new)?.to_vec()).map_err(HttpClientError::Http)
let body = response.bytes().await.map_err(Box::new)?;
debug!("Response body {}", String::from_utf8_lossy(&body));
builder.body(body.to_vec()).map_err(HttpClientError::Http)
})
}
}

Loading…
Cancel
Save