Browse Source

Switch to Edition 2024, more clippy lints, and less macro calls (#7200)

* Update to Rust 2024 Edition

Updated to the Rust 2024 Edition and added and fixed several lint checks.
This is a large change which, because of the extra lints, added some possible fixes for issues.

Signed-off-by: BlackDex <black.dex@gmail.com>

* Reorder and merge imports

Signed-off-by: BlackDex <black.dex@gmail.com>

* Remove "db_run!" macro calls where possible

Signed-off-by: BlackDex <black.dex@gmail.com>

---------

Signed-off-by: BlackDex <black.dex@gmail.com>
pull/7104/merge
Mathijs van Veluw 6 days ago
committed by GitHub
parent
commit
1ba2c6a26c
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 169
      Cargo.toml
  2. 10
      build.rs
  3. 9
      macros/src/lib.rs
  4. 2
      rustfmt.toml
  5. 111
      src/api/admin.rs
  6. 122
      src/api/core/accounts.rs
  7. 170
      src/api/core/ciphers.rs
  8. 48
      src/api/core/emergency_access.rs
  9. 82
      src/api/core/events.rs
  10. 9
      src/api/core/folders.rs
  11. 86
      src/api/core/mod.rs
  12. 214
      src/api/core/organizations.rs
  13. 77
      src/api/core/public.rs
  14. 56
      src/api/core/sends.rs
  15. 22
      src/api/core/two_factor/authenticator.rs
  16. 33
      src/api/core/two_factor/duo.rs
  17. 32
      src/api/core/two_factor/duo_oidc.rs
  18. 34
      src/api/core/two_factor/email.rs
  19. 31
      src/api/core/two_factor/mod.rs
  20. 26
      src/api/core/two_factor/protected_actions.rs
  21. 68
      src/api/core/two_factor/webauthn.rs
  22. 20
      src/api/core/two_factor/yubikey.rs
  23. 55
      src/api/icons.rs
  24. 206
      src/api/identity.rs
  25. 9
      src/api/mod.rs
  26. 35
      src/api/notifications.rs
  27. 22
      src/api/push.rs
  28. 20
      src/api/web.rs
  29. 138
      src/auth.rs
  30. 239
      src/config.rs
  31. 29
      src/db/mod.rs
  32. 32
      src/db/models/archive.rs
  33. 77
      src/db/models/attachment.rs
  34. 52
      src/db/models/auth_request.rs
  35. 598
      src/db/models/cipher.rs
  36. 633
      src/db/models/collection.rs
  37. 73
      src/db/models/device.rs
  38. 98
      src/db/models/emergency_access.rs
  39. 62
      src/db/models/event.rs
  40. 56
      src/db/models/favorite.rs
  41. 71
      src/db/models/folder.rs
  42. 259
      src/db/models/group.rs
  43. 4
      src/db/models/mod.rs
  44. 118
      src/db/models/org_policy.rs
  45. 335
      src/db/models/organization.rs
  46. 121
      src/db/models/send.rs
  47. 40
      src/db/models/sso_auth.rs
  48. 55
      src/db/models/two_factor.rs
  49. 42
      src/db/models/two_factor_duo_context.rs
  50. 39
      src/db/models/two_factor_incomplete.rs
  51. 123
      src/db/models/user.rs
  52. 11
      src/db/query_logger.rs
  53. 100
      src/error.rs
  54. 43
      src/http_client.rs
  55. 52
      src/mail.rs
  56. 42
      src/main.rs
  57. 8
      src/ratelimit.rs
  58. 62
      src/sso.rs
  59. 50
      src/sso_client.rs
  60. 30
      src/storage.rs
  61. 107
      src/util.rs

169
Cargo.toml

@ -1,5 +1,5 @@
[workspace.package] [workspace.package]
edition = "2021" edition = "2024"
rust-version = "1.93.0" rust-version = "1.93.0"
license = "AGPL-3.0-only" license = "AGPL-3.0-only"
repository = "https://github.com/dani-garcia/vaultwarden" repository = "https://github.com/dani-garcia/vaultwarden"
@ -23,7 +23,8 @@ publish.workspace = true
[features] [features]
default = [ default = [
# "sqlite" or "sqlite_system", # "sqlite",
# "sqlite_system",
# "mysql", # "mysql",
# "postgresql", # "postgresql",
] ]
@ -32,14 +33,22 @@ enable_syslog = []
# Please enable at least one of these DB backends. # Please enable at least one of these DB backends.
mysql = ["diesel/mysql", "diesel_migrations/mysql"] mysql = ["diesel/mysql", "diesel_migrations/mysql"]
postgresql = ["diesel/postgres", "diesel_migrations/postgres"] postgresql = ["diesel/postgres", "diesel_migrations/postgres"]
sqlite_system = ["diesel/sqlite", "diesel_migrations/sqlite"] sqlite_system = ["diesel/sqlite", "diesel_migrations/sqlite"] # Dynamically link SQLite
sqlite = ["sqlite_system", "libsqlite3-sys/bundled"] # Alternative to the above, statically linked SQLite into the binary instead of dynamically. sqlite = ["sqlite_system", "libsqlite3-sys/bundled"] # Statically link SQLite into the binary instead of dynamically.
# Enable to use a vendored and statically linked openssl # Enable to use a vendored and statically linked openssl
vendored_openssl = ["openssl/vendored"] vendored_openssl = ["openssl/vendored"]
# Enable MiMalloc memory allocator to replace the default malloc # Enable MiMalloc memory allocator to replace the default malloc
# This can improve performance for Alpine builds # This can improve performance for Alpine builds
enable_mimalloc = ["dep:mimalloc"] enable_mimalloc = ["dep:mimalloc"]
s3 = ["opendal/services-s3", "dep:aws-config", "dep:aws-credential-types", "dep:aws-smithy-runtime-api", "dep:http", "dep:reqsign-aws-v4", "dep:reqsign-core"] s3 = [
"opendal/services-s3",
"dep:aws-config",
"dep:aws-credential-types",
"dep:aws-smithy-runtime-api",
"dep:http",
"dep:reqsign-aws-v4",
"dep:reqsign-core",
]
# OIDC specific features # OIDC specific features
oidc-accept-rfc3339-timestamps = ["openidconnect/accept-rfc3339-timestamps"] oidc-accept-rfc3339-timestamps = ["openidconnect/accept-rfc3339-timestamps"]
@ -59,7 +68,8 @@ macros = { path = "./macros" }
# Logging # Logging
log = "0.4.29" log = "0.4.29"
fern = { version = "0.7.1", features = ["syslog-7", "reopen-1"] } fern = { version = "0.7.1", features = ["syslog-7", "reopen-1"] }
tracing = { version = "0.1.44", features = ["log"] } # Needed to have lettre and webauthn-rs trace logging to work # We need the `log` feature for `tracing` to enable logging for several crates to work, like lettre or webauthn-rs
tracing = { version = "0.1.44", features = ["log"] }
# A `dotenv` implementation for Rust # A `dotenv` implementation for Rust
dotenvy = { version = "0.15.7", default-features = false } dotenvy = { version = "0.15.7", default-features = false }
@ -70,8 +80,8 @@ num-derive = "0.4.2"
bigdecimal = "0.4.10" bigdecimal = "0.4.10"
# Web framework # Web framework
rocket = { version = "0.5.1", features = ["tls", "json"], default-features = false } rocket = { version = "0.5.1", default-features = false, features = ["json", "tls"] }
rocket_ws = { version ="0.1.1" } rocket_ws = { version = "0.1.1" }
# WebSockets libraries # WebSockets libraries
rmpv = "1.3.1" # MessagePack library rmpv = "1.3.1" # MessagePack library
@ -81,19 +91,32 @@ dashmap = "6.1.0"
# Async futures # Async futures
futures = "0.3.32" futures = "0.3.32"
tokio = { version = "1.52.3", features = ["rt-multi-thread", "fs", "io-util", "parking_lot", "time", "signal", "net"] } tokio = { version = "1.52.3", features = [
tokio-util = { version = "0.7.18", features = ["compat"]} "fs",
"io-util",
"net",
"parking_lot",
"rt-multi-thread",
"signal",
"time",
] }
tokio-util = { version = "0.7.18", features = ["compat"] }
# A generic serialization/deserialization framework # A generic serialization/deserialization framework
serde = { version = "1.0.228", features = ["derive"] } serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.149" serde_json = "1.0.149"
# A safe, extensible ORM and Query builder # A safe, extensible ORM and Query builder
# Currently pinned diesel to v2.3.3 as newer version break MySQL/MariaDB compatibility
diesel = { version = "2.3.9", features = ["chrono", "r2d2", "numeric"] } diesel = { version = "2.3.9", features = ["chrono", "r2d2", "numeric"] }
diesel_migrations = "2.3.2" diesel_migrations = "2.3.2"
derive_more = { version = "2.1.1", features = ["from", "into", "as_ref", "deref", "display"] } derive_more = { version = "2.1.1", features = [
"as_ref",
"deref",
"display",
"from",
"into",
] }
diesel-derive-newtype = "2.1.2" diesel-derive-newtype = "2.1.2"
# SQLite, statically bundled unless the `sqlite_system` feature is enabled # SQLite, statically bundled unless the `sqlite_system` feature is enabled
@ -109,7 +132,7 @@ subtle = "2.6.1"
uuid = { version = "1.23.1", features = ["v4"] } uuid = { version = "1.23.1", features = ["v4"] }
# Date and time libraries # Date and time libraries
chrono = { version = "0.4.44", features = ["clock", "serde"], default-features = false } chrono = { version = "0.4.44", default-features = false, features = ["clock", "serde"] }
chrono-tz = "0.10.4" chrono-tz = "0.10.4"
time = "0.3.47" time = "0.3.47"
@ -120,13 +143,13 @@ job_scheduler_ng = "2.4.0"
data-encoding = "2.11.0" data-encoding = "2.11.0"
# JWT library # JWT library
jsonwebtoken = { version = "10.4.0", features = ["use_pem", "rust_crypto"], default-features = false } jsonwebtoken = { version = "10.4.0", default-features = false, features = ["rust_crypto", "use_pem"] }
# TOTP library # TOTP library
totp-lite = "2.0.1" totp-lite = "2.0.1"
# Yubico Library # Yubico Library
yubico = { package = "yubico_ng", version = "0.15.0", features = ["online-tokio"], default-features = false } yubico = { package = "yubico_ng", version = "0.15.0", default-features = false, features = ["online-tokio"] }
# WebAuthn libraries # WebAuthn libraries
# danger-allow-state-serialisation is needed to save the state in the db # danger-allow-state-serialisation is needed to save the state in the db
@ -139,7 +162,20 @@ webauthn-rs-core = "0.5.5"
url = "2.5.8" url = "2.5.8"
# Email libraries # Email libraries
lettre = { version = "0.11.22", features = ["smtp-transport", "sendmail-transport", "builder", "serde", "hostname", "tracing", "tokio1-rustls", "ring", "rustls-native-certs"], default-features = false } lettre = { version = "0.11.22", default-features = false, features = [
# Misc
"tracing",
"serde",
"builder",
"hostname",
# TLS/Security
"ring",
"rustls-native-certs",
"tokio1-rustls",
# Transport
"smtp-transport",
"sendmail-transport",
] }
percent-encoding = "2.3.2" # URL encoding library used for URL's in the emails percent-encoding = "2.3.2" # URL encoding library used for URL's in the emails
email_address = "0.2.9" email_address = "0.2.9"
@ -147,12 +183,33 @@ email_address = "0.2.9"
handlebars = { version = "6.4.0", features = ["dir_source"] } handlebars = { version = "6.4.0", features = ["dir_source"] }
# HTTP client (Used for favicons, version check, DUO and HIBP API) # HTTP client (Used for favicons, version check, DUO and HIBP API)
reqwest = { version = "0.13.3", features = ["rustls-no-provider", "stream", "json", "form", "deflate", "gzip", "brotli", "zstd", "socks", "cookies", "charset", "http2", "system-proxy"], default-features = false} reqwest = { version = "0.13.3", default-features = false, features = [
# Misc
"charset",
"cookies",
"http2",
"json",
"form",
"rustls-no-provider",
"stream",
# Compression
"brotli",
"deflate",
"gzip",
"zstd",
# Proxy
"socks",
"system-proxy",
] }
hickory-resolver = "0.26.1" hickory-resolver = "0.26.1"
# Favicon extraction libraries # Favicon extraction libraries
html5gum = "0.8.3" html5gum = "0.8.3"
regex = { version = "1.12.3", features = ["std", "perf", "unicode-perl"], default-features = false } regex = { version = "1.12.3", default-features = false, features = [
"perf",
"std",
"unicode-perl",
] }
data-url = "0.3.2" data-url = "0.3.2"
bytes = "1.11.1" bytes = "1.11.1"
svg-hush = "0.9.6" svg-hush = "0.9.6"
@ -183,7 +240,7 @@ semver = "1.0.28"
# Allow overriding the default memory allocator # Allow overriding the default memory allocator
# Mainly used for the musl builds, since the default musl malloc is very slow # Mainly used for the musl builds, since the default musl malloc is very slow
mimalloc = { version = "0.1.50", features = ["secure"], default-features = false, optional = true } mimalloc = { version = "0.1.50", optional = true, default-features = false, features = ["secure"] }
which = "8.0.2" which = "8.0.2"
@ -197,10 +254,15 @@ rpassword = "7.5.2"
grass_compiler = { version = "0.13.4", default-features = false } grass_compiler = { version = "0.13.4", default-features = false }
# File are accessed through Apache OpenDAL # File are accessed through Apache OpenDAL
opendal = { version = "0.56.0", features = ["services-fs"], default-features = false } opendal = { version = "0.56.0", default-features = false, features = ["services-fs"] }
# For retrieving AWS credentials, including temporary SSO credentials # For retrieving AWS credentials, including temporary SSO credentials
aws-config = { version = "1.8.16", features = ["behavior-version-latest", "rt-tokio", "credentials-process", "sso"], default-features = false, optional = true } aws-config = { version = "1.8.16", optional = true, default-features = false, features = [
"behavior-version-latest",
"credentials-process",
"rt-tokio",
"sso",
] }
aws-credential-types = { version = "1.2.14", optional = true } aws-credential-types = { version = "1.2.14", optional = true }
aws-smithy-runtime-api = { version = "1.12.0", optional = true } aws-smithy-runtime-api = { version = "1.12.0", optional = true }
http = { version = "1.4.0", optional = true } http = { version = "1.4.0", optional = true }
@ -265,77 +327,74 @@ unsafe_code = "forbid"
non_ascii_idents = "forbid" non_ascii_idents = "forbid"
# Deny # Deny
deprecated_in_future = "deny" warnings = "deny" # Explicitly deny all warnings since we deny all warnings in the end
# Deny lint groups
deprecated_safe = { level = "deny", priority = -1 } deprecated_safe = { level = "deny", priority = -1 }
future_incompatible = { level = "deny", priority = -1 } future_incompatible = { level = "deny", priority = -1 }
keyword_idents = { level = "deny", priority = -1 } keyword_idents = { level = "deny", priority = -1 }
let_underscore = { level = "deny", priority = -1 } let_underscore = { level = "deny", priority = -1 }
nonstandard_style = { level = "deny", priority = -1 } nonstandard_style = { level = "deny", priority = -1 }
noop_method_call = "deny"
refining_impl_trait = { level = "deny", priority = -1 } refining_impl_trait = { level = "deny", priority = -1 }
rust_2018_idioms = { level = "deny", priority = -1 } rust_2018_idioms = { level = "deny", priority = -1 }
rust_2021_compatibility = { level = "deny", priority = -1 } rust_2021_compatibility = { level = "deny", priority = -1 }
rust_2024_compatibility = { level = "deny", priority = -1 } rust_2024_compatibility = { level = "deny", priority = -1 }
unused = { level = "deny", priority = -1 }
# Deny individual lints
closure_returning_async_block = "deny"
deprecated_in_future = "deny"
single_use_lifetimes = "deny" single_use_lifetimes = "deny"
trivial_casts = "deny" trivial_casts = "deny"
trivial_numeric_casts = "deny" trivial_numeric_casts = "deny"
unused = { level = "deny", priority = -1 }
unused_import_braces = "deny" unused_import_braces = "deny"
unused_lifetimes = "deny" unused_lifetimes = "deny"
unused_qualifications = "deny" unused_qualifications = "deny"
variant_size_differences = "deny" variant_size_differences = "deny"
# Allow the following lints since these cause issues with Rust v1.84.0 or newer
# Building Vaultwarden with Rust v1.85.0 with edition 2024 also works without issues
edition_2024_expr_fragment_specifier = "allow" # Once changed to Rust 2024 this should be removed and macro's should be validated again
if_let_rescope = "allow"
tail_expr_drop_order = "allow"
# https://rust-lang.github.io/rust-clippy/stable/index.html # https://rust-lang.github.io/rust-clippy/stable/index.html
[workspace.lints.clippy] [workspace.lints.clippy]
# Warn # Warn only so you can still use these during development, but not in the final code
dbg_macro = "warn" dbg_macro = "warn"
todo = "warn" todo = "warn"
# Ignore/Allow # Ignore/Allow
result_large_err = "allow" result_large_err = "allow"
# Deny # Warn on these lint group (Some might be warn by default already though)
# Will be denied during CI!
complexity = { level = "warn", priority = -1 }
pedantic = { level = "warn", priority = -1 }
perf = { level = "warn", priority = -1 }
style = { level = "warn", priority = -1 }
suspicious = { level = "warn", priority = -1 }
# Deny individual lints
branches_sharing_code = "deny" branches_sharing_code = "deny"
case_sensitive_file_extension_comparisons = "deny"
cast_lossless = "deny"
clone_on_ref_ptr = "deny" clone_on_ref_ptr = "deny"
duration_suboptimal_units = "deny"
equatable_if_let = "deny" equatable_if_let = "deny"
excessive_precision = "deny"
filter_map_next = "deny"
float_cmp_const = "deny" float_cmp_const = "deny"
implicit_clone = "deny"
inefficient_to_string = "deny"
iter_on_empty_collections = "deny" iter_on_empty_collections = "deny"
iter_on_single_items = "deny" iter_on_single_items = "deny"
linkedlist = "deny"
macro_use_imports = "deny"
manual_assert = "deny"
manual_instant_elapsed = "deny"
manual_string_new = "deny"
match_wildcard_for_single_variants = "deny"
mem_forget = "deny" mem_forget = "deny"
needless_borrow = "deny"
needless_collect = "deny" needless_collect = "deny"
needless_continue = "deny"
needless_lifetimes = "deny"
option_option = "deny"
redundant_clone = "deny" redundant_clone = "deny"
ref_option = "deny"
string_add_assign = "deny"
unnecessary_join = "deny"
unnecessary_self_imports = "deny" unnecessary_self_imports = "deny"
unnested_or_patterns = "deny"
unused_async = "deny"
unused_self = "deny"
useless_let_if_seq = "deny" useless_let_if_seq = "deny"
verbose_file_reads = "deny" verbose_file_reads = "deny"
zero_sized_map_values = "deny" str_to_string = "deny"
# Pedantic Opt-Outs
inline_always = "allow" # We use this sparsely
struct_field_names = "allow" # Noisy and some items are Bitwarden controlled
large_futures = "allow" # Causes a fail in some Rocket macro's, since we experience no issues, allow it
too_many_lines = "allow" # For now, allow this, good to enable in the future and see if we can refactor
unnecessary_wraps = "allow" # Too much false positives because of Rocket integrations
# We do not use these doc items
doc_link_with_quotes = "allow"
doc_markdown = "allow"
missing_errors_doc = "allow"
missing_panics_doc = "allow"
[lints] [lints]
workspace = true workspace = true

10
build.rs

@ -1,5 +1,4 @@
use std::env; use std::{env, io::Error, process::Command};
use std::process::Command;
fn main() { fn main() {
// These allow using e.g. #[cfg(mysql)] instead of #[cfg(feature = "mysql")], which helps when trying to add them through macros // These allow using e.g. #[cfg(mysql)] instead of #[cfg(feature = "mysql")], which helps when trying to add them through macros
@ -42,13 +41,12 @@ fn main() {
} }
} }
fn run(args: &[&str]) -> Result<String, std::io::Error> { fn run(args: &[&str]) -> Result<String, Error> {
let out = Command::new(args[0]).args(&args[1..]).output()?; let out = Command::new(args[0]).args(&args[1..]).output()?;
if !out.status.success() { if !out.status.success() {
use std::io::Error;
return Err(Error::other("Command not successful")); return Err(Error::other("Command not successful"));
} }
Ok(String::from_utf8(out.stdout).unwrap().trim().to_string()) Ok(String::from_utf8(out.stdout).unwrap().trim().to_owned())
} }
/// This method reads info from Git, namely tags, branch, and revision /// This method reads info from Git, namely tags, branch, and revision
@ -58,7 +56,7 @@ fn run(args: &[&str]) -> Result<String, std::io::Error> {
/// - `env!("GIT_BRANCH")` /// - `env!("GIT_BRANCH")`
/// - `env!("GIT_REV")` /// - `env!("GIT_REV")`
/// - `env!("VW_VERSION")` /// - `env!("VW_VERSION")`
fn version_from_git_info() -> Result<String, std::io::Error> { fn version_from_git_info() -> Result<String, Error> {
// The exact tag for the current commit, can be empty when // The exact tag for the current commit, can be empty when
// the current commit doesn't have an associated tag // the current commit doesn't have an associated tag
let exact_tag = run(&["git", "describe", "--abbrev=0", "--tags", "--exact-match"]).ok(); let exact_tag = run(&["git", "describe", "--abbrev=0", "--tags", "--exact-match"]).ok();

9
macros/src/lib.rs

@ -1,14 +1,15 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::quote; use quote::quote;
use syn::{DeriveInput, parse_macro_input};
#[proc_macro_derive(UuidFromParam)] #[proc_macro_derive(UuidFromParam)]
pub fn derive_uuid_from_param(input: TokenStream) -> TokenStream { pub fn derive_uuid_from_param(input: TokenStream) -> TokenStream {
let ast = syn::parse(input).unwrap(); let ast = parse_macro_input!(input as DeriveInput);
impl_derive_uuid_macro(&ast) impl_derive_uuid_macro(&ast)
} }
fn impl_derive_uuid_macro(ast: &syn::DeriveInput) -> TokenStream { fn impl_derive_uuid_macro(ast: &DeriveInput) -> TokenStream {
let name = &ast.ident; let name = &ast.ident;
let gen_derive = quote! { let gen_derive = quote! {
#[automatically_derived] #[automatically_derived]
@ -30,12 +31,12 @@ fn impl_derive_uuid_macro(ast: &syn::DeriveInput) -> TokenStream {
#[proc_macro_derive(IdFromParam)] #[proc_macro_derive(IdFromParam)]
pub fn derive_id_from_param(input: TokenStream) -> TokenStream { pub fn derive_id_from_param(input: TokenStream) -> TokenStream {
let ast = syn::parse(input).unwrap(); let ast = parse_macro_input!(input as DeriveInput);
impl_derive_safestring_macro(&ast) impl_derive_safestring_macro(&ast)
} }
fn impl_derive_safestring_macro(ast: &syn::DeriveInput) -> TokenStream { fn impl_derive_safestring_macro(ast: &DeriveInput) -> TokenStream {
let name = &ast.ident; let name = &ast.ident;
let gen_derive = quote! { let gen_derive = quote! {
#[automatically_derived] #[automatically_derived]

2
rustfmt.toml

@ -1,4 +1,4 @@
edition = "2021" edition = "2024"
max_width = 120 max_width = 120
newline_style = "Unix" newline_style = "Unix"
use_small_heuristics = "Off" use_small_heuristics = "Off"

111
src/api/admin.rs

@ -2,40 +2,40 @@ use std::{env, sync::LazyLock};
use reqwest::Method; use reqwest::Method;
use rocket::{ use rocket::{
Catcher, Route,
form::Form, form::Form,
http::{Cookie, CookieJar, MediaType, SameSite, Status}, http::{Cookie, CookieJar, MediaType, SameSite, Status},
request::{FromRequest, Outcome, Request}, request::{FromRequest, Outcome, Request},
response::{content::RawHtml as Html, Redirect}, response::{Redirect, content::RawHtml as Html},
serde::json::Json, serde::json::Json,
Catcher, Route,
}; };
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
CONFIG, VERSION,
api::{ api::{
ApiResult, EmptyResult, JsonResult, Notify,
core::{log_event, two_factor}, core::{log_event, two_factor},
unregister_push_device, ApiResult, EmptyResult, JsonResult, Notify, unregister_push_device,
}, },
auth::{decode_admin, encode_jwt, generate_admin_claims, ClientIp, Secure}, auth::{ClientIp, Secure, decode_admin, encode_jwt, generate_admin_claims},
config::ConfigBuilder, config::ConfigBuilder,
db::{ db::{
backup_sqlite, get_sql_server_version, ACTIVE_DB_TYPE, DbConn, DbConnType, backup_sqlite, get_sql_server_version,
models::{ models::{
Attachment, Cipher, Collection, Device, Event, EventType, Group, Invitation, Membership, MembershipId, Attachment, Cipher, Collection, Device, Event, EventType, Group, Invitation, Membership, MembershipId,
MembershipType, OrgPolicy, Organization, OrganizationId, SsoUser, TwoFactor, User, UserId, MembershipType, OrgPolicy, Organization, OrganizationId, SsoUser, TwoFactor, User, UserId,
}, },
DbConn, DbConnType, ACTIVE_DB_TYPE,
}, },
error::{Error, MapResult}, error::{Error, MapResult},
http_client::make_http_request, http_client::make_http_request,
mail, mail,
sso::FAKE_SSO_IDENTIFIER, sso::FAKE_SSO_IDENTIFIER,
util::{ util::{
container_base_image, format_naive_datetime_local, get_active_web_release, get_display_size, FeatureFlagFilter, NumberOrString, container_base_image, format_naive_datetime_local, get_active_web_release,
is_running_in_container, parse_experimental_client_feature_flags, FeatureFlagFilter, NumberOrString, get_display_size, is_running_in_container, parse_experimental_client_feature_flags,
}, },
CONFIG, VERSION,
}; };
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
@ -93,8 +93,7 @@ static DB_TYPE: LazyLock<&str> = LazyLock::new(|| match ACTIVE_DB_TYPE.get() {
}); });
#[cfg(sqlite)] #[cfg(sqlite)]
static CAN_BACKUP: LazyLock<bool> = static CAN_BACKUP: LazyLock<bool> = LazyLock::new(|| ACTIVE_DB_TYPE.get().is_some_and(|t| *t == DbConnType::Sqlite));
LazyLock::new(|| ACTIVE_DB_TYPE.get().map(|t| *t == DbConnType::Sqlite).unwrap_or(false));
#[cfg(not(sqlite))] #[cfg(not(sqlite))]
static CAN_BACKUP: LazyLock<bool> = LazyLock::new(|| false); static CAN_BACKUP: LazyLock<bool> = LazyLock::new(|| false);
@ -200,13 +199,7 @@ fn post_admin_login(
} }
// If the token is invalid, redirect to login page // If the token is invalid, redirect to login page
if !_validate_token(&data.token) { if validate_token(&data.token) {
error!("Invalid admin token. IP: {}", ip.ip);
Err(AdminResponse::Unauthorized(render_admin_login(
Some("Invalid admin token, please try again."),
redirect.as_deref(),
)))
} else {
// If the token received is valid, generate JWT and save it as a cookie // If the token received is valid, generate JWT and save it as a cookie
let claims = generate_admin_claims(); let claims = generate_admin_claims();
let jwt = encode_jwt(&claims); let jwt = encode_jwt(&claims);
@ -224,10 +217,16 @@ fn post_admin_login(
} else { } else {
Err(AdminResponse::Ok(render_admin_page())) Err(AdminResponse::Ok(render_admin_page()))
} }
} else {
error!("Invalid admin token. IP: {}", ip.ip);
Err(AdminResponse::Unauthorized(render_admin_login(
Some("Invalid admin token, please try again."),
redirect.as_deref(),
)))
} }
} }
fn _validate_token(token: &str) -> bool { fn validate_token(token: &str) -> bool {
match CONFIG.admin_token().as_ref() { match CONFIG.admin_token().as_ref() {
None => false, None => false,
Some(t) if t.starts_with("$argon2") => { Some(t) if t.starts_with("$argon2") => {
@ -307,21 +306,14 @@ async fn get_user_or_404(user_id: &UserId, conn: &DbConn) -> ApiResult<User> {
#[post("/invite", format = "application/json", data = "<data>")] #[post("/invite", format = "application/json", data = "<data>")]
async fn invite_user(data: Json<InviteData>, _token: AdminToken, conn: DbConn) -> JsonResult { async fn invite_user(data: Json<InviteData>, _token: AdminToken, conn: DbConn) -> JsonResult {
let data: InviteData = data.into_inner(); async fn generate_invite(user: &User, conn: &DbConn) -> EmptyResult {
if User::find_by_mail(&data.email, &conn).await.is_some() {
err_code!("User already exists", Status::Conflict.code)
}
let mut user = User::new(&data.email, None);
async fn _generate_invite(user: &User, conn: &DbConn) -> EmptyResult {
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
let org_id: OrganizationId = if CONFIG.sso_enabled() { let org_id: OrganizationId = if CONFIG.sso_enabled() {
FAKE_SSO_IDENTIFIER.into() FAKE_SSO_IDENTIFIER.into()
} else { } else {
FAKE_ADMIN_UUID.into() FAKE_ADMIN_UUID.into()
}; };
let member_id: MembershipId = FAKE_ADMIN_UUID.to_string().into(); let member_id: MembershipId = FAKE_ADMIN_UUID.to_owned().into();
mail::send_invite(user, org_id, member_id, &CONFIG.invitation_org_name(), None).await mail::send_invite(user, org_id, member_id, &CONFIG.invitation_org_name(), None).await
} else { } else {
let invitation = Invitation::new(&user.email); let invitation = Invitation::new(&user.email);
@ -329,7 +321,14 @@ async fn invite_user(data: Json<InviteData>, _token: AdminToken, conn: DbConn) -
} }
} }
_generate_invite(&user, &conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?; let data: InviteData = data.into_inner();
if User::find_by_mail(&data.email, &conn).await.is_some() {
err_code!("User already exists", Status::Conflict.code)
}
let mut user = User::new(&data.email, None);
generate_invite(&user, &conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
user.save(&conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?; user.save(&conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
Ok(Json(user.to_json(&conn).await)) Ok(Json(user.to_json(&conn).await))
@ -386,7 +385,7 @@ async fn users_overview(_token: AdminToken, conn: DbConn) -> ApiResult<Html<Stri
None => json!("Never"), None => json!("Never"),
}; };
usr["sso_identifier"] = json!(sso_u.map(|u| u.identifier.to_string()).unwrap_or(String::new())); usr["sso_identifier"] = json!(sso_u.map_or(String::new(), |u| u.identifier.to_string()));
users_json.push(usr); users_json.push(usr);
} }
@ -472,7 +471,7 @@ async fn deauth_user(user_id: UserId, _token: AdminToken, conn: DbConn, nt: Noti
match unregister_push_device(device.push_uuid.as_ref()).await { match unregister_push_device(device.push_uuid.as_ref()).await {
Ok(r) => r, Ok(r) => r,
Err(e) => error!("Unable to unregister devices from Bitwarden server: {e}"), Err(e) => error!("Unable to unregister devices from Bitwarden server: {e}"),
}; }
} }
} }
@ -528,7 +527,7 @@ async fn resend_user_invite(user_id: UserId, _token: AdminToken, conn: DbConn) -
} else { } else {
FAKE_ADMIN_UUID.into() FAKE_ADMIN_UUID.into()
}; };
let member_id: MembershipId = FAKE_ADMIN_UUID.to_string().into(); let member_id: MembershipId = FAKE_ADMIN_UUID.to_owned().into();
mail::send_invite(&user, org_id, member_id, &CONFIG.invitation_org_name(), None).await mail::send_invite(&user, org_id, member_id, &CONFIG.invitation_org_name(), None).await
} else { } else {
Ok(()) Ok(())
@ -554,9 +553,10 @@ async fn update_membership_type(data: Json<MembershipTypeData>, token: AdminToke
err!("The specified user isn't member of the organization") err!("The specified user isn't member of the organization")
}; };
let new_type = match MembershipType::from_str(&data.user_type.into_string()) { let new_type = if let Some(new_type) = MembershipType::from_str(&data.user_type.into_string()) {
Some(new_type) => new_type as i32, new_type as i32
None => err!("Invalid type"), } else {
err!("Invalid type")
}; };
if member_to_edit.atype == MembershipType::Owner && new_type != MembershipType::Owner { if member_to_edit.atype == MembershipType::Owner && new_type != MembershipType::Owner {
@ -656,35 +656,35 @@ async fn get_release_info(has_http_access: bool) -> (String, String, String) {
.await .await
{ {
Ok(r) => r.tag_name, Ok(r) => r.tag_name,
_ => "-".to_string(), _ => "-".to_owned(),
}, },
match get_json_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main").await { match get_json_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main").await {
Ok(mut c) => { Ok(mut c) => {
c.sha.truncate(8); c.sha.truncate(8);
c.sha c.sha
} }
_ => "-".to_string(), _ => "-".to_owned(),
}, },
// Do not fetch the web-vault version when running within a container // Do not fetch the web-vault version when running within a container
// The web-vault version is embedded within the container it self, and should not be updated manually // The web-vault version is embedded within the container it self, and should not be updated manually
match get_json_api::<GitRelease>("https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest") match get_json_api::<GitRelease>("https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest")
.await .await
{ {
Ok(r) => r.tag_name.trim_start_matches('v').to_string(), Ok(r) => r.tag_name.trim_start_matches('v').to_owned(),
_ => "-".to_string(), _ => "-".to_owned(),
}, },
) )
} else { } else {
("-".to_string(), "-".to_string(), "-".to_string()) ("-".to_owned(), "-".to_owned(), "-".to_owned())
} }
} }
async fn get_ntp_time(has_http_access: bool) -> String { async fn get_ntp_time(has_http_access: bool) -> String {
if has_http_access { if has_http_access && let Ok(cf_trace) = get_text_api("https://cloudflare.com/cdn-cgi/trace").await {
if let Ok(cf_trace) = get_text_api("https://cloudflare.com/cdn-cgi/trace").await {
for line in cf_trace.lines() { for line in cf_trace.lines() {
if let Some((key, value)) = line.split_once('=') { if let Some((key, value)) = line.split_once('=')
if key == "ts" { && key == "ts"
{
let ts = value.split_once('.').map_or(value, |(s, _)| s); let ts = value.split_once('.').map_or(value, |(s, _)| s);
if let Ok(dt) = chrono::DateTime::parse_from_str(ts, "%s") { if let Ok(dt) = chrono::DateTime::parse_from_str(ts, "%s") {
return dt.format("%Y-%m-%d %H:%M:%S UTC").to_string(); return dt.format("%Y-%m-%d %H:%M:%S UTC").to_string();
@ -693,8 +693,6 @@ async fn get_ntp_time(has_http_access: bool) -> String {
} }
} }
} }
}
}
String::from("Unable to fetch NTP time.") String::from("Unable to fetch NTP time.")
} }
@ -734,7 +732,7 @@ async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> A
// Check if we are able to resolve DNS entries // Check if we are able to resolve DNS entries
let dns_resolved = match ("github.com", 0).to_socket_addrs().map(|mut i| i.next()) { let dns_resolved = match ("github.com", 0).to_socket_addrs().map(|mut i| i.next()) {
Ok(Some(a)) => a.ip().to_string(), Ok(Some(a)) => a.ip().to_string(),
_ => "Unable to resolve domain name.".to_string(), _ => "Unable to resolve domain name.".to_owned(),
}; };
let (latest_vw_release, latest_vw_commit, latest_web_release) = get_release_info(has_http_access).await; let (latest_vw_release, latest_vw_commit, latest_web_release) = get_release_info(has_http_access).await;
@ -745,7 +743,7 @@ async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> A
let invalid_feature_flags: Vec<String> = parse_experimental_client_feature_flags( let invalid_feature_flags: Vec<String> = parse_experimental_client_feature_flags(
&CONFIG.experimental_client_feature_flags(), &CONFIG.experimental_client_feature_flags(),
FeatureFlagFilter::InvalidOnly, &FeatureFlagFilter::InvalidOnly,
) )
.into_keys() .into_keys()
.collect(); .collect();
@ -834,33 +832,30 @@ impl<'r> FromRequest<'r> for AdminToken {
type Error = &'static str; type Error = &'static str;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let ip = match ClientIp::from_request(request).await { let Outcome::Success(ip) = ClientIp::from_request(request).await else {
Outcome::Success(ip) => ip, err_handler!("Error getting Client IP")
_ => err_handler!("Error getting Client IP"),
}; };
if !CONFIG.disable_admin_token() { if !CONFIG.disable_admin_token() {
let cookies = request.cookies(); let cookies = request.cookies();
let access_token = match cookies.get(COOKIE_NAME) { let access_token = if let Some(cookie) = cookies.get(COOKIE_NAME) {
Some(cookie) => cookie.value(), cookie.value()
None => { } else {
let requested_page = let requested_page =
request.segments::<std::path::PathBuf>(0..).unwrap_or_default().display().to_string(); request.segments::<std::path::PathBuf>(0..).unwrap_or_default().display().to_string();
// When the requested page is empty, it is `/admin`, in that case, Forward, so it will render the login page // When the requested page is empty, it is `/admin`, in that case, Forward, so it will render the login page
// Else, return a 401 failure, which will be caught // Else, return a 401 failure, which will be caught
if requested_page.is_empty() { if requested_page.is_empty() {
return Outcome::Forward(Status::Unauthorized); return Outcome::Forward(Status::Unauthorized);
} else {
return Outcome::Error((Status::Unauthorized, "Unauthorized"));
}
} }
return Outcome::Error((Status::Unauthorized, "Unauthorized"));
}; };
if decode_admin(access_token).is_err() { if decode_admin(access_token).is_err() {
// Remove admin cookie // Remove admin cookie
cookies.remove(Cookie::build(COOKIE_NAME).path(admin_path())); cookies.remove(Cookie::build(COOKIE_NAME).path(admin_path()));
error!("Invalid or expired admin JWT. IP: {}.", &ip.ip); error!("Invalid or expired admin JWT. IP: {}.", ip.ip);
return Outcome::Error((Status::Unauthorized, "Session expired")); return Outcome::Error((Status::Unauthorized, "Session expired"));
} }
} }

122
src/api/core/accounts.rs

@ -1,34 +1,37 @@
use std::collections::HashSet; use std::collections::HashSet;
use crate::db::DbPool;
use chrono::Utc; use chrono::Utc;
use rocket::serde::json::Json; use rocket::{
http::Status,
request::{FromRequest, Outcome, Request},
serde::json::Json,
};
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
CONFIG,
api::{ api::{
AnonymousNotify, ApiResult, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType,
core::{accept_org_invite, log_user_event, two_factor::email}, core::{accept_org_invite, log_user_event, two_factor::email},
master_password_policy, register_push_device, unregister_push_device, AnonymousNotify, ApiResult, EmptyResult, master_password_policy, register_push_device, unregister_push_device,
JsonResult, Notify, PasswordOrOtpData, UpdateType,
}, },
auth::{decode_delete, decode_invite, decode_verify_email, ClientHeaders, Headers}, auth::{ClientHeaders, Headers, decode_delete, decode_invite, decode_verify_email},
crypto, crypto,
db::{ db::{
DbConn, DbPool,
models::{ models::{
AuthRequest, AuthRequestId, Cipher, CipherId, Device, DeviceId, DeviceType, EmergencyAccess, AuthRequest, AuthRequestId, Cipher, CipherId, Device, DeviceId, DeviceType, DeviceWithAuthRequest,
EmergencyAccessId, EventType, Folder, FolderId, Invitation, Membership, MembershipId, OrgPolicy, EmergencyAccess, EmergencyAccessId, EventType, Folder, FolderId, Invitation, Membership, MembershipId,
OrgPolicyType, Organization, OrganizationId, Send, SendId, User, UserId, UserKdfType, OrgPolicy, OrgPolicyType, Organization, OrganizationId, Send, SendId, User, UserId, UserKdfType,
}, },
DbConn,
}, },
mail, mail,
util::{deser_opt_nonempty_str, format_date, NumberOrString}, util::{NumberOrString, deser_opt_nonempty_str, format_date},
CONFIG,
}; };
use rocket::{ use super::{
http::Status, ciphers::{CipherData, update_cipher_from_data},
request::{FromRequest, Outcome, Request}, sends::{SendData, update_send_from_data},
}; };
pub fn routes() -> Vec<rocket::Route> { pub fn routes() -> Vec<rocket::Route> {
@ -54,9 +57,9 @@ pub fn routes() -> Vec<rocket::Route> {
delete_account, delete_account,
revision_date, revision_date,
password_hint, password_hint,
prelogin, post_prelogin,
verify_password, verify_password,
api_key, post_api_key,
rotate_api_key, rotate_api_key,
get_known_device, get_known_device,
get_all_devices, get_all_devices,
@ -142,7 +145,7 @@ fn clean_password_hint(password_hint: Option<&String>) -> Option<String> {
None => None, None => None,
Some(h) => match h.trim() { Some(h) => match h.trim() {
"" => None, "" => None,
ht => Some(ht.to_string()), ht => Some(ht.to_owned()),
}, },
} }
} }
@ -166,7 +169,7 @@ async fn is_email_2fa_required(member_id: Option<MembershipId>, conn: &DbConn) -
false false
} }
pub async fn _register(data: Json<RegisterData>, email_verification: bool, conn: DbConn) -> JsonResult { pub async fn register(data: Json<RegisterData>, email_verification: bool, conn: DbConn) -> JsonResult {
let mut data: RegisterData = data.into_inner(); let mut data: RegisterData = data.into_inner();
let email = data.email.to_lowercase(); let email = data.email.to_lowercase();
@ -237,11 +240,11 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, conn:
// Check if the length of the username exceeds 50 characters (Same is Upstream Bitwarden) // Check if the length of the username exceeds 50 characters (Same is Upstream Bitwarden)
// This also prevents issues with very long usernames causing to large JWT's. See #2419 // This also prevents issues with very long usernames causing to large JWT's. See #2419
if let Some(ref name) = data.name { if let Some(ref name) = data.name
if name.len() > 50 { && name.len() > 50
{
err!("The field Name must be a string with a maximum length of 50."); err!("The field Name must be a string with a maximum length of 50.");
} }
}
// Check against the password hint setting here so if it fails, the user // Check against the password hint setting here so if it fails, the user
// can retry without losing their invitation below. // can retry without losing their invitation below.
@ -373,8 +376,10 @@ async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, conn:
user.public_key = Some(keys.public_key); user.public_key = Some(keys.public_key);
} }
if let Some(identifier) = data.org_identifier { if let Some(identifier) = data.org_identifier
if identifier != crate::sso::FAKE_SSO_IDENTIFIER && identifier != crate::api::admin::FAKE_ADMIN_UUID { && identifier != crate::sso::FAKE_SSO_IDENTIFIER
&& identifier != crate::api::admin::FAKE_ADMIN_UUID
{
let Some(org) = Organization::find_by_uuid(&identifier.into(), &conn).await else { let Some(org) = Organization::find_by_uuid(&identifier.into(), &conn).await else {
err!("Failed to retrieve the associated organization") err!("Failed to retrieve the associated organization")
}; };
@ -385,7 +390,6 @@ async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, conn:
accept_org_invite(&user, membership, None, &conn).await?; accept_org_invite(&user, membership, None, &conn).await?;
} }
}
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_welcome(&user.email.to_lowercase()).await?; mail::send_welcome(&user.email.to_lowercase()).await?;
@ -451,11 +455,11 @@ async fn put_avatar(data: Json<AvatarData>, headers: Headers, conn: DbConn) -> J
// It looks like it only supports the 6 hex color format. // It looks like it only supports the 6 hex color format.
// If you try to add the short value it will not show that color. // If you try to add the short value it will not show that color.
// Check and force 7 chars, including the #. // Check and force 7 chars, including the #.
if let Some(color) = &data.avatar_color { if let Some(color) = &data.avatar_color
if color.len() != 7 { && color.len() != 7
{
err!("The field AvatarColor must be a HTML/Hex color code with a length of 7 characters") err!("The field AvatarColor must be a HTML/Hex color code with a length of 7 characters")
} }
}
let mut user = headers.user; let mut user = headers.user;
user.avatar_color = data.avatar_color; user.avatar_color = data.avatar_color;
@ -668,9 +672,6 @@ struct UpdateResetPasswordData {
reset_password_key: String, reset_password_key: String,
} }
use super::ciphers::CipherData;
use super::sends::{update_send_from_data, SendData};
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
struct KeyData { struct KeyData {
@ -840,7 +841,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
}; };
saved_folder.name = folder_data.name; saved_folder.name = folder_data.name;
saved_folder.save(&conn).await? saved_folder.save(&conn).await?;
} }
} }
@ -853,7 +854,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
}; };
saved_emergency_access.key_encrypted = Some(emergency_access_data.key_encrypted); saved_emergency_access.key_encrypted = Some(emergency_access_data.key_encrypted);
saved_emergency_access.save(&conn).await? saved_emergency_access.save(&conn).await?;
} }
// Update reset password data // Update reset password data
@ -865,7 +866,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
}; };
membership.reset_password_key = Some(reset_password_data.reset_password_key); membership.reset_password_key = Some(reset_password_data.reset_password_key);
membership.save(&conn).await? membership.save(&conn).await?;
} }
// Update send data // Update send data
@ -878,8 +879,6 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
} }
// Update cipher data // Update cipher data
use super::ciphers::update_cipher_from_data;
for cipher_data in data.account_data.ciphers { for cipher_data in data.account_data.ciphers {
if cipher_data.organization_id.is_none() { if cipher_data.organization_id.is_none() {
let Some(saved_cipher) = existing_ciphers.iter_mut().find(|c| &c.uuid == cipher_data.id.as_ref().unwrap()) let Some(saved_cipher) = existing_ciphers.iter_mut().find(|c| &c.uuid == cipher_data.id.as_ref().unwrap())
@ -890,7 +889,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
// Prevent triggering cipher updates via WebSockets by settings UpdateType::None // Prevent triggering cipher updates via WebSockets by settings UpdateType::None
// The user sessions are invalidated because all the ciphers were re-encrypted and thus triggering an update could cause issues. // The user sessions are invalidated because all the ciphers were re-encrypted and thus triggering an update could cause issues.
// We force the users to logout after the user has been saved to try and prevent these issues. // We force the users to logout after the user has been saved to try and prevent these issues.
update_cipher_from_data(saved_cipher, cipher_data, &headers, None, &conn, &nt, UpdateType::None).await? update_cipher_from_data(saved_cipher, cipher_data, &headers, None, &conn, &nt, UpdateType::None).await?;
} }
} }
@ -1020,24 +1019,22 @@ async fn post_email(data: Json<ChangeEmailData>, headers: Headers, conn: DbConn,
err!("Email already in use"); err!("Email already in use");
} }
match user.email_new { if let Some(ref val) = user.email_new {
Some(ref val) => {
if val != &data.new_email { if val != &data.new_email {
err!("Email change mismatch"); err!("Email change mismatch");
} }
} } else {
None => err!("No email change pending"), err!("No email change pending")
} }
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
// Only check the token if we sent out an email... // Only check the token if we sent out an email...
match user.email_new_token { if let Some(ref val) = user.email_new_token {
Some(ref val) => {
if *val != data.token.into_string() { if *val != data.token.into_string() {
err!("Token mismatch"); err!("Token mismatch");
} }
} } else {
None => err!("No email change pending"), err!("No email change pending")
} }
user.verified_at = Some(Utc::now().naive_utc()); user.verified_at = Some(Utc::now().naive_utc());
} else { } else {
@ -1114,11 +1111,11 @@ async fn post_delete_recover(data: Json<DeleteRecoverData>, conn: DbConn) -> Emp
let data: DeleteRecoverData = data.into_inner(); let data: DeleteRecoverData = data.into_inner();
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
if let Some(user) = User::find_by_mail(&data.email, &conn).await { if let Some(user) = User::find_by_mail(&data.email, &conn).await
if let Err(e) = mail::send_delete_account(&user.email, &user.uuid).await { && let Err(e) = mail::send_delete_account(&user.email, &user.uuid).await
{
error!("Error sending delete account email: {e:#?}"); error!("Error sending delete account email: {e:#?}");
} }
}
Ok(()) Ok(())
} else { } else {
// We don't support sending emails, but we shouldn't allow anybody // We don't support sending emails, but we shouldn't allow anybody
@ -1169,6 +1166,7 @@ async fn delete_account(data: Json<PasswordOrOtpData>, headers: Headers, conn: D
user.delete(&conn).await user.delete(&conn).await
} }
#[expect(clippy::needless_pass_by_value, reason = "Not beneficial for Headers")]
#[get("/accounts/revision-date")] #[get("/accounts/revision-date")]
fn revision_date(headers: Headers) -> JsonResult { fn revision_date(headers: Headers) -> JsonResult {
let revision_date = headers.user.updated_at.and_utc().timestamp_millis(); let revision_date = headers.user.updated_at.and_utc().timestamp_millis();
@ -1183,12 +1181,12 @@ struct PasswordHintData {
#[post("/accounts/password-hint", data = "<data>")] #[post("/accounts/password-hint", data = "<data>")]
async fn password_hint(data: Json<PasswordHintData>, conn: DbConn) -> EmptyResult { async fn password_hint(data: Json<PasswordHintData>, conn: DbConn) -> EmptyResult {
const NO_HINT: &str = "Sorry, you have no password hint...";
if !CONFIG.password_hints_allowed() || (!CONFIG.mail_enabled() && !CONFIG.show_password_hint()) { if !CONFIG.password_hints_allowed() || (!CONFIG.mail_enabled() && !CONFIG.show_password_hint()) {
err!("This server is not configured to provide password hints."); err!("This server is not configured to provide password hints.");
} }
const NO_HINT: &str = "Sorry, you have no password hint...";
let data: PasswordHintData = data.into_inner(); let data: PasswordHintData = data.into_inner();
let email = &data.email; let email = &data.email;
@ -1199,9 +1197,9 @@ async fn password_hint(data: Json<PasswordHintData>, conn: DbConn) -> EmptyResul
// There is still a timing side channel here in that the code // There is still a timing side channel here in that the code
// paths that send mail take noticeably longer than ones that // paths that send mail take noticeably longer than ones that
// don't. Add a randomized sleep to mitigate this somewhat. // don't. Add a randomized sleep to mitigate this somewhat.
use rand::{rngs::SmallRng, RngExt}; use rand::{RngExt, rngs::SmallRng};
let mut rng: SmallRng = rand::make_rng(); let mut rng: SmallRng = rand::make_rng();
let sleep_ms = rng.random_range(900..=1100) as u64; let sleep_ms: u64 = rng.random_range(900..=1100);
tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await; tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
Ok(()) Ok(())
} else { } else {
@ -1229,11 +1227,11 @@ pub struct PreloginData {
} }
#[post("/accounts/prelogin", data = "<data>")] #[post("/accounts/prelogin", data = "<data>")]
async fn prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> { async fn post_prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
_prelogin(data, conn).await prelogin(data, conn).await
} }
pub async fn _prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> { pub async fn prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
let data: PreloginData = data.into_inner(); let data: PreloginData = data.into_inner();
let (kdf_type, kdf_iter, kdf_mem, kdf_para) = match User::find_by_mail(&data.email, &conn).await { let (kdf_type, kdf_iter, kdf_mem, kdf_para) = match User::find_by_mail(&data.email, &conn).await {
@ -1283,9 +1281,7 @@ async fn verify_password(data: Json<SecretVerificationRequest>, headers: Headers
Ok(Json(master_password_policy(&user, &conn).await)) Ok(Json(master_password_policy(&user, &conn).await))
} }
async fn _api_key(data: Json<PasswordOrOtpData>, rotate: bool, headers: Headers, conn: DbConn) -> JsonResult { async fn update_api_key(data: Json<PasswordOrOtpData>, rotate: bool, headers: Headers, conn: DbConn) -> JsonResult {
use crate::util::format_date;
let data: PasswordOrOtpData = data.into_inner(); let data: PasswordOrOtpData = data.into_inner();
let mut user = headers.user; let mut user = headers.user;
@ -1304,13 +1300,13 @@ async fn _api_key(data: Json<PasswordOrOtpData>, rotate: bool, headers: Headers,
} }
#[post("/accounts/api-key", data = "<data>")] #[post("/accounts/api-key", data = "<data>")]
async fn api_key(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult { async fn post_api_key(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
_api_key(data, false, headers, conn).await update_api_key(data, false, headers, conn).await
} }
#[post("/accounts/rotate-api-key", data = "<data>")] #[post("/accounts/rotate-api-key", data = "<data>")]
async fn rotate_api_key(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult { async fn rotate_api_key(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
_api_key(data, true, headers, conn).await update_api_key(data, true, headers, conn).await
} }
#[get("/devices/knowndevice")] #[get("/devices/knowndevice")]
@ -1353,7 +1349,7 @@ impl<'r> FromRequest<'r> for KnownDevice {
}; };
let uuid = if let Some(uuid) = req.headers().get_one("X-Device-Identifier") { let uuid = if let Some(uuid) = req.headers().get_one("X-Device-Identifier") {
uuid.to_string().into() uuid.to_owned().into()
} else { } else {
return Outcome::Error((Status::BadRequest, "X-Device-Identifier value is required")); return Outcome::Error((Status::BadRequest, "X-Device-Identifier value is required"));
}; };
@ -1368,7 +1364,7 @@ impl<'r> FromRequest<'r> for KnownDevice {
#[get("/devices")] #[get("/devices")]
async fn get_all_devices(headers: Headers, conn: DbConn) -> JsonResult { async fn get_all_devices(headers: Headers, conn: DbConn) -> JsonResult {
let devices = Device::find_with_auth_request_by_user(&headers.user.uuid, &conn).await; let devices = Device::find_with_auth_request_by_user(&headers.user.uuid, &conn).await;
let devices = devices.iter().map(|device| device.to_json()).collect::<Vec<Value>>(); let devices = devices.iter().map(DeviceWithAuthRequest::to_json).collect::<Vec<Value>>();
Ok(Json(json!({ Ok(Json(json!({
"data": devices, "data": devices,
@ -1708,6 +1704,6 @@ pub async fn purge_auth_requests(pool: DbPool) {
if let Ok(conn) = pool.get().await { if let Ok(conn) = pool.get().await {
AuthRequest::purge_expired_auth_requests(&conn).await; AuthRequest::purge_expired_auth_requests(&conn).await;
} else { } else {
error!("Failed to get DB connection while purging auth requests") error!("Failed to get DB connection while purging auth requests");
} }
} }

170
src/api/core/ciphers.rs

@ -2,30 +2,30 @@ use std::collections::{HashMap, HashSet};
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use num_traits::ToPrimitive; use num_traits::ToPrimitive;
use rocket::fs::TempFile;
use rocket::serde::json::Json;
use rocket::{ use rocket::{
form::{Form, FromForm},
Route, Route,
form::{Form, FromForm},
fs::TempFile,
serde::json::Json,
}; };
use serde_json::Value; use serde_json::Value;
use crate::auth::ClientVersion;
use crate::util::{deser_opt_nonempty_str, save_temp_file, NumberOrString};
use crate::{ use crate::{
api::{self, core::log_event, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType}, CONFIG,
api::{self, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType, core::log_event},
auth::ClientVersion,
auth::{Headers, OrgIdGuard, OwnerHeaders}, auth::{Headers, OrgIdGuard, OwnerHeaders},
config::PathType, config::PathType,
crypto, crypto,
db::{ db::{
DbConn, DbPool,
models::{ models::{
Archive, Attachment, AttachmentId, Cipher, CipherId, Collection, CollectionCipher, CollectionGroup, Archive, Attachment, AttachmentId, Cipher, CipherId, Collection, CollectionCipher, CollectionGroup,
CollectionId, CollectionUser, EventType, Favorite, Folder, FolderCipher, FolderId, Group, Membership, CollectionId, CollectionUser, EventType, Favorite, Folder, FolderCipher, FolderId, Group, Membership,
MembershipType, OrgPolicy, OrgPolicyType, OrganizationId, RepromptType, Send, UserId, MembershipType, OrgPolicy, OrgPolicyType, OrganizationId, RepromptType, Send, UserId,
}, },
DbConn, DbPool,
}, },
CONFIG, util::{NumberOrString, deser_opt_nonempty_str, save_temp_file},
}; };
use super::folders::FolderData; use super::folders::FolderData;
@ -108,7 +108,7 @@ pub async fn purge_trashed_ciphers(pool: DbPool) {
if let Ok(conn) = pool.get().await { if let Ok(conn) = pool.get().await {
Cipher::purge_trash(&conn).await; Cipher::purge_trash(&conn).await;
} else { } else {
error!("Failed to get DB connection while purging trashed ciphers") error!("Failed to get DB connection while purging trashed ciphers");
} }
} }
@ -164,7 +164,7 @@ async fn sync(data: SyncData, headers: Headers, client_version: Option<ClientVer
let domains_json = if data.exclude_domains { let domains_json = if data.exclude_domains {
Value::Null Value::Null
} else { } else {
api::core::_get_eq_domains(&headers, true).into_inner() api::core::get_eq_domains(&headers, true).into_inner()
}; };
// This is very similar to the the userDecryptionOptions sent in connect/token, // This is very similar to the the userDecryptionOptions sent in connect/token,
@ -401,12 +401,27 @@ pub async fn update_cipher_from_data(
nt: &Notify<'_>, nt: &Notify<'_>,
ut: UpdateType, ut: UpdateType,
) -> EmptyResult { ) -> EmptyResult {
// Cleanup cipher data, like removing the 'Response' key.
// This key is somewhere generated during Javascript so no way for us this fix this.
// Also, upstream only retrieves keys they actually want to store, and thus skip the 'Response' key.
// We do not mind which data is in it, the keep our model more flexible when there are upstream changes.
// But, we at least know we do not need to store and return this specific key.
fn clean_cipher_data(mut json_data: Value) -> Value {
if json_data.is_array() {
json_data.as_array_mut().unwrap().iter_mut().for_each(|ref mut f| {
f.as_object_mut().unwrap().remove("response");
});
}
json_data
}
enforce_personal_ownership_policy(Some(&data), headers, conn).await?; enforce_personal_ownership_policy(Some(&data), headers, conn).await?;
// Check that the client isn't updating an existing cipher with stale data. // Check that the client isn't updating an existing cipher with stale data.
// And only perform this check when not importing ciphers, else the date/time check will fail. // And only perform this check when not importing ciphers, else the date/time check will fail.
if ut != UpdateType::None { if ut != UpdateType::None
if let Some(dt) = data.last_known_revision_date { && let Some(dt) = data.last_known_revision_date
{
match NaiveDateTime::parse_from_str(&dt, "%+") { match NaiveDateTime::parse_from_str(&dt, "%+") {
// ISO 8601 format // ISO 8601 format
Err(err) => warn!("Error parsing LastKnownRevisionDate '{dt}': {err}"), Err(err) => warn!("Error parsing LastKnownRevisionDate '{dt}': {err}"),
@ -416,7 +431,6 @@ pub async fn update_cipher_from_data(
Ok(_) => (), Ok(_) => (),
} }
} }
}
if cipher.organization_uuid.is_some() && cipher.organization_uuid != data.organization_id { if cipher.organization_uuid.is_some() && cipher.organization_uuid != data.organization_id {
err!("Organization mismatch. Please resync the client before updating the cipher") err!("Organization mismatch. Please resync the client before updating the cipher")
@ -456,25 +470,22 @@ pub async fn update_cipher_from_data(
cipher.user_uuid = Some(headers.user.uuid.clone()); cipher.user_uuid = Some(headers.user.uuid.clone());
} }
if let Some(ref folder_id) = data.folder_id { if let Some(ref folder_id) = data.folder_id
if Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, conn).await.is_none() { && Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, conn).await.is_none()
{
err!("Invalid folder", "Folder does not exist or belongs to another user"); err!("Invalid folder", "Folder does not exist or belongs to another user");
} }
}
// Modify attachments name and keys when rotating // Modify attachments name and keys when rotating
if let Some(attachments) = data.attachments2 { if let Some(attachments) = data.attachments2 {
for (id, attachment) in attachments { for (id, attachment) in attachments {
let mut saved_att = match Attachment::find_by_id(&id, conn).await { let Some(mut saved_att) = Attachment::find_by_id(&id, conn).await else {
Some(att) => att,
None => {
// Warn and continue here. // Warn and continue here.
// A missing attachment means it was removed via an other client. // A missing attachment means it was removed via an other client.
// Also the Desktop Client supports removing attachments and save an update afterwards. // Also the Desktop Client supports removing attachments and save an update afterwards.
// Bitwarden it self ignores these mismatches server side. // Bitwarden it self ignores these mismatches server side.
warn!("Attachment {id} doesn't exist"); warn!("Attachment {id} doesn't exist");
continue; continue;
}
}; };
if saved_att.cipher_uuid != cipher.uuid { if saved_att.cipher_uuid != cipher.uuid {
@ -491,20 +502,6 @@ pub async fn update_cipher_from_data(
} }
} }
// Cleanup cipher data, like removing the 'Response' key.
// This key is somewhere generated during Javascript so no way for us this fix this.
// Also, upstream only retrieves keys they actually want to store, and thus skip the 'Response' key.
// We do not mind which data is in it, the keep our model more flexible when there are upstream changes.
// But, we at least know we do not need to store and return this specific key.
fn _clean_cipher_data(mut json_data: Value) -> Value {
if json_data.is_array() {
json_data.as_array_mut().unwrap().iter_mut().for_each(|ref mut f| {
f.as_object_mut().unwrap().remove("response");
});
};
json_data
}
let type_data_opt = match data.r#type { let type_data_opt = match data.r#type {
1 => data.login, 1 => data.login,
2 => data.secure_note, 2 => data.secure_note,
@ -514,23 +511,22 @@ pub async fn update_cipher_from_data(
_ => err!("Invalid type"), _ => err!("Invalid type"),
}; };
let type_data = match type_data_opt { let type_data = if let Some(mut data) = type_data_opt {
Some(mut data) => {
// Remove the 'Response' key from the base object. // Remove the 'Response' key from the base object.
data.as_object_mut().unwrap().remove("response"); data.as_object_mut().unwrap().remove("response");
// Remove the 'Response' key from every Uri. // Remove the 'Response' key from every Uri.
if data["uris"].is_array() { if data["uris"].is_array() {
data["uris"] = _clean_cipher_data(data["uris"].clone()); data["uris"] = clean_cipher_data(data["uris"].clone());
} }
data data
} } else {
None => err!("Data missing"), err!("Data missing")
}; };
cipher.key = data.key; cipher.key = data.key;
cipher.name = data.name; cipher.name = data.name;
cipher.notes = data.notes; cipher.notes = data.notes;
cipher.fields = data.fields.map(|f| _clean_cipher_data(f).to_string()); cipher.fields = data.fields.map(|f| clean_cipher_data(f).to_string());
cipher.data = type_data.to_string(); cipher.data = type_data.to_string();
cipher.password_history = data.password_history.map(|f| f.to_string()); cipher.password_history = data.password_history.map(|f| f.to_string());
cipher.reprompt = data.reprompt.filter(|r| *r == RepromptType::None as i32 || *r == RepromptType::Password as i32); cipher.reprompt = data.reprompt.filter(|r| *r == RepromptType::None as i32 || *r == RepromptType::Password as i32);
@ -612,7 +608,7 @@ async fn post_ciphers_import(data: Json<ImportData>, headers: Headers, conn: DbC
let existing_folders: HashSet<Option<FolderId>> = let existing_folders: HashSet<Option<FolderId>> =
Folder::find_by_user(&headers.user.uuid, &conn).await.into_iter().map(|f| Some(f.uuid)).collect(); Folder::find_by_user(&headers.user.uuid, &conn).await.into_iter().map(|f| Some(f.uuid)).collect();
let mut folders: Vec<FolderId> = Vec::with_capacity(data.folders.len()); let mut folders: Vec<FolderId> = Vec::with_capacity(data.folders.len());
for folder in data.folders.into_iter() { for folder in data.folders {
let folder_id = if existing_folders.contains(&folder.id) { let folder_id = if existing_folders.contains(&folder.id) {
folder.id.unwrap() folder.id.unwrap()
} else { } else {
@ -737,11 +733,11 @@ async fn put_cipher_partial(
err!("Cipher does not exist", "Cipher is not accessible for the current user") err!("Cipher does not exist", "Cipher is not accessible for the current user")
} }
if let Some(ref folder_id) = data.folder_id { if let Some(ref folder_id) = data.folder_id
if Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, &conn).await.is_none() { && Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, &conn).await.is_none()
{
err!("Invalid folder", "Folder does not exist or belongs to another user"); err!("Invalid folder", "Folder does not exist or belongs to another user");
} }
}
// Move cipher // Move cipher
cipher.move_to_folder(data.folder_id.clone(), &headers.user.uuid, &conn).await?; cipher.move_to_folder(data.folder_id.clone(), &headers.user.uuid, &conn).await?;
@ -1004,7 +1000,7 @@ async fn put_cipher_share_selected(
err!("You must select at least one collection.") err!("You must select at least one collection.")
} }
for cipher in data.ciphers.iter() { for cipher in &data.ciphers {
if cipher.id.is_none() { if cipher.id.is_none() {
err!("Request missing ids field") err!("Request missing ids field")
} }
@ -1016,11 +1012,10 @@ async fn put_cipher_share_selected(
collection_ids: data.collection_ids.clone(), collection_ids: data.collection_ids.clone(),
}; };
match shared_cipher_data.cipher.id.take() { if let Some(id) = shared_cipher_data.cipher.id.take() {
Some(id) => {
share_cipher_by_uuid(&id, shared_cipher_data, &headers, &conn, &nt, Some(UpdateType::None)).await? share_cipher_by_uuid(&id, shared_cipher_data, &headers, &conn, &nt, Some(UpdateType::None)).await?
} } else {
None => err!("Request missing ids field"), err!("Request missing ids field")
}; };
} }
@ -1038,15 +1033,14 @@ async fn share_cipher_by_uuid(
nt: &Notify<'_>, nt: &Notify<'_>,
override_ut: Option<UpdateType>, override_ut: Option<UpdateType>,
) -> JsonResult { ) -> JsonResult {
let mut cipher = match Cipher::find_by_uuid(cipher_id, conn).await { let mut cipher = if let Some(cipher) = Cipher::find_by_uuid(cipher_id, conn).await {
Some(cipher) => {
if cipher.is_write_accessible_to_user(&headers.user.uuid, conn).await { if cipher.is_write_accessible_to_user(&headers.user.uuid, conn).await {
cipher cipher
} else { } else {
err!("Cipher is not write accessible") err!("Cipher is not write accessible")
} }
} } else {
None => err!("Cipher doesn't exist"), err!("Cipher doesn't exist")
}; };
let mut shared_to_collections = vec![]; let mut shared_to_collections = vec![];
@ -1065,7 +1059,7 @@ async fn share_cipher_by_uuid(
} }
} }
} }
}; }
// When LastKnownRevisionDate is None, it is a new cipher, so send CipherCreate. // When LastKnownRevisionDate is None, it is a new cipher, so send CipherCreate.
// If there is an override, like when handling multiple items, we want to prevent a push notification for every single item // If there is an override, like when handling multiple items, we want to prevent a push notification for every single item
@ -1263,11 +1257,11 @@ async fn save_attachment(
err!("Cipher is neither owned by a user nor an organization"); err!("Cipher is neither owned by a user nor an organization");
}; };
if let Some(size_limit) = size_limit { if let Some(size_limit) = size_limit
if size > size_limit { && size > size_limit
{
err!("Attachment storage limit exceeded with this file"); err!("Attachment storage limit exceeded with this file");
} }
}
let file_id = match &attachment { let file_id = match &attachment {
Some(attachment) => attachment.id.clone(), // v2 API Some(attachment) => attachment.id.clone(), // v2 API
@ -1408,7 +1402,7 @@ async fn post_attachment_share(
conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await?; delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await?;
post_attachment(cipher_id, data, headers, conn, nt).await post_attachment(cipher_id, data, headers, conn, nt).await
} }
@ -1442,7 +1436,7 @@ async fn delete_attachment(
conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
} }
#[delete("/ciphers/<cipher_id>/attachment/<attachment_id>/admin")] #[delete("/ciphers/<cipher_id>/attachment/<attachment_id>/admin")]
@ -1453,42 +1447,42 @@ async fn delete_attachment_admin(
conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
} }
#[post("/ciphers/<cipher_id>/delete")] #[post("/ciphers/<cipher_id>/delete")]
async fn delete_cipher_post(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn delete_cipher_post(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete // permanent delete
} }
#[post("/ciphers/<cipher_id>/delete-admin")] #[post("/ciphers/<cipher_id>/delete-admin")]
async fn delete_cipher_post_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn delete_cipher_post_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete // permanent delete
} }
#[put("/ciphers/<cipher_id>/delete")] #[put("/ciphers/<cipher_id>/delete")]
async fn delete_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn delete_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
// soft delete // soft delete
} }
#[put("/ciphers/<cipher_id>/delete-admin")] #[put("/ciphers/<cipher_id>/delete-admin")]
async fn delete_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn delete_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
// soft delete // soft delete
} }
#[delete("/ciphers/<cipher_id>")] #[delete("/ciphers/<cipher_id>")]
async fn delete_cipher(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn delete_cipher(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete // permanent delete
} }
#[delete("/ciphers/<cipher_id>/admin")] #[delete("/ciphers/<cipher_id>/admin")]
async fn delete_cipher_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn delete_cipher_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete // permanent delete
} }
@ -1499,7 +1493,7 @@ async fn delete_cipher_selected(
conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
// permanent delete // permanent delete
} }
@ -1510,7 +1504,7 @@ async fn delete_cipher_selected_post(
conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
// permanent delete // permanent delete
} }
@ -1521,7 +1515,7 @@ async fn delete_cipher_selected_put(
conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await
// soft delete // soft delete
} }
@ -1532,7 +1526,7 @@ async fn delete_cipher_selected_admin(
conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
// permanent delete // permanent delete
} }
@ -1543,7 +1537,7 @@ async fn delete_cipher_selected_post_admin(
conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
// permanent delete // permanent delete
} }
@ -1554,18 +1548,18 @@ async fn delete_cipher_selected_put_admin(
conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await
// soft delete // soft delete
} }
#[put("/ciphers/<cipher_id>/restore")] #[put("/ciphers/<cipher_id>/restore")]
async fn restore_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult { async fn restore_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
_restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
} }
#[put("/ciphers/<cipher_id>/restore-admin")] #[put("/ciphers/<cipher_id>/restore-admin")]
async fn restore_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult { async fn restore_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
_restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
} }
#[put("/ciphers/restore-admin", data = "<data>")] #[put("/ciphers/restore-admin", data = "<data>")]
@ -1575,7 +1569,7 @@ async fn restore_cipher_selected_admin(
conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
_restore_multiple_ciphers(data, &headers, &conn, &nt).await restore_multiple_ciphers(data, &headers, &conn, &nt).await
} }
#[put("/ciphers/restore", data = "<data>")] #[put("/ciphers/restore", data = "<data>")]
@ -1585,7 +1579,7 @@ async fn restore_cipher_selected(
conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
_restore_multiple_ciphers(data, &headers, &conn, &nt).await restore_multiple_ciphers(data, &headers, &conn, &nt).await
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -1606,11 +1600,11 @@ async fn move_cipher_selected(
let data = data.into_inner(); let data = data.into_inner();
let user_id = &headers.user.uuid; let user_id = &headers.user.uuid;
if let Some(ref folder_id) = data.folder_id { if let Some(ref folder_id) = data.folder_id
if Folder::find_by_uuid_and_user(folder_id, user_id, &conn).await.is_none() { && Folder::find_by_uuid_and_user(folder_id, user_id, &conn).await.is_none()
{
err!("Invalid folder", "Folder does not exist or belongs to another user"); err!("Invalid folder", "Folder does not exist or belongs to another user");
} }
}
let cipher_count = data.ids.len(); let cipher_count = data.ids.len();
let mut single_cipher: Option<Cipher> = None; let mut single_cipher: Option<Cipher> = None;
@ -1773,7 +1767,7 @@ pub enum CipherDeleteOptions {
HardMulti, HardMulti,
} }
async fn _delete_cipher_by_uuid( async fn delete_cipher_by_uuid(
cipher_id: &CipherId, cipher_id: &CipherId,
headers: &Headers, headers: &Headers,
conn: &DbConn, conn: &DbConn,
@ -1839,7 +1833,7 @@ struct CipherIdsData {
ids: Vec<CipherId>, ids: Vec<CipherId>,
} }
async fn _delete_multiple_ciphers( async fn delete_multiple_ciphers(
data: Json<CipherIdsData>, data: Json<CipherIdsData>,
headers: Headers, headers: Headers,
conn: DbConn, conn: DbConn,
@ -1849,9 +1843,9 @@ async fn _delete_multiple_ciphers(
let data = data.into_inner(); let data = data.into_inner();
for cipher_id in data.ids { for cipher_id in data.ids {
if let error @ Err(_) = _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &delete_options, &nt).await { if let error @ Err(_) = delete_cipher_by_uuid(&cipher_id, &headers, &conn, &delete_options, &nt).await {
return error; return error;
}; }
} }
// Multi delete actions do not send out a push for each cipher, we need to send a general sync here // Multi delete actions do not send out a push for each cipher, we need to send a general sync here
@ -1860,7 +1854,7 @@ async fn _delete_multiple_ciphers(
Ok(()) Ok(())
} }
async fn _restore_cipher_by_uuid( async fn restore_cipher_by_uuid(
cipher_id: &CipherId, cipher_id: &CipherId,
headers: &Headers, headers: &Headers,
multi_restore: bool, multi_restore: bool,
@ -1906,7 +1900,7 @@ async fn _restore_cipher_by_uuid(
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, conn).await?)) Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, conn).await?))
} }
async fn _restore_multiple_ciphers( async fn restore_multiple_ciphers(
data: Json<CipherIdsData>, data: Json<CipherIdsData>,
headers: &Headers, headers: &Headers,
conn: &DbConn, conn: &DbConn,
@ -1916,7 +1910,7 @@ async fn _restore_multiple_ciphers(
let mut ciphers: Vec<Value> = Vec::new(); let mut ciphers: Vec<Value> = Vec::new();
for cipher_id in data.ids { for cipher_id in data.ids {
match _restore_cipher_by_uuid(&cipher_id, headers, true, conn, nt).await { match restore_cipher_by_uuid(&cipher_id, headers, true, conn, nt).await {
Ok(json) => ciphers.push(json.into_inner()), Ok(json) => ciphers.push(json.into_inner()),
err => return err, err => return err,
} }
@ -1932,7 +1926,7 @@ async fn _restore_multiple_ciphers(
}))) })))
} }
async fn _delete_cipher_attachment_by_id( async fn delete_cipher_attachment_by_id(
cipher_id: &CipherId, cipher_id: &CipherId,
attachment_id: &AttachmentId, attachment_id: &AttachmentId,
headers: &Headers, headers: &Headers,
@ -2206,11 +2200,11 @@ impl CipherSyncData {
}; };
Self { Self {
cipher_archives,
cipher_attachments, cipher_attachments,
cipher_folders, cipher_folders,
cipher_favorites, cipher_favorites,
cipher_collections, cipher_collections,
cipher_archives,
members, members,
user_collections, user_collections,
user_collections_groups, user_collections_groups,

48
src/api/core/emergency_access.rs

@ -1,23 +1,23 @@
use chrono::{TimeDelta, Utc}; use chrono::{TimeDelta, Utc};
use rocket::{serde::json::Json, Route}; use rocket::{Route, serde::json::Json};
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
CONFIG,
api::{ api::{
core::{CipherSyncData, CipherSyncType},
EmptyResult, JsonResult, EmptyResult, JsonResult,
core::{CipherSyncData, CipherSyncType},
}, },
auth::{decode_emergency_access_invite, Headers}, auth::{Headers, decode_emergency_access_invite},
db::{ db::{
DbConn, DbPool,
models::{ models::{
Cipher, EmergencyAccess, EmergencyAccessId, EmergencyAccessStatus, EmergencyAccessType, Invitation, Cipher, EmergencyAccess, EmergencyAccessId, EmergencyAccessStatus, EmergencyAccessType, Invitation,
Membership, MembershipType, OrgPolicy, TwoFactor, User, UserId, Membership, MembershipType, OrgPolicy, TwoFactor, User, UserId,
}, },
DbConn, DbPool,
}, },
mail, mail,
util::NumberOrString, util::NumberOrString,
CONFIG,
}; };
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
@ -55,7 +55,7 @@ async fn get_contacts(headers: Headers, conn: DbConn) -> Json<Value> {
let mut emergency_access_list_json = Vec::with_capacity(emergency_access_list.len()); let mut emergency_access_list_json = Vec::with_capacity(emergency_access_list.len());
for ea in emergency_access_list { for ea in emergency_access_list {
if let Some(grantee) = ea.to_json_grantee_details(&conn).await { if let Some(grantee) = ea.to_json_grantee_details(&conn).await {
emergency_access_list_json.push(grantee) emergency_access_list_json.push(grantee);
} }
} }
@ -89,11 +89,14 @@ async fn get_grantees(headers: Headers, conn: DbConn) -> Json<Value> {
async fn get_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult { async fn get_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
match EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await { if let Some(emergency_access) =
Some(emergency_access) => Ok(Json( EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await
{
Ok(Json(
emergency_access.to_json_grantee_details(&conn).await.expect("Grantee user should exist but does not!"), emergency_access.to_json_grantee_details(&conn).await.expect("Grantee user should exist but does not!"),
)), ))
None => err!("Emergency access not valid."), } else {
err!("Emergency access not valid.")
} }
} }
@ -136,9 +139,10 @@ async fn post_emergency_access(
err!("Emergency access not valid.") err!("Emergency access not valid.")
}; };
let new_type = match EmergencyAccessType::from_str(&data.r#type.into_string()) { let new_type = if let Some(new_type) = EmergencyAccessType::from_str(&data.r#type.into_string()) {
Some(new_type) => new_type as i32, new_type as i32
None => err!("Invalid emergency access type."), } else {
err!("Invalid emergency access type.")
}; };
emergency_access.atype = new_type; emergency_access.atype = new_type;
@ -205,9 +209,10 @@ async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, co
let emergency_access_status = EmergencyAccessStatus::Invited as i32; let emergency_access_status = EmergencyAccessStatus::Invited as i32;
let new_type = match EmergencyAccessType::from_str(&data.r#type.into_string()) { let new_type = if let Some(new_type) = EmergencyAccessType::from_str(&data.r#type.into_string()) {
Some(new_type) => new_type as i32, new_type as i32
None => err!("Invalid emergency access type."), } else {
err!("Invalid emergency access type.")
}; };
let grantor_user = headers.user; let grantor_user = headers.user;
@ -342,12 +347,11 @@ async fn accept_invite(
err!("Claim email does not match current users email") err!("Claim email does not match current users email")
} }
let grantee_user = match User::find_by_mail(&claims.email, &conn).await { let grantee_user = if let Some(user) = User::find_by_mail(&claims.email, &conn).await {
Some(user) => {
Invitation::take(&claims.email, &conn).await; Invitation::take(&claims.email, &conn).await;
user user
} } else {
None => err!("Invited user not found"), err!("Invited user not found")
}; };
// We need to search for the uuid in combination with the email, since we do not yet store the uuid of the grantee in the database. // We need to search for the uuid in combination with the email, since we do not yet store the uuid of the grantee in the database.
@ -766,7 +770,7 @@ pub async fn emergency_request_timeout_job(pool: DbPool) {
} }
} }
} else { } else {
error!("Failed to get DB connection while searching emergency request timed out") error!("Failed to get DB connection while searching emergency request timed out");
} }
} }
@ -825,6 +829,6 @@ pub async fn emergency_notification_reminder_job(pool: DbPool) {
} }
} }
} else { } else {
error!("Failed to get DB connection while searching emergency notification reminder") error!("Failed to get DB connection while searching emergency notification reminder");
} }
} }

82
src/api/core/events.rs

@ -1,18 +1,18 @@
use std::net::IpAddr; use std::net::IpAddr;
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use rocket::{form::FromForm, serde::json::Json, Route}; use rocket::{Route, form::FromForm, serde::json::Json};
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
CONFIG,
api::{EmptyResult, JsonResult}, api::{EmptyResult, JsonResult},
auth::{AdminHeaders, Headers}, auth::{AdminHeaders, Headers},
db::{ db::{
models::{Cipher, CipherId, Event, Membership, MembershipId, OrganizationId, UserId},
DbConn, DbPool, DbConn, DbPool,
models::{Cipher, CipherId, Event, Membership, MembershipId, OrganizationId, UserId},
}, },
util::parse_date, util::parse_date,
CONFIG,
}; };
/// ############################################################################################################### /// ###############################################################################################################
@ -38,9 +38,7 @@ async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: Admin
// Return an empty vec when we org events are disabled. // Return an empty vec when we org events are disabled.
// This prevents client errors // This prevents client errors
let events_json: Vec<Value> = if !CONFIG.org_events_enabled() { let events_json: Vec<Value> = if CONFIG.org_events_enabled() {
Vec::with_capacity(0)
} else {
let start_date = parse_date(&data.start); let start_date = parse_date(&data.start);
let end_date = if let Some(before_date) = &data.continuation_token { let end_date = if let Some(before_date) = &data.continuation_token {
parse_date(before_date) parse_date(before_date)
@ -51,8 +49,10 @@ async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: Admin
Event::find_by_organization_uuid(&org_id, &start_date, &end_date, &conn) Event::find_by_organization_uuid(&org_id, &start_date, &end_date, &conn)
.await .await
.iter() .iter()
.map(|e| e.to_json()) .map(Event::to_json)
.collect() .collect()
} else {
Vec::with_capacity(0)
}; };
Ok(Json(json!({ Ok(Json(json!({
@ -64,13 +64,11 @@ async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: Admin
#[get("/ciphers/<cipher_id>/events?<data..>")] #[get("/ciphers/<cipher_id>/events?<data..>")]
async fn get_cipher_events(cipher_id: CipherId, data: EventRange, headers: Headers, conn: DbConn) -> JsonResult { async fn get_cipher_events(cipher_id: CipherId, data: EventRange, headers: Headers, conn: DbConn) -> JsonResult {
// Return an empty vec when we org events are disabled. // Return an empty vec when org events are disabled.
// This prevents client errors // This prevents client errors
let events_json: Vec<Value> = if !CONFIG.org_events_enabled() { let events_json: Vec<Value> = if CONFIG.org_events_enabled()
Vec::with_capacity(0) && Membership::user_has_ge_admin_access_to_cipher(&headers.user.uuid, &cipher_id, &conn).await
} else { {
let mut events_json = Vec::with_capacity(0);
if Membership::user_has_ge_admin_access_to_cipher(&headers.user.uuid, &cipher_id, &conn).await {
let start_date = parse_date(&data.start); let start_date = parse_date(&data.start);
let end_date = if let Some(before_date) = &data.continuation_token { let end_date = if let Some(before_date) = &data.continuation_token {
parse_date(before_date) parse_date(before_date)
@ -78,13 +76,9 @@ async fn get_cipher_events(cipher_id: CipherId, data: EventRange, headers: Heade
parse_date(&data.end) parse_date(&data.end)
}; };
events_json = Event::find_by_cipher_uuid(&cipher_id, &start_date, &end_date, &conn) Event::find_by_cipher_uuid(&cipher_id, &start_date, &end_date, &conn).await.iter().map(Event::to_json).collect()
.await } else {
.iter() Vec::with_capacity(0)
.map(|e| e.to_json())
.collect()
}
events_json
}; };
Ok(Json(json!({ Ok(Json(json!({
@ -107,9 +101,7 @@ async fn get_user_events(
} }
// Return an empty vec when we org events are disabled. // Return an empty vec when we org events are disabled.
// This prevents client errors // This prevents client errors
let events_json: Vec<Value> = if !CONFIG.org_events_enabled() { let events_json: Vec<Value> = if CONFIG.org_events_enabled() {
Vec::with_capacity(0)
} else {
let start_date = parse_date(&data.start); let start_date = parse_date(&data.start);
let end_date = if let Some(before_date) = &data.continuation_token { let end_date = if let Some(before_date) = &data.continuation_token {
parse_date(before_date) parse_date(before_date)
@ -120,8 +112,10 @@ async fn get_user_events(
Event::find_by_org_and_member(&org_id, &member_id, &start_date, &end_date, &conn) Event::find_by_org_and_member(&org_id, &member_id, &start_date, &end_date, &conn)
.await .await
.iter() .iter()
.map(|e| e.to_json()) .map(Event::to_json)
.collect() .collect()
} else {
Vec::with_capacity(0)
}; };
Ok(Json(json!({ Ok(Json(json!({
@ -134,7 +128,8 @@ async fn get_user_events(
fn get_continuation_token(events_json: &[Value]) -> Option<&str> { fn get_continuation_token(events_json: &[Value]) -> Option<&str> {
// When the length of the vec equals the max page_size there probably is more data // When the length of the vec equals the max page_size there probably is more data
// When it is less, then all events are loaded. // When it is less, then all events are loaded.
if events_json.len() as i64 == Event::PAGE_SIZE { #[expect(clippy::cast_possible_truncation, reason = "PAGE_SIZE fits within usize")]
if events_json.len() == Event::PAGE_SIZE as usize {
if let Some(last_event) = events_json.last() { if let Some(last_event) = events_json.last() {
last_event["date"].as_str() last_event["date"].as_str()
} else { } else {
@ -176,7 +171,7 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
let event_date = parse_date(&event.date); let event_date = parse_date(&event.date);
match event.r#type { match event.r#type {
1000..=1099 => { 1000..=1099 => {
_log_user_event( log_user_event_impl(
event.r#type, event.r#type,
&headers.user.uuid, &headers.user.uuid,
headers.device.atype, headers.device.atype,
@ -188,7 +183,7 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
} }
1600..=1699 => { 1600..=1699 => {
if let Some(org_id) = &event.organization_id { if let Some(org_id) = &event.organization_id {
_log_event( log_event_impl(
event.r#type, event.r#type,
org_id, org_id,
org_id, org_id,
@ -202,10 +197,11 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
} }
} }
_ => { _ => {
if let Some(cipher_uuid) = &event.cipher_id { if let Some(cipher_uuid) = &event.cipher_id
if let Some(cipher) = Cipher::find_by_uuid(cipher_uuid, &conn).await { && let Some(cipher) = Cipher::find_by_uuid(cipher_uuid, &conn).await
if let Some(org_id) = cipher.organization_uuid { && let Some(org_id) = cipher.organization_uuid
_log_event( {
log_event_impl(
event.r#type, event.r#type,
cipher_uuid, cipher_uuid,
&org_id, &org_id,
@ -220,8 +216,6 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
} }
} }
} }
}
}
Ok(()) Ok(())
} }
@ -229,10 +223,10 @@ pub async fn log_user_event(event_type: i32, user_id: &UserId, device_type: i32,
if !CONFIG.org_events_enabled() { if !CONFIG.org_events_enabled() {
return; return;
} }
_log_user_event(event_type, user_id, device_type, None, ip, conn).await; log_user_event_impl(event_type, user_id, device_type, None, ip, conn).await;
} }
async fn _log_user_event( async fn log_user_event_impl(
event_type: i32, event_type: i32,
user_id: &UserId, user_id: &UserId,
device_type: i32, device_type: i32,
@ -278,11 +272,11 @@ pub async fn log_event(
if !CONFIG.org_events_enabled() { if !CONFIG.org_events_enabled() {
return; return;
} }
_log_event(event_type, source_uuid, org_id, act_user_id, device_type, None, ip, conn).await; log_event_impl(event_type, source_uuid, org_id, act_user_id, device_type, None, ip, conn).await;
} }
#[allow(clippy::too_many_arguments)] #[expect(clippy::too_many_arguments)]
async fn _log_event( async fn log_event_impl(
event_type: i32, event_type: i32,
source_uuid: &str, source_uuid: &str,
org_id: &OrganizationId, org_id: &OrganizationId,
@ -298,24 +292,24 @@ async fn _log_event(
// 1000..=1099 Are user events, they need to be logged via log_user_event() // 1000..=1099 Are user events, they need to be logged via log_user_event()
// Cipher Events // Cipher Events
1100..=1199 => { 1100..=1199 => {
event.cipher_uuid = Some(source_uuid.to_string().into()); event.cipher_uuid = Some(source_uuid.to_owned().into());
} }
// Collection Events // Collection Events
1300..=1399 => { 1300..=1399 => {
event.collection_uuid = Some(source_uuid.to_string().into()); event.collection_uuid = Some(source_uuid.to_owned().into());
} }
// Group Events // Group Events
1400..=1499 => { 1400..=1499 => {
event.group_uuid = Some(source_uuid.to_string().into()); event.group_uuid = Some(source_uuid.to_owned().into());
} }
// Org User Events // Org User Events
1500..=1599 => { 1500..=1599 => {
event.org_user_uuid = Some(source_uuid.to_string().into()); event.org_user_uuid = Some(source_uuid.to_owned().into());
} }
// 1600..=1699 Are organizational events, and they do not need the source_uuid // 1600..=1699 Are organizational events, and they do not need the source_uuid
// Policy Events // Policy Events
1700..=1799 => { 1700..=1799 => {
event.policy_uuid = Some(source_uuid.to_string().into()); event.policy_uuid = Some(source_uuid.to_owned().into());
} }
// Ignore others // Ignore others
_ => {} _ => {}
@ -338,6 +332,6 @@ pub async fn event_cleanup_job(pool: DbPool) {
if let Ok(conn) = pool.get().await { if let Ok(conn) = pool.get().await {
Event::clean_events(&conn).await.ok(); Event::clean_events(&conn).await.ok();
} else { } else {
error!("Failed to get DB connection while trying to cleanup the events table") error!("Failed to get DB connection while trying to cleanup the events table");
} }
} }

9
src/api/core/folders.rs

@ -5,8 +5,8 @@ use crate::{
api::{EmptyResult, JsonResult, Notify, UpdateType}, api::{EmptyResult, JsonResult, Notify, UpdateType},
auth::Headers, auth::Headers,
db::{ db::{
models::{Folder, FolderId},
DbConn, DbConn,
models::{Folder, FolderId},
}, },
util::deser_opt_nonempty_str, util::deser_opt_nonempty_str,
}; };
@ -29,9 +29,10 @@ async fn get_folders(headers: Headers, conn: DbConn) -> Json<Value> {
#[get("/folders/<folder_id>")] #[get("/folders/<folder_id>")]
async fn get_folder(folder_id: FolderId, headers: Headers, conn: DbConn) -> JsonResult { async fn get_folder(folder_id: FolderId, headers: Headers, conn: DbConn) -> JsonResult {
match Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await { if let Some(folder) = Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await {
Some(folder) => Ok(Json(folder.to_json())), Ok(Json(folder.to_json()))
_ => err!("Invalid folder", "Folder does not exist or belongs to another user"), } else {
err!("Invalid folder", "Folder does not exist or belongs to another user")
} }
} }

86
src/api/core/mod.rs

@ -1,4 +1,6 @@
pub mod accounts; pub mod accounts;
pub mod two_factor;
mod ciphers; mod ciphers;
mod emergency_access; mod emergency_access;
mod events; mod events;
@ -6,17 +8,32 @@ mod folders;
mod organizations; mod organizations;
mod public; mod public;
mod sends; mod sends;
pub mod two_factor;
pub use accounts::purge_auth_requests; pub use accounts::purge_auth_requests;
pub use ciphers::{purge_trashed_ciphers, CipherData, CipherSyncData, CipherSyncType}; pub use ciphers::{CipherData, CipherSyncData, CipherSyncType, purge_trashed_ciphers};
pub use emergency_access::{emergency_notification_reminder_job, emergency_request_timeout_job}; pub use emergency_access::{emergency_notification_reminder_job, emergency_request_timeout_job};
pub use events::{event_cleanup_job, log_event, log_user_event}; pub use events::{event_cleanup_job, log_event, log_user_event};
use reqwest::Method;
pub use sends::purge_sends; pub use sends::purge_sends;
use reqwest::Method;
use rocket::{Catcher, Route, serde::json::Json, serde::json::Value};
use crate::{
CONFIG,
api::{EmptyResult, JsonResult, Notify, UpdateType},
auth::Headers,
db::{
DbConn,
models::{Membership, MembershipStatus, OrgPolicy, Organization, User},
},
error::Error,
http_client::make_http_request,
mail,
util::{FeatureFlagFilter, parse_experimental_client_feature_flags},
};
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
let mut eq_domains_routes = routes![get_eq_domains, post_eq_domains, put_eq_domains]; let mut eq_domains_routes = routes![get_settings_domains, post_settings_domains, put_settings_domains];
let mut hibp_routes = routes![hibp_breach]; let mut hibp_routes = routes![hibp_breach];
let mut meta_routes = routes![alive, now, version, config, get_api_webauthn]; let mut meta_routes = routes![alive, now, version, config, get_api_webauthn];
@ -44,25 +61,6 @@ pub fn events_routes() -> Vec<Route> {
routes routes
} }
//
// Move this somewhere else
//
use rocket::{serde::json::Json, serde::json::Value, Catcher, Route};
use crate::{
api::{EmptyResult, JsonResult, Notify, UpdateType},
auth::Headers,
db::{
models::{Membership, MembershipStatus, OrgPolicy, Organization, User},
DbConn,
},
error::Error,
http_client::make_http_request,
mail,
util::{parse_experimental_client_feature_flags, FeatureFlagFilter},
CONFIG,
};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
struct GlobalDomain { struct GlobalDomain {
@ -73,15 +71,17 @@ struct GlobalDomain {
const GLOBAL_DOMAINS: &str = include_str!("../../static/global_domains.json"); const GLOBAL_DOMAINS: &str = include_str!("../../static/global_domains.json");
#[expect(clippy::needless_pass_by_value, reason = "Not beneficial for Headers")]
#[get("/settings/domains")] #[get("/settings/domains")]
fn get_eq_domains(headers: Headers) -> Json<Value> { fn get_settings_domains(headers: Headers) -> Json<Value> {
_get_eq_domains(&headers, false) get_eq_domains(&headers, false)
} }
fn _get_eq_domains(headers: &Headers, no_excluded: bool) -> Json<Value> { fn get_eq_domains(headers: &Headers, no_excluded: bool) -> Json<Value> {
let user = &headers.user;
use serde_json::from_str; use serde_json::from_str;
let user = &headers.user;
let equivalent_domains: Vec<Vec<String>> = from_str(&user.equivalent_domains).unwrap(); let equivalent_domains: Vec<Vec<String>> = from_str(&user.equivalent_domains).unwrap();
let excluded_globals: Vec<i32> = from_str(&user.excluded_globals).unwrap(); let excluded_globals: Vec<i32> = from_str(&user.excluded_globals).unwrap();
@ -110,17 +110,23 @@ struct EquivDomainData {
} }
#[post("/settings/domains", data = "<data>")] #[post("/settings/domains", data = "<data>")]
async fn post_eq_domains(data: Json<EquivDomainData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult { async fn post_settings_domains(
data: Json<EquivDomainData>,
headers: Headers,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
use serde_json::to_string;
let data: EquivDomainData = data.into_inner(); let data: EquivDomainData = data.into_inner();
let excluded_globals = data.excluded_global_equivalent_domains.unwrap_or_default(); let excluded_globals = data.excluded_global_equivalent_domains.unwrap_or_default();
let equivalent_domains = data.equivalent_domains.unwrap_or_default(); let equivalent_domains = data.equivalent_domains.unwrap_or_default();
let mut user = headers.user; let mut user = headers.user;
use serde_json::to_string;
user.excluded_globals = to_string(&excluded_globals).unwrap_or_else(|_| "[]".to_string()); user.excluded_globals = to_string(&excluded_globals).unwrap_or_else(|_| "[]".to_owned());
user.equivalent_domains = to_string(&equivalent_domains).unwrap_or_else(|_| "[]".to_string()); user.equivalent_domains = to_string(&equivalent_domains).unwrap_or_else(|_| "[]".to_owned());
user.save(&conn).await?; user.save(&conn).await?;
@ -130,8 +136,13 @@ async fn post_eq_domains(data: Json<EquivDomainData>, headers: Headers, conn: Db
} }
#[put("/settings/domains", data = "<data>")] #[put("/settings/domains", data = "<data>")]
async fn put_eq_domains(data: Json<EquivDomainData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult { async fn put_settings_domains(
post_eq_domains(data, headers, conn, nt).await data: Json<EquivDomainData>,
headers: Headers,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
post_settings_domains(data, headers, conn, nt).await
} }
#[get("/hibp/breach?<username>")] #[get("/hibp/breach?<username>")]
@ -206,9 +217,9 @@ fn config() -> Json<Value> {
// iOS (v2026.2.1): https://github.com/bitwarden/ios/blob/cdd9ba1770ca2ffc098d02d12cc3208e3a830454/BitwardenShared/Core/Platform/Models/Enum/FeatureFlag.swift#L7 // iOS (v2026.2.1): https://github.com/bitwarden/ios/blob/cdd9ba1770ca2ffc098d02d12cc3208e3a830454/BitwardenShared/Core/Platform/Models/Enum/FeatureFlag.swift#L7
let mut feature_states = parse_experimental_client_feature_flags( let mut feature_states = parse_experimental_client_feature_flags(
&CONFIG.experimental_client_feature_flags(), &CONFIG.experimental_client_feature_flags(),
FeatureFlagFilter::ValidOnly, &FeatureFlagFilter::ValidOnly,
); );
feature_states.insert("pm-19148-innovation-archive".to_string(), true); feature_states.insert("pm-19148-innovation-archive".to_owned(), true);
Json(json!({ Json(json!({
// Note: The clients use this version to handle backwards compatibility concerns // Note: The clients use this version to handle backwards compatibility concerns
@ -278,9 +289,8 @@ async fn accept_org_invite(
member.save(conn).await?; member.save(conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
let org = match Organization::find_by_uuid(&member.org_uuid, conn).await { let Some(org) = Organization::find_by_uuid(&member.org_uuid, conn).await else {
Some(org) => org, err!("Organization not found.")
None => err!("Organization not found."),
}; };
// User was invited to an organization, so they must be confirmed manually after acceptance // User was invited to an organization, so they must be confirmed manually after acceptance
mail::send_invite_accepted(&user.email, &member.invited_by_email.unwrap_or(org.billing_email), &org.name) mail::send_invite_accepted(&user.email, &member.invited_by_email.unwrap_or(org.billing_email), &org.name)

214
src/api/core/organizations.rs

@ -1,28 +1,28 @@
use std::collections::{HashMap, HashSet};
use num_traits::FromPrimitive; use num_traits::FromPrimitive;
use rocket::serde::json::Json; use rocket::{Route, serde::json::Json};
use rocket::Route;
use serde_json::Value; use serde_json::Value;
use std::collections::{HashMap, HashSet};
use crate::api::admin::FAKE_ADMIN_UUID;
use crate::{ use crate::{
CONFIG,
api::admin::FAKE_ADMIN_UUID,
api::{ api::{
core::{accept_org_invite, log_event, two_factor, CipherSyncData, CipherSyncType},
EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType,
core::{CipherSyncData, CipherSyncType, accept_org_invite, log_event, two_factor},
}, },
auth::{decode_invite, AdminHeaders, Headers, ManagerHeaders, ManagerHeadersLoose, OrgMemberHeaders, OwnerHeaders}, auth::{AdminHeaders, Headers, ManagerHeaders, ManagerHeadersLoose, OrgMemberHeaders, OwnerHeaders, decode_invite},
db::{ db::{
DbConn,
models::{ models::{
Cipher, CipherId, Collection, CollectionCipher, CollectionGroup, CollectionId, CollectionUser, EventType, Cipher, CipherId, Collection, CollectionCipher, CollectionGroup, CollectionId, CollectionUser, EventType,
Group, GroupId, GroupUser, Invitation, Membership, MembershipId, MembershipStatus, MembershipType, Group, GroupId, GroupUser, Invitation, Membership, MembershipId, MembershipStatus, MembershipType,
OrgPolicy, OrgPolicyType, Organization, OrganizationApiKey, OrganizationId, User, UserId, OrgPolicy, OrgPolicyType, Organization, OrganizationApiKey, OrganizationId, User, UserId,
}, },
DbConn,
}, },
mail, mail,
sso::FAKE_SSO_IDENTIFIER, sso::FAKE_SSO_IDENTIFIER,
util::{convert_json_key_lcase_first, NumberOrString}, util::{NumberOrString, convert_json_key_lcase_first},
CONFIG,
}; };
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
@ -97,7 +97,7 @@ pub fn routes() -> Vec<Route> {
get_reset_password_details, get_reset_password_details,
put_reset_password, put_reset_password,
get_org_export, get_org_export,
api_key, post_api_key,
rotate_api_key, rotate_api_key,
get_billing_metadata, get_billing_metadata,
get_billing_warnings, get_billing_warnings,
@ -286,9 +286,10 @@ async fn get_organization(org_id: OrganizationId, headers: OwnerHeaders, conn: D
if org_id != headers.org_id { if org_id != headers.org_id {
err!("Organization not found", "Organization id's do not match"); err!("Organization not found", "Organization id's do not match");
} }
match Organization::find_by_uuid(&org_id, &conn).await { if let Some(organization) = Organization::find_by_uuid(&org_id, &conn).await {
Some(organization) => Ok(Json(organization.to_json())), Ok(Json(organization.to_json()))
None => err!("Can't find organization details"), } else {
err!("Can't find organization details")
} }
} }
@ -367,7 +368,7 @@ async fn get_auto_enroll_status(identifier: &str, headers: Headers, conn: DbConn
}; };
let (id, identifier, rp_auto_enroll) = match org { let (id, identifier, rp_auto_enroll) = match org {
None => (identifier.to_string(), identifier.to_string(), false), None => (identifier.to_owned(), identifier.to_owned(), false),
Some(org) => ( Some(org) => (
org.uuid.to_string(), org.uuid.to_string(),
org.uuid.to_string(), org.uuid.to_string(),
@ -393,7 +394,7 @@ async fn get_org_collections(org_id: OrganizationId, headers: ManagerHeadersLoos
} }
Ok(Json(json!({ Ok(Json(json!({
"data": _get_org_collections(&org_id, &conn).await, "data": get_org_collections_impl(&org_id, &conn).await,
"object": "list", "object": "list",
"continuationToken": null, "continuationToken": null,
}))) })))
@ -465,7 +466,7 @@ async fn get_org_collections_details(org_id: OrganizationId, headers: ManagerHea
CollectionGroup::find_by_collection(&col.uuid, &conn) CollectionGroup::find_by_collection(&col.uuid, &conn)
.await .await
.iter() .iter()
.map(|collection_group| collection_group.to_json_details_for_group()) .map(CollectionGroup::to_json_details_for_group)
.collect() .collect()
} else { } else {
Vec::with_capacity(0) Vec::with_capacity(0)
@ -477,7 +478,7 @@ async fn get_org_collections_details(org_id: OrganizationId, headers: ManagerHea
json_object["groups"] = json!(groups); json_object["groups"] = json!(groups);
json_object["object"] = json!("collectionAccessDetails"); json_object["object"] = json!("collectionAccessDetails");
json_object["unmanaged"] = json!(false); json_object["unmanaged"] = json!(false);
data.push(json_object) data.push(json_object);
} }
Ok(Json(json!({ Ok(Json(json!({
@ -487,7 +488,7 @@ async fn get_org_collections_details(org_id: OrganizationId, headers: ManagerHea
}))) })))
} }
async fn _get_org_collections(org_id: &OrganizationId, conn: &DbConn) -> Value { async fn get_org_collections_impl(org_id: &OrganizationId, conn: &DbConn) -> Value {
Collection::find_by_organization(org_id, conn).await.iter().map(Collection::to_json).collect::<Value>() Collection::find_by_organization(org_id, conn).await.iter().map(Collection::to_json).collect::<Value>()
} }
@ -573,7 +574,7 @@ async fn post_bulk_access_collections(
if Organization::find_by_uuid(&org_id, &conn).await.is_none() { if Organization::find_by_uuid(&org_id, &conn).await.is_none() {
err!("Can't find organization details") err!("Can't find organization details")
}; }
for col_id in data.collection_ids { for col_id in data.collection_ids {
let Some(collection) = Collection::find_by_uuid_and_org(&col_id, &org_id, &conn).await else { let Some(collection) = Collection::find_by_uuid_and_org(&col_id, &org_id, &conn).await else {
@ -650,7 +651,7 @@ async fn post_organization_collection_update(
if Organization::find_by_uuid(&org_id, &conn).await.is_none() { if Organization::find_by_uuid(&org_id, &conn).await.is_none() {
err!("Can't find organization details") err!("Can't find organization details")
}; }
let Some(mut collection) = Collection::find_by_uuid_and_org(&col_id, &org_id, &conn).await else { let Some(mut collection) = Collection::find_by_uuid_and_org(&col_id, &org_id, &conn).await else {
err!("Collection not found") err!("Collection not found")
@ -701,7 +702,7 @@ async fn post_organization_collection_update(
Ok(Json(collection.to_json_details(&headers.user.uuid, None, &conn).await)) Ok(Json(collection.to_json_details(&headers.user.uuid, None, &conn).await))
} }
async fn _delete_organization_collection( async fn delete_organization_collection_impl(
org_id: &OrganizationId, org_id: &OrganizationId,
col_id: &CollectionId, col_id: &CollectionId,
headers: &ManagerHeaders, headers: &ManagerHeaders,
@ -733,7 +734,7 @@ async fn delete_organization_collection(
headers: ManagerHeaders, headers: ManagerHeaders,
conn: DbConn, conn: DbConn,
) -> EmptyResult { ) -> EmptyResult {
_delete_organization_collection(&org_id, &col_id, &headers, &conn).await delete_organization_collection_impl(&org_id, &col_id, &headers, &conn).await
} }
#[post("/organizations/<org_id>/collections/<col_id>/delete")] #[post("/organizations/<org_id>/collections/<col_id>/delete")]
@ -743,7 +744,7 @@ async fn post_organization_collection_delete(
headers: ManagerHeaders, headers: ManagerHeaders,
conn: DbConn, conn: DbConn,
) -> EmptyResult { ) -> EmptyResult {
_delete_organization_collection(&org_id, &col_id, &headers, &conn).await delete_organization_collection_impl(&org_id, &col_id, &headers, &conn).await
} }
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
@ -769,7 +770,7 @@ async fn bulk_delete_organization_collections(
let headers = ManagerHeaders::from_loose(headers, &collections, &conn).await?; let headers = ManagerHeaders::from_loose(headers, &collections, &conn).await?;
for col_id in collections { for col_id in collections {
_delete_organization_collection(&org_id, &col_id, &headers, &conn).await? delete_organization_collection_impl(&org_id, &col_id, &headers, &conn).await?;
} }
Ok(()) Ok(())
} }
@ -799,7 +800,7 @@ async fn get_org_collection_detail(
CollectionGroup::find_by_collection(&collection.uuid, &conn) CollectionGroup::find_by_collection(&collection.uuid, &conn)
.await .await
.iter() .iter()
.map(|collection_group| collection_group.to_json_details_for_group()) .map(CollectionGroup::to_json_details_for_group)
.collect() .collect()
} else { } else {
// The Bitwarden clients seem to call this API regardless of whether groups are enabled, // The Bitwarden clients seem to call this API regardless of whether groups are enabled,
@ -886,13 +887,13 @@ async fn get_org_details(data: OrgIdData, headers: ManagerHeadersLoose, conn: Db
} }
Ok(Json(json!({ Ok(Json(json!({
"data": _get_org_details(&data.organization_id, &headers.host, &headers.user.uuid, &conn).await?, "data": get_org_details_impl(&data.organization_id, &headers.host, &headers.user.uuid, &conn).await?,
"object": "list", "object": "list",
"continuationToken": null, "continuationToken": null,
}))) })))
} }
async fn _get_org_details( async fn get_org_details_impl(
org_id: &OrganizationId, org_id: &OrganizationId,
host: &str, host: &str,
user_id: &UserId, user_id: &UserId,
@ -975,14 +976,13 @@ async fn post_org_keys(
} }
let data: OrgKeyData = data.into_inner(); let data: OrgKeyData = data.into_inner();
let mut org = match Organization::find_by_uuid(&org_id, &conn).await { let mut org = if let Some(organization) = Organization::find_by_uuid(&org_id, &conn).await {
Some(organization) => {
if organization.private_key.is_some() && organization.public_key.is_some() { if organization.private_key.is_some() && organization.public_key.is_some() {
err!("Organization Keys already exist") err!("Organization Keys already exist")
} }
organization organization
} } else {
None => err!("Can't find organization details"), err!("Can't find organization details")
}; };
org.private_key = Some(data.encrypted_private_key); org.private_key = Some(data.encrypted_private_key);
@ -1043,9 +1043,10 @@ async fn send_invite(
// The from_str() will convert the custom role type into a manager role type // The from_str() will convert the custom role type into a manager role type
let raw_type = &data.r#type.into_string(); let raw_type = &data.r#type.into_string();
// Membership::from_str will convert custom (4) to manager (3) // Membership::from_str will convert custom (4) to manager (3)
let new_type = match MembershipType::from_str(raw_type) { let new_type = if let Some(new_type) = MembershipType::from_str(raw_type) {
Some(new_type) => new_type as i32, new_type as i32
None => err!("Invalid type"), } else {
err!("Invalid type")
}; };
if new_type != MembershipType::User && headers.membership_type != MembershipType::Owner { if new_type != MembershipType::User && headers.membership_type != MembershipType::Owner {
@ -1062,7 +1063,7 @@ async fn send_invite(
&& data.permissions.get("createNewCollections") == Some(&json!(true))); && data.permissions.get("createNewCollections") == Some(&json!(true)));
let mut user_created: bool = false; let mut user_created: bool = false;
for email in data.emails.iter() { for email in &data.emails {
let mut member_status = MembershipStatus::Invited as i32; let mut member_status = MembershipStatus::Invited as i32;
let user = match User::find_by_mail(email, &conn).await { let user = match User::find_by_mail(email, &conn).await {
None => { None => {
@ -1086,14 +1087,14 @@ async fn send_invite(
Some(user) => { Some(user) => {
if Membership::find_by_user_and_org(&user.uuid, &org_id, &conn).await.is_some() { if Membership::find_by_user_and_org(&user.uuid, &org_id, &conn).await.is_some() {
err!(format!("User already in organization: {email}")) err!(format!("User already in organization: {email}"))
} else { }
// automatically accept existing users if mail is disabled // automatically accept existing users if mail is disabled
if !CONFIG.mail_enabled() && !user.password_hash.is_empty() { if !CONFIG.mail_enabled() && !user.password_hash.is_empty() {
member_status = MembershipStatus::Accepted as i32; member_status = MembershipStatus::Accepted as i32;
} }
user user
} }
}
}; };
let mut new_member = Membership::new(user.uuid.clone(), org_id.clone(), Some(headers.user.email.clone())); let mut new_member = Membership::new(user.uuid.clone(), org_id.clone(), Some(headers.user.email.clone()));
@ -1103,9 +1104,10 @@ async fn send_invite(
new_member.save(&conn).await?; new_member.save(&conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
let org_name = match Organization::find_by_uuid(&org_id, &conn).await { let org_name = if let Some(org) = Organization::find_by_uuid(&org_id, &conn).await {
Some(org) => org.name, org.name
None => err!("Error looking up organization"), } else {
err!("Error looking up organization")
}; };
if let Err(e) = mail::send_invite( if let Err(e) = mail::send_invite(
@ -1159,7 +1161,7 @@ async fn send_invite(
} }
} }
for group_id in data.groups.iter() { for group_id in &data.groups {
let mut group_entry = GroupUser::new(group_id.clone(), new_member.uuid.clone()); let mut group_entry = GroupUser::new(group_id.clone(), new_member.uuid.clone());
group_entry.save(&conn).await?; group_entry.save(&conn).await?;
} }
@ -1182,8 +1184,8 @@ async fn bulk_reinvite_members(
let mut bulk_response = Vec::new(); let mut bulk_response = Vec::new();
for member_id in data.ids { for member_id in data.ids {
let err_msg = match _reinvite_member(&org_id, &member_id, &headers.user.email, &conn).await { let err_msg = match reinvite_member_impl(&org_id, &member_id, &headers.user.email, &conn).await {
Ok(_) => String::new(), Ok(()) => String::new(),
Err(e) => format!("{e:?}"), Err(e) => format!("{e:?}"),
}; };
@ -1193,7 +1195,7 @@ async fn bulk_reinvite_members(
"id": member_id, "id": member_id,
"error": err_msg "error": err_msg
} }
)) ));
} }
Ok(Json(json!({ Ok(Json(json!({
@ -1213,10 +1215,10 @@ async fn reinvite_member(
if org_id != headers.org_id { if org_id != headers.org_id {
err!("Organization not found", "Organization id's do not match"); err!("Organization not found", "Organization id's do not match");
} }
_reinvite_member(&org_id, &member_id, &headers.user.email, &conn).await reinvite_member_impl(&org_id, &member_id, &headers.user.email, &conn).await
} }
async fn _reinvite_member( async fn reinvite_member_impl(
org_id: &OrganizationId, org_id: &OrganizationId,
member_id: &MembershipId, member_id: &MembershipId,
invited_by_email: &str, invited_by_email: &str,
@ -1238,13 +1240,14 @@ async fn _reinvite_member(
err!("Invitations are not allowed.") err!("Invitations are not allowed.")
} }
let org_name = match Organization::find_by_uuid(org_id, conn).await { let org_name = if let Some(org) = Organization::find_by_uuid(org_id, conn).await {
Some(org) => org.name, org.name
None => err!("Error looking up organization."), } else {
err!("Error looking up organization.")
}; };
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_invite(&user, org_id.clone(), member.uuid, &org_name, Some(invited_by_email.to_string())).await?; mail::send_invite(&user, org_id.clone(), member.uuid, &org_name, Some(invited_by_email.to_owned())).await?;
} else if user.password_hash.is_empty() { } else if user.password_hash.is_empty() {
let invitation = Invitation::new(&user.email); let invitation = Invitation::new(&user.email);
invitation.save(conn).await?; invitation.save(conn).await?;
@ -1352,8 +1355,8 @@ async fn bulk_confirm_invite(
for invite in keys { for invite in keys {
let member_id = invite.id.unwrap(); let member_id = invite.id.unwrap();
let user_key = invite.key.unwrap_or_default(); let user_key = invite.key.unwrap_or_default();
let err_msg = match _confirm_invite(&org_id, &member_id, &user_key, &headers, &conn, &nt).await { let err_msg = match confirm_invite_impl(&org_id, &member_id, &user_key, &headers, &conn, &nt).await {
Ok(_) => String::new(), Ok(()) => String::new(),
Err(e) => format!("{e:?}"), Err(e) => format!("{e:?}"),
}; };
@ -1387,10 +1390,10 @@ async fn confirm_invite(
) -> EmptyResult { ) -> EmptyResult {
let data = data.into_inner(); let data = data.into_inner();
let user_key = data.key.unwrap_or_default(); let user_key = data.key.unwrap_or_default();
_confirm_invite(&org_id, &member_id, &user_key, &headers, &conn, &nt).await confirm_invite_impl(&org_id, &member_id, &user_key, &headers, &conn, &nt).await
} }
async fn _confirm_invite( async fn confirm_invite_impl(
org_id: &OrganizationId, org_id: &OrganizationId,
member_id: &MembershipId, member_id: &MembershipId,
key: &str, key: &str,
@ -1418,7 +1421,7 @@ async fn _confirm_invite(
} }
member_to_confirm.status = MembershipStatus::Confirmed as i32; member_to_confirm.status = MembershipStatus::Confirmed as i32;
member_to_confirm.akey = key.to_string(); member_to_confirm.akey = key.to_owned();
// This check is also done at accept_invite, _confirm_invite, _activate_member, edit_member, admin::update_membership_type // This check is also done at accept_invite, _confirm_invite, _activate_member, edit_member, admin::update_membership_type
OrgPolicy::check_user_allowed(&member_to_confirm, "confirm", conn).await?; OrgPolicy::check_user_allowed(&member_to_confirm, "confirm", conn).await?;
@ -1435,13 +1438,15 @@ async fn _confirm_invite(
.await; .await;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
let org_name = match Organization::find_by_uuid(org_id, conn).await { let org_name = if let Some(org) = Organization::find_by_uuid(org_id, conn).await {
Some(org) => org.name, org.name
None => err!("Error looking up organization."), } else {
err!("Error looking up organization.")
}; };
let address = match User::find_by_uuid(&member_to_confirm.user_uuid, conn).await { let address = if let Some(user) = User::find_by_uuid(&member_to_confirm.user_uuid, conn).await {
Some(user) => user.email, user.email
None => err!("Error looking up user."), } else {
err!("Error looking up user.")
}; };
mail::send_invite_confirmed(&address, &org_name).await?; mail::send_invite_confirmed(&address, &org_name).await?;
} }
@ -1637,8 +1642,8 @@ async fn bulk_delete_member(
let mut bulk_response = Vec::new(); let mut bulk_response = Vec::new();
for member_id in data.ids { for member_id in data.ids {
let err_msg = match _delete_member(&org_id, &member_id, &headers, &conn, &nt).await { let err_msg = match delete_member_impl(&org_id, &member_id, &headers, &conn, &nt).await {
Ok(_) => String::new(), Ok(()) => String::new(),
Err(e) => format!("{e:?}"), Err(e) => format!("{e:?}"),
}; };
@ -1648,7 +1653,7 @@ async fn bulk_delete_member(
"id": member_id, "id": member_id,
"error": err_msg "error": err_msg
} }
)) ));
} }
Ok(Json(json!({ Ok(Json(json!({
@ -1666,10 +1671,10 @@ async fn delete_member(
conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
_delete_member(&org_id, &member_id, &headers, &conn, &nt).await delete_member_impl(&org_id, &member_id, &headers, &conn, &nt).await
} }
async fn _delete_member( async fn delete_member_impl(
org_id: &OrganizationId, org_id: &OrganizationId,
member_id: &MembershipId, member_id: &MembershipId,
headers: &AdminHeaders, headers: &AdminHeaders,
@ -1753,8 +1758,8 @@ async fn bulk_public_keys(
}))) })))
} }
use super::ciphers::update_cipher_from_data;
use super::ciphers::CipherData; use super::ciphers::CipherData;
use super::ciphers::update_cipher_from_data;
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@ -1902,11 +1907,12 @@ async fn post_bulk_collections(data: Json<BulkCollectionsData>, headers: Headers
} }
} }
for cipher_id in data.cipher_ids.iter() { for cipher_id in &data.cipher_ids {
// Only act on existing cipher uuid's // Only act on existing cipher uuid's
// Do not abort the operation just ignore it, it could be a cipher was just deleted for example // Do not abort the operation just ignore it, it could be a cipher was just deleted for example
if let Some(cipher) = Cipher::find_by_uuid_and_org(cipher_id, &data.organization_id, &conn).await { if let Some(cipher) = Cipher::find_by_uuid_and_org(cipher_id, &data.organization_id, &conn).await
if cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await { && cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await
{
// When selecting a specific collection from the left filter list, and use the bulk option, you can remove an item from that collection // When selecting a specific collection from the left filter list, and use the bulk option, you can remove an item from that collection
// In these cases the client will call this endpoint twice, once for adding the new collections and a second for deleting. // In these cases the client will call this endpoint twice, once for adding the new collections and a second for deleting.
if data.remove_collections { if data.remove_collections {
@ -1919,7 +1925,6 @@ async fn post_bulk_collections(data: Json<BulkCollectionsData>, headers: Headers
} }
} }
} }
};
} }
Ok(()) Ok(())
@ -1969,7 +1974,7 @@ async fn list_policies_token(org_id: OrganizationId, token: &str, conn: DbConn)
fn get_dummy_master_password_policy() -> JsonResult { fn get_dummy_master_password_policy() -> JsonResult {
let (enabled, data) = match CONFIG.sso_master_password_policy_value() { let (enabled, data) = match CONFIG.sso_master_password_policy_value() {
Some(policy) if CONFIG.sso_enabled() => (true, policy.to_string()), Some(policy) if CONFIG.sso_enabled() => (true, policy.to_string()),
_ => (false, "null".to_string()), _ => (false, "null".to_owned()),
}; };
let policy = OrgPolicy::new(FAKE_SSO_IDENTIFIER.into(), OrgPolicyType::MasterPassword, enabled, data); let policy = OrgPolicy::new(FAKE_SSO_IDENTIFIER.into(), OrgPolicyType::MasterPassword, enabled, data);
Ok(Json(policy.to_json())) Ok(Json(policy.to_json()))
@ -1982,7 +1987,7 @@ async fn get_master_password_policy(org_id: OrganizationId, _headers: OrgMemberH
OrgPolicy::find_by_org_and_type(&org_id, OrgPolicyType::MasterPassword, &conn).await.unwrap_or_else(|| { OrgPolicy::find_by_org_and_type(&org_id, OrgPolicyType::MasterPassword, &conn).await.unwrap_or_else(|| {
let (enabled, data) = match CONFIG.sso_master_password_policy_value() { let (enabled, data) = match CONFIG.sso_master_password_policy_value() {
Some(policy) if CONFIG.sso_enabled() => (true, policy.to_string()), Some(policy) if CONFIG.sso_enabled() => (true, policy.to_string()),
_ => (false, "null".to_string()), _ => (false, "null".to_owned()),
}; };
OrgPolicy::new(org_id, OrgPolicyType::MasterPassword, enabled, data) OrgPolicy::new(org_id, OrgPolicyType::MasterPassword, enabled, data)
@ -2003,7 +2008,7 @@ async fn get_policy(org_id: OrganizationId, pol_type: i32, headers: AdminHeaders
let policy = match OrgPolicy::find_by_org_and_type(&org_id, pol_type_enum, &conn).await { let policy = match OrgPolicy::find_by_org_and_type(&org_id, pol_type_enum, &conn).await {
Some(p) => p, Some(p) => p,
None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "null".to_string()), None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "null".to_owned()),
}; };
Ok(Json(policy.to_json())) Ok(Json(policy.to_json()))
@ -2078,7 +2083,7 @@ async fn put_policy(
// When enabling the SingleOrg policy, remove this org's members that are members of other orgs // When enabling the SingleOrg policy, remove this org's members that are members of other orgs
if pol_type_enum == OrgPolicyType::SingleOrg && data.enabled { if pol_type_enum == OrgPolicyType::SingleOrg && data.enabled {
for mut member in Membership::find_by_org(&org_id, &conn).await.into_iter() { for mut member in Membership::find_by_org(&org_id, &conn).await {
// Policy only applies to non-Owner/non-Admin members who have accepted joining the org // Policy only applies to non-Owner/non-Admin members who have accepted joining the org
// Exclude invited and revoked users when checking for this policy. // Exclude invited and revoked users when checking for this policy.
// Those users will not be allowed to accept or be activated because of the policy checks done there. // Those users will not be allowed to accept or be activated because of the policy checks done there.
@ -2113,7 +2118,7 @@ async fn put_policy(
let mut policy = match OrgPolicy::find_by_org_and_type(&org_id, pol_type_enum, &conn).await { let mut policy = match OrgPolicy::find_by_org_and_type(&org_id, pol_type_enum, &conn).await {
Some(p) => p, Some(p) => p,
None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "{}".to_string()), None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "{}".to_owned()),
}; };
policy.enabled = data.enabled; policy.enabled = data.enabled;
@ -2187,7 +2192,7 @@ fn get_plans() -> Json<Value> {
#[get("/organizations/<_org_id>/billing/metadata")] #[get("/organizations/<_org_id>/billing/metadata")]
fn get_billing_metadata(_org_id: OrganizationId, _headers: OrgMemberHeaders) -> Json<Value> { fn get_billing_metadata(_org_id: OrganizationId, _headers: OrgMemberHeaders) -> Json<Value> {
// Prevent a 404 error, which also causes Javascript errors. // Prevent a 404 error, which also causes Javascript errors.
Json(_empty_data_json()) Json(empty_data_json())
} }
#[get("/organizations/<_org_id>/billing/vnext/warnings")] #[get("/organizations/<_org_id>/billing/vnext/warnings")]
@ -2209,7 +2214,7 @@ fn get_self_host_billing_metadata(_org_id: OrganizationId, _headers: OrgMemberHe
})) }))
} }
fn _empty_data_json() -> Value { fn empty_data_json() -> Value {
json!({ json!({
"object": "list", "object": "list",
"data": [], "data": [],
@ -2230,7 +2235,7 @@ async fn revoke_member(
headers: AdminHeaders, headers: AdminHeaders,
conn: DbConn, conn: DbConn,
) -> EmptyResult { ) -> EmptyResult {
_revoke_member(&org_id, &member_id, &headers, &conn).await revoke_member_impl(&org_id, &member_id, &headers, &conn).await
} }
#[put("/organizations/<org_id>/users/revoke", data = "<data>")] #[put("/organizations/<org_id>/users/revoke", data = "<data>")]
@ -2249,8 +2254,8 @@ async fn bulk_revoke_members(
match data.ids { match data.ids {
Some(members) => { Some(members) => {
for member_id in members { for member_id in members {
let err_msg = match _revoke_member(&org_id, &member_id, &headers, &conn).await { let err_msg = match revoke_member_impl(&org_id, &member_id, &headers, &conn).await {
Ok(_) => String::new(), Ok(()) => String::new(),
Err(e) => format!("{e:?}"), Err(e) => format!("{e:?}"),
}; };
@ -2273,7 +2278,7 @@ async fn bulk_revoke_members(
}))) })))
} }
async fn _revoke_member( async fn revoke_member_impl(
org_id: &OrganizationId, org_id: &OrganizationId,
member_id: &MembershipId, member_id: &MembershipId,
headers: &AdminHeaders, headers: &AdminHeaders,
@ -2325,7 +2330,7 @@ async fn restore_member_vnext(
) -> EmptyResult { ) -> EmptyResult {
// Vaultwarden does not (yet) support the per User Collection linked to the `Enforce organization data ownership` policy. // Vaultwarden does not (yet) support the per User Collection linked to the `Enforce organization data ownership` policy.
// Therefor we ignore the `defaultUserCollectionName` data sent and just call restore_member // Therefor we ignore the `defaultUserCollectionName` data sent and just call restore_member
_restore_member(&org_id, &member_id, &headers, &conn).await restore_member_impl(&org_id, &member_id, &headers, &conn).await
} }
#[put("/organizations/<org_id>/users/<member_id>/restore")] #[put("/organizations/<org_id>/users/<member_id>/restore")]
@ -2335,7 +2340,7 @@ async fn restore_member(
headers: AdminHeaders, headers: AdminHeaders,
conn: DbConn, conn: DbConn,
) -> EmptyResult { ) -> EmptyResult {
_restore_member(&org_id, &member_id, &headers, &conn).await restore_member_impl(&org_id, &member_id, &headers, &conn).await
} }
#[put("/organizations/<org_id>/users/restore", data = "<data>")] #[put("/organizations/<org_id>/users/restore", data = "<data>")]
@ -2352,8 +2357,8 @@ async fn bulk_restore_members(
let mut bulk_response = Vec::new(); let mut bulk_response = Vec::new();
for member_id in data.ids { for member_id in data.ids {
let err_msg = match _restore_member(&org_id, &member_id, &headers, &conn).await { let err_msg = match restore_member_impl(&org_id, &member_id, &headers, &conn).await {
Ok(_) => String::new(), Ok(()) => String::new(),
Err(e) => format!("{e:?}"), Err(e) => format!("{e:?}"),
}; };
@ -2373,7 +2378,7 @@ async fn bulk_restore_members(
}))) })))
} }
async fn _restore_member( async fn restore_member_impl(
org_id: &OrganizationId, org_id: &OrganizationId,
member_id: &MembershipId, member_id: &MembershipId,
headers: &AdminHeaders, headers: &AdminHeaders,
@ -2429,11 +2434,11 @@ async fn get_groups_data(
if details { if details {
for g in groups { for g in groups {
groups_json.push(g.to_json_details(&conn).await) groups_json.push(g.to_json_details(&conn).await);
} }
} else { } else {
for g in groups { for g in groups {
groups_json.push(g.to_json()) groups_json.push(g.to_json());
} }
} }
groups_json groups_json
@ -2672,15 +2677,15 @@ async fn post_delete_group(
headers: AdminHeaders, headers: AdminHeaders,
conn: DbConn, conn: DbConn,
) -> EmptyResult { ) -> EmptyResult {
_delete_group(&org_id, &group_id, &headers, &conn).await delete_group_impl(&org_id, &group_id, &headers, &conn).await
} }
#[delete("/organizations/<org_id>/groups/<group_id>")] #[delete("/organizations/<org_id>/groups/<group_id>")]
async fn delete_group(org_id: OrganizationId, group_id: GroupId, headers: AdminHeaders, conn: DbConn) -> EmptyResult { async fn delete_group(org_id: OrganizationId, group_id: GroupId, headers: AdminHeaders, conn: DbConn) -> EmptyResult {
_delete_group(&org_id, &group_id, &headers, &conn).await delete_group_impl(&org_id, &group_id, &headers, &conn).await
} }
async fn _delete_group( async fn delete_group_impl(
org_id: &OrganizationId, org_id: &OrganizationId,
group_id: &GroupId, group_id: &GroupId,
headers: &AdminHeaders, headers: &AdminHeaders,
@ -2728,7 +2733,7 @@ async fn bulk_delete_groups(
let data: BulkGroupIds = data.into_inner(); let data: BulkGroupIds = data.into_inner();
for group_id in data.ids { for group_id in data.ids {
_delete_group(&org_id, &group_id, &headers, &conn).await? delete_group_impl(&org_id, &group_id, &headers, &conn).await?;
} }
Ok(()) Ok(())
} }
@ -2765,7 +2770,7 @@ async fn get_group_members(
if Group::find_by_uuid_and_org(&group_id, &org_id, &conn).await.is_none() { if Group::find_by_uuid_and_org(&group_id, &org_id, &conn).await.is_none() {
err!("Group could not be found!", "Group uuid is invalid or does not belong to the organization") err!("Group could not be found!", "Group uuid is invalid or does not belong to the organization")
}; }
let group_members: Vec<MembershipId> = GroupUser::find_by_group(&group_id, &org_id, &conn) let group_members: Vec<MembershipId> = GroupUser::find_by_group(&group_id, &org_id, &conn)
.await .await
@ -2793,7 +2798,7 @@ async fn put_group_members(
if Group::find_by_uuid_and_org(&group_id, &org_id, &conn).await.is_none() { if Group::find_by_uuid_and_org(&group_id, &org_id, &conn).await.is_none() {
err!("Group could not be found!", "Group uuid is invalid or does not belong to the organization") err!("Group could not be found!", "Group uuid is invalid or does not belong to the organization")
}; }
let assigned_members = data.into_inner(); let assigned_members = data.into_inner();
@ -3100,12 +3105,12 @@ async fn get_org_export(org_id: OrganizationId, headers: AdminHeaders, conn: DbC
} }
Ok(Json(json!({ Ok(Json(json!({
"collections": convert_json_key_lcase_first(_get_org_collections(&org_id, &conn).await), "collections": convert_json_key_lcase_first(get_org_collections_impl(&org_id, &conn).await),
"ciphers": convert_json_key_lcase_first(_get_org_details(&org_id, &headers.host, &headers.user.uuid, &conn).await?), "ciphers": convert_json_key_lcase_first(get_org_details_impl(&org_id, &headers.host, &headers.user.uuid, &conn).await?),
}))) })))
} }
async fn _api_key( async fn api_key(
org_id: &OrganizationId, org_id: &OrganizationId,
data: Json<PasswordOrOtpData>, data: Json<PasswordOrOtpData>,
rotate: bool, rotate: bool,
@ -3121,21 +3126,18 @@ async fn _api_key(
// Validate the admin users password/otp // Validate the admin users password/otp
data.validate(&user, true, &conn).await?; data.validate(&user, true, &conn).await?;
let org_api_key = match OrganizationApiKey::find_by_org_uuid(org_id, &conn).await { let org_api_key = if let Some(mut org_api_key) = OrganizationApiKey::find_by_org_uuid(org_id, &conn).await {
Some(mut org_api_key) => {
if rotate { if rotate {
org_api_key.api_key = crate::crypto::generate_api_key(); org_api_key.api_key = crate::crypto::generate_api_key();
org_api_key.revision_date = chrono::Utc::now().naive_utc(); org_api_key.revision_date = chrono::Utc::now().naive_utc();
org_api_key.save(&conn).await.expect("Error rotating organization API Key"); org_api_key.save(&conn).await.expect("Error rotating organization API Key");
} }
org_api_key org_api_key
} } else {
None => {
let api_key = crate::crypto::generate_api_key(); let api_key = crate::crypto::generate_api_key();
let new_org_api_key = OrganizationApiKey::new(org_id.clone(), api_key); let new_org_api_key = OrganizationApiKey::new(org_id.clone(), api_key);
new_org_api_key.save(&conn).await.expect("Error creating organization API Key"); new_org_api_key.save(&conn).await.expect("Error creating organization API Key");
new_org_api_key new_org_api_key
}
}; };
Ok(Json(json!({ Ok(Json(json!({
@ -3146,13 +3148,13 @@ async fn _api_key(
} }
#[post("/organizations/<org_id>/api-key", data = "<data>")] #[post("/organizations/<org_id>/api-key", data = "<data>")]
async fn api_key( async fn post_api_key(
org_id: OrganizationId, org_id: OrganizationId,
data: Json<PasswordOrOtpData>, data: Json<PasswordOrOtpData>,
headers: AdminHeaders, headers: AdminHeaders,
conn: DbConn, conn: DbConn,
) -> JsonResult { ) -> JsonResult {
_api_key(&org_id, data, false, headers, conn).await api_key(&org_id, data, false, headers, conn).await
} }
#[post("/organizations/<org_id>/rotate-api-key", data = "<data>")] #[post("/organizations/<org_id>/rotate-api-key", data = "<data>")]
@ -3162,5 +3164,5 @@ async fn rotate_api_key(
headers: AdminHeaders, headers: AdminHeaders,
conn: DbConn, conn: DbConn,
) -> JsonResult { ) -> JsonResult {
_api_key(&org_id, data, true, headers, conn).await api_key(&org_id, data, true, headers, conn).await
} }

77
src/api/core/public.rs

@ -1,23 +1,24 @@
use std::collections::HashSet;
use chrono::Utc; use chrono::Utc;
use rocket::{ use rocket::{
Request, Route,
request::{FromRequest, Outcome}, request::{FromRequest, Outcome},
serde::json::Json, serde::json::Json,
Request, Route,
}; };
use std::collections::HashSet;
use crate::{ use crate::{
CONFIG,
api::EmptyResult, api::EmptyResult,
auth, auth,
db::{ db::{
DbConn,
models::{ models::{
Group, GroupUser, Invitation, Membership, MembershipStatus, MembershipType, Organization, Group, GroupUser, Invitation, Membership, MembershipStatus, MembershipType, Organization,
OrganizationApiKey, OrganizationId, User, OrganizationApiKey, OrganizationId, User,
}, },
DbConn,
}, },
mail, CONFIG, mail,
}; };
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
@ -90,9 +91,9 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
} }
} else { } else {
// If user is not part of the organization // If user is not part of the organization
let user = match User::find_by_mail(&user_data.email, &conn).await { let user = if let Some(user) = User::find_by_mail(&user_data.email, &conn).await {
Some(user) => user, // exists in vaultwarden user
None => { } else {
// User does not exist yet // User does not exist yet
let mut new_user = User::new(&user_data.email, None); let mut new_user = User::new(&user_data.email, None);
new_user.save(&conn).await?; new_user.save(&conn).await?;
@ -102,7 +103,6 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
} }
user_created = true; user_created = true;
new_user new_user
}
}; };
let member_status = if CONFIG.mail_enabled() || user.password_hash.is_empty() { let member_status = if CONFIG.mail_enabled() || user.password_hash.is_empty() {
MembershipStatus::Invited as i32 MembershipStatus::Invited as i32
@ -110,9 +110,10 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
MembershipStatus::Accepted as i32 // Automatically mark user as accepted if no email invites MembershipStatus::Accepted as i32 // Automatically mark user as accepted if no email invites
}; };
let (org_name, org_email) = match Organization::find_by_uuid(&org_id, &conn).await { let (org_name, org_email) = if let Some(org) = Organization::find_by_uuid(&org_id, &conn).await {
Some(org) => (org.name, org.billing_email), (org.name, org.billing_email)
None => err!("Error looking up organization"), } else {
err!("Error looking up organization")
}; };
let mut new_member = Membership::new(user.uuid.clone(), org_id.clone(), Some(org_email.clone())); let mut new_member = Membership::new(user.uuid.clone(), org_id.clone(), Some(org_email.clone()));
@ -123,8 +124,8 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
new_member.save(&conn).await?; new_member.save(&conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled()
if let Err(e) = && let Err(e) =
mail::send_invite(&user, org_id.clone(), new_member.uuid.clone(), &org_name, Some(org_email)).await mail::send_invite(&user, org_id.clone(), new_member.uuid.clone(), &org_name, Some(org_email)).await
{ {
// Upon error delete the user, invite and org member records when needed // Upon error delete the user, invite and org member records when needed
@ -138,22 +139,18 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
} }
} }
} }
}
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
for group_data in &data.groups { for group_data in &data.groups {
let group_uuid = match Group::find_by_external_id_and_org(&group_data.external_id, &org_id, &conn).await { let group_uuid = if let Some(group) =
Some(group) => group.uuid, Group::find_by_external_id_and_org(&group_data.external_id, &org_id, &conn).await
None => { {
let mut group = Group::new( group.uuid
org_id.clone(), } else {
group_data.name.clone(), let mut group =
false, Group::new(org_id.clone(), group_data.name.clone(), false, Some(group_data.external_id.clone()));
Some(group_data.external_id.clone()),
);
group.save(&conn).await?; group.save(&conn).await?;
group.uuid group.uuid
}
}; };
GroupUser::delete_all_by_group(&group_uuid, &org_id, &conn).await?; GroupUser::delete_all_by_group(&group_uuid, &org_id, &conn).await?;
@ -174,12 +171,12 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
// Generate a HashSet to quickly verify if a member is listed or not. // Generate a HashSet to quickly verify if a member is listed or not.
let sync_members: HashSet<String> = data.members.into_iter().map(|m| m.external_id).collect(); let sync_members: HashSet<String> = data.members.into_iter().map(|m| m.external_id).collect();
for member in Membership::find_by_org(&org_id, &conn).await { for member in Membership::find_by_org(&org_id, &conn).await {
if let Some(ref user_external_id) = member.external_id { if let Some(ref user_external_id) = member.external_id
if !sync_members.contains(user_external_id) { && !sync_members.contains(user_external_id)
{
if member.atype == MembershipType::Owner && member.status == MembershipStatus::Confirmed as i32 { if member.atype == MembershipType::Owner && member.status == MembershipStatus::Confirmed as i32 {
// Removing owner, check that there is at least one other confirmed owner // Removing owner, check that there is at least one other confirmed owner
if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &conn).await <= 1 if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &conn).await <= 1 {
{
warn!("Can't delete the last owner"); warn!("Can't delete the last owner");
continue; continue;
} }
@ -188,7 +185,6 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
} }
} }
} }
}
Ok(()) Ok(())
} }
@ -202,12 +198,14 @@ impl<'r> FromRequest<'r> for PublicToken {
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let headers = request.headers(); let headers = request.headers();
// Get access_token // Get access_token
let access_token: &str = match headers.get_one("Authorization") { let access_token: &str = if let Some(a) = headers.get_one("Authorization") {
Some(a) => match a.rsplit("Bearer ").next() { if let Some(split) = a.rsplit("Bearer ").next() {
Some(split) => split, split
None => err_handler!("No access token provided"), } else {
}, err_handler!("No access token provided")
None => err_handler!("No access token provided"), }
} else {
err_handler!("No access token provided")
}; };
// Check JWT token is valid and get device and user from it // Check JWT token is valid and get device and user from it
let Ok(claims) = auth::decode_api_org(access_token) else { let Ok(claims) = auth::decode_api_org(access_token) else {
@ -229,14 +227,13 @@ impl<'r> FromRequest<'r> for PublicToken {
// Check if claims.sub is org_api_key.uuid // Check if claims.sub is org_api_key.uuid
// Check if claims.client_sub is org_api_key.org_uuid // Check if claims.client_sub is org_api_key.org_uuid
let conn = match DbConn::from_request(request).await { let Outcome::Success(conn) = DbConn::from_request(request).await else {
Outcome::Success(conn) => conn, err_handler!("Error getting DB")
_ => err_handler!("Error getting DB"),
}; };
let Some(org_id) = claims.client_id.strip_prefix("organization.") else { let Some(org_id) = claims.client_id.strip_prefix("organization.") else {
err_handler!("Malformed client_id") err_handler!("Malformed client_id")
}; };
let org_id: OrganizationId = org_id.to_string().into(); let org_id: OrganizationId = org_id.to_owned().into();
let Some(org_api_key) = OrganizationApiKey::find_by_org_uuid(&org_id, &conn).await else { let Some(org_api_key) = OrganizationApiKey::find_by_org_uuid(&org_id, &conn).await else {
err_handler!("Invalid client_id") err_handler!("Invalid client_id")
}; };

56
src/api/core/sends.rs

@ -10,15 +10,15 @@ use rocket::{
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
CONFIG,
api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType}, api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType},
auth::{ClientIp, Headers, Host}, auth::{ClientIp, Headers, Host},
config::PathType, config::PathType,
db::{ db::{
models::{Device, OrgPolicy, OrgPolicyType, Send, SendFileId, SendId, SendType, UserId},
DbConn, DbPool, DbConn, DbPool,
models::{Device, OrgPolicy, OrgPolicyType, Send, SendFileId, SendId, SendType, UserId},
}, },
util::{save_temp_file, NumberOrString}, util::{NumberOrString, save_temp_file},
CONFIG,
}; };
const SEND_INACCESSIBLE_MSG: &str = "Send does not exist or is no longer available"; const SEND_INACCESSIBLE_MSG: &str = "Send does not exist or is no longer available";
@ -63,7 +63,7 @@ pub async fn purge_sends(pool: DbPool) {
if let Ok(conn) = pool.get().await { if let Ok(conn) = pool.get().await {
Send::purge(&conn).await; Send::purge(&conn).await;
} else { } else {
error!("Failed to get DB connection while purging sends") error!("Failed to get DB connection while purging sends");
} }
} }
@ -168,7 +168,7 @@ fn create_send(data: SendData, user_id: UserId) -> ApiResult<Send> {
#[get("/sends")] #[get("/sends")]
async fn get_sends(headers: Headers, conn: DbConn) -> Json<Value> { async fn get_sends(headers: Headers, conn: DbConn) -> Json<Value> {
let sends = Send::find_by_user(&headers.user.uuid, &conn); let sends = Send::find_by_user(&headers.user.uuid, &conn);
let sends_json: Vec<Value> = sends.await.iter().map(|s| s.to_json()).collect(); let sends_json: Vec<Value> = sends.await.iter().map(Send::to_json).collect();
Json(json!({ Json(json!({
"data": sends_json, "data": sends_json,
@ -179,9 +179,10 @@ async fn get_sends(headers: Headers, conn: DbConn) -> Json<Value> {
#[get("/sends/<send_id>")] #[get("/sends/<send_id>")]
async fn get_send(send_id: SendId, headers: Headers, conn: DbConn) -> JsonResult { async fn get_send(send_id: SendId, headers: Headers, conn: DbConn) -> JsonResult {
match Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await { if let Some(send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await {
Some(send) => Ok(Json(send.to_json())), Ok(Json(send.to_json()))
None => err!("Send not found", "Invalid send uuid or does not belong to user"), } else {
err!("Send not found", "Invalid send uuid or does not belong to user")
} }
} }
@ -310,9 +311,10 @@ async fn post_send_file_v2(data: Json<SendData>, headers: Headers, conn: DbConn)
enforce_disable_hide_email_policy(&data, &headers, &conn).await?; enforce_disable_hide_email_policy(&data, &headers, &conn).await?;
let file_length = match &data.file_length { let file_length = if let Some(m) = &data.file_length {
Some(m) => m.into_i64()?, m.into_i64()?
_ => err!("Invalid send length"), } else {
err!("Invalid send length")
}; };
if file_length < 0 { if file_length < 0 {
err!("Send size can't be negative") err!("Send size can't be negative")
@ -457,17 +459,17 @@ async fn post_access(
err_code!(SEND_INACCESSIBLE_MSG, 404) err_code!(SEND_INACCESSIBLE_MSG, 404)
}; };
if let Some(max_access_count) = send.max_access_count { if let Some(max_access_count) = send.max_access_count
if send.access_count >= max_access_count { && send.access_count >= max_access_count
{
err_code!(SEND_INACCESSIBLE_MSG, 404); err_code!(SEND_INACCESSIBLE_MSG, 404);
} }
}
if let Some(expiration) = send.expiration_date { if let Some(expiration) = send.expiration_date
if Utc::now().naive_utc() >= expiration { && Utc::now().naive_utc() >= expiration
{
err_code!(SEND_INACCESSIBLE_MSG, 404) err_code!(SEND_INACCESSIBLE_MSG, 404)
} }
}
if Utc::now().naive_utc() >= send.deletion_date { if Utc::now().naive_utc() >= send.deletion_date {
err_code!(SEND_INACCESSIBLE_MSG, 404) err_code!(SEND_INACCESSIBLE_MSG, 404)
@ -517,17 +519,17 @@ async fn post_access_file(
err_code!(SEND_INACCESSIBLE_MSG, 404) err_code!(SEND_INACCESSIBLE_MSG, 404)
}; };
if let Some(max_access_count) = send.max_access_count { if let Some(max_access_count) = send.max_access_count
if send.access_count >= max_access_count { && send.access_count >= max_access_count
{
err_code!(SEND_INACCESSIBLE_MSG, 404) err_code!(SEND_INACCESSIBLE_MSG, 404)
} }
}
if let Some(expiration) = send.expiration_date { if let Some(expiration) = send.expiration_date
if Utc::now().naive_utc() >= expiration { && Utc::now().naive_utc() >= expiration
{
err_code!(SEND_INACCESSIBLE_MSG, 404) err_code!(SEND_INACCESSIBLE_MSG, 404)
} }
}
if Utc::now().naive_utc() >= send.deletion_date { if Utc::now().naive_utc() >= send.deletion_date {
err_code!(SEND_INACCESSIBLE_MSG, 404) err_code!(SEND_INACCESSIBLE_MSG, 404)
@ -572,7 +574,7 @@ async fn download_url(host: &Host, send_id: &SendId, file_id: &SendFileId) -> Re
let token_claims = crate::auth::generate_send_claims(send_id, file_id); let token_claims = crate::auth::generate_send_claims(send_id, file_id);
let token = crate::auth::encode_jwt(&token_claims); let token = crate::auth::encode_jwt(&token_claims);
Ok(format!("{}/api/sends/{send_id}/{file_id}?t={token}", &host.host)) Ok(format!("{}/api/sends/{send_id}/{file_id}?t={token}", host.host))
} else { } else {
Ok(operator.presign_read(&format!("{send_id}/{file_id}"), Duration::from_mins(5)).await?.uri().to_string()) Ok(operator.presign_read(&format!("{send_id}/{file_id}"), Duration::from_mins(5)).await?.uri().to_string())
} }
@ -580,11 +582,11 @@ async fn download_url(host: &Host, send_id: &SendId, file_id: &SendFileId) -> Re
#[get("/sends/<send_id>/<file_id>?<t>")] #[get("/sends/<send_id>/<file_id>?<t>")]
async fn download_send(send_id: SendId, file_id: SendFileId, t: &str) -> Option<NamedFile> { async fn download_send(send_id: SendId, file_id: SendFileId, t: &str) -> Option<NamedFile> {
if let Ok(claims) = crate::auth::decode_send(t) { if let Ok(claims) = crate::auth::decode_send(t)
if claims.sub == format!("{send_id}/{file_id}") { && claims.sub == format!("{send_id}/{file_id}")
{
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok(); return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok();
} }
}
None None
} }

22
src/api/core/two_factor/authenticator.rs

@ -1,14 +1,13 @@
use data_encoding::BASE32; use data_encoding::BASE32;
use rocket::serde::json::Json; use rocket::{Route, serde::json::Json};
use rocket::Route;
use crate::{ use crate::{
api::{core::log_user_event, core::two_factor::_generate_recover_code, EmptyResult, JsonResult, PasswordOrOtpData}, api::{EmptyResult, JsonResult, PasswordOrOtpData, core::log_user_event, core::two_factor::generate_recover_code},
auth::{ClientIp, Headers}, auth::{ClientIp, Headers},
crypto, crypto,
db::{ db::{
models::{EventType, TwoFactor, TwoFactorType, UserId},
DbConn, DbConn,
models::{EventType, TwoFactor, TwoFactorType, UserId},
}, },
util::NumberOrString, util::NumberOrString,
}; };
@ -70,9 +69,10 @@ async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: He
.await?; .await?;
// Validate key as base32 and 20 bytes length // Validate key as base32 and 20 bytes length
let decoded_key: Vec<u8> = match BASE32.decode(key.as_bytes()) { let decoded_key: Vec<u8> = if let Ok(decoded) = BASE32.decode(key.as_bytes()) {
Ok(decoded) => decoded, decoded
_ => err!("Invalid totp secret"), } else {
err!("Invalid totp secret")
}; };
if decoded_key.len() != 20 { if decoded_key.len() != 20 {
@ -82,7 +82,7 @@ async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: He
// Validate the token provided with the key, and save new twofactor // Validate the token provided with the key, and save new twofactor
validate_totp_code(&user.uuid, &token, &key.to_uppercase(), &headers.ip, &conn).await?; validate_totp_code(&user.uuid, &token, &key.to_uppercase(), &headers.ip, &conn).await?;
_generate_recover_code(&mut user, &conn).await; generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
@ -119,7 +119,7 @@ pub async fn validate_totp_code(
ip: &ClientIp, ip: &ClientIp,
conn: &DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
use totp_lite::{totp_custom, Sha1}; use totp_lite::{Sha1, totp_custom};
let Ok(decoded_secret) = BASE32.decode(secret.as_bytes()) else { let Ok(decoded_secret) = BASE32.decode(secret.as_bytes()) else {
err!("Invalid TOTP secret") err!("Invalid TOTP secret")
@ -128,7 +128,7 @@ pub async fn validate_totp_code(
let mut twofactor = match TwoFactor::find_by_user_and_type(user_id, TwoFactorType::Authenticator as i32, conn).await let mut twofactor = match TwoFactor::find_by_user_and_type(user_id, TwoFactorType::Authenticator as i32, conn).await
{ {
Some(tf) => tf, Some(tf) => tf,
_ => TwoFactor::new(user_id.clone(), TwoFactorType::Authenticator, secret.to_string()), _ => TwoFactor::new(user_id.clone(), TwoFactorType::Authenticator, secret.to_owned()),
}; };
// The amount of steps back and forward in time // The amount of steps back and forward in time
@ -145,7 +145,7 @@ pub async fn validate_totp_code(
// We need to calculate the time offsite and cast it as an u64. // We need to calculate the time offsite and cast it as an u64.
// Since we only have times into the future and the totp generator needs an u64 instead of the default i64. // Since we only have times into the future and the totp generator needs an u64 instead of the default i64.
let time = (current_timestamp + step * 30i64) as u64; let time: u64 = (current_timestamp + step * 30i64).cast_unsigned();
let generated = totp_custom::<Sha1>(30, 6, &decoded_secret, time); let generated = totp_custom::<Sha1>(30, 6, &decoded_secret, time);
// Check the given code equals the generated and if the time_step is larger then the one last used. // Check the given code equals the generated and if the time_step is larger then the one last used.

33
src/api/core/two_factor/duo.rs

@ -1,22 +1,21 @@
use chrono::Utc; use chrono::Utc;
use data_encoding::BASE64; use data_encoding::BASE64;
use rocket::serde::json::Json; use rocket::{Route, serde::json::Json};
use rocket::Route;
use crate::{ use crate::{
CONFIG,
api::{ api::{
core::log_user_event, core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult, ApiResult, EmptyResult, JsonResult, PasswordOrOtpData, core::log_user_event,
PasswordOrOtpData, core::two_factor::generate_recover_code,
}, },
auth::Headers, auth::Headers,
crypto, crypto,
db::{ db::{
models::{EventType, TwoFactor, TwoFactorType, User, UserId},
DbConn, DbConn,
models::{EventType, TwoFactor, TwoFactorType, User, UserId},
}, },
error::MapResult, error::MapResult,
http_client::make_http_request, http_client::make_http_request,
CONFIG,
}; };
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
@ -82,8 +81,7 @@ enum DuoStatus {
impl DuoStatus { impl DuoStatus {
fn data(self) -> Option<DuoData> { fn data(self) -> Option<DuoData> {
match self { match self {
DuoStatus::Global(data) => Some(data), DuoStatus::Global(data) | DuoStatus::User(data) => Some(data),
DuoStatus::User(data) => Some(data),
DuoStatus::Disabled(_) => None, DuoStatus::Disabled(_) => None,
} }
} }
@ -182,7 +180,7 @@ async fn activate_duo(data: Json<EnableDuoData>, headers: Headers, conn: DbConn)
let twofactor = TwoFactor::new(user.uuid.clone(), type_, data_str); let twofactor = TwoFactor::new(user.uuid.clone(), type_, data_str);
twofactor.save(&conn).await?; twofactor.save(&conn).await?;
_generate_recover_code(&mut user, &conn).await; generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
@ -201,14 +199,14 @@ async fn activate_duo_put(data: Json<EnableDuoData>, headers: Headers, conn: DbC
} }
async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult { async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
use reqwest::{header, Method}; use reqwest::{Method, header};
use std::str::FromStr; use std::str::FromStr;
// https://duo.com/docs/authapi#api-details // https://duo.com/docs/authapi#api-details
let url = format!("https://{}{path}", &data.host); let url = format!("https://{}{path}", data.host);
let date = Utc::now().to_rfc2822(); let dt = Utc::now().to_rfc2822();
let username = &data.ik; let username = &data.ik;
let fields = [&date, method, &data.host, path, params]; let fields = [&dt, method, &data.host, path, params];
let password = crypto::hmac_sign(&data.sk, &fields.join("\n")); let password = crypto::hmac_sign(&data.sk, &fields.join("\n"));
let m = Method::from_str(method).unwrap_or_default(); let m = Method::from_str(method).unwrap_or_default();
@ -216,7 +214,7 @@ async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData)
make_http_request(m, &url)? make_http_request(m, &url)?
.basic_auth(username, Some(password)) .basic_auth(username, Some(password))
.header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)") .header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)")
.header(header::DATE, date) .header(header::DATE, dt)
.send() .send()
.await? .await?
.error_for_status()?; .error_for_status()?;
@ -356,9 +354,10 @@ fn parse_duo_values(key: &str, val: &str, ikey: &str, prefix: &str, time: i64) -
err!("Invalid ikey") err!("Invalid ikey")
} }
let expire: i64 = match expire.parse() { let expire: i64 = if let Ok(e) = expire.parse() {
Ok(e) => e, e
Err(_) => err!("Invalid expire time"), } else {
err!("Invalid expire time")
}; };
if time >= expire { if time >= expire {

32
src/api/core/two_factor/duo_oidc.rs

@ -1,23 +1,24 @@
use std::collections::HashMap;
use chrono::Utc; use chrono::Utc;
use data_encoding::HEXLOWER; use data_encoding::HEXLOWER;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation}; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
use reqwest::{header, StatusCode}; use reqwest::{StatusCode, header};
use ring::digest::{digest, Digest, SHA512_256}; use ring::digest::{Digest, SHA512_256, digest};
use serde::Serialize; use serde::Serialize;
use std::collections::HashMap; use url::Url;
use crate::{ use crate::{
api::{core::two_factor::duo::get_duo_keys_email, EmptyResult}, CONFIG,
api::{EmptyResult, core::two_factor::duo::get_duo_keys_email},
crypto, crypto,
db::{ db::{
models::{DeviceId, EventType, TwoFactorDuoContext},
DbConn, DbPool, DbConn, DbPool,
models::{DeviceId, EventType, TwoFactorDuoContext},
}, },
error::Error, error::Error,
http_client::make_http_request, http_client::make_http_request,
CONFIG,
}; };
use url::Url;
// The location on this service that Duo should redirect users to. For us, this is a bridge // The location on this service that Duo should redirect users to. For us, this is a bridge
// built in to the Bitwarden clients. // built in to the Bitwarden clients.
@ -124,7 +125,7 @@ impl DuoClient {
ClientAssertion { ClientAssertion {
iss: self.client_id.clone(), iss: self.client_id.clone(),
sub: self.client_id.clone(), sub: self.client_id.clone(),
aud: url.to_string(), aud: url.to_owned(),
exp: now + JWT_VALIDITY_SECS, exp: now + JWT_VALIDITY_SECS,
jti: jwt_id, jti: jwt_id,
iat: now, iat: now,
@ -302,7 +303,7 @@ impl DuoClient {
if !(matching_nonces && matching_usernames) { if !(matching_nonces && matching_usernames) {
err!("Error validating Duo authorization, nonce or username mismatch.") err!("Error validating Duo authorization, nonce or username mismatch.")
}; }
Ok(()) Ok(())
} }
@ -347,7 +348,7 @@ pub async fn purge_duo_contexts(pool: DbPool) {
if let Ok(conn) = pool.get().await { if let Ok(conn) = pool.get().await {
TwoFactorDuoContext::purge_expired_duo_contexts(&conn).await; TwoFactorDuoContext::purge_expired_duo_contexts(&conn).await;
} else { } else {
error!("Failed to get DB connection while purging expired Duo authentications") error!("Failed to get DB connection while purging expired Duo authentications");
} }
} }
@ -394,7 +395,7 @@ pub async fn get_duo_auth_url(
match client.health_check().await { match client.health_check().await {
Ok(()) => {} Ok(()) => {}
Err(e) => return Err(e), Err(e) => return Err(e),
}; }
// Generate random OAuth2 state and OIDC Nonce // Generate random OAuth2 state and OIDC Nonce
let state: String = crypto::get_random_string_alphanum(STATE_LENGTH); let state: String = crypto::get_random_string_alphanum(STATE_LENGTH);
@ -438,16 +439,13 @@ pub async fn validate_duo_login(
// Get the context by the state reported by the client. If we don't have one, // Get the context by the state reported by the client. If we don't have one,
// it means the context is either missing or expired. // it means the context is either missing or expired.
let ctx = match extract_context(state, conn).await { let Some(ctx) = extract_context(state, conn).await else {
Some(c) => c,
None => {
err!( err!(
"Error validating duo authentication", "Error validating duo authentication",
ErrorEvent { ErrorEvent {
event: EventType::UserFailedLogIn2fa event: EventType::UserFailedLogIn2fa
} }
) )
}
}; };
// Context validation steps // Context validation steps
@ -476,13 +474,13 @@ pub async fn validate_duo_login(
match client.health_check().await { match client.health_check().await {
Ok(()) => {} Ok(()) => {}
Err(e) => return Err(e), Err(e) => return Err(e),
}; }
let d: Digest = digest(&SHA512_256, format!("{}{device_identifier}", ctx.nonce).as_bytes()); let d: Digest = digest(&SHA512_256, format!("{}{device_identifier}", ctx.nonce).as_bytes());
let hash: String = HEXLOWER.encode(d.as_ref()); let hash: String = HEXLOWER.encode(d.as_ref());
match client.exchange_authz_code_for_result(code, email, hash.as_str()).await { match client.exchange_authz_code_for_result(code, email, hash.as_str()).await {
Ok(_) => Ok(()), Ok(()) => Ok(()),
Err(_) => { Err(_) => {
err!( err!(
"Error validating duo authentication", "Error validating duo authentication",

34
src/api/core/two_factor/email.rs

@ -1,20 +1,20 @@
use chrono::{DateTime, TimeDelta, Utc}; use chrono::{DateTime, TimeDelta, Utc};
use rocket::serde::json::Json; use rocket::{Route, serde::json::Json};
use rocket::Route;
use crate::{ use crate::{
CONFIG,
api::{ api::{
core::{log_user_event, two_factor::_generate_recover_code},
EmptyResult, JsonResult, PasswordOrOtpData, EmptyResult, JsonResult, PasswordOrOtpData,
core::{log_user_event, two_factor::generate_recover_code},
}, },
auth::{ClientHeaders, Headers}, auth::{ClientHeaders, Headers},
crypto, crypto,
db::{ db::{
models::{AuthRequest, AuthRequestId, DeviceId, EventType, TwoFactor, TwoFactorType, User, UserId},
DbConn, DbConn,
models::{AuthRequest, AuthRequestId, DeviceId, EventType, TwoFactor, TwoFactorType, User, UserId},
}, },
error::{Error, MapResult}, error::{Error, MapResult},
mail, CONFIG, mail,
}; };
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
@ -232,7 +232,7 @@ async fn email(data: Json<EmailData>, headers: Headers, conn: DbConn) -> JsonRes
twofactor.data = email_data.to_json(); twofactor.data = email_data.to_json();
twofactor.save(&conn).await?; twofactor.save(&conn).await?;
_generate_recover_code(&mut user, &conn).await; generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
@ -284,9 +284,9 @@ pub async fn validate_email_code_str(
twofactor.data = email_data.to_json(); twofactor.data = email_data.to_json();
twofactor.save(conn).await?; twofactor.save(conn).await?;
let date = DateTime::from_timestamp(email_data.token_sent, 0).expect("Email token timestamp invalid.").naive_utc(); let dt = DateTime::from_timestamp(email_data.token_sent, 0).expect("Email token timestamp invalid.").naive_utc();
let max_time = CONFIG.email_expiration_time() as i64; let max_time = CONFIG.email_expiration_time().cast_signed();
if date + TimeDelta::try_seconds(max_time).unwrap() < Utc::now().naive_utc() { if dt + TimeDelta::try_seconds(max_time).unwrap() < Utc::now().naive_utc() {
err!( err!(
"Token has expired", "Token has expired",
ErrorEvent { ErrorEvent {
@ -342,9 +342,10 @@ impl EmailTokenData {
pub fn from_json(string: &str) -> Result<EmailTokenData, Error> { pub fn from_json(string: &str) -> Result<EmailTokenData, Error> {
let res: Result<EmailTokenData, serde_json::Error> = serde_json::from_str(string); let res: Result<EmailTokenData, serde_json::Error> = serde_json::from_str(string);
match res { if let Ok(x) = res {
Ok(x) => Ok(x), Ok(x)
Err(_) => err!("Could not decode EmailTokenData from string"), } else {
err!("Could not decode EmailTokenData from string")
} }
} }
} }
@ -362,18 +363,17 @@ pub async fn activate_email_2fa(user: &User, conn: &DbConn) -> EmptyResult {
pub fn obscure_email(email: &str) -> String { pub fn obscure_email(email: &str) -> String {
let split: Vec<&str> = email.rsplitn(2, '@').collect(); let split: Vec<&str> = email.rsplitn(2, '@').collect();
let mut name = split[1].to_string(); let mut name = split[1].to_owned();
let domain = &split[0]; let domain = &split[0];
let name_size = name.chars().count(); let name_size = name.chars().count();
let new_name = match name_size { let new_name = if let 1..=3 = name_size {
1..=3 => "*".repeat(name_size), "*".repeat(name_size)
_ => { } else {
let stars = "*".repeat(name_size - 2); let stars = "*".repeat(name_size - 2);
name.truncate(2); name.truncate(2);
format!("{name}{stars}") format!("{name}{stars}")
}
}; };
format!("{new_name}@{domain}") format!("{new_name}@{domain}")

31
src/api/core/two_factor/mod.rs

@ -1,28 +1,27 @@
use chrono::{TimeDelta, Utc}; use chrono::{TimeDelta, Utc};
use data_encoding::BASE32; use data_encoding::BASE32;
use num_traits::FromPrimitive; use num_traits::FromPrimitive;
use rocket::serde::json::Json; use rocket::{Route, serde::json::Json};
use rocket::Route;
use serde::Deserialize; use serde::Deserialize;
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
CONFIG,
api::{ api::{
core::{log_event, log_user_event},
EmptyResult, JsonResult, PasswordOrOtpData, EmptyResult, JsonResult, PasswordOrOtpData,
core::{log_event, log_user_event},
}, },
auth::Headers, auth::Headers,
crypto, crypto,
db::{ db::{
DbConn, DbPool,
models::{ models::{
DeviceType, EventType, Membership, MembershipType, OrgPolicyType, Organization, OrganizationId, TwoFactor, DeviceType, EventType, Membership, MembershipType, OrgPolicyType, Organization, OrganizationId, TwoFactor,
TwoFactorIncomplete, TwoFactorType, User, UserId, TwoFactorIncomplete, TwoFactorType, User, UserId,
}, },
DbConn, DbPool,
}, },
mail, mail,
util::NumberOrString, util::NumberOrString,
CONFIG,
}; };
pub mod authenticator; pub mod authenticator;
@ -37,7 +36,7 @@ fn has_global_duo_credentials() -> bool {
CONFIG._enable_duo() && CONFIG.duo_host().is_some() && CONFIG.duo_ikey().is_some() && CONFIG.duo_skey().is_some() CONFIG._enable_duo() && CONFIG.duo_host().is_some() && CONFIG.duo_ikey().is_some() && CONFIG.duo_skey().is_some()
} }
pub fn is_twofactor_provider_usable(provider_type: TwoFactorType, provider_data: Option<&str>) -> bool { pub fn is_twofactor_provider_usable(provider_type: &TwoFactorType, provider_data: Option<&str>) -> bool {
#[derive(Deserialize)] #[derive(Deserialize)]
struct DuoProviderData { struct DuoProviderData {
host: String, host: String,
@ -46,7 +45,7 @@ pub fn is_twofactor_provider_usable(provider_type: TwoFactorType, provider_data:
} }
match provider_type { match provider_type {
TwoFactorType::Authenticator => true, TwoFactorType::Authenticator | TwoFactorType::RecoveryCode => true,
TwoFactorType::Email => CONFIG._enable_email_2fa(), TwoFactorType::Email => CONFIG._enable_email_2fa(),
TwoFactorType::Duo | TwoFactorType::OrganizationDuo => { TwoFactorType::Duo | TwoFactorType::OrganizationDuo => {
provider_data provider_data
@ -59,7 +58,6 @@ pub fn is_twofactor_provider_usable(provider_type: TwoFactorType, provider_data:
} }
TwoFactorType::Webauthn => CONFIG.is_webauthn_2fa_supported(), TwoFactorType::Webauthn => CONFIG.is_webauthn_2fa_supported(),
TwoFactorType::Remember => !CONFIG.disable_2fa_remember(), TwoFactorType::Remember => !CONFIG.disable_2fa_remember(),
TwoFactorType::RecoveryCode => true,
TwoFactorType::U2f TwoFactorType::U2f
| TwoFactorType::U2fRegisterChallenge | TwoFactorType::U2fRegisterChallenge
| TwoFactorType::U2fLoginChallenge | TwoFactorType::U2fLoginChallenge
@ -96,7 +94,7 @@ async fn get_twofactor(headers: Headers, conn: DbConn) -> Json<Value> {
.iter() .iter()
.filter_map(|tf| { .filter_map(|tf| {
let provider_type = TwoFactorType::from_i32(tf.atype)?; let provider_type = TwoFactorType::from_i32(tf.atype)?;
is_twofactor_provider_usable(provider_type, Some(&tf.data)).then(|| TwoFactor::to_json_provider(tf)) is_twofactor_provider_usable(&provider_type, Some(&tf.data)).then(|| TwoFactor::to_json_provider(tf))
}) })
.collect(); .collect();
@ -120,7 +118,7 @@ async fn get_recover(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbCo
}))) })))
} }
async fn _generate_recover_code(user: &mut User, conn: &DbConn) { async fn generate_recover_code(user: &mut User, conn: &DbConn) {
if user.totp_recover.is_none() { if user.totp_recover.is_none() {
let totp_recover = crypto::encode_random_bytes::<20>(&BASE32); let totp_recover = crypto::encode_random_bytes::<20>(&BASE32);
user.totp_recover = Some(totp_recover); user.totp_recover = Some(totp_recover);
@ -180,9 +178,7 @@ pub async fn enforce_2fa_policy(
ip: &std::net::IpAddr, ip: &std::net::IpAddr,
conn: &DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
for member in for member in Membership::find_by_user_and_policy(&user.uuid, OrgPolicyType::TwoFactorAuthentication, conn).await {
Membership::find_by_user_and_policy(&user.uuid, OrgPolicyType::TwoFactorAuthentication, conn).await.into_iter()
{
// Policy only applies to non-Owner/non-Admin members who have accepted joining the org // Policy only applies to non-Owner/non-Admin members who have accepted joining the org
if member.atype < MembershipType::Admin { if member.atype < MembershipType::Admin {
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
@ -217,7 +213,7 @@ pub async fn enforce_2fa_policy_for_org(
conn: &DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
let org = Organization::find_by_uuid(org_id, conn).await.unwrap(); let org = Organization::find_by_uuid(org_id, conn).await.unwrap();
for member in Membership::find_confirmed_by_org(org_id, conn).await.into_iter() { for member in Membership::find_confirmed_by_org(org_id, conn).await {
// Don't enforce the policy for Admins and Owners. // Don't enforce the policy for Admins and Owners.
if member.atype < MembershipType::Admin && TwoFactor::find_by_user(&member.user_uuid, conn).await.is_empty() { if member.atype < MembershipType::Admin && TwoFactor::find_by_user(&member.user_uuid, conn).await.is_empty() {
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
@ -251,12 +247,9 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
return; return;
} }
let conn = match pool.get().await { let Ok(conn) = pool.get().await else {
Ok(conn) => conn,
_ => {
error!("Failed to get DB connection in send_incomplete_2fa_notifications()"); error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
return; return;
}
}; };
let now = Utc::now().naive_utc(); let now = Utc::now().naive_utc();
@ -278,7 +271,7 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
) )
.await .await
{ {
Ok(_) => { Ok(()) => {
if let Err(e) = login.delete(&conn).await { if let Err(e) = login.delete(&conn).await {
error!("Error deleting incomplete 2FA record: {e:#?}"); error!("Error deleting incomplete 2FA record: {e:#?}");
} }

26
src/api/core/two_factor/protected_actions.rs

@ -1,16 +1,17 @@
use chrono::{naive::serde::ts_seconds, NaiveDateTime, TimeDelta, Utc}; use chrono::{NaiveDateTime, TimeDelta, Utc, naive::serde::ts_seconds};
use rocket::{serde::json::Json, Route}; use rocket::{Route, serde::json::Json};
use crate::{ use crate::{
CONFIG,
api::EmptyResult, api::EmptyResult,
auth::Headers, auth::Headers,
crypto, crypto,
db::{ db::{
models::{TwoFactor, TwoFactorType, UserId},
DbConn, DbConn,
models::{TwoFactor, TwoFactorType, UserId},
}, },
error::{Error, MapResult}, error::{Error, MapResult},
mail, CONFIG, mail,
}; };
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
@ -44,9 +45,10 @@ impl ProtectedActionData {
pub fn from_json(string: &str) -> Result<Self, Error> { pub fn from_json(string: &str) -> Result<Self, Error> {
let res: Result<Self, serde_json::Error> = serde_json::from_str(string); let res: Result<Self, serde_json::Error> = serde_json::from_str(string);
match res { if let Ok(x) = res {
Ok(x) => Ok(x), Ok(x)
Err(_) => err!("Could not decode ProtectedActionData from string"), } else {
err!("Could not decode ProtectedActionData from string")
} }
} }
@ -62,7 +64,9 @@ impl ProtectedActionData {
#[post("/accounts/request-otp")] #[post("/accounts/request-otp")]
async fn request_otp(headers: Headers, conn: DbConn) -> EmptyResult { async fn request_otp(headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.mail_enabled() { if !CONFIG.mail_enabled() {
err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device."); err!(
"Email is disabled for this server. Either enable email or login using your master password instead of login via device."
);
} }
let user = headers.user; let user = headers.user;
@ -102,7 +106,9 @@ struct ProtectedActionVerify {
#[post("/accounts/verify-otp", data = "<data>")] #[post("/accounts/verify-otp", data = "<data>")]
async fn verify_otp(data: Json<ProtectedActionVerify>, headers: Headers, conn: DbConn) -> EmptyResult { async fn verify_otp(data: Json<ProtectedActionVerify>, headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.mail_enabled() { if !CONFIG.mail_enabled() {
err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device."); err!(
"Email is disabled for this server. Either enable email or login using your master password instead of login via device."
);
} }
let user = headers.user; let user = headers.user;
@ -133,7 +139,7 @@ pub async fn validate_protected_action_otp(
} }
// Check if the token has expired (Using the email 2fa expiration time) // Check if the token has expired (Using the email 2fa expiration time)
let max_time = CONFIG.email_expiration_time() as i64; let max_time = CONFIG.email_expiration_time().cast_signed();
if pa_data.time_since_sent().num_seconds() > max_time { if pa_data.time_since_sent().num_seconds() > max_time {
pa.delete(conn).await?; pa.delete(conn).await?;
err!("Token has expired") err!("Token has expired")

68
src/api/core/two_factor/webauthn.rs

@ -1,32 +1,33 @@
use std::{str::FromStr, sync::LazyLock, time::Duration};
use rocket::{Route, serde::json::Json};
use serde_json::Value;
use url::Url;
use uuid::Uuid;
use webauthn_rs::{
Webauthn, WebauthnBuilder,
prelude::{Base64UrlSafeData, Credential, Passkey, PasskeyAuthentication, PasskeyRegistration},
};
use webauthn_rs_proto::{
AuthenticationExtensionsClientOutputs, AuthenticatorAssertionResponseRaw, AuthenticatorAttestationResponseRaw,
PublicKeyCredential, RegisterPublicKeyCredential, RegistrationExtensionsClientOutputs,
RequestAuthenticationExtensions, UserVerificationPolicy,
};
use crate::{ use crate::{
CONFIG,
api::{ api::{
core::{log_user_event, two_factor::_generate_recover_code},
EmptyResult, JsonResult, PasswordOrOtpData, EmptyResult, JsonResult, PasswordOrOtpData,
core::{log_user_event, two_factor::generate_recover_code},
}, },
auth::Headers, auth::Headers,
crypto::ct_eq, crypto::ct_eq,
db::{ db::{
models::{EventType, TwoFactor, TwoFactorType, UserId},
DbConn, DbConn,
models::{EventType, TwoFactor, TwoFactorType, UserId},
}, },
error::Error, error::Error,
util::NumberOrString, util::NumberOrString,
CONFIG,
};
use rocket::serde::json::Json;
use rocket::Route;
use serde_json::Value;
use std::str::FromStr;
use std::sync::LazyLock;
use std::time::Duration;
use url::Url;
use uuid::Uuid;
use webauthn_rs::prelude::{Base64UrlSafeData, Credential, Passkey, PasskeyAuthentication, PasskeyRegistration};
use webauthn_rs::{Webauthn, WebauthnBuilder};
use webauthn_rs_proto::{
AuthenticationExtensionsClientOutputs, AuthenticatorAssertionResponseRaw, AuthenticatorAttestationResponseRaw,
PublicKeyCredential, RegisterPublicKeyCredential, RegistrationExtensionsClientOutputs,
RequestAuthenticationExtensions, UserVerificationPolicy,
}; };
static WEBAUTHN: LazyLock<Webauthn> = LazyLock::new(|| { static WEBAUTHN: LazyLock<Webauthn> = LazyLock::new(|| {
@ -149,7 +150,7 @@ async fn generate_webauthn_challenge(data: Json<PasswordOrOtpData>, headers: Hea
)?; )?;
let mut state = serde_json::to_value(&state)?; let mut state = serde_json::to_value(&state)?;
state["rs"]["policy"] = Value::String("discouraged".to_string()); state["rs"]["policy"] = Value::String("discouraged".to_owned());
state["rs"]["extensions"].as_object_mut().unwrap().clear(); state["rs"]["extensions"].as_object_mut().unwrap().clear();
let type_ = TwoFactorType::WebauthnRegisterChallenge; let type_ = TwoFactorType::WebauthnRegisterChallenge;
@ -265,13 +266,12 @@ async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, con
// Retrieve and delete the saved challenge state // Retrieve and delete the saved challenge state
let type_ = TwoFactorType::WebauthnRegisterChallenge as i32; let type_ = TwoFactorType::WebauthnRegisterChallenge as i32;
let state = match TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await { let state = if let Some(tf) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
Some(tf) => {
let state: PasskeyRegistration = serde_json::from_str(&tf.data)?; let state: PasskeyRegistration = serde_json::from_str(&tf.data)?;
tf.delete(&conn).await?; tf.delete(&conn).await?;
state state
} } else {
None => err!("Can't recover challenge"), err!("Can't recover challenge")
}; };
// Verify the credentials with the saved state // Verify the credentials with the saved state
@ -291,7 +291,7 @@ async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, con
TwoFactor::new(user.uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&registrations)?) TwoFactor::new(user.uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&registrations)?)
.save(&conn) .save(&conn)
.await?; .await?;
_generate_recover_code(&mut user, &conn).await; generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
@ -342,9 +342,10 @@ async fn delete_webauthn(data: Json<DeleteU2FData>, headers: Headers, conn: DbCo
// If entry is migrated from u2f, delete the u2f entry as well // If entry is migrated from u2f, delete the u2f entry as well
if let Some(mut u2f) = TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::U2f as i32, &conn).await if let Some(mut u2f) = TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::U2f as i32, &conn).await
{ {
let mut data: Vec<U2FRegistration> = match serde_json::from_str(&u2f.data) { let mut data: Vec<U2FRegistration> = if let Ok(d) = serde_json::from_str(&u2f.data) {
Ok(d) => d, d
Err(_) => err!("Error parsing U2F data"), } else {
err!("Error parsing U2F data")
}; };
data.retain(|r| r.reg.key_handle != removed_item.credential.cred_id().as_slice()); data.retain(|r| r.reg.key_handle != removed_item.credential.cred_id().as_slice());
@ -388,10 +389,10 @@ pub async fn generate_webauthn_login(user_id: &UserId, conn: &DbConn) -> JsonRes
// Modify to discourage user verification // Modify to discourage user verification
let mut state = serde_json::to_value(&state)?; let mut state = serde_json::to_value(&state)?;
state["ast"]["policy"] = Value::String("discouraged".to_string()); state["ast"]["policy"] = Value::String("discouraged".to_owned());
// Add appid, this is only needed for U2F compatibility, so maybe it can be removed as well // Add appid, this is only needed for U2F compatibility, so maybe it can be removed as well
let app_id = format!("{}/app-id.json", &CONFIG.domain()); let app_id = format!("{}/app-id.json", CONFIG.domain());
state["ast"]["appid"] = Value::String(app_id.clone()); state["ast"]["appid"] = Value::String(app_id.clone());
response.public_key.user_verification = UserVerificationPolicy::Discouraged_DO_NOT_USE; response.public_key.user_verification = UserVerificationPolicy::Discouraged_DO_NOT_USE;
@ -416,18 +417,17 @@ pub async fn generate_webauthn_login(user_id: &UserId, conn: &DbConn) -> JsonRes
pub async fn validate_webauthn_login(user_id: &UserId, response: &str, conn: &DbConn) -> EmptyResult { pub async fn validate_webauthn_login(user_id: &UserId, response: &str, conn: &DbConn) -> EmptyResult {
let type_ = TwoFactorType::WebauthnLoginChallenge as i32; let type_ = TwoFactorType::WebauthnLoginChallenge as i32;
let mut state = match TwoFactor::find_by_user_and_type(user_id, type_, conn).await { let mut state = if let Some(tf) = TwoFactor::find_by_user_and_type(user_id, type_, conn).await {
Some(tf) => {
let state: PasskeyAuthentication = serde_json::from_str(&tf.data)?; let state: PasskeyAuthentication = serde_json::from_str(&tf.data)?;
tf.delete(conn).await?; tf.delete(conn).await?;
state state
} } else {
None => err!( err!(
"Can't recover login challenge", "Can't recover login challenge",
ErrorEvent { ErrorEvent {
event: EventType::UserFailedLogIn2fa event: EventType::UserFailedLogIn2fa
} }
), )
}; };
let rsp: PublicKeyCredentialCopy = serde_json::from_str(response)?; let rsp: PublicKeyCredentialCopy = serde_json::from_str(response)?;

20
src/api/core/two_factor/yubikey.rs

@ -1,20 +1,19 @@
use rocket::serde::json::Json; use rocket::{Route, serde::json::Json};
use rocket::Route;
use serde_json::Value; use serde_json::Value;
use yubico::{config::Config, verify_async}; use yubico::{config::Config, verify_async};
use crate::{ use crate::{
CONFIG,
api::{ api::{
core::{log_user_event, two_factor::_generate_recover_code},
EmptyResult, JsonResult, PasswordOrOtpData, EmptyResult, JsonResult, PasswordOrOtpData,
core::{log_user_event, two_factor::generate_recover_code},
}, },
auth::Headers, auth::Headers,
db::{ db::{
models::{EventType, TwoFactor, TwoFactorType},
DbConn, DbConn,
models::{EventType, TwoFactor, TwoFactorType},
}, },
error::{Error, MapResult}, error::{Error, MapResult},
CONFIG,
}; };
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
@ -46,7 +45,7 @@ pub struct YubikeyMetadata {
fn parse_yubikeys(data: &EnableYubikeyData) -> Vec<String> { fn parse_yubikeys(data: &EnableYubikeyData) -> Vec<String> {
let data_keys = [&data.key1, &data.key2, &data.key3, &data.key4, &data.key5]; let data_keys = [&data.key1, &data.key2, &data.key3, &data.key4, &data.key5];
data_keys.iter().filter_map(|e| e.as_ref().cloned()).collect() data_keys.into_iter().flatten().cloned().collect()
} }
fn jsonify_yubikeys(yubikeys: Vec<String>) -> Value { fn jsonify_yubikeys(yubikeys: Vec<String>) -> Value {
@ -64,9 +63,10 @@ fn get_yubico_credentials() -> Result<(String, String), Error> {
err!("Yubico support is disabled"); err!("Yubico support is disabled");
} }
match (CONFIG.yubico_client_id(), CONFIG.yubico_secret_key()) { if let (Some(id), Some(secret)) = (CONFIG.yubico_client_id(), CONFIG.yubico_secret_key()) {
(Some(id), Some(secret)) => Ok((id, secret)), Ok((id, secret))
_ => err!("`YUBICO_CLIENT_ID` or `YUBICO_SECRET_KEY` environment variable is not set. Yubikey OTP Disabled"), } else {
err!("`YUBICO_CLIENT_ID` or `YUBICO_SECRET_KEY` environment variable is not set. Yubikey OTP Disabled")
} }
} }
@ -162,7 +162,7 @@ async fn activate_yubikey(data: Json<EnableYubikeyData>, headers: Headers, conn:
yubikey_data.data = serde_json::to_string(&yubikey_metadata).unwrap(); yubikey_data.data = serde_json::to_string(&yubikey_metadata).unwrap();
yubikey_data.save(&conn).await?; yubikey_data.save(&conn).await?;
_generate_recover_code(&mut user, &conn).await; generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;

55
src/api/icons.rs

@ -6,28 +6,29 @@ use std::{
}; };
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{stream::StreamExt, TryFutureExt}; use futures::{TryFutureExt, stream::StreamExt};
use html5gum::{Emitter, HtmlString, Readable, StringReader, Tokenizer}; use html5gum::{Emitter, HtmlString, Readable, StringReader, Tokenizer};
use regex::Regex; use regex::Regex;
use reqwest::{ use reqwest::{
header::{self, HeaderMap, HeaderValue},
Client, Response, Client, Response,
header::{self, HeaderMap, HeaderValue},
}; };
use rocket::{http::ContentType, response::Redirect, Route}; use rocket::{Route, http::ContentType, response::Redirect};
use svg_hush::{data_url_filter, Filter}; use svg_hush::{Filter, data_url_filter};
use crate::{ use crate::{
CONFIG,
config::PathType, config::PathType,
error::Error, error::Error,
http_client::{get_reqwest_client_builder, get_valid_host, should_block_host, CustomHttpClientError}, http_client::{CustomHttpClientError, get_reqwest_client_builder, get_valid_host, should_block_host},
util::Cached, util::Cached,
CONFIG,
}; };
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
match CONFIG.icon_service().as_str() { if CONFIG.icon_service().as_str() == "internal" {
"internal" => routes![icon_internal], routes![icon_internal]
_ => routes![icon_external], } else {
routes![icon_external]
} }
} }
@ -147,7 +148,7 @@ async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
if let Some(icon) = get_cached_icon(&path).await { if let Some(icon) = get_cached_icon(&path).await {
let icon_type = get_icon_type(&icon).unwrap_or("x-icon"); let icon_type = get_icon_type(&icon).unwrap_or("x-icon");
return Some((icon, icon_type.to_string())); return Some((icon, icon_type.to_owned()));
} }
if CONFIG.disable_icon_download() { if CONFIG.disable_icon_download() {
@ -158,7 +159,7 @@ async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
match download_icon(domain).await { match download_icon(domain).await {
Ok((icon, icon_type)) => { Ok((icon, icon_type)) => {
save_icon(&path, icon.to_vec()).await; save_icon(&path, icon.to_vec()).await;
Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string())) Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_owned()))
} }
Err(e) => { Err(e) => {
// If this error comes from the custom resolver, this means this is a blocked domain // If this error comes from the custom resolver, this means this is a blocked domain
@ -183,11 +184,11 @@ async fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
} }
// Try to read the cached icon, and return it if it exists // Try to read the cached icon, and return it if it exists
if let Ok(operator) = CONFIG.opendal_operator_for_path_type(&PathType::IconCache) { if let Ok(operator) = CONFIG.opendal_operator_for_path_type(&PathType::IconCache)
if let Ok(buf) = operator.read(path).await { && let Ok(buf) = operator.read(path).await
{
return Some(buf.to_vec()); return Some(buf.to_vec());
} }
}
None None
} }
@ -280,8 +281,9 @@ fn get_favicons_node(dom: Tokenizer<StringReader<'_>, FaviconEmitter>, icons: &m
} }
for icon_tag in icon_tags { for icon_tag in icon_tags {
if let Some(icon_href) = icon_tag.attributes.get(ATTR_HREF) { if let Some(icon_href) = icon_tag.attributes.get(ATTR_HREF)
if let Ok(full_href) = base_url.join(std::str::from_utf8(icon_href).unwrap_or_default()) { && let Ok(full_href) = base_url.join(std::str::from_utf8(icon_href).unwrap_or_default())
{
let sizes = if let Some(v) = icon_tag.attributes.get(ATTR_SIZES) { let sizes = if let Some(v) = icon_tag.attributes.get(ATTR_SIZES) {
std::str::from_utf8(v).unwrap_or_default() std::str::from_utf8(v).unwrap_or_default()
} else { } else {
@ -290,7 +292,6 @@ fn get_favicons_node(dom: Tokenizer<StringReader<'_>, FaviconEmitter>, icons: &m
let priority = get_icon_priority(full_href.as_str(), sizes); let priority = get_icon_priority(full_href.as_str(), sizes);
icons.push(Icon::new(priority, full_href.to_string())); icons.push(Icon::new(priority, full_href.to_string()));
} }
};
} }
} }
@ -406,7 +407,7 @@ async fn get_page(url: &str) -> Result<Response, Error> {
async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> { async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
let mut client = CLIENT.get(url); let mut client = CLIENT.get(url);
if !referer.is_empty() { if !referer.is_empty() {
client = client.header("Referer", referer) client = client.header("Referer", referer);
} }
Ok(client.send().await?.error_for_status()?) Ok(client.send().await?.error_for_status()?)
@ -494,12 +495,10 @@ async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
let mut buffer = Bytes::new(); let mut buffer = Bytes::new();
let mut icon_type: Option<&str> = None; let mut icon_type: Option<&str> = None;
use data_url::DataUrl;
let mut icons = icon_result.iconlist.iter().take(5).peekable(); let mut icons = icon_result.iconlist.iter().take(5).peekable();
while let Some(icon) = icons.next() { while let Some(icon) = icons.next() {
if icon.href.starts_with("data:image") { if icon.href.starts_with("data:image") {
let Ok(datauri) = DataUrl::process(&icon.href) else { let Ok(datauri) = data_url::DataUrl::process(&icon.href) else {
continue; continue;
}; };
// Check if we are able to decode the data uri // Check if we are able to decode the data uri
@ -523,7 +522,7 @@ async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
} }
} }
_ => debug!("Extracted icon from data:image uri is invalid"), _ => debug!("Extracted icon from data:image uri is invalid"),
}; }
} else { } else {
debug!("Trying {}", icon.href); debug!("Trying {}", icon.href);
// Make sure all icons are checked before returning error // Make sure all icons are checked before returning error
@ -587,11 +586,11 @@ async fn save_icon(path: &str, icon: Vec<u8>) {
fn get_icon_type(bytes: &[u8]) -> Option<&'static str> { fn get_icon_type(bytes: &[u8]) -> Option<&'static str> {
fn check_svg_after_xml_declaration(bytes: &[u8]) -> Option<&'static str> { fn check_svg_after_xml_declaration(bytes: &[u8]) -> Option<&'static str> {
// Look for SVG tag within the first 1KB // Look for SVG tag within the first 1KB
if let Ok(content) = std::str::from_utf8(&bytes[..bytes.len().min(1024)]) { if let Ok(content) = std::str::from_utf8(&bytes[..bytes.len().min(1024)])
if content.contains("<svg") || content.contains("<SVG") { && (content.contains("<svg") || content.contains("<SVG"))
{
return Some("svg+xml"); return Some("svg+xml");
} }
}
None None
} }
@ -733,7 +732,7 @@ impl FaviconEmitter {
let rel_value = let rel_value =
std::str::from_utf8(token.tag.attributes.get(ATTR_REL).unwrap()).unwrap_or_default(); std::str::from_utf8(token.tag.attributes.get(ATTR_REL).unwrap()).unwrap_or_default();
if rel_value.contains("icon") && !rel_value.contains("mask-icon") { if rel_value.contains("icon") && !rel_value.contains("mask-icon") {
self.emit_token = true self.emit_token = true;
} }
} }
_ => (), _ => (),
@ -806,13 +805,13 @@ impl Emitter for FaviconEmitter {
fn push_attribute_name(&mut self, s: &[u8]) { fn push_attribute_name(&mut self, s: &[u8]) {
if let Some(attr) = &mut self.current_attribute { if let Some(attr) = &mut self.current_attribute {
attr.0.extend(s) attr.0.extend(s);
} }
} }
fn push_attribute_value(&mut self, s: &[u8]) { fn push_attribute_value(&mut self, s: &[u8]) {
if let Some(attr) = &mut self.current_attribute { if let Some(attr) = &mut self.current_attribute {
attr.1.extend(s) attr.1.extend(s);
} }
} }

206
src/api/identity.rs

@ -1,18 +1,20 @@
use chrono::Utc; use chrono::Utc;
use num_traits::FromPrimitive; use num_traits::FromPrimitive;
use rocket::{ use rocket::{
Route,
form::{Form, FromForm}, form::{Form, FromForm},
http::{Cookie, CookieJar, SameSite}, http::{Cookie, CookieJar, SameSite},
response::Redirect, response::Redirect,
serde::json::Json, serde::json::Json,
Route,
}; };
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
CONFIG,
api::{ api::{
ApiResult, EmptyResult, JsonResult,
core::{ core::{
accounts::{_prelogin, _register, kdf_upgrade, PreloginData, RegisterData}, accounts::{PreloginData, RegisterData, kdf_upgrade, prelogin, register},
log_user_event, log_user_event,
two_factor::{ two_factor::{
authenticator, duo, duo_oidc, email, enforce_2fa_policy, is_twofactor_provider_usable, webauthn, authenticator, duo, duo_oidc, email, enforce_2fa_policy, is_twofactor_provider_usable, webauthn,
@ -21,29 +23,28 @@ use crate::{
}, },
master_password_policy, master_password_policy,
push::register_push_device, push::register_push_device,
ApiResult, EmptyResult, JsonResult,
}, },
auth, auth,
auth::{generate_organization_api_key_login_claims, AuthMethod, ClientHeaders, ClientIp, ClientVersion, Secure}, auth::{AuthMethod, ClientHeaders, ClientIp, ClientVersion, Secure, generate_organization_api_key_login_claims},
crypto, crypto,
db::{ db::{
DbConn,
models::{ models::{
AuthRequest, AuthRequestId, Device, DeviceId, EventType, Invitation, OIDCCodeResponseError, AuthRequest, AuthRequestId, Device, DeviceId, EventType, Invitation, OIDCCodeResponseError,
OrganizationApiKey, OrganizationId, SsoAuth, SsoUser, TwoFactor, TwoFactorIncomplete, TwoFactorType, User, OrganizationApiKey, OrganizationId, SsoAuth, SsoUser, TwoFactor, TwoFactorIncomplete, TwoFactorType, User,
UserId, UserId,
}, },
DbConn,
}, },
error::MapResult, error::MapResult,
mail, sso, mail, sso,
sso::{OIDCCode, OIDCCodeChallenge, OIDCCodeVerifier, OIDCState}, sso::{OIDCCode, OIDCCodeChallenge, OIDCCodeVerifier, OIDCState},
util, CONFIG, util,
}; };
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
routes![ routes![
login, login,
prelogin, post_prelogin,
prelogin_password, prelogin_password,
identity_register, identity_register,
register_verification_email, register_verification_email,
@ -68,43 +69,43 @@ async fn login(
let login_result = match data.grant_type.as_ref() { let login_result = match data.grant_type.as_ref() {
"refresh_token" => { "refresh_token" => {
_check_is_some(data.refresh_token.as_ref(), "refresh_token cannot be blank")?; check_is_some(data.refresh_token.as_ref(), "refresh_token cannot be blank")?;
_refresh_login(data, &conn, &client_header.ip).await refresh_login(data, &conn, &client_header.ip).await
} }
"password" if CONFIG.sso_enabled() && CONFIG.sso_only() => err!("SSO sign-in is required"), "password" if CONFIG.sso_enabled() && CONFIG.sso_only() => err!("SSO sign-in is required"),
"password" => { "password" => {
_check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?; check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?;
_check_is_some(data.password.as_ref(), "password cannot be blank")?; check_is_some(data.password.as_ref(), "password cannot be blank")?;
_check_is_some(data.scope.as_ref(), "scope cannot be blank")?; check_is_some(data.scope.as_ref(), "scope cannot be blank")?;
_check_is_some(data.username.as_ref(), "username cannot be blank")?; check_is_some(data.username.as_ref(), "username cannot be blank")?;
_check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?; check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?;
_check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?; check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?;
_check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?; check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?;
_password_login(data, &mut user_id, &conn, &client_header.ip, client_version.as_ref()).await password_login(data, &mut user_id, &conn, &client_header.ip, client_version.as_ref()).await
} }
"client_credentials" => { "client_credentials" => {
_check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?; check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?;
_check_is_some(data.client_secret.as_ref(), "client_secret cannot be blank")?; check_is_some(data.client_secret.as_ref(), "client_secret cannot be blank")?;
_check_is_some(data.scope.as_ref(), "scope cannot be blank")?; check_is_some(data.scope.as_ref(), "scope cannot be blank")?;
_check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?; check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?;
_check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?; check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?;
_check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?; check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?;
_api_key_login(data, &mut user_id, &conn, &client_header.ip).await api_key_login(data, &mut user_id, &conn, &client_header.ip).await
} }
"authorization_code" if CONFIG.sso_enabled() => { "authorization_code" if CONFIG.sso_enabled() => {
_check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?; check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?;
_check_is_some(data.code.as_ref(), "code cannot be blank")?; check_is_some(data.code.as_ref(), "code cannot be blank")?;
_check_is_some(data.code_verifier.as_ref(), "code verifier cannot be blank")?; check_is_some(data.code_verifier.as_ref(), "code verifier cannot be blank")?;
_check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?; check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?;
_check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?; check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?;
_check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?; check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?;
_sso_login(data, &mut user_id, &conn, &client_header.ip, client_version.as_ref()).await sso_login(data, &mut user_id, &conn, &client_header.ip, client_version.as_ref()).await
} }
"authorization_code" => err!("SSO sign-in is not available"), "authorization_code" => err!("SSO sign-in is not available"),
t => err!("Invalid type", t), t => err!("Invalid type", t),
@ -125,7 +126,7 @@ async fn login(
Err(e) => { Err(e) => {
if let Some(ev) = e.get_event() { if let Some(ev) = e.get_event() {
log_user_event(ev.event as i32, &user_id, client_header.device_type, &client_header.ip.ip, &conn) log_user_event(ev.event as i32, &user_id, client_header.device_type, &client_header.ip.ip, &conn)
.await .await;
} }
} }
} }
@ -134,7 +135,7 @@ async fn login(
login_result login_result
} }
async fn _refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult { async fn refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// When a refresh token is invalid or missing we need to respond with an HTTP BadRequest (400) // When a refresh token is invalid or missing we need to respond with an HTTP BadRequest (400)
// It also needs to return a json which holds at least a key `error` with the value `invalid_grant` // It also needs to return a json which holds at least a key `error` with the value `invalid_grant`
// See the link below for details // See the link below for details
@ -175,7 +176,7 @@ async fn _refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> Json
} }
// After exchanging the code we need to check first if 2FA is needed before continuing // After exchanging the code we need to check first if 2FA is needed before continuing
async fn _sso_login( async fn sso_login(
data: ConnectData, data: ConnectData,
user_id: &mut Option<UserId>, user_id: &mut Option<UserId>,
conn: &DbConn, conn: &DbConn,
@ -344,7 +345,7 @@ async fn _sso_login(
authenticated_response(&user, &mut device, auth_tokens, twofactor_token, conn, ip).await authenticated_response(&user, &mut device, auth_tokens, twofactor_token, conn, ip).await
} }
async fn _password_login( async fn password_login(
data: ConnectData, data: ConnectData,
user_id: &mut Option<UserId>, user_id: &mut Option<UserId>,
conn: &DbConn, conn: &DbConn,
@ -428,9 +429,9 @@ async fn _password_login(
if user.verified_at.is_none() && CONFIG.mail_enabled() && CONFIG.signups_verify() { if user.verified_at.is_none() && CONFIG.mail_enabled() && CONFIG.signups_verify() {
if user.last_verifying_at.is_none() if user.last_verifying_at.is_none()
|| now.signed_duration_since(user.last_verifying_at.unwrap()).num_seconds() || now.signed_duration_since(user.last_verifying_at.unwrap()).num_seconds()
> CONFIG.signups_verify_resend_time() as i64 > CONFIG.signups_verify_resend_time().cast_signed()
{ {
let resend_limit = CONFIG.signups_verify_resend_limit() as i32; let resend_limit = CONFIG.signups_verify_resend_limit().cast_signed();
if resend_limit == 0 || user.login_verify_count < resend_limit { if resend_limit == 0 || user.login_verify_count < resend_limit {
// We want to send another email verification if we require signups to verify // We want to send another email verification if we require signups to verify
// their email address, and we haven't sent them a reminder in a while... // their email address, and we haven't sent them a reminder in a while...
@ -566,19 +567,19 @@ async fn authenticated_response(
Ok(Json(result)) Ok(Json(result))
} }
async fn _api_key_login(data: ConnectData, user_id: &mut Option<UserId>, conn: &DbConn, ip: &ClientIp) -> JsonResult { async fn api_key_login(data: ConnectData, user_id: &mut Option<UserId>, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// Ratelimit the login // Ratelimit the login
crate::ratelimit::check_limit_login(&ip.ip)?; crate::ratelimit::check_limit_login(&ip.ip)?;
// Validate scope // Validate scope
match data.scope.as_ref() { match data.scope.as_ref() {
Some(scope) if scope == &AuthMethod::UserApiKey.scope() => _user_api_key_login(data, user_id, conn, ip).await, Some(scope) if scope == &AuthMethod::UserApiKey.scope() => user_api_key_login(data, user_id, conn, ip).await,
Some(scope) if scope == &AuthMethod::OrgApiKey.scope() => _organization_api_key_login(data, conn, ip).await, Some(scope) if scope == &AuthMethod::OrgApiKey.scope() => organization_api_key_login(data, conn, ip).await,
_ => err!("Scope not supported"), _ => err!("Scope not supported"),
} }
} }
async fn _user_api_key_login( async fn user_api_key_login(
data: ConnectData, data: ConnectData,
user_id: &mut Option<UserId>, user_id: &mut Option<UserId>,
conn: &DbConn, conn: &DbConn,
@ -710,13 +711,13 @@ async fn _user_api_key_login(
Ok(Json(result)) Ok(Json(result))
} }
async fn _organization_api_key_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult { async fn organization_api_key_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// Get the org via the client_id // Get the org via the client_id
let client_id = data.client_id.as_ref().unwrap(); let client_id = data.client_id.as_ref().unwrap();
let Some(org_id) = client_id.strip_prefix("organization.") else { let Some(org_id) = client_id.strip_prefix("organization.") else {
err!("Malformed client_id", format!("IP: {}.", ip.ip)) err!("Malformed client_id", format!("IP: {}.", ip.ip))
}; };
let org_id: OrganizationId = org_id.to_string().into(); let org_id: OrganizationId = org_id.to_owned().into();
let Some(org_api_key) = OrganizationApiKey::find_by_org_uuid(&org_id, conn).await else { let Some(org_api_key) = OrganizationApiKey::find_by_org_uuid(&org_id, conn).await else {
err!("Invalid client_id", format!("IP: {}.", ip.ip)) err!("Invalid client_id", format!("IP: {}.", ip.ip))
}; };
@ -747,15 +748,14 @@ async fn get_device(data: &ConnectData, conn: &DbConn, user: &User) -> ApiResult
let device_name = data.device_name.clone().expect("No device name provided"); let device_name = data.device_name.clone().expect("No device name provided");
// Find device or create new // Find device or create new
match Device::find_by_uuid_and_user(&device_id, &user.uuid, conn).await { if let Some(device) = Device::find_by_uuid_and_user(&device_id, &user.uuid, conn).await {
Some(device) => Ok(device), Ok(device)
None => { } else {
let mut device = Device::new(device_id, user.uuid.clone(), device_name, device_type); let mut device = Device::new(device_id, user.uuid.clone(), device_name, device_type);
// save device without updating `device.updated_at` // save device without updating `device.updated_at`
device.save(false, conn).await?; device.save(false, conn).await?;
Ok(device) Ok(device)
} }
}
} }
async fn twofactor_auth( async fn twofactor_auth(
@ -780,7 +780,7 @@ async fn twofactor_auth(
.iter() .iter()
.filter_map(|tf| { .filter_map(|tf| {
let provider_type = TwoFactorType::from_i32(tf.atype)?; let provider_type = TwoFactorType::from_i32(tf.atype)?;
(tf.enabled && is_twofactor_provider_usable(provider_type, Some(&tf.data))).then_some(tf.atype) (tf.enabled && is_twofactor_provider_usable(&provider_type, Some(&tf.data))).then_some(tf.atype)
}) })
.collect(); .collect();
if twofactor_ids.is_empty() { if twofactor_ids.is_empty() {
@ -793,40 +793,33 @@ async fn twofactor_auth(
&& !twofactor_ids.contains(&selected_id) && !twofactor_ids.contains(&selected_id)
{ {
err_json!( err_json!(
_json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?, json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
"Invalid two factor provider" "Invalid two factor provider"
) )
} }
let twofactor_code = match data.two_factor_token { let Some(ref twofactor_code) = data.two_factor_token else {
Some(ref code) => code,
None => {
err_json!( err_json!(
_json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?, json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
"2FA token not provided" "2FA token not provided"
) )
}
}; };
let selected_twofactor = twofactors.into_iter().find(|tf| tf.atype == selected_id && tf.enabled); let selected_twofactor = twofactors.into_iter().find(|tf| tf.atype == selected_id && tf.enabled);
use crate::crypto::ct_eq; let selected_data = selected_data(selected_twofactor);
let selected_data = _selected_data(selected_twofactor);
match TwoFactorType::from_i32(selected_id) { match TwoFactorType::from_i32(selected_id) {
Some(TwoFactorType::Authenticator) => { Some(TwoFactorType::Authenticator) => {
authenticator::validate_totp_code_str(&user.uuid, twofactor_code, &selected_data?, ip, conn).await? authenticator::validate_totp_code_str(&user.uuid, twofactor_code, &selected_data?, ip, conn).await?;
} }
Some(TwoFactorType::Webauthn) => webauthn::validate_webauthn_login(&user.uuid, twofactor_code, conn).await?, Some(TwoFactorType::Webauthn) => webauthn::validate_webauthn_login(&user.uuid, twofactor_code, conn).await?,
Some(TwoFactorType::YubiKey) => yubikey::validate_yubikey_login(twofactor_code, &selected_data?).await?, Some(TwoFactorType::YubiKey) => yubikey::validate_yubikey_login(twofactor_code, &selected_data?).await?,
Some(TwoFactorType::Duo) => { Some(TwoFactorType::Duo) => {
match CONFIG.duo_use_iframe() { if CONFIG.duo_use_iframe() {
true => {
// Legacy iframe prompt flow // Legacy iframe prompt flow
duo::validate_duo_login(&user.email, twofactor_code, conn).await? duo::validate_duo_login(&user.email, twofactor_code, conn).await?;
} } else {
false => {
// OIDC based flow // OIDC based flow
duo_oidc::validate_duo_login( duo_oidc::validate_duo_login(
&user.email, &user.email,
@ -835,12 +828,11 @@ async fn twofactor_auth(
data.device_identifier.as_ref().unwrap(), data.device_identifier.as_ref().unwrap(),
conn, conn,
) )
.await? .await?;
}
} }
} }
Some(TwoFactorType::Email) => { Some(TwoFactorType::Email) => {
email::validate_email_code_str(&user.uuid, twofactor_code, &selected_data?, &ip.ip, conn).await? email::validate_email_code_str(&user.uuid, twofactor_code, &selected_data?, &ip.ip, conn).await?;
} }
Some(TwoFactorType::Remember) => { Some(TwoFactorType::Remember) => {
match device.twofactor_remember { match device.twofactor_remember {
@ -848,7 +840,7 @@ async fn twofactor_auth(
// If it is invalid we need to trigger the 2FA Login prompt // If it is invalid we need to trigger the 2FA Login prompt
Some(ref token) Some(ref token)
if !CONFIG.disable_2fa_remember() if !CONFIG.disable_2fa_remember()
&& (ct_eq(token, twofactor_code) && (crypto::ct_eq(token, twofactor_code)
&& auth::decode_2fa_remember(twofactor_code) && auth::decode_2fa_remember(twofactor_code)
.is_ok_and(|t| t.sub == device.uuid && t.user_uuid == user.uuid)) => {} .is_ok_and(|t| t.sub == device.uuid && t.user_uuid == user.uuid)) => {}
_ => { _ => {
@ -859,7 +851,7 @@ async fn twofactor_auth(
device.save(true, conn).await?; device.save(true, conn).await?;
} }
err_json!( err_json!(
_json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?, json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
"2FA Remember token not provided or expired" "2FA Remember token not provided or expired"
) )
} }
@ -900,11 +892,11 @@ async fn twofactor_auth(
Ok(two_factor) Ok(two_factor)
} }
fn _selected_data(tf: Option<TwoFactor>) -> ApiResult<String> { fn selected_data(tf: Option<TwoFactor>) -> ApiResult<String> {
tf.map(|t| t.data).map_res("Two factor doesn't exist") tf.map(|t| t.data).map_res("Two factor doesn't exist")
} }
async fn _json_err_twofactor( async fn json_err_twofactor(
providers: &[i32], providers: &[i32],
user_id: &UserId, user_id: &UserId,
data: &ConnectData, data: &ConnectData,
@ -925,29 +917,26 @@ async fn _json_err_twofactor(
result["TwoFactorProviders2"][provider.to_string()] = Value::Null; result["TwoFactorProviders2"][provider.to_string()] = Value::Null;
match TwoFactorType::from_i32(*provider) { match TwoFactorType::from_i32(*provider) {
Some(TwoFactorType::Authenticator) => { /* Nothing to do for TOTP */ }
Some(TwoFactorType::Webauthn) if CONFIG.is_webauthn_2fa_supported() => { Some(TwoFactorType::Webauthn) if CONFIG.is_webauthn_2fa_supported() => {
let request = webauthn::generate_webauthn_login(user_id, conn).await?; let request = webauthn::generate_webauthn_login(user_id, conn).await?;
result["TwoFactorProviders2"][provider.to_string()] = request.0; result["TwoFactorProviders2"][provider.to_string()] = request.0;
} }
Some(TwoFactorType::Duo) => { Some(TwoFactorType::Duo) => {
let email = match User::find_by_uuid(user_id, conn).await { let email = if let Some(u) = User::find_by_uuid(user_id, conn).await {
Some(u) => u.email, u.email
None => err!("User does not exist"), } else {
err!("User does not exist")
}; };
match CONFIG.duo_use_iframe() { if CONFIG.duo_use_iframe() {
true => {
// Legacy iframe prompt flow // Legacy iframe prompt flow
let (signature, host) = duo::generate_duo_signature(&email, conn).await?; let (signature, host) = duo::generate_duo_signature(&email, conn).await?;
result["TwoFactorProviders2"][provider.to_string()] = json!({ result["TwoFactorProviders2"][provider.to_string()] = json!({
"Host": host, "Host": host,
"Signature": signature, "Signature": signature,
}) });
} } else {
false => {
// OIDC based flow // OIDC based flow
let auth_url = duo_oidc::get_duo_auth_url( let auth_url = duo_oidc::get_duo_auth_url(
&email, &email,
@ -959,8 +948,7 @@ async fn _json_err_twofactor(
result["TwoFactorProviders2"][provider.to_string()] = json!({ result["TwoFactorProviders2"][provider.to_string()] = json!({
"AuthUrl": auth_url, "AuthUrl": auth_url,
}) });
}
} }
} }
@ -973,7 +961,7 @@ async fn _json_err_twofactor(
result["TwoFactorProviders2"][provider.to_string()] = json!({ result["TwoFactorProviders2"][provider.to_string()] = json!({
"Nfc": yubikey_metadata.nfc, "Nfc": yubikey_metadata.nfc,
}) });
} }
Some(tf_type @ TwoFactorType::Email) => { Some(tf_type @ TwoFactorType::Email) => {
@ -991,16 +979,30 @@ async fn _json_err_twofactor(
// Send email immediately if email is the only 2FA option. // Send email immediately if email is the only 2FA option.
if providers.len() == 1 && !disabled_send { if providers.len() == 1 && !disabled_send {
email::send_token(user_id, conn).await? email::send_token(user_id, conn).await?;
} }
let email_data = email::EmailTokenData::from_json(&twofactor.data)?; let email_data = email::EmailTokenData::from_json(&twofactor.data)?;
result["TwoFactorProviders2"][provider.to_string()] = json!({ result["TwoFactorProviders2"][provider.to_string()] = json!({
"Email": email::obscure_email(&email_data.email), "Email": email::obscure_email(&email_data.email),
}) });
} }
_ => {} None
| Some(
TwoFactorType::Authenticator
| TwoFactorType::EmailVerificationChallenge
| TwoFactorType::OrganizationDuo
| TwoFactorType::ProtectedActions
| TwoFactorType::RecoveryCode
| TwoFactorType::Remember
| TwoFactorType::U2f
| TwoFactorType::U2fLoginChallenge
| TwoFactorType::U2fRegisterChallenge
| TwoFactorType::Webauthn
| TwoFactorType::WebauthnLoginChallenge
| TwoFactorType::WebauthnRegisterChallenge,
) => { /* Nothing special to do for these providers */ }
} }
} }
@ -1008,18 +1010,18 @@ async fn _json_err_twofactor(
} }
#[post("/accounts/prelogin", data = "<data>")] #[post("/accounts/prelogin", data = "<data>")]
async fn prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> { async fn post_prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
_prelogin(data, conn).await prelogin(data, conn).await
} }
#[post("/accounts/prelogin/password", data = "<data>")] #[post("/accounts/prelogin/password", data = "<data>")]
async fn prelogin_password(data: Json<PreloginData>, conn: DbConn) -> Json<Value> { async fn prelogin_password(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
_prelogin(data, conn).await prelogin(data, conn).await
} }
#[post("/accounts/register", data = "<data>")] #[post("/accounts/register", data = "<data>")]
async fn identity_register(data: Json<RegisterData>, conn: DbConn) -> JsonResult { async fn identity_register(data: Json<RegisterData>, conn: DbConn) -> JsonResult {
_register(data, false, conn).await register(data, false, conn).await
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -1058,13 +1060,13 @@ async fn register_verification_email(
if should_send_mail { if should_send_mail {
let user = User::find_by_mail(&data.email, &conn).await; let user = User::find_by_mail(&data.email, &conn).await;
if user.filter(|u| u.private_key.is_some()).is_some() { if user.as_ref().is_some_and(|u| u.private_key.is_some()) {
// There is still a timing side channel here in that the code // There is still a timing side channel here in that the code
// paths that send mail take noticeably longer than ones that don't. // paths that send mail take noticeably longer than ones that don't.
// Add a randomized sleep to mitigate this somewhat. // Add a randomized sleep to mitigate this somewhat.
use rand::{rngs::SmallRng, RngExt}; use rand::{RngExt, rngs::SmallRng};
let mut rng: SmallRng = rand::make_rng(); let mut rng: SmallRng = rand::make_rng();
let sleep_ms = rng.random_range(900..=1100) as u64; let sleep_ms: u64 = rng.random_range(900..=1100);
tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await; tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
} else { } else {
mail::send_register_verify_email(&data.email, &token).await?; mail::send_register_verify_email(&data.email, &token).await?;
@ -1080,7 +1082,7 @@ async fn register_verification_email(
#[post("/accounts/register/finish", data = "<data>")] #[post("/accounts/register/finish", data = "<data>")]
async fn register_finish(data: Json<RegisterData>, conn: DbConn) -> JsonResult { async fn register_finish(data: Json<RegisterData>, conn: DbConn) -> JsonResult {
_register(data, true, conn).await register(data, true, conn).await
} }
// https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts // https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts
@ -1143,7 +1145,7 @@ struct ConnectData {
#[field(name = uncased("code_verifier"))] #[field(name = uncased("code_verifier"))]
code_verifier: Option<OIDCCodeVerifier>, code_verifier: Option<OIDCCodeVerifier>,
} }
fn _check_is_some<T>(value: Option<&T>, msg: &str) -> EmptyResult { fn check_is_some<T>(value: Option<&T>, msg: &str) -> EmptyResult {
if value.is_none() { if value.is_none() {
err!(msg) err!(msg)
} }
@ -1166,7 +1168,7 @@ const SSO_BINDING_COOKIE: &str = "VW_SSO_BINDING";
#[get("/connect/oidc-signin?<code>&<state>", rank = 1)] #[get("/connect/oidc-signin?<code>&<state>", rank = 1)]
async fn oidcsignin(code: OIDCCode, state: String, cookies: &CookieJar<'_>, mut conn: DbConn) -> ApiResult<Redirect> { async fn oidcsignin(code: OIDCCode, state: String, cookies: &CookieJar<'_>, mut conn: DbConn) -> ApiResult<Redirect> {
_oidcsignin_redirect(state, code, None, cookies, &mut conn).await oidcsignin_redirect(state, code, None, cookies, &mut conn).await
} }
// Bitwarden client appear to only care for code and state // Bitwarden client appear to only care for code and state
@ -1180,7 +1182,7 @@ async fn oidcsignin_error(
cookies: &CookieJar<'_>, cookies: &CookieJar<'_>,
mut conn: DbConn, mut conn: DbConn,
) -> ApiResult<Redirect> { ) -> ApiResult<Redirect> {
_oidcsignin_redirect( oidcsignin_redirect(
state.clone(), state.clone(),
state.into(), state.into(),
Some(OIDCCodeResponseError { Some(OIDCCodeResponseError {
@ -1195,7 +1197,8 @@ async fn oidcsignin_error(
// The state was encoded using Base64 to ensure no issue with providers. // The state was encoded using Base64 to ensure no issue with providers.
// iss and scope parameters are needed for redirection to work on IOS. // iss and scope parameters are needed for redirection to work on IOS.
async fn _oidcsignin_redirect( // We pass the state as the code to get it back later on.
async fn oidcsignin_redirect(
base64_state: String, base64_state: String,
code: OIDCCode, code: OIDCCode,
error: Option<OIDCCodeResponseError>, error: Option<OIDCCodeResponseError>,
@ -1204,14 +1207,13 @@ async fn _oidcsignin_redirect(
) -> ApiResult<Redirect> { ) -> ApiResult<Redirect> {
let state = sso::decode_state(&base64_state)?; let state = sso::decode_state(&base64_state)?;
let mut sso_auth = match SsoAuth::find(&state, conn).await { let Some(mut sso_auth) = SsoAuth::find(&state, conn).await else {
None => err!(format!("Cannot retrieve sso_auth for {state}")), err!(format!("Cannot retrieve sso_auth for {state}"))
Some(sso_auth) => sso_auth,
}; };
// Browser-binding check // Browser-binding check
// The cookie was set on /connect/authorize and must come from the same browser that initiated the flow. // The cookie was set on /connect/authorize and must come from the same browser that initiated the flow.
let cookie_value = cookies.get(SSO_BINDING_COOKIE).map(|c| c.value().to_string()); let cookie_value = cookies.get(SSO_BINDING_COOKIE).map(|c| c.value().to_owned());
let provided_hash = cookie_value.as_deref().map(|v| crypto::sha256_hex(v.as_bytes())); let provided_hash = cookie_value.as_deref().map(|v| crypto::sha256_hex(v.as_bytes()));
match (sso_auth.binding_hash.as_deref(), provided_hash.as_deref()) { match (sso_auth.binding_hash.as_deref(), provided_hash.as_deref()) {
(Some(expected), Some(actual)) if crypto::ct_eq(expected, actual) => {} (Some(expected), Some(actual)) if crypto::ct_eq(expected, actual) => {}

9
src/api/mod.rs

@ -32,11 +32,13 @@ pub use crate::api::{
web::routes as web_routes, web::routes as web_routes,
web::static_files, web::static_files,
}; };
use crate::db::{ use crate::{
models::{OrgPolicy, OrgPolicyType, User}, CONFIG,
db::{
DbConn, DbConn,
models::{OrgPolicy, OrgPolicyType, User},
},
}; };
use crate::CONFIG;
// Type aliases for API methods results // Type aliases for API methods results
pub type ApiResult<T> = Result<T, crate::error::Error>; pub type ApiResult<T> = Result<T, crate::error::Error>;
@ -74,6 +76,7 @@ impl PasswordOrOtpData {
} }
} }
#[expect(clippy::struct_excessive_bools, reason = "Bitwarden clients expect the data in this specific format")]
#[derive(Debug, Default, Deserialize, Serialize)] #[derive(Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct MasterPasswordPolicy { pub struct MasterPasswordPolicy {

35
src/api/notifications.rs

@ -6,17 +6,22 @@ use std::{
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use rmpv::Value; use rmpv::Value;
use rocket::{futures::StreamExt, Route}; use rocket::{Route, futures::StreamExt};
use rocket_ws::{Message, WebSocket}; use rocket_ws::{Message, WebSocket};
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use crate::{ use crate::{
CONFIG, Error,
auth::{ClientIp, WsAccessTokenHeader}, auth::{ClientIp, WsAccessTokenHeader},
db::{ db::{
models::{AuthRequestId, Cipher, CollectionId, Device, DeviceId, Folder, PushId, Send as DbSend, User, UserId},
DbConn, DbConn,
models::{AuthRequestId, Cipher, CollectionId, Device, DeviceId, Folder, PushId, Send as DbSend, User, UserId},
}, },
Error, CONFIG, };
use super::{
push::push_auth_request, push::push_auth_response, push_cipher_update, push_folder_update, push_logout,
push_send_update, push_user_update,
}; };
pub static WS_USERS: LazyLock<Arc<WebSocketUsers>> = LazyLock::new(|| { pub static WS_USERS: LazyLock<Arc<WebSocketUsers>> = LazyLock::new(|| {
@ -31,11 +36,6 @@ pub static WS_ANONYMOUS_SUBSCRIPTIONS: LazyLock<Arc<AnonymousWebSocketSubscripti
}) })
}); });
use super::{
push::push_auth_request, push::push_auth_response, push_cipher_update, push_folder_update, push_logout,
push_send_update, push_user_update,
};
static NOTIFICATIONS_DISABLED: LazyLock<bool> = LazyLock::new(|| !CONFIG.enable_websocket() && !CONFIG.push_enabled()); static NOTIFICATIONS_DISABLED: LazyLock<bool> = LazyLock::new(|| !CONFIG.enable_websocket() && !CONFIG.push_enabled());
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
@ -102,7 +102,7 @@ impl Drop for WSAnonymousEntryMapGuard {
} }
} }
#[allow(tail_expr_drop_order)] #[expect(tail_expr_drop_order)]
#[get("/hub?<data..>")] #[get("/hub?<data..>")]
fn websockets_hub<'r>( fn websockets_hub<'r>(
ws: WebSocket, ws: WebSocket,
@ -186,7 +186,7 @@ fn websockets_hub<'r>(
}) })
} }
#[allow(tail_expr_drop_order)] #[expect(tail_expr_drop_order)]
#[get("/anonymous-hub?<token..>")] #[get("/anonymous-hub?<token..>")]
fn anonymous_websockets_hub<'r>(ws: WebSocket, token: String, ip: ClientIp) -> Result<rocket_ws::Stream!['r], Error> { fn anonymous_websockets_hub<'r>(ws: WebSocket, token: String, ip: ClientIp) -> Result<rocket_ws::Stream!['r], Error> {
info!("Accepting Anonymous Rocket WS connection from {}", ip.ip); info!("Accepting Anonymous Rocket WS connection from {}", ip.ip);
@ -268,14 +268,15 @@ fn serialize(val: &Value) -> Vec<u8> {
let mut len_buf: Vec<u8> = Vec::new(); let mut len_buf: Vec<u8> = Vec::new();
loop { loop {
let mut size_part = size & 0x7f; #[expect(clippy::cast_possible_truncation, reason = "masked to 7 bits, fits u8")]
let mut size_part = (size & 0x7f) as u8;
size >>= 7; size >>= 7;
if size > 0 { if size > 0 {
size_part |= 0x80; size_part |= 0x80;
} }
len_buf.push(size_part as u8); len_buf.push(size_part);
if size == 0 { if size == 0 {
break; break;
@ -329,7 +330,7 @@ pub struct WebSocketUsers {
impl WebSocketUsers { impl WebSocketUsers {
async fn send_update(&self, user_id: &UserId, data: &[u8]) { async fn send_update(&self, user_id: &UserId, data: &[u8]) {
if let Some(user) = self.map.get(user_id.as_ref()).map(|v| v.clone()) { if let Some(user) = self.map.get(user_id.as_ref()).map(|v| v.clone()) {
for (_, sender) in user.iter() { for (_, sender) in &user {
if let Err(e) = sender.send(Message::binary(data)).await { if let Err(e) = sender.send(Message::binary(data)).await {
error!("Error sending WS update {e}"); error!("Error sending WS update {e}");
} }
@ -538,12 +539,12 @@ pub struct AnonymousWebSocketSubscriptions {
impl AnonymousWebSocketSubscriptions { impl AnonymousWebSocketSubscriptions {
async fn send_update(&self, token: &str, data: &[u8]) { async fn send_update(&self, token: &str, data: &[u8]) {
if let Some(sender) = self.map.get(token).map(|v| v.clone()) { if let Some(sender) = self.map.get(token).map(|v| v.clone())
if let Err(e) = sender.send(Message::binary(data)).await { && let Err(e) = sender.send(Message::binary(data)).await
{
error!("Error sending WS update {e}"); error!("Error sending WS update {e}");
} }
} }
}
pub async fn send_auth_response(&self, user_id: &UserId, auth_request_id: &AuthRequestId) { pub async fn send_auth_response(&self, user_id: &UserId, auth_request_id: &AuthRequestId) {
if !CONFIG.enable_websocket() { if !CONFIG.enable_websocket() {
@ -582,7 +583,7 @@ fn create_update(payload: Vec<(Value, Value)>, ut: UpdateType, acting_device_id:
V::Nil, V::Nil,
"ReceiveMessage".into(), "ReceiveMessage".into(),
V::Array(vec![V::Map(vec![ V::Array(vec![V::Map(vec![
("ContextId".into(), acting_device_id.map(|v| v.to_string().into()).unwrap_or_else(|| V::Nil)), ("ContextId".into(), acting_device_id.map_or(V::Nil, |v| v.to_string().into())),
("Type".into(), (ut as i32).into()), ("Type".into(), (ut as i32).into()),
("Payload".into(), payload.into()), ("Payload".into(), payload.into()),
])]), ])]),

22
src/api/push.rs

@ -4,21 +4,21 @@ use std::{
}; };
use reqwest::{ use reqwest::{
header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE},
Method, Method,
header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE},
}; };
use serde_json::Value; use serde_json::Value;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::{ use crate::{
CONFIG,
api::{ApiResult, EmptyResult, UpdateType}, api::{ApiResult, EmptyResult, UpdateType},
db::{ db::{
models::{AuthRequestId, Cipher, Device, Folder, PushId, Send, User, UserId},
DbConn, DbConn,
models::{AuthRequestId, Cipher, Device, Folder, PushId, Send, User, UserId},
}, },
http_client::make_http_request, http_client::make_http_request,
util::{format_date, get_uuid}, util::{format_date, get_uuid},
CONFIG,
}; };
#[derive(Deserialize)] #[derive(Deserialize)]
@ -74,9 +74,9 @@ async fn get_auth_api_token() -> ApiResult<String> {
}; };
let mut api_token = API_TOKEN.write().await; let mut api_token = API_TOKEN.write().await;
api_token.valid_until = Instant::now() // Token valid for half the specified time
.checked_add(Duration::new((json_pushtoken.expires_in / 2) as u64, 0)) // Token valid for half the specified time let half_expires_in = u64::from((json_pushtoken.expires_in / 2).max(0).cast_unsigned());
.unwrap(); api_token.valid_until = Instant::now().checked_add(Duration::from_secs(half_expires_in)).unwrap();
api_token.access_token = json_pushtoken.access_token; api_token.access_token = json_pushtoken.access_token;
@ -161,7 +161,7 @@ pub async fn push_cipher_update(ut: UpdateType, cipher: &Cipher, device: &Device
// We shouldn't send a push notification on cipher update if the cipher belongs to an organization, this isn't implemented in the upstream server too. // We shouldn't send a push notification on cipher update if the cipher belongs to an organization, this isn't implemented in the upstream server too.
if cipher.organization_uuid.is_some() { if cipher.organization_uuid.is_some() {
return; return;
}; }
let Some(user_id) = &cipher.user_uuid else { let Some(user_id) = &cipher.user_uuid else {
debug!("Cipher has no uuid"); debug!("Cipher has no uuid");
return; return;
@ -244,8 +244,9 @@ pub async fn push_folder_update(ut: UpdateType, folder: &Folder, device: &Device
} }
pub async fn push_send_update(ut: UpdateType, send: &Send, device: &Device, conn: &DbConn) { pub async fn push_send_update(ut: UpdateType, send: &Send, device: &Device, conn: &DbConn) {
if let Some(s) = &send.user_uuid { if let Some(s) = &send.user_uuid
if Device::check_user_has_push_device(s, conn).await { && Device::check_user_has_push_device(s, conn).await
{
tokio::task::spawn(send_to_push_relay(json!({ tokio::task::spawn(send_to_push_relay(json!({
"userId": send.user_uuid, "userId": send.user_uuid,
"organizationId": null, "organizationId": null,
@ -261,7 +262,6 @@ pub async fn push_send_update(ut: UpdateType, send: &Send, device: &Device, conn
"installationId": null "installationId": null
}))); })));
} }
}
} }
async fn send_to_push_relay(notification_data: Value) { async fn send_to_push_relay(notification_data: Value) {
@ -296,7 +296,7 @@ async fn send_to_push_relay(notification_data: Value) {
.await .await
{ {
error!("An error occurred while sending a send update to the push relay: {e}"); error!("An error occurred while sending a send update to the push relay: {e}");
}; }
} }
pub async fn push_auth_request(user_id: &UserId, auth_request_id: &str, device: &Device, conn: &DbConn) { pub async fn push_auth_request(user_id: &UserId, auth_request_id: &str, device: &Device, conn: &DbConn) {

20
src/api/web.rs

@ -1,21 +1,24 @@
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use rocket::{ use rocket::{
Catcher, Route,
fs::NamedFile, fs::NamedFile,
http::ContentType, http::ContentType,
response::{content::RawCss as Css, content::RawHtml as Html, Redirect}, response::{Redirect, content::RawCss as Css, content::RawHtml as Html},
serde::json::Json, serde::json::Json,
Catcher, Route,
}; };
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
api::{core::now, ApiResult, EmptyResult}, CONFIG,
api::{ApiResult, EmptyResult, core::now},
auth::decode_file_download, auth::decode_file_download,
db::models::{AttachmentId, CipherId}, db::{
DbConn,
models::{AttachmentId, CipherId},
},
error::Error, error::Error,
util::Cached, util::Cached,
CONFIG,
}; };
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
@ -28,7 +31,7 @@ pub fn routes() -> Vec<Route> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
if CONFIG.reload_templates() { if CONFIG.reload_templates() {
routes.append(&mut routes![_static_files_dev]); routes.append(&mut routes![static_files_dev]);
} }
routes routes
@ -178,7 +181,6 @@ async fn attachments(cipher_id: CipherId, file_id: AttachmentId, token: String)
} }
// We use DbConn here to let the alive healthcheck also verify the database connection. // We use DbConn here to let the alive healthcheck also verify the database connection.
use crate::db::DbConn;
#[get("/alive")] #[get("/alive")]
fn alive(_conn: DbConn) -> Json<String> { fn alive(_conn: DbConn) -> Json<String> {
now() now()
@ -197,7 +199,7 @@ fn alive_head(_conn: DbConn) -> EmptyResult {
// NOTE: Do not forget to add any new files added to the `static_files` function below! // NOTE: Do not forget to add any new files added to the `static_files` function below!
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
#[get("/vw_static/<filename>", rank = 1)] #[get("/vw_static/<filename>", rank = 1)]
pub async fn _static_files_dev(filename: PathBuf) -> Option<NamedFile> { pub async fn static_files_dev(filename: PathBuf) -> Option<NamedFile> {
warn!("LOADING STATIC FILES FROM DISK"); warn!("LOADING STATIC FILES FROM DISK");
let file = filename.to_str().unwrap_or_default(); let file = filename.to_str().unwrap_or_default();
let ext = filename.extension().unwrap_or_default(); let ext = filename.extension().unwrap_or_default();
@ -210,7 +212,7 @@ pub async fn _static_files_dev(filename: PathBuf) -> Option<NamedFile> {
if let Ok(path) = path { if let Ok(path) = path {
return NamedFile::open(path).await.ok(); return NamedFile::open(path).await.ok();
}; }
None None
} }

138
src/auth.rs

@ -5,21 +5,30 @@ use std::{
}; };
use chrono::{DateTime, TimeDelta, Utc}; use chrono::{DateTime, TimeDelta, Utc};
use jsonwebtoken::{errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header}; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, errors::ErrorKind};
use num_traits::FromPrimitive; use num_traits::FromPrimitive;
use openssl::rsa::Rsa; use openssl::rsa::Rsa;
use serde::de::DeserializeOwned; use serde::{de::DeserializeOwned, ser::Serialize};
use serde::ser::Serialize;
use rocket::{
outcome::try_outcome,
request::{FromRequest, Outcome, Request},
};
use crate::{ use crate::{
CONFIG,
api::ApiResult, api::ApiResult,
config::PathType, config::PathType,
db::models::{ db::{
AttachmentId, CipherId, CollectionId, DeviceId, DeviceType, EmergencyAccessId, MembershipId, OrgApiKeyId, DbConn,
OrganizationId, SendFileId, SendId, UserId, models::{
AttachmentId, CipherId, Collection, CollectionId, Device, DeviceId, DeviceType, EmergencyAccessId,
Membership, MembershipId, MembershipStatus, MembershipType, OrgApiKeyId, OrganizationId, SendFileId,
SendId, User, UserId, UserStampException,
},
}, },
error::Error, error::Error,
sso, CONFIG, sso,
}; };
const JWT_ALGORITHM: Algorithm = Algorithm::RS256; const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
@ -52,12 +61,12 @@ static PRIVATE_RSA_KEY: OnceLock<EncodingKey> = OnceLock::new();
static PUBLIC_RSA_KEY: OnceLock<DecodingKey> = OnceLock::new(); static PUBLIC_RSA_KEY: OnceLock<DecodingKey> = OnceLock::new();
pub async fn initialize_keys() -> Result<(), Error> { pub async fn initialize_keys() -> Result<(), Error> {
use std::io::Error; use std::io::Error as IoError;
let rsa_key_filename = crate::storage::file_name(&CONFIG.private_rsa_key()) let rsa_key_filename = crate::storage::file_name(&CONFIG.private_rsa_key())
.ok_or_else(|| Error::other("Private RSA key path missing filename"))?; .ok_or_else(|| IoError::other("Private RSA key path missing filename"))?;
let operator = CONFIG.opendal_operator_for_path_type(&PathType::RsaKey).map_err(Error::other)?; let operator = CONFIG.opendal_operator_for_path_type(&PathType::RsaKey).map_err(IoError::other)?;
let priv_key_buffer = match operator.read(&rsa_key_filename).await { let priv_key_buffer = match operator.read(&rsa_key_filename).await {
Ok(buffer) => Some(buffer), Ok(buffer) => Some(buffer),
@ -226,7 +235,7 @@ impl LoginJwtClaims {
// let orgmanager: Vec<_> = orgs.iter().filter(|o| o.atype == 3).map(|o| o.org_uuid.clone()).collect(); // let orgmanager: Vec<_> = orgs.iter().filter(|o| o.atype == 3).map(|o| o.org_uuid.clone()).collect();
if exp <= (now + *BW_EXPIRATION).timestamp() { if exp <= (now + *BW_EXPIRATION).timestamp() {
warn!("Raise access_token lifetime to more than 5min.") warn!("Raise access_token lifetime to more than 5min.");
} }
// Create the JWT claims struct, to send to the client // Create the JWT claims struct, to send to the client
@ -253,7 +262,7 @@ impl LoginJwtClaims {
sstamp: user.security_stamp.clone(), sstamp: user.security_stamp.clone(),
device: device.uuid.clone(), device: device.uuid.clone(),
devicetype: DeviceType::from_i32(device.atype).to_string(), devicetype: DeviceType::from_i32(device.atype).to_string(),
client_id: client_id.unwrap_or("undefined".to_string()), client_id: client_id.unwrap_or("undefined".to_owned()),
scope, scope,
amr: vec!["Application".into()], amr: vec!["Application".into()],
} }
@ -506,7 +515,7 @@ pub fn generate_admin_claims() -> BasicJwtClaims {
nbf: time_now.timestamp(), nbf: time_now.timestamp(),
exp: (time_now + TimeDelta::try_minutes(CONFIG.admin_session_lifetime()).unwrap()).timestamp(), exp: (time_now + TimeDelta::try_minutes(CONFIG.admin_session_lifetime()).unwrap()).timestamp(),
iss: JWT_ADMIN_ISSUER.to_string(), iss: JWT_ADMIN_ISSUER.to_string(),
sub: "admin_panel".to_string(), sub: "admin_panel".to_owned(),
} }
} }
@ -523,16 +532,6 @@ pub fn generate_send_claims(send_id: &SendId, file_id: &SendFileId) -> BasicJwtC
// //
// Bearer token authentication // Bearer token authentication
// //
use rocket::{
outcome::try_outcome,
request::{FromRequest, Outcome, Request},
};
use crate::db::{
models::{Collection, Device, Membership, MembershipStatus, MembershipType, User, UserStampException},
DbConn,
};
pub struct Host { pub struct Host {
pub host: String, pub host: String,
} }
@ -548,7 +547,7 @@ impl<'r> FromRequest<'r> for Host {
let host = if CONFIG.domain_set() { let host = if CONFIG.domain_set() {
CONFIG.domain() CONFIG.domain()
} else if let Some(referer) = headers.get_one("Referer") { } else if let Some(referer) = headers.get_one("Referer") {
referer.to_string() referer.to_owned()
} else { } else {
// Try to guess from the headers // Try to guess from the headers
let protocol = if let Some(proto) = headers.get_one("X-Forwarded-Proto") { let protocol = if let Some(proto) = headers.get_one("X-Forwarded-Proto") {
@ -584,13 +583,15 @@ impl<'r> FromRequest<'r> for ClientHeaders {
type Error = &'static str; type Error = &'static str;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let ip = match ClientIp::from_request(request).await { let Outcome::Success(ip) = ClientIp::from_request(request).await else {
Outcome::Success(ip) => ip, err_handler!("Error getting Client IP")
_ => err_handler!("Error getting Client IP"),
}; };
// When unknown or unable to parse, return 14, which is 'Unknown Browser' // When unknown or unable to parse, return 'UnknownBrowser'
let device_type: i32 = let device_type: i32 = request
request.headers().get_one("device-type").map(|d| d.parse().unwrap_or(14)).unwrap_or_else(|| 14); .headers()
.get_one("device-type")
.and_then(|d| d.parse().ok())
.unwrap_or(DeviceType::UnknownBrowser as i32);
Outcome::Success(ClientHeaders { Outcome::Success(ClientHeaders {
device_type, device_type,
@ -614,18 +615,19 @@ impl<'r> FromRequest<'r> for Headers {
let headers = request.headers(); let headers = request.headers();
let host = try_outcome!(Host::from_request(request).await).host; let host = try_outcome!(Host::from_request(request).await).host;
let ip = match ClientIp::from_request(request).await { let Outcome::Success(ip) = ClientIp::from_request(request).await else {
Outcome::Success(ip) => ip, err_handler!("Error getting Client IP")
_ => err_handler!("Error getting Client IP"),
}; };
// Get access_token // Get access_token
let access_token: &str = match headers.get_one("Authorization") { let access_token: &str = if let Some(a) = headers.get_one("Authorization") {
Some(a) => match a.rsplit("Bearer ").next() { if let Some(split) = a.rsplit("Bearer ").next() {
Some(split) => split, split
None => err_handler!("No access token provided"), } else {
}, err_handler!("No access token provided")
None => err_handler!("No access token provided"), }
} else {
err_handler!("No access token provided")
}; };
// Check JWT token is valid and get device and user from it // Check JWT token is valid and get device and user from it
@ -636,9 +638,8 @@ impl<'r> FromRequest<'r> for Headers {
let device_id = claims.device; let device_id = claims.device;
let user_id = claims.sub; let user_id = claims.sub;
let conn = match DbConn::from_request(request).await { let Outcome::Success(conn) = DbConn::from_request(request).await else {
Outcome::Success(conn) => conn, err_handler!("Error getting DB")
_ => err_handler!("Error getting DB"),
}; };
let Some(device) = Device::find_by_uuid_and_user(&device_id, &user_id, &conn).await else { let Some(device) = Device::find_by_uuid_and_user(&device_id, &user_id, &conn).await else {
@ -669,7 +670,7 @@ impl<'r> FromRequest<'r> for Headers {
error!("Error updating user: {e:#?}"); error!("Error updating user: {e:#?}");
} }
err_handler!("Stamp exception is expired") err_handler!("Stamp exception is expired")
} else if !stamp_exception.routes.contains(&current_route.to_string()) { } else if !stamp_exception.routes.contains(&current_route.to_owned()) {
err_handler!("Invalid security stamp: Current route and exception route do not match") err_handler!("Invalid security stamp: Current route and exception route do not match")
} else if stamp_exception.security_stamp != claims.sstamp { } else if stamp_exception.security_stamp != claims.sstamp {
err_handler!("Invalid security stamp for matched stamp exception") err_handler!("Invalid security stamp for matched stamp exception")
@ -757,9 +758,8 @@ impl<'r> FromRequest<'r> for OrgHeaders {
match url_org_id { match url_org_id {
Some(org_id) if uuid::Uuid::parse_str(&org_id).is_ok() => { Some(org_id) if uuid::Uuid::parse_str(&org_id).is_ok() => {
let conn = match DbConn::from_request(request).await { let Outcome::Success(conn) = DbConn::from_request(request).await else {
Outcome::Success(conn) => conn, err_handler!("Error getting DB")
_ => err_handler!("Error getting DB"),
}; };
let user = headers.user; let user = headers.user;
@ -831,17 +831,17 @@ impl<'r> FromRequest<'r> for AdminHeaders {
// but there could be cases where it is a query value. // but there could be cases where it is a query value.
// First check the path, if this is not a valid uuid, try the query values. // First check the path, if this is not a valid uuid, try the query values.
fn get_col_id(request: &Request<'_>) -> Option<CollectionId> { fn get_col_id(request: &Request<'_>) -> Option<CollectionId> {
if let Some(Ok(col_id)) = request.param::<String>(3) { if let Some(Ok(col_id)) = request.param::<String>(3)
if uuid::Uuid::parse_str(&col_id).is_ok() { && uuid::Uuid::parse_str(&col_id).is_ok()
{
return Some(col_id.into()); return Some(col_id.into());
} }
}
if let Some(Ok(col_id)) = request.query_value::<String>("collectionId") { if let Some(Ok(col_id)) = request.query_value::<String>("collectionId")
if uuid::Uuid::parse_str(&col_id).is_ok() { && uuid::Uuid::parse_str(&col_id).is_ok()
{
return Some(col_id.into()); return Some(col_id.into());
} }
}
None None
} }
@ -864,18 +864,16 @@ impl<'r> FromRequest<'r> for ManagerHeaders {
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let headers = try_outcome!(OrgHeaders::from_request(request).await); let headers = try_outcome!(OrgHeaders::from_request(request).await);
if headers.is_confirmed_and_manager() { if headers.is_confirmed_and_manager() {
match get_col_id(request) { if let Some(col_id) = get_col_id(request) {
Some(col_id) => { let Outcome::Success(conn) = DbConn::from_request(request).await else {
let conn = match DbConn::from_request(request).await { err_handler!("Error getting DB")
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"),
}; };
if !Collection::is_coll_manageable_by_user(&col_id, &headers.membership.user_uuid, &conn).await { if !Collection::is_coll_manageable_by_user(&col_id, &headers.membership.user_uuid, &conn).await {
err_handler!("The current user isn't a manager for this collection") err_handler!("The current user isn't a manager for this collection")
} }
} } else {
_ => err_handler!("Error getting the collection id"), err_handler!("Error getting the collection id")
} }
Outcome::Success(Self { Outcome::Success(Self {
@ -1036,7 +1034,7 @@ impl From<OrgMemberHeaders> for Headers {
// //
// Client IP address detection // Client IP address detection
// //
#[derive(Copy, Clone)]
pub struct ClientIp { pub struct ClientIp {
pub ip: IpAddr, pub ip: IpAddr,
} }
@ -1068,6 +1066,7 @@ impl<'r> FromRequest<'r> for ClientIp {
} }
} }
#[derive(Copy, Clone)]
pub struct Secure { pub struct Secure {
pub https: bool, pub https: bool,
} }
@ -1153,15 +1152,14 @@ pub enum AuthMethod {
impl AuthMethod { impl AuthMethod {
pub fn scope(&self) -> String { pub fn scope(&self) -> String {
match self { match self {
AuthMethod::OrgApiKey => "api.organization".to_string(), AuthMethod::OrgApiKey => "api.organization".to_owned(),
AuthMethod::Password => "api offline_access".to_string(), AuthMethod::UserApiKey => "api".to_owned(),
AuthMethod::Sso => "api offline_access".to_string(), AuthMethod::Password | AuthMethod::Sso => "api offline_access".to_owned(),
AuthMethod::UserApiKey => "api".to_string(),
} }
} }
pub fn scope_vec(&self) -> Vec<String> { pub fn scope_vec(&self) -> Vec<String> {
self.scope().split_whitespace().map(str::to_string).collect() self.scope().split_whitespace().map(str::to_owned).collect()
} }
pub fn check_scope(&self, scope: Option<&String>) -> ApiResult<String> { pub fn check_scope(&self, scope: Option<&String>) -> ApiResult<String> {
@ -1274,17 +1272,15 @@ pub async fn refresh_tokens(
}; };
// Get device by refresh token // Get device by refresh token
let mut device = match Device::find_by_refresh_token(&refresh_claims.device_token, conn).await { let Some(mut device) = Device::find_by_refresh_token(&refresh_claims.device_token, conn).await else {
None => err!("Invalid refresh token"), err!("Invalid refresh token")
Some(device) => device,
}; };
// Save to update `updated_at`. // Save to update `updated_at`.
device.save(true, conn).await?; device.save(true, conn).await?;
let user = match User::find_by_uuid(&device.user_uuid, conn).await { let Some(user) = User::find_by_uuid(&device.user_uuid, conn).await else {
None => err!("Impossible to find user"), err!("Impossible to find user")
Some(user) => user,
}; };
let auth_tokens = match refresh_claims.sub { let auth_tokens = match refresh_claims.sub {

239
src/config.rs

@ -3,8 +3,8 @@ use std::{
fmt, fmt,
process::exit, process::exit,
sync::{ sync::{
atomic::{AtomicBool, Ordering},
LazyLock, RwLock, LazyLock, RwLock,
atomic::{AtomicBool, Ordering},
}, },
}; };
@ -16,8 +16,8 @@ use crate::{
error::Error, error::Error,
storage, storage,
util::{ util::{
get_active_web_release, get_env, get_env_bool, is_valid_email, parse_experimental_client_feature_flags, FeatureFlagFilter, get_active_web_release, get_env, get_env_bool, is_valid_email,
FeatureFlagFilter, parse_experimental_client_feature_flags,
}, },
}; };
@ -27,10 +27,10 @@ static CONFIG_FILE: LazyLock<String> = LazyLock::new(|| {
}); });
static CONFIG_FILE_PARENT_DIR: LazyLock<String> = static CONFIG_FILE_PARENT_DIR: LazyLock<String> =
LazyLock::new(|| storage::parent(&CONFIG_FILE).unwrap_or_else(|| "data".to_string())); LazyLock::new(|| storage::parent(&CONFIG_FILE).unwrap_or_else(|| "data".to_owned()));
static CONFIG_FILENAME: LazyLock<String> = static CONFIG_FILENAME: LazyLock<String> =
LazyLock::new(|| storage::file_name(&CONFIG_FILE).unwrap_or_else(|| "config.json".to_string())); LazyLock::new(|| storage::file_name(&CONFIG_FILE).unwrap_or_else(|| "config.json".to_owned()));
pub static SKIP_CONFIG_VALIDATION: AtomicBool = AtomicBool::new(false); pub static SKIP_CONFIG_VALIDATION: AtomicBool = AtomicBool::new(false);
@ -360,13 +360,7 @@ macro_rules! make_config {
)+)+ )+)+
pub fn prepare_json(&self) -> serde_json::Value { pub fn prepare_json(&self) -> serde_json::Value {
let (def, cfg, overridden) = { fn get_form_type(rust_type: &'static str) -> &'static str {
// Lock the inner as short as possible and clone what is needed to prevent deadlocks
let inner = &self.inner.read().unwrap();
(inner._env.build(), inner.config.clone(), inner._overrides.clone())
};
fn _get_form_type(rust_type: &'static str) -> &'static str {
match rust_type { match rust_type {
"Pass" => "password", "Pass" => "password",
"String" => "text", "String" => "text",
@ -375,7 +369,7 @@ macro_rules! make_config {
} }
} }
fn _get_doc(doc_str: &'static str) -> ElementDoc { fn get_doc(doc_str: &'static str) -> ElementDoc {
let mut split = doc_str.split("|>").map(str::trim); let mut split = doc_str.split("|>").map(str::trim);
ElementDoc { ElementDoc {
name: split.next().unwrap_or_default(), name: split.next().unwrap_or_default(),
@ -383,6 +377,12 @@ macro_rules! make_config {
} }
} }
let (def, cfg, overridden) = {
// Lock the inner as short as possible and clone what is needed to prevent deadlocks
let inner = &self.inner.read().unwrap();
(inner._env.build(), inner.config.clone(), inner._overrides.clone())
};
let data: Vec<GroupData> = vec![ let data: Vec<GroupData> = vec![
$( // This repetition is for each group $( // This repetition is for each group
GroupData { GroupData {
@ -397,8 +397,8 @@ macro_rules! make_config {
name: stringify!($name), name: stringify!($name),
value: serde_json::to_value(&cfg.$name).unwrap_or_default(), value: serde_json::to_value(&cfg.$name).unwrap_or_default(),
default: serde_json::to_value(&def.$name).unwrap_or_default(), default: serde_json::to_value(&def.$name).unwrap_or_default(),
r#type: _get_form_type(stringify!($ty)), r#type: get_form_type(stringify!($ty)),
doc: _get_doc(concat!($($doc),+)), doc: get_doc(concat!($($doc),+)),
overridden: overridden.contains(&pastey::paste!(stringify!([<$name:upper>]))), overridden: overridden.contains(&pastey::paste!(stringify!([<$name:upper>]))),
}, },
)+], // End of elements repetition )+], // End of elements repetition
@ -408,9 +408,31 @@ macro_rules! make_config {
} }
pub fn get_support_json(&self) -> serde_json::Value { pub fn get_support_json(&self) -> serde_json::Value {
/// We map over the string and remove all alphanumeric, _ and - characters.
/// This is the fastest way (within micro-seconds) instead of using a regex (which takes mili-seconds)
fn privacy_mask(value: &str) -> String {
let mut n: u16 = 0;
let mut colon_match = false;
value
.chars()
.map(|c| {
n += 1;
match c {
':' if n <= 11 => {
colon_match = true;
c
}
'/' if n <= 13 && colon_match => c,
',' => c,
_ => '*',
}
})
.collect::<String>()
}
// Define which config keys need to be masked. // Define which config keys need to be masked.
// Pass types will always be masked and no need to put them in the list. // Pass types will always be masked and no need to put them in the list.
// Besides Pass, only String types will be masked via _privacy_mask. // Besides Pass, only String types will be masked via privacy_mask.
const PRIVACY_CONFIG: &[&str] = &[ const PRIVACY_CONFIG: &[&str] = &[
"allowed_connect_src", "allowed_connect_src",
"allowed_iframe_ancestors", "allowed_iframe_ancestors",
@ -437,28 +459,6 @@ macro_rules! make_config {
inner.config.clone() inner.config.clone()
}; };
/// We map over the string and remove all alphanumeric, _ and - characters.
/// This is the fastest way (within micro-seconds) instead of using a regex (which takes mili-seconds)
fn _privacy_mask(value: &str) -> String {
let mut n: u16 = 0;
let mut colon_match = false;
value
.chars()
.map(|c| {
n += 1;
match c {
':' if n <= 11 => {
colon_match = true;
c
}
'/' if n <= 13 && colon_match => c,
',' => c,
_ => '*',
}
})
.collect::<String>()
}
serde_json::Value::Object({ serde_json::Value::Object({
let mut json = serde_json::Map::new(); let mut json = serde_json::Map::new();
$($( $($(
@ -468,7 +468,7 @@ macro_rules! make_config {
for mask_key in PRIVACY_CONFIG { for mask_key in PRIVACY_CONFIG {
if let Some(value) = json.get_mut(*mask_key) { if let Some(value) = json.get_mut(*mask_key) {
if let Some(s) = value.as_str() { if let Some(s) = value.as_str() {
*value = _privacy_mask(s).into(); *value = privacy_mask(s).into();
} }
} }
} }
@ -502,7 +502,7 @@ macro_rules! make_config {
make_config! { make_config! {
folders { folders {
/// Data folder |> Main data folder /// Data folder |> Main data folder
data_folder: String, false, def, "data".to_string(); data_folder: String, false, def, "data".to_owned();
/// Database URL /// Database URL
database_url: String, false, auto, |c| format!("sqlite://{}", storage::join_path(&c.data_folder, "db.sqlite3")); database_url: String, false, auto, |c| format!("sqlite://{}", storage::join_path(&c.data_folder, "db.sqlite3"));
/// Icon cache folder /// Icon cache folder
@ -518,7 +518,7 @@ make_config! {
/// Session JWT key /// Session JWT key
rsa_key_filename: String, false, auto, |c| storage::join_path(&c.data_folder, "rsa_key"); rsa_key_filename: String, false, auto, |c| storage::join_path(&c.data_folder, "rsa_key");
/// Web vault folder /// Web vault folder
web_vault_folder: String, false, def, "web-vault/".to_string(); web_vault_folder: String, false, def, "web-vault/".to_owned();
}, },
ws { ws {
/// Enable websocket notifications /// Enable websocket notifications
@ -528,9 +528,9 @@ make_config! {
/// Enable push notifications /// Enable push notifications
push_enabled: bool, false, def, false; push_enabled: bool, false, def, false;
/// Push relay uri /// Push relay uri
push_relay_uri: String, false, def, "https://push.bitwarden.com".to_string(); push_relay_uri: String, false, def, "https://push.bitwarden.com".to_owned();
/// Push identity uri /// Push identity uri
push_identity_uri: String, false, def, "https://identity.bitwarden.com".to_string(); push_identity_uri: String, false, def, "https://identity.bitwarden.com".to_owned();
/// Installation id |> The installation id from https://bitwarden.com/host /// Installation id |> The installation id from https://bitwarden.com/host
push_installation_id: Pass, false, def, String::new(); push_installation_id: Pass, false, def, String::new();
/// Installation key |> The installation key from https://bitwarden.com/host /// Installation key |> The installation key from https://bitwarden.com/host
@ -542,38 +542,38 @@ make_config! {
job_poll_interval_ms: u64, false, def, 30_000; job_poll_interval_ms: u64, false, def, 30_000;
/// Send purge schedule |> Cron schedule of the job that checks for Sends past their deletion date. /// Send purge schedule |> Cron schedule of the job that checks for Sends past their deletion date.
/// Defaults to hourly. Set blank to disable this job. /// Defaults to hourly. Set blank to disable this job.
send_purge_schedule: String, false, def, "0 5 * * * *".to_string(); send_purge_schedule: String, false, def, "0 5 * * * *".to_owned();
/// Trash purge schedule |> Cron schedule of the job that checks for trashed items to delete permanently. /// Trash purge schedule |> Cron schedule of the job that checks for trashed items to delete permanently.
/// Defaults to daily. Set blank to disable this job. /// Defaults to daily. Set blank to disable this job.
trash_purge_schedule: String, false, def, "0 5 0 * * *".to_string(); trash_purge_schedule: String, false, def, "0 5 0 * * *".to_owned();
/// Incomplete 2FA login schedule |> Cron schedule of the job that checks for incomplete 2FA logins. /// Incomplete 2FA login schedule |> Cron schedule of the job that checks for incomplete 2FA logins.
/// Defaults to once every minute. Set blank to disable this job. /// Defaults to once every minute. Set blank to disable this job.
incomplete_2fa_schedule: String, false, def, "30 * * * * *".to_string(); incomplete_2fa_schedule: String, false, def, "30 * * * * *".to_owned();
/// Emergency notification reminder schedule |> Cron schedule of the job that sends expiration reminders to emergency access grantors. /// Emergency notification reminder schedule |> Cron schedule of the job that sends expiration reminders to emergency access grantors.
/// Defaults to hourly. (3 minutes after the hour) Set blank to disable this job. /// Defaults to hourly. (3 minutes after the hour) Set blank to disable this job.
emergency_notification_reminder_schedule: String, false, def, "0 3 * * * *".to_string(); emergency_notification_reminder_schedule: String, false, def, "0 3 * * * *".to_owned();
/// Emergency request timeout schedule |> Cron schedule of the job that grants emergency access requests that have met the required wait time. /// Emergency request timeout schedule |> Cron schedule of the job that grants emergency access requests that have met the required wait time.
/// Defaults to hourly. (7 minutes after the hour) Set blank to disable this job. /// Defaults to hourly. (7 minutes after the hour) Set blank to disable this job.
emergency_request_timeout_schedule: String, false, def, "0 7 * * * *".to_string(); emergency_request_timeout_schedule: String, false, def, "0 7 * * * *".to_owned();
/// Event cleanup schedule |> Cron schedule of the job that cleans old events from the event table. /// Event cleanup schedule |> Cron schedule of the job that cleans old events from the event table.
/// Defaults to daily. Set blank to disable this job. /// Defaults to daily. Set blank to disable this job.
event_cleanup_schedule: String, false, def, "0 10 0 * * *".to_string(); event_cleanup_schedule: String, false, def, "0 10 0 * * *".to_owned();
/// Auth Request cleanup schedule |> Cron schedule of the job that cleans old auth requests from the auth request. /// Auth Request cleanup schedule |> Cron schedule of the job that cleans old auth requests from the auth request.
/// Defaults to every minute. Set blank to disable this job. /// Defaults to every minute. Set blank to disable this job.
auth_request_purge_schedule: String, false, def, "30 * * * * *".to_string(); auth_request_purge_schedule: String, false, def, "30 * * * * *".to_owned();
/// Duo Auth context cleanup schedule |> Cron schedule of the job that cleans expired Duo contexts from the database. Does nothing if Duo MFA is disabled or set to use the legacy iframe prompt. /// Duo Auth context cleanup schedule |> Cron schedule of the job that cleans expired Duo contexts from the database. Does nothing if Duo MFA is disabled or set to use the legacy iframe prompt.
/// Defaults to once every minute. Set blank to disable this job. /// Defaults to once every minute. Set blank to disable this job.
duo_context_purge_schedule: String, false, def, "30 * * * * *".to_string(); duo_context_purge_schedule: String, false, def, "30 * * * * *".to_owned();
/// Purge incomplete SSO auth. |> Cron schedule of the job that cleans leftover auth in db due to incomplete SSO login. /// Purge incomplete SSO auth. |> Cron schedule of the job that cleans leftover auth in db due to incomplete SSO login.
/// Defaults to daily. Set blank to disable this job. /// Defaults to daily. Set blank to disable this job.
purge_incomplete_sso_auth: String, false, def, "0 20 0 * * *".to_string(); purge_incomplete_sso_auth: String, false, def, "0 20 0 * * *".to_owned();
}, },
/// General settings /// General settings
settings { settings {
/// Domain URL |> This needs to be set to the URL used to access the server, including 'http[s]://' /// Domain URL |> This needs to be set to the URL used to access the server, including 'http[s]://'
/// and port, if it's different than the default. Some server functions don't work correctly without this value /// and port, if it's different than the default. Some server functions don't work correctly without this value
domain: String, true, def, "http://localhost".to_string(); domain: String, true, def, "http://localhost".to_owned();
/// Domain Set |> Indicates if the domain is set by the admin. Otherwise the default will be used. /// Domain Set |> Indicates if the domain is set by the admin. Otherwise the default will be used.
domain_set: bool, false, def, false; domain_set: bool, false, def, false;
/// Domain origin |> Domain URL origin (in https://example.com:8443/path, https://example.com:8443 is the origin) /// Domain origin |> Domain URL origin (in https://example.com:8443/path, https://example.com:8443 is the origin)
@ -653,7 +653,7 @@ make_config! {
admin_token: Pass, true, option; admin_token: Pass, true, option;
/// Invitation organization name |> Name shown in the invitation emails that don't come from a specific organization /// Invitation organization name |> Name shown in the invitation emails that don't come from a specific organization
invitation_org_name: String, true, def, "Vaultwarden".to_string(); invitation_org_name: String, true, def, "Vaultwarden".to_owned();
/// Events days retain |> Number of days to retain events stored in the database. If unset, events are kept indefinitely. /// Events days retain |> Number of days to retain events stored in the database. If unset, events are kept indefinitely.
events_days_retain: i64, false, option; events_days_retain: i64, false, option;
@ -663,7 +663,7 @@ make_config! {
advanced { advanced {
/// Client IP header |> If not present, the remote IP is used. /// Client IP header |> If not present, the remote IP is used.
/// Set to the string "none" (without quotes), to disable any headers and just use the remote IP /// Set to the string "none" (without quotes), to disable any headers and just use the remote IP
ip_header: String, true, def, "X-Real-IP".to_string(); ip_header: String, true, def, "X-Real-IP".to_owned();
/// Internal IP header property, used to avoid recomputing each time /// Internal IP header property, used to avoid recomputing each time
_ip_header_enabled: bool, false, generated, |c| &c.ip_header.trim().to_lowercase() != "none"; _ip_header_enabled: bool, false, generated, |c| &c.ip_header.trim().to_lowercase() != "none";
/// Icon service |> The predefined icon services are: internal, bitwarden, duckduckgo, google. /// Icon service |> The predefined icon services are: internal, bitwarden, duckduckgo, google.
@ -672,7 +672,7 @@ make_config! {
/// `internal` refers to Vaultwarden's built-in icon fetching implementation. If an external /// `internal` refers to Vaultwarden's built-in icon fetching implementation. If an external
/// service is set, an icon request to Vaultwarden will return an HTTP redirect to the /// service is set, an icon request to Vaultwarden will return an HTTP redirect to the
/// corresponding icon at the external service. /// corresponding icon at the external service.
icon_service: String, false, def, "internal".to_string(); icon_service: String, false, def, "internal".to_owned();
/// _icon_service_url /// _icon_service_url
_icon_service_url: String, false, generated, |c| generate_icon_service_url(&c.icon_service); _icon_service_url: String, false, generated, |c| generate_icon_service_url(&c.icon_service);
/// _icon_service_csp /// _icon_service_csp
@ -723,14 +723,14 @@ make_config! {
/// Enable extended logging /// Enable extended logging
extended_logging: bool, false, def, true; extended_logging: bool, false, def, true;
/// Log timestamp format /// Log timestamp format
log_timestamp_format: String, true, def, "%Y-%m-%d %H:%M:%S.%3f".to_string(); log_timestamp_format: String, true, def, "%Y-%m-%d %H:%M:%S.%3f".to_owned();
/// Enable the log to output to Syslog /// Enable the log to output to Syslog
use_syslog: bool, false, def, false; use_syslog: bool, false, def, false;
/// Log file path /// Log file path
log_file: String, false, option; log_file: String, false, option;
/// Log level |> Valid values are "trace", "debug", "info", "warn", "error" and "off" /// Log level |> Valid values are "trace", "debug", "info", "warn", "error" and "off"
/// For a specific module append it as a comma separated value "info,path::to::module=debug" /// For a specific module append it as a comma separated value "info,path::to::module=debug"
log_level: String, false, def, "info".to_string(); log_level: String, false, def, "info".to_owned();
/// Enable DB WAL |> Turning this off might lead to worse performance, but might help if using vaultwarden on some exotic filesystems, /// Enable DB WAL |> Turning this off might lead to worse performance, but might help if using vaultwarden on some exotic filesystems,
/// that do not support WAL. Please make sure you read project wiki on the topic before changing this setting. /// that do not support WAL. Please make sure you read project wiki on the topic before changing this setting.
@ -812,7 +812,7 @@ make_config! {
/// Authority Server |> Base url of the OIDC provider discovery endpoint (without `/.well-known/openid-configuration`) /// Authority Server |> Base url of the OIDC provider discovery endpoint (without `/.well-known/openid-configuration`)
sso_authority: String, true, def, String::new(); sso_authority: String, true, def, String::new();
/// Authorization request scopes |> List the of the needed scope (`openid` is implicit) /// Authorization request scopes |> List the of the needed scope (`openid` is implicit)
sso_scopes: String, true, def, "email profile".to_string(); sso_scopes: String, true, def, "email profile".to_owned();
/// Authorization request extra parameters /// Authorization request extra parameters
sso_authorize_extra_params: String, true, def, String::new(); sso_authorize_extra_params: String, true, def, String::new();
/// Use PKCE during Authorization flow /// Use PKCE during Authorization flow
@ -880,7 +880,7 @@ make_config! {
/// From Address /// From Address
smtp_from: String, true, def, String::new(); smtp_from: String, true, def, String::new();
/// From Name /// From Name
smtp_from_name: String, true, def, "Vaultwarden".to_string(); smtp_from_name: String, true, def, "Vaultwarden".to_owned();
/// Username /// Username
smtp_username: String, true, option; smtp_username: String, true, option;
/// Password /// Password
@ -930,8 +930,9 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
let file_path = url.strip_prefix("sqlite://").unwrap_or(url); let file_path = url.strip_prefix("sqlite://").unwrap_or(url);
if file_path.contains('/') { if file_path.contains('/') {
let path = std::path::Path::new(file_path); let path = std::path::Path::new(file_path);
if let Some(parent) = path.parent() { if let Some(parent) = path.parent()
if !parent.is_dir() { && !parent.is_dir()
{
err!(format!( err!(format!(
"SQLite database directory `{}` does not exist or is not a directory", "SQLite database directory `{}` does not exist or is not a directory",
parent.display() parent.display()
@ -940,7 +941,6 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
} }
} }
} }
}
if cfg.password_iterations < 100_000 { if cfg.password_iterations < 100_000 {
err!("PASSWORD_ITERATIONS should be at least 100000 or higher. The default is 600000!"); err!("PASSWORD_ITERATIONS should be at least 100000 or higher. The default is 600000!");
@ -959,11 +959,11 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
err!(format!("`DATABASE_MIN_CONNS` must be smaller than or equal to `DATABASE_MAX_CONNS`.",)); err!(format!("`DATABASE_MIN_CONNS` must be smaller than or equal to `DATABASE_MAX_CONNS`.",));
} }
if let Some(log_file) = &cfg.log_file { if let Some(log_file) = &cfg.log_file
if std::fs::OpenOptions::new().append(true).create(true).open(log_file).is_err() { && std::fs::OpenOptions::new().append(true).create(true).open(log_file).is_err()
{
err!("Unable to write to log file", log_file); err!("Unable to write to log file", log_file);
} }
}
let dom = cfg.domain.to_lowercase(); let dom = cfg.domain.to_lowercase();
if !dom.starts_with("http://") && !dom.starts_with("https://") { if !dom.starts_with("http://") && !dom.starts_with("https://") {
@ -975,7 +975,9 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
let connect_src = cfg.allowed_connect_src.to_lowercase(); let connect_src = cfg.allowed_connect_src.to_lowercase();
for url in connect_src.split_whitespace() { for url in connect_src.split_whitespace() {
if !url.starts_with("https://") || Url::parse(url).is_err() { if !url.starts_with("https://") || Url::parse(url).is_err() {
err!("ALLOWED_CONNECT_SRC variable contains one or more invalid URLs. Only FQDN's starting with https are allowed"); err!(
"ALLOWED_CONNECT_SRC variable contains one or more invalid URLs. Only FQDN's starting with https are allowed"
);
} }
} }
@ -991,12 +993,13 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
err!("`ORG_CREATION_USERS` contains invalid email addresses"); err!("`ORG_CREATION_USERS` contains invalid email addresses");
} }
if let Some(ref token) = cfg.admin_token { if let Some(ref token) = cfg.admin_token
if token.trim().is_empty() && !cfg.disable_admin_token { && token.trim().is_empty()
&& !cfg.disable_admin_token
{
println!("[WARNING] `ADMIN_TOKEN` is enabled but has an empty value, so the admin page will be disabled."); println!("[WARNING] `ADMIN_TOKEN` is enabled but has an empty value, so the admin page will be disabled.");
println!("[WARNING] To enable the admin page without a token, use `DISABLE_ADMIN_TOKEN`."); println!("[WARNING] To enable the admin page without a token, use `DISABLE_ADMIN_TOKEN`.");
} }
}
if cfg.push_enabled && (cfg.push_installation_id == String::new() || cfg.push_installation_key == String::new()) { if cfg.push_enabled && (cfg.push_installation_id == String::new() || cfg.push_installation_key == String::new()) {
err!( err!(
@ -1029,38 +1032,42 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
} }
} }
let invalid_flags = let invalid_flags = parse_experimental_client_feature_flags(
parse_experimental_client_feature_flags(&cfg.experimental_client_feature_flags, FeatureFlagFilter::InvalidOnly); &cfg.experimental_client_feature_flags,
&FeatureFlagFilter::InvalidOnly,
);
if !invalid_flags.is_empty() { if !invalid_flags.is_empty() {
let feature_flags_error = format!("Unrecognized experimental client feature flags: {:?}.\n\ let feature_flags_error = format!(
"Unrecognized experimental client feature flags: {invalid_flags:?}.\n\
Please ensure all feature flags are spelled correctly and that they are supported in this version.\n\ Please ensure all feature flags are spelled correctly and that they are supported in this version.\n\
Supported flags: {:?}\n", invalid_flags, SUPPORTED_FEATURE_FLAGS); Supported flags: {SUPPORTED_FEATURE_FLAGS:?}\n"
);
if on_update { if on_update {
err!(feature_flags_error); err!(feature_flags_error);
} else {
println!("[WARNING] {feature_flags_error}");
} }
println!("[WARNING] {feature_flags_error}");
} }
#[expect(clippy::items_after_statements, reason = "Keep this close to where it is used")]
const MAX_FILESIZE_KB: i64 = i64::MAX >> 10; const MAX_FILESIZE_KB: i64 = i64::MAX >> 10;
if let Some(limit) = cfg.user_attachment_limit { if let Some(limit) = cfg.user_attachment_limit
if !(0i64..=MAX_FILESIZE_KB).contains(&limit) { && !(0i64..=MAX_FILESIZE_KB).contains(&limit)
{
err!("`USER_ATTACHMENT_LIMIT` is out of bounds"); err!("`USER_ATTACHMENT_LIMIT` is out of bounds");
} }
}
if let Some(limit) = cfg.org_attachment_limit { if let Some(limit) = cfg.org_attachment_limit
if !(0i64..=MAX_FILESIZE_KB).contains(&limit) { && !(0i64..=MAX_FILESIZE_KB).contains(&limit)
{
err!("`ORG_ATTACHMENT_LIMIT` is out of bounds"); err!("`ORG_ATTACHMENT_LIMIT` is out of bounds");
} }
}
if let Some(limit) = cfg.user_send_limit { if let Some(limit) = cfg.user_send_limit
if !(0i64..=MAX_FILESIZE_KB).contains(&limit) { && !(0i64..=MAX_FILESIZE_KB).contains(&limit)
{
err!("`USER_SEND_LIMIT` is out of bounds"); err!("`USER_SEND_LIMIT` is out of bounds");
} }
}
if cfg._enable_duo if cfg._enable_duo
&& (cfg.duo_host.is_some() || cfg.duo_ikey.is_some() || cfg.duo_skey.is_some()) && (cfg.duo_host.is_some() || cfg.duo_ikey.is_some() || cfg.duo_skey.is_some())
@ -1087,7 +1094,9 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
if let Some(yubico_server) = &cfg.yubico_server { if let Some(yubico_server) = &cfg.yubico_server {
let yubico_server = yubico_server.to_lowercase(); let yubico_server = yubico_server.to_lowercase();
if !yubico_server.starts_with("https://") { if !yubico_server.starts_with("https://") {
err!("`YUBICO_SERVER` must be a valid URL and start with 'https://'. Either unset this variable or provide a valid URL.") err!(
"`YUBICO_SERVER` must be a valid URL and start with 'https://'. Either unset this variable or provide a valid URL."
)
} }
} }
} }
@ -1139,7 +1148,9 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
} }
if cfg.smtp_username.is_some() != cfg.smtp_password.is_some() { if cfg.smtp_username.is_some() != cfg.smtp_password.is_some() {
err!("Both `SMTP_USERNAME` and `SMTP_PASSWORD` need to be set to enable email authentication without `USE_SENDMAIL`") err!(
"Both `SMTP_USERNAME` and `SMTP_PASSWORD` need to be set to enable email authentication without `USE_SENDMAIL`"
)
} }
} }
@ -1300,7 +1311,7 @@ fn extract_url_origin(url: &str) -> String {
/// All trailing '/' chars are trimmed, even if the path is a lone '/'. /// All trailing '/' chars are trimmed, even if the path is a lone '/'.
fn extract_url_path(url: &str) -> String { fn extract_url_path(url: &str) -> String {
match Url::parse(url) { match Url::parse(url) {
Ok(u) => u.path().trim_end_matches('/').to_string(), Ok(u) => u.path().trim_end_matches('/').to_owned(),
Err(_) => { Err(_) => {
// We already print it in the method above, no need to do it again // We already print it in the method above, no need to do it again
String::new() String::new()
@ -1310,7 +1321,7 @@ fn extract_url_path(url: &str) -> String {
fn generate_smtp_img_src(embed_images: bool, domain: &str) -> String { fn generate_smtp_img_src(embed_images: bool, domain: &str) -> String {
if embed_images { if embed_images {
"cid:".to_string() "cid:".to_owned()
} else { } else {
// normalize base_url // normalize base_url
let base_url = domain.trim_end_matches('/'); let base_url = domain.trim_end_matches('/');
@ -1329,10 +1340,10 @@ fn generate_sso_callback_path(domain: &str) -> String {
fn generate_icon_service_url(icon_service: &str) -> String { fn generate_icon_service_url(icon_service: &str) -> String {
match icon_service { match icon_service {
"internal" => String::new(), "internal" => String::new(),
"bitwarden" => "https://icons.bitwarden.net/{}/icon.png".to_string(), "bitwarden" => "https://icons.bitwarden.net/{}/icon.png".to_owned(),
"duckduckgo" => "https://icons.duckduckgo.com/ip3/{}.ico".to_string(), "duckduckgo" => "https://icons.duckduckgo.com/ip3/{}.ico".to_owned(),
"google" => "https://www.google.com/s2/favicons?domain={}&sz=32".to_string(), "google" => "https://www.google.com/s2/favicons?domain={}&sz=32".to_owned(),
_ => icon_service.to_string(), _ => icon_service.to_owned(),
} }
} }
@ -1341,7 +1352,7 @@ fn generate_icon_service_csp(icon_service: &str, icon_service_url: &str) -> Stri
// We split on the first '{', since that is the variable delimiter for an icon service URL. // We split on the first '{', since that is the variable delimiter for an icon service URL.
// Everything up until the first '{' should be fixed and can be used as an CSP string. // Everything up until the first '{' should be fixed and can be used as an CSP string.
let csp_string = match icon_service_url.split_once('{') { let csp_string = match icon_service_url.split_once('{') {
Some((c, _)) => c.to_string(), Some((c, _)) => c.to_owned(),
None => String::new(), None => String::new(),
}; };
@ -1358,12 +1369,12 @@ fn smtp_convert_deprecated_ssl_options(smtp_ssl: Option<bool>, smtp_explicit_tls
println!("[DEPRECATED]: `SMTP_SSL` or `SMTP_EXPLICIT_TLS` is set. Please use `SMTP_SECURITY` instead."); println!("[DEPRECATED]: `SMTP_SSL` or `SMTP_EXPLICIT_TLS` is set. Please use `SMTP_SECURITY` instead.");
} }
if smtp_explicit_tls.is_some() && smtp_explicit_tls.unwrap() { if smtp_explicit_tls.is_some() && smtp_explicit_tls.unwrap() {
return "force_tls".to_string(); return "force_tls".to_owned();
} else if smtp_ssl.is_some() && !smtp_ssl.unwrap() { } else if smtp_ssl.is_some() && !smtp_ssl.unwrap() {
return "off".to_string(); return "off".to_owned();
} }
// Return the default `starttls` in all other cases // Return the default `starttls` in all other cases
"starttls".to_string() "starttls".to_owned()
} }
pub enum PathType { pub enum PathType {
@ -1406,12 +1417,12 @@ pub const SUPPORTED_FEATURE_FLAGS: &[&str] = &[
impl Config { impl Config {
pub async fn load() -> Result<Self, Error> { pub async fn load() -> Result<Self, Error> {
// Loading from env and file // Loading from env and file
let _env = ConfigBuilder::from_env(); let env = ConfigBuilder::from_env();
let _usr = ConfigBuilder::from_file().await.unwrap_or_default(); let usr = ConfigBuilder::from_file().await.unwrap_or_default();
// Create merged config, config file overwrites env // Create merged config, config file overwrites env
let mut _overrides = Vec::new(); let mut overrides = Vec::new();
let builder = _env.merge(&_usr, true, &mut _overrides); let builder = env.merge(&usr, true, &mut overrides);
// Fill any missing with defaults // Fill any missing with defaults
let config = builder.build(); let config = builder.build();
@ -1424,9 +1435,9 @@ impl Config {
rocket_shutdown_handle: None, rocket_shutdown_handle: None,
templates: load_templates(&config.templates_folder), templates: load_templates(&config.templates_folder),
config, config,
_env, _env: env,
_usr, _usr: usr,
_overrides, _overrides: overrides,
}), }),
}) })
} }
@ -1472,8 +1483,8 @@ impl Config {
async fn update_config_partial(&self, other: ConfigBuilder) -> Result<(), Error> { async fn update_config_partial(&self, other: ConfigBuilder) -> Result<(), Error> {
let builder = { let builder = {
let usr = &self.inner.read().unwrap()._usr; let usr = &self.inner.read().unwrap()._usr;
let mut _overrides = Vec::new(); let mut overrides = Vec::new();
usr.merge(&other, false, &mut _overrides) usr.merge(&other, false, &mut overrides)
}; };
self.update_config(builder, false).await self.update_config(builder, false).await
} }
@ -1496,11 +1507,11 @@ impl Config {
/// Tests whether signup is allowed for an email address, taking into /// Tests whether signup is allowed for an email address, taking into
/// account the signups_allowed and signups_domains_whitelist settings. /// account the signups_allowed and signups_domains_whitelist settings.
pub fn is_signup_allowed(&self, email: &str) -> bool { pub fn is_signup_allowed(&self, email: &str) -> bool {
if !self.signups_domains_whitelist().is_empty() { if self.signups_domains_whitelist().is_empty() {
self.signups_allowed()
} else {
// The whitelist setting overrides the signups_allowed setting. // The whitelist setting overrides the signups_allowed setting.
self.is_email_domain_allowed(email) self.is_email_domain_allowed(email)
} else {
self.signups_allowed()
} }
} }
@ -1621,12 +1632,12 @@ impl Config {
} }
pub fn shutdown(&self) { pub fn shutdown(&self) {
if let Ok(mut c) = self.inner.write() { if let Ok(mut c) = self.inner.write()
if let Some(handle) = c.rocket_shutdown_handle.take() { && let Some(handle) = c.rocket_shutdown_handle.take()
{
handle.notify(); handle.notify();
} }
} }
}
pub fn sso_issuer_url(&self) -> Result<openidconnect::IssuerUrl, Error> { pub fn sso_issuer_url(&self) -> Result<openidconnect::IssuerUrl, Error> {
validate_internal_sso_issuer_url(&self.sso_authority()) validate_internal_sso_issuer_url(&self.sso_authority())
@ -1641,7 +1652,7 @@ impl Config {
} }
pub fn sso_scopes_vec(&self) -> Vec<String> { pub fn sso_scopes_vec(&self) -> Vec<String> {
self.sso_scopes().split_whitespace().map(str::to_string).collect() self.sso_scopes().split_whitespace().map(str::to_owned).collect()
} }
pub fn sso_authorize_extra_params_vec(&self) -> Vec<(String, String)> { pub fn sso_authorize_extra_params_vec(&self) -> Vec<(String, String)> {
@ -1751,7 +1762,7 @@ fn case_helper<'reg, 'rc>(
let value = param.value().clone(); let value = param.value().clone();
if h.params().iter().skip(1).any(|x| x.value() == &value) { if h.params().iter().skip(1).any(|x| x.value() == &value) {
h.template().map(|t| t.render(r, ctx, rc, out)).unwrap_or_else(|| Ok(())) h.template().map_or(Ok(()), |t| t.render(r, ctx, rc, out))
} else { } else {
Ok(()) Ok(())
} }

29
src/db/mod.rs

@ -6,25 +6,23 @@ use std::{
}; };
use diesel::{ use diesel::{
Connection, RunQueryDsl,
connection::SimpleConnection, connection::SimpleConnection,
r2d2::{CustomizeConnection, Pool, PooledConnection}, r2d2::{CustomizeConnection, Pool, PooledConnection},
Connection, RunQueryDsl,
}; };
use rocket::{ use rocket::{
Request,
http::Status, http::Status,
request::{FromRequest, Outcome}, request::{FromRequest, Outcome},
Request,
}; };
use tokio::{ use tokio::{
sync::{Mutex, OwnedSemaphorePermit, Semaphore}, sync::{Mutex, OwnedSemaphorePermit, Semaphore},
time::timeout, time::timeout,
}; };
use crate::{ use crate::{
error::{Error, MapResult},
CONFIG, CONFIG,
error::{Error, MapResult},
}; };
// These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools // These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
@ -62,7 +60,7 @@ pub struct DbConnManager {
impl DbConnManager { impl DbConnManager {
pub fn new(database_url: &str) -> Self { pub fn new(database_url: &str) -> Self {
Self { Self {
database_url: database_url.to_string(), database_url: database_url.to_owned(),
} }
} }
@ -224,7 +222,7 @@ impl DbPool {
// Set a global to determine the database more easily throughout the rest of the code // Set a global to determine the database more easily throughout the rest of the code
if ACTIVE_DB_TYPE.set(conn_type).is_err() { if ACTIVE_DB_TYPE.set(conn_type).is_err() {
error!("Tried to set the active database connection type more than once.") error!("Tried to set the active database connection type more than once.");
} }
Ok(DbPool { Ok(DbPool {
@ -279,10 +277,10 @@ impl DbConnType {
#[cfg(not(sqlite))] #[cfg(not(sqlite))]
err!("`DATABASE_URL` is a SQLite URL, but the 'sqlite' feature is not enabled") err!("`DATABASE_URL` is a SQLite URL, but the 'sqlite' feature is not enabled")
}
// No recognized scheme — assume legacy bare-path SQLite, but the database file must already exist. // No recognized scheme — assume legacy bare-path SQLite, but the database file must already exist.
// This prevents misconfigured URLs (typos, quoted strings) from silently creating a new empty SQLite database. // This prevents misconfigured URLs (typos, quoted strings) from silently creating a new empty SQLite database.
} else {
#[cfg(sqlite)] #[cfg(sqlite)]
{ {
if std::path::Path::new(url).exists() { if std::path::Path::new(url).exists() {
@ -299,14 +297,13 @@ impl DbConnType {
#[cfg(not(sqlite))] #[cfg(not(sqlite))]
err!("`DATABASE_URL` does not match any known database scheme (mysql://, postgresql://, sqlite://)") err!("`DATABASE_URL` does not match any known database scheme (mysql://, postgresql://, sqlite://)")
} }
}
pub fn get_init_stmts(&self) -> String { pub fn get_init_stmts(&self) -> String {
let init_stmts = CONFIG.database_conn_init(); let init_stmts = CONFIG.database_conn_init();
if !init_stmts.is_empty() { if init_stmts.is_empty() {
init_stmts
} else {
self.default_init_stmts() self.default_init_stmts()
} else {
init_stmts
} }
} }
@ -317,7 +314,7 @@ impl DbConnType {
#[cfg(postgresql)] #[cfg(postgresql)]
Self::Postgresql => String::new(), Self::Postgresql => String::new(),
#[cfg(sqlite)] #[cfg(sqlite)]
Self::Sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(), Self::Sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_owned(),
} }
} }
} }
@ -408,7 +405,7 @@ pub fn backup_sqlite() -> Result<String, Error> {
use diesel::Connection; use diesel::Connection;
let db_url = CONFIG.database_url(); let db_url = CONFIG.database_url();
if DbConnType::from_url(&CONFIG.database_url()).map(|t| t == DbConnType::Sqlite).unwrap_or(false) { if DbConnType::from_url(&CONFIG.database_url()).is_ok_and(|t| t == DbConnType::Sqlite) {
// Strip the sqlite:// prefix if present to get the raw file path // Strip the sqlite:// prefix if present to get the raw file path
let file_path = db_url.strip_prefix("sqlite://").unwrap_or(&db_url); let file_path = db_url.strip_prefix("sqlite://").unwrap_or(&db_url);
// Open a read-only connection for the backup // Open a read-only connection for the backup
@ -443,12 +440,12 @@ pub async fn get_sql_server_version(conn: &DbConn) -> String {
postgresql,mysql { postgresql,mysql {
diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("version();")) diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("version();"))
.get_result::<String>(conn) .get_result::<String>(conn)
.unwrap_or_else(|_| "Unknown".to_string()) .unwrap_or_else(|_| "Unknown".to_owned())
} }
sqlite { sqlite {
diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("sqlite_version();")) diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("sqlite_version();"))
.get_result::<String>(conn) .get_result::<String>(conn)
.unwrap_or_else(|_| "Unknown".to_string()) .unwrap_or_else(|_| "Unknown".to_owned())
} }
} }
} }

32
src/db/models/archive.rs

@ -1,11 +1,13 @@
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use diesel::prelude::*; use diesel::prelude::*;
use crate::{
api::EmptyResult,
db::{DbConn, schema::archives},
error::MapResult,
};
use super::{CipherId, User, UserId}; use super::{CipherId, User, UserId};
use crate::api::EmptyResult;
use crate::db::schema::archives;
use crate::db::DbConn;
use crate::error::MapResult;
#[derive(Identifiable, Queryable, Insertable)] #[derive(Identifiable, Queryable, Insertable)]
#[diesel(table_name = archives)] #[diesel(table_name = archives)]
@ -19,13 +21,15 @@ pub struct Archive {
impl Archive { impl Archive {
// Returns the date the specified cipher was archived // Returns the date the specified cipher was archived
pub async fn get_archived_at(cipher_uuid: &CipherId, user_uuid: &UserId, conn: &DbConn) -> Option<NaiveDateTime> { pub async fn get_archived_at(cipher_uuid: &CipherId, user_uuid: &UserId, conn: &DbConn) -> Option<NaiveDateTime> {
db_run! { conn: { conn.run(move |conn| {
archives::table archives::table
.filter(archives::cipher_uuid.eq(cipher_uuid)) .filter(archives::cipher_uuid.eq(cipher_uuid))
.filter(archives::user_uuid.eq(user_uuid)) .filter(archives::user_uuid.eq(user_uuid))
.select(archives::archived_at) .select(archives::archived_at)
.first::<NaiveDateTime>(conn).ok() .first::<NaiveDateTime>(conn)
}} .ok()
})
.await
} }
// Saves (inserts or updates) an archive record with the provided timestamp // Saves (inserts or updates) an archive record with the provided timestamp
@ -66,26 +70,26 @@ impl Archive {
// Deletes an archive record for a specific cipher // Deletes an archive record for a specific cipher
pub async fn delete_by_cipher(user_uuid: &UserId, cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult { pub async fn delete_by_cipher(user_uuid: &UserId, cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(user_uuid, conn).await; User::update_uuid_revision(user_uuid, conn).await;
db_run! { conn: { conn.run(move |conn| {
diesel::delete( diesel::delete(
archives::table archives::table.filter(archives::user_uuid.eq(user_uuid)).filter(archives::cipher_uuid.eq(cipher_uuid)),
.filter(archives::user_uuid.eq(user_uuid))
.filter(archives::cipher_uuid.eq(cipher_uuid))
) )
.execute(conn) .execute(conn)
.map_res("Error deleting archive") .map_res("Error deleting archive")
}} })
.await
} }
/// Return a vec with (cipher_uuid, archived_at) /// Return a vec with (cipher_uuid, archived_at)
/// This is used during a full sync so we only need one query for all archive matches /// This is used during a full sync so we only need one query for all archive matches
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<(CipherId, NaiveDateTime)> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<(CipherId, NaiveDateTime)> {
db_run! { conn: { conn.run(move |conn| {
archives::table archives::table
.filter(archives::user_uuid.eq(user_uuid)) .filter(archives::user_uuid.eq(user_uuid))
.select((archives::cipher_uuid, archives::archived_at)) .select((archives::cipher_uuid, archives::archived_at))
.load::<(CipherId, NaiveDateTime)>(conn) .load::<(CipherId, NaiveDateTime)>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
} }

77
src/db/models/attachment.rs

@ -1,14 +1,25 @@
use std::time::Duration;
use bigdecimal::{BigDecimal, ToPrimitive}; use bigdecimal::{BigDecimal, ToPrimitive};
use derive_more::{AsRef, Deref, Display}; use derive_more::{AsRef, Deref, Display};
use diesel::prelude::*; use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use std::time::Duration;
use super::{CipherId, OrganizationId, UserId}; use crate::{
use crate::db::schema::{attachments, ciphers}; CONFIG,
use crate::{config::PathType, CONFIG}; api::EmptyResult,
auth::{encode_jwt, generate_file_download_claims},
config::PathType,
db::{
DbConn,
schema::{attachments, ciphers},
},
error::MapResult,
};
use macros::IdFromParam; use macros::IdFromParam;
use super::{CipherId, OrganizationId, UserId};
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = attachments)] #[diesel(table_name = attachments)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -67,12 +78,6 @@ impl Attachment {
} }
} }
use crate::auth::{encode_jwt, generate_file_download_claims};
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods /// Database methods
impl Attachment { impl Attachment {
pub async fn save(&self, conn: &DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
@ -107,15 +112,15 @@ impl Attachment {
} }
pub async fn delete(&self, conn: &DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
crate::util::retry(|| crate::util::retry(
diesel::delete(attachments::table.filter(attachments::id.eq(&self.id))) || diesel::delete(attachments::table.filter(attachments::id.eq(&self.id))).execute(conn),
.execute(conn),
10, 10,
) )
.map(|_| ()) .map(|_| ())
.map_res("Error deleting attachment") .map_res("Error deleting attachment")
}}?; })
.await?;
let operator = CONFIG.opendal_operator_for_path_type(&PathType::Attachments)?; let operator = CONFIG.opendal_operator_for_path_type(&PathType::Attachments)?;
let file_path = self.get_file_path(); let file_path = self.get_file_path();
@ -139,25 +144,22 @@ impl Attachment {
} }
pub async fn find_by_id(id: &AttachmentId, conn: &DbConn) -> Option<Self> { pub async fn find_by_id(id: &AttachmentId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| attachments::table.filter(attachments::id.eq(id.to_lowercase())).first::<Self>(conn).ok())
attachments::table .await
.filter(attachments::id.eq(id.to_lowercase()))
.first::<Self>(conn)
.ok()
}}
} }
pub async fn find_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
attachments::table attachments::table
.filter(attachments::cipher_uuid.eq(cipher_uuid)) .filter(attachments::cipher_uuid.eq(cipher_uuid))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading attachments") .expect("Error loading attachments")
}} })
.await
} }
pub async fn size_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 { pub async fn size_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
db_run! { conn: { conn.run(move |conn| {
let result: Option<BigDecimal> = attachments::table let result: Option<BigDecimal> = attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::user_uuid.eq(user_uuid)) .filter(ciphers::user_uuid.eq(user_uuid))
@ -168,24 +170,26 @@ impl Attachment {
match result.map(|r| r.to_i64()) { match result.map(|r| r.to_i64()) {
Some(Some(r)) => r, Some(Some(r)) => r,
Some(None) => i64::MAX, Some(None) => i64::MAX,
None => 0 None => 0,
} }
}} })
.await
} }
pub async fn count_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 { pub async fn count_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
db_run! { conn: { conn.run(move |conn| {
attachments::table attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::user_uuid.eq(user_uuid)) .filter(ciphers::user_uuid.eq(user_uuid))
.count() .count()
.first(conn) .first(conn)
.unwrap_or(0) .unwrap_or(0)
}} })
.await
} }
pub async fn size_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 { pub async fn size_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { conn.run(move |conn| {
let result: Option<BigDecimal> = attachments::table let result: Option<BigDecimal> = attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::organization_uuid.eq(org_uuid)) .filter(ciphers::organization_uuid.eq(org_uuid))
@ -196,20 +200,22 @@ impl Attachment {
match result.map(|r| r.to_i64()) { match result.map(|r| r.to_i64()) {
Some(Some(r)) => r, Some(Some(r)) => r,
Some(None) => i64::MAX, Some(None) => i64::MAX,
None => 0 None => 0,
} }
}} })
.await
} }
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 { pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { conn.run(move |conn| {
attachments::table attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::organization_uuid.eq(org_uuid)) .filter(ciphers::organization_uuid.eq(org_uuid))
.count() .count()
.first(conn) .first(conn)
.unwrap_or(0) .unwrap_or(0)
}} })
.await
} }
// This will return all attachments linked to the user or org // This will return all attachments linked to the user or org
@ -220,7 +226,7 @@ impl Attachment {
org_uuids: &Vec<OrganizationId>, org_uuids: &Vec<OrganizationId>,
conn: &DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
attachments::table attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::user_uuid.eq(user_uuid)) .filter(ciphers::user_uuid.eq(user_uuid))
@ -228,7 +234,8 @@ impl Attachment {
.select(attachments::all_columns) .select(attachments::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading attachments") .expect("Error loading attachments")
}} })
.await
} }
} }

52
src/db/models/auth_request.rs

@ -1,12 +1,19 @@
use super::{DeviceId, OrganizationId, UserId};
use crate::db::schema::auth_requests;
use crate::{crypto::ct_eq, util::format_date};
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From}; use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*; use diesel::prelude::*;
use macros::UuidFromParam;
use serde_json::Value; use serde_json::Value;
use crate::{
api::EmptyResult,
crypto::ct_eq,
db::{DbConn, schema::auth_requests},
error::MapResult,
util::format_date,
};
use macros::UuidFromParam;
use super::{DeviceId, OrganizationId, UserId};
#[derive(Identifiable, Queryable, Insertable, AsChangeset, Deserialize, Serialize)] #[derive(Identifiable, Queryable, Insertable, AsChangeset, Deserialize, Serialize)]
#[diesel(table_name = auth_requests)] #[diesel(table_name = auth_requests)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -74,11 +81,6 @@ impl AuthRequest {
} }
} }
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
impl AuthRequest { impl AuthRequest {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
@ -112,31 +114,28 @@ impl AuthRequest {
} }
pub async fn find_by_uuid(uuid: &AuthRequestId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &AuthRequestId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| auth_requests::table.filter(auth_requests::uuid.eq(uuid)).first::<Self>(conn).ok()).await
auth_requests::table
.filter(auth_requests::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
} }
pub async fn find_by_uuid_and_user(uuid: &AuthRequestId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid_and_user(uuid: &AuthRequestId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
auth_requests::table auth_requests::table
.filter(auth_requests::uuid.eq(uuid)) .filter(auth_requests::uuid.eq(uuid))
.filter(auth_requests::user_uuid.eq(user_uuid)) .filter(auth_requests::user_uuid.eq(user_uuid))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
auth_requests::table auth_requests::table
.filter(auth_requests::user_uuid.eq(user_uuid)) .filter(auth_requests::user_uuid.eq(user_uuid))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading auth_requests") .expect("Error loading auth_requests")
}} })
.await
} }
pub async fn find_by_user_and_requested_device( pub async fn find_by_user_and_requested_device(
@ -144,7 +143,7 @@ impl AuthRequest {
device_uuid: &DeviceId, device_uuid: &DeviceId,
conn: &DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
auth_requests::table auth_requests::table
.filter(auth_requests::user_uuid.eq(user_uuid)) .filter(auth_requests::user_uuid.eq(user_uuid))
.filter(auth_requests::request_device_identifier.eq(device_uuid)) .filter(auth_requests::request_device_identifier.eq(device_uuid))
@ -152,24 +151,27 @@ impl AuthRequest {
.order_by(auth_requests::creation_date.desc()) .order_by(auth_requests::creation_date.desc())
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_created_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> { pub async fn find_created_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
auth_requests::table auth_requests::table
.filter(auth_requests::creation_date.lt(dt)) .filter(auth_requests::creation_date.lt(dt))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading auth_requests") .expect("Error loading auth_requests")
}} })
.await
} }
pub async fn delete(&self, conn: &DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(auth_requests::table.filter(auth_requests::uuid.eq(&self.uuid))) diesel::delete(auth_requests::table.filter(auth_requests::uuid.eq(&self.uuid)))
.execute(conn) .execute(conn)
.map_res("Error deleting auth request") .map_res("Error deleting auth request")
}} })
.await
} }
pub fn check_access_code(&self, access_code: &str) -> bool { pub fn check_access_code(&self, access_code: &str) -> bool {

598
src/db/models/cipher.rs

@ -1,22 +1,32 @@
use crate::db::schema::{ use std::borrow::Cow;
ciphers, ciphers_collections, collections, collections_groups, folders, folders_ciphers, groups, groups_users,
users_collections, users_organizations,
};
use crate::util::LowerCase;
use crate::CONFIG;
use chrono::{NaiveDateTime, TimeDelta, Utc}; use chrono::{NaiveDateTime, TimeDelta, Utc};
use derive_more::{AsRef, Deref, Display, From}; use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*; use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use crate::{
CONFIG,
api::{
EmptyResult,
core::{CipherData, CipherSyncData, CipherSyncType},
},
db::{
DbConn,
schema::{
ciphers, ciphers_collections, collections, collections_groups, folders, folders_ciphers, groups,
groups_users, users_collections, users_organizations,
},
},
error::MapResult,
util::LowerCase,
};
use macros::UuidFromParam;
use super::{ use super::{
Archive, Attachment, CollectionCipher, CollectionId, Favorite, FolderCipher, FolderId, Group, Membership, Archive, Attachment, CollectionCipher, CollectionId, Favorite, FolderCipher, FolderId, Group, Membership,
MembershipStatus, MembershipType, OrganizationId, User, UserId, MembershipStatus, MembershipType, OrganizationId, User, UserId,
}; };
use crate::api::core::{CipherData, CipherSyncData, CipherSyncType};
use macros::UuidFromParam;
use std::borrow::Cow;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = ciphers)] #[diesel(table_name = ciphers)]
@ -91,18 +101,19 @@ impl Cipher {
format!("The field Notes exceeds the maximum encrypted value length of {max_note_size} characters."); format!("The field Notes exceeds the maximum encrypted value length of {max_note_size} characters.");
for (index, cipher) in cipher_data.iter().enumerate() { for (index, cipher) in cipher_data.iter().enumerate() {
// Validate the note size and if it is exceeded return a warning // Validate the note size and if it is exceeded return a warning
if let Some(note) = &cipher.notes { if let Some(note) = &cipher.notes
if note.len() > max_note_size { && note.len() > max_note_size
{
validation_errors validation_errors
.insert(format!("Ciphers[{index}].Notes"), serde_json::to_value([&max_note_size_msg]).unwrap()); .insert(format!("Ciphers[{index}].Notes"), serde_json::to_value([&max_note_size_msg]).unwrap());
} }
}
// Validate the password history if it contains `null` values and if so, return a warning // Validate the password history if it contains `null` values and if so, return a warning
if let Some(Value::Array(password_history)) = &cipher.password_history { if let Some(Value::Array(password_history)) = &cipher.password_history {
for pwh in password_history { for pwh in password_history {
if let Value::Object(pwo) = pwh { if let Value::Object(pwo) = pwh
if pwo.get("password").is_some_and(|p| !p.is_string()) { && pwo.get("password").is_some_and(|p| !p.is_string())
{
validation_errors.insert( validation_errors.insert(
format!("Ciphers[{index}].Notes"), format!("Ciphers[{index}].Notes"),
serde_json::to_value([ serde_json::to_value([
@ -115,7 +126,6 @@ impl Cipher {
} }
} }
} }
}
if !validation_errors.is_empty() { if !validation_errors.is_empty() {
let err_json = json!({ let err_json = json!({
@ -124,17 +134,12 @@ impl Cipher {
"object": "error" "object": "error"
}); });
err_json!(err_json, "Import validation errors") err_json!(err_json, "Import validation errors")
} else {
Ok(())
} }
Ok(())
} }
} }
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods /// Database methods
impl Cipher { impl Cipher {
pub async fn to_json( pub async fn to_json(
@ -149,15 +154,15 @@ impl Cipher {
let mut attachments_json: Value = Value::Null; let mut attachments_json: Value = Value::Null;
if let Some(cipher_sync_data) = cipher_sync_data { if let Some(cipher_sync_data) = cipher_sync_data {
if let Some(attachments) = cipher_sync_data.cipher_attachments.get(&self.uuid) { if let Some(attachments) = cipher_sync_data.cipher_attachments.get(&self.uuid)
if !attachments.is_empty() { && !attachments.is_empty()
{
let mut attachments_json_vec = vec![]; let mut attachments_json_vec = vec![];
for attachment in attachments { for attachment in attachments {
attachments_json_vec.push(attachment.to_json(host).await?); attachments_json_vec.push(attachment.to_json(host).await?);
} }
attachments_json = Value::Array(attachments_json_vec); attachments_json = Value::Array(attachments_json_vec);
} }
}
} else { } else {
let attachments = Attachment::find_by_cipher(&self.uuid, conn).await; let attachments = Attachment::find_by_cipher(&self.uuid, conn).await;
if !attachments.is_empty() { if !attachments.is_empty() {
@ -172,13 +177,12 @@ impl Cipher {
// We don't need these values at all for Organizational syncs // We don't need these values at all for Organizational syncs
// Skip any other database calls if this is the case and just return false. // Skip any other database calls if this is the case and just return false.
let (read_only, hide_passwords, _) = if sync_type == CipherSyncType::User { let (read_only, hide_passwords, _) = if sync_type == CipherSyncType::User {
match self.get_access_restrictions(user_uuid, cipher_sync_data, conn).await { if let Some((ro, hp, mn)) = self.get_access_restrictions(user_uuid, cipher_sync_data, conn).await {
Some((ro, hp, mn)) => (ro, hp, mn), (ro, hp, mn)
None => { } else {
error!("Cipher ownership assertion failure"); error!("Cipher ownership assertion failure");
(true, true, false) (true, true, false)
} }
}
} else { } else {
(false, false, false) (false, false, false)
}; };
@ -231,15 +235,14 @@ impl Cipher {
Some(p) if p.is_string() => Some(d.data), Some(p) if p.is_string() => Some(d.data),
_ => None, _ => None,
}) })
.map(|mut d| match d.get("lastUsedDate").and_then(|l| l.as_str()) { .map(|mut d| {
Some(l) => { let lud = if let Some(l) = d.get("lastUsedDate").and_then(|l| l.as_str()) {
d["lastUsedDate"] = json!(validate_and_format_date(l)); validate_and_format_date(l)
d } else {
} "1970-01-01T00:00:00.000000Z".to_owned()
_ => { };
d["lastUsedDate"] = json!("1970-01-01T00:00:00.000000Z"); d["lastUsedDate"] = json!(lud);
d d
}
}) })
.collect() .collect()
}) })
@ -247,19 +250,18 @@ impl Cipher {
// Get the type_data or a default to an empty json object '{}'. // Get the type_data or a default to an empty json object '{}'.
// If not passing an empty object, mobile clients will crash. // If not passing an empty object, mobile clients will crash.
let mut type_data_json = let mut type_data_json = serde_json::from_str::<LowerCase<Value>>(&self.data)
serde_json::from_str::<LowerCase<Value>>(&self.data).map(|d| d.data).unwrap_or_else(|_| { .inspect_err(|_| warn!("Error parsing data field for {}", self.uuid))
warn!("Error parsing data field for {}", self.uuid); .map_or_else(|_| Value::Object(serde_json::Map::new()), |d| d.data);
Value::Object(serde_json::Map::new())
});
// NOTE: This was marked as *Backwards Compatibility Code*, but as of January 2021 this is still being used by upstream // NOTE: This was marked as *Backwards Compatibility Code*, but as of January 2021 this is still being used by upstream
// Set the first element of the Uris array as Uri, this is needed several (mobile) clients. // Set the first element of the Uris array as Uri, this is needed several (mobile) clients.
if self.atype == 1 { if self.atype == 1 {
// Upstream always has an `uri` key/value // Upstream always has an `uri` key/value
type_data_json["uri"] = Value::Null; type_data_json["uri"] = Value::Null;
if let Some(uris) = type_data_json["uris"].as_array_mut() { if let Some(uris) = type_data_json["uris"].as_array_mut()
if !uris.is_empty() { && !uris.is_empty()
{
// Fix uri match values first, they are only allowed to be a number or null // Fix uri match values first, they are only allowed to be a number or null
// If it is a string, convert it to an int or null if that fails // If it is a string, convert it to an int or null if that fails
for uri in &mut *uris { for uri in &mut *uris {
@ -273,7 +275,6 @@ impl Cipher {
} }
type_data_json["uri"] = uris[0]["uri"].clone(); type_data_json["uri"] = uris[0]["uri"].clone();
} }
}
// Check if `passwordRevisionDate` is a valid date, else convert it // Check if `passwordRevisionDate` is a valid date, else convert it
if let Some(pw_revision) = type_data_json["passwordRevisionDate"].as_str() { if let Some(pw_revision) = type_data_json["passwordRevisionDate"].as_str() {
@ -285,7 +286,7 @@ impl Cipher {
// This breaks at least the native mobile clients // This breaks at least the native mobile clients
if self.atype == 2 { if self.atype == 2 {
match type_data_json { match type_data_json {
Value::Object(ref t) if t.get("type").is_some_and(|t| t.is_number()) => {} Value::Object(ref t) if t.get("type").is_some_and(Value::is_number) => {}
_ => { _ => {
type_data_json = json!({"type": 0}); type_data_json = json!({"type": 0});
} }
@ -297,9 +298,9 @@ impl Cipher {
// The only way to fix this is by setting type_data_json to `null` // The only way to fix this is by setting type_data_json to `null`
// Opening this ssh-key in the mobile client will probably crash the client, but you can edit, save and afterwards delete it // Opening this ssh-key in the mobile client will probably crash the client, but you can edit, save and afterwards delete it
if self.atype == 5 if self.atype == 5
&& (type_data_json["keyFingerprint"].as_str().is_none_or(|v| v.is_empty()) && (type_data_json["keyFingerprint"].as_str().is_none_or(str::is_empty)
|| type_data_json["privateKey"].as_str().is_none_or(|v| v.is_empty()) || type_data_json["privateKey"].as_str().is_none_or(str::is_empty)
|| type_data_json["publicKey"].as_str().is_none_or(|v| v.is_empty())) || type_data_json["publicKey"].as_str().is_none_or(str::is_empty))
{ {
warn!("Error parsing ssh-key, mandatory fields are invalid for {}", self.uuid); warn!("Error parsing ssh-key, mandatory fields are invalid for {}", self.uuid);
type_data_json = Value::Null; type_data_json = Value::Null;
@ -415,7 +416,7 @@ impl Cipher {
match self.user_uuid { match self.user_uuid {
Some(ref user_uuid) => { Some(ref user_uuid) => {
User::update_uuid_revision(user_uuid, conn).await; User::update_uuid_revision(user_uuid, conn).await;
user_uuids.push(user_uuid.clone()) user_uuids.push(user_uuid.clone());
} }
None => { None => {
// Belongs to Organization, need to update affected users // Belongs to Organization, need to update affected users
@ -430,11 +431,11 @@ impl Cipher {
} }
for member in collection_users { for member in collection_users {
User::update_uuid_revision(&member.user_uuid, conn).await; User::update_uuid_revision(&member.user_uuid, conn).await;
user_uuids.push(member.user_uuid.clone()) user_uuids.push(member.user_uuid.clone());
}
} }
} }
} }
};
user_uuids user_uuids
} }
@ -480,11 +481,12 @@ impl Cipher {
Attachment::delete_all_by_cipher(&self.uuid, conn).await?; Attachment::delete_all_by_cipher(&self.uuid, conn).await?;
Favorite::delete_all_by_cipher(&self.uuid, conn).await?; Favorite::delete_all_by_cipher(&self.uuid, conn).await?;
db_run! { conn: { conn.run(move |conn| {
diesel::delete(ciphers::table.filter(ciphers::uuid.eq(&self.uuid))) diesel::delete(ciphers::table.filter(ciphers::uuid.eq(&self.uuid)))
.execute(conn) .execute(conn)
.map_res("Error deleting cipher") .map_res("Error deleting cipher")
}} })
.await
} }
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@ -531,9 +533,10 @@ impl Cipher {
// Remove from folder // Remove from folder
(Some(old_folder), None) => { (Some(old_folder), None) => {
match FolderCipher::find_by_folder_and_cipher(&old_folder, &self.uuid, conn).await { if let Some(old_folder) = FolderCipher::find_by_folder_and_cipher(&old_folder, &self.uuid, conn).await {
Some(old_folder) => old_folder.delete(conn).await, old_folder.delete(conn).await
None => err!("Couldn't move from previous folder"), } else {
err!("Couldn't move from previous folder")
} }
} }
@ -584,9 +587,8 @@ impl Cipher {
if let Some(ref org_uuid) = self.organization_uuid { if let Some(ref org_uuid) = self.organization_uuid {
if let Some(cipher_sync_data) = cipher_sync_data { if let Some(cipher_sync_data) = cipher_sync_data {
return cipher_sync_data.user_group_full_access_for_organizations.contains(org_uuid); return cipher_sync_data.user_group_full_access_for_organizations.contains(org_uuid);
} else {
return Group::is_in_full_access_group(user_uuid, org_uuid, conn).await;
} }
return Group::is_in_full_access_group(user_uuid, org_uuid, conn).await;
} }
false false
} }
@ -628,10 +630,10 @@ impl Cipher {
rows rows
} else { } else {
let user_permissions = self.get_user_collections_access_flags(user_uuid, conn).await; let user_permissions = self.get_user_collections_access_flags(user_uuid, conn).await;
if !user_permissions.is_empty() { if user_permissions.is_empty() {
user_permissions
} else {
self.get_group_collections_access_flags(user_uuid, conn).await self.get_group_collections_access_flags(user_uuid, conn).await
} else {
user_permissions
} }
}; };
@ -657,7 +659,7 @@ impl Cipher {
let mut read_only = true; let mut read_only = true;
let mut hide_passwords = true; let mut hide_passwords = true;
let mut manage = false; let mut manage = false;
for (ro, hp, mn) in rows.iter() { for (ro, hp, mn) in &rows {
read_only &= ro; read_only &= ro;
hide_passwords &= hp; hide_passwords &= hp;
manage |= mn; manage |= mn;
@ -667,51 +669,51 @@ impl Cipher {
} }
async fn get_user_collections_access_flags(&self, user_uuid: &UserId, conn: &DbConn) -> Vec<(bool, bool, bool)> { async fn get_user_collections_access_flags(&self, user_uuid: &UserId, conn: &DbConn) -> Vec<(bool, bool, bool)> {
db_run! { conn: { conn.run(move |conn| {
// Check whether this cipher is in any collections accessible to the // Check whether this cipher is in any collections accessible to the
// user. If so, retrieve the access flags for each collection. // user. If so, retrieve the access flags for each collection.
ciphers::table ciphers::table
.filter(ciphers::uuid.eq(&self.uuid)) .filter(ciphers::uuid.eq(&self.uuid))
.inner_join(ciphers_collections::table.on( .inner_join(ciphers_collections::table.on(ciphers::uuid.eq(ciphers_collections::cipher_uuid)))
ciphers::uuid.eq(ciphers_collections::cipher_uuid) .inner_join(
)) users_collections::table.on(ciphers_collections::collection_uuid
.inner_join(users_collections::table.on( .eq(users_collections::collection_uuid)
ciphers_collections::collection_uuid.eq(users_collections::collection_uuid) .and(users_collections::user_uuid.eq(user_uuid))),
.and(users_collections::user_uuid.eq(user_uuid)) )
))
.select((users_collections::read_only, users_collections::hide_passwords, users_collections::manage)) .select((users_collections::read_only, users_collections::hide_passwords, users_collections::manage))
.load::<(bool, bool, bool)>(conn) .load::<(bool, bool, bool)>(conn)
.expect("Error getting user access restrictions") .expect("Error getting user access restrictions")
}} })
.await
} }
async fn get_group_collections_access_flags(&self, user_uuid: &UserId, conn: &DbConn) -> Vec<(bool, bool, bool)> { async fn get_group_collections_access_flags(&self, user_uuid: &UserId, conn: &DbConn) -> Vec<(bool, bool, bool)> {
if !CONFIG.org_groups_enabled() { if !CONFIG.org_groups_enabled() {
return Vec::new(); return Vec::new();
} }
db_run! { conn: { conn.run(move |conn| {
ciphers::table ciphers::table
.filter(ciphers::uuid.eq(&self.uuid)) .filter(ciphers::uuid.eq(&self.uuid))
.inner_join(ciphers_collections::table.on( .inner_join(ciphers_collections::table.on(ciphers::uuid.eq(ciphers_collections::cipher_uuid)))
ciphers::uuid.eq(ciphers_collections::cipher_uuid) .inner_join(
)) collections_groups::table
.inner_join(collections_groups::table.on( .on(collections_groups::collections_uuid.eq(ciphers_collections::collection_uuid)),
collections_groups::collections_uuid.eq(ciphers_collections::collection_uuid) )
)) .inner_join(groups_users::table.on(groups_users::groups_uuid.eq(collections_groups::groups_uuid)))
.inner_join(groups_users::table.on( .inner_join(
groups_users::groups_uuid.eq(collections_groups::groups_uuid) users_organizations::table.on(users_organizations::uuid.eq(groups_users::users_organizations_uuid)),
)) )
.inner_join(users_organizations::table.on( .inner_join(
users_organizations::uuid.eq(groups_users::users_organizations_uuid) groups::table.on(groups::uuid
)) .eq(collections_groups::groups_uuid)
.inner_join(groups::table.on(groups::uuid.eq(collections_groups::groups_uuid) .and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
.and(groups::organizations_uuid.eq(users_organizations::org_uuid)) )
))
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.select((collections_groups::read_only, collections_groups::hide_passwords, collections_groups::manage)) .select((collections_groups::read_only, collections_groups::hide_passwords, collections_groups::manage))
.load::<(bool, bool, bool)>(conn) .load::<(bool, bool, bool)>(conn)
.expect("Error getting group access restrictions") .expect("Error getting group access restrictions")
}} })
.await
} }
pub async fn is_write_accessible_to_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool { pub async fn is_write_accessible_to_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
@ -760,7 +762,7 @@ impl Cipher {
} }
pub async fn get_folder_uuid(&self, user_uuid: &UserId, conn: &DbConn) -> Option<FolderId> { pub async fn get_folder_uuid(&self, user_uuid: &UserId, conn: &DbConn) -> Option<FolderId> {
db_run! { conn: { conn.run(move |conn| {
folders_ciphers::table folders_ciphers::table
.inner_join(folders::table) .inner_join(folders::table)
.filter(folders::user_uuid.eq(&user_uuid)) .filter(folders::user_uuid.eq(&user_uuid))
@ -768,16 +770,12 @@ impl Cipher {
.select(folders_ciphers::folder_uuid) .select(folders_ciphers::folder_uuid)
.first::<FolderId>(conn) .first::<FolderId>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_by_uuid(uuid: &CipherId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &CipherId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| ciphers::table.filter(ciphers::uuid.eq(uuid)).first::<Self>(conn).ok()).await
ciphers::table
.filter(ciphers::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
} }
pub async fn find_by_uuid_and_org( pub async fn find_by_uuid_and_org(
@ -785,13 +783,14 @@ impl Cipher {
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
conn: &DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
ciphers::table ciphers::table
.filter(ciphers::uuid.eq(cipher_uuid)) .filter(ciphers::uuid.eq(cipher_uuid))
.filter(ciphers::organization_uuid.eq(org_uuid)) .filter(ciphers::organization_uuid.eq(org_uuid))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
// Find all ciphers accessible or visible to the specified user. // Find all ciphers accessible or visible to the specified user.
@ -813,32 +812,35 @@ impl Cipher {
conn: &DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
db_run! { conn: { conn.run(move |conn| {
let mut query = ciphers::table let mut query = ciphers::table
.left_join(ciphers_collections::table.on( .left_join(ciphers_collections::table.on(ciphers::uuid.eq(ciphers_collections::cipher_uuid)))
ciphers::uuid.eq(ciphers_collections::cipher_uuid) .left_join(
)) users_organizations::table.on(ciphers::organization_uuid
.left_join(users_organizations::table.on( .eq(users_organizations::org_uuid.nullable())
ciphers::organization_uuid.eq(users_organizations::org_uuid.nullable())
.and(users_organizations::user_uuid.eq(user_uuid)) .and(users_organizations::user_uuid.eq(user_uuid))
.and(users_organizations::status.eq(MembershipStatus::Confirmed as i32)) .and(users_organizations::status.eq(MembershipStatus::Confirmed as i32))),
)) )
.left_join(users_collections::table.on( .left_join(
ciphers_collections::collection_uuid.eq(users_collections::collection_uuid) users_collections::table.on(ciphers_collections::collection_uuid
.eq(users_collections::collection_uuid)
// Ensure that users_collections::user_uuid is NULL for unconfirmed users. // Ensure that users_collections::user_uuid is NULL for unconfirmed users.
.and(users_organizations::user_uuid.eq(users_collections::user_uuid)) .and(users_organizations::user_uuid.eq(users_collections::user_uuid))),
)) )
.left_join(groups_users::table.on( .left_join(
groups_users::users_organizations_uuid.eq(users_organizations::uuid) groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
)) )
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid) .left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
// Ensure that group and membership belong to the same org // Ensure that group and membership belong to the same org
.and(groups::organizations_uuid.eq(users_organizations::org_uuid)) .and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)) )
.left_join(collections_groups::table.on( .left_join(
collections_groups::collections_uuid.eq(ciphers_collections::collection_uuid) collections_groups::table.on(collections_groups::collections_uuid
.and(collections_groups::groups_uuid.eq(groups::uuid)) .eq(ciphers_collections::collection_uuid)
)) .and(collections_groups::groups_uuid.eq(groups::uuid))),
)
.filter(ciphers::user_uuid.eq(user_uuid)) // Cipher owner .filter(ciphers::user_uuid.eq(user_uuid)) // Cipher owner
.or_filter(users_organizations::access_all.eq(true)) // access_all in org .or_filter(users_organizations::access_all.eq(true)) // access_all in org
.or_filter(users_collections::user_uuid.eq(user_uuid)) // Access to collection .or_filter(users_collections::user_uuid.eq(user_uuid)) // Access to collection
@ -848,39 +850,34 @@ impl Cipher {
if !visible_only { if !visible_only {
query = query.or_filter( query = query.or_filter(
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin/owner users_organizations::atype.le(MembershipType::Admin as i32), // Org admin/owner
); );
} }
// Only filter for one specific cipher // Only filter for one specific cipher
if !cipher_uuids.is_empty() { if !cipher_uuids.is_empty() {
query = query.filter( query = query.filter(ciphers::uuid.eq_any(cipher_uuids));
ciphers::uuid.eq_any(cipher_uuids)
);
} }
query query.select(ciphers::all_columns).distinct().load::<Self>(conn).expect("Error loading ciphers")
.select(ciphers::all_columns) })
.distinct() .await
.load::<Self>(conn)
.expect("Error loading ciphers")
}}
} else { } else {
db_run! { conn: { conn.run(move |conn| {
let mut query = ciphers::table let mut query = ciphers::table
.left_join(ciphers_collections::table.on( .left_join(ciphers_collections::table.on(ciphers::uuid.eq(ciphers_collections::cipher_uuid)))
ciphers::uuid.eq(ciphers_collections::cipher_uuid) .left_join(
)) users_organizations::table.on(ciphers::organization_uuid
.left_join(users_organizations::table.on( .eq(users_organizations::org_uuid.nullable())
ciphers::organization_uuid.eq(users_organizations::org_uuid.nullable())
.and(users_organizations::user_uuid.eq(user_uuid)) .and(users_organizations::user_uuid.eq(user_uuid))
.and(users_organizations::status.eq(MembershipStatus::Confirmed as i32)) .and(users_organizations::status.eq(MembershipStatus::Confirmed as i32))),
)) )
.left_join(users_collections::table.on( .left_join(
ciphers_collections::collection_uuid.eq(users_collections::collection_uuid) users_collections::table.on(ciphers_collections::collection_uuid
.eq(users_collections::collection_uuid)
// Ensure that users_collections::user_uuid is NULL for unconfirmed users. // Ensure that users_collections::user_uuid is NULL for unconfirmed users.
.and(users_organizations::user_uuid.eq(users_collections::user_uuid)) .and(users_organizations::user_uuid.eq(users_collections::user_uuid))),
)) )
.filter(ciphers::user_uuid.eq(user_uuid)) // Cipher owner .filter(ciphers::user_uuid.eq(user_uuid)) // Cipher owner
.or_filter(users_organizations::access_all.eq(true)) // access_all in org .or_filter(users_organizations::access_all.eq(true)) // access_all in org
.or_filter(users_collections::user_uuid.eq(user_uuid)) // Access to collection .or_filter(users_collections::user_uuid.eq(user_uuid)) // Access to collection
@ -888,23 +885,18 @@ impl Cipher {
if !visible_only { if !visible_only {
query = query.or_filter( query = query.or_filter(
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin/owner users_organizations::atype.le(MembershipType::Admin as i32), // Org admin/owner
); );
} }
// Only filter for one specific cipher // Only filter for one specific cipher
if !cipher_uuids.is_empty() { if !cipher_uuids.is_empty() {
query = query.filter( query = query.filter(ciphers::uuid.eq_any(cipher_uuids));
ciphers::uuid.eq_any(cipher_uuids)
);
} }
query query.select(ciphers::all_columns).distinct().load::<Self>(conn).expect("Error loading ciphers")
.select(ciphers::all_columns) })
.distinct() .await
.load::<Self>(conn)
.expect("Error loading ciphers")
}}
} }
} }
@ -927,193 +919,208 @@ impl Cipher {
// Find all ciphers directly owned by the specified user. // Find all ciphers directly owned by the specified user.
pub async fn find_owned_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_owned_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
ciphers::table ciphers::table
.filter( .filter(ciphers::user_uuid.eq(user_uuid).and(ciphers::organization_uuid.is_null()))
ciphers::user_uuid.eq(user_uuid)
.and(ciphers::organization_uuid.is_null())
)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading ciphers") .expect("Error loading ciphers")
}} })
.await
} }
pub async fn count_owned_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 { pub async fn count_owned_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
db_run! { conn: { conn.run(move |conn| {
ciphers::table ciphers::table.filter(ciphers::user_uuid.eq(user_uuid)).count().first::<i64>(conn).ok().unwrap_or(0)
.filter(ciphers::user_uuid.eq(user_uuid)) })
.count() .await
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
} }
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
ciphers::table ciphers::table
.filter(ciphers::organization_uuid.eq(org_uuid)) .filter(ciphers::organization_uuid.eq(org_uuid))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading ciphers") .expect("Error loading ciphers")
}} })
.await
} }
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 { pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { conn.run(move |conn| {
ciphers::table ciphers::table.filter(ciphers::organization_uuid.eq(org_uuid)).count().first::<i64>(conn).ok().unwrap_or(0)
.filter(ciphers::organization_uuid.eq(org_uuid)) })
.count() .await
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
} }
pub async fn find_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
folders_ciphers::table.inner_join(ciphers::table) folders_ciphers::table
.inner_join(ciphers::table)
.filter(folders_ciphers::folder_uuid.eq(folder_uuid)) .filter(folders_ciphers::folder_uuid.eq(folder_uuid))
.select(ciphers::all_columns) .select(ciphers::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading ciphers") .expect("Error loading ciphers")
}} })
.await
} }
/// Find all ciphers that were deleted before the specified datetime. /// Find all ciphers that were deleted before the specified datetime.
pub async fn find_deleted_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> { pub async fn find_deleted_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
ciphers::table ciphers::table.filter(ciphers::deleted_at.lt(dt)).load::<Self>(conn).expect("Error loading ciphers")
.filter(ciphers::deleted_at.lt(dt)) })
.load::<Self>(conn) .await
.expect("Error loading ciphers")
}}
} }
pub async fn get_collections(&self, user_uuid: UserId, conn: &DbConn) -> Vec<CollectionId> { pub async fn get_collections(&self, user_uuid: UserId, conn: &DbConn) -> Vec<CollectionId> {
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
db_run! { conn: { conn.run(move |conn| {
ciphers_collections::table ciphers_collections::table
.filter(ciphers_collections::cipher_uuid.eq(&self.uuid)) .filter(ciphers_collections::cipher_uuid.eq(&self.uuid))
.inner_join(collections::table.on( .inner_join(collections::table.on(collections::uuid.eq(ciphers_collections::collection_uuid)))
collections::uuid.eq(ciphers_collections::collection_uuid) .left_join(
)) users_organizations::table.on(users_organizations::org_uuid
.left_join(users_organizations::table.on( .eq(collections::org_uuid)
users_organizations::org_uuid.eq(collections::org_uuid) .and(users_organizations::user_uuid.eq(user_uuid.clone()))),
.and(users_organizations::user_uuid.eq(user_uuid.clone())) )
)) .left_join(
.left_join(users_collections::table.on( users_collections::table.on(users_collections::collection_uuid
users_collections::collection_uuid.eq(ciphers_collections::collection_uuid) .eq(ciphers_collections::collection_uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone())) .and(users_collections::user_uuid.eq(user_uuid.clone()))),
)) )
.left_join(groups_users::table.on( .left_join(
groups_users::users_organizations_uuid.eq(users_organizations::uuid) groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
)) )
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid) .left_join(
.and(groups::organizations_uuid.eq(users_organizations::org_uuid)) groups::table.on(groups::uuid
)) .eq(groups_users::groups_uuid)
.left_join(collections_groups::table.on( .and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
collections_groups::collections_uuid.eq(ciphers_collections::collection_uuid) )
.and(collections_groups::groups_uuid.eq(groups::uuid)) .left_join(
)) collections_groups::table.on(collections_groups::collections_uuid
.filter(users_organizations::access_all.eq(true) // User has access all .eq(ciphers_collections::collection_uuid)
.or(users_collections::user_uuid.eq(user_uuid) // User has access to collection .and(collections_groups::groups_uuid.eq(groups::uuid))),
)
.filter(
users_organizations::access_all
.eq(true) // User has access all
.or(users_collections::user_uuid
.eq(user_uuid) // User has access to collection
.and(users_collections::read_only.eq(false))) .and(users_collections::read_only.eq(false)))
.or(groups::access_all.eq(true)) // Access via groups .or(groups::access_all.eq(true)) // Access via groups
.or(collections_groups::collections_uuid.is_not_null() // Access via groups .or(collections_groups::collections_uuid
.and(collections_groups::read_only.eq(false))) .is_not_null() // Access via groups
.and(collections_groups::read_only.eq(false))),
) )
.select(ciphers_collections::collection_uuid) .select(ciphers_collections::collection_uuid)
.load::<CollectionId>(conn) .load::<CollectionId>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} else { } else {
db_run! { conn: { conn.run(move |conn| {
ciphers_collections::table ciphers_collections::table
.filter(ciphers_collections::cipher_uuid.eq(&self.uuid)) .filter(ciphers_collections::cipher_uuid.eq(&self.uuid))
.inner_join(collections::table.on( .inner_join(collections::table.on(collections::uuid.eq(ciphers_collections::collection_uuid)))
collections::uuid.eq(ciphers_collections::collection_uuid) .inner_join(
)) users_organizations::table.on(users_organizations::org_uuid
.inner_join(users_organizations::table.on( .eq(collections::org_uuid)
users_organizations::org_uuid.eq(collections::org_uuid) .and(users_organizations::user_uuid.eq(user_uuid.clone()))),
.and(users_organizations::user_uuid.eq(user_uuid.clone())) )
)) .left_join(
.left_join(users_collections::table.on( users_collections::table.on(users_collections::collection_uuid
users_collections::collection_uuid.eq(ciphers_collections::collection_uuid) .eq(ciphers_collections::collection_uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone())) .and(users_collections::user_uuid.eq(user_uuid.clone()))),
)) )
.filter(users_organizations::access_all.eq(true) // User has access all .filter(
.or(users_collections::user_uuid.eq(user_uuid) // User has access to collection users_organizations::access_all
.and(users_collections::read_only.eq(false))) .eq(true) // User has access all
.or(users_collections::user_uuid
.eq(user_uuid) // User has access to collection
.and(users_collections::read_only.eq(false))),
) )
.select(ciphers_collections::collection_uuid) .select(ciphers_collections::collection_uuid)
.load::<CollectionId>(conn) .load::<CollectionId>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
} }
pub async fn get_admin_collections(&self, user_uuid: UserId, conn: &DbConn) -> Vec<CollectionId> { pub async fn get_admin_collections(&self, user_uuid: UserId, conn: &DbConn) -> Vec<CollectionId> {
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
db_run! { conn: { conn.run(move |conn| {
ciphers_collections::table ciphers_collections::table
.filter(ciphers_collections::cipher_uuid.eq(&self.uuid)) .filter(ciphers_collections::cipher_uuid.eq(&self.uuid))
.inner_join(collections::table.on( .inner_join(collections::table.on(collections::uuid.eq(ciphers_collections::collection_uuid)))
collections::uuid.eq(ciphers_collections::collection_uuid) .left_join(
)) users_organizations::table.on(users_organizations::org_uuid
.left_join(users_organizations::table.on( .eq(collections::org_uuid)
users_organizations::org_uuid.eq(collections::org_uuid) .and(users_organizations::user_uuid.eq(user_uuid.clone()))),
.and(users_organizations::user_uuid.eq(user_uuid.clone())) )
)) .left_join(
.left_join(users_collections::table.on( users_collections::table.on(users_collections::collection_uuid
users_collections::collection_uuid.eq(ciphers_collections::collection_uuid) .eq(ciphers_collections::collection_uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone())) .and(users_collections::user_uuid.eq(user_uuid.clone()))),
)) )
.left_join(groups_users::table.on( .left_join(
groups_users::users_organizations_uuid.eq(users_organizations::uuid) groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
)) )
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid) .left_join(
.and(groups::organizations_uuid.eq(users_organizations::org_uuid)) groups::table.on(groups::uuid
)) .eq(groups_users::groups_uuid)
.left_join(collections_groups::table.on( .and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
collections_groups::collections_uuid.eq(ciphers_collections::collection_uuid) )
.and(collections_groups::groups_uuid.eq(groups::uuid)) .left_join(
)) collections_groups::table.on(collections_groups::collections_uuid
.filter(users_organizations::access_all.eq(true) // User has access all .eq(ciphers_collections::collection_uuid)
.or(users_collections::user_uuid.eq(user_uuid) // User has access to collection .and(collections_groups::groups_uuid.eq(groups::uuid))),
)
.filter(
users_organizations::access_all
.eq(true) // User has access all
.or(users_collections::user_uuid
.eq(user_uuid) // User has access to collection
.and(users_collections::read_only.eq(false))) .and(users_collections::read_only.eq(false)))
.or(groups::access_all.eq(true)) // Access via groups .or(groups::access_all.eq(true)) // Access via groups
.or(collections_groups::collections_uuid.is_not_null() // Access via groups .or(collections_groups::collections_uuid
.is_not_null() // Access via groups
.and(collections_groups::read_only.eq(false))) .and(collections_groups::read_only.eq(false)))
.or(users_organizations::atype.le(MembershipType::Admin as i32)) // User is admin or owner .or(users_organizations::atype.le(MembershipType::Admin as i32)), // User is admin or owner
) )
.select(ciphers_collections::collection_uuid) .select(ciphers_collections::collection_uuid)
.load::<CollectionId>(conn) .load::<CollectionId>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} else { } else {
db_run! { conn: { conn.run(move |conn| {
ciphers_collections::table ciphers_collections::table
.filter(ciphers_collections::cipher_uuid.eq(&self.uuid)) .filter(ciphers_collections::cipher_uuid.eq(&self.uuid))
.inner_join(collections::table.on( .inner_join(collections::table.on(collections::uuid.eq(ciphers_collections::collection_uuid)))
collections::uuid.eq(ciphers_collections::collection_uuid) .inner_join(
)) users_organizations::table.on(users_organizations::org_uuid
.inner_join(users_organizations::table.on( .eq(collections::org_uuid)
users_organizations::org_uuid.eq(collections::org_uuid) .and(users_organizations::user_uuid.eq(user_uuid.clone()))),
.and(users_organizations::user_uuid.eq(user_uuid.clone())) )
)) .left_join(
.left_join(users_collections::table.on( users_collections::table.on(users_collections::collection_uuid
users_collections::collection_uuid.eq(ciphers_collections::collection_uuid) .eq(ciphers_collections::collection_uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone())) .and(users_collections::user_uuid.eq(user_uuid.clone()))),
)) )
.filter(users_organizations::access_all.eq(true) // User has access all .filter(
.or(users_collections::user_uuid.eq(user_uuid) // User has access to collection users_organizations::access_all
.eq(true) // User has access all
.or(users_collections::user_uuid
.eq(user_uuid) // User has access to collection
.and(users_collections::read_only.eq(false))) .and(users_collections::read_only.eq(false)))
.or(users_organizations::atype.le(MembershipType::Admin as i32)) // User is admin or owner .or(users_organizations::atype.le(MembershipType::Admin as i32)), // User is admin or owner
) )
.select(ciphers_collections::collection_uuid) .select(ciphers_collections::collection_uuid)
.load::<CollectionId>(conn) .load::<CollectionId>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
} }
@ -1123,32 +1130,30 @@ impl Cipher {
user_uuid: UserId, user_uuid: UserId,
conn: &DbConn, conn: &DbConn,
) -> Vec<(CipherId, CollectionId)> { ) -> Vec<(CipherId, CollectionId)> {
db_run! { conn: { conn.run(move |conn| {
ciphers_collections::table ciphers_collections::table
.inner_join(collections::table.on( .inner_join(collections::table.on(collections::uuid.eq(ciphers_collections::collection_uuid)))
collections::uuid.eq(ciphers_collections::collection_uuid) .inner_join(
)) users_organizations::table.on(users_organizations::org_uuid
.inner_join(users_organizations::table.on( .eq(collections::org_uuid)
users_organizations::org_uuid.eq(collections::org_uuid).and( .and(users_organizations::user_uuid.eq(user_uuid.clone()))),
users_organizations::user_uuid.eq(user_uuid.clone()) )
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(ciphers_collections::collection_uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
) )
)) .left_join(groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)))
.left_join(users_collections::table.on( .left_join(
users_collections::collection_uuid.eq(ciphers_collections::collection_uuid).and( groups::table.on(groups::uuid
users_collections::user_uuid.eq(user_uuid.clone()) .eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
) )
)) .left_join(
.left_join(groups_users::table.on( collections_groups::table.on(collections_groups::collections_uuid
groups_users::users_organizations_uuid.eq(users_organizations::uuid) .eq(ciphers_collections::collection_uuid)
)) .and(collections_groups::groups_uuid.eq(groups::uuid))),
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(collections_groups::table.on(
collections_groups::collections_uuid.eq(ciphers_collections::collection_uuid).and(
collections_groups::groups_uuid.eq(groups::uuid)
) )
))
.or_filter(users_collections::user_uuid.eq(user_uuid)) // User has access to collection .or_filter(users_collections::user_uuid.eq(user_uuid)) // User has access to collection
.or_filter(users_organizations::access_all.eq(true)) // User has access all .or_filter(users_organizations::access_all.eq(true)) // User has access all
.or_filter(users_organizations::atype.le(MembershipType::Admin as i32)) // User is admin or owner .or_filter(users_organizations::atype.le(MembershipType::Admin as i32)) // User is admin or owner
@ -1158,7 +1163,8 @@ impl Cipher {
.distinct() .distinct()
.load::<(CipherId, CollectionId)>(conn) .load::<(CipherId, CollectionId)>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
} }

633
src/db/models/collection.rs

@ -1,16 +1,25 @@
use derive_more::{AsRef, Deref, Display, From}; use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use crate::{
CONFIG,
api::EmptyResult,
db::{
DbConn,
schema::{
ciphers_collections, collections, collections_groups, groups, groups_users, users_collections,
users_organizations,
},
},
error::MapResult,
};
use macros::UuidFromParam;
use super::{ use super::{
CipherId, CollectionGroup, GroupUser, Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId, CipherId, CollectionGroup, GroupUser, Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId,
User, UserId, User, UserId,
}; };
use crate::db::schema::{
ciphers_collections, collections, collections_groups, groups, groups_users, users_collections, users_organizations,
};
use crate::CONFIG;
use diesel::prelude::*;
use macros::UuidFromParam;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = collections)] #[diesel(table_name = collections)]
@ -74,7 +83,7 @@ impl Collection {
if external_id.is_empty() { if external_id.is_empty() {
self.external_id = None; self.external_id = None;
} else { } else {
self.external_id = Some(external_id) self.external_id = Some(external_id);
} }
} }
None => self.external_id = None, None => self.external_id = None,
@ -147,11 +156,6 @@ impl Collection {
} }
} }
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods /// Database methods
impl Collection { impl Collection {
pub async fn save(&self, conn: &DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
@ -193,11 +197,12 @@ impl Collection {
CollectionUser::delete_all_by_collection(&self.uuid, conn).await?; CollectionUser::delete_all_by_collection(&self.uuid, conn).await?;
CollectionGroup::delete_all_by_collection(&self.uuid, &self.org_uuid, conn).await?; CollectionGroup::delete_all_by_collection(&self.uuid, &self.org_uuid, conn).await?;
db_run! { conn: { conn.run(move |conn| {
diesel::delete(collections::table.filter(collections::uuid.eq(self.uuid))) diesel::delete(collections::table.filter(collections::uuid.eq(self.uuid)))
.execute(conn) .execute(conn)
.map_res("Error deleting collection") .map_res("Error deleting collection")
}} })
.await
} }
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@ -208,90 +213,90 @@ impl Collection {
} }
pub async fn update_users_revision(&self, conn: &DbConn) { pub async fn update_users_revision(&self, conn: &DbConn) {
for member in Membership::find_by_collection_and_org(&self.uuid, &self.org_uuid, conn).await.iter() { for member in &Membership::find_by_collection_and_org(&self.uuid, &self.org_uuid, conn).await {
User::update_uuid_revision(&member.user_uuid, conn).await; User::update_uuid_revision(&member.user_uuid, conn).await;
} }
} }
pub async fn find_by_uuid(uuid: &CollectionId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &CollectionId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| collections::table.filter(collections::uuid.eq(uuid)).first::<Self>(conn).ok()).await
collections::table
.filter(collections::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
} }
pub async fn find_by_user_uuid(user_uuid: UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user_uuid(user_uuid: UserId, conn: &DbConn) -> Vec<Self> {
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
db_run! { conn: { conn.run(move |conn| {
collections::table collections::table
.left_join(users_collections::table.on( .left_join(
users_collections::collection_uuid.eq(collections::uuid).and( users_collections::table.on(users_collections::collection_uuid
users_collections::user_uuid.eq(user_uuid.clone()) .eq(collections::uuid)
) .and(users_collections::user_uuid.eq(user_uuid.clone()))),
)) )
.left_join(users_organizations::table.on( .left_join(
collections::org_uuid.eq(users_organizations::org_uuid).and( users_organizations::table.on(collections::org_uuid
users_organizations::user_uuid.eq(user_uuid.clone()) .eq(users_organizations::org_uuid)
) .and(users_organizations::user_uuid.eq(user_uuid.clone()))),
)) )
.left_join(groups_users::table.on( .left_join(
groups_users::users_organizations_uuid.eq(users_organizations::uuid) groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
)) )
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid) .left_join(
.and(groups::organizations_uuid.eq(users_organizations::org_uuid)) groups::table.on(groups::uuid
)) .eq(groups_users::groups_uuid)
.left_join(collections_groups::table.on( .and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
collections_groups::groups_uuid.eq(groups_users::groups_uuid).and( )
collections_groups::collections_uuid.eq(collections::uuid) .left_join(
) collections_groups::table.on(collections_groups::groups_uuid
)) .eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
)
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter( .filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32) users_collections::user_uuid
) .eq(user_uuid)
.filter( .or(
users_collections::user_uuid.eq(user_uuid).or( // Directly accessed collection // Directly accessed collection
users_organizations::access_all.eq(true) // access_all in Organization users_organizations::access_all.eq(true), // access_all in Organization
).or( )
groups::access_all.eq(true) // access_all in groups .or(
).or( // access via groups groups::access_all.eq(true), // access_all in groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and( )
collections_groups::collections_uuid.is_not_null() .or(
) // access via groups
) groups_users::users_organizations_uuid
.eq(users_organizations::uuid)
.and(collections_groups::collections_uuid.is_not_null()),
),
) )
.select(collections::all_columns) .select(collections::all_columns)
.distinct() .distinct()
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading collections") .expect("Error loading collections")
}} })
.await
} else { } else {
db_run! { conn: { conn.run(move |conn| {
collections::table collections::table
.left_join(users_collections::table.on( .left_join(
users_collections::collection_uuid.eq(collections::uuid).and( users_collections::table.on(users_collections::collection_uuid
users_collections::user_uuid.eq(user_uuid.clone()) .eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
) )
)) .left_join(
.left_join(users_organizations::table.on( users_organizations::table.on(collections::org_uuid
collections::org_uuid.eq(users_organizations::org_uuid).and( .eq(users_organizations::org_uuid)
users_organizations::user_uuid.eq(user_uuid.clone()) .and(users_organizations::user_uuid.eq(user_uuid.clone()))),
) )
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(users_collections::user_uuid.eq(user_uuid).or(
// Directly accessed collection
users_organizations::access_all.eq(true), // access_all in Organization
)) ))
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.filter(
users_collections::user_uuid.eq(user_uuid).or( // Directly accessed collection
users_organizations::access_all.eq(true) // access_all in Organization
)
)
.select(collections::all_columns) .select(collections::all_columns)
.distinct() .distinct()
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading collections") .expect("Error loading collections")
}} })
.await
} }
} }
@ -308,256 +313,311 @@ impl Collection {
} }
pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
collections::table collections::table
.filter(collections::org_uuid.eq(org_uuid)) .filter(collections::org_uuid.eq(org_uuid))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading collections") .expect("Error loading collections")
}} })
.await
} }
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 { pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { conn.run(move |conn| {
collections::table collections::table.filter(collections::org_uuid.eq(org_uuid)).count().first::<i64>(conn).ok().unwrap_or(0)
.filter(collections::org_uuid.eq(org_uuid)) })
.count() .await
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
} }
pub async fn find_by_uuid_and_org(uuid: &CollectionId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid_and_org(uuid: &CollectionId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
collections::table collections::table
.filter(collections::uuid.eq(uuid)) .filter(collections::uuid.eq(uuid))
.filter(collections::org_uuid.eq(org_uuid)) .filter(collections::org_uuid.eq(org_uuid))
.select(collections::all_columns) .select(collections::all_columns)
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_by_uuid_and_user(uuid: &CollectionId, user_uuid: UserId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid_and_user(uuid: &CollectionId, user_uuid: UserId, conn: &DbConn) -> Option<Self> {
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
db_run! { conn: { conn.run(move |conn| {
collections::table collections::table
.left_join(users_collections::table.on( .left_join(
users_collections::collection_uuid.eq(collections::uuid).and( users_collections::table.on(users_collections::collection_uuid
users_collections::user_uuid.eq(user_uuid.clone()) .eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
) )
)) .left_join(
.left_join(users_organizations::table.on( users_organizations::table.on(collections::org_uuid
collections::org_uuid.eq(users_organizations::org_uuid).and( .eq(users_organizations::org_uuid)
users_organizations::user_uuid.eq(user_uuid) .and(users_organizations::user_uuid.eq(user_uuid))),
) )
)) .left_join(
.left_join(groups_users::table.on( groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
groups_users::users_organizations_uuid.eq(users_organizations::uuid) )
)) .left_join(
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid) groups::table.on(groups::uuid
.and(groups::organizations_uuid.eq(users_organizations::org_uuid)) .eq(groups_users::groups_uuid)
)) .and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
.left_join(collections_groups::table.on( )
collections_groups::groups_uuid.eq(groups_users::groups_uuid).and( .left_join(
collections_groups::collections_uuid.eq(collections::uuid) collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
) )
))
.filter(collections::uuid.eq(uuid)) .filter(collections::uuid.eq(uuid))
.filter( .filter(
users_collections::collection_uuid.eq(uuid).or( // Directly accessed collection users_collections::collection_uuid
users_organizations::access_all.eq(true).or( // access_all in Organization .eq(uuid)
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner .or(
)).or( // Directly accessed collection
groups::access_all.eq(true) // access_all in groups users_organizations::access_all.eq(true).or(
).or( // access via groups // access_all in Organization
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and( users_organizations::atype.le(MembershipType::Admin as i32), // Org admin or owner
collections_groups::collections_uuid.is_not_null() ),
) )
.or(
groups::access_all.eq(true), // access_all in groups
)
.or(
// access via groups
groups_users::users_organizations_uuid
.eq(users_organizations::uuid)
.and(collections_groups::collections_uuid.is_not_null()),
),
) )
).select(collections::all_columns) .select(collections::all_columns)
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} else { } else {
db_run! { conn: { conn.run(move |conn| {
collections::table collections::table
.left_join(users_collections::table.on( .left_join(
users_collections::collection_uuid.eq(collections::uuid).and( users_collections::table.on(users_collections::collection_uuid
users_collections::user_uuid.eq(user_uuid.clone()) .eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
) )
)) .left_join(
.left_join(users_organizations::table.on( users_organizations::table.on(collections::org_uuid
collections::org_uuid.eq(users_organizations::org_uuid).and( .eq(users_organizations::org_uuid)
users_organizations::user_uuid.eq(user_uuid) .and(users_organizations::user_uuid.eq(user_uuid))),
) )
))
.filter(collections::uuid.eq(uuid)) .filter(collections::uuid.eq(uuid))
.filter( .filter(users_collections::collection_uuid.eq(uuid).or(
users_collections::collection_uuid.eq(uuid).or( // Directly accessed collection // Directly accessed collection
users_organizations::access_all.eq(true).or( // access_all in Organization users_organizations::access_all.eq(true).or(
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner // access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32), // Org admin or owner
),
)) ))
).select(collections::all_columns) .select(collections::all_columns)
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
} }
pub async fn is_writable_by_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool { pub async fn is_writable_by_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
let user_uuid = user_uuid.to_string(); let user_uuid = user_uuid.to_string();
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
db_run! { conn: { conn.run(move |conn| {
collections::table collections::table
.filter(collections::uuid.eq(&self.uuid)) .filter(collections::uuid.eq(&self.uuid))
.inner_join(users_organizations::table.on( .inner_join(
collections::org_uuid.eq(users_organizations::org_uuid) users_organizations::table.on(collections::org_uuid
.and(users_organizations::user_uuid.eq(user_uuid.clone())) .eq(users_organizations::org_uuid)
)) .and(users_organizations::user_uuid.eq(user_uuid.clone()))),
.left_join(users_collections::table.on( )
users_collections::collection_uuid.eq(collections::uuid) .left_join(
.and(users_collections::user_uuid.eq(user_uuid)) users_collections::table.on(users_collections::collection_uuid
)) .eq(collections::uuid)
.left_join(groups_users::table.on( .and(users_collections::user_uuid.eq(user_uuid))),
groups_users::users_organizations_uuid.eq(users_organizations::uuid) )
)) .left_join(
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid) groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
.and(groups::organizations_uuid.eq(users_organizations::org_uuid)) )
)) .left_join(
.left_join(collections_groups::table.on( groups::table.on(groups::uuid
collections_groups::groups_uuid.eq(groups_users::groups_uuid) .eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid)) .and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)) )
.filter(users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner .left_join(
collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
)
.filter(
users_organizations::atype
.le(MembershipType::Admin as i32) // Org admin or owner
.or(users_organizations::access_all.eq(true)) // access_all via membership .or(users_organizations::access_all.eq(true)) // access_all via membership
.or(users_collections::collection_uuid.eq(&self.uuid) // write access given to collection .or(users_collections::collection_uuid
.eq(&self.uuid) // write access given to collection
.and(users_collections::read_only.eq(false))) .and(users_collections::read_only.eq(false)))
.or(groups::access_all.eq(true)) // access_all via group .or(groups::access_all.eq(true)) // access_all via group
.or(collections_groups::collections_uuid.is_not_null() // write access given via group .or(collections_groups::collections_uuid
.and(collections_groups::read_only.eq(false))) .is_not_null() // write access given via group
.and(collections_groups::read_only.eq(false))),
) )
.count() .count()
.first::<i64>(conn) .first::<i64>(conn)
.ok() .ok()
.unwrap_or(0) != 0 .unwrap_or(0)
}} != 0
})
.await
} else { } else {
db_run! { conn: { conn.run(move |conn| {
collections::table collections::table
.filter(collections::uuid.eq(&self.uuid)) .filter(collections::uuid.eq(&self.uuid))
.inner_join(users_organizations::table.on( .inner_join(
collections::org_uuid.eq(users_organizations::org_uuid) users_organizations::table.on(collections::org_uuid
.and(users_organizations::user_uuid.eq(user_uuid.clone())) .eq(users_organizations::org_uuid)
)) .and(users_organizations::user_uuid.eq(user_uuid.clone()))),
.left_join(users_collections::table.on( )
users_collections::collection_uuid.eq(collections::uuid) .left_join(
.and(users_collections::user_uuid.eq(user_uuid)) users_collections::table.on(users_collections::collection_uuid
)) .eq(collections::uuid)
.filter(users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner .and(users_collections::user_uuid.eq(user_uuid))),
)
.filter(
users_organizations::atype
.le(MembershipType::Admin as i32) // Org admin or owner
.or(users_organizations::access_all.eq(true)) // access_all via membership .or(users_organizations::access_all.eq(true)) // access_all via membership
.or(users_collections::collection_uuid.eq(&self.uuid) // write access given to collection .or(users_collections::collection_uuid
.and(users_collections::read_only.eq(false))) .eq(&self.uuid) // write access given to collection
.and(users_collections::read_only.eq(false))),
) )
.count() .count()
.first::<i64>(conn) .first::<i64>(conn)
.ok() .ok()
.unwrap_or(0) != 0 .unwrap_or(0)
}} != 0
})
.await
} }
} }
pub async fn hide_passwords_for_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool { pub async fn hide_passwords_for_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
let user_uuid = user_uuid.to_string(); let user_uuid = user_uuid.to_string();
db_run! { conn: { conn.run(move |conn| {
collections::table collections::table
.left_join(users_collections::table.on( .left_join(
users_collections::collection_uuid.eq(collections::uuid).and( users_collections::table.on(users_collections::collection_uuid
users_collections::user_uuid.eq(user_uuid.clone()) .eq(collections::uuid)
) .and(users_collections::user_uuid.eq(user_uuid.clone()))),
)) )
.left_join(users_organizations::table.on( .left_join(
collections::org_uuid.eq(users_organizations::org_uuid).and( users_organizations::table.on(collections::org_uuid
users_organizations::user_uuid.eq(user_uuid) .eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
)
.left_join(groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)))
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
.left_join(
collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
) )
))
.left_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid).and(
collections_groups::collections_uuid.eq(collections::uuid)
)
))
.filter(collections::uuid.eq(&self.uuid)) .filter(collections::uuid.eq(&self.uuid))
.filter( .filter(
users_collections::collection_uuid.eq(&self.uuid).and(users_collections::hide_passwords.eq(true)).or(// Directly accessed collection users_collections::collection_uuid
users_organizations::access_all.eq(true).or( // access_all in Organization .eq(&self.uuid)
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner .and(users_collections::hide_passwords.eq(true))
)).or( .or(
groups::access_all.eq(true) // access_all in groups // Directly accessed collection
).or( // access via groups users_organizations::access_all.eq(true).or(
// access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32), // Org admin or owner
),
)
.or(
groups::access_all.eq(true), // access_all in groups
)
.or(
// access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and( groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid.is_not_null().and( collections_groups::collections_uuid
collections_groups::hide_passwords.eq(true)) .is_not_null()
) .and(collections_groups::hide_passwords.eq(true)),
) ),
),
) )
.count() .count()
.first::<i64>(conn) .first::<i64>(conn)
.ok() .ok()
.unwrap_or(0) != 0 .unwrap_or(0)
}} != 0
})
.await
} }
pub async fn is_coll_manageable_by_user(uuid: &CollectionId, user_uuid: &UserId, conn: &DbConn) -> bool { pub async fn is_coll_manageable_by_user(uuid: &CollectionId, user_uuid: &UserId, conn: &DbConn) -> bool {
let uuid = uuid.to_string(); let uuid = uuid.to_string();
let user_uuid = user_uuid.to_string(); let user_uuid = user_uuid.to_string();
db_run! { conn: { conn.run(move |conn| {
collections::table collections::table
.left_join(users_collections::table.on( .left_join(
users_collections::collection_uuid.eq(collections::uuid).and( users_collections::table.on(users_collections::collection_uuid
users_collections::user_uuid.eq(user_uuid.clone()) .eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
)
.left_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
)
.left_join(groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)))
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
.left_join(
collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
) )
))
.left_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid).and(
users_organizations::user_uuid.eq(user_uuid)
)
))
.left_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid).and(
collections_groups::collections_uuid.eq(collections::uuid)
)
))
.filter(collections::uuid.eq(&uuid)) .filter(collections::uuid.eq(&uuid))
.filter( .filter(
users_collections::collection_uuid.eq(&uuid).and(users_collections::manage.eq(true)).or(// Directly accessed collection users_collections::collection_uuid
users_organizations::access_all.eq(true).or( // access_all in Organization .eq(&uuid)
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner .and(users_collections::manage.eq(true))
)).or( .or(
groups::access_all.eq(true) // access_all in groups // Directly accessed collection
).or( // access via groups users_organizations::access_all.eq(true).or(
// access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32), // Org admin or owner
),
)
.or(
groups::access_all.eq(true), // access_all in groups
)
.or(
// access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and( groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid.is_not_null().and( collections_groups::collections_uuid
collections_groups::manage.eq(true)) .is_not_null()
) .and(collections_groups::manage.eq(true)),
) ),
),
) )
.count() .count()
.first::<i64>(conn) .first::<i64>(conn)
.ok() .ok()
.unwrap_or(0) != 0 .unwrap_or(0)
}} != 0
})
.await
} }
pub async fn is_manageable_by_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool { pub async fn is_manageable_by_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
@ -572,7 +632,7 @@ impl CollectionUser {
user_uuid: &UserId, user_uuid: &UserId,
conn: &DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_collections::table users_collections::table
.filter(users_collections::user_uuid.eq(user_uuid)) .filter(users_collections::user_uuid.eq(user_uuid))
.inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid))) .inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid)))
@ -580,24 +640,35 @@ impl CollectionUser {
.select(users_collections::all_columns) .select(users_collections::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading users_collections") .expect("Error loading users_collections")
}} })
.await
} }
pub async fn find_by_organization_swap_user_uuid_with_member_uuid( pub async fn find_by_organization_swap_user_uuid_with_member_uuid(
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
conn: &DbConn, conn: &DbConn,
) -> Vec<CollectionMembership> { ) -> Vec<CollectionMembership> {
let col_users = db_run! { conn: { let col_users = conn
.run(move |conn| {
users_collections::table users_collections::table
.inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid))) .inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid)))
.filter(collections::org_uuid.eq(org_uuid)) .filter(collections::org_uuid.eq(org_uuid))
.inner_join(users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid))) .inner_join(
users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)),
)
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.select((users_organizations::uuid, users_collections::collection_uuid, users_collections::read_only, users_collections::hide_passwords, users_collections::manage)) .select((
users_organizations::uuid,
users_collections::collection_uuid,
users_collections::read_only,
users_collections::hide_passwords,
users_collections::manage,
))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading users_collections") .expect("Error loading users_collections")
}}; })
col_users.into_iter().map(|c| c.into()).collect() .await;
col_users.into_iter().map(Into::into).collect()
} }
pub async fn save( pub async fn save(
@ -666,7 +737,7 @@ impl CollectionUser {
pub async fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn).await; User::update_uuid_revision(&self.user_uuid, conn).await;
db_run! { conn: { conn.run(move |conn| {
diesel::delete( diesel::delete(
users_collections::table users_collections::table
.filter(users_collections::user_uuid.eq(&self.user_uuid)) .filter(users_collections::user_uuid.eq(&self.user_uuid))
@ -674,17 +745,19 @@ impl CollectionUser {
) )
.execute(conn) .execute(conn)
.map_res("Error removing user from collection") .map_res("Error removing user from collection")
}} })
.await
} }
pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_collections::table users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid)) .filter(users_collections::collection_uuid.eq(collection_uuid))
.select(users_collections::all_columns) .select(users_collections::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading users_collections") .expect("Error loading users_collections")
}} })
.await
} }
pub async fn find_by_org_and_coll_swap_user_uuid_with_member_uuid( pub async fn find_by_org_and_coll_swap_user_uuid_with_member_uuid(
@ -692,16 +765,26 @@ impl CollectionUser {
collection_uuid: &CollectionId, collection_uuid: &CollectionId,
conn: &DbConn, conn: &DbConn,
) -> Vec<CollectionMembership> { ) -> Vec<CollectionMembership> {
let col_users = db_run! { conn: { let col_users = conn
.run(move |conn| {
users_collections::table users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid)) .filter(users_collections::collection_uuid.eq(collection_uuid))
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.inner_join(users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid))) .inner_join(
.select((users_organizations::uuid, users_collections::collection_uuid, users_collections::read_only, users_collections::hide_passwords, users_collections::manage)) users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)),
)
.select((
users_organizations::uuid,
users_collections::collection_uuid,
users_collections::read_only,
users_collections::hide_passwords,
users_collections::manage,
))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading users_collections") .expect("Error loading users_collections")
}}; })
col_users.into_iter().map(|c| c.into()).collect() .await;
col_users.into_iter().map(Into::into).collect()
} }
pub async fn find_by_collection_and_user( pub async fn find_by_collection_and_user(
@ -709,36 +792,39 @@ impl CollectionUser {
user_uuid: &UserId, user_uuid: &UserId,
conn: &DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
users_collections::table users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid)) .filter(users_collections::collection_uuid.eq(collection_uuid))
.filter(users_collections::user_uuid.eq(user_uuid)) .filter(users_collections::user_uuid.eq(user_uuid))
.select(users_collections::all_columns) .select(users_collections::all_columns)
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_collections::table users_collections::table
.filter(users_collections::user_uuid.eq(user_uuid)) .filter(users_collections::user_uuid.eq(user_uuid))
.select(users_collections::all_columns) .select(users_collections::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading users_collections") .expect("Error loading users_collections")
}} })
.await
} }
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
for collection in CollectionUser::find_by_collection(collection_uuid, conn).await.iter() { for collection in &CollectionUser::find_by_collection(collection_uuid, conn).await {
User::update_uuid_revision(&collection.user_uuid, conn).await; User::update_uuid_revision(&collection.user_uuid, conn).await;
} }
db_run! { conn: { conn.run(move |conn| {
diesel::delete(users_collections::table.filter(users_collections::collection_uuid.eq(collection_uuid))) diesel::delete(users_collections::table.filter(users_collections::collection_uuid.eq(collection_uuid)))
.execute(conn) .execute(conn)
.map_res("Error deleting users from collection") .map_res("Error deleting users from collection")
}} })
.await
} }
pub async fn delete_all_by_user_and_org( pub async fn delete_all_by_user_and_org(
@ -748,17 +834,21 @@ impl CollectionUser {
) -> EmptyResult { ) -> EmptyResult {
let collectionusers = Self::find_by_organization_and_user_uuid(org_uuid, user_uuid, conn).await; let collectionusers = Self::find_by_organization_and_user_uuid(org_uuid, user_uuid, conn).await;
db_run! { conn: { conn.run(move |conn| {
for user in collectionusers { for user in collectionusers {
let _: () = diesel::delete(users_collections::table.filter( let _: () = diesel::delete(
users_collections::user_uuid.eq(user_uuid) users_collections::table.filter(
.and(users_collections::collection_uuid.eq(user.collection_uuid)) users_collections::user_uuid
)) .eq(user_uuid)
.and(users_collections::collection_uuid.eq(user.collection_uuid)),
),
)
.execute(conn) .execute(conn)
.map_res("Error removing user from collections")?; .map_res("Error removing user from collections")?;
} }
Ok(()) Ok(())
}} })
.await
} }
pub async fn has_access_to_collection_by_user(col_id: &CollectionId, user_uuid: &UserId, conn: &DbConn) -> bool { pub async fn has_access_to_collection_by_user(col_id: &CollectionId, user_uuid: &UserId, conn: &DbConn) -> bool {
@ -801,7 +891,7 @@ impl CollectionCipher {
pub async fn delete(cipher_uuid: &CipherId, collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult { pub async fn delete(cipher_uuid: &CipherId, collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
Self::update_users_revision(collection_uuid, conn).await; Self::update_users_revision(collection_uuid, conn).await;
db_run! { conn: { conn.run(move |conn| {
diesel::delete( diesel::delete(
ciphers_collections::table ciphers_collections::table
.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid)) .filter(ciphers_collections::cipher_uuid.eq(cipher_uuid))
@ -809,23 +899,26 @@ impl CollectionCipher {
) )
.execute(conn) .execute(conn)
.map_res("Error deleting cipher from collection") .map_res("Error deleting cipher from collection")
}} })
.await
} }
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(ciphers_collections::table.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid))) diesel::delete(ciphers_collections::table.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid)))
.execute(conn) .execute(conn)
.map_res("Error removing cipher from collections") .map_res("Error removing cipher from collections")
}} })
.await
} }
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(ciphers_collections::table.filter(ciphers_collections::collection_uuid.eq(collection_uuid))) diesel::delete(ciphers_collections::table.filter(ciphers_collections::collection_uuid.eq(collection_uuid)))
.execute(conn) .execute(conn)
.map_res("Error removing ciphers from collection") .map_res("Error removing ciphers from collection")
}} })
.await
} }
pub async fn update_users_revision(collection_uuid: &CollectionId, conn: &DbConn) { pub async fn update_users_revision(collection_uuid: &CollectionId, conn: &DbConn) {

73
src/db/models/device.rs

@ -1,18 +1,20 @@
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use data_encoding::BASE64URL; use data_encoding::BASE64URL;
use derive_more::{Display, From}; use derive_more::{Display, From};
use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use super::{AuthRequest, UserId};
use crate::db::schema::devices;
use crate::{ use crate::{
api::EmptyResult,
crypto, crypto,
db::{DbConn, schema::devices},
error::MapResult,
util::{format_date, get_uuid}, util::{format_date, get_uuid},
}; };
use diesel::prelude::*;
use macros::{IdFromParam, UuidFromParam}; use macros::{IdFromParam, UuidFromParam};
use super::{AuthRequest, UserId};
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = devices)] #[diesel(table_name = devices)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -135,10 +137,6 @@ impl DeviceWithAuthRequest {
} }
} }
} }
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods /// Database methods
impl Device { impl Device {
@ -171,21 +169,23 @@ impl Device {
} }
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(devices::table.filter(devices::user_uuid.eq(user_uuid))) diesel::delete(devices::table.filter(devices::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)
.map_res("Error removing devices for user") .map_res("Error removing devices for user")
}} })
.await
} }
pub async fn find_by_uuid_and_user(uuid: &DeviceId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid_and_user(uuid: &DeviceId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
devices::table devices::table
.filter(devices::uuid.eq(uuid)) .filter(devices::uuid.eq(uuid))
.filter(devices::user_uuid.eq(user_uuid)) .filter(devices::user_uuid.eq(user_uuid))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_with_auth_request_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<DeviceWithAuthRequest> { pub async fn find_with_auth_request_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<DeviceWithAuthRequest> {
@ -199,71 +199,65 @@ impl Device {
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
devices::table devices::table.filter(devices::user_uuid.eq(user_uuid)).load::<Self>(conn).expect("Error loading devices")
.filter(devices::user_uuid.eq(user_uuid)) })
.load::<Self>(conn) .await
.expect("Error loading devices")
}}
} }
pub async fn find_by_uuid(uuid: &DeviceId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| devices::table.filter(devices::uuid.eq(uuid)).first::<Self>(conn).ok()).await
devices::table
.filter(devices::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
} }
pub async fn clear_push_token_by_uuid(uuid: &DeviceId, conn: &DbConn) -> EmptyResult { pub async fn clear_push_token_by_uuid(uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::update(devices::table) diesel::update(devices::table)
.filter(devices::uuid.eq(uuid)) .filter(devices::uuid.eq(uuid))
.set(devices::push_token.eq::<Option<String>>(None)) .set(devices::push_token.eq::<Option<String>>(None))
.execute(conn) .execute(conn)
.map_res("Error removing push token") .map_res("Error removing push token")
}} })
.await
} }
pub async fn find_by_refresh_token(refresh_token: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_refresh_token(refresh_token: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| devices::table.filter(devices::refresh_token.eq(refresh_token)).first::<Self>(conn).ok())
devices::table .await
.filter(devices::refresh_token.eq(refresh_token))
.first::<Self>(conn)
.ok()
}}
} }
pub async fn find_latest_active_by_user(user_uuid: &UserId, conn: &DbConn) -> Option<Self> { pub async fn find_latest_active_by_user(user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
devices::table devices::table
.filter(devices::user_uuid.eq(user_uuid)) .filter(devices::user_uuid.eq(user_uuid))
.order(devices::updated_at.desc()) .order(devices::updated_at.desc())
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_push_devices_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_push_devices_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
devices::table devices::table
.filter(devices::user_uuid.eq(user_uuid)) .filter(devices::user_uuid.eq(user_uuid))
.filter(devices::push_token.is_not_null()) .filter(devices::push_token.is_not_null())
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading push devices") .expect("Error loading push devices")
}} })
.await
} }
pub async fn check_user_has_push_device(user_uuid: &UserId, conn: &DbConn) -> bool { pub async fn check_user_has_push_device(user_uuid: &UserId, conn: &DbConn) -> bool {
db_run! { conn: { conn.run(move |conn| {
devices::table devices::table
.filter(devices::user_uuid.eq(user_uuid)) .filter(devices::user_uuid.eq(user_uuid))
.filter(devices::push_token.is_not_null()) .filter(devices::push_token.is_not_null())
.count() .count()
.first::<i64>(conn) .first::<i64>(conn)
.ok() .ok()
.unwrap_or(0) != 0 .unwrap_or(0)
}} != 0
})
.await
} }
pub async fn rotate_refresh_tokens_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult { pub async fn rotate_refresh_tokens_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
@ -337,6 +331,7 @@ pub enum DeviceType {
} }
impl DeviceType { impl DeviceType {
#[expect(clippy::match_same_arms, reason = "Specifically define 14 and have a fallback for new types")]
pub fn from_i32(value: i32) -> DeviceType { pub fn from_i32(value: i32) -> DeviceType {
match value { match value {
0 => DeviceType::Android, 0 => DeviceType::Android,

98
src/db/models/emergency_access.rs

@ -1,13 +1,17 @@
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From}; use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use super::{User, UserId}; use crate::{
use crate::db::schema::emergency_access; api::EmptyResult,
use crate::{api::EmptyResult, db::DbConn, error::MapResult}; db::{DbConn, schema::emergency_access},
use diesel::prelude::*; error::MapResult,
};
use macros::UuidFromParam; use macros::UuidFromParam;
use super::{User, UserId};
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = emergency_access)] #[diesel(table_name = emergency_access)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -87,14 +91,13 @@ impl EmergencyAccess {
User::find_by_uuid(grantee_uuid, conn).await.expect("Grantee user not found.") User::find_by_uuid(grantee_uuid, conn).await.expect("Grantee user not found.")
} else { } else {
let email = self.email.as_deref()?; let email = self.email.as_deref()?;
match User::find_by_mail(email, conn).await { if let Some(user) = User::find_by_mail(email, conn).await {
Some(user) => user, user
None => { } else {
// remove outstanding invitations which should not exist // remove outstanding invitations which should not exist
Self::delete_all_by_grantee_email(email, conn).await.ok(); Self::delete_all_by_grantee_email(email, conn).await.ok();
return None; return None;
} }
}
}; };
Some(json!({ Some(json!({
@ -183,28 +186,36 @@ impl EmergencyAccess {
self.status = status; self.status = status;
date.clone_into(&mut self.updated_at); date.clone_into(&mut self.updated_at);
db_run! { conn: { conn.run(move |conn| {
crate::util::retry(|| { crate::util::retry(
|| {
diesel::update(emergency_access::table.filter(emergency_access::uuid.eq(&self.uuid))) diesel::update(emergency_access::table.filter(emergency_access::uuid.eq(&self.uuid)))
.set((emergency_access::status.eq(status), emergency_access::updated_at.eq(date))) .set((emergency_access::status.eq(status), emergency_access::updated_at.eq(date)))
.execute(conn) .execute(conn)
}, 10) },
10,
)
.map_res("Error updating emergency access status") .map_res("Error updating emergency access status")
}} })
.await
} }
pub async fn update_last_notification_date_and_save(&mut self, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult { pub async fn update_last_notification_date_and_save(&mut self, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
self.last_notification_at = Some(date.to_owned()); self.last_notification_at = Some(date.to_owned());
date.clone_into(&mut self.updated_at); date.clone_into(&mut self.updated_at);
db_run! { conn: { conn.run(move |conn| {
crate::util::retry(|| { crate::util::retry(
|| {
diesel::update(emergency_access::table.filter(emergency_access::uuid.eq(&self.uuid))) diesel::update(emergency_access::table.filter(emergency_access::uuid.eq(&self.uuid)))
.set((emergency_access::last_notification_at.eq(date), emergency_access::updated_at.eq(date))) .set((emergency_access::last_notification_at.eq(date), emergency_access::updated_at.eq(date)))
.execute(conn) .execute(conn)
}, 10) },
10,
)
.map_res("Error updating emergency access status") .map_res("Error updating emergency access status")
}} })
.await
} }
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
@ -227,11 +238,12 @@ impl EmergencyAccess {
pub async fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.grantor_uuid, conn).await; User::update_uuid_revision(&self.grantor_uuid, conn).await;
db_run! { conn: { conn.run(move |conn| {
diesel::delete(emergency_access::table.filter(emergency_access::uuid.eq(self.uuid))) diesel::delete(emergency_access::table.filter(emergency_access::uuid.eq(self.uuid)))
.execute(conn) .execute(conn)
.map_res("Error removing user from emergency access") .map_res("Error removing user from emergency access")
}} })
.await
} }
pub async fn find_by_grantor_uuid_and_grantee_uuid_or_email( pub async fn find_by_grantor_uuid_and_grantee_uuid_or_email(
@ -240,23 +252,25 @@ impl EmergencyAccess {
email: &str, email: &str,
conn: &DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
emergency_access::table emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid)) .filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.filter(emergency_access::grantee_uuid.eq(grantee_uuid).or(emergency_access::email.eq(email))) .filter(emergency_access::grantee_uuid.eq(grantee_uuid).or(emergency_access::email.eq(email)))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_all_recoveries_initiated(conn: &DbConn) -> Vec<Self> { pub async fn find_all_recoveries_initiated(conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
emergency_access::table emergency_access::table
.filter(emergency_access::status.eq(EmergencyAccessStatus::RecoveryInitiated as i32)) .filter(emergency_access::status.eq(EmergencyAccessStatus::RecoveryInitiated as i32))
.filter(emergency_access::recovery_initiated_at.is_not_null()) .filter(emergency_access::recovery_initiated_at.is_not_null())
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading emergency_access") .expect("Error loading emergency_access")
}} })
.await
} }
pub async fn find_by_uuid_and_grantor_uuid( pub async fn find_by_uuid_and_grantor_uuid(
@ -264,13 +278,14 @@ impl EmergencyAccess {
grantor_uuid: &UserId, grantor_uuid: &UserId,
conn: &DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
emergency_access::table emergency_access::table
.filter(emergency_access::uuid.eq(uuid)) .filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::grantor_uuid.eq(grantor_uuid)) .filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_by_uuid_and_grantee_uuid( pub async fn find_by_uuid_and_grantee_uuid(
@ -278,13 +293,14 @@ impl EmergencyAccess {
grantee_uuid: &UserId, grantee_uuid: &UserId,
conn: &DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
emergency_access::table emergency_access::table
.filter(emergency_access::uuid.eq(uuid)) .filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::grantee_uuid.eq(grantee_uuid)) .filter(emergency_access::grantee_uuid.eq(grantee_uuid))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_by_uuid_and_grantee_email( pub async fn find_by_uuid_and_grantee_email(
@ -292,61 +308,67 @@ impl EmergencyAccess {
grantee_email: &str, grantee_email: &str,
conn: &DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
emergency_access::table emergency_access::table
.filter(emergency_access::uuid.eq(uuid)) .filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::email.eq(grantee_email)) .filter(emergency_access::email.eq(grantee_email))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_all_by_grantee_uuid(grantee_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_all_by_grantee_uuid(grantee_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
emergency_access::table emergency_access::table
.filter(emergency_access::grantee_uuid.eq(grantee_uuid)) .filter(emergency_access::grantee_uuid.eq(grantee_uuid))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading emergency_access") .expect("Error loading emergency_access")
}} })
.await
} }
pub async fn find_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Option<Self> { pub async fn find_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
emergency_access::table emergency_access::table
.filter(emergency_access::email.eq(grantee_email)) .filter(emergency_access::email.eq(grantee_email))
.filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32)) .filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_all_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_all_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
emergency_access::table emergency_access::table
.filter(emergency_access::email.eq(grantee_email)) .filter(emergency_access::email.eq(grantee_email))
.filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32)) .filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading emergency_access") .expect("Error loading emergency_access")
}} })
.await
} }
pub async fn find_all_by_grantor_uuid(grantor_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_all_by_grantor_uuid(grantor_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
emergency_access::table emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid)) .filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading emergency_access") .expect("Error loading emergency_access")
}} })
.await
} }
pub async fn find_all_confirmed_by_grantor_uuid(grantor_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_all_confirmed_by_grantor_uuid(grantor_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
emergency_access::table emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid)) .filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.filter(emergency_access::status.ge(EmergencyAccessStatus::Confirmed as i32)) .filter(emergency_access::status.ge(EmergencyAccessStatus::Confirmed as i32))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading emergency_access") .expect("Error loading emergency_access")
}} })
.await
} }
pub async fn accept_invite(&mut self, grantee_uuid: &UserId, grantee_email: &str, conn: &DbConn) -> EmptyResult { pub async fn accept_invite(&mut self, grantee_uuid: &UserId, grantee_email: &str, conn: &DbConn) -> EmptyResult {

62
src/db/models/event.rs

@ -1,11 +1,18 @@
use chrono::{NaiveDateTime, TimeDelta, Utc}; use chrono::{NaiveDateTime, TimeDelta, Utc};
//use derive_more::{AsRef, Deref, Display, From}; use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use crate::{
CONFIG,
api::EmptyResult,
db::{
DbConn,
schema::{event, users_organizations},
},
error::MapResult,
};
use super::{CipherId, CollectionId, GroupId, MembershipId, OrgPolicyId, OrganizationId, UserId}; use super::{CipherId, CollectionId, GroupId, MembershipId, OrgPolicyId, OrganizationId, UserId};
use crate::db::schema::{event, users_organizations};
use crate::{api::EmptyResult, db::DbConn, error::MapResult, CONFIG};
use diesel::prelude::*;
// https://bitwarden.com/help/event-logs/ // https://bitwarden.com/help/event-logs/
@ -249,11 +256,10 @@ impl Event {
} }
pub async fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(event::table.filter(event::uuid.eq(self.uuid))) diesel::delete(event::table.filter(event::uuid.eq(self.uuid))).execute(conn).map_res("Error deleting event")
.execute(conn) })
.map_res("Error deleting event") .await
}}
} }
/// ############## /// ##############
@ -264,7 +270,7 @@ impl Event {
end: &NaiveDateTime, end: &NaiveDateTime,
conn: &DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
event::table event::table
.filter(event::org_uuid.eq(org_uuid)) .filter(event::org_uuid.eq(org_uuid))
.filter(event::event_date.between(start, end)) .filter(event::event_date.between(start, end))
@ -272,18 +278,15 @@ impl Event {
.limit(Self::PAGE_SIZE) .limit(Self::PAGE_SIZE)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error filtering events") .expect("Error filtering events")
}} })
.await
} }
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 { pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { conn.run(move |conn| {
event::table event::table.filter(event::org_uuid.eq(org_uuid)).count().first::<i64>(conn).ok().unwrap_or(0)
.filter(event::org_uuid.eq(org_uuid)) })
.count() .await
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
} }
pub async fn find_by_org_and_member( pub async fn find_by_org_and_member(
@ -293,18 +296,23 @@ impl Event {
end: &NaiveDateTime, end: &NaiveDateTime,
conn: &DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
event::table event::table
.inner_join(users_organizations::table.on(users_organizations::uuid.eq(member_uuid))) .inner_join(users_organizations::table.on(users_organizations::uuid.eq(member_uuid)))
.filter(event::org_uuid.eq(org_uuid)) .filter(event::org_uuid.eq(org_uuid))
.filter(event::event_date.between(start, end)) .filter(event::event_date.between(start, end))
.filter(event::user_uuid.eq(users_organizations::user_uuid.nullable()).or(event::act_user_uuid.eq(users_organizations::user_uuid.nullable()))) .filter(
event::user_uuid
.eq(users_organizations::user_uuid.nullable())
.or(event::act_user_uuid.eq(users_organizations::user_uuid.nullable())),
)
.select(event::all_columns) .select(event::all_columns)
.order_by(event::event_date.desc()) .order_by(event::event_date.desc())
.limit(Self::PAGE_SIZE) .limit(Self::PAGE_SIZE)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error filtering events") .expect("Error filtering events")
}} })
.await
} }
pub async fn find_by_cipher_uuid( pub async fn find_by_cipher_uuid(
@ -313,7 +321,7 @@ impl Event {
end: &NaiveDateTime, end: &NaiveDateTime,
conn: &DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
event::table event::table
.filter(event::cipher_uuid.eq(cipher_uuid)) .filter(event::cipher_uuid.eq(cipher_uuid))
.filter(event::event_date.between(start, end)) .filter(event::event_date.between(start, end))
@ -321,17 +329,19 @@ impl Event {
.limit(Self::PAGE_SIZE) .limit(Self::PAGE_SIZE)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error filtering events") .expect("Error filtering events")
}} })
.await
} }
pub async fn clean_events(conn: &DbConn) -> EmptyResult { pub async fn clean_events(conn: &DbConn) -> EmptyResult {
if let Some(days_to_retain) = CONFIG.events_days_retain() { if let Some(days_to_retain) = CONFIG.events_days_retain() {
let dt = Utc::now().naive_utc() - TimeDelta::try_days(days_to_retain).unwrap(); let dt = Utc::now().naive_utc() - TimeDelta::try_days(days_to_retain).unwrap();
db_run! { conn: { conn.run(move |conn| {
diesel::delete(event::table.filter(event::event_date.lt(dt))) diesel::delete(event::table.filter(event::event_date.lt(dt)))
.execute(conn) .execute(conn)
.map_res("Error cleaning old events") .map_res("Error cleaning old events")
}} })
.await
} else { } else {
Ok(()) Ok(())
} }

56
src/db/models/favorite.rs

@ -1,7 +1,13 @@
use super::{CipherId, User, UserId};
use crate::db::schema::favorites;
use diesel::prelude::*; use diesel::prelude::*;
use crate::{
api::EmptyResult,
db::{DbConn, schema::favorites},
error::MapResult,
};
use super::{CipherId, User, UserId};
#[derive(Identifiable, Queryable, Insertable)] #[derive(Identifiable, Queryable, Insertable)]
#[diesel(table_name = favorites)] #[diesel(table_name = favorites)]
#[diesel(primary_key(user_uuid, cipher_uuid))] #[diesel(primary_key(user_uuid, cipher_uuid))]
@ -10,24 +16,18 @@ pub struct Favorite {
pub cipher_uuid: CipherId, pub cipher_uuid: CipherId,
} }
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
impl Favorite { impl Favorite {
// Returns whether the specified cipher is a favorite of the specified user. // Returns whether the specified cipher is a favorite of the specified user.
pub async fn is_favorite(cipher_uuid: &CipherId, user_uuid: &UserId, conn: &DbConn) -> bool { pub async fn is_favorite(cipher_uuid: &CipherId, user_uuid: &UserId, conn: &DbConn) -> bool {
db_run! { conn: { conn.run(move |conn| {
let query = favorites::table let query = favorites::table
.filter(favorites::cipher_uuid.eq(cipher_uuid)) .filter(favorites::cipher_uuid.eq(cipher_uuid))
.filter(favorites::user_uuid.eq(user_uuid)) .filter(favorites::user_uuid.eq(user_uuid))
.count(); .count();
query.first::<i64>(conn) query.first::<i64>(conn).ok().unwrap_or(0) != 0
.ok() })
.unwrap_or(0) != 0 .await
}}
} }
// Sets whether the specified cipher is a favorite of the specified user. // Sets whether the specified cipher is a favorite of the specified user.
@ -41,27 +41,26 @@ impl Favorite {
match (old, new) { match (old, new) {
(false, true) => { (false, true) => {
User::update_uuid_revision(user_uuid, conn).await; User::update_uuid_revision(user_uuid, conn).await;
db_run! { conn: { conn.run(move |conn| {
diesel::insert_into(favorites::table) diesel::insert_into(favorites::table)
.values(( .values((favorites::user_uuid.eq(user_uuid), favorites::cipher_uuid.eq(cipher_uuid)))
favorites::user_uuid.eq(user_uuid),
favorites::cipher_uuid.eq(cipher_uuid),
))
.execute(conn) .execute(conn)
.map_res("Error adding favorite") .map_res("Error adding favorite")
}} })
.await
} }
(true, false) => { (true, false) => {
User::update_uuid_revision(user_uuid, conn).await; User::update_uuid_revision(user_uuid, conn).await;
db_run! { conn: { conn.run(move |conn| {
diesel::delete( diesel::delete(
favorites::table favorites::table
.filter(favorites::user_uuid.eq(user_uuid)) .filter(favorites::user_uuid.eq(user_uuid))
.filter(favorites::cipher_uuid.eq(cipher_uuid)) .filter(favorites::cipher_uuid.eq(cipher_uuid)),
) )
.execute(conn) .execute(conn)
.map_res("Error removing favorite") .map_res("Error removing favorite")
}} })
.await
} }
// Otherwise, the favorite status is already what it should be. // Otherwise, the favorite status is already what it should be.
_ => Ok(()), _ => Ok(()),
@ -70,31 +69,34 @@ impl Favorite {
// Delete all favorite entries associated with the specified cipher. // Delete all favorite entries associated with the specified cipher.
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(favorites::table.filter(favorites::cipher_uuid.eq(cipher_uuid))) diesel::delete(favorites::table.filter(favorites::cipher_uuid.eq(cipher_uuid)))
.execute(conn) .execute(conn)
.map_res("Error removing favorites by cipher") .map_res("Error removing favorites by cipher")
}} })
.await
} }
// Delete all favorite entries associated with the specified user. // Delete all favorite entries associated with the specified user.
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(favorites::table.filter(favorites::user_uuid.eq(user_uuid))) diesel::delete(favorites::table.filter(favorites::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)
.map_res("Error removing favorites by user") .map_res("Error removing favorites by user")
}} })
.await
} }
/// Return a vec with (cipher_uuid) this will only contain favorite flagged ciphers /// Return a vec with (cipher_uuid) this will only contain favorite flagged ciphers
/// This is used during a full sync so we only need one query for all favorite cipher matches. /// This is used during a full sync so we only need one query for all favorite cipher matches.
pub async fn get_all_cipher_uuid_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<CipherId> { pub async fn get_all_cipher_uuid_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<CipherId> {
db_run! { conn: { conn.run(move |conn| {
favorites::table favorites::table
.filter(favorites::user_uuid.eq(user_uuid)) .filter(favorites::user_uuid.eq(user_uuid))
.select(favorites::cipher_uuid) .select(favorites::cipher_uuid)
.load::<CipherId>(conn) .load::<CipherId>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
} }

71
src/db/models/folder.rs

@ -1,12 +1,20 @@
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From}; use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use super::{CipherId, User, UserId}; use crate::{
use crate::db::schema::{folders, folders_ciphers}; api::EmptyResult,
use diesel::prelude::*; db::{
DbConn,
schema::{folders, folders_ciphers},
},
error::MapResult,
};
use macros::UuidFromParam; use macros::UuidFromParam;
use super::{CipherId, User, UserId};
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = folders)] #[diesel(table_name = folders)]
#[diesel(primary_key(uuid))] #[diesel(primary_key(uuid))]
@ -56,17 +64,12 @@ impl Folder {
impl FolderCipher { impl FolderCipher {
pub fn new(folder_uuid: FolderId, cipher_uuid: CipherId) -> Self { pub fn new(folder_uuid: FolderId, cipher_uuid: CipherId) -> Self {
Self { Self {
folder_uuid,
cipher_uuid, cipher_uuid,
folder_uuid,
} }
} }
} }
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods /// Database methods
impl Folder { impl Folder {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
@ -107,11 +110,12 @@ impl Folder {
User::update_uuid_revision(&self.user_uuid, conn).await; User::update_uuid_revision(&self.user_uuid, conn).await;
FolderCipher::delete_all_by_folder(&self.uuid, conn).await?; FolderCipher::delete_all_by_folder(&self.uuid, conn).await?;
db_run! { conn: { conn.run(move |conn| {
diesel::delete(folders::table.filter(folders::uuid.eq(&self.uuid))) diesel::delete(folders::table.filter(folders::uuid.eq(&self.uuid)))
.execute(conn) .execute(conn)
.map_res("Error deleting folder") .map_res("Error deleting folder")
}} })
.await
} }
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
@ -122,22 +126,21 @@ impl Folder {
} }
pub async fn find_by_uuid_and_user(uuid: &FolderId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid_and_user(uuid: &FolderId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
folders::table folders::table
.filter(folders::uuid.eq(uuid)) .filter(folders::uuid.eq(uuid))
.filter(folders::user_uuid.eq(user_uuid)) .filter(folders::user_uuid.eq(user_uuid))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
folders::table folders::table.filter(folders::user_uuid.eq(user_uuid)).load::<Self>(conn).expect("Error loading folders")
.filter(folders::user_uuid.eq(user_uuid)) })
.load::<Self>(conn) .await
.expect("Error loading folders")
}}
} }
} }
@ -165,7 +168,7 @@ impl FolderCipher {
} }
pub async fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete( diesel::delete(
folders_ciphers::table folders_ciphers::table
.filter(folders_ciphers::cipher_uuid.eq(self.cipher_uuid)) .filter(folders_ciphers::cipher_uuid.eq(self.cipher_uuid))
@ -173,23 +176,26 @@ impl FolderCipher {
) )
.execute(conn) .execute(conn)
.map_res("Error removing cipher from folder") .map_res("Error removing cipher from folder")
}} })
.await
} }
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(folders_ciphers::table.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid))) diesel::delete(folders_ciphers::table.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid)))
.execute(conn) .execute(conn)
.map_res("Error removing cipher from folders") .map_res("Error removing cipher from folders")
}} })
.await
} }
pub async fn delete_all_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(folders_ciphers::table.filter(folders_ciphers::folder_uuid.eq(folder_uuid))) diesel::delete(folders_ciphers::table.filter(folders_ciphers::folder_uuid.eq(folder_uuid)))
.execute(conn) .execute(conn)
.map_res("Error removing ciphers from folder") .map_res("Error removing ciphers from folder")
}} })
.await
} }
pub async fn find_by_folder_and_cipher( pub async fn find_by_folder_and_cipher(
@ -197,35 +203,38 @@ impl FolderCipher {
cipher_uuid: &CipherId, cipher_uuid: &CipherId,
conn: &DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
folders_ciphers::table folders_ciphers::table
.filter(folders_ciphers::folder_uuid.eq(folder_uuid)) .filter(folders_ciphers::folder_uuid.eq(folder_uuid))
.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid)) .filter(folders_ciphers::cipher_uuid.eq(cipher_uuid))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
folders_ciphers::table folders_ciphers::table
.filter(folders_ciphers::folder_uuid.eq(folder_uuid)) .filter(folders_ciphers::folder_uuid.eq(folder_uuid))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading folders") .expect("Error loading folders")
}} })
.await
} }
/// Return a vec with (cipher_uuid, folder_uuid) /// Return a vec with (cipher_uuid, folder_uuid)
/// This is used during a full sync so we only need one query for all folder matches. /// This is used during a full sync so we only need one query for all folder matches.
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<(CipherId, FolderId)> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<(CipherId, FolderId)> {
db_run! { conn: { conn.run(move |conn| {
folders_ciphers::table folders_ciphers::table
.inner_join(folders::table) .inner_join(folders::table)
.filter(folders::user_uuid.eq(user_uuid)) .filter(folders::user_uuid.eq(user_uuid))
.select(folders_ciphers::all_columns) .select(folders_ciphers::all_columns)
.load::<(CipherId, FolderId)>(conn) .load::<(CipherId, FolderId)>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
} }

259
src/db/models/group.rs

@ -1,14 +1,20 @@
use super::{CollectionId, Membership, MembershipId, OrganizationId, User, UserId};
use crate::api::EmptyResult;
use crate::db::schema::{collections, collections_groups, groups, groups_users, users_organizations};
use crate::db::DbConn;
use crate::error::MapResult;
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From}; use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*; use diesel::prelude::*;
use macros::UuidFromParam;
use serde_json::Value; use serde_json::Value;
use crate::{
api::EmptyResult,
db::{
DbConn,
schema::{collections, collections_groups, groups, groups_users, users_organizations},
},
error::MapResult,
};
use macros::UuidFromParam;
use super::{CollectionId, Membership, MembershipId, OrganizationId, User, UserId};
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = groups)] #[diesel(table_name = groups)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -197,33 +203,31 @@ impl Group {
} }
pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
groups::table groups::table
.filter(groups::organizations_uuid.eq(org_uuid)) .filter(groups::organizations_uuid.eq(org_uuid))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading groups") .expect("Error loading groups")
}} })
.await
} }
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 { pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { conn.run(move |conn| {
groups::table groups::table.filter(groups::organizations_uuid.eq(org_uuid)).count().first::<i64>(conn).ok().unwrap_or(0)
.filter(groups::organizations_uuid.eq(org_uuid)) })
.count() .await
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
} }
pub async fn find_by_uuid_and_org(uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid_and_org(uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
groups::table groups::table
.filter(groups::uuid.eq(uuid)) .filter(groups::uuid.eq(uuid))
.filter(groups::organizations_uuid.eq(org_uuid)) .filter(groups::organizations_uuid.eq(org_uuid))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_by_external_id_and_org( pub async fn find_by_external_id_and_org(
@ -231,77 +235,85 @@ impl Group {
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
conn: &DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
groups::table groups::table
.filter(groups::external_id.eq(external_id)) .filter(groups::external_id.eq(external_id))
.filter(groups::organizations_uuid.eq(org_uuid)) .filter(groups::organizations_uuid.eq(org_uuid))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
//Returns all organizations the user has full access to //Returns all organizations the user has full access to
pub async fn get_orgs_by_user_with_full_access(user_uuid: &UserId, conn: &DbConn) -> Vec<OrganizationId> { pub async fn get_orgs_by_user_with_full_access(user_uuid: &UserId, conn: &DbConn) -> Vec<OrganizationId> {
db_run! { conn: { conn.run(move |conn| {
groups_users::table groups_users::table
.inner_join(users_organizations::table.on( .inner_join(
users_organizations::uuid.eq(groups_users::users_organizations_uuid) users_organizations::table.on(users_organizations::uuid.eq(groups_users::users_organizations_uuid)),
)) )
.inner_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid) .inner_join(
.and(groups::organizations_uuid.eq(users_organizations::org_uuid)) groups::table.on(groups::uuid
)) .eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(groups::access_all.eq(true)) .filter(groups::access_all.eq(true))
.select(groups::organizations_uuid) .select(groups::organizations_uuid)
.distinct() .distinct()
.load::<OrganizationId>(conn) .load::<OrganizationId>(conn)
.expect("Error loading organization group full access information for user") .expect("Error loading organization group full access information for user")
}} })
.await
} }
pub async fn is_in_full_access_group(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &DbConn) -> bool { pub async fn is_in_full_access_group(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &DbConn) -> bool {
db_run! { conn: { conn.run(move |conn| {
groups::table groups::table
.inner_join(groups_users::table.on( .inner_join(groups_users::table.on(groups_users::groups_uuid.eq(groups::uuid)))
groups_users::groups_uuid.eq(groups::uuid) .inner_join(
)) users_organizations::table.on(users_organizations::uuid.eq(groups_users::users_organizations_uuid)),
.inner_join(users_organizations::table.on( )
users_organizations::uuid.eq(groups_users::users_organizations_uuid)
))
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(groups::organizations_uuid.eq(org_uuid)) .filter(groups::organizations_uuid.eq(org_uuid))
.filter(groups::access_all.eq(true)) .filter(groups::access_all.eq(true))
.select(groups::access_all) .select(groups::access_all)
.first::<bool>(conn) .first::<bool>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
pub async fn delete(&self, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult { pub async fn delete(&self, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
CollectionGroup::delete_all_by_group(&self.uuid, org_uuid, conn).await?; CollectionGroup::delete_all_by_group(&self.uuid, org_uuid, conn).await?;
GroupUser::delete_all_by_group(&self.uuid, org_uuid, conn).await?; GroupUser::delete_all_by_group(&self.uuid, org_uuid, conn).await?;
db_run! { conn: { conn.run(move |conn| {
diesel::delete(groups::table.filter(groups::uuid.eq(&self.uuid))) diesel::delete(groups::table.filter(groups::uuid.eq(&self.uuid)))
.execute(conn) .execute(conn)
.map_res("Error deleting group") .map_res("Error deleting group")
}} })
.await
} }
pub async fn update_revision(uuid: &GroupId, conn: &DbConn) { pub async fn update_revision(uuid: &GroupId, conn: &DbConn) {
if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await { if let Err(e) = Self::update_revision_impl(uuid, &Utc::now().naive_utc(), conn).await {
warn!("Failed to update revision for {uuid}: {e:#?}"); warn!("Failed to update revision for {uuid}: {e:#?}");
} }
} }
async fn _update_revision(uuid: &GroupId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult { async fn update_revision_impl(uuid: &GroupId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
crate::util::retry(|| { crate::util::retry(
|| {
diesel::update(groups::table.filter(groups::uuid.eq(uuid))) diesel::update(groups::table.filter(groups::uuid.eq(uuid)))
.set(groups::revision_date.eq(date)) .set(groups::revision_date.eq(date))
.execute(conn) .execute(conn)
}, 10) },
10,
)
.map_res("Error updating group revision") .map_res("Error updating group revision")
}} })
.await
} }
} }
@ -366,60 +378,63 @@ impl CollectionGroup {
} }
pub async fn find_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
collections_groups::table collections_groups::table
.inner_join(groups::table.on( .inner_join(groups::table.on(groups::uuid.eq(collections_groups::groups_uuid)))
groups::uuid.eq(collections_groups::groups_uuid) .inner_join(
)) collections::table.on(collections::uuid
.inner_join(collections::table.on( .eq(collections_groups::collections_uuid)
collections::uuid.eq(collections_groups::collections_uuid) .and(collections::org_uuid.eq(groups::organizations_uuid))),
.and(collections::org_uuid.eq(groups::organizations_uuid)) )
))
.filter(collections_groups::groups_uuid.eq(group_uuid)) .filter(collections_groups::groups_uuid.eq(group_uuid))
.filter(collections::org_uuid.eq(org_uuid)) .filter(collections::org_uuid.eq(org_uuid))
.select(collections_groups::all_columns) .select(collections_groups::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading collection groups") .expect("Error loading collection groups")
}} })
.await
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
collections_groups::table collections_groups::table
.inner_join(groups_users::table.on( .inner_join(groups_users::table.on(groups_users::groups_uuid.eq(collections_groups::groups_uuid)))
groups_users::groups_uuid.eq(collections_groups::groups_uuid) .inner_join(
)) users_organizations::table.on(users_organizations::uuid.eq(groups_users::users_organizations_uuid)),
.inner_join(users_organizations::table.on( )
users_organizations::uuid.eq(groups_users::users_organizations_uuid) .inner_join(
)) groups::table.on(groups::uuid
.inner_join(groups::table.on(groups::uuid.eq(collections_groups::groups_uuid) .eq(collections_groups::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid)) .and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)) )
.inner_join(collections::table.on( .inner_join(
collections::uuid.eq(collections_groups::collections_uuid) collections::table.on(collections::uuid
.and(collections::org_uuid.eq(groups::organizations_uuid)) .eq(collections_groups::collections_uuid)
)) .and(collections::org_uuid.eq(groups::organizations_uuid))),
)
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.select(collections_groups::all_columns) .select(collections_groups::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading user collection groups") .expect("Error loading user collection groups")
}} })
.await
} }
pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
collections_groups::table collections_groups::table
.filter(collections_groups::collections_uuid.eq(collection_uuid)) .filter(collections_groups::collections_uuid.eq(collection_uuid))
.inner_join(collections::table.on( .inner_join(collections::table.on(collections::uuid.eq(collections_groups::collections_uuid)))
collections::uuid.eq(collections_groups::collections_uuid) .inner_join(
)) groups::table.on(groups::uuid
.inner_join(groups::table.on(groups::uuid.eq(collections_groups::groups_uuid) .eq(collections_groups::groups_uuid)
.and(groups::organizations_uuid.eq(collections::org_uuid)) .and(groups::organizations_uuid.eq(collections::org_uuid))),
)) )
.select(collections_groups::all_columns) .select(collections_groups::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading collection groups") .expect("Error loading collection groups")
}} })
.await
} }
pub async fn delete(&self, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult { pub async fn delete(&self, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@ -428,13 +443,14 @@ impl CollectionGroup {
group_user.update_user_revision(conn).await; group_user.update_user_revision(conn).await;
} }
db_run! { conn: { conn.run(move |conn| {
diesel::delete(collections_groups::table) diesel::delete(collections_groups::table)
.filter(collections_groups::collections_uuid.eq(&self.collections_uuid)) .filter(collections_groups::collections_uuid.eq(&self.collections_uuid))
.filter(collections_groups::groups_uuid.eq(&self.groups_uuid)) .filter(collections_groups::groups_uuid.eq(&self.groups_uuid))
.execute(conn) .execute(conn)
.map_res("Error deleting collection group") .map_res("Error deleting collection group")
}} })
.await
} }
pub async fn delete_all_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@ -443,12 +459,13 @@ impl CollectionGroup {
group_user.update_user_revision(conn).await; group_user.update_user_revision(conn).await;
} }
db_run! { conn: { conn.run(move |conn| {
diesel::delete(collections_groups::table) diesel::delete(collections_groups::table)
.filter(collections_groups::groups_uuid.eq(group_uuid)) .filter(collections_groups::groups_uuid.eq(group_uuid))
.execute(conn) .execute(conn)
.map_res("Error deleting collection group") .map_res("Error deleting collection group")
}} })
.await
} }
pub async fn delete_all_by_collection( pub async fn delete_all_by_collection(
@ -464,12 +481,13 @@ impl CollectionGroup {
} }
} }
db_run! { conn: { conn.run(move |conn| {
diesel::delete(collections_groups::table) diesel::delete(collections_groups::table)
.filter(collections_groups::collections_uuid.eq(collection_uuid)) .filter(collections_groups::collections_uuid.eq(collection_uuid))
.execute(conn) .execute(conn)
.map_res("Error deleting collection group") .map_res("Error deleting collection group")
}} })
.await
} }
} }
@ -521,30 +539,31 @@ impl GroupUser {
} }
pub async fn find_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
groups_users::table groups_users::table
.inner_join(groups::table.on( .inner_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)))
groups::uuid.eq(groups_users::groups_uuid) .inner_join(
)) users_organizations::table.on(users_organizations::uuid
.inner_join(users_organizations::table.on( .eq(groups_users::users_organizations_uuid)
users_organizations::uuid.eq(groups_users::users_organizations_uuid) .and(users_organizations::org_uuid.eq(groups::organizations_uuid))),
.and(users_organizations::org_uuid.eq(groups::organizations_uuid)) )
))
.filter(groups_users::groups_uuid.eq(group_uuid)) .filter(groups_users::groups_uuid.eq(group_uuid))
.filter(groups::organizations_uuid.eq(org_uuid)) .filter(groups::organizations_uuid.eq(org_uuid))
.select(groups_users::all_columns) .select(groups_users::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading group users") .expect("Error loading group users")
}} })
.await
} }
pub async fn find_by_member(member_uuid: &MembershipId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_member(member_uuid: &MembershipId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
groups_users::table groups_users::table
.filter(groups_users::users_organizations_uuid.eq(member_uuid)) .filter(groups_users::users_organizations_uuid.eq(member_uuid))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading groups for user") .expect("Error loading groups for user")
}} })
.await
} }
pub async fn has_access_to_collection_by_member( pub async fn has_access_to_collection_by_member(
@ -552,24 +571,23 @@ impl GroupUser {
member_uuid: &MembershipId, member_uuid: &MembershipId,
conn: &DbConn, conn: &DbConn,
) -> bool { ) -> bool {
db_run! { conn: { conn.run(move |conn| {
groups_users::table groups_users::table
.inner_join(collections_groups::table.on( .inner_join(collections_groups::table.on(collections_groups::groups_uuid.eq(groups_users::groups_uuid)))
collections_groups::groups_uuid.eq(groups_users::groups_uuid) .inner_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)))
)) .inner_join(
.inner_join(groups::table.on( collections::table.on(collections::uuid
groups::uuid.eq(groups_users::groups_uuid) .eq(collections_groups::collections_uuid)
)) .and(collections::org_uuid.eq(groups::organizations_uuid))),
.inner_join(collections::table.on( )
collections::uuid.eq(collections_groups::collections_uuid)
.and(collections::org_uuid.eq(groups::organizations_uuid))
))
.filter(collections_groups::collections_uuid.eq(collection_uuid)) .filter(collections_groups::collections_uuid.eq(collection_uuid))
.filter(groups_users::users_organizations_uuid.eq(member_uuid)) .filter(groups_users::users_organizations_uuid.eq(member_uuid))
.count() .count()
.first::<i64>(conn) .first::<i64>(conn)
.unwrap_or(0) != 0 .unwrap_or(0)
}} != 0
})
.await
} }
pub async fn has_full_access_by_member( pub async fn has_full_access_by_member(
@ -577,18 +595,18 @@ impl GroupUser {
member_uuid: &MembershipId, member_uuid: &MembershipId,
conn: &DbConn, conn: &DbConn,
) -> bool { ) -> bool {
db_run! { conn: { conn.run(move |conn| {
groups_users::table groups_users::table
.inner_join(groups::table.on( .inner_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)))
groups::uuid.eq(groups_users::groups_uuid)
))
.filter(groups::organizations_uuid.eq(org_uuid)) .filter(groups::organizations_uuid.eq(org_uuid))
.filter(groups::access_all.eq(true)) .filter(groups::access_all.eq(true))
.filter(groups_users::users_organizations_uuid.eq(member_uuid)) .filter(groups_users::users_organizations_uuid.eq(member_uuid))
.count() .count()
.first::<i64>(conn) .first::<i64>(conn)
.unwrap_or(0) != 0 .unwrap_or(0)
}} != 0
})
.await
} }
pub async fn update_user_revision(&self, conn: &DbConn) { pub async fn update_user_revision(&self, conn: &DbConn) {
@ -606,15 +624,16 @@ impl GroupUser {
match Membership::find_by_uuid(member_uuid, conn).await { match Membership::find_by_uuid(member_uuid, conn).await {
Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await, Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await,
None => warn!("Member could not be found!"), None => warn!("Member could not be found!"),
}; }
db_run! { conn: { conn.run(move |conn| {
diesel::delete(groups_users::table) diesel::delete(groups_users::table)
.filter(groups_users::groups_uuid.eq(group_uuid)) .filter(groups_users::groups_uuid.eq(group_uuid))
.filter(groups_users::users_organizations_uuid.eq(member_uuid)) .filter(groups_users::users_organizations_uuid.eq(member_uuid))
.execute(conn) .execute(conn)
.map_res("Error deleting group users") .map_res("Error deleting group users")
}} })
.await
} }
pub async fn delete_all_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@ -623,12 +642,13 @@ impl GroupUser {
group_user.update_user_revision(conn).await; group_user.update_user_revision(conn).await;
} }
db_run! { conn: { conn.run(move |conn| {
diesel::delete(groups_users::table) diesel::delete(groups_users::table)
.filter(groups_users::groups_uuid.eq(group_uuid)) .filter(groups_users::groups_uuid.eq(group_uuid))
.execute(conn) .execute(conn)
.map_res("Error deleting group users") .map_res("Error deleting group users")
}} })
.await
} }
pub async fn delete_all_by_member(member_uuid: &MembershipId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_member(member_uuid: &MembershipId, conn: &DbConn) -> EmptyResult {
@ -637,12 +657,13 @@ impl GroupUser {
None => warn!("Member could not be found!"), None => warn!("Member could not be found!"),
} }
db_run! { conn: { conn.run(move |conn| {
diesel::delete(groups_users::table) diesel::delete(groups_users::table)
.filter(groups_users::users_organizations_uuid.eq(member_uuid)) .filter(groups_users::users_organizations_uuid.eq(member_uuid))
.execute(conn) .execute(conn)
.map_res("Error deleting user groups") .map_res("Error deleting user groups")
}} })
.await
} }
} }

4
src/db/models/mod.rs

@ -23,7 +23,7 @@ pub use self::attachment::{Attachment, AttachmentId};
pub use self::auth_request::{AuthRequest, AuthRequestId}; pub use self::auth_request::{AuthRequest, AuthRequestId};
pub use self::cipher::{Cipher, CipherId, RepromptType}; pub use self::cipher::{Cipher, CipherId, RepromptType};
pub use self::collection::{Collection, CollectionCipher, CollectionId, CollectionUser}; pub use self::collection::{Collection, CollectionCipher, CollectionId, CollectionUser};
pub use self::device::{Device, DeviceId, DeviceType, PushId}; pub use self::device::{Device, DeviceId, DeviceType, DeviceWithAuthRequest, PushId};
pub use self::emergency_access::{EmergencyAccess, EmergencyAccessId, EmergencyAccessStatus, EmergencyAccessType}; pub use self::emergency_access::{EmergencyAccess, EmergencyAccessId, EmergencyAccessStatus, EmergencyAccessType};
pub use self::event::{Event, EventType}; pub use self::event::{Event, EventType};
pub use self::favorite::Favorite; pub use self::favorite::Favorite;
@ -35,8 +35,8 @@ pub use self::organization::{
OrganizationId, OrganizationId,
}; };
pub use self::send::{ pub use self::send::{
id::{SendFileId, SendId},
Send, SendType, Send, SendType,
id::{SendFileId, SendId},
}; };
pub use self::sso_auth::{OIDCAuthenticatedUser, OIDCCodeResponseError, SsoAuth}; pub use self::sso_auth::{OIDCAuthenticatedUser, OIDCCodeResponseError, SsoAuth};
pub use self::two_factor::{TwoFactor, TwoFactorType}; pub use self::two_factor::{TwoFactor, TwoFactorType};

118
src/db/models/org_policy.rs

@ -1,14 +1,17 @@
use derive_more::{AsRef, From}; use derive_more::{AsRef, From};
use diesel::prelude::*;
use serde::Deserialize; use serde::Deserialize;
use serde_json::Value; use serde_json::Value;
use crate::api::core::two_factor; use crate::{
use crate::api::EmptyResult; CONFIG,
use crate::db::schema::{org_policies, users_organizations}; api::{EmptyResult, core::two_factor},
use crate::db::DbConn; db::{
use crate::error::MapResult; DbConn,
use crate::CONFIG; schema::{org_policies, users_organizations},
use diesel::prelude::*; },
error::MapResult,
};
use super::{Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId, TwoFactor, UserId}; use super::{Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId, TwoFactor, UserId};
@ -148,37 +151,38 @@ impl OrgPolicy {
} }
pub async fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(org_policies::table.filter(org_policies::uuid.eq(self.uuid))) diesel::delete(org_policies::table.filter(org_policies::uuid.eq(self.uuid)))
.execute(conn) .execute(conn)
.map_res("Error deleting org_policy") .map_res("Error deleting org_policy")
}} })
.await
} }
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
org_policies::table org_policies::table
.filter(org_policies::org_uuid.eq(org_uuid)) .filter(org_policies::org_uuid.eq(org_uuid))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading org_policy") .expect("Error loading org_policy")
}} })
.await
} }
pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
org_policies::table org_policies::table
.inner_join( .inner_join(
users_organizations::table.on( users_organizations::table.on(users_organizations::org_uuid
users_organizations::org_uuid.eq(org_policies::org_uuid) .eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))) .and(users_organizations::user_uuid.eq(user_uuid))),
)
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
) )
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.select(org_policies::all_columns) .select(org_policies::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading org_policy") .expect("Error loading org_policy")
}} })
.await
} }
pub async fn find_by_org_and_type( pub async fn find_by_org_and_type(
@ -186,21 +190,23 @@ impl OrgPolicy {
policy_type: OrgPolicyType, policy_type: OrgPolicyType,
conn: &DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
org_policies::table org_policies::table
.filter(org_policies::org_uuid.eq(org_uuid)) .filter(org_policies::org_uuid.eq(org_uuid))
.filter(org_policies::atype.eq(policy_type as i32)) .filter(org_policies::atype.eq(policy_type as i32))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(org_policies::table.filter(org_policies::org_uuid.eq(org_uuid))) diesel::delete(org_policies::table.filter(org_policies::org_uuid.eq(org_uuid)))
.execute(conn) .execute(conn)
.map_res("Error deleting org_policy") .map_res("Error deleting org_policy")
}} })
.await
} }
pub async fn find_accepted_and_confirmed_by_user_and_active_policy( pub async fn find_accepted_and_confirmed_by_user_and_active_policy(
@ -208,25 +214,22 @@ impl OrgPolicy {
policy_type: OrgPolicyType, policy_type: OrgPolicyType,
conn: &DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
org_policies::table org_policies::table
.inner_join( .inner_join(
users_organizations::table.on( users_organizations::table.on(users_organizations::org_uuid
users_organizations::org_uuid.eq(org_policies::org_uuid) .eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))) .and(users_organizations::user_uuid.eq(user_uuid))),
)
.filter(
users_organizations::status.eq(MembershipStatus::Accepted as i32)
)
.or_filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
) )
.filter(users_organizations::status.eq(MembershipStatus::Accepted as i32))
.or_filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(org_policies::atype.eq(policy_type as i32)) .filter(org_policies::atype.eq(policy_type as i32))
.filter(org_policies::enabled.eq(true)) .filter(org_policies::enabled.eq(true))
.select(org_policies::all_columns) .select(org_policies::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading org_policy") .expect("Error loading org_policy")
}} })
.await
} }
pub async fn find_confirmed_by_user_and_active_policy( pub async fn find_confirmed_by_user_and_active_policy(
@ -234,22 +237,21 @@ impl OrgPolicy {
policy_type: OrgPolicyType, policy_type: OrgPolicyType,
conn: &DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
org_policies::table org_policies::table
.inner_join( .inner_join(
users_organizations::table.on( users_organizations::table.on(users_organizations::org_uuid
users_organizations::org_uuid.eq(org_policies::org_uuid) .eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))) .and(users_organizations::user_uuid.eq(user_uuid))),
)
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
) )
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(org_policies::atype.eq(policy_type as i32)) .filter(org_policies::atype.eq(policy_type as i32))
.filter(org_policies::enabled.eq(true)) .filter(org_policies::enabled.eq(true))
.select(org_policies::all_columns) .select(org_policies::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading org_policy") .expect("Error loading org_policy")
}} })
.await
} }
/// Returns true if the user belongs to an org that has enabled the specified policy type, /// Returns true if the user belongs to an org that has enabled the specified policy type,
@ -269,12 +271,12 @@ impl OrgPolicy {
continue; continue;
} }
if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await { if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await
if user.atype < MembershipType::Admin { && user.atype < MembershipType::Admin
{
return true; return true;
} }
} }
}
false false
} }
@ -282,15 +284,15 @@ impl OrgPolicy {
if m.atype < MembershipType::Admin && m.status > (MembershipStatus::Invited as i32) { if m.atype < MembershipType::Admin && m.status > (MembershipStatus::Invited as i32) {
// Enforce TwoFactor/TwoStep login // Enforce TwoFactor/TwoStep login
if let Some(p) = Self::find_by_org_and_type(&m.org_uuid, OrgPolicyType::TwoFactorAuthentication, conn).await if let Some(p) = Self::find_by_org_and_type(&m.org_uuid, OrgPolicyType::TwoFactorAuthentication, conn).await
&& p.enabled
&& TwoFactor::find_by_user(&m.user_uuid, conn).await.is_empty()
{ {
if p.enabled && TwoFactor::find_by_user(&m.user_uuid, conn).await.is_empty() {
if CONFIG.email_2fa_auto_fallback() { if CONFIG.email_2fa_auto_fallback() {
two_factor::email::find_and_activate_email_2fa(&m.user_uuid, conn).await?; two_factor::email::find_and_activate_email_2fa(&m.user_uuid, conn).await?;
} else { } else {
err!(format!("Cannot {} because 2FA is required (membership {})", action, m.uuid)); err!(format!("Cannot {} because 2FA is required (membership {})", action, m.uuid));
} }
} }
}
// Check if the user is part of another Organization with SingleOrg activated // Check if the user is part of another Organization with SingleOrg activated
if Self::is_applicable_to_user(&m.user_uuid, OrgPolicyType::SingleOrg, Some(&m.org_uuid), conn).await { if Self::is_applicable_to_user(&m.user_uuid, OrgPolicyType::SingleOrg, Some(&m.org_uuid), conn).await {
@ -300,12 +302,14 @@ impl OrgPolicy {
)); ));
} }
if let Some(p) = Self::find_by_org_and_type(&m.org_uuid, OrgPolicyType::SingleOrg, conn).await { if let Some(p) = Self::find_by_org_and_type(&m.org_uuid, OrgPolicyType::SingleOrg, conn).await
if p.enabled && p.enabled
&& Membership::count_accepted_and_confirmed_by_user(&m.user_uuid, &m.org_uuid, conn).await > 0 && Membership::count_accepted_and_confirmed_by_user(&m.user_uuid, &m.org_uuid, conn).await > 0
{ {
err!(format!("Cannot {} because the organization policy forbids being part of other organization (membership {})", action, m.uuid)); err!(format!(
} "Cannot {} because the organization policy forbids being part of other organization (membership {})",
action, m.uuid
));
} }
} }
@ -332,8 +336,9 @@ impl OrgPolicy {
for policy in for policy in
OrgPolicy::find_confirmed_by_user_and_active_policy(user_uuid, OrgPolicyType::SendOptions, conn).await OrgPolicy::find_confirmed_by_user_and_active_policy(user_uuid, OrgPolicyType::SendOptions, conn).await
{ {
if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await { if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await
if user.atype < MembershipType::Admin { && user.atype < MembershipType::Admin
{
match serde_json::from_str::<SendOptionsPolicyData>(&policy.data) { match serde_json::from_str::<SendOptionsPolicyData>(&policy.data) {
Ok(opts) => { Ok(opts) => {
if opts.disable_hide_email { if opts.disable_hide_email {
@ -344,16 +349,15 @@ impl OrgPolicy {
} }
} }
} }
}
false false
} }
pub async fn is_enabled_for_member(member_uuid: &MembershipId, policy_type: OrgPolicyType, conn: &DbConn) -> bool { pub async fn is_enabled_for_member(member_uuid: &MembershipId, policy_type: OrgPolicyType, conn: &DbConn) -> bool {
if let Some(member) = Membership::find_by_uuid(member_uuid, conn).await { if let Some(member) = Membership::find_by_uuid(member_uuid, conn).await
if let Some(policy) = OrgPolicy::find_by_org_and_type(&member.org_uuid, policy_type, conn).await { && let Some(policy) = OrgPolicy::find_by_org_and_type(&member.org_uuid, policy_type, conn).await
{
return policy.enabled; return policy.enabled;
} }
}
false false
} }
} }

335
src/db/models/organization.rs

@ -1,24 +1,33 @@
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
};
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From}; use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*; use diesel::prelude::*;
use num_traits::FromPrimitive; use num_traits::FromPrimitive;
use serde_json::Value; use serde_json::Value;
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
};
use super::{ use crate::{
CipherId, Collection, CollectionGroup, CollectionId, CollectionUser, Group, GroupId, GroupUser, OrgPolicy, CONFIG,
OrgPolicyType, TwoFactor, User, UserId, api::EmptyResult,
}; db::{
use crate::db::schema::{ DbConn,
schema::{
ciphers, ciphers_collections, collections_groups, groups, groups_users, org_policies, organization_api_key, ciphers, ciphers_collections, collections_groups, groups, groups_users, org_policies, organization_api_key,
organizations, users, users_collections, users_organizations, organizations, users, users_collections, users_organizations,
},
},
error::MapResult,
}; };
use crate::CONFIG;
use macros::UuidFromParam; use macros::UuidFromParam;
use super::{
Cipher, CipherId, Collection, CollectionGroup, CollectionId, CollectionUser, Group, GroupId, GroupUser, OrgPolicy,
OrgPolicyType, TwoFactor, User, UserId,
};
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = organizations)] #[diesel(table_name = organizations)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -93,6 +102,10 @@ pub enum MembershipType {
impl MembershipType { impl MembershipType {
pub fn from_str(s: &str) -> Option<Self> { pub fn from_str(s: &str) -> Option<Self> {
#[expect(
clippy::match_same_arms,
reason = "Specifically define `4|Custom` since this is a hack, not a default"
)]
match s { match s {
"0" | "Owner" => Some(MembershipType::Owner), "0" | "Owner" => Some(MembershipType::Owner),
"1" | "Admin" => Some(MembershipType::Admin), "1" | "Admin" => Some(MembershipType::Admin),
@ -321,11 +334,6 @@ impl OrganizationApiKey {
} }
} }
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods /// Database methods
impl Organization { impl Organization {
pub async fn save(&self, conn: &DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
@ -333,7 +341,7 @@ impl Organization {
err!(format!("BillingEmail {} is not a valid email address", self.billing_email)) err!(format!("BillingEmail {} is not a valid email address", self.billing_email))
} }
for member in Membership::find_by_org(&self.uuid, conn).await.iter() { for member in &Membership::find_by_org(&self.uuid, conn).await {
User::update_uuid_revision(&member.user_uuid, conn).await; User::update_uuid_revision(&member.user_uuid, conn).await;
} }
@ -369,8 +377,6 @@ impl Organization {
} }
pub async fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
use super::{Cipher, Collection};
Cipher::delete_all_by_organization(&self.uuid, conn).await?; Cipher::delete_all_by_organization(&self.uuid, conn).await?;
Collection::delete_all_by_organization(&self.uuid, conn).await?; Collection::delete_all_by_organization(&self.uuid, conn).await?;
Membership::delete_all_by_organization(&self.uuid, conn).await?; Membership::delete_all_by_organization(&self.uuid, conn).await?;
@ -378,43 +384,30 @@ impl Organization {
Group::delete_all_by_organization(&self.uuid, conn).await?; Group::delete_all_by_organization(&self.uuid, conn).await?;
OrganizationApiKey::delete_all_by_organization(&self.uuid, conn).await?; OrganizationApiKey::delete_all_by_organization(&self.uuid, conn).await?;
db_run! { conn: { conn.run(move |conn| {
diesel::delete(organizations::table.filter(organizations::uuid.eq(self.uuid))) diesel::delete(organizations::table.filter(organizations::uuid.eq(self.uuid)))
.execute(conn) .execute(conn)
.map_res("Error saving organization") .map_res("Error saving organization")
}} })
.await
} }
pub async fn find_by_uuid(uuid: &OrganizationId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| organizations::table.filter(organizations::uuid.eq(uuid)).first::<Self>(conn).ok()).await
organizations::table
.filter(organizations::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
} }
pub async fn find_by_name(name: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_name(name: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| organizations::table.filter(organizations::name.eq(name)).first::<Self>(conn).ok()).await
organizations::table
.filter(organizations::name.eq(name))
.first::<Self>(conn)
.ok()
}}
} }
pub async fn get_all(conn: &DbConn) -> Vec<Self> { pub async fn get_all(conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| organizations::table.load::<Self>(conn).expect("Error loading organizations")).await
organizations::table
.load::<Self>(conn)
.expect("Error loading organizations")
}}
} }
pub async fn find_main_org_user_email(user_email: &str, conn: &DbConn) -> Option<Self> { pub async fn find_main_org_user_email(user_email: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = user_email.to_lowercase(); let lower_mail = user_email.to_lowercase();
db_run! { conn: { conn.run(move |conn| {
organizations::table organizations::table
.inner_join(users_organizations::table.on(users_organizations::org_uuid.eq(organizations::uuid))) .inner_join(users_organizations::table.on(users_organizations::org_uuid.eq(organizations::uuid)))
.inner_join(users::table.on(users::uuid.eq(users_organizations::user_uuid))) .inner_join(users::table.on(users::uuid.eq(users_organizations::user_uuid)))
@ -424,13 +417,14 @@ impl Organization {
.select(organizations::all_columns) .select(organizations::all_columns)
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_org_user_email(user_email: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_org_user_email(user_email: &str, conn: &DbConn) -> Vec<Self> {
let lower_mail = user_email.to_lowercase(); let lower_mail = user_email.to_lowercase();
db_run! { conn: { conn.run(move |conn| {
organizations::table organizations::table
.inner_join(users_organizations::table.on(users_organizations::org_uuid.eq(organizations::uuid))) .inner_join(users_organizations::table.on(users_organizations::org_uuid.eq(organizations::uuid)))
.inner_join(users::table.on(users::uuid.eq(users_organizations::user_uuid))) .inner_join(users::table.on(users::uuid.eq(users_organizations::user_uuid)))
@ -440,7 +434,8 @@ impl Organization {
.select(organizations::all_columns) .select(organizations::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading user orgs") .expect("Error loading user orgs")
}} })
.await
} }
} }
@ -780,11 +775,12 @@ impl Membership {
CollectionUser::delete_all_by_user_and_org(&self.user_uuid, &self.org_uuid, conn).await?; CollectionUser::delete_all_by_user_and_org(&self.user_uuid, &self.org_uuid, conn).await?;
GroupUser::delete_all_by_member(&self.uuid, conn).await?; GroupUser::delete_all_by_member(&self.uuid, conn).await?;
db_run! { conn: { conn.run(move |conn| {
diesel::delete(users_organizations::table.filter(users_organizations::uuid.eq(self.uuid))) diesel::delete(users_organizations::table.filter(users_organizations::uuid.eq(self.uuid)))
.execute(conn) .execute(conn)
.map_res("Error removing user from organization") .map_res("Error removing user from organization")
}} })
.await
} }
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@ -802,11 +798,11 @@ impl Membership {
} }
pub async fn find_by_email_and_org(email: &str, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Membership> { pub async fn find_by_email_and_org(email: &str, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Membership> {
if let Some(user) = User::find_by_mail(email, conn).await { if let Some(user) = User::find_by_mail(email, conn).await
if let Some(member) = Membership::find_by_user_and_org(&user.uuid, org_uuid, conn).await { && let Some(member) = Membership::find_by_user_and_org(&user.uuid, org_uuid, conn).await
{
return Some(member); return Some(member);
} }
}
None None
} }
@ -824,64 +820,67 @@ impl Membership {
} }
pub async fn find_by_uuid(uuid: &MembershipId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &MembershipId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table.filter(users_organizations::uuid.eq(uuid)).first::<Self>(conn).ok()
.filter(users_organizations::uuid.eq(uuid)) })
.first::<Self>(conn) .await
.ok()
}}
} }
pub async fn find_by_uuid_and_org(uuid: &MembershipId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid_and_org(uuid: &MembershipId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::uuid.eq(uuid)) .filter(users_organizations::uuid.eq(uuid))
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32)) .filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.load::<Self>(conn) .load::<Self>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
pub async fn find_invited_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_invited_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Invited as i32)) .filter(users_organizations::status.eq(MembershipStatus::Invited as i32))
.load::<Self>(conn) .load::<Self>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
// Should be used only when email are disabled. // Should be used only when email are disabled.
// In Organizations::send_invite status is set to Accepted only if the user has a password. // In Organizations::send_invite status is set to Accepted only if the user has a password.
pub async fn accept_user_invitations(user_uuid: &UserId, conn: &DbConn) -> EmptyResult { pub async fn accept_user_invitations(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::update(users_organizations::table) diesel::update(users_organizations::table)
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Invited as i32)) .filter(users_organizations::status.eq(MembershipStatus::Invited as i32))
.set(users_organizations::status.eq(MembershipStatus::Accepted as i32)) .set(users_organizations::status.eq(MembershipStatus::Accepted as i32))
.execute(conn) .execute(conn)
.map_res("Error confirming invitations") .map_res("Error confirming invitations")
}} })
.await
} }
pub async fn find_any_state_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_any_state_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.load::<Self>(conn) .load::<Self>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
pub async fn count_accepted_and_confirmed_by_user( pub async fn count_accepted_and_confirmed_by_user(
@ -889,70 +888,83 @@ impl Membership {
excluded_org: &OrganizationId, excluded_org: &OrganizationId,
conn: &DbConn, conn: &DbConn,
) -> i64 { ) -> i64 {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::org_uuid.ne(excluded_org)) .filter(users_organizations::org_uuid.ne(excluded_org))
.filter(users_organizations::status.eq(MembershipStatus::Accepted as i32).or(users_organizations::status.eq(MembershipStatus::Confirmed as i32))) .filter(
users_organizations::status
.eq(MembershipStatus::Accepted as i32)
.or(users_organizations::status.eq(MembershipStatus::Confirmed as i32)),
)
.count() .count()
.first::<i64>(conn) .first::<i64>(conn)
.unwrap_or(0) .unwrap_or(0)
}} })
.await
} }
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading user organizations") .expect("Error loading user organizations")
}} })
.await
} }
pub async fn find_confirmed_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> { pub async fn find_confirmed_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32)) .filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.load::<Self>(conn) .load::<Self>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
// Get all users which are either owner or admin, or a manager which can manage/access all // Get all users which are either owner or admin, or a manager which can manage/access all
pub async fn find_confirmed_and_manage_all_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> { pub async fn find_confirmed_and_manage_all_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32)) .filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter( .filter(
users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32]) users_organizations::atype
.or(users_organizations::atype.eq(MembershipType::Manager as i32).and(users_organizations::access_all.eq(true))) .eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32])
.or(users_organizations::atype
.eq(MembershipType::Manager as i32)
.and(users_organizations::access_all.eq(true))),
) )
.load::<Self>(conn) .load::<Self>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 { pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.count() .count()
.first::<i64>(conn) .first::<i64>(conn)
.ok() .ok()
.unwrap_or(0) .unwrap_or(0)
}} })
.await
} }
pub async fn find_by_org_and_type(org_uuid: &OrganizationId, atype: MembershipType, conn: &DbConn) -> Vec<Self> { pub async fn find_by_org_and_type(org_uuid: &OrganizationId, atype: MembershipType, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::atype.eq(atype as i32)) .filter(users_organizations::atype.eq(atype as i32))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading user organizations") .expect("Error loading user organizations")
}} })
.await
} }
pub async fn count_confirmed_by_org_and_type( pub async fn count_confirmed_by_org_and_type(
@ -960,7 +972,7 @@ impl Membership {
atype: MembershipType, atype: MembershipType,
conn: &DbConn, conn: &DbConn,
) -> i64 { ) -> i64 {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::atype.eq(atype as i32)) .filter(users_organizations::atype.eq(atype as i32))
@ -968,17 +980,19 @@ impl Membership {
.count() .count()
.first::<i64>(conn) .first::<i64>(conn)
.unwrap_or(0) .unwrap_or(0)
}} })
.await
} }
pub async fn find_by_user_and_org(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> { pub async fn find_by_user_and_org(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_confirmed_by_user_and_org( pub async fn find_confirmed_by_user_and_org(
@ -986,78 +1000,76 @@ impl Membership {
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
conn: &DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.filter( .filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading user organizations") .expect("Error loading user organizations")
}} })
.await
} }
pub async fn get_orgs_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<OrganizationId> { pub async fn get_orgs_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<OrganizationId> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.select(users_organizations::org_uuid) .select(users_organizations::org_uuid)
.load::<OrganizationId>(conn) .load::<OrganizationId>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
pub async fn find_by_user_and_policy(user_uuid: &UserId, policy_type: OrgPolicyType, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user_and_policy(user_uuid: &UserId, policy_type: OrgPolicyType, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.inner_join( .inner_join(
org_policies::table.on( org_policies::table.on(org_policies::org_uuid
org_policies::org_uuid.eq(users_organizations::org_uuid) .eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid)) .and(users_organizations::user_uuid.eq(user_uuid))
.and(org_policies::atype.eq(policy_type as i32)) .and(org_policies::atype.eq(policy_type as i32))
.and(org_policies::enabled.eq(true))) .and(org_policies::enabled.eq(true))),
)
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
) )
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.select(users_organizations::all_columns) .select(users_organizations::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.unwrap_or_default() .unwrap_or_default()
}} })
.await
} }
pub async fn find_by_cipher_and_org(cipher_uuid: &CipherId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_cipher_and_org(cipher_uuid: &CipherId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.left_join(users_collections::table.on( .left_join(users_collections::table.on(users_collections::user_uuid.eq(users_organizations::user_uuid)))
users_collections::user_uuid.eq(users_organizations::user_uuid) .left_join(
)) ciphers_collections::table.on(ciphers_collections::collection_uuid
.left_join(ciphers_collections::table.on( .eq(users_collections::collection_uuid)
ciphers_collections::collection_uuid.eq(users_collections::collection_uuid).and( .and(ciphers_collections::cipher_uuid.eq(&cipher_uuid))),
ciphers_collections::cipher_uuid.eq(&cipher_uuid)
) )
.filter(users_organizations::access_all.eq(true).or(
// AccessAll..
ciphers_collections::cipher_uuid.eq(&cipher_uuid), // ..or access to collection with cipher
)) ))
.filter(
users_organizations::access_all.eq(true).or( // AccessAll..
ciphers_collections::cipher_uuid.eq(&cipher_uuid) // ..or access to collection with cipher
)
)
.select(users_organizations::all_columns) .select(users_organizations::all_columns)
.distinct() .distinct()
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading user organizations") .expect("Error loading user organizations")
}} })
.await
} }
pub async fn find_by_cipher_and_org_with_group( pub async fn find_by_cipher_and_org_with_group(
@ -1065,45 +1077,54 @@ impl Membership {
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
conn: &DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.inner_join(groups_users::table.on( .inner_join(
groups_users::users_organizations_uuid.eq(users_organizations::uuid) groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(ciphers_collections::table.on(
ciphers_collections::collection_uuid.eq(collections_groups::collections_uuid).and(ciphers_collections::cipher_uuid.eq(&cipher_uuid))
))
.filter(
groups::access_all.eq(true).or( // AccessAll via groups
ciphers_collections::cipher_uuid.eq(&cipher_uuid) // ..or access to collection via group
) )
.left_join(collections_groups::table.on(collections_groups::groups_uuid.eq(groups_users::groups_uuid)))
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
) )
.left_join(
ciphers_collections::table.on(ciphers_collections::collection_uuid
.eq(collections_groups::collections_uuid)
.and(ciphers_collections::cipher_uuid.eq(&cipher_uuid))),
)
.filter(groups::access_all.eq(true).or(
// AccessAll via groups
ciphers_collections::cipher_uuid.eq(&cipher_uuid), // ..or access to collection via group
))
.select(users_organizations::all_columns) .select(users_organizations::all_columns)
.distinct() .distinct()
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading user organizations with groups") .expect("Error loading user organizations with groups")
}} })
.await
} }
pub async fn user_has_ge_admin_access_to_cipher(user_uuid: &UserId, cipher_uuid: &CipherId, conn: &DbConn) -> bool { pub async fn user_has_ge_admin_access_to_cipher(user_uuid: &UserId, cipher_uuid: &CipherId, conn: &DbConn) -> bool {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.inner_join(ciphers::table.on(ciphers::uuid.eq(cipher_uuid).and(ciphers::organization_uuid.eq(users_organizations::org_uuid.nullable())))) .inner_join(
ciphers::table.on(ciphers::uuid
.eq(cipher_uuid)
.and(ciphers::organization_uuid.eq(users_organizations::org_uuid.nullable()))),
)
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32])) .filter(
users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32]),
)
.count() .count()
.first::<i64>(conn) .first::<i64>(conn)
.ok() .ok()
.unwrap_or(0) != 0 .unwrap_or(0)
}} != 0
})
.await
} }
pub async fn find_by_collection_and_org( pub async fn find_by_collection_and_org(
@ -1111,44 +1132,41 @@ impl Membership {
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
conn: &DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.left_join(users_collections::table.on( .left_join(users_collections::table.on(users_collections::user_uuid.eq(users_organizations::user_uuid)))
users_collections::user_uuid.eq(users_organizations::user_uuid) .filter(users_organizations::access_all.eq(true).or(
// AccessAll..
users_collections::collection_uuid.eq(&collection_uuid), // ..or access to collection with cipher
)) ))
.filter(
users_organizations::access_all.eq(true).or( // AccessAll..
users_collections::collection_uuid.eq(&collection_uuid) // ..or access to collection with cipher
)
)
.select(users_organizations::all_columns) .select(users_organizations::all_columns)
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading user organizations") .expect("Error loading user organizations")
}} })
.await
} }
pub async fn find_by_external_id_and_org(ext_id: &str, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> { pub async fn find_by_external_id_and_org(ext_id: &str, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter( .filter(users_organizations::external_id.eq(ext_id).and(users_organizations::org_uuid.eq(org_uuid)))
users_organizations::external_id.eq(ext_id)
.and(users_organizations::org_uuid.eq(org_uuid))
)
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_main_user_org(user_uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_main_user_org(user_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.ne(MembershipStatus::Revoked as i32)) .filter(users_organizations::status.ne(MembershipStatus::Revoked as i32))
.order(users_organizations::atype.asc()) .order(users_organizations::atype.asc())
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
} }
@ -1186,20 +1204,19 @@ impl OrganizationApiKey {
} }
pub async fn find_by_org_uuid(org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> { pub async fn find_by_org_uuid(org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
organization_api_key::table organization_api_key::table.filter(organization_api_key::org_uuid.eq(org_uuid)).first::<Self>(conn).ok()
.filter(organization_api_key::org_uuid.eq(org_uuid)) })
.first::<Self>(conn) .await
.ok()
}}
} }
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(organization_api_key::table.filter(organization_api_key::org_uuid.eq(org_uuid))) diesel::delete(organization_api_key::table.filter(organization_api_key::org_uuid.eq(org_uuid)))
.execute(conn) .execute(conn)
.map_res("Error removing organization api key from organization") .map_res("Error removing organization api key from organization")
}} })
.await
} }
} }

121
src/db/models/send.rs

@ -1,11 +1,19 @@
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use data_encoding::BASE64URL_NOPAD;
use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use uuid::Uuid;
use crate::{config::PathType, util::LowerCase, CONFIG}; use crate::{
CONFIG,
api::EmptyResult,
config::PathType,
db::{DbConn, schema::sends},
error::MapResult,
util::{LowerCase, NumberOrString, format_date},
};
use super::{OrganizationId, User, UserId}; use super::{OrganizationId, User, UserId};
use crate::db::schema::sends;
use diesel::prelude::*;
use id::SendId; use id::SendId;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
@ -107,37 +115,33 @@ impl Send {
pub fn check_password(&self, password: &str) -> bool { pub fn check_password(&self, password: &str) -> bool {
match (&self.password_hash, &self.password_salt, self.password_iter) { match (&self.password_hash, &self.password_salt, self.password_iter) {
(Some(hash), Some(salt), Some(iter)) => { (Some(hash), Some(salt), Some(iter)) => {
crate::crypto::verify_password_hash(password.as_bytes(), salt, hash, iter as u32) crate::crypto::verify_password_hash(password.as_bytes(), salt, hash, iter.cast_unsigned())
} }
_ => false, _ => false,
} }
} }
pub async fn creator_identifier(&self, conn: &DbConn) -> Option<String> { pub async fn creator_identifier(&self, conn: &DbConn) -> Option<String> {
if let Some(hide_email) = self.hide_email { if let Some(hide_email) = self.hide_email
if hide_email { && hide_email
{
return None; return None;
} }
}
if let Some(user_uuid) = &self.user_uuid { if let Some(user_uuid) = &self.user_uuid
if let Some(user) = User::find_by_uuid(user_uuid, conn).await { && let Some(user) = User::find_by_uuid(user_uuid, conn).await
{
return Some(user.email); return Some(user.email);
} }
}
None None
} }
pub fn to_json(&self) -> Value { pub fn to_json(&self) -> Value {
use crate::util::format_date;
use data_encoding::BASE64URL_NOPAD;
use uuid::Uuid;
let mut data = serde_json::from_str::<LowerCase<Value>>(&self.data).map(|d| d.data).unwrap_or_default(); let mut data = serde_json::from_str::<LowerCase<Value>>(&self.data).map(|d| d.data).unwrap_or_default();
// Mobile clients expect size to be a string instead of a number // Mobile clients expect size to be a string instead of a number
if let Some(size) = data.get("size").and_then(|v| v.as_i64()) { if let Some(size) = data.get("size").and_then(Value::as_i64) {
data["size"] = Value::String(size.to_string()); data["size"] = Value::String(size.to_string());
} }
@ -167,12 +171,10 @@ impl Send {
} }
pub async fn to_json_access(&self, conn: &DbConn) -> Value { pub async fn to_json_access(&self, conn: &DbConn) -> Value {
use crate::util::format_date;
let mut data = serde_json::from_str::<LowerCase<Value>>(&self.data).map(|d| d.data).unwrap_or_default(); let mut data = serde_json::from_str::<LowerCase<Value>>(&self.data).map(|d| d.data).unwrap_or_default();
// Mobile clients expect size to be a string instead of a number // Mobile clients expect size to be a string instead of a number
if let Some(size) = data.get("size").and_then(|v| v.as_i64()) { if let Some(size) = data.get("size").and_then(Value::as_i64) {
data["size"] = Value::String(size.to_string()); data["size"] = Value::String(size.to_string());
} }
@ -191,12 +193,6 @@ impl Send {
} }
} }
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
use crate::util::NumberOrString;
impl Send { impl Send {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await; self.update_users_revision(conn).await;
@ -240,11 +236,10 @@ impl Send {
operator.delete_with(&self.uuid).recursive(true).await.ok(); operator.delete_with(&self.uuid).recursive(true).await.ok();
} }
db_run! { conn: { conn.run(move |conn| {
diesel::delete(sends::table.filter(sends::uuid.eq(&self.uuid))) diesel::delete(sends::table.filter(sends::uuid.eq(&self.uuid))).execute(conn).map_res("Error deleting send")
.execute(conn) })
.map_res("Error deleting send") .await
}}
} }
/// Purge all sends that are past their deletion date. /// Purge all sends that are past their deletion date.
@ -256,15 +251,12 @@ impl Send {
pub async fn update_users_revision(&self, conn: &DbConn) -> Vec<UserId> { pub async fn update_users_revision(&self, conn: &DbConn) -> Vec<UserId> {
let mut user_uuids = Vec::new(); let mut user_uuids = Vec::new();
match &self.user_uuid { if let Some(user_uuid) = &self.user_uuid {
Some(user_uuid) => {
User::update_uuid_revision(user_uuid, conn).await; User::update_uuid_revision(user_uuid, conn).await;
user_uuids.push(user_uuid.clone()) user_uuids.push(user_uuid.clone());
} } else {
None => {
// Belongs to Organization, not implemented // Belongs to Organization, not implemented
} }
};
user_uuids user_uuids
} }
@ -276,9 +268,6 @@ impl Send {
} }
pub async fn find_by_access_id(access_id: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_access_id(access_id: &str, conn: &DbConn) -> Option<Self> {
use data_encoding::BASE64URL_NOPAD;
use uuid::Uuid;
let Ok(uuid_vec) = BASE64URL_NOPAD.decode(access_id.as_bytes()) else { let Ok(uuid_vec) = BASE64URL_NOPAD.decode(access_id.as_bytes()) else {
return None; return None;
}; };
@ -292,50 +281,38 @@ impl Send {
} }
pub async fn find_by_uuid(uuid: &SendId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &SendId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| sends::table.filter(sends::uuid.eq(uuid)).first::<Self>(conn).ok()).await
sends::table
.filter(sends::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
} }
pub async fn find_by_uuid_and_user(uuid: &SendId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid_and_user(uuid: &SendId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
sends::table sends::table.filter(sends::uuid.eq(uuid)).filter(sends::user_uuid.eq(user_uuid)).first::<Self>(conn).ok()
.filter(sends::uuid.eq(uuid)) })
.filter(sends::user_uuid.eq(user_uuid)) .await
.first::<Self>(conn)
.ok()
}}
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
sends::table sends::table.filter(sends::user_uuid.eq(user_uuid)).load::<Self>(conn).expect("Error loading sends")
.filter(sends::user_uuid.eq(user_uuid)) })
.load::<Self>(conn) .await
.expect("Error loading sends")
}}
} }
pub async fn size_by_user(user_uuid: &UserId, conn: &DbConn) -> Option<i64> { pub async fn size_by_user(user_uuid: &UserId, conn: &DbConn) -> Option<i64> {
let sends = Self::find_by_user(user_uuid, conn).await;
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct FileData { struct FileData {
#[serde(rename = "size", alias = "Size")] #[serde(rename = "size", alias = "Size")]
size: NumberOrString, size: NumberOrString,
} }
let sends = Self::find_by_user(user_uuid, conn).await;
let mut total: i64 = 0; let mut total: i64 = 0;
for send in sends { for send in sends {
if send.atype == SendType::File as i32 { if send.atype == SendType::File as i32
if let Ok(size) = && let Ok(size) =
serde_json::from_str::<FileData>(&send.data).map_err(Into::into).and_then(|d| d.size.into_i64()) serde_json::from_str::<FileData>(&send.data).map_err(Into::into).and_then(|d| d.size.into_i64())
{ {
total = total.checked_add(size)?; total = total.checked_add(size)?;
};
} }
} }
@ -343,22 +320,18 @@ impl Send {
} }
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
sends::table sends::table.filter(sends::organization_uuid.eq(org_uuid)).load::<Self>(conn).expect("Error loading sends")
.filter(sends::organization_uuid.eq(org_uuid)) })
.load::<Self>(conn) .await
.expect("Error loading sends")
}}
} }
pub async fn find_by_past_deletion_date(conn: &DbConn) -> Vec<Self> { pub async fn find_by_past_deletion_date(conn: &DbConn) -> Vec<Self> {
let now = Utc::now().naive_utc(); let now = Utc::now().naive_utc();
db_run! { conn: { conn.run(move |conn| {
sends::table sends::table.filter(sends::deletion_date.lt(now)).load::<Self>(conn).expect("Error loading sends")
.filter(sends::deletion_date.lt(now)) })
.load::<Self>(conn) .await
.expect("Error loading sends")
}}
} }
} }

40
src/db/models/sso_auth.rs

@ -1,17 +1,20 @@
use chrono::{NaiveDateTime, Utc};
use std::time::Duration; use std::time::Duration;
use crate::api::EmptyResult; use chrono::{NaiveDateTime, Utc};
use crate::db::schema::sso_auth; use diesel::{
use crate::db::{DbConn, DbPool}; deserialize::FromSql,
use crate::error::MapResult; expression::AsExpression,
use crate::sso::{OIDCCode, OIDCCodeChallenge, OIDCIdentifier, OIDCState, SSO_AUTH_EXPIRATION}; prelude::*,
serialize::{Output, ToSql},
sql_types::Text,
};
use diesel::deserialize::FromSql; use crate::{
use diesel::expression::AsExpression; api::EmptyResult,
use diesel::prelude::*; db::{DbConn, DbPool, schema::sso_auth},
use diesel::serialize::{Output, ToSql}; error::MapResult,
use diesel::sql_types::Text; sso::{OIDCCode, OIDCCodeChallenge, OIDCIdentifier, OIDCState, SSO_AUTH_EXPIRATION},
};
#[derive(AsExpression, Clone, Debug, Serialize, Deserialize, FromSqlRow)] #[derive(AsExpression, Clone, Debug, Serialize, Deserialize, FromSqlRow)]
#[diesel(sql_type = Text)] #[diesel(sql_type = Text)]
@ -106,13 +109,14 @@ impl SsoAuth {
pub async fn find(state: &OIDCState, conn: &DbConn) -> Option<Self> { pub async fn find(state: &OIDCState, conn: &DbConn) -> Option<Self> {
let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION; let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION;
db_run! { conn: { conn.run(move |conn| {
sso_auth::table sso_auth::table
.filter(sso_auth::state.eq(state)) .filter(sso_auth::state.eq(state))
.filter(sso_auth::created_at.ge(oldest)) .filter(sso_auth::created_at.ge(oldest))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_by_code(code: &OIDCCode, conn: &DbConn) -> Option<Self> { pub async fn find_by_code(code: &OIDCCode, conn: &DbConn) -> Option<Self> {
@ -127,22 +131,24 @@ impl SsoAuth {
} }
pub async fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! {conn: { conn.run(move |conn| {
diesel::delete(sso_auth::table.filter(sso_auth::state.eq(self.state))) diesel::delete(sso_auth::table.filter(sso_auth::state.eq(self.state)))
.execute(conn) .execute(conn)
.map_res("Error deleting sso_auth") .map_res("Error deleting sso_auth")
}} })
.await
} }
pub async fn delete_expired(pool: DbPool) -> EmptyResult { pub async fn delete_expired(pool: DbPool) -> EmptyResult {
debug!("Purging expired sso_auth"); debug!("Purging expired sso_auth");
if let Ok(conn) = pool.get().await { if let Ok(conn) = pool.get().await {
let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION; let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION;
db_run! { conn: { conn.run(move |conn| {
diesel::delete(sso_auth::table.filter(sso_auth::created_at.lt(oldest))) diesel::delete(sso_auth::table.filter(sso_auth::created_at.lt(oldest)))
.execute(conn) .execute(conn)
.map_res("Error deleting expired SSO nonce") .map_res("Error deleting expired SSO nonce")
}} })
.await
} else { } else {
err!("Failed to get DB connection while purging expired sso_auth") err!("Failed to get DB connection while purging expired sso_auth")
} }

55
src/db/models/two_factor.rs

@ -1,13 +1,17 @@
use super::UserId;
use crate::api::core::two_factor::webauthn::WebauthnRegistration;
use crate::db::schema::twofactor;
use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*; use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use webauthn_rs::prelude::{Credential, ParsedAttestation}; use webauthn_rs::prelude::{Credential, ParsedAttestation};
use webauthn_rs_core::proto::CredentialV3; use webauthn_rs_core::proto::CredentialV3;
use webauthn_rs_proto::{AttestationFormat, RegisteredExtensions}; use webauthn_rs_proto::{AttestationFormat, RegisteredExtensions};
use crate::{
api::{EmptyResult, core::two_factor::webauthn::WebauthnRegistration},
db::{DbConn, schema::twofactor},
error::MapResult,
};
use super::UserId;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor)] #[diesel(table_name = twofactor)]
#[diesel(primary_key(uuid))] #[diesel(primary_key(uuid))]
@ -114,53 +118,58 @@ impl TwoFactor {
} }
pub async fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(twofactor::table.filter(twofactor::uuid.eq(self.uuid))) diesel::delete(twofactor::table.filter(twofactor::uuid.eq(self.uuid)))
.execute(conn) .execute(conn)
.map_res("Error deleting twofactor") .map_res("Error deleting twofactor")
}} })
.await
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
twofactor::table twofactor::table
.filter(twofactor::user_uuid.eq(user_uuid)) .filter(twofactor::user_uuid.eq(user_uuid))
.filter(twofactor::atype.lt(1000)) // Filter implementation types .filter(twofactor::atype.lt(1000)) // Filter implementation types
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading twofactor") .expect("Error loading twofactor")
}} })
.await
} }
pub async fn find_by_user_and_type(user_uuid: &UserId, atype: i32, conn: &DbConn) -> Option<Self> { pub async fn find_by_user_and_type(user_uuid: &UserId, atype: i32, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
twofactor::table twofactor::table
.filter(twofactor::user_uuid.eq(user_uuid)) .filter(twofactor::user_uuid.eq(user_uuid))
.filter(twofactor::atype.eq(atype)) .filter(twofactor::atype.eq(atype))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid))) diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)
.map_res("Error deleting twofactors") .map_res("Error deleting twofactors")
}} })
.await
} }
pub async fn migrate_u2f_to_webauthn(conn: &DbConn) -> EmptyResult { pub async fn migrate_u2f_to_webauthn(conn: &DbConn) -> EmptyResult {
let u2f_factors = db_run! { conn: { use crate::api::core::two_factor::webauthn::{U2FRegistration, get_webauthn_registrations};
use webauthn_rs::prelude::{COSEEC2Key, COSEKey, COSEKeyType, ECDSACurve};
use webauthn_rs_proto::{COSEAlgorithm, UserVerificationPolicy};
let u2f_factors = conn
.run(move |conn| {
twofactor::table twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::U2f as i32)) .filter(twofactor::atype.eq(TwoFactorType::U2f as i32))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading twofactor") .expect("Error loading twofactor")
}}; })
.await;
use crate::api::core::two_factor::webauthn::U2FRegistration;
use crate::api::core::two_factor::webauthn::{get_webauthn_registrations, WebauthnRegistration};
use webauthn_rs::prelude::{COSEEC2Key, COSEKey, COSEKeyType, ECDSACurve};
use webauthn_rs_proto::{COSEAlgorithm, UserVerificationPolicy};
for mut u2f in u2f_factors { for mut u2f in u2f_factors {
let mut regs: Vec<U2FRegistration> = serde_json::from_str(&u2f.data)?; let mut regs: Vec<U2FRegistration> = serde_json::from_str(&u2f.data)?;
@ -227,12 +236,14 @@ impl TwoFactor {
} }
pub async fn migrate_credential_to_passkey(conn: &DbConn) -> EmptyResult { pub async fn migrate_credential_to_passkey(conn: &DbConn) -> EmptyResult {
let webauthn_factors = db_run! { conn: { let webauthn_factors = conn
.run(move |conn| {
twofactor::table twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::Webauthn as i32)) .filter(twofactor::atype.eq(TwoFactorType::Webauthn as i32))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading twofactor") .expect("Error loading twofactor")
}}; })
.await;
for webauthn_factor in webauthn_factors { for webauthn_factor in webauthn_factors {
// assume that a failure to parse into the old struct, means that it was already converted // assume that a failure to parse into the old struct, means that it was already converted
@ -241,7 +252,7 @@ impl TwoFactor {
continue; continue;
}; };
let regs = regs.into_iter().map(|r| r.into()).collect::<Vec<WebauthnRegistration>>(); let regs = regs.into_iter().map(Into::into).collect::<Vec<WebauthnRegistration>>();
TwoFactor::new(webauthn_factor.user_uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&regs)?) TwoFactor::new(webauthn_factor.user_uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&regs)?)
.save(conn) .save(conn)

42
src/db/models/two_factor_duo_context.rs

@ -1,9 +1,12 @@
use chrono::Utc; use chrono::Utc;
use crate::db::schema::twofactor_duo_ctx;
use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*; use diesel::prelude::*;
use crate::{
api::EmptyResult,
db::{DbConn, schema::twofactor_duo_ctx},
error::MapResult,
};
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor_duo_ctx)] #[diesel(table_name = twofactor_duo_ctx)]
#[diesel(primary_key(state))] #[diesel(primary_key(state))]
@ -16,12 +19,10 @@ pub struct TwoFactorDuoContext {
impl TwoFactorDuoContext { impl TwoFactorDuoContext {
pub async fn find_by_state(state: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_state(state: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
twofactor_duo_ctx::table twofactor_duo_ctx::table.filter(twofactor_duo_ctx::state.eq(state)).first::<Self>(conn).ok()
.filter(twofactor_duo_ctx::state.eq(state)) })
.first::<Self>(conn) .await
.ok()
}}
} }
pub async fn save(state: &str, user_email: &str, nonce: &str, ttl: i64, conn: &DbConn) -> EmptyResult { pub async fn save(state: &str, user_email: &str, nonce: &str, ttl: i64, conn: &DbConn) -> EmptyResult {
@ -29,41 +30,42 @@ impl TwoFactorDuoContext {
let exists = Self::find_by_state(state, conn).await; let exists = Self::find_by_state(state, conn).await;
if exists.is_some() { if exists.is_some() {
return Ok(()); return Ok(());
}; }
let exp = Utc::now().timestamp() + ttl; let exp = Utc::now().timestamp() + ttl;
db_run! { conn: { conn.run(move |conn| {
diesel::insert_into(twofactor_duo_ctx::table) diesel::insert_into(twofactor_duo_ctx::table)
.values(( .values((
twofactor_duo_ctx::state.eq(state), twofactor_duo_ctx::state.eq(state),
twofactor_duo_ctx::user_email.eq(user_email), twofactor_duo_ctx::user_email.eq(user_email),
twofactor_duo_ctx::nonce.eq(nonce), twofactor_duo_ctx::nonce.eq(nonce),
twofactor_duo_ctx::exp.eq(exp) twofactor_duo_ctx::exp.eq(exp),
)) ))
.execute(conn) .execute(conn)
.map_res("Error saving context to twofactor_duo_ctx") .map_res("Error saving context to twofactor_duo_ctx")
}} })
.await
} }
pub async fn find_expired(conn: &DbConn) -> Vec<Self> { pub async fn find_expired(conn: &DbConn) -> Vec<Self> {
let now = Utc::now().timestamp(); let now = Utc::now().timestamp();
db_run! { conn: { conn.run(move |conn| {
twofactor_duo_ctx::table twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::exp.lt(now)) .filter(twofactor_duo_ctx::exp.lt(now))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error finding expired contexts in twofactor_duo_ctx") .expect("Error finding expired contexts in twofactor_duo_ctx")
}} })
.await
} }
pub async fn delete(&self, conn: &DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete( diesel::delete(twofactor_duo_ctx::table.filter(twofactor_duo_ctx::state.eq(&self.state)))
twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::state.eq(&self.state)))
.execute(conn) .execute(conn)
.map_res("Error deleting from twofactor_duo_ctx") .map_res("Error deleting from twofactor_duo_ctx")
}} })
.await
} }
pub async fn purge_expired_duo_contexts(conn: &DbConn) { pub async fn purge_expired_duo_contexts(conn: &DbConn) {

39
src/db/models/two_factor_incomplete.rs

@ -1,17 +1,17 @@
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use diesel::prelude::*;
use crate::db::schema::twofactor_incomplete;
use crate::{ use crate::{
CONFIG,
api::EmptyResult, api::EmptyResult,
auth::ClientIp, auth::ClientIp,
db::{ db::{
models::{DeviceId, UserId},
DbConn, DbConn,
models::{DeviceId, UserId},
schema::twofactor_incomplete,
}, },
error::MapResult, error::MapResult,
CONFIG,
}; };
use diesel::prelude::*;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor_incomplete)] #[diesel(table_name = twofactor_incomplete)]
@ -49,7 +49,7 @@ impl TwoFactorIncomplete {
return Ok(()); return Ok(());
} }
db_run! { conn: { conn.run(move |conn| {
diesel::insert_into(twofactor_incomplete::table) diesel::insert_into(twofactor_incomplete::table)
.values(( .values((
twofactor_incomplete::user_uuid.eq(user_uuid), twofactor_incomplete::user_uuid.eq(user_uuid),
@ -61,7 +61,8 @@ impl TwoFactorIncomplete {
)) ))
.execute(conn) .execute(conn)
.map_res("Error adding twofactor_incomplete record") .map_res("Error adding twofactor_incomplete record")
}} })
.await
} }
pub async fn mark_complete(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> EmptyResult { pub async fn mark_complete(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
@ -73,22 +74,24 @@ impl TwoFactorIncomplete {
} }
pub async fn find_by_user_and_device(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> Option<Self> { pub async fn find_by_user_and_device(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| {
twofactor_incomplete::table twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid)) .filter(twofactor_incomplete::user_uuid.eq(user_uuid))
.filter(twofactor_incomplete::device_uuid.eq(device_uuid)) .filter(twofactor_incomplete::device_uuid.eq(device_uuid))
.first::<Self>(conn) .first::<Self>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_logins_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> { pub async fn find_logins_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { conn.run(move |conn| {
twofactor_incomplete::table twofactor_incomplete::table
.filter(twofactor_incomplete::login_time.lt(dt)) .filter(twofactor_incomplete::login_time.lt(dt))
.load::<Self>(conn) .load::<Self>(conn)
.expect("Error loading twofactor_incomplete") .expect("Error loading twofactor_incomplete")
}} })
.await
} }
pub async fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
@ -96,20 +99,24 @@ impl TwoFactorIncomplete {
} }
pub async fn delete_by_user_and_device(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> EmptyResult { pub async fn delete_by_user_and_device(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(twofactor_incomplete::table diesel::delete(
twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid)) .filter(twofactor_incomplete::user_uuid.eq(user_uuid))
.filter(twofactor_incomplete::device_uuid.eq(device_uuid))) .filter(twofactor_incomplete::device_uuid.eq(device_uuid)),
)
.execute(conn) .execute(conn)
.map_res("Error in twofactor_incomplete::delete_by_user_and_device()") .map_res("Error in twofactor_incomplete::delete_by_user_and_device()")
}} })
.await
} }
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(twofactor_incomplete::table.filter(twofactor_incomplete::user_uuid.eq(user_uuid))) diesel::delete(twofactor_incomplete::table.filter(twofactor_incomplete::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)
.map_res("Error in twofactor_incomplete::delete_all_by_user()") .map_res("Error in twofactor_incomplete::delete_all_by_user()")
}} })
.await
} }
} }

123
src/db/models/user.rs

@ -1,23 +1,27 @@
use crate::db::schema::{invitations, sso_users, twofactor_incomplete, users};
use chrono::{NaiveDateTime, TimeDelta, Utc}; use chrono::{NaiveDateTime, TimeDelta, Utc};
use derive_more::{AsRef, Deref, Display, From}; use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*; use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use super::{
Cipher, Device, EmergencyAccess, Favorite, Folder, Membership, MembershipType, TwoFactor, TwoFactorIncomplete,
};
use crate::{ use crate::{
CONFIG,
api::EmptyResult, api::EmptyResult,
crypto, crypto,
db::{models::DeviceId, DbConn}, db::{
DbConn,
models::DeviceId,
schema::{invitations, sso_users, twofactor_incomplete, users},
},
error::MapResult, error::MapResult,
sso::OIDCIdentifier, sso::OIDCIdentifier,
util::{format_date, get_uuid, retry}, util::{format_date, get_uuid, retry},
CONFIG,
}; };
use macros::UuidFromParam; use macros::UuidFromParam;
use super::{
Cipher, Device, EmergencyAccess, Favorite, Folder, Membership, MembershipType, TwoFactor, TwoFactorIncomplete,
};
#[derive(Identifiable, Queryable, Insertable, AsChangeset, Selectable)] #[derive(Identifiable, Queryable, Insertable, AsChangeset, Selectable)]
#[diesel(table_name = users)] #[diesel(table_name = users)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -137,8 +141,8 @@ impl User {
_totp_secret: None, _totp_secret: None,
totp_recover: None, totp_recover: None,
equivalent_domains: "[]".to_string(), equivalent_domains: "[]".to_owned(),
excluded_globals: "[]".to_string(), excluded_globals: "[]".to_owned(),
client_kdf_type: Self::CLIENT_KDF_TYPE_DEFAULT, client_kdf_type: Self::CLIENT_KDF_TYPE_DEFAULT,
client_kdf_iter: Self::CLIENT_KDF_ITER_DEFAULT, client_kdf_iter: Self::CLIENT_KDF_ITER_DEFAULT,
@ -158,7 +162,7 @@ impl User {
password.as_bytes(), password.as_bytes(),
&self.salt, &self.salt,
&self.password_hash, &self.password_hash,
self.password_iterations as u32, self.password_iterations.cast_unsigned(),
) )
} }
@ -193,7 +197,8 @@ impl User {
allow_next_route: Option<Vec<String>>, allow_next_route: Option<Vec<String>>,
conn: &DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
self.password_hash = crypto::hash_password(password.as_bytes(), &self.salt, self.password_iterations as u32); self.password_hash =
crypto::hash_password(password.as_bytes(), &self.salt, self.password_iterations.cast_unsigned());
if let Some(route) = allow_next_route { if let Some(route) = allow_next_route {
self.set_stamp_exception(route); self.set_stamp_exception(route);
@ -238,10 +243,10 @@ impl User {
pub fn display_name(&self) -> &str { pub fn display_name(&self) -> &str {
// default to email if name is empty // default to email if name is empty
if !&self.name.is_empty() { if self.name.is_empty() {
&self.name
} else {
&self.email &self.email
} else {
&self.name
} }
} }
} }
@ -337,15 +342,14 @@ impl User {
TwoFactorIncomplete::delete_all_by_user(&self.uuid, conn).await?; TwoFactorIncomplete::delete_all_by_user(&self.uuid, conn).await?;
Invitation::take(&self.email, conn).await; // Delete invitation if any Invitation::take(&self.email, conn).await; // Delete invitation if any
db_run! { conn: { conn.run(move |conn| {
diesel::delete(users::table.filter(users::uuid.eq(self.uuid))) diesel::delete(users::table.filter(users::uuid.eq(self.uuid))).execute(conn).map_res("Error deleting user")
.execute(conn) })
.map_res("Error deleting user") .await
}}
} }
pub async fn update_uuid_revision(uuid: &UserId, conn: &DbConn) { pub async fn update_uuid_revision(uuid: &UserId, conn: &DbConn) {
if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await { if let Err(e) = Self::update_revision_impl(uuid, &Utc::now().naive_utc(), conn).await {
warn!("Failed to update revision for {uuid}: {e:#?}"); warn!("Failed to update revision for {uuid}: {e:#?}");
} }
} }
@ -353,68 +357,62 @@ impl User {
pub async fn update_all_revisions(conn: &DbConn) -> EmptyResult { pub async fn update_all_revisions(conn: &DbConn) -> EmptyResult {
let updated_at = Utc::now().naive_utc(); let updated_at = Utc::now().naive_utc();
db_run! { conn: { conn.run(move |conn| {
retry(|| { retry(|| diesel::update(users::table).set(users::updated_at.eq(updated_at)).execute(conn), 10)
diesel::update(users::table)
.set(users::updated_at.eq(updated_at))
.execute(conn)
}, 10)
.map_res("Error updating revision date for all users") .map_res("Error updating revision date for all users")
}} })
.await
} }
pub async fn update_revision(&mut self, conn: &DbConn) -> EmptyResult { pub async fn update_revision(&mut self, conn: &DbConn) -> EmptyResult {
self.updated_at = Utc::now().naive_utc(); self.updated_at = Utc::now().naive_utc();
Self::_update_revision(&self.uuid, &self.updated_at, conn).await Self::update_revision_impl(&self.uuid, &self.updated_at, conn).await
} }
async fn _update_revision(uuid: &UserId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult { async fn update_revision_impl(uuid: &UserId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
retry(|| { retry(
|| {
diesel::update(users::table.filter(users::uuid.eq(uuid))) diesel::update(users::table.filter(users::uuid.eq(uuid)))
.set(users::updated_at.eq(date)) .set(users::updated_at.eq(date))
.execute(conn) .execute(conn)
}, 10) },
10,
)
.map_res("Error updating user revision") .map_res("Error updating user revision")
}} })
.await
} }
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = mail.to_lowercase(); let lower_mail = mail.to_lowercase();
db_run! { conn: { conn.run(move |conn| users::table.filter(users::email.eq(lower_mail)).first::<Self>(conn).ok()).await
users::table
.filter(users::email.eq(lower_mail))
.first::<Self>(conn)
.ok()
}}
} }
pub async fn find_by_uuid(uuid: &UserId, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { conn.run(move |conn| users::table.filter(users::uuid.eq(uuid)).first::<Self>(conn).ok()).await
users::table
.filter(users::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
} }
pub async fn find_by_device_for_email2fa(device_uuid: &DeviceId, conn: &DbConn) -> Option<Self> { pub async fn find_by_device_for_email2fa(device_uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
if let Some(user_uuid) = db_run! ( conn: { if let Some(user_uuid) = conn
.run(move |conn| {
twofactor_incomplete::table twofactor_incomplete::table
.filter(twofactor_incomplete::device_uuid.eq(device_uuid)) .filter(twofactor_incomplete::device_uuid.eq(device_uuid))
.order_by(twofactor_incomplete::login_time.desc()) .order_by(twofactor_incomplete::login_time.desc())
.select(twofactor_incomplete::user_uuid) .select(twofactor_incomplete::user_uuid)
.first::<UserId>(conn) .first::<UserId>(conn)
.ok() .ok()
}) { })
.await
{
return Self::find_by_uuid(&user_uuid, conn).await; return Self::find_by_uuid(&user_uuid, conn).await;
} }
None None
} }
pub async fn get_all(conn: &DbConn) -> Vec<(Self, Option<SsoUser>)> { pub async fn get_all(conn: &DbConn) -> Vec<(Self, Option<SsoUser>)> {
db_run! { conn: { conn.run(move |conn| {
users::table users::table
.left_join(sso_users::table) .left_join(sso_users::table)
.select(<(Self, Option<SsoUser>)>::as_select()) .select(<(Self, Option<SsoUser>)>::as_select())
@ -422,7 +420,8 @@ impl User {
.expect("Error loading groups for user") .expect("Error loading groups for user")
.into_iter() .into_iter()
.collect() .collect()
}} })
.await
} }
pub async fn last_active(&self, conn: &DbConn) -> Option<NaiveDateTime> { pub async fn last_active(&self, conn: &DbConn) -> Option<NaiveDateTime> {
@ -467,21 +466,18 @@ impl Invitation {
} }
pub async fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(invitations::table.filter(invitations::email.eq(self.email))) diesel::delete(invitations::table.filter(invitations::email.eq(self.email)))
.execute(conn) .execute(conn)
.map_res("Error deleting invitation") .map_res("Error deleting invitation")
}} })
.await
} }
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = mail.to_lowercase(); let lower_mail = mail.to_lowercase();
db_run! { conn: { conn.run(move |conn| invitations::table.filter(invitations::email.eq(lower_mail)).first::<Self>(conn).ok())
invitations::table .await
.filter(invitations::email.eq(lower_mail))
.first::<Self>(conn)
.ok()
}}
} }
pub async fn take(mail: &str, conn: &DbConn) -> bool { pub async fn take(mail: &str, conn: &DbConn) -> bool {
@ -531,34 +527,37 @@ impl SsoUser {
} }
pub async fn find_by_identifier(identifier: &str, conn: &DbConn) -> Option<(User, Self)> { pub async fn find_by_identifier(identifier: &str, conn: &DbConn) -> Option<(User, Self)> {
db_run! { conn: { conn.run(move |conn| {
users::table users::table
.inner_join(sso_users::table) .inner_join(sso_users::table)
.select(<(User, Self)>::as_select()) .select(<(User, Self)>::as_select())
.filter(sso_users::identifier.eq(identifier)) .filter(sso_users::identifier.eq(identifier))
.first::<(User, Self)>(conn) .first::<(User, Self)>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<(User, Option<Self>)> { pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<(User, Option<Self>)> {
let lower_mail = mail.to_lowercase(); let lower_mail = mail.to_lowercase();
db_run! { conn: { conn.run(move |conn| {
users::table users::table
.left_join(sso_users::table) .left_join(sso_users::table)
.select(<(User, Option<Self>)>::as_select()) .select(<(User, Option<Self>)>::as_select())
.filter(users::email.eq(lower_mail)) .filter(users::email.eq(lower_mail))
.first::<(User, Option<Self>)>(conn) .first::<(User, Option<Self>)>(conn)
.ok() .ok()
}} })
.await
} }
pub async fn delete(user_uuid: &UserId, conn: &DbConn) -> EmptyResult { pub async fn delete(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { conn.run(move |conn| {
diesel::delete(sso_users::table.filter(sso_users::user_uuid.eq(user_uuid))) diesel::delete(sso_users::table.filter(sso_users::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)
.map_res("Error deleting sso user") .map_res("Error deleting sso user")
}} })
.await
} }
} }

11
src/db/query_logger.rs

@ -1,6 +1,7 @@
use diesel::connection::{Instrumentation, InstrumentationEvent};
use std::{cell::RefCell, collections::HashMap, time::Instant}; use std::{cell::RefCell, collections::HashMap, time::Instant};
use diesel::connection::{Instrumentation, InstrumentationEvent};
thread_local! { thread_local! {
static QUERY_PERF_TRACKER: RefCell<HashMap<String, Instant>> = RefCell::new(HashMap::new()); static QUERY_PERF_TRACKER: RefCell<HashMap<String, Instant>> = RefCell::new(HashMap::new());
} }
@ -11,7 +12,7 @@ pub fn simple_logger() -> Option<Box<dyn Instrumentation>> {
url, url,
.. ..
} => { } => {
debug!("Establishing connection: {url}") debug!("Establishing connection: {url}");
} }
InstrumentationEvent::FinishEstablishConnection { InstrumentationEvent::FinishEstablishConnection {
url, url,
@ -19,9 +20,9 @@ pub fn simple_logger() -> Option<Box<dyn Instrumentation>> {
.. ..
} => { } => {
if let Some(e) = error { if let Some(e) = error {
error!("Error during establishing a connection with {url}: {e:?}") error!("Error during establishing a connection with {url}: {e:?}");
} else { } else {
debug!("Connection established: {url}") debug!("Connection established: {url}");
} }
} }
InstrumentationEvent::StartQuery { InstrumentationEvent::StartQuery {
@ -47,7 +48,7 @@ pub fn simple_logger() -> Option<Box<dyn Instrumentation>> {
} else if duration.as_secs() >= 1 { } else if duration.as_secs() >= 1 {
info!("SLOW QUERY [{:.2}s]: {}", duration.as_secs_f32(), query_string); info!("SLOW QUERY [{:.2}s]: {}", duration.as_secs_f32(), query_string);
} else { } else {
debug!("QUERY [{:?}]: {}", duration, query_string); debug!("QUERY [{duration:?}]: {query_string}");
} }
} }
}); });

100
src/error.rs

@ -1,10 +1,11 @@
// //
// Error generator macro // Error generator macro
// //
use std::error::Error as StdError;
use crate::db::models::EventType; use crate::db::models::EventType;
use crate::http_client::CustomHttpClientError; use crate::http_client::CustomHttpClientError;
use serde::ser::{Serialize, SerializeStruct, Serializer}; use serde::ser::{Serialize, SerializeStruct, Serializer};
use std::error::Error as StdError;
macro_rules! make_error { macro_rules! make_error {
( $( $name:ident ( $ty:ty ): $src_fn:expr, $usr_msg_fun:expr ),+ $(,)? ) => { ( $( $name:ident ( $ty:ty ): $src_fn:expr, $usr_msg_fun:expr ),+ $(,)? ) => {
@ -14,24 +15,24 @@ macro_rules! make_error {
#[derive(Debug)] #[derive(Debug)]
pub struct ErrorEvent { pub event: EventType } pub struct ErrorEvent { pub event: EventType }
pub struct Error { message: String, error: ErrorKind, error_code: u16, event: Option<ErrorEvent> } pub struct Error { message: String, kind: ErrorKind, code: u16, event: Option<ErrorEvent> }
$(impl From<$ty> for Error { $(impl From<$ty> for Error {
fn from(err: $ty) -> Self { Error::from((stringify!($name), err)) } fn from(err: $ty) -> Self { Error::from((stringify!($name), err)) }
})+ })+
$(impl<S: Into<String>> From<(S, $ty)> for Error { $(impl<S: Into<String>> From<(S, $ty)> for Error {
fn from(val: (S, $ty)) -> Self { fn from(val: (S, $ty)) -> Self {
Error { message: val.0.into(), error: ErrorKind::$name(val.1), error_code: BAD_REQUEST, event: None } Error { message: val.0.into(), kind: ErrorKind::$name(val.1), code: BAD_REQUEST, event: None }
} }
})+ })+
impl StdError for Error { impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> { fn source(&self) -> Option<&(dyn StdError + 'static)> {
match &self.error {$( ErrorKind::$name(e) => $src_fn(e), )+} match &self.kind {$( ErrorKind::$name(e) => $src_fn(e), )+}
} }
} }
impl std::fmt::Display for Error { impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.error {$( match &self.kind {$(
ErrorKind::$name(e) => f.write_str(&$usr_msg_fun(e, &self.message)), ErrorKind::$name(e) => f.write_str(&$usr_msg_fun(e, &self.message)),
)+} )+}
} }
@ -39,10 +40,10 @@ macro_rules! make_error {
}; };
} }
use diesel::ConnectionError as DieselConErr;
use diesel::r2d2::Error as R2d2Err; use diesel::r2d2::Error as R2d2Err;
use diesel::r2d2::PoolError as R2d2PoolErr; use diesel::r2d2::PoolError as R2d2PoolErr;
use diesel::result::Error as DieselErr; use diesel::result::Error as DieselErr;
use diesel::ConnectionError as DieselConErr;
use handlebars::RenderError as HbErr; use handlebars::RenderError as HbErr;
use jsonwebtoken::errors::Error as JwtErr; use jsonwebtoken::errors::Error as JwtErr;
use lettre::address::AddressError as AddrErr; use lettre::address::AddressError as AddrErr;
@ -71,46 +72,46 @@ pub struct Compact {}
// The second one contains the function used to obtain the response sent to the client // The second one contains the function used to obtain the response sent to the client
make_error! { make_error! {
// Just an empty error // Just an empty error
Empty(Empty): _no_source, _serialize, Empty(Empty): no_source, serialize,
// Used to represent err! calls // Used to represent err! calls
Simple(String): _no_source, _api_error, Simple(String): no_source, api_error,
Compact(Compact): _no_source, _compact_api_error, Compact(Compact): no_source, compact_api_error,
// Used in our custom http client to handle non-global IPs and blocked domains // Used in our custom http client to handle non-global IPs and blocked domains
CustomHttpClient(CustomHttpClientError): _has_source, _api_error, CustomHttpClient(CustomHttpClientError): has_source, api_error,
// Used for special return values, like 2FA errors // Used for special return values, like 2FA errors
Json(Value): _no_source, _serialize, Json(Value): no_source, serialize,
Db(DieselErr): _has_source, _api_error, Db(DieselErr): has_source, api_error,
R2d2(R2d2Err): _has_source, _api_error, R2d2(R2d2Err): has_source, api_error,
R2d2Pool(R2d2PoolErr): _has_source, _api_error, R2d2Pool(R2d2PoolErr): has_source, api_error,
Serde(SerdeErr): _has_source, _api_error, Serde(SerdeErr): has_source, api_error,
JWt(JwtErr): _has_source, _api_error, JWt(JwtErr): has_source, api_error,
Handlebars(HbErr): _has_source, _api_error, Handlebars(HbErr): has_source, api_error,
Io(IoErr): _has_source, _api_error, Io(IoErr): has_source, api_error,
Time(TimeErr): _has_source, _api_error, Time(TimeErr): has_source, api_error,
Req(ReqErr): _has_source, _api_error, Req(ReqErr): has_source, api_error,
Regex(RegexErr): _has_source, _api_error, Regex(RegexErr): has_source, api_error,
Yubico(YubiErr): _has_source, _api_error, Yubico(YubiErr): has_source, api_error,
Lettre(LettreErr): _has_source, _api_error, Lettre(LettreErr): has_source, api_error,
Address(AddrErr): _has_source, _api_error, Address(AddrErr): has_source, api_error,
Smtp(SmtpErr): _has_source, _api_error, Smtp(SmtpErr): has_source, api_error,
OpenSSL(SSLErr): _has_source, _api_error, OpenSSL(SSLErr): has_source, api_error,
Rocket(RocketErr): _has_source, _api_error, Rocket(RocketErr): has_source, api_error,
DieselCon(DieselConErr): _has_source, _api_error, DieselCon(DieselConErr): has_source, api_error,
Webauthn(WebauthnErr): _has_source, _api_error, Webauthn(WebauthnErr): has_source, api_error,
OpenDAL(OpenDALErr): _has_source, _api_error, OpenDAL(OpenDALErr): has_source, api_error,
} }
impl std::fmt::Debug for Error { impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.source() { match self.source() {
Some(e) => write!(f, "{}.\n[CAUSE] {:#?}", self.message, e), Some(e) => write!(f, "{}.\n[CAUSE] {:#?}", self.message, e),
None => match self.error { None => match self.kind {
ErrorKind::Empty(_) => Ok(()), ErrorKind::Empty(_) => Ok(()),
ErrorKind::Simple(ref s) => { ErrorKind::Simple(ref s) => {
if &self.message == s { if &self.message == s {
@ -135,6 +136,7 @@ impl Error {
(usr_msg.clone(), usr_msg.into()).into() (usr_msg.clone(), usr_msg.into()).into()
} }
#[must_use]
pub fn empty() -> Self { pub fn empty() -> Self {
Empty {}.into() Empty {}.into()
} }
@ -147,13 +149,13 @@ impl Error {
#[must_use] #[must_use]
pub fn with_kind(mut self, kind: ErrorKind) -> Self { pub fn with_kind(mut self, kind: ErrorKind) -> Self {
self.error = kind; self.kind = kind;
self self
} }
#[must_use] #[must_use]
pub const fn with_code(mut self, code: u16) -> Self { pub const fn with_code(mut self, code: u16) -> Self {
self.error_code = code; self.code = code;
self self
} }
@ -194,14 +196,14 @@ impl<S> MapResult<S> for Option<S> {
} }
} }
const fn _has_source<T>(e: T) -> Option<T> { const fn has_source<T>(e: T) -> Option<T> {
Some(e) Some(e)
} }
fn _no_source<T, S>(_: T) -> Option<S> { fn no_source<T, S>(_: T) -> Option<S> {
None None
} }
fn _serialize(e: &impl Serialize, _msg: &str) -> String { fn serialize(e: &impl Serialize, _msg: &str) -> String {
serde_json::to_string(e).unwrap() serde_json::to_string(e).unwrap()
} }
@ -280,14 +282,14 @@ struct ApiErrorResponse<'a>(ApiErrorMsg<'a>);
/// The custom serialization adds all other needed fields /// The custom serialization adds all other needed fields
struct CompactApiErrorResponse<'a>(ApiErrorMsg<'a>); struct CompactApiErrorResponse<'a>(ApiErrorMsg<'a>);
fn _api_error(_: &impl std::any::Any, msg: &str) -> String { fn api_error(_: &impl std::any::Any, msg: &str) -> String {
let response = ApiErrorMsg { let response = ApiErrorMsg {
message: msg, message: msg,
}; };
serde_json::to_string(&ApiErrorResponse(response)).unwrap() serde_json::to_string(&ApiErrorResponse(response)).unwrap()
} }
fn _compact_api_error(_: &impl std::any::Any, msg: &str) -> String { fn compact_api_error(_: &impl std::any::Any, msg: &str) -> String {
let response = ApiErrorMsg { let response = ApiErrorMsg {
message: msg, message: msg,
}; };
@ -299,18 +301,20 @@ fn _compact_api_error(_: &impl std::any::Any, msg: &str) -> String {
// //
use std::io::Cursor; use std::io::Cursor;
use rocket::http::{ContentType, Status}; use rocket::{
use rocket::request::Request; http::{ContentType, Status},
use rocket::response::{self, Responder, Response}; request::Request,
response::{self, Responder, Response},
};
impl Responder<'_, 'static> for Error { impl Responder<'_, 'static> for Error {
fn respond_to(self, _: &Request<'_>) -> response::Result<'static> { fn respond_to(self, _: &Request<'_>) -> response::Result<'static> {
match self.error { match self.kind {
ErrorKind::Empty(_) | ErrorKind::Simple(_) | ErrorKind::Compact(_) => {} // Don't print the error in this situation ErrorKind::Empty(_) | ErrorKind::Simple(_) | ErrorKind::Compact(_) => {} // Don't print the error in this situation
_ => error!(target: "error", "{self:#?}"), _ => error!(target: "error", "{self:#?}"),
}; }
let code = Status::from_code(self.error_code).unwrap_or(Status::BadRequest); let code = Status::from_code(self.code).unwrap_or(Status::BadRequest);
let body = self.to_string(); let body = self.to_string();
Response::build().status(code).header(ContentType::JSON).sized_body(Some(body.len()), Cursor::new(body)).ok() Response::build().status(code).header(ContentType::JSON).sized_body(Some(body.len()), Cursor::new(body)).ok()
} }

43
src/http_client.rs

@ -5,17 +5,21 @@ use std::{
time::Duration, time::Duration,
}; };
use hickory_resolver::{net::runtime::TokioRuntimeProvider, TokioResolver}; use hickory_resolver::{TokioResolver, net::runtime::TokioRuntimeProvider};
use regex::Regex; use regex::Regex;
use reqwest::{ use reqwest::{
Client, ClientBuilder,
dns::{Name, Resolve, Resolving}, dns::{Name, Resolve, Resolving},
header, Client, ClientBuilder, header,
}; };
use url::Host; use url::Host;
use crate::{util::is_global, CONFIG}; use crate::{CONFIG, util::is_global};
pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::RequestBuilder, crate::Error> { pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::RequestBuilder, crate::Error> {
static INSTANCE: LazyLock<Client> =
LazyLock::new(|| get_reqwest_client_builder().build().expect("Failed to build client"));
let Ok(url) = url::Url::parse(url) else { let Ok(url) = url::Url::parse(url) else {
err!("Invalid URL"); err!("Invalid URL");
}; };
@ -25,9 +29,6 @@ pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::
should_block_host(&host)?; should_block_host(&host)?;
static INSTANCE: LazyLock<Client> =
LazyLock::new(|| get_reqwest_client_builder().build().expect("Failed to build client"));
Ok(INSTANCE.request(method, url)) Ok(INSTANCE.request(method, url))
} }
@ -67,19 +68,20 @@ fn should_block_ip(ip: IpAddr) -> bool {
} }
fn should_block_address_regex(domain_or_ip: &str) -> bool { fn should_block_address_regex(domain_or_ip: &str) -> bool {
static COMPILED_REGEX: Mutex<Option<(String, Regex)>> = Mutex::new(None);
let Some(block_regex) = CONFIG.http_request_block_regex() else { let Some(block_regex) = CONFIG.http_request_block_regex() else {
return false; return false;
}; };
static COMPILED_REGEX: Mutex<Option<(String, Regex)>> = Mutex::new(None);
let mut guard = COMPILED_REGEX.lock().unwrap(); let mut guard = COMPILED_REGEX.lock().unwrap();
// If the stored regex is up to date, use it // If the stored regex is up to date, use it
if let Some((value, regex)) = &*guard { if let Some((value, regex)) = &*guard
if value == &block_regex { && value == &block_regex
{
return regex.is_match(domain_or_ip); return regex.is_match(domain_or_ip);
} }
}
// If we don't have a regex stored, or it's not up to date, recreate it // If we don't have a regex stored, or it's not up to date, recreate it
let regex = Regex::new(&block_regex).unwrap(); let regex = Regex::new(&block_regex).unwrap();
@ -92,7 +94,7 @@ fn should_block_address_regex(domain_or_ip: &str) -> bool {
pub fn get_valid_host(host: &str) -> Result<Host, CustomHttpClientError> { pub fn get_valid_host(host: &str) -> Result<Host, CustomHttpClientError> {
let Ok(host) = Host::parse(host) else { let Ok(host) = Host::parse(host) else {
return Err(CustomHttpClientError::Invalid { return Err(CustomHttpClientError::Invalid {
domain: host.to_string(), domain: host.to_owned(),
}); });
}; };
@ -136,17 +138,17 @@ pub fn should_block_host<S: AsRef<str>>(host: &Host<S>) -> Result<(), CustomHttp
let (ip, host_str): (Option<IpAddr>, String) = match host { let (ip, host_str): (Option<IpAddr>, String) = match host {
Host::Ipv4(ip) => (Some(IpAddr::V4(*ip)), ip.to_string()), Host::Ipv4(ip) => (Some(IpAddr::V4(*ip)), ip.to_string()),
Host::Ipv6(ip) => (Some(IpAddr::V6(*ip)), ip.to_string()), Host::Ipv6(ip) => (Some(IpAddr::V6(*ip)), ip.to_string()),
Host::Domain(d) => (None, d.as_ref().to_string()), Host::Domain(d) => (None, d.as_ref().to_owned()),
}; };
if let Some(ip) = ip { if let Some(ip) = ip
if should_block_ip(ip) { && should_block_ip(ip)
{
return Err(CustomHttpClientError::NonGlobalIp { return Err(CustomHttpClientError::NonGlobalIp {
domain: None, domain: None,
ip, ip,
}); });
} }
}
if should_block_address_regex(&host_str) { if should_block_address_regex(&host_str) {
return Err(CustomHttpClientError::Blocked { return Err(CustomHttpClientError::Blocked {
@ -233,8 +235,7 @@ impl CustomDnsResolver {
builder.build() builder.build()
}) })
.inspect_err(|e| warn!("Error creating Hickory resolver, falling back to default: {e:?}")) .inspect_err(|e| warn!("Error creating Hickory resolver, falling back to default: {e:?}"))
.map(|resolver| Arc::new(Self::Hickory(Arc::new(resolver)))) .map_or_else(|_| Arc::new(Self::Default()), |resolver| Arc::new(Self::Hickory(Arc::new(resolver))))
.unwrap_or_else(|_| Arc::new(Self::Default()))
} }
// Note that we get an iterator of addresses, but we only grab the first one for convenience // Note that we get an iterator of addresses, but we only grab the first one for convenience
@ -257,13 +258,13 @@ impl CustomDnsResolver {
fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> { fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> {
let Ok(host) = get_valid_host(name) else { let Ok(host) = get_valid_host(name) else {
return Err(CustomHttpClientError::Invalid { return Err(CustomHttpClientError::Invalid {
domain: name.to_string(), domain: name.to_owned(),
}); });
}; };
if should_block_host(&host).is_err() { if should_block_host(&host).is_err() {
return Err(CustomHttpClientError::Blocked { return Err(CustomHttpClientError::Blocked {
domain: name.to_string(), domain: name.to_owned(),
}); });
} }
@ -273,7 +274,7 @@ fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> {
fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomHttpClientError> { fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomHttpClientError> {
if should_block_ip(ip) { if should_block_ip(ip) {
Err(CustomHttpClientError::NonGlobalIp { Err(CustomHttpClientError::NonGlobalIp {
domain: Some(name.to_string()), domain: Some(name.to_owned()),
ip, ip,
}) })
} else { } else {
@ -318,7 +319,7 @@ pub(crate) mod aws {
let future = async move { let future = async move {
let method = reqwest::Method::from_bytes(request.method().as_bytes()) let method = reqwest::Method::from_bytes(request.method().as_bytes())
.map_err(|e| ConnectorError::user(Box::new(e)))?; .map_err(|e| ConnectorError::user(Box::new(e)))?;
let mut req_builder = client.request(method, request.uri().to_string()); let mut req_builder = client.request(method, request.uri().to_owned());
for (name, value) in request.headers() { for (name, value) in request.headers() {
req_builder = req_builder.header(name, value); req_builder = req_builder.header(name, value);

52
src/mail.rs

@ -1,16 +1,17 @@
use chrono::NaiveDateTime;
use percent_encoding::{percent_encode, NON_ALPHANUMERIC};
use std::{env::consts::EXE_SUFFIX, str::FromStr}; use std::{env::consts::EXE_SUFFIX, str::FromStr};
use chrono::NaiveDateTime;
use lettre::{ use lettre::{
Address, AsyncSendmailTransport, AsyncSmtpTransport, AsyncTransport, Tokio1Executor,
message::{Attachment, Body, Mailbox, Message, MultiPart, SinglePart}, message::{Attachment, Body, Mailbox, Message, MultiPart, SinglePart},
transport::smtp::authentication::{Credentials, Mechanism as SmtpAuthMechanism}, transport::smtp::authentication::{Credentials, Mechanism as SmtpAuthMechanism},
transport::smtp::client::{Tls, TlsParameters}, transport::smtp::client::{Tls, TlsParameters},
transport::smtp::extension::ClientId, transport::smtp::extension::ClientId,
Address, AsyncSendmailTransport, AsyncSmtpTransport, AsyncTransport, Tokio1Executor,
}; };
use percent_encoding::{NON_ALPHANUMERIC, percent_encode};
use crate::{ use crate::{
CONFIG,
api::EmptyResult, api::EmptyResult,
auth::{ auth::{
encode_jwt, generate_delete_claims, generate_emergency_access_invite_claims, generate_invite_claims, encode_jwt, generate_delete_claims, generate_emergency_access_invite_claims, generate_invite_claims,
@ -18,7 +19,7 @@ use crate::{
}, },
db::models::{Device, DeviceType, EmergencyAccessId, MembershipId, OrganizationId, User, UserId}, db::models::{Device, DeviceType, EmergencyAccessId, MembershipId, OrganizationId, User, UserId},
error::Error, error::Error,
CONFIG, util::upcase_first,
}; };
fn sendmail_transport() -> AsyncSendmailTransport<Tokio1Executor> { fn sendmail_transport() -> AsyncSendmailTransport<Tokio1Executor> {
@ -38,7 +39,9 @@ fn smtp_transport() -> AsyncSmtpTransport<Tokio1Executor> {
.timeout(Some(Duration::from_secs(CONFIG.smtp_timeout()))); .timeout(Some(Duration::from_secs(CONFIG.smtp_timeout())));
// Determine security // Determine security
let smtp_client = if CONFIG.smtp_security() != *"off" { let smtp_client = if CONFIG.smtp_security() == *"off" {
smtp_client
} else {
let mut tls_parameters = TlsParameters::builder(host); let mut tls_parameters = TlsParameters::builder(host);
if CONFIG.smtp_accept_invalid_hostnames() { if CONFIG.smtp_accept_invalid_hostnames() {
tls_parameters = tls_parameters.dangerous_accept_invalid_hostnames(true); tls_parameters = tls_parameters.dangerous_accept_invalid_hostnames(true);
@ -53,8 +56,6 @@ fn smtp_transport() -> AsyncSmtpTransport<Tokio1Executor> {
} else { } else {
smtp_client.tls(Tls::Required(tls_parameters)) smtp_client.tls(Tls::Required(tls_parameters))
} }
} else {
smtp_client
}; };
let smtp_client = match (CONFIG.smtp_username(), CONFIG.smtp_password()) { let smtp_client = match (CONFIG.smtp_username(), CONFIG.smtp_password()) {
@ -81,12 +82,12 @@ fn smtp_transport() -> AsyncSmtpTransport<Tokio1Executor> {
} }
} }
if !selected_mechanisms.is_empty() { if selected_mechanisms.is_empty() {
smtp_client.authentication(selected_mechanisms)
} else {
// Only show a warning, and return without setting an actual authentication mechanism // Only show a warning, and return without setting an actual authentication mechanism
warn!("No valid SMTP Auth mechanism found for '{mechanism}', using default values"); warn!("No valid SMTP Auth mechanism found for '{mechanism}', using default values");
smtp_client smtp_client
} else {
smtp_client.authentication(selected_mechanisms)
} }
} }
_ => smtp_client, _ => smtp_client,
@ -129,14 +130,16 @@ fn get_template(template_name: &str, data: &serde_json::Value) -> Result<(String
let text = CONFIG.render_template(template_name, data)?; let text = CONFIG.render_template(template_name, data)?;
let mut text_split = text.split("<!---------------->"); let mut text_split = text.split("<!---------------->");
let subject = match text_split.next() { let subject = if let Some(s) = text_split.next() {
Some(s) => s.trim().to_string(), s.trim().to_owned()
None => err!("Template doesn't contain subject"), } else {
err!("Template doesn't contain subject")
}; };
let body = match text_split.next() { let body = if let Some(s) = text_split.next() {
Some(s) => s.trim().to_string(), s.trim().to_owned()
None => err!("Template doesn't contain body"), } else {
err!("Template doesn't contain body")
}; };
if text_split.next().is_some() { if text_split.next().is_some() {
@ -204,9 +207,8 @@ pub async fn send_verify_email(address: &str, user_id: &UserId) -> EmptyResult {
pub async fn send_register_verify_email(email: &str, token: &str) -> EmptyResult { pub async fn send_register_verify_email(email: &str, token: &str) -> EmptyResult {
let mut query = url::Url::parse("https://query.builder").unwrap(); let mut query = url::Url::parse("https://query.builder").unwrap();
query.query_pairs_mut().append_pair("email", email).append_pair("token", token); query.query_pairs_mut().append_pair("email", email).append_pair("token", token);
let query_string = match query.query() { let Some(query_string) = query.query() else {
None => err!("Failed to build verify URL query parameters"), err!("Failed to build verify URL query parameters")
Some(query) => query,
}; };
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
@ -504,8 +506,6 @@ pub async fn send_invite_confirmed(address: &str, org_name: &str) -> EmptyResult
} }
pub async fn send_new_device_logged_in(address: &str, ip: &str, dt: &NaiveDateTime, device: &Device) -> EmptyResult { pub async fn send_new_device_logged_in(address: &str, ip: &str, dt: &NaiveDateTime, device: &Device) -> EmptyResult {
use crate::util::upcase_first;
let fmt = "%A, %B %_d, %Y at %r %Z"; let fmt = "%A, %B %_d, %Y at %r %Z";
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/new_device_logged_in", "email/new_device_logged_in",
@ -529,8 +529,6 @@ pub async fn send_incomplete_2fa_login(
device_name: &str, device_name: &str,
device_type: &str, device_type: &str,
) -> EmptyResult { ) -> EmptyResult {
use crate::util::upcase_first;
let fmt = "%A, %B %_d, %Y at %r %Z"; let fmt = "%A, %B %_d, %Y at %r %Z";
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/incomplete_2fa_login", "email/incomplete_2fa_login",
@ -655,7 +653,7 @@ pub async fn send_protected_action_token(address: &str, token: &str) -> EmptyRes
async fn send_with_selected_transport(email: Message) -> EmptyResult { async fn send_with_selected_transport(email: Message) -> EmptyResult {
if CONFIG.use_sendmail() { if CONFIG.use_sendmail() {
match sendmail_transport().send(email).await { match sendmail_transport().send(email).await {
Ok(_) => Ok(()), Ok(()) => Ok(()),
// Match some common errors and make them more user friendly // Match some common errors and make them more user friendly
Err(e) => { Err(e) => {
if e.is_client() { if e.is_client() {
@ -664,12 +662,11 @@ async fn send_with_selected_transport(email: Message) -> EmptyResult {
} else if e.is_response() { } else if e.is_response() {
debug!("Sendmail response error: {e:?}"); debug!("Sendmail response error: {e:?}");
err!(format!("Sendmail response error: {e}")); err!(format!("Sendmail response error: {e}"));
} else { }
debug!("Sendmail error: {e:?}"); debug!("Sendmail error: {e:?}");
err!(format!("Sendmail error: {e}")); err!(format!("Sendmail error: {e}"));
} }
} }
}
} else { } else {
match smtp_transport().send(email).await { match smtp_transport().send(email).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -695,13 +692,12 @@ async fn send_with_selected_transport(email: Message) -> EmptyResult {
} else if e.is_tls() { } else if e.is_tls() {
debug!("SMTP encryption error: {e:#?}"); debug!("SMTP encryption error: {e:#?}");
err!(format!("SMTP encryption error: {e}")); err!(format!("SMTP encryption error: {e}"));
} else { }
debug!("SMTP error: {e:#?}"); debug!("SMTP error: {e:#?}");
err!(format!("SMTP error: {e}")); err!(format!("SMTP error: {e}"));
} }
} }
} }
}
} }
async fn send_email(address: &str, subject: &str, body_html: String, body_text: String) -> EmptyResult { async fn send_email(address: &str, subject: &str, body_html: String, body_text: String) -> EmptyResult {

42
src/main.rs

@ -33,6 +33,7 @@ use std::{
path::Path, path::Path,
process::exit, process::exit,
str::FromStr, str::FromStr,
sync::{Arc, atomic::Ordering},
thread, thread,
}; };
@ -44,6 +45,8 @@ use tokio::{
#[cfg(unix)] #[cfg(unix)]
use tokio::signal::unix::SignalKind; use tokio::signal::unix::SignalKind;
use rocket::data::{Limits, ToByteUnit};
#[macro_use] #[macro_use]
mod error; mod error;
mod api; mod api;
@ -60,13 +63,11 @@ mod sso_client;
mod storage; mod storage;
mod util; mod util;
use crate::api::core::two_factor::duo_oidc::purge_duo_contexts; use crate::api::{
use crate::api::purge_auth_requests; WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS, core::two_factor::duo_oidc::purge_duo_contexts, purge_auth_requests,
use crate::api::{WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS}; };
pub use config::{PathType, CONFIG}; pub use config::{CONFIG, PathType};
pub use error::{Error, MapResult}; pub use error::{Error, MapResult};
use rocket::data::{Limits, ToByteUnit};
use std::sync::{atomic::Ordering, Arc};
pub use util::is_running_in_container; pub use util::is_running_in_container;
#[rocket::main] #[rocket::main]
@ -137,27 +138,24 @@ fn parse_args() {
if let Some(command) = pargs.subcommand().unwrap_or_default() { if let Some(command) = pargs.subcommand().unwrap_or_default() {
if command == "hash" { if command == "hash" {
use argon2::{ use argon2::{
password_hash::SaltString, Algorithm::Argon2id, Argon2, ParamsBuilder, PasswordHasher, Version::V0x13, Algorithm::Argon2id, Argon2, ParamsBuilder, PasswordHasher, Version::V0x13, password_hash::SaltString,
}; };
let mut argon2_params = ParamsBuilder::new(); let mut argon2_params = ParamsBuilder::new();
let preset: Option<String> = pargs.opt_value_from_str(["-p", "--preset"]).unwrap_or_default(); let preset: Option<String> = pargs.opt_value_from_str(["-p", "--preset"]).unwrap_or_default();
let selected_preset; let selected_preset;
match preset.as_deref() { if preset.as_deref() == Some("owasp") {
Some("owasp") => {
selected_preset = "owasp"; selected_preset = "owasp";
argon2_params.m_cost(19456); argon2_params.m_cost(19456);
argon2_params.t_cost(2); argon2_params.t_cost(2);
argon2_params.p_cost(1); argon2_params.p_cost(1);
} } else {
_ => {
// Bitwarden preset is the default // Bitwarden preset is the default
selected_preset = "bitwarden"; selected_preset = "bitwarden";
argon2_params.m_cost(65540); argon2_params.m_cost(65540);
argon2_params.t_cost(3); argon2_params.t_cost(3);
argon2_params.p_cost(4); argon2_params.p_cost(4);
} }
}
println!("Generate an Argon2id PHC string using the '{selected_preset}' preset:\n"); println!("Generate an Argon2id PHC string using the '{selected_preset}' preset:\n");
@ -247,7 +245,7 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
let level = caps let level = caps
.get(1) .get(1)
.and_then(|m| log::LevelFilter::from_str(m.as_str()).ok()) .and_then(|m| log::LevelFilter::from_str(m.as_str()).ok())
.ok_or(Error::new("Failed to parse global log level".to_string(), ""))?; .ok_or(Error::new("Failed to parse global log level".to_owned(), ""))?;
let levels_override: Vec<(&str, log::LevelFilter)> = caps let levels_override: Vec<(&str, log::LevelFilter)> = caps
.get(2) .get(2)
@ -256,13 +254,13 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
.split(',') .split(',')
.collect::<Vec<&str>>() .collect::<Vec<&str>>()
.into_iter() .into_iter()
.flat_map(|s| match s.split_once('=') { .filter_map(|s| match s.split_once('=') {
Some((log, lvl_str)) => log::LevelFilter::from_str(lvl_str).ok().map(|lvl| (log, lvl)), Some((log, lvl_str)) => log::LevelFilter::from_str(lvl_str).ok().map(|lvl| (log, lvl)),
_ => None, _ => None,
}) })
.collect() .collect()
}) })
.ok_or(Error::new("Failed to parse overrides".to_string(), ""))?; .ok_or(Error::new("Failed to parse overrides".to_owned(), ""))?;
(level, levels_override) (level, levels_override)
} else { } else {
@ -338,7 +336,7 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
("vaultwarden::db::query_logger", log::LevelFilter::Off), ("vaultwarden::db::query_logger", log::LevelFilter::Off),
]); ]);
for (path, level) in levels_override.into_iter() { for (path, level) in levels_override {
let _ = default_levels.insert(path, level); let _ = default_levels.insert(path, level);
} }
@ -352,7 +350,7 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
let mut logger = fern::Dispatch::new().level(level).chain(std::io::stdout()); let mut logger = fern::Dispatch::new().level(level).chain(std::io::stdout());
for (path, level) in default_levels { for (path, level) in default_levels {
logger = logger.level_for(path.to_string(), level); logger = logger.level_for(path.to_owned(), level);
} }
if CONFIG.extended_logging() { if CONFIG.extended_logging() {
@ -363,7 +361,7 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
record.target(), record.target(),
record.level(), record.level(),
message message
)) ));
}); });
} else { } else {
logger = logger.format(|out, message, _| out.finish(format_args!("{message}"))); logger = logger.format(|out, message, _| out.finish(format_args!("{message}")));
@ -609,9 +607,7 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
#[cfg(all(unix, sqlite))] #[cfg(all(unix, sqlite))]
{ {
if db::ACTIVE_DB_TYPE.get() != Some(&db::DbConnType::Sqlite) { if db::ACTIVE_DB_TYPE.get() == Some(&db::DbConnType::Sqlite) {
debug!("PostgreSQL and MySQL/MariaDB do not support this backup feature, skip adding USR1 signal.");
} else {
tokio::spawn(async move { tokio::spawn(async move {
let mut signal_user1 = tokio::signal::unix::signal(SignalKind::user_defined1()).unwrap(); let mut signal_user1 = tokio::signal::unix::signal(SignalKind::user_defined1()).unwrap();
loop { loop {
@ -624,6 +620,8 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
} }
} }
}); });
} else {
debug!("PostgreSQL and MySQL/MariaDB do not support this backup feature, skip adding USR1 signal.");
} }
} }
@ -671,7 +669,7 @@ fn schedule_jobs(pool: db::DbPool) {
let runtime = tokio::runtime::Runtime::new().unwrap(); let runtime = tokio::runtime::Runtime::new().unwrap();
thread::Builder::new() thread::Builder::new()
.name("job-scheduler".to_string()) .name("job-scheduler".to_owned())
.spawn(move || { .spawn(move || {
use job_scheduler_ng::{Job, JobScheduler}; use job_scheduler_ng::{Job, JobScheduler};
let _runtime_guard = runtime.enter(); let _runtime_guard = runtime.enter();

8
src/ratelimit.rs

@ -1,8 +1,8 @@
use std::{net::IpAddr, num::NonZeroU32, sync::LazyLock, time::Duration}; use std::{net::IpAddr, num::NonZeroU32, sync::LazyLock, time::Duration};
use governor::{clock::DefaultClock, state::keyed::DashMapStateStore, Quota, RateLimiter}; use governor::{Quota, RateLimiter, clock::DefaultClock, state::keyed::DashMapStateStore};
use crate::{Error, CONFIG}; use crate::{CONFIG, Error};
type Limiter<T = IpAddr> = RateLimiter<T, DashMapStateStore<T>, DefaultClock>; type Limiter<T = IpAddr> = RateLimiter<T, DashMapStateStore<T>, DefaultClock>;
@ -20,7 +20,7 @@ static LIMITER_ADMIN: LazyLock<Limiter> = LazyLock::new(|| {
pub fn check_limit_login(ip: &IpAddr) -> Result<(), Error> { pub fn check_limit_login(ip: &IpAddr) -> Result<(), Error> {
match LIMITER_LOGIN.check_key(ip) { match LIMITER_LOGIN.check_key(ip) {
Ok(_) => Ok(()), Ok(()) => Ok(()),
Err(_e) => { Err(_e) => {
err_code!("Too many login requests", 429); err_code!("Too many login requests", 429);
} }
@ -29,7 +29,7 @@ pub fn check_limit_login(ip: &IpAddr) -> Result<(), Error> {
pub fn check_limit_admin(ip: &IpAddr) -> Result<(), Error> { pub fn check_limit_admin(ip: &IpAddr) -> Result<(), Error> {
match LIMITER_ADMIN.check_key(ip) { match LIMITER_ADMIN.check_key(ip) {
Ok(_) => Ok(()), Ok(()) => Ok(()),
Err(_e) => { Err(_e) => {
err_code!("Too many admin requests", 429); err_code!("Too many admin requests", 429);
} }

62
src/sso.rs

@ -6,15 +6,15 @@ use regex::Regex;
use url::Url; use url::Url;
use crate::{ use crate::{
CONFIG,
api::ApiResult, api::ApiResult,
auth, auth,
auth::{AuthMethod, AuthTokens, TokenWrapper, BW_EXPIRATION, DEFAULT_REFRESH_VALIDITY}, auth::{AuthMethod, AuthTokens, BW_EXPIRATION, DEFAULT_REFRESH_VALIDITY, TokenWrapper},
db::{ db::{
models::{Device, OIDCAuthenticatedUser, SsoAuth, SsoUser, User},
DbConn, DbConn,
models::{Device, OIDCAuthenticatedUser, SsoAuth, SsoUser, User},
}, },
sso_client::Client, sso_client::Client,
CONFIG,
}; };
pub static FAKE_SSO_IDENTIFIER: &str = "00000000-01DC-01DC-01DC-000000000000"; pub static FAKE_SSO_IDENTIFIER: &str = "00000000-01DC-01DC-01DC-000000000000";
@ -123,7 +123,7 @@ pub fn encode_ssotoken_claims() -> String {
nbf: time_now.timestamp(), nbf: time_now.timestamp(),
exp: (time_now + chrono::TimeDelta::try_minutes(2).unwrap()).timestamp(), exp: (time_now + chrono::TimeDelta::try_minutes(2).unwrap()).timestamp(),
iss: SSO_JWT_ISSUER.to_string(), iss: SSO_JWT_ISSUER.to_string(),
sub: "vaultwarden".to_string(), sub: "vaultwarden".to_owned(),
}; };
auth::encode_jwt(&claims) auth::encode_jwt(&claims)
@ -171,12 +171,14 @@ fn decode_token_claims(token_name: &str, token: &str) -> ApiResult<BasicTokenCla
} }
pub fn decode_state(base64_state: &str) -> ApiResult<OIDCState> { pub fn decode_state(base64_state: &str) -> ApiResult<OIDCState> {
let state = match data_encoding::BASE64.decode(base64_state.as_bytes()) { let state = if let Ok(vec) = data_encoding::BASE64.decode(base64_state.as_bytes()) {
Ok(vec) => match String::from_utf8(vec) { if let Ok(valid) = String::from_utf8(vec) {
Ok(valid) => OIDCState(valid), OIDCState(valid)
Err(_) => err!(format!("Invalid utf8 chars in {base64_state} after base64 decoding")), } else {
}, err!(format!("Invalid utf8 chars in {base64_state} after base64 decoding"))
Err(_) => err!(format!("Failed to decode {base64_state} using base64")), }
} else {
err!(format!("Failed to decode {base64_state} using base64"))
}; };
Ok(state) Ok(state)
@ -193,12 +195,15 @@ pub async fn authorize_url(
) -> ApiResult<Url> { ) -> ApiResult<Url> {
let redirect_uri = match client_id { let redirect_uri = match client_id {
"web" | "browser" => format!("{}/sso-connector.html", CONFIG.domain()), "web" | "browser" => format!("{}/sso-connector.html", CONFIG.domain()),
"desktop" | "mobile" => "bitwarden://sso-callback".to_string(), "desktop" | "mobile" => "bitwarden://sso-callback".to_owned(),
"cli" => { "cli" => {
let port_regex = Regex::new(r"^http://localhost:([0-9]{4})$").unwrap(); let port_regex = Regex::new(r"^http://localhost:([0-9]{4})$").unwrap();
match port_regex.captures(raw_redirect_uri).and_then(|captures| captures.get(1).map(|c| c.as_str())) { if let Some(port) =
Some(port) => format!("http://localhost:{port}"), port_regex.captures(raw_redirect_uri).and_then(|captures| captures.get(1).map(|c| c.as_str()))
None => err!("Failed to extract port number"), {
format!("http://localhost:{port}")
} else {
err!("Failed to extract port number")
} }
} }
_ => err!(format!("Unsupported client {client_id}")), _ => err!(format!("Unsupported client {client_id}")),
@ -246,9 +251,8 @@ pub async fn exchange_code(
) -> ApiResult<(SsoAuth, OIDCAuthenticatedUser)> { ) -> ApiResult<(SsoAuth, OIDCAuthenticatedUser)> {
use openidconnect::OAuth2TokenResponse; use openidconnect::OAuth2TokenResponse;
let mut sso_auth = match SsoAuth::find_by_code(code, conn).await { let Some(mut sso_auth) = SsoAuth::find_by_code(code, conn).await else {
None => err!(format!("Invalid code cannot retrieve sso auth")), err!("Invalid code cannot retrieve sso auth")
Some(sso_auth) => sso_auth,
}; };
if let Some(authenticated_user) = sso_auth.auth_response.clone() { if let Some(authenticated_user) = sso_auth.auth_response.clone() {
@ -286,8 +290,8 @@ pub async fn exchange_code(
let user_name = id_claims.preferred_username().or(user_info.preferred_username()).map(|un| un.to_string()); let user_name = id_claims.preferred_username().or(user_info.preferred_username()).map(|un| un.to_string());
let refresh_token = token_response.refresh_token().map(|t| t.secret()); let refresh_token = token_response.refresh_token().map(openidconnect::RefreshToken::secret);
if refresh_token.is_none() && CONFIG.sso_scopes_vec().contains(&"offline_access".to_string()) { if refresh_token.is_none() && CONFIG.sso_scopes_vec().contains(&"offline_access".to_owned()) {
error!("Scope offline_access is present but response contain no refresh_token"); error!("Scope offline_access is present but response contain no refresh_token");
} }
@ -331,7 +335,9 @@ pub async fn redeem(
user_sso.save(conn).await?; user_sso.save(conn).await?;
} }
if !CONFIG.sso_auth_only_not_session() { if CONFIG.sso_auth_only_not_session() {
Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id))
} else {
let now = Utc::now(); let now = Utc::now();
let (ap_nbf, ap_exp) = let (ap_nbf, ap_exp) =
@ -344,9 +350,7 @@ pub async fn redeem(
let access_claims = let access_claims =
auth::LoginJwtClaims::new(device, user, ap_nbf, ap_exp, AuthMethod::Sso.scope_vec(), client_id, now); auth::LoginJwtClaims::new(device, user, ap_nbf, ap_exp, AuthMethod::Sso.scope_vec(), client_id, now);
_create_auth_tokens(device, auth_user.refresh_token, access_claims, auth_user.access_token) create_auth_tokens_impl(device, auth_user.refresh_token, access_claims, auth_user.access_token)
} else {
Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id))
} }
} }
@ -360,7 +364,9 @@ pub fn create_auth_tokens(
access_token: String, access_token: String,
expires_in: Option<Duration>, expires_in: Option<Duration>,
) -> ApiResult<AuthTokens> { ) -> ApiResult<AuthTokens> {
if !CONFIG.sso_auth_only_not_session() { if CONFIG.sso_auth_only_not_session() {
Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id))
} else {
let now = Utc::now(); let now = Utc::now();
let (ap_nbf, ap_exp) = match (decode_token_claims("access_token", &access_token), expires_in) { let (ap_nbf, ap_exp) = match (decode_token_claims("access_token", &access_token), expires_in) {
@ -372,13 +378,11 @@ pub fn create_auth_tokens(
let access_claims = let access_claims =
auth::LoginJwtClaims::new(device, user, ap_nbf, ap_exp, AuthMethod::Sso.scope_vec(), client_id, now); auth::LoginJwtClaims::new(device, user, ap_nbf, ap_exp, AuthMethod::Sso.scope_vec(), client_id, now);
_create_auth_tokens(device, refresh_token, access_claims, access_token) create_auth_tokens_impl(device, refresh_token, access_claims, access_token)
} else {
Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id))
} }
} }
fn _create_auth_tokens( fn create_auth_tokens_impl(
device: &Device, device: &Device,
refresh_token: Option<String>, refresh_token: Option<String>,
access_claims: auth::LoginJwtClaims, access_claims: auth::LoginJwtClaims,
@ -462,7 +466,7 @@ pub async fn exchange_refresh_token(
now, now,
); );
_create_auth_tokens(device, None, access_claims, access_token) create_auth_tokens_impl(device, None, access_claims, access_token)
} }
None => err!("No token present while in SSO"), None => err!("No token present while in SSO"),
} }

50
src/sso_client.rs

@ -1,18 +1,31 @@
use std::{borrow::Cow, future::Future, pin::Pin, sync::LazyLock, time::Duration}; use std::{borrow::Cow, future::Future, pin::Pin, sync::LazyLock, time::Duration};
use openidconnect::{core::*, *}; use openidconnect::{
AccessToken, AsyncHttpClient, AuthDisplay, AuthPrompt, AuthenticationFlow, AuthorizationCode, AuthorizationRequest,
ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, EmptyExtraTokenFields, EndpointNotSet, EndpointSet,
HttpClientError, HttpRequest, HttpResponse, IdTokenClaims, IdTokenFields, Nonce, OAuth2TokenResponse,
PkceCodeChallenge, PkceCodeVerifier, RefreshToken, ResponseType, Scope, StandardErrorResponse,
StandardTokenResponse,
core::{
CoreAuthDisplay, CoreAuthPrompt, CoreClient, CoreErrorResponseType, CoreGenderClaim, CoreIdTokenVerifier,
CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreProviderMetadata,
CoreResponseType, CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse,
CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
},
http, url,
};
use regex::Regex; use regex::Regex;
use url::Url; use url::Url;
use crate::{ use crate::{
CONFIG,
api::{ApiResult, EmptyResult}, api::{ApiResult, EmptyResult},
db::models::SsoAuth, db::models::SsoAuth,
http_client::get_reqwest_client_builder, http_client::get_reqwest_client_builder,
sso::{OIDCCode, OIDCCodeChallenge, OIDCCodeVerifier, OIDCState}, sso::{OIDCCode, OIDCCodeChallenge, OIDCCodeVerifier, OIDCState},
CONFIG,
}; };
static CLIENT_CACHE_KEY: LazyLock<String> = LazyLock::new(|| "sso-client".to_string()); static CLIENT_CACHE_KEY: LazyLock<String> = LazyLock::new(|| "sso-client".to_owned());
static CLIENT_CACHE: LazyLock<moka::sync::Cache<String, Client>> = LazyLock::new(|| { static CLIENT_CACHE: LazyLock<moka::sync::Cache<String, Client>> = LazyLock::new(|| {
moka::sync::Cache::builder() moka::sync::Cache::builder()
.max_capacity(1) .max_capacity(1)
@ -85,7 +98,7 @@ impl<'c> AsyncHttpClient<'c> for OidcHttpClient {
impl Client { impl Client {
// Call the OpenId discovery endpoint to retrieve configuration // Call the OpenId discovery endpoint to retrieve configuration
async fn _get_client() -> ApiResult<Self> { async fn get_client() -> ApiResult<Self> {
let client_id = ClientId::new(CONFIG.sso_client_id()); let client_id = ClientId::new(CONFIG.sso_client_id());
let client_secret = ClientSecret::new(CONFIG.sso_client_secret()); let client_secret = ClientSecret::new(CONFIG.sso_client_secret());
@ -103,14 +116,16 @@ impl Client {
let base_client = CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret)); let base_client = CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret));
let token_uri = match base_client.token_uri() { let token_uri = if let Some(uri) = base_client.token_uri() {
Some(uri) => uri.clone(), uri.clone()
None => err!("Failed to discover token_url, cannot proceed"), } else {
err!("Failed to discover token_url, cannot proceed")
}; };
let user_info_url = match base_client.user_info_url() { let user_info_url = if let Some(url) = base_client.user_info_url() {
Some(url) => url.clone(), url.clone()
None => err!("Failed to discover user_info url, cannot proceed"), } else {
err!("Failed to discover user_info url, cannot proceed")
}; };
let core_client = base_client let core_client = base_client
@ -129,13 +144,13 @@ impl Client {
if CONFIG.sso_client_cache_expiration() > 0 { if CONFIG.sso_client_cache_expiration() > 0 {
match CLIENT_CACHE.get(&*CLIENT_CACHE_KEY) { match CLIENT_CACHE.get(&*CLIENT_CACHE_KEY) {
Some(client) => Ok(client), Some(client) => Ok(client),
None => Self::_get_client().await.inspect(|client| { None => Self::get_client().await.inspect(|client| {
debug!("Inserting new client in cache"); debug!("Inserting new client in cache");
CLIENT_CACHE.insert(CLIENT_CACHE_KEY.clone(), client.clone()); CLIENT_CACHE.insert(CLIENT_CACHE_KEY.clone(), client.clone());
}), }),
} }
} else { } else {
Self::_get_client().await Self::get_client().await
} }
} }
@ -214,15 +229,14 @@ impl Client {
Ok(token_response) => { Ok(token_response) => {
let oidc_nonce = Nonce::new(sso_auth.nonce.clone()); let oidc_nonce = Nonce::new(sso_auth.nonce.clone());
let id_token = match token_response.extra_fields().id_token() { let Some(id_token) = token_response.extra_fields().id_token() else {
None => err!("Token response did not contain an id_token"), err!("Token response did not contain an id_token")
Some(token) => token,
}; };
if CONFIG.sso_debug_tokens() { if CONFIG.sso_debug_tokens() {
debug!("Id token: {}", id_token.to_string()); debug!("Id token: {}", id_token.to_string());
debug!("Access token: {}", token_response.access_token().secret()); debug!("Access token: {}", token_response.access_token().secret());
debug!("Refresh token: {:?}", token_response.refresh_token().map(|t| t.secret())); debug!("Refresh token: {:?}", token_response.refresh_token().map(RefreshToken::secret));
debug!("Expiration time: {:?}", token_response.expires_in()); debug!("Expiration time: {:?}", token_response.expires_in());
} }
@ -275,12 +289,12 @@ impl Client {
let client = Client::cached().await?; let client = Client::cached().await?;
REFRESH_CACHE REFRESH_CACHE
.get_with(refresh_token.clone(), async move { client._exchange_refresh_token(refresh_token).await }) .get_with(refresh_token.clone(), async move { client.exchange_refresh_token_impl(refresh_token).await })
.await .await
.map_err(Into::into) .map_err(Into::into)
} }
async fn _exchange_refresh_token(&self, refresh_token: String) -> Result<RefreshTokenResponse, String> { async fn exchange_refresh_token_impl(&self, refresh_token: String) -> Result<RefreshTokenResponse, String> {
let rt = RefreshToken::new(refresh_token); let rt = RefreshToken::new(refresh_token);
match self.core_client.exchange_refresh_token(&rt).request_async(&self.http_client).await { match self.core_client.exchange_refresh_token(&rt).request_async(&self.http_client).await {

30
src/storage.rs

@ -9,9 +9,9 @@ pub(crate) fn join_path(base: &str, child: &str) -> String {
let base = base.trim_end_matches('/'); let base = base.trim_end_matches('/');
let child = child.trim_start_matches('/'); let child = child.trim_start_matches('/');
if base.is_empty() { if base.is_empty() {
child.to_string() child.to_owned()
} else if child.is_empty() { } else if child.is_empty() {
base.to_string() base.to_owned()
} else { } else {
format!("{base}/{child}") format!("{base}/{child}")
} }
@ -34,7 +34,7 @@ pub(crate) fn parent(path: &str) -> Option<String> {
return s3::parent(path); return s3::parent(path);
} }
std::path::Path::new(path).parent()?.to_str().map(ToString::to_string) std::path::Path::new(path).parent()?.to_str().map(str::to_owned)
} }
pub(crate) fn file_name(path: &str) -> Option<String> { pub(crate) fn file_name(path: &str) -> Option<String> {
@ -43,7 +43,7 @@ pub(crate) fn file_name(path: &str) -> Option<String> {
return s3::file_name(path); return s3::file_name(path);
} }
std::path::Path::new(path).file_name()?.to_str().map(ToString::to_string) std::path::Path::new(path).file_name()?.to_str().map(str::to_owned)
} }
pub(crate) fn is_fs_operator(operator: &opendal::Operator) -> bool { pub(crate) fn is_fs_operator(operator: &opendal::Operator) -> bool {
@ -70,7 +70,7 @@ pub(crate) fn operator_for_path(path: &str) -> Result<opendal::Operator, crate::
opendal::Operator::new(builder)?.finish() opendal::Operator::new(builder)?.finish()
}; };
OPERATORS_BY_PATH.insert(path.to_string(), operator.clone()); OPERATORS_BY_PATH.insert(path.to_owned(), operator.clone());
Ok(operator) Ok(operator)
} }
@ -88,7 +88,7 @@ mod s3 {
pub(super) fn join_path(base: &str, child: &str) -> String { pub(super) fn join_path(base: &str, child: &str) -> String {
if let Ok(mut url) = Url::parse(base) { if let Ok(mut url) = Url::parse(base) {
let mut segments = path_segments(&url); let mut segments = path_segments(&url);
segments.extend(child.split('/').filter(|segment| !segment.is_empty()).map(ToString::to_string)); segments.extend(child.split('/').filter(|segment| !segment.is_empty()).map(str::to_owned));
set_path_segments(&mut url, &segments); set_path_segments(&mut url, &segments);
return url.to_string(); return url.to_string();
} }
@ -96,9 +96,9 @@ mod s3 {
let base = base.trim_end_matches('/'); let base = base.trim_end_matches('/');
let child = child.trim_start_matches('/'); let child = child.trim_start_matches('/');
if base.is_empty() { if base.is_empty() {
child.to_string() child.to_owned()
} else if child.is_empty() { } else if child.is_empty() {
base.to_string() base.to_owned()
} else { } else {
format!("{base}/{child}") format!("{base}/{child}")
} }
@ -126,7 +126,7 @@ mod s3 {
return Some(url.to_string()); return Some(url.to_string());
} }
std::path::Path::new(path).parent()?.to_str().map(ToString::to_string) std::path::Path::new(path).parent()?.to_str().map(str::to_owned)
} }
pub(super) fn file_name(path: &str) -> Option<String> { pub(super) fn file_name(path: &str) -> Option<String> {
@ -134,12 +134,12 @@ mod s3 {
return path_segments(&url).pop(); return path_segments(&url).pop();
} }
std::path::Path::new(path).file_name()?.to_str().map(ToString::to_string) std::path::Path::new(path).file_name()?.to_str().map(str::to_owned)
} }
fn path_segments(url: &Url) -> Vec<String> { fn path_segments(url: &Url) -> Vec<String> {
url.path_segments() url.path_segments()
.map(|segments| segments.filter(|segment| !segment.is_empty()).map(ToString::to_string).collect()) .map(|segments| segments.filter(|segment| !segment.is_empty()).map(str::to_owned).collect())
.unwrap_or_default() .unwrap_or_default()
} }
@ -206,9 +206,9 @@ mod s3 {
}; };
Ok(Some(Credential { Ok(Some(Credential {
access_key_id: creds.access_key_id().to_string(), access_key_id: creds.access_key_id().to_owned(),
secret_access_key: creds.secret_access_key().to_string(), secret_access_key: creds.secret_access_key().to_owned(),
session_token: creds.session_token().map(|s| s.to_string()), session_token: creds.session_token().map(ToOwned::to_owned),
expires_in, expires_in,
})) }))
} }
@ -218,7 +218,7 @@ mod s3 {
let mut config = opendal::services::S3Config::from_uri(&uri)?; let mut config = opendal::services::S3Config::from_uri(&uri)?;
if !uri_has_option(&uri, &["default_storage_class"]) { if !uri_has_option(&uri, &["default_storage_class"]) {
config.default_storage_class = Some("INTELLIGENT_TIERING".to_string()); config.default_storage_class = Some("INTELLIGENT_TIERING".to_owned());
} }
if !uri_has_option( if !uri_has_option(

107
src/util.rs

@ -1,24 +1,28 @@
// //
// Web Headers and caching // Web Headers and caching
// //
use std::{collections::HashMap, io::Cursor, path::Path}; use std::{collections::HashMap, env, fmt, io::Cursor, path::Path, str::FromStr};
use chrono::{DateTime, Local, NaiveDateTime, TimeZone};
use num_traits::ToPrimitive; use num_traits::ToPrimitive;
use tokio::{
runtime::Handle,
time::{Duration, sleep},
};
use serde::de::{self, DeserializeOwned, Deserializer, MapAccess, SeqAccess, Visitor};
use serde_json::Value;
use rocket::{ use rocket::{
Data, Orbit, Request, Response, Rocket,
fairing::{Fairing, Info, Kind}, fairing::{Fairing, Info, Kind},
http::{ContentType, Header, HeaderMap, Method, Status}, http::{ContentType, Header, HeaderMap, Method, Status},
response::{self, Responder}, response::{self, Responder},
Data, Orbit, Request, Response, Rocket,
};
use tokio::{
runtime::Handle,
time::{sleep, Duration},
}; };
use crate::{ use crate::{
config::{PathType, SUPPORTED_FEATURE_FLAGS},
CONFIG, CONFIG,
config::{PathType, SUPPORTED_FEATURE_FLAGS},
}; };
pub struct AppHeaders(); pub struct AppHeaders();
@ -75,11 +79,16 @@ impl Fairing for AppHeaders {
// Do not send the Content-Security-Policy (CSP) Header and X-Frame-Options for the *-connector.html files. // Do not send the Content-Security-Policy (CSP) Header and X-Frame-Options for the *-connector.html files.
// This can cause issues when some MFA requests needs to open a popup or page within the clients like WebAuthn, or Duo. // This can cause issues when some MFA requests needs to open a popup or page within the clients like WebAuthn, or Duo.
// This is the same behavior as upstream Bitwarden. // This is the same behavior as upstream Bitwarden.
if !req_uri_path.ends_with("connector.html") { if req_uri_path.ends_with("connector.html") {
// It looks like this header get's set somewhere else also, make sure this is not sent for these files, it will cause MFA issues.
res.remove_header("X-Frame-Options");
} else {
let csp = if is_image { let csp = if is_image {
// Prevent scripts, frames, objects, etc., from loading with images, mainly for SVG images, since these could contain JavaScript and other unsafe items. // Prevent scripts, frames, objects, etc., from loading with images, mainly for SVG images, since these could contain JavaScript and other unsafe items.
// Even though we sanitize SVG images before storing and viewing them, it's better to prevent allowing these elements. // Even though we sanitize SVG images before storing and viewing them, it's better to prevent allowing these elements.
String::from("default-src 'none'; img-src 'self' data:; style-src 'unsafe-inline'; script-src 'none'; frame-src 'none'; object-src 'none") String::from(
"default-src 'none'; img-src 'self' data:; style-src 'unsafe-inline'; script-src 'none'; frame-src 'none'; object-src 'none",
)
} else { } else {
// # Frame Ancestors: // # Frame Ancestors:
// Chrome Web Store: https://chrome.google.com/webstore/detail/bitwarden-free-password-m/nngceckbapebfimnlniiiahkandclblb // Chrome Web Store: https://chrome.google.com/webstore/detail/bitwarden-free-password-m/nngceckbapebfimnlniiiahkandclblb
@ -129,9 +138,6 @@ impl Fairing for AppHeaders {
res.set_raw_header("Content-Security-Policy", csp); res.set_raw_header("Content-Security-Policy", csp);
res.set_raw_header("X-Frame-Options", "SAMEORIGIN"); res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
} else {
// It looks like this header get's set somewhere else also, make sure this is not sent for these files, it will cause MFA issues.
res.remove_header("X-Frame-Options");
} }
// Disable cache unless otherwise specified // Disable cache unless otherwise specified
@ -146,7 +152,7 @@ pub struct Cors();
impl Cors { impl Cors {
fn get_header(headers: &HeaderMap<'_>, name: &str) -> String { fn get_header(headers: &HeaderMap<'_>, name: &str) -> String {
match headers.get_one(name) { match headers.get_one(name) {
Some(h) => h.to_string(), Some(h) => h.to_owned(),
_ => String::new(), _ => String::new(),
} }
} }
@ -212,7 +218,7 @@ impl<R> Cached<R> {
Self { Self {
response, response,
is_immutable, is_immutable,
ttl: 604800, // 7 days ttl: 604_800, // 7 days
} }
} }
@ -286,7 +292,7 @@ impl Fairing for BetterLogging {
} else { } else {
"http" "http"
}; };
let addr = format!("{scheme}://{}:{}", &config.address, &config.port); let addr = format!("{scheme}://{}:{}", config.address, config.port);
info!(target: "start", "Rocket has launched from {addr}"); info!(target: "start", "Rocket has launched from {addr}");
} }
@ -303,7 +309,7 @@ impl Fairing for BetterLogging {
match uri.query() { match uri.query() {
Some(q) => info!(target: "request", "{method} {uri_path_str}?{}", &q[..q.len().min(30)]), Some(q) => info!(target: "request", "{method} {uri_path_str}?{}", &q[..q.len().min(30)]),
None => info!(target: "request", "{method} {uri_path_str}"), None => info!(target: "request", "{method} {uri_path_str}"),
}; }
} }
} }
@ -316,10 +322,10 @@ impl Fairing for BetterLogging {
let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str); let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) { if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
let status = response.status(); let status = response.status();
if let Some(ref route) = request.route() { if let Some(route) = request.route() {
info!(target: "response", "{route} => {status}") info!(target: "response", "{route} => {status}");
} else { } else {
info!(target: "response", "{status}") info!(target: "response", "{status}");
} }
} }
} }
@ -354,9 +360,6 @@ pub fn get_uuid() -> String {
// //
// String util methods // String util methods
// //
use std::str::FromStr;
#[inline] #[inline]
pub fn upcase_first(s: &str) -> String { pub fn upcase_first(s: &str) -> String {
let mut c = s.chars(); let mut c = s.chars();
@ -390,9 +393,6 @@ where
// //
// Env methods // Env methods
// //
use std::env;
pub fn get_env_str_value(key: &str) -> Option<String> { pub fn get_env_str_value(key: &str) -> Option<String> {
let key_file = format!("{key}_FILE"); let key_file = format!("{key}_FILE");
let value_from_env = env::var(key); let value_from_env = env::var(key);
@ -402,7 +402,7 @@ pub fn get_env_str_value(key: &str) -> Option<String> {
(Ok(_), Ok(_)) => panic!("You should not define both {key} and {key_file}!"), (Ok(_), Ok(_)) => panic!("You should not define both {key} and {key_file}!"),
(Ok(v_env), Err(_)) => Some(v_env), (Ok(v_env), Err(_)) => Some(v_env),
(Err(_), Ok(v_file)) => match std::fs::read_to_string(v_file) { (Err(_), Ok(v_file)) => match std::fs::read_to_string(v_file) {
Ok(content) => Some(content.trim().to_string()), Ok(content) => Some(content.trim().to_owned()),
Err(e) => panic!("Failed to load {key}: {e:?}"), Err(e) => panic!("Failed to load {key}: {e:?}"),
}, },
_ => None, _ => None,
@ -431,8 +431,6 @@ pub fn get_env_bool(key: &str) -> Option<bool> {
// Date util methods // Date util methods
// //
use chrono::{DateTime, Local, NaiveDateTime, TimeZone};
/// Formats a UTC-offset `NaiveDateTime` in the format used by Bitwarden API /// Formats a UTC-offset `NaiveDateTime` in the format used by Bitwarden API
/// responses with "date" fields (`CreationDate`, `RevisionDate`, etc.). /// responses with "date" fields (`CreationDate`, `RevisionDate`, etc.).
pub fn format_date(dt: &NaiveDateTime) -> String { pub fn format_date(dt: &NaiveDateTime) -> String {
@ -457,11 +455,11 @@ pub fn validate_and_format_date(dt: &str) -> String {
pub fn format_datetime_local(dt: &DateTime<Local>, fmt: &str) -> String { pub fn format_datetime_local(dt: &DateTime<Local>, fmt: &str) -> String {
// Try parsing the `TZ` environment variable to enable formatting `%Z` as // Try parsing the `TZ` environment variable to enable formatting `%Z` as
// a time zone abbreviation. // a time zone abbreviation.
if let Ok(tz) = env::var("TZ") { if let Ok(tz) = env::var("TZ")
if let Ok(tz) = tz.parse::<chrono_tz::Tz>() { && let Ok(tz) = tz.parse::<chrono_tz::Tz>()
{
return dt.with_timezone(&tz).format(fmt).to_string(); return dt.with_timezone(&tz).format(fmt).to_string();
} }
}
// Otherwise, fall back to formatting `%Z` as a UTC offset. // Otherwise, fall back to formatting `%Z` as a UTC offset.
dt.format(fmt).to_string() dt.format(fmt).to_string()
@ -512,6 +510,7 @@ pub fn is_valid_email(email: &str) -> bool {
// //
/// Returns true if the program is running in Docker, Podman or Kubernetes. /// Returns true if the program is running in Docker, Podman or Kubernetes.
#[must_use]
pub fn is_running_in_container() -> bool { pub fn is_running_in_container() -> bool {
Path::new("/.dockerenv").exists() Path::new("/.dockerenv").exists()
|| Path::new("/run/.containerenv").exists() || Path::new("/run/.containerenv").exists()
@ -543,12 +542,12 @@ pub fn get_active_web_release() -> String {
]; ];
for version_file in version_files { for version_file in version_files {
if let Ok(version_str) = std::fs::read_to_string(&version_file) { if let Ok(version_str) = std::fs::read_to_string(&version_file)
if let Ok(version) = serde_json::from_str::<WebVaultVersion>(&version_str) { && let Ok(version) = serde_json::from_str::<WebVaultVersion>(&version_str)
{
return String::from(version.version.trim_start_matches('v')); return String::from(version.version.trim_start_matches('v'));
} }
} }
}
String::from("Version file missing") String::from("Version file missing")
} }
@ -556,12 +555,6 @@ pub fn get_active_web_release() -> String {
// //
// Deserialization methods // Deserialization methods
// //
use std::fmt;
use serde::de::{self, DeserializeOwned, Deserializer, MapAccess, SeqAccess, Visitor};
use serde_json::Value;
pub type JsonMap = serde_json::Map<String, Value>; pub type JsonMap = serde_json::Map<String, Value>;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -605,7 +598,7 @@ impl<'de> Visitor<'de> for LowerCaseVisitor {
let mut result_map = JsonMap::new(); let mut result_map = JsonMap::new();
while let Some((key, value)) = map.next_entry()? { while let Some((key, value)) = map.next_entry()? {
result_map.insert(_process_key(key), convert_json_key_lcase_first(value)); result_map.insert(process_json_key(key), convert_json_key_lcase_first(value));
} }
Ok(Value::Object(result_map)) Ok(Value::Object(result_map))
@ -627,7 +620,7 @@ impl<'de> Visitor<'de> for LowerCaseVisitor {
// Inner function to handle a special case for the 'ssn' key. // Inner function to handle a special case for the 'ssn' key.
// This key is part of the Identity Cipher (Social Security Number) // This key is part of the Identity Cipher (Social Security Number)
fn _process_key(key: &str) -> String { fn process_json_key(key: &str) -> String {
match key.to_lowercase().as_ref() { match key.to_lowercase().as_ref() {
"ssn" => "ssn".into(), "ssn" => "ssn".into(),
_ => lcase_first(key), _ => lcase_first(key),
@ -664,21 +657,24 @@ impl NumberOrString {
} }
} }
#[allow(clippy::wrong_self_convention)] #[expect(clippy::wrong_self_convention)]
pub fn into_i32(&self) -> Result<i32, crate::Error> { pub fn into_i32(&self) -> Result<i32, crate::Error> {
use std::num::ParseIntError as PIE; use std::num::ParseIntError as PIE;
match self { match self {
NumberOrString::Number(n) => match n.to_i32() { NumberOrString::Number(n) => {
Some(n) => Ok(n), if let Some(n) = n.to_i32() {
None => err!("Number does not fit in i32"), Ok(n)
}, } else {
err!("Number does not fit in i32")
}
}
NumberOrString::String(s) => { NumberOrString::String(s) => {
s.parse().map_err(|e: PIE| crate::Error::new("Can't convert to number", e.to_string())) s.parse().map_err(|e: PIE| crate::Error::new("Can't convert to number", e.to_string()))
} }
} }
} }
#[allow(clippy::wrong_self_convention)] #[expect(clippy::wrong_self_convention)]
pub fn into_i64(&self) -> Result<i64, crate::Error> { pub fn into_i64(&self) -> Result<i64, crate::Error> {
use std::num::ParseIntError as PIE; use std::num::ParseIntError as PIE;
match self { match self {
@ -753,11 +749,11 @@ pub fn convert_json_key_lcase_first(src_json: Value) -> Value {
Value::Object(obj) => { Value::Object(obj) => {
let mut json_map = JsonMap::new(); let mut json_map = JsonMap::new();
for (key, value) in obj.into_iter() { for (key, value) in obj {
match (key, value) { match (key, value) {
(key, Value::Object(elm)) => { (key, Value::Object(elm)) => {
let inner_value = convert_json_key_lcase_first(Value::Object(elm)); let inner_value = convert_json_key_lcase_first(Value::Object(elm));
json_map.insert(_process_key(&key), inner_value); json_map.insert(process_json_key(&key), inner_value);
} }
(key, Value::Array(elm)) => { (key, Value::Array(elm)) => {
@ -767,11 +763,11 @@ pub fn convert_json_key_lcase_first(src_json: Value) -> Value {
inner_array.push(convert_json_key_lcase_first(inner_obj)); inner_array.push(convert_json_key_lcase_first(inner_obj));
} }
json_map.insert(_process_key(&key), Value::Array(inner_array)); json_map.insert(process_json_key(&key), Value::Array(inner_array));
} }
(key, value) => { (key, value) => {
json_map.insert(_process_key(&key), value); json_map.insert(process_json_key(&key), value);
} }
} }
} }
@ -793,7 +789,7 @@ pub enum FeatureFlagFilter {
/// Parses the experimental client feature flags string into a HashMap. /// Parses the experimental client feature flags string into a HashMap.
pub fn parse_experimental_client_feature_flags( pub fn parse_experimental_client_feature_flags(
experimental_client_feature_flags: &str, experimental_client_feature_flags: &str,
filter_mode: FeatureFlagFilter, filter_mode: &FeatureFlagFilter,
) -> HashMap<String, bool> { ) -> HashMap<String, bool> {
experimental_client_feature_flags experimental_client_feature_flags
.split(',') .split(',')
@ -811,7 +807,8 @@ pub fn parse_experimental_client_feature_flags(
/// TODO: This is extracted from IpAddr::is_global, which is unstable: /// TODO: This is extracted from IpAddr::is_global, which is unstable:
/// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global /// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global
/// Remove once https://github.com/rust-lang/rust/issues/27709 is merged /// Remove once https://github.com/rust-lang/rust/issues/27709 is merged
#[allow(clippy::nonminimal_bool)] // #[expect(clippy::nonminimal_bool, reason = "Mostly copy/paste from std, keep as-is")]
#[expect(clippy::decimal_bitwise_operands, reason = "Mostly copy/paste from std, keep as-is")]
#[cfg(any(not(feature = "unstable"), test))] #[cfg(any(not(feature = "unstable"), test))]
pub fn is_global_hardcoded(ip: std::net::IpAddr) -> bool { pub fn is_global_hardcoded(ip: std::net::IpAddr) -> bool {
match ip { match ip {

Loading…
Cancel
Save