Browse Source

Use Diesels MultiConnections Derive

With this PR we remove almost all custom macro's to create the multiple database type code. This is now handled by Diesel it self.

This removed the need of the following functions/macro's:
 - `db_object!`
 - `::to_db`
 - `.from_db()`

It is also possible to just use one schema instead of multiple per type.

Also done:
 - Refactored the SQLite backup function
 - Some formatting of queries so every call is one a separate line, this looks a bit better
 - Declare `conn` as mut inside each `db_run!` instead of having to declare it as `mut` in functions or calls
 - Added an `ACTIVE_DB_TYPE` static which holds the currently active database type
 - Removed `diesel_logger` crate and use Diesel's `set_default_instrumentation()`
   If you want debug queries you can now simply change the log level of `vaultwarden::db::query_logger`
 - Use PostgreSQL v17 in the Alpine images to match the Debian Trixie version
 - Optimized the Workflows since `diesel_logger` isn't needed anymore

And on the extra plus-side, this lowers the compile-time and binary size too.

Signed-off-by: BlackDex <black.dex@gmail.com>
pull/6279/head
BlackDex 1 month ago
parent
commit
03aa7e5090
No known key found for this signature in database GPG Key ID: 58C80A2AA6C765E1
  1. 22
      .github/workflows/build.yml
  2. 2
      .github/workflows/hadolint.yml
  3. 8
      .github/workflows/release.yml
  4. 4
      .github/workflows/trivy.yml
  5. 2
      .github/workflows/zizmor.yml
  6. 739
      Cargo.lock
  7. 38
      Cargo.toml
  8. 6
      build.rs
  9. 6
      docker/Dockerfile.alpine
  10. 6
      docker/Dockerfile.j2
  11. 2
      macros/Cargo.toml
  12. 189
      src/api/admin.rs
  13. 256
      src/api/core/accounts.rs
  14. 345
      src/api/core/ciphers.rs
  15. 172
      src/api/core/emergency_access.rs
  16. 41
      src/api/core/events.rs
  17. 35
      src/api/core/folders.rs
  18. 18
      src/api/core/mod.rs
  19. 659
      src/api/core/organizations.rs
  20. 53
      src/api/core/public.rs
  21. 129
      src/api/core/sends.rs
  22. 38
      src/api/core/two_factor/authenticator.rs
  23. 24
      src/api/core/two_factor/duo.rs
  24. 10
      src/api/core/two_factor/duo_oidc.rs
  25. 42
      src/api/core/two_factor/email.rs
  26. 56
      src/api/core/two_factor/mod.rs
  27. 15
      src/api/core/two_factor/protected_actions.rs
  28. 49
      src/api/core/two_factor/webauthn.rs
  29. 18
      src/api/core/two_factor/yubikey.rs
  30. 59
      src/api/identity.rs
  31. 2
      src/api/mod.rs
  32. 20
      src/api/notifications.rs
  33. 26
      src/api/push.rs
  34. 20
      src/auth.rs
  35. 12
      src/config.rs
  36. 379
      src/db/mod.rs
  37. 50
      src/db/models/attachment.rs
  38. 44
      src/db/models/auth_request.rs
  39. 133
      src/db/models/cipher.rs
  40. 128
      src/db/models/collection.rs
  41. 70
      src/db/models/device.rs
  42. 91
      src/db/models/emergency_access.rs
  43. 39
      src/db/models/event.rs
  44. 18
      src/db/models/favorite.rs
  45. 53
      src/db/models/folder.rs
  46. 95
      src/db/models/group.rs
  47. 56
      src/db/models/org_policy.rs
  48. 247
      src/db/models/organization.rs
  49. 56
      src/db/models/send.rs
  50. 16
      src/db/models/sso_nonce.rs
  51. 39
      src/db/models/two_factor.rs
  52. 44
      src/db/models/two_factor_duo_context.rs
  53. 32
      src/db/models/two_factor_incomplete.rs
  54. 89
      src/db/models/user.rs
  55. 57
      src/db/query_logger.rs
  56. 0
      src/db/schema.rs
  57. 395
      src/db/schemas/mysql/schema.rs
  58. 395
      src/db/schemas/sqlite/schema.rs
  59. 50
      src/main.rs
  60. 15
      src/sso.rs

22
.github/workflows/build.yml

@ -69,9 +69,9 @@ jobs:
CHANNEL: ${{ matrix.channel }}
run: |
if [[ "${CHANNEL}" == 'rust-toolchain' ]]; then
RUST_TOOLCHAIN="$(grep -oP 'channel.*"(\K.*?)(?=")' rust-toolchain.toml)"
RUST_TOOLCHAIN="$(grep -m1 -oP 'channel.*"(\K.*?)(?=")' rust-toolchain.toml)"
elif [[ "${CHANNEL}" == 'msrv' ]]; then
RUST_TOOLCHAIN="$(grep -oP 'rust-version.*"(\K.*?)(?=")' Cargo.toml)"
RUST_TOOLCHAIN="$(grep -m1 -oP 'rust-version\s.*"(\K.*?)(?=")' Cargo.toml)"
else
RUST_TOOLCHAIN="${CHANNEL}"
fi
@ -116,7 +116,7 @@ jobs:
# Enable Rust Caching
- name: Rust Caching
uses: Swatinem/rust-cache@98c8021b550208e191a6a3145459bfc9fb29c4c0 # v2.8.0
uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1
with:
# Use a custom prefix-key to force a fresh start. This is sometimes needed with bigger changes.
# Like changing the build host from Ubuntu 20.04 to 22.04 for example.
@ -126,18 +126,6 @@ jobs:
# Run cargo tests
# First test all features together, afterwards test them separately.
- name: "test features: sqlite,mysql,postgresql,enable_mimalloc,query_logger"
id: test_sqlite_mysql_postgresql_mimalloc_logger
if: ${{ !cancelled() }}
run: |
cargo test --features sqlite,mysql,postgresql,enable_mimalloc,query_logger
- name: "test features: sqlite,mysql,postgresql,enable_mimalloc"
id: test_sqlite_mysql_postgresql_mimalloc
if: ${{ !cancelled() }}
run: |
cargo test --features sqlite,mysql,postgresql,enable_mimalloc
- name: "test features: sqlite,mysql,postgresql"
id: test_sqlite_mysql_postgresql
if: ${{ !cancelled() }}
@ -187,8 +175,6 @@ jobs:
- name: "Some checks failed"
if: ${{ failure() }}
env:
TEST_DB_M_L: ${{ steps.test_sqlite_mysql_postgresql_mimalloc_logger.outcome }}
TEST_DB_M: ${{ steps.test_sqlite_mysql_postgresql_mimalloc.outcome }}
TEST_DB: ${{ steps.test_sqlite_mysql_postgresql.outcome }}
TEST_SQLITE: ${{ steps.test_sqlite.outcome }}
TEST_MYSQL: ${{ steps.test_mysql.outcome }}
@ -200,8 +186,6 @@ jobs:
echo "" >> "${GITHUB_STEP_SUMMARY}"
echo "|Job|Status|" >> "${GITHUB_STEP_SUMMARY}"
echo "|---|------|" >> "${GITHUB_STEP_SUMMARY}"
echo "|test (sqlite,mysql,postgresql,enable_mimalloc,query_logger)|${TEST_DB_M_L}|" >> "${GITHUB_STEP_SUMMARY}"
echo "|test (sqlite,mysql,postgresql,enable_mimalloc)|${TEST_DB_M}|" >> "${GITHUB_STEP_SUMMARY}"
echo "|test (sqlite,mysql,postgresql)|${TEST_DB}|" >> "${GITHUB_STEP_SUMMARY}"
echo "|test (sqlite)|${TEST_SQLITE}|" >> "${GITHUB_STEP_SUMMARY}"
echo "|test (mysql)|${TEST_MYSQL}|" >> "${GITHUB_STEP_SUMMARY}"

2
.github/workflows/hadolint.yml

@ -31,7 +31,7 @@ jobs:
sudo curl -L https://github.com/hadolint/hadolint/releases/download/v${HADOLINT_VERSION}/hadolint-$(uname -s)-$(uname -m) -o /usr/local/bin/hadolint && \
sudo chmod +x /usr/local/bin/hadolint
env:
HADOLINT_VERSION: 2.12.0
HADOLINT_VERSION: 2.13.1
# End Download hadolint
# Checkout the repo
- name: Checkout

8
.github/workflows/release.yml

@ -204,7 +204,7 @@ jobs:
# Attest container images
- name: Attest - docker.io - ${{ matrix.base_image }}
if: ${{ env.HAVE_DOCKERHUB_LOGIN == 'true' && steps.bake_vw.outputs.metadata != ''}}
uses: actions/attest-build-provenance@e8998f949152b193b063cb0ec769d69d929409be # v2.4.0
uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
with:
subject-name: ${{ vars.DOCKERHUB_REPO }}
subject-digest: ${{ env.DIGEST_SHA }}
@ -212,7 +212,7 @@ jobs:
- name: Attest - ghcr.io - ${{ matrix.base_image }}
if: ${{ env.HAVE_GHCR_LOGIN == 'true' && steps.bake_vw.outputs.metadata != ''}}
uses: actions/attest-build-provenance@e8998f949152b193b063cb0ec769d69d929409be # v2.4.0
uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
with:
subject-name: ${{ vars.GHCR_REPO }}
subject-digest: ${{ env.DIGEST_SHA }}
@ -220,7 +220,7 @@ jobs:
- name: Attest - quay.io - ${{ matrix.base_image }}
if: ${{ env.HAVE_QUAY_LOGIN == 'true' && steps.bake_vw.outputs.metadata != ''}}
uses: actions/attest-build-provenance@e8998f949152b193b063cb0ec769d69d929409be # v2.4.0
uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
with:
subject-name: ${{ vars.QUAY_REPO }}
subject-digest: ${{ env.DIGEST_SHA }}
@ -299,7 +299,7 @@ jobs:
path: vaultwarden-armv6-${{ matrix.base_image }}
- name: "Attest artifacts ${{ matrix.base_image }}"
uses: actions/attest-build-provenance@e8998f949152b193b063cb0ec769d69d929409be # v2.4.0
uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
with:
subject-path: vaultwarden-*
# End Upload artifacts to Github Actions

4
.github/workflows/trivy.yml

@ -36,7 +36,7 @@ jobs:
persist-credentials: false
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # v0.33.0 + b6643a2
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # 0.33.1
env:
TRIVY_DB_REPOSITORY: docker.io/aquasec/trivy-db:2,public.ecr.aws/aquasecurity/trivy-db:2,ghcr.io/aquasecurity/trivy-db:2
TRIVY_JAVA_DB_REPOSITORY: docker.io/aquasec/trivy-java-db:1,public.ecr.aws/aquasecurity/trivy-java-db:1,ghcr.io/aquasecurity/trivy-java-db:1
@ -48,6 +48,6 @@ jobs:
severity: CRITICAL,HIGH
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@3c3833e0f8c1c83d449a7478aa59c036a9165498 # v3.29.11
uses: github/codeql-action/upload-sarif@192325c86100d080feab897ff886c34abd4c83a3 # v3.30.3
with:
sarif_file: 'trivy-results.sarif'

2
.github/workflows/zizmor.yml

@ -21,7 +21,7 @@ jobs:
persist-credentials: false
- name: Run zizmor
uses: zizmorcore/zizmor-action@5ca5fc7a4779c5263a3ffa0e1f693009994446d1 # v0.1.2
uses: zizmorcore/zizmor-action@e673c3917a1aef3c65c972347ed84ccd013ecda4 # v0.2.0
with:
# intentionally not scanning the entire repository,
# since it contains integration tests.

739
Cargo.lock

File diff suppressed because it is too large

38
Cargo.toml

@ -16,7 +16,11 @@ publish = false
build = "build.rs"
[features]
# default = ["sqlite"]
default = [
# "sqlite",
# "mysql",
# "postgresql",
]
# Empty to keep compatibility, prefer to set USE_SYSLOG=true
enable_syslog = []
mysql = ["diesel/mysql", "diesel_migrations/mysql"]
@ -27,11 +31,6 @@ vendored_openssl = ["openssl/vendored"]
# Enable MiMalloc memory allocator to replace the default malloc
# This can improve performance for Alpine builds
enable_mimalloc = ["dep:mimalloc"]
# This is a development dependency, and should only be used during development!
# It enables the usage of the diesel_logger crate, which is able to output the generated queries.
# You also need to set an env variable `QUERY_LOGGER=1` to fully activate this so you do not have to re-compile
# if you want to turn off the logging for a specific run.
query_logger = ["dep:diesel_logger"]
s3 = ["opendal/services-s3", "dep:aws-config", "dep:aws-credential-types", "dep:aws-smithy-runtime-api", "dep:anyhow", "dep:http", "dep:reqsign"]
# OIDC specific features
@ -50,7 +49,7 @@ syslog = "7.0.0"
macros = { path = "./macros" }
# Logging
log = "0.4.27"
log = "0.4.28"
fern = { version = "0.7.1", features = ["syslog-7", "reopen-1"] }
tracing = { version = "0.1.41", features = ["log"] } # Needed to have lettre and webauthn-rs trace logging to work
@ -81,13 +80,12 @@ tokio = { version = "1.47.1", features = ["rt-multi-thread", "fs", "io-util", "p
tokio-util = { version = "0.7.16", features = ["compat"]}
# A generic serialization/deserialization framework
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.143"
serde = { version = "1.0.225", features = ["derive"] }
serde_json = "1.0.145"
# A safe, extensible ORM and Query builder
diesel = { version = "2.2.12", features = ["chrono", "r2d2", "numeric"] }
diesel_migrations = "2.2.0"
diesel_logger = { version = "0.4.0", optional = true }
diesel = { version = "2.3.2", features = ["chrono", "r2d2", "numeric"] }
diesel_migrations = "2.3.0"
derive_more = { version = "2.0.1", features = ["from", "into", "as_ref", "deref", "display"] }
diesel-derive-newtype = "2.1.2"
@ -101,12 +99,12 @@ ring = "0.17.14"
subtle = "2.6.1"
# UUID generation
uuid = { version = "1.18.0", features = ["v4"] }
uuid = { version = "1.18.1", features = ["v4"] }
# Date and time libraries
chrono = { version = "0.4.41", features = ["clock", "serde"], default-features = false }
chrono = { version = "0.4.42", features = ["clock", "serde"], default-features = false }
chrono-tz = "0.10.4"
time = "0.3.41"
time = "0.3.44"
# Job scheduler
job_scheduler_ng = "2.3.0"
@ -157,7 +155,7 @@ cached = { version = "0.56.0", features = ["async"] }
# Used for custom short lived cookie jar during favicon extraction
cookie = "0.18.1"
cookie_store = "0.21.1"
cookie_store = "0.22.0"
# Used by U2F, JWT and PostgreSQL
openssl = "0.10.73"
@ -174,7 +172,7 @@ openidconnect = { version = "4.0.1", features = ["reqwest", "native-tls"] }
mini-moka = "0.10.3"
# Check client versions for specific features.
semver = "1.0.26"
semver = "1.0.27"
# Allow overriding the default memory allocator
# Mainly used for the musl builds, since the default musl malloc is very slow
@ -195,9 +193,9 @@ grass_compiler = { version = "0.13.4", default-features = false }
opendal = { version = "0.54.0", features = ["services-fs"], default-features = false }
# For retrieving AWS credentials, including temporary SSO credentials
anyhow = { version = "1.0.99", optional = true }
aws-config = { version = "1.8.5", features = ["behavior-version-latest", "rt-tokio", "credentials-process", "sso"], default-features = false, optional = true }
aws-credential-types = { version = "1.2.5", optional = true }
anyhow = { version = "1.0.100", optional = true }
aws-config = { version = "1.8.6", features = ["behavior-version-latest", "rt-tokio", "credentials-process", "sso"], default-features = false, optional = true }
aws-credential-types = { version = "1.2.6", optional = true }
aws-smithy-runtime-api = { version = "1.9.0", optional = true }
http = { version = "1.3.1", optional = true }
reqsign = { version = "0.16.5", optional = true }

6
build.rs

@ -9,8 +9,6 @@ fn main() {
println!("cargo:rustc-cfg=mysql");
#[cfg(feature = "postgresql")]
println!("cargo:rustc-cfg=postgresql");
#[cfg(feature = "query_logger")]
println!("cargo:rustc-cfg=query_logger");
#[cfg(feature = "s3")]
println!("cargo:rustc-cfg=s3");
@ -24,7 +22,6 @@ fn main() {
println!("cargo::rustc-check-cfg=cfg(sqlite)");
println!("cargo::rustc-check-cfg=cfg(mysql)");
println!("cargo::rustc-check-cfg=cfg(postgresql)");
println!("cargo::rustc-check-cfg=cfg(query_logger)");
println!("cargo::rustc-check-cfg=cfg(s3)");
// Rerun when these paths are changed.
@ -34,9 +31,6 @@ fn main() {
println!("cargo:rerun-if-changed=.git/index");
println!("cargo:rerun-if-changed=.git/refs/tags");
#[cfg(all(not(debug_assertions), feature = "query_logger"))]
compile_error!("Query Logging is only allowed during development, it is not intended for production usage!");
// Support $BWRS_VERSION for legacy compatibility, but default to $VW_VERSION.
// If neither exist, read from git.
let maybe_vaultwarden_version =

6
docker/Dockerfile.alpine

@ -53,9 +53,9 @@ ENV DEBIAN_FRONTEND=noninteractive \
TERM=xterm-256color \
CARGO_HOME="/root/.cargo" \
USER="root" \
# Use PostgreSQL v15 during Alpine/MUSL builds instead of the default v11
# Debian Bookworm already contains libpq v15
PQ_LIB_DIR="/usr/local/musl/pq15/lib"
# Use PostgreSQL v17 during Alpine/MUSL builds instead of the default v16
# Debian Trixie uses libpq v17
PQ_LIB_DIR="/usr/local/musl/pq17/lib"
# Create CARGO_HOME folder and don't download rust docs

6
docker/Dockerfile.j2

@ -63,9 +63,9 @@ ENV DEBIAN_FRONTEND=noninteractive \
CARGO_HOME="/root/.cargo" \
USER="root"
{%- if base == "alpine" %} \
# Use PostgreSQL v15 during Alpine/MUSL builds instead of the default v11
# Debian Bookworm already contains libpq v15
PQ_LIB_DIR="/usr/local/musl/pq15/lib"
# Use PostgreSQL v17 during Alpine/MUSL builds instead of the default v16
# Debian Trixie uses libpq v17
PQ_LIB_DIR="/usr/local/musl/pq17/lib"
{% endif %}
{% if base == "debian" %}

2
macros/Cargo.toml

@ -10,7 +10,7 @@ proc-macro = true
[dependencies]
quote = "1.0.40"
syn = "2.0.105"
syn = "2.0.106"
[lints]
workspace = true

189
src/api/admin.rs

@ -20,7 +20,14 @@ use crate::{
},
auth::{decode_admin, encode_jwt, generate_admin_claims, ClientIp, Secure},
config::ConfigBuilder,
db::{backup_database, get_sql_server_version, models::*, DbConn, DbConnType},
db::{
backup_sqlite, get_sql_server_version,
models::{
Attachment, Cipher, Collection, Device, Event, EventType, Group, Invitation, Membership, MembershipId,
MembershipType, OrgPolicy, OrgPolicyErr, Organization, OrganizationId, SsoUser, TwoFactor, User, UserId,
},
DbConn, DbConnType, ACTIVE_DB_TYPE,
},
error::{Error, MapResult},
http_client::make_http_request,
mail,
@ -75,18 +82,20 @@ pub fn catchers() -> Vec<Catcher> {
}
}
static DB_TYPE: Lazy<&str> = Lazy::new(|| {
DbConnType::from_url(&CONFIG.database_url())
.map(|t| match t {
DbConnType::sqlite => "SQLite",
DbConnType::mysql => "MySQL",
DbConnType::postgresql => "PostgreSQL",
})
.unwrap_or("Unknown")
static DB_TYPE: Lazy<&str> = Lazy::new(|| match ACTIVE_DB_TYPE.get() {
#[cfg(mysql)]
Some(DbConnType::Mysql) => "MySQL",
#[cfg(postgresql)]
Some(DbConnType::Postgresql) => "PostgreSQL",
#[cfg(sqlite)]
Some(DbConnType::Sqlite) => "SQLite",
_ => "Unknown",
});
static CAN_BACKUP: Lazy<bool> =
Lazy::new(|| DbConnType::from_url(&CONFIG.database_url()).map(|t| t == DbConnType::sqlite).unwrap_or(false));
#[cfg(sqlite)]
static CAN_BACKUP: Lazy<bool> = Lazy::new(|| ACTIVE_DB_TYPE.get().map(|t| *t == DbConnType::Sqlite).unwrap_or(false));
#[cfg(not(sqlite))]
static CAN_BACKUP: Lazy<bool> = Lazy::new(|| false);
#[get("/")]
fn admin_disabled() -> &'static str {
@ -284,7 +293,7 @@ struct InviteData {
email: String,
}
async fn get_user_or_404(user_id: &UserId, conn: &mut DbConn) -> ApiResult<User> {
async fn get_user_or_404(user_id: &UserId, conn: &DbConn) -> ApiResult<User> {
if let Some(user) = User::find_by_uuid(user_id, conn).await {
Ok(user)
} else {
@ -293,15 +302,15 @@ async fn get_user_or_404(user_id: &UserId, conn: &mut DbConn) -> ApiResult<User>
}
#[post("/invite", format = "application/json", data = "<data>")]
async fn invite_user(data: Json<InviteData>, _token: AdminToken, mut conn: DbConn) -> JsonResult {
async fn invite_user(data: Json<InviteData>, _token: AdminToken, conn: DbConn) -> JsonResult {
let data: InviteData = data.into_inner();
if User::find_by_mail(&data.email, &mut conn).await.is_some() {
if User::find_by_mail(&data.email, &conn).await.is_some() {
err_code!("User already exists", Status::Conflict.code)
}
let mut user = User::new(data.email, None);
async fn _generate_invite(user: &User, conn: &mut DbConn) -> EmptyResult {
async fn _generate_invite(user: &User, conn: &DbConn) -> EmptyResult {
if CONFIG.mail_enabled() {
let org_id: OrganizationId = FAKE_ADMIN_UUID.to_string().into();
let member_id: MembershipId = FAKE_ADMIN_UUID.to_string().into();
@ -312,10 +321,10 @@ async fn invite_user(data: Json<InviteData>, _token: AdminToken, mut conn: DbCon
}
}
_generate_invite(&user, &mut conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
user.save(&mut conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
_generate_invite(&user, &conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
user.save(&conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
Ok(Json(user.to_json(&mut conn).await))
Ok(Json(user.to_json(&conn).await))
}
#[post("/test/smtp", format = "application/json", data = "<data>")]
@ -336,14 +345,14 @@ fn logout(cookies: &CookieJar<'_>) -> Redirect {
}
#[get("/users")]
async fn get_users_json(_token: AdminToken, mut conn: DbConn) -> Json<Value> {
let users = User::get_all(&mut conn).await;
async fn get_users_json(_token: AdminToken, conn: DbConn) -> Json<Value> {
let users = User::get_all(&conn).await;
let mut users_json = Vec::with_capacity(users.len());
for (u, _) in users {
let mut usr = u.to_json(&mut conn).await;
let mut usr = u.to_json(&conn).await;
usr["userEnabled"] = json!(u.enabled);
usr["createdAt"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT));
usr["lastActive"] = match u.last_active(&mut conn).await {
usr["lastActive"] = match u.last_active(&conn).await {
Some(dt) => json!(format_naive_datetime_local(&dt, DT_FMT)),
None => json!(None::<String>),
};
@ -354,17 +363,17 @@ async fn get_users_json(_token: AdminToken, mut conn: DbConn) -> Json<Value> {
}
#[get("/users/overview")]
async fn users_overview(_token: AdminToken, mut conn: DbConn) -> ApiResult<Html<String>> {
let users = User::get_all(&mut conn).await;
async fn users_overview(_token: AdminToken, conn: DbConn) -> ApiResult<Html<String>> {
let users = User::get_all(&conn).await;
let mut users_json = Vec::with_capacity(users.len());
for (u, sso_u) in users {
let mut usr = u.to_json(&mut conn).await;
usr["cipher_count"] = json!(Cipher::count_owned_by_user(&u.uuid, &mut conn).await);
usr["attachment_count"] = json!(Attachment::count_by_user(&u.uuid, &mut conn).await);
usr["attachment_size"] = json!(get_display_size(Attachment::size_by_user(&u.uuid, &mut conn).await));
let mut usr = u.to_json(&conn).await;
usr["cipher_count"] = json!(Cipher::count_owned_by_user(&u.uuid, &conn).await);
usr["attachment_count"] = json!(Attachment::count_by_user(&u.uuid, &conn).await);
usr["attachment_size"] = json!(get_display_size(Attachment::size_by_user(&u.uuid, &conn).await));
usr["user_enabled"] = json!(u.enabled);
usr["created_at"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT));
usr["last_active"] = match u.last_active(&mut conn).await {
usr["last_active"] = match u.last_active(&conn).await {
Some(dt) => json!(format_naive_datetime_local(&dt, DT_FMT)),
None => json!("Never"),
};
@ -379,9 +388,9 @@ async fn users_overview(_token: AdminToken, mut conn: DbConn) -> ApiResult<Html<
}
#[get("/users/by-mail/<mail>")]
async fn get_user_by_mail_json(mail: &str, _token: AdminToken, mut conn: DbConn) -> JsonResult {
if let Some(u) = User::find_by_mail(mail, &mut conn).await {
let mut usr = u.to_json(&mut conn).await;
async fn get_user_by_mail_json(mail: &str, _token: AdminToken, conn: DbConn) -> JsonResult {
if let Some(u) = User::find_by_mail(mail, &conn).await {
let mut usr = u.to_json(&conn).await;
usr["userEnabled"] = json!(u.enabled);
usr["createdAt"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT));
Ok(Json(usr))
@ -391,21 +400,21 @@ async fn get_user_by_mail_json(mail: &str, _token: AdminToken, mut conn: DbConn)
}
#[get("/users/<user_id>")]
async fn get_user_json(user_id: UserId, _token: AdminToken, mut conn: DbConn) -> JsonResult {
let u = get_user_or_404(&user_id, &mut conn).await?;
let mut usr = u.to_json(&mut conn).await;
async fn get_user_json(user_id: UserId, _token: AdminToken, conn: DbConn) -> JsonResult {
let u = get_user_or_404(&user_id, &conn).await?;
let mut usr = u.to_json(&conn).await;
usr["userEnabled"] = json!(u.enabled);
usr["createdAt"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT));
Ok(Json(usr))
}
#[post("/users/<user_id>/delete", format = "application/json")]
async fn delete_user(user_id: UserId, token: AdminToken, mut conn: DbConn) -> EmptyResult {
let user = get_user_or_404(&user_id, &mut conn).await?;
async fn delete_user(user_id: UserId, token: AdminToken, conn: DbConn) -> EmptyResult {
let user = get_user_or_404(&user_id, &conn).await?;
// Get the membership records before deleting the actual user
let memberships = Membership::find_any_state_by_user(&user_id, &mut conn).await;
let res = user.delete(&mut conn).await;
let memberships = Membership::find_any_state_by_user(&user_id, &conn).await;
let res = user.delete(&conn).await;
for membership in memberships {
log_event(
@ -415,7 +424,7 @@ async fn delete_user(user_id: UserId, token: AdminToken, mut conn: DbConn) -> Em
&ACTING_ADMIN_USER.into(),
14, // Use UnknownBrowser type
&token.ip.ip,
&mut conn,
&conn,
)
.await;
}
@ -424,9 +433,9 @@ async fn delete_user(user_id: UserId, token: AdminToken, mut conn: DbConn) -> Em
}
#[delete("/users/<user_id>/sso", format = "application/json")]
async fn delete_sso_user(user_id: UserId, token: AdminToken, mut conn: DbConn) -> EmptyResult {
let memberships = Membership::find_any_state_by_user(&user_id, &mut conn).await;
let res = SsoUser::delete(&user_id, &mut conn).await;
async fn delete_sso_user(user_id: UserId, token: AdminToken, conn: DbConn) -> EmptyResult {
let memberships = Membership::find_any_state_by_user(&user_id, &conn).await;
let res = SsoUser::delete(&user_id, &conn).await;
for membership in memberships {
log_event(
@ -436,7 +445,7 @@ async fn delete_sso_user(user_id: UserId, token: AdminToken, mut conn: DbConn) -
&ACTING_ADMIN_USER.into(),
14, // Use UnknownBrowser type
&token.ip.ip,
&mut conn,
&conn,
)
.await;
}
@ -445,13 +454,13 @@ async fn delete_sso_user(user_id: UserId, token: AdminToken, mut conn: DbConn) -
}
#[post("/users/<user_id>/deauth", format = "application/json")]
async fn deauth_user(user_id: UserId, _token: AdminToken, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let mut user = get_user_or_404(&user_id, &mut conn).await?;
async fn deauth_user(user_id: UserId, _token: AdminToken, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let mut user = get_user_or_404(&user_id, &conn).await?;
nt.send_logout(&user, None, &mut conn).await;
nt.send_logout(&user, None, &conn).await;
if CONFIG.push_enabled() {
for device in Device::find_push_devices_by_user(&user.uuid, &mut conn).await {
for device in Device::find_push_devices_by_user(&user.uuid, &conn).await {
match unregister_push_device(&device.push_uuid).await {
Ok(r) => r,
Err(e) => error!("Unable to unregister devices from Bitwarden server: {e}"),
@ -459,46 +468,46 @@ async fn deauth_user(user_id: UserId, _token: AdminToken, mut conn: DbConn, nt:
}
}
Device::delete_all_by_user(&user.uuid, &mut conn).await?;
Device::delete_all_by_user(&user.uuid, &conn).await?;
user.reset_security_stamp();
user.save(&mut conn).await
user.save(&conn).await
}
#[post("/users/<user_id>/disable", format = "application/json")]
async fn disable_user(user_id: UserId, _token: AdminToken, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let mut user = get_user_or_404(&user_id, &mut conn).await?;
Device::delete_all_by_user(&user.uuid, &mut conn).await?;
async fn disable_user(user_id: UserId, _token: AdminToken, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let mut user = get_user_or_404(&user_id, &conn).await?;
Device::delete_all_by_user(&user.uuid, &conn).await?;
user.reset_security_stamp();
user.enabled = false;
let save_result = user.save(&mut conn).await;
let save_result = user.save(&conn).await;
nt.send_logout(&user, None, &mut conn).await;
nt.send_logout(&user, None, &conn).await;
save_result
}
#[post("/users/<user_id>/enable", format = "application/json")]
async fn enable_user(user_id: UserId, _token: AdminToken, mut conn: DbConn) -> EmptyResult {
let mut user = get_user_or_404(&user_id, &mut conn).await?;
async fn enable_user(user_id: UserId, _token: AdminToken, conn: DbConn) -> EmptyResult {
let mut user = get_user_or_404(&user_id, &conn).await?;
user.enabled = true;
user.save(&mut conn).await
user.save(&conn).await
}
#[post("/users/<user_id>/remove-2fa", format = "application/json")]
async fn remove_2fa(user_id: UserId, token: AdminToken, mut conn: DbConn) -> EmptyResult {
let mut user = get_user_or_404(&user_id, &mut conn).await?;
TwoFactor::delete_all_by_user(&user.uuid, &mut conn).await?;
two_factor::enforce_2fa_policy(&user, &ACTING_ADMIN_USER.into(), 14, &token.ip.ip, &mut conn).await?;
async fn remove_2fa(user_id: UserId, token: AdminToken, conn: DbConn) -> EmptyResult {
let mut user = get_user_or_404(&user_id, &conn).await?;
TwoFactor::delete_all_by_user(&user.uuid, &conn).await?;
two_factor::enforce_2fa_policy(&user, &ACTING_ADMIN_USER.into(), 14, &token.ip.ip, &conn).await?;
user.totp_recover = None;
user.save(&mut conn).await
user.save(&conn).await
}
#[post("/users/<user_id>/invite/resend", format = "application/json")]
async fn resend_user_invite(user_id: UserId, _token: AdminToken, mut conn: DbConn) -> EmptyResult {
if let Some(user) = User::find_by_uuid(&user_id, &mut conn).await {
async fn resend_user_invite(user_id: UserId, _token: AdminToken, conn: DbConn) -> EmptyResult {
if let Some(user) = User::find_by_uuid(&user_id, &conn).await {
//TODO: replace this with user.status check when it will be available (PR#3397)
if !user.password_hash.is_empty() {
err_code!("User already accepted invitation", Status::BadRequest.code);
@ -524,10 +533,10 @@ struct MembershipTypeData {
}
#[post("/users/org_type", format = "application/json", data = "<data>")]
async fn update_membership_type(data: Json<MembershipTypeData>, token: AdminToken, mut conn: DbConn) -> EmptyResult {
async fn update_membership_type(data: Json<MembershipTypeData>, token: AdminToken, conn: DbConn) -> EmptyResult {
let data: MembershipTypeData = data.into_inner();
let Some(mut member_to_edit) = Membership::find_by_user_and_org(&data.user_uuid, &data.org_uuid, &mut conn).await
let Some(mut member_to_edit) = Membership::find_by_user_and_org(&data.user_uuid, &data.org_uuid, &conn).await
else {
err!("The specified user isn't member of the organization")
};
@ -539,7 +548,7 @@ async fn update_membership_type(data: Json<MembershipTypeData>, token: AdminToke
if member_to_edit.atype == MembershipType::Owner && new_type != MembershipType::Owner {
// Removing owner permission, check that there is at least one other confirmed owner
if Membership::count_confirmed_by_org_and_type(&data.org_uuid, MembershipType::Owner, &mut conn).await <= 1 {
if Membership::count_confirmed_by_org_and_type(&data.org_uuid, MembershipType::Owner, &conn).await <= 1 {
err!("Can't change the type of the last owner")
}
}
@ -547,11 +556,11 @@ async fn update_membership_type(data: Json<MembershipTypeData>, token: AdminToke
// This check is also done at api::organizations::{accept_invite, _confirm_invite, _activate_member, edit_member}, update_membership_type
// It returns different error messages per function.
if new_type < MembershipType::Admin {
match OrgPolicy::is_user_allowed(&member_to_edit.user_uuid, &member_to_edit.org_uuid, true, &mut conn).await {
match OrgPolicy::is_user_allowed(&member_to_edit.user_uuid, &member_to_edit.org_uuid, true, &conn).await {
Ok(_) => {}
Err(OrgPolicyErr::TwoFactorMissing) => {
if CONFIG.email_2fa_auto_fallback() {
two_factor::email::find_and_activate_email_2fa(&member_to_edit.user_uuid, &mut conn).await?;
two_factor::email::find_and_activate_email_2fa(&member_to_edit.user_uuid, &conn).await?;
} else {
err!("You cannot modify this user to this type because they have not setup 2FA");
}
@ -569,32 +578,32 @@ async fn update_membership_type(data: Json<MembershipTypeData>, token: AdminToke
&ACTING_ADMIN_USER.into(),
14, // Use UnknownBrowser type
&token.ip.ip,
&mut conn,
&conn,
)
.await;
member_to_edit.atype = new_type;
member_to_edit.save(&mut conn).await
member_to_edit.save(&conn).await
}
#[post("/users/update_revision", format = "application/json")]
async fn update_revision_users(_token: AdminToken, mut conn: DbConn) -> EmptyResult {
User::update_all_revisions(&mut conn).await
async fn update_revision_users(_token: AdminToken, conn: DbConn) -> EmptyResult {
User::update_all_revisions(&conn).await
}
#[get("/organizations/overview")]
async fn organizations_overview(_token: AdminToken, mut conn: DbConn) -> ApiResult<Html<String>> {
let organizations = Organization::get_all(&mut conn).await;
async fn organizations_overview(_token: AdminToken, conn: DbConn) -> ApiResult<Html<String>> {
let organizations = Organization::get_all(&conn).await;
let mut organizations_json = Vec::with_capacity(organizations.len());
for o in organizations {
let mut org = o.to_json();
org["user_count"] = json!(Membership::count_by_org(&o.uuid, &mut conn).await);
org["cipher_count"] = json!(Cipher::count_by_org(&o.uuid, &mut conn).await);
org["collection_count"] = json!(Collection::count_by_org(&o.uuid, &mut conn).await);
org["group_count"] = json!(Group::count_by_org(&o.uuid, &mut conn).await);
org["event_count"] = json!(Event::count_by_org(&o.uuid, &mut conn).await);
org["attachment_count"] = json!(Attachment::count_by_org(&o.uuid, &mut conn).await);
org["attachment_size"] = json!(get_display_size(Attachment::size_by_org(&o.uuid, &mut conn).await));
org["user_count"] = json!(Membership::count_by_org(&o.uuid, &conn).await);
org["cipher_count"] = json!(Cipher::count_by_org(&o.uuid, &conn).await);
org["collection_count"] = json!(Collection::count_by_org(&o.uuid, &conn).await);
org["group_count"] = json!(Group::count_by_org(&o.uuid, &conn).await);
org["event_count"] = json!(Event::count_by_org(&o.uuid, &conn).await);
org["attachment_count"] = json!(Attachment::count_by_org(&o.uuid, &conn).await);
org["attachment_size"] = json!(get_display_size(Attachment::size_by_org(&o.uuid, &conn).await));
organizations_json.push(org);
}
@ -603,9 +612,9 @@ async fn organizations_overview(_token: AdminToken, mut conn: DbConn) -> ApiResu
}
#[post("/organizations/<org_id>/delete", format = "application/json")]
async fn delete_organization(org_id: OrganizationId, _token: AdminToken, mut conn: DbConn) -> EmptyResult {
let org = Organization::find_by_uuid(&org_id, &mut conn).await.map_res("Organization doesn't exist")?;
org.delete(&mut conn).await
async fn delete_organization(org_id: OrganizationId, _token: AdminToken, conn: DbConn) -> EmptyResult {
let org = Organization::find_by_uuid(&org_id, &conn).await.map_res("Organization doesn't exist")?;
org.delete(&conn).await
}
#[derive(Deserialize)]
@ -693,7 +702,7 @@ async fn get_ntp_time(has_http_access: bool) -> String {
}
#[get("/diagnostics")]
async fn diagnostics(_token: AdminToken, ip_header: IpHeader, mut conn: DbConn) -> ApiResult<Html<String>> {
async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> {
use chrono::prelude::*;
use std::net::ToSocketAddrs;
@ -747,7 +756,7 @@ async fn diagnostics(_token: AdminToken, ip_header: IpHeader, mut conn: DbConn)
"uses_proxy": uses_proxy,
"enable_websocket": &CONFIG.enable_websocket(),
"db_type": *DB_TYPE,
"db_version": get_sql_server_version(&mut conn).await,
"db_version": get_sql_server_version(&conn).await,
"admin_url": format!("{}/diagnostics", admin_url()),
"overrides": &CONFIG.get_overrides().join(", "),
"host_arch": env::consts::ARCH,
@ -791,9 +800,9 @@ async fn delete_config(_token: AdminToken) -> EmptyResult {
}
#[post("/config/backup_db", format = "application/json")]
async fn backup_db(_token: AdminToken, mut conn: DbConn) -> ApiResult<String> {
fn backup_db(_token: AdminToken) -> ApiResult<String> {
if *CAN_BACKUP {
match backup_database(&mut conn).await {
match backup_sqlite() {
Ok(f) => Ok(format!("Backup to '{f}' was successful")),
Err(e) => err!(format!("Backup was unsuccessful {e}")),
}

256
src/api/core/accounts.rs

@ -13,7 +13,14 @@ use crate::{
},
auth::{decode_delete, decode_invite, decode_verify_email, ClientHeaders, Headers},
crypto,
db::{models::*, DbConn},
db::{
models::{
AuthRequest, AuthRequestId, Cipher, CipherId, Device, DeviceId, DeviceType, EmergencyAccess,
EmergencyAccessId, EventType, Folder, FolderId, Invitation, Membership, MembershipId, OrgPolicy,
OrgPolicyType, Organization, OrganizationId, Send, SendId, User, UserId, UserKdfType,
},
DbConn,
},
mail,
util::{format_date, NumberOrString},
CONFIG,
@ -142,7 +149,7 @@ fn enforce_password_hint_setting(password_hint: &Option<String>) -> EmptyResult
}
Ok(())
}
async fn is_email_2fa_required(member_id: Option<MembershipId>, conn: &mut DbConn) -> bool {
async fn is_email_2fa_required(member_id: Option<MembershipId>, conn: &DbConn) -> bool {
if !CONFIG._enable_email_2fa() {
return false;
}
@ -160,7 +167,7 @@ async fn register(data: Json<RegisterData>, conn: DbConn) -> JsonResult {
_register(data, false, conn).await
}
pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut conn: DbConn) -> JsonResult {
pub async fn _register(data: Json<RegisterData>, email_verification: bool, conn: DbConn) -> JsonResult {
let mut data: RegisterData = data.into_inner();
let email = data.email.to_lowercase();
@ -242,7 +249,7 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut c
let password_hint = clean_password_hint(&data.master_password_hint);
enforce_password_hint_setting(&password_hint)?;
let mut user = match User::find_by_mail(&email, &mut conn).await {
let mut user = match User::find_by_mail(&email, &conn).await {
Some(user) => {
if !user.password_hash.is_empty() {
err!("Registration not allowed or user already exists")
@ -257,12 +264,12 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut c
} else {
err!("Registration email does not match invite email")
}
} else if Invitation::take(&email, &mut conn).await {
Membership::accept_user_invitations(&user.uuid, &mut conn).await?;
} else if Invitation::take(&email, &conn).await {
Membership::accept_user_invitations(&user.uuid, &conn).await?;
user
} else if CONFIG.is_signup_allowed(&email)
|| (CONFIG.emergency_access_allowed()
&& EmergencyAccess::find_invited_by_grantee_email(&email, &mut conn).await.is_some())
&& EmergencyAccess::find_invited_by_grantee_email(&email, &conn).await.is_some())
{
user
} else {
@ -273,7 +280,7 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut c
// Order is important here; the invitation check must come first
// because the vaultwarden admin can invite anyone, regardless
// of other signup restrictions.
if Invitation::take(&email, &mut conn).await
if Invitation::take(&email, &conn).await
|| CONFIG.is_signup_allowed(&email)
|| pending_emergency_access.is_some()
{
@ -285,7 +292,7 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut c
};
// Make sure we don't leave a lingering invitation.
Invitation::take(&email, &mut conn).await;
Invitation::take(&email, &conn).await;
set_kdf_data(&mut user, data.kdf)?;
@ -316,17 +323,17 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut c
error!("Error sending welcome email: {e:#?}");
}
if email_verified && is_email_2fa_required(data.organization_user_id, &mut conn).await {
email::activate_email_2fa(&user, &mut conn).await.ok();
if email_verified && is_email_2fa_required(data.organization_user_id, &conn).await {
email::activate_email_2fa(&user, &conn).await.ok();
}
}
user.save(&mut conn).await?;
user.save(&conn).await?;
// accept any open emergency access invitations
if !CONFIG.mail_enabled() && CONFIG.emergency_access_allowed() {
for mut emergency_invite in EmergencyAccess::find_all_invited_by_grantee_email(&user.email, &mut conn).await {
emergency_invite.accept_invite(&user.uuid, &user.email, &mut conn).await.ok();
for mut emergency_invite in EmergencyAccess::find_all_invited_by_grantee_email(&user.email, &conn).await {
emergency_invite.accept_invite(&user.uuid, &user.email, &conn).await.ok();
}
}
@ -337,7 +344,7 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut c
}
#[post("/accounts/set-password", data = "<data>")]
async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: SetPasswordData = data.into_inner();
let mut user = headers.user;
@ -367,30 +374,30 @@ async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, mut co
if let Some(identifier) = data.org_identifier {
if identifier != crate::sso::FAKE_IDENTIFIER {
let org = match Organization::find_by_name(&identifier, &mut conn).await {
let org = match Organization::find_by_name(&identifier, &conn).await {
None => err!("Failed to retrieve the associated organization"),
Some(org) => org,
};
let membership = match Membership::find_by_user_and_org(&user.uuid, &org.uuid, &mut conn).await {
let membership = match Membership::find_by_user_and_org(&user.uuid, &org.uuid, &conn).await {
None => err!("Failed to retrieve the invitation"),
Some(org) => org,
};
accept_org_invite(&user, membership, None, &mut conn).await?;
accept_org_invite(&user, membership, None, &conn).await?;
}
}
if CONFIG.mail_enabled() {
mail::send_welcome(&user.email.to_lowercase()).await?;
} else {
Membership::accept_user_invitations(&user.uuid, &mut conn).await?;
Membership::accept_user_invitations(&user.uuid, &conn).await?;
}
log_user_event(EventType::UserChangedPassword as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn)
log_user_event(EventType::UserChangedPassword as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn)
.await;
user.save(&mut conn).await?;
user.save(&conn).await?;
Ok(Json(json!({
"Object": "set-password",
@ -399,8 +406,8 @@ async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, mut co
}
#[get("/accounts/profile")]
async fn profile(headers: Headers, mut conn: DbConn) -> Json<Value> {
Json(headers.user.to_json(&mut conn).await)
async fn profile(headers: Headers, conn: DbConn) -> Json<Value> {
Json(headers.user.to_json(&conn).await)
}
#[derive(Debug, Deserialize)]
@ -416,7 +423,7 @@ async fn put_profile(data: Json<ProfileData>, headers: Headers, conn: DbConn) ->
}
#[post("/accounts/profile", data = "<data>")]
async fn post_profile(data: Json<ProfileData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn post_profile(data: Json<ProfileData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: ProfileData = data.into_inner();
// Check if the length of the username exceeds 50 characters (Same is Upstream Bitwarden)
@ -428,8 +435,8 @@ async fn post_profile(data: Json<ProfileData>, headers: Headers, mut conn: DbCon
let mut user = headers.user;
user.name = data.name;
user.save(&mut conn).await?;
Ok(Json(user.to_json(&mut conn).await))
user.save(&conn).await?;
Ok(Json(user.to_json(&conn).await))
}
#[derive(Deserialize)]
@ -439,7 +446,7 @@ struct AvatarData {
}
#[put("/accounts/avatar", data = "<data>")]
async fn put_avatar(data: Json<AvatarData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn put_avatar(data: Json<AvatarData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: AvatarData = data.into_inner();
// It looks like it only supports the 6 hex color format.
@ -454,13 +461,13 @@ async fn put_avatar(data: Json<AvatarData>, headers: Headers, mut conn: DbConn)
let mut user = headers.user;
user.avatar_color = data.avatar_color;
user.save(&mut conn).await?;
Ok(Json(user.to_json(&mut conn).await))
user.save(&conn).await?;
Ok(Json(user.to_json(&conn).await))
}
#[get("/users/<user_id>/public-key")]
async fn get_public_keys(user_id: UserId, _headers: Headers, mut conn: DbConn) -> JsonResult {
let user = match User::find_by_uuid(&user_id, &mut conn).await {
async fn get_public_keys(user_id: UserId, _headers: Headers, conn: DbConn) -> JsonResult {
let user = match User::find_by_uuid(&user_id, &conn).await {
Some(user) if user.public_key.is_some() => user,
Some(_) => err_code!("User has no public_key", Status::NotFound.code),
None => err_code!("User doesn't exist", Status::NotFound.code),
@ -474,7 +481,7 @@ async fn get_public_keys(user_id: UserId, _headers: Headers, mut conn: DbConn) -
}
#[post("/accounts/keys", data = "<data>")]
async fn post_keys(data: Json<KeysData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn post_keys(data: Json<KeysData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: KeysData = data.into_inner();
let mut user = headers.user;
@ -482,7 +489,7 @@ async fn post_keys(data: Json<KeysData>, headers: Headers, mut conn: DbConn) ->
user.private_key = Some(data.encrypted_private_key);
user.public_key = Some(data.public_key);
user.save(&mut conn).await?;
user.save(&conn).await?;
Ok(Json(json!({
"privateKey": user.private_key,
@ -501,7 +508,7 @@ struct ChangePassData {
}
#[post("/accounts/password", data = "<data>")]
async fn post_password(data: Json<ChangePassData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
async fn post_password(data: Json<ChangePassData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let data: ChangePassData = data.into_inner();
let mut user = headers.user;
@ -512,7 +519,7 @@ async fn post_password(data: Json<ChangePassData>, headers: Headers, mut conn: D
user.password_hint = clean_password_hint(&data.master_password_hint);
enforce_password_hint_setting(&user.password_hint)?;
log_user_event(EventType::UserChangedPassword as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn)
log_user_event(EventType::UserChangedPassword as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn)
.await;
user.set_password(
@ -527,12 +534,12 @@ async fn post_password(data: Json<ChangePassData>, headers: Headers, mut conn: D
]),
);
let save_result = user.save(&mut conn).await;
let save_result = user.save(&conn).await;
// Prevent logging out the client where the user requested this endpoint from.
// If you do logout the user it will causes issues at the client side.
// Adding the device uuid will prevent this.
nt.send_logout(&user, Some(headers.device.uuid.clone()), &mut conn).await;
nt.send_logout(&user, Some(headers.device.uuid.clone()), &conn).await;
save_result
}
@ -584,7 +591,7 @@ fn set_kdf_data(user: &mut User, data: KDFData) -> EmptyResult {
}
#[post("/accounts/kdf", data = "<data>")]
async fn post_kdf(data: Json<ChangeKdfData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
async fn post_kdf(data: Json<ChangeKdfData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let data: ChangeKdfData = data.into_inner();
let mut user = headers.user;
@ -595,9 +602,9 @@ async fn post_kdf(data: Json<ChangeKdfData>, headers: Headers, mut conn: DbConn,
set_kdf_data(&mut user, data.kdf)?;
user.set_password(&data.new_master_password_hash, Some(data.key), true, None);
let save_result = user.save(&mut conn).await;
let save_result = user.save(&conn).await;
nt.send_logout(&user, Some(headers.device.uuid.clone()), &mut conn).await;
nt.send_logout(&user, Some(headers.device.uuid.clone()), &conn).await;
save_result
}
@ -752,7 +759,7 @@ fn validate_keydata(
}
#[post("/accounts/key-management/rotate-user-account-keys", data = "<data>")]
async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
// TODO: See if we can wrap everything within a SQL Transaction. If something fails it should revert everything.
let data: KeyData = data.into_inner();
@ -770,13 +777,13 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
// TODO: Ideally we'd do everything after this point in a single transaction.
let mut existing_ciphers = Cipher::find_owned_by_user(user_id, &mut conn).await;
let mut existing_folders = Folder::find_by_user(user_id, &mut conn).await;
let mut existing_emergency_access = EmergencyAccess::find_all_by_grantor_uuid(user_id, &mut conn).await;
let mut existing_memberships = Membership::find_by_user(user_id, &mut conn).await;
let mut existing_ciphers = Cipher::find_owned_by_user(user_id, &conn).await;
let mut existing_folders = Folder::find_by_user(user_id, &conn).await;
let mut existing_emergency_access = EmergencyAccess::find_all_by_grantor_uuid(user_id, &conn).await;
let mut existing_memberships = Membership::find_by_user(user_id, &conn).await;
// We only rotate the reset password key if it is set.
existing_memberships.retain(|m| m.reset_password_key.is_some());
let mut existing_sends = Send::find_by_user(user_id, &mut conn).await;
let mut existing_sends = Send::find_by_user(user_id, &conn).await;
validate_keydata(
&data,
@ -798,7 +805,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
};
saved_folder.name = folder_data.name;
saved_folder.save(&mut conn).await?
saved_folder.save(&conn).await?
}
}
@ -811,7 +818,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
};
saved_emergency_access.key_encrypted = Some(emergency_access_data.key_encrypted);
saved_emergency_access.save(&mut conn).await?
saved_emergency_access.save(&conn).await?
}
// Update reset password data
@ -823,7 +830,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
};
membership.reset_password_key = Some(reset_password_data.reset_password_key);
membership.save(&mut conn).await?
membership.save(&conn).await?
}
// Update send data
@ -832,7 +839,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
err!("Send doesn't exist")
};
update_send_from_data(send, send_data, &headers, &mut conn, &nt, UpdateType::None).await?;
update_send_from_data(send, send_data, &headers, &conn, &nt, UpdateType::None).await?;
}
// Update cipher data
@ -848,7 +855,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
// Prevent triggering cipher updates via WebSockets by settings UpdateType::None
// The user sessions are invalidated because all the ciphers were re-encrypted and thus triggering an update could cause issues.
// We force the users to logout after the user has been saved to try and prevent these issues.
update_cipher_from_data(saved_cipher, cipher_data, &headers, None, &mut conn, &nt, UpdateType::None).await?
update_cipher_from_data(saved_cipher, cipher_data, &headers, None, &conn, &nt, UpdateType::None).await?
}
}
@ -863,28 +870,28 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
None,
);
let save_result = user.save(&mut conn).await;
let save_result = user.save(&conn).await;
// Prevent logging out the client where the user requested this endpoint from.
// If you do logout the user it will causes issues at the client side.
// Adding the device uuid will prevent this.
nt.send_logout(&user, Some(headers.device.uuid.clone()), &mut conn).await;
nt.send_logout(&user, Some(headers.device.uuid.clone()), &conn).await;
save_result
}
#[post("/accounts/security-stamp", data = "<data>")]
async fn post_sstamp(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
async fn post_sstamp(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let data: PasswordOrOtpData = data.into_inner();
let mut user = headers.user;
data.validate(&user, true, &mut conn).await?;
data.validate(&user, true, &conn).await?;
Device::delete_all_by_user(&user.uuid, &mut conn).await?;
Device::delete_all_by_user(&user.uuid, &conn).await?;
user.reset_security_stamp();
let save_result = user.save(&mut conn).await;
let save_result = user.save(&conn).await;
nt.send_logout(&user, None, &mut conn).await;
nt.send_logout(&user, None, &conn).await;
save_result
}
@ -897,7 +904,7 @@ struct EmailTokenData {
}
#[post("/accounts/email-token", data = "<data>")]
async fn post_email_token(data: Json<EmailTokenData>, headers: Headers, mut conn: DbConn) -> EmptyResult {
async fn post_email_token(data: Json<EmailTokenData>, headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.email_change_allowed() {
err!("Email change is not allowed.");
}
@ -909,7 +916,7 @@ async fn post_email_token(data: Json<EmailTokenData>, headers: Headers, mut conn
err!("Invalid password")
}
if User::find_by_mail(&data.new_email, &mut conn).await.is_some() {
if User::find_by_mail(&data.new_email, &conn).await.is_some() {
if CONFIG.mail_enabled() {
if let Err(e) = mail::send_change_email_existing(&data.new_email, &user.email).await {
error!("Error sending change-email-existing email: {e:#?}");
@ -934,7 +941,7 @@ async fn post_email_token(data: Json<EmailTokenData>, headers: Headers, mut conn
user.email_new = Some(data.new_email);
user.email_new_token = Some(token);
user.save(&mut conn).await
user.save(&conn).await
}
#[derive(Deserialize)]
@ -949,7 +956,7 @@ struct ChangeEmailData {
}
#[post("/accounts/email", data = "<data>")]
async fn post_email(data: Json<ChangeEmailData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
async fn post_email(data: Json<ChangeEmailData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
if !CONFIG.email_change_allowed() {
err!("Email change is not allowed.");
}
@ -961,7 +968,7 @@ async fn post_email(data: Json<ChangeEmailData>, headers: Headers, mut conn: DbC
err!("Invalid password")
}
if User::find_by_mail(&data.new_email, &mut conn).await.is_some() {
if User::find_by_mail(&data.new_email, &conn).await.is_some() {
err!("Email already in use");
}
@ -995,9 +1002,9 @@ async fn post_email(data: Json<ChangeEmailData>, headers: Headers, mut conn: DbC
user.set_password(&data.new_master_password_hash, Some(data.key), true, None);
let save_result = user.save(&mut conn).await;
let save_result = user.save(&conn).await;
nt.send_logout(&user, None, &mut conn).await;
nt.send_logout(&user, None, &conn).await;
save_result
}
@ -1025,10 +1032,10 @@ struct VerifyEmailTokenData {
}
#[post("/accounts/verify-email-token", data = "<data>")]
async fn post_verify_email_token(data: Json<VerifyEmailTokenData>, mut conn: DbConn) -> EmptyResult {
async fn post_verify_email_token(data: Json<VerifyEmailTokenData>, conn: DbConn) -> EmptyResult {
let data: VerifyEmailTokenData = data.into_inner();
let Some(mut user) = User::find_by_uuid(&data.user_id, &mut conn).await else {
let Some(mut user) = User::find_by_uuid(&data.user_id, &conn).await else {
err!("User doesn't exist")
};
@ -1041,7 +1048,7 @@ async fn post_verify_email_token(data: Json<VerifyEmailTokenData>, mut conn: DbC
user.verified_at = Some(Utc::now().naive_utc());
user.last_verifying_at = None;
user.login_verify_count = 0;
if let Err(e) = user.save(&mut conn).await {
if let Err(e) = user.save(&conn).await {
error!("Error saving email verification: {e:#?}");
}
@ -1055,11 +1062,11 @@ struct DeleteRecoverData {
}
#[post("/accounts/delete-recover", data = "<data>")]
async fn post_delete_recover(data: Json<DeleteRecoverData>, mut conn: DbConn) -> EmptyResult {
async fn post_delete_recover(data: Json<DeleteRecoverData>, conn: DbConn) -> EmptyResult {
let data: DeleteRecoverData = data.into_inner();
if CONFIG.mail_enabled() {
if let Some(user) = User::find_by_mail(&data.email, &mut conn).await {
if let Some(user) = User::find_by_mail(&data.email, &conn).await {
if let Err(e) = mail::send_delete_account(&user.email, &user.uuid).await {
error!("Error sending delete account email: {e:#?}");
}
@ -1082,21 +1089,21 @@ struct DeleteRecoverTokenData {
}
#[post("/accounts/delete-recover-token", data = "<data>")]
async fn post_delete_recover_token(data: Json<DeleteRecoverTokenData>, mut conn: DbConn) -> EmptyResult {
async fn post_delete_recover_token(data: Json<DeleteRecoverTokenData>, conn: DbConn) -> EmptyResult {
let data: DeleteRecoverTokenData = data.into_inner();
let Ok(claims) = decode_delete(&data.token) else {
err!("Invalid claim")
};
let Some(user) = User::find_by_uuid(&data.user_id, &mut conn).await else {
let Some(user) = User::find_by_uuid(&data.user_id, &conn).await else {
err!("User doesn't exist")
};
if claims.sub != *user.uuid {
err!("Invalid claim");
}
user.delete(&mut conn).await
user.delete(&conn).await
}
#[post("/accounts/delete", data = "<data>")]
@ -1105,13 +1112,13 @@ async fn post_delete_account(data: Json<PasswordOrOtpData>, headers: Headers, co
}
#[delete("/accounts", data = "<data>")]
async fn delete_account(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> EmptyResult {
async fn delete_account(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> EmptyResult {
let data: PasswordOrOtpData = data.into_inner();
let user = headers.user;
data.validate(&user, true, &mut conn).await?;
data.validate(&user, true, &conn).await?;
user.delete(&mut conn).await
user.delete(&conn).await
}
#[get("/accounts/revision-date")]
@ -1127,7 +1134,7 @@ struct PasswordHintData {
}
#[post("/accounts/password-hint", data = "<data>")]
async fn password_hint(data: Json<PasswordHintData>, mut conn: DbConn) -> EmptyResult {
async fn password_hint(data: Json<PasswordHintData>, conn: DbConn) -> EmptyResult {
if !CONFIG.password_hints_allowed() || (!CONFIG.mail_enabled() && !CONFIG.show_password_hint()) {
err!("This server is not configured to provide password hints.");
}
@ -1137,7 +1144,7 @@ async fn password_hint(data: Json<PasswordHintData>, mut conn: DbConn) -> EmptyR
let data: PasswordHintData = data.into_inner();
let email = &data.email;
match User::find_by_mail(email, &mut conn).await {
match User::find_by_mail(email, &conn).await {
None => {
// To prevent user enumeration, act as if the user exists.
if CONFIG.mail_enabled() {
@ -1179,10 +1186,10 @@ async fn prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
_prelogin(data, conn).await
}
pub async fn _prelogin(data: Json<PreloginData>, mut conn: DbConn) -> Json<Value> {
pub async fn _prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
let data: PreloginData = data.into_inner();
let (kdf_type, kdf_iter, kdf_mem, kdf_para) = match User::find_by_mail(&data.email, &mut conn).await {
let (kdf_type, kdf_iter, kdf_mem, kdf_para) = match User::find_by_mail(&data.email, &conn).await {
Some(user) => (user.client_kdf_type, user.client_kdf_iter, user.client_kdf_memory, user.client_kdf_parallelism),
None => (User::CLIENT_KDF_TYPE_DEFAULT, User::CLIENT_KDF_ITER_DEFAULT, None, None),
};
@ -1203,7 +1210,7 @@ struct SecretVerificationRequest {
}
// Change the KDF Iterations if necessary
pub async fn kdf_upgrade(user: &mut User, pwd_hash: &str, conn: &mut DbConn) -> ApiResult<()> {
pub async fn kdf_upgrade(user: &mut User, pwd_hash: &str, conn: &DbConn) -> ApiResult<()> {
if user.password_iterations < CONFIG.password_iterations() {
user.password_iterations = CONFIG.password_iterations();
user.set_password(pwd_hash, None, false, None);
@ -1216,7 +1223,7 @@ pub async fn kdf_upgrade(user: &mut User, pwd_hash: &str, conn: &mut DbConn) ->
}
#[post("/accounts/verify-password", data = "<data>")]
async fn verify_password(data: Json<SecretVerificationRequest>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn verify_password(data: Json<SecretVerificationRequest>, headers: Headers, conn: DbConn) -> JsonResult {
let data: SecretVerificationRequest = data.into_inner();
let mut user = headers.user;
@ -1224,22 +1231,22 @@ async fn verify_password(data: Json<SecretVerificationRequest>, headers: Headers
err!("Invalid password")
}
kdf_upgrade(&mut user, &data.master_password_hash, &mut conn).await?;
kdf_upgrade(&mut user, &data.master_password_hash, &conn).await?;
Ok(Json(master_password_policy(&user, &conn).await))
}
async fn _api_key(data: Json<PasswordOrOtpData>, rotate: bool, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn _api_key(data: Json<PasswordOrOtpData>, rotate: bool, headers: Headers, conn: DbConn) -> JsonResult {
use crate::util::format_date;
let data: PasswordOrOtpData = data.into_inner();
let mut user = headers.user;
data.validate(&user, true, &mut conn).await?;
data.validate(&user, true, &conn).await?;
if rotate || user.api_key.is_none() {
user.api_key = Some(crypto::generate_api_key());
user.save(&mut conn).await.expect("Error saving API key");
user.save(&conn).await.expect("Error saving API key");
}
Ok(Json(json!({
@ -1260,10 +1267,10 @@ async fn rotate_api_key(data: Json<PasswordOrOtpData>, headers: Headers, conn: D
}
#[get("/devices/knowndevice")]
async fn get_known_device(device: KnownDevice, mut conn: DbConn) -> JsonResult {
async fn get_known_device(device: KnownDevice, conn: DbConn) -> JsonResult {
let mut result = false;
if let Some(user) = User::find_by_mail(&device.email, &mut conn).await {
result = Device::find_by_uuid_and_user(&device.uuid, &user.uuid, &mut conn).await.is_some();
if let Some(user) = User::find_by_mail(&device.email, &conn).await {
result = Device::find_by_uuid_and_user(&device.uuid, &user.uuid, &conn).await.is_some();
}
Ok(Json(json!(result)))
}
@ -1306,8 +1313,8 @@ impl<'r> FromRequest<'r> for KnownDevice {
}
#[get("/devices")]
async fn get_all_devices(headers: Headers, mut conn: DbConn) -> JsonResult {
let devices = Device::find_with_auth_request_by_user(&headers.user.uuid, &mut conn).await;
async fn get_all_devices(headers: Headers, conn: DbConn) -> JsonResult {
let devices = Device::find_with_auth_request_by_user(&headers.user.uuid, &conn).await;
let devices = devices.iter().map(|device| device.to_json()).collect::<Vec<Value>>();
Ok(Json(json!({
@ -1318,8 +1325,8 @@ async fn get_all_devices(headers: Headers, mut conn: DbConn) -> JsonResult {
}
#[get("/devices/identifier/<device_id>")]
async fn get_device(device_id: DeviceId, headers: Headers, mut conn: DbConn) -> JsonResult {
let Some(device) = Device::find_by_uuid_and_user(&device_id, &headers.user.uuid, &mut conn).await else {
async fn get_device(device_id: DeviceId, headers: Headers, conn: DbConn) -> JsonResult {
let Some(device) = Device::find_by_uuid_and_user(&device_id, &headers.user.uuid, &conn).await else {
err!("No device found");
};
Ok(Json(device.to_json()))
@ -1337,17 +1344,11 @@ async fn post_device_token(device_id: DeviceId, data: Json<PushToken>, headers:
}
#[put("/devices/identifier/<device_id>/token", data = "<data>")]
async fn put_device_token(
device_id: DeviceId,
data: Json<PushToken>,
headers: Headers,
mut conn: DbConn,
) -> EmptyResult {
async fn put_device_token(device_id: DeviceId, data: Json<PushToken>, headers: Headers, conn: DbConn) -> EmptyResult {
let data = data.into_inner();
let token = data.push_token;
let Some(mut device) = Device::find_by_uuid_and_user(&headers.device.uuid, &headers.user.uuid, &mut conn).await
else {
let Some(mut device) = Device::find_by_uuid_and_user(&headers.device.uuid, &headers.user.uuid, &conn).await else {
err!(format!("Error: device {device_id} should be present before a token can be assigned"))
};
@ -1360,17 +1361,17 @@ async fn put_device_token(
}
device.push_token = Some(token);
if let Err(e) = device.save(&mut conn).await {
if let Err(e) = device.save(&conn).await {
err!(format!("An error occurred while trying to save the device push token: {e}"));
}
register_push_device(&mut device, &mut conn).await?;
register_push_device(&mut device, &conn).await?;
Ok(())
}
#[put("/devices/identifier/<device_id>/clear-token")]
async fn put_clear_device_token(device_id: DeviceId, mut conn: DbConn) -> EmptyResult {
async fn put_clear_device_token(device_id: DeviceId, conn: DbConn) -> EmptyResult {
// This only clears push token
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/Controllers/DevicesController.cs#L215
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/Services/Implementations/DeviceService.cs#L37
@ -1382,8 +1383,8 @@ async fn put_clear_device_token(device_id: DeviceId, mut conn: DbConn) -> EmptyR
return Ok(());
}
if let Some(device) = Device::find_by_uuid(&device_id, &mut conn).await {
Device::clear_push_token_by_uuid(&device_id, &mut conn).await?;
if let Some(device) = Device::find_by_uuid(&device_id, &conn).await {
Device::clear_push_token_by_uuid(&device_id, &conn).await?;
unregister_push_device(&device.push_uuid).await?;
}
@ -1412,17 +1413,17 @@ struct AuthRequestRequest {
async fn post_auth_request(
data: Json<AuthRequestRequest>,
client_headers: ClientHeaders,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
let data = data.into_inner();
let Some(user) = User::find_by_mail(&data.email, &mut conn).await else {
let Some(user) = User::find_by_mail(&data.email, &conn).await else {
err!("AuthRequest doesn't exist", "User not found")
};
// Validate device uuid and type
let device = match Device::find_by_uuid_and_user(&data.device_identifier, &user.uuid, &mut conn).await {
let device = match Device::find_by_uuid_and_user(&data.device_identifier, &user.uuid, &conn).await {
Some(device) if device.atype == client_headers.device_type => device,
_ => err!("AuthRequest doesn't exist", "Device verification failed"),
};
@ -1435,16 +1436,16 @@ async fn post_auth_request(
data.access_code,
data.public_key,
);
auth_request.save(&mut conn).await?;
auth_request.save(&conn).await?;
nt.send_auth_request(&user.uuid, &auth_request.uuid, &device, &mut conn).await;
nt.send_auth_request(&user.uuid, &auth_request.uuid, &device, &conn).await;
log_user_event(
EventType::UserRequestedDeviceApproval as i32,
&user.uuid,
client_headers.device_type,
&client_headers.ip.ip,
&mut conn,
&conn,
)
.await;
@ -1464,8 +1465,8 @@ async fn post_auth_request(
}
#[get("/auth-requests/<auth_request_id>")]
async fn get_auth_request(auth_request_id: AuthRequestId, headers: Headers, mut conn: DbConn) -> JsonResult {
let Some(auth_request) = AuthRequest::find_by_uuid_and_user(&auth_request_id, &headers.user.uuid, &mut conn).await
async fn get_auth_request(auth_request_id: AuthRequestId, headers: Headers, conn: DbConn) -> JsonResult {
let Some(auth_request) = AuthRequest::find_by_uuid_and_user(&auth_request_id, &headers.user.uuid, &conn).await
else {
err!("AuthRequest doesn't exist", "Record not found or user uuid does not match")
};
@ -1501,13 +1502,12 @@ async fn put_auth_request(
auth_request_id: AuthRequestId,
data: Json<AuthResponseRequest>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
ant: AnonymousNotify<'_>,
nt: Notify<'_>,
) -> JsonResult {
let data = data.into_inner();
let Some(mut auth_request) =
AuthRequest::find_by_uuid_and_user(&auth_request_id, &headers.user.uuid, &mut conn).await
let Some(mut auth_request) = AuthRequest::find_by_uuid_and_user(&auth_request_id, &headers.user.uuid, &conn).await
else {
err!("AuthRequest doesn't exist", "Record not found or user uuid does not match")
};
@ -1529,28 +1529,28 @@ async fn put_auth_request(
auth_request.master_password_hash = data.master_password_hash;
auth_request.response_device_id = Some(data.device_identifier.clone());
auth_request.response_date = Some(response_date);
auth_request.save(&mut conn).await?;
auth_request.save(&conn).await?;
ant.send_auth_response(&auth_request.user_uuid, &auth_request.uuid).await;
nt.send_auth_response(&auth_request.user_uuid, &auth_request.uuid, &headers.device, &mut conn).await;
nt.send_auth_response(&auth_request.user_uuid, &auth_request.uuid, &headers.device, &conn).await;
log_user_event(
EventType::OrganizationUserApprovedAuthRequest as i32,
&headers.user.uuid,
headers.device.atype,
&headers.ip.ip,
&mut conn,
&conn,
)
.await;
} else {
// If denied, there's no reason to keep the request
auth_request.delete(&mut conn).await?;
auth_request.delete(&conn).await?;
log_user_event(
EventType::OrganizationUserRejectedAuthRequest as i32,
&headers.user.uuid,
headers.device.atype,
&headers.ip.ip,
&mut conn,
&conn,
)
.await;
}
@ -1575,9 +1575,9 @@ async fn get_auth_request_response(
auth_request_id: AuthRequestId,
code: &str,
client_headers: ClientHeaders,
mut conn: DbConn,
conn: DbConn,
) -> JsonResult {
let Some(auth_request) = AuthRequest::find_by_uuid(&auth_request_id, &mut conn).await else {
let Some(auth_request) = AuthRequest::find_by_uuid(&auth_request_id, &conn).await else {
err!("AuthRequest doesn't exist", "User not found")
};
@ -1606,8 +1606,8 @@ async fn get_auth_request_response(
}
#[get("/auth-requests")]
async fn get_auth_requests(headers: Headers, mut conn: DbConn) -> JsonResult {
let auth_requests = AuthRequest::find_by_user(&headers.user.uuid, &mut conn).await;
async fn get_auth_requests(headers: Headers, conn: DbConn) -> JsonResult {
let auth_requests = AuthRequest::find_by_user(&headers.user.uuid, &conn).await;
Ok(Json(json!({
"data": auth_requests
@ -1637,8 +1637,8 @@ async fn get_auth_requests(headers: Headers, mut conn: DbConn) -> JsonResult {
pub async fn purge_auth_requests(pool: DbPool) {
debug!("Purging auth requests");
if let Ok(mut conn) = pool.get().await {
AuthRequest::purge_expired_auth_requests(&mut conn).await;
if let Ok(conn) = pool.get().await {
AuthRequest::purge_expired_auth_requests(&conn).await;
} else {
error!("Failed to get DB connection while purging trashed ciphers")
}

345
src/api/core/ciphers.rs

@ -17,7 +17,14 @@ use crate::{
auth::Headers,
config::PathType,
crypto,
db::{models::*, DbConn, DbPool},
db::{
models::{
Attachment, AttachmentId, Cipher, CipherId, Collection, CollectionCipher, CollectionGroup, CollectionId,
CollectionUser, EventType, Favorite, Folder, FolderCipher, FolderId, Group, Membership, MembershipType,
OrgPolicy, OrgPolicyType, OrganizationId, RepromptType, Send, UserId,
},
DbConn, DbPool,
},
CONFIG,
};
@ -93,8 +100,8 @@ pub fn routes() -> Vec<Route> {
pub async fn purge_trashed_ciphers(pool: DbPool) {
debug!("Purging trashed ciphers");
if let Ok(mut conn) = pool.get().await {
Cipher::purge_trash(&mut conn).await;
if let Ok(conn) = pool.get().await {
Cipher::purge_trash(&conn).await;
} else {
error!("Failed to get DB connection while purging trashed ciphers")
}
@ -107,11 +114,11 @@ struct SyncData {
}
#[get("/sync?<data..>")]
async fn sync(data: SyncData, headers: Headers, client_version: Option<ClientVersion>, mut conn: DbConn) -> JsonResult {
let user_json = headers.user.to_json(&mut conn).await;
async fn sync(data: SyncData, headers: Headers, client_version: Option<ClientVersion>, conn: DbConn) -> JsonResult {
let user_json = headers.user.to_json(&conn).await;
// Get all ciphers which are visible by the user
let mut ciphers = Cipher::find_by_user_visible(&headers.user.uuid, &mut conn).await;
let mut ciphers = Cipher::find_by_user_visible(&headers.user.uuid, &conn).await;
// Filter out SSH keys if the client version is less than 2024.12.0
let show_ssh_keys = if let Some(client_version) = client_version {
@ -124,31 +131,30 @@ async fn sync(data: SyncData, headers: Headers, client_version: Option<ClientVer
ciphers.retain(|c| c.atype != 5);
}
let cipher_sync_data = CipherSyncData::new(&headers.user.uuid, CipherSyncType::User, &mut conn).await;
let cipher_sync_data = CipherSyncData::new(&headers.user.uuid, CipherSyncType::User, &conn).await;
// Lets generate the ciphers_json using all the gathered info
let mut ciphers_json = Vec::with_capacity(ciphers.len());
for c in ciphers {
ciphers_json.push(
c.to_json(&headers.host, &headers.user.uuid, Some(&cipher_sync_data), CipherSyncType::User, &mut conn)
.await?,
c.to_json(&headers.host, &headers.user.uuid, Some(&cipher_sync_data), CipherSyncType::User, &conn).await?,
);
}
let collections = Collection::find_by_user_uuid(headers.user.uuid.clone(), &mut conn).await;
let collections = Collection::find_by_user_uuid(headers.user.uuid.clone(), &conn).await;
let mut collections_json = Vec::with_capacity(collections.len());
for c in collections {
collections_json.push(c.to_json_details(&headers.user.uuid, Some(&cipher_sync_data), &mut conn).await);
collections_json.push(c.to_json_details(&headers.user.uuid, Some(&cipher_sync_data), &conn).await);
}
let folders_json: Vec<Value> =
Folder::find_by_user(&headers.user.uuid, &mut conn).await.iter().map(Folder::to_json).collect();
Folder::find_by_user(&headers.user.uuid, &conn).await.iter().map(Folder::to_json).collect();
let sends_json: Vec<Value> =
Send::find_by_user(&headers.user.uuid, &mut conn).await.iter().map(Send::to_json).collect();
Send::find_by_user(&headers.user.uuid, &conn).await.iter().map(Send::to_json).collect();
let policies_json: Vec<Value> =
OrgPolicy::find_confirmed_by_user(&headers.user.uuid, &mut conn).await.iter().map(OrgPolicy::to_json).collect();
OrgPolicy::find_confirmed_by_user(&headers.user.uuid, &conn).await.iter().map(OrgPolicy::to_json).collect();
let domains_json = if data.exclude_domains {
Value::Null
@ -169,15 +175,14 @@ async fn sync(data: SyncData, headers: Headers, client_version: Option<ClientVer
}
#[get("/ciphers")]
async fn get_ciphers(headers: Headers, mut conn: DbConn) -> JsonResult {
let ciphers = Cipher::find_by_user_visible(&headers.user.uuid, &mut conn).await;
let cipher_sync_data = CipherSyncData::new(&headers.user.uuid, CipherSyncType::User, &mut conn).await;
async fn get_ciphers(headers: Headers, conn: DbConn) -> JsonResult {
let ciphers = Cipher::find_by_user_visible(&headers.user.uuid, &conn).await;
let cipher_sync_data = CipherSyncData::new(&headers.user.uuid, CipherSyncType::User, &conn).await;
let mut ciphers_json = Vec::with_capacity(ciphers.len());
for c in ciphers {
ciphers_json.push(
c.to_json(&headers.host, &headers.user.uuid, Some(&cipher_sync_data), CipherSyncType::User, &mut conn)
.await?,
c.to_json(&headers.host, &headers.user.uuid, Some(&cipher_sync_data), CipherSyncType::User, &conn).await?,
);
}
@ -189,16 +194,16 @@ async fn get_ciphers(headers: Headers, mut conn: DbConn) -> JsonResult {
}
#[get("/ciphers/<cipher_id>")]
async fn get_cipher(cipher_id: CipherId, headers: Headers, mut conn: DbConn) -> JsonResult {
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else {
async fn get_cipher(cipher_id: CipherId, headers: Headers, conn: DbConn) -> JsonResult {
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist")
};
if !cipher.is_accessible_to_user(&headers.user.uuid, &mut conn).await {
if !cipher.is_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not owned by user")
}
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?))
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?))
}
#[get("/ciphers/<cipher_id>/admin")]
@ -291,7 +296,7 @@ async fn post_ciphers_admin(data: Json<ShareCipherData>, headers: Headers, conn:
async fn post_ciphers_create(
data: Json<ShareCipherData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
let mut data: ShareCipherData = data.into_inner();
@ -305,11 +310,11 @@ async fn post_ciphers_create(
// This check is usually only needed in update_cipher_from_data(), but we
// need it here as well to avoid creating an empty cipher in the call to
// cipher.save() below.
enforce_personal_ownership_policy(Some(&data.cipher), &headers, &mut conn).await?;
enforce_personal_ownership_policy(Some(&data.cipher), &headers, &conn).await?;
let mut cipher = Cipher::new(data.cipher.r#type, data.cipher.name.clone());
cipher.user_uuid = Some(headers.user.uuid.clone());
cipher.save(&mut conn).await?;
cipher.save(&conn).await?;
// When cloning a cipher, the Bitwarden clients seem to set this field
// based on the cipher being cloned (when creating a new cipher, it's set
@ -319,12 +324,12 @@ async fn post_ciphers_create(
// or otherwise), we can just ignore this field entirely.
data.cipher.last_known_revision_date = None;
share_cipher_by_uuid(&cipher.uuid, data, &headers, &mut conn, &nt, None).await
share_cipher_by_uuid(&cipher.uuid, data, &headers, &conn, &nt, None).await
}
/// Called when creating a new user-owned cipher.
#[post("/ciphers", data = "<data>")]
async fn post_ciphers(data: Json<CipherData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult {
async fn post_ciphers(data: Json<CipherData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
let mut data: CipherData = data.into_inner();
// The web/browser clients set this field to null as expected, but the
@ -334,9 +339,9 @@ async fn post_ciphers(data: Json<CipherData>, headers: Headers, mut conn: DbConn
data.last_known_revision_date = None;
let mut cipher = Cipher::new(data.r#type, data.name.clone());
update_cipher_from_data(&mut cipher, data, &headers, None, &mut conn, &nt, UpdateType::SyncCipherCreate).await?;
update_cipher_from_data(&mut cipher, data, &headers, None, &conn, &nt, UpdateType::SyncCipherCreate).await?;
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?))
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?))
}
/// Enforces the personal ownership policy on user-owned ciphers, if applicable.
@ -346,11 +351,7 @@ async fn post_ciphers(data: Json<CipherData>, headers: Headers, mut conn: DbConn
/// allowed to delete or share such ciphers to an org, however.
///
/// Ref: https://bitwarden.com/help/article/policies/#personal-ownership
async fn enforce_personal_ownership_policy(
data: Option<&CipherData>,
headers: &Headers,
conn: &mut DbConn,
) -> EmptyResult {
async fn enforce_personal_ownership_policy(data: Option<&CipherData>, headers: &Headers, conn: &DbConn) -> EmptyResult {
if data.is_none() || data.unwrap().organization_id.is_none() {
let user_id = &headers.user.uuid;
let policy_type = OrgPolicyType::PersonalOwnership;
@ -366,7 +367,7 @@ pub async fn update_cipher_from_data(
data: CipherData,
headers: &Headers,
shared_to_collections: Option<Vec<CollectionId>>,
conn: &mut DbConn,
conn: &DbConn,
nt: &Notify<'_>,
ut: UpdateType,
) -> EmptyResult {
@ -559,13 +560,8 @@ struct RelationsData {
}
#[post("/ciphers/import", data = "<data>")]
async fn post_ciphers_import(
data: Json<ImportData>,
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
enforce_personal_ownership_policy(None, &headers, &mut conn).await?;
async fn post_ciphers_import(data: Json<ImportData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
enforce_personal_ownership_policy(None, &headers, &conn).await?;
let data: ImportData = data.into_inner();
@ -577,14 +573,14 @@ async fn post_ciphers_import(
// Read and create the folders
let existing_folders: HashSet<Option<FolderId>> =
Folder::find_by_user(&headers.user.uuid, &mut conn).await.into_iter().map(|f| Some(f.uuid)).collect();
Folder::find_by_user(&headers.user.uuid, &conn).await.into_iter().map(|f| Some(f.uuid)).collect();
let mut folders: Vec<FolderId> = Vec::with_capacity(data.folders.len());
for folder in data.folders.into_iter() {
let folder_id = if existing_folders.contains(&folder.id) {
folder.id.unwrap()
} else {
let mut new_folder = Folder::new(headers.user.uuid.clone(), folder.name);
new_folder.save(&mut conn).await?;
new_folder.save(&conn).await?;
new_folder.uuid
};
@ -604,12 +600,12 @@ async fn post_ciphers_import(
cipher_data.folder_id = folder_id;
let mut cipher = Cipher::new(cipher_data.r#type, cipher_data.name.clone());
update_cipher_from_data(&mut cipher, cipher_data, &headers, None, &mut conn, &nt, UpdateType::None).await?;
update_cipher_from_data(&mut cipher, cipher_data, &headers, None, &conn, &nt, UpdateType::None).await?;
}
let mut user = headers.user;
user.update_revision(&mut conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &mut conn).await;
user.update_revision(&conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &conn).await;
Ok(())
}
@ -653,12 +649,12 @@ async fn put_cipher(
cipher_id: CipherId,
data: Json<CipherData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
let data: CipherData = data.into_inner();
let Some(mut cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else {
let Some(mut cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist")
};
@ -667,13 +663,13 @@ async fn put_cipher(
// cipher itself, so the user shouldn't need write access to change these.
// Interestingly, upstream Bitwarden doesn't properly handle this either.
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &mut conn).await {
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not write accessible")
}
update_cipher_from_data(&mut cipher, data, &headers, None, &mut conn, &nt, UpdateType::SyncCipherUpdate).await?;
update_cipher_from_data(&mut cipher, data, &headers, None, &conn, &nt, UpdateType::SyncCipherUpdate).await?;
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?))
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?))
}
#[post("/ciphers/<cipher_id>/partial", data = "<data>")]
@ -692,26 +688,26 @@ async fn put_cipher_partial(
cipher_id: CipherId,
data: Json<PartialCipherData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
) -> JsonResult {
let data: PartialCipherData = data.into_inner();
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else {
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist")
};
if let Some(ref folder_id) = data.folder_id {
if Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, &mut conn).await.is_none() {
if Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, &conn).await.is_none() {
err!("Invalid folder", "Folder does not exist or belongs to another user");
}
}
// Move cipher
cipher.move_to_folder(data.folder_id.clone(), &headers.user.uuid, &mut conn).await?;
cipher.move_to_folder(data.folder_id.clone(), &headers.user.uuid, &conn).await?;
// Update favorite
cipher.set_favorite(Some(data.favorite), &headers.user.uuid, &mut conn).await?;
cipher.set_favorite(Some(data.favorite), &headers.user.uuid, &conn).await?;
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?))
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?))
}
#[derive(Deserialize)]
@ -764,35 +760,34 @@ async fn post_collections_update(
cipher_id: CipherId,
data: Json<CollectionsAdminData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
let data: CollectionsAdminData = data.into_inner();
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else {
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist")
};
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &mut conn).await {
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not write accessible")
}
let posted_collections = HashSet::<CollectionId>::from_iter(data.collection_ids);
let current_collections =
HashSet::<CollectionId>::from_iter(cipher.get_collections(headers.user.uuid.clone(), &mut conn).await);
HashSet::<CollectionId>::from_iter(cipher.get_collections(headers.user.uuid.clone(), &conn).await);
for collection in posted_collections.symmetric_difference(&current_collections) {
match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &mut conn).await
{
match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &conn).await {
None => err!("Invalid collection ID provided"),
Some(collection) => {
if collection.is_writable_by_user(&headers.user.uuid, &mut conn).await {
if collection.is_writable_by_user(&headers.user.uuid, &conn).await {
if posted_collections.contains(&collection.uuid) {
// Add to collection
CollectionCipher::save(&cipher.uuid, &collection.uuid, &mut conn).await?;
CollectionCipher::save(&cipher.uuid, &collection.uuid, &conn).await?;
} else {
// Remove from collection
CollectionCipher::delete(&cipher.uuid, &collection.uuid, &mut conn).await?;
CollectionCipher::delete(&cipher.uuid, &collection.uuid, &conn).await?;
}
} else {
err!("No rights to modify the collection")
@ -804,10 +799,10 @@ async fn post_collections_update(
nt.send_cipher_update(
UpdateType::SyncCipherUpdate,
&cipher,
&cipher.update_users_revision(&mut conn).await,
&cipher.update_users_revision(&conn).await,
&headers.device,
Some(Vec::from_iter(posted_collections)),
&mut conn,
&conn,
)
.await;
@ -818,11 +813,11 @@ async fn post_collections_update(
&headers.user.uuid,
headers.device.atype,
&headers.ip.ip,
&mut conn,
&conn,
)
.await;
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?))
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?))
}
#[put("/ciphers/<cipher_id>/collections-admin", data = "<data>")]
@ -841,35 +836,34 @@ async fn post_collections_admin(
cipher_id: CipherId,
data: Json<CollectionsAdminData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
let data: CollectionsAdminData = data.into_inner();
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else {
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist")
};
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &mut conn).await {
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not write accessible")
}
let posted_collections = HashSet::<CollectionId>::from_iter(data.collection_ids);
let current_collections =
HashSet::<CollectionId>::from_iter(cipher.get_admin_collections(headers.user.uuid.clone(), &mut conn).await);
HashSet::<CollectionId>::from_iter(cipher.get_admin_collections(headers.user.uuid.clone(), &conn).await);
for collection in posted_collections.symmetric_difference(&current_collections) {
match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &mut conn).await
{
match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &conn).await {
None => err!("Invalid collection ID provided"),
Some(collection) => {
if collection.is_writable_by_user(&headers.user.uuid, &mut conn).await {
if collection.is_writable_by_user(&headers.user.uuid, &conn).await {
if posted_collections.contains(&collection.uuid) {
// Add to collection
CollectionCipher::save(&cipher.uuid, &collection.uuid, &mut conn).await?;
CollectionCipher::save(&cipher.uuid, &collection.uuid, &conn).await?;
} else {
// Remove from collection
CollectionCipher::delete(&cipher.uuid, &collection.uuid, &mut conn).await?;
CollectionCipher::delete(&cipher.uuid, &collection.uuid, &conn).await?;
}
} else {
err!("No rights to modify the collection")
@ -881,10 +875,10 @@ async fn post_collections_admin(
nt.send_cipher_update(
UpdateType::SyncCipherUpdate,
&cipher,
&cipher.update_users_revision(&mut conn).await,
&cipher.update_users_revision(&conn).await,
&headers.device,
Some(Vec::from_iter(posted_collections)),
&mut conn,
&conn,
)
.await;
@ -895,7 +889,7 @@ async fn post_collections_admin(
&headers.user.uuid,
headers.device.atype,
&headers.ip.ip,
&mut conn,
&conn,
)
.await;
@ -916,12 +910,12 @@ async fn post_cipher_share(
cipher_id: CipherId,
data: Json<ShareCipherData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
let data: ShareCipherData = data.into_inner();
share_cipher_by_uuid(&cipher_id, data, &headers, &mut conn, &nt, None).await
share_cipher_by_uuid(&cipher_id, data, &headers, &conn, &nt, None).await
}
#[put("/ciphers/<cipher_id>/share", data = "<data>")]
@ -929,12 +923,12 @@ async fn put_cipher_share(
cipher_id: CipherId,
data: Json<ShareCipherData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
let data: ShareCipherData = data.into_inner();
share_cipher_by_uuid(&cipher_id, data, &headers, &mut conn, &nt, None).await
share_cipher_by_uuid(&cipher_id, data, &headers, &conn, &nt, None).await
}
#[derive(Deserialize)]
@ -948,7 +942,7 @@ struct ShareSelectedCipherData {
async fn put_cipher_share_selected(
data: Json<ShareSelectedCipherData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
let mut data: ShareSelectedCipherData = data.into_inner();
@ -975,14 +969,14 @@ async fn put_cipher_share_selected(
match shared_cipher_data.cipher.id.take() {
Some(id) => {
share_cipher_by_uuid(&id, shared_cipher_data, &headers, &mut conn, &nt, Some(UpdateType::None)).await?
share_cipher_by_uuid(&id, shared_cipher_data, &headers, &conn, &nt, Some(UpdateType::None)).await?
}
None => err!("Request missing ids field"),
};
}
// Multi share actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &mut conn).await;
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &conn).await;
Ok(())
}
@ -991,7 +985,7 @@ async fn share_cipher_by_uuid(
cipher_id: &CipherId,
data: ShareCipherData,
headers: &Headers,
conn: &mut DbConn,
conn: &DbConn,
nt: &Notify<'_>,
override_ut: Option<UpdateType>,
) -> JsonResult {
@ -1050,17 +1044,17 @@ async fn get_attachment(
cipher_id: CipherId,
attachment_id: AttachmentId,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
) -> JsonResult {
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else {
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist")
};
if !cipher.is_accessible_to_user(&headers.user.uuid, &mut conn).await {
if !cipher.is_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not accessible")
}
match Attachment::find_by_id(&attachment_id, &mut conn).await {
match Attachment::find_by_id(&attachment_id, &conn).await {
Some(attachment) if cipher_id == attachment.cipher_uuid => Ok(Json(attachment.to_json(&headers.host).await?)),
Some(_) => err!("Attachment doesn't belong to cipher"),
None => err!("Attachment doesn't exist"),
@ -1090,13 +1084,13 @@ async fn post_attachment_v2(
cipher_id: CipherId,
data: Json<AttachmentRequestData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
) -> JsonResult {
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else {
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist")
};
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &mut conn).await {
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not write accessible")
}
@ -1109,7 +1103,7 @@ async fn post_attachment_v2(
let attachment_id = crypto::generate_attachment_id();
let attachment =
Attachment::new(attachment_id.clone(), cipher.uuid.clone(), data.file_name, file_size, Some(data.key));
attachment.save(&mut conn).await.expect("Error saving attachment");
attachment.save(&conn).await.expect("Error saving attachment");
let url = format!("/ciphers/{}/attachment/{attachment_id}", cipher.uuid);
let response_key = match data.admin_request {
@ -1122,7 +1116,7 @@ async fn post_attachment_v2(
"attachmentId": attachment_id,
"url": url,
"fileUploadType": FileUploadType::Direct as i32,
response_key: cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?,
response_key: cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?,
})))
}
@ -1145,7 +1139,7 @@ async fn save_attachment(
cipher_id: CipherId,
data: Form<UploadData<'_>>,
headers: &Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> Result<(Cipher, DbConn), crate::error::Error> {
let data = data.into_inner();
@ -1157,11 +1151,11 @@ async fn save_attachment(
err!("Attachment size can't be negative")
}
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else {
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist")
};
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &mut conn).await {
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not write accessible")
}
@ -1176,7 +1170,7 @@ async fn save_attachment(
match CONFIG.user_attachment_limit() {
Some(0) => err!("Attachments are disabled"),
Some(limit_kb) => {
let already_used = Attachment::size_by_user(user_id, &mut conn).await;
let already_used = Attachment::size_by_user(user_id, &conn).await;
let left = limit_kb
.checked_mul(1024)
.and_then(|l| l.checked_sub(already_used))
@ -1198,7 +1192,7 @@ async fn save_attachment(
match CONFIG.org_attachment_limit() {
Some(0) => err!("Attachments are disabled"),
Some(limit_kb) => {
let already_used = Attachment::size_by_org(org_id, &mut conn).await;
let already_used = Attachment::size_by_org(org_id, &conn).await;
let left = limit_kb
.checked_mul(1024)
.and_then(|l| l.checked_sub(already_used))
@ -1249,10 +1243,10 @@ async fn save_attachment(
if size != attachment.file_size {
// Update the attachment with the actual file size.
attachment.file_size = size;
attachment.save(&mut conn).await.expect("Error updating attachment");
attachment.save(&conn).await.expect("Error updating attachment");
}
} else {
attachment.delete(&mut conn).await.ok();
attachment.delete(&conn).await.ok();
err!(format!("Attachment size mismatch (expected within [{min_size}, {max_size}], got {size})"));
}
@ -1272,7 +1266,7 @@ async fn save_attachment(
}
let attachment =
Attachment::new(file_id.clone(), cipher_id.clone(), encrypted_filename.unwrap(), size, data.key);
attachment.save(&mut conn).await.expect("Error saving attachment");
attachment.save(&conn).await.expect("Error saving attachment");
}
save_temp_file(PathType::Attachments, &format!("{cipher_id}/{file_id}"), data.data, true).await?;
@ -1280,10 +1274,10 @@ async fn save_attachment(
nt.send_cipher_update(
UpdateType::SyncCipherUpdate,
&cipher,
&cipher.update_users_revision(&mut conn).await,
&cipher.update_users_revision(&conn).await,
&headers.device,
None,
&mut conn,
&conn,
)
.await;
@ -1295,7 +1289,7 @@ async fn save_attachment(
&headers.user.uuid,
headers.device.atype,
&headers.ip.ip,
&mut conn,
&conn,
)
.await;
}
@ -1313,10 +1307,10 @@ async fn post_attachment_v2_data(
attachment_id: AttachmentId,
data: Form<UploadData<'_>>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
let attachment = match Attachment::find_by_id(&attachment_id, &mut conn).await {
let attachment = match Attachment::find_by_id(&attachment_id, &conn).await {
Some(attachment) if cipher_id == attachment.cipher_uuid => Some(attachment),
Some(_) => err!("Attachment doesn't belong to cipher"),
None => err!("Attachment doesn't exist"),
@ -1340,9 +1334,9 @@ async fn post_attachment(
// the attachment database record as well as saving the data to disk.
let attachment = None;
let (cipher, mut conn) = save_attachment(attachment, cipher_id, data, &headers, conn, nt).await?;
let (cipher, conn) = save_attachment(attachment, cipher_id, data, &headers, conn, nt).await?;
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?))
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?))
}
#[post("/ciphers/<cipher_id>/attachment-admin", format = "multipart/form-data", data = "<data>")]
@ -1362,10 +1356,10 @@ async fn post_attachment_share(
attachment_id: AttachmentId,
data: Form<UploadData<'_>>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &mut conn, &nt).await?;
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await?;
post_attachment(cipher_id, data, headers, conn, nt).await
}
@ -1396,10 +1390,10 @@ async fn delete_attachment(
cipher_id: CipherId,
attachment_id: AttachmentId,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &mut conn, &nt).await
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
}
#[delete("/ciphers/<cipher_id>/attachment/<attachment_id>/admin")]
@ -1407,55 +1401,45 @@ async fn delete_attachment_admin(
cipher_id: CipherId,
attachment_id: AttachmentId,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &mut conn, &nt).await
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
}
#[post("/ciphers/<cipher_id>/delete")]
async fn delete_cipher_post(cipher_id: CipherId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &CipherDeleteOptions::HardSingle, &nt).await
async fn delete_cipher_post(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete
}
#[post("/ciphers/<cipher_id>/delete-admin")]
async fn delete_cipher_post_admin(
cipher_id: CipherId,
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &CipherDeleteOptions::HardSingle, &nt).await
async fn delete_cipher_post_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete
}
#[put("/ciphers/<cipher_id>/delete")]
async fn delete_cipher_put(cipher_id: CipherId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &CipherDeleteOptions::SoftSingle, &nt).await
async fn delete_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
// soft delete
}
#[put("/ciphers/<cipher_id>/delete-admin")]
async fn delete_cipher_put_admin(
cipher_id: CipherId,
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &CipherDeleteOptions::SoftSingle, &nt).await
async fn delete_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
// soft delete
}
#[delete("/ciphers/<cipher_id>")]
async fn delete_cipher(cipher_id: CipherId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &CipherDeleteOptions::HardSingle, &nt).await
async fn delete_cipher(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete
}
#[delete("/ciphers/<cipher_id>/admin")]
async fn delete_cipher_admin(cipher_id: CipherId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &CipherDeleteOptions::HardSingle, &nt).await
async fn delete_cipher_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete
}
@ -1526,38 +1510,33 @@ async fn delete_cipher_selected_put_admin(
}
#[put("/ciphers/<cipher_id>/restore")]
async fn restore_cipher_put(cipher_id: CipherId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult {
_restore_cipher_by_uuid(&cipher_id, &headers, false, &mut conn, &nt).await
async fn restore_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
_restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
}
#[put("/ciphers/<cipher_id>/restore-admin")]
async fn restore_cipher_put_admin(
cipher_id: CipherId,
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
_restore_cipher_by_uuid(&cipher_id, &headers, false, &mut conn, &nt).await
async fn restore_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
_restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
}
#[put("/ciphers/restore-admin", data = "<data>")]
async fn restore_cipher_selected_admin(
data: Json<CipherIdsData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
_restore_multiple_ciphers(data, &headers, &mut conn, &nt).await
_restore_multiple_ciphers(data, &headers, &conn, &nt).await
}
#[put("/ciphers/restore", data = "<data>")]
async fn restore_cipher_selected(
data: Json<CipherIdsData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
_restore_multiple_ciphers(data, &headers, &mut conn, &nt).await
_restore_multiple_ciphers(data, &headers, &conn, &nt).await
}
#[derive(Deserialize)]
@ -1571,14 +1550,14 @@ struct MoveCipherData {
async fn move_cipher_selected(
data: Json<MoveCipherData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
let data = data.into_inner();
let user_id = &headers.user.uuid;
if let Some(ref folder_id) = data.folder_id {
if Folder::find_by_uuid_and_user(folder_id, user_id, &mut conn).await.is_none() {
if Folder::find_by_uuid_and_user(folder_id, user_id, &conn).await.is_none() {
err!("Invalid folder", "Folder does not exist or belongs to another user");
}
}
@ -1588,10 +1567,10 @@ async fn move_cipher_selected(
// TODO: Convert this to use a single query (or at least less) to update all items
// Find all ciphers a user has access to, all others will be ignored
let accessible_ciphers = Cipher::find_by_user_and_ciphers(user_id, &data.ids, &mut conn).await;
let accessible_ciphers = Cipher::find_by_user_and_ciphers(user_id, &data.ids, &conn).await;
let accessible_ciphers_count = accessible_ciphers.len();
for cipher in accessible_ciphers {
cipher.move_to_folder(data.folder_id.clone(), user_id, &mut conn).await?;
cipher.move_to_folder(data.folder_id.clone(), user_id, &conn).await?;
if cipher_count == 1 {
single_cipher = Some(cipher);
}
@ -1604,12 +1583,12 @@ async fn move_cipher_selected(
std::slice::from_ref(user_id),
&headers.device,
None,
&mut conn,
&conn,
)
.await;
} else {
// Multi move actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &mut conn).await;
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &conn).await;
}
if cipher_count != accessible_ciphers_count {
@ -1642,23 +1621,23 @@ async fn delete_all(
organization: Option<OrganizationIdData>,
data: Json<PasswordOrOtpData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
let data: PasswordOrOtpData = data.into_inner();
let mut user = headers.user;
data.validate(&user, true, &mut conn).await?;
data.validate(&user, true, &conn).await?;
match organization {
Some(org_data) => {
// Organization ID in query params, purging organization vault
match Membership::find_by_user_and_org(&user.uuid, &org_data.org_id, &mut conn).await {
match Membership::find_by_user_and_org(&user.uuid, &org_data.org_id, &conn).await {
None => err!("You don't have permission to purge the organization vault"),
Some(member) => {
if member.atype == MembershipType::Owner {
Cipher::delete_all_by_organization(&org_data.org_id, &mut conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &mut conn).await;
Cipher::delete_all_by_organization(&org_data.org_id, &conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &conn).await;
log_event(
EventType::OrganizationPurgedVault as i32,
@ -1667,7 +1646,7 @@ async fn delete_all(
&user.uuid,
headers.device.atype,
&headers.ip.ip,
&mut conn,
&conn,
)
.await;
@ -1681,17 +1660,17 @@ async fn delete_all(
None => {
// No organization ID in query params, purging user vault
// Delete ciphers and their attachments
for cipher in Cipher::find_owned_by_user(&user.uuid, &mut conn).await {
cipher.delete(&mut conn).await?;
for cipher in Cipher::find_owned_by_user(&user.uuid, &conn).await {
cipher.delete(&conn).await?;
}
// Delete folders
for f in Folder::find_by_user(&user.uuid, &mut conn).await {
f.delete(&mut conn).await?;
for f in Folder::find_by_user(&user.uuid, &conn).await {
f.delete(&conn).await?;
}
user.update_revision(&mut conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &mut conn).await;
user.update_revision(&conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &conn).await;
Ok(())
}
@ -1709,7 +1688,7 @@ pub enum CipherDeleteOptions {
async fn _delete_cipher_by_uuid(
cipher_id: &CipherId,
headers: &Headers,
conn: &mut DbConn,
conn: &DbConn,
delete_options: &CipherDeleteOptions,
nt: &Notify<'_>,
) -> EmptyResult {
@ -1775,20 +1754,20 @@ struct CipherIdsData {
async fn _delete_multiple_ciphers(
data: Json<CipherIdsData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
delete_options: CipherDeleteOptions,
nt: Notify<'_>,
) -> EmptyResult {
let data = data.into_inner();
for cipher_id in data.ids {
if let error @ Err(_) = _delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &delete_options, &nt).await {
if let error @ Err(_) = _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &delete_options, &nt).await {
return error;
};
}
// Multi delete actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &mut conn).await;
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &conn).await;
Ok(())
}
@ -1797,7 +1776,7 @@ async fn _restore_cipher_by_uuid(
cipher_id: &CipherId,
headers: &Headers,
multi_restore: bool,
conn: &mut DbConn,
conn: &DbConn,
nt: &Notify<'_>,
) -> JsonResult {
let Some(mut cipher) = Cipher::find_by_uuid(cipher_id, conn).await else {
@ -1842,7 +1821,7 @@ async fn _restore_cipher_by_uuid(
async fn _restore_multiple_ciphers(
data: Json<CipherIdsData>,
headers: &Headers,
conn: &mut DbConn,
conn: &DbConn,
nt: &Notify<'_>,
) -> JsonResult {
let data = data.into_inner();
@ -1869,7 +1848,7 @@ async fn _delete_cipher_attachment_by_id(
cipher_id: &CipherId,
attachment_id: &AttachmentId,
headers: &Headers,
conn: &mut DbConn,
conn: &DbConn,
nt: &Notify<'_>,
) -> JsonResult {
let Some(attachment) = Attachment::find_by_id(attachment_id, conn).await else {
@ -1938,7 +1917,7 @@ pub enum CipherSyncType {
}
impl CipherSyncData {
pub async fn new(user_id: &UserId, sync_type: CipherSyncType, conn: &mut DbConn) -> Self {
pub async fn new(user_id: &UserId, sync_type: CipherSyncType, conn: &DbConn) -> Self {
let cipher_folders: HashMap<CipherId, FolderId>;
let cipher_favorites: HashSet<CipherId>;
match sync_type {

172
src/api/core/emergency_access.rs

@ -8,7 +8,13 @@ use crate::{
EmptyResult, JsonResult,
},
auth::{decode_emergency_access_invite, Headers},
db::{models::*, DbConn, DbPool},
db::{
models::{
Cipher, EmergencyAccess, EmergencyAccessId, EmergencyAccessStatus, EmergencyAccessType, Invitation,
Membership, MembershipType, OrgPolicy, TwoFactor, User, UserId,
},
DbConn, DbPool,
},
mail,
util::NumberOrString,
CONFIG,
@ -40,7 +46,7 @@ pub fn routes() -> Vec<Route> {
// region get
#[get("/emergency-access/trusted")]
async fn get_contacts(headers: Headers, mut conn: DbConn) -> Json<Value> {
async fn get_contacts(headers: Headers, conn: DbConn) -> Json<Value> {
if !CONFIG.emergency_access_allowed() {
return Json(json!({
"data": [{
@ -58,10 +64,10 @@ async fn get_contacts(headers: Headers, mut conn: DbConn) -> Json<Value> {
"continuationToken": null
}));
}
let emergency_access_list = EmergencyAccess::find_all_by_grantor_uuid(&headers.user.uuid, &mut conn).await;
let emergency_access_list = EmergencyAccess::find_all_by_grantor_uuid(&headers.user.uuid, &conn).await;
let mut emergency_access_list_json = Vec::with_capacity(emergency_access_list.len());
for ea in emergency_access_list {
if let Some(grantee) = ea.to_json_grantee_details(&mut conn).await {
if let Some(grantee) = ea.to_json_grantee_details(&conn).await {
emergency_access_list_json.push(grantee)
}
}
@ -74,15 +80,15 @@ async fn get_contacts(headers: Headers, mut conn: DbConn) -> Json<Value> {
}
#[get("/emergency-access/granted")]
async fn get_grantees(headers: Headers, mut conn: DbConn) -> Json<Value> {
async fn get_grantees(headers: Headers, conn: DbConn) -> Json<Value> {
let emergency_access_list = if CONFIG.emergency_access_allowed() {
EmergencyAccess::find_all_by_grantee_uuid(&headers.user.uuid, &mut conn).await
EmergencyAccess::find_all_by_grantee_uuid(&headers.user.uuid, &conn).await
} else {
Vec::new()
};
let mut emergency_access_list_json = Vec::with_capacity(emergency_access_list.len());
for ea in emergency_access_list {
emergency_access_list_json.push(ea.to_json_grantor_details(&mut conn).await);
emergency_access_list_json.push(ea.to_json_grantor_details(&conn).await);
}
Json(json!({
@ -93,12 +99,12 @@ async fn get_grantees(headers: Headers, mut conn: DbConn) -> Json<Value> {
}
#[get("/emergency-access/<emer_id>")]
async fn get_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn get_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?;
match EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &mut conn).await {
match EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await {
Some(emergency_access) => Ok(Json(
emergency_access.to_json_grantee_details(&mut conn).await.expect("Grantee user should exist but does not!"),
emergency_access.to_json_grantee_details(&conn).await.expect("Grantee user should exist but does not!"),
)),
None => err!("Emergency access not valid."),
}
@ -131,14 +137,14 @@ async fn post_emergency_access(
emer_id: EmergencyAccessId,
data: Json<EmergencyAccessUpdateData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
) -> JsonResult {
check_emergency_access_enabled()?;
let data: EmergencyAccessUpdateData = data.into_inner();
let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &mut conn).await
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await
else {
err!("Emergency access not valid.")
};
@ -154,7 +160,7 @@ async fn post_emergency_access(
emergency_access.key_encrypted = data.key_encrypted;
}
emergency_access.save(&mut conn).await?;
emergency_access.save(&conn).await?;
Ok(Json(emergency_access.to_json()))
}
@ -163,12 +169,12 @@ async fn post_emergency_access(
// region delete
#[delete("/emergency-access/<emer_id>")]
async fn delete_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> EmptyResult {
async fn delete_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> EmptyResult {
check_emergency_access_enabled()?;
let emergency_access = match (
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &mut conn).await,
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &headers.user.uuid, &mut conn).await,
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await,
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &headers.user.uuid, &conn).await,
) {
(Some(grantor_emer), None) => {
info!("Grantor deleted emergency access {emer_id}");
@ -181,7 +187,7 @@ async fn delete_emergency_access(emer_id: EmergencyAccessId, headers: Headers, m
_ => err!("Emergency access not valid."),
};
emergency_access.delete(&mut conn).await?;
emergency_access.delete(&conn).await?;
Ok(())
}
@ -203,7 +209,7 @@ struct EmergencyAccessInviteData {
}
#[post("/emergency-access/invite", data = "<data>")]
async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, mut conn: DbConn) -> EmptyResult {
async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, conn: DbConn) -> EmptyResult {
check_emergency_access_enabled()?;
let data: EmergencyAccessInviteData = data.into_inner();
@ -224,7 +230,7 @@ async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, mu
err!("You can not set yourself as an emergency contact.")
}
let (grantee_user, new_user) = match User::find_by_mail(&email, &mut conn).await {
let (grantee_user, new_user) = match User::find_by_mail(&email, &conn).await {
None => {
if !CONFIG.invitations_allowed() {
err!(format!("Grantee user does not exist: {email}"))
@ -236,11 +242,11 @@ async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, mu
if !CONFIG.mail_enabled() {
let invitation = Invitation::new(&email);
invitation.save(&mut conn).await?;
invitation.save(&conn).await?;
}
let mut user = User::new(email.clone(), None);
user.save(&mut conn).await?;
user.save(&conn).await?;
(user, true)
}
Some(user) if user.password_hash.is_empty() => (user, true),
@ -251,7 +257,7 @@ async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, mu
&grantor_user.uuid,
&grantee_user.uuid,
&grantee_user.email,
&mut conn,
&conn,
)
.await
.is_some()
@ -261,7 +267,7 @@ async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, mu
let mut new_emergency_access =
EmergencyAccess::new(grantor_user.uuid, grantee_user.email, emergency_access_status, new_type, wait_time_days);
new_emergency_access.save(&mut conn).await?;
new_emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() {
mail::send_emergency_access_invite(
@ -274,18 +280,18 @@ async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, mu
.await?;
} else if !new_user {
// if mail is not enabled immediately accept the invitation for existing users
new_emergency_access.accept_invite(&grantee_user.uuid, &email, &mut conn).await?;
new_emergency_access.accept_invite(&grantee_user.uuid, &email, &conn).await?;
}
Ok(())
}
#[post("/emergency-access/<emer_id>/reinvite")]
async fn resend_invite(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> EmptyResult {
async fn resend_invite(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> EmptyResult {
check_emergency_access_enabled()?;
let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &mut conn).await
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await
else {
err!("Emergency access not valid.")
};
@ -298,7 +304,7 @@ async fn resend_invite(emer_id: EmergencyAccessId, headers: Headers, mut conn: D
err!("Email not valid.")
};
let Some(grantee_user) = User::find_by_mail(&email, &mut conn).await else {
let Some(grantee_user) = User::find_by_mail(&email, &conn).await else {
err!("Grantee user not found.")
};
@ -315,10 +321,10 @@ async fn resend_invite(emer_id: EmergencyAccessId, headers: Headers, mut conn: D
.await?;
} else if !grantee_user.password_hash.is_empty() {
// accept the invitation for existing user
emergency_access.accept_invite(&grantee_user.uuid, &email, &mut conn).await?;
} else if CONFIG.invitations_allowed() && Invitation::find_by_mail(&email, &mut conn).await.is_none() {
emergency_access.accept_invite(&grantee_user.uuid, &email, &conn).await?;
} else if CONFIG.invitations_allowed() && Invitation::find_by_mail(&email, &conn).await.is_none() {
let invitation = Invitation::new(&email);
invitation.save(&mut conn).await?;
invitation.save(&conn).await?;
}
Ok(())
@ -335,7 +341,7 @@ async fn accept_invite(
emer_id: EmergencyAccessId,
data: Json<AcceptData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
) -> EmptyResult {
check_emergency_access_enabled()?;
@ -349,9 +355,9 @@ async fn accept_invite(
err!("Claim email does not match current users email")
}
let grantee_user = match User::find_by_mail(&claims.email, &mut conn).await {
let grantee_user = match User::find_by_mail(&claims.email, &conn).await {
Some(user) => {
Invitation::take(&claims.email, &mut conn).await;
Invitation::take(&claims.email, &conn).await;
user
}
None => err!("Invited user not found"),
@ -360,13 +366,13 @@ async fn accept_invite(
// We need to search for the uuid in combination with the email, since we do not yet store the uuid of the grantee in the database.
// The uuid of the grantee gets stored once accepted.
let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantee_email(&emer_id, &headers.user.email, &mut conn).await
EmergencyAccess::find_by_uuid_and_grantee_email(&emer_id, &headers.user.email, &conn).await
else {
err!("Emergency access not valid.")
};
// get grantor user to send Accepted email
let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &mut conn).await else {
let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await else {
err!("Grantor user not found.")
};
@ -374,7 +380,7 @@ async fn accept_invite(
&& grantor_user.name == claims.grantor_name
&& grantor_user.email == claims.grantor_email
{
emergency_access.accept_invite(&grantee_user.uuid, &grantee_user.email, &mut conn).await?;
emergency_access.accept_invite(&grantee_user.uuid, &grantee_user.email, &conn).await?;
if CONFIG.mail_enabled() {
mail::send_emergency_access_invite_accepted(&grantor_user.email, &grantee_user.email).await?;
@ -397,7 +403,7 @@ async fn confirm_emergency_access(
emer_id: EmergencyAccessId,
data: Json<ConfirmData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
) -> JsonResult {
check_emergency_access_enabled()?;
@ -406,7 +412,7 @@ async fn confirm_emergency_access(
let key = data.key;
let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &confirming_user.uuid, &mut conn).await
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &confirming_user.uuid, &conn).await
else {
err!("Emergency access not valid.")
};
@ -417,12 +423,12 @@ async fn confirm_emergency_access(
err!("Emergency access not valid.")
}
let Some(grantor_user) = User::find_by_uuid(&confirming_user.uuid, &mut conn).await else {
let Some(grantor_user) = User::find_by_uuid(&confirming_user.uuid, &conn).await else {
err!("Grantor user not found.")
};
if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() {
let Some(grantee_user) = User::find_by_uuid(grantee_uuid, &mut conn).await else {
let Some(grantee_user) = User::find_by_uuid(grantee_uuid, &conn).await else {
err!("Grantee user not found.")
};
@ -430,7 +436,7 @@ async fn confirm_emergency_access(
emergency_access.key_encrypted = Some(key);
emergency_access.email = None;
emergency_access.save(&mut conn).await?;
emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() {
mail::send_emergency_access_invite_confirmed(&grantee_user.email, &grantor_user.name).await?;
@ -446,12 +452,12 @@ async fn confirm_emergency_access(
// region access emergency access
#[post("/emergency-access/<emer_id>/initiate")]
async fn initiate_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn initiate_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?;
let initiating_user = headers.user;
let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &initiating_user.uuid, &mut conn).await
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &initiating_user.uuid, &conn).await
else {
err!("Emergency access not valid.")
};
@ -460,7 +466,7 @@ async fn initiate_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
err!("Emergency access not valid.")
}
let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &mut conn).await else {
let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await else {
err!("Grantor user not found.")
};
@ -469,7 +475,7 @@ async fn initiate_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
emergency_access.updated_at = now;
emergency_access.recovery_initiated_at = Some(now);
emergency_access.last_notification_at = Some(now);
emergency_access.save(&mut conn).await?;
emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() {
mail::send_emergency_access_recovery_initiated(
@ -484,11 +490,11 @@ async fn initiate_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
}
#[post("/emergency-access/<emer_id>/approve")]
async fn approve_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn approve_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?;
let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &mut conn).await
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await
else {
err!("Emergency access not valid.")
};
@ -497,17 +503,17 @@ async fn approve_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
err!("Emergency access not valid.")
}
let Some(grantor_user) = User::find_by_uuid(&headers.user.uuid, &mut conn).await else {
let Some(grantor_user) = User::find_by_uuid(&headers.user.uuid, &conn).await else {
err!("Grantor user not found.")
};
if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() {
let Some(grantee_user) = User::find_by_uuid(grantee_uuid, &mut conn).await else {
let Some(grantee_user) = User::find_by_uuid(grantee_uuid, &conn).await else {
err!("Grantee user not found.")
};
emergency_access.status = EmergencyAccessStatus::RecoveryApproved as i32;
emergency_access.save(&mut conn).await?;
emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() {
mail::send_emergency_access_recovery_approved(&grantee_user.email, &grantor_user.name).await?;
@ -519,11 +525,11 @@ async fn approve_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
}
#[post("/emergency-access/<emer_id>/reject")]
async fn reject_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn reject_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?;
let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &mut conn).await
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await
else {
err!("Emergency access not valid.")
};
@ -535,12 +541,12 @@ async fn reject_emergency_access(emer_id: EmergencyAccessId, headers: Headers, m
}
if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() {
let Some(grantee_user) = User::find_by_uuid(grantee_uuid, &mut conn).await else {
let Some(grantee_user) = User::find_by_uuid(grantee_uuid, &conn).await else {
err!("Grantee user not found.")
};
emergency_access.status = EmergencyAccessStatus::Confirmed as i32;
emergency_access.save(&mut conn).await?;
emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() {
mail::send_emergency_access_recovery_rejected(&grantee_user.email, &headers.user.name).await?;
@ -556,11 +562,11 @@ async fn reject_emergency_access(emer_id: EmergencyAccessId, headers: Headers, m
// region action
#[post("/emergency-access/<emer_id>/view")]
async fn view_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn view_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?;
let Some(emergency_access) =
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &headers.user.uuid, &mut conn).await
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &headers.user.uuid, &conn).await
else {
err!("Emergency access not valid.")
};
@ -569,8 +575,8 @@ async fn view_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut
err!("Emergency access not valid.")
}
let ciphers = Cipher::find_owned_by_user(&emergency_access.grantor_uuid, &mut conn).await;
let cipher_sync_data = CipherSyncData::new(&emergency_access.grantor_uuid, CipherSyncType::User, &mut conn).await;
let ciphers = Cipher::find_owned_by_user(&emergency_access.grantor_uuid, &conn).await;
let cipher_sync_data = CipherSyncData::new(&emergency_access.grantor_uuid, CipherSyncType::User, &conn).await;
let mut ciphers_json = Vec::with_capacity(ciphers.len());
for c in ciphers {
@ -580,7 +586,7 @@ async fn view_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut
&emergency_access.grantor_uuid,
Some(&cipher_sync_data),
CipherSyncType::User,
&mut conn,
&conn,
)
.await?,
);
@ -594,12 +600,12 @@ async fn view_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut
}
#[post("/emergency-access/<emer_id>/takeover")]
async fn takeover_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn takeover_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?;
let requesting_user = headers.user;
let Some(emergency_access) =
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &requesting_user.uuid, &mut conn).await
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &requesting_user.uuid, &conn).await
else {
err!("Emergency access not valid.")
};
@ -608,7 +614,7 @@ async fn takeover_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
err!("Emergency access not valid.")
}
let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &mut conn).await else {
let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await else {
err!("Grantor user not found.")
};
@ -636,7 +642,7 @@ async fn password_emergency_access(
emer_id: EmergencyAccessId,
data: Json<EmergencyAccessPasswordData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
) -> EmptyResult {
check_emergency_access_enabled()?;
@ -646,7 +652,7 @@ async fn password_emergency_access(
let requesting_user = headers.user;
let Some(emergency_access) =
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &requesting_user.uuid, &mut conn).await
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &requesting_user.uuid, &conn).await
else {
err!("Emergency access not valid.")
};
@ -655,21 +661,21 @@ async fn password_emergency_access(
err!("Emergency access not valid.")
}
let Some(mut grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &mut conn).await else {
let Some(mut grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await else {
err!("Grantor user not found.")
};
// change grantor_user password
grantor_user.set_password(new_master_password_hash, Some(data.key), true, None);
grantor_user.save(&mut conn).await?;
grantor_user.save(&conn).await?;
// Disable TwoFactor providers since they will otherwise block logins
TwoFactor::delete_all_by_user(&grantor_user.uuid, &mut conn).await?;
TwoFactor::delete_all_by_user(&grantor_user.uuid, &conn).await?;
// Remove grantor from all organisations unless Owner
for member in Membership::find_any_state_by_user(&grantor_user.uuid, &mut conn).await {
for member in Membership::find_any_state_by_user(&grantor_user.uuid, &conn).await {
if member.atype != MembershipType::Owner as i32 {
member.delete(&mut conn).await?;
member.delete(&conn).await?;
}
}
Ok(())
@ -678,10 +684,10 @@ async fn password_emergency_access(
// endregion
#[get("/emergency-access/<emer_id>/policies")]
async fn policies_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn policies_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
let requesting_user = headers.user;
let Some(emergency_access) =
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &requesting_user.uuid, &mut conn).await
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &requesting_user.uuid, &conn).await
else {
err!("Emergency access not valid.")
};
@ -690,11 +696,11 @@ async fn policies_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
err!("Emergency access not valid.")
}
let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &mut conn).await else {
let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await else {
err!("Grantor user not found.")
};
let policies = OrgPolicy::find_confirmed_by_user(&grantor_user.uuid, &mut conn);
let policies = OrgPolicy::find_confirmed_by_user(&grantor_user.uuid, &conn);
let policies_json: Vec<Value> = policies.await.iter().map(OrgPolicy::to_json).collect();
Ok(Json(json!({
@ -728,8 +734,8 @@ pub async fn emergency_request_timeout_job(pool: DbPool) {
return;
}
if let Ok(mut conn) = pool.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await;
if let Ok(conn) = pool.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&conn).await;
if emergency_access_list.is_empty() {
debug!("No emergency request timeout to approve");
@ -743,18 +749,18 @@ pub async fn emergency_request_timeout_job(pool: DbPool) {
if recovery_allowed_at.le(&now) {
// Only update the access status
// Updating the whole record could cause issues when the emergency_notification_reminder_job is also active
emer.update_access_status_and_save(EmergencyAccessStatus::RecoveryApproved as i32, &now, &mut conn)
emer.update_access_status_and_save(EmergencyAccessStatus::RecoveryApproved as i32, &now, &conn)
.await
.expect("Unable to update emergency access status");
if CONFIG.mail_enabled() {
// get grantor user to send Accepted email
let grantor_user =
User::find_by_uuid(&emer.grantor_uuid, &mut conn).await.expect("Grantor user not found");
User::find_by_uuid(&emer.grantor_uuid, &conn).await.expect("Grantor user not found");
// get grantee user to send Accepted email
let grantee_user =
User::find_by_uuid(&emer.grantee_uuid.clone().expect("Grantee user invalid"), &mut conn)
User::find_by_uuid(&emer.grantee_uuid.clone().expect("Grantee user invalid"), &conn)
.await
.expect("Grantee user not found");
@ -783,8 +789,8 @@ pub async fn emergency_notification_reminder_job(pool: DbPool) {
return;
}
if let Ok(mut conn) = pool.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await;
if let Ok(conn) = pool.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&conn).await;
if emergency_access_list.is_empty() {
debug!("No emergency request reminder notification to send");
@ -805,18 +811,18 @@ pub async fn emergency_notification_reminder_job(pool: DbPool) {
if final_recovery_reminder_at.le(&now) && next_recovery_reminder_at.le(&now) {
// Only update the last notification date
// Updating the whole record could cause issues when the emergency_request_timeout_job is also active
emer.update_last_notification_date_and_save(&now, &mut conn)
emer.update_last_notification_date_and_save(&now, &conn)
.await
.expect("Unable to update emergency access notification date");
if CONFIG.mail_enabled() {
// get grantor user to send Accepted email
let grantor_user =
User::find_by_uuid(&emer.grantor_uuid, &mut conn).await.expect("Grantor user not found");
User::find_by_uuid(&emer.grantor_uuid, &conn).await.expect("Grantor user not found");
// get grantee user to send Accepted email
let grantee_user =
User::find_by_uuid(&emer.grantee_uuid.clone().expect("Grantee user invalid"), &mut conn)
User::find_by_uuid(&emer.grantee_uuid.clone().expect("Grantee user invalid"), &conn)
.await
.expect("Grantee user not found");

41
src/api/core/events.rs

@ -31,12 +31,7 @@ struct EventRange {
// Upstream: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/AdminConsole/Controllers/EventsController.cs#L87
#[get("/organizations/<org_id>/events?<data..>")]
async fn get_org_events(
org_id: OrganizationId,
data: EventRange,
headers: AdminHeaders,
mut conn: DbConn,
) -> JsonResult {
async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: AdminHeaders, conn: DbConn) -> JsonResult {
if org_id != headers.org_id {
err!("Organization not found", "Organization id's do not match");
}
@ -53,7 +48,7 @@ async fn get_org_events(
parse_date(&data.end)
};
Event::find_by_organization_uuid(&org_id, &start_date, &end_date, &mut conn)
Event::find_by_organization_uuid(&org_id, &start_date, &end_date, &conn)
.await
.iter()
.map(|e| e.to_json())
@ -68,14 +63,14 @@ async fn get_org_events(
}
#[get("/ciphers/<cipher_id>/events?<data..>")]
async fn get_cipher_events(cipher_id: CipherId, data: EventRange, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn get_cipher_events(cipher_id: CipherId, data: EventRange, headers: Headers, conn: DbConn) -> JsonResult {
// Return an empty vec when we org events are disabled.
// This prevents client errors
let events_json: Vec<Value> = if !CONFIG.org_events_enabled() {
Vec::with_capacity(0)
} else {
let mut events_json = Vec::with_capacity(0);
if Membership::user_has_ge_admin_access_to_cipher(&headers.user.uuid, &cipher_id, &mut conn).await {
if Membership::user_has_ge_admin_access_to_cipher(&headers.user.uuid, &cipher_id, &conn).await {
let start_date = parse_date(&data.start);
let end_date = if let Some(before_date) = &data.continuation_token {
parse_date(before_date)
@ -83,7 +78,7 @@ async fn get_cipher_events(cipher_id: CipherId, data: EventRange, headers: Heade
parse_date(&data.end)
};
events_json = Event::find_by_cipher_uuid(&cipher_id, &start_date, &end_date, &mut conn)
events_json = Event::find_by_cipher_uuid(&cipher_id, &start_date, &end_date, &conn)
.await
.iter()
.map(|e| e.to_json())
@ -105,7 +100,7 @@ async fn get_user_events(
member_id: MembershipId,
data: EventRange,
headers: AdminHeaders,
mut conn: DbConn,
conn: DbConn,
) -> JsonResult {
if org_id != headers.org_id {
err!("Organization not found", "Organization id's do not match");
@ -122,7 +117,7 @@ async fn get_user_events(
parse_date(&data.end)
};
Event::find_by_org_and_member(&org_id, &member_id, &start_date, &end_date, &mut conn)
Event::find_by_org_and_member(&org_id, &member_id, &start_date, &end_date, &conn)
.await
.iter()
.map(|e| e.to_json())
@ -172,7 +167,7 @@ struct EventCollection {
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Events/Controllers/CollectController.cs
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Services/Implementations/EventService.cs
#[post("/collect", format = "application/json", data = "<data>")]
async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers, mut conn: DbConn) -> EmptyResult {
async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.org_events_enabled() {
return Ok(());
}
@ -187,7 +182,7 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
headers.device.atype,
Some(event_date),
&headers.ip.ip,
&mut conn,
&conn,
)
.await;
}
@ -201,14 +196,14 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
headers.device.atype,
Some(event_date),
&headers.ip.ip,
&mut conn,
&conn,
)
.await;
}
}
_ => {
if let Some(cipher_uuid) = &event.cipher_id {
if let Some(cipher) = Cipher::find_by_uuid(cipher_uuid, &mut conn).await {
if let Some(cipher) = Cipher::find_by_uuid(cipher_uuid, &conn).await {
if let Some(org_id) = cipher.organization_uuid {
_log_event(
event.r#type,
@ -218,7 +213,7 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
headers.device.atype,
Some(event_date),
&headers.ip.ip,
&mut conn,
&conn,
)
.await;
}
@ -230,7 +225,7 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
Ok(())
}
pub async fn log_user_event(event_type: i32, user_id: &UserId, device_type: i32, ip: &IpAddr, conn: &mut DbConn) {
pub async fn log_user_event(event_type: i32, user_id: &UserId, device_type: i32, ip: &IpAddr, conn: &DbConn) {
if !CONFIG.org_events_enabled() {
return;
}
@ -243,7 +238,7 @@ async fn _log_user_event(
device_type: i32,
event_date: Option<NaiveDateTime>,
ip: &IpAddr,
conn: &mut DbConn,
conn: &DbConn,
) {
let memberships = Membership::find_by_user(user_id, conn).await;
let mut events: Vec<Event> = Vec::with_capacity(memberships.len() + 1); // We need an event per org and one without an org
@ -278,7 +273,7 @@ pub async fn log_event(
act_user_id: &UserId,
device_type: i32,
ip: &IpAddr,
conn: &mut DbConn,
conn: &DbConn,
) {
if !CONFIG.org_events_enabled() {
return;
@ -295,7 +290,7 @@ async fn _log_event(
device_type: i32,
event_date: Option<NaiveDateTime>,
ip: &IpAddr,
conn: &mut DbConn,
conn: &DbConn,
) {
// Create a new empty event
let mut event = Event::new(event_type, event_date);
@ -340,8 +335,8 @@ pub async fn event_cleanup_job(pool: DbPool) {
return;
}
if let Ok(mut conn) = pool.get().await {
Event::clean_events(&mut conn).await.ok();
if let Ok(conn) = pool.get().await {
Event::clean_events(&conn).await.ok();
} else {
error!("Failed to get DB connection while trying to cleanup the events table")
}

35
src/api/core/folders.rs

@ -4,7 +4,10 @@ use serde_json::Value;
use crate::{
api::{EmptyResult, JsonResult, Notify, UpdateType},
auth::Headers,
db::{models::*, DbConn},
db::{
models::{Folder, FolderId},
DbConn,
},
};
pub fn routes() -> Vec<rocket::Route> {
@ -12,8 +15,8 @@ pub fn routes() -> Vec<rocket::Route> {
}
#[get("/folders")]
async fn get_folders(headers: Headers, mut conn: DbConn) -> Json<Value> {
let folders = Folder::find_by_user(&headers.user.uuid, &mut conn).await;
async fn get_folders(headers: Headers, conn: DbConn) -> Json<Value> {
let folders = Folder::find_by_user(&headers.user.uuid, &conn).await;
let folders_json: Vec<Value> = folders.iter().map(Folder::to_json).collect();
Json(json!({
@ -24,8 +27,8 @@ async fn get_folders(headers: Headers, mut conn: DbConn) -> Json<Value> {
}
#[get("/folders/<folder_id>")]
async fn get_folder(folder_id: FolderId, headers: Headers, mut conn: DbConn) -> JsonResult {
match Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &mut conn).await {
async fn get_folder(folder_id: FolderId, headers: Headers, conn: DbConn) -> JsonResult {
match Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await {
Some(folder) => Ok(Json(folder.to_json())),
_ => err!("Invalid folder", "Folder does not exist or belongs to another user"),
}
@ -39,13 +42,13 @@ pub struct FolderData {
}
#[post("/folders", data = "<data>")]
async fn post_folders(data: Json<FolderData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult {
async fn post_folders(data: Json<FolderData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
let data: FolderData = data.into_inner();
let mut folder = Folder::new(headers.user.uuid, data.name);
folder.save(&mut conn).await?;
nt.send_folder_update(UpdateType::SyncFolderCreate, &folder, &headers.device, &mut conn).await;
folder.save(&conn).await?;
nt.send_folder_update(UpdateType::SyncFolderCreate, &folder, &headers.device, &conn).await;
Ok(Json(folder.to_json()))
}
@ -66,19 +69,19 @@ async fn put_folder(
folder_id: FolderId,
data: Json<FolderData>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
let data: FolderData = data.into_inner();
let Some(mut folder) = Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &mut conn).await else {
let Some(mut folder) = Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await else {
err!("Invalid folder", "Folder does not exist or belongs to another user")
};
folder.name = data.name;
folder.save(&mut conn).await?;
nt.send_folder_update(UpdateType::SyncFolderUpdate, &folder, &headers.device, &mut conn).await;
folder.save(&conn).await?;
nt.send_folder_update(UpdateType::SyncFolderUpdate, &folder, &headers.device, &conn).await;
Ok(Json(folder.to_json()))
}
@ -89,14 +92,14 @@ async fn delete_folder_post(folder_id: FolderId, headers: Headers, conn: DbConn,
}
#[delete("/folders/<folder_id>")]
async fn delete_folder(folder_id: FolderId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let Some(folder) = Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &mut conn).await else {
async fn delete_folder(folder_id: FolderId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let Some(folder) = Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await else {
err!("Invalid folder", "Folder does not exist or belongs to another user")
};
// Delete the actual folder entry
folder.delete(&mut conn).await?;
folder.delete(&conn).await?;
nt.send_folder_update(UpdateType::SyncFolderDelete, &folder, &headers.device, &mut conn).await;
nt.send_folder_update(UpdateType::SyncFolderDelete, &folder, &headers.device, &conn).await;
Ok(())
}

18
src/api/core/mod.rs

@ -52,7 +52,10 @@ use rocket::{serde::json::Json, serde::json::Value, Catcher, Route};
use crate::{
api::{EmptyResult, JsonResult, Notify, UpdateType},
auth::Headers,
db::{models::*, DbConn},
db::{
models::{Membership, MembershipStatus, MembershipType, OrgPolicy, OrgPolicyErr, Organization, User},
DbConn,
},
error::Error,
http_client::make_http_request,
mail,
@ -106,12 +109,7 @@ struct EquivDomainData {
}
#[post("/settings/domains", data = "<data>")]
async fn post_eq_domains(
data: Json<EquivDomainData>,
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
async fn post_eq_domains(data: Json<EquivDomainData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
let data: EquivDomainData = data.into_inner();
let excluded_globals = data.excluded_global_equivalent_domains.unwrap_or_default();
@ -123,9 +121,9 @@ async fn post_eq_domains(
user.excluded_globals = to_string(&excluded_globals).unwrap_or_else(|_| "[]".to_string());
user.equivalent_domains = to_string(&equivalent_domains).unwrap_or_else(|_| "[]".to_string());
user.save(&mut conn).await?;
user.save(&conn).await?;
nt.send_user_update(UpdateType::SyncSettings, &user, &headers.device.push_uuid, &mut conn).await;
nt.send_user_update(UpdateType::SyncSettings, &user, &headers.device.push_uuid, &conn).await;
Ok(Json(json!({})))
}
@ -265,7 +263,7 @@ async fn accept_org_invite(
user: &User,
mut member: Membership,
reset_password_key: Option<String>,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
if member.status != MembershipStatus::Invited as i32 {
err!("User already accepted the invitation");

659
src/api/core/organizations.rs

File diff suppressed because it is too large

53
src/api/core/public.rs

@ -10,7 +10,13 @@ use std::collections::HashSet;
use crate::{
api::EmptyResult,
auth,
db::{models::*, DbConn},
db::{
models::{
Group, GroupUser, Invitation, Membership, MembershipStatus, MembershipType, Organization,
OrganizationApiKey, OrganizationId, User,
},
DbConn,
},
mail, CONFIG,
};
@ -44,7 +50,7 @@ struct OrgImportData {
}
#[post("/public/organization/import", data = "<data>")]
async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: DbConn) -> EmptyResult {
async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn) -> EmptyResult {
// Most of the logic for this function can be found here
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs#L1203
@ -55,13 +61,12 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
let mut user_created: bool = false;
if user_data.deleted {
// If user is marked for deletion and it exists, revoke it
if let Some(mut member) = Membership::find_by_email_and_org(&user_data.email, &org_id, &mut conn).await {
if let Some(mut member) = Membership::find_by_email_and_org(&user_data.email, &org_id, &conn).await {
// Only revoke a user if it is not the last confirmed owner
let revoked = if member.atype == MembershipType::Owner
&& member.status == MembershipStatus::Confirmed as i32
{
if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &mut conn).await <= 1
{
if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &conn).await <= 1 {
warn!("Can't revoke the last owner");
false
} else {
@ -73,27 +78,27 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
let ext_modified = member.set_external_id(Some(user_data.external_id.clone()));
if revoked || ext_modified {
member.save(&mut conn).await?;
member.save(&conn).await?;
}
}
// If user is part of the organization, restore it
} else if let Some(mut member) = Membership::find_by_email_and_org(&user_data.email, &org_id, &mut conn).await {
} else if let Some(mut member) = Membership::find_by_email_and_org(&user_data.email, &org_id, &conn).await {
let restored = member.restore();
let ext_modified = member.set_external_id(Some(user_data.external_id.clone()));
if restored || ext_modified {
member.save(&mut conn).await?;
member.save(&conn).await?;
}
} else {
// If user is not part of the organization
let user = match User::find_by_mail(&user_data.email, &mut conn).await {
let user = match User::find_by_mail(&user_data.email, &conn).await {
Some(user) => user, // exists in vaultwarden
None => {
// User does not exist yet
let mut new_user = User::new(user_data.email.clone(), None);
new_user.save(&mut conn).await?;
new_user.save(&conn).await?;
if !CONFIG.mail_enabled() {
Invitation::new(&new_user.email).save(&mut conn).await?;
Invitation::new(&new_user.email).save(&conn).await?;
}
user_created = true;
new_user
@ -105,7 +110,7 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
MembershipStatus::Accepted as i32 // Automatically mark user as accepted if no email invites
};
let (org_name, org_email) = match Organization::find_by_uuid(&org_id, &mut conn).await {
let (org_name, org_email) = match Organization::find_by_uuid(&org_id, &conn).await {
Some(org) => (org.name, org.billing_email),
None => err!("Error looking up organization"),
};
@ -116,7 +121,7 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
new_member.atype = MembershipType::User as i32;
new_member.status = member_status;
new_member.save(&mut conn).await?;
new_member.save(&conn).await?;
if CONFIG.mail_enabled() {
if let Err(e) =
@ -124,9 +129,9 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
{
// Upon error delete the user, invite and org member records when needed
if user_created {
user.delete(&mut conn).await?;
user.delete(&conn).await?;
} else {
new_member.delete(&mut conn).await?;
new_member.delete(&conn).await?;
}
err!(format!("Error sending invite: {e:?} "));
@ -137,8 +142,7 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
if CONFIG.org_groups_enabled() {
for group_data in &data.groups {
let group_uuid = match Group::find_by_external_id_and_org(&group_data.external_id, &org_id, &mut conn).await
{
let group_uuid = match Group::find_by_external_id_and_org(&group_data.external_id, &org_id, &conn).await {
Some(group) => group.uuid,
None => {
let mut group = Group::new(
@ -147,17 +151,17 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
false,
Some(group_data.external_id.clone()),
);
group.save(&mut conn).await?;
group.save(&conn).await?;
group.uuid
}
};
GroupUser::delete_all_by_group(&group_uuid, &mut conn).await?;
GroupUser::delete_all_by_group(&group_uuid, &conn).await?;
for ext_id in &group_data.member_external_ids {
if let Some(member) = Membership::find_by_external_id_and_org(ext_id, &org_id, &mut conn).await {
if let Some(member) = Membership::find_by_external_id_and_org(ext_id, &org_id, &conn).await {
let mut group_user = GroupUser::new(group_uuid.clone(), member.uuid.clone());
group_user.save(&mut conn).await?;
group_user.save(&conn).await?;
}
}
}
@ -169,19 +173,18 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
if data.overwrite_existing {
// Generate a HashSet to quickly verify if a member is listed or not.
let sync_members: HashSet<String> = data.members.into_iter().map(|m| m.external_id).collect();
for member in Membership::find_by_org(&org_id, &mut conn).await {
for member in Membership::find_by_org(&org_id, &conn).await {
if let Some(ref user_external_id) = member.external_id {
if !sync_members.contains(user_external_id) {
if member.atype == MembershipType::Owner && member.status == MembershipStatus::Confirmed as i32 {
// Removing owner, check that there is at least one other confirmed owner
if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &mut conn).await
<= 1
if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &conn).await <= 1
{
warn!("Can't delete the last owner");
continue;
}
}
member.delete(&mut conn).await?;
member.delete(&conn).await?;
}
}
}

129
src/api/core/sends.rs

@ -14,7 +14,10 @@ use crate::{
api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType},
auth::{ClientIp, Headers, Host},
config::PathType,
db::{models::*, DbConn, DbPool},
db::{
models::{Device, OrgPolicy, OrgPolicyType, Send, SendFileId, SendId, SendType, UserId},
DbConn, DbPool,
},
util::{save_temp_file, NumberOrString},
CONFIG,
};
@ -58,8 +61,8 @@ pub fn routes() -> Vec<rocket::Route> {
pub async fn purge_sends(pool: DbPool) {
debug!("Purging sends");
if let Ok(mut conn) = pool.get().await {
Send::purge(&mut conn).await;
if let Ok(conn) = pool.get().await {
Send::purge(&conn).await;
} else {
error!("Failed to get DB connection while purging sends")
}
@ -96,7 +99,7 @@ pub struct SendData {
///
/// There is also a Vaultwarden-specific `sends_allowed` config setting that
/// controls this policy globally.
async fn enforce_disable_send_policy(headers: &Headers, conn: &mut DbConn) -> EmptyResult {
async fn enforce_disable_send_policy(headers: &Headers, conn: &DbConn) -> EmptyResult {
let user_id = &headers.user.uuid;
if !CONFIG.sends_allowed()
|| OrgPolicy::is_applicable_to_user(user_id, OrgPolicyType::DisableSend, None, conn).await
@ -112,7 +115,7 @@ async fn enforce_disable_send_policy(headers: &Headers, conn: &mut DbConn) -> Em
/// but is allowed to remove this option from an existing Send.
///
/// Ref: https://bitwarden.com/help/article/policies/#send-options
async fn enforce_disable_hide_email_policy(data: &SendData, headers: &Headers, conn: &mut DbConn) -> EmptyResult {
async fn enforce_disable_hide_email_policy(data: &SendData, headers: &Headers, conn: &DbConn) -> EmptyResult {
let user_id = &headers.user.uuid;
let hide_email = data.hide_email.unwrap_or(false);
if hide_email && OrgPolicy::is_hide_email_disabled(user_id, conn).await {
@ -164,8 +167,8 @@ fn create_send(data: SendData, user_id: UserId) -> ApiResult<Send> {
}
#[get("/sends")]
async fn get_sends(headers: Headers, mut conn: DbConn) -> Json<Value> {
let sends = Send::find_by_user(&headers.user.uuid, &mut conn);
async fn get_sends(headers: Headers, conn: DbConn) -> Json<Value> {
let sends = Send::find_by_user(&headers.user.uuid, &conn);
let sends_json: Vec<Value> = sends.await.iter().map(|s| s.to_json()).collect();
Json(json!({
@ -176,32 +179,32 @@ async fn get_sends(headers: Headers, mut conn: DbConn) -> Json<Value> {
}
#[get("/sends/<send_id>")]
async fn get_send(send_id: SendId, headers: Headers, mut conn: DbConn) -> JsonResult {
match Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &mut conn).await {
async fn get_send(send_id: SendId, headers: Headers, conn: DbConn) -> JsonResult {
match Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await {
Some(send) => Ok(Json(send.to_json())),
None => err!("Send not found", "Invalid send uuid or does not belong to user"),
}
}
#[post("/sends", data = "<data>")]
async fn post_send(data: Json<SendData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?;
async fn post_send(data: Json<SendData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &conn).await?;
let data: SendData = data.into_inner();
enforce_disable_hide_email_policy(&data, &headers, &mut conn).await?;
enforce_disable_hide_email_policy(&data, &headers, &conn).await?;
if data.r#type == SendType::File as i32 {
err!("File sends should use /api/sends/file")
}
let mut send = create_send(data, headers.user.uuid)?;
send.save(&mut conn).await?;
send.save(&conn).await?;
nt.send_send_update(
UpdateType::SyncSendCreate,
&send,
&send.update_users_revision(&mut conn).await,
&send.update_users_revision(&conn).await,
&headers.device,
&mut conn,
&conn,
)
.await;
@ -225,8 +228,8 @@ struct UploadDataV2<'f> {
// 2025: This endpoint doesn't seem to exists anymore in the latest version
// See: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/Tools/Controllers/SendsController.cs
#[post("/sends/file", format = "multipart/form-data", data = "<data>")]
async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?;
async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &conn).await?;
let UploadData {
model,
@ -241,12 +244,12 @@ async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, mut conn:
err!("Send size can't be negative")
}
enforce_disable_hide_email_policy(&model, &headers, &mut conn).await?;
enforce_disable_hide_email_policy(&model, &headers, &conn).await?;
let size_limit = match CONFIG.user_send_limit() {
Some(0) => err!("File uploads are disabled"),
Some(limit_kb) => {
let Some(already_used) = Send::size_by_user(&headers.user.uuid, &mut conn).await else {
let Some(already_used) = Send::size_by_user(&headers.user.uuid, &conn).await else {
err!("Existing sends overflow")
};
let Some(left) = limit_kb.checked_mul(1024).and_then(|l| l.checked_sub(already_used)) else {
@ -282,13 +285,13 @@ async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, mut conn:
send.data = serde_json::to_string(&data_value)?;
// Save the changes in the database
send.save(&mut conn).await?;
send.save(&conn).await?;
nt.send_send_update(
UpdateType::SyncSendCreate,
&send,
&send.update_users_revision(&mut conn).await,
&send.update_users_revision(&conn).await,
&headers.device,
&mut conn,
&conn,
)
.await;
@ -297,8 +300,8 @@ async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, mut conn:
// Upstream: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/Tools/Controllers/SendsController.cs#L165
#[post("/sends/file/v2", data = "<data>")]
async fn post_send_file_v2(data: Json<SendData>, headers: Headers, mut conn: DbConn) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?;
async fn post_send_file_v2(data: Json<SendData>, headers: Headers, conn: DbConn) -> JsonResult {
enforce_disable_send_policy(&headers, &conn).await?;
let data = data.into_inner();
@ -306,7 +309,7 @@ async fn post_send_file_v2(data: Json<SendData>, headers: Headers, mut conn: DbC
err!("Send content is not a file");
}
enforce_disable_hide_email_policy(&data, &headers, &mut conn).await?;
enforce_disable_hide_email_policy(&data, &headers, &conn).await?;
let file_length = match &data.file_length {
Some(m) => m.into_i64()?,
@ -319,7 +322,7 @@ async fn post_send_file_v2(data: Json<SendData>, headers: Headers, mut conn: DbC
let size_limit = match CONFIG.user_send_limit() {
Some(0) => err!("File uploads are disabled"),
Some(limit_kb) => {
let Some(already_used) = Send::size_by_user(&headers.user.uuid, &mut conn).await else {
let Some(already_used) = Send::size_by_user(&headers.user.uuid, &conn).await else {
err!("Existing sends overflow")
};
let Some(left) = limit_kb.checked_mul(1024).and_then(|l| l.checked_sub(already_used)) else {
@ -348,7 +351,7 @@ async fn post_send_file_v2(data: Json<SendData>, headers: Headers, mut conn: DbC
o.insert(String::from("sizeName"), Value::String(crate::util::get_display_size(file_length)));
}
send.data = serde_json::to_string(&data_value)?;
send.save(&mut conn).await?;
send.save(&conn).await?;
Ok(Json(json!({
"fileUploadType": 0, // 0 == Direct | 1 == Azure
@ -373,14 +376,14 @@ async fn post_send_file_v2_data(
file_id: SendFileId,
data: Form<UploadDataV2<'_>>,
headers: Headers,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
enforce_disable_send_policy(&headers, &mut conn).await?;
enforce_disable_send_policy(&headers, &conn).await?;
let data = data.into_inner();
let Some(send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &mut conn).await else {
let Some(send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await else {
err!("Send not found. Unable to save the file.", "Invalid send uuid or does not belong to user.")
};
@ -428,9 +431,9 @@ async fn post_send_file_v2_data(
nt.send_send_update(
UpdateType::SyncSendCreate,
&send,
&send.update_users_revision(&mut conn).await,
&send.update_users_revision(&conn).await,
&headers.device,
&mut conn,
&conn,
)
.await;
@ -447,11 +450,11 @@ pub struct SendAccessData {
async fn post_access(
access_id: &str,
data: Json<SendAccessData>,
mut conn: DbConn,
conn: DbConn,
ip: ClientIp,
nt: Notify<'_>,
) -> JsonResult {
let Some(mut send) = Send::find_by_access_id(access_id, &mut conn).await else {
let Some(mut send) = Send::find_by_access_id(access_id, &conn).await else {
err_code!(SEND_INACCESSIBLE_MSG, 404)
};
@ -488,18 +491,18 @@ async fn post_access(
send.access_count += 1;
}
send.save(&mut conn).await?;
send.save(&conn).await?;
nt.send_send_update(
UpdateType::SyncSendUpdate,
&send,
&send.update_users_revision(&mut conn).await,
&send.update_users_revision(&conn).await,
&ANON_PUSH_DEVICE,
&mut conn,
&conn,
)
.await;
Ok(Json(send.to_json_access(&mut conn).await))
Ok(Json(send.to_json_access(&conn).await))
}
#[post("/sends/<send_id>/access/file/<file_id>", data = "<data>")]
@ -508,10 +511,10 @@ async fn post_access_file(
file_id: SendFileId,
data: Json<SendAccessData>,
host: Host,
mut conn: DbConn,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
let Some(mut send) = Send::find_by_uuid(&send_id, &mut conn).await else {
let Some(mut send) = Send::find_by_uuid(&send_id, &conn).await else {
err_code!(SEND_INACCESSIBLE_MSG, 404)
};
@ -545,14 +548,14 @@ async fn post_access_file(
send.access_count += 1;
send.save(&mut conn).await?;
send.save(&conn).await?;
nt.send_send_update(
UpdateType::SyncSendUpdate,
&send,
&send.update_users_revision(&mut conn).await,
&send.update_users_revision(&conn).await,
&ANON_PUSH_DEVICE,
&mut conn,
&conn,
)
.await;
@ -587,23 +590,17 @@ async fn download_send(send_id: SendId, file_id: SendFileId, t: &str) -> Option<
}
#[put("/sends/<send_id>", data = "<data>")]
async fn put_send(
send_id: SendId,
data: Json<SendData>,
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?;
async fn put_send(send_id: SendId, data: Json<SendData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &conn).await?;
let data: SendData = data.into_inner();
enforce_disable_hide_email_policy(&data, &headers, &mut conn).await?;
enforce_disable_hide_email_policy(&data, &headers, &conn).await?;
let Some(mut send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &mut conn).await else {
let Some(mut send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await else {
err!("Send not found", "Send send_id is invalid or does not belong to user")
};
update_send_from_data(&mut send, data, &headers, &mut conn, &nt, UpdateType::SyncSendUpdate).await?;
update_send_from_data(&mut send, data, &headers, &conn, &nt, UpdateType::SyncSendUpdate).await?;
Ok(Json(send.to_json()))
}
@ -612,7 +609,7 @@ pub async fn update_send_from_data(
send: &mut Send,
data: SendData,
headers: &Headers,
conn: &mut DbConn,
conn: &DbConn,
nt: &Notify<'_>,
ut: UpdateType,
) -> EmptyResult {
@ -667,18 +664,18 @@ pub async fn update_send_from_data(
}
#[delete("/sends/<send_id>")]
async fn delete_send(send_id: SendId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let Some(send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &mut conn).await else {
async fn delete_send(send_id: SendId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let Some(send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await else {
err!("Send not found", "Invalid send uuid, or does not belong to user")
};
send.delete(&mut conn).await?;
send.delete(&conn).await?;
nt.send_send_update(
UpdateType::SyncSendDelete,
&send,
&send.update_users_revision(&mut conn).await,
&send.update_users_revision(&conn).await,
&headers.device,
&mut conn,
&conn,
)
.await;
@ -686,21 +683,21 @@ async fn delete_send(send_id: SendId, headers: Headers, mut conn: DbConn, nt: No
}
#[put("/sends/<send_id>/remove-password")]
async fn put_remove_password(send_id: SendId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?;
async fn put_remove_password(send_id: SendId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &conn).await?;
let Some(mut send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &mut conn).await else {
let Some(mut send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await else {
err!("Send not found", "Invalid send uuid, or does not belong to user")
};
send.set_password(None);
send.save(&mut conn).await?;
send.save(&conn).await?;
nt.send_send_update(
UpdateType::SyncSendUpdate,
&send,
&send.update_users_revision(&mut conn).await,
&send.update_users_revision(&conn).await,
&headers.device,
&mut conn,
&conn,
)
.await;

38
src/api/core/two_factor/authenticator.rs

@ -20,14 +20,14 @@ pub fn routes() -> Vec<Route> {
}
#[post("/two-factor/get-authenticator", data = "<data>")]
async fn generate_authenticator(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn generate_authenticator(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordOrOtpData = data.into_inner();
let user = headers.user;
data.validate(&user, false, &mut conn).await?;
data.validate(&user, false, &conn).await?;
let type_ = TwoFactorType::Authenticator as i32;
let twofactor = TwoFactor::find_by_user_and_type(&user.uuid, type_, &mut conn).await;
let twofactor = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await;
let (enabled, key) = match twofactor {
Some(tf) => (true, tf.data),
@ -55,7 +55,7 @@ struct EnableAuthenticatorData {
}
#[post("/two-factor/authenticator", data = "<data>")]
async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EnableAuthenticatorData = data.into_inner();
let key = data.key;
let token = data.token.into_string();
@ -66,7 +66,7 @@ async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: He
master_password_hash: data.master_password_hash,
otp: data.otp,
}
.validate(&user, true, &mut conn)
.validate(&user, true, &conn)
.await?;
// Validate key as base32 and 20 bytes length
@ -80,11 +80,11 @@ async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: He
}
// Validate the token provided with the key, and save new twofactor
validate_totp_code(&user.uuid, &token, &key.to_uppercase(), &headers.ip, &mut conn).await?;
validate_totp_code(&user.uuid, &token, &key.to_uppercase(), &headers.ip, &conn).await?;
_generate_recover_code(&mut user, &mut conn).await;
_generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
Ok(Json(json!({
"enabled": true,
@ -103,7 +103,7 @@ pub async fn validate_totp_code_str(
totp_code: &str,
secret: &str,
ip: &ClientIp,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
if !totp_code.chars().all(char::is_numeric) {
err!("TOTP code is not a number");
@ -117,7 +117,7 @@ pub async fn validate_totp_code(
totp_code: &str,
secret: &str,
ip: &ClientIp,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
use totp_lite::{totp_custom, Sha1};
@ -189,7 +189,7 @@ struct DisableAuthenticatorData {
}
#[delete("/two-factor/authenticator", data = "<data>")]
async fn disable_authenticator(data: Json<DisableAuthenticatorData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn disable_authenticator(data: Json<DisableAuthenticatorData>, headers: Headers, conn: DbConn) -> JsonResult {
let user = headers.user;
let type_ = data.r#type.into_i32()?;
@ -197,24 +197,18 @@ async fn disable_authenticator(data: Json<DisableAuthenticatorData>, headers: He
err!("Invalid password");
}
if let Some(twofactor) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &mut conn).await {
if let Some(twofactor) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
if twofactor.data == data.key {
twofactor.delete(&mut conn).await?;
log_user_event(
EventType::UserDisabled2fa as i32,
&user.uuid,
headers.device.atype,
&headers.ip.ip,
&mut conn,
)
twofactor.delete(&conn).await?;
log_user_event(EventType::UserDisabled2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn)
.await;
} else {
err!(format!("TOTP key for user {} does not match recorded value, cannot deactivate", &user.email));
}
}
if TwoFactor::find_by_user(&user.uuid, &mut conn).await.is_empty() {
super::enforce_2fa_policy(&user, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await?;
if TwoFactor::find_by_user(&user.uuid, &conn).await.is_empty() {
super::enforce_2fa_policy(&user, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await?;
}
Ok(Json(json!({

24
src/api/core/two_factor/duo.rs

@ -92,13 +92,13 @@ impl DuoStatus {
const DISABLED_MESSAGE_DEFAULT: &str = "<To use the global Duo keys, please leave these fields untouched>";
#[post("/two-factor/get-duo", data = "<data>")]
async fn get_duo(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn get_duo(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordOrOtpData = data.into_inner();
let user = headers.user;
data.validate(&user, false, &mut conn).await?;
data.validate(&user, false, &conn).await?;
let data = get_user_duo_data(&user.uuid, &mut conn).await;
let data = get_user_duo_data(&user.uuid, &conn).await;
let (enabled, data) = match data {
DuoStatus::Global(_) => (true, Some(DuoData::secret())),
@ -158,7 +158,7 @@ fn check_duo_fields_custom(data: &EnableDuoData) -> bool {
}
#[post("/two-factor/duo", data = "<data>")]
async fn activate_duo(data: Json<EnableDuoData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn activate_duo(data: Json<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EnableDuoData = data.into_inner();
let mut user = headers.user;
@ -166,7 +166,7 @@ async fn activate_duo(data: Json<EnableDuoData>, headers: Headers, mut conn: DbC
master_password_hash: data.master_password_hash.clone(),
otp: data.otp.clone(),
}
.validate(&user, true, &mut conn)
.validate(&user, true, &conn)
.await?;
let (data, data_str) = if check_duo_fields_custom(&data) {
@ -180,11 +180,11 @@ async fn activate_duo(data: Json<EnableDuoData>, headers: Headers, mut conn: DbC
let type_ = TwoFactorType::Duo;
let twofactor = TwoFactor::new(user.uuid.clone(), type_, data_str);
twofactor.save(&mut conn).await?;
twofactor.save(&conn).await?;
_generate_recover_code(&mut user, &mut conn).await;
_generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
Ok(Json(json!({
"enabled": true,
@ -231,7 +231,7 @@ const AUTH_PREFIX: &str = "AUTH";
const DUO_PREFIX: &str = "TX";
const APP_PREFIX: &str = "APP";
async fn get_user_duo_data(user_id: &UserId, conn: &mut DbConn) -> DuoStatus {
async fn get_user_duo_data(user_id: &UserId, conn: &DbConn) -> DuoStatus {
let type_ = TwoFactorType::Duo as i32;
// If the user doesn't have an entry, disabled
@ -254,7 +254,7 @@ async fn get_user_duo_data(user_id: &UserId, conn: &mut DbConn) -> DuoStatus {
}
// let (ik, sk, ak, host) = get_duo_keys();
pub(crate) async fn get_duo_keys_email(email: &str, conn: &mut DbConn) -> ApiResult<(String, String, String, String)> {
pub(crate) async fn get_duo_keys_email(email: &str, conn: &DbConn) -> ApiResult<(String, String, String, String)> {
let data = match User::find_by_mail(email, conn).await {
Some(u) => get_user_duo_data(&u.uuid, conn).await.data(),
_ => DuoData::global(),
@ -264,7 +264,7 @@ pub(crate) async fn get_duo_keys_email(email: &str, conn: &mut DbConn) -> ApiRes
Ok((data.ik, data.sk, CONFIG.get_duo_akey().await, data.host))
}
pub async fn generate_duo_signature(email: &str, conn: &mut DbConn) -> ApiResult<(String, String)> {
pub async fn generate_duo_signature(email: &str, conn: &DbConn) -> ApiResult<(String, String)> {
let now = Utc::now().timestamp();
let (ik, sk, ak, host) = get_duo_keys_email(email, conn).await?;
@ -282,7 +282,7 @@ fn sign_duo_values(key: &str, email: &str, ikey: &str, prefix: &str, expire: i64
format!("{cookie}|{}", crypto::hmac_sign(key, &cookie))
}
pub async fn validate_duo_login(email: &str, response: &str, conn: &mut DbConn) -> EmptyResult {
pub async fn validate_duo_login(email: &str, response: &str, conn: &DbConn) -> EmptyResult {
let split: Vec<&str> = response.split(':').collect();
if split.len() != 2 {
err!(

10
src/api/core/two_factor/duo_oidc.rs

@ -317,7 +317,7 @@ struct DuoAuthContext {
// Given a state string, retrieve the associated Duo auth context and
// delete the retrieved state from the database.
async fn extract_context(state: &str, conn: &mut DbConn) -> Option<DuoAuthContext> {
async fn extract_context(state: &str, conn: &DbConn) -> Option<DuoAuthContext> {
let ctx: TwoFactorDuoContext = match TwoFactorDuoContext::find_by_state(state, conn).await {
Some(c) => c,
None => return None,
@ -344,8 +344,8 @@ async fn extract_context(state: &str, conn: &mut DbConn) -> Option<DuoAuthContex
// Task to clean up expired Duo authentication contexts that may have accumulated in the database.
pub async fn purge_duo_contexts(pool: DbPool) {
debug!("Purging Duo authentication contexts");
if let Ok(mut conn) = pool.get().await {
TwoFactorDuoContext::purge_expired_duo_contexts(&mut conn).await;
if let Ok(conn) = pool.get().await {
TwoFactorDuoContext::purge_expired_duo_contexts(&conn).await;
} else {
error!("Failed to get DB connection while purging expired Duo authentications")
}
@ -380,7 +380,7 @@ pub async fn get_duo_auth_url(
email: &str,
client_id: &str,
device_identifier: &DeviceId,
conn: &mut DbConn,
conn: &DbConn,
) -> Result<String, Error> {
let (ik, sk, _, host) = get_duo_keys_email(email, conn).await?;
@ -418,7 +418,7 @@ pub async fn validate_duo_login(
two_factor_token: &str,
client_id: &str,
device_identifier: &DeviceId,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
// Result supplied to us by clients in the form "<authz code>|<state>"
let split: Vec<&str> = two_factor_token.split('|').collect();

42
src/api/core/two_factor/email.rs

@ -39,13 +39,13 @@ struct SendEmailLoginData {
/// User is trying to login and wants to use email 2FA.
/// Does not require Bearer token
#[post("/two-factor/send-email-login", data = "<data>")] // JsonResult
async fn send_email_login(data: Json<SendEmailLoginData>, mut conn: DbConn) -> EmptyResult {
async fn send_email_login(data: Json<SendEmailLoginData>, conn: DbConn) -> EmptyResult {
let data: SendEmailLoginData = data.into_inner();
use crate::db::models::User;
// Get the user
let Some(user) = User::find_by_device_id(&data.device_identifier, &mut conn).await else {
let Some(user) = User::find_by_device_id(&data.device_identifier, &conn).await else {
err!("Cannot find user. Try again.")
};
@ -53,13 +53,13 @@ async fn send_email_login(data: Json<SendEmailLoginData>, mut conn: DbConn) -> E
err!("Email 2FA is disabled")
}
send_token(&user.uuid, &mut conn).await?;
send_token(&user.uuid, &conn).await?;
Ok(())
}
/// Generate the token, save the data for later verification and send email to user
pub async fn send_token(user_id: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn send_token(user_id: &UserId, conn: &DbConn) -> EmptyResult {
let type_ = TwoFactorType::Email as i32;
let mut twofactor = TwoFactor::find_by_user_and_type(user_id, type_, conn).await.map_res("Two factor not found")?;
@ -77,14 +77,14 @@ pub async fn send_token(user_id: &UserId, conn: &mut DbConn) -> EmptyResult {
/// When user clicks on Manage email 2FA show the user the related information
#[post("/two-factor/get-email", data = "<data>")]
async fn get_email(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn get_email(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordOrOtpData = data.into_inner();
let user = headers.user;
data.validate(&user, false, &mut conn).await?;
data.validate(&user, false, &conn).await?;
let (enabled, mfa_email) =
match TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::Email as i32, &mut conn).await {
match TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::Email as i32, &conn).await {
Some(x) => {
let twofactor_data = EmailTokenData::from_json(&x.data)?;
(true, json!(twofactor_data.email))
@ -110,7 +110,7 @@ struct SendEmailData {
/// Send a verification email to the specified email address to check whether it exists/belongs to user.
#[post("/two-factor/send-email", data = "<data>")]
async fn send_email(data: Json<SendEmailData>, headers: Headers, mut conn: DbConn) -> EmptyResult {
async fn send_email(data: Json<SendEmailData>, headers: Headers, conn: DbConn) -> EmptyResult {
let data: SendEmailData = data.into_inner();
let user = headers.user;
@ -118,7 +118,7 @@ async fn send_email(data: Json<SendEmailData>, headers: Headers, mut conn: DbCon
master_password_hash: data.master_password_hash,
otp: data.otp,
}
.validate(&user, false, &mut conn)
.validate(&user, false, &conn)
.await?;
if !CONFIG._enable_email_2fa() {
@ -127,8 +127,8 @@ async fn send_email(data: Json<SendEmailData>, headers: Headers, mut conn: DbCon
let type_ = TwoFactorType::Email as i32;
if let Some(tf) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &mut conn).await {
tf.delete(&mut conn).await?;
if let Some(tf) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
tf.delete(&conn).await?;
}
let generated_token = crypto::generate_email_token(CONFIG.email_token_size());
@ -136,7 +136,7 @@ async fn send_email(data: Json<SendEmailData>, headers: Headers, mut conn: DbCon
// Uses EmailVerificationChallenge as type to show that it's not verified yet.
let twofactor = TwoFactor::new(user.uuid, TwoFactorType::EmailVerificationChallenge, twofactor_data.to_json());
twofactor.save(&mut conn).await?;
twofactor.save(&conn).await?;
mail::send_token(&twofactor_data.email, &twofactor_data.last_token.map_res("Token is empty")?).await?;
@ -154,7 +154,7 @@ struct EmailData {
/// Verify email belongs to user and can be used for 2FA email codes.
#[put("/two-factor/email", data = "<data>")]
async fn email(data: Json<EmailData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn email(data: Json<EmailData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EmailData = data.into_inner();
let mut user = headers.user;
@ -163,12 +163,12 @@ async fn email(data: Json<EmailData>, headers: Headers, mut conn: DbConn) -> Jso
master_password_hash: data.master_password_hash,
otp: data.otp,
}
.validate(&user, true, &mut conn)
.validate(&user, true, &conn)
.await?;
let type_ = TwoFactorType::EmailVerificationChallenge as i32;
let mut twofactor =
TwoFactor::find_by_user_and_type(&user.uuid, type_, &mut conn).await.map_res("Two factor not found")?;
TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await.map_res("Two factor not found")?;
let mut email_data = EmailTokenData::from_json(&twofactor.data)?;
@ -183,11 +183,11 @@ async fn email(data: Json<EmailData>, headers: Headers, mut conn: DbConn) -> Jso
email_data.reset_token();
twofactor.atype = TwoFactorType::Email as i32;
twofactor.data = email_data.to_json();
twofactor.save(&mut conn).await?;
twofactor.save(&conn).await?;
_generate_recover_code(&mut user, &mut conn).await;
_generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
Ok(Json(json!({
"email": email_data.email,
@ -202,7 +202,7 @@ pub async fn validate_email_code_str(
token: &str,
data: &str,
ip: &std::net::IpAddr,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
let mut email_data = EmailTokenData::from_json(data)?;
let mut twofactor = TwoFactor::find_by_user_and_type(user_id, TwoFactorType::Email as i32, conn)
@ -302,7 +302,7 @@ impl EmailTokenData {
}
}
pub async fn activate_email_2fa(user: &User, conn: &mut DbConn) -> EmptyResult {
pub async fn activate_email_2fa(user: &User, conn: &DbConn) -> EmptyResult {
if user.verified_at.is_none() {
err!("Auto-enabling of email 2FA failed because the users email address has not been verified!");
}
@ -332,7 +332,7 @@ pub fn obscure_email(email: &str) -> String {
format!("{new_name}@{domain}")
}
pub async fn find_and_activate_email_2fa(user_id: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn find_and_activate_email_2fa(user_id: &UserId, conn: &DbConn) -> EmptyResult {
if let Some(user) = User::find_by_uuid(user_id, conn).await {
activate_email_2fa(&user, conn).await
} else {

56
src/api/core/two_factor/mod.rs

@ -11,7 +11,13 @@ use crate::{
},
auth::{ClientHeaders, Headers},
crypto,
db::{models::*, DbConn, DbPool},
db::{
models::{
DeviceType, EventType, Membership, MembershipType, OrgPolicyType, Organization, OrganizationId, TwoFactor,
TwoFactorIncomplete, User, UserId,
},
DbConn, DbPool,
},
mail,
util::NumberOrString,
CONFIG,
@ -46,8 +52,8 @@ pub fn routes() -> Vec<Route> {
}
#[get("/two-factor")]
async fn get_twofactor(headers: Headers, mut conn: DbConn) -> Json<Value> {
let twofactors = TwoFactor::find_by_user(&headers.user.uuid, &mut conn).await;
async fn get_twofactor(headers: Headers, conn: DbConn) -> Json<Value> {
let twofactors = TwoFactor::find_by_user(&headers.user.uuid, &conn).await;
let twofactors_json: Vec<Value> = twofactors.iter().map(TwoFactor::to_json_provider).collect();
Json(json!({
@ -58,11 +64,11 @@ async fn get_twofactor(headers: Headers, mut conn: DbConn) -> Json<Value> {
}
#[post("/two-factor/get-recover", data = "<data>")]
async fn get_recover(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn get_recover(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordOrOtpData = data.into_inner();
let user = headers.user;
data.validate(&user, true, &mut conn).await?;
data.validate(&user, true, &conn).await?;
Ok(Json(json!({
"code": user.totp_recover,
@ -79,13 +85,13 @@ struct RecoverTwoFactor {
}
#[post("/two-factor/recover", data = "<data>")]
async fn recover(data: Json<RecoverTwoFactor>, client_headers: ClientHeaders, mut conn: DbConn) -> JsonResult {
async fn recover(data: Json<RecoverTwoFactor>, client_headers: ClientHeaders, conn: DbConn) -> JsonResult {
let data: RecoverTwoFactor = data.into_inner();
use crate::db::models::User;
// Get the user
let Some(mut user) = User::find_by_mail(&data.email, &mut conn).await else {
let Some(mut user) = User::find_by_mail(&data.email, &conn).await else {
err!("Username or password is incorrect. Try again.")
};
@ -100,25 +106,25 @@ async fn recover(data: Json<RecoverTwoFactor>, client_headers: ClientHeaders, mu
}
// Remove all twofactors from the user
TwoFactor::delete_all_by_user(&user.uuid, &mut conn).await?;
enforce_2fa_policy(&user, &user.uuid, client_headers.device_type, &client_headers.ip.ip, &mut conn).await?;
TwoFactor::delete_all_by_user(&user.uuid, &conn).await?;
enforce_2fa_policy(&user, &user.uuid, client_headers.device_type, &client_headers.ip.ip, &conn).await?;
log_user_event(
EventType::UserRecovered2fa as i32,
&user.uuid,
client_headers.device_type,
&client_headers.ip.ip,
&mut conn,
&conn,
)
.await;
// Remove the recovery code, not needed without twofactors
user.totp_recover = None;
user.save(&mut conn).await?;
user.save(&conn).await?;
Ok(Json(Value::Object(serde_json::Map::new())))
}
async fn _generate_recover_code(user: &mut User, conn: &mut DbConn) {
async fn _generate_recover_code(user: &mut User, conn: &DbConn) {
if user.totp_recover.is_none() {
let totp_recover = crypto::encode_random_bytes::<20>(BASE32);
user.totp_recover = Some(totp_recover);
@ -135,7 +141,7 @@ struct DisableTwoFactorData {
}
#[post("/two-factor/disable", data = "<data>")]
async fn disable_twofactor(data: Json<DisableTwoFactorData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn disable_twofactor(data: Json<DisableTwoFactorData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: DisableTwoFactorData = data.into_inner();
let user = headers.user;
@ -144,19 +150,19 @@ async fn disable_twofactor(data: Json<DisableTwoFactorData>, headers: Headers, m
master_password_hash: data.master_password_hash,
otp: data.otp,
}
.validate(&user, true, &mut conn)
.validate(&user, true, &conn)
.await?;
let type_ = data.r#type.into_i32()?;
if let Some(twofactor) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &mut conn).await {
twofactor.delete(&mut conn).await?;
log_user_event(EventType::UserDisabled2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn)
if let Some(twofactor) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
twofactor.delete(&conn).await?;
log_user_event(EventType::UserDisabled2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn)
.await;
}
if TwoFactor::find_by_user(&user.uuid, &mut conn).await.is_empty() {
enforce_2fa_policy(&user, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await?;
if TwoFactor::find_by_user(&user.uuid, &conn).await.is_empty() {
enforce_2fa_policy(&user, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await?;
}
Ok(Json(json!({
@ -176,7 +182,7 @@ pub async fn enforce_2fa_policy(
act_user_id: &UserId,
device_type: i32,
ip: &std::net::IpAddr,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
for member in
Membership::find_by_user_and_policy(&user.uuid, OrgPolicyType::TwoFactorAuthentication, conn).await.into_iter()
@ -212,7 +218,7 @@ pub async fn enforce_2fa_policy_for_org(
act_user_id: &UserId,
device_type: i32,
ip: &std::net::IpAddr,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
let org = Organization::find_by_uuid(org_id, conn).await.unwrap();
for member in Membership::find_confirmed_by_org(org_id, conn).await.into_iter() {
@ -249,7 +255,7 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
return;
}
let mut conn = match pool.get().await {
let conn = match pool.get().await {
Ok(conn) => conn,
_ => {
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
@ -260,9 +266,9 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
let now = Utc::now().naive_utc();
let time_limit = TimeDelta::try_minutes(CONFIG.incomplete_2fa_time_limit()).unwrap();
let time_before = now - time_limit;
let incomplete_logins = TwoFactorIncomplete::find_logins_before(&time_before, &mut conn).await;
let incomplete_logins = TwoFactorIncomplete::find_logins_before(&time_before, &conn).await;
for login in incomplete_logins {
let user = User::find_by_uuid(&login.user_uuid, &mut conn).await.expect("User not found");
let user = User::find_by_uuid(&login.user_uuid, &conn).await.expect("User not found");
info!(
"User {} did not complete a 2FA login within the configured time limit. IP: {}",
user.email, login.ip_address
@ -277,7 +283,7 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
.await
{
Ok(_) => {
if let Err(e) = login.delete(&mut conn).await {
if let Err(e) = login.delete(&conn).await {
error!("Error deleting incomplete 2FA record: {e:#?}");
}
}

15
src/api/core/two_factor/protected_actions.rs

@ -55,7 +55,7 @@ impl ProtectedActionData {
}
#[post("/accounts/request-otp")]
async fn request_otp(headers: Headers, mut conn: DbConn) -> EmptyResult {
async fn request_otp(headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.mail_enabled() {
err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device.");
}
@ -63,10 +63,9 @@ async fn request_otp(headers: Headers, mut conn: DbConn) -> EmptyResult {
let user = headers.user;
// Only one Protected Action per user is allowed to take place, delete the previous one
if let Some(pa) =
TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::ProtectedActions as i32, &mut conn).await
if let Some(pa) = TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::ProtectedActions as i32, &conn).await
{
pa.delete(&mut conn).await?;
pa.delete(&conn).await?;
}
let generated_token = crypto::generate_email_token(CONFIG.email_token_size());
@ -74,7 +73,7 @@ async fn request_otp(headers: Headers, mut conn: DbConn) -> EmptyResult {
// Uses EmailVerificationChallenge as type to show that it's not verified yet.
let twofactor = TwoFactor::new(user.uuid, TwoFactorType::ProtectedActions, pa_data.to_json());
twofactor.save(&mut conn).await?;
twofactor.save(&conn).await?;
mail::send_protected_action_token(&user.email, &pa_data.token).await?;
@ -89,7 +88,7 @@ struct ProtectedActionVerify {
}
#[post("/accounts/verify-otp", data = "<data>")]
async fn verify_otp(data: Json<ProtectedActionVerify>, headers: Headers, mut conn: DbConn) -> EmptyResult {
async fn verify_otp(data: Json<ProtectedActionVerify>, headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.mail_enabled() {
err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device.");
}
@ -99,14 +98,14 @@ async fn verify_otp(data: Json<ProtectedActionVerify>, headers: Headers, mut con
// Delete the token after one validation attempt
// This endpoint only gets called for the vault export, and doesn't need a second attempt
validate_protected_action_otp(&data.otp, &user.uuid, true, &mut conn).await
validate_protected_action_otp(&data.otp, &user.uuid, true, &conn).await
}
pub async fn validate_protected_action_otp(
otp: &str,
user_id: &UserId,
delete_if_valid: bool,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
let pa = TwoFactor::find_by_user_and_type(user_id, TwoFactorType::ProtectedActions as i32, conn)
.await

49
src/api/core/two_factor/webauthn.rs

@ -107,7 +107,7 @@ impl WebauthnRegistration {
}
#[post("/two-factor/get-webauthn", data = "<data>")]
async fn get_webauthn(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn get_webauthn(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
if !CONFIG.domain_set() {
err!("`DOMAIN` environment variable is not set. Webauthn disabled")
}
@ -115,9 +115,9 @@ async fn get_webauthn(data: Json<PasswordOrOtpData>, headers: Headers, mut conn:
let data: PasswordOrOtpData = data.into_inner();
let user = headers.user;
data.validate(&user, false, &mut conn).await?;
data.validate(&user, false, &conn).await?;
let (enabled, registrations) = get_webauthn_registrations(&user.uuid, &mut conn).await?;
let (enabled, registrations) = get_webauthn_registrations(&user.uuid, &conn).await?;
let registrations_json: Vec<Value> = registrations.iter().map(WebauthnRegistration::to_json).collect();
Ok(Json(json!({
@ -128,13 +128,13 @@ async fn get_webauthn(data: Json<PasswordOrOtpData>, headers: Headers, mut conn:
}
#[post("/two-factor/get-webauthn-challenge", data = "<data>")]
async fn generate_webauthn_challenge(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn generate_webauthn_challenge(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordOrOtpData = data.into_inner();
let user = headers.user;
data.validate(&user, false, &mut conn).await?;
data.validate(&user, false, &conn).await?;
let registrations = get_webauthn_registrations(&user.uuid, &mut conn)
let registrations = get_webauthn_registrations(&user.uuid, &conn)
.await?
.1
.into_iter()
@ -153,7 +153,7 @@ async fn generate_webauthn_challenge(data: Json<PasswordOrOtpData>, headers: Hea
state["rs"]["extensions"].as_object_mut().unwrap().clear();
let type_ = TwoFactorType::WebauthnRegisterChallenge;
TwoFactor::new(user.uuid.clone(), type_, serde_json::to_string(&state)?).save(&mut conn).await?;
TwoFactor::new(user.uuid.clone(), type_, serde_json::to_string(&state)?).save(&conn).await?;
// Because for this flow we abuse the passkeys as 2FA, and use it more like a securitykey
// we need to modify some of the default settings defined by `start_passkey_registration()`.
@ -252,7 +252,7 @@ impl From<PublicKeyCredentialCopy> for PublicKeyCredential {
}
#[post("/two-factor/webauthn", data = "<data>")]
async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EnableWebauthnData = data.into_inner();
let mut user = headers.user;
@ -260,15 +260,15 @@ async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, mut
master_password_hash: data.master_password_hash,
otp: data.otp,
}
.validate(&user, true, &mut conn)
.validate(&user, true, &conn)
.await?;
// Retrieve and delete the saved challenge state
let type_ = TwoFactorType::WebauthnRegisterChallenge as i32;
let state = match TwoFactor::find_by_user_and_type(&user.uuid, type_, &mut conn).await {
let state = match TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
Some(tf) => {
let state: PasskeyRegistration = serde_json::from_str(&tf.data)?;
tf.delete(&mut conn).await?;
tf.delete(&conn).await?;
state
}
None => err!("Can't recover challenge"),
@ -277,7 +277,7 @@ async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, mut
// Verify the credentials with the saved state
let credential = WEBAUTHN.finish_passkey_registration(&data.device_response.into(), &state)?;
let mut registrations: Vec<_> = get_webauthn_registrations(&user.uuid, &mut conn).await?.1;
let mut registrations: Vec<_> = get_webauthn_registrations(&user.uuid, &conn).await?.1;
// TODO: Check for repeated ID's
registrations.push(WebauthnRegistration {
id: data.id.into_i32()?,
@ -289,11 +289,11 @@ async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, mut
// Save the registrations and return them
TwoFactor::new(user.uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&registrations)?)
.save(&mut conn)
.save(&conn)
.await?;
_generate_recover_code(&mut user, &mut conn).await;
_generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
let keys_json: Vec<Value> = registrations.iter().map(WebauthnRegistration::to_json).collect();
Ok(Json(json!({
@ -316,14 +316,14 @@ struct DeleteU2FData {
}
#[delete("/two-factor/webauthn", data = "<data>")]
async fn delete_webauthn(data: Json<DeleteU2FData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn delete_webauthn(data: Json<DeleteU2FData>, headers: Headers, conn: DbConn) -> JsonResult {
let id = data.id.into_i32()?;
if !headers.user.check_valid_password(&data.master_password_hash) {
err!("Invalid password");
}
let Some(mut tf) =
TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::Webauthn as i32, &mut conn).await
TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::Webauthn as i32, &conn).await
else {
err!("Webauthn data not found!")
};
@ -336,12 +336,11 @@ async fn delete_webauthn(data: Json<DeleteU2FData>, headers: Headers, mut conn:
let removed_item = data.remove(item_pos);
tf.data = serde_json::to_string(&data)?;
tf.save(&mut conn).await?;
tf.save(&conn).await?;
drop(tf);
// If entry is migrated from u2f, delete the u2f entry as well
if let Some(mut u2f) =
TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::U2f as i32, &mut conn).await
if let Some(mut u2f) = TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::U2f as i32, &conn).await
{
let mut data: Vec<U2FRegistration> = match serde_json::from_str(&u2f.data) {
Ok(d) => d,
@ -352,7 +351,7 @@ async fn delete_webauthn(data: Json<DeleteU2FData>, headers: Headers, mut conn:
let new_data_str = serde_json::to_string(&data)?;
u2f.data = new_data_str;
u2f.save(&mut conn).await?;
u2f.save(&conn).await?;
}
let keys_json: Vec<Value> = data.iter().map(WebauthnRegistration::to_json).collect();
@ -366,7 +365,7 @@ async fn delete_webauthn(data: Json<DeleteU2FData>, headers: Headers, mut conn:
pub async fn get_webauthn_registrations(
user_id: &UserId,
conn: &mut DbConn,
conn: &DbConn,
) -> Result<(bool, Vec<WebauthnRegistration>), Error> {
let type_ = TwoFactorType::Webauthn as i32;
match TwoFactor::find_by_user_and_type(user_id, type_, conn).await {
@ -375,7 +374,7 @@ pub async fn get_webauthn_registrations(
}
}
pub async fn generate_webauthn_login(user_id: &UserId, conn: &mut DbConn) -> JsonResult {
pub async fn generate_webauthn_login(user_id: &UserId, conn: &DbConn) -> JsonResult {
// Load saved credentials
let creds: Vec<Passkey> =
get_webauthn_registrations(user_id, conn).await?.1.into_iter().map(|r| r.credential).collect();
@ -415,7 +414,7 @@ pub async fn generate_webauthn_login(user_id: &UserId, conn: &mut DbConn) -> Jso
Ok(Json(serde_json::to_value(response.public_key)?))
}
pub async fn validate_webauthn_login(user_id: &UserId, response: &str, conn: &mut DbConn) -> EmptyResult {
pub async fn validate_webauthn_login(user_id: &UserId, response: &str, conn: &DbConn) -> EmptyResult {
let type_ = TwoFactorType::WebauthnLoginChallenge as i32;
let mut state = match TwoFactor::find_by_user_and_type(user_id, type_, conn).await {
Some(tf) => {
@ -469,7 +468,7 @@ async fn check_and_update_backup_eligible(
rsp: &PublicKeyCredential,
registrations: &mut Vec<WebauthnRegistration>,
state: &mut PasskeyAuthentication,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
// The feature flags from the response
// For details see: https://www.w3.org/TR/webauthn-3/#sctn-authenticator-data

18
src/api/core/two_factor/yubikey.rs

@ -83,19 +83,19 @@ async fn verify_yubikey_otp(otp: String) -> EmptyResult {
}
#[post("/two-factor/get-yubikey", data = "<data>")]
async fn generate_yubikey(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn generate_yubikey(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
// Make sure the credentials are set
get_yubico_credentials()?;
let data: PasswordOrOtpData = data.into_inner();
let user = headers.user;
data.validate(&user, false, &mut conn).await?;
data.validate(&user, false, &conn).await?;
let user_id = &user.uuid;
let yubikey_type = TwoFactorType::YubiKey as i32;
let r = TwoFactor::find_by_user_and_type(user_id, yubikey_type, &mut conn).await;
let r = TwoFactor::find_by_user_and_type(user_id, yubikey_type, &conn).await;
if let Some(r) = r {
let yubikey_metadata: YubikeyMetadata = serde_json::from_str(&r.data)?;
@ -116,7 +116,7 @@ async fn generate_yubikey(data: Json<PasswordOrOtpData>, headers: Headers, mut c
}
#[post("/two-factor/yubikey", data = "<data>")]
async fn activate_yubikey(data: Json<EnableYubikeyData>, headers: Headers, mut conn: DbConn) -> JsonResult {
async fn activate_yubikey(data: Json<EnableYubikeyData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EnableYubikeyData = data.into_inner();
let mut user = headers.user;
@ -124,12 +124,12 @@ async fn activate_yubikey(data: Json<EnableYubikeyData>, headers: Headers, mut c
master_password_hash: data.master_password_hash.clone(),
otp: data.otp.clone(),
}
.validate(&user, true, &mut conn)
.validate(&user, true, &conn)
.await?;
// Check if we already have some data
let mut yubikey_data =
match TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::YubiKey as i32, &mut conn).await {
match TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::YubiKey as i32, &conn).await {
Some(data) => data,
None => TwoFactor::new(user.uuid.clone(), TwoFactorType::YubiKey, String::new()),
};
@ -160,11 +160,11 @@ async fn activate_yubikey(data: Json<EnableYubikeyData>, headers: Headers, mut c
};
yubikey_data.data = serde_json::to_string(&yubikey_metadata).unwrap();
yubikey_data.save(&mut conn).await?;
yubikey_data.save(&conn).await?;
_generate_recover_code(&mut user, &mut conn).await;
_generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
let mut result = jsonify_yubikeys(yubikey_metadata.keys);

59
src/api/identity.rs

@ -22,7 +22,13 @@ use crate::{
},
auth,
auth::{generate_organization_api_key_login_claims, AuthMethod, ClientHeaders, ClientIp, ClientVersion},
db::{models::*, DbConn},
db::{
models::{
AuthRequest, AuthRequestId, Device, DeviceId, EventType, Invitation, OrganizationApiKey, OrganizationId,
SsoNonce, SsoUser, TwoFactor, TwoFactorIncomplete, TwoFactorType, User, UserId,
},
DbConn,
},
error::MapResult,
mail, sso,
sso::{OIDCCode, OIDCState},
@ -48,7 +54,7 @@ async fn login(
data: Form<ConnectData>,
client_header: ClientHeaders,
client_version: Option<ClientVersion>,
mut conn: DbConn,
conn: DbConn,
) -> JsonResult {
let data: ConnectData = data.into_inner();
@ -57,7 +63,7 @@ async fn login(
let login_result = match data.grant_type.as_ref() {
"refresh_token" => {
_check_is_some(&data.refresh_token, "refresh_token cannot be blank")?;
_refresh_login(data, &mut conn, &client_header.ip).await
_refresh_login(data, &conn, &client_header.ip).await
}
"password" if CONFIG.sso_enabled() && CONFIG.sso_only() => err!("SSO sign-in is required"),
"password" => {
@ -70,7 +76,7 @@ async fn login(
_check_is_some(&data.device_name, "device_name cannot be blank")?;
_check_is_some(&data.device_type, "device_type cannot be blank")?;
_password_login(data, &mut user_id, &mut conn, &client_header.ip, &client_version).await
_password_login(data, &mut user_id, &conn, &client_header.ip, &client_version).await
}
"client_credentials" => {
_check_is_some(&data.client_id, "client_id cannot be blank")?;
@ -81,7 +87,7 @@ async fn login(
_check_is_some(&data.device_name, "device_name cannot be blank")?;
_check_is_some(&data.device_type, "device_type cannot be blank")?;
_api_key_login(data, &mut user_id, &mut conn, &client_header.ip).await
_api_key_login(data, &mut user_id, &conn, &client_header.ip).await
}
"authorization_code" if CONFIG.sso_enabled() => {
_check_is_some(&data.client_id, "client_id cannot be blank")?;
@ -91,7 +97,7 @@ async fn login(
_check_is_some(&data.device_name, "device_name cannot be blank")?;
_check_is_some(&data.device_type, "device_type cannot be blank")?;
_sso_login(data, &mut user_id, &mut conn, &client_header.ip, &client_version).await
_sso_login(data, &mut user_id, &conn, &client_header.ip, &client_version).await
}
"authorization_code" => err!("SSO sign-in is not available"),
t => err!("Invalid type", t),
@ -105,19 +111,13 @@ async fn login(
&user_id,
client_header.device_type,
&client_header.ip.ip,
&mut conn,
&conn,
)
.await;
}
Err(e) => {
if let Some(ev) = e.get_event() {
log_user_event(
ev.event as i32,
&user_id,
client_header.device_type,
&client_header.ip.ip,
&mut conn,
)
log_user_event(ev.event as i32, &user_id, client_header.device_type, &client_header.ip.ip, &conn)
.await
}
}
@ -128,7 +128,7 @@ async fn login(
}
// Return Status::Unauthorized to trigger logout
async fn _refresh_login(data: ConnectData, conn: &mut DbConn, ip: &ClientIp) -> JsonResult {
async fn _refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// Extract token
let refresh_token = match data.refresh_token {
Some(token) => token,
@ -166,7 +166,7 @@ async fn _refresh_login(data: ConnectData, conn: &mut DbConn, ip: &ClientIp) ->
async fn _sso_login(
data: ConnectData,
user_id: &mut Option<UserId>,
conn: &mut DbConn,
conn: &DbConn,
ip: &ClientIp,
client_version: &Option<ClientVersion>,
) -> JsonResult {
@ -319,7 +319,7 @@ async fn _sso_login(
async fn _password_login(
data: ConnectData,
user_id: &mut Option<UserId>,
conn: &mut DbConn,
conn: &DbConn,
ip: &ClientIp,
client_version: &Option<ClientVersion>,
) -> JsonResult {
@ -444,7 +444,7 @@ async fn authenticated_response(
auth_tokens: auth::AuthTokens,
twofactor_token: Option<String>,
now: &NaiveDateTime,
conn: &mut DbConn,
conn: &DbConn,
ip: &ClientIp,
) -> JsonResult {
if CONFIG.mail_enabled() && device.is_new() {
@ -504,12 +504,7 @@ async fn authenticated_response(
Ok(Json(result))
}
async fn _api_key_login(
data: ConnectData,
user_id: &mut Option<UserId>,
conn: &mut DbConn,
ip: &ClientIp,
) -> JsonResult {
async fn _api_key_login(data: ConnectData, user_id: &mut Option<UserId>, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// Ratelimit the login
crate::ratelimit::check_limit_login(&ip.ip)?;
@ -524,7 +519,7 @@ async fn _api_key_login(
async fn _user_api_key_login(
data: ConnectData,
user_id: &mut Option<UserId>,
conn: &mut DbConn,
conn: &DbConn,
ip: &ClientIp,
) -> JsonResult {
// Get the user via the client_id
@ -614,7 +609,7 @@ async fn _user_api_key_login(
Ok(Json(result))
}
async fn _organization_api_key_login(data: ConnectData, conn: &mut DbConn, ip: &ClientIp) -> JsonResult {
async fn _organization_api_key_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// Get the org via the client_id
let client_id = data.client_id.as_ref().unwrap();
let Some(org_id) = client_id.strip_prefix("organization.") else {
@ -643,7 +638,7 @@ async fn _organization_api_key_login(data: ConnectData, conn: &mut DbConn, ip: &
}
/// Retrieves an existing device or creates a new device from ConnectData and the User
async fn get_device(data: &ConnectData, conn: &mut DbConn, user: &User) -> ApiResult<Device> {
async fn get_device(data: &ConnectData, conn: &DbConn, user: &User) -> ApiResult<Device> {
// On iOS, device_type sends "iOS", on others it sends a number
// When unknown or unable to parse, return 14, which is 'Unknown Browser'
let device_type = util::try_parse_string(data.device_type.as_ref()).unwrap_or(14);
@ -663,7 +658,7 @@ async fn twofactor_auth(
device: &mut Device,
ip: &ClientIp,
client_version: &Option<ClientVersion>,
conn: &mut DbConn,
conn: &DbConn,
) -> ApiResult<Option<String>> {
let twofactors = TwoFactor::find_by_user(&user.uuid, conn).await;
@ -780,7 +775,7 @@ async fn _json_err_twofactor(
user_id: &UserId,
data: &ConnectData,
client_version: &Option<ClientVersion>,
conn: &mut DbConn,
conn: &DbConn,
) -> ApiResult<Value> {
let mut result = json!({
"error" : "invalid_grant",
@ -905,13 +900,13 @@ enum RegisterVerificationResponse {
#[post("/accounts/register/send-verification-email", data = "<data>")]
async fn register_verification_email(
data: Json<RegisterVerificationData>,
mut conn: DbConn,
conn: DbConn,
) -> ApiResult<RegisterVerificationResponse> {
let data = data.into_inner();
// the registration can only continue if signup is allowed or there exists an invitation
if !(CONFIG.is_signup_allowed(&data.email)
|| (!CONFIG.mail_enabled() && Invitation::find_by_mail(&data.email, &mut conn).await.is_some()))
|| (!CONFIG.mail_enabled() && Invitation::find_by_mail(&data.email, &conn).await.is_some()))
{
err!("Registration not allowed or user already exists")
}
@ -922,7 +917,7 @@ async fn register_verification_email(
let token = auth::encode_jwt(&token_claims);
if should_send_mail {
let user = User::find_by_mail(&data.email, &mut conn).await;
let user = User::find_by_mail(&data.email, &conn).await;
if user.filter(|u| u.private_key.is_some()).is_some() {
// There is still a timing side channel here in that the code
// paths that send mail take noticeably longer than ones that

2
src/api/mod.rs

@ -55,7 +55,7 @@ impl PasswordOrOtpData {
/// Tokens used via this struct can be used multiple times during the process
/// First for the validation to continue, after that to enable or validate the following actions
/// This is different per caller, so it can be adjusted to delete the token or not
pub async fn validate(&self, user: &User, delete_if_valid: bool, conn: &mut DbConn) -> EmptyResult {
pub async fn validate(&self, user: &User, delete_if_valid: bool, conn: &DbConn) -> EmptyResult {
use crate::api::core::two_factor::protected_actions::validate_protected_action_otp;
match (self.master_password_hash.as_deref(), self.otp.as_deref()) {

20
src/api/notifications.rs

@ -339,7 +339,7 @@ impl WebSocketUsers {
}
// NOTE: The last modified date needs to be updated before calling these methods
pub async fn send_user_update(&self, ut: UpdateType, user: &User, push_uuid: &Option<PushId>, conn: &mut DbConn) {
pub async fn send_user_update(&self, ut: UpdateType, user: &User, push_uuid: &Option<PushId>, conn: &DbConn) {
// Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED {
return;
@ -359,7 +359,7 @@ impl WebSocketUsers {
}
}
pub async fn send_logout(&self, user: &User, acting_device_id: Option<DeviceId>, conn: &mut DbConn) {
pub async fn send_logout(&self, user: &User, acting_device_id: Option<DeviceId>, conn: &DbConn) {
// Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED {
return;
@ -379,7 +379,7 @@ impl WebSocketUsers {
}
}
pub async fn send_folder_update(&self, ut: UpdateType, folder: &Folder, device: &Device, conn: &mut DbConn) {
pub async fn send_folder_update(&self, ut: UpdateType, folder: &Folder, device: &Device, conn: &DbConn) {
// Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED {
return;
@ -410,7 +410,7 @@ impl WebSocketUsers {
user_ids: &[UserId],
device: &Device,
collection_uuids: Option<Vec<CollectionId>>,
conn: &mut DbConn,
conn: &DbConn,
) {
// Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED {
@ -458,7 +458,7 @@ impl WebSocketUsers {
send: &DbSend,
user_ids: &[UserId],
device: &Device,
conn: &mut DbConn,
conn: &DbConn,
) {
// Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED {
@ -486,13 +486,7 @@ impl WebSocketUsers {
}
}
pub async fn send_auth_request(
&self,
user_id: &UserId,
auth_request_uuid: &str,
device: &Device,
conn: &mut DbConn,
) {
pub async fn send_auth_request(&self, user_id: &UserId, auth_request_uuid: &str, device: &Device, conn: &DbConn) {
// Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED {
return;
@ -516,7 +510,7 @@ impl WebSocketUsers {
user_id: &UserId,
auth_request_id: &AuthRequestId,
device: &Device,
conn: &mut DbConn,
conn: &DbConn,
) {
// Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED {

26
src/api/push.rs

@ -7,7 +7,10 @@ use tokio::sync::RwLock;
use crate::{
api::{ApiResult, EmptyResult, UpdateType},
db::models::{AuthRequestId, Cipher, Device, DeviceId, Folder, PushId, Send, User, UserId},
db::{
models::{AuthRequestId, Cipher, Device, DeviceId, Folder, PushId, Send, User, UserId},
DbConn,
},
http_client::make_http_request,
util::{format_date, get_uuid},
CONFIG,
@ -79,7 +82,7 @@ async fn get_auth_api_token() -> ApiResult<String> {
Ok(api_token.access_token.clone())
}
pub async fn register_push_device(device: &mut Device, conn: &mut crate::db::DbConn) -> EmptyResult {
pub async fn register_push_device(device: &mut Device, conn: &DbConn) -> EmptyResult {
if !CONFIG.push_enabled() || !device.is_push_device() {
return Ok(());
}
@ -152,7 +155,7 @@ pub async fn unregister_push_device(push_id: &Option<PushId>) -> EmptyResult {
Ok(())
}
pub async fn push_cipher_update(ut: UpdateType, cipher: &Cipher, device: &Device, conn: &mut crate::db::DbConn) {
pub async fn push_cipher_update(ut: UpdateType, cipher: &Cipher, device: &Device, conn: &DbConn) {
// We shouldn't send a push notification on cipher update if the cipher belongs to an organization, this isn't implemented in the upstream server too.
if cipher.organization_uuid.is_some() {
return;
@ -183,7 +186,7 @@ pub async fn push_cipher_update(ut: UpdateType, cipher: &Cipher, device: &Device
}
}
pub async fn push_logout(user: &User, acting_device_id: Option<DeviceId>, conn: &mut crate::db::DbConn) {
pub async fn push_logout(user: &User, acting_device_id: Option<DeviceId>, conn: &DbConn) {
let acting_device_id: Value = acting_device_id.map(|v| v.to_string().into()).unwrap_or_else(|| Value::Null);
if Device::check_user_has_push_device(&user.uuid, conn).await {
@ -203,7 +206,7 @@ pub async fn push_logout(user: &User, acting_device_id: Option<DeviceId>, conn:
}
}
pub async fn push_user_update(ut: UpdateType, user: &User, push_uuid: &Option<PushId>, conn: &mut crate::db::DbConn) {
pub async fn push_user_update(ut: UpdateType, user: &User, push_uuid: &Option<PushId>, conn: &DbConn) {
if Device::check_user_has_push_device(&user.uuid, conn).await {
tokio::task::spawn(send_to_push_relay(json!({
"userId": user.uuid,
@ -221,7 +224,7 @@ pub async fn push_user_update(ut: UpdateType, user: &User, push_uuid: &Option<Pu
}
}
pub async fn push_folder_update(ut: UpdateType, folder: &Folder, device: &Device, conn: &mut crate::db::DbConn) {
pub async fn push_folder_update(ut: UpdateType, folder: &Folder, device: &Device, conn: &DbConn) {
if Device::check_user_has_push_device(&folder.user_uuid, conn).await {
tokio::task::spawn(send_to_push_relay(json!({
"userId": folder.user_uuid,
@ -240,7 +243,7 @@ pub async fn push_folder_update(ut: UpdateType, folder: &Folder, device: &Device
}
}
pub async fn push_send_update(ut: UpdateType, send: &Send, device: &Device, conn: &mut crate::db::DbConn) {
pub async fn push_send_update(ut: UpdateType, send: &Send, device: &Device, conn: &DbConn) {
if let Some(s) = &send.user_uuid {
if Device::check_user_has_push_device(s, conn).await {
tokio::task::spawn(send_to_push_relay(json!({
@ -296,7 +299,7 @@ async fn send_to_push_relay(notification_data: Value) {
};
}
pub async fn push_auth_request(user_id: &UserId, auth_request_id: &str, device: &Device, conn: &mut crate::db::DbConn) {
pub async fn push_auth_request(user_id: &UserId, auth_request_id: &str, device: &Device, conn: &DbConn) {
if Device::check_user_has_push_device(user_id, conn).await {
tokio::task::spawn(send_to_push_relay(json!({
"userId": user_id,
@ -314,12 +317,7 @@ pub async fn push_auth_request(user_id: &UserId, auth_request_id: &str, device:
}
}
pub async fn push_auth_response(
user_id: &UserId,
auth_request_id: &AuthRequestId,
device: &Device,
conn: &mut crate::db::DbConn,
) {
pub async fn push_auth_response(user_id: &UserId, auth_request_id: &AuthRequestId, device: &Device, conn: &DbConn) {
if Device::check_user_has_push_device(user_id, conn).await {
tokio::task::spawn(send_to_push_relay(json!({
"userId": user_id,

20
src/auth.rs

@ -604,16 +604,16 @@ impl<'r> FromRequest<'r> for Headers {
let device_id = claims.device;
let user_id = claims.sub;
let mut conn = match DbConn::from_request(request).await {
let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"),
};
let Some(device) = Device::find_by_uuid_and_user(&device_id, &user_id, &mut conn).await else {
let Some(device) = Device::find_by_uuid_and_user(&device_id, &user_id, &conn).await else {
err_handler!("Invalid device id")
};
let Some(user) = User::find_by_uuid(&user_id, &mut conn).await else {
let Some(user) = User::find_by_uuid(&user_id, &conn).await else {
err_handler!("Device has no user associated")
};
@ -633,7 +633,7 @@ impl<'r> FromRequest<'r> for Headers {
// This prevents checking this stamp exception for new requests.
let mut user = user;
user.reset_stamp_exception();
if let Err(e) = user.save(&mut conn).await {
if let Err(e) = user.save(&conn).await {
error!("Error updating user: {e:#?}");
}
err_handler!("Stamp exception is expired")
@ -706,13 +706,13 @@ impl<'r> FromRequest<'r> for OrgHeaders {
match url_org_id {
Some(org_id) if uuid::Uuid::parse_str(&org_id).is_ok() => {
let mut conn = match DbConn::from_request(request).await {
let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"),
};
let user = headers.user;
let Some(membership) = Membership::find_by_user_and_org(&user.uuid, &org_id, &mut conn).await else {
let Some(membership) = Membership::find_by_user_and_org(&user.uuid, &org_id, &conn).await else {
err_handler!("The current user isn't member of the organization");
};
@ -815,12 +815,12 @@ impl<'r> FromRequest<'r> for ManagerHeaders {
if headers.is_confirmed_and_manager() {
match get_col_id(request) {
Some(col_id) => {
let mut conn = match DbConn::from_request(request).await {
let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"),
};
if !Collection::can_access_collection(&headers.membership, &col_id, &mut conn).await {
if !Collection::can_access_collection(&headers.membership, &col_id, &conn).await {
err_handler!("The current user isn't a manager for this collection")
}
}
@ -896,7 +896,7 @@ impl ManagerHeaders {
pub async fn from_loose(
h: ManagerHeadersLoose,
collections: &Vec<CollectionId>,
conn: &mut DbConn,
conn: &DbConn,
) -> Result<ManagerHeaders, Error> {
for col_id in collections {
if uuid::Uuid::parse_str(col_id.as_ref()).is_err() {
@ -1200,7 +1200,7 @@ pub async fn refresh_tokens(
ip: &ClientIp,
refresh_token: &str,
client_id: Option<String>,
conn: &mut DbConn,
conn: &DbConn,
) -> ApiResult<(Device, AuthTokens)> {
let refresh_claims = match decode_refresh(refresh_token) {
Err(err) => {

12
src/config.rs

@ -12,7 +12,6 @@ use once_cell::sync::Lazy;
use reqwest::Url;
use crate::{
db::DbConnType,
error::Error,
util::{get_env, get_env_bool, get_web_vault_version, is_valid_email, parse_experimental_client_feature_flags},
};
@ -815,12 +814,19 @@ make_config! {
fn validate_config(cfg: &ConfigItems) -> Result<(), Error> {
// Validate connection URL is valid and DB feature is enabled
#[cfg(sqlite)]
{
use crate::db::DbConnType;
let url = &cfg.database_url;
if DbConnType::from_url(url)? == DbConnType::sqlite && url.contains('/') {
if DbConnType::from_url(url)? == DbConnType::Sqlite && url.contains('/') {
let path = std::path::Path::new(&url);
if let Some(parent) = path.parent() {
if !parent.is_dir() {
err!(format!("SQLite database directory `{}` does not exist or is not a directory", parent.display()));
err!(format!(
"SQLite database directory `{}` does not exist or is not a directory",
parent.display()
));
}
}
}
}

379
src/db/mod.rs

@ -1,8 +1,14 @@
use std::{sync::Arc, time::Duration};
mod query_logger;
use std::{
sync::{Arc, OnceLock},
time::Duration,
};
use diesel::{
connection::SimpleConnection,
r2d2::{ConnectionManager, CustomizeConnection, Pool, PooledConnection},
Connection, RunQueryDsl,
};
use rocket::{
@ -21,20 +27,7 @@ use crate::{
CONFIG,
};
#[cfg(sqlite)]
#[path = "schemas/sqlite/schema.rs"]
pub mod __sqlite_schema;
#[cfg(mysql)]
#[path = "schemas/mysql/schema.rs"]
pub mod __mysql_schema;
#[cfg(postgresql)]
#[path = "schemas/postgresql/schema.rs"]
pub mod __postgresql_schema;
// These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
// A wrapper around spawn_blocking that propagates panics to the calling code.
pub async fn run_blocking<F, R>(job: F) -> R
where
@ -51,47 +44,54 @@ where
}
// This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported
macro_rules! generate_connections {
( $( $name:ident: $ty:ty ),+ ) => {
#[allow(non_camel_case_types, dead_code)]
#[derive(diesel::MultiConnection)]
pub enum DbConnInner {
#[cfg(mysql)]
Mysql(diesel::mysql::MysqlConnection),
#[cfg(postgresql)]
Postgresql(diesel::pg::PgConnection),
#[cfg(sqlite)]
Sqlite(diesel::sqlite::SqliteConnection),
}
#[derive(Eq, PartialEq)]
pub enum DbConnType { $( $name, )+ }
pub enum DbConnType {
#[cfg(mysql)]
Mysql,
#[cfg(postgresql)]
Postgresql,
#[cfg(sqlite)]
Sqlite,
}
pub static ACTIVE_DB_TYPE: OnceLock<DbConnType> = OnceLock::new();
pub struct DbConn {
conn: Arc<Mutex<Option<DbConnInner>>>,
conn: Arc<Mutex<Option<PooledConnection<ConnectionManager<DbConnInner>>>>>,
permit: Option<OwnedSemaphorePermit>,
}
#[allow(non_camel_case_types)]
pub enum DbConnInner { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ }
#[derive(Debug)]
pub struct DbConnOptions {
pub init_stmts: String,
}
$( // Based on <https://stackoverflow.com/a/57717533>.
#[cfg($name)]
impl CustomizeConnection<$ty, diesel::r2d2::Error> for DbConnOptions {
fn on_acquire(&self, conn: &mut $ty) -> Result<(), diesel::r2d2::Error> {
impl CustomizeConnection<DbConnInner, diesel::r2d2::Error> for DbConnOptions {
fn on_acquire(&self, conn: &mut DbConnInner) -> Result<(), diesel::r2d2::Error> {
if !self.init_stmts.is_empty() {
conn.batch_execute(&self.init_stmts).map_err(diesel::r2d2::Error::QueryError)?;
}
Ok(())
}
})+
}
#[derive(Clone)]
pub struct DbPool {
// This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
pool: Option<DbPoolInner>,
semaphore: Arc<Semaphore>
pool: Option<Pool<ConnectionManager<DbConnInner>>>,
semaphore: Arc<Semaphore>,
}
#[allow(non_camel_case_types)]
#[derive(Clone)]
pub enum DbPoolInner { $( #[cfg($name)] $name(Pool<ConnectionManager< $ty >>), )+ }
impl Drop for DbConn {
fn drop(&mut self) {
let conn = Arc::clone(&self.conn);
@ -116,7 +116,11 @@ macro_rules! generate_connections {
impl Drop for DbPool {
fn drop(&mut self) {
let pool = self.pool.take();
tokio::task::spawn_blocking(move || drop(pool));
// Only use spawn_blocking if the Tokio runtime is still available
// Otherwise the pool will be dropped on the current thread
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn_blocking(move || drop(pool));
}
}
}
@ -126,32 +130,53 @@ macro_rules! generate_connections {
let url = CONFIG.database_url();
let conn_type = DbConnType::from_url(&url)?;
match conn_type { $(
DbConnType::$name => {
#[cfg($name)]
// Only set the default instrumentation if the log level is specifically set to either warn, info or debug
if log_enabled!(target: "vaultwarden::db::query_logger", log::Level::Warn)
|| log_enabled!(target: "vaultwarden::db::query_logger", log::Level::Info)
|| log_enabled!(target: "vaultwarden::db::query_logger", log::Level::Debug)
{
pastey::paste!{ [< $name _migrations >]::run_migrations()?; }
let manager = ConnectionManager::new(&url);
drop(diesel::connection::set_default_instrumentation(query_logger::simple_logger));
}
match conn_type {
#[cfg(mysql)]
DbConnType::Mysql => {
mysql_migrations::run_migrations(&url)?;
}
#[cfg(postgresql)]
DbConnType::Postgresql => {
postgresql_migrations::run_migrations(&url)?;
}
#[cfg(sqlite)]
DbConnType::Sqlite => {
sqlite_migrations::run_migrations(&url)?;
}
}
let max_conns = CONFIG.database_max_conns();
let manager = ConnectionManager::<DbConnInner>::new(&url);
let pool = Pool::builder()
.max_size(CONFIG.database_max_conns())
.max_size(max_conns)
.min_idle(Some(CONFIG.database_min_conns()))
.idle_timeout(Some(Duration::from_secs(CONFIG.database_idle_timeout())))
.connection_timeout(Duration::from_secs(CONFIG.database_timeout()))
.connection_customizer(Box::new(DbConnOptions {
init_stmts: conn_type.get_init_stmts()
init_stmts: conn_type.get_init_stmts(),
}))
.build(manager)
.map_res("Failed to create pool")?;
// Set a global to determine the database more easily throughout the rest of the code
if ACTIVE_DB_TYPE.set(conn_type).is_err() {
error!("Tried to set the active database connection type more than once.")
}
Ok(DbPool {
pool: Some(DbPoolInner::$name(pool)),
semaphore: Arc::new(Semaphore::new(CONFIG.database_max_conns() as usize)),
pool: Some(pool),
semaphore: Arc::new(Semaphore::new(max_conns as usize)),
})
}
#[cfg(not($name))]
unreachable!("Trying to use a DB backend when it's feature is disabled")
},
)+ }
}
// Get a connection from the pool
pub async fn get(&self) -> Result<DbConn, Error> {
let duration = Duration::from_secs(CONFIG.database_timeout());
@ -162,51 +187,31 @@ macro_rules! generate_connections {
}
};
match self.pool.as_ref().expect("DbPool.pool should always be Some()") { $(
#[cfg($name)]
DbPoolInner::$name(p) => {
let p = self.pool.as_ref().expect("DbPool.pool should always be Some()");
let pool = p.clone();
let c = run_blocking(move || pool.get_timeout(duration)).await.map_res("Error retrieving connection from pool")?;
let c =
run_blocking(move || pool.get_timeout(duration)).await.map_res("Error retrieving connection from pool")?;
Ok(DbConn {
conn: Arc::new(Mutex::new(Some(DbConnInner::$name(c)))),
permit: Some(permit)
conn: Arc::new(Mutex::new(Some(c))),
permit: Some(permit),
})
},
)+ }
}
}
};
}
#[cfg(not(query_logger))]
generate_connections! {
sqlite: diesel::sqlite::SqliteConnection,
mysql: diesel::mysql::MysqlConnection,
postgresql: diesel::pg::PgConnection
}
#[cfg(query_logger)]
generate_connections! {
sqlite: diesel_logger::LoggingConnection<diesel::sqlite::SqliteConnection>,
mysql: diesel_logger::LoggingConnection<diesel::mysql::MysqlConnection>,
postgresql: diesel_logger::LoggingConnection<diesel::pg::PgConnection>
}
impl DbConnType {
pub fn from_url(url: &str) -> Result<DbConnType, Error> {
pub fn from_url(url: &str) -> Result<Self, Error> {
// Mysql
if url.starts_with("mysql:") {
if url.len() > 6 && &url[..6] == "mysql:" {
#[cfg(mysql)]
return Ok(DbConnType::mysql);
return Ok(DbConnType::Mysql);
#[cfg(not(mysql))]
err!("`DATABASE_URL` is a MySQL URL, but the 'mysql' feature is not enabled")
// Postgres
} else if url.starts_with("postgresql:") || url.starts_with("postgres:") {
// Postgresql
} else if url.len() > 11 && (&url[..11] == "postgresql:" || &url[..9] == "postgres:") {
#[cfg(postgresql)]
return Ok(DbConnType::postgresql);
return Ok(DbConnType::Postgresql);
#[cfg(not(postgresql))]
err!("`DATABASE_URL` is a PostgreSQL URL, but the 'postgresql' feature is not enabled")
@ -214,7 +219,7 @@ impl DbConnType {
//Sqlite
} else {
#[cfg(sqlite)]
return Ok(DbConnType::sqlite);
return Ok(DbConnType::Sqlite);
#[cfg(not(sqlite))]
err!("`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled")
@ -232,175 +237,102 @@ impl DbConnType {
pub fn default_init_stmts(&self) -> String {
match self {
Self::sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(),
Self::mysql => String::new(),
Self::postgresql => String::new(),
#[cfg(mysql)]
Self::Mysql => String::new(),
#[cfg(postgresql)]
Self::Postgresql => String::new(),
#[cfg(sqlite)]
Self::Sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(),
}
}
}
#[macro_export]
macro_rules! db_run {
// Same for all dbs
( $conn:ident: $body:block ) => {
db_run! { $conn: sqlite, mysql, postgresql $body }
};
( @raw $conn:ident: $body:block ) => {
db_run! { @raw $conn: sqlite, mysql, postgresql $body }
};
// Different code for each db
( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
#[allow(unused)] use diesel::prelude::*;
#[allow(unused)] use $crate::db::FromDb;
let conn = $conn.conn.clone();
macro_rules! db_run_base {
( $conn:ident ) => {
let conn = std::sync::Arc::clone(&$conn.conn);
let mut conn = conn.lock_owned().await;
match conn.as_mut().expect("internal invariant broken: self.connection is Some") {
$($(
#[cfg($db)]
$crate::db::DbConnInner::$db($conn) => {
pastey::paste! {
#[allow(unused)] use $crate::db::[<__ $db _schema>]::{self as schema, *};
#[allow(unused)] use [<__ $db _model>]::*;
let $conn = conn.as_mut().expect("internal invariant broken: self.conn is Some");
};
}
tokio::task::block_in_place(move || { $body }) // Run blocking can't be used due to the 'static limitation, use block_in_place instead
},
)+)+
}
#[macro_export]
macro_rules! db_run {
( $conn:ident: $body:block ) => {{
db_run_base!($conn);
// Run blocking can't be used due to the 'static limitation, use block_in_place instead
tokio::task::block_in_place(move || $body )
}};
( @raw $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
#[allow(unused)] use diesel::prelude::*;
#[allow(unused)] use $crate::db::FromDb;
let conn = $conn.conn.clone();
let mut conn = conn.lock_owned().await;
match conn.as_mut().expect("internal invariant broken: self.connection is Some") {
( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
db_run_base!($conn);
match std::ops::DerefMut::deref_mut($conn) {
$($(
#[cfg($db)]
$crate::db::DbConnInner::$db($conn) => {
pastey::paste! {
#[allow(unused)] use $crate::db::[<__ $db _schema>]::{self as schema, *};
// @ RAW: #[allow(unused)] use [<__ $db _model>]::*;
}
tokio::task::block_in_place(move || { $body }) // Run blocking can't be used due to the 'static limitation, use block_in_place instead
pastey::paste!(&mut $crate::db::DbConnInner::[<$db:camel>](ref mut $conn)) => {
// Run blocking can't be used due to the 'static limitation, use block_in_place instead
tokio::task::block_in_place(move || $body )
},
)+)+
}
)+)+}
}};
}
pub trait FromDb {
type Output;
#[allow(clippy::wrong_self_convention)]
fn from_db(self) -> Self::Output;
}
impl<T: FromDb> FromDb for Vec<T> {
type Output = Vec<T::Output>;
#[inline(always)]
fn from_db(self) -> Self::Output {
self.into_iter().map(FromDb::from_db).collect()
}
}
impl<T: FromDb> FromDb for Option<T> {
type Output = Option<T::Output>;
#[inline(always)]
fn from_db(self) -> Self::Output {
self.map(FromDb::from_db)
}
}
// For each struct eg. Cipher, we create a CipherDb inside a module named __$db_model (where $db is sqlite, mysql or postgresql),
// to implement the Diesel traits. We also provide methods to convert between them and the basic structs. Later, that module will be auto imported when using db_run!
#[macro_export]
macro_rules! db_object {
( $(
$( #[$attr:meta] )*
pub struct $name:ident {
$( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty ),+
$(,)?
}
)+ ) => {
// Create the normal struct, without attributes
$( pub struct $name { $( /*$( #[$field_attr] )**/ $vis $field : $typ, )+ } )+
#[cfg(sqlite)]
pub mod __sqlite_model { $( db_object! { @db sqlite | $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ }
#[cfg(mysql)]
pub mod __mysql_model { $( db_object! { @db mysql | $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ }
#[cfg(postgresql)]
pub mod __postgresql_model { $( db_object! { @db postgresql | $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ }
};
( @db $db:ident | $( #[$attr:meta] )* | $name:ident | $( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty),+) => {
pastey::paste! {
#[allow(unused)] use super::*;
#[allow(unused)] use diesel::prelude::*;
#[allow(unused)] use $crate::db::[<__ $db _schema>]::*;
$( #[$attr] )*
pub struct [<$name Db>] { $(
$( #[$field_attr] )* $vis $field : $typ,
)+ }
impl [<$name Db>] {
#[allow(clippy::wrong_self_convention)]
#[inline(always)] pub fn to_db(x: &super::$name) -> Self { Self { $( $field: x.$field.clone(), )+ } }
}
impl $crate::db::FromDb for [<$name Db>] {
type Output = super::$name;
#[allow(clippy::wrong_self_convention)]
#[inline(always)] fn from_db(self) -> Self::Output { super::$name { $( $field: self.$field, )+ } }
}
}
};
}
pub mod schema;
// Reexport the models, needs to be after the macros are defined so it can access them
pub mod models;
/// Creates a back-up of the sqlite database
/// MySQL/MariaDB and PostgreSQL are not supported.
pub async fn backup_database(conn: &mut DbConn) -> Result<String, Error> {
db_run! {@raw conn:
postgresql, mysql {
let _ = conn;
err!("PostgreSQL and MySQL/MariaDB do not support this backup feature");
}
sqlite {
#[cfg(sqlite)]
pub fn backup_sqlite() -> Result<String, Error> {
use diesel::Connection;
use std::{fs::File, io::Write};
let db_url = CONFIG.database_url();
if DbConnType::from_url(&CONFIG.database_url()).map(|t| t == DbConnType::Sqlite).unwrap_or(false) {
// Since we do not allow any schema for sqlite database_url's like `file:` or `sqlite:` to be set, we can assume here it isn't
// This way we can set a readonly flag on the opening mode without issues.
let mut conn = diesel::sqlite::SqliteConnection::establish(&format!("sqlite://{db_url}?mode=ro"))?;
let db_path = std::path::Path::new(&db_url).parent().unwrap();
let backup_file = db_path
.join(format!("db_{}.sqlite3", chrono::Utc::now().format("%Y%m%d_%H%M%S")))
.to_string_lossy()
.into_owned();
diesel::sql_query(format!("VACUUM INTO '{backup_file}'")).execute(conn)?;
match File::create(backup_file.clone()) {
Ok(mut f) => {
let serialized_db = conn.serialize_database_to_buffer();
f.write_all(serialized_db.as_slice()).expect("Error writing SQLite backup");
Ok(backup_file)
}
Err(e) => {
err_silent!(format!("Unable to save SQLite backup: {e:?}"))
}
}
} else {
err_silent!("The database type is not SQLite. Backups only works for SQLite databases")
}
}
#[cfg(not(sqlite))]
pub fn backup_sqlite() -> Result<String, Error> {
err_silent!("The database type is not SQLite. Backups only works for SQLite databases")
}
/// Get the SQL Server version
pub async fn get_sql_server_version(conn: &mut DbConn) -> String {
db_run! {@raw conn:
pub async fn get_sql_server_version(conn: &DbConn) -> String {
db_run! { conn:
postgresql,mysql {
define_sql_function!{
fn version() -> diesel::sql_types::Text;
}
diesel::select(version()).get_result::<String>(conn).unwrap_or_else(|_| "Unknown".to_string())
diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("version();"))
.get_result::<String>(conn)
.unwrap_or_else(|_| "Unknown".to_string())
}
sqlite {
define_sql_function!{
fn sqlite_version() -> diesel::sql_types::Text;
}
diesel::select(sqlite_version()).get_result::<String>(conn).unwrap_or_else(|_| "Unknown".to_string())
diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("sqlite_version();"))
.get_result::<String>(conn)
.unwrap_or_else(|_| "Unknown".to_string())
}
}
}
@ -428,16 +360,14 @@ impl<'r> FromRequest<'r> for DbConn {
// https://docs.rs/diesel_migrations/*/diesel_migrations/macro.embed_migrations.html
#[cfg(sqlite)]
mod sqlite_migrations {
use diesel::{Connection, RunQueryDsl};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness};
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
pub fn run_migrations() -> Result<(), super::Error> {
use diesel::{Connection, RunQueryDsl};
let url = crate::CONFIG.database_url();
pub fn run_migrations(url: &str) -> Result<(), super::Error> {
// Establish a connection to the sqlite database (this will create a new one, if it does
// not exist, and exit if there is an error).
let mut connection = diesel::sqlite::SqliteConnection::establish(&url)?;
let mut connection = diesel::sqlite::SqliteConnection::establish(url)?;
// Run the migrations after successfully establishing a connection
// Disable Foreign Key Checks during migration
@ -458,15 +388,15 @@ mod sqlite_migrations {
#[cfg(mysql)]
mod mysql_migrations {
use diesel::{Connection, RunQueryDsl};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness};
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/mysql");
pub fn run_migrations() -> Result<(), super::Error> {
use diesel::{Connection, RunQueryDsl};
pub fn run_migrations(url: &str) -> Result<(), super::Error> {
// Make sure the database is up to date (create if it doesn't exist, or run the migrations)
let mut connection = diesel::mysql::MysqlConnection::establish(&crate::CONFIG.database_url())?;
// Disable Foreign Key Checks during migration
let mut connection = diesel::mysql::MysqlConnection::establish(url)?;
// Disable Foreign Key Checks during migration
// Scoped to a connection/session.
diesel::sql_query("SET FOREIGN_KEY_CHECKS = 0")
.execute(&mut connection)
@ -479,13 +409,14 @@ mod mysql_migrations {
#[cfg(postgresql)]
mod postgresql_migrations {
use diesel::Connection;
use diesel_migrations::{EmbeddedMigrations, MigrationHarness};
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgresql");
pub fn run_migrations() -> Result<(), super::Error> {
use diesel::Connection;
pub fn run_migrations(url: &str) -> Result<(), super::Error> {
// Make sure the database is up to date (create if it doesn't exist, or run the migrations)
let mut connection = diesel::pg::PgConnection::establish(&crate::CONFIG.database_url())?;
let mut connection = diesel::pg::PgConnection::establish(url)?;
connection.run_pending_migrations(MIGRATIONS).expect("Error running migrations");
Ok(())
}

50
src/db/models/attachment.rs

@ -1,14 +1,14 @@
use std::time::Duration;
use bigdecimal::{BigDecimal, ToPrimitive};
use derive_more::{AsRef, Deref, Display};
use diesel::prelude::*;
use serde_json::Value;
use std::time::Duration;
use super::{CipherId, OrganizationId, UserId};
use crate::db::schema::{attachments, ciphers};
use crate::{config::PathType, CONFIG};
use macros::IdFromParam;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = attachments)]
#[diesel(treat_none_as_null = true)]
@ -20,7 +20,6 @@ db_object! {
pub file_size: i64,
pub akey: Option<String>,
}
}
/// Local methods
impl Attachment {
@ -76,11 +75,11 @@ use crate::error::MapResult;
/// Database methods
impl Attachment {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(attachments::table)
.values(AttachmentDb::to_db(self))
.values(self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -88,7 +87,7 @@ impl Attachment {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(attachments::table)
.filter(attachments::id.eq(&self.id))
.set(AttachmentDb::to_db(self))
.set(self)
.execute(conn)
.map_res("Error saving attachment")
}
@ -96,22 +95,22 @@ impl Attachment {
}.map_res("Error saving attachment")
}
postgresql {
let value = AttachmentDb::to_db(self);
diesel::insert_into(attachments::table)
.values(&value)
.values(self)
.on_conflict(attachments::id)
.do_update()
.set(&value)
.set(self)
.execute(conn)
.map_res("Error saving attachment")
}
}
}
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
crate::util::retry(
|| diesel::delete(attachments::table.filter(attachments::id.eq(&self.id))).execute(conn),
crate::util::retry(||
diesel::delete(attachments::table.filter(attachments::id.eq(&self.id)))
.execute(conn),
10,
)
.map(|_| ())
@ -132,34 +131,32 @@ impl Attachment {
Ok(())
}
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
for attachment in Attachment::find_by_cipher(cipher_uuid, conn).await {
attachment.delete(conn).await?;
}
Ok(())
}
pub async fn find_by_id(id: &AttachmentId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_id(id: &AttachmentId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
attachments::table
.filter(attachments::id.eq(id.to_lowercase()))
.first::<AttachmentDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_by_cipher(cipher_uuid: &CipherId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
attachments::table
.filter(attachments::cipher_uuid.eq(cipher_uuid))
.load::<AttachmentDb>(conn)
.load::<Self>(conn)
.expect("Error loading attachments")
.from_db()
}}
}
pub async fn size_by_user(user_uuid: &UserId, conn: &mut DbConn) -> i64 {
pub async fn size_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
db_run! { conn: {
let result: Option<BigDecimal> = attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
@ -176,7 +173,7 @@ impl Attachment {
}}
}
pub async fn count_by_user(user_uuid: &UserId, conn: &mut DbConn) -> i64 {
pub async fn count_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
db_run! { conn: {
attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
@ -187,7 +184,7 @@ impl Attachment {
}}
}
pub async fn size_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 {
pub async fn size_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: {
let result: Option<BigDecimal> = attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
@ -204,7 +201,7 @@ impl Attachment {
}}
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 {
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: {
attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
@ -221,7 +218,7 @@ impl Attachment {
pub async fn find_all_by_user_and_orgs(
user_uuid: &UserId,
org_uuids: &Vec<OrganizationId>,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
attachments::table
@ -229,9 +226,8 @@ impl Attachment {
.filter(ciphers::user_uuid.eq(user_uuid))
.or_filter(ciphers::organization_uuid.eq_any(org_uuids))
.select(attachments::all_columns)
.load::<AttachmentDb>(conn)
.load::<Self>(conn)
.expect("Error loading attachments")
.from_db()
}}
}
}

44
src/db/models/auth_request.rs

@ -1,11 +1,12 @@
use super::{DeviceId, OrganizationId, UserId};
use crate::db::schema::auth_requests;
use crate::{crypto::ct_eq, util::format_date};
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use macros::UuidFromParam;
use serde_json::Value;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset, Deserialize, Serialize)]
#[diesel(table_name = auth_requests)]
#[diesel(treat_none_as_null = true)]
@ -33,7 +34,6 @@ db_object! {
pub authentication_date: Option<NaiveDateTime>,
}
}
impl AuthRequest {
pub fn new(
@ -80,11 +80,11 @@ use crate::api::EmptyResult;
use crate::error::MapResult;
impl AuthRequest {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(auth_requests::table)
.values(AuthRequestDb::to_db(self))
.values(&*self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -92,7 +92,7 @@ impl AuthRequest {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(auth_requests::table)
.filter(auth_requests::uuid.eq(&self.uuid))
.set(AuthRequestDb::to_db(self))
.set(&*self)
.execute(conn)
.map_res("Error auth_request")
}
@ -100,51 +100,49 @@ impl AuthRequest {
}.map_res("Error auth_request")
}
postgresql {
let value = AuthRequestDb::to_db(self);
diesel::insert_into(auth_requests::table)
.values(&value)
.values(&*self)
.on_conflict(auth_requests::uuid)
.do_update()
.set(&value)
.set(&*self)
.execute(conn)
.map_res("Error saving auth_request")
}
}
}
pub async fn find_by_uuid(uuid: &AuthRequestId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid(uuid: &AuthRequestId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
auth_requests::table
.filter(auth_requests::uuid.eq(uuid))
.first::<AuthRequestDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_by_uuid_and_user(uuid: &AuthRequestId, user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid_and_user(uuid: &AuthRequestId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
auth_requests::table
.filter(auth_requests::uuid.eq(uuid))
.filter(auth_requests::user_uuid.eq(user_uuid))
.first::<AuthRequestDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
auth_requests::table
.filter(auth_requests::user_uuid.eq(user_uuid))
.load::<AuthRequestDb>(conn).expect("Error loading auth_requests").from_db()
.load::<Self>(conn)
.expect("Error loading auth_requests")
}}
}
pub async fn find_by_user_and_requested_device(
user_uuid: &UserId,
device_uuid: &DeviceId,
conn: &mut DbConn,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
auth_requests::table
@ -152,19 +150,21 @@ impl AuthRequest {
.filter(auth_requests::request_device_identifier.eq(device_uuid))
.filter(auth_requests::approved.is_null())
.order_by(auth_requests::creation_date.desc())
.first::<AuthRequestDb>(conn).ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_created_before(dt: &NaiveDateTime, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_created_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
auth_requests::table
.filter(auth_requests::creation_date.lt(dt))
.load::<AuthRequestDb>(conn).expect("Error loading auth_requests").from_db()
.load::<Self>(conn)
.expect("Error loading auth_requests")
}}
}
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(auth_requests::table.filter(auth_requests::uuid.eq(&self.uuid)))
.execute(conn)
@ -176,7 +176,7 @@ impl AuthRequest {
ct_eq(&self.access_code, access_code)
}
pub async fn purge_expired_auth_requests(conn: &mut DbConn) {
pub async fn purge_expired_auth_requests(conn: &DbConn) {
let expiry_time = Utc::now().naive_utc() - chrono::TimeDelta::try_minutes(5).unwrap(); //after 5 minutes, clients reject the request
for auth_request in Self::find_created_before(&expiry_time, conn).await {
auth_request.delete(conn).await.ok();

133
src/db/models/cipher.rs

@ -1,7 +1,12 @@
use crate::db::schema::{
ciphers, ciphers_collections, collections, collections_groups, folders, folders_ciphers, groups, groups_users,
users_collections, users_organizations,
};
use crate::util::LowerCase;
use crate::CONFIG;
use chrono::{NaiveDateTime, TimeDelta, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value;
use super::{
@ -13,7 +18,6 @@ use macros::UuidFromParam;
use std::borrow::Cow;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = ciphers)]
#[diesel(treat_none_as_null = true)]
@ -46,7 +50,6 @@ db_object! {
pub deleted_at: Option<NaiveDateTime>,
pub reprompt: Option<i32>,
}
}
pub enum RepromptType {
None = 0,
@ -140,7 +143,7 @@ impl Cipher {
user_uuid: &UserId,
cipher_sync_data: Option<&CipherSyncData>,
sync_type: CipherSyncType,
conn: &mut DbConn,
conn: &DbConn,
) -> Result<Value, crate::Error> {
use crate::util::{format_date, validate_and_format_date};
@ -402,7 +405,7 @@ impl Cipher {
Ok(json_object)
}
pub async fn update_users_revision(&self, conn: &mut DbConn) -> Vec<UserId> {
pub async fn update_users_revision(&self, conn: &DbConn) -> Vec<UserId> {
let mut user_uuids = Vec::new();
match self.user_uuid {
Some(ref user_uuid) => {
@ -430,14 +433,14 @@ impl Cipher {
user_uuids
}
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await;
self.updated_at = Utc::now().naive_utc();
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(ciphers::table)
.values(CipherDb::to_db(self))
.values(&*self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -445,7 +448,7 @@ impl Cipher {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(ciphers::table)
.filter(ciphers::uuid.eq(&self.uuid))
.set(CipherDb::to_db(self))
.set(&*self)
.execute(conn)
.map_res("Error saving cipher")
}
@ -453,19 +456,18 @@ impl Cipher {
}.map_res("Error saving cipher")
}
postgresql {
let value = CipherDb::to_db(self);
diesel::insert_into(ciphers::table)
.values(&value)
.values(&*self)
.on_conflict(ciphers::uuid)
.do_update()
.set(&value)
.set(&*self)
.execute(conn)
.map_res("Error saving cipher")
}
}
}
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await;
FolderCipher::delete_all_by_cipher(&self.uuid, conn).await?;
@ -480,7 +482,7 @@ impl Cipher {
}}
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
// TODO: Optimize this by executing a DELETE directly on the database, instead of first fetching.
for cipher in Self::find_by_org(org_uuid, conn).await {
cipher.delete(conn).await?;
@ -488,7 +490,7 @@ impl Cipher {
Ok(())
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
for cipher in Self::find_owned_by_user(user_uuid, conn).await {
cipher.delete(conn).await?;
}
@ -496,7 +498,7 @@ impl Cipher {
}
/// Purge all ciphers that are old enough to be auto-deleted.
pub async fn purge_trash(conn: &mut DbConn) {
pub async fn purge_trash(conn: &DbConn) {
if let Some(auto_delete_days) = CONFIG.trash_auto_delete_days() {
let now = Utc::now().naive_utc();
let dt = now - TimeDelta::try_days(auto_delete_days).unwrap();
@ -510,7 +512,7 @@ impl Cipher {
&self,
folder_uuid: Option<FolderId>,
user_uuid: &UserId,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
User::update_uuid_revision(user_uuid, conn).await;
@ -550,7 +552,7 @@ impl Cipher {
&self,
user_uuid: &UserId,
cipher_sync_data: Option<&CipherSyncData>,
conn: &mut DbConn,
conn: &DbConn,
) -> bool {
if let Some(ref org_uuid) = self.organization_uuid {
if let Some(cipher_sync_data) = cipher_sync_data {
@ -569,7 +571,7 @@ impl Cipher {
&self,
user_uuid: &UserId,
cipher_sync_data: Option<&CipherSyncData>,
conn: &mut DbConn,
conn: &DbConn,
) -> bool {
if !CONFIG.org_groups_enabled() {
return false;
@ -593,7 +595,7 @@ impl Cipher {
&self,
user_uuid: &UserId,
cipher_sync_data: Option<&CipherSyncData>,
conn: &mut DbConn,
conn: &DbConn,
) -> Option<(bool, bool, bool)> {
// Check whether this cipher is directly owned by the user, or is in
// a collection that the user has full access to. If so, there are no
@ -659,11 +661,7 @@ impl Cipher {
Some((read_only, hide_passwords, manage))
}
async fn get_user_collections_access_flags(
&self,
user_uuid: &UserId,
conn: &mut DbConn,
) -> Vec<(bool, bool, bool)> {
async fn get_user_collections_access_flags(&self, user_uuid: &UserId, conn: &DbConn) -> Vec<(bool, bool, bool)> {
db_run! { conn: {
// Check whether this cipher is in any collections accessible to the
// user. If so, retrieve the access flags for each collection.
@ -680,11 +678,7 @@ impl Cipher {
}}
}
async fn get_group_collections_access_flags(
&self,
user_uuid: &UserId,
conn: &mut DbConn,
) -> Vec<(bool, bool, bool)> {
async fn get_group_collections_access_flags(&self, user_uuid: &UserId, conn: &DbConn) -> Vec<(bool, bool, bool)> {
if !CONFIG.org_groups_enabled() {
return Vec::new();
}
@ -710,31 +704,31 @@ impl Cipher {
}}
}
pub async fn is_write_accessible_to_user(&self, user_uuid: &UserId, conn: &mut DbConn) -> bool {
pub async fn is_write_accessible_to_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
match self.get_access_restrictions(user_uuid, None, conn).await {
Some((read_only, _hide_passwords, manage)) => !read_only || manage,
None => false,
}
}
pub async fn is_accessible_to_user(&self, user_uuid: &UserId, conn: &mut DbConn) -> bool {
pub async fn is_accessible_to_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
self.get_access_restrictions(user_uuid, None, conn).await.is_some()
}
// Returns whether this cipher is a favorite of the specified user.
pub async fn is_favorite(&self, user_uuid: &UserId, conn: &mut DbConn) -> bool {
pub async fn is_favorite(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
Favorite::is_favorite(&self.uuid, user_uuid, conn).await
}
// Sets whether this cipher is a favorite of the specified user.
pub async fn set_favorite(&self, favorite: Option<bool>, user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn set_favorite(&self, favorite: Option<bool>, user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
match favorite {
None => Ok(()), // No change requested.
Some(status) => Favorite::set_favorite(status, &self.uuid, user_uuid, conn).await,
}
}
pub async fn get_folder_uuid(&self, user_uuid: &UserId, conn: &mut DbConn) -> Option<FolderId> {
pub async fn get_folder_uuid(&self, user_uuid: &UserId, conn: &DbConn) -> Option<FolderId> {
db_run! { conn: {
folders_ciphers::table
.inner_join(folders::table)
@ -746,28 +740,26 @@ impl Cipher {
}}
}
pub async fn find_by_uuid(uuid: &CipherId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid(uuid: &CipherId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
ciphers::table
.filter(ciphers::uuid.eq(uuid))
.first::<CipherDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_by_uuid_and_org(
cipher_uuid: &CipherId,
org_uuid: &OrganizationId,
conn: &mut DbConn,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
ciphers::table
.filter(ciphers::uuid.eq(cipher_uuid))
.filter(ciphers::organization_uuid.eq(org_uuid))
.first::<CipherDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
@ -787,7 +779,7 @@ impl Cipher {
user_uuid: &UserId,
visible_only: bool,
cipher_uuids: &Vec<CipherId>,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<Self> {
if CONFIG.org_groups_enabled() {
db_run! { conn: {
@ -839,7 +831,8 @@ impl Cipher {
query
.select(ciphers::all_columns)
.distinct()
.load::<CipherDb>(conn).expect("Error loading ciphers").from_db()
.load::<Self>(conn)
.expect("Error loading ciphers")
}}
} else {
db_run! { conn: {
@ -878,45 +871,43 @@ impl Cipher {
query
.select(ciphers::all_columns)
.distinct()
.load::<CipherDb>(conn).expect("Error loading ciphers").from_db()
.load::<Self>(conn)
.expect("Error loading ciphers")
}}
}
}
// Find all ciphers visible to the specified user.
pub async fn find_by_user_visible(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_user_visible(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
Self::find_by_user(user_uuid, true, &vec![], conn).await
}
pub async fn find_by_user_and_ciphers(
user_uuid: &UserId,
cipher_uuids: &Vec<CipherId>,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<Self> {
Self::find_by_user(user_uuid, true, cipher_uuids, conn).await
}
pub async fn find_by_user_and_cipher(
user_uuid: &UserId,
cipher_uuid: &CipherId,
conn: &mut DbConn,
) -> Option<Self> {
pub async fn find_by_user_and_cipher(user_uuid: &UserId, cipher_uuid: &CipherId, conn: &DbConn) -> Option<Self> {
Self::find_by_user(user_uuid, true, &vec![cipher_uuid.clone()], conn).await.pop()
}
// Find all ciphers directly owned by the specified user.
pub async fn find_owned_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_owned_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
ciphers::table
.filter(
ciphers::user_uuid.eq(user_uuid)
.and(ciphers::organization_uuid.is_null())
)
.load::<CipherDb>(conn).expect("Error loading ciphers").from_db()
.load::<Self>(conn)
.expect("Error loading ciphers")
}}
}
pub async fn count_owned_by_user(user_uuid: &UserId, conn: &mut DbConn) -> i64 {
pub async fn count_owned_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
db_run! { conn: {
ciphers::table
.filter(ciphers::user_uuid.eq(user_uuid))
@ -927,15 +918,16 @@ impl Cipher {
}}
}
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
ciphers::table
.filter(ciphers::organization_uuid.eq(org_uuid))
.load::<CipherDb>(conn).expect("Error loading ciphers").from_db()
.load::<Self>(conn)
.expect("Error loading ciphers")
}}
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 {
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: {
ciphers::table
.filter(ciphers::organization_uuid.eq(org_uuid))
@ -946,25 +938,27 @@ impl Cipher {
}}
}
pub async fn find_by_folder(folder_uuid: &FolderId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
folders_ciphers::table.inner_join(ciphers::table)
.filter(folders_ciphers::folder_uuid.eq(folder_uuid))
.select(ciphers::all_columns)
.load::<CipherDb>(conn).expect("Error loading ciphers").from_db()
.load::<Self>(conn)
.expect("Error loading ciphers")
}}
}
/// Find all ciphers that were deleted before the specified datetime.
pub async fn find_deleted_before(dt: &NaiveDateTime, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_deleted_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
ciphers::table
.filter(ciphers::deleted_at.lt(dt))
.load::<CipherDb>(conn).expect("Error loading ciphers").from_db()
.load::<Self>(conn)
.expect("Error loading ciphers")
}}
}
pub async fn get_collections(&self, user_uuid: UserId, conn: &mut DbConn) -> Vec<CollectionId> {
pub async fn get_collections(&self, user_uuid: UserId, conn: &DbConn) -> Vec<CollectionId> {
if CONFIG.org_groups_enabled() {
db_run! { conn: {
ciphers_collections::table
@ -996,7 +990,8 @@ impl Cipher {
.and(collections_groups::read_only.eq(false)))
)
.select(ciphers_collections::collection_uuid)
.load::<CollectionId>(conn).unwrap_or_default()
.load::<CollectionId>(conn)
.unwrap_or_default()
}}
} else {
db_run! { conn: {
@ -1018,12 +1013,13 @@ impl Cipher {
.and(users_collections::read_only.eq(false)))
)
.select(ciphers_collections::collection_uuid)
.load::<CollectionId>(conn).unwrap_or_default()
.load::<CollectionId>(conn)
.unwrap_or_default()
}}
}
}
pub async fn get_admin_collections(&self, user_uuid: UserId, conn: &mut DbConn) -> Vec<CollectionId> {
pub async fn get_admin_collections(&self, user_uuid: UserId, conn: &DbConn) -> Vec<CollectionId> {
if CONFIG.org_groups_enabled() {
db_run! { conn: {
ciphers_collections::table
@ -1056,7 +1052,8 @@ impl Cipher {
.or(users_organizations::atype.le(MembershipType::Admin as i32)) // User is admin or owner
)
.select(ciphers_collections::collection_uuid)
.load::<CollectionId>(conn).unwrap_or_default()
.load::<CollectionId>(conn)
.unwrap_or_default()
}}
} else {
db_run! { conn: {
@ -1079,7 +1076,8 @@ impl Cipher {
.or(users_organizations::atype.le(MembershipType::Admin as i32)) // User is admin or owner
)
.select(ciphers_collections::collection_uuid)
.load::<CollectionId>(conn).unwrap_or_default()
.load::<CollectionId>(conn)
.unwrap_or_default()
}}
}
}
@ -1088,7 +1086,7 @@ impl Cipher {
/// This is used during a full sync so we only need one query for all collections accessible.
pub async fn get_collections_with_cipher_by_user(
user_uuid: UserId,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<(CipherId, CollectionId)> {
db_run! { conn: {
ciphers_collections::table
@ -1123,7 +1121,8 @@ impl Cipher {
.or_filter(collections_groups::collections_uuid.is_not_null()) //Access via group
.select(ciphers_collections::all_columns)
.distinct()
.load::<(CipherId, CollectionId)>(conn).unwrap_or_default()
.load::<(CipherId, CollectionId)>(conn)
.unwrap_or_default()
}}
}
}

128
src/db/models/collection.rs

@ -5,10 +5,13 @@ use super::{
CipherId, CollectionGroup, GroupUser, Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId,
User, UserId,
};
use crate::db::schema::{
ciphers_collections, collections, collections_groups, groups, groups_users, users_collections, users_organizations,
};
use crate::CONFIG;
use diesel::prelude::*;
use macros::UuidFromParam;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = collections)]
#[diesel(treat_none_as_null = true)]
@ -38,7 +41,6 @@ db_object! {
pub cipher_uuid: CipherId,
pub collection_uuid: CollectionId,
}
}
/// Local methods
impl Collection {
@ -83,7 +85,7 @@ impl Collection {
&self,
user_uuid: &UserId,
cipher_sync_data: Option<&crate::api::core::CipherSyncData>,
conn: &mut DbConn,
conn: &DbConn,
) -> Value {
let (read_only, hide_passwords, manage) = if let Some(cipher_sync_data) = cipher_sync_data {
match cipher_sync_data.members.get(&self.org_uuid) {
@ -135,7 +137,7 @@ impl Collection {
json_object
}
pub async fn can_access_collection(member: &Membership, col_id: &CollectionId, conn: &mut DbConn) -> bool {
pub async fn can_access_collection(member: &Membership, col_id: &CollectionId, conn: &DbConn) -> bool {
member.has_status(MembershipStatus::Confirmed)
&& (member.has_full_access()
|| CollectionUser::has_access_to_collection_by_user(col_id, &member.user_uuid, conn).await
@ -152,13 +154,13 @@ use crate::error::MapResult;
/// Database methods
impl Collection {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await;
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(collections::table)
.values(CollectionDb::to_db(self))
.values(self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -166,7 +168,7 @@ impl Collection {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(collections::table)
.filter(collections::uuid.eq(&self.uuid))
.set(CollectionDb::to_db(self))
.set(self)
.execute(conn)
.map_res("Error saving collection")
}
@ -174,19 +176,18 @@ impl Collection {
}.map_res("Error saving collection")
}
postgresql {
let value = CollectionDb::to_db(self);
diesel::insert_into(collections::table)
.values(&value)
.values(self)
.on_conflict(collections::uuid)
.do_update()
.set(&value)
.set(self)
.execute(conn)
.map_res("Error saving collection")
}
}
}
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await;
CollectionCipher::delete_all_by_collection(&self.uuid, conn).await?;
CollectionUser::delete_all_by_collection(&self.uuid, conn).await?;
@ -199,30 +200,29 @@ impl Collection {
}}
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
for collection in Self::find_by_organization(org_uuid, conn).await {
collection.delete(conn).await?;
}
Ok(())
}
pub async fn update_users_revision(&self, conn: &mut DbConn) {
pub async fn update_users_revision(&self, conn: &DbConn) {
for member in Membership::find_by_collection_and_org(&self.uuid, &self.org_uuid, conn).await.iter() {
User::update_uuid_revision(&member.user_uuid, conn).await;
}
}
pub async fn find_by_uuid(uuid: &CollectionId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid(uuid: &CollectionId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
collections::table
.filter(collections::uuid.eq(uuid))
.first::<CollectionDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_by_user_uuid(user_uuid: UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_user_uuid(user_uuid: UserId, conn: &DbConn) -> Vec<Self> {
if CONFIG.org_groups_enabled() {
db_run! { conn: {
collections::table
@ -263,7 +263,8 @@ impl Collection {
)
.select(collections::all_columns)
.distinct()
.load::<CollectionDb>(conn).expect("Error loading collections").from_db()
.load::<Self>(conn)
.expect("Error loading collections")
}}
} else {
db_run! { conn: {
@ -288,7 +289,8 @@ impl Collection {
)
.select(collections::all_columns)
.distinct()
.load::<CollectionDb>(conn).expect("Error loading collections").from_db()
.load::<Self>(conn)
.expect("Error loading collections")
}}
}
}
@ -296,7 +298,7 @@ impl Collection {
pub async fn find_by_organization_and_user_uuid(
org_uuid: &OrganizationId,
user_uuid: &UserId,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<Self> {
Self::find_by_user_uuid(user_uuid.to_owned(), conn)
.await
@ -305,17 +307,16 @@ impl Collection {
.collect()
}
pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
collections::table
.filter(collections::org_uuid.eq(org_uuid))
.load::<CollectionDb>(conn)
.load::<Self>(conn)
.expect("Error loading collections")
.from_db()
}}
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 {
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: {
collections::table
.filter(collections::org_uuid.eq(org_uuid))
@ -326,23 +327,18 @@ impl Collection {
}}
}
pub async fn find_by_uuid_and_org(
uuid: &CollectionId,
org_uuid: &OrganizationId,
conn: &mut DbConn,
) -> Option<Self> {
pub async fn find_by_uuid_and_org(uuid: &CollectionId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
collections::table
.filter(collections::uuid.eq(uuid))
.filter(collections::org_uuid.eq(org_uuid))
.select(collections::all_columns)
.first::<CollectionDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_by_uuid_and_user(uuid: &CollectionId, user_uuid: UserId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid_and_user(uuid: &CollectionId, user_uuid: UserId, conn: &DbConn) -> Option<Self> {
if CONFIG.org_groups_enabled() {
db_run! { conn: {
collections::table
@ -380,8 +376,8 @@ impl Collection {
)
)
).select(collections::all_columns)
.first::<CollectionDb>(conn).ok()
.from_db()
.first::<Self>(conn)
.ok()
}}
} else {
db_run! { conn: {
@ -403,13 +399,13 @@ impl Collection {
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
))
).select(collections::all_columns)
.first::<CollectionDb>(conn).ok()
.from_db()
.first::<Self>(conn)
.ok()
}}
}
}
pub async fn is_writable_by_user(&self, user_uuid: &UserId, conn: &mut DbConn) -> bool {
pub async fn is_writable_by_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
let user_uuid = user_uuid.to_string();
if CONFIG.org_groups_enabled() {
db_run! { conn: {
@ -471,7 +467,7 @@ impl Collection {
}
}
pub async fn hide_passwords_for_user(&self, user_uuid: &UserId, conn: &mut DbConn) -> bool {
pub async fn hide_passwords_for_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
let user_uuid = user_uuid.to_string();
db_run! { conn: {
collections::table
@ -517,7 +513,7 @@ impl Collection {
}}
}
pub async fn is_manageable_by_user(&self, user_uuid: &UserId, conn: &mut DbConn) -> bool {
pub async fn is_manageable_by_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
let user_uuid = user_uuid.to_string();
db_run! { conn: {
collections::table
@ -569,7 +565,7 @@ impl CollectionUser {
pub async fn find_by_organization_and_user_uuid(
org_uuid: &OrganizationId,
user_uuid: &UserId,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
users_collections::table
@ -577,15 +573,14 @@ impl CollectionUser {
.inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid)))
.filter(collections::org_uuid.eq(org_uuid))
.select(users_collections::all_columns)
.load::<CollectionUserDb>(conn)
.load::<Self>(conn)
.expect("Error loading users_collections")
.from_db()
}}
}
pub async fn find_by_organization_swap_user_uuid_with_member_uuid(
org_uuid: &OrganizationId,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<CollectionMembership> {
let col_users = db_run! { conn: {
users_collections::table
@ -594,9 +589,8 @@ impl CollectionUser {
.inner_join(users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)))
.filter(users_organizations::org_uuid.eq(org_uuid))
.select((users_organizations::uuid, users_collections::collection_uuid, users_collections::read_only, users_collections::hide_passwords, users_collections::manage))
.load::<CollectionUserDb>(conn)
.load::<Self>(conn)
.expect("Error loading users_collections")
.from_db()
}};
col_users.into_iter().map(|c| c.into()).collect()
}
@ -607,7 +601,7 @@ impl CollectionUser {
read_only: bool,
hide_passwords: bool,
manage: bool,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
User::update_uuid_revision(user_uuid, conn).await;
@ -664,7 +658,7 @@ impl CollectionUser {
}
}
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn).await;
db_run! { conn: {
@ -678,21 +672,20 @@ impl CollectionUser {
}}
}
pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid))
.select(users_collections::all_columns)
.load::<CollectionUserDb>(conn)
.load::<Self>(conn)
.expect("Error loading users_collections")
.from_db()
}}
}
pub async fn find_by_org_and_coll_swap_user_uuid_with_member_uuid(
org_uuid: &OrganizationId,
collection_uuid: &CollectionId,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<CollectionMembership> {
let col_users = db_run! { conn: {
users_collections::table
@ -700,9 +693,8 @@ impl CollectionUser {
.filter(users_organizations::org_uuid.eq(org_uuid))
.inner_join(users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)))
.select((users_organizations::uuid, users_collections::collection_uuid, users_collections::read_only, users_collections::hide_passwords, users_collections::manage))
.load::<CollectionUserDb>(conn)
.load::<Self>(conn)
.expect("Error loading users_collections")
.from_db()
}};
col_users.into_iter().map(|c| c.into()).collect()
}
@ -710,31 +702,29 @@ impl CollectionUser {
pub async fn find_by_collection_and_user(
collection_uuid: &CollectionId,
user_uuid: &UserId,
conn: &mut DbConn,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid))
.filter(users_collections::user_uuid.eq(user_uuid))
.select(users_collections::all_columns)
.first::<CollectionUserDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
users_collections::table
.filter(users_collections::user_uuid.eq(user_uuid))
.select(users_collections::all_columns)
.load::<CollectionUserDb>(conn)
.load::<Self>(conn)
.expect("Error loading users_collections")
.from_db()
}}
}
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
for collection in CollectionUser::find_by_collection(collection_uuid, conn).await.iter() {
User::update_uuid_revision(&collection.user_uuid, conn).await;
}
@ -749,7 +739,7 @@ impl CollectionUser {
pub async fn delete_all_by_user_and_org(
user_uuid: &UserId,
org_uuid: &OrganizationId,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
let collectionusers = Self::find_by_organization_and_user_uuid(org_uuid, user_uuid, conn).await;
@ -766,18 +756,14 @@ impl CollectionUser {
}}
}
pub async fn has_access_to_collection_by_user(
col_id: &CollectionId,
user_uuid: &UserId,
conn: &mut DbConn,
) -> bool {
pub async fn has_access_to_collection_by_user(col_id: &CollectionId, user_uuid: &UserId, conn: &DbConn) -> bool {
Self::find_by_collection_and_user(col_id, user_uuid, conn).await.is_some()
}
}
/// Database methods
impl CollectionCipher {
pub async fn save(cipher_uuid: &CipherId, collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult {
pub async fn save(cipher_uuid: &CipherId, collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
Self::update_users_revision(collection_uuid, conn).await;
db_run! { conn:
@ -807,7 +793,7 @@ impl CollectionCipher {
}
}
pub async fn delete(cipher_uuid: &CipherId, collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(cipher_uuid: &CipherId, collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
Self::update_users_revision(collection_uuid, conn).await;
db_run! { conn: {
@ -821,7 +807,7 @@ impl CollectionCipher {
}}
}
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(ciphers_collections::table.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid)))
.execute(conn)
@ -829,7 +815,7 @@ impl CollectionCipher {
}}
}
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(ciphers_collections::table.filter(ciphers_collections::collection_uuid.eq(collection_uuid)))
.execute(conn)
@ -837,7 +823,7 @@ impl CollectionCipher {
}}
}
pub async fn update_users_revision(collection_uuid: &CollectionId, conn: &mut DbConn) {
pub async fn update_users_revision(collection_uuid: &CollectionId, conn: &DbConn) {
if let Some(collection) = Collection::find_by_uuid(collection_uuid, conn).await {
collection.update_users_revision(conn).await;
}

70
src/db/models/device.rs

@ -5,13 +5,14 @@ use derive_more::{Display, From};
use serde_json::Value;
use super::{AuthRequest, UserId};
use crate::db::schema::devices;
use crate::{
crypto,
util::{format_date, get_uuid},
};
use diesel::prelude::*;
use macros::{IdFromParam, UuidFromParam};
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = devices)]
#[diesel(treat_none_as_null = true)]
@ -31,7 +32,6 @@ db_object! {
pub refresh_token: String,
pub twofactor_remember: Option<String>,
}
}
/// Local methods
impl Device {
@ -115,13 +115,7 @@ use crate::error::MapResult;
/// Database methods
impl Device {
pub async fn new(
uuid: DeviceId,
user_uuid: UserId,
name: String,
atype: i32,
conn: &mut DbConn,
) -> ApiResult<Device> {
pub async fn new(uuid: DeviceId, user_uuid: UserId, name: String, atype: i32, conn: &DbConn) -> ApiResult<Device> {
let now = Utc::now().naive_utc();
let device = Self {
@ -142,18 +136,24 @@ impl Device {
device.inner_save(conn).await.map(|()| device)
}
async fn inner_save(&self, conn: &mut DbConn) -> EmptyResult {
async fn inner_save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn:
sqlite, mysql {
crate::util::retry(
|| diesel::replace_into(devices::table).values(DeviceDb::to_db(self)).execute(conn),
crate::util::retry(||
diesel::replace_into(devices::table)
.values(self)
.execute(conn),
10,
).map_res("Error saving device")
}
postgresql {
let value = DeviceDb::to_db(self);
crate::util::retry(
|| diesel::insert_into(devices::table).values(&value).on_conflict((devices::uuid, devices::user_uuid)).do_update().set(&value).execute(conn),
crate::util::retry(||
diesel::insert_into(devices::table)
.values(self)
.on_conflict((devices::uuid, devices::user_uuid))
.do_update()
.set(self)
.execute(conn),
10,
).map_res("Error saving device")
}
@ -161,12 +161,12 @@ impl Device {
}
// Should only be called after user has passed authentication
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.updated_at = Utc::now().naive_utc();
self.inner_save(conn).await
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(devices::table.filter(devices::user_uuid.eq(user_uuid)))
.execute(conn)
@ -174,18 +174,17 @@ impl Device {
}}
}
pub async fn find_by_uuid_and_user(uuid: &DeviceId, user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid_and_user(uuid: &DeviceId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
devices::table
.filter(devices::uuid.eq(uuid))
.filter(devices::user_uuid.eq(user_uuid))
.first::<DeviceDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_with_auth_request_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<DeviceWithAuthRequest> {
pub async fn find_with_auth_request_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<DeviceWithAuthRequest> {
let devices = Self::find_by_user(user_uuid, conn).await;
let mut result = Vec::new();
for device in devices {
@ -195,27 +194,25 @@ impl Device {
result
}
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
devices::table
.filter(devices::user_uuid.eq(user_uuid))
.load::<DeviceDb>(conn)
.load::<Self>(conn)
.expect("Error loading devices")
.from_db()
}}
}
pub async fn find_by_uuid(uuid: &DeviceId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid(uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
devices::table
.filter(devices::uuid.eq(uuid))
.first::<DeviceDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn clear_push_token_by_uuid(uuid: &DeviceId, conn: &mut DbConn) -> EmptyResult {
pub async fn clear_push_token_by_uuid(uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::update(devices::table)
.filter(devices::uuid.eq(uuid))
@ -224,39 +221,36 @@ impl Device {
.map_res("Error removing push token")
}}
}
pub async fn find_by_refresh_token(refresh_token: &str, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_refresh_token(refresh_token: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
devices::table
.filter(devices::refresh_token.eq(refresh_token))
.first::<DeviceDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_latest_active_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_latest_active_by_user(user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
devices::table
.filter(devices::user_uuid.eq(user_uuid))
.order(devices::updated_at.desc())
.first::<DeviceDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_push_devices_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_push_devices_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
devices::table
.filter(devices::user_uuid.eq(user_uuid))
.filter(devices::push_token.is_not_null())
.load::<DeviceDb>(conn)
.load::<Self>(conn)
.expect("Error loading push devices")
.from_db()
}}
}
pub async fn check_user_has_push_device(user_uuid: &UserId, conn: &mut DbConn) -> bool {
pub async fn check_user_has_push_device(user_uuid: &UserId, conn: &DbConn) -> bool {
db_run! { conn: {
devices::table
.filter(devices::user_uuid.eq(user_uuid))

91
src/db/models/emergency_access.rs

@ -3,10 +3,11 @@ use derive_more::{AsRef, Deref, Display, From};
use serde_json::Value;
use super::{User, UserId};
use crate::db::schema::emergency_access;
use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*;
use macros::UuidFromParam;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = emergency_access)]
#[diesel(treat_none_as_null = true)]
@ -25,10 +26,8 @@ db_object! {
pub updated_at: NaiveDateTime,
pub created_at: NaiveDateTime,
}
}
// Local methods
impl EmergencyAccess {
pub fn new(grantor_uuid: UserId, email: String, status: i32, atype: i32, wait_time_days: i32) -> Self {
let now = Utc::now().naive_utc();
@ -67,7 +66,7 @@ impl EmergencyAccess {
})
}
pub async fn to_json_grantor_details(&self, conn: &mut DbConn) -> Value {
pub async fn to_json_grantor_details(&self, conn: &DbConn) -> Value {
let grantor_user = User::find_by_uuid(&self.grantor_uuid, conn).await.expect("Grantor user not found.");
json!({
@ -83,7 +82,7 @@ impl EmergencyAccess {
})
}
pub async fn to_json_grantee_details(&self, conn: &mut DbConn) -> Option<Value> {
pub async fn to_json_grantee_details(&self, conn: &DbConn) -> Option<Value> {
let grantee_user = if let Some(grantee_uuid) = &self.grantee_uuid {
User::find_by_uuid(grantee_uuid, conn).await.expect("Grantee user not found.")
} else if let Some(email) = self.email.as_deref() {
@ -140,14 +139,14 @@ pub enum EmergencyAccessStatus {
// region Database methods
impl EmergencyAccess {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.grantor_uuid, conn).await;
self.updated_at = Utc::now().naive_utc();
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(emergency_access::table)
.values(EmergencyAccessDb::to_db(self))
.values(&*self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -155,7 +154,7 @@ impl EmergencyAccess {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(emergency_access::table)
.filter(emergency_access::uuid.eq(&self.uuid))
.set(EmergencyAccessDb::to_db(self))
.set(&*self)
.execute(conn)
.map_res("Error updating emergency access")
}
@ -163,12 +162,11 @@ impl EmergencyAccess {
}.map_res("Error saving emergency access")
}
postgresql {
let value = EmergencyAccessDb::to_db(self);
diesel::insert_into(emergency_access::table)
.values(&value)
.values(&*self)
.on_conflict(emergency_access::uuid)
.do_update()
.set(&value)
.set(&*self)
.execute(conn)
.map_res("Error saving emergency access")
}
@ -179,7 +177,7 @@ impl EmergencyAccess {
&mut self,
status: i32,
date: &NaiveDateTime,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
// Update the grantee so that it will refresh it's status.
User::update_uuid_revision(self.grantee_uuid.as_ref().expect("Error getting grantee"), conn).await;
@ -196,11 +194,7 @@ impl EmergencyAccess {
}}
}
pub async fn update_last_notification_date_and_save(
&mut self,
date: &NaiveDateTime,
conn: &mut DbConn,
) -> EmptyResult {
pub async fn update_last_notification_date_and_save(&mut self, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
self.last_notification_at = Some(date.to_owned());
date.clone_into(&mut self.updated_at);
@ -214,7 +208,7 @@ impl EmergencyAccess {
}}
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
for ea in Self::find_all_by_grantor_uuid(user_uuid, conn).await {
ea.delete(conn).await?;
}
@ -224,14 +218,14 @@ impl EmergencyAccess {
Ok(())
}
pub async fn delete_all_by_grantee_email(grantee_email: &str, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_grantee_email(grantee_email: &str, conn: &DbConn) -> EmptyResult {
for ea in Self::find_all_invited_by_grantee_email(grantee_email, conn).await {
ea.delete(conn).await?;
}
Ok(())
}
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.grantor_uuid, conn).await;
db_run! { conn: {
@ -245,109 +239,108 @@ impl EmergencyAccess {
grantor_uuid: &UserId,
grantee_uuid: &UserId,
email: &str,
conn: &mut DbConn,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.filter(emergency_access::grantee_uuid.eq(grantee_uuid).or(emergency_access::email.eq(email)))
.first::<EmergencyAccessDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_all_recoveries_initiated(conn: &mut DbConn) -> Vec<Self> {
pub async fn find_all_recoveries_initiated(conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::status.eq(EmergencyAccessStatus::RecoveryInitiated as i32))
.filter(emergency_access::recovery_initiated_at.is_not_null())
.load::<EmergencyAccessDb>(conn).expect("Error loading emergency_access").from_db()
.load::<Self>(conn)
.expect("Error loading emergency_access")
}}
}
pub async fn find_by_uuid_and_grantor_uuid(
uuid: &EmergencyAccessId,
grantor_uuid: &UserId,
conn: &mut DbConn,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.first::<EmergencyAccessDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_uuid_and_grantee_uuid(
uuid: &EmergencyAccessId,
grantee_uuid: &UserId,
conn: &mut DbConn,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::grantee_uuid.eq(grantee_uuid))
.first::<EmergencyAccessDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_uuid_and_grantee_email(
uuid: &EmergencyAccessId,
grantee_email: &str,
conn: &mut DbConn,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::email.eq(grantee_email))
.first::<EmergencyAccessDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_all_by_grantee_uuid(grantee_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_all_by_grantee_uuid(grantee_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::grantee_uuid.eq(grantee_uuid))
.load::<EmergencyAccessDb>(conn).expect("Error loading emergency_access").from_db()
.load::<Self>(conn)
.expect("Error loading emergency_access")
}}
}
pub async fn find_invited_by_grantee_email(grantee_email: &str, conn: &mut DbConn) -> Option<Self> {
pub async fn find_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::email.eq(grantee_email))
.filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32))
.first::<EmergencyAccessDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_all_invited_by_grantee_email(grantee_email: &str, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_all_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::email.eq(grantee_email))
.filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32))
.load::<EmergencyAccessDb>(conn).expect("Error loading emergency_access").from_db()
.load::<Self>(conn)
.expect("Error loading emergency_access")
}}
}
pub async fn find_all_by_grantor_uuid(grantor_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_all_by_grantor_uuid(grantor_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.load::<EmergencyAccessDb>(conn).expect("Error loading emergency_access").from_db()
.load::<Self>(conn)
.expect("Error loading emergency_access")
}}
}
pub async fn accept_invite(
&mut self,
grantee_uuid: &UserId,
grantee_email: &str,
conn: &mut DbConn,
) -> EmptyResult {
pub async fn accept_invite(&mut self, grantee_uuid: &UserId, grantee_email: &str, conn: &DbConn) -> EmptyResult {
if self.email.is_none() || self.email.as_ref().unwrap() != grantee_email {
err!("User email does not match invite.");
}

39
src/db/models/event.rs

@ -3,11 +3,12 @@ use chrono::{NaiveDateTime, TimeDelta, Utc};
use serde_json::Value;
use super::{CipherId, CollectionId, GroupId, MembershipId, OrgPolicyId, OrganizationId, UserId};
use crate::db::schema::{event, users_organizations};
use crate::{api::EmptyResult, db::DbConn, error::MapResult, CONFIG};
use diesel::prelude::*;
// https://bitwarden.com/help/event-logs/
db_object! {
// Upstream: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Services/Implementations/EventService.cs
// Upstream: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/AdminConsole/Public/Models/Response/EventResponseModel.cs
// Upstream SQL: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Sql/dbo/Tables/Event.sql
@ -34,7 +35,6 @@ db_object! {
pub provider_user_uuid: Option<String>,
pub provider_org_uuid: Option<String>,
}
}
// Upstream enum: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Enums/EventType.cs
#[derive(Debug, Copy, Clone)]
@ -193,27 +193,27 @@ impl Event {
/// #############
/// Basic Queries
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn:
sqlite, mysql {
diesel::replace_into(event::table)
.values(EventDb::to_db(self))
.values(self)
.execute(conn)
.map_res("Error saving event")
}
postgresql {
diesel::insert_into(event::table)
.values(EventDb::to_db(self))
.values(self)
.on_conflict(event::uuid)
.do_update()
.set(EventDb::to_db(self))
.set(self)
.execute(conn)
.map_res("Error saving event")
}
}
}
pub async fn save_user_event(events: Vec<Event>, conn: &mut DbConn) -> EmptyResult {
pub async fn save_user_event(events: Vec<Event>, conn: &DbConn) -> EmptyResult {
// Special save function which is able to handle multiple events.
// SQLite doesn't support the DEFAULT argument, and does not support inserting multiple values at the same time.
// MySQL and PostgreSQL do.
@ -224,14 +224,13 @@ impl Event {
sqlite {
for event in events {
diesel::insert_or_ignore_into(event::table)
.values(EventDb::to_db(&event))
.values(&event)
.execute(conn)
.unwrap_or_default();
}
Ok(())
}
mysql {
let events: Vec<EventDb> = events.iter().map(EventDb::to_db).collect();
diesel::insert_or_ignore_into(event::table)
.values(&events)
.execute(conn)
@ -239,7 +238,6 @@ impl Event {
Ok(())
}
postgresql {
let events: Vec<EventDb> = events.iter().map(EventDb::to_db).collect();
diesel::insert_into(event::table)
.values(&events)
.on_conflict_do_nothing()
@ -250,7 +248,7 @@ impl Event {
}
}
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(event::table.filter(event::uuid.eq(self.uuid)))
.execute(conn)
@ -264,7 +262,7 @@ impl Event {
org_uuid: &OrganizationId,
start: &NaiveDateTime,
end: &NaiveDateTime,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
event::table
@ -272,13 +270,12 @@ impl Event {
.filter(event::event_date.between(start, end))
.order_by(event::event_date.desc())
.limit(Self::PAGE_SIZE)
.load::<EventDb>(conn)
.load::<Self>(conn)
.expect("Error filtering events")
.from_db()
}}
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 {
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: {
event::table
.filter(event::org_uuid.eq(org_uuid))
@ -294,7 +291,7 @@ impl Event {
member_uuid: &MembershipId,
start: &NaiveDateTime,
end: &NaiveDateTime,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
event::table
@ -305,9 +302,8 @@ impl Event {
.select(event::all_columns)
.order_by(event::event_date.desc())
.limit(Self::PAGE_SIZE)
.load::<EventDb>(conn)
.load::<Self>(conn)
.expect("Error filtering events")
.from_db()
}}
}
@ -315,7 +311,7 @@ impl Event {
cipher_uuid: &CipherId,
start: &NaiveDateTime,
end: &NaiveDateTime,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
event::table
@ -323,13 +319,12 @@ impl Event {
.filter(event::event_date.between(start, end))
.order_by(event::event_date.desc())
.limit(Self::PAGE_SIZE)
.load::<EventDb>(conn)
.load::<Self>(conn)
.expect("Error filtering events")
.from_db()
}}
}
pub async fn clean_events(conn: &mut DbConn) -> EmptyResult {
pub async fn clean_events(conn: &DbConn) -> EmptyResult {
if let Some(days_to_retain) = CONFIG.events_days_retain() {
let dt = Utc::now().naive_utc() - TimeDelta::try_days(days_to_retain).unwrap();
db_run! { conn: {

18
src/db/models/favorite.rs

@ -1,6 +1,7 @@
use super::{CipherId, User, UserId};
use crate::db::schema::favorites;
use diesel::prelude::*;
db_object! {
#[derive(Identifiable, Queryable, Insertable)]
#[diesel(table_name = favorites)]
#[diesel(primary_key(user_uuid, cipher_uuid))]
@ -8,7 +9,6 @@ db_object! {
pub user_uuid: UserId,
pub cipher_uuid: CipherId,
}
}
use crate::db::DbConn;
@ -17,14 +17,16 @@ use crate::error::MapResult;
impl Favorite {
// Returns whether the specified cipher is a favorite of the specified user.
pub async fn is_favorite(cipher_uuid: &CipherId, user_uuid: &UserId, conn: &mut DbConn) -> bool {
pub async fn is_favorite(cipher_uuid: &CipherId, user_uuid: &UserId, conn: &DbConn) -> bool {
db_run! { conn: {
let query = favorites::table
.filter(favorites::cipher_uuid.eq(cipher_uuid))
.filter(favorites::user_uuid.eq(user_uuid))
.count();
query.first::<i64>(conn).ok().unwrap_or(0) != 0
query.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}}
}
@ -33,7 +35,7 @@ impl Favorite {
favorite: bool,
cipher_uuid: &CipherId,
user_uuid: &UserId,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
let (old, new) = (Self::is_favorite(cipher_uuid, user_uuid, conn).await, favorite);
match (old, new) {
@ -67,7 +69,7 @@ impl Favorite {
}
// Delete all favorite entries associated with the specified cipher.
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(favorites::table.filter(favorites::cipher_uuid.eq(cipher_uuid)))
.execute(conn)
@ -76,7 +78,7 @@ impl Favorite {
}
// Delete all favorite entries associated with the specified user.
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(favorites::table.filter(favorites::user_uuid.eq(user_uuid)))
.execute(conn)
@ -86,7 +88,7 @@ impl Favorite {
/// Return a vec with (cipher_uuid) this will only contain favorite flagged ciphers
/// This is used during a full sync so we only need one query for all favorite cipher matches.
pub async fn get_all_cipher_uuid_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<CipherId> {
pub async fn get_all_cipher_uuid_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<CipherId> {
db_run! { conn: {
favorites::table
.filter(favorites::user_uuid.eq(user_uuid))

53
src/db/models/folder.rs

@ -3,9 +3,10 @@ use derive_more::{AsRef, Deref, Display, From};
use serde_json::Value;
use super::{CipherId, User, UserId};
use crate::db::schema::{folders, folders_ciphers};
use diesel::prelude::*;
use macros::UuidFromParam;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = folders)]
#[diesel(primary_key(uuid))]
@ -24,7 +25,6 @@ db_object! {
pub cipher_uuid: CipherId,
pub folder_uuid: FolderId,
}
}
/// Local methods
impl Folder {
@ -69,14 +69,14 @@ use crate::error::MapResult;
/// Database methods
impl Folder {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn).await;
self.updated_at = Utc::now().naive_utc();
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(folders::table)
.values(FolderDb::to_db(self))
.values(&*self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -84,7 +84,7 @@ impl Folder {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(folders::table)
.filter(folders::uuid.eq(&self.uuid))
.set(FolderDb::to_db(self))
.set(&*self)
.execute(conn)
.map_res("Error saving folder")
}
@ -92,19 +92,18 @@ impl Folder {
}.map_res("Error saving folder")
}
postgresql {
let value = FolderDb::to_db(self);
diesel::insert_into(folders::table)
.values(&value)
.values(&*self)
.on_conflict(folders::uuid)
.do_update()
.set(&value)
.set(&*self)
.execute(conn)
.map_res("Error saving folder")
}
}
}
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn).await;
FolderCipher::delete_all_by_folder(&self.uuid, conn).await?;
@ -115,50 +114,48 @@ impl Folder {
}}
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
for folder in Self::find_by_user(user_uuid, conn).await {
folder.delete(conn).await?;
}
Ok(())
}
pub async fn find_by_uuid_and_user(uuid: &FolderId, user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid_and_user(uuid: &FolderId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
folders::table
.filter(folders::uuid.eq(uuid))
.filter(folders::user_uuid.eq(user_uuid))
.first::<FolderDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
folders::table
.filter(folders::user_uuid.eq(user_uuid))
.load::<FolderDb>(conn)
.load::<Self>(conn)
.expect("Error loading folders")
.from_db()
}}
}
}
impl FolderCipher {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn:
sqlite, mysql {
// Not checking for ForeignKey Constraints here.
// Table folders_ciphers does not have ForeignKey Constraints which would cause conflicts.
// This table has no constraints pointing to itself, but only to others.
diesel::replace_into(folders_ciphers::table)
.values(FolderCipherDb::to_db(self))
.values(self)
.execute(conn)
.map_res("Error adding cipher to folder")
}
postgresql {
diesel::insert_into(folders_ciphers::table)
.values(FolderCipherDb::to_db(self))
.values(self)
.on_conflict((folders_ciphers::cipher_uuid, folders_ciphers::folder_uuid))
.do_nothing()
.execute(conn)
@ -167,7 +164,7 @@ impl FolderCipher {
}
}
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(
folders_ciphers::table
@ -179,7 +176,7 @@ impl FolderCipher {
}}
}
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(folders_ciphers::table.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid)))
.execute(conn)
@ -187,7 +184,7 @@ impl FolderCipher {
}}
}
pub async fn delete_all_by_folder(folder_uuid: &FolderId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(folders_ciphers::table.filter(folders_ciphers::folder_uuid.eq(folder_uuid)))
.execute(conn)
@ -198,31 +195,29 @@ impl FolderCipher {
pub async fn find_by_folder_and_cipher(
folder_uuid: &FolderId,
cipher_uuid: &CipherId,
conn: &mut DbConn,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
folders_ciphers::table
.filter(folders_ciphers::folder_uuid.eq(folder_uuid))
.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid))
.first::<FolderCipherDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_by_folder(folder_uuid: &FolderId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
folders_ciphers::table
.filter(folders_ciphers::folder_uuid.eq(folder_uuid))
.load::<FolderCipherDb>(conn)
.load::<Self>(conn)
.expect("Error loading folders")
.from_db()
}}
}
/// Return a vec with (cipher_uuid, folder_uuid)
/// This is used during a full sync so we only need one query for all folder matches.
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<(CipherId, FolderId)> {
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<(CipherId, FolderId)> {
db_run! { conn: {
folders_ciphers::table
.inner_join(folders::table)

95
src/db/models/group.rs

@ -1,13 +1,14 @@
use super::{CollectionId, Membership, MembershipId, OrganizationId, User, UserId};
use crate::api::EmptyResult;
use crate::db::schema::{collections_groups, groups, groups_users, users_organizations};
use crate::db::DbConn;
use crate::error::MapResult;
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use macros::UuidFromParam;
use serde_json::Value;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = groups)]
#[diesel(treat_none_as_null = true)]
@ -38,8 +39,7 @@ db_object! {
#[diesel(primary_key(groups_uuid, users_organizations_uuid))]
pub struct GroupUser {
pub groups_uuid: GroupId,
pub users_organizations_uuid: MembershipId
}
pub users_organizations_uuid: MembershipId,
}
/// Local methods
@ -77,7 +77,7 @@ impl Group {
})
}
pub async fn to_json_details(&self, conn: &mut DbConn) -> Value {
pub async fn to_json_details(&self, conn: &DbConn) -> Value {
// If both read_only and hide_passwords are false, then manage should be true
// You can't have an entry with read_only and manage, or hide_passwords and manage
// Or an entry with everything to false
@ -156,13 +156,13 @@ impl GroupUser {
/// Database methods
impl Group {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.revision_date = Utc::now().naive_utc();
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(groups::table)
.values(GroupDb::to_db(self))
.values(&*self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -170,7 +170,7 @@ impl Group {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(groups::table)
.filter(groups::uuid.eq(&self.uuid))
.set(GroupDb::to_db(self))
.set(&*self)
.execute(conn)
.map_res("Error saving group")
}
@ -178,36 +178,34 @@ impl Group {
}.map_res("Error saving group")
}
postgresql {
let value = GroupDb::to_db(self);
diesel::insert_into(groups::table)
.values(&value)
.values(&*self)
.on_conflict(groups::uuid)
.do_update()
.set(&value)
.set(&*self)
.execute(conn)
.map_res("Error saving group")
}
}
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
for group in Self::find_by_organization(org_uuid, conn).await {
group.delete(conn).await?;
}
Ok(())
}
pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
groups::table
.filter(groups::organizations_uuid.eq(org_uuid))
.load::<GroupDb>(conn)
.load::<Self>(conn)
.expect("Error loading groups")
.from_db()
}}
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 {
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: {
groups::table
.filter(groups::organizations_uuid.eq(org_uuid))
@ -218,33 +216,31 @@ impl Group {
}}
}
pub async fn find_by_uuid_and_org(uuid: &GroupId, org_uuid: &OrganizationId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid_and_org(uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
groups::table
.filter(groups::uuid.eq(uuid))
.filter(groups::organizations_uuid.eq(org_uuid))
.first::<GroupDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_by_external_id_and_org(
external_id: &str,
org_uuid: &OrganizationId,
conn: &mut DbConn,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
groups::table
.filter(groups::external_id.eq(external_id))
.filter(groups::organizations_uuid.eq(org_uuid))
.first::<GroupDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
//Returns all organizations the user has full access to
pub async fn get_orgs_by_user_with_full_access(user_uuid: &UserId, conn: &mut DbConn) -> Vec<OrganizationId> {
pub async fn get_orgs_by_user_with_full_access(user_uuid: &UserId, conn: &DbConn) -> Vec<OrganizationId> {
db_run! { conn: {
groups_users::table
.inner_join(users_organizations::table.on(
@ -262,7 +258,7 @@ impl Group {
}}
}
pub async fn is_in_full_access_group(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &mut DbConn) -> bool {
pub async fn is_in_full_access_group(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &DbConn) -> bool {
db_run! { conn: {
groups::table
.inner_join(groups_users::table.on(
@ -280,7 +276,7 @@ impl Group {
}}
}
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
CollectionGroup::delete_all_by_group(&self.uuid, conn).await?;
GroupUser::delete_all_by_group(&self.uuid, conn).await?;
@ -291,13 +287,13 @@ impl Group {
}}
}
pub async fn update_revision(uuid: &GroupId, conn: &mut DbConn) {
pub async fn update_revision(uuid: &GroupId, conn: &DbConn) {
if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await {
warn!("Failed to update revision for {uuid}: {e:#?}");
}
}
async fn _update_revision(uuid: &GroupId, date: &NaiveDateTime, conn: &mut DbConn) -> EmptyResult {
async fn _update_revision(uuid: &GroupId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
crate::util::retry(|| {
diesel::update(groups::table.filter(groups::uuid.eq(uuid)))
@ -310,7 +306,7 @@ impl Group {
}
impl CollectionGroup {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
let group_users = GroupUser::find_by_group(&self.groups_uuid, conn).await;
for group_user in group_users {
group_user.update_user_revision(conn).await;
@ -369,17 +365,16 @@ impl CollectionGroup {
}
}
pub async fn find_by_group(group_uuid: &GroupId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_group(group_uuid: &GroupId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
collections_groups::table
.filter(collections_groups::groups_uuid.eq(group_uuid))
.load::<CollectionGroupDb>(conn)
.load::<Self>(conn)
.expect("Error loading collection groups")
.from_db()
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
collections_groups::table
.inner_join(groups_users::table.on(
@ -390,24 +385,22 @@ impl CollectionGroup {
))
.filter(users_organizations::user_uuid.eq(user_uuid))
.select(collections_groups::all_columns)
.load::<CollectionGroupDb>(conn)
.load::<Self>(conn)
.expect("Error loading user collection groups")
.from_db()
}}
}
pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
collections_groups::table
.filter(collections_groups::collections_uuid.eq(collection_uuid))
.select(collections_groups::all_columns)
.load::<CollectionGroupDb>(conn)
.load::<Self>(conn)
.expect("Error loading collection groups")
.from_db()
}}
}
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
let group_users = GroupUser::find_by_group(&self.groups_uuid, conn).await;
for group_user in group_users {
group_user.update_user_revision(conn).await;
@ -422,7 +415,7 @@ impl CollectionGroup {
}}
}
pub async fn delete_all_by_group(group_uuid: &GroupId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_group(group_uuid: &GroupId, conn: &DbConn) -> EmptyResult {
let group_users = GroupUser::find_by_group(group_uuid, conn).await;
for group_user in group_users {
group_user.update_user_revision(conn).await;
@ -436,7 +429,7 @@ impl CollectionGroup {
}}
}
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
let collection_assigned_to_groups = CollectionGroup::find_by_collection(collection_uuid, conn).await;
for collection_assigned_to_group in collection_assigned_to_groups {
let group_users = GroupUser::find_by_group(&collection_assigned_to_group.groups_uuid, conn).await;
@ -455,7 +448,7 @@ impl CollectionGroup {
}
impl GroupUser {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.update_user_revision(conn).await;
db_run! { conn:
@ -501,30 +494,28 @@ impl GroupUser {
}
}
pub async fn find_by_group(group_uuid: &GroupId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_group(group_uuid: &GroupId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
groups_users::table
.filter(groups_users::groups_uuid.eq(group_uuid))
.load::<GroupUserDb>(conn)
.load::<Self>(conn)
.expect("Error loading group users")
.from_db()
}}
}
pub async fn find_by_member(member_uuid: &MembershipId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_member(member_uuid: &MembershipId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
groups_users::table
.filter(groups_users::users_organizations_uuid.eq(member_uuid))
.load::<GroupUserDb>(conn)
.load::<Self>(conn)
.expect("Error loading groups for user")
.from_db()
}}
}
pub async fn has_access_to_collection_by_member(
collection_uuid: &CollectionId,
member_uuid: &MembershipId,
conn: &mut DbConn,
conn: &DbConn,
) -> bool {
db_run! { conn: {
groups_users::table
@ -542,7 +533,7 @@ impl GroupUser {
pub async fn has_full_access_by_member(
org_uuid: &OrganizationId,
member_uuid: &MembershipId,
conn: &mut DbConn,
conn: &DbConn,
) -> bool {
db_run! { conn: {
groups_users::table
@ -558,7 +549,7 @@ impl GroupUser {
}}
}
pub async fn update_user_revision(&self, conn: &mut DbConn) {
pub async fn update_user_revision(&self, conn: &DbConn) {
match Membership::find_by_uuid(&self.users_organizations_uuid, conn).await {
Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await,
None => warn!("Member could not be found!"),
@ -568,7 +559,7 @@ impl GroupUser {
pub async fn delete_by_group_and_member(
group_uuid: &GroupId,
member_uuid: &MembershipId,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
match Membership::find_by_uuid(member_uuid, conn).await {
Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await,
@ -584,7 +575,7 @@ impl GroupUser {
}}
}
pub async fn delete_all_by_group(group_uuid: &GroupId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_group(group_uuid: &GroupId, conn: &DbConn) -> EmptyResult {
let group_users = GroupUser::find_by_group(group_uuid, conn).await;
for group_user in group_users {
group_user.update_user_revision(conn).await;
@ -598,7 +589,7 @@ impl GroupUser {
}}
}
pub async fn delete_all_by_member(member_uuid: &MembershipId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_member(member_uuid: &MembershipId, conn: &DbConn) -> EmptyResult {
match Membership::find_by_uuid(member_uuid, conn).await {
Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await,
None => warn!("Member could not be found!"),

56
src/db/models/org_policy.rs

@ -3,12 +3,13 @@ use serde::Deserialize;
use serde_json::Value;
use crate::api::EmptyResult;
use crate::db::schema::{org_policies, users_organizations};
use crate::db::DbConn;
use crate::error::MapResult;
use diesel::prelude::*;
use super::{Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId, TwoFactor, UserId};
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = org_policies)]
#[diesel(primary_key(uuid))]
@ -19,7 +20,6 @@ db_object! {
pub enabled: bool,
pub data: String,
}
}
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Enums/PolicyType.cs
#[derive(Copy, Clone, Eq, PartialEq, num_derive::FromPrimitive)]
@ -106,11 +106,11 @@ impl OrgPolicy {
/// Database methods
impl OrgPolicy {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(org_policies::table)
.values(OrgPolicyDb::to_db(self))
.values(self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -118,7 +118,7 @@ impl OrgPolicy {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(org_policies::table)
.filter(org_policies::uuid.eq(&self.uuid))
.set(OrgPolicyDb::to_db(self))
.set(self)
.execute(conn)
.map_res("Error saving org_policy")
}
@ -126,7 +126,6 @@ impl OrgPolicy {
}.map_res("Error saving org_policy")
}
postgresql {
let value = OrgPolicyDb::to_db(self);
// We need to make sure we're not going to violate the unique constraint on org_uuid and atype.
// This happens automatically on other DBMS backends due to replace_into(). PostgreSQL does
// not support multiple constraints on ON CONFLICT clauses.
@ -139,17 +138,17 @@ impl OrgPolicy {
.map_res("Error deleting org_policy for insert")?;
diesel::insert_into(org_policies::table)
.values(&value)
.values(self)
.on_conflict(org_policies::uuid)
.do_update()
.set(&value)
.set(self)
.execute(conn)
.map_res("Error saving org_policy")
}
}
}
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(org_policies::table.filter(org_policies::uuid.eq(self.uuid)))
.execute(conn)
@ -157,17 +156,16 @@ impl OrgPolicy {
}}
}
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
org_policies::table
.filter(org_policies::org_uuid.eq(org_uuid))
.load::<OrgPolicyDb>(conn)
.load::<Self>(conn)
.expect("Error loading org_policy")
.from_db()
}}
}
pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
org_policies::table
.inner_join(
@ -179,28 +177,26 @@ impl OrgPolicy {
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.select(org_policies::all_columns)
.load::<OrgPolicyDb>(conn)
.load::<Self>(conn)
.expect("Error loading org_policy")
.from_db()
}}
}
pub async fn find_by_org_and_type(
org_uuid: &OrganizationId,
policy_type: OrgPolicyType,
conn: &mut DbConn,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
org_policies::table
.filter(org_policies::org_uuid.eq(org_uuid))
.filter(org_policies::atype.eq(policy_type as i32))
.first::<OrgPolicyDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(org_policies::table.filter(org_policies::org_uuid.eq(org_uuid)))
.execute(conn)
@ -229,16 +225,15 @@ impl OrgPolicy {
.filter(org_policies::atype.eq(policy_type as i32))
.filter(org_policies::enabled.eq(true))
.select(org_policies::all_columns)
.load::<OrgPolicyDb>(conn)
.load::<Self>(conn)
.expect("Error loading org_policy")
.from_db()
}}
}
pub async fn find_confirmed_by_user_and_active_policy(
user_uuid: &UserId,
policy_type: OrgPolicyType,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
org_policies::table
@ -253,9 +248,8 @@ impl OrgPolicy {
.filter(org_policies::atype.eq(policy_type as i32))
.filter(org_policies::enabled.eq(true))
.select(org_policies::all_columns)
.load::<OrgPolicyDb>(conn)
.load::<Self>(conn)
.expect("Error loading org_policy")
.from_db()
}}
}
@ -266,7 +260,7 @@ impl OrgPolicy {
user_uuid: &UserId,
policy_type: OrgPolicyType,
exclude_org_uuid: Option<&OrganizationId>,
conn: &mut DbConn,
conn: &DbConn,
) -> bool {
for policy in
OrgPolicy::find_accepted_and_confirmed_by_user_and_active_policy(user_uuid, policy_type, conn).await
@ -289,7 +283,7 @@ impl OrgPolicy {
user_uuid: &UserId,
org_uuid: &OrganizationId,
exclude_current_org: bool,
conn: &mut DbConn,
conn: &DbConn,
) -> OrgPolicyResult {
// Enforce TwoFactor/TwoStep login
if TwoFactor::find_by_user(user_uuid, conn).await.is_empty() {
@ -315,7 +309,7 @@ impl OrgPolicy {
Ok(())
}
pub async fn org_is_reset_password_auto_enroll(org_uuid: &OrganizationId, conn: &mut DbConn) -> bool {
pub async fn org_is_reset_password_auto_enroll(org_uuid: &OrganizationId, conn: &DbConn) -> bool {
match OrgPolicy::find_by_org_and_type(org_uuid, OrgPolicyType::ResetPassword, conn).await {
Some(policy) => match serde_json::from_str::<ResetPasswordDataModel>(&policy.data) {
Ok(opts) => {
@ -331,7 +325,7 @@ impl OrgPolicy {
/// Returns true if the user belongs to an org that has enabled the `DisableHideEmail`
/// option of the `Send Options` policy, and the user is not an owner or admin of that org.
pub async fn is_hide_email_disabled(user_uuid: &UserId, conn: &mut DbConn) -> bool {
pub async fn is_hide_email_disabled(user_uuid: &UserId, conn: &DbConn) -> bool {
for policy in
OrgPolicy::find_confirmed_by_user_and_active_policy(user_uuid, OrgPolicyType::SendOptions, conn).await
{
@ -351,11 +345,7 @@ impl OrgPolicy {
false
}
pub async fn is_enabled_for_member(
member_uuid: &MembershipId,
policy_type: OrgPolicyType,
conn: &mut DbConn,
) -> bool {
pub async fn is_enabled_for_member(member_uuid: &MembershipId, policy_type: OrgPolicyType, conn: &DbConn) -> bool {
if let Some(member) = Membership::find_by_uuid(member_uuid, conn).await {
if let Some(policy) = OrgPolicy::find_by_org_and_type(&member.org_uuid, policy_type, conn).await {
return policy.enabled;

247
src/db/models/organization.rs

@ -1,5 +1,6 @@
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use num_traits::FromPrimitive;
use serde_json::Value;
use std::{
@ -11,10 +12,13 @@ use super::{
CipherId, Collection, CollectionGroup, CollectionId, CollectionUser, Group, GroupId, GroupUser, OrgPolicy,
OrgPolicyType, TwoFactor, User, UserId,
};
use crate::db::schema::{
ciphers, ciphers_collections, collections_groups, groups, groups_users, org_policies, organization_api_key,
organizations, users, users_collections, users_organizations,
};
use crate::CONFIG;
use macros::UuidFromParam;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = organizations)]
#[diesel(treat_none_as_null = true)]
@ -56,7 +60,6 @@ db_object! {
pub api_key: String,
pub revision_date: NaiveDateTime,
}
}
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Enums/OrganizationUserStatusType.cs
#[derive(PartialEq)]
@ -325,7 +328,7 @@ use crate::error::MapResult;
/// Database methods
impl Organization {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
if !crate::util::is_valid_email(&self.billing_email) {
err!(format!("BillingEmail {} is not a valid email address", self.billing_email))
}
@ -337,7 +340,7 @@ impl Organization {
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(organizations::table)
.values(OrganizationDb::to_db(self))
.values(self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -345,7 +348,7 @@ impl Organization {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(organizations::table)
.filter(organizations::uuid.eq(&self.uuid))
.set(OrganizationDb::to_db(self))
.set(self)
.execute(conn)
.map_res("Error saving organization")
}
@ -354,19 +357,18 @@ impl Organization {
}
postgresql {
let value = OrganizationDb::to_db(self);
diesel::insert_into(organizations::table)
.values(&value)
.values(self)
.on_conflict(organizations::uuid)
.do_update()
.set(&value)
.set(self)
.execute(conn)
.map_res("Error saving organization")
}
}
}
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
use super::{Cipher, Collection};
Cipher::delete_all_by_organization(&self.uuid, conn).await?;
@ -383,31 +385,33 @@ impl Organization {
}}
}
pub async fn find_by_uuid(uuid: &OrganizationId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid(uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
organizations::table
.filter(organizations::uuid.eq(uuid))
.first::<OrganizationDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_name(name: &str, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_name(name: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
organizations::table
.filter(organizations::name.eq(name))
.first::<OrganizationDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn get_all(conn: &mut DbConn) -> Vec<Self> {
pub async fn get_all(conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
organizations::table.load::<OrganizationDb>(conn).expect("Error loading organizations").from_db()
organizations::table
.load::<Self>(conn)
.expect("Error loading organizations")
}}
}
pub async fn find_main_org_user_email(user_email: &str, conn: &mut DbConn) -> Option<Organization> {
pub async fn find_main_org_user_email(user_email: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = user_email.to_lowercase();
db_run! { conn: {
@ -418,12 +422,12 @@ impl Organization {
.filter(users_organizations::status.ne(MembershipStatus::Revoked as i32))
.order(users_organizations::atype.asc())
.select(organizations::all_columns)
.first::<OrganizationDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_org_user_email(user_email: &str, conn: &mut DbConn) -> Vec<Organization> {
pub async fn find_org_user_email(user_email: &str, conn: &DbConn) -> Vec<Self> {
let lower_mail = user_email.to_lowercase();
db_run! { conn: {
@ -434,15 +438,14 @@ impl Organization {
.filter(users_organizations::status.ne(MembershipStatus::Revoked as i32))
.order(users_organizations::atype.asc())
.select(organizations::all_columns)
.load::<OrganizationDb>(conn)
.load::<Self>(conn)
.expect("Error loading user orgs")
.from_db()
}}
}
}
impl Membership {
pub async fn to_json(&self, conn: &mut DbConn) -> Value {
pub async fn to_json(&self, conn: &DbConn) -> Value {
let org = Organization::find_by_uuid(&self.org_uuid, conn).await.unwrap();
// HACK: Convert the manager type to a custom type
@ -533,12 +536,7 @@ impl Membership {
})
}
pub async fn to_json_user_details(
&self,
include_collections: bool,
include_groups: bool,
conn: &mut DbConn,
) -> Value {
pub async fn to_json_user_details(&self, include_collections: bool, include_groups: bool, conn: &DbConn) -> Value {
let user = User::find_by_uuid(&self.user_uuid, conn).await.unwrap();
// Because BitWarden want the status to be -1 for revoked users we need to catch that here.
@ -680,7 +678,7 @@ impl Membership {
})
}
pub async fn to_json_details(&self, conn: &mut DbConn) -> Value {
pub async fn to_json_details(&self, conn: &DbConn) -> Value {
let coll_uuids = if self.access_all {
vec![] // If we have complete access, no need to fill the array
} else {
@ -720,7 +718,7 @@ impl Membership {
})
}
pub async fn to_json_mini_details(&self, conn: &mut DbConn) -> Value {
pub async fn to_json_mini_details(&self, conn: &DbConn) -> Value {
let user = User::find_by_uuid(&self.user_uuid, conn).await.unwrap();
// Because Bitwarden wants the status to be -1 for revoked users we need to catch that here.
@ -742,13 +740,13 @@ impl Membership {
})
}
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn).await;
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(users_organizations::table)
.values(MembershipDb::to_db(self))
.values(self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -756,7 +754,7 @@ impl Membership {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(users_organizations::table)
.filter(users_organizations::uuid.eq(&self.uuid))
.set(MembershipDb::to_db(self))
.set(self)
.execute(conn)
.map_res("Error adding user to organization")
},
@ -764,19 +762,18 @@ impl Membership {
}.map_res("Error adding user to organization")
}
postgresql {
let value = MembershipDb::to_db(self);
diesel::insert_into(users_organizations::table)
.values(&value)
.values(self)
.on_conflict(users_organizations::uuid)
.do_update()
.set(&value)
.set(self)
.execute(conn)
.map_res("Error adding user to organization")
}
}
}
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn).await;
CollectionUser::delete_all_by_user_and_org(&self.user_uuid, &self.org_uuid, conn).await?;
@ -789,25 +786,21 @@ impl Membership {
}}
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
for member in Self::find_by_org(org_uuid, conn).await {
member.delete(conn).await?;
}
Ok(())
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
for member in Self::find_any_state_by_user(user_uuid, conn).await {
member.delete(conn).await?;
}
Ok(())
}
pub async fn find_by_email_and_org(
email: &str,
org_uuid: &OrganizationId,
conn: &mut DbConn,
) -> Option<Membership> {
pub async fn find_by_email_and_org(email: &str, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Membership> {
if let Some(user) = User::find_by_mail(email, conn).await {
if let Some(member) = Membership::find_by_user_and_org(&user.uuid, org_uuid, conn).await {
return Some(member);
@ -829,52 +822,48 @@ impl Membership {
(self.access_all || self.atype >= MembershipType::Admin) && self.has_status(MembershipStatus::Confirmed)
}
pub async fn find_by_uuid(uuid: &MembershipId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid(uuid: &MembershipId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::uuid.eq(uuid))
.first::<MembershipDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_uuid_and_org(
uuid: &MembershipId,
org_uuid: &OrganizationId,
conn: &mut DbConn,
) -> Option<Self> {
pub async fn find_by_uuid_and_org(uuid: &MembershipId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::uuid.eq(uuid))
.filter(users_organizations::org_uuid.eq(org_uuid))
.first::<MembershipDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.load::<MembershipDb>(conn)
.unwrap_or_default().from_db()
.load::<Self>(conn)
.unwrap_or_default()
}}
}
pub async fn find_invited_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_invited_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Invited as i32))
.load::<MembershipDb>(conn)
.unwrap_or_default().from_db()
.load::<Self>(conn)
.unwrap_or_default()
}}
}
// Should be used only when email are disabled.
// In Organizations::send_invite status is set to Accepted only if the user has a password.
pub async fn accept_user_invitations(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn accept_user_invitations(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::update(users_organizations::table)
.filter(users_organizations::user_uuid.eq(user_uuid))
@ -885,16 +874,16 @@ impl Membership {
}}
}
pub async fn find_any_state_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_any_state_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.load::<MembershipDb>(conn)
.unwrap_or_default().from_db()
.load::<Self>(conn)
.unwrap_or_default()
}}
}
pub async fn count_accepted_and_confirmed_by_user(user_uuid: &UserId, conn: &mut DbConn) -> i64 {
pub async fn count_accepted_and_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
@ -905,27 +894,27 @@ impl Membership {
}}
}
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.load::<MembershipDb>(conn)
.expect("Error loading user organizations").from_db()
.load::<Self>(conn)
.expect("Error loading user organizations")
}}
}
pub async fn find_confirmed_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_confirmed_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.load::<MembershipDb>(conn)
.unwrap_or_default().from_db()
.load::<Self>(conn)
.unwrap_or_default()
}}
}
// Get all users which are either owner or admin, or a manager which can manage/access all
pub async fn find_confirmed_and_manage_all_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_confirmed_and_manage_all_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
@ -934,12 +923,12 @@ impl Membership {
users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32])
.or(users_organizations::atype.eq(MembershipType::Manager as i32).and(users_organizations::access_all.eq(true)))
)
.load::<MembershipDb>(conn)
.unwrap_or_default().from_db()
.load::<Self>(conn)
.unwrap_or_default()
}}
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 {
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
@ -950,24 +939,20 @@ impl Membership {
}}
}
pub async fn find_by_org_and_type(
org_uuid: &OrganizationId,
atype: MembershipType,
conn: &mut DbConn,
) -> Vec<Self> {
pub async fn find_by_org_and_type(org_uuid: &OrganizationId, atype: MembershipType, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::atype.eq(atype as i32))
.load::<MembershipDb>(conn)
.expect("Error loading user organizations").from_db()
.load::<Self>(conn)
.expect("Error loading user organizations")
}}
}
pub async fn count_confirmed_by_org_and_type(
org_uuid: &OrganizationId,
atype: MembershipType,
conn: &mut DbConn,
conn: &DbConn,
) -> i64 {
db_run! { conn: {
users_organizations::table
@ -980,24 +965,20 @@ impl Membership {
}}
}
pub async fn find_by_user_and_org(
user_uuid: &UserId,
org_uuid: &OrganizationId,
conn: &mut DbConn,
) -> Option<Self> {
pub async fn find_by_user_and_org(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::org_uuid.eq(org_uuid))
.first::<MembershipDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_confirmed_by_user_and_org(
user_uuid: &UserId,
org_uuid: &OrganizationId,
conn: &mut DbConn,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
users_organizations::table
@ -1006,21 +987,21 @@ impl Membership {
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.first::<MembershipDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.load::<MembershipDb>(conn)
.expect("Error loading user organizations").from_db()
.load::<Self>(conn)
.expect("Error loading user organizations")
}}
}
pub async fn get_orgs_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<OrganizationId> {
pub async fn get_orgs_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<OrganizationId> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
@ -1030,11 +1011,7 @@ impl Membership {
}}
}
pub async fn find_by_user_and_policy(
user_uuid: &UserId,
policy_type: OrgPolicyType,
conn: &mut DbConn,
) -> Vec<Self> {
pub async fn find_by_user_and_policy(user_uuid: &UserId, policy_type: OrgPolicyType, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
users_organizations::table
.inner_join(
@ -1048,16 +1025,12 @@ impl Membership {
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.select(users_organizations::all_columns)
.load::<MembershipDb>(conn)
.unwrap_or_default().from_db()
.load::<Self>(conn)
.unwrap_or_default()
}}
}
pub async fn find_by_cipher_and_org(
cipher_uuid: &CipherId,
org_uuid: &OrganizationId,
conn: &mut DbConn,
) -> Vec<Self> {
pub async fn find_by_cipher_and_org(cipher_uuid: &CipherId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
@ -1076,14 +1049,15 @@ impl Membership {
)
.select(users_organizations::all_columns)
.distinct()
.load::<MembershipDb>(conn).expect("Error loading user organizations").from_db()
.load::<Self>(conn)
.expect("Error loading user organizations")
}}
}
pub async fn find_by_cipher_and_org_with_group(
cipher_uuid: &CipherId,
org_uuid: &OrganizationId,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
users_organizations::table
@ -1106,15 +1080,12 @@ impl Membership {
)
.select(users_organizations::all_columns)
.distinct()
.load::<MembershipDb>(conn).expect("Error loading user organizations with groups").from_db()
.load::<Self>(conn)
.expect("Error loading user organizations with groups")
}}
}
pub async fn user_has_ge_admin_access_to_cipher(
user_uuid: &UserId,
cipher_uuid: &CipherId,
conn: &mut DbConn,
) -> bool {
pub async fn user_has_ge_admin_access_to_cipher(user_uuid: &UserId, cipher_uuid: &CipherId, conn: &DbConn) -> bool {
db_run! { conn: {
users_organizations::table
.inner_join(ciphers::table.on(ciphers::uuid.eq(cipher_uuid).and(ciphers::organization_uuid.eq(users_organizations::org_uuid.nullable()))))
@ -1122,14 +1093,15 @@ impl Membership {
.filter(users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32]))
.count()
.first::<i64>(conn)
.ok().unwrap_or(0) != 0
.ok()
.unwrap_or(0) != 0
}}
}
pub async fn find_by_collection_and_org(
collection_uuid: &CollectionId,
org_uuid: &OrganizationId,
conn: &mut DbConn,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
users_organizations::table
@ -1143,33 +1115,31 @@ impl Membership {
)
)
.select(users_organizations::all_columns)
.load::<MembershipDb>(conn).expect("Error loading user organizations").from_db()
.load::<Self>(conn)
.expect("Error loading user organizations")
}}
}
pub async fn find_by_external_id_and_org(
ext_id: &str,
org_uuid: &OrganizationId,
conn: &mut DbConn,
) -> Option<Self> {
pub async fn find_by_external_id_and_org(ext_id: &str, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
users_organizations::table
.filter(
users_organizations::external_id.eq(ext_id)
.and(users_organizations::org_uuid.eq(org_uuid))
)
.first::<MembershipDb>(conn).ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_main_user_org(user_uuid: &str, conn: &mut DbConn) -> Option<Self> {
pub async fn find_main_user_org(user_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.ne(MembershipStatus::Revoked as i32))
.order(users_organizations::atype.asc())
.first::<MembershipDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
}
@ -1179,7 +1149,7 @@ impl OrganizationApiKey {
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(organization_api_key::table)
.values(OrganizationApiKeyDb::to_db(self))
.values(self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -1187,7 +1157,7 @@ impl OrganizationApiKey {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(organization_api_key::table)
.filter(organization_api_key::uuid.eq(&self.uuid))
.set(OrganizationApiKeyDb::to_db(self))
.set(self)
.execute(conn)
.map_res("Error saving organization")
}
@ -1196,12 +1166,11 @@ impl OrganizationApiKey {
}
postgresql {
let value = OrganizationApiKeyDb::to_db(self);
diesel::insert_into(organization_api_key::table)
.values(&value)
.values(self)
.on_conflict((organization_api_key::uuid, organization_api_key::org_uuid))
.do_update()
.set(&value)
.set(self)
.execute(conn)
.map_res("Error saving organization")
}
@ -1212,12 +1181,12 @@ impl OrganizationApiKey {
db_run! { conn: {
organization_api_key::table
.filter(organization_api_key::org_uuid.eq(org_uuid))
.first::<OrganizationApiKeyDb>(conn)
.ok().from_db()
.first::<Self>(conn)
.ok()
}}
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(organization_api_key::table.filter(organization_api_key::org_uuid.eq(org_uuid)))
.execute(conn)

56
src/db/models/send.rs

@ -4,9 +4,10 @@ use serde_json::Value;
use crate::{config::PathType, util::LowerCase, CONFIG};
use super::{OrganizationId, User, UserId};
use crate::db::schema::sends;
use diesel::prelude::*;
use id::SendId;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = sends)]
#[diesel(treat_none_as_null = true)]
@ -38,7 +39,6 @@ db_object! {
pub disabled: bool,
pub hide_email: Option<bool>,
}
}
#[derive(Copy, Clone, PartialEq, Eq, num_derive::FromPrimitive)]
pub enum SendType {
@ -103,7 +103,7 @@ impl Send {
}
}
pub async fn creator_identifier(&self, conn: &mut DbConn) -> Option<String> {
pub async fn creator_identifier(&self, conn: &DbConn) -> Option<String> {
if let Some(hide_email) = self.hide_email {
if hide_email {
return None;
@ -155,7 +155,7 @@ impl Send {
})
}
pub async fn to_json_access(&self, conn: &mut DbConn) -> Value {
pub async fn to_json_access(&self, conn: &DbConn) -> Value {
use crate::util::format_date;
let mut data = serde_json::from_str::<LowerCase<Value>>(&self.data).map(|d| d.data).unwrap_or_default();
@ -187,14 +187,14 @@ use crate::error::MapResult;
use crate::util::NumberOrString;
impl Send {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await;
self.revision_date = Utc::now().naive_utc();
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(sends::table)
.values(SendDb::to_db(self))
.values(&*self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -202,7 +202,7 @@ impl Send {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(sends::table)
.filter(sends::uuid.eq(&self.uuid))
.set(SendDb::to_db(self))
.set(&*self)
.execute(conn)
.map_res("Error saving send")
}
@ -210,19 +210,18 @@ impl Send {
}.map_res("Error saving send")
}
postgresql {
let value = SendDb::to_db(self);
diesel::insert_into(sends::table)
.values(&value)
.values(&*self)
.on_conflict(sends::uuid)
.do_update()
.set(&value)
.set(&*self)
.execute(conn)
.map_res("Error saving send")
}
}
}
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await;
if self.atype == SendType::File as i32 {
@ -238,13 +237,13 @@ impl Send {
}
/// Purge all sends that are past their deletion date.
pub async fn purge(conn: &mut DbConn) {
pub async fn purge(conn: &DbConn) {
for send in Self::find_by_past_deletion_date(conn).await {
send.delete(conn).await.ok();
}
}
pub async fn update_users_revision(&self, conn: &mut DbConn) -> Vec<UserId> {
pub async fn update_users_revision(&self, conn: &DbConn) -> Vec<UserId> {
let mut user_uuids = Vec::new();
match &self.user_uuid {
Some(user_uuid) => {
@ -258,14 +257,14 @@ impl Send {
user_uuids
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
for send in Self::find_by_user(user_uuid, conn).await {
send.delete(conn).await?;
}
Ok(())
}
pub async fn find_by_access_id(access_id: &str, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_access_id(access_id: &str, conn: &DbConn) -> Option<Self> {
use data_encoding::BASE64URL_NOPAD;
use uuid::Uuid;
@ -281,36 +280,35 @@ impl Send {
Self::find_by_uuid(&uuid, conn).await
}
pub async fn find_by_uuid(uuid: &SendId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid(uuid: &SendId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
sends::table
.filter(sends::uuid.eq(uuid))
.first::<SendDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_by_uuid_and_user(uuid: &SendId, user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid_and_user(uuid: &SendId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
sends::table
.filter(sends::uuid.eq(uuid))
.filter(sends::user_uuid.eq(user_uuid))
.first::<SendDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
sends::table
.filter(sends::user_uuid.eq(user_uuid))
.load::<SendDb>(conn).expect("Error loading sends").from_db()
.load::<Self>(conn)
.expect("Error loading sends")
}}
}
pub async fn size_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Option<i64> {
pub async fn size_by_user(user_uuid: &UserId, conn: &DbConn) -> Option<i64> {
let sends = Self::find_by_user(user_uuid, conn).await;
#[derive(serde::Deserialize)]
@ -333,20 +331,22 @@ impl Send {
Some(total)
}
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
sends::table
.filter(sends::organization_uuid.eq(org_uuid))
.load::<SendDb>(conn).expect("Error loading sends").from_db()
.load::<Self>(conn)
.expect("Error loading sends")
}}
}
pub async fn find_by_past_deletion_date(conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_past_deletion_date(conn: &DbConn) -> Vec<Self> {
let now = Utc::now().naive_utc();
db_run! { conn: {
sends::table
.filter(sends::deletion_date.lt(now))
.load::<SendDb>(conn).expect("Error loading sends").from_db()
.load::<Self>(conn)
.expect("Error loading sends")
}}
}
}

16
src/db/models/sso_nonce.rs

@ -1,11 +1,12 @@
use chrono::{NaiveDateTime, Utc};
use crate::api::EmptyResult;
use crate::db::schema::sso_nonce;
use crate::db::{DbConn, DbPool};
use crate::error::MapResult;
use crate::sso::{OIDCState, NONCE_EXPIRATION};
use diesel::prelude::*;
db_object! {
#[derive(Identifiable, Queryable, Insertable)]
#[diesel(table_name = sso_nonce)]
#[diesel(primary_key(state))]
@ -16,7 +17,6 @@ db_object! {
pub redirect_uri: String,
pub created_at: NaiveDateTime,
}
}
/// Local methods
impl SsoNonce {
@ -35,25 +35,24 @@ impl SsoNonce {
/// Database methods
impl SsoNonce {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn:
sqlite, mysql {
diesel::replace_into(sso_nonce::table)
.values(SsoNonceDb::to_db(self))
.values(self)
.execute(conn)
.map_res("Error saving SSO nonce")
}
postgresql {
let value = SsoNonceDb::to_db(self);
diesel::insert_into(sso_nonce::table)
.values(&value)
.values(self)
.execute(conn)
.map_res("Error saving SSO nonce")
}
}
}
pub async fn delete(state: &OIDCState, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(state: &OIDCState, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(sso_nonce::table.filter(sso_nonce::state.eq(state)))
.execute(conn)
@ -67,9 +66,8 @@ impl SsoNonce {
sso_nonce::table
.filter(sso_nonce::state.eq(state))
.filter(sso_nonce::created_at.ge(oldest))
.first::<SsoNonceDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}

39
src/db/models/two_factor.rs

@ -1,12 +1,13 @@
use super::UserId;
use crate::api::core::two_factor::webauthn::WebauthnRegistration;
use crate::db::schema::twofactor;
use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*;
use serde_json::Value;
use webauthn_rs::prelude::{Credential, ParsedAttestation};
use webauthn_rs_core::proto::CredentialV3;
use webauthn_rs_proto::{AttestationFormat, RegisteredExtensions};
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor)]
#[diesel(primary_key(uuid))]
@ -18,7 +19,6 @@ db_object! {
pub data: String,
pub last_used: i64,
}
}
#[allow(dead_code)]
#[derive(num_derive::FromPrimitive)]
@ -76,11 +76,11 @@ impl TwoFactor {
/// Database methods
impl TwoFactor {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(twofactor::table)
.values(TwoFactorDb::to_db(self))
.values(self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -88,7 +88,7 @@ impl TwoFactor {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(twofactor::table)
.filter(twofactor::uuid.eq(&self.uuid))
.set(TwoFactorDb::to_db(self))
.set(self)
.execute(conn)
.map_res("Error saving twofactor")
}
@ -96,7 +96,6 @@ impl TwoFactor {
}.map_res("Error saving twofactor")
}
postgresql {
let value = TwoFactorDb::to_db(self);
// We need to make sure we're not going to violate the unique constraint on user_uuid and atype.
// This happens automatically on other DBMS backends due to replace_into(). PostgreSQL does
// not support multiple constraints on ON CONFLICT clauses.
@ -105,17 +104,17 @@ impl TwoFactor {
.map_res("Error deleting twofactor for insert")?;
diesel::insert_into(twofactor::table)
.values(&value)
.values(self)
.on_conflict(twofactor::uuid)
.do_update()
.set(&value)
.set(self)
.execute(conn)
.map_res("Error saving twofactor")
}
}
}
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(twofactor::table.filter(twofactor::uuid.eq(self.uuid)))
.execute(conn)
@ -123,29 +122,27 @@ impl TwoFactor {
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
twofactor::table
.filter(twofactor::user_uuid.eq(user_uuid))
.filter(twofactor::atype.lt(1000)) // Filter implementation types
.load::<TwoFactorDb>(conn)
.load::<Self>(conn)
.expect("Error loading twofactor")
.from_db()
}}
}
pub async fn find_by_user_and_type(user_uuid: &UserId, atype: i32, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_user_and_type(user_uuid: &UserId, atype: i32, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
twofactor::table
.filter(twofactor::user_uuid.eq(user_uuid))
.filter(twofactor::atype.eq(atype))
.first::<TwoFactorDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid)))
.execute(conn)
@ -153,13 +150,12 @@ impl TwoFactor {
}}
}
pub async fn migrate_u2f_to_webauthn(conn: &mut DbConn) -> EmptyResult {
pub async fn migrate_u2f_to_webauthn(conn: &DbConn) -> EmptyResult {
let u2f_factors = db_run! { conn: {
twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::U2f as i32))
.load::<TwoFactorDb>(conn)
.load::<Self>(conn)
.expect("Error loading twofactor")
.from_db()
}};
use crate::api::core::two_factor::webauthn::U2FRegistration;
@ -231,13 +227,12 @@ impl TwoFactor {
Ok(())
}
pub async fn migrate_credential_to_passkey(conn: &mut DbConn) -> EmptyResult {
pub async fn migrate_credential_to_passkey(conn: &DbConn) -> EmptyResult {
let webauthn_factors = db_run! { conn: {
twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::Webauthn as i32))
.load::<TwoFactorDb>(conn)
.load::<Self>(conn)
.expect("Error loading twofactor")
.from_db()
}};
for webauthn_factor in webauthn_factors {

44
src/db/models/two_factor_duo_context.rs

@ -1,8 +1,9 @@
use chrono::Utc;
use crate::db::schema::twofactor_duo_ctx;
use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor_duo_ctx)]
#[diesel(primary_key(state))]
@ -12,22 +13,18 @@ db_object! {
pub nonce: String,
pub exp: i64,
}
}
impl TwoFactorDuoContext {
pub async fn find_by_state(state: &str, conn: &mut DbConn) -> Option<Self> {
db_run! {
conn: {
pub async fn find_by_state(state: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::state.eq(state))
.first::<TwoFactorDuoContextDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}
}
}}
}
pub async fn save(state: &str, user_email: &str, nonce: &str, ttl: i64, conn: &mut DbConn) -> EmptyResult {
pub async fn save(state: &str, user_email: &str, nonce: &str, ttl: i64, conn: &DbConn) -> EmptyResult {
// A saved context should never be changed, only created or deleted.
let exists = Self::find_by_state(state, conn).await;
if exists.is_some() {
@ -36,8 +33,7 @@ impl TwoFactorDuoContext {
let exp = Utc::now().timestamp() + ttl;
db_run! {
conn: {
db_run! { conn: {
diesel::insert_into(twofactor_duo_ctx::table)
.values((
twofactor_duo_ctx::state.eq(state),
@ -47,36 +43,30 @@ impl TwoFactorDuoContext {
))
.execute(conn)
.map_res("Error saving context to twofactor_duo_ctx")
}
}
}}
}
pub async fn find_expired(conn: &mut DbConn) -> Vec<Self> {
pub async fn find_expired(conn: &DbConn) -> Vec<Self> {
let now = Utc::now().timestamp();
db_run! {
conn: {
db_run! { conn: {
twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::exp.lt(now))
.load::<TwoFactorDuoContextDb>(conn)
.load::<Self>(conn)
.expect("Error finding expired contexts in twofactor_duo_ctx")
.from_db()
}
}
}}
}
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult {
db_run! {
conn: {
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(
twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::state.eq(&self.state)))
.execute(conn)
.map_res("Error deleting from twofactor_duo_ctx")
}
}
}}
}
pub async fn purge_expired_duo_contexts(conn: &mut DbConn) {
pub async fn purge_expired_duo_contexts(conn: &DbConn) {
for context in Self::find_expired(conn).await {
context.delete(conn).await.ok();
}

32
src/db/models/two_factor_incomplete.rs

@ -1,5 +1,6 @@
use chrono::{NaiveDateTime, Utc};
use crate::db::schema::twofactor_incomplete;
use crate::{
api::EmptyResult,
auth::ClientIp,
@ -10,8 +11,8 @@ use crate::{
error::MapResult,
CONFIG,
};
use diesel::prelude::*;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor_incomplete)]
#[diesel(primary_key(user_uuid, device_uuid))]
@ -26,7 +27,6 @@ db_object! {
pub login_time: NaiveDateTime,
pub ip_address: String,
}
}
impl TwoFactorIncomplete {
pub async fn mark_incomplete(
@ -35,7 +35,7 @@ impl TwoFactorIncomplete {
device_name: &str,
device_type: i32,
ip: &ClientIp,
conn: &mut DbConn,
conn: &DbConn,
) -> EmptyResult {
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
return Ok(());
@ -64,7 +64,7 @@ impl TwoFactorIncomplete {
}}
}
pub async fn mark_complete(user_uuid: &UserId, device_uuid: &DeviceId, conn: &mut DbConn) -> EmptyResult {
pub async fn mark_complete(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
return Ok(());
}
@ -72,40 +72,30 @@ impl TwoFactorIncomplete {
Self::delete_by_user_and_device(user_uuid, device_uuid, conn).await
}
pub async fn find_by_user_and_device(
user_uuid: &UserId,
device_uuid: &DeviceId,
conn: &mut DbConn,
) -> Option<Self> {
pub async fn find_by_user_and_device(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid))
.filter(twofactor_incomplete::device_uuid.eq(device_uuid))
.first::<TwoFactorIncompleteDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_logins_before(dt: &NaiveDateTime, conn: &mut DbConn) -> Vec<Self> {
pub async fn find_logins_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
twofactor_incomplete::table
.filter(twofactor_incomplete::login_time.lt(dt))
.load::<TwoFactorIncompleteDb>(conn)
.load::<Self>(conn)
.expect("Error loading twofactor_incomplete")
.from_db()
}}
}
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
Self::delete_by_user_and_device(&self.user_uuid, &self.device_uuid, conn).await
}
pub async fn delete_by_user_and_device(
user_uuid: &UserId,
device_uuid: &DeviceId,
conn: &mut DbConn,
) -> EmptyResult {
pub async fn delete_by_user_and_device(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid))
@ -115,7 +105,7 @@ impl TwoFactorIncomplete {
}}
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(twofactor_incomplete::table.filter(twofactor_incomplete::user_uuid.eq(user_uuid)))
.execute(conn)

89
src/db/models/user.rs

@ -1,5 +1,7 @@
use crate::db::schema::{devices, invitations, sso_users, users};
use chrono::{NaiveDateTime, TimeDelta, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value;
use super::{
@ -17,7 +19,6 @@ use crate::{
};
use macros::UuidFromParam;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset, Selectable)]
#[diesel(table_name = users)]
#[diesel(treat_none_as_null = true)]
@ -81,7 +82,6 @@ db_object! {
pub user_uuid: UserId,
pub identifier: OIDCIdentifier,
}
}
pub enum UserKdfType {
Pbkdf2 = 0,
@ -236,7 +236,7 @@ impl User {
/// Database methods
impl User {
pub async fn to_json(&self, conn: &mut DbConn) -> Value {
pub async fn to_json(&self, conn: &DbConn) -> Value {
let mut orgs_json = Vec::new();
for c in Membership::find_confirmed_by_user(&self.uuid, conn).await {
orgs_json.push(c.to_json(conn).await);
@ -275,7 +275,7 @@ impl User {
})
}
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
if !crate::util::is_valid_email(&self.email) {
err!(format!("User email {} is not a valid email address", self.email))
}
@ -285,7 +285,7 @@ impl User {
db_run! { conn:
sqlite, mysql {
match diesel::replace_into(users::table)
.values(UserDb::to_db(self))
.values(&*self)
.execute(conn)
{
Ok(_) => Ok(()),
@ -293,7 +293,7 @@ impl User {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(users::table)
.filter(users::uuid.eq(&self.uuid))
.set(UserDb::to_db(self))
.set(&*self)
.execute(conn)
.map_res("Error saving user")
}
@ -301,19 +301,18 @@ impl User {
}.map_res("Error saving user")
}
postgresql {
let value = UserDb::to_db(self);
diesel::insert_into(users::table) // Insert or update
.values(&value)
.values(&*self)
.on_conflict(users::uuid)
.do_update()
.set(&value)
.set(&*self)
.execute(conn)
.map_res("Error saving user")
}
}
}
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
for member in Membership::find_confirmed_by_user(&self.uuid, conn).await {
if member.atype == MembershipType::Owner
&& Membership::count_confirmed_by_org_and_type(&member.org_uuid, MembershipType::Owner, conn).await <= 1
@ -341,13 +340,13 @@ impl User {
}}
}
pub async fn update_uuid_revision(uuid: &UserId, conn: &mut DbConn) {
pub async fn update_uuid_revision(uuid: &UserId, conn: &DbConn) {
if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await {
warn!("Failed to update revision for {uuid}: {e:#?}");
}
}
pub async fn update_all_revisions(conn: &mut DbConn) -> EmptyResult {
pub async fn update_all_revisions(conn: &DbConn) -> EmptyResult {
let updated_at = Utc::now().naive_utc();
db_run! { conn: {
@ -360,13 +359,13 @@ impl User {
}}
}
pub async fn update_revision(&mut self, conn: &mut DbConn) -> EmptyResult {
pub async fn update_revision(&mut self, conn: &DbConn) -> EmptyResult {
self.updated_at = Utc::now().naive_utc();
Self::_update_revision(&self.uuid, &self.updated_at, conn).await
}
async fn _update_revision(uuid: &UserId, date: &NaiveDateTime, conn: &mut DbConn) -> EmptyResult {
async fn _update_revision(uuid: &UserId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
retry(|| {
diesel::update(users::table.filter(users::uuid.eq(uuid)))
@ -377,49 +376,49 @@ impl User {
}}
}
pub async fn find_by_mail(mail: &str, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = mail.to_lowercase();
db_run! { conn: {
users::table
.filter(users::email.eq(lower_mail))
.first::<UserDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn find_by_uuid(uuid: &UserId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid(uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
users::table.filter(users::uuid.eq(uuid)).first::<UserDb>(conn).ok().from_db()
users::table
.filter(users::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_device_id(device_uuid: &DeviceId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_device_id(device_uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
users::table
.inner_join(devices::table.on(devices::user_uuid.eq(users::uuid)))
.filter(devices::uuid.eq(device_uuid))
.select(users::all_columns)
.first::<UserDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn get_all(conn: &mut DbConn) -> Vec<(User, Option<SsoUser>)> {
pub async fn get_all(conn: &DbConn) -> Vec<(Self, Option<SsoUser>)> {
db_run! { conn: {
users::table
.left_join(sso_users::table)
.select(<(UserDb, Option<SsoUserDb>)>::as_select())
.select(<(Self, Option<SsoUser>)>::as_select())
.load(conn)
.expect("Error loading groups for user")
.into_iter()
.map(|(user, sso_user)| { (user.from_db(), sso_user.from_db()) })
.collect()
}}
}
pub async fn last_active(&self, conn: &mut DbConn) -> Option<NaiveDateTime> {
pub async fn last_active(&self, conn: &DbConn) -> Option<NaiveDateTime> {
match Device::find_latest_active_by_user(&self.uuid, conn).await {
Some(device) => Some(device.updated_at),
None => None,
@ -435,7 +434,7 @@ impl Invitation {
}
}
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
if !crate::util::is_valid_email(&self.email) {
err!(format!("Invitation email {} is not a valid email address", self.email))
}
@ -445,13 +444,13 @@ impl Invitation {
// Not checking for ForeignKey Constraints here
// Table invitations does not have any ForeignKey Constraints.
diesel::replace_into(invitations::table)
.values(InvitationDb::to_db(self))
.values(self)
.execute(conn)
.map_res("Error saving invitation")
}
postgresql {
diesel::insert_into(invitations::table)
.values(InvitationDb::to_db(self))
.values(self)
.on_conflict(invitations::email)
.do_nothing()
.execute(conn)
@ -460,7 +459,7 @@ impl Invitation {
}
}
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(invitations::table.filter(invitations::email.eq(self.email)))
.execute(conn)
@ -468,18 +467,17 @@ impl Invitation {
}}
}
pub async fn find_by_mail(mail: &str, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = mail.to_lowercase();
db_run! { conn: {
invitations::table
.filter(invitations::email.eq(lower_mail))
.first::<InvitationDb>(conn)
.first::<Self>(conn)
.ok()
.from_db()
}}
}
pub async fn take(mail: &str, conn: &mut DbConn) -> bool {
pub async fn take(mail: &str, conn: &DbConn) -> bool {
match Self::find_by_mail(mail, conn).await {
Some(invitation) => invitation.delete(conn).await.is_ok(),
None => false,
@ -508,51 +506,48 @@ impl Invitation {
pub struct UserId(String);
impl SsoUser {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn:
sqlite, mysql {
diesel::replace_into(sso_users::table)
.values(SsoUserDb::to_db(self))
.values(self)
.execute(conn)
.map_res("Error saving SSO user")
}
postgresql {
let value = SsoUserDb::to_db(self);
diesel::insert_into(sso_users::table)
.values(&value)
.values(self)
.execute(conn)
.map_res("Error saving SSO user")
}
}
}
pub async fn find_by_identifier(identifier: &str, conn: &DbConn) -> Option<(User, SsoUser)> {
pub async fn find_by_identifier(identifier: &str, conn: &DbConn) -> Option<(User, Self)> {
db_run! { conn: {
users::table
.inner_join(sso_users::table)
.select(<(UserDb, SsoUserDb)>::as_select())
.select(<(User, Self)>::as_select())
.filter(sso_users::identifier.eq(identifier))
.first::<(UserDb, SsoUserDb)>(conn)
.first::<(User, Self)>(conn)
.ok()
.map(|(user, sso_user)| { (user.from_db(), sso_user.from_db()) })
}}
}
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<(User, Option<SsoUser>)> {
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<(User, Option<Self>)> {
let lower_mail = mail.to_lowercase();
db_run! { conn: {
users::table
.left_join(sso_users::table)
.select(<(UserDb, Option<SsoUserDb>)>::as_select())
.select(<(User, Option<Self>)>::as_select())
.filter(users::email.eq(lower_mail))
.first::<(UserDb, Option<SsoUserDb>)>(conn)
.first::<(User, Option<Self>)>(conn)
.ok()
.map(|(user, sso_user)| { (user.from_db(), sso_user.from_db()) })
}}
}
pub async fn delete(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult {
pub async fn delete(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(sso_users::table.filter(sso_users::user_uuid.eq(user_uuid)))
.execute(conn)

57
src/db/query_logger.rs

@ -0,0 +1,57 @@
use dashmap::DashMap;
use diesel::connection::{Instrumentation, InstrumentationEvent};
use std::{
sync::{Arc, LazyLock},
thread,
time::Instant,
};
pub static QUERY_PERF_TRACKER: LazyLock<Arc<DashMap<(thread::ThreadId, String), Instant>>> =
LazyLock::new(|| Arc::new(DashMap::new()));
pub fn simple_logger() -> Option<Box<dyn Instrumentation>> {
Some(Box::new(|event: InstrumentationEvent<'_>| match event {
InstrumentationEvent::StartEstablishConnection {
url,
..
} => {
debug!("Establishing connection: {url}")
}
InstrumentationEvent::FinishEstablishConnection {
url,
error,
..
} => {
if let Some(e) = error {
error!("Error during establishing a connection with {url}: {e:?}")
} else {
debug!("Connection established: {url}")
}
}
InstrumentationEvent::StartQuery {
query,
..
} => {
let query_string = format!("{query:?}");
let start = Instant::now();
QUERY_PERF_TRACKER.insert((thread::current().id(), query_string), start);
}
InstrumentationEvent::FinishQuery {
query,
..
} => {
let query_string = format!("{query:?}");
if let Some((_, start)) = QUERY_PERF_TRACKER.remove(&(thread::current().id(), query_string.clone())) {
let duration = start.elapsed();
if duration.as_secs() >= 5 {
warn!("SLOW QUERY [{:.2}s]: {}", duration.as_secs_f32(), query_string);
} else if duration.as_secs() >= 1 {
info!("SLOW QUERY [{:.2}s]: {}", duration.as_secs_f32(), query_string);
} else {
debug!("QUERY [{:?}]: {}", duration, query_string);
}
}
}
_ => {}
}))
}

0
src/db/schemas/postgresql/schema.rs → src/db/schema.rs

395
src/db/schemas/mysql/schema.rs

@ -1,395 +0,0 @@
table! {
attachments (id) {
id -> Text,
cipher_uuid -> Text,
file_name -> Text,
file_size -> BigInt,
akey -> Nullable<Text>,
}
}
table! {
ciphers (uuid) {
uuid -> Text,
created_at -> Datetime,
updated_at -> Datetime,
user_uuid -> Nullable<Text>,
organization_uuid -> Nullable<Text>,
key -> Nullable<Text>,
atype -> Integer,
name -> Text,
notes -> Nullable<Text>,
fields -> Nullable<Text>,
data -> Text,
password_history -> Nullable<Text>,
deleted_at -> Nullable<Datetime>,
reprompt -> Nullable<Integer>,
}
}
table! {
ciphers_collections (cipher_uuid, collection_uuid) {
cipher_uuid -> Text,
collection_uuid -> Text,
}
}
table! {
collections (uuid) {
uuid -> Text,
org_uuid -> Text,
name -> Text,
external_id -> Nullable<Text>,
}
}
table! {
devices (uuid, user_uuid) {
uuid -> Text,
created_at -> Datetime,
updated_at -> Datetime,
user_uuid -> Text,
name -> Text,
atype -> Integer,
push_uuid -> Nullable<Text>,
push_token -> Nullable<Text>,
refresh_token -> Text,
twofactor_remember -> Nullable<Text>,
}
}
table! {
event (uuid) {
uuid -> Varchar,
event_type -> Integer,
user_uuid -> Nullable<Varchar>,
org_uuid -> Nullable<Varchar>,
cipher_uuid -> Nullable<Varchar>,
collection_uuid -> Nullable<Varchar>,
group_uuid -> Nullable<Varchar>,
org_user_uuid -> Nullable<Varchar>,
act_user_uuid -> Nullable<Varchar>,
device_type -> Nullable<Integer>,
ip_address -> Nullable<Text>,
event_date -> Timestamp,
policy_uuid -> Nullable<Varchar>,
provider_uuid -> Nullable<Varchar>,
provider_user_uuid -> Nullable<Varchar>,
provider_org_uuid -> Nullable<Varchar>,
}
}
table! {
favorites (user_uuid, cipher_uuid) {
user_uuid -> Text,
cipher_uuid -> Text,
}
}
table! {
folders (uuid) {
uuid -> Text,
created_at -> Datetime,
updated_at -> Datetime,
user_uuid -> Text,
name -> Text,
}
}
table! {
folders_ciphers (cipher_uuid, folder_uuid) {
cipher_uuid -> Text,
folder_uuid -> Text,
}
}
table! {
invitations (email) {
email -> Text,
}
}
table! {
org_policies (uuid) {
uuid -> Text,
org_uuid -> Text,
atype -> Integer,
enabled -> Bool,
data -> Text,
}
}
table! {
organizations (uuid) {
uuid -> Text,
name -> Text,
billing_email -> Text,
private_key -> Nullable<Text>,
public_key -> Nullable<Text>,
}
}
table! {
sends (uuid) {
uuid -> Text,
user_uuid -> Nullable<Text>,
organization_uuid -> Nullable<Text>,
name -> Text,
notes -> Nullable<Text>,
atype -> Integer,
data -> Text,
akey -> Text,
password_hash -> Nullable<Binary>,
password_salt -> Nullable<Binary>,
password_iter -> Nullable<Integer>,
max_access_count -> Nullable<Integer>,
access_count -> Integer,
creation_date -> Datetime,
revision_date -> Datetime,
expiration_date -> Nullable<Datetime>,
deletion_date -> Datetime,
disabled -> Bool,
hide_email -> Nullable<Bool>,
}
}
table! {
twofactor (uuid) {
uuid -> Text,
user_uuid -> Text,
atype -> Integer,
enabled -> Bool,
data -> Text,
last_used -> BigInt,
}
}
table! {
twofactor_incomplete (user_uuid, device_uuid) {
user_uuid -> Text,
device_uuid -> Text,
device_name -> Text,
device_type -> Integer,
login_time -> Timestamp,
ip_address -> Text,
}
}
table! {
twofactor_duo_ctx (state) {
state -> Text,
user_email -> Text,
nonce -> Text,
exp -> BigInt,
}
}
table! {
users (uuid) {
uuid -> Text,
enabled -> Bool,
created_at -> Datetime,
updated_at -> Datetime,
verified_at -> Nullable<Datetime>,
last_verifying_at -> Nullable<Datetime>,
login_verify_count -> Integer,
email -> Text,
email_new -> Nullable<Text>,
email_new_token -> Nullable<Text>,
name -> Text,
password_hash -> Binary,
salt -> Binary,
password_iterations -> Integer,
password_hint -> Nullable<Text>,
akey -> Text,
private_key -> Nullable<Text>,
public_key -> Nullable<Text>,
totp_secret -> Nullable<Text>,
totp_recover -> Nullable<Text>,
security_stamp -> Text,
stamp_exception -> Nullable<Text>,
equivalent_domains -> Text,
excluded_globals -> Text,
client_kdf_type -> Integer,
client_kdf_iter -> Integer,
client_kdf_memory -> Nullable<Integer>,
client_kdf_parallelism -> Nullable<Integer>,
api_key -> Nullable<Text>,
avatar_color -> Nullable<Text>,
external_id -> Nullable<Text>,
}
}
table! {
users_collections (user_uuid, collection_uuid) {
user_uuid -> Text,
collection_uuid -> Text,
read_only -> Bool,
hide_passwords -> Bool,
manage -> Bool,
}
}
table! {
users_organizations (uuid) {
uuid -> Text,
user_uuid -> Text,
org_uuid -> Text,
invited_by_email -> Nullable<Text>,
access_all -> Bool,
akey -> Text,
status -> Integer,
atype -> Integer,
reset_password_key -> Nullable<Text>,
external_id -> Nullable<Text>,
}
}
table! {
organization_api_key (uuid, org_uuid) {
uuid -> Text,
org_uuid -> Text,
atype -> Integer,
api_key -> Text,
revision_date -> Timestamp,
}
}
table! {
sso_nonce (state) {
state -> Text,
nonce -> Text,
verifier -> Nullable<Text>,
redirect_uri -> Text,
created_at -> Timestamp,
}
}
table! {
sso_users (user_uuid) {
user_uuid -> Text,
identifier -> Text,
}
}
table! {
emergency_access (uuid) {
uuid -> Text,
grantor_uuid -> Text,
grantee_uuid -> Nullable<Text>,
email -> Nullable<Text>,
key_encrypted -> Nullable<Text>,
atype -> Integer,
status -> Integer,
wait_time_days -> Integer,
recovery_initiated_at -> Nullable<Timestamp>,
last_notification_at -> Nullable<Timestamp>,
updated_at -> Timestamp,
created_at -> Timestamp,
}
}
table! {
groups (uuid) {
uuid -> Text,
organizations_uuid -> Text,
name -> Text,
access_all -> Bool,
external_id -> Nullable<Text>,
creation_date -> Timestamp,
revision_date -> Timestamp,
}
}
table! {
groups_users (groups_uuid, users_organizations_uuid) {
groups_uuid -> Text,
users_organizations_uuid -> Text,
}
}
table! {
collections_groups (collections_uuid, groups_uuid) {
collections_uuid -> Text,
groups_uuid -> Text,
read_only -> Bool,
hide_passwords -> Bool,
manage -> Bool,
}
}
table! {
auth_requests (uuid) {
uuid -> Text,
user_uuid -> Text,
organization_uuid -> Nullable<Text>,
request_device_identifier -> Text,
device_type -> Integer,
request_ip -> Text,
response_device_id -> Nullable<Text>,
access_code -> Text,
public_key -> Text,
enc_key -> Nullable<Text>,
master_password_hash -> Nullable<Text>,
approved -> Nullable<Bool>,
creation_date -> Timestamp,
response_date -> Nullable<Timestamp>,
authentication_date -> Nullable<Timestamp>,
}
}
joinable!(attachments -> ciphers (cipher_uuid));
joinable!(ciphers -> organizations (organization_uuid));
joinable!(ciphers -> users (user_uuid));
joinable!(ciphers_collections -> ciphers (cipher_uuid));
joinable!(ciphers_collections -> collections (collection_uuid));
joinable!(collections -> organizations (org_uuid));
joinable!(devices -> users (user_uuid));
joinable!(folders -> users (user_uuid));
joinable!(folders_ciphers -> ciphers (cipher_uuid));
joinable!(folders_ciphers -> folders (folder_uuid));
joinable!(org_policies -> organizations (org_uuid));
joinable!(sends -> organizations (organization_uuid));
joinable!(sends -> users (user_uuid));
joinable!(twofactor -> users (user_uuid));
joinable!(users_collections -> collections (collection_uuid));
joinable!(users_collections -> users (user_uuid));
joinable!(users_organizations -> organizations (org_uuid));
joinable!(users_organizations -> users (user_uuid));
joinable!(users_organizations -> ciphers (org_uuid));
joinable!(organization_api_key -> organizations (org_uuid));
joinable!(emergency_access -> users (grantor_uuid));
joinable!(groups -> organizations (organizations_uuid));
joinable!(groups_users -> users_organizations (users_organizations_uuid));
joinable!(groups_users -> groups (groups_uuid));
joinable!(collections_groups -> collections (collections_uuid));
joinable!(collections_groups -> groups (groups_uuid));
joinable!(event -> users_organizations (uuid));
joinable!(auth_requests -> users (user_uuid));
joinable!(sso_users -> users (user_uuid));
allow_tables_to_appear_in_same_query!(
attachments,
ciphers,
ciphers_collections,
collections,
devices,
folders,
folders_ciphers,
invitations,
org_policies,
organizations,
sends,
sso_users,
twofactor,
users,
users_collections,
users_organizations,
organization_api_key,
emergency_access,
groups,
groups_users,
collections_groups,
event,
auth_requests,
);

395
src/db/schemas/sqlite/schema.rs

@ -1,395 +0,0 @@
table! {
attachments (id) {
id -> Text,
cipher_uuid -> Text,
file_name -> Text,
file_size -> BigInt,
akey -> Nullable<Text>,
}
}
table! {
ciphers (uuid) {
uuid -> Text,
created_at -> Timestamp,
updated_at -> Timestamp,
user_uuid -> Nullable<Text>,
organization_uuid -> Nullable<Text>,
key -> Nullable<Text>,
atype -> Integer,
name -> Text,
notes -> Nullable<Text>,
fields -> Nullable<Text>,
data -> Text,
password_history -> Nullable<Text>,
deleted_at -> Nullable<Timestamp>,
reprompt -> Nullable<Integer>,
}
}
table! {
ciphers_collections (cipher_uuid, collection_uuid) {
cipher_uuid -> Text,
collection_uuid -> Text,
}
}
table! {
collections (uuid) {
uuid -> Text,
org_uuid -> Text,
name -> Text,
external_id -> Nullable<Text>,
}
}
table! {
devices (uuid, user_uuid) {
uuid -> Text,
created_at -> Timestamp,
updated_at -> Timestamp,
user_uuid -> Text,
name -> Text,
atype -> Integer,
push_uuid -> Nullable<Text>,
push_token -> Nullable<Text>,
refresh_token -> Text,
twofactor_remember -> Nullable<Text>,
}
}
table! {
event (uuid) {
uuid -> Text,
event_type -> Integer,
user_uuid -> Nullable<Text>,
org_uuid -> Nullable<Text>,
cipher_uuid -> Nullable<Text>,
collection_uuid -> Nullable<Text>,
group_uuid -> Nullable<Text>,
org_user_uuid -> Nullable<Text>,
act_user_uuid -> Nullable<Text>,
device_type -> Nullable<Integer>,
ip_address -> Nullable<Text>,
event_date -> Timestamp,
policy_uuid -> Nullable<Text>,
provider_uuid -> Nullable<Text>,
provider_user_uuid -> Nullable<Text>,
provider_org_uuid -> Nullable<Text>,
}
}
table! {
favorites (user_uuid, cipher_uuid) {
user_uuid -> Text,
cipher_uuid -> Text,
}
}
table! {
folders (uuid) {
uuid -> Text,
created_at -> Timestamp,
updated_at -> Timestamp,
user_uuid -> Text,
name -> Text,
}
}
table! {
folders_ciphers (cipher_uuid, folder_uuid) {
cipher_uuid -> Text,
folder_uuid -> Text,
}
}
table! {
invitations (email) {
email -> Text,
}
}
table! {
org_policies (uuid) {
uuid -> Text,
org_uuid -> Text,
atype -> Integer,
enabled -> Bool,
data -> Text,
}
}
table! {
organizations (uuid) {
uuid -> Text,
name -> Text,
billing_email -> Text,
private_key -> Nullable<Text>,
public_key -> Nullable<Text>,
}
}
table! {
sends (uuid) {
uuid -> Text,
user_uuid -> Nullable<Text>,
organization_uuid -> Nullable<Text>,
name -> Text,
notes -> Nullable<Text>,
atype -> Integer,
data -> Text,
akey -> Text,
password_hash -> Nullable<Binary>,
password_salt -> Nullable<Binary>,
password_iter -> Nullable<Integer>,
max_access_count -> Nullable<Integer>,
access_count -> Integer,
creation_date -> Timestamp,
revision_date -> Timestamp,
expiration_date -> Nullable<Timestamp>,
deletion_date -> Timestamp,
disabled -> Bool,
hide_email -> Nullable<Bool>,
}
}
table! {
twofactor (uuid) {
uuid -> Text,
user_uuid -> Text,
atype -> Integer,
enabled -> Bool,
data -> Text,
last_used -> BigInt,
}
}
table! {
twofactor_incomplete (user_uuid, device_uuid) {
user_uuid -> Text,
device_uuid -> Text,
device_name -> Text,
device_type -> Integer,
login_time -> Timestamp,
ip_address -> Text,
}
}
table! {
twofactor_duo_ctx (state) {
state -> Text,
user_email -> Text,
nonce -> Text,
exp -> BigInt,
}
}
table! {
users (uuid) {
uuid -> Text,
enabled -> Bool,
created_at -> Timestamp,
updated_at -> Timestamp,
verified_at -> Nullable<Timestamp>,
last_verifying_at -> Nullable<Timestamp>,
login_verify_count -> Integer,
email -> Text,
email_new -> Nullable<Text>,
email_new_token -> Nullable<Text>,
name -> Text,
password_hash -> Binary,
salt -> Binary,
password_iterations -> Integer,
password_hint -> Nullable<Text>,
akey -> Text,
private_key -> Nullable<Text>,
public_key -> Nullable<Text>,
totp_secret -> Nullable<Text>,
totp_recover -> Nullable<Text>,
security_stamp -> Text,
stamp_exception -> Nullable<Text>,
equivalent_domains -> Text,
excluded_globals -> Text,
client_kdf_type -> Integer,
client_kdf_iter -> Integer,
client_kdf_memory -> Nullable<Integer>,
client_kdf_parallelism -> Nullable<Integer>,
api_key -> Nullable<Text>,
avatar_color -> Nullable<Text>,
external_id -> Nullable<Text>,
}
}
table! {
users_collections (user_uuid, collection_uuid) {
user_uuid -> Text,
collection_uuid -> Text,
read_only -> Bool,
hide_passwords -> Bool,
manage -> Bool,
}
}
table! {
users_organizations (uuid) {
uuid -> Text,
user_uuid -> Text,
org_uuid -> Text,
invited_by_email -> Nullable<Text>,
access_all -> Bool,
akey -> Text,
status -> Integer,
atype -> Integer,
reset_password_key -> Nullable<Text>,
external_id -> Nullable<Text>,
}
}
table! {
organization_api_key (uuid, org_uuid) {
uuid -> Text,
org_uuid -> Text,
atype -> Integer,
api_key -> Text,
revision_date -> Timestamp,
}
}
table! {
sso_nonce (state) {
state -> Text,
nonce -> Text,
verifier -> Nullable<Text>,
redirect_uri -> Text,
created_at -> Timestamp,
}
}
table! {
sso_users (user_uuid) {
user_uuid -> Text,
identifier -> Text,
}
}
table! {
emergency_access (uuid) {
uuid -> Text,
grantor_uuid -> Text,
grantee_uuid -> Nullable<Text>,
email -> Nullable<Text>,
key_encrypted -> Nullable<Text>,
atype -> Integer,
status -> Integer,
wait_time_days -> Integer,
recovery_initiated_at -> Nullable<Timestamp>,
last_notification_at -> Nullable<Timestamp>,
updated_at -> Timestamp,
created_at -> Timestamp,
}
}
table! {
groups (uuid) {
uuid -> Text,
organizations_uuid -> Text,
name -> Text,
access_all -> Bool,
external_id -> Nullable<Text>,
creation_date -> Timestamp,
revision_date -> Timestamp,
}
}
table! {
groups_users (groups_uuid, users_organizations_uuid) {
groups_uuid -> Text,
users_organizations_uuid -> Text,
}
}
table! {
collections_groups (collections_uuid, groups_uuid) {
collections_uuid -> Text,
groups_uuid -> Text,
read_only -> Bool,
hide_passwords -> Bool,
manage -> Bool,
}
}
table! {
auth_requests (uuid) {
uuid -> Text,
user_uuid -> Text,
organization_uuid -> Nullable<Text>,
request_device_identifier -> Text,
device_type -> Integer,
request_ip -> Text,
response_device_id -> Nullable<Text>,
access_code -> Text,
public_key -> Text,
enc_key -> Nullable<Text>,
master_password_hash -> Nullable<Text>,
approved -> Nullable<Bool>,
creation_date -> Timestamp,
response_date -> Nullable<Timestamp>,
authentication_date -> Nullable<Timestamp>,
}
}
joinable!(attachments -> ciphers (cipher_uuid));
joinable!(ciphers -> organizations (organization_uuid));
joinable!(ciphers -> users (user_uuid));
joinable!(ciphers_collections -> ciphers (cipher_uuid));
joinable!(ciphers_collections -> collections (collection_uuid));
joinable!(collections -> organizations (org_uuid));
joinable!(devices -> users (user_uuid));
joinable!(folders -> users (user_uuid));
joinable!(folders_ciphers -> ciphers (cipher_uuid));
joinable!(folders_ciphers -> folders (folder_uuid));
joinable!(org_policies -> organizations (org_uuid));
joinable!(sends -> organizations (organization_uuid));
joinable!(sends -> users (user_uuid));
joinable!(twofactor -> users (user_uuid));
joinable!(users_collections -> collections (collection_uuid));
joinable!(users_collections -> users (user_uuid));
joinable!(users_organizations -> organizations (org_uuid));
joinable!(users_organizations -> users (user_uuid));
joinable!(users_organizations -> ciphers (org_uuid));
joinable!(organization_api_key -> organizations (org_uuid));
joinable!(emergency_access -> users (grantor_uuid));
joinable!(groups -> organizations (organizations_uuid));
joinable!(groups_users -> users_organizations (users_organizations_uuid));
joinable!(groups_users -> groups (groups_uuid));
joinable!(collections_groups -> collections (collections_uuid));
joinable!(collections_groups -> groups (groups_uuid));
joinable!(event -> users_organizations (uuid));
joinable!(auth_requests -> users (user_uuid));
joinable!(sso_users -> users (user_uuid));
allow_tables_to_appear_in_same_query!(
attachments,
ciphers,
ciphers_collections,
collections,
devices,
folders,
folders_ciphers,
invitations,
org_policies,
organizations,
sends,
sso_users,
twofactor,
users,
users_collections,
users_organizations,
organization_api_key,
emergency_access,
groups,
groups_users,
collections_groups,
event,
auth_requests,
);

50
src/main.rs

@ -71,7 +71,7 @@ pub use util::is_running_in_container;
#[rocket::main]
async fn main() -> Result<(), Error> {
parse_args().await;
parse_args();
launch_info();
let level = init_logging()?;
@ -87,8 +87,8 @@ async fn main() -> Result<(), Error> {
let pool = create_db_pool().await;
schedule_jobs(pool.clone());
db::models::TwoFactor::migrate_u2f_to_webauthn(&mut pool.get().await.unwrap()).await.unwrap();
db::models::TwoFactor::migrate_credential_to_passkey(&mut pool.get().await.unwrap()).await.unwrap();
db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().await.unwrap()).await.unwrap();
db::models::TwoFactor::migrate_credential_to_passkey(&pool.get().await.unwrap()).await.unwrap();
let extra_debug = matches!(level, log::LevelFilter::Trace | log::LevelFilter::Debug);
launch_rocket(pool, extra_debug).await // Blocks until program termination.
@ -117,7 +117,7 @@ PRESETS: m= t= p=
pub const VERSION: Option<&str> = option_env!("VW_VERSION");
async fn parse_args() {
fn parse_args() {
let mut pargs = pico_args::Arguments::from_env();
let version = VERSION.unwrap_or("(Version info from Git not present)");
@ -188,7 +188,7 @@ async fn parse_args() {
exit(1);
}
} else if command == "backup" {
match backup_sqlite().await {
match db::backup_sqlite() {
Ok(f) => {
println!("Backup to '{f}' was successful");
exit(0);
@ -203,23 +203,6 @@ async fn parse_args() {
}
}
async fn backup_sqlite() -> Result<String, Error> {
use crate::db::{backup_database, DbConnType};
if DbConnType::from_url(&CONFIG.database_url()).map(|t| t == DbConnType::sqlite).unwrap_or(false) {
// Establish a connection to the sqlite database
let mut conn = db::DbPool::from_config()
.expect("SQLite database connection failed")
.get()
.await
.expect("Unable to get SQLite db pool");
let backup_file = backup_database(&mut conn).await?;
Ok(backup_file)
} else {
err_silent!("The database type is not SQLite. Backups only works for SQLite databases")
}
}
fn launch_info() {
println!(
"\
@ -285,13 +268,6 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
log::LevelFilter::Off
};
let diesel_logger_level: log::LevelFilter =
if cfg!(feature = "query_logger") && std::env::var("QUERY_LOGGER").is_ok() {
log::LevelFilter::Debug
} else {
log::LevelFilter::Off
};
// Only show Rocket underscore `_` logs when the level is Debug or higher
// Else this will bloat the log output with useless messages.
let rocket_underscore_level = if level >= log::LevelFilter::Debug {
@ -342,9 +318,15 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
// Variable level for hickory used by reqwest
("hickory_resolver::name_server::name_server", hickory_level),
("hickory_proto::xfer", hickory_level),
("diesel_logger", diesel_logger_level),
// SMTP
("lettre::transport::smtp", smtp_log_level),
// Set query_logger default to Off, but can be overwritten manually
// You can set LOG_LEVEL=info,vaultwarden::db::query_logger=<LEVEL> to overwrite it.
// This makes it possible to do the following:
// warn = Print slow queries only, 5 seconds or longer
// info = Print slow queries only, 1 second or longer
// debug = Print all queries
("vaultwarden::db::query_logger", log::LevelFilter::Off),
]);
for (path, level) in levels_override.into_iter() {
@ -614,21 +596,25 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
CONFIG.shutdown();
});
#[cfg(unix)]
#[cfg(all(unix, sqlite))]
{
if db::ACTIVE_DB_TYPE.get() != Some(&db::DbConnType::Sqlite) {
debug!("PostgreSQL and MySQL/MariaDB do not support this backup feature, skip adding USR1 signal.");
} else {
tokio::spawn(async move {
let mut signal_user1 = tokio::signal::unix::signal(SignalKind::user_defined1()).unwrap();
loop {
// If we need more signals to act upon, we might want to use select! here.
// With only one item to listen for this is enough.
let _ = signal_user1.recv().await;
match backup_sqlite().await {
match db::backup_sqlite() {
Ok(f) => info!("Backup to '{f}' was successful"),
Err(e) => error!("Backup failed. {e:?}"),
}
}
});
}
}
instance.launch().await?;

15
src/sso.rs

@ -165,12 +165,7 @@ pub fn decode_state(base64_state: String) -> ApiResult<OIDCState> {
// The `nonce` allow to protect against replay attacks
// redirect_uri from: https://github.com/bitwarden/server/blob/main/src/Identity/IdentityServer/ApiClient.cs
pub async fn authorize_url(
state: OIDCState,
client_id: &str,
raw_redirect_uri: &str,
mut conn: DbConn,
) -> ApiResult<Url> {
pub async fn authorize_url(state: OIDCState, client_id: &str, raw_redirect_uri: &str, conn: DbConn) -> ApiResult<Url> {
let redirect_uri = match client_id {
"web" | "browser" => format!("{}/sso-connector.html", CONFIG.domain()),
"desktop" | "mobile" => "bitwarden://sso-callback".to_string(),
@ -185,7 +180,7 @@ pub async fn authorize_url(
};
let (auth_url, nonce) = Client::authorize_url(state, redirect_uri).await?;
nonce.save(&mut conn).await?;
nonce.save(&conn).await?;
Ok(auth_url)
}
@ -235,7 +230,7 @@ pub struct UserInformation {
pub user_name: Option<String>,
}
async fn decode_code_claims(code: &str, conn: &mut DbConn) -> ApiResult<(OIDCCode, OIDCState)> {
async fn decode_code_claims(code: &str, conn: &DbConn) -> ApiResult<(OIDCCode, OIDCState)> {
match auth::decode_jwt::<OIDCCodeClaims>(code, SSO_JWT_ISSUER.to_string()) {
Ok(code_claims) => match code_claims.code {
OIDCCodeWrapper::Ok {
@ -265,7 +260,7 @@ async fn decode_code_claims(code: &str, conn: &mut DbConn) -> ApiResult<(OIDCCod
// - second time we will rely on the `AC_CACHE` since the `code` has already been exchanged.
// The `nonce` will ensure that the user is authorized only once.
// We return only the `UserInformation` to force calling `redeem` to obtain the `refresh_token`.
pub async fn exchange_code(wrapped_code: &str, conn: &mut DbConn) -> ApiResult<UserInformation> {
pub async fn exchange_code(wrapped_code: &str, conn: &DbConn) -> ApiResult<UserInformation> {
use openidconnect::OAuth2TokenResponse;
let (code, state) = decode_code_claims(wrapped_code, conn).await?;
@ -330,7 +325,7 @@ pub async fn exchange_code(wrapped_code: &str, conn: &mut DbConn) -> ApiResult<U
}
// User has passed 2FA flow we can delete `nonce` and clear the cache.
pub async fn redeem(state: &OIDCState, conn: &mut DbConn) -> ApiResult<AuthenticatedUser> {
pub async fn redeem(state: &OIDCState, conn: &DbConn) -> ApiResult<AuthenticatedUser> {
if let Err(err) = SsoNonce::delete(state, conn).await {
error!("Failed to delete database sso_nonce using {state}: {err}")
}

Loading…
Cancel
Save