From f82a142ceed15f425e92c877abc11b4f5b3a0294 Mon Sep 17 00:00:00 2001 From: BlockListed <44610569+BlockListed@users.noreply.github.com> Date: Sat, 9 Sep 2023 13:50:56 +0200 Subject: [PATCH] get domain and origin with single extractor --- src/auth.rs | 38 ++++++++++++++++++++++---------------- src/config.rs | 2 +- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/auth.rs b/src/auth.rs index d4762d57..214bf855 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -10,6 +10,7 @@ use serde::de::DeserializeOwned; use serde::ser::Serialize; use crate::{error::Error, CONFIG}; +use crate::config::extract_url_origin; const JWT_ALGORITHM: Algorithm = Algorithm::RS256; @@ -360,12 +361,13 @@ use crate::db::{ DbConn, }; -pub struct BaseURL { +pub struct HostInfo { pub base_url: String, + pub origin: String, } #[rocket::async_trait] -impl<'r> FromRequest<'r> for BaseURL { +impl<'r> FromRequest<'r> for HostInfo { type Error = &'static str; async fn from_request(request: &'r Request<'_>) -> Outcome { @@ -373,7 +375,7 @@ impl<'r> FromRequest<'r> for BaseURL { // Get host // TODO: UPDATE THIS SECTION - let base_url = if CONFIG.domain_set() { + if CONFIG.domain_set() { let host = if let Some(host) = headers.get_one("X-Forwarded-Host") { host } else if let Some(host) = headers.get_one("Host") { @@ -385,15 +387,18 @@ impl<'r> FromRequest<'r> for BaseURL { todo!() }; - let Some(base_url) = CONFIG.host_to_base_url(host) else { - // TODO fix error handling - // This is probably a 421 misdirected request. - todo!() - }; - - base_url + // TODO fix error handling + // This is probably a 421 misdirected request + let (base_url, origin) = CONFIG.host_to_domain(host).and_then(|base_url| { + Some((base_url, CONFIG.domain_origin(host)?)) + }).expect("This should not be merged like this!!!"); + + return Outcome::Success(HostInfo { base_url, origin }); } else if let Some(referer) = headers.get_one("Referer") { - referer.to_string() + return Outcome::Success(HostInfo { + base_url: referer.to_string(), + origin: extract_url_origin(referer), + }); } else { // Try to guess from the headers use std::env; @@ -414,12 +419,13 @@ impl<'r> FromRequest<'r> for BaseURL { "" }; - format!("{protocol}://{host}") - }; + let base_url_origin = format!("{protocol}://{host}"); - Outcome::Success(BaseURL { - base_url, - }) + return Outcome::Success(HostInfo { + base_url: base_url_origin, + origin: base_url_origin, + }); + } } } diff --git a/src/config.rs b/src/config.rs index 70878f06..1f62c22f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1003,7 +1003,7 @@ fn validate_config(cfg: &ConfigItems) -> Result<(), Error> { } /// Extracts an RFC 6454 web origin from a URL. -fn extract_url_origin(url: &str) -> String { +pub fn extract_url_origin(url: &str) -> String { match Url::parse(url) { Ok(u) => u.origin().ascii_serialization(), Err(e) => {