From 83317d90b7201779ca76fccbc93a45903a6d5ca9 Mon Sep 17 00:00:00 2001 From: Mathijs van Veluw Date: Sat, 2 May 2026 18:56:15 +0200 Subject: [PATCH] Update to Rust 2024 Edition Updated to the Rust 2024 Edition and added and fixed several lint checks. This is a large change which, because of the extra lints, added some possible fixes for issues. Signed-off-by: BlackDex --- Cargo.toml | 169 ++++++++---- build.rs | 10 +- macros/src/lib.rs | 9 +- src/api/admin.rs | 133 +++++---- src/api/core/accounts.rs | 140 +++++----- src/api/core/ciphers.rs | 216 +++++++-------- src/api/core/emergency_access.rs | 52 ++-- src/api/core/events.rs | 114 ++++---- src/api/core/folders.rs | 9 +- src/api/core/mod.rs | 54 ++-- src/api/core/organizations.rs | 260 +++++++++--------- src/api/core/public.rs | 119 ++++---- src/api/core/sends.rs | 66 ++--- src/api/core/two_factor/authenticator.rs | 21 +- src/api/core/two_factor/duo.rs | 32 +-- src/api/core/two_factor/duo_oidc.rs | 39 ++- src/api/core/two_factor/email.rs | 39 +-- src/api/core/two_factor/mod.rs | 34 +-- src/api/core/two_factor/protected_actions.rs | 26 +- src/api/core/two_factor/webauthn.rs | 51 ++-- src/api/core/two_factor/yubikey.rs | 19 +- src/api/icons.rs | 75 +++-- src/api/identity.rs | 272 ++++++++++--------- src/api/mod.rs | 5 +- src/api/notifications.rs | 27 +- src/api/push.rs | 48 ++-- src/api/web.rs | 14 +- src/auth.rs | 118 ++++---- src/config.rs | 261 +++++++++--------- src/db/mod.rs | 53 ++-- src/db/models/archive.rs | 2 +- src/db/models/attachment.rs | 2 +- src/db/models/cipher.rs | 146 +++++----- src/db/models/collection.rs | 12 +- src/db/models/device.rs | 1 + src/db/models/emergency_access.rs | 13 +- src/db/models/event.rs | 2 +- src/db/models/folder.rs | 2 +- src/db/models/group.rs | 8 +- src/db/models/mod.rs | 4 +- src/db/models/org_policy.rs | 64 ++--- src/db/models/organization.rs | 16 +- src/db/models/send.rs | 51 ++-- src/db/models/two_factor.rs | 11 +- src/db/models/two_factor_duo_context.rs | 2 +- src/db/models/two_factor_incomplete.rs | 4 +- src/db/models/user.rs | 25 +- src/db/query_logger.rs | 8 +- src/error.rs | 89 +++--- src/http_client.rs | 53 ++-- src/mail.rs | 49 ++-- src/main.rs | 51 ++-- src/ratelimit.rs | 8 +- src/sso.rs | 62 +++-- src/sso_client.rs | 50 ++-- src/storage.rs | 30 +- src/util.rs | 83 +++--- 57 files changed, 1710 insertions(+), 1623 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 61910302..599af5c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace.package] -edition = "2021" +edition = "2024" rust-version = "1.93.0" license = "AGPL-3.0-only" repository = "https://github.com/dani-garcia/vaultwarden" @@ -23,7 +23,8 @@ publish.workspace = true [features] default = [ - # "sqlite" or "sqlite_system", + # "sqlite", + # "sqlite_system", # "mysql", # "postgresql", ] @@ -32,14 +33,22 @@ enable_syslog = [] # Please enable at least one of these DB backends. mysql = ["diesel/mysql", "diesel_migrations/mysql"] postgresql = ["diesel/postgres", "diesel_migrations/postgres"] -sqlite_system = ["diesel/sqlite", "diesel_migrations/sqlite"] -sqlite = ["sqlite_system", "libsqlite3-sys/bundled"] # Alternative to the above, statically linked SQLite into the binary instead of dynamically. +sqlite_system = ["diesel/sqlite", "diesel_migrations/sqlite"] # Dynamically link SQLite +sqlite = ["sqlite_system", "libsqlite3-sys/bundled"] # Statically link SQLite into the binary instead of dynamically. # Enable to use a vendored and statically linked openssl vendored_openssl = ["openssl/vendored"] # Enable MiMalloc memory allocator to replace the default malloc # This can improve performance for Alpine builds enable_mimalloc = ["dep:mimalloc"] -s3 = ["opendal/services-s3", "dep:aws-config", "dep:aws-credential-types", "dep:aws-smithy-runtime-api", "dep:http", "dep:reqsign-aws-v4", "dep:reqsign-core"] +s3 = [ + "opendal/services-s3", + "dep:aws-config", + "dep:aws-credential-types", + "dep:aws-smithy-runtime-api", + "dep:http", + "dep:reqsign-aws-v4", + "dep:reqsign-core", +] # OIDC specific features oidc-accept-rfc3339-timestamps = ["openidconnect/accept-rfc3339-timestamps"] @@ -59,7 +68,8 @@ macros = { path = "./macros" } # Logging log = "0.4.29" fern = { version = "0.7.1", features = ["syslog-7", "reopen-1"] } -tracing = { version = "0.1.44", features = ["log"] } # Needed to have lettre and webauthn-rs trace logging to work +# We need the `log` feature for `tracing` to enable logging for several crates to work, like lettre or webauthn-rs +tracing = { version = "0.1.44", features = ["log"] } # A `dotenv` implementation for Rust dotenvy = { version = "0.15.7", default-features = false } @@ -70,8 +80,8 @@ num-derive = "0.4.2" bigdecimal = "0.4.10" # Web framework -rocket = { version = "0.5.1", features = ["tls", "json"], default-features = false } -rocket_ws = { version ="0.1.1" } +rocket = { version = "0.5.1", default-features = false, features = ["json", "tls"] } +rocket_ws = { version = "0.1.1" } # WebSockets libraries rmpv = "1.3.1" # MessagePack library @@ -81,19 +91,32 @@ dashmap = "6.1.0" # Async futures futures = "0.3.32" -tokio = { version = "1.52.3", features = ["rt-multi-thread", "fs", "io-util", "parking_lot", "time", "signal", "net"] } -tokio-util = { version = "0.7.18", features = ["compat"]} +tokio = { version = "1.52.3", features = [ + "fs", + "io-util", + "net", + "parking_lot", + "rt-multi-thread", + "signal", + "time", +] } +tokio-util = { version = "0.7.18", features = ["compat"] } # A generic serialization/deserialization framework serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.149" # A safe, extensible ORM and Query builder -# Currently pinned diesel to v2.3.3 as newer version break MySQL/MariaDB compatibility diesel = { version = "2.3.9", features = ["chrono", "r2d2", "numeric"] } diesel_migrations = "2.3.2" -derive_more = { version = "2.1.1", features = ["from", "into", "as_ref", "deref", "display"] } +derive_more = { version = "2.1.1", features = [ + "as_ref", + "deref", + "display", + "from", + "into", +] } diesel-derive-newtype = "2.1.2" # SQLite, statically bundled unless the `sqlite_system` feature is enabled @@ -109,7 +132,7 @@ subtle = "2.6.1" uuid = { version = "1.23.1", features = ["v4"] } # Date and time libraries -chrono = { version = "0.4.44", features = ["clock", "serde"], default-features = false } +chrono = { version = "0.4.44", default-features = false, features = ["clock", "serde"] } chrono-tz = "0.10.4" time = "0.3.47" @@ -120,13 +143,13 @@ job_scheduler_ng = "2.4.0" data-encoding = "2.11.0" # JWT library -jsonwebtoken = { version = "10.4.0", features = ["use_pem", "rust_crypto"], default-features = false } +jsonwebtoken = { version = "10.4.0", default-features = false, features = ["rust_crypto", "use_pem"] } # TOTP library totp-lite = "2.0.1" # Yubico Library -yubico = { package = "yubico_ng", version = "0.15.0", features = ["online-tokio"], default-features = false } +yubico = { package = "yubico_ng", version = "0.15.0", default-features = false, features = ["online-tokio"] } # WebAuthn libraries # danger-allow-state-serialisation is needed to save the state in the db @@ -139,7 +162,20 @@ webauthn-rs-core = "0.5.5" url = "2.5.8" # Email libraries -lettre = { version = "0.11.22", features = ["smtp-transport", "sendmail-transport", "builder", "serde", "hostname", "tracing", "tokio1-rustls", "ring", "rustls-native-certs"], default-features = false } +lettre = { version = "0.11.22", default-features = false, features = [ + # Misc + "tracing", + "serde", + "builder", + "hostname", + # TLS/Security + "ring", + "rustls-native-certs", + "tokio1-rustls", + # Transport + "smtp-transport", + "sendmail-transport", +] } percent-encoding = "2.3.2" # URL encoding library used for URL's in the emails email_address = "0.2.9" @@ -147,12 +183,33 @@ email_address = "0.2.9" handlebars = { version = "6.4.0", features = ["dir_source"] } # HTTP client (Used for favicons, version check, DUO and HIBP API) -reqwest = { version = "0.13.3", features = ["rustls-no-provider", "stream", "json", "form", "deflate", "gzip", "brotli", "zstd", "socks", "cookies", "charset", "http2", "system-proxy"], default-features = false} +reqwest = { version = "0.13.3", default-features = false, features = [ + # Misc + "charset", + "cookies", + "http2", + "json", + "form", + "rustls-no-provider", + "stream", + # Compression + "brotli", + "deflate", + "gzip", + "zstd", + # Proxy + "socks", + "system-proxy", +] } hickory-resolver = "0.26.1" # Favicon extraction libraries html5gum = "0.8.3" -regex = { version = "1.12.3", features = ["std", "perf", "unicode-perl"], default-features = false } +regex = { version = "1.12.3", default-features = false, features = [ + "perf", + "std", + "unicode-perl", +] } data-url = "0.3.2" bytes = "1.11.1" svg-hush = "0.9.6" @@ -183,7 +240,7 @@ semver = "1.0.28" # Allow overriding the default memory allocator # Mainly used for the musl builds, since the default musl malloc is very slow -mimalloc = { version = "0.1.50", features = ["secure"], default-features = false, optional = true } +mimalloc = { version = "0.1.50", optional = true, default-features = false, features = ["secure"] } which = "8.0.2" @@ -197,10 +254,15 @@ rpassword = "7.5.2" grass_compiler = { version = "0.13.4", default-features = false } # File are accessed through Apache OpenDAL -opendal = { version = "0.56.0", features = ["services-fs"], default-features = false } +opendal = { version = "0.56.0", default-features = false, features = ["services-fs"] } # For retrieving AWS credentials, including temporary SSO credentials -aws-config = { version = "1.8.16", features = ["behavior-version-latest", "rt-tokio", "credentials-process", "sso"], default-features = false, optional = true } +aws-config = { version = "1.8.16", optional = true, default-features = false, features = [ + "behavior-version-latest", + "credentials-process", + "rt-tokio", + "sso", +] } aws-credential-types = { version = "1.2.14", optional = true } aws-smithy-runtime-api = { version = "1.12.0", optional = true } http = { version = "1.4.0", optional = true } @@ -265,77 +327,74 @@ unsafe_code = "forbid" non_ascii_idents = "forbid" # Deny -deprecated_in_future = "deny" +warnings = "deny" # Explicitly deny all warnings since we deny all warnings in the end + +# Deny lint groups deprecated_safe = { level = "deny", priority = -1 } future_incompatible = { level = "deny", priority = -1 } keyword_idents = { level = "deny", priority = -1 } let_underscore = { level = "deny", priority = -1 } nonstandard_style = { level = "deny", priority = -1 } -noop_method_call = "deny" refining_impl_trait = { level = "deny", priority = -1 } rust_2018_idioms = { level = "deny", priority = -1 } rust_2021_compatibility = { level = "deny", priority = -1 } rust_2024_compatibility = { level = "deny", priority = -1 } +unused = { level = "deny", priority = -1 } + +# Deny individual lints +closure_returning_async_block = "deny" +deprecated_in_future = "deny" single_use_lifetimes = "deny" trivial_casts = "deny" trivial_numeric_casts = "deny" -unused = { level = "deny", priority = -1 } unused_import_braces = "deny" unused_lifetimes = "deny" unused_qualifications = "deny" variant_size_differences = "deny" -# Allow the following lints since these cause issues with Rust v1.84.0 or newer -# Building Vaultwarden with Rust v1.85.0 with edition 2024 also works without issues -edition_2024_expr_fragment_specifier = "allow" # Once changed to Rust 2024 this should be removed and macro's should be validated again -if_let_rescope = "allow" -tail_expr_drop_order = "allow" # https://rust-lang.github.io/rust-clippy/stable/index.html [workspace.lints.clippy] -# Warn +# Warn only so you can still use these during development, but not in the final code dbg_macro = "warn" todo = "warn" # Ignore/Allow result_large_err = "allow" -# Deny +# Warn on these lint group (Some might be warn by default already though) +# Will be denied during CI! +complexity = { level = "warn", priority = -1 } +pedantic = { level = "warn", priority = -1 } +perf = { level = "warn", priority = -1 } +style = { level = "warn", priority = -1 } +suspicious = { level = "warn", priority = -1 } + +# Deny individual lints branches_sharing_code = "deny" -case_sensitive_file_extension_comparisons = "deny" -cast_lossless = "deny" clone_on_ref_ptr = "deny" -duration_suboptimal_units = "deny" equatable_if_let = "deny" -excessive_precision = "deny" -filter_map_next = "deny" float_cmp_const = "deny" -implicit_clone = "deny" -inefficient_to_string = "deny" iter_on_empty_collections = "deny" iter_on_single_items = "deny" -linkedlist = "deny" -macro_use_imports = "deny" -manual_assert = "deny" -manual_instant_elapsed = "deny" -manual_string_new = "deny" -match_wildcard_for_single_variants = "deny" mem_forget = "deny" -needless_borrow = "deny" needless_collect = "deny" -needless_continue = "deny" -needless_lifetimes = "deny" -option_option = "deny" redundant_clone = "deny" -ref_option = "deny" -string_add_assign = "deny" -unnecessary_join = "deny" unnecessary_self_imports = "deny" -unnested_or_patterns = "deny" -unused_async = "deny" -unused_self = "deny" useless_let_if_seq = "deny" verbose_file_reads = "deny" -zero_sized_map_values = "deny" +str_to_string = "deny" + +# Pedantic Opt-Outs +inline_always = "allow" # We use this sparsely +struct_field_names = "allow" # Noisy and some items are Bitwarden controlled +large_futures = "allow" # Causes a fail in some Rocket macro's, since we experience no issues, allow it +too_many_lines = "allow" # For now, allow this, good to enable in the future and see if we can refactor +unnecessary_wraps = "allow" # Too much false positives because of Rocket integrations +# We do not use these doc items +doc_link_with_quotes = "allow" +doc_markdown = "allow" +missing_errors_doc = "allow" +missing_panics_doc = "allow" [lints] workspace = true diff --git a/build.rs b/build.rs index 2d1106c2..32fcf845 100644 --- a/build.rs +++ b/build.rs @@ -1,5 +1,4 @@ -use std::env; -use std::process::Command; +use std::{env, io::Error, process::Command}; fn main() { // These allow using e.g. #[cfg(mysql)] instead of #[cfg(feature = "mysql")], which helps when trying to add them through macros @@ -42,13 +41,12 @@ fn main() { } } -fn run(args: &[&str]) -> Result { +fn run(args: &[&str]) -> Result { let out = Command::new(args[0]).args(&args[1..]).output()?; if !out.status.success() { - use std::io::Error; return Err(Error::other("Command not successful")); } - Ok(String::from_utf8(out.stdout).unwrap().trim().to_string()) + Ok(String::from_utf8(out.stdout).unwrap().trim().to_owned()) } /// This method reads info from Git, namely tags, branch, and revision @@ -58,7 +56,7 @@ fn run(args: &[&str]) -> Result { /// - `env!("GIT_BRANCH")` /// - `env!("GIT_REV")` /// - `env!("VW_VERSION")` -fn version_from_git_info() -> Result { +fn version_from_git_info() -> Result { // The exact tag for the current commit, can be empty when // the current commit doesn't have an associated tag let exact_tag = run(&["git", "describe", "--abbrev=0", "--tags", "--exact-match"]).ok(); diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 2d923ce1..73b23a22 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,14 +1,15 @@ use proc_macro::TokenStream; use quote::quote; +use syn::{DeriveInput, parse_macro_input}; #[proc_macro_derive(UuidFromParam)] pub fn derive_uuid_from_param(input: TokenStream) -> TokenStream { - let ast = syn::parse(input).unwrap(); + let ast = parse_macro_input!(input as DeriveInput); impl_derive_uuid_macro(&ast) } -fn impl_derive_uuid_macro(ast: &syn::DeriveInput) -> TokenStream { +fn impl_derive_uuid_macro(ast: &DeriveInput) -> TokenStream { let name = &ast.ident; let gen_derive = quote! { #[automatically_derived] @@ -30,12 +31,12 @@ fn impl_derive_uuid_macro(ast: &syn::DeriveInput) -> TokenStream { #[proc_macro_derive(IdFromParam)] pub fn derive_id_from_param(input: TokenStream) -> TokenStream { - let ast = syn::parse(input).unwrap(); + let ast = parse_macro_input!(input as DeriveInput); impl_derive_safestring_macro(&ast) } -fn impl_derive_safestring_macro(ast: &syn::DeriveInput) -> TokenStream { +fn impl_derive_safestring_macro(ast: &DeriveInput) -> TokenStream { let name = &ast.ident; let gen_derive = quote! { #[automatically_derived] diff --git a/src/api/admin.rs b/src/api/admin.rs index 02c976cc..cb31a353 100644 --- a/src/api/admin.rs +++ b/src/api/admin.rs @@ -2,40 +2,40 @@ use std::{env, sync::LazyLock}; use reqwest::Method; use rocket::{ + Catcher, Route, form::Form, http::{Cookie, CookieJar, MediaType, SameSite, Status}, request::{FromRequest, Outcome, Request}, - response::{content::RawHtml as Html, Redirect}, + response::{Redirect, content::RawHtml as Html}, serde::json::Json, - Catcher, Route, }; use serde::de::DeserializeOwned; use serde_json::Value; use crate::{ + CONFIG, VERSION, api::{ + ApiResult, EmptyResult, JsonResult, Notify, core::{log_event, two_factor}, - unregister_push_device, ApiResult, EmptyResult, JsonResult, Notify, + unregister_push_device, }, - auth::{decode_admin, encode_jwt, generate_admin_claims, ClientIp, Secure}, + auth::{ClientIp, Secure, decode_admin, encode_jwt, generate_admin_claims}, config::ConfigBuilder, db::{ - backup_sqlite, get_sql_server_version, + ACTIVE_DB_TYPE, DbConn, DbConnType, backup_sqlite, get_sql_server_version, models::{ Attachment, Cipher, Collection, Device, Event, EventType, Group, Invitation, Membership, MembershipId, MembershipType, OrgPolicy, Organization, OrganizationId, SsoUser, TwoFactor, User, UserId, }, - DbConn, DbConnType, ACTIVE_DB_TYPE, }, error::{Error, MapResult}, http_client::make_http_request, mail, sso::FAKE_SSO_IDENTIFIER, util::{ - container_base_image, format_naive_datetime_local, get_active_web_release, get_display_size, - is_running_in_container, parse_experimental_client_feature_flags, FeatureFlagFilter, NumberOrString, + FeatureFlagFilter, NumberOrString, container_base_image, format_naive_datetime_local, get_active_web_release, + get_display_size, is_running_in_container, parse_experimental_client_feature_flags, }, - CONFIG, VERSION, }; pub fn routes() -> Vec { @@ -93,8 +93,7 @@ static DB_TYPE: LazyLock<&str> = LazyLock::new(|| match ACTIVE_DB_TYPE.get() { }); #[cfg(sqlite)] -static CAN_BACKUP: LazyLock = - LazyLock::new(|| ACTIVE_DB_TYPE.get().map(|t| *t == DbConnType::Sqlite).unwrap_or(false)); +static CAN_BACKUP: LazyLock = LazyLock::new(|| ACTIVE_DB_TYPE.get().is_some_and(|t| *t == DbConnType::Sqlite)); #[cfg(not(sqlite))] static CAN_BACKUP: LazyLock = LazyLock::new(|| false); @@ -200,13 +199,7 @@ fn post_admin_login( } // If the token is invalid, redirect to login page - if !_validate_token(&data.token) { - error!("Invalid admin token. IP: {}", ip.ip); - Err(AdminResponse::Unauthorized(render_admin_login( - Some("Invalid admin token, please try again."), - redirect.as_deref(), - ))) - } else { + if validate_token(&data.token) { // If the token received is valid, generate JWT and save it as a cookie let claims = generate_admin_claims(); let jwt = encode_jwt(&claims); @@ -224,10 +217,16 @@ fn post_admin_login( } else { Err(AdminResponse::Ok(render_admin_page())) } + } else { + error!("Invalid admin token. IP: {}", ip.ip); + Err(AdminResponse::Unauthorized(render_admin_login( + Some("Invalid admin token, please try again."), + redirect.as_deref(), + ))) } } -fn _validate_token(token: &str) -> bool { +fn validate_token(token: &str) -> bool { match CONFIG.admin_token().as_ref() { None => false, Some(t) if t.starts_with("$argon2") => { @@ -307,21 +306,14 @@ async fn get_user_or_404(user_id: &UserId, conn: &DbConn) -> ApiResult { #[post("/invite", format = "application/json", data = "")] async fn invite_user(data: Json, _token: AdminToken, conn: DbConn) -> JsonResult { - let data: InviteData = data.into_inner(); - if User::find_by_mail(&data.email, &conn).await.is_some() { - err_code!("User already exists", Status::Conflict.code) - } - - let mut user = User::new(&data.email, None); - - async fn _generate_invite(user: &User, conn: &DbConn) -> EmptyResult { + async fn generate_invite(user: &User, conn: &DbConn) -> EmptyResult { if CONFIG.mail_enabled() { let org_id: OrganizationId = if CONFIG.sso_enabled() { FAKE_SSO_IDENTIFIER.into() } else { FAKE_ADMIN_UUID.into() }; - let member_id: MembershipId = FAKE_ADMIN_UUID.to_string().into(); + let member_id: MembershipId = FAKE_ADMIN_UUID.to_owned().into(); mail::send_invite(user, org_id, member_id, &CONFIG.invitation_org_name(), None).await } else { let invitation = Invitation::new(&user.email); @@ -329,7 +321,14 @@ async fn invite_user(data: Json, _token: AdminToken, conn: DbConn) - } } - _generate_invite(&user, &conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?; + let data: InviteData = data.into_inner(); + if User::find_by_mail(&data.email, &conn).await.is_some() { + err_code!("User already exists", Status::Conflict.code) + } + + let mut user = User::new(&data.email, None); + + generate_invite(&user, &conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?; user.save(&conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?; Ok(Json(user.to_json(&conn).await)) @@ -386,7 +385,7 @@ async fn users_overview(_token: AdminToken, conn: DbConn) -> ApiResult json!("Never"), }; - usr["sso_identifier"] = json!(sso_u.map(|u| u.identifier.to_string()).unwrap_or(String::new())); + usr["sso_identifier"] = json!(sso_u.map_or(String::new(), |u| u.identifier.to_string())); users_json.push(usr); } @@ -472,7 +471,7 @@ async fn deauth_user(user_id: UserId, _token: AdminToken, conn: DbConn, nt: Noti match unregister_push_device(device.push_uuid.as_ref()).await { Ok(r) => r, Err(e) => error!("Unable to unregister devices from Bitwarden server: {e}"), - }; + } } } @@ -528,7 +527,7 @@ async fn resend_user_invite(user_id: UserId, _token: AdminToken, conn: DbConn) - } else { FAKE_ADMIN_UUID.into() }; - let member_id: MembershipId = FAKE_ADMIN_UUID.to_string().into(); + let member_id: MembershipId = FAKE_ADMIN_UUID.to_owned().into(); mail::send_invite(&user, org_id, member_id, &CONFIG.invitation_org_name(), None).await } else { Ok(()) @@ -554,9 +553,10 @@ async fn update_membership_type(data: Json, token: AdminToke err!("The specified user isn't member of the organization") }; - let new_type = match MembershipType::from_str(&data.user_type.into_string()) { - Some(new_type) => new_type as i32, - None => err!("Invalid type"), + let new_type = if let Some(new_type) = MembershipType::from_str(&data.user_type.into_string()) { + new_type as i32 + } else { + err!("Invalid type") }; if member_to_edit.atype == MembershipType::Owner && new_type != MembershipType::Owner { @@ -656,42 +656,40 @@ async fn get_release_info(has_http_access: bool) -> (String, String, String) { .await { Ok(r) => r.tag_name, - _ => "-".to_string(), + _ => "-".to_owned(), }, match get_json_api::("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main").await { Ok(mut c) => { c.sha.truncate(8); c.sha } - _ => "-".to_string(), + _ => "-".to_owned(), }, // Do not fetch the web-vault version when running within a container // The web-vault version is embedded within the container it self, and should not be updated manually match get_json_api::("https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest") .await { - Ok(r) => r.tag_name.trim_start_matches('v').to_string(), - _ => "-".to_string(), + Ok(r) => r.tag_name.trim_start_matches('v').to_owned(), + _ => "-".to_owned(), }, ) } else { - ("-".to_string(), "-".to_string(), "-".to_string()) + ("-".to_owned(), "-".to_owned(), "-".to_owned()) } } async fn get_ntp_time(has_http_access: bool) -> String { - if has_http_access { - if let Ok(cf_trace) = get_text_api("https://cloudflare.com/cdn-cgi/trace").await { - for line in cf_trace.lines() { - if let Some((key, value)) = line.split_once('=') { - if key == "ts" { - let ts = value.split_once('.').map_or(value, |(s, _)| s); - if let Ok(dt) = chrono::DateTime::parse_from_str(ts, "%s") { - return dt.format("%Y-%m-%d %H:%M:%S UTC").to_string(); - } - break; - } + if has_http_access && let Ok(cf_trace) = get_text_api("https://cloudflare.com/cdn-cgi/trace").await { + for line in cf_trace.lines() { + if let Some((key, value)) = line.split_once('=') + && key == "ts" + { + let ts = value.split_once('.').map_or(value, |(s, _)| s); + if let Ok(dt) = chrono::DateTime::parse_from_str(ts, "%s") { + return dt.format("%Y-%m-%d %H:%M:%S UTC").to_string(); } + break; } } } @@ -734,7 +732,7 @@ async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> A // Check if we are able to resolve DNS entries let dns_resolved = match ("github.com", 0).to_socket_addrs().map(|mut i| i.next()) { Ok(Some(a)) => a.ip().to_string(), - _ => "Unable to resolve domain name.".to_string(), + _ => "Unable to resolve domain name.".to_owned(), }; let (latest_vw_release, latest_vw_commit, latest_web_release) = get_release_info(has_http_access).await; @@ -745,7 +743,7 @@ async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> A let invalid_feature_flags: Vec = parse_experimental_client_feature_flags( &CONFIG.experimental_client_feature_flags(), - FeatureFlagFilter::InvalidOnly, + &FeatureFlagFilter::InvalidOnly, ) .into_keys() .collect(); @@ -834,33 +832,30 @@ impl<'r> FromRequest<'r> for AdminToken { type Error = &'static str; async fn from_request(request: &'r Request<'_>) -> Outcome { - let ip = match ClientIp::from_request(request).await { - Outcome::Success(ip) => ip, - _ => err_handler!("Error getting Client IP"), + let Outcome::Success(ip) = ClientIp::from_request(request).await else { + err_handler!("Error getting Client IP") }; if !CONFIG.disable_admin_token() { let cookies = request.cookies(); - let access_token = match cookies.get(COOKIE_NAME) { - Some(cookie) => cookie.value(), - None => { - let requested_page = - request.segments::(0..).unwrap_or_default().display().to_string(); - // When the requested page is empty, it is `/admin`, in that case, Forward, so it will render the login page - // Else, return a 401 failure, which will be caught - if requested_page.is_empty() { - return Outcome::Forward(Status::Unauthorized); - } else { - return Outcome::Error((Status::Unauthorized, "Unauthorized")); - } + let access_token = if let Some(cookie) = cookies.get(COOKIE_NAME) { + cookie.value() + } else { + let requested_page = + request.segments::(0..).unwrap_or_default().display().to_string(); + // When the requested page is empty, it is `/admin`, in that case, Forward, so it will render the login page + // Else, return a 401 failure, which will be caught + if requested_page.is_empty() { + return Outcome::Forward(Status::Unauthorized); } + return Outcome::Error((Status::Unauthorized, "Unauthorized")); }; if decode_admin(access_token).is_err() { // Remove admin cookie cookies.remove(Cookie::build(COOKIE_NAME).path(admin_path())); - error!("Invalid or expired admin JWT. IP: {}.", &ip.ip); + error!("Invalid or expired admin JWT. IP: {}.", ip.ip); return Outcome::Error((Status::Unauthorized, "Session expired")); } } diff --git a/src/api/core/accounts.rs b/src/api/core/accounts.rs index a8f9768e..f2852e5c 100644 --- a/src/api/core/accounts.rs +++ b/src/api/core/accounts.rs @@ -2,33 +2,36 @@ use std::collections::HashSet; use crate::db::DbPool; use chrono::Utc; -use rocket::serde::json::Json; use serde_json::Value; use crate::{ + CONFIG, api::{ + AnonymousNotify, ApiResult, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType, core::{accept_org_invite, log_user_event, two_factor::email}, - master_password_policy, register_push_device, unregister_push_device, AnonymousNotify, ApiResult, EmptyResult, - JsonResult, Notify, PasswordOrOtpData, UpdateType, + master_password_policy, register_push_device, unregister_push_device, }, - auth::{decode_delete, decode_invite, decode_verify_email, ClientHeaders, Headers}, + auth::{ClientHeaders, Headers, decode_delete, decode_invite, decode_verify_email}, crypto, db::{ + DbConn, models::{ - AuthRequest, AuthRequestId, Cipher, CipherId, Device, DeviceId, DeviceType, EmergencyAccess, - EmergencyAccessId, EventType, Folder, FolderId, Invitation, Membership, MembershipId, OrgPolicy, - OrgPolicyType, Organization, OrganizationId, Send, SendId, User, UserId, UserKdfType, + AuthRequest, AuthRequestId, Cipher, CipherId, Device, DeviceId, DeviceType, DeviceWithAuthRequest, + EmergencyAccess, EmergencyAccessId, EventType, Folder, FolderId, Invitation, Membership, MembershipId, + OrgPolicy, OrgPolicyType, Organization, OrganizationId, Send, SendId, User, UserId, UserKdfType, }, - DbConn, }, mail, - util::{deser_opt_nonempty_str, format_date, NumberOrString}, - CONFIG, + util::{NumberOrString, deser_opt_nonempty_str, format_date}, }; +use super::ciphers::{CipherData, update_cipher_from_data}; +use super::sends::{SendData, update_send_from_data}; + use rocket::{ http::Status, request::{FromRequest, Outcome, Request}, + serde::json::Json, }; pub fn routes() -> Vec { @@ -54,9 +57,9 @@ pub fn routes() -> Vec { delete_account, revision_date, password_hint, - prelogin, + post_prelogin, verify_password, - api_key, + post_api_key, rotate_api_key, get_known_device, get_all_devices, @@ -142,7 +145,7 @@ fn clean_password_hint(password_hint: Option<&String>) -> Option { None => None, Some(h) => match h.trim() { "" => None, - ht => Some(ht.to_string()), + ht => Some(ht.to_owned()), }, } } @@ -166,7 +169,7 @@ async fn is_email_2fa_required(member_id: Option, conn: &DbConn) - false } -pub async fn _register(data: Json, email_verification: bool, conn: DbConn) -> JsonResult { +pub async fn register(data: Json, email_verification: bool, conn: DbConn) -> JsonResult { let mut data: RegisterData = data.into_inner(); let email = data.email.to_lowercase(); @@ -237,10 +240,10 @@ pub async fn _register(data: Json, email_verification: bool, conn: // Check if the length of the username exceeds 50 characters (Same is Upstream Bitwarden) // This also prevents issues with very long usernames causing to large JWT's. See #2419 - if let Some(ref name) = data.name { - if name.len() > 50 { - err!("The field Name must be a string with a maximum length of 50."); - } + if let Some(ref name) = data.name + && name.len() > 50 + { + err!("The field Name must be a string with a maximum length of 50."); } // Check against the password hint setting here so if it fails, the user @@ -373,18 +376,19 @@ async fn post_set_password(data: Json, headers: Headers, conn: user.public_key = Some(keys.public_key); } - if let Some(identifier) = data.org_identifier { - if identifier != crate::sso::FAKE_SSO_IDENTIFIER && identifier != crate::api::admin::FAKE_ADMIN_UUID { - let Some(org) = Organization::find_by_uuid(&identifier.into(), &conn).await else { - err!("Failed to retrieve the associated organization") - }; + if let Some(identifier) = data.org_identifier + && identifier != crate::sso::FAKE_SSO_IDENTIFIER + && identifier != crate::api::admin::FAKE_ADMIN_UUID + { + let Some(org) = Organization::find_by_uuid(&identifier.into(), &conn).await else { + err!("Failed to retrieve the associated organization") + }; - let Some(membership) = Membership::find_by_user_and_org(&user.uuid, &org.uuid, &conn).await else { - err!("Failed to retrieve the invitation") - }; + let Some(membership) = Membership::find_by_user_and_org(&user.uuid, &org.uuid, &conn).await else { + err!("Failed to retrieve the invitation") + }; - accept_org_invite(&user, membership, None, &conn).await?; - } + accept_org_invite(&user, membership, None, &conn).await?; } if CONFIG.mail_enabled() { @@ -451,10 +455,10 @@ async fn put_avatar(data: Json, headers: Headers, conn: DbConn) -> J // It looks like it only supports the 6 hex color format. // If you try to add the short value it will not show that color. // Check and force 7 chars, including the #. - if let Some(color) = &data.avatar_color { - if color.len() != 7 { - err!("The field AvatarColor must be a HTML/Hex color code with a length of 7 characters") - } + if let Some(color) = &data.avatar_color + && color.len() != 7 + { + err!("The field AvatarColor must be a HTML/Hex color code with a length of 7 characters") } let mut user = headers.user; @@ -668,9 +672,6 @@ struct UpdateResetPasswordData { reset_password_key: String, } -use super::ciphers::CipherData; -use super::sends::{update_send_from_data, SendData}; - #[derive(Deserialize)] #[serde(rename_all = "camelCase")] struct KeyData { @@ -840,7 +841,7 @@ async fn post_rotatekey(data: Json, headers: Headers, conn: DbConn, nt: }; saved_folder.name = folder_data.name; - saved_folder.save(&conn).await? + saved_folder.save(&conn).await?; } } @@ -853,7 +854,7 @@ async fn post_rotatekey(data: Json, headers: Headers, conn: DbConn, nt: }; saved_emergency_access.key_encrypted = Some(emergency_access_data.key_encrypted); - saved_emergency_access.save(&conn).await? + saved_emergency_access.save(&conn).await?; } // Update reset password data @@ -865,7 +866,7 @@ async fn post_rotatekey(data: Json, headers: Headers, conn: DbConn, nt: }; membership.reset_password_key = Some(reset_password_data.reset_password_key); - membership.save(&conn).await? + membership.save(&conn).await?; } // Update send data @@ -878,8 +879,6 @@ async fn post_rotatekey(data: Json, headers: Headers, conn: DbConn, nt: } // Update cipher data - use super::ciphers::update_cipher_from_data; - for cipher_data in data.account_data.ciphers { if cipher_data.organization_id.is_none() { let Some(saved_cipher) = existing_ciphers.iter_mut().find(|c| &c.uuid == cipher_data.id.as_ref().unwrap()) @@ -890,7 +889,7 @@ async fn post_rotatekey(data: Json, headers: Headers, conn: DbConn, nt: // Prevent triggering cipher updates via WebSockets by settings UpdateType::None // The user sessions are invalidated because all the ciphers were re-encrypted and thus triggering an update could cause issues. // We force the users to logout after the user has been saved to try and prevent these issues. - update_cipher_from_data(saved_cipher, cipher_data, &headers, None, &conn, &nt, UpdateType::None).await? + update_cipher_from_data(saved_cipher, cipher_data, &headers, None, &conn, &nt, UpdateType::None).await?; } } @@ -1020,24 +1019,22 @@ async fn post_email(data: Json, headers: Headers, conn: DbConn, err!("Email already in use"); } - match user.email_new { - Some(ref val) => { - if val != &data.new_email { - err!("Email change mismatch"); - } + if let Some(ref val) = user.email_new { + if val != &data.new_email { + err!("Email change mismatch"); } - None => err!("No email change pending"), + } else { + err!("No email change pending") } if CONFIG.mail_enabled() { // Only check the token if we sent out an email... - match user.email_new_token { - Some(ref val) => { - if *val != data.token.into_string() { - err!("Token mismatch"); - } + if let Some(ref val) = user.email_new_token { + if *val != data.token.into_string() { + err!("Token mismatch"); } - None => err!("No email change pending"), + } else { + err!("No email change pending") } user.verified_at = Some(Utc::now().naive_utc()); } else { @@ -1114,10 +1111,10 @@ async fn post_delete_recover(data: Json, conn: DbConn) -> Emp let data: DeleteRecoverData = data.into_inner(); if CONFIG.mail_enabled() { - if let Some(user) = User::find_by_mail(&data.email, &conn).await { - if let Err(e) = mail::send_delete_account(&user.email, &user.uuid).await { - error!("Error sending delete account email: {e:#?}"); - } + if let Some(user) = User::find_by_mail(&data.email, &conn).await + && let Err(e) = mail::send_delete_account(&user.email, &user.uuid).await + { + error!("Error sending delete account email: {e:#?}"); } Ok(()) } else { @@ -1169,6 +1166,7 @@ async fn delete_account(data: Json, headers: Headers, conn: D user.delete(&conn).await } +#[expect(clippy::needless_pass_by_value, reason = "Not beneficial for Headers")] #[get("/accounts/revision-date")] fn revision_date(headers: Headers) -> JsonResult { let revision_date = headers.user.updated_at.and_utc().timestamp_millis(); @@ -1183,12 +1181,12 @@ struct PasswordHintData { #[post("/accounts/password-hint", data = "")] async fn password_hint(data: Json, conn: DbConn) -> EmptyResult { + const NO_HINT: &str = "Sorry, you have no password hint..."; + if !CONFIG.password_hints_allowed() || (!CONFIG.mail_enabled() && !CONFIG.show_password_hint()) { err!("This server is not configured to provide password hints."); } - const NO_HINT: &str = "Sorry, you have no password hint..."; - let data: PasswordHintData = data.into_inner(); let email = &data.email; @@ -1199,9 +1197,9 @@ async fn password_hint(data: Json, conn: DbConn) -> EmptyResul // There is still a timing side channel here in that the code // paths that send mail take noticeably longer than ones that // don't. Add a randomized sleep to mitigate this somewhat. - use rand::{rngs::SmallRng, RngExt}; + use rand::{RngExt, rngs::SmallRng}; let mut rng: SmallRng = rand::make_rng(); - let sleep_ms = rng.random_range(900..=1100) as u64; + let sleep_ms: u64 = rng.random_range(900..=1100); tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await; Ok(()) } else { @@ -1229,11 +1227,11 @@ pub struct PreloginData { } #[post("/accounts/prelogin", data = "")] -async fn prelogin(data: Json, conn: DbConn) -> Json { - _prelogin(data, conn).await +async fn post_prelogin(data: Json, conn: DbConn) -> Json { + prelogin(data, conn).await } -pub async fn _prelogin(data: Json, conn: DbConn) -> Json { +pub async fn prelogin(data: Json, conn: DbConn) -> Json { let data: PreloginData = data.into_inner(); let (kdf_type, kdf_iter, kdf_mem, kdf_para) = match User::find_by_mail(&data.email, &conn).await { @@ -1283,7 +1281,7 @@ async fn verify_password(data: Json, headers: Headers Ok(Json(master_password_policy(&user, &conn).await)) } -async fn _api_key(data: Json, rotate: bool, headers: Headers, conn: DbConn) -> JsonResult { +async fn update_api_key(data: Json, rotate: bool, headers: Headers, conn: DbConn) -> JsonResult { use crate::util::format_date; let data: PasswordOrOtpData = data.into_inner(); @@ -1304,13 +1302,13 @@ async fn _api_key(data: Json, rotate: bool, headers: Headers, } #[post("/accounts/api-key", data = "")] -async fn api_key(data: Json, headers: Headers, conn: DbConn) -> JsonResult { - _api_key(data, false, headers, conn).await +async fn post_api_key(data: Json, headers: Headers, conn: DbConn) -> JsonResult { + update_api_key(data, false, headers, conn).await } #[post("/accounts/rotate-api-key", data = "")] async fn rotate_api_key(data: Json, headers: Headers, conn: DbConn) -> JsonResult { - _api_key(data, true, headers, conn).await + update_api_key(data, true, headers, conn).await } #[get("/devices/knowndevice")] @@ -1353,7 +1351,7 @@ impl<'r> FromRequest<'r> for KnownDevice { }; let uuid = if let Some(uuid) = req.headers().get_one("X-Device-Identifier") { - uuid.to_string().into() + uuid.to_owned().into() } else { return Outcome::Error((Status::BadRequest, "X-Device-Identifier value is required")); }; @@ -1368,7 +1366,7 @@ impl<'r> FromRequest<'r> for KnownDevice { #[get("/devices")] async fn get_all_devices(headers: Headers, conn: DbConn) -> JsonResult { let devices = Device::find_with_auth_request_by_user(&headers.user.uuid, &conn).await; - let devices = devices.iter().map(|device| device.to_json()).collect::>(); + let devices = devices.iter().map(DeviceWithAuthRequest::to_json).collect::>(); Ok(Json(json!({ "data": devices, @@ -1708,6 +1706,6 @@ pub async fn purge_auth_requests(pool: DbPool) { if let Ok(conn) = pool.get().await { AuthRequest::purge_expired_auth_requests(&conn).await; } else { - error!("Failed to get DB connection while purging auth requests") + error!("Failed to get DB connection while purging auth requests"); } } diff --git a/src/api/core/ciphers.rs b/src/api/core/ciphers.rs index 43e555e2..7e6a34fa 100644 --- a/src/api/core/ciphers.rs +++ b/src/api/core/ciphers.rs @@ -5,27 +5,27 @@ use num_traits::ToPrimitive; use rocket::fs::TempFile; use rocket::serde::json::Json; use rocket::{ - form::{Form, FromForm}, Route, + form::{Form, FromForm}, }; use serde_json::Value; use crate::auth::ClientVersion; -use crate::util::{deser_opt_nonempty_str, save_temp_file, NumberOrString}; +use crate::util::{NumberOrString, deser_opt_nonempty_str, save_temp_file}; use crate::{ - api::{self, core::log_event, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType}, + CONFIG, + api::{self, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType, core::log_event}, auth::{Headers, OrgIdGuard, OwnerHeaders}, config::PathType, crypto, db::{ + DbConn, DbPool, models::{ Archive, Attachment, AttachmentId, Cipher, CipherId, Collection, CollectionCipher, CollectionGroup, CollectionId, CollectionUser, EventType, Favorite, Folder, FolderCipher, FolderId, Group, Membership, MembershipType, OrgPolicy, OrgPolicyType, OrganizationId, RepromptType, Send, UserId, }, - DbConn, DbPool, }, - CONFIG, }; use super::folders::FolderData; @@ -108,7 +108,7 @@ pub async fn purge_trashed_ciphers(pool: DbPool) { if let Ok(conn) = pool.get().await { Cipher::purge_trash(&conn).await; } else { - error!("Failed to get DB connection while purging trashed ciphers") + error!("Failed to get DB connection while purging trashed ciphers"); } } @@ -164,7 +164,7 @@ async fn sync(data: SyncData, headers: Headers, client_version: Option, ut: UpdateType, ) -> EmptyResult { + // Cleanup cipher data, like removing the 'Response' key. + // This key is somewhere generated during Javascript so no way for us this fix this. + // Also, upstream only retrieves keys they actually want to store, and thus skip the 'Response' key. + // We do not mind which data is in it, the keep our model more flexible when there are upstream changes. + // But, we at least know we do not need to store and return this specific key. + fn clean_cipher_data(mut json_data: Value) -> Value { + if json_data.is_array() { + json_data.as_array_mut().unwrap().iter_mut().for_each(|ref mut f| { + f.as_object_mut().unwrap().remove("response"); + }); + } + json_data + } + enforce_personal_ownership_policy(Some(&data), headers, conn).await?; // Check that the client isn't updating an existing cipher with stale data. // And only perform this check when not importing ciphers, else the date/time check will fail. - if ut != UpdateType::None { - if let Some(dt) = data.last_known_revision_date { - match NaiveDateTime::parse_from_str(&dt, "%+") { - // ISO 8601 format - Err(err) => warn!("Error parsing LastKnownRevisionDate '{dt}': {err}"), - Ok(dt) if cipher.updated_at.signed_duration_since(dt).num_seconds() > 1 => { - err!("The client copy of this cipher is out of date. Resync the client and try again.") - } - Ok(_) => (), + if ut != UpdateType::None + && let Some(dt) = data.last_known_revision_date + { + match NaiveDateTime::parse_from_str(&dt, "%+") { + // ISO 8601 format + Err(err) => warn!("Error parsing LastKnownRevisionDate '{dt}': {err}"), + Ok(dt) if cipher.updated_at.signed_duration_since(dt).num_seconds() > 1 => { + err!("The client copy of this cipher is out of date. Resync the client and try again.") } + Ok(_) => (), } } @@ -456,25 +470,22 @@ pub async fn update_cipher_from_data( cipher.user_uuid = Some(headers.user.uuid.clone()); } - if let Some(ref folder_id) = data.folder_id { - if Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, conn).await.is_none() { - err!("Invalid folder", "Folder does not exist or belongs to another user"); - } + if let Some(ref folder_id) = data.folder_id + && Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, conn).await.is_none() + { + err!("Invalid folder", "Folder does not exist or belongs to another user"); } // Modify attachments name and keys when rotating if let Some(attachments) = data.attachments2 { for (id, attachment) in attachments { - let mut saved_att = match Attachment::find_by_id(&id, conn).await { - Some(att) => att, - None => { - // Warn and continue here. - // A missing attachment means it was removed via an other client. - // Also the Desktop Client supports removing attachments and save an update afterwards. - // Bitwarden it self ignores these mismatches server side. - warn!("Attachment {id} doesn't exist"); - continue; - } + let Some(mut saved_att) = Attachment::find_by_id(&id, conn).await else { + // Warn and continue here. + // A missing attachment means it was removed via an other client. + // Also the Desktop Client supports removing attachments and save an update afterwards. + // Bitwarden it self ignores these mismatches server side. + warn!("Attachment {id} doesn't exist"); + continue; }; if saved_att.cipher_uuid != cipher.uuid { @@ -491,20 +502,6 @@ pub async fn update_cipher_from_data( } } - // Cleanup cipher data, like removing the 'Response' key. - // This key is somewhere generated during Javascript so no way for us this fix this. - // Also, upstream only retrieves keys they actually want to store, and thus skip the 'Response' key. - // We do not mind which data is in it, the keep our model more flexible when there are upstream changes. - // But, we at least know we do not need to store and return this specific key. - fn _clean_cipher_data(mut json_data: Value) -> Value { - if json_data.is_array() { - json_data.as_array_mut().unwrap().iter_mut().for_each(|ref mut f| { - f.as_object_mut().unwrap().remove("response"); - }); - }; - json_data - } - let type_data_opt = match data.r#type { 1 => data.login, 2 => data.secure_note, @@ -514,23 +511,22 @@ pub async fn update_cipher_from_data( _ => err!("Invalid type"), }; - let type_data = match type_data_opt { - Some(mut data) => { - // Remove the 'Response' key from the base object. - data.as_object_mut().unwrap().remove("response"); - // Remove the 'Response' key from every Uri. - if data["uris"].is_array() { - data["uris"] = _clean_cipher_data(data["uris"].clone()); - } - data + let type_data = if let Some(mut data) = type_data_opt { + // Remove the 'Response' key from the base object. + data.as_object_mut().unwrap().remove("response"); + // Remove the 'Response' key from every Uri. + if data["uris"].is_array() { + data["uris"] = clean_cipher_data(data["uris"].clone()); } - None => err!("Data missing"), + data + } else { + err!("Data missing") }; cipher.key = data.key; cipher.name = data.name; cipher.notes = data.notes; - cipher.fields = data.fields.map(|f| _clean_cipher_data(f).to_string()); + cipher.fields = data.fields.map(|f| clean_cipher_data(f).to_string()); cipher.data = type_data.to_string(); cipher.password_history = data.password_history.map(|f| f.to_string()); cipher.reprompt = data.reprompt.filter(|r| *r == RepromptType::None as i32 || *r == RepromptType::Password as i32); @@ -612,7 +608,7 @@ async fn post_ciphers_import(data: Json, headers: Headers, conn: DbC let existing_folders: HashSet> = Folder::find_by_user(&headers.user.uuid, &conn).await.into_iter().map(|f| Some(f.uuid)).collect(); let mut folders: Vec = Vec::with_capacity(data.folders.len()); - for folder in data.folders.into_iter() { + for folder in data.folders { let folder_id = if existing_folders.contains(&folder.id) { folder.id.unwrap() } else { @@ -737,10 +733,10 @@ async fn put_cipher_partial( err!("Cipher does not exist", "Cipher is not accessible for the current user") } - if let Some(ref folder_id) = data.folder_id { - if Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, &conn).await.is_none() { - err!("Invalid folder", "Folder does not exist or belongs to another user"); - } + if let Some(ref folder_id) = data.folder_id + && Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, &conn).await.is_none() + { + err!("Invalid folder", "Folder does not exist or belongs to another user"); } // Move cipher @@ -1004,7 +1000,7 @@ async fn put_cipher_share_selected( err!("You must select at least one collection.") } - for cipher in data.ciphers.iter() { + for cipher in &data.ciphers { if cipher.id.is_none() { err!("Request missing ids field") } @@ -1016,11 +1012,10 @@ async fn put_cipher_share_selected( collection_ids: data.collection_ids.clone(), }; - match shared_cipher_data.cipher.id.take() { - Some(id) => { - share_cipher_by_uuid(&id, shared_cipher_data, &headers, &conn, &nt, Some(UpdateType::None)).await? - } - None => err!("Request missing ids field"), + if let Some(id) = shared_cipher_data.cipher.id.take() { + share_cipher_by_uuid(&id, shared_cipher_data, &headers, &conn, &nt, Some(UpdateType::None)).await? + } else { + err!("Request missing ids field") }; } @@ -1038,15 +1033,14 @@ async fn share_cipher_by_uuid( nt: &Notify<'_>, override_ut: Option, ) -> JsonResult { - let mut cipher = match Cipher::find_by_uuid(cipher_id, conn).await { - Some(cipher) => { - if cipher.is_write_accessible_to_user(&headers.user.uuid, conn).await { - cipher - } else { - err!("Cipher is not write accessible") - } + let mut cipher = if let Some(cipher) = Cipher::find_by_uuid(cipher_id, conn).await { + if cipher.is_write_accessible_to_user(&headers.user.uuid, conn).await { + cipher + } else { + err!("Cipher is not write accessible") } - None => err!("Cipher doesn't exist"), + } else { + err!("Cipher doesn't exist") }; let mut shared_to_collections = vec![]; @@ -1065,7 +1059,7 @@ async fn share_cipher_by_uuid( } } } - }; + } // When LastKnownRevisionDate is None, it is a new cipher, so send CipherCreate. // If there is an override, like when handling multiple items, we want to prevent a push notification for every single item @@ -1263,10 +1257,10 @@ async fn save_attachment( err!("Cipher is neither owned by a user nor an organization"); }; - if let Some(size_limit) = size_limit { - if size > size_limit { - err!("Attachment storage limit exceeded with this file"); - } + if let Some(size_limit) = size_limit + && size > size_limit + { + err!("Attachment storage limit exceeded with this file"); } let file_id = match &attachment { @@ -1408,7 +1402,7 @@ async fn post_attachment_share( conn: DbConn, nt: Notify<'_>, ) -> JsonResult { - _delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await?; + delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await?; post_attachment(cipher_id, data, headers, conn, nt).await } @@ -1442,7 +1436,7 @@ async fn delete_attachment( conn: DbConn, nt: Notify<'_>, ) -> JsonResult { - _delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await + delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await } #[delete("/ciphers//attachment//admin")] @@ -1453,42 +1447,42 @@ async fn delete_attachment_admin( conn: DbConn, nt: Notify<'_>, ) -> JsonResult { - _delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await + delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await } #[post("/ciphers//delete")] async fn delete_cipher_post(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult { - _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await + delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await // permanent delete } #[post("/ciphers//delete-admin")] async fn delete_cipher_post_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult { - _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await + delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await // permanent delete } #[put("/ciphers//delete")] async fn delete_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult { - _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await + delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await // soft delete } #[put("/ciphers//delete-admin")] async fn delete_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult { - _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await + delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await // soft delete } #[delete("/ciphers/")] async fn delete_cipher(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult { - _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await + delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await // permanent delete } #[delete("/ciphers//admin")] async fn delete_cipher_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult { - _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await + delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await // permanent delete } @@ -1499,7 +1493,7 @@ async fn delete_cipher_selected( conn: DbConn, nt: Notify<'_>, ) -> EmptyResult { - _delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await + delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await // permanent delete } @@ -1510,7 +1504,7 @@ async fn delete_cipher_selected_post( conn: DbConn, nt: Notify<'_>, ) -> EmptyResult { - _delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await + delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await // permanent delete } @@ -1521,7 +1515,7 @@ async fn delete_cipher_selected_put( conn: DbConn, nt: Notify<'_>, ) -> EmptyResult { - _delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await + delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await // soft delete } @@ -1532,7 +1526,7 @@ async fn delete_cipher_selected_admin( conn: DbConn, nt: Notify<'_>, ) -> EmptyResult { - _delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await + delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await // permanent delete } @@ -1543,7 +1537,7 @@ async fn delete_cipher_selected_post_admin( conn: DbConn, nt: Notify<'_>, ) -> EmptyResult { - _delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await + delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await // permanent delete } @@ -1554,18 +1548,18 @@ async fn delete_cipher_selected_put_admin( conn: DbConn, nt: Notify<'_>, ) -> EmptyResult { - _delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await + delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await // soft delete } #[put("/ciphers//restore")] async fn restore_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult { - _restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await + restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await } #[put("/ciphers//restore-admin")] async fn restore_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult { - _restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await + restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await } #[put("/ciphers/restore-admin", data = "")] @@ -1575,7 +1569,7 @@ async fn restore_cipher_selected_admin( conn: DbConn, nt: Notify<'_>, ) -> JsonResult { - _restore_multiple_ciphers(data, &headers, &conn, &nt).await + restore_multiple_ciphers(data, &headers, &conn, &nt).await } #[put("/ciphers/restore", data = "")] @@ -1585,7 +1579,7 @@ async fn restore_cipher_selected( conn: DbConn, nt: Notify<'_>, ) -> JsonResult { - _restore_multiple_ciphers(data, &headers, &conn, &nt).await + restore_multiple_ciphers(data, &headers, &conn, &nt).await } #[derive(Deserialize)] @@ -1606,10 +1600,10 @@ async fn move_cipher_selected( let data = data.into_inner(); let user_id = &headers.user.uuid; - if let Some(ref folder_id) = data.folder_id { - if Folder::find_by_uuid_and_user(folder_id, user_id, &conn).await.is_none() { - err!("Invalid folder", "Folder does not exist or belongs to another user"); - } + if let Some(ref folder_id) = data.folder_id + && Folder::find_by_uuid_and_user(folder_id, user_id, &conn).await.is_none() + { + err!("Invalid folder", "Folder does not exist or belongs to another user"); } let cipher_count = data.ids.len(); @@ -1773,7 +1767,7 @@ pub enum CipherDeleteOptions { HardMulti, } -async fn _delete_cipher_by_uuid( +async fn delete_cipher_by_uuid( cipher_id: &CipherId, headers: &Headers, conn: &DbConn, @@ -1839,7 +1833,7 @@ struct CipherIdsData { ids: Vec, } -async fn _delete_multiple_ciphers( +async fn delete_multiple_ciphers( data: Json, headers: Headers, conn: DbConn, @@ -1849,9 +1843,9 @@ async fn _delete_multiple_ciphers( let data = data.into_inner(); for cipher_id in data.ids { - if let error @ Err(_) = _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &delete_options, &nt).await { + if let error @ Err(_) = delete_cipher_by_uuid(&cipher_id, &headers, &conn, &delete_options, &nt).await { return error; - }; + } } // Multi delete actions do not send out a push for each cipher, we need to send a general sync here @@ -1860,7 +1854,7 @@ async fn _delete_multiple_ciphers( Ok(()) } -async fn _restore_cipher_by_uuid( +async fn restore_cipher_by_uuid( cipher_id: &CipherId, headers: &Headers, multi_restore: bool, @@ -1906,7 +1900,7 @@ async fn _restore_cipher_by_uuid( Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, conn).await?)) } -async fn _restore_multiple_ciphers( +async fn restore_multiple_ciphers( data: Json, headers: &Headers, conn: &DbConn, @@ -1916,7 +1910,7 @@ async fn _restore_multiple_ciphers( let mut ciphers: Vec = Vec::new(); for cipher_id in data.ids { - match _restore_cipher_by_uuid(&cipher_id, headers, true, conn, nt).await { + match restore_cipher_by_uuid(&cipher_id, headers, true, conn, nt).await { Ok(json) => ciphers.push(json.into_inner()), err => return err, } @@ -1932,7 +1926,7 @@ async fn _restore_multiple_ciphers( }))) } -async fn _delete_cipher_attachment_by_id( +async fn delete_cipher_attachment_by_id( cipher_id: &CipherId, attachment_id: &AttachmentId, headers: &Headers, @@ -2206,11 +2200,11 @@ impl CipherSyncData { }; Self { - cipher_archives, cipher_attachments, cipher_folders, cipher_favorites, cipher_collections, + cipher_archives, members, user_collections, user_collections_groups, diff --git a/src/api/core/emergency_access.rs b/src/api/core/emergency_access.rs index 29a15c8d..2eb95502 100644 --- a/src/api/core/emergency_access.rs +++ b/src/api/core/emergency_access.rs @@ -1,23 +1,23 @@ use chrono::{TimeDelta, Utc}; -use rocket::{serde::json::Json, Route}; +use rocket::{Route, serde::json::Json}; use serde_json::Value; use crate::{ + CONFIG, api::{ - core::{CipherSyncData, CipherSyncType}, EmptyResult, JsonResult, + core::{CipherSyncData, CipherSyncType}, }, - auth::{decode_emergency_access_invite, Headers}, + auth::{Headers, decode_emergency_access_invite}, db::{ + DbConn, DbPool, models::{ Cipher, EmergencyAccess, EmergencyAccessId, EmergencyAccessStatus, EmergencyAccessType, Invitation, Membership, MembershipType, OrgPolicy, TwoFactor, User, UserId, }, - DbConn, DbPool, }, mail, util::NumberOrString, - CONFIG, }; pub fn routes() -> Vec { @@ -55,7 +55,7 @@ async fn get_contacts(headers: Headers, conn: DbConn) -> Json { let mut emergency_access_list_json = Vec::with_capacity(emergency_access_list.len()); for ea in emergency_access_list { if let Some(grantee) = ea.to_json_grantee_details(&conn).await { - emergency_access_list_json.push(grantee) + emergency_access_list_json.push(grantee); } } @@ -89,11 +89,14 @@ async fn get_grantees(headers: Headers, conn: DbConn) -> Json { async fn get_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult { check_emergency_access_enabled()?; - match EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await { - Some(emergency_access) => Ok(Json( + if let Some(emergency_access) = + EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await + { + Ok(Json( emergency_access.to_json_grantee_details(&conn).await.expect("Grantee user should exist but does not!"), - )), - None => err!("Emergency access not valid."), + )) + } else { + err!("Emergency access not valid.") } } @@ -136,9 +139,10 @@ async fn post_emergency_access( err!("Emergency access not valid.") }; - let new_type = match EmergencyAccessType::from_str(&data.r#type.into_string()) { - Some(new_type) => new_type as i32, - None => err!("Invalid emergency access type."), + let new_type = if let Some(new_type) = EmergencyAccessType::from_str(&data.r#type.into_string()) { + new_type as i32 + } else { + err!("Invalid emergency access type.") }; emergency_access.atype = new_type; @@ -205,9 +209,10 @@ async fn send_invite(data: Json, headers: Headers, co let emergency_access_status = EmergencyAccessStatus::Invited as i32; - let new_type = match EmergencyAccessType::from_str(&data.r#type.into_string()) { - Some(new_type) => new_type as i32, - None => err!("Invalid emergency access type."), + let new_type = if let Some(new_type) = EmergencyAccessType::from_str(&data.r#type.into_string()) { + new_type as i32 + } else { + err!("Invalid emergency access type.") }; let grantor_user = headers.user; @@ -342,12 +347,11 @@ async fn accept_invite( err!("Claim email does not match current users email") } - let grantee_user = match User::find_by_mail(&claims.email, &conn).await { - Some(user) => { - Invitation::take(&claims.email, &conn).await; - user - } - None => err!("Invited user not found"), + let grantee_user = if let Some(user) = User::find_by_mail(&claims.email, &conn).await { + Invitation::take(&claims.email, &conn).await; + user + } else { + err!("Invited user not found") }; // We need to search for the uuid in combination with the email, since we do not yet store the uuid of the grantee in the database. @@ -766,7 +770,7 @@ pub async fn emergency_request_timeout_job(pool: DbPool) { } } } else { - error!("Failed to get DB connection while searching emergency request timed out") + error!("Failed to get DB connection while searching emergency request timed out"); } } @@ -825,6 +829,6 @@ pub async fn emergency_notification_reminder_job(pool: DbPool) { } } } else { - error!("Failed to get DB connection while searching emergency notification reminder") + error!("Failed to get DB connection while searching emergency notification reminder"); } } diff --git a/src/api/core/events.rs b/src/api/core/events.rs index d1612255..b6e2bacd 100644 --- a/src/api/core/events.rs +++ b/src/api/core/events.rs @@ -1,18 +1,18 @@ use std::net::IpAddr; use chrono::NaiveDateTime; -use rocket::{form::FromForm, serde::json::Json, Route}; +use rocket::{Route, form::FromForm, serde::json::Json}; use serde_json::Value; use crate::{ + CONFIG, api::{EmptyResult, JsonResult}, auth::{AdminHeaders, Headers}, db::{ - models::{Cipher, CipherId, Event, Membership, MembershipId, OrganizationId, UserId}, DbConn, DbPool, + models::{Cipher, CipherId, Event, Membership, MembershipId, OrganizationId, UserId}, }, util::parse_date, - CONFIG, }; /// ############################################################################################################### @@ -38,9 +38,7 @@ async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: Admin // Return an empty vec when we org events are disabled. // This prevents client errors - let events_json: Vec = if !CONFIG.org_events_enabled() { - Vec::with_capacity(0) - } else { + let events_json: Vec = if CONFIG.org_events_enabled() { let start_date = parse_date(&data.start); let end_date = if let Some(before_date) = &data.continuation_token { parse_date(before_date) @@ -51,8 +49,10 @@ async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: Admin Event::find_by_organization_uuid(&org_id, &start_date, &end_date, &conn) .await .iter() - .map(|e| e.to_json()) + .map(Event::to_json) .collect() + } else { + Vec::with_capacity(0) }; Ok(Json(json!({ @@ -64,27 +64,21 @@ async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: Admin #[get("/ciphers//events?")] async fn get_cipher_events(cipher_id: CipherId, data: EventRange, headers: Headers, conn: DbConn) -> JsonResult { - // Return an empty vec when we org events are disabled. + // Return an empty vec when org events are disabled. // This prevents client errors - let events_json: Vec = if !CONFIG.org_events_enabled() { - Vec::with_capacity(0) - } else { - let mut events_json = Vec::with_capacity(0); - if Membership::user_has_ge_admin_access_to_cipher(&headers.user.uuid, &cipher_id, &conn).await { - let start_date = parse_date(&data.start); - let end_date = if let Some(before_date) = &data.continuation_token { - parse_date(before_date) - } else { - parse_date(&data.end) - }; + let events_json: Vec = if CONFIG.org_events_enabled() + && Membership::user_has_ge_admin_access_to_cipher(&headers.user.uuid, &cipher_id, &conn).await + { + let start_date = parse_date(&data.start); + let end_date = if let Some(before_date) = &data.continuation_token { + parse_date(before_date) + } else { + parse_date(&data.end) + }; - events_json = Event::find_by_cipher_uuid(&cipher_id, &start_date, &end_date, &conn) - .await - .iter() - .map(|e| e.to_json()) - .collect() - } - events_json + Event::find_by_cipher_uuid(&cipher_id, &start_date, &end_date, &conn).await.iter().map(Event::to_json).collect() + } else { + Vec::with_capacity(0) }; Ok(Json(json!({ @@ -107,9 +101,7 @@ async fn get_user_events( } // Return an empty vec when we org events are disabled. // This prevents client errors - let events_json: Vec = if !CONFIG.org_events_enabled() { - Vec::with_capacity(0) - } else { + let events_json: Vec = if CONFIG.org_events_enabled() { let start_date = parse_date(&data.start); let end_date = if let Some(before_date) = &data.continuation_token { parse_date(before_date) @@ -120,8 +112,10 @@ async fn get_user_events( Event::find_by_org_and_member(&org_id, &member_id, &start_date, &end_date, &conn) .await .iter() - .map(|e| e.to_json()) + .map(Event::to_json) .collect() + } else { + Vec::with_capacity(0) }; Ok(Json(json!({ @@ -134,7 +128,8 @@ async fn get_user_events( fn get_continuation_token(events_json: &[Value]) -> Option<&str> { // When the length of the vec equals the max page_size there probably is more data // When it is less, then all events are loaded. - if events_json.len() as i64 == Event::PAGE_SIZE { + #[expect(clippy::cast_possible_truncation, reason = "PAGE_SIZE fits within usize")] + if events_json.len() == Event::PAGE_SIZE as usize { if let Some(last_event) = events_json.last() { last_event["date"].as_str() } else { @@ -176,7 +171,7 @@ async fn post_events_collect(data: Json>, headers: Headers, let event_date = parse_date(&event.date); match event.r#type { 1000..=1099 => { - _log_user_event( + log_user_event_impl( event.r#type, &headers.user.uuid, headers.device.atype, @@ -188,7 +183,7 @@ async fn post_events_collect(data: Json>, headers: Headers, } 1600..=1699 => { if let Some(org_id) = &event.organization_id { - _log_event( + log_event_impl( event.r#type, org_id, org_id, @@ -202,22 +197,21 @@ async fn post_events_collect(data: Json>, headers: Headers, } } _ => { - if let Some(cipher_uuid) = &event.cipher_id { - if let Some(cipher) = Cipher::find_by_uuid(cipher_uuid, &conn).await { - if let Some(org_id) = cipher.organization_uuid { - _log_event( - event.r#type, - cipher_uuid, - &org_id, - &headers.user.uuid, - headers.device.atype, - Some(event_date), - &headers.ip.ip, - &conn, - ) - .await; - } - } + if let Some(cipher_uuid) = &event.cipher_id + && let Some(cipher) = Cipher::find_by_uuid(cipher_uuid, &conn).await + && let Some(org_id) = cipher.organization_uuid + { + log_event_impl( + event.r#type, + cipher_uuid, + &org_id, + &headers.user.uuid, + headers.device.atype, + Some(event_date), + &headers.ip.ip, + &conn, + ) + .await; } } } @@ -229,10 +223,10 @@ pub async fn log_user_event(event_type: i32, user_id: &UserId, device_type: i32, if !CONFIG.org_events_enabled() { return; } - _log_user_event(event_type, user_id, device_type, None, ip, conn).await; + log_user_event_impl(event_type, user_id, device_type, None, ip, conn).await; } -async fn _log_user_event( +async fn log_user_event_impl( event_type: i32, user_id: &UserId, device_type: i32, @@ -278,11 +272,11 @@ pub async fn log_event( if !CONFIG.org_events_enabled() { return; } - _log_event(event_type, source_uuid, org_id, act_user_id, device_type, None, ip, conn).await; + log_event_impl(event_type, source_uuid, org_id, act_user_id, device_type, None, ip, conn).await; } -#[allow(clippy::too_many_arguments)] -async fn _log_event( +#[expect(clippy::too_many_arguments)] +async fn log_event_impl( event_type: i32, source_uuid: &str, org_id: &OrganizationId, @@ -298,24 +292,24 @@ async fn _log_event( // 1000..=1099 Are user events, they need to be logged via log_user_event() // Cipher Events 1100..=1199 => { - event.cipher_uuid = Some(source_uuid.to_string().into()); + event.cipher_uuid = Some(source_uuid.to_owned().into()); } // Collection Events 1300..=1399 => { - event.collection_uuid = Some(source_uuid.to_string().into()); + event.collection_uuid = Some(source_uuid.to_owned().into()); } // Group Events 1400..=1499 => { - event.group_uuid = Some(source_uuid.to_string().into()); + event.group_uuid = Some(source_uuid.to_owned().into()); } // Org User Events 1500..=1599 => { - event.org_user_uuid = Some(source_uuid.to_string().into()); + event.org_user_uuid = Some(source_uuid.to_owned().into()); } // 1600..=1699 Are organizational events, and they do not need the source_uuid // Policy Events 1700..=1799 => { - event.policy_uuid = Some(source_uuid.to_string().into()); + event.policy_uuid = Some(source_uuid.to_owned().into()); } // Ignore others _ => {} @@ -338,6 +332,6 @@ pub async fn event_cleanup_job(pool: DbPool) { if let Ok(conn) = pool.get().await { Event::clean_events(&conn).await.ok(); } else { - error!("Failed to get DB connection while trying to cleanup the events table") + error!("Failed to get DB connection while trying to cleanup the events table"); } } diff --git a/src/api/core/folders.rs b/src/api/core/folders.rs index 1b3fd714..8c930093 100644 --- a/src/api/core/folders.rs +++ b/src/api/core/folders.rs @@ -5,8 +5,8 @@ use crate::{ api::{EmptyResult, JsonResult, Notify, UpdateType}, auth::Headers, db::{ - models::{Folder, FolderId}, DbConn, + models::{Folder, FolderId}, }, util::deser_opt_nonempty_str, }; @@ -29,9 +29,10 @@ async fn get_folders(headers: Headers, conn: DbConn) -> Json { #[get("/folders/")] async fn get_folder(folder_id: FolderId, headers: Headers, conn: DbConn) -> JsonResult { - match Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await { - Some(folder) => Ok(Json(folder.to_json())), - _ => err!("Invalid folder", "Folder does not exist or belongs to another user"), + if let Some(folder) = Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await { + Ok(Json(folder.to_json())) + } else { + err!("Invalid folder", "Folder does not exist or belongs to another user") } } diff --git a/src/api/core/mod.rs b/src/api/core/mod.rs index ad9002fd..178d7c45 100644 --- a/src/api/core/mod.rs +++ b/src/api/core/mod.rs @@ -9,14 +9,14 @@ mod sends; pub mod two_factor; pub use accounts::purge_auth_requests; -pub use ciphers::{purge_trashed_ciphers, CipherData, CipherSyncData, CipherSyncType}; +pub use ciphers::{CipherData, CipherSyncData, CipherSyncType, purge_trashed_ciphers}; pub use emergency_access::{emergency_notification_reminder_job, emergency_request_timeout_job}; pub use events::{event_cleanup_job, log_event, log_user_event}; use reqwest::Method; pub use sends::purge_sends; pub fn routes() -> Vec { - let mut eq_domains_routes = routes![get_eq_domains, post_eq_domains, put_eq_domains]; + let mut eq_domains_routes = routes![get_settings_domains, post_settings_domains, put_settings_domains]; let mut hibp_routes = routes![hibp_breach]; let mut meta_routes = routes![alive, now, version, config, get_api_webauthn]; @@ -47,20 +47,20 @@ pub fn events_routes() -> Vec { // // Move this somewhere else // -use rocket::{serde::json::Json, serde::json::Value, Catcher, Route}; +use rocket::{Catcher, Route, serde::json::Json, serde::json::Value}; use crate::{ + CONFIG, api::{EmptyResult, JsonResult, Notify, UpdateType}, auth::Headers, db::{ - models::{Membership, MembershipStatus, OrgPolicy, Organization, User}, DbConn, + models::{Membership, MembershipStatus, OrgPolicy, Organization, User}, }, error::Error, http_client::make_http_request, mail, - util::{parse_experimental_client_feature_flags, FeatureFlagFilter}, - CONFIG, + util::{FeatureFlagFilter, parse_experimental_client_feature_flags}, }; #[derive(Debug, Serialize, Deserialize)] @@ -73,15 +73,17 @@ struct GlobalDomain { const GLOBAL_DOMAINS: &str = include_str!("../../static/global_domains.json"); +#[expect(clippy::needless_pass_by_value, reason = "Not beneficial for Headers")] #[get("/settings/domains")] -fn get_eq_domains(headers: Headers) -> Json { - _get_eq_domains(&headers, false) +fn get_settings_domains(headers: Headers) -> Json { + get_eq_domains(&headers, false) } -fn _get_eq_domains(headers: &Headers, no_excluded: bool) -> Json { - let user = &headers.user; +fn get_eq_domains(headers: &Headers, no_excluded: bool) -> Json { use serde_json::from_str; + let user = &headers.user; + let equivalent_domains: Vec> = from_str(&user.equivalent_domains).unwrap(); let excluded_globals: Vec = from_str(&user.excluded_globals).unwrap(); @@ -110,17 +112,23 @@ struct EquivDomainData { } #[post("/settings/domains", data = "")] -async fn post_eq_domains(data: Json, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult { +async fn post_settings_domains( + data: Json, + headers: Headers, + conn: DbConn, + nt: Notify<'_>, +) -> JsonResult { + use serde_json::to_string; + let data: EquivDomainData = data.into_inner(); let excluded_globals = data.excluded_global_equivalent_domains.unwrap_or_default(); let equivalent_domains = data.equivalent_domains.unwrap_or_default(); let mut user = headers.user; - use serde_json::to_string; - user.excluded_globals = to_string(&excluded_globals).unwrap_or_else(|_| "[]".to_string()); - user.equivalent_domains = to_string(&equivalent_domains).unwrap_or_else(|_| "[]".to_string()); + user.excluded_globals = to_string(&excluded_globals).unwrap_or_else(|_| "[]".to_owned()); + user.equivalent_domains = to_string(&equivalent_domains).unwrap_or_else(|_| "[]".to_owned()); user.save(&conn).await?; @@ -130,8 +138,13 @@ async fn post_eq_domains(data: Json, headers: Headers, conn: Db } #[put("/settings/domains", data = "")] -async fn put_eq_domains(data: Json, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult { - post_eq_domains(data, headers, conn, nt).await +async fn put_settings_domains( + data: Json, + headers: Headers, + conn: DbConn, + nt: Notify<'_>, +) -> JsonResult { + post_settings_domains(data, headers, conn, nt).await } #[get("/hibp/breach?")] @@ -206,9 +219,9 @@ fn config() -> Json { // iOS (v2026.2.1): https://github.com/bitwarden/ios/blob/cdd9ba1770ca2ffc098d02d12cc3208e3a830454/BitwardenShared/Core/Platform/Models/Enum/FeatureFlag.swift#L7 let mut feature_states = parse_experimental_client_feature_flags( &CONFIG.experimental_client_feature_flags(), - FeatureFlagFilter::ValidOnly, + &FeatureFlagFilter::ValidOnly, ); - feature_states.insert("pm-19148-innovation-archive".to_string(), true); + feature_states.insert("pm-19148-innovation-archive".to_owned(), true); Json(json!({ // Note: The clients use this version to handle backwards compatibility concerns @@ -278,9 +291,8 @@ async fn accept_org_invite( member.save(conn).await?; if CONFIG.mail_enabled() { - let org = match Organization::find_by_uuid(&member.org_uuid, conn).await { - Some(org) => org, - None => err!("Organization not found."), + let Some(org) = Organization::find_by_uuid(&member.org_uuid, conn).await else { + err!("Organization not found.") }; // User was invited to an organization, so they must be confirmed manually after acceptance mail::send_invite_accepted(&user.email, &member.invited_by_email.unwrap_or(org.billing_email), &org.name) diff --git a/src/api/core/organizations.rs b/src/api/core/organizations.rs index 3e6eb767..d801af31 100644 --- a/src/api/core/organizations.rs +++ b/src/api/core/organizations.rs @@ -1,28 +1,28 @@ use num_traits::FromPrimitive; -use rocket::serde::json::Json; use rocket::Route; +use rocket::serde::json::Json; use serde_json::Value; use std::collections::{HashMap, HashSet}; use crate::api::admin::FAKE_ADMIN_UUID; use crate::{ + CONFIG, api::{ - core::{accept_org_invite, log_event, two_factor, CipherSyncData, CipherSyncType}, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType, + core::{CipherSyncData, CipherSyncType, accept_org_invite, log_event, two_factor}, }, - auth::{decode_invite, AdminHeaders, Headers, ManagerHeaders, ManagerHeadersLoose, OrgMemberHeaders, OwnerHeaders}, + auth::{AdminHeaders, Headers, ManagerHeaders, ManagerHeadersLoose, OrgMemberHeaders, OwnerHeaders, decode_invite}, db::{ + DbConn, models::{ Cipher, CipherId, Collection, CollectionCipher, CollectionGroup, CollectionId, CollectionUser, EventType, Group, GroupId, GroupUser, Invitation, Membership, MembershipId, MembershipStatus, MembershipType, OrgPolicy, OrgPolicyType, Organization, OrganizationApiKey, OrganizationId, User, UserId, }, - DbConn, }, mail, sso::FAKE_SSO_IDENTIFIER, - util::{convert_json_key_lcase_first, NumberOrString}, - CONFIG, + util::{NumberOrString, convert_json_key_lcase_first}, }; pub fn routes() -> Vec { @@ -97,7 +97,7 @@ pub fn routes() -> Vec { get_reset_password_details, put_reset_password, get_org_export, - api_key, + post_api_key, rotate_api_key, get_billing_metadata, get_billing_warnings, @@ -286,9 +286,10 @@ async fn get_organization(org_id: OrganizationId, headers: OwnerHeaders, conn: D if org_id != headers.org_id { err!("Organization not found", "Organization id's do not match"); } - match Organization::find_by_uuid(&org_id, &conn).await { - Some(organization) => Ok(Json(organization.to_json())), - None => err!("Can't find organization details"), + if let Some(organization) = Organization::find_by_uuid(&org_id, &conn).await { + Ok(Json(organization.to_json())) + } else { + err!("Can't find organization details") } } @@ -367,7 +368,7 @@ async fn get_auto_enroll_status(identifier: &str, headers: Headers, conn: DbConn }; let (id, identifier, rp_auto_enroll) = match org { - None => (identifier.to_string(), identifier.to_string(), false), + None => (identifier.to_owned(), identifier.to_owned(), false), Some(org) => ( org.uuid.to_string(), org.uuid.to_string(), @@ -393,7 +394,7 @@ async fn get_org_collections(org_id: OrganizationId, headers: ManagerHeadersLoos } Ok(Json(json!({ - "data": _get_org_collections(&org_id, &conn).await, + "data": get_org_collections_impl(&org_id, &conn).await, "object": "list", "continuationToken": null, }))) @@ -465,7 +466,7 @@ async fn get_org_collections_details(org_id: OrganizationId, headers: ManagerHea CollectionGroup::find_by_collection(&col.uuid, &conn) .await .iter() - .map(|collection_group| collection_group.to_json_details_for_group()) + .map(CollectionGroup::to_json_details_for_group) .collect() } else { Vec::with_capacity(0) @@ -477,7 +478,7 @@ async fn get_org_collections_details(org_id: OrganizationId, headers: ManagerHea json_object["groups"] = json!(groups); json_object["object"] = json!("collectionAccessDetails"); json_object["unmanaged"] = json!(false); - data.push(json_object) + data.push(json_object); } Ok(Json(json!({ @@ -487,7 +488,7 @@ async fn get_org_collections_details(org_id: OrganizationId, headers: ManagerHea }))) } -async fn _get_org_collections(org_id: &OrganizationId, conn: &DbConn) -> Value { +async fn get_org_collections_impl(org_id: &OrganizationId, conn: &DbConn) -> Value { Collection::find_by_organization(org_id, conn).await.iter().map(Collection::to_json).collect::() } @@ -573,7 +574,7 @@ async fn post_bulk_access_collections( if Organization::find_by_uuid(&org_id, &conn).await.is_none() { err!("Can't find organization details") - }; + } for col_id in data.collection_ids { let Some(collection) = Collection::find_by_uuid_and_org(&col_id, &org_id, &conn).await else { @@ -650,7 +651,7 @@ async fn post_organization_collection_update( if Organization::find_by_uuid(&org_id, &conn).await.is_none() { err!("Can't find organization details") - }; + } let Some(mut collection) = Collection::find_by_uuid_and_org(&col_id, &org_id, &conn).await else { err!("Collection not found") @@ -701,7 +702,7 @@ async fn post_organization_collection_update( Ok(Json(collection.to_json_details(&headers.user.uuid, None, &conn).await)) } -async fn _delete_organization_collection( +async fn delete_organization_collection_impl( org_id: &OrganizationId, col_id: &CollectionId, headers: &ManagerHeaders, @@ -733,7 +734,7 @@ async fn delete_organization_collection( headers: ManagerHeaders, conn: DbConn, ) -> EmptyResult { - _delete_organization_collection(&org_id, &col_id, &headers, &conn).await + delete_organization_collection_impl(&org_id, &col_id, &headers, &conn).await } #[post("/organizations//collections//delete")] @@ -743,7 +744,7 @@ async fn post_organization_collection_delete( headers: ManagerHeaders, conn: DbConn, ) -> EmptyResult { - _delete_organization_collection(&org_id, &col_id, &headers, &conn).await + delete_organization_collection_impl(&org_id, &col_id, &headers, &conn).await } #[derive(Deserialize, Debug)] @@ -769,7 +770,7 @@ async fn bulk_delete_organization_collections( let headers = ManagerHeaders::from_loose(headers, &collections, &conn).await?; for col_id in collections { - _delete_organization_collection(&org_id, &col_id, &headers, &conn).await? + delete_organization_collection_impl(&org_id, &col_id, &headers, &conn).await?; } Ok(()) } @@ -799,7 +800,7 @@ async fn get_org_collection_detail( CollectionGroup::find_by_collection(&collection.uuid, &conn) .await .iter() - .map(|collection_group| collection_group.to_json_details_for_group()) + .map(CollectionGroup::to_json_details_for_group) .collect() } else { // The Bitwarden clients seem to call this API regardless of whether groups are enabled, @@ -886,13 +887,13 @@ async fn get_org_details(data: OrgIdData, headers: ManagerHeadersLoose, conn: Db } Ok(Json(json!({ - "data": _get_org_details(&data.organization_id, &headers.host, &headers.user.uuid, &conn).await?, + "data": get_org_details_impl(&data.organization_id, &headers.host, &headers.user.uuid, &conn).await?, "object": "list", "continuationToken": null, }))) } -async fn _get_org_details( +async fn get_org_details_impl( org_id: &OrganizationId, host: &str, user_id: &UserId, @@ -975,14 +976,13 @@ async fn post_org_keys( } let data: OrgKeyData = data.into_inner(); - let mut org = match Organization::find_by_uuid(&org_id, &conn).await { - Some(organization) => { - if organization.private_key.is_some() && organization.public_key.is_some() { - err!("Organization Keys already exist") - } - organization + let mut org = if let Some(organization) = Organization::find_by_uuid(&org_id, &conn).await { + if organization.private_key.is_some() && organization.public_key.is_some() { + err!("Organization Keys already exist") } - None => err!("Can't find organization details"), + organization + } else { + err!("Can't find organization details") }; org.private_key = Some(data.encrypted_private_key); @@ -1043,9 +1043,10 @@ async fn send_invite( // The from_str() will convert the custom role type into a manager role type let raw_type = &data.r#type.into_string(); // Membership::from_str will convert custom (4) to manager (3) - let new_type = match MembershipType::from_str(raw_type) { - Some(new_type) => new_type as i32, - None => err!("Invalid type"), + let new_type = if let Some(new_type) = MembershipType::from_str(raw_type) { + new_type as i32 + } else { + err!("Invalid type") }; if new_type != MembershipType::User && headers.membership_type != MembershipType::Owner { @@ -1062,7 +1063,7 @@ async fn send_invite( && data.permissions.get("createNewCollections") == Some(&json!(true))); let mut user_created: bool = false; - for email in data.emails.iter() { + for email in &data.emails { let mut member_status = MembershipStatus::Invited as i32; let user = match User::find_by_mail(email, &conn).await { None => { @@ -1086,13 +1087,13 @@ async fn send_invite( Some(user) => { if Membership::find_by_user_and_org(&user.uuid, &org_id, &conn).await.is_some() { err!(format!("User already in organization: {email}")) - } else { - // automatically accept existing users if mail is disabled - if !CONFIG.mail_enabled() && !user.password_hash.is_empty() { - member_status = MembershipStatus::Accepted as i32; - } - user } + + // automatically accept existing users if mail is disabled + if !CONFIG.mail_enabled() && !user.password_hash.is_empty() { + member_status = MembershipStatus::Accepted as i32; + } + user } }; @@ -1103,9 +1104,10 @@ async fn send_invite( new_member.save(&conn).await?; if CONFIG.mail_enabled() { - let org_name = match Organization::find_by_uuid(&org_id, &conn).await { - Some(org) => org.name, - None => err!("Error looking up organization"), + let org_name = if let Some(org) = Organization::find_by_uuid(&org_id, &conn).await { + org.name + } else { + err!("Error looking up organization") }; if let Err(e) = mail::send_invite( @@ -1159,7 +1161,7 @@ async fn send_invite( } } - for group_id in data.groups.iter() { + for group_id in &data.groups { let mut group_entry = GroupUser::new(group_id.clone(), new_member.uuid.clone()); group_entry.save(&conn).await?; } @@ -1182,8 +1184,8 @@ async fn bulk_reinvite_members( let mut bulk_response = Vec::new(); for member_id in data.ids { - let err_msg = match _reinvite_member(&org_id, &member_id, &headers.user.email, &conn).await { - Ok(_) => String::new(), + let err_msg = match reinvite_member_impl(&org_id, &member_id, &headers.user.email, &conn).await { + Ok(()) => String::new(), Err(e) => format!("{e:?}"), }; @@ -1193,7 +1195,7 @@ async fn bulk_reinvite_members( "id": member_id, "error": err_msg } - )) + )); } Ok(Json(json!({ @@ -1213,10 +1215,10 @@ async fn reinvite_member( if org_id != headers.org_id { err!("Organization not found", "Organization id's do not match"); } - _reinvite_member(&org_id, &member_id, &headers.user.email, &conn).await + reinvite_member_impl(&org_id, &member_id, &headers.user.email, &conn).await } -async fn _reinvite_member( +async fn reinvite_member_impl( org_id: &OrganizationId, member_id: &MembershipId, invited_by_email: &str, @@ -1238,13 +1240,14 @@ async fn _reinvite_member( err!("Invitations are not allowed.") } - let org_name = match Organization::find_by_uuid(org_id, conn).await { - Some(org) => org.name, - None => err!("Error looking up organization."), + let org_name = if let Some(org) = Organization::find_by_uuid(org_id, conn).await { + org.name + } else { + err!("Error looking up organization.") }; if CONFIG.mail_enabled() { - mail::send_invite(&user, org_id.clone(), member.uuid, &org_name, Some(invited_by_email.to_string())).await?; + mail::send_invite(&user, org_id.clone(), member.uuid, &org_name, Some(invited_by_email.to_owned())).await?; } else if user.password_hash.is_empty() { let invitation = Invitation::new(&user.email); invitation.save(conn).await?; @@ -1352,8 +1355,8 @@ async fn bulk_confirm_invite( for invite in keys { let member_id = invite.id.unwrap(); let user_key = invite.key.unwrap_or_default(); - let err_msg = match _confirm_invite(&org_id, &member_id, &user_key, &headers, &conn, &nt).await { - Ok(_) => String::new(), + let err_msg = match confirm_invite_impl(&org_id, &member_id, &user_key, &headers, &conn, &nt).await { + Ok(()) => String::new(), Err(e) => format!("{e:?}"), }; @@ -1387,10 +1390,10 @@ async fn confirm_invite( ) -> EmptyResult { let data = data.into_inner(); let user_key = data.key.unwrap_or_default(); - _confirm_invite(&org_id, &member_id, &user_key, &headers, &conn, &nt).await + confirm_invite_impl(&org_id, &member_id, &user_key, &headers, &conn, &nt).await } -async fn _confirm_invite( +async fn confirm_invite_impl( org_id: &OrganizationId, member_id: &MembershipId, key: &str, @@ -1418,7 +1421,7 @@ async fn _confirm_invite( } member_to_confirm.status = MembershipStatus::Confirmed as i32; - member_to_confirm.akey = key.to_string(); + member_to_confirm.akey = key.to_owned(); // This check is also done at accept_invite, _confirm_invite, _activate_member, edit_member, admin::update_membership_type OrgPolicy::check_user_allowed(&member_to_confirm, "confirm", conn).await?; @@ -1435,13 +1438,15 @@ async fn _confirm_invite( .await; if CONFIG.mail_enabled() { - let org_name = match Organization::find_by_uuid(org_id, conn).await { - Some(org) => org.name, - None => err!("Error looking up organization."), + let org_name = if let Some(org) = Organization::find_by_uuid(org_id, conn).await { + org.name + } else { + err!("Error looking up organization.") }; - let address = match User::find_by_uuid(&member_to_confirm.user_uuid, conn).await { - Some(user) => user.email, - None => err!("Error looking up user."), + let address = if let Some(user) = User::find_by_uuid(&member_to_confirm.user_uuid, conn).await { + user.email + } else { + err!("Error looking up user.") }; mail::send_invite_confirmed(&address, &org_name).await?; } @@ -1637,8 +1642,8 @@ async fn bulk_delete_member( let mut bulk_response = Vec::new(); for member_id in data.ids { - let err_msg = match _delete_member(&org_id, &member_id, &headers, &conn, &nt).await { - Ok(_) => String::new(), + let err_msg = match delete_member_impl(&org_id, &member_id, &headers, &conn, &nt).await { + Ok(()) => String::new(), Err(e) => format!("{e:?}"), }; @@ -1648,7 +1653,7 @@ async fn bulk_delete_member( "id": member_id, "error": err_msg } - )) + )); } Ok(Json(json!({ @@ -1666,10 +1671,10 @@ async fn delete_member( conn: DbConn, nt: Notify<'_>, ) -> EmptyResult { - _delete_member(&org_id, &member_id, &headers, &conn, &nt).await + delete_member_impl(&org_id, &member_id, &headers, &conn, &nt).await } -async fn _delete_member( +async fn delete_member_impl( org_id: &OrganizationId, member_id: &MembershipId, headers: &AdminHeaders, @@ -1753,8 +1758,8 @@ async fn bulk_public_keys( }))) } -use super::ciphers::update_cipher_from_data; use super::ciphers::CipherData; +use super::ciphers::update_cipher_from_data; #[derive(Deserialize)] #[serde(rename_all = "camelCase")] @@ -1902,24 +1907,24 @@ async fn post_bulk_collections(data: Json, headers: Headers } } - for cipher_id in data.cipher_ids.iter() { + for cipher_id in &data.cipher_ids { // Only act on existing cipher uuid's // Do not abort the operation just ignore it, it could be a cipher was just deleted for example - if let Some(cipher) = Cipher::find_by_uuid_and_org(cipher_id, &data.organization_id, &conn).await { - if cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await { - // When selecting a specific collection from the left filter list, and use the bulk option, you can remove an item from that collection - // In these cases the client will call this endpoint twice, once for adding the new collections and a second for deleting. - if data.remove_collections { - for collection in &data.collection_ids { - CollectionCipher::delete(&cipher.uuid, collection, &conn).await?; - } - } else { - for collection in &data.collection_ids { - CollectionCipher::save(&cipher.uuid, collection, &conn).await?; - } + if let Some(cipher) = Cipher::find_by_uuid_and_org(cipher_id, &data.organization_id, &conn).await + && cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await + { + // When selecting a specific collection from the left filter list, and use the bulk option, you can remove an item from that collection + // In these cases the client will call this endpoint twice, once for adding the new collections and a second for deleting. + if data.remove_collections { + for collection in &data.collection_ids { + CollectionCipher::delete(&cipher.uuid, collection, &conn).await?; + } + } else { + for collection in &data.collection_ids { + CollectionCipher::save(&cipher.uuid, collection, &conn).await?; } } - }; + } } Ok(()) @@ -1969,7 +1974,7 @@ async fn list_policies_token(org_id: OrganizationId, token: &str, conn: DbConn) fn get_dummy_master_password_policy() -> JsonResult { let (enabled, data) = match CONFIG.sso_master_password_policy_value() { Some(policy) if CONFIG.sso_enabled() => (true, policy.to_string()), - _ => (false, "null".to_string()), + _ => (false, "null".to_owned()), }; let policy = OrgPolicy::new(FAKE_SSO_IDENTIFIER.into(), OrgPolicyType::MasterPassword, enabled, data); Ok(Json(policy.to_json())) @@ -1982,7 +1987,7 @@ async fn get_master_password_policy(org_id: OrganizationId, _headers: OrgMemberH OrgPolicy::find_by_org_and_type(&org_id, OrgPolicyType::MasterPassword, &conn).await.unwrap_or_else(|| { let (enabled, data) = match CONFIG.sso_master_password_policy_value() { Some(policy) if CONFIG.sso_enabled() => (true, policy.to_string()), - _ => (false, "null".to_string()), + _ => (false, "null".to_owned()), }; OrgPolicy::new(org_id, OrgPolicyType::MasterPassword, enabled, data) @@ -2003,7 +2008,7 @@ async fn get_policy(org_id: OrganizationId, pol_type: i32, headers: AdminHeaders let policy = match OrgPolicy::find_by_org_and_type(&org_id, pol_type_enum, &conn).await { Some(p) => p, - None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "null".to_string()), + None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "null".to_owned()), }; Ok(Json(policy.to_json())) @@ -2078,7 +2083,7 @@ async fn put_policy( // When enabling the SingleOrg policy, remove this org's members that are members of other orgs if pol_type_enum == OrgPolicyType::SingleOrg && data.enabled { - for mut member in Membership::find_by_org(&org_id, &conn).await.into_iter() { + for mut member in Membership::find_by_org(&org_id, &conn).await { // Policy only applies to non-Owner/non-Admin members who have accepted joining the org // Exclude invited and revoked users when checking for this policy. // Those users will not be allowed to accept or be activated because of the policy checks done there. @@ -2113,7 +2118,7 @@ async fn put_policy( let mut policy = match OrgPolicy::find_by_org_and_type(&org_id, pol_type_enum, &conn).await { Some(p) => p, - None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "{}".to_string()), + None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "{}".to_owned()), }; policy.enabled = data.enabled; @@ -2187,7 +2192,7 @@ fn get_plans() -> Json { #[get("/organizations/<_org_id>/billing/metadata")] fn get_billing_metadata(_org_id: OrganizationId, _headers: OrgMemberHeaders) -> Json { // Prevent a 404 error, which also causes Javascript errors. - Json(_empty_data_json()) + Json(empty_data_json()) } #[get("/organizations/<_org_id>/billing/vnext/warnings")] @@ -2209,7 +2214,7 @@ fn get_self_host_billing_metadata(_org_id: OrganizationId, _headers: OrgMemberHe })) } -fn _empty_data_json() -> Value { +fn empty_data_json() -> Value { json!({ "object": "list", "data": [], @@ -2230,7 +2235,7 @@ async fn revoke_member( headers: AdminHeaders, conn: DbConn, ) -> EmptyResult { - _revoke_member(&org_id, &member_id, &headers, &conn).await + revoke_member_impl(&org_id, &member_id, &headers, &conn).await } #[put("/organizations//users/revoke", data = "")] @@ -2249,8 +2254,8 @@ async fn bulk_revoke_members( match data.ids { Some(members) => { for member_id in members { - let err_msg = match _revoke_member(&org_id, &member_id, &headers, &conn).await { - Ok(_) => String::new(), + let err_msg = match revoke_member_impl(&org_id, &member_id, &headers, &conn).await { + Ok(()) => String::new(), Err(e) => format!("{e:?}"), }; @@ -2273,7 +2278,7 @@ async fn bulk_revoke_members( }))) } -async fn _revoke_member( +async fn revoke_member_impl( org_id: &OrganizationId, member_id: &MembershipId, headers: &AdminHeaders, @@ -2325,7 +2330,7 @@ async fn restore_member_vnext( ) -> EmptyResult { // Vaultwarden does not (yet) support the per User Collection linked to the `Enforce organization data ownership` policy. // Therefor we ignore the `defaultUserCollectionName` data sent and just call restore_member - _restore_member(&org_id, &member_id, &headers, &conn).await + restore_member_impl(&org_id, &member_id, &headers, &conn).await } #[put("/organizations//users//restore")] @@ -2335,7 +2340,7 @@ async fn restore_member( headers: AdminHeaders, conn: DbConn, ) -> EmptyResult { - _restore_member(&org_id, &member_id, &headers, &conn).await + restore_member_impl(&org_id, &member_id, &headers, &conn).await } #[put("/organizations//users/restore", data = "")] @@ -2352,8 +2357,8 @@ async fn bulk_restore_members( let mut bulk_response = Vec::new(); for member_id in data.ids { - let err_msg = match _restore_member(&org_id, &member_id, &headers, &conn).await { - Ok(_) => String::new(), + let err_msg = match restore_member_impl(&org_id, &member_id, &headers, &conn).await { + Ok(()) => String::new(), Err(e) => format!("{e:?}"), }; @@ -2373,7 +2378,7 @@ async fn bulk_restore_members( }))) } -async fn _restore_member( +async fn restore_member_impl( org_id: &OrganizationId, member_id: &MembershipId, headers: &AdminHeaders, @@ -2429,11 +2434,11 @@ async fn get_groups_data( if details { for g in groups { - groups_json.push(g.to_json_details(&conn).await) + groups_json.push(g.to_json_details(&conn).await); } } else { for g in groups { - groups_json.push(g.to_json()) + groups_json.push(g.to_json()); } } groups_json @@ -2672,15 +2677,15 @@ async fn post_delete_group( headers: AdminHeaders, conn: DbConn, ) -> EmptyResult { - _delete_group(&org_id, &group_id, &headers, &conn).await + delete_group_impl(&org_id, &group_id, &headers, &conn).await } #[delete("/organizations//groups/")] async fn delete_group(org_id: OrganizationId, group_id: GroupId, headers: AdminHeaders, conn: DbConn) -> EmptyResult { - _delete_group(&org_id, &group_id, &headers, &conn).await + delete_group_impl(&org_id, &group_id, &headers, &conn).await } -async fn _delete_group( +async fn delete_group_impl( org_id: &OrganizationId, group_id: &GroupId, headers: &AdminHeaders, @@ -2728,7 +2733,7 @@ async fn bulk_delete_groups( let data: BulkGroupIds = data.into_inner(); for group_id in data.ids { - _delete_group(&org_id, &group_id, &headers, &conn).await? + delete_group_impl(&org_id, &group_id, &headers, &conn).await?; } Ok(()) } @@ -2765,7 +2770,7 @@ async fn get_group_members( if Group::find_by_uuid_and_org(&group_id, &org_id, &conn).await.is_none() { err!("Group could not be found!", "Group uuid is invalid or does not belong to the organization") - }; + } let group_members: Vec = GroupUser::find_by_group(&group_id, &org_id, &conn) .await @@ -2793,7 +2798,7 @@ async fn put_group_members( if Group::find_by_uuid_and_org(&group_id, &org_id, &conn).await.is_none() { err!("Group could not be found!", "Group uuid is invalid or does not belong to the organization") - }; + } let assigned_members = data.into_inner(); @@ -3100,12 +3105,12 @@ async fn get_org_export(org_id: OrganizationId, headers: AdminHeaders, conn: DbC } Ok(Json(json!({ - "collections": convert_json_key_lcase_first(_get_org_collections(&org_id, &conn).await), - "ciphers": convert_json_key_lcase_first(_get_org_details(&org_id, &headers.host, &headers.user.uuid, &conn).await?), + "collections": convert_json_key_lcase_first(get_org_collections_impl(&org_id, &conn).await), + "ciphers": convert_json_key_lcase_first(get_org_details_impl(&org_id, &headers.host, &headers.user.uuid, &conn).await?), }))) } -async fn _api_key( +async fn api_key( org_id: &OrganizationId, data: Json, rotate: bool, @@ -3121,21 +3126,18 @@ async fn _api_key( // Validate the admin users password/otp data.validate(&user, true, &conn).await?; - let org_api_key = match OrganizationApiKey::find_by_org_uuid(org_id, &conn).await { - Some(mut org_api_key) => { - if rotate { - org_api_key.api_key = crate::crypto::generate_api_key(); - org_api_key.revision_date = chrono::Utc::now().naive_utc(); - org_api_key.save(&conn).await.expect("Error rotating organization API Key"); - } - org_api_key - } - None => { - let api_key = crate::crypto::generate_api_key(); - let new_org_api_key = OrganizationApiKey::new(org_id.clone(), api_key); - new_org_api_key.save(&conn).await.expect("Error creating organization API Key"); - new_org_api_key + let org_api_key = if let Some(mut org_api_key) = OrganizationApiKey::find_by_org_uuid(org_id, &conn).await { + if rotate { + org_api_key.api_key = crate::crypto::generate_api_key(); + org_api_key.revision_date = chrono::Utc::now().naive_utc(); + org_api_key.save(&conn).await.expect("Error rotating organization API Key"); } + org_api_key + } else { + let api_key = crate::crypto::generate_api_key(); + let new_org_api_key = OrganizationApiKey::new(org_id.clone(), api_key); + new_org_api_key.save(&conn).await.expect("Error creating organization API Key"); + new_org_api_key }; Ok(Json(json!({ @@ -3146,13 +3148,13 @@ async fn _api_key( } #[post("/organizations//api-key", data = "")] -async fn api_key( +async fn post_api_key( org_id: OrganizationId, data: Json, headers: AdminHeaders, conn: DbConn, ) -> JsonResult { - _api_key(&org_id, data, false, headers, conn).await + api_key(&org_id, data, false, headers, conn).await } #[post("/organizations//rotate-api-key", data = "")] @@ -3162,5 +3164,5 @@ async fn rotate_api_key( headers: AdminHeaders, conn: DbConn, ) -> JsonResult { - _api_key(&org_id, data, true, headers, conn).await + api_key(&org_id, data, true, headers, conn).await } diff --git a/src/api/core/public.rs b/src/api/core/public.rs index d757d953..f50b2543 100644 --- a/src/api/core/public.rs +++ b/src/api/core/public.rs @@ -1,23 +1,24 @@ use chrono::Utc; use rocket::{ + Request, Route, request::{FromRequest, Outcome}, serde::json::Json, - Request, Route, }; use std::collections::HashSet; use crate::{ + CONFIG, api::EmptyResult, auth, db::{ + DbConn, models::{ Group, GroupUser, Invitation, Membership, MembershipStatus, MembershipType, Organization, OrganizationApiKey, OrganizationId, User, }, - DbConn, }, - mail, CONFIG, + mail, }; pub fn routes() -> Vec { @@ -90,19 +91,18 @@ async fn ldap_import(data: Json, token: PublicToken, conn: DbConn } } else { // If user is not part of the organization - let user = match User::find_by_mail(&user_data.email, &conn).await { - Some(user) => user, // exists in vaultwarden - None => { - // User does not exist yet - let mut new_user = User::new(&user_data.email, None); - new_user.save(&conn).await?; - - if !CONFIG.mail_enabled() { - Invitation::new(&new_user.email).save(&conn).await?; - } - user_created = true; - new_user + let user = if let Some(user) = User::find_by_mail(&user_data.email, &conn).await { + user + } else { + // User does not exist yet + let mut new_user = User::new(&user_data.email, None); + new_user.save(&conn).await?; + + if !CONFIG.mail_enabled() { + Invitation::new(&new_user.email).save(&conn).await?; } + user_created = true; + new_user }; let member_status = if CONFIG.mail_enabled() || user.password_hash.is_empty() { MembershipStatus::Invited as i32 @@ -110,9 +110,10 @@ async fn ldap_import(data: Json, token: PublicToken, conn: DbConn MembershipStatus::Accepted as i32 // Automatically mark user as accepted if no email invites }; - let (org_name, org_email) = match Organization::find_by_uuid(&org_id, &conn).await { - Some(org) => (org.name, org.billing_email), - None => err!("Error looking up organization"), + let (org_name, org_email) = if let Some(org) = Organization::find_by_uuid(&org_id, &conn).await { + (org.name, org.billing_email) + } else { + err!("Error looking up organization") }; let mut new_member = Membership::new(user.uuid.clone(), org_id.clone(), Some(org_email.clone())); @@ -123,37 +124,33 @@ async fn ldap_import(data: Json, token: PublicToken, conn: DbConn new_member.save(&conn).await?; - if CONFIG.mail_enabled() { - if let Err(e) = + if CONFIG.mail_enabled() + && let Err(e) = mail::send_invite(&user, org_id.clone(), new_member.uuid.clone(), &org_name, Some(org_email)).await - { - // Upon error delete the user, invite and org member records when needed - if user_created { - user.delete(&conn).await?; - } else { - new_member.delete(&conn).await?; - } - - err!(format!("Error sending invite: {e:?} ")); + { + // Upon error delete the user, invite and org member records when needed + if user_created { + user.delete(&conn).await?; + } else { + new_member.delete(&conn).await?; } + + err!(format!("Error sending invite: {e:?} ")); } } } if CONFIG.org_groups_enabled() { for group_data in &data.groups { - let group_uuid = match Group::find_by_external_id_and_org(&group_data.external_id, &org_id, &conn).await { - Some(group) => group.uuid, - None => { - let mut group = Group::new( - org_id.clone(), - group_data.name.clone(), - false, - Some(group_data.external_id.clone()), - ); - group.save(&conn).await?; - group.uuid - } + let group_uuid = if let Some(group) = + Group::find_by_external_id_and_org(&group_data.external_id, &org_id, &conn).await + { + group.uuid + } else { + let mut group = + Group::new(org_id.clone(), group_data.name.clone(), false, Some(group_data.external_id.clone())); + group.save(&conn).await?; + group.uuid }; GroupUser::delete_all_by_group(&group_uuid, &org_id, &conn).await?; @@ -174,18 +171,17 @@ async fn ldap_import(data: Json, token: PublicToken, conn: DbConn // Generate a HashSet to quickly verify if a member is listed or not. let sync_members: HashSet = data.members.into_iter().map(|m| m.external_id).collect(); for member in Membership::find_by_org(&org_id, &conn).await { - if let Some(ref user_external_id) = member.external_id { - if !sync_members.contains(user_external_id) { - if member.atype == MembershipType::Owner && member.status == MembershipStatus::Confirmed as i32 { - // Removing owner, check that there is at least one other confirmed owner - if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &conn).await <= 1 - { - warn!("Can't delete the last owner"); - continue; - } + if let Some(ref user_external_id) = member.external_id + && !sync_members.contains(user_external_id) + { + if member.atype == MembershipType::Owner && member.status == MembershipStatus::Confirmed as i32 { + // Removing owner, check that there is at least one other confirmed owner + if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &conn).await <= 1 { + warn!("Can't delete the last owner"); + continue; } - member.delete(&conn).await?; } + member.delete(&conn).await?; } } } @@ -202,12 +198,14 @@ impl<'r> FromRequest<'r> for PublicToken { async fn from_request(request: &'r Request<'_>) -> Outcome { let headers = request.headers(); // Get access_token - let access_token: &str = match headers.get_one("Authorization") { - Some(a) => match a.rsplit("Bearer ").next() { - Some(split) => split, - None => err_handler!("No access token provided"), - }, - None => err_handler!("No access token provided"), + let access_token: &str = if let Some(a) = headers.get_one("Authorization") { + if let Some(split) = a.rsplit("Bearer ").next() { + split + } else { + err_handler!("No access token provided") + } + } else { + err_handler!("No access token provided") }; // Check JWT token is valid and get device and user from it let Ok(claims) = auth::decode_api_org(access_token) else { @@ -229,14 +227,13 @@ impl<'r> FromRequest<'r> for PublicToken { // Check if claims.sub is org_api_key.uuid // Check if claims.client_sub is org_api_key.org_uuid - let conn = match DbConn::from_request(request).await { - Outcome::Success(conn) => conn, - _ => err_handler!("Error getting DB"), + let Outcome::Success(conn) = DbConn::from_request(request).await else { + err_handler!("Error getting DB") }; let Some(org_id) = claims.client_id.strip_prefix("organization.") else { err_handler!("Malformed client_id") }; - let org_id: OrganizationId = org_id.to_string().into(); + let org_id: OrganizationId = org_id.to_owned().into(); let Some(org_api_key) = OrganizationApiKey::find_by_org_uuid(&org_id, &conn).await else { err_handler!("Invalid client_id") }; diff --git a/src/api/core/sends.rs b/src/api/core/sends.rs index 45ead810..2a7e06c1 100644 --- a/src/api/core/sends.rs +++ b/src/api/core/sends.rs @@ -10,15 +10,15 @@ use rocket::{ use serde_json::Value; use crate::{ + CONFIG, api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType}, auth::{ClientIp, Headers, Host}, config::PathType, db::{ - models::{Device, OrgPolicy, OrgPolicyType, Send, SendFileId, SendId, SendType, UserId}, DbConn, DbPool, + models::{Device, OrgPolicy, OrgPolicyType, Send, SendFileId, SendId, SendType, UserId}, }, - util::{save_temp_file, NumberOrString}, - CONFIG, + util::{NumberOrString, save_temp_file}, }; const SEND_INACCESSIBLE_MSG: &str = "Send does not exist or is no longer available"; @@ -63,7 +63,7 @@ pub async fn purge_sends(pool: DbPool) { if let Ok(conn) = pool.get().await { Send::purge(&conn).await; } else { - error!("Failed to get DB connection while purging sends") + error!("Failed to get DB connection while purging sends"); } } @@ -168,7 +168,7 @@ fn create_send(data: SendData, user_id: UserId) -> ApiResult { #[get("/sends")] async fn get_sends(headers: Headers, conn: DbConn) -> Json { let sends = Send::find_by_user(&headers.user.uuid, &conn); - let sends_json: Vec = sends.await.iter().map(|s| s.to_json()).collect(); + let sends_json: Vec = sends.await.iter().map(Send::to_json).collect(); Json(json!({ "data": sends_json, @@ -179,9 +179,10 @@ async fn get_sends(headers: Headers, conn: DbConn) -> Json { #[get("/sends/")] async fn get_send(send_id: SendId, headers: Headers, conn: DbConn) -> JsonResult { - match Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await { - Some(send) => Ok(Json(send.to_json())), - None => err!("Send not found", "Invalid send uuid or does not belong to user"), + if let Some(send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await { + Ok(Json(send.to_json())) + } else { + err!("Send not found", "Invalid send uuid or does not belong to user") } } @@ -310,9 +311,10 @@ async fn post_send_file_v2(data: Json, headers: Headers, conn: DbConn) enforce_disable_hide_email_policy(&data, &headers, &conn).await?; - let file_length = match &data.file_length { - Some(m) => m.into_i64()?, - _ => err!("Invalid send length"), + let file_length = if let Some(m) = &data.file_length { + m.into_i64()? + } else { + err!("Invalid send length") }; if file_length < 0 { err!("Send size can't be negative") @@ -457,16 +459,16 @@ async fn post_access( err_code!(SEND_INACCESSIBLE_MSG, 404) }; - if let Some(max_access_count) = send.max_access_count { - if send.access_count >= max_access_count { - err_code!(SEND_INACCESSIBLE_MSG, 404); - } + if let Some(max_access_count) = send.max_access_count + && send.access_count >= max_access_count + { + err_code!(SEND_INACCESSIBLE_MSG, 404); } - if let Some(expiration) = send.expiration_date { - if Utc::now().naive_utc() >= expiration { - err_code!(SEND_INACCESSIBLE_MSG, 404) - } + if let Some(expiration) = send.expiration_date + && Utc::now().naive_utc() >= expiration + { + err_code!(SEND_INACCESSIBLE_MSG, 404) } if Utc::now().naive_utc() >= send.deletion_date { @@ -517,16 +519,16 @@ async fn post_access_file( err_code!(SEND_INACCESSIBLE_MSG, 404) }; - if let Some(max_access_count) = send.max_access_count { - if send.access_count >= max_access_count { - err_code!(SEND_INACCESSIBLE_MSG, 404) - } + if let Some(max_access_count) = send.max_access_count + && send.access_count >= max_access_count + { + err_code!(SEND_INACCESSIBLE_MSG, 404) } - if let Some(expiration) = send.expiration_date { - if Utc::now().naive_utc() >= expiration { - err_code!(SEND_INACCESSIBLE_MSG, 404) - } + if let Some(expiration) = send.expiration_date + && Utc::now().naive_utc() >= expiration + { + err_code!(SEND_INACCESSIBLE_MSG, 404) } if Utc::now().naive_utc() >= send.deletion_date { @@ -572,7 +574,7 @@ async fn download_url(host: &Host, send_id: &SendId, file_id: &SendFileId) -> Re let token_claims = crate::auth::generate_send_claims(send_id, file_id); let token = crate::auth::encode_jwt(&token_claims); - Ok(format!("{}/api/sends/{send_id}/{file_id}?t={token}", &host.host)) + Ok(format!("{}/api/sends/{send_id}/{file_id}?t={token}", host.host)) } else { Ok(operator.presign_read(&format!("{send_id}/{file_id}"), Duration::from_mins(5)).await?.uri().to_string()) } @@ -580,10 +582,10 @@ async fn download_url(host: &Host, send_id: &SendId, file_id: &SendFileId) -> Re #[get("/sends//?")] async fn download_send(send_id: SendId, file_id: SendFileId, t: &str) -> Option { - if let Ok(claims) = crate::auth::decode_send(t) { - if claims.sub == format!("{send_id}/{file_id}") { - return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok(); - } + if let Ok(claims) = crate::auth::decode_send(t) + && claims.sub == format!("{send_id}/{file_id}") + { + return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok(); } None } diff --git a/src/api/core/two_factor/authenticator.rs b/src/api/core/two_factor/authenticator.rs index 4759aa3c..44cd7427 100644 --- a/src/api/core/two_factor/authenticator.rs +++ b/src/api/core/two_factor/authenticator.rs @@ -1,14 +1,14 @@ use data_encoding::BASE32; -use rocket::serde::json::Json; use rocket::Route; +use rocket::serde::json::Json; use crate::{ - api::{core::log_user_event, core::two_factor::_generate_recover_code, EmptyResult, JsonResult, PasswordOrOtpData}, + api::{EmptyResult, JsonResult, PasswordOrOtpData, core::log_user_event, core::two_factor::generate_recover_code}, auth::{ClientIp, Headers}, crypto, db::{ - models::{EventType, TwoFactor, TwoFactorType, UserId}, DbConn, + models::{EventType, TwoFactor, TwoFactorType, UserId}, }, util::NumberOrString, }; @@ -70,9 +70,10 @@ async fn activate_authenticator(data: Json, headers: He .await?; // Validate key as base32 and 20 bytes length - let decoded_key: Vec = match BASE32.decode(key.as_bytes()) { - Ok(decoded) => decoded, - _ => err!("Invalid totp secret"), + let decoded_key: Vec = if let Ok(decoded) = BASE32.decode(key.as_bytes()) { + decoded + } else { + err!("Invalid totp secret") }; if decoded_key.len() != 20 { @@ -82,7 +83,7 @@ async fn activate_authenticator(data: Json, headers: He // Validate the token provided with the key, and save new twofactor validate_totp_code(&user.uuid, &token, &key.to_uppercase(), &headers.ip, &conn).await?; - _generate_recover_code(&mut user, &conn).await; + generate_recover_code(&mut user, &conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await; @@ -119,7 +120,7 @@ pub async fn validate_totp_code( ip: &ClientIp, conn: &DbConn, ) -> EmptyResult { - use totp_lite::{totp_custom, Sha1}; + use totp_lite::{Sha1, totp_custom}; let Ok(decoded_secret) = BASE32.decode(secret.as_bytes()) else { err!("Invalid TOTP secret") @@ -128,7 +129,7 @@ pub async fn validate_totp_code( let mut twofactor = match TwoFactor::find_by_user_and_type(user_id, TwoFactorType::Authenticator as i32, conn).await { Some(tf) => tf, - _ => TwoFactor::new(user_id.clone(), TwoFactorType::Authenticator, secret.to_string()), + _ => TwoFactor::new(user_id.clone(), TwoFactorType::Authenticator, secret.to_owned()), }; // The amount of steps back and forward in time @@ -145,7 +146,7 @@ pub async fn validate_totp_code( // We need to calculate the time offsite and cast it as an u64. // Since we only have times into the future and the totp generator needs an u64 instead of the default i64. - let time = (current_timestamp + step * 30i64) as u64; + let time: u64 = (current_timestamp + step * 30i64).cast_unsigned(); let generated = totp_custom::(30, 6, &decoded_secret, time); // Check the given code equals the generated and if the time_step is larger then the one last used. diff --git a/src/api/core/two_factor/duo.rs b/src/api/core/two_factor/duo.rs index f2de50c3..512de9c1 100644 --- a/src/api/core/two_factor/duo.rs +++ b/src/api/core/two_factor/duo.rs @@ -1,22 +1,22 @@ use chrono::Utc; use data_encoding::BASE64; -use rocket::serde::json::Json; use rocket::Route; +use rocket::serde::json::Json; use crate::{ + CONFIG, api::{ - core::log_user_event, core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult, - PasswordOrOtpData, + ApiResult, EmptyResult, JsonResult, PasswordOrOtpData, core::log_user_event, + core::two_factor::generate_recover_code, }, auth::Headers, crypto, db::{ - models::{EventType, TwoFactor, TwoFactorType, User, UserId}, DbConn, + models::{EventType, TwoFactor, TwoFactorType, User, UserId}, }, error::MapResult, http_client::make_http_request, - CONFIG, }; pub fn routes() -> Vec { @@ -82,8 +82,7 @@ enum DuoStatus { impl DuoStatus { fn data(self) -> Option { match self { - DuoStatus::Global(data) => Some(data), - DuoStatus::User(data) => Some(data), + DuoStatus::Global(data) | DuoStatus::User(data) => Some(data), DuoStatus::Disabled(_) => None, } } @@ -182,7 +181,7 @@ async fn activate_duo(data: Json, headers: Headers, conn: DbConn) let twofactor = TwoFactor::new(user.uuid.clone(), type_, data_str); twofactor.save(&conn).await?; - _generate_recover_code(&mut user, &conn).await; + generate_recover_code(&mut user, &conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await; @@ -201,14 +200,14 @@ async fn activate_duo_put(data: Json, headers: Headers, conn: DbC } async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult { - use reqwest::{header, Method}; + use reqwest::{Method, header}; use std::str::FromStr; // https://duo.com/docs/authapi#api-details - let url = format!("https://{}{path}", &data.host); - let date = Utc::now().to_rfc2822(); + let url = format!("https://{}{path}", data.host); + let dt = Utc::now().to_rfc2822(); let username = &data.ik; - let fields = [&date, method, &data.host, path, params]; + let fields = [&dt, method, &data.host, path, params]; let password = crypto::hmac_sign(&data.sk, &fields.join("\n")); let m = Method::from_str(method).unwrap_or_default(); @@ -216,7 +215,7 @@ async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) make_http_request(m, &url)? .basic_auth(username, Some(password)) .header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)") - .header(header::DATE, date) + .header(header::DATE, dt) .send() .await? .error_for_status()?; @@ -356,9 +355,10 @@ fn parse_duo_values(key: &str, val: &str, ikey: &str, prefix: &str, time: i64) - err!("Invalid ikey") } - let expire: i64 = match expire.parse() { - Ok(e) => e, - Err(_) => err!("Invalid expire time"), + let expire: i64 = if let Ok(e) = expire.parse() { + e + } else { + err!("Invalid expire time") }; if time >= expire { diff --git a/src/api/core/two_factor/duo_oidc.rs b/src/api/core/two_factor/duo_oidc.rs index 144ffe84..560e6f65 100644 --- a/src/api/core/two_factor/duo_oidc.rs +++ b/src/api/core/two_factor/duo_oidc.rs @@ -1,21 +1,21 @@ use chrono::Utc; use data_encoding::HEXLOWER; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation}; -use reqwest::{header, StatusCode}; -use ring::digest::{digest, Digest, SHA512_256}; +use reqwest::{StatusCode, header}; +use ring::digest::{Digest, SHA512_256, digest}; use serde::Serialize; use std::collections::HashMap; use crate::{ - api::{core::two_factor::duo::get_duo_keys_email, EmptyResult}, + CONFIG, + api::{EmptyResult, core::two_factor::duo::get_duo_keys_email}, crypto, db::{ - models::{DeviceId, EventType, TwoFactorDuoContext}, DbConn, DbPool, + models::{DeviceId, EventType, TwoFactorDuoContext}, }, error::Error, http_client::make_http_request, - CONFIG, }; use url::Url; @@ -124,7 +124,7 @@ impl DuoClient { ClientAssertion { iss: self.client_id.clone(), sub: self.client_id.clone(), - aud: url.to_string(), + aud: url.to_owned(), exp: now + JWT_VALIDITY_SECS, jti: jwt_id, iat: now, @@ -302,7 +302,7 @@ impl DuoClient { if !(matching_nonces && matching_usernames) { err!("Error validating Duo authorization, nonce or username mismatch.") - }; + } Ok(()) } @@ -347,7 +347,7 @@ pub async fn purge_duo_contexts(pool: DbPool) { if let Ok(conn) = pool.get().await { TwoFactorDuoContext::purge_expired_duo_contexts(&conn).await; } else { - error!("Failed to get DB connection while purging expired Duo authentications") + error!("Failed to get DB connection while purging expired Duo authentications"); } } @@ -394,7 +394,7 @@ pub async fn get_duo_auth_url( match client.health_check().await { Ok(()) => {} Err(e) => return Err(e), - }; + } // Generate random OAuth2 state and OIDC Nonce let state: String = crypto::get_random_string_alphanum(STATE_LENGTH); @@ -438,16 +438,13 @@ pub async fn validate_duo_login( // Get the context by the state reported by the client. If we don't have one, // it means the context is either missing or expired. - let ctx = match extract_context(state, conn).await { - Some(c) => c, - None => { - err!( - "Error validating duo authentication", - ErrorEvent { - event: EventType::UserFailedLogIn2fa - } - ) - } + let Some(ctx) = extract_context(state, conn).await else { + err!( + "Error validating duo authentication", + ErrorEvent { + event: EventType::UserFailedLogIn2fa + } + ) }; // Context validation steps @@ -476,13 +473,13 @@ pub async fn validate_duo_login( match client.health_check().await { Ok(()) => {} Err(e) => return Err(e), - }; + } let d: Digest = digest(&SHA512_256, format!("{}{device_identifier}", ctx.nonce).as_bytes()); let hash: String = HEXLOWER.encode(d.as_ref()); match client.exchange_authz_code_for_result(code, email, hash.as_str()).await { - Ok(_) => Ok(()), + Ok(()) => Ok(()), Err(_) => { err!( "Error validating duo authentication", diff --git a/src/api/core/two_factor/email.rs b/src/api/core/two_factor/email.rs index 7fa350de..d2ede49f 100644 --- a/src/api/core/two_factor/email.rs +++ b/src/api/core/two_factor/email.rs @@ -1,20 +1,21 @@ use chrono::{DateTime, TimeDelta, Utc}; -use rocket::serde::json::Json; use rocket::Route; +use rocket::serde::json::Json; use crate::{ + CONFIG, api::{ - core::{log_user_event, two_factor::_generate_recover_code}, EmptyResult, JsonResult, PasswordOrOtpData, + core::{log_user_event, two_factor::generate_recover_code}, }, auth::{ClientHeaders, Headers}, crypto, db::{ - models::{AuthRequest, AuthRequestId, DeviceId, EventType, TwoFactor, TwoFactorType, User, UserId}, DbConn, + models::{AuthRequest, AuthRequestId, DeviceId, EventType, TwoFactor, TwoFactorType, User, UserId}, }, error::{Error, MapResult}, - mail, CONFIG, + mail, }; pub fn routes() -> Vec { @@ -232,7 +233,7 @@ async fn email(data: Json, headers: Headers, conn: DbConn) -> JsonRes twofactor.data = email_data.to_json(); twofactor.save(&conn).await?; - _generate_recover_code(&mut user, &conn).await; + generate_recover_code(&mut user, &conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await; @@ -284,9 +285,9 @@ pub async fn validate_email_code_str( twofactor.data = email_data.to_json(); twofactor.save(conn).await?; - let date = DateTime::from_timestamp(email_data.token_sent, 0).expect("Email token timestamp invalid.").naive_utc(); - let max_time = CONFIG.email_expiration_time() as i64; - if date + TimeDelta::try_seconds(max_time).unwrap() < Utc::now().naive_utc() { + let dt = DateTime::from_timestamp(email_data.token_sent, 0).expect("Email token timestamp invalid.").naive_utc(); + let max_time = CONFIG.email_expiration_time().cast_signed(); + if dt + TimeDelta::try_seconds(max_time).unwrap() < Utc::now().naive_utc() { err!( "Token has expired", ErrorEvent { @@ -342,9 +343,10 @@ impl EmailTokenData { pub fn from_json(string: &str) -> Result { let res: Result = serde_json::from_str(string); - match res { - Ok(x) => Ok(x), - Err(_) => err!("Could not decode EmailTokenData from string"), + if let Ok(x) = res { + Ok(x) + } else { + err!("Could not decode EmailTokenData from string") } } } @@ -362,18 +364,17 @@ pub async fn activate_email_2fa(user: &User, conn: &DbConn) -> EmptyResult { pub fn obscure_email(email: &str) -> String { let split: Vec<&str> = email.rsplitn(2, '@').collect(); - let mut name = split[1].to_string(); + let mut name = split[1].to_owned(); let domain = &split[0]; let name_size = name.chars().count(); - let new_name = match name_size { - 1..=3 => "*".repeat(name_size), - _ => { - let stars = "*".repeat(name_size - 2); - name.truncate(2); - format!("{name}{stars}") - } + let new_name = if let 1..=3 = name_size { + "*".repeat(name_size) + } else { + let stars = "*".repeat(name_size - 2); + name.truncate(2); + format!("{name}{stars}") }; format!("{new_name}@{domain}") diff --git a/src/api/core/two_factor/mod.rs b/src/api/core/two_factor/mod.rs index 3a503a23..bf4e2282 100644 --- a/src/api/core/two_factor/mod.rs +++ b/src/api/core/two_factor/mod.rs @@ -1,28 +1,28 @@ use chrono::{TimeDelta, Utc}; use data_encoding::BASE32; use num_traits::FromPrimitive; -use rocket::serde::json::Json; use rocket::Route; +use rocket::serde::json::Json; use serde::Deserialize; use serde_json::Value; use crate::{ + CONFIG, api::{ - core::{log_event, log_user_event}, EmptyResult, JsonResult, PasswordOrOtpData, + core::{log_event, log_user_event}, }, auth::Headers, crypto, db::{ + DbConn, DbPool, models::{ DeviceType, EventType, Membership, MembershipType, OrgPolicyType, Organization, OrganizationId, TwoFactor, TwoFactorIncomplete, TwoFactorType, User, UserId, }, - DbConn, DbPool, }, mail, util::NumberOrString, - CONFIG, }; pub mod authenticator; @@ -37,7 +37,7 @@ fn has_global_duo_credentials() -> bool { CONFIG._enable_duo() && CONFIG.duo_host().is_some() && CONFIG.duo_ikey().is_some() && CONFIG.duo_skey().is_some() } -pub fn is_twofactor_provider_usable(provider_type: TwoFactorType, provider_data: Option<&str>) -> bool { +pub fn is_twofactor_provider_usable(provider_type: &TwoFactorType, provider_data: Option<&str>) -> bool { #[derive(Deserialize)] struct DuoProviderData { host: String, @@ -46,7 +46,7 @@ pub fn is_twofactor_provider_usable(provider_type: TwoFactorType, provider_data: } match provider_type { - TwoFactorType::Authenticator => true, + TwoFactorType::Authenticator | TwoFactorType::RecoveryCode => true, TwoFactorType::Email => CONFIG._enable_email_2fa(), TwoFactorType::Duo | TwoFactorType::OrganizationDuo => { provider_data @@ -59,7 +59,6 @@ pub fn is_twofactor_provider_usable(provider_type: TwoFactorType, provider_data: } TwoFactorType::Webauthn => CONFIG.is_webauthn_2fa_supported(), TwoFactorType::Remember => !CONFIG.disable_2fa_remember(), - TwoFactorType::RecoveryCode => true, TwoFactorType::U2f | TwoFactorType::U2fRegisterChallenge | TwoFactorType::U2fLoginChallenge @@ -96,7 +95,7 @@ async fn get_twofactor(headers: Headers, conn: DbConn) -> Json { .iter() .filter_map(|tf| { let provider_type = TwoFactorType::from_i32(tf.atype)?; - is_twofactor_provider_usable(provider_type, Some(&tf.data)).then(|| TwoFactor::to_json_provider(tf)) + is_twofactor_provider_usable(&provider_type, Some(&tf.data)).then(|| TwoFactor::to_json_provider(tf)) }) .collect(); @@ -120,7 +119,7 @@ async fn get_recover(data: Json, headers: Headers, conn: DbCo }))) } -async fn _generate_recover_code(user: &mut User, conn: &DbConn) { +async fn generate_recover_code(user: &mut User, conn: &DbConn) { if user.totp_recover.is_none() { let totp_recover = crypto::encode_random_bytes::<20>(&BASE32); user.totp_recover = Some(totp_recover); @@ -180,9 +179,7 @@ pub async fn enforce_2fa_policy( ip: &std::net::IpAddr, conn: &DbConn, ) -> EmptyResult { - for member in - Membership::find_by_user_and_policy(&user.uuid, OrgPolicyType::TwoFactorAuthentication, conn).await.into_iter() - { + for member in Membership::find_by_user_and_policy(&user.uuid, OrgPolicyType::TwoFactorAuthentication, conn).await { // Policy only applies to non-Owner/non-Admin members who have accepted joining the org if member.atype < MembershipType::Admin { if CONFIG.mail_enabled() { @@ -217,7 +214,7 @@ pub async fn enforce_2fa_policy_for_org( conn: &DbConn, ) -> EmptyResult { let org = Organization::find_by_uuid(org_id, conn).await.unwrap(); - for member in Membership::find_confirmed_by_org(org_id, conn).await.into_iter() { + for member in Membership::find_confirmed_by_org(org_id, conn).await { // Don't enforce the policy for Admins and Owners. if member.atype < MembershipType::Admin && TwoFactor::find_by_user(&member.user_uuid, conn).await.is_empty() { if CONFIG.mail_enabled() { @@ -251,12 +248,9 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) { return; } - let conn = match pool.get().await { - Ok(conn) => conn, - _ => { - error!("Failed to get DB connection in send_incomplete_2fa_notifications()"); - return; - } + let Ok(conn) = pool.get().await else { + error!("Failed to get DB connection in send_incomplete_2fa_notifications()"); + return; }; let now = Utc::now().naive_utc(); @@ -278,7 +272,7 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) { ) .await { - Ok(_) => { + Ok(()) => { if let Err(e) = login.delete(&conn).await { error!("Error deleting incomplete 2FA record: {e:#?}"); } diff --git a/src/api/core/two_factor/protected_actions.rs b/src/api/core/two_factor/protected_actions.rs index 800a6cf4..c0c1b5e8 100644 --- a/src/api/core/two_factor/protected_actions.rs +++ b/src/api/core/two_factor/protected_actions.rs @@ -1,16 +1,17 @@ -use chrono::{naive::serde::ts_seconds, NaiveDateTime, TimeDelta, Utc}; -use rocket::{serde::json::Json, Route}; +use chrono::{NaiveDateTime, TimeDelta, Utc, naive::serde::ts_seconds}; +use rocket::{Route, serde::json::Json}; use crate::{ + CONFIG, api::EmptyResult, auth::Headers, crypto, db::{ - models::{TwoFactor, TwoFactorType, UserId}, DbConn, + models::{TwoFactor, TwoFactorType, UserId}, }, error::{Error, MapResult}, - mail, CONFIG, + mail, }; pub fn routes() -> Vec { @@ -44,9 +45,10 @@ impl ProtectedActionData { pub fn from_json(string: &str) -> Result { let res: Result = serde_json::from_str(string); - match res { - Ok(x) => Ok(x), - Err(_) => err!("Could not decode ProtectedActionData from string"), + if let Ok(x) = res { + Ok(x) + } else { + err!("Could not decode ProtectedActionData from string") } } @@ -62,7 +64,9 @@ impl ProtectedActionData { #[post("/accounts/request-otp")] async fn request_otp(headers: Headers, conn: DbConn) -> EmptyResult { if !CONFIG.mail_enabled() { - err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device."); + err!( + "Email is disabled for this server. Either enable email or login using your master password instead of login via device." + ); } let user = headers.user; @@ -102,7 +106,9 @@ struct ProtectedActionVerify { #[post("/accounts/verify-otp", data = "")] async fn verify_otp(data: Json, headers: Headers, conn: DbConn) -> EmptyResult { if !CONFIG.mail_enabled() { - err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device."); + err!( + "Email is disabled for this server. Either enable email or login using your master password instead of login via device." + ); } let user = headers.user; @@ -133,7 +139,7 @@ pub async fn validate_protected_action_otp( } // Check if the token has expired (Using the email 2fa expiration time) - let max_time = CONFIG.email_expiration_time() as i64; + let max_time = CONFIG.email_expiration_time().cast_signed(); if pa_data.time_since_sent().num_seconds() > max_time { pa.delete(conn).await?; err!("Token has expired") diff --git a/src/api/core/two_factor/webauthn.rs b/src/api/core/two_factor/webauthn.rs index ad17ce36..ddcdc75a 100644 --- a/src/api/core/two_factor/webauthn.rs +++ b/src/api/core/two_factor/webauthn.rs @@ -1,20 +1,20 @@ use crate::{ + CONFIG, api::{ - core::{log_user_event, two_factor::_generate_recover_code}, EmptyResult, JsonResult, PasswordOrOtpData, + core::{log_user_event, two_factor::generate_recover_code}, }, auth::Headers, crypto::ct_eq, db::{ - models::{EventType, TwoFactor, TwoFactorType, UserId}, DbConn, + models::{EventType, TwoFactor, TwoFactorType, UserId}, }, error::Error, util::NumberOrString, - CONFIG, }; -use rocket::serde::json::Json; use rocket::Route; +use rocket::serde::json::Json; use serde_json::Value; use std::str::FromStr; use std::sync::LazyLock; @@ -149,7 +149,7 @@ async fn generate_webauthn_challenge(data: Json, headers: Hea )?; let mut state = serde_json::to_value(&state)?; - state["rs"]["policy"] = Value::String("discouraged".to_string()); + state["rs"]["policy"] = Value::String("discouraged".to_owned()); state["rs"]["extensions"].as_object_mut().unwrap().clear(); let type_ = TwoFactorType::WebauthnRegisterChallenge; @@ -265,13 +265,12 @@ async fn activate_webauthn(data: Json, headers: Headers, con // Retrieve and delete the saved challenge state let type_ = TwoFactorType::WebauthnRegisterChallenge as i32; - let state = match TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await { - Some(tf) => { - let state: PasskeyRegistration = serde_json::from_str(&tf.data)?; - tf.delete(&conn).await?; - state - } - None => err!("Can't recover challenge"), + let state = if let Some(tf) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await { + let state: PasskeyRegistration = serde_json::from_str(&tf.data)?; + tf.delete(&conn).await?; + state + } else { + err!("Can't recover challenge") }; // Verify the credentials with the saved state @@ -291,7 +290,7 @@ async fn activate_webauthn(data: Json, headers: Headers, con TwoFactor::new(user.uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(®istrations)?) .save(&conn) .await?; - _generate_recover_code(&mut user, &conn).await; + generate_recover_code(&mut user, &conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await; @@ -342,9 +341,10 @@ async fn delete_webauthn(data: Json, headers: Headers, conn: DbCo // If entry is migrated from u2f, delete the u2f entry as well if let Some(mut u2f) = TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::U2f as i32, &conn).await { - let mut data: Vec = match serde_json::from_str(&u2f.data) { - Ok(d) => d, - Err(_) => err!("Error parsing U2F data"), + let mut data: Vec = if let Ok(d) = serde_json::from_str(&u2f.data) { + d + } else { + err!("Error parsing U2F data") }; data.retain(|r| r.reg.key_handle != removed_item.credential.cred_id().as_slice()); @@ -388,10 +388,10 @@ pub async fn generate_webauthn_login(user_id: &UserId, conn: &DbConn) -> JsonRes // Modify to discourage user verification let mut state = serde_json::to_value(&state)?; - state["ast"]["policy"] = Value::String("discouraged".to_string()); + state["ast"]["policy"] = Value::String("discouraged".to_owned()); // Add appid, this is only needed for U2F compatibility, so maybe it can be removed as well - let app_id = format!("{}/app-id.json", &CONFIG.domain()); + let app_id = format!("{}/app-id.json", CONFIG.domain()); state["ast"]["appid"] = Value::String(app_id.clone()); response.public_key.user_verification = UserVerificationPolicy::Discouraged_DO_NOT_USE; @@ -416,18 +416,17 @@ pub async fn generate_webauthn_login(user_id: &UserId, conn: &DbConn) -> JsonRes pub async fn validate_webauthn_login(user_id: &UserId, response: &str, conn: &DbConn) -> EmptyResult { let type_ = TwoFactorType::WebauthnLoginChallenge as i32; - let mut state = match TwoFactor::find_by_user_and_type(user_id, type_, conn).await { - Some(tf) => { - let state: PasskeyAuthentication = serde_json::from_str(&tf.data)?; - tf.delete(conn).await?; - state - } - None => err!( + let mut state = if let Some(tf) = TwoFactor::find_by_user_and_type(user_id, type_, conn).await { + let state: PasskeyAuthentication = serde_json::from_str(&tf.data)?; + tf.delete(conn).await?; + state + } else { + err!( "Can't recover login challenge", ErrorEvent { event: EventType::UserFailedLogIn2fa } - ), + ) }; let rsp: PublicKeyCredentialCopy = serde_json::from_str(response)?; diff --git a/src/api/core/two_factor/yubikey.rs b/src/api/core/two_factor/yubikey.rs index 1cf11255..7412371d 100644 --- a/src/api/core/two_factor/yubikey.rs +++ b/src/api/core/two_factor/yubikey.rs @@ -1,20 +1,20 @@ -use rocket::serde::json::Json; use rocket::Route; +use rocket::serde::json::Json; use serde_json::Value; use yubico::{config::Config, verify_async}; use crate::{ + CONFIG, api::{ - core::{log_user_event, two_factor::_generate_recover_code}, EmptyResult, JsonResult, PasswordOrOtpData, + core::{log_user_event, two_factor::generate_recover_code}, }, auth::Headers, db::{ - models::{EventType, TwoFactor, TwoFactorType}, DbConn, + models::{EventType, TwoFactor, TwoFactorType}, }, error::{Error, MapResult}, - CONFIG, }; pub fn routes() -> Vec { @@ -46,7 +46,7 @@ pub struct YubikeyMetadata { fn parse_yubikeys(data: &EnableYubikeyData) -> Vec { let data_keys = [&data.key1, &data.key2, &data.key3, &data.key4, &data.key5]; - data_keys.iter().filter_map(|e| e.as_ref().cloned()).collect() + data_keys.into_iter().flatten().cloned().collect() } fn jsonify_yubikeys(yubikeys: Vec) -> Value { @@ -64,9 +64,10 @@ fn get_yubico_credentials() -> Result<(String, String), Error> { err!("Yubico support is disabled"); } - match (CONFIG.yubico_client_id(), CONFIG.yubico_secret_key()) { - (Some(id), Some(secret)) => Ok((id, secret)), - _ => err!("`YUBICO_CLIENT_ID` or `YUBICO_SECRET_KEY` environment variable is not set. Yubikey OTP Disabled"), + if let (Some(id), Some(secret)) = (CONFIG.yubico_client_id(), CONFIG.yubico_secret_key()) { + Ok((id, secret)) + } else { + err!("`YUBICO_CLIENT_ID` or `YUBICO_SECRET_KEY` environment variable is not set. Yubikey OTP Disabled") } } @@ -162,7 +163,7 @@ async fn activate_yubikey(data: Json, headers: Headers, conn: yubikey_data.data = serde_json::to_string(&yubikey_metadata).unwrap(); yubikey_data.save(&conn).await?; - _generate_recover_code(&mut user, &conn).await; + generate_recover_code(&mut user, &conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await; diff --git a/src/api/icons.rs b/src/api/icons.rs index 5c9ed113..02a14844 100644 --- a/src/api/icons.rs +++ b/src/api/icons.rs @@ -6,28 +6,29 @@ use std::{ }; use bytes::{Bytes, BytesMut}; -use futures::{stream::StreamExt, TryFutureExt}; +use futures::{TryFutureExt, stream::StreamExt}; use html5gum::{Emitter, HtmlString, Readable, StringReader, Tokenizer}; use regex::Regex; use reqwest::{ - header::{self, HeaderMap, HeaderValue}, Client, Response, + header::{self, HeaderMap, HeaderValue}, }; -use rocket::{http::ContentType, response::Redirect, Route}; -use svg_hush::{data_url_filter, Filter}; +use rocket::{Route, http::ContentType, response::Redirect}; +use svg_hush::{Filter, data_url_filter}; use crate::{ + CONFIG, config::PathType, error::Error, - http_client::{get_reqwest_client_builder, get_valid_host, should_block_host, CustomHttpClientError}, + http_client::{CustomHttpClientError, get_reqwest_client_builder, get_valid_host, should_block_host}, util::Cached, - CONFIG, }; pub fn routes() -> Vec { - match CONFIG.icon_service().as_str() { - "internal" => routes![icon_internal], - _ => routes![icon_external], + if CONFIG.icon_service().as_str() == "internal" { + routes![icon_internal] + } else { + routes![icon_external] } } @@ -147,7 +148,7 @@ async fn get_icon(domain: &str) -> Option<(Vec, String)> { if let Some(icon) = get_cached_icon(&path).await { let icon_type = get_icon_type(&icon).unwrap_or("x-icon"); - return Some((icon, icon_type.to_string())); + return Some((icon, icon_type.to_owned())); } if CONFIG.disable_icon_download() { @@ -158,7 +159,7 @@ async fn get_icon(domain: &str) -> Option<(Vec, String)> { match download_icon(domain).await { Ok((icon, icon_type)) => { save_icon(&path, icon.to_vec()).await; - Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string())) + Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_owned())) } Err(e) => { // If this error comes from the custom resolver, this means this is a blocked domain @@ -183,10 +184,10 @@ async fn get_cached_icon(path: &str) -> Option> { } // Try to read the cached icon, and return it if it exists - if let Ok(operator) = CONFIG.opendal_operator_for_path_type(&PathType::IconCache) { - if let Ok(buf) = operator.read(path).await { - return Some(buf.to_vec()); - } + if let Ok(operator) = CONFIG.opendal_operator_for_path_type(&PathType::IconCache) + && let Ok(buf) = operator.read(path).await + { + return Some(buf.to_vec()); } None @@ -280,17 +281,17 @@ fn get_favicons_node(dom: Tokenizer, FaviconEmitter>, icons: &m } for icon_tag in icon_tags { - if let Some(icon_href) = icon_tag.attributes.get(ATTR_HREF) { - if let Ok(full_href) = base_url.join(std::str::from_utf8(icon_href).unwrap_or_default()) { - let sizes = if let Some(v) = icon_tag.attributes.get(ATTR_SIZES) { - std::str::from_utf8(v).unwrap_or_default() - } else { - "" - }; - let priority = get_icon_priority(full_href.as_str(), sizes); - icons.push(Icon::new(priority, full_href.to_string())); - } - }; + if let Some(icon_href) = icon_tag.attributes.get(ATTR_HREF) + && let Ok(full_href) = base_url.join(std::str::from_utf8(icon_href).unwrap_or_default()) + { + let sizes = if let Some(v) = icon_tag.attributes.get(ATTR_SIZES) { + std::str::from_utf8(v).unwrap_or_default() + } else { + "" + }; + let priority = get_icon_priority(full_href.as_str(), sizes); + icons.push(Icon::new(priority, full_href.to_string())); + } } } @@ -406,7 +407,7 @@ async fn get_page(url: &str) -> Result { async fn get_page_with_referer(url: &str, referer: &str) -> Result { let mut client = CLIENT.get(url); if !referer.is_empty() { - client = client.header("Referer", referer) + client = client.header("Referer", referer); } Ok(client.send().await?.error_for_status()?) @@ -494,12 +495,10 @@ async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> { let mut buffer = Bytes::new(); let mut icon_type: Option<&str> = None; - use data_url::DataUrl; - let mut icons = icon_result.iconlist.iter().take(5).peekable(); while let Some(icon) = icons.next() { if icon.href.starts_with("data:image") { - let Ok(datauri) = DataUrl::process(&icon.href) else { + let Ok(datauri) = data_url::DataUrl::process(&icon.href) else { continue; }; // Check if we are able to decode the data uri @@ -523,7 +522,7 @@ async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> { } } _ => debug!("Extracted icon from data:image uri is invalid"), - }; + } } else { debug!("Trying {}", icon.href); // Make sure all icons are checked before returning error @@ -587,10 +586,10 @@ async fn save_icon(path: &str, icon: Vec) { fn get_icon_type(bytes: &[u8]) -> Option<&'static str> { fn check_svg_after_xml_declaration(bytes: &[u8]) -> Option<&'static str> { // Look for SVG tag within the first 1KB - if let Ok(content) = std::str::from_utf8(&bytes[..bytes.len().min(1024)]) { - if content.contains(" (), @@ -806,13 +805,13 @@ impl Emitter for FaviconEmitter { fn push_attribute_name(&mut self, s: &[u8]) { if let Some(attr) = &mut self.current_attribute { - attr.0.extend(s) + attr.0.extend(s); } } fn push_attribute_value(&mut self, s: &[u8]) { if let Some(attr) = &mut self.current_attribute { - attr.1.extend(s) + attr.1.extend(s); } } diff --git a/src/api/identity.rs b/src/api/identity.rs index 34078529..3962827d 100644 --- a/src/api/identity.rs +++ b/src/api/identity.rs @@ -1,18 +1,20 @@ use chrono::Utc; use num_traits::FromPrimitive; use rocket::{ + Route, form::{Form, FromForm}, http::{Cookie, CookieJar, SameSite}, response::Redirect, serde::json::Json, - Route, }; use serde_json::Value; use crate::{ + CONFIG, api::{ + ApiResult, EmptyResult, JsonResult, core::{ - accounts::{_prelogin, _register, kdf_upgrade, PreloginData, RegisterData}, + accounts::{PreloginData, RegisterData, kdf_upgrade, prelogin, register}, log_user_event, two_factor::{ authenticator, duo, duo_oidc, email, enforce_2fa_policy, is_twofactor_provider_usable, webauthn, @@ -21,29 +23,28 @@ use crate::{ }, master_password_policy, push::register_push_device, - ApiResult, EmptyResult, JsonResult, }, auth, - auth::{generate_organization_api_key_login_claims, AuthMethod, ClientHeaders, ClientIp, ClientVersion, Secure}, + auth::{AuthMethod, ClientHeaders, ClientIp, ClientVersion, Secure, generate_organization_api_key_login_claims}, crypto, db::{ + DbConn, models::{ AuthRequest, AuthRequestId, Device, DeviceId, EventType, Invitation, OIDCCodeResponseError, OrganizationApiKey, OrganizationId, SsoAuth, SsoUser, TwoFactor, TwoFactorIncomplete, TwoFactorType, User, UserId, }, - DbConn, }, error::MapResult, mail, sso, sso::{OIDCCode, OIDCCodeChallenge, OIDCCodeVerifier, OIDCState}, - util, CONFIG, + util, }; pub fn routes() -> Vec { routes![ login, - prelogin, + post_prelogin, prelogin_password, identity_register, register_verification_email, @@ -68,43 +69,43 @@ async fn login( let login_result = match data.grant_type.as_ref() { "refresh_token" => { - _check_is_some(data.refresh_token.as_ref(), "refresh_token cannot be blank")?; - _refresh_login(data, &conn, &client_header.ip).await + check_is_some(data.refresh_token.as_ref(), "refresh_token cannot be blank")?; + refresh_login(data, &conn, &client_header.ip).await } "password" if CONFIG.sso_enabled() && CONFIG.sso_only() => err!("SSO sign-in is required"), "password" => { - _check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?; - _check_is_some(data.password.as_ref(), "password cannot be blank")?; - _check_is_some(data.scope.as_ref(), "scope cannot be blank")?; - _check_is_some(data.username.as_ref(), "username cannot be blank")?; + check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?; + check_is_some(data.password.as_ref(), "password cannot be blank")?; + check_is_some(data.scope.as_ref(), "scope cannot be blank")?; + check_is_some(data.username.as_ref(), "username cannot be blank")?; - _check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?; - _check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?; - _check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?; + check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?; + check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?; + check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?; - _password_login(data, &mut user_id, &conn, &client_header.ip, client_version.as_ref()).await + password_login(data, &mut user_id, &conn, &client_header.ip, client_version.as_ref()).await } "client_credentials" => { - _check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?; - _check_is_some(data.client_secret.as_ref(), "client_secret cannot be blank")?; - _check_is_some(data.scope.as_ref(), "scope cannot be blank")?; + check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?; + check_is_some(data.client_secret.as_ref(), "client_secret cannot be blank")?; + check_is_some(data.scope.as_ref(), "scope cannot be blank")?; - _check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?; - _check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?; - _check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?; + check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?; + check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?; + check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?; - _api_key_login(data, &mut user_id, &conn, &client_header.ip).await + api_key_login(data, &mut user_id, &conn, &client_header.ip).await } "authorization_code" if CONFIG.sso_enabled() => { - _check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?; - _check_is_some(data.code.as_ref(), "code cannot be blank")?; - _check_is_some(data.code_verifier.as_ref(), "code verifier cannot be blank")?; + check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?; + check_is_some(data.code.as_ref(), "code cannot be blank")?; + check_is_some(data.code_verifier.as_ref(), "code verifier cannot be blank")?; - _check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?; - _check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?; - _check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?; + check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?; + check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?; + check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?; - _sso_login(data, &mut user_id, &conn, &client_header.ip, client_version.as_ref()).await + sso_login(data, &mut user_id, &conn, &client_header.ip, client_version.as_ref()).await } "authorization_code" => err!("SSO sign-in is not available"), t => err!("Invalid type", t), @@ -125,7 +126,7 @@ async fn login( Err(e) => { if let Some(ev) = e.get_event() { log_user_event(ev.event as i32, &user_id, client_header.device_type, &client_header.ip.ip, &conn) - .await + .await; } } } @@ -134,7 +135,7 @@ async fn login( login_result } -async fn _refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult { +async fn refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult { // When a refresh token is invalid or missing we need to respond with an HTTP BadRequest (400) // It also needs to return a json which holds at least a key `error` with the value `invalid_grant` // See the link below for details @@ -175,7 +176,7 @@ async fn _refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> Json } // After exchanging the code we need to check first if 2FA is needed before continuing -async fn _sso_login( +async fn sso_login( data: ConnectData, user_id: &mut Option, conn: &DbConn, @@ -344,7 +345,7 @@ async fn _sso_login( authenticated_response(&user, &mut device, auth_tokens, twofactor_token, conn, ip).await } -async fn _password_login( +async fn password_login( data: ConnectData, user_id: &mut Option, conn: &DbConn, @@ -428,9 +429,9 @@ async fn _password_login( if user.verified_at.is_none() && CONFIG.mail_enabled() && CONFIG.signups_verify() { if user.last_verifying_at.is_none() || now.signed_duration_since(user.last_verifying_at.unwrap()).num_seconds() - > CONFIG.signups_verify_resend_time() as i64 + > CONFIG.signups_verify_resend_time().cast_signed() { - let resend_limit = CONFIG.signups_verify_resend_limit() as i32; + let resend_limit = CONFIG.signups_verify_resend_limit().cast_signed(); if resend_limit == 0 || user.login_verify_count < resend_limit { // We want to send another email verification if we require signups to verify // their email address, and we haven't sent them a reminder in a while... @@ -566,19 +567,19 @@ async fn authenticated_response( Ok(Json(result)) } -async fn _api_key_login(data: ConnectData, user_id: &mut Option, conn: &DbConn, ip: &ClientIp) -> JsonResult { +async fn api_key_login(data: ConnectData, user_id: &mut Option, conn: &DbConn, ip: &ClientIp) -> JsonResult { // Ratelimit the login crate::ratelimit::check_limit_login(&ip.ip)?; // Validate scope match data.scope.as_ref() { - Some(scope) if scope == &AuthMethod::UserApiKey.scope() => _user_api_key_login(data, user_id, conn, ip).await, - Some(scope) if scope == &AuthMethod::OrgApiKey.scope() => _organization_api_key_login(data, conn, ip).await, + Some(scope) if scope == &AuthMethod::UserApiKey.scope() => user_api_key_login(data, user_id, conn, ip).await, + Some(scope) if scope == &AuthMethod::OrgApiKey.scope() => organization_api_key_login(data, conn, ip).await, _ => err!("Scope not supported"), } } -async fn _user_api_key_login( +async fn user_api_key_login( data: ConnectData, user_id: &mut Option, conn: &DbConn, @@ -710,13 +711,13 @@ async fn _user_api_key_login( Ok(Json(result)) } -async fn _organization_api_key_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult { +async fn organization_api_key_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult { // Get the org via the client_id let client_id = data.client_id.as_ref().unwrap(); let Some(org_id) = client_id.strip_prefix("organization.") else { err!("Malformed client_id", format!("IP: {}.", ip.ip)) }; - let org_id: OrganizationId = org_id.to_string().into(); + let org_id: OrganizationId = org_id.to_owned().into(); let Some(org_api_key) = OrganizationApiKey::find_by_org_uuid(&org_id, conn).await else { err!("Invalid client_id", format!("IP: {}.", ip.ip)) }; @@ -747,14 +748,13 @@ async fn get_device(data: &ConnectData, conn: &DbConn, user: &User) -> ApiResult let device_name = data.device_name.clone().expect("No device name provided"); // Find device or create new - match Device::find_by_uuid_and_user(&device_id, &user.uuid, conn).await { - Some(device) => Ok(device), - None => { - let mut device = Device::new(device_id, user.uuid.clone(), device_name, device_type); - // save device without updating `device.updated_at` - device.save(false, conn).await?; - Ok(device) - } + if let Some(device) = Device::find_by_uuid_and_user(&device_id, &user.uuid, conn).await { + Ok(device) + } else { + let mut device = Device::new(device_id, user.uuid.clone(), device_name, device_type); + // save device without updating `device.updated_at` + device.save(false, conn).await?; + Ok(device) } } @@ -780,7 +780,7 @@ async fn twofactor_auth( .iter() .filter_map(|tf| { let provider_type = TwoFactorType::from_i32(tf.atype)?; - (tf.enabled && is_twofactor_provider_usable(provider_type, Some(&tf.data))).then_some(tf.atype) + (tf.enabled && is_twofactor_provider_usable(&provider_type, Some(&tf.data))).then_some(tf.atype) }) .collect(); if twofactor_ids.is_empty() { @@ -788,59 +788,51 @@ async fn twofactor_auth( } let selected_id = data.two_factor_provider.unwrap_or(twofactor_ids[0]); // If we aren't given a two factor provider, assume the first one - // Ignore Remember and RecoveryCode Types during this check, these are special + // Ignore Remember and RecoveryCode Types during this check, these are special if ![TwoFactorType::Remember as i32, TwoFactorType::RecoveryCode as i32].contains(&selected_id) && !twofactor_ids.contains(&selected_id) { err_json!( - _json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?, + json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?, "Invalid two factor provider" ) } - let twofactor_code = match data.two_factor_token { - Some(ref code) => code, - None => { - err_json!( - _json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?, - "2FA token not provided" - ) - } + let Some(ref twofactor_code) = data.two_factor_token else { + err_json!( + json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?, + "2FA token not provided" + ) }; let selected_twofactor = twofactors.into_iter().find(|tf| tf.atype == selected_id && tf.enabled); - use crate::crypto::ct_eq; - - let selected_data = _selected_data(selected_twofactor); + let selected_data = selected_data(selected_twofactor); match TwoFactorType::from_i32(selected_id) { Some(TwoFactorType::Authenticator) => { - authenticator::validate_totp_code_str(&user.uuid, twofactor_code, &selected_data?, ip, conn).await? + authenticator::validate_totp_code_str(&user.uuid, twofactor_code, &selected_data?, ip, conn).await?; } Some(TwoFactorType::Webauthn) => webauthn::validate_webauthn_login(&user.uuid, twofactor_code, conn).await?, Some(TwoFactorType::YubiKey) => yubikey::validate_yubikey_login(twofactor_code, &selected_data?).await?, Some(TwoFactorType::Duo) => { - match CONFIG.duo_use_iframe() { - true => { - // Legacy iframe prompt flow - duo::validate_duo_login(&user.email, twofactor_code, conn).await? - } - false => { - // OIDC based flow - duo_oidc::validate_duo_login( - &user.email, - twofactor_code, - data.client_id.as_ref().unwrap(), - data.device_identifier.as_ref().unwrap(), - conn, - ) - .await? - } + if CONFIG.duo_use_iframe() { + // Legacy iframe prompt flow + duo::validate_duo_login(&user.email, twofactor_code, conn).await?; + } else { + // OIDC based flow + duo_oidc::validate_duo_login( + &user.email, + twofactor_code, + data.client_id.as_ref().unwrap(), + data.device_identifier.as_ref().unwrap(), + conn, + ) + .await?; } } Some(TwoFactorType::Email) => { - email::validate_email_code_str(&user.uuid, twofactor_code, &selected_data?, &ip.ip, conn).await? + email::validate_email_code_str(&user.uuid, twofactor_code, &selected_data?, &ip.ip, conn).await?; } Some(TwoFactorType::Remember) => { match device.twofactor_remember { @@ -848,7 +840,7 @@ async fn twofactor_auth( // If it is invalid we need to trigger the 2FA Login prompt Some(ref token) if !CONFIG.disable_2fa_remember() - && (ct_eq(token, twofactor_code) + && (crypto::ct_eq(token, twofactor_code) && auth::decode_2fa_remember(twofactor_code) .is_ok_and(|t| t.sub == device.uuid && t.user_uuid == user.uuid)) => {} _ => { @@ -859,7 +851,7 @@ async fn twofactor_auth( device.save(true, conn).await?; } err_json!( - _json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?, + json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?, "2FA Remember token not provided or expired" ) } @@ -900,11 +892,11 @@ async fn twofactor_auth( Ok(two_factor) } -fn _selected_data(tf: Option) -> ApiResult { +fn selected_data(tf: Option) -> ApiResult { tf.map(|t| t.data).map_res("Two factor doesn't exist") } -async fn _json_err_twofactor( +async fn json_err_twofactor( providers: &[i32], user_id: &UserId, data: &ConnectData, @@ -925,42 +917,38 @@ async fn _json_err_twofactor( result["TwoFactorProviders2"][provider.to_string()] = Value::Null; match TwoFactorType::from_i32(*provider) { - Some(TwoFactorType::Authenticator) => { /* Nothing to do for TOTP */ } - Some(TwoFactorType::Webauthn) if CONFIG.is_webauthn_2fa_supported() => { let request = webauthn::generate_webauthn_login(user_id, conn).await?; result["TwoFactorProviders2"][provider.to_string()] = request.0; } Some(TwoFactorType::Duo) => { - let email = match User::find_by_uuid(user_id, conn).await { - Some(u) => u.email, - None => err!("User does not exist"), + let email = if let Some(u) = User::find_by_uuid(user_id, conn).await { + u.email + } else { + err!("User does not exist") }; - match CONFIG.duo_use_iframe() { - true => { - // Legacy iframe prompt flow - let (signature, host) = duo::generate_duo_signature(&email, conn).await?; - result["TwoFactorProviders2"][provider.to_string()] = json!({ - "Host": host, - "Signature": signature, - }) - } - false => { - // OIDC based flow - let auth_url = duo_oidc::get_duo_auth_url( - &email, - data.client_id.as_ref().unwrap(), - data.device_identifier.as_ref().unwrap(), - conn, - ) - .await?; - - result["TwoFactorProviders2"][provider.to_string()] = json!({ - "AuthUrl": auth_url, - }) - } + if CONFIG.duo_use_iframe() { + // Legacy iframe prompt flow + let (signature, host) = duo::generate_duo_signature(&email, conn).await?; + result["TwoFactorProviders2"][provider.to_string()] = json!({ + "Host": host, + "Signature": signature, + }); + } else { + // OIDC based flow + let auth_url = duo_oidc::get_duo_auth_url( + &email, + data.client_id.as_ref().unwrap(), + data.device_identifier.as_ref().unwrap(), + conn, + ) + .await?; + + result["TwoFactorProviders2"][provider.to_string()] = json!({ + "AuthUrl": auth_url, + }); } } @@ -973,7 +961,7 @@ async fn _json_err_twofactor( result["TwoFactorProviders2"][provider.to_string()] = json!({ "Nfc": yubikey_metadata.nfc, - }) + }); } Some(tf_type @ TwoFactorType::Email) => { @@ -991,16 +979,30 @@ async fn _json_err_twofactor( // Send email immediately if email is the only 2FA option. if providers.len() == 1 && !disabled_send { - email::send_token(user_id, conn).await? + email::send_token(user_id, conn).await?; } let email_data = email::EmailTokenData::from_json(&twofactor.data)?; result["TwoFactorProviders2"][provider.to_string()] = json!({ "Email": email::obscure_email(&email_data.email), - }) + }); } - _ => {} + None + | Some( + TwoFactorType::Authenticator + | TwoFactorType::EmailVerificationChallenge + | TwoFactorType::OrganizationDuo + | TwoFactorType::ProtectedActions + | TwoFactorType::RecoveryCode + | TwoFactorType::Remember + | TwoFactorType::U2f + | TwoFactorType::U2fLoginChallenge + | TwoFactorType::U2fRegisterChallenge + | TwoFactorType::Webauthn + | TwoFactorType::WebauthnLoginChallenge + | TwoFactorType::WebauthnRegisterChallenge, + ) => { /* Nothing special to do for these providers */ } } } @@ -1008,18 +1010,18 @@ async fn _json_err_twofactor( } #[post("/accounts/prelogin", data = "")] -async fn prelogin(data: Json, conn: DbConn) -> Json { - _prelogin(data, conn).await +async fn post_prelogin(data: Json, conn: DbConn) -> Json { + prelogin(data, conn).await } #[post("/accounts/prelogin/password", data = "")] async fn prelogin_password(data: Json, conn: DbConn) -> Json { - _prelogin(data, conn).await + prelogin(data, conn).await } #[post("/accounts/register", data = "")] async fn identity_register(data: Json, conn: DbConn) -> JsonResult { - _register(data, false, conn).await + register(data, false, conn).await } #[derive(Debug, Deserialize)] @@ -1058,13 +1060,13 @@ async fn register_verification_email( if should_send_mail { let user = User::find_by_mail(&data.email, &conn).await; - if user.filter(|u| u.private_key.is_some()).is_some() { + if user.as_ref().is_some_and(|u| u.private_key.is_some()) { // There is still a timing side channel here in that the code // paths that send mail take noticeably longer than ones that don't. // Add a randomized sleep to mitigate this somewhat. - use rand::{rngs::SmallRng, RngExt}; + use rand::{RngExt, rngs::SmallRng}; let mut rng: SmallRng = rand::make_rng(); - let sleep_ms = rng.random_range(900..=1100) as u64; + let sleep_ms: u64 = rng.random_range(900..=1100); tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await; } else { mail::send_register_verify_email(&data.email, &token).await?; @@ -1080,7 +1082,7 @@ async fn register_verification_email( #[post("/accounts/register/finish", data = "")] async fn register_finish(data: Json, conn: DbConn) -> JsonResult { - _register(data, true, conn).await + register(data, true, conn).await } // https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts @@ -1143,7 +1145,7 @@ struct ConnectData { #[field(name = uncased("code_verifier"))] code_verifier: Option, } -fn _check_is_some(value: Option<&T>, msg: &str) -> EmptyResult { +fn check_is_some(value: Option<&T>, msg: &str) -> EmptyResult { if value.is_none() { err!(msg) } @@ -1166,7 +1168,7 @@ const SSO_BINDING_COOKIE: &str = "VW_SSO_BINDING"; #[get("/connect/oidc-signin?&", rank = 1)] async fn oidcsignin(code: OIDCCode, state: String, cookies: &CookieJar<'_>, mut conn: DbConn) -> ApiResult { - _oidcsignin_redirect(state, code, None, cookies, &mut conn).await + oidcsignin_redirect(state, code, None, cookies, &mut conn).await } // Bitwarden client appear to only care for code and state @@ -1180,7 +1182,7 @@ async fn oidcsignin_error( cookies: &CookieJar<'_>, mut conn: DbConn, ) -> ApiResult { - _oidcsignin_redirect( + oidcsignin_redirect( state.clone(), state.into(), Some(OIDCCodeResponseError { @@ -1195,7 +1197,8 @@ async fn oidcsignin_error( // The state was encoded using Base64 to ensure no issue with providers. // iss and scope parameters are needed for redirection to work on IOS. -async fn _oidcsignin_redirect( +// We pass the state as the code to get it back later on. +async fn oidcsignin_redirect( base64_state: String, code: OIDCCode, error: Option, @@ -1204,14 +1207,13 @@ async fn _oidcsignin_redirect( ) -> ApiResult { let state = sso::decode_state(&base64_state)?; - let mut sso_auth = match SsoAuth::find(&state, conn).await { - None => err!(format!("Cannot retrieve sso_auth for {state}")), - Some(sso_auth) => sso_auth, + let Some(mut sso_auth) = SsoAuth::find(&state, conn).await else { + err!(format!("Cannot retrieve sso_auth for {state}")) }; // Browser-binding check // The cookie was set on /connect/authorize and must come from the same browser that initiated the flow. - let cookie_value = cookies.get(SSO_BINDING_COOKIE).map(|c| c.value().to_string()); + let cookie_value = cookies.get(SSO_BINDING_COOKIE).map(|c| c.value().to_owned()); let provided_hash = cookie_value.as_deref().map(|v| crypto::sha256_hex(v.as_bytes())); match (sso_auth.binding_hash.as_deref(), provided_hash.as_deref()) { (Some(expected), Some(actual)) if crypto::ct_eq(expected, actual) => {} diff --git a/src/api/mod.rs b/src/api/mod.rs index ecdf9408..3c85d17e 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -9,6 +9,7 @@ mod web; use rocket::serde::json::Json; use serde_json::Value; +use crate::CONFIG; pub use crate::api::{ admin::catchers as admin_catchers, admin::routes as admin_routes, @@ -33,10 +34,9 @@ pub use crate::api::{ web::static_files, }; use crate::db::{ - models::{OrgPolicy, OrgPolicyType, User}, DbConn, + models::{OrgPolicy, OrgPolicyType, User}, }; -use crate::CONFIG; // Type aliases for API methods results pub type ApiResult = Result; @@ -74,6 +74,7 @@ impl PasswordOrOtpData { } } +#[expect(clippy::struct_excessive_bools, reason = "Bitwarden clients expect the data in this specific format")] #[derive(Debug, Default, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct MasterPasswordPolicy { diff --git a/src/api/notifications.rs b/src/api/notifications.rs index b1d64472..27fa1fd6 100644 --- a/src/api/notifications.rs +++ b/src/api/notifications.rs @@ -6,17 +6,17 @@ use std::{ use chrono::{NaiveDateTime, Utc}; use rmpv::Value; -use rocket::{futures::StreamExt, Route}; +use rocket::{Route, futures::StreamExt}; use rocket_ws::{Message, WebSocket}; use tokio::sync::mpsc::Sender; use crate::{ + CONFIG, Error, auth::{ClientIp, WsAccessTokenHeader}, db::{ - models::{AuthRequestId, Cipher, CollectionId, Device, DeviceId, Folder, PushId, Send as DbSend, User, UserId}, DbConn, + models::{AuthRequestId, Cipher, CollectionId, Device, DeviceId, Folder, PushId, Send as DbSend, User, UserId}, }, - Error, CONFIG, }; pub static WS_USERS: LazyLock> = LazyLock::new(|| { @@ -102,7 +102,7 @@ impl Drop for WSAnonymousEntryMapGuard { } } -#[allow(tail_expr_drop_order)] +#[expect(tail_expr_drop_order)] #[get("/hub?")] fn websockets_hub<'r>( ws: WebSocket, @@ -186,7 +186,7 @@ fn websockets_hub<'r>( }) } -#[allow(tail_expr_drop_order)] +#[expect(tail_expr_drop_order)] #[get("/anonymous-hub?")] fn anonymous_websockets_hub<'r>(ws: WebSocket, token: String, ip: ClientIp) -> Result { info!("Accepting Anonymous Rocket WS connection from {}", ip.ip); @@ -268,14 +268,15 @@ fn serialize(val: &Value) -> Vec { let mut len_buf: Vec = Vec::new(); loop { - let mut size_part = size & 0x7f; + #[expect(clippy::cast_possible_truncation, reason = "masked to 7 bits, fits u8")] + let mut size_part = (size & 0x7f) as u8; size >>= 7; if size > 0 { size_part |= 0x80; } - len_buf.push(size_part as u8); + len_buf.push(size_part); if size == 0 { break; @@ -329,7 +330,7 @@ pub struct WebSocketUsers { impl WebSocketUsers { async fn send_update(&self, user_id: &UserId, data: &[u8]) { if let Some(user) = self.map.get(user_id.as_ref()).map(|v| v.clone()) { - for (_, sender) in user.iter() { + for (_, sender) in &user { if let Err(e) = sender.send(Message::binary(data)).await { error!("Error sending WS update {e}"); } @@ -538,10 +539,10 @@ pub struct AnonymousWebSocketSubscriptions { impl AnonymousWebSocketSubscriptions { async fn send_update(&self, token: &str, data: &[u8]) { - if let Some(sender) = self.map.get(token).map(|v| v.clone()) { - if let Err(e) = sender.send(Message::binary(data)).await { - error!("Error sending WS update {e}"); - } + if let Some(sender) = self.map.get(token).map(|v| v.clone()) + && let Err(e) = sender.send(Message::binary(data)).await + { + error!("Error sending WS update {e}"); } } @@ -582,7 +583,7 @@ fn create_update(payload: Vec<(Value, Value)>, ut: UpdateType, acting_device_id: V::Nil, "ReceiveMessage".into(), V::Array(vec![V::Map(vec![ - ("ContextId".into(), acting_device_id.map(|v| v.to_string().into()).unwrap_or_else(|| V::Nil)), + ("ContextId".into(), acting_device_id.map_or(V::Nil, |v| v.to_string().into())), ("Type".into(), (ut as i32).into()), ("Payload".into(), payload.into()), ])]), diff --git a/src/api/push.rs b/src/api/push.rs index e3ff1383..e87a0985 100644 --- a/src/api/push.rs +++ b/src/api/push.rs @@ -4,21 +4,21 @@ use std::{ }; use reqwest::{ - header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}, Method, + header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}, }; use serde_json::Value; use tokio::sync::RwLock; use crate::{ + CONFIG, api::{ApiResult, EmptyResult, UpdateType}, db::{ - models::{AuthRequestId, Cipher, Device, Folder, PushId, Send, User, UserId}, DbConn, + models::{AuthRequestId, Cipher, Device, Folder, PushId, Send, User, UserId}, }, http_client::make_http_request, util::{format_date, get_uuid}, - CONFIG, }; #[derive(Deserialize)] @@ -74,9 +74,9 @@ async fn get_auth_api_token() -> ApiResult { }; let mut api_token = API_TOKEN.write().await; - api_token.valid_until = Instant::now() - .checked_add(Duration::new((json_pushtoken.expires_in / 2) as u64, 0)) // Token valid for half the specified time - .unwrap(); + // Token valid for half the specified time + let half_expires_in = u64::from((json_pushtoken.expires_in / 2).max(0).cast_unsigned()); + api_token.valid_until = Instant::now().checked_add(Duration::from_secs(half_expires_in)).unwrap(); api_token.access_token = json_pushtoken.access_token; @@ -161,7 +161,7 @@ pub async fn push_cipher_update(ut: UpdateType, cipher: &Cipher, device: &Device // We shouldn't send a push notification on cipher update if the cipher belongs to an organization, this isn't implemented in the upstream server too. if cipher.organization_uuid.is_some() { return; - }; + } let Some(user_id) = &cipher.user_uuid else { debug!("Cipher has no uuid"); return; @@ -244,23 +244,23 @@ pub async fn push_folder_update(ut: UpdateType, folder: &Folder, device: &Device } pub async fn push_send_update(ut: UpdateType, send: &Send, device: &Device, conn: &DbConn) { - if let Some(s) = &send.user_uuid { - if Device::check_user_has_push_device(s, conn).await { - tokio::task::spawn(send_to_push_relay(json!({ + if let Some(s) = &send.user_uuid + && Device::check_user_has_push_device(s, conn).await + { + tokio::task::spawn(send_to_push_relay(json!({ + "userId": send.user_uuid, + "organizationId": null, + "deviceId": device.push_uuid, // Should be the records unique uuid of the acting device (unique uuid per user/device) + "identifier": device.uuid, // Should be the acting device id (aka uuid per device/app) + "type": ut as i32, + "payload": { + "id": send.uuid, "userId": send.user_uuid, - "organizationId": null, - "deviceId": device.push_uuid, // Should be the records unique uuid of the acting device (unique uuid per user/device) - "identifier": device.uuid, // Should be the acting device id (aka uuid per device/app) - "type": ut as i32, - "payload": { - "id": send.uuid, - "userId": send.user_uuid, - "revisionDate": format_date(&send.revision_date) - }, - "clientType": null, - "installationId": null - }))); - } + "revisionDate": format_date(&send.revision_date) + }, + "clientType": null, + "installationId": null + }))); } } @@ -296,7 +296,7 @@ async fn send_to_push_relay(notification_data: Value) { .await { error!("An error occurred while sending a send update to the push relay: {e}"); - }; + } } pub async fn push_auth_request(user_id: &UserId, auth_request_id: &str, device: &Device, conn: &DbConn) { diff --git a/src/api/web.rs b/src/api/web.rs index 0ae9c7db..771a08b0 100644 --- a/src/api/web.rs +++ b/src/api/web.rs @@ -1,21 +1,21 @@ use std::path::{Path, PathBuf}; use rocket::{ + Catcher, Route, fs::NamedFile, http::ContentType, - response::{content::RawCss as Css, content::RawHtml as Html, Redirect}, + response::{Redirect, content::RawCss as Css, content::RawHtml as Html}, serde::json::Json, - Catcher, Route, }; use serde_json::Value; use crate::{ - api::{core::now, ApiResult, EmptyResult}, + CONFIG, + api::{ApiResult, EmptyResult, core::now}, auth::decode_file_download, db::models::{AttachmentId, CipherId}, error::Error, util::Cached, - CONFIG, }; pub fn routes() -> Vec { @@ -28,7 +28,7 @@ pub fn routes() -> Vec { #[cfg(debug_assertions)] if CONFIG.reload_templates() { - routes.append(&mut routes![_static_files_dev]); + routes.append(&mut routes![static_files_dev]); } routes @@ -197,7 +197,7 @@ fn alive_head(_conn: DbConn) -> EmptyResult { // NOTE: Do not forget to add any new files added to the `static_files` function below! #[cfg(debug_assertions)] #[get("/vw_static/", rank = 1)] -pub async fn _static_files_dev(filename: PathBuf) -> Option { +pub async fn static_files_dev(filename: PathBuf) -> Option { warn!("LOADING STATIC FILES FROM DISK"); let file = filename.to_str().unwrap_or_default(); let ext = filename.extension().unwrap_or_default(); @@ -210,7 +210,7 @@ pub async fn _static_files_dev(filename: PathBuf) -> Option { if let Ok(path) = path { return NamedFile::open(path).await.ok(); - }; + } None } diff --git a/src/auth.rs b/src/auth.rs index 06bd9c22..ba58501d 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -5,13 +5,14 @@ use std::{ }; use chrono::{DateTime, TimeDelta, Utc}; -use jsonwebtoken::{errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header}; +use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, errors::ErrorKind}; use num_traits::FromPrimitive; use openssl::rsa::Rsa; use serde::de::DeserializeOwned; use serde::ser::Serialize; use crate::{ + CONFIG, api::ApiResult, config::PathType, db::models::{ @@ -19,7 +20,7 @@ use crate::{ OrganizationId, SendFileId, SendId, UserId, }, error::Error, - sso, CONFIG, + sso, }; const JWT_ALGORITHM: Algorithm = Algorithm::RS256; @@ -226,7 +227,7 @@ impl LoginJwtClaims { // let orgmanager: Vec<_> = orgs.iter().filter(|o| o.atype == 3).map(|o| o.org_uuid.clone()).collect(); if exp <= (now + *BW_EXPIRATION).timestamp() { - warn!("Raise access_token lifetime to more than 5min.") + warn!("Raise access_token lifetime to more than 5min."); } // Create the JWT claims struct, to send to the client @@ -253,7 +254,7 @@ impl LoginJwtClaims { sstamp: user.security_stamp.clone(), device: device.uuid.clone(), devicetype: DeviceType::from_i32(device.atype).to_string(), - client_id: client_id.unwrap_or("undefined".to_string()), + client_id: client_id.unwrap_or("undefined".to_owned()), scope, amr: vec!["Application".into()], } @@ -506,7 +507,7 @@ pub fn generate_admin_claims() -> BasicJwtClaims { nbf: time_now.timestamp(), exp: (time_now + TimeDelta::try_minutes(CONFIG.admin_session_lifetime()).unwrap()).timestamp(), iss: JWT_ADMIN_ISSUER.to_string(), - sub: "admin_panel".to_string(), + sub: "admin_panel".to_owned(), } } @@ -529,8 +530,8 @@ use rocket::{ }; use crate::db::{ - models::{Collection, Device, Membership, MembershipStatus, MembershipType, User, UserStampException}, DbConn, + models::{Collection, Device, Membership, MembershipStatus, MembershipType, User, UserStampException}, }; pub struct Host { @@ -548,7 +549,7 @@ impl<'r> FromRequest<'r> for Host { let host = if CONFIG.domain_set() { CONFIG.domain() } else if let Some(referer) = headers.get_one("Referer") { - referer.to_string() + referer.to_owned() } else { // Try to guess from the headers let protocol = if let Some(proto) = headers.get_one("X-Forwarded-Proto") { @@ -584,13 +585,15 @@ impl<'r> FromRequest<'r> for ClientHeaders { type Error = &'static str; async fn from_request(request: &'r Request<'_>) -> Outcome { - let ip = match ClientIp::from_request(request).await { - Outcome::Success(ip) => ip, - _ => err_handler!("Error getting Client IP"), + let Outcome::Success(ip) = ClientIp::from_request(request).await else { + err_handler!("Error getting Client IP") }; - // When unknown or unable to parse, return 14, which is 'Unknown Browser' - let device_type: i32 = - request.headers().get_one("device-type").map(|d| d.parse().unwrap_or(14)).unwrap_or_else(|| 14); + // When unknown or unable to parse, return 'UnknownBrowser' + let device_type: i32 = request + .headers() + .get_one("device-type") + .and_then(|d| d.parse().ok()) + .unwrap_or(DeviceType::UnknownBrowser as i32); Outcome::Success(ClientHeaders { device_type, @@ -614,18 +617,19 @@ impl<'r> FromRequest<'r> for Headers { let headers = request.headers(); let host = try_outcome!(Host::from_request(request).await).host; - let ip = match ClientIp::from_request(request).await { - Outcome::Success(ip) => ip, - _ => err_handler!("Error getting Client IP"), + let Outcome::Success(ip) = ClientIp::from_request(request).await else { + err_handler!("Error getting Client IP") }; // Get access_token - let access_token: &str = match headers.get_one("Authorization") { - Some(a) => match a.rsplit("Bearer ").next() { - Some(split) => split, - None => err_handler!("No access token provided"), - }, - None => err_handler!("No access token provided"), + let access_token: &str = if let Some(a) = headers.get_one("Authorization") { + if let Some(split) = a.rsplit("Bearer ").next() { + split + } else { + err_handler!("No access token provided") + } + } else { + err_handler!("No access token provided") }; // Check JWT token is valid and get device and user from it @@ -636,9 +640,8 @@ impl<'r> FromRequest<'r> for Headers { let device_id = claims.device; let user_id = claims.sub; - let conn = match DbConn::from_request(request).await { - Outcome::Success(conn) => conn, - _ => err_handler!("Error getting DB"), + let Outcome::Success(conn) = DbConn::from_request(request).await else { + err_handler!("Error getting DB") }; let Some(device) = Device::find_by_uuid_and_user(&device_id, &user_id, &conn).await else { @@ -669,7 +672,7 @@ impl<'r> FromRequest<'r> for Headers { error!("Error updating user: {e:#?}"); } err_handler!("Stamp exception is expired") - } else if !stamp_exception.routes.contains(¤t_route.to_string()) { + } else if !stamp_exception.routes.contains(¤t_route.to_owned()) { err_handler!("Invalid security stamp: Current route and exception route do not match") } else if stamp_exception.security_stamp != claims.sstamp { err_handler!("Invalid security stamp for matched stamp exception") @@ -757,9 +760,8 @@ impl<'r> FromRequest<'r> for OrgHeaders { match url_org_id { Some(org_id) if uuid::Uuid::parse_str(&org_id).is_ok() => { - let conn = match DbConn::from_request(request).await { - Outcome::Success(conn) => conn, - _ => err_handler!("Error getting DB"), + let Outcome::Success(conn) = DbConn::from_request(request).await else { + err_handler!("Error getting DB") }; let user = headers.user; @@ -831,16 +833,16 @@ impl<'r> FromRequest<'r> for AdminHeaders { // but there could be cases where it is a query value. // First check the path, if this is not a valid uuid, try the query values. fn get_col_id(request: &Request<'_>) -> Option { - if let Some(Ok(col_id)) = request.param::(3) { - if uuid::Uuid::parse_str(&col_id).is_ok() { - return Some(col_id.into()); - } + if let Some(Ok(col_id)) = request.param::(3) + && uuid::Uuid::parse_str(&col_id).is_ok() + { + return Some(col_id.into()); } - if let Some(Ok(col_id)) = request.query_value::("collectionId") { - if uuid::Uuid::parse_str(&col_id).is_ok() { - return Some(col_id.into()); - } + if let Some(Ok(col_id)) = request.query_value::("collectionId") + && uuid::Uuid::parse_str(&col_id).is_ok() + { + return Some(col_id.into()); } None @@ -864,18 +866,16 @@ impl<'r> FromRequest<'r> for ManagerHeaders { async fn from_request(request: &'r Request<'_>) -> Outcome { let headers = try_outcome!(OrgHeaders::from_request(request).await); if headers.is_confirmed_and_manager() { - match get_col_id(request) { - Some(col_id) => { - let conn = match DbConn::from_request(request).await { - Outcome::Success(conn) => conn, - _ => err_handler!("Error getting DB"), - }; - - if !Collection::is_coll_manageable_by_user(&col_id, &headers.membership.user_uuid, &conn).await { - err_handler!("The current user isn't a manager for this collection") - } + if let Some(col_id) = get_col_id(request) { + let Outcome::Success(conn) = DbConn::from_request(request).await else { + err_handler!("Error getting DB") + }; + + if !Collection::is_coll_manageable_by_user(&col_id, &headers.membership.user_uuid, &conn).await { + err_handler!("The current user isn't a manager for this collection") } - _ => err_handler!("Error getting the collection id"), + } else { + err_handler!("Error getting the collection id") } Outcome::Success(Self { @@ -1036,7 +1036,7 @@ impl From for Headers { // // Client IP address detection // - +#[derive(Copy, Clone)] pub struct ClientIp { pub ip: IpAddr, } @@ -1068,6 +1068,7 @@ impl<'r> FromRequest<'r> for ClientIp { } } +#[derive(Copy, Clone)] pub struct Secure { pub https: bool, } @@ -1153,15 +1154,14 @@ pub enum AuthMethod { impl AuthMethod { pub fn scope(&self) -> String { match self { - AuthMethod::OrgApiKey => "api.organization".to_string(), - AuthMethod::Password => "api offline_access".to_string(), - AuthMethod::Sso => "api offline_access".to_string(), - AuthMethod::UserApiKey => "api".to_string(), + AuthMethod::OrgApiKey => "api.organization".to_owned(), + AuthMethod::UserApiKey => "api".to_owned(), + AuthMethod::Password | AuthMethod::Sso => "api offline_access".to_owned(), } } pub fn scope_vec(&self) -> Vec { - self.scope().split_whitespace().map(str::to_string).collect() + self.scope().split_whitespace().map(str::to_owned).collect() } pub fn check_scope(&self, scope: Option<&String>) -> ApiResult { @@ -1274,17 +1274,15 @@ pub async fn refresh_tokens( }; // Get device by refresh token - let mut device = match Device::find_by_refresh_token(&refresh_claims.device_token, conn).await { - None => err!("Invalid refresh token"), - Some(device) => device, + let Some(mut device) = Device::find_by_refresh_token(&refresh_claims.device_token, conn).await else { + err!("Invalid refresh token") }; // Save to update `updated_at`. device.save(true, conn).await?; - let user = match User::find_by_uuid(&device.user_uuid, conn).await { - None => err!("Impossible to find user"), - Some(user) => user, + let Some(user) = User::find_by_uuid(&device.user_uuid, conn).await else { + err!("Impossible to find user") }; let auth_tokens = match refresh_claims.sub { diff --git a/src/config.rs b/src/config.rs index 770fb553..5c826fe9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,8 +3,8 @@ use std::{ fmt, process::exit, sync::{ - atomic::{AtomicBool, Ordering}, LazyLock, RwLock, + atomic::{AtomicBool, Ordering}, }, }; @@ -16,8 +16,8 @@ use crate::{ error::Error, storage, util::{ - get_active_web_release, get_env, get_env_bool, is_valid_email, parse_experimental_client_feature_flags, - FeatureFlagFilter, + FeatureFlagFilter, get_active_web_release, get_env, get_env_bool, is_valid_email, + parse_experimental_client_feature_flags, }, }; @@ -27,10 +27,10 @@ static CONFIG_FILE: LazyLock = LazyLock::new(|| { }); static CONFIG_FILE_PARENT_DIR: LazyLock = - LazyLock::new(|| storage::parent(&CONFIG_FILE).unwrap_or_else(|| "data".to_string())); + LazyLock::new(|| storage::parent(&CONFIG_FILE).unwrap_or_else(|| "data".to_owned())); static CONFIG_FILENAME: LazyLock = - LazyLock::new(|| storage::file_name(&CONFIG_FILE).unwrap_or_else(|| "config.json".to_string())); + LazyLock::new(|| storage::file_name(&CONFIG_FILE).unwrap_or_else(|| "config.json".to_owned())); pub static SKIP_CONFIG_VALIDATION: AtomicBool = AtomicBool::new(false); @@ -360,13 +360,7 @@ macro_rules! make_config { )+)+ pub fn prepare_json(&self) -> serde_json::Value { - let (def, cfg, overridden) = { - // Lock the inner as short as possible and clone what is needed to prevent deadlocks - let inner = &self.inner.read().unwrap(); - (inner._env.build(), inner.config.clone(), inner._overrides.clone()) - }; - - fn _get_form_type(rust_type: &'static str) -> &'static str { + fn get_form_type(rust_type: &'static str) -> &'static str { match rust_type { "Pass" => "password", "String" => "text", @@ -375,7 +369,7 @@ macro_rules! make_config { } } - fn _get_doc(doc_str: &'static str) -> ElementDoc { + fn get_doc(doc_str: &'static str) -> ElementDoc { let mut split = doc_str.split("|>").map(str::trim); ElementDoc { name: split.next().unwrap_or_default(), @@ -383,6 +377,12 @@ macro_rules! make_config { } } + let (def, cfg, overridden) = { + // Lock the inner as short as possible and clone what is needed to prevent deadlocks + let inner = &self.inner.read().unwrap(); + (inner._env.build(), inner.config.clone(), inner._overrides.clone()) + }; + let data: Vec = vec![ $( // This repetition is for each group GroupData { @@ -397,8 +397,8 @@ macro_rules! make_config { name: stringify!($name), value: serde_json::to_value(&cfg.$name).unwrap_or_default(), default: serde_json::to_value(&def.$name).unwrap_or_default(), - r#type: _get_form_type(stringify!($ty)), - doc: _get_doc(concat!($($doc),+)), + r#type: get_form_type(stringify!($ty)), + doc: get_doc(concat!($($doc),+)), overridden: overridden.contains(&pastey::paste!(stringify!([<$name:upper>]))), }, )+], // End of elements repetition @@ -408,9 +408,31 @@ macro_rules! make_config { } pub fn get_support_json(&self) -> serde_json::Value { + /// We map over the string and remove all alphanumeric, _ and - characters. + /// This is the fastest way (within micro-seconds) instead of using a regex (which takes mili-seconds) + fn privacy_mask(value: &str) -> String { + let mut n: u16 = 0; + let mut colon_match = false; + value + .chars() + .map(|c| { + n += 1; + match c { + ':' if n <= 11 => { + colon_match = true; + c + } + '/' if n <= 13 && colon_match => c, + ',' => c, + _ => '*', + } + }) + .collect::() + } + // Define which config keys need to be masked. // Pass types will always be masked and no need to put them in the list. - // Besides Pass, only String types will be masked via _privacy_mask. + // Besides Pass, only String types will be masked via privacy_mask. const PRIVACY_CONFIG: &[&str] = &[ "allowed_connect_src", "allowed_iframe_ancestors", @@ -437,28 +459,6 @@ macro_rules! make_config { inner.config.clone() }; - /// We map over the string and remove all alphanumeric, _ and - characters. - /// This is the fastest way (within micro-seconds) instead of using a regex (which takes mili-seconds) - fn _privacy_mask(value: &str) -> String { - let mut n: u16 = 0; - let mut colon_match = false; - value - .chars() - .map(|c| { - n += 1; - match c { - ':' if n <= 11 => { - colon_match = true; - c - } - '/' if n <= 13 && colon_match => c, - ',' => c, - _ => '*', - } - }) - .collect::() - } - serde_json::Value::Object({ let mut json = serde_json::Map::new(); $($( @@ -468,7 +468,7 @@ macro_rules! make_config { for mask_key in PRIVACY_CONFIG { if let Some(value) = json.get_mut(*mask_key) { if let Some(s) = value.as_str() { - *value = _privacy_mask(s).into(); + *value = privacy_mask(s).into(); } } } @@ -502,7 +502,7 @@ macro_rules! make_config { make_config! { folders { /// Data folder |> Main data folder - data_folder: String, false, def, "data".to_string(); + data_folder: String, false, def, "data".to_owned(); /// Database URL database_url: String, false, auto, |c| format!("sqlite://{}", storage::join_path(&c.data_folder, "db.sqlite3")); /// Icon cache folder @@ -518,7 +518,7 @@ make_config! { /// Session JWT key rsa_key_filename: String, false, auto, |c| storage::join_path(&c.data_folder, "rsa_key"); /// Web vault folder - web_vault_folder: String, false, def, "web-vault/".to_string(); + web_vault_folder: String, false, def, "web-vault/".to_owned(); }, ws { /// Enable websocket notifications @@ -528,9 +528,9 @@ make_config! { /// Enable push notifications push_enabled: bool, false, def, false; /// Push relay uri - push_relay_uri: String, false, def, "https://push.bitwarden.com".to_string(); + push_relay_uri: String, false, def, "https://push.bitwarden.com".to_owned(); /// Push identity uri - push_identity_uri: String, false, def, "https://identity.bitwarden.com".to_string(); + push_identity_uri: String, false, def, "https://identity.bitwarden.com".to_owned(); /// Installation id |> The installation id from https://bitwarden.com/host push_installation_id: Pass, false, def, String::new(); /// Installation key |> The installation key from https://bitwarden.com/host @@ -542,38 +542,38 @@ make_config! { job_poll_interval_ms: u64, false, def, 30_000; /// Send purge schedule |> Cron schedule of the job that checks for Sends past their deletion date. /// Defaults to hourly. Set blank to disable this job. - send_purge_schedule: String, false, def, "0 5 * * * *".to_string(); + send_purge_schedule: String, false, def, "0 5 * * * *".to_owned(); /// Trash purge schedule |> Cron schedule of the job that checks for trashed items to delete permanently. /// Defaults to daily. Set blank to disable this job. - trash_purge_schedule: String, false, def, "0 5 0 * * *".to_string(); + trash_purge_schedule: String, false, def, "0 5 0 * * *".to_owned(); /// Incomplete 2FA login schedule |> Cron schedule of the job that checks for incomplete 2FA logins. /// Defaults to once every minute. Set blank to disable this job. - incomplete_2fa_schedule: String, false, def, "30 * * * * *".to_string(); + incomplete_2fa_schedule: String, false, def, "30 * * * * *".to_owned(); /// Emergency notification reminder schedule |> Cron schedule of the job that sends expiration reminders to emergency access grantors. /// Defaults to hourly. (3 minutes after the hour) Set blank to disable this job. - emergency_notification_reminder_schedule: String, false, def, "0 3 * * * *".to_string(); + emergency_notification_reminder_schedule: String, false, def, "0 3 * * * *".to_owned(); /// Emergency request timeout schedule |> Cron schedule of the job that grants emergency access requests that have met the required wait time. /// Defaults to hourly. (7 minutes after the hour) Set blank to disable this job. - emergency_request_timeout_schedule: String, false, def, "0 7 * * * *".to_string(); + emergency_request_timeout_schedule: String, false, def, "0 7 * * * *".to_owned(); /// Event cleanup schedule |> Cron schedule of the job that cleans old events from the event table. /// Defaults to daily. Set blank to disable this job. - event_cleanup_schedule: String, false, def, "0 10 0 * * *".to_string(); + event_cleanup_schedule: String, false, def, "0 10 0 * * *".to_owned(); /// Auth Request cleanup schedule |> Cron schedule of the job that cleans old auth requests from the auth request. /// Defaults to every minute. Set blank to disable this job. - auth_request_purge_schedule: String, false, def, "30 * * * * *".to_string(); + auth_request_purge_schedule: String, false, def, "30 * * * * *".to_owned(); /// Duo Auth context cleanup schedule |> Cron schedule of the job that cleans expired Duo contexts from the database. Does nothing if Duo MFA is disabled or set to use the legacy iframe prompt. /// Defaults to once every minute. Set blank to disable this job. - duo_context_purge_schedule: String, false, def, "30 * * * * *".to_string(); + duo_context_purge_schedule: String, false, def, "30 * * * * *".to_owned(); /// Purge incomplete SSO auth. |> Cron schedule of the job that cleans leftover auth in db due to incomplete SSO login. /// Defaults to daily. Set blank to disable this job. - purge_incomplete_sso_auth: String, false, def, "0 20 0 * * *".to_string(); + purge_incomplete_sso_auth: String, false, def, "0 20 0 * * *".to_owned(); }, /// General settings settings { /// Domain URL |> This needs to be set to the URL used to access the server, including 'http[s]://' /// and port, if it's different than the default. Some server functions don't work correctly without this value - domain: String, true, def, "http://localhost".to_string(); + domain: String, true, def, "http://localhost".to_owned(); /// Domain Set |> Indicates if the domain is set by the admin. Otherwise the default will be used. domain_set: bool, false, def, false; /// Domain origin |> Domain URL origin (in https://example.com:8443/path, https://example.com:8443 is the origin) @@ -653,7 +653,7 @@ make_config! { admin_token: Pass, true, option; /// Invitation organization name |> Name shown in the invitation emails that don't come from a specific organization - invitation_org_name: String, true, def, "Vaultwarden".to_string(); + invitation_org_name: String, true, def, "Vaultwarden".to_owned(); /// Events days retain |> Number of days to retain events stored in the database. If unset, events are kept indefinitely. events_days_retain: i64, false, option; @@ -663,7 +663,7 @@ make_config! { advanced { /// Client IP header |> If not present, the remote IP is used. /// Set to the string "none" (without quotes), to disable any headers and just use the remote IP - ip_header: String, true, def, "X-Real-IP".to_string(); + ip_header: String, true, def, "X-Real-IP".to_owned(); /// Internal IP header property, used to avoid recomputing each time _ip_header_enabled: bool, false, generated, |c| &c.ip_header.trim().to_lowercase() != "none"; /// Icon service |> The predefined icon services are: internal, bitwarden, duckduckgo, google. @@ -672,7 +672,7 @@ make_config! { /// `internal` refers to Vaultwarden's built-in icon fetching implementation. If an external /// service is set, an icon request to Vaultwarden will return an HTTP redirect to the /// corresponding icon at the external service. - icon_service: String, false, def, "internal".to_string(); + icon_service: String, false, def, "internal".to_owned(); /// _icon_service_url _icon_service_url: String, false, generated, |c| generate_icon_service_url(&c.icon_service); /// _icon_service_csp @@ -723,14 +723,14 @@ make_config! { /// Enable extended logging extended_logging: bool, false, def, true; /// Log timestamp format - log_timestamp_format: String, true, def, "%Y-%m-%d %H:%M:%S.%3f".to_string(); + log_timestamp_format: String, true, def, "%Y-%m-%d %H:%M:%S.%3f".to_owned(); /// Enable the log to output to Syslog use_syslog: bool, false, def, false; /// Log file path log_file: String, false, option; /// Log level |> Valid values are "trace", "debug", "info", "warn", "error" and "off" /// For a specific module append it as a comma separated value "info,path::to::module=debug" - log_level: String, false, def, "info".to_string(); + log_level: String, false, def, "info".to_owned(); /// Enable DB WAL |> Turning this off might lead to worse performance, but might help if using vaultwarden on some exotic filesystems, /// that do not support WAL. Please make sure you read project wiki on the topic before changing this setting. @@ -812,7 +812,7 @@ make_config! { /// Authority Server |> Base url of the OIDC provider discovery endpoint (without `/.well-known/openid-configuration`) sso_authority: String, true, def, String::new(); /// Authorization request scopes |> List the of the needed scope (`openid` is implicit) - sso_scopes: String, true, def, "email profile".to_string(); + sso_scopes: String, true, def, "email profile".to_owned(); /// Authorization request extra parameters sso_authorize_extra_params: String, true, def, String::new(); /// Use PKCE during Authorization flow @@ -880,7 +880,7 @@ make_config! { /// From Address smtp_from: String, true, def, String::new(); /// From Name - smtp_from_name: String, true, def, "Vaultwarden".to_string(); + smtp_from_name: String, true, def, "Vaultwarden".to_owned(); /// Username smtp_username: String, true, option; /// Password @@ -930,13 +930,13 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> { let file_path = url.strip_prefix("sqlite://").unwrap_or(url); if file_path.contains('/') { let path = std::path::Path::new(file_path); - if let Some(parent) = path.parent() { - if !parent.is_dir() { - err!(format!( - "SQLite database directory `{}` does not exist or is not a directory", - parent.display() - )); - } + if let Some(parent) = path.parent() + && !parent.is_dir() + { + err!(format!( + "SQLite database directory `{}` does not exist or is not a directory", + parent.display() + )); } } } @@ -959,10 +959,10 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> { err!(format!("`DATABASE_MIN_CONNS` must be smaller than or equal to `DATABASE_MAX_CONNS`.",)); } - if let Some(log_file) = &cfg.log_file { - if std::fs::OpenOptions::new().append(true).create(true).open(log_file).is_err() { - err!("Unable to write to log file", log_file); - } + if let Some(log_file) = &cfg.log_file + && std::fs::OpenOptions::new().append(true).create(true).open(log_file).is_err() + { + err!("Unable to write to log file", log_file); } let dom = cfg.domain.to_lowercase(); @@ -975,7 +975,9 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> { let connect_src = cfg.allowed_connect_src.to_lowercase(); for url in connect_src.split_whitespace() { if !url.starts_with("https://") || Url::parse(url).is_err() { - err!("ALLOWED_CONNECT_SRC variable contains one or more invalid URLs. Only FQDN's starting with https are allowed"); + err!( + "ALLOWED_CONNECT_SRC variable contains one or more invalid URLs. Only FQDN's starting with https are allowed" + ); } } @@ -991,11 +993,12 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> { err!("`ORG_CREATION_USERS` contains invalid email addresses"); } - if let Some(ref token) = cfg.admin_token { - if token.trim().is_empty() && !cfg.disable_admin_token { - println!("[WARNING] `ADMIN_TOKEN` is enabled but has an empty value, so the admin page will be disabled."); - println!("[WARNING] To enable the admin page without a token, use `DISABLE_ADMIN_TOKEN`."); - } + if let Some(ref token) = cfg.admin_token + && token.trim().is_empty() + && !cfg.disable_admin_token + { + println!("[WARNING] `ADMIN_TOKEN` is enabled but has an empty value, so the admin page will be disabled."); + println!("[WARNING] To enable the admin page without a token, use `DISABLE_ADMIN_TOKEN`."); } if cfg.push_enabled && (cfg.push_installation_id == String::new() || cfg.push_installation_key == String::new()) { @@ -1029,37 +1032,41 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> { } } - let invalid_flags = - parse_experimental_client_feature_flags(&cfg.experimental_client_feature_flags, FeatureFlagFilter::InvalidOnly); + let invalid_flags = parse_experimental_client_feature_flags( + &cfg.experimental_client_feature_flags, + &FeatureFlagFilter::InvalidOnly, + ); if !invalid_flags.is_empty() { - let feature_flags_error = format!("Unrecognized experimental client feature flags: {:?}.\n\ + let feature_flags_error = format!( + "Unrecognized experimental client feature flags: {invalid_flags:?}.\n\ Please ensure all feature flags are spelled correctly and that they are supported in this version.\n\ - Supported flags: {:?}\n", invalid_flags, SUPPORTED_FEATURE_FLAGS); + Supported flags: {SUPPORTED_FEATURE_FLAGS:?}\n" + ); if on_update { err!(feature_flags_error); - } else { - println!("[WARNING] {feature_flags_error}"); } + println!("[WARNING] {feature_flags_error}"); } + #[expect(clippy::items_after_statements, reason = "Keep this close to where it is used")] const MAX_FILESIZE_KB: i64 = i64::MAX >> 10; - if let Some(limit) = cfg.user_attachment_limit { - if !(0i64..=MAX_FILESIZE_KB).contains(&limit) { - err!("`USER_ATTACHMENT_LIMIT` is out of bounds"); - } + if let Some(limit) = cfg.user_attachment_limit + && !(0i64..=MAX_FILESIZE_KB).contains(&limit) + { + err!("`USER_ATTACHMENT_LIMIT` is out of bounds"); } - if let Some(limit) = cfg.org_attachment_limit { - if !(0i64..=MAX_FILESIZE_KB).contains(&limit) { - err!("`ORG_ATTACHMENT_LIMIT` is out of bounds"); - } + if let Some(limit) = cfg.org_attachment_limit + && !(0i64..=MAX_FILESIZE_KB).contains(&limit) + { + err!("`ORG_ATTACHMENT_LIMIT` is out of bounds"); } - if let Some(limit) = cfg.user_send_limit { - if !(0i64..=MAX_FILESIZE_KB).contains(&limit) { - err!("`USER_SEND_LIMIT` is out of bounds"); - } + if let Some(limit) = cfg.user_send_limit + && !(0i64..=MAX_FILESIZE_KB).contains(&limit) + { + err!("`USER_SEND_LIMIT` is out of bounds"); } if cfg._enable_duo @@ -1087,7 +1094,9 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> { if let Some(yubico_server) = &cfg.yubico_server { let yubico_server = yubico_server.to_lowercase(); if !yubico_server.starts_with("https://") { - err!("`YUBICO_SERVER` must be a valid URL and start with 'https://'. Either unset this variable or provide a valid URL.") + err!( + "`YUBICO_SERVER` must be a valid URL and start with 'https://'. Either unset this variable or provide a valid URL." + ) } } } @@ -1139,7 +1148,9 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> { } if cfg.smtp_username.is_some() != cfg.smtp_password.is_some() { - err!("Both `SMTP_USERNAME` and `SMTP_PASSWORD` need to be set to enable email authentication without `USE_SENDMAIL`") + err!( + "Both `SMTP_USERNAME` and `SMTP_PASSWORD` need to be set to enable email authentication without `USE_SENDMAIL`" + ) } } @@ -1300,7 +1311,7 @@ fn extract_url_origin(url: &str) -> String { /// All trailing '/' chars are trimmed, even if the path is a lone '/'. fn extract_url_path(url: &str) -> String { match Url::parse(url) { - Ok(u) => u.path().trim_end_matches('/').to_string(), + Ok(u) => u.path().trim_end_matches('/').to_owned(), Err(_) => { // We already print it in the method above, no need to do it again String::new() @@ -1310,7 +1321,7 @@ fn extract_url_path(url: &str) -> String { fn generate_smtp_img_src(embed_images: bool, domain: &str) -> String { if embed_images { - "cid:".to_string() + "cid:".to_owned() } else { // normalize base_url let base_url = domain.trim_end_matches('/'); @@ -1329,10 +1340,10 @@ fn generate_sso_callback_path(domain: &str) -> String { fn generate_icon_service_url(icon_service: &str) -> String { match icon_service { "internal" => String::new(), - "bitwarden" => "https://icons.bitwarden.net/{}/icon.png".to_string(), - "duckduckgo" => "https://icons.duckduckgo.com/ip3/{}.ico".to_string(), - "google" => "https://www.google.com/s2/favicons?domain={}&sz=32".to_string(), - _ => icon_service.to_string(), + "bitwarden" => "https://icons.bitwarden.net/{}/icon.png".to_owned(), + "duckduckgo" => "https://icons.duckduckgo.com/ip3/{}.ico".to_owned(), + "google" => "https://www.google.com/s2/favicons?domain={}&sz=32".to_owned(), + _ => icon_service.to_owned(), } } @@ -1341,7 +1352,7 @@ fn generate_icon_service_csp(icon_service: &str, icon_service_url: &str) -> Stri // We split on the first '{', since that is the variable delimiter for an icon service URL. // Everything up until the first '{' should be fixed and can be used as an CSP string. let csp_string = match icon_service_url.split_once('{') { - Some((c, _)) => c.to_string(), + Some((c, _)) => c.to_owned(), None => String::new(), }; @@ -1358,12 +1369,12 @@ fn smtp_convert_deprecated_ssl_options(smtp_ssl: Option, smtp_explicit_tls println!("[DEPRECATED]: `SMTP_SSL` or `SMTP_EXPLICIT_TLS` is set. Please use `SMTP_SECURITY` instead."); } if smtp_explicit_tls.is_some() && smtp_explicit_tls.unwrap() { - return "force_tls".to_string(); + return "force_tls".to_owned(); } else if smtp_ssl.is_some() && !smtp_ssl.unwrap() { - return "off".to_string(); + return "off".to_owned(); } // Return the default `starttls` in all other cases - "starttls".to_string() + "starttls".to_owned() } pub enum PathType { @@ -1406,12 +1417,12 @@ pub const SUPPORTED_FEATURE_FLAGS: &[&str] = &[ impl Config { pub async fn load() -> Result { // Loading from env and file - let _env = ConfigBuilder::from_env(); - let _usr = ConfigBuilder::from_file().await.unwrap_or_default(); + let env = ConfigBuilder::from_env(); + let usr = ConfigBuilder::from_file().await.unwrap_or_default(); // Create merged config, config file overwrites env - let mut _overrides = Vec::new(); - let builder = _env.merge(&_usr, true, &mut _overrides); + let mut overrides = Vec::new(); + let builder = env.merge(&usr, true, &mut overrides); // Fill any missing with defaults let config = builder.build(); @@ -1424,9 +1435,9 @@ impl Config { rocket_shutdown_handle: None, templates: load_templates(&config.templates_folder), config, - _env, - _usr, - _overrides, + _env: env, + _usr: usr, + _overrides: overrides, }), }) } @@ -1472,8 +1483,8 @@ impl Config { async fn update_config_partial(&self, other: ConfigBuilder) -> Result<(), Error> { let builder = { let usr = &self.inner.read().unwrap()._usr; - let mut _overrides = Vec::new(); - usr.merge(&other, false, &mut _overrides) + let mut overrides = Vec::new(); + usr.merge(&other, false, &mut overrides) }; self.update_config(builder, false).await } @@ -1496,11 +1507,11 @@ impl Config { /// Tests whether signup is allowed for an email address, taking into /// account the signups_allowed and signups_domains_whitelist settings. pub fn is_signup_allowed(&self, email: &str) -> bool { - if !self.signups_domains_whitelist().is_empty() { + if self.signups_domains_whitelist().is_empty() { + self.signups_allowed() + } else { // The whitelist setting overrides the signups_allowed setting. self.is_email_domain_allowed(email) - } else { - self.signups_allowed() } } @@ -1621,10 +1632,10 @@ impl Config { } pub fn shutdown(&self) { - if let Ok(mut c) = self.inner.write() { - if let Some(handle) = c.rocket_shutdown_handle.take() { - handle.notify(); - } + if let Ok(mut c) = self.inner.write() + && let Some(handle) = c.rocket_shutdown_handle.take() + { + handle.notify(); } } @@ -1641,7 +1652,7 @@ impl Config { } pub fn sso_scopes_vec(&self) -> Vec { - self.sso_scopes().split_whitespace().map(str::to_string).collect() + self.sso_scopes().split_whitespace().map(str::to_owned).collect() } pub fn sso_authorize_extra_params_vec(&self) -> Vec<(String, String)> { @@ -1751,7 +1762,7 @@ fn case_helper<'reg, 'rc>( let value = param.value().clone(); if h.params().iter().skip(1).any(|x| x.value() == &value) { - h.template().map(|t| t.render(r, ctx, rc, out)).unwrap_or_else(|| Ok(())) + h.template().map_or(Ok(()), |t| t.render(r, ctx, rc, out)) } else { Ok(()) } diff --git a/src/db/mod.rs b/src/db/mod.rs index 4aafe995..ff568bd0 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -6,15 +6,15 @@ use std::{ }; use diesel::{ + Connection, RunQueryDsl, connection::SimpleConnection, r2d2::{CustomizeConnection, Pool, PooledConnection}, - Connection, RunQueryDsl, }; use rocket::{ + Request, http::Status, request::{FromRequest, Outcome}, - Request, }; use tokio::{ @@ -23,8 +23,8 @@ use tokio::{ }; use crate::{ - error::{Error, MapResult}, CONFIG, + error::{Error, MapResult}, }; // These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools @@ -62,7 +62,7 @@ pub struct DbConnManager { impl DbConnManager { pub fn new(database_url: &str) -> Self { Self { - database_url: database_url.to_string(), + database_url: database_url.to_owned(), } } @@ -224,7 +224,7 @@ impl DbPool { // Set a global to determine the database more easily throughout the rest of the code if ACTIVE_DB_TYPE.set(conn_type).is_err() { - error!("Tried to set the active database connection type more than once.") + error!("Tried to set the active database connection type more than once."); } Ok(DbPool { @@ -279,34 +279,33 @@ impl DbConnType { #[cfg(not(sqlite))] err!("`DATABASE_URL` is a SQLite URL, but the 'sqlite' feature is not enabled") + } // No recognized scheme — assume legacy bare-path SQLite, but the database file must already exist. // This prevents misconfigured URLs (typos, quoted strings) from silently creating a new empty SQLite database. - } else { - #[cfg(sqlite)] - { - if std::path::Path::new(url).exists() { - return Ok(DbConnType::Sqlite); - } - err!(format!( - "`DATABASE_URL` does not match any known database scheme (mysql://, postgresql://, sqlite://) \ - and no existing SQLite database was found at '{url}'. \ - If you intend to use SQLite, use an explicit `sqlite://` scheme in your `DATABASE_URL`. \ - Otherwise, check your DATABASE_URL for typos or quoting issues." - )) + #[cfg(sqlite)] + { + if std::path::Path::new(url).exists() { + return Ok(DbConnType::Sqlite); } - - #[cfg(not(sqlite))] - err!("`DATABASE_URL` does not match any known database scheme (mysql://, postgresql://, sqlite://)") + err!(format!( + "`DATABASE_URL` does not match any known database scheme (mysql://, postgresql://, sqlite://) \ + and no existing SQLite database was found at '{url}'. \ + If you intend to use SQLite, use an explicit `sqlite://` scheme in your `DATABASE_URL`. \ + Otherwise, check your DATABASE_URL for typos or quoting issues." + )) } + + #[cfg(not(sqlite))] + err!("`DATABASE_URL` does not match any known database scheme (mysql://, postgresql://, sqlite://)") } pub fn get_init_stmts(&self) -> String { let init_stmts = CONFIG.database_conn_init(); - if !init_stmts.is_empty() { - init_stmts - } else { + if init_stmts.is_empty() { self.default_init_stmts() + } else { + init_stmts } } @@ -317,7 +316,7 @@ impl DbConnType { #[cfg(postgresql)] Self::Postgresql => String::new(), #[cfg(sqlite)] - Self::Sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(), + Self::Sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_owned(), } } } @@ -408,7 +407,7 @@ pub fn backup_sqlite() -> Result { use diesel::Connection; let db_url = CONFIG.database_url(); - if DbConnType::from_url(&CONFIG.database_url()).map(|t| t == DbConnType::Sqlite).unwrap_or(false) { + if DbConnType::from_url(&CONFIG.database_url()).is_ok_and(|t| t == DbConnType::Sqlite) { // Strip the sqlite:// prefix if present to get the raw file path let file_path = db_url.strip_prefix("sqlite://").unwrap_or(&db_url); // Open a read-only connection for the backup @@ -443,12 +442,12 @@ pub async fn get_sql_server_version(conn: &DbConn) -> String { postgresql,mysql { diesel::select(diesel::dsl::sql::("version();")) .get_result::(conn) - .unwrap_or_else(|_| "Unknown".to_string()) + .unwrap_or_else(|_| "Unknown".to_owned()) } sqlite { diesel::select(diesel::dsl::sql::("sqlite_version();")) .get_result::(conn) - .unwrap_or_else(|_| "Unknown".to_string()) + .unwrap_or_else(|_| "Unknown".to_owned()) } } } diff --git a/src/db/models/archive.rs b/src/db/models/archive.rs index f576e7ed..ac15fa93 100644 --- a/src/db/models/archive.rs +++ b/src/db/models/archive.rs @@ -3,8 +3,8 @@ use diesel::prelude::*; use super::{CipherId, User, UserId}; use crate::api::EmptyResult; -use crate::db::schema::archives; use crate::db::DbConn; +use crate::db::schema::archives; use crate::error::MapResult; #[derive(Identifiable, Queryable, Insertable)] diff --git a/src/db/models/attachment.rs b/src/db/models/attachment.rs index dad081bd..f1972813 100644 --- a/src/db/models/attachment.rs +++ b/src/db/models/attachment.rs @@ -6,7 +6,7 @@ use std::time::Duration; use super::{CipherId, OrganizationId, UserId}; use crate::db::schema::{attachments, ciphers}; -use crate::{config::PathType, CONFIG}; +use crate::{CONFIG, config::PathType}; use macros::IdFromParam; #[derive(Identifiable, Queryable, Insertable, AsChangeset)] diff --git a/src/db/models/cipher.rs b/src/db/models/cipher.rs index db906179..f23bb6f8 100644 --- a/src/db/models/cipher.rs +++ b/src/db/models/cipher.rs @@ -1,9 +1,9 @@ +use crate::CONFIG; use crate::db::schema::{ ciphers, ciphers_collections, collections, collections_groups, folders, folders_ciphers, groups, groups_users, users_collections, users_organizations, }; use crate::util::LowerCase; -use crate::CONFIG; use chrono::{NaiveDateTime, TimeDelta, Utc}; use derive_more::{AsRef, Deref, Display, From}; use diesel::prelude::*; @@ -91,27 +91,27 @@ impl Cipher { format!("The field Notes exceeds the maximum encrypted value length of {max_note_size} characters."); for (index, cipher) in cipher_data.iter().enumerate() { // Validate the note size and if it is exceeded return a warning - if let Some(note) = &cipher.notes { - if note.len() > max_note_size { - validation_errors - .insert(format!("Ciphers[{index}].Notes"), serde_json::to_value([&max_note_size_msg]).unwrap()); - } + if let Some(note) = &cipher.notes + && note.len() > max_note_size + { + validation_errors + .insert(format!("Ciphers[{index}].Notes"), serde_json::to_value([&max_note_size_msg]).unwrap()); } // Validate the password history if it contains `null` values and if so, return a warning if let Some(Value::Array(password_history)) = &cipher.password_history { for pwh in password_history { - if let Value::Object(pwo) = pwh { - if pwo.get("password").is_some_and(|p| !p.is_string()) { - validation_errors.insert( - format!("Ciphers[{index}].Notes"), - serde_json::to_value([ - "The password history contains a `null` value. Only strings are allowed.", - ]) - .unwrap(), - ); - break; - } + if let Value::Object(pwo) = pwh + && pwo.get("password").is_some_and(|p| !p.is_string()) + { + validation_errors.insert( + format!("Ciphers[{index}].Notes"), + serde_json::to_value([ + "The password history contains a `null` value. Only strings are allowed.", + ]) + .unwrap(), + ); + break; } } } @@ -124,9 +124,9 @@ impl Cipher { "object": "error" }); err_json!(err_json, "Import validation errors") - } else { - Ok(()) } + + Ok(()) } } @@ -149,14 +149,14 @@ impl Cipher { let mut attachments_json: Value = Value::Null; if let Some(cipher_sync_data) = cipher_sync_data { - if let Some(attachments) = cipher_sync_data.cipher_attachments.get(&self.uuid) { - if !attachments.is_empty() { - let mut attachments_json_vec = vec![]; - for attachment in attachments { - attachments_json_vec.push(attachment.to_json(host).await?); - } - attachments_json = Value::Array(attachments_json_vec); + if let Some(attachments) = cipher_sync_data.cipher_attachments.get(&self.uuid) + && !attachments.is_empty() + { + let mut attachments_json_vec = vec![]; + for attachment in attachments { + attachments_json_vec.push(attachment.to_json(host).await?); } + attachments_json = Value::Array(attachments_json_vec); } } else { let attachments = Attachment::find_by_cipher(&self.uuid, conn).await; @@ -172,12 +172,11 @@ impl Cipher { // We don't need these values at all for Organizational syncs // Skip any other database calls if this is the case and just return false. let (read_only, hide_passwords, _) = if sync_type == CipherSyncType::User { - match self.get_access_restrictions(user_uuid, cipher_sync_data, conn).await { - Some((ro, hp, mn)) => (ro, hp, mn), - None => { - error!("Cipher ownership assertion failure"); - (true, true, false) - } + if let Some((ro, hp, mn)) = self.get_access_restrictions(user_uuid, cipher_sync_data, conn).await { + (ro, hp, mn) + } else { + error!("Cipher ownership assertion failure"); + (true, true, false) } } else { (false, false, false) @@ -231,15 +230,14 @@ impl Cipher { Some(p) if p.is_string() => Some(d.data), _ => None, }) - .map(|mut d| match d.get("lastUsedDate").and_then(|l| l.as_str()) { - Some(l) => { - d["lastUsedDate"] = json!(validate_and_format_date(l)); - d - } - _ => { - d["lastUsedDate"] = json!("1970-01-01T00:00:00.000000Z"); - d - } + .map(|mut d| { + let lud = if let Some(l) = d.get("lastUsedDate").and_then(|l| l.as_str()) { + validate_and_format_date(l) + } else { + "1970-01-01T00:00:00.000000Z".to_owned() + }; + d["lastUsedDate"] = json!(lud); + d }) .collect() }) @@ -247,32 +245,30 @@ impl Cipher { // Get the type_data or a default to an empty json object '{}'. // If not passing an empty object, mobile clients will crash. - let mut type_data_json = - serde_json::from_str::>(&self.data).map(|d| d.data).unwrap_or_else(|_| { - warn!("Error parsing data field for {}", self.uuid); - Value::Object(serde_json::Map::new()) - }); + let mut type_data_json = serde_json::from_str::>(&self.data) + .inspect_err(|_| warn!("Error parsing data field for {}", self.uuid)) + .map_or_else(|_| Value::Object(serde_json::Map::new()), |d| d.data); // NOTE: This was marked as *Backwards Compatibility Code*, but as of January 2021 this is still being used by upstream // Set the first element of the Uris array as Uri, this is needed several (mobile) clients. if self.atype == 1 { // Upstream always has an `uri` key/value type_data_json["uri"] = Value::Null; - if let Some(uris) = type_data_json["uris"].as_array_mut() { - if !uris.is_empty() { - // Fix uri match values first, they are only allowed to be a number or null - // If it is a string, convert it to an int or null if that fails - for uri in &mut *uris { - if uri["match"].is_string() { - let match_value = match uri["match"].as_str().unwrap_or_default().parse::() { - Ok(n) => json!(n), - _ => Value::Null, - }; - uri["match"] = match_value; - } + if let Some(uris) = type_data_json["uris"].as_array_mut() + && !uris.is_empty() + { + // Fix uri match values first, they are only allowed to be a number or null + // If it is a string, convert it to an int or null if that fails + for uri in &mut *uris { + if uri["match"].is_string() { + let match_value = match uri["match"].as_str().unwrap_or_default().parse::() { + Ok(n) => json!(n), + _ => Value::Null, + }; + uri["match"] = match_value; } - type_data_json["uri"] = uris[0]["uri"].clone(); } + type_data_json["uri"] = uris[0]["uri"].clone(); } // Check if `passwordRevisionDate` is a valid date, else convert it @@ -285,7 +281,7 @@ impl Cipher { // This breaks at least the native mobile clients if self.atype == 2 { match type_data_json { - Value::Object(ref t) if t.get("type").is_some_and(|t| t.is_number()) => {} + Value::Object(ref t) if t.get("type").is_some_and(Value::is_number) => {} _ => { type_data_json = json!({"type": 0}); } @@ -297,9 +293,9 @@ impl Cipher { // The only way to fix this is by setting type_data_json to `null` // Opening this ssh-key in the mobile client will probably crash the client, but you can edit, save and afterwards delete it if self.atype == 5 - && (type_data_json["keyFingerprint"].as_str().is_none_or(|v| v.is_empty()) - || type_data_json["privateKey"].as_str().is_none_or(|v| v.is_empty()) - || type_data_json["publicKey"].as_str().is_none_or(|v| v.is_empty())) + && (type_data_json["keyFingerprint"].as_str().is_none_or(str::is_empty) + || type_data_json["privateKey"].as_str().is_none_or(str::is_empty) + || type_data_json["publicKey"].as_str().is_none_or(str::is_empty)) { warn!("Error parsing ssh-key, mandatory fields are invalid for {}", self.uuid); type_data_json = Value::Null; @@ -415,7 +411,7 @@ impl Cipher { match self.user_uuid { Some(ref user_uuid) => { User::update_uuid_revision(user_uuid, conn).await; - user_uuids.push(user_uuid.clone()) + user_uuids.push(user_uuid.clone()); } None => { // Belongs to Organization, need to update affected users @@ -430,11 +426,11 @@ impl Cipher { } for member in collection_users { User::update_uuid_revision(&member.user_uuid, conn).await; - user_uuids.push(member.user_uuid.clone()) + user_uuids.push(member.user_uuid.clone()); } } } - }; + } user_uuids } @@ -531,9 +527,10 @@ impl Cipher { // Remove from folder (Some(old_folder), None) => { - match FolderCipher::find_by_folder_and_cipher(&old_folder, &self.uuid, conn).await { - Some(old_folder) => old_folder.delete(conn).await, - None => err!("Couldn't move from previous folder"), + if let Some(old_folder) = FolderCipher::find_by_folder_and_cipher(&old_folder, &self.uuid, conn).await { + old_folder.delete(conn).await + } else { + err!("Couldn't move from previous folder") } } @@ -584,9 +581,8 @@ impl Cipher { if let Some(ref org_uuid) = self.organization_uuid { if let Some(cipher_sync_data) = cipher_sync_data { return cipher_sync_data.user_group_full_access_for_organizations.contains(org_uuid); - } else { - return Group::is_in_full_access_group(user_uuid, org_uuid, conn).await; } + return Group::is_in_full_access_group(user_uuid, org_uuid, conn).await; } false } @@ -628,10 +624,10 @@ impl Cipher { rows } else { let user_permissions = self.get_user_collections_access_flags(user_uuid, conn).await; - if !user_permissions.is_empty() { - user_permissions - } else { + if user_permissions.is_empty() { self.get_group_collections_access_flags(user_uuid, conn).await + } else { + user_permissions } }; @@ -657,7 +653,7 @@ impl Cipher { let mut read_only = true; let mut hide_passwords = true; let mut manage = false; - for (ro, hp, mn) in rows.iter() { + for (ro, hp, mn) in &rows { read_only &= ro; hide_passwords &= hp; manage |= mn; diff --git a/src/db/models/collection.rs b/src/db/models/collection.rs index b1f82335..679550d6 100644 --- a/src/db/models/collection.rs +++ b/src/db/models/collection.rs @@ -5,10 +5,10 @@ use super::{ CipherId, CollectionGroup, GroupUser, Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId, User, UserId, }; +use crate::CONFIG; use crate::db::schema::{ ciphers_collections, collections, collections_groups, groups, groups_users, users_collections, users_organizations, }; -use crate::CONFIG; use diesel::prelude::*; use macros::UuidFromParam; @@ -74,7 +74,7 @@ impl Collection { if external_id.is_empty() { self.external_id = None; } else { - self.external_id = Some(external_id) + self.external_id = Some(external_id); } } None => self.external_id = None, @@ -208,7 +208,7 @@ impl Collection { } pub async fn update_users_revision(&self, conn: &DbConn) { - for member in Membership::find_by_collection_and_org(&self.uuid, &self.org_uuid, conn).await.iter() { + for member in &Membership::find_by_collection_and_org(&self.uuid, &self.org_uuid, conn).await { User::update_uuid_revision(&member.user_uuid, conn).await; } } @@ -597,7 +597,7 @@ impl CollectionUser { .load::(conn) .expect("Error loading users_collections") }}; - col_users.into_iter().map(|c| c.into()).collect() + col_users.into_iter().map(Into::into).collect() } pub async fn save( @@ -701,7 +701,7 @@ impl CollectionUser { .load::(conn) .expect("Error loading users_collections") }}; - col_users.into_iter().map(|c| c.into()).collect() + col_users.into_iter().map(Into::into).collect() } pub async fn find_by_collection_and_user( @@ -730,7 +730,7 @@ impl CollectionUser { } pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult { - for collection in CollectionUser::find_by_collection(collection_uuid, conn).await.iter() { + for collection in &CollectionUser::find_by_collection(collection_uuid, conn).await { User::update_uuid_revision(&collection.user_uuid, conn).await; } diff --git a/src/db/models/device.rs b/src/db/models/device.rs index 7364a2ec..3eb817c5 100644 --- a/src/db/models/device.rs +++ b/src/db/models/device.rs @@ -337,6 +337,7 @@ pub enum DeviceType { } impl DeviceType { + #[expect(clippy::match_same_arms, reason = "Specifically define 14 and have a fallback for new types")] pub fn from_i32(value: i32) -> DeviceType { match value { 0 => DeviceType::Android, diff --git a/src/db/models/emergency_access.rs b/src/db/models/emergency_access.rs index 5ea334a4..fea034a3 100644 --- a/src/db/models/emergency_access.rs +++ b/src/db/models/emergency_access.rs @@ -87,13 +87,12 @@ impl EmergencyAccess { User::find_by_uuid(grantee_uuid, conn).await.expect("Grantee user not found.") } else { let email = self.email.as_deref()?; - match User::find_by_mail(email, conn).await { - Some(user) => user, - None => { - // remove outstanding invitations which should not exist - Self::delete_all_by_grantee_email(email, conn).await.ok(); - return None; - } + if let Some(user) = User::find_by_mail(email, conn).await { + user + } else { + // remove outstanding invitations which should not exist + Self::delete_all_by_grantee_email(email, conn).await.ok(); + return None; } }; diff --git a/src/db/models/event.rs b/src/db/models/event.rs index bd4b2310..bea01ce1 100644 --- a/src/db/models/event.rs +++ b/src/db/models/event.rs @@ -4,7 +4,7 @@ use serde_json::Value; use super::{CipherId, CollectionId, GroupId, MembershipId, OrgPolicyId, OrganizationId, UserId}; use crate::db::schema::{event, users_organizations}; -use crate::{api::EmptyResult, db::DbConn, error::MapResult, CONFIG}; +use crate::{CONFIG, api::EmptyResult, db::DbConn, error::MapResult}; use diesel::prelude::*; // https://bitwarden.com/help/event-logs/ diff --git a/src/db/models/folder.rs b/src/db/models/folder.rs index b4cbc7ff..f63e3378 100644 --- a/src/db/models/folder.rs +++ b/src/db/models/folder.rs @@ -56,8 +56,8 @@ impl Folder { impl FolderCipher { pub fn new(folder_uuid: FolderId, cipher_uuid: CipherId) -> Self { Self { - folder_uuid, cipher_uuid, + folder_uuid, } } } diff --git a/src/db/models/group.rs b/src/db/models/group.rs index f41ad9ca..ea1e7d48 100644 --- a/src/db/models/group.rs +++ b/src/db/models/group.rs @@ -1,7 +1,7 @@ use super::{CollectionId, Membership, MembershipId, OrganizationId, User, UserId}; use crate::api::EmptyResult; -use crate::db::schema::{collections, collections_groups, groups, groups_users, users_organizations}; use crate::db::DbConn; +use crate::db::schema::{collections, collections_groups, groups, groups_users, users_organizations}; use crate::error::MapResult; use chrono::{NaiveDateTime, Utc}; use derive_more::{AsRef, Deref, Display, From}; @@ -288,12 +288,12 @@ impl Group { } pub async fn update_revision(uuid: &GroupId, conn: &DbConn) { - if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await { + if let Err(e) = Self::update_revision_impl(uuid, &Utc::now().naive_utc(), conn).await { warn!("Failed to update revision for {uuid}: {e:#?}"); } } - async fn _update_revision(uuid: &GroupId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult { + async fn update_revision_impl(uuid: &GroupId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult { db_run! { conn: { crate::util::retry(|| { diesel::update(groups::table.filter(groups::uuid.eq(uuid))) @@ -606,7 +606,7 @@ impl GroupUser { match Membership::find_by_uuid(member_uuid, conn).await { Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await, None => warn!("Member could not be found!"), - }; + } db_run! { conn: { diesel::delete(groups_users::table) diff --git a/src/db/models/mod.rs b/src/db/models/mod.rs index 7cc81852..1cacbcac 100644 --- a/src/db/models/mod.rs +++ b/src/db/models/mod.rs @@ -23,7 +23,7 @@ pub use self::attachment::{Attachment, AttachmentId}; pub use self::auth_request::{AuthRequest, AuthRequestId}; pub use self::cipher::{Cipher, CipherId, RepromptType}; pub use self::collection::{Collection, CollectionCipher, CollectionId, CollectionUser}; -pub use self::device::{Device, DeviceId, DeviceType, PushId}; +pub use self::device::{Device, DeviceId, DeviceType, DeviceWithAuthRequest, PushId}; pub use self::emergency_access::{EmergencyAccess, EmergencyAccessId, EmergencyAccessStatus, EmergencyAccessType}; pub use self::event::{Event, EventType}; pub use self::favorite::Favorite; @@ -35,8 +35,8 @@ pub use self::organization::{ OrganizationId, }; pub use self::send::{ - id::{SendFileId, SendId}, Send, SendType, + id::{SendFileId, SendId}, }; pub use self::sso_auth::{OIDCAuthenticatedUser, OIDCCodeResponseError, SsoAuth}; pub use self::two_factor::{TwoFactor, TwoFactorType}; diff --git a/src/db/models/org_policy.rs b/src/db/models/org_policy.rs index 7e922f35..701b6ff9 100644 --- a/src/db/models/org_policy.rs +++ b/src/db/models/org_policy.rs @@ -2,12 +2,12 @@ use derive_more::{AsRef, From}; use serde::Deserialize; use serde_json::Value; -use crate::api::core::two_factor; +use crate::CONFIG; use crate::api::EmptyResult; -use crate::db::schema::{org_policies, users_organizations}; +use crate::api::core::two_factor; use crate::db::DbConn; +use crate::db::schema::{org_policies, users_organizations}; use crate::error::MapResult; -use crate::CONFIG; use diesel::prelude::*; use super::{Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId, TwoFactor, UserId}; @@ -269,10 +269,10 @@ impl OrgPolicy { continue; } - if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await { - if user.atype < MembershipType::Admin { - return true; - } + if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await + && user.atype < MembershipType::Admin + { + return true; } } false @@ -282,13 +282,13 @@ impl OrgPolicy { if m.atype < MembershipType::Admin && m.status > (MembershipStatus::Invited as i32) { // Enforce TwoFactor/TwoStep login if let Some(p) = Self::find_by_org_and_type(&m.org_uuid, OrgPolicyType::TwoFactorAuthentication, conn).await + && p.enabled + && TwoFactor::find_by_user(&m.user_uuid, conn).await.is_empty() { - if p.enabled && TwoFactor::find_by_user(&m.user_uuid, conn).await.is_empty() { - if CONFIG.email_2fa_auto_fallback() { - two_factor::email::find_and_activate_email_2fa(&m.user_uuid, conn).await?; - } else { - err!(format!("Cannot {} because 2FA is required (membership {})", action, m.uuid)); - } + if CONFIG.email_2fa_auto_fallback() { + two_factor::email::find_and_activate_email_2fa(&m.user_uuid, conn).await?; + } else { + err!(format!("Cannot {} because 2FA is required (membership {})", action, m.uuid)); } } @@ -300,12 +300,14 @@ impl OrgPolicy { )); } - if let Some(p) = Self::find_by_org_and_type(&m.org_uuid, OrgPolicyType::SingleOrg, conn).await { - if p.enabled - && Membership::count_accepted_and_confirmed_by_user(&m.user_uuid, &m.org_uuid, conn).await > 0 - { - err!(format!("Cannot {} because the organization policy forbids being part of other organization (membership {})", action, m.uuid)); - } + if let Some(p) = Self::find_by_org_and_type(&m.org_uuid, OrgPolicyType::SingleOrg, conn).await + && p.enabled + && Membership::count_accepted_and_confirmed_by_user(&m.user_uuid, &m.org_uuid, conn).await > 0 + { + err!(format!( + "Cannot {} because the organization policy forbids being part of other organization (membership {})", + action, m.uuid + )); } } @@ -332,16 +334,16 @@ impl OrgPolicy { for policy in OrgPolicy::find_confirmed_by_user_and_active_policy(user_uuid, OrgPolicyType::SendOptions, conn).await { - if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await { - if user.atype < MembershipType::Admin { - match serde_json::from_str::(&policy.data) { - Ok(opts) => { - if opts.disable_hide_email { - return true; - } + if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await + && user.atype < MembershipType::Admin + { + match serde_json::from_str::(&policy.data) { + Ok(opts) => { + if opts.disable_hide_email { + return true; } - _ => error!("Failed to deserialize SendOptionsPolicyData: {}", policy.data), } + _ => error!("Failed to deserialize SendOptionsPolicyData: {}", policy.data), } } } @@ -349,10 +351,10 @@ impl OrgPolicy { } pub async fn is_enabled_for_member(member_uuid: &MembershipId, policy_type: OrgPolicyType, conn: &DbConn) -> bool { - if let Some(member) = Membership::find_by_uuid(member_uuid, conn).await { - if let Some(policy) = OrgPolicy::find_by_org_and_type(&member.org_uuid, policy_type, conn).await { - return policy.enabled; - } + if let Some(member) = Membership::find_by_uuid(member_uuid, conn).await + && let Some(policy) = OrgPolicy::find_by_org_and_type(&member.org_uuid, policy_type, conn).await + { + return policy.enabled; } false } diff --git a/src/db/models/organization.rs b/src/db/models/organization.rs index ae19b30c..c2c64acb 100644 --- a/src/db/models/organization.rs +++ b/src/db/models/organization.rs @@ -12,11 +12,11 @@ use super::{ CipherId, Collection, CollectionGroup, CollectionId, CollectionUser, Group, GroupId, GroupUser, OrgPolicy, OrgPolicyType, TwoFactor, User, UserId, }; +use crate::CONFIG; use crate::db::schema::{ ciphers, ciphers_collections, collections_groups, groups, groups_users, org_policies, organization_api_key, organizations, users, users_collections, users_organizations, }; -use crate::CONFIG; use macros::UuidFromParam; #[derive(Identifiable, Queryable, Insertable, AsChangeset)] @@ -93,6 +93,10 @@ pub enum MembershipType { impl MembershipType { pub fn from_str(s: &str) -> Option { + #[expect( + clippy::match_same_arms, + reason = "Specifically define `4|Custom` since this is a hack, not a default" + )] match s { "0" | "Owner" => Some(MembershipType::Owner), "1" | "Admin" => Some(MembershipType::Admin), @@ -333,7 +337,7 @@ impl Organization { err!(format!("BillingEmail {} is not a valid email address", self.billing_email)) } - for member in Membership::find_by_org(&self.uuid, conn).await.iter() { + for member in &Membership::find_by_org(&self.uuid, conn).await { User::update_uuid_revision(&member.user_uuid, conn).await; } @@ -802,10 +806,10 @@ impl Membership { } pub async fn find_by_email_and_org(email: &str, org_uuid: &OrganizationId, conn: &DbConn) -> Option { - if let Some(user) = User::find_by_mail(email, conn).await { - if let Some(member) = Membership::find_by_user_and_org(&user.uuid, org_uuid, conn).await { - return Some(member); - } + if let Some(user) = User::find_by_mail(email, conn).await + && let Some(member) = Membership::find_by_user_and_org(&user.uuid, org_uuid, conn).await + { + return Some(member); } None diff --git a/src/db/models/send.rs b/src/db/models/send.rs index 5b6611fa..4c595bc1 100644 --- a/src/db/models/send.rs +++ b/src/db/models/send.rs @@ -1,7 +1,7 @@ use chrono::{NaiveDateTime, Utc}; use serde_json::Value; -use crate::{config::PathType, util::LowerCase, CONFIG}; +use crate::{CONFIG, config::PathType, util::LowerCase}; use super::{OrganizationId, User, UserId}; use crate::db::schema::sends; @@ -107,23 +107,23 @@ impl Send { pub fn check_password(&self, password: &str) -> bool { match (&self.password_hash, &self.password_salt, self.password_iter) { (Some(hash), Some(salt), Some(iter)) => { - crate::crypto::verify_password_hash(password.as_bytes(), salt, hash, iter as u32) + crate::crypto::verify_password_hash(password.as_bytes(), salt, hash, iter.cast_unsigned()) } _ => false, } } pub async fn creator_identifier(&self, conn: &DbConn) -> Option { - if let Some(hide_email) = self.hide_email { - if hide_email { - return None; - } + if let Some(hide_email) = self.hide_email + && hide_email + { + return None; } - if let Some(user_uuid) = &self.user_uuid { - if let Some(user) = User::find_by_uuid(user_uuid, conn).await { - return Some(user.email); - } + if let Some(user_uuid) = &self.user_uuid + && let Some(user) = User::find_by_uuid(user_uuid, conn).await + { + return Some(user.email); } None @@ -137,7 +137,7 @@ impl Send { let mut data = serde_json::from_str::>(&self.data).map(|d| d.data).unwrap_or_default(); // Mobile clients expect size to be a string instead of a number - if let Some(size) = data.get("size").and_then(|v| v.as_i64()) { + if let Some(size) = data.get("size").and_then(Value::as_i64) { data["size"] = Value::String(size.to_string()); } @@ -172,7 +172,7 @@ impl Send { let mut data = serde_json::from_str::>(&self.data).map(|d| d.data).unwrap_or_default(); // Mobile clients expect size to be a string instead of a number - if let Some(size) = data.get("size").and_then(|v| v.as_i64()) { + if let Some(size) = data.get("size").and_then(Value::as_i64) { data["size"] = Value::String(size.to_string()); } @@ -256,15 +256,12 @@ impl Send { pub async fn update_users_revision(&self, conn: &DbConn) -> Vec { let mut user_uuids = Vec::new(); - match &self.user_uuid { - Some(user_uuid) => { - User::update_uuid_revision(user_uuid, conn).await; - user_uuids.push(user_uuid.clone()) - } - None => { - // Belongs to Organization, not implemented - } - }; + if let Some(user_uuid) = &self.user_uuid { + User::update_uuid_revision(user_uuid, conn).await; + user_uuids.push(user_uuid.clone()); + } else { + // Belongs to Organization, not implemented + } user_uuids } @@ -320,22 +317,20 @@ impl Send { } pub async fn size_by_user(user_uuid: &UserId, conn: &DbConn) -> Option { - let sends = Self::find_by_user(user_uuid, conn).await; - #[derive(serde::Deserialize)] struct FileData { #[serde(rename = "size", alias = "Size")] size: NumberOrString, } + let sends = Self::find_by_user(user_uuid, conn).await; let mut total: i64 = 0; for send in sends { - if send.atype == SendType::File as i32 { - if let Ok(size) = + if send.atype == SendType::File as i32 + && let Ok(size) = serde_json::from_str::(&send.data).map_err(Into::into).and_then(|d| d.size.into_i64()) - { - total = total.checked_add(size)?; - }; + { + total = total.checked_add(size)?; } } diff --git a/src/db/models/two_factor.rs b/src/db/models/two_factor.rs index 0dc08e3e..cf64d950 100644 --- a/src/db/models/two_factor.rs +++ b/src/db/models/two_factor.rs @@ -150,6 +150,10 @@ impl TwoFactor { } pub async fn migrate_u2f_to_webauthn(conn: &DbConn) -> EmptyResult { + use crate::api::core::two_factor::webauthn::{U2FRegistration, get_webauthn_registrations}; + use webauthn_rs::prelude::{COSEEC2Key, COSEKey, COSEKeyType, ECDSACurve}; + use webauthn_rs_proto::{COSEAlgorithm, UserVerificationPolicy}; + let u2f_factors = db_run! { conn: { twofactor::table .filter(twofactor::atype.eq(TwoFactorType::U2f as i32)) @@ -157,11 +161,6 @@ impl TwoFactor { .expect("Error loading twofactor") }}; - use crate::api::core::two_factor::webauthn::U2FRegistration; - use crate::api::core::two_factor::webauthn::{get_webauthn_registrations, WebauthnRegistration}; - use webauthn_rs::prelude::{COSEEC2Key, COSEKey, COSEKeyType, ECDSACurve}; - use webauthn_rs_proto::{COSEAlgorithm, UserVerificationPolicy}; - for mut u2f in u2f_factors { let mut regs: Vec = serde_json::from_str(&u2f.data)?; // If there are no registrations or they are migrated (we do the migration in batch so we can consider them all migrated when the first one is) @@ -241,7 +240,7 @@ impl TwoFactor { continue; }; - let regs = regs.into_iter().map(|r| r.into()).collect::>(); + let regs = regs.into_iter().map(Into::into).collect::>(); TwoFactor::new(webauthn_factor.user_uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(®s)?) .save(conn) diff --git a/src/db/models/two_factor_duo_context.rs b/src/db/models/two_factor_duo_context.rs index 205a57d8..4f7d2388 100644 --- a/src/db/models/two_factor_duo_context.rs +++ b/src/db/models/two_factor_duo_context.rs @@ -29,7 +29,7 @@ impl TwoFactorDuoContext { let exists = Self::find_by_state(state, conn).await; if exists.is_some() { return Ok(()); - }; + } let exp = Utc::now().timestamp() + ttl; diff --git a/src/db/models/two_factor_incomplete.rs b/src/db/models/two_factor_incomplete.rs index 2f7e4779..ca008821 100644 --- a/src/db/models/two_factor_incomplete.rs +++ b/src/db/models/two_factor_incomplete.rs @@ -2,14 +2,14 @@ use chrono::{NaiveDateTime, Utc}; use crate::db::schema::twofactor_incomplete; use crate::{ + CONFIG, api::EmptyResult, auth::ClientIp, db::{ - models::{DeviceId, UserId}, DbConn, + models::{DeviceId, UserId}, }, error::MapResult, - CONFIG, }; use diesel::prelude::*; diff --git a/src/db/models/user.rs b/src/db/models/user.rs index ebc72101..9e43bc11 100644 --- a/src/db/models/user.rs +++ b/src/db/models/user.rs @@ -8,13 +8,13 @@ use super::{ Cipher, Device, EmergencyAccess, Favorite, Folder, Membership, MembershipType, TwoFactor, TwoFactorIncomplete, }; use crate::{ + CONFIG, api::EmptyResult, crypto, - db::{models::DeviceId, DbConn}, + db::{DbConn, models::DeviceId}, error::MapResult, sso::OIDCIdentifier, util::{format_date, get_uuid, retry}, - CONFIG, }; use macros::UuidFromParam; @@ -137,8 +137,8 @@ impl User { _totp_secret: None, totp_recover: None, - equivalent_domains: "[]".to_string(), - excluded_globals: "[]".to_string(), + equivalent_domains: "[]".to_owned(), + excluded_globals: "[]".to_owned(), client_kdf_type: Self::CLIENT_KDF_TYPE_DEFAULT, client_kdf_iter: Self::CLIENT_KDF_ITER_DEFAULT, @@ -158,7 +158,7 @@ impl User { password.as_bytes(), &self.salt, &self.password_hash, - self.password_iterations as u32, + self.password_iterations.cast_unsigned(), ) } @@ -193,7 +193,8 @@ impl User { allow_next_route: Option>, conn: &DbConn, ) -> EmptyResult { - self.password_hash = crypto::hash_password(password.as_bytes(), &self.salt, self.password_iterations as u32); + self.password_hash = + crypto::hash_password(password.as_bytes(), &self.salt, self.password_iterations.cast_unsigned()); if let Some(route) = allow_next_route { self.set_stamp_exception(route); @@ -238,10 +239,10 @@ impl User { pub fn display_name(&self) -> &str { // default to email if name is empty - if !&self.name.is_empty() { - &self.name - } else { + if self.name.is_empty() { &self.email + } else { + &self.name } } } @@ -345,7 +346,7 @@ impl User { } pub async fn update_uuid_revision(uuid: &UserId, conn: &DbConn) { - if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await { + if let Err(e) = Self::update_revision_impl(uuid, &Utc::now().naive_utc(), conn).await { warn!("Failed to update revision for {uuid}: {e:#?}"); } } @@ -366,10 +367,10 @@ impl User { pub async fn update_revision(&mut self, conn: &DbConn) -> EmptyResult { self.updated_at = Utc::now().naive_utc(); - Self::_update_revision(&self.uuid, &self.updated_at, conn).await + Self::update_revision_impl(&self.uuid, &self.updated_at, conn).await } - async fn _update_revision(uuid: &UserId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult { + async fn update_revision_impl(uuid: &UserId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult { db_run! { conn: { retry(|| { diesel::update(users::table.filter(users::uuid.eq(uuid))) diff --git a/src/db/query_logger.rs b/src/db/query_logger.rs index 0a207918..e8312aac 100644 --- a/src/db/query_logger.rs +++ b/src/db/query_logger.rs @@ -11,7 +11,7 @@ pub fn simple_logger() -> Option> { url, .. } => { - debug!("Establishing connection: {url}") + debug!("Establishing connection: {url}"); } InstrumentationEvent::FinishEstablishConnection { url, @@ -19,9 +19,9 @@ pub fn simple_logger() -> Option> { .. } => { if let Some(e) = error { - error!("Error during establishing a connection with {url}: {e:?}") + error!("Error during establishing a connection with {url}: {e:?}"); } else { - debug!("Connection established: {url}") + debug!("Connection established: {url}"); } } InstrumentationEvent::StartQuery { @@ -47,7 +47,7 @@ pub fn simple_logger() -> Option> { } else if duration.as_secs() >= 1 { info!("SLOW QUERY [{:.2}s]: {}", duration.as_secs_f32(), query_string); } else { - debug!("QUERY [{:?}]: {}", duration, query_string); + debug!("QUERY [{duration:?}]: {query_string}"); } } }); diff --git a/src/error.rs b/src/error.rs index 1a258fd1..ccc23e15 100644 --- a/src/error.rs +++ b/src/error.rs @@ -14,24 +14,24 @@ macro_rules! make_error { #[derive(Debug)] pub struct ErrorEvent { pub event: EventType } - pub struct Error { message: String, error: ErrorKind, error_code: u16, event: Option } + pub struct Error { message: String, kind: ErrorKind, code: u16, event: Option } $(impl From<$ty> for Error { fn from(err: $ty) -> Self { Error::from((stringify!($name), err)) } })+ $(impl> From<(S, $ty)> for Error { fn from(val: (S, $ty)) -> Self { - Error { message: val.0.into(), error: ErrorKind::$name(val.1), error_code: BAD_REQUEST, event: None } + Error { message: val.0.into(), kind: ErrorKind::$name(val.1), code: BAD_REQUEST, event: None } } })+ impl StdError for Error { fn source(&self) -> Option<&(dyn StdError + 'static)> { - match &self.error {$( ErrorKind::$name(e) => $src_fn(e), )+} + match &self.kind {$( ErrorKind::$name(e) => $src_fn(e), )+} } } impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self.error {$( + match &self.kind {$( ErrorKind::$name(e) => f.write_str(&$usr_msg_fun(e, &self.message)), )+} } @@ -39,10 +39,10 @@ macro_rules! make_error { }; } +use diesel::ConnectionError as DieselConErr; use diesel::r2d2::Error as R2d2Err; use diesel::r2d2::PoolError as R2d2PoolErr; use diesel::result::Error as DieselErr; -use diesel::ConnectionError as DieselConErr; use handlebars::RenderError as HbErr; use jsonwebtoken::errors::Error as JwtErr; use lettre::address::AddressError as AddrErr; @@ -71,46 +71,46 @@ pub struct Compact {} // The second one contains the function used to obtain the response sent to the client make_error! { // Just an empty error - Empty(Empty): _no_source, _serialize, + Empty(Empty): no_source, serialize, // Used to represent err! calls - Simple(String): _no_source, _api_error, - Compact(Compact): _no_source, _compact_api_error, + Simple(String): no_source, api_error, + Compact(Compact): no_source, compact_api_error, // Used in our custom http client to handle non-global IPs and blocked domains - CustomHttpClient(CustomHttpClientError): _has_source, _api_error, + CustomHttpClient(CustomHttpClientError): has_source, api_error, // Used for special return values, like 2FA errors - Json(Value): _no_source, _serialize, - Db(DieselErr): _has_source, _api_error, - R2d2(R2d2Err): _has_source, _api_error, - R2d2Pool(R2d2PoolErr): _has_source, _api_error, - Serde(SerdeErr): _has_source, _api_error, - JWt(JwtErr): _has_source, _api_error, - Handlebars(HbErr): _has_source, _api_error, - - Io(IoErr): _has_source, _api_error, - Time(TimeErr): _has_source, _api_error, - Req(ReqErr): _has_source, _api_error, - Regex(RegexErr): _has_source, _api_error, - Yubico(YubiErr): _has_source, _api_error, - - Lettre(LettreErr): _has_source, _api_error, - Address(AddrErr): _has_source, _api_error, - Smtp(SmtpErr): _has_source, _api_error, - OpenSSL(SSLErr): _has_source, _api_error, - Rocket(RocketErr): _has_source, _api_error, - - DieselCon(DieselConErr): _has_source, _api_error, - Webauthn(WebauthnErr): _has_source, _api_error, - - OpenDAL(OpenDALErr): _has_source, _api_error, + Json(Value): no_source, serialize, + Db(DieselErr): has_source, api_error, + R2d2(R2d2Err): has_source, api_error, + R2d2Pool(R2d2PoolErr): has_source, api_error, + Serde(SerdeErr): has_source, api_error, + JWt(JwtErr): has_source, api_error, + Handlebars(HbErr): has_source, api_error, + + Io(IoErr): has_source, api_error, + Time(TimeErr): has_source, api_error, + Req(ReqErr): has_source, api_error, + Regex(RegexErr): has_source, api_error, + Yubico(YubiErr): has_source, api_error, + + Lettre(LettreErr): has_source, api_error, + Address(AddrErr): has_source, api_error, + Smtp(SmtpErr): has_source, api_error, + OpenSSL(SSLErr): has_source, api_error, + Rocket(RocketErr): has_source, api_error, + + DieselCon(DieselConErr): has_source, api_error, + Webauthn(WebauthnErr): has_source, api_error, + + OpenDAL(OpenDALErr): has_source, api_error, } impl std::fmt::Debug for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.source() { Some(e) => write!(f, "{}.\n[CAUSE] {:#?}", self.message, e), - None => match self.error { + None => match self.kind { ErrorKind::Empty(_) => Ok(()), ErrorKind::Simple(ref s) => { if &self.message == s { @@ -135,6 +135,7 @@ impl Error { (usr_msg.clone(), usr_msg.into()).into() } + #[must_use] pub fn empty() -> Self { Empty {}.into() } @@ -147,13 +148,13 @@ impl Error { #[must_use] pub fn with_kind(mut self, kind: ErrorKind) -> Self { - self.error = kind; + self.kind = kind; self } #[must_use] pub const fn with_code(mut self, code: u16) -> Self { - self.error_code = code; + self.code = code; self } @@ -194,14 +195,14 @@ impl MapResult for Option { } } -const fn _has_source(e: T) -> Option { +const fn has_source(e: T) -> Option { Some(e) } -fn _no_source(_: T) -> Option { +fn no_source(_: T) -> Option { None } -fn _serialize(e: &impl Serialize, _msg: &str) -> String { +fn serialize(e: &impl Serialize, _msg: &str) -> String { serde_json::to_string(e).unwrap() } @@ -280,14 +281,14 @@ struct ApiErrorResponse<'a>(ApiErrorMsg<'a>); /// The custom serialization adds all other needed fields struct CompactApiErrorResponse<'a>(ApiErrorMsg<'a>); -fn _api_error(_: &impl std::any::Any, msg: &str) -> String { +fn api_error(_: &impl std::any::Any, msg: &str) -> String { let response = ApiErrorMsg { message: msg, }; serde_json::to_string(&ApiErrorResponse(response)).unwrap() } -fn _compact_api_error(_: &impl std::any::Any, msg: &str) -> String { +fn compact_api_error(_: &impl std::any::Any, msg: &str) -> String { let response = ApiErrorMsg { message: msg, }; @@ -305,12 +306,12 @@ use rocket::response::{self, Responder, Response}; impl Responder<'_, 'static> for Error { fn respond_to(self, _: &Request<'_>) -> response::Result<'static> { - match self.error { + match self.kind { ErrorKind::Empty(_) | ErrorKind::Simple(_) | ErrorKind::Compact(_) => {} // Don't print the error in this situation _ => error!(target: "error", "{self:#?}"), - }; + } - let code = Status::from_code(self.error_code).unwrap_or(Status::BadRequest); + let code = Status::from_code(self.code).unwrap_or(Status::BadRequest); let body = self.to_string(); Response::build().status(code).header(ContentType::JSON).sized_body(Some(body.len()), Cursor::new(body)).ok() } diff --git a/src/http_client.rs b/src/http_client.rs index d39b884d..232ba7da 100644 --- a/src/http_client.rs +++ b/src/http_client.rs @@ -5,17 +5,21 @@ use std::{ time::Duration, }; -use hickory_resolver::{net::runtime::TokioRuntimeProvider, TokioResolver}; +use hickory_resolver::{TokioResolver, net::runtime::TokioRuntimeProvider}; use regex::Regex; use reqwest::{ + Client, ClientBuilder, dns::{Name, Resolve, Resolving}, - header, Client, ClientBuilder, + header, }; use url::Host; -use crate::{util::is_global, CONFIG}; +use crate::{CONFIG, util::is_global}; pub fn make_http_request(method: reqwest::Method, url: &str) -> Result { + static INSTANCE: LazyLock = + LazyLock::new(|| get_reqwest_client_builder().build().expect("Failed to build client")); + let Ok(url) = url::Url::parse(url) else { err!("Invalid URL"); }; @@ -25,9 +29,6 @@ pub fn make_http_request(method: reqwest::Method, url: &str) -> Result = - LazyLock::new(|| get_reqwest_client_builder().build().expect("Failed to build client")); - Ok(INSTANCE.request(method, url)) } @@ -67,18 +68,19 @@ fn should_block_ip(ip: IpAddr) -> bool { } fn should_block_address_regex(domain_or_ip: &str) -> bool { + static COMPILED_REGEX: Mutex> = Mutex::new(None); + let Some(block_regex) = CONFIG.http_request_block_regex() else { return false; }; - static COMPILED_REGEX: Mutex> = Mutex::new(None); let mut guard = COMPILED_REGEX.lock().unwrap(); // If the stored regex is up to date, use it - if let Some((value, regex)) = &*guard { - if value == &block_regex { - return regex.is_match(domain_or_ip); - } + if let Some((value, regex)) = &*guard + && value == &block_regex + { + return regex.is_match(domain_or_ip); } // If we don't have a regex stored, or it's not up to date, recreate it @@ -92,7 +94,7 @@ fn should_block_address_regex(domain_or_ip: &str) -> bool { pub fn get_valid_host(host: &str) -> Result { let Ok(host) = Host::parse(host) else { return Err(CustomHttpClientError::Invalid { - domain: host.to_string(), + domain: host.to_owned(), }); }; @@ -136,16 +138,16 @@ pub fn should_block_host>(host: &Host) -> Result<(), CustomHttp let (ip, host_str): (Option, String) = match host { Host::Ipv4(ip) => (Some(IpAddr::V4(*ip)), ip.to_string()), Host::Ipv6(ip) => (Some(IpAddr::V6(*ip)), ip.to_string()), - Host::Domain(d) => (None, d.as_ref().to_string()), + Host::Domain(d) => (None, d.as_ref().to_owned()), }; - if let Some(ip) = ip { - if should_block_ip(ip) { - return Err(CustomHttpClientError::NonGlobalIp { - domain: None, - ip, - }); - } + if let Some(ip) = ip + && should_block_ip(ip) + { + return Err(CustomHttpClientError::NonGlobalIp { + domain: None, + ip, + }); } if should_block_address_regex(&host_str) { @@ -233,8 +235,7 @@ impl CustomDnsResolver { builder.build() }) .inspect_err(|e| warn!("Error creating Hickory resolver, falling back to default: {e:?}")) - .map(|resolver| Arc::new(Self::Hickory(Arc::new(resolver)))) - .unwrap_or_else(|_| Arc::new(Self::Default())) + .map_or_else(|_| Arc::new(Self::Default()), |resolver| Arc::new(Self::Hickory(Arc::new(resolver)))) } // Note that we get an iterator of addresses, but we only grab the first one for convenience @@ -257,13 +258,13 @@ impl CustomDnsResolver { fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> { let Ok(host) = get_valid_host(name) else { return Err(CustomHttpClientError::Invalid { - domain: name.to_string(), + domain: name.to_owned(), }); }; if should_block_host(&host).is_err() { return Err(CustomHttpClientError::Blocked { - domain: name.to_string(), + domain: name.to_owned(), }); } @@ -273,7 +274,7 @@ fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> { fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomHttpClientError> { if should_block_ip(ip) { Err(CustomHttpClientError::NonGlobalIp { - domain: Some(name.to_string()), + domain: Some(name.to_owned()), ip, }) } else { @@ -318,7 +319,7 @@ pub(crate) mod aws { let future = async move { let method = reqwest::Method::from_bytes(request.method().as_bytes()) .map_err(|e| ConnectorError::user(Box::new(e)))?; - let mut req_builder = client.request(method, request.uri().to_string()); + let mut req_builder = client.request(method, request.uri().to_owned()); for (name, value) in request.headers() { req_builder = req_builder.header(name, value); diff --git a/src/mail.rs b/src/mail.rs index cdbd269a..5da753bf 100644 --- a/src/mail.rs +++ b/src/mail.rs @@ -1,16 +1,17 @@ use chrono::NaiveDateTime; -use percent_encoding::{percent_encode, NON_ALPHANUMERIC}; +use percent_encoding::{NON_ALPHANUMERIC, percent_encode}; use std::{env::consts::EXE_SUFFIX, str::FromStr}; use lettre::{ + Address, AsyncSendmailTransport, AsyncSmtpTransport, AsyncTransport, Tokio1Executor, message::{Attachment, Body, Mailbox, Message, MultiPart, SinglePart}, transport::smtp::authentication::{Credentials, Mechanism as SmtpAuthMechanism}, transport::smtp::client::{Tls, TlsParameters}, transport::smtp::extension::ClientId, - Address, AsyncSendmailTransport, AsyncSmtpTransport, AsyncTransport, Tokio1Executor, }; use crate::{ + CONFIG, api::EmptyResult, auth::{ encode_jwt, generate_delete_claims, generate_emergency_access_invite_claims, generate_invite_claims, @@ -18,7 +19,6 @@ use crate::{ }, db::models::{Device, DeviceType, EmergencyAccessId, MembershipId, OrganizationId, User, UserId}, error::Error, - CONFIG, }; fn sendmail_transport() -> AsyncSendmailTransport { @@ -38,7 +38,9 @@ fn smtp_transport() -> AsyncSmtpTransport { .timeout(Some(Duration::from_secs(CONFIG.smtp_timeout()))); // Determine security - let smtp_client = if CONFIG.smtp_security() != *"off" { + let smtp_client = if CONFIG.smtp_security() == *"off" { + smtp_client + } else { let mut tls_parameters = TlsParameters::builder(host); if CONFIG.smtp_accept_invalid_hostnames() { tls_parameters = tls_parameters.dangerous_accept_invalid_hostnames(true); @@ -53,8 +55,6 @@ fn smtp_transport() -> AsyncSmtpTransport { } else { smtp_client.tls(Tls::Required(tls_parameters)) } - } else { - smtp_client }; let smtp_client = match (CONFIG.smtp_username(), CONFIG.smtp_password()) { @@ -81,12 +81,12 @@ fn smtp_transport() -> AsyncSmtpTransport { } } - if !selected_mechanisms.is_empty() { - smtp_client.authentication(selected_mechanisms) - } else { + if selected_mechanisms.is_empty() { // Only show a warning, and return without setting an actual authentication mechanism warn!("No valid SMTP Auth mechanism found for '{mechanism}', using default values"); smtp_client + } else { + smtp_client.authentication(selected_mechanisms) } } _ => smtp_client, @@ -129,14 +129,16 @@ fn get_template(template_name: &str, data: &serde_json::Value) -> Result<(String let text = CONFIG.render_template(template_name, data)?; let mut text_split = text.split(""); - let subject = match text_split.next() { - Some(s) => s.trim().to_string(), - None => err!("Template doesn't contain subject"), + let subject = if let Some(s) = text_split.next() { + s.trim().to_owned() + } else { + err!("Template doesn't contain subject") }; - let body = match text_split.next() { - Some(s) => s.trim().to_string(), - None => err!("Template doesn't contain body"), + let body = if let Some(s) = text_split.next() { + s.trim().to_owned() + } else { + err!("Template doesn't contain body") }; if text_split.next().is_some() { @@ -204,9 +206,8 @@ pub async fn send_verify_email(address: &str, user_id: &UserId) -> EmptyResult { pub async fn send_register_verify_email(email: &str, token: &str) -> EmptyResult { let mut query = url::Url::parse("https://query.builder").unwrap(); query.query_pairs_mut().append_pair("email", email).append_pair("token", token); - let query_string = match query.query() { - None => err!("Failed to build verify URL query parameters"), - Some(query) => query, + let Some(query_string) = query.query() else { + err!("Failed to build verify URL query parameters") }; let (subject, body_html, body_text) = get_text( @@ -655,7 +656,7 @@ pub async fn send_protected_action_token(address: &str, token: &str) -> EmptyRes async fn send_with_selected_transport(email: Message) -> EmptyResult { if CONFIG.use_sendmail() { match sendmail_transport().send(email).await { - Ok(_) => Ok(()), + Ok(()) => Ok(()), // Match some common errors and make them more user friendly Err(e) => { if e.is_client() { @@ -664,10 +665,9 @@ async fn send_with_selected_transport(email: Message) -> EmptyResult { } else if e.is_response() { debug!("Sendmail response error: {e:?}"); err!(format!("Sendmail response error: {e}")); - } else { - debug!("Sendmail error: {e:?}"); - err!(format!("Sendmail error: {e}")); } + debug!("Sendmail error: {e:?}"); + err!(format!("Sendmail error: {e}")); } } } else { @@ -695,10 +695,9 @@ async fn send_with_selected_transport(email: Message) -> EmptyResult { } else if e.is_tls() { debug!("SMTP encryption error: {e:#?}"); err!(format!("SMTP encryption error: {e}")); - } else { - debug!("SMTP error: {e:#?}"); - err!(format!("SMTP error: {e}")); } + debug!("SMTP error: {e:#?}"); + err!(format!("SMTP error: {e}")); } } } diff --git a/src/main.rs b/src/main.rs index 4ffeacc1..42e10f53 100644 --- a/src/main.rs +++ b/src/main.rs @@ -63,10 +63,10 @@ mod util; use crate::api::core::two_factor::duo_oidc::purge_duo_contexts; use crate::api::purge_auth_requests; use crate::api::{WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS}; -pub use config::{PathType, CONFIG}; +pub use config::{CONFIG, PathType}; pub use error::{Error, MapResult}; use rocket::data::{Limits, ToByteUnit}; -use std::sync::{atomic::Ordering, Arc}; +use std::sync::{Arc, atomic::Ordering}; pub use util::is_running_in_container; #[rocket::main] @@ -137,26 +137,23 @@ fn parse_args() { if let Some(command) = pargs.subcommand().unwrap_or_default() { if command == "hash" { use argon2::{ - password_hash::SaltString, Algorithm::Argon2id, Argon2, ParamsBuilder, PasswordHasher, Version::V0x13, + Algorithm::Argon2id, Argon2, ParamsBuilder, PasswordHasher, Version::V0x13, password_hash::SaltString, }; let mut argon2_params = ParamsBuilder::new(); let preset: Option = pargs.opt_value_from_str(["-p", "--preset"]).unwrap_or_default(); let selected_preset; - match preset.as_deref() { - Some("owasp") => { - selected_preset = "owasp"; - argon2_params.m_cost(19456); - argon2_params.t_cost(2); - argon2_params.p_cost(1); - } - _ => { - // Bitwarden preset is the default - selected_preset = "bitwarden"; - argon2_params.m_cost(65540); - argon2_params.t_cost(3); - argon2_params.p_cost(4); - } + if preset.as_deref() == Some("owasp") { + selected_preset = "owasp"; + argon2_params.m_cost(19456); + argon2_params.t_cost(2); + argon2_params.p_cost(1); + } else { + // Bitwarden preset is the default + selected_preset = "bitwarden"; + argon2_params.m_cost(65540); + argon2_params.t_cost(3); + argon2_params.p_cost(4); } println!("Generate an Argon2id PHC string using the '{selected_preset}' preset:\n"); @@ -247,7 +244,7 @@ fn init_logging() -> Result { let level = caps .get(1) .and_then(|m| log::LevelFilter::from_str(m.as_str()).ok()) - .ok_or(Error::new("Failed to parse global log level".to_string(), ""))?; + .ok_or(Error::new("Failed to parse global log level".to_owned(), ""))?; let levels_override: Vec<(&str, log::LevelFilter)> = caps .get(2) @@ -256,13 +253,13 @@ fn init_logging() -> Result { .split(',') .collect::>() .into_iter() - .flat_map(|s| match s.split_once('=') { + .filter_map(|s| match s.split_once('=') { Some((log, lvl_str)) => log::LevelFilter::from_str(lvl_str).ok().map(|lvl| (log, lvl)), _ => None, }) .collect() }) - .ok_or(Error::new("Failed to parse overrides".to_string(), ""))?; + .ok_or(Error::new("Failed to parse overrides".to_owned(), ""))?; (level, levels_override) } else { @@ -338,7 +335,7 @@ fn init_logging() -> Result { ("vaultwarden::db::query_logger", log::LevelFilter::Off), ]); - for (path, level) in levels_override.into_iter() { + for (path, level) in levels_override { let _ = default_levels.insert(path, level); } @@ -352,7 +349,7 @@ fn init_logging() -> Result { let mut logger = fern::Dispatch::new().level(level).chain(std::io::stdout()); for (path, level) in default_levels { - logger = logger.level_for(path.to_string(), level); + logger = logger.level_for(path.to_owned(), level); } if CONFIG.extended_logging() { @@ -363,7 +360,7 @@ fn init_logging() -> Result { record.target(), record.level(), message - )) + )); }); } else { logger = logger.format(|out, message, _| out.finish(format_args!("{message}"))); @@ -609,9 +606,7 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> #[cfg(all(unix, sqlite))] { - if db::ACTIVE_DB_TYPE.get() != Some(&db::DbConnType::Sqlite) { - debug!("PostgreSQL and MySQL/MariaDB do not support this backup feature, skip adding USR1 signal."); - } else { + if db::ACTIVE_DB_TYPE.get() == Some(&db::DbConnType::Sqlite) { tokio::spawn(async move { let mut signal_user1 = tokio::signal::unix::signal(SignalKind::user_defined1()).unwrap(); loop { @@ -624,6 +619,8 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> } } }); + } else { + debug!("PostgreSQL and MySQL/MariaDB do not support this backup feature, skip adding USR1 signal."); } } @@ -671,7 +668,7 @@ fn schedule_jobs(pool: db::DbPool) { let runtime = tokio::runtime::Runtime::new().unwrap(); thread::Builder::new() - .name("job-scheduler".to_string()) + .name("job-scheduler".to_owned()) .spawn(move || { use job_scheduler_ng::{Job, JobScheduler}; let _runtime_guard = runtime.enter(); diff --git a/src/ratelimit.rs b/src/ratelimit.rs index 854bcc53..2b422924 100644 --- a/src/ratelimit.rs +++ b/src/ratelimit.rs @@ -1,8 +1,8 @@ use std::{net::IpAddr, num::NonZeroU32, sync::LazyLock, time::Duration}; -use governor::{clock::DefaultClock, state::keyed::DashMapStateStore, Quota, RateLimiter}; +use governor::{Quota, RateLimiter, clock::DefaultClock, state::keyed::DashMapStateStore}; -use crate::{Error, CONFIG}; +use crate::{CONFIG, Error}; type Limiter = RateLimiter, DefaultClock>; @@ -20,7 +20,7 @@ static LIMITER_ADMIN: LazyLock = LazyLock::new(|| { pub fn check_limit_login(ip: &IpAddr) -> Result<(), Error> { match LIMITER_LOGIN.check_key(ip) { - Ok(_) => Ok(()), + Ok(()) => Ok(()), Err(_e) => { err_code!("Too many login requests", 429); } @@ -29,7 +29,7 @@ pub fn check_limit_login(ip: &IpAddr) -> Result<(), Error> { pub fn check_limit_admin(ip: &IpAddr) -> Result<(), Error> { match LIMITER_ADMIN.check_key(ip) { - Ok(_) => Ok(()), + Ok(()) => Ok(()), Err(_e) => { err_code!("Too many admin requests", 429); } diff --git a/src/sso.rs b/src/sso.rs index 56e9a534..01fbd906 100644 --- a/src/sso.rs +++ b/src/sso.rs @@ -6,15 +6,15 @@ use regex::Regex; use url::Url; use crate::{ + CONFIG, api::ApiResult, auth, - auth::{AuthMethod, AuthTokens, TokenWrapper, BW_EXPIRATION, DEFAULT_REFRESH_VALIDITY}, + auth::{AuthMethod, AuthTokens, BW_EXPIRATION, DEFAULT_REFRESH_VALIDITY, TokenWrapper}, db::{ - models::{Device, OIDCAuthenticatedUser, SsoAuth, SsoUser, User}, DbConn, + models::{Device, OIDCAuthenticatedUser, SsoAuth, SsoUser, User}, }, sso_client::Client, - CONFIG, }; pub static FAKE_SSO_IDENTIFIER: &str = "00000000-01DC-01DC-01DC-000000000000"; @@ -123,7 +123,7 @@ pub fn encode_ssotoken_claims() -> String { nbf: time_now.timestamp(), exp: (time_now + chrono::TimeDelta::try_minutes(2).unwrap()).timestamp(), iss: SSO_JWT_ISSUER.to_string(), - sub: "vaultwarden".to_string(), + sub: "vaultwarden".to_owned(), }; auth::encode_jwt(&claims) @@ -171,12 +171,14 @@ fn decode_token_claims(token_name: &str, token: &str) -> ApiResult ApiResult { - let state = match data_encoding::BASE64.decode(base64_state.as_bytes()) { - Ok(vec) => match String::from_utf8(vec) { - Ok(valid) => OIDCState(valid), - Err(_) => err!(format!("Invalid utf8 chars in {base64_state} after base64 decoding")), - }, - Err(_) => err!(format!("Failed to decode {base64_state} using base64")), + let state = if let Ok(vec) = data_encoding::BASE64.decode(base64_state.as_bytes()) { + if let Ok(valid) = String::from_utf8(vec) { + OIDCState(valid) + } else { + err!(format!("Invalid utf8 chars in {base64_state} after base64 decoding")) + } + } else { + err!(format!("Failed to decode {base64_state} using base64")) }; Ok(state) @@ -193,12 +195,15 @@ pub async fn authorize_url( ) -> ApiResult { let redirect_uri = match client_id { "web" | "browser" => format!("{}/sso-connector.html", CONFIG.domain()), - "desktop" | "mobile" => "bitwarden://sso-callback".to_string(), + "desktop" | "mobile" => "bitwarden://sso-callback".to_owned(), "cli" => { let port_regex = Regex::new(r"^http://localhost:([0-9]{4})$").unwrap(); - match port_regex.captures(raw_redirect_uri).and_then(|captures| captures.get(1).map(|c| c.as_str())) { - Some(port) => format!("http://localhost:{port}"), - None => err!("Failed to extract port number"), + if let Some(port) = + port_regex.captures(raw_redirect_uri).and_then(|captures| captures.get(1).map(|c| c.as_str())) + { + format!("http://localhost:{port}") + } else { + err!("Failed to extract port number") } } _ => err!(format!("Unsupported client {client_id}")), @@ -246,9 +251,8 @@ pub async fn exchange_code( ) -> ApiResult<(SsoAuth, OIDCAuthenticatedUser)> { use openidconnect::OAuth2TokenResponse; - let mut sso_auth = match SsoAuth::find_by_code(code, conn).await { - None => err!(format!("Invalid code cannot retrieve sso auth")), - Some(sso_auth) => sso_auth, + let Some(mut sso_auth) = SsoAuth::find_by_code(code, conn).await else { + err!("Invalid code cannot retrieve sso auth") }; if let Some(authenticated_user) = sso_auth.auth_response.clone() { @@ -286,8 +290,8 @@ pub async fn exchange_code( let user_name = id_claims.preferred_username().or(user_info.preferred_username()).map(|un| un.to_string()); - let refresh_token = token_response.refresh_token().map(|t| t.secret()); - if refresh_token.is_none() && CONFIG.sso_scopes_vec().contains(&"offline_access".to_string()) { + let refresh_token = token_response.refresh_token().map(openidconnect::RefreshToken::secret); + if refresh_token.is_none() && CONFIG.sso_scopes_vec().contains(&"offline_access".to_owned()) { error!("Scope offline_access is present but response contain no refresh_token"); } @@ -331,7 +335,9 @@ pub async fn redeem( user_sso.save(conn).await?; } - if !CONFIG.sso_auth_only_not_session() { + if CONFIG.sso_auth_only_not_session() { + Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id)) + } else { let now = Utc::now(); let (ap_nbf, ap_exp) = @@ -344,9 +350,7 @@ pub async fn redeem( let access_claims = auth::LoginJwtClaims::new(device, user, ap_nbf, ap_exp, AuthMethod::Sso.scope_vec(), client_id, now); - _create_auth_tokens(device, auth_user.refresh_token, access_claims, auth_user.access_token) - } else { - Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id)) + create_auth_tokens_impl(device, auth_user.refresh_token, access_claims, auth_user.access_token) } } @@ -360,7 +364,9 @@ pub fn create_auth_tokens( access_token: String, expires_in: Option, ) -> ApiResult { - if !CONFIG.sso_auth_only_not_session() { + if CONFIG.sso_auth_only_not_session() { + Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id)) + } else { let now = Utc::now(); let (ap_nbf, ap_exp) = match (decode_token_claims("access_token", &access_token), expires_in) { @@ -372,13 +378,11 @@ pub fn create_auth_tokens( let access_claims = auth::LoginJwtClaims::new(device, user, ap_nbf, ap_exp, AuthMethod::Sso.scope_vec(), client_id, now); - _create_auth_tokens(device, refresh_token, access_claims, access_token) - } else { - Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id)) + create_auth_tokens_impl(device, refresh_token, access_claims, access_token) } } -fn _create_auth_tokens( +fn create_auth_tokens_impl( device: &Device, refresh_token: Option, access_claims: auth::LoginJwtClaims, @@ -462,7 +466,7 @@ pub async fn exchange_refresh_token( now, ); - _create_auth_tokens(device, None, access_claims, access_token) + create_auth_tokens_impl(device, None, access_claims, access_token) } None => err!("No token present while in SSO"), } diff --git a/src/sso_client.rs b/src/sso_client.rs index 68e171c6..5aa77750 100644 --- a/src/sso_client.rs +++ b/src/sso_client.rs @@ -1,18 +1,31 @@ use std::{borrow::Cow, future::Future, pin::Pin, sync::LazyLock, time::Duration}; -use openidconnect::{core::*, *}; +use openidconnect::{ + AccessToken, AsyncHttpClient, AuthDisplay, AuthPrompt, AuthenticationFlow, AuthorizationCode, AuthorizationRequest, + ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, EmptyExtraTokenFields, EndpointNotSet, EndpointSet, + HttpClientError, HttpRequest, HttpResponse, IdTokenClaims, IdTokenFields, Nonce, OAuth2TokenResponse, + PkceCodeChallenge, PkceCodeVerifier, RefreshToken, ResponseType, Scope, StandardErrorResponse, + StandardTokenResponse, + core::{ + CoreAuthDisplay, CoreAuthPrompt, CoreClient, CoreErrorResponseType, CoreGenderClaim, CoreIdTokenVerifier, + CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreProviderMetadata, + CoreResponseType, CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, + CoreTokenResponse, CoreTokenType, CoreUserInfoClaims, + }, + http, url, +}; use regex::Regex; use url::Url; use crate::{ + CONFIG, api::{ApiResult, EmptyResult}, db::models::SsoAuth, http_client::get_reqwest_client_builder, sso::{OIDCCode, OIDCCodeChallenge, OIDCCodeVerifier, OIDCState}, - CONFIG, }; -static CLIENT_CACHE_KEY: LazyLock = LazyLock::new(|| "sso-client".to_string()); +static CLIENT_CACHE_KEY: LazyLock = LazyLock::new(|| "sso-client".to_owned()); static CLIENT_CACHE: LazyLock> = LazyLock::new(|| { moka::sync::Cache::builder() .max_capacity(1) @@ -85,7 +98,7 @@ impl<'c> AsyncHttpClient<'c> for OidcHttpClient { impl Client { // Call the OpenId discovery endpoint to retrieve configuration - async fn _get_client() -> ApiResult { + async fn get_client() -> ApiResult { let client_id = ClientId::new(CONFIG.sso_client_id()); let client_secret = ClientSecret::new(CONFIG.sso_client_secret()); @@ -103,14 +116,16 @@ impl Client { let base_client = CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret)); - let token_uri = match base_client.token_uri() { - Some(uri) => uri.clone(), - None => err!("Failed to discover token_url, cannot proceed"), + let token_uri = if let Some(uri) = base_client.token_uri() { + uri.clone() + } else { + err!("Failed to discover token_url, cannot proceed") }; - let user_info_url = match base_client.user_info_url() { - Some(url) => url.clone(), - None => err!("Failed to discover user_info url, cannot proceed"), + let user_info_url = if let Some(url) = base_client.user_info_url() { + url.clone() + } else { + err!("Failed to discover user_info url, cannot proceed") }; let core_client = base_client @@ -129,13 +144,13 @@ impl Client { if CONFIG.sso_client_cache_expiration() > 0 { match CLIENT_CACHE.get(&*CLIENT_CACHE_KEY) { Some(client) => Ok(client), - None => Self::_get_client().await.inspect(|client| { + None => Self::get_client().await.inspect(|client| { debug!("Inserting new client in cache"); CLIENT_CACHE.insert(CLIENT_CACHE_KEY.clone(), client.clone()); }), } } else { - Self::_get_client().await + Self::get_client().await } } @@ -214,15 +229,14 @@ impl Client { Ok(token_response) => { let oidc_nonce = Nonce::new(sso_auth.nonce.clone()); - let id_token = match token_response.extra_fields().id_token() { - None => err!("Token response did not contain an id_token"), - Some(token) => token, + let Some(id_token) = token_response.extra_fields().id_token() else { + err!("Token response did not contain an id_token") }; if CONFIG.sso_debug_tokens() { debug!("Id token: {}", id_token.to_string()); debug!("Access token: {}", token_response.access_token().secret()); - debug!("Refresh token: {:?}", token_response.refresh_token().map(|t| t.secret())); + debug!("Refresh token: {:?}", token_response.refresh_token().map(RefreshToken::secret)); debug!("Expiration time: {:?}", token_response.expires_in()); } @@ -275,12 +289,12 @@ impl Client { let client = Client::cached().await?; REFRESH_CACHE - .get_with(refresh_token.clone(), async move { client._exchange_refresh_token(refresh_token).await }) + .get_with(refresh_token.clone(), async move { client.exchange_refresh_token_impl(refresh_token).await }) .await .map_err(Into::into) } - async fn _exchange_refresh_token(&self, refresh_token: String) -> Result { + async fn exchange_refresh_token_impl(&self, refresh_token: String) -> Result { let rt = RefreshToken::new(refresh_token); match self.core_client.exchange_refresh_token(&rt).request_async(&self.http_client).await { diff --git a/src/storage.rs b/src/storage.rs index ada2a951..ac88d026 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -9,9 +9,9 @@ pub(crate) fn join_path(base: &str, child: &str) -> String { let base = base.trim_end_matches('/'); let child = child.trim_start_matches('/'); if base.is_empty() { - child.to_string() + child.to_owned() } else if child.is_empty() { - base.to_string() + base.to_owned() } else { format!("{base}/{child}") } @@ -34,7 +34,7 @@ pub(crate) fn parent(path: &str) -> Option { return s3::parent(path); } - std::path::Path::new(path).parent()?.to_str().map(ToString::to_string) + std::path::Path::new(path).parent()?.to_str().map(str::to_owned) } pub(crate) fn file_name(path: &str) -> Option { @@ -43,7 +43,7 @@ pub(crate) fn file_name(path: &str) -> Option { return s3::file_name(path); } - std::path::Path::new(path).file_name()?.to_str().map(ToString::to_string) + std::path::Path::new(path).file_name()?.to_str().map(str::to_owned) } pub(crate) fn is_fs_operator(operator: &opendal::Operator) -> bool { @@ -70,7 +70,7 @@ pub(crate) fn operator_for_path(path: &str) -> Result String { if let Ok(mut url) = Url::parse(base) { let mut segments = path_segments(&url); - segments.extend(child.split('/').filter(|segment| !segment.is_empty()).map(ToString::to_string)); + segments.extend(child.split('/').filter(|segment| !segment.is_empty()).map(str::to_owned)); set_path_segments(&mut url, &segments); return url.to_string(); } @@ -96,9 +96,9 @@ mod s3 { let base = base.trim_end_matches('/'); let child = child.trim_start_matches('/'); if base.is_empty() { - child.to_string() + child.to_owned() } else if child.is_empty() { - base.to_string() + base.to_owned() } else { format!("{base}/{child}") } @@ -126,7 +126,7 @@ mod s3 { return Some(url.to_string()); } - std::path::Path::new(path).parent()?.to_str().map(ToString::to_string) + std::path::Path::new(path).parent()?.to_str().map(str::to_owned) } pub(super) fn file_name(path: &str) -> Option { @@ -134,12 +134,12 @@ mod s3 { return path_segments(&url).pop(); } - std::path::Path::new(path).file_name()?.to_str().map(ToString::to_string) + std::path::Path::new(path).file_name()?.to_str().map(str::to_owned) } fn path_segments(url: &Url) -> Vec { url.path_segments() - .map(|segments| segments.filter(|segment| !segment.is_empty()).map(ToString::to_string).collect()) + .map(|segments| segments.filter(|segment| !segment.is_empty()).map(str::to_owned).collect()) .unwrap_or_default() } @@ -206,9 +206,9 @@ mod s3 { }; Ok(Some(Credential { - access_key_id: creds.access_key_id().to_string(), - secret_access_key: creds.secret_access_key().to_string(), - session_token: creds.session_token().map(|s| s.to_string()), + access_key_id: creds.access_key_id().to_owned(), + secret_access_key: creds.secret_access_key().to_owned(), + session_token: creds.session_token().map(ToOwned::to_owned), expires_in, })) } @@ -218,7 +218,7 @@ mod s3 { let mut config = opendal::services::S3Config::from_uri(&uri)?; if !uri_has_option(&uri, &["default_storage_class"]) { - config.default_storage_class = Some("INTELLIGENT_TIERING".to_string()); + config.default_storage_class = Some("INTELLIGENT_TIERING".to_owned()); } if !uri_has_option( diff --git a/src/util.rs b/src/util.rs index 5cd78eed..2e505dee 100644 --- a/src/util.rs +++ b/src/util.rs @@ -5,20 +5,20 @@ use std::{collections::HashMap, io::Cursor, path::Path}; use num_traits::ToPrimitive; use rocket::{ + Data, Orbit, Request, Response, Rocket, fairing::{Fairing, Info, Kind}, http::{ContentType, Header, HeaderMap, Method, Status}, response::{self, Responder}, - Data, Orbit, Request, Response, Rocket, }; use tokio::{ runtime::Handle, - time::{sleep, Duration}, + time::{Duration, sleep}, }; use crate::{ - config::{PathType, SUPPORTED_FEATURE_FLAGS}, CONFIG, + config::{PathType, SUPPORTED_FEATURE_FLAGS}, }; pub struct AppHeaders(); @@ -75,11 +75,16 @@ impl Fairing for AppHeaders { // Do not send the Content-Security-Policy (CSP) Header and X-Frame-Options for the *-connector.html files. // This can cause issues when some MFA requests needs to open a popup or page within the clients like WebAuthn, or Duo. // This is the same behavior as upstream Bitwarden. - if !req_uri_path.ends_with("connector.html") { + if req_uri_path.ends_with("connector.html") { + // It looks like this header get's set somewhere else also, make sure this is not sent for these files, it will cause MFA issues. + res.remove_header("X-Frame-Options"); + } else { let csp = if is_image { // Prevent scripts, frames, objects, etc., from loading with images, mainly for SVG images, since these could contain JavaScript and other unsafe items. // Even though we sanitize SVG images before storing and viewing them, it's better to prevent allowing these elements. - String::from("default-src 'none'; img-src 'self' data:; style-src 'unsafe-inline'; script-src 'none'; frame-src 'none'; object-src 'none") + String::from( + "default-src 'none'; img-src 'self' data:; style-src 'unsafe-inline'; script-src 'none'; frame-src 'none'; object-src 'none", + ) } else { // # Frame Ancestors: // Chrome Web Store: https://chrome.google.com/webstore/detail/bitwarden-free-password-m/nngceckbapebfimnlniiiahkandclblb @@ -129,9 +134,6 @@ impl Fairing for AppHeaders { res.set_raw_header("Content-Security-Policy", csp); res.set_raw_header("X-Frame-Options", "SAMEORIGIN"); - } else { - // It looks like this header get's set somewhere else also, make sure this is not sent for these files, it will cause MFA issues. - res.remove_header("X-Frame-Options"); } // Disable cache unless otherwise specified @@ -146,7 +148,7 @@ pub struct Cors(); impl Cors { fn get_header(headers: &HeaderMap<'_>, name: &str) -> String { match headers.get_one(name) { - Some(h) => h.to_string(), + Some(h) => h.to_owned(), _ => String::new(), } } @@ -212,7 +214,7 @@ impl Cached { Self { response, is_immutable, - ttl: 604800, // 7 days + ttl: 604_800, // 7 days } } @@ -286,7 +288,7 @@ impl Fairing for BetterLogging { } else { "http" }; - let addr = format!("{scheme}://{}:{}", &config.address, &config.port); + let addr = format!("{scheme}://{}:{}", config.address, config.port); info!(target: "start", "Rocket has launched from {addr}"); } @@ -303,7 +305,7 @@ impl Fairing for BetterLogging { match uri.query() { Some(q) => info!(target: "request", "{method} {uri_path_str}?{}", &q[..q.len().min(30)]), None => info!(target: "request", "{method} {uri_path_str}"), - }; + } } } @@ -316,10 +318,10 @@ impl Fairing for BetterLogging { let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str); if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) { let status = response.status(); - if let Some(ref route) = request.route() { - info!(target: "response", "{route} => {status}") + if let Some(route) = request.route() { + info!(target: "response", "{route} => {status}"); } else { - info!(target: "response", "{status}") + info!(target: "response", "{status}"); } } } @@ -402,7 +404,7 @@ pub fn get_env_str_value(key: &str) -> Option { (Ok(_), Ok(_)) => panic!("You should not define both {key} and {key_file}!"), (Ok(v_env), Err(_)) => Some(v_env), (Err(_), Ok(v_file)) => match std::fs::read_to_string(v_file) { - Ok(content) => Some(content.trim().to_string()), + Ok(content) => Some(content.trim().to_owned()), Err(e) => panic!("Failed to load {key}: {e:?}"), }, _ => None, @@ -457,10 +459,10 @@ pub fn validate_and_format_date(dt: &str) -> String { pub fn format_datetime_local(dt: &DateTime, fmt: &str) -> String { // Try parsing the `TZ` environment variable to enable formatting `%Z` as // a time zone abbreviation. - if let Ok(tz) = env::var("TZ") { - if let Ok(tz) = tz.parse::() { - return dt.with_timezone(&tz).format(fmt).to_string(); - } + if let Ok(tz) = env::var("TZ") + && let Ok(tz) = tz.parse::() + { + return dt.with_timezone(&tz).format(fmt).to_string(); } // Otherwise, fall back to formatting `%Z` as a UTC offset. @@ -512,6 +514,7 @@ pub fn is_valid_email(email: &str) -> bool { // /// Returns true if the program is running in Docker, Podman or Kubernetes. +#[must_use] pub fn is_running_in_container() -> bool { Path::new("/.dockerenv").exists() || Path::new("/run/.containerenv").exists() @@ -543,10 +546,10 @@ pub fn get_active_web_release() -> String { ]; for version_file in version_files { - if let Ok(version_str) = std::fs::read_to_string(&version_file) { - if let Ok(version) = serde_json::from_str::(&version_str) { - return String::from(version.version.trim_start_matches('v')); - } + if let Ok(version_str) = std::fs::read_to_string(&version_file) + && let Ok(version) = serde_json::from_str::(&version_str) + { + return String::from(version.version.trim_start_matches('v')); } } @@ -605,7 +608,7 @@ impl<'de> Visitor<'de> for LowerCaseVisitor { let mut result_map = JsonMap::new(); while let Some((key, value)) = map.next_entry()? { - result_map.insert(_process_key(key), convert_json_key_lcase_first(value)); + result_map.insert(process_json_key(key), convert_json_key_lcase_first(value)); } Ok(Value::Object(result_map)) @@ -627,7 +630,7 @@ impl<'de> Visitor<'de> for LowerCaseVisitor { // Inner function to handle a special case for the 'ssn' key. // This key is part of the Identity Cipher (Social Security Number) -fn _process_key(key: &str) -> String { +fn process_json_key(key: &str) -> String { match key.to_lowercase().as_ref() { "ssn" => "ssn".into(), _ => lcase_first(key), @@ -664,21 +667,24 @@ impl NumberOrString { } } - #[allow(clippy::wrong_self_convention)] + #[expect(clippy::wrong_self_convention)] pub fn into_i32(&self) -> Result { use std::num::ParseIntError as PIE; match self { - NumberOrString::Number(n) => match n.to_i32() { - Some(n) => Ok(n), - None => err!("Number does not fit in i32"), - }, + NumberOrString::Number(n) => { + if let Some(n) = n.to_i32() { + Ok(n) + } else { + err!("Number does not fit in i32") + } + } NumberOrString::String(s) => { s.parse().map_err(|e: PIE| crate::Error::new("Can't convert to number", e.to_string())) } } } - #[allow(clippy::wrong_self_convention)] + #[expect(clippy::wrong_self_convention)] pub fn into_i64(&self) -> Result { use std::num::ParseIntError as PIE; match self { @@ -753,11 +759,11 @@ pub fn convert_json_key_lcase_first(src_json: Value) -> Value { Value::Object(obj) => { let mut json_map = JsonMap::new(); - for (key, value) in obj.into_iter() { + for (key, value) in obj { match (key, value) { (key, Value::Object(elm)) => { let inner_value = convert_json_key_lcase_first(Value::Object(elm)); - json_map.insert(_process_key(&key), inner_value); + json_map.insert(process_json_key(&key), inner_value); } (key, Value::Array(elm)) => { @@ -767,11 +773,11 @@ pub fn convert_json_key_lcase_first(src_json: Value) -> Value { inner_array.push(convert_json_key_lcase_first(inner_obj)); } - json_map.insert(_process_key(&key), Value::Array(inner_array)); + json_map.insert(process_json_key(&key), Value::Array(inner_array)); } (key, value) => { - json_map.insert(_process_key(&key), value); + json_map.insert(process_json_key(&key), value); } } } @@ -793,7 +799,7 @@ pub enum FeatureFlagFilter { /// Parses the experimental client feature flags string into a HashMap. pub fn parse_experimental_client_feature_flags( experimental_client_feature_flags: &str, - filter_mode: FeatureFlagFilter, + filter_mode: &FeatureFlagFilter, ) -> HashMap { experimental_client_feature_flags .split(',') @@ -811,7 +817,8 @@ pub fn parse_experimental_client_feature_flags( /// TODO: This is extracted from IpAddr::is_global, which is unstable: /// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global /// Remove once https://github.com/rust-lang/rust/issues/27709 is merged -#[allow(clippy::nonminimal_bool)] +// #[expect(clippy::nonminimal_bool, reason = "Mostly copy/paste from std, keep as-is")] +#[expect(clippy::decimal_bitwise_operands, reason = "Mostly copy/paste from std, keep as-is")] #[cfg(any(not(feature = "unstable"), test))] pub fn is_global_hardcoded(ip: std::net::IpAddr) -> bool { match ip {