diff --git a/src/auth.rs b/src/auth.rs index b904299b..57534769 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -9,7 +9,7 @@ use openssl::rsa::Rsa; use serde::de::DeserializeOwned; use serde::ser::Serialize; -use crate::config::extract_url_origin; +use crate::config::{extract_url_origin, extract_url_host}; use crate::{error::Error, CONFIG}; const JWT_ALGORITHM: Algorithm = Algorithm::RS256; @@ -371,6 +371,17 @@ pub struct HostInfo { pub origin: String, } +fn get_host_info(host: &str) -> Option { + CONFIG + .host_to_domain(host) + .and_then(|base_url| Some((base_url, CONFIG.domain_origin(host)?))) + .and_then(|(base_url, origin)| Some(HostInfo { base_url, origin })) +} + +fn get_main_host() -> String { + extract_url_host(&CONFIG.main_domain()) +} + #[rocket::async_trait] impl<'r> FromRequest<'r> for HostInfo { type Error = &'static str; @@ -381,28 +392,20 @@ impl<'r> FromRequest<'r> for HostInfo { // Get host // TODO: UPDATE THIS SECTION if CONFIG.domain_set() { - let host = if let Some(host) = headers.get_one("X-Forwarded-Host") { - host + let host: Cow<'_, str> = if let Some(host) = headers.get_one("X-Forwarded-Host") { + host.into() } else if let Some(host) = headers.get_one("Host") { - host + host.into() } else { - // TODO fix error handling - // This is probably a 400 bad request, - // because http requests require the host header - todo!() + get_main_host().into() }; + + let host_info = get_host_info(host.as_ref()) + .unwrap_or_else(|| { + get_host_info(&get_main_host()).expect("Main domain doesn't have entry!") + }); - // TODO fix error handling - // This is probably a 421 misdirected request - let (base_url, origin) = CONFIG - .host_to_domain(host) - .and_then(|base_url| Some((base_url, CONFIG.domain_origin(host)?))) - .expect("This should not be merged like this!!!"); - - return Outcome::Success(HostInfo { - base_url, - origin, - }); + return Outcome::Success(host_info); } else if let Some(referer) = headers.get_one("Referer") { return Outcome::Success(HostInfo { base_url: referer.to_string(), @@ -852,6 +855,7 @@ impl<'r> FromRequest<'r> for OwnerHeaders { } } +use std::borrow::Cow; // // Client IP address detection //