diff --git a/Cargo.toml b/Cargo.toml index cf9e3fac..80bf7d8d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -137,6 +137,8 @@ backtrace = "0.3.60" # Macro ident concatenation paste = "1.0.5" +openidconnect = "2.0.1" +urlencoding = "1.1.1" [patch.crates-io] # Use newest ring diff --git a/src/api/identity.rs b/src/api/identity.rs index 7a13cdfd..ba34a0ba 100644 --- a/src/api/identity.rs +++ b/src/api/identity.rs @@ -3,10 +3,12 @@ use num_traits::FromPrimitive; use rocket::{ http::{RawStr, Status}, request::{Form, FormItems, FromForm}, + response::Redirect, Route, }; use rocket_contrib::json::Json; use serde_json::Value; +use std::iter::FromIterator; use crate::{ api::{ @@ -44,6 +46,13 @@ fn login(data: Form, conn: DbConn, ip: ClientIp) -> JsonResult { _password_login(data, conn, &ip) } + "authorization_code" => { + _check_is_some(&data.code, "code cannot be blank")?; + _check_is_some(&data.org_identifier, "org_identifier cannot be blank")?; + _check_is_some(&data.device_identifier, "device identifier cannot be blank")?; + + _authorization_login(data, conn) + } t => err!("Invalid type", t), } } @@ -78,6 +87,32 @@ fn _refresh_login(data: ConnectData, conn: DbConn) -> JsonResult { }))) } +fn _authorization_login(data: ConnectData, conn: DbConn) -> JsonResult { + let (access_token, refresh_token) = get_auth_code_access_token(data.code.unwrap(), data.org_identifier.unwrap(), &conn); + // let expiry = jsonwebtoken::decode_header(access_token.as_str()).unwrap(); + let time_now = std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap().as_secs(); + + let mut device = Device::find_by_uuid(&data.device_identifier.unwrap(), &conn).map_res("device not found")?; + + // COMMON + let user = User::find_by_uuid(&device.user_uuid, &conn).unwrap(); + + Ok(Json(json!({ + "access_token": access_token, + "expires_in": 1000000, + "token_type": "Bearer", + "refresh_token": device.refresh_token, + "Key": user.akey, + "PrivateKey": user.private_key, + + "Kdf": user.client_kdf_type, + "KdfIterations": user.client_kdf_iter, + "ResetMasterPassword": false, // TODO: according to official server seems something like: user.password_hash.is_empty(), but would need testing + "scope": "api offline_access", + "unofficialServer": true, + }))) +} + fn _password_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult { // Validate scope let scope = data.scope.as_ref().unwrap(); @@ -393,6 +428,10 @@ struct ConnectData { two_factor_provider: Option, two_factor_token: Option, two_factor_remember: Option, + + // Needed for authorization code + code: Option, + org_identifier: Option, } impl<'f> FromForm<'f> for ConnectData { @@ -419,6 +458,8 @@ impl<'f> FromForm<'f> for ConnectData { "twofactorprovider" => form.two_factor_provider = value.parse().ok(), "twofactortoken" => form.two_factor_token = Some(value), "twofactorremember" => form.two_factor_remember = value.parse().ok(), + "code" => form.code = Some(value), + "orgidentifier" => form.org_identifier = Some(value), key => warn!("Detected unexpected parameter during login: {}", key), } } @@ -465,25 +506,119 @@ fn prevalidate(domainHint: &RawStr, conn: DbConn) -> JsonResult { Ok(Json(empty_result)) } +use openidconnect::core::{ + CoreProviderMetadata, CoreClient, + CoreResponseType, +}; +use openidconnect::reqwest::http_client; +use openidconnect::{ + AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, + CsrfToken, IssuerUrl, Nonce, RedirectUrl, + Scope, OAuth2TokenResponse, +}; + +fn handle_error(fail: &T, msg: &'static str) { + let mut err_msg = format!("ERROR: {}", msg); + let mut cur_fail: Option<&dyn std::error::Error> = Some(fail); + while let Some(cause) = cur_fail { + err_msg += &format!("\n caused by: {}", cause); + cur_fail = cause.source(); + } + panic!("{}", err_msg); +} + +fn get_client_from_identifier (identifier: &str, conn: &DbConn) -> CoreClient { + let organization = Organization::find_by_identifier(identifier, conn); -#[get("/connect/authorize?")] -fn authorize( - domain_hint: &RawStr, - conn: DbConn, -) { - let empty_result = json!({}); - let organization = Organization::find_by_identifier(domain_hint.as_str(), &conn); match organization { Some(organization) => { println!("found org. authority: {}", organization.authority); - let redirect = Some(organization.callback_path.to_string()); + let redirect = organization.callback_path.to_string(); let issuer = reqwest::Url::parse(&organization.authority).unwrap(); println!("got issuer: {}", issuer); - // return Ok(Json(empty_result)); + let client_id = ClientId::new(organization.client_id); + let client_secret = ClientSecret::new(organization.client_secret); + let issuer_url = IssuerUrl::new(organization.authority).expect("invalid issuer URL"); + let provider_metadata = CoreProviderMetadata::discover(&issuer_url, http_client) + .unwrap_or_else(|err| { + handle_error(&err, "Failed to discover OpenID Provider"); + unreachable!(); + }); + let client = CoreClient::from_provider_metadata( + provider_metadata, + client_id, + Some(client_secret), + ) + .set_redirect_uri(RedirectUrl::new(redirect).expect("Invalid redirect URL")); + return client; }, None => { - println!("error"); - // return invalid_json("No Organization found", false); - } + panic!("unable to find org"); + }, } } + +#[get("/connect/authorize?&")] +fn authorize( + domain_hint: &RawStr, + state: &RawStr, + conn: DbConn, +) -> Redirect { + let empty_result = json!({}); + let client = get_client_from_identifier(domain_hint.as_str(), &conn); + + let (mut authorize_url, csrf_state, _nonce) = client + .authorize_url( + AuthenticationFlow::::AuthorizationCode, + CsrfToken::new_random, + Nonce::new_random, + ) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("profile".to_string())) + .url(); + + // it seems impossible to set the state going in dynamically (requires static lifetime string) + // so I change it after the fact (will it work? Let's find out) + let old_pairs = authorize_url.query_pairs().clone(); + let new_pairs = old_pairs.map(|pair| { + let (key, value) = pair; + if key == "state" { + return format!("{}={}", key, state); + } + return format!("{}={}", key, value); + }); + let full_query = Vec::from_iter(new_pairs).join("&"); + authorize_url.set_query(Some(full_query.as_str())); + + // return Redirect::to(rocket::uri!(&authorize_url.to_string())); + return Redirect::to(authorize_url.to_string()); + // return Ok(Json(empty_result)); +} + +fn get_auth_code_access_token ( + code: String, + org_identifier: String, + conn: &DbConn, +) -> (String, String) { + let oidc_code = AuthorizationCode::new(code); + + println!("code: {}", oidc_code.secret()); + println!("identifier: {}", org_identifier); + + let client = get_client_from_identifier(&org_identifier, conn); + + let token_response = client + .exchange_code(oidc_code) + .request(http_client) + .unwrap_or_else(|err| { + handle_error(&err, "Failed to contact token endpoint"); + unreachable!(); + }); + + + let access_token = token_response.access_token().secret().to_string(); + let refresh_token = token_response.refresh_token().unwrap().secret().to_string(); + println!("access token: {}, refresh token: {}", access_token, refresh_token); + + (access_token, refresh_token) +} diff --git a/src/db/models/organization.rs b/src/db/models/organization.rs index 538fdfbf..c9b4bfa8 100644 --- a/src/db/models/organization.rs +++ b/src/db/models/organization.rs @@ -142,8 +142,8 @@ impl Organization { public_key, identifier: String::from(""), use_sso: false, - callback_path: String::from("http://localhost/oidc-signin"), - signed_out_callback_path: String::from("http://localhost/sso/oidc-signin"), + callback_path: String::from("http://localhost/#/sso/"), + signed_out_callback_path: String::from("http://localhost/#/sso/"), authority: String::from(""), client_id: String::from(""), client_secret: String::from(""),