You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
560 lines
17 KiB
560 lines
17 KiB
use std::{
|
|
fmt,
|
|
net::{IpAddr, SocketAddr},
|
|
sync::{Arc, LazyLock, Mutex},
|
|
time::Duration,
|
|
};
|
|
|
|
use hickory_resolver::{net::runtime::TokioRuntimeProvider, TokioResolver};
|
|
use regex::Regex;
|
|
use reqwest::{
|
|
dns::{Name, Resolve, Resolving},
|
|
header, Client, ClientBuilder,
|
|
};
|
|
use url::Host;
|
|
|
|
use crate::{util::is_global, CONFIG};
|
|
|
|
pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::RequestBuilder, crate::Error> {
|
|
let Ok(url) = url::Url::parse(url) else {
|
|
err!("Invalid URL");
|
|
};
|
|
let Some(host) = url.host() else {
|
|
err!("Invalid host");
|
|
};
|
|
|
|
should_block_host(&host)?;
|
|
|
|
static INSTANCE: LazyLock<Client> =
|
|
LazyLock::new(|| get_reqwest_client_builder().build().expect("Failed to build client"));
|
|
|
|
Ok(INSTANCE.request(method, url))
|
|
}
|
|
|
|
pub fn get_reqwest_client_builder() -> ClientBuilder {
|
|
let mut headers = header::HeaderMap::new();
|
|
headers.insert(header::USER_AGENT, header::HeaderValue::from_static("Vaultwarden"));
|
|
|
|
let redirect_policy = reqwest::redirect::Policy::custom(|attempt| {
|
|
if attempt.previous().len() >= 5 {
|
|
return attempt.error("Too many redirects");
|
|
}
|
|
|
|
let Some(host) = attempt.url().host() else {
|
|
return attempt.error("Invalid host");
|
|
};
|
|
|
|
if let Err(e) = should_block_host(&host) {
|
|
return attempt.error(e);
|
|
}
|
|
|
|
attempt.follow()
|
|
});
|
|
|
|
Client::builder()
|
|
.default_headers(headers)
|
|
.redirect(redirect_policy)
|
|
.dns_resolver(CustomDnsResolver::instance())
|
|
.timeout(Duration::from_secs(10))
|
|
}
|
|
|
|
fn should_block_ip(ip: IpAddr) -> bool {
|
|
if !CONFIG.http_request_block_non_global_ips() {
|
|
return false;
|
|
}
|
|
|
|
!is_global(ip)
|
|
}
|
|
|
|
fn should_block_address_regex(domain_or_ip: &str) -> bool {
|
|
let Some(block_regex) = CONFIG.http_request_block_regex() else {
|
|
return false;
|
|
};
|
|
|
|
static COMPILED_REGEX: Mutex<Option<(String, Regex)>> = Mutex::new(None);
|
|
let mut guard = COMPILED_REGEX.lock().unwrap();
|
|
|
|
// If the stored regex is up to date, use it
|
|
if let Some((value, regex)) = &*guard {
|
|
if value == &block_regex {
|
|
return regex.is_match(domain_or_ip);
|
|
}
|
|
}
|
|
|
|
// If we don't have a regex stored, or it's not up to date, recreate it
|
|
let regex = Regex::new(&block_regex).unwrap();
|
|
let is_match = regex.is_match(domain_or_ip);
|
|
*guard = Some((block_regex, regex));
|
|
|
|
is_match
|
|
}
|
|
|
|
pub fn get_valid_host(host: &str) -> Result<Host, CustomHttpClientError> {
|
|
let Ok(host) = Host::parse(host) else {
|
|
return Err(CustomHttpClientError::Invalid {
|
|
domain: host.to_string(),
|
|
});
|
|
};
|
|
|
|
// Some extra checks to validate hosts
|
|
match host {
|
|
Host::Domain(ref domain) => {
|
|
// Host::parse() does not verify length or all possible invalid characters
|
|
// We do some extra checks here to prevent issues
|
|
if domain.len() > 253 {
|
|
debug!("Domain validation error: '{domain}' exceeds 253 characters");
|
|
return Err(CustomHttpClientError::Invalid {
|
|
domain: host.to_string(),
|
|
});
|
|
}
|
|
if !domain.split('.').all(|label| {
|
|
!label.is_empty()
|
|
// Labels can't be longer than 63 chars
|
|
&& label.len() <= 63
|
|
// Labels are not allowed to start or end with a hyphen `-`
|
|
&& !label.starts_with('-')
|
|
&& !label.ends_with('-')
|
|
// Only ASCII Alphanumeric characters are allowed
|
|
// We already received a punycoded domain back, so no unicode should exists here
|
|
&& label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
|
|
}) {
|
|
debug!(
|
|
"Domain validation error: '{domain}' labels contain invalid characters or exceed the maximum length"
|
|
);
|
|
return Err(CustomHttpClientError::Invalid {
|
|
domain: host.to_string(),
|
|
});
|
|
}
|
|
}
|
|
Host::Ipv4(_) | Host::Ipv6(_) => {}
|
|
}
|
|
|
|
Ok(host)
|
|
}
|
|
|
|
pub fn should_block_host<S: AsRef<str>>(host: &Host<S>) -> Result<(), CustomHttpClientError> {
|
|
let (ip, host_str): (Option<IpAddr>, String) = match host {
|
|
Host::Ipv4(ip) => (Some(IpAddr::V4(*ip)), ip.to_string()),
|
|
Host::Ipv6(ip) => (Some(IpAddr::V6(*ip)), ip.to_string()),
|
|
Host::Domain(d) => (None, d.as_ref().to_string()),
|
|
};
|
|
|
|
if let Some(ip) = ip {
|
|
if should_block_ip(ip) {
|
|
return Err(CustomHttpClientError::NonGlobalIp {
|
|
domain: None,
|
|
ip,
|
|
});
|
|
}
|
|
}
|
|
|
|
if should_block_address_regex(&host_str) {
|
|
return Err(CustomHttpClientError::Blocked {
|
|
domain: host_str,
|
|
});
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum CustomHttpClientError {
|
|
Blocked {
|
|
domain: String,
|
|
},
|
|
NonGlobalIp {
|
|
domain: Option<String>,
|
|
ip: IpAddr,
|
|
},
|
|
Invalid {
|
|
domain: String,
|
|
},
|
|
}
|
|
|
|
impl CustomHttpClientError {
|
|
pub fn downcast_ref(e: &dyn std::error::Error) -> Option<&Self> {
|
|
let mut source = e.source();
|
|
|
|
while let Some(err) = source {
|
|
source = err.source();
|
|
if let Some(err) = err.downcast_ref::<CustomHttpClientError>() {
|
|
return Some(err);
|
|
}
|
|
}
|
|
None
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for CustomHttpClientError {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
Self::Blocked {
|
|
domain,
|
|
} => write!(f, "Blocked domain: '{domain}' matched HTTP_REQUEST_BLOCK_REGEX"),
|
|
Self::NonGlobalIp {
|
|
domain: Some(domain),
|
|
ip,
|
|
} => write!(f, "IP {ip} for domain '{domain}' is not a global IP!"),
|
|
Self::NonGlobalIp {
|
|
domain: None,
|
|
ip,
|
|
} => write!(f, "IP '{ip}' is not a global IP!"),
|
|
Self::Invalid {
|
|
domain,
|
|
} => write!(f, "Invalid host: '{domain}' contains invalid characters or exceeds the maximum length"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for CustomHttpClientError {}
|
|
|
|
#[derive(Debug, Clone)]
|
|
enum CustomDnsResolver {
|
|
Default(),
|
|
Hickory(Arc<TokioResolver>),
|
|
}
|
|
type BoxError = Box<dyn std::error::Error + Send + Sync>;
|
|
|
|
impl CustomDnsResolver {
|
|
fn instance() -> Arc<Self> {
|
|
static INSTANCE: LazyLock<Arc<CustomDnsResolver>> = LazyLock::new(CustomDnsResolver::new);
|
|
Arc::clone(&*INSTANCE)
|
|
}
|
|
|
|
fn new() -> Arc<Self> {
|
|
TokioResolver::builder(TokioRuntimeProvider::default())
|
|
.and_then(|mut builder| {
|
|
// Hickory's default since v0.26 is `Ipv6AndIpv4`, which sorts IPv6 first
|
|
// This might cause issues on IPv4 only systems or containers
|
|
// Unless someone enabled DNS_PREFER_IPV6, use Ipv4AndIpv6, which returns IPv4 first which was our previous default
|
|
if !CONFIG.dns_prefer_ipv6() {
|
|
builder.options_mut().ip_strategy = hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6;
|
|
}
|
|
builder.build()
|
|
})
|
|
.inspect_err(|e| warn!("Error creating Hickory resolver, falling back to default: {e:?}"))
|
|
.map(|resolver| Arc::new(Self::Hickory(Arc::new(resolver))))
|
|
.unwrap_or_else(|_| Arc::new(Self::Default()))
|
|
}
|
|
|
|
// Note that we get an iterator of addresses, but we only grab the first one for convenience
|
|
async fn resolve_domain(&self, name: &str) -> Result<Vec<SocketAddr>, BoxError> {
|
|
pre_resolve(name)?;
|
|
|
|
let results: Vec<SocketAddr> = match self {
|
|
Self::Default() => tokio::net::lookup_host((name, 0)).await?.collect(),
|
|
Self::Hickory(r) => r.lookup_ip(name).await?.iter().map(|i| SocketAddr::new(i, 0)).collect(),
|
|
};
|
|
|
|
for addr in &results {
|
|
post_resolve(name, addr.ip())?;
|
|
}
|
|
|
|
Ok(results)
|
|
}
|
|
}
|
|
|
|
fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> {
|
|
let Ok(host) = get_valid_host(name) else {
|
|
return Err(CustomHttpClientError::Invalid {
|
|
domain: name.to_string(),
|
|
});
|
|
};
|
|
|
|
if should_block_host(&host).is_err() {
|
|
return Err(CustomHttpClientError::Blocked {
|
|
domain: name.to_string(),
|
|
});
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomHttpClientError> {
|
|
if should_block_ip(ip) {
|
|
Err(CustomHttpClientError::NonGlobalIp {
|
|
domain: Some(name.to_string()),
|
|
ip,
|
|
})
|
|
} else {
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl Resolve for CustomDnsResolver {
|
|
fn resolve(&self, name: Name) -> Resolving {
|
|
let this = self.clone();
|
|
Box::pin(async move {
|
|
let name = name.as_str();
|
|
let results = this.resolve_domain(name).await?;
|
|
if results.is_empty() {
|
|
warn!("Unable to resolve {name} to any valid IP address");
|
|
}
|
|
Ok::<reqwest::dns::Addrs, _>(Box::new(results.into_iter()))
|
|
})
|
|
}
|
|
}
|
|
|
|
#[cfg(s3)]
|
|
pub(crate) mod aws {
|
|
use aws_smithy_runtime_api::client::{
|
|
http::{HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector},
|
|
orchestrator::HttpResponse,
|
|
result::ConnectorError,
|
|
runtime_components::RuntimeComponents,
|
|
};
|
|
use reqwest::Client;
|
|
|
|
// Adapter that wraps reqwest to be compatible with the AWS SDK
|
|
#[derive(Debug)]
|
|
pub(crate) struct AwsReqwestConnector {
|
|
pub(crate) client: Client,
|
|
}
|
|
|
|
impl HttpConnector for AwsReqwestConnector {
|
|
fn call(&self, request: aws_smithy_runtime_api::client::orchestrator::HttpRequest) -> HttpConnectorFuture {
|
|
// Convert the AWS-style request to a reqwest request
|
|
let client = self.client.clone();
|
|
let future = async move {
|
|
let method = reqwest::Method::from_bytes(request.method().as_bytes())
|
|
.map_err(|e| ConnectorError::user(Box::new(e)))?;
|
|
let mut req_builder = client.request(method, request.uri().to_string());
|
|
|
|
for (name, value) in request.headers() {
|
|
req_builder = req_builder.header(name, value);
|
|
}
|
|
|
|
if let Some(body_bytes) = request.body().bytes() {
|
|
req_builder = req_builder.body(body_bytes.to_vec());
|
|
}
|
|
|
|
let response = req_builder.send().await.map_err(|e| ConnectorError::io(Box::new(e)))?;
|
|
|
|
let status = response.status().into();
|
|
let bytes = response.bytes().await.map_err(|e| ConnectorError::io(Box::new(e)))?;
|
|
|
|
Ok(HttpResponse::new(status, bytes.into()))
|
|
};
|
|
|
|
HttpConnectorFuture::new(Box::pin(future))
|
|
}
|
|
}
|
|
|
|
impl HttpClient for AwsReqwestConnector {
|
|
fn http_connector(
|
|
&self,
|
|
_settings: &HttpConnectorSettings,
|
|
_components: &RuntimeComponents,
|
|
) -> SharedHttpConnector {
|
|
SharedHttpConnector::new(AwsReqwestConnector {
|
|
client: self.client.clone(),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::util::is_global_hardcoded;
|
|
use std::net::Ipv4Addr;
|
|
use url::Host;
|
|
|
|
// ===
|
|
// IPv4 numeric-format normalization
|
|
fn parse_to_ip(s: &str) -> Option<IpAddr> {
|
|
match Host::parse(s).ok()? {
|
|
Host::Ipv4(v4) => Some(IpAddr::V4(v4)),
|
|
Host::Ipv6(v6) => Some(IpAddr::V6(v6)),
|
|
Host::Domain(_) => None,
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn dotted_decimal_loopback_normalizes() {
|
|
let ip = parse_to_ip("127.0.0.1").unwrap();
|
|
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
|
assert!(!is_global_hardcoded(ip));
|
|
}
|
|
|
|
#[test]
|
|
fn single_decimal_loopback_normalizes() {
|
|
// 127.0.0.1 == 2130706433
|
|
let ip = parse_to_ip("2130706433").unwrap();
|
|
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
|
assert!(!is_global_hardcoded(ip));
|
|
}
|
|
|
|
#[test]
|
|
fn hex_loopback_normalizes() {
|
|
let ip = parse_to_ip("0x7f000001").unwrap();
|
|
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
|
assert!(!is_global_hardcoded(ip));
|
|
}
|
|
|
|
#[test]
|
|
fn dotted_hex_loopback_normalizes() {
|
|
let ip = parse_to_ip("0x7f.0.0.1").unwrap();
|
|
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
|
assert!(!is_global_hardcoded(ip));
|
|
}
|
|
|
|
#[test]
|
|
fn octal_loopback_normalizes() {
|
|
// 017700000001 == 127.0.0.1
|
|
let ip = parse_to_ip("017700000001").unwrap();
|
|
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
|
assert!(!is_global_hardcoded(ip));
|
|
}
|
|
|
|
#[test]
|
|
fn dotted_octal_loopback_normalizes() {
|
|
let ip = parse_to_ip("0177.0.0.01").unwrap();
|
|
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
|
assert!(!is_global_hardcoded(ip));
|
|
}
|
|
|
|
#[test]
|
|
fn aws_metadata_decimal_blocked() {
|
|
// 169.254.169.254 == 2852039166 (link-local, AWS IMDS)
|
|
let ip = parse_to_ip("2852039166").unwrap();
|
|
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)));
|
|
assert!(!is_global_hardcoded(ip));
|
|
}
|
|
|
|
#[test]
|
|
fn rfc1918_hex_blocked() {
|
|
// 10.0.0.1
|
|
let ip = parse_to_ip("0x0a000001").unwrap();
|
|
assert!(!is_global_hardcoded(ip));
|
|
}
|
|
|
|
#[test]
|
|
fn public_ip_decimal_allowed() {
|
|
// 8.8.8.8 == 134744072
|
|
let ip = parse_to_ip("134744072").unwrap();
|
|
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
|
|
assert!(is_global_hardcoded(ip));
|
|
}
|
|
|
|
// ===
|
|
// get_valid_host integration: numeric forms become Host::Ipv4
|
|
#[test]
|
|
fn get_valid_host_normalizes_decimal_int() {
|
|
let h = get_valid_host("2130706433").expect("valid");
|
|
assert!(matches!(h, Host::Ipv4(ip) if ip == Ipv4Addr::new(127, 0, 0, 1)));
|
|
}
|
|
|
|
#[test]
|
|
fn get_valid_host_normalizes_hex() {
|
|
let h = get_valid_host("0x7f000001").expect("valid");
|
|
assert!(matches!(h, Host::Ipv4(ip) if ip == Ipv4Addr::new(127, 0, 0, 1)));
|
|
}
|
|
|
|
#[test]
|
|
fn get_valid_host_normalizes_octal() {
|
|
let h = get_valid_host("017700000001").expect("valid");
|
|
assert!(matches!(h, Host::Ipv4(ip) if ip == Ipv4Addr::new(127, 0, 0, 1)));
|
|
}
|
|
|
|
// ===
|
|
// IPv6 formats
|
|
#[test]
|
|
fn ipv6_loopback_blocked() {
|
|
let h = get_valid_host("[::1]").expect("valid");
|
|
let Host::Ipv6(ip) = h else {
|
|
panic!("expected v6")
|
|
};
|
|
assert!(!is_global_hardcoded(IpAddr::V6(ip)));
|
|
}
|
|
|
|
#[test]
|
|
fn ipv4_mapped_in_ipv6_loopback_blocked() {
|
|
// ::ffff:127.0.0.1 — v4-mapped form; is_global_hardcoded blocks via ::ffff:0:0/96
|
|
let h = get_valid_host("[::ffff:127.0.0.1]").expect("valid");
|
|
let Host::Ipv6(ip) = h else {
|
|
panic!("expected v6")
|
|
};
|
|
assert!(!is_global_hardcoded(IpAddr::V6(ip)));
|
|
}
|
|
|
|
#[test]
|
|
fn ipv6_unique_local_blocked() {
|
|
let h = get_valid_host("[fc00::1]").expect("valid");
|
|
let Host::Ipv6(ip) = h else {
|
|
panic!("expected v6")
|
|
};
|
|
assert!(!is_global_hardcoded(IpAddr::V6(ip)));
|
|
}
|
|
|
|
// ===
|
|
// Punycode / IDN
|
|
#[test]
|
|
fn punycode_passthrough() {
|
|
let h = get_valid_host("xn--deadbeafcaf-lbb.test").expect("valid");
|
|
match h {
|
|
Host::Domain(d) => assert_eq!(d, "xn--deadbeafcaf-lbb.test"),
|
|
_ => panic!("expected domain"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn idn_unicode_gets_punycoded() {
|
|
let h = get_valid_host("deadbeafcafé.test").expect("valid");
|
|
match h {
|
|
Host::Domain(d) => assert_eq!(d, "xn--deadbeafcaf-lbb.test"),
|
|
_ => panic!("expected domain"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn idn_unicode_gets_punycoded_tld() {
|
|
let h = get_valid_host("deadbeaf.café").expect("valid");
|
|
match h {
|
|
Host::Domain(d) => assert_eq!(d, "deadbeaf.xn--caf-dma"),
|
|
_ => panic!("expected domain"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn idn_emoji_gets_punycoded() {
|
|
let h = get_valid_host("xn--t88h.test").expect("valid"); // 🛡️.test
|
|
match h {
|
|
Host::Domain(d) => assert_eq!(d, "xn--t88h.test"),
|
|
_ => panic!("expected domain"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn idn_unicode_to_punycode_roundtrip() {
|
|
let from_unicode = get_valid_host("🛡️.test").expect("valid");
|
|
let from_puny = get_valid_host("xn--t88h.test").expect("valid");
|
|
match (from_unicode, from_puny) {
|
|
(Host::Domain(a), Host::Domain(b)) => assert_eq!(a, b),
|
|
_ => panic!("expected domains"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn invalid_punycode_rejected() {
|
|
// bare invalid punycode
|
|
assert!(get_valid_host("xn--").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn underscore_in_label_rejected() {
|
|
assert!(get_valid_host("dead_beaf.cafe").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn label_too_long_rejected() {
|
|
let label = "a".repeat(64);
|
|
assert!(get_valid_host(&format!("{label}.test")).is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn domain_too_long_rejected() {
|
|
let big = "a.".repeat(130) + "test"; // > 253
|
|
assert!(get_valid_host(&big).is_err());
|
|
}
|
|
}
|
|
|