|
|
@ -3,10 +3,12 @@ use num_traits::FromPrimitive; |
|
|
|
use rocket::{ |
|
|
|
http::{RawStr, Status}, |
|
|
|
request::{Form, FormItems, FromForm}, |
|
|
|
response::Redirect, |
|
|
|
Route, |
|
|
|
}; |
|
|
|
use rocket_contrib::json::Json; |
|
|
|
use serde_json::Value; |
|
|
|
use std::iter::FromIterator; |
|
|
|
|
|
|
|
use crate::{ |
|
|
|
api::{ |
|
|
@ -44,6 +46,13 @@ fn login(data: Form<ConnectData>, conn: DbConn, ip: ClientIp) -> JsonResult { |
|
|
|
|
|
|
|
_password_login(data, conn, &ip) |
|
|
|
} |
|
|
|
"authorization_code" => { |
|
|
|
_check_is_some(&data.code, "code cannot be blank")?; |
|
|
|
_check_is_some(&data.org_identifier, "org_identifier cannot be blank")?; |
|
|
|
_check_is_some(&data.device_identifier, "device identifier cannot be blank")?; |
|
|
|
|
|
|
|
_authorization_login(data, conn) |
|
|
|
} |
|
|
|
t => err!("Invalid type", t), |
|
|
|
} |
|
|
|
} |
|
|
@ -78,6 +87,32 @@ fn _refresh_login(data: ConnectData, conn: DbConn) -> JsonResult { |
|
|
|
}))) |
|
|
|
} |
|
|
|
|
|
|
|
fn _authorization_login(data: ConnectData, conn: DbConn) -> JsonResult { |
|
|
|
let (access_token, refresh_token) = get_auth_code_access_token(data.code.unwrap(), data.org_identifier.unwrap(), &conn); |
|
|
|
// let expiry = jsonwebtoken::decode_header(access_token.as_str()).unwrap();
|
|
|
|
let time_now = std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap().as_secs(); |
|
|
|
|
|
|
|
let mut device = Device::find_by_uuid(&data.device_identifier.unwrap(), &conn).map_res("device not found")?; |
|
|
|
|
|
|
|
// COMMON
|
|
|
|
let user = User::find_by_uuid(&device.user_uuid, &conn).unwrap(); |
|
|
|
|
|
|
|
Ok(Json(json!({ |
|
|
|
"access_token": access_token, |
|
|
|
"expires_in": 1000000, |
|
|
|
"token_type": "Bearer", |
|
|
|
"refresh_token": device.refresh_token, |
|
|
|
"Key": user.akey, |
|
|
|
"PrivateKey": user.private_key, |
|
|
|
|
|
|
|
"Kdf": user.client_kdf_type, |
|
|
|
"KdfIterations": user.client_kdf_iter, |
|
|
|
"ResetMasterPassword": false, // TODO: according to official server seems something like: user.password_hash.is_empty(), but would need testing
|
|
|
|
"scope": "api offline_access", |
|
|
|
"unofficialServer": true, |
|
|
|
}))) |
|
|
|
} |
|
|
|
|
|
|
|
fn _password_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult { |
|
|
|
// Validate scope
|
|
|
|
let scope = data.scope.as_ref().unwrap(); |
|
|
@ -393,6 +428,10 @@ struct ConnectData { |
|
|
|
two_factor_provider: Option<i32>, |
|
|
|
two_factor_token: Option<String>, |
|
|
|
two_factor_remember: Option<i32>, |
|
|
|
|
|
|
|
// Needed for authorization code
|
|
|
|
code: Option<String>, |
|
|
|
org_identifier: Option<String>, |
|
|
|
} |
|
|
|
|
|
|
|
impl<'f> FromForm<'f> for ConnectData { |
|
|
@ -419,6 +458,8 @@ impl<'f> FromForm<'f> for ConnectData { |
|
|
|
"twofactorprovider" => form.two_factor_provider = value.parse().ok(), |
|
|
|
"twofactortoken" => form.two_factor_token = Some(value), |
|
|
|
"twofactorremember" => form.two_factor_remember = value.parse().ok(), |
|
|
|
"code" => form.code = Some(value), |
|
|
|
"orgidentifier" => form.org_identifier = Some(value), |
|
|
|
key => warn!("Detected unexpected parameter during login: {}", key), |
|
|
|
} |
|
|
|
} |
|
|
@ -465,25 +506,119 @@ fn prevalidate(domainHint: &RawStr, conn: DbConn) -> JsonResult { |
|
|
|
Ok(Json(empty_result)) |
|
|
|
} |
|
|
|
|
|
|
|
use openidconnect::core::{ |
|
|
|
CoreProviderMetadata, CoreClient, |
|
|
|
CoreResponseType, |
|
|
|
}; |
|
|
|
use openidconnect::reqwest::http_client; |
|
|
|
use openidconnect::{ |
|
|
|
AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, |
|
|
|
CsrfToken, IssuerUrl, Nonce, RedirectUrl, |
|
|
|
Scope, OAuth2TokenResponse, |
|
|
|
}; |
|
|
|
|
|
|
|
fn handle_error<T: std::error::Error>(fail: &T, msg: &'static str) { |
|
|
|
let mut err_msg = format!("ERROR: {}", msg); |
|
|
|
let mut cur_fail: Option<&dyn std::error::Error> = Some(fail); |
|
|
|
while let Some(cause) = cur_fail { |
|
|
|
err_msg += &format!("\n caused by: {}", cause); |
|
|
|
cur_fail = cause.source(); |
|
|
|
} |
|
|
|
panic!("{}", err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
fn get_client_from_identifier (identifier: &str, conn: &DbConn) -> CoreClient { |
|
|
|
let organization = Organization::find_by_identifier(identifier, conn); |
|
|
|
|
|
|
|
#[get("/connect/authorize?<domain_hint>")] |
|
|
|
fn authorize( |
|
|
|
domain_hint: &RawStr, |
|
|
|
conn: DbConn, |
|
|
|
) { |
|
|
|
let empty_result = json!({}); |
|
|
|
let organization = Organization::find_by_identifier(domain_hint.as_str(), &conn); |
|
|
|
match organization { |
|
|
|
Some(organization) => { |
|
|
|
println!("found org. authority: {}", organization.authority); |
|
|
|
let redirect = Some(organization.callback_path.to_string()); |
|
|
|
let redirect = organization.callback_path.to_string(); |
|
|
|
let issuer = reqwest::Url::parse(&organization.authority).unwrap(); |
|
|
|
println!("got issuer: {}", issuer); |
|
|
|
// return Ok(Json(empty_result));
|
|
|
|
let client_id = ClientId::new(organization.client_id); |
|
|
|
let client_secret = ClientSecret::new(organization.client_secret); |
|
|
|
let issuer_url = IssuerUrl::new(organization.authority).expect("invalid issuer URL"); |
|
|
|
let provider_metadata = CoreProviderMetadata::discover(&issuer_url, http_client) |
|
|
|
.unwrap_or_else(|err| { |
|
|
|
handle_error(&err, "Failed to discover OpenID Provider"); |
|
|
|
unreachable!(); |
|
|
|
}); |
|
|
|
let client = CoreClient::from_provider_metadata( |
|
|
|
provider_metadata, |
|
|
|
client_id, |
|
|
|
Some(client_secret), |
|
|
|
) |
|
|
|
.set_redirect_uri(RedirectUrl::new(redirect).expect("Invalid redirect URL")); |
|
|
|
return client; |
|
|
|
}, |
|
|
|
None => { |
|
|
|
println!("error"); |
|
|
|
// return invalid_json("No Organization found", false);
|
|
|
|
} |
|
|
|
panic!("unable to find org"); |
|
|
|
}, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#[get("/connect/authorize?<domain_hint>&<state>")] |
|
|
|
fn authorize( |
|
|
|
domain_hint: &RawStr, |
|
|
|
state: &RawStr, |
|
|
|
conn: DbConn, |
|
|
|
) -> Redirect { |
|
|
|
let empty_result = json!({}); |
|
|
|
let client = get_client_from_identifier(domain_hint.as_str(), &conn); |
|
|
|
|
|
|
|
let (mut authorize_url, csrf_state, _nonce) = client |
|
|
|
.authorize_url( |
|
|
|
AuthenticationFlow::<CoreResponseType>::AuthorizationCode, |
|
|
|
CsrfToken::new_random, |
|
|
|
Nonce::new_random, |
|
|
|
) |
|
|
|
.add_scope(Scope::new("email".to_string())) |
|
|
|
.add_scope(Scope::new("profile".to_string())) |
|
|
|
.url(); |
|
|
|
|
|
|
|
// it seems impossible to set the state going in dynamically (requires static lifetime string)
|
|
|
|
// so I change it after the fact (will it work? Let's find out)
|
|
|
|
let old_pairs = authorize_url.query_pairs().clone(); |
|
|
|
let new_pairs = old_pairs.map(|pair| { |
|
|
|
let (key, value) = pair; |
|
|
|
if key == "state" { |
|
|
|
return format!("{}={}", key, state); |
|
|
|
} |
|
|
|
return format!("{}={}", key, value); |
|
|
|
}); |
|
|
|
let full_query = Vec::from_iter(new_pairs).join("&"); |
|
|
|
authorize_url.set_query(Some(full_query.as_str())); |
|
|
|
|
|
|
|
// return Redirect::to(rocket::uri!(&authorize_url.to_string()));
|
|
|
|
return Redirect::to(authorize_url.to_string()); |
|
|
|
// return Ok(Json(empty_result));
|
|
|
|
} |
|
|
|
|
|
|
|
fn get_auth_code_access_token ( |
|
|
|
code: String, |
|
|
|
org_identifier: String, |
|
|
|
conn: &DbConn, |
|
|
|
) -> (String, String) { |
|
|
|
let oidc_code = AuthorizationCode::new(code); |
|
|
|
|
|
|
|
println!("code: {}", oidc_code.secret()); |
|
|
|
println!("identifier: {}", org_identifier); |
|
|
|
|
|
|
|
let client = get_client_from_identifier(&org_identifier, conn); |
|
|
|
|
|
|
|
let token_response = client |
|
|
|
.exchange_code(oidc_code) |
|
|
|
.request(http_client) |
|
|
|
.unwrap_or_else(|err| { |
|
|
|
handle_error(&err, "Failed to contact token endpoint"); |
|
|
|
unreachable!(); |
|
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
let access_token = token_response.access_token().secret().to_string(); |
|
|
|
let refresh_token = token_response.refresh_token().unwrap().secret().to_string(); |
|
|
|
println!("access token: {}, refresh token: {}", access_token, refresh_token); |
|
|
|
|
|
|
|
(access_token, refresh_token) |
|
|
|
} |
|
|
|